mirror of
https://github.com/coder/coder.git
synced 2026-06-04 05:28:20 +00:00
e2f5401fb2
fixes https://github.com/coder/internal/issues/927 Adds a small subprocess that outlives the testing process to clean up any leaked test databases.
189 lines
5.2 KiB
Go
189 lines
5.2 KiB
Go
package dbtestutil
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
_ "embed"
|
|
"fmt"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/lib/pq"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/cryptorand"
|
|
)
|
|
|
|
const CoderTestingDBName = "coder_testing"
|
|
|
|
//go:embed coder_testing.sql
|
|
var coderTestingSQLInit string
|
|
|
|
type Broker struct {
|
|
sync.Mutex
|
|
uuid uuid.UUID
|
|
coderTestingDB *sql.DB
|
|
refCount int
|
|
// we keep a reference to the stdin of the cleaner so that Go doesn't garbage collect it.
|
|
cleanerFD any
|
|
}
|
|
|
|
func (b *Broker) Create(t TBSubset, opts ...OpenOption) (ConnectionParams, error) {
|
|
if err := b.init(t); err != nil {
|
|
return ConnectionParams{}, err
|
|
}
|
|
openOptions := OpenOptions{}
|
|
for _, opt := range opts {
|
|
opt(&openOptions)
|
|
}
|
|
|
|
var (
|
|
username = defaultConnectionParams.Username
|
|
password = defaultConnectionParams.Password
|
|
host = defaultConnectionParams.Host
|
|
port = defaultConnectionParams.Port
|
|
)
|
|
|
|
// Use a time-based prefix to make it easier to find the database
|
|
// when debugging.
|
|
now := time.Now().Format("test_2006_01_02_15_04_05")
|
|
dbSuffix, err := cryptorand.StringCharset(cryptorand.Lower, 10)
|
|
if err != nil {
|
|
return ConnectionParams{}, xerrors.Errorf("generate db suffix: %w", err)
|
|
}
|
|
dbName := now + "_" + dbSuffix
|
|
|
|
// TODO: add package and test name
|
|
_, err = b.coderTestingDB.Exec(
|
|
"INSERT INTO test_databases (name, process_uuid) VALUES ($1, $2)", dbName, b.uuid)
|
|
if err != nil {
|
|
return ConnectionParams{}, xerrors.Errorf("insert test_database row: %w", err)
|
|
}
|
|
|
|
// if empty createDatabaseFromTemplate will create a new template db
|
|
templateDBName := os.Getenv("DB_FROM")
|
|
if openOptions.DBFrom != nil {
|
|
templateDBName = *openOptions.DBFrom
|
|
}
|
|
if err = createDatabaseFromTemplate(t, defaultConnectionParams, b.coderTestingDB, dbName, templateDBName); err != nil {
|
|
return ConnectionParams{}, xerrors.Errorf("create database: %w", err)
|
|
}
|
|
|
|
testDBParams := ConnectionParams{
|
|
Username: username,
|
|
Password: password,
|
|
Host: host,
|
|
Port: port,
|
|
DBName: dbName,
|
|
}
|
|
|
|
// Optionally log the DSN to help connect to the test database.
|
|
if openOptions.LogDSN {
|
|
_, _ = fmt.Fprintf(os.Stderr, "Connect to the database for %s using: psql '%s'\n", t.Name(), testDBParams.DSN())
|
|
}
|
|
t.Cleanup(b.clean(t, dbName))
|
|
return testDBParams, nil
|
|
}
|
|
|
|
func (b *Broker) clean(t TBSubset, dbName string) func() {
|
|
return func() {
|
|
_, err := b.coderTestingDB.Exec("DROP DATABASE " + dbName + ";")
|
|
if err != nil {
|
|
t.Logf("failed to clean up database %q: %s\n", dbName, err.Error())
|
|
return
|
|
}
|
|
_, err = b.coderTestingDB.Exec("UPDATE test_databases SET dropped_at = CURRENT_TIMESTAMP WHERE name = $1", dbName)
|
|
if err != nil {
|
|
t.Logf("failed to mark test database '%s' dropped: %s\n", dbName, err.Error())
|
|
}
|
|
}
|
|
}
|
|
|
|
func (b *Broker) init(t TBSubset) error {
|
|
b.Lock()
|
|
defer b.Unlock()
|
|
b.refCount++
|
|
t.Cleanup(b.decRef)
|
|
if b.coderTestingDB != nil {
|
|
// already initialized
|
|
return nil
|
|
}
|
|
|
|
connectionParamsInitOnce.Do(func() {
|
|
errDefaultConnectionParamsInit = initDefaultConnection(t)
|
|
})
|
|
if errDefaultConnectionParamsInit != nil {
|
|
return xerrors.Errorf("init default connection params: %w", errDefaultConnectionParamsInit)
|
|
}
|
|
coderTestingParams := defaultConnectionParams
|
|
coderTestingParams.DBName = CoderTestingDBName
|
|
coderTestingDB, err := sql.Open("postgres", coderTestingParams.DSN())
|
|
if err != nil {
|
|
return xerrors.Errorf("open postgres connection: %w", err)
|
|
}
|
|
|
|
// creating the db can succeed even if the database doesn't exist. Ping it to find out.
|
|
err = coderTestingDB.Ping()
|
|
var pqErr *pq.Error
|
|
if xerrors.As(err, &pqErr) && pqErr.Code == "3D000" {
|
|
// database does not exist.
|
|
if closeErr := coderTestingDB.Close(); closeErr != nil {
|
|
return xerrors.Errorf("close postgres connection: %w", closeErr)
|
|
}
|
|
err = createCoderTestingDB(t)
|
|
if err != nil {
|
|
return xerrors.Errorf("create coder testing db: %w", err)
|
|
}
|
|
coderTestingDB, err = sql.Open("postgres", coderTestingParams.DSN())
|
|
if err != nil {
|
|
return xerrors.Errorf("open postgres connection: %w", err)
|
|
}
|
|
} else if err != nil {
|
|
_ = coderTestingDB.Close()
|
|
return xerrors.Errorf("ping '%s' database: %w", CoderTestingDBName, err)
|
|
}
|
|
b.coderTestingDB = coderTestingDB
|
|
|
|
if b.uuid == uuid.Nil {
|
|
b.uuid = uuid.New()
|
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
|
defer cancel()
|
|
b.cleanerFD, err = startCleaner(ctx, b.uuid, coderTestingParams.DSN())
|
|
if err != nil {
|
|
return xerrors.Errorf("start test db cleaner: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func createCoderTestingDB(t TBSubset) error {
|
|
db, err := sql.Open("postgres", defaultConnectionParams.DSN())
|
|
if err != nil {
|
|
return xerrors.Errorf("open postgres connection: %w", err)
|
|
}
|
|
defer func() {
|
|
_ = db.Close()
|
|
}()
|
|
err = createAndInitDatabase(t, defaultConnectionParams, db, CoderTestingDBName, func(testDB *sql.DB) error {
|
|
_, err := testDB.Exec(coderTestingSQLInit)
|
|
return err
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("create coder testing db: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (b *Broker) decRef() {
|
|
b.Lock()
|
|
defer b.Unlock()
|
|
b.refCount--
|
|
if b.refCount == 0 {
|
|
// ensures we don't leave go routines around for GoLeak to find.
|
|
_ = b.coderTestingDB.Close()
|
|
b.coderTestingDB = nil
|
|
}
|
|
}
|