diff --git a/cli/clibase/cmd.go b/cli/clibase/cmd.go index 2b9da50022..49a3cae718 100644 --- a/cli/clibase/cmd.go +++ b/cli/clibase/cmd.go @@ -172,8 +172,8 @@ type Invocation struct { // WithOS returns the invocation as a main package, filling in the invocation's unset // fields with OS defaults. -func (inv *Invocation) WithOS() *Invocation { - return inv.with(func(i *Invocation) { +func (i *Invocation) WithOS() *Invocation { + return i.with(func(i *Invocation) { i.Stdout = os.Stdout i.Stderr = os.Stderr i.Stdin = os.Stdin @@ -182,18 +182,18 @@ func (inv *Invocation) WithOS() *Invocation { }) } -func (inv *Invocation) Context() context.Context { - if inv.ctx == nil { +func (i *Invocation) Context() context.Context { + if i.ctx == nil { return context.Background() } - return inv.ctx + return i.ctx } -func (inv *Invocation) ParsedFlags() *pflag.FlagSet { - if inv.parsedFlags == nil { +func (i *Invocation) ParsedFlags() *pflag.FlagSet { + if i.parsedFlags == nil { panic("flags not parsed, has Run() been called?") } - return inv.parsedFlags + return i.parsedFlags } type runState struct { @@ -218,8 +218,30 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet { // run recursively executes the command and its children. // allArgs is wired through the stack so that global flags can be accepted // anywhere in the command invocation. -func (inv *Invocation) run(state *runState) error { - err := inv.Command.Options.ParseEnv(inv.Environ) +func (i *Invocation) run(state *runState) error { + err := i.Command.Options.SetDefaults() + if err != nil { + return xerrors.Errorf("setting defaults: %w", err) + } + + // If we set the Default of an array but later see a flag for it, we + // don't want to append, we want to replace. So, we need to keep the state + // of defaulted array options. + defaultedArrays := make(map[string]int) + for _, opt := range i.Command.Options { + sv, ok := opt.Value.(pflag.SliceValue) + if !ok { + continue + } + + if opt.Flag == "" { + continue + } + + defaultedArrays[opt.Flag] = len(sv.GetSlice()) + } + + err = i.Command.Options.ParseEnv(i.Environ) if err != nil { return xerrors.Errorf("parsing env: %w", err) } @@ -227,8 +249,8 @@ func (inv *Invocation) run(state *runState) error { // Now the fun part, argument parsing! children := make(map[string]*Cmd) - for _, child := range inv.Command.Children { - child.Parent = inv.Command + for _, child := range i.Command.Children { + child.Parent = i.Command for _, name := range append(child.Aliases, child.Name()) { if _, ok := children[name]; ok { return xerrors.Errorf("duplicate command name: %s", name) @@ -237,44 +259,57 @@ func (inv *Invocation) run(state *runState) error { } } - if inv.parsedFlags == nil { - inv.parsedFlags = pflag.NewFlagSet(inv.Command.Name(), pflag.ContinueOnError) + if i.parsedFlags == nil { + i.parsedFlags = pflag.NewFlagSet(i.Command.Name(), pflag.ContinueOnError) // We handle Usage ourselves. - inv.parsedFlags.Usage = func() {} + i.parsedFlags.Usage = func() {} } // If we find a duplicate flag, we want the deeper command's flag to override // the shallow one. Unfortunately, pflag has no way to remove a flag, so we // have to create a copy of the flagset without a value. - inv.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) { - if inv.parsedFlags.Lookup(f.Name) != nil { - inv.parsedFlags = copyFlagSetWithout(inv.parsedFlags, f.Name) + i.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) { + if i.parsedFlags.Lookup(f.Name) != nil { + i.parsedFlags = copyFlagSetWithout(i.parsedFlags, f.Name) } - inv.parsedFlags.AddFlag(f) + i.parsedFlags.AddFlag(f) }) var parsedArgs []string - if !inv.Command.RawArgs { + if !i.Command.RawArgs { // Flag parsing will fail on intermediate commands in the command tree, // so we check the error after looking for a child command. - state.flagParseErr = inv.parsedFlags.Parse(state.allArgs) - parsedArgs = inv.parsedFlags.Args() - } + state.flagParseErr = i.parsedFlags.Parse(state.allArgs) + parsedArgs = i.parsedFlags.Args() - // Set defaults for flags that weren't set by the user. - skipDefaults := make(map[int]struct{}, len(inv.Command.Options)) - for i, opt := range inv.Command.Options { - if fl := inv.parsedFlags.Lookup(opt.Flag); fl != nil && fl.Changed { - skipDefaults[i] = struct{}{} - } - if opt.envChanged { - skipDefaults[i] = struct{}{} - } - } - err = inv.Command.Options.SetDefaults(skipDefaults) - if err != nil { - return xerrors.Errorf("setting defaults: %w", err) + i.parsedFlags.VisitAll(func(f *pflag.Flag) { + i, ok := defaultedArrays[f.Name] + if !ok { + return + } + + if !f.Changed { + return + } + + // If flag was changed, we need to remove the default values. + sv, ok := f.Value.(pflag.SliceValue) + if !ok { + panic("defaulted array option is not a slice value") + } + ss := sv.GetSlice() + if len(ss) == 0 { + // Slice likely zeroed by a flag. + // E.g. "--fruit" may default to "apples,oranges" but the user + // provided "--fruit=""". + return + } + err := sv.Replace(ss[i:]) + if err != nil { + panic(err) + } + }) } // Run child command if found (next child only) @@ -283,64 +318,64 @@ func (inv *Invocation) run(state *runState) error { if len(parsedArgs) > state.commandDepth { nextArg := parsedArgs[state.commandDepth] if child, ok := children[nextArg]; ok { - child.Parent = inv.Command - inv.Command = child + child.Parent = i.Command + i.Command = child state.commandDepth++ - return inv.run(state) + return i.run(state) } } // Flag parse errors are irrelevant for raw args commands. - if !inv.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) { + if !i.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) { return xerrors.Errorf( "parsing flags (%v) for %q: %w", state.allArgs, - inv.Command.FullName(), state.flagParseErr, + i.Command.FullName(), state.flagParseErr, ) } - if inv.Command.RawArgs { + if i.Command.RawArgs { // If we're at the root command, then the name is omitted // from the arguments, so we can just use the entire slice. if state.commandDepth == 0 { - inv.Args = state.allArgs + i.Args = state.allArgs } else { - argPos, err := findArg(inv.Command.Name(), state.allArgs, inv.parsedFlags) + argPos, err := findArg(i.Command.Name(), state.allArgs, i.parsedFlags) if err != nil { panic(err) } - inv.Args = state.allArgs[argPos+1:] + i.Args = state.allArgs[argPos+1:] } } else { // In non-raw-arg mode, we want to skip over flags. - inv.Args = parsedArgs[state.commandDepth:] + i.Args = parsedArgs[state.commandDepth:] } - mw := inv.Command.Middleware + mw := i.Command.Middleware if mw == nil { mw = Chain() } - ctx := inv.ctx + ctx := i.ctx if ctx == nil { ctx = context.Background() } ctx, cancel := context.WithCancel(ctx) defer cancel() - inv = inv.WithContext(ctx) + i = i.WithContext(ctx) - if inv.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) { - if inv.Command.HelpHandler == nil { - return xerrors.Errorf("no handler or help for command %s", inv.Command.FullName()) + if i.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) { + if i.Command.HelpHandler == nil { + return xerrors.Errorf("no handler or help for command %s", i.Command.FullName()) } - return inv.Command.HelpHandler(inv) + return i.Command.HelpHandler(i) } - err = mw(inv.Command.Handler)(inv) + err = mw(i.Command.Handler)(i) if err != nil { return &RunCommandError{ - Cmd: inv.Command, + Cmd: i.Command, Err: err, } } @@ -403,33 +438,33 @@ func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) { // If two command share a flag name, the first command wins. // //nolint:revive -func (inv *Invocation) Run() (err error) { +func (i *Invocation) Run() (err error) { defer func() { // Pflag is panicky, so additional context is helpful in tests. if flag.Lookup("test.v") == nil { return } if r := recover(); r != nil { - err = xerrors.Errorf("panic recovered for %s: %v", inv.Command.FullName(), r) + err = xerrors.Errorf("panic recovered for %s: %v", i.Command.FullName(), r) panic(err) } }() - err = inv.run(&runState{ - allArgs: inv.Args, + err = i.run(&runState{ + allArgs: i.Args, }) return err } // WithContext returns a copy of the Invocation with the given context. -func (inv *Invocation) WithContext(ctx context.Context) *Invocation { - return inv.with(func(i *Invocation) { +func (i *Invocation) WithContext(ctx context.Context) *Invocation { + return i.with(func(i *Invocation) { i.ctx = ctx }) } // with returns a copy of the Invocation with the given function applied. -func (inv *Invocation) with(fn func(*Invocation)) *Invocation { - i2 := *inv +func (i *Invocation) with(fn func(*Invocation)) *Invocation { + i2 := *i fn(&i2) return &i2 } diff --git a/cli/clibase/cmd_test.go b/cli/clibase/cmd_test.go index f5ed6f6763..cf835327cf 100644 --- a/cli/clibase/cmd_test.go +++ b/cli/clibase/cmd_test.go @@ -247,7 +247,6 @@ func TestCommand_FlagOverride(t *testing.T) { Use: "1", Options: clibase.OptionSet{ { - Name: "flag", Flag: "f", Value: clibase.DiscardValue, }, @@ -257,7 +256,6 @@ func TestCommand_FlagOverride(t *testing.T) { Use: "2", Options: clibase.OptionSet{ { - Name: "flag", Flag: "f", Value: clibase.StringOf(&flag), }, @@ -529,17 +527,11 @@ func TestCommand_EmptySlice(t *testing.T) { } } - // Base-case, uses default. + // Base-case err := cmd("bad", "bad", "bad").Invoke().Run() require.NoError(t, err) - // Reset to nothing at all. inv := cmd().Invoke("--arr", "") err = inv.Run() require.NoError(t, err) - - // Override - inv = cmd("great").Invoke("--arr", "great") - err = inv.Run() - require.NoError(t, err) } diff --git a/cli/clibase/option.go b/cli/clibase/option.go index 7b294f4884..05b444c248 100644 --- a/cli/clibase/option.go +++ b/cli/clibase/option.go @@ -46,8 +46,6 @@ type Option struct { UseInstead []Option `json:"use_instead,omitempty"` Hidden bool `json:"hidden,omitempty"` - - envChanged bool } // OptionSet is a group of options that can be applied to a command. @@ -135,7 +133,6 @@ func (s *OptionSet) ParseEnv(vs []EnvVar) error { continue } - opt.envChanged = true if err := opt.Value.Set(envVal); err != nil { merr = multierror.Append( merr, xerrors.Errorf("parse %q: %w", opt.Name, err), @@ -146,27 +143,19 @@ func (s *OptionSet) ParseEnv(vs []EnvVar) error { return merr.ErrorOrNil() } -// SetDefaults sets the default values for each Option, skipping values -// that have already been set as indicated by the skip map. -func (s *OptionSet) SetDefaults(skip map[int]struct{}) error { +// SetDefaults sets the default values for each Option. +// It should be called before all parsing (e.g. ParseFlags, ParseEnv). +func (s *OptionSet) SetDefaults() error { if s == nil { return nil } var merr *multierror.Error - for i, opt := range *s { - // Skip values that may have already been set by the user. - if len(skip) > 0 { - if _, ok := skip[i]; ok { - continue - } - } - + for _, opt := range *s { if opt.Default == "" { continue } - if opt.Value == nil { merr = multierror.Append( merr, diff --git a/cli/clibase/option_test.go b/cli/clibase/option_test.go index 862e8098db..d9d38cc6c7 100644 --- a/cli/clibase/option_test.go +++ b/cli/clibase/option_test.go @@ -49,7 +49,7 @@ func TestOptionSet_ParseFlags(t *testing.T) { }, } - err := os.SetDefaults(nil) + err := os.SetDefaults() require.NoError(t, err) err = os.FlagSet().Parse([]string{"--name", "foo", "--name", "bar"}) @@ -111,7 +111,7 @@ func TestOptionSet_ParseEnv(t *testing.T) { }, } - err := os.SetDefaults(nil) + err := os.SetDefaults() require.NoError(t, err) err = os.ParseEnv(clibase.ParseEnviron([]string{"CODER_WORKSPACE_NAME="}, "CODER_")) diff --git a/cli/clibase/yaml_test.go b/cli/clibase/yaml_test.go index 62582a5252..3efad6ee54 100644 --- a/cli/clibase/yaml_test.go +++ b/cli/clibase/yaml_test.go @@ -44,7 +44,7 @@ func TestOption_ToYAML(t *testing.T) { }, } - err := os.SetDefaults(nil) + err := os.SetDefaults() require.NoError(t, err) n, err := os.ToYAML() diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index aaf9779d2e..ff9bf82add 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -1075,7 +1075,7 @@ QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8 func DeploymentValues(t *testing.T) *codersdk.DeploymentValues { var cfg codersdk.DeploymentValues opts := cfg.Options() - err := opts.SetDefaults(nil) + err := opts.SetDefaults() require.NoError(t, err) return &cfg }