diff --git a/go.mod b/go.mod index e0c11bb5fb..afbedf22f2 100644 --- a/go.mod +++ b/go.mod @@ -595,6 +595,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackmordaunt/icns/v3 v3.0.1 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect + github.com/joho/godotenv v1.5.1 github.com/kaptinlin/go-i18n v0.2.4 // indirect github.com/kaptinlin/jsonpointer v0.4.10 // indirect github.com/kaptinlin/jsonschema v0.6.10 // indirect diff --git a/go.sum b/go.sum index 644911e5ec..a290649c90 100644 --- a/go.sum +++ b/go.sum @@ -768,6 +768,8 @@ github.com/jedib0t/go-pretty/v6 v6.7.1/go.mod h1:YwC5CE4fJ1HFUDeivSV1r//AmANFHyq github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 h1:elKwZS1OcdQ0WwEDBeqxKwb7WB62QX8bvZ/FJnVXIfk= diff --git a/scripts/develop/main.go b/scripts/develop/main.go index b6905af7cf..e66f6f0936 100644 --- a/scripts/develop/main.go +++ b/scripts/develop/main.go @@ -27,6 +27,7 @@ import ( "time" "github.com/google/uuid" + "github.com/joho/godotenv" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" @@ -69,6 +70,24 @@ const ( ) func main() { + // Pre-parse --env-file before serpent runs so that variables from + // the file are visible to serpent's Env-tag resolution for other + // options. The flag is also registered in the serpent OptionSet + // below for --help discoverability. + envFile, err := parseEnvFileFlag() + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "develop: %v\n", err) + os.Exit(1) + } + if envFile != "" { + n, err := loadEnvFile(envFile) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "develop: error loading env file %s: %v\n", envFile, err) + os.Exit(1) + } + _, _ = fmt.Fprintf(os.Stderr, "develop: loaded %d variable(s) from %s\n", n, envFile) + } + var cfg devConfig cmd := &serpent.Command{ @@ -182,6 +201,12 @@ func main() { Description: "Accept changed migration files and update tracking. Use when you've manually fixed the DB to match the new migrations.", Value: serpent.BoolOf(&cfg.dbContinue), }, + { + Flag: "env-file", + Env: "CODER_DEV_ENV_FILE", + Description: "Path to a .env file to load before starting. Variables in the file do not override existing environment variables. Note: unquoted and double-quoted values undergo $VAR expansion against other entries in the same file (not the process environment); use single quotes for literal dollar signs.", + Value: serpent.StringOf(&cfg.envFile), + }, }, Handler: func(inv *serpent.Invocation) error { cfg.serverExtraArgs = inv.Args @@ -198,7 +223,7 @@ func main() { }, } - err := cmd.Invoke(os.Args[1:]...).WithOS().Run() + err = cmd.Invoke(os.Args[1:]...).WithOS().Run() if err != nil { _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) @@ -223,6 +248,9 @@ type devConfig struct { dbRollback bool dbReset bool dbContinue bool + // envFile is populated by serpent for --help output; actual loading + // uses parseEnvFileFlag() before serpent runs. + envFile string projectRoot string binaryPath string configDir string @@ -435,8 +463,46 @@ func (c *devConfig) cmd(ctx context.Context, bin string, args ...string) *exec.C return cmd } -// filterEnv returns env with any variables whose key matches -// exclude removed. +// parseEnvFileFlag extracts the --env-file value from os.Args and +// CODER_DEV_ENV_FILE before serpent runs, so that loaded variables +// are visible to serpent's Env-tag resolution for other options. +func parseEnvFileFlag() (string, error) { + for i, arg := range os.Args[1:] { + if arg == "--env-file" { + if i+2 >= len(os.Args) { + return "", xerrors.New("--env-file requires a value") + } + return os.Args[i+2], nil + } + if v, ok := strings.CutPrefix(arg, "--env-file="); ok { + return v, nil + } + } + return os.Getenv("CODER_DEV_ENV_FILE"), nil +} + +// loadEnvFile reads the file at path using godotenv and sets any variables +// not already present in the process environment. It returns the number of +// variables set. +func loadEnvFile(path string) (int, error) { + vars, err := godotenv.Read(path) + if err != nil { + return 0, err + } + var n int + for key, val := range vars { + if _, exists := os.LookupEnv(key); exists { + continue + } + if err := os.Setenv(key, val); err != nil { + return n, err + } + n++ + } + return n, nil +} + +// filterEnv returns env with any variables whose key matches exclude removed. func filterEnv(env []string, exclude ...string) []string { out := make([]string, 0, len(env)) for _, e := range env { diff --git a/scripts/develop/main_test.go b/scripts/develop/main_test.go index 98bcd79f06..2491d52b4c 100644 --- a/scripts/develop/main_test.go +++ b/scripts/develop/main_test.go @@ -855,3 +855,134 @@ func TestPrometheusBannerEntry(t *testing.T) { }) } } + +//nolint:paralleltest // loadEnvFile mutates process-global environment. +func TestLoadEnvFile(t *testing.T) { + t.Run("LoadsVariablesFromFile", func(t *testing.T) { + tmpDir := t.TempDir() + envFile := filepath.Join(tmpDir, ".env") + err := os.WriteFile(envFile, []byte(strings.Join([]string{ + "# Comment line", + "", + "FOO_TEST_VAR=bar", + "export BAZ_TEST_VAR=qux", + `QUOTED_TEST_VAR="hello world"`, + "SINGLE_QUOTED_TEST_VAR='single quoted'", + }, "\n")), 0o600) + require.NoError(t, err) + + // Ensure none are set beforehand. + t.Setenv("FOO_TEST_VAR", "") + os.Unsetenv("FOO_TEST_VAR") + t.Setenv("BAZ_TEST_VAR", "") + os.Unsetenv("BAZ_TEST_VAR") + t.Setenv("QUOTED_TEST_VAR", "") + os.Unsetenv("QUOTED_TEST_VAR") + t.Setenv("SINGLE_QUOTED_TEST_VAR", "") + os.Unsetenv("SINGLE_QUOTED_TEST_VAR") + + n, err := loadEnvFile(envFile) + require.NoError(t, err) + assert.Equal(t, 4, n) + assert.Equal(t, "bar", os.Getenv("FOO_TEST_VAR")) + assert.Equal(t, "qux", os.Getenv("BAZ_TEST_VAR")) + assert.Equal(t, "hello world", os.Getenv("QUOTED_TEST_VAR")) + assert.Equal(t, "single quoted", os.Getenv("SINGLE_QUOTED_TEST_VAR")) + }) + + t.Run("DoesNotOverrideExisting", func(t *testing.T) { + tmpDir := t.TempDir() + envFile := filepath.Join(tmpDir, ".env") + err := os.WriteFile(envFile, []byte("EXISTING_TEST_VAR=new\n"), 0o600) + require.NoError(t, err) + + t.Setenv("EXISTING_TEST_VAR", "original") + + n, err := loadEnvFile(envFile) + require.NoError(t, err) + assert.Equal(t, 0, n) + assert.Equal(t, "original", os.Getenv("EXISTING_TEST_VAR")) + }) + + t.Run("ErrorsOnMissingFile", func(t *testing.T) { + _, err := loadEnvFile("/nonexistent/path/.env") + require.Error(t, err) + }) + + t.Run("ErrorsOnEmptyPath", func(t *testing.T) { + // This tests the caller logic (main), but we verify loadEnvFile + // would error on empty path since godotenv.Read("") fails. + _, err := loadEnvFile("") + require.Error(t, err) + }) +} + +//nolint:paralleltest // parseEnvFileFlag mutates process-global os.Args. +func TestParseEnvFileFlag(t *testing.T) { + t.Run("FlagWithSpace", func(t *testing.T) { + orig := os.Args + t.Cleanup(func() { os.Args = orig }) + os.Args = []string{"develop", "--env-file", "/tmp/test.env", "--port", "3000"} + + result, err := parseEnvFileFlag() + require.NoError(t, err) + assert.Equal(t, "/tmp/test.env", result) + }) + + t.Run("FlagWithEquals", func(t *testing.T) { + orig := os.Args + t.Cleanup(func() { os.Args = orig }) + os.Args = []string{"develop", "--env-file=/tmp/test.env", "--port", "3000"} + + result, err := parseEnvFileFlag() + require.NoError(t, err) + assert.Equal(t, "/tmp/test.env", result) + }) + + t.Run("FallsBackToEnvVar", func(t *testing.T) { + orig := os.Args + t.Cleanup(func() { os.Args = orig }) + os.Args = []string{"develop", "--port", "3000"} + + t.Setenv("CODER_DEV_ENV_FILE", "/tmp/from-env.env") + + result, err := parseEnvFileFlag() + require.NoError(t, err) + assert.Equal(t, "/tmp/from-env.env", result) + }) + + t.Run("FlagTakesPrecedenceOverEnvVar", func(t *testing.T) { + orig := os.Args + t.Cleanup(func() { os.Args = orig }) + os.Args = []string{"develop", "--env-file", "/tmp/from-flag.env"} + + t.Setenv("CODER_DEV_ENV_FILE", "/tmp/from-env.env") + + result, err := parseEnvFileFlag() + require.NoError(t, err) + assert.Equal(t, "/tmp/from-flag.env", result) + }) + + t.Run("ReturnsEmptyWhenUnset", func(t *testing.T) { + orig := os.Args + t.Cleanup(func() { os.Args = orig }) + os.Args = []string{"develop", "--port", "3000"} + + t.Setenv("CODER_DEV_ENV_FILE", "") + os.Unsetenv("CODER_DEV_ENV_FILE") + + result, err := parseEnvFileFlag() + require.NoError(t, err) + assert.Equal(t, "", result) + }) + + t.Run("ErrorsWhenValueMissing", func(t *testing.T) { + orig := os.Args + t.Cleanup(func() { os.Args = orig }) + os.Args = []string{"develop", "--env-file"} + + _, err := parseEnvFileFlag() + require.Error(t, err) + assert.Contains(t, err.Error(), "--env-file requires a value") + }) +}