mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
perf: reduce DB calls to GetWorkspaceByAgentID via caching workspace info (#20662)
--------- Signed-off-by: Callum Styan <callumstyan@gmail.com>
This commit is contained in:
+72
-5
@@ -36,6 +36,8 @@ import (
|
|||||||
"github.com/coder/quartz"
|
"github.com/coder/quartz"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const workspaceCacheRefreshInterval = 5 * time.Minute
|
||||||
|
|
||||||
// API implements the DRPC agent API interface from agent/proto. This struct is
|
// API implements the DRPC agent API interface from agent/proto. This struct is
|
||||||
// instantiated once per agent connection and kept alive for the duration of the
|
// instantiated once per agent connection and kept alive for the duration of the
|
||||||
// session.
|
// session.
|
||||||
@@ -54,6 +56,8 @@ type API struct {
|
|||||||
*SubAgentAPI
|
*SubAgentAPI
|
||||||
*tailnet.DRPCService
|
*tailnet.DRPCService
|
||||||
|
|
||||||
|
cachedWorkspaceFields *CachedWorkspaceFields
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -92,7 +96,7 @@ type Options struct {
|
|||||||
UpdateAgentMetricsFn func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
|
UpdateAgentMetricsFn func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(opts Options) *API {
|
func New(opts Options, workspace database.Workspace) *API {
|
||||||
if opts.Clock == nil {
|
if opts.Clock == nil {
|
||||||
opts.Clock = quartz.NewReal()
|
opts.Clock = quartz.NewReal()
|
||||||
}
|
}
|
||||||
@@ -114,6 +118,13 @@ func New(opts Options) *API {
|
|||||||
WorkspaceID: opts.WorkspaceID,
|
WorkspaceID: opts.WorkspaceID,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Don't cache details for prebuilds, though the cached fields will eventually be updated
|
||||||
|
// by the refresh routine once the prebuild workspace is claimed.
|
||||||
|
api.cachedWorkspaceFields = &CachedWorkspaceFields{}
|
||||||
|
if !workspace.IsPrebuild() {
|
||||||
|
api.cachedWorkspaceFields.UpdateValues(workspace)
|
||||||
|
}
|
||||||
|
|
||||||
api.AnnouncementBannerAPI = &AnnouncementBannerAPI{
|
api.AnnouncementBannerAPI = &AnnouncementBannerAPI{
|
||||||
appearanceFetcher: opts.AppearanceFetcher,
|
appearanceFetcher: opts.AppearanceFetcher,
|
||||||
}
|
}
|
||||||
@@ -139,6 +150,7 @@ func New(opts Options) *API {
|
|||||||
|
|
||||||
api.StatsAPI = &StatsAPI{
|
api.StatsAPI = &StatsAPI{
|
||||||
AgentFn: api.agent,
|
AgentFn: api.agent,
|
||||||
|
Workspace: api.cachedWorkspaceFields,
|
||||||
Database: opts.Database,
|
Database: opts.Database,
|
||||||
Log: opts.Log,
|
Log: opts.Log,
|
||||||
StatsReporter: opts.StatsReporter,
|
StatsReporter: opts.StatsReporter,
|
||||||
@@ -162,10 +174,11 @@ func New(opts Options) *API {
|
|||||||
}
|
}
|
||||||
|
|
||||||
api.MetadataAPI = &MetadataAPI{
|
api.MetadataAPI = &MetadataAPI{
|
||||||
AgentFn: api.agent,
|
AgentFn: api.agent,
|
||||||
Database: opts.Database,
|
Workspace: api.cachedWorkspaceFields,
|
||||||
Pubsub: opts.Pubsub,
|
Database: opts.Database,
|
||||||
Log: opts.Log,
|
Pubsub: opts.Pubsub,
|
||||||
|
Log: opts.Log,
|
||||||
}
|
}
|
||||||
|
|
||||||
api.LogsAPI = &LogsAPI{
|
api.LogsAPI = &LogsAPI{
|
||||||
@@ -205,6 +218,10 @@ func New(opts Options) *API {
|
|||||||
Database: opts.Database,
|
Database: opts.Database,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start background cache refresh loop to handle workspace changes
|
||||||
|
// like prebuild claims where owner_id and other fields may be modified in the DB.
|
||||||
|
go api.startCacheRefreshLoop(opts.Ctx)
|
||||||
|
|
||||||
return api
|
return api
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -254,6 +271,56 @@ func (a *API) agent(ctx context.Context) (database.WorkspaceAgent, error) {
|
|||||||
return agent, nil
|
return agent, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// refreshCachedWorkspace periodically updates the cached workspace fields.
|
||||||
|
// This ensures that changes like prebuild claims (which modify owner_id, name, etc.)
|
||||||
|
// are eventually reflected in the cache without requiring agent reconnection.
|
||||||
|
func (a *API) refreshCachedWorkspace(ctx context.Context) {
|
||||||
|
ws, err := a.opts.Database.GetWorkspaceByID(ctx, a.opts.WorkspaceID)
|
||||||
|
if err != nil {
|
||||||
|
a.opts.Log.Warn(ctx, "failed to refresh cached workspace fields", slog.Error(err))
|
||||||
|
a.cachedWorkspaceFields.Clear()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ws.IsPrebuild() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we still have the same values, skip the update and logging calls.
|
||||||
|
if a.cachedWorkspaceFields.identity.Equal(database.WorkspaceIdentityFromWorkspace(ws)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Update fields that can change during workspace lifecycle (e.g., AutostartSchedule)
|
||||||
|
a.cachedWorkspaceFields.UpdateValues(ws)
|
||||||
|
|
||||||
|
a.opts.Log.Debug(ctx, "refreshed cached workspace fields",
|
||||||
|
slog.F("workspace_id", ws.ID),
|
||||||
|
slog.F("owner_id", ws.OwnerID),
|
||||||
|
slog.F("name", ws.Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
// startCacheRefreshLoop runs a background goroutine that periodically refreshes
|
||||||
|
// the cached workspace fields. This is primarily needed to handle prebuild claims
|
||||||
|
// where the owner_id and other fields change while the agent connection persists.
|
||||||
|
func (a *API) startCacheRefreshLoop(ctx context.Context) {
|
||||||
|
// Refresh every 5 minutes. This provides a reasonable balance between:
|
||||||
|
// - Keeping cache fresh for prebuild claims and other workspace updates
|
||||||
|
// - Minimizing unnecessary database queries
|
||||||
|
ticker := a.opts.Clock.TickerFunc(ctx, workspaceCacheRefreshInterval, func() error {
|
||||||
|
a.refreshCachedWorkspace(ctx)
|
||||||
|
return nil
|
||||||
|
}, "cache_refresh")
|
||||||
|
|
||||||
|
// We need to wait on the ticker exiting.
|
||||||
|
_ = ticker.Wait()
|
||||||
|
|
||||||
|
a.opts.Log.Debug(ctx, "cache refresh loop exited, invalidating the workspace cache on agent API",
|
||||||
|
slog.F("workspace_id", a.cachedWorkspaceFields.identity.ID),
|
||||||
|
slog.F("owner_id", a.cachedWorkspaceFields.identity.OwnerUsername),
|
||||||
|
slog.F("name", a.cachedWorkspaceFields.identity.Name))
|
||||||
|
a.cachedWorkspaceFields.Clear()
|
||||||
|
}
|
||||||
|
|
||||||
func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||||
a.opts.PublishWorkspaceUpdateFn(ctx, a.opts.OwnerID, wspubsub.WorkspaceEvent{
|
a.opts.PublishWorkspaceUpdateFn(ctx, a.opts.OwnerID, wspubsub.WorkspaceEvent{
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
|
|||||||
@@ -0,0 +1,52 @@
|
|||||||
|
package agentapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CachedWorkspaceFields contains workspace data that is safe to cache for the
|
||||||
|
// duration of an agent connection. These fields are used to reduce database calls
|
||||||
|
// in high-frequency operations like stats reporting and metadata updates.
|
||||||
|
// Prebuild workspaces should not be cached using this struct within the API struct,
|
||||||
|
// however some of these fields for a workspace can be updated live so there is a
|
||||||
|
// routine in the API for refreshing the workspace on a timed interval.
|
||||||
|
//
|
||||||
|
// IMPORTANT: ACL fields (GroupACL, UserACL) are NOT cached because they can be
|
||||||
|
// modified in the database and we must use fresh data for authorization checks.
|
||||||
|
type CachedWorkspaceFields struct {
|
||||||
|
lock sync.RWMutex
|
||||||
|
|
||||||
|
identity database.WorkspaceIdentity
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cws *CachedWorkspaceFields) Clear() {
|
||||||
|
cws.lock.Lock()
|
||||||
|
defer cws.lock.Unlock()
|
||||||
|
cws.identity = database.WorkspaceIdentity{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cws *CachedWorkspaceFields) UpdateValues(ws database.Workspace) {
|
||||||
|
cws.lock.Lock()
|
||||||
|
defer cws.lock.Unlock()
|
||||||
|
cws.identity.ID = ws.ID
|
||||||
|
cws.identity.OwnerID = ws.OwnerID
|
||||||
|
cws.identity.OrganizationID = ws.OrganizationID
|
||||||
|
cws.identity.TemplateID = ws.TemplateID
|
||||||
|
cws.identity.Name = ws.Name
|
||||||
|
cws.identity.OwnerUsername = ws.OwnerUsername
|
||||||
|
cws.identity.TemplateName = ws.TemplateName
|
||||||
|
cws.identity.AutostartSchedule = ws.AutostartSchedule
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the Workspace, true, unless the workspace has not been cached (nuked or was a prebuild).
|
||||||
|
func (cws *CachedWorkspaceFields) AsWorkspaceIdentity() (database.WorkspaceIdentity, bool) {
|
||||||
|
cws.lock.RLock()
|
||||||
|
defer cws.lock.RUnlock()
|
||||||
|
// Should we be more explicit about all fields being set to be valid?
|
||||||
|
if cws.identity.Equal(database.WorkspaceIdentity{}) {
|
||||||
|
return database.WorkspaceIdentity{}, false
|
||||||
|
}
|
||||||
|
return cws.identity, true
|
||||||
|
}
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
package agentapi_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/coderd/agentapi"
|
||||||
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCacheClear(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
user = database.User{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Username: "bill",
|
||||||
|
}
|
||||||
|
template = database.Template{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Name: "tpl",
|
||||||
|
}
|
||||||
|
workspace = database.Workspace{
|
||||||
|
ID: uuid.New(),
|
||||||
|
OwnerID: user.ID,
|
||||||
|
OwnerUsername: user.Username,
|
||||||
|
TemplateID: template.ID,
|
||||||
|
Name: "xyz",
|
||||||
|
TemplateName: template.Name,
|
||||||
|
}
|
||||||
|
workspaceAsCacheFields = agentapi.CachedWorkspaceFields{}
|
||||||
|
)
|
||||||
|
|
||||||
|
workspaceAsCacheFields.UpdateValues(database.Workspace{
|
||||||
|
ID: workspace.ID,
|
||||||
|
OwnerID: workspace.OwnerID,
|
||||||
|
OwnerUsername: workspace.OwnerUsername,
|
||||||
|
TemplateID: workspace.TemplateID,
|
||||||
|
Name: workspace.Name,
|
||||||
|
TemplateName: workspace.TemplateName,
|
||||||
|
AutostartSchedule: workspace.AutostartSchedule,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
emptyCws := agentapi.CachedWorkspaceFields{}
|
||||||
|
workspaceAsCacheFields.Clear()
|
||||||
|
wsi, ok := workspaceAsCacheFields.AsWorkspaceIdentity()
|
||||||
|
require.False(t, ok)
|
||||||
|
ecwsi, ok := emptyCws.AsWorkspaceIdentity()
|
||||||
|
require.False(t, ok)
|
||||||
|
require.True(t, ecwsi.Equal(wsi))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheUpdate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
user = database.User{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Username: "bill",
|
||||||
|
}
|
||||||
|
template = database.Template{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Name: "tpl",
|
||||||
|
}
|
||||||
|
workspace = database.Workspace{
|
||||||
|
ID: uuid.New(),
|
||||||
|
OwnerID: user.ID,
|
||||||
|
OwnerUsername: user.Username,
|
||||||
|
TemplateID: template.ID,
|
||||||
|
Name: "xyz",
|
||||||
|
TemplateName: template.Name,
|
||||||
|
}
|
||||||
|
workspaceAsCacheFields = agentapi.CachedWorkspaceFields{}
|
||||||
|
)
|
||||||
|
|
||||||
|
workspaceAsCacheFields.UpdateValues(database.Workspace{
|
||||||
|
ID: workspace.ID,
|
||||||
|
OwnerID: workspace.OwnerID,
|
||||||
|
OwnerUsername: workspace.OwnerUsername,
|
||||||
|
TemplateID: workspace.TemplateID,
|
||||||
|
Name: workspace.Name,
|
||||||
|
TemplateName: workspace.TemplateName,
|
||||||
|
AutostartSchedule: workspace.AutostartSchedule,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
cws := agentapi.CachedWorkspaceFields{}
|
||||||
|
cws.UpdateValues(workspace)
|
||||||
|
wsi, ok := workspaceAsCacheFields.AsWorkspaceIdentity()
|
||||||
|
require.True(t, ok)
|
||||||
|
cwsi, ok := cws.AsWorkspaceIdentity()
|
||||||
|
require.True(t, ok)
|
||||||
|
require.True(t, wsi.Equal(cwsi))
|
||||||
|
}
|
||||||
@@ -12,15 +12,17 @@ import (
|
|||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||||
"github.com/coder/coder/v2/coderd/database"
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MetadataAPI struct {
|
type MetadataAPI struct {
|
||||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||||
Database database.Store
|
Workspace *CachedWorkspaceFields
|
||||||
Pubsub pubsub.Pubsub
|
Database database.Store
|
||||||
Log slog.Logger
|
Pubsub pubsub.Pubsub
|
||||||
|
Log slog.Logger
|
||||||
|
|
||||||
TimeNowFn func() time.Time // defaults to dbtime.Now()
|
TimeNowFn func() time.Time // defaults to dbtime.Now()
|
||||||
}
|
}
|
||||||
@@ -107,7 +109,19 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.Database.UpdateWorkspaceAgentMetadata(ctx, dbUpdate)
|
// Inject RBAC object into context for dbauthz fast path, avoid having to
|
||||||
|
// call GetWorkspaceByAgentID on every metadata update.
|
||||||
|
rbacCtx := ctx
|
||||||
|
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
|
||||||
|
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
|
||||||
|
if err != nil {
|
||||||
|
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
|
||||||
|
//nolint:gocritic
|
||||||
|
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = a.Database.UpdateWorkspaceAgentMetadata(rbacCtx, dbUpdate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, xerrors.Errorf("update workspace agent metadata in database: %w", err)
|
return nil, xerrors.Errorf("update workspace agent metadata in database: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,12 +2,14 @@ package agentapi_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/mock/gomock"
|
"go.uber.org/mock/gomock"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
@@ -15,10 +17,14 @@ import (
|
|||||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||||
"github.com/coder/coder/v2/coderd/agentapi"
|
"github.com/coder/coder/v2/coderd/agentapi"
|
||||||
"github.com/coder/coder/v2/coderd/database"
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||||
|
"github.com/coder/coder/v2/coderd/rbac"
|
||||||
|
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
|
"github.com/coder/quartz"
|
||||||
)
|
)
|
||||||
|
|
||||||
type fakePublisher struct {
|
type fakePublisher struct {
|
||||||
@@ -84,9 +90,10 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
|||||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||||
return agent, nil
|
return agent, nil
|
||||||
},
|
},
|
||||||
Database: dbM,
|
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||||
Pubsub: pub,
|
Database: dbM,
|
||||||
Log: testutil.Logger(t),
|
Pubsub: pub,
|
||||||
|
Log: testutil.Logger(t),
|
||||||
TimeNowFn: func() time.Time {
|
TimeNowFn: func() time.Time {
|
||||||
return now
|
return now
|
||||||
},
|
},
|
||||||
@@ -169,9 +176,10 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
|||||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||||
return agent, nil
|
return agent, nil
|
||||||
},
|
},
|
||||||
Database: dbM,
|
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||||
Pubsub: pub,
|
Database: dbM,
|
||||||
Log: testutil.Logger(t),
|
Pubsub: pub,
|
||||||
|
Log: testutil.Logger(t),
|
||||||
TimeNowFn: func() time.Time {
|
TimeNowFn: func() time.Time {
|
||||||
return now
|
return now
|
||||||
},
|
},
|
||||||
@@ -238,9 +246,10 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
|||||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||||
return agent, nil
|
return agent, nil
|
||||||
},
|
},
|
||||||
Database: dbM,
|
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||||
Pubsub: pub,
|
Database: dbM,
|
||||||
Log: testutil.Logger(t),
|
Pubsub: pub,
|
||||||
|
Log: testutil.Logger(t),
|
||||||
TimeNowFn: func() time.Time {
|
TimeNowFn: func() time.Time {
|
||||||
return now
|
return now
|
||||||
},
|
},
|
||||||
@@ -272,4 +281,421 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
|||||||
Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key},
|
Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key},
|
||||||
}, gotEvent)
|
}, gotEvent)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Test RBAC fast path with valid RBAC object - should NOT call GetWorkspaceByAgentID
|
||||||
|
// This test verifies that when a valid RBAC object is present in context, the dbauthz layer
|
||||||
|
// uses the fast path and skips the GetWorkspaceByAgentID database call.
|
||||||
|
t.Run("WorkspaceCached_SkipsDBCall", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
ctrl = gomock.NewController(t)
|
||||||
|
dbM = dbmock.NewMockStore(ctrl)
|
||||||
|
pub = &fakePublisher{}
|
||||||
|
now = dbtime.Now()
|
||||||
|
// Set up consistent IDs that represent a valid workspace->agent relationship
|
||||||
|
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||||
|
ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321")
|
||||||
|
orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
||||||
|
agentID = uuid.MustParse("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
|
||||||
|
)
|
||||||
|
|
||||||
|
agent := database.WorkspaceAgent{
|
||||||
|
ID: agentID,
|
||||||
|
// In a real scenario, this agent would belong to a resource in the workspace above
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &agentproto.BatchUpdateMetadataRequest{
|
||||||
|
Metadata: []*agentproto.Metadata{
|
||||||
|
{
|
||||||
|
Key: "test_key",
|
||||||
|
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||||
|
CollectedAt: timestamppb.New(now.Add(-time.Second)),
|
||||||
|
Age: 1,
|
||||||
|
Value: "test_value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect UpdateWorkspaceAgentMetadata to be called
|
||||||
|
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||||
|
WorkspaceAgentID: agent.ID,
|
||||||
|
Key: []string{"test_key"},
|
||||||
|
Value: []string{"test_value"},
|
||||||
|
Error: []string{""},
|
||||||
|
CollectedAt: []time.Time{now},
|
||||||
|
}).Return(nil)
|
||||||
|
|
||||||
|
// DO NOT expect GetWorkspaceByAgentID - the fast path should skip this call
|
||||||
|
// If GetWorkspaceByAgentID is called, the test will fail with "unexpected call"
|
||||||
|
|
||||||
|
// dbauthz will call Wrappers() to check for wrapped databases
|
||||||
|
dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes()
|
||||||
|
|
||||||
|
// Set up dbauthz to test the actual authorization layer
|
||||||
|
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||||
|
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||||||
|
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
|
||||||
|
accessControlStore.Store(&acs)
|
||||||
|
|
||||||
|
api := &agentapi.MetadataAPI{
|
||||||
|
AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) {
|
||||||
|
return agent, nil
|
||||||
|
},
|
||||||
|
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||||
|
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||||
|
Pubsub: pub,
|
||||||
|
Log: testutil.Logger(t),
|
||||||
|
TimeNowFn: func() time.Time {
|
||||||
|
return now
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
api.Workspace.UpdateValues(database.Workspace{
|
||||||
|
ID: workspaceID,
|
||||||
|
OwnerID: ownerID,
|
||||||
|
OrganizationID: orgID,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create context with system actor so authorization passes
|
||||||
|
ctx := dbauthz.AsSystemRestricted(context.Background())
|
||||||
|
resp, err := api.BatchUpdateMetadata(ctx, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
})
|
||||||
|
// Test RBAC slow path - invalid RBAC object should fall back to GetWorkspaceByAgentID
|
||||||
|
// This test verifies that when the RBAC object has invalid IDs (nil UUIDs), the dbauthz layer
|
||||||
|
// falls back to the slow path and calls GetWorkspaceByAgentID.
|
||||||
|
t.Run("InvalidWorkspaceCached_RequiresDBCall", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
ctrl = gomock.NewController(t)
|
||||||
|
dbM = dbmock.NewMockStore(ctrl)
|
||||||
|
pub = &fakePublisher{}
|
||||||
|
now = dbtime.Now()
|
||||||
|
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||||
|
ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321")
|
||||||
|
orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
||||||
|
agentID = uuid.MustParse("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
|
||||||
|
)
|
||||||
|
|
||||||
|
agent := database.WorkspaceAgent{
|
||||||
|
ID: agentID,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &agentproto.BatchUpdateMetadataRequest{
|
||||||
|
Metadata: []*agentproto.Metadata{
|
||||||
|
{
|
||||||
|
Key: "test_key",
|
||||||
|
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||||
|
CollectedAt: timestamppb.New(now.Add(-time.Second)),
|
||||||
|
Age: 1,
|
||||||
|
Value: "test_value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT GetWorkspaceByAgentID to be called because the RBAC fast path validation fails
|
||||||
|
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agentID).Return(database.Workspace{
|
||||||
|
ID: workspaceID,
|
||||||
|
OwnerID: ownerID,
|
||||||
|
OrganizationID: orgID,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
// Expect UpdateWorkspaceAgentMetadata to be called after authorization
|
||||||
|
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||||
|
WorkspaceAgentID: agent.ID,
|
||||||
|
Key: []string{"test_key"},
|
||||||
|
Value: []string{"test_value"},
|
||||||
|
Error: []string{""},
|
||||||
|
CollectedAt: []time.Time{now},
|
||||||
|
}).Return(nil)
|
||||||
|
|
||||||
|
// dbauthz will call Wrappers() to check for wrapped databases
|
||||||
|
dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes()
|
||||||
|
|
||||||
|
// Set up dbauthz to test the actual authorization layer
|
||||||
|
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||||
|
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||||||
|
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
|
||||||
|
accessControlStore.Store(&acs)
|
||||||
|
|
||||||
|
api := &agentapi.MetadataAPI{
|
||||||
|
AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) {
|
||||||
|
return agent, nil
|
||||||
|
},
|
||||||
|
|
||||||
|
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||||
|
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||||
|
Pubsub: pub,
|
||||||
|
Log: testutil.Logger(t),
|
||||||
|
TimeNowFn: func() time.Time {
|
||||||
|
return now
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an invalid RBAC object with nil UUIDs for owner/org
|
||||||
|
// This will fail dbauthz fast path validation and trigger GetWorkspaceByAgentID
|
||||||
|
api.Workspace.UpdateValues(database.Workspace{
|
||||||
|
ID: uuid.MustParse("cccccccc-cccc-cccc-cccc-cccccccccccc"),
|
||||||
|
OwnerID: uuid.Nil, // Invalid: fails dbauthz fast path validation
|
||||||
|
OrganizationID: uuid.Nil, // Invalid: fails dbauthz fast path validation
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create context with system actor so authorization passes
|
||||||
|
ctx := dbauthz.AsSystemRestricted(context.Background())
|
||||||
|
resp, err := api.BatchUpdateMetadata(ctx, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
})
|
||||||
|
// Test RBAC slow path - no RBAC object in context
|
||||||
|
// This test verifies that when no RBAC object is present in context, the dbauthz layer
|
||||||
|
// falls back to the slow path and calls GetWorkspaceByAgentID.
|
||||||
|
t.Run("WorkspaceNotCached_RequiresDBCall", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
ctrl = gomock.NewController(t)
|
||||||
|
dbM = dbmock.NewMockStore(ctrl)
|
||||||
|
pub = &fakePublisher{}
|
||||||
|
now = dbtime.Now()
|
||||||
|
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||||
|
ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321")
|
||||||
|
orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
||||||
|
agentID = uuid.MustParse("dddddddd-dddd-dddd-dddd-dddddddddddd")
|
||||||
|
)
|
||||||
|
|
||||||
|
agent := database.WorkspaceAgent{
|
||||||
|
ID: agentID,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &agentproto.BatchUpdateMetadataRequest{
|
||||||
|
Metadata: []*agentproto.Metadata{
|
||||||
|
{
|
||||||
|
Key: "test_key",
|
||||||
|
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||||
|
CollectedAt: timestamppb.New(now.Add(-time.Second)),
|
||||||
|
Age: 1,
|
||||||
|
Value: "test_value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT GetWorkspaceByAgentID to be called because no RBAC object is in context
|
||||||
|
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agentID).Return(database.Workspace{
|
||||||
|
ID: workspaceID,
|
||||||
|
OwnerID: ownerID,
|
||||||
|
OrganizationID: orgID,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
// Expect UpdateWorkspaceAgentMetadata to be called after authorization
|
||||||
|
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||||
|
WorkspaceAgentID: agent.ID,
|
||||||
|
Key: []string{"test_key"},
|
||||||
|
Value: []string{"test_value"},
|
||||||
|
Error: []string{""},
|
||||||
|
CollectedAt: []time.Time{now},
|
||||||
|
}).Return(nil)
|
||||||
|
|
||||||
|
// dbauthz will call Wrappers() to check for wrapped databases
|
||||||
|
dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes()
|
||||||
|
|
||||||
|
// Set up dbauthz to test the actual authorization layer
|
||||||
|
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||||
|
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||||||
|
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
|
||||||
|
accessControlStore.Store(&acs)
|
||||||
|
|
||||||
|
api := &agentapi.MetadataAPI{
|
||||||
|
AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) {
|
||||||
|
return agent, nil
|
||||||
|
},
|
||||||
|
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||||
|
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||||
|
Pubsub: pub,
|
||||||
|
Log: testutil.Logger(t),
|
||||||
|
TimeNowFn: func() time.Time {
|
||||||
|
return now
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create context with system actor so authorization passes
|
||||||
|
ctx := dbauthz.AsSystemRestricted(context.Background())
|
||||||
|
resp, err := api.BatchUpdateMetadata(ctx, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test cache refresh - AutostartSchedule updated
|
||||||
|
// This test verifies that the cache refresh mechanism actually calls GetWorkspaceByID
|
||||||
|
// and updates the cached workspace fields when the workspace is modified (e.g., autostart schedule changes).
|
||||||
|
t.Run("CacheRefreshed_AutostartScheduleUpdated", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
ctrl = gomock.NewController(t)
|
||||||
|
dbM = dbmock.NewMockStore(ctrl)
|
||||||
|
pub = &fakePublisher{}
|
||||||
|
now = dbtime.Now()
|
||||||
|
mClock = quartz.NewMock(t)
|
||||||
|
tickerTrap = mClock.Trap().TickerFunc("cache_refresh")
|
||||||
|
|
||||||
|
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||||
|
ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321")
|
||||||
|
orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
||||||
|
templateID = uuid.MustParse("aaaabbbb-cccc-dddd-eeee-ffffffff0000")
|
||||||
|
agentID = uuid.MustParse("ffffffff-ffff-ffff-ffff-ffffffffffff")
|
||||||
|
)
|
||||||
|
|
||||||
|
agent := database.WorkspaceAgent{
|
||||||
|
ID: agentID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initial workspace - has Monday-Friday 9am autostart
|
||||||
|
initialWorkspace := database.Workspace{
|
||||||
|
ID: workspaceID,
|
||||||
|
OwnerID: ownerID,
|
||||||
|
OrganizationID: orgID,
|
||||||
|
TemplateID: templateID,
|
||||||
|
Name: "my-workspace",
|
||||||
|
OwnerUsername: "testuser",
|
||||||
|
TemplateName: "test-template",
|
||||||
|
AutostartSchedule: sql.NullString{Valid: true, String: "CRON_TZ=UTC 0 9 * * 1-5"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Updated workspace - user changed autostart to 5pm and renamed workspace
|
||||||
|
updatedWorkspace := database.Workspace{
|
||||||
|
ID: workspaceID,
|
||||||
|
OwnerID: ownerID,
|
||||||
|
OrganizationID: orgID,
|
||||||
|
TemplateID: templateID,
|
||||||
|
Name: "my-workspace-renamed", // Changed!
|
||||||
|
OwnerUsername: "testuser",
|
||||||
|
TemplateName: "test-template",
|
||||||
|
AutostartSchedule: sql.NullString{Valid: true, String: "CRON_TZ=UTC 0 17 * * 1-5"}, // Changed!
|
||||||
|
DormantAt: sql.NullTime{},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &agentproto.BatchUpdateMetadataRequest{
|
||||||
|
Metadata: []*agentproto.Metadata{
|
||||||
|
{
|
||||||
|
Key: "test_key",
|
||||||
|
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||||
|
CollectedAt: timestamppb.New(now.Add(-time.Second)),
|
||||||
|
Age: 1,
|
||||||
|
Value: "test_value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT GetWorkspaceByID to be called during cache refresh
|
||||||
|
// This is the key assertion - proves the refresh mechanism is working
|
||||||
|
dbM.EXPECT().GetWorkspaceByID(gomock.Any(), workspaceID).Return(updatedWorkspace, nil)
|
||||||
|
|
||||||
|
// API needs to fetch the agent when calling metadata update
|
||||||
|
dbM.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(agent, nil)
|
||||||
|
|
||||||
|
// After refresh, metadata update should work with updated cache
|
||||||
|
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||||
|
func(ctx context.Context, params database.UpdateWorkspaceAgentMetadataParams) error {
|
||||||
|
require.Equal(t, agent.ID, params.WorkspaceAgentID)
|
||||||
|
require.Equal(t, []string{"test_key"}, params.Key)
|
||||||
|
require.Equal(t, []string{"test_value"}, params.Value)
|
||||||
|
require.Equal(t, []string{""}, params.Error)
|
||||||
|
require.Len(t, params.CollectedAt, 1)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
).AnyTimes()
|
||||||
|
|
||||||
|
// May call GetWorkspaceByAgentID if slow path is used before refresh
|
||||||
|
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agentID).Return(updatedWorkspace, nil).AnyTimes()
|
||||||
|
|
||||||
|
// dbauthz will call Wrappers()
|
||||||
|
dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes()
|
||||||
|
|
||||||
|
// Set up dbauthz
|
||||||
|
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||||
|
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||||||
|
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
|
||||||
|
accessControlStore.Store(&acs)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Create roles with workspace permissions
|
||||||
|
userRoles := rbac.Roles([]rbac.Role{
|
||||||
|
{
|
||||||
|
Identifier: rbac.RoleMember(),
|
||||||
|
User: []rbac.Permission{
|
||||||
|
{
|
||||||
|
Negate: false,
|
||||||
|
ResourceType: rbac.ResourceWorkspace.Type,
|
||||||
|
Action: policy.WildcardSymbol,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ByOrgID: map[string]rbac.OrgPermissions{
|
||||||
|
orgID.String(): {
|
||||||
|
Member: []rbac.Permission{
|
||||||
|
{
|
||||||
|
Negate: false,
|
||||||
|
ResourceType: rbac.ResourceWorkspace.Type,
|
||||||
|
Action: policy.WildcardSymbol,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
agentScope := rbac.WorkspaceAgentScope(rbac.WorkspaceAgentScopeParams{
|
||||||
|
WorkspaceID: workspaceID,
|
||||||
|
OwnerID: ownerID,
|
||||||
|
TemplateID: templateID,
|
||||||
|
VersionID: uuid.New(),
|
||||||
|
})
|
||||||
|
|
||||||
|
ctxWithActor := dbauthz.As(ctx, rbac.Subject{
|
||||||
|
Type: rbac.SubjectTypeUser,
|
||||||
|
FriendlyName: "testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
ID: ownerID.String(),
|
||||||
|
Roles: userRoles,
|
||||||
|
Groups: []string{orgID.String()},
|
||||||
|
Scope: agentScope,
|
||||||
|
}.WithCachedASTValue())
|
||||||
|
|
||||||
|
// Create full API with cached workspace fields (initial state)
|
||||||
|
api := agentapi.New(agentapi.Options{
|
||||||
|
Ctx: ctxWithActor,
|
||||||
|
AgentID: agentID,
|
||||||
|
WorkspaceID: workspaceID,
|
||||||
|
OwnerID: ownerID,
|
||||||
|
OrganizationID: orgID,
|
||||||
|
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||||
|
Log: testutil.Logger(t),
|
||||||
|
Clock: mClock,
|
||||||
|
Pubsub: pub,
|
||||||
|
}, initialWorkspace) // Cache is initialized with 9am schedule and "my-workspace" name
|
||||||
|
|
||||||
|
// Wait for ticker to be set up and release it so it can fire
|
||||||
|
tickerTrap.MustWait(ctx).MustRelease(ctx)
|
||||||
|
tickerTrap.Close()
|
||||||
|
|
||||||
|
// Advance clock to trigger cache refresh and wait for it to complete
|
||||||
|
_, aw := mClock.AdvanceNext()
|
||||||
|
aw.MustWait(ctx)
|
||||||
|
|
||||||
|
// At this point, GetWorkspaceByID should have been called and cache updated
|
||||||
|
// The cache now has the 5pm schedule and "my-workspace-renamed" name
|
||||||
|
|
||||||
|
// Now call metadata update to verify the refreshed cache works
|
||||||
|
resp, err := api.MetadataAPI.BatchUpdateMetadata(ctxWithActor, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
type StatsAPI struct {
|
type StatsAPI struct {
|
||||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||||
|
Workspace *CachedWorkspaceFields
|
||||||
Database database.Store
|
Database database.Store
|
||||||
Log slog.Logger
|
Log slog.Logger
|
||||||
StatsReporter *workspacestats.Reporter
|
StatsReporter *workspacestats.Reporter
|
||||||
@@ -46,14 +47,21 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
getWorkspaceAgentByIDRow, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
|
|
||||||
if err != nil {
|
// If cache is empty (prebuild or invalid), fall back to DB
|
||||||
return nil, xerrors.Errorf("get workspace by agent ID %q: %w", workspaceAgent.ID, err)
|
var ws database.WorkspaceIdentity
|
||||||
|
var ok bool
|
||||||
|
if ws, ok = a.Workspace.AsWorkspaceIdentity(); !ok {
|
||||||
|
w, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("get workspace by agent ID %q: %w", workspaceAgent.ID, err)
|
||||||
|
}
|
||||||
|
ws = database.WorkspaceIdentityFromWorkspace(w)
|
||||||
}
|
}
|
||||||
workspace := getWorkspaceAgentByIDRow
|
|
||||||
a.Log.Debug(ctx, "read stats report",
|
a.Log.Debug(ctx, "read stats report",
|
||||||
slog.F("interval", a.AgentStatsRefreshInterval),
|
slog.F("interval", a.AgentStatsRefreshInterval),
|
||||||
slog.F("workspace_id", workspace.ID),
|
slog.F("workspace_id", ws.ID),
|
||||||
slog.F("payload", req),
|
slog.F("payload", req),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -70,9 +78,8 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
|
|||||||
err = a.StatsReporter.ReportAgentStats(
|
err = a.StatsReporter.ReportAgentStats(
|
||||||
ctx,
|
ctx,
|
||||||
a.now(),
|
a.now(),
|
||||||
workspace,
|
ws,
|
||||||
workspaceAgent,
|
workspaceAgent,
|
||||||
getWorkspaceAgentByIDRow.TemplateName,
|
|
||||||
req.Stats,
|
req.Stats,
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -52,8 +52,19 @@ func TestUpdateStates(t *testing.T) {
|
|||||||
ID: uuid.New(),
|
ID: uuid.New(),
|
||||||
Name: "abc",
|
Name: "abc",
|
||||||
}
|
}
|
||||||
|
workspaceAsCacheFields = agentapi.CachedWorkspaceFields{}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
workspaceAsCacheFields.UpdateValues(database.Workspace{
|
||||||
|
ID: workspace.ID,
|
||||||
|
OwnerID: workspace.OwnerID,
|
||||||
|
OwnerUsername: workspace.OwnerUsername,
|
||||||
|
TemplateID: workspace.TemplateID,
|
||||||
|
Name: workspace.Name,
|
||||||
|
TemplateName: workspace.TemplateName,
|
||||||
|
AutostartSchedule: workspace.AutostartSchedule,
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("OK", func(t *testing.T) {
|
t.Run("OK", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -111,7 +122,8 @@ func TestUpdateStates(t *testing.T) {
|
|||||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||||
return agent, nil
|
return agent, nil
|
||||||
},
|
},
|
||||||
Database: dbM,
|
Workspace: &workspaceAsCacheFields,
|
||||||
|
Database: dbM,
|
||||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||||
Database: dbM,
|
Database: dbM,
|
||||||
Pubsub: ps,
|
Pubsub: ps,
|
||||||
@@ -136,9 +148,6 @@ func TestUpdateStates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer wut.Close()
|
defer wut.Close()
|
||||||
|
|
||||||
// Workspace gets fetched.
|
|
||||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(workspace, nil)
|
|
||||||
|
|
||||||
// We expect an activity bump because ConnectionCount > 0.
|
// We expect an activity bump because ConnectionCount > 0.
|
||||||
dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{
|
dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{
|
||||||
WorkspaceID: workspace.ID,
|
WorkspaceID: workspace.ID,
|
||||||
@@ -223,7 +232,8 @@ func TestUpdateStates(t *testing.T) {
|
|||||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||||
return agent, nil
|
return agent, nil
|
||||||
},
|
},
|
||||||
Database: dbM,
|
Workspace: &workspaceAsCacheFields,
|
||||||
|
Database: dbM,
|
||||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||||
Database: dbM,
|
Database: dbM,
|
||||||
Pubsub: ps,
|
Pubsub: ps,
|
||||||
@@ -239,9 +249,6 @@ func TestUpdateStates(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Workspace gets fetched.
|
|
||||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(workspace, nil)
|
|
||||||
|
|
||||||
_, err := api.UpdateStats(context.Background(), req)
|
_, err := api.UpdateStats(context.Background(), req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
@@ -260,7 +267,8 @@ func TestUpdateStates(t *testing.T) {
|
|||||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||||
return agent, nil
|
return agent, nil
|
||||||
},
|
},
|
||||||
Database: dbM,
|
Workspace: &workspaceAsCacheFields,
|
||||||
|
Database: dbM,
|
||||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||||
Database: dbM,
|
Database: dbM,
|
||||||
Pubsub: ps,
|
Pubsub: ps,
|
||||||
@@ -333,11 +341,17 @@ func TestUpdateStates(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
// need to overwrite the cached fields for this test, but the struct has a lock
|
||||||
|
ws := agentapi.CachedWorkspaceFields{}
|
||||||
|
ws.UpdateValues(workspace)
|
||||||
|
// ws.AutostartSchedule = workspace.AutostartSchedule
|
||||||
|
|
||||||
api := agentapi.StatsAPI{
|
api := agentapi.StatsAPI{
|
||||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||||
return agent, nil
|
return agent, nil
|
||||||
},
|
},
|
||||||
Database: dbM,
|
Workspace: &ws,
|
||||||
|
Database: dbM,
|
||||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||||
Database: dbM,
|
Database: dbM,
|
||||||
Pubsub: ps,
|
Pubsub: ps,
|
||||||
@@ -362,9 +376,6 @@ func TestUpdateStates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer wut.Close()
|
defer wut.Close()
|
||||||
|
|
||||||
// Workspace gets fetched.
|
|
||||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(workspace, nil)
|
|
||||||
|
|
||||||
// We expect an activity bump because ConnectionCount > 0. However, the
|
// We expect an activity bump because ConnectionCount > 0. However, the
|
||||||
// next autostart time will be set on the bump.
|
// next autostart time will be set on the bump.
|
||||||
dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{
|
dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{
|
||||||
@@ -451,7 +462,8 @@ func TestUpdateStates(t *testing.T) {
|
|||||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||||
return agent, nil
|
return agent, nil
|
||||||
},
|
},
|
||||||
Database: dbM,
|
Workspace: &workspaceAsCacheFields,
|
||||||
|
Database: dbM,
|
||||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||||
Database: dbM,
|
Database: dbM,
|
||||||
Pubsub: ps,
|
Pubsub: ps,
|
||||||
@@ -478,9 +490,6 @@ func TestUpdateStates(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Workspace gets fetched.
|
|
||||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(workspace, nil)
|
|
||||||
|
|
||||||
// We expect an activity bump because ConnectionCount > 0.
|
// We expect an activity bump because ConnectionCount > 0.
|
||||||
dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{
|
dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{
|
||||||
WorkspaceID: workspace.ID,
|
WorkspaceID: workspace.ID,
|
||||||
|
|||||||
@@ -5556,6 +5556,22 @@ func (q *querier) UpdateWorkspaceAgentLogOverflowByID(ctx context.Context, arg d
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *querier) UpdateWorkspaceAgentMetadata(ctx context.Context, arg database.UpdateWorkspaceAgentMetadataParams) error {
|
func (q *querier) UpdateWorkspaceAgentMetadata(ctx context.Context, arg database.UpdateWorkspaceAgentMetadataParams) error {
|
||||||
|
// Fast path: Check if we have an RBAC object in context.
|
||||||
|
// This is set by the workspace agent RPC handler to avoid the expensive
|
||||||
|
// GetWorkspaceByAgentID query for every metadata update.
|
||||||
|
// NOTE: The cached RBAC object is refreshed every 5 minutes in agentapi/api.go.
|
||||||
|
if rbacObj, ok := WorkspaceRBACFromContext(ctx); ok {
|
||||||
|
// Errors here will result in falling back to the GetWorkspaceAgentByID query, skipping
|
||||||
|
// the cache in case the cached data is stale.
|
||||||
|
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbacObj); err == nil {
|
||||||
|
return q.db.UpdateWorkspaceAgentMetadata(ctx, arg)
|
||||||
|
}
|
||||||
|
q.log.Debug(ctx, "fast path authorization failed, using slow path",
|
||||||
|
slog.F("agent_id", arg.WorkspaceAgentID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Slow path: Fallback to fetching the workspace for authorization if the RBAC object is not present (or is invalid)
|
||||||
|
// in the request context.
|
||||||
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.WorkspaceAgentID)
|
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.WorkspaceAgentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -0,0 +1,41 @@
|
|||||||
|
package dbauthz
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/coderd/rbac"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isWorkspaceRBACObjectEmpty(rbacObj rbac.Object) bool {
|
||||||
|
// if any of these are true then the rbac.Object work a workspace is considered empty
|
||||||
|
return rbacObj.Owner == "" || rbacObj.OrgID == "" || rbacObj.Owner == uuid.Nil.String() || rbacObj.OrgID == uuid.Nil.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
type workspaceRBACContextKey struct{}
|
||||||
|
|
||||||
|
// WithWorkspaceRBAC attaches a workspace RBAC object to the context.
|
||||||
|
// RBAC fields on this RBAC object should not be used.
|
||||||
|
//
|
||||||
|
// This is primarily used by the workspace agent RPC handler to cache workspace
|
||||||
|
// authorization data for the duration of an agent connection.
|
||||||
|
func WithWorkspaceRBAC(ctx context.Context, rbacObj rbac.Object) (context.Context, error) {
|
||||||
|
if rbacObj.Type != rbac.ResourceWorkspace.Type {
|
||||||
|
return ctx, xerrors.New("RBAC Object must be of type Workspace")
|
||||||
|
}
|
||||||
|
if isWorkspaceRBACObjectEmpty(rbacObj) {
|
||||||
|
return ctx, xerrors.Errorf("cannot attach empty RBAC object to context: %+v", rbacObj)
|
||||||
|
}
|
||||||
|
if len(rbacObj.ACLGroupList) != 0 || len(rbacObj.ACLUserList) != 0 {
|
||||||
|
return ctx, xerrors.New("ACL fields for Workspace RBAC object must be nullified, the can be changed during runtime and should not be cached")
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, workspaceRBACContextKey{}, rbacObj), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WorkspaceRBACFromContext attempts to retrieve the workspace RBAC object from context.
|
||||||
|
func WorkspaceRBACFromContext(ctx context.Context) (rbac.Object, bool) {
|
||||||
|
obj, ok := ctx.Value(workspaceRBACContextKey{}).(rbac.Object)
|
||||||
|
return obj, ok
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"slices"
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
@@ -796,3 +797,60 @@ func (s UserSecret) RBACObject() rbac.Object {
|
|||||||
func (s AIBridgeInterception) RBACObject() rbac.Object {
|
func (s AIBridgeInterception) RBACObject() rbac.Object {
|
||||||
return rbac.ResourceAibridgeInterception.WithOwner(s.InitiatorID.String())
|
return rbac.ResourceAibridgeInterception.WithOwner(s.InitiatorID.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WorkspaceIdentity contains the minimal workspace fields needed for agent API metadata/stats reporting
|
||||||
|
// and RBAC checks, without requiring a full database.Workspace object.
|
||||||
|
type WorkspaceIdentity struct {
|
||||||
|
// Add any other fields needed for IsPrebuild() if it relies on workspace fields
|
||||||
|
// Identity fields
|
||||||
|
ID uuid.UUID
|
||||||
|
OwnerID uuid.UUID
|
||||||
|
OrganizationID uuid.UUID
|
||||||
|
TemplateID uuid.UUID
|
||||||
|
|
||||||
|
// Display fields for logging/metrics
|
||||||
|
Name string
|
||||||
|
OwnerUsername string
|
||||||
|
TemplateName string
|
||||||
|
|
||||||
|
// Lifecycle fields needed for stats reporting
|
||||||
|
AutostartSchedule sql.NullString
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w WorkspaceIdentity) RBACObject() rbac.Object {
|
||||||
|
return Workspace{
|
||||||
|
ID: w.ID,
|
||||||
|
OwnerID: w.OwnerID,
|
||||||
|
OrganizationID: w.OrganizationID,
|
||||||
|
TemplateID: w.TemplateID,
|
||||||
|
Name: w.Name,
|
||||||
|
OwnerUsername: w.OwnerUsername,
|
||||||
|
TemplateName: w.TemplateName,
|
||||||
|
AutostartSchedule: w.AutostartSchedule,
|
||||||
|
}.RBACObject()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPrebuild returns true if the workspace is a prebuild workspace.
|
||||||
|
// A workspace is considered a prebuild if its owner is the prebuild system user.
|
||||||
|
func (w WorkspaceIdentity) IsPrebuild() bool {
|
||||||
|
return w.OwnerID == PrebuildsSystemUserID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w WorkspaceIdentity) Equal(w2 WorkspaceIdentity) bool {
|
||||||
|
return w.ID == w2.ID && w.OwnerID == w2.OwnerID && w.OrganizationID == w2.OrganizationID &&
|
||||||
|
w.TemplateID == w2.TemplateID && w.Name == w2.Name && w.OwnerUsername == w2.OwnerUsername &&
|
||||||
|
w.TemplateName == w2.TemplateName && w.AutostartSchedule == w2.AutostartSchedule
|
||||||
|
}
|
||||||
|
|
||||||
|
func WorkspaceIdentityFromWorkspace(w Workspace) WorkspaceIdentity {
|
||||||
|
return WorkspaceIdentity{
|
||||||
|
ID: w.ID,
|
||||||
|
OwnerID: w.OwnerID,
|
||||||
|
OrganizationID: w.OrganizationID,
|
||||||
|
TemplateID: w.TemplateID,
|
||||||
|
Name: w.Name,
|
||||||
|
OwnerUsername: w.OwnerUsername,
|
||||||
|
TemplateName: w.TemplateName,
|
||||||
|
AutostartSchedule: w.AutostartSchedule,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Optional:
|
// Optional:
|
||||||
UpdateAgentMetricsFn: api.UpdateAgentMetrics,
|
UpdateAgentMetricsFn: api.UpdateAgentMetrics,
|
||||||
})
|
}, workspace)
|
||||||
|
|
||||||
streamID := tailnet.StreamID{
|
streamID := tailnet.StreamID{
|
||||||
Name: fmt.Sprintf("%s-%s-%s", workspace.OwnerUsername, workspace.Name, workspaceAgent.Name),
|
Name: fmt.Sprintf("%s-%s-%s", workspace.OwnerUsername, workspace.Name, workspaceAgent.Name),
|
||||||
|
|||||||
@@ -1717,13 +1717,13 @@ func (api *API) postWorkspaceUsage(rw http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
template, err := api.Database.GetTemplateByID(ctx, workspace.TemplateID)
|
// template, err := api.Database.GetTemplateByID(ctx, workspace.TemplateID)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
httpapi.InternalServerError(rw, err)
|
// httpapi.InternalServerError(rw, err)
|
||||||
return
|
// return
|
||||||
}
|
// }
|
||||||
|
|
||||||
err = api.statsReporter.ReportAgentStats(ctx, dbtime.Now(), workspace, agent, template.Name, stat, true)
|
err = api.statsReporter.ReportAgentStats(ctx, dbtime.Now(), database.WorkspaceIdentityFromWorkspace(workspace), agent, stat, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpapi.InternalServerError(rw, err)
|
httpapi.InternalServerError(rw, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ func (r *Reporter) ReportAppStats(ctx context.Context, stats []workspaceapps.Sta
|
|||||||
}
|
}
|
||||||
|
|
||||||
// nolint:revive // usage is a control flag while we have the experiment
|
// nolint:revive // usage is a control flag while we have the experiment
|
||||||
func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspace database.Workspace, workspaceAgent database.WorkspaceAgent, templateName string, stats *agentproto.Stats, usage bool) error {
|
func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspace database.WorkspaceIdentity, workspaceAgent database.WorkspaceAgent, stats *agentproto.Stats, usage bool) error {
|
||||||
// update agent stats
|
// update agent stats
|
||||||
r.opts.StatsBatcher.Add(now, workspaceAgent.ID, workspace.TemplateID, workspace.OwnerID, workspace.ID, stats, usage)
|
r.opts.StatsBatcher.Add(now, workspaceAgent.ID, workspace.TemplateID, workspace.OwnerID, workspace.ID, stats, usage)
|
||||||
|
|
||||||
@@ -130,7 +130,7 @@ func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspac
|
|||||||
Username: workspace.OwnerUsername,
|
Username: workspace.OwnerUsername,
|
||||||
WorkspaceName: workspace.Name,
|
WorkspaceName: workspace.Name,
|
||||||
AgentName: workspaceAgent.Name,
|
AgentName: workspaceAgent.Name,
|
||||||
TemplateName: templateName,
|
TemplateName: workspace.TemplateName,
|
||||||
}, stats.Metrics)
|
}, stats.Metrics)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user