mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
perf: reduce DB calls to GetWorkspaceByAgentID via caching workspace info (#20662)
--------- Signed-off-by: Callum Styan <callumstyan@gmail.com>
This commit is contained in:
+72
-5
@@ -36,6 +36,8 @@ import (
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const workspaceCacheRefreshInterval = 5 * time.Minute
|
||||
|
||||
// API implements the DRPC agent API interface from agent/proto. This struct is
|
||||
// instantiated once per agent connection and kept alive for the duration of the
|
||||
// session.
|
||||
@@ -54,6 +56,8 @@ type API struct {
|
||||
*SubAgentAPI
|
||||
*tailnet.DRPCService
|
||||
|
||||
cachedWorkspaceFields *CachedWorkspaceFields
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
@@ -92,7 +96,7 @@ type Options struct {
|
||||
UpdateAgentMetricsFn func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
|
||||
}
|
||||
|
||||
func New(opts Options) *API {
|
||||
func New(opts Options, workspace database.Workspace) *API {
|
||||
if opts.Clock == nil {
|
||||
opts.Clock = quartz.NewReal()
|
||||
}
|
||||
@@ -114,6 +118,13 @@ func New(opts Options) *API {
|
||||
WorkspaceID: opts.WorkspaceID,
|
||||
}
|
||||
|
||||
// Don't cache details for prebuilds, though the cached fields will eventually be updated
|
||||
// by the refresh routine once the prebuild workspace is claimed.
|
||||
api.cachedWorkspaceFields = &CachedWorkspaceFields{}
|
||||
if !workspace.IsPrebuild() {
|
||||
api.cachedWorkspaceFields.UpdateValues(workspace)
|
||||
}
|
||||
|
||||
api.AnnouncementBannerAPI = &AnnouncementBannerAPI{
|
||||
appearanceFetcher: opts.AppearanceFetcher,
|
||||
}
|
||||
@@ -139,6 +150,7 @@ func New(opts Options) *API {
|
||||
|
||||
api.StatsAPI = &StatsAPI{
|
||||
AgentFn: api.agent,
|
||||
Workspace: api.cachedWorkspaceFields,
|
||||
Database: opts.Database,
|
||||
Log: opts.Log,
|
||||
StatsReporter: opts.StatsReporter,
|
||||
@@ -162,10 +174,11 @@ func New(opts Options) *API {
|
||||
}
|
||||
|
||||
api.MetadataAPI = &MetadataAPI{
|
||||
AgentFn: api.agent,
|
||||
Database: opts.Database,
|
||||
Pubsub: opts.Pubsub,
|
||||
Log: opts.Log,
|
||||
AgentFn: api.agent,
|
||||
Workspace: api.cachedWorkspaceFields,
|
||||
Database: opts.Database,
|
||||
Pubsub: opts.Pubsub,
|
||||
Log: opts.Log,
|
||||
}
|
||||
|
||||
api.LogsAPI = &LogsAPI{
|
||||
@@ -205,6 +218,10 @@ func New(opts Options) *API {
|
||||
Database: opts.Database,
|
||||
}
|
||||
|
||||
// Start background cache refresh loop to handle workspace changes
|
||||
// like prebuild claims where owner_id and other fields may be modified in the DB.
|
||||
go api.startCacheRefreshLoop(opts.Ctx)
|
||||
|
||||
return api
|
||||
}
|
||||
|
||||
@@ -254,6 +271,56 @@ func (a *API) agent(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
// refreshCachedWorkspace periodically updates the cached workspace fields.
|
||||
// This ensures that changes like prebuild claims (which modify owner_id, name, etc.)
|
||||
// are eventually reflected in the cache without requiring agent reconnection.
|
||||
func (a *API) refreshCachedWorkspace(ctx context.Context) {
|
||||
ws, err := a.opts.Database.GetWorkspaceByID(ctx, a.opts.WorkspaceID)
|
||||
if err != nil {
|
||||
a.opts.Log.Warn(ctx, "failed to refresh cached workspace fields", slog.Error(err))
|
||||
a.cachedWorkspaceFields.Clear()
|
||||
return
|
||||
}
|
||||
|
||||
if ws.IsPrebuild() {
|
||||
return
|
||||
}
|
||||
|
||||
// If we still have the same values, skip the update and logging calls.
|
||||
if a.cachedWorkspaceFields.identity.Equal(database.WorkspaceIdentityFromWorkspace(ws)) {
|
||||
return
|
||||
}
|
||||
// Update fields that can change during workspace lifecycle (e.g., AutostartSchedule)
|
||||
a.cachedWorkspaceFields.UpdateValues(ws)
|
||||
|
||||
a.opts.Log.Debug(ctx, "refreshed cached workspace fields",
|
||||
slog.F("workspace_id", ws.ID),
|
||||
slog.F("owner_id", ws.OwnerID),
|
||||
slog.F("name", ws.Name))
|
||||
}
|
||||
|
||||
// startCacheRefreshLoop runs a background goroutine that periodically refreshes
|
||||
// the cached workspace fields. This is primarily needed to handle prebuild claims
|
||||
// where the owner_id and other fields change while the agent connection persists.
|
||||
func (a *API) startCacheRefreshLoop(ctx context.Context) {
|
||||
// Refresh every 5 minutes. This provides a reasonable balance between:
|
||||
// - Keeping cache fresh for prebuild claims and other workspace updates
|
||||
// - Minimizing unnecessary database queries
|
||||
ticker := a.opts.Clock.TickerFunc(ctx, workspaceCacheRefreshInterval, func() error {
|
||||
a.refreshCachedWorkspace(ctx)
|
||||
return nil
|
||||
}, "cache_refresh")
|
||||
|
||||
// We need to wait on the ticker exiting.
|
||||
_ = ticker.Wait()
|
||||
|
||||
a.opts.Log.Debug(ctx, "cache refresh loop exited, invalidating the workspace cache on agent API",
|
||||
slog.F("workspace_id", a.cachedWorkspaceFields.identity.ID),
|
||||
slog.F("owner_id", a.cachedWorkspaceFields.identity.OwnerUsername),
|
||||
slog.F("name", a.cachedWorkspaceFields.identity.Name))
|
||||
a.cachedWorkspaceFields.Clear()
|
||||
}
|
||||
|
||||
func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
|
||||
a.opts.PublishWorkspaceUpdateFn(ctx, a.opts.OwnerID, wspubsub.WorkspaceEvent{
|
||||
Kind: kind,
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package agentapi
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
// CachedWorkspaceFields contains workspace data that is safe to cache for the
|
||||
// duration of an agent connection. These fields are used to reduce database calls
|
||||
// in high-frequency operations like stats reporting and metadata updates.
|
||||
// Prebuild workspaces should not be cached using this struct within the API struct,
|
||||
// however some of these fields for a workspace can be updated live so there is a
|
||||
// routine in the API for refreshing the workspace on a timed interval.
|
||||
//
|
||||
// IMPORTANT: ACL fields (GroupACL, UserACL) are NOT cached because they can be
|
||||
// modified in the database and we must use fresh data for authorization checks.
|
||||
type CachedWorkspaceFields struct {
|
||||
lock sync.RWMutex
|
||||
|
||||
identity database.WorkspaceIdentity
|
||||
}
|
||||
|
||||
func (cws *CachedWorkspaceFields) Clear() {
|
||||
cws.lock.Lock()
|
||||
defer cws.lock.Unlock()
|
||||
cws.identity = database.WorkspaceIdentity{}
|
||||
}
|
||||
|
||||
func (cws *CachedWorkspaceFields) UpdateValues(ws database.Workspace) {
|
||||
cws.lock.Lock()
|
||||
defer cws.lock.Unlock()
|
||||
cws.identity.ID = ws.ID
|
||||
cws.identity.OwnerID = ws.OwnerID
|
||||
cws.identity.OrganizationID = ws.OrganizationID
|
||||
cws.identity.TemplateID = ws.TemplateID
|
||||
cws.identity.Name = ws.Name
|
||||
cws.identity.OwnerUsername = ws.OwnerUsername
|
||||
cws.identity.TemplateName = ws.TemplateName
|
||||
cws.identity.AutostartSchedule = ws.AutostartSchedule
|
||||
}
|
||||
|
||||
// Returns the Workspace, true, unless the workspace has not been cached (nuked or was a prebuild).
|
||||
func (cws *CachedWorkspaceFields) AsWorkspaceIdentity() (database.WorkspaceIdentity, bool) {
|
||||
cws.lock.RLock()
|
||||
defer cws.lock.RUnlock()
|
||||
// Should we be more explicit about all fields being set to be valid?
|
||||
if cws.identity.Equal(database.WorkspaceIdentity{}) {
|
||||
return database.WorkspaceIdentity{}, false
|
||||
}
|
||||
return cws.identity, true
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package agentapi_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
func TestCacheClear(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
user = database.User{
|
||||
ID: uuid.New(),
|
||||
Username: "bill",
|
||||
}
|
||||
template = database.Template{
|
||||
ID: uuid.New(),
|
||||
Name: "tpl",
|
||||
}
|
||||
workspace = database.Workspace{
|
||||
ID: uuid.New(),
|
||||
OwnerID: user.ID,
|
||||
OwnerUsername: user.Username,
|
||||
TemplateID: template.ID,
|
||||
Name: "xyz",
|
||||
TemplateName: template.Name,
|
||||
}
|
||||
workspaceAsCacheFields = agentapi.CachedWorkspaceFields{}
|
||||
)
|
||||
|
||||
workspaceAsCacheFields.UpdateValues(database.Workspace{
|
||||
ID: workspace.ID,
|
||||
OwnerID: workspace.OwnerID,
|
||||
OwnerUsername: workspace.OwnerUsername,
|
||||
TemplateID: workspace.TemplateID,
|
||||
Name: workspace.Name,
|
||||
TemplateName: workspace.TemplateName,
|
||||
AutostartSchedule: workspace.AutostartSchedule,
|
||||
},
|
||||
)
|
||||
|
||||
emptyCws := agentapi.CachedWorkspaceFields{}
|
||||
workspaceAsCacheFields.Clear()
|
||||
wsi, ok := workspaceAsCacheFields.AsWorkspaceIdentity()
|
||||
require.False(t, ok)
|
||||
ecwsi, ok := emptyCws.AsWorkspaceIdentity()
|
||||
require.False(t, ok)
|
||||
require.True(t, ecwsi.Equal(wsi))
|
||||
}
|
||||
|
||||
func TestCacheUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
user = database.User{
|
||||
ID: uuid.New(),
|
||||
Username: "bill",
|
||||
}
|
||||
template = database.Template{
|
||||
ID: uuid.New(),
|
||||
Name: "tpl",
|
||||
}
|
||||
workspace = database.Workspace{
|
||||
ID: uuid.New(),
|
||||
OwnerID: user.ID,
|
||||
OwnerUsername: user.Username,
|
||||
TemplateID: template.ID,
|
||||
Name: "xyz",
|
||||
TemplateName: template.Name,
|
||||
}
|
||||
workspaceAsCacheFields = agentapi.CachedWorkspaceFields{}
|
||||
)
|
||||
|
||||
workspaceAsCacheFields.UpdateValues(database.Workspace{
|
||||
ID: workspace.ID,
|
||||
OwnerID: workspace.OwnerID,
|
||||
OwnerUsername: workspace.OwnerUsername,
|
||||
TemplateID: workspace.TemplateID,
|
||||
Name: workspace.Name,
|
||||
TemplateName: workspace.TemplateName,
|
||||
AutostartSchedule: workspace.AutostartSchedule,
|
||||
},
|
||||
)
|
||||
|
||||
cws := agentapi.CachedWorkspaceFields{}
|
||||
cws.UpdateValues(workspace)
|
||||
wsi, ok := workspaceAsCacheFields.AsWorkspaceIdentity()
|
||||
require.True(t, ok)
|
||||
cwsi, ok := cws.AsWorkspaceIdentity()
|
||||
require.True(t, ok)
|
||||
require.True(t, wsi.Equal(cwsi))
|
||||
}
|
||||
@@ -12,15 +12,17 @@ import (
|
||||
"cdr.dev/slog"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
type MetadataAPI struct {
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
Database database.Store
|
||||
Pubsub pubsub.Pubsub
|
||||
Log slog.Logger
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
Workspace *CachedWorkspaceFields
|
||||
Database database.Store
|
||||
Pubsub pubsub.Pubsub
|
||||
Log slog.Logger
|
||||
|
||||
TimeNowFn func() time.Time // defaults to dbtime.Now()
|
||||
}
|
||||
@@ -107,7 +109,19 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
|
||||
)
|
||||
}
|
||||
|
||||
err = a.Database.UpdateWorkspaceAgentMetadata(ctx, dbUpdate)
|
||||
// Inject RBAC object into context for dbauthz fast path, avoid having to
|
||||
// call GetWorkspaceByAgentID on every metadata update.
|
||||
rbacCtx := ctx
|
||||
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
|
||||
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
|
||||
if err != nil {
|
||||
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
|
||||
//nolint:gocritic
|
||||
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
|
||||
}
|
||||
}
|
||||
|
||||
err = a.Database.UpdateWorkspaceAgentMetadata(rbacCtx, dbUpdate)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("update workspace agent metadata in database: %w", err)
|
||||
}
|
||||
|
||||
@@ -2,12 +2,14 @@ package agentapi_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
@@ -15,10 +17,14 @@ import (
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
type fakePublisher struct {
|
||||
@@ -84,9 +90,10 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
@@ -169,9 +176,10 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
@@ -238,9 +246,10 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbM,
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
@@ -272,4 +281,421 @@ func TestBatchUpdateMetadata(t *testing.T) {
|
||||
Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key},
|
||||
}, gotEvent)
|
||||
})
|
||||
|
||||
// Test RBAC fast path with valid RBAC object - should NOT call GetWorkspaceByAgentID
|
||||
// This test verifies that when a valid RBAC object is present in context, the dbauthz layer
|
||||
// uses the fast path and skips the GetWorkspaceByAgentID database call.
|
||||
t.Run("WorkspaceCached_SkipsDBCall", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
dbM = dbmock.NewMockStore(ctrl)
|
||||
pub = &fakePublisher{}
|
||||
now = dbtime.Now()
|
||||
// Set up consistent IDs that represent a valid workspace->agent relationship
|
||||
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||
ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321")
|
||||
orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
||||
agentID = uuid.MustParse("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
|
||||
)
|
||||
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
// In a real scenario, this agent would belong to a resource in the workspace above
|
||||
}
|
||||
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
Metadata: []*agentproto.Metadata{
|
||||
{
|
||||
Key: "test_key",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
CollectedAt: timestamppb.New(now.Add(-time.Second)),
|
||||
Age: 1,
|
||||
Value: "test_value",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Expect UpdateWorkspaceAgentMetadata to be called
|
||||
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
Key: []string{"test_key"},
|
||||
Value: []string{"test_value"},
|
||||
Error: []string{""},
|
||||
CollectedAt: []time.Time{now},
|
||||
}).Return(nil)
|
||||
|
||||
// DO NOT expect GetWorkspaceByAgentID - the fast path should skip this call
|
||||
// If GetWorkspaceByAgentID is called, the test will fail with "unexpected call"
|
||||
|
||||
// dbauthz will call Wrappers() to check for wrapped databases
|
||||
dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes()
|
||||
|
||||
// Set up dbauthz to test the actual authorization layer
|
||||
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||||
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
|
||||
accessControlStore.Store(&acs)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
api.Workspace.UpdateValues(database.Workspace{
|
||||
ID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
})
|
||||
|
||||
// Create context with system actor so authorization passes
|
||||
ctx := dbauthz.AsSystemRestricted(context.Background())
|
||||
resp, err := api.BatchUpdateMetadata(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
})
|
||||
// Test RBAC slow path - invalid RBAC object should fall back to GetWorkspaceByAgentID
|
||||
// This test verifies that when the RBAC object has invalid IDs (nil UUIDs), the dbauthz layer
|
||||
// falls back to the slow path and calls GetWorkspaceByAgentID.
|
||||
t.Run("InvalidWorkspaceCached_RequiresDBCall", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
dbM = dbmock.NewMockStore(ctrl)
|
||||
pub = &fakePublisher{}
|
||||
now = dbtime.Now()
|
||||
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||
ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321")
|
||||
orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
||||
agentID = uuid.MustParse("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
|
||||
)
|
||||
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
}
|
||||
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
Metadata: []*agentproto.Metadata{
|
||||
{
|
||||
Key: "test_key",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
CollectedAt: timestamppb.New(now.Add(-time.Second)),
|
||||
Age: 1,
|
||||
Value: "test_value",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// EXPECT GetWorkspaceByAgentID to be called because the RBAC fast path validation fails
|
||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agentID).Return(database.Workspace{
|
||||
ID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
}, nil)
|
||||
|
||||
// Expect UpdateWorkspaceAgentMetadata to be called after authorization
|
||||
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
Key: []string{"test_key"},
|
||||
Value: []string{"test_value"},
|
||||
Error: []string{""},
|
||||
CollectedAt: []time.Time{now},
|
||||
}).Return(nil)
|
||||
|
||||
// dbauthz will call Wrappers() to check for wrapped databases
|
||||
dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes()
|
||||
|
||||
// Set up dbauthz to test the actual authorization layer
|
||||
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||||
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
|
||||
accessControlStore.Store(&acs)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
// Create an invalid RBAC object with nil UUIDs for owner/org
|
||||
// This will fail dbauthz fast path validation and trigger GetWorkspaceByAgentID
|
||||
api.Workspace.UpdateValues(database.Workspace{
|
||||
ID: uuid.MustParse("cccccccc-cccc-cccc-cccc-cccccccccccc"),
|
||||
OwnerID: uuid.Nil, // Invalid: fails dbauthz fast path validation
|
||||
OrganizationID: uuid.Nil, // Invalid: fails dbauthz fast path validation
|
||||
})
|
||||
|
||||
// Create context with system actor so authorization passes
|
||||
ctx := dbauthz.AsSystemRestricted(context.Background())
|
||||
resp, err := api.BatchUpdateMetadata(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
})
|
||||
// Test RBAC slow path - no RBAC object in context
|
||||
// This test verifies that when no RBAC object is present in context, the dbauthz layer
|
||||
// falls back to the slow path and calls GetWorkspaceByAgentID.
|
||||
t.Run("WorkspaceNotCached_RequiresDBCall", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
dbM = dbmock.NewMockStore(ctrl)
|
||||
pub = &fakePublisher{}
|
||||
now = dbtime.Now()
|
||||
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||
ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321")
|
||||
orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
||||
agentID = uuid.MustParse("dddddddd-dddd-dddd-dddd-dddddddddddd")
|
||||
)
|
||||
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
}
|
||||
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
Metadata: []*agentproto.Metadata{
|
||||
{
|
||||
Key: "test_key",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
CollectedAt: timestamppb.New(now.Add(-time.Second)),
|
||||
Age: 1,
|
||||
Value: "test_value",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// EXPECT GetWorkspaceByAgentID to be called because no RBAC object is in context
|
||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agentID).Return(database.Workspace{
|
||||
ID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
}, nil)
|
||||
|
||||
// Expect UpdateWorkspaceAgentMetadata to be called after authorization
|
||||
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{
|
||||
WorkspaceAgentID: agent.ID,
|
||||
Key: []string{"test_key"},
|
||||
Value: []string{"test_value"},
|
||||
Error: []string{""},
|
||||
CollectedAt: []time.Time{now},
|
||||
}).Return(nil)
|
||||
|
||||
// dbauthz will call Wrappers() to check for wrapped databases
|
||||
dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes()
|
||||
|
||||
// Set up dbauthz to test the actual authorization layer
|
||||
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||||
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
|
||||
accessControlStore.Store(&acs)
|
||||
|
||||
api := &agentapi.MetadataAPI{
|
||||
AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Workspace: &agentapi.CachedWorkspaceFields{},
|
||||
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||
Pubsub: pub,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
// Create context with system actor so authorization passes
|
||||
ctx := dbauthz.AsSystemRestricted(context.Background())
|
||||
resp, err := api.BatchUpdateMetadata(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
})
|
||||
|
||||
// Test cache refresh - AutostartSchedule updated
|
||||
// This test verifies that the cache refresh mechanism actually calls GetWorkspaceByID
|
||||
// and updates the cached workspace fields when the workspace is modified (e.g., autostart schedule changes).
|
||||
t.Run("CacheRefreshed_AutostartScheduleUpdated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
dbM = dbmock.NewMockStore(ctrl)
|
||||
pub = &fakePublisher{}
|
||||
now = dbtime.Now()
|
||||
mClock = quartz.NewMock(t)
|
||||
tickerTrap = mClock.Trap().TickerFunc("cache_refresh")
|
||||
|
||||
workspaceID = uuid.MustParse("12345678-1234-1234-1234-123456789012")
|
||||
ownerID = uuid.MustParse("87654321-4321-4321-4321-210987654321")
|
||||
orgID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
||||
templateID = uuid.MustParse("aaaabbbb-cccc-dddd-eeee-ffffffff0000")
|
||||
agentID = uuid.MustParse("ffffffff-ffff-ffff-ffff-ffffffffffff")
|
||||
)
|
||||
|
||||
agent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
}
|
||||
|
||||
// Initial workspace - has Monday-Friday 9am autostart
|
||||
initialWorkspace := database.Workspace{
|
||||
ID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
TemplateID: templateID,
|
||||
Name: "my-workspace",
|
||||
OwnerUsername: "testuser",
|
||||
TemplateName: "test-template",
|
||||
AutostartSchedule: sql.NullString{Valid: true, String: "CRON_TZ=UTC 0 9 * * 1-5"},
|
||||
}
|
||||
|
||||
// Updated workspace - user changed autostart to 5pm and renamed workspace
|
||||
updatedWorkspace := database.Workspace{
|
||||
ID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
TemplateID: templateID,
|
||||
Name: "my-workspace-renamed", // Changed!
|
||||
OwnerUsername: "testuser",
|
||||
TemplateName: "test-template",
|
||||
AutostartSchedule: sql.NullString{Valid: true, String: "CRON_TZ=UTC 0 17 * * 1-5"}, // Changed!
|
||||
DormantAt: sql.NullTime{},
|
||||
}
|
||||
|
||||
req := &agentproto.BatchUpdateMetadataRequest{
|
||||
Metadata: []*agentproto.Metadata{
|
||||
{
|
||||
Key: "test_key",
|
||||
Result: &agentproto.WorkspaceAgentMetadata_Result{
|
||||
CollectedAt: timestamppb.New(now.Add(-time.Second)),
|
||||
Age: 1,
|
||||
Value: "test_value",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// EXPECT GetWorkspaceByID to be called during cache refresh
|
||||
// This is the key assertion - proves the refresh mechanism is working
|
||||
dbM.EXPECT().GetWorkspaceByID(gomock.Any(), workspaceID).Return(updatedWorkspace, nil)
|
||||
|
||||
// API needs to fetch the agent when calling metadata update
|
||||
dbM.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(agent, nil)
|
||||
|
||||
// After refresh, metadata update should work with updated cache
|
||||
dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx context.Context, params database.UpdateWorkspaceAgentMetadataParams) error {
|
||||
require.Equal(t, agent.ID, params.WorkspaceAgentID)
|
||||
require.Equal(t, []string{"test_key"}, params.Key)
|
||||
require.Equal(t, []string{"test_value"}, params.Value)
|
||||
require.Equal(t, []string{""}, params.Error)
|
||||
require.Len(t, params.CollectedAt, 1)
|
||||
return nil
|
||||
},
|
||||
).AnyTimes()
|
||||
|
||||
// May call GetWorkspaceByAgentID if slow path is used before refresh
|
||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agentID).Return(updatedWorkspace, nil).AnyTimes()
|
||||
|
||||
// dbauthz will call Wrappers()
|
||||
dbM.EXPECT().Wrappers().Return([]string{}).AnyTimes()
|
||||
|
||||
// Set up dbauthz
|
||||
auth := rbac.NewStrictCachingAuthorizer(prometheus.NewRegistry())
|
||||
accessControlStore := &atomic.Pointer[dbauthz.AccessControlStore]{}
|
||||
var acs dbauthz.AccessControlStore = dbauthz.AGPLTemplateAccessControlStore{}
|
||||
accessControlStore.Store(&acs)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create roles with workspace permissions
|
||||
userRoles := rbac.Roles([]rbac.Role{
|
||||
{
|
||||
Identifier: rbac.RoleMember(),
|
||||
User: []rbac.Permission{
|
||||
{
|
||||
Negate: false,
|
||||
ResourceType: rbac.ResourceWorkspace.Type,
|
||||
Action: policy.WildcardSymbol,
|
||||
},
|
||||
},
|
||||
ByOrgID: map[string]rbac.OrgPermissions{
|
||||
orgID.String(): {
|
||||
Member: []rbac.Permission{
|
||||
{
|
||||
Negate: false,
|
||||
ResourceType: rbac.ResourceWorkspace.Type,
|
||||
Action: policy.WildcardSymbol,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
agentScope := rbac.WorkspaceAgentScope(rbac.WorkspaceAgentScopeParams{
|
||||
WorkspaceID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
TemplateID: templateID,
|
||||
VersionID: uuid.New(),
|
||||
})
|
||||
|
||||
ctxWithActor := dbauthz.As(ctx, rbac.Subject{
|
||||
Type: rbac.SubjectTypeUser,
|
||||
FriendlyName: "testuser",
|
||||
Email: "testuser@example.com",
|
||||
ID: ownerID.String(),
|
||||
Roles: userRoles,
|
||||
Groups: []string{orgID.String()},
|
||||
Scope: agentScope,
|
||||
}.WithCachedASTValue())
|
||||
|
||||
// Create full API with cached workspace fields (initial state)
|
||||
api := agentapi.New(agentapi.Options{
|
||||
Ctx: ctxWithActor,
|
||||
AgentID: agentID,
|
||||
WorkspaceID: workspaceID,
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
Database: dbauthz.New(dbM, auth, testutil.Logger(t), accessControlStore),
|
||||
Log: testutil.Logger(t),
|
||||
Clock: mClock,
|
||||
Pubsub: pub,
|
||||
}, initialWorkspace) // Cache is initialized with 9am schedule and "my-workspace" name
|
||||
|
||||
// Wait for ticker to be set up and release it so it can fire
|
||||
tickerTrap.MustWait(ctx).MustRelease(ctx)
|
||||
tickerTrap.Close()
|
||||
|
||||
// Advance clock to trigger cache refresh and wait for it to complete
|
||||
_, aw := mClock.AdvanceNext()
|
||||
aw.MustWait(ctx)
|
||||
|
||||
// At this point, GetWorkspaceByID should have been called and cache updated
|
||||
// The cache now has the 5pm schedule and "my-workspace-renamed" name
|
||||
|
||||
// Now call metadata update to verify the refreshed cache works
|
||||
resp, err := api.MetadataAPI.BatchUpdateMetadata(ctxWithActor, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
type StatsAPI struct {
|
||||
AgentFn func(context.Context) (database.WorkspaceAgent, error)
|
||||
Workspace *CachedWorkspaceFields
|
||||
Database database.Store
|
||||
Log slog.Logger
|
||||
StatsReporter *workspacestats.Reporter
|
||||
@@ -46,14 +47,21 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
getWorkspaceAgentByIDRow, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace by agent ID %q: %w", workspaceAgent.ID, err)
|
||||
|
||||
// If cache is empty (prebuild or invalid), fall back to DB
|
||||
var ws database.WorkspaceIdentity
|
||||
var ok bool
|
||||
if ws, ok = a.Workspace.AsWorkspaceIdentity(); !ok {
|
||||
w, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace by agent ID %q: %w", workspaceAgent.ID, err)
|
||||
}
|
||||
ws = database.WorkspaceIdentityFromWorkspace(w)
|
||||
}
|
||||
workspace := getWorkspaceAgentByIDRow
|
||||
|
||||
a.Log.Debug(ctx, "read stats report",
|
||||
slog.F("interval", a.AgentStatsRefreshInterval),
|
||||
slog.F("workspace_id", workspace.ID),
|
||||
slog.F("workspace_id", ws.ID),
|
||||
slog.F("payload", req),
|
||||
)
|
||||
|
||||
@@ -70,9 +78,8 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
|
||||
err = a.StatsReporter.ReportAgentStats(
|
||||
ctx,
|
||||
a.now(),
|
||||
workspace,
|
||||
ws,
|
||||
workspaceAgent,
|
||||
getWorkspaceAgentByIDRow.TemplateName,
|
||||
req.Stats,
|
||||
false,
|
||||
)
|
||||
|
||||
@@ -52,8 +52,19 @@ func TestUpdateStates(t *testing.T) {
|
||||
ID: uuid.New(),
|
||||
Name: "abc",
|
||||
}
|
||||
workspaceAsCacheFields = agentapi.CachedWorkspaceFields{}
|
||||
)
|
||||
|
||||
workspaceAsCacheFields.UpdateValues(database.Workspace{
|
||||
ID: workspace.ID,
|
||||
OwnerID: workspace.OwnerID,
|
||||
OwnerUsername: workspace.OwnerUsername,
|
||||
TemplateID: workspace.TemplateID,
|
||||
Name: workspace.Name,
|
||||
TemplateName: workspace.TemplateName,
|
||||
AutostartSchedule: workspace.AutostartSchedule,
|
||||
})
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -111,7 +122,8 @@ func TestUpdateStates(t *testing.T) {
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Workspace: &workspaceAsCacheFields,
|
||||
Database: dbM,
|
||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||
Database: dbM,
|
||||
Pubsub: ps,
|
||||
@@ -136,9 +148,6 @@ func TestUpdateStates(t *testing.T) {
|
||||
}
|
||||
defer wut.Close()
|
||||
|
||||
// Workspace gets fetched.
|
||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(workspace, nil)
|
||||
|
||||
// We expect an activity bump because ConnectionCount > 0.
|
||||
dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
@@ -223,7 +232,8 @@ func TestUpdateStates(t *testing.T) {
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Workspace: &workspaceAsCacheFields,
|
||||
Database: dbM,
|
||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||
Database: dbM,
|
||||
Pubsub: ps,
|
||||
@@ -239,9 +249,6 @@ func TestUpdateStates(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Workspace gets fetched.
|
||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(workspace, nil)
|
||||
|
||||
_, err := api.UpdateStats(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
@@ -260,7 +267,8 @@ func TestUpdateStates(t *testing.T) {
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Workspace: &workspaceAsCacheFields,
|
||||
Database: dbM,
|
||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||
Database: dbM,
|
||||
Pubsub: ps,
|
||||
@@ -333,11 +341,17 @@ func TestUpdateStates(t *testing.T) {
|
||||
},
|
||||
}
|
||||
)
|
||||
// need to overwrite the cached fields for this test, but the struct has a lock
|
||||
ws := agentapi.CachedWorkspaceFields{}
|
||||
ws.UpdateValues(workspace)
|
||||
// ws.AutostartSchedule = workspace.AutostartSchedule
|
||||
|
||||
api := agentapi.StatsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Workspace: &ws,
|
||||
Database: dbM,
|
||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||
Database: dbM,
|
||||
Pubsub: ps,
|
||||
@@ -362,9 +376,6 @@ func TestUpdateStates(t *testing.T) {
|
||||
}
|
||||
defer wut.Close()
|
||||
|
||||
// Workspace gets fetched.
|
||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(workspace, nil)
|
||||
|
||||
// We expect an activity bump because ConnectionCount > 0. However, the
|
||||
// next autostart time will be set on the bump.
|
||||
dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{
|
||||
@@ -451,7 +462,8 @@ func TestUpdateStates(t *testing.T) {
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Workspace: &workspaceAsCacheFields,
|
||||
Database: dbM,
|
||||
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
|
||||
Database: dbM,
|
||||
Pubsub: ps,
|
||||
@@ -478,9 +490,6 @@ func TestUpdateStates(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Workspace gets fetched.
|
||||
dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(workspace, nil)
|
||||
|
||||
// We expect an activity bump because ConnectionCount > 0.
|
||||
dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
|
||||
@@ -5556,6 +5556,22 @@ func (q *querier) UpdateWorkspaceAgentLogOverflowByID(ctx context.Context, arg d
|
||||
}
|
||||
|
||||
func (q *querier) UpdateWorkspaceAgentMetadata(ctx context.Context, arg database.UpdateWorkspaceAgentMetadataParams) error {
|
||||
// Fast path: Check if we have an RBAC object in context.
|
||||
// This is set by the workspace agent RPC handler to avoid the expensive
|
||||
// GetWorkspaceByAgentID query for every metadata update.
|
||||
// NOTE: The cached RBAC object is refreshed every 5 minutes in agentapi/api.go.
|
||||
if rbacObj, ok := WorkspaceRBACFromContext(ctx); ok {
|
||||
// Errors here will result in falling back to the GetWorkspaceAgentByID query, skipping
|
||||
// the cache in case the cached data is stale.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbacObj); err == nil {
|
||||
return q.db.UpdateWorkspaceAgentMetadata(ctx, arg)
|
||||
}
|
||||
q.log.Debug(ctx, "fast path authorization failed, using slow path",
|
||||
slog.F("agent_id", arg.WorkspaceAgentID))
|
||||
}
|
||||
|
||||
// Slow path: Fallback to fetching the workspace for authorization if the RBAC object is not present (or is invalid)
|
||||
// in the request context.
|
||||
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.WorkspaceAgentID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
package dbauthz
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
)
|
||||
|
||||
func isWorkspaceRBACObjectEmpty(rbacObj rbac.Object) bool {
|
||||
// if any of these are true then the rbac.Object work a workspace is considered empty
|
||||
return rbacObj.Owner == "" || rbacObj.OrgID == "" || rbacObj.Owner == uuid.Nil.String() || rbacObj.OrgID == uuid.Nil.String()
|
||||
}
|
||||
|
||||
type workspaceRBACContextKey struct{}
|
||||
|
||||
// WithWorkspaceRBAC attaches a workspace RBAC object to the context.
|
||||
// RBAC fields on this RBAC object should not be used.
|
||||
//
|
||||
// This is primarily used by the workspace agent RPC handler to cache workspace
|
||||
// authorization data for the duration of an agent connection.
|
||||
func WithWorkspaceRBAC(ctx context.Context, rbacObj rbac.Object) (context.Context, error) {
|
||||
if rbacObj.Type != rbac.ResourceWorkspace.Type {
|
||||
return ctx, xerrors.New("RBAC Object must be of type Workspace")
|
||||
}
|
||||
if isWorkspaceRBACObjectEmpty(rbacObj) {
|
||||
return ctx, xerrors.Errorf("cannot attach empty RBAC object to context: %+v", rbacObj)
|
||||
}
|
||||
if len(rbacObj.ACLGroupList) != 0 || len(rbacObj.ACLUserList) != 0 {
|
||||
return ctx, xerrors.New("ACL fields for Workspace RBAC object must be nullified, the can be changed during runtime and should not be cached")
|
||||
}
|
||||
return context.WithValue(ctx, workspaceRBACContextKey{}, rbacObj), nil
|
||||
}
|
||||
|
||||
// WorkspaceRBACFromContext attempts to retrieve the workspace RBAC object from context.
|
||||
func WorkspaceRBACFromContext(ctx context.Context) (rbac.Object, bool) {
|
||||
obj, ok := ctx.Value(workspaceRBACContextKey{}).(rbac.Object)
|
||||
return obj, ok
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"slices"
|
||||
"sort"
|
||||
@@ -796,3 +797,60 @@ func (s UserSecret) RBACObject() rbac.Object {
|
||||
func (s AIBridgeInterception) RBACObject() rbac.Object {
|
||||
return rbac.ResourceAibridgeInterception.WithOwner(s.InitiatorID.String())
|
||||
}
|
||||
|
||||
// WorkspaceIdentity contains the minimal workspace fields needed for agent API metadata/stats reporting
|
||||
// and RBAC checks, without requiring a full database.Workspace object.
|
||||
type WorkspaceIdentity struct {
|
||||
// Add any other fields needed for IsPrebuild() if it relies on workspace fields
|
||||
// Identity fields
|
||||
ID uuid.UUID
|
||||
OwnerID uuid.UUID
|
||||
OrganizationID uuid.UUID
|
||||
TemplateID uuid.UUID
|
||||
|
||||
// Display fields for logging/metrics
|
||||
Name string
|
||||
OwnerUsername string
|
||||
TemplateName string
|
||||
|
||||
// Lifecycle fields needed for stats reporting
|
||||
AutostartSchedule sql.NullString
|
||||
}
|
||||
|
||||
func (w WorkspaceIdentity) RBACObject() rbac.Object {
|
||||
return Workspace{
|
||||
ID: w.ID,
|
||||
OwnerID: w.OwnerID,
|
||||
OrganizationID: w.OrganizationID,
|
||||
TemplateID: w.TemplateID,
|
||||
Name: w.Name,
|
||||
OwnerUsername: w.OwnerUsername,
|
||||
TemplateName: w.TemplateName,
|
||||
AutostartSchedule: w.AutostartSchedule,
|
||||
}.RBACObject()
|
||||
}
|
||||
|
||||
// IsPrebuild returns true if the workspace is a prebuild workspace.
|
||||
// A workspace is considered a prebuild if its owner is the prebuild system user.
|
||||
func (w WorkspaceIdentity) IsPrebuild() bool {
|
||||
return w.OwnerID == PrebuildsSystemUserID
|
||||
}
|
||||
|
||||
func (w WorkspaceIdentity) Equal(w2 WorkspaceIdentity) bool {
|
||||
return w.ID == w2.ID && w.OwnerID == w2.OwnerID && w.OrganizationID == w2.OrganizationID &&
|
||||
w.TemplateID == w2.TemplateID && w.Name == w2.Name && w.OwnerUsername == w2.OwnerUsername &&
|
||||
w.TemplateName == w2.TemplateName && w.AutostartSchedule == w2.AutostartSchedule
|
||||
}
|
||||
|
||||
func WorkspaceIdentityFromWorkspace(w Workspace) WorkspaceIdentity {
|
||||
return WorkspaceIdentity{
|
||||
ID: w.ID,
|
||||
OwnerID: w.OwnerID,
|
||||
OrganizationID: w.OrganizationID,
|
||||
TemplateID: w.TemplateID,
|
||||
Name: w.Name,
|
||||
OwnerUsername: w.OwnerUsername,
|
||||
TemplateName: w.TemplateName,
|
||||
AutostartSchedule: w.AutostartSchedule,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Optional:
|
||||
UpdateAgentMetricsFn: api.UpdateAgentMetrics,
|
||||
})
|
||||
}, workspace)
|
||||
|
||||
streamID := tailnet.StreamID{
|
||||
Name: fmt.Sprintf("%s-%s-%s", workspace.OwnerUsername, workspace.Name, workspaceAgent.Name),
|
||||
|
||||
@@ -1717,13 +1717,13 @@ func (api *API) postWorkspaceUsage(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
template, err := api.Database.GetTemplateByID(ctx, workspace.TemplateID)
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
// template, err := api.Database.GetTemplateByID(ctx, workspace.TemplateID)
|
||||
// if err != nil {
|
||||
// httpapi.InternalServerError(rw, err)
|
||||
// return
|
||||
// }
|
||||
|
||||
err = api.statsReporter.ReportAgentStats(ctx, dbtime.Now(), workspace, agent, template.Name, stat, true)
|
||||
err = api.statsReporter.ReportAgentStats(ctx, dbtime.Now(), database.WorkspaceIdentityFromWorkspace(workspace), agent, stat, true)
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
|
||||
@@ -120,7 +120,7 @@ func (r *Reporter) ReportAppStats(ctx context.Context, stats []workspaceapps.Sta
|
||||
}
|
||||
|
||||
// nolint:revive // usage is a control flag while we have the experiment
|
||||
func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspace database.Workspace, workspaceAgent database.WorkspaceAgent, templateName string, stats *agentproto.Stats, usage bool) error {
|
||||
func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspace database.WorkspaceIdentity, workspaceAgent database.WorkspaceAgent, stats *agentproto.Stats, usage bool) error {
|
||||
// update agent stats
|
||||
r.opts.StatsBatcher.Add(now, workspaceAgent.ID, workspace.TemplateID, workspace.OwnerID, workspace.ID, stats, usage)
|
||||
|
||||
@@ -130,7 +130,7 @@ func (r *Reporter) ReportAgentStats(ctx context.Context, now time.Time, workspac
|
||||
Username: workspace.OwnerUsername,
|
||||
WorkspaceName: workspace.Name,
|
||||
AgentName: workspaceAgent.Name,
|
||||
TemplateName: templateName,
|
||||
TemplateName: workspace.TemplateName,
|
||||
}, stats.Metrics)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user