diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index bf1cf50734..855ca05c60 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -1528,7 +1528,10 @@ func nullInt64Ptr(v sql.NullInt64) *int64 { // Chat converts a database.Chat to a codersdk.Chat. It coalesces // nil slices and maps to empty values for JSON serialization and // derives RootChatID from the parent chain when not explicitly set. -func Chat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat { +// When diffStatus is non-nil the response includes diff metadata. +// When files is non-empty the response includes file metadata; +// pass nil to omit the files field (e.g. list endpoints). +func Chat(c database.Chat, diffStatus *database.ChatDiffStatus, files []database.GetChatFileMetadataByChatIDRow) codersdk.Chat { mcpServerIDs := c.MCPServerIDs if mcpServerIDs == nil { mcpServerIDs = []uuid.UUID{} @@ -1581,6 +1584,19 @@ func Chat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat { convertedDiffStatus := ChatDiffStatus(c.ID, diffStatus) chat.DiffStatus = &convertedDiffStatus } + if len(files) > 0 { + chat.Files = make([]codersdk.ChatFileMetadata, 0, len(files)) + for _, row := range files { + chat.Files = append(chat.Files, codersdk.ChatFileMetadata{ + ID: row.ID, + OwnerID: row.OwnerID, + OrganizationID: row.OrganizationID, + Name: row.Name, + MimeType: row.Mimetype, + CreatedAt: row.CreatedAt, + }) + } + } if c.LastInjectedContext.Valid { var parts []codersdk.ChatMessagePart // Internal fields are stripped at write time in @@ -1604,9 +1620,9 @@ func ChatRows(rows []database.GetChatsRow, diffStatusesByChatID map[uuid.UUID]da for i, row := range rows { diffStatus, ok := diffStatusesByChatID[row.Chat.ID] if ok { - result[i] = Chat(row.Chat, &diffStatus) + result[i] = Chat(row.Chat, &diffStatus, nil) } else { - result[i] = Chat(row.Chat, nil) + result[i] = Chat(row.Chat, nil, nil) if diffStatusesByChatID != nil { emptyDiffStatus := ChatDiffStatus(row.Chat.ID, nil) result[i].DiffStatus = &emptyDiffStatus diff --git a/coderd/database/db2sdk/db2sdk_test.go b/coderd/database/db2sdk/db2sdk_test.go index 2738b5670a..2282f36fbd 100644 --- a/coderd/database/db2sdk/db2sdk_test.go +++ b/coderd/database/db2sdk/db2sdk_test.go @@ -561,14 +561,26 @@ func TestChat_AllFieldsPopulated(t *testing.T) { ChatID: input.ID, } - got := db2sdk.Chat(input, diffStatus) + fileRows := []database.GetChatFileMetadataByChatIDRow{ + { + ID: uuid.New(), + OwnerID: input.OwnerID, + OrganizationID: uuid.New(), + Name: "test.png", + Mimetype: "image/png", + CreatedAt: now, + }, + } + + got := db2sdk.Chat(input, diffStatus, fileRows) v := reflect.ValueOf(got) typ := v.Type() // HasUnread is populated by ChatRows (which joins the - // read-cursor query), not by Chat, so it is expected - // to remain zero here. - skip := map[string]bool{"HasUnread": true} + // read-cursor query), not by Chat. Warnings is a transient + // field populated by handlers, not the converter. Both are + // expected to remain zero here. + skip := map[string]bool{"HasUnread": true, "Warnings": true} for i := range typ.NumField() { field := typ.Field(i) if skip[field.Name] { @@ -581,6 +593,112 @@ func TestChat_AllFieldsPopulated(t *testing.T) { } } +func TestChat_FileMetadataConversion(t *testing.T) { + t.Parallel() + + ownerID := uuid.New() + orgID := uuid.New() + fileID := uuid.New() + now := dbtime.Now() + + chat := database.Chat{ + ID: uuid.New(), + OwnerID: ownerID, + LastModelConfigID: uuid.New(), + Title: "file metadata test", + Status: database.ChatStatusWaiting, + CreatedAt: now, + UpdatedAt: now, + } + + rows := []database.GetChatFileMetadataByChatIDRow{ + { + ID: fileID, + OwnerID: ownerID, + OrganizationID: orgID, + Name: "screenshot.png", + Mimetype: "image/png", + CreatedAt: now, + }, + } + + result := db2sdk.Chat(chat, nil, rows) + + require.Len(t, result.Files, 1) + f := result.Files[0] + require.Equal(t, fileID, f.ID) + require.Equal(t, ownerID, f.OwnerID, "OwnerID must be mapped from DB row") + require.Equal(t, orgID, f.OrganizationID, "OrganizationID must be mapped from DB row") + require.Equal(t, "screenshot.png", f.Name) + require.Equal(t, "image/png", f.MimeType) + require.Equal(t, now, f.CreatedAt) + + // Verify JSON serialization uses snake_case for mime_type. + data, err := json.Marshal(f) + require.NoError(t, err) + require.Contains(t, string(data), `"mime_type"`) + require.NotContains(t, string(data), `"mimetype"`) +} + +func TestChat_NilFilesOmitted(t *testing.T) { + t.Parallel() + + chat := database.Chat{ + ID: uuid.New(), + OwnerID: uuid.New(), + LastModelConfigID: uuid.New(), + Title: "no files", + Status: database.ChatStatusWaiting, + CreatedAt: dbtime.Now(), + UpdatedAt: dbtime.Now(), + } + + result := db2sdk.Chat(chat, nil, nil) + require.Empty(t, result.Files) +} + +func TestChat_MultipleFiles(t *testing.T) { + t.Parallel() + + now := dbtime.Now() + file1 := uuid.New() + file2 := uuid.New() + + chat := database.Chat{ + ID: uuid.New(), + OwnerID: uuid.New(), + LastModelConfigID: uuid.New(), + Title: "multi file test", + Status: database.ChatStatusWaiting, + CreatedAt: now, + UpdatedAt: now, + } + + rows := []database.GetChatFileMetadataByChatIDRow{ + { + ID: file1, + OwnerID: chat.OwnerID, + OrganizationID: uuid.New(), + Name: "a.png", + Mimetype: "image/png", + CreatedAt: now, + }, + { + ID: file2, + OwnerID: chat.OwnerID, + OrganizationID: uuid.New(), + Name: "b.txt", + Mimetype: "text/plain", + CreatedAt: now, + }, + } + + result := db2sdk.Chat(chat, nil, rows) + require.Len(t, result.Files, 2) + require.Equal(t, "a.png", result.Files[0].Name) + require.Equal(t, "b.txt", result.Files[1].Name) +} + func TestChatQueuedMessage_MalformedContent(t *testing.T) { t.Parallel() diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 03a09965bf..fa6d1b2398 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2583,6 +2583,10 @@ func (q *querier) GetChatFileByID(ctx context.Context, id uuid.UUID) (database.C return file, nil } +func (q *querier) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) { + return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatFileMetadataByChatID)(ctx, chatID) +} + func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) { files, err := q.db.GetChatFilesByIDs(ctx, ids) if err != nil { @@ -5393,6 +5397,17 @@ func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg datab return q.db.InsertWorkspaceResourceMetadata(ctx, arg) } +func (q *querier) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) { + chat, err := q.db.GetChatByID(ctx, arg.ChatID) + if err != nil { + return 0, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil { + return 0, err + } + return q.db.LinkChatFiles(ctx, arg) +} + func (q *querier) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) { prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type) if err != nil { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 26c2cbb654..22648aa6a5 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -400,6 +400,17 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return([]database.Chat{chat}, nil).AnyTimes() check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns([]database.Chat{chat}) })) + s.Run("LinkChatFiles", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + arg := database.LinkChatFilesParams{ + ChatID: chat.ID, + MaxFileLinks: int32(codersdk.MaxChatFileIDs), + FileIds: []uuid.UUID{uuid.New()}, + } + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().LinkChatFiles(gomock.Any(), arg).Return(int32(0), nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int32(0)) + })) s.Run("PinChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() @@ -576,6 +587,19 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetChatFilesByIDs(gomock.Any(), []uuid.UUID{file.ID}).Return([]database.ChatFile{file}, nil).AnyTimes() check.Args([]uuid.UUID{file.ID}).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns([]database.ChatFile{file}) })) + s.Run("GetChatFileMetadataByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + file := testutil.Fake(s.T(), faker, database.ChatFile{}) + rows := []database.GetChatFileMetadataByChatIDRow{{ + ID: file.ID, + Name: file.Name, + Mimetype: file.Mimetype, + CreatedAt: file.CreatedAt, + OwnerID: file.OwnerID, + OrganizationID: file.OrganizationID, + }} + dbm.EXPECT().GetChatFileMetadataByChatID(gomock.Any(), file.ID).Return(rows, nil).AnyTimes() + check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(rows) + })) s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID}) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 1243a9138b..1c4e5955b0 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1128,6 +1128,14 @@ func (m queryMetricsStore) GetChatFileByID(ctx context.Context, id uuid.UUID) (d return r0, r1 } +func (m queryMetricsStore) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) { + start := time.Now() + r0, r1 := m.s.GetChatFileMetadataByChatID(ctx, chatID) + m.queryLatencies.WithLabelValues("GetChatFileMetadataByChatID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatFileMetadataByChatID").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) { start := time.Now() r0, r1 := m.s.GetChatFilesByIDs(ctx, ids) @@ -3776,6 +3784,14 @@ func (m queryMetricsStore) InsertWorkspaceResourceMetadata(ctx context.Context, return r0, r1 } +func (m queryMetricsStore) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) { + start := time.Now() + r0, r1 := m.s.LinkChatFiles(ctx, arg) + m.queryLatencies.WithLabelValues("LinkChatFiles").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "LinkChatFiles").Inc() + return r0, r1 +} + func (m queryMetricsStore) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) { start := time.Now() r0, r1 := m.s.ListAIBridgeClients(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 6d3a07699f..dcc142214e 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -2072,6 +2072,21 @@ func (mr *MockStoreMockRecorder) GetChatFileByID(ctx, id any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileByID", reflect.TypeOf((*MockStore)(nil).GetChatFileByID), ctx, id) } +// GetChatFileMetadataByChatID mocks base method. +func (m *MockStore) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]database.GetChatFileMetadataByChatIDRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatFileMetadataByChatID", ctx, chatID) + ret0, _ := ret[0].([]database.GetChatFileMetadataByChatIDRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatFileMetadataByChatID indicates an expected call of GetChatFileMetadataByChatID. +func (mr *MockStoreMockRecorder) GetChatFileMetadataByChatID(ctx, chatID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFileMetadataByChatID", reflect.TypeOf((*MockStore)(nil).GetChatFileMetadataByChatID), ctx, chatID) +} + // GetChatFilesByIDs mocks base method. func (m *MockStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ChatFile, error) { m.ctrl.T.Helper() @@ -7066,6 +7081,21 @@ func (mr *MockStoreMockRecorder) InsertWorkspaceResourceMetadata(ctx, arg any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertWorkspaceResourceMetadata", reflect.TypeOf((*MockStore)(nil).InsertWorkspaceResourceMetadata), ctx, arg) } +// LinkChatFiles mocks base method. +func (m *MockStore) LinkChatFiles(ctx context.Context, arg database.LinkChatFilesParams) (int32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkChatFiles", ctx, arg) + ret0, _ := ret[0].(int32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LinkChatFiles indicates an expected call of LinkChatFiles. +func (mr *MockStoreMockRecorder) LinkChatFiles(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkChatFiles", reflect.TypeOf((*MockStore)(nil).LinkChatFiles), ctx, arg) +} + // ListAIBridgeClients mocks base method. func (m *MockStore) ListAIBridgeClients(ctx context.Context, arg database.ListAIBridgeClientsParams) ([]string, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 2b202d78c5..b3d7d9081d 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1269,6 +1269,11 @@ CREATE TABLE chat_diff_statuses ( head_branch text ); +CREATE TABLE chat_file_links ( + chat_id uuid NOT NULL, + file_id uuid NOT NULL +); + CREATE TABLE chat_files ( id uuid DEFAULT gen_random_uuid() NOT NULL, owner_id uuid NOT NULL, @@ -3344,6 +3349,9 @@ ALTER TABLE ONLY boundary_usage_stats ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id); +ALTER TABLE ONLY chat_file_links + ADD CONSTRAINT chat_file_links_chat_id_file_id_key UNIQUE (chat_id, file_id); + ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id); @@ -3734,6 +3742,8 @@ CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC); CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at); +CREATE INDEX idx_chat_file_links_chat_id ON chat_file_links USING btree (chat_id); + CREATE INDEX idx_chat_files_org ON chat_files USING btree (organization_id); CREATE INDEX idx_chat_files_owner ON chat_files USING btree (owner_id); @@ -4036,6 +4046,12 @@ ALTER TABLE ONLY api_keys ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; +ALTER TABLE ONLY chat_file_links + ADD CONSTRAINT chat_file_links_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; + +ALTER TABLE ONLY chat_file_links + ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE; + ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index ff4c021d77..f682e0d760 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -10,6 +10,8 @@ const ( ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id); ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; + ForeignKeyChatFileLinksChatID ForeignKeyConstraint = "chat_file_links_chat_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; + ForeignKeyChatFileLinksFileID ForeignKeyConstraint = "chat_file_links_file_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE; ForeignKeyChatFilesOrganizationID ForeignKeyConstraint = "chat_files_organization_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; ForeignKeyChatFilesOwnerID ForeignKeyConstraint = "chat_files_owner_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; diff --git a/coderd/database/migrations/000462_chat_file_links.down.sql b/coderd/database/migrations/000462_chat_file_links.down.sql new file mode 100644 index 0000000000..ceb5db9ef7 --- /dev/null +++ b/coderd/database/migrations/000462_chat_file_links.down.sql @@ -0,0 +1,9 @@ +ALTER TABLE chats ADD COLUMN file_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL; + +UPDATE chats SET file_ids = ( + SELECT COALESCE(array_agg(cfl.file_id), '{}') + FROM chat_file_links cfl + WHERE cfl.chat_id = chats.id +); + +DROP TABLE chat_file_links; diff --git a/coderd/database/migrations/000462_chat_file_links.up.sql b/coderd/database/migrations/000462_chat_file_links.up.sql new file mode 100644 index 0000000000..402bba7add --- /dev/null +++ b/coderd/database/migrations/000462_chat_file_links.up.sql @@ -0,0 +1,17 @@ +CREATE TABLE chat_file_links ( + chat_id uuid NOT NULL, + file_id uuid NOT NULL, + UNIQUE (chat_id, file_id) +); + +CREATE INDEX idx_chat_file_links_chat_id ON chat_file_links (chat_id); + +ALTER TABLE chat_file_links + ADD CONSTRAINT chat_file_links_chat_id_fkey + FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; + +ALTER TABLE chat_file_links + ADD CONSTRAINT chat_file_links_file_id_fkey + FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE; + +ALTER TABLE chats DROP COLUMN IF EXISTS file_ids; diff --git a/coderd/database/migrations/testdata/fixtures/000462_chat_file_links.up.sql b/coderd/database/migrations/testdata/fixtures/000462_chat_file_links.up.sql new file mode 100644 index 0000000000..7007c90c96 --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000462_chat_file_links.up.sql @@ -0,0 +1,5 @@ +INSERT INTO chat_file_links (chat_id, file_id) +VALUES ( + '72c0438a-18eb-4688-ab80-e4c6a126ef96', + '00000000-0000-0000-0000-000000000099' +); diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 147161f03c..8d5ef4906d 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -187,6 +187,10 @@ func (c ChatFile) RBACObject() rbac.Object { return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID) } +func (c GetChatFileMetadataByChatIDRow) RBACObject() rbac.Object { + return rbac.ResourceChat.WithID(c.ID).WithOwner(c.OwnerID.String()).InOrg(c.OrganizationID) +} + func (s APIKeyScope) ToRBAC() rbac.ScopeName { switch s { case ApiKeyScopeCoderAll: diff --git a/coderd/database/models.go b/coderd/database/models.go index a38d39a21e..a9e70db01a 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4218,6 +4218,11 @@ type ChatFile struct { Data []byte `db:"data" json:"data"` } +type ChatFileLink struct { + ChatID uuid.UUID `db:"chat_id" json:"chat_id"` + FileID uuid.UUID `db:"file_id" json:"file_id"` +} + type ChatMessage struct { ID int64 `db:"id" json:"id"` ChatID uuid.UUID `db:"chat_id" json:"chat_id"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 53072af425..fc442a56f3 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -244,6 +244,10 @@ type sqlcQuerier interface { GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error) GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error) + // GetChatFileMetadataByChatID returns lightweight file metadata for + // all files linked to a chat. The data column is excluded to avoid + // loading file content. + GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]GetChatFileMetadataByChatIDRow, error) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error) // GetChatIncludeDefaultSystemPrompt preserves the legacy default // for deployments created before the explicit include-default toggle. @@ -778,6 +782,15 @@ type sqlcQuerier interface { InsertWorkspaceProxy(ctx context.Context, arg InsertWorkspaceProxyParams) (WorkspaceProxy, error) InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error) InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error) + // LinkChatFiles inserts file associations into the chat_file_links + // join table with deduplication (ON CONFLICT DO NOTHING). The INSERT + // is conditional: it only proceeds when the total number of links + // (existing + genuinely new) does not exceed max_file_links. Returns + // the number of genuinely new file IDs that were NOT inserted due to + // the cap. A return value of 0 means all files were linked (or were + // already linked). A positive value means the cap blocked that many + // new links. + LinkChatFiles(ctx context.Context, arg LinkChatFilesParams) (int32, error) ListAIBridgeClients(ctx context.Context, arg ListAIBridgeClientsParams) ([]string, error) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams) ([]ListAIBridgeInterceptionsRow, error) // Finds all unique AI Bridge interception telemetry summaries combinations diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index faf6dd4e9c..58ae9bc037 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2889,6 +2889,56 @@ func (q *sqlQuerier) GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFil return i, err } +const getChatFileMetadataByChatID = `-- name: GetChatFileMetadataByChatID :many +SELECT cf.id, cf.owner_id, cf.organization_id, cf.name, cf.mimetype, cf.created_at +FROM chat_files cf +JOIN chat_file_links cfl ON cfl.file_id = cf.id +WHERE cfl.chat_id = $1::uuid +ORDER BY cf.created_at ASC +` + +type GetChatFileMetadataByChatIDRow struct { + ID uuid.UUID `db:"id" json:"id"` + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + Name string `db:"name" json:"name"` + Mimetype string `db:"mimetype" json:"mimetype"` + CreatedAt time.Time `db:"created_at" json:"created_at"` +} + +// GetChatFileMetadataByChatID returns lightweight file metadata for +// all files linked to a chat. The data column is excluded to avoid +// loading file content. +func (q *sqlQuerier) GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]GetChatFileMetadataByChatIDRow, error) { + rows, err := q.db.QueryContext(ctx, getChatFileMetadataByChatID, chatID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChatFileMetadataByChatIDRow + for rows.Next() { + var i GetChatFileMetadataByChatIDRow + if err := rows.Scan( + &i.ID, + &i.OwnerID, + &i.OrganizationID, + &i.Name, + &i.Mimetype, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getChatFilesByIDs = `-- name: GetChatFilesByIDs :many SELECT id, owner_id, organization_id, created_at, name, mimetype, data FROM chat_files WHERE id = ANY($1::uuid[]) ` @@ -6033,6 +6083,57 @@ func (q *sqlQuerier) InsertChatQueuedMessage(ctx context.Context, arg InsertChat return i, err } +const linkChatFiles = `-- name: LinkChatFiles :one +WITH current AS ( + SELECT COUNT(*) AS cnt + FROM chat_file_links + WHERE chat_id = $1::uuid +), +new_links AS ( + SELECT $1::uuid AS chat_id, unnest($2::uuid[]) AS file_id +), +genuinely_new AS ( + SELECT nl.chat_id, nl.file_id + FROM new_links nl + WHERE NOT EXISTS ( + SELECT 1 FROM chat_file_links cfl + WHERE cfl.chat_id = nl.chat_id AND cfl.file_id = nl.file_id + ) +), +inserted AS ( + INSERT INTO chat_file_links (chat_id, file_id) + SELECT gn.chat_id, gn.file_id + FROM genuinely_new gn, current c + WHERE c.cnt + (SELECT COUNT(*) FROM genuinely_new) <= $3::int + ON CONFLICT (chat_id, file_id) DO NOTHING + RETURNING file_id +) +SELECT + (SELECT COUNT(*)::int FROM genuinely_new) - + (SELECT COUNT(*)::int FROM inserted) AS rejected_new_files +` + +type LinkChatFilesParams struct { + ChatID uuid.UUID `db:"chat_id" json:"chat_id"` + FileIds []uuid.UUID `db:"file_ids" json:"file_ids"` + MaxFileLinks int32 `db:"max_file_links" json:"max_file_links"` +} + +// LinkChatFiles inserts file associations into the chat_file_links +// join table with deduplication (ON CONFLICT DO NOTHING). The INSERT +// is conditional: it only proceeds when the total number of links +// (existing + genuinely new) does not exceed max_file_links. Returns +// the number of genuinely new file IDs that were NOT inserted due to +// the cap. A return value of 0 means all files were linked (or were +// already linked). A positive value means the cap blocked that many +// new links. +func (q *sqlQuerier) LinkChatFiles(ctx context.Context, arg LinkChatFilesParams) (int32, error) { + row := q.db.QueryRowContext(ctx, linkChatFiles, arg.ChatID, pq.Array(arg.FileIds), arg.MaxFileLinks) + var rejected_new_files int32 + err := row.Scan(&rejected_new_files) + return rejected_new_files, err +} + const listChatUsageLimitGroupOverrides = `-- name: ListChatUsageLimitGroupOverrides :many SELECT g.id AS group_id, diff --git a/coderd/database/queries/chatfiles.sql b/coderd/database/queries/chatfiles.sql index 5cb2ad89fe..ac0ec0782e 100644 --- a/coderd/database/queries/chatfiles.sql +++ b/coderd/database/queries/chatfiles.sql @@ -8,3 +8,13 @@ SELECT * FROM chat_files WHERE id = @id::uuid; -- name: GetChatFilesByIDs :many SELECT * FROM chat_files WHERE id = ANY(@ids::uuid[]); + +-- name: GetChatFileMetadataByChatID :many +-- GetChatFileMetadataByChatID returns lightweight file metadata for +-- all files linked to a chat. The data column is excluded to avoid +-- loading file content. +SELECT cf.id, cf.owner_id, cf.organization_id, cf.name, cf.mimetype, cf.created_at +FROM chat_files cf +JOIN chat_file_links cfl ON cfl.file_id = cf.id +WHERE cfl.chat_id = @chat_id::uuid +ORDER BY cf.created_at ASC; diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index a7c0bf6740..b29d766b07 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -567,6 +567,43 @@ WHERE RETURNING *; +-- name: LinkChatFiles :one +-- LinkChatFiles inserts file associations into the chat_file_links +-- join table with deduplication (ON CONFLICT DO NOTHING). The INSERT +-- is conditional: it only proceeds when the total number of links +-- (existing + genuinely new) does not exceed max_file_links. Returns +-- the number of genuinely new file IDs that were NOT inserted due to +-- the cap. A return value of 0 means all files were linked (or were +-- already linked). A positive value means the cap blocked that many +-- new links. +WITH current AS ( + SELECT COUNT(*) AS cnt + FROM chat_file_links + WHERE chat_id = @chat_id::uuid +), +new_links AS ( + SELECT @chat_id::uuid AS chat_id, unnest(@file_ids::uuid[]) AS file_id +), +genuinely_new AS ( + SELECT nl.chat_id, nl.file_id + FROM new_links nl + WHERE NOT EXISTS ( + SELECT 1 FROM chat_file_links cfl + WHERE cfl.chat_id = nl.chat_id AND cfl.file_id = nl.file_id + ) +), +inserted AS ( + INSERT INTO chat_file_links (chat_id, file_id) + SELECT gn.chat_id, gn.file_id + FROM genuinely_new gn, current c + WHERE c.cnt + (SELECT COUNT(*) FROM genuinely_new) <= @max_file_links::int + ON CONFLICT (chat_id, file_id) DO NOTHING + RETURNING file_id +) +SELECT + (SELECT COUNT(*)::int FROM genuinely_new) - + (SELECT COUNT(*)::int FROM inserted) AS rejected_new_files; + -- name: AcquireChats :many -- Acquires up to @num_chats pending chats for processing. Uses SKIP LOCKED -- to prevent multiple replicas from acquiring the same chat. diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index e9c933ca08..33c017b535 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -247,6 +247,7 @@ sql: mcp_server_tool_snapshots: MCPServerToolSnapshots mcp_server_config_id: MCPServerConfigID mcp_server_ids: MCPServerIDs + max_file_links: MaxFileLinks icon_url: IconURL oauth2_client_id: OAuth2ClientID oauth2_client_secret: OAuth2ClientSecret diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index a329058947..0a6bd3fa7b 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -16,6 +16,7 @@ const ( UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id); UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id); UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id); + UniqueChatFileLinksChatIDFileIDKey UniqueConstraint = "chat_file_links_chat_id_file_id_key" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_file_id_key UNIQUE (chat_id, file_id); UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id); UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id); UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id); diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 44d3ac3886..c870e2e56a 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -413,7 +413,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { return } - contentBlocks, titleSource, inputError := createChatInputFromRequest(ctx, api.Database, req) + contentBlocks, titleSource, fileIDs, inputError := createChatInputFromRequest(ctx, api.Database, req) if inputError != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, *inputError) return @@ -524,7 +524,32 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { return } - httpapi.Write(ctx, rw, http.StatusCreated, db2sdk.Chat(chat, nil)) + // Link any user-uploaded files referenced in the initial + // message to this newly created chat (best-effort; cap + // enforced in SQL). + unlinked, capExceeded := api.linkFilesToChat(ctx, chat.ID, fileIDs) + + // Re-read the chat so the response reflects the authoritative + // database state (file links are deduped in the join table). + chat, err = api.Database.GetChatByID(ctx, chat.ID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to read back chat after creation.", + Detail: err.Error(), + }) + return + } + + chatFiles := api.fetchChatFileMetadata(ctx, chat.ID) + response := db2sdk.Chat(chat, nil, chatFiles) + if len(unlinked) > 0 { + if capExceeded { + response.Warnings = append(response.Warnings, fileLinkCapWarning(len(unlinked))) + } else { + response.Warnings = append(response.Warnings, fileLinkErrorWarning(len(unlinked))) + } + } + httpapi.Write(ctx, rw, http.StatusCreated, response) } // EXPERIMENTAL: this endpoint is experimental and is subject to change. @@ -1301,7 +1326,11 @@ func (api *API) getChat(rw http.ResponseWriter, r *http.Request) { slog.Error(err), ) } - httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, diffStatus)) + + // Hydrate file metadata for all files linked to this chat. + chatFiles := api.fetchChatFileMetadata(ctx, chat.ID) + + httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, diffStatus, chatFiles)) } // EXPERIMENTAL: this endpoint is experimental and is subject to change. @@ -1791,7 +1820,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { return } - contentBlocks, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content") + contentBlocks, _, fileIDs, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content") if inputError != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: inputError.Message, @@ -1873,6 +1902,9 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { return } + // Link any user-uploaded files referenced in this message + // to the chat (best-effort; cap enforced in SQL). + unlinked, capExceeded := api.linkFilesToChat(ctx, chatID, fileIDs) response := codersdk.CreateChatMessageResponse{Queued: sendResult.Queued} if sendResult.Queued { if sendResult.QueuedMessage != nil { @@ -1882,6 +1914,13 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { message := convertChatMessage(sendResult.Message) response.Message = &message } + if len(unlinked) > 0 { + if capExceeded { + response.Warnings = append(response.Warnings, fileLinkCapWarning(len(unlinked))) + } else { + response.Warnings = append(response.Warnings, fileLinkErrorWarning(len(unlinked))) + } + } httpapi.Write(ctx, rw, http.StatusOK, response) } @@ -1915,7 +1954,7 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) { return } - contentBlocks, _, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content") + contentBlocks, _, fileIDs, inputError := createChatInputFromParts(ctx, api.Database, req.Content, "content") if inputError != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: inputError.Message, @@ -1954,8 +1993,20 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) { return } - message := convertChatMessage(editResult.Message) - httpapi.Write(ctx, rw, http.StatusOK, message) + // Link any user-uploaded files referenced in the edited + // message to the chat (best-effort; cap enforced in SQL). + unlinked, capExceeded := api.linkFilesToChat(ctx, chat.ID, fileIDs) + response := codersdk.EditChatMessageResponse{ + Message: convertChatMessage(editResult.Message), + } + if len(unlinked) > 0 { + if capExceeded { + response.Warnings = append(response.Warnings, fileLinkCapWarning(len(unlinked))) + } else { + response.Warnings = append(response.Warnings, fileLinkErrorWarning(len(unlinked))) + } + } + httpapi.Write(ctx, rw, http.StatusOK, response) } // EXPERIMENTAL: this endpoint is experimental and is subject to change. @@ -2232,7 +2283,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) { chat = updatedChat } - httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, nil)) + httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, nil, nil)) } // EXPERIMENTAL: this endpoint is experimental and is subject to change. @@ -2276,7 +2327,7 @@ func (api *API) regenerateChatTitle(rw http.ResponseWriter, r *http.Request) { return } - httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(updatedChat, nil)) + httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(updatedChat, nil, nil)) } // EXPERIMENTAL: this endpoint is experimental and is subject to change. @@ -3688,6 +3739,7 @@ func (api *API) chatFileByID(rw http.ResponseWriter, r *http.Request) { func createChatInputFromRequest(ctx context.Context, db database.Store, req codersdk.CreateChatRequest) ( []codersdk.ChatMessagePart, string, + []uuid.UUID, *codersdk.Response, ) { return createChatInputFromParts(ctx, db, req.Content, "content") @@ -3698,14 +3750,15 @@ func createChatInputFromParts( db database.Store, parts []codersdk.ChatInputPart, fieldName string, -) ([]codersdk.ChatMessagePart, string, *codersdk.Response) { +) ([]codersdk.ChatMessagePart, string, []uuid.UUID, *codersdk.Response) { if len(parts) == 0 { - return nil, "", &codersdk.Response{ + return nil, "", nil, &codersdk.Response{ Message: "Content is required.", Detail: "Content cannot be empty.", } } + var fileIDs []uuid.UUID content := make([]codersdk.ChatMessagePart, 0, len(parts)) textParts := make([]string, 0, len(parts)) for i, part := range parts { @@ -3713,7 +3766,7 @@ func createChatInputFromParts( case string(codersdk.ChatInputPartTypeText): text := strings.TrimSpace(part.Text) if text == "" { - return nil, "", &codersdk.Response{ + return nil, "", nil, &codersdk.Response{ Message: "Invalid input part.", Detail: fmt.Sprintf("%s[%d].text cannot be empty.", fieldName, i), } @@ -3722,7 +3775,7 @@ func createChatInputFromParts( textParts = append(textParts, text) case string(codersdk.ChatInputPartTypeFile): if part.FileID == uuid.Nil { - return nil, "", &codersdk.Response{ + return nil, "", nil, &codersdk.Response{ Message: "Invalid input part.", Detail: fmt.Sprintf("%s[%d].file_id is required for file parts.", fieldName, i), } @@ -3733,20 +3786,23 @@ func createChatInputFromParts( chatFile, err := db.GetChatFileByID(ctx, part.FileID) if err != nil { if httpapi.Is404Error(err) { - return nil, "", &codersdk.Response{ + return nil, "", nil, &codersdk.Response{ Message: "Invalid input part.", Detail: fmt.Sprintf("%s[%d].file_id references a file that does not exist.", fieldName, i), } } - return nil, "", &codersdk.Response{ + return nil, "", nil, &codersdk.Response{ Message: "Internal error.", Detail: fmt.Sprintf("Failed to retrieve file for %s[%d].", fieldName, i), } } content = append(content, codersdk.ChatMessageFile(part.FileID, chatFile.Mimetype)) + fileIDs = append(fileIDs, part.FileID) + // file-reference parts carry inline code snippets, not uploaded + // files. They have no FileID and are excluded from file tracking. case string(codersdk.ChatInputPartTypeFileReference): if part.FileName == "" { - return nil, "", &codersdk.Response{ + return nil, "", nil, &codersdk.Response{ Message: "Invalid input part.", Detail: fmt.Sprintf("%s[%d].file_name cannot be empty for file-reference.", fieldName, i), } @@ -3764,7 +3820,7 @@ func createChatInputFromParts( } textParts = append(textParts, sb.String()) default: - return nil, "", &codersdk.Response{ + return nil, "", nil, &codersdk.Response{ Message: "Invalid input part.", Detail: fmt.Sprintf( "%s[%d].type %q is not supported.", @@ -3779,13 +3835,13 @@ func createChatInputFromParts( // Allow file-only messages. The titleSource may be empty // when only file parts are provided, callers handle this. if len(content) == 0 { - return nil, "", &codersdk.Response{ + return nil, "", nil, &codersdk.Response{ Message: "Content is required.", Detail: fmt.Sprintf("%s must include at least one text or file part.", fieldName), } } titleSource := strings.TrimSpace(strings.Join(textParts, " ")) - return content, titleSource, nil + return content, titleSource, fileIDs, nil } func chatTitleFromMessage(message string) string { @@ -3820,6 +3876,70 @@ func truncateRunes(value string, maxLen int) string { return string(runes[:maxLen]) } +// linkFilesToChat inserts file-link rows into the chat_file_links +// join table. Cap enforcement and dedup are handled atomically in +// SQL. On success returns (nil, false). On failure returns the full +// input fileIDs slice — linking is all-or-nothing because the +// SQL operates on the batch atomically. capExceeded indicates +// whether the failure was due to the cap being exceeded (true) +// or a database error (false). +// Failures are logged but never block the caller. +func (api *API) linkFilesToChat(ctx context.Context, chatID uuid.UUID, fileIDs []uuid.UUID) (unlinked []uuid.UUID, capExceeded bool) { + if len(fileIDs) == 0 { + return nil, false + } + rejected, err := api.Database.LinkChatFiles(ctx, database.LinkChatFilesParams{ + ChatID: chatID, + MaxFileLinks: int32(codersdk.MaxChatFileIDs), + FileIds: fileIDs, + }) + if err != nil { + api.Logger.Error(ctx, "failed to link files to chat", + slog.F("chat_id", chatID), + slog.F("file_ids", fileIDs), + slog.Error(err), + ) + return fileIDs, false + } + if rejected > 0 { + api.Logger.Warn(ctx, "file cap reached, files not linked", + slog.F("chat_id", chatID), + slog.F("file_ids", fileIDs), + slog.F("max_file_links", codersdk.MaxChatFileIDs), + ) + return fileIDs, true + } + return nil, false +} + +// fileLinkCapWarning builds a user-facing warning when a batch +// of file IDs was atomically rejected because the resulting +// array would exceed the per-chat file cap. +func fileLinkCapWarning(count int) string { + return fmt.Sprintf("file linking skipped: batch of %d file(s) would exceed limit of %d", count, codersdk.MaxChatFileIDs) +} + +// fileLinkErrorWarning builds a user-facing warning when a +// database error prevented linking files to a chat. +func fileLinkErrorWarning(count int) string { + return fmt.Sprintf("%d file(s) could not be linked due to a server error", count) +} + +// fetchChatFileMetadata returns metadata for all files linked to +// the given chat. Errors are logged and result in a nil return +// (callers treat file metadata as best-effort). +func (api *API) fetchChatFileMetadata(ctx context.Context, chatID uuid.UUID) []database.GetChatFileMetadataByChatIDRow { + rows, err := api.Database.GetChatFileMetadataByChatID(ctx, chatID) + if err != nil { + api.Logger.Error(ctx, "failed to fetch chat file metadata", + slog.F("chat_id", chatID), + slog.Error(err), + ) + return nil + } + return rows +} + func convertChatCostModelBreakdown(model database.GetChatCostPerModelRow) codersdk.ChatCostModelBreakdown { displayName := strings.TrimSpace(model.DisplayName) if displayName == "" { diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index 65316dc02d..a32d117ccf 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -3227,6 +3227,153 @@ func TestGetChat(t *testing.T) { _, err = otherClient.GetChat(ctx, createdChat.ID) requireSDKError(t, err, http.StatusNotFound) }) + + t.Run("FilesHydrated", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + // Upload a file. + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "hydrated.png", bytes.NewReader(pngData)) + require.NoError(t, err) + + // Create a chat with a text + file part. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + {Type: codersdk.ChatInputPartTypeText, Text: "check file hydration"}, + {Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID}, + }, + }) + require.NoError(t, err) + + // GET the chat — files must be hydrated with all metadata fields. + chatResult, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + require.Len(t, chatResult.Files, 1) + f := chatResult.Files[0] + require.Equal(t, uploadResp.ID, f.ID) + require.Equal(t, firstUser.UserID, f.OwnerID) + require.NotEqual(t, uuid.Nil, f.OrganizationID) + require.Equal(t, "image/png", f.MimeType) + require.Equal(t, "hydrated.png", f.Name) + require.NotZero(t, f.CreatedAt) + }) + + // ToolCreatedFilesLinked exercises the DB path that chatd uses + // when a tool (e.g. propose_plan) creates a file: InsertChatFile + // then LinkChatFiles. This is a DB-level test because driving + // the full chatd tool-call pipeline requires an LLM mock. + t.Run("ToolCreatedFilesLinked", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, store := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + // Create a chat via the API so all metadata is set up. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + {Type: codersdk.ChatInputPartTypeText, Text: "tool file test"}, + }, + }) + require.NoError(t, err) + + // Mimic what chatd's StoreFile closure does: + // 1. InsertChatFile + // 2. LinkChatFiles + //nolint:gocritic // Using AsChatd to mimic the chatd background worker. + chatdCtx := dbauthz.AsChatd(ctx) + fileRow, err := store.InsertChatFile(chatdCtx, database.InsertChatFileParams{ + OwnerID: firstUser.UserID, + OrganizationID: firstUser.OrganizationID, + Name: "plan.md", + Mimetype: "text/markdown", + Data: []byte("# Plan"), + }) + require.NoError(t, err) + + rejected, err := store.LinkChatFiles(chatdCtx, database.LinkChatFilesParams{ + ChatID: chat.ID, + MaxFileLinks: int32(codersdk.MaxChatFileIDs), + FileIds: []uuid.UUID{fileRow.ID}, + }) + require.NoError(t, err) + require.Equal(t, int32(0), rejected, "0 rejected = all files linked") + + // Verify via the API that the file appears in the chat. + chatResult, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + require.Len(t, chatResult.Files, 1) + f := chatResult.Files[0] + require.Equal(t, fileRow.ID, f.ID) + require.Equal(t, firstUser.UserID, f.OwnerID) + require.Equal(t, firstUser.OrganizationID, f.OrganizationID) + require.Equal(t, "plan.md", f.Name) + require.Equal(t, "text/markdown", f.MimeType) + + // Fill up to the cap by inserting more files via the + // chatd DB path, then verify the cap is enforced. + for i := 1; i < codersdk.MaxChatFileIDs; i++ { + extra, err := store.InsertChatFile(chatdCtx, database.InsertChatFileParams{ + OwnerID: firstUser.UserID, + OrganizationID: firstUser.OrganizationID, + Name: fmt.Sprintf("file%d.md", i), + Mimetype: "text/markdown", + Data: []byte("data"), + }) + require.NoError(t, err) + _, err = store.LinkChatFiles(chatdCtx, database.LinkChatFilesParams{ + ChatID: chat.ID, + MaxFileLinks: int32(codersdk.MaxChatFileIDs), + FileIds: []uuid.UUID{extra.ID}, + }) + require.NoError(t, err) + } + + // Chat should now have exactly MaxChatFileIDs files. + chatResult, err = client.GetChat(ctx, chat.ID) + require.NoError(t, err) + require.Len(t, chatResult.Files, codersdk.MaxChatFileIDs) + + // Attempt to add one more file — should be rejected (0 rows). + overflow, err := store.InsertChatFile(chatdCtx, database.InsertChatFileParams{ + OwnerID: firstUser.UserID, + OrganizationID: firstUser.OrganizationID, + Name: "overflow.md", + Mimetype: "text/markdown", + Data: []byte("too many"), + }) + require.NoError(t, err) + rejected, err = store.LinkChatFiles(chatdCtx, database.LinkChatFilesParams{ + ChatID: chat.ID, + MaxFileLinks: int32(codersdk.MaxChatFileIDs), + FileIds: []uuid.UUID{overflow.ID}, + }) + require.NoError(t, err) + require.Equal(t, int32(1), rejected, "cap should reject the 21st file") + + // Re-appending an already-linked ID at cap should succeed + // (dedup means no array growth). + rejected, err = store.LinkChatFiles(chatdCtx, database.LinkChatFilesParams{ + ChatID: chat.ID, + MaxFileLinks: int32(codersdk.MaxChatFileIDs), + FileIds: []uuid.UUID{fileRow.ID}, + }) + require.NoError(t, err) + // ON CONFLICT DO NOTHING returns 0 rows when the link + // already exists, which is fine — the file is still linked. + require.Equal(t, int32(0), rejected, "dedup of existing ID should be a no-op") + + // Count should still be exactly MaxChatFileIDs. + chatResult, err = client.GetChat(ctx, chat.ID) + require.NoError(t, err) + require.Len(t, chatResult.Files, codersdk.MaxChatFileIDs) + }) } func TestArchiveChat(t *testing.T) { @@ -4373,6 +4520,14 @@ func TestChatMessageWithFiles(t *testing.T) { // With no text, chatTitleFromMessage("") returns "New Chat". require.Equal(t, "New Chat", chat.Title) + require.Len(t, chat.Files, 1) + f := chat.Files[0] + require.Equal(t, uploadResp.ID, f.ID) + require.Equal(t, firstUser.UserID, f.OwnerID) + require.NotEqual(t, uuid.Nil, f.OrganizationID) + require.Equal(t, "image/png", f.MimeType) + require.Equal(t, "test.png", f.Name) + require.NotZero(t, f.CreatedAt) }) t.Run("InvalidFileID", func(t *testing.T) { @@ -4407,6 +4562,189 @@ func TestChatMessageWithFiles(t *testing.T) { require.Equal(t, "Invalid input part.", sdkErr.Message) require.Contains(t, sdkErr.Detail, "does not exist") }) + + t.Run("FilesLinkedOnSend", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + // Create a text-only chat (no files initially). + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + {Type: codersdk.ChatInputPartTypeText, Text: "no files yet"}, + }, + }) + require.NoError(t, err) + + // Upload a file. + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "linked.png", bytes.NewReader(pngData)) + require.NoError(t, err) + + // Send a message with the file. + _, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{ + {Type: codersdk.ChatInputPartTypeText, Text: "here is a file"}, + {Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID}, + }, + }) + require.NoError(t, err) + + // GET the chat — file should be linked. + chatResult, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + require.Len(t, chatResult.Files, 1) + require.Equal(t, uploadResp.ID, chatResult.Files[0].ID) + require.Equal(t, "linked.png", chatResult.Files[0].Name) + }) + + t.Run("DedupFileIDs", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + // Upload a file. + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "dedup.png", bytes.NewReader(pngData)) + require.NoError(t, err) + + // Create a chat with a file. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + {Type: codersdk.ChatInputPartTypeText, Text: "first mention"}, + {Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID}, + }, + }) + require.NoError(t, err) + + // Send another message with the SAME file. + msgResp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{ + {Type: codersdk.ChatInputPartTypeText, Text: "same file again"}, + {Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID}, + }, + }) + require.NoError(t, err) + require.Empty(t, msgResp.Warnings, "dedup below cap should not produce warnings") + + // GET — should have exactly 1 file (deduped by SQL DISTINCT). + chatResult, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + require.Len(t, chatResult.Files, 1, "duplicate file IDs should be deduped") + require.Equal(t, uploadResp.ID, chatResult.Files[0].ID) + }) + + t.Run("FileCapExceeded", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + + // Upload MaxChatFileIDs files. + fileIDs := make([]uuid.UUID, 0, codersdk.MaxChatFileIDs) + for i := range codersdk.MaxChatFileIDs { + resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", fmt.Sprintf("file%d.png", i), bytes.NewReader(pngData)) + require.NoError(t, err) + fileIDs = append(fileIDs, resp.ID) + } + + // Create a chat using all MaxChatFileIDs files. + parts := []codersdk.ChatInputPart{ + {Type: codersdk.ChatInputPartTypeText, Text: "max files"}, + } + for _, fid := range fileIDs { + parts = append(parts, codersdk.ChatInputPart{Type: codersdk.ChatInputPartTypeFile, FileID: fid}) + } + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{Content: parts}) + require.NoError(t, err) + require.Empty(t, chat.Warnings, "creating a chat at exactly the cap should not warn") + require.Len(t, chat.Files, codersdk.MaxChatFileIDs, "all files should be linked on creation") + + // Upload one more file. + extraResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "one-too-many.png", bytes.NewReader(pngData)) + require.NoError(t, err) + + // Sending a message with the extra file should succeed + // (message goes through) but the file should NOT be linked + // (cap enforced in SQL). The response includes a warning. + msgResp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{ + {Type: codersdk.ChatInputPartTypeText, Text: "one too many"}, + {Type: codersdk.ChatInputPartTypeFile, FileID: extraResp.ID}, + }, + }) + require.NoError(t, err) + require.NotEmpty(t, msgResp.Warnings, "response should warn about unlinked files") + require.Contains(t, msgResp.Warnings[0], "file linking skipped") + + // The extra file should NOT appear in the chat's files. + chatResult, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + require.Len(t, chatResult.Files, codersdk.MaxChatFileIDs, + "file count should not exceed the cap") + + // Sending a message referencing an already-linked file + // should succeed with no warnings (dedup, no array growth). + msgResp2, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{ + {Type: codersdk.ChatInputPartTypeText, Text: "re-reference existing"}, + {Type: codersdk.ChatInputPartTypeFile, FileID: fileIDs[0]}, + }, + }) + require.NoError(t, err) + require.Empty(t, msgResp2.Warnings, "re-referencing an existing file should not warn") + }) + + t.Run("FileCapOnCreate", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + + // Upload MaxChatFileIDs + 1 files. + fileIDs := make([]uuid.UUID, 0, codersdk.MaxChatFileIDs+1) + for i := range codersdk.MaxChatFileIDs + 1 { + resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", fmt.Sprintf("create%d.png", i), bytes.NewReader(pngData)) + require.NoError(t, err) + fileIDs = append(fileIDs, resp.ID) + } + + // Create a chat with all files (one over the cap). + parts := []codersdk.ChatInputPart{ + {Type: codersdk.ChatInputPartTypeText, Text: "over cap on create"}, + } + for _, fid := range fileIDs { + parts = append(parts, codersdk.ChatInputPart{Type: codersdk.ChatInputPartTypeFile, FileID: fid}) + } + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{Content: parts}) + require.NoError(t, err, "chat creation should succeed even when cap is exceeded") + require.NotEmpty(t, chat.Warnings, "response should warn about unlinked files") + require.Contains(t, chat.Warnings[0], "file linking skipped") + + // Only MaxChatFileIDs files should actually be linked. + // With SQL-level batch rejection, ALL files are rejected + // when the result would exceed the cap. Since we're + // sending MaxChatFileIDs+1 files, the deduped count is + // 21 > 20, so 0 rows are affected and all files are + // unlinked. + chatResult, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + require.Empty(t, chatResult.Files, "no files should be linked when batch exceeds cap") + }) } func TestPatchChatMessage(t *testing.T) { @@ -4453,11 +4791,11 @@ func TestPatchChatMessage(t *testing.T) { require.NoError(t, err) // The edited message is soft-deleted and a new one is inserted, // so the returned ID will differ from the original. - require.NotEqual(t, userMessageID, edited.ID) - require.Equal(t, codersdk.ChatMessageRoleUser, edited.Role) + require.NotEqual(t, userMessageID, edited.Message.ID) + require.Equal(t, codersdk.ChatMessageRoleUser, edited.Message.Role) foundEditedText := false - for _, part := range edited.Content { + for _, part := range edited.Message.Content { if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "hello after edit" { foundEditedText = true } @@ -4545,11 +4883,11 @@ func TestPatchChatMessage(t *testing.T) { require.NoError(t, err) // The edited message is soft-deleted and a new one is inserted, // so the returned ID will differ from the original. - require.NotEqual(t, userMessageID, edited.ID) + require.NotEqual(t, userMessageID, edited.Message.ID) // Assert the edit response preserves the file_id. var foundText, foundFile bool - for _, part := range edited.Content { + for _, part := range edited.Message.Content { if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "after edit with file" { foundText = true } @@ -4692,6 +5030,112 @@ func TestPatchChatMessage(t *testing.T) { sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid chat message ID.", sdkErr.Message) }) + + t.Run("FilesLinkedOnEdit", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + // Create a text-only chat. + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + {Type: codersdk.ChatInputPartTypeText, Text: "before file edit"}, + }, + }) + require.NoError(t, err) + + // Upload a file. + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "edit-linked.png", bytes.NewReader(pngData)) + require.NoError(t, err) + + // Find the user message ID. + messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil) + require.NoError(t, err) + var userMessageID int64 + for _, msg := range messagesResult.Messages { + if msg.Role == codersdk.ChatMessageRoleUser { + userMessageID = msg.ID + break + } + } + require.NotZero(t, userMessageID) + + // Edit the message to include the file. + _, err = client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{ + Content: []codersdk.ChatInputPart{ + {Type: codersdk.ChatInputPartTypeText, Text: "after file edit"}, + {Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID}, + }, + }) + require.NoError(t, err) + + // GET the chat — file should be linked. + chatResult, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + require.Len(t, chatResult.Files, 1) + f := chatResult.Files[0] + require.Equal(t, uploadResp.ID, f.ID) + require.Equal(t, "edit-linked.png", f.Name) + require.Equal(t, "image/png", f.MimeType) + }) + + t.Run("CapExceededOnEdit", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + // Create a chat with MaxChatFileIDs files already linked. + parts := []codersdk.ChatInputPart{ + {Type: codersdk.ChatInputPartTypeText, Text: "fill to cap"}, + } + pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) + for i := range codersdk.MaxChatFileIDs { + up, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", fmt.Sprintf("cap-%d.png", i), bytes.NewReader(pngData)) + require.NoError(t, err) + parts = append(parts, codersdk.ChatInputPart{Type: codersdk.ChatInputPartTypeFile, FileID: up.ID}) + } + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{Content: parts}) + require.NoError(t, err) + require.Empty(t, chat.Warnings, "all files should link on create") + + // Find the user message. + messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil) + require.NoError(t, err) + var userMessageID int64 + for _, msg := range messagesResult.Messages { + if msg.Role == codersdk.ChatMessageRoleUser { + userMessageID = msg.ID + break + } + } + require.NotZero(t, userMessageID) + + // Upload one more file and try to link via edit. + extra, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "one-too-many.png", bytes.NewReader(pngData)) + require.NoError(t, err) + edited, err := client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{ + Content: []codersdk.ChatInputPart{ + {Type: codersdk.ChatInputPartTypeText, Text: "edit with extra file"}, + {Type: codersdk.ChatInputPartTypeFile, FileID: extra.ID}, + }, + }) + require.NoError(t, err) + require.NotEmpty(t, edited.Warnings, "edit should surface cap warning") + require.Contains(t, edited.Warnings[0], "file linking skipped") + + // Verify the cap is still enforced. + chatResult, err := client.GetChat(ctx, chat.ID) + require.NoError(t, err) + require.Len(t, chatResult.Files, codersdk.MaxChatFileIDs, + "file count should not exceed the cap") + }) } func TestStreamChat(t *testing.T) { diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 92b79b675d..0219cad2a7 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -3204,7 +3204,11 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.Ch if p.pubsub == nil { return } - sdkChat := db2sdk.Chat(chat, nil) // we have diffStatus already converted + // diffStatus is applied below. File metadata is intentionally + // omitted from pubsub events to avoid an extra DB query per + // publish. Clients must merge pubsub updates, not replace + // cached file metadata. + sdkChat := db2sdk.Chat(chat, nil, nil) if diffStatus != nil { sdkChat.DiffStatus = diffStatus } @@ -4525,6 +4529,27 @@ func (p *Server) runChat( return uuid.Nil, xerrors.Errorf("insert chat file: %w", err) } + // Cap enforcement and dedup are handled atomically + // in SQL. rejected > 0 = cap exceeded. + rejected, err := p.db.LinkChatFiles(ctx, database.LinkChatFilesParams{ + ChatID: chatSnapshot.ID, + MaxFileLinks: int32(codersdk.MaxChatFileIDs), + FileIds: []uuid.UUID{row.ID}, + }) + switch { + case err != nil: + p.logger.Error(ctx, "failed to link file to chat", + slog.F("chat_id", chatSnapshot.ID), + slog.F("file_id", row.ID), + slog.Error(err), + ) + case rejected > 0: + p.logger.Warn(ctx, "file cap reached, file not linked to chat", + slog.F("chat_id", chatSnapshot.ID), + slog.F("file_id", row.ID), + slog.F("max_file_links", codersdk.MaxChatFileIDs), + ) + } return row.ID, nil }, })) diff --git a/codersdk/chats.go b/codersdk/chats.go index f62a843553..ec1d575152 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -26,6 +26,12 @@ import ( // threshold settings. const ChatCompactionThresholdKeyPrefix = "chat_compaction_threshold_pct:" +// MaxChatFileIDs is the maximum number of file IDs that can be +// associated with a single chat. This limit prevents unbounded +// growth in the chat_file_links table. It is easier to raise +// this limit than to lower it. +const MaxChatFileIDs = 20 + // CompactionThresholdKey returns the user-config key for a specific // model configuration's compaction threshold. func CompactionThresholdKey(modelConfigID uuid.UUID) string { @@ -46,24 +52,25 @@ const ( // Chat represents a chat session with an AI agent. type Chat struct { - ID uuid.UUID `json:"id" format:"uuid"` - OwnerID uuid.UUID `json:"owner_id" format:"uuid"` - WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"` - BuildID *uuid.UUID `json:"build_id,omitempty" format:"uuid"` - AgentID *uuid.UUID `json:"agent_id,omitempty" format:"uuid"` - ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"` - RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"` - LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"` - Title string `json:"title"` - Status ChatStatus `json:"status"` - LastError *string `json:"last_error"` - DiffStatus *ChatDiffStatus `json:"diff_status,omitempty"` - CreatedAt time.Time `json:"created_at" format:"date-time"` - UpdatedAt time.Time `json:"updated_at" format:"date-time"` - Archived bool `json:"archived"` - PinOrder int32 `json:"pin_order"` - MCPServerIDs []uuid.UUID `json:"mcp_server_ids" format:"uuid"` - Labels map[string]string `json:"labels"` + ID uuid.UUID `json:"id" format:"uuid"` + OwnerID uuid.UUID `json:"owner_id" format:"uuid"` + WorkspaceID *uuid.UUID `json:"workspace_id,omitempty" format:"uuid"` + BuildID *uuid.UUID `json:"build_id,omitempty" format:"uuid"` + AgentID *uuid.UUID `json:"agent_id,omitempty" format:"uuid"` + ParentChatID *uuid.UUID `json:"parent_chat_id,omitempty" format:"uuid"` + RootChatID *uuid.UUID `json:"root_chat_id,omitempty" format:"uuid"` + LastModelConfigID uuid.UUID `json:"last_model_config_id" format:"uuid"` + Title string `json:"title"` + Status ChatStatus `json:"status"` + LastError *string `json:"last_error"` + DiffStatus *ChatDiffStatus `json:"diff_status,omitempty"` + CreatedAt time.Time `json:"created_at" format:"date-time"` + UpdatedAt time.Time `json:"updated_at" format:"date-time"` + Archived bool `json:"archived"` + PinOrder int32 `json:"pin_order"` + MCPServerIDs []uuid.UUID `json:"mcp_server_ids" format:"uuid"` + Labels map[string]string `json:"labels"` + Files []ChatFileMetadata `json:"files,omitempty"` // HasUnread is true when assistant messages exist beyond // the owner's read cursor, which updates on stream // connect and disconnect. @@ -73,6 +80,18 @@ type Chat struct { // is updated only when context changes — first workspace // attach or agent change. LastInjectedContext []ChatMessagePart `json:"last_injected_context,omitempty"` + Warnings []string `json:"warnings,omitempty"` +} + +// ChatFileMetadata contains lightweight metadata about a file +// associated with a chat, excluding the file content itself. +type ChatFileMetadata struct { + ID uuid.UUID `json:"id" format:"uuid"` + OwnerID uuid.UUID `json:"owner_id" format:"uuid"` + OrganizationID uuid.UUID `json:"organization_id" format:"uuid"` + Name string `json:"name"` + MimeType string `json:"mime_type"` + CreatedAt time.Time `json:"created_at" format:"date-time"` } // ChatMessage represents a single message in a chat. @@ -402,6 +421,15 @@ type CreateChatMessageResponse struct { Message *ChatMessage `json:"message,omitempty"` QueuedMessage *ChatQueuedMessage `json:"queued_message,omitempty"` Queued bool `json:"queued"` + Warnings []string `json:"warnings,omitempty"` +} + +// EditChatMessageResponse is the response from editing a message in a chat. +// Edits are always synchronous (no queueing), so the message is returned +// directly. +type EditChatMessageResponse struct { + Message ChatMessage `json:"message"` + Warnings []string `json:"warnings,omitempty"` } // UploadChatFileResponse is the response from uploading a chat file. @@ -1956,7 +1984,7 @@ func (c *ExperimentalClient) EditChatMessage( chatID uuid.UUID, messageID int64, req EditChatMessageRequest, -) (ChatMessage, error) { +) (EditChatMessageResponse, error) { res, err := c.Request( ctx, http.MethodPatch, @@ -1964,14 +1992,14 @@ func (c *ExperimentalClient) EditChatMessage( req, ) if err != nil { - return ChatMessage{}, err + return EditChatMessageResponse{}, err } if res.StatusCode != http.StatusOK { - return ChatMessage{}, readBodyAsChatUsageLimitError(res) + return EditChatMessageResponse{}, readBodyAsChatUsageLimitError(res) } defer res.Body.Close() - var message ChatMessage - return message, json.NewDecoder(res.Body).Decode(&message) + var resp EditChatMessageResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) } // InterruptChat cancels an in-flight chat run and leaves it waiting. diff --git a/docs/ai-coder/agents/chats-api.md b/docs/ai-coder/agents/chats-api.md index 7b755434ac..55d78b5ab5 100644 --- a/docs/ai-coder/agents/chats-api.md +++ b/docs/ai-coder/agents/chats-api.md @@ -134,6 +134,12 @@ edited message onward, truncating any messages that followed it. |-----------|-------------------|----------|----------------------------------| | `content` | `ChatInputPart[]` | yes | The replacement message content. | +The response is an `EditChatMessageResponse` with the edited `message` +and an optional `warnings` array. When file references in the edited +content cannot be linked (e.g. the per-chat file cap is reached), the +edit still succeeds and the `warnings` array describes which files +were not linked. + ### Stream updates `GET /api/experimental/chats/{chat}/stream` @@ -201,7 +207,9 @@ Each event is a JSON object with `kind` and `chat` fields: `GET /api/experimental/chats` -Returns all chats owned by the authenticated user. +Returns all chats owned by the authenticated user. The `files` field is +populated on `POST /chats` and `GET /chats/{id}`. Other endpoints that +return a `Chat` object omit it. | Query parameter | Type | Required | Description | |-----------------|----------|----------|------------------------------------------------------------------| @@ -212,7 +220,17 @@ Returns all chats owned by the authenticated user. `GET /api/experimental/chats/{chat}` -Returns the `Chat` object (metadata only, no messages). +Returns the `Chat` object (metadata only, no messages). The response +includes a `files` field (`ChatFileMetadata[]`) containing metadata for +files that have been successfully linked to the chat. File linking is +best-effort; if linking fails, the file remains in message content but +will be absent from this field. + +When file linking is skipped (e.g. the per-chat file cap is reached), +`POST /chats` includes a `warnings` array on the `Chat` response and +`POST /chats/{chat}/messages` includes a `warnings` array on the +`CreateChatMessageResponse`. The `warnings` field is `omitempty` and +absent when all files are linked successfully. ### Get chat messages @@ -295,6 +313,10 @@ file, use `GET /api/experimental/chats/files/{file}`. Supported formats: PNG, JPEG, GIF, WebP (up to 10 MB). The server validates actual file content regardless of the declared `Content-Type`. +Files referenced in messages are automatically linked to the chat and +appear in the `files` field on subsequent +`GET /api/experimental/chats/{chat}` responses. + ## Chat statuses | Status | Meaning | diff --git a/site/src/api/api.ts b/site/src/api/api.ts index a37ad02e73..260d7ab57a 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -3170,14 +3170,13 @@ class ExperimentalApiMethods { chatId: string, messageId: number, req: TypesGen.EditChatMessageRequest, - ): Promise => { - const response = await this.axios.patch( + ): Promise => { + const response = await this.axios.patch( `/api/experimental/chats/${chatId}/messages/${messageId}`, req, ); return response.data; }; - interruptChat = async (chatId: string): Promise => { const response = await this.axios.post( `/api/experimental/chats/${chatId}/interrupt`, diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 0e657b5d14..6df4b9b6a9 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1200,6 +1200,7 @@ export interface Chat { readonly pin_order: number; readonly mcp_server_ids: readonly string[]; readonly labels: Record; + readonly files?: readonly ChatFileMetadata[]; /** * HasUnread is true when assistant messages exist beyond * the owner's read cursor, which updates on stream @@ -1213,6 +1214,7 @@ export interface Chat { * attach or agent change. */ readonly last_injected_context?: readonly ChatMessagePart[]; + readonly warnings?: readonly string[]; } // From codersdk/chats.go @@ -1407,6 +1409,20 @@ export interface ChatDiffStatus { readonly stale_at?: string; } +// From codersdk/chats.go +/** + * ChatFileMetadata contains lightweight metadata about a file + * associated with a chat, excluding the file content itself. + */ +export interface ChatFileMetadata { + readonly id: string; + readonly owner_id: string; + readonly organization_id: string; + readonly name: string; + readonly mime_type: string; + readonly created_at: string; +} + // From codersdk/chats.go export interface ChatFilePart { readonly type: "file"; @@ -2354,6 +2370,7 @@ export interface CreateChatMessageResponse { readonly message?: ChatMessage; readonly queued_message?: ChatQueuedMessage; readonly queued: boolean; + readonly warnings?: readonly string[]; } // From codersdk/chats.go @@ -3205,6 +3222,17 @@ export interface EditChatMessageRequest { readonly content: readonly ChatInputPart[]; } +// From codersdk/chats.go +/** + * EditChatMessageResponse is the response from editing a message in a chat. + * Edits are always synchronous (no queueing), so the message is returned + * directly. + */ +export interface EditChatMessageResponse { + readonly message: ChatMessage; + readonly warnings?: readonly string[]; +} + // From codersdk/externalauth.go export type EnhancedExternalAuthProvider = | "azure-devops" @@ -4112,6 +4140,15 @@ export interface MatchedProvisioners { readonly most_recently_seen?: string; } +// From codersdk/chats.go +/** + * MaxChatFileIDs is the maximum number of file IDs that can be + * associated with a single chat. This limit prevents unbounded + * growth in the chat_file_links table. It is easier to raise + * this limit than to lower it. + */ +export const MaxChatFileIDs = 20; + // From codersdk/organizations.go export interface MinimalOrganization { readonly id: string;