chore: ensure proper rbac permissions on 'Acquire' file in the cache (#18348)

The file cache was caching the `Unauthorized` errors if a user without
the right perms opened the file first. So all future opens would fail.

Now the cache always opens with a subject that can read files. And authz
is checked on the Acquire per user.
This commit is contained in:
Steven Masley
2025-06-16 08:40:45 -05:00
committed by GitHub
parent d83706bd5b
commit 1d1070d051
16 changed files with 218 additions and 51 deletions
+4 -4
View File
@@ -19,7 +19,7 @@ import (
// objects that the user is authorized to perform the given action on.
// This is faster than calling Authorize() on each object.
func AuthorizeFilter[O rbac.Objecter](h *HTTPAuthorizer, r *http.Request, action policy.Action, objects []O) ([]O, error) {
roles := httpmw.UserAuthorization(r)
roles := httpmw.UserAuthorization(r.Context())
objects, err := rbac.Filter(r.Context(), h.Authorizer, roles, action, objects)
if err != nil {
// Log the error as Filter should not be erroring.
@@ -65,7 +65,7 @@ func (api *API) Authorize(r *http.Request, action policy.Action, object rbac.Obj
// return
// }
func (h *HTTPAuthorizer) Authorize(r *http.Request, action policy.Action, object rbac.Objecter) bool {
roles := httpmw.UserAuthorization(r)
roles := httpmw.UserAuthorization(r.Context())
err := h.Authorizer.Authorize(r.Context(), roles, action, object.RBACObject())
if err != nil {
// Log the errors for debugging
@@ -97,7 +97,7 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action policy.Action, object
// call 'Authorize()' on the returned objects.
// Note the authorization is only for the given action and object type.
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action policy.Action, objectType string) (rbac.PreparedAuthorized, error) {
roles := httpmw.UserAuthorization(r)
roles := httpmw.UserAuthorization(r.Context())
prepared, err := h.Authorizer.Prepare(r.Context(), roles, action, objectType)
if err != nil {
return nil, xerrors.Errorf("prepare filter: %w", err)
@@ -120,7 +120,7 @@ func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action policy.Actio
// @Router /authcheck [post]
func (api *API) checkAuthorization(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
auth := httpmw.UserAuthorization(r)
auth := httpmw.UserAuthorization(r.Context())
var params codersdk.AuthorizationRequest
if !httpapi.Read(ctx, rw, r, &params) {
+1 -1
View File
@@ -572,7 +572,7 @@ func New(options *Options) *API {
TemplateScheduleStore: options.TemplateScheduleStore,
UserQuietHoursScheduleStore: options.UserQuietHoursScheduleStore,
AccessControlStore: options.AccessControlStore,
FileCache: files.NewFromStore(options.Database, options.PrometheusRegistry),
FileCache: files.NewFromStore(options.Database, options.PrometheusRegistry, options.Authorizer),
Experiments: experiments,
WebpushDispatcher: options.WebPushDispatcher,
healthCheckGroup: &singleflight.Group[string, *healthsdk.HealthcheckReport]{},
+5 -1
View File
@@ -234,6 +234,10 @@ func (r *RecordingAuthorizer) AssertOutOfOrder(t *testing.T, actor rbac.Subject,
// AssertActor asserts in order. If the order of authz calls does not match,
// this will fail.
func (r *RecordingAuthorizer) AssertActor(t *testing.T, actor rbac.Subject, did ...ActionObjectPair) {
r.AssertActorID(t, actor.ID, did...)
}
func (r *RecordingAuthorizer) AssertActorID(t *testing.T, id string, did ...ActionObjectPair) {
r.Lock()
defer r.Unlock()
ptr := 0
@@ -242,7 +246,7 @@ func (r *RecordingAuthorizer) AssertActor(t *testing.T, actor rbac.Subject, did
// Finished all assertions
return
}
if call.Actor.ID == actor.ID {
if call.Actor.ID == id {
action, object := did[ptr].Action, did[ptr].Object
assert.Equalf(t, action, call.Action, "assert action %d", ptr)
assert.Equalf(t, object, call.Object, "assert object %d", ptr)
+23
View File
@@ -432,6 +432,25 @@ var (
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
subjectFileReader = rbac.Subject{
Type: rbac.SubjectTypeFileReader,
FriendlyName: "Can Read All Files",
// Arbitrary uuid to have a unique ID for this subject.
ID: rbac.SubjectTypeFileReaderID,
Roles: rbac.Roles([]rbac.Role{
{
Identifier: rbac.RoleIdentifier{Name: "file-reader"},
DisplayName: "FileReader",
Site: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceFile.Type: {policy.ActionRead},
}),
Org: map[string][]rbac.Permission{},
User: []rbac.Permission{},
},
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
)
// AsProvisionerd returns a context with an actor that has permissions required
@@ -498,6 +517,10 @@ func AsPrebuildsOrchestrator(ctx context.Context) context.Context {
return As(ctx, subjectPrebuildsOrchestrator)
}
func AsFileReader(ctx context.Context) context.Context {
return As(ctx, subjectFileReader)
}
var AsRemoveActor = rbac.Subject{
ID: "remove-actor",
}
+39 -18
View File
@@ -13,33 +13,41 @@ import (
archivefs "github.com/coder/coder/v2/archive/fs"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/util/lazy"
)
// NewFromStore returns a file cache that will fetch files from the provided
// database.
func NewFromStore(store database.Store, registerer prometheus.Registerer) *Cache {
fetch := func(ctx context.Context, fileID uuid.UUID) (cacheEntryValue, error) {
file, err := store.GetFileByID(ctx, fileID)
func NewFromStore(store database.Store, registerer prometheus.Registerer, authz rbac.Authorizer) *Cache {
fetch := func(ctx context.Context, fileID uuid.UUID) (CacheEntryValue, error) {
// Make sure the read does not fail due to authorization issues.
// Authz is checked on the Acquire call, so this is safe.
//nolint:gocritic
file, err := store.GetFileByID(dbauthz.AsFileReader(ctx), fileID)
if err != nil {
return cacheEntryValue{}, xerrors.Errorf("failed to read file from database: %w", err)
return CacheEntryValue{}, xerrors.Errorf("failed to read file from database: %w", err)
}
content := bytes.NewBuffer(file.Data)
return cacheEntryValue{
FS: archivefs.FromTarReader(content),
size: int64(content.Len()),
return CacheEntryValue{
Object: file.RBACObject(),
FS: archivefs.FromTarReader(content),
Size: int64(content.Len()),
}, nil
}
return New(fetch, registerer)
return New(fetch, registerer, authz)
}
func New(fetch fetcher, registerer prometheus.Registerer) *Cache {
func New(fetch fetcher, registerer prometheus.Registerer, authz rbac.Authorizer) *Cache {
return (&Cache{
lock: sync.Mutex{},
data: make(map[uuid.UUID]*cacheEntry),
fetcher: fetch,
authz: authz,
}).registerMetrics(registerer)
}
@@ -101,6 +109,7 @@ type Cache struct {
lock sync.Mutex
data map[uuid.UUID]*cacheEntry
fetcher
authz rbac.Authorizer
// metrics
cacheMetrics
@@ -117,18 +126,19 @@ type cacheMetrics struct {
totalCacheSize prometheus.Counter
}
type cacheEntryValue struct {
type CacheEntryValue struct {
fs.FS
size int64
Object rbac.Object
Size int64
}
type cacheEntry struct {
// refCount must only be accessed while the Cache lock is held.
refCount int
value *lazy.ValueWithError[cacheEntryValue]
value *lazy.ValueWithError[CacheEntryValue]
}
type fetcher func(context.Context, uuid.UUID) (cacheEntryValue, error)
type fetcher func(context.Context, uuid.UUID) (CacheEntryValue, error)
// Acquire will load the fs.FS for the given file. It guarantees that parallel
// calls for the same fileID will only result in one fetch, and that parallel
@@ -146,22 +156,33 @@ func (c *Cache) Acquire(ctx context.Context, fileID uuid.UUID) (fs.FS, error) {
c.Release(fileID)
return nil, err
}
subject, ok := dbauthz.ActorFromContext(ctx)
if !ok {
return nil, dbauthz.ErrNoActor
}
// Always check the caller can actually read the file.
if err := c.authz.Authorize(ctx, subject, policy.ActionRead, it.Object); err != nil {
c.Release(fileID)
return nil, err
}
return it.FS, err
}
func (c *Cache) prepare(ctx context.Context, fileID uuid.UUID) *lazy.ValueWithError[cacheEntryValue] {
func (c *Cache) prepare(ctx context.Context, fileID uuid.UUID) *lazy.ValueWithError[CacheEntryValue] {
c.lock.Lock()
defer c.lock.Unlock()
entry, ok := c.data[fileID]
if !ok {
value := lazy.NewWithError(func() (cacheEntryValue, error) {
value := lazy.NewWithError(func() (CacheEntryValue, error) {
val, err := c.fetcher(ctx, fileID)
// Always add to the cache size the bytes of the file loaded.
if err == nil {
c.currentCacheSize.Add(float64(val.size))
c.totalCacheSize.Add(float64(val.size))
c.currentCacheSize.Add(float64(val.Size))
c.totalCacheSize.Add(float64(val.Size))
}
return val, err
@@ -206,7 +227,7 @@ func (c *Cache) Release(fileID uuid.UUID) {
ev, err := entry.value.Load()
if err == nil {
c.currentCacheSize.Add(-1 * float64(ev.size))
c.currentCacheSize.Add(-1 * float64(ev.Size))
}
delete(c.data, fileID)
@@ -1,4 +1,4 @@
package files
package files_test
import (
"context"
@@ -12,28 +12,114 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/promhelp"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/files"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/testutil"
)
// nolint:paralleltest,tparallel // Serially testing is easier
func TestCacheRBAC(t *testing.T) {
t.Parallel()
db, cache, rec := cacheAuthzSetup(t)
ctx := testutil.Context(t, testutil.WaitMedium)
file := dbgen.File(t, db, database.File{})
nobodyID := uuid.New()
nobody := dbauthz.As(ctx, rbac.Subject{
ID: nobodyID.String(),
Roles: rbac.Roles{},
Scope: rbac.ScopeAll,
})
userID := uuid.New()
userReader := dbauthz.As(ctx, rbac.Subject{
ID: userID.String(),
Roles: rbac.Roles{
must(rbac.RoleByName(rbac.RoleTemplateAdmin())),
},
Scope: rbac.ScopeAll,
})
//nolint:gocritic // Unit testing
cacheReader := dbauthz.AsFileReader(ctx)
t.Run("NoRolesOpen", func(t *testing.T) {
// Ensure start is clean
require.Equal(t, 0, cache.Count())
rec.Reset()
_, err := cache.Acquire(nobody, file.ID)
require.Error(t, err)
require.True(t, rbac.IsUnauthorizedError(err))
// Ensure that the cache is empty
require.Equal(t, 0, cache.Count())
// Check the assertions
rec.AssertActorID(t, nobodyID.String(), rec.Pair(policy.ActionRead, file))
rec.AssertActorID(t, rbac.SubjectTypeFileReaderID, rec.Pair(policy.ActionRead, file))
})
t.Run("CacheHasFile", func(t *testing.T) {
rec.Reset()
require.Equal(t, 0, cache.Count())
// Read the file with a file reader to put it into the cache.
_, err := cache.Acquire(cacheReader, file.ID)
require.NoError(t, err)
require.Equal(t, 1, cache.Count())
// "nobody" should not be able to read the file.
_, err = cache.Acquire(nobody, file.ID)
require.Error(t, err)
require.True(t, rbac.IsUnauthorizedError(err))
require.Equal(t, 1, cache.Count())
// UserReader can
_, err = cache.Acquire(userReader, file.ID)
require.NoError(t, err)
require.Equal(t, 1, cache.Count())
cache.Release(file.ID)
cache.Release(file.ID)
require.Equal(t, 0, cache.Count())
rec.AssertActorID(t, nobodyID.String(), rec.Pair(policy.ActionRead, file))
rec.AssertActorID(t, rbac.SubjectTypeFileReaderID, rec.Pair(policy.ActionRead, file))
rec.AssertActorID(t, userID.String(), rec.Pair(policy.ActionRead, file))
})
}
func cachePromMetricName(metric string) string {
return "coderd_file_cache_" + metric
}
func TestConcurrency(t *testing.T) {
t.Parallel()
//nolint:gocritic // Unit testing
ctx := dbauthz.AsFileReader(t.Context())
const fileSize = 10
emptyFS := afero.NewIOFS(afero.NewReadOnlyFs(afero.NewMemMapFs()))
var fetches atomic.Int64
reg := prometheus.NewRegistry()
c := New(func(_ context.Context, _ uuid.UUID) (cacheEntryValue, error) {
c := files.New(func(_ context.Context, _ uuid.UUID) (files.CacheEntryValue, error) {
fetches.Add(1)
// Wait long enough before returning to make sure that all of the goroutines
// will be waiting in line, ensuring that no one duplicated a fetch.
time.Sleep(testutil.IntervalMedium)
return cacheEntryValue{FS: emptyFS, size: fileSize}, nil
}, reg)
return files.CacheEntryValue{FS: emptyFS, Size: fileSize}, nil
}, reg, &coderdtest.FakeAuthorizer{})
batches := 1000
groups := make([]*errgroup.Group, 0, batches)
@@ -51,7 +137,7 @@ func TestConcurrency(t *testing.T) {
g.Go(func() error {
// We don't bother to Release these references because the Cache will be
// released at the end of the test anyway.
_, err := c.Acquire(t.Context(), id)
_, err := c.Acquire(ctx, id)
return err
})
}
@@ -74,16 +160,18 @@ func TestConcurrency(t *testing.T) {
func TestRelease(t *testing.T) {
t.Parallel()
//nolint:gocritic // Unit testing
ctx := dbauthz.AsFileReader(t.Context())
const fileSize = 10
emptyFS := afero.NewIOFS(afero.NewReadOnlyFs(afero.NewMemMapFs()))
reg := prometheus.NewRegistry()
c := New(func(_ context.Context, _ uuid.UUID) (cacheEntryValue, error) {
return cacheEntryValue{
c := files.New(func(_ context.Context, _ uuid.UUID) (files.CacheEntryValue, error) {
return files.CacheEntryValue{
FS: emptyFS,
size: fileSize,
Size: fileSize,
}, nil
}, reg)
}, reg, &coderdtest.FakeAuthorizer{})
batches := 100
ids := make([]uuid.UUID, 0, batches)
@@ -95,7 +183,7 @@ func TestRelease(t *testing.T) {
batchSize := 10
for openedIdx, id := range ids {
for batchIdx := range batchSize {
it, err := c.Acquire(t.Context(), id)
it, err := c.Acquire(ctx, id)
require.NoError(t, err)
require.Equal(t, emptyFS, it)
@@ -112,7 +200,7 @@ func TestRelease(t *testing.T) {
}
// Make sure cache is fully loaded
require.Equal(t, len(c.data), batches)
require.Equal(t, c.Count(), batches)
// Now release all of the references
for closedIdx, id := range ids {
@@ -136,7 +224,7 @@ func TestRelease(t *testing.T) {
}
// ...and make sure that the cache has emptied itself.
require.Equal(t, len(c.data), 0)
require.Equal(t, c.Count(), 0)
// Verify all the counts & metrics are correct.
// All existing files are closed
@@ -150,3 +238,29 @@ func TestRelease(t *testing.T) {
require.Equal(t, batches, promhelp.CounterValue(t, reg, cachePromMetricName("open_files_total"), nil))
require.Equal(t, batches*batchSize, promhelp.CounterValue(t, reg, cachePromMetricName("open_file_refs_total"), nil))
}
func cacheAuthzSetup(t *testing.T) (database.Store, *files.Cache, *coderdtest.RecordingAuthorizer) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{})
reg := prometheus.NewRegistry()
db, _ := dbtestutil.NewDB(t)
authz := rbac.NewAuthorizer(reg)
rec := &coderdtest.RecordingAuthorizer{
Called: nil,
Wrapped: authz,
}
// Dbauthz wrap the db
db = dbauthz.New(db, rec, logger, coderdtest.AccessControlStorePointer())
c := files.NewFromStore(db, reg, rec)
return db, c, rec
}
func must[T any](t T, err error) T {
if err != nil {
panic(err)
}
return t
}
+4 -4
View File
@@ -47,14 +47,14 @@ func APIKey(r *http.Request) database.APIKey {
// UserAuthorizationOptional may return the roles and scope used for
// authorization. Depends on the ExtractAPIKey handler.
func UserAuthorizationOptional(r *http.Request) (rbac.Subject, bool) {
return dbauthz.ActorFromContext(r.Context())
func UserAuthorizationOptional(ctx context.Context) (rbac.Subject, bool) {
return dbauthz.ActorFromContext(ctx)
}
// UserAuthorization returns the roles and scope used for authorization. Depends
// on the ExtractAPIKey handler.
func UserAuthorization(r *http.Request) rbac.Subject {
auth, ok := UserAuthorizationOptional(r)
func UserAuthorization(ctx context.Context) rbac.Subject {
auth, ok := UserAuthorizationOptional(ctx)
if !ok {
panic("developer error: ExtractAPIKey middleware not provided")
}
+3 -3
View File
@@ -58,7 +58,7 @@ func TestAPIKey(t *testing.T) {
assert.NoError(t, err, "actor rego ok")
}
auth, ok := httpmw.UserAuthorizationOptional(r)
auth, ok := httpmw.UserAuthorizationOptional(r.Context())
assert.True(t, ok, "httpmw auth ok")
if ok {
_, err := auth.Roles.Expand()
@@ -904,7 +904,7 @@ func TestAPIKey(t *testing.T) {
})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
assertActorOk(t, r)
auth := httpmw.UserAuthorization(r)
auth := httpmw.UserAuthorization(r.Context())
roles, err := auth.Roles.Expand()
assert.NoError(t, err, "expand user roles")
@@ -968,7 +968,7 @@ func TestAPIKey(t *testing.T) {
RedirectToLogin: false,
})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
assertActorOk(t, r)
auth := httpmw.UserAuthorization(r)
auth := httpmw.UserAuthorization(r.Context())
roles, err := auth.Roles.Expand()
assert.NoError(t, err, "expand user roles")
+1 -1
View File
@@ -125,7 +125,7 @@ func TestExtractUserRoles(t *testing.T) {
}),
)
rtr.Get("/", func(_ http.ResponseWriter, r *http.Request) {
roles := httpmw.UserAuthorization(r)
roles := httpmw.UserAuthorization(r.Context())
require.Equal(t, user.ID.String(), roles.ID)
require.ElementsMatch(t, expRoles, roles.Roles.Names())
})
+1 -1
View File
@@ -43,7 +43,7 @@ func RateLimit(count int, window time.Duration) func(http.Handler) http.Handler
// Allow Owner to bypass rate limiting for load tests
// and automation.
auth := UserAuthorization(r)
auth := UserAuthorization(r.Context())
// We avoid using rbac.Authorizer since rego is CPU-intensive
// and undermines the DoS-prevention goal of the rate limiter.
+1 -1
View File
@@ -36,7 +36,7 @@ func authorizeMW(accessURL *url.URL) func(next http.Handler) http.Handler {
}
app := httpmw.OAuth2ProviderApp(r)
ua := httpmw.UserAuthorization(r)
ua := httpmw.UserAuthorization(r.Context())
// url.Parse() allows empty URLs, which is fine because the origin is not
// always set by browsers (or other tools like cURL). If the origin does
+1 -1
View File
@@ -133,7 +133,7 @@ func (api *API) handleDynamicParameters(listen bool, rw http.ResponseWriter, r *
// nolint:gocritic // We need to fetch the templates files for the Terraform
// evaluator, and the user likely does not have permission.
fileCtx := dbauthz.AsProvisionerd(ctx)
fileCtx := dbauthz.AsFileReader(ctx)
fileID, err := api.Database.GetFileIDByTemplateVersionID(fileCtx, templateVersion.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
+5
View File
@@ -74,6 +74,11 @@ const (
SubjectTypeSystemRestricted SubjectType = "system_restricted"
SubjectTypeNotifier SubjectType = "notifier"
SubjectTypeSubAgentAPI SubjectType = "sub_agent_api"
SubjectTypeFileReader SubjectType = "file_reader"
)
const (
SubjectTypeFileReaderID = "acbf0be6-6fed-47b6-8c43-962cb5cab994"
)
// Subject is a struct that contains all the elements of a subject in an rbac
+2 -2
View File
@@ -26,7 +26,7 @@ import (
// @Router /users/roles [get]
func (api *API) AssignableSiteRoles(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
actorRoles := httpmw.UserAuthorization(r)
actorRoles := httpmw.UserAuthorization(r.Context())
if !api.Authorize(r, policy.ActionRead, rbac.ResourceAssignRole) {
httpapi.Forbidden(rw)
return
@@ -59,7 +59,7 @@ func (api *API) AssignableSiteRoles(rw http.ResponseWriter, r *http.Request) {
func (api *API) assignableOrgRoles(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
organization := httpmw.OrganizationParam(r)
actorRoles := httpmw.UserAuthorization(r)
actorRoles := httpmw.UserAuthorization(r.Context())
if !api.Authorize(r, policy.ActionRead, rbac.ResourceAssignOrgRole.InOrg(organization.ID)) {
httpapi.ResourceNotFound(rw)
+1 -1
View File
@@ -525,7 +525,7 @@ func (api *API) deleteUser(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
auditor := *api.Auditor.Load()
user := httpmw.UserParam(r)
auth := httpmw.UserAuthorization(r)
auth := httpmw.UserAuthorization(r.Context())
aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{
Audit: auditor,
Log: api.Logger,
+1 -1
View File
@@ -133,7 +133,7 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, org database.Organiza
tags: tags,
}, nil
}
ua := httpmw.UserAuthorization(r)
ua := httpmw.UserAuthorization(r.Context())
err = p.authorizer.Authorize(ctx, ua, policy.ActionCreate, rbac.ResourceProvisionerDaemon.InOrg(org.ID))
if err != nil {
return provisiionerDaemonAuthResponse{}, xerrors.New("user unauthorized")