diff --git a/coderd/database/db.go b/coderd/database/db.go index 6d5ad99576..8a3a6f1055 100644 --- a/coderd/database/db.go +++ b/coderd/database/db.go @@ -182,7 +182,7 @@ func (q *sqlQuerier) InTx(function func(Store) error, txOpts *TxOptions) error { } // InTx performs database operations inside a transaction. -func (q *sqlQuerier) runTx(function func(Store) error, txOpts *sql.TxOptions) error { +func (q *sqlQuerier) runTx(function func(Store) error, txOpts *sql.TxOptions) (err error) { if _, ok := q.db.(*sqlx.Tx); ok { // If the current inner "db" is already a transaction, we just reuse it. // We do not need to handle commit/rollback as the outer tx will handle diff --git a/coderd/database/db_test.go b/coderd/database/db_test.go index 9941ef5ba3..bec132e0fb 100644 --- a/coderd/database/db_test.go +++ b/coderd/database/db_test.go @@ -5,9 +5,11 @@ import ( "database/sql" "testing" + "github.com/DATA-DOG/go-sqlmock" "github.com/google/uuid" "github.com/lib/pq" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtestutil" @@ -82,6 +84,33 @@ func TestNestedInTx(t *testing.T) { require.Equal(t, uid, user.ID, "user id expected") } +func TestInTx_CapturesRollbackError(t *testing.T) { + t.Parallel() + + sqlDB, mock, err := sqlmock.New() + require.NoError(t, err) + t.Cleanup(func() { _ = sqlDB.Close() }) + + db := database.New(sqlDB) + + callbackErr := xerrors.New("callback failed") + rollbackErr := xerrors.New("rollback failed") + + mock.ExpectBegin() + mock.ExpectRollback().WillReturnError(rollbackErr) + + err = db.InTx(func(_ database.Store) error { + return callbackErr + }, nil) + require.EqualError(t, err, "defer (rollback failed): execute transaction: callback failed") + require.ErrorIs(t, err, callbackErr, + "returned error should still match the callback error when rollback fails") + require.NotErrorIs(t, err, rollbackErr, + "rollback failure should be reported in the message, not wrapped in the error chain") + + require.NoError(t, mock.ExpectationsWereMet()) +} + func testSQLDB(t testing.TB) *sql.DB { t.Helper() diff --git a/go.mod b/go.mod index f4c5d57268..7dcd74dbf2 100644 --- a/go.mod +++ b/go.mod @@ -105,6 +105,7 @@ replace github.com/openai/openai-go/v3 => github.com/kylecarbs/openai-go/v3 v3.0 require ( cdr.dev/slog/v3 v3.0.0 cloud.google.com/go/compute/metadata v0.9.0 + github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/Microsoft/go-winio v0.6.2 github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d github.com/adrg/xdg v0.5.0 diff --git a/go.sum b/go.sum index 5374bc9af4..4037a81e94 100644 --- a/go.sum +++ b/go.sum @@ -796,6 +796,7 @@ github.com/kirsle/configdir v0.0.0-20170128060238-e45d2f54772f h1:dKccXx7xA56UNq github.com/kirsle/configdir v0.0.0-20170128060238-e45d2f54772f/go.mod h1:4rEELDSfUAlBSyUjPG0JnaNGjf13JySHFeRdD/3dLP0= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=