mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
chore: prevent db migrations from running on all cli commands (#15980)
This commit is contained in:
+31
-12
@@ -3,22 +3,27 @@
|
|||||||
package cli
|
package cli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
"cdr.dev/slog/sloggers/sloghuman"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/awsiamrds"
|
||||||
|
"github.com/coder/coder/v2/codersdk"
|
||||||
"github.com/coder/pretty"
|
"github.com/coder/pretty"
|
||||||
"github.com/coder/serpent"
|
"github.com/coder/serpent"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/cli/cliui"
|
"github.com/coder/coder/v2/cli/cliui"
|
||||||
"github.com/coder/coder/v2/coderd/database"
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
"github.com/coder/coder/v2/coderd/database/migrations"
|
|
||||||
"github.com/coder/coder/v2/coderd/userpassword"
|
"github.com/coder/coder/v2/coderd/userpassword"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*RootCmd) resetPassword() *serpent.Command {
|
func (*RootCmd) resetPassword() *serpent.Command {
|
||||||
var postgresURL string
|
var (
|
||||||
|
postgresURL string
|
||||||
|
postgresAuth string
|
||||||
|
)
|
||||||
|
|
||||||
root := &serpent.Command{
|
root := &serpent.Command{
|
||||||
Use: "reset-password <username>",
|
Use: "reset-password <username>",
|
||||||
@@ -27,20 +32,26 @@ func (*RootCmd) resetPassword() *serpent.Command {
|
|||||||
Handler: func(inv *serpent.Invocation) error {
|
Handler: func(inv *serpent.Invocation) error {
|
||||||
username := inv.Args[0]
|
username := inv.Args[0]
|
||||||
|
|
||||||
sqlDB, err := sql.Open("postgres", postgresURL)
|
logger := slog.Make(sloghuman.Sink(inv.Stdout))
|
||||||
|
if ok, _ := inv.ParsedFlags().GetBool("verbose"); ok {
|
||||||
|
logger = logger.Leveled(slog.LevelDebug)
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDriver := "postgres"
|
||||||
|
if codersdk.PostgresAuth(postgresAuth) == codersdk.PostgresAuthAWSIAMRDS {
|
||||||
|
var err error
|
||||||
|
sqlDriver, err = awsiamrds.Register(inv.Context(), sqlDriver)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("register aws rds iam auth: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDB, err := ConnectToPostgres(inv.Context(), logger, sqlDriver, postgresURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("dial postgres: %w", err)
|
return xerrors.Errorf("dial postgres: %w", err)
|
||||||
}
|
}
|
||||||
defer sqlDB.Close()
|
defer sqlDB.Close()
|
||||||
err = sqlDB.Ping()
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("ping postgres: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = migrations.EnsureClean(sqlDB)
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("database needs migration: %w", err)
|
|
||||||
}
|
|
||||||
db := database.New(sqlDB)
|
db := database.New(sqlDB)
|
||||||
|
|
||||||
user, err := db.GetUserByEmailOrUsername(inv.Context(), database.GetUserByEmailOrUsernameParams{
|
user, err := db.GetUserByEmailOrUsername(inv.Context(), database.GetUserByEmailOrUsernameParams{
|
||||||
@@ -97,6 +108,14 @@ func (*RootCmd) resetPassword() *serpent.Command {
|
|||||||
Env: "CODER_PG_CONNECTION_URL",
|
Env: "CODER_PG_CONNECTION_URL",
|
||||||
Value: serpent.StringOf(&postgresURL),
|
Value: serpent.StringOf(&postgresURL),
|
||||||
},
|
},
|
||||||
|
serpent.Option{
|
||||||
|
Name: "Postgres Connection Auth",
|
||||||
|
Description: "Type of auth to use when connecting to postgres.",
|
||||||
|
Flag: "postgres-connection-auth",
|
||||||
|
Env: "CODER_PG_CONNECTION_AUTH",
|
||||||
|
Default: "password",
|
||||||
|
Value: serpent.EnumOf(&postgresAuth, codersdk.PostgresAuthDrivers...),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return root
|
return root
|
||||||
|
|||||||
+21
-5
@@ -697,7 +697,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
|||||||
options.Database = dbmem.New()
|
options.Database = dbmem.New()
|
||||||
options.Pubsub = pubsub.NewInMemory()
|
options.Pubsub = pubsub.NewInMemory()
|
||||||
} else {
|
} else {
|
||||||
sqlDB, dbURL, err := getPostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver)
|
sqlDB, dbURL, err := getAndMigratePostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("connect to postgres: %w", err)
|
return xerrors.Errorf("connect to postgres: %w", err)
|
||||||
}
|
}
|
||||||
@@ -2090,9 +2090,18 @@ func IsLocalhost(host string) bool {
|
|||||||
return host == "localhost" || host == "127.0.0.1" || host == "::1"
|
return host == "localhost" || host == "127.0.0.1" || host == "::1"
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string) (sqlDB *sql.DB, err error) {
|
// ConnectToPostgres takes in the migration command to run on the database once
|
||||||
|
// it connects. To avoid running migrations, pass in `nil` or a no-op function.
|
||||||
|
// Regardless of the passed in migration function, if the database is not fully
|
||||||
|
// migrated, an error will be returned. This can happen if the database is on a
|
||||||
|
// future or past migration version.
|
||||||
|
//
|
||||||
|
// If no error is returned, the database is fully migrated and up to date.
|
||||||
|
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string, migrate func(db *sql.DB) error) (*sql.DB, error) {
|
||||||
logger.Debug(ctx, "connecting to postgresql")
|
logger.Debug(ctx, "connecting to postgresql")
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var sqlDB *sql.DB
|
||||||
// Try to connect for 30 seconds.
|
// Try to connect for 30 seconds.
|
||||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -2155,10 +2164,17 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d
|
|||||||
}
|
}
|
||||||
logger.Debug(ctx, "connected to postgresql", slog.F("version", versionNum))
|
logger.Debug(ctx, "connected to postgresql", slog.F("version", versionNum))
|
||||||
|
|
||||||
err = migrations.Up(sqlDB)
|
if migrate != nil {
|
||||||
|
err = migrate(sqlDB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, xerrors.Errorf("migrate up: %w", err)
|
return nil, xerrors.Errorf("migrate up: %w", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = migrations.EnsureClean(sqlDB)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("migrations in database: %w", err)
|
||||||
|
}
|
||||||
// The default is 0 but the request will fail with a 500 if the DB
|
// The default is 0 but the request will fail with a 500 if the DB
|
||||||
// cannot accept new connections, so we try to limit that here.
|
// cannot accept new connections, so we try to limit that here.
|
||||||
// Requests will wait for a new connection instead of a hard error
|
// Requests will wait for a new connection instead of a hard error
|
||||||
@@ -2561,7 +2577,7 @@ func signalNotifyContext(ctx context.Context, inv *serpent.Invocation, sig ...os
|
|||||||
return inv.SignalNotifyContext(ctx, sig...)
|
return inv.SignalNotifyContext(ctx, sig...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) {
|
func getAndMigratePostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) {
|
||||||
dbURL, err := escapePostgresURLUserInfo(postgresURL)
|
dbURL, err := escapePostgresURLUserInfo(postgresURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", xerrors.Errorf("escaping postgres URL: %w", err)
|
return nil, "", xerrors.Errorf("escaping postgres URL: %w", err)
|
||||||
@@ -2574,7 +2590,7 @@ func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL)
|
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL, migrations.Up)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", xerrors.Errorf("connect to postgres: %w", err)
|
return nil, "", xerrors.Errorf("connect to postgres: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *serpent.Command {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL)
|
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("connect to postgres: %w", err)
|
return xerrors.Errorf("connect to postgres: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+34
-1
@@ -38,11 +38,13 @@ import (
|
|||||||
"tailscale.com/derp/derphttp"
|
"tailscale.com/derp/derphttp"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
|
||||||
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
"github.com/coder/coder/v2/cli"
|
"github.com/coder/coder/v2/cli"
|
||||||
"github.com/coder/coder/v2/cli/clitest"
|
"github.com/coder/coder/v2/cli/clitest"
|
||||||
"github.com/coder/coder/v2/cli/config"
|
"github.com/coder/coder/v2/cli/config"
|
||||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/migrations"
|
||||||
"github.com/coder/coder/v2/coderd/httpapi"
|
"github.com/coder/coder/v2/coderd/httpapi"
|
||||||
"github.com/coder/coder/v2/coderd/telemetry"
|
"github.com/coder/coder/v2/coderd/telemetry"
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
@@ -1828,6 +1830,10 @@ func TestConnectToPostgres(t *testing.T) {
|
|||||||
if !dbtestutil.WillUsePostgres() {
|
if !dbtestutil.WillUsePostgres() {
|
||||||
t.Skip("this test does not make sense without postgres")
|
t.Skip("this test does not make sense without postgres")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("Migrate", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||||
t.Cleanup(cancel)
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
@@ -1836,12 +1842,39 @@ func TestConnectToPostgres(t *testing.T) {
|
|||||||
dbURL, err := dbtestutil.Open(t)
|
dbURL, err := dbtestutil.Open(t)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL)
|
sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, migrations.Up)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
_ = sqlDB.Close()
|
_ = sqlDB.Close()
|
||||||
})
|
})
|
||||||
require.NoError(t, sqlDB.PingContext(ctx))
|
require.NoError(t, sqlDB.PingContext(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NoMigrate", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
|
|
||||||
|
dbURL, err := dbtestutil.Open(t)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
okDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer okDB.Close()
|
||||||
|
|
||||||
|
// Set the migration number forward
|
||||||
|
_, err = okDB.Exec(`UPDATE schema_migrations SET version = version + 1`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorContains(t, err, "database needs migration")
|
||||||
|
|
||||||
|
require.NoError(t, okDB.PingContext(ctx))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_InvalidDERP(t *testing.T) {
|
func TestServer_InvalidDERP(t *testing.T) {
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ USAGE:
|
|||||||
Directly connect to the database to reset a user's password
|
Directly connect to the database to reset a user's password
|
||||||
|
|
||||||
OPTIONS:
|
OPTIONS:
|
||||||
|
--postgres-connection-auth password|awsiamrds, $CODER_PG_CONNECTION_AUTH (default: password)
|
||||||
|
Type of auth to use when connecting to postgres.
|
||||||
|
|
||||||
--postgres-url string, $CODER_PG_CONNECTION_URL
|
--postgres-url string, $CODER_PG_CONNECTION_URL
|
||||||
URL of a PostgreSQL database to connect to.
|
URL of a PostgreSQL database to connect to.
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/coder/coder/v2/cli"
|
"github.com/coder/coder/v2/cli"
|
||||||
"github.com/coder/coder/v2/coderd/database/awsiamrds"
|
"github.com/coder/coder/v2/coderd/database/awsiamrds"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/migrations"
|
||||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
)
|
)
|
||||||
@@ -32,7 +33,7 @@ func TestDriver(t *testing.T) {
|
|||||||
sqlDriver, err := awsiamrds.Register(ctx, "postgres")
|
sqlDriver, err := awsiamrds.Register(ctx, "postgres")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url)
|
db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url, migrations.Up)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = db.Close()
|
_ = db.Close()
|
||||||
|
|||||||
Generated
+10
@@ -19,3 +19,13 @@ coder reset-password [flags] <username>
|
|||||||
| Environment | <code>$CODER_PG_CONNECTION_URL</code> |
|
| Environment | <code>$CODER_PG_CONNECTION_URL</code> |
|
||||||
|
|
||||||
URL of a PostgreSQL database to connect to.
|
URL of a PostgreSQL database to connect to.
|
||||||
|
|
||||||
|
### --postgres-connection-auth
|
||||||
|
|
||||||
|
| | |
|
||||||
|
|-------------|----------------------------------------|
|
||||||
|
| Type | <code>password\|awsiamrds</code> |
|
||||||
|
| Environment | <code>$CODER_PG_CONNECTION_AUTH</code> |
|
||||||
|
| Default | <code>password</code> |
|
||||||
|
|
||||||
|
Type of auth to use when connecting to postgres.
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ func (*RootCmd) dbcryptRotateCmd() *serpent.Command {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
|
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("connect to postgres: %w", err)
|
return xerrors.Errorf("connect to postgres: %w", err)
|
||||||
}
|
}
|
||||||
@@ -163,7 +163,7 @@ func (*RootCmd) dbcryptDecryptCmd() *serpent.Command {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
|
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("connect to postgres: %w", err)
|
return xerrors.Errorf("connect to postgres: %w", err)
|
||||||
}
|
}
|
||||||
@@ -219,7 +219,7 @@ Are you sure you want to continue?`
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
|
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("connect to postgres: %w", err)
|
return xerrors.Errorf("connect to postgres: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user