mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: limit calls to GetWorkspaceAgentByID in agentapi (#23015)
We currently call GetWorkspaceAgentByID millions of times at scale unnecessarily. This PR embeds immutable fields into the relevant services instead of fetching for them every time. resolves https://github.com/coder/scaletest/issues/84 Confirmed with a 10k scaletest that this changeset takes the query from 10M+ queries down to 39k
This commit is contained in:
+13
-8
@@ -103,7 +103,7 @@ type Options struct {
|
||||
UpdateAgentMetricsFn func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
|
||||
}
|
||||
|
||||
func New(opts Options, workspace database.Workspace) *API {
|
||||
func New(opts Options, workspace database.Workspace, agent database.WorkspaceAgent) *API {
|
||||
if opts.Clock == nil {
|
||||
opts.Clock = quartz.NewReal()
|
||||
}
|
||||
@@ -156,7 +156,8 @@ func New(opts Options, workspace database.Workspace) *API {
|
||||
}
|
||||
|
||||
api.StatsAPI = &StatsAPI{
|
||||
AgentFn: api.agent,
|
||||
AgentID: agent.ID,
|
||||
AgentName: agent.Name,
|
||||
Workspace: api.cachedWorkspaceFields,
|
||||
Database: opts.Database,
|
||||
Log: opts.Log,
|
||||
@@ -175,16 +176,18 @@ func New(opts Options, workspace database.Workspace) *API {
|
||||
}
|
||||
|
||||
api.AppsAPI = &AppsAPI{
|
||||
AgentID: agent.ID,
|
||||
AgentFn: api.agent,
|
||||
Database: opts.Database,
|
||||
Log: opts.Log,
|
||||
Workspace: api.cachedWorkspaceFields,
|
||||
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
|
||||
Clock: opts.Clock,
|
||||
NotificationsEnqueuer: opts.NotificationsEnqueuer,
|
||||
}
|
||||
|
||||
api.MetadataAPI = &MetadataAPI{
|
||||
AgentFn: api.agent,
|
||||
AgentID: agent.ID,
|
||||
Workspace: api.cachedWorkspaceFields,
|
||||
Database: opts.Database,
|
||||
Log: opts.Log,
|
||||
@@ -204,7 +207,8 @@ func New(opts Options, workspace database.Workspace) *API {
|
||||
}
|
||||
|
||||
api.ConnLogAPI = &ConnLogAPI{
|
||||
AgentFn: api.agent,
|
||||
AgentID: agent.ID,
|
||||
AgentName: agent.Name,
|
||||
ConnectionLogger: opts.ConnectionLogger,
|
||||
Database: opts.Database,
|
||||
Workspace: api.cachedWorkspaceFields,
|
||||
@@ -222,7 +226,6 @@ func New(opts Options, workspace database.Workspace) *API {
|
||||
api.SubAgentAPI = &SubAgentAPI{
|
||||
OwnerID: opts.OwnerID,
|
||||
OrganizationID: opts.OrganizationID,
|
||||
AgentID: opts.AgentID,
|
||||
AgentFn: api.agent,
|
||||
Log: opts.Log,
|
||||
Clock: opts.Clock,
|
||||
@@ -297,8 +300,10 @@ func (a *API) agent(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
func (a *API) refreshCachedWorkspace(ctx context.Context) {
|
||||
ws, err := a.opts.Database.GetWorkspaceByID(ctx, a.opts.WorkspaceID)
|
||||
if err != nil {
|
||||
// Do not clear the cache on transient DB errors. Stale data is
|
||||
// preferable to no data, which forces callers to fall back to
|
||||
// expensive queries like GetWorkspaceByAgentID.
|
||||
a.opts.Log.Warn(ctx, "failed to refresh cached workspace fields", slog.Error(err))
|
||||
a.cachedWorkspaceFields.Clear()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -341,11 +346,11 @@ func (a *API) startCacheRefreshLoop(ctx context.Context) {
|
||||
a.cachedWorkspaceFields.Clear()
|
||||
}
|
||||
|
||||
func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
func (a *API) publishWorkspaceUpdate(ctx context.Context, agentID uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
a.opts.PublishWorkspaceUpdateFn(ctx, a.opts.OwnerID, wspubsub.WorkspaceEvent{
|
||||
Kind: kind,
|
||||
WorkspaceID: a.opts.WorkspaceID,
|
||||
AgentID: &agent.ID,
|
||||
AgentID: &agentID,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
+38
-33
@@ -24,22 +24,19 @@ import (
|
||||
)
|
||||
|
||||
type AppsAPI struct {
|
||||
AgentID uuid.UUID
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
Database database.Store
|
||||
Log slog.Logger
|
||||
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
|
||||
Workspace *CachedWorkspaceFields
|
||||
PublishWorkspaceUpdateFn func(context.Context, uuid.UUID, wspubsub.WorkspaceEventKind) error
|
||||
NotificationsEnqueuer notifications.Enqueuer
|
||||
Clock quartz.Clock
|
||||
}
|
||||
|
||||
func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
|
||||
workspaceAgent, err := a.AgentFn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a.Log.Debug(ctx, "got batch app health update",
|
||||
slog.F("agent_id", workspaceAgent.ID.String()),
|
||||
slog.F("agent_id", a.AgentID.String()),
|
||||
slog.F("updates", req.Updates),
|
||||
)
|
||||
|
||||
@@ -47,9 +44,9 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
|
||||
return &agentproto.BatchUpdateAppHealthResponse{}, nil
|
||||
}
|
||||
|
||||
apps, err := a.Database.GetWorkspaceAppsByAgentID(ctx, workspaceAgent.ID)
|
||||
apps, err := a.Database.GetWorkspaceAppsByAgentID(ctx, a.AgentID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace apps by agent ID %q: %w", workspaceAgent.ID, err)
|
||||
return nil, xerrors.Errorf("get workspace apps by agent ID %q: %w", a.AgentID, err)
|
||||
}
|
||||
|
||||
var newApps []database.WorkspaceApp
|
||||
@@ -110,7 +107,7 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
|
||||
}
|
||||
|
||||
if a.PublishWorkspaceUpdateFn != nil && len(newApps) > 0 {
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAppHealthUpdate)
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, a.AgentID, wspubsub.WorkspaceEventKindAppHealthUpdate)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish workspace update: %w", err)
|
||||
}
|
||||
@@ -149,12 +146,8 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
|
||||
})
|
||||
}
|
||||
|
||||
workspaceAgent, err := a.AgentFn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
app, err := a.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{
|
||||
AgentID: workspaceAgent.ID,
|
||||
AgentID: a.AgentID,
|
||||
Slug: req.Slug,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -164,11 +157,10 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
|
||||
})
|
||||
}
|
||||
|
||||
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to get workspace.",
|
||||
Detail: err.Error(),
|
||||
ws, ok := a.Workspace.AsWorkspaceIdentity()
|
||||
if !ok {
|
||||
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Workspace identity not cached.",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -190,8 +182,8 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
|
||||
_, err = a.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: dbtime.Now(),
|
||||
WorkspaceID: workspace.ID,
|
||||
AgentID: workspaceAgent.ID,
|
||||
WorkspaceID: ws.ID,
|
||||
AgentID: a.AgentID,
|
||||
AppID: app.ID,
|
||||
State: dbState,
|
||||
Message: cleaned,
|
||||
@@ -208,7 +200,7 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
|
||||
}
|
||||
|
||||
if a.PublishWorkspaceUpdateFn != nil {
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, a.AgentID, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
|
||||
if err != nil {
|
||||
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to publish workspace update.",
|
||||
@@ -217,14 +209,14 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
|
||||
}
|
||||
}
|
||||
|
||||
// Notify on state change to Working/Idle for AI tasks
|
||||
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState, workspace, workspaceAgent)
|
||||
// Notify on state change to Working/Idle for AI tasks.
|
||||
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState)
|
||||
|
||||
if shouldBump(dbState, latestAppStatus) {
|
||||
// We pass time.Time{} for nextAutostart since we don't have access to
|
||||
// TemplateScheduleStore here. The activity bump logic handles this by
|
||||
// defaulting to the template's activity_bump duration (typically 1 hour).
|
||||
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, workspace.ID, time.Time{})
|
||||
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, ws.ID, time.Time{})
|
||||
}
|
||||
// just return a blank response because it doesn't contain any settable fields at present.
|
||||
return new(agentproto.UpdateAppStatusResponse), nil
|
||||
@@ -261,8 +253,6 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
|
||||
appID uuid.UUID,
|
||||
latestAppStatus database.WorkspaceAppStatus,
|
||||
newAppStatus database.WorkspaceAppStatusState,
|
||||
workspace database.Workspace,
|
||||
agent database.WorkspaceAgent,
|
||||
) {
|
||||
var notificationTemplate uuid.UUID
|
||||
switch newAppStatus {
|
||||
@@ -279,11 +269,20 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
|
||||
return
|
||||
}
|
||||
|
||||
if !workspace.TaskID.Valid {
|
||||
taskID := a.Workspace.TaskID()
|
||||
if !taskID.Valid {
|
||||
// Workspace has no task ID, do nothing.
|
||||
return
|
||||
}
|
||||
|
||||
// Only fetch fresh agent state for task workspaces, since we need
|
||||
// the current lifecycle state to decide whether to send notifications.
|
||||
agent, err := a.AgentFn(ctx)
|
||||
if err != nil {
|
||||
a.Log.Warn(ctx, "failed to get agent for AI task notification", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Only send notifications when the agent is ready. We want to skip
|
||||
// any state transitions that occur whilst the workspace is starting
|
||||
// up as it doesn't make sense to receive them.
|
||||
@@ -296,7 +295,7 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
|
||||
return
|
||||
}
|
||||
|
||||
task, err := a.Database.GetTaskByID(ctx, workspace.TaskID.UUID)
|
||||
task, err := a.Database.GetTaskByID(ctx, taskID.UUID)
|
||||
if err != nil {
|
||||
a.Log.Warn(ctx, "failed to get task", slog.Error(err))
|
||||
return
|
||||
@@ -321,14 +320,20 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
|
||||
return
|
||||
}
|
||||
|
||||
ws, ok := a.Workspace.AsWorkspaceIdentity()
|
||||
if !ok {
|
||||
a.Log.Warn(ctx, "failed to get workspace identity for AI task notification")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := a.NotificationsEnqueuer.EnqueueWithData(
|
||||
// nolint:gocritic // Need notifier actor to enqueue notifications
|
||||
dbauthz.AsNotifier(ctx),
|
||||
workspace.OwnerID,
|
||||
ws.OwnerID,
|
||||
notificationTemplate,
|
||||
map[string]string{
|
||||
"task": task.Name,
|
||||
"workspace": workspace.Name,
|
||||
"workspace": ws.Name,
|
||||
},
|
||||
map[string]any{
|
||||
// Use a 1-minute bucketed timestamp to bypass per-day dedupe,
|
||||
@@ -338,7 +343,7 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
|
||||
},
|
||||
"api-workspace-agent-app-status",
|
||||
// Associate this notification with related entities
|
||||
workspace.ID, workspace.OwnerID, workspace.OrganizationID, appID,
|
||||
ws.ID, ws.OwnerID, ws.OrganizationID, appID,
|
||||
); err != nil {
|
||||
a.Log.Warn(ctx, "failed to notify of task state", slog.Error(err))
|
||||
return
|
||||
|
||||
@@ -67,12 +67,10 @@ func TestBatchUpdateAppHealths(t *testing.T) {
|
||||
|
||||
publishCalled := false
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
publishCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -105,12 +103,10 @@ func TestBatchUpdateAppHealths(t *testing.T) {
|
||||
|
||||
publishCalled := false
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
publishCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -144,12 +140,10 @@ func TestBatchUpdateAppHealths(t *testing.T) {
|
||||
|
||||
publishCalled := false
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
publishCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -180,9 +174,7 @@ func TestBatchUpdateAppHealths(t *testing.T) {
|
||||
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app3}, nil)
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: nil,
|
||||
@@ -209,9 +201,7 @@ func TestBatchUpdateAppHealths(t *testing.T) {
|
||||
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: nil,
|
||||
@@ -239,9 +229,7 @@ func TestBatchUpdateAppHealths(t *testing.T) {
|
||||
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: nil,
|
||||
@@ -279,14 +267,26 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
|
||||
}
|
||||
workspaceUpdates := make(chan wspubsub.WorkspaceEventKind, 100)
|
||||
|
||||
workspace := database.Workspace{
|
||||
ID: uuid.UUID{9},
|
||||
TaskID: uuid.NullUUID{
|
||||
Valid: true,
|
||||
UUID: uuid.UUID{7},
|
||||
},
|
||||
}
|
||||
cachedWs := &agentapi.CachedWorkspaceFields{}
|
||||
cachedWs.UpdateValues(workspace)
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentID: agent.ID,
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(_ context.Context, agnt *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
assert.Equal(t, *agnt, agent)
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
Workspace: cachedWs,
|
||||
PublishWorkspaceUpdateFn: func(_ context.Context, agnt uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
assert.Equal(t, agnt, agent.ID)
|
||||
testutil.AssertSend(ctx, t, workspaceUpdates, kind)
|
||||
return nil
|
||||
},
|
||||
@@ -309,14 +309,6 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
|
||||
},
|
||||
}
|
||||
mDB.EXPECT().GetTaskByID(gomock.Any(), task.ID).Times(1).Return(task, nil)
|
||||
workspace := database.Workspace{
|
||||
ID: uuid.UUID{9},
|
||||
TaskID: uuid.NullUUID{
|
||||
Valid: true,
|
||||
UUID: task.ID,
|
||||
},
|
||||
}
|
||||
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Times(1).Return(workspace, nil)
|
||||
appStatus := database.WorkspaceAppStatus{
|
||||
ID: uuid.UUID{6},
|
||||
}
|
||||
@@ -363,9 +355,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
|
||||
Return(database.WorkspaceApp{}, sql.ErrNoRows)
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
}
|
||||
@@ -392,9 +382,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
|
||||
}
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
}
|
||||
@@ -422,9 +410,7 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
|
||||
}
|
||||
|
||||
api := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
Database: mDB,
|
||||
Log: testutil.Logger(t),
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -23,12 +24,14 @@ type CachedWorkspaceFields struct {
|
||||
lock sync.RWMutex
|
||||
|
||||
identity database.WorkspaceIdentity
|
||||
taskID uuid.NullUUID
|
||||
}
|
||||
|
||||
func (cws *CachedWorkspaceFields) Clear() {
|
||||
cws.lock.Lock()
|
||||
defer cws.lock.Unlock()
|
||||
cws.identity = database.WorkspaceIdentity{}
|
||||
cws.taskID = uuid.NullUUID{}
|
||||
}
|
||||
|
||||
func (cws *CachedWorkspaceFields) UpdateValues(ws database.Workspace) {
|
||||
@@ -42,6 +45,13 @@ func (cws *CachedWorkspaceFields) UpdateValues(ws database.Workspace) {
|
||||
cws.identity.OwnerUsername = ws.OwnerUsername
|
||||
cws.identity.TemplateName = ws.TemplateName
|
||||
cws.identity.AutostartSchedule = ws.AutostartSchedule
|
||||
cws.taskID = ws.TaskID
|
||||
}
|
||||
|
||||
func (cws *CachedWorkspaceFields) TaskID() uuid.NullUUID {
|
||||
cws.lock.RLock()
|
||||
defer cws.lock.RUnlock()
|
||||
return cws.taskID
|
||||
}
|
||||
|
||||
// Returns the Workspace, true, unless the workspace has not been cached (nuked or was a prebuild).
|
||||
|
||||
@@ -14,11 +14,11 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/connectionlog"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
)
|
||||
|
||||
type ConnLogAPI struct {
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
AgentID uuid.UUID
|
||||
AgentName string
|
||||
ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger]
|
||||
Workspace *CachedWorkspaceFields
|
||||
Database database.Store
|
||||
@@ -53,27 +53,12 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
|
||||
}
|
||||
}
|
||||
|
||||
// Inject RBAC object into context for dbauthz fast path, avoid having to
|
||||
// call GetWorkspaceByAgentID on every metadata update.
|
||||
rbacCtx := ctx
|
||||
var ws database.WorkspaceIdentity
|
||||
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
|
||||
ws = dbws
|
||||
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
|
||||
if err != nil {
|
||||
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
|
||||
//nolint:gocritic
|
||||
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch contextual data for this connection log event.
|
||||
workspaceAgent, err := a.AgentFn(rbacCtx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get agent: %w", err)
|
||||
}
|
||||
if ws.Equal(database.WorkspaceIdentity{}) {
|
||||
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
|
||||
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, a.AgentID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace by agent id: %w", err)
|
||||
}
|
||||
@@ -97,7 +82,7 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
|
||||
WorkspaceOwnerID: ws.OwnerID,
|
||||
WorkspaceID: ws.ID,
|
||||
WorkspaceName: ws.Name,
|
||||
AgentName: workspaceAgent.Name,
|
||||
AgentName: a.AgentName,
|
||||
Type: connectionType,
|
||||
Code: code,
|
||||
Ip: logIP,
|
||||
|
||||
@@ -114,10 +114,9 @@ func TestConnectionLog(t *testing.T) {
|
||||
api := &agentapi.ConnLogAPI{
|
||||
ConnectionLogger: asAtomicPointer[connectionlog.ConnectionLogger](connLogger),
|
||||
Database: mDB,
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
AgentID: agent.ID,
|
||||
AgentName: agent.Name,
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
}
|
||||
api.ReportConnection(context.Background(), &agentproto.ReportConnectionRequest{
|
||||
Connection: &agentproto.Connection{
|
||||
|
||||
@@ -30,7 +30,7 @@ type LifecycleAPI struct {
|
||||
WorkspaceID uuid.UUID
|
||||
Database database.Store
|
||||
Log slog.Logger
|
||||
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
|
||||
PublishWorkspaceUpdateFn func(context.Context, uuid.UUID, wspubsub.WorkspaceEventKind) error
|
||||
|
||||
TimeNowFn func() time.Time // defaults to dbtime.Now()
|
||||
Metrics *LifecycleMetrics
|
||||
@@ -122,7 +122,7 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
|
||||
}
|
||||
|
||||
if a.PublishWorkspaceUpdateFn != nil {
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentLifecycleUpdate)
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, workspaceAgent.ID, wspubsub.WorkspaceEventKindAgentLifecycleUpdate)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish workspace update: %w", err)
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ func TestUpdateLifecycle(t *testing.T) {
|
||||
WorkspaceID: workspaceID,
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
publishCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -206,7 +206,7 @@ func TestUpdateLifecycle(t *testing.T) {
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
Metrics: metrics,
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
publishCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -323,7 +323,7 @@ func TestUpdateLifecycle(t *testing.T) {
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
Metrics: metrics,
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
atomic.AddInt64(&publishCalled, 1)
|
||||
return nil
|
||||
},
|
||||
@@ -410,7 +410,7 @@ func TestUpdateLifecycle(t *testing.T) {
|
||||
WorkspaceID: workspaceID,
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
publishCalled = true
|
||||
return nil
|
||||
},
|
||||
|
||||
@@ -19,7 +19,7 @@ type LogsAPI struct {
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
Database database.Store
|
||||
Log slog.Logger
|
||||
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
|
||||
PublishWorkspaceUpdateFn func(context.Context, uuid.UUID, wspubsub.WorkspaceEventKind) error
|
||||
PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage)
|
||||
|
||||
TimeNowFn func() time.Time // defaults to dbtime.Now()
|
||||
@@ -125,7 +125,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
|
||||
}
|
||||
|
||||
if a.PublishWorkspaceUpdateFn != nil {
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentLogsOverflow)
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, workspaceAgent.ID, wspubsub.WorkspaceEventKindAgentLogsOverflow)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish workspace update: %w", err)
|
||||
}
|
||||
@@ -145,7 +145,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
|
||||
if workspaceAgent.LogsLength == 0 && a.PublishWorkspaceUpdateFn != nil {
|
||||
// If these are the first logs being appended, we publish a UI update
|
||||
// to notify the UI that logs are now available.
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentFirstLogs)
|
||||
err = a.PublishWorkspaceUpdateFn(ctx, workspaceAgent.ID, wspubsub.WorkspaceEventKindAgentFirstLogs)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish workspace update: %w", err)
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func TestBatchCreateLogs(t *testing.T) {
|
||||
},
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
publishWorkspaceUpdateCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -155,7 +155,7 @@ func TestBatchCreateLogs(t *testing.T) {
|
||||
},
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
publishWorkspaceUpdateCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -203,7 +203,7 @@ func TestBatchCreateLogs(t *testing.T) {
|
||||
},
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
publishWorkspaceUpdateCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -296,7 +296,7 @@ func TestBatchCreateLogs(t *testing.T) {
|
||||
},
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
publishWorkspaceUpdateCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -340,7 +340,7 @@ func TestBatchCreateLogs(t *testing.T) {
|
||||
},
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
publishWorkspaceUpdateCalled = true
|
||||
return nil
|
||||
},
|
||||
@@ -387,7 +387,7 @@ func TestBatchCreateLogs(t *testing.T) {
|
||||
},
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
publishWorkspaceUpdateCalled = true
|
||||
return nil
|
||||
},
|
||||
|
||||
@@ -32,16 +32,12 @@ type ManifestAPI struct {
|
||||
DerpForceWebSockets bool
|
||||
WorkspaceID uuid.UUID
|
||||
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
AgentFn func(ctx context.Context) (database.WorkspaceAgent, error)
|
||||
Database database.Store
|
||||
DerpMapFn func() *tailcfg.DERPMap
|
||||
}
|
||||
|
||||
func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
|
||||
workspaceAgent, err := a.AgentFn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var (
|
||||
dbApps []database.WorkspaceApp
|
||||
scripts []database.WorkspaceAgentScript
|
||||
@@ -50,6 +46,11 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
|
||||
devcontainers []database.WorkspaceAgentDevcontainer
|
||||
)
|
||||
|
||||
workspaceAgent, err := a.AgentFn(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("getting workspace agent: %w", err)
|
||||
}
|
||||
|
||||
var eg errgroup.Group
|
||||
eg.Go(func() (err error) {
|
||||
dbApps, err = a.Database.GetWorkspaceAppsByAgentID(ctx, workspaceAgent.ID)
|
||||
|
||||
@@ -322,9 +322,7 @@ func TestGetManifest(t *testing.T) {
|
||||
DisableDirectConnections: true,
|
||||
DerpForceWebSockets: true,
|
||||
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil },
|
||||
WorkspaceID: workspace.ID,
|
||||
Database: mDB,
|
||||
DerpMapFn: derpMapFn,
|
||||
@@ -389,9 +387,7 @@ func TestGetManifest(t *testing.T) {
|
||||
DisableDirectConnections: true,
|
||||
DerpForceWebSockets: true,
|
||||
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return childAgent, nil
|
||||
},
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return childAgent, nil },
|
||||
WorkspaceID: workspace.ID,
|
||||
Database: mDB,
|
||||
DerpMapFn: derpMapFn,
|
||||
@@ -512,9 +508,7 @@ func TestGetManifest(t *testing.T) {
|
||||
DisableDirectConnections: true,
|
||||
DerpForceWebSockets: true,
|
||||
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil },
|
||||
WorkspaceID: workspace.ID,
|
||||
Database: mDB,
|
||||
DerpMapFn: derpMapFn,
|
||||
|
||||
@@ -5,18 +5,18 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
)
|
||||
|
||||
type MetadataAPI struct {
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
AgentID uuid.UUID
|
||||
Workspace *CachedWorkspaceFields
|
||||
Database database.Store
|
||||
Log slog.Logger
|
||||
@@ -45,29 +45,11 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
|
||||
maxErrorLen = maxValueLen
|
||||
)
|
||||
|
||||
// Inject RBAC object into context for dbauthz fast path, avoid having to
|
||||
// call GetWorkspaceByAgentID on every metadata update.
|
||||
var err error
|
||||
rbacCtx := ctx
|
||||
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
|
||||
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
|
||||
if err != nil {
|
||||
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
|
||||
//nolint:gocritic
|
||||
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
|
||||
}
|
||||
}
|
||||
|
||||
workspaceAgent, err := a.AgentFn(rbacCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
collectedAt = a.now()
|
||||
allKeysLen = 0
|
||||
dbUpdate = database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: workspaceAgent.ID,
|
||||
WorkspaceAgentID: a.AgentID,
|
||||
// These need to be `make(x, 0, len(req.Metadata))` instead of
|
||||
// `make(x, len(req.Metadata))` because we may not insert all
|
||||
// metadata if the keys are large.
|
||||
@@ -121,7 +103,7 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
|
||||
}
|
||||
|
||||
// Use batcher to batch metadata updates.
|
||||
err = a.Batcher.Add(workspaceAgent.ID, dbUpdate.Key, dbUpdate.Value, dbUpdate.Error, dbUpdate.CollectedAt)
|
||||
err := a.Batcher.Add(a.AgentID, dbUpdate.Key, dbUpdate.Value, dbUpdate.Error, dbUpdate.CollectedAt)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("add metadata to batcher: %w", err)
|
||||
}
|
||||
|
||||
@@ -80,9 +80,7 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
t.Cleanup(batcher.Close)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Log: testutil.Logger(t),
|
||||
Batcher: batcher,
|
||||
@@ -159,9 +157,7 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
t.Cleanup(batcher.Close)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Log: testutil.Logger(t),
|
||||
Batcher: batcher,
|
||||
@@ -241,9 +237,7 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
t.Cleanup(batcher.Close)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Log: testutil.Logger(t),
|
||||
Batcher: batcher,
|
||||
|
||||
@@ -4,20 +4,21 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
type StatsAPI struct {
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
AgentID uuid.UUID
|
||||
AgentName string
|
||||
Workspace *CachedWorkspaceFields
|
||||
Database database.Store
|
||||
Log slog.Logger
|
||||
@@ -44,32 +45,13 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Inject RBAC object into context for dbauthz fast path, avoid having to
|
||||
// call GetWorkspaceAgentByID on every stats update.
|
||||
|
||||
rbacCtx := ctx
|
||||
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
|
||||
var err error
|
||||
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
|
||||
if err != nil {
|
||||
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
|
||||
//nolint:gocritic
|
||||
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
|
||||
}
|
||||
}
|
||||
|
||||
workspaceAgent, err := a.AgentFn(rbacCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If cache is empty (prebuild or invalid), fall back to DB
|
||||
var ws database.WorkspaceIdentity
|
||||
var ok bool
|
||||
if ws, ok = a.Workspace.AsWorkspaceIdentity(); !ok {
|
||||
w, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
|
||||
w, err := a.Database.GetWorkspaceByAgentID(ctx, a.AgentID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace by agent ID %q: %w", workspaceAgent.ID, err)
|
||||
return nil, xerrors.Errorf("get workspace by agent ID %q: %w", a.AgentID, err)
|
||||
}
|
||||
ws = database.WorkspaceIdentityFromWorkspace(w)
|
||||
}
|
||||
@@ -90,11 +72,12 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
|
||||
req.Stats.SessionCountReconnectingPty = 0
|
||||
}
|
||||
|
||||
err = a.StatsReporter.ReportAgentStats(
|
||||
err := a.StatsReporter.ReportAgentStats(
|
||||
ctx,
|
||||
a.now(),
|
||||
ws,
|
||||
workspaceAgent,
|
||||
a.AgentID,
|
||||
a.AgentName,
|
||||
req.Stats,
|
||||
false,
|
||||
)
|
||||
|
||||
@@ -119,9 +119,8 @@ func TestUpdateStats(t *testing.T) {
|
||||
}
|
||||
)
|
||||
api := agentapi.StatsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
AgentName: agent.Name,
|
||||
Workspace: &workspaceAsCacheFields,
|
||||
Database: dbM,
|
||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||
@@ -229,9 +228,8 @@ func TestUpdateStats(t *testing.T) {
|
||||
}
|
||||
)
|
||||
api := agentapi.StatsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
AgentName: agent.Name,
|
||||
Workspace: &workspaceAsCacheFields,
|
||||
Database: dbM,
|
||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||
@@ -264,9 +262,8 @@ func TestUpdateStats(t *testing.T) {
|
||||
}
|
||||
)
|
||||
api := agentapi.StatsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
AgentName: agent.Name,
|
||||
Workspace: &workspaceAsCacheFields,
|
||||
Database: dbM,
|
||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||
@@ -347,9 +344,8 @@ func TestUpdateStats(t *testing.T) {
|
||||
// ws.AutostartSchedule = workspace.AutostartSchedule
|
||||
|
||||
api := agentapi.StatsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
AgentName: agent.Name,
|
||||
Workspace: &ws,
|
||||
Database: dbM,
|
||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||
@@ -459,9 +455,8 @@ func TestUpdateStats(t *testing.T) {
|
||||
)
|
||||
defer wut.Close()
|
||||
api := agentapi.StatsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
AgentName: agent.Name,
|
||||
Workspace: &workspaceAsCacheFields,
|
||||
Database: dbM,
|
||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||
@@ -596,9 +591,8 @@ func TestUpdateStats(t *testing.T) {
|
||||
}
|
||||
)
|
||||
api := agentapi.StatsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
AgentID: agent.ID,
|
||||
AgentName: agent.Name,
|
||||
Workspace: &workspaceAsCacheFields,
|
||||
Database: dbM,
|
||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
type SubAgentAPI struct {
|
||||
OwnerID uuid.UUID
|
||||
OrganizationID uuid.UUID
|
||||
AgentID uuid.UUID
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
|
||||
Log slog.Logger
|
||||
@@ -295,7 +294,12 @@ func (a *SubAgentAPI) ListSubAgents(ctx context.Context, _ *agentproto.ListSubAg
|
||||
//nolint:gocritic // This gives us only the permissions required to do the job.
|
||||
ctx = dbauthz.AsSubAgentAPI(ctx, a.OrganizationID, a.OwnerID)
|
||||
|
||||
workspaceAgents, err := a.Database.GetWorkspaceAgentsByParentID(ctx, a.AgentID)
|
||||
parentAgent, err := a.AgentFn(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get parent agent: %w", err)
|
||||
}
|
||||
|
||||
workspaceAgents, err := a.Database.GetWorkspaceAgentsByParentID(ctx, parentAgent.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -81,12 +81,9 @@ func TestSubAgentAPI(t *testing.T) {
|
||||
return &agentapi.SubAgentAPI{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
AgentID: agent.ID,
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Clock: clock,
|
||||
Database: dbauthz.New(db, auth, logger, accessControlStore),
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil },
|
||||
Clock: clock,
|
||||
Database: dbauthz.New(db, auth, logger, accessControlStore),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -314,17 +314,22 @@ func (api *API) patchWorkspaceAgentAppStatus(rw http.ResponseWriter, r *http.Req
|
||||
// This functionality has been moved to the AppsAPI in the agentapi. We keep this HTTP handler around for back
|
||||
// compatibility with older agents. We'll translate the request into the protobuf so there is only one primary
|
||||
// implementation.
|
||||
cachedWs := &agentapi.CachedWorkspaceFields{}
|
||||
cachedWs.UpdateValues(workspace)
|
||||
|
||||
appAPI := &agentapi.AppsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return workspaceAgent, nil
|
||||
AgentID: workspaceAgent.ID,
|
||||
Database: api.Database,
|
||||
Log: api.Logger,
|
||||
Workspace: cachedWs,
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return api.Database.GetWorkspaceAgentByID(ctx, workspaceAgent.ID)
|
||||
},
|
||||
Database: api.Database,
|
||||
Log: api.Logger,
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
PublishWorkspaceUpdateFn: func(ctx context.Context, agentID uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
|
||||
api.publishWorkspaceUpdate(ctx, workspace.OwnerID, wspubsub.WorkspaceEvent{
|
||||
Kind: kind,
|
||||
WorkspaceID: workspace.ID,
|
||||
AgentID: &agent.ID,
|
||||
AgentID: &agentID,
|
||||
})
|
||||
return nil
|
||||
},
|
||||
|
||||
@@ -178,7 +178,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Optional:
|
||||
UpdateAgentMetricsFn: api.UpdateAgentMetrics,
|
||||
}, workspace)
|
||||
}, workspace, workspaceAgent)
|
||||
|
||||
streamID := tailnet.StreamID{
|
||||
Name: fmt.Sprintf("%s-%s-%s", workspace.OwnerUsername, workspace.Name, workspaceAgent.Name),
|
||||
|
||||
@@ -1753,7 +1753,7 @@ func (api *API) postWorkspaceUsage(rw http.ResponseWriter, r *http.Request) {
|
||||
// return
|
||||
// }
|
||||
|
||||
err = api.statsReporter.ReportAgentStats(ctx, dbtime.Now(), database.WorkspaceIdentityFromWorkspace(workspace), agent, stat, true)
|
||||
err = api.statsReporter.ReportAgentStats(ctx, dbtime.Now(), database.WorkspaceIdentityFromWorkspace(workspace), agent.ID, agent.Name, stat, true)
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
|
||||
@@ -137,10 +137,10 @@ func (r *Reporter) ReportAppStats(ctx context.Context, stats []workspaceapps.Sta
|
||||
}
|
||||
|
||||
// nolint:revive // usage is a control flag while we have the experiment
|
||||
func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspace database.WorkspaceIdentity, workspaceAgent database.WorkspaceAgent, stats *agentproto.Stats, usage bool) error {
|
||||
func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspace database.WorkspaceIdentity, agentID uuid.UUID, agentName string, stats *agentproto.Stats, usage bool) error {
|
||||
// update agent stats
|
||||
if !r.opts.DisableDatabaseInserts {
|
||||
r.opts.StatsBatcher.Add(now, workspaceAgent.ID, workspace.TemplateID, workspace.OwnerID, workspace.ID, stats, usage)
|
||||
r.opts.StatsBatcher.Add(now, agentID, workspace.TemplateID, workspace.OwnerID, workspace.ID, stats, usage)
|
||||
}
|
||||
|
||||
// update prometheus metrics (even if template insights are disabled)
|
||||
@@ -148,7 +148,7 @@ func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspac
|
||||
r.opts.UpdateAgentMetricsFn(ctx, prometheusmetrics.AgentMetricLabels{
|
||||
Username: workspace.OwnerUsername,
|
||||
WorkspaceName: workspace.Name,
|
||||
AgentName: workspaceAgent.Name,
|
||||
AgentName: agentName,
|
||||
TemplateName: workspace.TemplateName,
|
||||
}, stats.Metrics)
|
||||
}
|
||||
|
||||
@@ -182,6 +182,7 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) (err error)
|
||||
wchRef, rchRef := wch, rch
|
||||
for {
|
||||
if wchRef == nil && rchRef == nil {
|
||||
logger.Info(ctx, "reading and writing to agent complete! Closing connection")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user