diff --git a/coderd/aiseats/aiseats.go b/coderd/aiseats/aiseats.go new file mode 100644 index 0000000000..06c48e28a6 --- /dev/null +++ b/coderd/aiseats/aiseats.go @@ -0,0 +1,38 @@ +// Package aiseats is the AGPL version the package. +// The actual implementation is in `enterprise/aiseats`. +package aiseats + +import ( + "context" + + "github.com/google/uuid" + + "github.com/coder/coder/v2/coderd/database" +) + +type Reason struct { + EventType database.AiSeatUsageReason + Description string +} + +// ReasonAIBridge constructs a reason for usage originating from AI Bridge. +func ReasonAIBridge(description string) Reason { + return Reason{EventType: database.AiSeatUsageReasonAibridge, Description: description} +} + +// ReasonTask constructs a reason for usage originating from tasks. +func ReasonTask(description string) Reason { + return Reason{EventType: database.AiSeatUsageReasonTask, Description: description} +} + +// SeatTracker records AI seat consumption state. +type SeatTracker interface { + // RecordUsage does not return an error to prevent blocking the user from using + // AI features. This method is used to record usage, not enforce it. + RecordUsage(ctx context.Context, userID uuid.UUID, reason Reason) +} + +// Noop is an AGPL seat tracker that does nothing. +type Noop struct{} + +func (Noop) RecordUsage(context.Context, uuid.UUID, Reason) {} diff --git a/coderd/coderd.go b/coderd/coderd.go index 4314367945..4a3f5cd1ab 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -44,6 +44,7 @@ import ( "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/coderd/agentapi" "github.com/coder/coder/v2/coderd/agentapi/metadatabatcher" + "github.com/coder/coder/v2/coderd/aiseats" _ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs. "github.com/coder/coder/v2/coderd/appearance" "github.com/coder/coder/v2/coderd/audit" @@ -630,6 +631,8 @@ func New(options *Options) *API { dbRolluper: options.DatabaseRolluper, ProfileCollector: defaultProfileCollector{}, } + api.AISeatTracker = aiseats.Noop{} + api.WorkspaceAppsProvider = workspaceapps.NewDBTokenProvider( ctx, options.Logger.Named("workspaceapps"), @@ -2033,6 +2036,8 @@ type API struct { dbRolluper *dbrollup.Rolluper // chatDaemon handles background processing of pending chats. chatDaemon *chatd.Server + // AISeatTracker records AI seat usage. + AISeatTracker aiseats.SeatTracker // gitSyncWorker refreshes stale chat diff statuses in the // background. gitSyncWorker *gitsync.Worker @@ -2245,6 +2250,7 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n provisionerdserver.Options{ OIDCConfig: api.OIDCConfig, ExternalAuthConfigs: api.ExternalAuthConfigs, + AISeatTracker: api.AISeatTracker, Clock: api.Clock, HeartbeatFn: options.heartbeatFn, }, diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index b776b75bf9..db0aff780a 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -28,6 +28,7 @@ import ( protobuf "google.golang.org/protobuf/proto" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/aiseats" "github.com/coder/coder/v2/coderd/apikey" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/database" @@ -76,6 +77,7 @@ const ( type Options struct { OIDCConfig promoauth.OAuth2Config ExternalAuthConfigs []*externalauth.Config + AISeatTracker aiseats.SeatTracker // Clock for testing Clock quartz.Clock @@ -120,6 +122,7 @@ type server struct { NotificationsEnqueuer notifications.Enqueuer PrebuildsOrchestrator *atomic.Pointer[prebuilds.ReconciliationOrchestrator] UsageInserter *atomic.Pointer[usage.Inserter] + AISeatTracker aiseats.SeatTracker Experiments codersdk.Experiments OIDCConfig promoauth.OAuth2Config @@ -215,6 +218,9 @@ func NewServer( if err := tags.Valid(); err != nil { return nil, xerrors.Errorf("invalid tags: %w", err) } + if options.AISeatTracker == nil { + options.AISeatTracker = aiseats.Noop{} + } if options.AcquireJobLongPollDur == 0 { options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur } @@ -253,6 +259,7 @@ func NewServer( heartbeatFn: options.HeartbeatFn, PrebuildsOrchestrator: prebuildsOrchestrator, UsageInserter: usageInserter, + AISeatTracker: options.AISeatTracker, metrics: metrics, Experiments: experiments, } @@ -2437,6 +2444,12 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro }) } + // Record AI seat usage for successful task workspace builds. + if workspaceBuild.Transition == database.WorkspaceTransitionStart && workspace.TaskID.Valid { + s.AISeatTracker.RecordUsage(ctx, workspace.OwnerID, + aiseats.ReasonTask("task workspace build succeeded")) + } + if s.PrebuildsOrchestrator != nil && input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM { // Track resource replacements, if there are any. orchestrator := s.PrebuildsOrchestrator.Load() diff --git a/enterprise/aibridgedserver/aibridgedserver.go b/enterprise/aibridgedserver/aibridgedserver.go index 54104b7f4f..49f6b51bc6 100644 --- a/enterprise/aibridgedserver/aibridgedserver.go +++ b/enterprise/aibridgedserver/aibridgedserver.go @@ -15,6 +15,7 @@ import ( "google.golang.org/protobuf/types/known/structpb" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/aiseats" "github.com/coder/coder/v2/coderd/apikey" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" @@ -81,10 +82,12 @@ type Server struct { coderMCPConfig *proto.MCPServerConfig // may be nil if not available structuredLogging bool + aiSeatTracker aiseats.SeatTracker } func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, accessURL string, bridgeCfg codersdk.AIBridgeConfig, externalAuthConfigs []*externalauth.Config, experiments codersdk.Experiments, + aiSeatTracker aiseats.SeatTracker, ) (*Server, error) { eac := make(map[string]*externalauth.Config, len(externalAuthConfigs)) @@ -102,6 +105,7 @@ func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, ac logger: logger, externalAuthConfigs: eac, structuredLogging: bridgeCfg.StructuredLogging.Value(), + aiSeatTracker: aiSeatTracker, } if bridgeCfg.InjectCoderMCPTools { @@ -184,6 +188,8 @@ func (s *Server) RecordInterception(ctx context.Context, in *proto.RecordInterce return nil, xerrors.Errorf("start interception: %w", err) } + reason := aiseats.ReasonAIBridge("provider=" + in.Provider + ", model=" + in.Model) + s.aiSeatTracker.RecordUsage(ctx, initID, reason) return &proto.RecordInterceptionResponse{}, nil } diff --git a/enterprise/aibridgedserver/aibridgedserver_test.go b/enterprise/aibridgedserver/aibridgedserver_test.go index bae2197d76..b195829534 100644 --- a/enterprise/aibridgedserver/aibridgedserver_test.go +++ b/enterprise/aibridgedserver/aibridgedserver_test.go @@ -24,6 +24,7 @@ import ( "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogjson" + "github.com/coder/coder/v2/coderd/aiseats" "github.com/coder/coder/v2/coderd/apikey" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" @@ -176,7 +177,7 @@ func TestAuthorization(t *testing.T) { tc.mocksFn(db, apiKey, user) } - srv, err := aibridgedserver.NewServer(t.Context(), db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments) + srv, err := aibridgedserver.NewServer(t.Context(), db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, aiseats.Noop{}) require.NoError(t, err) require.NotNil(t, srv) @@ -268,7 +269,7 @@ func TestGetMCPServerConfigs(t *testing.T) { accessURL := "https://my-cool-deployment.com" srv, err := aibridgedserver.NewServer(t.Context(), db, logger, accessURL, codersdk.AIBridgeConfig{ InjectCoderMCPTools: serpent.Bool(!tc.disableCoderMCPInjection), - }, tc.externalAuthConfigs, tc.experiments) + }, tc.externalAuthConfigs, tc.experiments, aiseats.Noop{}) require.NoError(t, err) require.NotNil(t, srv) @@ -318,7 +319,7 @@ func TestGetMCPServerAccessTokensBatch(t *testing.T) { { ID: "3", }, - }, requiredExperiments) + }, requiredExperiments, aiseats.Noop{}) require.NoError(t, err) require.NotNil(t, srv) @@ -1014,7 +1015,7 @@ func testRecordMethod[Req any, Resp any]( } ctx := testutil.Context(t, testutil.WaitLong) - srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments) + srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, aiseats.Noop{}) require.NoError(t, err) resp, err := callMethod(srv, ctx, tc.request) @@ -1309,7 +1310,7 @@ func TestStructuredLogging(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{ StructuredLogging: serpent.Bool(tc.structuredLogging), - }, nil, requiredExperiments) + }, nil, requiredExperiments, aiseats.Noop{}) require.NoError(t, err) err = tc.recordFn(srv, ctx, interceptionID) @@ -1351,7 +1352,7 @@ func TestInferredThreadsByToolCalls(t *testing.T) { user := dbgen.User(t, db, database.User{}) - srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments) + srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, aiseats.Noop{}) require.NoError(t, err) aID := uuid.New() diff --git a/enterprise/aiseats/tracker.go b/enterprise/aiseats/tracker.go new file mode 100644 index 0000000000..d69dc0491c --- /dev/null +++ b/enterprise/aiseats/tracker.go @@ -0,0 +1,91 @@ +package aiseats + +import ( + "context" + "sync" + "time" + + "github.com/google/uuid" + + "cdr.dev/slog/v3" + agplaiseats "github.com/coder/coder/v2/coderd/aiseats" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/quartz" +) + +type store interface { + UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) error +} + +// throttleInterval is the minimum time between DB writes for the same user. This +// is to prevent ai seat tracking from consuming more db resources. +// +// These events are not critical to be recorded in real time, so we can afford to +// skip almost all of them. The first write is the most important, as it +// indicates a seat is consumed. Subsequent writes are purely informative and has +// no functional impact. +const ( + throttleInterval = 6 * time.Hour + // failedThrottleInterval exists to prevent a transient error from causing no + // usage to be recorded. Still debounce. + failedThrottleInterval = 30 * time.Minute +) + +// SeatTracker records current AI seat state for users. +type SeatTracker struct { + db store + logger slog.Logger + clock quartz.Clock + + mu sync.RWMutex + retryAfter map[uuid.UUID]time.Time +} + +func New(db store, logger slog.Logger, clock quartz.Clock) *SeatTracker { + if clock == nil { + clock = quartz.NewReal() + } + return &SeatTracker{db: db, logger: logger, clock: clock, retryAfter: make(map[uuid.UUID]time.Time)} +} + +// skipRecord returns true when the user is still in the retry cooldown +// window and we should skip a DB write attempt. +func (t *SeatTracker) skipRecord(userID uuid.UUID, now time.Time) bool { + t.mu.RLock() + defer t.mu.RUnlock() + + retryAfter, ok := t.retryAfter[userID] + return ok && now.Before(retryAfter) +} + +// recordThrottle sets the next time when DB writes for this user are allowed. +func (t *SeatTracker) recordThrottle(userID uuid.UUID, now time.Time, d time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + t.retryAfter[userID] = now.Add(d) +} + +// RecordUsage will record the AI seat usage for the user. There is a race condition between +// checking if the user should be recorded or throttled and actually recording. This is fine, as +// it just means we record the usage twice. +// The throttle just exists to prevent excessive database queries. +func (t *SeatTracker) RecordUsage(ctx context.Context, userID uuid.UUID, reason agplaiseats.Reason) { + now := t.clock.Now() + if t.skipRecord(userID, now) { + return + } + + err := t.db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{ + UserID: userID, + FirstUsedAt: now, + LastEventType: reason.EventType, + LastEventDescription: reason.Description, + }) + if err != nil { + t.logger.Warn(ctx, "upsert AI seat state", slog.Error(err), slog.F("user_id", userID), slog.F("event_type", reason.EventType)) + t.recordThrottle(userID, now, failedThrottleInterval) + return + } + + t.recordThrottle(userID, now, throttleInterval) +} diff --git a/enterprise/aiseats/tracker_test.go b/enterprise/aiseats/tracker_test.go new file mode 100644 index 0000000000..574e80fcbb --- /dev/null +++ b/enterprise/aiseats/tracker_test.go @@ -0,0 +1,94 @@ +package aiseats_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + agplaiseats "github.com/coder/coder/v2/coderd/aiseats" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" + enterpriseaiseats "github.com/coder/coder/v2/enterprise/aiseats" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func TestSeatTrackerDB(t *testing.T) { + t.Parallel() + + t.Run("ActiveUserRecorded", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + clock := quartz.NewMock(t) + tracker := enterpriseaiseats.New(db, testutil.Logger(t), clock) + + user := dbgen.User(t, db, database.User{Status: database.UserStatusActive}) + tracker.RecordUsage(ctx, user.ID, agplaiseats.ReasonAIBridge("active user event")) + + count, err := db.GetActiveAISeatCount(ctx) + require.NoError(t, err) + require.EqualValues(t, 1, count) + }) + + t.Run("InactiveUsersExcluded", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + tracker := enterpriseaiseats.New(db, testutil.Logger(t), quartz.NewMock(t)) + + dormantUser := dbgen.User(t, db, database.User{Status: database.UserStatusDormant}) + tracker.RecordUsage(ctx, dormantUser.ID, agplaiseats.ReasonTask("dormant user event")) + + suspendedUser := dbgen.User(t, db, database.User{Status: database.UserStatusSuspended}) + tracker.RecordUsage(ctx, suspendedUser.ID, agplaiseats.ReasonTask("suspended user event")) + + count, err := db.GetActiveAISeatCount(ctx) + require.NoError(t, err) + require.EqualValues(t, 0, count) + }) + + t.Run("StatusTransitions", func(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + tracker := enterpriseaiseats.New(db, testutil.Logger(t), quartz.NewMock(t)) + + user := dbgen.User(t, db, database.User{Status: database.UserStatusActive}) + tracker.RecordUsage(ctx, user.ID, agplaiseats.ReasonAIBridge("status transition")) + + count, err := db.GetActiveAISeatCount(ctx) + require.NoError(t, err) + require.EqualValues(t, 1, count) + + _, err = db.UpdateUserStatus(ctx, database.UpdateUserStatusParams{ + ID: user.ID, + Status: database.UserStatusDormant, + UpdatedAt: dbtime.Now(), + UserIsSeen: false, + }) + require.NoError(t, err) + + count, err = db.GetActiveAISeatCount(ctx) + require.NoError(t, err) + require.EqualValues(t, 0, count) + + _, err = db.UpdateUserStatus(ctx, database.UpdateUserStatusParams{ + ID: user.ID, + Status: database.UserStatusActive, + UpdatedAt: dbtime.Now().Add(time.Second), + UserIsSeen: false, + }) + require.NoError(t, err) + + count, err = db.GetActiveAISeatCount(ctx) + require.NoError(t, err) + require.EqualValues(t, 1, count) + }) +} diff --git a/enterprise/coderd/aibridged.go b/enterprise/coderd/aibridged.go index 95c06fd5c9..3eff01d497 100644 --- a/enterprise/coderd/aibridged.go +++ b/enterprise/coderd/aibridged.go @@ -48,7 +48,7 @@ func (api *API) CreateInMemoryAIBridgeServer(dialCtx context.Context) (client ai mux := drpcmux.New() srv, err := aibridgedserver.NewServer(api.ctx, api.Database, api.Logger.Named("aibridgedserver"), - api.AccessURL.String(), api.DeploymentValues.AI.BridgeConfig, api.ExternalAuthConfigs, api.AGPL.Experiments) + api.AccessURL.String(), api.DeploymentValues.AI.BridgeConfig, api.ExternalAuthConfigs, api.AGPL.Experiments, api.aiSeatTracker) if err != nil { return nil, err } diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index caf8baae63..bfefc89404 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -45,6 +45,7 @@ import ( agplusage "github.com/coder/coder/v2/coderd/usage" "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/enterprise/aiseats" entchatd "github.com/coder/coder/v2/enterprise/coderd/chatd" "github.com/coder/coder/v2/enterprise/coderd/connectionlog" "github.com/coder/coder/v2/enterprise/coderd/dbauthz" @@ -217,7 +218,10 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { }, }) + api.aiSeatTracker = aiseats.New(options.Database, api.Logger.Named("aiseats"), quartz.NewReal()) + api.AGPL = coderd.New(options.Options) + api.AGPL.AISeatTracker = api.aiSeatTracker defer func() { if err != nil { _ = api.Close() @@ -785,6 +789,7 @@ type API struct { aibridgedHandler http.Handler aibridgeproxydHandler http.Handler + aiSeatTracker *aiseats.SeatTracker } // writeEntitlementWarningsHeader writes the entitlement warnings to the response header diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index af52fc9b6e..c293abced2 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -356,6 +356,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) provisionerdserver.Options{ ExternalAuthConfigs: api.ExternalAuthConfigs, OIDCConfig: api.OIDCConfig, + AISeatTracker: api.AGPL.AISeatTracker, Clock: api.Clock, }, api.NotificationsEnqueuer,