diff --git a/cli/ssh.go b/cli/ssh.go index 323d2913ae..b1d66f28a4 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -109,6 +109,51 @@ func (r *RootCmd) ssh() *serpent.Command { } }, ), + CompletionHandler: func(inv *serpent.Invocation) []string { + client, err := r.InitClient(inv) + if err != nil { + return []string{} + } + + res, err := client.Workspaces(inv.Context(), codersdk.WorkspaceFilter{ + Owner: codersdk.Me, + }) + if err != nil { + return []string{} + } + + var mu sync.Mutex + var completions []string + var wg sync.WaitGroup + for _, ws := range res.Workspaces { + wg.Add(1) + go func() { + defer wg.Done() + resources, err := client.TemplateVersionResources(inv.Context(), ws.LatestBuild.TemplateVersionID) + if err != nil { + return + } + var agents []codersdk.WorkspaceAgent + for _, resource := range resources { + agents = append(agents, resource.Agents...) + } + + mu.Lock() + defer mu.Unlock() + if len(agents) == 1 { + completions = append(completions, ws.Name) + } else { + for _, agent := range agents { + completions = append(completions, fmt.Sprintf("%s.%s", ws.Name, agent.Name)) + } + } + }() + } + wg.Wait() + + slices.Sort(completions) + return completions + }, Handler: func(inv *serpent.Invocation) (retErr error) { client, err := r.InitClient(inv) if err != nil { diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 652e0e9c01..7ce9d85258 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -2447,3 +2447,99 @@ func tempDirUnixSocket(t *testing.T) string { return t.TempDir() } + +func TestSSH_Completion(t *testing.T) { + t.Parallel() + + t.Run("SingleAgent", func(t *testing.T) { + t.Parallel() + + client, workspace, agentToken := setupWorkspaceForAgent(t) + _ = agenttest.New(t, client.URL, agentToken) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) + + var stdout bytes.Buffer + inv, root := clitest.New(t, "ssh", "") + inv.Stdout = &stdout + inv.Environ.Set("COMPLETION_MODE", "1") + clitest.SetupConfig(t, client, root) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + err := inv.WithContext(ctx).Run() + require.NoError(t, err) + + // For single-agent workspaces, the only completion should be the + // bare workspace name. + output := stdout.String() + t.Logf("Completion output: %q", output) + require.Contains(t, output, workspace.Name) + }) + + t.Run("MultiAgent", func(t *testing.T) { + t.Parallel() + + client, store := coderdtest.NewWithDatabase(t, nil) + first := coderdtest.CreateFirstUser(t, client) + userClient, user := coderdtest.CreateAnotherUserMutators(t, client, first.OrganizationID, nil, func(r *codersdk.CreateUserRequestWithOrgs) { + r.Username = "multiuser" + }) + + r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{ + Name: "multiworkspace", + OrganizationID: first.OrganizationID, + OwnerID: user.ID, + }).WithAgent(func(agents []*proto.Agent) []*proto.Agent { + return []*proto.Agent{ + { + Name: "agent1", + Auth: &proto.Agent_Token{}, + }, + { + Name: "agent2", + Auth: &proto.Agent_Token{}, + }, + } + }).Do() + + var stdout bytes.Buffer + inv, root := clitest.New(t, "ssh", "") + inv.Stdout = &stdout + inv.Environ.Set("COMPLETION_MODE", "1") + clitest.SetupConfig(t, userClient, root) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) + defer cancel() + + err := inv.WithContext(ctx).Run() + require.NoError(t, err) + + // For multi-agent workspaces, completions should include the + // workspace.agent format but NOT the bare workspace name. + output := stdout.String() + t.Logf("Completion output: %q", output) + lines := strings.Split(strings.TrimSpace(output), "\n") + require.NotContains(t, lines, r.Workspace.Name) + require.Contains(t, output, r.Workspace.Name+".agent1") + require.Contains(t, output, r.Workspace.Name+".agent2") + }) + + t.Run("NetworkError", func(t *testing.T) { + t.Parallel() + + var stdout bytes.Buffer + inv, _ := clitest.New(t, "ssh", "") + inv.Stdout = &stdout + inv.Environ.Set("COMPLETION_MODE", "1") + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + err := inv.WithContext(ctx).Run() + require.NoError(t, err) + + output := stdout.String() + require.Empty(t, output) + }) +}