mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
refactor: remove agents TUI (#25190)
This commit is contained in:
-205
@@ -1,205 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/google/uuid"
|
||||
"github.com/muesli/termenv"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func installTUISignalHandler(p *tea.Program) func() {
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
|
||||
defer func() {
|
||||
signal.Stop(sig)
|
||||
close(ch)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
return
|
||||
case <-sig:
|
||||
p.Send(terminateTUIMsg{})
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() {
|
||||
ch <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func fitHelpText(width int, candidates ...string) string {
|
||||
if len(candidates) == 0 {
|
||||
return ""
|
||||
}
|
||||
if width <= 0 {
|
||||
return candidates[0]
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
if lipgloss.Width(candidate) <= width {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
return truncateText(candidates[len(candidates)-1], width, " •|│:", 1)
|
||||
}
|
||||
|
||||
func truncateText(text string, width int, trimRightCutset string, ellipsisWidth int) string {
|
||||
if width <= 0 {
|
||||
return ""
|
||||
}
|
||||
if lipgloss.Width(text) <= width {
|
||||
return text
|
||||
}
|
||||
if width <= ellipsisWidth {
|
||||
return "…"
|
||||
}
|
||||
for runes := []rune(text); len(runes) > 0; runes = runes[:len(runes)-1] {
|
||||
truncated := strings.TrimRight(string(runes), trimRightCutset) + "…"
|
||||
if lipgloss.Width(truncated) <= width {
|
||||
return truncated
|
||||
}
|
||||
}
|
||||
return "…"
|
||||
}
|
||||
|
||||
func (r *RootCmd) agentsCommand() *serpent.Command {
|
||||
var (
|
||||
workspaceFlag string
|
||||
modelFlag string
|
||||
)
|
||||
|
||||
return &serpent.Command{
|
||||
Use: "agents [chat-id]",
|
||||
Short: "Interactive terminal UI for AI agents.",
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Name: "workspace",
|
||||
Flag: "workspace",
|
||||
Description: "Associate the chat with a workspace by name, owner/name, or UUID.",
|
||||
Value: serpent.StringOf(&workspaceFlag),
|
||||
},
|
||||
{
|
||||
Name: "model",
|
||||
Flag: "model",
|
||||
Description: "Choose a model by ID, provider/model, or display name.",
|
||||
Value: serpent.StringOf(&modelFlag),
|
||||
},
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
orgs, err := client.OrganizationsByUser(inv.Context(), codersdk.Me)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("list organizations: %w", err)
|
||||
}
|
||||
if len(orgs) == 0 {
|
||||
return xerrors.New("no organizations found")
|
||||
}
|
||||
defaultOrgID := orgs[0].ID
|
||||
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
if len(inv.Args) > 1 {
|
||||
return xerrors.New("expected zero or one chat ID")
|
||||
}
|
||||
|
||||
var initialChatID *uuid.UUID
|
||||
if len(inv.Args) == 1 {
|
||||
chatID, err := uuid.Parse(inv.Args[0])
|
||||
if err != nil {
|
||||
return xerrors.Errorf("invalid chat ID %q: %w", inv.Args[0], err)
|
||||
}
|
||||
initialChatID = &chatID
|
||||
}
|
||||
|
||||
var workspaceID *uuid.UUID
|
||||
if workspaceFlag != "" {
|
||||
workspace, err := client.ResolveWorkspace(inv.Context(), workspaceFlag)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve workspace %q: %w", workspaceFlag, err)
|
||||
}
|
||||
workspaceID = &workspace.ID
|
||||
}
|
||||
|
||||
modelID, err := resolveModel(inv.Context(), expClient, modelFlag)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set an explicit color profile before Bubble Tea acquires the
|
||||
// terminal so lipgloss/termenv don't send OSC color queries that
|
||||
// can leak back into stdin as literal input in some terminals.
|
||||
renderer := lipgloss.NewRenderer(
|
||||
inv.Stdout,
|
||||
termenv.WithProfile(termenv.TrueColor),
|
||||
)
|
||||
renderer.SetHasDarkBackground(true)
|
||||
|
||||
model := newChatsTUIModel(inv.Context(), expClient, initialChatID, workspaceID, modelID, defaultOrgID)
|
||||
model.setRenderer(renderer)
|
||||
program := tea.NewProgram(
|
||||
model,
|
||||
tea.WithAltScreen(),
|
||||
tea.WithoutSignalHandler(),
|
||||
tea.WithContext(inv.Context()),
|
||||
tea.WithInput(inv.Stdin),
|
||||
tea.WithOutput(inv.Stdout),
|
||||
)
|
||||
|
||||
closeSignalHandler := installTUISignalHandler(program)
|
||||
defer closeSignalHandler()
|
||||
|
||||
runModel, err := program.Run()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := runModel.(chatsTUIModel); !ok {
|
||||
return xerrors.Errorf("unknown model found %T (%+v)", runModel, runModel)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:nilnil // A nil string indicates that no model override was provided.
|
||||
func resolveModel(ctx context.Context, client *codersdk.ExperimentalClient, modelFlag string) (*string, error) {
|
||||
if modelFlag == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if _, err := uuid.Parse(modelFlag); err == nil {
|
||||
return &modelFlag, nil
|
||||
}
|
||||
|
||||
catalog, err := client.ListChatModels(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("listing models: %w", err)
|
||||
}
|
||||
|
||||
for _, provider := range catalog.Providers {
|
||||
for _, model := range provider.Models {
|
||||
if model.ID == modelFlag || model.Provider+"/"+model.Model == modelFlag || model.DisplayName == modelFlag {
|
||||
return &model.ID, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, xerrors.Errorf("unknown model %q", modelFlag)
|
||||
}
|
||||
-1444
File diff suppressed because it is too large
Load Diff
@@ -1,180 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
type (
|
||||
chatsListedMsg struct {
|
||||
chats []codersdk.Chat
|
||||
err error
|
||||
}
|
||||
chatOpenedMsg struct {
|
||||
generation uint64
|
||||
chatID uuid.UUID
|
||||
chat codersdk.Chat
|
||||
err error
|
||||
}
|
||||
chatHistoryMsg struct {
|
||||
generation uint64
|
||||
chatID uuid.UUID
|
||||
messages []codersdk.ChatMessage
|
||||
err error
|
||||
}
|
||||
chatCreatedMsg struct {
|
||||
generation uint64
|
||||
chatID uuid.UUID
|
||||
chat codersdk.Chat
|
||||
err error
|
||||
}
|
||||
chatPlanModeUpdatedMsg struct {
|
||||
generation uint64
|
||||
chatID uuid.UUID
|
||||
err error
|
||||
}
|
||||
messageSentMsg struct {
|
||||
generation uint64
|
||||
chatID uuid.UUID
|
||||
resp codersdk.CreateChatMessageResponse
|
||||
err error
|
||||
}
|
||||
chatInterruptedMsg struct {
|
||||
generation uint64
|
||||
chatID uuid.UUID
|
||||
chat codersdk.Chat
|
||||
err error
|
||||
}
|
||||
modelsListedMsg struct {
|
||||
catalog codersdk.ChatModelsResponse
|
||||
err error
|
||||
}
|
||||
diffContentsMsg struct {
|
||||
generation uint64
|
||||
chatID uuid.UUID
|
||||
diff codersdk.ChatDiffContents
|
||||
err error
|
||||
}
|
||||
chatStreamEventMsg struct {
|
||||
generation uint64
|
||||
chatID uuid.UUID
|
||||
event codersdk.ChatStreamEvent
|
||||
err error
|
||||
}
|
||||
// showAskUserQuestionMsg tells the parent model to open the
|
||||
// ask-user-question overlay.
|
||||
showAskUserQuestionMsg struct {
|
||||
state *askUserQuestionState
|
||||
}
|
||||
// hideAskUserQuestionMsg tells the parent model to close the
|
||||
// ask-user-question overlay.
|
||||
hideAskUserQuestionMsg struct{}
|
||||
// toolResultsSubmittedMsg is sent after the async SubmitToolResults
|
||||
// call completes.
|
||||
toolResultsSubmittedMsg struct {
|
||||
generation uint64
|
||||
chatID uuid.UUID
|
||||
err error
|
||||
}
|
||||
streamRetryMsg struct {
|
||||
generation uint64
|
||||
}
|
||||
toggleModelPickerMsg struct{}
|
||||
toggleDiffDrawerMsg struct{}
|
||||
)
|
||||
|
||||
func scheduleStreamRetry(generation uint64, delay time.Duration) tea.Cmd {
|
||||
return tea.Tick(delay, func(time.Time) tea.Msg {
|
||||
return streamRetryMsg{generation: generation}
|
||||
})
|
||||
}
|
||||
|
||||
func apiCmd[T any](fn func() (T, error), wrap func(T, error) tea.Msg) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
value, err := fn()
|
||||
return wrap(value, err)
|
||||
}
|
||||
}
|
||||
|
||||
func loadChatHistoryCmd(ctx context.Context, client *codersdk.ExperimentalClient, chatID uuid.UUID, generation uint64) tea.Cmd {
|
||||
return apiCmd(func() ([]codersdk.ChatMessage, error) {
|
||||
var (
|
||||
allMessages []codersdk.ChatMessage
|
||||
opts *codersdk.ChatMessagesPaginationOptions
|
||||
)
|
||||
|
||||
for {
|
||||
resp, err := client.GetChatMessages(ctx, chatID, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allMessages = append(allMessages, resp.Messages...)
|
||||
if !resp.HasMore || len(resp.Messages) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
opts = &codersdk.ChatMessagesPaginationOptions{
|
||||
BeforeID: resp.Messages[len(resp.Messages)-1].ID,
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortStableFunc(allMessages, func(a, b codersdk.ChatMessage) int {
|
||||
switch {
|
||||
case a.CreatedAt.Before(b.CreatedAt):
|
||||
return -1
|
||||
case a.CreatedAt.After(b.CreatedAt):
|
||||
return 1
|
||||
case a.ID < b.ID:
|
||||
return -1
|
||||
case a.ID > b.ID:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
|
||||
return allMessages, nil
|
||||
}, func(messages []codersdk.ChatMessage, err error) tea.Msg {
|
||||
return chatHistoryMsg{generation: generation, chatID: chatID, messages: messages, err: err}
|
||||
})
|
||||
}
|
||||
|
||||
func submitAskUserQuestionCmd(client *codersdk.Client, chatID uuid.UUID, generation uint64, state *askUserQuestionState) tea.Cmd {
|
||||
output, err := buildAskUserQuestionToolResult(state)
|
||||
if err != nil {
|
||||
return func() tea.Msg {
|
||||
return toolResultsSubmittedMsg{generation: generation, chatID: chatID, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
req := codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{{
|
||||
ToolCallID: state.ToolCallID,
|
||||
Output: output,
|
||||
IsError: false,
|
||||
}},
|
||||
}
|
||||
return apiCmd(func() (struct{}, error) {
|
||||
return struct{}{}, codersdk.NewExperimentalClient(client).SubmitToolResults(context.Background(), chatID, req)
|
||||
}, func(_ struct{}, err error) tea.Msg {
|
||||
return toolResultsSubmittedMsg{generation: generation, chatID: chatID, err: err}
|
||||
})
|
||||
}
|
||||
|
||||
func listenToStream(chatID uuid.UUID, generation uint64, eventCh <-chan codersdk.ChatStreamEvent) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
event, ok := <-eventCh
|
||||
if !ok {
|
||||
return chatStreamEventMsg{generation: generation, chatID: chatID, err: io.EOF}
|
||||
}
|
||||
return chatStreamEventMsg{generation: generation, chatID: chatID, event: event}
|
||||
}
|
||||
}
|
||||
@@ -1,332 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
const localChatDiffWatchTimeout = 5 * time.Second
|
||||
|
||||
// localChatDiffReadLimit bounds the size of the Changes message the
|
||||
// client is willing to receive from the chat git watcher. agentgit
|
||||
// caps each repository's UnifiedDiff at ~3 MiB (maxTotalDiffSize),
|
||||
// and a Changes payload can aggregate many repos plus metadata, so
|
||||
// 4 MiB is too tight for realistic multi-repo worktrees. 32 MiB
|
||||
// covers ~10 maxed-out repos; pathological payloads beyond that still
|
||||
// fall back to the remote empty diff via errLocalDiffWatchClosed /
|
||||
// shouldIgnoreLocalDiffFallbackError.
|
||||
const localChatDiffReadLimit = 32 << 20 // 32 MiB
|
||||
|
||||
// errLocalDiffWatchClosed is returned when the chat git watcher
|
||||
// websocket closes during the Changes read loop with one of the
|
||||
// known-safe close statuses:
|
||||
//
|
||||
// - StatusMessageTooBig: the Changes payload exceeded our local
|
||||
// 32 MiB client read limit (localChatDiffReadLimit).
|
||||
// - StatusGoingAway: the coderd watchChatGit proxy tore the
|
||||
// client stream down. This is the status the proxy always uses
|
||||
// in coderd/exp_chats.go, so it also covers the upstream 4 MiB
|
||||
// read limit on agent->coderd messages (see
|
||||
// workspacesdk/agentconn.go): when that limit is exceeded the
|
||||
// agent closes with StatusMessageTooBig, but the proxy does not
|
||||
// propagate that status and the client only ever observes
|
||||
// StatusGoingAway.
|
||||
//
|
||||
// Both cases degrade to the remote empty diff returned by /diff:
|
||||
// the local watcher is a supplementary enrichment source that
|
||||
// cannot improve on the remote when its stream is cut short. Other
|
||||
// close statuses (StatusInternalError, StatusProtocolError, ...)
|
||||
// and non-close read errors still surface as hard errors so real
|
||||
// protocol regressions are not hidden behind the fallback.
|
||||
var errLocalDiffWatchClosed = xerrors.New("chat git watcher connection closed before delivering a Changes message")
|
||||
|
||||
func fetchChatDiffContents(
|
||||
ctx context.Context,
|
||||
client *codersdk.ExperimentalClient,
|
||||
chatID uuid.UUID,
|
||||
) (codersdk.ChatDiffContents, error) {
|
||||
remoteDiff, err := client.GetChatDiffContents(ctx, chatID)
|
||||
if err != nil {
|
||||
return codersdk.ChatDiffContents{}, err
|
||||
}
|
||||
if strings.TrimSpace(remoteDiff.Diff) != "" {
|
||||
return remoteDiff, nil
|
||||
}
|
||||
|
||||
localDiff, localSingleRepo, err := fetchLocalChatDiffContents(ctx, client, chatID)
|
||||
if err != nil {
|
||||
if shouldIgnoreLocalDiffFallbackError(err) {
|
||||
return remoteDiff, nil
|
||||
}
|
||||
return codersdk.ChatDiffContents{}, err
|
||||
}
|
||||
if strings.TrimSpace(localDiff.Diff) == "" {
|
||||
return remoteDiff, nil
|
||||
}
|
||||
|
||||
// Backfill metadata from the remote diff only when the local
|
||||
// watcher produced a single contributing repository. Gate this on
|
||||
// the explicit single-repo signal from buildLocalChatDiffContents
|
||||
// rather than on Branch/RemoteOrigin being non-nil, because a
|
||||
// single contributing repo can legitimately have an empty branch
|
||||
// (detached HEAD) or no origin remote and we still want remote
|
||||
// fields like Provider/PullRequestURL to flow through. Multi-repo
|
||||
// aggregates cannot be described by a single remote's metadata, so
|
||||
// we leave them alone.
|
||||
if localSingleRepo {
|
||||
if localDiff.Provider == nil {
|
||||
localDiff.Provider = remoteDiff.Provider
|
||||
}
|
||||
if localDiff.RemoteOrigin == nil {
|
||||
localDiff.RemoteOrigin = remoteDiff.RemoteOrigin
|
||||
}
|
||||
if localDiff.Branch == nil {
|
||||
localDiff.Branch = remoteDiff.Branch
|
||||
}
|
||||
if localDiff.PullRequestURL == nil {
|
||||
localDiff.PullRequestURL = remoteDiff.PullRequestURL
|
||||
}
|
||||
}
|
||||
return localDiff, nil
|
||||
}
|
||||
|
||||
// fetchLocalChatDiffContents returns the aggregated local-watcher diff
|
||||
// and a singleRepo flag that indicates whether that aggregate came from
|
||||
// exactly one contributing repository. The caller uses singleRepo to
|
||||
// decide whether it is safe to backfill remote-only metadata onto the
|
||||
// local diff. All error paths return singleRepo=false.
|
||||
//
|
||||
// This intentionally bypasses wsjson.NewStream and reads the websocket
|
||||
// directly so we can inspect the close status: an oversized Changes
|
||||
// payload must degrade to the remote empty diff via
|
||||
// errLocalDiffWatchClosed + shouldIgnoreLocalDiffFallbackError,
|
||||
// but wsjson.Decoder swallows the read error (logs at debug) and
|
||||
// closes the channel, which would collapse that specific case into
|
||||
// the same generic "connection closed" bucket as server crashes or
|
||||
// decode failures. Reading directly lets us narrowly fall back only
|
||||
// for read-limit violations while still surfacing real protocol
|
||||
// regressions.
|
||||
func fetchLocalChatDiffContents(
|
||||
parentCtx context.Context,
|
||||
client *codersdk.ExperimentalClient,
|
||||
chatID uuid.UUID,
|
||||
) (codersdk.ChatDiffContents, bool, error) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, localChatDiffWatchTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := dialChatGit(ctx, client, chatID)
|
||||
if err != nil {
|
||||
return codersdk.ChatDiffContents{}, false, err
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close(websocket.StatusNormalClosure, "")
|
||||
}()
|
||||
conn.SetReadLimit(localChatDiffReadLimit)
|
||||
|
||||
refreshPayload, err := json.Marshal(codersdk.WorkspaceAgentGitClientMessage{
|
||||
Type: codersdk.WorkspaceAgentGitClientMessageTypeRefresh,
|
||||
})
|
||||
if err != nil {
|
||||
return codersdk.ChatDiffContents{}, false, xerrors.Errorf("marshal git refresh: %w", err)
|
||||
}
|
||||
if err := conn.Write(ctx, websocket.MessageText, refreshPayload); err != nil {
|
||||
return codersdk.ChatDiffContents{}, false, xerrors.Errorf("request git refresh: %w", err)
|
||||
}
|
||||
|
||||
for {
|
||||
msgType, payload, err := conn.Read(ctx)
|
||||
if err != nil {
|
||||
// Context expiration gets its own wrapping so it threads
|
||||
// cleanly through shouldIgnoreLocalDiffFallbackError's
|
||||
// context.DeadlineExceeded case.
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
return codersdk.ChatDiffContents{}, false, xerrors.Errorf("watch chat git: %w", ctxErr)
|
||||
}
|
||||
// A Changes payload that exceeds localChatDiffReadLimit
|
||||
// causes coder/websocket to close the connection with
|
||||
// StatusMessageTooBig. The coderd watchChatGit proxy
|
||||
// also always closes the client with StatusGoingAway
|
||||
// (see coderd/exp_chats.go), which is how we observe
|
||||
// the upstream 4 MiB agent->coderd read-limit breach:
|
||||
// the agent closes its own hop with StatusMessageTooBig,
|
||||
// but the proxy does not propagate that status, so the
|
||||
// client only ever sees StatusGoingAway. Map both onto
|
||||
// the narrow sentinel so shouldIgnoreLocalDiffFallbackError
|
||||
// can degrade to the remote empty diff instead of
|
||||
// surfacing a hard error. Every other close status
|
||||
// (StatusInternalError, StatusProtocolError, ...) and
|
||||
// every non-close read error still propagates so real
|
||||
// protocol regressions reach the user.
|
||||
switch websocket.CloseStatus(err) {
|
||||
case websocket.StatusMessageTooBig, websocket.StatusGoingAway:
|
||||
return codersdk.ChatDiffContents{}, false, errLocalDiffWatchClosed
|
||||
}
|
||||
return codersdk.ChatDiffContents{}, false, xerrors.Errorf("read git watch: %w", err)
|
||||
}
|
||||
// Ignore unexpected frame types instead of erroring; the
|
||||
// watcher only emits text frames today and a future binary
|
||||
// heartbeat should not break the overlay.
|
||||
if msgType != websocket.MessageText {
|
||||
continue
|
||||
}
|
||||
var msg codersdk.WorkspaceAgentGitServerMessage
|
||||
if err := json.Unmarshal(payload, &msg); err != nil {
|
||||
return codersdk.ChatDiffContents{}, false, xerrors.Errorf("decode git watch message: %w", err)
|
||||
}
|
||||
switch msg.Type {
|
||||
case codersdk.WorkspaceAgentGitServerMessageTypeError:
|
||||
message := strings.TrimSpace(msg.Message)
|
||||
if message == "" {
|
||||
message = "git watch returned an unknown error"
|
||||
}
|
||||
return codersdk.ChatDiffContents{}, false, xerrors.New(message)
|
||||
case codersdk.WorkspaceAgentGitServerMessageTypeChanges:
|
||||
diff, singleRepo := buildLocalChatDiffContents(chatID, msg.Repositories)
|
||||
return diff, singleRepo, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dialChatGit opens the chat git-watcher WebSocket. We dial the socket
|
||||
// manually instead of using codersdk.Client.Dial because that helper
|
||||
// closes the HTTP response body before surfacing the error, which
|
||||
// prevents codersdk.ReadBodyAsError from extracting the status code and
|
||||
// message that shouldIgnoreLocalDiffFallbackError needs to decide
|
||||
// whether to degrade to the empty remote diff. Keep this handrolled
|
||||
// path as long as the shared helper has that limitation.
|
||||
func dialChatGit(
|
||||
ctx context.Context,
|
||||
client *codersdk.ExperimentalClient,
|
||||
chatID uuid.UUID,
|
||||
) (*websocket.Conn, error) {
|
||||
requestURL, err := client.URL.Parse(
|
||||
fmt.Sprintf("/api/experimental/chats/%s/stream/git", chatID),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialOptions := &websocket.DialOptions{
|
||||
HTTPClient: client.HTTPClient,
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
}
|
||||
client.SessionTokenProvider.SetDialOption(dialOptions)
|
||||
|
||||
conn, resp, err := websocket.Dial(ctx, requestURL.String(), dialOptions)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
if resp != nil {
|
||||
return nil, codersdk.ReadBodyAsError(resp)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// buildLocalChatDiffContents aggregates the local watcher's
|
||||
// per-repository changes into a single ChatDiffContents. The returned
|
||||
// singleRepo flag is true iff the aggregated diff came from exactly
|
||||
// one contributing repository (one repo with a non-empty UnifiedDiff
|
||||
// that has not been removed). Callers use this flag to decide whether
|
||||
// it is safe to backfill remote-only metadata onto the local diff:
|
||||
// multi-repo aggregates cannot be described by a single remote's
|
||||
// branch/origin/PR URL, but a single-repo aggregate can even when the
|
||||
// contributing repo has an empty branch (detached HEAD) or no origin
|
||||
// remote configured.
|
||||
func buildLocalChatDiffContents(
|
||||
chatID uuid.UUID,
|
||||
repositories []codersdk.WorkspaceAgentRepoChanges,
|
||||
) (codersdk.ChatDiffContents, bool) {
|
||||
result := codersdk.ChatDiffContents{ChatID: chatID}
|
||||
if len(repositories) == 0 {
|
||||
return result, false
|
||||
}
|
||||
|
||||
repositories = slices.Clone(repositories)
|
||||
slices.SortFunc(repositories, func(a, b codersdk.WorkspaceAgentRepoChanges) int {
|
||||
return strings.Compare(a.RepoRoot, b.RepoRoot)
|
||||
})
|
||||
|
||||
diffSegments := make([]string, 0, len(repositories))
|
||||
diffRepositories := make([]codersdk.WorkspaceAgentRepoChanges, 0, len(repositories))
|
||||
for _, repo := range repositories {
|
||||
if repo.Removed || strings.TrimSpace(repo.UnifiedDiff) == "" {
|
||||
continue
|
||||
}
|
||||
diffRepositories = append(diffRepositories, repo)
|
||||
diffSegments = append(diffSegments, strings.TrimRight(repo.UnifiedDiff, "\n"))
|
||||
}
|
||||
if len(diffSegments) == 0 {
|
||||
return result, false
|
||||
}
|
||||
|
||||
result.Diff = strings.Join(diffSegments, "\n")
|
||||
singleRepo := len(diffRepositories) == 1
|
||||
if singleRepo {
|
||||
if branch := strings.TrimSpace(diffRepositories[0].Branch); branch != "" {
|
||||
result.Branch = &branch
|
||||
}
|
||||
if origin := strings.TrimSpace(diffRepositories[0].RemoteOrigin); origin != "" {
|
||||
result.RemoteOrigin = &origin
|
||||
}
|
||||
}
|
||||
return result, singleRepo
|
||||
}
|
||||
|
||||
func shouldIgnoreLocalDiffFallbackError(err error) bool {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
// A watcher stream closed with StatusMessageTooBig or
|
||||
// StatusGoingAway is a best-effort degradation point: the
|
||||
// remote /diff endpoint already returns the empty placeholder
|
||||
// in this case, so fall back to it instead of surfacing a hard
|
||||
// error. See errLocalDiffWatchClosed for the rationale on why
|
||||
// those two close statuses are safe while others still surface.
|
||||
if errors.Is(err, errLocalDiffWatchClosed) {
|
||||
return true
|
||||
}
|
||||
|
||||
sdkErr, ok := codersdk.AsError(err)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch sdkErr.StatusCode() {
|
||||
case http.StatusNotFound:
|
||||
return true
|
||||
case http.StatusForbidden:
|
||||
// authorizeChatWorkspaceExec returns 403 when the chat owner's
|
||||
// workspace permissions have been revoked. The remote diff
|
||||
// endpoint (getChatDiffContents) does not re-check workspace
|
||||
// permissions, so degrade to its empty response the same way
|
||||
// we do for the 400 variants below.
|
||||
return true
|
||||
case http.StatusBadRequest:
|
||||
// These correspond to the 400 responses from watchChatGit in
|
||||
// coderd/exp_chats.go when the chat cannot be observed through
|
||||
// a workspace agent (no workspace bound, workspace deleted, no
|
||||
// agents, or an agent that is not yet connected). Each should
|
||||
// fall back to the empty remote diff the same way a missing
|
||||
// chat (404) does instead of surfacing a hard error.
|
||||
// codersdk.IsChatGitWatchFallbackMessage keeps this list
|
||||
// mechanically linked to the server-side messages.
|
||||
return codersdk.IsChatGitWatchFallbackMessage(sdkErr.Message)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -1,743 +0,0 @@
|
||||
package cli //nolint:testpackage // Tests unexported local diff fallback helpers.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
func TestFetchChatDiffContents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("FallsBackToLocalGitWatcher", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := t.Context()
|
||||
chatID := uuid.New()
|
||||
path := fmt.Sprintf("/api/experimental/chats/%s", chatID)
|
||||
client := newTestExperimentalClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case path + "/diff":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.ChatDiffContents{ChatID: chatID}))
|
||||
case path + "/stream/git":
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||
|
||||
_, payload, err := conn.Read(ctx)
|
||||
require.NoError(t, err)
|
||||
var refresh codersdk.WorkspaceAgentGitClientMessage
|
||||
require.NoError(t, json.Unmarshal(payload, &refresh))
|
||||
require.Equal(t, codersdk.WorkspaceAgentGitClientMessageTypeRefresh, refresh.Type)
|
||||
|
||||
writer, err := conn.Writer(ctx, websocket.MessageText)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, json.NewEncoder(writer).Encode(codersdk.WorkspaceAgentGitServerMessage{
|
||||
Type: codersdk.WorkspaceAgentGitServerMessageTypeChanges,
|
||||
Repositories: []codersdk.WorkspaceAgentRepoChanges{{
|
||||
RepoRoot: "/workspace/repo",
|
||||
Branch: "feature/local-diff",
|
||||
RemoteOrigin: "https://github.com/coder/coder.git",
|
||||
UnifiedDiff: "diff --git a/a.txt b/a.txt\n--- a/a.txt\n+++ b/a.txt\n@@ -1 +1 @@\n-old\n+new\n",
|
||||
}},
|
||||
}))
|
||||
require.NoError(t, writer.Close())
|
||||
default:
|
||||
http.NotFound(rw, r)
|
||||
}
|
||||
}))
|
||||
|
||||
diff, err := fetchChatDiffContents(ctx, client, chatID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, diff.Branch)
|
||||
require.Equal(t, "feature/local-diff", *diff.Branch)
|
||||
require.NotNil(t, diff.RemoteOrigin)
|
||||
require.Equal(t, "https://github.com/coder/coder.git", *diff.RemoteOrigin)
|
||||
require.Contains(t, diff.Diff, "diff --git a/a.txt b/a.txt")
|
||||
require.Contains(t, diff.Diff, "+new")
|
||||
})
|
||||
|
||||
t.Run("IgnoresTimedOutWatcherFallbackErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), testutil.IntervalMedium)
|
||||
defer cancel()
|
||||
|
||||
handlerDone := make(chan struct{})
|
||||
chatID := uuid.New()
|
||||
path := fmt.Sprintf("/api/experimental/chats/%s", chatID)
|
||||
client := newTestExperimentalClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case path + "/diff":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.ChatDiffContents{ChatID: chatID}))
|
||||
case path + "/stream/git":
|
||||
defer close(handlerDone)
|
||||
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||
|
||||
_, payload, err := conn.Read(r.Context())
|
||||
require.NoError(t, err)
|
||||
var refresh codersdk.WorkspaceAgentGitClientMessage
|
||||
require.NoError(t, json.Unmarshal(payload, &refresh))
|
||||
require.Equal(t, codersdk.WorkspaceAgentGitClientMessageTypeRefresh, refresh.Type)
|
||||
|
||||
// Keep the WebSocket open until the client disconnects
|
||||
// (either from fetchChatDiffContents hitting its watch
|
||||
// timeout or test cleanup closing the connection)
|
||||
// instead of sleeping for a fixed duration. The second
|
||||
// Read blocks on the socket and unblocks with an error
|
||||
// when the peer closes the connection, so this handler
|
||||
// drains cleanly without time.Sleep (see WORKFLOWS.md).
|
||||
_, _, _ = conn.Read(r.Context())
|
||||
default:
|
||||
http.NotFound(rw, r)
|
||||
}
|
||||
}))
|
||||
|
||||
diff, err := fetchChatDiffContents(ctx, client, chatID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chatID, diff.ChatID)
|
||||
require.Empty(t, diff.Diff)
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-handlerDone:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
})
|
||||
|
||||
t.Run("IgnoresMissingWorkspaceFallbackErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Each message here matches a 400 response that watchChatGit can
|
||||
// return when the chat cannot be observed through the workspace
|
||||
// agent. fetchChatDiffContents should swallow the error and fall
|
||||
// back to the empty remote diff instead of surfacing a hard
|
||||
// error in the TUI. Drive the subtests from the shared codersdk
|
||||
// constants so a server-side rewording automatically flows
|
||||
// through the test matrix.
|
||||
for _, message := range []string{
|
||||
codersdk.ChatGitWatchNoWorkspaceMessage,
|
||||
codersdk.ChatGitWatchWorkspaceNotFoundMessage,
|
||||
codersdk.ChatGitWatchWorkspaceNoAgentsMessage,
|
||||
codersdk.ChatGitWatchAgentStateMessage(codersdk.WorkspaceAgentConnecting),
|
||||
} {
|
||||
t.Run(message, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := t.Context()
|
||||
chatID := uuid.New()
|
||||
path := fmt.Sprintf("/api/experimental/chats/%s", chatID)
|
||||
client := newTestExperimentalClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case path + "/diff":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.ChatDiffContents{ChatID: chatID}))
|
||||
case path + "/stream/git":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.Response{Message: message}))
|
||||
default:
|
||||
http.NotFound(rw, r)
|
||||
}
|
||||
}))
|
||||
|
||||
diff, err := fetchChatDiffContents(ctx, client, chatID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chatID, diff.ChatID)
|
||||
require.Empty(t, diff.Diff)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IgnoresForbiddenWatcherFallbackErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// authorizeChatWorkspaceExec in coderd/exp_chats.go returns 403
|
||||
// when the chat owner's workspace exec permission is revoked.
|
||||
// The remote /diff endpoint does not re-check workspace
|
||||
// permissions, so fetchChatDiffContents must swallow the 403
|
||||
// and fall back to the empty remote diff just like it does for
|
||||
// the 400 variants above. Without this subtest, removing the
|
||||
// `case http.StatusForbidden` branch in
|
||||
// shouldIgnoreLocalDiffFallbackError would silently regress.
|
||||
ctx := t.Context()
|
||||
chatID := uuid.New()
|
||||
path := fmt.Sprintf("/api/experimental/chats/%s", chatID)
|
||||
client := newTestExperimentalClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case path + "/diff":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.ChatDiffContents{ChatID: chatID}))
|
||||
case path + "/stream/git":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusForbidden)
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.Response{Message: "forbidden"}))
|
||||
default:
|
||||
http.NotFound(rw, r)
|
||||
}
|
||||
}))
|
||||
|
||||
diff, err := fetchChatDiffContents(ctx, client, chatID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chatID, diff.ChatID)
|
||||
require.Empty(t, diff.Diff)
|
||||
})
|
||||
|
||||
t.Run("IgnoresNotFoundWatcherFallbackErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// watchChatGit in coderd/exp_chats.go returns 404 for missing
|
||||
// chats (httpapi.ResourceNotFound). The remote /diff endpoint
|
||||
// already handles the missing-chat case on its own, so
|
||||
// fetchChatDiffContents must swallow the 404 from /stream/git
|
||||
// and fall back to whatever the remote diff returned, the
|
||||
// same way it does for the 400 and 403 variants above.
|
||||
// Without this subtest, removing the `case http.StatusNotFound`
|
||||
// branch in shouldIgnoreLocalDiffFallbackError would silently
|
||||
// regress (mirrors the 403 coverage added for DEREM-16).
|
||||
ctx := t.Context()
|
||||
chatID := uuid.New()
|
||||
path := fmt.Sprintf("/api/experimental/chats/%s", chatID)
|
||||
client := newTestExperimentalClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case path + "/diff":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.ChatDiffContents{ChatID: chatID}))
|
||||
case path + "/stream/git":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.Response{Message: "not found"}))
|
||||
default:
|
||||
http.NotFound(rw, r)
|
||||
}
|
||||
}))
|
||||
|
||||
diff, err := fetchChatDiffContents(ctx, client, chatID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chatID, diff.ChatID)
|
||||
require.Empty(t, diff.Diff)
|
||||
})
|
||||
|
||||
t.Run("BackfillsRemoteMetadataWhenLocalDiffIsSingleRepo", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// The scenario this PR was written for: a chat has remote
|
||||
// metadata (provider, pull-request URL, etc.) but the server
|
||||
// returns an empty Diff because the remote watcher has not
|
||||
// observed changes yet. The CLI fetches the local watcher
|
||||
// diff and must carry the remote metadata forward so the
|
||||
// Diff overlay still shows the PR URL / origin.
|
||||
ctx := t.Context()
|
||||
chatID := uuid.New()
|
||||
path := fmt.Sprintf("/api/experimental/chats/%s", chatID)
|
||||
remoteBranch := "feature/remote-branch"
|
||||
remoteOrigin := "https://github.com/coder/coder.git"
|
||||
remotePR := "https://github.com/coder/coder/pull/42"
|
||||
remoteProvider := "github"
|
||||
client := newTestExperimentalClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case path + "/diff":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.ChatDiffContents{
|
||||
ChatID: chatID,
|
||||
Provider: &remoteProvider,
|
||||
RemoteOrigin: &remoteOrigin,
|
||||
Branch: &remoteBranch,
|
||||
PullRequestURL: &remotePR,
|
||||
}))
|
||||
case path + "/stream/git":
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||
|
||||
_, payload, err := conn.Read(ctx)
|
||||
require.NoError(t, err)
|
||||
var refresh codersdk.WorkspaceAgentGitClientMessage
|
||||
require.NoError(t, json.Unmarshal(payload, &refresh))
|
||||
require.Equal(t, codersdk.WorkspaceAgentGitClientMessageTypeRefresh, refresh.Type)
|
||||
|
||||
writer, err := conn.Writer(ctx, websocket.MessageText)
|
||||
require.NoError(t, err)
|
||||
// Return exactly one repo so buildLocalChatDiffContents
|
||||
// sets Branch/RemoteOrigin, which is the signal that
|
||||
// fetchChatDiffContents uses to backfill missing
|
||||
// metadata from the remote response (Provider, PR URL)
|
||||
// without overwriting fields the local watcher
|
||||
// already populated.
|
||||
require.NoError(t, json.NewEncoder(writer).Encode(codersdk.WorkspaceAgentGitServerMessage{
|
||||
Type: codersdk.WorkspaceAgentGitServerMessageTypeChanges,
|
||||
Repositories: []codersdk.WorkspaceAgentRepoChanges{{
|
||||
RepoRoot: "/workspace/repo",
|
||||
Branch: "feature/local-branch",
|
||||
RemoteOrigin: "https://github.com/coder/local.git",
|
||||
UnifiedDiff: "diff --git a/a.txt b/a.txt\n--- a/a.txt\n+++ b/a.txt\n@@ -1 +1 @@\n-old\n+new\n",
|
||||
}},
|
||||
}))
|
||||
require.NoError(t, writer.Close())
|
||||
default:
|
||||
http.NotFound(rw, r)
|
||||
}
|
||||
}))
|
||||
|
||||
diff, err := fetchChatDiffContents(ctx, client, chatID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The aggregated diff comes from the local watcher.
|
||||
require.Contains(t, diff.Diff, "diff --git a/a.txt b/a.txt")
|
||||
require.Contains(t, diff.Diff, "+new")
|
||||
|
||||
// Branch and RemoteOrigin were populated by the single-repo
|
||||
// local watcher result, so they must NOT be overwritten by
|
||||
// the remote response.
|
||||
require.NotNil(t, diff.Branch)
|
||||
require.Equal(t, "feature/local-branch", *diff.Branch)
|
||||
require.NotNil(t, diff.RemoteOrigin)
|
||||
require.Equal(t, "https://github.com/coder/local.git", *diff.RemoteOrigin)
|
||||
|
||||
// Provider and PullRequestURL were nil on the local diff,
|
||||
// so they must be backfilled from the remote metadata.
|
||||
require.NotNil(t, diff.Provider)
|
||||
require.Equal(t, remoteProvider, *diff.Provider)
|
||||
require.NotNil(t, diff.PullRequestURL)
|
||||
require.Equal(t, remotePR, *diff.PullRequestURL)
|
||||
})
|
||||
|
||||
t.Run("BackfillsRemoteMetadataWhenSingleRepoHasBlankBranchAndOrigin", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// A single contributing repo can legitimately be in detached
|
||||
// HEAD with no origin remote configured: buildLocalChatDiffContents
|
||||
// then leaves both Branch and RemoteOrigin nil even though
|
||||
// exactly one repository produced the aggregated diff. Before
|
||||
// the singleRepo flag was introduced, the gate on
|
||||
// `localDiff.Branch != nil || localDiff.RemoteOrigin != nil`
|
||||
// skipped the backfill in this case and the drawer silently
|
||||
// lost remote Provider/PullRequestURL. fetchChatDiffContents
|
||||
// must now use the explicit singleRepo signal so remote
|
||||
// metadata still flows through, and must also populate the
|
||||
// nil Branch/RemoteOrigin from the remote response to keep the
|
||||
// drawer display consistent with all other single-repo diffs.
|
||||
ctx := t.Context()
|
||||
chatID := uuid.New()
|
||||
path := fmt.Sprintf("/api/experimental/chats/%s", chatID)
|
||||
remoteBranch := "feature/remote-branch"
|
||||
remoteOrigin := "https://github.com/coder/coder.git"
|
||||
remotePR := "https://github.com/coder/coder/pull/42"
|
||||
remoteProvider := "github"
|
||||
client := newTestExperimentalClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case path + "/diff":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.ChatDiffContents{
|
||||
ChatID: chatID,
|
||||
Provider: &remoteProvider,
|
||||
RemoteOrigin: &remoteOrigin,
|
||||
Branch: &remoteBranch,
|
||||
PullRequestURL: &remotePR,
|
||||
}))
|
||||
case path + "/stream/git":
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||
|
||||
_, payload, err := conn.Read(ctx)
|
||||
require.NoError(t, err)
|
||||
var refresh codersdk.WorkspaceAgentGitClientMessage
|
||||
require.NoError(t, json.Unmarshal(payload, &refresh))
|
||||
require.Equal(t, codersdk.WorkspaceAgentGitClientMessageTypeRefresh, refresh.Type)
|
||||
|
||||
writer, err := conn.Writer(ctx, websocket.MessageText)
|
||||
require.NoError(t, err)
|
||||
// Exactly one repository contributes, but both
|
||||
// Branch and RemoteOrigin are empty (detached HEAD,
|
||||
// no origin remote). buildLocalChatDiffContents
|
||||
// still flags this as singleRepo=true, so the
|
||||
// backfill must run and populate every nil field
|
||||
// from the remote response.
|
||||
require.NoError(t, json.NewEncoder(writer).Encode(codersdk.WorkspaceAgentGitServerMessage{
|
||||
Type: codersdk.WorkspaceAgentGitServerMessageTypeChanges,
|
||||
Repositories: []codersdk.WorkspaceAgentRepoChanges{{
|
||||
RepoRoot: "/workspace/repo",
|
||||
Branch: "",
|
||||
RemoteOrigin: "",
|
||||
UnifiedDiff: "diff --git a/a.txt b/a.txt\n--- a/a.txt\n+++ b/a.txt\n@@ -1 +1 @@\n-old\n+new\n",
|
||||
}},
|
||||
}))
|
||||
require.NoError(t, writer.Close())
|
||||
default:
|
||||
http.NotFound(rw, r)
|
||||
}
|
||||
}))
|
||||
|
||||
diff, err := fetchChatDiffContents(ctx, client, chatID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The aggregated diff still comes from the local watcher.
|
||||
require.Contains(t, diff.Diff, "diff --git a/a.txt b/a.txt")
|
||||
require.Contains(t, diff.Diff, "+new")
|
||||
|
||||
// Every remote-only field is backfilled because
|
||||
// buildLocalChatDiffContents flagged the aggregate as
|
||||
// singleRepo=true even with blank branch/origin.
|
||||
require.NotNil(t, diff.Branch)
|
||||
require.Equal(t, remoteBranch, *diff.Branch)
|
||||
require.NotNil(t, diff.RemoteOrigin)
|
||||
require.Equal(t, remoteOrigin, *diff.RemoteOrigin)
|
||||
require.NotNil(t, diff.Provider)
|
||||
require.Equal(t, remoteProvider, *diff.Provider)
|
||||
require.NotNil(t, diff.PullRequestURL)
|
||||
require.Equal(t, remotePR, *diff.PullRequestURL)
|
||||
})
|
||||
|
||||
t.Run("IgnoresWatcherMessageTooBigCloses", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// agentgit caps each repository's UnifiedDiff at ~3 MiB and a
|
||||
// Changes payload aggregates every repo plus metadata, so a
|
||||
// realistic multi-repo workspace can legitimately produce a
|
||||
// payload that exceeds the client's websocket read limit.
|
||||
// When that happens coder/websocket closes the connection
|
||||
// with StatusMessageTooBig. fetchChatDiffContents must map
|
||||
// that specific close status onto errLocalDiffWatchClosed
|
||||
// and fall back to the remote empty diff rather than
|
||||
// surfacing a hard error to the TUI. Without this subtest,
|
||||
// removing the StatusMessageTooBig branch in
|
||||
// fetchLocalChatDiffContents or the errLocalDiffWatchClosed
|
||||
// branch in shouldIgnoreLocalDiffFallbackError would
|
||||
// silently regress the large-multi-repo case this feature is
|
||||
// meant to improve.
|
||||
ctx := t.Context()
|
||||
chatID := uuid.New()
|
||||
path := fmt.Sprintf("/api/experimental/chats/%s", chatID)
|
||||
client := newTestExperimentalClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case path + "/diff":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.ChatDiffContents{ChatID: chatID}))
|
||||
case path + "/stream/git":
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
require.NoError(t, err)
|
||||
// Drain the refresh before closing so the client
|
||||
// surfaces the close status from its next Read, not
|
||||
// an unrelated write error.
|
||||
_, _, err = conn.Read(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close(websocket.StatusMessageTooBig, "too big"))
|
||||
default:
|
||||
http.NotFound(rw, r)
|
||||
}
|
||||
}))
|
||||
|
||||
diff, err := fetchChatDiffContents(ctx, client, chatID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chatID, diff.ChatID)
|
||||
require.Empty(t, diff.Diff)
|
||||
})
|
||||
|
||||
t.Run("IgnoresWatcherGoingAwayCloses", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// The coderd watchChatGit proxy always closes the client
|
||||
// stream with StatusGoingAway regardless of why the
|
||||
// upstream agent->coderd hop failed. In particular, when
|
||||
// that hop's 4 MiB read limit (workspacesdk/agentconn.go)
|
||||
// is exceeded, the agent closes its end with
|
||||
// StatusMessageTooBig but the proxy does not propagate
|
||||
// that status, so the client only observes
|
||||
// StatusGoingAway. That is the exact scenario this PR's
|
||||
// 32 MiB client read limit is meant to handle, so the
|
||||
// TUI must degrade to the remote empty diff for
|
||||
// StatusGoingAway just like it does for
|
||||
// StatusMessageTooBig. Without this subtest, narrowing
|
||||
// the close-status match back to StatusMessageTooBig
|
||||
// only would silently regress multi-repo worktrees whose
|
||||
// aggregate Changes payload sits between the 4 MiB
|
||||
// upstream limit and the 32 MiB client limit.
|
||||
ctx := t.Context()
|
||||
chatID := uuid.New()
|
||||
path := fmt.Sprintf("/api/experimental/chats/%s", chatID)
|
||||
client := newTestExperimentalClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case path + "/diff":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.ChatDiffContents{ChatID: chatID}))
|
||||
case path + "/stream/git":
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
require.NoError(t, err)
|
||||
_, _, err = conn.Read(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close(websocket.StatusGoingAway, "proxy tear-down"))
|
||||
default:
|
||||
http.NotFound(rw, r)
|
||||
}
|
||||
}))
|
||||
|
||||
diff, err := fetchChatDiffContents(ctx, client, chatID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chatID, diff.ChatID)
|
||||
require.Empty(t, diff.Diff)
|
||||
})
|
||||
|
||||
t.Run("SurfacesUnexpectedWatcherCloseErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// The StatusMessageTooBig fallback is intentionally narrow:
|
||||
// a generic websocket close (for example the server
|
||||
// crashing and closing with StatusInternalError) should
|
||||
// surface as an error rather than silently degrading,
|
||||
// because that would hide real protocol regressions behind
|
||||
// the best-effort fallback. This subtest pins that
|
||||
// distinction so a future attempt to blanket-ignore every
|
||||
// close reason immediately breaks the test.
|
||||
ctx := t.Context()
|
||||
chatID := uuid.New()
|
||||
path := fmt.Sprintf("/api/experimental/chats/%s", chatID)
|
||||
client := newTestExperimentalClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case path + "/diff":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.ChatDiffContents{ChatID: chatID}))
|
||||
case path + "/stream/git":
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
require.NoError(t, err)
|
||||
_, _, err = conn.Read(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close(websocket.StatusInternalError, "boom"))
|
||||
default:
|
||||
http.NotFound(rw, r)
|
||||
}
|
||||
}))
|
||||
|
||||
_, err := fetchChatDiffContents(ctx, client, chatID)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("ReturnsRemoteDiffWithoutDialingWatcher", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// When the remote /diff endpoint returns a non-empty diff the
|
||||
// CLI short-circuits the WebSocket fallback. If the git stream
|
||||
// handler ever fires, the test fails the request explicitly so
|
||||
// an inverted condition regresses loudly.
|
||||
ctx := t.Context()
|
||||
chatID := uuid.New()
|
||||
path := fmt.Sprintf("/api/experimental/chats/%s", chatID)
|
||||
branch := "feature/remote"
|
||||
prURL := "https://example.com/pr/1"
|
||||
remoteDiff := codersdk.ChatDiffContents{
|
||||
ChatID: chatID,
|
||||
Branch: &branch,
|
||||
PullRequestURL: &prURL,
|
||||
Diff: "diff --git a/remote.txt b/remote.txt\n--- a/remote.txt\n+++ b/remote.txt\n@@ -1 +1 @@\n-old\n+new\n",
|
||||
}
|
||||
client := newTestExperimentalClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case path + "/diff":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(remoteDiff))
|
||||
case path + "/stream/git":
|
||||
t.Errorf("local git watcher should not be dialed when the remote diff is non-empty")
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
default:
|
||||
http.NotFound(rw, r)
|
||||
}
|
||||
}))
|
||||
|
||||
got, err := fetchChatDiffContents(ctx, client, chatID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, chatID, got.ChatID)
|
||||
require.Equal(t, remoteDiff.Diff, got.Diff)
|
||||
require.NotNil(t, got.Branch)
|
||||
require.Equal(t, branch, *got.Branch)
|
||||
require.NotNil(t, got.PullRequestURL)
|
||||
require.Equal(t, prURL, *got.PullRequestURL)
|
||||
})
|
||||
|
||||
t.Run("PropagatesRemoteDiffAPIErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// A 500 from /diff is a hard failure that the CLI must surface
|
||||
// rather than silently fall back. The local watcher must not
|
||||
// be dialed when the remote endpoint returned an error.
|
||||
ctx := t.Context()
|
||||
chatID := uuid.New()
|
||||
path := fmt.Sprintf("/api/experimental/chats/%s", chatID)
|
||||
client := newTestExperimentalClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case path + "/diff":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.Response{Message: "boom"}))
|
||||
case path + "/stream/git":
|
||||
t.Errorf("local git watcher should not be dialed when /diff errors")
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
default:
|
||||
http.NotFound(rw, r)
|
||||
}
|
||||
}))
|
||||
|
||||
_, err := fetchChatDiffContents(ctx, client, chatID)
|
||||
require.Error(t, err)
|
||||
sdkErr, ok := codersdk.AsError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, http.StatusInternalServerError, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("SurfacesNonIgnorableWatcherErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// A 500 from the git stream is not in the ignorable set, so
|
||||
// fetchChatDiffContents must return it verbatim instead of
|
||||
// silently collapsing to the empty remote diff.
|
||||
ctx := t.Context()
|
||||
chatID := uuid.New()
|
||||
path := fmt.Sprintf("/api/experimental/chats/%s", chatID)
|
||||
client := newTestExperimentalClient(t, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case path + "/diff":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.ChatDiffContents{ChatID: chatID}))
|
||||
case path + "/stream/git":
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
require.NoError(t, json.NewEncoder(rw).Encode(codersdk.Response{Message: "internal git watcher failure"}))
|
||||
default:
|
||||
http.NotFound(rw, r)
|
||||
}
|
||||
}))
|
||||
|
||||
_, err := fetchChatDiffContents(ctx, client, chatID)
|
||||
require.Error(t, err)
|
||||
sdkErr, ok := codersdk.AsError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, http.StatusInternalServerError, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildLocalChatDiffContents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("SortsMultipleReposByRepoRoot", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
diff, singleRepo := buildLocalChatDiffContents(chatID, []codersdk.WorkspaceAgentRepoChanges{
|
||||
{
|
||||
RepoRoot: "/workspace/z-repo",
|
||||
UnifiedDiff: "diff --git a/z.txt b/z.txt\n+z\n",
|
||||
},
|
||||
{
|
||||
RepoRoot: "/workspace/a-repo",
|
||||
Branch: "feature/local",
|
||||
RemoteOrigin: "https://github.com/coder/coder.git",
|
||||
UnifiedDiff: "diff --git a/a.txt b/a.txt\n+a\n",
|
||||
},
|
||||
})
|
||||
|
||||
// Multi-repo aggregation drops the per-repo metadata because
|
||||
// Branch/RemoteOrigin only make sense for a single repo. The
|
||||
// singleRepo flag must be false so callers know not to
|
||||
// backfill remote metadata onto a multi-repo aggregate.
|
||||
require.Equal(t, chatID, diff.ChatID)
|
||||
require.Contains(t, diff.Diff, "diff --git a/a.txt b/a.txt")
|
||||
require.Contains(t, diff.Diff, "diff --git a/z.txt b/z.txt")
|
||||
require.Less(t, strings.Index(diff.Diff, "a.txt"), strings.Index(diff.Diff, "z.txt"))
|
||||
require.Nil(t, diff.Branch)
|
||||
require.Nil(t, diff.RemoteOrigin)
|
||||
require.False(t, singleRepo)
|
||||
})
|
||||
|
||||
t.Run("ReturnsEmptyForNoRepositories", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
// No repos: exercise the early-return in buildLocalChatDiffContents
|
||||
// so the empty case is mechanically covered. singleRepo must
|
||||
// be false because no repository contributed any diff.
|
||||
for _, repos := range [][]codersdk.WorkspaceAgentRepoChanges{nil, {}} {
|
||||
diff, singleRepo := buildLocalChatDiffContents(chatID, repos)
|
||||
require.Equal(t, chatID, diff.ChatID)
|
||||
require.Empty(t, diff.Diff)
|
||||
require.Nil(t, diff.Branch)
|
||||
require.Nil(t, diff.RemoteOrigin)
|
||||
require.False(t, singleRepo)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SkipsRemovedAndEmptyRepositories", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
// Removed repos (Removed=true) and repos with whitespace-only
|
||||
// UnifiedDiff must not contribute to the aggregated diff. With
|
||||
// a single contributing repo, the per-repo Branch and
|
||||
// RemoteOrigin should still propagate to the result and
|
||||
// singleRepo must be true because only one repository
|
||||
// contributed.
|
||||
diff, singleRepo := buildLocalChatDiffContents(chatID, []codersdk.WorkspaceAgentRepoChanges{
|
||||
{
|
||||
RepoRoot: "/workspace/removed",
|
||||
Removed: true,
|
||||
UnifiedDiff: "diff --git a/removed.txt b/removed.txt\n+removed\n",
|
||||
},
|
||||
{
|
||||
RepoRoot: "/workspace/empty",
|
||||
UnifiedDiff: " \n",
|
||||
},
|
||||
{
|
||||
RepoRoot: "/workspace/only",
|
||||
Branch: "feature/only",
|
||||
RemoteOrigin: "https://github.com/coder/coder.git",
|
||||
UnifiedDiff: "diff --git a/only.txt b/only.txt\n+only\n",
|
||||
},
|
||||
})
|
||||
|
||||
require.Equal(t, chatID, diff.ChatID)
|
||||
require.Contains(t, diff.Diff, "diff --git a/only.txt b/only.txt")
|
||||
require.NotContains(t, diff.Diff, "removed.txt")
|
||||
require.NotContains(t, diff.Diff, "empty")
|
||||
require.NotNil(t, diff.Branch)
|
||||
require.Equal(t, "feature/only", *diff.Branch)
|
||||
require.NotNil(t, diff.RemoteOrigin)
|
||||
require.Equal(t, "https://github.com/coder/coder.git", *diff.RemoteOrigin)
|
||||
require.True(t, singleRepo)
|
||||
})
|
||||
|
||||
t.Run("ReturnsEmptyWhenAllRepositoriesAreSkipped", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
// If every repo is removed or empty, buildLocalChatDiffContents
|
||||
// returns the empty remote-diff shape so the caller falls back
|
||||
// to the placeholder overlay instead of rendering a diff-less
|
||||
// summary. singleRepo must be false because no repository
|
||||
// contributed any diff content.
|
||||
diff, singleRepo := buildLocalChatDiffContents(chatID, []codersdk.WorkspaceAgentRepoChanges{
|
||||
{RepoRoot: "/workspace/removed", Removed: true, UnifiedDiff: "diff --git a/removed.txt b/removed.txt\n+removed\n"},
|
||||
{RepoRoot: "/workspace/empty"},
|
||||
})
|
||||
|
||||
require.Equal(t, chatID, diff.ChatID)
|
||||
require.Empty(t, diff.Diff)
|
||||
require.Nil(t, diff.Branch)
|
||||
require.Nil(t, diff.RemoteOrigin)
|
||||
require.False(t, singleRepo)
|
||||
})
|
||||
}
|
||||
@@ -1,151 +0,0 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func agentsPtr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
func setupAgentsBackend(t *testing.T) (*codersdk.Client, *codersdk.ExperimentalClient, uuid.UUID) {
|
||||
t.Helper()
|
||||
|
||||
values := coderdtest.DeploymentValues(t)
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
DeploymentValues: values,
|
||||
})
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "test-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: agentsPtr(int64(4096)),
|
||||
IsDefault: agentsPtr(true),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return client, expClient, firstUser.OrganizationID
|
||||
}
|
||||
|
||||
//nolint:revive // Test helper signature keeps t first for consistency with other helpers.
|
||||
func seedChat(t *testing.T, ctx context.Context, expClient *codersdk.ExperimentalClient, orgID uuid.UUID, seed string) codersdk.Chat {
|
||||
t.Helper()
|
||||
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: orgID,
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: seed,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return chat
|
||||
}
|
||||
|
||||
type agentsSession struct {
|
||||
t *testing.T
|
||||
pty *ptytest.PTY
|
||||
errCh <-chan error
|
||||
}
|
||||
|
||||
func (s *agentsSession) expect(ctx context.Context, text string) {
|
||||
s.t.Helper()
|
||||
s.pty.ExpectMatchContext(ctx, text)
|
||||
}
|
||||
|
||||
func (s *agentsSession) wait(ctx context.Context) error {
|
||||
s.t.Helper()
|
||||
return testutil.RequireReceive(ctx, s.t, s.errCh)
|
||||
}
|
||||
|
||||
//nolint:unused // Kept as a small PTY helper for future multi-character input.
|
||||
func (s *agentsSession) write(text string) {
|
||||
s.t.Helper()
|
||||
s.pty.WriteLine(text)
|
||||
}
|
||||
|
||||
func (s *agentsSession) writeRune(r rune) {
|
||||
s.t.Helper()
|
||||
_, err := s.pty.Input().Write([]byte(string(r)))
|
||||
require.NoError(s.t, err)
|
||||
}
|
||||
|
||||
func (s *agentsSession) enter() {
|
||||
s.t.Helper()
|
||||
_, err := s.pty.Input().Write([]byte("\r"))
|
||||
require.NoError(s.t, err)
|
||||
}
|
||||
|
||||
func (s *agentsSession) esc() {
|
||||
s.t.Helper()
|
||||
_, err := s.pty.Input().Write([]byte("\x1b"))
|
||||
require.NoError(s.t, err)
|
||||
}
|
||||
|
||||
func (s *agentsSession) ctrlC() {
|
||||
s.t.Helper()
|
||||
_, err := s.pty.Input().Write([]byte{3})
|
||||
require.NoError(s.t, err)
|
||||
}
|
||||
|
||||
func (s *agentsSession) quit() {
|
||||
s.t.Helper()
|
||||
s.writeRune('q')
|
||||
}
|
||||
|
||||
//nolint:revive // Test helper signature keeps t first for consistency with other helpers.
|
||||
func startAgentsSession(t *testing.T, ctx context.Context, client *codersdk.Client, args ...string) *agentsSession {
|
||||
t.Helper()
|
||||
|
||||
// Reading to / writing from the PTY is flaky on non-linux systems.
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("skipping on non-linux")
|
||||
}
|
||||
|
||||
fullArgs := append([]string{"agents"}, args...)
|
||||
inv, root := clitest.New(t, fullArgs...)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
pty := ptytest.New(t)
|
||||
tty, err := os.OpenFile(pty.Name(), os.O_RDWR, 0)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = tty.Close()
|
||||
})
|
||||
|
||||
inv.Stdin = tty
|
||||
inv.Stdout = tty
|
||||
inv.Stderr = tty
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
tGo(t, func() {
|
||||
errCh <- inv.WithContext(ctx).Run()
|
||||
})
|
||||
|
||||
return &agentsSession{t: t, pty: pty, errCh: errCh}
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestAgentsE2E(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("EmptyStateBoot", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, _, _ := setupAgentsBackend(t)
|
||||
session := startAgentsSession(t, ctx, client)
|
||||
|
||||
session.expect(ctx, "No chats yet. Press n to start a new chat.")
|
||||
session.quit()
|
||||
require.NoError(t, session.wait(ctx))
|
||||
})
|
||||
|
||||
t.Run("ListAndNavigate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, expClient, orgID := setupAgentsBackend(t)
|
||||
|
||||
_ = seedChat(t, ctx, expClient, orgID, "alpha nav seed")
|
||||
_ = seedChat(t, ctx, expClient, orgID, "bravo nav seed")
|
||||
_ = seedChat(t, ctx, expClient, orgID, "charlie nav seed")
|
||||
|
||||
session := startAgentsSession(t, ctx, client)
|
||||
|
||||
session.expect(ctx, "charlie nav seed")
|
||||
session.expect(ctx, "enter: open")
|
||||
session.enter()
|
||||
session.expect(ctx, "esc")
|
||||
session.esc()
|
||||
session.expect(ctx, "enter: open")
|
||||
session.quit()
|
||||
require.NoError(t, session.wait(ctx))
|
||||
})
|
||||
|
||||
t.Run("SearchFilter", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, expClient, orgID := setupAgentsBackend(t)
|
||||
|
||||
_ = seedChat(t, ctx, expClient, orgID, "alpha filter seed")
|
||||
_ = seedChat(t, ctx, expClient, orgID, "zulu filter seed")
|
||||
|
||||
session := startAgentsSession(t, ctx, client)
|
||||
|
||||
session.expect(ctx, "alpha filter seed")
|
||||
session.expect(ctx, "enter: open")
|
||||
session.writeRune('/')
|
||||
session.expect(ctx, "/ ")
|
||||
for _, r := range "zzzznotamatch" {
|
||||
session.writeRune(r)
|
||||
}
|
||||
session.expect(ctx, "No matches.")
|
||||
session.ctrlC()
|
||||
require.NoError(t, session.wait(ctx))
|
||||
})
|
||||
|
||||
t.Run("ExistingChatHistory", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, expClient, orgID := setupAgentsBackend(t)
|
||||
|
||||
chat := seedChat(t, ctx, expClient, orgID, "direct open seed")
|
||||
session := startAgentsSession(t, ctx, client, chat.ID.String())
|
||||
|
||||
// The initial render contains both the chat title/content
|
||||
// and the status bar in a single frame. Their relative
|
||||
// order in the PTY byte stream depends on async title
|
||||
// generation, so matching them with separate sequential
|
||||
// expects is racy. Instead, just confirm the seed text is
|
||||
// visible (proving we are in the chat view), then verify
|
||||
// esc navigates back to the list.
|
||||
session.expect(ctx, "direct open seed")
|
||||
session.esc()
|
||||
session.expect(ctx, "enter: open")
|
||||
session.quit()
|
||||
require.NoError(t, session.wait(ctx))
|
||||
})
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
var terminalEscapeSequenceRegexp = regexp.MustCompile(
|
||||
`\x1b\[[\x30-\x3f]*[\x20-\x2f]*[\x40-\x7e]|` +
|
||||
"" + `[\x30-\x3f]*[\x20-\x2f]*[\x40-\x7e]|` +
|
||||
`\x1b\][^\x07\x1b]*(?:\x07|\x1b\\)|` +
|
||||
"" + `[^\x07\x1b]*(?:\x07|\x1b\\)|` +
|
||||
`\x1b[^\[\]].`,
|
||||
)
|
||||
|
||||
func sanitizeTerminalRenderableText(text string) string {
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
text = terminalEscapeSequenceRegexp.ReplaceAllString(text, "")
|
||||
return strings.Map(func(r rune) rune {
|
||||
switch r {
|
||||
case '\n', '\t':
|
||||
return r
|
||||
}
|
||||
if unicode.IsControl(r) {
|
||||
return -1
|
||||
}
|
||||
return r
|
||||
}, text)
|
||||
}
|
||||
@@ -1,483 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/spinner"
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
type (
|
||||
openSelectedChatMsg struct {
|
||||
chatID uuid.UUID
|
||||
}
|
||||
openDraftChatMsg struct{}
|
||||
refreshChatsMsg struct{}
|
||||
)
|
||||
|
||||
type chatDisplayRow struct {
|
||||
chat codersdk.Chat
|
||||
depth int
|
||||
isSubagent bool
|
||||
childCount int
|
||||
isExpanded bool
|
||||
}
|
||||
|
||||
type chatListModel struct {
|
||||
styles tuiStyles
|
||||
chats []codersdk.Chat
|
||||
expanded map[uuid.UUID]bool
|
||||
cursor int
|
||||
offset int
|
||||
loading bool
|
||||
err error
|
||||
search textinput.Model
|
||||
searching bool
|
||||
spinner spinner.Model
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
func newChatListModel(styles tuiStyles) chatListModel {
|
||||
search := textinput.New()
|
||||
search.Placeholder = "Search chats..."
|
||||
search.Prompt = "/ "
|
||||
|
||||
s := spinner.New()
|
||||
s.Spinner = spinner.Dot
|
||||
s.Style = styles.dimmedText
|
||||
|
||||
return chatListModel{
|
||||
styles: styles,
|
||||
expanded: make(map[uuid.UUID]bool),
|
||||
loading: true,
|
||||
search: search,
|
||||
spinner: s,
|
||||
}
|
||||
}
|
||||
|
||||
func (m chatListModel) searchQuery() string {
|
||||
return strings.TrimSpace(strings.ToLower(m.search.Value()))
|
||||
}
|
||||
|
||||
func (m chatListModel) filteredChats() []codersdk.Chat {
|
||||
query := m.searchQuery()
|
||||
if query == "" {
|
||||
return m.chats
|
||||
}
|
||||
|
||||
filtered := make([]codersdk.Chat, 0, len(m.chats))
|
||||
for _, chat := range m.chats {
|
||||
if strings.Contains(strings.ToLower(chat.Title), query) || strings.Contains(strings.ToLower(chat.ID.String()), query) {
|
||||
filtered = append(filtered, chat)
|
||||
continue
|
||||
}
|
||||
if chat.LastError != nil && strings.Contains(strings.ToLower(chat.LastError.Message), query) {
|
||||
filtered = append(filtered, chat)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (m chatListModel) displayRows() []chatDisplayRow {
|
||||
filtered := m.filteredChats()
|
||||
if len(filtered) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
queryActive := m.searchQuery() != ""
|
||||
chatsByID := make(map[uuid.UUID]codersdk.Chat, len(m.chats))
|
||||
included := make(map[uuid.UUID]struct{}, len(filtered))
|
||||
for _, chat := range m.chats {
|
||||
chatsByID[chat.ID] = chat
|
||||
}
|
||||
for _, chat := range filtered {
|
||||
included[chat.ID] = struct{}{}
|
||||
if !queryActive {
|
||||
continue
|
||||
}
|
||||
for parentID := chat.ParentChatID; parentID != nil; {
|
||||
parent, ok := chatsByID[*parentID]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
included[parent.ID] = struct{}{}
|
||||
parentID = parent.ParentChatID
|
||||
}
|
||||
}
|
||||
|
||||
childrenOf := make(map[uuid.UUID][]codersdk.Chat)
|
||||
roots := make([]codersdk.Chat, 0, len(included))
|
||||
for _, chat := range m.chats {
|
||||
if _, ok := included[chat.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
if chat.ParentChatID == nil {
|
||||
roots = append(roots, chat)
|
||||
continue
|
||||
}
|
||||
if _, ok := included[*chat.ParentChatID]; ok {
|
||||
childrenOf[*chat.ParentChatID] = append(childrenOf[*chat.ParentChatID], chat)
|
||||
}
|
||||
}
|
||||
|
||||
rows := make([]chatDisplayRow, 0, len(included))
|
||||
var appendRows func(codersdk.Chat, int)
|
||||
appendRows = func(chat codersdk.Chat, depth int) {
|
||||
children := childrenOf[chat.ID]
|
||||
isExpanded := m.expanded[chat.ID]
|
||||
if queryActive && len(children) > 0 {
|
||||
isExpanded = true
|
||||
}
|
||||
|
||||
rows = append(rows, chatDisplayRow{
|
||||
chat: chat,
|
||||
depth: depth,
|
||||
isSubagent: depth > 0,
|
||||
childCount: len(children),
|
||||
isExpanded: isExpanded,
|
||||
})
|
||||
if !isExpanded {
|
||||
return
|
||||
}
|
||||
for _, child := range children {
|
||||
appendRows(child, depth+1)
|
||||
}
|
||||
}
|
||||
|
||||
for _, root := range roots {
|
||||
appendRows(root, 0)
|
||||
}
|
||||
|
||||
return rows
|
||||
}
|
||||
|
||||
func (m chatListModel) selectedRow() (chatDisplayRow, bool) {
|
||||
rows := m.displayRows()
|
||||
if len(rows) == 0 || m.cursor < 0 || m.cursor >= len(rows) {
|
||||
return chatDisplayRow{}, false
|
||||
}
|
||||
return rows[m.cursor], true
|
||||
}
|
||||
|
||||
func (m *chatListModel) moveCursorToChat(chatID uuid.UUID) {
|
||||
rows := m.displayRows()
|
||||
for i, row := range rows {
|
||||
if row.chat.ID == chatID {
|
||||
m.cursor = i
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type chatExpansionIntent int
|
||||
|
||||
const (
|
||||
chatExpansionToggle chatExpansionIntent = iota
|
||||
chatExpansionExpand
|
||||
chatExpansionCollapse
|
||||
)
|
||||
|
||||
func (m *chatListModel) updateSelectedRowExpansion(intent chatExpansionIntent) bool {
|
||||
row, ok := m.selectedRow()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if row.childCount == 0 {
|
||||
if intent == chatExpansionExpand || row.chat.ParentChatID == nil {
|
||||
return false
|
||||
}
|
||||
parentID := *row.chat.ParentChatID
|
||||
m.expanded[parentID] = false
|
||||
m.moveCursorToChat(parentID)
|
||||
return true
|
||||
}
|
||||
|
||||
switch intent {
|
||||
case chatExpansionExpand:
|
||||
if row.isExpanded {
|
||||
return false
|
||||
}
|
||||
m.expanded[row.chat.ID] = true
|
||||
case chatExpansionCollapse:
|
||||
if row.isExpanded {
|
||||
m.expanded[row.chat.ID] = false
|
||||
return true
|
||||
}
|
||||
if row.chat.ParentChatID == nil || !m.expanded[*row.chat.ParentChatID] {
|
||||
return false
|
||||
}
|
||||
parentID := *row.chat.ParentChatID
|
||||
m.expanded[parentID] = false
|
||||
m.moveCursorToChat(parentID)
|
||||
return true
|
||||
case chatExpansionToggle:
|
||||
if row.isExpanded && !m.expanded[row.chat.ID] {
|
||||
return false
|
||||
}
|
||||
m.expanded[row.chat.ID] = !row.isExpanded
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (m chatListModel) selectedChat() *codersdk.Chat {
|
||||
row, ok := m.selectedRow()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return &row.chat
|
||||
}
|
||||
|
||||
func (m *chatListModel) normalizeCursor() {
|
||||
total := len(m.displayRows())
|
||||
if total == 0 {
|
||||
m.cursor = 0
|
||||
m.offset = 0
|
||||
return
|
||||
}
|
||||
m.cursor = min(max(m.cursor, 0), total-1)
|
||||
m.offset, _ = m.visibleWindow(total)
|
||||
}
|
||||
|
||||
func (m chatListModel) visibleChatCount() int {
|
||||
overhead := 3
|
||||
if m.searching {
|
||||
overhead += 2
|
||||
}
|
||||
|
||||
visibleCount := m.height - overhead
|
||||
if visibleCount < 3 {
|
||||
visibleCount = 3
|
||||
}
|
||||
return visibleCount
|
||||
}
|
||||
|
||||
func (m chatListModel) visibleWindow(total int) (start int, end int) {
|
||||
if total == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
visibleCount := m.visibleChatCount()
|
||||
maxOffset := max(total-visibleCount, 0)
|
||||
cursor := min(max(m.cursor, 0), total-1)
|
||||
start = min(max(min(max(m.offset, 0), maxOffset), cursor-visibleCount+1), cursor)
|
||||
end = min(start+visibleCount, total)
|
||||
return start, end
|
||||
}
|
||||
|
||||
func (m chatListModel) Init() tea.Cmd {
|
||||
return m.spinner.Tick
|
||||
}
|
||||
|
||||
func (m chatListModel) Update(msg tea.Msg) (chatListModel, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
m.height = msg.Height
|
||||
m.normalizeCursor()
|
||||
return m, nil
|
||||
|
||||
case spinner.TickMsg:
|
||||
if m.loading {
|
||||
m.spinner, cmd = m.spinner.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case chatsListedMsg:
|
||||
m.chats = msg.chats
|
||||
m.err = msg.err
|
||||
m.loading = false
|
||||
m.normalizeCursor()
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
key := msg.String()
|
||||
if m.searching {
|
||||
switch key {
|
||||
case "esc":
|
||||
if m.search.Value() != "" {
|
||||
m.search.SetValue("")
|
||||
}
|
||||
m.search.Blur()
|
||||
m.searching = false
|
||||
m.normalizeCursor()
|
||||
return m, nil
|
||||
case "enter":
|
||||
m.search.Blur()
|
||||
m.searching = false
|
||||
m.normalizeCursor()
|
||||
return m, nil
|
||||
default:
|
||||
m.search, cmd = m.search.Update(msg)
|
||||
m.normalizeCursor()
|
||||
m.offset = 0
|
||||
return m, cmd
|
||||
}
|
||||
}
|
||||
|
||||
navigationHandled, normalizeNavigation := true, true
|
||||
switch key {
|
||||
case "/", "ctrl+f":
|
||||
m.searching = true
|
||||
m.search.Focus()
|
||||
case "up", "k":
|
||||
m.cursor--
|
||||
case "down", "j":
|
||||
m.cursor++
|
||||
case "right", "l":
|
||||
normalizeNavigation = m.updateSelectedRowExpansion(chatExpansionExpand)
|
||||
case "left", "h":
|
||||
normalizeNavigation = m.updateSelectedRowExpansion(chatExpansionCollapse)
|
||||
case "x":
|
||||
normalizeNavigation = m.updateSelectedRowExpansion(chatExpansionToggle)
|
||||
default:
|
||||
navigationHandled = false
|
||||
}
|
||||
if navigationHandled {
|
||||
if normalizeNavigation {
|
||||
m.normalizeCursor()
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch key {
|
||||
case "enter":
|
||||
selected := m.selectedChat()
|
||||
if selected == nil {
|
||||
return m, nil
|
||||
}
|
||||
return m, func() tea.Msg {
|
||||
return openSelectedChatMsg{chatID: selected.ID}
|
||||
}
|
||||
case "n":
|
||||
return m, func() tea.Msg {
|
||||
return openDraftChatMsg{}
|
||||
}
|
||||
case "r":
|
||||
m.loading = true
|
||||
m.err = nil
|
||||
return m, func() tea.Msg {
|
||||
return refreshChatsMsg{}
|
||||
}
|
||||
case "q":
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m chatListModel) View() string {
|
||||
if m.loading {
|
||||
return m.spinner.View() + " Loading chats…"
|
||||
}
|
||||
|
||||
if m.err != nil {
|
||||
return m.styles.errorText.Render(m.err.Error()) + "\n" + m.styles.helpText.Render("Press r to retry")
|
||||
}
|
||||
|
||||
rows := m.displayRows()
|
||||
lines := make([]string, 0, len(rows)+3)
|
||||
if m.searching {
|
||||
lines = append(lines, m.styles.searchInput.Render(m.search.View()))
|
||||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
if strings.TrimSpace(m.search.Value()) != "" {
|
||||
lines = append(lines, m.styles.dimmedText.Render("No matches."))
|
||||
} else {
|
||||
lines = append(lines, m.styles.dimmedText.Render("No chats yet. Press n to start a new chat."))
|
||||
}
|
||||
help := fitHelpText(
|
||||
m.width,
|
||||
"/: search • n: new chat • r: refresh • q: quit",
|
||||
"/ search • n new • r refresh • q quit",
|
||||
"/ • n • r • q",
|
||||
)
|
||||
lines = append(lines, m.styles.helpText.Render(help))
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
statusWidth := 12
|
||||
start, end := m.visibleWindow(len(rows))
|
||||
for i := start; i < end; i++ {
|
||||
row := rows[i]
|
||||
rowPrefix := " "
|
||||
rowStyle := m.styles.normalItem
|
||||
if i == m.cursor {
|
||||
rowPrefix = "> "
|
||||
rowStyle = m.styles.selectedItem
|
||||
}
|
||||
if row.depth > 0 {
|
||||
rowPrefix += strings.Repeat(" ", row.depth)
|
||||
}
|
||||
if row.childCount > 0 {
|
||||
if row.isExpanded {
|
||||
rowPrefix += "▼ "
|
||||
} else {
|
||||
rowPrefix += "▶ "
|
||||
}
|
||||
}
|
||||
|
||||
extraText := ""
|
||||
extra := ""
|
||||
if row.childCount > 0 {
|
||||
extraText = fmt.Sprintf(" (%d subagents)", row.childCount)
|
||||
extra = m.styles.dimmedText.Render(extraText)
|
||||
}
|
||||
|
||||
titleWidth := max(m.width-statusWidth-18-len(rowPrefix)-len(extraText), 20)
|
||||
title := m.styles.truncate(sanitizeTerminalRenderableText(row.chat.Title), titleWidth)
|
||||
status := m.styles.statusColor(row.chat.Status).Render(string(row.chat.Status))
|
||||
rowText := fmt.Sprintf("%s%s %s %s%s", rowPrefix, rowStyle.Render(title), status, m.styles.dimmedText.Render(timeAgo(row.chat.UpdatedAt)), extra)
|
||||
lines = append(lines, rowText)
|
||||
|
||||
if row.chat.Status == codersdk.ChatStatusError && row.chat.LastError != nil && row.chat.LastError.Message != "" {
|
||||
lastError := row.chat.LastError.Message
|
||||
errWidth := max(m.width-4, 20)
|
||||
errPrefix := " "
|
||||
if row.depth > 0 {
|
||||
errPrefix += strings.Repeat(" ", row.depth)
|
||||
}
|
||||
lines = append(lines, errPrefix+m.styles.dimmedText.Render(m.styles.truncate(sanitizeTerminalRenderableText(lastError), errWidth)))
|
||||
}
|
||||
}
|
||||
|
||||
lines = append(lines, "")
|
||||
help := fitHelpText(
|
||||
m.width,
|
||||
"↑/k: up • ↓/j: down • →/l: expand • ←/h: collapse • x: toggle • enter: open • /: search • n: new chat • r: refresh • q: quit",
|
||||
"↑/k up • ↓/j down • →/l expand • ←/h collapse • x toggle • ↵ open • / search • n new • q quit",
|
||||
"↑↓ nav • →← fold • x toggle • ↵ open • / search • n new • q quit",
|
||||
"↑↓ • →← • x • ↵ • / • n • q",
|
||||
)
|
||||
lines = append(lines, m.styles.helpText.Render(help))
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func timeAgo(t time.Time) string {
|
||||
elapsed := time.Since(t)
|
||||
if elapsed < time.Minute {
|
||||
return "just now"
|
||||
}
|
||||
if elapsed < time.Hour {
|
||||
return fmt.Sprintf("%dm ago", int(elapsed/time.Minute))
|
||||
}
|
||||
if elapsed < 24*time.Hour {
|
||||
return fmt.Sprintf("%dh ago", int(elapsed/time.Hour))
|
||||
}
|
||||
return fmt.Sprintf("%dd ago", int(elapsed/(24*time.Hour)))
|
||||
}
|
||||
@@ -1,514 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
type tuiView int
|
||||
|
||||
const (
|
||||
viewList tuiView = iota
|
||||
viewChat
|
||||
)
|
||||
|
||||
type tuiOverlay int
|
||||
|
||||
const (
|
||||
overlayNone tuiOverlay = iota
|
||||
overlayModelPicker
|
||||
overlayDiffDrawer
|
||||
overlayAskUserQuestion
|
||||
)
|
||||
|
||||
type (
|
||||
terminateTUIMsg struct{}
|
||||
chatsTUIModel struct {
|
||||
ctx context.Context
|
||||
client *codersdk.ExperimentalClient
|
||||
styles tuiStyles
|
||||
currentView tuiView
|
||||
overlay tuiOverlay
|
||||
list chatListModel
|
||||
chat chatViewModel
|
||||
initialChatID *uuid.UUID
|
||||
workspaceID *uuid.UUID
|
||||
modelOverride *string
|
||||
organizationID uuid.UUID
|
||||
chatGeneration uint64
|
||||
catalog *codersdk.ChatModelsResponse
|
||||
quitting bool
|
||||
width int
|
||||
height int
|
||||
}
|
||||
)
|
||||
|
||||
func newChatsTUIModel(
|
||||
ctx context.Context,
|
||||
client *codersdk.ExperimentalClient,
|
||||
initialChatID *uuid.UUID,
|
||||
workspaceID *uuid.UUID,
|
||||
modelOverride *string,
|
||||
organizationID uuid.UUID,
|
||||
) chatsTUIModel {
|
||||
styles := newTUIStyles()
|
||||
currentView := viewList
|
||||
if initialChatID != nil {
|
||||
currentView = viewChat
|
||||
}
|
||||
chat := newChatViewModel(ctx, client, workspaceID, modelOverride, organizationID, styles)
|
||||
chatGeneration := uint64(0)
|
||||
if initialChatID != nil {
|
||||
chat.activeChatID = *initialChatID
|
||||
chat.chatGeneration = 1
|
||||
chat.loading = true
|
||||
chat.metadataResolved = false
|
||||
chat.historyResolved = false
|
||||
chatGeneration = 1
|
||||
}
|
||||
return chatsTUIModel{
|
||||
ctx: ctx,
|
||||
client: client,
|
||||
styles: styles,
|
||||
currentView: currentView,
|
||||
overlay: overlayNone,
|
||||
list: newChatListModel(styles),
|
||||
chat: chat,
|
||||
initialChatID: initialChatID,
|
||||
workspaceID: workspaceID,
|
||||
modelOverride: modelOverride,
|
||||
organizationID: organizationID,
|
||||
chatGeneration: chatGeneration,
|
||||
}
|
||||
}
|
||||
|
||||
// resetChatSession creates a fresh chatViewModel, preserves the
|
||||
// window dimensions from the previous session, and advances
|
||||
// the monotonic generation counter so in-flight async messages
|
||||
// from the old session are ignored.
|
||||
func (m *chatsTUIModel) resetChatSession() {
|
||||
old := m.chat
|
||||
m.chat = newChatViewModel(m.ctx, m.client, m.workspaceID, m.modelOverride, m.organizationID, m.styles)
|
||||
m.chat.width = old.width
|
||||
m.chat.height = old.height
|
||||
m.chat.loading = true
|
||||
m.chat.metadataResolved = false
|
||||
m.chat.historyResolved = false
|
||||
m.chatGeneration++
|
||||
m.chat.chatGeneration = m.chatGeneration
|
||||
}
|
||||
|
||||
func (m *chatsTUIModel) setRenderer(renderer *lipgloss.Renderer) {
|
||||
styles := newTUIStyles(renderer)
|
||||
m.styles = styles
|
||||
m.list.styles = styles
|
||||
m.list.spinner.Style = styles.dimmedText
|
||||
m.chat.styles = styles
|
||||
m.chat.spinner.Style = styles.dimmedText
|
||||
}
|
||||
|
||||
func (m chatsTUIModel) Init() tea.Cmd {
|
||||
if m.initialChatID != nil {
|
||||
m.chat.activeChatID = *m.initialChatID
|
||||
return tea.Batch(append([]tea.Cmd{m.chat.Init()}, m.loadChatCmd(*m.initialChatID, m.chat.chatGeneration)...)...)
|
||||
}
|
||||
return tea.Batch(m.loadChatsCmd(), m.list.Init())
|
||||
}
|
||||
|
||||
func (m chatsTUIModel) loadChatsCmd() tea.Cmd {
|
||||
return apiCmd(func() ([]codersdk.Chat, error) { return m.client.ListChats(m.ctx, nil) }, func(chats []codersdk.Chat, err error) tea.Msg { return chatsListedMsg{chats: chats, err: err} })
|
||||
}
|
||||
|
||||
func (m chatsTUIModel) loadChatCmd(chatID uuid.UUID, generation uint64) []tea.Cmd {
|
||||
return []tea.Cmd{apiCmd(func() (codersdk.Chat, error) { return m.client.GetChat(m.ctx, chatID) }, func(chat codersdk.Chat, err error) tea.Msg {
|
||||
return chatOpenedMsg{generation: generation, chatID: chatID, chat: chat, err: err}
|
||||
}), loadChatHistoryCmd(m.ctx, m.client, chatID, generation)}
|
||||
}
|
||||
|
||||
func (m chatsTUIModel) childWindowSizeMsg() tea.WindowSizeMsg {
|
||||
h := m.height
|
||||
if m.currentView == viewList {
|
||||
h = max(0, h-1)
|
||||
}
|
||||
return tea.WindowSizeMsg{Width: m.width, Height: h}
|
||||
}
|
||||
|
||||
func (m *chatsTUIModel) toggleOverlay(overlay tuiOverlay) bool {
|
||||
if m.overlay == overlay {
|
||||
m.overlay = overlayNone
|
||||
return false
|
||||
}
|
||||
m.overlay = overlay
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *chatsTUIModel) handleEsc(msg tea.KeyMsg) tea.Cmd {
|
||||
if m.currentView == viewList && m.list.searching {
|
||||
var cmd tea.Cmd
|
||||
m.list, cmd = m.list.Update(msg)
|
||||
return cmd
|
||||
}
|
||||
if m.currentView == viewChat {
|
||||
m.chatGeneration++
|
||||
m.chat.chatGeneration = m.chatGeneration
|
||||
m.chat.stopStream()
|
||||
m.currentView = viewList
|
||||
m.list.loading = true
|
||||
return m.loadChatsCmd()
|
||||
}
|
||||
m.quitting = true
|
||||
return tea.Quit
|
||||
}
|
||||
|
||||
func isOverlayCloseKey(msg tea.KeyMsg) bool {
|
||||
if msg.Type == tea.KeyEsc || msg.Type == tea.KeyEscape {
|
||||
return true
|
||||
}
|
||||
|
||||
key := msg.String()
|
||||
return key == "esc" || key == "ctrl+["
|
||||
}
|
||||
|
||||
func (m *chatsTUIModel) handleModelPickerKey(msg tea.KeyMsg) tea.Cmd {
|
||||
switch msg.String() {
|
||||
case "up", "k":
|
||||
if m.chat.modelPickerCursor > 0 {
|
||||
m.chat.modelPickerCursor--
|
||||
}
|
||||
case "down", "j":
|
||||
if m.chat.modelPickerCursor < len(m.chat.modelPickerFlat)-1 {
|
||||
m.chat.modelPickerCursor++
|
||||
}
|
||||
case "enter":
|
||||
if len(m.chat.modelPickerFlat) > 0 && m.chat.modelPickerCursor < len(m.chat.modelPickerFlat) {
|
||||
selected := m.chat.modelPickerFlat[m.chat.modelPickerCursor]
|
||||
m.chat.modelOverride = &selected.ID
|
||||
m.modelOverride = &selected.ID
|
||||
m.overlay = overlayNone
|
||||
}
|
||||
case "ctrl+p", "q":
|
||||
m.overlay = overlayNone
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *chatsTUIModel) handleAskUserQuestionKey(msg tea.KeyMsg) tea.Cmd {
|
||||
state := m.chat.pendingAskUserQuestion
|
||||
if state == nil || state.Submitting || len(state.Questions) == 0 {
|
||||
return nil
|
||||
}
|
||||
if state.CurrentIndex < 0 || state.CurrentIndex >= len(state.Questions) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if state.OtherMode {
|
||||
switch msg.Type {
|
||||
case tea.KeyEsc:
|
||||
state.OtherMode = false
|
||||
state.OtherInput.Blur()
|
||||
return nil
|
||||
case tea.KeyEnter:
|
||||
answer := strings.TrimSpace(state.OtherInput.Value())
|
||||
if answer == "" {
|
||||
return nil
|
||||
}
|
||||
return m.recordAskAnswer(answer, "", true)
|
||||
default:
|
||||
var cmd tea.Cmd
|
||||
state.OtherInput, cmd = state.OtherInput.Update(msg)
|
||||
return cmd
|
||||
}
|
||||
}
|
||||
|
||||
question := state.Questions[state.CurrentIndex]
|
||||
optionCount := len(question.Options) + 1
|
||||
switch msg.String() {
|
||||
case "up", "k":
|
||||
state.OptionCursor--
|
||||
if state.OptionCursor < 0 {
|
||||
state.OptionCursor = optionCount - 1
|
||||
}
|
||||
case "down", "j":
|
||||
state.OptionCursor++
|
||||
if state.OptionCursor >= optionCount {
|
||||
state.OptionCursor = 0
|
||||
}
|
||||
case "left", "h":
|
||||
if state.CurrentIndex == 0 {
|
||||
return nil
|
||||
}
|
||||
state.CurrentIndex--
|
||||
state.OptionCursor = 0
|
||||
state.OtherMode = false
|
||||
state.OtherInput.Blur()
|
||||
state.Error = nil
|
||||
if len(state.Answers) > state.CurrentIndex {
|
||||
state.Answers = state.Answers[:state.CurrentIndex]
|
||||
}
|
||||
case "enter":
|
||||
state.Error = nil
|
||||
if state.OptionCursor < len(question.Options) {
|
||||
option := question.Options[state.OptionCursor]
|
||||
answer := strings.TrimSpace(option.Value)
|
||||
if answer == "" {
|
||||
answer = option.Label
|
||||
}
|
||||
return m.recordAskAnswer(answer, option.Label, false)
|
||||
}
|
||||
state.OtherMode = true
|
||||
state.OtherInput.SetValue("")
|
||||
state.OtherInput.Focus()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *chatsTUIModel) recordAskAnswer(answer, optionLabel string, freeform bool) tea.Cmd {
|
||||
state := m.chat.pendingAskUserQuestion
|
||||
if state == nil || len(state.Questions) == 0 {
|
||||
return nil
|
||||
}
|
||||
if state.CurrentIndex < 0 || state.CurrentIndex >= len(state.Questions) {
|
||||
return nil
|
||||
}
|
||||
|
||||
question := state.Questions[state.CurrentIndex]
|
||||
if len(state.Answers) > state.CurrentIndex {
|
||||
state.Answers = state.Answers[:state.CurrentIndex]
|
||||
}
|
||||
|
||||
state.Answers = append(state.Answers, askQuestionAnswer{
|
||||
Header: question.Header,
|
||||
Question: question.Question,
|
||||
Answer: answer,
|
||||
OptionLabel: optionLabel,
|
||||
Freeform: freeform,
|
||||
})
|
||||
state.OtherMode = false
|
||||
state.OtherInput.Blur()
|
||||
state.OtherInput.SetValue("")
|
||||
state.OptionCursor = 0
|
||||
state.Error = nil
|
||||
|
||||
if state.CurrentIndex+1 < len(state.Questions) {
|
||||
state.CurrentIndex++
|
||||
return nil
|
||||
}
|
||||
|
||||
state.Submitting = true
|
||||
return submitAskUserQuestionCmd(m.client.Client, m.chat.activeChatID, m.chat.chatGeneration, state)
|
||||
}
|
||||
|
||||
func (m *chatsTUIModel) openChatCmd(chatID *uuid.UUID) tea.Cmd {
|
||||
m.currentView = viewChat
|
||||
m.chat.stopStream()
|
||||
m.resetChatSession()
|
||||
if chatID == nil {
|
||||
m.chat.draft = true
|
||||
m.chat.loading = false
|
||||
m.chat.metadataResolved = true
|
||||
m.chat.historyResolved = true
|
||||
m.chat, _ = m.chat.Update(m.childWindowSizeMsg())
|
||||
return nil
|
||||
}
|
||||
m.chat.activeChatID = *chatID
|
||||
m.chat, _ = m.chat.Update(m.childWindowSizeMsg())
|
||||
return tea.Batch(append([]tea.Cmd{m.chat.Init()}, m.loadChatCmd(*chatID, m.chat.chatGeneration)...)...)
|
||||
}
|
||||
|
||||
func (m *chatsTUIModel) toggleModelPickerCmd() tea.Cmd {
|
||||
if !m.toggleOverlay(overlayModelPicker) {
|
||||
return nil
|
||||
}
|
||||
if m.catalog == nil {
|
||||
return apiCmd(func() (codersdk.ChatModelsResponse, error) { return m.client.ListChatModels(m.ctx) }, func(catalog codersdk.ChatModelsResponse, err error) tea.Msg {
|
||||
return modelsListedMsg{catalog: catalog, err: err}
|
||||
})
|
||||
}
|
||||
if len(m.chat.modelPickerFlat) == 0 {
|
||||
m.chat.modelPickerFlat = availableChatModels(*m.catalog)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *chatsTUIModel) toggleDiffDrawerCmd() tea.Cmd {
|
||||
if m.chat.chat == nil {
|
||||
return nil
|
||||
}
|
||||
if !m.toggleOverlay(overlayDiffDrawer) {
|
||||
return nil
|
||||
}
|
||||
if m.chat.diffContents == nil || m.chat.diffErr != nil {
|
||||
m.chat.diffErr = nil
|
||||
chatID := m.chat.chat.ID
|
||||
generation := m.chat.chatGeneration
|
||||
return apiCmd(func() (codersdk.ChatDiffContents, error) { return fetchChatDiffContents(m.ctx, m.client, chatID) }, func(diff codersdk.ChatDiffContents, err error) tea.Msg {
|
||||
return diffContentsMsg{generation: generation, chatID: chatID, diff: diff, err: err}
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m chatsTUIModel) updateChild(msg tea.Msg, view tuiView) (chatsTUIModel, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
if view == viewChat {
|
||||
m.chat, cmd = m.chat.Update(msg)
|
||||
} else {
|
||||
m.list, cmd = m.list.Update(msg)
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m chatsTUIModel) renderOverlay(title, body string) string {
|
||||
return renderOverlayFrame(m.styles, m.width, m.styles.title.Render(title), body, m.styles.helpText.Render("Esc to close"))
|
||||
}
|
||||
|
||||
func (m chatsTUIModel) diffOverlayView() string {
|
||||
switch {
|
||||
case m.chat.diffErr != nil:
|
||||
return m.renderOverlay("Diff", m.styles.errorText.Render(wrapPreservingNewlines(m.chat.diffErr.Error(), contentWidth(m.width, 6))))
|
||||
case m.chat.diffContents != nil:
|
||||
return renderDiffDrawer(m.styles, *m.chat.diffContents, m.chat.diffSummary, m.chat.diffStyledBody, m.width, m.height)
|
||||
default:
|
||||
return m.renderOverlay("Diff", m.styles.dimmedText.Render("Loading diff…"))
|
||||
}
|
||||
}
|
||||
|
||||
func padViewHeight(text string, height int) string {
|
||||
if height <= 0 {
|
||||
return text
|
||||
}
|
||||
if text == "" {
|
||||
return strings.Repeat("\n", max(height-1, 0))
|
||||
}
|
||||
lineCount := countRenderedLines(text)
|
||||
if lineCount >= height {
|
||||
return text
|
||||
}
|
||||
return text + strings.Repeat("\n", height-lineCount)
|
||||
}
|
||||
|
||||
func (m chatsTUIModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
m.height = msg.Height
|
||||
childMsg := m.childWindowSizeMsg()
|
||||
m.list, _ = m.list.Update(childMsg)
|
||||
m.chat, _ = m.chat.Update(childMsg)
|
||||
return m, nil
|
||||
case terminateTUIMsg:
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
case tea.KeyMsg:
|
||||
if msg.Type == tea.KeyCtrlC {
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
// Handle overlays first so their keys do not leak to the underlying
|
||||
// view.
|
||||
if m.overlay == overlayAskUserQuestion {
|
||||
return m, m.handleAskUserQuestionKey(msg)
|
||||
}
|
||||
if m.overlay == overlayModelPicker {
|
||||
if isOverlayCloseKey(msg) {
|
||||
m.overlay = overlayNone
|
||||
return m, tea.ClearScreen
|
||||
}
|
||||
cmd := m.handleModelPickerKey(msg)
|
||||
if m.overlay == overlayNone {
|
||||
return m, tea.Batch(cmd, tea.ClearScreen)
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
if m.overlay == overlayDiffDrawer {
|
||||
if isOverlayCloseKey(msg) {
|
||||
m.overlay = overlayNone
|
||||
return m, tea.ClearScreen
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
if msg.String() == "esc" {
|
||||
return m, m.handleEsc(msg)
|
||||
}
|
||||
case openSelectedChatMsg:
|
||||
return m, m.openChatCmd(&msg.chatID)
|
||||
case openDraftChatMsg:
|
||||
return m, m.openChatCmd(nil)
|
||||
case refreshChatsMsg:
|
||||
return m, m.loadChatsCmd()
|
||||
case toggleModelPickerMsg:
|
||||
return m, m.toggleModelPickerCmd()
|
||||
case toggleDiffDrawerMsg:
|
||||
return m, m.toggleDiffDrawerCmd()
|
||||
case showAskUserQuestionMsg:
|
||||
m.chat.pendingAskUserQuestion = msg.state
|
||||
m.overlay = overlayAskUserQuestion
|
||||
return m.updateChild(msg, viewChat)
|
||||
case hideAskUserQuestionMsg:
|
||||
if m.overlay == overlayAskUserQuestion {
|
||||
m.overlay = overlayNone
|
||||
}
|
||||
return m.updateChild(msg, viewChat)
|
||||
case toolResultsSubmittedMsg:
|
||||
if msg.err == nil && m.chat.matchesGeneration(msg.generation) && msg.chatID == m.chat.activeChatID {
|
||||
m.chat.pendingAskUserQuestion = nil
|
||||
if m.overlay == overlayAskUserQuestion {
|
||||
m.overlay = overlayNone
|
||||
}
|
||||
}
|
||||
return m.updateChild(msg, viewChat)
|
||||
case chatsListedMsg:
|
||||
return m.updateChild(msg, viewList)
|
||||
case chatOpenedMsg, chatHistoryMsg, chatStreamEventMsg, messageSentMsg, chatCreatedMsg, chatInterruptedMsg, diffContentsMsg:
|
||||
return m.updateChild(msg, viewChat)
|
||||
case modelsListedMsg:
|
||||
if msg.err != nil {
|
||||
m.overlay = overlayNone
|
||||
} else {
|
||||
catalog := msg.catalog
|
||||
m.catalog = &catalog
|
||||
}
|
||||
return m.updateChild(msg, viewChat)
|
||||
}
|
||||
return m.updateChild(msg, m.currentView)
|
||||
}
|
||||
|
||||
func (m chatsTUIModel) View() string {
|
||||
if m.quitting {
|
||||
return ""
|
||||
}
|
||||
|
||||
var base string
|
||||
if m.currentView == viewChat {
|
||||
base = m.chat.View()
|
||||
} else {
|
||||
base = m.styles.title.Render("Coder Chats") + "\n" + m.list.View()
|
||||
}
|
||||
|
||||
switch m.overlay {
|
||||
case overlayAskUserQuestion:
|
||||
if m.chat.pendingAskUserQuestion != nil {
|
||||
base += "\n" + renderAskUserQuestion(m.styles, m.chat.pendingAskUserQuestion, m.width, m.height)
|
||||
}
|
||||
case overlayModelPicker:
|
||||
if m.catalog == nil {
|
||||
base += "\n" + m.renderOverlay("Select Model", m.styles.dimmedText.Render("Loading models..."))
|
||||
break
|
||||
}
|
||||
selectedID := ""
|
||||
if m.chat.modelOverride != nil {
|
||||
selectedID = *m.chat.modelOverride
|
||||
}
|
||||
base += "\n" + renderModelPicker(m.styles, *m.catalog, selectedID, m.chat.modelPickerCursor, m.width, m.height)
|
||||
case overlayDiffDrawer:
|
||||
base += "\n" + m.diffOverlayView()
|
||||
}
|
||||
return padViewHeight(base, m.height)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,131 +0,0 @@
|
||||
package cli //nolint:testpackage // Tests unexported chat stream helpers.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
type chatWatchWriters struct{ stdout, stderr io.Writer }
|
||||
|
||||
func (w chatWatchWriters) Write(p []byte) (int, error) { return w.stdout.Write(p) }
|
||||
|
||||
func (w chatWatchWriters) Stderr() io.Writer {
|
||||
if w.stderr != nil {
|
||||
return w.stderr
|
||||
}
|
||||
return w.stdout
|
||||
}
|
||||
|
||||
func consumeChatStream(eventCh <-chan codersdk.ChatStreamEvent, out io.Writer) error {
|
||||
errOut := out
|
||||
if writer, ok := out.(interface{ Stderr() io.Writer }); ok {
|
||||
errOut = writer.Stderr()
|
||||
}
|
||||
|
||||
printedInline := false
|
||||
flush := func() error {
|
||||
if !printedInline {
|
||||
return nil
|
||||
}
|
||||
printedInline = false
|
||||
_, err := fmt.Fprintln(out)
|
||||
return err
|
||||
}
|
||||
|
||||
printLine := func(dst io.Writer, format string, args ...any) error {
|
||||
if err := flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := fmt.Fprintf(dst, format, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
for event := range eventCh {
|
||||
var err error
|
||||
switch event.Type {
|
||||
case codersdk.ChatStreamEventTypeMessagePart:
|
||||
if part := event.MessagePart; part != nil &&
|
||||
part.Part.Type == codersdk.ChatMessagePartTypeText && part.Part.Text != "" {
|
||||
printedInline = true
|
||||
_, err = fmt.Fprint(out, part.Part.Text)
|
||||
}
|
||||
case codersdk.ChatStreamEventTypeMessage:
|
||||
if message := event.Message; message != nil && !printedInline {
|
||||
for _, part := range message.Content {
|
||||
if part.Type != codersdk.ChatMessagePartTypeText || part.Text == "" {
|
||||
continue
|
||||
}
|
||||
printedInline = true
|
||||
if _, err = fmt.Fprint(out, part.Text); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
err = flush()
|
||||
}
|
||||
case codersdk.ChatStreamEventTypeStatus:
|
||||
if event.Status == nil {
|
||||
err = flush()
|
||||
break
|
||||
}
|
||||
err = printLine(out, "[Status: %s]\n", event.Status.Status)
|
||||
case codersdk.ChatStreamEventTypeError:
|
||||
if event.Error == nil {
|
||||
err = flush()
|
||||
break
|
||||
}
|
||||
err = printLine(errOut, "[Error: %s]\n", event.Error.Message)
|
||||
case codersdk.ChatStreamEventTypeRetry:
|
||||
if event.Retry == nil {
|
||||
err = flush()
|
||||
break
|
||||
}
|
||||
err = printLine(out, "[Retry attempt %d after error: %s]\n", event.Retry.Attempt, event.Retry.Error)
|
||||
case codersdk.ChatStreamEventTypeQueueUpdate:
|
||||
default:
|
||||
err = printLine(out, "[Event: %s]\n", event.Type)
|
||||
}
|
||||
if err != nil {
|
||||
return xerrors.Errorf("render chat stream event: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := flush(); err != nil {
|
||||
return xerrors.Errorf("flush chat stream output: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConsumeChatStreamText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
events := make(chan codersdk.ChatStreamEvent, 7)
|
||||
for _, event := range []codersdk.ChatStreamEvent{
|
||||
{Type: codersdk.ChatStreamEventTypeMessagePart, MessagePart: &codersdk.ChatStreamMessagePart{Role: codersdk.ChatMessageRoleAssistant, Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "Hello"}}},
|
||||
{Type: codersdk.ChatStreamEventTypeMessagePart, MessagePart: &codersdk.ChatStreamMessagePart{Role: codersdk.ChatMessageRoleAssistant, Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolCall, Text: "ignored"}}},
|
||||
{Type: codersdk.ChatStreamEventTypeMessagePart, MessagePart: &codersdk.ChatStreamMessagePart{Role: codersdk.ChatMessageRoleAssistant, Part: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: " world"}}},
|
||||
{Type: codersdk.ChatStreamEventTypeMessage, Message: &codersdk.ChatMessage{ID: 1, ChatID: uuid.New(), Role: codersdk.ChatMessageRoleAssistant, Content: []codersdk.ChatMessagePart{{Type: codersdk.ChatMessagePartTypeText, Text: "Hello world"}}}},
|
||||
{Type: codersdk.ChatStreamEventTypeStatus, Status: &codersdk.ChatStreamStatus{Status: codersdk.ChatStatusRunning}},
|
||||
{Type: codersdk.ChatStreamEventTypeRetry, Retry: &codersdk.ChatStreamRetry{Attempt: 2, Error: "rate limited"}},
|
||||
{Type: codersdk.ChatStreamEventTypeError, Error: &codersdk.ChatError{Message: "boom"}},
|
||||
} {
|
||||
events <- event
|
||||
}
|
||||
close(events)
|
||||
|
||||
var stdout bytes.Buffer
|
||||
var stderr bytes.Buffer
|
||||
err := consumeChatStream(events, chatWatchWriters{stdout: &stdout, stderr: &stderr})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Hello world\n[Status: running]\n[Retry attempt 2 after error: rate limited]\n", stdout.String())
|
||||
require.Equal(t, "[Error: boom]\n", stderr.String())
|
||||
}
|
||||
@@ -1,98 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
type tuiStyles struct {
|
||||
title lipgloss.Style
|
||||
subtitle lipgloss.Style
|
||||
statusBar lipgloss.Style
|
||||
statusBadge lipgloss.Style
|
||||
selectedItem lipgloss.Style
|
||||
selectedBlock lipgloss.Style
|
||||
normalItem lipgloss.Style
|
||||
dimmedText lipgloss.Style
|
||||
errorText lipgloss.Style
|
||||
searchInput lipgloss.Style
|
||||
separator lipgloss.Style
|
||||
helpText lipgloss.Style
|
||||
modeBadgeExec lipgloss.Style
|
||||
modeBadgePlan lipgloss.Style
|
||||
userMessage lipgloss.Style
|
||||
assistantMsg lipgloss.Style
|
||||
reasoning lipgloss.Style
|
||||
toolCallStyle lipgloss.Style
|
||||
toolPending lipgloss.Style
|
||||
toolSuccess lipgloss.Style
|
||||
compaction lipgloss.Style
|
||||
warningText lipgloss.Style
|
||||
criticalText lipgloss.Style
|
||||
overlayBorder lipgloss.Style
|
||||
composerStyle lipgloss.Style
|
||||
}
|
||||
|
||||
func newTUIStyles(renderers ...*lipgloss.Renderer) tuiStyles {
|
||||
renderer := lipgloss.DefaultRenderer()
|
||||
if len(renderers) > 0 && renderers[0] != nil {
|
||||
renderer = renderers[0]
|
||||
}
|
||||
|
||||
return tuiStyles{
|
||||
title: renderer.NewStyle().Bold(true),
|
||||
subtitle: renderer.NewStyle().Faint(true),
|
||||
statusBar: renderer.NewStyle(),
|
||||
statusBadge: renderer.NewStyle().Padding(0, 1),
|
||||
selectedItem: renderer.NewStyle().Bold(true),
|
||||
selectedBlock: renderer.NewStyle().
|
||||
BorderLeft(true).
|
||||
BorderStyle(lipgloss.NormalBorder()).
|
||||
BorderForeground(lipgloss.AdaptiveColor{Light: "63", Dark: "63"}).
|
||||
PaddingLeft(1),
|
||||
normalItem: renderer.NewStyle(),
|
||||
dimmedText: renderer.NewStyle().Faint(true),
|
||||
errorText: renderer.NewStyle().Foreground(lipgloss.Color("1")),
|
||||
searchInput: renderer.NewStyle().
|
||||
BorderStyle(lipgloss.NormalBorder()).
|
||||
BorderBottom(true),
|
||||
separator: renderer.NewStyle().Faint(true),
|
||||
helpText: renderer.NewStyle().Faint(true),
|
||||
modeBadgeExec: renderer.NewStyle().Bold(true).Foreground(lipgloss.AdaptiveColor{Light: "22", Dark: "42"}),
|
||||
modeBadgePlan: renderer.NewStyle().Bold(true).Foreground(lipgloss.AdaptiveColor{Light: "130", Dark: "214"}),
|
||||
userMessage: renderer.NewStyle().Bold(true).Foreground(lipgloss.Color("6")),
|
||||
assistantMsg: renderer.NewStyle(),
|
||||
reasoning: renderer.NewStyle().Faint(true).Italic(true),
|
||||
toolCallStyle: renderer.NewStyle().Foreground(lipgloss.Color("3")),
|
||||
toolPending: renderer.NewStyle().Faint(true).Foreground(lipgloss.Color("3")),
|
||||
toolSuccess: renderer.NewStyle().Foreground(lipgloss.Color("2")),
|
||||
compaction: renderer.NewStyle().Bold(true).Foreground(lipgloss.Color("5")),
|
||||
warningText: renderer.NewStyle().Foreground(lipgloss.Color("3")),
|
||||
criticalText: renderer.NewStyle().Foreground(lipgloss.Color("1")).Bold(true),
|
||||
overlayBorder: renderer.NewStyle().BorderStyle(lipgloss.RoundedBorder()).Padding(1),
|
||||
composerStyle: renderer.NewStyle().BorderStyle(lipgloss.NormalBorder()).BorderTop(true),
|
||||
}
|
||||
}
|
||||
|
||||
func (s tuiStyles) statusColor(status codersdk.ChatStatus) lipgloss.Style {
|
||||
color := lipgloss.Color("7")
|
||||
switch status {
|
||||
case codersdk.ChatStatusWaiting, codersdk.ChatStatusPending:
|
||||
color = lipgloss.Color("3")
|
||||
case codersdk.ChatStatusRunning:
|
||||
color = lipgloss.Color("4")
|
||||
case codersdk.ChatStatusPaused:
|
||||
color = lipgloss.Color("5")
|
||||
case codersdk.ChatStatusCompleted:
|
||||
color = lipgloss.Color("2")
|
||||
case codersdk.ChatStatusError:
|
||||
color = lipgloss.Color("1")
|
||||
}
|
||||
return s.statusBadge.Foreground(color)
|
||||
}
|
||||
|
||||
func (s tuiStyles) truncate(text string, maxWidth int) string {
|
||||
_ = s
|
||||
return truncateText(text, maxWidth, "", 3)
|
||||
}
|
||||
-3331
File diff suppressed because it is too large
Load Diff
@@ -100,7 +100,6 @@ const (
|
||||
func (r *RootCmd) CoreSubcommands() []*serpent.Command {
|
||||
// Please re-sort this list alphabetically if you change it!
|
||||
return []*serpent.Command{
|
||||
r.agentsCommand(),
|
||||
r.completion(),
|
||||
r.dotfiles(),
|
||||
externalAuth(),
|
||||
|
||||
Vendored
-1
@@ -14,7 +14,6 @@ USAGE:
|
||||
$ coder templates init
|
||||
|
||||
SUBCOMMANDS:
|
||||
agents Interactive terminal UI for AI agents.
|
||||
autoupdate Toggle auto-update policy for a workspace
|
||||
completion Install or update shell completion scripts for the
|
||||
detected or chosen shell.
|
||||
|
||||
-16
@@ -1,16 +0,0 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder agents [flags] [chat-id]
|
||||
|
||||
Interactive terminal UI for AI agents.
|
||||
|
||||
OPTIONS:
|
||||
--model string
|
||||
Choose a model by ID, provider/model, or display name.
|
||||
|
||||
--workspace string
|
||||
Associate the chat with a workspace by name, owner/name, or UUID.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
Reference in New Issue
Block a user