mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: chat desktop backend (#23005)
Implement the backend for the desktop feature for agents. - Adds a new `/api/experimental/chats/$id/desktop` endpoint to coderd which exposes a VNC stream from a [portabledesktop](https://github.com/coder/portabledesktop) process running inside the workspace - Adds a new `spawn_computer_use_agent` tool to chatd, which spawns a subagent that has access to the `computer` tool which lets it interact with the `portabledesktop` process running inside the workspace - Adds the plumbing to make the above possible There's a follow up frontend PR here: https://github.com/coder/coder/pull/23006
This commit is contained in:
+10
-1
@@ -39,6 +39,7 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/clistat"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentdesktop"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentfiles"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
@@ -310,6 +311,7 @@ type agent struct {
|
||||
filesAPI *agentfiles.API
|
||||
gitAPI *agentgit.API
|
||||
processAPI *agentproc.API
|
||||
desktopAPI *agentdesktop.API
|
||||
|
||||
socketServerEnabled bool
|
||||
socketPath string
|
||||
@@ -386,7 +388,10 @@ func (a *agent) init() {
|
||||
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore)
|
||||
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
|
||||
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
|
||||
|
||||
desktop := agentdesktop.NewPortableDesktop(
|
||||
a.logger.Named("desktop"), a.execer, a.scriptDataDir,
|
||||
)
|
||||
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
|
||||
a.reconnectingPTYServer = reconnectingpty.NewServer(
|
||||
a.logger.Named("reconnecting-pty"),
|
||||
a.sshServer,
|
||||
@@ -2057,6 +2062,10 @@ func (a *agent) Close() error {
|
||||
a.logger.Error(a.hardCtx, "process API close", slog.Error(err))
|
||||
}
|
||||
|
||||
if err := a.desktopAPI.Close(); err != nil {
|
||||
a.logger.Error(a.hardCtx, "desktop API close", slog.Error(err))
|
||||
}
|
||||
|
||||
if a.boundaryLogProxy != nil {
|
||||
err = a.boundaryLogProxy.Close()
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,536 @@
|
||||
package agentdesktop
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// DesktopAction is the request body for the desktop action endpoint.
|
||||
type DesktopAction struct {
|
||||
Action string `json:"action"`
|
||||
Coordinate *[2]int `json:"coordinate,omitempty"`
|
||||
StartCoordinate *[2]int `json:"start_coordinate,omitempty"`
|
||||
Text *string `json:"text,omitempty"`
|
||||
Duration *int `json:"duration,omitempty"`
|
||||
ScrollAmount *int `json:"scroll_amount,omitempty"`
|
||||
ScrollDirection *string `json:"scroll_direction,omitempty"`
|
||||
// ScaledWidth and ScaledHeight are the coordinate space the
|
||||
// model is using. When provided, coordinates are linearly
|
||||
// mapped from scaled → native before dispatching.
|
||||
ScaledWidth *int `json:"scaled_width,omitempty"`
|
||||
ScaledHeight *int `json:"scaled_height,omitempty"`
|
||||
}
|
||||
|
||||
// DesktopActionResponse is the response from the desktop action
|
||||
// endpoint.
|
||||
type DesktopActionResponse struct {
|
||||
Output string `json:"output,omitempty"`
|
||||
ScreenshotData string `json:"screenshot_data,omitempty"`
|
||||
ScreenshotWidth int `json:"screenshot_width,omitempty"`
|
||||
ScreenshotHeight int `json:"screenshot_height,omitempty"`
|
||||
}
|
||||
|
||||
// API exposes the desktop streaming HTTP routes for the agent.
|
||||
type API struct {
|
||||
logger slog.Logger
|
||||
desktop Desktop
|
||||
clock quartz.Clock
|
||||
}
|
||||
|
||||
// NewAPI creates a new desktop streaming API.
|
||||
func NewAPI(logger slog.Logger, desktop Desktop, clock quartz.Clock) *API {
|
||||
if clock == nil {
|
||||
clock = quartz.NewReal()
|
||||
}
|
||||
return &API{
|
||||
logger: logger,
|
||||
desktop: desktop,
|
||||
clock: clock,
|
||||
}
|
||||
}
|
||||
|
||||
// Routes returns the chi router for mounting at /api/v0/desktop.
|
||||
func (a *API) Routes() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/vnc", a.handleDesktopVNC)
|
||||
r.Post("/action", a.handleAction)
|
||||
return r
|
||||
}
|
||||
|
||||
func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Start the desktop session (idempotent).
|
||||
_, err := a.desktop.Start(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to start desktop session.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get a VNC connection.
|
||||
vncConn, err := a.desktop.VNCConn(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to connect to VNC server.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer vncConn.Close()
|
||||
|
||||
// Accept WebSocket from coderd.
|
||||
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Error(ctx, "failed to accept websocket", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// No read limit — RFB framebuffer updates can be large.
|
||||
conn.SetReadLimit(-1)
|
||||
|
||||
wsCtx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
// Bicopy raw bytes between WebSocket and VNC TCP.
|
||||
agentssh.Bicopy(wsCtx, wsNetConn, vncConn)
|
||||
}
|
||||
|
||||
func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
handlerStart := a.clock.Now()
|
||||
|
||||
// Ensure the desktop is running and grab native dimensions.
|
||||
cfg, err := a.desktop.Start(ctx)
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "handleAction: desktop.Start failed",
|
||||
slog.Error(err),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to start desktop session.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var action DesktopAction
|
||||
if err := json.NewDecoder(r.Body).Decode(&action); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to decode request body.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
a.logger.Info(ctx, "handleAction: started",
|
||||
slog.F("action", action.Action),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
)
|
||||
|
||||
// Helper to scale a coordinate pair from the model's space to
|
||||
// native display pixels.
|
||||
scaleXY := func(x, y int) (int, int) {
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
|
||||
x = scaleCoordinate(x, *action.ScaledWidth, cfg.Width)
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
|
||||
y = scaleCoordinate(y, *action.ScaledHeight, cfg.Height)
|
||||
}
|
||||
return x, y
|
||||
}
|
||||
|
||||
var resp DesktopActionResponse
|
||||
|
||||
switch action.Action {
|
||||
case "key":
|
||||
if action.Text == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing \"text\" for key action.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := a.desktop.KeyPress(ctx, *action.Text); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Key press failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "key action performed"
|
||||
|
||||
case "type":
|
||||
if action.Text == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing \"text\" for type action.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := a.desktop.Type(ctx, *action.Text); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Type action failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "type action performed"
|
||||
|
||||
case "cursor_position":
|
||||
x, y, err := a.desktop.CursorPosition(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Cursor position failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "x=" + strconv.Itoa(x) + ",y=" + strconv.Itoa(y)
|
||||
|
||||
case "mouse_move":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
if err := a.desktop.Move(ctx, x, y); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Mouse move failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "mouse_move action performed"
|
||||
|
||||
case "left_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
stepStart := a.clock.Now()
|
||||
if err := a.desktop.Click(ctx, x, y, MouseButtonLeft); err != nil {
|
||||
a.logger.Warn(ctx, "handleAction: Click failed",
|
||||
slog.F("action", "left_click"),
|
||||
slog.F("step", "click"),
|
||||
slog.F("step_ms", time.Since(stepStart).Milliseconds()),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
slog.Error(err),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Left click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
a.logger.Debug(ctx, "handleAction: Click completed",
|
||||
slog.F("action", "left_click"),
|
||||
slog.F("step_ms", time.Since(stepStart).Milliseconds()),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
)
|
||||
resp.Output = "left_click action performed"
|
||||
|
||||
case "left_click_drag":
|
||||
if action.Coordinate == nil || action.StartCoordinate == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing \"coordinate\" or \"start_coordinate\" for left_click_drag.",
|
||||
})
|
||||
return
|
||||
}
|
||||
sx, sy := scaleXY(action.StartCoordinate[0], action.StartCoordinate[1])
|
||||
ex, ey := scaleXY(action.Coordinate[0], action.Coordinate[1])
|
||||
if err := a.desktop.Drag(ctx, sx, sy, ex, ey); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Left click drag failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "left_click_drag action performed"
|
||||
|
||||
case "left_mouse_down":
|
||||
if err := a.desktop.ButtonDown(ctx, MouseButtonLeft); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Left mouse down failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "left_mouse_down action performed"
|
||||
|
||||
case "left_mouse_up":
|
||||
if err := a.desktop.ButtonUp(ctx, MouseButtonLeft); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Left mouse up failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "left_mouse_up action performed"
|
||||
|
||||
case "right_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
if err := a.desktop.Click(ctx, x, y, MouseButtonRight); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Right click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "right_click action performed"
|
||||
|
||||
case "middle_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
if err := a.desktop.Click(ctx, x, y, MouseButtonMiddle); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Middle click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "middle_click action performed"
|
||||
|
||||
case "double_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
if err := a.desktop.DoubleClick(ctx, x, y, MouseButtonLeft); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Double click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "double_click action performed"
|
||||
|
||||
case "triple_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
for range 3 {
|
||||
if err := a.desktop.Click(ctx, x, y, MouseButtonLeft); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Triple click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
resp.Output = "triple_click action performed"
|
||||
|
||||
case "scroll":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
|
||||
amount := 3
|
||||
if action.ScrollAmount != nil {
|
||||
amount = *action.ScrollAmount
|
||||
}
|
||||
direction := "down"
|
||||
if action.ScrollDirection != nil {
|
||||
direction = *action.ScrollDirection
|
||||
}
|
||||
|
||||
var dx, dy int
|
||||
switch direction {
|
||||
case "up":
|
||||
dy = -amount
|
||||
case "down":
|
||||
dy = amount
|
||||
case "left":
|
||||
dx = -amount
|
||||
case "right":
|
||||
dx = amount
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid scroll direction: " + direction,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := a.desktop.Scroll(ctx, x, y, dx, dy); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Scroll failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "scroll action performed"
|
||||
|
||||
case "hold_key":
|
||||
if action.Text == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing \"text\" for hold_key action.",
|
||||
})
|
||||
return
|
||||
}
|
||||
dur := 1000
|
||||
if action.Duration != nil {
|
||||
dur = *action.Duration
|
||||
}
|
||||
if err := a.desktop.KeyDown(ctx, *action.Text); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Key down failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
timer := a.clock.NewTimer(time.Duration(dur)*time.Millisecond, "agentdesktop", "hold_key")
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context canceled; release the key immediately.
|
||||
if err := a.desktop.KeyUp(ctx, *action.Text); err != nil {
|
||||
a.logger.Warn(ctx, "handleAction: KeyUp after context cancel", slog.Error(err))
|
||||
}
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
if err := a.desktop.KeyUp(ctx, *action.Text); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Key up failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "hold_key action performed"
|
||||
|
||||
case "screenshot":
|
||||
var opts ScreenshotOptions
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
|
||||
opts.TargetWidth = *action.ScaledWidth
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
|
||||
opts.TargetHeight = *action.ScaledHeight
|
||||
}
|
||||
result, err := a.desktop.Screenshot(ctx, opts)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Screenshot failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "screenshot"
|
||||
resp.ScreenshotData = result.Data
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 && *action.ScaledWidth != cfg.Width {
|
||||
resp.ScreenshotWidth = *action.ScaledWidth
|
||||
} else {
|
||||
resp.ScreenshotWidth = cfg.Width
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 && *action.ScaledHeight != cfg.Height {
|
||||
resp.ScreenshotHeight = *action.ScaledHeight
|
||||
} else {
|
||||
resp.ScreenshotHeight = cfg.Height
|
||||
}
|
||||
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Unknown action: " + action.Action,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
elapsedMs := a.clock.Since(handlerStart).Milliseconds()
|
||||
if ctx.Err() != nil {
|
||||
a.logger.Error(ctx, "handleAction: context canceled before writing response",
|
||||
slog.F("action", action.Action),
|
||||
slog.F("elapsed_ms", elapsedMs),
|
||||
slog.Error(ctx.Err()),
|
||||
)
|
||||
return
|
||||
}
|
||||
a.logger.Info(ctx, "handleAction: writing response",
|
||||
slog.F("action", action.Action),
|
||||
slog.F("elapsed_ms", elapsedMs),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Close shuts down the desktop session if one is running.
|
||||
func (a *API) Close() error {
|
||||
return a.desktop.Close()
|
||||
}
|
||||
|
||||
// coordFromAction extracts the coordinate pair from a DesktopAction,
|
||||
// returning an error if the coordinate field is missing.
|
||||
func coordFromAction(action DesktopAction) (x, y int, err error) {
|
||||
if action.Coordinate == nil {
|
||||
return 0, 0, &missingFieldError{field: "coordinate", action: action.Action}
|
||||
}
|
||||
return action.Coordinate[0], action.Coordinate[1], nil
|
||||
}
|
||||
|
||||
// missingFieldError is returned when a required field is absent from
|
||||
// a DesktopAction.
|
||||
type missingFieldError struct {
|
||||
field string
|
||||
action string
|
||||
}
|
||||
|
||||
func (e *missingFieldError) Error() string {
|
||||
return "Missing \"" + e.field + "\" for " + e.action + " action."
|
||||
}
|
||||
|
||||
// scaleCoordinate maps a coordinate from scaled → native space.
|
||||
func scaleCoordinate(scaled, scaledDim, nativeDim int) int {
|
||||
if scaledDim == 0 || scaledDim == nativeDim {
|
||||
return scaled
|
||||
}
|
||||
native := (float64(scaled)+0.5)*float64(nativeDim)/float64(scaledDim) - 0.5
|
||||
// Clamp to valid range.
|
||||
native = math.Max(native, 0)
|
||||
native = math.Min(native, float64(nativeDim-1))
|
||||
return int(native)
|
||||
}
|
||||
@@ -0,0 +1,467 @@
|
||||
package agentdesktop_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentdesktop"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// Ensure fakeDesktop satisfies the Desktop interface at compile time.
|
||||
var _ agentdesktop.Desktop = (*fakeDesktop)(nil)
|
||||
|
||||
// fakeDesktop is a minimal Desktop implementation for unit tests.
|
||||
type fakeDesktop struct {
|
||||
startErr error
|
||||
startCfg agentdesktop.DisplayConfig
|
||||
vncConnErr error
|
||||
screenshotErr error
|
||||
screenshotRes agentdesktop.ScreenshotResult
|
||||
closed bool
|
||||
|
||||
// Track calls for assertions.
|
||||
lastMove [2]int
|
||||
lastClick [3]int // x, y, button
|
||||
lastScroll [4]int // x, y, dx, dy
|
||||
lastKey string
|
||||
lastTyped string
|
||||
lastKeyDown string
|
||||
lastKeyUp string
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Start(context.Context) (agentdesktop.DisplayConfig, error) {
|
||||
return f.startCfg, f.startErr
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) VNCConn(context.Context) (net.Conn, error) {
|
||||
return nil, f.vncConnErr
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Screenshot(_ context.Context, _ agentdesktop.ScreenshotOptions) (agentdesktop.ScreenshotResult, error) {
|
||||
return f.screenshotRes, f.screenshotErr
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Move(_ context.Context, x, y int) error {
|
||||
f.lastMove = [2]int{x, y}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Click(_ context.Context, x, y int, _ agentdesktop.MouseButton) error {
|
||||
f.lastClick = [3]int{x, y, 1}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) DoubleClick(_ context.Context, x, y int, _ agentdesktop.MouseButton) error {
|
||||
f.lastClick = [3]int{x, y, 2}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeDesktop) ButtonDown(context.Context, agentdesktop.MouseButton) error { return nil }
|
||||
func (*fakeDesktop) ButtonUp(context.Context, agentdesktop.MouseButton) error { return nil }
|
||||
|
||||
func (f *fakeDesktop) Scroll(_ context.Context, x, y, dx, dy int) error {
|
||||
f.lastScroll = [4]int{x, y, dx, dy}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeDesktop) Drag(context.Context, int, int, int, int) error { return nil }
|
||||
|
||||
func (f *fakeDesktop) KeyPress(_ context.Context, key string) error {
|
||||
f.lastKey = key
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) KeyDown(_ context.Context, key string) error {
|
||||
f.lastKeyDown = key
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) KeyUp(_ context.Context, key string) error {
|
||||
f.lastKeyUp = key
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Type(_ context.Context, text string) error {
|
||||
f.lastTyped = text
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeDesktop) CursorPosition(context.Context) (x int, y int, err error) {
|
||||
return 10, 20, nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Close() error {
|
||||
f.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHandleDesktopVNC_StartError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{startErr: xerrors.New("no desktop")}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/vnc", nil)
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Failed to start desktop session.", resp.Message)
|
||||
}
|
||||
|
||||
func TestHandleAction_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: workspacesdk.DesktopDisplayWidth, Height: workspacesdk.DesktopDisplayHeight},
|
||||
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
body := agentdesktop.DesktopAction{Action: "screenshot"}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var result agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
// Dimensions come from DisplayConfig, not the screenshot CLI.
|
||||
assert.Equal(t, "screenshot", result.Output)
|
||||
assert.Equal(t, "base64data", result.ScreenshotData)
|
||||
assert.Equal(t, workspacesdk.DesktopDisplayWidth, result.ScreenshotWidth)
|
||||
assert.Equal(t, workspacesdk.DesktopDisplayHeight, result.ScreenshotHeight)
|
||||
}
|
||||
|
||||
func TestHandleAction_LeftClick(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "left_click",
|
||||
Coordinate: &[2]int{100, 200},
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var resp agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "left_click action performed", resp.Output)
|
||||
assert.Equal(t, [3]int{100, 200, 1}, fake.lastClick)
|
||||
}
|
||||
|
||||
func TestHandleAction_UnknownAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
body := agentdesktop.DesktopAction{Action: "explode"}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
}
|
||||
|
||||
func TestHandleAction_KeyAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
text := "Return"
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "key",
|
||||
Text: &text,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "Return", fake.lastKey)
|
||||
}
|
||||
|
||||
func TestHandleAction_TypeAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
text := "hello world"
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "type",
|
||||
Text: &text,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "hello world", fake.lastTyped)
|
||||
}
|
||||
|
||||
func TestHandleAction_HoldKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
mClk := quartz.NewMock(t)
|
||||
trap := mClk.Trap().NewTimer("agentdesktop", "hold_key")
|
||||
defer trap.Close()
|
||||
api := agentdesktop.NewAPI(logger, fake, mClk)
|
||||
defer api.Close()
|
||||
|
||||
text := "Shift_L"
|
||||
dur := 100
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "hold_key",
|
||||
Text: &text,
|
||||
Duration: &dur,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
handler.ServeHTTP(rr, req)
|
||||
}()
|
||||
|
||||
// Wait for the timer to be created, then advance past it.
|
||||
trap.MustWait(req.Context()).MustRelease(req.Context())
|
||||
mClk.Advance(time.Duration(dur) * time.Millisecond).MustWait(req.Context())
|
||||
|
||||
<-done
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var resp agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hold_key action performed", resp.Output)
|
||||
assert.Equal(t, "Shift_L", fake.lastKeyDown)
|
||||
assert.Equal(t, "Shift_L", fake.lastKeyUp)
|
||||
}
|
||||
|
||||
func TestHandleAction_HoldKeyMissingText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
body := agentdesktop.DesktopAction{Action: "hold_key"}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Missing \"text\" for hold_key action.", resp.Message)
|
||||
}
|
||||
|
||||
func TestHandleAction_ScrollDown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
dir := "down"
|
||||
amount := 5
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "scroll",
|
||||
Coordinate: &[2]int{500, 400},
|
||||
ScrollDirection: &dir,
|
||||
ScrollAmount: &amount,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
// dy should be positive 5 for "down".
|
||||
assert.Equal(t, [4]int{500, 400, 0, 5}, fake.lastScroll)
|
||||
}
|
||||
|
||||
func TestHandleAction_CoordinateScaling(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
// Native display is 1920x1080.
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
// Model is working in a 1280x720 coordinate space.
|
||||
sw := 1280
|
||||
sh := 720
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "mouse_move",
|
||||
Coordinate: &[2]int{640, 360},
|
||||
ScaledWidth: &sw,
|
||||
ScaledHeight: &sh,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
// 640 in 1280-space → 960 in 1920-space (midpoint maps to
|
||||
// midpoint).
|
||||
assert.Equal(t, 960, fake.lastMove[0])
|
||||
assert.Equal(t, 540, fake.lastMove[1])
|
||||
}
|
||||
|
||||
func TestClose_DelegatesToDesktop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
|
||||
err := api.Close()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, fake.closed)
|
||||
}
|
||||
|
||||
func TestClose_PreventsNewSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// After Close(), Start() will return an error because the
|
||||
// underlying Desktop is closed.
|
||||
fake := &fakeDesktop{}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
|
||||
err := api.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate the closed desktop returning an error on Start().
|
||||
fake.startErr = xerrors.New("desktop is closed")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/vnc", nil)
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Desktop abstracts a virtual desktop session running inside a workspace.
|
||||
type Desktop interface {
|
||||
// Start launches the desktop session. It is idempotent — calling
|
||||
// Start on an already-running session returns the existing
|
||||
// config. The returned DisplayConfig describes the running
|
||||
// session.
|
||||
Start(ctx context.Context) (DisplayConfig, error)
|
||||
|
||||
// VNCConn dials the desktop's VNC server and returns a raw
|
||||
// net.Conn carrying RFB binary frames. Each call returns a new
|
||||
// connection; multiple clients can connect simultaneously.
|
||||
// Start must be called before VNCConn.
|
||||
VNCConn(ctx context.Context) (net.Conn, error)
|
||||
|
||||
// Screenshot captures the current framebuffer as a PNG and
|
||||
// returns it base64-encoded. TargetWidth/TargetHeight in opts
|
||||
// are the desired output dimensions (the implementation
|
||||
// rescales); pass 0 to use native resolution.
|
||||
Screenshot(ctx context.Context, opts ScreenshotOptions) (ScreenshotResult, error)
|
||||
|
||||
// Mouse operations.
|
||||
|
||||
// Move moves the mouse cursor to absolute coordinates.
|
||||
Move(ctx context.Context, x, y int) error
|
||||
// Click performs a mouse button click at the given coordinates.
|
||||
Click(ctx context.Context, x, y int, button MouseButton) error
|
||||
// DoubleClick performs a double-click at the given coordinates.
|
||||
DoubleClick(ctx context.Context, x, y int, button MouseButton) error
|
||||
// ButtonDown presses and holds a mouse button.
|
||||
ButtonDown(ctx context.Context, button MouseButton) error
|
||||
// ButtonUp releases a mouse button.
|
||||
ButtonUp(ctx context.Context, button MouseButton) error
|
||||
// Scroll scrolls by (dx, dy) clicks at the given coordinates.
|
||||
Scroll(ctx context.Context, x, y, dx, dy int) error
|
||||
// Drag moves from (startX,startY) to (endX,endY) while holding
|
||||
// the left mouse button.
|
||||
Drag(ctx context.Context, startX, startY, endX, endY int) error
|
||||
|
||||
// Keyboard operations.
|
||||
|
||||
// KeyPress sends a key-down then key-up for a key combo string
|
||||
// (e.g. "Return", "ctrl+c").
|
||||
KeyPress(ctx context.Context, keys string) error
|
||||
// KeyDown presses and holds a key.
|
||||
KeyDown(ctx context.Context, key string) error
|
||||
// KeyUp releases a key.
|
||||
KeyUp(ctx context.Context, key string) error
|
||||
// Type types a string of text character-by-character.
|
||||
Type(ctx context.Context, text string) error
|
||||
|
||||
// CursorPosition returns the current cursor coordinates.
|
||||
CursorPosition(ctx context.Context) (x, y int, err error)
|
||||
|
||||
// Close shuts down the desktop session and cleans up resources.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// DisplayConfig describes a running desktop session.
|
||||
type DisplayConfig struct {
|
||||
Width int // native width in pixels
|
||||
Height int // native height in pixels
|
||||
VNCPort int // local TCP port for the VNC server
|
||||
Display int // X11 display number (e.g. 1 for :1), -1 if N/A
|
||||
}
|
||||
|
||||
// MouseButton identifies a mouse button.
|
||||
type MouseButton string
|
||||
|
||||
const (
|
||||
MouseButtonLeft MouseButton = "left"
|
||||
MouseButtonRight MouseButton = "right"
|
||||
MouseButtonMiddle MouseButton = "middle"
|
||||
)
|
||||
|
||||
// ScreenshotOptions configures a screenshot capture.
|
||||
type ScreenshotOptions struct {
|
||||
TargetWidth int // 0 = native
|
||||
TargetHeight int // 0 = native
|
||||
}
|
||||
|
||||
// ScreenshotResult is a captured screenshot.
|
||||
type ScreenshotResult struct {
|
||||
Data string // base64-encoded PNG
|
||||
}
|
||||
@@ -0,0 +1,544 @@
|
||||
package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
portableDesktopVersion = "v0.0.4"
|
||||
downloadRetries = 3
|
||||
downloadRetryDelay = time.Second
|
||||
)
|
||||
|
||||
// platformBinaries maps GOARCH to download URL and expected SHA-256
|
||||
// digest for each supported platform.
|
||||
var platformBinaries = map[string]struct {
|
||||
URL string
|
||||
SHA256 string
|
||||
}{
|
||||
"amd64": {
|
||||
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-x64",
|
||||
SHA256: "a04e05e6c7d6f2e6b3acbf1729a7b21271276300b4fee321f4ffee6136538317",
|
||||
},
|
||||
"arm64": {
|
||||
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-arm64",
|
||||
SHA256: "b8cb9142dc32d46a608f25229cbe8168ff2a3aadc54253c74ff54cd347e16ca6",
|
||||
},
|
||||
}
|
||||
|
||||
// portableDesktopOutput is the JSON output from
|
||||
// `portabledesktop up --json`.
|
||||
type portableDesktopOutput struct {
|
||||
VNCPort int `json:"vncPort"`
|
||||
Geometry string `json:"geometry"` // e.g. "1920x1080"
|
||||
}
|
||||
|
||||
// desktopSession tracks a running portabledesktop process.
|
||||
type desktopSession struct {
|
||||
cmd *exec.Cmd
|
||||
vncPort int
|
||||
width int // native width, parsed from geometry
|
||||
height int // native height, parsed from geometry
|
||||
display int // X11 display number, -1 if not available
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// cursorOutput is the JSON output from `portabledesktop cursor --json`.
|
||||
type cursorOutput struct {
|
||||
X int `json:"x"`
|
||||
Y int `json:"y"`
|
||||
}
|
||||
|
||||
// screenshotOutput is the JSON output from
|
||||
// `portabledesktop screenshot --json`.
|
||||
type screenshotOutput struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// portableDesktop implements Desktop by shelling out to the
|
||||
// portabledesktop CLI via agentexec.Execer.
|
||||
type portableDesktop struct {
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
dataDir string // agent's ScriptDataDir, used for binary caching
|
||||
|
||||
mu sync.Mutex
|
||||
session *desktopSession // nil until started
|
||||
binPath string // resolved path to binary, cached
|
||||
closed bool
|
||||
|
||||
// httpClient is used for downloading the binary. If nil,
|
||||
// http.DefaultClient is used.
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewPortableDesktop creates a Desktop backed by the portabledesktop
|
||||
// CLI binary, using execer to spawn child processes. dataDir is used
|
||||
// to cache the downloaded binary.
|
||||
func NewPortableDesktop(
|
||||
logger slog.Logger,
|
||||
execer agentexec.Execer,
|
||||
dataDir string,
|
||||
) Desktop {
|
||||
return &portableDesktop{
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
dataDir: dataDir,
|
||||
}
|
||||
}
|
||||
|
||||
// httpDo returns the HTTP client to use for downloads.
|
||||
func (p *portableDesktop) httpDo() *http.Client {
|
||||
if p.httpClient != nil {
|
||||
return p.httpClient
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
// Start launches the desktop session (idempotent).
|
||||
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return DisplayConfig{}, xerrors.New("desktop is closed")
|
||||
}
|
||||
|
||||
if err := p.ensureBinary(ctx); err != nil {
|
||||
return DisplayConfig{}, xerrors.Errorf("ensure portabledesktop binary: %w", err)
|
||||
}
|
||||
|
||||
// If we have an existing session, check if it's still alive.
|
||||
if p.session != nil {
|
||||
if !(p.session.cmd.ProcessState != nil && p.session.cmd.ProcessState.Exited()) {
|
||||
return DisplayConfig{
|
||||
Width: p.session.width,
|
||||
Height: p.session.height,
|
||||
VNCPort: p.session.vncPort,
|
||||
Display: p.session.display,
|
||||
}, nil
|
||||
}
|
||||
// Process died — clean up and recreate.
|
||||
p.logger.Warn(ctx, "portabledesktop process died, recreating session")
|
||||
p.session.cancel()
|
||||
p.session = nil
|
||||
}
|
||||
|
||||
// Spawn portabledesktop up --json.
|
||||
sessionCtx, sessionCancel := context.WithCancel(context.Background())
|
||||
|
||||
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
|
||||
cmd := p.execer.CommandContext(sessionCtx, p.binPath, "up", "--json",
|
||||
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight))
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
sessionCancel()
|
||||
return DisplayConfig{}, xerrors.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
sessionCancel()
|
||||
return DisplayConfig{}, xerrors.Errorf("start portabledesktop: %w", err)
|
||||
}
|
||||
|
||||
// Parse the JSON output to get VNC port and geometry.
|
||||
var output portableDesktopOutput
|
||||
if err := json.NewDecoder(stdout).Decode(&output); err != nil {
|
||||
sessionCancel()
|
||||
_ = cmd.Process.Kill()
|
||||
_ = cmd.Wait()
|
||||
return DisplayConfig{}, xerrors.Errorf("parse portabledesktop output: %w", err)
|
||||
}
|
||||
|
||||
if output.VNCPort == 0 {
|
||||
sessionCancel()
|
||||
_ = cmd.Process.Kill()
|
||||
_ = cmd.Wait()
|
||||
return DisplayConfig{}, xerrors.New("portabledesktop returned port 0")
|
||||
}
|
||||
|
||||
var w, h int
|
||||
if output.Geometry != "" {
|
||||
if _, err := fmt.Sscanf(output.Geometry, "%dx%d", &w, &h); err != nil {
|
||||
p.logger.Warn(ctx, "failed to parse geometry, using defaults",
|
||||
slog.F("geometry", output.Geometry),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
p.logger.Info(ctx, "started portabledesktop session",
|
||||
slog.F("vnc_port", output.VNCPort),
|
||||
slog.F("width", w),
|
||||
slog.F("height", h),
|
||||
slog.F("pid", cmd.Process.Pid),
|
||||
)
|
||||
|
||||
p.session = &desktopSession{
|
||||
cmd: cmd,
|
||||
vncPort: output.VNCPort,
|
||||
width: w,
|
||||
height: h,
|
||||
display: -1,
|
||||
cancel: sessionCancel,
|
||||
}
|
||||
|
||||
return DisplayConfig{
|
||||
Width: w,
|
||||
Height: h,
|
||||
VNCPort: output.VNCPort,
|
||||
Display: -1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VNCConn dials the desktop's VNC server and returns a raw
|
||||
// net.Conn carrying RFB binary frames.
|
||||
func (p *portableDesktop) VNCConn(_ context.Context) (net.Conn, error) {
|
||||
p.mu.Lock()
|
||||
session := p.session
|
||||
p.mu.Unlock()
|
||||
|
||||
if session == nil {
|
||||
return nil, xerrors.New("desktop session not started")
|
||||
}
|
||||
|
||||
return net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", session.vncPort))
|
||||
}
|
||||
|
||||
// Screenshot captures the current framebuffer as a base64-encoded PNG.
|
||||
func (p *portableDesktop) Screenshot(ctx context.Context, opts ScreenshotOptions) (ScreenshotResult, error) {
|
||||
args := []string{"screenshot", "--json"}
|
||||
if opts.TargetWidth > 0 {
|
||||
args = append(args, "--target-width", strconv.Itoa(opts.TargetWidth))
|
||||
}
|
||||
if opts.TargetHeight > 0 {
|
||||
args = append(args, "--target-height", strconv.Itoa(opts.TargetHeight))
|
||||
}
|
||||
|
||||
out, err := p.runCmd(ctx, args...)
|
||||
if err != nil {
|
||||
return ScreenshotResult{}, err
|
||||
}
|
||||
|
||||
var result screenshotOutput
|
||||
if err := json.Unmarshal([]byte(out), &result); err != nil {
|
||||
return ScreenshotResult{}, xerrors.Errorf("parse screenshot output: %w", err)
|
||||
}
|
||||
|
||||
return ScreenshotResult(result), nil
|
||||
}
|
||||
|
||||
// Move moves the mouse cursor to absolute coordinates.
|
||||
func (p *portableDesktop) Move(ctx context.Context, x, y int) error {
|
||||
_, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y))
|
||||
return err
|
||||
}
|
||||
|
||||
// Click performs a mouse button click at the given coordinates.
|
||||
func (p *portableDesktop) Click(ctx context.Context, x, y int, button MouseButton) error {
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := p.runCmd(ctx, "mouse", "click", string(button))
|
||||
return err
|
||||
}
|
||||
|
||||
// DoubleClick performs a double-click at the given coordinates.
|
||||
func (p *portableDesktop) DoubleClick(ctx context.Context, x, y int, button MouseButton) error {
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := p.runCmd(ctx, "mouse", "click", string(button)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := p.runCmd(ctx, "mouse", "click", string(button))
|
||||
return err
|
||||
}
|
||||
|
||||
// ButtonDown presses and holds a mouse button.
|
||||
func (p *portableDesktop) ButtonDown(ctx context.Context, button MouseButton) error {
|
||||
_, err := p.runCmd(ctx, "mouse", "down", string(button))
|
||||
return err
|
||||
}
|
||||
|
||||
// ButtonUp releases a mouse button.
|
||||
func (p *portableDesktop) ButtonUp(ctx context.Context, button MouseButton) error {
|
||||
_, err := p.runCmd(ctx, "mouse", "up", string(button))
|
||||
return err
|
||||
}
|
||||
|
||||
// Scroll scrolls by (dx, dy) clicks at the given coordinates.
|
||||
func (p *portableDesktop) Scroll(ctx context.Context, x, y, dx, dy int) error {
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := p.runCmd(ctx, "mouse", "scroll", strconv.Itoa(dx), strconv.Itoa(dy))
|
||||
return err
|
||||
}
|
||||
|
||||
// Drag moves from (startX,startY) to (endX,endY) while holding the
|
||||
// left mouse button.
|
||||
func (p *portableDesktop) Drag(ctx context.Context, startX, startY, endX, endY int) error {
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(startX), strconv.Itoa(startY)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := p.runCmd(ctx, "mouse", "down", string(MouseButtonLeft)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(endX), strconv.Itoa(endY)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := p.runCmd(ctx, "mouse", "up", string(MouseButtonLeft))
|
||||
return err
|
||||
}
|
||||
|
||||
// KeyPress sends a key-down then key-up for a key combo string.
|
||||
func (p *portableDesktop) KeyPress(ctx context.Context, keys string) error {
|
||||
_, err := p.runCmd(ctx, "keyboard", "key", keys)
|
||||
return err
|
||||
}
|
||||
|
||||
// KeyDown presses and holds a key.
|
||||
func (p *portableDesktop) KeyDown(ctx context.Context, key string) error {
|
||||
_, err := p.runCmd(ctx, "keyboard", "down", key)
|
||||
return err
|
||||
}
|
||||
|
||||
// KeyUp releases a key.
|
||||
func (p *portableDesktop) KeyUp(ctx context.Context, key string) error {
|
||||
_, err := p.runCmd(ctx, "keyboard", "up", key)
|
||||
return err
|
||||
}
|
||||
|
||||
// Type types a string of text character-by-character.
|
||||
func (p *portableDesktop) Type(ctx context.Context, text string) error {
|
||||
_, err := p.runCmd(ctx, "keyboard", "type", text)
|
||||
return err
|
||||
}
|
||||
|
||||
// CursorPosition returns the current cursor coordinates.
|
||||
func (p *portableDesktop) CursorPosition(ctx context.Context) (x int, y int, err error) {
|
||||
out, err := p.runCmd(ctx, "cursor", "--json")
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
var result cursorOutput
|
||||
if err := json.Unmarshal([]byte(out), &result); err != nil {
|
||||
return 0, 0, xerrors.Errorf("parse cursor output: %w", err)
|
||||
}
|
||||
|
||||
return result.X, result.Y, nil
|
||||
}
|
||||
|
||||
// Close shuts down the desktop session and cleans up resources.
|
||||
func (p *portableDesktop) Close() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.closed = true
|
||||
if p.session != nil {
|
||||
p.session.cancel()
|
||||
// Xvnc is a child process — killing it cleans up the X
|
||||
// session.
|
||||
_ = p.session.cmd.Process.Kill()
|
||||
_ = p.session.cmd.Wait()
|
||||
p.session = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runCmd executes a portabledesktop subcommand and returns combined
|
||||
// output. The caller must have previously called ensureBinary.
|
||||
func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, error) {
|
||||
start := time.Now()
|
||||
//nolint:gosec // args are constructed by the caller, not user input.
|
||||
cmd := p.execer.CommandContext(ctx, p.binPath, args...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "portabledesktop command failed",
|
||||
slog.F("args", args),
|
||||
slog.F("elapsed_ms", elapsed.Milliseconds()),
|
||||
slog.Error(err),
|
||||
slog.F("output", string(out)),
|
||||
)
|
||||
return "", xerrors.Errorf("portabledesktop %s: %w: %s", args[0], err, string(out))
|
||||
}
|
||||
if elapsed > 5*time.Second {
|
||||
p.logger.Warn(ctx, "portabledesktop command slow",
|
||||
slog.F("args", args),
|
||||
slog.F("elapsed_ms", elapsed.Milliseconds()),
|
||||
)
|
||||
} else {
|
||||
p.logger.Debug(ctx, "portabledesktop command completed",
|
||||
slog.F("args", args),
|
||||
slog.F("elapsed_ms", elapsed.Milliseconds()),
|
||||
)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
// ensureBinary resolves or downloads the portabledesktop binary. It
|
||||
// must be called while p.mu is held.
|
||||
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
|
||||
if p.binPath != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 1. Check PATH.
|
||||
if path, err := exec.LookPath("portabledesktop"); err == nil {
|
||||
p.logger.Info(ctx, "found portabledesktop in PATH",
|
||||
slog.F("path", path),
|
||||
)
|
||||
p.binPath = path
|
||||
return nil
|
||||
}
|
||||
|
||||
// 2. Platform checks.
|
||||
if runtime.GOOS != "linux" {
|
||||
return xerrors.New("portabledesktop is only supported on Linux")
|
||||
}
|
||||
bin, ok := platformBinaries[runtime.GOARCH]
|
||||
if !ok {
|
||||
return xerrors.Errorf("unsupported architecture for portabledesktop: %s", runtime.GOARCH)
|
||||
}
|
||||
|
||||
// 3. Check cache.
|
||||
cacheDir := filepath.Join(p.dataDir, "portabledesktop", bin.SHA256)
|
||||
cachedPath := filepath.Join(cacheDir, "portabledesktop")
|
||||
|
||||
if info, err := os.Stat(cachedPath); err == nil && !info.IsDir() {
|
||||
// Verify it is executable.
|
||||
if info.Mode()&0o100 != 0 {
|
||||
p.logger.Info(ctx, "using cached portabledesktop binary",
|
||||
slog.F("path", cachedPath),
|
||||
)
|
||||
p.binPath = cachedPath
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Download with retry.
|
||||
p.logger.Info(ctx, "downloading portabledesktop binary",
|
||||
slog.F("url", bin.URL),
|
||||
slog.F("version", portableDesktopVersion),
|
||||
slog.F("arch", runtime.GOARCH),
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
for attempt := range downloadRetries {
|
||||
if err := downloadBinary(ctx, p.httpDo(), bin.URL, bin.SHA256, cachedPath); err != nil {
|
||||
lastErr = err
|
||||
p.logger.Warn(ctx, "download attempt failed",
|
||||
slog.F("attempt", attempt+1),
|
||||
slog.F("max_attempts", downloadRetries),
|
||||
slog.Error(err),
|
||||
)
|
||||
if attempt < downloadRetries-1 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(downloadRetryDelay):
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
p.binPath = cachedPath
|
||||
p.logger.Info(ctx, "downloaded portabledesktop binary",
|
||||
slog.F("path", cachedPath),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
return xerrors.Errorf("download portabledesktop after %d attempts: %w", downloadRetries, lastErr)
|
||||
}
|
||||
|
||||
// downloadBinary fetches a binary from url, verifies its SHA-256
|
||||
// digest matches expectedSHA256, and atomically writes it to destPath.
|
||||
func downloadBinary(ctx context.Context, client *http.Client, url, expectedSHA256, destPath string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(destPath), 0o700); err != nil {
|
||||
return xerrors.Errorf("create cache directory: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create HTTP request: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("HTTP GET %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return xerrors.Errorf("HTTP GET %s: status %d", url, resp.StatusCode)
|
||||
}
|
||||
|
||||
// Write to a temp file in the same directory so the final rename
|
||||
// is atomic on the same filesystem.
|
||||
tmpFile, err := os.CreateTemp(filepath.Dir(destPath), "portabledesktop-download-*")
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
// Clean up the temp file on any error path.
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
// Stream the response body while computing SHA-256.
|
||||
hasher := sha256.New()
|
||||
if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil {
|
||||
return xerrors.Errorf("download body: %w", err)
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return xerrors.Errorf("close temp file: %w", err)
|
||||
}
|
||||
|
||||
// Verify digest.
|
||||
actualSHA256 := hex.EncodeToString(hasher.Sum(nil))
|
||||
if actualSHA256 != expectedSHA256 {
|
||||
return xerrors.Errorf(
|
||||
"SHA-256 mismatch: expected %s, got %s",
|
||||
expectedSHA256, actualSHA256,
|
||||
)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tmpPath, 0o700); err != nil {
|
||||
return xerrors.Errorf("chmod: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, destPath); err != nil {
|
||||
return xerrors.Errorf("rename to final path: %w", err)
|
||||
}
|
||||
|
||||
success = true
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,713 @@
|
||||
package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/pty"
|
||||
)
|
||||
|
||||
// recordedExecer implements agentexec.Execer by recording every
|
||||
// invocation and delegating to a real shell command built from a
|
||||
// caller-supplied mapping of subcommand → shell script body.
|
||||
type recordedExecer struct {
|
||||
mu sync.Mutex
|
||||
commands [][]string
|
||||
// scripts maps a subcommand keyword (e.g. "up", "screenshot")
|
||||
// to a shell snippet whose stdout will be the command output.
|
||||
scripts map[string]string
|
||||
}
|
||||
|
||||
func (r *recordedExecer) record(cmd string, args ...string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.commands = append(r.commands, append([]string{cmd}, args...))
|
||||
}
|
||||
|
||||
func (r *recordedExecer) allCommands() [][]string {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
out := make([][]string, len(r.commands))
|
||||
copy(out, r.commands)
|
||||
return out
|
||||
}
|
||||
|
||||
// scriptFor finds the first matching script key present in args.
|
||||
func (r *recordedExecer) scriptFor(args []string) string {
|
||||
for _, a := range args {
|
||||
if s, ok := r.scripts[a]; ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
// Fallback: succeed silently.
|
||||
return "true"
|
||||
}
|
||||
|
||||
func (r *recordedExecer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd {
|
||||
r.record(cmd, args...)
|
||||
script := r.scriptFor(args)
|
||||
//nolint:gosec // Test helper — script content is controlled by the test.
|
||||
return exec.CommandContext(ctx, "sh", "-c", script)
|
||||
}
|
||||
|
||||
func (r *recordedExecer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd {
|
||||
r.record(cmd, args...)
|
||||
return pty.CommandContext(ctx, "sh", "-c", r.scriptFor(args))
|
||||
}
|
||||
|
||||
// --- portableDesktop tests ---
|
||||
|
||||
func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
// The "up" script prints the JSON line then sleeps until
|
||||
// the context is canceled (simulating a long-running process).
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
cfg, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1920, cfg.Width)
|
||||
assert.Equal(t, 1080, cfg.Height)
|
||||
assert.Equal(t, 5901, cfg.VNCPort)
|
||||
assert.Equal(t, -1, cfg.Display)
|
||||
|
||||
// Clean up the long-running process.
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_Start_Idempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
cfg1, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg2, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, cfg1, cfg2, "second Start should return the same config")
|
||||
|
||||
// The execer should have been called exactly once for "up".
|
||||
cmds := rec.allCommands()
|
||||
upCalls := 0
|
||||
for _, c := range cmds {
|
||||
for _, a := range c {
|
||||
if a == "up" {
|
||||
upCalls++
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, upCalls, "expected exactly one 'up' invocation")
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"screenshot": `echo '{"data":"abc123"}'`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := pd.Screenshot(ctx, ScreenshotOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "abc123", result.Data)
|
||||
}
|
||||
|
||||
func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"screenshot": `echo '{"data":"x"}'`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := pd.Screenshot(ctx, ScreenshotOptions{
|
||||
TargetWidth: 800,
|
||||
TargetHeight: 600,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
require.NotEmpty(t, cmds)
|
||||
|
||||
// The last command should contain the target dimension flags.
|
||||
last := cmds[len(cmds)-1]
|
||||
joined := strings.Join(last, " ")
|
||||
assert.Contains(t, joined, "--target-width 800")
|
||||
assert.Contains(t, joined, "--target-height 600")
|
||||
}
|
||||
|
||||
func TestPortableDesktop_MouseMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Each sub-test verifies a single mouse method dispatches the
|
||||
// correct CLI arguments.
|
||||
tests := []struct {
|
||||
name string
|
||||
invoke func(context.Context, *portableDesktop) error
|
||||
wantArgs []string // substrings expected in a recorded command
|
||||
}{
|
||||
{
|
||||
name: "Move",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Move(ctx, 42, 99)
|
||||
},
|
||||
wantArgs: []string{"mouse", "move", "42", "99"},
|
||||
},
|
||||
{
|
||||
name: "Click",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Click(ctx, 10, 20, MouseButtonLeft)
|
||||
},
|
||||
// Click does move then click.
|
||||
wantArgs: []string{"mouse", "click", "left"},
|
||||
},
|
||||
{
|
||||
name: "DoubleClick",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.DoubleClick(ctx, 5, 6, MouseButtonRight)
|
||||
},
|
||||
wantArgs: []string{"mouse", "click", "right"},
|
||||
},
|
||||
{
|
||||
name: "ButtonDown",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.ButtonDown(ctx, MouseButtonMiddle)
|
||||
},
|
||||
wantArgs: []string{"mouse", "down", "middle"},
|
||||
},
|
||||
{
|
||||
name: "ButtonUp",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.ButtonUp(ctx, MouseButtonLeft)
|
||||
},
|
||||
wantArgs: []string{"mouse", "up", "left"},
|
||||
},
|
||||
{
|
||||
name: "Scroll",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Scroll(ctx, 50, 60, 3, 4)
|
||||
},
|
||||
wantArgs: []string{"mouse", "scroll", "3", "4"},
|
||||
},
|
||||
{
|
||||
name: "Drag",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Drag(ctx, 10, 20, 30, 40)
|
||||
},
|
||||
// Drag ends with mouse up left.
|
||||
wantArgs: []string{"mouse", "up", "left"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"mouse": `echo ok`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
err := tt.invoke(context.Background(), pd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
require.NotEmpty(t, cmds, "expected at least one command")
|
||||
|
||||
// Find at least one recorded command that contains
|
||||
// all expected argument substrings.
|
||||
found := false
|
||||
for _, cmd := range cmds {
|
||||
joined := strings.Join(cmd, " ")
|
||||
match := true
|
||||
for _, want := range tt.wantArgs {
|
||||
if !strings.Contains(joined, want) {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found,
|
||||
"no recorded command matched %v; got %v", tt.wantArgs, cmds)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortableDesktop_KeyboardMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
invoke func(context.Context, *portableDesktop) error
|
||||
wantArgs []string
|
||||
}{
|
||||
{
|
||||
name: "KeyPress",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.KeyPress(ctx, "Return")
|
||||
},
|
||||
wantArgs: []string{"keyboard", "key", "Return"},
|
||||
},
|
||||
{
|
||||
name: "KeyDown",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.KeyDown(ctx, "shift")
|
||||
},
|
||||
wantArgs: []string{"keyboard", "down", "shift"},
|
||||
},
|
||||
{
|
||||
name: "KeyUp",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.KeyUp(ctx, "shift")
|
||||
},
|
||||
wantArgs: []string{"keyboard", "up", "shift"},
|
||||
},
|
||||
{
|
||||
name: "Type",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Type(ctx, "hello world")
|
||||
},
|
||||
wantArgs: []string{"keyboard", "type", "hello world"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"keyboard": `echo ok`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
err := tt.invoke(context.Background(), pd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
require.NotEmpty(t, cmds)
|
||||
|
||||
last := cmds[len(cmds)-1]
|
||||
joined := strings.Join(last, " ")
|
||||
for _, want := range tt.wantArgs {
|
||||
assert.Contains(t, joined, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortableDesktop_CursorPosition(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"cursor": `echo '{"x":100,"y":200}'`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
x, y, err := pd.CursorPosition(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, x)
|
||||
assert.Equal(t, 200, y)
|
||||
}
|
||||
|
||||
func TestPortableDesktop_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1024x768"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Session should exist.
|
||||
pd.mu.Lock()
|
||||
require.NotNil(t, pd.session)
|
||||
pd.mu.Unlock()
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
|
||||
// Session should be cleaned up.
|
||||
pd.mu.Lock()
|
||||
assert.Nil(t, pd.session)
|
||||
assert.True(t, pd.closed)
|
||||
pd.mu.Unlock()
|
||||
|
||||
// Subsequent Start must fail.
|
||||
_, err = pd.Start(ctx)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "desktop is closed")
|
||||
}
|
||||
|
||||
// --- downloadBinary tests ---
|
||||
|
||||
func TestDownloadBinary_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho portable\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the file exists and has correct content.
|
||||
got, err := os.ReadFile(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
|
||||
// Verify executable permissions.
|
||||
info, err := os.Stat(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, info.Mode()&0o700, "binary should be executable")
|
||||
}
|
||||
|
||||
func TestDownloadBinary_ChecksumMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("real binary content"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, wrongSHA, destPath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "SHA-256 mismatch")
|
||||
|
||||
// The destination file should not exist (temp file cleaned up).
|
||||
_, statErr := os.Stat(destPath)
|
||||
assert.True(t, os.IsNotExist(statErr), "dest file should not exist after checksum failure")
|
||||
|
||||
// No leftover temp files in the directory.
|
||||
entries, err := os.ReadDir(destDir)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, entries, "no leftover temp files should remain")
|
||||
}
|
||||
|
||||
func TestDownloadBinary_HTTPError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, "irrelevant", destPath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "status 404")
|
||||
}
|
||||
|
||||
// --- ensureBinary tests ---
|
||||
|
||||
func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// When binPath is already set, ensureBinary should return
|
||||
// immediately without doing any work.
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "/already/set",
|
||||
}
|
||||
|
||||
err := pd.ensureBinary(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/already/set", pd.binPath)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_UsesCachedBinary(t *testing.T) {
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment.
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
}
|
||||
|
||||
bin, ok := platformBinaries[runtime.GOARCH]
|
||||
if !ok {
|
||||
t.Skipf("no platformBinary entry for %s", runtime.GOARCH)
|
||||
}
|
||||
|
||||
dataDir := t.TempDir()
|
||||
cacheDir := filepath.Join(dataDir, "portabledesktop", bin.SHA256)
|
||||
require.NoError(t, os.MkdirAll(cacheDir, 0o700))
|
||||
|
||||
cachedPath := filepath.Join(cacheDir, "portabledesktop")
|
||||
require.NoError(t, os.WriteFile(cachedPath, []byte("#!/bin/sh\n"), 0o600))
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: dataDir,
|
||||
}
|
||||
|
||||
// Clear PATH so LookPath won't find a real binary.
|
||||
t.Setenv("PATH", "")
|
||||
|
||||
err := pd.ensureBinary(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cachedPath, pd.binPath)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_Downloads(t *testing.T) {
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment and we override the package-level platformBinaries.
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
}
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho downloaded\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Save and restore platformBinaries for this test.
|
||||
origBinaries := platformBinaries
|
||||
platformBinaries = map[string]struct {
|
||||
URL string
|
||||
SHA256 string
|
||||
}{
|
||||
runtime.GOARCH: {
|
||||
URL: srv.URL + "/portabledesktop",
|
||||
SHA256: expectedSHA,
|
||||
},
|
||||
}
|
||||
t.Cleanup(func() { platformBinaries = origBinaries })
|
||||
|
||||
dataDir := t.TempDir()
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: dataDir,
|
||||
httpClient: srv.Client(),
|
||||
}
|
||||
|
||||
// Ensure PATH doesn't contain a real portabledesktop binary.
|
||||
t.Setenv("PATH", "")
|
||||
|
||||
err := pd.ensureBinary(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPath := filepath.Join(dataDir, "portabledesktop", expectedSHA, "portabledesktop")
|
||||
assert.Equal(t, expectedPath, pd.binPath)
|
||||
|
||||
// Verify the downloaded file has correct content.
|
||||
got, err := os.ReadFile(expectedPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_RetriesOnFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
}
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho retried\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
|
||||
var mu sync.Mutex
|
||||
attempt := 0
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
mu.Lock()
|
||||
current := attempt
|
||||
attempt++
|
||||
mu.Unlock()
|
||||
|
||||
// Fail the first 2 attempts, succeed on the third.
|
||||
if current < 2 {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Test downloadBinary directly to avoid time.Sleep in
|
||||
// ensureBinary's retry loop. We call it 3 times to simulate
|
||||
// what ensureBinary would do.
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
var lastErr error
|
||||
for i := range 3 {
|
||||
lastErr = downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
|
||||
if lastErr == nil {
|
||||
break
|
||||
}
|
||||
if i < 2 {
|
||||
// In the real code, ensureBinary sleeps here.
|
||||
// We skip the sleep in tests.
|
||||
continue
|
||||
}
|
||||
}
|
||||
require.NoError(t, lastErr, "download should succeed on the third attempt")
|
||||
|
||||
got, err := os.ReadFile(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
|
||||
mu.Lock()
|
||||
assert.Equal(t, 3, attempt, "server should have been hit 3 times")
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// Ensure that portableDesktop satisfies the Desktop interface at
|
||||
// compile time. This uses the unexported type so it lives in the
|
||||
// internal test package.
|
||||
var _ Desktop = (*portableDesktop)(nil)
|
||||
|
||||
// Silence the linter about unused imports — agentexec.DefaultExecer
|
||||
// is used in TestEnsureBinary_UsesCachedBinPath and others, and
|
||||
// fmt.Sscanf is used indirectly via the implementation.
|
||||
var (
|
||||
_ = agentexec.DefaultExecer
|
||||
_ = fmt.Sprintf
|
||||
)
|
||||
@@ -30,6 +30,7 @@ func (a *agent) apiHandler() http.Handler {
|
||||
r.Mount("/api/v0", a.filesAPI.Routes())
|
||||
r.Mount("/api/v0/git", a.gitAPI.Routes())
|
||||
r.Mount("/api/v0/processes", a.processAPI.Routes())
|
||||
r.Mount("/api/v0/desktop", a.desktopAPI.Routes())
|
||||
|
||||
if a.devcontainers {
|
||||
r.Mount("/api/v0/containers", a.containerAPI.Routes())
|
||||
|
||||
Generated
+43
@@ -481,6 +481,49 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/archive": {
|
||||
"post": {
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Archive a chat",
|
||||
"operationId": "archive-chat",
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/desktop": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Watch chat desktop",
|
||||
"operationId": "watch-chat-desktop",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"101": {
|
||||
"description": "Switching Protocols"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/connectionlog": {
|
||||
"get": {
|
||||
"security": [
|
||||
|
||||
Generated
+39
@@ -410,6 +410,45 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/archive": {
|
||||
"post": {
|
||||
"tags": ["Chats"],
|
||||
"summary": "Archive a chat",
|
||||
"operationId": "archive-chat",
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/desktop": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Watch chat desktop",
|
||||
"operationId": "watch-chat-desktop",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"101": {
|
||||
"description": "Switching Protocols"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/connectionlog": {
|
||||
"get": {
|
||||
"security": [
|
||||
|
||||
+58
-18
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/webpush"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -179,6 +180,7 @@ type CreateOptions struct {
|
||||
RootChatID uuid.NullUUID
|
||||
Title string
|
||||
ModelConfigID uuid.UUID
|
||||
ChatMode database.NullChatMode
|
||||
SystemPrompt string
|
||||
InitialUserContent []codersdk.ChatMessagePart
|
||||
}
|
||||
@@ -262,6 +264,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
RootChatID: opts.RootChatID,
|
||||
LastModelConfigID: opts.ModelConfigID,
|
||||
Title: opts.Title,
|
||||
Mode: opts.ChatMode,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert chat: %w", err)
|
||||
@@ -2143,10 +2146,13 @@ func (p *Server) runChat(
|
||||
// Fire title generation asynchronously so it doesn't block the
|
||||
// chat response. It uses a detached context so it can finish
|
||||
// even after the chat processing context is canceled.
|
||||
// Snapshot model so the goroutine doesn't race with the
|
||||
// model = cuModel reassignment below.
|
||||
titleModel := model
|
||||
p.inflight.Add(1)
|
||||
go func() {
|
||||
defer p.inflight.Done()
|
||||
p.maybeGenerateChatTitle(context.WithoutCancel(ctx), chat, messages, model, providerKeys, logger)
|
||||
p.maybeGenerateChatTitle(context.WithoutCancel(ctx), chat, messages, titleModel, providerKeys, logger)
|
||||
}()
|
||||
|
||||
prompt, err := chatprompt.ConvertMessagesWithFiles(ctx, messages, p.chatFileResolver(), logger)
|
||||
@@ -2157,6 +2163,9 @@ func (p *Server) runChat(
|
||||
prompt = chatprompt.InsertSystem(prompt, defaultSubagentInstruction)
|
||||
}
|
||||
|
||||
// Detect computer-use subagent via the mode column.
|
||||
isComputerUse := chat.Mode.Valid && chat.Mode.ChatMode == database.ChatModeComputerUse
|
||||
|
||||
// NOTE: Buffering was already started in processChat before
|
||||
// the running status was published, so message_part events
|
||||
// are captured from the moment subscribers can see
|
||||
@@ -2512,6 +2521,20 @@ func (p *Server) runChat(
|
||||
},
|
||||
}
|
||||
|
||||
if isComputerUse {
|
||||
// Override model for computer use subagent.
|
||||
cuModel, cuErr := chatprovider.ModelFromConfig(
|
||||
chattool.ComputerUseModelProvider,
|
||||
chattool.ComputerUseModelName,
|
||||
providerKeys,
|
||||
chatprovider.UserAgent(),
|
||||
)
|
||||
if cuErr != nil {
|
||||
return xerrors.Errorf("resolve computer use model: %w", cuErr)
|
||||
}
|
||||
model = cuModel
|
||||
}
|
||||
|
||||
// Here are all the tools we have for the chat.
|
||||
tools := []fantasy.AgentTool{
|
||||
chattool.ReadFile(chattool.ReadFileOptions{
|
||||
@@ -2568,23 +2591,34 @@ func (p *Server) runChat(
|
||||
WorkspaceMu: &workspaceMu,
|
||||
}),
|
||||
)
|
||||
tools = append(tools, p.subagentTools(func() database.Chat {
|
||||
tools = append(tools, p.subagentTools(ctx, func() database.Chat {
|
||||
return chat
|
||||
})...)
|
||||
}
|
||||
|
||||
// Build provider-native tools (e.g., web search) based on
|
||||
// the model configuration.
|
||||
var providerTools []fantasy.Tool
|
||||
var providerTools []chatloop.ProviderTool
|
||||
if callConfig.ProviderOptions != nil {
|
||||
providerTools = buildProviderTools(model.Provider(), callConfig.ProviderOptions)
|
||||
}
|
||||
|
||||
if isComputerUse {
|
||||
providerTools = append(providerTools, chatloop.ProviderTool{
|
||||
Definition: chattool.ComputerUseProviderTool(
|
||||
workspacesdk.DesktopDisplayWidth,
|
||||
workspacesdk.DesktopDisplayHeight),
|
||||
Runner: chattool.NewComputerUseTool(
|
||||
workspacesdk.DesktopDisplayWidth,
|
||||
workspacesdk.DesktopDisplayHeight,
|
||||
getWorkspaceConn, quartz.NewReal(),
|
||||
),
|
||||
})
|
||||
}
|
||||
err = chatloop.Run(ctx, chatloop.RunOptions{
|
||||
Model: model,
|
||||
Messages: prompt,
|
||||
Tools: tools,
|
||||
MaxSteps: maxChatSteps,
|
||||
Tools: tools, MaxSteps: maxChatSteps,
|
||||
|
||||
ModelConfig: callConfig,
|
||||
ProviderOptions: chatprovider.ProviderOptionsFromChatModelConfig(model, callConfig.ProviderOptions),
|
||||
@@ -2668,14 +2702,16 @@ func (p *Server) runChat(
|
||||
// buildProviderTools creates provider-native tool definitions
|
||||
// (like web search) based on the model configuration. These
|
||||
// tools are executed server-side by the LLM provider.
|
||||
func buildProviderTools(_ string, options *codersdk.ChatModelProviderOptions) []fantasy.Tool {
|
||||
var tools []fantasy.Tool
|
||||
func buildProviderTools(_ string, options *codersdk.ChatModelProviderOptions) []chatloop.ProviderTool {
|
||||
var tools []chatloop.ProviderTool
|
||||
|
||||
if options.Anthropic != nil && options.Anthropic.WebSearchEnabled != nil && *options.Anthropic.WebSearchEnabled {
|
||||
tools = append(tools, anthropic.WebSearchTool(&anthropic.WebSearchToolOptions{
|
||||
AllowedDomains: options.Anthropic.AllowedDomains,
|
||||
BlockedDomains: options.Anthropic.BlockedDomains,
|
||||
}))
|
||||
tools = append(tools, chatloop.ProviderTool{
|
||||
Definition: anthropic.WebSearchTool(&anthropic.WebSearchToolOptions{
|
||||
AllowedDomains: options.Anthropic.AllowedDomains,
|
||||
BlockedDomains: options.Anthropic.BlockedDomains,
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
if options.OpenAI != nil && options.OpenAI.WebSearchEnabled != nil && *options.OpenAI.WebSearchEnabled {
|
||||
@@ -2686,17 +2722,21 @@ func buildProviderTools(_ string, options *codersdk.ChatModelProviderOptions) []
|
||||
if len(options.OpenAI.AllowedDomains) > 0 {
|
||||
args["allowed_domains"] = options.OpenAI.AllowedDomains
|
||||
}
|
||||
tools = append(tools, fantasy.ProviderDefinedTool{
|
||||
ID: "web_search",
|
||||
Name: "web_search",
|
||||
Args: args,
|
||||
tools = append(tools, chatloop.ProviderTool{
|
||||
Definition: fantasy.ProviderDefinedTool{
|
||||
ID: "web_search",
|
||||
Name: "web_search",
|
||||
Args: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if options.Google != nil && options.Google.WebSearchEnabled != nil && *options.Google.WebSearchEnabled {
|
||||
tools = append(tools, fantasy.ProviderDefinedTool{
|
||||
ID: "web_search",
|
||||
Name: "web_search",
|
||||
tools = append(tools, chatloop.ProviderTool{
|
||||
Definition: fantasy.ProviderDefinedTool{
|
||||
ID: "web_search",
|
||||
Name: "web_search",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,9 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -14,12 +17,15 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
@@ -28,6 +34,8 @@ import (
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
proto "github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -1767,3 +1775,349 @@ func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) {
|
||||
require.NotEqual(t, "Agent has finished running.", msg.Body,
|
||||
"push body should not use the default fallback text")
|
||||
}
|
||||
|
||||
func TestComputerUseSubagentToolsAndModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Track tools and model from the Anthropic LLM calls (the
|
||||
// computer use child chat). We use a raw HTTP handler because
|
||||
// the chattest AnthropicRequest struct does not capture tools.
|
||||
type anthropicCall struct {
|
||||
Model string
|
||||
Tools []string
|
||||
}
|
||||
var anthropicMu sync.Mutex
|
||||
var anthropicCalls []anthropicCall
|
||||
|
||||
anthropicSrv := httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"tools"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
names := make([]string, len(req.Tools))
|
||||
for i, tool := range req.Tools {
|
||||
names[i] = tool.Name
|
||||
}
|
||||
anthropicMu.Lock()
|
||||
anthropicCalls = append(anthropicCalls, anthropicCall{
|
||||
Model: req.Model,
|
||||
Tools: names,
|
||||
})
|
||||
anthropicMu.Unlock()
|
||||
|
||||
if !req.Stream {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg-test",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": chattool.ComputerUseModelName,
|
||||
"content": []map[string]any{{"type": "text", "text": "Done."}},
|
||||
"stop_reason": "end_turn",
|
||||
"usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Stream a minimal Anthropic SSE response.
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
chunks := []map[string]any{
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": "msg-test",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": chattool.ComputerUseModelName,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": "Done.",
|
||||
},
|
||||
},
|
||||
{"type": "content_block_stop", "index": 0},
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{"stop_reason": "end_turn"},
|
||||
"usage": map[string]any{"output_tokens": 5},
|
||||
},
|
||||
{"type": "message_stop"},
|
||||
}
|
||||
|
||||
for _, chunk := range chunks {
|
||||
chunkBytes, _ := json.Marshal(chunk)
|
||||
eventType, _ := chunk["type"].(string)
|
||||
_, _ = fmt.Fprintf(w, "event: %s\ndata: %s\n\n",
|
||||
eventType, chunkBytes)
|
||||
flusher.Flush()
|
||||
}
|
||||
},
|
||||
))
|
||||
t.Cleanup(anthropicSrv.Close)
|
||||
|
||||
// OpenAI mock for the root chat. The first streaming call
|
||||
// triggers spawn_computer_use_agent; subsequent calls reply
|
||||
// with text.
|
||||
var openAICallCount atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
if openAICallCount.Add(1) == 1 {
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAIToolCallChunk(
|
||||
"spawn_computer_use_agent",
|
||||
`{"prompt":"do the desktop thing","title":"cu-sub"}`,
|
||||
),
|
||||
)
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("Done.")...,
|
||||
)
|
||||
})
|
||||
|
||||
// Seed the DB: user, openai-compat provider, model config.
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai-compat",
|
||||
DisplayName: "OpenAI Compat",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: openAIURL,
|
||||
CreatedBy: uuid.NullUUID{},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai-compat",
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{},
|
||||
UpdatedBy: uuid.NullUUID{},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add an Anthropic provider pointing to our mock server.
|
||||
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
APIKey: "test-anthropic-key",
|
||||
BaseUrl: anthropicSrv.URL,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Build workspace + agent records so getWorkspaceConn can
|
||||
// resolve the agent for the computer use child.
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
})
|
||||
tpl := dbgen.Template(t, db, database.Template{
|
||||
CreatedBy: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
ActiveVersionID: tv.ID,
|
||||
})
|
||||
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
TemplateID: tpl.ID,
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
InitiatorID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
TemplateVersionID: tv.ID,
|
||||
WorkspaceID: ws.ID,
|
||||
JobID: pj.ID,
|
||||
})
|
||||
res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
JobID: pj.ID,
|
||||
})
|
||||
dbAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
ResourceID: res.ID,
|
||||
})
|
||||
|
||||
// Mock agent connection that returns valid display dimensions
|
||||
// for the initial screenshot check in the computer use path.
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
mockConn.EXPECT().
|
||||
ExecuteDesktopAction(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.DesktopActionResponse{
|
||||
ScreenshotWidth: 1920,
|
||||
ScreenshotHeight: 1080,
|
||||
ScreenshotData: "iVBOR",
|
||||
}, nil).
|
||||
AnyTimes()
|
||||
mockConn.EXPECT().
|
||||
SetExtraHeaders(gomock.Any()).
|
||||
AnyTimes()
|
||||
mockConn.EXPECT().
|
||||
LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.LSResponse{}, xerrors.New("not found")).
|
||||
AnyTimes()
|
||||
|
||||
agentConnFn := func(
|
||||
_ context.Context, agentID uuid.UUID,
|
||||
) (workspacesdk.AgentConn, func(), error) {
|
||||
require.Equal(t, dbAgent.ID, agentID)
|
||||
return mockConn, func() {}, nil
|
||||
}
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: 10 * time.Millisecond,
|
||||
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
||||
AgentConn: agentConnFn,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
// Create a root chat with a workspace so the child inherits it.
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "computer-use-detection",
|
||||
ModelConfigID: model.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("Use the desktop to check the UI"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the root chat AND the computer use child to finish.
|
||||
// The root chat spawns the child, then the chatd server picks
|
||||
// up and runs the child (which hits the Anthropic mock).
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
if got.Status != database.ChatStatusWaiting &&
|
||||
got.Status != database.ChatStatusError {
|
||||
return false
|
||||
}
|
||||
// Ensure the Anthropic mock received at least one call.
|
||||
anthropicMu.Lock()
|
||||
n := len(anthropicCalls)
|
||||
anthropicMu.Unlock()
|
||||
return n >= 1
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
anthropicMu.Lock()
|
||||
calls := append([]anthropicCall(nil), anthropicCalls...)
|
||||
anthropicMu.Unlock()
|
||||
|
||||
require.NotEmpty(t, calls,
|
||||
"expected at least one Anthropic LLM call")
|
||||
|
||||
childModel := calls[0].Model
|
||||
childTools := calls[0].Tools
|
||||
|
||||
// 1. Verify the model is the computer use model.
|
||||
require.Equal(t, chattool.ComputerUseModelName, childModel,
|
||||
"computer use subagent should use %s",
|
||||
chattool.ComputerUseModelName)
|
||||
|
||||
// 2. Verify the computer tool is present.
|
||||
require.Contains(t, childTools, "computer",
|
||||
"computer use subagent should have the computer tool")
|
||||
|
||||
// 3. Verify standard workspace tools are present (the same
|
||||
// set a regular subagent gets).
|
||||
standardTools := []string{
|
||||
"read_file", "write_file", "edit_files", "execute",
|
||||
"process_output", "process_list", "process_signal",
|
||||
}
|
||||
for _, tool := range standardTools {
|
||||
require.Contains(t, childTools, tool,
|
||||
"computer use subagent should have standard tool %q",
|
||||
tool)
|
||||
}
|
||||
|
||||
// 4. Verify workspace provisioning tools are NOT present.
|
||||
workspaceProvisioningTools := []string{
|
||||
"list_templates", "read_template",
|
||||
"create_workspace", "start_workspace",
|
||||
}
|
||||
for _, tool := range workspaceProvisioningTools {
|
||||
require.NotContains(t, childTools, tool,
|
||||
"computer use subagent should NOT have workspace "+
|
||||
"provisioning tool %q", tool)
|
||||
}
|
||||
|
||||
// 5. Verify subagent tools are NOT present.
|
||||
subagentTools := []string{
|
||||
"spawn_agent", "spawn_computer_use_agent",
|
||||
"wait_agent", "message_agent", "close_agent",
|
||||
}
|
||||
for _, tool := range subagentTools {
|
||||
require.NotContains(t, childTools, tool,
|
||||
"computer use subagent should NOT have subagent "+
|
||||
"tool %q", tool)
|
||||
}
|
||||
|
||||
// 6. Verify the child chat has Mode = computer_use in
|
||||
// the DB.
|
||||
allChats, err := db.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
|
||||
OwnerID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
var children []database.Chat
|
||||
for _, c := range allChats {
|
||||
if c.ParentChatID.Valid && c.ParentChatID.UUID == chat.ID {
|
||||
children = append(children, c)
|
||||
}
|
||||
}
|
||||
require.Len(t, children, 1)
|
||||
require.True(t, children[0].Mode.Valid)
|
||||
require.Equal(t, database.ChatModeComputerUse,
|
||||
children[0].Mode.ChatMode)
|
||||
}
|
||||
|
||||
@@ -63,11 +63,12 @@ type RunOptions struct {
|
||||
// of the provider, which lives in chatd, not chatloop.
|
||||
ProviderOptions fantasy.ProviderOptions
|
||||
|
||||
// ProviderTools are provider-native tools (like web search)
|
||||
// that are passed directly to the provider API alongside
|
||||
// function tool definitions. These are not necessarily
|
||||
// executed server-side; handling is provider-specific.
|
||||
ProviderTools []fantasy.Tool
|
||||
// ProviderTools are provider-native tools (like web search
|
||||
// and computer use) whose definitions are passed directly
|
||||
// to the provider API. When a ProviderTool has a non-nil
|
||||
// Runner, tool calls are executed locally; otherwise the
|
||||
// provider handles execution (e.g. web search).
|
||||
ProviderTools []ProviderTool
|
||||
|
||||
PersistStep func(context.Context, PersistedStep) error
|
||||
PublishMessagePart func(
|
||||
@@ -88,6 +89,16 @@ type RunOptions struct {
|
||||
OnInterruptedPersistError func(error)
|
||||
}
|
||||
|
||||
// ProviderTool pairs a provider-native tool definition with an
|
||||
// optional local executor. When Runner is nil the tool is fully
|
||||
// provider-executed (e.g. web search). When Runner is non-nil
|
||||
// the definition is sent to the API but execution is handled
|
||||
// locally (e.g. computer use).
|
||||
type ProviderTool struct {
|
||||
Definition fantasy.Tool
|
||||
Runner fantasy.AgentTool
|
||||
}
|
||||
|
||||
// stepResult holds the accumulated output of a single streaming
|
||||
// step. Since we own the stream consumer, all content is tracked
|
||||
// directly here — no shadow draft state needed.
|
||||
@@ -315,7 +326,7 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
toolResults = executeTools(ctx, opts.Tools, result.toolCalls, func(tr fantasy.ToolResultContent) {
|
||||
toolResults = executeTools(ctx, opts.Tools, opts.ProviderTools, result.toolCalls, func(tr fantasy.ToolResultContent) {
|
||||
publishMessagePart(
|
||||
codersdk.ChatMessageRoleTool,
|
||||
chatprompt.PartFromContent(tr),
|
||||
@@ -639,6 +650,7 @@ func processStepStream(
|
||||
func executeTools(
|
||||
ctx context.Context,
|
||||
allTools []fantasy.AgentTool,
|
||||
providerTools []ProviderTool,
|
||||
toolCalls []fantasy.ToolCallContent,
|
||||
onResult func(fantasy.ToolResultContent),
|
||||
) []fantasy.ToolResultContent {
|
||||
@@ -664,6 +676,13 @@ func executeTools(
|
||||
for _, t := range allTools {
|
||||
toolMap[t.Info().Name] = t
|
||||
}
|
||||
// Include runners from provider tools so locally-executed
|
||||
// provider tools (e.g. computer use) can be dispatched.
|
||||
for _, pt := range providerTools {
|
||||
if pt.Runner != nil {
|
||||
toolMap[pt.Runner.Info().Name] = pt.Runner
|
||||
}
|
||||
}
|
||||
|
||||
results := make([]fantasy.ToolResultContent, len(localToolCalls))
|
||||
var wg sync.WaitGroup
|
||||
@@ -863,15 +882,16 @@ func persistInterruptedStep(
|
||||
// buildToolDefinitions converts AgentTool definitions into the
|
||||
// fantasy.Tool slice expected by fantasy.Call. When activeTools
|
||||
// is non-empty, only function tools whose name appears in the
|
||||
// list are included. Provider tools bypass this filter and are
|
||||
// always appended unconditionally.
|
||||
func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string, providerTools []fantasy.Tool) []fantasy.Tool {
|
||||
prepared := make([]fantasy.Tool, 0, len(tools))
|
||||
// list are included. Provider tool definitions are always
|
||||
// appended unconditionally.
|
||||
func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string, providerTools []ProviderTool) []fantasy.Tool {
|
||||
prepared := make([]fantasy.Tool, 0, len(tools)+len(providerTools))
|
||||
for _, tool := range tools {
|
||||
info := tool.Info()
|
||||
if len(activeTools) > 0 && !slices.Contains(activeTools, info.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
inputSchema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": info.Parameters,
|
||||
@@ -885,7 +905,9 @@ func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string, provi
|
||||
ProviderOptions: tool.ProviderOptions(),
|
||||
})
|
||||
}
|
||||
prepared = append(prepared, providerTools...)
|
||||
for _, pt := range providerTools {
|
||||
prepared = append(prepared, pt.Definition)
|
||||
}
|
||||
return prepared
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
// ComputerUseModelProvider is the provider for the computer
|
||||
// use model.
|
||||
ComputerUseModelProvider = "anthropic"
|
||||
// ComputerUseModelName is the model used for computer use
|
||||
// subagents.
|
||||
ComputerUseModelName = "claude-opus-4-6"
|
||||
)
|
||||
|
||||
// computerUseTool implements fantasy.AgentTool and
|
||||
// chatloop.ToolDefiner for Anthropic computer use.
|
||||
type computerUseTool struct {
|
||||
displayWidth int
|
||||
displayHeight int
|
||||
getWorkspaceConn func(ctx context.Context) (workspacesdk.AgentConn, error)
|
||||
providerOptions fantasy.ProviderOptions
|
||||
clock quartz.Clock
|
||||
}
|
||||
|
||||
// NewComputerUseTool creates a computer use AgentTool that
|
||||
// delegates to the agent's desktop endpoints.
|
||||
func NewComputerUseTool(
|
||||
displayWidth, displayHeight int,
|
||||
getWorkspaceConn func(ctx context.Context) (workspacesdk.AgentConn, error),
|
||||
clock quartz.Clock,
|
||||
) fantasy.AgentTool {
|
||||
return &computerUseTool{
|
||||
displayWidth: displayWidth,
|
||||
displayHeight: displayHeight,
|
||||
getWorkspaceConn: getWorkspaceConn,
|
||||
clock: clock,
|
||||
}
|
||||
}
|
||||
|
||||
func (*computerUseTool) Info() fantasy.ToolInfo {
|
||||
return fantasy.ToolInfo{
|
||||
Name: "computer",
|
||||
Description: "Control the desktop: take screenshots, move the mouse, click, type, and scroll.",
|
||||
Parameters: map[string]any{},
|
||||
Required: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// ComputerUseProviderTool creates the provider-defined tool
|
||||
// definition for Anthropic computer use. This is passed via
|
||||
// ProviderTools so the API receives the correct wire format.
|
||||
func ComputerUseProviderTool(displayWidth, displayHeight int) fantasy.Tool {
|
||||
return fantasyanthropic.NewComputerUseTool(
|
||||
fantasyanthropic.ComputerUseToolOptions{
|
||||
DisplayWidthPx: int64(displayWidth),
|
||||
DisplayHeightPx: int64(displayHeight),
|
||||
ToolVersion: fantasyanthropic.ComputerUse20251124,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (t *computerUseTool) ProviderOptions() fantasy.ProviderOptions {
|
||||
return t.providerOptions
|
||||
}
|
||||
|
||||
func (t *computerUseTool) SetProviderOptions(opts fantasy.ProviderOptions) {
|
||||
t.providerOptions = opts
|
||||
}
|
||||
|
||||
func (t *computerUseTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
input, err := fantasyanthropic.ParseComputerUseInput(call.Input)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("invalid computer use input: %v", err),
|
||||
), nil
|
||||
}
|
||||
|
||||
conn, err := t.getWorkspaceConn(ctx)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("failed to connect to workspace: %v", err),
|
||||
), nil
|
||||
}
|
||||
|
||||
// Compute scaled screenshot size for Anthropic constraints.
|
||||
scaledW, scaledH := computeScaledScreenshotSize(
|
||||
t.displayWidth, t.displayHeight,
|
||||
)
|
||||
|
||||
// For wait actions, sleep then return a screenshot.
|
||||
if input.Action == fantasyanthropic.ActionWait {
|
||||
d := input.Duration
|
||||
if d <= 0 {
|
||||
d = 1000
|
||||
}
|
||||
timer := t.clock.NewTimer(time.Duration(d)*time.Millisecond, "computeruse", "wait")
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-timer.C:
|
||||
}
|
||||
screenshotAction := workspacesdk.DesktopAction{
|
||||
Action: "screenshot",
|
||||
ScaledWidth: &scaledW,
|
||||
ScaledHeight: &scaledH,
|
||||
}
|
||||
screenResp, sErr := conn.ExecuteDesktopAction(ctx, screenshotAction)
|
||||
if sErr != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("screenshot failed: %v", sErr),
|
||||
), nil
|
||||
}
|
||||
return fantasy.NewImageResponse(
|
||||
[]byte(screenResp.ScreenshotData), "image/png",
|
||||
), nil
|
||||
}
|
||||
|
||||
// For screenshot action, use ExecuteDesktopAction.
|
||||
if input.Action == fantasyanthropic.ActionScreenshot {
|
||||
screenshotAction := workspacesdk.DesktopAction{
|
||||
Action: "screenshot",
|
||||
ScaledWidth: &scaledW,
|
||||
ScaledHeight: &scaledH,
|
||||
}
|
||||
screenResp, sErr := conn.ExecuteDesktopAction(ctx, screenshotAction)
|
||||
if sErr != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("screenshot failed: %v", sErr),
|
||||
), nil
|
||||
}
|
||||
return fantasy.NewImageResponse(
|
||||
[]byte(screenResp.ScreenshotData), "image/png",
|
||||
), nil
|
||||
}
|
||||
|
||||
// Build the action request.
|
||||
action := workspacesdk.DesktopAction{
|
||||
Action: string(input.Action),
|
||||
ScaledWidth: &scaledW,
|
||||
ScaledHeight: &scaledH,
|
||||
}
|
||||
if input.Coordinate != ([2]int64{}) {
|
||||
coord := [2]int{int(input.Coordinate[0]), int(input.Coordinate[1])}
|
||||
action.Coordinate = &coord
|
||||
}
|
||||
if input.StartCoordinate != ([2]int64{}) {
|
||||
coord := [2]int{int(input.StartCoordinate[0]), int(input.StartCoordinate[1])}
|
||||
action.StartCoordinate = &coord
|
||||
}
|
||||
if input.Text != "" {
|
||||
action.Text = &input.Text
|
||||
}
|
||||
if input.Duration > 0 {
|
||||
d := int(input.Duration)
|
||||
action.Duration = &d
|
||||
}
|
||||
if input.ScrollAmount > 0 {
|
||||
s := int(input.ScrollAmount)
|
||||
action.ScrollAmount = &s
|
||||
}
|
||||
if input.ScrollDirection != "" {
|
||||
action.ScrollDirection = &input.ScrollDirection
|
||||
}
|
||||
|
||||
// Execute the action.
|
||||
_, err = conn.ExecuteDesktopAction(ctx, action)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("action %q failed: %v", input.Action, err),
|
||||
), nil
|
||||
}
|
||||
|
||||
// Take a screenshot after every action (Anthropic pattern).
|
||||
screenshotAction := workspacesdk.DesktopAction{
|
||||
Action: "screenshot",
|
||||
ScaledWidth: &scaledW,
|
||||
ScaledHeight: &scaledH,
|
||||
}
|
||||
screenResp, sErr := conn.ExecuteDesktopAction(ctx, screenshotAction)
|
||||
if sErr != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("screenshot failed: %v", sErr),
|
||||
), nil
|
||||
}
|
||||
|
||||
return fantasy.NewImageResponse(
|
||||
[]byte(screenResp.ScreenshotData), "image/png",
|
||||
), nil
|
||||
}
|
||||
|
||||
// computeScaledScreenshotSize computes the target screenshot
|
||||
// dimensions to fit within Anthropic's constraints.
|
||||
func computeScaledScreenshotSize(width, height int) (scaledWidth int, scaledHeight int) {
|
||||
const maxLongEdge = 1568
|
||||
const maxTotalPixels = 1_150_000
|
||||
|
||||
longEdge := max(width, height)
|
||||
totalPixels := width * height
|
||||
longEdgeScale := float64(maxLongEdge) / float64(longEdge)
|
||||
totalPixelsScale := math.Sqrt(
|
||||
float64(maxTotalPixels) / float64(totalPixels),
|
||||
)
|
||||
scale := min(1.0, longEdgeScale, totalPixelsScale)
|
||||
|
||||
if scale >= 1.0 {
|
||||
return width, height
|
||||
}
|
||||
return max(1, int(float64(width)*scale)),
|
||||
max(1, int(float64(height)*scale))
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestComputeScaledScreenshotSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
width, height int
|
||||
wantW, wantH int
|
||||
}{
|
||||
{
|
||||
name: "1920x1080_scales_down",
|
||||
width: 1920,
|
||||
height: 1080,
|
||||
wantW: 1429,
|
||||
wantH: 804,
|
||||
},
|
||||
{
|
||||
name: "1280x800_no_scaling",
|
||||
width: 1280,
|
||||
height: 800,
|
||||
wantW: 1280,
|
||||
wantH: 800,
|
||||
},
|
||||
{
|
||||
name: "3840x2160_large_display",
|
||||
width: 3840,
|
||||
height: 2160,
|
||||
wantW: 1429,
|
||||
wantH: 804,
|
||||
},
|
||||
{
|
||||
name: "1568x1000_pixel_cap_applies",
|
||||
width: 1568,
|
||||
height: 1000,
|
||||
wantW: 1342,
|
||||
wantH: 856,
|
||||
},
|
||||
{
|
||||
name: "100x100_small_display",
|
||||
width: 100,
|
||||
height: 100,
|
||||
wantW: 100,
|
||||
wantH: 100,
|
||||
},
|
||||
{
|
||||
name: "4000x3000_stays_within_limits",
|
||||
width: 4000,
|
||||
// Both constraints apply. The function should keep
|
||||
// the result within maxLongEdge=1568 and
|
||||
// totalPixels<=1,150,000.
|
||||
height: 3000,
|
||||
wantW: 1238,
|
||||
wantH: 928,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
gotW, gotH := computeScaledScreenshotSize(tt.width, tt.height)
|
||||
assert.Equal(t, tt.wantW, gotW)
|
||||
assert.Equal(t, tt.wantH, gotH)
|
||||
|
||||
// Invariant: results must respect Anthropic constraints.
|
||||
const maxLongEdge = 1568
|
||||
const maxTotalPixels = 1_150_000
|
||||
longEdge := max(gotW, gotH)
|
||||
assert.LessOrEqual(t, longEdge, maxLongEdge,
|
||||
"long edge %d exceeds max %d", longEdge, maxLongEdge)
|
||||
assert.LessOrEqual(t, gotW*gotH, maxTotalPixels,
|
||||
"total pixels %d exceeds max %d", gotW*gotH, maxTotalPixels)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package chattool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestComputerUseTool_Info(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, nil, quartz.NewReal())
|
||||
info := tool.Info()
|
||||
assert.Equal(t, "computer", info.Name)
|
||||
assert.NotEmpty(t, info.Description)
|
||||
}
|
||||
|
||||
func TestComputerUseProviderTool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
def := chattool.ComputerUseProviderTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight)
|
||||
pdt, ok := def.(fantasy.ProviderDefinedTool)
|
||||
require.True(t, ok, "ComputerUseProviderTool should return a ProviderDefinedTool")
|
||||
assert.Contains(t, pdt.ID, "computer")
|
||||
assert.Equal(t, "computer", pdt.Name)
|
||||
// Verify display dimensions are passed through.
|
||||
assert.Equal(t, int64(workspacesdk.DesktopDisplayWidth), pdt.Args["display_width_px"])
|
||||
assert.Equal(t, int64(workspacesdk.DesktopDisplayHeight), pdt.Args["display_height_px"])
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "base64png",
|
||||
ScreenshotWidth: 1024,
|
||||
ScreenshotHeight: 768,
|
||||
}, nil)
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-1",
|
||||
Name: "computer",
|
||||
Input: `{"action":"screenshot"}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, "image/png", resp.MediaType)
|
||||
assert.Equal(t, []byte("base64png"), resp.Data)
|
||||
assert.False(t, resp.IsError)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_LeftClick(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
// Expect the action call first.
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "left_click performed",
|
||||
}, nil)
|
||||
|
||||
// Then expect a screenshot (auto-screenshot after action).
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "after-click",
|
||||
ScreenshotWidth: 1024,
|
||||
ScreenshotHeight: 768,
|
||||
}, nil)
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-2",
|
||||
Name: "computer",
|
||||
Input: `{"action":"left_click","coordinate":[100,200]}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, []byte("after-click"), resp.Data)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_Wait(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
// Expect a screenshot after the wait completes.
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "after-wait",
|
||||
ScreenshotWidth: 1024,
|
||||
ScreenshotHeight: 768,
|
||||
}, nil)
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-3",
|
||||
Name: "computer",
|
||||
Input: `{"action":"wait","duration":10}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, "image/png", resp.MediaType)
|
||||
assert.Equal(t, []byte("after-wait"), resp.Data)
|
||||
assert.False(t, resp.IsError)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_ConnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return nil, xerrors.New("workspace not available")
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-4",
|
||||
Name: "computer",
|
||||
Input: `{"action":"screenshot"}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.Contains(t, resp.Content, "workspace not available")
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return nil, xerrors.New("should not be called")
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-5",
|
||||
Name: "computer",
|
||||
Input: `{invalid json`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.Contains(t, resp.Content, "invalid computer use input")
|
||||
}
|
||||
+123
-2
@@ -13,6 +13,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -26,11 +27,30 @@ const (
|
||||
defaultSubagentWaitTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
// computerUseSubagentSystemPrompt is the system prompt prepended to
|
||||
// every computer use subagent chat. It instructs the model on how to
|
||||
// interact with the desktop environment via the computer tool.
|
||||
const computerUseSubagentSystemPrompt = `You are a computer use agent with access to a desktop environment. You can see the screen, move the mouse, click, type, scroll, and drag.
|
||||
|
||||
Your primary tool is the "computer" tool which lets you interact with the desktop. After every action you take, you will receive a screenshot showing the current state of the screen. Use these screenshots to verify your actions and plan next steps.
|
||||
|
||||
Guidelines:
|
||||
- Always start by taking a screenshot to see the current state of the desktop.
|
||||
- Be precise with coordinates when clicking or typing.
|
||||
- Wait for UI elements to load before interacting with them.
|
||||
- If an action doesn't produce the expected result, try alternative approaches.
|
||||
- Report what you accomplished when done.`
|
||||
|
||||
type spawnAgentArgs struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Title string `json:"title,omitempty"`
|
||||
}
|
||||
|
||||
type spawnComputerUseAgentArgs struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Title string `json:"title,omitempty"`
|
||||
}
|
||||
|
||||
type waitAgentArgs struct {
|
||||
ChatID string `json:"chat_id"`
|
||||
TimeoutSeconds *int `json:"timeout_seconds,omitempty"`
|
||||
@@ -46,8 +66,26 @@ type closeAgentArgs struct {
|
||||
ChatID string `json:"chat_id"`
|
||||
}
|
||||
|
||||
func (p *Server) subagentTools(currentChat func() database.Chat) []fantasy.AgentTool {
|
||||
return []fantasy.AgentTool{
|
||||
// isAnthropicConfigured reports whether an Anthropic API key is
|
||||
// available, either from static provider keys or from the database.
|
||||
func (p *Server) isAnthropicConfigured(ctx context.Context) bool {
|
||||
if p.providerAPIKeys.APIKey("anthropic") != "" {
|
||||
return true
|
||||
}
|
||||
dbProviders, err := p.db.GetEnabledChatProviders(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, prov := range dbProviders {
|
||||
if chatprovider.NormalizeProvider(prov.Provider) == "anthropic" && strings.TrimSpace(prov.APIKey) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Server) subagentTools(ctx context.Context, currentChat func() database.Chat) []fantasy.AgentTool {
|
||||
tools := []fantasy.AgentTool{
|
||||
fantasy.NewAgentTool(
|
||||
"spawn_agent",
|
||||
"Spawn a delegated child agent to work on a clearly scoped, "+
|
||||
@@ -213,6 +251,89 @@ func (p *Server) subagentTools(currentChat func() database.Chat) []fantasy.Agent
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
// Only include the computer use tool when an Anthropic
|
||||
// provider is configured, since it requires an Anthropic
|
||||
// model.
|
||||
if p.isAnthropicConfigured(ctx) {
|
||||
tools = append(tools, fantasy.NewAgentTool(
|
||||
"spawn_computer_use_agent",
|
||||
"Spawn a dedicated computer use agent that can see the desktop "+
|
||||
"(take screenshots) and interact with it (mouse, keyboard, "+
|
||||
"scroll). The agent runs on a model optimized for computer "+
|
||||
"use and has the same workspace tools as a standard subagent "+
|
||||
"plus the native Anthropic computer tool. Use this for tasks "+
|
||||
"that require visual interaction with a desktop GUI (e.g. "+
|
||||
"browser automation, GUI testing, visual inspection). After "+
|
||||
"spawning, use wait_agent to collect the result.",
|
||||
func(ctx context.Context, args spawnComputerUseAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if currentChat == nil {
|
||||
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
|
||||
}
|
||||
|
||||
parent := currentChat()
|
||||
if parent.ParentChatID.Valid {
|
||||
return fantasy.NewTextErrorResponse("delegated chats cannot create child subagents"), nil
|
||||
}
|
||||
|
||||
parent, err := p.db.GetChatByID(ctx, parent.ID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
prompt := strings.TrimSpace(args.Prompt)
|
||||
if prompt == "" {
|
||||
return fantasy.NewTextErrorResponse("prompt is required"), nil
|
||||
}
|
||||
|
||||
title := strings.TrimSpace(args.Title)
|
||||
if title == "" {
|
||||
title = subagentFallbackChatTitle(prompt)
|
||||
}
|
||||
|
||||
rootChatID := parent.ID
|
||||
if parent.RootChatID.Valid {
|
||||
rootChatID = parent.RootChatID.UUID
|
||||
}
|
||||
if parent.LastModelConfigID == uuid.Nil {
|
||||
return fantasy.NewTextErrorResponse("parent chat model config id is required"), nil
|
||||
}
|
||||
|
||||
// Create the child chat with Mode set to
|
||||
// computer_use. This signals runChat to use the
|
||||
// predefined computer use model and include the
|
||||
// computer tool.
|
||||
childChat, err := p.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
WorkspaceID: parent.WorkspaceID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: rootChatID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: parent.LastModelConfigID,
|
||||
Title: title,
|
||||
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
SystemPrompt: computerUseSubagentSystemPrompt + "\n\n" + prompt,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
})
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
return toolJSONResponse(map[string]any{
|
||||
"chat_id": childChat.ID.String(),
|
||||
"title": childChat.Title,
|
||||
"status": string(childChat.Status),
|
||||
}), nil
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
func parseSubagentToolChatID(raw string) (uuid.UUID, error) {
|
||||
|
||||
@@ -0,0 +1,300 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestComputerUseSubagentSystemPrompt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Verify the system prompt constant is non-empty and contains
|
||||
// key instructions for the computer use agent.
|
||||
assert.NotEmpty(t, computerUseSubagentSystemPrompt)
|
||||
assert.Contains(t, computerUseSubagentSystemPrompt, "computer")
|
||||
assert.Contains(t, computerUseSubagentSystemPrompt, "screenshot")
|
||||
}
|
||||
|
||||
func TestSubagentFallbackChatTitle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "EmptyPrompt",
|
||||
input: "",
|
||||
want: "New Chat",
|
||||
},
|
||||
{
|
||||
name: "ShortPrompt",
|
||||
input: "Open Firefox",
|
||||
want: "Open Firefox",
|
||||
},
|
||||
{
|
||||
name: "LongPrompt",
|
||||
input: "Please open the Firefox browser and navigate to the settings page",
|
||||
want: "Please open the Firefox browser and...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := subagentFallbackChatTitle(tt.input)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// newInternalTestServer creates a Server for internal tests with
|
||||
// custom provider API keys. The server is automatically closed
|
||||
// when the test finishes.
|
||||
func newInternalTestServer(
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
ps pubsub.Pubsub,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
) *Server {
|
||||
t.Helper()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := New(Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
// Use a very long interval so the background loop
|
||||
// does not interfere with test assertions.
|
||||
PendingChatAcquireInterval: testutil.WaitLong,
|
||||
ProviderAPIKeys: keys,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
return server
|
||||
}
|
||||
|
||||
// seedInternalChatDeps inserts an OpenAI provider and model config
|
||||
// into the database and returns the created user and model. This
|
||||
// deliberately does NOT create an Anthropic provider.
|
||||
func seedInternalChatDeps(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
) (database.User, database.ChatModelConfig) {
|
||||
t.Helper()
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: "",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return user, model
|
||||
}
|
||||
|
||||
// findToolByName returns the tool with the given name from the
|
||||
// slice, or nil if no match is found.
|
||||
func findToolByName(tools []fantasy.AgentTool, name string) fantasy.AgentTool {
|
||||
for _, tool := range tools {
|
||||
if tool.Info().Name == name {
|
||||
return tool
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
// No Anthropic key in ProviderAPIKeys.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// Create a root parent chat.
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-no-anthropic",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Re-fetch so LastModelConfigID is populated from the DB.
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
|
||||
tool := findToolByName(tools, "spawn_computer_use_agent")
|
||||
assert.Nil(t, tool, "spawn_computer_use_agent tool must be omitted when Anthropic is not configured")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
// Provide an Anthropic key so the provider check passes.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// Create a root parent chat.
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "root-parent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a child chat under the parent.
|
||||
child, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Title: "child-subagent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do something")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Re-fetch the child so ParentChatID is populated.
|
||||
childChat, err := db.GetChatByID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, childChat.ParentChatID.Valid,
|
||||
"child chat must have a parent")
|
||||
|
||||
// Get tools as if the child chat is the current chat.
|
||||
tools := server.subagentTools(ctx, func() database.Chat { return childChat })
|
||||
tool := findToolByName(tools, "spawn_computer_use_agent")
|
||||
require.NotNil(t, tool, "spawn_computer_use_agent tool must be present")
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-2",
|
||||
Name: "spawn_computer_use_agent",
|
||||
Input: `{"prompt":"open browser"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, resp.IsError, "expected an error response")
|
||||
assert.Contains(t, resp.Content, "delegated chats cannot create child subagents")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
// Provide an Anthropic key so the tool can proceed.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// The parent uses an OpenAI model.
|
||||
require.Equal(t, "openai", model.Provider,
|
||||
"seed helper must create an OpenAI model")
|
||||
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-openai",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
|
||||
tool := findToolByName(tools, "spawn_computer_use_agent")
|
||||
require.NotNil(t, tool)
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-3",
|
||||
Name: "spawn_computer_use_agent",
|
||||
Input: `{"prompt":"take a screenshot"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.IsError, "expected success but got: %s", resp.Content)
|
||||
|
||||
// Parse the response to get the child chat ID.
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
childIDStr, ok := result["chat_id"].(string)
|
||||
require.True(t, ok, "response must contain chat_id")
|
||||
|
||||
childID, err := uuid.Parse(childIDStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
childChat, err := db.GetChatByID(ctx, childID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The child must have Mode=computer_use which causes
|
||||
// runChat to override the model to the predefined computer
|
||||
// use model instead of using the parent's model config.
|
||||
require.True(t, childChat.Mode.Valid)
|
||||
assert.Equal(t, database.ChatModeComputerUse, childChat.Mode.ChatMode)
|
||||
|
||||
// The predefined computer use model is Anthropic, which
|
||||
// differs from the parent's OpenAI model. This confirms
|
||||
// that the child will not inherit the parent's model at
|
||||
// runtime.
|
||||
assert.NotEqual(t, model.Provider, chattool.ComputerUseModelProvider,
|
||||
"computer use model provider must differ from parent model provider")
|
||||
assert.Equal(t, "anthropic", chattool.ComputerUseModelProvider)
|
||||
assert.NotEmpty(t, chattool.ComputerUseModelName)
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
package chatd_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestSpawnComputerUseAgent_CreatesChildWithChatMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newTestServer(t, db, ps, uuid.New())
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
// Create a parent chat.
|
||||
parent, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate what spawn_computer_use_agent does: set ChatMode
|
||||
// to computer_use and provide a system prompt.
|
||||
prompt := "Use the desktop to open Firefox"
|
||||
|
||||
child, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: model.ID,
|
||||
Title: "computer-use",
|
||||
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
SystemPrompt: "Computer use instructions\n\n" + prompt,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify parent-child relationship.
|
||||
require.True(t, child.ParentChatID.Valid)
|
||||
require.Equal(t, parent.ID, child.ParentChatID.UUID)
|
||||
|
||||
// Verify the chat type is set correctly.
|
||||
require.True(t, child.Mode.Valid)
|
||||
assert.Equal(t, database.ChatModeComputerUse, child.Mode.ChatMode)
|
||||
|
||||
// Confirm via a fresh DB read as well.
|
||||
got, err := db.GetChatByID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, got.Mode.Valid)
|
||||
assert.Equal(t, database.ChatModeComputerUse, got.Mode.ChatMode)
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_SystemPromptFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newTestServer(t, db, ps, uuid.New())
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
parent, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
prompt := "Navigate to settings page"
|
||||
systemPrompt := "Computer use instructions\n\n" + prompt
|
||||
|
||||
child, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: model.ID,
|
||||
Title: "computer-use-format",
|
||||
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
SystemPrompt: systemPrompt,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
messages, err := db.GetChatMessagesForPromptByChatID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The system message raw content is a JSON-encoded string.
|
||||
// It should contain the system prompt with the user prompt.
|
||||
var rawSystemContent string
|
||||
for _, msg := range messages {
|
||||
if msg.Role != "system" {
|
||||
continue
|
||||
}
|
||||
if msg.Content.Valid {
|
||||
rawSystemContent = string(msg.Content.RawMessage)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.Contains(t, rawSystemContent, prompt,
|
||||
"system prompt raw content should contain the user prompt")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_ChildIsListedUnderParent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newTestServer(t, db, ps, uuid.New())
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
parent, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
prompt := "Check the UI layout"
|
||||
|
||||
child, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: model.ID,
|
||||
Title: "computer-use-child",
|
||||
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
SystemPrompt: "Computer use instructions\n\n" + prompt,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the child is linked to the parent.
|
||||
fetchedChild, err := db.GetChatByID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, fetchedChild.ParentChatID.Valid)
|
||||
assert.Equal(t, parent.ID, fetchedChild.ParentChatID.UUID)
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_RootChatIDPropagation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newTestServer(t, db, ps, uuid.New())
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
// Create a root parent chat (no parent of its own).
|
||||
parent, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "root-parent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
prompt := "Take a screenshot"
|
||||
|
||||
child, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: model.ID,
|
||||
Title: "computer-use-root-test",
|
||||
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
SystemPrompt: "Computer use instructions\n\n" + prompt,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// When the parent has no RootChatID, the child's RootChatID
|
||||
// should point to the parent.
|
||||
require.True(t, child.RootChatID.Valid)
|
||||
assert.Equal(t, parent.ID, child.RootChatID.UUID)
|
||||
|
||||
// Verify chat was retrieved correctly from the DB.
|
||||
got, err := db.GetChatByID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, got.RootChatID.Valid)
|
||||
assert.Equal(t, parent.ID, got.RootChatID.UUID)
|
||||
}
|
||||
+117
@@ -25,6 +25,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
@@ -43,6 +44,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/searchquery"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/wsjson"
|
||||
"github.com/coder/websocket"
|
||||
@@ -735,6 +737,121 @@ proxyLoop:
|
||||
_ = clientStream.Close(websocket.StatusGoingAway)
|
||||
}
|
||||
|
||||
// @Summary Watch chat desktop
|
||||
// @ID watch-chat-desktop
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Chats
|
||||
// @Param chat path string true "Chat ID" format(uuid)
|
||||
// @Success 101
|
||||
// @Router /chats/{chat}/desktop [get]
|
||||
//
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) watchChatDesktop(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
chat = httpmw.ChatParam(r)
|
||||
logger = api.Logger.Named("chat_desktop").With(slog.F("chat_id", chat.ID))
|
||||
)
|
||||
|
||||
if !chat.WorkspaceID.Valid {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Chat has no workspace.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
agents, err := api.Database.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, chat.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspace agents.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(agents) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Chat workspace has no agents.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
apiAgent, err := db2sdk.WorkspaceAgent(
|
||||
api.DERPMap(),
|
||||
*api.TailnetCoordinator.Load(),
|
||||
agents[0],
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
api.AgentInactiveDisconnectTimeout,
|
||||
api.DeploymentValues.AgentFallbackTroubleshootingURL.String(),
|
||||
)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error reading workspace agent.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if apiAgent.Status != codersdk.WorkspaceAgentConnected {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Agent state is %q, must be connected.", apiAgent.Status),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer dialCancel()
|
||||
|
||||
agentConn, release, err := api.agentProvider.AgentConn(dialCtx, agents[0].ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to dial workspace agent.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer release()
|
||||
|
||||
desktopConn, err := agentConn.ConnectDesktopVNC(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to connect to agent desktop.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer desktopConn.Close()
|
||||
|
||||
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to accept websocket", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// No read limit — RFB framebuffer updates can be large.
|
||||
conn.SetReadLimit(-1)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
ctx, wsNetConn := workspaceapps.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
agentssh.Bicopy(ctx, wsNetConn, desktopConn)
|
||||
logger.Debug(ctx, "desktop Bicopy finished")
|
||||
}
|
||||
|
||||
// @Summary Archive a chat
|
||||
// @ID archive-chat
|
||||
// @Tags Chats
|
||||
// @Success 204
|
||||
// @Router /chats/{chat}/archive [post]
|
||||
func (api *API) archiveChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
|
||||
@@ -3802,6 +3802,41 @@ func decRef(value string) *decimal.Decimal {
|
||||
return &d
|
||||
}
|
||||
|
||||
func TestWatchChatDesktop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NoWorkspace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "desktop no workspace test",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to connect to the desktop endpoint — should fail because
|
||||
// chat has no workspace.
|
||||
res, err := client.Request(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
fmt.Sprintf("/api/experimental/chats/%s/desktop", createdChat.ID),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
}
|
||||
|
||||
func createChatModelConfig(t *testing.T, client *codersdk.Client) codersdk.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -1179,6 +1179,7 @@ func New(options *Options) *API {
|
||||
r.Use(httpmw.ExtractChatParam(options.Database))
|
||||
r.Get("/", api.getChat)
|
||||
r.Get("/git/watch", api.watchChatGit)
|
||||
r.Get("/desktop", api.watchChatDesktop)
|
||||
r.Post("/archive", api.archiveChat)
|
||||
r.Post("/unarchive", api.unarchiveChat)
|
||||
r.Get("/messages", api.getChatMessages)
|
||||
|
||||
Generated
+6
-1
@@ -278,6 +278,10 @@ CREATE TYPE chat_message_visibility AS ENUM (
|
||||
'both'
|
||||
);
|
||||
|
||||
CREATE TYPE chat_mode AS ENUM (
|
||||
'computer_use'
|
||||
);
|
||||
|
||||
CREATE TYPE chat_status AS ENUM (
|
||||
'waiting',
|
||||
'pending',
|
||||
@@ -1306,7 +1310,8 @@ CREATE TABLE chats (
|
||||
root_chat_id uuid,
|
||||
last_model_config_id uuid NOT NULL,
|
||||
archived boolean DEFAULT false NOT NULL,
|
||||
last_error text
|
||||
last_error text,
|
||||
mode chat_mode
|
||||
);
|
||||
|
||||
CREATE TABLE connection_logs (
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE chats DROP COLUMN mode;
|
||||
DROP TYPE IF EXISTS chat_mode;
|
||||
@@ -0,0 +1,3 @@
|
||||
CREATE TYPE chat_mode AS ENUM ('computer_use');
|
||||
|
||||
ALTER TABLE chats ADD COLUMN mode chat_mode;
|
||||
@@ -1174,6 +1174,61 @@ func AllChatMessageVisibilityValues() []ChatMessageVisibility {
|
||||
}
|
||||
}
|
||||
|
||||
type ChatMode string
|
||||
|
||||
const (
|
||||
ChatModeComputerUse ChatMode = "computer_use"
|
||||
)
|
||||
|
||||
func (e *ChatMode) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ChatMode(s)
|
||||
case string:
|
||||
*e = ChatMode(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ChatMode: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullChatMode struct {
|
||||
ChatMode ChatMode `json:"chat_mode"`
|
||||
Valid bool `json:"valid"` // Valid is true if ChatMode is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (ns *NullChatMode) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
ns.ChatMode, ns.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
ns.Valid = true
|
||||
return ns.ChatMode.Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (ns NullChatMode) Value() (driver.Value, error) {
|
||||
if !ns.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return string(ns.ChatMode), nil
|
||||
}
|
||||
|
||||
func (e ChatMode) Valid() bool {
|
||||
switch e {
|
||||
case ChatModeComputerUse:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AllChatModeValues() []ChatMode {
|
||||
return []ChatMode{
|
||||
ChatModeComputerUse,
|
||||
}
|
||||
}
|
||||
|
||||
type ChatStatus string
|
||||
|
||||
const (
|
||||
@@ -3972,6 +4027,7 @@ type Chat struct {
|
||||
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
|
||||
Archived bool `db:"archived" json:"archived"`
|
||||
LastError sql.NullString `db:"last_error" json:"last_error"`
|
||||
Mode NullChatMode `db:"mode" json:"mode"`
|
||||
}
|
||||
|
||||
type ChatDiffStatus struct {
|
||||
|
||||
@@ -2949,7 +2949,7 @@ WHERE
|
||||
$3::int
|
||||
)
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode
|
||||
`
|
||||
|
||||
type AcquireChatsParams struct {
|
||||
@@ -2985,6 +2985,7 @@ func (q *sqlQuerier) AcquireChats(ctx context.Context, arg AcquireChatsParams) (
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3172,7 +3173,7 @@ func (q *sqlQuerier) DeleteChatQueuedMessage(ctx context.Context, arg DeleteChat
|
||||
|
||||
const getChatByID = `-- name: GetChatByID :one
|
||||
SELECT
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
@@ -3198,12 +3199,13 @@ func (q *sqlQuerier) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatByIDForUpdate = `-- name: GetChatByIDForUpdate :one
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error FROM chats WHERE id = $1::uuid FOR UPDATE
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode FROM chats WHERE id = $1::uuid FOR UPDATE
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) {
|
||||
@@ -3225,6 +3227,7 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -3886,7 +3889,7 @@ func (q *sqlQuerier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID
|
||||
|
||||
const getChatsByOwnerID = `-- name: GetChatsByOwnerID :many
|
||||
SELECT
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
@@ -3963,6 +3966,7 @@ func (q *sqlQuerier) GetChatsByOwnerID(ctx context.Context, arg GetChatsByOwnerI
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4024,7 +4028,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh
|
||||
|
||||
const getStaleChats = `-- name: GetStaleChats :many
|
||||
SELECT
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
@@ -4059,6 +4063,7 @@ func (q *sqlQuerier) GetStaleChats(ctx context.Context, staleThreshold time.Time
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4080,17 +4085,19 @@ INSERT INTO chats (
|
||||
parent_chat_id,
|
||||
root_chat_id,
|
||||
last_model_config_id,
|
||||
title
|
||||
title,
|
||||
mode
|
||||
) VALUES (
|
||||
$1::uuid,
|
||||
$2::uuid,
|
||||
$3::uuid,
|
||||
$4::uuid,
|
||||
$5::uuid,
|
||||
$6::text
|
||||
$6::text,
|
||||
$7::chat_mode
|
||||
)
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode
|
||||
`
|
||||
|
||||
type InsertChatParams struct {
|
||||
@@ -4100,6 +4107,7 @@ type InsertChatParams struct {
|
||||
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
|
||||
Title string `db:"title" json:"title"`
|
||||
Mode NullChatMode `db:"mode" json:"mode"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error) {
|
||||
@@ -4110,6 +4118,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
|
||||
arg.RootChatID,
|
||||
arg.LastModelConfigID,
|
||||
arg.Title,
|
||||
arg.Mode,
|
||||
)
|
||||
var i Chat
|
||||
err := row.Scan(
|
||||
@@ -4128,6 +4137,7 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -4307,7 +4317,7 @@ SET
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode
|
||||
`
|
||||
|
||||
type UpdateChatByIDParams struct {
|
||||
@@ -4334,6 +4344,7 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -4421,7 +4432,7 @@ SET
|
||||
WHERE
|
||||
id = $6::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode
|
||||
`
|
||||
|
||||
type UpdateChatStatusParams struct {
|
||||
@@ -4459,6 +4470,7 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -4472,7 +4484,7 @@ SET
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode
|
||||
`
|
||||
|
||||
type UpdateChatWorkspaceParams struct {
|
||||
@@ -4499,6 +4511,7 @@ func (q *sqlQuerier) UpdateChatWorkspace(ctx context.Context, arg UpdateChatWork
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
@@ -142,14 +142,16 @@ INSERT INTO chats (
|
||||
parent_chat_id,
|
||||
root_chat_id,
|
||||
last_model_config_id,
|
||||
title
|
||||
title,
|
||||
mode
|
||||
) VALUES (
|
||||
@owner_id::uuid,
|
||||
sqlc.narg('workspace_id')::uuid,
|
||||
sqlc.narg('parent_chat_id')::uuid,
|
||||
sqlc.narg('root_chat_id')::uuid,
|
||||
@last_model_config_id::uuid,
|
||||
@title::text
|
||||
@title::text,
|
||||
sqlc.narg('mode')::chat_mode
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
@@ -89,6 +89,8 @@ type AgentConn interface {
|
||||
Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error)
|
||||
WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error)
|
||||
WatchGit(ctx context.Context, logger slog.Logger, chatID uuid.UUID) (*wsjson.Stream[codersdk.WorkspaceAgentGitServerMessage, codersdk.WorkspaceAgentGitClientMessage], error)
|
||||
ConnectDesktopVNC(ctx context.Context) (net.Conn, error)
|
||||
ExecuteDesktopAction(ctx context.Context, action DesktopAction) (DesktopActionResponse, error)
|
||||
}
|
||||
|
||||
// AgentConn represents a connection to a workspace agent.
|
||||
@@ -530,6 +532,112 @@ func (c *agentConn) WatchGit(ctx context.Context, logger slog.Logger, chatID uui
|
||||
](conn, websocket.MessageText, websocket.MessageText, logger), nil
|
||||
}
|
||||
|
||||
// ConnectDesktopVNC opens a WebSocket to the agent's desktop endpoint and
|
||||
// returns a net.Conn carrying raw RFB (VNC) binary data.
|
||||
func (c *agentConn) ConnectDesktopVNC(ctx context.Context) (net.Conn, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
host := net.JoinHostPort(c.agentAddress().String(), strconv.Itoa(AgentHTTPAPIServerPort))
|
||||
|
||||
dialOpts := &websocket.DialOptions{
|
||||
HTTPClient: c.apiClient(),
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
}
|
||||
c.headersMu.RLock()
|
||||
if len(c.extraHeaders) > 0 {
|
||||
dialOpts.HTTPHeader = c.extraHeaders.Clone()
|
||||
}
|
||||
c.headersMu.RUnlock()
|
||||
|
||||
url := fmt.Sprintf("http://%s/api/v0/desktop/vnc", host)
|
||||
conn, res, err := websocket.Dial(ctx, url, dialOpts)
|
||||
if err != nil {
|
||||
if res == nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
if res != nil && res.Body != nil {
|
||||
defer res.Body.Close()
|
||||
}
|
||||
|
||||
// No read limit — RFB framebuffer updates can be large.
|
||||
conn.SetReadLimit(-1)
|
||||
|
||||
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
|
||||
}
|
||||
|
||||
// DesktopAction is the request body for the desktop action
|
||||
// endpoint.
|
||||
type DesktopAction struct {
|
||||
Action string `json:"action"`
|
||||
Coordinate *[2]int `json:"coordinate,omitempty"`
|
||||
StartCoordinate *[2]int `json:"start_coordinate,omitempty"`
|
||||
Text *string `json:"text,omitempty"`
|
||||
Duration *int `json:"duration,omitempty"`
|
||||
ScrollAmount *int `json:"scroll_amount,omitempty"`
|
||||
ScrollDirection *string `json:"scroll_direction,omitempty"`
|
||||
ScaledWidth *int `json:"scaled_width,omitempty"`
|
||||
ScaledHeight *int `json:"scaled_height,omitempty"`
|
||||
}
|
||||
|
||||
// DesktopActionResponse is the response from the desktop action
|
||||
// endpoint.
|
||||
type DesktopActionResponse struct {
|
||||
Output string `json:"output,omitempty"`
|
||||
ScreenshotData string `json:"screenshot_data,omitempty"`
|
||||
ScreenshotWidth int `json:"screenshot_width,omitempty"`
|
||||
ScreenshotHeight int `json:"screenshot_height,omitempty"`
|
||||
}
|
||||
|
||||
// ExecuteDesktopAction executes a mouse/keyboard/scroll action on the
|
||||
// agent's desktop.
|
||||
func (c *agentConn) ExecuteDesktopAction(ctx context.Context, action DesktopAction) (DesktopActionResponse, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
host := net.JoinHostPort(
|
||||
c.agentAddress().String(),
|
||||
strconv.Itoa(AgentHTTPAPIServerPort),
|
||||
)
|
||||
|
||||
body, err := json.Marshal(action)
|
||||
if err != nil {
|
||||
return DesktopActionResponse{}, xerrors.Errorf("marshal action: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s/api/v0/desktop/action", host)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return DesktopActionResponse{}, xerrors.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.headersMu.RLock()
|
||||
if len(c.extraHeaders) > 0 {
|
||||
for k, v := range c.extraHeaders {
|
||||
req.Header[k] = v
|
||||
}
|
||||
}
|
||||
c.headersMu.RUnlock()
|
||||
|
||||
resp, err := c.apiClient().Do(req)
|
||||
if err != nil {
|
||||
return DesktopActionResponse{}, xerrors.Errorf("action request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return DesktopActionResponse{}, codersdk.ReadBodyAsError(resp)
|
||||
}
|
||||
|
||||
var result DesktopActionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return DesktopActionResponse{}, xerrors.Errorf("decode action response: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteDevcontainer deletes the provided devcontainer.
|
||||
// This is a blocking call and will wait for the container to be deleted.
|
||||
func (c *agentConn) DeleteDevcontainer(ctx context.Context, devcontainerID string) error {
|
||||
|
||||
@@ -83,6 +83,21 @@ func (mr *MockAgentConnMockRecorder) Close() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAgentConn)(nil).Close))
|
||||
}
|
||||
|
||||
// ConnectDesktopVNC mocks base method.
|
||||
func (m *MockAgentConn) ConnectDesktopVNC(ctx context.Context) (net.Conn, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ConnectDesktopVNC", ctx)
|
||||
ret0, _ := ret[0].(net.Conn)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ConnectDesktopVNC indicates an expected call of ConnectDesktopVNC.
|
||||
func (mr *MockAgentConnMockRecorder) ConnectDesktopVNC(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectDesktopVNC", reflect.TypeOf((*MockAgentConn)(nil).ConnectDesktopVNC), ctx)
|
||||
}
|
||||
|
||||
// DebugLogs mocks base method.
|
||||
func (m *MockAgentConn) DebugLogs(ctx context.Context) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -171,6 +186,21 @@ func (mr *MockAgentConnMockRecorder) EditFiles(ctx, edits any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EditFiles", reflect.TypeOf((*MockAgentConn)(nil).EditFiles), ctx, edits)
|
||||
}
|
||||
|
||||
// ExecuteDesktopAction mocks base method.
|
||||
func (m *MockAgentConn) ExecuteDesktopAction(ctx context.Context, action workspacesdk.DesktopAction) (workspacesdk.DesktopActionResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ExecuteDesktopAction", ctx, action)
|
||||
ret0, _ := ret[0].(workspacesdk.DesktopActionResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ExecuteDesktopAction indicates an expected call of ExecuteDesktopAction.
|
||||
func (mr *MockAgentConnMockRecorder) ExecuteDesktopAction(ctx, action any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteDesktopAction", reflect.TypeOf((*MockAgentConn)(nil).ExecuteDesktopAction), ctx, action)
|
||||
}
|
||||
|
||||
// GetPeerDiagnostics mocks base method.
|
||||
func (m *MockAgentConn) GetPeerDiagnostics() tailnet.PeerDiagnostics {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
package workspacesdk
|
||||
|
||||
const (
|
||||
// DesktopDisplayWidth is the default display width in pixels
|
||||
// used for computer-use desktop sessions.
|
||||
DesktopDisplayWidth = 1366
|
||||
// DesktopDisplayHeight is the default display height in pixels
|
||||
// used for computer-use desktop sessions.
|
||||
DesktopDisplayHeight = 768
|
||||
)
|
||||
@@ -1410,6 +1410,10 @@
|
||||
"title": "Builds",
|
||||
"path": "./reference/api/builds.md"
|
||||
},
|
||||
{
|
||||
"title": "Chats",
|
||||
"path": "./reference/api/chats.md"
|
||||
},
|
||||
{
|
||||
"title": "Debug",
|
||||
"path": "./reference/api/debug.md"
|
||||
|
||||
Generated
+45
@@ -0,0 +1,45 @@
|
||||
# Chats
|
||||
|
||||
## Archive a chat
|
||||
|
||||
### Code samples
|
||||
|
||||
```shell
|
||||
# Example request using curl
|
||||
curl -X POST http://coder-server:8080/api/v2/chats/{chat}/archive
|
||||
|
||||
```
|
||||
|
||||
`POST /chats/{chat}/archive`
|
||||
|
||||
### Responses
|
||||
|
||||
| Status | Meaning | Description | Schema |
|
||||
|--------|-----------------------------------------------------------------|-------------|--------|
|
||||
| 204 | [No Content](https://tools.ietf.org/html/rfc7231#section-6.3.5) | No Content | |
|
||||
|
||||
## Watch chat desktop
|
||||
|
||||
### Code samples
|
||||
|
||||
```shell
|
||||
# Example request using curl
|
||||
curl -X GET http://coder-server:8080/api/v2/chats/{chat}/desktop \
|
||||
-H 'Coder-Session-Token: API_KEY'
|
||||
```
|
||||
|
||||
`GET /chats/{chat}/desktop`
|
||||
|
||||
### Parameters
|
||||
|
||||
| Name | In | Type | Required | Description |
|
||||
|--------|------|--------------|----------|-------------|
|
||||
| `chat` | path | string(uuid) | true | Chat ID |
|
||||
|
||||
### Responses
|
||||
|
||||
| Status | Meaning | Description | Schema |
|
||||
|--------|--------------------------------------------------------------------------|---------------------|--------|
|
||||
| 101 | [Switching Protocols](https://tools.ietf.org/html/rfc7231#section-6.2.2) | Switching Protocols | |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
@@ -76,7 +76,7 @@ replace github.com/spf13/afero => github.com/aslilac/afero v0.0.0-20250403163713
|
||||
// 1) Adds thinking effort to Anthropic provider
|
||||
// 2) Downgraded to Go 1.25 due to issue with Windows CI
|
||||
// https://github.com/kylecarbs/fantasy/compare/main...kylecarbs:fantasy:cj/go1.25
|
||||
replace charm.land/fantasy => github.com/kylecarbs/fantasy v0.0.0-20260312195846-2681eb9ddd20
|
||||
replace charm.land/fantasy => github.com/kylecarbs/fantasy v0.0.0-20260313123746-578317bb0e5b
|
||||
|
||||
replace github.com/charmbracelet/anthropic-sdk-go => github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab
|
||||
|
||||
|
||||
@@ -797,8 +797,8 @@ github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab h1:5UMY
|
||||
github.com/kylecarbs/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab/go.mod h1:hqlYqR7uPKOKfnNeicUbZp0Ps0GeYFlKYtwh5HGDCx8=
|
||||
github.com/kylecarbs/chroma/v2 v2.0.0-20240401211003-9e036e0631f3 h1:Z9/bo5PSeMutpdiKYNt/TTSfGM1Ll0naj3QzYX9VxTc=
|
||||
github.com/kylecarbs/chroma/v2 v2.0.0-20240401211003-9e036e0631f3/go.mod h1:BUGjjsD+ndS6eX37YgTchSEG+Jg9Jv1GiZs9sqPqztk=
|
||||
github.com/kylecarbs/fantasy v0.0.0-20260312195846-2681eb9ddd20 h1:AEaj4CwdJelIN8GgDZH5xVBP4WFvGqkETGmRO8YBKSA=
|
||||
github.com/kylecarbs/fantasy v0.0.0-20260312195846-2681eb9ddd20/go.mod h1:p6cYJVG8D8AC51MgejAKCMu0myRyQ+vKLuoJQ3biaXo=
|
||||
github.com/kylecarbs/fantasy v0.0.0-20260313123746-578317bb0e5b h1:sC/Qw4tgnzsYQ04i8RU/RIL9UGzLYOSVWKK83CEPoJk=
|
||||
github.com/kylecarbs/fantasy v0.0.0-20260313123746-578317bb0e5b/go.mod h1:p6cYJVG8D8AC51MgejAKCMu0myRyQ+vKLuoJQ3biaXo=
|
||||
github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e h1:OP0ZMFeZkUnOzTFRfpuK3m7Kp4fNvC6qN+exwj7aI4M=
|
||||
github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e/go.mod h1:mQak9GHqbspjC/5iUx3qMlIho8xBS/ppAL/hX5SmPJU=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
|
||||
Reference in New Issue
Block a user