diff --git a/agent/agentcontextconfig/api.go b/agent/agentcontextconfig/api.go index 75211df37c..d82c98902f 100644 --- a/agent/agentcontextconfig/api.go +++ b/agent/agentcontextconfig/api.go @@ -134,6 +134,33 @@ func Config(workingDir string) (workspacesdk.ContextConfigResponse, []string) { }, ResolvePaths(mcpConfigFile, workingDir) } +// ContextPartsFromDir reads instruction files and discovers skills +// from a specific directory, using default file names. This is used +// by the CLI chat context commands to read context from an arbitrary +// directory without consulting agent env vars. +func ContextPartsFromDir(dir string) []codersdk.ChatMessagePart { + var parts []codersdk.ChatMessagePart + + if entry, found := readInstructionFileFromDir(dir, DefaultInstructionsFile); found { + parts = append(parts, entry) + } + + // Reuse ResolvePaths so CLI skill discovery follows the same + // project-relative path handling as agent config resolution. + skillParts := discoverSkills( + ResolvePaths(strings.Join([]string{DefaultSkillsDir, "skills"}, ","), dir), + DefaultSkillMetaFile, + ) + parts = append(parts, skillParts...) + + // Guarantee non-nil slice. + if parts == nil { + parts = []codersdk.ChatMessagePart{} + } + + return parts +} + // MCPConfigFiles returns the resolved MCP configuration file // paths for the agent's MCP manager. func (api *API) MCPConfigFiles() []string { diff --git a/agent/agentcontextconfig/api_test.go b/agent/agentcontextconfig/api_test.go index 35cc75507e..be5075c6d9 100644 --- a/agent/agentcontextconfig/api_test.go +++ b/agent/agentcontextconfig/api_test.go @@ -23,18 +23,144 @@ func filterParts(parts []codersdk.ChatMessagePart, t codersdk.ChatMessagePartTyp return out } -func TestConfig(t *testing.T) { - t.Run("Defaults", func(t *testing.T) { - fakeHome := t.TempDir() - t.Setenv("HOME", fakeHome) - t.Setenv("USERPROFILE", fakeHome) +func writeSkillMetaFileInRoot(t *testing.T, skillsRoot, name, description string) string { + t.Helper() - // Clear all env vars so defaults are used. - t.Setenv(agentcontextconfig.EnvInstructionsDirs, "") - t.Setenv(agentcontextconfig.EnvInstructionsFile, "") - t.Setenv(agentcontextconfig.EnvSkillsDirs, "") - t.Setenv(agentcontextconfig.EnvSkillMetaFile, "") - t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "") + skillDir := filepath.Join(skillsRoot, name) + require.NoError(t, os.MkdirAll(skillDir, 0o755)) + require.NoError(t, os.WriteFile( + filepath.Join(skillDir, "SKILL.md"), + []byte("---\nname: "+name+"\ndescription: "+description+"\n---\nSkill body"), + 0o600, + )) + + return skillDir +} + +func writeSkillMetaFile(t *testing.T, dir, name, description string) string { + t.Helper() + return writeSkillMetaFileInRoot(t, filepath.Join(dir, ".agents", "skills"), name, description) +} + +func TestContextPartsFromDir(t *testing.T) { + t.Parallel() + + t.Run("ReturnsInstructionFilePart", func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + instructionPath := filepath.Join(dir, "AGENTS.md") + require.NoError(t, os.WriteFile(instructionPath, []byte("project instructions"), 0o600)) + + parts := agentcontextconfig.ContextPartsFromDir(dir) + contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile) + skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill) + + require.Len(t, parts, 1) + require.Len(t, contextParts, 1) + require.Empty(t, skillParts) + require.Equal(t, instructionPath, contextParts[0].ContextFilePath) + require.Equal(t, "project instructions", contextParts[0].ContextFileContent) + require.False(t, contextParts[0].ContextFileTruncated) + }) + + t.Run("ReturnsSkillParts", func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + skillDir := writeSkillMetaFile(t, dir, "my-skill", "A test skill") + + parts := agentcontextconfig.ContextPartsFromDir(dir) + contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile) + skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill) + + require.Len(t, parts, 1) + require.Empty(t, contextParts) + require.Len(t, skillParts, 1) + require.Equal(t, "my-skill", skillParts[0].SkillName) + require.Equal(t, "A test skill", skillParts[0].SkillDescription) + require.Equal(t, skillDir, skillParts[0].SkillDir) + require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile) + }) + + t.Run("ReturnsSkillPartsFromSkillsDir", func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + skillDir := writeSkillMetaFileInRoot( + t, + filepath.Join(dir, "skills"), + "my-skill", + "A test skill", + ) + + parts := agentcontextconfig.ContextPartsFromDir(dir) + contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile) + skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill) + + require.Len(t, parts, 1) + require.Empty(t, contextParts) + require.Len(t, skillParts, 1) + require.Equal(t, "my-skill", skillParts[0].SkillName) + require.Equal(t, "A test skill", skillParts[0].SkillDescription) + require.Equal(t, skillDir, skillParts[0].SkillDir) + require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile) + }) + + t.Run("ReturnsEmptyForEmptyDir", func(t *testing.T) { + t.Parallel() + + parts := agentcontextconfig.ContextPartsFromDir(t.TempDir()) + + require.NotNil(t, parts) + require.Empty(t, parts) + }) + + t.Run("ReturnsCombinedResults", func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + instructionPath := filepath.Join(dir, "AGENTS.md") + require.NoError(t, os.WriteFile(instructionPath, []byte("combined instructions"), 0o600)) + skillDir := writeSkillMetaFile(t, dir, "combined-skill", "Combined test skill") + + parts := agentcontextconfig.ContextPartsFromDir(dir) + contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile) + skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill) + + require.Len(t, parts, 2) + require.Len(t, contextParts, 1) + require.Len(t, skillParts, 1) + require.Equal(t, instructionPath, contextParts[0].ContextFilePath) + require.Equal(t, "combined instructions", contextParts[0].ContextFileContent) + require.Equal(t, "combined-skill", skillParts[0].SkillName) + require.Equal(t, skillDir, skillParts[0].SkillDir) + }) +} + +func setupConfigTestEnv(t *testing.T, overrides map[string]string) string { + t.Helper() + + fakeHome := t.TempDir() + t.Setenv("HOME", fakeHome) + t.Setenv("USERPROFILE", fakeHome) + t.Setenv(agentcontextconfig.EnvInstructionsDirs, "") + t.Setenv(agentcontextconfig.EnvInstructionsFile, "") + t.Setenv(agentcontextconfig.EnvSkillsDirs, "") + t.Setenv(agentcontextconfig.EnvSkillMetaFile, "") + t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "") + + for key, value := range overrides { + t.Setenv(key, value) + } + + return fakeHome +} + +func TestConfig(t *testing.T) { + //nolint:paralleltest // Uses t.Setenv to mutate process-wide environment. + t.Run("Defaults", func(t *testing.T) { + setupConfigTestEnv(t, nil) workDir := platformAbsPath("work") cfg, mcpFiles := agentcontextconfig.Config(workDir) @@ -46,20 +172,18 @@ func TestConfig(t *testing.T) { require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, mcpFiles) }) + //nolint:paralleltest // Uses t.Setenv to mutate process-wide environment. t.Run("CustomEnvVars", func(t *testing.T) { - fakeHome := t.TempDir() - t.Setenv("HOME", fakeHome) - t.Setenv("USERPROFILE", fakeHome) - optInstructions := t.TempDir() optSkills := t.TempDir() optMCP := platformAbsPath("opt", "mcp.json") - - t.Setenv(agentcontextconfig.EnvInstructionsDirs, optInstructions) - t.Setenv(agentcontextconfig.EnvInstructionsFile, "CUSTOM.md") - t.Setenv(agentcontextconfig.EnvSkillsDirs, optSkills) - t.Setenv(agentcontextconfig.EnvSkillMetaFile, "META.yaml") - t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP) + setupConfigTestEnv(t, map[string]string{ + agentcontextconfig.EnvInstructionsDirs: optInstructions, + agentcontextconfig.EnvInstructionsFile: "CUSTOM.md", + agentcontextconfig.EnvSkillsDirs: optSkills, + agentcontextconfig.EnvSkillMetaFile: "META.yaml", + agentcontextconfig.EnvMCPConfigFiles: optMCP, + }) // Create files matching the custom names so we can // verify the env vars actually change lookup behavior. @@ -85,15 +209,12 @@ func TestConfig(t *testing.T) { require.Equal(t, "META.yaml", skillParts[0].ContextFileSkillMetaFile) }) + //nolint:paralleltest // Uses t.Setenv to mutate process-wide environment. t.Run("WhitespaceInFileNames", func(t *testing.T) { - fakeHome := t.TempDir() - t.Setenv("HOME", fakeHome) - t.Setenv("USERPROFILE", fakeHome) + fakeHome := setupConfigTestEnv(t, map[string]string{ + agentcontextconfig.EnvInstructionsFile: " CLAUDE.md ", + }) t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome) - t.Setenv(agentcontextconfig.EnvInstructionsFile, " CLAUDE.md ") - t.Setenv(agentcontextconfig.EnvSkillsDirs, "") - t.Setenv(agentcontextconfig.EnvSkillMetaFile, "") - t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "") workDir := t.TempDir() // Create a file matching the trimmed name. @@ -106,19 +227,13 @@ func TestConfig(t *testing.T) { require.Equal(t, "hello", ctxFiles[0].ContextFileContent) }) + //nolint:paralleltest // Uses t.Setenv to mutate process-wide environment. t.Run("CommaSeparatedDirs", func(t *testing.T) { - fakeHome := t.TempDir() - t.Setenv("HOME", fakeHome) - t.Setenv("USERPROFILE", fakeHome) - a := t.TempDir() b := t.TempDir() - - t.Setenv(agentcontextconfig.EnvInstructionsDirs, a+","+b) - t.Setenv(agentcontextconfig.EnvInstructionsFile, "") - t.Setenv(agentcontextconfig.EnvSkillsDirs, "") - t.Setenv(agentcontextconfig.EnvSkillMetaFile, "") - t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "") + setupConfigTestEnv(t, map[string]string{ + agentcontextconfig.EnvInstructionsDirs: a + "," + b, + }) // Put instruction files in both dirs. require.NoError(t, os.WriteFile(filepath.Join(a, "AGENTS.md"), []byte("from a"), 0o600)) @@ -133,17 +248,10 @@ func TestConfig(t *testing.T) { require.Equal(t, "from b", ctxFiles[1].ContextFileContent) }) + //nolint:paralleltest // Uses t.Setenv to mutate process-wide environment. t.Run("ReadsInstructionFiles", func(t *testing.T) { - t.Setenv(agentcontextconfig.EnvInstructionsDirs, "") - t.Setenv(agentcontextconfig.EnvInstructionsFile, "") - t.Setenv(agentcontextconfig.EnvSkillsDirs, "") - t.Setenv(agentcontextconfig.EnvSkillMetaFile, "") - t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "") - workDir := t.TempDir() - fakeHome := t.TempDir() - t.Setenv("HOME", fakeHome) - t.Setenv("USERPROFILE", fakeHome) + fakeHome := setupConfigTestEnv(t, nil) // Create ~/.coder/AGENTS.md coderDir := filepath.Join(fakeHome, ".coder") @@ -164,16 +272,9 @@ func TestConfig(t *testing.T) { require.False(t, ctxFiles[0].ContextFileTruncated) }) + //nolint:paralleltest // Uses t.Setenv to mutate process-wide environment. t.Run("ReadsWorkingDirInstructionFile", func(t *testing.T) { - fakeHome := t.TempDir() - t.Setenv("HOME", fakeHome) - t.Setenv("USERPROFILE", fakeHome) - t.Setenv(agentcontextconfig.EnvInstructionsDirs, "") - t.Setenv(agentcontextconfig.EnvInstructionsFile, "") - t.Setenv(agentcontextconfig.EnvSkillsDirs, "") - t.Setenv(agentcontextconfig.EnvSkillMetaFile, "") - t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "") - + setupConfigTestEnv(t, nil) workDir := t.TempDir() // Create AGENTS.md in the working directory. @@ -193,16 +294,9 @@ func TestConfig(t *testing.T) { require.Equal(t, filepath.Join(workDir, "AGENTS.md"), ctxFiles[0].ContextFilePath) }) + //nolint:paralleltest // Uses t.Setenv to mutate process-wide environment. t.Run("TruncatesLargeInstructionFile", func(t *testing.T) { - fakeHome := t.TempDir() - t.Setenv("HOME", fakeHome) - t.Setenv("USERPROFILE", fakeHome) - t.Setenv(agentcontextconfig.EnvInstructionsDirs, "") - t.Setenv(agentcontextconfig.EnvInstructionsFile, "") - t.Setenv(agentcontextconfig.EnvSkillsDirs, "") - t.Setenv(agentcontextconfig.EnvSkillMetaFile, "") - t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "") - + setupConfigTestEnv(t, nil) workDir := t.TempDir() largeContent := strings.Repeat("a", 64*1024+100) require.NoError(t, os.WriteFile(filepath.Join(workDir, "AGENTS.md"), []byte(largeContent), 0o600)) @@ -215,79 +309,47 @@ func TestConfig(t *testing.T) { require.Len(t, ctxFiles[0].ContextFileContent, 64*1024) }) - t.Run("SanitizesHTMLComments", func(t *testing.T) { - fakeHome := t.TempDir() - t.Setenv("HOME", fakeHome) - t.Setenv("USERPROFILE", fakeHome) - t.Setenv(agentcontextconfig.EnvInstructionsDirs, "") - t.Setenv(agentcontextconfig.EnvInstructionsFile, "") - t.Setenv(agentcontextconfig.EnvSkillsDirs, "") - t.Setenv(agentcontextconfig.EnvSkillMetaFile, "") - t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "") + sanitizationTests := []struct { + name string + input string + expected string + }{ + { + name: "SanitizesHTMLComments", + input: "visible\ncontent", + expected: "visible\ncontent", + }, + { + name: "SanitizesInvisibleUnicode", + input: "before\u200bafter", + expected: "beforeafter", + }, + { + name: "NormalizesCRLF", + input: "line1\r\nline2\rline3", + expected: "line1\nline2\nline3", + }, + } + //nolint:paralleltest // Uses t.Setenv to mutate process-wide environment. + for _, tt := range sanitizationTests { + t.Run(tt.name, func(t *testing.T) { + setupConfigTestEnv(t, nil) + workDir := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(workDir, "AGENTS.md"), + []byte(tt.input), + 0o600, + )) - workDir := t.TempDir() - require.NoError(t, os.WriteFile( - filepath.Join(workDir, "AGENTS.md"), - []byte("visible\ncontent"), - 0o600, - )) + cfg, _ := agentcontextconfig.Config(workDir) - cfg, _ := agentcontextconfig.Config(workDir) - - ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile) - require.Len(t, ctxFiles, 1) - require.Equal(t, "visible\ncontent", ctxFiles[0].ContextFileContent) - }) - - t.Run("SanitizesInvisibleUnicode", func(t *testing.T) { - fakeHome := t.TempDir() - t.Setenv("HOME", fakeHome) - t.Setenv("USERPROFILE", fakeHome) - t.Setenv(agentcontextconfig.EnvInstructionsDirs, "") - t.Setenv(agentcontextconfig.EnvInstructionsFile, "") - t.Setenv(agentcontextconfig.EnvSkillsDirs, "") - t.Setenv(agentcontextconfig.EnvSkillMetaFile, "") - t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "") - - workDir := t.TempDir() - // U+200B (zero-width space) should be stripped. - require.NoError(t, os.WriteFile( - filepath.Join(workDir, "AGENTS.md"), - []byte("before\u200bafter"), - 0o600, - )) - - cfg, _ := agentcontextconfig.Config(workDir) - - ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile) - require.Len(t, ctxFiles, 1) - require.Equal(t, "beforeafter", ctxFiles[0].ContextFileContent) - }) - - t.Run("NormalizesCRLF", func(t *testing.T) { - fakeHome := t.TempDir() - t.Setenv("HOME", fakeHome) - t.Setenv("USERPROFILE", fakeHome) - t.Setenv(agentcontextconfig.EnvInstructionsDirs, "") - t.Setenv(agentcontextconfig.EnvInstructionsFile, "") - t.Setenv(agentcontextconfig.EnvSkillsDirs, "") - t.Setenv(agentcontextconfig.EnvSkillMetaFile, "") - t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "") - - workDir := t.TempDir() - require.NoError(t, os.WriteFile( - filepath.Join(workDir, "AGENTS.md"), - []byte("line1\r\nline2\rline3"), - 0o600, - )) - - cfg, _ := agentcontextconfig.Config(workDir) - - ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile) - require.Len(t, ctxFiles, 1) - require.Equal(t, "line1\nline2\nline3", ctxFiles[0].ContextFileContent) - }) + ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile) + require.Len(t, ctxFiles, 1) + require.Equal(t, tt.expected, ctxFiles[0].ContextFileContent) + }) + } + //nolint:paralleltest // Uses t.Setenv to mutate process-wide environment. t.Run("DiscoversSkills", func(t *testing.T) { fakeHome := t.TempDir() t.Setenv("HOME", fakeHome) @@ -320,17 +382,13 @@ func TestConfig(t *testing.T) { require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile) }) + //nolint:paralleltest // Uses t.Setenv to mutate process-wide environment. t.Run("SkipsMissingDirs", func(t *testing.T) { - fakeHome := t.TempDir() - t.Setenv("HOME", fakeHome) - t.Setenv("USERPROFILE", fakeHome) - nonExistent := filepath.Join(t.TempDir(), "does-not-exist") - t.Setenv(agentcontextconfig.EnvInstructionsDirs, nonExistent) - t.Setenv(agentcontextconfig.EnvInstructionsFile, "") - t.Setenv(agentcontextconfig.EnvSkillsDirs, nonExistent) - t.Setenv(agentcontextconfig.EnvSkillMetaFile, "") - t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "") + setupConfigTestEnv(t, map[string]string{ + agentcontextconfig.EnvInstructionsDirs: nonExistent, + agentcontextconfig.EnvSkillsDirs: nonExistent, + }) workDir := t.TempDir() cfg, _ := agentcontextconfig.Config(workDir) @@ -340,17 +398,13 @@ func TestConfig(t *testing.T) { require.Empty(t, cfg.Parts) }) + //nolint:paralleltest // Uses t.Setenv to mutate process-wide environment. t.Run("MCPConfigFilesResolvedSeparately", func(t *testing.T) { - fakeHome := t.TempDir() - t.Setenv("HOME", fakeHome) - t.Setenv("USERPROFILE", fakeHome) - t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome) - t.Setenv(agentcontextconfig.EnvInstructionsFile, "") - t.Setenv(agentcontextconfig.EnvSkillsDirs, "") - t.Setenv(agentcontextconfig.EnvSkillMetaFile, "") - optMCP := platformAbsPath("opt", "custom.json") - t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP) + fakeHome := setupConfigTestEnv(t, map[string]string{ + agentcontextconfig.EnvMCPConfigFiles: optMCP, + }) + t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome) workDir := t.TempDir() _, mcpFiles := agentcontextconfig.Config(workDir) @@ -358,14 +412,10 @@ func TestConfig(t *testing.T) { require.Equal(t, []string{optMCP}, mcpFiles) }) + //nolint:paralleltest // Uses t.Setenv to mutate process-wide environment. t.Run("SkillNameMustMatchDir", func(t *testing.T) { - fakeHome := t.TempDir() - t.Setenv("HOME", fakeHome) - t.Setenv("USERPROFILE", fakeHome) + fakeHome := setupConfigTestEnv(t, nil) t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome) - t.Setenv(agentcontextconfig.EnvInstructionsFile, "") - t.Setenv(agentcontextconfig.EnvSkillMetaFile, "") - t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "") workDir := t.TempDir() skillsDir := filepath.Join(workDir, "skills") @@ -385,14 +435,10 @@ func TestConfig(t *testing.T) { require.Empty(t, skillParts) }) + //nolint:paralleltest // Uses t.Setenv to mutate process-wide environment. t.Run("DuplicateSkillsFirstWins", func(t *testing.T) { - fakeHome := t.TempDir() - t.Setenv("HOME", fakeHome) - t.Setenv("USERPROFILE", fakeHome) + fakeHome := setupConfigTestEnv(t, nil) t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome) - t.Setenv(agentcontextconfig.EnvInstructionsFile, "") - t.Setenv(agentcontextconfig.EnvSkillMetaFile, "") - t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "") workDir := t.TempDir() skillsDir1 := filepath.Join(workDir, "skills1") diff --git a/cli/exp_chat.go b/cli/exp_chat.go new file mode 100644 index 0000000000..61c017f172 --- /dev/null +++ b/cli/exp_chat.go @@ -0,0 +1,194 @@ +package cli + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/agent/agentcontextconfig" + "github.com/coder/coder/v2/codersdk/agentsdk" + "github.com/coder/serpent" +) + +func (r *RootCmd) chatCommand() *serpent.Command { + return &serpent.Command{ + Use: "chat", + Short: "Manage agent chats", + Long: "Commands for interacting with chats from within a workspace.", + Handler: func(i *serpent.Invocation) error { + return i.Command.HelpHandler(i) + }, + Children: []*serpent.Command{ + r.chatContextCommand(), + }, + } +} + +func (r *RootCmd) chatContextCommand() *serpent.Command { + return &serpent.Command{ + Use: "context", + Short: "Manage chat context", + Long: "Add or clear context files and skills for an active chat session.", + Handler: func(i *serpent.Invocation) error { + return i.Command.HelpHandler(i) + }, + Children: []*serpent.Command{ + r.chatContextAddCommand(), + r.chatContextClearCommand(), + }, + } +} + +func (*RootCmd) chatContextAddCommand() *serpent.Command { + var ( + dir string + chatID string + ) + agentAuth := &AgentAuth{} + cmd := &serpent.Command{ + Use: "add", + Short: "Add context to an active chat", + Long: "Read instruction files and discover skills from a directory, then add " + + "them as context to an active chat session. Multiple calls " + + "are additive.", + Handler: func(inv *serpent.Invocation) error { + ctx := inv.Context() + ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...) + defer stop() + + if dir == "" && inv.Environ.Get("CODER") != "true" { + return xerrors.New("this command must be run inside a Coder workspace (set --dir to override)") + } + + client, err := agentAuth.CreateClient() + if err != nil { + return xerrors.Errorf("create agent client: %w", err) + } + + resolvedDir := dir + if resolvedDir == "" { + resolvedDir, err = os.Getwd() + if err != nil { + return xerrors.Errorf("get working directory: %w", err) + } + } + resolvedDir, err = filepath.Abs(resolvedDir) + if err != nil { + return xerrors.Errorf("resolve directory: %w", err) + } + info, err := os.Stat(resolvedDir) + if err != nil { + return xerrors.Errorf("cannot read directory %q: %w", resolvedDir, err) + } + if !info.IsDir() { + return xerrors.Errorf("%q is not a directory", resolvedDir) + } + + parts := agentcontextconfig.ContextPartsFromDir(resolvedDir) + if len(parts) == 0 { + _, _ = fmt.Fprintln(inv.Stderr, "No context files or skills found in "+resolvedDir) + return nil + } + + // Resolve chat ID from flag or auto-detect. + resolvedChatID, err := parseChatID(chatID) + if err != nil { + return err + } + + resp, err := client.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + ChatID: resolvedChatID, + Parts: parts, + }) + if err != nil { + return xerrors.Errorf("add chat context: %w", err) + } + + _, _ = fmt.Fprintf(inv.Stdout, "Added %d context part(s) to chat %s\n", resp.Count, resp.ChatID) + return nil + }, + Options: serpent.OptionSet{ + { + Name: "Directory", + Flag: "dir", + Description: "Directory to read context files and skills from. Defaults to the current working directory.", + Value: serpent.StringOf(&dir), + }, + { + Name: "Chat ID", + Flag: "chat", + Env: "CODER_CHAT_ID", + Description: "Chat ID to add context to. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.", + Value: serpent.StringOf(&chatID), + }, + }, + } + agentAuth.AttachOptions(cmd, false) + return cmd +} + +func (*RootCmd) chatContextClearCommand() *serpent.Command { + var chatID string + agentAuth := &AgentAuth{} + cmd := &serpent.Command{ + Use: "clear", + Short: "Clear context from an active chat", + Long: "Soft-delete all context-file and skill messages from an active chat. " + + "The next turn will re-fetch default context from the agent.", + Handler: func(inv *serpent.Invocation) error { + ctx := inv.Context() + ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...) + defer stop() + + client, err := agentAuth.CreateClient() + if err != nil { + return xerrors.Errorf("create agent client: %w", err) + } + + resolvedChatID, err := parseChatID(chatID) + if err != nil { + return err + } + + resp, err := client.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{ + ChatID: resolvedChatID, + }) + if err != nil { + return xerrors.Errorf("clear chat context: %w", err) + } + + if resp.ChatID == uuid.Nil { + _, _ = fmt.Fprintln(inv.Stdout, "No active chats to clear.") + } else { + _, _ = fmt.Fprintf(inv.Stdout, "Cleared context from chat %s\n", resp.ChatID) + } + return nil + }, + Options: serpent.OptionSet{{ + Name: "Chat ID", + Flag: "chat", + Env: "CODER_CHAT_ID", + Description: "Chat ID to clear context from. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.", + Value: serpent.StringOf(&chatID), + }}, + } + agentAuth.AttachOptions(cmd, false) + return cmd +} + +// parseChatID returns the chat UUID from the flag value (which +// serpent already populates from --chat or CODER_CHAT_ID). Returns +// uuid.Nil if empty (the server will auto-detect). +func parseChatID(flagValue string) (uuid.UUID, error) { + if flagValue == "" { + return uuid.Nil, nil + } + parsed, err := uuid.Parse(flagValue) + if err != nil { + return uuid.Nil, xerrors.Errorf("invalid chat ID %q: %w", flagValue, err) + } + return parsed, nil +} diff --git a/cli/exp_chat_test.go b/cli/exp_chat_test.go new file mode 100644 index 0000000000..30696c6eca --- /dev/null +++ b/cli/exp_chat_test.go @@ -0,0 +1,46 @@ +package cli_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/cli/clitest" +) + +func TestExpChatContextAdd(t *testing.T) { + t.Parallel() + + t.Run("RequiresWorkspaceOrDir", func(t *testing.T) { + t.Parallel() + + inv, _ := clitest.New(t, "exp", "chat", "context", "add") + + err := inv.Run() + require.Error(t, err) + require.Contains(t, err.Error(), "this command must be run inside a Coder workspace") + }) + + t.Run("AllowsExplicitDir", func(t *testing.T) { + t.Parallel() + + inv, _ := clitest.New(t, "exp", "chat", "context", "add", "--dir", t.TempDir()) + + err := inv.Run() + if err != nil { + require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace") + } + }) + + t.Run("AllowsWorkspaceEnv", func(t *testing.T) { + t.Parallel() + + inv, _ := clitest.New(t, "exp", "chat", "context", "add") + inv.Environ.Set("CODER", "true") + + err := inv.Run() + if err != nil { + require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace") + } + }) +} diff --git a/cli/root.go b/cli/root.go index a0b57923b4..0af41238f3 100644 --- a/cli/root.go +++ b/cli/root.go @@ -148,6 +148,7 @@ func (r *RootCmd) AGPLExperimental() []*serpent.Command { return []*serpent.Command{ r.scaletestCmd(), r.errorExample(), + r.chatCommand(), r.mcpCommand(), r.promptExample(), r.rptyCommand(), diff --git a/coderd/coderd.go b/coderd/coderd.go index e1ae2a9502..fa145d3fd2 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1653,6 +1653,10 @@ func New(options *Options) *API { r.Get("/gitsshkey", api.agentGitSSHKey) r.Post("/log-source", api.workspaceAgentPostLogSource) r.Get("/reinit", api.workspaceAgentReinit) + r.Route("/experimental", func(r chi.Router) { + r.Post("/chat-context", api.workspaceAgentAddChatContext) + r.Delete("/chat-context", api.workspaceAgentClearChatContext) + }) r.Route("/tasks/{task}", func(r chi.Router) { r.Post("/log-snapshot", api.postWorkspaceAgentTaskLogSnapshot) }) diff --git a/coderd/coderdtest/swaggerparser.go b/coderd/coderdtest/swaggerparser.go index efb6461fe0..05854d2f8f 100644 --- a/coderd/coderdtest/swaggerparser.go +++ b/coderd/coderdtest/swaggerparser.go @@ -147,6 +147,10 @@ func parseSwaggerComment(commentGroup *ast.CommentGroup) SwaggerComment { return c } +func isExperimentalEndpoint(route string) bool { + return strings.HasPrefix(route, "/workspaceagents/me/experimental/") +} + func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments []SwaggerComment) { assertUniqueRoutes(t, swaggerComments) assertSingleAnnotations(t, swaggerComments) @@ -165,6 +169,9 @@ func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments [ if strings.HasSuffix(route, "/*") { return } + if isExperimentalEndpoint(route) { + return + } c := findSwaggerCommentByMethodAndRoute(swaggerComments, method, route) assert.NotNil(t, c, "Missing @Router annotation") diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 6e2834305d..9e952dc002 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1708,6 +1708,17 @@ func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error return q.db.CleanupDeletedMCPServerIDsFromChats(ctx) } +func (q *querier) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error { + chat, err := q.db.GetChatByID(ctx, chatID) + if err != nil { + return err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil { + return err + } + return q.db.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID) +} + func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) { prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type) if err != nil { @@ -2413,6 +2424,10 @@ func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) { return q.db.GetActiveAISeatCount(ctx) } +func (q *querier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) { + return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetActiveChatsByAgentID)(ctx, agentID) +} + func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil { return nil, err @@ -5728,6 +5743,17 @@ func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg databas return q.db.SoftDeleteChatMessagesAfterID(ctx, arg) } +func (q *querier) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error { + chat, err := q.db.GetChatByID(ctx, chatID) + if err != nil { + return err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil { + return err + } + return q.db.SoftDeleteContextFileMessages(ctx, chatID) +} + func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) { return q.db.TryAcquireLock(ctx, id) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index e9add7a2a7..c85769b562 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -478,6 +478,24 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), arg).Return([]database.Chat{chatA, chatB}, nil).AnyTimes() check.Args(arg).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB}) })) + s.Run("GetActiveChatsByAgentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + agentID := uuid.New() + dbm.EXPECT().GetActiveChatsByAgentID(gomock.Any(), agentID).Return([]database.Chat{chat}, nil).AnyTimes() + check.Args(agentID).Asserts(chat, policy.ActionRead).Returns([]database.Chat{chat}) + })) + s.Run("SoftDeleteContextFileMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().SoftDeleteContextFileMessages(gomock.Any(), chat.ID).Return(nil).AnyTimes() + check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns() + })) + s.Run("ClearChatMessageProviderResponseIDsByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().ClearChatMessageProviderResponseIDsByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes() + check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns() + })) s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { arg := database.GetChatCostPerChatParams{ OwnerID: uuid.New(), diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 664285853c..4b2d27ef0e 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -280,6 +280,14 @@ func (m queryMetricsStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Conte return r0 } +func (m queryMetricsStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error { + start := time.Now() + r0 := m.s.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID) + m.queryLatencies.WithLabelValues("ClearChatMessageProviderResponseIDsByChatID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ClearChatMessageProviderResponseIDsByChatID").Inc() + return r0 +} + func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) { start := time.Now() r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg) @@ -968,6 +976,14 @@ func (m queryMetricsStore) GetActiveAISeatCount(ctx context.Context) (int64, err return r0, r1 } +func (m queryMetricsStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) { + start := time.Now() + r0, r1 := m.s.GetActiveChatsByAgentID(ctx, agentID) + m.queryLatencies.WithLabelValues("GetActiveChatsByAgentID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetActiveChatsByAgentID").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) { start := time.Now() r0, r1 := m.s.GetActivePresetPrebuildSchedules(ctx) @@ -4104,6 +4120,14 @@ func (m queryMetricsStore) SoftDeleteChatMessagesAfterID(ctx context.Context, ar return r0 } +func (m queryMetricsStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error { + start := time.Now() + r0 := m.s.SoftDeleteContextFileMessages(ctx, chatID) + m.queryLatencies.WithLabelValues("SoftDeleteContextFileMessages").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteContextFileMessages").Inc() + return r0 +} + func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) { start := time.Now() r0, r1 := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 30472ca133..8498ec0aa9 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -363,6 +363,20 @@ func (mr *MockStoreMockRecorder) CleanupDeletedMCPServerIDsFromChats(ctx any) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupDeletedMCPServerIDsFromChats", reflect.TypeOf((*MockStore)(nil).CleanupDeletedMCPServerIDsFromChats), ctx) } +// ClearChatMessageProviderResponseIDsByChatID mocks base method. +func (m *MockStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClearChatMessageProviderResponseIDsByChatID", ctx, chatID) + ret0, _ := ret[0].(error) + return ret0 +} + +// ClearChatMessageProviderResponseIDsByChatID indicates an expected call of ClearChatMessageProviderResponseIDsByChatID. +func (mr *MockStoreMockRecorder) ClearChatMessageProviderResponseIDsByChatID(ctx, chatID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearChatMessageProviderResponseIDsByChatID", reflect.TypeOf((*MockStore)(nil).ClearChatMessageProviderResponseIDsByChatID), ctx, chatID) +} + // CountAIBridgeInterceptions mocks base method. func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) { m.ctrl.T.Helper() @@ -1667,6 +1681,21 @@ func (mr *MockStoreMockRecorder) GetActiveAISeatCount(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveAISeatCount", reflect.TypeOf((*MockStore)(nil).GetActiveAISeatCount), ctx) } +// GetActiveChatsByAgentID mocks base method. +func (m *MockStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveChatsByAgentID", ctx, agentID) + ret0, _ := ret[0].([]database.Chat) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveChatsByAgentID indicates an expected call of GetActiveChatsByAgentID. +func (mr *MockStoreMockRecorder) GetActiveChatsByAgentID(ctx, agentID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveChatsByAgentID", reflect.TypeOf((*MockStore)(nil).GetActiveChatsByAgentID), ctx, agentID) +} + // GetActivePresetPrebuildSchedules mocks base method. func (m *MockStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) { m.ctrl.T.Helper() @@ -7780,6 +7809,20 @@ func (mr *MockStoreMockRecorder) SoftDeleteChatMessagesAfterID(ctx, arg any) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessagesAfterID), ctx, arg) } +// SoftDeleteContextFileMessages mocks base method. +func (m *MockStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SoftDeleteContextFileMessages", ctx, chatID) + ret0, _ := ret[0].(error) + return ret0 +} + +// SoftDeleteContextFileMessages indicates an expected call of SoftDeleteContextFileMessages. +func (mr *MockStoreMockRecorder) SoftDeleteContextFileMessages(ctx, chatID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteContextFileMessages", reflect.TypeOf((*MockStore)(nil).SoftDeleteContextFileMessages), ctx, chatID) +} + // TryAcquireLock mocks base method. func (m *MockStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index c59dca8950..0c04ae18e5 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -3783,6 +3783,8 @@ CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled); CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id); +CREATE INDEX idx_chats_agent_id ON chats USING btree (agent_id) WHERE (agent_id IS NOT NULL); + CREATE INDEX idx_chats_labels ON chats USING gin (labels); CREATE INDEX idx_chats_last_model_config_id ON chats USING btree (last_model_config_id); diff --git a/coderd/database/migrations/000465_chat_agent_id_index.down.sql b/coderd/database/migrations/000465_chat_agent_id_index.down.sql new file mode 100644 index 0000000000..7e7de2550c --- /dev/null +++ b/coderd/database/migrations/000465_chat_agent_id_index.down.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS idx_chats_agent_id; diff --git a/coderd/database/migrations/000465_chat_agent_id_index.up.sql b/coderd/database/migrations/000465_chat_agent_id_index.up.sql new file mode 100644 index 0000000000..87f9684561 --- /dev/null +++ b/coderd/database/migrations/000465_chat_agent_id_index.up.sql @@ -0,0 +1 @@ +CREATE INDEX idx_chats_agent_id ON chats(agent_id) WHERE agent_id IS NOT NULL; diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 5dc83c2ccb..0bf767c17b 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -76,6 +76,7 @@ type sqlcQuerier interface { CleanTailnetLostPeers(ctx context.Context) error CleanTailnetTunnels(ctx context.Context) error CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error + ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error) CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) @@ -215,6 +216,7 @@ type sqlcQuerier interface { GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error) GetActiveAISeatCount(ctx context.Context) (int64, error) + GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]Chat, error) GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error) GetActiveUserCount(ctx context.Context, includeSystem bool) (int64, error) GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error) @@ -893,6 +895,7 @@ type sqlcQuerier interface { SelectUsageEventsForPublishing(ctx context.Context, now time.Time) ([]UsageEvent, error) SoftDeleteChatMessageByID(ctx context.Context, id int64) error SoftDeleteChatMessagesAfterID(ctx context.Context, arg SoftDeleteChatMessagesAfterIDParams) error + SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error // Non blocking lock. Returns true if the lock was acquired, false otherwise. // // This must be called from within a transaction. The lock will be automatically diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 016d2f4927..6a207470fd 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4505,6 +4505,19 @@ func (q *sqlQuerier) BackoffChatDiffStatus(ctx context.Context, arg BackoffChatD return err } +const clearChatMessageProviderResponseIDsByChatID = `-- name: ClearChatMessageProviderResponseIDsByChatID :exec +UPDATE chat_messages +SET provider_response_id = NULL +WHERE chat_id = $1::uuid + AND deleted = false + AND provider_response_id IS NOT NULL +` + +func (q *sqlQuerier) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error { + _, err := q.db.ExecContext(ctx, clearChatMessageProviderResponseIDsByChatID, chatID) + return err +} + const countEnabledModelsWithoutPricing = `-- name: CountEnabledModelsWithoutPricing :one SELECT COUNT(*)::bigint AS count FROM chat_model_configs @@ -4603,6 +4616,66 @@ func (q *sqlQuerier) DeleteOldChats(ctx context.Context, arg DeleteOldChatsParam return result.RowsAffected() } +const getActiveChatsByAgentID = `-- name: GetActiveChatsByAgentID :many +SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools +FROM chats +WHERE agent_id = $1::uuid + AND archived = false + -- Active statuses only: waiting, pending, running, paused, + -- requires_action. + -- Excludes completed and error (terminal states). + AND status IN ('waiting', 'running', 'paused', 'pending', 'requires_action') +ORDER BY updated_at DESC +` + +func (q *sqlQuerier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]Chat, error) { + rows, err := q.db.QueryContext(ctx, getActiveChatsByAgentID, agentID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Chat + for rows.Next() { + var i Chat + if err := rows.Scan( + &i.ID, + &i.OwnerID, + &i.WorkspaceID, + &i.Title, + &i.Status, + &i.WorkerID, + &i.StartedAt, + &i.HeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.ParentChatID, + &i.RootChatID, + &i.LastModelConfigID, + &i.Archived, + &i.LastError, + &i.Mode, + pq.Array(&i.MCPServerIDs), + &i.Labels, + &i.BuildID, + &i.AgentID, + &i.PinOrder, + &i.LastReadMessageID, + &i.LastInjectedContext, + &i.DynamicTools, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getChatByID = `-- name: GetChatByID :one SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools @@ -6706,6 +6779,18 @@ func (q *sqlQuerier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg Soft return err } +const softDeleteContextFileMessages = `-- name: SoftDeleteContextFileMessages :exec +UPDATE chat_messages SET deleted = true +WHERE chat_id = $1::uuid + AND deleted = false + AND content::jsonb @> '[{"type": "context-file"}]' +` + +func (q *sqlQuerier) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error { + _, err := q.db.ExecContext(ctx, softDeleteContextFileMessages, chatID) + return err +} + const unarchiveChatByID = `-- name: UnarchiveChatByID :many WITH chats AS ( UPDATE chats SET diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 587b4a4da6..c54a8d51fa 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -1293,3 +1293,26 @@ GROUP BY cm.chat_id; SELECT id, provider, model, context_limit, enabled, is_default FROM chat_model_configs WHERE deleted = false; +-- name: GetActiveChatsByAgentID :many +SELECT * +FROM chats +WHERE agent_id = @agent_id::uuid + AND archived = false + -- Active statuses only: waiting, pending, running, paused, + -- requires_action. + -- Excludes completed and error (terminal states). + AND status IN ('waiting', 'running', 'paused', 'pending', 'requires_action') +ORDER BY updated_at DESC; + +-- name: ClearChatMessageProviderResponseIDsByChatID :exec +UPDATE chat_messages +SET provider_response_id = NULL +WHERE chat_id = @chat_id::uuid + AND deleted = false + AND provider_response_id IS NOT NULL; + +-- name: SoftDeleteContextFileMessages :exec +UPDATE chat_messages SET deleted = true +WHERE chat_id = @chat_id::uuid + AND deleted = false + AND content::jsonb @> '[{"type": "context-file"}]'; diff --git a/coderd/export_test.go b/coderd/export_test.go new file mode 100644 index 0000000000..95d8313cab --- /dev/null +++ b/coderd/export_test.go @@ -0,0 +1,4 @@ +package coderd + +// InsertAgentChatTestModelConfig exposes insertAgentChatTestModelConfig for external tests. +var InsertAgentChatTestModelConfig = insertAgentChatTestModelConfig diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index def90f23d2..855833d95a 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -42,6 +42,8 @@ import ( "github.com/coder/coder/v2/coderd/telemetry" maputil "github.com/coder/coder/v2/coderd/util/maps" "github.com/coder/coder/v2/coderd/wspubsub" + "github.com/coder/coder/v2/coderd/x/chatd" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" "github.com/coder/coder/v2/coderd/x/gitsync" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" @@ -2393,3 +2395,598 @@ func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.Wor } return sdk } + +// maxChatContextParts caps the number of parts per request to +// prevent unbounded message payloads. +const maxChatContextParts = 100 + +// maxChatContextFileBytes caps each context-file part to the same +// 64KiB budget used when the agent reads instruction files from disk. +const maxChatContextFileBytes = 64 * 1024 + +// maxChatContextRequestBodyBytes caps the JSON request body size for +// agent-added context to roughly the same per-part budget used when +// reading instruction files from disk. +const maxChatContextRequestBodyBytes int64 = maxChatContextParts * maxChatContextFileBytes + +// sanitizeWorkspaceAgentContextFileContent applies prompt +// sanitization, then enforces the 64KiB per-file budget. The +// truncated flag is preserved when the caller already capped the +// file before sending it. +func sanitizeWorkspaceAgentContextFileContent( + content string, + truncated bool, +) (string, bool) { + content = chatd.SanitizePromptText(content) + if len(content) > maxChatContextFileBytes { + content = content[:maxChatContextFileBytes] + truncated = true + } + return content, truncated +} + +// readChatContextBody reads and validates the request body for chat +// context endpoints. It handles MaxBytesReader wrapping, error +// responses, and body rewind. If the body is empty or whitespace-only +// and allowEmpty is true, it returns false without writing an error. +// +//nolint:revive // Add and clear endpoints only differ by empty-body handling. +func readChatContextBody(ctx context.Context, rw http.ResponseWriter, r *http.Request, dst any, allowEmpty bool) bool { + r.Body = http.MaxBytesReader(rw, r.Body, maxChatContextRequestBodyBytes) + body, err := io.ReadAll(r.Body) + if err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{ + Message: "Request body too large.", + Detail: fmt.Sprintf("Maximum request body size is %d bytes.", maxChatContextRequestBodyBytes), + }) + return false + } + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Failed to read request body.", + Detail: err.Error(), + }) + return false + } + if allowEmpty && len(bytes.TrimSpace(body)) == 0 { + r.Body = http.NoBody + return false + } + + r.Body = io.NopCloser(bytes.NewReader(body)) + return httpapi.Read(ctx, rw, r, dst) +} + +// @x-apidocgen {"skip": true} +func (api *API) workspaceAgentAddChatContext(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + workspaceAgent := httpmw.WorkspaceAgent(r) + + var req agentsdk.AddChatContextRequest + if !readChatContextBody(ctx, rw, r, &req, false) { + return + } + + if len(req.Parts) == 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "No context parts provided.", + }) + return + } + + if len(req.Parts) > maxChatContextParts { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Too many context parts (%d). Maximum is %d.", len(req.Parts), maxChatContextParts), + }) + return + } + + // Filter to only non-empty context-file and skill parts. + filtered := chatd.FilterContextParts(req.Parts, false) + if len(filtered) == 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "No context-file or skill parts provided.", + }) + return + } + req.Parts = filtered + responsePartCount := 0 + + // Use system context for chat operations since the + // workspace agent scope does not include chat resources. + // We verify agent-to-chat ownership explicitly below. + //nolint:gocritic // Agent needs system access to read/write chat resources. + sysCtx := dbauthz.AsSystemRestricted(ctx) + workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to determine workspace from agent token.", + Detail: err.Error(), + }) + return + } + + chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID) + if err != nil { + writeAgentChatError(ctx, rw, err) + return + } + + // Stamp each persisted part with the agent identity. Context-file + // parts also get server-authoritative workspace metadata. + directory := workspaceAgent.ExpandedDirectory + if directory == "" { + directory = workspaceAgent.Directory + } + for i := range req.Parts { + req.Parts[i].ContextFileAgentID = uuid.NullUUID{ + UUID: workspaceAgent.ID, + Valid: true, + } + if req.Parts[i].Type != codersdk.ChatMessagePartTypeContextFile { + continue + } + req.Parts[i].ContextFileContent, req.Parts[i].ContextFileTruncated = sanitizeWorkspaceAgentContextFileContent( + req.Parts[i].ContextFileContent, + req.Parts[i].ContextFileTruncated, + ) + req.Parts[i].ContextFileOS = workspaceAgent.OperatingSystem + req.Parts[i].ContextFileDirectory = directory + } + req.Parts = chatd.FilterContextParts(req.Parts, false) + if len(req.Parts) == 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "No context-file or skill parts provided.", + }) + return + } + responsePartCount = len(req.Parts) + + // Skill-only messages need a sentinel context-file part so the turn + // pipeline trusts the associated skill metadata. + req.Parts = prependAgentChatContextSentinelIfNeeded( + req.Parts, + workspaceAgent.ID, + workspaceAgent.OperatingSystem, + directory, + ) + + content, err := chatprompt.MarshalParts(req.Parts) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to marshal context parts.", + Detail: err.Error(), + }) + return + } + + err = api.Database.InTx(func(tx database.Store) error { + locked, err := tx.GetChatByIDForUpdate(sysCtx, chat.ID) + if err != nil { + return xerrors.Errorf("lock chat: %w", err) + } + if !isActiveAgentChat(locked) { + return errChatNotActive + } + if !locked.AgentID.Valid || locked.AgentID.UUID != workspaceAgent.ID { + return errChatDoesNotBelongToAgent + } + if locked.OwnerID != workspace.OwnerID { + return errChatDoesNotBelongToWorkspaceOwner + } + if _, err := tx.InsertChatMessages(sysCtx, chatd.BuildSingleChatMessageInsertParams( + chat.ID, + database.ChatMessageRoleUser, + content, + database.ChatMessageVisibilityBoth, + locked.LastModelConfigID, + chatprompt.CurrentContentVersion, + uuid.Nil, + )); err != nil { + return xerrors.Errorf("insert context message: %w", err) + } + if err := updateAgentChatLastInjectedContextFromMessages(sysCtx, api.Logger, tx, chat.ID); err != nil { + return xerrors.Errorf("rebuild injected context cache: %w", err) + } + return nil + }, nil) + if err != nil { + if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) { + writeAgentChatError(ctx, rw, err) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to persist context message.", + Detail: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusOK, agentsdk.AddChatContextResponse{ + ChatID: chat.ID, + Count: responsePartCount, + }) +} + +// @x-apidocgen {"skip": true} +func (api *API) workspaceAgentClearChatContext(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + workspaceAgent := httpmw.WorkspaceAgent(r) + + var req agentsdk.ClearChatContextRequest + populated := readChatContextBody(ctx, rw, r, &req, true) + if !populated && r.Body != http.NoBody { + return + } + + // Use system context for chat operations since the + // workspace agent scope does not include chat resources. + //nolint:gocritic // Agent needs system access to read/write chat resources. + sysCtx := dbauthz.AsSystemRestricted(ctx) + workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to determine workspace from agent token.", + Detail: err.Error(), + }) + return + } + + chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID) + if err != nil { + // Zero active chats is not an error for clear. + if errors.Is(err, errNoActiveChats) { + httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{}) + return + } + writeAgentChatError(ctx, rw, err) + return + } + + err = clearAgentChatContext(sysCtx, api.Database, chat.ID, workspaceAgent.ID, workspace.OwnerID) + if err != nil { + if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) { + writeAgentChatError(ctx, rw, err) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to clear context from chat.", + Detail: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{ + ChatID: chat.ID, + }) +} + +var ( + errNoActiveChats = xerrors.New("no active chats found") + errChatNotFound = xerrors.New("chat not found") + errChatNotActive = xerrors.New("chat is not active") + errChatDoesNotBelongToAgent = xerrors.New("chat does not belong to this agent") + errChatDoesNotBelongToWorkspaceOwner = xerrors.New("chat does not belong to this workspace owner") +) + +type multipleActiveChatsError struct { + count int +} + +func (e *multipleActiveChatsError) Error() string { + return fmt.Sprintf( + "multiple active chats (%d) found for this agent, specify a chat ID", + e.count, + ) +} + +func resolveDefaultAgentChat(chats []database.Chat) (database.Chat, error) { + switch len(chats) { + case 0: + return database.Chat{}, errNoActiveChats + case 1: + return chats[0], nil + } + + var rootChat *database.Chat + for i := range chats { + chat := &chats[i] + if chat.ParentChatID.Valid { + continue + } + if rootChat != nil { + return database.Chat{}, &multipleActiveChatsError{count: len(chats)} + } + rootChat = chat + } + if rootChat != nil { + return *rootChat, nil + } + return database.Chat{}, &multipleActiveChatsError{count: len(chats)} +} + +// resolveAgentChat finds the target chat from either an explicit ID +// or auto-detection via the agent's active chats. +func resolveAgentChat( + ctx context.Context, + db database.Store, + agentID uuid.UUID, + workspaceOwnerID uuid.UUID, + explicitChatID uuid.UUID, +) (database.Chat, error) { + if explicitChatID == uuid.Nil { + chats, err := db.GetActiveChatsByAgentID(ctx, agentID) + if err != nil { + return database.Chat{}, xerrors.Errorf("list active chats: %w", err) + } + ownerChats := make([]database.Chat, 0, len(chats)) + for _, chat := range chats { + if chat.OwnerID != workspaceOwnerID { + continue + } + ownerChats = append(ownerChats, chat) + } + return resolveDefaultAgentChat(ownerChats) + } + + chat, err := db.GetChatByID(ctx, explicitChatID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return database.Chat{}, errChatNotFound + } + return database.Chat{}, xerrors.Errorf("get chat by id: %w", err) + } + if !chat.AgentID.Valid || chat.AgentID.UUID != agentID { + return database.Chat{}, errChatDoesNotBelongToAgent + } + if chat.OwnerID != workspaceOwnerID { + return database.Chat{}, errChatDoesNotBelongToWorkspaceOwner + } + if !isActiveAgentChat(chat) { + return database.Chat{}, errChatNotActive + } + return chat, nil +} + +func isActiveAgentChat(chat database.Chat) bool { + if chat.Archived { + return false + } + + switch chat.Status { + case database.ChatStatusWaiting, + database.ChatStatusPending, + database.ChatStatusRunning, + database.ChatStatusPaused, + database.ChatStatusRequiresAction: + return true + default: + return false + } +} + +func clearAgentChatContext( + ctx context.Context, + db database.Store, + chatID uuid.UUID, + agentID uuid.UUID, + workspaceOwnerID uuid.UUID, +) error { + return db.InTx(func(tx database.Store) error { + locked, err := tx.GetChatByIDForUpdate(ctx, chatID) + if err != nil { + return xerrors.Errorf("lock chat: %w", err) + } + if !isActiveAgentChat(locked) { + return errChatNotActive + } + if !locked.AgentID.Valid || locked.AgentID.UUID != agentID { + return errChatDoesNotBelongToAgent + } + if locked.OwnerID != workspaceOwnerID { + return errChatDoesNotBelongToWorkspaceOwner + } + messages, err := tx.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chatID, + AfterID: 0, + }) + if err != nil { + return xerrors.Errorf("get chat messages: %w", err) + } + hadInjectedContext := locked.LastInjectedContext.Valid + var skillOnlyMessageIDs []int64 + for _, msg := range messages { + if !msg.Content.Valid { + continue + } + hasContextFile := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeContextFile) + hasSkill := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeSkill) + if hasContextFile || hasSkill { + hadInjectedContext = true + } + if hasSkill && !hasContextFile { + skillOnlyMessageIDs = append(skillOnlyMessageIDs, msg.ID) + } + } + if !hadInjectedContext { + return nil + } + if err := tx.SoftDeleteContextFileMessages(ctx, chatID); err != nil { + return xerrors.Errorf("soft delete context-file messages: %w", err) + } + for _, messageID := range skillOnlyMessageIDs { + if err := tx.SoftDeleteChatMessageByID(ctx, messageID); err != nil { + return xerrors.Errorf("soft delete context message %d: %w", messageID, err) + } + } + // Reset provider-side Responses chaining so the next turn replays + // the post-clear history instead of inheriting cleared context. + if err := tx.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID); err != nil { + return xerrors.Errorf("clear provider response chain: %w", err) + } + // Clear the injected-context cache inside the transaction so it is + // atomic with the soft-deletes. + param, err := chatd.BuildLastInjectedContext(nil) + if err != nil { + return xerrors.Errorf("clear injected context cache: %w", err) + } + if _, err := tx.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{ + ID: chatID, + LastInjectedContext: param, + }); err != nil { + return xerrors.Errorf("clear injected context cache: %w", err) + } + return nil + }, nil) +} + +// prependAgentChatContextSentinelIfNeeded adds an empty context-file +// part when the request only carries skills. The turn pipeline uses +// the sentinel's agent metadata to trust the skill parts. +func prependAgentChatContextSentinelIfNeeded( + parts []codersdk.ChatMessagePart, + agentID uuid.UUID, + operatingSystem string, + directory string, +) []codersdk.ChatMessagePart { + hasContextFile := false + hasSkill := false + for _, part := range parts { + switch part.Type { + case codersdk.ChatMessagePartTypeContextFile: + hasContextFile = true + case codersdk.ChatMessagePartTypeSkill: + hasSkill = true + } + if hasContextFile && hasSkill { + return parts + } + } + if !hasSkill || hasContextFile { + return parts + } + return append([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: chatd.AgentChatContextSentinelPath, + ContextFileAgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + ContextFileOS: operatingSystem, + ContextFileDirectory: directory, + }}, parts...) +} + +func sortChatMessagesByCreatedAtAndID(messages []database.ChatMessage) { + sort.SliceStable(messages, func(i, j int) bool { + if messages[i].CreatedAt.Equal(messages[j].CreatedAt) { + return messages[i].ID < messages[j].ID + } + return messages[i].CreatedAt.Before(messages[j].CreatedAt) + }) +} + +// updateAgentChatLastInjectedContextFromMessages rebuilds the +// injected-context cache from all persisted context-file and skill parts. +func updateAgentChatLastInjectedContextFromMessages( + ctx context.Context, + logger slog.Logger, + db database.Store, + chatID uuid.UUID, +) error { + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chatID, + AfterID: 0, + }) + if err != nil { + return xerrors.Errorf("load context messages for injected context: %w", err) + } + + sortChatMessagesByCreatedAtAndID(messages) + + parts, err := chatd.CollectContextPartsFromMessages(ctx, logger, messages, true) + if err != nil { + return xerrors.Errorf("collect injected context parts: %w", err) + } + parts = chatd.FilterContextPartsToLatestAgent(parts) + + param, err := chatd.BuildLastInjectedContext(parts) + if err != nil { + return xerrors.Errorf("update injected context: %w", err) + } + if _, err := db.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{ + ID: chatID, + LastInjectedContext: param, + }); err != nil { + return xerrors.Errorf("update injected context: %w", err) + } + return nil +} + +func messageHasPartTypes(raw []byte, types ...codersdk.ChatMessagePartType) bool { + var parts []codersdk.ChatMessagePart + if err := json.Unmarshal(raw, &parts); err != nil { + return false + } + for _, part := range parts { + for _, typ := range types { + if part.Type == typ { + return true + } + } + } + return false +} + +// writeAgentChatError translates resolveAgentChat errors to HTTP +// responses. +func writeAgentChatError( + ctx context.Context, + rw http.ResponseWriter, + err error, +) { + if errors.Is(err, errNoActiveChats) { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ + Message: "No active chats found for this agent.", + }) + return + } + if errors.Is(err, errChatNotFound) { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ + Message: "Chat not found.", + }) + return + } + if errors.Is(err, errChatDoesNotBelongToAgent) { + httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{ + Message: "Chat does not belong to this agent.", + }) + return + } + if errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) { + httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{ + Message: "Chat does not belong to this workspace owner.", + }) + return + } + if errors.Is(err, errChatNotActive) { + httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{ + Message: "Cannot modify context: this chat is no longer active.", + }) + return + } + + var multipleErr *multipleActiveChatsError + if errors.As(err, &multipleErr) { + httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{ + Message: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to resolve chat.", + Detail: err.Error(), + }) +} diff --git a/coderd/workspaceagents_active_chat_internal_test.go b/coderd/workspaceagents_active_chat_internal_test.go new file mode 100644 index 0000000000..68e3beeda1 --- /dev/null +++ b/coderd/workspaceagents_active_chat_internal_test.go @@ -0,0 +1,76 @@ +package coderd + +import ( + "fmt" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/testutil" +) + +func TestActiveAgentChatDefinitionsAgree(t *testing.T) { + t.Parallel() + + ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium)) + db, _ := dbtestutil.NewDB(t) + + org, err := db.GetDefaultOrganization(ctx) + require.NoError(t, err) + + owner := dbgen.User(t, db, database.User{}) + workspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: org.ID, + OwnerID: owner.ID, + }).WithAgent().Do() + modelConfig := insertAgentChatTestModelConfig(ctx, t, db, owner.ID) + + insertedChats := make([]database.Chat, 0, len(database.AllChatStatusValues())*2) + for _, archived := range []bool{false, true} { + for _, status := range database.AllChatStatusValues() { + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + Status: status, + OwnerID: owner.ID, + LastModelConfigID: modelConfig.ID, + Title: fmt.Sprintf("%s-archived-%t", status, archived), + AgentID: uuid.NullUUID{UUID: workspace.Agents[0].ID, Valid: true}, + }) + require.NoError(t, err) + + if archived { + _, err = db.ArchiveChatByID(ctx, chat.ID) + require.NoError(t, err) + + chat, err = db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + } + + insertedChats = append(insertedChats, chat) + } + } + + activeChats, err := db.GetActiveChatsByAgentID(ctx, workspace.Agents[0].ID) + require.NoError(t, err) + + activeByID := make(map[uuid.UUID]bool, len(activeChats)) + for _, chat := range activeChats { + activeByID[chat.ID] = true + } + + for _, chat := range insertedChats { + require.Equalf( + t, + isActiveAgentChat(chat), + activeByID[chat.ID], + "status=%s archived=%t", + chat.Status, + chat.Archived, + ) + } +} diff --git a/coderd/workspaceagents_chat_context_internal_test.go b/coderd/workspaceagents_chat_context_internal_test.go new file mode 100644 index 0000000000..377c79466c --- /dev/null +++ b/coderd/workspaceagents_chat_context_internal_test.go @@ -0,0 +1,128 @@ +package coderd + +import ( + "context" + "database/sql" + "encoding/json" + "testing" + "time" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/codersdk" +) + +func TestUpdateAgentChatLastInjectedContextFromMessagesUsesMessageIDTieBreaker(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + chatID := uuid.New() + createdAt := time.Date(2026, time.April, 9, 13, 0, 0, 0, time.UTC) + oldAgentID := uuid.New() + newAgentID := uuid.New() + + oldContent, err := json.Marshal([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/old/AGENTS.md", + ContextFileContent: "old instructions", + ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true}, + }}) + require.NoError(t, err) + newContent, err := json.Marshal([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/new/AGENTS.md", + ContextFileContent: "new instructions", + ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true}, + }}) + require.NoError(t, err) + + db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ + ChatID: chatID, + AfterID: 0, + }).Return([]database.ChatMessage{ + { + ID: 2, + CreatedAt: createdAt, + Content: pqtype.NullRawMessage{ + RawMessage: newContent, + Valid: true, + }, + }, + { + ID: 1, + CreatedAt: createdAt, + Content: pqtype.NullRawMessage{ + RawMessage: oldContent, + Valid: true, + }, + }, + }, nil) + + db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) { + require.Equal(t, chatID, arg.ID) + require.True(t, arg.LastInjectedContext.Valid) + var cached []codersdk.ChatMessagePart + require.NoError(t, json.Unmarshal(arg.LastInjectedContext.RawMessage, &cached)) + require.Len(t, cached, 1) + require.Equal(t, "/new/AGENTS.md", cached[0].ContextFilePath) + require.Equal(t, uuid.NullUUID{UUID: newAgentID, Valid: true}, cached[0].ContextFileAgentID) + return database.Chat{}, nil + }, + ) + + err = updateAgentChatLastInjectedContextFromMessages( + context.Background(), + slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + db, + chatID, + ) + require.NoError(t, err) +} + +func insertAgentChatTestModelConfig( + ctx context.Context, + t testing.TB, + db database.Store, + userID uuid.UUID, +) database.ChatModelConfig { + t.Helper() + + sysCtx := dbauthz.AsSystemRestricted(ctx) + createdBy := uuid.NullUUID{UUID: userID, Valid: true} + + _, err := db.InsertChatProvider(sysCtx, database.InsertChatProviderParams{ + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-api-key", + ApiKeyKeyID: sql.NullString{}, + CreatedBy: createdBy, + Enabled: true, + CentralApiKeyEnabled: true, + }) + require.NoError(t, err) + + model, err := db.InsertChatModelConfig(sysCtx, database.InsertChatModelConfigParams{ + Provider: "openai", + Model: "gpt-4o-mini", + DisplayName: "Test Model", + CreatedBy: createdBy, + UpdatedBy: createdBy, + Enabled: true, + IsDefault: true, + ContextLimit: 128000, + CompressionThreshold: 70, + Options: json.RawMessage(`{}`), + }) + require.NoError(t, err) + + return model +} diff --git a/coderd/workspaceagents_chat_context_test.go b/coderd/workspaceagents_chat_context_test.go new file mode 100644 index 0000000000..d3f67fb8bc --- /dev/null +++ b/coderd/workspaceagents_chat_context_test.go @@ -0,0 +1,1084 @@ +package coderd_test + +import ( + "context" + "encoding/json" + "net/http" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/x/chatd" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/agentsdk" + "github.com/coder/coder/v2/testutil" +) + +type agentChatContextTestSetup struct { + client *codersdk.Client + db database.Store + user codersdk.CreateFirstUserResponse + workspace dbfake.WorkspaceResponse + agentClient *agentsdk.Client +} + +type agentChatContextBeforeInTxStore struct { + database.Store + beforeInTx func() +} + +func (s *agentChatContextBeforeInTxStore) InTx(fn func(database.Store) error, opts *database.TxOptions) error { + if s.beforeInTx != nil { + beforeInTx := s.beforeInTx + s.beforeInTx = nil + beforeInTx() + } + return s.Store.InTx(fn, opts) +} + +func TestAgentChatContext(t *testing.T) { + t.Parallel() + + type addSuccessStep struct { + req agentsdk.AddChatContextRequest + wantCount int + } + + type addSuccessCase struct { + name string + steps []addSuccessStep + wantStored [][]codersdk.ChatMessagePart + storedOrdered bool + wantCached []codersdk.ChatMessagePart + cachedOrdered bool + } + + agentInstructionsPart := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/AGENTS.md", + ContextFileContent: "context from the agent", + } + fileAPart := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/file-a.md", + ContextFileContent: "file A context", + } + fileBPart := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/file-b.md", + ContextFileContent: "file B context", + } + repoHelperSkillPart := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "repo-helper", + SkillDescription: "Repository instructions", + SkillDir: "/workspace/.agents/skills/repo-helper", + ContextFileSkillMetaFile: "SKILL.md", + } + projectInstructionsPart := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/AGENTS.md", + ContextFileContent: "project instructions", + } + cachedAgentInstructionsPart := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: agentInstructionsPart.ContextFilePath, + } + cachedFileAPart := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: fileAPart.ContextFilePath, + } + cachedFileBPart := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: fileBPart.ContextFilePath, + } + cachedRepoHelperSkillPart := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: repoHelperSkillPart.SkillName, + SkillDescription: repoHelperSkillPart.SkillDescription, + } + cachedProjectInstructionsPart := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: projectInstructionsPart.ContextFilePath, + } + + addSuccessCases := []addSuccessCase{ + { + name: "AddSuccessFiltersPartsAndUpdatesCache", + steps: []addSuccessStep{{req: agentsdk.AddChatContextRequest{Parts: []codersdk.ChatMessagePart{codersdk.ChatMessageText("ignore this text part"), agentInstructionsPart}}, wantCount: 1}}, + wantStored: [][]codersdk.ChatMessagePart{{agentInstructionsPart}}, + storedOrdered: true, + wantCached: []codersdk.ChatMessagePart{cachedAgentInstructionsPart}, + cachedOrdered: true, + }, + { + name: "AddSuccessIsAdditive", + steps: []addSuccessStep{{req: agentsdk.AddChatContextRequest{Parts: []codersdk.ChatMessagePart{fileAPart}}, wantCount: 1}, {req: agentsdk.AddChatContextRequest{Parts: []codersdk.ChatMessagePart{fileBPart}}, wantCount: 1}}, + wantStored: [][]codersdk.ChatMessagePart{{fileAPart}, {fileBPart}}, + storedOrdered: false, + wantCached: []codersdk.ChatMessagePart{cachedFileAPart, cachedFileBPart}, + cachedOrdered: false, + }, + { + name: "AddSuccessWithSkillOnlyPartsGetsSentinel", + steps: []addSuccessStep{{req: agentsdk.AddChatContextRequest{Parts: []codersdk.ChatMessagePart{repoHelperSkillPart}}, wantCount: 1}}, + wantStored: [][]codersdk.ChatMessagePart{{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: chatd.AgentChatContextSentinelPath, + }, repoHelperSkillPart}}, + storedOrdered: true, + wantCached: []codersdk.ChatMessagePart{cachedRepoHelperSkillPart}, + cachedOrdered: true, + }, + { + name: "AddSuccessWithMixedPartsNoSentinel", + steps: []addSuccessStep{{req: agentsdk.AddChatContextRequest{Parts: []codersdk.ChatMessagePart{projectInstructionsPart, repoHelperSkillPart}}, wantCount: 2}}, + wantStored: [][]codersdk.ChatMessagePart{{projectInstructionsPart, repoHelperSkillPart}}, + storedOrdered: true, + wantCached: []codersdk.ChatMessagePart{cachedProjectInstructionsPart, cachedRepoHelperSkillPart}, + cachedOrdered: true, + }, + } + + for _, tc := range addSuccessCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + + for _, step := range tc.steps { + resp, err := setup.agentClient.AddChatContext(ctx, step.req) + require.NoError(t, err) + require.Equal(t, chat.ID, resp.ChatID) + require.Equal(t, step.wantCount, resp.Count) + } + + actualStored := requireAgentChatContextStoredMessages(t, requireAgentChatContextMessages(ctx, t, setup.db, chat.ID)) + agent := setup.workspace.Agents[0] + wantStored := agentChatContextExpectedMessages(agent, tc.wantStored) + if tc.storedOrdered { + require.Equal(t, wantStored, actualStored) + } else { + require.ElementsMatch(t, wantStored, actualStored) + } + + wantCached := agentChatContextExpectedCachedParts(agent, tc.wantCached) + actualCached := requireAgentChatContextCachedParts(ctx, t, setup.db, chat.ID) + if tc.cachedOrdered { + require.Equal(t, wantCached, actualCached) + } else { + require.ElementsMatch(t, wantCached, actualCached) + } + }) + } + + t.Run("AddUsesLockedChatModelConfig", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + baseDB, pubsub := dbtestutil.NewDB(t) + interceptDB := &agentChatContextBeforeInTxStore{Store: baseDB} + client := coderdtest.New(t, &coderdtest.Options{ + Database: interceptDB, + Pubsub: pubsub, + }) + user := coderdtest.CreateFirstUser(t, client) + workspace := dbfake.WorkspaceBuild(t, baseDB, database.WorkspaceTable{ + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }).WithAgent().Do() + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(workspace.AgentToken)) + + originalModel := coderd.InsertAgentChatTestModelConfig(ctx, t, baseDB, user.UserID) + updatedModel, err := baseDB.InsertChatModelConfig( + dbauthz.AsSystemRestricted(ctx), + database.InsertChatModelConfigParams{ + Provider: originalModel.Provider, + Model: "gpt-4o-mini-updated", + DisplayName: "Updated Test Model", + CreatedBy: uuid.NullUUID{UUID: user.UserID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: user.UserID, Valid: true}, + Enabled: true, + IsDefault: false, + ContextLimit: originalModel.ContextLimit, + CompressionThreshold: originalModel.CompressionThreshold, + Options: json.RawMessage(`{}`), + }, + ) + require.NoError(t, err) + chat := createAgentChatContextChat(ctx, t, baseDB, user.UserID, originalModel.ID, workspace.Agents[0].ID, t.Name()) + + interceptDB.beforeInTx = func() { + _, err := baseDB.UpdateChatLastModelConfigByID( + dbauthz.AsSystemRestricted(ctx), + database.UpdateChatLastModelConfigByIDParams{ + ID: chat.ID, + LastModelConfigID: updatedModel.ID, + }, + ) + require.NoError(t, err) + } + + resp, err := agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + ChatID: chat.ID, + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/instructions.md", + ContextFileContent: "remember this file", + }}, + }) + require.NoError(t, err) + require.Equal(t, chat.ID, resp.ChatID) + require.Equal(t, 1, resp.Count) + + messages := requireAgentChatContextMessages(ctx, t, baseDB, chat.ID) + require.Len(t, messages, 1) + require.True(t, messages[0].ModelConfigID.Valid) + require.Equal(t, updatedModel.ID, messages[0].ModelConfigID.UUID) + + persistedChat, err := baseDB.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.Equal(t, updatedModel.ID, persistedChat.LastModelConfigID) + }) + + t.Run("ClearDeletesSkillMessages", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + + skillPart := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "repo-helper", + SkillDescription: "Repository instructions", + SkillDir: "/workspace/.agents/skills/repo-helper", + ContextFileSkillMetaFile: "SKILL.md", + } + _, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + Parts: []codersdk.ChatMessagePart{skillPart}, + }) + require.NoError(t, err) + + messages, err := setup.db.GetChatMessagesByChatID( + dbauthz.AsSystemRestricted(ctx), + database.GetChatMessagesByChatIDParams{ChatID: chat.ID, AfterID: 0}, + ) + require.NoError(t, err) + require.Len(t, messages, 1) + + storedParts := requireAgentChatContextParts(t, messages[0].Content.RawMessage) + require.Len(t, storedParts, 2) + + // Strip the sentinel so clear must delete the skill message via + // the skill-part scan instead of the context-file bulk delete. + rawSkillOnly, err := json.Marshal([]codersdk.ChatMessagePart{storedParts[1]}) + require.NoError(t, err) + _, err = setup.db.UpdateChatMessageByID( + dbauthz.AsSystemRestricted(ctx), + database.UpdateChatMessageByIDParams{ + ID: messages[0].ID, + Content: pqtype.NullRawMessage{ + RawMessage: rawSkillOnly, + Valid: true, + }, + }, + ) + require.NoError(t, err) + + resp, err := setup.agentClient.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{}) + require.NoError(t, err) + require.Equal(t, chat.ID, resp.ChatID) + + messages, err = setup.db.GetChatMessagesByChatID( + dbauthz.AsSystemRestricted(ctx), + database.GetChatMessagesByChatIDParams{ChatID: chat.ID, AfterID: 0}, + ) + require.NoError(t, err) + require.Empty(t, messages) + + persistedChat, err := setup.db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.False(t, persistedChat.LastInjectedContext.Valid) + }) + + t.Run("ClearDeletesSkillMessagesBeforeCompressedSummary", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + + skillPart := codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "repo-helper", + SkillDescription: "Repository instructions", + SkillDir: "/workspace/.agents/skills/repo-helper", + ContextFileSkillMetaFile: "SKILL.md", + } + _, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + Parts: []codersdk.ChatMessagePart{skillPart}, + }) + require.NoError(t, err) + + messages := requireAgentChatContextMessages(ctx, t, setup.db, chat.ID) + require.Len(t, messages, 1) + + storedParts := requireAgentChatContextParts(t, messages[0].Content.RawMessage) + require.Len(t, storedParts, 2) + + // Strip the sentinel so the skill message must be found by the + // full-history scan even after compaction hides it from the + // prompt-scoped query. + rawSkillOnly, err := json.Marshal([]codersdk.ChatMessagePart{storedParts[1]}) + require.NoError(t, err) + _, err = setup.db.UpdateChatMessageByID( + dbauthz.AsSystemRestricted(ctx), + database.UpdateChatMessageByIDParams{ + ID: messages[0].ID, + Content: pqtype.NullRawMessage{ + RawMessage: rawSkillOnly, + Valid: true, + }, + }, + ) + require.NoError(t, err) + + summaryContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("compressed summary"), + }) + require.NoError(t, err) + summaryParams := chatd.BuildSingleChatMessageInsertParams( + chat.ID, + database.ChatMessageRoleUser, + summaryContent, + database.ChatMessageVisibilityModel, + chat.LastModelConfigID, + chatprompt.CurrentContentVersion, + setup.user.UserID, + ) + summaryParams.Compressed[0] = true + _, err = setup.db.InsertChatMessages( + dbauthz.AsSystemRestricted(ctx), + summaryParams, + ) + require.NoError(t, err) + + regularContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("keep this user message"), + }) + require.NoError(t, err) + _, err = setup.db.InsertChatMessages( + dbauthz.AsSystemRestricted(ctx), + chatd.BuildSingleChatMessageInsertParams( + chat.ID, + database.ChatMessageRoleUser, + regularContent, + database.ChatMessageVisibilityBoth, + chat.LastModelConfigID, + chatprompt.CurrentContentVersion, + setup.user.UserID, + ), + ) + require.NoError(t, err) + + resp, err := setup.agentClient.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{}) + require.NoError(t, err) + require.Equal(t, chat.ID, resp.ChatID) + + messages = requireAgentChatContextMessages(ctx, t, setup.db, chat.ID) + require.Len(t, messages, 1) + require.Equal(t, database.ChatMessageRoleUser, messages[0].Role) + + remainingParts := requireAgentChatContextParts(t, messages[0].Content.RawMessage) + require.Len(t, remainingParts, 1) + require.Equal(t, codersdk.ChatMessagePartTypeText, remainingParts[0].Type) + require.Equal(t, "keep this user message", remainingParts[0].Text) + + persistedChat, err := setup.db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.False(t, persistedChat.LastInjectedContext.Valid) + }) + + t.Run("ClearSuccessDeletesInjectedContext", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + + _, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/instructions.md", + ContextFileContent: "remember this file", + }}, + }) + require.NoError(t, err) + + regularContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("keep this user message"), + }) + require.NoError(t, err) + _, err = setup.db.InsertChatMessages( + dbauthz.AsSystemRestricted(ctx), + chatd.BuildSingleChatMessageInsertParams( + chat.ID, + database.ChatMessageRoleUser, + regularContent, + database.ChatMessageVisibilityBoth, + chat.LastModelConfigID, + chatprompt.CurrentContentVersion, + setup.user.UserID, + ), + ) + require.NoError(t, err) + + resp, err := setup.agentClient.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{}) + require.NoError(t, err) + require.Equal(t, chat.ID, resp.ChatID) + + messages, err := setup.db.GetChatMessagesByChatID( + dbauthz.AsSystemRestricted(ctx), + database.GetChatMessagesByChatIDParams{ChatID: chat.ID, AfterID: 0}, + ) + require.NoError(t, err) + require.Len(t, messages, 1) + require.Equal(t, database.ChatMessageRoleUser, messages[0].Role) + + remainingParts := requireAgentChatContextParts(t, messages[0].Content.RawMessage) + require.Len(t, remainingParts, 1) + require.Equal(t, codersdk.ChatMessagePartTypeText, remainingParts[0].Type) + require.Equal(t, "keep this user message", remainingParts[0].Text) + + persistedChat, err := setup.db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.False(t, persistedChat.LastInjectedContext.Valid) + }) + + t.Run("ClearSuccessResetsProviderResponseChain", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + + _, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/instructions.md", + ContextFileContent: "remember this file", + }}, + }) + require.NoError(t, err) + + assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("assistant reply"), + }) + require.NoError(t, err) + assistantParams := chatd.BuildSingleChatMessageInsertParams( + chat.ID, + database.ChatMessageRoleAssistant, + assistantContent, + database.ChatMessageVisibilityBoth, + chat.LastModelConfigID, + chatprompt.CurrentContentVersion, + uuid.Nil, + ) + assistantParams.ProviderResponseID[0] = "resp-123" + _, err = setup.db.InsertChatMessages( + dbauthz.AsSystemRestricted(ctx), + assistantParams, + ) + require.NoError(t, err) + + messages := requireAgentChatContextMessages(ctx, t, setup.db, chat.ID) + require.Len(t, messages, 2) + require.Equal(t, database.ChatMessageRoleAssistant, messages[1].Role) + require.True(t, messages[1].ProviderResponseID.Valid) + require.Equal(t, "resp-123", messages[1].ProviderResponseID.String) + + resp, err := setup.agentClient.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{}) + require.NoError(t, err) + require.Equal(t, chat.ID, resp.ChatID) + + messages = requireAgentChatContextMessages(ctx, t, setup.db, chat.ID) + require.Len(t, messages, 1) + require.Equal(t, database.ChatMessageRoleAssistant, messages[0].Role) + require.False(t, messages[0].ProviderResponseID.Valid) + + remainingParts := requireAgentChatContextParts(t, messages[0].Content.RawMessage) + require.Len(t, remainingParts, 1) + require.Equal(t, codersdk.ChatMessagePartTypeText, remainingParts[0].Type) + require.Equal(t, "assistant reply", remainingParts[0].Text) + + persistedChat, err := setup.db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.False(t, persistedChat.LastInjectedContext.Valid) + }) + + t.Run("ClearWithoutContextPreservesProviderResponseChain", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + + assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("assistant reply"), + }) + require.NoError(t, err) + assistantParams := chatd.BuildSingleChatMessageInsertParams( + chat.ID, + database.ChatMessageRoleAssistant, + assistantContent, + database.ChatMessageVisibilityBoth, + chat.LastModelConfigID, + chatprompt.CurrentContentVersion, + uuid.Nil, + ) + assistantParams.ProviderResponseID[0] = "resp-123" + _, err = setup.db.InsertChatMessages( + dbauthz.AsSystemRestricted(ctx), + assistantParams, + ) + require.NoError(t, err) + + resp, err := setup.agentClient.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{ChatID: chat.ID}) + require.NoError(t, err) + require.Equal(t, chat.ID, resp.ChatID) + + messages := requireAgentChatContextMessages(ctx, t, setup.db, chat.ID) + require.Len(t, messages, 1) + require.True(t, messages[0].ProviderResponseID.Valid) + require.Equal(t, "resp-123", messages[0].ProviderResponseID.String) + }) + + t.Run("AddFailsWhenAgentHasNoActiveChat", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + + _, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/AGENTS.md", + ContextFileContent: "missing chat", + }}, + }) + sdkErr := requireSDKError(t, err, http.StatusNotFound) + require.Equal(t, "No active chats found for this agent.", sdkErr.Message) + }) + + t.Run("AddRejectsChatOwnedByAnotherAgent", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := coderdtest.NewWithDatabase(t, nil) + user := coderdtest.CreateFirstUser(t, client) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, db, user.UserID) + + firstWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }).WithAgent().Do() + secondWorkspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }).WithAgent().Do() + + chat := createAgentChatContextChat(ctx, t, db, user.UserID, model.ID, firstWorkspace.Agents[0].ID, t.Name()) + secondAgentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(secondWorkspace.AgentToken)) + + _, err := secondAgentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + ChatID: chat.ID, + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/foreign.md", + ContextFileContent: "not your chat", + }}, + }) + sdkErr := requireSDKError(t, err, http.StatusForbidden) + require.Equal(t, "Chat does not belong to this agent.", sdkErr.Message) + }) + + t.Run("AddRejectsChatOwnedByAnotherUserOnSameAgent", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + _, otherUser := coderdtest.CreateAnotherUser(t, setup.client, setup.user.OrganizationID) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(ctx, t, setup.db, otherUser.ID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + + _, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + ChatID: chat.ID, + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/foreign.md", + ContextFileContent: "not your chat", + }}, + }) + sdkErr := requireSDKError(t, err, http.StatusForbidden) + require.Equal(t, "Chat does not belong to this workspace owner.", sdkErr.Message) + }) + + t.Run("AddRejectsTooManyParts", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + parts := make([]codersdk.ChatMessagePart, 101) + for i := range parts { + parts[i] = codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/file.md", + ContextFileContent: "too many", + } + } + + _, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{Parts: parts}) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Contains(t, sdkErr.Message, "Too many context parts") + }) + + t.Run("AddRejectsEmptyContextFileParts", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + + _, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/empty.md", + }}, + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "No context-file or skill parts provided.", sdkErr.Message) + }) + + t.Run("AddRejectsWhitespaceOnlyContextFileParts", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + + _, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + ChatID: chat.ID, + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/whitespace.md", + ContextFileContent: " \n\t", + }}, + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "No context-file or skill parts provided.", sdkErr.Message) + }) + + t.Run("AddTruncatesOversizedContextFileParts", func(t *testing.T) { + t.Parallel() + + const maxContextFileBytes = 64 * 1024 + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + largeContent := strings.Repeat("a", maxContextFileBytes+100) + + resp, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + ChatID: chat.ID, + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/AGENTS.md", + ContextFileContent: largeContent, + }}, + }) + require.NoError(t, err) + require.Equal(t, chat.ID, resp.ChatID) + require.Equal(t, 1, resp.Count) + + messages := requireAgentChatContextStoredMessages(t, requireAgentChatContextMessages(ctx, t, setup.db, chat.ID)) + require.Len(t, messages, 1) + require.Len(t, messages[0], 1) + require.True(t, messages[0][0].ContextFileTruncated) + require.Len(t, messages[0][0].ContextFileContent, maxContextFileBytes) + require.Equal(t, largeContent[:maxContextFileBytes], messages[0][0].ContextFileContent) + + cached := requireAgentChatContextCachedParts(ctx, t, setup.db, chat.ID) + require.Len(t, cached, 1) + require.True(t, cached[0].ContextFileTruncated) + }) + + t.Run("AddSanitizesBeforeApplyingContextFileSizeCap", func(t *testing.T) { + t.Parallel() + + const maxContextFileBytes = 64 * 1024 + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + + visible := strings.Repeat("a", maxContextFileBytes-1) + content := visible + strings.Repeat("\u200b", 100) + "z" + + resp, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + ChatID: chat.ID, + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/AGENTS.md", + ContextFileContent: content, + }}, + }) + require.NoError(t, err) + require.Equal(t, chat.ID, resp.ChatID) + require.Equal(t, 1, resp.Count) + + messages := requireAgentChatContextStoredMessages(t, requireAgentChatContextMessages(ctx, t, setup.db, chat.ID)) + require.Len(t, messages, 1) + require.Len(t, messages[0], 1) + require.False(t, messages[0][0].ContextFileTruncated) + require.Equal(t, visible+"z", messages[0][0].ContextFileContent) + }) + + t.Run("ClearIsIdempotentWhenNoActiveChatExists", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + + resp, err := setup.agentClient.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{}) + require.NoError(t, err) + require.Equal(t, uuid.Nil, resp.ChatID) + }) + + t.Run("AddUsesWorkspaceOwnerChatWhenAnotherUsersChatIsActive", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + _, otherUser := coderdtest.CreateAnotherUser(t, setup.client, setup.user.OrganizationID) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + ownerChat := createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()+"-owner") + foreignChat := createAgentChatContextChat(ctx, t, setup.db, otherUser.ID, model.ID, setup.workspace.Agents[0].ID, t.Name()+"-foreign") + + resp, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/file.go", + ContextFileContent: "content", + }}, + }) + require.NoError(t, err) + require.Equal(t, ownerChat.ID, resp.ChatID) + + ownerMessages := requireAgentChatContextMessages(ctx, t, setup.db, ownerChat.ID) + require.Len(t, ownerMessages, 1) + require.Empty(t, requireAgentChatContextMessages(ctx, t, setup.db, foreignChat.ID)) + }) + + t.Run("AddUsesRootChatWhenOnlySubagentMakesActiveChatAmbiguous", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + rootChat := createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()+"-root") + childChat := createAgentChatContextChildChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, rootChat.ID, t.Name()+"-child") + + resp, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/file.go", + ContextFileContent: "content", + }}, + }) + require.NoError(t, err) + require.Equal(t, rootChat.ID, resp.ChatID) + + rootMessages := requireAgentChatContextMessages(ctx, t, setup.db, rootChat.ID) + require.Len(t, rootMessages, 1) + require.Empty(t, requireAgentChatContextMessages(ctx, t, setup.db, childChat.ID)) + }) + + t.Run("AddFailsWithMultipleActiveChats", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()+"-chat1") + createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()+"-chat2") + + _, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/file.go", + ContextFileContent: "content", + }}, + }) + sdkErr := requireSDKError(t, err, http.StatusConflict) + require.Contains(t, sdkErr.Message, "multiple active chats") + }) + + t.Run("ClearUsesRootChatWhenOnlySubagentMakesActiveChatAmbiguous", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + rootChat := createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()+"-root") + childChat := createAgentChatContextChildChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, rootChat.ID, t.Name()+"-child") + + _, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + ChatID: rootChat.ID, + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/file.go", + ContextFileContent: "content", + }}, + }) + require.NoError(t, err) + + resp, err := setup.agentClient.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{}) + require.NoError(t, err) + require.Equal(t, rootChat.ID, resp.ChatID) + + require.Empty(t, requireAgentChatContextMessages(ctx, t, setup.db, rootChat.ID)) + require.Empty(t, requireAgentChatContextMessages(ctx, t, setup.db, childChat.ID)) + }) + + t.Run("ClearUsesWorkspaceOwnerChatWhenAnotherUsersChatIsActive", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + _, otherUser := coderdtest.CreateAnotherUser(t, setup.client, setup.user.OrganizationID) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + ownerChat := createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()+"-owner") + _ = createAgentChatContextChat(ctx, t, setup.db, otherUser.ID, model.ID, setup.workspace.Agents[0].ID, t.Name()+"-foreign") + + _, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + ChatID: ownerChat.ID, + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/file.go", + ContextFileContent: "content", + }}, + }) + require.NoError(t, err) + + resp, err := setup.agentClient.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{}) + require.NoError(t, err) + require.Equal(t, ownerChat.ID, resp.ChatID) + require.Empty(t, requireAgentChatContextMessages(ctx, t, setup.db, ownerChat.ID)) + }) + + t.Run("ClearRejectsChatOwnedByAnotherUserOnSameAgent", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + _, otherUser := coderdtest.CreateAnotherUser(t, setup.client, setup.user.OrganizationID) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(ctx, t, setup.db, otherUser.ID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + + _, err := setup.agentClient.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{ChatID: chat.ID}) + sdkErr := requireSDKError(t, err, http.StatusForbidden) + require.Equal(t, "Chat does not belong to this workspace owner.", sdkErr.Message) + }) + + t.Run("AddFailsWhenChatIsNotActive", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(ctx, t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(ctx, t, setup.db, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + + _, err := setup.db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusCompleted, + }) + require.NoError(t, err) + + _, err = setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + ChatID: chat.ID, + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/file.go", + ContextFileContent: "content", + }}, + }) + sdkErr := requireSDKError(t, err, http.StatusConflict) + require.Equal(t, "Cannot modify context: this chat is no longer active.", sdkErr.Message) + }) +} + +func requireAgentChatContextMessages(ctx context.Context, t testing.TB, db database.Store, chatID uuid.UUID) []database.ChatMessage { + t.Helper() + + messages, err := db.GetChatMessagesByChatID( + dbauthz.AsSystemRestricted(ctx), + database.GetChatMessagesByChatIDParams{ChatID: chatID, AfterID: 0}, + ) + require.NoError(t, err) + return messages +} + +func requireAgentChatContextCachedParts(ctx context.Context, t testing.TB, db database.Store, chatID uuid.UUID) []codersdk.ChatMessagePart { + t.Helper() + + chat, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chatID) + require.NoError(t, err) + require.True(t, chat.LastInjectedContext.Valid) + return requireAgentChatContextParts(t, chat.LastInjectedContext.RawMessage) +} + +func requireAgentChatContextStoredMessages(t testing.TB, messages []database.ChatMessage) [][]codersdk.ChatMessagePart { + t.Helper() + + stored := make([][]codersdk.ChatMessagePart, len(messages)) + for i, message := range messages { + require.Equal(t, database.ChatMessageRoleUser, message.Role) + require.True(t, message.Content.Valid) + stored[i] = requireAgentChatContextParts(t, message.Content.RawMessage) + } + return stored +} + +func agentChatContextExpectedMessages(agent database.WorkspaceAgent, messages [][]codersdk.ChatMessagePart) [][]codersdk.ChatMessagePart { + expected := make([][]codersdk.ChatMessagePart, len(messages)) + for i, parts := range messages { + expected[i] = agentChatContextExpectedStoredParts(agent, parts) + } + return expected +} + +func agentChatContextExpectedStoredParts(agent database.WorkspaceAgent, parts []codersdk.ChatMessagePart) []codersdk.ChatMessagePart { + expected := make([]codersdk.ChatMessagePart, len(parts)) + for i, part := range parts { + part.ContextFileAgentID = uuid.NullUUID{UUID: agent.ID, Valid: true} + if part.Type == codersdk.ChatMessagePartTypeContextFile { + part.ContextFileOS = agent.OperatingSystem + part.ContextFileDirectory = agentChatContextDirectory(agent) + } + expected[i] = part + } + return expected +} + +func agentChatContextExpectedCachedParts(agent database.WorkspaceAgent, parts []codersdk.ChatMessagePart) []codersdk.ChatMessagePart { + expected := make([]codersdk.ChatMessagePart, len(parts)) + for i, part := range parts { + part.ContextFileAgentID = uuid.NullUUID{UUID: agent.ID, Valid: true} + expected[i] = part + } + return expected +} + +func newAgentChatContextTestSetup(t *testing.T) agentChatContextTestSetup { + t.Helper() + + client, db := coderdtest.NewWithDatabase(t, nil) + user := coderdtest.CreateFirstUser(t, client) + workspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }).WithAgent().Do() + + return agentChatContextTestSetup{ + client: client, + db: db, + user: user, + workspace: workspace, + agentClient: agentsdk.New(client.URL, agentsdk.WithFixedToken(workspace.AgentToken)), + } +} + +func createAgentChatContextChat( + ctx context.Context, + t testing.TB, + db database.Store, + ownerID uuid.UUID, + modelConfigID uuid.UUID, + agentID uuid.UUID, + title string, +) database.Chat { + t.Helper() + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + Status: database.ChatStatusWaiting, + OwnerID: ownerID, + LastModelConfigID: modelConfigID, + Title: title, + AgentID: uuid.NullUUID{UUID: agentID, Valid: true}, + }) + require.NoError(t, err) + + return chat +} + +func createAgentChatContextChildChat( + ctx context.Context, + t testing.TB, + db database.Store, + ownerID uuid.UUID, + modelConfigID uuid.UUID, + agentID uuid.UUID, + parentChatID uuid.UUID, + title string, +) database.Chat { + t.Helper() + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + Status: database.ChatStatusWaiting, + OwnerID: ownerID, + LastModelConfigID: modelConfigID, + Title: title, + AgentID: uuid.NullUUID{UUID: agentID, Valid: true}, + ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true}, + RootChatID: uuid.NullUUID{UUID: parentChatID, Valid: true}, + }) + require.NoError(t, err) + + return chat +} + +func requireAgentChatContextParts(t testing.TB, raw json.RawMessage) []codersdk.ChatMessagePart { + t.Helper() + + var parts []codersdk.ChatMessagePart + require.NoError(t, json.Unmarshal(raw, &parts)) + return parts +} + +func agentChatContextDirectory(agent database.WorkspaceAgent) string { + if agent.ExpandedDirectory != "" { + return agent.ExpandedDirectory + } + return agent.Directory +} diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 93a6956c85..2cd6a3a7fc 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -2461,6 +2461,33 @@ type chainModeInfo struct { // trailingUserCount is the number of contiguous user messages // at the end of the conversation that form the current turn. trailingUserCount int + // contributingTrailingUserCount counts the trailing user + // messages that materially change the provider input. + contributingTrailingUserCount int +} + +func userMessageContributesToChainMode(msg database.ChatMessage) bool { + parts, err := chatprompt.ParseContent(msg) + if err != nil { + return false + } + for _, part := range parts { + switch part.Type { + case codersdk.ChatMessagePartTypeText, + codersdk.ChatMessagePartTypeReasoning: + if strings.TrimSpace(part.Text) != "" { + return true + } + case codersdk.ChatMessagePartTypeFile, + codersdk.ChatMessagePartTypeFileReference: + return true + case codersdk.ChatMessagePartTypeContextFile: + if part.ContextFileContent != "" { + return true + } + } + } + return false } // resolveChainMode scans DB messages from the end to count trailing user @@ -2470,11 +2497,13 @@ func resolveChainMode(messages []database.ChatMessage) chainModeInfo { var info chainModeInfo i := len(messages) - 1 for ; i >= 0; i-- { - if messages[i].Role == database.ChatMessageRoleUser { - info.trailingUserCount++ - continue + if messages[i].Role != database.ChatMessageRoleUser { + break + } + info.trailingUserCount++ + if userMessageContributesToChainMode(messages[i]) { + info.contributingTrailingUserCount++ } - break } for ; i >= 0; i-- { switch messages[i].Role { @@ -2497,15 +2526,15 @@ func resolveChainMode(messages []database.ChatMessage) chainModeInfo { return info } -// filterPromptForChainMode keeps only system messages and the last -// trailingUserCount user messages from the prompt. Assistant and tool -// messages are dropped because the provider already has them via the -// previous_response_id chain. +// filterPromptForChainMode keeps only system messages and the trailing +// user messages that still contribute model-visible content to the +// current turn. Assistant and tool messages are dropped because the +// provider already has them via the previous_response_id chain. func filterPromptForChainMode( prompt []fantasy.Message, - trailingUserCount int, + info chainModeInfo, ) []fantasy.Message { - if trailingUserCount <= 0 { + if info.contributingTrailingUserCount <= 0 { return prompt } @@ -2516,7 +2545,12 @@ func filterPromptForChainMode( } } - usersToSkip := totalUsers - trailingUserCount + // Prompt construction already drops user turns with no model-visible + // content, such as skill-only sentinel messages. That means the user + // count here stays aligned with contributingTrailingUserCount even + // when non-contributing DB turns are interleaved in the trailing + // block. + usersToSkip := totalUsers - info.contributingTrailingUserCount if usersToSkip < 0 { usersToSkip = 0 } @@ -2562,6 +2596,28 @@ func appendChatMessage( params.ProviderResponseID = append(params.ProviderResponseID, msg.providerResponseID) } +// BuildSingleChatMessageInsertParams creates batch insert params for one +// message using the shared chat message builder. +func BuildSingleChatMessageInsertParams( + chatID uuid.UUID, + role database.ChatMessageRole, + content pqtype.NullRawMessage, + visibility database.ChatMessageVisibility, + modelConfigID uuid.UUID, + contentVersion int16, + createdBy uuid.UUID, +) database.InsertChatMessagesParams { + params := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + ChatID: chatID, + } + msg := newChatMessage(role, content, visibility, modelConfigID, contentVersion) + if createdBy != uuid.Nil { + msg = msg.withCreatedBy(createdBy) + } + appendChatMessage(¶ms, msg) + return params +} + func insertUserMessageAndSetPending( ctx context.Context, store database.Store, @@ -4430,13 +4486,21 @@ func (p *Server) runChat( // the workspace agent has changed (e.g. workspace rebuilt). needsInstructionPersist := false hasContextFiles := false + persistedSkills := skillsFromParts(messages) + latestInjectedAgentID, hasLatestInjectedAgent := latestContextAgentID(messages) + currentWorkspaceAgentID := uuid.Nil + hasCurrentWorkspaceAgent := false if chat.WorkspaceID.Valid { + if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil { + currentWorkspaceAgentID = agent.ID + hasCurrentWorkspaceAgent = true + } persistedAgentID, found := contextFileAgentID(messages) hasContextFiles = found - if !hasContextFiles { + if !hasPersistedInstructionFiles(messages) { needsInstructionPersist = true - } else if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil && agent.ID != persistedAgentID { - // Agent changed — persist fresh instruction files. + } else if hasCurrentWorkspaceAgent && currentWorkspaceAgentID != persistedAgentID { + // Agent changed. Persist fresh instruction files. // Old context-file messages remain in the conversation // to preserve the prompt cache prefix. needsInstructionPersist = true @@ -4459,7 +4523,8 @@ func (p *Server) runChat( if needsInstructionPersist { g2.Go(func() error { var persistErr error - instruction, skills, persistErr = p.persistInstructionFiles( + var discoveredSkills []chattool.SkillMeta + instruction, discoveredSkills, persistErr = p.persistInstructionFiles( ctx, chat, modelConfig.ID, @@ -4471,6 +4536,12 @@ func (p *Server) runChat( return workspaceCtx.getWorkspaceConn(instructionCtx) }, ) + skills = selectSkillMetasForInstructionRefresh( + persistedSkills, + discoveredSkills, + uuid.NullUUID{UUID: currentWorkspaceAgentID, Valid: hasCurrentWorkspaceAgent}, + uuid.NullUUID{UUID: latestInjectedAgentID, Valid: hasLatestInjectedAgent}, + ) if persistErr != nil { p.logger.Warn(ctx, "failed to persist instruction files", slog.F("chat_id", chat.ID), @@ -4485,7 +4556,7 @@ func (p *Server) runChat( // re-injected via InsertSystem after compaction drops // those messages. No workspace dial needed. instruction = instructionFromContextFiles(messages) - skills = skillsFromParts(messages) + skills = persistedSkills } g2.Go(func() error { resolvedUserPrompt = p.resolveUserPrompt(ctx, chat.OwnerID) @@ -5103,14 +5174,14 @@ func (p *Server) runChat( // assistant and tool messages that the provider already has. chainModeActive := chatprovider.IsResponsesStoreEnabled(providerOptions) && chainInfo.previousResponseID != "" && - chainInfo.trailingUserCount > 0 && + chainInfo.contributingTrailingUserCount > 0 && chainInfo.modelConfigID == modelConfig.ID if chainModeActive { providerOptions = chatprovider.CloneWithPreviousResponseID( providerOptions, chainInfo.previousResponseID, ) - prompt = filterPromptForChainMode(prompt, chainInfo.trailingUserCount) + prompt = filterPromptForChainMode(prompt, chainInfo) } err = chatloop.Run(ctx, chatloop.RunOptions{ Model: model, @@ -5164,7 +5235,7 @@ func (p *Server) runChat( if chainModeActive { reloadedPrompt = filterPromptForChainMode( reloadedPrompt, - chainInfo.trailingUserCount, + chainInfo, ) } return reloadedPrompt, nil @@ -5537,8 +5608,9 @@ func refreshChatWorkspaceSnapshot( } // contextFileAgentID extracts the workspace agent ID from the most -// recent persisted context-file parts. Returns uuid.Nil, false if no -// context-file parts exist. +// recent persisted instruction-file parts. The skill-only sentinel is +// ignored because it does not represent persisted instruction content. +// Returns uuid.Nil, false if no instruction-file parts exist. func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) { var lastID uuid.UUID found := false @@ -5551,11 +5623,14 @@ func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) { continue } for _, p := range parts { - if p.Type == codersdk.ChatMessagePartTypeContextFile && p.ContextFileAgentID.Valid { - lastID = p.ContextFileAgentID.UUID - found = true - break + if p.Type != codersdk.ChatMessagePartTypeContextFile || + !p.ContextFileAgentID.Valid || + p.ContextFilePath == AgentChatContextSentinelPath { + continue } + lastID = p.ContextFileAgentID.UUID + found = true + break } } return lastID, found @@ -5625,13 +5700,14 @@ func (p *Server) persistInstructionFiles( // agent cannot know its own UUID, OS metadata, or // directory — those are added here at the trust boundary. var discoveredSkills []chattool.SkillMeta - var hasContent bool + var hasContent, hasContextFilePart bool agentID := uuid.NullUUID{UUID: agent.ID, Valid: true} for i := range agentParts { agentParts[i].ContextFileAgentID = agentID switch agentParts[i].Type { case codersdk.ChatMessagePartTypeContextFile: + hasContextFilePart = true agentParts[i].ContextFileContent = SanitizePromptText(agentParts[i].ContextFileContent) agentParts[i].ContextFileOS = agent.OperatingSystem agentParts[i].ContextFileDirectory = directory @@ -5652,13 +5728,13 @@ func (p *Server) persistInstructionFiles( if !workspaceConnOK { return "", nil, nil } - // Persist a sentinel (plus any skill-only parts) so - // subsequent turns skip the workspace agent dial. - if len(agentParts) == 0 { - agentParts = []codersdk.ChatMessagePart{{ + // Persist a blank context-file marker (plus any skill-only + // parts) so subsequent turns skip the workspace agent dial. + if !hasContextFilePart { + agentParts = append([]codersdk.ChatMessagePart{{ Type: codersdk.ChatMessagePartTypeContextFile, ContextFileAgentID: agentID, - }} + }}, agentParts...) } content, err := chatprompt.MarshalParts(agentParts) if err != nil { diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index a256200d3f..bf5ef8eaec 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "charm.land/fantasy" "github.com/google/uuid" "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/require" @@ -703,7 +704,33 @@ func TestPersistInstructionFilesSentinelWithSkills(t *testing.T) { gomock.Any(), agentID, ).Return(workspaceAgent, nil).Times(1) - db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + db.EXPECT().InsertChatMessages(gomock.Any(), + gomock.Cond(func(x any) bool { + arg, ok := x.(database.InsertChatMessagesParams) + if !ok || arg.ChatID != chat.ID || len(arg.Content) != 1 { + return false + } + var parts []codersdk.ChatMessagePart + if err := json.Unmarshal([]byte(arg.Content[0]), &parts); err != nil { + return false + } + foundMarker := false + foundSkill := false + for _, p := range parts { + switch p.Type { + case codersdk.ChatMessagePartTypeContextFile: + if p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) && p.ContextFileContent == "" { + foundMarker = true + } + case codersdk.ChatMessagePartTypeSkill: + if p.SkillName == "my-skill" && p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) { + foundSkill = true + } + } + } + return foundMarker && foundSkill + }), + ).Return(nil, nil).Times(1) db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), gomock.Cond(func(x any) bool { arg, ok := x.(database.UpdateChatLastInjectedContextParams) @@ -2020,6 +2047,30 @@ func TestContextFileAgentID(t *testing.T) { require.True(t, ok) }) + t.Run("IgnoresSkillOnlySentinel", func(t *testing.T) { + t.Parallel() + instructionAgentID := uuid.New() + sentinelAgentID := uuid.New() + msgs := []database.ChatMessage{ + chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/AGENTS.md", + ContextFileAgentID: uuid.NullUUID{UUID: instructionAgentID, Valid: true}, + }}), + chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: AgentChatContextSentinelPath, + ContextFileAgentID: uuid.NullUUID{ + UUID: sentinelAgentID, + Valid: true, + }, + }}), + } + id, ok := contextFileAgentID(msgs) + require.Equal(t, instructionAgentID, id) + require.True(t, ok) + }) + t.Run("SentinelWithoutAgentID", func(t *testing.T) { t.Parallel() msgs := []database.ChatMessage{ @@ -2036,6 +2087,492 @@ func TestContextFileAgentID(t *testing.T) { }) } +func TestHasPersistedInstructionFiles(t *testing.T) { + t.Parallel() + + t.Run("IgnoresAgentChatContextSentinel", func(t *testing.T) { + t.Parallel() + agentID := uuid.New() + msgs := []database.ChatMessage{ + chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: AgentChatContextSentinelPath, + ContextFileAgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + }}), + } + require.False(t, hasPersistedInstructionFiles(msgs)) + }) + + t.Run("AcceptsPersistedInstructionFile", func(t *testing.T) { + t.Parallel() + agentID := uuid.New() + msgs := []database.ChatMessage{ + chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/AGENTS.md", + ContextFileContent: "repo instructions", + ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true}, + }}), + } + require.True(t, hasPersistedInstructionFiles(msgs)) + }) +} + +func TestInstructionFromContextFilesUsesLatestContextAgent(t *testing.T) { + t.Parallel() + + oldAgentID := uuid.New() + newAgentID := uuid.New() + msgs := []database.ChatMessage{ + chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/old/AGENTS.md", + ContextFileContent: "old instructions", + ContextFileOS: "darwin", + ContextFileDirectory: "/old", + ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true}, + }}), + chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/new/AGENTS.md", + ContextFileContent: "new instructions", + ContextFileOS: "linux", + ContextFileDirectory: "/new", + ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true}, + }}), + } + + got := instructionFromContextFiles(msgs) + require.Contains(t, got, "new instructions") + require.Contains(t, got, "Operating System: linux") + require.Contains(t, got, "Working Directory: /new") + require.NotContains(t, got, "old instructions") + require.NotContains(t, got, "Operating System: darwin") +} + +func TestInstructionFromContextFilesKeepsLegacyUnstampedParts(t *testing.T) { + t.Parallel() + + oldAgentID := uuid.New() + newAgentID := uuid.New() + msgs := []database.ChatMessage{ + chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/legacy/AGENTS.md", + ContextFileContent: "legacy instructions", + }}), + chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/old/AGENTS.md", + ContextFileContent: "old instructions", + ContextFileOS: "darwin", + ContextFileDirectory: "/old", + ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true}, + }}), + chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/new/AGENTS.md", + ContextFileContent: "new instructions", + ContextFileOS: "linux", + ContextFileDirectory: "/new", + ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true}, + }}), + } + + got := instructionFromContextFiles(msgs) + require.Contains(t, got, "legacy instructions") + require.Contains(t, got, "new instructions") + require.Contains(t, got, "Operating System: linux") + require.Contains(t, got, "Working Directory: /new") + require.NotContains(t, got, "old instructions") + require.NotContains(t, got, "Operating System: darwin") +} + +func TestSkillsFromPartsKeepsLegacyUnstampedParts(t *testing.T) { + t.Parallel() + + oldAgentID := uuid.New() + newAgentID := uuid.New() + msgs := []database.ChatMessage{ + chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "repo-helper-legacy", + SkillDir: "/skills/repo-helper-legacy", + }}), + chatMessageWithParts([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/old/AGENTS.md", + ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true}, + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "repo-helper-old", + SkillDir: "/skills/repo-helper-old", + ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true}, + }, + }), + chatMessageWithParts([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: AgentChatContextSentinelPath, + ContextFileAgentID: uuid.NullUUID{ + UUID: newAgentID, + Valid: true, + }, + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "repo-helper-new", + SkillDir: "/skills/repo-helper-new", + ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true}, + }, + }), + } + + got := skillsFromParts(msgs) + require.Equal(t, []chattool.SkillMeta{ + {Name: "repo-helper-legacy", Dir: "/skills/repo-helper-legacy"}, + {Name: "repo-helper-new", Dir: "/skills/repo-helper-new"}, + }, got) +} + +func TestSkillsFromPartsUsesLatestContextAgent(t *testing.T) { + t.Parallel() + + oldAgentID := uuid.New() + newAgentID := uuid.New() + msgs := []database.ChatMessage{ + chatMessageWithParts([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/old/AGENTS.md", + ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true}, + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "repo-helper-old", + SkillDir: "/skills/repo-helper-old", + ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true}, + }, + }), + chatMessageWithParts([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: AgentChatContextSentinelPath, + ContextFileAgentID: uuid.NullUUID{ + UUID: newAgentID, + Valid: true, + }, + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "repo-helper-new", + SkillDir: "/skills/repo-helper-new", + ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true}, + }, + }), + } + + got := skillsFromParts(msgs) + require.Equal(t, []chattool.SkillMeta{{ + Name: "repo-helper-new", + Dir: "/skills/repo-helper-new", + }}, got) +} + +func TestMergeSkillMetas(t *testing.T) { + t.Parallel() + + persisted := []chattool.SkillMeta{{ + Name: "repo-helper", + Description: "Persisted skill", + Dir: "/skills/repo-helper-old", + }} + discovered := []chattool.SkillMeta{ + { + Name: "repo-helper", + Description: "Discovered replacement", + Dir: "/skills/repo-helper-new", + MetaFile: "SKILL.md", + }, + { + Name: "deep-review", + Description: "Discovered skill", + Dir: "/skills/deep-review", + }, + } + + got := mergeSkillMetas(persisted, discovered) + require.Equal(t, []chattool.SkillMeta{ + discovered[0], + discovered[1], + }, got) +} + +func TestSelectSkillMetasForInstructionRefresh(t *testing.T) { + t.Parallel() + + persisted := []chattool.SkillMeta{{Name: "persisted", Dir: "/skills/persisted"}} + discovered := []chattool.SkillMeta{{Name: "discovered", Dir: "/skills/discovered"}} + currentAgentID := uuid.New() + otherAgentID := uuid.New() + + t.Run("MergesCurrentAgentSkills", func(t *testing.T) { + t.Parallel() + got := selectSkillMetasForInstructionRefresh( + persisted, + discovered, + uuid.NullUUID{UUID: currentAgentID, Valid: true}, + uuid.NullUUID{UUID: currentAgentID, Valid: true}, + ) + require.Equal(t, []chattool.SkillMeta{discovered[0], persisted[0]}, got) + }) + + t.Run("DropsStalePersistedSkillsWhenAgentChanged", func(t *testing.T) { + t.Parallel() + got := selectSkillMetasForInstructionRefresh( + persisted, + discovered, + uuid.NullUUID{UUID: currentAgentID, Valid: true}, + uuid.NullUUID{UUID: otherAgentID, Valid: true}, + ) + require.Equal(t, discovered, got) + }) + + t.Run("PreservesPersistedSkillsWhenAgentLookupFails", func(t *testing.T) { + t.Parallel() + got := selectSkillMetasForInstructionRefresh( + persisted, + nil, + uuid.NullUUID{}, + uuid.NullUUID{UUID: otherAgentID, Valid: true}, + ) + require.Equal(t, persisted, got) + }) +} + +func TestResolveChainModeIgnoresSkillOnlySentinelMessages(t *testing.T) { + t.Parallel() + + modelConfigID := uuid.New() + assistant := database.ChatMessage{ + Role: database.ChatMessageRoleAssistant, + ProviderResponseID: sql.NullString{String: "resp-123", Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true}, + } + skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: AgentChatContextSentinelPath, + ContextFileAgentID: uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }, + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "repo-helper", + SkillDir: "/skills/repo-helper", + }, + }) + skillOnly.Role = database.ChatMessageRoleUser + user := chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeText, + Text: "latest user message", + }}) + user.Role = database.ChatMessageRoleUser + + got := resolveChainMode([]database.ChatMessage{assistant, skillOnly, user}) + require.Equal(t, "resp-123", got.previousResponseID) + require.Equal(t, modelConfigID, got.modelConfigID) + require.Equal(t, 2, got.trailingUserCount) + require.Equal(t, 1, got.contributingTrailingUserCount) +} + +func TestFilterPromptForChainModeKeepsContributingUsersAcrossSkippedSentinelTurns(t *testing.T) { + t.Parallel() + + modelConfigID := uuid.New() + priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeText, + Text: "prior user message", + }}) + priorUser.Role = database.ChatMessageRoleUser + assistant := database.ChatMessage{ + Role: database.ChatMessageRoleAssistant, + ProviderResponseID: sql.NullString{String: "resp-123", Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true}, + } + firstTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeText, + Text: "first trailing user", + }}) + firstTrailingUser.Role = database.ChatMessageRoleUser + skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: AgentChatContextSentinelPath, + ContextFileAgentID: uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }, + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "repo-helper", + SkillDir: "/skills/repo-helper", + }, + }) + skillOnly.Role = database.ChatMessageRoleUser + lastTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeText, + Text: "last trailing user", + }}) + lastTrailingUser.Role = database.ChatMessageRoleUser + + chainInfo := resolveChainMode([]database.ChatMessage{ + priorUser, + assistant, + firstTrailingUser, + skillOnly, + lastTrailingUser, + }) + require.Equal(t, 3, chainInfo.trailingUserCount) + require.Equal(t, 2, chainInfo.contributingTrailingUserCount) + + prompt := []fantasy.Message{ + { + Role: fantasy.MessageRoleSystem, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "system instruction"}, + }, + }, + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "prior user message"}, + }, + }, + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "assistant reply"}, + }, + }, + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "first trailing user"}, + }, + }, + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "last trailing user"}, + }, + }, + } + + got := filterPromptForChainMode(prompt, chainInfo) + require.Len(t, got, 3) + require.Equal(t, fantasy.MessageRoleSystem, got[0].Role) + require.Equal(t, fantasy.MessageRoleUser, got[1].Role) + require.Equal(t, fantasy.MessageRoleUser, got[2].Role) + + firstPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0]) + require.True(t, ok) + require.Equal(t, "first trailing user", firstPart.Text) + lastPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[2].Content[0]) + require.True(t, ok) + require.Equal(t, "last trailing user", lastPart.Text) +} + +func TestFilterPromptForChainModeUsesContributingTrailingUsers(t *testing.T) { + t.Parallel() + + modelConfigID := uuid.New() + priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeText, + Text: "prior user message", + }}) + priorUser.Role = database.ChatMessageRoleUser + assistant := database.ChatMessage{ + Role: database.ChatMessageRoleAssistant, + ProviderResponseID: sql.NullString{String: "resp-123", Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true}, + } + skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: AgentChatContextSentinelPath, + ContextFileAgentID: uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }, + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "repo-helper", + SkillDir: "/skills/repo-helper", + }, + }) + skillOnly.Role = database.ChatMessageRoleUser + latestUser := chatMessageWithParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeText, + Text: "latest user message", + }}) + latestUser.Role = database.ChatMessageRoleUser + + chainInfo := resolveChainMode([]database.ChatMessage{ + priorUser, + assistant, + skillOnly, + latestUser, + }) + require.Equal(t, 2, chainInfo.trailingUserCount) + require.Equal(t, 1, chainInfo.contributingTrailingUserCount) + + prompt := []fantasy.Message{ + { + Role: fantasy.MessageRoleSystem, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "system instruction"}, + }, + }, + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "prior user message"}, + }, + }, + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "assistant reply"}, + }, + }, + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "latest user message"}, + }, + }, + } + + got := filterPromptForChainMode(prompt, chainInfo) + require.Len(t, got, 2) + require.Equal(t, fantasy.MessageRoleSystem, got[0].Role) + require.Equal(t, fantasy.MessageRoleUser, got[1].Role) + + part, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0]) + require.True(t, ok) + require.Equal(t, "latest user message", part.Text) +} + func chatMessageWithParts(parts []codersdk.ChatMessagePart) database.ChatMessage { raw, _ := json.Marshal(parts) return database.ChatMessage{ diff --git a/coderd/x/chatd/contextparts.go b/coderd/x/chatd/contextparts.go new file mode 100644 index 0000000000..b013620b8c --- /dev/null +++ b/coderd/x/chatd/contextparts.go @@ -0,0 +1,153 @@ +package chatd + +import ( + "context" + "encoding/json" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/codersdk" +) + +// AgentChatContextSentinelPath marks the synthetic empty context-file +// part used to preserve skill-only workspace-agent additions across +// turns without treating them as persisted instruction files. +const AgentChatContextSentinelPath = ".coder/agent-chat-context-sentinel" + +// FilterContextParts keeps only context-file and skill parts from parts. +// When keepEmptyContextFiles is false, context-file parts with empty +// content are dropped. When keepEmptyContextFiles is true, empty +// context-file parts are preserved. +// revive:disable-next-line:flag-parameter // Required by shared helper callers. +func FilterContextParts( + parts []codersdk.ChatMessagePart, + keepEmptyContextFiles bool, +) []codersdk.ChatMessagePart { + var filtered []codersdk.ChatMessagePart + for _, part := range parts { + switch part.Type { + case codersdk.ChatMessagePartTypeContextFile: + if !keepEmptyContextFiles && part.ContextFileContent == "" { + continue + } + case codersdk.ChatMessagePartTypeSkill: + default: + continue + } + filtered = append(filtered, part) + } + return filtered +} + +// CollectContextPartsFromMessages unmarshals chat message content and +// collects the context-file and skill parts it contains. When +// keepEmptyContextFiles is false, empty context-file parts are skipped. +// When it is true, empty context-file parts are included in the result. +func CollectContextPartsFromMessages( + ctx context.Context, + logger slog.Logger, + messages []database.ChatMessage, + keepEmptyContextFiles bool, +) ([]codersdk.ChatMessagePart, error) { + var collected []codersdk.ChatMessagePart + for _, msg := range messages { + if !msg.Content.Valid { + continue + } + + var parts []codersdk.ChatMessagePart + if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil { + logger.Warn(ctx, "skipping malformed chat context message", + slog.F("chat_message_id", msg.ID), + slog.Error(err), + ) + continue + } + + collected = append( + collected, + FilterContextParts(parts, keepEmptyContextFiles)..., + ) + } + + return collected, nil +} + +func latestContextAgentIDFromParts(parts []codersdk.ChatMessagePart) (uuid.UUID, bool) { + var lastID uuid.UUID + found := false + for _, part := range parts { + if part.Type != codersdk.ChatMessagePartTypeContextFile || + !part.ContextFileAgentID.Valid { + continue + } + lastID = part.ContextFileAgentID.UUID + found = true + } + return lastID, found +} + +// FilterContextPartsToLatestAgent keeps parts stamped with the latest +// workspace-agent ID seen in the slice, plus legacy unstamped parts. +// When no stamped context-file parts exist, it returns the original +// slice unchanged. +func FilterContextPartsToLatestAgent(parts []codersdk.ChatMessagePart) []codersdk.ChatMessagePart { + latestAgentID, ok := latestContextAgentIDFromParts(parts) + if !ok { + return parts + } + + filtered := make([]codersdk.ChatMessagePart, 0, len(parts)) + for _, part := range parts { + switch part.Type { + case codersdk.ChatMessagePartTypeContextFile, + codersdk.ChatMessagePartTypeSkill: + if part.ContextFileAgentID.Valid && + part.ContextFileAgentID.UUID != latestAgentID { + continue + } + default: + continue + } + filtered = append(filtered, part) + } + return filtered +} + +// BuildLastInjectedContext filters parts down to non-empty context-file +// and skill parts, strips their internal fields, and marshals the +// result for LastInjectedContext. A nil or fully filtered input returns +// an invalid NullRawMessage. +func BuildLastInjectedContext( + parts []codersdk.ChatMessagePart, +) (pqtype.NullRawMessage, error) { + if parts == nil { + return pqtype.NullRawMessage{Valid: false}, nil + } + + filtered := FilterContextParts(parts, false) + if len(filtered) == 0 { + return pqtype.NullRawMessage{Valid: false}, nil + } + + stripped := make([]codersdk.ChatMessagePart, 0, len(filtered)) + for _, part := range filtered { + cp := part + cp.StripInternal() + stripped = append(stripped, cp) + } + + raw, err := json.Marshal(stripped) + if err != nil { + return pqtype.NullRawMessage{}, xerrors.Errorf( + "marshal injected context: %w", + err, + ) + } + + return pqtype.NullRawMessage{RawMessage: raw, Valid: true}, nil +} diff --git a/coderd/x/chatd/instruction.go b/coderd/x/chatd/instruction.go index 39575196ef..02f6dc675a 100644 --- a/coderd/x/chatd/instruction.go +++ b/coderd/x/chatd/instruction.go @@ -5,6 +5,8 @@ import ( "encoding/json" "strings" + "github.com/google/uuid" + "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/x/chatd/chattool" "github.com/coder/coder/v2/codersdk" @@ -57,6 +59,34 @@ func formatSystemInstructions( return b.String() } +// latestContextAgentID returns the most recent workspace-agent ID seen +// on any persisted context-file part, including the skill-only sentinel. +// Returns uuid.Nil, false when no stamped context-file parts exist. +func latestContextAgentID(messages []database.ChatMessage) (uuid.UUID, bool) { + var lastID uuid.UUID + found := false + for _, msg := range messages { + if !msg.Content.Valid || + !bytes.Contains(msg.Content.RawMessage, []byte(`"context-file"`)) { + continue + } + var parts []codersdk.ChatMessagePart + if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil { + continue + } + for _, part := range parts { + if part.Type != codersdk.ChatMessagePartTypeContextFile || + !part.ContextFileAgentID.Valid { + continue + } + lastID = part.ContextFileAgentID.UUID + found = true + break + } + } + return lastID, found +} + // instructionFromContextFiles reconstructs the formatted instruction // string from persisted context-file parts. This is used on non-first // turns so the instruction can be re-injected after compaction @@ -64,6 +94,7 @@ func formatSystemInstructions( func instructionFromContextFiles( messages []database.ChatMessage, ) string { + filterAgentID, filterByAgent := latestContextAgentID(messages) var contextParts []codersdk.ChatMessagePart var os, dir string for _, msg := range messages { @@ -79,6 +110,10 @@ func instructionFromContextFiles( if part.Type != codersdk.ChatMessagePartTypeContextFile { continue } + if filterByAgent && part.ContextFileAgentID.Valid && + part.ContextFileAgentID.UUID != filterAgentID { + continue + } if part.ContextFileOS != "" { os = part.ContextFileOS } @@ -93,6 +128,80 @@ func instructionFromContextFiles( return formatSystemInstructions(os, dir, contextParts) } +// hasPersistedInstructionFiles reports whether messages include a +// persisted context-file part that should suppress another baseline +// instruction-file lookup. The workspace-agent skill-only sentinel is +// ignored so default instructions still load on fresh chats. +func hasPersistedInstructionFiles( + messages []database.ChatMessage, +) bool { + for _, msg := range messages { + if !msg.Content.Valid || + !bytes.Contains(msg.Content.RawMessage, []byte(`"context-file"`)) { + continue + } + var parts []codersdk.ChatMessagePart + if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil { + continue + } + for _, part := range parts { + if part.Type != codersdk.ChatMessagePartTypeContextFile || + !part.ContextFileAgentID.Valid || + part.ContextFilePath == AgentChatContextSentinelPath { + continue + } + return true + } + } + return false +} + +func mergeSkillMetas( + persisted []chattool.SkillMeta, + discovered []chattool.SkillMeta, +) []chattool.SkillMeta { + if len(persisted) == 0 { + return discovered + } + if len(discovered) == 0 { + return persisted + } + + seen := make(map[string]struct{}, len(persisted)+len(discovered)) + merged := make([]chattool.SkillMeta, 0, len(persisted)+len(discovered)) + appendUnique := func(skill chattool.SkillMeta) { + if _, ok := seen[skill.Name]; ok { + return + } + seen[skill.Name] = struct{}{} + merged = append(merged, skill) + } + for _, skill := range discovered { + appendUnique(skill) + } + for _, skill := range persisted { + appendUnique(skill) + } + return merged +} + +// selectSkillMetasForInstructionRefresh chooses which skill metadata +// should be injected on a turn that refreshes instruction files. +func selectSkillMetasForInstructionRefresh( + persisted []chattool.SkillMeta, + discovered []chattool.SkillMeta, + currentAgentID uuid.NullUUID, + latestInjectedAgentID uuid.NullUUID, +) []chattool.SkillMeta { + if currentAgentID.Valid && latestInjectedAgentID.Valid && latestInjectedAgentID.UUID == currentAgentID.UUID { + return mergeSkillMetas(persisted, discovered) + } + if !currentAgentID.Valid && len(discovered) == 0 { + return persisted + } + return discovered +} + // skillsFromParts reconstructs skill metadata from persisted // skill parts. This is analogous to instructionFromContextFiles // so the skill index can be re-injected after compaction without @@ -100,6 +209,7 @@ func instructionFromContextFiles( func skillsFromParts( messages []database.ChatMessage, ) []chattool.SkillMeta { + filterAgentID, filterByAgent := latestContextAgentID(messages) var skills []chattool.SkillMeta for _, msg := range messages { if !msg.Content.Valid || @@ -114,6 +224,10 @@ func skillsFromParts( if part.Type != codersdk.ChatMessagePartTypeSkill { continue } + if filterByAgent && part.ContextFileAgentID.Valid && + part.ContextFileAgentID.UUID != filterAgentID { + continue + } skills = append(skills, chattool.SkillMeta{ Name: part.SkillName, Description: part.SkillDescription, diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 7be45aa102..330c22029a 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -11,6 +11,7 @@ import ( "charm.land/fantasy" "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" "golang.org/x/xerrors" "cdr.dev/slog/v3" @@ -361,48 +362,19 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database. return fantasy.NewTextErrorResponse(err.Error()), nil } - prompt := strings.TrimSpace(args.Prompt) - if prompt == "" { - return fantasy.NewTextErrorResponse("prompt is required"), nil - } - - title := strings.TrimSpace(args.Title) - if title == "" { - title = subagentFallbackChatTitle(prompt) - } - - rootChatID := parent.ID - if parent.RootChatID.Valid { - rootChatID = parent.RootChatID.UUID - } - if parent.LastModelConfigID == uuid.Nil { - return fantasy.NewTextErrorResponse("parent chat model config id is required"), nil - } - - // Create the child chat with Mode set to - // computer_use. This signals runChat to use the - // predefined computer use model and include the - // computer tool. - childChat, err := p.CreateChat(ctx, CreateOptions{ - OwnerID: parent.OwnerID, - WorkspaceID: parent.WorkspaceID, - BuildID: parent.BuildID, - AgentID: parent.AgentID, - ParentChatID: uuid.NullUUID{ - UUID: parent.ID, - Valid: true, + childChat, err := p.createChildSubagentChatWithOptions( + ctx, + parent, + args.Prompt, + args.Title, + childSubagentChatOptions{ + chatMode: database.NullChatMode{ + ChatMode: database.ChatModeComputerUse, + Valid: true, + }, + systemPrompt: computerUseSubagentSystemPrompt + "\n\n" + strings.TrimSpace(args.Prompt), }, - RootChatID: uuid.NullUUID{ - UUID: rootChatID, - Valid: true, - }, - ModelConfigID: parent.LastModelConfigID, - Title: title, - ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true}, - SystemPrompt: computerUseSubagentSystemPrompt + "\n\n" + prompt, - InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)}, - MCPServerIDs: parent.MCPServerIDs, - }) + ) if err != nil { return fantasy.NewTextErrorResponse(err.Error()), nil } @@ -427,11 +399,26 @@ func parseSubagentToolChatID(raw string) (uuid.UUID, error) { return chatID, nil } +type childSubagentChatOptions struct { + chatMode database.NullChatMode + systemPrompt string +} + func (p *Server) createChildSubagentChat( ctx context.Context, parent database.Chat, prompt string, title string, +) (database.Chat, error) { + return p.createChildSubagentChatWithOptions(ctx, parent, prompt, title, childSubagentChatOptions{}) +} + +func (p *Server) createChildSubagentChatWithOptions( + ctx context.Context, + parent database.Chat, + prompt string, + title string, + opts childSubagentChatOptions, ) (database.Chat, error) { if parent.ParentChatID.Valid { return database.Chat{}, xerrors.New("delegated chats cannot create child subagents") @@ -455,31 +442,251 @@ func (p *Server) createChildSubagentChat( return database.Chat{}, xerrors.New("parent chat model config id is required") } - child, err := p.CreateChat(ctx, CreateOptions{ - OwnerID: parent.OwnerID, - WorkspaceID: parent.WorkspaceID, - BuildID: parent.BuildID, - AgentID: parent.AgentID, - ParentChatID: uuid.NullUUID{ - UUID: parent.ID, - Valid: true, - }, - RootChatID: uuid.NullUUID{ - UUID: rootChatID, - Valid: true, - }, - ModelConfigID: parent.LastModelConfigID, - Title: title, - InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)}, - MCPServerIDs: parent.MCPServerIDs, - }) - if err != nil { - return database.Chat{}, xerrors.Errorf("create child chat: %w", err) + mcpServerIDs := parent.MCPServerIDs + if mcpServerIDs == nil { + mcpServerIDs = []uuid.UUID{} } + labelsJSON, err := json.Marshal(database.StringMap{}) + if err != nil { + return database.Chat{}, xerrors.Errorf("marshal labels: %w", err) + } + childSystemPrompt := SanitizePromptText(opts.systemPrompt) + + var child database.Chat + txErr := p.db.InTx(func(tx database.Store) error { + if limitErr := p.checkUsageLimit(ctx, tx, parent.OwnerID); limitErr != nil { + return limitErr + } + + insertedChat, err := tx.InsertChat(ctx, database.InsertChatParams{ + OwnerID: parent.OwnerID, + WorkspaceID: parent.WorkspaceID, + BuildID: parent.BuildID, + AgentID: parent.AgentID, + ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true}, + RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}, + LastModelConfigID: parent.LastModelConfigID, + Title: title, + Mode: opts.chatMode, + Status: database.ChatStatusPending, + MCPServerIDs: mcpServerIDs, + Labels: pqtype.NullRawMessage{ + RawMessage: labelsJSON, + Valid: true, + }, + DynamicTools: pqtype.NullRawMessage{}, + }) + if err != nil { + return xerrors.Errorf("insert child chat: %w", err) + } + + deploymentPrompt := p.resolveDeploymentSystemPrompt(ctx) + workspaceAwareness := "There is no workspace associated with this chat yet. Create one using the create_workspace tool before using workspace tools like execute, read_file, write_file, etc." + if insertedChat.WorkspaceID.Valid { + workspaceAwareness = "This chat is attached to a workspace. You can use workspace tools like execute, read_file, write_file, etc." + } + workspaceAwarenessContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText(workspaceAwareness), + }) + if err != nil { + return xerrors.Errorf("marshal workspace awareness: %w", err) + } + userContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)}) + if err != nil { + return xerrors.Errorf("marshal initial user content: %w", err) + } + + systemParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + ChatID: insertedChat.ID, + } + if deploymentPrompt != "" { + deploymentContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText(deploymentPrompt), + }) + if err != nil { + return xerrors.Errorf("marshal deployment system prompt: %w", err) + } + appendChatMessage(&systemParams, newChatMessage( + database.ChatMessageRoleSystem, + deploymentContent, + database.ChatMessageVisibilityModel, + parent.LastModelConfigID, + chatprompt.CurrentContentVersion, + )) + } + if childSystemPrompt != "" { + childSystemPromptContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText(childSystemPrompt), + }) + if err != nil { + return xerrors.Errorf("marshal child system prompt: %w", err) + } + appendChatMessage(&systemParams, newChatMessage( + database.ChatMessageRoleSystem, + childSystemPromptContent, + database.ChatMessageVisibilityModel, + parent.LastModelConfigID, + chatprompt.CurrentContentVersion, + )) + } + appendChatMessage(&systemParams, newChatMessage( + database.ChatMessageRoleSystem, + workspaceAwarenessContent, + database.ChatMessageVisibilityModel, + parent.LastModelConfigID, + chatprompt.CurrentContentVersion, + )) + if _, err := tx.InsertChatMessages(ctx, systemParams); err != nil { + return xerrors.Errorf("insert initial child system messages: %w", err) + } + + child = insertedChat + + // Copy persisted context before the initial child prompt so the + // child cannot be acquired until its inherited context is in + // place. signalWake runs only after commit. + copiedContextParts, err := copyParentContextMessages(ctx, p.logger, tx, parent, child) + if err != nil { + return xerrors.Errorf("copy parent context messages: %w", err) + } + if err := updateChildLastInjectedContext(ctx, p.logger, tx, child.ID, copiedContextParts); err != nil { + return xerrors.Errorf("update child injected context: %w", err) + } + + userParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + ChatID: insertedChat.ID, + } + appendChatMessage(&userParams, newChatMessage( + database.ChatMessageRoleUser, + userContent, + database.ChatMessageVisibilityBoth, + parent.LastModelConfigID, + chatprompt.CurrentContentVersion, + ).withCreatedBy(parent.OwnerID)) + if _, err := tx.InsertChatMessages(ctx, userParams); err != nil { + return xerrors.Errorf("insert initial child user message: %w", err) + } + + return nil + }, nil) + if txErr != nil { + return database.Chat{}, xerrors.Errorf("create child chat: %w", txErr) + } + + p.publishChatPubsubEvent(child, coderdpubsub.ChatEventKindCreated, nil) + p.signalWake() return child, nil } +// copyParentContextMessages reads persisted context-file and skill +// messages from the parent chat and inserts copies into the child +// chat. This ensures sub-agents inherit the same instruction and +// skill context as their parent without independently re-fetching +// from the agent. +func copyParentContextMessages( + ctx context.Context, + logger slog.Logger, + store database.Store, + parent database.Chat, + child database.Chat, +) ([]codersdk.ChatMessagePart, error) { + parentMessages, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: parent.ID, + AfterID: 0, + }) + if err != nil { + return nil, xerrors.Errorf("get parent messages: %w", err) + } + + var ( + copiedParts []codersdk.ChatMessagePart + copiedRole database.ChatMessageRole + copiedVisibility database.ChatMessageVisibility + copiedVersion int16 + ) + for _, msg := range parentMessages { + if !msg.Content.Valid { + continue + } + var parts []codersdk.ChatMessagePart + if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil { + logger.Warn(ctx, "failed to unmarshal parent context message", + slog.F("parent_chat_id", parent.ID), + slog.F("message_id", msg.ID), + slog.Error(err), + ) + continue + } + + messageContextParts := FilterContextParts(parts, true) + if len(messageContextParts) == 0 { + continue + } + if copiedParts == nil { + copiedRole = msg.Role + copiedVisibility = msg.Visibility + copiedVersion = msg.ContentVersion + } + copiedParts = append(copiedParts, messageContextParts...) + } + if len(copiedParts) == 0 { + return nil, nil + } + + copiedParts = FilterContextPartsToLatestAgent(copiedParts) + filteredContent, err := chatprompt.MarshalParts(copiedParts) + if err != nil { + return nil, xerrors.Errorf("marshal filtered context parts: %w", err) + } + + msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + ChatID: child.ID, + } + appendChatMessage(&msgParams, newChatMessage( + copiedRole, + filteredContent, + copiedVisibility, + child.LastModelConfigID, + copiedVersion, + )) + if _, err := store.InsertChatMessages(ctx, msgParams); err != nil { + return nil, xerrors.Errorf("insert context message: %w", err) + } + + return copiedParts, nil +} + +func updateChildLastInjectedContext( + ctx context.Context, + logger slog.Logger, + store database.Store, + chatID uuid.UUID, + parts []codersdk.ChatMessagePart, +) error { + parts = FilterContextPartsToLatestAgent(parts) + param, err := BuildLastInjectedContext(parts) + if err != nil { + logger.Warn(ctx, "failed to marshal inherited injected context", + slog.F("chat_id", chatID), + slog.Error(err), + ) + return xerrors.Errorf("marshal inherited injected context: %w", err) + } + if _, err := store.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{ + ID: chatID, + LastInjectedContext: param, + }); err != nil { + logger.Warn(ctx, "failed to update inherited injected context", + slog.F("chat_id", chatID), + slog.Error(err), + ) + return xerrors.Errorf("update inherited injected context: %w", err) + } + + return nil +} + func (p *Server) sendSubagentMessage( ctx context.Context, parentChatID uuid.UUID, diff --git a/coderd/x/chatd/subagent_context_internal_test.go b/coderd/x/chatd/subagent_context_internal_test.go new file mode 100644 index 0000000000..a2f908cb33 --- /dev/null +++ b/coderd/x/chatd/subagent_context_internal_test.go @@ -0,0 +1,506 @@ +package chatd + +import ( + "context" + "encoding/json" + "testing" + + "charm.land/fantasy" + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/codersdk" +) + +func TestCollectContextPartsFromMessagesSkipsSentinelContextFiles(t *testing.T) { + t.Parallel() + + content, err := json.Marshal([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/home/coder/project/.agents/skills/my-skill/SKILL.md", + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "my-skill", + SkillDescription: "A test skill", + }, + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/home/coder/project/AGENTS.md", + ContextFileContent: "# Project instructions", + }, + codersdk.ChatMessageText("ignored"), + }) + require.NoError(t, err) + + parts, err := CollectContextPartsFromMessages(context.Background(), slog.Make(), []database.ChatMessage{ //nolint:exhaustruct // Only content fields matter for this unit test. + { + ID: 1, + Content: pqtype.NullRawMessage{ + RawMessage: content, + Valid: true, + }, + }, + }, false) + require.NoError(t, err) + require.Len(t, parts, 2) + require.Equal(t, codersdk.ChatMessagePartTypeSkill, parts[0].Type) + require.Equal(t, "my-skill", parts[0].SkillName) + require.Equal(t, codersdk.ChatMessagePartTypeContextFile, parts[1].Type) + require.Equal(t, "/home/coder/project/AGENTS.md", parts[1].ContextFilePath) + require.Equal(t, "# Project instructions", parts[1].ContextFileContent) +} + +func TestCollectContextPartsFromMessagesKeepsEmptyContextFilesWhenRequested(t *testing.T) { + t.Parallel() + + content, err := json.Marshal([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: AgentChatContextSentinelPath, + ContextFileAgentID: uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }, + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "my-skill", + }, + }) + require.NoError(t, err) + + parts, err := CollectContextPartsFromMessages(context.Background(), slog.Make(), []database.ChatMessage{ //nolint:exhaustruct // Only content fields matter for this unit test. + { + ID: 1, + Content: pqtype.NullRawMessage{ + RawMessage: content, + Valid: true, + }, + }, + }, true) + require.NoError(t, err) + require.Len(t, parts, 2) + require.Equal(t, AgentChatContextSentinelPath, parts[0].ContextFilePath) + require.Equal(t, "my-skill", parts[1].SkillName) +} + +func TestFilterContextPartsToLatestAgent(t *testing.T) { + t.Parallel() + + oldAgentID := uuid.New() + newAgentID := uuid.New() + parts := []codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/legacy/AGENTS.md", + ContextFileContent: "legacy instructions", + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "repo-helper-legacy", + }, + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/old/AGENTS.md", + ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true}, + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "repo-helper-old", + ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true}, + }, + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: AgentChatContextSentinelPath, + ContextFileAgentID: uuid.NullUUID{ + UUID: newAgentID, + Valid: true, + }, + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "repo-helper-new", + ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true}, + }, + } + + got := FilterContextPartsToLatestAgent(parts) + require.Len(t, got, 4) + require.Equal(t, "/legacy/AGENTS.md", got[0].ContextFilePath) + require.Equal(t, "repo-helper-legacy", got[1].SkillName) + require.Equal(t, AgentChatContextSentinelPath, got[2].ContextFilePath) + require.Equal(t, "repo-helper-new", got[3].SkillName) +} + +func createParentChatWithInheritedContext( + ctx context.Context, + t *testing.T, + db database.Store, + server *Server, +) database.Chat { + t.Helper() + + user, model := seedInternalChatDeps(ctx, t, db) + + parent, err := server.CreateChat(ctx, CreateOptions{ + OwnerID: user.ID, + Title: "parent-with-context", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + }) + require.NoError(t, err) + + inheritedParts := []codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/home/coder/project/AGENTS.md", + ContextFileContent: "# Project instructions", + ContextFileOS: "linux", + ContextFileDirectory: "/home/coder/project", + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "my-skill", + SkillDescription: "A test skill", + SkillDir: "/home/coder/project/.agents/skills/my-skill", + ContextFileSkillMetaFile: "SKILL.md", + }, + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/home/coder/project/.agents/skills/my-skill/SKILL.md", + }, + } + content, err := json.Marshal(inheritedParts) + require.NoError(t, err) + + _, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{ + ChatID: parent.ID, + CreatedBy: []uuid.UUID{user.ID}, + ModelConfigID: []uuid.UUID{model.ID}, + Role: []database.ChatMessageRole{database.ChatMessageRoleUser}, + Content: []string{string(content)}, + ContentVersion: []int16{chatprompt.CurrentContentVersion}, + Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, + InputTokens: []int64{0}, + OutputTokens: []int64{0}, + TotalTokens: []int64{0}, + ReasoningTokens: []int64{0}, + CacheCreationTokens: []int64{0}, + CacheReadTokens: []int64{0}, + ContextLimit: []int64{0}, + Compressed: []bool{false}, + TotalCostMicros: []int64{0}, + RuntimeMs: []int64{0}, + }) + require.NoError(t, err) + + parentChat, err := db.GetChatByID(ctx, parent.ID) + require.NoError(t, err) + return parentChat +} + +func assertChildInheritedContext( + ctx context.Context, + t *testing.T, + db database.Store, + childID uuid.UUID, + prompt string, +) { + t.Helper() + + childChat, err := db.GetChatByID(ctx, childID) + require.NoError(t, err) + require.True(t, childChat.LastInjectedContext.Valid) + + var cached []codersdk.ChatMessagePart + require.NoError(t, json.Unmarshal(childChat.LastInjectedContext.RawMessage, &cached)) + require.Len(t, cached, 2) + + var sawContextFile bool + var sawSkill bool + for _, part := range cached { + switch part.Type { + case codersdk.ChatMessagePartTypeContextFile: + sawContextFile = true + require.Equal(t, "/home/coder/project/AGENTS.md", part.ContextFilePath) + require.Empty(t, part.ContextFileContent) + require.Empty(t, part.ContextFileOS) + require.Empty(t, part.ContextFileDirectory) + case codersdk.ChatMessagePartTypeSkill: + sawSkill = true + require.Equal(t, "my-skill", part.SkillName) + require.Equal(t, "A test skill", part.SkillDescription) + require.Empty(t, part.SkillDir) + require.Empty(t, part.ContextFileSkillMetaFile) + default: + t.Fatalf("unexpected cached part type %q", part.Type) + } + } + require.True(t, sawContextFile) + require.True(t, sawSkill) + + childMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: childID, + AfterID: 0, + }) + require.NoError(t, err) + + var ( + contextMessageIndexes []int + userPromptIndex = -1 + sawDBAgentsContextFile bool + sawDBSkillCompanionContext bool + sawDBSkill bool + ) + for i, msg := range childMessages { + if !msg.Content.Valid { + continue + } + + var parts []codersdk.ChatMessagePart + require.NoError(t, json.Unmarshal(msg.Content.RawMessage, &parts)) + + if len(parts) == 1 && parts[0].Type == codersdk.ChatMessagePartTypeText && parts[0].Text == prompt { + require.Equal(t, database.ChatMessageRoleUser, msg.Role) + userPromptIndex = i + continue + } + + hasInheritedContext := false + for _, part := range parts { + switch part.Type { + case codersdk.ChatMessagePartTypeContextFile: + hasInheritedContext = true + switch part.ContextFilePath { + case "/home/coder/project/AGENTS.md": + sawDBAgentsContextFile = true + require.Equal(t, "# Project instructions", part.ContextFileContent) + require.Equal(t, "linux", part.ContextFileOS) + require.Equal(t, "/home/coder/project", part.ContextFileDirectory) + case "/home/coder/project/.agents/skills/my-skill/SKILL.md": + sawDBSkillCompanionContext = true + require.Empty(t, part.ContextFileContent) + require.Empty(t, part.ContextFileOS) + require.Empty(t, part.ContextFileDirectory) + default: + t.Fatalf("unexpected child inherited context file path %q", part.ContextFilePath) + } + case codersdk.ChatMessagePartTypeSkill: + hasInheritedContext = true + sawDBSkill = true + require.Equal(t, "my-skill", part.SkillName) + require.Equal(t, "A test skill", part.SkillDescription) + require.Equal(t, "/home/coder/project/.agents/skills/my-skill", part.SkillDir) + require.Equal(t, "SKILL.md", part.ContextFileSkillMetaFile) + default: + t.Fatalf("unexpected child inherited part type %q", part.Type) + } + } + if hasInheritedContext { + require.Equal(t, database.ChatMessageRoleUser, msg.Role) + contextMessageIndexes = append(contextMessageIndexes, i) + } + } + + require.NotEmpty(t, contextMessageIndexes) + require.NotEqual(t, -1, userPromptIndex) + for _, idx := range contextMessageIndexes { + require.Less(t, idx, userPromptIndex) + } + require.True(t, sawDBAgentsContextFile) + require.True(t, sawDBSkillCompanionContext) + require.True(t, sawDBSkill) +} + +func createParentChatWithRotatedInheritedContext( + ctx context.Context, + t *testing.T, + db database.Store, + server *Server, +) database.Chat { + t.Helper() + + user, model := seedInternalChatDeps(ctx, t, db) + + parent, err := server.CreateChat(ctx, CreateOptions{ + OwnerID: user.ID, + Title: "parent-with-rotated-context", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + }) + require.NoError(t, err) + + oldAgentID := uuid.New() + newAgentID := uuid.New() + oldContent, err := json.Marshal([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/home/coder/project-old/AGENTS.md", + ContextFileContent: "# Old instructions", + ContextFileOS: "darwin", + ContextFileDirectory: "/home/coder/project-old", + ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true}, + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "old-skill", + SkillDescription: "Old skill", + SkillDir: "/home/coder/project-old/.agents/skills/old-skill", + ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true}, + }, + }) + require.NoError(t, err) + newContent, err := json.Marshal([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/home/coder/project-new/AGENTS.md", + ContextFileContent: "# New instructions", + ContextFileOS: "linux", + ContextFileDirectory: "/home/coder/project-new", + ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true}, + }, + { + Type: codersdk.ChatMessagePartTypeSkill, + SkillName: "new-skill", + SkillDescription: "New skill", + SkillDir: "/home/coder/project-new/.agents/skills/new-skill", + ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true}, + }, + }) + require.NoError(t, err) + + _, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{ + ChatID: parent.ID, + CreatedBy: []uuid.UUID{user.ID, user.ID}, + ModelConfigID: []uuid.UUID{model.ID, model.ID}, + Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleUser}, + Content: []string{string(oldContent), string(newContent)}, + ContentVersion: []int16{chatprompt.CurrentContentVersion, chatprompt.CurrentContentVersion}, + Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth}, + InputTokens: []int64{0, 0}, + OutputTokens: []int64{0, 0}, + TotalTokens: []int64{0, 0}, + ReasoningTokens: []int64{0, 0}, + CacheCreationTokens: []int64{0, 0}, + CacheReadTokens: []int64{0, 0}, + ContextLimit: []int64{0, 0}, + Compressed: []bool{false, false}, + TotalCostMicros: []int64{0, 0}, + RuntimeMs: []int64{0, 0}, + }) + require.NoError(t, err) + + parentChat, err := db.GetChatByID(ctx, parent.ID) + require.NoError(t, err) + return parentChat +} + +func TestCreateChildSubagentChatCopiesOnlyLatestAgentContext(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + ctx := chatdTestContext(t) + parentChat := createParentChatWithRotatedInheritedContext(ctx, t, db, server) + + child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "") + require.NoError(t, err) + + childChat, err := db.GetChatByID(ctx, child.ID) + require.NoError(t, err) + require.True(t, childChat.LastInjectedContext.Valid) + + var cached []codersdk.ChatMessagePart + require.NoError(t, json.Unmarshal(childChat.LastInjectedContext.RawMessage, &cached)) + require.Len(t, cached, 2) + require.Equal(t, "/home/coder/project-new/AGENTS.md", cached[0].ContextFilePath) + require.Equal(t, "new-skill", cached[1].SkillName) + + childMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: child.ID, + AfterID: 0, + }) + require.NoError(t, err) + + var inherited [][]codersdk.ChatMessagePart + for _, msg := range childMessages { + if !msg.Content.Valid { + continue + } + var parts []codersdk.ChatMessagePart + require.NoError(t, json.Unmarshal(msg.Content.RawMessage, &parts)) + if len(parts) == 0 || parts[0].Type == codersdk.ChatMessagePartTypeText { + continue + } + inherited = append(inherited, parts) + } + require.Len(t, inherited, 1) + require.Len(t, inherited[0], 2) + require.Equal(t, "/home/coder/project-new/AGENTS.md", inherited[0][0].ContextFilePath) + require.Equal(t, "# New instructions", inherited[0][0].ContextFileContent) + require.Equal(t, "new-skill", inherited[0][1].SkillName) +} + +func TestCreateChildSubagentChatUpdatesInheritedLastInjectedContext(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + ctx := chatdTestContext(t) + parentChat := createParentChatWithInheritedContext(ctx, t, db, server) + + child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "") + require.NoError(t, err) + + assertChildInheritedContext(ctx, t, db, child.ID, "inspect bindings") +} + +func TestSpawnComputerUseAgentInheritsContext(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true)) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{ + Anthropic: "test-anthropic-key", + }) + + ctx := chatdTestContext(t) + parentChat := createParentChatWithInheritedContext(ctx, t, db, server) + + tools := server.subagentTools(ctx, func() database.Chat { return parentChat }) + tool := findToolByName(tools, "spawn_computer_use_agent") + require.NotNil(t, tool) + + resp, err := tool.Run(ctx, fantasy.ToolCall{ + ID: "call-context", + Name: "spawn_computer_use_agent", + Input: `{"prompt":"inspect bindings"}`, + }) + require.NoError(t, err) + require.False(t, resp.IsError, "expected success but got: %s", resp.Content) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + childIDStr, ok := result["chat_id"].(string) + require.True(t, ok) + + childID, err := uuid.Parse(childIDStr) + require.NoError(t, err) + + childChat, err := db.GetChatByID(ctx, childID) + require.NoError(t, err) + require.True(t, childChat.Mode.Valid) + require.Equal(t, database.ChatModeComputerUse, childChat.Mode.ChatMode) + + assertChildInheritedContext(ctx, t, db, childID, "inspect bindings") +} diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index fd49f764bb..5e72eef6c2 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -892,3 +892,66 @@ func (s *SSEAgentReinitReceiver) Receive(ctx context.Context) (*Reinitialization return &reinitEvent, nil } } + +// AddChatContextRequest is the request body for adding chat context. +type AddChatContextRequest struct { + // ChatID optionally identifies the chat to add context to. + // If empty, auto-detection is used (CODER_CHAT_ID env, the + // only active chat, or the only top-level active chat for this + // agent). + ChatID uuid.UUID `json:"chat_id,omitempty"` + // Parts are the context-file and skill parts to add. + Parts []codersdk.ChatMessagePart `json:"parts"` +} + +// AddChatContextResponse is the response for adding chat context. +type AddChatContextResponse struct { + ChatID uuid.UUID `json:"chat_id"` + Count int `json:"count"` +} + +// ClearChatContextRequest is the request body for clearing chat context. +type ClearChatContextRequest struct { + // ChatID optionally identifies the chat to clear context from. + // If empty, auto-detection is used (CODER_CHAT_ID env, the + // only active chat, or the only top-level active chat for this + // agent). + ChatID uuid.UUID `json:"chat_id,omitempty"` +} + +// ClearChatContextResponse is the response for clearing chat context. +type ClearChatContextResponse struct { + ChatID uuid.UUID `json:"chat_id"` +} + +// AddChatContext adds context-file and skill parts to an active chat. +func (c *Client) AddChatContext(ctx context.Context, req AddChatContextRequest) (AddChatContextResponse, error) { + res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/experimental/chat-context", req) + if err != nil { + return AddChatContextResponse{}, xerrors.Errorf("execute request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return AddChatContextResponse{}, codersdk.ReadBodyAsError(res) + } + + var resp AddChatContextResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} + +// ClearChatContext soft-deletes context-file and skill messages from an active chat. +func (c *Client) ClearChatContext(ctx context.Context, req ClearChatContextRequest) (ClearChatContextResponse, error) { + res, err := c.SDK.Request(ctx, http.MethodDelete, "/api/v2/workspaceagents/me/experimental/chat-context", req) + if err != nil { + return ClearChatContextResponse{}, xerrors.Errorf("execute request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return ClearChatContextResponse{}, codersdk.ReadBodyAsError(res) + } + + var resp ClearChatContextResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +}