fix(agent/usershell): check shell on darwin via dscl (#8366)

This commit is contained in:
Mathias Fredriksson
2023-07-11 20:27:50 +03:00
committed by GitHub
parent de1d04d7bb
commit e508d9aa6e
4 changed files with 70 additions and 31 deletions
+20 -3
View File
@@ -1,8 +1,25 @@
package usershell
import "os"
import (
"os"
"os/exec"
"path/filepath"
"strings"
"golang.org/x/xerrors"
)
// Get returns the $SHELL environment variable.
func Get(_ string) (string, error) {
return os.Getenv("SHELL"), nil
func Get(username string) (string, error) {
// This command will output "UserShell: /bin/zsh" if successful, we
// can ignore the error since we have fallback behavior.
out, _ := exec.Command("dscl", ".", "-read", filepath.Join("/Users", username), "UserShell").Output()
s, ok := strings.CutPrefix(string(out), "UserShell: ")
if ok {
return strings.TrimSpace(s), nil
}
if s = os.Getenv("SHELL"); s != "" {
return s, nil
}
return "", xerrors.Errorf("shell for user %q not found via dscl or in $SHELL", username)
}
+4 -1
View File
@@ -27,5 +27,8 @@ func Get(username string) (string, error) {
}
return parts[6], nil
}
return "", xerrors.Errorf("user %q not found in /etc/passwd", username)
if s := os.Getenv("SHELL"); s != "" {
return s, nil
}
return "", xerrors.Errorf("shell for user %q not found in /etc/passwd or $SHELL", username)
}
-27
View File
@@ -1,27 +0,0 @@
//go:build !windows && !darwin
// +build !windows,!darwin
package usershell_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/agent/usershell"
)
func TestGet(t *testing.T) {
t.Parallel()
t.Run("Has", func(t *testing.T) {
t.Parallel()
shell, err := usershell.Get("root")
require.NoError(t, err)
require.NotEmpty(t, shell)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
_, err := usershell.Get("notauser")
require.Error(t, err)
})
}
+46
View File
@@ -0,0 +1,46 @@
package usershell_test
import (
"os/user"
"runtime"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/agent/usershell"
)
//nolint:paralleltest,tparallel // This test sets an environment variable.
func TestGet(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
t.Run("Fallback", func(t *testing.T) {
t.Setenv("SHELL", "/bin/sh")
t.Run("NonExistentUser", func(t *testing.T) {
shell, err := usershell.Get("notauser")
require.NoError(t, err)
require.Equal(t, "/bin/sh", shell)
})
})
t.Run("NoFallback", func(t *testing.T) {
// Disable env fallback for these tests.
t.Setenv("SHELL", "")
t.Run("NotFound", func(t *testing.T) {
_, err := usershell.Get("notauser")
require.Error(t, err)
})
t.Run("User", func(t *testing.T) {
u, err := user.Current()
require.NoError(t, err)
shell, err := usershell.Get(u.Username)
require.NoError(t, err)
require.NotEmpty(t, shell)
})
})
}