diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 53b65527f0..09f8a5e26d 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -481,6 +481,128 @@ const docTemplate = `{ } } }, + "/chats/files": { + "post": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "consumes": [ + "application/octet-stream" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Chats" + ], + "summary": "Upload a chat file", + "operationId": "upload-chat-file", + "parameters": [ + { + "type": "string", + "description": "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)", + "name": "Content-Type", + "in": "header", + "required": true + }, + { + "type": "string", + "format": "uuid", + "description": "Organization ID", + "name": "organization", + "in": "query", + "required": true + } + ], + "responses": { + "201": { + "description": "Created", + "schema": { + "$ref": "#/definitions/codersdk.UploadChatFileResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "413": { + "description": "Request Entity Too Large", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + } + } + } + }, + "/chats/files/{file}": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "tags": [ + "Chats" + ], + "summary": "Get a chat file", + "operationId": "get-chat-file", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "File ID", + "name": "file", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + } + } + } + }, "/chats/{chat}/archive": { "post": { "tags": [ @@ -20334,6 +20456,15 @@ const docTemplate = `{ } } }, + "codersdk.UploadChatFileResponse": { + "type": "object", + "properties": { + "id": { + "type": "string", + "format": "uuid" + } + } + }, "codersdk.UploadResponse": { "type": "object", "properties": { diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index b24b130c48..a234b00b51 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -410,6 +410,120 @@ } } }, + "/chats/files": { + "post": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "consumes": ["application/octet-stream"], + "produces": ["application/json"], + "tags": ["Chats"], + "summary": "Upload a chat file", + "operationId": "upload-chat-file", + "parameters": [ + { + "type": "string", + "description": "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)", + "name": "Content-Type", + "in": "header", + "required": true + }, + { + "type": "string", + "format": "uuid", + "description": "Organization ID", + "name": "organization", + "in": "query", + "required": true + } + ], + "responses": { + "201": { + "description": "Created", + "schema": { + "$ref": "#/definitions/codersdk.UploadChatFileResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "413": { + "description": "Request Entity Too Large", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + } + } + } + }, + "/chats/files/{file}": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "tags": ["Chats"], + "summary": "Get a chat file", + "operationId": "get-chat-file", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "File ID", + "name": "file", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/codersdk.Response" + } + } + } + } + }, "/chats/{chat}/archive": { "post": { "tags": ["Chats"], @@ -18650,6 +18764,15 @@ } } }, + "codersdk.UploadChatFileResponse": { + "type": "object", + "properties": { + "id": { + "type": "string", + "format": "uuid" + } + } + }, "codersdk.UploadResponse": { "type": "object", "properties": { diff --git a/coderd/chatd/chatd.go b/coderd/chatd/chatd.go index b03dbb48c2..011fbc71e6 100644 --- a/coderd/chatd/chatd.go +++ b/coderd/chatd/chatd.go @@ -165,6 +165,9 @@ type CreateOptions struct { ModelConfigID uuid.UUID SystemPrompt string InitialUserContent []fantasy.Content + // ContentFileIDs maps content block indices to their chat_files IDs + // so the file_id can be preserved in the stored message JSON. + ContentFileIDs map[int]uuid.UUID } // SendMessageBusyBehavior controls what happens when a chat is already active. @@ -180,10 +183,11 @@ const ( // SendMessageOptions controls user message insertion with busy-state behavior. type SendMessageOptions struct { - ChatID uuid.UUID - Content []fantasy.Content - ModelConfigID *uuid.UUID - BusyBehavior SendMessageBusyBehavior + ChatID uuid.UUID + Content []fantasy.Content + ContentFileIDs map[int]uuid.UUID + ModelConfigID *uuid.UUID + BusyBehavior SendMessageBusyBehavior } // SendMessageResult contains the outcome of user message processing. @@ -199,6 +203,7 @@ type EditMessageOptions struct { ChatID uuid.UUID EditedMessageID int64 Content []fantasy.Content + ContentFileIDs map[int]uuid.UUID } // EditMessageResult contains the updated user message and chat status. @@ -278,7 +283,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C } } - userContent, err := chatprompt.MarshalContent(opts.InitialUserContent) + userContent, err := chatprompt.MarshalContent(opts.InitialUserContent, opts.ContentFileIDs) if err != nil { return xerrors.Errorf("marshal initial user content: %w", err) } @@ -345,7 +350,7 @@ func (p *Server) SendMessage( return SendMessageResult{}, xerrors.Errorf("invalid busy behavior %q", opts.BusyBehavior) } - content, err := chatprompt.MarshalContent(opts.Content) + content, err := chatprompt.MarshalContent(opts.Content, opts.ContentFileIDs) if err != nil { return SendMessageResult{}, xerrors.Errorf("marshal message content: %w", err) } @@ -448,7 +453,7 @@ func (p *Server) EditMessage( return EditMessageResult{}, xerrors.New("content is required") } - content, err := chatprompt.MarshalContent(opts.Content) + content, err := chatprompt.MarshalContent(opts.Content, opts.ContentFileIDs) if err != nil { return EditMessageResult{}, xerrors.Errorf("marshal message content: %w", err) } @@ -1607,6 +1612,25 @@ func (p *Server) subscribeChatControl( return controlCancel } +// chatFileResolver returns a FileResolver that fetches chat file +// content from the database by ID. +func (p *Server) chatFileResolver() chatprompt.FileResolver { + return func(ctx context.Context, ids []uuid.UUID) (map[uuid.UUID]chatprompt.FileData, error) { + files, err := p.db.GetChatFilesByIDs(ctx, ids) + if err != nil { + return nil, err + } + result := make(map[uuid.UUID]chatprompt.FileData, len(files)) + for _, f := range files { + result[f.ID] = chatprompt.FileData{ + Data: f.Data, + MediaType: f.Mimetype, + } + } + return result, nil + } +} + func (p *Server) processChat(ctx context.Context, chat database.Chat) { logger := p.logger.With(slog.F("chat_id", chat.ID)) logger.Info(ctx, "processing chat request") @@ -1922,7 +1946,7 @@ func (p *Server) runChat( p.maybeGenerateChatTitle(context.WithoutCancel(ctx), chat, messages, model, providerKeys, logger) }() - prompt, err := chatprompt.ConvertMessages(messages) + prompt, err := chatprompt.ConvertMessagesWithFiles(ctx, messages, p.chatFileResolver()) if err != nil { return xerrors.Errorf("build chat prompt: %w", err) } @@ -2064,7 +2088,7 @@ func (p *Server) runChat( } if len(assistantBlocks) > 0 { - assistantContent, err := chatprompt.MarshalContent(assistantBlocks) + assistantContent, err := chatprompt.MarshalContent(assistantBlocks, nil) if err != nil { return err } @@ -2270,7 +2294,7 @@ func (p *Server) runChat( if err != nil { return nil, xerrors.Errorf("reload chat messages: %w", err) } - reloadedPrompt, err := chatprompt.ConvertMessages(reloadedMsgs) + reloadedPrompt, err := chatprompt.ConvertMessagesWithFiles(reloadCtx, reloadedMsgs, p.chatFileResolver()) if err != nil { return nil, xerrors.Errorf("convert reloaded messages: %w", err) } @@ -2363,7 +2387,7 @@ func (p *Server) persistChatContextSummary( ToolName: "chat_summarized", Input: string(args), }, - }) + }, nil) if err != nil { return xerrors.Errorf("encode summary tool call: %w", err) } diff --git a/coderd/chatd/chatprompt/chatprompt.go b/coderd/chatd/chatprompt/chatprompt.go index ff6ad99368..a7d07100cc 100644 --- a/coderd/chatd/chatprompt/chatprompt.go +++ b/coderd/chatd/chatprompt/chatprompt.go @@ -1,12 +1,14 @@ package chatprompt import ( + "context" "encoding/json" "regexp" "strings" "charm.land/fantasy" fantasyopenai "charm.land/fantasy/providers/openai" + "github.com/google/uuid" "github.com/sqlc-dev/pqtype" "golang.org/x/xerrors" @@ -16,12 +18,156 @@ import ( var toolCallIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`) +// FileData holds resolved file content for LLM prompt building. +type FileData struct { + Data []byte + MediaType string +} + +// FileResolver fetches file content by ID for LLM prompt building. +type FileResolver func(ctx context.Context, ids []uuid.UUID) (map[uuid.UUID]FileData, error) + +// ExtractFileID parses the file_id from a serialized file content +// block envelope. Returns uuid.Nil and an error when the block is +// not a file-type block or has no file_id. +func ExtractFileID(raw json.RawMessage) (uuid.UUID, error) { + var envelope struct { + Type string `json:"type"` + Data struct { + FileID string `json:"file_id"` + } `json:"data"` + } + if err := json.Unmarshal(raw, &envelope); err != nil { + return uuid.Nil, xerrors.Errorf("unmarshal content block: %w", err) + } + if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeFile)) { + return uuid.Nil, xerrors.Errorf("not a file content block: %s", envelope.Type) + } + if envelope.Data.FileID == "" { + return uuid.Nil, xerrors.New("no file_id") + } + return uuid.Parse(envelope.Data.FileID) +} + +// extractFileIDs scans raw message content for file_id references. +// Returns a map of block index to file ID. Returns nil for +// non-array content or content with no file references. +func extractFileIDs(raw pqtype.NullRawMessage) map[int]uuid.UUID { + if !raw.Valid || len(raw.RawMessage) == 0 { + return nil + } + var rawBlocks []json.RawMessage + if err := json.Unmarshal(raw.RawMessage, &rawBlocks); err != nil { + return nil + } + var result map[int]uuid.UUID + for i, block := range rawBlocks { + fid, err := ExtractFileID(block) + if err == nil { + if result == nil { + result = make(map[int]uuid.UUID) + } + result[i] = fid + } + } + return result +} + +// patchFileContent fills in empty Data on FileContent blocks from +// resolved file data. Blocks that already have inline data (backward +// compat) or have no resolved data are left unchanged. +func patchFileContent( + content []fantasy.Content, + fileIDs map[int]uuid.UUID, + resolved map[uuid.UUID]FileData, +) { + for blockIdx, fid := range fileIDs { + if blockIdx >= len(content) { + continue + } + switch fc := content[blockIdx].(type) { + case fantasy.FileContent: + if len(fc.Data) > 0 { + continue + } + if data, found := resolved[fid]; found { + fc.Data = data.Data + content[blockIdx] = fc + } + case *fantasy.FileContent: + if len(fc.Data) > 0 { + continue + } + if data, found := resolved[fid]; found { + fc.Data = data.Data + } + } + } +} + +// ConvertMessages converts persisted chat messages into LLM prompt +// messages without resolving file references from storage. Inline +// file data is preserved when present (backward compat). func ConvertMessages( messages []database.ChatMessage, ) ([]fantasy.Message, error) { + return ConvertMessagesWithFiles(context.Background(), messages, nil) +} + +// ConvertMessagesWithFiles converts persisted chat messages into LLM +// prompt messages, resolving file references via the provided +// resolver. When resolver is nil, file blocks without inline data +// are passed through as-is (same behavior as ConvertMessages). +func ConvertMessagesWithFiles( + ctx context.Context, + messages []database.ChatMessage, + resolver FileResolver, +) ([]fantasy.Message, error) { + // Phase 1: Pre-scan user messages for file_id references. + var allFileIDs []uuid.UUID + seenFileIDs := make(map[uuid.UUID]struct{}) + fileIDsByMsg := make(map[int]map[int]uuid.UUID) + + if resolver != nil { + for i, msg := range messages { + visibility := msg.Visibility + if visibility == "" { + visibility = database.ChatMessageVisibilityBoth + } + if visibility != database.ChatMessageVisibilityModel && + visibility != database.ChatMessageVisibilityBoth { + continue + } + if msg.Role != string(fantasy.MessageRoleUser) { + continue + } + fids := extractFileIDs(msg.Content) + if len(fids) > 0 { + fileIDsByMsg[i] = fids + for _, fid := range fids { + if _, seen := seenFileIDs[fid]; !seen { + seenFileIDs[fid] = struct{}{} + allFileIDs = append(allFileIDs, fid) + } + } + } + } + } + + // Phase 2: Batch resolve file data. + var resolved map[uuid.UUID]FileData + if len(allFileIDs) > 0 { + var err error + resolved, err = resolver(ctx, allFileIDs) + if err != nil { + return nil, xerrors.Errorf("resolve chat files: %w", err) + } + } + + // Phase 3: Convert messages, patching file content as needed. prompt := make([]fantasy.Message, 0, len(messages)) toolNameByCallID := make(map[string]string) - for _, message := range messages { + for i, message := range messages { visibility := message.Visibility if visibility == "" { visibility = database.ChatMessageVisibilityBoth @@ -51,6 +197,9 @@ func ConvertMessages( if err != nil { return nil, err } + if fids, ok := fileIDsByMsg[i]; ok { + patchFileContent(content, fids, resolved) + } prompt = append(prompt, fantasy.Message{ Role: fantasy.MessageRoleUser, Content: ToMessageParts(content), @@ -400,7 +549,10 @@ func ExtractToolCalls(parts []fantasy.MessagePart) []fantasy.ToolCallContent { } // MarshalContent encodes message content blocks for persistence. -func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) { +// fileIDs optionally maps block indices to chat_files IDs, which +// are injected into the JSON envelope for file-type blocks so +// the reference survives round-trips through storage. +func MarshalContent(blocks []fantasy.Content, fileIDs map[int]uuid.UUID) (pqtype.NullRawMessage, error) { if len(blocks) == 0 { return pqtype.NullRawMessage{}, nil } @@ -415,6 +567,16 @@ func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) { err, ) } + if fid, ok := fileIDs[i]; ok { + encoded, err = injectFileID(encoded, fid) + if err != nil { + return pqtype.NullRawMessage{}, xerrors.Errorf( + "inject file_id into content block %d: %w", + i, + err, + ) + } + } encodedBlocks = append(encodedBlocks, encoded) } @@ -425,6 +587,27 @@ func MarshalContent(blocks []fantasy.Content) (pqtype.NullRawMessage, error) { return pqtype.NullRawMessage{RawMessage: data, Valid: true}, nil } +// injectFileID adds a file_id field into the data sub-object of a +// serialized content block envelope. This follows the same pattern +// as the reasoning title injection in marshalContentBlock. +func injectFileID(encoded json.RawMessage, fileID uuid.UUID) (json.RawMessage, error) { + var envelope struct { + Type string `json:"type"` + Data struct { + MediaType string `json:"media_type"` + Data json.RawMessage `json:"data,omitempty"` + FileID string `json:"file_id,omitempty"` + ProviderMetadata *json.RawMessage `json:"provider_metadata,omitempty"` + } `json:"data"` + } + if err := json.Unmarshal(encoded, &envelope); err != nil { + return encoded, err + } + envelope.Data.FileID = fileID.String() + envelope.Data.Data = nil // Strip inline data; resolved at LLM dispatch time. + return json.Marshal(envelope) +} + // MarshalToolResult encodes a single tool result for persistence as // an opaque JSON blob. The stored shape is // [{"tool_call_id":…,"tool_name":…,"result":…,"is_error":…}]. diff --git a/coderd/chatd/chatprompt/chatprompt_test.go b/coderd/chatd/chatprompt/chatprompt_test.go index ba398446a1..56d3124366 100644 --- a/coderd/chatd/chatprompt/chatprompt_test.go +++ b/coderd/chatd/chatprompt/chatprompt_test.go @@ -1,10 +1,13 @@ package chatprompt_test import ( + "context" "encoding/json" "testing" "charm.land/fantasy" + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/chatd/chatprompt" @@ -52,7 +55,7 @@ func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) { ToolName: "execute", Input: tc.input, }, - }) + }, nil) require.NoError(t, err) toolContent, err := chatprompt.MarshalToolResult( @@ -89,3 +92,139 @@ func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) { }) } } + +func TestConvertMessagesWithFiles_ResolvesFileData(t *testing.T) { + t.Parallel() + + fileID := uuid.New() + fileData := []byte("fake-image-bytes") + + // Build a user message with file_id but no inline data, as + // would be stored after injectFileID strips the data. + rawContent := mustJSON(t, []json.RawMessage{ + mustJSON(t, map[string]any{ + "type": "file", + "data": map[string]any{ + "media_type": "image/png", + "file_id": fileID.String(), + }, + }), + }) + + resolver := func(_ context.Context, ids []uuid.UUID) (map[uuid.UUID]chatprompt.FileData, error) { + result := make(map[uuid.UUID]chatprompt.FileData) + for _, id := range ids { + if id == fileID { + result[id] = chatprompt.FileData{ + Data: fileData, + MediaType: "image/png", + } + } + } + return result, nil + } + + prompt, err := chatprompt.ConvertMessagesWithFiles( + context.Background(), + []database.ChatMessage{ + { + Role: string(fantasy.MessageRoleUser), + Visibility: database.ChatMessageVisibilityBoth, + Content: pqtype.NullRawMessage{RawMessage: rawContent, Valid: true}, + }, + }, + resolver, + ) + require.NoError(t, err) + require.Len(t, prompt, 1) + require.Equal(t, fantasy.MessageRoleUser, prompt[0].Role) + require.Len(t, prompt[0].Content, 1) + + filePart, ok := fantasy.AsMessagePart[fantasy.FilePart](prompt[0].Content[0]) + require.True(t, ok, "expected FilePart") + require.Equal(t, fileData, filePart.Data) + require.Equal(t, "image/png", filePart.MediaType) +} + +func TestConvertMessagesWithFiles_BackwardCompat(t *testing.T) { + t.Parallel() + + // A message with inline data and a file_id should use the + // inline data even when the resolver returns nothing. + fileID := uuid.New() + inlineData := []byte("inline-image-data") + + rawContent := mustJSON(t, []json.RawMessage{ + mustJSON(t, map[string]any{ + "type": "file", + "data": map[string]any{ + "media_type": "image/png", + "data": inlineData, + "file_id": fileID.String(), + }, + }), + }) + + prompt, err := chatprompt.ConvertMessagesWithFiles( + context.Background(), + []database.ChatMessage{ + { + Role: string(fantasy.MessageRoleUser), + Visibility: database.ChatMessageVisibilityBoth, + Content: pqtype.NullRawMessage{RawMessage: rawContent, Valid: true}, + }, + }, + nil, // No resolver. + ) + require.NoError(t, err) + require.Len(t, prompt, 1) + require.Len(t, prompt[0].Content, 1) + + filePart, ok := fantasy.AsMessagePart[fantasy.FilePart](prompt[0].Content[0]) + require.True(t, ok, "expected FilePart") + require.Equal(t, inlineData, filePart.Data) +} + +func TestInjectFileID_StripsInlineData(t *testing.T) { + t.Parallel() + + fileID := uuid.New() + imageData := []byte("raw-image-bytes") + + // Marshal a file content block with inline data, then inject + // a file_id. The result should have file_id but no data. + content, err := chatprompt.MarshalContent([]fantasy.Content{ + fantasy.FileContent{ + MediaType: "image/png", + Data: imageData, + }, + }, map[int]uuid.UUID{0: fileID}) + require.NoError(t, err) + + // Parse the stored content to verify shape. + var blocks []json.RawMessage + require.NoError(t, json.Unmarshal(content.RawMessage, &blocks)) + require.Len(t, blocks, 1) + + var envelope struct { + Type string `json:"type"` + Data struct { + MediaType string `json:"media_type"` + Data *json.RawMessage `json:"data,omitempty"` + FileID string `json:"file_id"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(blocks[0], &envelope)) + require.Equal(t, "file", envelope.Type) + require.Equal(t, "image/png", envelope.Data.MediaType) + require.Equal(t, fileID.String(), envelope.Data.FileID) + // Data should be nil (omitted) since injectFileID strips it. + require.Nil(t, envelope.Data.Data, "inline data should be stripped") +} + +func mustJSON(t *testing.T, v any) json.RawMessage { + t.Helper() + data, err := json.Marshal(v) + require.NoError(t, err) + return data +} diff --git a/coderd/chats.go b/coderd/chats.go index 26c7828d07..a8094a97d0 100644 --- a/coderd/chats.go +++ b/coderd/chats.go @@ -1,11 +1,15 @@ package coderd import ( + "bufio" + "bytes" "context" "database/sql" "encoding/json" + "errors" "fmt" "io" + "mime" "net/http" "net/http/httptest" "net/url" @@ -247,7 +251,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { return } - contentBlocks, titleSource, inputError := createChatInputFromRequest(req) + contentBlocks, contentFileIDs, titleSource, inputError := createChatInputFromRequest(ctx, api.Database, req) if inputError != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, *inputError) return @@ -282,6 +286,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { ModelConfigID: modelConfigID, SystemPrompt: defaultChatSystemPrompt(), InitialUserContent: contentBlocks, + ContentFileIDs: contentFileIDs, }) if err != nil { if database.IsForeignKeyViolation( @@ -647,7 +652,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { return } - contentBlocks, _, inputError := createChatInputFromParts(req.Content, "content") + contentBlocks, contentFileIDs, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content") if inputError != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: inputError.Message, @@ -659,10 +664,11 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { sendResult, sendErr := api.chatDaemon.SendMessage( ctx, chatd.SendMessageOptions{ - ChatID: chatID, - Content: contentBlocks, - ModelConfigID: req.ModelConfigID, - BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + ChatID: chatID, + Content: contentBlocks, + ContentFileIDs: contentFileIDs, + ModelConfigID: req.ModelConfigID, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, }, ) if sendErr != nil { @@ -721,7 +727,7 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) { return } - contentBlocks, _, inputError := createChatInputFromParts(req.Content, "content") + contentBlocks, contentFileIDs, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content") if inputError != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: inputError.Message, @@ -734,6 +740,7 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) { ChatID: chat.ID, EditedMessageID: messageID, Content: contentBlocks, + ContentFileIDs: contentFileIDs, }) if editErr != nil { switch { @@ -2196,45 +2203,298 @@ func normalizeChatCompressionThreshold( return threshold, nil } +const ( + // maxChatFileSize is the maximum size of a chat file upload (10 MB). + maxChatFileSize = 10 << 20 + // maxChatFileName is the maximum length of an uploaded file name. + maxChatFileName = 255 +) + +// allowedChatFileMIMETypes lists the content types accepted for chat +// file uploads. SVG is explicitly excluded because it can contain scripts. +var allowedChatFileMIMETypes = map[string]bool{ + "image/png": true, + "image/jpeg": true, + "image/gif": true, + "image/webp": true, + "image/svg+xml": false, // SVG can contain scripts. +} + +var ( + webpMagicRIFF = []byte("RIFF") + webpMagicWEBP = []byte("WEBP") +) + +// detectChatFileType detects the MIME type of the given data. +// It extends http.DetectContentType with support for WebP, which +// Go's standard sniffer does not recognize. +func detectChatFileType(data []byte) string { + if len(data) >= 12 && + bytes.Equal(data[0:4], webpMagicRIFF) && + bytes.Equal(data[8:12], webpMagicWEBP) { + return "image/webp" + } + return http.DetectContentType(data) +} + func defaultChatSystemPrompt() string { return chatd.DefaultSystemPrompt } -func createChatInputFromRequest(req codersdk.CreateChatRequest) ( +// @Summary Upload a chat file +// @ID upload-chat-file +// @Security CoderSessionToken +// @Accept application/octet-stream +// @Produce json +// @Tags Chats +// @Param Content-Type header string true "Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp)" +// @Param organization query string true "Organization ID" format(uuid) +// @Success 201 {object} codersdk.UploadChatFileResponse +// @Failure 400 {object} codersdk.Response +// @Failure 401 {object} codersdk.Response +// @Failure 413 {object} codersdk.Response +// @Failure 500 {object} codersdk.Response +// @Router /chats/files [post] +func (api *API) postChatFile(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + apiKey := httpmw.APIKey(r) + + if !api.Authorize(r, policy.ActionCreate, rbac.ResourceChat.WithOwner(apiKey.UserID.String())) { + httpapi.Forbidden(rw) + return + } + + orgIDStr := r.URL.Query().Get("organization") + if orgIDStr == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Missing organization query parameter.", + }) + return + } + orgID, err := uuid.Parse(orgIDStr) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid organization ID.", + }) + return + } + + contentType := r.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/octet-stream" + } + // Strip parameters (e.g. "image/png; charset=utf-8" → "image/png") + // so the allowlist check matches the base media type. + if mediaType, _, err := mime.ParseMediaType(contentType); err == nil { + contentType = mediaType + } + + if allowed, ok := allowedChatFileMIMETypes[contentType]; !ok || !allowed { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Unsupported file type.", + Detail: "Allowed types: image/png, image/jpeg, image/gif, image/webp.", + }) + return + } + + r.Body = http.MaxBytesReader(rw, r.Body, maxChatFileSize) + br := bufio.NewReader(r.Body) + + // Peek at the leading bytes to sniff the real content type + // before reading the entire body. + peek, peekErr := br.Peek(512) + if peekErr != nil && !errors.Is(peekErr, io.EOF) && !errors.Is(peekErr, bufio.ErrBufferFull) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Failed to read file from request.", + Detail: peekErr.Error(), + }) + return + } + + // Verify the actual content matches a safe image type so that + // a client cannot spoof Content-Type to serve active content. + detected := detectChatFileType(peek) + if allowed, ok := allowedChatFileMIMETypes[detected]; !ok || !allowed { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Unsupported file type.", + Detail: "Allowed types: image/png, image/jpeg, image/gif, image/webp.", + }) + return + } + + // Read the full body now that we know the type is valid. + data, err := io.ReadAll(br) + if err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{ + Message: "File too large.", + Detail: fmt.Sprintf("Maximum file size is %d bytes.", maxChatFileSize), + }) + return + } + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Failed to read file from request.", + Detail: err.Error(), + }) + return + } + + // Extract filename from Content-Disposition header if provided. + var filename string + if cd := r.Header.Get("Content-Disposition"); cd != "" { + if _, params, err := mime.ParseMediaType(cd); err == nil { + filename = params["filename"] + if len(filename) > maxChatFileName { + // Truncate at rune boundary to avoid splitting + // multi-byte UTF-8 characters. + var truncated []byte + for _, r := range filename { + encoded := []byte(string(r)) + if len(truncated)+len(encoded) > maxChatFileName { + break + } + truncated = append(truncated, encoded...) + } + filename = string(truncated) + } + } + } + + chatFile, err := api.Database.InsertChatFile(ctx, database.InsertChatFileParams{ + OwnerID: apiKey.UserID, + OrganizationID: orgID, + Name: filename, + Mimetype: detected, + Data: data, + }) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to save chat file.", + Detail: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusCreated, codersdk.UploadChatFileResponse{ + ID: chatFile.ID, + }) +} + +// @Summary Get a chat file +// @ID get-chat-file +// @Security CoderSessionToken +// @Tags Chats +// @Param file path string true "File ID" format(uuid) +// @Success 200 +// @Failure 400 {object} codersdk.Response +// @Failure 401 {object} codersdk.Response +// @Failure 404 {object} codersdk.Response +// @Failure 500 {object} codersdk.Response +// @Router /chats/files/{file} [get] +func (api *API) chatFileByID(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + fileIDStr := chi.URLParam(r, "file") + fileID, err := uuid.Parse(fileIDStr) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid file ID.", + }) + return + } + + chatFile, err := api.Database.GetChatFileByID(ctx, fileID) + if err != nil { + if httpapi.Is404Error(err) { + httpapi.ResourceNotFound(rw) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to get chat file.", + Detail: err.Error(), + }) + return + } + + rw.Header().Set("Content-Type", chatFile.Mimetype) + if chatFile.Name != "" { + rw.Header().Set("Content-Disposition", mime.FormatMediaType("inline", map[string]string{"filename": chatFile.Name})) + } else { + rw.Header().Set("Content-Disposition", "inline") + } + rw.Header().Set("Cache-Control", "private, max-age=31536000, immutable") + rw.Header().Set("Content-Length", strconv.Itoa(len(chatFile.Data))) + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write(chatFile.Data) +} + +func createChatInputFromRequest(ctx context.Context, db database.Store, req codersdk.CreateChatRequest) ( []fantasy.Content, + map[int]uuid.UUID, string, *codersdk.Response, ) { - return createChatInputFromParts(req.Content, "content") + return createChatInputFromParts(ctx, db, req.Content, "content") } func createChatInputFromParts( + ctx context.Context, + db database.Store, parts []codersdk.ChatInputPart, fieldName string, -) ([]fantasy.Content, string, *codersdk.Response) { +) ([]fantasy.Content, map[int]uuid.UUID, string, *codersdk.Response) { if len(parts) == 0 { - return nil, "", &codersdk.Response{ + return nil, nil, "", &codersdk.Response{ Message: "Content is required.", Detail: "Content cannot be empty.", } } content := make([]fantasy.Content, 0, len(parts)) + fileIDs := make(map[int]uuid.UUID) textParts := make([]string, 0, len(parts)) for i, part := range parts { switch strings.ToLower(strings.TrimSpace(string(part.Type))) { case string(codersdk.ChatInputPartTypeText): text := strings.TrimSpace(part.Text) if text == "" { - return nil, "", &codersdk.Response{ + return nil, nil, "", &codersdk.Response{ Message: "Invalid input part.", Detail: fmt.Sprintf("%s[%d].text cannot be empty.", fieldName, i), } } content = append(content, fantasy.TextContent{Text: text}) textParts = append(textParts, text) + case string(codersdk.ChatInputPartTypeFile): + if part.FileID == uuid.Nil { + return nil, nil, "", &codersdk.Response{ + Message: "Invalid input part.", + Detail: fmt.Sprintf("%s[%d].file_id is required for file parts.", fieldName, i), + } + } + // Validate that the file exists and get its media type. + // File data is not loaded here; it's resolved at LLM + // dispatch time via chatFileResolver. + chatFile, err := db.GetChatFileByID(ctx, part.FileID) + if err != nil { + if httpapi.Is404Error(err) { + return nil, nil, "", &codersdk.Response{ + Message: "Invalid input part.", + Detail: fmt.Sprintf("%s[%d].file_id references a file that does not exist.", fieldName, i), + } + } + return nil, nil, "", &codersdk.Response{ + Message: "Internal error.", + Detail: fmt.Sprintf("Failed to retrieve file for %s[%d].", fieldName, i), + } + } + content = append(content, fantasy.FileContent{ + MediaType: chatFile.Mimetype, + }) + fileIDs[len(content)-1] = part.FileID default: - return nil, "", &codersdk.Response{ + return nil, nil, "", &codersdk.Response{ Message: "Invalid input part.", Detail: fmt.Sprintf( "%s[%d].type %q is not supported.", @@ -2246,14 +2506,16 @@ func createChatInputFromParts( } } - titleSource := strings.TrimSpace(strings.Join(textParts, " ")) - if titleSource == "" { - return nil, "", &codersdk.Response{ + // Allow file-only messages. The titleSource may be empty + // when only file parts are provided, callers handle this. + if len(content) == 0 { + return nil, nil, "", &codersdk.Response{ Message: "Content is required.", - Detail: "Content must include at least one text part.", + Detail: fmt.Sprintf("%s must include at least one text or file part.", fieldName), } } - return content, titleSource, nil + titleSource := strings.TrimSpace(strings.Join(textParts, " ")) + return content, fileIDs, titleSource, nil } func chatTitleFromMessage(message string) string { diff --git a/coderd/chats_test.go b/coderd/chats_test.go index 911cc6133c..30a45a23b9 100644 --- a/coderd/chats_test.go +++ b/coderd/chats_test.go @@ -1,11 +1,13 @@ package coderd_test import ( + "bytes" "database/sql" "encoding/json" "fmt" "net/http" "regexp" + "strings" "testing" "time" @@ -1525,6 +1527,175 @@ func TestPostChatMessages(t *testing.T) { }) } +func TestChatMessageWithFiles(t *testing.T) { + t.Parallel() + + t.Run("FileOnly", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + _ = createChatModelConfig(t, client) + + // Upload a file. + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData)) + require.NoError(t, err) + + // Create a chat with text first. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "initial message", + }, + }, + }) + require.NoError(t, err) + + // Send a file-only message (no text). + resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeFile, + FileID: uploadResp.ID, + }, + }, + }) + require.NoError(t, err) + + // Verify the message was accepted. + if resp.Queued { + require.NotNil(t, resp.QueuedMessage) + } else { + require.NotNil(t, resp.Message) + require.Equal(t, "user", resp.Message.Role) + } + }) + + t.Run("TextAndFile", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + _ = createChatModelConfig(t, client) + + // Upload a file. + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData)) + require.NoError(t, err) + + // Create a chat with text first. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "initial message", + }, + }, + }) + require.NoError(t, err) + + // Send a message with both text and file. + resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "here is an image", + }, + { + Type: codersdk.ChatInputPartTypeFile, + FileID: uploadResp.ID, + }, + }, + }) + require.NoError(t, err) + + if resp.Queued { + require.NotNil(t, resp.QueuedMessage) + } else { + require.NotNil(t, resp.Message) + require.Equal(t, "user", resp.Message.Role) + } + + // Verify file parts omit inline data in the API response. + chatWithMessages, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + for _, msg := range chatWithMessages.Messages { + for _, part := range msg.Content { + if part.Type == codersdk.ChatMessagePartTypeFile { + require.True(t, part.FileID.Valid, "file part should have a valid file_id") + require.Equal(t, uploadResp.ID, part.FileID.UUID) + require.Nil(t, part.Data, "file data should not be sent when file_id is present") + } + } + } + }) + + t.Run("FileOnlyOnCreate", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + _ = createChatModelConfig(t, client) + + // Upload a file. + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData)) + require.NoError(t, err) + + // Create a new chat with only a file part. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeFile, + FileID: uploadResp.ID, + }, + }, + }) + require.NoError(t, err) + + // With no text, chatTitleFromMessage("") returns "New Chat". + require.Equal(t, "New Chat", chat.Title) + }) + + t.Run("InvalidFileID", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client) + _ = createChatModelConfig(t, client) + + // Create a chat with text first. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "initial message", + }, + }, + }) + require.NoError(t, err) + + // Send a message with a non-existent file ID. + _, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeFile, + FileID: uuid.New(), + }, + }, + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Invalid input part.", sdkErr.Message) + require.Contains(t, sdkErr.Detail, "does not exist") + }) +} + func TestPatchChatMessage(t *testing.T) { t.Parallel() @@ -1602,6 +1773,100 @@ func TestPatchChatMessage(t *testing.T) { require.False(t, foundOriginalInChat) }) + t.Run("PreservesFileID", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + _ = createChatModelConfig(t, client) + + // Upload a file. + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData)) + require.NoError(t, err) + + // Create a chat with a text + file part. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "before edit with file", + }, + { + Type: codersdk.ChatInputPartTypeFile, + FileID: uploadResp.ID, + }, + }, + }) + require.NoError(t, err) + + // Find the user message ID. + chatWithMessages, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + + var userMessageID int64 + for _, message := range chatWithMessages.Messages { + if message.Role == "user" { + userMessageID = message.ID + break + } + } + require.NotZero(t, userMessageID) + + // Edit the message: new text, same file_id. + edited, err := client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "after edit with file", + }, + { + Type: codersdk.ChatInputPartTypeFile, + FileID: uploadResp.ID, + }, + }, + }) + require.NoError(t, err) + require.Equal(t, userMessageID, edited.ID) + + // Assert the edit response preserves the file_id. + var foundText, foundFile bool + for _, part := range edited.Content { + if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "after edit with file" { + foundText = true + } + if part.Type == codersdk.ChatMessagePartTypeFile && part.FileID.Valid && part.FileID.UUID == uploadResp.ID { + foundFile = true + require.Nil(t, part.Data, "file data should not be sent when file_id is present") + } + } + require.True(t, foundText, "edited message should contain updated text") + require.True(t, foundFile, "edited message should preserve file_id") + + // GET the chat and verify the file_id persists. + updatedChat, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + + var foundTextInChat, foundFileInChat bool + for _, message := range updatedChat.Messages { + if message.Role != "user" { + continue + } + for _, part := range message.Content { + if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "after edit with file" { + foundTextInChat = true + } + if part.Type == codersdk.ChatMessagePartTypeFile && part.FileID.Valid && part.FileID.UUID == uploadResp.ID { + foundFileInChat = true + require.Nil(t, part.Data, "file data should not be sent when file_id is present") + } + } + } + require.True(t, foundTextInChat, "chat should contain edited text") + require.True(t, foundFileInChat, "chat should preserve file_id after edit") + }) + t.Run("MessageNotFound", func(t *testing.T) { t.Parallel() @@ -2212,6 +2477,259 @@ func TestPromoteChatQueuedMessage(t *testing.T) { }) } +func TestPostChatFile(t *testing.T) { + t.Parallel() + + t.Run("Success/PNG", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + // Valid PNG header + padding. + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) + require.NoError(t, err) + require.NotEqual(t, uuid.Nil, resp.ID) + }) + + t.Run("Success/JPEG", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + data := append([]byte{0xFF, 0xD8, 0xFF, 0xE0}, make([]byte, 64)...) + resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/jpeg", "test.jpg", bytes.NewReader(data)) + require.NoError(t, err) + require.NotEqual(t, uuid.Nil, resp.ID) + }) + + t.Run("Success/WebP", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + // WebP: RIFF + 4-byte size + WEBP + padding. + data := append([]byte("RIFF"), make([]byte, 4)...) + data = append(data, []byte("WEBP")...) + data = append(data, make([]byte, 64)...) + resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/webp", "test.webp", bytes.NewReader(data)) + require.NoError(t, err) + require.NotEqual(t, uuid.Nil, resp.ID) + }) + + t.Run("UnsupportedContentType", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + _, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "text/plain", "test.txt", bytes.NewReader([]byte("hello"))) + requireSDKError(t, err, http.StatusBadRequest) + }) + + t.Run("SVGBlocked", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + _, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/svg+xml", "test.svg", bytes.NewReader([]byte(""))) + requireSDKError(t, err, http.StatusBadRequest) + }) + + t.Run("ContentSniffingRejects", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + // Header says PNG but body is plain text. + _, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader([]byte("hello world"))) + requireSDKError(t, err, http.StatusBadRequest) + }) + + t.Run("TooLarge", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + // 10 MB + 1 byte, with valid PNG header to pass MIME check. + data := make([]byte, 10<<20+1) + copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}) + _, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) + require.Error(t, err) + }) + + t.Run("MissingOrganization", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + coderdtest.CreateFirstUser(t, client) + + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + res, err := client.Request(ctx, http.MethodPost, "/api/experimental/chats/files", bytes.NewReader(data), func(r *http.Request) { + r.Header.Set("Content-Type", "image/png") + }) + require.NoError(t, err) + defer res.Body.Close() + err = codersdk.ReadBodyAsError(res) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Contains(t, sdkErr.Message, "Missing organization") + }) + + t.Run("InvalidOrganization", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + coderdtest.CreateFirstUser(t, client) + + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + res, err := client.Request(ctx, http.MethodPost, "/api/experimental/chats/files?organization=not-a-uuid", bytes.NewReader(data), func(r *http.Request) { + r.Header.Set("Content-Type", "image/png") + }) + require.NoError(t, err) + defer res.Body.Close() + err = codersdk.ReadBodyAsError(res) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Contains(t, sdkErr.Message, "Invalid organization ID") + }) + + t.Run("WrongOrganization", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + coderdtest.CreateFirstUser(t, client) + + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + _, err := client.UploadChatFile(ctx, uuid.New(), "image/png", "test.png", bytes.NewReader(data)) + require.Error(t, err) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + // dbauthz returns 404 or 500 depending on how the org lookup + // fails; 403 is also possible. Any non-success code is valid. + require.GreaterOrEqual(t, sdkErr.StatusCode(), http.StatusBadRequest, + "expected error status, got %d", sdkErr.StatusCode()) + }) + + t.Run("Unauthenticated", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + unauthed := codersdk.New(client.URL) + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + _, err := unauthed.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) + requireSDKError(t, err, http.StatusUnauthorized) + }) +} + +func TestGetChatFile(t *testing.T) { + t.Parallel() + + t.Run("Success", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) + require.NoError(t, err) + + got, contentType, err := client.GetChatFile(ctx, uploaded.ID) + require.NoError(t, err) + require.Equal(t, "image/png", contentType) + require.Equal(t, data, got) + }) + + t.Run("CacheHeaders", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) + require.NoError(t, err) + + res, err := client.Request(ctx, http.MethodGet, + fmt.Sprintf("/api/experimental/chats/files/%s", uploaded.ID), nil) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, "private, max-age=31536000, immutable", res.Header.Get("Cache-Control")) + require.Contains(t, res.Header.Get("Content-Disposition"), "inline") + require.Contains(t, res.Header.Get("Content-Disposition"), "test.png") + }) + + t.Run("LongFilename", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + longName := strings.Repeat("a", 300) + ".png" + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", longName, bytes.NewReader(data)) + require.NoError(t, err) + + res, err := client.Request(ctx, http.MethodGet, + fmt.Sprintf("/api/experimental/chats/files/%s", uploaded.ID), nil) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + // Filename should be truncated to maxChatFileName (255) bytes. + cd := res.Header.Get("Content-Disposition") + require.Contains(t, cd, "inline") + require.Contains(t, cd, strings.Repeat("a", 255)) + require.NotContains(t, cd, strings.Repeat("a", 256)) + }) + + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + coderdtest.CreateFirstUser(t, client) + + _, _, err := client.GetChatFile(ctx, uuid.New()) + requireSDKError(t, err, http.StatusNotFound) + }) + + t.Run("InvalidUUID", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + coderdtest.CreateFirstUser(t, client) + + res, err := client.Request(ctx, http.MethodGet, + "/api/experimental/chats/files/not-a-uuid", nil) + require.NoError(t, err) + defer res.Body.Close() + err = codersdk.ReadBodyAsError(res) + requireSDKError(t, err, http.StatusBadRequest) + }) + + t.Run("OtherUserForbidden", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client) + + data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) + require.NoError(t, err) + + otherClient, _ := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID) + _, _, err = otherClient.GetChatFile(ctx, uploaded.ID) + requireSDKError(t, err, http.StatusNotFound) + }) +} + func createChatModelConfig(t *testing.T, client *codersdk.Client) codersdk.ChatModelConfig { t.Helper() diff --git a/coderd/coderd.go b/coderd/coderd.go index dc9a4bc5f3..ae6b0bc159 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1113,6 +1113,11 @@ func New(options *Options) *API { r.Post("/", api.postChats) r.Get("/models", api.listChatModels) r.Get("/watch", api.watchChats) + r.Route("/files", func(r chi.Router) { + r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute)) + r.Post("/", api.postChatFile) + r.Get("/{file}", api.chatFileByID) + }) r.Route("/providers", func(r chi.Router) { r.Get("/", api.listChatProviders) r.Post("/", api.createChatProvider) @@ -1842,6 +1847,14 @@ func New(options *Options) *API { "parsing additional CSP headers", slog.Error(cspParseErrors)) } + // Add blob: to img-src for chat file attachment previews when + // the agents experiment is enabled. + if api.Experiments.Enabled(codersdk.ExperimentAgents) { + additionalCSPHeaders[httpmw.CSPDirectiveImgSrc] = append( + additionalCSPHeaders[httpmw.CSPDirectiveImgSrc], "blob:", + ) + } + // Add CSP headers to all static assets and pages. CSP headers only affect // browsers, so these don't make sense on api routes. cspMW := httpmw.CSPHeaders( diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index 9cfb4a344a..4ef02d3e42 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -1156,9 +1156,7 @@ func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMe } var rawBlocks []json.RawMessage - if role == string(fantasy.MessageRoleAssistant) { - _ = json.Unmarshal(raw.RawMessage, &rawBlocks) - } + _ = json.Unmarshal(raw.RawMessage, &rawBlocks) parts := make([]codersdk.ChatMessagePart, 0, len(content)) for i, block := range content { @@ -1166,10 +1164,20 @@ func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMe if part.Type == "" { continue } - if part.Type == codersdk.ChatMessagePartTypeReasoning { - part.Title = "" - if i < len(rawBlocks) { + if i < len(rawBlocks) { + switch part.Type { + case codersdk.ChatMessagePartTypeReasoning: part.Title = reasoningStoredTitle(rawBlocks[i]) + case codersdk.ChatMessagePartTypeFile: + if fid, err := chatprompt.ExtractFileID(rawBlocks[i]); err == nil { + part.FileID = uuid.NullUUID{UUID: fid, Valid: true} + } + // When a file_id is present, omit inline data + // from the response. Clients fetch content via + // the GET /chats/files/{id} endpoint instead. + if part.FileID.Valid { + part.Data = nil + } } } parts = append(parts, part) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 839c439cd6..83710fca9e 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2457,6 +2457,30 @@ func (q *querier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIDs []uu return q.db.GetChatDiffStatusesByChatIDs(ctx, chatIDs) } +func (q *querier) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) { + file, err := q.db.GetChatFileByID(ctx, id) + if err != nil { + return database.ChatFile{}, err + } + if err := q.authorizeContext(ctx, policy.ActionRead, file); err != nil { + return database.ChatFile{}, err + } + return file, nil +} + +func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) { + files, err := q.db.GetChatFilesByIDs(ctx, ids) + if err != nil { + return nil, err + } + for _, f := range files { + if err := q.authorizeContext(ctx, policy.ActionRead, f); err != nil { + return nil, err + } + } + return files, nil +} + func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) { // ChatMessages are authorized through their parent Chat. // We need to fetch the message first to get its chat_id. @@ -4491,6 +4515,11 @@ func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams) return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg) } +func (q *querier) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) { + // Authorize create on chat resource scoped to the owner and org. + return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertChatFile)(ctx, arg) +} + func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) { // Authorize create on the parent chat (using update permission). chat, err := q.db.GetChatByID(ctx, arg.ChatID) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 252f2075f4..668d016db8 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -463,6 +463,16 @@ func (s *MethodTestSuite) TestChats() { Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead). Returns([]database.ChatDiffStatus{diffStatusA, diffStatusB}) })) + s.Run("GetChatFileByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + file := testutil.Fake(s.T(), faker, database.ChatFile{}) + dbm.EXPECT().GetChatFileByID(gomock.Any(), file.ID).Return(file, nil).AnyTimes() + check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(file) + })) + s.Run("GetChatFilesByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + file := testutil.Fake(s.T(), faker, database.ChatFile{}) + dbm.EXPECT().GetChatFilesByIDs(gomock.Any(), []uuid.UUID{file.ID}).Return([]database.ChatFile{file}, nil).AnyTimes() + check.Args([]uuid.UUID{file.ID}).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns([]database.ChatFile{file}) + })) s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID}) @@ -579,6 +589,12 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().InsertChat(gomock.Any(), arg).Return(chat, nil).AnyTimes() check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionCreate).Returns(chat) })) + s.Run("InsertChatFile", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + arg := testutil.Fake(s.T(), faker, database.InsertChatFileParams{}) + file := testutil.Fake(s.T(), faker, database.InsertChatFileRow{OwnerID: arg.OwnerID, OrganizationID: arg.OrganizationID}) + dbm.EXPECT().InsertChatFile(gomock.Any(), arg).Return(file, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), policy.ActionCreate).Returns(file) + })) s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID}) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 60e12d7c96..82ad4baf61 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1007,6 +1007,22 @@ func (m queryMetricsStore) GetChatDiffStatusesByChatIDs(ctx context.Context, cha return r0, r1 } +func (m queryMetricsStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) { + start := time.Now() + r0, r1 := m.s.GetChatFileByID(ctx, id) + m.queryLatencies.WithLabelValues("GetChatFileByID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFileByID").Inc() + return r0, r1 +} + +func (m queryMetricsStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) { + start := time.Now() + r0, r1 := m.s.GetChatFilesByIDs(ctx, ids) + m.queryLatencies.WithLabelValues("GetChatFilesByIDs").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFilesByIDs").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) { start := time.Now() r0, r1 := m.s.GetChatMessageByID(ctx, id) @@ -2943,6 +2959,14 @@ func (m queryMetricsStore) InsertChat(ctx context.Context, arg database.InsertCh return r0, r1 } +func (m queryMetricsStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) { + start := time.Now() + r0, r1 := m.s.InsertChatFile(ctx, arg) + m.queryLatencies.WithLabelValues("InsertChatFile").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatFile").Inc() + return r0, r1 +} + func (m queryMetricsStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) { start := time.Now() r0, r1 := m.s.InsertChatMessage(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 724d7f2b7b..6a1b286ac5 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1837,6 +1837,36 @@ func (mr *MockStoreMockRecorder) GetChatDiffStatusesByChatIDs(ctx, chatIds any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDiffStatusesByChatIDs", reflect.TypeOf((*MockStore)(nil).GetChatDiffStatusesByChatIDs), ctx, chatIds) } +// GetChatFileByID mocks base method. +func (m *MockStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.ChatFile, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatFileByID", ctx, id) + ret0, _ := ret[0].(database.ChatFile) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatFileByID indicates an expected call of GetChatFileByID. +func (mr *MockStoreMockRecorder) GetChatFileByID(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileByID", reflect.TypeOf((*MockStore)(nil).GetChatFileByID), ctx, id) +} + +// GetChatFilesByIDs mocks base method. +func (m *MockStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatFilesByIDs", ctx, ids) + ret0, _ := ret[0].([]database.ChatFile) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatFilesByIDs indicates an expected call of GetChatFilesByIDs. +func (mr *MockStoreMockRecorder) GetChatFilesByIDs(ctx, ids any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFilesByIDs", reflect.TypeOf((*MockStore)(nil).GetChatFilesByIDs), ctx, ids) +} + // GetChatMessageByID mocks base method. func (m *MockStore) GetChatMessageByID(ctx context.Context, id int64) (database.ChatMessage, error) { m.ctrl.T.Helper() @@ -5511,6 +5541,21 @@ func (mr *MockStoreMockRecorder) InsertChat(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChat", reflect.TypeOf((*MockStore)(nil).InsertChat), ctx, arg) } +// InsertChatFile mocks base method. +func (m *MockStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertChatFile", ctx, arg) + ret0, _ := ret[0].(database.InsertChatFileRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InsertChatFile indicates an expected call of InsertChatFile. +func (mr *MockStoreMockRecorder) InsertChatFile(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatFile", reflect.TypeOf((*MockStore)(nil).InsertChatFile), ctx, arg) +} + // InsertChatMessage mocks base method. func (m *MockStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index a6d1ef960a..2cb2071ede 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1190,6 +1190,16 @@ CREATE TABLE chat_diff_statuses ( git_remote_origin text DEFAULT ''::text NOT NULL ); +CREATE TABLE chat_files ( + id uuid DEFAULT gen_random_uuid() NOT NULL, + owner_id uuid NOT NULL, + organization_id uuid NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + name text DEFAULT ''::text NOT NULL, + mimetype text NOT NULL, + data bytea NOT NULL +); + CREATE TABLE chat_messages ( id bigint NOT NULL, chat_id uuid NOT NULL, @@ -3140,6 +3150,9 @@ ALTER TABLE ONLY boundary_usage_stats ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id); +ALTER TABLE ONLY chat_files + ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id); + ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id); @@ -3495,6 +3508,10 @@ CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC); CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at); +CREATE INDEX idx_chat_files_org ON chat_files USING btree (organization_id); + +CREATE INDEX idx_chat_files_owner ON chat_files USING btree (owner_id); + CREATE INDEX idx_chat_messages_chat ON chat_messages USING btree (chat_id); CREATE INDEX idx_chat_messages_chat_created ON chat_messages USING btree (chat_id, created_at); @@ -3774,6 +3791,12 @@ ALTER TABLE ONLY api_keys ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; +ALTER TABLE ONLY chat_files + ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; + +ALTER TABLE ONLY chat_files + ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE; + ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index 8c8797c2a8..2fb45a6963 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -9,6 +9,8 @@ const ( ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id); ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; + ForeignKeyChatFilesOrganizationID ForeignKeyConstraint = "chat_files_organization_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; + ForeignKeyChatFilesOwnerID ForeignKeyConstraint = "chat_files_owner_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; ForeignKeyChatMessagesModelConfigID ForeignKeyConstraint = "chat_messages_model_config_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_model_config_id_fkey FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id); ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id); diff --git a/coderd/database/migrations/000429_chat_files.down.sql b/coderd/database/migrations/000429_chat_files.down.sql new file mode 100644 index 0000000000..37044f07df --- /dev/null +++ b/coderd/database/migrations/000429_chat_files.down.sql @@ -0,0 +1,2 @@ +DROP INDEX IF EXISTS idx_chat_files_org; +DROP TABLE IF EXISTS chat_files; diff --git a/coderd/database/migrations/000429_chat_files.up.sql b/coderd/database/migrations/000429_chat_files.up.sql new file mode 100644 index 0000000000..42abedaeb5 --- /dev/null +++ b/coderd/database/migrations/000429_chat_files.up.sql @@ -0,0 +1,12 @@ +CREATE TABLE chat_files ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + owner_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + name TEXT NOT NULL DEFAULT '', + mimetype TEXT NOT NULL, + data BYTEA NOT NULL +); + +CREATE INDEX idx_chat_files_owner ON chat_files(owner_id); +CREATE INDEX idx_chat_files_org ON chat_files(organization_id); diff --git a/coderd/database/migrations/testdata/fixtures/000429_chat_files.up.sql b/coderd/database/migrations/testdata/fixtures/000429_chat_files.up.sql new file mode 100644 index 0000000000..cd546f8f28 --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000429_chat_files.up.sql @@ -0,0 +1,13 @@ +INSERT INTO chat_files (id, owner_id, organization_id, created_at, name, mimetype, data) +SELECT + '00000000-0000-0000-0000-000000000099', + u.id, + om.organization_id, + '2024-01-01 00:00:00+00', + 'test.png', + 'image/png', + E'\\x89504E47' +FROM users u +JOIN organization_members om ON om.user_id = u.id +ORDER BY u.created_at, u.id +LIMIT 1; diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 3408ab20d5..a978840726 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -178,6 +178,10 @@ func (c Chat) RBACObject() rbac.Object { return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()) } +func (c ChatFile) RBACObject() rbac.Object { + return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID) +} + func (s APIKeyScope) ToRBAC() rbac.ScopeName { switch s { case ApiKeyScopeCoderAll: diff --git a/coderd/database/models.go b/coderd/database/models.go index 9007d046b4..f1f31313cf 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -3926,6 +3926,16 @@ type ChatDiffStatus struct { GitRemoteOrigin string `db:"git_remote_origin" json:"git_remote_origin"` } +type ChatFile struct { + ID uuid.UUID `db:"id" json:"id"` + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + Name string `db:"name" json:"name"` + Mimetype string `db:"mimetype" json:"mimetype"` + Data []byte `db:"data" json:"data"` +} + type ChatMessage struct { ID int64 `db:"id" json:"id"` ChatID uuid.UUID `db:"chat_id" json:"chat_id"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 467c97801e..014f9fd3b0 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -218,6 +218,8 @@ type sqlcQuerier interface { GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error) + GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error) + GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error) GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error) GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error) @@ -601,6 +603,7 @@ type sqlcQuerier interface { InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error) InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error) + InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error) InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error) InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error) InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 0e796864f7..931aa9bea6 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2214,6 +2214,103 @@ func (q *sqlQuerier) UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBou return new_period, err } +const getChatFileByID = `-- name: GetChatFileByID :one +SELECT id, owner_id, organization_id, created_at, name, mimetype, data FROM chat_files WHERE id = $1::uuid +` + +func (q *sqlQuerier) GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error) { + row := q.db.QueryRowContext(ctx, getChatFileByID, id) + var i ChatFile + err := row.Scan( + &i.ID, + &i.OwnerID, + &i.OrganizationID, + &i.CreatedAt, + &i.Name, + &i.Mimetype, + &i.Data, + ) + return i, err +} + +const getChatFilesByIDs = `-- name: GetChatFilesByIDs :many +SELECT id, owner_id, organization_id, created_at, name, mimetype, data FROM chat_files WHERE id = ANY($1::uuid[]) +` + +func (q *sqlQuerier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error) { + rows, err := q.db.QueryContext(ctx, getChatFilesByIDs, pq.Array(ids)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ChatFile + for rows.Next() { + var i ChatFile + if err := rows.Scan( + &i.ID, + &i.OwnerID, + &i.OrganizationID, + &i.CreatedAt, + &i.Name, + &i.Mimetype, + &i.Data, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertChatFile = `-- name: InsertChatFile :one +INSERT INTO chat_files (owner_id, organization_id, name, mimetype, data) +VALUES ($1::uuid, $2::uuid, $3::text, $4::text, $5::bytea) +RETURNING id, owner_id, organization_id, created_at, name, mimetype +` + +type InsertChatFileParams struct { + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + Name string `db:"name" json:"name"` + Mimetype string `db:"mimetype" json:"mimetype"` + Data []byte `db:"data" json:"data"` +} + +type InsertChatFileRow struct { + ID uuid.UUID `db:"id" json:"id"` + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + Name string `db:"name" json:"name"` + Mimetype string `db:"mimetype" json:"mimetype"` +} + +func (q *sqlQuerier) InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error) { + row := q.db.QueryRowContext(ctx, insertChatFile, + arg.OwnerID, + arg.OrganizationID, + arg.Name, + arg.Mimetype, + arg.Data, + ) + var i InsertChatFileRow + err := row.Scan( + &i.ID, + &i.OwnerID, + &i.OrganizationID, + &i.CreatedAt, + &i.Name, + &i.Mimetype, + ) + return i, err +} + const deleteChatModelConfigByID = `-- name: DeleteChatModelConfigByID :exec UPDATE chat_model_configs diff --git a/coderd/database/queries/chatfiles.sql b/coderd/database/queries/chatfiles.sql new file mode 100644 index 0000000000..5cb2ad89fe --- /dev/null +++ b/coderd/database/queries/chatfiles.sql @@ -0,0 +1,10 @@ +-- name: InsertChatFile :one +INSERT INTO chat_files (owner_id, organization_id, name, mimetype, data) +VALUES (@owner_id::uuid, @organization_id::uuid, @name::text, @mimetype::text, @data::bytea) +RETURNING id, owner_id, organization_id, created_at, name, mimetype; + +-- name: GetChatFileByID :one +SELECT * FROM chat_files WHERE id = @id::uuid; + +-- name: GetChatFilesByIDs :many +SELECT * FROM chat_files WHERE id = ANY(@ids::uuid[]); diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index e7f8489915..0ecf890017 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -15,6 +15,7 @@ const ( UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id); UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id); UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id); + UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id); UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id); UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id); UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id); diff --git a/codersdk/chats.go b/codersdk/chats.go index 8175ae5bc7..e3b81caa9c 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "mime" "net/http" "net/url" "strings" @@ -96,6 +97,7 @@ type ChatMessagePart struct { Title string `json:"title,omitempty"` MediaType string `json:"media_type,omitempty"` Data []byte `json:"data,omitempty"` + FileID uuid.NullUUID `json:"file_id,omitempty" format:"uuid"` } // ChatInputPartType represents an input part type for user chat input. @@ -103,12 +105,14 @@ type ChatInputPartType string const ( ChatInputPartTypeText ChatInputPartType = "text" + ChatInputPartTypeFile ChatInputPartType = "file" ) // ChatInputPart is a single user input part for creating a chat. type ChatInputPart struct { - Type ChatInputPartType `json:"type"` - Text string `json:"text,omitempty"` + Type ChatInputPartType `json:"type"` + Text string `json:"text,omitempty"` + FileID uuid.UUID `json:"file_id,omitempty" format:"uuid"` } // CreateChatRequest is the request to create a new chat. @@ -141,6 +145,11 @@ type CreateChatMessageResponse struct { Queued bool `json:"queued"` } +// UploadChatFileResponse is the response from uploading a chat file. +type UploadChatFileResponse struct { + ID uuid.UUID `json:"id" format:"uuid"` +} + // ChatWithMessages is a chat along with its messages. type ChatWithMessages struct { Chat Chat `json:"chat"` @@ -938,6 +947,42 @@ func (c *Client) GetChatDiffContents(ctx context.Context, chatID uuid.UUID) (Cha return diff, json.NewDecoder(res.Body).Decode(&diff) } +// UploadChatFile uploads a file for use in chat messages. +func (c *Client) UploadChatFile(ctx context.Context, organizationID uuid.UUID, contentType string, filename string, rd io.Reader) (UploadChatFileResponse, error) { + res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/files?organization=%s", organizationID), rd, func(r *http.Request) { + r.Header.Set("Content-Type", contentType) + if filename != "" { + r.Header.Set("Content-Disposition", mime.FormatMediaType("attachment", map[string]string{"filename": filename})) + } + }) + if err != nil { + return UploadChatFileResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusCreated { + return UploadChatFileResponse{}, ReadBodyAsError(res) + } + var resp UploadChatFileResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} + +// GetChatFile retrieves a previously uploaded chat file by ID. +func (c *Client) GetChatFile(ctx context.Context, fileID uuid.UUID) ([]byte, string, error) { + res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/files/%s", fileID), nil) + if err != nil { + return nil, "", err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, "", ReadBodyAsError(res) + } + data, err := io.ReadAll(res.Body) + if err != nil { + return nil, "", err + } + return data, res.Header.Get("Content-Type"), nil +} + func formatChatStreamResponseError(response Response) string { message := strings.TrimSpace(response.Message) detail := strings.TrimSpace(response.Detail) diff --git a/docs/ai-coder/agents/index.md b/docs/ai-coder/agents/index.md index 92d62520e4..3c426e4d40 100644 --- a/docs/ai-coder/agents/index.md +++ b/docs/ai-coder/agents/index.md @@ -132,6 +132,18 @@ are queued and delivered when the agent completes its current step, so there is no need to wait for a response before providing additional context or changing direction. +### Image attachments + +Users can attach images to chat messages by pasting from the clipboard, dragging +files into the input area, or using the attachment button. Supported formats are +PNG, JPEG, GIF, and WebP up to 10 MB per file. Images are sent to the model as +multimodal content alongside the text prompt. + +This is useful for sharing screenshots of errors, UI mockups, terminal output, +or other visual context that helps the agent understand the task. Messages can +contain images alone or combined with text. Image attachments require a model +that supports vision input. + ## Security benefits of the control plane architecture Running the agent loop in the control plane rather than inside the developer diff --git a/docs/reference/api/chats.md b/docs/reference/api/chats.md index 655993b962..073b509a08 100644 --- a/docs/reference/api/chats.md +++ b/docs/reference/api/chats.md @@ -1,5 +1,83 @@ # Chats +## Upload a chat file + +### Code samples + +```shell +# Example request using curl +curl -X POST http://coder-server:8080/api/v2/chats/files?organization=497f6eca-6276-4993-bfeb-53cbbbba6f08 \ + -H 'Accept: application/json' \ + -H 'Content-Type: string' \ + -H 'Coder-Session-Token: API_KEY' +``` + +`POST /chats/files` + +### Parameters + +| Name | In | Type | Required | Description | +|----------------|--------|--------------|----------|-----------------------------------------------------------------------------------| +| `Content-Type` | header | string | true | Content-Type must be an image type (image/png, image/jpeg, image/gif, image/webp) | +| `organization` | query | string(uuid) | true | Organization ID | + +### Example responses + +> 201 Response + +```json +{ + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08" +} +``` + +### Responses + +| Status | Meaning | Description | Schema | +|--------|----------------------------------------------------------------------------|--------------------------|------------------------------------------------------------------------------| +| 201 | [Created](https://tools.ietf.org/html/rfc7231#section-6.3.2) | Created | [codersdk.UploadChatFileResponse](schemas.md#codersdkuploadchatfileresponse) | +| 400 | [Bad Request](https://tools.ietf.org/html/rfc7231#section-6.5.1) | Bad Request | [codersdk.Response](schemas.md#codersdkresponse) | +| 401 | [Unauthorized](https://tools.ietf.org/html/rfc7235#section-3.1) | Unauthorized | [codersdk.Response](schemas.md#codersdkresponse) | +| 413 | [Payload Too Large](https://tools.ietf.org/html/rfc7231#section-6.5.11) | Request Entity Too Large | [codersdk.Response](schemas.md#codersdkresponse) | +| 500 | [Internal Server Error](https://tools.ietf.org/html/rfc7231#section-6.6.1) | Internal Server Error | [codersdk.Response](schemas.md#codersdkresponse) | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). + +## Get a chat file + +### Code samples + +```shell +# Example request using curl +curl -X GET http://coder-server:8080/api/v2/chats/files/{file} \ + -H 'Accept: */*' \ + -H 'Coder-Session-Token: API_KEY' +``` + +`GET /chats/files/{file}` + +### Parameters + +| Name | In | Type | Required | Description | +|--------|------|--------------|----------|-------------| +| `file` | path | string(uuid) | true | File ID | + +### Example responses + +> 400 Response + +### Responses + +| Status | Meaning | Description | Schema | +|--------|----------------------------------------------------------------------------|-----------------------|--------------------------------------------------| +| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | | +| 400 | [Bad Request](https://tools.ietf.org/html/rfc7231#section-6.5.1) | Bad Request | [codersdk.Response](schemas.md#codersdkresponse) | +| 401 | [Unauthorized](https://tools.ietf.org/html/rfc7235#section-3.1) | Unauthorized | [codersdk.Response](schemas.md#codersdkresponse) | +| 404 | [Not Found](https://tools.ietf.org/html/rfc7231#section-6.5.4) | Not Found | [codersdk.Response](schemas.md#codersdkresponse) | +| 500 | [Internal Server Error](https://tools.ietf.org/html/rfc7231#section-6.6.1) | Internal Server Error | [codersdk.Response](schemas.md#codersdkresponse) | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). + ## Archive a chat ### Code samples diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index cd67ea783d..c113917502 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -9847,6 +9847,20 @@ If the schedule is empty, the user will be updated to use the default schedule.| |----------|---------|----------|--------------|-------------| | `ttl_ms` | integer | false | | | +## codersdk.UploadChatFileResponse + +```json +{ + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08" +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|------|--------|----------|--------------|-------------| +| `id` | string | false | | | + ## codersdk.UploadResponse ```json diff --git a/site/src/api/api.ts b/site/src/api/api.ts index b8849444a9..36c7d4ae9e 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -2296,6 +2296,23 @@ class ApiMethods { return response.data; }; + uploadChatFile = async ( + file: File, + organizationId: string, + ): Promise => { + const response = await this.axios.post( + `/api/experimental/chats/files?organization=${organizationId}`, + file, + { + headers: { + "Content-Type": file.type || "application/octet-stream", + "Content-Disposition": `attachment; filename="${file.name}"`, + }, + }, + ); + return response.data; + }; + getTemplateVersionLogs = async ( versionId: string, ): Promise => { diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index fb58a6a963..4efd1153d4 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1120,12 +1120,13 @@ export interface ChatGitChange { export interface ChatInputPart { readonly type: ChatInputPartType; readonly text?: string; + readonly file_id?: string; } // From codersdk/chats.go -export type ChatInputPartType = "text"; +export type ChatInputPartType = "file" | "text"; -export const ChatInputPartTypes: ChatInputPartType[] = ["text"]; +export const ChatInputPartTypes: ChatInputPartType[] = ["file", "text"]; // From codersdk/chats.go /** @@ -1161,6 +1162,7 @@ export interface ChatMessagePart { readonly title?: string; readonly media_type?: string; readonly data?: string; + readonly file_id?: string; } // From codersdk/chats.go @@ -6556,6 +6558,14 @@ export interface UpdateWorkspaceTTLRequest { readonly ttl_ms: number | null; } +// From codersdk/chats.go +/** + * UploadChatFileResponse is the response from uploading a chat file. + */ +export interface UploadChatFileResponse { + readonly id: string; +} + // From codersdk/files.go /** * UploadResponse contains the hash to reference the uploaded file. diff --git a/site/src/components/ChatMessageInput/ChatMessageInput.tsx b/site/src/components/ChatMessageInput/ChatMessageInput.tsx index 23ad9c3c31..77e7cecea9 100644 --- a/site/src/components/ChatMessageInput/ChatMessageInput.tsx +++ b/site/src/components/ChatMessageInput/ChatMessageInput.tsx @@ -57,8 +57,11 @@ const DisableFormattingPlugin: FC = memo(function DisableFormattingPlugin() { }); // Intercepts paste events and inserts clipboard content as plain text, -// stripping any rich-text formatting. -const PasteSanitizationPlugin: FC = memo(function PasteSanitizationPlugin() { +// stripping any rich-text formatting. Image files are forwarded to +// the parent via the onFilePaste callback instead of being inserted. +const PasteSanitizationPlugin: FC<{ + onFilePaste?: (file: File) => void; +}> = memo(function PasteSanitizationPlugin({ onFilePaste }) { const [editor] = useLexicalComposerContext(); useEffect(() => { @@ -69,6 +72,22 @@ const PasteSanitizationPlugin: FC = memo(function PasteSanitizationPlugin() { const clipboardData = event.clipboardData; if (!clipboardData) return false; + // Check for image files in the clipboard (e.g. pasted + // screenshots). Forward them to the parent via callback + // instead of inserting text. + if (onFilePaste && clipboardData.files.length > 0) { + const images = Array.from(clipboardData.files).filter((f) => + f.type.startsWith("image/"), + ); + if (images.length > 0) { + event.preventDefault(); + for (const file of images) { + onFilePaste(file); + } + return true; + } + } + const text = clipboardData.getData("text/plain"); if (!text) return false; @@ -106,7 +125,7 @@ const PasteSanitizationPlugin: FC = memo(function PasteSanitizationPlugin() { }, COMMAND_PRIORITY_HIGH, ); - }, [editor]); + }, [editor, onFilePaste]); return null; }); @@ -217,6 +236,7 @@ interface ChatMessageInputProps onChange?: (content: string) => void; rows?: number; onEnter?: () => void; + onFilePaste?: (file: File) => void; disabled?: boolean; autoFocus?: boolean; "aria-label"?: string; @@ -245,6 +265,7 @@ const ChatMessageInput = memo( onChange, rows, onEnter, + onFilePaste, disabled, autoFocus, "aria-label": ariaLabel, @@ -392,7 +413,7 @@ const ChatMessageInput = memo( /> - + diff --git a/site/src/pages/AgentsPage/AgentChatInput.stories.tsx b/site/src/pages/AgentsPage/AgentChatInput.stories.tsx index 7f3a287e6f..47245db2f6 100644 --- a/site/src/pages/AgentsPage/AgentChatInput.stories.tsx +++ b/site/src/pages/AgentsPage/AgentChatInput.stories.tsx @@ -1,6 +1,6 @@ import type { Meta, StoryObj } from "@storybook/react-vite"; import { expect, fn, userEvent, waitFor, within } from "storybook/test"; -import { AgentChatInput } from "./AgentChatInput"; +import { AgentChatInput, type UploadState } from "./AgentChatInput"; const defaultModelOptions = [ { @@ -144,3 +144,80 @@ export const LongContentScrollable: Story = { initialValue: longContent, }, }; + +// Tiny 1x1 transparent PNG as data URI for attachment previews. +const TINY_PNG = + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="; + +const createMockFile = (name: string, type: string) => + new File(["mock-data"], name, { type }); + +export const WithAttachments: Story = { + args: (() => { + const file1 = createMockFile("screenshot.png", "image/png"); + const file2 = createMockFile("diagram.jpg", "image/jpeg"); + const attachments = [file1, file2]; + return { + attachments, + uploadStates: new Map([ + [file1, { status: "uploaded", fileId: "f1" }], + [file2, { status: "uploaded", fileId: "f2" }], + ]), + previewUrls: new Map([ + [file1, TINY_PNG], + [file2, TINY_PNG], + ]), + onAttach: fn(), + onRemoveAttachment: fn(), + initialValue: "Here are the images", + }; + })(), +}; + +export const WithUploadingAttachment: Story = { + args: (() => { + const file = createMockFile("uploading.png", "image/png"); + return { + attachments: [file], + uploadStates: new Map([ + [file, { status: "uploading" }], + ]), + previewUrls: new Map([[file, TINY_PNG]]), + onAttach: fn(), + onRemoveAttachment: fn(), + initialValue: "Waiting for upload", + }; + })(), +}; + +export const WithAttachmentError: Story = { + args: (() => { + const file = createMockFile("broken.png", "image/png"); + return { + attachments: [file], + uploadStates: new Map([ + [file, { status: "error", error: "Upload failed: server error" }], + ]), + previewUrls: new Map([[file, TINY_PNG]]), + onAttach: fn(), + onRemoveAttachment: fn(), + initialValue: "Upload had an error", + }; + })(), +}; + +export const AttachmentsOnly: Story = { + args: (() => { + const file = createMockFile("photo.png", "image/png"); + return { + attachments: [file], + uploadStates: new Map([ + [file, { status: "uploaded", fileId: "f-only" }], + ]), + previewUrls: new Map([[file, TINY_PNG]]), + onAttach: fn(), + onRemoveAttachment: fn(), + initialValue: "", + }; + })(), +}; diff --git a/site/src/pages/AgentsPage/AgentChatInput.tsx b/site/src/pages/AgentsPage/AgentChatInput.tsx index a67775dd8a..9df0d0489e 100644 --- a/site/src/pages/AgentsPage/AgentChatInput.tsx +++ b/site/src/pages/AgentsPage/AgentChatInput.tsx @@ -13,14 +13,29 @@ import { TooltipContent, TooltipTrigger, } from "components/Tooltip/Tooltip"; -import { ArrowUpIcon, Loader2Icon, Square, XIcon } from "lucide-react"; +import { + AlertTriangleIcon, + ArrowUpIcon, + ImageIcon, + Loader2Icon, + Square, + XIcon, +} from "lucide-react"; +import type React from "react"; import { memo, type ReactNode, useCallback, useRef, useState } from "react"; import { cn } from "utils/cn"; +import { ImageLightbox } from "./ImageLightbox"; import { formatProviderLabel } from "./modelOptions"; import { QueuedMessagesList } from "./QueuedMessagesList"; export type { ChatMessageInputRef } from "components/ChatMessageInput/ChatMessageInput"; +export type UploadState = { + status: "uploading" | "uploaded" | "error"; + fileId?: string; + error?: string; +}; + export interface AgentContextUsage { readonly usedTokens?: number; readonly contextLimitTokens?: number; @@ -76,6 +91,11 @@ interface AgentChatInputProps { // Pass `null` to render fallback values (e.g. when limit is unknown). // Omit entirely to hide the indicator. contextUsage?: AgentContextUsage | null; + attachments?: File[]; + onAttach?: (files: File[]) => void; + onRemoveAttachment?: (index: number) => void; + uploadStates?: Map; + previewUrls?: Map; } const hasFiniteTokenValue = (value: number | undefined): value is number => @@ -201,6 +221,97 @@ const ContextUsageIndicator = memo<{ usage: AgentContextUsage | null }>( ); ContextUsageIndicator.displayName = "ContextUsageIndicator"; +/** Renders an image thumbnail from a pre-created preview URL. */ +export const ImageThumbnail = memo<{ + previewUrl: string; + name: string; + className?: string; +}>(({ previewUrl, name, className }) => ( + {name} +)); +ImageThumbnail.displayName = "ImageThumbnail"; + +/** Renders a horizontal strip of attachment thumbnails above the input. */ +export const AttachmentPreview = memo<{ + attachments: File[]; + onRemove: (index: number) => void; + uploadStates?: Map; + previewUrls?: Map; + onPreview?: (url: string) => void; +}>(({ attachments, onRemove, uploadStates, previewUrls, onPreview }) => { + if (attachments.length === 0) return null; + + return ( +
+ {attachments.map((file, index) => { + const uploadState = uploadStates?.get(file); + const previewUrl = previewUrls?.get(file) ?? ""; + return ( +
+ {file.type.startsWith("image/") && previewUrl ? ( + + ) : ( +
+ {file.name.split(".").pop()?.toUpperCase() || "FILE"} +
+ )} + {uploadState?.status === "uploading" && ( +
+ +
+ )} + {uploadState?.status === "error" && ( + + +
+ +
+
+ +

+ {uploadState.error ?? "Upload failed"} +

+
+
+ )} + +
+ ); + })} +
+ ); +}); +AttachmentPreview.displayName = "AttachmentPreview"; + export const AgentChatInput = memo( ({ onSend, @@ -230,8 +341,14 @@ export const AgentChatInput = memo( isEditingHistoryMessage = false, onCancelHistoryEdit, contextUsage, + attachments = [], + onAttach, + onRemoveAttachment, + uploadStates, + previewUrls, }) => { const internalRef = useRef(null); + const [previewImage, setPreviewImage] = useState(null); // Merge the external inputRef with our internal ref so both // point to the same ChatMessageInputRef instance. @@ -251,6 +368,57 @@ export const AgentChatInput = memo( [inputRef], ); + const fileInputRef = useRef(null); + + const handleFileSelect = useCallback( + (e: React.ChangeEvent) => { + if (e.target.files && onAttach) { + onAttach(Array.from(e.target.files)); + } + // Reset so the same file can be selected again. + e.target.value = ""; + }, + [onAttach], + ); + + const handleFilePaste = useCallback( + (file: File) => { + onAttach?.([file]); + }, + [onAttach], + ); + + // Drag-and-drop support for image files. + const [isDragging, setIsDragging] = useState(false); + + const handleDragOver = useCallback((e: React.DragEvent) => { + e.preventDefault(); + if (e.dataTransfer.types.includes("Files")) { + setIsDragging(true); + } + }, []); + + const handleDragLeave = useCallback((e: React.DragEvent) => { + if (!e.currentTarget.contains(e.relatedTarget as Node)) { + setIsDragging(false); + } + }, []); + + const handleDrop = useCallback( + (e: React.DragEvent) => { + e.preventDefault(); + setIsDragging(false); + if (!onAttach || !e.dataTransfer.files.length) return; + const images = Array.from(e.dataTransfer.files).filter((f) => + f.type.startsWith("image/"), + ); + if (images.length > 0) { + onAttach(images); + } + }, + [onAttach], + ); + // Track whether the editor has content so we can gate the // send button without a controlled value prop. const [hasContent, setHasContent] = useState(() => !!initialValue?.trim()); @@ -275,7 +443,18 @@ export const AgentChatInput = memo( } } - const canSend = !isDisabled && !isLoading && hasModelOptions && hasContent; + const isUploading = attachments.some( + (f) => uploadStates?.get(f)?.status === "uploading", + ); + const hasUploadedAttachments = attachments.some( + (f) => uploadStates?.get(f)?.status === "uploaded", + ); + const canSend = + !isDisabled && + !isLoading && + hasModelOptions && + (hasContent || hasUploadedAttachments) && + !isUploading; const handleSubmit = useCallback(() => { const text = internalRef.current?.getValue()?.trim() ?? ""; @@ -284,6 +463,7 @@ export const AgentChatInput = memo( // promote the first one instead of submitting. if ( !text && + !hasUploadedAttachments && !isDisabled && !isLoading && queuedMessages.length > 0 && @@ -293,7 +473,12 @@ export const AgentChatInput = memo( return; } - if (!text || isDisabled || isLoading || !hasModelOptions) { + if ( + (!text && !hasUploadedAttachments) || + isDisabled || + isLoading || + !hasModelOptions + ) { return; } @@ -303,6 +488,7 @@ export const AgentChatInput = memo( isDisabled, isLoading, hasModelOptions, + hasUploadedAttachments, onSend, queuedMessages, onPromoteQueuedMessage, @@ -348,8 +534,14 @@ export const AgentChatInput = memo( /> )}
{editingQueuedMessageID !== null && (
@@ -388,8 +580,18 @@ export const AgentChatInput = memo(
)} + {onRemoveAttachment && ( + + )} ( {contextUsage !== undefined && ( )} + {onAttach && ( + <> + + + + )} {isStreaming && onInterrupt && (
); - return content; + return ( + <> + {content} + {previewImage && ( + setPreviewImage(null)} + /> + )} + + ); }, ); AgentChatInput.displayName = "AgentChatInput"; diff --git a/site/src/pages/AgentsPage/AgentDetail.test.ts b/site/src/pages/AgentsPage/AgentDetail.test.ts index f6549df834..82b7f8f0cb 100644 --- a/site/src/pages/AgentsPage/AgentDetail.test.ts +++ b/site/src/pages/AgentsPage/AgentDetail.test.ts @@ -110,7 +110,7 @@ describe("useConversationEditingState", () => { await act(async () => { result.current.handleSendFromInput("hello"); await vi.waitFor(() => { - expect(onSend).toHaveBeenCalledWith("hello", undefined); + expect(onSend).toHaveBeenCalledWith("hello", undefined, undefined); }); }); diff --git a/site/src/pages/AgentsPage/AgentDetail.tsx b/site/src/pages/AgentsPage/AgentDetail.tsx index 91bc8ff04a..af7506304b 100644 --- a/site/src/pages/AgentsPage/AgentDetail.tsx +++ b/site/src/pages/AgentsPage/AgentDetail.tsx @@ -22,6 +22,7 @@ import { getVSCodeHref, openAppInNewWindow, } from "modules/apps/apps"; +import { useDashboard } from "modules/dashboard/useDashboard"; import { type FC, useCallback, @@ -35,7 +36,11 @@ import { useNavigate, useOutletContext, useParams } from "react-router"; import { toast } from "sonner"; import { cn } from "utils/cn"; import { pageTitle } from "utils/page"; -import { AgentChatInput, type ChatMessageInputRef } from "./AgentChatInput"; +import { + AgentChatInput, + type ChatMessageInputRef, + type UploadState, +} from "./AgentChatInput"; import { selectChatStatus, selectHasStreamState, @@ -74,6 +79,7 @@ import { } from "./modelOptions"; import { RightPanel } from "./RightPanel"; import { SidebarTabView } from "./SidebarTabView"; +import { useFileAttachments } from "./useFileAttachments"; import { useGitWatcher } from "./useGitWatcher"; const noopSetChatErrorReason: AgentsOutletContext["setChatErrorReason"] = @@ -99,7 +105,11 @@ interface AgentDetailTimelineProps { store: ChatStoreHandle; chatID: string; persistedErrorReason: string | undefined; - onEditUserMessage?: (messageId: number, text: string) => void; + onEditUserMessage?: ( + messageId: number, + text: string, + fileBlocks?: Array<{ mediaType: string; data?: string }>, + ) => void; editingMessageId?: number | null; savingMessageId?: number | null; } @@ -186,7 +196,7 @@ const AgentDetailTimeline: FC = ({ interface AgentDetailInputProps { store: ChatStoreHandle; compressionThreshold: number | undefined; - onSend: (message: string) => void; + onSend: (message: string, fileIds?: string[]) => void; onDeleteQueuedMessage: (id: number) => Promise; onPromoteQueuedMessage: (id: number) => Promise; onInterrupt: () => void; @@ -210,6 +220,13 @@ interface AgentDetailInputProps { onCancelQueueEdit: () => void; isEditingHistoryMessage: boolean; onCancelHistoryEdit: () => void; + // File blocks from the message being edited, converted to + // File objects and pre-populated into attachments. + editingFileBlocks?: Array<{ + mediaType: string; + data?: string; + fileId?: string; + }>; } const AgentDetailInput: FC = ({ @@ -237,6 +254,7 @@ const AgentDetailInput: FC = ({ onCancelQueueEdit, isEditingHistoryMessage, onCancelHistoryEdit, + editingFileBlocks, }) => { const messagesByID = useChatSelector(store, selectMessagesByID); const orderedMessageIDs = useChatSelector(store, selectOrderedMessageIDs); @@ -251,6 +269,8 @@ const AgentDetailInput: FC = ({ .filter(isChatMessage), [messagesByID, orderedMessageIDs], ); + const { organizations } = useDashboard(); + const organizationId = organizations[0]?.id; const latestContextUsage = useMemo(() => { const usage = getLatestContextUsage(messages); if (!usage) { @@ -258,12 +278,96 @@ const AgentDetailInput: FC = ({ } return { ...usage, compressionThreshold }; }, [messages, compressionThreshold]); + const { + attachments, + uploadStates, + previewUrls, + handleAttach, + handleRemoveAttachment, + resetAttachments, + setAttachments, + setPreviewUrls, + setUploadStates, + } = useFileAttachments(organizationId); + // Pre-populate attachments from existing file blocks when + // entering edit mode on a message with images. + useEffect(() => { + if (!editingFileBlocks || editingFileBlocks.length === 0) { + // Clear attachments when exiting edit mode. + setAttachments([]); + setUploadStates(new Map()); + setPreviewUrls(new Map()); + return; + } + const files = editingFileBlocks.map((block, i) => { + const ext = block.mediaType.split("/")[1] ?? "png"; + // Empty File used as a Map key only, its content is never + // read because the existing fileId is reused at send time. + return new File([], `attachment-${i}.${ext}`, { + type: block.mediaType, + }); + }); + setAttachments(files); + setPreviewUrls( + new Map( + files.map((f, i) => [ + f, + `/api/experimental/chats/files/${editingFileBlocks[i].fileId}`, + ]), + ), + ); + const newUploadStates = new Map(); + for (const [i, file] of files.entries()) { + const block = editingFileBlocks[i]; + if (block.fileId) { + newUploadStates.set(file, { + status: "uploaded", + fileId: block.fileId, + }); + } + } + setUploadStates(newUploadStates); + }, [editingFileBlocks, setAttachments, setPreviewUrls, setUploadStates]); + const isStreaming = hasStreamState || chatStatus === "running" || chatStatus === "pending"; return ( { + void (async () => { + try { + // Collect file IDs from already-uploaded attachments. + // Skip files in error state (e.g. too large). + const fileIds: string[] = []; + let skippedErrors = 0; + for (const file of attachments) { + const state = uploadStates.get(file); + if (state?.status === "error") { + skippedErrors++; + continue; + } + if (state?.status === "uploaded" && state.fileId) { + fileIds.push(state.fileId); + } + } + if (skippedErrors > 0) { + toast.warning( + `${skippedErrors} attachment${skippedErrors > 1 ? "s" : ""} could not be sent (upload failed)`, + ); + } + await onSend(message, fileIds.length > 0 ? fileIds : undefined); + resetAttachments(); + } catch { + // Attachments preserved for retry on failure. + } + })(); + }} + attachments={attachments} + onAttach={handleAttach} + onRemoveAttachment={handleRemoveAttachment} + uploadStates={uploadStates} + previewUrls={previewUrls} inputRef={inputRef} initialValue={initialValue} onContentChange={onContentChange} @@ -295,7 +399,11 @@ const AgentDetailInput: FC = ({ /** @internal Exported for testing. */ export function useConversationEditingState(deps: { chatID: string | undefined; - onSend: (message: string, editedMessageID?: number) => Promise; + onSend: ( + message: string, + fileIds?: string[], + editedMessageID?: number, + ) => Promise; onDeleteQueuedMessage: (id: number) => Promise; chatInputRef: React.RefObject; inputValueRef: React.RefObject; @@ -321,15 +429,23 @@ export function useConversationEditingState(deps: { const [draftBeforeHistoryEdit, setDraftBeforeHistoryEdit] = useState< string | null >(null); + const [editingFileBlocks, setEditingFileBlocks] = useState< + Array<{ mediaType: string; data?: string; fileId?: string }> + >([]); const handleEditUserMessage = useCallback( - (messageId: number, text: string) => { + ( + messageId: number, + text: string, + fileBlocks?: Array<{ mediaType: string; data?: string; fileId?: string }>, + ) => { setDraftBeforeHistoryEdit((prev) => editingMessageId !== null ? prev : inputValueRef.current, ); setEditingMessageId(messageId); setEditorInitialValue(text); inputValueRef.current = text; + setEditingFileBlocks(fileBlocks ?? []); }, [editingMessageId, inputValueRef], ); @@ -339,6 +455,7 @@ export function useConversationEditingState(deps: { inputValueRef.current = draftBeforeHistoryEdit ?? ""; setEditingMessageId(null); setDraftBeforeHistoryEdit(null); + setEditingFileBlocks([]); }, [draftBeforeHistoryEdit, inputValueRef]); // -- Queue editing state -- @@ -371,29 +488,29 @@ export function useConversationEditingState(deps: { // Wraps the parent onSend to clear local input/editing state // and handle queue-edit deletion. const handleSendFromInput = useCallback( - (message: string) => { + async (message: string, fileIds?: string[]) => { const editedMessageID = editingMessageId !== null ? editingMessageId : undefined; const queueEditID = editingQueuedMessageID; - void onSend(message, editedMessageID).then(() => { - // Clear input and editing state on success. - chatInputRef.current?.clear(); - chatInputRef.current?.focus(); - inputValueRef.current = ""; - if (typeof window !== "undefined" && draftStorageKey) { - localStorage.removeItem(draftStorageKey); - } - if (editingMessageId !== null) { - setEditingMessageId(null); - setDraftBeforeHistoryEdit(null); - } - if (queueEditID !== null) { - setEditingQueuedMessageID(null); - setDraftBeforeQueueEdit(null); - void onDeleteQueuedMessage(queueEditID); - } - }); + await onSend(message, fileIds, editedMessageID); + // Clear input and editing state on success. + chatInputRef.current?.clear(); + chatInputRef.current?.focus(); + inputValueRef.current = ""; + if (typeof window !== "undefined" && draftStorageKey) { + localStorage.removeItem(draftStorageKey); + } + if (editingMessageId !== null) { + setEditingMessageId(null); + setDraftBeforeHistoryEdit(null); + setEditingFileBlocks([]); + } + if (queueEditID !== null) { + setEditingQueuedMessageID(null); + setDraftBeforeQueueEdit(null); + void onDeleteQueuedMessage(queueEditID); + } }, [ chatInputRef, @@ -425,6 +542,7 @@ export function useConversationEditingState(deps: { chatInputRef, editorInitialValue, editingMessageId, + editingFileBlocks, handleEditUserMessage, handleCancelHistoryEdit, editingQueuedMessageID, @@ -658,16 +776,26 @@ const AgentDetail: FC = () => { interruptMutation.isPending; const isInputDisabled = !hasModelOptions || isArchived; - const handleSend = async (message: string, editedMessageID?: number) => { - if ( - !message.trim() || - isSubmissionPending || - !agentId || - !hasModelOptions - ) { + const handleSend = async ( + message: string, + fileIds?: string[], + editedMessageID?: number, + ) => { + const hasContent = message.trim() || (fileIds && fileIds.length > 0); + if (!hasContent || isSubmissionPending || !agentId || !hasModelOptions) { return; } - const content: TypesGen.ChatInputPart[] = [{ type: "text", text: message }]; + const content: TypesGen.ChatInputPart[] = []; + if (message.trim()) { + content.push({ type: "text", text: message }); + } + + // Add pre-uploaded file references. + if (fileIds && fileIds.length > 0) { + for (const fileId of fileIds) { + content.push({ type: "file", file_id: fileId }); + } + } if (editedMessageID !== undefined) { const request: TypesGen.EditChatMessageRequest = { content }; clearChatErrorReason(agentId); @@ -1091,6 +1219,7 @@ const AgentDetail: FC = () => { onCancelQueueEdit={editing.handleCancelQueueEdit} isEditingHistoryMessage={editing.editingMessageId !== null} onCancelHistoryEdit={editing.handleCancelHistoryEdit} + editingFileBlocks={editing.editingFileBlocks} /> diff --git a/site/src/pages/AgentsPage/AgentDetail/ConversationTimeline.tsx b/site/src/pages/AgentsPage/AgentDetail/ConversationTimeline.tsx index aba7a8844d..5867d2328f 100644 --- a/site/src/pages/AgentsPage/AgentDetail/ConversationTimeline.tsx +++ b/site/src/pages/AgentsPage/AgentDetail/ConversationTimeline.tsx @@ -18,6 +18,8 @@ import { useState, } from "react"; import { cn } from "utils/cn"; +import { ImageThumbnail } from "../AgentChatInput"; +import { ImageLightbox } from "../ImageLightbox"; import { useSmoothStreamingText } from "./SmoothText"; import type { MergedTool, @@ -102,6 +104,7 @@ type RenderBlockListParams = { isStreaming?: boolean; subagentTitles?: Map; subagentStatusOverrides?: Map; + onImageClick?: (src: string) => void; }; // Wrapper that runs the smooth-streaming jitter buffer on a single @@ -132,6 +135,7 @@ function renderBlockList({ isStreaming = false, subagentTitles, subagentStatusOverrides, + onImageClick, }: RenderBlockListParams): RenderBlockListResult { const renderedToolIDs = new Set(); const elements = blocks @@ -194,6 +198,30 @@ function renderBlockList({ /> ); } + case "file": + if (block.mediaType.startsWith("image/")) { + const src = block.fileId + ? `/api/experimental/chats/files/${block.fileId}` + : `data:${block.mediaType};base64,${block.data}`; + return ( + + ); + } + return null; default: return null; } @@ -205,7 +233,11 @@ function renderBlockList({ const ChatMessageItem = memo<{ message: TypesGen.ChatMessage; parsed: ParsedMessageContent; - onEditUserMessage?: (messageId: number, text: string) => void; + onEditUserMessage?: ( + messageId: number, + text: string, + fileBlocks?: Array<{ mediaType: string; data?: string; fileId?: string }>, + ) => void; editingMessageId?: number | null; savingMessageId?: number | null; // When true, renders a gradient overlay inside the bubble @@ -223,6 +255,7 @@ const ChatMessageItem = memo<{ }) => { const isUser = message.role === "user"; const isSavingMessage = savingMessageId === message.id; + const [previewImage, setPreviewImage] = useState(null); const toolByID = new Map(parsed.tools.map((tool) => [tool.id, tool])); if ( @@ -243,82 +276,137 @@ const ChatMessageItem = memo<{ blocks: parsed.blocks, toolByID, keyPrefix: String(message.id), + onImageClick: setPreviewImage, }); const remainingTools = parsed.tools.filter( (tool) => !renderedToolIDs.has(tool.id), ); return ( - - {isUser ? ( - - onEditUserMessage(message.id, parsed.markdown || "") - : undefined - } - > -
- {parsed.markdown || ""} - {isSavingMessage && ( - + + {isUser ? ( + + { + const fileBlocks = parsed.blocks.filter( + (b): b is Extract => + b.type === "file" && + b.mediaType.startsWith("image/"), + ); + onEditUserMessage( + message.id, + parsed.markdown || "", + fileBlocks.length > 0 ? fileBlocks : undefined, + ); + } + : undefined + } + > +
+ + {parsed.markdown || ""} + + {isSavingMessage && ( + + )} +
+ {(() => { + const imageBlocks = parsed.blocks.filter( + (b): b is Extract => + b.type === "file" && b.mediaType.startsWith("image/"), + ); + if (imageBlocks.length === 0) return null; + return ( +
+ {imageBlocks.map((block, i) => { + const src = block.fileId + ? `/api/experimental/chats/files/${block.fileId}` + : `data:${block.mediaType};base64,${block.data}`; + return ( + + ); + })} +
+ ); + })()} + {fadeFromBottom && ( +
)} -
- {fadeFromBottom && ( -
- )} - - - ) : ( - - -
- {orderedBlocks} - {remainingTools.map((tool) => ( - - ))} - {!hasRenderableContent && ( -
- Message has no renderable content. -
- )} -
-
-
+ + + ) : ( + + +
+ {orderedBlocks} + {remainingTools.map((tool) => ( + + ))} + {!hasRenderableContent && ( +
+ Message has no renderable content. +
+ )} +
+
+
+ )} + + {previewImage && ( + setPreviewImage(null)} + /> )} - + ); }, ); @@ -405,7 +493,11 @@ StreamingOutput.displayName = "StreamingOutput"; const StickyUserMessage: FC<{ message: TypesGen.ChatMessage; parsed: ParsedMessageContent; - onEditUserMessage?: (messageId: number, text: string) => void; + onEditUserMessage?: ( + messageId: number, + text: string, + fileBlocks?: Array<{ mediaType: string; data?: string; fileId?: string }>, + ) => void; editingMessageId?: number | null; savingMessageId?: number | null; }> = ({ @@ -540,8 +632,16 @@ const StickyUserMessage: FC<{ }, [isStuck]); const handleEditUserMessage = onEditUserMessage - ? (messageId: number, text: string) => { - onEditUserMessage(messageId, text); + ? ( + messageId: number, + text: string, + fileBlocks?: Array<{ + mediaType: string; + data?: string; + fileId?: string; + }>, + ) => { + onEditUserMessage(messageId, text, fileBlocks); requestAnimationFrame(() => { const sentinel = sentinelRef.current; if (!sentinel) return; @@ -653,7 +753,11 @@ type ConversationTimelineProps = { retryState?: { attempt: number; error: string } | null; isAwaitingFirstStreamChunk: boolean; detailErrorMessage?: string | null; - onEditUserMessage?: (messageId: number, text: string) => void; + onEditUserMessage?: ( + messageId: number, + text: string, + fileBlocks?: Array<{ mediaType: string; data?: string; fileId?: string }>, + ) => void; editingMessageId?: number | null; savingMessageId?: number | null; }; diff --git a/site/src/pages/AgentsPage/AgentDetail/messageParsing.test.ts b/site/src/pages/AgentsPage/AgentDetail/messageParsing.test.ts index edb58526b8..279bd29d26 100644 --- a/site/src/pages/AgentsPage/AgentDetail/messageParsing.test.ts +++ b/site/src/pages/AgentsPage/AgentDetail/messageParsing.test.ts @@ -229,6 +229,40 @@ describe("parseMessageContent", () => { expect(result.toolCalls).toHaveLength(1); expect(result.toolCalls[0].name).toBe("test"); }); + + it("extracts fileId from a file block with file_id", () => { + const result = parseMessageContent([ + { + type: "file", + media_type: "image/png", + file_id: "abc-123-def", + }, + ]); + expect(result.blocks).toHaveLength(1); + expect(result.blocks[0]).toEqual({ + type: "file", + mediaType: "image/png", + data: undefined, + fileId: "abc-123-def", + }); + }); + + it("parses a file block without file_id (backward compat)", () => { + const result = parseMessageContent([ + { + type: "file", + media_type: "image/png", + data: "iVBORw0KGgo=", + }, + ]); + expect(result.blocks).toHaveLength(1); + expect(result.blocks[0]).toEqual({ + type: "file", + mediaType: "image/png", + data: "iVBORw0KGgo=", + fileId: undefined, + }); + }); }); describe("mergeTools", () => { diff --git a/site/src/pages/AgentsPage/AgentDetail/messageParsing.ts b/site/src/pages/AgentsPage/AgentDetail/messageParsing.ts index a4731f414d..fff9eca30c 100644 --- a/site/src/pages/AgentsPage/AgentDetail/messageParsing.ts +++ b/site/src/pages/AgentsPage/AgentDetail/messageParsing.ts @@ -216,6 +216,23 @@ export const parseMessageContent = (content: unknown): ParsedMessageContent => { parsed.blocks = ensureToolBlock(parsed.blocks, id); break; } + case "file": { + const mediaType = asString(typedBlock.media_type); + const data = asString(typedBlock.data); + const fileId = asString(typedBlock.file_id); + if (mediaType && (data || fileId)) { + parsed.blocks = [ + ...parsed.blocks, + { + type: "file", + mediaType, + data: data || undefined, + fileId: fileId || undefined, + }, + ]; + } + break; + } default: { const text = asString(typedBlock.text); parsed.markdown = appendText(parsed.markdown, text); diff --git a/site/src/pages/AgentsPage/AgentDetail/streamState.ts b/site/src/pages/AgentsPage/AgentDetail/streamState.ts index 9681be3f5b..b4291d6ff9 100644 --- a/site/src/pages/AgentsPage/AgentDetail/streamState.ts +++ b/site/src/pages/AgentsPage/AgentDetail/streamState.ts @@ -150,6 +150,26 @@ export const applyMessagePartToStreamState = ( }, }; } + case "file": { + const mediaType = asString(part.media_type); + const data = asString(part.data); + const fileId = asString(part.file_id); + if (!mediaType || (!data && !fileId)) { + return prev; + } + return { + ...nextState, + blocks: [ + ...nextState.blocks, + { + type: "file", + mediaType, + data: data || undefined, + fileId: fileId || undefined, + }, + ], + }; + } default: return prev; } diff --git a/site/src/pages/AgentsPage/AgentDetail/types.ts b/site/src/pages/AgentsPage/AgentDetail/types.ts index efb436e0dc..7e01d864c2 100644 --- a/site/src/pages/AgentsPage/AgentDetail/types.ts +++ b/site/src/pages/AgentsPage/AgentDetail/types.ts @@ -35,6 +35,12 @@ export type RenderBlock = | { type: "tool"; id: string; + } + | { + type: "file"; + mediaType: string; + data?: string; // base64, absent when file_id is available + fileId?: string; }; export type ParsedMessageContent = { diff --git a/site/src/pages/AgentsPage/AgentsPage.stories.tsx b/site/src/pages/AgentsPage/AgentsPage.stories.tsx index c6a2f81958..7b65c9b35f 100644 --- a/site/src/pages/AgentsPage/AgentsPage.stories.tsx +++ b/site/src/pages/AgentsPage/AgentsPage.stories.tsx @@ -1,4 +1,5 @@ import { MockWorkspace } from "testHelpers/entities"; +import { withDashboardProvider } from "testHelpers/storybook"; import type { Meta, StoryObj } from "@storybook/react-vite"; import { API } from "api/api"; import { @@ -26,6 +27,7 @@ const behaviorStorageKey = "agents.system-prompt"; const meta: Meta = { title: "pages/AgentsPage/AgentsEmptyState", component: AgentsEmptyState, + decorators: [withDashboardProvider], args: { onCreateChat: fn(), isCreating: false, diff --git a/site/src/pages/AgentsPage/AgentsPage.tsx b/site/src/pages/AgentsPage/AgentsPage.tsx index cfbd60318c..a00a78c204 100644 --- a/site/src/pages/AgentsPage/AgentsPage.tsx +++ b/site/src/pages/AgentsPage/AgentsPage.tsx @@ -59,6 +59,7 @@ import { } from "./modelOptions"; import { useAgentsPageKeybindings } from "./useAgentsPageKeybindings"; import { useAgentsPWA } from "./useAgentsPWA"; +import { useFileAttachments } from "./useFileAttachments"; import { WebPushButton } from "./WebPushButton"; /** @internal Exported for testing. */ @@ -72,6 +73,7 @@ type ChatModelOption = ModelSelectorOption; type CreateChatOptions = { message: string; + fileIDs?: string[]; workspaceId?: string; model?: string; }; @@ -329,11 +331,20 @@ const AgentsPage: FC = () => { ], ); const handleCreateChat = async (options: CreateChatOptions) => { - const { message, workspaceId, model } = options; + const { message, fileIDs, workspaceId, model } = options; const modelConfigID = (model && modelConfigIDByModelID.get(model)) || nilUUID; + const content: TypesGen.ChatInputPart[] = []; + if (message.trim()) { + content.push({ type: "text", text: message }); + } + if (fileIDs) { + for (const fileID of fileIDs) { + content.push({ type: "file", file_id: fileID }); + } + } const createdChat = await createMutation.mutateAsync({ - content: [{ type: "text", text: message }], + content, workspace_id: workspaceId, model_config_id: modelConfigID, }); @@ -686,6 +697,7 @@ export const AgentsEmptyState: FC = ({ isConfigureAgentsDialogOpen, onConfigureAgentsDialogOpenChange, }) => { + const { organizations } = useDashboard(); const { initialInputValue, handleContentChange, submitDraft, resetDraft } = useEmptyStateDraft(); const initialSystemPrompt = () => { @@ -855,10 +867,11 @@ export const AgentsEmptyState: FC = ({ ); const handleSend = useCallback( - (message: string) => { + async (message: string, fileIDs?: string[]) => { submitDraft(); - void onCreateChat({ + await onCreateChat({ message, + fileIDs, workspaceId: selectedWorkspaceIdRef.current ?? undefined, model: selectedModelRef.current || undefined, }).catch(() => { @@ -877,6 +890,44 @@ export const AgentsEmptyState: FC = ({ ? `${selectedWorkspace.owner_name}/${selectedWorkspace.name}` : undefined; + const { + attachments, + uploadStates, + previewUrls, + handleAttach, + handleRemoveAttachment, + resetAttachments, + } = useFileAttachments(organizations[0]?.id); + + const handleSendWithAttachments = useCallback( + async (message: string) => { + const fileIds: string[] = []; + let skippedErrors = 0; + for (const file of attachments) { + const state = uploadStates.get(file); + if (state?.status === "error") { + skippedErrors++; + continue; + } + if (state?.status === "uploaded" && state.fileId) { + fileIds.push(state.fileId); + } + } + if (skippedErrors > 0) { + toast.warning( + `${skippedErrors} attachment${skippedErrors > 1 ? "s" : ""} could not be sent (upload failed)`, + ); + } + try { + await handleSend(message, fileIds.length > 0 ? fileIds : undefined); + resetAttachments(); + } catch { + // Attachments preserved for retry on failure. + } + }, + [attachments, handleSend, resetAttachments, uploadStates], + ); + return (
@@ -886,7 +937,7 @@ export const AgentsEmptyState: FC = ({ )} = ({ hasModelOptions={hasModelOptions} inputStatusText={inputStatusText} modelCatalogStatusMessage={modelCatalogStatusMessage} + attachments={attachments} + onAttach={handleAttach} + onRemoveAttachment={handleRemoveAttachment} + uploadStates={uploadStates} + previewUrls={previewUrls} leftActions={ + new File(["mock-data"], name, { type }); + +const meta: Meta = { + title: "pages/AgentsPage/AttachmentPreview", + component: AttachmentPreview, + decorators: [ + (Story) => ( +
+ +
+ ), + ], + args: { + onRemove: fn(), + onPreview: fn(), + }, +}; + +export default meta; +type Story = StoryObj; + +export const SingleImage: Story = { + args: (() => { + const file = createMockFile("photo.png", "image/png"); + return { + attachments: [file], + uploadStates: new Map([ + [file, { status: "uploaded", fileId: "file-1" }], + ]), + previewUrls: new Map([[file, TINY_PNG]]), + }; + })(), +}; + +export const MultipleImages: Story = { + args: (() => { + const files = [ + createMockFile("photo-1.png", "image/png"), + createMockFile("photo-2.jpg", "image/jpeg"), + createMockFile("photo-3.png", "image/png"), + ]; + return { + attachments: files, + uploadStates: new Map( + files.map((f) => [f, { status: "uploaded", fileId: f.name }]), + ), + previewUrls: new Map(files.map((f) => [f, TINY_PNG])), + }; + })(), +}; + +export const Uploading: Story = { + args: (() => { + const file = createMockFile("uploading.png", "image/png"); + return { + attachments: [file], + uploadStates: new Map([ + [file, { status: "uploading" }], + ]), + previewUrls: new Map([[file, TINY_PNG]]), + }; + })(), +}; + +export const UploadError: Story = { + args: (() => { + const file = createMockFile("broken.png", "image/png"); + return { + attachments: [file], + uploadStates: new Map([ + [file, { status: "error", error: "Upload failed: server error" }], + ]), + previewUrls: new Map([[file, TINY_PNG]]), + }; + })(), +}; + +export const FileTooLarge: Story = { + args: (() => { + const file = createMockFile("huge-screenshot.png", "image/png"); + return { + attachments: [file], + uploadStates: new Map([ + [ + file, + { + status: "error", + error: "File too large (12.4 MB). Maximum is 10 MB.", + }, + ], + ]), + previewUrls: new Map([[file, TINY_PNG]]), + }; + })(), + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + const overlay = canvas.getByLabelText("Upload error"); + await userEvent.hover(overlay); + }, +}; + +export const NonImageFile: Story = { + args: (() => { + const file = createMockFile("readme.txt", "text/plain"); + return { + attachments: [file], + uploadStates: new Map([ + [file, { status: "uploaded", fileId: "file-txt" }], + ]), + }; + })(), +}; + +export const MixedStates: Story = { + args: (() => { + const uploaded = createMockFile("done.png", "image/png"); + const uploading = createMockFile("pending.jpg", "image/jpeg"); + const errored = createMockFile("failed.png", "image/png"); + const attachments = [uploaded, uploading, errored]; + return { + attachments, + uploadStates: new Map([ + [uploaded, { status: "uploaded", fileId: "file-ok" }], + [uploading, { status: "uploading" }], + [errored, { status: "error", error: "Network timeout" }], + ]), + previewUrls: new Map([ + [uploaded, TINY_PNG], + [uploading, TINY_PNG], + [errored, TINY_PNG], + ]), + }; + })(), +}; diff --git a/site/src/pages/AgentsPage/ImageLightbox.stories.tsx b/site/src/pages/AgentsPage/ImageLightbox.stories.tsx new file mode 100644 index 0000000000..6bc3bc9b27 --- /dev/null +++ b/site/src/pages/AgentsPage/ImageLightbox.stories.tsx @@ -0,0 +1,30 @@ +import type { Meta, StoryObj } from "@storybook/react-vite"; +import { fn } from "storybook/test"; +import { ImageLightbox } from "./ImageLightbox"; + +// Tiny 1x1 colored PNG so the lightbox has something visible to display. +const TINY_PNG = + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="; + +const meta: Meta = { + title: "pages/AgentsPage/ImageLightbox", + component: ImageLightbox, + decorators: [ + (Story) => ( +
+

Background content behind the lightbox overlay

+ +
+ ), + ], +}; + +export default meta; +type Story = StoryObj; + +export const Default: Story = { + args: { + src: TINY_PNG, + onClose: fn(), + }, +}; diff --git a/site/src/pages/AgentsPage/ImageLightbox.tsx b/site/src/pages/AgentsPage/ImageLightbox.tsx new file mode 100644 index 0000000000..a6a2ae8b2b --- /dev/null +++ b/site/src/pages/AgentsPage/ImageLightbox.tsx @@ -0,0 +1,25 @@ +import { Dialog, DialogContent, DialogTitle } from "components/Dialog/Dialog"; +import type { FC } from "react"; + +interface ImageLightboxProps { + src: string; + onClose: () => void; +} + +export const ImageLightbox: FC = ({ src, onClose }) => { + return ( + !open && onClose()}> + + Image preview + Attachment preview + + + ); +}; diff --git a/site/src/pages/AgentsPage/useFileAttachments.ts b/site/src/pages/AgentsPage/useFileAttachments.ts new file mode 100644 index 0000000000..2c1a7b6228 --- /dev/null +++ b/site/src/pages/AgentsPage/useFileAttachments.ts @@ -0,0 +1,160 @@ +import { API } from "api/api"; +import { getErrorDetail, getErrorMessage } from "api/errors"; +import { + type Dispatch, + type SetStateAction, + useCallback, + useEffect, + useRef, + useState, +} from "react"; +import type { UploadState } from "./AgentChatInput"; + +interface UseFileAttachmentsReturn { + attachments: File[]; + uploadStates: Map; + previewUrls: Map; + handleAttach: (files: File[]) => void; + handleRemoveAttachment: (index: number) => void; + startUpload: (file: File) => void; + resetAttachments: () => void; + setAttachments: Dispatch>; + setPreviewUrls: Dispatch>>; + setUploadStates: Dispatch>>; +} + +export function useFileAttachments( + organizationId: string | undefined, +): UseFileAttachmentsReturn { + const [attachments, setAttachments] = useState([]); + const [uploadStates, setUploadStates] = useState( + () => new Map(), + ); + const [previewUrls, setPreviewUrls] = useState(() => new Map()); + + // Revoke blob URLs on unmount to prevent memory leaks. + const previewUrlsRef = useRef(previewUrls); + previewUrlsRef.current = previewUrls; + useEffect(() => { + return () => { + for (const [, url] of previewUrlsRef.current) { + if (url.startsWith("blob:")) URL.revokeObjectURL(url); + } + }; + }, []); + + const startUpload = useCallback( + (file: File) => { + if (!organizationId) { + setUploadStates((prev) => + new Map(prev).set(file, { + status: "error", + error: "Unable to upload: no organization context.", + }), + ); + return; + } + setUploadStates((prev) => + new Map(prev).set(file, { status: "uploading" }), + ); + void (async () => { + try { + const result = await API.uploadChatFile(file, organizationId); + setUploadStates((prev) => + new Map(prev).set(file, { + status: "uploaded", + fileId: result.id, + }), + ); + // Pre-warm the browser HTTP cache so the timeline + // can render this image instantly after send. The + // server responds with Cache-Control: private, + // immutable, so the never hits the + // network again. + void fetch(`/api/experimental/chats/files/${result.id}`); + } catch (err: unknown) { + const message = getErrorMessage(err, "Upload failed"); + const detail = getErrorDetail(err); + const errorMessage = detail ? `${message} ${detail}` : message; + setUploadStates((prev) => + new Map(prev).set(file, { + status: "error", + error: errorMessage, + }), + ); + } + })(); + }, + [organizationId], + ); + + const handleAttach = useCallback( + (files: File[]) => { + const maxSize = 10 * 1024 * 1024; // 10 MB + setAttachments((prev) => [...prev, ...files]); + setPreviewUrls((prev) => { + const next = new Map(prev); + for (const file of files) { + next.set(file, URL.createObjectURL(file)); + } + return next; + }); + for (const file of files) { + if (file.size > maxSize) { + setUploadStates((prev) => + new Map(prev).set(file, { + status: "error" as const, + error: `File too large (${(file.size / 1024 / 1024).toFixed(1)} MB). Maximum is 10 MB.`, + }), + ); + } else { + startUpload(file); + } + } + }, + [startUpload], + ); + + const handleRemoveAttachment = useCallback((index: number) => { + setAttachments((prev) => { + const removed = prev[index]; + if (removed) { + setUploadStates((prevStates) => { + const next = new Map(prevStates); + next.delete(removed); + return next; + }); + setPreviewUrls((prevUrls) => { + const url = prevUrls.get(removed); + if (url?.startsWith("blob:")) URL.revokeObjectURL(url); + const next = new Map(prevUrls); + next.delete(removed); + return next; + }); + } + return prev.filter((_, i) => i !== index); + }); + }, []); + + const resetAttachments = useCallback(() => { + for (const [, url] of previewUrlsRef.current) { + if (url.startsWith("blob:")) URL.revokeObjectURL(url); + } + setPreviewUrls(new Map()); + setUploadStates(new Map()); + setAttachments([]); + }, []); + + return { + attachments, + uploadStates, + previewUrls, + handleAttach, + handleRemoveAttachment, + startUpload, + resetAttachments, + setAttachments, + setPreviewUrls, + setUploadStates, + }; +}