fix: limit calls to GetWorkspaceAgentByID in agentapi (#23015)

We currently call GetWorkspaceAgentByID millions of times at scale
unnecessarily. This PR embeds immutable fields into the relevant
services instead of fetching for them every time.

resolves https://github.com/coder/scaletest/issues/84

Confirmed with a 10k scaletest that this changeset takes the query from
10M+ queries down to 39k
This commit is contained in:
Jon Ayers
2026-03-20 15:42:05 -05:00
committed by GitHub
parent 32021b3ac2
commit f135ffdb3a
23 changed files with 173 additions and 228 deletions
+13 -8
View File
@@ -103,7 +103,7 @@ type Options struct {
UpdateAgentMetricsFn func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
}
func New(opts Options, workspace database.Workspace) *API {
func New(opts Options, workspace database.Workspace, agent database.WorkspaceAgent) *API {
if opts.Clock == nil {
opts.Clock = quartz.NewReal()
}
@@ -156,7 +156,8 @@ func New(opts Options, workspace database.Workspace) *API {
}
api.StatsAPI = &StatsAPI{
AgentFn: api.agent,
AgentID: agent.ID,
AgentName: agent.Name,
Workspace: api.cachedWorkspaceFields,
Database: opts.Database,
Log: opts.Log,
@@ -175,16 +176,18 @@ func New(opts Options, workspace database.Workspace) *API {
}
api.AppsAPI = &AppsAPI{
AgentID: agent.ID,
AgentFn: api.agent,
Database: opts.Database,
Log: opts.Log,
Workspace: api.cachedWorkspaceFields,
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
Clock: opts.Clock,
NotificationsEnqueuer: opts.NotificationsEnqueuer,
}
api.MetadataAPI = &MetadataAPI{
AgentFn: api.agent,
AgentID: agent.ID,
Workspace: api.cachedWorkspaceFields,
Database: opts.Database,
Log: opts.Log,
@@ -204,7 +207,8 @@ func New(opts Options, workspace database.Workspace) *API {
}
api.ConnLogAPI = &ConnLogAPI{
AgentFn: api.agent,
AgentID: agent.ID,
AgentName: agent.Name,
ConnectionLogger: opts.ConnectionLogger,
Database: opts.Database,
Workspace: api.cachedWorkspaceFields,
@@ -222,7 +226,6 @@ func New(opts Options, workspace database.Workspace) *API {
api.SubAgentAPI = &SubAgentAPI{
OwnerID: opts.OwnerID,
OrganizationID: opts.OrganizationID,
AgentID: opts.AgentID,
AgentFn: api.agent,
Log: opts.Log,
Clock: opts.Clock,
@@ -297,8 +300,10 @@ func (a *API) agent(ctx context.Context) (database.WorkspaceAgent, error) {
func (a *API) refreshCachedWorkspace(ctx context.Context) {
ws, err := a.opts.Database.GetWorkspaceByID(ctx, a.opts.WorkspaceID)
if err != nil {
// Do not clear the cache on transient DB errors. Stale data is
// preferable to no data, which forces callers to fall back to
// expensive queries like GetWorkspaceByAgentID.
a.opts.Log.Warn(ctx, "failed to refresh cached workspace fields", slog.Error(err))
a.cachedWorkspaceFields.Clear()
return
}
@@ -341,11 +346,11 @@ func (a *API) startCacheRefreshLoop(ctx context.Context) {
a.cachedWorkspaceFields.Clear()
}
func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
func (a *API) publishWorkspaceUpdate(ctx context.Context, agentID uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
a.opts.PublishWorkspaceUpdateFn(ctx, a.opts.OwnerID, wspubsub.WorkspaceEvent{
Kind: kind,
WorkspaceID: a.opts.WorkspaceID,
AgentID: &agent.ID,
AgentID: &agentID,
})
return nil
}
+38 -33
View File
@@ -24,22 +24,19 @@ import (
)
type AppsAPI struct {
AgentID uuid.UUID
AgentFn func(context.Context) (database.WorkspaceAgent, error)
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
Workspace *CachedWorkspaceFields
PublishWorkspaceUpdateFn func(context.Context, uuid.UUID, wspubsub.WorkspaceEventKind) error
NotificationsEnqueuer notifications.Enqueuer
Clock quartz.Clock
}
func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
a.Log.Debug(ctx, "got batch app health update",
slog.F("agent_id", workspaceAgent.ID.String()),
slog.F("agent_id", a.AgentID.String()),
slog.F("updates", req.Updates),
)
@@ -47,9 +44,9 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
return &agentproto.BatchUpdateAppHealthResponse{}, nil
}
apps, err := a.Database.GetWorkspaceAppsByAgentID(ctx, workspaceAgent.ID)
apps, err := a.Database.GetWorkspaceAppsByAgentID(ctx, a.AgentID)
if err != nil {
return nil, xerrors.Errorf("get workspace apps by agent ID %q: %w", workspaceAgent.ID, err)
return nil, xerrors.Errorf("get workspace apps by agent ID %q: %w", a.AgentID, err)
}
var newApps []database.WorkspaceApp
@@ -110,7 +107,7 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
}
if a.PublishWorkspaceUpdateFn != nil && len(newApps) > 0 {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAppHealthUpdate)
err = a.PublishWorkspaceUpdateFn(ctx, a.AgentID, wspubsub.WorkspaceEventKindAppHealthUpdate)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
}
@@ -149,12 +146,8 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
})
}
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
app, err := a.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: workspaceAgent.ID,
AgentID: a.AgentID,
Slug: req.Slug,
})
if err != nil {
@@ -164,11 +157,10 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
})
}
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace.",
Detail: err.Error(),
ws, ok := a.Workspace.AsWorkspaceIdentity()
if !ok {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Workspace identity not cached.",
})
}
@@ -190,8 +182,8 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
_, err = a.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
ID: uuid.New(),
CreatedAt: dbtime.Now(),
WorkspaceID: workspace.ID,
AgentID: workspaceAgent.ID,
WorkspaceID: ws.ID,
AgentID: a.AgentID,
AppID: app.ID,
State: dbState,
Message: cleaned,
@@ -208,7 +200,7 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
}
if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
err = a.PublishWorkspaceUpdateFn(ctx, a.AgentID, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
if err != nil {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to publish workspace update.",
@@ -217,14 +209,14 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
}
}
// Notify on state change to Working/Idle for AI tasks
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState, workspace, workspaceAgent)
// Notify on state change to Working/Idle for AI tasks.
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState)
if shouldBump(dbState, latestAppStatus) {
// We pass time.Time{} for nextAutostart since we don't have access to
// TemplateScheduleStore here. The activity bump logic handles this by
// defaulting to the template's activity_bump duration (typically 1 hour).
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, workspace.ID, time.Time{})
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, ws.ID, time.Time{})
}
// just return a blank response because it doesn't contain any settable fields at present.
return new(agentproto.UpdateAppStatusResponse), nil
@@ -261,8 +253,6 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
appID uuid.UUID,
latestAppStatus database.WorkspaceAppStatus,
newAppStatus database.WorkspaceAppStatusState,
workspace database.Workspace,
agent database.WorkspaceAgent,
) {
var notificationTemplate uuid.UUID
switch newAppStatus {
@@ -279,11 +269,20 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
return
}
if !workspace.TaskID.Valid {
taskID := a.Workspace.TaskID()
if !taskID.Valid {
// Workspace has no task ID, do nothing.
return
}
// Only fetch fresh agent state for task workspaces, since we need
// the current lifecycle state to decide whether to send notifications.
agent, err := a.AgentFn(ctx)
if err != nil {
a.Log.Warn(ctx, "failed to get agent for AI task notification", slog.Error(err))
return
}
// Only send notifications when the agent is ready. We want to skip
// any state transitions that occur whilst the workspace is starting
// up as it doesn't make sense to receive them.
@@ -296,7 +295,7 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
return
}
task, err := a.Database.GetTaskByID(ctx, workspace.TaskID.UUID)
task, err := a.Database.GetTaskByID(ctx, taskID.UUID)
if err != nil {
a.Log.Warn(ctx, "failed to get task", slog.Error(err))
return
@@ -321,14 +320,20 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
return
}
ws, ok := a.Workspace.AsWorkspaceIdentity()
if !ok {
a.Log.Warn(ctx, "failed to get workspace identity for AI task notification")
return
}
if _, err := a.NotificationsEnqueuer.EnqueueWithData(
// nolint:gocritic // Need notifier actor to enqueue notifications
dbauthz.AsNotifier(ctx),
workspace.OwnerID,
ws.OwnerID,
notificationTemplate,
map[string]string{
"task": task.Name,
"workspace": workspace.Name,
"workspace": ws.Name,
},
map[string]any{
// Use a 1-minute bucketed timestamp to bypass per-day dedupe,
@@ -338,7 +343,7 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
},
"api-workspace-agent-app-status",
// Associate this notification with related entities
workspace.ID, workspace.OwnerID, workspace.OrganizationID, appID,
ws.ID, ws.OwnerID, ws.OrganizationID, appID,
); err != nil {
a.Log.Warn(ctx, "failed to notify of task state", slog.Error(err))
return
+28 -42
View File
@@ -67,12 +67,10 @@ func TestBatchUpdateAppHealths(t *testing.T) {
publishCalled := false
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true
return nil
},
@@ -105,12 +103,10 @@ func TestBatchUpdateAppHealths(t *testing.T) {
publishCalled := false
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true
return nil
},
@@ -144,12 +140,10 @@ func TestBatchUpdateAppHealths(t *testing.T) {
publishCalled := false
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true
return nil
},
@@ -180,9 +174,7 @@ func TestBatchUpdateAppHealths(t *testing.T) {
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app3}, nil)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: nil,
@@ -209,9 +201,7 @@ func TestBatchUpdateAppHealths(t *testing.T) {
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: nil,
@@ -239,9 +229,7 @@ func TestBatchUpdateAppHealths(t *testing.T) {
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: nil,
@@ -279,14 +267,26 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
}
workspaceUpdates := make(chan wspubsub.WorkspaceEventKind, 100)
workspace := database.Workspace{
ID: uuid.UUID{9},
TaskID: uuid.NullUUID{
Valid: true,
UUID: uuid.UUID{7},
},
}
cachedWs := &agentapi.CachedWorkspaceFields{}
cachedWs.UpdateValues(workspace)
api := &agentapi.AppsAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(_ context.Context, agnt *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
assert.Equal(t, *agnt, agent)
Database: mDB,
Log: testutil.Logger(t),
Workspace: cachedWs,
PublishWorkspaceUpdateFn: func(_ context.Context, agnt uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
assert.Equal(t, agnt, agent.ID)
testutil.AssertSend(ctx, t, workspaceUpdates, kind)
return nil
},
@@ -309,14 +309,6 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
},
}
mDB.EXPECT().GetTaskByID(gomock.Any(), task.ID).Times(1).Return(task, nil)
workspace := database.Workspace{
ID: uuid.UUID{9},
TaskID: uuid.NullUUID{
Valid: true,
UUID: task.ID,
},
}
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Times(1).Return(workspace, nil)
appStatus := database.WorkspaceAppStatus{
ID: uuid.UUID{6},
}
@@ -363,9 +355,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
Return(database.WorkspaceApp{}, sql.ErrNoRows)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
Database: mDB,
Log: testutil.Logger(t),
}
@@ -392,9 +382,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
}
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
Database: mDB,
Log: testutil.Logger(t),
}
@@ -422,9 +410,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
}
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
Database: mDB,
Log: testutil.Logger(t),
}
+10
View File
@@ -4,6 +4,7 @@ import (
"context"
"sync"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
@@ -23,12 +24,14 @@ type CachedWorkspaceFields struct {
lock sync.RWMutex
identity database.WorkspaceIdentity
taskID uuid.NullUUID
}
func (cws *CachedWorkspaceFields) Clear() {
cws.lock.Lock()
defer cws.lock.Unlock()
cws.identity = database.WorkspaceIdentity{}
cws.taskID = uuid.NullUUID{}
}
func (cws *CachedWorkspaceFields) UpdateValues(ws database.Workspace) {
@@ -42,6 +45,13 @@ func (cws *CachedWorkspaceFields) UpdateValues(ws database.Workspace) {
cws.identity.OwnerUsername = ws.OwnerUsername
cws.identity.TemplateName = ws.TemplateName
cws.identity.AutostartSchedule = ws.AutostartSchedule
cws.taskID = ws.TaskID
}
func (cws *CachedWorkspaceFields) TaskID() uuid.NullUUID {
cws.lock.RLock()
defer cws.lock.RUnlock()
return cws.taskID
}
// Returns the Workspace, true, unless the workspace has not been cached (nuked or was a prebuild).
+4 -19
View File
@@ -14,11 +14,11 @@ import (
"github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
)
type ConnLogAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
AgentID uuid.UUID
AgentName string
ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger]
Workspace *CachedWorkspaceFields
Database database.Store
@@ -53,27 +53,12 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
}
}
// Inject RBAC object into context for dbauthz fast path, avoid having to
// call GetWorkspaceByAgentID on every metadata update.
rbacCtx := ctx
var ws database.WorkspaceIdentity
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
ws = dbws
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
if err != nil {
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
//nolint:gocritic
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
}
}
// Fetch contextual data for this connection log event.
workspaceAgent, err := a.AgentFn(rbacCtx)
if err != nil {
return nil, xerrors.Errorf("get agent: %w", err)
}
if ws.Equal(database.WorkspaceIdentity{}) {
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, a.AgentID)
if err != nil {
return nil, xerrors.Errorf("get workspace by agent id: %w", err)
}
@@ -97,7 +82,7 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: workspaceAgent.Name,
AgentName: a.AgentName,
Type: connectionType,
Code: code,
Ip: logIP,
+3 -4
View File
@@ -114,10 +114,9 @@ func TestConnectionLog(t *testing.T) {
api := &agentapi.ConnLogAPI{
ConnectionLogger: asAtomicPointer[connectionlog.ConnectionLogger](connLogger),
Database: mDB,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Workspace: &agentapi.CachedWorkspaceFields{},
AgentID: agent.ID,
AgentName: agent.Name,
Workspace: &agentapi.CachedWorkspaceFields{},
}
api.ReportConnection(context.Background(), &agentproto.ReportConnectionRequest{
Connection: &agentproto.Connection{
+2 -2
View File
@@ -30,7 +30,7 @@ type LifecycleAPI struct {
WorkspaceID uuid.UUID
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
PublishWorkspaceUpdateFn func(context.Context, uuid.UUID, wspubsub.WorkspaceEventKind) error
TimeNowFn func() time.Time // defaults to dbtime.Now()
Metrics *LifecycleMetrics
@@ -122,7 +122,7 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
}
if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentLifecycleUpdate)
err = a.PublishWorkspaceUpdateFn(ctx, workspaceAgent.ID, wspubsub.WorkspaceEventKindAgentLifecycleUpdate)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
}
+4 -4
View File
@@ -85,7 +85,7 @@ func TestUpdateLifecycle(t *testing.T) {
WorkspaceID: workspaceID,
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true
return nil
},
@@ -206,7 +206,7 @@ func TestUpdateLifecycle(t *testing.T) {
Database: dbM,
Log: testutil.Logger(t),
Metrics: metrics,
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true
return nil
},
@@ -323,7 +323,7 @@ func TestUpdateLifecycle(t *testing.T) {
Database: dbM,
Log: testutil.Logger(t),
Metrics: metrics,
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
atomic.AddInt64(&publishCalled, 1)
return nil
},
@@ -410,7 +410,7 @@ func TestUpdateLifecycle(t *testing.T) {
WorkspaceID: workspaceID,
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true
return nil
},
+3 -3
View File
@@ -19,7 +19,7 @@ type LogsAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
PublishWorkspaceUpdateFn func(context.Context, uuid.UUID, wspubsub.WorkspaceEventKind) error
PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage)
TimeNowFn func() time.Time // defaults to dbtime.Now()
@@ -125,7 +125,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
}
if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentLogsOverflow)
err = a.PublishWorkspaceUpdateFn(ctx, workspaceAgent.ID, wspubsub.WorkspaceEventKindAgentLogsOverflow)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
}
@@ -145,7 +145,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
if workspaceAgent.LogsLength == 0 && a.PublishWorkspaceUpdateFn != nil {
// If these are the first logs being appended, we publish a UI update
// to notify the UI that logs are now available.
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentFirstLogs)
err = a.PublishWorkspaceUpdateFn(ctx, workspaceAgent.ID, wspubsub.WorkspaceEventKindAgentFirstLogs)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
}
+6 -6
View File
@@ -51,7 +51,7 @@ func TestBatchCreateLogs(t *testing.T) {
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true
return nil
},
@@ -155,7 +155,7 @@ func TestBatchCreateLogs(t *testing.T) {
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true
return nil
},
@@ -203,7 +203,7 @@ func TestBatchCreateLogs(t *testing.T) {
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true
return nil
},
@@ -296,7 +296,7 @@ func TestBatchCreateLogs(t *testing.T) {
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true
return nil
},
@@ -340,7 +340,7 @@ func TestBatchCreateLogs(t *testing.T) {
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true
return nil
},
@@ -387,7 +387,7 @@ func TestBatchCreateLogs(t *testing.T) {
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true
return nil
},
+6 -5
View File
@@ -32,16 +32,12 @@ type ManifestAPI struct {
DerpForceWebSockets bool
WorkspaceID uuid.UUID
AgentFn func(context.Context) (database.WorkspaceAgent, error)
AgentFn func(ctx context.Context) (database.WorkspaceAgent, error)
Database database.Store
DerpMapFn func() *tailcfg.DERPMap
}
func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
var (
dbApps []database.WorkspaceApp
scripts []database.WorkspaceAgentScript
@@ -50,6 +46,11 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
devcontainers []database.WorkspaceAgentDevcontainer
)
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, xerrors.Errorf("getting workspace agent: %w", err)
}
var eg errgroup.Group
eg.Go(func() (err error) {
dbApps, err = a.Database.GetWorkspaceAppsByAgentID(ctx, workspaceAgent.ID)
+3 -9
View File
@@ -322,9 +322,7 @@ func TestGetManifest(t *testing.T) {
DisableDirectConnections: true,
DerpForceWebSockets: true,
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil },
WorkspaceID: workspace.ID,
Database: mDB,
DerpMapFn: derpMapFn,
@@ -389,9 +387,7 @@ func TestGetManifest(t *testing.T) {
DisableDirectConnections: true,
DerpForceWebSockets: true,
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return childAgent, nil
},
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return childAgent, nil },
WorkspaceID: workspace.ID,
Database: mDB,
DerpMapFn: derpMapFn,
@@ -512,9 +508,7 @@ func TestGetManifest(t *testing.T) {
DisableDirectConnections: true,
DerpForceWebSockets: true,
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil },
WorkspaceID: workspace.ID,
Database: mDB,
DerpMapFn: derpMapFn,
+4 -22
View File
@@ -5,18 +5,18 @@ import (
"fmt"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
)
type MetadataAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
AgentID uuid.UUID
Workspace *CachedWorkspaceFields
Database database.Store
Log slog.Logger
@@ -45,29 +45,11 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
maxErrorLen = maxValueLen
)
// Inject RBAC object into context for dbauthz fast path, avoid having to
// call GetWorkspaceByAgentID on every metadata update.
var err error
rbacCtx := ctx
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
if err != nil {
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
//nolint:gocritic
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
}
}
workspaceAgent, err := a.AgentFn(rbacCtx)
if err != nil {
return nil, err
}
var (
collectedAt = a.now()
allKeysLen = 0
dbUpdate = database.UpdateWorkspaceAgentMetadataParams{
WorkspaceAgentID: workspaceAgent.ID,
WorkspaceAgentID: a.AgentID,
// These need to be `make(x, 0, len(req.Metadata))` instead of
// `make(x, len(req.Metadata))` because we may not insert all
// metadata if the keys are large.
@@ -121,7 +103,7 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
}
// Use batcher to batch metadata updates.
err = a.Batcher.Add(workspaceAgent.ID, dbUpdate.Key, dbUpdate.Value, dbUpdate.Error, dbUpdate.CollectedAt)
err := a.Batcher.Add(a.AgentID, dbUpdate.Key, dbUpdate.Value, dbUpdate.Error, dbUpdate.CollectedAt)
if err != nil {
return nil, xerrors.Errorf("add metadata to batcher: %w", err)
}
+3 -9
View File
@@ -80,9 +80,7 @@ func TestBatchUpdateMetadata(t *testing.T) {
t.Cleanup(batcher.Close)
api := &agentapi.MetadataAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
Workspace: &agentapi.CachedWorkspaceFields{},
Log: testutil.Logger(t),
Batcher: batcher,
@@ -159,9 +157,7 @@ func TestBatchUpdateMetadata(t *testing.T) {
t.Cleanup(batcher.Close)
api := &agentapi.MetadataAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
Workspace: &agentapi.CachedWorkspaceFields{},
Log: testutil.Logger(t),
Batcher: batcher,
@@ -241,9 +237,7 @@ func TestBatchUpdateMetadata(t *testing.T) {
t.Cleanup(batcher.Close)
api := &agentapi.MetadataAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
Workspace: &agentapi.CachedWorkspaceFields{},
Log: testutil.Logger(t),
Batcher: batcher,
+8 -25
View File
@@ -4,20 +4,21 @@ import (
"context"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/durationpb"
"cdr.dev/slog/v3"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/codersdk"
)
type StatsAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
AgentID uuid.UUID
AgentName string
Workspace *CachedWorkspaceFields
Database database.Store
Log slog.Logger
@@ -44,32 +45,13 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
return res, nil
}
// Inject RBAC object into context for dbauthz fast path, avoid having to
// call GetWorkspaceAgentByID on every stats update.
rbacCtx := ctx
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
var err error
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
if err != nil {
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
//nolint:gocritic
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
}
}
workspaceAgent, err := a.AgentFn(rbacCtx)
if err != nil {
return nil, err
}
// If cache is empty (prebuild or invalid), fall back to DB
var ws database.WorkspaceIdentity
var ok bool
if ws, ok = a.Workspace.AsWorkspaceIdentity(); !ok {
w, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
w, err := a.Database.GetWorkspaceByAgentID(ctx, a.AgentID)
if err != nil {
return nil, xerrors.Errorf("get workspace by agent ID %q: %w", workspaceAgent.ID, err)
return nil, xerrors.Errorf("get workspace by agent ID %q: %w", a.AgentID, err)
}
ws = database.WorkspaceIdentityFromWorkspace(w)
}
@@ -90,11 +72,12 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
req.Stats.SessionCountReconnectingPty = 0
}
err = a.StatsReporter.ReportAgentStats(
err := a.StatsReporter.ReportAgentStats(
ctx,
a.now(),
ws,
workspaceAgent,
a.AgentID,
a.AgentName,
req.Stats,
false,
)
+12 -18
View File
@@ -119,9 +119,8 @@ func TestUpdateStats(t *testing.T) {
}
)
api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
AgentName: agent.Name,
Workspace: &workspaceAsCacheFields,
Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -229,9 +228,8 @@ func TestUpdateStats(t *testing.T) {
}
)
api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
AgentName: agent.Name,
Workspace: &workspaceAsCacheFields,
Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -264,9 +262,8 @@ func TestUpdateStats(t *testing.T) {
}
)
api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
AgentName: agent.Name,
Workspace: &workspaceAsCacheFields,
Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -347,9 +344,8 @@ func TestUpdateStats(t *testing.T) {
// ws.AutostartSchedule = workspace.AutostartSchedule
api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
AgentName: agent.Name,
Workspace: &ws,
Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -459,9 +455,8 @@ func TestUpdateStats(t *testing.T) {
)
defer wut.Close()
api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
AgentName: agent.Name,
Workspace: &workspaceAsCacheFields,
Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -596,9 +591,8 @@ func TestUpdateStats(t *testing.T) {
}
)
api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
AgentID: agent.ID,
AgentName: agent.Name,
Workspace: &workspaceAsCacheFields,
Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
+6 -2
View File
@@ -25,7 +25,6 @@ import (
type SubAgentAPI struct {
OwnerID uuid.UUID
OrganizationID uuid.UUID
AgentID uuid.UUID
AgentFn func(context.Context) (database.WorkspaceAgent, error)
Log slog.Logger
@@ -295,7 +294,12 @@ func (a *SubAgentAPI) ListSubAgents(ctx context.Context, _ *agentproto.ListSubAg
//nolint:gocritic // This gives us only the permissions required to do the job.
ctx = dbauthz.AsSubAgentAPI(ctx, a.OrganizationID, a.OwnerID)
workspaceAgents, err := a.Database.GetWorkspaceAgentsByParentID(ctx, a.AgentID)
parentAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, xerrors.Errorf("get parent agent: %w", err)
}
workspaceAgents, err := a.Database.GetWorkspaceAgentsByParentID(ctx, parentAgent.ID)
if err != nil {
return nil, err
}
+3 -6
View File
@@ -81,12 +81,9 @@ func TestSubAgentAPI(t *testing.T) {
return &agentapi.SubAgentAPI{
OwnerID: user.ID,
OrganizationID: org.ID,
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Clock: clock,
Database: dbauthz.New(db, auth, logger, accessControlStore),
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil },
Clock: clock,
Database: dbauthz.New(db, auth, logger, accessControlStore),
}
}
+11 -6
View File
@@ -314,17 +314,22 @@ func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Req
// This functionality has been moved to the AppsAPI in the agentapi. We keep this HTTP handler around for back
// compatibility with older agents. We'll translate the request into the protobuf so there is only one primary
// implementation.
cachedWs := &agentapi.CachedWorkspaceFields{}
cachedWs.UpdateValues(workspace)
appAPI := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return workspaceAgent, nil
AgentID: workspaceAgent.ID,
Database: api.Database,
Log: api.Logger,
Workspace: cachedWs,
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return api.Database.GetWorkspaceAgentByID(ctx, workspaceAgent.ID)
},
Database: api.Database,
Log: api.Logger,
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, agentID uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{
Kind: kind,
WorkspaceID: workspace.ID,
AgentID: &agent.ID,
AgentID: &agentID,
})
return nil
},
+1 -1
View File
@@ -178,7 +178,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
// Optional:
UpdateAgentMetricsFn: api.UpdateAgentMetrics,
}, workspace)
}, workspace, workspaceAgent)
streamID := tailnet.StreamID{
Name: fmt.Sprintf("%s-%s-%s", workspace.OwnerUsername, workspace.Name, workspaceAgent.Name),
+1 -1
View File
@@ -1753,7 +1753,7 @@ func (api *API) postWorkspaceUsage(rw http.ResponseWriter, r *http.Request) {
// return
// }
err = api.statsReporter.ReportAgentStats(ctx, dbtime.Now(), database.WorkspaceIdentityFromWorkspace(workspace), agent, stat, true)
err = api.statsReporter.ReportAgentStats(ctx, dbtime.Now(), database.WorkspaceIdentityFromWorkspace(workspace), agent.ID, agent.Name, stat, true)
if err != nil {
httpapi.InternalServerError(rw, err)
return
+3 -3
View File
@@ -137,10 +137,10 @@ func (r *Reporter) ReportAppStats(ctx context.Context, stats []workspaceapps.Sta
}
// nolint:revive // usage is a control flag while we have the experiment
func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspace database.WorkspaceIdentity, workspaceAgent database.WorkspaceAgent, stats *agentproto.Stats, usage bool) error {
func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspace database.WorkspaceIdentity, agentID uuid.UUID, agentName string, stats *agentproto.Stats, usage bool) error {
// update agent stats
if !r.opts.DisableDatabaseInserts {
r.opts.StatsBatcher.Add(now, workspaceAgent.ID, workspace.TemplateID, workspace.OwnerID, workspace.ID, stats, usage)
r.opts.StatsBatcher.Add(now, agentID, workspace.TemplateID, workspace.OwnerID, workspace.ID, stats, usage)
}
// update prometheus metrics (even if template insights are disabled)
@@ -148,7 +148,7 @@ func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspac
r.opts.UpdateAgentMetricsFn(ctx, prometheusmetrics.AgentMetricLabels{
Username: workspace.OwnerUsername,
WorkspaceName: workspace.Name,
AgentName: workspaceAgent.Name,
AgentName: agentName,
TemplateName: workspace.TemplateName,
}, stats.Metrics)
}
+1
View File
@@ -182,6 +182,7 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) (err error)
wchRef, rchRef := wch, rch
for {
if wchRef == nil && rchRef == nil {
logger.Info(ctx, "reading and writing to agent complete! Closing connection")
return nil
}