feat: agents desktop recording thumbnail backend (#24022)

The agents chat interface displays thumbnails for videos recorded by the
computer use agent. Currently, to display a thumbnail, the frontend
downloads the entire video and shows the first frame. This PR starts
storing a new thumbnail file in the database for every recorded video,
and exposes the file id in the `wait_agent` tool result alongside the
recording file id, so the frontend can fetch just the thumbnail.
This commit is contained in:
Hugo Dutka
2026-04-09 13:47:54 +02:00
committed by GitHub
parent 2c499484b7
commit efb19eb748
11 changed files with 1072 additions and 126 deletions
+60 -3
View File
@@ -5,7 +5,9 @@ import (
"encoding/json"
"errors"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"strconv"
"sync"
"time"
@@ -620,6 +622,11 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
return
}
defer artifact.Reader.Close()
defer func() {
if artifact.ThumbnailReader != nil {
_ = artifact.ThumbnailReader.Close()
}
}()
if artifact.Size > workspacesdk.MaxRecordingSize {
a.logger.Warn(ctx, "recording file exceeds maximum size",
@@ -633,10 +640,60 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
return
}
rw.Header().Set("Content-Type", "video/mp4")
rw.Header().Set("Content-Length", strconv.FormatInt(artifact.Size, 10))
// Discard the thumbnail if it exceeds the maximum size.
// The server-side consumer also enforces this per-part, but
// rejecting it here avoids streaming a large thumbnail over
// the wire for nothing.
if artifact.ThumbnailReader != nil && artifact.ThumbnailSize > workspacesdk.MaxThumbnailSize {
a.logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting",
slog.F("recording_id", recordingID),
slog.F("size", artifact.ThumbnailSize),
slog.F("max_size", workspacesdk.MaxThumbnailSize),
)
_ = artifact.ThumbnailReader.Close()
artifact.ThumbnailReader = nil
artifact.ThumbnailSize = 0
}
// The multipart response is best-effort: once WriteHeader(200) is
// called, CreatePart failures produce a truncated response without
// the closing boundary. The server-side consumer handles this
// gracefully, preserving any parts read before the error.
mw := multipart.NewWriter(rw)
defer mw.Close()
rw.Header().Set("Content-Type", "multipart/mixed; boundary="+mw.Boundary())
rw.WriteHeader(http.StatusOK)
_, _ = io.Copy(rw, artifact.Reader)
// Part 1: video/mp4 (always present).
videoPart, err := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {"video/mp4"},
})
if err != nil {
a.logger.Warn(ctx, "failed to create video multipart part",
slog.F("recording_id", recordingID),
slog.Error(err))
return
}
if _, err := io.Copy(videoPart, artifact.Reader); err != nil {
a.logger.Warn(ctx, "failed to write video multipart part",
slog.F("recording_id", recordingID),
slog.Error(err))
return
}
// Part 2: image/jpeg (present only when thumbnail was extracted).
if artifact.ThumbnailReader != nil {
thumbPart, err := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {"image/jpeg"},
})
if err != nil {
a.logger.Warn(ctx, "failed to create thumbnail multipart part",
slog.F("recording_id", recordingID),
slog.Error(err))
return
}
_, _ = io.Copy(thumbPart, artifact.ThumbnailReader)
}
}
// coordFromAction extracts the coordinate pair from a DesktopAction,
+191 -16
View File
@@ -4,12 +4,17 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"net"
"net/http"
"net/http/httptest"
"os"
"slices"
"strings"
"sync"
"testing"
"time"
@@ -59,6 +64,8 @@ type fakeDesktop struct {
lastKeyDown string
lastKeyUp string
thumbnailData []byte // if set, StopRecording includes a thumbnail
// Recording tracking (guarded by recMu).
recMu sync.Mutex
recordings map[string]string // ID → file path
@@ -187,10 +194,15 @@ func (f *fakeDesktop) StopRecording(_ context.Context, recordingID string) (*age
_ = file.Close()
return nil, err
}
return &agentdesktop.RecordingArtifact{
artifact := &agentdesktop.RecordingArtifact{
Reader: file,
Size: info.Size(),
}, nil
}
if f.thumbnailData != nil {
artifact.ThumbnailReader = io.NopCloser(bytes.NewReader(f.thumbnailData))
artifact.ThumbnailSize = int64(len(f.thumbnailData))
}
return artifact, nil
}
func (f *fakeDesktop) RecordActivity() {
@@ -785,8 +797,8 @@ func TestRecordingStartStop(t *testing.T) {
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), rr.Body.Bytes())
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), parts["video/mp4"])
}
func TestRecordingStartFails(t *testing.T) {
@@ -847,8 +859,8 @@ func TestRecordingStartIdempotent(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), rr.Body.Bytes())
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), parts["video/mp4"])
}
func TestRecordingStopIdempotent(t *testing.T) {
@@ -872,7 +884,7 @@ func TestRecordingStopIdempotent(t *testing.T) {
require.Equal(t, http.StatusOK, rr.Code)
// Stop twice - both should succeed with identical data.
var bodies [2][]byte
var videoParts [2][]byte
for i := range 2 {
body, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent})
require.NoError(t, err)
@@ -880,10 +892,10 @@ func TestRecordingStopIdempotent(t *testing.T) {
request := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
handler.ServeHTTP(recorder, request)
require.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "video/mp4", recorder.Header().Get("Content-Type"))
bodies[i] = recorder.Body.Bytes()
parts := parseMultipartParts(t, recorder.Header().Get("Content-Type"), recorder.Body.Bytes())
videoParts[i] = parts["video/mp4"]
}
assert.Equal(t, bodies[0], bodies[1])
assert.Equal(t, videoParts[0], videoParts[1])
}
func TestRecordingStopInvalidIDFormat(t *testing.T) {
@@ -1004,8 +1016,8 @@ func TestRecordingMultipleSimultaneous(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
assert.Equal(t, expected[id], rr.Body.Bytes())
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
assert.Equal(t, expected[id], parts["video/mp4"])
}
}
@@ -1112,8 +1124,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
firstData := rr.Body.Bytes()
firstParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
firstData := firstParts["video/mp4"]
require.NotEmpty(t, firstData)
// Step 3: Start again with the same ID - should succeed
@@ -1128,8 +1140,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
secondData := rr.Body.Bytes()
secondParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
secondData := secondParts["video/mp4"]
require.NotEmpty(t, secondData)
// The two recordings should have different data because the
@@ -1235,3 +1247,166 @@ func TestRecordingStopCorrupted(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, "Recording is corrupted.", respStop.Message)
}
// parseMultipartParts parses a multipart/mixed response and returns
// a map from Content-Type to body bytes.
func parseMultipartParts(t *testing.T, contentType string, body []byte) map[string][]byte {
t.Helper()
_, params, err := mime.ParseMediaType(contentType)
require.NoError(t, err, "parse Content-Type")
boundary := params["boundary"]
require.NotEmpty(t, boundary, "missing boundary")
mr := multipart.NewReader(bytes.NewReader(body), boundary)
parts := make(map[string][]byte)
for {
part, err := mr.NextPart()
if errors.Is(err, io.EOF) {
break
}
require.NoError(t, err, "unexpected multipart parse error")
ct := part.Header.Get("Content-Type")
data, readErr := io.ReadAll(part)
require.NoError(t, readErr)
parts[ct] = data
}
return parts
}
func TestHandleRecordingStop_WithThumbnail(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
// Create a fake JPEG header: 0xFF 0xD8 0xFF followed by 509 zero bytes.
thumbnail := make([]byte, 512)
thumbnail[0] = 0xff
thumbnail[1] = 0xd8
thumbnail[2] = 0xff
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
thumbnailData: thumbnail,
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
// Start recording.
recID := uuid.New().String()
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Stop recording.
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Verify multipart response.
ct := rr.Header().Get("Content-Type")
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
"expected multipart/mixed Content-Type, got %s", ct)
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
assert.Len(t, parts, 2, "expected exactly 2 parts (video + thumbnail)")
// The fake writes "fake-mp4-data-<id>-<counter>" as the MP4 content.
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
assert.Equal(t, expectedMP4, parts["video/mp4"])
assert.Equal(t, thumbnail, parts["image/jpeg"])
}
func TestHandleRecordingStop_NoThumbnail(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
// Start recording.
recID := uuid.New().String()
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Stop recording.
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Verify multipart response.
ct := rr.Header().Get("Content-Type")
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
"expected multipart/mixed Content-Type, got %s", ct)
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
assert.Len(t, parts, 1, "expected exactly 1 part (video only)")
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
assert.Equal(t, expectedMP4, parts["video/mp4"])
}
func TestHandleRecordingStop_OversizedThumbnail(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
// Create thumbnail data that exceeds MaxThumbnailSize.
oversizedThumb := make([]byte, workspacesdk.MaxThumbnailSize+1)
oversizedThumb[0] = 0xff
oversizedThumb[1] = 0xd8
oversizedThumb[2] = 0xff
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
thumbnailData: oversizedThumb,
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
// Start recording.
recID := uuid.New().String()
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Stop recording.
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Verify multipart response contains only the video part.
ct := rr.Header().Get("Content-Type")
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
"expected multipart/mixed Content-Type, got %s", ct)
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
assert.Len(t, parts, 1, "expected exactly 1 part (video only, oversized thumbnail discarded)")
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
assert.Equal(t, expectedMP4, parts["video/mp4"])
}
+5
View File
@@ -105,6 +105,11 @@ type RecordingArtifact struct {
Reader io.ReadCloser
// Size is the byte length of the MP4 content.
Size int64
// ThumbnailReader is the JPEG thumbnail. May be nil if no
// thumbnail was produced. Callers must close it when done.
ThumbnailReader io.ReadCloser
// ThumbnailSize is the byte length of the thumbnail.
ThumbnailSize int64
}
// DisplayConfig describes a running desktop session.
+72 -13
View File
@@ -56,6 +56,7 @@ type screenshotOutput struct {
type recordingProcess struct {
cmd *exec.Cmd
filePath string
thumbPath string
stopped bool
killed bool // true when the process was SIGKILLed
done chan struct{} // closed when cmd.Wait() returns
@@ -383,13 +384,20 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
}
}
// Completed recording - discard old file, start fresh.
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
p.logger.Warn(ctx, "failed to remove old recording file",
slog.F("recording_id", recordingID),
slog.F("file_path", rec.filePath),
slog.Error(err),
)
}
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
p.logger.Warn(ctx, "failed to remove old thumbnail file",
slog.F("recording_id", recordingID),
slog.F("thumbnail_path", rec.thumbPath),
slog.Error(err),
)
}
delete(p.recordings, recordingID)
}
@@ -406,6 +414,7 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
}
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".mp4")
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".thumb.jpg")
// Use a background context so the process outlives the HTTP
// request that triggered it.
@@ -419,6 +428,7 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
"--idle-speedup", "20",
"--idle-min-duration", "0.35",
"--idle-noise-tolerance", "-38dB",
"--thumbnail", thumbPath,
filePath)
if err := cmd.Start(); err != nil {
@@ -427,9 +437,10 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
}
rec := &recordingProcess{
cmd: cmd,
filePath: filePath,
done: make(chan struct{}),
cmd: cmd,
filePath: filePath,
thumbPath: thumbPath,
done: make(chan struct{}),
}
go func() {
rec.waitErr = cmd.Wait()
@@ -499,10 +510,35 @@ func (p *portableDesktop) StopRecording(ctx context.Context, recordingID string)
_ = f.Close()
return nil, xerrors.Errorf("stat recording artifact: %w", err)
}
return &RecordingArtifact{
artifact := &RecordingArtifact{
Reader: f,
Size: info.Size(),
}, nil
}
// Attach thumbnail if the subprocess wrote one.
thumbFile, err := os.Open(rec.thumbPath)
if err != nil {
p.logger.Warn(ctx, "thumbnail not available",
slog.F("thumbnail_path", rec.thumbPath),
slog.Error(err))
return artifact, nil
}
thumbInfo, err := thumbFile.Stat()
if err != nil {
_ = thumbFile.Close()
p.logger.Warn(ctx, "thumbnail stat failed",
slog.F("thumbnail_path", rec.thumbPath),
slog.Error(err))
return artifact, nil
}
if thumbInfo.Size() == 0 {
_ = thumbFile.Close()
p.logger.Warn(ctx, "thumbnail file is empty",
slog.F("thumbnail_path", rec.thumbPath))
return artifact, nil
}
artifact.ThumbnailReader = thumbFile
artifact.ThumbnailSize = thumbInfo.Size()
return artifact, nil
}
// lockedStopRecordingProcess stops a single recording via stopOnce.
@@ -571,18 +607,33 @@ func (p *portableDesktop) lockedCleanStaleRecordings(ctx context.Context) {
}
info, err := os.Stat(rec.filePath)
if err != nil {
// File already removed or inaccessible; drop entry.
// File already removed or inaccessible; clean up
// any leftover thumbnail and drop the entry.
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
slog.F("recording_id", id),
slog.F("thumbnail_path", rec.thumbPath),
slog.Error(err),
)
}
delete(p.recordings, id)
continue
}
if p.clock.Since(info.ModTime()) > time.Hour {
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
p.logger.Warn(ctx, "failed to remove stale recording file",
slog.F("recording_id", id),
slog.F("file_path", rec.filePath),
slog.Error(err),
)
}
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
slog.F("recording_id", id),
slog.F("thumbnail_path", rec.thumbPath),
slog.Error(err),
)
}
delete(p.recordings, id)
}
}
@@ -603,13 +654,14 @@ func (p *portableDesktop) Close() error {
// Snapshot recording file paths and idle goroutine channels
// for cleanup, then clear the map.
type recEntry struct {
id string
filePath string
idleDone chan struct{}
id string
filePath string
thumbPath string
idleDone chan struct{}
}
var allRecs []recEntry
for id, rec := range p.recordings {
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, idleDone: rec.idleDone})
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, thumbPath: rec.thumbPath, idleDone: rec.idleDone})
delete(p.recordings, id)
}
session := p.session
@@ -630,13 +682,20 @@ func (p *portableDesktop) Close() error {
go func() {
defer close(cleanupDone)
for _, entry := range allRecs {
if err := os.Remove(entry.filePath); err != nil && !os.IsNotExist(err) {
if err := os.Remove(entry.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
p.logger.Warn(context.Background(), "failed to remove recording file on close",
slog.F("recording_id", entry.id),
slog.F("file_path", entry.filePath),
slog.Error(err),
)
}
if err := os.Remove(entry.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
p.logger.Warn(context.Background(), "failed to remove thumbnail file on close",
slog.F("recording_id", entry.id),
slog.F("thumbnail_path", entry.thumbPath),
slog.Error(err),
)
}
}
if session != nil {
session.cancel()
@@ -2,6 +2,7 @@ package agentdesktop
import (
"context"
"io"
"os"
"os/exec"
"path/filepath"
@@ -584,6 +585,7 @@ func TestPortableDesktop_StartRecording(t *testing.T) {
joined := strings.Join(cmd, " ")
if strings.Contains(joined, "record") && strings.Contains(joined, "coder-recording-"+recID) {
found = true
assert.Contains(t, joined, "--thumbnail", "record command should include --thumbnail flag")
break
}
}
@@ -666,6 +668,66 @@ func TestPortableDesktop_StopRecording_ReturnsArtifact(t *testing.T) {
defer artifact.Reader.Close()
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
// No thumbnail file exists, so ThumbnailReader should be nil.
assert.Nil(t, artifact.ThumbnailReader, "ThumbnailReader should be nil when no thumbnail file exists")
require.NoError(t, pd.Close())
}
func TestPortableDesktop_StopRecording_WithThumbnail(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewReal()
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
recID := uuid.New().String()
err := pd.StartRecording(ctx, recID)
require.NoError(t, err)
// Write a dummy MP4 file at the expected path.
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".mp4")
require.NoError(t, os.WriteFile(filePath, []byte("fake-mp4-data"), 0o600))
t.Cleanup(func() { _ = os.Remove(filePath) })
// Write a thumbnail file at the expected path.
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".thumb.jpg")
thumbContent := []byte("fake-jpeg-thumbnail")
require.NoError(t, os.WriteFile(thumbPath, thumbContent, 0o600))
t.Cleanup(func() { _ = os.Remove(thumbPath) })
artifact, err := pd.StopRecording(ctx, recID)
require.NoError(t, err)
defer artifact.Reader.Close()
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
// Thumbnail should be attached.
require.NotNil(t, artifact.ThumbnailReader, "ThumbnailReader should be non-nil when thumbnail file exists")
defer artifact.ThumbnailReader.Close()
assert.Equal(t, int64(len(thumbContent)), artifact.ThumbnailSize)
// Read and verify thumbnail content.
thumbData, err := io.ReadAll(artifact.ThumbnailReader)
require.NoError(t, err)
assert.Equal(t, thumbContent, thumbData)
require.NoError(t, pd.Close())
}
+3 -3
View File
@@ -73,9 +73,9 @@ const (
// maxConcurrentRecordingUploads caps the number of recording
// stop-and-store operations that can run concurrently. Each
// slot buffers up to MaxRecordingSize (100 MB) in memory, so
// this value implicitly bounds memory to roughly
// maxConcurrentRecordingUploads * 100 MB.
// slot buffers up to MaxRecordingSize + MaxThumbnailSize
// (110 MB) in memory, so this value implicitly bounds memory
// to roughly maxConcurrentRecordingUploads * 110 MB.
maxConcurrentRecordingUploads = 25
// staleRecoveryIntervalDivisor determines how often the stale
+148 -54
View File
@@ -2,8 +2,11 @@ package chatd
import (
"context"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"github.com/google/uuid"
@@ -13,71 +16,60 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type recordingResult struct {
recordingFileID string
thumbnailFileID string
}
// stopAndStoreRecording stops the desktop recording, downloads the
// MP4, and stores it in chat_files. Only called when the subagent
// completed successfully. Returns the file ID on success, empty
// string on any failure. All errors are logged but not propagated
// — recording is best-effort.
// multipart response containing the MP4 and optional thumbnail, and
// stores them in chat_files. Only called when the subagent completed
// successfully. Returns file IDs on success, empty fields on any
// failure. All errors are logged but not propagated; recording is
// best-effort.
func (p *Server) stopAndStoreRecording(
ctx context.Context,
conn workspacesdk.AgentConn,
recordingID string,
ownerID uuid.UUID,
workspaceID uuid.NullUUID,
) string {
) recordingResult {
var result recordingResult
select {
case p.recordingSem <- struct{}{}:
defer func() { <-p.recordingSem }()
case <-ctx.Done():
p.logger.Warn(ctx, "context canceled waiting for recording semaphore", slog.Error(ctx.Err()))
return ""
return result
}
body, err := conn.StopDesktopRecording(ctx,
resp, err := conn.StopDesktopRecording(ctx,
workspacesdk.StopDesktopRecordingRequest{RecordingID: recordingID})
if err != nil {
p.logger.Warn(ctx, "failed to stop desktop recording",
slog.Error(err))
return ""
return result
}
type readResult struct {
data []byte
err error
}
ch := make(chan readResult, 1)
go func() {
data, err := io.ReadAll(io.LimitReader(body, workspacesdk.MaxRecordingSize+1))
ch <- readResult{data, err}
}()
defer resp.Body.Close()
var data []byte
select {
case res := <-ch:
body.Close()
data = res.data
if res.err != nil {
p.logger.Warn(ctx, "failed to read recording data", slog.Error(res.err))
return ""
}
case <-ctx.Done():
body.Close()
p.logger.Warn(ctx, "context canceled while reading recording data", slog.Error(ctx.Err()))
return ""
_, params, err := mime.ParseMediaType(resp.ContentType)
if err != nil {
p.logger.Warn(ctx, "failed to parse content type from recording response",
slog.F("content_type", resp.ContentType),
slog.Error(err))
return result
}
if len(data) > workspacesdk.MaxRecordingSize {
p.logger.Warn(ctx, "recording data exceeds maximum size, skipping store",
slog.F("size", len(data)),
slog.F("max_size", workspacesdk.MaxRecordingSize))
return ""
}
if len(data) == 0 {
p.logger.Warn(ctx, "recording data is empty, skipping store")
return ""
boundary := params["boundary"]
if boundary == "" {
p.logger.Warn(ctx, "missing boundary in recording response content type",
slog.F("content_type", resp.ContentType))
return result
}
if !workspaceID.Valid {
p.logger.Warn(ctx, "chat has no workspace, cannot store recording")
return ""
return result
}
// The chatd actor is used here because the recording is stored on
@@ -87,21 +79,123 @@ func (p *Server) stopAndStoreRecording(
if err != nil {
p.logger.Warn(ctx, "failed to resolve workspace for recording",
slog.Error(err))
return ""
return result
}
//nolint:gocritic // AsChatd is required to insert chat files from the recording pipeline.
row, err := p.db.InsertChatFile(dbauthz.AsChatd(ctx), database.InsertChatFileParams{
OwnerID: ownerID,
OrganizationID: ws.OrganizationID,
Name: fmt.Sprintf("recording-%s.mp4", p.clock.Now().UTC().Format("2006-01-02T15-04-05Z")),
Mimetype: "video/mp4",
Data: data,
})
if err != nil {
p.logger.Warn(ctx, "failed to store recording in database",
slog.Error(err))
return ""
mr := multipart.NewReader(resp.Body, boundary)
// Context cancellation is checked between parts. Within a
// part read, cancellation relies on Go's HTTP transport closing
// the underlying connection when the context is done, which
// interrupts the blocked io.ReadAll.
// First pass: parse all multipart parts into memory.
// The agent sends at most two parts: one video/mp4 and one
// optional image/jpeg thumbnail. Cap the number of parts to
// prevent a malicious or broken agent from forcing the server
// into an unbounded parsing loop.
const maxParts = 2
var videoData, thumbnailData []byte
for range maxParts {
if ctx.Err() != nil {
p.logger.Warn(ctx, "context canceled while reading recording parts", slog.Error(ctx.Err()))
break
}
part, err := mr.NextPart()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
p.logger.Warn(ctx, "error reading next multipart part", slog.Error(err))
break
}
contentType := part.Header.Get("Content-Type")
// Select the read limit based on content type so that
// thumbnails (image/jpeg) do not allocate up to
// MaxRecordingSize (100 MB) before the size check rejects
// them. Unknown types use a small default since they are
// discarded below.
maxSize := int64(1 << 20) // 1 MB default for unknown types
switch contentType {
case "video/mp4":
maxSize = int64(workspacesdk.MaxRecordingSize)
case "image/jpeg":
maxSize = int64(workspacesdk.MaxThumbnailSize)
}
data, err := io.ReadAll(io.LimitReader(part, maxSize+1))
if err != nil {
p.logger.Warn(ctx, "failed to read recording part data",
slog.F("content_type", contentType),
slog.Error(err))
continue
}
if int64(len(data)) > maxSize {
p.logger.Warn(ctx, "recording part exceeds maximum size, skipping",
slog.F("content_type", contentType),
slog.F("size", len(data)),
slog.F("max_size", maxSize))
continue
}
if len(data) == 0 {
p.logger.Warn(ctx, "recording part is empty, skipping",
slog.F("content_type", contentType))
continue
}
switch contentType {
case "video/mp4":
if videoData != nil {
p.logger.Warn(ctx, "duplicate video/mp4 part in recording response, skipping")
continue
}
videoData = data
case "image/jpeg":
if thumbnailData != nil {
p.logger.Warn(ctx, "duplicate image/jpeg part in recording response, skipping")
continue
}
thumbnailData = data
default:
p.logger.Debug(ctx, "skipping unknown part content type",
slog.F("content_type", contentType))
}
}
return row.ID.String()
// Second pass: store the collected data in the database.
if videoData != nil {
//nolint:gocritic // AsChatd is required to insert chat files from the recording pipeline.
row, err := p.db.InsertChatFile(dbauthz.AsChatd(ctx), database.InsertChatFileParams{
OwnerID: ownerID,
OrganizationID: ws.OrganizationID,
Name: fmt.Sprintf("recording-%s.mp4", p.clock.Now().UTC().Format("2006-01-02T15-04-05Z")),
Mimetype: "video/mp4",
Data: videoData,
})
if err != nil {
p.logger.Warn(ctx, "failed to store recording in database",
slog.Error(err))
} else {
result.recordingFileID = row.ID.String()
}
}
if thumbnailData != nil && result.recordingFileID != "" {
//nolint:gocritic // AsChatd is required to insert chat files from the recording pipeline.
row, err := p.db.InsertChatFile(dbauthz.AsChatd(ctx), database.InsertChatFileParams{
OwnerID: ownerID,
OrganizationID: ws.OrganizationID,
Name: fmt.Sprintf("thumbnail-%s.jpg", p.clock.Now().UTC().Format("2006-01-02T15-04-05Z")),
Mimetype: "image/jpeg",
Data: thumbnailData,
})
if err != nil {
p.logger.Warn(ctx, "failed to store thumbnail in database",
slog.Error(err))
} else {
result.thumbnailFileID = row.ID.String()
}
}
return result
}
+495 -23
View File
@@ -5,6 +5,8 @@ import (
"context"
"encoding/json"
"io"
"mime/multipart"
"net/textproto"
"strings"
"testing"
"time"
@@ -34,6 +36,30 @@ func (zeroReader) Read(p []byte) (int, error) {
return len(p), nil
}
// partSpec describes a single part for buildMultipartResponse.
type partSpec struct {
contentType string
data []byte
}
// buildMultipartResponse constructs a StopDesktopRecordingResponse
// with the given content type/data pairs encoded as multipart/mixed.
func buildMultipartResponse(parts ...partSpec) workspacesdk.StopDesktopRecordingResponse {
var buf bytes.Buffer
mw := multipart.NewWriter(&buf)
for _, p := range parts {
partWriter, _ := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {p.contentType},
})
_, _ = partWriter.Write(p.data)
}
_ = mw.Close()
return workspacesdk.StopDesktopRecordingResponse{
Body: io.NopCloser(bytes.NewReader(buf.Bytes())),
ContentType: "multipart/mixed; boundary=" + mw.Boundary(),
}
}
// createComputerUseParentChild creates a parent chat and a
// computer_use child chat bound to the given workspace/agent.
// Both chats are inserted directly via DB to avoid triggering
@@ -170,8 +196,7 @@ func TestWaitAgentComputerUseRecording(t *testing.T) {
mockConn.EXPECT().
StopDesktopRecording(gomock.Any(), gomock.Any()).
Return(io.NopCloser(bytes.NewReader(fakeMp4)), nil).
Times(1)
Return(buildMultipartResponse(partSpec{"video/mp4", fakeMp4}), nil).Times(1)
// Invoke wait_agent via the tool closure.
resp, err := invokeWaitAgentTool(ctx, t, server, db, parent.ID, child.ID, 5)
@@ -198,6 +223,87 @@ func TestWaitAgentComputerUseRecording(t *testing.T) {
assert.Equal(t, fakeMp4, chatFile.Data)
}
// TestWaitAgentComputerUseRecordingWithThumbnail verifies the
// recording flow when the agent produces both video and thumbnail:
// both file IDs appear in the wait_agent tool response.
func TestWaitAgentComputerUseRecordingWithThumbnail(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := chatdTestContext(t)
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
user, model := seedInternalChatDeps(ctx, t, db)
workspace, _, agent := seedWorkspaceBinding(t, db, user.ID)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
parent, child := createComputerUseParentChild(
ctx, t, server, user, model, workspace, agent,
"parent-recording-thumb", "computer-use-child-thumb",
)
server.drainInflight()
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
require.Equal(t, agent.ID, agentID)
return mockConn, func() {}, nil
}
insertAssistantMessage(ctx, t, db, child.ID, model.ID, "I opened Firefox and took a screenshot.")
setChatStatus(ctx, t, db, child.ID, database.ChatStatusWaiting, "")
fakeMp4 := []byte("fake-mp4-data-with-thumbnail-test")
fakeThumb := []byte("fake-jpeg-thumbnail-data")
mockConn.EXPECT().
StartDesktopRecording(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, req workspacesdk.StartDesktopRecordingRequest) error {
require.NotEmpty(t, req.RecordingID)
return nil
}).
Times(1)
mockConn.EXPECT().
StopDesktopRecording(gomock.Any(), gomock.Any()).
Return(buildMultipartResponse(
partSpec{"video/mp4", fakeMp4},
partSpec{"image/jpeg", fakeThumb},
), nil).Times(1)
resp, err := invokeWaitAgentTool(ctx, t, server, db, parent.ID, child.ID, 5)
require.NoError(t, err)
require.False(t, resp.IsError, "expected successful response, got: %s", resp.Content)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
// Verify recording_file_id is present and valid.
storedFileID, ok := result["recording_file_id"].(string)
require.True(t, ok, "recording_file_id must be present in response")
require.NotEmpty(t, storedFileID)
fileUUID, err := uuid.Parse(storedFileID)
require.NoError(t, err)
chatFile, err := db.GetChatFileByID(ctx, fileUUID)
require.NoError(t, err)
assert.Equal(t, "video/mp4", chatFile.Mimetype)
assert.Equal(t, fakeMp4, chatFile.Data)
// Verify thumbnail_file_id is present and valid.
thumbFileID, ok := result["thumbnail_file_id"].(string)
require.True(t, ok, "thumbnail_file_id must be present in response")
require.NotEmpty(t, thumbFileID)
thumbUUID, err := uuid.Parse(thumbFileID)
require.NoError(t, err)
thumbFile, err := db.GetChatFileByID(ctx, thumbUUID)
require.NoError(t, err)
assert.Equal(t, "image/jpeg", thumbFile.Mimetype)
assert.Equal(t, fakeThumb, thumbFile.Data)
}
// TestWaitAgentNonComputerUseNoRecording verifies that when the
// child chat is NOT a computer_use chat, no recording is attempted.
// StartDesktopRecording must never be called.
@@ -342,7 +448,7 @@ func TestWaitAgentRecordingStopFails(t *testing.T) {
mockConn.EXPECT().
StopDesktopRecording(gomock.Any(), gomock.Any()).
Return(nil, xerrors.New("disk full")).
Return(workspacesdk.StopDesktopRecordingResponse{}, xerrors.New("disk full")).
Times(1)
// Invoke wait_agent via the tool closure.
@@ -446,10 +552,10 @@ func TestWaitAgentTimeoutLeavesRecordingRunning(t *testing.T) {
assert.Contains(t, result.resp.Content, "timed out")
}
// TestStopAndStoreRecordingOversized verifies that when the recording
// data exceeds MaxRecordingSize, stopAndStoreRecording returns an
// empty string and does NOT call InsertChatFile.
func TestStopAndStoreRecordingOversized(t *testing.T) {
// TestStopAndStoreRecording_Oversized verifies that when the
// recording data exceeds MaxRecordingSize, stopAndStoreRecording
// returns an empty string and does NOT call InsertChatFile.
func TestStopAndStoreRecording_Oversized(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
@@ -463,29 +569,146 @@ func TestStopAndStoreRecordingOversized(t *testing.T) {
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
// Create a reader that produces MaxRecordingSize+1 bytes without
// allocating the full buffer in memory.
oversizedReader := io.LimitReader(
&zeroReader{},
int64(workspacesdk.MaxRecordingSize+1),
)
// Build a streaming multipart response with a video/mp4 part
// that exceeds MaxRecordingSize without allocating the full
// buffer in memory.
pr, pw := io.Pipe()
mw := multipart.NewWriter(pw)
go func() {
partWriter, _ := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {"video/mp4"},
})
// Stream MaxRecordingSize+1 zero bytes.
_, _ = io.Copy(partWriter, io.LimitReader(&zeroReader{}, int64(workspacesdk.MaxRecordingSize+1)))
_ = mw.Close()
_ = pw.Close()
}()
mockConn.EXPECT().
StopDesktopRecording(gomock.Any(), gomock.Any()).
Return(io.NopCloser(oversizedReader), nil).
Return(workspacesdk.StopDesktopRecordingResponse{
Body: pr,
ContentType: "multipart/mixed; boundary=" + mw.Boundary(),
}, nil).
Times(1)
recordingID := uuid.New().String()
storedFileID := server.stopAndStoreRecording(
result := server.stopAndStoreRecording(
ctx, mockConn, recordingID, user.ID,
uuid.NullUUID{UUID: workspace.ID, Valid: true},
)
assert.Empty(t, storedFileID, "oversized recording should not be stored")
assert.Empty(t, result.recordingFileID, "oversized recording should not be stored")
}
// TestStopAndStoreRecordingEmpty verifies that when the recording
// TestStopAndStoreRecording_OversizedThumbnail verifies that when the
// thumbnail part exceeds MaxThumbnailSize it is skipped while the
// normal-sized video part is still stored.
func TestStopAndStoreRecording_OversizedThumbnail(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := chatdTestContext(t)
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
user, _ := seedInternalChatDeps(ctx, t, db)
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
videoData := bytes.Repeat([]byte{0xAA}, 1024)
// Build a streaming multipart response with a normal video part
// and an oversized thumbnail part.
pr, pw := io.Pipe()
mw := multipart.NewWriter(pw)
go func() {
vw, _ := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {"video/mp4"},
})
_, _ = vw.Write(videoData)
tw, _ := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {"image/jpeg"},
})
// Stream MaxThumbnailSize+1 zero bytes for the thumbnail.
_, _ = io.Copy(tw, io.LimitReader(&zeroReader{}, int64(workspacesdk.MaxThumbnailSize+1)))
_ = mw.Close()
_ = pw.Close()
}()
mockConn.EXPECT().
StopDesktopRecording(gomock.Any(), gomock.Any()).
Return(workspacesdk.StopDesktopRecordingResponse{
Body: pr,
ContentType: "multipart/mixed; boundary=" + mw.Boundary(),
}, nil).
Times(1)
recordingID := uuid.New().String()
result := server.stopAndStoreRecording(
ctx, mockConn, recordingID, user.ID,
uuid.NullUUID{UUID: workspace.ID, Valid: true},
)
// Video should be stored.
recUUID, err := uuid.Parse(result.recordingFileID)
require.NoError(t, err, "RecordingFileID should be a valid UUID")
recFile, err := db.GetChatFileByID(ctx, recUUID)
require.NoError(t, err)
assert.Equal(t, "video/mp4", recFile.Mimetype)
assert.Equal(t, videoData, recFile.Data)
// Thumbnail should be skipped (oversized).
assert.Empty(t, result.thumbnailFileID, "oversized thumbnail should not be stored")
}
// TestStopAndStoreRecording_DuplicatePartsIgnored verifies that when
// a multipart response contains two video/mp4 parts, only the first
// is stored and the duplicate is skipped.
func TestStopAndStoreRecording_DuplicatePartsIgnored(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := chatdTestContext(t)
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
user, _ := seedInternalChatDeps(ctx, t, db)
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
firstVideo := bytes.Repeat([]byte{0x01}, 512)
secondVideo := bytes.Repeat([]byte{0x02}, 512)
mockConn.EXPECT().
StopDesktopRecording(gomock.Any(), gomock.Any()).
Return(buildMultipartResponse(
partSpec{"video/mp4", firstVideo},
partSpec{"video/mp4", secondVideo},
), nil).
Times(1)
recordingID := uuid.New().String()
result := server.stopAndStoreRecording(
ctx, mockConn, recordingID, user.ID,
uuid.NullUUID{UUID: workspace.ID, Valid: true},
)
// Only the first video part should be stored.
recUUID, err := uuid.Parse(result.recordingFileID)
require.NoError(t, err)
recFile, err := db.GetChatFileByID(ctx, recUUID)
require.NoError(t, err)
assert.Equal(t, firstVideo, recFile.Data, "first video part should be stored, not the duplicate")
}
// TestStopAndStoreRecording_Empty verifies that when the recording
// data is empty, stopAndStoreRecording returns an empty string and
// does NOT call InsertChatFile.
func TestStopAndStoreRecordingEmpty(t *testing.T) {
func TestStopAndStoreRecording_Empty(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
@@ -499,16 +722,265 @@ func TestStopAndStoreRecordingEmpty(t *testing.T) {
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
// Return empty data.
// Build a multipart response with an empty video/mp4 part.
mockConn.EXPECT().
StopDesktopRecording(gomock.Any(), gomock.Any()).
Return(io.NopCloser(bytes.NewReader(nil)), nil).
Times(1)
Return(buildMultipartResponse(partSpec{"video/mp4", nil}), nil).Times(1)
recordingID := uuid.New().String()
storedFileID := server.stopAndStoreRecording(
result := server.stopAndStoreRecording(
ctx, mockConn, recordingID, user.ID,
uuid.NullUUID{UUID: workspace.ID, Valid: true},
)
assert.Empty(t, storedFileID, "empty recording should not be stored")
assert.Empty(t, result.recordingFileID, "empty recording should not be stored")
}
// TestStopAndStoreRecording_WithThumbnail verifies that a multipart
// response containing both a video/mp4 part and an image/jpeg part
// results in both files being stored with correct mimetypes.
func TestStopAndStoreRecording_WithThumbnail(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := chatdTestContext(t)
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
user, _ := seedInternalChatDeps(ctx, t, db)
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
videoData := bytes.Repeat([]byte{0xDE, 0xAD}, 512) // 1024 bytes
thumbData := bytes.Repeat([]byte{0xFF, 0xD8}, 256) // 512 bytes
mockConn.EXPECT().
StopDesktopRecording(gomock.Any(), gomock.Any()).
Return(buildMultipartResponse(
partSpec{"video/mp4", videoData},
partSpec{"image/jpeg", thumbData},
), nil).
Times(1)
recordingID := uuid.New().String()
result := server.stopAndStoreRecording(
ctx, mockConn, recordingID, user.ID,
uuid.NullUUID{UUID: workspace.ID, Valid: true},
)
// Both file IDs should be valid UUIDs.
recUUID, err := uuid.Parse(result.recordingFileID)
require.NoError(t, err, "RecordingFileID should be a valid UUID")
thumbUUID, err := uuid.Parse(result.thumbnailFileID)
require.NoError(t, err, "ThumbnailFileID should be a valid UUID")
// Verify the recording file in the database.
recFile, err := db.GetChatFileByID(ctx, recUUID)
require.NoError(t, err)
assert.Equal(t, "video/mp4", recFile.Mimetype)
assert.Equal(t, videoData, recFile.Data)
// Verify the thumbnail file in the database.
thumbFile, err := db.GetChatFileByID(ctx, thumbUUID)
require.NoError(t, err)
assert.Equal(t, "image/jpeg", thumbFile.Mimetype)
assert.Equal(t, thumbData, thumbFile.Data)
}
// TestStopAndStoreRecording_VideoOnly verifies that a multipart
// response with only a video/mp4 part stores the recording but
// leaves thumbnailFileID empty.
func TestStopAndStoreRecording_VideoOnly(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := chatdTestContext(t)
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
user, _ := seedInternalChatDeps(ctx, t, db)
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
videoData := make([]byte, 1024)
mockConn.EXPECT().
StopDesktopRecording(gomock.Any(), gomock.Any()).
Return(buildMultipartResponse(partSpec{"video/mp4", videoData}), nil).Times(1)
recordingID := uuid.New().String()
result := server.stopAndStoreRecording(
ctx, mockConn, recordingID, user.ID,
uuid.NullUUID{UUID: workspace.ID, Valid: true},
)
// Recording should be stored.
recUUID, err := uuid.Parse(result.recordingFileID)
require.NoError(t, err, "RecordingFileID should be a valid UUID")
recFile, err := db.GetChatFileByID(ctx, recUUID)
require.NoError(t, err)
assert.Equal(t, "video/mp4", recFile.Mimetype)
assert.Equal(t, videoData, recFile.Data)
// No thumbnail.
assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty when no thumbnail part is present")
}
// TestStopAndStoreRecording_DownloadFailure verifies that when
// StopDesktopRecording returns an error, stopAndStoreRecording
// returns an empty recordingResult without panicking.
func TestStopAndStoreRecording_DownloadFailure(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := chatdTestContext(t)
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
user, _ := seedInternalChatDeps(ctx, t, db)
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
mockConn.EXPECT().
StopDesktopRecording(gomock.Any(), gomock.Any()).
Return(workspacesdk.StopDesktopRecordingResponse{}, xerrors.New("network error")).
Times(1)
recordingID := uuid.New().String()
result := server.stopAndStoreRecording(
ctx, mockConn, recordingID, user.ID,
uuid.NullUUID{UUID: workspace.ID, Valid: true},
)
assert.Empty(t, result.recordingFileID, "RecordingFileID should be empty on download failure")
assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty on download failure")
}
// TestStopAndStoreRecording_UnknownPartIgnored verifies that parts
// with unrecognized content types are silently skipped while known
// parts (video/mp4 and image/jpeg) are still stored.
func TestStopAndStoreRecording_UnknownPartIgnored(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := chatdTestContext(t)
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
user, _ := seedInternalChatDeps(ctx, t, db)
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
videoData := make([]byte, 1024)
thumbData := make([]byte, 512)
unknownData := make([]byte, 256)
mockConn.EXPECT().
StopDesktopRecording(gomock.Any(), gomock.Any()).
Return(buildMultipartResponse(
partSpec{"video/mp4", videoData},
partSpec{"image/jpeg", thumbData},
partSpec{"application/octet-stream", unknownData},
), nil).Times(1)
recordingID := uuid.New().String()
result := server.stopAndStoreRecording(
ctx, mockConn, recordingID, user.ID,
uuid.NullUUID{UUID: workspace.ID, Valid: true},
)
// Both known parts should be stored.
recUUID, err := uuid.Parse(result.recordingFileID)
require.NoError(t, err, "RecordingFileID should be a valid UUID")
thumbUUID, err := uuid.Parse(result.thumbnailFileID)
require.NoError(t, err, "ThumbnailFileID should be a valid UUID")
// Verify only 2 files exist (unknown part was skipped).
recFile, err := db.GetChatFileByID(ctx, recUUID)
require.NoError(t, err)
assert.Equal(t, "video/mp4", recFile.Mimetype)
assert.Equal(t, videoData, recFile.Data)
thumbFile, err := db.GetChatFileByID(ctx, thumbUUID)
require.NoError(t, err)
assert.Equal(t, "image/jpeg", thumbFile.Mimetype)
assert.Equal(t, thumbData, thumbFile.Data)
}
// TestStopAndStoreRecording_MalformedContentType verifies that a
// response with an unparseable Content-Type returns an empty result.
func TestStopAndStoreRecording_MalformedContentType(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := chatdTestContext(t)
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
user, _ := seedInternalChatDeps(ctx, t, db)
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
mockConn.EXPECT().
StopDesktopRecording(gomock.Any(), gomock.Any()).
Return(workspacesdk.StopDesktopRecordingResponse{
Body: io.NopCloser(bytes.NewReader(nil)),
ContentType: "",
}, nil).
Times(1)
recordingID := uuid.New().String()
result := server.stopAndStoreRecording(
ctx, mockConn, recordingID, user.ID,
uuid.NullUUID{UUID: workspace.ID, Valid: true},
)
assert.Empty(t, result.recordingFileID, "RecordingFileID should be empty for malformed content type")
assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty for malformed content type")
}
// TestStopAndStoreRecording_MissingBoundary verifies that a
// multipart response without a boundary parameter returns an empty
// result.
func TestStopAndStoreRecording_MissingBoundary(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := chatdTestContext(t)
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
user, _ := seedInternalChatDeps(ctx, t, db)
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
mockConn.EXPECT().
StopDesktopRecording(gomock.Any(), gomock.Any()).
Return(workspacesdk.StopDesktopRecordingResponse{
Body: io.NopCloser(bytes.NewReader(nil)),
ContentType: "multipart/mixed",
}, nil).
Times(1)
recordingID := uuid.New().String()
result := server.stopAndStoreRecording(
ctx, mockConn, recordingID, user.ID,
uuid.NullUUID{UUID: workspace.ID, Valid: true},
)
assert.Empty(t, result.recordingFileID, "RecordingFileID should be empty when boundary is missing")
assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty when boundary is missing")
}
+7 -4
View File
@@ -233,13 +233,13 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
}
// Only stop and store the recording on success.
var storedFileID string
var recResult recordingResult
if recordingID != "" && agentConn != nil {
// Use a fresh context for cleanup so a canceled
// parent context doesn't prevent recording storage.
stopCtx, stopCancel := context.WithTimeout(context.WithoutCancel(ctx), 90*time.Second)
defer stopCancel()
storedFileID = p.stopAndStoreRecording(stopCtx, agentConn,
recResult = p.stopAndStoreRecording(stopCtx, agentConn,
recordingID, parent.OwnerID, parent.WorkspaceID)
}
resp := map[string]any{
@@ -248,8 +248,11 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
"report": report,
"status": string(targetChat.Status),
}
if storedFileID != "" {
resp["recording_file_id"] = storedFileID
if recResult.recordingFileID != "" {
resp["recording_file_id"] = recResult.recordingFileID
}
if recResult.thumbnailFileID != "" {
resp["thumbnail_file_id"] = recResult.thumbnailFileID
}
return toolJSONResponse(resp), nil
},
+27 -8
View File
@@ -95,7 +95,7 @@ type AgentConn interface {
ConnectDesktopVNC(ctx context.Context) (net.Conn, error)
ExecuteDesktopAction(ctx context.Context, action DesktopAction) (DesktopActionResponse, error)
StartDesktopRecording(ctx context.Context, req StartDesktopRecordingRequest) error
StopDesktopRecording(ctx context.Context, req StopDesktopRecordingRequest) (io.ReadCloser, error)
StopDesktopRecording(ctx context.Context, req StopDesktopRecordingRequest) (StopDesktopRecordingResponse, error)
}
// AgentConn represents a connection to a workspace agent.
@@ -610,11 +610,25 @@ type StopDesktopRecordingRequest struct {
RecordingID string `json:"recording_id"`
}
// StopDesktopRecordingResponse wraps the response from stopping a
// desktop recording. Body contains the recording data as a
// multipart/mixed stream. ContentType holds the Content-Type
// header (including boundary) so callers can parse the body.
type StopDesktopRecordingResponse struct {
Body io.ReadCloser
ContentType string
}
// MaxRecordingSize is the largest desktop recording (in bytes)
// that will be accepted. Used by both the agent-side stop handler
// and the server-side storage pipeline.
const MaxRecordingSize = 100 << 20 // 100 MB
// MaxThumbnailSize is the largest thumbnail (in bytes) that will
// be accepted. Applied both agent-side (before streaming) and
// server-side (when parsing multipart parts).
const MaxThumbnailSize = 10 << 20 // 10 MB
// ExecuteDesktopAction executes a mouse/keyboard/scroll action on the
// agent's desktop.
func (c *agentConn) ExecuteDesktopAction(ctx context.Context, action DesktopAction) (DesktopActionResponse, error) {
@@ -681,22 +695,27 @@ func (c *agentConn) StartDesktopRecording(ctx context.Context, req StartDesktopR
}
// StopDesktopRecording stops a desktop recording session on the
// agent and returns the MP4 data as an io.ReadCloser. The caller
// is responsible for closing the returned reader. Idempotent —
// safe to call on an already-stopped recording.
func (c *agentConn) StopDesktopRecording(ctx context.Context, req StopDesktopRecordingRequest) (io.ReadCloser, error) {
// agent and returns the recording as a StopDesktopRecordingResponse.
// The response body is a multipart/mixed stream containing the
// video (and optionally a JPEG thumbnail). The caller is
// responsible for closing the returned Body. Idempotent — safe
// to call on an already-stopped recording.
func (c *agentConn) StopDesktopRecording(ctx context.Context, req StopDesktopRecordingRequest) (StopDesktopRecordingResponse, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
res, err := c.apiRequest(ctx, http.MethodPost, "/api/v0/desktop/recording/stop", req)
if err != nil {
return nil, xerrors.Errorf("stop recording request: %w", err)
return StopDesktopRecordingResponse{}, xerrors.Errorf("stop recording request: %w", err)
}
if res.StatusCode != http.StatusOK {
defer res.Body.Close()
return nil, codersdk.ReadBodyAsError(res)
return StopDesktopRecordingResponse{}, codersdk.ReadBodyAsError(res)
}
// Caller is responsible for closing res.Body.
return res.Body, nil
return StopDesktopRecordingResponse{
Body: res.Body,
ContentType: res.Header.Get("Content-Type"),
}, nil
}
// DeleteDevcontainer deletes the provided devcontainer.
@@ -580,10 +580,10 @@ func (mr *MockAgentConnMockRecorder) StartProcess(ctx, req any) *gomock.Call {
}
// StopDesktopRecording mocks base method.
func (m *MockAgentConn) StopDesktopRecording(ctx context.Context, req workspacesdk.StopDesktopRecordingRequest) (io.ReadCloser, error) {
func (m *MockAgentConn) StopDesktopRecording(ctx context.Context, req workspacesdk.StopDesktopRecordingRequest) (workspacesdk.StopDesktopRecordingResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StopDesktopRecording", ctx, req)
ret0, _ := ret[0].(io.ReadCloser)
ret0, _ := ret[0].(workspacesdk.StopDesktopRecordingResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}