package coderd_test import ( "bytes" "context" "database/sql" "encoding/json" stderrors "errors" "fmt" "mime" "net/http" "regexp" "strings" "sync/atomic" "testing" "time" "github.com/google/uuid" "github.com/mark3labs/mcp-go/mcp" "github.com/shopspring/decimal" "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/require" "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/externalauth" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/coderd/x/chatd" "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" "github.com/coder/serpent" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) const chatProviderAPIKeySizeLimit = 10240 func chatDeploymentValues(t testing.TB) *codersdk.DeploymentValues { t.Helper() values := coderdtest.DeploymentValues(t) values.Experiments = []string{string(codersdk.ExperimentAgents)} return values } func newChatClient(t testing.TB) *codersdk.ExperimentalClient { t.Helper() client := coderdtest.New(t, &coderdtest.Options{ DeploymentValues: chatDeploymentValues(t), }) return codersdk.NewExperimentalClient(client) } func newChatClientWithDeploymentValues( t testing.TB, values *codersdk.DeploymentValues, ) *codersdk.ExperimentalClient { t.Helper() client := coderdtest.New(t, &coderdtest.Options{ DeploymentValues: values, }) return codersdk.NewExperimentalClient(client) } func newChatClientWithDatabase(t testing.TB) (*codersdk.ExperimentalClient, database.Store) { t.Helper() client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{ DeploymentValues: chatDeploymentValues(t), }) return codersdk.NewExperimentalClient(client), db } type failNextChatSystemPromptStore struct { database.Store failNextGetChatIncludeDefaultSystemPrompt atomic.Bool failNextGetChatSystemPromptConfig atomic.Bool failNextUpsertChatIncludeDefaultSystemPrompt atomic.Bool } func (s *failNextChatSystemPromptStore) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) { if s.failNextGetChatIncludeDefaultSystemPrompt.CompareAndSwap(true, false) { return false, stderrors.New("forced include-default read failure") } return s.Store.GetChatIncludeDefaultSystemPrompt(ctx) } func (s *failNextChatSystemPromptStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefault bool) error { if s.failNextUpsertChatIncludeDefaultSystemPrompt.CompareAndSwap(true, false) { return stderrors.New("forced include-default upsert failure") } return s.Store.UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefault) } func (s *failNextChatSystemPromptStore) GetChatSystemPromptConfig(ctx context.Context) (database.GetChatSystemPromptConfigRow, error) { if s.failNextGetChatSystemPromptConfig.CompareAndSwap(true, false) { return database.GetChatSystemPromptConfigRow{}, stderrors.New("forced chat system prompt configuration read failure") } return s.Store.GetChatSystemPromptConfig(ctx) } func requireChatUsageLimitExceededError( t *testing.T, err error, wantSpentMicros int64, wantLimitMicros int64, wantResetsAt time.Time, ) *codersdk.ChatUsageLimitExceededResponse { t.Helper() sdkErr, ok := codersdk.AsError(err) require.True(t, ok) require.Equal(t, http.StatusConflict, sdkErr.StatusCode()) require.Equal(t, "Chat usage limit exceeded.", sdkErr.Message) limitErr := codersdk.ChatUsageLimitExceededFrom(err) require.NotNil(t, limitErr) require.Equal(t, "Chat usage limit exceeded.", limitErr.Message) require.Equal(t, wantSpentMicros, limitErr.SpentMicros) require.Equal(t, wantLimitMicros, limitErr.LimitMicros) require.True( t, limitErr.ResetsAt.Equal(wantResetsAt), "expected resets_at %s, got %s", wantResetsAt.UTC().Format(time.RFC3339), limitErr.ResetsAt.UTC().Format(time.RFC3339), ) return limitErr } func enableDailyChatUsageLimit( ctx context.Context, t *testing.T, db database.Store, limitMicros int64, ) time.Time { t.Helper() _, err := db.UpsertChatUsageLimitConfig( dbauthz.AsSystemRestricted(ctx), database.UpsertChatUsageLimitConfigParams{ Enabled: true, DefaultLimitMicros: limitMicros, Period: string(codersdk.ChatUsageLimitPeriodDay), }, ) require.NoError(t, err) _, periodEnd := chatd.ComputeUsagePeriodBounds(time.Now(), codersdk.ChatUsageLimitPeriodDay) return periodEnd } func insertAssistantCostMessage( ctx context.Context, t *testing.T, db database.Store, chatID uuid.UUID, modelConfigID uuid.UUID, totalCostMicros int64, ) { t.Helper() assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ codersdk.ChatMessageText("assistant"), }) require.NoError(t, err) _, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{ ChatID: chatID, CreatedBy: []uuid.UUID{uuid.Nil}, ModelConfigID: []uuid.UUID{modelConfigID}, Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, ContentVersion: []int16{chatprompt.CurrentContentVersion}, Content: []string{string(assistantContent.RawMessage)}, Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, InputTokens: []int64{0}, OutputTokens: []int64{0}, TotalTokens: []int64{0}, ReasoningTokens: []int64{0}, CacheCreationTokens: []int64{0}, CacheReadTokens: []int64{0}, ContextLimit: []int64{0}, Compressed: []bool{false}, TotalCostMicros: []int64{totalCostMicros}, RuntimeMs: []int64{0}, }) require.NoError(t, err) } func TestPostChats(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) // Use a member with agents-access instead of the owner to // verify least-privilege access. memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID, rbac.RoleAgentsAccess()) memberClient := codersdk.NewExperimentalClient(memberClientRaw) chat, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "hello from chats route tests", }, }, }) require.NoError(t, err) require.NotEqual(t, uuid.Nil, chat.ID) require.Equal(t, member.ID, chat.OwnerID) require.Equal(t, modelConfig.ID, chat.LastModelConfigID) require.Equal(t, "hello from chats route tests", chat.Title) require.NotZero(t, chat.CreatedAt) require.NotZero(t, chat.UpdatedAt) require.Nil(t, chat.WorkspaceID) require.NotNil(t, chat.RootChatID) require.Equal(t, chat.ID, *chat.RootChatID) chatResult, err := memberClient.GetChat(ctx, chat.ID) require.NoError(t, err) messagesResult, err := memberClient.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) require.Equal(t, chat.ID, chatResult.ID) foundUserMessage := false for _, message := range messagesResult.Messages { if message.Role != codersdk.ChatMessageRoleUser { continue } for _, part := range message.Content { if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "hello from chats route tests" { foundUserMessage = true break } } } require.True(t, foundUserMessage) }) t.Run("MemberWithoutAgentsAccess", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Member without agents-access should be denied. memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) _, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "this should fail", }, }, }) requireSDKError(t, err, http.StatusForbidden) }) t.Run("HidesSystemPromptMessages", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "verify hidden system prompt", }, }, }) require.NoError(t, err) messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) for _, message := range messagesResult.Messages { require.NotEqual(t, codersdk.ChatMessageRoleSystem, message.Role) } }) t.Run("WithPerChatSystemPrompt", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: user.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "hello with system prompt", }, }, SystemPrompt: "You are a Go expert.", }) require.NoError(t, err) require.NotEqual(t, uuid.Nil, chat.ID) // Use the DB directly to see system messages, which are // hidden from the public API. dbMessages, err := db.GetChatMessagesForPromptByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) // Expect: deployment system prompt, per-chat system prompt, // workspace awareness, user message. var systemMessages []database.ChatMessage for _, msg := range dbMessages { if msg.Role == database.ChatMessageRoleSystem { systemMessages = append(systemMessages, msg) } } require.GreaterOrEqual(t, len(systemMessages), 2, "expected at least deployment + per-chat system messages") // The per-chat system prompt should be the second system // message and contain the user-specified text. foundPerChat := false for _, msg := range systemMessages { if msg.Content.Valid { raw := string(msg.Content.RawMessage) if strings.Contains(raw, "You are a Go expert.") { foundPerChat = true break } } } require.True(t, foundPerChat, "per-chat system prompt not found in system messages") }) t.Run("PerChatSystemPromptEmpty", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: user.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "hello without system prompt", }, }, SystemPrompt: "", }) require.NoError(t, err) dbMessages, err := db.GetChatMessagesForPromptByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) // No per-chat system prompt should be present. for _, msg := range dbMessages { if msg.Role == database.ChatMessageRoleSystem && msg.Content.Valid { raw := string(msg.Content.RawMessage) require.NotContains(t, raw, "You are a Go expert.", "unexpected per-chat system prompt in messages") } } }) t.Run("PerChatSystemPromptTooLong", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) user := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) longPrompt := strings.Repeat("a", 10001) _, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: user.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "hello", }, }, SystemPrompt: longPrompt, }) requireSDKError(t, err, http.StatusBadRequest) }) t.Run("WorkspaceNotAccessible", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID, rbac.RoleAgentsAccess()) memberClient := codersdk.NewExperimentalClient(memberClientRaw) workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ OrganizationID: firstUser.OrganizationID, OwnerID: firstUser.UserID, }).WithAgent().Do() _, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "hello", }, }, WorkspaceID: &workspaceBuild.Workspace.ID, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal( t, "Workspace not found or you do not have access to this resource", sdkErr.Message, ) }) t.Run("WorkspaceAccessibleButNoSSH", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) orgAdminClientRaw, _ := coderdtest.CreateAnotherUser( t, adminClient.Client, firstUser.OrganizationID, rbac.ScopedRoleOrgAdmin(firstUser.OrganizationID), rbac.RoleAgentsAccess(), ) orgAdminClient := codersdk.NewExperimentalClient(orgAdminClientRaw) workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ OrganizationID: firstUser.OrganizationID, OwnerID: firstUser.UserID, }).WithAgent().Do() _, err := orgAdminClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "hello", }, }, WorkspaceID: &workspaceBuild.Workspace.ID, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal( t, "Workspace not found or you do not have access to this resource", sdkErr.Message, ) }) t.Run("WorkspaceNotFound", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) workspaceID := uuid.New() _, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "hello", }, }, WorkspaceID: &workspaceID, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal( t, "Workspace not found or you do not have access to this resource", sdkErr.Message, ) }) t.Run("WorkspaceSelectsFirstAgent", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ OrganizationID: user.OrganizationID, OwnerID: user.UserID, }).WithAgent().Do() chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: user.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "hello", }, }, WorkspaceID: &workspaceBuild.Workspace.ID, }) require.NoError(t, err) require.NotNil(t, chat.WorkspaceID) require.Equal(t, workspaceBuild.Workspace.ID, *chat.WorkspaceID) require.Equal(t, modelConfig.ID, chat.LastModelConfigID) }) t.Run("MissingDefaultModelConfig", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "hello", }, }, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "No default chat model config is configured.", sdkErr.Message) }) t.Run("EmptyContent", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: nil, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Content is required.", sdkErr.Message) require.Equal(t, "Content cannot be empty.", sdkErr.Detail) }) t.Run("EmptyText", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: " ", }, }, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid input part.", sdkErr.Message) require.Equal(t, "content[0].text cannot be empty.", sdkErr.Detail) }) t.Run("UnsupportedPartType", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartType("image"), Text: "hello", }, }, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid input part.", sdkErr.Message) require.Equal(t, `content[0].type "image" is not supported.`, sdkErr.Detail) }) t.Run("UsageLimitExceeded", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100) existingChat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "existing-limit-chat", }) require.NoError(t, err) insertAssistantCostMessage(ctx, t, db, existingChat.ID, modelConfig.ID, 100) _, err = client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: user.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "over limit", }}, }) requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt) }) t.Run("NilOrganizationID", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID, rbac.RoleAgentsAccess()) memberClient := codersdk.NewExperimentalClient(memberClientRaw) _, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: uuid.Nil, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "hello", }}, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "organization_id is required.", sdkErr.Message) }) t.Run("NonMemberOrganization", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID, rbac.RoleAgentsAccess()) memberClient := codersdk.NewExperimentalClient(memberClientRaw) // Create a second organization via the database since the // API endpoint is enterprise-only. secondOrg := dbgen.Organization(t, db, database.Organization{}) _, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: secondOrg.ID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "hello", }}, }) sdkErr := requireSDKError(t, err, http.StatusForbidden) require.Equal(t, "You are not a member of the specified organization.", sdkErr.Message) }) t.Run("CrossOrgWorkspaceMismatch", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ OrganizationID: firstUser.OrganizationID, OwnerID: firstUser.UserID, }).WithAgent().Do() // Create a second organization and add the admin as a member // so the request passes the membership check but fails on // the workspace org mismatch. secondOrg := dbgen.Organization(t, db, database.Organization{}) dbgen.OrganizationMember(t, db, database.OrganizationMember{ OrganizationID: secondOrg.ID, UserID: firstUser.UserID, }) _, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: secondOrg.ID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "hello", }}, WorkspaceID: &workspaceBuild.Workspace.ID, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Workspace does not belong to the specified organization.", sdkErr.Message) }) } func TestPostChats_ClientType(t *testing.T) { t.Parallel() client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID, rbac.RoleAgentsAccess()) memberClient := codersdk.NewExperimentalClient(memberClientRaw) newChat := func(t *testing.T, clientType codersdk.ChatClientType) codersdk.Chat { t.Helper() ctx := testutil.Context(t, testutil.WaitLong) chat, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "client type test", }}, ClientType: clientType, }) require.NoError(t, err) return chat } t.Run("DefaultIsAPI", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) // Omit ClientType entirely — should default to "api". chat := newChat(t, "") require.Equal(t, codersdk.ChatClientTypeAPI, chat.ClientType) got, err := memberClient.GetChat(ctx, chat.ID) require.NoError(t, err) require.Equal(t, codersdk.ChatClientTypeAPI, got.ClientType) }) t.Run("ExplicitAPI", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) chat := newChat(t, codersdk.ChatClientTypeAPI) require.Equal(t, codersdk.ChatClientTypeAPI, chat.ClientType) got, err := memberClient.GetChat(ctx, chat.ID) require.NoError(t, err) require.Equal(t, codersdk.ChatClientTypeAPI, got.ClientType) }) t.Run("ExplicitUI", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) chat := newChat(t, codersdk.ChatClientTypeUI) require.Equal(t, codersdk.ChatClientTypeUI, chat.ClientType) got, err := memberClient.GetChat(ctx, chat.ID) require.NoError(t, err) require.Equal(t, codersdk.ChatClientTypeUI, got.ClientType) }) t.Run("InvalidClientType", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) _, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "bad client type", }}, ClientType: "bogus", }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Contains(t, sdkErr.Message, "Invalid client_type") }) } func TestListChats(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) firstChatA, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "first owner chat", }, }, }) require.NoError(t, err) firstChatB, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "second owner chat", }, }, }) require.NoError(t, err) memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID, rbac.RoleAgentsAccess()) memberClient := codersdk.NewExperimentalClient(memberClientRaw) memberDBChat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: firstUser.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: member.ID, LastModelConfigID: modelConfig.ID, Title: "member chat only", }) require.NoError(t, err) chats, err := client.ListChats(ctx, nil) require.NoError(t, err) require.Len(t, chats, 2) chatIndexes := make(map[uuid.UUID]int, len(chats)) chatsByID := make(map[uuid.UUID]codersdk.Chat, len(chats)) for i, chat := range chats { chatIndexes[chat.ID] = i chatsByID[chat.ID] = chat require.Equal(t, firstUser.UserID, chat.OwnerID) require.Equal(t, modelConfig.ID, chat.LastModelConfigID) // The chat may have been picked up by the background // processor (via signalWake) before we list, so // accept any active status. require.Contains(t, []codersdk.ChatStatus{ codersdk.ChatStatusPending, codersdk.ChatStatusRunning, codersdk.ChatStatusError, codersdk.ChatStatusWaiting, codersdk.ChatStatusCompleted, }, chat.Status, "unexpected chat status: %s", chat.Status) require.NotZero(t, chat.CreatedAt) require.NotZero(t, chat.UpdatedAt) require.Nil(t, chat.ParentChatID) require.Nil(t, chat.WorkspaceID) require.NotNil(t, chat.RootChatID) require.Equal(t, chat.ID, *chat.RootChatID) require.NotNil(t, chat.DiffStatus) require.Equal(t, chat.ID, chat.DiffStatus.ChatID) } require.Contains(t, chatsByID, firstChatA.ID) require.Contains(t, chatsByID, firstChatB.ID) require.NotContains(t, chatsByID, memberDBChat.ID) require.Equal(t, "first owner chat", chatsByID[firstChatA.ID].Title) require.Equal(t, "second owner chat", chatsByID[firstChatB.ID].Title) for i := 1; i < len(chats); i++ { require.False(t, chats[i-1].UpdatedAt.Before(chats[i].UpdatedAt)) } // The list is already verified as sorted by UpdatedAt // descending (loop above). We intentionally do NOT // compare positions using the creation-time UpdatedAt // values because signalWake() may trigger background // processing that mutates UpdatedAt between CreateChat // and ListChats. memberChats, err := memberClient.ListChats(ctx, nil) require.NoError(t, err) require.Len(t, memberChats, 1) require.Equal(t, memberDBChat.ID, memberChats[0].ID) require.Equal(t, member.ID, memberChats[0].OwnerID) require.Equal(t, "member chat only", memberChats[0].Title) require.NotNil(t, memberChats[0].RootChatID) require.Equal(t, memberChats[0].ID, *memberChats[0].RootChatID) require.NotNil(t, memberChats[0].DiffStatus) require.Equal(t, memberChats[0].ID, memberChats[0].DiffStatus.ChatID) }) t.Run("OrgMemberWithoutAgentsAccessCanAccessOwnChats", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) // Create a member without agents-access and insert a chat // owned by them via system context. With org-scoped chats, // org members get full CRUD on their own chats through // OrgMemberPermissions, without needing agents-access. // The agents-access role only gates chat creation (postChats) // and message sending (postChatMessages). Metadata operations // like archive/pin/label and reading are not gated. // See: https://github.com/coder/coder/issues/24250 memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) _, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: firstUser.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: member.ID, LastModelConfigID: modelConfig.ID, Title: "member chat", }) require.NoError(t, err) chats, err := memberClient.ListChats(ctx, nil) require.NoError(t, err) require.Len(t, chats, 1) // Verify member without agents-access can update own chat. err = memberClient.UpdateChat(ctx, chats[0].ID, codersdk.UpdateChatRequest{ Title: ptr.Ref("new title"), }) require.NoError(t, err) }) t.Run("Unauthenticated", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) unauthenticatedClient := codersdk.NewExperimentalClient(codersdk.New(client.URL)) _, err := unauthenticatedClient.ListChats(ctx, nil) requireSDKError(t, err, http.StatusUnauthorized) }) t.Run("Pagination", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, _ := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Create 5 chats. const totalChats = 5 createdChats := make([]codersdk.Chat, 0, totalChats) for i := 0; i < totalChats; i++ { chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: fmt.Sprintf("chat-%d", i), }, }, }) require.NoError(t, err) createdChats = append(createdChats, chat) } // Wait for all chats to reach a terminal status so // updated_at is stable before paginating. for _, c := range createdChats { require.Eventually(t, func() bool { all, listErr := client.ListChats(ctx, nil) if listErr != nil { return false } for _, ch := range all { if ch.ID == c.ID { return ch.Status != codersdk.ChatStatusPending && ch.Status != codersdk.ChatStatusRunning } } return false }, testutil.WaitShort, testutil.IntervalFast) } // Fetch first page with limit=2. page1, err := client.ListChats(ctx, &codersdk.ListChatsOptions{ Pagination: codersdk.Pagination{Limit: 2}, }) require.NoError(t, err) require.Len(t, page1, 2) // Fetch second page using after_id from last item of page 1. page2, err := client.ListChats(ctx, &codersdk.ListChatsOptions{ Pagination: codersdk.Pagination{ AfterID: uuid.MustParse(page1[len(page1)-1].ID.String()), Limit: 2, }, }) require.NoError(t, err) require.Len(t, page2, 2) // Ensure page1 and page2 have no overlap. page1IDs := make(map[uuid.UUID]struct{}) for _, c := range page1 { page1IDs[c.ID] = struct{}{} } for _, c := range page2 { _, overlap := page1IDs[c.ID] require.False(t, overlap, "page2 should not contain items from page1") } // Fetch third page — should have 1 remaining chat. page3, err := client.ListChats(ctx, &codersdk.ListChatsOptions{ Pagination: codersdk.Pagination{ AfterID: uuid.MustParse(page2[len(page2)-1].ID.String()), Limit: 2, }, }) require.NoError(t, err) require.Len(t, page3, 1) // All 5 chats should be accounted for. allIDs := make(map[uuid.UUID]struct{}) for _, c := range append(append(page1, page2...), page3...) { allIDs[c.ID] = struct{}{} } for _, c := range createdChats { _, found := allIDs[c.ID] require.True(t, found, "chat %s should appear in paginated results", c.ID) } // Fetch with offset=3, limit=2 — should return 2 chats. offsetPage, err := client.ListChats(ctx, &codersdk.ListChatsOptions{ Pagination: codersdk.Pagination{Offset: 3, Limit: 2}, }) require.NoError(t, err) require.Len(t, offsetPage, 2) // No limit should return all chats. allChats, err := client.ListChats(ctx, nil) require.NoError(t, err) require.Len(t, allChats, totalChats) }) // Test that a pinned chat with an old updated_at appears on page 1. t.Run("PinnedOnFirstPage", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, _ := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Create the chat that will later be pinned. It gets the // earliest updated_at because it is inserted first. pinnedChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "pinned-chat", }}, }) require.NoError(t, err) // Fill page 1 with newer chats so the pinned chat would // normally be pushed off the first page (default limit 50). const fillerCount = 51 fillerChats := make([]codersdk.Chat, 0, fillerCount) for i := range fillerCount { c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: fmt.Sprintf("filler-%d", i), }}, }) require.NoError(t, createErr) fillerChats = append(fillerChats, c) } // Wait for all chats to reach a terminal status so // updated_at is stable before paginating. A single // polling loop checks every chat per tick to avoid // O(N) separate Eventually loops. allCreated := append([]codersdk.Chat{pinnedChat}, fillerChats...) pending := make(map[uuid.UUID]struct{}, len(allCreated)) for _, c := range allCreated { pending[c.ID] = struct{}{} } testutil.Eventually(ctx, t, func(_ context.Context) bool { all, listErr := client.ListChats(ctx, &codersdk.ListChatsOptions{ Pagination: codersdk.Pagination{Limit: fillerCount + 10}, }) if listErr != nil { return false } for _, ch := range all { if _, ok := pending[ch.ID]; ok && ch.Status != codersdk.ChatStatusPending && ch.Status != codersdk.ChatStatusRunning { delete(pending, ch.ID) } } return len(pending) == 0 }, testutil.IntervalFast) // Pin the earliest chat. err = client.UpdateChat(ctx, pinnedChat.ID, codersdk.UpdateChatRequest{ PinOrder: ptr.Ref(int32(1)), }) require.NoError(t, err) // Fetch page 1 with default limit (50). page1, err := client.ListChats(ctx, &codersdk.ListChatsOptions{ Pagination: codersdk.Pagination{Limit: 50}, }) require.NoError(t, err) // The pinned chat must appear on page 1. page1IDs := make(map[uuid.UUID]struct{}, len(page1)) for _, c := range page1 { page1IDs[c.ID] = struct{}{} } _, found := page1IDs[pinnedChat.ID] require.True(t, found, "pinned chat should appear on page 1") // The pinned chat should be the first item in the list. require.Equal(t, pinnedChat.ID, page1[0].ID, "pinned chat should be first") }) // Test cursor pagination with a mix of pinned and unpinned chats. t.Run("CursorWithPins", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, _ := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Create 5 chats: 2 will be pinned, 3 unpinned. const totalChats = 5 createdChats := make([]codersdk.Chat, 0, totalChats) for i := range totalChats { c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: fmt.Sprintf("cursor-pin-chat-%d", i), }}, }) require.NoError(t, createErr) createdChats = append(createdChats, c) } // Wait for all chats to reach terminal status. // Check each chat by ID rather than fetching the full list. testutil.Eventually(ctx, t, func(_ context.Context) bool { for _, c := range createdChats { ch, err := client.GetChat(ctx, c.ID) require.NoError(t, err, "GetChat should succeed for just-created chat %s", c.ID) if ch.Status == codersdk.ChatStatusPending || ch.Status == codersdk.ChatStatusRunning { return false } } return true }, testutil.IntervalFast) // Pin the first two chats (oldest updated_at). err := client.UpdateChat(ctx, createdChats[0].ID, codersdk.UpdateChatRequest{ PinOrder: ptr.Ref(int32(1)), }) require.NoError(t, err) err = client.UpdateChat(ctx, createdChats[1].ID, codersdk.UpdateChatRequest{ PinOrder: ptr.Ref(int32(1)), }) require.NoError(t, err) // Paginate with limit=2 using cursor (after_id). const pageSize = 2 maxPages := totalChats/pageSize + 2 var allPaginated []codersdk.Chat var afterID uuid.UUID for range maxPages { opts := &codersdk.ListChatsOptions{ Pagination: codersdk.Pagination{Limit: pageSize}, } if afterID != uuid.Nil { opts.Pagination.AfterID = afterID } page, listErr := client.ListChats(ctx, opts) require.NoError(t, listErr) if len(page) == 0 { break } allPaginated = append(allPaginated, page...) afterID = page[len(page)-1].ID } // All chats should appear exactly once. seenIDs := make(map[uuid.UUID]struct{}, len(allPaginated)) for _, c := range allPaginated { _, dup := seenIDs[c.ID] require.False(t, dup, "chat %s appeared more than once", c.ID) seenIDs[c.ID] = struct{}{} } require.Len(t, seenIDs, totalChats, "all chats should appear in paginated results") // Pinned chats should come before unpinned ones, and // within the pinned group, lower pin_order sorts first. pinnedSeen := false unpinnedSeen := false for _, c := range allPaginated { if c.PinOrder > 0 { require.False(t, unpinnedSeen, "pinned chat %s appeared after unpinned chat", c.ID) pinnedSeen = true } else { unpinnedSeen = true } } require.True(t, pinnedSeen, "at least one pinned chat should exist") // Verify within-pinned ordering: pin_order=1 before // pin_order=2 (the -pin_order DESC column). require.Equal(t, createdChats[0].ID, allPaginated[0].ID, "pin_order=1 chat should be first") require.Equal(t, createdChats[1].ID, allPaginated[1].ID, "pin_order=2 chat should be second") }) } func TestListChatModels(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) models, err := client.ListChatModels(ctx) require.NoError(t, err) var openAIProvider *codersdk.ChatModelProvider for i := range models.Providers { if models.Providers[i].Provider == "openai" { openAIProvider = &models.Providers[i] break } } require.NotNil(t, openAIProvider) require.True(t, openAIProvider.Available) foundModel := false for _, model := range openAIProvider.Models { if model.Provider == "openai" && model.Model == "gpt-4o-mini" { foundModel = true break } } require.True(t, foundModel) }) t.Run("Unauthenticated", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) unauthenticatedClient := codersdk.NewExperimentalClient(codersdk.New(client.URL)) _, err := unauthenticatedClient.ListChatModels(ctx) requireSDKError(t, err, http.StatusUnauthorized) }) t.Run("CentralOnlyProviderAvailable", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) models, err := client.ListChatModels(ctx) require.NoError(t, err) var openAIProvider *codersdk.ChatModelProvider for i := range models.Providers { if models.Providers[i].Provider == "openai" { openAIProvider = &models.Providers[i] break } } require.NotNil(t, openAIProvider) require.True(t, openAIProvider.Available) }) t.Run("UserOnlyProviderRequiresUserKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "anthropic", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) contextLimit := int64(4096) _, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "anthropic", Model: "claude-sonnet", ContextLimit: &contextLimit, }) require.NoError(t, err) models, err := client.ListChatModels(ctx) require.NoError(t, err) var anthropicProvider *codersdk.ChatModelProvider for i := range models.Providers { if models.Providers[i].Provider == "anthropic" { anthropicProvider = &models.Providers[i] break } } require.NotNil(t, anthropicProvider) require.False(t, anthropicProvider.Available) require.Equal(t, codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired, anthropicProvider.UnavailableReason) _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: "user-api-key", }) require.NoError(t, err) models, err = client.ListChatModels(ctx) require.NoError(t, err) anthropicProvider = nil for i := range models.Providers { if models.Providers[i].Provider == "anthropic" { anthropicProvider = &models.Providers[i] break } } require.NotNil(t, anthropicProvider) require.True(t, anthropicProvider.Available) }) t.Run("CentralAndUserWithFallback", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "google", APIKey: "central-api-key", CentralAPIKeyEnabled: ptr.Ref(true), AllowUserAPIKey: ptr.Ref(true), AllowCentralAPIKeyFallback: ptr.Ref(true), }) require.NoError(t, err) contextLimit := int64(4096) _, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "google", Model: "gemini-1.5-pro", ContextLimit: &contextLimit, }) require.NoError(t, err) models, err := client.ListChatModels(ctx) require.NoError(t, err) var googleProvider *codersdk.ChatModelProvider for i := range models.Providers { if models.Providers[i].Provider == "google" { googleProvider = &models.Providers[i] break } } require.NotNil(t, googleProvider) require.True(t, googleProvider.Available) _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: "user-api-key", }) require.NoError(t, err) models, err = client.ListChatModels(ctx) require.NoError(t, err) googleProvider = nil for i := range models.Providers { if models.Providers[i].Provider == "google" { googleProvider = &models.Providers[i] break } } require.NotNil(t, googleProvider) require.True(t, googleProvider.Available) }) t.Run("DisabledProvidersAndModelsAreFilteredOut", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) values := chatDeploymentValues(t) values.AI.BridgeConfig.LegacyOpenAI.Key = serpent.String("deployment-openai-key") client := newChatClientWithDeploymentValues(t, values) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-key", }) require.NoError(t, err) contextLimit := int64(4096) _, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", Model: "gpt-4o-mini", ContextLimit: &contextLimit, }) require.NoError(t, err) models, err := client.ListChatModels(ctx) require.NoError(t, err) require.Len(t, models.Providers, 1) require.Equal(t, "openai", models.Providers[0].Provider) require.Len(t, models.Providers[0].Models, 1) require.Equal(t, "gpt-4o-mini", models.Providers[0].Models[0].Model) enabled := false _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ Enabled: &enabled, }) require.NoError(t, err) models, err = client.ListChatModels(ctx) require.NoError(t, err) require.Empty(t, models.Providers) }) } func TestWatchChats(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) conn, err := client.Dial(ctx, "/api/experimental/chats/watch", nil) require.NoError(t, err) defer conn.Close(websocket.StatusNormalClosure, "done") createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "watch route created event", }, }, }) require.NoError(t, err) for { var payload codersdk.ChatWatchEvent err = wsjson.Read(ctx, conn, &payload) require.NoError(t, err) if payload.Kind == codersdk.ChatWatchEventKindCreated && payload.Chat.ID == createdChat.ID { break } } }) t.Run("CreatedEventIncludesAllChatFields", func(t *testing.T) { t.Parallel() // This test verifies that the pubsub "created" event // carries a fully-populated codersdk.Chat. Exhaustive // field-level coverage of the converter is handled by // TestChat_AllFieldsPopulated (db2sdk) and // TestChat_JSONRoundTrip (codersdk). This integration // test only checks that key fields survive the full // API → pubsub → websocket pipeline. ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) conn, err := client.Dial(ctx, "/api/experimental/chats/watch", nil) require.NoError(t, err) defer conn.Close(websocket.StatusNormalClosure, "done") createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "watch route fields completeness test", }, }, }) require.NoError(t, err) var got codersdk.Chat testutil.Eventually(ctx, t, func(_ context.Context) bool { var payload codersdk.ChatWatchEvent if readErr := wsjson.Read(ctx, conn, &payload); readErr != nil { return false } if payload.Kind == codersdk.ChatWatchEventKindCreated && payload.Chat.ID == createdChat.ID { got = payload.Chat return true } return false }, testutil.IntervalFast, "expected a created event for chat %s", createdChat.ID) require.Equal(t, createdChat.ID, got.ID) require.Equal(t, createdChat.OwnerID, got.OwnerID) require.Equal(t, modelConfig.ID, got.LastModelConfigID) require.Equal(t, createdChat.Title, got.Title) require.Equal(t, codersdk.ChatStatusPending, got.Status) require.NotNil(t, got.RootChatID) require.Equal(t, createdChat.ID, *got.RootChatID) require.NotZero(t, got.CreatedAt) require.NotZero(t, got.UpdatedAt) }) t.Run("DiffStatusChangeIncludesDiffStatus", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) rawClient, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ DeploymentValues: chatDeploymentValues(t), }) client := codersdk.NewExperimentalClient(rawClient) db := api.Database user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) // Insert a chat and a diff status row. chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "diff status watch test", }) require.NoError(t, err) refreshedAt := time.Now().UTC().Truncate(time.Second) staleAt := refreshedAt.Add(time.Hour) _, err = db.UpsertChatDiffStatusReference( dbauthz.AsSystemRestricted(ctx), database.UpsertChatDiffStatusReferenceParams{ ChatID: chat.ID, Url: sql.NullString{String: "https://github.com/coder/coder/pull/99", Valid: true}, GitBranch: "feature/test", GitRemoteOrigin: "git@github.com:coder/coder.git", StaleAt: staleAt, }, ) require.NoError(t, err) _, err = db.UpsertChatDiffStatus( dbauthz.AsSystemRestricted(ctx), database.UpsertChatDiffStatusParams{ ChatID: chat.ID, Url: sql.NullString{String: "https://github.com/coder/coder/pull/99", Valid: true}, PullRequestState: sql.NullString{String: "open", Valid: true}, Additions: 42, Deletions: 7, ChangedFiles: 5, RefreshedAt: refreshedAt, StaleAt: staleAt, }, ) require.NoError(t, err) // Open the watch WebSocket. conn, err := client.Dial(ctx, "/api/experimental/chats/watch", nil) require.NoError(t, err) defer conn.Close(websocket.StatusNormalClosure, "done") // Publish a diff_status_change event via pubsub, // mimicking what PublishDiffStatusChange does after // it reads the diff status from the DB. dbStatus, err := db.GetChatDiffStatusByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) sdkDiffStatus := db2sdk.ChatDiffStatus(chat.ID, &dbStatus) event := codersdk.ChatWatchEvent{ Kind: codersdk.ChatWatchEventKindDiffStatusChange, Chat: codersdk.Chat{ ID: chat.ID, OwnerID: chat.OwnerID, Title: chat.Title, Status: codersdk.ChatStatus(chat.Status), CreatedAt: chat.CreatedAt, UpdatedAt: chat.UpdatedAt, DiffStatus: &sdkDiffStatus, }, } payload, err := json.Marshal(event) require.NoError(t, err) // Publish the event in a goroutine that keeps retrying. // When the WebSocket Dial returns, the server has completed // the HTTP upgrade but may not have called SubscribeWithErr // yet. If we publish only once, the message can arrive // before the subscription is active and be silently dropped, // causing the read loop to block until the context deadline. // Re-publishing on a short ticker guarantees that at least // one publish lands after the subscription is ready. publishDone := make(chan struct{}) go func() { ticker := time.NewTicker(testutil.IntervalFast) defer ticker.Stop() for { // Publish immediately on the first iteration, // then again on each tick. _ = api.Pubsub.Publish(coderdpubsub.ChatWatchEventChannel(user.UserID), payload) select { case <-publishDone: return case <-ctx.Done(): return case <-ticker.C: } } }() var received codersdk.ChatWatchEvent for { err = wsjson.Read(ctx, conn, &received) require.NoError(t, err) if received.Kind == codersdk.ChatWatchEventKindDiffStatusChange && received.Chat.ID == chat.ID { break } } close(publishDone) // Verify the event carries the full DiffStatus. require.NotNil(t, received.Chat.DiffStatus, "diff_status_change event must include DiffStatus") ds := received.Chat.DiffStatus require.Equal(t, chat.ID, ds.ChatID) require.NotNil(t, ds.URL) require.Equal(t, "https://github.com/coder/coder/pull/99", *ds.URL) require.NotNil(t, ds.PullRequestState) require.Equal(t, "open", *ds.PullRequestState) require.EqualValues(t, 42, ds.Additions) require.EqualValues(t, 7, ds.Deletions) require.EqualValues(t, 5, ds.ChangedFiles) }) t.Run("ArchiveAndUnarchiveEmitEventsForDescendants", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) parentChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: user.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "watch root chat", }, }, }) require.NoError(t, err) childOne, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "watch child 1", ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true}, RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true}, }) require.NoError(t, err) childTwo, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "watch child 2", ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true}, RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true}, }) require.NoError(t, err) conn, err := client.Dial(ctx, "/api/experimental/chats/watch", nil) require.NoError(t, err) defer conn.Close(websocket.StatusNormalClosure, "done") collectLifecycleEvents := func(expectedKind codersdk.ChatWatchEventKind) map[uuid.UUID]codersdk.ChatWatchEvent { t.Helper() events := make(map[uuid.UUID]codersdk.ChatWatchEvent, 3) for len(events) < 3 { var payload codersdk.ChatWatchEvent err = wsjson.Read(ctx, conn, &payload) require.NoError(t, err) if payload.Kind != expectedKind { continue } events[payload.Chat.ID] = payload } return events } assertLifecycleEvents := func(events map[uuid.UUID]codersdk.ChatWatchEvent, archived bool) { t.Helper() require.Len(t, events, 3) for _, chatID := range []uuid.UUID{parentChat.ID, childOne.ID, childTwo.ID} { payload, ok := events[chatID] require.True(t, ok, "missing event for chat %s", chatID) require.Equal(t, archived, payload.Chat.Archived) } } err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)}) require.NoError(t, err) deletedEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindDeleted) assertLifecycleEvents(deletedEvents, true) err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)}) require.NoError(t, err) createdEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindCreated) assertLifecycleEvents(createdEvents, false) }) t.Run("Unauthenticated", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) unauthenticatedClient := codersdk.New(client.URL) res, err := unauthenticatedClient.Request( ctx, http.MethodGet, "/api/experimental/chats/watch", nil, ) require.NoError(t, err) defer res.Body.Close() require.Equal(t, http.StatusUnauthorized, res.StatusCode) }) } func TestListChatProviders(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) providers, err := client.ListChatProviders(ctx) require.NoError(t, err) var openAIProvider *codersdk.ChatProviderConfig for i := range providers { if providers[i].Provider == "openai" { openAIProvider = &providers[i] break } } require.NotNil(t, openAIProvider) require.Equal(t, codersdk.ChatProviderConfigSourceDatabase, openAIProvider.Source) require.True(t, openAIProvider.Enabled) require.True(t, openAIProvider.HasAPIKey) }) t.Run("IgnoresDeploymentKeyWhenCentralKeyDisabled", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) values := chatDeploymentValues(t) values.AI.BridgeConfig.LegacyOpenAI.Key = serpent.String("deployment-openai-key") client := newChatClientWithDeploymentValues(t, values) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) require.False(t, provider.HasAPIKey) providers, err := client.ListChatProviders(ctx) require.NoError(t, err) for _, listed := range providers { if listed.Provider == "openai" { require.False(t, listed.HasAPIKey) return } } t.Fatal("openai provider not found") }) t.Run("ForbiddenForOrganizationMember", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) _, err := memberClient.ListChatProviders(ctx) requireSDKError(t, err, http.StatusForbidden) }) } func TestCreateChatProvider(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", DisplayName: "OpenAI Primary", APIKey: "test-api-key", }) require.NoError(t, err) require.NotEqual(t, uuid.Nil, provider.ID) require.Equal(t, "openai", provider.Provider) require.Equal(t, "OpenAI Primary", provider.DisplayName) require.True(t, provider.Enabled) require.True(t, provider.HasAPIKey) require.Equal(t, codersdk.ChatProviderConfigSourceDatabase, provider.Source) }) t.Run("InvalidProvider", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "not-a-provider", APIKey: "test-api-key", }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid provider.", sdkErr.Message) }) t.Run("Conflict", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) _, err = client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "other-api-key", }) sdkErr := requireSDKError(t, err, http.StatusConflict) require.Equal(t, "Chat provider already exists.", sdkErr.Message) }) t.Run("ForbiddenForOrganizationMember", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) _, err := memberClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "member-key", }) requireSDKError(t, err, http.StatusForbidden) }) t.Run("DefaultsPolicyFields", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) require.True(t, provider.CentralAPIKeyEnabled) require.False(t, provider.AllowUserAPIKey) require.False(t, provider.AllowCentralAPIKeyFallback) }) t.Run("UserOnlyDoesNotRequireCentralKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) require.False(t, provider.CentralAPIKeyEnabled) require.True(t, provider.AllowUserAPIKey) require.False(t, provider.AllowCentralAPIKeyFallback) require.False(t, provider.HasAPIKey) }) t.Run("RejectsDeploymentBackedCentralKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) values := chatDeploymentValues(t) values.AI.BridgeConfig.LegacyOpenAI.Key = serpent.String("deployment-openai-key") client := newChatClientWithDeploymentValues(t, values) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "API key is required when central API key is enabled.", sdkErr.Message) }) t.Run("RejectsInvalidPolicyTuple", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) testCases := []struct { name string central bool user bool fallback bool }{ { name: "NoneEnabled", central: false, user: false, fallback: false, }, { name: "FallbackWithoutCentral", central: false, user: true, fallback: true, }, { name: "FallbackWithoutUser", central: true, user: false, fallback: true, }, } for _, testCase := range testCases { _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", CentralAPIKeyEnabled: ptr.Ref(testCase.central), AllowUserAPIKey: ptr.Ref(testCase.user), AllowCentralAPIKeyFallback: ptr.Ref(testCase.fallback), }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equalf(t, "Invalid credential policy.", sdkErr.Message, "case %s", testCase.name) } }) t.Run("RejectsTooLargeAPIKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: strings.Repeat("a", chatProviderAPIKeySizeLimit+1), }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "API key too large.", sdkErr.Message) require.Equal(t, fmt.Sprintf("API key exceeds maximum size of %d bytes", chatProviderAPIKeySizeLimit), sdkErr.Detail) }) t.Run("AllowsMaxSizedAPIKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: strings.Repeat("a", chatProviderAPIKeySizeLimit), }) require.NoError(t, err) require.True(t, provider.HasAPIKey) }) } func TestUpdateChatProvider(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) enabled := false baseURL := "https://example.com/v1" updated, err := client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ DisplayName: "OpenAI Updated", Enabled: &enabled, BaseURL: &baseURL, }) require.NoError(t, err) require.Equal(t, provider.ID, updated.ID) require.Equal(t, "OpenAI Updated", updated.DisplayName) require.False(t, updated.Enabled) require.Equal(t, baseURL, updated.BaseURL) }) t.Run("NotFound", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.UpdateChatProvider(ctx, uuid.New(), codersdk.UpdateChatProviderConfigRequest{ DisplayName: "missing", }) requireSDKError(t, err, http.StatusNotFound) }) t.Run("InvalidProviderID", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) res, err := client.Request( ctx, http.MethodPatch, "/api/experimental/chats/providers/not-a-uuid", codersdk.UpdateChatProviderConfigRequest{DisplayName: "ignored"}, ) require.NoError(t, err) defer res.Body.Close() err = codersdk.ReadBodyAsError(res) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid chat provider ID.", sdkErr.Message) }) t.Run("ForbiddenForOrganizationMember", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) provider, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) _, err = memberClient.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ DisplayName: "member update", }) requireSDKError(t, err, http.StatusForbidden) }) t.Run("AppliesPolicyOverrides", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) updated, err := client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) require.True(t, updated.AllowUserAPIKey) require.False(t, updated.CentralAPIKeyEnabled) require.False(t, updated.HasAPIKey) }) t.Run("RejectsDeploymentBackedCentralKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) values := chatDeploymentValues(t) values.AI.BridgeConfig.LegacyOpenAI.Key = serpent.String("deployment-openai-key") client := newChatClientWithDeploymentValues(t, values) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ CentralAPIKeyEnabled: ptr.Ref(true), }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "API key is required when central API key is enabled.", sdkErr.Message) }) t.Run("RejectsClearingLastCentralKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ APIKey: ptr.Ref(""), }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "API key is required when central API key is enabled.", sdkErr.Message) }) t.Run("RejectsEnablingCentralKeyWithoutKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ CentralAPIKeyEnabled: ptr.Ref(true), }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "API key is required when central API key is enabled.", sdkErr.Message) }) t.Run("RejectsInvalidPolicyTuple", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) testCases := []struct { name string central bool user bool fallback bool }{ { name: "NoneEnabled", central: false, user: false, fallback: false, }, { name: "FallbackWithoutCentral", central: false, user: true, fallback: true, }, { name: "FallbackWithoutUser", central: true, user: false, fallback: true, }, } for _, testCase := range testCases { _, err := client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ CentralAPIKeyEnabled: ptr.Ref(testCase.central), AllowUserAPIKey: ptr.Ref(testCase.user), AllowCentralAPIKeyFallback: ptr.Ref(testCase.fallback), }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equalf(t, "Invalid credential policy.", sdkErr.Message, "case %s", testCase.name) } }) t.Run("RejectsTooLargeAPIKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ APIKey: ptr.Ref(strings.Repeat("a", chatProviderAPIKeySizeLimit+1)), }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "API key too large.", sdkErr.Message) require.Equal(t, fmt.Sprintf("API key exceeds maximum size of %d bytes", chatProviderAPIKeySizeLimit), sdkErr.Detail) }) t.Run("AllowsMaxSizedAPIKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) updated, err := client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ APIKey: ptr.Ref(strings.Repeat("a", chatProviderAPIKeySizeLimit)), }) require.NoError(t, err) require.True(t, updated.HasAPIKey) }) } func TestDeleteChatProvider(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) err = client.DeleteChatProvider(ctx, provider.ID) require.NoError(t, err) providers, err := client.ListChatProviders(ctx) require.NoError(t, err) for _, listed := range providers { require.NotEqual(t, provider.ID, listed.ID) } }) t.Run("NotFound", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) err := client.DeleteChatProvider(ctx, uuid.New()) requireSDKError(t, err, http.StatusNotFound) }) t.Run("InvalidProviderID", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) res, err := client.Request( ctx, http.MethodDelete, "/api/experimental/chats/providers/not-a-uuid", nil, ) require.NoError(t, err) defer res.Body.Close() err = codersdk.ReadBodyAsError(res) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid chat provider ID.", sdkErr.Message) }) t.Run("ForbiddenForOrganizationMember", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) provider, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) err = memberClient.DeleteChatProvider(ctx, provider.ID) requireSDKError(t, err, http.StatusForbidden) }) } func TestChatProviderAPIKeysFromDeploymentValues(t *testing.T) { t.Parallel() t.Run("DoesNotReuseBridgeConfig", func(t *testing.T) { t.Parallel() values := chatDeploymentValues(t) values.AI.BridgeConfig.LegacyOpenAI.Key = serpent.String("deployment-openai-key") values.AI.BridgeConfig.LegacyAnthropic.Key = serpent.String("deployment-anthropic-key") values.AI.BridgeConfig.LegacyOpenAI.BaseURL = serpent.String("https://custom-openai.example.com") keys := coderd.ChatProviderAPIKeysFromDeploymentValues(values) require.Equal(t, chatprovider.ProviderAPIKeys{}, keys) }) t.Run("NilDeploymentValues", func(t *testing.T) { t.Parallel() keys := coderd.ChatProviderAPIKeysFromDeploymentValues(nil) require.Equal(t, chatprovider.ProviderAPIKeys{}, keys) }) } func TestUserChatProviderConfigs(t *testing.T) { t.Parallel() requireUserProviderConfig := func(t *testing.T, configs []codersdk.UserChatProviderConfig, provider string) codersdk.UserChatProviderConfig { t.Helper() for _, config := range configs { if config.Provider == provider { return config } } t.Fatalf("provider %q not found", provider) return codersdk.UserChatProviderConfig{} } requireNoUserProviderConfig := func(t *testing.T, configs []codersdk.UserChatProviderConfig, provider string) { t.Helper() for _, config := range configs { if config.Provider == provider { t.Fatalf("provider %q unexpectedly found", provider) } } } t.Run("ListOnlyUserKeyProviders", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) anthropicProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "anthropic", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) _, err = client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "google", APIKey: "central-api-key", }) require.NoError(t, err) configs, err := client.ListUserChatProviderConfigs(ctx) require.NoError(t, err) require.Len(t, configs, 1) require.Equal(t, anthropicProvider.ID, configs[0].ProviderID) require.Equal(t, anthropicProvider.Provider, configs[0].Provider) }) t.Run("ListReportsHasUserAPIKeyFalse", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "anthropic", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) configs, err := client.ListUserChatProviderConfigs(ctx) require.NoError(t, err) require.Len(t, configs, 1) require.Equal(t, provider.ID, configs[0].ProviderID) require.False(t, configs[0].HasUserAPIKey) }) t.Run("ListHidesDisabledProviderEvenWithSavedKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "anthropic", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: "user-key", }) require.NoError(t, err) _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ Enabled: ptr.Ref(false), }) require.NoError(t, err) configs, err := client.ListUserChatProviderConfigs(ctx) require.NoError(t, err) require.Empty(t, configs) requireNoUserProviderConfig(t, configs, "anthropic") }) t.Run("ListHidesUserKeyDisabledProviderAndRestoresOnReEnable", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "anthropic", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: "user-key", }) require.NoError(t, err) centralAPIKey := "central-key" _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ APIKey: ¢ralAPIKey, CentralAPIKeyEnabled: ptr.Ref(true), AllowUserAPIKey: ptr.Ref(false), }) require.NoError(t, err) configs, err := client.ListUserChatProviderConfigs(ctx) require.NoError(t, err) require.Empty(t, configs) requireNoUserProviderConfig(t, configs, "anthropic") _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) configs, err = client.ListUserChatProviderConfigs(ctx) require.NoError(t, err) listed := requireUserProviderConfig(t, configs, "anthropic") require.Equal(t, provider.ID, listed.ProviderID) require.True(t, listed.HasUserAPIKey) require.False(t, listed.HasCentralAPIKeyFallback) }) t.Run("UpsertCreatesKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "anthropic", APIKey: "central-key", CentralAPIKeyEnabled: ptr.Ref(true), AllowUserAPIKey: ptr.Ref(true), AllowCentralAPIKeyFallback: ptr.Ref(true), }) require.NoError(t, err) config, err := client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: "user-key", }) require.NoError(t, err) require.Equal(t, provider.ID, config.ProviderID) require.Equal(t, provider.Provider, config.Provider) require.Equal(t, provider.DisplayName, config.DisplayName) require.True(t, config.HasUserAPIKey) require.True(t, config.HasCentralAPIKeyFallback) configs, err := client.ListUserChatProviderConfigs(ctx) require.NoError(t, err) listed := requireUserProviderConfig(t, configs, "anthropic") require.Equal(t, provider.ID, listed.ProviderID) require.Equal(t, provider.DisplayName, listed.DisplayName) require.True(t, listed.HasUserAPIKey) require.True(t, listed.HasCentralAPIKeyFallback) }) t.Run("ListRecomputesFallbackAvailability", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) values := chatDeploymentValues(t) values.AI.BridgeConfig.LegacyOpenAI.Key = serpent.String("deployment-openai-key") client := newChatClientWithDeploymentValues(t, values) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-central-key", AllowUserAPIKey: ptr.Ref(true), AllowCentralAPIKeyFallback: ptr.Ref(true), }) require.NoError(t, err) _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: "user-key", }) require.NoError(t, err) configs, err := client.ListUserChatProviderConfigs(ctx) require.NoError(t, err) listed := requireUserProviderConfig(t, configs, "openai") require.True(t, listed.HasCentralAPIKeyFallback) _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ CentralAPIKeyEnabled: ptr.Ref(false), AllowCentralAPIKeyFallback: ptr.Ref(false), }) require.NoError(t, err) configs, err = client.ListUserChatProviderConfigs(ctx) require.NoError(t, err) listed = requireUserProviderConfig(t, configs, "openai") require.False(t, listed.HasCentralAPIKeyFallback) }) t.Run("UpsertUpdatesKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "anthropic", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: "key-1", }) require.NoError(t, err) _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: "key-2", }) require.NoError(t, err) configs, err := client.ListUserChatProviderConfigs(ctx) require.NoError(t, err) listed := requireUserProviderConfig(t, configs, "anthropic") require.True(t, listed.HasUserAPIKey) }) t.Run("UpsertRejectsMissingProvider", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.UpsertUserChatProviderKey(ctx, uuid.New(), codersdk.CreateUserChatProviderKeyRequest{ APIKey: "user-key", }) requireSDKError(t, err, http.StatusNotFound) }) t.Run("UpsertRejectsDisabledProvider", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "anthropic", Enabled: ptr.Ref(false), CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: "user-key", }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Provider is disabled.", sdkErr.Message) }) t.Run("UpsertRejectsProviderWithoutUserKeys", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "google", APIKey: "central-api-key", }) require.NoError(t, err) _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: "user-key", }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Provider does not allow user API keys.", sdkErr.Message) }) t.Run("UpsertRejectsEmptyAPIKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "anthropic", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: "", }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "API key is required.", sdkErr.Message) }) t.Run("DeleteRemovesKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "anthropic", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: "user-key", }) require.NoError(t, err) configs, err := client.ListUserChatProviderConfigs(ctx) require.NoError(t, err) listed := requireUserProviderConfig(t, configs, "anthropic") require.True(t, listed.HasUserAPIKey) err = client.DeleteUserChatProviderKey(ctx, provider.ID) require.NoError(t, err) configs, err = client.ListUserChatProviderConfigs(ctx) require.NoError(t, err) listed = requireUserProviderConfig(t, configs, "anthropic") require.False(t, listed.HasUserAPIKey) err = client.DeleteUserChatProviderKey(ctx, provider.ID) require.NoError(t, err) }) t.Run("OtherUserDoesNotSeeKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) provider, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "anthropic", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) _, err = adminClient.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: "admin-user-key", }) require.NoError(t, err) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) configs, err := memberClient.ListUserChatProviderConfigs(ctx) require.NoError(t, err) listed := requireUserProviderConfig(t, configs, "anthropic") require.Equal(t, provider.ID, listed.ProviderID) require.False(t, listed.HasUserAPIKey) }) } func TestUpsertUserChatProviderKey(t *testing.T) { t.Parallel() t.Run("RejectsTooLargeAPIKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "anthropic", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: strings.Repeat("a", chatProviderAPIKeySizeLimit+1), }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "API key too large.", sdkErr.Message) require.Equal(t, fmt.Sprintf("API key exceeds maximum size of %d bytes", chatProviderAPIKeySizeLimit), sdkErr.Detail) }) t.Run("AllowsMaxSizedAPIKey", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "anthropic", CentralAPIKeyEnabled: ptr.Ref(false), AllowUserAPIKey: ptr.Ref(true), }) require.NoError(t, err) config, err := client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ APIKey: strings.Repeat("a", chatProviderAPIKeySizeLimit), }) require.NoError(t, err) require.True(t, config.HasUserAPIKey) }) } func TestListChatModelConfigs(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) configs, err := client.ListChatModelConfigs(ctx) require.NoError(t, err) require.NotEmpty(t, configs) found := false for _, config := range configs { if config.ID == modelConfig.ID { found = true require.Equal(t, "openai", config.Provider) require.Equal(t, "gpt-4o-mini", config.Model) require.True(t, config.IsDefault) } } require.True(t, found) }) t.Run("AdminIncludesDisabledModelConfigs", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) contextLimit := int64(4096) enabled := false disabledConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", Model: "gpt-4o-disabled", DisplayName: "GPT-4o Disabled", Enabled: &enabled, ContextLimit: &contextLimit, }) require.NoError(t, err) require.False(t, disabledConfig.Enabled) configs, err := client.ListChatModelConfigs(ctx) require.NoError(t, err) found := false for _, config := range configs { if config.ID == disabledConfig.ID { found = true require.False(t, config.Enabled) require.Equal(t, disabledConfig.DisplayName, config.DisplayName) } } require.True(t, found) }) t.Run("NonAdminExcludesDisabledModelConfigs", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) enabledConfig := createChatModelConfig(t, adminClient) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) contextLimit := int64(4096) enabled := false _, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", Model: "gpt-4o-disabled", DisplayName: "GPT-4o Disabled", Enabled: &enabled, ContextLimit: &contextLimit, }) require.NoError(t, err) configs, err := memberClient.ListChatModelConfigs(ctx) require.NoError(t, err) require.Len(t, configs, 1) require.Equal(t, enabledConfig.ID, configs[0].ID) require.True(t, configs[0].Enabled) }) t.Run("DeserializesLegacyPricingJSON", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) legacyOptions := json.RawMessage(`{"input_price_per_million_tokens":0.15,"output_price_per_million_tokens":0.6,"cache_read_price_per_million_tokens":0.03,"cache_write_price_per_million_tokens":0.3}`) storedConfig, err := db.InsertChatModelConfig(dbauthz.AsSystemRestricted(ctx), database.InsertChatModelConfigParams{ Provider: "openai", Model: "gpt-4o-mini-legacy", DisplayName: "GPT-4o Mini Legacy", CreatedBy: uuid.NullUUID{UUID: firstUser.UserID, Valid: true}, UpdatedBy: uuid.NullUUID{UUID: firstUser.UserID, Valid: true}, Enabled: true, IsDefault: false, ContextLimit: 4096, CompressionThreshold: 80, Options: legacyOptions, }) require.NoError(t, err) configs, err := client.ListChatModelConfigs(ctx) require.NoError(t, err) require.Len(t, configs, 1) require.Equal(t, storedConfig.ID, configs[0].ID) requireChatModelPricing(t, configs[0].ModelConfig, &codersdk.ChatModelCallConfig{ Cost: &codersdk.ModelCostConfig{ InputPricePerMillionTokens: decRef("0.15"), OutputPricePerMillionTokens: decRef("0.6"), CacheReadPricePerMillionTokens: decRef("0.03"), CacheWritePricePerMillionTokens: decRef("0.3"), }, }) }) t.Run("SuccessForOrganizationMember", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) modelConfig := createChatModelConfig(t, adminClient) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) // Non-admin users should see only enabled model configs. configs, err := memberClient.ListChatModelConfigs(ctx) require.NoError(t, err) require.NotEmpty(t, configs) found := false for _, config := range configs { if config.ID == modelConfig.ID { found = true require.Equal(t, "openai", config.Provider) require.Equal(t, "gpt-4o-mini", config.Model) } } require.True(t, found) }) } func TestCreateChatModelConfig(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) contextLimit := int64(4096) isDefault := true pricing := &codersdk.ChatModelCallConfig{ Cost: &codersdk.ModelCostConfig{ InputPricePerMillionTokens: decRef("0.15"), OutputPricePerMillionTokens: decRef("0.6"), CacheReadPricePerMillionTokens: decRef("0.03"), CacheWritePricePerMillionTokens: decRef("0.3"), }, } modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", Model: "gpt-4o-mini", ContextLimit: &contextLimit, IsDefault: &isDefault, ModelConfig: pricing, }) require.NoError(t, err) require.NotEqual(t, uuid.Nil, modelConfig.ID) require.Equal(t, "openai", modelConfig.Provider) require.Equal(t, "gpt-4o-mini", modelConfig.Model) require.EqualValues(t, 4096, modelConfig.ContextLimit) require.True(t, modelConfig.IsDefault) requireChatModelPricing(t, modelConfig.ModelConfig, pricing) configs, err := client.ListChatModelConfigs(ctx) require.NoError(t, err) require.Len(t, configs, 1) requireChatModelPricing(t, configs[0].ModelConfig, pricing) }) t.Run("RejectsNegativePricing", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) contextLimit := int64(4096) _, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", Model: "gpt-4o-mini", ContextLimit: &contextLimit, ModelConfig: &codersdk.ChatModelCallConfig{ Cost: &codersdk.ModelCostConfig{ InputPricePerMillionTokens: decRef("-0.01"), }, }, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid model config.", sdkErr.Message) require.Equal( t, "cost.input_price_per_million_tokens must be greater than or equal to zero", sdkErr.Detail, ) }) t.Run("MissingContextLimit", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", Model: "gpt-4o-mini", }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Context limit is required.", sdkErr.Message) }) t.Run("ProviderNotConfigured", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) contextLimit := int64(4096) _, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", Model: "gpt-4o-mini", ContextLimit: &contextLimit, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Chat provider is not configured.", sdkErr.Message) }) t.Run("ForbiddenForOrganizationMember", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) _, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) contextLimit := int64(4096) _, err = memberClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", Model: "gpt-4o-mini", ContextLimit: &contextLimit, }) requireSDKError(t, err, http.StatusForbidden) }) } func TestUpdateChatModelConfig(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) contextLimit := int64(8192) pricing := &codersdk.ChatModelCallConfig{ Cost: &codersdk.ModelCostConfig{ InputPricePerMillionTokens: decRef("0.2"), OutputPricePerMillionTokens: decRef("0.8"), CacheReadPricePerMillionTokens: decRef("0.04"), CacheWritePricePerMillionTokens: decRef("0.4"), }, } updated, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ DisplayName: "GPT-4o Mini Updated", ContextLimit: &contextLimit, ModelConfig: pricing, }) require.NoError(t, err) require.Equal(t, modelConfig.ID, updated.ID) require.Equal(t, "GPT-4o Mini Updated", updated.DisplayName) require.EqualValues(t, 8192, updated.ContextLimit) requireChatModelPricing(t, updated.ModelConfig, pricing) configs, err := client.ListChatModelConfigs(ctx) require.NoError(t, err) require.Len(t, configs, 1) requireChatModelPricing(t, configs[0].ModelConfig, pricing) }) t.Run("DisablePreservesRecordAndHidesItFromNonAdmins", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) modelConfig := createChatModelConfig(t, adminClient) enabled := false updated, err := adminClient.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ Enabled: &enabled, }) require.NoError(t, err) require.Equal(t, modelConfig.ID, updated.ID) require.False(t, updated.Enabled) adminConfigs, err := adminClient.ListChatModelConfigs(ctx) require.NoError(t, err) foundForAdmin := false for _, config := range adminConfigs { if config.ID == modelConfig.ID { foundForAdmin = true require.False(t, config.Enabled) } } require.True(t, foundForAdmin) memberConfigs, err := memberClient.ListChatModelConfigs(ctx) require.NoError(t, err) for _, config := range memberConfigs { require.NotEqual(t, modelConfig.ID, config.ID) } }) t.Run("ReEnableRestoresVisibilityForNonAdmins", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) _, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) contextLimit := int64(4096) enabled := false modelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", Model: "gpt-4o-reenable", DisplayName: "GPT-4o Re-enable", Enabled: &enabled, ContextLimit: &contextLimit, }) require.NoError(t, err) require.False(t, modelConfig.Enabled) memberConfigs, err := memberClient.ListChatModelConfigs(ctx) require.NoError(t, err) foundForMember := false for _, config := range memberConfigs { if config.ID == modelConfig.ID { foundForMember = true } } require.False(t, foundForMember) enabled = true updated, err := adminClient.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ Enabled: &enabled, }) require.NoError(t, err) require.Equal(t, modelConfig.ID, updated.ID) require.True(t, updated.Enabled) memberConfigs, err = memberClient.ListChatModelConfigs(ctx) require.NoError(t, err) foundForMember = false for _, config := range memberConfigs { if config.ID == modelConfig.ID { foundForMember = true require.True(t, config.Enabled) } } require.True(t, foundForMember) }) t.Run("RejectsNegativePricing", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) _, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ ModelConfig: &codersdk.ChatModelCallConfig{ Cost: &codersdk.ModelCostConfig{ OutputPricePerMillionTokens: decRef("-1.0"), }, }, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid model config.", sdkErr.Message) require.Equal( t, "cost.output_price_per_million_tokens must be greater than or equal to zero", sdkErr.Detail, ) }) t.Run("NotFound", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.UpdateChatModelConfig(ctx, uuid.New(), codersdk.UpdateChatModelConfigRequest{ DisplayName: "missing", }) requireSDKError(t, err, http.StatusNotFound) }) t.Run("InvalidContextLimit", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) contextLimit := int64(0) _, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ ContextLimit: &contextLimit, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Context limit must be greater than zero.", sdkErr.Message) }) t.Run("InvalidModelConfigID", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) res, err := client.Request( ctx, http.MethodPatch, "/api/experimental/chats/model-configs/not-a-uuid", codersdk.UpdateChatModelConfigRequest{DisplayName: "ignored"}, ) require.NoError(t, err) defer res.Body.Close() err = codersdk.ReadBodyAsError(res) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid chat model config ID.", sdkErr.Message) }) t.Run("ForbiddenForOrganizationMember", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) modelConfig := createChatModelConfig(t, adminClient) _, err := memberClient.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ DisplayName: "member update", }) requireSDKError(t, err, http.StatusForbidden) }) } func TestDeleteChatModelConfig(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) err := client.DeleteChatModelConfig(ctx, modelConfig.ID) require.NoError(t, err) configs, err := client.ListChatModelConfigs(ctx) require.NoError(t, err) for _, config := range configs { require.NotEqual(t, modelConfig.ID, config.ID) } }) t.Run("NotFound", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) err := client.DeleteChatModelConfig(ctx, uuid.New()) requireSDKError(t, err, http.StatusNotFound) }) t.Run("InvalidModelConfigID", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) res, err := client.Request( ctx, http.MethodDelete, "/api/experimental/chats/model-configs/not-a-uuid", nil, ) require.NoError(t, err) defer res.Body.Close() err = codersdk.ReadBodyAsError(res) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid chat model config ID.", sdkErr.Message) }) t.Run("ForbiddenForOrganizationMember", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) modelConfig := createChatModelConfig(t, adminClient) err := memberClient.DeleteChatModelConfig(ctx, modelConfig.ID) requireSDKError(t, err, http.StatusForbidden) }) } func TestGetChat(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "get chat route payload", }, }, }) require.NoError(t, err) chatResult, err := client.GetChat(ctx, createdChat.ID) require.NoError(t, err) messagesResult, err := client.GetChatMessages(ctx, createdChat.ID, nil) require.NoError(t, err) require.Equal(t, createdChat.ID, chatResult.ID) require.Equal(t, firstUser.UserID, chatResult.OwnerID) require.Equal(t, modelConfig.ID, chatResult.LastModelConfigID) require.Equal(t, "get chat route payload", chatResult.Title) require.NotZero(t, chatResult.CreatedAt) require.NotZero(t, chatResult.UpdatedAt) require.NotEmpty(t, messagesResult.Messages) require.Empty(t, messagesResult.QueuedMessages) foundUserMessage := false for _, message := range messagesResult.Messages { require.Equal(t, createdChat.ID, message.ChatID) require.NotEqual(t, codersdk.ChatMessageRoleSystem, message.Role) for _, part := range message.Content { if message.Role == codersdk.ChatMessageRoleUser && part.Type == codersdk.ChatMessagePartTypeText && part.Text == "get chat route payload" { foundUserMessage = true } } } require.True(t, foundUserMessage) }) t.Run("NotFoundForDifferentUser", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "private chat", }, }, }) require.NoError(t, err) otherClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID, rbac.RoleAgentsAccess()) otherClient := codersdk.NewExperimentalClient(otherClientRaw) _, err = otherClient.GetChat(ctx, createdChat.ID) requireSDKError(t, err, http.StatusNotFound) }) t.Run("FilesHydrated", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Upload a file. pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "hydrated.png", bytes.NewReader(pngData)) require.NoError(t, err) // Create a chat with a text + file part. chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ {Type: codersdk.ChatInputPartTypeText, Text: "check file hydration"}, {Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID}, }, }) require.NoError(t, err) // GET the chat — files must be hydrated with all metadata fields. chatResult, err := client.GetChat(ctx, chat.ID) require.NoError(t, err) require.Len(t, chatResult.Files, 1) f := chatResult.Files[0] require.Equal(t, uploadResp.ID, f.ID) require.Equal(t, firstUser.UserID, f.OwnerID) require.NotEqual(t, uuid.Nil, f.OrganizationID) require.Equal(t, "image/png", f.MimeType) require.Equal(t, "hydrated.png", f.Name) require.NotZero(t, f.CreatedAt) }) // ToolCreatedFilesLinked exercises the DB path that chatd uses // when a tool (e.g. propose_plan) creates a file: InsertChatFile // then LinkChatFiles. This is a DB-level test because driving // the full chatd tool-call pipeline requires an LLM mock. t.Run("ToolCreatedFilesLinked", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, store := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Create a chat via the API so all metadata is set up. chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ {Type: codersdk.ChatInputPartTypeText, Text: "tool file test"}, }, }) require.NoError(t, err) // Mimic what chatd's StoreFile closure does: // 1. InsertChatFile // 2. LinkChatFiles //nolint:gocritic // Using AsChatd to mimic the chatd background worker. chatdCtx := dbauthz.AsChatd(ctx) fileRow, err := store.InsertChatFile(chatdCtx, database.InsertChatFileParams{ OwnerID: firstUser.UserID, OrganizationID: firstUser.OrganizationID, Name: "plan.md", Mimetype: "text/markdown", Data: []byte("# Plan"), }) require.NoError(t, err) rejected, err := store.LinkChatFiles(chatdCtx, database.LinkChatFilesParams{ ChatID: chat.ID, MaxFileLinks: int32(codersdk.MaxChatFileIDs), FileIds: []uuid.UUID{fileRow.ID}, }) require.NoError(t, err) require.Equal(t, int32(0), rejected, "0 rejected = all files linked") // Verify via the API that the file appears in the chat. chatResult, err := client.GetChat(ctx, chat.ID) require.NoError(t, err) require.Len(t, chatResult.Files, 1) f := chatResult.Files[0] require.Equal(t, fileRow.ID, f.ID) require.Equal(t, firstUser.UserID, f.OwnerID) require.Equal(t, firstUser.OrganizationID, f.OrganizationID) require.Equal(t, "plan.md", f.Name) require.Equal(t, "text/markdown", f.MimeType) // Fill up to the cap by inserting more files via the // chatd DB path, then verify the cap is enforced. for i := 1; i < codersdk.MaxChatFileIDs; i++ { extra, err := store.InsertChatFile(chatdCtx, database.InsertChatFileParams{ OwnerID: firstUser.UserID, OrganizationID: firstUser.OrganizationID, Name: fmt.Sprintf("file%d.md", i), Mimetype: "text/markdown", Data: []byte("data"), }) require.NoError(t, err) _, err = store.LinkChatFiles(chatdCtx, database.LinkChatFilesParams{ ChatID: chat.ID, MaxFileLinks: int32(codersdk.MaxChatFileIDs), FileIds: []uuid.UUID{extra.ID}, }) require.NoError(t, err) } // Chat should now have exactly MaxChatFileIDs files. chatResult, err = client.GetChat(ctx, chat.ID) require.NoError(t, err) require.Len(t, chatResult.Files, codersdk.MaxChatFileIDs) // Attempt to add one more file — should be rejected (0 rows). overflow, err := store.InsertChatFile(chatdCtx, database.InsertChatFileParams{ OwnerID: firstUser.UserID, OrganizationID: firstUser.OrganizationID, Name: "overflow.md", Mimetype: "text/markdown", Data: []byte("too many"), }) require.NoError(t, err) rejected, err = store.LinkChatFiles(chatdCtx, database.LinkChatFilesParams{ ChatID: chat.ID, MaxFileLinks: int32(codersdk.MaxChatFileIDs), FileIds: []uuid.UUID{overflow.ID}, }) require.NoError(t, err) require.Equal(t, int32(1), rejected, "cap should reject the 21st file") // Re-appending an already-linked ID at cap should succeed // (dedup means no array growth). rejected, err = store.LinkChatFiles(chatdCtx, database.LinkChatFilesParams{ ChatID: chat.ID, MaxFileLinks: int32(codersdk.MaxChatFileIDs), FileIds: []uuid.UUID{fileRow.ID}, }) require.NoError(t, err) // ON CONFLICT DO NOTHING returns 0 rows when the link // already exists, which is fine — the file is still linked. require.Equal(t, int32(0), rejected, "dedup of existing ID should be a no-op") // Count should still be exactly MaxChatFileIDs. chatResult, err = client.GetChat(ctx, chat.ID) require.NoError(t, err) require.Len(t, chatResult.Files, codersdk.MaxChatFileIDs) }) } func TestPatchChat(t *testing.T) { t.Parallel() createChat := func(ctx context.Context, t *testing.T, client *codersdk.ExperimentalClient, orgID uuid.UUID, text string) codersdk.Chat { t.Helper() chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: orgID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: text, }, }, }) require.NoError(t, err) return chat } getChat := func(ctx context.Context, t *testing.T, client *codersdk.ExperimentalClient, chatID uuid.UUID) codersdk.Chat { t.Helper() chat, err := client.GetChat(ctx, chatID) require.NoError(t, err) return chat } createStoredChat := func( ctx context.Context, t *testing.T, db database.Store, ownerID uuid.UUID, orgID uuid.UUID, modelConfigID uuid.UUID, title string, ) codersdk.Chat { t.Helper() dbChat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: orgID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: ownerID, LastModelConfigID: modelConfigID, Title: title, }) require.NoError(t, err) return db2sdk.Chat(dbChat, nil, nil) } t.Run("PlanMode", func(t *testing.T) { t.Parallel() t.Run("SetToPlan", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat := createChat(ctx, t, client, firstUser.OrganizationID, "set plan mode") err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ PlanMode: ptr.Ref(codersdk.ChatPlanModePlan), }) require.NoError(t, err) updated := getChat(ctx, t, client, chat.ID) require.Equal(t, codersdk.ChatPlanModePlan, updated.PlanMode) }) t.Run("Clear", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat := createChat(ctx, t, client, firstUser.OrganizationID, "clear plan mode") err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ PlanMode: ptr.Ref(codersdk.ChatPlanModePlan), }) require.NoError(t, err) err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ PlanMode: ptr.Ref(codersdk.ChatPlanMode("")), }) require.NoError(t, err) updated := getChat(ctx, t, client, chat.ID) require.Empty(t, updated.PlanMode) }) t.Run("RejectsInvalidValue", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat := createChat(ctx, t, client, firstUser.OrganizationID, "invalid plan mode") invalidPlanMode := codersdk.ChatPlanMode("invalid") err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ PlanMode: &invalidPlanMode, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid plan_mode value.", sdkErr.Message) }) }) t.Run("WorkspaceBinding", func(t *testing.T) { t.Parallel() t.Run("BindValidWorkspace", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ OrganizationID: firstUser.OrganizationID, OwnerID: firstUser.UserID, }).WithAgent().Do() chat := createStoredChat( ctx, t, db, firstUser.UserID, firstUser.OrganizationID, modelConfig.ID, "bind workspace", ) err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ WorkspaceID: &workspaceBuild.Workspace.ID, }) require.NoError(t, err) updated := getChat(ctx, t, client, chat.ID) require.NotNil(t, updated.WorkspaceID) require.Equal(t, workspaceBuild.Workspace.ID, *updated.WorkspaceID) }) t.Run("WorkspaceNotFound", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat := createStoredChat( ctx, t, db, firstUser.UserID, firstUser.OrganizationID, modelConfig.ID, "missing workspace", ) workspaceID := uuid.New() err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ WorkspaceID: &workspaceID, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Workspace not found or you do not have access to this resource", sdkErr.Message) }) t.Run("RejectsCrossOrgWorkspaceBinding", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) secondOrg := dbgen.Organization(t, db, database.Organization{}) dbgen.OrganizationMember(t, db, database.OrganizationMember{ OrganizationID: secondOrg.ID, UserID: firstUser.UserID, }) workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ OrganizationID: secondOrg.ID, OwnerID: firstUser.UserID, }).WithAgent().Do() chat := createStoredChat( ctx, t, db, firstUser.UserID, firstUser.OrganizationID, modelConfig.ID, "cross org workspace binding", ) err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ WorkspaceID: &workspaceBuild.Workspace.ID, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Workspace does not belong to this chat's organization.", sdkErr.Message) }) t.Run("ClearWorkspaceBinding", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ OrganizationID: firstUser.OrganizationID, OwnerID: firstUser.UserID, }).WithAgent().Do() chat := createStoredChat( ctx, t, db, firstUser.UserID, firstUser.OrganizationID, modelConfig.ID, "clear workspace binding", ) err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ WorkspaceID: &workspaceBuild.Workspace.ID, }) require.NoError(t, err) workspaceID := uuid.Nil err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{ WorkspaceID: &workspaceID, }) require.NoError(t, err) updated := getChat(ctx, t, client, chat.ID) require.Nil(t, updated.WorkspaceID) require.Nil(t, updated.BuildID) require.Nil(t, updated.AgentID) }) }) } func TestArchiveChat(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chatToArchive, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "archive me", }, }, }) require.NoError(t, err) chatToKeep, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "keep me", }, }, }) require.NoError(t, err) chatsBeforeArchive, err := client.ListChats(ctx, nil) require.NoError(t, err) require.Len(t, chatsBeforeArchive, 2) err = client.UpdateChat(ctx, chatToArchive.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)}) require.NoError(t, err) // Default (no filter) returns only non-archived chats. allChats, err := client.ListChats(ctx, nil) require.NoError(t, err) require.Len(t, allChats, 1) require.Equal(t, chatToKeep.ID, allChats[0].ID) // archived:false returns only non-archived chats. activeChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{ Query: "archived:false", }) require.NoError(t, err) require.Len(t, activeChats, 1) require.Equal(t, chatToKeep.ID, activeChats[0].ID) require.False(t, activeChats[0].Archived) // archived:true returns only archived chats. archivedChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{ Query: "archived:true", }) require.NoError(t, err) require.Len(t, archivedChats, 1) require.Equal(t, chatToArchive.ID, archivedChats[0].ID) require.True(t, archivedChats[0].Archived) }) t.Run("NotFound", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) err := client.UpdateChat(ctx, uuid.New(), codersdk.UpdateChatRequest{Archived: ptr.Ref(true)}) requireSDKError(t, err, http.StatusNotFound) }) t.Run("ArchivesChildren", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) // Create a parent chat via the API. parentChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: user.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "parent chat", }, }, }) require.NoError(t, err) // Insert child chats directly via the database. child1, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "child 1", ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true}, RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true}, }) require.NoError(t, err) child2, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "child 2", ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true}, RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true}, }) require.NoError(t, err) // Archive the parent via the API. err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)}) require.NoError(t, err) // archived:false should exclude the entire archived family. activeChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{ Query: "archived:false", }) require.NoError(t, err) for _, c := range activeChats { require.NotEqual(t, parentChat.ID, c.ID, "parent should not appear") require.NotEqual(t, child1.ID, c.ID, "child1 should not appear") require.NotEqual(t, child2.ID, c.ID, "child2 should not appear") } // Verify children are archived directly in the DB. dbChild1, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), child1.ID) require.NoError(t, err) require.True(t, dbChild1.Archived, "child1 should be archived") dbChild2, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), child2.ID) require.NoError(t, err) require.True(t, dbChild2.Archived, "child2 should be archived") }) } func TestUnarchiveChat(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "archive then unarchive me", }, }, }) require.NoError(t, err) // Archive the chat first. err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)}) require.NoError(t, err) // Verify it's archived. archivedChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{ Query: "archived:true", }) require.NoError(t, err) require.Len(t, archivedChats, 1) require.True(t, archivedChats[0].Archived) // Unarchive the chat. err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)}) require.NoError(t, err) // Verify it's no longer archived. activeChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{ Query: "archived:false", }) require.NoError(t, err) require.Len(t, activeChats, 1) require.Equal(t, chat.ID, activeChats[0].ID) require.False(t, activeChats[0].Archived) // No archived chats remain. archivedChats, err = client.ListChats(ctx, &codersdk.ListChatsOptions{ Query: "archived:true", }) require.NoError(t, err) require.Empty(t, archivedChats) }) t.Run("UnarchivesChildren", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) parentChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: user.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "parent chat", }, }, }) require.NoError(t, err) child1, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "child 1", ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true}, RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true}, }) require.NoError(t, err) child2, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "child 2", ParentChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true}, RootChatID: uuid.NullUUID{UUID: parentChat.ID, Valid: true}, }) require.NoError(t, err) err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)}) require.NoError(t, err) err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)}) require.NoError(t, err) activeChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{ Query: "archived:false", }) require.NoError(t, err) var foundParent bool var foundChild1 bool var foundChild2 bool for _, chat := range activeChats { switch chat.ID { case parentChat.ID: foundParent = true require.False(t, chat.Archived) case child1.ID: foundChild1 = true require.False(t, chat.Archived) case child2.ID: foundChild2 = true require.False(t, chat.Archived) } } require.True(t, foundParent, "parent should be listed as active") require.True(t, foundChild1, "child1 should be listed as active") require.True(t, foundChild2, "child2 should be listed as active") archivedChats, err := client.ListChats(ctx, &codersdk.ListChatsOptions{ Query: "archived:true", }) require.NoError(t, err) for _, chat := range archivedChats { require.NotEqual(t, parentChat.ID, chat.ID, "parent should not remain archived") require.NotEqual(t, child1.ID, chat.ID, "child1 should not remain archived") require.NotEqual(t, child2.ID, chat.ID, "child2 should not remain archived") } dbParent, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), parentChat.ID) require.NoError(t, err) require.False(t, dbParent.Archived, "parent should be unarchived") dbChild1, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), child1.ID) require.NoError(t, err) require.False(t, dbChild1.Archived, "child1 should be unarchived") dbChild2, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), child2.ID) require.NoError(t, err) require.False(t, dbChild2.Archived, "child2 should be unarchived") }) t.Run("NotArchived", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "not archived", }, }, }) require.NoError(t, err) // Trying to unarchive a non-archived chat should fail. err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)}) requireSDKError(t, err, http.StatusBadRequest) }) t.Run("NotFound", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) err := client.UpdateChat(ctx, uuid.New(), codersdk.UpdateChatRequest{Archived: ptr.Ref(false)}) requireSDKError(t, err, http.StatusNotFound) }) } func TestChatPinOrder(t *testing.T) { t.Parallel() createChat := func(ctx context.Context, t *testing.T, client *codersdk.ExperimentalClient, orgID uuid.UUID, title string) codersdk.Chat { t.Helper() chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: orgID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: title, }, }, }) require.NoError(t, err) return chat } getChat := func(ctx context.Context, t *testing.T, client *codersdk.ExperimentalClient, chatID uuid.UUID) codersdk.Chat { t.Helper() chat, err := client.GetChat(ctx, chatID) require.NoError(t, err) return chat } t.Run("PinReorderAndUnpin", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) first := createChat(ctx, t, client, firstUser.OrganizationID, "first pinned chat") second := createChat(ctx, t, client, firstUser.OrganizationID, "second pinned chat") third := createChat(ctx, t, client, firstUser.OrganizationID, "third pinned chat") err := client.UpdateChat(ctx, first.ID, codersdk.UpdateChatRequest{PinOrder: ptr.Ref(int32(1))}) require.NoError(t, err) err = client.UpdateChat(ctx, second.ID, codersdk.UpdateChatRequest{PinOrder: ptr.Ref(int32(1))}) require.NoError(t, err) err = client.UpdateChat(ctx, third.ID, codersdk.UpdateChatRequest{PinOrder: ptr.Ref(int32(1))}) require.NoError(t, err) first = getChat(ctx, t, client, first.ID) second = getChat(ctx, t, client, second.ID) third = getChat(ctx, t, client, third.ID) require.EqualValues(t, 1, first.PinOrder) require.EqualValues(t, 2, second.PinOrder) require.EqualValues(t, 3, third.PinOrder) err = client.UpdateChat(ctx, third.ID, codersdk.UpdateChatRequest{PinOrder: ptr.Ref(int32(1))}) require.NoError(t, err) first = getChat(ctx, t, client, first.ID) second = getChat(ctx, t, client, second.ID) third = getChat(ctx, t, client, third.ID) require.EqualValues(t, 2, first.PinOrder) require.EqualValues(t, 3, second.PinOrder) require.EqualValues(t, 1, third.PinOrder) err = client.UpdateChat(ctx, first.ID, codersdk.UpdateChatRequest{PinOrder: ptr.Ref(int32(0))}) require.NoError(t, err) first = getChat(ctx, t, client, first.ID) second = getChat(ctx, t, client, second.ID) third = getChat(ctx, t, client, third.ID) require.Zero(t, first.PinOrder) require.EqualValues(t, 2, second.PinOrder) require.EqualValues(t, 1, third.PinOrder) }) t.Run("ArchiveClearsPinOrder", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) first := createChat(ctx, t, client, firstUser.OrganizationID, "pinned then archived") second := createChat(ctx, t, client, firstUser.OrganizationID, "stays pinned") // Pin both. err := client.UpdateChat(ctx, first.ID, codersdk.UpdateChatRequest{PinOrder: ptr.Ref(int32(1))}) require.NoError(t, err) err = client.UpdateChat(ctx, second.ID, codersdk.UpdateChatRequest{PinOrder: ptr.Ref(int32(1))}) require.NoError(t, err) // Archive the first — pin_order should be cleared. err = client.UpdateChat(ctx, first.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)}) require.NoError(t, err) first = getChat(ctx, t, client, first.ID) second = getChat(ctx, t, client, second.ID) require.Zero(t, first.PinOrder, "archived chat should have pin_order 0") require.True(t, first.Archived) // The remaining pin keeps its original position. The next // pin/unpin/reorder operation compacts via ROW_NUMBER(). require.EqualValues(t, 2, second.PinOrder, "remaining pin keeps original position") }) t.Run("RejectsNegative", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat := createChat(ctx, t, client, firstUser.OrganizationID, "negative pin order") err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{PinOrder: ptr.Ref(int32(-1))}) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Pin order must be non-negative.", sdkErr.Message) chat = getChat(ctx, t, client, chat.ID) require.Zero(t, chat.PinOrder) }) } func TestPostChatMessages(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "initial message for post route test", }, }, }) require.NoError(t, err) hasTextPart := func(parts []codersdk.ChatMessagePart, want string) bool { for _, part := range parts { if part.Type == codersdk.ChatMessagePartTypeText && part.Text == want { return true } } return false } messageText := "post message route success " + uuid.NewString() created, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: messageText, }, }, }) require.NoError(t, err) if created.Queued { require.Nil(t, created.Message) require.NotNil(t, created.QueuedMessage) require.Equal(t, chat.ID, created.QueuedMessage.ChatID) require.NotZero(t, created.QueuedMessage.ID) require.True(t, hasTextPart(created.QueuedMessage.Content, messageText)) require.Eventually(t, func() bool { messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil) if getErr != nil { return false } for _, queued := range messagesResult.QueuedMessages { if queued.ID == created.QueuedMessage.ID && queued.ChatID == chat.ID && hasTextPart(queued.Content, messageText) { return true } } for _, message := range messagesResult.Messages { if message.Role == codersdk.ChatMessageRoleUser && hasTextPart(message.Content, messageText) { return true } } return false }, testutil.WaitLong, testutil.IntervalFast) } else { require.Nil(t, created.QueuedMessage) require.NotNil(t, created.Message) require.Equal(t, chat.ID, created.Message.ChatID) require.Equal(t, codersdk.ChatMessageRoleUser, created.Message.Role) require.NotZero(t, created.Message.ID) require.True(t, hasTextPart(created.Message.Content, messageText)) require.Eventually(t, func() bool { messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil) if getErr != nil { return false } for _, message := range messagesResult.Messages { if message.ID == created.Message.ID && message.Role == codersdk.ChatMessageRoleUser && hasTextPart(message.Content, messageText) { return true } } return false }, testutil.WaitLong, testutil.IntervalFast) } }) t.Run("MemberWithoutAgentsAccess", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) // Create a member without agents-access and insert a // chat owned by them via system context. Even though // the member can read the chat through org membership, // sending messages should be gated by agents-access // because it triggers AI/LLM inference. memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: firstUser.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: member.ID, LastModelConfigID: modelConfig.ID, Title: "member chat", }) require.NoError(t, err) _, err = memberClient.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "this should fail", }, }, }) requireSDKError(t, err, http.StatusForbidden) }) t.Run("EmptyText", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "initial message for validation test", }, }, }) require.NoError(t, err) _, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: " ", }, }, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid input part.", sdkErr.Message) require.Equal(t, "content[0].text cannot be empty.", sdkErr.Detail) }) t.Run("UsageLimitExceeded", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "initial message for usage-limit test", }}, }) require.NoError(t, err) wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100) insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100) _, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "over limit", }}, }) requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt) }) t.Run("ChatNotFound", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) _, err := client.CreateChatMessage(ctx, uuid.New(), codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "hello", }, }, }) requireSDKError(t, err, http.StatusNotFound) }) } func TestChatMessageWithFileReferences(t *testing.T) { t.Parallel() // createChat is a helper that creates a chat so we can post messages to it. createChatForTest := func(t *testing.T, client *codersdk.ExperimentalClient, orgID uuid.UUID) codersdk.Chat { t.Helper() ctx := testutil.Context(t, testutil.WaitLong) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: orgID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "initial message", }}, }) require.NoError(t, err) return chat } t.Run("FileReferenceOnly", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat := createChatForTest(t, client, firstUser.OrganizationID) created, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeFileReference, FileName: "main.go", StartLine: 10, EndLine: 15, Content: "func broken() {}", }}, }) require.NoError(t, err) // File-reference parts are stored as structured parts. checkFileRef := func(part codersdk.ChatMessagePart) bool { return part.Type == codersdk.ChatMessagePartTypeFileReference && part.FileName == "main.go" && part.StartLine == 10 && part.EndLine == 15 && part.Content == "func broken() {}" } var found bool require.Eventually(t, func() bool { messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil) if getErr != nil { return false } for _, message := range messagesResult.Messages { if message.Role != codersdk.ChatMessageRoleUser { continue } for _, part := range message.Content { if checkFileRef(part) { found = true return true } } } // The message may have been queued. if created.Queued && created.QueuedMessage != nil { for _, queued := range messagesResult.QueuedMessages { for _, part := range queued.Content { if checkFileRef(part) { found = true return true } } } } return false }, testutil.WaitLong, testutil.IntervalFast) require.True(t, found, "expected to find file-reference part in stored message") }) t.Run("FileReferenceSingleLine", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat := createChatForTest(t, client, firstUser.OrganizationID) created, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeFileReference, FileName: "lib/utils.ts", StartLine: 42, EndLine: 42, Content: "const x = 1;", }}, }) require.NoError(t, err) checkFileRef := func(part codersdk.ChatMessagePart) bool { return part.Type == codersdk.ChatMessagePartTypeFileReference && part.FileName == "lib/utils.ts" && part.StartLine == 42 && part.EndLine == 42 && part.Content == "const x = 1;" } require.Eventually(t, func() bool { messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil) if getErr != nil { return false } for _, msg := range messagesResult.Messages { for _, part := range msg.Content { if checkFileRef(part) { return true } } } if created.Queued && created.QueuedMessage != nil { for _, queued := range messagesResult.QueuedMessages { for _, part := range queued.Content { if checkFileRef(part) { return true } } } } return false }, testutil.WaitLong, testutil.IntervalFast) }) t.Run("FileReferenceWithoutContent", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat := createChatForTest(t, client, firstUser.OrganizationID) created, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeFileReference, FileName: "README.md", StartLine: 1, EndLine: 1, // No code content — just a file reference. }}, }) require.NoError(t, err) checkFileRef := func(part codersdk.ChatMessagePart) bool { return part.Type == codersdk.ChatMessagePartTypeFileReference && part.FileName == "README.md" && part.StartLine == 1 && part.EndLine == 1 && part.Content == "" } require.Eventually(t, func() bool { messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil) if getErr != nil { return false } for _, msg := range messagesResult.Messages { for _, part := range msg.Content { if checkFileRef(part) { return true } } } if created.Queued && created.QueuedMessage != nil { for _, queued := range messagesResult.QueuedMessages { for _, part := range queued.Content { if checkFileRef(part) { return true } } } } return false }, testutil.WaitLong, testutil.IntervalFast) }) t.Run("FileReferenceWithCode", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat := createChatForTest(t, client, firstUser.OrganizationID) created, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeFileReference, FileName: "server.go", StartLine: 5, EndLine: 8, Content: "func main() {\n\tfmt.Println()\n}", }}, }) require.NoError(t, err) checkFileRef := func(part codersdk.ChatMessagePart) bool { return part.Type == codersdk.ChatMessagePartTypeFileReference && part.FileName == "server.go" && part.StartLine == 5 && part.EndLine == 8 && part.Content == "func main() {\n\tfmt.Println()\n}" } require.Eventually(t, func() bool { messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil) if getErr != nil { return false } for _, msg := range messagesResult.Messages { for _, part := range msg.Content { if checkFileRef(part) { return true } } } if created.Queued && created.QueuedMessage != nil { for _, queued := range messagesResult.QueuedMessages { for _, part := range queued.Content { if checkFileRef(part) { return true } } } } return false }, testutil.WaitLong, testutil.IntervalFast) }) t.Run("InterleavedTextAndFileReferences", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat := createChatForTest(t, client, firstUser.OrganizationID) created, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "Please review these two issues:", }, { Type: codersdk.ChatInputPartTypeFileReference, FileName: "a.go", StartLine: 1, EndLine: 3, Content: "line1\nline2\nline3", }, { Type: codersdk.ChatInputPartTypeText, Text: "first issue", }, { Type: codersdk.ChatInputPartTypeText, Text: "and also:", }, { Type: codersdk.ChatInputPartTypeFileReference, FileName: "b.go", StartLine: 10, EndLine: 10, Content: "return nil", }, { Type: codersdk.ChatInputPartTypeText, Text: "second issue", }, }, }) require.NoError(t, err) // Verify that all six parts are stored in order with // correct types: text, file-reference, text, text, // file-reference, text. type wantPart struct { typ codersdk.ChatMessagePartType text string fileName string startLine int endLine int content string } want := []wantPart{ {typ: codersdk.ChatMessagePartTypeText, text: "Please review these two issues:"}, {typ: codersdk.ChatMessagePartTypeFileReference, fileName: "a.go", startLine: 1, endLine: 3, content: "line1\nline2\nline3"}, {typ: codersdk.ChatMessagePartTypeText, text: "first issue"}, {typ: codersdk.ChatMessagePartTypeText, text: "and also:"}, {typ: codersdk.ChatMessagePartTypeFileReference, fileName: "b.go", startLine: 10, endLine: 10, content: "return nil"}, {typ: codersdk.ChatMessagePartTypeText, text: "second issue"}, } require.Eventually(t, func() bool { messagesResult, getErr := client.GetChatMessages(ctx, chat.ID, nil) if getErr != nil { return false } checkParts := func(parts []codersdk.ChatMessagePart) bool { if len(parts) != len(want) { return false } for i, w := range want { p := parts[i] if p.Type != w.typ { return false } switch w.typ { case codersdk.ChatMessagePartTypeText: if p.Text != w.text { return false } case codersdk.ChatMessagePartTypeFileReference: if p.FileName != w.fileName || p.StartLine != w.startLine || p.EndLine != w.endLine || p.Content != w.content { return false } } } return true } for _, msg := range messagesResult.Messages { if msg.Role == codersdk.ChatMessageRoleUser && checkParts(msg.Content) { return true } } if created.Queued && created.QueuedMessage != nil { for _, queued := range messagesResult.QueuedMessages { if checkParts(queued.Content) { return true } } } return false }, testutil.WaitLong, testutil.IntervalFast) }) t.Run("EmptyFileName", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat := createChatForTest(t, client, firstUser.OrganizationID) _, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeFileReference, FileName: "", StartLine: 1, EndLine: 1, }}, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid input part.", sdkErr.Message) require.Equal(t, "content[0].file_name cannot be empty for file-reference.", sdkErr.Detail) }) t.Run("CreateChatWithFileReference", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // File references should also work in the initial CreateChat call. chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeFileReference, FileName: "bug.py", StartLine: 7, EndLine: 7, Content: "x = None", }}, }) require.NoError(t, err) require.NotEqual(t, uuid.Nil, chat.ID) // Title is derived from the text parts. For file-references // the formatted text becomes the title source. require.NotEmpty(t, chat.Title) }) } func TestChatMessageWithFiles(t *testing.T) { t.Parallel() t.Run("FileOnly", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Upload a file. pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData)) require.NoError(t, err) // Create a chat with text first. chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "initial message", }, }, }) require.NoError(t, err) // Send a file-only message (no text). resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID, }, }, }) require.NoError(t, err) // Verify the message was accepted. if resp.Queued { require.NotNil(t, resp.QueuedMessage) } else { require.NotNil(t, resp.Message) require.Equal(t, codersdk.ChatMessageRoleUser, resp.Message.Role) } }) t.Run("TextAndFile", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Upload a file. pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData)) require.NoError(t, err) // Create a chat with text first. chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "initial message", }, }, }) require.NoError(t, err) // Send a message with both text and file. resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "here is an image", }, { Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID, }, }, }) require.NoError(t, err) if resp.Queued { require.NotNil(t, resp.QueuedMessage) } else { require.NotNil(t, resp.Message) require.Equal(t, codersdk.ChatMessageRoleUser, resp.Message.Role) } // Verify file parts omit inline data in the API response. messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) for _, msg := range messagesResult.Messages { for _, part := range msg.Content { if part.Type == codersdk.ChatMessagePartTypeFile { require.True(t, part.FileID.Valid, "file part should have a valid file_id") require.Equal(t, uploadResp.ID, part.FileID.UUID) require.Nil(t, part.Data, "file data should not be sent when file_id is present") } } } }) t.Run("FileOnlyOnCreate", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Upload a file. pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData)) require.NoError(t, err) // Create a new chat with only a file part. chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID, }, }, }) require.NoError(t, err) // With no text, chatTitleFromMessage("") returns "New Chat". require.Equal(t, "New Chat", chat.Title) require.Len(t, chat.Files, 1) f := chat.Files[0] require.Equal(t, uploadResp.ID, f.ID) require.Equal(t, firstUser.UserID, f.OwnerID) require.NotEqual(t, uuid.Nil, f.OrganizationID) require.Equal(t, "image/png", f.MimeType) require.Equal(t, "test.png", f.Name) require.NotZero(t, f.CreatedAt) }) t.Run("InvalidFileID", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Create a chat with text first. chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "initial message", }, }, }) require.NoError(t, err) // Send a message with a non-existent file ID. _, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeFile, FileID: uuid.New(), }, }, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid input part.", sdkErr.Message) require.Contains(t, sdkErr.Detail, "does not exist") }) t.Run("FilesLinkedOnSend", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Create a text-only chat (no files initially). chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ {Type: codersdk.ChatInputPartTypeText, Text: "no files yet"}, }, }) require.NoError(t, err) // Upload a file. pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "linked.png", bytes.NewReader(pngData)) require.NoError(t, err) // Send a message with the file. _, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{ {Type: codersdk.ChatInputPartTypeText, Text: "here is a file"}, {Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID}, }, }) require.NoError(t, err) // GET the chat — file should be linked. chatResult, err := client.GetChat(ctx, chat.ID) require.NoError(t, err) require.Len(t, chatResult.Files, 1) require.Equal(t, uploadResp.ID, chatResult.Files[0].ID) require.Equal(t, "linked.png", chatResult.Files[0].Name) }) t.Run("DedupFileIDs", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Upload a file. pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "dedup.png", bytes.NewReader(pngData)) require.NoError(t, err) // Create a chat with a file. chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ {Type: codersdk.ChatInputPartTypeText, Text: "first mention"}, {Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID}, }, }) require.NoError(t, err) // Send another message with the SAME file. msgResp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{ {Type: codersdk.ChatInputPartTypeText, Text: "same file again"}, {Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID}, }, }) require.NoError(t, err) require.Empty(t, msgResp.Warnings, "dedup below cap should not produce warnings") // GET — should have exactly 1 file (deduped by SQL DISTINCT). chatResult, err := client.GetChat(ctx, chat.ID) require.NoError(t, err) require.Len(t, chatResult.Files, 1, "duplicate file IDs should be deduped") require.Equal(t, uploadResp.ID, chatResult.Files[0].ID) }) t.Run("FileCapExceeded", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) // Upload MaxChatFileIDs files. fileIDs := make([]uuid.UUID, 0, codersdk.MaxChatFileIDs) for i := range codersdk.MaxChatFileIDs { resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", fmt.Sprintf("file%d.png", i), bytes.NewReader(pngData)) require.NoError(t, err) fileIDs = append(fileIDs, resp.ID) } // Create a chat using all MaxChatFileIDs files. parts := []codersdk.ChatInputPart{ {Type: codersdk.ChatInputPartTypeText, Text: "max files"}, } for _, fid := range fileIDs { parts = append(parts, codersdk.ChatInputPart{Type: codersdk.ChatInputPartTypeFile, FileID: fid}) } chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{OrganizationID: firstUser.OrganizationID, Content: parts}) require.NoError(t, err) require.Empty(t, chat.Warnings, "creating a chat at exactly the cap should not warn") require.Len(t, chat.Files, codersdk.MaxChatFileIDs, "all files should be linked on creation") // Upload one more file. extraResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "one-too-many.png", bytes.NewReader(pngData)) require.NoError(t, err) // Sending a message with the extra file should succeed // (message goes through) but the file should NOT be linked // (cap enforced in SQL). The response includes a warning. msgResp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{ {Type: codersdk.ChatInputPartTypeText, Text: "one too many"}, {Type: codersdk.ChatInputPartTypeFile, FileID: extraResp.ID}, }, }) require.NoError(t, err) require.NotEmpty(t, msgResp.Warnings, "response should warn about unlinked files") require.Contains(t, msgResp.Warnings[0], "file linking skipped") // The extra file should NOT appear in the chat's files. chatResult, err := client.GetChat(ctx, chat.ID) require.NoError(t, err) require.Len(t, chatResult.Files, codersdk.MaxChatFileIDs, "file count should not exceed the cap") // Sending a message referencing an already-linked file // should succeed with no warnings (dedup, no array growth). msgResp2, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ Content: []codersdk.ChatInputPart{ {Type: codersdk.ChatInputPartTypeText, Text: "re-reference existing"}, {Type: codersdk.ChatInputPartTypeFile, FileID: fileIDs[0]}, }, }) require.NoError(t, err) require.Empty(t, msgResp2.Warnings, "re-referencing an existing file should not warn") }) t.Run("FileCapOnCreate", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) // Upload MaxChatFileIDs + 1 files. fileIDs := make([]uuid.UUID, 0, codersdk.MaxChatFileIDs+1) for i := range codersdk.MaxChatFileIDs + 1 { resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", fmt.Sprintf("create%d.png", i), bytes.NewReader(pngData)) require.NoError(t, err) fileIDs = append(fileIDs, resp.ID) } // Create a chat with all files (one over the cap). parts := []codersdk.ChatInputPart{ {Type: codersdk.ChatInputPartTypeText, Text: "over cap on create"}, } for _, fid := range fileIDs { parts = append(parts, codersdk.ChatInputPart{Type: codersdk.ChatInputPartTypeFile, FileID: fid}) } chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{OrganizationID: firstUser.OrganizationID, Content: parts}) require.NoError(t, err, "chat creation should succeed even when cap is exceeded") require.NotEmpty(t, chat.Warnings, "response should warn about unlinked files") require.Contains(t, chat.Warnings[0], "file linking skipped") // Only MaxChatFileIDs files should actually be linked. // With SQL-level batch rejection, ALL files are rejected // when the result would exceed the cap. Since we're // sending MaxChatFileIDs+1 files, the deduped count is // 21 > 20, so 0 rows are affected and all files are // unlinked. chatResult, err := client.GetChat(ctx, chat.ID) require.NoError(t, err) require.Empty(t, chatResult.Files, "no files should be linked when batch exceeds cap") }) } func TestPatchChatMessage(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "hello before edit", }, }, }) require.NoError(t, err) messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) var userMessageID int64 for _, message := range messagesResult.Messages { if message.Role == codersdk.ChatMessageRoleUser { userMessageID = message.ID break } } require.NotZero(t, userMessageID) edited, err := client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{ Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "hello after edit", }, }, }) require.NoError(t, err) // The edited message is soft-deleted and a new one is inserted, // so the returned ID will differ from the original. require.NotEqual(t, userMessageID, edited.Message.ID) require.Equal(t, codersdk.ChatMessageRoleUser, edited.Message.Role) foundEditedText := false for _, part := range edited.Message.Content { if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "hello after edit" { foundEditedText = true } } require.True(t, foundEditedText) messagesResult, err = client.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) foundEditedInChat := false foundOriginalInChat := false for _, message := range messagesResult.Messages { if message.Role != codersdk.ChatMessageRoleUser { continue } for _, part := range message.Content { if part.Type != codersdk.ChatMessagePartTypeText { continue } if part.Text == "hello after edit" { foundEditedInChat = true } if part.Text == "hello before edit" { foundOriginalInChat = true } } } require.True(t, foundEditedInChat) require.False(t, foundOriginalInChat) }) t.Run("PreservesFileID", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Upload a file. pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(pngData)) require.NoError(t, err) // Create a chat with a text + file part. chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "before edit with file", }, { Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID, }, }, }) require.NoError(t, err) // Find the user message ID. messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) var userMessageID int64 for _, message := range messagesResult.Messages { if message.Role == codersdk.ChatMessageRoleUser { userMessageID = message.ID break } } require.NotZero(t, userMessageID) // Edit the message: new text, same file_id. edited, err := client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{ Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "after edit with file", }, { Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID, }, }, }) require.NoError(t, err) // The edited message is soft-deleted and a new one is inserted, // so the returned ID will differ from the original. require.NotEqual(t, userMessageID, edited.Message.ID) // Assert the edit response preserves the file_id. var foundText, foundFile bool for _, part := range edited.Message.Content { if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "after edit with file" { foundText = true } if part.Type == codersdk.ChatMessagePartTypeFile && part.FileID.Valid && part.FileID.UUID == uploadResp.ID { foundFile = true require.Nil(t, part.Data, "file data should not be sent when file_id is present") } } require.True(t, foundText, "edited message should contain updated text") require.True(t, foundFile, "edited message should preserve file_id") // GET the chat messages and verify the file_id persists. messagesResult, err = client.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) var foundTextInChat, foundFileInChat bool for _, message := range messagesResult.Messages { if message.Role != codersdk.ChatMessageRoleUser { continue } for _, part := range message.Content { if part.Type == codersdk.ChatMessagePartTypeText && part.Text == "after edit with file" { foundTextInChat = true } if part.Type == codersdk.ChatMessagePartTypeFile && part.FileID.Valid && part.FileID.UUID == uploadResp.ID { foundFileInChat = true require.Nil(t, part.Data, "file data should not be sent when file_id is present") } } } require.True(t, foundTextInChat, "chat should contain edited text") require.True(t, foundFileInChat, "chat should preserve file_id after edit") }) t.Run("UsageLimitExceeded", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "hello before edit", }}, }) require.NoError(t, err) messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) var userMessageID int64 for _, message := range messagesResult.Messages { if message.Role == codersdk.ChatMessageRoleUser { userMessageID = message.ID break } } require.NotZero(t, userMessageID) wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100) insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100) _, err = client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{ Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "edited over limit", }}, }) requireChatUsageLimitExceededError(t, err, 100, 100, wantResetsAt) }) t.Run("MessageNotFound", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "hello", }, }, }) require.NoError(t, err) _, err = client.EditChatMessage(ctx, chat.ID, 999999, codersdk.EditChatMessageRequest{ Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "edited", }, }, }) sdkErr := requireSDKError(t, err, http.StatusNotFound) require.Equal(t, "Chat message not found.", sdkErr.Message) }) t.Run("InvalidMessageID", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "hello", }, }, }) require.NoError(t, err) res, err := client.Request( ctx, http.MethodPatch, fmt.Sprintf("/api/experimental/chats/%s/messages/not-an-int", chat.ID), codersdk.EditChatMessageRequest{ Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "ignored", }, }, }, ) require.NoError(t, err) defer res.Body.Close() err = codersdk.ReadBodyAsError(res) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid chat message ID.", sdkErr.Message) }) t.Run("FilesLinkedOnEdit", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Create a text-only chat. chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ {Type: codersdk.ChatInputPartTypeText, Text: "before file edit"}, }, }) require.NoError(t, err) // Upload a file. pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) uploadResp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "edit-linked.png", bytes.NewReader(pngData)) require.NoError(t, err) // Find the user message ID. messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) var userMessageID int64 for _, msg := range messagesResult.Messages { if msg.Role == codersdk.ChatMessageRoleUser { userMessageID = msg.ID break } } require.NotZero(t, userMessageID) // Edit the message to include the file. _, err = client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{ Content: []codersdk.ChatInputPart{ {Type: codersdk.ChatInputPartTypeText, Text: "after file edit"}, {Type: codersdk.ChatInputPartTypeFile, FileID: uploadResp.ID}, }, }) require.NoError(t, err) // GET the chat — file should be linked. chatResult, err := client.GetChat(ctx, chat.ID) require.NoError(t, err) require.Len(t, chatResult.Files, 1) f := chatResult.Files[0] require.Equal(t, uploadResp.ID, f.ID) require.Equal(t, "edit-linked.png", f.Name) require.Equal(t, "image/png", f.MimeType) }) t.Run("CapExceededOnEdit", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) // Create a chat with MaxChatFileIDs files already linked. parts := []codersdk.ChatInputPart{ {Type: codersdk.ChatInputPartTypeText, Text: "fill to cap"}, } pngData := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) for i := range codersdk.MaxChatFileIDs { up, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", fmt.Sprintf("cap-%d.png", i), bytes.NewReader(pngData)) require.NoError(t, err) parts = append(parts, codersdk.ChatInputPart{Type: codersdk.ChatInputPartTypeFile, FileID: up.ID}) } chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{OrganizationID: firstUser.OrganizationID, Content: parts}) require.NoError(t, err) require.Empty(t, chat.Warnings, "all files should link on create") // Find the user message. messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) var userMessageID int64 for _, msg := range messagesResult.Messages { if msg.Role == codersdk.ChatMessageRoleUser { userMessageID = msg.ID break } } require.NotZero(t, userMessageID) // Upload one more file and try to link via edit. extra, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "one-too-many.png", bytes.NewReader(pngData)) require.NoError(t, err) edited, err := client.EditChatMessage(ctx, chat.ID, userMessageID, codersdk.EditChatMessageRequest{ Content: []codersdk.ChatInputPart{ {Type: codersdk.ChatInputPartTypeText, Text: "edit with extra file"}, {Type: codersdk.ChatInputPartTypeFile, FileID: extra.ID}, }, }) require.NoError(t, err) require.NotEmpty(t, edited.Warnings, "edit should surface cap warning") require.Contains(t, edited.Warnings[0], "file linking skipped") // Verify the cap is still enforced. chatResult, err := client.GetChat(ctx, chat.ID) require.NoError(t, err) require.Len(t, chatResult.Files, codersdk.MaxChatFileIDs, "file count should not exceed the cap") }) } func TestStreamChat(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) const initialMessage = "stream chat route initial message" chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: initialMessage, }, }, }) require.NoError(t, err) events, closer, err := client.StreamChat(ctx, chat.ID, nil) require.NoError(t, err) defer closer.Close() hasTextPart := func(parts []codersdk.ChatMessagePart, want string) bool { for _, part := range parts { if part.Type == codersdk.ChatMessagePartTypeText && part.Text == want { return true } } return false } foundInitialUserMessage := false for !foundInitialUserMessage { select { case <-ctx.Done(): require.FailNow(t, "timed out waiting for expected stream chat event") case event, ok := <-events: require.True(t, ok, "stream closed before expected event") require.Equal(t, chat.ID, event.ChatID) require.NotEqual(t, codersdk.ChatStreamEventTypeError, event.Type) if event.Type == codersdk.ChatStreamEventTypeMessage && event.Message != nil && event.Message.Role == codersdk.ChatMessageRoleUser && hasTextPart(event.Message.Content, initialMessage) { foundInitialUserMessage = true } } } }) t.Run("Unauthenticated", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) unauthenticatedClient := codersdk.New(client.URL) res, err := unauthenticatedClient.Request( ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s/stream", uuid.New()), nil, ) require.NoError(t, err) defer res.Body.Close() require.Equal(t, http.StatusUnauthorized, res.StatusCode) }) } func TestInterruptChat(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "interrupt route test", }) require.NoError(t, err) runningWorkerID := uuid.New() chat, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ ID: chat.ID, Status: database.ChatStatusRunning, WorkerID: uuid.NullUUID{UUID: runningWorkerID, Valid: true}, StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, }) require.NoError(t, err) require.Equal(t, database.ChatStatusRunning, chat.Status) require.True(t, chat.WorkerID.Valid) require.True(t, chat.StartedAt.Valid) require.True(t, chat.HeartbeatAt.Valid) interrupted, err := client.InterruptChat(ctx, chat.ID) require.NoError(t, err) require.Equal(t, chat.ID, interrupted.ID) require.Equal(t, codersdk.ChatStatusWaiting, interrupted.Status) persisted, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) require.Equal(t, database.ChatStatusWaiting, persisted.Status) require.False(t, persisted.WorkerID.Valid) require.False(t, persisted.StartedAt.Valid) require.False(t, persisted.HeartbeatAt.Valid) }) t.Run("ChatNotFound", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.InterruptChat(ctx, uuid.New()) requireSDKError(t, err, http.StatusNotFound) }) } func TestRegenerateChatTitle(t *testing.T) { t.Parallel() t.Run("ChatNotFound", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.RegenerateChatTitle(ctx, uuid.New()) requireSDKError(t, err, http.StatusNotFound) }) t.Run("UpdateDenied", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) clientRaw, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{ Authorizer: &coderdtest.FakeAuthorizer{ ConditionalReturn: func(_ context.Context, _ rbac.Subject, action policy.Action, object rbac.Object) error { if action == policy.ActionUpdate && object.Type == rbac.ResourceChat.Type { return xerrors.New("denied") } return nil }, }, DeploymentValues: chatDeploymentValues(t), }) client := codersdk.NewExperimentalClient(clientRaw) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "chat with update denied", }) require.NoError(t, err) _, err = client.RegenerateChatTitle(ctx, chat.ID) requireSDKError(t, err, http.StatusNotFound) }) t.Run("NotFoundForDifferentUser", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "private chat", }, }, }) require.NoError(t, err) otherClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID, rbac.RoleAgentsAccess()) otherClient := codersdk.NewExperimentalClient(otherClientRaw) _, err = otherClient.RegenerateChatTitle(ctx, createdChat.ID) requireSDKError(t, err, http.StatusNotFound) }) t.Run("Unauthenticated", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "chat for unauthenticated regeneration", }}, }) require.NoError(t, err) unauthenticatedClient := codersdk.NewExperimentalClient(codersdk.New(client.URL)) _, err = unauthenticatedClient.RegenerateChatTitle(ctx, chat.ID) requireSDKError(t, err, http.StatusUnauthorized) }) t.Run("UsageLimitExceeded", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "chat over usage limit", }}, }) require.NoError(t, err) wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100) insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100) _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ ID: chat.ID, Status: database.ChatStatusCompleted, WorkerID: uuid.NullUUID{}, StartedAt: sql.NullTime{}, HeartbeatAt: sql.NullTime{}, LastError: sql.NullString{}, }) require.NoError(t, err) _, err = client.RegenerateChatTitle(ctx, chat.ID) limitErr := codersdk.ChatUsageLimitExceededFrom(err) require.NotNil(t, limitErr) require.Equal(t, "Chat usage limit exceeded.", limitErr.Message) require.Equal(t, int64(100), limitErr.SpentMicros) require.Equal(t, int64(100), limitErr.LimitMicros) require.True( t, limitErr.ResetsAt.Equal(wantResetsAt), "expected resets_at %s, got %s", wantResetsAt.UTC().Format(time.RFC3339), limitErr.ResetsAt.UTC().Format(time.RFC3339), ) }) t.Run("AlreadyInProgress", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "chat with lock held", }) require.NoError(t, err) _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ ID: chat.ID, Status: database.ChatStatusCompleted, WorkerID: uuid.NullUUID{UUID: uuid.MustParse("00000000-0000-0000-0000-000000000001"), Valid: true}, StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, LastError: sql.NullString{}, }) require.NoError(t, err) res, err := client.Request( ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/title/regenerate", chat.ID), nil, ) require.NoError(t, err) defer res.Body.Close() require.Equal(t, http.StatusConflict, res.StatusCode) var resp codersdk.Response require.NoError(t, json.NewDecoder(res.Body).Decode(&resp)) require.Equal(t, "Title regeneration already in progress for this chat.", resp.Message) }) t.Run("PendingWithoutWorker", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "pending chat without worker", }) require.NoError(t, err) chat, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ ID: chat.ID, Status: database.ChatStatusPending, WorkerID: uuid.NullUUID{}, StartedAt: sql.NullTime{}, HeartbeatAt: sql.NullTime{}, LastError: sql.NullString{}, }) require.NoError(t, err) before, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) res, err := client.Request( ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/title/regenerate", chat.ID), nil, ) require.NoError(t, err) defer res.Body.Close() require.Equal(t, http.StatusConflict, res.StatusCode) var resp codersdk.Response require.NoError(t, json.NewDecoder(res.Body).Decode(&resp)) require.Equal(t, "Title regeneration already in progress for this chat.", resp.Message) persisted, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) require.Equal(t, database.ChatStatusPending, persisted.Status) require.False(t, persisted.WorkerID.Valid) require.True(t, persisted.UpdatedAt.Equal(before.UpdatedAt)) }) t.Run("RegenerationFailure", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "test chat", }, }, }) require.NoError(t, err) // Wait for background processing triggered by signalWake // to finish before setting the status, otherwise the // processor may update updated_at concurrently. require.Eventually(t, func() bool { c, getErr := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) if getErr != nil { return false } return c.Status != database.ChatStatusPending && c.Status != database.ChatStatusRunning }, testutil.WaitShort, testutil.IntervalFast) _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ ID: chat.ID, Status: database.ChatStatusCompleted, WorkerID: uuid.NullUUID{}, StartedAt: sql.NullTime{}, HeartbeatAt: sql.NullTime{}, LastError: sql.NullString{}, }) require.NoError(t, err) before, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) _, err = client.RegenerateChatTitle(ctx, chat.ID) requireSDKError(t, err, http.StatusInternalServerError) after, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) require.True(t, after.UpdatedAt.Equal(before.UpdatedAt)) }) } func TestGetChatDiffStatus(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) rawClient, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ DeploymentValues: chatDeploymentValues(t), ExternalAuthConfigs: []*externalauth.Config{ { ID: "gitlab-test", Type: "gitlab", Regex: regexp.MustCompile(`github\.com`), }, }, }) client := codersdk.NewExperimentalClient(rawClient) db := api.Database user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) noCachedStatusChat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "get diff status route no cache", }) require.NoError(t, err) noCachedChat, err := client.GetChat(ctx, noCachedStatusChat.ID) require.NoError(t, err) require.Equal(t, noCachedStatusChat.ID, noCachedChat.ID) require.Nil(t, noCachedChat.DiffStatus) cachedStatusChat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "get diff status route cached", }) require.NoError(t, err) refreshedAt := time.Now().UTC().Truncate(time.Second) staleAt := refreshedAt.Add(time.Hour) _, err = db.UpsertChatDiffStatusReference( dbauthz.AsSystemRestricted(ctx), database.UpsertChatDiffStatusReferenceParams{ ChatID: cachedStatusChat.ID, Url: sql.NullString{}, GitBranch: "feature/diff-status", GitRemoteOrigin: "git@github.com:coder/coder.git", StaleAt: staleAt, }, ) require.NoError(t, err) _, err = db.UpsertChatDiffStatus( dbauthz.AsSystemRestricted(ctx), database.UpsertChatDiffStatusParams{ ChatID: cachedStatusChat.ID, Url: sql.NullString{}, PullRequestState: sql.NullString{ String: " open ", Valid: true, }, ChangesRequested: true, Additions: 11, Deletions: 4, ChangedFiles: 3, RefreshedAt: refreshedAt, StaleAt: staleAt, }, ) require.NoError(t, err) cachedChat, err := client.GetChat(ctx, cachedStatusChat.ID) require.NoError(t, err) require.Equal(t, cachedStatusChat.ID, cachedChat.ID) require.NotNil(t, cachedChat.DiffStatus) cachedStatus := cachedChat.DiffStatus require.Equal(t, cachedStatusChat.ID, cachedStatus.ChatID) require.NotNil(t, cachedStatus.URL) require.Equal(t, "https://github.com/coder/coder/tree/feature/diff-status", *cachedStatus.URL) require.NotNil(t, cachedStatus.PullRequestState) require.Equal(t, "open", *cachedStatus.PullRequestState) require.True(t, cachedStatus.ChangesRequested) require.EqualValues(t, 11, cachedStatus.Additions) require.EqualValues(t, 4, cachedStatus.Deletions) require.EqualValues(t, 3, cachedStatus.ChangedFiles) require.NotNil(t, cachedStatus.RefreshedAt) require.WithinDuration(t, refreshedAt, *cachedStatus.RefreshedAt, time.Second) require.NotNil(t, cachedStatus.StaleAt) require.WithinDuration(t, staleAt, *cachedStatus.StaleAt, time.Second) }) t.Run("NotFoundForDifferentUser", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "private chat", }, }, }) require.NoError(t, err) otherClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID, rbac.RoleAgentsAccess()) otherClient := codersdk.NewExperimentalClient(otherClientRaw) _, err = otherClient.GetChat(ctx, createdChat.ID) requireSDKError(t, err, http.StatusNotFound) }) } func TestGetChatDiffContents(t *testing.T) { t.Parallel() t.Run("SuccessWithCachedRepositoryReference", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) rawClient, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ DeploymentValues: chatDeploymentValues(t), ExternalAuthConfigs: []*externalauth.Config{ { ID: "gitlab-test", Type: "gitlab", Regex: regexp.MustCompile(`gitlab\.example\.com`), }, }, }) client := codersdk.NewExperimentalClient(rawClient) db := api.Database user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "diff contents with cached repository reference", }) require.NoError(t, err) _, err = db.UpsertChatDiffStatusReference( dbauthz.AsSystemRestricted(ctx), database.UpsertChatDiffStatusReferenceParams{ ChatID: chat.ID, Url: sql.NullString{}, GitBranch: "feature/cached-diff", GitRemoteOrigin: "https://gitlab.example.com/acme/project.git", StaleAt: time.Now().UTC().Add(time.Hour), }, ) require.NoError(t, err) diffContents, err := client.GetChatDiffContents(ctx, chat.ID) require.NoError(t, err) require.Equal(t, chat.ID, diffContents.ChatID) require.NotNil(t, diffContents.Provider) require.Equal(t, "gitlab", *diffContents.Provider) require.NotNil(t, diffContents.RemoteOrigin) require.Equal(t, "https://gitlab.example.com/acme/project.git", *diffContents.RemoteOrigin) require.NotNil(t, diffContents.Branch) require.Equal(t, "feature/cached-diff", *diffContents.Branch) require.Nil(t, diffContents.PullRequestURL) require.Empty(t, diffContents.Diff) }) t.Run("SuccessWithoutCachedReference", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "diff contents test", }, }, }) require.NoError(t, err) diffContents, err := client.GetChatDiffContents(ctx, chat.ID) require.NoError(t, err) require.Equal(t, chat.ID, diffContents.ChatID) require.Nil(t, diffContents.Provider) require.Nil(t, diffContents.RemoteOrigin) require.Nil(t, diffContents.Branch) require.Nil(t, diffContents.PullRequestURL) require.Empty(t, diffContents.Diff) }) t.Run("NotFoundForDifferentUser", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "private chat", }, }, }) require.NoError(t, err) otherClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID, rbac.RoleAgentsAccess()) otherClient := codersdk.NewExperimentalClient(otherClientRaw) _, err = otherClient.GetChatDiffContents(ctx, createdChat.ID) requireSDKError(t, err, http.StatusNotFound) }) } func TestDeleteChatQueuedMessage(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "delete queued message route test", }) require.NoError(t, err) deleteContent, err := json.Marshal([]codersdk.ChatMessagePart{ codersdk.ChatMessageText("queued message for delete route"), }) require.NoError(t, err) queuedMessage, err := db.InsertChatQueuedMessage( dbauthz.AsSystemRestricted(ctx), database.InsertChatQueuedMessageParams{ ChatID: chat.ID, Content: deleteContent, }, ) require.NoError(t, err) res, err := client.Request( ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/chats/%s/queue/%d", chat.ID, queuedMessage.ID), nil, ) require.NoError(t, err) res.Body.Close() require.Equal(t, http.StatusNoContent, res.StatusCode) messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) for _, queued := range messagesResult.QueuedMessages { require.NotEqual(t, queuedMessage.ID, queued.ID) } queuedMessages, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) for _, queued := range queuedMessages { require.NotEqual(t, queuedMessage.ID, queued.ID) } }) t.Run("InvalidQueuedMessageID", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "delete queued invalid id", }) require.NoError(t, err) invalidRes, err := client.Request( ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/chats/%s/queue/not-an-int", chat.ID), nil, ) require.NoError(t, err) defer invalidRes.Body.Close() err = codersdk.ReadBodyAsError(invalidRes) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid queued message ID.", sdkErr.Message) require.Contains(t, sdkErr.Detail, "invalid syntax") }) } func TestPromoteChatQueuedMessage(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "promote queued message route test", }) require.NoError(t, err) const queuedText = "queued message for promote route" queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{ codersdk.ChatMessageText(queuedText), }) require.NoError(t, err) queuedMessage, err := db.InsertChatQueuedMessage( dbauthz.AsSystemRestricted(ctx), database.InsertChatQueuedMessageParams{ ChatID: chat.ID, Content: queuedContent, }, ) require.NoError(t, err) promoteRes, err := client.Request( ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/queue/%d/promote", chat.ID, queuedMessage.ID), nil, ) require.NoError(t, err) defer promoteRes.Body.Close() require.Equal(t, http.StatusOK, promoteRes.StatusCode) var promoted codersdk.ChatMessage err = json.NewDecoder(promoteRes.Body).Decode(&promoted) require.NoError(t, err) require.NotZero(t, promoted.ID) require.Equal(t, chat.ID, promoted.ChatID) require.Equal(t, codersdk.ChatMessageRoleUser, promoted.Role) foundPromotedText := false for _, part := range promoted.Content { if part.Type == codersdk.ChatMessagePartTypeText && part.Text == queuedText { foundPromotedText = true break } } require.True(t, foundPromotedText) messagesResult, err := client.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) for _, queued := range messagesResult.QueuedMessages { require.NotEqual(t, queuedMessage.ID, queued.ID) } queuedMessages, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) for _, queued := range queuedMessages { require.NotEqual(t, queuedMessage.ID, queued.ID) } }) t.Run("PromotesAlreadyQueuedMessageAfterLimitReached", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) enableDailyChatUsageLimit(ctx, t, db, 100) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "promote queued usage limit", }) require.NoError(t, err) const queuedText = "queued message for promote route" queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{ codersdk.ChatMessageText(queuedText), }) require.NoError(t, err) queuedMessage, err := db.InsertChatQueuedMessage( dbauthz.AsSystemRestricted(ctx), database.InsertChatQueuedMessageParams{ ChatID: chat.ID, Content: queuedContent, }, ) require.NoError(t, err) insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100) _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ ID: chat.ID, Status: database.ChatStatusWaiting, WorkerID: uuid.NullUUID{}, StartedAt: sql.NullTime{}, HeartbeatAt: sql.NullTime{}, LastError: sql.NullString{}, }) require.NoError(t, err) promoteRes, err := client.Request( ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/queue/%d/promote", chat.ID, queuedMessage.ID), nil, ) require.NoError(t, err) defer promoteRes.Body.Close() require.Equal(t, http.StatusOK, promoteRes.StatusCode) var promoted codersdk.ChatMessage err = json.NewDecoder(promoteRes.Body).Decode(&promoted) require.NoError(t, err) require.NotZero(t, promoted.ID) require.Equal(t, chat.ID, promoted.ChatID) require.Equal(t, codersdk.ChatMessageRoleUser, promoted.Role) foundPromotedText := false for _, part := range promoted.Content { if part.Type == codersdk.ChatMessagePartTypeText && part.Text == queuedText { foundPromotedText = true break } } require.True(t, foundPromotedText) queuedMessages, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) for _, queued := range queuedMessages { require.NotEqual(t, queuedMessage.ID, queued.ID) } }) t.Run("InvalidQueuedMessageID", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "promote queued invalid id", }) require.NoError(t, err) invalidRes, err := client.Request( ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/queue/not-an-int/promote", chat.ID), nil, ) require.NoError(t, err) defer invalidRes.Body.Close() err = codersdk.ReadBodyAsError(invalidRes) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid queued message ID.", sdkErr.Message) require.Contains(t, sdkErr.Detail, "invalid syntax") }) t.Run("MemberWithoutAgentsAccess", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) // Create a member without agents-access. Even though the // member owns the chat, promoting a queued message should // be gated by agents-access because it triggers inference. memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: firstUser.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: member.ID, LastModelConfigID: modelConfig.ID, Title: "promote queued no agents access", }) require.NoError(t, err) queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{ codersdk.ChatMessageText("queued message no agents access"), }) require.NoError(t, err) queuedMessage, err := db.InsertChatQueuedMessage( dbauthz.AsSystemRestricted(ctx), database.InsertChatQueuedMessageParams{ ChatID: chat.ID, Content: queuedContent, }, ) require.NoError(t, err) promoteRes, err := memberClient.Request( ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/queue/%d/promote", chat.ID, queuedMessage.ID), nil, ) require.NoError(t, err) defer promoteRes.Body.Close() require.Equal(t, http.StatusForbidden, promoteRes.StatusCode) }) } func TestChatUsageLimitOverrideRoutes(t *testing.T) { t.Parallel() t.Run("UpsertUserOverrideRequiresPositiveSpendLimit", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, _ := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID) res, err := client.Request( ctx, http.MethodPut, fmt.Sprintf("/api/experimental/chats/usage-limits/overrides/%s", member.ID), map[string]any{}, ) require.NoError(t, err) defer res.Body.Close() err = codersdk.ReadBodyAsError(res) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid chat usage limit override.", sdkErr.Message) require.Equal(t, "Spend limit must be greater than 0.", sdkErr.Detail) }) t.Run("UpsertUserOverrideMissingUser", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.UpsertChatUsageLimitOverride(ctx, uuid.New(), codersdk.UpsertChatUsageLimitOverrideRequest{ SpendLimitMicros: 7_000_000, }) sdkErr := requireSDKError(t, err, http.StatusNotFound) require.Equal(t, "User not found.", sdkErr.Message) }) t.Run("DeleteUserOverrideMissingUser", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) err := client.DeleteChatUsageLimitOverride(ctx, uuid.New()) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "User not found.", sdkErr.Message) }) t.Run("DeleteUserOverrideMissingOverride", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID) err := client.DeleteChatUsageLimitOverride(ctx, member.ID) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Chat usage limit override not found.", sdkErr.Message) }) t.Run("UpdateUserOverride", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, _ := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID) _, err := client.UpsertChatUsageLimitOverride(ctx, member.ID, codersdk.UpsertChatUsageLimitOverrideRequest{ SpendLimitMicros: 5_000_000, }) require.NoError(t, err) override, err := client.UpsertChatUsageLimitOverride(ctx, member.ID, codersdk.UpsertChatUsageLimitOverrideRequest{ SpendLimitMicros: 10_000_000, }) require.NoError(t, err) require.Equal(t, member.ID, override.UserID) require.NotNil(t, override.SpendLimitMicros) require.EqualValues(t, 10_000_000, *override.SpendLimitMicros) config, err := client.GetChatUsageLimitConfig(ctx) require.NoError(t, err) require.Len(t, config.Overrides, 1) require.Equal(t, member.ID, config.Overrides[0].UserID) require.NotNil(t, config.Overrides[0].SpendLimitMicros) require.EqualValues(t, 10_000_000, *config.Overrides[0].SpendLimitMicros) }) t.Run("UpsertGroupOverrideIncludesMemberCount", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID) group := dbgen.Group(t, db, database.Group{OrganizationID: firstUser.OrganizationID}) dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: group.ID, UserID: member.ID}) dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: group.ID, UserID: database.PrebuildsSystemUserID}) override, err := client.UpsertChatUsageLimitGroupOverride(ctx, group.ID, codersdk.UpsertChatUsageLimitGroupOverrideRequest{ SpendLimitMicros: 7_000_000, }) require.NoError(t, err) require.Equal(t, group.ID, override.GroupID) require.EqualValues(t, 1, override.MemberCount) require.NotNil(t, override.SpendLimitMicros) require.EqualValues(t, 7_000_000, *override.SpendLimitMicros) config, err := client.GetChatUsageLimitConfig(ctx) require.NoError(t, err) var listed *codersdk.ChatUsageLimitGroupOverride for i := range config.GroupOverrides { if config.GroupOverrides[i].GroupID == group.ID { listed = &config.GroupOverrides[i] break } } require.NotNil(t, listed) require.EqualValues(t, 1, listed.MemberCount) }) t.Run("UpdateGroupOverride", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID) group := dbgen.Group(t, db, database.Group{OrganizationID: firstUser.OrganizationID}) dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: group.ID, UserID: firstUser.UserID}) dbgen.GroupMember(t, db, database.GroupMemberTable{GroupID: group.ID, UserID: member.ID}) _, err := client.UpsertChatUsageLimitGroupOverride(ctx, group.ID, codersdk.UpsertChatUsageLimitGroupOverrideRequest{ SpendLimitMicros: 5_000_000, }) require.NoError(t, err) override, err := client.UpsertChatUsageLimitGroupOverride(ctx, group.ID, codersdk.UpsertChatUsageLimitGroupOverrideRequest{ SpendLimitMicros: 10_000_000, }) require.NoError(t, err) require.Equal(t, group.ID, override.GroupID) require.EqualValues(t, 2, override.MemberCount) require.NotNil(t, override.SpendLimitMicros) require.EqualValues(t, 10_000_000, *override.SpendLimitMicros) config, err := client.GetChatUsageLimitConfig(ctx) require.NoError(t, err) require.Len(t, config.GroupOverrides, 1) require.Equal(t, group.ID, config.GroupOverrides[0].GroupID) require.EqualValues(t, 2, config.GroupOverrides[0].MemberCount) require.NotNil(t, config.GroupOverrides[0].SpendLimitMicros) require.EqualValues(t, 10_000_000, *config.GroupOverrides[0].SpendLimitMicros) }) t.Run("UpsertGroupOverrideMissingGroup", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) _, err := client.UpsertChatUsageLimitGroupOverride(ctx, uuid.New(), codersdk.UpsertChatUsageLimitGroupOverrideRequest{ SpendLimitMicros: 7_000_000, }) sdkErr := requireSDKError(t, err, http.StatusNotFound) require.Equal(t, "Group not found.", sdkErr.Message) }) t.Run("DeleteGroupOverrideMissingOverride", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) group := dbgen.Group(t, db, database.Group{OrganizationID: firstUser.OrganizationID}) err := client.DeleteChatUsageLimitGroupOverride(ctx, group.ID) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Chat usage limit group override not found.", sdkErr.Message) }) } func TestPostChatFile(t *testing.T) { t.Parallel() t.Run("Success/PNG", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) // Valid PNG header + padding. data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) require.NoError(t, err) require.NotEqual(t, uuid.Nil, resp.ID) }) t.Run("Success/JPEG", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) data := append([]byte{0xFF, 0xD8, 0xFF, 0xE0}, make([]byte, 64)...) resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/jpeg", "test.jpg", bytes.NewReader(data)) require.NoError(t, err) require.NotEqual(t, uuid.Nil, resp.ID) }) t.Run("Success/WebP", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) // WebP: RIFF + 4-byte size + WEBP + padding. data := append([]byte("RIFF"), make([]byte, 4)...) data = append(data, []byte("WEBP")...) data = append(data, make([]byte, 64)...) resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/webp", "test.webp", bytes.NewReader(data)) require.NoError(t, err) require.NotEqual(t, uuid.Nil, resp.ID) }) t.Run("Success/TextPlain", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) data := []byte("This is a test paste.\nWith multiple lines.\n") resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "text/plain", "test.txt", bytes.NewReader(data)) require.NoError(t, err) require.NotEqual(t, uuid.Nil, resp.ID) }) t.Run("UnsupportedContentType", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "application/pdf", "test.pdf", bytes.NewReader([]byte("%PDF-1.7"))) requireSDKError(t, err, http.StatusBadRequest) }) t.Run("SVGBlocked", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/svg+xml", "test.svg", bytes.NewReader([]byte(""))) requireSDKError(t, err, http.StatusBadRequest) }) t.Run("ContentSniffingRejects", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) // Header says PNG but body is plain text. _, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader([]byte("hello world"))) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Contains(t, sdkErr.Message, "does not match") }) t.Run("ContentSniffingRejectsPNGAsText", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) // Valid 1x1 PNG declared as text/plain should still be rejected. data := []byte{ 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x04, 0x00, 0x00, 0x00, 0xB5, 0x1C, 0x0C, 0x02, 0x00, 0x00, 0x00, 0x0B, 0x49, 0x44, 0x41, 0x54, 0x78, 0xDA, 0x63, 0xFC, 0xFF, 0x1F, 0x00, 0x03, 0x03, 0x02, 0x00, 0xEF, 0x9A, 0x1A, 0x2A, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82, } _, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "text/plain", "test.txt", bytes.NewReader(data)) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Contains(t, sdkErr.Message, "does not match") }) t.Run("TooLarge", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) // 10 MB + 1 byte, with valid PNG header to pass MIME check. data := make([]byte, 10<<20+1) copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}) _, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) require.Error(t, err) }) t.Run("Success/TextPlainHTMLLikeContent", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) data := []byte("\n

