From a104d608a35ddc474175b28e2c5e68b544d57c0a Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 6 Mar 2026 21:05:26 +0200 Subject: [PATCH] feat: add file/image attachment support to chat input (#22604) This change adds support for image attachments to chat via add button and clipboard paste. Files are stored in a new `chat_files` table and referenced by ID in message content. File data is resolved from storage at LLM dispatch time, keeping the message content column small. Upload validates MIME types via content type or content sniffing against an allowlist (png, jpeg, gif, webp). The retrieval endpoint serves files with immutable caching headers. On the frontend, uploads start eagerly on attach with a background fetch to pre-warm the browser HTTP cache so the timeline renders instantly after send. --- coderd/apidoc/docs.go | 131 +++++ coderd/apidoc/swagger.json | 123 +++++ coderd/chatd/chatd.go | 46 +- coderd/chatd/chatprompt/chatprompt.go | 187 ++++++- coderd/chatd/chatprompt/chatprompt_test.go | 141 ++++- coderd/chats.go | 298 +++++++++- coderd/chats_test.go | 518 ++++++++++++++++++ coderd/coderd.go | 13 + coderd/database/db2sdk/db2sdk.go | 20 +- coderd/database/dbauthz/dbauthz.go | 29 + coderd/database/dbauthz/dbauthz_test.go | 16 + coderd/database/dbmetrics/querymetrics.go | 24 + coderd/database/dbmock/dbmock.go | 45 ++ coderd/database/dump.sql | 23 + coderd/database/foreign_key_constraint.go | 2 + .../migrations/000429_chat_files.down.sql | 2 + .../migrations/000429_chat_files.up.sql | 12 + .../fixtures/000429_chat_files.up.sql | 13 + coderd/database/modelmethods.go | 4 + coderd/database/models.go | 10 + coderd/database/querier.go | 3 + coderd/database/queries.sql.go | 97 ++++ coderd/database/queries/chatfiles.sql | 10 + coderd/database/unique_constraint.go | 1 + codersdk/chats.go | 49 +- docs/ai-coder/agents/index.md | 12 + docs/reference/api/chats.md | 78 +++ docs/reference/api/schemas.md | 14 + site/src/api/api.ts | 17 + site/src/api/typesGenerated.ts | 14 +- .../ChatMessageInput/ChatMessageInput.tsx | 29 +- .../AgentsPage/AgentChatInput.stories.tsx | 79 ++- site/src/pages/AgentsPage/AgentChatInput.tsx | 245 ++++++++- site/src/pages/AgentsPage/AgentDetail.test.ts | 2 +- site/src/pages/AgentsPage/AgentDetail.tsx | 195 +++++-- .../AgentDetail/ConversationTimeline.tsx | 248 ++++++--- .../AgentDetail/messageParsing.test.ts | 34 ++ .../AgentsPage/AgentDetail/messageParsing.ts | 17 + .../AgentsPage/AgentDetail/streamState.ts | 20 + .../src/pages/AgentsPage/AgentDetail/types.ts | 6 + .../pages/AgentsPage/AgentsPage.stories.tsx | 2 + site/src/pages/AgentsPage/AgentsPage.tsx | 66 ++- .../AgentsPage/AttachmentPreview.stories.tsx | 143 +++++ .../AgentsPage/ImageLightbox.stories.tsx | 30 + site/src/pages/AgentsPage/ImageLightbox.tsx | 25 + .../pages/AgentsPage/useFileAttachments.ts | 160 ++++++ 46 files changed, 3090 insertions(+), 163 deletions(-) create mode 100644 coderd/database/migrations/000429_chat_files.down.sql create mode 100644 coderd/database/migrations/000429_chat_files.up.sql create mode 100644 coderd/database/migrations/testdata/fixtures/000429_chat_files.up.sql create mode 100644 coderd/database/queries/chatfiles.sql create mode 100644 site/src/pages/AgentsPage/AttachmentPreview.stories.tsx create mode 100644 site/src/pages/AgentsPage/ImageLightbox.stories.tsx create mode 100644 site/src/pages/AgentsPage/ImageLightbox.tsx create mode 100644 site/src/pages/AgentsPage/useFileAttachments.ts diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 53b65527f0..09f8a5e26d 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -481,6 +481,128 @@ const docTemplate = `{ } } }, + "/chats/files": { + "post": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "consumes": [ + "application/octet-stream" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Chats" + ], + "summary": "Upload a chat file", + "operationId": "upload-chat-file", + "parameters": [ + { + "type": "string", + "description": "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)", + "name": "Content-Type", + "in": "header", + "required": true + }, + { + "type": "string", + "format": "uuid", + "description": "Organization ID", + "name": "organization", + "in": "query", + "required": true + } + ], + "responses": { + "201": { + "description": "Created", + "schema": { + "$ref": "#/definitions/codersdk.UploadChatFileResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "413": { + "description": "Request Entity Too Large", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + } + } + } + }, + "/chats/files/{file}": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "tags": [ + "Chats" + ], + "summary": "Get a chat file", + "operationId": "get-chat-file", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "File ID", + "name": "file", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + } + } + } + }, "/chats/{chat}/archive": { "post": { "tags": [ @@ -20334,6 +20456,15 @@ const docTemplate = `{ } } }, + "codersdk.UploadChatFileResponse": { + "type": "object", + "properties": { + "id": { + "type": "string", + "format": "uuid" + } + } + }, "codersdk.UploadResponse": { "type": "object", "properties": { diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index b24b130c48..a234b00b51 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -410,6 +410,120 @@ } } }, + "/chats/files": { + "post": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "consumes": ["application/octet-stream"], + "produces": ["application/json"], + "tags": ["Chats"], + "summary": "Upload a chat file", + "operationId": "upload-chat-file", + "parameters": [ + { + "type": "string", + "description": "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)", + "name": "Content-Type", + "in": "header", + "required": true + }, + { + "type": "string", + "format": "uuid", + "description": "Organization ID", + "name": "organization", + "in": "query", + "required": true + } + ], + "responses": { + "201": { + "description": "Created", + "schema": { + "$ref": "#/definitions/codersdk.UploadChatFileResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "413": { + "description": "Request Entity Too Large", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + } + } + } + }, + "/chats/files/{file}": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "tags": ["Chats"], + "summary": "Get a chat file", + "operationId": "get-chat-file", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "File ID", + "name": "file", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + } + } + } + }, "/chats/{chat}/archive": { "post": { "tags": ["Chats"], @@ -18650,6 +18764,15 @@ } } }, + "codersdk.UploadChatFileResponse": { + "type": "object", + "properties": { + "id": { + "type": "string", + "format": "uuid" + } + } + }, "codersdk.UploadResponse": { "type": "object", "properties": { diff --git a/coderd/chatd/chatd.go b/coderd/chatd/chatd.go index b03dbb48c2..011fbc71e6 100644 --- a/coderd/chatd/chatd.go +++ b/coderd/chatd/chatd.go @@ -165,6 +165,9 @@ type CreateOptions struct { ModelConfigID uuid.UUID SystemPrompt string InitialUserContent []fantasy.Content + // ContentFileIDs maps content block indices to their chat_files IDs + // so the file_id can be preserved in the stored message JSON. + ContentFileIDs map[int]uuid.UUID } // SendMessageBusyBehavior controls what happens when a chat is already active. @@ -180,10 +183,11 @@ const ( // SendMessageOptions controls user message insertion with busy-state behavior. type SendMessageOptions struct { - ChatID uuid.UUID - Content []fantasy.Content - ModelConfigID *uuid.UUID - BusyBehavior SendMessageBusyBehavior + ChatID uuid.UUID + Content []fantasy.Content + ContentFileIDs map[int]uuid.UUID + ModelConfigID *uuid.UUID + BusyBehavior SendMessageBusyBehavior } // SendMessageResult contains the outcome of user message processing. @@ -199,6 +203,7 @@ type EditMessageOptions struct { ChatID uuid.UUID EditedMessageID int64 Content []fantasy.Content + ContentFileIDs map[int]uuid.UUID } // EditMessageResult contains the updated user message and chat status. @@ -278,7 +283,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C } } - userContent, err := chatprompt.MarshalContent(opts.InitialUserContent) + userContent, err := chatprompt.MarshalContent(opts.InitialUserContent, opts.ContentFileIDs) if err != nil { return xerrors.Errorf("marshal initial user content: %w", err) } @@ -345,7 +350,7 @@ func (p *Server) SendMessage( return SendMessageResult{}, xerrors.Errorf("invalid busy behavior %q", opts.BusyBehavior) } - content, err := chatprompt.MarshalContent(opts.Content) + content, err := chatprompt.MarshalContent(opts.Content, opts.ContentFileIDs) if err != nil { return SendMessageResult{}, xerrors.Errorf("marshal message content: %w", err) } @@ -448,7 +453,7 @@ func (p *Server) EditMessage( return EditMessageResult{}, xerrors.New("content is required") } - content, err := chatprompt.MarshalContent(opts.Content) + content, err := chatprompt.MarshalContent(opts.Content, opts.ContentFileIDs) if err != nil { return EditMessageResult{}, xerrors.Errorf("marshal message content: %w", err) } @@ -1607,6 +1612,25 @@ func (p *Server) subscribeChatControl( return controlCancel } +// chatFileResolver returns a FileResolver that fetches chat file +// content from the database by ID. +func (p *Server) chatFileResolver() chatprompt.FileResolver { + return func(ctx context.Context, ids []uuid.UUID) (map[uuid.UUID]chatprompt.FileData, error) { + files, err := p.db.GetChatFilesByIDs(ctx, ids) + if err != nil { + return nil, err + } + result := make(map[uuid.UUID]chatprompt.FileData, len(files)) + for _, f := range files { + result[f.ID] = chatprompt.FileData{ + Data: f.Data, + MediaType: f.Mimetype, + } + } + return result, nil + } +} + func (p *Server) processChat(ctx context.Context, chat database.Chat) { logger := p.logger.With(slog.F("chat_id", chat.ID)) logger.Info(ctx, "processing chat request") @@ -1922,7 +1946,7 @@ func (p *Server) runChat( p.maybeGenerateChatTitle(context.WithoutCancel(ctx), chat, messages, model, providerKeys, logger) }() - prompt, err := chatprompt.ConvertMessages(messages) + prompt, err := chatprompt.ConvertMessagesWithFiles(ctx, messages, p.chatFileResolver()) if err != nil { return xerrors.Errorf("build chat prompt: %w", err) } @@ -2064,7 +2088,7 @@ func (p *Server) runChat( } if len(assistantBlocks) > 0 { - assistantContent, err := chatprompt.MarshalContent(assistantBlocks) + assistantContent, err := chatprompt.MarshalContent(assistantBlocks, nil) if err != nil { return err } @@ -2270,7 +2294,7 @@ func (p *Server) runChat( if err != nil { return nil, xerrors.Errorf("reload chat messages: %w", err) } - reloadedPrompt, err := chatprompt.ConvertMessages(reloadedMsgs) + reloadedPrompt, err := chatprompt.ConvertMessagesWithFiles(reloadCtx, reloadedMsgs, p.chatFileResolver()) if err != nil { return nil, xerrors.Errorf("convert reloaded messages: %w", err) } @@ -2363,7 +2387,7 @@ func (p *Server) persistChatContextSummary( ToolName: "chat_summarized", Input: string(args), }, - }) + }, nil) if err != nil { return xerrors.Errorf("encode summary tool call: %w", err) } diff --git a/coderd/chatd/chatprompt/chatprompt.go b/coderd/chatd/chatprompt/chatprompt.go index ff6ad99368..a7d07100cc 100644 --- a/coderd/chatd/chatprompt/chatprompt.go +++ b/coderd/chatd/chatprompt/chatprompt.go @@ -1,12 +1,14 @@ package chatprompt import ( + "context" "encoding/json" "regexp" "strings" "charm.land/fantasy" fantasyopenai "charm.land/fantasy/providers/openai" + "github.com/google/uuid" "github.com/sqlc-dev/pqtype" "golang.org/x/xerrors" @@ -16,12 +18,156 @@ import ( var toolCallIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`) +// FileData holds resolved file content for LLM prompt building. +type FileData struct { + Data []byte + MediaType string +} + +// FileResolver fetches file content by ID for LLM prompt building. +type FileResolver func(ctx context.Context, ids []uuid.UUID) (map[uuid.UUID]FileData, error) + +// ExtractFileID parses the file_id from a serialized file content +// block envelope. Returns uuid.Nil and an error when the block is +// not a file-type block or has no file_id. +func ExtractFileID(raw json.RawMessage) (uuid.UUID, error) { + var envelope struct { + Type string `json:"type"` + Data struct { + FileID string `json:"file_id"` + } `json:"data"` + } + if err := json.Unmarshal(raw, &envelope); err != nil { + return uuid.Nil, xerrors.Errorf("unmarshal content block: %w", err) + } + if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeFile)) { + return uuid.Nil, xerrors.Errorf("not a file content block: %s", envelope.Type) + } + if envelope.Data.FileID == "" { + return uuid.Nil, xerrors.New("no file_id") + } + return uuid.Parse(envelope.Data.FileID) +} + +// extractFileIDs scans raw message content for file_id references. +// Returns a map of block index to file ID. Returns nil for +// non-array content or content with no file references. +func extractFileIDs(raw pqtype.NullRawMessage) map[int]uuid.UUID { + if !raw.Valid || len(raw.RawMessage) == 0 { + return nil + } + var rawBlocks []json.RawMessage + if err := json.Unmarshal(raw.RawMessage, &rawBlocks); err != nil { + return nil + } + var result map[int]uuid.UUID + for i, block := range rawBlocks { + fid, err := ExtractFileID(block) + if err == nil { + if result == nil { + result = make(map[int]uuid.UUID) + } + result[i] = fid + } + } + return result +} + +// patchFileContent fills in empty Data on FileContent blocks from +// resolved file data. Blocks that already have inline data (backward +// compat) or have no resolved data are left unchanged. +func patchFileContent( + content []fantasy.Content, + fileIDs map[int]uuid.UUID, + resolved map[uuid.UUID]FileData, +) { + for blockIdx, fid := range fileIDs { + if blockIdx >= len(content) { + continue + } + switch fc := content[blockIdx].(type) { + case fantasy.FileContent: + if len(fc.Data) > 0 { + continue + } + if data, found := resolved[fid]; found { + fc.Data = data.Data + content[blockIdx] = fc + } + case *fantasy.FileContent: + if len(fc.Data) > 0 { + continue + } + if data, found := resolved[fid]; found { + fc.Data = data.Data + } + } + } +} + +// ConvertMessages converts persisted chat messages into LLM prompt +// messages without resolving file references from storage. Inline +// file data is preserved when present (backward compat). func ConvertMessages( messages []database.ChatMessage, ) ([]fantasy.Message, error) { + return ConvertMessagesWithFiles(context.Background(), messages, nil) +} + +// ConvertMessagesWithFiles converts persisted chat messages into LLM +// prompt messages, resolving file references via the provided +// resolver. When resolver is nil, file blocks without inline data +// are passed through as-is (same behavior as ConvertMessages). +func ConvertMessagesWithFiles( + ctx context.Context, + messages []database.ChatMessage, + resolver FileResolver, +) ([]fantasy.Message, error) { + // Phase 1: Pre-scan user messages for file_id references. + var allFileIDs []uuid.UUID + seenFileIDs := make(map[uuid.UUID]struct{}) + fileIDsByMsg := make(map[int]map[int]uuid.UUID) + + if resolver != nil { + for i, msg := range messages { + visibility := msg.Visibility + if visibility == "" { + visibility = database.ChatMessageVisibilityBoth + } + if visibility != database.ChatMessageVisibilityModel && + visibility != database.ChatMessageVisibilityBoth { + continue + } + if msg.Role != string(fantasy.MessageRoleUser) { + continue + } + fids := extractFileIDs(msg.Content) + if len(fids) > 0 { + fileIDsByMsg[i] = fids + for _, fid := range fids { + if _, seen := seenFileIDs[fid]; !seen { + seenFileIDs[fid] = struct{}{} + allFileIDs = append(allFileIDs, fid) + } + } + } + } + } + + // Phase 2: Batch resolve file data. + var resolved map[uuid.UUID]FileData + if len(allFileIDs) > 0 { + var err error + resolved, err = resolver(ctx, allFileIDs) + if err != nil { + return nil, xerrors.Errorf("resolve chat files: %w", err) + } + } + + // Phase 3: Convert messages, patching file content as needed. prompt := make([]fantasy.Message, 0, len(messages)) toolNameByCallID := make(map[string]string) - for _, message := range messages { + for i, message := range messages { visibility := message.Visibility if visibility == "" { visibility = database.ChatMessageVisibilityBoth @@ -51,6 +197,9 @@ func ConvertMessages( if err != nil { return nil, err } + if fids, ok := fileIDsByMsg[i]; ok { + patchFileContent(content, fids, resolved) + } prompt = append(prompt, fantasy.Message{ Role: fantasy.MessageRoleUser, Content: ToMessageParts(content), @@ -400,7 +549,10 @@ func ExtractToolCalls(parts []fantasy.MessagePart) []fantasy.ToolCallContent { } // MarshalContent encodes message content blocks for persistence. -func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) { +// fileIDs optionally maps block indices to chat_files IDs, which +// are injected into the JSON envelope for file-type blocks so +// the reference survives round-trips through storage. +func MarshalContent(blocks []fantasy.Content, fileIDs map[int]uuid.UUID) (pqtype.NullRawMessage, error) { if len(blocks) == 0 { return pqtype.NullRawMessage{}, nil } @@ -415,6 +567,16 @@ func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) { err, ) } + if fid, ok := fileIDs[i]; ok { + encoded, err = injectFileID(encoded, fid) + if err != nil { + return pqtype.NullRawMessage{}, xerrors.Errorf( + "inject file_id into content block %d: %w", + i, + err, + ) + } + } encodedBlocks = append(encodedBlocks, encoded) } @@ -425,6 +587,27 @@ func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) { return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil } +// injectFileID adds a file_id field into the data sub-object of a +// serialized content block envelope. This follows the same pattern +// as the reasoning title injection in marshalContentBlock. +func injectFileID(encoded json.RawMessage, fileID uuid.UUID) (json.RawMessage, error) { + var envelope struct { + Type string `json:"type"` + Data struct { + MediaType string `json:"media_type"` + Data json.RawMessage `json:"data,omitempty"` + FileID string `json:"file_id,omitempty"` + ProviderMetadata *json.RawMessage `json:"provider_metadata,omitempty"` + } `json:"data"` + } + if err := json.Unmarshal(encoded, &envelope); err != nil { + return encoded, err + } + envelope.Data.FileID = fileID.String() + envelope.Data.Data = nil // Strip inline data; resolved at LLM dispatch time. + return json.Marshal(envelope) +} + // MarshalToolResult encodes a single tool result for persistence as // an opaque JSON blob. The stored shape is // [{"tool_call_id":…,"tool_name":…,"result":…,"is_error":…}]. diff --git a/coderd/chatd/chatprompt/chatprompt_test.go b/coderd/chatd/chatprompt/chatprompt_test.go index ba398446a1..56d3124366 100644 --- a/coderd/chatd/chatprompt/chatprompt_test.go +++ b/coderd/chatd/chatprompt/chatprompt_test.go @@ -1,10 +1,13 @@ package chatprompt_test import ( + "context" "encoding/json" "testing" "charm.land/fantasy" + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/chatd/chatprompt" @@ -52,7 +55,7 @@ func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) { ToolName: "execute", Input: tc.input, }, - }) + }, nil) require.NoError(t, err) toolContent, err := chatprompt.MarshalToolResult( @@ -89,3 +92,139 @@ func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) { }) } } + +func TestConvertMessagesWithFiles_ResolvesFileData(t *testing.T) { + t.Parallel() + + fileID := uuid.New() + fileData := []byte("fake-image-bytes") + + // Build a user message with file_id but no inline data, as + // would be stored after injectFileID strips the data. + rawContent := mustJSON(t, []json.RawMessage{ + mustJSON(t, map[string]any{ + "type": "file", + "data": map[string]any{ + "media_type": "image/png", + "file_id": fileID.String(), + }, + }), + }) + + resolver := func(_ context.Context, ids []uuid.UUID) (map[uuid.UUID]chatprompt.FileData, error) { + result := make(map[uuid.UUID]chatprompt.FileData) + for _, id := range ids { + if id == fileID { + result[id] = chatprompt.FileData{ + Data: fileData, + MediaType: "image/png", + } + } + } + return result, nil + } + + prompt, err := chatprompt.ConvertMessagesWithFiles( + context.Background(), + []database.ChatMessage{ + { + Role: string(fantasy.MessageRoleUser), + Visibility: database.ChatMessageVisibilityBoth, + Content: pqtype.NullRawMessage{RawMessage: rawContent, Valid: true}, + }, + }, + resolver, + ) + require.NoError(t, err) + require.Len(t, prompt, 1) + require.Equal(t, fantasy.MessageRoleUser, prompt[0].Role) + require.Len(t, prompt[0].Content, 1) + + filePart, ok := fantasy.AsMessagePart[fantasy.FilePart](prompt[0].Content[0]) + require.True(t, ok, "expected FilePart") + require.Equal(t, fileData, filePart.Data) + require.Equal(t, "image/png", filePart.MediaType) +} + +func TestConvertMessagesWithFiles_BackwardCompat(t *testing.T) { + t.Parallel() + + // A message with inline data and a file_id should use the + // inline data even when the resolver returns nothing. + fileID := uuid.New() + inlineData := []byte("inline-image-data") + + rawContent := mustJSON(t, []json.RawMessage{ + mustJSON(t, map[string]any{ + "type": "file", + "data": map[string]any{ + "media_type": "image/png", + "data": inlineData, + "file_id": fileID.String(), + }, + }), + }) + + prompt, err := chatprompt.ConvertMessagesWithFiles( + context.Background(), + []database.ChatMessage{ + { + Role: string(fantasy.MessageRoleUser), + Visibility: database.ChatMessageVisibilityBoth, + Content: pqtype.NullRawMessage{RawMessage: rawContent, Valid: true}, + }, + }, + nil, // No resolver. + ) + require.NoError(t, err) + require.Len(t, prompt, 1) + require.Len(t, prompt[0].Content, 1) + + filePart, ok := fantasy.AsMessagePart[fantasy.FilePart](prompt[0].Content[0]) + require.True(t, ok, "expected FilePart") + require.Equal(t, inlineData, filePart.Data) +} + +func TestInjectFileID_StripsInlineData(t *testing.T) { + t.Parallel() + + fileID := uuid.New() + imageData := []byte("raw-image-bytes") + + // Marshal a file content block with inline data, then inject + // a file_id. The result should have file_id but no data. + content, err := chatprompt.MarshalContent([]fantasy.Content{ + fantasy.FileContent{ + MediaType: "image/png", + Data: imageData, + }, + }, map[int]uuid.UUID{0: fileID}) + require.NoError(t, err) + + // Parse the stored content to verify shape. + var blocks []json.RawMessage + require.NoError(t, json.Unmarshal(content.RawMessage, &blocks)) + require.Len(t, blocks, 1) + + var envelope struct { + Type string `json:"type"` + Data struct { + MediaType string `json:"media_type"` + Data *json.RawMessage `json:"data,omitempty"` + FileID string `json:"file_id"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(blocks[0], &envelope)) + require.Equal(t, "file", envelope.Type) + require.Equal(t, "image/png", envelope.Data.MediaType) + require.Equal(t, fileID.String(), envelope.Data.FileID) + // Data should be nil (omitted) since injectFileID strips it. + require.Nil(t, envelope.Data.Data, "inline data should be stripped") +} + +func mustJSON(t *testing.T, v any) json.RawMessage { + t.Helper() + data, err := json.Marshal(v) + require.NoError(t, err) + return data +} diff --git a/coderd/chats.go b/coderd/chats.go index 26c7828d07..a8094a97d0 100644 --- a/coderd/chats.go +++ b/coderd/chats.go @@ -1,11 +1,15 @@ package coderd import ( + "bufio" + "bytes" "context" "database/sql" "encoding/json" + "errors" "fmt" "io" + "mime" "net/http" "net/http/httptest" "net/url" @@ -247,7 +251,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { return } - contentBlocks, titleSource, inputError := createChatInputFromRequest(req) + contentBlocks, contentFileIDs, titleSource, inputError := createChatInputFromRequest(ctx, api.Database, req) if inputError != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, *inputError) return @@ -282,6 +286,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { ModelConfigID: modelConfigID, SystemPrompt: defaultChatSystemPrompt(), InitialUserContent: contentBlocks, + ContentFileIDs: contentFileIDs, }) if err != nil { if database.IsForeignKeyViolation( @@ -647,7 +652,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { return } - contentBlocks, _, inputError := createChatInputFromParts(req.Content, "content") + contentBlocks, contentFileIDs, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content") if inputError != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: inputError.Message, @@ -659,10 +664,11 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { sendResult, sendErr := api.chatDaemon.SendMessage( ctx, chatd.SendMessageOptions{ - ChatID: chatID, - Content: contentBlocks, - ModelConfigID: req.ModelConfigID, - BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + ChatID: chatID, + Content: contentBlocks, + ContentFileIDs: contentFileIDs, + ModelConfigID: req.ModelConfigID, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, }, ) if sendErr != nil { @@ -721,7 +727,7 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) { return } - contentBlocks, _, inputError := createChatInputFromParts(req.Content, "content") + contentBlocks, contentFileIDs, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content") if inputError != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: inputError.Message, @@ -734,6 +740,7 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) { ChatID: chat.ID, EditedMessageID: messageID, Content: contentBlocks, + ContentFileIDs: contentFileIDs, }) if editErr != nil { switch { @@ -2196,45 +2203,298 @@ func normalizeChatCompressionThreshold( return threshold, nil } +const ( + // maxChatFileSize is the maximum size of a chat file upload (10 MB). + maxChatFileSize = 10 << 20 + // maxChatFileName is the maximum length of an uploaded file name. + maxChatFileName = 255 +) + +// allowedChatFileMIMETypes lists the content types accepted for chat +// file uploads. SVG is explicitly excluded because it can contain scripts. +var allowedChatFileMIMETypes = map[string]bool{ + "image/png": true, + "image/jpeg": true, + "image/gif": true, + "image/webp": true, + "image/svg+xml": false, // SVG can contain scripts. +} + +var ( + webpMagicRIFF = []byte("RIFF") + webpMagicWEBP = []byte("WEBP") +) + +// detectChatFileType detects the MIME type of the given data. +// It extends http.DetectContentType with support for WebP, which +// Go's standard sniffer does not recognize. +func detectChatFileType(data []byte) string { + if len(data) >= 12 && + bytes.Equal(data[0:4], webpMagicRIFF) && + bytes.Equal(data[8:12], webpMagicWEBP) { + return "image/webp" + } + return http.DetectContentType(data) +} + func defaultChatSystemPrompt() string { return chatd.DefaultSystemPrompt } -func createChatInputFromRequest(req codersdk.CreateChatRequest) ( +// @Summary Upload a chat file +// @ID upload-chat-file +// @Security CoderSessionToken +// @Accept application/octet-stream +// @Produce json +// @Tags Chats +// @Param Content-Type header string true "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)" +// @Param organization query string true "Organization ID" format(uuid) +// @Success 201 {object} codersdk.UploadChatFileResponse +// @Failure 400 {object} codersdk.Response +// @Failure 401 {object} codersdk.Response +// @Failure 413 {object} codersdk.Response +// @Failure 500 {object} codersdk.Response +// @Router /chats/files [post] +func (api *API) postChatFile(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + apiKey := httpmw.APIKey(r) + + if !api.Authorize(r, policy.ActionCreate, rbac.ResourceChat.WithOwner(apiKey.UserID.String())) { + httpapi.Forbidden(rw) + return + } + + orgIDStr := r.URL.Query().Get("organization") + if orgIDStr == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Missing organization query parameter.", + }) + return + } + orgID, err := uuid.Parse(orgIDStr) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid organization ID.", + }) + return + } + + contentType := r.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/octet-stream" + } + // Strip parameters (e.g. "image/png; charset=utf-8" → "image/png") + // so the allowlist check matches the base media type. + if mediaType, _, err := mime.ParseMediaType(contentType); err == nil { + contentType = mediaType + } + + if allowed, ok := allowedChatFileMIMETypes[contentType]; !ok || !allowed { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Unsupported file type.", + Detail: "Allowed types: image/png, image/jpeg, image/gif, image/webp.", + }) + return + } + + r.Body = http.MaxBytesReader(rw, r.Body, maxChatFileSize) + br := bufio.NewReader(r.Body) + + // Peek at the leading bytes to sniff the real content type + // before reading the entire body. + peek, peekErr := br.Peek(512) + if peekErr != nil && !errors.Is(peekErr, io.EOF) && !errors.Is(peekErr, bufio.ErrBufferFull) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Failed to read file from request.", + Detail: peekErr.Error(), + }) + return + } + + // Verify the actual content matches a safe image type so that + // a client cannot spoof Content-Type to serve active content. + detected := detectChatFileType(peek) + if allowed, ok := allowedChatFileMIMETypes[detected]; !ok || !allowed { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Unsupported file type.", + Detail: "Allowed types: image/png, image/jpeg, image/gif, image/webp.", + }) + return + } + + // Read the full body now that we know the type is valid. + data, err := io.ReadAll(br) + if err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{ + Message: "File too large.", + Detail: fmt.Sprintf("Maximum file size is %d bytes.", maxChatFileSize), + }) + return + } + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Failed to read file from request.", + Detail: err.Error(), + }) + return + } + + // Extract filename from Content-Disposition header if provided. + var filename string + if cd := r.Header.Get("Content-Disposition"); cd != "" { + if _, params, err := mime.ParseMediaType(cd); err == nil { + filename = params["filename"] + if len(filename) > maxChatFileName { + // Truncate at rune boundary to avoid splitting + // multi-byte UTF-8 characters. + var truncated []byte + for _, r := range filename { + encoded := []byte(string(r)) + if len(truncated)+len(encoded) > maxChatFileName { + break + } + truncated = append(truncated, encoded...) + } + filename = string(truncated) + } + } + } + + chatFile, err := api.Database.InsertChatFile(ctx, database.InsertChatFileParams{ + OwnerID: apiKey.UserID, + OrganizationID: orgID, + Name: filename, + Mimetype: detected, + Data: data, + }) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to save chat file.", + Detail: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusCreated, codersdk.UploadChatFileResponse{ + ID: chatFile.ID, + }) +} + +// @Summary Get a chat file +// @ID get-chat-file +// @Security CoderSessionToken +// @Tags Chats +// @Param file path string true "File ID" format(uuid) +// @Success 200 +// @Failure 400 {object} codersdk.Response +// @Failure 401 {object} codersdk.Response +// @Failure 404 {object} codersdk.Response +// @Failure 500 {object} codersdk.Response +// @Router /chats/files/{file} [get] +func (api *API) chatFileByID(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + fileIDStr := chi.URLParam(r, "file") + fileID, err := uuid.Parse(fileIDStr) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid file ID.", + }) + return + } + + chatFile, err := api.Database.GetChatFileByID(ctx, fileID) + if err != nil { + if httpapi.Is404Error(err) { + httpapi.ResourceNotFound(rw) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to get chat file.", + Detail: err.Error(), + }) + return + } + + rw.Header().Set("Content-Type", chatFile.Mimetype) + if chatFile.Name != "" { + rw.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": chatFile.Name})) + } else { + rw.Header().Set("Content-Disposition", "inline") + } + rw.Header().Set("Cache-Control", "private, max-age=31536000, immutable") + rw.Header().Set("Content-Length", strconv.Itoa(len(chatFile.Data))) + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write(chatFile.Data) +} + +func createChatInputFromRequest(ctx context.Context, db database.Store, req codersdk.CreateChatRequest) ( []fantasy.Content, + map[int]uuid.UUID, string, *codersdk.Response, ) { - return createChatInputFromParts(req.Content, "content") + return createChatInputFromParts(ctx, db, req.Content, "content") } func createChatInputFromParts( + ctx context.Context, + db database.Store, parts []codersdk.ChatInputPart, fieldName string, -) ([]fantasy.Content, string, *codersdk.Response) { +) ([]fantasy.Content, map[int]uuid.UUID, string, *codersdk.Response) { if len(parts) == 0 { - return nil, "", &codersdk.Response{ + return nil, nil, "", &codersdk.Response{ Message: "Content is required.", Detail: "Content cannot be empty.", } } content := make([]fantasy.Content, 0, len(parts)) + fileIDs := make(map[int]uuid.UUID) textParts := make([]string, 0, len(parts)) for i, part := range parts { switch strings.ToLower(strings.TrimSpace(string(part.Type))) { case string(codersdk.ChatInputPartTypeText): text := strings.TrimSpace(part.Text) if text == "" { - return nil, "", &codersdk.Response{ + return nil, nil, "", &codersdk.Response{ Message: "Invalid input part.", Detail: fmt.Sprintf("%s[%d].text cannot be empty.", fieldName, i), } } content = append(content, fantasy.TextContent{Text: text}) textParts = append(textParts, text) + case string(codersdk.ChatInputPartTypeFile): + if part.FileID == uuid.Nil { + return nil, nil, "", &codersdk.Response{ + Message: "Invalid input part.", + Detail: fmt.Sprintf("%s[%d].file_id is required for file parts.", fieldName, i), + } + } + // Validate that the file exists and get its media type. + // File data is not loaded here; it's resolved at LLM + // dispatch time via chatFileResolver. + chatFile, err := db.GetChatFileByID(ctx, part.FileID) + if err != nil { + if httpapi.Is404Error(err) { + return nil, nil, "", &codersdk.Response{ + Message: "Invalid input part.", + Detail: fmt.Sprintf("%s[%d].file_id references a file that does not exist.", fieldName, i), + } + } + return nil, nil, "", &codersdk.Response{ + Message: "Internal error.", + Detail: fmt.Sprintf("Failed to retrieve file for %s[%d].", fieldName, i), + } + } + content = append(content, fantasy.FileContent{ + MediaType: chatFile.Mimetype, + }) + fileIDs[len(content)-1] = part.FileID default: - return nil, "", &codersdk.Response{ + return nil, nil, "", &codersdk.Response{ Message: "Invalid input part.", Detail: fmt.Sprintf( "%s[%d].type %q is not supported.", @@ -2246,14 +2506,16 @@ func createChatInputFromParts( } } - titleSource := strings.TrimSpace(strings.Join(textParts, " ")) - if titleSource == "" { - return nil, "", &codersdk.Response{ + // Allow file-only messages. The titleSource may be empty + // when only file parts are provided, callers handle this. + if len(content) == 0 { + return nil, nil, "", &codersdk.Response{ Message: "Content is required.", - Detail: "Content must include at least one text part.", + Detail: fmt.Sprintf("%s must include at least one text or file part.", fieldName), } } - return content, titleSource, nil + titleSource := strings.TrimSpace(strings.Join(textParts, " ")) + return content, fileIDs, titleSource, nil } func chatTitleFromMessage(message string) string { diff --git a/coderd/chats_test.go b/coderd/chats_test.go index 911cc6133c..30a45a23b9 100644 --- a/coderd/chats_test.go +++ b/coderd/chats_test.go @@ -1,11 +1,13 @@ package coderd_test import ( + "bytes" "database/sql" "encoding/json" "fmt" "net/http" "regexp" + "strings" "testing" "time" @@ -1525,6 +1527,175 @@ func TestPostChatMessages(t *testing.T) { }) } +func TestChatMessageWithFiles(t *testing.T) { + t.Parallel() + + t.Run("FileOnly", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + _ = createChatModelConfig(t, client) + + // Upload a file. + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData)) + require.NoError(t, err) + + // Create a chat with text first. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "initial message", + }, + }, + }) + require.NoError(t, err) + + // Send a file-only message (no text). + resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeFile, + FileID: uploadResp.ID, + }, + }, + }) + require.NoError(t, err) + + // Verify the message was accepted. + if resp.Queued { + require.NotNil(t, resp.QueuedMessage) + } else { + require.NotNil(t, resp.Message) + require.Equal(t, "user", resp.Message.Role) + } + }) + + t.Run("TextAndFile", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + _ = createChatModelConfig(t, client) + + // Upload a file. + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData)) + require.NoError(t, err) + + // Create a chat with text first. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "initial message", + }, + }, + }) + require.NoError(t, err) + + // Send a message with both text and file. + resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "here is an image", + }, + { + Type: codersdk.ChatInputPartTypeFile, + FileID: uploadResp.ID, + }, + }, + }) + require.NoError(t, err) + + if resp.Queued { + require.NotNil(t, resp.QueuedMessage) + } else { + require.NotNil(t, resp.Message) + require.Equal(t, "user", resp.Message.Role) + } + + // Verify file parts omit inline data in the API response. + chatWithMessages, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + for _, msg := range chatWithMessages.Messages { + for _, part := range msg.Content { + if part.Type == codersdk.ChatMessagePartTypeFile { + require.True(t, part.FileID.Valid, "file part should have a valid file_id") + require.Equal(t, uploadResp.ID, part.FileID.UUID) + require.Nil(t, part.Data, "file data should not be sent when file_id is present") + } + } + } + }) + + t.Run("FileOnlyOnCreate", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + _ = createChatModelConfig(t, client) + + // Upload a file. + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData)) + require.NoError(t, err) + + // Create a new chat with only a file part. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeFile, + FileID: uploadResp.ID, + }, + }, + }) + require.NoError(t, err) + + // With no text, chatTitleFromMessage("") returns "New Chat". + require.Equal(t, "New Chat", chat.Title) + }) + + t.Run("InvalidFileID", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client) + _ = createChatModelConfig(t, client) + + // Create a chat with text first. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "initial message", + }, + }, + }) + require.NoError(t, err) + + // Send a message with a non-existent file ID. + _, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeFile, + FileID: uuid.New(), + }, + }, + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Invalid input part.", sdkErr.Message) + require.Contains(t, sdkErr.Detail, "does not exist") + }) +} + func TestPatchChatMessage(t *testing.T) { t.Parallel() @@ -1602,6 +1773,100 @@ func TestPatchChatMessage(t *testing.T) { require.False(t, foundOriginalInChat) }) + t.Run("PreservesFileID", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + _ = createChatModelConfig(t, client) + + // Upload a file. + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData)) + require.NoError(t, err) + + // Create a chat with a text + file part. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "before edit with file", + }, + { + Type: codersdk.ChatInputPartTypeFile, + FileID: uploadResp.ID, + }, + }, + }) + require.NoError(t, err) + + // Find the user message ID. + chatWithMessages, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + + var userMessageID int64 + for _, message := range chatWithMessages.Messages { + if message.Role == "user" { + userMessageID = message.ID + break + } + } + require.NotZero(t, userMessageID) + + // Edit the message: new text, same file_id. + edited, err := client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "after edit with file", + }, + { + Type: codersdk.ChatInputPartTypeFile, + FileID: uploadResp.ID, + }, + }, + }) + require.NoError(t, err) + require.Equal(t, userMessageID, edited.ID) + + // Assert the edit response preserves the file_id. + var foundText, foundFile bool + for _, part := range edited.Content { + if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "after edit with file" { + foundText = true + } + if part.Type == codersdk.ChatMessagePartTypeFile && part.FileID.Valid && part.FileID.UUID == uploadResp.ID { + foundFile = true + require.Nil(t, part.Data, "file data should not be sent when file_id is present") + } + } + require.True(t, foundText, "edited message should contain updated text") + require.True(t, foundFile, "edited message should preserve file_id") + + // GET the chat and verify the file_id persists. + updatedChat, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + + var foundTextInChat, foundFileInChat bool + for _, message := range updatedChat.Messages { + if message.Role != "user" { + continue + } + for _, part := range message.Content { + if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "after edit with file" { + foundTextInChat = true + } + if part.Type == codersdk.ChatMessagePartTypeFile && part.FileID.Valid && part.FileID.UUID == uploadResp.ID { + foundFileInChat = true + require.Nil(t, part.Data, "file data should not be sent when file_id is present") + } + } + } + require.True(t, foundTextInChat, "chat should contain edited text") + require.True(t, foundFileInChat, "chat should preserve file_id after edit") + }) + t.Run("MessageNotFound", func(t *testing.T) { t.Parallel() @@ -2212,6 +2477,259 @@ func TestPromoteChatQueuedMessage(t *testing.T) { }) } +func TestPostChatFile(t *testing.T) { + t.Parallel() + + t.Run("Success/PNG", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + // Valid PNG header + padding. + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) + require.NoError(t, err) + require.NotEqual(t, uuid.Nil, resp.ID) + }) + + t.Run("Success/JPEG", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + data := append([]byte{0xFF, 0xD8, 0xFF, 0xE0}, make([]byte, 64)...) + resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/jpeg", "test.jpg", bytes.NewReader(data)) + require.NoError(t, err) + require.NotEqual(t, uuid.Nil, resp.ID) + }) + + t.Run("Success/WebP", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + // WebP: RIFF + 4-byte size + WEBP + padding. + data := append([]byte("RIFF"), make([]byte, 4)...) + data = append(data, []byte("WEBP")...) + data = append(data, make([]byte, 64)...) + resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/webp", "test.webp", bytes.NewReader(data)) + require.NoError(t, err) + require.NotEqual(t, uuid.Nil, resp.ID) + }) + + t.Run("UnsupportedContentType", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + _, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "text/plain", "test.txt", bytes.NewReader([]byte("hello"))) + requireSDKError(t, err, http.StatusBadRequest) + }) + + t.Run("SVGBlocked", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + _, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/svg+xml", "test.svg", bytes.NewReader([]byte(""))) + requireSDKError(t, err, http.StatusBadRequest) + }) + + t.Run("ContentSniffingRejects", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + // Header says PNG but body is plain text. + _, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader([]byte("hello world"))) + requireSDKError(t, err, http.StatusBadRequest) + }) + + t.Run("TooLarge", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + // 10 MB + 1 byte, with valid PNG header to pass MIME check. + data := make([]byte, 10<<20+1) + copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}) + _, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) + require.Error(t, err) + }) + + t.Run("MissingOrganization", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + coderdtest.CreateFirstUser(t, client) + + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + res, err := client.Request(ctx, http.MethodPost, "/api/experimental/chats/files", bytes.NewReader(data), func(r *http.Request) { + r.Header.Set("Content-Type", "image/png") + }) + require.NoError(t, err) + defer res.Body.Close() + err = codersdk.ReadBodyAsError(res) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Contains(t, sdkErr.Message, "Missing organization") + }) + + t.Run("InvalidOrganization", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + coderdtest.CreateFirstUser(t, client) + + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + res, err := client.Request(ctx, http.MethodPost, "/api/experimental/chats/files?organization=not-a-uuid", bytes.NewReader(data), func(r *http.Request) { + r.Header.Set("Content-Type", "image/png") + }) + require.NoError(t, err) + defer res.Body.Close() + err = codersdk.ReadBodyAsError(res) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Contains(t, sdkErr.Message, "Invalid organization ID") + }) + + t.Run("WrongOrganization", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + coderdtest.CreateFirstUser(t, client) + + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + _, err := client.UploadChatFile(ctx, uuid.New(), "image/png", "test.png", bytes.NewReader(data)) + require.Error(t, err) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + // dbauthz returns 404 or 500 depending on how the org lookup + // fails; 403 is also possible. Any non-success code is valid. + require.GreaterOrEqual(t, sdkErr.StatusCode(), http.StatusBadRequest, + "expected error status, got %d", sdkErr.StatusCode()) + }) + + t.Run("Unauthenticated", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + unauthed := codersdk.New(client.URL) + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + _, err := unauthed.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) + requireSDKError(t, err, http.StatusUnauthorized) + }) +} + +func TestGetChatFile(t *testing.T) { + t.Parallel() + + t.Run("Success", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) + require.NoError(t, err) + + got, contentType, err := client.GetChatFile(ctx, uploaded.ID) + require.NoError(t, err) + require.Equal(t, "image/png", contentType) + require.Equal(t, data, got) + }) + + t.Run("CacheHeaders", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) + require.NoError(t, err) + + res, err := client.Request(ctx, http.MethodGet, + fmt.Sprintf("/api/experimental/chats/files/%s", uploaded.ID), nil) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, "private, max-age=31536000, immutable", res.Header.Get("Cache-Control")) + require.Contains(t, res.Header.Get("Content-Disposition"), "inline") + require.Contains(t, res.Header.Get("Content-Disposition"), "test.png") + }) + + t.Run("LongFilename", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + longName := strings.Repeat("a", 300) + ".png" + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", longName, bytes.NewReader(data)) + require.NoError(t, err) + + res, err := client.Request(ctx, http.MethodGet, + fmt.Sprintf("/api/experimental/chats/files/%s", uploaded.ID), nil) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + // Filename should be truncated to maxChatFileName (255) bytes. + cd := res.Header.Get("Content-Disposition") + require.Contains(t, cd, "inline") + require.Contains(t, cd, strings.Repeat("a", 255)) + require.NotContains(t, cd, strings.Repeat("a", 256)) + }) + + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + coderdtest.CreateFirstUser(t, client) + + _, _, err := client.GetChatFile(ctx, uuid.New()) + requireSDKError(t, err, http.StatusNotFound) + }) + + t.Run("InvalidUUID", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + coderdtest.CreateFirstUser(t, client) + + res, err := client.Request(ctx, http.MethodGet, + "/api/experimental/chats/files/not-a-uuid", nil) + require.NoError(t, err) + defer res.Body.Close() + err = codersdk.ReadBodyAsError(res) + requireSDKError(t, err, http.StatusBadRequest) + }) + + t.Run("OtherUserForbidden", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) + require.NoError(t, err) + + otherClient, _ := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID) + _, _, err = otherClient.GetChatFile(ctx, uploaded.ID) + requireSDKError(t, err, http.StatusNotFound) + }) +} + func createChatModelConfig(t *testing.T, client *codersdk.Client) codersdk.ChatModelConfig { t.Helper() diff --git a/coderd/coderd.go b/coderd/coderd.go index dc9a4bc5f3..ae6b0bc159 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1113,6 +1113,11 @@ func New(options *Options) *API { r.Post("/", api.postChats) r.Get("/models", api.listChatModels) r.Get("/watch", api.watchChats) + r.Route("/files", func(r chi.Router) { + r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute)) + r.Post("/", api.postChatFile) + r.Get("/{file}", api.chatFileByID) + }) r.Route("/providers", func(r chi.Router) { r.Get("/", api.listChatProviders) r.Post("/", api.createChatProvider) @@ -1842,6 +1847,14 @@ func New(options *Options) *API { "parsing additional CSP headers", slog.Error(cspParseErrors)) } + // Add blob: to img-src for chat file attachment previews when + // the agents experiment is enabled. + if api.Experiments.Enabled(codersdk.ExperimentAgents) { + additionalCSPHeaders[httpmw.CSPDirectiveImgSrc] = append( + additionalCSPHeaders[httpmw.CSPDirectiveImgSrc], "blob:", + ) + } + // Add CSP headers to all static assets and pages. CSP headers only affect // browsers, so these don't make sense on api routes. cspMW := httpmw.CSPHeaders( diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index 9cfb4a344a..4ef02d3e42 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -1156,9 +1156,7 @@ func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMe } var rawBlocks []json.RawMessage - if role == string(fantasy.MessageRoleAssistant) { - _ = json.Unmarshal(raw.RawMessage, &rawBlocks) - } + _ = json.Unmarshal(raw.RawMessage, &rawBlocks) parts := make([]codersdk.ChatMessagePart, 0, len(content)) for i, block := range content { @@ -1166,10 +1164,20 @@ func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMe if part.Type == "" { continue } - if part.Type == codersdk.ChatMessagePartTypeReasoning { - part.Title = "" - if i < len(rawBlocks) { + if i < len(rawBlocks) { + switch part.Type { + case codersdk.ChatMessagePartTypeReasoning: part.Title = reasoningStoredTitle(rawBlocks[i]) + case codersdk.ChatMessagePartTypeFile: + if fid, err := chatprompt.ExtractFileID(rawBlocks[i]); err == nil { + part.FileID = uuid.NullUUID{UUID: fid, Valid: true} + } + // When a file_id is present, omit inline data + // from the response. Clients fetch content via + // the GET /chats/files/{id} endpoint instead. + if part.FileID.Valid { + part.Data = nil + } } } parts = append(parts, part) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 839c439cd6..83710fca9e 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2457,6 +2457,30 @@ func (q *querier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uu return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs) } +func (q *querier) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) { + file, err := q.db.GetChatFileByID(ctx, id) + if err != nil { + return database.ChatFile{}, err + } + if err := q.authorizeContext(ctx, policy.ActionRead, file); err != nil { + return database.ChatFile{}, err + } + return file, nil +} + +func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) { + files, err := q.db.GetChatFilesByIDs(ctx, ids) + if err != nil { + return nil, err + } + for _, f := range files { + if err := q.authorizeContext(ctx, policy.ActionRead, f); err != nil { + return nil, err + } + } + return files, nil +} + func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) { // ChatMessages are authorized through their parent Chat. // We need to fetch the message first to get its chat_id. @@ -4491,6 +4515,11 @@ func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams) return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg) } +func (q *querier) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) { + // Authorize create on chat resource scoped to the owner and org. + return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertChatFile)(ctx, arg) +} + func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) { // Authorize create on the parent chat (using update permission). chat, err := q.db.GetChatByID(ctx, arg.ChatID) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 252f2075f4..668d016db8 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -463,6 +463,16 @@ func (s *MethodTestSuite) TestChats() { Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead). Returns([]database.ChatDiffStatus{diffStatusA, diffStatusB}) })) + s.Run("GetChatFileByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + file := testutil.Fake(s.T(), faker, database.ChatFile{}) + dbm.EXPECT().GetChatFileByID(gomock.Any(), file.ID).Return(file, nil).AnyTimes() + check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(file) + })) + s.Run("GetChatFilesByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + file := testutil.Fake(s.T(), faker, database.ChatFile{}) + dbm.EXPECT().GetChatFilesByIDs(gomock.Any(), []uuid.UUID{file.ID}).Return([]database.ChatFile{file}, nil).AnyTimes() + check.Args([]uuid.UUID{file.ID}).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns([]database.ChatFile{file}) + })) s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID}) @@ -579,6 +589,12 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().InsertChat(gomock.Any(), arg).Return(chat, nil).AnyTimes() check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionCreate).Returns(chat) })) + s.Run("InsertChatFile", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + arg := testutil.Fake(s.T(), faker, database.InsertChatFileParams{}) + file := testutil.Fake(s.T(), faker, database.InsertChatFileRow{OwnerID: arg.OwnerID, OrganizationID: arg.OrganizationID}) + dbm.EXPECT().InsertChatFile(gomock.Any(), arg).Return(file, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), policy.ActionCreate).Returns(file) + })) s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID}) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 60e12d7c96..82ad4baf61 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1007,6 +1007,22 @@ func (m queryMetricsStore) GetChatDiffStatusesByChatIDs(ctx context.Context, cha return r0, r1 } +func (m queryMetricsStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) { + start := time.Now() + r0, r1 := m.s.GetChatFileByID(ctx, id) + m.queryLatencies.WithLabelValues("GetChatFileByID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFileByID").Inc() + return r0, r1 +} + +func (m queryMetricsStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) { + start := time.Now() + r0, r1 := m.s.GetChatFilesByIDs(ctx, ids) + m.queryLatencies.WithLabelValues("GetChatFilesByIDs").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFilesByIDs").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) { start := time.Now() r0, r1 := m.s.GetChatMessageByID(ctx, id) @@ -2943,6 +2959,14 @@ func (m queryMetricsStore) InsertChat(ctx context.Context, arg database.InsertCh return r0, r1 } +func (m queryMetricsStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) { + start := time.Now() + r0, r1 := m.s.InsertChatFile(ctx, arg) + m.queryLatencies.WithLabelValues("InsertChatFile").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatFile").Inc() + return r0, r1 +} + func (m queryMetricsStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) { start := time.Now() r0, r1 := m.s.InsertChatMessage(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 724d7f2b7b..6a1b286ac5 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1837,6 +1837,36 @@ func (mr *MockStoreMockRecorder) GetChatDiffStatusesByChatIDs(ctx, chatIds any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDiffStatusesByChatIDs", reflect.TypeOf((*MockStore)(nil).GetChatDiffStatusesByChatIDs), ctx, chatIds) } +// GetChatFileByID mocks base method. +func (m *MockStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatFileByID", ctx, id) + ret0, _ := ret[0].(database.ChatFile) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatFileByID indicates an expected call of GetChatFileByID. +func (mr *MockStoreMockRecorder) GetChatFileByID(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileByID", reflect.TypeOf((*MockStore)(nil).GetChatFileByID), ctx, id) +} + +// GetChatFilesByIDs mocks base method. +func (m *MockStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatFilesByIDs", ctx, ids) + ret0, _ := ret[0].([]database.ChatFile) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatFilesByIDs indicates an expected call of GetChatFilesByIDs. +func (mr *MockStoreMockRecorder) GetChatFilesByIDs(ctx, ids any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFilesByIDs", reflect.TypeOf((*MockStore)(nil).GetChatFilesByIDs), ctx, ids) +} + // GetChatMessageByID mocks base method. func (m *MockStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) { m.ctrl.T.Helper() @@ -5511,6 +5541,21 @@ func (mr *MockStoreMockRecorder) InsertChat(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChat", reflect.TypeOf((*MockStore)(nil).InsertChat), ctx, arg) } +// InsertChatFile mocks base method. +func (m *MockStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertChatFile", ctx, arg) + ret0, _ := ret[0].(database.InsertChatFileRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InsertChatFile indicates an expected call of InsertChatFile. +func (mr *MockStoreMockRecorder) InsertChatFile(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatFile", reflect.TypeOf((*MockStore)(nil).InsertChatFile), ctx, arg) +} + // InsertChatMessage mocks base method. func (m *MockStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index a6d1ef960a..2cb2071ede 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1190,6 +1190,16 @@ CREATE TABLE chat_diff_statuses ( git_remote_origin text DEFAULT ''::text NOT NULL ); +CREATE TABLE chat_files ( + id uuid DEFAULT gen_random_uuid() NOT NULL, + owner_id uuid NOT NULL, + organization_id uuid NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + name text DEFAULT ''::text NOT NULL, + mimetype text NOT NULL, + data bytea NOT NULL +); + CREATE TABLE chat_messages ( id bigint NOT NULL, chat_id uuid NOT NULL, @@ -3140,6 +3150,9 @@ ALTER TABLE ONLY boundary_usage_stats ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id); +ALTER TABLE ONLY chat_files + ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id); + ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id); @@ -3495,6 +3508,10 @@ CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC); CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at); +CREATE INDEX idx_chat_files_org ON chat_files USING btree (organization_id); + +CREATE INDEX idx_chat_files_owner ON chat_files USING btree (owner_id); + CREATE INDEX idx_chat_messages_chat ON chat_messages USING btree (chat_id); CREATE INDEX idx_chat_messages_chat_created ON chat_messages USING btree (chat_id, created_at); @@ -3774,6 +3791,12 @@ ALTER TABLE ONLY api_keys ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; +ALTER TABLE ONLY chat_files + ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; + +ALTER TABLE ONLY chat_files + ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE; + ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index 8c8797c2a8..2fb45a6963 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -9,6 +9,8 @@ const ( ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id); ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; + ForeignKeyChatFilesOrganizationID ForeignKeyConstraint = "chat_files_organization_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; + ForeignKeyChatFilesOwnerID ForeignKeyConstraint = "chat_files_owner_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; ForeignKeyChatMessagesModelConfigID ForeignKeyConstraint = "chat_messages_model_config_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_model_config_id_fkey FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id); ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id); diff --git a/coderd/database/migrations/000429_chat_files.down.sql b/coderd/database/migrations/000429_chat_files.down.sql new file mode 100644 index 0000000000..37044f07df --- /dev/null +++ b/coderd/database/migrations/000429_chat_files.down.sql @@ -0,0 +1,2 @@ +DROP INDEX IF EXISTS idx_chat_files_org; +DROP TABLE IF EXISTS chat_files; diff --git a/coderd/database/migrations/000429_chat_files.up.sql b/coderd/database/migrations/000429_chat_files.up.sql new file mode 100644 index 0000000000..42abedaeb5 --- /dev/null +++ b/coderd/database/migrations/000429_chat_files.up.sql @@ -0,0 +1,12 @@ +CREATE TABLE chat_files ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + owner_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + name TEXT NOT NULL DEFAULT '', + mimetype TEXT NOT NULL, + data BYTEA NOT NULL +); + +CREATE INDEX idx_chat_files_owner ON chat_files(owner_id); +CREATE INDEX idx_chat_files_org ON chat_files(organization_id); diff --git a/coderd/database/migrations/testdata/fixtures/000429_chat_files.up.sql b/coderd/database/migrations/testdata/fixtures/000429_chat_files.up.sql new file mode 100644 index 0000000000..cd546f8f28 --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000429_chat_files.up.sql @@ -0,0 +1,13 @@ +INSERT INTO chat_files (id, owner_id, organization_id, created_at, name, mimetype, data) +SELECT + '00000000-0000-0000-0000-000000000099', + u.id, + om.organization_id, + '2024-01-01 00:00:00+00', + 'test.png', + 'image/png', + E'\\x89504E47' +FROM users u +JOIN organization_members om ON om.user_id = u.id +ORDER BY u.created_at, u.id +LIMIT 1; diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 3408ab20d5..a978840726 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -178,6 +178,10 @@ func (c Chat) RBACObject() rbac.Object { return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()) } +func (c ChatFile) RBACObject() rbac.Object { + return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID) +} + func (s APIKeyScope) ToRBAC() rbac.ScopeName { switch s { case ApiKeyScopeCoderAll: diff --git a/coderd/database/models.go b/coderd/database/models.go index 9007d046b4..f1f31313cf 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -3926,6 +3926,16 @@ type ChatDiffStatus struct { GitRemoteOrigin string `db:"git_remote_origin" json:"git_remote_origin"` } +type ChatFile struct { + ID uuid.UUID `db:"id" json:"id"` + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + Name string `db:"name" json:"name"` + Mimetype string `db:"mimetype" json:"mimetype"` + Data []byte `db:"data" json:"data"` +} + type ChatMessage struct { ID int64 `db:"id" json:"id"` ChatID uuid.UUID `db:"chat_id" json:"chat_id"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 467c97801e..014f9fd3b0 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -218,6 +218,8 @@ type sqlcQuerier interface { GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error) + GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error) + GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error) GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error) GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error) @@ -601,6 +603,7 @@ type sqlcQuerier interface { InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error) InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error) + InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error) InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error) InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error) InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 0e796864f7..931aa9bea6 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2214,6 +2214,103 @@ func (q *sqlQuerier) UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBou return new_period, err } +const getChatFileByID = `-- name: GetChatFileByID :one +SELECT id, owner_id, organization_id, created_at, name, mimetype, data FROM chat_files WHERE id = $1::uuid +` + +func (q *sqlQuerier) GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error) { + row := q.db.QueryRowContext(ctx, getChatFileByID, id) + var i ChatFile + err := row.Scan( + &i.ID, + &i.OwnerID, + &i.OrganizationID, + &i.CreatedAt, + &i.Name, + &i.Mimetype, + &i.Data, + ) + return i, err +} + +const getChatFilesByIDs = `-- name: GetChatFilesByIDs :many +SELECT id, owner_id, organization_id, created_at, name, mimetype, data FROM chat_files WHERE id = ANY($1::uuid[]) +` + +func (q *sqlQuerier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error) { + rows, err := q.db.QueryContext(ctx, getChatFilesByIDs, pq.Array(ids)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ChatFile + for rows.Next() { + var i ChatFile + if err := rows.Scan( + &i.ID, + &i.OwnerID, + &i.OrganizationID, + &i.CreatedAt, + &i.Name, + &i.Mimetype, + &i.Data, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertChatFile = `-- name: InsertChatFile :one +INSERT INTO chat_files (owner_id, organization_id, name, mimetype, data) +VALUES ($1::uuid, $2::uuid, $3::text, $4::text, $5::bytea) +RETURNING id, owner_id, organization_id, created_at, name, mimetype +` + +type InsertChatFileParams struct { + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + Name string `db:"name" json:"name"` + Mimetype string `db:"mimetype" json:"mimetype"` + Data []byte `db:"data" json:"data"` +} + +type InsertChatFileRow struct { + ID uuid.UUID `db:"id" json:"id"` + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + Name string `db:"name" json:"name"` + Mimetype string `db:"mimetype" json:"mimetype"` +} + +func (q *sqlQuerier) InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error) { + row := q.db.QueryRowContext(ctx, insertChatFile, + arg.OwnerID, + arg.OrganizationID, + arg.Name, + arg.Mimetype, + arg.Data, + ) + var i InsertChatFileRow + err := row.Scan( + &i.ID, + &i.OwnerID, + &i.OrganizationID, + &i.CreatedAt, + &i.Name, + &i.Mimetype, + ) + return i, err +} + const deleteChatModelConfigByID = `-- name: DeleteChatModelConfigByID :exec UPDATE chat_model_configs diff --git a/coderd/database/queries/chatfiles.sql b/coderd/database/queries/chatfiles.sql new file mode 100644 index 0000000000..5cb2ad89fe --- /dev/null +++ b/coderd/database/queries/chatfiles.sql @@ -0,0 +1,10 @@ +-- name: InsertChatFile :one +INSERT INTO chat_files (owner_id, organization_id, name, mimetype, data) +VALUES (@owner_id::uuid, @organization_id::uuid, @name::text, @mimetype::text, @data::bytea) +RETURNING id, owner_id, organization_id, created_at, name, mimetype; + +-- name: GetChatFileByID :one +SELECT * FROM chat_files WHERE id = @id::uuid; + +-- name: GetChatFilesByIDs :many +SELECT * FROM chat_files WHERE id = ANY(@ids::uuid[]); diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index e7f8489915..0ecf890017 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -15,6 +15,7 @@ const ( UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id); UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id); UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id); + UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id); UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id); UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id); UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id); diff --git a/codersdk/chats.go b/codersdk/chats.go index 8175ae5bc7..e3b81caa9c 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "mime" "net/http" "net/url" "strings" @@ -96,6 +97,7 @@ type ChatMessagePart struct { Title string `json:"title,omitempty"` MediaType string `json:"media_type,omitempty"` Data []byte `json:"data,omitempty"` + FileID uuid.NullUUID `json:"file_id,omitempty" format:"uuid"` } // ChatInputPartType represents an input part type for user chat input. @@ -103,12 +105,14 @@ type ChatInputPartType string const ( ChatInputPartTypeText ChatInputPartType = "text" + ChatInputPartTypeFile ChatInputPartType = "file" ) // ChatInputPart is a single user input part for creating a chat. type ChatInputPart struct { - Type ChatInputPartType `json:"type"` - Text string `json:"text,omitempty"` + Type ChatInputPartType `json:"type"` + Text string `json:"text,omitempty"` + FileID uuid.UUID `json:"file_id,omitempty" format:"uuid"` } // CreateChatRequest is the request to create a new chat. @@ -141,6 +145,11 @@ type CreateChatMessageResponse struct { Queued bool `json:"queued"` } +// UploadChatFileResponse is the response from uploading a chat file. +type UploadChatFileResponse struct { + ID uuid.UUID `json:"id" format:"uuid"` +} + // ChatWithMessages is a chat along with its messages. type ChatWithMessages struct { Chat Chat `json:"chat"` @@ -938,6 +947,42 @@ func (c *Client) GetChatDiffContents(ctx context.Context, chatID uuid.UUID) (Cha return diff, json.NewDecoder(res.Body).Decode(&diff) } +// UploadChatFile uploads a file for use in chat messages. +func (c *Client) UploadChatFile(ctx context.Context, organizationID uuid.UUID, contentType string, filename string, rd io.Reader) (UploadChatFileResponse, error) { + res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/files?organization=%s", organizationID), rd, func(r *http.Request) { + r.Header.Set("Content-Type", contentType) + if filename != "" { + r.Header.Set("Content-Disposition", mime.FormatMediaType("attachment", map[string]string{"filename": filename})) + } + }) + if err != nil { + return UploadChatFileResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusCreated { + return UploadChatFileResponse{}, ReadBodyAsError(res) + } + var resp UploadChatFileResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} + +// GetChatFile retrieves a previously uploaded chat file by ID. +func (c *Client) GetChatFile(ctx context.Context, fileID uuid.UUID) ([]byte, string, error) { + res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/files/%s", fileID), nil) + if err != nil { + return nil, "", err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, "", ReadBodyAsError(res) + } + data, err := io.ReadAll(res.Body) + if err != nil { + return nil, "", err + } + return data, res.Header.Get("Content-Type"), nil +} + func formatChatStreamResponseError(response Response) string { message := strings.TrimSpace(response.Message) detail := strings.TrimSpace(response.Detail) diff --git a/docs/ai-coder/agents/index.md b/docs/ai-coder/agents/index.md index 92d62520e4..3c426e4d40 100644 --- a/docs/ai-coder/agents/index.md +++ b/docs/ai-coder/agents/index.md @@ -132,6 +132,18 @@ are queued and delivered when the agent completes its current step, so there is no need to wait for a response before providing additional context or changing direction. +### Image attachments + +Users can attach images to chat messages by pasting from the clipboard, dragging +files into the input area, or using the attachment button. Supported formats are +PNG, JPEG, GIF, and WebP up to 10 MB per file. Images are sent to the model as +multimodal content alongside the text prompt. + +This is useful for sharing screenshots of errors, UI mockups, terminal output, +or other visual context that helps the agent understand the task. Messages can +contain images alone or combined with text. Image attachments require a model +that supports vision input. + ## Security benefits of the control plane architecture Running the agent loop in the control plane rather than inside the developer diff --git a/docs/reference/api/chats.md b/docs/reference/api/chats.md index 655993b962..073b509a08 100644 --- a/docs/reference/api/chats.md +++ b/docs/reference/api/chats.md @@ -1,5 +1,83 @@ # Chats +## Upload a chat file + +### Code samples + +```shell +# Example request using curl +curl -X POST http://coder-server:8080/api/v2/chats/files?organization=497f6eca-6276-4993-bfeb-53cbbbba6f08 \ + -H 'Accept: application/json' \ + -H 'Content-Type: string' \ + -H 'Coder-Session-Token: API_KEY' +``` + +`POST /chats/files` + +### Parameters + +| Name | In | Type | Required | Description | +|----------------|--------|--------------|----------|-----------------------------------------------------------------------------------| +| `Content-Type` | header | string | true | Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp) | +| `organization` | query | string(uuid) | true | Organization ID | + +### Example responses + +> 201 Response + +```json +{ + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08" +} +``` + +### Responses + +| Status | Meaning | Description | Schema | +|--------|----------------------------------------------------------------------------|--------------------------|------------------------------------------------------------------------------| +| 201 | [Created](https://tools.ietf.org/html/rfc7231#section-6.3.2) | Created | [codersdk.UploadChatFileResponse](schemas.md#codersdkuploadchatfileresponse) | +| 400 | [Bad Request](https://tools.ietf.org/html/rfc7231#section-6.5.1) | Bad Request | [codersdk.Response](schemas.md#codersdkresponse) | +| 401 | [Unauthorized](https://tools.ietf.org/html/rfc7235#section-3.1) | Unauthorized | [codersdk.Response](schemas.md#codersdkresponse) | +| 413 | [Payload Too Large](https://tools.ietf.org/html/rfc7231#section-6.5.11) | Request Entity Too Large | [codersdk.Response](schemas.md#codersdkresponse) | +| 500 | [Internal Server Error](https://tools.ietf.org/html/rfc7231#section-6.6.1) | Internal Server Error | [codersdk.Response](schemas.md#codersdkresponse) | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). + +## Get a chat file + +### Code samples + +```shell +# Example request using curl +curl -X GET http://coder-server:8080/api/v2/chats/files/{file} \ + -H 'Accept: */*' \ + -H 'Coder-Session-Token: API_KEY' +``` + +`GET /chats/files/{file}` + +### Parameters + +| Name | In | Type | Required | Description | +|--------|------|--------------|----------|-------------| +| `file` | path | string(uuid) | true | File ID | + +### Example responses + +> 400 Response + +### Responses + +| Status | Meaning | Description | Schema | +|--------|----------------------------------------------------------------------------|-----------------------|--------------------------------------------------| +| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | | +| 400 | [Bad Request](https://tools.ietf.org/html/rfc7231#section-6.5.1) | Bad Request | [codersdk.Response](schemas.md#codersdkresponse) | +| 401 | [Unauthorized](https://tools.ietf.org/html/rfc7235#section-3.1) | Unauthorized | [codersdk.Response](schemas.md#codersdkresponse) | +| 404 | [Not Found](https://tools.ietf.org/html/rfc7231#section-6.5.4) | Not Found | [codersdk.Response](schemas.md#codersdkresponse) | +| 500 | [Internal Server Error](https://tools.ietf.org/html/rfc7231#section-6.6.1) | Internal Server Error | [codersdk.Response](schemas.md#codersdkresponse) | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). + ## Archive a chat ### Code samples diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index cd67ea783d..c113917502 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -9847,6 +9847,20 @@ If the schedule is empty, the user will be updated to use the default schedule.| |----------|---------|----------|--------------|-------------| | `ttl_ms` | integer | false | | | +## codersdk.UploadChatFileResponse + +```json +{ + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08" +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|------|--------|----------|--------------|-------------| +| `id` | string | false | | | + ## codersdk.UploadResponse ```json diff --git a/site/src/api/api.ts b/site/src/api/api.ts index b8849444a9..36c7d4ae9e 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -2296,6 +2296,23 @@ class ApiMethods { return response.data; }; + uploadChatFile = async ( + file: File, + organizationId: string, + ): Promise => { + const response = await this.axios.post( + `/api/experimental/chats/files?organization=${organizationId}`, + file, + { + headers: { + "Content-Type": file.type || "application/octet-stream", + "Content-Disposition": `attachment; filename="${file.name}"`, + }, + }, + ); + return response.data; + }; + getTemplateVersionLogs = async ( versionId: string, ): Promise => { diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index fb58a6a963..4efd1153d4 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1120,12 +1120,13 @@ export interface ChatGitChange { export interface ChatInputPart { readonly type: ChatInputPartType; readonly text?: string; + readonly file_id?: string; } // From codersdk/chats.go -export type ChatInputPartType = "text"; +export type ChatInputPartType = "file" | "text"; -export const ChatInputPartTypes: ChatInputPartType[] = ["text"]; +export const ChatInputPartTypes: ChatInputPartType[] = ["file", "text"]; // From codersdk/chats.go /** @@ -1161,6 +1162,7 @@ export interface ChatMessagePart { readonly title?: string; readonly media_type?: string; readonly data?: string; + readonly file_id?: string; } // From codersdk/chats.go @@ -6556,6 +6558,14 @@ export interface UpdateWorkspaceTTLRequest { readonly ttl_ms: number | null; } +// From codersdk/chats.go +/** + * UploadChatFileResponse is the response from uploading a chat file. + */ +export interface UploadChatFileResponse { + readonly id: string; +} + // From codersdk/files.go /** * UploadResponse contains the hash to reference the uploaded file. diff --git a/site/src/components/ChatMessageInput/ChatMessageInput.tsx b/site/src/components/ChatMessageInput/ChatMessageInput.tsx index 23ad9c3c31..77e7cecea9 100644 --- a/site/src/components/ChatMessageInput/ChatMessageInput.tsx +++ b/site/src/components/ChatMessageInput/ChatMessageInput.tsx @@ -57,8 +57,11 @@ const DisableFormattingPlugin: FC = memo(function DisableFormattingPlugin() { }); // Intercepts paste events and inserts clipboard content as plain text, -// stripping any rich-text formatting. -const PasteSanitizationPlugin: FC = memo(function PasteSanitizationPlugin() { +// stripping any rich-text formatting. Image files are forwarded to +// the parent via the onFilePaste callback instead of being inserted. +const PasteSanitizationPlugin: FC<{ + onFilePaste?: (file: File) => void; +}> = memo(function PasteSanitizationPlugin({ onFilePaste }) { const [editor] = useLexicalComposerContext(); useEffect(() => { @@ -69,6 +72,22 @@ const PasteSanitizationPlugin: FC = memo(function PasteSanitizationPlugin() { const clipboardData = event.clipboardData; if (!clipboardData) return false; + // Check for image files in the clipboard (e.g. pasted + // screenshots). Forward them to the parent via callback + // instead of inserting text. + if (onFilePaste && clipboardData.files.length > 0) { + const images = Array.from(clipboardData.files).filter((f) => + f.type.startsWith("image/"), + ); + if (images.length > 0) { + event.preventDefault(); + for (const file of images) { + onFilePaste(file); + } + return true; + } + } + const text = clipboardData.getData("text/plain"); if (!text) return false; @@ -106,7 +125,7 @@ const PasteSanitizationPlugin: FC = memo(function PasteSanitizationPlugin() { }, COMMAND_PRIORITY_HIGH, ); - }, [editor]); + }, [editor, onFilePaste]); return null; }); @@ -217,6 +236,7 @@ interface ChatMessageInputProps onChange?: (content: string) => void; rows?: number; onEnter?: () => void; + onFilePaste?: (file: File) => void; disabled?: boolean; autoFocus?: boolean; "aria-label"?: string; @@ -245,6 +265,7 @@ const ChatMessageInput = memo( onChange, rows, onEnter, + onFilePaste, disabled, autoFocus, "aria-label": ariaLabel, @@ -392,7 +413,7 @@ const ChatMessageInput = memo( /> - + diff --git a/site/src/pages/AgentsPage/AgentChatInput.stories.tsx b/site/src/pages/AgentsPage/AgentChatInput.stories.tsx index 7f3a287e6f..47245db2f6 100644 --- a/site/src/pages/AgentsPage/AgentChatInput.stories.tsx +++ b/site/src/pages/AgentsPage/AgentChatInput.stories.tsx @@ -1,6 +1,6 @@ import type { Meta, StoryObj } from "@storybook/react-vite"; import { expect, fn, userEvent, waitFor, within } from "storybook/test"; -import { AgentChatInput } from "./AgentChatInput"; +import { AgentChatInput, type UploadState } from "./AgentChatInput"; const defaultModelOptions = [ { @@ -144,3 +144,80 @@ export const LongContentScrollable: Story = { initialValue: longContent, }, }; + +// Tiny 1x1 transparent PNG as data URI for attachment previews. +const TINY_PNG = + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="; + +const createMockFile = (name: string, type: string) => + new File(["mock-data"], name, { type }); + +export const WithAttachments: Story = { + args: (() => { + const file1 = createMockFile("screenshot.png", "image/png"); + const file2 = createMockFile("diagram.jpg", "image/jpeg"); + const attachments = [file1, file2]; + return { + attachments, + uploadStates: new Map([ + [file1, { status: "uploaded", fileId: "f1" }], + [file2, { status: "uploaded", fileId: "f2" }], + ]), + previewUrls: new Map([ + [file1, TINY_PNG], + [file2, TINY_PNG], + ]), + onAttach: fn(), + onRemoveAttachment: fn(), + initialValue: "Here are the images", + }; + })(), +}; + +export const WithUploadingAttachment: Story = { + args: (() => { + const file = createMockFile("uploading.png", "image/png"); + return { + attachments: [file], + uploadStates: new Map([ + [file, { status: "uploading" }], + ]), + previewUrls: new Map([[file, TINY_PNG]]), + onAttach: fn(), + onRemoveAttachment: fn(), + initialValue: "Waiting for upload", + }; + })(), +}; + +export const WithAttachmentError: Story = { + args: (() => { + const file = createMockFile("broken.png", "image/png"); + return { + attachments: [file], + uploadStates: new Map([ + [file, { status: "error", error: "Upload failed: server error" }], + ]), + previewUrls: new Map([[file, TINY_PNG]]), + onAttach: fn(), + onRemoveAttachment: fn(), + initialValue: "Upload had an error", + }; + })(), +}; + +export const AttachmentsOnly: Story = { + args: (() => { + const file = createMockFile("photo.png", "image/png"); + return { + attachments: [file], + uploadStates: new Map([ + [file, { status: "uploaded", fileId: "f-only" }], + ]), + previewUrls: new Map([[file, TINY_PNG]]), + onAttach: fn(), + onRemoveAttachment: fn(), + initialValue: "", + }; + })(), +}; diff --git a/site/src/pages/AgentsPage/AgentChatInput.tsx b/site/src/pages/AgentsPage/AgentChatInput.tsx index a67775dd8a..9df0d0489e 100644 --- a/site/src/pages/AgentsPage/AgentChatInput.tsx +++ b/site/src/pages/AgentsPage/AgentChatInput.tsx @@ -13,14 +13,29 @@ import { TooltipContent, TooltipTrigger, } from "components/Tooltip/Tooltip"; -import { ArrowUpIcon, Loader2Icon, Square, XIcon } from "lucide-react"; +import { + AlertTriangleIcon, + ArrowUpIcon, + ImageIcon, + Loader2Icon, + Square, + XIcon, +} from "lucide-react"; +import type React from "react"; import { memo, type ReactNode, useCallback, useRef, useState } from "react"; import { cn } from "utils/cn"; +import { ImageLightbox } from "./ImageLightbox"; import { formatProviderLabel } from "./modelOptions"; import { QueuedMessagesList } from "./QueuedMessagesList"; export type { ChatMessageInputRef } from "components/ChatMessageInput/ChatMessageInput"; +export type UploadState = { + status: "uploading" | "uploaded" | "error"; + fileId?: string; + error?: string; +}; + export interface AgentContextUsage { readonly usedTokens?: number; readonly contextLimitTokens?: number; @@ -76,6 +91,11 @@ interface AgentChatInputProps { // Pass `null` to render fallback values (e.g. when limit is unknown). // Omit entirely to hide the indicator. contextUsage?: AgentContextUsage | null; + attachments?: File[]; + onAttach?: (files: File[]) => void; + onRemoveAttachment?: (index: number) => void; + uploadStates?: Map; + previewUrls?: Map; } const hasFiniteTokenValue = (value: number | undefined): value is number => @@ -201,6 +221,97 @@ const ContextUsageIndicator = memo<{ usage: AgentContextUsage | null }>( ); ContextUsageIndicator.displayName = "ContextUsageIndicator"; +/** Renders an image thumbnail from a pre-created preview URL. */ +export const ImageThumbnail = memo<{ + previewUrl: string; + name: string; + className?: string; +}>(({ previewUrl, name, className }) => ( + {name} +)); +ImageThumbnail.displayName = "ImageThumbnail"; + +/** Renders a horizontal strip of attachment thumbnails above the input. */ +export const AttachmentPreview = memo<{ + attachments: File[]; + onRemove: (index: number) => void; + uploadStates?: Map; + previewUrls?: Map; + onPreview?: (url: string) => void; +}>(({ attachments, onRemove, uploadStates, previewUrls, onPreview }) => { + if (attachments.length === 0) return null; + + return ( +
+ {attachments.map((file, index) => { + const uploadState = uploadStates?.get(file); + const previewUrl = previewUrls?.get(file) ?? ""; + return ( +
+ {file.type.startsWith("image/") && previewUrl ? ( + + ) : ( +
+ {file.name.split(".").pop()?.toUpperCase() || "FILE"} +
+ )} + {uploadState?.status === "uploading" && ( +
+ +
+ )} + {uploadState?.status === "error" && ( + + +
+ +
+
+ +

+ {uploadState.error ?? "Upload failed"} +

+
+
+ )} + +
+ ); + })} +
+ ); +}); +AttachmentPreview.displayName = "AttachmentPreview"; + export const AgentChatInput = memo( ({ onSend, @@ -230,8 +341,14 @@ export const AgentChatInput = memo( isEditingHistoryMessage = false, onCancelHistoryEdit, contextUsage, + attachments = [], + onAttach, + onRemoveAttachment, + uploadStates, + previewUrls, }) => { const internalRef = useRef(null); + const [previewImage, setPreviewImage] = useState(null); // Merge the external inputRef with our internal ref so both // point to the same ChatMessageInputRef instance. @@ -251,6 +368,57 @@ export const AgentChatInput = memo( [inputRef], ); + const fileInputRef = useRef(null); + + const handleFileSelect = useCallback( + (e: React.ChangeEvent) => { + if (e.target.files && onAttach) { + onAttach(Array.from(e.target.files)); + } + // Reset so the same file can be selected again. + e.target.value = ""; + }, + [onAttach], + ); + + const handleFilePaste = useCallback( + (file: File) => { + onAttach?.([file]); + }, + [onAttach], + ); + + // Drag-and-drop support for image files. + const [isDragging, setIsDragging] = useState(false); + + const handleDragOver = useCallback((e: React.DragEvent) => { + e.preventDefault(); + if (e.dataTransfer.types.includes("Files")) { + setIsDragging(true); + } + }, []); + + const handleDragLeave = useCallback((e: React.DragEvent) => { + if (!e.currentTarget.contains(e.relatedTarget as Node)) { + setIsDragging(false); + } + }, []); + + const handleDrop = useCallback( + (e: React.DragEvent) => { + e.preventDefault(); + setIsDragging(false); + if (!onAttach || !e.dataTransfer.files.length) return; + const images = Array.from(e.dataTransfer.files).filter((f) => + f.type.startsWith("image/"), + ); + if (images.length > 0) { + onAttach(images); + } + }, + [onAttach], + ); + // Track whether the editor has content so we can gate the // send button without a controlled value prop. const [hasContent, setHasContent] = useState(() => !!initialValue?.trim()); @@ -275,7 +443,18 @@ export const AgentChatInput = memo( } } - const canSend = !isDisabled && !isLoading && hasModelOptions && hasContent; + const isUploading = attachments.some( + (f) => uploadStates?.get(f)?.status === "uploading", + ); + const hasUploadedAttachments = attachments.some( + (f) => uploadStates?.get(f)?.status === "uploaded", + ); + const canSend = + !isDisabled && + !isLoading && + hasModelOptions && + (hasContent || hasUploadedAttachments) && + !isUploading; const handleSubmit = useCallback(() => { const text = internalRef.current?.getValue()?.trim() ?? ""; @@ -284,6 +463,7 @@ export const AgentChatInput = memo( // promote the first one instead of submitting. if ( !text && + !hasUploadedAttachments && !isDisabled && !isLoading && queuedMessages.length > 0 && @@ -293,7 +473,12 @@ export const AgentChatInput = memo( return; } - if (!text || isDisabled || isLoading || !hasModelOptions) { + if ( + (!text && !hasUploadedAttachments) || + isDisabled || + isLoading || + !hasModelOptions + ) { return; } @@ -303,6 +488,7 @@ export const AgentChatInput = memo( isDisabled, isLoading, hasModelOptions, + hasUploadedAttachments, onSend, queuedMessages, onPromoteQueuedMessage, @@ -348,8 +534,14 @@ export const AgentChatInput = memo( /> )}
{editingQueuedMessageID !== null && (
@@ -388,8 +580,18 @@ export const AgentChatInput = memo(
)} + {onRemoveAttachment && ( + + )} ( {contextUsage !== undefined && ( )} + {onAttach && ( + <> + + + + )} {isStreaming && onInterrupt && (
); - return content; + return ( + <> + {content} + {previewImage && ( + setPreviewImage(null)} + /> + )} + + ); }, ); AgentChatInput.displayName = "AgentChatInput"; diff --git a/site/src/pages/AgentsPage/AgentDetail.test.ts b/site/src/pages/AgentsPage/AgentDetail.test.ts index f6549df834..82b7f8f0cb 100644 --- a/site/src/pages/AgentsPage/AgentDetail.test.ts +++ b/site/src/pages/AgentsPage/AgentDetail.test.ts @@ -110,7 +110,7 @@ describe("useConversationEditingState", () => { await act(async () => { result.current.handleSendFromInput("hello"); await vi.waitFor(() => { - expect(onSend).toHaveBeenCalledWith("hello", undefined); + expect(onSend).toHaveBeenCalledWith("hello", undefined, undefined); }); }); diff --git a/site/src/pages/AgentsPage/AgentDetail.tsx b/site/src/pages/AgentsPage/AgentDetail.tsx index 91bc8ff04a..af7506304b 100644 --- a/site/src/pages/AgentsPage/AgentDetail.tsx +++ b/site/src/pages/AgentsPage/AgentDetail.tsx @@ -22,6 +22,7 @@ import { getVSCodeHref, openAppInNewWindow, } from "modules/apps/apps"; +import { useDashboard } from "modules/dashboard/useDashboard"; import { type FC, useCallback, @@ -35,7 +36,11 @@ import { useNavigate, useOutletContext, useParams } from "react-router"; import { toast } from "sonner"; import { cn } from "utils/cn"; import { pageTitle } from "utils/page"; -import { AgentChatInput, type ChatMessageInputRef } from "./AgentChatInput"; +import { + AgentChatInput, + type ChatMessageInputRef, + type UploadState, +} from "./AgentChatInput"; import { selectChatStatus, selectHasStreamState, @@ -74,6 +79,7 @@ import { } from "./modelOptions"; import { RightPanel } from "./RightPanel"; import { SidebarTabView } from "./SidebarTabView"; +import { useFileAttachments } from "./useFileAttachments"; import { useGitWatcher } from "./useGitWatcher"; const noopSetChatErrorReason: AgentsOutletContext["setChatErrorReason"] = @@ -99,7 +105,11 @@ interface AgentDetailTimelineProps { store: ChatStoreHandle; chatID: string; persistedErrorReason: string | undefined; - onEditUserMessage?: (messageId: number, text: string) => void; + onEditUserMessage?: ( + messageId: number, + text: string, + fileBlocks?: Array<{ mediaType: string; data?: string }>, + ) => void; editingMessageId?: number | null; savingMessageId?: number | null; } @@ -186,7 +196,7 @@ const AgentDetailTimeline: FC = ({ interface AgentDetailInputProps { store: ChatStoreHandle; compressionThreshold: number | undefined; - onSend: (message: string) => void; + onSend: (message: string, fileIds?: string[]) => void; onDeleteQueuedMessage: (id: number) => Promise; onPromoteQueuedMessage: (id: number) => Promise; onInterrupt: () => void; @@ -210,6 +220,13 @@ interface AgentDetailInputProps { onCancelQueueEdit: () => void; isEditingHistoryMessage: boolean; onCancelHistoryEdit: () => void; + // File blocks from the message being edited, converted to + // File objects and pre-populated into attachments. + editingFileBlocks?: Array<{ + mediaType: string; + data?: string; + fileId?: string; + }>; } const AgentDetailInput: FC = ({ @@ -237,6 +254,7 @@ const AgentDetailInput: FC = ({ onCancelQueueEdit, isEditingHistoryMessage, onCancelHistoryEdit, + editingFileBlocks, }) => { const messagesByID = useChatSelector(store, selectMessagesByID); const orderedMessageIDs = useChatSelector(store, selectOrderedMessageIDs); @@ -251,6 +269,8 @@ const AgentDetailInput: FC = ({ .filter(isChatMessage), [messagesByID, orderedMessageIDs], ); + const { organizations } = useDashboard(); + const organizationId = organizations[0]?.id; const latestContextUsage = useMemo(() => { const usage = getLatestContextUsage(messages); if (!usage) { @@ -258,12 +278,96 @@ const AgentDetailInput: FC = ({ } return { ...usage, compressionThreshold }; }, [messages, compressionThreshold]); + const { + attachments, + uploadStates, + previewUrls, + handleAttach, + handleRemoveAttachment, + resetAttachments, + setAttachments, + setPreviewUrls, + setUploadStates, + } = useFileAttachments(organizationId); + // Pre-populate attachments from existing file blocks when + // entering edit mode on a message with images. + useEffect(() => { + if (!editingFileBlocks || editingFileBlocks.length === 0) { + // Clear attachments when exiting edit mode. + setAttachments([]); + setUploadStates(new Map()); + setPreviewUrls(new Map()); + return; + } + const files = editingFileBlocks.map((block, i) => { + const ext = block.mediaType.split("/")[1] ?? "png"; + // Empty File used as a Map key only, its content is never + // read because the existing fileId is reused at send time. + return new File([], `attachment-${i}.${ext}`, { + type: block.mediaType, + }); + }); + setAttachments(files); + setPreviewUrls( + new Map( + files.map((f, i) => [ + f, + `/api/experimental/chats/files/${editingFileBlocks[i].fileId}`, + ]), + ), + ); + const newUploadStates = new Map(); + for (const [i, file] of files.entries()) { + const block = editingFileBlocks[i]; + if (block.fileId) { + newUploadStates.set(file, { + status: "uploaded", + fileId: block.fileId, + }); + } + } + setUploadStates(newUploadStates); + }, [editingFileBlocks, setAttachments, setPreviewUrls, setUploadStates]); + const isStreaming = hasStreamState || chatStatus === "running" || chatStatus === "pending"; return ( { + void (async () => { + try { + // Collect file IDs from already-uploaded attachments. + // Skip files in error state (e.g. too large). + const fileIds: string[] = []; + let skippedErrors = 0; + for (const file of attachments) { + const state = uploadStates.get(file); + if (state?.status === "error") { + skippedErrors++; + continue; + } + if (state?.status === "uploaded" && state.fileId) { + fileIds.push(state.fileId); + } + } + if (skippedErrors > 0) { + toast.warning( + `${skippedErrors} attachment${skippedErrors > 1 ? "s" : ""} could not be sent (upload failed)`, + ); + } + await onSend(message, fileIds.length > 0 ? fileIds : undefined); + resetAttachments(); + } catch { + // Attachments preserved for retry on failure. + } + })(); + }} + attachments={attachments} + onAttach={handleAttach} + onRemoveAttachment={handleRemoveAttachment} + uploadStates={uploadStates} + previewUrls={previewUrls} inputRef={inputRef} initialValue={initialValue} onContentChange={onContentChange} @@ -295,7 +399,11 @@ const AgentDetailInput: FC = ({ /** @internal Exported for testing. */ export function useConversationEditingState(deps: { chatID: string | undefined; - onSend: (message: string, editedMessageID?: number) => Promise; + onSend: ( + message: string, + fileIds?: string[], + editedMessageID?: number, + ) => Promise; onDeleteQueuedMessage: (id: number) => Promise; chatInputRef: React.RefObject; inputValueRef: React.RefObject; @@ -321,15 +429,23 @@ export function useConversationEditingState(deps: { const [draftBeforeHistoryEdit, setDraftBeforeHistoryEdit] = useState< string | null >(null); + const [editingFileBlocks, setEditingFileBlocks] = useState< + Array<{ mediaType: string; data?: string; fileId?: string }> + >([]); const handleEditUserMessage = useCallback( - (messageId: number, text: string) => { + ( + messageId: number, + text: string, + fileBlocks?: Array<{ mediaType: string; data?: string; fileId?: string }>, + ) => { setDraftBeforeHistoryEdit((prev) => editingMessageId !== null ? prev : inputValueRef.current, ); setEditingMessageId(messageId); setEditorInitialValue(text); inputValueRef.current = text; + setEditingFileBlocks(fileBlocks ?? []); }, [editingMessageId, inputValueRef], ); @@ -339,6 +455,7 @@ export function useConversationEditingState(deps: { inputValueRef.current = draftBeforeHistoryEdit ?? ""; setEditingMessageId(null); setDraftBeforeHistoryEdit(null); + setEditingFileBlocks([]); }, [draftBeforeHistoryEdit, inputValueRef]); // -- Queue editing state -- @@ -371,29 +488,29 @@ export function useConversationEditingState(deps: { // Wraps the parent onSend to clear local input/editing state // and handle queue-edit deletion. const handleSendFromInput = useCallback( - (message: string) => { + async (message: string, fileIds?: string[]) => { const editedMessageID = editingMessageId !== null ? editingMessageId : undefined; const queueEditID = editingQueuedMessageID; - void onSend(message, editedMessageID).then(() => { - // Clear input and editing state on success. - chatInputRef.current?.clear(); - chatInputRef.current?.focus(); - inputValueRef.current = ""; - if (typeof window !== "undefined" && draftStorageKey) { - localStorage.removeItem(draftStorageKey); - } - if (editingMessageId !== null) { - setEditingMessageId(null); - setDraftBeforeHistoryEdit(null); - } - if (queueEditID !== null) { - setEditingQueuedMessageID(null); - setDraftBeforeQueueEdit(null); - void onDeleteQueuedMessage(queueEditID); - } - }); + await onSend(message, fileIds, editedMessageID); + // Clear input and editing state on success. + chatInputRef.current?.clear(); + chatInputRef.current?.focus(); + inputValueRef.current = ""; + if (typeof window !== "undefined" && draftStorageKey) { + localStorage.removeItem(draftStorageKey); + } + if (editingMessageId !== null) { + setEditingMessageId(null); + setDraftBeforeHistoryEdit(null); + setEditingFileBlocks([]); + } + if (queueEditID !== null) { + setEditingQueuedMessageID(null); + setDraftBeforeQueueEdit(null); + void onDeleteQueuedMessage(queueEditID); + } }, [ chatInputRef, @@ -425,6 +542,7 @@ export function useConversationEditingState(deps: { chatInputRef, editorInitialValue, editingMessageId, + editingFileBlocks, handleEditUserMessage, handleCancelHistoryEdit, editingQueuedMessageID, @@ -658,16 +776,26 @@ const AgentDetail: FC = () => { interruptMutation.isPending; const isInputDisabled = !hasModelOptions || isArchived; - const handleSend = async (message: string, editedMessageID?: number) => { - if ( - !message.trim() || - isSubmissionPending || - !agentId || - !hasModelOptions - ) { + const handleSend = async ( + message: string, + fileIds?: string[], + editedMessageID?: number, + ) => { + const hasContent = message.trim() || (fileIds && fileIds.length > 0); + if (!hasContent || isSubmissionPending || !agentId || !hasModelOptions) { return; } - const content: TypesGen.ChatInputPart[] = [{ type: "text", text: message }]; + const content: TypesGen.ChatInputPart[] = []; + if (message.trim()) { + content.push({ type: "text", text: message }); + } + + // Add pre-uploaded file references. + if (fileIds && fileIds.length > 0) { + for (const fileId of fileIds) { + content.push({ type: "file", file_id: fileId }); + } + } if (editedMessageID !== undefined) { const request: TypesGen.EditChatMessageRequest = { content }; clearChatErrorReason(agentId); @@ -1091,6 +1219,7 @@ const AgentDetail: FC = () => { onCancelQueueEdit={editing.handleCancelQueueEdit} isEditingHistoryMessage={editing.editingMessageId !== null} onCancelHistoryEdit={editing.handleCancelHistoryEdit} + editingFileBlocks={editing.editingFileBlocks} /> diff --git a/site/src/pages/AgentsPage/AgentDetail/ConversationTimeline.tsx b/site/src/pages/AgentsPage/AgentDetail/ConversationTimeline.tsx index aba7a8844d..5867d2328f 100644 --- a/site/src/pages/AgentsPage/AgentDetail/ConversationTimeline.tsx +++ b/site/src/pages/AgentsPage/AgentDetail/ConversationTimeline.tsx @@ -18,6 +18,8 @@ import { useState, } from "react"; import { cn } from "utils/cn"; +import { ImageThumbnail } from "../AgentChatInput"; +import { ImageLightbox } from "../ImageLightbox"; import { useSmoothStreamingText } from "./SmoothText"; import type { MergedTool, @@ -102,6 +104,7 @@ type RenderBlockListParams = { isStreaming?: boolean; subagentTitles?: Map; subagentStatusOverrides?: Map; + onImageClick?: (src: string) => void; }; // Wrapper that runs the smooth-streaming jitter buffer on a single @@ -132,6 +135,7 @@ function renderBlockList({ isStreaming = false, subagentTitles, subagentStatusOverrides, + onImageClick, }: RenderBlockListParams): RenderBlockListResult { const renderedToolIDs = new Set(); const elements = blocks @@ -194,6 +198,30 @@ function renderBlockList({ /> ); } + case "file": + if (block.mediaType.startsWith("image/")) { + const src = block.fileId + ? `/api/experimental/chats/files/${block.fileId}` + : `data:${block.mediaType};base64,${block.data}`; + return ( + + ); + } + return null; default: return null; } @@ -205,7 +233,11 @@ function renderBlockList({ const ChatMessageItem = memo<{ message: TypesGen.ChatMessage; parsed: ParsedMessageContent; - onEditUserMessage?: (messageId: number, text: string) => void; + onEditUserMessage?: ( + messageId: number, + text: string, + fileBlocks?: Array<{ mediaType: string; data?: string; fileId?: string }>, + ) => void; editingMessageId?: number | null; savingMessageId?: number | null; // When true, renders a gradient overlay inside the bubble @@ -223,6 +255,7 @@ const ChatMessageItem = memo<{ }) => { const isUser = message.role === "user"; const isSavingMessage = savingMessageId === message.id; + const [previewImage, setPreviewImage] = useState(null); const toolByID = new Map(parsed.tools.map((tool) => [tool.id, tool])); if ( @@ -243,82 +276,137 @@ const ChatMessageItem = memo<{ blocks: parsed.blocks, toolByID, keyPrefix: String(message.id), + onImageClick: setPreviewImage, }); const remainingTools = parsed.tools.filter( (tool) => !renderedToolIDs.has(tool.id), ); return ( - - {isUser ? ( - - onEditUserMessage(message.id, parsed.markdown || "") - : undefined - } - > -
- {parsed.markdown || ""} - {isSavingMessage && ( - + + {isUser ? ( + + { + const fileBlocks = parsed.blocks.filter( + (b): b is Extract => + b.type === "file" && + b.mediaType.startsWith("image/"), + ); + onEditUserMessage( + message.id, + parsed.markdown || "", + fileBlocks.length > 0 ? fileBlocks : undefined, + ); + } + : undefined + } + > +
+ + {parsed.markdown || ""} + + {isSavingMessage && ( + + )} +
+ {(() => { + const imageBlocks = parsed.blocks.filter( + (b): b is Extract => + b.type === "file" && b.mediaType.startsWith("image/"), + ); + if (imageBlocks.length === 0) return null; + return ( +
+ {imageBlocks.map((block, i) => { + const src = block.fileId + ? `/api/experimental/chats/files/${block.fileId}` + : `data:${block.mediaType};base64,${block.data}`; + return ( + + ); + })} +
+ ); + })()} + {fadeFromBottom && ( +
)} -
- {fadeFromBottom && ( -
- )} - - - ) : ( - - -
- {orderedBlocks} - {remainingTools.map((tool) => ( - - ))} - {!hasRenderableContent && ( -
- Message has no renderable content. -
- )} -
-
-
+ + + ) : ( + + +
+ {orderedBlocks} + {remainingTools.map((tool) => ( + + ))} + {!hasRenderableContent && ( +
+ Message has no renderable content. +
+ )} +
+
+
+ )} + + {previewImage && ( + setPreviewImage(null)} + /> )} - + ); }, ); @@ -405,7 +493,11 @@ StreamingOutput.displayName = "StreamingOutput"; const StickyUserMessage: FC<{ message: TypesGen.ChatMessage; parsed: ParsedMessageContent; - onEditUserMessage?: (messageId: number, text: string) => void; + onEditUserMessage?: ( + messageId: number, + text: string, + fileBlocks?: Array<{ mediaType: string; data?: string; fileId?: string }>, + ) => void; editingMessageId?: number | null; savingMessageId?: number | null; }> = ({ @@ -540,8 +632,16 @@ const StickyUserMessage: FC<{ }, [isStuck]); const handleEditUserMessage = onEditUserMessage - ? (messageId: number, text: string) => { - onEditUserMessage(messageId, text); + ? ( + messageId: number, + text: string, + fileBlocks?: Array<{ + mediaType: string; + data?: string; + fileId?: string; + }>, + ) => { + onEditUserMessage(messageId, text, fileBlocks); requestAnimationFrame(() => { const sentinel = sentinelRef.current; if (!sentinel) return; @@ -653,7 +753,11 @@ type ConversationTimelineProps = { retryState?: { attempt: number; error: string } | null; isAwaitingFirstStreamChunk: boolean; detailErrorMessage?: string | null; - onEditUserMessage?: (messageId: number, text: string) => void; + onEditUserMessage?: ( + messageId: number, + text: string, + fileBlocks?: Array<{ mediaType: string; data?: string; fileId?: string }>, + ) => void; editingMessageId?: number | null; savingMessageId?: number | null; }; diff --git a/site/src/pages/AgentsPage/AgentDetail/messageParsing.test.ts b/site/src/pages/AgentsPage/AgentDetail/messageParsing.test.ts index edb58526b8..279bd29d26 100644 --- a/site/src/pages/AgentsPage/AgentDetail/messageParsing.test.ts +++ b/site/src/pages/AgentsPage/AgentDetail/messageParsing.test.ts @@ -229,6 +229,40 @@ describe("parseMessageContent", () => { expect(result.toolCalls).toHaveLength(1); expect(result.toolCalls[0].name).toBe("test"); }); + + it("extracts fileId from a file block with file_id", () => { + const result = parseMessageContent([ + { + type: "file", + media_type: "image/png", + file_id: "abc-123-def", + }, + ]); + expect(result.blocks).toHaveLength(1); + expect(result.blocks[0]).toEqual({ + type: "file", + mediaType: "image/png", + data: undefined, + fileId: "abc-123-def", + }); + }); + + it("parses a file block without file_id (backward compat)", () => { + const result = parseMessageContent([ + { + type: "file", + media_type: "image/png", + data: "iVBORw0KGgo=", + }, + ]); + expect(result.blocks).toHaveLength(1); + expect(result.blocks[0]).toEqual({ + type: "file", + mediaType: "image/png", + data: "iVBORw0KGgo=", + fileId: undefined, + }); + }); }); describe("mergeTools", () => { diff --git a/site/src/pages/AgentsPage/AgentDetail/messageParsing.ts b/site/src/pages/AgentsPage/AgentDetail/messageParsing.ts index a4731f414d..fff9eca30c 100644 --- a/site/src/pages/AgentsPage/AgentDetail/messageParsing.ts +++ b/site/src/pages/AgentsPage/AgentDetail/messageParsing.ts @@ -216,6 +216,23 @@ export const parseMessageContent = (content: unknown): ParsedMessageContent => { parsed.blocks = ensureToolBlock(parsed.blocks, id); break; } + case "file": { + const mediaType = asString(typedBlock.media_type); + const data = asString(typedBlock.data); + const fileId = asString(typedBlock.file_id); + if (mediaType && (data || fileId)) { + parsed.blocks = [ + ...parsed.blocks, + { + type: "file", + mediaType, + data: data || undefined, + fileId: fileId || undefined, + }, + ]; + } + break; + } default: { const text = asString(typedBlock.text); parsed.markdown = appendText(parsed.markdown, text); diff --git a/site/src/pages/AgentsPage/AgentDetail/streamState.ts b/site/src/pages/AgentsPage/AgentDetail/streamState.ts index 9681be3f5b..b4291d6ff9 100644 --- a/site/src/pages/AgentsPage/AgentDetail/streamState.ts +++ b/site/src/pages/AgentsPage/AgentDetail/streamState.ts @@ -150,6 +150,26 @@ export const applyMessagePartToStreamState = ( }, }; } + case "file": { + const mediaType = asString(part.media_type); + const data = asString(part.data); + const fileId = asString(part.file_id); + if (!mediaType || (!data && !fileId)) { + return prev; + } + return { + ...nextState, + blocks: [ + ...nextState.blocks, + { + type: "file", + mediaType, + data: data || undefined, + fileId: fileId || undefined, + }, + ], + }; + } default: return prev; } diff --git a/site/src/pages/AgentsPage/AgentDetail/types.ts b/site/src/pages/AgentsPage/AgentDetail/types.ts index efb436e0dc..7e01d864c2 100644 --- a/site/src/pages/AgentsPage/AgentDetail/types.ts +++ b/site/src/pages/AgentsPage/AgentDetail/types.ts @@ -35,6 +35,12 @@ export type RenderBlock = | { type: "tool"; id: string; + } + | { + type: "file"; + mediaType: string; + data?: string; // base64, absent when file_id is available + fileId?: string; }; export type ParsedMessageContent = { diff --git a/site/src/pages/AgentsPage/AgentsPage.stories.tsx b/site/src/pages/AgentsPage/AgentsPage.stories.tsx index c6a2f81958..7b65c9b35f 100644 --- a/site/src/pages/AgentsPage/AgentsPage.stories.tsx +++ b/site/src/pages/AgentsPage/AgentsPage.stories.tsx @@ -1,4 +1,5 @@ import { MockWorkspace } from "testHelpers/entities"; +import { withDashboardProvider } from "testHelpers/storybook"; import type { Meta, StoryObj } from "@storybook/react-vite"; import { API } from "api/api"; import { @@ -26,6 +27,7 @@ const behaviorStorageKey = "agents.system-prompt"; const meta: Meta = { title: "pages/AgentsPage/AgentsEmptyState", component: AgentsEmptyState, + decorators: [withDashboardProvider], args: { onCreateChat: fn(), isCreating: false, diff --git a/site/src/pages/AgentsPage/AgentsPage.tsx b/site/src/pages/AgentsPage/AgentsPage.tsx index cfbd60318c..a00a78c204 100644 --- a/site/src/pages/AgentsPage/AgentsPage.tsx +++ b/site/src/pages/AgentsPage/AgentsPage.tsx @@ -59,6 +59,7 @@ import { } from "./modelOptions"; import { useAgentsPageKeybindings } from "./useAgentsPageKeybindings"; import { useAgentsPWA } from "./useAgentsPWA"; +import { useFileAttachments } from "./useFileAttachments"; import { WebPushButton } from "./WebPushButton"; /** @internal Exported for testing. */ @@ -72,6 +73,7 @@ type ChatModelOption = ModelSelectorOption; type CreateChatOptions = { message: string; + fileIDs?: string[]; workspaceId?: string; model?: string; }; @@ -329,11 +331,20 @@ const AgentsPage: FC = () => { ], ); const handleCreateChat = async (options: CreateChatOptions) => { - const { message, workspaceId, model } = options; + const { message, fileIDs, workspaceId, model } = options; const modelConfigID = (model && modelConfigIDByModelID.get(model)) || nilUUID; + const content: TypesGen.ChatInputPart[] = []; + if (message.trim()) { + content.push({ type: "text", text: message }); + } + if (fileIDs) { + for (const fileID of fileIDs) { + content.push({ type: "file", file_id: fileID }); + } + } const createdChat = await createMutation.mutateAsync({ - content: [{ type: "text", text: message }], + content, workspace_id: workspaceId, model_config_id: modelConfigID, }); @@ -686,6 +697,7 @@ export const AgentsEmptyState: FC = ({ isConfigureAgentsDialogOpen, onConfigureAgentsDialogOpenChange, }) => { + const { organizations } = useDashboard(); const { initialInputValue, handleContentChange, submitDraft, resetDraft } = useEmptyStateDraft(); const initialSystemPrompt = () => { @@ -855,10 +867,11 @@ export const AgentsEmptyState: FC = ({ ); const handleSend = useCallback( - (message: string) => { + async (message: string, fileIDs?: string[]) => { submitDraft(); - void onCreateChat({ + await onCreateChat({ message, + fileIDs, workspaceId: selectedWorkspaceIdRef.current ?? undefined, model: selectedModelRef.current || undefined, }).catch(() => { @@ -877,6 +890,44 @@ export const AgentsEmptyState: FC = ({ ? `${selectedWorkspace.owner_name}/${selectedWorkspace.name}` : undefined; + const { + attachments, + uploadStates, + previewUrls, + handleAttach, + handleRemoveAttachment, + resetAttachments, + } = useFileAttachments(organizations[0]?.id); + + const handleSendWithAttachments = useCallback( + async (message: string) => { + const fileIds: string[] = []; + let skippedErrors = 0; + for (const file of attachments) { + const state = uploadStates.get(file); + if (state?.status === "error") { + skippedErrors++; + continue; + } + if (state?.status === "uploaded" && state.fileId) { + fileIds.push(state.fileId); + } + } + if (skippedErrors > 0) { + toast.warning( + `${skippedErrors} attachment${skippedErrors > 1 ? "s" : ""} could not be sent (upload failed)`, + ); + } + try { + await handleSend(message, fileIds.length > 0 ? fileIds : undefined); + resetAttachments(); + } catch { + // Attachments preserved for retry on failure. + } + }, + [attachments, handleSend, resetAttachments, uploadStates], + ); + return (
@@ -886,7 +937,7 @@ export const AgentsEmptyState: FC = ({ )} = ({ hasModelOptions={hasModelOptions} inputStatusText={inputStatusText} modelCatalogStatusMessage={modelCatalogStatusMessage} + attachments={attachments} + onAttach={handleAttach} + onRemoveAttachment={handleRemoveAttachment} + uploadStates={uploadStates} + previewUrls={previewUrls} leftActions={ + new File(["mock-data"], name, { type }); + +const meta: Meta = { + title: "pages/AgentsPage/AttachmentPreview", + component: AttachmentPreview, + decorators: [ + (Story) => ( +
+ +
+ ), + ], + args: { + onRemove: fn(), + onPreview: fn(), + }, +}; + +export default meta; +type Story = StoryObj; + +export const SingleImage: Story = { + args: (() => { + const file = createMockFile("photo.png", "image/png"); + return { + attachments: [file], + uploadStates: new Map([ + [file, { status: "uploaded", fileId: "file-1" }], + ]), + previewUrls: new Map([[file, TINY_PNG]]), + }; + })(), +}; + +export const MultipleImages: Story = { + args: (() => { + const files = [ + createMockFile("photo-1.png", "image/png"), + createMockFile("photo-2.jpg", "image/jpeg"), + createMockFile("photo-3.png", "image/png"), + ]; + return { + attachments: files, + uploadStates: new Map( + files.map((f) => [f, { status: "uploaded", fileId: f.name }]), + ), + previewUrls: new Map(files.map((f) => [f, TINY_PNG])), + }; + })(), +}; + +export const Uploading: Story = { + args: (() => { + const file = createMockFile("uploading.png", "image/png"); + return { + attachments: [file], + uploadStates: new Map([ + [file, { status: "uploading" }], + ]), + previewUrls: new Map([[file, TINY_PNG]]), + }; + })(), +}; + +export const UploadError: Story = { + args: (() => { + const file = createMockFile("broken.png", "image/png"); + return { + attachments: [file], + uploadStates: new Map([ + [file, { status: "error", error: "Upload failed: server error" }], + ]), + previewUrls: new Map([[file, TINY_PNG]]), + }; + })(), +}; + +export const FileTooLarge: Story = { + args: (() => { + const file = createMockFile("huge-screenshot.png", "image/png"); + return { + attachments: [file], + uploadStates: new Map([ + [ + file, + { + status: "error", + error: "File too large (12.4 MB). Maximum is 10 MB.", + }, + ], + ]), + previewUrls: new Map([[file, TINY_PNG]]), + }; + })(), + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + const overlay = canvas.getByLabelText("Upload error"); + await userEvent.hover(overlay); + }, +}; + +export const NonImageFile: Story = { + args: (() => { + const file = createMockFile("readme.txt", "text/plain"); + return { + attachments: [file], + uploadStates: new Map([ + [file, { status: "uploaded", fileId: "file-txt" }], + ]), + }; + })(), +}; + +export const MixedStates: Story = { + args: (() => { + const uploaded = createMockFile("done.png", "image/png"); + const uploading = createMockFile("pending.jpg", "image/jpeg"); + const errored = createMockFile("failed.png", "image/png"); + const attachments = [uploaded, uploading, errored]; + return { + attachments, + uploadStates: new Map([ + [uploaded, { status: "uploaded", fileId: "file-ok" }], + [uploading, { status: "uploading" }], + [errored, { status: "error", error: "Network timeout" }], + ]), + previewUrls: new Map([ + [uploaded, TINY_PNG], + [uploading, TINY_PNG], + [errored, TINY_PNG], + ]), + }; + })(), +}; diff --git a/site/src/pages/AgentsPage/ImageLightbox.stories.tsx b/site/src/pages/AgentsPage/ImageLightbox.stories.tsx new file mode 100644 index 0000000000..6bc3bc9b27 --- /dev/null +++ b/site/src/pages/AgentsPage/ImageLightbox.stories.tsx @@ -0,0 +1,30 @@ +import type { Meta, StoryObj } from "@storybook/react-vite"; +import { fn } from "storybook/test"; +import { ImageLightbox } from "./ImageLightbox"; + +// Tiny 1x1 colored PNG so the lightbox has something visible to display. +const TINY_PNG = + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="; + +const meta: Meta = { + title: "pages/AgentsPage/ImageLightbox", + component: ImageLightbox, + decorators: [ + (Story) => ( +
+

Background content behind the lightbox overlay

+ +
+ ), + ], +}; + +export default meta; +type Story = StoryObj; + +export const Default: Story = { + args: { + src: TINY_PNG, + onClose: fn(), + }, +}; diff --git a/site/src/pages/AgentsPage/ImageLightbox.tsx b/site/src/pages/AgentsPage/ImageLightbox.tsx new file mode 100644 index 0000000000..a6a2ae8b2b --- /dev/null +++ b/site/src/pages/AgentsPage/ImageLightbox.tsx @@ -0,0 +1,25 @@ +import { Dialog, DialogContent, DialogTitle } from "components/Dialog/Dialog"; +import type { FC } from "react"; + +interface ImageLightboxProps { + src: string; + onClose: () => void; +} + +export const ImageLightbox: FC = ({ src, onClose }) => { + return ( + !open && onClose()}> + + Image preview + Attachment preview + + + ); +}; diff --git a/site/src/pages/AgentsPage/useFileAttachments.ts b/site/src/pages/AgentsPage/useFileAttachments.ts new file mode 100644 index 0000000000..2c1a7b6228 --- /dev/null +++ b/site/src/pages/AgentsPage/useFileAttachments.ts @@ -0,0 +1,160 @@ +import { API } from "api/api"; +import { getErrorDetail, getErrorMessage } from "api/errors"; +import { + type Dispatch, + type SetStateAction, + useCallback, + useEffect, + useRef, + useState, +} from "react"; +import type { UploadState } from "./AgentChatInput"; + +interface UseFileAttachmentsReturn { + attachments: File[]; + uploadStates: Map; + previewUrls: Map; + handleAttach: (files: File[]) => void; + handleRemoveAttachment: (index: number) => void; + startUpload: (file: File) => void; + resetAttachments: () => void; + setAttachments: Dispatch>; + setPreviewUrls: Dispatch>>; + setUploadStates: Dispatch>>; +} + +export function useFileAttachments( + organizationId: string | undefined, +): UseFileAttachmentsReturn { + const [attachments, setAttachments] = useState([]); + const [uploadStates, setUploadStates] = useState( + () => new Map(), + ); + const [previewUrls, setPreviewUrls] = useState(() => new Map()); + + // Revoke blob URLs on unmount to prevent memory leaks. + const previewUrlsRef = useRef(previewUrls); + previewUrlsRef.current = previewUrls; + useEffect(() => { + return () => { + for (const [, url] of previewUrlsRef.current) { + if (url.startsWith("blob:")) URL.revokeObjectURL(url); + } + }; + }, []); + + const startUpload = useCallback( + (file: File) => { + if (!organizationId) { + setUploadStates((prev) => + new Map(prev).set(file, { + status: "error", + error: "Unable to upload: no organization context.", + }), + ); + return; + } + setUploadStates((prev) => + new Map(prev).set(file, { status: "uploading" }), + ); + void (async () => { + try { + const result = await API.uploadChatFile(file, organizationId); + setUploadStates((prev) => + new Map(prev).set(file, { + status: "uploaded", + fileId: result.id, + }), + ); + // Pre-warm the browser HTTP cache so the timeline + // can render this image instantly after send. The + // server responds with Cache-Control: private, + // immutable, so the never hits the + // network again. + void fetch(`/api/experimental/chats/files/${result.id}`); + } catch (err: unknown) { + const message = getErrorMessage(err, "Upload failed"); + const detail = getErrorDetail(err); + const errorMessage = detail ? `${message} ${detail}` : message; + setUploadStates((prev) => + new Map(prev).set(file, { + status: "error", + error: errorMessage, + }), + ); + } + })(); + }, + [organizationId], + ); + + const handleAttach = useCallback( + (files: File[]) => { + const maxSize = 10 * 1024 * 1024; // 10 MB + setAttachments((prev) => [...prev, ...files]); + setPreviewUrls((prev) => { + const next = new Map(prev); + for (const file of files) { + next.set(file, URL.createObjectURL(file)); + } + return next; + }); + for (const file of files) { + if (file.size > maxSize) { + setUploadStates((prev) => + new Map(prev).set(file, { + status: "error" as const, + error: `File too large (${(file.size / 1024 / 1024).toFixed(1)} MB). Maximum is 10 MB.`, + }), + ); + } else { + startUpload(file); + } + } + }, + [startUpload], + ); + + const handleRemoveAttachment = useCallback((index: number) => { + setAttachments((prev) => { + const removed = prev[index]; + if (removed) { + setUploadStates((prevStates) => { + const next = new Map(prevStates); + next.delete(removed); + return next; + }); + setPreviewUrls((prevUrls) => { + const url = prevUrls.get(removed); + if (url?.startsWith("blob:")) URL.revokeObjectURL(url); + const next = new Map(prevUrls); + next.delete(removed); + return next; + }); + } + return prev.filter((_, i) => i !== index); + }); + }, []); + + const resetAttachments = useCallback(() => { + for (const [, url] of previewUrlsRef.current) { + if (url.startsWith("blob:")) URL.revokeObjectURL(url); + } + setPreviewUrls(new Map()); + setUploadStates(new Map()); + setAttachments([]); + }, []); + + return { + attachments, + uploadStates, + previewUrls, + handleAttach, + handleRemoveAttachment, + startUpload, + resetAttachments, + setAttachments, + setPreviewUrls, + setUploadStates, + }; +}