diff --git a/coderd/database/dbtestutil/broker.go b/coderd/database/dbtestutil/broker.go new file mode 100644 index 0000000000..371a8dc9bb --- /dev/null +++ b/coderd/database/dbtestutil/broker.go @@ -0,0 +1,176 @@ +package dbtestutil + +import ( + "database/sql" + _ "embed" + "fmt" + "os" + "sync" + "time" + + "github.com/google/uuid" + "github.com/lib/pq" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/cryptorand" +) + +const CoderTestingDBName = "coder_testing" + +//go:embed coder_testing.sql +var coderTestingSQLInit string + +type Broker struct { + sync.Mutex + uuid uuid.UUID + coderTestingDB *sql.DB + refCount int +} + +func (b *Broker) Create(t TBSubset, opts ...OpenOption) (ConnectionParams, error) { + if err := b.init(t); err != nil { + return ConnectionParams{}, err + } + openOptions := OpenOptions{} + for _, opt := range opts { + opt(&openOptions) + } + + var ( + username = defaultConnectionParams.Username + password = defaultConnectionParams.Password + host = defaultConnectionParams.Host + port = defaultConnectionParams.Port + ) + + // Use a time-based prefix to make it easier to find the database + // when debugging. + now := time.Now().Format("test_2006_01_02_15_04_05") + dbSuffix, err := cryptorand.StringCharset(cryptorand.Lower, 10) + if err != nil { + return ConnectionParams{}, xerrors.Errorf("generate db suffix: %w", err) + } + dbName := now + "_" + dbSuffix + + // TODO: add package and test name + _, err = b.coderTestingDB.Exec( + "INSERT INTO test_databases (name, process_uuid) VALUES ($1, $2)", dbName, b.uuid) + if err != nil { + return ConnectionParams{}, xerrors.Errorf("insert test_database row: %w", err) + } + + // if empty createDatabaseFromTemplate will create a new template db + templateDBName := os.Getenv("DB_FROM") + if openOptions.DBFrom != nil { + templateDBName = *openOptions.DBFrom + } + if err = createDatabaseFromTemplate(t, defaultConnectionParams, b.coderTestingDB, dbName, templateDBName); err != nil { + return ConnectionParams{}, xerrors.Errorf("create database: %w", err) + } + + testDBParams := ConnectionParams{ + Username: username, + Password: password, + Host: host, + Port: port, + DBName: dbName, + } + + // Optionally log the DSN to help connect to the test database. + if openOptions.LogDSN { + _, _ = fmt.Fprintf(os.Stderr, "Connect to the database for %s using: psql '%s'\n", t.Name(), testDBParams.DSN()) + } + t.Cleanup(b.clean(t, dbName)) + return testDBParams, nil +} + +func (b *Broker) clean(t TBSubset, dbName string) func() { + return func() { + _, err := b.coderTestingDB.Exec("DROP DATABASE " + dbName + ";") + if err != nil { + t.Logf("failed to clean up database %q: %s\n", dbName, err.Error()) + return + } + _, err = b.coderTestingDB.Exec("UPDATE test_databases SET dropped_at = CURRENT_TIMESTAMP WHERE name = $1", dbName) + if err != nil { + t.Logf("failed to mark test database '%s' dropped: %s\n", dbName, err.Error()) + } + } +} + +func (b *Broker) init(t TBSubset) error { + b.Lock() + defer b.Unlock() + b.refCount++ + t.Cleanup(b.decRef) + if b.coderTestingDB != nil { + // already initialized + return nil + } + + connectionParamsInitOnce.Do(func() { + errDefaultConnectionParamsInit = initDefaultConnection(t) + }) + if errDefaultConnectionParamsInit != nil { + return xerrors.Errorf("init default connection params: %w", errDefaultConnectionParamsInit) + } + coderTestingParams := defaultConnectionParams + coderTestingParams.DBName = CoderTestingDBName + coderTestingDB, err := sql.Open("postgres", coderTestingParams.DSN()) + if err != nil { + return xerrors.Errorf("open postgres connection: %w", err) + } + + // creating the db can succeed even if the database doesn't exist. Ping it to find out. + err = coderTestingDB.Ping() + var pqErr *pq.Error + if xerrors.As(err, &pqErr) && pqErr.Code == "3D000" { + // database does not exist. + if closeErr := coderTestingDB.Close(); closeErr != nil { + return xerrors.Errorf("close postgres connection: %w", closeErr) + } + err = createCoderTestingDB(t) + if err != nil { + return xerrors.Errorf("create coder testing db: %w", err) + } + coderTestingDB, err = sql.Open("postgres", coderTestingParams.DSN()) + if err != nil { + return xerrors.Errorf("open postgres connection: %w", err) + } + } else if err != nil { + _ = coderTestingDB.Close() + return xerrors.Errorf("ping '%s' database: %w", CoderTestingDBName, err) + } + b.coderTestingDB = coderTestingDB + b.uuid = uuid.New() + return nil +} + +func createCoderTestingDB(t TBSubset) error { + db, err := sql.Open("postgres", defaultConnectionParams.DSN()) + if err != nil { + return xerrors.Errorf("open postgres connection: %w", err) + } + defer func() { + _ = db.Close() + }() + err = createAndInitDatabase(t, defaultConnectionParams, db, CoderTestingDBName, func(testDB *sql.DB) error { + _, err := testDB.Exec(coderTestingSQLInit) + return err + }) + if err != nil { + return xerrors.Errorf("create coder testing db: %w", err) + } + return nil +} + +func (b *Broker) decRef() { + b.Lock() + defer b.Unlock() + b.refCount-- + if b.refCount == 0 { + // ensures we don't leave go routines around for GoLeak to find. + _ = b.coderTestingDB.Close() + b.coderTestingDB = nil + } +} diff --git a/coderd/database/dbtestutil/coder_testing.sql b/coderd/database/dbtestutil/coder_testing.sql new file mode 100644 index 0000000000..edaab486c8 --- /dev/null +++ b/coderd/database/dbtestutil/coder_testing.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS test_databases ( + name text PRIMARY KEY, + created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP, + dropped_at timestamp with time zone, -- null means it hasn't been dropped + process_uuid uuid NOT NULL +); + +CREATE INDEX IF NOT EXISTS test_databases_process_uuid ON test_databases (process_uuid, dropped_at); diff --git a/coderd/database/dbtestutil/postgres.go b/coderd/database/dbtestutil/postgres.go index 1ab80569de..567fae0daf 100644 --- a/coderd/database/dbtestutil/postgres.go +++ b/coderd/database/dbtestutil/postgres.go @@ -22,7 +22,6 @@ import ( "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/database/migrations" - "github.com/coder/coder/v2/cryptorand" "github.com/coder/retry" ) @@ -52,6 +51,7 @@ var ( "connection refused", // nothing is listening on the port "No connection could be made", // Windows variant of the above } + DefaultBroker = Broker{} ) // initDefaultConnection initializes the default postgres connection parameters. @@ -173,101 +173,25 @@ type TBSubset interface { // Otherwise, it will start a new postgres container. func Open(t TBSubset, opts ...OpenOption) (string, error) { t.Helper() - - connectionParamsInitOnce.Do(func() { - errDefaultConnectionParamsInit = initDefaultConnection(t) - }) - if errDefaultConnectionParamsInit != nil { - return "", xerrors.Errorf("init default connection params: %w", errDefaultConnectionParamsInit) - } - - openOptions := OpenOptions{} - for _, opt := range opts { - opt(&openOptions) - } - - var ( - username = defaultConnectionParams.Username - password = defaultConnectionParams.Password - host = defaultConnectionParams.Host - port = defaultConnectionParams.Port - ) - - // Use a time-based prefix to make it easier to find the database - // when debugging. - now := time.Now().Format("test_2006_01_02_15_04_05") - dbSuffix, err := cryptorand.StringCharset(cryptorand.Lower, 10) + params, err := DefaultBroker.Create(t, opts...) if err != nil { - return "", xerrors.Errorf("generate db suffix: %w", err) + return "", err } - dbName := now + "_" + dbSuffix - - // if empty createDatabaseFromTemplate will create a new template db - templateDBName := os.Getenv("DB_FROM") - if openOptions.DBFrom != nil { - templateDBName = *openOptions.DBFrom - } - if err = createDatabaseFromTemplate(t, defaultConnectionParams, dbName, templateDBName); err != nil { - return "", xerrors.Errorf("create database: %w", err) - } - - t.Cleanup(func() { - cleanupDbURL := defaultConnectionParams.DSN() - cleanupConn, err := sql.Open("postgres", cleanupDbURL) - if err != nil { - t.Logf("cleanup database %q: failed to connect to postgres: %s\n", dbName, err.Error()) - return - } - defer func() { - if err := cleanupConn.Close(); err != nil { - t.Logf("cleanup database %q: failed to close connection: %s\n", dbName, err.Error()) - } - }() - _, err = cleanupConn.Exec("DROP DATABASE " + dbName + ";") - if err != nil { - t.Logf("failed to clean up database %q: %s\n", dbName, err.Error()) - return - } - }) - - dsn := ConnectionParams{ - Username: username, - Password: password, - Host: host, - Port: port, - DBName: dbName, - }.DSN() - - // Optionally log the DSN to help connect to the test database. - if openOptions.LogDSN { - _, _ = fmt.Fprintf(os.Stderr, "Connect to the database for %s using: psql '%s'\n", t.Name(), dsn) - } - return dsn, nil + return params.DSN(), nil } // createDatabaseFromTemplate creates a new database from a template database. // If templateDBName is empty, it will create a new template database based on // the current migrations, and name it "tpl_". Or if it's // already been created, it will use that. -func createDatabaseFromTemplate(t TBSubset, connParams ConnectionParams, newDBName string, templateDBName string) error { +func createDatabaseFromTemplate(t TBSubset, connParams ConnectionParams, db *sql.DB, newDBName string, templateDBName string) error { t.Helper() - dbURL := connParams.DSN() - db, err := sql.Open("postgres", dbURL) - if err != nil { - return xerrors.Errorf("connect to postgres: %w", err) - } - defer func() { - if err := db.Close(); err != nil { - t.Logf("create database from template: failed to close connection: %s\n", err.Error()) - } - }() - emptyTemplateDBName := templateDBName == "" if emptyTemplateDBName { templateDBName = fmt.Sprintf("tpl_%s", migrations.GetMigrationsHash()[:32]) } - _, err = db.Exec("CREATE DATABASE " + newDBName + " WITH TEMPLATE " + templateDBName) + _, err := db.Exec("CREATE DATABASE " + newDBName + " WITH TEMPLATE " + templateDBName) if err == nil { // Template database already exists and we successfully created the new database. return nil @@ -282,82 +206,96 @@ func createDatabaseFromTemplate(t TBSubset, connParams ConnectionParams, newDBNa // sanity check panic("templateDBName is not empty. there's a bug in the code above") } - // The templateDBName is empty, so we need to create the template database. - // We will use a tx to obtain a lock, so another test or process doesn't race with us. - tx, err := db.BeginTx(context.Background(), nil) - if err != nil { - return xerrors.Errorf("begin tx: %w", err) - } - defer func() { - err := tx.Rollback() - if err != nil && !errors.Is(err, sql.ErrTxDone) { - t.Logf("create database from template: failed to rollback tx: %s\n", err.Error()) - } - }() - // 2137 is an arbitrary number. We just need a lock that is unique to creating - // the template database. - _, err = tx.Exec("SELECT pg_advisory_xact_lock(2137)") - if err != nil { - return xerrors.Errorf("acquire lock: %w", err) - } - // Someone else might have created the template db while we were waiting. - tplDbExistsRes, err := tx.Query("SELECT 1 FROM pg_database WHERE datname = $1", templateDBName) - if err != nil { - return xerrors.Errorf("check if db exists: %w", err) - } - tplDbAlreadyExists := tplDbExistsRes.Next() - if err := tplDbExistsRes.Close(); err != nil { - return xerrors.Errorf("close tpl db exists res: %w", err) - } - if !tplDbAlreadyExists { - // We will use a temporary template database to avoid race conditions. We will - // rename it to the real template database name after we're sure it was fully - // initialized. - // It's dropped here to ensure that if a previous run of this function failed - // midway, we don't encounter issues with the temporary database still existing. - tmpTemplateDBName := "tmp_" + templateDBName - // We're using db instead of tx here because you can't run `DROP DATABASE` inside - // a transaction. - if _, err := db.Exec("DROP DATABASE IF EXISTS " + tmpTemplateDBName); err != nil { - return xerrors.Errorf("drop tmp template db: %w", err) - } - if _, err := db.Exec("CREATE DATABASE " + tmpTemplateDBName); err != nil { - return xerrors.Errorf("create tmp template db: %w", err) - } - tplDbURL := ConnectionParams{ - Username: connParams.Username, - Password: connParams.Password, - Host: connParams.Host, - Port: connParams.Port, - DBName: tmpTemplateDBName, - }.DSN() - tplDb, err := sql.Open("postgres", tplDbURL) - if err != nil { - return xerrors.Errorf("connect to template db: %w", err) - } - defer func() { - if err := tplDb.Close(); err != nil { - t.Logf("create database from template: failed to close template db: %s\n", err.Error()) - } - }() + // The templateDBName is empty, so we need to create the template database. + err = createAndInitDatabase(t, connParams, db, templateDBName, func(tplDb *sql.DB) error { if err := migrations.Up(tplDb); err != nil { return xerrors.Errorf("migrate template db: %w", err) } - if err := tplDb.Close(); err != nil { - return xerrors.Errorf("close template db: %w", err) - } - if _, err := db.Exec("ALTER DATABASE " + tmpTemplateDBName + " RENAME TO " + templateDBName); err != nil { - return xerrors.Errorf("rename tmp template db: %w", err) - } + return nil + }) + if err != nil { + return xerrors.Errorf("create template database: %w", err) } // Try to create the database again now that a template exists. if _, err = db.Exec("CREATE DATABASE " + newDBName + " WITH TEMPLATE " + templateDBName); err != nil { return xerrors.Errorf("create db with template after migrations: %w", err) } - if err = tx.Commit(); err != nil { - return xerrors.Errorf("commit tx: %w", err) + return nil +} + +func createAndInitDatabase(t TBSubset, connParams ConnectionParams, db *sql.DB, name string, initialize func(*sql.DB) error) error { + // We will use a tx to obtain a lock, so another test or process doesn't race with us. + tx, err := db.BeginTx(context.Background(), nil) + if err != nil { + return xerrors.Errorf("begin tx: %w", err) + } + // we only use the transaction for locking and querying, so it's fine to always roll it back. + defer func() { + err := tx.Rollback() + if err != nil && !errors.Is(err, sql.ErrTxDone) { + t.Logf("create database: failed to rollback tx: %s\n", err.Error()) + } + }() + // 2137 is an arbitrary number. We just need a lock that is unique to creating + // the database. + _, err = tx.Exec("SELECT pg_advisory_xact_lock(2137)") + if err != nil { + return xerrors.Errorf("acquire lock: %w", err) + } + + // Someone else might have created the db while we were waiting. + dbExistsRes, err := tx.Query("SELECT 1 FROM pg_database WHERE datname = $1", name) + if err != nil { + return xerrors.Errorf("check if db exists: %w", err) + } + dbAlreadyExists := dbExistsRes.Next() + if err := dbExistsRes.Close(); err != nil { + return xerrors.Errorf("close tpl db exists res: %w", err) + } + if dbAlreadyExists { + return nil + } + + // We will use a temporary database to avoid race conditions. We will + // rename it to the real database name after we're sure it was fully + // initialized. + // It's dropped here to ensure that if a previous run of this function failed + // midway, we don't encounter issues with the temporary database still existing. + tmpDBName := "tmp_" + name + // We're using db instead of tx here because you can't run `DROP DATABASE` inside + // a transaction. + if _, err := db.Exec("DROP DATABASE IF EXISTS " + tmpDBName); err != nil { + return xerrors.Errorf("drop tmp db: %w", err) + } + if _, err := db.Exec("CREATE DATABASE " + tmpDBName); err != nil { + return xerrors.Errorf("create tmp db: %w", err) + } + tmpDbURL := ConnectionParams{ + Username: connParams.Username, + Password: connParams.Password, + Host: connParams.Host, + Port: connParams.Port, + DBName: tmpDBName, + }.DSN() + tmpDb, err := sql.Open("postgres", tmpDbURL) + if err != nil { + return xerrors.Errorf("connect to template db: %w", err) + } + defer func() { + if err := tmpDb.Close(); err != nil { + t.Logf("failed to close temp db: %s\n", err.Error()) + } + }() + if err := initialize(tmpDb); err != nil { + return xerrors.Errorf("initialize: %w", err) + } + if err := tmpDb.Close(); err != nil { + return xerrors.Errorf("close template db: %w", err) + } + if _, err := db.Exec("ALTER DATABASE " + tmpDBName + " RENAME TO " + name); err != nil { + return xerrors.Errorf("rename tmp db: %w", err) } return nil }