Paste me as plain text.

\n") resp, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "text/plain", "snippet.txt", bytes.NewReader(data)) require.NoError(t, err) require.NotEqual(t, uuid.Nil, resp.ID) }) t.Run("MissingOrganization", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) coderdtest.CreateFirstUser(t, client.Client) data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) res, err := client.Request(ctx, http.MethodPost, "/api/experimental/chats/files", bytes.NewReader(data), func(r *http.Request) { r.Header.Set("Content-Type", "image/png") }) require.NoError(t, err) defer res.Body.Close() err = codersdk.ReadBodyAsError(res) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Contains(t, sdkErr.Message, "Missing organization") }) t.Run("InvalidOrganization", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) coderdtest.CreateFirstUser(t, client.Client) data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) res, err := client.Request(ctx, http.MethodPost, "/api/experimental/chats/files?organization=not-a-uuid", bytes.NewReader(data), func(r *http.Request) { r.Header.Set("Content-Type", "image/png") }) require.NoError(t, err) defer res.Body.Close() err = codersdk.ReadBodyAsError(res) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Contains(t, sdkErr.Message, "Invalid organization ID") }) t.Run("WrongOrganization", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) coderdtest.CreateFirstUser(t, client.Client) data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) _, err := client.UploadChatFile(ctx, uuid.New(), "image/png", "test.png", bytes.NewReader(data)) require.Error(t, err) var sdkErr *codersdk.Error require.ErrorAs(t, err, &sdkErr) // dbauthz returns 404 or 500 depending on how the org lookup // fails; 403 is also possible. Any non-success code is valid. require.GreaterOrEqual(t, sdkErr.StatusCode(), http.StatusBadRequest, "expected error status, got %d", sdkErr.StatusCode()) }) t.Run("Unauthenticated", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) unauthed := codersdk.NewExperimentalClient(codersdk.New(client.URL)) data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) _, err := unauthed.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) requireSDKError(t, err, http.StatusUnauthorized) }) } func TestGetChatFile(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) require.NoError(t, err) got, contentType, err := client.GetChatFile(ctx, uploaded.ID) require.NoError(t, err) require.Equal(t, "image/png", contentType) require.Equal(t, data, got) }) t.Run("Success/TextPlain", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) data := []byte("This is a test paste.\nWith multiple lines.\n") uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "text/plain", "test.txt", bytes.NewReader(data)) require.NoError(t, err) got, contentType, err := client.GetChatFile(ctx, uploaded.ID) require.NoError(t, err) require.Equal(t, "text/plain", contentType) require.Equal(t, data, got) }) t.Run("CacheHeaders", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) require.NoError(t, err) res, err := client.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/files/%s", uploaded.ID), nil) require.NoError(t, err) defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) require.Equal(t, "private, max-age=31536000, immutable", res.Header.Get("Cache-Control")) require.Equal(t, "nosniff", res.Header.Get("X-Content-Type-Options")) require.Contains(t, res.Header.Get("Content-Disposition"), "inline") require.Contains(t, res.Header.Get("Content-Disposition"), "test.png") }) t.Run("LongFilename", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) longName := strings.Repeat("a", 300) + ".png" data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", longName, bytes.NewReader(data)) require.NoError(t, err) res, err := client.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/files/%s", uploaded.ID), nil) require.NoError(t, err) defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) // Filename should be truncated to maxChatFileName (255) bytes. cd := res.Header.Get("Content-Disposition") require.Contains(t, cd, "inline") require.Contains(t, cd, strings.Repeat("a", 255)) require.NotContains(t, cd, strings.Repeat("a", 256)) }) t.Run("UnicodeFilename", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) // Upload with a non-ASCII filename using RFC 5987 encoding, // which is what the frontend sends for Unicode filenames. data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "スクリーンショット.png", bytes.NewReader(data)) require.NoError(t, err) res, err := client.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/files/%s", uploaded.ID), nil) require.NoError(t, err) defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) cd := res.Header.Get("Content-Disposition") require.Contains(t, cd, "inline") _, params, err := mime.ParseMediaType(cd) require.NoError(t, err) require.Equal(t, "スクリーンショット.png", params["filename"]) }) t.Run("NotFound", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) coderdtest.CreateFirstUser(t, client.Client) _, _, err := client.GetChatFile(ctx, uuid.New()) requireSDKError(t, err, http.StatusNotFound) }) t.Run("InvalidUUID", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) coderdtest.CreateFirstUser(t, client.Client) res, err := client.Request(ctx, http.MethodGet, "/api/experimental/chats/files/not-a-uuid", nil) require.NoError(t, err) defer res.Body.Close() err = codersdk.ReadBodyAsError(res) requireSDKError(t, err, http.StatusBadRequest) }) t.Run("OtherUserForbidden", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...) uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "test.png", bytes.NewReader(data)) require.NoError(t, err) otherClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID, rbac.RoleAgentsAccess()) otherClient := codersdk.NewExperimentalClient(otherClientRaw) _, _, err = otherClient.GetChatFile(ctx, uploaded.ID) requireSDKError(t, err, http.StatusNotFound) }) } type chatCostTestFixture struct { Client *codersdk.ExperimentalClient DB database.Store ModelConfigID uuid.UUID ChatID uuid.UUID EarliestCreatedAt time.Time LatestCreatedAt time.Time } // safeOptions returns an explicit time window around the fixture messages to // avoid app-time/database-time boundary flakes in summary tests. func (f chatCostTestFixture) safeOptions() codersdk.ChatCostSummaryOptions { return codersdk.ChatCostSummaryOptions{ StartDate: f.EarliestCreatedAt.Add(-time.Minute), EndDate: f.LatestCreatedAt.Add(time.Minute), } } func seedChatCostFixture(t *testing.T) chatCostTestFixture { t.Helper() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: firstUser.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: firstUser.UserID, LastModelConfigID: modelConfig.ID, Title: "test chat", }) require.NoError(t, err) results, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{ ChatID: chat.ID, CreatedBy: []uuid.UUID{uuid.Nil, uuid.Nil}, ModelConfigID: []uuid.UUID{modelConfig.ID, modelConfig.ID}, Role: []database.ChatMessageRole{"assistant", "assistant"}, Content: []string{"null", "null"}, ContentVersion: []int16{0, 0}, Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth}, InputTokens: []int64{100, 100}, OutputTokens: []int64{50, 50}, TotalTokens: []int64{0, 0}, ReasoningTokens: []int64{0, 0}, CacheCreationTokens: []int64{0, 0}, CacheReadTokens: []int64{0, 0}, ContextLimit: []int64{0, 0}, Compressed: []bool{false, false}, TotalCostMicros: []int64{500, 500}, RuntimeMs: []int64{1500, 2500}, }) require.NoError(t, err) require.Len(t, results, 2) earliestCreatedAt := results[0].CreatedAt latestCreatedAt := results[0].CreatedAt for _, msg := range results { if msg.CreatedAt.Before(earliestCreatedAt) { earliestCreatedAt = msg.CreatedAt } if msg.CreatedAt.After(latestCreatedAt) { latestCreatedAt = msg.CreatedAt } } return chatCostTestFixture{ Client: client, DB: db, ModelConfigID: modelConfig.ID, ChatID: chat.ID, EarliestCreatedAt: earliestCreatedAt, LatestCreatedAt: latestCreatedAt, } } func assertChatCostSummary(t *testing.T, summary codersdk.ChatCostSummary, modelConfigID, chatID uuid.UUID) { t.Helper() require.Equal(t, int64(1000), summary.TotalCostMicros) require.Equal(t, int64(2), summary.PricedMessageCount) require.Equal(t, int64(0), summary.UnpricedMessageCount) require.Equal(t, int64(200), summary.TotalInputTokens) require.Equal(t, int64(100), summary.TotalOutputTokens) require.Equal(t, int64(4000), summary.TotalRuntimeMs) require.Len(t, summary.ByModel, 1) require.Equal(t, modelConfigID, summary.ByModel[0].ModelConfigID) require.Equal(t, int64(1000), summary.ByModel[0].TotalCostMicros) require.Equal(t, int64(2), summary.ByModel[0].MessageCount) require.Equal(t, int64(4000), summary.ByModel[0].TotalRuntimeMs) require.Len(t, summary.ByChat, 1) require.Equal(t, chatID, summary.ByChat[0].RootChatID) require.Equal(t, int64(1000), summary.ByChat[0].TotalCostMicros) require.Equal(t, int64(2), summary.ByChat[0].MessageCount) require.Equal(t, int64(4000), summary.ByChat[0].TotalRuntimeMs) } func TestChatCostSummary(t *testing.T) { t.Parallel() t.Run("BasicSummary", func(t *testing.T) { t.Parallel() f := seedChatCostFixture(t) ctx := testutil.Context(t, testutil.WaitLong) // Use a window derived from DB timestamps to avoid time boundary flakes. summary, err := f.Client.GetChatCostSummary(ctx, "me", f.safeOptions()) require.NoError(t, err) assertChatCostSummary(t, summary, f.ModelConfigID, f.ChatID) }) } func TestChatCostSummary_AfterModelDeletion(t *testing.T) { t.Parallel() f := seedChatCostFixture(t) ctx := testutil.Context(t, testutil.WaitLong) options := f.safeOptions() // Baseline: use DB-derived timestamps to avoid time boundary flakes. summary, err := f.Client.GetChatCostSummary(ctx, "me", options) require.NoError(t, err) assertChatCostSummary(t, summary, f.ModelConfigID, f.ChatID) // Soft-delete the model config. err = f.Client.DeleteChatModelConfig(ctx, f.ModelConfigID) require.NoError(t, err) // Costs must survive the deletion unchanged within the same safe window. summary, err = f.Client.GetChatCostSummary(ctx, "me", options) require.NoError(t, err) assertChatCostSummary(t, summary, f.ModelConfigID, f.ChatID) } func TestChatCostSummary_AdminDrilldown(t *testing.T) { t.Parallel() seedCtx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) modelConfig := createChatModelConfig(t, client) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{ OrganizationID: firstUser.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: member.ID, LastModelConfigID: modelConfig.ID, Title: "member chat", }) require.NoError(t, err) results, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessagesParams{ ChatID: chat.ID, CreatedBy: []uuid.UUID{uuid.Nil}, ModelConfigID: []uuid.UUID{modelConfig.ID}, Role: []database.ChatMessageRole{"assistant"}, Content: []string{"null"}, ContentVersion: []int16{0}, Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, InputTokens: []int64{200}, OutputTokens: []int64{100}, TotalTokens: []int64{0}, ReasoningTokens: []int64{0}, CacheCreationTokens: []int64{0}, CacheReadTokens: []int64{0}, ContextLimit: []int64{0}, Compressed: []bool{false}, TotalCostMicros: []int64{750}, RuntimeMs: []int64{0}, }) require.NoError(t, err) message := results[0] options := codersdk.ChatCostSummaryOptions{ // Pad the DB-assigned timestamp so the query window cannot race it. StartDate: message.CreatedAt.Add(-time.Minute), EndDate: message.CreatedAt.Add(time.Minute), } t.Run("AdminCanDrilldown", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) summary, err := client.GetChatCostSummary(ctx, member.ID.String(), options) require.NoError(t, err) require.Equal(t, int64(750), summary.TotalCostMicros) require.Equal(t, int64(1), summary.PricedMessageCount) }) t.Run("MemberCannotDrilldownOtherUser", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) _, err := memberClient.GetChatCostSummary(ctx, firstUser.UserID.String(), options) require.Error(t, err) var sdkErr *codersdk.Error require.ErrorAs(t, err, &sdkErr) require.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) }) } func TestChatCostUsers(t *testing.T) { t.Parallel() seedCtx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) firstUserRecord, err := db.GetUserByID(dbauthz.AsSystemRestricted(seedCtx), firstUser.UserID) require.NoError(t, err) modelConfig := createChatModelConfig(t, client) adminChat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{ OrganizationID: firstUser.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: firstUser.UserID, LastModelConfigID: modelConfig.ID, Title: "admin chat", }) require.NoError(t, err) _, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessagesParams{ ChatID: adminChat.ID, CreatedBy: []uuid.UUID{uuid.Nil}, ModelConfigID: []uuid.UUID{modelConfig.ID}, Role: []database.ChatMessageRole{"assistant"}, Content: []string{"null"}, ContentVersion: []int16{0}, Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, InputTokens: []int64{100}, OutputTokens: []int64{50}, TotalTokens: []int64{0}, ReasoningTokens: []int64{0}, CacheCreationTokens: []int64{0}, CacheReadTokens: []int64{0}, ContextLimit: []int64{0}, Compressed: []bool{false}, TotalCostMicros: []int64{300}, RuntimeMs: []int64{0}, }) require.NoError(t, err) memberChat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{ OrganizationID: firstUser.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: member.ID, LastModelConfigID: modelConfig.ID, Title: "member chat", }) require.NoError(t, err) _, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessagesParams{ ChatID: memberChat.ID, CreatedBy: []uuid.UUID{uuid.Nil}, ModelConfigID: []uuid.UUID{modelConfig.ID}, Role: []database.ChatMessageRole{"assistant"}, Content: []string{"null"}, ContentVersion: []int16{0}, Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, InputTokens: []int64{200}, OutputTokens: []int64{100}, TotalTokens: []int64{0}, ReasoningTokens: []int64{0}, CacheCreationTokens: []int64{0}, CacheReadTokens: []int64{0}, ContextLimit: []int64{0}, Compressed: []bool{false}, TotalCostMicros: []int64{800}, RuntimeMs: []int64{0}, }) require.NoError(t, err) t.Run("AdminCanListUsers", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) resp, err := client.GetChatCostUsers(ctx, codersdk.ChatCostUsersOptions{}) require.NoError(t, err) require.Equal(t, int64(2), resp.Count) require.Len(t, resp.Users, 2) require.Equal(t, member.ID, resp.Users[0].UserID) require.Equal(t, member.Username, resp.Users[0].Username) require.Equal(t, int64(800), resp.Users[0].TotalCostMicros) require.Equal(t, int64(1), resp.Users[0].MessageCount) require.Equal(t, int64(1), resp.Users[0].ChatCount) require.Equal(t, firstUser.UserID, resp.Users[1].UserID) require.Equal(t, firstUserRecord.Username, resp.Users[1].Username) require.Equal(t, int64(300), resp.Users[1].TotalCostMicros) }) t.Run("AdminCanFilterAndPaginateUsers", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) resp, err := client.GetChatCostUsers(ctx, codersdk.ChatCostUsersOptions{ Username: member.Username, Pagination: codersdk.Pagination{ Limit: 1, Offset: 0, }, }) require.NoError(t, err) require.Equal(t, int64(1), resp.Count) require.Len(t, resp.Users, 1) require.Equal(t, member.ID, resp.Users[0].UserID) require.Equal(t, member.Username, resp.Users[0].Username) }) t.Run("MemberCannotListUsers", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) _, err := memberClient.GetChatCostUsers(ctx, codersdk.ChatCostUsersOptions{}) require.Error(t, err) var sdkErr *codersdk.Error require.ErrorAs(t, err, &sdkErr) require.Equal(t, http.StatusForbidden, sdkErr.StatusCode()) }) } func TestChatCostSummary_DateRange(t *testing.T) { t.Parallel() seedCtx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{ OrganizationID: firstUser.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: firstUser.UserID, LastModelConfigID: modelConfig.ID, Title: "date range test", }) require.NoError(t, err) _, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessagesParams{ ChatID: chat.ID, CreatedBy: []uuid.UUID{uuid.Nil}, ModelConfigID: []uuid.UUID{modelConfig.ID}, Role: []database.ChatMessageRole{"assistant"}, Content: []string{"null"}, ContentVersion: []int16{0}, Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, InputTokens: []int64{100}, OutputTokens: []int64{50}, TotalTokens: []int64{0}, ReasoningTokens: []int64{0}, CacheCreationTokens: []int64{0}, CacheReadTokens: []int64{0}, ContextLimit: []int64{0}, Compressed: []bool{false}, TotalCostMicros: []int64{500}, RuntimeMs: []int64{0}, }) require.NoError(t, err) now := time.Now() t.Run("MessageInRange", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{ StartDate: now.Add(-time.Hour), EndDate: now.Add(time.Hour), }) require.NoError(t, err) require.Equal(t, int64(500), summary.TotalCostMicros) require.Equal(t, int64(1), summary.PricedMessageCount) }) t.Run("MessageOutOfRange", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{ StartDate: now.Add(time.Hour), EndDate: now.Add(2 * time.Hour), }) require.NoError(t, err) require.Equal(t, int64(0), summary.TotalCostMicros) require.Equal(t, int64(0), summary.PricedMessageCount) }) } func TestChatCostSummary_UnpricedMessages(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: firstUser.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: firstUser.UserID, LastModelConfigID: modelConfig.ID, Title: "unpriced test", }) require.NoError(t, err) pricedResults, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{ ChatID: chat.ID, CreatedBy: []uuid.UUID{uuid.Nil}, ModelConfigID: []uuid.UUID{modelConfig.ID}, Role: []database.ChatMessageRole{"assistant"}, Content: []string{"null"}, ContentVersion: []int16{0}, Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, InputTokens: []int64{100}, OutputTokens: []int64{50}, TotalTokens: []int64{0}, ReasoningTokens: []int64{0}, CacheCreationTokens: []int64{0}, CacheReadTokens: []int64{0}, ContextLimit: []int64{0}, Compressed: []bool{false}, TotalCostMicros: []int64{500}, RuntimeMs: []int64{0}, }) require.NoError(t, err) pricedMessage := pricedResults[0] unpricedResults, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{ ChatID: chat.ID, CreatedBy: []uuid.UUID{uuid.Nil}, ModelConfigID: []uuid.UUID{modelConfig.ID}, Role: []database.ChatMessageRole{"assistant"}, Content: []string{"null"}, ContentVersion: []int16{0}, Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, InputTokens: []int64{200}, OutputTokens: []int64{75}, TotalTokens: []int64{0}, ReasoningTokens: []int64{0}, CacheCreationTokens: []int64{0}, CacheReadTokens: []int64{0}, ContextLimit: []int64{0}, Compressed: []bool{false}, TotalCostMicros: []int64{0}, RuntimeMs: []int64{0}, }) require.NoError(t, err) unpricedMessage := unpricedResults[0] earliestCreatedAt := pricedMessage.CreatedAt latestCreatedAt := pricedMessage.CreatedAt if unpricedMessage.CreatedAt.Before(earliestCreatedAt) { earliestCreatedAt = unpricedMessage.CreatedAt } if unpricedMessage.CreatedAt.After(latestCreatedAt) { latestCreatedAt = unpricedMessage.CreatedAt } options := codersdk.ChatCostSummaryOptions{ // Pad the DB-assigned timestamps to avoid time boundary flakes. StartDate: earliestCreatedAt.Add(-time.Minute), EndDate: latestCreatedAt.Add(time.Minute), } summary, err := client.GetChatCostSummary(ctx, "me", options) require.NoError(t, err) require.Equal(t, int64(500), summary.TotalCostMicros) require.Equal(t, int64(1), summary.PricedMessageCount) require.Equal(t, int64(1), summary.UnpricedMessageCount) require.Equal(t, int64(300), summary.TotalInputTokens) require.Equal(t, int64(125), summary.TotalOutputTokens) } func requireChatModelPricing( t *testing.T, actual *codersdk.ChatModelCallConfig, expected *codersdk.ChatModelCallConfig, ) { t.Helper() require.NotNil(t, actual) require.NotNil(t, expected) require.NotNil(t, actual.Cost) require.NotNil(t, expected.Cost) require.NotNil(t, actual.Cost.InputPricePerMillionTokens) require.NotNil(t, actual.Cost.OutputPricePerMillionTokens) require.NotNil(t, actual.Cost.CacheReadPricePerMillionTokens) require.NotNil(t, actual.Cost.CacheWritePricePerMillionTokens) require.True(t, expected.Cost.InputPricePerMillionTokens.Equal(*actual.Cost.InputPricePerMillionTokens)) require.True(t, expected.Cost.OutputPricePerMillionTokens.Equal(*actual.Cost.OutputPricePerMillionTokens)) require.True(t, expected.Cost.CacheReadPricePerMillionTokens.Equal(*actual.Cost.CacheReadPricePerMillionTokens)) require.True(t, expected.Cost.CacheWritePricePerMillionTokens.Equal(*actual.Cost.CacheWritePricePerMillionTokens)) } func decRef(value string) *decimal.Decimal { d := decimal.RequireFromString(value) return &d } func TestWatchChatDesktop(t *testing.T) { t.Parallel() t.Run("NoWorkspace", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: "desktop no workspace test", }, }, }) require.NoError(t, err) // Try to connect to the desktop endpoint — should fail because // chat has no workspace. res, err := client.Request( ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s/stream/desktop", createdChat.ID), nil, ) require.NoError(t, err) defer res.Body.Close() require.Equal(t, http.StatusBadRequest, res.StatusCode) }) } func createChatModelConfig(t *testing.T, client *codersdk.ExperimentalClient) codersdk.ChatModelConfig { t.Helper() ctx := testutil.Context(t, testutil.WaitLong) _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ Provider: "openai", APIKey: "test-api-key", }) require.NoError(t, err) contextLimit := int64(4096) isDefault := true modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", Model: "gpt-4o-mini", ContextLimit: &contextLimit, IsDefault: &isDefault, }) require.NoError(t, err) return modelConfig } //nolint:tparallel,paralleltest // Subtests share a single coderdtest instance. func TestChatSystemPrompt(t *testing.T) { t.Parallel() adminClient, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) _ = createChatModelConfig(t, adminClient) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) const workspaceAwareness = "There is no workspace associated with this chat yet. Create one using the create_workspace tool before using workspace tools like execute, read_file, write_file, etc." updateChatSystemPrompt := func(t *testing.T, ctx context.Context, req codersdk.UpdateChatSystemPromptRequest) { t.Helper() err := adminClient.UpdateChatSystemPrompt(ctx, req) require.NoError(t, err) } getChatSystemPrompt := func(t *testing.T, ctx context.Context) codersdk.ChatSystemPromptResponse { t.Helper() resp, err := adminClient.GetChatSystemPrompt(ctx) require.NoError(t, err) return resp } assertInjectedSystemMessages := func(t *testing.T, ctx context.Context, wantResolvedPrompt string) { t.Helper() chat, err := adminClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{ { Type: codersdk.ChatInputPartTypeText, Text: fmt.Sprintf("system prompt composition %s", t.Name()), }, }, }) require.NoError(t, err) messages, err := db.GetChatMessagesForPromptByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) var systemTexts []string for _, message := range messages { if message.Role != database.ChatMessageRoleSystem { continue } parts, err := chatprompt.ParseContent(message) require.NoError(t, err) require.Len(t, parts, 1) require.Equal(t, codersdk.ChatMessagePartTypeText, parts[0].Type) systemTexts = append(systemTexts, parts[0].Text) } if wantResolvedPrompt == "" { require.Equal(t, []string{workspaceAwareness}, systemTexts) return } require.Equal(t, []string{wantResolvedPrompt, workspaceAwareness}, systemTexts) } t.Run("ReturnsEmptyWhenUnset", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) resp := getChatSystemPrompt(t, ctx) require.Equal(t, "", resp.SystemPrompt) require.True(t, resp.IncludeDefaultSystemPrompt, "should default to true") require.Equal(t, chatd.DefaultSystemPrompt, resp.DefaultSystemPrompt, "should return the built-in default prompt for preview") }) t.Run("AdminCanSet", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) updateChatSystemPrompt(t, ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "You are a helpful coding assistant.", IncludeDefaultSystemPrompt: ptr.Ref(true), }) resp := getChatSystemPrompt(t, ctx) require.Equal(t, "You are a helpful coding assistant.", resp.SystemPrompt) require.True(t, resp.IncludeDefaultSystemPrompt) require.Equal(t, chatd.DefaultSystemPrompt, resp.DefaultSystemPrompt) }) t.Run("AdminCanUnset", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) // Unset by sending an empty string. updateChatSystemPrompt(t, ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "", IncludeDefaultSystemPrompt: ptr.Ref(true), }) resp := getChatSystemPrompt(t, ctx) require.Empty(t, resp.SystemPrompt) require.True(t, resp.IncludeDefaultSystemPrompt) require.Equal(t, chatd.DefaultSystemPrompt, resp.DefaultSystemPrompt) }) t.Run("ToggleIncludeDefault", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) updateChatSystemPrompt(t, ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "", IncludeDefaultSystemPrompt: ptr.Ref(false), }) resp := getChatSystemPrompt(t, ctx) require.Empty(t, resp.SystemPrompt) require.False(t, resp.IncludeDefaultSystemPrompt) require.Equal(t, chatd.DefaultSystemPrompt, resp.DefaultSystemPrompt) updateChatSystemPrompt(t, ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "", IncludeDefaultSystemPrompt: ptr.Ref(true), }) resp = getChatSystemPrompt(t, ctx) require.Empty(t, resp.SystemPrompt) require.True(t, resp.IncludeDefaultSystemPrompt) require.Equal(t, chatd.DefaultSystemPrompt, resp.DefaultSystemPrompt) }) t.Run("PreservesIncludeDefaultWhenOmitted", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) rawDB, pubsub := dbtestutil.NewDB(t) store := &failNextChatSystemPromptStore{Store: rawDB} client := codersdk.NewExperimentalClient(coderdtest.New(t, &coderdtest.Options{ Database: store, Pubsub: pubsub, DeploymentValues: chatDeploymentValues(t), })) _ = coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) err := client.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "", IncludeDefaultSystemPrompt: ptr.Ref(false), }) require.NoError(t, err) store.failNextGetChatIncludeDefaultSystemPrompt.Store(true) store.failNextUpsertChatIncludeDefaultSystemPrompt.Store(true) err = client.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "Omitted toggle request", }) require.NoError(t, err) resp, err := client.GetChatSystemPrompt(ctx) require.NoError(t, err) require.Equal(t, "Omitted toggle request", resp.SystemPrompt) require.False(t, resp.IncludeDefaultSystemPrompt) require.Equal(t, chatd.DefaultSystemPrompt, resp.DefaultSystemPrompt) }) t.Run("ExistingCustomPromptDefaultsIncludeDefaultOff", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) legacyClient, legacyDB := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, legacyClient.Client) _ = createChatModelConfig(t, legacyClient) require.NoError(t, legacyDB.UpsertChatSystemPrompt(dbauthz.AsSystemRestricted(ctx), "Legacy custom instructions")) resp, err := legacyClient.GetChatSystemPrompt(ctx) require.NoError(t, err) require.Equal(t, "Legacy custom instructions", resp.SystemPrompt) require.False(t, resp.IncludeDefaultSystemPrompt) require.Equal(t, chatd.DefaultSystemPrompt, resp.DefaultSystemPrompt) chat, err := legacyClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: fmt.Sprintf("legacy custom prompt %s", t.Name()), }}, }) require.NoError(t, err) messages, err := legacyDB.GetChatMessagesForPromptByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) var systemTexts []string for _, message := range messages { if message.Role != database.ChatMessageRoleSystem { continue } parts, err := chatprompt.ParseContent(message) require.NoError(t, err) require.Len(t, parts, 1) require.Equal(t, codersdk.ChatMessagePartTypeText, parts[0].Type) systemTexts = append(systemTexts, parts[0].Text) } require.Equal(t, []string{"Legacy custom instructions", workspaceAwareness}, systemTexts) }) t.Run("DefaultSystemPromptPreview", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) resp := getChatSystemPrompt(t, ctx) require.Equal(t, chatd.DefaultSystemPrompt, resp.DefaultSystemPrompt) require.NotEmpty(t, resp.DefaultSystemPrompt, "built-in default prompt should not be empty") }) t.Run("SavesBothFieldsTogether", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) updateChatSystemPrompt(t, ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "Custom instructions for all users.", IncludeDefaultSystemPrompt: ptr.Ref(false), }) resp := getChatSystemPrompt(t, ctx) require.Equal(t, "Custom instructions for all users.", resp.SystemPrompt) require.False(t, resp.IncludeDefaultSystemPrompt) updateChatSystemPrompt(t, ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "Different instructions.", IncludeDefaultSystemPrompt: ptr.Ref(true), }) resp = getChatSystemPrompt(t, ctx) require.Equal(t, "Different instructions.", resp.SystemPrompt) require.True(t, resp.IncludeDefaultSystemPrompt) }) t.Run("PromptComposition", func(t *testing.T) { t.Run("DefaultOnlyWhenToggleOnAndEmpty", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) updateChatSystemPrompt(t, ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "", IncludeDefaultSystemPrompt: ptr.Ref(true), }) resp := getChatSystemPrompt(t, ctx) require.Empty(t, resp.SystemPrompt) require.True(t, resp.IncludeDefaultSystemPrompt) require.Equal(t, chatd.DefaultSystemPrompt, resp.DefaultSystemPrompt) assertInjectedSystemMessages(t, ctx, chatd.DefaultSystemPrompt) }) t.Run("BothWhenToggleOnAndNonEmpty", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) updateChatSystemPrompt(t, ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "Custom instructions", IncludeDefaultSystemPrompt: ptr.Ref(true), }) resp := getChatSystemPrompt(t, ctx) require.Equal(t, "Custom instructions", resp.SystemPrompt) require.True(t, resp.IncludeDefaultSystemPrompt) require.Equal(t, chatd.DefaultSystemPrompt, resp.DefaultSystemPrompt) assertInjectedSystemMessages(t, ctx, chatd.DefaultSystemPrompt+"\n\nCustom instructions") }) t.Run("CustomOnlyWhenToggleOff", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) updateChatSystemPrompt(t, ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "Custom only", IncludeDefaultSystemPrompt: ptr.Ref(false), }) resp := getChatSystemPrompt(t, ctx) require.Equal(t, "Custom only", resp.SystemPrompt) require.False(t, resp.IncludeDefaultSystemPrompt) require.Equal(t, chatd.DefaultSystemPrompt, resp.DefaultSystemPrompt) assertInjectedSystemMessages(t, ctx, "Custom only") }) t.Run("EmptyWhenToggleOffAndEmpty", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) updateChatSystemPrompt(t, ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "", IncludeDefaultSystemPrompt: ptr.Ref(false), }) resp := getChatSystemPrompt(t, ctx) require.Empty(t, resp.SystemPrompt) require.False(t, resp.IncludeDefaultSystemPrompt) require.Equal(t, chatd.DefaultSystemPrompt, resp.DefaultSystemPrompt) assertInjectedSystemMessages(t, ctx, "") }) }) t.Run("CreateChatFallsBackToDefaultWhenSystemPromptConfigReadFailsWithIncludeDefaultEnabled", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) rawDB, pubsub := dbtestutil.NewDB(t) store := &failNextChatSystemPromptStore{Store: rawDB} client := codersdk.NewExperimentalClient(coderdtest.New(t, &coderdtest.Options{ Database: store, Pubsub: pubsub, DeploymentValues: chatDeploymentValues(t), })) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) err := client.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "Keep custom instructions", IncludeDefaultSystemPrompt: ptr.Ref(true), }) require.NoError(t, err) store.failNextGetChatSystemPromptConfig.Store(true) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: fmt.Sprintf("config-read fallback %s", t.Name()), }}, }) require.NoError(t, err) messages, err := rawDB.GetChatMessagesForPromptByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) var systemTexts []string for _, message := range messages { if message.Role != database.ChatMessageRoleSystem { continue } parts, err := chatprompt.ParseContent(message) require.NoError(t, err) require.Len(t, parts, 1) require.Equal(t, codersdk.ChatMessagePartTypeText, parts[0].Type) systemTexts = append(systemTexts, parts[0].Text) } require.Equal(t, []string{chatd.DefaultSystemPrompt, workspaceAwareness}, systemTexts) }) t.Run("CreateChatFallbackIgnoresDisabledPreferenceWhenConfigReadFails", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) rawDB, pubsub := dbtestutil.NewDB(t) store := &failNextChatSystemPromptStore{Store: rawDB} client := codersdk.NewExperimentalClient(coderdtest.New(t, &coderdtest.Options{ Database: store, Pubsub: pubsub, DeploymentValues: chatDeploymentValues(t), })) firstUser := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) err := client.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "Do not use the default prompt", IncludeDefaultSystemPrompt: ptr.Ref(false), }) require.NoError(t, err) // A config read failure loses all admin preferences, including // include_default=false, so chat creation falls back to the built-in default. store.failNextGetChatSystemPromptConfig.Store(true) chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: firstUser.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: fmt.Sprintf("config-read fallback %s", t.Name()), }}, }) require.NoError(t, err) messages, err := rawDB.GetChatMessagesForPromptByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID) require.NoError(t, err) var systemTexts []string for _, message := range messages { if message.Role != database.ChatMessageRoleSystem { continue } parts, err := chatprompt.ParseContent(message) require.NoError(t, err) require.Len(t, parts, 1) require.Equal(t, codersdk.ChatMessagePartTypeText, parts[0].Type) systemTexts = append(systemTexts, parts[0].Text) } require.Equal(t, []string{chatd.DefaultSystemPrompt, workspaceAwareness}, systemTexts) }) t.Run("NonAdminFails", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) err := memberClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: "This should fail.", IncludeDefaultSystemPrompt: ptr.Ref(true), }) requireSDKError(t, err, http.StatusForbidden) _, err = memberClient.GetChatSystemPrompt(ctx) requireSDKError(t, err, http.StatusNotFound) }) t.Run("UnauthenticatedFails", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) anonClient := codersdk.NewExperimentalClient(codersdk.New(adminClient.URL)) _, err := anonClient.GetChatSystemPrompt(ctx) var sdkErr *codersdk.Error require.ErrorAs(t, err, &sdkErr) require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode()) }) t.Run("TooLong", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) tooLong := strings.Repeat("a", 131073) err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{ SystemPrompt: tooLong, IncludeDefaultSystemPrompt: ptr.Ref(true), }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "System prompt exceeds maximum length.", sdkErr.Message) }) } //nolint:tparallel,paralleltest // Subtests share a single coderdtest instance. func TestChatPlanModeInstructions(t *testing.T) { t.Parallel() adminClient, _ := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) _ = createChatModelConfig(t, adminClient) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) updateChatPlanModeInstructions := func(t *testing.T, ctx context.Context, req codersdk.UpdateChatPlanModeInstructionsRequest) { t.Helper() err := adminClient.UpdateChatPlanModeInstructions(ctx, req) require.NoError(t, err) } getChatPlanModeInstructions := func(t *testing.T, ctx context.Context) codersdk.ChatPlanModeInstructionsResponse { t.Helper() resp, err := adminClient.GetChatPlanModeInstructions(ctx) require.NoError(t, err) return resp } roundTripTests := []struct { name string updates []string want string }{ { name: "DefaultGETReturnsEmpty", want: "", }, { name: "PUTThenGETRoundTrips", updates: []string{"Use plan mode for multi-step changes."}, want: "Use plan mode for multi-step changes.", }, } for _, tt := range roundTripTests { t.Run(tt.name, func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) for _, instructions := range tt.updates { updateChatPlanModeInstructions(t, ctx, codersdk.UpdateChatPlanModeInstructionsRequest{ PlanModeInstructions: instructions, }) } resp := getChatPlanModeInstructions(t, ctx) require.Equal(t, tt.want, resp.PlanModeInstructions) }) } t.Run("OversizedPayloadReturns400", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) tooLong := strings.Repeat("a", 131073) err := adminClient.UpdateChatPlanModeInstructions(ctx, codersdk.UpdateChatPlanModeInstructionsRequest{ PlanModeInstructions: tooLong, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Plan mode instructions exceed maximum length.", sdkErr.Message) }) t.Run("NonAdminGETReturns404", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) _, err := memberClient.GetChatPlanModeInstructions(ctx) requireSDKError(t, err, http.StatusNotFound) }) } //nolint:tparallel,paralleltest // Subtests share a single coderdtest instance. func TestChatExploreModelOverride(t *testing.T) { t.Parallel() adminClient, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) defaultModel := createChatModelConfig(t, adminClient) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) createAdditionalModel := func(t *testing.T, model string, enabled bool) codersdk.ChatModelConfig { t.Helper() ctx := testutil.Context(t, testutil.WaitLong) contextLimit := int64(4096) isDefault := false modelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: defaultModel.Provider, Model: model, ContextLimit: &contextLimit, IsDefault: &isDefault, }) require.NoError(t, err) if enabled { return modelConfig } updated, err := adminClient.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ Enabled: ptr.Ref(false), }) require.NoError(t, err) return updated } t.Run("DefaultGETReturnsEmpty", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) resp, err := adminClient.GetChatExploreModelOverride(ctx) require.NoError(t, err) require.Nil(t, resp.ModelConfigID) require.False(t, resp.HasMalformedOverride) }) t.Run("AdminCanSetAndClear", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) overrideModel := createAdditionalModel(t, "gpt-4.1-mini", true) err := adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{ ModelConfigID: &overrideModel.ID, }) require.NoError(t, err) resp, err := adminClient.GetChatExploreModelOverride(ctx) require.NoError(t, err) require.NotNil(t, resp.ModelConfigID) require.Equal(t, overrideModel.ID, *resp.ModelConfigID) require.False(t, resp.HasMalformedOverride) err = adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{}) require.NoError(t, err) resp, err = adminClient.GetChatExploreModelOverride(ctx) require.NoError(t, err) require.Nil(t, resp.ModelConfigID) require.False(t, resp.HasMalformedOverride) }) t.Run("MalformedStoredOverrideIsReportedAndCanBeCleared", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) require.NoError(t, db.UpsertChatExploreModelOverride( dbauthz.AsSystemRestricted(ctx), "not-a-uuid", )) resp, err := adminClient.GetChatExploreModelOverride(ctx) require.NoError(t, err) require.Nil(t, resp.ModelConfigID) require.True(t, resp.HasMalformedOverride) err = adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{}) require.NoError(t, err) resp, err = adminClient.GetChatExploreModelOverride(ctx) require.NoError(t, err) require.Nil(t, resp.ModelConfigID) require.False(t, resp.HasMalformedOverride) }) t.Run("DisabledModelReturns400", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) disabledModel := createAdditionalModel(t, "gpt-4.1-disabled", false) err := adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{ ModelConfigID: &disabledModel.ID, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid model_config_id.", sdkErr.Message) }) t.Run("UnknownModelReturns400", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) unknownModelID := uuid.New() err := adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{ ModelConfigID: &unknownModelID, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Invalid model_config_id.", sdkErr.Message) }) t.Run("NonAdminGETReturns404", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) _, err := memberClient.GetChatExploreModelOverride(ctx) requireSDKError(t, err, http.StatusNotFound) }) t.Run("NonAdminPUTReturns403", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) err := memberClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{ ModelConfigID: &defaultModel.ID, }) requireSDKError(t, err, http.StatusForbidden) }) } func TestChatDesktopEnabled(t *testing.T) { t.Parallel() t.Run("ReturnsFalseWhenUnset", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) coderdtest.CreateFirstUser(t, adminClient.Client) resp, err := adminClient.GetChatDesktopEnabled(ctx) require.NoError(t, err) require.False(t, resp.EnableDesktop) }) t.Run("AdminCanSetTrue", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) coderdtest.CreateFirstUser(t, adminClient.Client) err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{ EnableDesktop: true, }) require.NoError(t, err) resp, err := adminClient.GetChatDesktopEnabled(ctx) require.NoError(t, err) require.True(t, resp.EnableDesktop) }) t.Run("AdminCanSetFalse", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) coderdtest.CreateFirstUser(t, adminClient.Client) // Set true first, then set false. err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{ EnableDesktop: true, }) require.NoError(t, err) err = adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{ EnableDesktop: false, }) require.NoError(t, err) resp, err := adminClient.GetChatDesktopEnabled(ctx) require.NoError(t, err) require.False(t, resp.EnableDesktop) }) t.Run("NonAdminCanRead", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{ EnableDesktop: true, }) require.NoError(t, err) resp, err := memberClient.GetChatDesktopEnabled(ctx) require.NoError(t, err) require.True(t, resp.EnableDesktop) }) t.Run("NonAdminWriteFails", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) err := memberClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{ EnableDesktop: true, }) requireSDKError(t, err, http.StatusForbidden) }) t.Run("UnauthenticatedFails", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) coderdtest.CreateFirstUser(t, adminClient.Client) anonClient := codersdk.NewExperimentalClient(codersdk.New(adminClient.URL)) _, err := anonClient.GetChatDesktopEnabled(ctx) var sdkErr *codersdk.Error require.ErrorAs(t, err, &sdkErr) require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode()) }) } func TestChatWorkspaceTTL(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) anonClient := codersdk.NewExperimentalClient(codersdk.New(adminClient.URL)) // Default value is 0 (disabled) when nothing has been configured. resp, err := adminClient.GetChatWorkspaceTTL(ctx) require.NoError(t, err, "get default") require.Equal(t, int64(0), resp.WorkspaceTTLMillis, "default should be 0") // Admin can set a positive TTL (2h = 7_200_000 ms). err = adminClient.UpdateChatWorkspaceTTL(ctx, codersdk.UpdateChatWorkspaceTTLRequest{ WorkspaceTTLMillis: 7_200_000, }) require.NoError(t, err, "admin set 2h") resp, err = adminClient.GetChatWorkspaceTTL(ctx) require.NoError(t, err, "get after set") require.Equal(t, int64(7_200_000), resp.WorkspaceTTLMillis, "should return 7200000 ms (2h)") // Non-admin can read the value. resp, err = memberClient.GetChatWorkspaceTTL(ctx) require.NoError(t, err, "member get") require.Equal(t, int64(7_200_000), resp.WorkspaceTTLMillis, "member should see same value") // Admin can set back to zero (disabled / template default). err = adminClient.UpdateChatWorkspaceTTL(ctx, codersdk.UpdateChatWorkspaceTTLRequest{ WorkspaceTTLMillis: 0, }) require.NoError(t, err, "admin set 0") resp, err = adminClient.GetChatWorkspaceTTL(ctx) require.NoError(t, err, "get after zero") require.Equal(t, int64(0), resp.WorkspaceTTLMillis, "should be 0 after reset") // Non-admin write is forbidden. err = memberClient.UpdateChatWorkspaceTTL(ctx, codersdk.UpdateChatWorkspaceTTLRequest{ WorkspaceTTLMillis: 3_600_000, }) requireSDKError(t, err, http.StatusForbidden) // Unauthenticated read is rejected. _, err = anonClient.GetChatWorkspaceTTL(ctx) var sdkErr *codersdk.Error require.ErrorAs(t, err, &sdkErr, "anon get") require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode(), "anon should get 401") // Validation: negative duration. err = adminClient.UpdateChatWorkspaceTTL(ctx, codersdk.UpdateChatWorkspaceTTLRequest{ WorkspaceTTLMillis: -3_600_000, }) requireSDKError(t, err, http.StatusBadRequest) // Validation: less than 1 minute (30s = 30_000 ms). err = adminClient.UpdateChatWorkspaceTTL(ctx, codersdk.UpdateChatWorkspaceTTLRequest{ WorkspaceTTLMillis: 30_000, }) requireSDKError(t, err, http.StatusBadRequest) // Boundary: just under 1 minute should be rejected (59_999 ms). err = adminClient.UpdateChatWorkspaceTTL(ctx, codersdk.UpdateChatWorkspaceTTLRequest{ WorkspaceTTLMillis: 59_999, }) requireSDKError(t, err, http.StatusBadRequest) // Boundary: exactly 1 minute should succeed (60_000 ms). err = adminClient.UpdateChatWorkspaceTTL(ctx, codersdk.UpdateChatWorkspaceTTLRequest{ WorkspaceTTLMillis: 60_000, }) require.NoError(t, err, "exactly 1 minute should be accepted") // Boundary: exactly 30 days should succeed (720h = 2_592_000_000 ms). err = adminClient.UpdateChatWorkspaceTTL(ctx, codersdk.UpdateChatWorkspaceTTLRequest{ WorkspaceTTLMillis: 2_592_000_000, }) require.NoError(t, err, "720h (exactly 30 days) should be accepted") // Validation: exceeds 30-day maximum (721h = 2_595_600_000 ms). err = adminClient.UpdateChatWorkspaceTTL(ctx, codersdk.UpdateChatWorkspaceTTLRequest{ WorkspaceTTLMillis: 2_595_600_000, }) requireSDKError(t, err, http.StatusBadRequest) } func TestChatRetentionDays(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) adminClient := newChatClient(t) firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) // Default value is 30 (days) when nothing has been configured. resp, err := adminClient.GetChatRetentionDays(ctx) require.NoError(t, err, "get default") require.Equal(t, int32(30), resp.RetentionDays, "default should be 30") // Admin can set retention days to 90. err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{ RetentionDays: 90, }) require.NoError(t, err, "admin set 90") resp, err = adminClient.GetChatRetentionDays(ctx) require.NoError(t, err, "get after set") require.Equal(t, int32(90), resp.RetentionDays, "should return 90") // Non-admin member can read the value. resp, err = memberClient.GetChatRetentionDays(ctx) require.NoError(t, err, "member get") require.Equal(t, int32(90), resp.RetentionDays, "member should see same value") // Non-admin member cannot write. err = memberClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{RetentionDays: 7}) requireSDKError(t, err, http.StatusForbidden) // Admin can disable purge by setting 0. err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{ RetentionDays: 0, }) require.NoError(t, err, "admin set 0") resp, err = adminClient.GetChatRetentionDays(ctx) require.NoError(t, err, "get after zero") require.Equal(t, int32(0), resp.RetentionDays, "should be 0 after disable") // Validation: negative value is rejected. err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{ RetentionDays: -1, }) requireSDKError(t, err, http.StatusBadRequest) // Validation: exceeding the 3650-day maximum is rejected. err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{ RetentionDays: 3651, // retentionDaysMaximum + 1; keep in sync with coderd/exp_chats.go. }) requireSDKError(t, err, http.StatusBadRequest) } //nolint:tparallel,paralleltest // Subtests share a single coderdtest instance. func TestUserChatCompactionThresholds(t *testing.T) { t.Parallel() client, _ := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) t.Run("EmptyByDefault", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) thresholds, err := client.GetUserChatCompactionThresholds(ctx) require.NoError(t, err) require.Empty(t, thresholds.Thresholds) }) t.Run("PutAndGet", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) override, err := client.UpdateUserChatCompactionThreshold(ctx, modelConfig.ID, codersdk.UpdateUserChatCompactionThresholdRequest{ ThresholdPercent: 75, }) require.NoError(t, err) require.Equal(t, modelConfig.ID, override.ModelConfigID) require.EqualValues(t, 75, override.ThresholdPercent) thresholds, err := client.GetUserChatCompactionThresholds(ctx) require.NoError(t, err) require.Len(t, thresholds.Thresholds, 1) require.Equal(t, modelConfig.ID, thresholds.Thresholds[0].ModelConfigID) require.EqualValues(t, 75, thresholds.Thresholds[0].ThresholdPercent) }) t.Run("UpsertChangesValue", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) _, err := client.UpdateUserChatCompactionThreshold(ctx, modelConfig.ID, codersdk.UpdateUserChatCompactionThresholdRequest{ ThresholdPercent: 50, }) require.NoError(t, err) override, err := client.UpdateUserChatCompactionThreshold(ctx, modelConfig.ID, codersdk.UpdateUserChatCompactionThresholdRequest{ ThresholdPercent: 75, }) require.NoError(t, err) require.EqualValues(t, 75, override.ThresholdPercent) thresholds, err := client.GetUserChatCompactionThresholds(ctx) require.NoError(t, err) require.Len(t, thresholds.Thresholds, 1) require.EqualValues(t, 75, thresholds.Thresholds[0].ThresholdPercent) }) t.Run("BoundaryValues", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) override, err := client.UpdateUserChatCompactionThreshold(ctx, modelConfig.ID, codersdk.UpdateUserChatCompactionThresholdRequest{ ThresholdPercent: 0, }) require.NoError(t, err) require.EqualValues(t, 0, override.ThresholdPercent) thresholds, err := client.GetUserChatCompactionThresholds(ctx) require.NoError(t, err) require.Len(t, thresholds.Thresholds, 1) require.EqualValues(t, 0, thresholds.Thresholds[0].ThresholdPercent) override, err = client.UpdateUserChatCompactionThreshold(ctx, modelConfig.ID, codersdk.UpdateUserChatCompactionThresholdRequest{ ThresholdPercent: 100, }) require.NoError(t, err) require.EqualValues(t, 100, override.ThresholdPercent) thresholds, err = client.GetUserChatCompactionThresholds(ctx) require.NoError(t, err) require.Len(t, thresholds.Thresholds, 1) require.EqualValues(t, 100, thresholds.Thresholds[0].ThresholdPercent) }) t.Run("ValidationRejectsInvalid", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) _, err := client.UpdateUserChatCompactionThreshold(ctx, modelConfig.ID, codersdk.UpdateUserChatCompactionThresholdRequest{ ThresholdPercent: -1, }) requireSDKError(t, err, http.StatusBadRequest) _, err = client.UpdateUserChatCompactionThreshold(ctx, modelConfig.ID, codersdk.UpdateUserChatCompactionThresholdRequest{ ThresholdPercent: 101, }) requireSDKError(t, err, http.StatusBadRequest) }) t.Run("Delete", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) err := client.DeleteUserChatCompactionThreshold(ctx, modelConfig.ID) require.NoError(t, err) thresholds, err := client.GetUserChatCompactionThresholds(ctx) require.NoError(t, err) require.Empty(t, thresholds.Thresholds) }) t.Run("DeleteIdempotent", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) err := client.DeleteUserChatCompactionThreshold(ctx, modelConfig.ID) require.NoError(t, err) }) t.Run("NonExistentModelConfig", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) fakeID := uuid.New() _, err := client.UpdateUserChatCompactionThreshold(ctx, fakeID, codersdk.UpdateUserChatCompactionThresholdRequest{ ThresholdPercent: 50, }) requireSDKError(t, err, http.StatusNotFound) }) t.Run("IsolatedPerUser", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) override, err := client.UpdateUserChatCompactionThreshold(ctx, modelConfig.ID, codersdk.UpdateUserChatCompactionThresholdRequest{ ThresholdPercent: 75, }) require.NoError(t, err) require.Equal(t, modelConfig.ID, override.ModelConfigID) require.EqualValues(t, 75, override.ThresholdPercent) adminThresholds, err := client.GetUserChatCompactionThresholds(ctx) require.NoError(t, err) require.Len(t, adminThresholds.Thresholds, 1) require.Equal(t, modelConfig.ID, adminThresholds.Thresholds[0].ModelConfigID) require.EqualValues(t, 75, adminThresholds.Thresholds[0].ThresholdPercent) memberThresholds, err := memberClient.GetUserChatCompactionThresholds(ctx) require.NoError(t, err) require.Empty(t, memberThresholds.Thresholds) }) } //nolint:tparallel // Subtests share a single coderdtest instance and run sequentially. func TestChatTemplateAllowlist(t *testing.T) { t.Parallel() // Shared setup: one coderdtest instance with two real templates. // Subtests that need valid template IDs use these. client, store := newChatClientWithDatabase(t) admin := coderdtest.CreateFirstUser(t, client.Client) tmpl1 := dbgen.Template(t, store, database.Template{ OrganizationID: admin.OrganizationID, CreatedBy: admin.UserID, }) tmpl2 := dbgen.Template(t, store, database.Template{ OrganizationID: admin.OrganizationID, CreatedBy: admin.UserID, }) deprecatedTmpl := dbgen.Template(t, store, database.Template{ OrganizationID: admin.OrganizationID, CreatedBy: admin.UserID, }) //nolint:gocritic // Owner context needed to deprecate the template in test setup. ownerRoles, err := rbac.RoleIdentifiers{rbac.RoleOwner()}.Expand() require.NoError(t, err) err = store.UpdateTemplateAccessControlByID(dbauthz.As(context.Background(), rbac.Subject{ ID: "owner", Roles: rbac.Roles(ownerRoles), Scope: rbac.ExpandableScope(rbac.ScopeAll), }), database.UpdateTemplateAccessControlByIDParams{ ID: deprecatedTmpl.ID, Deprecated: "this template is deprecated", }) require.NoError(t, err, "deprecate template") //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. t.Run("ReturnsEmptyWhenUnset", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) resp, err := client.GetChatTemplateAllowlist(ctx) require.NoError(t, err) require.Empty(t, resp.TemplateIDs) }) //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. t.Run("AdminCanSet", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) ids := []string{tmpl1.ID.String(), tmpl2.ID.String()} err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: ids}) require.NoError(t, err) resp, err := client.GetChatTemplateAllowlist(ctx) require.NoError(t, err) require.ElementsMatch(t, ids, resp.TemplateIDs) }) //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. t.Run("AdminCanClear", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{}}) require.NoError(t, err) resp, err := client.GetChatTemplateAllowlist(ctx) require.NoError(t, err) require.Empty(t, resp.TemplateIDs) }) //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. t.Run("NonAdminReadFails", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, admin.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) _, err := memberClient.GetChatTemplateAllowlist(ctx) requireSDKError(t, err, http.StatusNotFound) }) //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. t.Run("NonAdminWriteFails", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, admin.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) // Uses a random UUID — hits 404 before template validation. err := memberClient.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{uuid.NewString()}}) requireSDKError(t, err, http.StatusNotFound) }) //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. t.Run("UnauthenticatedFails", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) anonClient := codersdk.NewExperimentalClient(codersdk.New(client.URL)) // Uses a random UUID — hits 401 before template validation. err := anonClient.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{uuid.NewString()}}) requireSDKError(t, err, http.StatusUnauthorized) }) //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. t.Run("InvalidUUIDRejected", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{"not-a-uuid"}}) requireSDKError(t, err, http.StatusBadRequest) }) //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. t.Run("NonexistentTemplateRejected", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{TemplateIDs: []string{uuid.NewString()}}) requireSDKError(t, err, http.StatusBadRequest) }) //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. t.Run("DeprecatedTemplateRejected", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{ TemplateIDs: []string{deprecatedTmpl.ID.String()}, }) requireSDKError(t, err, http.StatusBadRequest) }) //nolint:paralleltest // Sequential: subtests share a single coderdtest instance. t.Run("DeduplicatesIDs", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) id := tmpl1.ID.String() err := client.UpdateChatTemplateAllowlist(ctx, codersdk.ChatTemplateAllowlist{ TemplateIDs: []string{id, id, id}, }) require.NoError(t, err) resp, err := client.GetChatTemplateAllowlist(ctx) require.NoError(t, err) require.Len(t, resp.TemplateIDs, 1) require.Equal(t, id, resp.TemplateIDs[0]) }) } func TestGetChatsByWorkspace(t *testing.T) { t.Parallel() client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) // Helper to create a workspace owned by the test user. newWorkspace := func() dbfake.WorkspaceBuildBuilder { return dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ OrganizationID: user.OrganizationID, OwnerID: user.UserID, }).WithAgent() } // Helper to insert a chat linked to a workspace. insertChat := func(ctx context.Context, title string, workspaceID uuid.UUID) database.Chat { chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: title, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}, }) require.NoError(t, err) return chat } t.Run("EmptyRequestReturnsEmptyMap", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) result, err := client.GetChatsByWorkspace(ctx, []uuid.UUID{}) require.NoError(t, err) require.Empty(t, result) }) t.Run("WorkspaceWithNoChatsOmitted", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) ws := newWorkspace().Do() result, err := client.GetChatsByWorkspace(ctx, []uuid.UUID{ws.Workspace.ID}) require.NoError(t, err) require.Empty(t, result) }) t.Run("ReturnsChatLinkedToWorkspace", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) ws := newWorkspace().Do() chat := insertChat(ctx, "workspace chat", ws.Workspace.ID) result, err := client.GetChatsByWorkspace(ctx, []uuid.UUID{ws.Workspace.ID}) require.NoError(t, err) require.Len(t, result, 1) require.Equal(t, chat.ID, result[ws.Workspace.ID]) }) t.Run("ArchivedChatsExcluded", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) ws := newWorkspace().Do() chat := insertChat(ctx, "soon to be archived", ws.Workspace.ID) err := client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)}) require.NoError(t, err) result, err := client.GetChatsByWorkspace(ctx, []uuid.UUID{ws.Workspace.ID}) require.NoError(t, err) require.Empty(t, result) }) t.Run("ReturnsLatestNonArchivedChat", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) ws := newWorkspace().Do() // Insert an older chat and archive it. olderChat := insertChat(ctx, "older archived", ws.Workspace.ID) err := client.UpdateChat(ctx, olderChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)}) require.NoError(t, err) // Insert two active chats — the second is newer due to insert // ordering and should win the "latest" selection in Go after // the SQL returns both ordered by updated_at DESC. _ = insertChat(ctx, "older active", ws.Workspace.ID) newerChat := insertChat(ctx, "newer active", ws.Workspace.ID) result, err := client.GetChatsByWorkspace(ctx, []uuid.UUID{ws.Workspace.ID}) require.NoError(t, err) require.Len(t, result, 1) require.Equal(t, newerChat.ID, result[ws.Workspace.ID]) }) t.Run("MultipleWorkspaces", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) wsA := newWorkspace().Do() wsB := newWorkspace().Do() wsC := newWorkspace().Do() chatA := insertChat(ctx, "chat for workspace A", wsA.Workspace.ID) chatB := insertChat(ctx, "chat for workspace B", wsB.Workspace.ID) // Query all three workspaces; C has no chats. result, err := client.GetChatsByWorkspace(ctx, []uuid.UUID{ wsA.Workspace.ID, wsB.Workspace.ID, wsC.Workspace.ID, }) require.NoError(t, err) require.Len(t, result, 2) require.Equal(t, chatA.ID, result[wsA.Workspace.ID]) require.Equal(t, chatB.ID, result[wsB.Workspace.ID]) _, hasC := result[wsC.Workspace.ID] require.False(t, hasC, "workspace C should not appear in result") }) t.Run("RejectsTooManyWorkspaceIDs", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) ids := make([]uuid.UUID, 26) for i := range ids { ids[i] = uuid.New() } _, err := client.GetChatsByWorkspace(ctx, ids) require.Error(t, err) requireSDKError(t, err, http.StatusBadRequest) }) } func TestSubmitToolResults(t *testing.T) { t.Parallel() // setupRequiresAction creates a chat via the DB with dynamic tools, // inserts an assistant message containing tool-call parts for each // given toolCallID, and sets the chat status to requires_action. // It returns the chat row so callers can exercise the endpoint. setupRequiresAction := func( ctx context.Context, t *testing.T, db database.Store, ownerID uuid.UUID, organizationID uuid.UUID, modelConfigID uuid.UUID, dynamicToolName string, toolCallIDs []string, ) database.Chat { t.Helper() // Marshal dynamic tools into the chat row. dynamicTools := []mcp.Tool{{ Name: dynamicToolName, Description: "a test dynamic tool", InputSchema: mcp.ToolInputSchema{Type: "object"}, }} dtJSON, err := json.Marshal(dynamicTools) require.NoError(t, err) chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: organizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: ownerID, LastModelConfigID: modelConfigID, Title: "tool-results-test", DynamicTools: pqtype.NullRawMessage{RawMessage: dtJSON, Valid: true}, }) require.NoError(t, err) // Build assistant message with tool-call parts. parts := make([]codersdk.ChatMessagePart, 0, len(toolCallIDs)) for _, id := range toolCallIDs { parts = append(parts, codersdk.ChatMessagePart{ Type: codersdk.ChatMessagePartTypeToolCall, ToolCallID: id, ToolName: dynamicToolName, Args: json.RawMessage(`{"key":"value"}`), }) } content, err := chatprompt.MarshalParts(parts) require.NoError(t, err) _, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{ ChatID: chat.ID, CreatedBy: []uuid.UUID{uuid.Nil}, ModelConfigID: []uuid.UUID{modelConfigID}, Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, ContentVersion: []int16{chatprompt.CurrentContentVersion}, Content: []string{string(content.RawMessage)}, Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, InputTokens: []int64{0}, OutputTokens: []int64{0}, TotalTokens: []int64{0}, ReasoningTokens: []int64{0}, CacheCreationTokens: []int64{0}, CacheReadTokens: []int64{0}, ContextLimit: []int64{0}, Compressed: []bool{false}, TotalCostMicros: []int64{0}, RuntimeMs: []int64{0}, }) require.NoError(t, err) // Transition to requires_action. chat, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ ID: chat.ID, Status: database.ChatStatusRequiresAction, }) require.NoError(t, err) require.Equal(t, database.ChatStatusRequiresAction, chat.Status) return chat } t.Run("Success", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) const toolName = "my_dynamic_tool" toolCallIDs := []string{"call_abc", "call_def"} chat := setupRequiresAction(ctx, t, db, user.UserID, user.OrganizationID, modelConfig.ID, toolName, toolCallIDs) err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ Results: []codersdk.ToolResult{ {ToolCallID: "call_abc", Output: json.RawMessage(`"result_a"`)}, {ToolCallID: "call_def", Output: json.RawMessage(`"result_b"`)}, }, }) require.NoError(t, err) // Verify status is no longer requires_action. The chatd // loop may have already picked the chat up and // transitioned it further (pending → running → …), so we // accept any non-requires_action status. gotChat, err := client.GetChat(ctx, chat.ID) require.NoError(t, err) require.NotEqual(t, codersdk.ChatStatusRequiresAction, gotChat.Status, "chat should no longer be in requires_action after submitting tool results") // Verify tool-result messages were persisted. msgsResp, err := client.GetChatMessages(ctx, chat.ID, nil) require.NoError(t, err) var toolResultCount int for _, msg := range msgsResp.Messages { if msg.Role == codersdk.ChatMessageRoleTool { toolResultCount++ } } require.Equal(t, len(toolCallIDs), toolResultCount, "expected one tool-result message per submitted result") }) t.Run("WrongStatus", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) // Create a chat that is NOT in requires_action status. chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ OrganizationID: user.OrganizationID, Status: database.ChatStatusWaiting, ClientType: database.ChatClientTypeUi, OwnerID: user.UserID, LastModelConfigID: modelConfig.ID, Title: "wrong-status-test", }) require.NoError(t, err) err = client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ Results: []codersdk.ToolResult{ {ToolCallID: "call_xyz", Output: json.RawMessage(`"nope"`)}, }, }) requireSDKError(t, err, http.StatusConflict) }) t.Run("MissingResult", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) const toolName = "my_dynamic_tool" toolCallIDs := []string{"call_one", "call_two"} chat := setupRequiresAction(ctx, t, db, user.UserID, user.OrganizationID, modelConfig.ID, toolName, toolCallIDs) // Submit only one of the two required results. err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ Results: []codersdk.ToolResult{ {ToolCallID: "call_one", Output: json.RawMessage(`"partial"`)}, }, }) requireSDKError(t, err, http.StatusBadRequest) }) t.Run("UnexpectedResult", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) const toolName = "my_dynamic_tool" toolCallIDs := []string{"call_real"} chat := setupRequiresAction(ctx, t, db, user.UserID, user.OrganizationID, modelConfig.ID, toolName, toolCallIDs) // Submit a result with a wrong tool_call_id. err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ Results: []codersdk.ToolResult{ {ToolCallID: "call_bogus", Output: json.RawMessage(`"wrong"`)}, }, }) requireSDKError(t, err, http.StatusBadRequest) }) t.Run("InvalidJSONOutput", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) const toolName = "my_dynamic_tool" toolCallIDs := []string{"call_json"} chat := setupRequiresAction(ctx, t, db, user.UserID, user.OrganizationID, modelConfig.ID, toolName, toolCallIDs) // We must bypass the SDK client because json.RawMessage // rejects invalid JSON during json.Marshal. A raw HTTP // request lets the invalid payload reach the server so we // can verify server-side validation. rawBody := `{"results":[{"tool_call_id":"call_json","output":not-json,"is_error":false}]}` url := client.URL.JoinPath(fmt.Sprintf("/api/experimental/chats/%s/tool-results", chat.ID)).String() req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBufferString(rawBody)) require.NoError(t, err) req.Header.Set("Content-Type", "application/json") req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() require.Equal(t, http.StatusBadRequest, resp.StatusCode) }) t.Run("DuplicateToolCallID", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) const toolName = "my_dynamic_tool" toolCallIDs := []string{"call_dup1", "call_dup2"} chat := setupRequiresAction(ctx, t, db, user.UserID, user.OrganizationID, modelConfig.ID, toolName, toolCallIDs) err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ Results: []codersdk.ToolResult{ {ToolCallID: "call_dup1", Output: json.RawMessage(`"result_a"`)}, {ToolCallID: "call_dup1", Output: json.RawMessage(`"result_b"`)}, }, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Contains(t, sdkErr.Message, "Duplicate tool_call_id") }) t.Run("EmptyResults", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) const toolName = "my_dynamic_tool" toolCallIDs := []string{"call_empty"} chat := setupRequiresAction(ctx, t, db, user.UserID, user.OrganizationID, modelConfig.ID, toolName, toolCallIDs) err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ Results: []codersdk.ToolResult{}, }) requireSDKError(t, err, http.StatusBadRequest) }) t.Run("NotFoundForDifferentUser", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) const toolName = "my_dynamic_tool" toolCallIDs := []string{"call_other"} chat := setupRequiresAction(ctx, t, db, user.UserID, user.OrganizationID, modelConfig.ID, toolName, toolCallIDs) // Create a second user and try to submit tool results // to user A's chat. otherClientRaw, _ := coderdtest.CreateAnotherUser( t, client.Client, user.OrganizationID, rbac.RoleAgentsAccess(), ) otherClient := codersdk.NewExperimentalClient(otherClientRaw) err := otherClient.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ Results: []codersdk.ToolResult{ {ToolCallID: "call_other", Output: json.RawMessage(`"nope"`)}, }, }) requireSDKError(t, err, http.StatusNotFound) }) t.Run("MemberWithoutAgentsAccess", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) // Create a member without agents-access. Even though the // member owns the chat, submitting tool results should be // gated by agents-access because it triggers inference. memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) const toolName = "my_dynamic_tool" toolCallIDs := []string{"call_noaccess"} chat := setupRequiresAction(ctx, t, db, member.ID, firstUser.OrganizationID, modelConfig.ID, toolName, toolCallIDs) err := memberClient.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{ Results: []codersdk.ToolResult{ {ToolCallID: "call_noaccess", Output: json.RawMessage(`"should fail"`)}, }, }) requireSDKError(t, err, http.StatusForbidden) }) } func TestPostChats_DynamicToolValidation(t *testing.T) { t.Parallel() t.Run("TooManyTools", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) user := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) tools := make([]codersdk.DynamicTool, 251) for i := range tools { tools[i] = codersdk.DynamicTool{ Name: fmt.Sprintf("tool-%d", i), } } _, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: user.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "hello", }}, UnsafeDynamicTools: tools, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Too many dynamic tools.", sdkErr.Message) }) t.Run("EmptyToolName", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) user := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) _, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: user.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "hello", }}, UnsafeDynamicTools: []codersdk.DynamicTool{ {Name: ""}, }, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Dynamic tool name must not be empty.", sdkErr.Message) }) t.Run("DuplicateToolName", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) user := coderdtest.CreateFirstUser(t, client.Client) _ = createChatModelConfig(t, client) _, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: user.OrganizationID, Content: []codersdk.ChatInputPart{{ Type: codersdk.ChatInputPartTypeText, Text: "hello", }}, UnsafeDynamicTools: []codersdk.DynamicTool{ {Name: "dup-tool"}, {Name: "dup-tool"}, }, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Duplicate dynamic tool name.", sdkErr.Message) }) } func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error { t.Helper() var sdkErr *codersdk.Error require.ErrorAs(t, err, &sdkErr) require.Equal(t, expectedStatus, sdkErr.StatusCode()) return sdkErr }