fix: limit calls to GetWorkspaceAgentByID in agentapi (#23015)

We currently call GetWorkspaceAgentByID millions of times at scale
unnecessarily. This PR embeds immutable fields into the relevant
services instead of fetching for them every time.

resolves https://github.com/coder/scaletest/issues/84

Confirmed with a 10k scaletest that this changeset takes the query from
10M+ queries down to 39k
This commit is contained in:
Jon Ayers
2026-03-20 15:42:05 -05:00
committed by GitHub
parent 32021b3ac2
commit f135ffdb3a
23 changed files with 173 additions and 228 deletions
+13 -8
View File
@@ -103,7 +103,7 @@ type Options struct {
UpdateAgentMetricsFn func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) UpdateAgentMetricsFn func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
} }
func New(opts Options, workspace database.Workspace) *API { func New(opts Options, workspace database.Workspace, agent database.WorkspaceAgent) *API {
if opts.Clock == nil { if opts.Clock == nil {
opts.Clock = quartz.NewReal() opts.Clock = quartz.NewReal()
} }
@@ -156,7 +156,8 @@ func New(opts Options, workspace database.Workspace) *API {
} }
api.StatsAPI = &StatsAPI{ api.StatsAPI = &StatsAPI{
AgentFn: api.agent, AgentID: agent.ID,
AgentName: agent.Name,
Workspace: api.cachedWorkspaceFields, Workspace: api.cachedWorkspaceFields,
Database: opts.Database, Database: opts.Database,
Log: opts.Log, Log: opts.Log,
@@ -175,16 +176,18 @@ func New(opts Options, workspace database.Workspace) *API {
} }
api.AppsAPI = &AppsAPI{ api.AppsAPI = &AppsAPI{
AgentID: agent.ID,
AgentFn: api.agent, AgentFn: api.agent,
Database: opts.Database, Database: opts.Database,
Log: opts.Log, Log: opts.Log,
Workspace: api.cachedWorkspaceFields,
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate, PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
Clock: opts.Clock, Clock: opts.Clock,
NotificationsEnqueuer: opts.NotificationsEnqueuer, NotificationsEnqueuer: opts.NotificationsEnqueuer,
} }
api.MetadataAPI = &MetadataAPI{ api.MetadataAPI = &MetadataAPI{
AgentFn: api.agent, AgentID: agent.ID,
Workspace: api.cachedWorkspaceFields, Workspace: api.cachedWorkspaceFields,
Database: opts.Database, Database: opts.Database,
Log: opts.Log, Log: opts.Log,
@@ -204,7 +207,8 @@ func New(opts Options, workspace database.Workspace) *API {
} }
api.ConnLogAPI = &ConnLogAPI{ api.ConnLogAPI = &ConnLogAPI{
AgentFn: api.agent, AgentID: agent.ID,
AgentName: agent.Name,
ConnectionLogger: opts.ConnectionLogger, ConnectionLogger: opts.ConnectionLogger,
Database: opts.Database, Database: opts.Database,
Workspace: api.cachedWorkspaceFields, Workspace: api.cachedWorkspaceFields,
@@ -222,7 +226,6 @@ func New(opts Options, workspace database.Workspace) *API {
api.SubAgentAPI = &SubAgentAPI{ api.SubAgentAPI = &SubAgentAPI{
OwnerID: opts.OwnerID, OwnerID: opts.OwnerID,
OrganizationID: opts.OrganizationID, OrganizationID: opts.OrganizationID,
AgentID: opts.AgentID,
AgentFn: api.agent, AgentFn: api.agent,
Log: opts.Log, Log: opts.Log,
Clock: opts.Clock, Clock: opts.Clock,
@@ -297,8 +300,10 @@ func (a *API) agent(ctx context.Context) (database.WorkspaceAgent, error) {
func (a *API) refreshCachedWorkspace(ctx context.Context) { func (a *API) refreshCachedWorkspace(ctx context.Context) {
ws, err := a.opts.Database.GetWorkspaceByID(ctx, a.opts.WorkspaceID) ws, err := a.opts.Database.GetWorkspaceByID(ctx, a.opts.WorkspaceID)
if err != nil { if err != nil {
// Do not clear the cache on transient DB errors. Stale data is
// preferable to no data, which forces callers to fall back to
// expensive queries like GetWorkspaceByAgentID.
a.opts.Log.Warn(ctx, "failed to refresh cached workspace fields", slog.Error(err)) a.opts.Log.Warn(ctx, "failed to refresh cached workspace fields", slog.Error(err))
a.cachedWorkspaceFields.Clear()
return return
} }
@@ -341,11 +346,11 @@ func (a *API) startCacheRefreshLoop(ctx context.Context) {
a.cachedWorkspaceFields.Clear() a.cachedWorkspaceFields.Clear()
} }
func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { func (a *API) publishWorkspaceUpdate(ctx context.Context, agentID uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
a.opts.PublishWorkspaceUpdateFn(ctx, a.opts.OwnerID, wspubsub.WorkspaceEvent{ a.opts.PublishWorkspaceUpdateFn(ctx, a.opts.OwnerID, wspubsub.WorkspaceEvent{
Kind: kind, Kind: kind,
WorkspaceID: a.opts.WorkspaceID, WorkspaceID: a.opts.WorkspaceID,
AgentID: &agent.ID, AgentID: &agentID,
}) })
return nil return nil
} }
+38 -33
View File
@@ -24,22 +24,19 @@ import (
) )
type AppsAPI struct { type AppsAPI struct {
AgentID uuid.UUID
AgentFn func(context.Context) (database.WorkspaceAgent, error) AgentFn func(context.Context) (database.WorkspaceAgent, error)
Database database.Store Database database.Store
Log slog.Logger Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error Workspace *CachedWorkspaceFields
PublishWorkspaceUpdateFn func(context.Context, uuid.UUID, wspubsub.WorkspaceEventKind) error
NotificationsEnqueuer notifications.Enqueuer NotificationsEnqueuer notifications.Enqueuer
Clock quartz.Clock Clock quartz.Clock
} }
func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) { func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
a.Log.Debug(ctx, "got batch app health update", a.Log.Debug(ctx, "got batch app health update",
slog.F("agent_id", workspaceAgent.ID.String()), slog.F("agent_id", a.AgentID.String()),
slog.F("updates", req.Updates), slog.F("updates", req.Updates),
) )
@@ -47,9 +44,9 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
return &agentproto.BatchUpdateAppHealthResponse{}, nil return &agentproto.BatchUpdateAppHealthResponse{}, nil
} }
apps, err := a.Database.GetWorkspaceAppsByAgentID(ctx, workspaceAgent.ID) apps, err := a.Database.GetWorkspaceAppsByAgentID(ctx, a.AgentID)
if err != nil { if err != nil {
return nil, xerrors.Errorf("get workspace apps by agent ID %q: %w", workspaceAgent.ID, err) return nil, xerrors.Errorf("get workspace apps by agent ID %q: %w", a.AgentID, err)
} }
var newApps []database.WorkspaceApp var newApps []database.WorkspaceApp
@@ -110,7 +107,7 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
} }
if a.PublishWorkspaceUpdateFn != nil && len(newApps) > 0 { if a.PublishWorkspaceUpdateFn != nil && len(newApps) > 0 {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAppHealthUpdate) err = a.PublishWorkspaceUpdateFn(ctx, a.AgentID, wspubsub.WorkspaceEventKindAppHealthUpdate)
if err != nil { if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err) return nil, xerrors.Errorf("publish workspace update: %w", err)
} }
@@ -149,12 +146,8 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
}) })
} }
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
app, err := a.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{ app, err := a.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: workspaceAgent.ID, AgentID: a.AgentID,
Slug: req.Slug, Slug: req.Slug,
}) })
if err != nil { if err != nil {
@@ -164,11 +157,10 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
}) })
} }
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID) ws, ok := a.Workspace.AsWorkspaceIdentity()
if err != nil { if !ok {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{ return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get workspace.", Message: "Workspace identity not cached.",
Detail: err.Error(),
}) })
} }
@@ -190,8 +182,8 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
_, err = a.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{ _, err = a.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
ID: uuid.New(), ID: uuid.New(),
CreatedAt: dbtime.Now(), CreatedAt: dbtime.Now(),
WorkspaceID: workspace.ID, WorkspaceID: ws.ID,
AgentID: workspaceAgent.ID, AgentID: a.AgentID,
AppID: app.ID, AppID: app.ID,
State: dbState, State: dbState,
Message: cleaned, Message: cleaned,
@@ -208,7 +200,7 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
} }
if a.PublishWorkspaceUpdateFn != nil { if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentAppStatusUpdate) err = a.PublishWorkspaceUpdateFn(ctx, a.AgentID, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
if err != nil { if err != nil {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{ return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to publish workspace update.", Message: "Failed to publish workspace update.",
@@ -217,14 +209,14 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
} }
} }
// Notify on state change to Working/Idle for AI tasks // Notify on state change to Working/Idle for AI tasks.
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState, workspace, workspaceAgent) a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState)
if shouldBump(dbState, latestAppStatus) { if shouldBump(dbState, latestAppStatus) {
// We pass time.Time{} for nextAutostart since we don't have access to // We pass time.Time{} for nextAutostart since we don't have access to
// TemplateScheduleStore here. The activity bump logic handles this by // TemplateScheduleStore here. The activity bump logic handles this by
// defaulting to the template's activity_bump duration (typically 1 hour). // defaulting to the template's activity_bump duration (typically 1 hour).
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, workspace.ID, time.Time{}) workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, ws.ID, time.Time{})
} }
// just return a blank response because it doesn't contain any settable fields at present. // just return a blank response because it doesn't contain any settable fields at present.
return new(agentproto.UpdateAppStatusResponse), nil return new(agentproto.UpdateAppStatusResponse), nil
@@ -261,8 +253,6 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
appID uuid.UUID, appID uuid.UUID,
latestAppStatus database.WorkspaceAppStatus, latestAppStatus database.WorkspaceAppStatus,
newAppStatus database.WorkspaceAppStatusState, newAppStatus database.WorkspaceAppStatusState,
workspace database.Workspace,
agent database.WorkspaceAgent,
) { ) {
var notificationTemplate uuid.UUID var notificationTemplate uuid.UUID
switch newAppStatus { switch newAppStatus {
@@ -279,11 +269,20 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
return return
} }
if !workspace.TaskID.Valid { taskID := a.Workspace.TaskID()
if !taskID.Valid {
// Workspace has no task ID, do nothing. // Workspace has no task ID, do nothing.
return return
} }
// Only fetch fresh agent state for task workspaces, since we need
// the current lifecycle state to decide whether to send notifications.
agent, err := a.AgentFn(ctx)
if err != nil {
a.Log.Warn(ctx, "failed to get agent for AI task notification", slog.Error(err))
return
}
// Only send notifications when the agent is ready. We want to skip // Only send notifications when the agent is ready. We want to skip
// any state transitions that occur whilst the workspace is starting // any state transitions that occur whilst the workspace is starting
// up as it doesn't make sense to receive them. // up as it doesn't make sense to receive them.
@@ -296,7 +295,7 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
return return
} }
task, err := a.Database.GetTaskByID(ctx, workspace.TaskID.UUID) task, err := a.Database.GetTaskByID(ctx, taskID.UUID)
if err != nil { if err != nil {
a.Log.Warn(ctx, "failed to get task", slog.Error(err)) a.Log.Warn(ctx, "failed to get task", slog.Error(err))
return return
@@ -321,14 +320,20 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
return return
} }
ws, ok := a.Workspace.AsWorkspaceIdentity()
if !ok {
a.Log.Warn(ctx, "failed to get workspace identity for AI task notification")
return
}
if _, err := a.NotificationsEnqueuer.EnqueueWithData( if _, err := a.NotificationsEnqueuer.EnqueueWithData(
// nolint:gocritic // Need notifier actor to enqueue notifications // nolint:gocritic // Need notifier actor to enqueue notifications
dbauthz.AsNotifier(ctx), dbauthz.AsNotifier(ctx),
workspace.OwnerID, ws.OwnerID,
notificationTemplate, notificationTemplate,
map[string]string{ map[string]string{
"task": task.Name, "task": task.Name,
"workspace": workspace.Name, "workspace": ws.Name,
}, },
map[string]any{ map[string]any{
// Use a 1-minute bucketed timestamp to bypass per-day dedupe, // Use a 1-minute bucketed timestamp to bypass per-day dedupe,
@@ -338,7 +343,7 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
}, },
"api-workspace-agent-app-status", "api-workspace-agent-app-status",
// Associate this notification with related entities // Associate this notification with related entities
workspace.ID, workspace.OwnerID, workspace.OrganizationID, appID, ws.ID, ws.OwnerID, ws.OrganizationID, appID,
); err != nil { ); err != nil {
a.Log.Warn(ctx, "failed to notify of task state", slog.Error(err)) a.Log.Warn(ctx, "failed to notify of task state", slog.Error(err))
return return
+28 -42
View File
@@ -67,12 +67,10 @@ func TestBatchUpdateAppHealths(t *testing.T) {
publishCalled := false publishCalled := false
api := &agentapi.AppsAPI{ api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil
},
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true publishCalled = true
return nil return nil
}, },
@@ -105,12 +103,10 @@ func TestBatchUpdateAppHealths(t *testing.T) {
publishCalled := false publishCalled := false
api := &agentapi.AppsAPI{ api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil
},
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true publishCalled = true
return nil return nil
}, },
@@ -144,12 +140,10 @@ func TestBatchUpdateAppHealths(t *testing.T) {
publishCalled := false publishCalled := false
api := &agentapi.AppsAPI{ api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil
},
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true publishCalled = true
return nil return nil
}, },
@@ -180,9 +174,7 @@ func TestBatchUpdateAppHealths(t *testing.T) {
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app3}, nil) dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app3}, nil)
api := &agentapi.AppsAPI{ api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil
},
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: nil, PublishWorkspaceUpdateFn: nil,
@@ -209,9 +201,7 @@ func TestBatchUpdateAppHealths(t *testing.T) {
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil) dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
api := &agentapi.AppsAPI{ api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil
},
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: nil, PublishWorkspaceUpdateFn: nil,
@@ -239,9 +229,7 @@ func TestBatchUpdateAppHealths(t *testing.T) {
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil) dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
api := &agentapi.AppsAPI{ api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil
},
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: nil, PublishWorkspaceUpdateFn: nil,
@@ -279,14 +267,26 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
} }
workspaceUpdates := make(chan wspubsub.WorkspaceEventKind, 100) workspaceUpdates := make(chan wspubsub.WorkspaceEventKind, 100)
workspace := database.Workspace{
ID: uuid.UUID{9},
TaskID: uuid.NullUUID{
Valid: true,
UUID: uuid.UUID{7},
},
}
cachedWs := &agentapi.CachedWorkspaceFields{}
cachedWs.UpdateValues(workspace)
api := &agentapi.AppsAPI{ api := &agentapi.AppsAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil return agent, nil
}, },
Database: mDB, Database: mDB,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(_ context.Context, agnt *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { Workspace: cachedWs,
assert.Equal(t, *agnt, agent) PublishWorkspaceUpdateFn: func(_ context.Context, agnt uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
assert.Equal(t, agnt, agent.ID)
testutil.AssertSend(ctx, t, workspaceUpdates, kind) testutil.AssertSend(ctx, t, workspaceUpdates, kind)
return nil return nil
}, },
@@ -309,14 +309,6 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
}, },
} }
mDB.EXPECT().GetTaskByID(gomock.Any(), task.ID).Times(1).Return(task, nil) mDB.EXPECT().GetTaskByID(gomock.Any(), task.ID).Times(1).Return(task, nil)
workspace := database.Workspace{
ID: uuid.UUID{9},
TaskID: uuid.NullUUID{
Valid: true,
UUID: task.ID,
},
}
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Times(1).Return(workspace, nil)
appStatus := database.WorkspaceAppStatus{ appStatus := database.WorkspaceAppStatus{
ID: uuid.UUID{6}, ID: uuid.UUID{6},
} }
@@ -363,9 +355,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
Return(database.WorkspaceApp{}, sql.ErrNoRows) Return(database.WorkspaceApp{}, sql.ErrNoRows)
api := &agentapi.AppsAPI{ api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil
},
Database: mDB, Database: mDB,
Log: testutil.Logger(t), Log: testutil.Logger(t),
} }
@@ -392,9 +382,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
} }
api := &agentapi.AppsAPI{ api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil
},
Database: mDB, Database: mDB,
Log: testutil.Logger(t), Log: testutil.Logger(t),
} }
@@ -422,9 +410,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
} }
api := &agentapi.AppsAPI{ api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil
},
Database: mDB, Database: mDB,
Log: testutil.Logger(t), Log: testutil.Logger(t),
} }
+10
View File
@@ -4,6 +4,7 @@ import (
"context" "context"
"sync" "sync"
"github.com/google/uuid"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
@@ -23,12 +24,14 @@ type CachedWorkspaceFields struct {
lock sync.RWMutex lock sync.RWMutex
identity database.WorkspaceIdentity identity database.WorkspaceIdentity
taskID uuid.NullUUID
} }
func (cws *CachedWorkspaceFields) Clear() { func (cws *CachedWorkspaceFields) Clear() {
cws.lock.Lock() cws.lock.Lock()
defer cws.lock.Unlock() defer cws.lock.Unlock()
cws.identity = database.WorkspaceIdentity{} cws.identity = database.WorkspaceIdentity{}
cws.taskID = uuid.NullUUID{}
} }
func (cws *CachedWorkspaceFields) UpdateValues(ws database.Workspace) { func (cws *CachedWorkspaceFields) UpdateValues(ws database.Workspace) {
@@ -42,6 +45,13 @@ func (cws *CachedWorkspaceFields) UpdateValues(ws database.Workspace) {
cws.identity.OwnerUsername = ws.OwnerUsername cws.identity.OwnerUsername = ws.OwnerUsername
cws.identity.TemplateName = ws.TemplateName cws.identity.TemplateName = ws.TemplateName
cws.identity.AutostartSchedule = ws.AutostartSchedule cws.identity.AutostartSchedule = ws.AutostartSchedule
cws.taskID = ws.TaskID
}
func (cws *CachedWorkspaceFields) TaskID() uuid.NullUUID {
cws.lock.RLock()
defer cws.lock.RUnlock()
return cws.taskID
} }
// Returns the Workspace, true, unless the workspace has not been cached (nuked or was a prebuild). // Returns the Workspace, true, unless the workspace has not been cached (nuked or was a prebuild).
+4 -19
View File
@@ -14,11 +14,11 @@ import (
"github.com/coder/coder/v2/coderd/connectionlog" "github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
) )
type ConnLogAPI struct { type ConnLogAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error) AgentID uuid.UUID
AgentName string
ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger] ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger]
Workspace *CachedWorkspaceFields Workspace *CachedWorkspaceFields
Database database.Store Database database.Store
@@ -53,27 +53,12 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
} }
} }
// Inject RBAC object into context for dbauthz fast path, avoid having to
// call GetWorkspaceByAgentID on every metadata update.
rbacCtx := ctx
var ws database.WorkspaceIdentity var ws database.WorkspaceIdentity
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok { if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
ws = dbws ws = dbws
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
if err != nil {
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
//nolint:gocritic
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
}
}
// Fetch contextual data for this connection log event.
workspaceAgent, err := a.AgentFn(rbacCtx)
if err != nil {
return nil, xerrors.Errorf("get agent: %w", err)
} }
if ws.Equal(database.WorkspaceIdentity{}) { if ws.Equal(database.WorkspaceIdentity{}) {
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID) workspace, err := a.Database.GetWorkspaceByAgentID(ctx, a.AgentID)
if err != nil { if err != nil {
return nil, xerrors.Errorf("get workspace by agent id: %w", err) return nil, xerrors.Errorf("get workspace by agent id: %w", err)
} }
@@ -97,7 +82,7 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
WorkspaceOwnerID: ws.OwnerID, WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID, WorkspaceID: ws.ID,
WorkspaceName: ws.Name, WorkspaceName: ws.Name,
AgentName: workspaceAgent.Name, AgentName: a.AgentName,
Type: connectionType, Type: connectionType,
Code: code, Code: code,
Ip: logIP, Ip: logIP,
+3 -4
View File
@@ -114,10 +114,9 @@ func TestConnectionLog(t *testing.T) {
api := &agentapi.ConnLogAPI{ api := &agentapi.ConnLogAPI{
ConnectionLogger: asAtomicPointer[connectionlog.ConnectionLogger](connLogger), ConnectionLogger: asAtomicPointer[connectionlog.ConnectionLogger](connLogger),
Database: mDB, Database: mDB,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil AgentName: agent.Name,
}, Workspace: &agentapi.CachedWorkspaceFields{},
Workspace: &agentapi.CachedWorkspaceFields{},
} }
api.ReportConnection(context.Background(), &agentproto.ReportConnectionRequest{ api.ReportConnection(context.Background(), &agentproto.ReportConnectionRequest{
Connection: &agentproto.Connection{ Connection: &agentproto.Connection{
+2 -2
View File
@@ -30,7 +30,7 @@ type LifecycleAPI struct {
WorkspaceID uuid.UUID WorkspaceID uuid.UUID
Database database.Store Database database.Store
Log slog.Logger Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error PublishWorkspaceUpdateFn func(context.Context, uuid.UUID, wspubsub.WorkspaceEventKind) error
TimeNowFn func() time.Time // defaults to dbtime.Now() TimeNowFn func() time.Time // defaults to dbtime.Now()
Metrics *LifecycleMetrics Metrics *LifecycleMetrics
@@ -122,7 +122,7 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
} }
if a.PublishWorkspaceUpdateFn != nil { if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentLifecycleUpdate) err = a.PublishWorkspaceUpdateFn(ctx, workspaceAgent.ID, wspubsub.WorkspaceEventKindAgentLifecycleUpdate)
if err != nil { if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err) return nil, xerrors.Errorf("publish workspace update: %w", err)
} }
+4 -4
View File
@@ -85,7 +85,7 @@ func TestUpdateLifecycle(t *testing.T) {
WorkspaceID: workspaceID, WorkspaceID: workspaceID,
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true publishCalled = true
return nil return nil
}, },
@@ -206,7 +206,7 @@ func TestUpdateLifecycle(t *testing.T) {
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
Metrics: metrics, Metrics: metrics,
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true publishCalled = true
return nil return nil
}, },
@@ -323,7 +323,7 @@ func TestUpdateLifecycle(t *testing.T) {
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
Metrics: metrics, Metrics: metrics,
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
atomic.AddInt64(&publishCalled, 1) atomic.AddInt64(&publishCalled, 1)
return nil return nil
}, },
@@ -410,7 +410,7 @@ func TestUpdateLifecycle(t *testing.T) {
WorkspaceID: workspaceID, WorkspaceID: workspaceID,
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true publishCalled = true
return nil return nil
}, },
+3 -3
View File
@@ -19,7 +19,7 @@ type LogsAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error) AgentFn func(context.Context) (database.WorkspaceAgent, error)
Database database.Store Database database.Store
Log slog.Logger Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error PublishWorkspaceUpdateFn func(context.Context, uuid.UUID, wspubsub.WorkspaceEventKind) error
PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage)
TimeNowFn func() time.Time // defaults to dbtime.Now() TimeNowFn func() time.Time // defaults to dbtime.Now()
@@ -125,7 +125,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
} }
if a.PublishWorkspaceUpdateFn != nil { if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentLogsOverflow) err = a.PublishWorkspaceUpdateFn(ctx, workspaceAgent.ID, wspubsub.WorkspaceEventKindAgentLogsOverflow)
if err != nil { if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err) return nil, xerrors.Errorf("publish workspace update: %w", err)
} }
@@ -145,7 +145,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
if workspaceAgent.LogsLength == 0 && a.PublishWorkspaceUpdateFn != nil { if workspaceAgent.LogsLength == 0 && a.PublishWorkspaceUpdateFn != nil {
// If these are the first logs being appended, we publish a UI update // If these are the first logs being appended, we publish a UI update
// to notify the UI that logs are now available. // to notify the UI that logs are now available.
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentFirstLogs) err = a.PublishWorkspaceUpdateFn(ctx, workspaceAgent.ID, wspubsub.WorkspaceEventKindAgentFirstLogs)
if err != nil { if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err) return nil, xerrors.Errorf("publish workspace update: %w", err)
} }
+6 -6
View File
@@ -51,7 +51,7 @@ func TestBatchCreateLogs(t *testing.T) {
}, },
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true publishWorkspaceUpdateCalled = true
return nil return nil
}, },
@@ -155,7 +155,7 @@ func TestBatchCreateLogs(t *testing.T) {
}, },
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true publishWorkspaceUpdateCalled = true
return nil return nil
}, },
@@ -203,7 +203,7 @@ func TestBatchCreateLogs(t *testing.T) {
}, },
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true publishWorkspaceUpdateCalled = true
return nil return nil
}, },
@@ -296,7 +296,7 @@ func TestBatchCreateLogs(t *testing.T) {
}, },
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true publishWorkspaceUpdateCalled = true
return nil return nil
}, },
@@ -340,7 +340,7 @@ func TestBatchCreateLogs(t *testing.T) {
}, },
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true publishWorkspaceUpdateCalled = true
return nil return nil
}, },
@@ -387,7 +387,7 @@ func TestBatchCreateLogs(t *testing.T) {
}, },
Database: dbM, Database: dbM,
Log: testutil.Logger(t), Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error { PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true publishWorkspaceUpdateCalled = true
return nil return nil
}, },
+6 -5
View File
@@ -32,16 +32,12 @@ type ManifestAPI struct {
DerpForceWebSockets bool DerpForceWebSockets bool
WorkspaceID uuid.UUID WorkspaceID uuid.UUID
AgentFn func(context.Context) (database.WorkspaceAgent, error) AgentFn func(ctx context.Context) (database.WorkspaceAgent, error)
Database database.Store Database database.Store
DerpMapFn func() *tailcfg.DERPMap DerpMapFn func() *tailcfg.DERPMap
} }
func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifestRequest) (*agentproto.Manifest, error) { func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
var ( var (
dbApps []database.WorkspaceApp dbApps []database.WorkspaceApp
scripts []database.WorkspaceAgentScript scripts []database.WorkspaceAgentScript
@@ -50,6 +46,11 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
devcontainers []database.WorkspaceAgentDevcontainer devcontainers []database.WorkspaceAgentDevcontainer
) )
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, xerrors.Errorf("getting workspace agent: %w", err)
}
var eg errgroup.Group var eg errgroup.Group
eg.Go(func() (err error) { eg.Go(func() (err error) {
dbApps, err = a.Database.GetWorkspaceAppsByAgentID(ctx, workspaceAgent.ID) dbApps, err = a.Database.GetWorkspaceAppsByAgentID(ctx, workspaceAgent.ID)
+3 -9
View File
@@ -322,9 +322,7 @@ func TestGetManifest(t *testing.T) {
DisableDirectConnections: true, DisableDirectConnections: true,
DerpForceWebSockets: true, DerpForceWebSockets: true,
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil },
return agent, nil
},
WorkspaceID: workspace.ID, WorkspaceID: workspace.ID,
Database: mDB, Database: mDB,
DerpMapFn: derpMapFn, DerpMapFn: derpMapFn,
@@ -389,9 +387,7 @@ func TestGetManifest(t *testing.T) {
DisableDirectConnections: true, DisableDirectConnections: true,
DerpForceWebSockets: true, DerpForceWebSockets: true,
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return childAgent, nil },
return childAgent, nil
},
WorkspaceID: workspace.ID, WorkspaceID: workspace.ID,
Database: mDB, Database: mDB,
DerpMapFn: derpMapFn, DerpMapFn: derpMapFn,
@@ -512,9 +508,7 @@ func TestGetManifest(t *testing.T) {
DisableDirectConnections: true, DisableDirectConnections: true,
DerpForceWebSockets: true, DerpForceWebSockets: true,
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil },
return agent, nil
},
WorkspaceID: workspace.ID, WorkspaceID: workspace.ID,
Database: mDB, Database: mDB,
DerpMapFn: derpMapFn, DerpMapFn: derpMapFn,
+4 -22
View File
@@ -5,18 +5,18 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/google/uuid"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"cdr.dev/slog/v3" "cdr.dev/slog/v3"
agentproto "github.com/coder/coder/v2/agent/proto" agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher" "github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/dbtime"
) )
type MetadataAPI struct { type MetadataAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error) AgentID uuid.UUID
Workspace *CachedWorkspaceFields Workspace *CachedWorkspaceFields
Database database.Store Database database.Store
Log slog.Logger Log slog.Logger
@@ -45,29 +45,11 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
maxErrorLen = maxValueLen maxErrorLen = maxValueLen
) )
// Inject RBAC object into context for dbauthz fast path, avoid having to
// call GetWorkspaceByAgentID on every metadata update.
var err error
rbacCtx := ctx
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
if err != nil {
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
//nolint:gocritic
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
}
}
workspaceAgent, err := a.AgentFn(rbacCtx)
if err != nil {
return nil, err
}
var ( var (
collectedAt = a.now() collectedAt = a.now()
allKeysLen = 0 allKeysLen = 0
dbUpdate = database.UpdateWorkspaceAgentMetadataParams{ dbUpdate = database.UpdateWorkspaceAgentMetadataParams{
WorkspaceAgentID: workspaceAgent.ID, WorkspaceAgentID: a.AgentID,
// These need to be `make(x, 0, len(req.Metadata))` instead of // These need to be `make(x, 0, len(req.Metadata))` instead of
// `make(x, len(req.Metadata))` because we may not insert all // `make(x, len(req.Metadata))` because we may not insert all
// metadata if the keys are large. // metadata if the keys are large.
@@ -121,7 +103,7 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
} }
// Use batcher to batch metadata updates. // Use batcher to batch metadata updates.
err = a.Batcher.Add(workspaceAgent.ID, dbUpdate.Key, dbUpdate.Value, dbUpdate.Error, dbUpdate.CollectedAt) err := a.Batcher.Add(a.AgentID, dbUpdate.Key, dbUpdate.Value, dbUpdate.Error, dbUpdate.CollectedAt)
if err != nil { if err != nil {
return nil, xerrors.Errorf("add metadata to batcher: %w", err) return nil, xerrors.Errorf("add metadata to batcher: %w", err)
} }
+3 -9
View File
@@ -80,9 +80,7 @@ func TestBatchUpdateMetadata(t *testing.T) {
t.Cleanup(batcher.Close) t.Cleanup(batcher.Close)
api := &agentapi.MetadataAPI{ api := &agentapi.MetadataAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil
},
Workspace: &agentapi.CachedWorkspaceFields{}, Workspace: &agentapi.CachedWorkspaceFields{},
Log: testutil.Logger(t), Log: testutil.Logger(t),
Batcher: batcher, Batcher: batcher,
@@ -159,9 +157,7 @@ func TestBatchUpdateMetadata(t *testing.T) {
t.Cleanup(batcher.Close) t.Cleanup(batcher.Close)
api := &agentapi.MetadataAPI{ api := &agentapi.MetadataAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil
},
Workspace: &agentapi.CachedWorkspaceFields{}, Workspace: &agentapi.CachedWorkspaceFields{},
Log: testutil.Logger(t), Log: testutil.Logger(t),
Batcher: batcher, Batcher: batcher,
@@ -241,9 +237,7 @@ func TestBatchUpdateMetadata(t *testing.T) {
t.Cleanup(batcher.Close) t.Cleanup(batcher.Close)
api := &agentapi.MetadataAPI{ api := &agentapi.MetadataAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil
},
Workspace: &agentapi.CachedWorkspaceFields{}, Workspace: &agentapi.CachedWorkspaceFields{},
Log: testutil.Logger(t), Log: testutil.Logger(t),
Batcher: batcher, Batcher: batcher,
+8 -25
View File
@@ -4,20 +4,21 @@ import (
"context" "context"
"time" "time"
"github.com/google/uuid"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/durationpb"
"cdr.dev/slog/v3" "cdr.dev/slog/v3"
agentproto "github.com/coder/coder/v2/agent/proto" agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/workspacestats" "github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
) )
type StatsAPI struct { type StatsAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error) AgentID uuid.UUID
AgentName string
Workspace *CachedWorkspaceFields Workspace *CachedWorkspaceFields
Database database.Store Database database.Store
Log slog.Logger Log slog.Logger
@@ -44,32 +45,13 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
return res, nil return res, nil
} }
// Inject RBAC object into context for dbauthz fast path, avoid having to
// call GetWorkspaceAgentByID on every stats update.
rbacCtx := ctx
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
var err error
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
if err != nil {
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
//nolint:gocritic
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
}
}
workspaceAgent, err := a.AgentFn(rbacCtx)
if err != nil {
return nil, err
}
// If cache is empty (prebuild or invalid), fall back to DB // If cache is empty (prebuild or invalid), fall back to DB
var ws database.WorkspaceIdentity var ws database.WorkspaceIdentity
var ok bool var ok bool
if ws, ok = a.Workspace.AsWorkspaceIdentity(); !ok { if ws, ok = a.Workspace.AsWorkspaceIdentity(); !ok {
w, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID) w, err := a.Database.GetWorkspaceByAgentID(ctx, a.AgentID)
if err != nil { if err != nil {
return nil, xerrors.Errorf("get workspace by agent ID %q: %w", workspaceAgent.ID, err) return nil, xerrors.Errorf("get workspace by agent ID %q: %w", a.AgentID, err)
} }
ws = database.WorkspaceIdentityFromWorkspace(w) ws = database.WorkspaceIdentityFromWorkspace(w)
} }
@@ -90,11 +72,12 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
req.Stats.SessionCountReconnectingPty = 0 req.Stats.SessionCountReconnectingPty = 0
} }
err = a.StatsReporter.ReportAgentStats( err := a.StatsReporter.ReportAgentStats(
ctx, ctx,
a.now(), a.now(),
ws, ws,
workspaceAgent, a.AgentID,
a.AgentName,
req.Stats, req.Stats,
false, false,
) )
+12 -18
View File
@@ -119,9 +119,8 @@ func TestUpdateStats(t *testing.T) {
} }
) )
api := agentapi.StatsAPI{ api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil AgentName: agent.Name,
},
Workspace: &workspaceAsCacheFields, Workspace: &workspaceAsCacheFields,
Database: dbM, Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{ StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -229,9 +228,8 @@ func TestUpdateStats(t *testing.T) {
} }
) )
api := agentapi.StatsAPI{ api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil AgentName: agent.Name,
},
Workspace: &workspaceAsCacheFields, Workspace: &workspaceAsCacheFields,
Database: dbM, Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{ StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -264,9 +262,8 @@ func TestUpdateStats(t *testing.T) {
} }
) )
api := agentapi.StatsAPI{ api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil AgentName: agent.Name,
},
Workspace: &workspaceAsCacheFields, Workspace: &workspaceAsCacheFields,
Database: dbM, Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{ StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -347,9 +344,8 @@ func TestUpdateStats(t *testing.T) {
// ws.AutostartSchedule = workspace.AutostartSchedule // ws.AutostartSchedule = workspace.AutostartSchedule
api := agentapi.StatsAPI{ api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil AgentName: agent.Name,
},
Workspace: &ws, Workspace: &ws,
Database: dbM, Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{ StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -459,9 +455,8 @@ func TestUpdateStats(t *testing.T) {
) )
defer wut.Close() defer wut.Close()
api := agentapi.StatsAPI{ api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil AgentName: agent.Name,
},
Workspace: &workspaceAsCacheFields, Workspace: &workspaceAsCacheFields,
Database: dbM, Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{ StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -596,9 +591,8 @@ func TestUpdateStats(t *testing.T) {
} }
) )
api := agentapi.StatsAPI{ api := agentapi.StatsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: agent.ID,
return agent, nil AgentName: agent.Name,
},
Workspace: &workspaceAsCacheFields, Workspace: &workspaceAsCacheFields,
Database: dbM, Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{ StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
+6 -2
View File
@@ -25,7 +25,6 @@ import (
type SubAgentAPI struct { type SubAgentAPI struct {
OwnerID uuid.UUID OwnerID uuid.UUID
OrganizationID uuid.UUID OrganizationID uuid.UUID
AgentID uuid.UUID
AgentFn func(context.Context) (database.WorkspaceAgent, error) AgentFn func(context.Context) (database.WorkspaceAgent, error)
Log slog.Logger Log slog.Logger
@@ -295,7 +294,12 @@ func (a *SubAgentAPI) ListSubAgents(ctx context.Context, _ *agentproto.ListSubAg
//nolint:gocritic // This gives us only the permissions required to do the job. //nolint:gocritic // This gives us only the permissions required to do the job.
ctx = dbauthz.AsSubAgentAPI(ctx, a.OrganizationID, a.OwnerID) ctx = dbauthz.AsSubAgentAPI(ctx, a.OrganizationID, a.OwnerID)
workspaceAgents, err := a.Database.GetWorkspaceAgentsByParentID(ctx, a.AgentID) parentAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, xerrors.Errorf("get parent agent: %w", err)
}
workspaceAgents, err := a.Database.GetWorkspaceAgentsByParentID(ctx, parentAgent.ID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
+3 -6
View File
@@ -81,12 +81,9 @@ func TestSubAgentAPI(t *testing.T) {
return &agentapi.SubAgentAPI{ return &agentapi.SubAgentAPI{
OwnerID: user.ID, OwnerID: user.ID,
OrganizationID: org.ID, OrganizationID: org.ID,
AgentID: agent.ID, AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil },
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { Clock: clock,
return agent, nil Database: dbauthz.New(db, auth, logger, accessControlStore),
},
Clock: clock,
Database: dbauthz.New(db, auth, logger, accessControlStore),
} }
} }
+11 -6
View File
@@ -314,17 +314,22 @@ func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Req
// This functionality has been moved to the AppsAPI in the agentapi. We keep this HTTP handler around for back // This functionality has been moved to the AppsAPI in the agentapi. We keep this HTTP handler around for back
// compatibility with older agents. We'll translate the request into the protobuf so there is only one primary // compatibility with older agents. We'll translate the request into the protobuf so there is only one primary
// implementation. // implementation.
cachedWs := &agentapi.CachedWorkspaceFields{}
cachedWs.UpdateValues(workspace)
appAPI := &agentapi.AppsAPI{ appAPI := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) { AgentID: workspaceAgent.ID,
return workspaceAgent, nil Database: api.Database,
Log: api.Logger,
Workspace: cachedWs,
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return api.Database.GetWorkspaceAgentByID(ctx, workspaceAgent.ID)
}, },
Database: api.Database, PublishWorkspaceUpdateFn: func(ctx context.Context, agentID uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
Log: api.Logger,
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{ api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{
Kind: kind, Kind: kind,
WorkspaceID: workspace.ID, WorkspaceID: workspace.ID,
AgentID: &agent.ID, AgentID: &agentID,
}) })
return nil return nil
}, },
+1 -1
View File
@@ -178,7 +178,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
// Optional: // Optional:
UpdateAgentMetricsFn: api.UpdateAgentMetrics, UpdateAgentMetricsFn: api.UpdateAgentMetrics,
}, workspace) }, workspace, workspaceAgent)
streamID := tailnet.StreamID{ streamID := tailnet.StreamID{
Name: fmt.Sprintf("%s-%s-%s", workspace.OwnerUsername, workspace.Name, workspaceAgent.Name), Name: fmt.Sprintf("%s-%s-%s", workspace.OwnerUsername, workspace.Name, workspaceAgent.Name),
+1 -1
View File
@@ -1753,7 +1753,7 @@ func (api *API) postWorkspaceUsage(rw http.ResponseWriter, r *http.Request) {
// return // return
// } // }
err = api.statsReporter.ReportAgentStats(ctx, dbtime.Now(), database.WorkspaceIdentityFromWorkspace(workspace), agent, stat, true) err = api.statsReporter.ReportAgentStats(ctx, dbtime.Now(), database.WorkspaceIdentityFromWorkspace(workspace), agent.ID, agent.Name, stat, true)
if err != nil { if err != nil {
httpapi.InternalServerError(rw, err) httpapi.InternalServerError(rw, err)
return return
+3 -3
View File
@@ -137,10 +137,10 @@ func (r *Reporter) ReportAppStats(ctx context.Context, stats []workspaceapps.Sta
} }
// nolint:revive // usage is a control flag while we have the experiment // nolint:revive // usage is a control flag while we have the experiment
func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspace database.WorkspaceIdentity, workspaceAgent database.WorkspaceAgent, stats *agentproto.Stats, usage bool) error { func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspace database.WorkspaceIdentity, agentID uuid.UUID, agentName string, stats *agentproto.Stats, usage bool) error {
// update agent stats // update agent stats
if !r.opts.DisableDatabaseInserts { if !r.opts.DisableDatabaseInserts {
r.opts.StatsBatcher.Add(now, workspaceAgent.ID, workspace.TemplateID, workspace.OwnerID, workspace.ID, stats, usage) r.opts.StatsBatcher.Add(now, agentID, workspace.TemplateID, workspace.OwnerID, workspace.ID, stats, usage)
} }
// update prometheus metrics (even if template insights are disabled) // update prometheus metrics (even if template insights are disabled)
@@ -148,7 +148,7 @@ func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspac
r.opts.UpdateAgentMetricsFn(ctx, prometheusmetrics.AgentMetricLabels{ r.opts.UpdateAgentMetricsFn(ctx, prometheusmetrics.AgentMetricLabels{
Username: workspace.OwnerUsername, Username: workspace.OwnerUsername,
WorkspaceName: workspace.Name, WorkspaceName: workspace.Name,
AgentName: workspaceAgent.Name, AgentName: agentName,
TemplateName: workspace.TemplateName, TemplateName: workspace.TemplateName,
}, stats.Metrics) }, stats.Metrics)
} }
+1
View File
@@ -182,6 +182,7 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) (err error)
wchRef, rchRef := wch, rch wchRef, rchRef := wch, rch
for { for {
if wchRef == nil && rchRef == nil { if wchRef == nil && rchRef == nil {
logger.Info(ctx, "reading and writing to agent complete! Closing connection")
return nil return nil
} }