test: start migrating dbauthz tests to mocked db (#19257)

This PR adds a framework to move to a mocked db. And therefore massively speed up these tests.
This commit is contained in:
Steven Masley
2025-08-08 13:46:24 -05:00
committed by GitHub
parent 155c7bbc65
commit ce935657f6
6 changed files with 186 additions and 17 deletions
+15 -13
View File
@@ -11,9 +11,11 @@ import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v7"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"cdr.dev/slog"
@@ -22,6 +24,7 @@ import (
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/notifications"
@@ -204,14 +207,15 @@ func defaultIPAddress() pqtype.Inet {
}
func (s *MethodTestSuite) TestAPIKey() {
s.Run("DeleteAPIKeyByID", s.Subtest(func(db database.Store, check *expects) {
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
key, _ := dbgen.APIKey(s.T(), db, database.APIKey{})
s.Run("DeleteAPIKeyByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
key := testutil.Fake(s.T(), faker, database.APIKey{})
dbm.EXPECT().GetAPIKeyByID(gomock.Any(), key.ID).Return(key, nil).AnyTimes()
dbm.EXPECT().DeleteAPIKeyByID(gomock.Any(), key.ID).Return(nil).AnyTimes()
check.Args(key.ID).Asserts(key, policy.ActionDelete).Returns()
}))
s.Run("GetAPIKeyByID", s.Subtest(func(db database.Store, check *expects) {
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
key, _ := dbgen.APIKey(s.T(), db, database.APIKey{})
s.Run("GetAPIKeyByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
key := testutil.Fake(s.T(), faker, database.APIKey{})
dbm.EXPECT().GetAPIKeyByID(gomock.Any(), key.ID).Return(key, nil).AnyTimes()
check.Args(key.ID).Asserts(key, policy.ActionRead).Returns(key)
}))
s.Run("GetAPIKeyByName", s.Subtest(func(db database.Store, check *expects) {
@@ -234,14 +238,12 @@ func (s *MethodTestSuite) TestAPIKey() {
Asserts(a, policy.ActionRead, b, policy.ActionRead).
Returns(slice.New(a, b))
}))
s.Run("GetAPIKeysByUserID", s.Subtest(func(db database.Store, check *expects) {
u1 := dbgen.User(s.T(), db, database.User{})
u2 := dbgen.User(s.T(), db, database.User{})
keyA, _ := dbgen.APIKey(s.T(), db, database.APIKey{UserID: u1.ID, LoginType: database.LoginTypeToken, TokenName: "key-a"})
keyB, _ := dbgen.APIKey(s.T(), db, database.APIKey{UserID: u1.ID, LoginType: database.LoginTypeToken, TokenName: "key-b"})
_, _ = dbgen.APIKey(s.T(), db, database.APIKey{UserID: u2.ID, LoginType: database.LoginTypeToken})
s.Run("GetAPIKeysByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u1 := testutil.Fake(s.T(), faker, database.User{})
keyA := testutil.Fake(s.T(), faker, database.APIKey{UserID: u1.ID, LoginType: database.LoginTypeToken, TokenName: "key-a"})
keyB := testutil.Fake(s.T(), faker, database.APIKey{UserID: u1.ID, LoginType: database.LoginTypeToken, TokenName: "key-b"})
dbm.EXPECT().GetAPIKeysByUserID(gomock.Any(), gomock.Any()).Return(slice.New(keyA, keyB), nil).AnyTimes()
check.Args(database.GetAPIKeysByUserIDParams{LoginType: database.LoginTypeToken, UserID: u1.ID}).
Asserts(keyA, policy.ActionRead, keyB, policy.ActionRead).
Returns(slice.New(keyA, keyB))
+30 -4
View File
@@ -10,6 +10,7 @@ import (
"strings"
"testing"
"github.com/brianvoe/gofakeit/v7"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
@@ -20,7 +21,6 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
@@ -28,6 +28,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/rbac/regosql"
"github.com/coder/coder/v2/coderd/util/slice"
)
@@ -105,11 +106,37 @@ func (s *MethodTestSuite) TearDownSuite() {
var testActorID = uuid.New()
// Subtest is a helper function that returns a function that can be passed to
// Mocked runs a subtest with a mocked database. Removing the overhead of a real
// postgres database resulting in much faster tests.
func (s *MethodTestSuite) Mocked(testCaseF func(dmb *dbmock.MockStore, faker *gofakeit.Faker, check *expects)) func() {
t := s.T()
mDB := dbmock.NewMockStore(gomock.NewController(t))
mDB.EXPECT().Wrappers().Return([]string{}).AnyTimes()
// Use a constant seed to prevent flakes from random data generation.
faker := gofakeit.New(0)
// The usual Subtest assumes the test setup will use a real database to populate
// with data. In this mocked case, we want to pass the underlying mocked database
// to the test case instead.
return s.SubtestWithDB(mDB, func(_ database.Store, check *expects) {
testCaseF(mDB, faker, check)
})
}
// Subtest starts up a real postgres database for each test case.
// Deprecated: Use 'Mocked' instead for much faster tests.
func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expects)) func() {
t := s.T()
db, _ := dbtestutil.NewDB(t)
return s.SubtestWithDB(db, testCaseF)
}
// SubtestWithDB is a helper function that returns a function that can be passed to
// s.Run(). This function will run the test case for the method that is being
// tested. The check parameter is used to assert the results of the method.
// If the caller does not use the `check` parameter, the test will fail.
func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expects)) func() {
func (s *MethodTestSuite) SubtestWithDB(db database.Store, testCaseF func(db database.Store, check *expects)) func() {
return func() {
t := s.T()
testName := s.T().Name()
@@ -117,7 +144,6 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expec
methodName := names[len(names)-1]
s.methodAccounting[methodName]++
db, _ := dbtestutil.NewDB(t)
fakeAuthorizer := &coderdtest.FakeAuthorizer{}
rec := &coderdtest.RecordingAuthorizer{
Wrapped: fakeAuthorizer,
+1
View File
@@ -477,6 +477,7 @@ require (
)
require (
github.com/brianvoe/gofakeit/v7 v7.3.0
github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225
github.com/coder/aisdk-go v0.0.9
github.com/coder/preview v1.0.3
+2
View File
@@ -830,6 +830,8 @@ github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl
github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bramvdbogaerde/go-scp v1.5.0 h1:a9BinAjTfQh273eh7vd3qUgmBC+bx+3TRDtkZWmIpzM=
github.com/bramvdbogaerde/go-scp v1.5.0/go.mod h1:on2aH5AxaFb2G0N5Vsdy6B0Ml7k9HuHSwfo1y0QzAbQ=
github.com/brianvoe/gofakeit/v7 v7.3.0 h1:TWStf7/lLpAjKw+bqwzeORo9jvrxToWEwp9b1J2vApQ=
github.com/brianvoe/gofakeit/v7 v7.3.0/go.mod h1:QXuPeBw164PJCzCUZVmgpgHJ3Llj49jSLVkKPMtxtxA=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/bytecodealliance/wasmtime-go/v3 v3.0.2 h1:3uZCA/BLTIu+DqCfguByNMJa2HVHpXvjfy0Dy7g6fuA=
+67
View File
@@ -0,0 +1,67 @@
package testutil
import (
"reflect"
"testing"
"github.com/brianvoe/gofakeit/v7"
"github.com/stretchr/testify/require"
)
// Fake will populate any zero fields in the provided struct with fake data.
// Non-zero fields will remain unchanged.
// Usage:
//
// key := Fake(t, faker, database.APIKey{
// TokenName: "keep-my-name",
// })
func Fake[T any](t *testing.T, faker *gofakeit.Faker, seed T) T {
t.Helper()
var tmp T
err := faker.Struct(&tmp)
require.NoError(t, err, "failed to generate fake data for type %T", tmp)
mergeZero(&seed, tmp)
return seed
}
// mergeZero merges the fields of src into dst, but only if the field in dst is
// currently the zero value.
// Make sure `dst` is a pointer to a struct, otherwise the fields are not assignable.
func mergeZero(dst any, src any) {
srcv := reflect.ValueOf(src)
if srcv.Kind() == reflect.Ptr {
srcv = srcv.Elem()
}
remain := [][2]reflect.Value{
{reflect.ValueOf(dst).Elem(), srcv},
}
// Traverse the struct fields and set them only if they are currently zero.
// This is a breadth-first traversal of the struct fields. Struct definitions
// Should not be that deep, so we should not hit any stack overflow issues.
for {
if len(remain) == 0 {
return
}
dv, sv := remain[0][0], remain[0][1]
remain = remain[1:] //
for i := 0; i < dv.NumField(); i++ {
df := dv.Field(i)
sf := sv.Field(i)
if !df.CanSet() {
continue
}
if df.IsZero() { // only write if currently zero
df.Set(sf)
continue
}
if dv.Field(i).Kind() == reflect.Struct {
// If the field is a struct, we need to traverse it as well.
remain = append(remain, [2]reflect.Value{df, sf})
}
}
}
}
+71
View File
@@ -0,0 +1,71 @@
package testutil_test
import (
"testing"
"github.com/brianvoe/gofakeit/v7"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/testutil"
)
type simpleStruct struct {
ID uuid.UUID
Name string
Description string
Age int `fake:"{number:18,60}"`
}
type nestedStruct struct {
Person simpleStruct
Address string
}
func TestFake(t *testing.T) {
t.Parallel()
t.Run("Simple", func(t *testing.T) {
t.Parallel()
faker := gofakeit.New(0)
person := testutil.Fake(t, faker, simpleStruct{
Name: "alice",
})
require.Equal(t, "alice", person.Name)
require.NotEqual(t, uuid.Nil, person.ID)
require.NotEmpty(t, person.Description)
require.Greater(t, person.Age, 17, "Age should be greater than 17")
require.Less(t, person.Age, 61, "Age should be less than 61")
})
t.Run("Nested", func(t *testing.T) {
t.Parallel()
faker := gofakeit.New(0)
person := testutil.Fake(t, faker, nestedStruct{
Person: simpleStruct{
Name: "alice",
},
})
require.Equal(t, "alice", person.Person.Name)
require.NotEqual(t, uuid.Nil, person.Person.ID)
require.NotEmpty(t, person.Person.Description)
require.Greater(t, person.Person.Age, 17, "Age should be greater than 17")
require.NotEmpty(t, person.Address)
})
t.Run("DatabaseType", func(t *testing.T) {
t.Parallel()
faker := gofakeit.New(0)
id := uuid.New()
key := testutil.Fake(t, faker, database.APIKey{
UserID: id,
TokenName: "keep-my-name",
})
require.Equal(t, id, key.UserID)
require.NotEmpty(t, key.TokenName)
})
}