diff --git a/coderd/agentapi/cached_workspace.go b/coderd/agentapi/cached_workspace.go index 7c1bc0ff63..cb2ab19990 100644 --- a/coderd/agentapi/cached_workspace.go +++ b/coderd/agentapi/cached_workspace.go @@ -1,9 +1,13 @@ package agentapi import ( + "context" "sync" + "golang.org/x/xerrors" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" ) // CachedWorkspaceFields contains workspace data that is safe to cache for the @@ -50,3 +54,19 @@ func (cws *CachedWorkspaceFields) AsWorkspaceIdentity() (database.WorkspaceIdent } return cws.identity, true } + +// ContextInject attempts to inject the rbac object for the cached workspace fields +// into the given context, either returning the wrapped context or the original. +func (cws *CachedWorkspaceFields) ContextInject(ctx context.Context) (context.Context, error) { + var err error + rbacCtx := ctx + if dbws, ok := cws.AsWorkspaceIdentity(); ok { + rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject()) + if err != nil { + // Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID. + //nolint:gocritic + return ctx, xerrors.Errorf("Cached workspace was present but RBAC object was invalid: %w", err) + } + } + return rbacCtx, nil +} diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index ec56e244a0..3f8138fb10 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2455,6 +2455,18 @@ func (q *querier) GetLatestWorkspaceAppStatusesByWorkspaceIDs(ctx context.Contex } func (q *querier) GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { + // Fast path: Check if we have a workspace RBAC object in context. + if rbacObj, ok := WorkspaceRBACFromContext(ctx); ok { + // Errors here will result in falling back to GetWorkspaceByAgentID, + // in case the cached data is stale. + if err := q.authorizeContext(ctx, policy.ActionRead, rbacObj); err == nil { + return q.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceID) + } + + q.log.Debug(ctx, "fast path authorization failed for GetLatestWorkspaceBuildByWorkspaceID, using slow path", + slog.F("workspace_id", workspaceID)) + } + if _, err := q.GetWorkspaceByID(ctx, workspaceID); err != nil { return database.WorkspaceBuild{}, err } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 0e1e684587..b3a3acb890 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -4731,3 +4731,77 @@ func (s *MethodTestSuite) TestTelemetry() { check.Args(database.CalculateAIBridgeInterceptionsTelemetrySummaryParams{}).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead) })) } + +func TestGetLatestWorkspaceBuildByWorkspaceID_FastPath(t *testing.T) { + t.Parallel() + + ownerID := uuid.New() + wsID := uuid.New() + orgID := uuid.New() + + workspace := database.Workspace{ + ID: wsID, + OwnerID: ownerID, + OrganizationID: orgID, + } + + build := database.WorkspaceBuild{ + ID: uuid.New(), + WorkspaceID: wsID, + } + + wsIdentity := database.WorkspaceIdentity{ + ID: wsID, + OwnerID: ownerID, + OrganizationID: orgID, + } + + actor := rbac.Subject{ + ID: ownerID.String(), + Roles: rbac.RoleIdentifiers{rbac.RoleOwner()}, + Groups: []string{orgID.String()}, + Scope: rbac.ScopeAll, + } + + authorizer := &coderdtest.RecordingAuthorizer{ + Wrapped: (&coderdtest.FakeAuthorizer{}).AlwaysReturn(nil), + } + + t.Run("WithWorkspaceRBAC", func(t *testing.T) { + t.Parallel() + + ctx := dbauthz.As(context.Background(), actor) + ctrl := gomock.NewController(t) + dbm := dbmock.NewMockStore(ctrl) + + rbacObj := wsIdentity.RBACObject() + ctx, err := dbauthz.WithWorkspaceRBAC(ctx, rbacObj) + require.NoError(t, err) + + dbm.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspace.ID).Return(build, nil).AnyTimes() + dbm.EXPECT().Wrappers().Return([]string{}) + + q := dbauthz.New(dbm, authorizer, slogtest.Make(t, nil), coderdtest.AccessControlStorePointer()) + + result, err := q.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID) + require.NoError(t, err) + require.Equal(t, build, result) + }) + t.Run("WithoutWorkspaceRBAC", func(t *testing.T) { + t.Parallel() + + ctx := dbauthz.As(context.Background(), actor) + ctrl := gomock.NewController(t) + dbm := dbmock.NewMockStore(ctrl) + + dbm.EXPECT().GetWorkspaceByID(gomock.Any(), wsID).Return(workspace, nil).AnyTimes() + dbm.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspace.ID).Return(build, nil).AnyTimes() + dbm.EXPECT().Wrappers().Return([]string{}) + + q := dbauthz.New(dbm, authorizer, slogtest.Make(t, nil), coderdtest.AccessControlStorePointer()) + + result, err := q.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID) + require.NoError(t, err) + require.Equal(t, build, result) + }) +} diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index 50a14768c1..37d5e6d3b7 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -227,10 +227,11 @@ func (api *API) startAgentYamuxMonitor(ctx context.Context, mux *yamux.Session, ) *agentConnectionMonitor { monitor := &agentConnectionMonitor{ - apiCtx: api.ctx, - workspace: workspace, - workspaceAgent: workspaceAgent, - workspaceBuild: workspaceBuild, + apiCtx: api.ctx, + workspace: workspace, + workspaceAgent: workspaceAgent, + workspaceBuild: workspaceBuild, + conn: &yamuxPingerCloser{mux: mux}, pingPeriod: api.AgentConnectionUpdateFrequency, db: api.Database, @@ -453,6 +454,13 @@ func (m *agentConnectionMonitor) monitor(ctx context.Context) { AgentID: &m.workspaceAgent.ID, }) } + + ctx, err := dbauthz.WithWorkspaceRBAC(ctx, m.workspace.RBACObject()) + if err != nil { + // Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID. + //nolint:gocritic + m.logger.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err)) + } err = checkBuildIsLatest(ctx, m.db, m.workspaceBuild) if err != nil { reason = err.Error()