diff --git a/cli/server.go b/cli/server.go index 6edd14b7d5..b12f5e0189 100644 --- a/cli/server.go +++ b/cli/server.go @@ -29,6 +29,7 @@ import ( "strings" "sync" "sync/atomic" + "testing" "time" "github.com/charmbracelet/lipgloss" @@ -1377,6 +1378,7 @@ func IsLocalURL(ctx context.Context, u *url.URL) (bool, error) { } func shutdownWithTimeout(shutdown func(context.Context) error, timeout time.Duration) error { + // nolint:gocritic // The magic number is parameterized. ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() return shutdown(ctx) @@ -2134,50 +2136,83 @@ func startBuiltinPostgres(ctx context.Context, cfg config.Root, logger slog.Logg return "", nil, xerrors.New("The built-in PostgreSQL cannot run as the root user. Create a non-root user and run again!") } - // Ensure a password and port have been generated! - connectionURL, err := embeddedPostgresURL(cfg) - if err != nil { - return "", nil, err - } - pgPassword, err := cfg.PostgresPassword().Read() - if err != nil { - return "", nil, xerrors.Errorf("read postgres password: %w", err) - } - pgPortRaw, err := cfg.PostgresPort().Read() - if err != nil { - return "", nil, xerrors.Errorf("read postgres port: %w", err) - } - pgPort, err := strconv.ParseUint(pgPortRaw, 10, 16) - if err != nil { - return "", nil, xerrors.Errorf("parse postgres port: %w", err) - } - cachePath := filepath.Join(cfg.PostgresPath(), "cache") if customCacheDir != "" { cachePath = filepath.Join(customCacheDir, "postgres") } stdlibLogger := slog.Stdlib(ctx, logger.Named("postgres"), slog.LevelDebug) - ep := embeddedpostgres.NewDatabase( - embeddedpostgres.DefaultConfig(). - Version(embeddedpostgres.V13). - BinariesPath(filepath.Join(cfg.PostgresPath(), "bin")). - // Default BinaryRepositoryURL repo1.maven.org is flaky. - BinaryRepositoryURL("https://repo.maven.apache.org/maven2"). - DataPath(filepath.Join(cfg.PostgresPath(), "data")). - RuntimePath(filepath.Join(cfg.PostgresPath(), "runtime")). - CachePath(cachePath). - Username("coder"). - Password(pgPassword). - Database("coder"). - Encoding("UTF8"). - Port(uint32(pgPort)). - Logger(stdlibLogger.Writer()), - ) - err = ep.Start() - if err != nil { - return "", nil, xerrors.Errorf("Failed to start built-in PostgreSQL. Optionally, specify an external deployment with `--postgres-url`: %w", err) + + // If the port is not defined, an available port will be found dynamically. + maxAttempts := 1 + _, err = cfg.PostgresPort().Read() + retryPortDiscovery := errors.Is(err, os.ErrNotExist) && testing.Testing() + if retryPortDiscovery { + // There is no way to tell Postgres to use an ephemeral port, so in order to avoid + // flaky tests in CI we need to retry EmbeddedPostgres.Start in case of a race + // condition where the port we quickly listen on and close in embeddedPostgresURL() + // is not free by the time the embedded postgres starts up. This maximum_should + // cover most cases where port conflicts occur in CI and cause flaky tests. + maxAttempts = 3 } - return connectionURL, ep.Stop, nil + + var startErr error + for attempt := 0; attempt < maxAttempts; attempt++ { + // Ensure a password and port have been generated. + connectionURL, err := embeddedPostgresURL(cfg) + if err != nil { + return "", nil, err + } + pgPassword, err := cfg.PostgresPassword().Read() + if err != nil { + return "", nil, xerrors.Errorf("read postgres password: %w", err) + } + pgPortRaw, err := cfg.PostgresPort().Read() + if err != nil { + return "", nil, xerrors.Errorf("read postgres port: %w", err) + } + pgPort, err := strconv.ParseUint(pgPortRaw, 10, 16) + if err != nil { + return "", nil, xerrors.Errorf("parse postgres port: %w", err) + } + + ep := embeddedpostgres.NewDatabase( + embeddedpostgres.DefaultConfig(). + Version(embeddedpostgres.V13). + BinariesPath(filepath.Join(cfg.PostgresPath(), "bin")). + // Default BinaryRepositoryURL repo1.maven.org is flaky. + BinaryRepositoryURL("https://repo.maven.apache.org/maven2"). + DataPath(filepath.Join(cfg.PostgresPath(), "data")). + RuntimePath(filepath.Join(cfg.PostgresPath(), "runtime")). + CachePath(cachePath). + Username("coder"). + Password(pgPassword). + Database("coder"). + Encoding("UTF8"). + Port(uint32(pgPort)). + Logger(stdlibLogger.Writer()), + ) + + startErr = ep.Start() + if startErr == nil { + return connectionURL, ep.Stop, nil + } + + logger.Warn(ctx, "failed to start embedded postgres", + slog.F("attempt", attempt+1), + slog.F("max_attempts", maxAttempts), + slog.F("port", pgPort), + slog.Error(startErr), + ) + + if retryPortDiscovery { + // Since a retry is needed, we wipe the port stored here at the beginning of the loop. + _ = cfg.PostgresPort().Delete() + } + } + + return "", nil, xerrors.Errorf("failed to start built-in PostgreSQL after %d attempts. "+ + "Optionally, specify an external deployment. See https://coder.com/docs/tutorials/external-database "+ + "for more details: %w", maxAttempts, startErr) } func ConfigureHTTPClient(ctx context.Context, clientCertFile, clientKeyFile string, tlsClientCAFile string) (context.Context, *http.Client, error) { @@ -2286,7 +2321,7 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d var err error var sqlDB *sql.DB dbNeedsClosing := true - // Try to connect for 30 seconds. + // nolint:gocritic // Try to connect for 30 seconds. ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -2382,6 +2417,7 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d } func pingPostgres(ctx context.Context, db *sql.DB) error { + // nolint:gocritic // This is a reasonable magic number for a ping timeout. ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() return db.PingContext(ctx)