diff --git a/agent/x/agentdesktop/api.go b/agent/x/agentdesktop/api.go index 2cae89bd04..33ff0fb7ca 100644 --- a/agent/x/agentdesktop/api.go +++ b/agent/x/agentdesktop/api.go @@ -5,7 +5,9 @@ import ( "encoding/json" "errors" "io" + "mime/multipart" "net/http" + "net/textproto" "strconv" "sync" "time" @@ -620,6 +622,11 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) { return } defer artifact.Reader.Close() + defer func() { + if artifact.ThumbnailReader != nil { + _ = artifact.ThumbnailReader.Close() + } + }() if artifact.Size > workspacesdk.MaxRecordingSize { a.logger.Warn(ctx, "recording file exceeds maximum size", @@ -633,10 +640,60 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) { return } - rw.Header().Set("Content-Type", "video/mp4") - rw.Header().Set("Content-Length", strconv.FormatInt(artifact.Size, 10)) + // Discard the thumbnail if it exceeds the maximum size. + // The server-side consumer also enforces this per-part, but + // rejecting it here avoids streaming a large thumbnail over + // the wire for nothing. + if artifact.ThumbnailReader != nil && artifact.ThumbnailSize > workspacesdk.MaxThumbnailSize { + a.logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting", + slog.F("recording_id", recordingID), + slog.F("size", artifact.ThumbnailSize), + slog.F("max_size", workspacesdk.MaxThumbnailSize), + ) + _ = artifact.ThumbnailReader.Close() + artifact.ThumbnailReader = nil + artifact.ThumbnailSize = 0 + } + + // The multipart response is best-effort: once WriteHeader(200) is + // called, CreatePart failures produce a truncated response without + // the closing boundary. The server-side consumer handles this + // gracefully, preserving any parts read before the error. + mw := multipart.NewWriter(rw) + defer mw.Close() + rw.Header().Set("Content-Type", "multipart/mixed; boundary="+mw.Boundary()) rw.WriteHeader(http.StatusOK) - _, _ = io.Copy(rw, artifact.Reader) + + // Part 1: video/mp4 (always present). + videoPart, err := mw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"video/mp4"}, + }) + if err != nil { + a.logger.Warn(ctx, "failed to create video multipart part", + slog.F("recording_id", recordingID), + slog.Error(err)) + return + } + if _, err := io.Copy(videoPart, artifact.Reader); err != nil { + a.logger.Warn(ctx, "failed to write video multipart part", + slog.F("recording_id", recordingID), + slog.Error(err)) + return + } + + // Part 2: image/jpeg (present only when thumbnail was extracted). + if artifact.ThumbnailReader != nil { + thumbPart, err := mw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"image/jpeg"}, + }) + if err != nil { + a.logger.Warn(ctx, "failed to create thumbnail multipart part", + slog.F("recording_id", recordingID), + slog.Error(err)) + return + } + _, _ = io.Copy(thumbPart, artifact.ThumbnailReader) + } } // coordFromAction extracts the coordinate pair from a DesktopAction, diff --git a/agent/x/agentdesktop/api_test.go b/agent/x/agentdesktop/api_test.go index a919cf53bd..7663d677bc 100644 --- a/agent/x/agentdesktop/api_test.go +++ b/agent/x/agentdesktop/api_test.go @@ -4,12 +4,17 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" + "io" + "mime" + "mime/multipart" "net" "net/http" "net/http/httptest" "os" "slices" + "strings" "sync" "testing" "time" @@ -59,6 +64,8 @@ type fakeDesktop struct { lastKeyDown string lastKeyUp string + thumbnailData []byte // if set, StopRecording includes a thumbnail + // Recording tracking (guarded by recMu). recMu sync.Mutex recordings map[string]string // ID → file path @@ -187,10 +194,15 @@ func (f *fakeDesktop) StopRecording(_ context.Context, recordingID string) (*age _ = file.Close() return nil, err } - return &agentdesktop.RecordingArtifact{ + artifact := &agentdesktop.RecordingArtifact{ Reader: file, Size: info.Size(), - }, nil + } + if f.thumbnailData != nil { + artifact.ThumbnailReader = io.NopCloser(bytes.NewReader(f.thumbnailData)) + artifact.ThumbnailSize = int64(len(f.thumbnailData)) + } + return artifact, nil } func (f *fakeDesktop) RecordActivity() { @@ -785,8 +797,8 @@ func TestRecordingStartStop(t *testing.T) { req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody)) handler.ServeHTTP(rr, req) require.Equal(t, http.StatusOK, rr.Code) - assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type")) - assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), rr.Body.Bytes()) + parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes()) + assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), parts["video/mp4"]) } func TestRecordingStartFails(t *testing.T) { @@ -847,8 +859,8 @@ func TestRecordingStartIdempotent(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody)) handler.ServeHTTP(rr, req) require.Equal(t, http.StatusOK, rr.Code) - assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type")) - assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), rr.Body.Bytes()) + parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes()) + assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), parts["video/mp4"]) } func TestRecordingStopIdempotent(t *testing.T) { @@ -872,7 +884,7 @@ func TestRecordingStopIdempotent(t *testing.T) { require.Equal(t, http.StatusOK, rr.Code) // Stop twice - both should succeed with identical data. - var bodies [2][]byte + var videoParts [2][]byte for i := range 2 { body, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent}) require.NoError(t, err) @@ -880,10 +892,10 @@ func TestRecordingStopIdempotent(t *testing.T) { request := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body)) handler.ServeHTTP(recorder, request) require.Equal(t, http.StatusOK, recorder.Code) - assert.Equal(t, "video/mp4", recorder.Header().Get("Content-Type")) - bodies[i] = recorder.Body.Bytes() + parts := parseMultipartParts(t, recorder.Header().Get("Content-Type"), recorder.Body.Bytes()) + videoParts[i] = parts["video/mp4"] } - assert.Equal(t, bodies[0], bodies[1]) + assert.Equal(t, videoParts[0], videoParts[1]) } func TestRecordingStopInvalidIDFormat(t *testing.T) { @@ -1004,8 +1016,8 @@ func TestRecordingMultipleSimultaneous(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body)) handler.ServeHTTP(rr, req) require.Equal(t, http.StatusOK, rr.Code) - assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type")) - assert.Equal(t, expected[id], rr.Body.Bytes()) + parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes()) + assert.Equal(t, expected[id], parts["video/mp4"]) } } @@ -1112,8 +1124,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) { req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody)) handler.ServeHTTP(rr, req) require.Equal(t, http.StatusOK, rr.Code) - assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type")) - firstData := rr.Body.Bytes() + firstParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes()) + firstData := firstParts["video/mp4"] require.NotEmpty(t, firstData) // Step 3: Start again with the same ID - should succeed @@ -1128,8 +1140,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) { req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody)) handler.ServeHTTP(rr, req) require.Equal(t, http.StatusOK, rr.Code) - assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type")) - secondData := rr.Body.Bytes() + secondParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes()) + secondData := secondParts["video/mp4"] require.NotEmpty(t, secondData) // The two recordings should have different data because the @@ -1235,3 +1247,166 @@ func TestRecordingStopCorrupted(t *testing.T) { require.NoError(t, err) assert.Equal(t, "Recording is corrupted.", respStop.Message) } + +// parseMultipartParts parses a multipart/mixed response and returns +// a map from Content-Type to body bytes. +func parseMultipartParts(t *testing.T, contentType string, body []byte) map[string][]byte { + t.Helper() + _, params, err := mime.ParseMediaType(contentType) + require.NoError(t, err, "parse Content-Type") + boundary := params["boundary"] + require.NotEmpty(t, boundary, "missing boundary") + mr := multipart.NewReader(bytes.NewReader(body), boundary) + parts := make(map[string][]byte) + for { + part, err := mr.NextPart() + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err, "unexpected multipart parse error") + ct := part.Header.Get("Content-Type") + data, readErr := io.ReadAll(part) + require.NoError(t, readErr) + parts[ct] = data + } + return parts +} + +func TestHandleRecordingStop_WithThumbnail(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, nil) + // Create a fake JPEG header: 0xFF 0xD8 0xFF followed by 509 zero bytes. + thumbnail := make([]byte, 512) + thumbnail[0] = 0xff + thumbnail[1] = 0xd8 + thumbnail[2] = 0xff + + fake := &fakeDesktop{ + startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080}, + thumbnailData: thumbnail, + } + api := agentdesktop.NewAPI(logger, fake, nil) + defer api.Close() + + handler := api.Routes() + + // Start recording. + recID := uuid.New().String() + startBody, err := json.Marshal(map[string]string{"recording_id": recID}) + require.NoError(t, err) + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody)) + handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code) + + // Stop recording. + stopBody, err := json.Marshal(map[string]string{"recording_id": recID}) + require.NoError(t, err) + rr = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody)) + handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code) + + // Verify multipart response. + ct := rr.Header().Get("Content-Type") + assert.True(t, strings.HasPrefix(ct, "multipart/mixed"), + "expected multipart/mixed Content-Type, got %s", ct) + + parts := parseMultipartParts(t, ct, rr.Body.Bytes()) + assert.Len(t, parts, 2, "expected exactly 2 parts (video + thumbnail)") + + // The fake writes "fake-mp4-data--" as the MP4 content. + expectedMP4 := []byte("fake-mp4-data-" + recID + "-1") + assert.Equal(t, expectedMP4, parts["video/mp4"]) + assert.Equal(t, thumbnail, parts["image/jpeg"]) +} + +func TestHandleRecordingStop_NoThumbnail(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, nil) + fake := &fakeDesktop{ + startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080}, + } + api := agentdesktop.NewAPI(logger, fake, nil) + defer api.Close() + + handler := api.Routes() + + // Start recording. + recID := uuid.New().String() + startBody, err := json.Marshal(map[string]string{"recording_id": recID}) + require.NoError(t, err) + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody)) + handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code) + + // Stop recording. + stopBody, err := json.Marshal(map[string]string{"recording_id": recID}) + require.NoError(t, err) + rr = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody)) + handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code) + + // Verify multipart response. + ct := rr.Header().Get("Content-Type") + assert.True(t, strings.HasPrefix(ct, "multipart/mixed"), + "expected multipart/mixed Content-Type, got %s", ct) + + parts := parseMultipartParts(t, ct, rr.Body.Bytes()) + assert.Len(t, parts, 1, "expected exactly 1 part (video only)") + + expectedMP4 := []byte("fake-mp4-data-" + recID + "-1") + assert.Equal(t, expectedMP4, parts["video/mp4"]) +} + +func TestHandleRecordingStop_OversizedThumbnail(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, nil) + // Create thumbnail data that exceeds MaxThumbnailSize. + oversizedThumb := make([]byte, workspacesdk.MaxThumbnailSize+1) + oversizedThumb[0] = 0xff + oversizedThumb[1] = 0xd8 + oversizedThumb[2] = 0xff + + fake := &fakeDesktop{ + startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080}, + thumbnailData: oversizedThumb, + } + api := agentdesktop.NewAPI(logger, fake, nil) + defer api.Close() + + handler := api.Routes() + + // Start recording. + recID := uuid.New().String() + startBody, err := json.Marshal(map[string]string{"recording_id": recID}) + require.NoError(t, err) + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody)) + handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code) + + // Stop recording. + stopBody, err := json.Marshal(map[string]string{"recording_id": recID}) + require.NoError(t, err) + rr = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody)) + handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code) + + // Verify multipart response contains only the video part. + ct := rr.Header().Get("Content-Type") + assert.True(t, strings.HasPrefix(ct, "multipart/mixed"), + "expected multipart/mixed Content-Type, got %s", ct) + + parts := parseMultipartParts(t, ct, rr.Body.Bytes()) + assert.Len(t, parts, 1, "expected exactly 1 part (video only, oversized thumbnail discarded)") + + expectedMP4 := []byte("fake-mp4-data-" + recID + "-1") + assert.Equal(t, expectedMP4, parts["video/mp4"]) +} diff --git a/agent/x/agentdesktop/desktop.go b/agent/x/agentdesktop/desktop.go index 82760d314d..9f2ac424b3 100644 --- a/agent/x/agentdesktop/desktop.go +++ b/agent/x/agentdesktop/desktop.go @@ -105,6 +105,11 @@ type RecordingArtifact struct { Reader io.ReadCloser // Size is the byte length of the MP4 content. Size int64 + // ThumbnailReader is the JPEG thumbnail. May be nil if no + // thumbnail was produced. Callers must close it when done. + ThumbnailReader io.ReadCloser + // ThumbnailSize is the byte length of the thumbnail. + ThumbnailSize int64 } // DisplayConfig describes a running desktop session. diff --git a/agent/x/agentdesktop/portabledesktop.go b/agent/x/agentdesktop/portabledesktop.go index 47e922c565..99fa422db4 100644 --- a/agent/x/agentdesktop/portabledesktop.go +++ b/agent/x/agentdesktop/portabledesktop.go @@ -56,6 +56,7 @@ type screenshotOutput struct { type recordingProcess struct { cmd *exec.Cmd filePath string + thumbPath string stopped bool killed bool // true when the process was SIGKILLed done chan struct{} // closed when cmd.Wait() returns @@ -383,13 +384,20 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string } } // Completed recording - discard old file, start fresh. - if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) { + if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) { p.logger.Warn(ctx, "failed to remove old recording file", slog.F("recording_id", recordingID), slog.F("file_path", rec.filePath), slog.Error(err), ) } + if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) { + p.logger.Warn(ctx, "failed to remove old thumbnail file", + slog.F("recording_id", recordingID), + slog.F("thumbnail_path", rec.thumbPath), + slog.Error(err), + ) + } delete(p.recordings, recordingID) } @@ -406,6 +414,7 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string } filePath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".mp4") + thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".thumb.jpg") // Use a background context so the process outlives the HTTP // request that triggered it. @@ -419,6 +428,7 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string "--idle-speedup", "20", "--idle-min-duration", "0.35", "--idle-noise-tolerance", "-38dB", + "--thumbnail", thumbPath, filePath) if err := cmd.Start(); err != nil { @@ -427,9 +437,10 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string } rec := &recordingProcess{ - cmd: cmd, - filePath: filePath, - done: make(chan struct{}), + cmd: cmd, + filePath: filePath, + thumbPath: thumbPath, + done: make(chan struct{}), } go func() { rec.waitErr = cmd.Wait() @@ -499,10 +510,35 @@ func (p *portableDesktop) StopRecording(ctx context.Context, recordingID string) _ = f.Close() return nil, xerrors.Errorf("stat recording artifact: %w", err) } - return &RecordingArtifact{ + artifact := &RecordingArtifact{ Reader: f, Size: info.Size(), - }, nil + } + // Attach thumbnail if the subprocess wrote one. + thumbFile, err := os.Open(rec.thumbPath) + if err != nil { + p.logger.Warn(ctx, "thumbnail not available", + slog.F("thumbnail_path", rec.thumbPath), + slog.Error(err)) + return artifact, nil + } + thumbInfo, err := thumbFile.Stat() + if err != nil { + _ = thumbFile.Close() + p.logger.Warn(ctx, "thumbnail stat failed", + slog.F("thumbnail_path", rec.thumbPath), + slog.Error(err)) + return artifact, nil + } + if thumbInfo.Size() == 0 { + _ = thumbFile.Close() + p.logger.Warn(ctx, "thumbnail file is empty", + slog.F("thumbnail_path", rec.thumbPath)) + return artifact, nil + } + artifact.ThumbnailReader = thumbFile + artifact.ThumbnailSize = thumbInfo.Size() + return artifact, nil } // lockedStopRecordingProcess stops a single recording via stopOnce. @@ -571,18 +607,33 @@ func (p *portableDesktop) lockedCleanStaleRecordings(ctx context.Context) { } info, err := os.Stat(rec.filePath) if err != nil { - // File already removed or inaccessible; drop entry. + // File already removed or inaccessible; clean up + // any leftover thumbnail and drop the entry. + if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) { + p.logger.Warn(ctx, "failed to remove stale thumbnail file", + slog.F("recording_id", id), + slog.F("thumbnail_path", rec.thumbPath), + slog.Error(err), + ) + } delete(p.recordings, id) continue } if p.clock.Since(info.ModTime()) > time.Hour { - if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) { + if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) { p.logger.Warn(ctx, "failed to remove stale recording file", slog.F("recording_id", id), slog.F("file_path", rec.filePath), slog.Error(err), ) } + if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) { + p.logger.Warn(ctx, "failed to remove stale thumbnail file", + slog.F("recording_id", id), + slog.F("thumbnail_path", rec.thumbPath), + slog.Error(err), + ) + } delete(p.recordings, id) } } @@ -603,13 +654,14 @@ func (p *portableDesktop) Close() error { // Snapshot recording file paths and idle goroutine channels // for cleanup, then clear the map. type recEntry struct { - id string - filePath string - idleDone chan struct{} + id string + filePath string + thumbPath string + idleDone chan struct{} } var allRecs []recEntry for id, rec := range p.recordings { - allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, idleDone: rec.idleDone}) + allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, thumbPath: rec.thumbPath, idleDone: rec.idleDone}) delete(p.recordings, id) } session := p.session @@ -630,13 +682,20 @@ func (p *portableDesktop) Close() error { go func() { defer close(cleanupDone) for _, entry := range allRecs { - if err := os.Remove(entry.filePath); err != nil && !os.IsNotExist(err) { + if err := os.Remove(entry.filePath); err != nil && !errors.Is(err, os.ErrNotExist) { p.logger.Warn(context.Background(), "failed to remove recording file on close", slog.F("recording_id", entry.id), slog.F("file_path", entry.filePath), slog.Error(err), ) } + if err := os.Remove(entry.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) { + p.logger.Warn(context.Background(), "failed to remove thumbnail file on close", + slog.F("recording_id", entry.id), + slog.F("thumbnail_path", entry.thumbPath), + slog.Error(err), + ) + } } if session != nil { session.cancel() diff --git a/agent/x/agentdesktop/portabledesktop_internal_test.go b/agent/x/agentdesktop/portabledesktop_internal_test.go index 64fa9ceb7e..562ce4e23f 100644 --- a/agent/x/agentdesktop/portabledesktop_internal_test.go +++ b/agent/x/agentdesktop/portabledesktop_internal_test.go @@ -2,6 +2,7 @@ package agentdesktop import ( "context" + "io" "os" "os/exec" "path/filepath" @@ -584,6 +585,7 @@ func TestPortableDesktop_StartRecording(t *testing.T) { joined := strings.Join(cmd, " ") if strings.Contains(joined, "record") && strings.Contains(joined, "coder-recording-"+recID) { found = true + assert.Contains(t, joined, "--thumbnail", "record command should include --thumbnail flag") break } } @@ -666,6 +668,66 @@ func TestPortableDesktop_StopRecording_ReturnsArtifact(t *testing.T) { defer artifact.Reader.Close() assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size) + // No thumbnail file exists, so ThumbnailReader should be nil. + assert.Nil(t, artifact.ThumbnailReader, "ThumbnailReader should be nil when no thumbnail file exists") + + require.NoError(t, pd.Close()) +} + +func TestPortableDesktop_StopRecording_WithThumbnail(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, nil) + rec := &recordedExecer{ + scripts: map[string]string{ + "record": `trap 'exit 0' INT; sleep 120 & wait`, + "up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`, + }, + } + + clk := quartz.NewReal() + pd := &portableDesktop{ + logger: logger, + execer: rec, + scriptBinDir: t.TempDir(), + clock: clk, + binPath: "portabledesktop", + recordings: make(map[string]*recordingProcess), + } + pd.lastDesktopActionAt.Store(clk.Now().UnixNano()) + + ctx := t.Context() + recID := uuid.New().String() + err := pd.StartRecording(ctx, recID) + require.NoError(t, err) + + // Write a dummy MP4 file at the expected path. + filePath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".mp4") + require.NoError(t, os.WriteFile(filePath, []byte("fake-mp4-data"), 0o600)) + t.Cleanup(func() { _ = os.Remove(filePath) }) + + // Write a thumbnail file at the expected path. + thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".thumb.jpg") + thumbContent := []byte("fake-jpeg-thumbnail") + require.NoError(t, os.WriteFile(thumbPath, thumbContent, 0o600)) + t.Cleanup(func() { _ = os.Remove(thumbPath) }) + + artifact, err := pd.StopRecording(ctx, recID) + require.NoError(t, err) + defer artifact.Reader.Close() + + assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size) + + // Thumbnail should be attached. + require.NotNil(t, artifact.ThumbnailReader, "ThumbnailReader should be non-nil when thumbnail file exists") + defer artifact.ThumbnailReader.Close() + assert.Equal(t, int64(len(thumbContent)), artifact.ThumbnailSize) + + // Read and verify thumbnail content. + thumbData, err := io.ReadAll(artifact.ThumbnailReader) + require.NoError(t, err) + assert.Equal(t, thumbContent, thumbData) + require.NoError(t, pd.Close()) } diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index bd9476a6eb..93a6956c85 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -73,9 +73,9 @@ const ( // maxConcurrentRecordingUploads caps the number of recording // stop-and-store operations that can run concurrently. Each - // slot buffers up to MaxRecordingSize (100 MB) in memory, so - // this value implicitly bounds memory to roughly - // maxConcurrentRecordingUploads * 100 MB. + // slot buffers up to MaxRecordingSize + MaxThumbnailSize + // (110 MB) in memory, so this value implicitly bounds memory + // to roughly maxConcurrentRecordingUploads * 110 MB. maxConcurrentRecordingUploads = 25 // staleRecoveryIntervalDivisor determines how often the stale diff --git a/coderd/x/chatd/recording.go b/coderd/x/chatd/recording.go index 2d4dc403ec..eaf5b1c009 100644 --- a/coderd/x/chatd/recording.go +++ b/coderd/x/chatd/recording.go @@ -2,8 +2,11 @@ package chatd import ( "context" + "errors" "fmt" "io" + "mime" + "mime/multipart" "github.com/google/uuid" @@ -13,71 +16,60 @@ import ( "github.com/coder/coder/v2/codersdk/workspacesdk" ) +type recordingResult struct { + recordingFileID string + thumbnailFileID string +} + // stopAndStoreRecording stops the desktop recording, downloads the -// MP4, and stores it in chat_files. Only called when the subagent -// completed successfully. Returns the file ID on success, empty -// string on any failure. All errors are logged but not propagated -// — recording is best-effort. +// multipart response containing the MP4 and optional thumbnail, and +// stores them in chat_files. Only called when the subagent completed +// successfully. Returns file IDs on success, empty fields on any +// failure. All errors are logged but not propagated; recording is +// best-effort. func (p *Server) stopAndStoreRecording( ctx context.Context, conn workspacesdk.AgentConn, recordingID string, ownerID uuid.UUID, workspaceID uuid.NullUUID, -) string { +) recordingResult { + var result recordingResult + select { case p.recordingSem <- struct{}{}: defer func() { <-p.recordingSem }() case <-ctx.Done(): p.logger.Warn(ctx, "context canceled waiting for recording semaphore", slog.Error(ctx.Err())) - return "" + return result } - body, err := conn.StopDesktopRecording(ctx, + resp, err := conn.StopDesktopRecording(ctx, workspacesdk.StopDesktopRecordingRequest{RecordingID: recordingID}) if err != nil { p.logger.Warn(ctx, "failed to stop desktop recording", slog.Error(err)) - return "" + return result } - type readResult struct { - data []byte - err error - } - ch := make(chan readResult, 1) - go func() { - data, err := io.ReadAll(io.LimitReader(body, workspacesdk.MaxRecordingSize+1)) - ch <- readResult{data, err} - }() + defer resp.Body.Close() - var data []byte - select { - case res := <-ch: - body.Close() - data = res.data - if res.err != nil { - p.logger.Warn(ctx, "failed to read recording data", slog.Error(res.err)) - return "" - } - case <-ctx.Done(): - body.Close() - p.logger.Warn(ctx, "context canceled while reading recording data", slog.Error(ctx.Err())) - return "" + _, params, err := mime.ParseMediaType(resp.ContentType) + if err != nil { + p.logger.Warn(ctx, "failed to parse content type from recording response", + slog.F("content_type", resp.ContentType), + slog.Error(err)) + return result } - if len(data) > workspacesdk.MaxRecordingSize { - p.logger.Warn(ctx, "recording data exceeds maximum size, skipping store", - slog.F("size", len(data)), - slog.F("max_size", workspacesdk.MaxRecordingSize)) - return "" - } - if len(data) == 0 { - p.logger.Warn(ctx, "recording data is empty, skipping store") - return "" + boundary := params["boundary"] + if boundary == "" { + p.logger.Warn(ctx, "missing boundary in recording response content type", + slog.F("content_type", resp.ContentType)) + return result } if !workspaceID.Valid { p.logger.Warn(ctx, "chat has no workspace, cannot store recording") - return "" + return result } // The chatd actor is used here because the recording is stored on @@ -87,21 +79,123 @@ func (p *Server) stopAndStoreRecording( if err != nil { p.logger.Warn(ctx, "failed to resolve workspace for recording", slog.Error(err)) - return "" + return result } - //nolint:gocritic // AsChatd is required to insert chat files from the recording pipeline. - row, err := p.db.InsertChatFile(dbauthz.AsChatd(ctx), database.InsertChatFileParams{ - OwnerID: ownerID, - OrganizationID: ws.OrganizationID, - Name: fmt.Sprintf("recording-%s.mp4", p.clock.Now().UTC().Format("2006-01-02T15-04-05Z")), - Mimetype: "video/mp4", - Data: data, - }) - if err != nil { - p.logger.Warn(ctx, "failed to store recording in database", - slog.Error(err)) - return "" + mr := multipart.NewReader(resp.Body, boundary) + // Context cancellation is checked between parts. Within a + // part read, cancellation relies on Go's HTTP transport closing + // the underlying connection when the context is done, which + // interrupts the blocked io.ReadAll. + // First pass: parse all multipart parts into memory. + // The agent sends at most two parts: one video/mp4 and one + // optional image/jpeg thumbnail. Cap the number of parts to + // prevent a malicious or broken agent from forcing the server + // into an unbounded parsing loop. + const maxParts = 2 + var videoData, thumbnailData []byte + for range maxParts { + if ctx.Err() != nil { + p.logger.Warn(ctx, "context canceled while reading recording parts", slog.Error(ctx.Err())) + break + } + + part, err := mr.NextPart() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + p.logger.Warn(ctx, "error reading next multipart part", slog.Error(err)) + break + } + + contentType := part.Header.Get("Content-Type") + + // Select the read limit based on content type so that + // thumbnails (image/jpeg) do not allocate up to + // MaxRecordingSize (100 MB) before the size check rejects + // them. Unknown types use a small default since they are + // discarded below. + maxSize := int64(1 << 20) // 1 MB default for unknown types + switch contentType { + case "video/mp4": + maxSize = int64(workspacesdk.MaxRecordingSize) + case "image/jpeg": + maxSize = int64(workspacesdk.MaxThumbnailSize) + } + + data, err := io.ReadAll(io.LimitReader(part, maxSize+1)) + if err != nil { + p.logger.Warn(ctx, "failed to read recording part data", + slog.F("content_type", contentType), + slog.Error(err)) + continue + } + if int64(len(data)) > maxSize { + p.logger.Warn(ctx, "recording part exceeds maximum size, skipping", + slog.F("content_type", contentType), + slog.F("size", len(data)), + slog.F("max_size", maxSize)) + continue + } + if len(data) == 0 { + p.logger.Warn(ctx, "recording part is empty, skipping", + slog.F("content_type", contentType)) + continue + } + + switch contentType { + case "video/mp4": + if videoData != nil { + p.logger.Warn(ctx, "duplicate video/mp4 part in recording response, skipping") + continue + } + videoData = data + case "image/jpeg": + if thumbnailData != nil { + p.logger.Warn(ctx, "duplicate image/jpeg part in recording response, skipping") + continue + } + thumbnailData = data + default: + p.logger.Debug(ctx, "skipping unknown part content type", + slog.F("content_type", contentType)) + } } - return row.ID.String() + + // Second pass: store the collected data in the database. + if videoData != nil { + //nolint:gocritic // AsChatd is required to insert chat files from the recording pipeline. + row, err := p.db.InsertChatFile(dbauthz.AsChatd(ctx), database.InsertChatFileParams{ + OwnerID: ownerID, + OrganizationID: ws.OrganizationID, + Name: fmt.Sprintf("recording-%s.mp4", p.clock.Now().UTC().Format("2006-01-02T15-04-05Z")), + Mimetype: "video/mp4", + Data: videoData, + }) + if err != nil { + p.logger.Warn(ctx, "failed to store recording in database", + slog.Error(err)) + } else { + result.recordingFileID = row.ID.String() + } + } + if thumbnailData != nil && result.recordingFileID != "" { + //nolint:gocritic // AsChatd is required to insert chat files from the recording pipeline. + row, err := p.db.InsertChatFile(dbauthz.AsChatd(ctx), database.InsertChatFileParams{ + OwnerID: ownerID, + OrganizationID: ws.OrganizationID, + Name: fmt.Sprintf("thumbnail-%s.jpg", p.clock.Now().UTC().Format("2006-01-02T15-04-05Z")), + Mimetype: "image/jpeg", + Data: thumbnailData, + }) + if err != nil { + p.logger.Warn(ctx, "failed to store thumbnail in database", + slog.Error(err)) + } else { + result.thumbnailFileID = row.ID.String() + } + } + + return result } diff --git a/coderd/x/chatd/recording_internal_test.go b/coderd/x/chatd/recording_internal_test.go index b511d18168..67ea48dc2f 100644 --- a/coderd/x/chatd/recording_internal_test.go +++ b/coderd/x/chatd/recording_internal_test.go @@ -5,6 +5,8 @@ import ( "context" "encoding/json" "io" + "mime/multipart" + "net/textproto" "strings" "testing" "time" @@ -34,6 +36,30 @@ func (zeroReader) Read(p []byte) (int, error) { return len(p), nil } +// partSpec describes a single part for buildMultipartResponse. +type partSpec struct { + contentType string + data []byte +} + +// buildMultipartResponse constructs a StopDesktopRecordingResponse +// with the given content type/data pairs encoded as multipart/mixed. +func buildMultipartResponse(parts ...partSpec) workspacesdk.StopDesktopRecordingResponse { + var buf bytes.Buffer + mw := multipart.NewWriter(&buf) + for _, p := range parts { + partWriter, _ := mw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {p.contentType}, + }) + _, _ = partWriter.Write(p.data) + } + _ = mw.Close() + return workspacesdk.StopDesktopRecordingResponse{ + Body: io.NopCloser(bytes.NewReader(buf.Bytes())), + ContentType: "multipart/mixed; boundary=" + mw.Boundary(), + } +} + // createComputerUseParentChild creates a parent chat and a // computer_use child chat bound to the given workspace/agent. // Both chats are inserted directly via DB to avoid triggering @@ -170,8 +196,7 @@ func TestWaitAgentComputerUseRecording(t *testing.T) { mockConn.EXPECT(). StopDesktopRecording(gomock.Any(), gomock.Any()). - Return(io.NopCloser(bytes.NewReader(fakeMp4)), nil). - Times(1) + Return(buildMultipartResponse(partSpec{"video/mp4", fakeMp4}), nil).Times(1) // Invoke wait_agent via the tool closure. resp, err := invokeWaitAgentTool(ctx, t, server, db, parent.ID, child.ID, 5) @@ -198,6 +223,87 @@ func TestWaitAgentComputerUseRecording(t *testing.T) { assert.Equal(t, fakeMp4, chatFile.Data) } +// TestWaitAgentComputerUseRecordingWithThumbnail verifies the +// recording flow when the agent produces both video and thumbnail: +// both file IDs appear in the wait_agent tool response. +func TestWaitAgentComputerUseRecordingWithThumbnail(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := chatdTestContext(t) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + + user, model := seedInternalChatDeps(ctx, t, db) + workspace, _, agent := seedWorkspaceBinding(t, db, user.ID) + + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + parent, child := createComputerUseParentChild( + ctx, t, server, user, model, workspace, agent, + "parent-recording-thumb", "computer-use-child-thumb", + ) + + server.drainInflight() + + server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, agent.ID, agentID) + return mockConn, func() {}, nil + } + + insertAssistantMessage(ctx, t, db, child.ID, model.ID, "I opened Firefox and took a screenshot.") + + setChatStatus(ctx, t, db, child.ID, database.ChatStatusWaiting, "") + + fakeMp4 := []byte("fake-mp4-data-with-thumbnail-test") + fakeThumb := []byte("fake-jpeg-thumbnail-data") + + mockConn.EXPECT(). + StartDesktopRecording(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, req workspacesdk.StartDesktopRecordingRequest) error { + require.NotEmpty(t, req.RecordingID) + return nil + }). + Times(1) + + mockConn.EXPECT(). + StopDesktopRecording(gomock.Any(), gomock.Any()). + Return(buildMultipartResponse( + partSpec{"video/mp4", fakeMp4}, + partSpec{"image/jpeg", fakeThumb}, + ), nil).Times(1) + + resp, err := invokeWaitAgentTool(ctx, t, server, db, parent.ID, child.ID, 5) + require.NoError(t, err) + require.False(t, resp.IsError, "expected successful response, got: %s", resp.Content) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + + // Verify recording_file_id is present and valid. + storedFileID, ok := result["recording_file_id"].(string) + require.True(t, ok, "recording_file_id must be present in response") + require.NotEmpty(t, storedFileID) + fileUUID, err := uuid.Parse(storedFileID) + require.NoError(t, err) + chatFile, err := db.GetChatFileByID(ctx, fileUUID) + require.NoError(t, err) + assert.Equal(t, "video/mp4", chatFile.Mimetype) + assert.Equal(t, fakeMp4, chatFile.Data) + + // Verify thumbnail_file_id is present and valid. + thumbFileID, ok := result["thumbnail_file_id"].(string) + require.True(t, ok, "thumbnail_file_id must be present in response") + require.NotEmpty(t, thumbFileID) + thumbUUID, err := uuid.Parse(thumbFileID) + require.NoError(t, err) + thumbFile, err := db.GetChatFileByID(ctx, thumbUUID) + require.NoError(t, err) + assert.Equal(t, "image/jpeg", thumbFile.Mimetype) + assert.Equal(t, fakeThumb, thumbFile.Data) +} + // TestWaitAgentNonComputerUseNoRecording verifies that when the // child chat is NOT a computer_use chat, no recording is attempted. // StartDesktopRecording must never be called. @@ -342,7 +448,7 @@ func TestWaitAgentRecordingStopFails(t *testing.T) { mockConn.EXPECT(). StopDesktopRecording(gomock.Any(), gomock.Any()). - Return(nil, xerrors.New("disk full")). + Return(workspacesdk.StopDesktopRecordingResponse{}, xerrors.New("disk full")). Times(1) // Invoke wait_agent via the tool closure. @@ -446,10 +552,10 @@ func TestWaitAgentTimeoutLeavesRecordingRunning(t *testing.T) { assert.Contains(t, result.resp.Content, "timed out") } -// TestStopAndStoreRecordingOversized verifies that when the recording -// data exceeds MaxRecordingSize, stopAndStoreRecording returns an -// empty string and does NOT call InsertChatFile. -func TestStopAndStoreRecordingOversized(t *testing.T) { +// TestStopAndStoreRecording_Oversized verifies that when the +// recording data exceeds MaxRecordingSize, stopAndStoreRecording +// returns an empty string and does NOT call InsertChatFile. +func TestStopAndStoreRecording_Oversized(t *testing.T) { t.Parallel() db, ps := dbtestutil.NewDB(t) @@ -463,29 +569,146 @@ func TestStopAndStoreRecordingOversized(t *testing.T) { server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) - // Create a reader that produces MaxRecordingSize+1 bytes without - // allocating the full buffer in memory. - oversizedReader := io.LimitReader( - &zeroReader{}, - int64(workspacesdk.MaxRecordingSize+1), - ) + // Build a streaming multipart response with a video/mp4 part + // that exceeds MaxRecordingSize without allocating the full + // buffer in memory. + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + go func() { + partWriter, _ := mw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"video/mp4"}, + }) + // Stream MaxRecordingSize+1 zero bytes. + _, _ = io.Copy(partWriter, io.LimitReader(&zeroReader{}, int64(workspacesdk.MaxRecordingSize+1))) + _ = mw.Close() + _ = pw.Close() + }() + mockConn.EXPECT(). StopDesktopRecording(gomock.Any(), gomock.Any()). - Return(io.NopCloser(oversizedReader), nil). + Return(workspacesdk.StopDesktopRecordingResponse{ + Body: pr, + ContentType: "multipart/mixed; boundary=" + mw.Boundary(), + }, nil). Times(1) recordingID := uuid.New().String() - storedFileID := server.stopAndStoreRecording( + result := server.stopAndStoreRecording( ctx, mockConn, recordingID, user.ID, uuid.NullUUID{UUID: workspace.ID, Valid: true}, ) - assert.Empty(t, storedFileID, "oversized recording should not be stored") + assert.Empty(t, result.recordingFileID, "oversized recording should not be stored") } -// TestStopAndStoreRecordingEmpty verifies that when the recording +// TestStopAndStoreRecording_OversizedThumbnail verifies that when the +// thumbnail part exceeds MaxThumbnailSize it is skipped while the +// normal-sized video part is still stored. +func TestStopAndStoreRecording_OversizedThumbnail(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := chatdTestContext(t) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + + user, _ := seedInternalChatDeps(ctx, t, db) + workspace, _, _ := seedWorkspaceBinding(t, db, user.ID) + + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + videoData := bytes.Repeat([]byte{0xAA}, 1024) + + // Build a streaming multipart response with a normal video part + // and an oversized thumbnail part. + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + go func() { + vw, _ := mw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"video/mp4"}, + }) + _, _ = vw.Write(videoData) + tw, _ := mw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"image/jpeg"}, + }) + // Stream MaxThumbnailSize+1 zero bytes for the thumbnail. + _, _ = io.Copy(tw, io.LimitReader(&zeroReader{}, int64(workspacesdk.MaxThumbnailSize+1))) + _ = mw.Close() + _ = pw.Close() + }() + + mockConn.EXPECT(). + StopDesktopRecording(gomock.Any(), gomock.Any()). + Return(workspacesdk.StopDesktopRecordingResponse{ + Body: pr, + ContentType: "multipart/mixed; boundary=" + mw.Boundary(), + }, nil). + Times(1) + + recordingID := uuid.New().String() + result := server.stopAndStoreRecording( + ctx, mockConn, recordingID, user.ID, + uuid.NullUUID{UUID: workspace.ID, Valid: true}, + ) + + // Video should be stored. + recUUID, err := uuid.Parse(result.recordingFileID) + require.NoError(t, err, "RecordingFileID should be a valid UUID") + recFile, err := db.GetChatFileByID(ctx, recUUID) + require.NoError(t, err) + assert.Equal(t, "video/mp4", recFile.Mimetype) + assert.Equal(t, videoData, recFile.Data) + + // Thumbnail should be skipped (oversized). + assert.Empty(t, result.thumbnailFileID, "oversized thumbnail should not be stored") +} + +// TestStopAndStoreRecording_DuplicatePartsIgnored verifies that when +// a multipart response contains two video/mp4 parts, only the first +// is stored and the duplicate is skipped. +func TestStopAndStoreRecording_DuplicatePartsIgnored(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := chatdTestContext(t) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + + user, _ := seedInternalChatDeps(ctx, t, db) + workspace, _, _ := seedWorkspaceBinding(t, db, user.ID) + + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + firstVideo := bytes.Repeat([]byte{0x01}, 512) + secondVideo := bytes.Repeat([]byte{0x02}, 512) + + mockConn.EXPECT(). + StopDesktopRecording(gomock.Any(), gomock.Any()). + Return(buildMultipartResponse( + partSpec{"video/mp4", firstVideo}, + partSpec{"video/mp4", secondVideo}, + ), nil). + Times(1) + + recordingID := uuid.New().String() + result := server.stopAndStoreRecording( + ctx, mockConn, recordingID, user.ID, + uuid.NullUUID{UUID: workspace.ID, Valid: true}, + ) + + // Only the first video part should be stored. + recUUID, err := uuid.Parse(result.recordingFileID) + require.NoError(t, err) + recFile, err := db.GetChatFileByID(ctx, recUUID) + require.NoError(t, err) + assert.Equal(t, firstVideo, recFile.Data, "first video part should be stored, not the duplicate") +} + +// TestStopAndStoreRecording_Empty verifies that when the recording // data is empty, stopAndStoreRecording returns an empty string and // does NOT call InsertChatFile. -func TestStopAndStoreRecordingEmpty(t *testing.T) { +func TestStopAndStoreRecording_Empty(t *testing.T) { t.Parallel() db, ps := dbtestutil.NewDB(t) @@ -499,16 +722,265 @@ func TestStopAndStoreRecordingEmpty(t *testing.T) { server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) - // Return empty data. + // Build a multipart response with an empty video/mp4 part. mockConn.EXPECT(). StopDesktopRecording(gomock.Any(), gomock.Any()). - Return(io.NopCloser(bytes.NewReader(nil)), nil). - Times(1) + Return(buildMultipartResponse(partSpec{"video/mp4", nil}), nil).Times(1) recordingID := uuid.New().String() - storedFileID := server.stopAndStoreRecording( + result := server.stopAndStoreRecording( ctx, mockConn, recordingID, user.ID, uuid.NullUUID{UUID: workspace.ID, Valid: true}, ) - assert.Empty(t, storedFileID, "empty recording should not be stored") + assert.Empty(t, result.recordingFileID, "empty recording should not be stored") +} + +// TestStopAndStoreRecording_WithThumbnail verifies that a multipart +// response containing both a video/mp4 part and an image/jpeg part +// results in both files being stored with correct mimetypes. +func TestStopAndStoreRecording_WithThumbnail(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := chatdTestContext(t) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + + user, _ := seedInternalChatDeps(ctx, t, db) + workspace, _, _ := seedWorkspaceBinding(t, db, user.ID) + + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + videoData := bytes.Repeat([]byte{0xDE, 0xAD}, 512) // 1024 bytes + thumbData := bytes.Repeat([]byte{0xFF, 0xD8}, 256) // 512 bytes + + mockConn.EXPECT(). + StopDesktopRecording(gomock.Any(), gomock.Any()). + Return(buildMultipartResponse( + partSpec{"video/mp4", videoData}, + partSpec{"image/jpeg", thumbData}, + ), nil). + Times(1) + + recordingID := uuid.New().String() + result := server.stopAndStoreRecording( + ctx, mockConn, recordingID, user.ID, + uuid.NullUUID{UUID: workspace.ID, Valid: true}, + ) + + // Both file IDs should be valid UUIDs. + recUUID, err := uuid.Parse(result.recordingFileID) + require.NoError(t, err, "RecordingFileID should be a valid UUID") + + thumbUUID, err := uuid.Parse(result.thumbnailFileID) + require.NoError(t, err, "ThumbnailFileID should be a valid UUID") + // Verify the recording file in the database. + recFile, err := db.GetChatFileByID(ctx, recUUID) + require.NoError(t, err) + assert.Equal(t, "video/mp4", recFile.Mimetype) + assert.Equal(t, videoData, recFile.Data) + + // Verify the thumbnail file in the database. + thumbFile, err := db.GetChatFileByID(ctx, thumbUUID) + require.NoError(t, err) + assert.Equal(t, "image/jpeg", thumbFile.Mimetype) + assert.Equal(t, thumbData, thumbFile.Data) +} + +// TestStopAndStoreRecording_VideoOnly verifies that a multipart +// response with only a video/mp4 part stores the recording but +// leaves thumbnailFileID empty. +func TestStopAndStoreRecording_VideoOnly(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := chatdTestContext(t) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + + user, _ := seedInternalChatDeps(ctx, t, db) + workspace, _, _ := seedWorkspaceBinding(t, db, user.ID) + + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + videoData := make([]byte, 1024) + + mockConn.EXPECT(). + StopDesktopRecording(gomock.Any(), gomock.Any()). + Return(buildMultipartResponse(partSpec{"video/mp4", videoData}), nil).Times(1) + + recordingID := uuid.New().String() + result := server.stopAndStoreRecording( + ctx, mockConn, recordingID, user.ID, + uuid.NullUUID{UUID: workspace.ID, Valid: true}, + ) + + // Recording should be stored. + recUUID, err := uuid.Parse(result.recordingFileID) + require.NoError(t, err, "RecordingFileID should be a valid UUID") + + recFile, err := db.GetChatFileByID(ctx, recUUID) + require.NoError(t, err) + assert.Equal(t, "video/mp4", recFile.Mimetype) + assert.Equal(t, videoData, recFile.Data) + + // No thumbnail. + assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty when no thumbnail part is present") +} + +// TestStopAndStoreRecording_DownloadFailure verifies that when +// StopDesktopRecording returns an error, stopAndStoreRecording +// returns an empty recordingResult without panicking. +func TestStopAndStoreRecording_DownloadFailure(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := chatdTestContext(t) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + + user, _ := seedInternalChatDeps(ctx, t, db) + workspace, _, _ := seedWorkspaceBinding(t, db, user.ID) + + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + mockConn.EXPECT(). + StopDesktopRecording(gomock.Any(), gomock.Any()). + Return(workspacesdk.StopDesktopRecordingResponse{}, xerrors.New("network error")). + Times(1) + + recordingID := uuid.New().String() + result := server.stopAndStoreRecording( + ctx, mockConn, recordingID, user.ID, + uuid.NullUUID{UUID: workspace.ID, Valid: true}, + ) + + assert.Empty(t, result.recordingFileID, "RecordingFileID should be empty on download failure") + assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty on download failure") +} + +// TestStopAndStoreRecording_UnknownPartIgnored verifies that parts +// with unrecognized content types are silently skipped while known +// parts (video/mp4 and image/jpeg) are still stored. +func TestStopAndStoreRecording_UnknownPartIgnored(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := chatdTestContext(t) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + + user, _ := seedInternalChatDeps(ctx, t, db) + workspace, _, _ := seedWorkspaceBinding(t, db, user.ID) + + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + videoData := make([]byte, 1024) + thumbData := make([]byte, 512) + unknownData := make([]byte, 256) + + mockConn.EXPECT(). + StopDesktopRecording(gomock.Any(), gomock.Any()). + Return(buildMultipartResponse( + partSpec{"video/mp4", videoData}, + partSpec{"image/jpeg", thumbData}, + partSpec{"application/octet-stream", unknownData}, + ), nil).Times(1) + + recordingID := uuid.New().String() + result := server.stopAndStoreRecording( + ctx, mockConn, recordingID, user.ID, + uuid.NullUUID{UUID: workspace.ID, Valid: true}, + ) + + // Both known parts should be stored. + recUUID, err := uuid.Parse(result.recordingFileID) + require.NoError(t, err, "RecordingFileID should be a valid UUID") + + thumbUUID, err := uuid.Parse(result.thumbnailFileID) + require.NoError(t, err, "ThumbnailFileID should be a valid UUID") + + // Verify only 2 files exist (unknown part was skipped). + recFile, err := db.GetChatFileByID(ctx, recUUID) + require.NoError(t, err) + assert.Equal(t, "video/mp4", recFile.Mimetype) + assert.Equal(t, videoData, recFile.Data) + + thumbFile, err := db.GetChatFileByID(ctx, thumbUUID) + require.NoError(t, err) + assert.Equal(t, "image/jpeg", thumbFile.Mimetype) + assert.Equal(t, thumbData, thumbFile.Data) +} + +// TestStopAndStoreRecording_MalformedContentType verifies that a +// response with an unparseable Content-Type returns an empty result. +func TestStopAndStoreRecording_MalformedContentType(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := chatdTestContext(t) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + + user, _ := seedInternalChatDeps(ctx, t, db) + workspace, _, _ := seedWorkspaceBinding(t, db, user.ID) + + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + mockConn.EXPECT(). + StopDesktopRecording(gomock.Any(), gomock.Any()). + Return(workspacesdk.StopDesktopRecordingResponse{ + Body: io.NopCloser(bytes.NewReader(nil)), + ContentType: "", + }, nil). + Times(1) + + recordingID := uuid.New().String() + result := server.stopAndStoreRecording( + ctx, mockConn, recordingID, user.ID, + uuid.NullUUID{UUID: workspace.ID, Valid: true}, + ) + + assert.Empty(t, result.recordingFileID, "RecordingFileID should be empty for malformed content type") + assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty for malformed content type") +} + +// TestStopAndStoreRecording_MissingBoundary verifies that a +// multipart response without a boundary parameter returns an empty +// result. +func TestStopAndStoreRecording_MissingBoundary(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := chatdTestContext(t) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + + user, _ := seedInternalChatDeps(ctx, t, db) + workspace, _, _ := seedWorkspaceBinding(t, db, user.ID) + + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + mockConn.EXPECT(). + StopDesktopRecording(gomock.Any(), gomock.Any()). + Return(workspacesdk.StopDesktopRecordingResponse{ + Body: io.NopCloser(bytes.NewReader(nil)), + ContentType: "multipart/mixed", + }, nil). + Times(1) + + recordingID := uuid.New().String() + result := server.stopAndStoreRecording( + ctx, mockConn, recordingID, user.ID, + uuid.NullUUID{UUID: workspace.ID, Valid: true}, + ) + + assert.Empty(t, result.recordingFileID, "RecordingFileID should be empty when boundary is missing") + assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty when boundary is missing") } diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 16753e72d1..7be45aa102 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -233,13 +233,13 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database. } // Only stop and store the recording on success. - var storedFileID string + var recResult recordingResult if recordingID != "" && agentConn != nil { // Use a fresh context for cleanup so a canceled // parent context doesn't prevent recording storage. stopCtx, stopCancel := context.WithTimeout(context.WithoutCancel(ctx), 90*time.Second) defer stopCancel() - storedFileID = p.stopAndStoreRecording(stopCtx, agentConn, + recResult = p.stopAndStoreRecording(stopCtx, agentConn, recordingID, parent.OwnerID, parent.WorkspaceID) } resp := map[string]any{ @@ -248,8 +248,11 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database. "report": report, "status": string(targetChat.Status), } - if storedFileID != "" { - resp["recording_file_id"] = storedFileID + if recResult.recordingFileID != "" { + resp["recording_file_id"] = recResult.recordingFileID + } + if recResult.thumbnailFileID != "" { + resp["thumbnail_file_id"] = recResult.thumbnailFileID } return toolJSONResponse(resp), nil }, diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index 831c94113e..9968a2f27b 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -95,7 +95,7 @@ type AgentConn interface { ConnectDesktopVNC(ctx context.Context) (net.Conn, error) ExecuteDesktopAction(ctx context.Context, action DesktopAction) (DesktopActionResponse, error) StartDesktopRecording(ctx context.Context, req StartDesktopRecordingRequest) error - StopDesktopRecording(ctx context.Context, req StopDesktopRecordingRequest) (io.ReadCloser, error) + StopDesktopRecording(ctx context.Context, req StopDesktopRecordingRequest) (StopDesktopRecordingResponse, error) } // AgentConn represents a connection to a workspace agent. @@ -610,11 +610,25 @@ type StopDesktopRecordingRequest struct { RecordingID string `json:"recording_id"` } +// StopDesktopRecordingResponse wraps the response from stopping a +// desktop recording. Body contains the recording data as a +// multipart/mixed stream. ContentType holds the Content-Type +// header (including boundary) so callers can parse the body. +type StopDesktopRecordingResponse struct { + Body io.ReadCloser + ContentType string +} + // MaxRecordingSize is the largest desktop recording (in bytes) // that will be accepted. Used by both the agent-side stop handler // and the server-side storage pipeline. const MaxRecordingSize = 100 << 20 // 100 MB +// MaxThumbnailSize is the largest thumbnail (in bytes) that will +// be accepted. Applied both agent-side (before streaming) and +// server-side (when parsing multipart parts). +const MaxThumbnailSize = 10 << 20 // 10 MB + // ExecuteDesktopAction executes a mouse/keyboard/scroll action on the // agent's desktop. func (c *agentConn) ExecuteDesktopAction(ctx context.Context, action DesktopAction) (DesktopActionResponse, error) { @@ -681,22 +695,27 @@ func (c *agentConn) StartDesktopRecording(ctx context.Context, req StartDesktopR } // StopDesktopRecording stops a desktop recording session on the -// agent and returns the MP4 data as an io.ReadCloser. The caller -// is responsible for closing the returned reader. Idempotent — -// safe to call on an already-stopped recording. -func (c *agentConn) StopDesktopRecording(ctx context.Context, req StopDesktopRecordingRequest) (io.ReadCloser, error) { +// agent and returns the recording as a StopDesktopRecordingResponse. +// The response body is a multipart/mixed stream containing the +// video (and optionally a JPEG thumbnail). The caller is +// responsible for closing the returned Body. Idempotent — safe +// to call on an already-stopped recording. +func (c *agentConn) StopDesktopRecording(ctx context.Context, req StopDesktopRecordingRequest) (StopDesktopRecordingResponse, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodPost, "/api/v0/desktop/recording/stop", req) if err != nil { - return nil, xerrors.Errorf("stop recording request: %w", err) + return StopDesktopRecordingResponse{}, xerrors.Errorf("stop recording request: %w", err) } if res.StatusCode != http.StatusOK { defer res.Body.Close() - return nil, codersdk.ReadBodyAsError(res) + return StopDesktopRecordingResponse{}, codersdk.ReadBodyAsError(res) } // Caller is responsible for closing res.Body. - return res.Body, nil + return StopDesktopRecordingResponse{ + Body: res.Body, + ContentType: res.Header.Get("Content-Type"), + }, nil } // DeleteDevcontainer deletes the provided devcontainer. diff --git a/codersdk/workspacesdk/agentconnmock/agentconnmock.go b/codersdk/workspacesdk/agentconnmock/agentconnmock.go index 2d90863a21..f895a4c9ad 100644 --- a/codersdk/workspacesdk/agentconnmock/agentconnmock.go +++ b/codersdk/workspacesdk/agentconnmock/agentconnmock.go @@ -580,10 +580,10 @@ func (mr *MockAgentConnMockRecorder) StartProcess(ctx, req any) *gomock.Call { } // StopDesktopRecording mocks base method. -func (m *MockAgentConn) StopDesktopRecording(ctx context.Context, req workspacesdk.StopDesktopRecordingRequest) (io.ReadCloser, error) { +func (m *MockAgentConn) StopDesktopRecording(ctx context.Context, req workspacesdk.StopDesktopRecordingRequest) (workspacesdk.StopDesktopRecordingResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "StopDesktopRecording", ctx, req) - ret0, _ := ret[0].(io.ReadCloser) + ret0, _ := ret[0].(workspacesdk.StopDesktopRecordingResponse) ret1, _ := ret[1].(error) return ret0, ret1 }