diff --git a/cli/root.go b/cli/root.go index b20520b192..6064796534 100644 --- a/cli/root.go +++ b/cli/root.go @@ -1275,6 +1275,12 @@ func (e *exitError) Unwrap() error { return e.err } +// ExitCode returns the OS exit code that the CLI will use when this error is +// returned from a command handler. +func (e *exitError) ExitCode() int { + return e.code +} + // ExitError returns an error that will cause the CLI to exit with the given // exit code. If err is non-nil, it will be wrapped by the returned error. func ExitError(code int, err error) error { diff --git a/cli/ssh.go b/cli/ssh.go index d638aefb36..e7d62b29d4 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -116,6 +116,7 @@ func retryWithInterval(ctx context.Context, logger slog.Logger, interval time.Du func (r *RootCmd) ssh() *serpent.Command { var ( stdio bool + tty bool hostPrefix string hostnameSuffix string forceNewTunnel bool @@ -633,9 +634,15 @@ func (r *RootCmd) ssh() *serpent.Command { } } + // Command mode must not request a PTY by default. A PTY + // interposes line discipline on the remote stdin which would + // prevent EOF from propagating to commands that read until + // EOF (e.g. `cat`, `wc`, `tar`). Interactive shell sessions + // always need a PTY, and command mode can opt in via --tty. + requestPTY := command == "" || tty stdinFile, validIn := inv.Stdin.(*os.File) stdoutFile, validOut := inv.Stdout.(*os.File) - if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) { + if requestPTY && validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) { inState, err := pty.MakeInputRaw(stdinFile.Fd()) if err != nil { return err @@ -685,18 +692,29 @@ func (r *RootCmd) ssh() *serpent.Command { } } - err = sshSession.RequestPty("xterm-256color", 128, 128, gossh.TerminalModes{}) - if err != nil { - return xerrors.Errorf("request pty: %w", err) - } - sshSession.Stdin = inv.Stdin sshSession.Stdout = inv.Stdout sshSession.Stderr = inv.Stderr + if requestPTY { + err = sshSession.RequestPty("xterm-256color", 128, 128, gossh.TerminalModes{}) + if err != nil { + return xerrors.Errorf("request pty: %w", err) + } + } + if command != "" { err := sshSession.Run(command) if err != nil { + if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) { + // Preserve the remote command's exit status as the CLI + // exit code, but clear the error since it's not useful + // beyond reporting status. + return ExitError(exitErr.ExitStatus(), nil) + } + if missingErr := (&gossh.ExitMissingError{}); errors.As(err, &missingErr) { + return ExitError(255, xerrors.New("SSH connection ended unexpectedly")) + } return xerrors.Errorf("run command: %w", err) } } else { @@ -728,7 +746,7 @@ func (r *RootCmd) ssh() *serpent.Command { // If the connection drops unexpectedly, we get an // ExitMissingError but no other error details, so try to at // least give the user a better message - if errors.Is(err, &gossh.ExitMissingError{}) { + if missingErr := (&gossh.ExitMissingError{}); errors.As(err, &missingErr) { return ExitError(255, xerrors.New("SSH connection ended unexpectedly")) } return xerrors.Errorf("session ended: %w", err) @@ -751,6 +769,13 @@ func (r *RootCmd) ssh() *serpent.Command { Description: "Specifies whether to emit SSH output over stdin/stdout.", Value: serpent.BoolOf(&stdio), }, + { + Flag: "tty", + FlagShorthand: "t", + Env: "CODER_SSH_TTY", + Description: "Request a pseudo-terminal for the SSH session. Interactive shell sessions request one by default; command sessions do not unless this flag is set.", + Value: serpent.BoolOf(&tty), + }, { Flag: "ssh-host-prefix", Env: "CODER_SSH_SSH_HOST_PREFIX", diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 8f4c74e1ec..6b8392060c 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -2302,9 +2302,9 @@ func TestSSH_CoderConnect(t *testing.T) { err := inv.WithContext(ctx).Run() assert.Error(t, err) - var exitErr *ssh.ExitError + var exitErr interface{ ExitCode() int } assert.True(t, errors.As(err, &exitErr)) - assert.Equal(t, 1, exitErr.ExitStatus()) + assert.Equal(t, 1, exitErr.ExitCode()) }) }) @@ -2368,6 +2368,81 @@ func TestSSH_CoderConnect(t *testing.T) { }) } +func TestSSH_OneShotCommandMode(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("'test' shell command and wc are not available on Windows") + } + + client, workspace, agentToken := setupWorkspaceForAgent(t) + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + t.Run("DoesNotRequestPTY", func(t *testing.T) { + t.Parallel() + + output := new(bytes.Buffer) + inv, root := clitest.New(t, "ssh", workspace.Name, "test -t 0 && echo tty || echo not-tty") + clitest.SetupConfig(t, client, root) + inv.Stdout = output + inv.Stderr = io.Discard + + ctx := testutil.Context(t, testutil.WaitShort) + err := inv.WithContext(ctx).Run() + require.NoError(t, err) + require.Equal(t, "not-tty", strings.TrimSpace(output.String())) + }) + + t.Run("RequestsPTYWithFlag", func(t *testing.T) { + t.Parallel() + + output := new(bytes.Buffer) + inv, root := clitest.New(t, "ssh", "--tty", workspace.Name, "test -t 0 && echo tty || echo not-tty") + clitest.SetupConfig(t, client, root) + inv.Stdout = output + inv.Stderr = io.Discard + + ctx := testutil.Context(t, testutil.WaitShort) + err := inv.WithContext(ctx).Run() + require.NoError(t, err) + require.Equal(t, "tty", strings.TrimSpace(output.String())) + }) + + t.Run("ClosesStdinOnEOF", func(t *testing.T) { + t.Parallel() + + output := new(bytes.Buffer) + inv, root := clitest.New(t, "ssh", workspace.Name, "wc -l") + clitest.SetupConfig(t, client, root) + inv.Stdin = strings.NewReader("a\nb\nc\n") + inv.Stdout = output + inv.Stderr = io.Discard + + ctx := testutil.Context(t, testutil.WaitShort) + err := inv.WithContext(ctx).Run() + require.NoError(t, err) + require.Equal(t, "3", strings.TrimSpace(output.String())) + }) + + t.Run("PropagatesExitCode", func(t *testing.T) { + t.Parallel() + + // Use a non-1 exit code so that we don't accidentally pass when the + // CLI falls back to the default exit code of 1 for any error. + inv, root := clitest.New(t, "ssh", workspace.Name, "exit 2") + clitest.SetupConfig(t, client, root) + inv.Stderr = io.Discard + + ctx := testutil.Context(t, testutil.WaitShort) + err := inv.WithContext(ctx).Run() + require.Error(t, err) + + var cliExitErr interface{ ExitCode() int } + require.ErrorAs(t, err, &cliExitErr) + require.Equal(t, 2, cliExitErr.ExitCode()) + }) +} + type fakeCoderConnectDialer struct{} func (*fakeCoderConnectDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { diff --git a/cli/testdata/coder_ssh_--help.golden b/cli/testdata/coder_ssh_--help.golden index 8019dbdc2a..b75ad909dd 100644 --- a/cli/testdata/coder_ssh_--help.golden +++ b/cli/testdata/coder_ssh_--help.golden @@ -67,6 +67,11 @@ OPTIONS: --stdio bool, $CODER_SSH_STDIO Specifies whether to emit SSH output over stdin/stdout. + -t, --tty bool, $CODER_SSH_TTY + Request a pseudo-terminal for the SSH session. Interactive shell + sessions request one by default; command sessions do not unless this + flag is set. + --wait yes|no|auto, $CODER_SSH_WAIT (default: auto) Specifies whether or not to wait for the startup script to finish executing. Auto means that the agent startup script behavior diff --git a/docs/reference/cli/ssh.md b/docs/reference/cli/ssh.md index aaa76bd256..4f5ec13177 100644 --- a/docs/reference/cli/ssh.md +++ b/docs/reference/cli/ssh.md @@ -30,6 +30,15 @@ This command does not have full parity with the standard SSH command. For users Specifies whether to emit SSH output over stdin/stdout. +### -t, --tty + +| | | +|-------------|-----------------------------| +| Type | bool | +| Environment | $CODER_SSH_TTY | + +Request a pseudo-terminal for the SSH session. Interactive shell sessions request one by default; command sessions do not unless this flag is set. + ### --ssh-host-prefix | | |