diff --git a/coderd/database/check_constraint.go b/coderd/database/check_constraint.go index 5223dc17b7..1a209d785c 100644 --- a/coderd/database/check_constraint.go +++ b/coderd/database/check_constraint.go @@ -14,6 +14,8 @@ const ( CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config + CheckChatsPinOrderArchivedCheck CheckConstraint = "chats_pin_order_archived_check" // chats + CheckChatsPinOrderParentCheck CheckConstraint = "chats_pin_order_parent_check" // chats CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles CheckGroupsChatSpendLimitMicrosCheck CheckConstraint = "groups_chat_spend_limit_micros_check" // groups CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 8a293ecca1..9706259cd5 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1483,7 +1483,9 @@ CREATE TABLE chats ( dynamic_tools jsonb, organization_id uuid NOT NULL, plan_mode chat_plan_mode, - client_type chat_client_type DEFAULT 'api'::chat_client_type NOT NULL + client_type chat_client_type DEFAULT 'api'::chat_client_type NOT NULL, + CONSTRAINT chats_pin_order_archived_check CHECK (((pin_order = 0) OR (archived = false))), + CONSTRAINT chats_pin_order_parent_check CHECK (((pin_order = 0) OR (parent_chat_id IS NULL))) ); CREATE TABLE connection_logs ( diff --git a/coderd/database/migrations/000476_chat_pin_order_constraints.down.sql b/coderd/database/migrations/000476_chat_pin_order_constraints.down.sql new file mode 100644 index 0000000000..d59780914a --- /dev/null +++ b/coderd/database/migrations/000476_chat_pin_order_constraints.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE chats DROP CONSTRAINT IF EXISTS chats_pin_order_parent_check; +ALTER TABLE chats DROP CONSTRAINT IF EXISTS chats_pin_order_archived_check; diff --git a/coderd/database/migrations/000476_chat_pin_order_constraints.up.sql b/coderd/database/migrations/000476_chat_pin_order_constraints.up.sql new file mode 100644 index 0000000000..66d0237199 --- /dev/null +++ b/coderd/database/migrations/000476_chat_pin_order_constraints.up.sql @@ -0,0 +1,14 @@ +-- Defensive: fix any existing violating rows before adding constraints. +UPDATE chats SET pin_order = 0 + WHERE pin_order > 0 AND parent_chat_id IS NOT NULL; + +UPDATE chats SET pin_order = 0 + WHERE pin_order > 0 AND archived = true; + +ALTER TABLE chats + ADD CONSTRAINT chats_pin_order_parent_check + CHECK (pin_order = 0 OR parent_chat_id IS NULL); + +ALTER TABLE chats + ADD CONSTRAINT chats_pin_order_archived_check + CHECK (pin_order = 0 OR archived = false); diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index 1817fa241a..0f96dcc172 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -11151,6 +11151,95 @@ func TestChatPinOrderQueries(t *testing.T) { }) } +func TestChatPinOrderConstraints(t *testing.T) { + t.Parallel() + if testing.Short() { + t.SkipNow() + } + + db, _ := dbtestutil.NewDB(t) + org := dbgen.Organization(t, db, database.Organization{}) + owner := dbgen.User(t, db, database.User{}) + dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID}) + + bg := context.Background() + _, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{ + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + Enabled: true, + CentralApiKeyEnabled: true, + }) + require.NoError(t, err) + + modelCfg, err := db.InsertChatModelConfig(bg, database.InsertChatModelConfigParams{ + Provider: "openai", + Model: "test-model", + DisplayName: "Test Model", + CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true}, + Enabled: true, + IsDefault: true, + ContextLimit: 128000, + CompressionThreshold: 80, + Options: json.RawMessage(`{}`), + }) + require.NoError(t, err) + + t.Run("ChildChatCannotBePinned", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) + + parent, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusCompleted, + ClientType: database.ChatClientTypeUi, + OwnerID: owner.ID, + LastModelConfigID: modelCfg.ID, + Title: "parent", + }) + require.NoError(t, err) + + child, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusCompleted, + ClientType: database.ChatClientTypeUi, + OwnerID: owner.ID, + LastModelConfigID: modelCfg.ID, + Title: "child", + ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true}, + RootChatID: uuid.NullUUID{UUID: parent.ID, Valid: true}, + }) + require.NoError(t, err) + + err = db.PinChatByID(ctx, child.ID) + require.Error(t, err) + require.True(t, database.IsCheckViolation(err, database.CheckChatsPinOrderParentCheck)) + }) + + t.Run("ArchivedChatCannotBePinned", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusCompleted, + ClientType: database.ChatClientTypeUi, + OwnerID: owner.ID, + LastModelConfigID: modelCfg.ID, + Title: "will be archived", + }) + require.NoError(t, err) + + _, err = db.ArchiveChatByID(ctx, chat.ID) + require.NoError(t, err) + + err = db.PinChatByID(ctx, chat.ID) + require.Error(t, err) + require.True(t, database.IsCheckViolation(err, database.CheckChatsPinOrderArchivedCheck)) + }) +} + func TestChatLabels(t *testing.T) { t.Parallel() if testing.Short() { diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index c9cca1b4c8..455d84862b 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -2208,6 +2208,13 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) { return } + if pinOrder > 0 && chat.ParentChatID.Valid { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Cannot pin a child chat.", + }) + return + } + // The behavior depends on current pin state: // - pinOrder == 0: unpin. // - pinOrder > 0 && already pinned: reorder (shift @@ -2232,10 +2239,21 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) { err = api.Database.PinChatByID(ctx, chat.ID) } if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: errMsg, - Detail: err.Error(), - }) + switch { + case database.IsCheckViolation(err, database.CheckChatsPinOrderParentCheck): + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Cannot pin a child chat.", + }) + case database.IsCheckViolation(err, database.CheckChatsPinOrderArchivedCheck): + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Cannot pin an archived chat.", + }) + default: + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: errMsg, + Detail: err.Error(), + }) + } return } } diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index ea35d4c4c0..07907ebed7 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -5575,6 +5575,36 @@ func TestChatPinOrder(t *testing.T) { chat = getChat(ctx, t, client, chat.ID) require.Zero(t, chat.PinOrder) }) + + t.Run("RejectsChildChat", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + parentChat := createChat(ctx, t, client, firstUser.OrganizationID, "parent chat") + + child, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: firstUser.OrganizationID, + OwnerID: firstUser.UserID, + LastModelConfigID: modelConfig.ID, + Title: "child chat", + Status: database.ChatStatusCompleted, + ClientType: database.ChatClientTypeUi, + ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true}, + RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true}, + }) + require.NoError(t, err) + + err = client.UpdateChat(ctx, child.ID, codersdk.UpdateChatRequest{PinOrder: ptr.Ref(int32(1))}) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Cannot pin a child chat.", sdkErr.Message) + + result := getChat(ctx, t, client, child.ID) + require.Zero(t, result.PinOrder) + }) } func TestPostChatMessages(t *testing.T) {