feat: add CLI commands for managing chat context from workspaces (#24105)

Adds `coder exp chat context add` and `coder exp chat context clear`
commands that run inside a workspace to manage chat context files via
the agent token.

`add` reads instruction and skill files from a directory (defaulting to
cwd) and inserts them as context-file messages into an active chat.
Multiple calls are additive — `instructionFromContextFiles` already
accumulates all context-file parts across messages.

`clear` soft-deletes all context-file messages, causing
`contextFileAgentID()` to return `!found` on the next turn, which
triggers `needsInstructionPersist=true` and re-fetches defaults from the
agent.

Both commands auto-detect the target chat via `CODER_CHAT_ID` (already
set by `agentproc` on chat-spawned processes), or fall back to
single-active-chat resolution for the agent. The `--chat` flag overrides
both.

Also adds sub-agent context inheritance: `createChildSubagentChat` now
copies parent context-file messages to child chats at spawn time, so
delegated sub-agents share the same instruction context without
independently re-fetching from the workspace agent.

<details><summary>Implementation details</summary>

**New files:**
- `cli/exp_chat.go` — CLI command tree under `coder exp chat context`

**Modified files:**
- `agent/agentcontextconfig/api.go` — `ConfigFromDir()` reads context
from an arbitrary directory without env vars
- `codersdk/agentsdk/agentsdk.go` — `AddChatContext`/`ClearChatContext`
SDK methods
- `coderd/workspaceagents.go` — POST/DELETE handlers on
`/workspaceagents/me/chat-context`
- `coderd/coderd.go` — Route registration
- `coderd/database/queries/chats.sql` — `GetActiveChatsByAgentID`,
`SoftDeleteContextFileMessages`
- `coderd/database/dbauthz/dbauthz.go` — RBAC implementations for new
queries
- `coderd/x/chatd/subagent.go` — `copyParentContextFiles` for sub-agent
inheritance
- `cli/root.go` — Register `chatCommand()` in `AGPLExperimental()`

**Auth pattern:** Uses `AgentAuth` (same as `coder external-auth`) —
agent token via `CODER_AGENT_TOKEN` + `CODER_AGENT_URL` env vars.

</details>

> 🤖 Generated by Coder Agents

---------

Co-authored-by: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com>
This commit is contained in:
Kyle Carberry
2026-04-09 10:33:00 -04:00
committed by GitHub
parent f8e8f979a2
commit 391b22aef7
29 changed files with 4354 additions and 258 deletions
+27
View File
@@ -134,6 +134,33 @@ func Config(workingDir string) (workspacesdk.ContextConfigResponse, []string) {
}, ResolvePaths(mcpConfigFile, workingDir)
}
// ContextPartsFromDir reads instruction files and discovers skills
// from a specific directory, using default file names. This is used
// by the CLI chat context commands to read context from an arbitrary
// directory without consulting agent env vars.
func ContextPartsFromDir(dir string) []codersdk.ChatMessagePart {
var parts []codersdk.ChatMessagePart
if entry, found := readInstructionFileFromDir(dir, DefaultInstructionsFile); found {
parts = append(parts, entry)
}
// Reuse ResolvePaths so CLI skill discovery follows the same
// project-relative path handling as agent config resolution.
skillParts := discoverSkills(
ResolvePaths(strings.Join([]string{DefaultSkillsDir, "skills"}, ","), dir),
DefaultSkillMetaFile,
)
parts = append(parts, skillParts...)
// Guarantee non-nil slice.
if parts == nil {
parts = []codersdk.ChatMessagePart{}
}
return parts
}
// MCPConfigFiles returns the resolved MCP configuration file
// paths for the agent's MCP manager.
func (api *API) MCPConfigFiles() []string {
+211 -165
View File
@@ -23,18 +23,144 @@ func filterParts(parts []codersdk.ChatMessagePart, t codersdk.ChatMessagePartTyp
return out
}
func TestConfig(t *testing.T) {
t.Run("Defaults", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
func writeSkillMetaFileInRoot(t *testing.T, skillsRoot, name, description string) string {
t.Helper()
// Clear all env vars so defaults are used.
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
skillDir := filepath.Join(skillsRoot, name)
require.NoError(t, os.MkdirAll(skillDir, 0o755))
require.NoError(t, os.WriteFile(
filepath.Join(skillDir, "SKILL.md"),
[]byte("---\nname: "+name+"\ndescription: "+description+"\n---\nSkill body"),
0o600,
))
return skillDir
}
func writeSkillMetaFile(t *testing.T, dir, name, description string) string {
t.Helper()
return writeSkillMetaFileInRoot(t, filepath.Join(dir, ".agents", "skills"), name, description)
}
func TestContextPartsFromDir(t *testing.T) {
t.Parallel()
t.Run("ReturnsInstructionFilePart", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
instructionPath := filepath.Join(dir, "AGENTS.md")
require.NoError(t, os.WriteFile(instructionPath, []byte("project instructions"), 0o600))
parts := agentcontextconfig.ContextPartsFromDir(dir)
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
require.Len(t, parts, 1)
require.Len(t, contextParts, 1)
require.Empty(t, skillParts)
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
require.Equal(t, "project instructions", contextParts[0].ContextFileContent)
require.False(t, contextParts[0].ContextFileTruncated)
})
t.Run("ReturnsSkillParts", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
skillDir := writeSkillMetaFile(t, dir, "my-skill", "A test skill")
parts := agentcontextconfig.ContextPartsFromDir(dir)
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
require.Len(t, parts, 1)
require.Empty(t, contextParts)
require.Len(t, skillParts, 1)
require.Equal(t, "my-skill", skillParts[0].SkillName)
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
require.Equal(t, skillDir, skillParts[0].SkillDir)
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
})
t.Run("ReturnsSkillPartsFromSkillsDir", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
skillDir := writeSkillMetaFileInRoot(
t,
filepath.Join(dir, "skills"),
"my-skill",
"A test skill",
)
parts := agentcontextconfig.ContextPartsFromDir(dir)
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
require.Len(t, parts, 1)
require.Empty(t, contextParts)
require.Len(t, skillParts, 1)
require.Equal(t, "my-skill", skillParts[0].SkillName)
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
require.Equal(t, skillDir, skillParts[0].SkillDir)
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
})
t.Run("ReturnsEmptyForEmptyDir", func(t *testing.T) {
t.Parallel()
parts := agentcontextconfig.ContextPartsFromDir(t.TempDir())
require.NotNil(t, parts)
require.Empty(t, parts)
})
t.Run("ReturnsCombinedResults", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
instructionPath := filepath.Join(dir, "AGENTS.md")
require.NoError(t, os.WriteFile(instructionPath, []byte("combined instructions"), 0o600))
skillDir := writeSkillMetaFile(t, dir, "combined-skill", "Combined test skill")
parts := agentcontextconfig.ContextPartsFromDir(dir)
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
require.Len(t, parts, 2)
require.Len(t, contextParts, 1)
require.Len(t, skillParts, 1)
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
require.Equal(t, "combined instructions", contextParts[0].ContextFileContent)
require.Equal(t, "combined-skill", skillParts[0].SkillName)
require.Equal(t, skillDir, skillParts[0].SkillDir)
})
}
func setupConfigTestEnv(t *testing.T, overrides map[string]string) string {
t.Helper()
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
for key, value := range overrides {
t.Setenv(key, value)
}
return fakeHome
}
func TestConfig(t *testing.T) {
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("Defaults", func(t *testing.T) {
setupConfigTestEnv(t, nil)
workDir := platformAbsPath("work")
cfg, mcpFiles := agentcontextconfig.Config(workDir)
@@ -46,20 +172,18 @@ func TestConfig(t *testing.T) {
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, mcpFiles)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("CustomEnvVars", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
optInstructions := t.TempDir()
optSkills := t.TempDir()
optMCP := platformAbsPath("opt", "mcp.json")
t.Setenv(agentcontextconfig.EnvInstructionsDirs, optInstructions)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "CUSTOM.md")
t.Setenv(agentcontextconfig.EnvSkillsDirs, optSkills)
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "META.yaml")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
setupConfigTestEnv(t, map[string]string{
agentcontextconfig.EnvInstructionsDirs: optInstructions,
agentcontextconfig.EnvInstructionsFile: "CUSTOM.md",
agentcontextconfig.EnvSkillsDirs: optSkills,
agentcontextconfig.EnvSkillMetaFile: "META.yaml",
agentcontextconfig.EnvMCPConfigFiles: optMCP,
})
// Create files matching the custom names so we can
// verify the env vars actually change lookup behavior.
@@ -85,15 +209,12 @@ func TestConfig(t *testing.T) {
require.Equal(t, "META.yaml", skillParts[0].ContextFileSkillMetaFile)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("WhitespaceInFileNames", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
fakeHome := setupConfigTestEnv(t, map[string]string{
agentcontextconfig.EnvInstructionsFile: " CLAUDE.md ",
})
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, " CLAUDE.md ")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
// Create a file matching the trimmed name.
@@ -106,19 +227,13 @@ func TestConfig(t *testing.T) {
require.Equal(t, "hello", ctxFiles[0].ContextFileContent)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("CommaSeparatedDirs", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
a := t.TempDir()
b := t.TempDir()
t.Setenv(agentcontextconfig.EnvInstructionsDirs, a+","+b)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
setupConfigTestEnv(t, map[string]string{
agentcontextconfig.EnvInstructionsDirs: a + "," + b,
})
// Put instruction files in both dirs.
require.NoError(t, os.WriteFile(filepath.Join(a, "AGENTS.md"), []byte("from a"), 0o600))
@@ -133,17 +248,10 @@ func TestConfig(t *testing.T) {
require.Equal(t, "from b", ctxFiles[1].ContextFileContent)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("ReadsInstructionFiles", func(t *testing.T) {
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
fakeHome := setupConfigTestEnv(t, nil)
// Create ~/.coder/AGENTS.md
coderDir := filepath.Join(fakeHome, ".coder")
@@ -164,16 +272,9 @@ func TestConfig(t *testing.T) {
require.False(t, ctxFiles[0].ContextFileTruncated)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("ReadsWorkingDirInstructionFile", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
setupConfigTestEnv(t, nil)
workDir := t.TempDir()
// Create AGENTS.md in the working directory.
@@ -193,16 +294,9 @@ func TestConfig(t *testing.T) {
require.Equal(t, filepath.Join(workDir, "AGENTS.md"), ctxFiles[0].ContextFilePath)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("TruncatesLargeInstructionFile", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
setupConfigTestEnv(t, nil)
workDir := t.TempDir()
largeContent := strings.Repeat("a", 64*1024+100)
require.NoError(t, os.WriteFile(filepath.Join(workDir, "AGENTS.md"), []byte(largeContent), 0o600))
@@ -215,79 +309,47 @@ func TestConfig(t *testing.T) {
require.Len(t, ctxFiles[0].ContextFileContent, 64*1024)
})
t.Run("SanitizesHTMLComments", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
sanitizationTests := []struct {
name string
input string
expected string
}{
{
name: "SanitizesHTMLComments",
input: "visible\n<!-- hidden -->content",
expected: "visible\ncontent",
},
{
name: "SanitizesInvisibleUnicode",
input: "before\u200bafter",
expected: "beforeafter",
},
{
name: "NormalizesCRLF",
input: "line1\r\nline2\rline3",
expected: "line1\nline2\nline3",
},
}
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
for _, tt := range sanitizationTests {
t.Run(tt.name, func(t *testing.T) {
setupConfigTestEnv(t, nil)
workDir := t.TempDir()
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte(tt.input),
0o600,
))
workDir := t.TempDir()
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte("visible\n<!-- hidden -->content"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "visible\ncontent", ctxFiles[0].ContextFileContent)
})
t.Run("SanitizesInvisibleUnicode", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
// U+200B (zero-width space) should be stripped.
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte("before\u200bafter"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "beforeafter", ctxFiles[0].ContextFileContent)
})
t.Run("NormalizesCRLF", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte("line1\r\nline2\rline3"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "line1\nline2\nline3", ctxFiles[0].ContextFileContent)
})
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, tt.expected, ctxFiles[0].ContextFileContent)
})
}
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("DiscoversSkills", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
@@ -320,17 +382,13 @@ func TestConfig(t *testing.T) {
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("SkipsMissingDirs", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
nonExistent := filepath.Join(t.TempDir(), "does-not-exist")
t.Setenv(agentcontextconfig.EnvInstructionsDirs, nonExistent)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, nonExistent)
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
setupConfigTestEnv(t, map[string]string{
agentcontextconfig.EnvInstructionsDirs: nonExistent,
agentcontextconfig.EnvSkillsDirs: nonExistent,
})
workDir := t.TempDir()
cfg, _ := agentcontextconfig.Config(workDir)
@@ -340,17 +398,13 @@ func TestConfig(t *testing.T) {
require.Empty(t, cfg.Parts)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("MCPConfigFilesResolvedSeparately", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
optMCP := platformAbsPath("opt", "custom.json")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
fakeHome := setupConfigTestEnv(t, map[string]string{
agentcontextconfig.EnvMCPConfigFiles: optMCP,
})
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
workDir := t.TempDir()
_, mcpFiles := agentcontextconfig.Config(workDir)
@@ -358,14 +412,10 @@ func TestConfig(t *testing.T) {
require.Equal(t, []string{optMCP}, mcpFiles)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("SkillNameMustMatchDir", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
fakeHome := setupConfigTestEnv(t, nil)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
skillsDir := filepath.Join(workDir, "skills")
@@ -385,14 +435,10 @@ func TestConfig(t *testing.T) {
require.Empty(t, skillParts)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("DuplicateSkillsFirstWins", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
fakeHome := setupConfigTestEnv(t, nil)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
skillsDir1 := filepath.Join(workDir, "skills1")
+194
View File
@@ -0,0 +1,194 @@
package cli
import (
"fmt"
"os"
"path/filepath"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/agent/agentcontextconfig"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/serpent"
)
func (r *RootCmd) chatCommand() *serpent.Command {
return &serpent.Command{
Use: "chat",
Short: "Manage agent chats",
Long: "Commands for interacting with chats from within a workspace.",
Handler: func(i *serpent.Invocation) error {
return i.Command.HelpHandler(i)
},
Children: []*serpent.Command{
r.chatContextCommand(),
},
}
}
func (r *RootCmd) chatContextCommand() *serpent.Command {
return &serpent.Command{
Use: "context",
Short: "Manage chat context",
Long: "Add or clear context files and skills for an active chat session.",
Handler: func(i *serpent.Invocation) error {
return i.Command.HelpHandler(i)
},
Children: []*serpent.Command{
r.chatContextAddCommand(),
r.chatContextClearCommand(),
},
}
}
func (*RootCmd) chatContextAddCommand() *serpent.Command {
var (
dir string
chatID string
)
agentAuth := &AgentAuth{}
cmd := &serpent.Command{
Use: "add",
Short: "Add context to an active chat",
Long: "Read instruction files and discover skills from a directory, then add " +
"them as context to an active chat session. Multiple calls " +
"are additive.",
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
defer stop()
if dir == "" && inv.Environ.Get("CODER") != "true" {
return xerrors.New("this command must be run inside a Coder workspace (set --dir to override)")
}
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
resolvedDir := dir
if resolvedDir == "" {
resolvedDir, err = os.Getwd()
if err != nil {
return xerrors.Errorf("get working directory: %w", err)
}
}
resolvedDir, err = filepath.Abs(resolvedDir)
if err != nil {
return xerrors.Errorf("resolve directory: %w", err)
}
info, err := os.Stat(resolvedDir)
if err != nil {
return xerrors.Errorf("cannot read directory %q: %w", resolvedDir, err)
}
if !info.IsDir() {
return xerrors.Errorf("%q is not a directory", resolvedDir)
}
parts := agentcontextconfig.ContextPartsFromDir(resolvedDir)
if len(parts) == 0 {
_, _ = fmt.Fprintln(inv.Stderr, "No context files or skills found in "+resolvedDir)
return nil
}
// Resolve chat ID from flag or auto-detect.
resolvedChatID, err := parseChatID(chatID)
if err != nil {
return err
}
resp, err := client.AddChatContext(ctx, agentsdk.AddChatContextRequest{
ChatID: resolvedChatID,
Parts: parts,
})
if err != nil {
return xerrors.Errorf("add chat context: %w", err)
}
_, _ = fmt.Fprintf(inv.Stdout, "Added %d context part(s) to chat %s\n", resp.Count, resp.ChatID)
return nil
},
Options: serpent.OptionSet{
{
Name: "Directory",
Flag: "dir",
Description: "Directory to read context files and skills from. Defaults to the current working directory.",
Value: serpent.StringOf(&dir),
},
{
Name: "Chat ID",
Flag: "chat",
Env: "CODER_CHAT_ID",
Description: "Chat ID to add context to. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
Value: serpent.StringOf(&chatID),
},
},
}
agentAuth.AttachOptions(cmd, false)
return cmd
}
func (*RootCmd) chatContextClearCommand() *serpent.Command {
var chatID string
agentAuth := &AgentAuth{}
cmd := &serpent.Command{
Use: "clear",
Short: "Clear context from an active chat",
Long: "Soft-delete all context-file and skill messages from an active chat. " +
"The next turn will re-fetch default context from the agent.",
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
defer stop()
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
resolvedChatID, err := parseChatID(chatID)
if err != nil {
return err
}
resp, err := client.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{
ChatID: resolvedChatID,
})
if err != nil {
return xerrors.Errorf("clear chat context: %w", err)
}
if resp.ChatID == uuid.Nil {
_, _ = fmt.Fprintln(inv.Stdout, "No active chats to clear.")
} else {
_, _ = fmt.Fprintf(inv.Stdout, "Cleared context from chat %s\n", resp.ChatID)
}
return nil
},
Options: serpent.OptionSet{{
Name: "Chat ID",
Flag: "chat",
Env: "CODER_CHAT_ID",
Description: "Chat ID to clear context from. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
Value: serpent.StringOf(&chatID),
}},
}
agentAuth.AttachOptions(cmd, false)
return cmd
}
// parseChatID returns the chat UUID from the flag value (which
// serpent already populates from --chat or CODER_CHAT_ID). Returns
// uuid.Nil if empty (the server will auto-detect).
func parseChatID(flagValue string) (uuid.UUID, error) {
if flagValue == "" {
return uuid.Nil, nil
}
parsed, err := uuid.Parse(flagValue)
if err != nil {
return uuid.Nil, xerrors.Errorf("invalid chat ID %q: %w", flagValue, err)
}
return parsed, nil
}
+46
View File
@@ -0,0 +1,46 @@
package cli_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
)
func TestExpChatContextAdd(t *testing.T) {
t.Parallel()
t.Run("RequiresWorkspaceOrDir", func(t *testing.T) {
t.Parallel()
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
err := inv.Run()
require.Error(t, err)
require.Contains(t, err.Error(), "this command must be run inside a Coder workspace")
})
t.Run("AllowsExplicitDir", func(t *testing.T) {
t.Parallel()
inv, _ := clitest.New(t, "exp", "chat", "context", "add", "--dir", t.TempDir())
err := inv.Run()
if err != nil {
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
}
})
t.Run("AllowsWorkspaceEnv", func(t *testing.T) {
t.Parallel()
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
inv.Environ.Set("CODER", "true")
err := inv.Run()
if err != nil {
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
}
})
}
+1
View File
@@ -148,6 +148,7 @@ func (r *RootCmd) AGPLExperimental() []*serpent.Command {
return []*serpent.Command{
r.scaletestCmd(),
r.errorExample(),
r.chatCommand(),
r.mcpCommand(),
r.promptExample(),
r.rptyCommand(),
+4
View File
@@ -1653,6 +1653,10 @@ func New(options *Options) *API {
r.Get("/gitsshkey", api.agentGitSSHKey)
r.Post("/log-source", api.workspaceAgentPostLogSource)
r.Get("/reinit", api.workspaceAgentReinit)
r.Route("/experimental", func(r chi.Router) {
r.Post("/chat-context", api.workspaceAgentAddChatContext)
r.Delete("/chat-context", api.workspaceAgentClearChatContext)
})
r.Route("/tasks/{task}", func(r chi.Router) {
r.Post("/log-snapshot", api.postWorkspaceAgentTaskLogSnapshot)
})
+7
View File
@@ -147,6 +147,10 @@ func parseSwaggerComment(commentGroup *ast.CommentGroup) SwaggerComment {
return c
}
func isExperimentalEndpoint(route string) bool {
return strings.HasPrefix(route, "/workspaceagents/me/experimental/")
}
func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments []SwaggerComment) {
assertUniqueRoutes(t, swaggerComments)
assertSingleAnnotations(t, swaggerComments)
@@ -165,6 +169,9 @@ func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments [
if strings.HasSuffix(route, "/*") {
return
}
if isExperimentalEndpoint(route) {
return
}
c := findSwaggerCommentByMethodAndRoute(swaggerComments, method, route)
assert.NotNil(t, c, "Missing @Router annotation")
+26
View File
@@ -1708,6 +1708,17 @@ func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
return q.db.CleanupDeletedMCPServerIDsFromChats(ctx)
}
func (q *querier) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
chat, err := q.db.GetChatByID(ctx, chatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
}
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
@@ -2413,6 +2424,10 @@ func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) {
return q.db.GetActiveAISeatCount(ctx)
}
func (q *querier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetActiveChatsByAgentID)(ctx, agentID)
}
func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil {
return nil, err
@@ -5728,6 +5743,17 @@ func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg databas
return q.db.SoftDeleteChatMessagesAfterID(ctx, arg)
}
func (q *querier) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
chat, err := q.db.GetChatByID(ctx, chatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.SoftDeleteContextFileMessages(ctx, chatID)
}
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
return q.db.TryAcquireLock(ctx, id)
}
+18
View File
@@ -478,6 +478,24 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), arg).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
check.Args(arg).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
}))
s.Run("GetActiveChatsByAgentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
agentID := uuid.New()
dbm.EXPECT().GetActiveChatsByAgentID(gomock.Any(), agentID).Return([]database.Chat{chat}, nil).AnyTimes()
check.Args(agentID).Asserts(chat, policy.ActionRead).Returns([]database.Chat{chat})
}))
s.Run("SoftDeleteContextFileMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().SoftDeleteContextFileMessages(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("ClearChatMessageProviderResponseIDsByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().ClearChatMessageProviderResponseIDsByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetChatCostPerChatParams{
OwnerID: uuid.New(),
+24
View File
@@ -280,6 +280,14 @@ func (m queryMetricsStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Conte
return r0
}
func (m queryMetricsStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
start := time.Now()
r0 := m.s.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("ClearChatMessageProviderResponseIDsByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ClearChatMessageProviderResponseIDsByChatID").Inc()
return r0
}
func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg)
@@ -968,6 +976,14 @@ func (m queryMetricsStore) GetActiveAISeatCount(ctx context.Context) (int64, err
return r0, r1
}
func (m queryMetricsStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetActiveChatsByAgentID(ctx, agentID)
m.queryLatencies.WithLabelValues("GetActiveChatsByAgentID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetActiveChatsByAgentID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
start := time.Now()
r0, r1 := m.s.GetActivePresetPrebuildSchedules(ctx)
@@ -4104,6 +4120,14 @@ func (m queryMetricsStore) SoftDeleteChatMessagesAfterID(ctx context.Context, ar
return r0
}
func (m queryMetricsStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
start := time.Now()
r0 := m.s.SoftDeleteContextFileMessages(ctx, chatID)
m.queryLatencies.WithLabelValues("SoftDeleteContextFileMessages").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteContextFileMessages").Inc()
return r0
}
func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
start := time.Now()
r0, r1 := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock)
+43
View File
@@ -363,6 +363,20 @@ func (mr *MockStoreMockRecorder) CleanupDeletedMCPServerIDsFromChats(ctx any) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupDeletedMCPServerIDsFromChats", reflect.TypeOf((*MockStore)(nil).CleanupDeletedMCPServerIDsFromChats), ctx)
}
// ClearChatMessageProviderResponseIDsByChatID mocks base method.
func (m *MockStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClearChatMessageProviderResponseIDsByChatID", ctx, chatID)
ret0, _ := ret[0].(error)
return ret0
}
// ClearChatMessageProviderResponseIDsByChatID indicates an expected call of ClearChatMessageProviderResponseIDsByChatID.
func (mr *MockStoreMockRecorder) ClearChatMessageProviderResponseIDsByChatID(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearChatMessageProviderResponseIDsByChatID", reflect.TypeOf((*MockStore)(nil).ClearChatMessageProviderResponseIDsByChatID), ctx, chatID)
}
// CountAIBridgeInterceptions mocks base method.
func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
m.ctrl.T.Helper()
@@ -1667,6 +1681,21 @@ func (mr *MockStoreMockRecorder) GetActiveAISeatCount(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveAISeatCount", reflect.TypeOf((*MockStore)(nil).GetActiveAISeatCount), ctx)
}
// GetActiveChatsByAgentID mocks base method.
func (m *MockStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveChatsByAgentID", ctx, agentID)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveChatsByAgentID indicates an expected call of GetActiveChatsByAgentID.
func (mr *MockStoreMockRecorder) GetActiveChatsByAgentID(ctx, agentID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveChatsByAgentID", reflect.TypeOf((*MockStore)(nil).GetActiveChatsByAgentID), ctx, agentID)
}
// GetActivePresetPrebuildSchedules mocks base method.
func (m *MockStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
m.ctrl.T.Helper()
@@ -7780,6 +7809,20 @@ func (mr *MockStoreMockRecorder) SoftDeleteChatMessagesAfterID(ctx, arg any) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessagesAfterID), ctx, arg)
}
// SoftDeleteContextFileMessages mocks base method.
func (m *MockStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SoftDeleteContextFileMessages", ctx, chatID)
ret0, _ := ret[0].(error)
return ret0
}
// SoftDeleteContextFileMessages indicates an expected call of SoftDeleteContextFileMessages.
func (mr *MockStoreMockRecorder) SoftDeleteContextFileMessages(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteContextFileMessages", reflect.TypeOf((*MockStore)(nil).SoftDeleteContextFileMessages), ctx, chatID)
}
// TryAcquireLock mocks base method.
func (m *MockStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
m.ctrl.T.Helper()
+2
View File
@@ -3783,6 +3783,8 @@ CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled);
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id);
CREATE INDEX idx_chats_agent_id ON chats USING btree (agent_id) WHERE (agent_id IS NOT NULL);
CREATE INDEX idx_chats_labels ON chats USING gin (labels);
CREATE INDEX idx_chats_last_model_config_id ON chats USING btree (last_model_config_id);
@@ -0,0 +1 @@
DROP INDEX IF EXISTS idx_chats_agent_id;
@@ -0,0 +1 @@
CREATE INDEX idx_chats_agent_id ON chats(agent_id) WHERE agent_id IS NOT NULL;
+3
View File
@@ -76,6 +76,7 @@ type sqlcQuerier interface {
CleanTailnetLostPeers(ctx context.Context) error
CleanTailnetTunnels(ctx context.Context) error
CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error)
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
@@ -215,6 +216,7 @@ type sqlcQuerier interface {
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
GetActiveAISeatCount(ctx context.Context) (int64, error)
GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]Chat, error)
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
GetActiveUserCount(ctx context.Context, includeSystem bool) (int64, error)
GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error)
@@ -893,6 +895,7 @@ type sqlcQuerier interface {
SelectUsageEventsForPublishing(ctx context.Context, now time.Time) ([]UsageEvent, error)
SoftDeleteChatMessageByID(ctx context.Context, id int64) error
SoftDeleteChatMessagesAfterID(ctx context.Context, arg SoftDeleteChatMessagesAfterIDParams) error
SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
//
// This must be called from within a transaction. The lock will be automatically
+85
View File
@@ -4505,6 +4505,19 @@ func (q *sqlQuerier) BackoffChatDiffStatus(ctx context.Context, arg BackoffChatD
return err
}
const clearChatMessageProviderResponseIDsByChatID = `-- name: ClearChatMessageProviderResponseIDsByChatID :exec
UPDATE chat_messages
SET provider_response_id = NULL
WHERE chat_id = $1::uuid
AND deleted = false
AND provider_response_id IS NOT NULL
`
func (q *sqlQuerier) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
_, err := q.db.ExecContext(ctx, clearChatMessageProviderResponseIDsByChatID, chatID)
return err
}
const countEnabledModelsWithoutPricing = `-- name: CountEnabledModelsWithoutPricing :one
SELECT COUNT(*)::bigint AS count
FROM chat_model_configs
@@ -4603,6 +4616,66 @@ func (q *sqlQuerier) DeleteOldChats(ctx context.Context, arg DeleteOldChatsParam
return result.RowsAffected()
}
const getActiveChatsByAgentID = `-- name: GetActiveChatsByAgentID :many
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
FROM chats
WHERE agent_id = $1::uuid
AND archived = false
-- Active statuses only: waiting, pending, running, paused,
-- requires_action.
-- Excludes completed and error (terminal states).
AND status IN ('waiting', 'running', 'paused', 'pending', 'requires_action')
ORDER BY updated_at DESC
`
func (q *sqlQuerier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]Chat, error) {
rows, err := q.db.QueryContext(ctx, getActiveChatsByAgentID, agentID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Chat
for rows.Next() {
var i Chat
if err := rows.Scan(
&i.ID,
&i.OwnerID,
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.ParentChatID,
&i.RootChatID,
&i.LastModelConfigID,
&i.Archived,
&i.LastError,
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
&i.PinOrder,
&i.LastReadMessageID,
&i.LastInjectedContext,
&i.DynamicTools,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getChatByID = `-- name: GetChatByID :one
SELECT
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
@@ -6706,6 +6779,18 @@ func (q *sqlQuerier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg Soft
return err
}
const softDeleteContextFileMessages = `-- name: SoftDeleteContextFileMessages :exec
UPDATE chat_messages SET deleted = true
WHERE chat_id = $1::uuid
AND deleted = false
AND content::jsonb @> '[{"type": "context-file"}]'
`
func (q *sqlQuerier) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
_, err := q.db.ExecContext(ctx, softDeleteContextFileMessages, chatID)
return err
}
const unarchiveChatByID = `-- name: UnarchiveChatByID :many
WITH chats AS (
UPDATE chats SET
+23
View File
@@ -1293,3 +1293,26 @@ GROUP BY cm.chat_id;
SELECT id, provider, model, context_limit, enabled, is_default
FROM chat_model_configs
WHERE deleted = false;
-- name: GetActiveChatsByAgentID :many
SELECT *
FROM chats
WHERE agent_id = @agent_id::uuid
AND archived = false
-- Active statuses only: waiting, pending, running, paused,
-- requires_action.
-- Excludes completed and error (terminal states).
AND status IN ('waiting', 'running', 'paused', 'pending', 'requires_action')
ORDER BY updated_at DESC;
-- name: ClearChatMessageProviderResponseIDsByChatID :exec
UPDATE chat_messages
SET provider_response_id = NULL
WHERE chat_id = @chat_id::uuid
AND deleted = false
AND provider_response_id IS NOT NULL;
-- name: SoftDeleteContextFileMessages :exec
UPDATE chat_messages SET deleted = true
WHERE chat_id = @chat_id::uuid
AND deleted = false
AND content::jsonb @> '[{"type": "context-file"}]';
+4
View File
@@ -0,0 +1,4 @@
package coderd
// InsertAgentChatTestModelConfig exposes insertAgentChatTestModelConfig for external tests.
var InsertAgentChatTestModelConfig = insertAgentChatTestModelConfig
+597
View File
@@ -42,6 +42,8 @@ import (
"github.com/coder/coder/v2/coderd/telemetry"
maputil "github.com/coder/coder/v2/coderd/util/maps"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/coderd/x/chatd"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/gitsync"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
@@ -2393,3 +2395,598 @@ func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.Wor
}
return sdk
}
// maxChatContextParts caps the number of parts per request to
// prevent unbounded message payloads.
const maxChatContextParts = 100
// maxChatContextFileBytes caps each context-file part to the same
// 64KiB budget used when the agent reads instruction files from disk.
const maxChatContextFileBytes = 64 * 1024
// maxChatContextRequestBodyBytes caps the JSON request body size for
// agent-added context to roughly the same per-part budget used when
// reading instruction files from disk.
const maxChatContextRequestBodyBytes int64 = maxChatContextParts * maxChatContextFileBytes
// sanitizeWorkspaceAgentContextFileContent applies prompt
// sanitization, then enforces the 64KiB per-file budget. The
// truncated flag is preserved when the caller already capped the
// file before sending it.
func sanitizeWorkspaceAgentContextFileContent(
content string,
truncated bool,
) (string, bool) {
content = chatd.SanitizePromptText(content)
if len(content) > maxChatContextFileBytes {
content = content[:maxChatContextFileBytes]
truncated = true
}
return content, truncated
}
// readChatContextBody reads and validates the request body for chat
// context endpoints. It handles MaxBytesReader wrapping, error
// responses, and body rewind. If the body is empty or whitespace-only
// and allowEmpty is true, it returns false without writing an error.
//
//nolint:revive // Add and clear endpoints only differ by empty-body handling.
func readChatContextBody(ctx context.Context, rw http.ResponseWriter, r *http.Request, dst any, allowEmpty bool) bool {
r.Body = http.MaxBytesReader(rw, r.Body, maxChatContextRequestBodyBytes)
body, err := io.ReadAll(r.Body)
if err != nil {
var maxBytesErr *http.MaxBytesError
if errors.As(err, &maxBytesErr) {
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
Message: "Request body too large.",
Detail: fmt.Sprintf("Maximum request body size is %d bytes.", maxChatContextRequestBodyBytes),
})
return false
}
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to read request body.",
Detail: err.Error(),
})
return false
}
if allowEmpty && len(bytes.TrimSpace(body)) == 0 {
r.Body = http.NoBody
return false
}
r.Body = io.NopCloser(bytes.NewReader(body))
return httpapi.Read(ctx, rw, r, dst)
}
// @x-apidocgen {"skip": true}
func (api *API) workspaceAgentAddChatContext(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
var req agentsdk.AddChatContextRequest
if !readChatContextBody(ctx, rw, r, &req, false) {
return
}
if len(req.Parts) == 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "No context parts provided.",
})
return
}
if len(req.Parts) > maxChatContextParts {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Too many context parts (%d). Maximum is %d.", len(req.Parts), maxChatContextParts),
})
return
}
// Filter to only non-empty context-file and skill parts.
filtered := chatd.FilterContextParts(req.Parts, false)
if len(filtered) == 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "No context-file or skill parts provided.",
})
return
}
req.Parts = filtered
responsePartCount := 0
// Use system context for chat operations since the
// workspace agent scope does not include chat resources.
// We verify agent-to-chat ownership explicitly below.
//nolint:gocritic // Agent needs system access to read/write chat resources.
sysCtx := dbauthz.AsSystemRestricted(ctx)
workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to determine workspace from agent token.",
Detail: err.Error(),
})
return
}
chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID)
if err != nil {
writeAgentChatError(ctx, rw, err)
return
}
// Stamp each persisted part with the agent identity. Context-file
// parts also get server-authoritative workspace metadata.
directory := workspaceAgent.ExpandedDirectory
if directory == "" {
directory = workspaceAgent.Directory
}
for i := range req.Parts {
req.Parts[i].ContextFileAgentID = uuid.NullUUID{
UUID: workspaceAgent.ID,
Valid: true,
}
if req.Parts[i].Type != codersdk.ChatMessagePartTypeContextFile {
continue
}
req.Parts[i].ContextFileContent, req.Parts[i].ContextFileTruncated = sanitizeWorkspaceAgentContextFileContent(
req.Parts[i].ContextFileContent,
req.Parts[i].ContextFileTruncated,
)
req.Parts[i].ContextFileOS = workspaceAgent.OperatingSystem
req.Parts[i].ContextFileDirectory = directory
}
req.Parts = chatd.FilterContextParts(req.Parts, false)
if len(req.Parts) == 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "No context-file or skill parts provided.",
})
return
}
responsePartCount = len(req.Parts)
// Skill-only messages need a sentinel context-file part so the turn
// pipeline trusts the associated skill metadata.
req.Parts = prependAgentChatContextSentinelIfNeeded(
req.Parts,
workspaceAgent.ID,
workspaceAgent.OperatingSystem,
directory,
)
content, err := chatprompt.MarshalParts(req.Parts)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to marshal context parts.",
Detail: err.Error(),
})
return
}
err = api.Database.InTx(func(tx database.Store) error {
locked, err := tx.GetChatByIDForUpdate(sysCtx, chat.ID)
if err != nil {
return xerrors.Errorf("lock chat: %w", err)
}
if !isActiveAgentChat(locked) {
return errChatNotActive
}
if !locked.AgentID.Valid || locked.AgentID.UUID != workspaceAgent.ID {
return errChatDoesNotBelongToAgent
}
if locked.OwnerID != workspace.OwnerID {
return errChatDoesNotBelongToWorkspaceOwner
}
if _, err := tx.InsertChatMessages(sysCtx, chatd.BuildSingleChatMessageInsertParams(
chat.ID,
database.ChatMessageRoleUser,
content,
database.ChatMessageVisibilityBoth,
locked.LastModelConfigID,
chatprompt.CurrentContentVersion,
uuid.Nil,
)); err != nil {
return xerrors.Errorf("insert context message: %w", err)
}
if err := updateAgentChatLastInjectedContextFromMessages(sysCtx, api.Logger, tx, chat.ID); err != nil {
return xerrors.Errorf("rebuild injected context cache: %w", err)
}
return nil
}, nil)
if err != nil {
if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
writeAgentChatError(ctx, rw, err)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to persist context message.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.AddChatContextResponse{
ChatID: chat.ID,
Count: responsePartCount,
})
}
// @x-apidocgen {"skip": true}
func (api *API) workspaceAgentClearChatContext(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
var req agentsdk.ClearChatContextRequest
populated := readChatContextBody(ctx, rw, r, &req, true)
if !populated && r.Body != http.NoBody {
return
}
// Use system context for chat operations since the
// workspace agent scope does not include chat resources.
//nolint:gocritic // Agent needs system access to read/write chat resources.
sysCtx := dbauthz.AsSystemRestricted(ctx)
workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to determine workspace from agent token.",
Detail: err.Error(),
})
return
}
chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID)
if err != nil {
// Zero active chats is not an error for clear.
if errors.Is(err, errNoActiveChats) {
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{})
return
}
writeAgentChatError(ctx, rw, err)
return
}
err = clearAgentChatContext(sysCtx, api.Database, chat.ID, workspaceAgent.ID, workspace.OwnerID)
if err != nil {
if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
writeAgentChatError(ctx, rw, err)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to clear context from chat.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{
ChatID: chat.ID,
})
}
var (
errNoActiveChats = xerrors.New("no active chats found")
errChatNotFound = xerrors.New("chat not found")
errChatNotActive = xerrors.New("chat is not active")
errChatDoesNotBelongToAgent = xerrors.New("chat does not belong to this agent")
errChatDoesNotBelongToWorkspaceOwner = xerrors.New("chat does not belong to this workspace owner")
)
type multipleActiveChatsError struct {
count int
}
func (e *multipleActiveChatsError) Error() string {
return fmt.Sprintf(
"multiple active chats (%d) found for this agent, specify a chat ID",
e.count,
)
}
func resolveDefaultAgentChat(chats []database.Chat) (database.Chat, error) {
switch len(chats) {
case 0:
return database.Chat{}, errNoActiveChats
case 1:
return chats[0], nil
}
var rootChat *database.Chat
for i := range chats {
chat := &chats[i]
if chat.ParentChatID.Valid {
continue
}
if rootChat != nil {
return database.Chat{}, &multipleActiveChatsError{count: len(chats)}
}
rootChat = chat
}
if rootChat != nil {
return *rootChat, nil
}
return database.Chat{}, &multipleActiveChatsError{count: len(chats)}
}
// resolveAgentChat finds the target chat from either an explicit ID
// or auto-detection via the agent's active chats.
func resolveAgentChat(
ctx context.Context,
db database.Store,
agentID uuid.UUID,
workspaceOwnerID uuid.UUID,
explicitChatID uuid.UUID,
) (database.Chat, error) {
if explicitChatID == uuid.Nil {
chats, err := db.GetActiveChatsByAgentID(ctx, agentID)
if err != nil {
return database.Chat{}, xerrors.Errorf("list active chats: %w", err)
}
ownerChats := make([]database.Chat, 0, len(chats))
for _, chat := range chats {
if chat.OwnerID != workspaceOwnerID {
continue
}
ownerChats = append(ownerChats, chat)
}
return resolveDefaultAgentChat(ownerChats)
}
chat, err := db.GetChatByID(ctx, explicitChatID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return database.Chat{}, errChatNotFound
}
return database.Chat{}, xerrors.Errorf("get chat by id: %w", err)
}
if !chat.AgentID.Valid || chat.AgentID.UUID != agentID {
return database.Chat{}, errChatDoesNotBelongToAgent
}
if chat.OwnerID != workspaceOwnerID {
return database.Chat{}, errChatDoesNotBelongToWorkspaceOwner
}
if !isActiveAgentChat(chat) {
return database.Chat{}, errChatNotActive
}
return chat, nil
}
func isActiveAgentChat(chat database.Chat) bool {
if chat.Archived {
return false
}
switch chat.Status {
case database.ChatStatusWaiting,
database.ChatStatusPending,
database.ChatStatusRunning,
database.ChatStatusPaused,
database.ChatStatusRequiresAction:
return true
default:
return false
}
}
func clearAgentChatContext(
ctx context.Context,
db database.Store,
chatID uuid.UUID,
agentID uuid.UUID,
workspaceOwnerID uuid.UUID,
) error {
return db.InTx(func(tx database.Store) error {
locked, err := tx.GetChatByIDForUpdate(ctx, chatID)
if err != nil {
return xerrors.Errorf("lock chat: %w", err)
}
if !isActiveAgentChat(locked) {
return errChatNotActive
}
if !locked.AgentID.Valid || locked.AgentID.UUID != agentID {
return errChatDoesNotBelongToAgent
}
if locked.OwnerID != workspaceOwnerID {
return errChatDoesNotBelongToWorkspaceOwner
}
messages, err := tx.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
})
if err != nil {
return xerrors.Errorf("get chat messages: %w", err)
}
hadInjectedContext := locked.LastInjectedContext.Valid
var skillOnlyMessageIDs []int64
for _, msg := range messages {
if !msg.Content.Valid {
continue
}
hasContextFile := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeContextFile)
hasSkill := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeSkill)
if hasContextFile || hasSkill {
hadInjectedContext = true
}
if hasSkill && !hasContextFile {
skillOnlyMessageIDs = append(skillOnlyMessageIDs, msg.ID)
}
}
if !hadInjectedContext {
return nil
}
if err := tx.SoftDeleteContextFileMessages(ctx, chatID); err != nil {
return xerrors.Errorf("soft delete context-file messages: %w", err)
}
for _, messageID := range skillOnlyMessageIDs {
if err := tx.SoftDeleteChatMessageByID(ctx, messageID); err != nil {
return xerrors.Errorf("soft delete context message %d: %w", messageID, err)
}
}
// Reset provider-side Responses chaining so the next turn replays
// the post-clear history instead of inheriting cleared context.
if err := tx.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID); err != nil {
return xerrors.Errorf("clear provider response chain: %w", err)
}
// Clear the injected-context cache inside the transaction so it is
// atomic with the soft-deletes.
param, err := chatd.BuildLastInjectedContext(nil)
if err != nil {
return xerrors.Errorf("clear injected context cache: %w", err)
}
if _, err := tx.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
ID: chatID,
LastInjectedContext: param,
}); err != nil {
return xerrors.Errorf("clear injected context cache: %w", err)
}
return nil
}, nil)
}
// prependAgentChatContextSentinelIfNeeded adds an empty context-file
// part when the request only carries skills. The turn pipeline uses
// the sentinel's agent metadata to trust the skill parts.
func prependAgentChatContextSentinelIfNeeded(
parts []codersdk.ChatMessagePart,
agentID uuid.UUID,
operatingSystem string,
directory string,
) []codersdk.ChatMessagePart {
hasContextFile := false
hasSkill := false
for _, part := range parts {
switch part.Type {
case codersdk.ChatMessagePartTypeContextFile:
hasContextFile = true
case codersdk.ChatMessagePartTypeSkill:
hasSkill = true
}
if hasContextFile && hasSkill {
return parts
}
}
if !hasSkill || hasContextFile {
return parts
}
return append([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: chatd.AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
ContextFileOS: operatingSystem,
ContextFileDirectory: directory,
}}, parts...)
}
func sortChatMessagesByCreatedAtAndID(messages []database.ChatMessage) {
sort.SliceStable(messages, func(i, j int) bool {
if messages[i].CreatedAt.Equal(messages[j].CreatedAt) {
return messages[i].ID < messages[j].ID
}
return messages[i].CreatedAt.Before(messages[j].CreatedAt)
})
}
// updateAgentChatLastInjectedContextFromMessages rebuilds the
// injected-context cache from all persisted context-file and skill parts.
func updateAgentChatLastInjectedContextFromMessages(
ctx context.Context,
logger slog.Logger,
db database.Store,
chatID uuid.UUID,
) error {
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
})
if err != nil {
return xerrors.Errorf("load context messages for injected context: %w", err)
}
sortChatMessagesByCreatedAtAndID(messages)
parts, err := chatd.CollectContextPartsFromMessages(ctx, logger, messages, true)
if err != nil {
return xerrors.Errorf("collect injected context parts: %w", err)
}
parts = chatd.FilterContextPartsToLatestAgent(parts)
param, err := chatd.BuildLastInjectedContext(parts)
if err != nil {
return xerrors.Errorf("update injected context: %w", err)
}
if _, err := db.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
ID: chatID,
LastInjectedContext: param,
}); err != nil {
return xerrors.Errorf("update injected context: %w", err)
}
return nil
}
func messageHasPartTypes(raw []byte, types ...codersdk.ChatMessagePartType) bool {
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal(raw, &parts); err != nil {
return false
}
for _, part := range parts {
for _, typ := range types {
if part.Type == typ {
return true
}
}
}
return false
}
// writeAgentChatError translates resolveAgentChat errors to HTTP
// responses.
func writeAgentChatError(
ctx context.Context,
rw http.ResponseWriter,
err error,
) {
if errors.Is(err, errNoActiveChats) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "No active chats found for this agent.",
})
return
}
if errors.Is(err, errChatNotFound) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "Chat not found.",
})
return
}
if errors.Is(err, errChatDoesNotBelongToAgent) {
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: "Chat does not belong to this agent.",
})
return
}
if errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: "Chat does not belong to this workspace owner.",
})
return
}
if errors.Is(err, errChatNotActive) {
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "Cannot modify context: this chat is no longer active.",
})
return
}
var multipleErr *multipleActiveChatsError
if errors.As(err, &multipleErr) {
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to resolve chat.",
Detail: err.Error(),
})
}
@@ -0,0 +1,76 @@
package coderd
import (
"fmt"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/testutil"
)
func TestActiveAgentChatDefinitionsAgree(t *testing.T) {
t.Parallel()
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
db, _ := dbtestutil.NewDB(t)
org, err := db.GetDefaultOrganization(ctx)
require.NoError(t, err)
owner := dbgen.User(t, db, database.User{})
workspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: org.ID,
OwnerID: owner.ID,
}).WithAgent().Do()
modelConfig := insertAgentChatTestModelConfig(ctx, t, db, owner.ID)
insertedChats := make([]database.Chat, 0, len(database.AllChatStatusValues())*2)
for _, archived := range []bool{false, true} {
for _, status := range database.AllChatStatusValues() {
chat, err := db.InsertChat(ctx, database.InsertChatParams{
Status: status,
OwnerID: owner.ID,
LastModelConfigID: modelConfig.ID,
Title: fmt.Sprintf("%s-archived-%t", status, archived),
AgentID: uuid.NullUUID{UUID: workspace.Agents[0].ID, Valid: true},
})
require.NoError(t, err)
if archived {
_, err = db.ArchiveChatByID(ctx, chat.ID)
require.NoError(t, err)
chat, err = db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
}
insertedChats = append(insertedChats, chat)
}
}
activeChats, err := db.GetActiveChatsByAgentID(ctx, workspace.Agents[0].ID)
require.NoError(t, err)
activeByID := make(map[uuid.UUID]bool, len(activeChats))
for _, chat := range activeChats {
activeByID[chat.ID] = true
}
for _, chat := range insertedChats {
require.Equalf(
t,
isActiveAgentChat(chat),
activeByID[chat.ID],
"status=%s archived=%t",
chat.Status,
chat.Archived,
)
}
}
@@ -0,0 +1,128 @@
package coderd
import (
"context"
"database/sql"
"encoding/json"
"testing"
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/codersdk"
)
func TestUpdateAgentChatLastInjectedContextFromMessagesUsesMessageIDTieBreaker(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
createdAt := time.Date(2026, time.April, 9, 13, 0, 0, 0, time.UTC)
oldAgentID := uuid.New()
newAgentID := uuid.New()
oldContent, err := json.Marshal([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileContent: "old instructions",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
}})
require.NoError(t, err)
newContent, err := json.Marshal([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/new/AGENTS.md",
ContextFileContent: "new instructions",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
}})
require.NoError(t, err)
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{
{
ID: 2,
CreatedAt: createdAt,
Content: pqtype.NullRawMessage{
RawMessage: newContent,
Valid: true,
},
},
{
ID: 1,
CreatedAt: createdAt,
Content: pqtype.NullRawMessage{
RawMessage: oldContent,
Valid: true,
},
},
}, nil)
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) {
require.Equal(t, chatID, arg.ID)
require.True(t, arg.LastInjectedContext.Valid)
var cached []codersdk.ChatMessagePart
require.NoError(t, json.Unmarshal(arg.LastInjectedContext.RawMessage, &cached))
require.Len(t, cached, 1)
require.Equal(t, "/new/AGENTS.md", cached[0].ContextFilePath)
require.Equal(t, uuid.NullUUID{UUID: newAgentID, Valid: true}, cached[0].ContextFileAgentID)
return database.Chat{}, nil
},
)
err = updateAgentChatLastInjectedContextFromMessages(
context.Background(),
slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
db,
chatID,
)
require.NoError(t, err)
}
func insertAgentChatTestModelConfig(
ctx context.Context,
t testing.TB,
db database.Store,
userID uuid.UUID,
) database.ChatModelConfig {
t.Helper()
sysCtx := dbauthz.AsSystemRestricted(ctx)
createdBy := uuid.NullUUID{UUID: userID, Valid: true}
_, err := db.InsertChatProvider(sysCtx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-api-key",
ApiKeyKeyID: sql.NullString{},
CreatedBy: createdBy,
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
model, err := db.InsertChatModelConfig(sysCtx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "gpt-4o-mini",
DisplayName: "Test Model",
CreatedBy: createdBy,
UpdatedBy: createdBy,
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 70,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
return model
}
File diff suppressed because it is too large Load Diff
+107 -31
View File
@@ -2461,6 +2461,33 @@ type chainModeInfo struct {
// trailingUserCount is the number of contiguous user messages
// at the end of the conversation that form the current turn.
trailingUserCount int
// contributingTrailingUserCount counts the trailing user
// messages that materially change the provider input.
contributingTrailingUserCount int
}
func userMessageContributesToChainMode(msg database.ChatMessage) bool {
parts, err := chatprompt.ParseContent(msg)
if err != nil {
return false
}
for _, part := range parts {
switch part.Type {
case codersdk.ChatMessagePartTypeText,
codersdk.ChatMessagePartTypeReasoning:
if strings.TrimSpace(part.Text) != "" {
return true
}
case codersdk.ChatMessagePartTypeFile,
codersdk.ChatMessagePartTypeFileReference:
return true
case codersdk.ChatMessagePartTypeContextFile:
if part.ContextFileContent != "" {
return true
}
}
}
return false
}
// resolveChainMode scans DB messages from the end to count trailing user
@@ -2470,11 +2497,13 @@ func resolveChainMode(messages []database.ChatMessage) chainModeInfo {
var info chainModeInfo
i := len(messages) - 1
for ; i >= 0; i-- {
if messages[i].Role == database.ChatMessageRoleUser {
info.trailingUserCount++
continue
if messages[i].Role != database.ChatMessageRoleUser {
break
}
info.trailingUserCount++
if userMessageContributesToChainMode(messages[i]) {
info.contributingTrailingUserCount++
}
break
}
for ; i >= 0; i-- {
switch messages[i].Role {
@@ -2497,15 +2526,15 @@ func resolveChainMode(messages []database.ChatMessage) chainModeInfo {
return info
}
// filterPromptForChainMode keeps only system messages and the last
// trailingUserCount user messages from the prompt. Assistant and tool
// messages are dropped because the provider already has them via the
// previous_response_id chain.
// filterPromptForChainMode keeps only system messages and the trailing
// user messages that still contribute model-visible content to the
// current turn. Assistant and tool messages are dropped because the
// provider already has them via the previous_response_id chain.
func filterPromptForChainMode(
prompt []fantasy.Message,
trailingUserCount int,
info chainModeInfo,
) []fantasy.Message {
if trailingUserCount <= 0 {
if info.contributingTrailingUserCount <= 0 {
return prompt
}
@@ -2516,7 +2545,12 @@ func filterPromptForChainMode(
}
}
usersToSkip := totalUsers - trailingUserCount
// Prompt construction already drops user turns with no model-visible
// content, such as skill-only sentinel messages. That means the user
// count here stays aligned with contributingTrailingUserCount even
// when non-contributing DB turns are interleaved in the trailing
// block.
usersToSkip := totalUsers - info.contributingTrailingUserCount
if usersToSkip < 0 {
usersToSkip = 0
}
@@ -2562,6 +2596,28 @@ func appendChatMessage(
params.ProviderResponseID = append(params.ProviderResponseID, msg.providerResponseID)
}
// BuildSingleChatMessageInsertParams creates batch insert params for one
// message using the shared chat message builder.
func BuildSingleChatMessageInsertParams(
chatID uuid.UUID,
role database.ChatMessageRole,
content pqtype.NullRawMessage,
visibility database.ChatMessageVisibility,
modelConfigID uuid.UUID,
contentVersion int16,
createdBy uuid.UUID,
) database.InsertChatMessagesParams {
params := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
ChatID: chatID,
}
msg := newChatMessage(role, content, visibility, modelConfigID, contentVersion)
if createdBy != uuid.Nil {
msg = msg.withCreatedBy(createdBy)
}
appendChatMessage(&params, msg)
return params
}
func insertUserMessageAndSetPending(
ctx context.Context,
store database.Store,
@@ -4430,13 +4486,21 @@ func (p *Server) runChat(
// the workspace agent has changed (e.g. workspace rebuilt).
needsInstructionPersist := false
hasContextFiles := false
persistedSkills := skillsFromParts(messages)
latestInjectedAgentID, hasLatestInjectedAgent := latestContextAgentID(messages)
currentWorkspaceAgentID := uuid.Nil
hasCurrentWorkspaceAgent := false
if chat.WorkspaceID.Valid {
if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil {
currentWorkspaceAgentID = agent.ID
hasCurrentWorkspaceAgent = true
}
persistedAgentID, found := contextFileAgentID(messages)
hasContextFiles = found
if !hasContextFiles {
if !hasPersistedInstructionFiles(messages) {
needsInstructionPersist = true
} else if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil && agent.ID != persistedAgentID {
// Agent changed — persist fresh instruction files.
} else if hasCurrentWorkspaceAgent && currentWorkspaceAgentID != persistedAgentID {
// Agent changed. Persist fresh instruction files.
// Old context-file messages remain in the conversation
// to preserve the prompt cache prefix.
needsInstructionPersist = true
@@ -4459,7 +4523,8 @@ func (p *Server) runChat(
if needsInstructionPersist {
g2.Go(func() error {
var persistErr error
instruction, skills, persistErr = p.persistInstructionFiles(
var discoveredSkills []chattool.SkillMeta
instruction, discoveredSkills, persistErr = p.persistInstructionFiles(
ctx,
chat,
modelConfig.ID,
@@ -4471,6 +4536,12 @@ func (p *Server) runChat(
return workspaceCtx.getWorkspaceConn(instructionCtx)
},
)
skills = selectSkillMetasForInstructionRefresh(
persistedSkills,
discoveredSkills,
uuid.NullUUID{UUID: currentWorkspaceAgentID, Valid: hasCurrentWorkspaceAgent},
uuid.NullUUID{UUID: latestInjectedAgentID, Valid: hasLatestInjectedAgent},
)
if persistErr != nil {
p.logger.Warn(ctx, "failed to persist instruction files",
slog.F("chat_id", chat.ID),
@@ -4485,7 +4556,7 @@ func (p *Server) runChat(
// re-injected via InsertSystem after compaction drops
// those messages. No workspace dial needed.
instruction = instructionFromContextFiles(messages)
skills = skillsFromParts(messages)
skills = persistedSkills
}
g2.Go(func() error {
resolvedUserPrompt = p.resolveUserPrompt(ctx, chat.OwnerID)
@@ -5103,14 +5174,14 @@ func (p *Server) runChat(
// assistant and tool messages that the provider already has.
chainModeActive := chatprovider.IsResponsesStoreEnabled(providerOptions) &&
chainInfo.previousResponseID != "" &&
chainInfo.trailingUserCount > 0 &&
chainInfo.contributingTrailingUserCount > 0 &&
chainInfo.modelConfigID == modelConfig.ID
if chainModeActive {
providerOptions = chatprovider.CloneWithPreviousResponseID(
providerOptions,
chainInfo.previousResponseID,
)
prompt = filterPromptForChainMode(prompt, chainInfo.trailingUserCount)
prompt = filterPromptForChainMode(prompt, chainInfo)
}
err = chatloop.Run(ctx, chatloop.RunOptions{
Model: model,
@@ -5164,7 +5235,7 @@ func (p *Server) runChat(
if chainModeActive {
reloadedPrompt = filterPromptForChainMode(
reloadedPrompt,
chainInfo.trailingUserCount,
chainInfo,
)
}
return reloadedPrompt, nil
@@ -5537,8 +5608,9 @@ func refreshChatWorkspaceSnapshot(
}
// contextFileAgentID extracts the workspace agent ID from the most
// recent persisted context-file parts. Returns uuid.Nil, false if no
// context-file parts exist.
// recent persisted instruction-file parts. The skill-only sentinel is
// ignored because it does not represent persisted instruction content.
// Returns uuid.Nil, false if no instruction-file parts exist.
func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
var lastID uuid.UUID
found := false
@@ -5551,11 +5623,14 @@ func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
continue
}
for _, p := range parts {
if p.Type == codersdk.ChatMessagePartTypeContextFile && p.ContextFileAgentID.Valid {
lastID = p.ContextFileAgentID.UUID
found = true
break
if p.Type != codersdk.ChatMessagePartTypeContextFile ||
!p.ContextFileAgentID.Valid ||
p.ContextFilePath == AgentChatContextSentinelPath {
continue
}
lastID = p.ContextFileAgentID.UUID
found = true
break
}
}
return lastID, found
@@ -5625,13 +5700,14 @@ func (p *Server) persistInstructionFiles(
// agent cannot know its own UUID, OS metadata, or
// directory — those are added here at the trust boundary.
var discoveredSkills []chattool.SkillMeta
var hasContent bool
var hasContent, hasContextFilePart bool
agentID := uuid.NullUUID{UUID: agent.ID, Valid: true}
for i := range agentParts {
agentParts[i].ContextFileAgentID = agentID
switch agentParts[i].Type {
case codersdk.ChatMessagePartTypeContextFile:
hasContextFilePart = true
agentParts[i].ContextFileContent = SanitizePromptText(agentParts[i].ContextFileContent)
agentParts[i].ContextFileOS = agent.OperatingSystem
agentParts[i].ContextFileDirectory = directory
@@ -5652,13 +5728,13 @@ func (p *Server) persistInstructionFiles(
if !workspaceConnOK {
return "", nil, nil
}
// Persist a sentinel (plus any skill-only parts) so
// subsequent turns skip the workspace agent dial.
if len(agentParts) == 0 {
agentParts = []codersdk.ChatMessagePart{{
// Persist a blank context-file marker (plus any skill-only
// parts) so subsequent turns skip the workspace agent dial.
if !hasContextFilePart {
agentParts = append([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFileAgentID: agentID,
}}
}}, agentParts...)
}
content, err := chatprompt.MarshalParts(agentParts)
if err != nil {
+538 -1
View File
@@ -8,6 +8,7 @@ import (
"testing"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
@@ -703,7 +704,33 @@ func TestPersistInstructionFilesSentinelWithSkills(t *testing.T) {
gomock.Any(),
agentID,
).Return(workspaceAgent, nil).Times(1)
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().InsertChatMessages(gomock.Any(),
gomock.Cond(func(x any) bool {
arg, ok := x.(database.InsertChatMessagesParams)
if !ok || arg.ChatID != chat.ID || len(arg.Content) != 1 {
return false
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal([]byte(arg.Content[0]), &parts); err != nil {
return false
}
foundMarker := false
foundSkill := false
for _, p := range parts {
switch p.Type {
case codersdk.ChatMessagePartTypeContextFile:
if p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) && p.ContextFileContent == "" {
foundMarker = true
}
case codersdk.ChatMessagePartTypeSkill:
if p.SkillName == "my-skill" && p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) {
foundSkill = true
}
}
}
return foundMarker && foundSkill
}),
).Return(nil, nil).Times(1)
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
gomock.Cond(func(x any) bool {
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
@@ -2020,6 +2047,30 @@ func TestContextFileAgentID(t *testing.T) {
require.True(t, ok)
})
t.Run("IgnoresSkillOnlySentinel", func(t *testing.T) {
t.Parallel()
instructionAgentID := uuid.New()
sentinelAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/workspace/AGENTS.md",
ContextFileAgentID: uuid.NullUUID{UUID: instructionAgentID, Valid: true},
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: sentinelAgentID,
Valid: true,
},
}}),
}
id, ok := contextFileAgentID(msgs)
require.Equal(t, instructionAgentID, id)
require.True(t, ok)
})
t.Run("SentinelWithoutAgentID", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
@@ -2036,6 +2087,492 @@ func TestContextFileAgentID(t *testing.T) {
})
}
func TestHasPersistedInstructionFiles(t *testing.T) {
t.Parallel()
t.Run("IgnoresAgentChatContextSentinel", func(t *testing.T) {
t.Parallel()
agentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}}),
}
require.False(t, hasPersistedInstructionFiles(msgs))
})
t.Run("AcceptsPersistedInstructionFile", func(t *testing.T) {
t.Parallel()
agentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/workspace/AGENTS.md",
ContextFileContent: "repo instructions",
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
}}),
}
require.True(t, hasPersistedInstructionFiles(msgs))
})
}
func TestInstructionFromContextFilesUsesLatestContextAgent(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileContent: "old instructions",
ContextFileOS: "darwin",
ContextFileDirectory: "/old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/new/AGENTS.md",
ContextFileContent: "new instructions",
ContextFileOS: "linux",
ContextFileDirectory: "/new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
}}),
}
got := instructionFromContextFiles(msgs)
require.Contains(t, got, "new instructions")
require.Contains(t, got, "Operating System: linux")
require.Contains(t, got, "Working Directory: /new")
require.NotContains(t, got, "old instructions")
require.NotContains(t, got, "Operating System: darwin")
}
func TestInstructionFromContextFilesKeepsLegacyUnstampedParts(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/legacy/AGENTS.md",
ContextFileContent: "legacy instructions",
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileContent: "old instructions",
ContextFileOS: "darwin",
ContextFileDirectory: "/old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/new/AGENTS.md",
ContextFileContent: "new instructions",
ContextFileOS: "linux",
ContextFileDirectory: "/new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
}}),
}
got := instructionFromContextFiles(msgs)
require.Contains(t, got, "legacy instructions")
require.Contains(t, got, "new instructions")
require.Contains(t, got, "Operating System: linux")
require.Contains(t, got, "Working Directory: /new")
require.NotContains(t, got, "old instructions")
require.NotContains(t, got, "Operating System: darwin")
}
func TestSkillsFromPartsKeepsLegacyUnstampedParts(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-legacy",
SkillDir: "/skills/repo-helper-legacy",
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-old",
SkillDir: "/skills/repo-helper-old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: newAgentID,
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-new",
SkillDir: "/skills/repo-helper-new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
},
}),
}
got := skillsFromParts(msgs)
require.Equal(t, []chattool.SkillMeta{
{Name: "repo-helper-legacy", Dir: "/skills/repo-helper-legacy"},
{Name: "repo-helper-new", Dir: "/skills/repo-helper-new"},
}, got)
}
func TestSkillsFromPartsUsesLatestContextAgent(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-old",
SkillDir: "/skills/repo-helper-old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: newAgentID,
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-new",
SkillDir: "/skills/repo-helper-new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
},
}),
}
got := skillsFromParts(msgs)
require.Equal(t, []chattool.SkillMeta{{
Name: "repo-helper-new",
Dir: "/skills/repo-helper-new",
}}, got)
}
func TestMergeSkillMetas(t *testing.T) {
t.Parallel()
persisted := []chattool.SkillMeta{{
Name: "repo-helper",
Description: "Persisted skill",
Dir: "/skills/repo-helper-old",
}}
discovered := []chattool.SkillMeta{
{
Name: "repo-helper",
Description: "Discovered replacement",
Dir: "/skills/repo-helper-new",
MetaFile: "SKILL.md",
},
{
Name: "deep-review",
Description: "Discovered skill",
Dir: "/skills/deep-review",
},
}
got := mergeSkillMetas(persisted, discovered)
require.Equal(t, []chattool.SkillMeta{
discovered[0],
discovered[1],
}, got)
}
func TestSelectSkillMetasForInstructionRefresh(t *testing.T) {
t.Parallel()
persisted := []chattool.SkillMeta{{Name: "persisted", Dir: "/skills/persisted"}}
discovered := []chattool.SkillMeta{{Name: "discovered", Dir: "/skills/discovered"}}
currentAgentID := uuid.New()
otherAgentID := uuid.New()
t.Run("MergesCurrentAgentSkills", func(t *testing.T) {
t.Parallel()
got := selectSkillMetasForInstructionRefresh(
persisted,
discovered,
uuid.NullUUID{UUID: currentAgentID, Valid: true},
uuid.NullUUID{UUID: currentAgentID, Valid: true},
)
require.Equal(t, []chattool.SkillMeta{discovered[0], persisted[0]}, got)
})
t.Run("DropsStalePersistedSkillsWhenAgentChanged", func(t *testing.T) {
t.Parallel()
got := selectSkillMetasForInstructionRefresh(
persisted,
discovered,
uuid.NullUUID{UUID: currentAgentID, Valid: true},
uuid.NullUUID{UUID: otherAgentID, Valid: true},
)
require.Equal(t, discovered, got)
})
t.Run("PreservesPersistedSkillsWhenAgentLookupFails", func(t *testing.T) {
t.Parallel()
got := selectSkillMetasForInstructionRefresh(
persisted,
nil,
uuid.NullUUID{},
uuid.NullUUID{UUID: otherAgentID, Valid: true},
)
require.Equal(t, persisted, got)
})
}
func TestResolveChainModeIgnoresSkillOnlySentinelMessages(t *testing.T) {
t.Parallel()
modelConfigID := uuid.New()
assistant := database.ChatMessage{
Role: database.ChatMessageRoleAssistant,
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper",
SkillDir: "/skills/repo-helper",
},
})
skillOnly.Role = database.ChatMessageRoleUser
user := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "latest user message",
}})
user.Role = database.ChatMessageRoleUser
got := resolveChainMode([]database.ChatMessage{assistant, skillOnly, user})
require.Equal(t, "resp-123", got.previousResponseID)
require.Equal(t, modelConfigID, got.modelConfigID)
require.Equal(t, 2, got.trailingUserCount)
require.Equal(t, 1, got.contributingTrailingUserCount)
}
func TestFilterPromptForChainModeKeepsContributingUsersAcrossSkippedSentinelTurns(t *testing.T) {
t.Parallel()
modelConfigID := uuid.New()
priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "prior user message",
}})
priorUser.Role = database.ChatMessageRoleUser
assistant := database.ChatMessage{
Role: database.ChatMessageRoleAssistant,
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
firstTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "first trailing user",
}})
firstTrailingUser.Role = database.ChatMessageRoleUser
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper",
SkillDir: "/skills/repo-helper",
},
})
skillOnly.Role = database.ChatMessageRoleUser
lastTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "last trailing user",
}})
lastTrailingUser.Role = database.ChatMessageRoleUser
chainInfo := resolveChainMode([]database.ChatMessage{
priorUser,
assistant,
firstTrailingUser,
skillOnly,
lastTrailingUser,
})
require.Equal(t, 3, chainInfo.trailingUserCount)
require.Equal(t, 2, chainInfo.contributingTrailingUserCount)
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "system instruction"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "prior user message"},
},
},
{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "assistant reply"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "first trailing user"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "last trailing user"},
},
},
}
got := filterPromptForChainMode(prompt, chainInfo)
require.Len(t, got, 3)
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
require.Equal(t, fantasy.MessageRoleUser, got[1].Role)
require.Equal(t, fantasy.MessageRoleUser, got[2].Role)
firstPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
require.True(t, ok)
require.Equal(t, "first trailing user", firstPart.Text)
lastPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[2].Content[0])
require.True(t, ok)
require.Equal(t, "last trailing user", lastPart.Text)
}
func TestFilterPromptForChainModeUsesContributingTrailingUsers(t *testing.T) {
t.Parallel()
modelConfigID := uuid.New()
priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "prior user message",
}})
priorUser.Role = database.ChatMessageRoleUser
assistant := database.ChatMessage{
Role: database.ChatMessageRoleAssistant,
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper",
SkillDir: "/skills/repo-helper",
},
})
skillOnly.Role = database.ChatMessageRoleUser
latestUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "latest user message",
}})
latestUser.Role = database.ChatMessageRoleUser
chainInfo := resolveChainMode([]database.ChatMessage{
priorUser,
assistant,
skillOnly,
latestUser,
})
require.Equal(t, 2, chainInfo.trailingUserCount)
require.Equal(t, 1, chainInfo.contributingTrailingUserCount)
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "system instruction"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "prior user message"},
},
},
{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "assistant reply"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "latest user message"},
},
},
}
got := filterPromptForChainMode(prompt, chainInfo)
require.Len(t, got, 2)
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
require.Equal(t, fantasy.MessageRoleUser, got[1].Role)
part, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
require.True(t, ok)
require.Equal(t, "latest user message", part.Text)
}
func chatMessageWithParts(parts []codersdk.ChatMessagePart) database.ChatMessage {
raw, _ := json.Marshal(parts)
return database.ChatMessage{
+153
View File
@@ -0,0 +1,153 @@
package chatd
import (
"context"
"encoding/json"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/codersdk"
)
// AgentChatContextSentinelPath marks the synthetic empty context-file
// part used to preserve skill-only workspace-agent additions across
// turns without treating them as persisted instruction files.
const AgentChatContextSentinelPath = ".coder/agent-chat-context-sentinel"
// FilterContextParts keeps only context-file and skill parts from parts.
// When keepEmptyContextFiles is false, context-file parts with empty
// content are dropped. When keepEmptyContextFiles is true, empty
// context-file parts are preserved.
// revive:disable-next-line:flag-parameter // Required by shared helper callers.
func FilterContextParts(
parts []codersdk.ChatMessagePart,
keepEmptyContextFiles bool,
) []codersdk.ChatMessagePart {
var filtered []codersdk.ChatMessagePart
for _, part := range parts {
switch part.Type {
case codersdk.ChatMessagePartTypeContextFile:
if !keepEmptyContextFiles && part.ContextFileContent == "" {
continue
}
case codersdk.ChatMessagePartTypeSkill:
default:
continue
}
filtered = append(filtered, part)
}
return filtered
}
// CollectContextPartsFromMessages unmarshals chat message content and
// collects the context-file and skill parts it contains. When
// keepEmptyContextFiles is false, empty context-file parts are skipped.
// When it is true, empty context-file parts are included in the result.
func CollectContextPartsFromMessages(
ctx context.Context,
logger slog.Logger,
messages []database.ChatMessage,
keepEmptyContextFiles bool,
) ([]codersdk.ChatMessagePart, error) {
var collected []codersdk.ChatMessagePart
for _, msg := range messages {
if !msg.Content.Valid {
continue
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil {
logger.Warn(ctx, "skipping malformed chat context message",
slog.F("chat_message_id", msg.ID),
slog.Error(err),
)
continue
}
collected = append(
collected,
FilterContextParts(parts, keepEmptyContextFiles)...,
)
}
return collected, nil
}
func latestContextAgentIDFromParts(parts []codersdk.ChatMessagePart) (uuid.UUID, bool) {
var lastID uuid.UUID
found := false
for _, part := range parts {
if part.Type != codersdk.ChatMessagePartTypeContextFile ||
!part.ContextFileAgentID.Valid {
continue
}
lastID = part.ContextFileAgentID.UUID
found = true
}
return lastID, found
}
// FilterContextPartsToLatestAgent keeps parts stamped with the latest
// workspace-agent ID seen in the slice, plus legacy unstamped parts.
// When no stamped context-file parts exist, it returns the original
// slice unchanged.
func FilterContextPartsToLatestAgent(parts []codersdk.ChatMessagePart) []codersdk.ChatMessagePart {
latestAgentID, ok := latestContextAgentIDFromParts(parts)
if !ok {
return parts
}
filtered := make([]codersdk.ChatMessagePart, 0, len(parts))
for _, part := range parts {
switch part.Type {
case codersdk.ChatMessagePartTypeContextFile,
codersdk.ChatMessagePartTypeSkill:
if part.ContextFileAgentID.Valid &&
part.ContextFileAgentID.UUID != latestAgentID {
continue
}
default:
continue
}
filtered = append(filtered, part)
}
return filtered
}
// BuildLastInjectedContext filters parts down to non-empty context-file
// and skill parts, strips their internal fields, and marshals the
// result for LastInjectedContext. A nil or fully filtered input returns
// an invalid NullRawMessage.
func BuildLastInjectedContext(
parts []codersdk.ChatMessagePart,
) (pqtype.NullRawMessage, error) {
if parts == nil {
return pqtype.NullRawMessage{Valid: false}, nil
}
filtered := FilterContextParts(parts, false)
if len(filtered) == 0 {
return pqtype.NullRawMessage{Valid: false}, nil
}
stripped := make([]codersdk.ChatMessagePart, 0, len(filtered))
for _, part := range filtered {
cp := part
cp.StripInternal()
stripped = append(stripped, cp)
}
raw, err := json.Marshal(stripped)
if err != nil {
return pqtype.NullRawMessage{}, xerrors.Errorf(
"marshal injected context: %w",
err,
)
}
return pqtype.NullRawMessage{RawMessage: raw, Valid: true}, nil
}
+114
View File
@@ -5,6 +5,8 @@ import (
"encoding/json"
"strings"
"github.com/google/uuid"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/codersdk"
@@ -57,6 +59,34 @@ func formatSystemInstructions(
return b.String()
}
// latestContextAgentID returns the most recent workspace-agent ID seen
// on any persisted context-file part, including the skill-only sentinel.
// Returns uuid.Nil, false when no stamped context-file parts exist.
func latestContextAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
var lastID uuid.UUID
found := false
for _, msg := range messages {
if !msg.Content.Valid ||
!bytes.Contains(msg.Content.RawMessage, []byte(`"context-file"`)) {
continue
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil {
continue
}
for _, part := range parts {
if part.Type != codersdk.ChatMessagePartTypeContextFile ||
!part.ContextFileAgentID.Valid {
continue
}
lastID = part.ContextFileAgentID.UUID
found = true
break
}
}
return lastID, found
}
// instructionFromContextFiles reconstructs the formatted instruction
// string from persisted context-file parts. This is used on non-first
// turns so the instruction can be re-injected after compaction
@@ -64,6 +94,7 @@ func formatSystemInstructions(
func instructionFromContextFiles(
messages []database.ChatMessage,
) string {
filterAgentID, filterByAgent := latestContextAgentID(messages)
var contextParts []codersdk.ChatMessagePart
var os, dir string
for _, msg := range messages {
@@ -79,6 +110,10 @@ func instructionFromContextFiles(
if part.Type != codersdk.ChatMessagePartTypeContextFile {
continue
}
if filterByAgent && part.ContextFileAgentID.Valid &&
part.ContextFileAgentID.UUID != filterAgentID {
continue
}
if part.ContextFileOS != "" {
os = part.ContextFileOS
}
@@ -93,6 +128,80 @@ func instructionFromContextFiles(
return formatSystemInstructions(os, dir, contextParts)
}
// hasPersistedInstructionFiles reports whether messages include a
// persisted context-file part that should suppress another baseline
// instruction-file lookup. The workspace-agent skill-only sentinel is
// ignored so default instructions still load on fresh chats.
func hasPersistedInstructionFiles(
messages []database.ChatMessage,
) bool {
for _, msg := range messages {
if !msg.Content.Valid ||
!bytes.Contains(msg.Content.RawMessage, []byte(`"context-file"`)) {
continue
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil {
continue
}
for _, part := range parts {
if part.Type != codersdk.ChatMessagePartTypeContextFile ||
!part.ContextFileAgentID.Valid ||
part.ContextFilePath == AgentChatContextSentinelPath {
continue
}
return true
}
}
return false
}
func mergeSkillMetas(
persisted []chattool.SkillMeta,
discovered []chattool.SkillMeta,
) []chattool.SkillMeta {
if len(persisted) == 0 {
return discovered
}
if len(discovered) == 0 {
return persisted
}
seen := make(map[string]struct{}, len(persisted)+len(discovered))
merged := make([]chattool.SkillMeta, 0, len(persisted)+len(discovered))
appendUnique := func(skill chattool.SkillMeta) {
if _, ok := seen[skill.Name]; ok {
return
}
seen[skill.Name] = struct{}{}
merged = append(merged, skill)
}
for _, skill := range discovered {
appendUnique(skill)
}
for _, skill := range persisted {
appendUnique(skill)
}
return merged
}
// selectSkillMetasForInstructionRefresh chooses which skill metadata
// should be injected on a turn that refreshes instruction files.
func selectSkillMetasForInstructionRefresh(
persisted []chattool.SkillMeta,
discovered []chattool.SkillMeta,
currentAgentID uuid.NullUUID,
latestInjectedAgentID uuid.NullUUID,
) []chattool.SkillMeta {
if currentAgentID.Valid && latestInjectedAgentID.Valid && latestInjectedAgentID.UUID == currentAgentID.UUID {
return mergeSkillMetas(persisted, discovered)
}
if !currentAgentID.Valid && len(discovered) == 0 {
return persisted
}
return discovered
}
// skillsFromParts reconstructs skill metadata from persisted
// skill parts. This is analogous to instructionFromContextFiles
// so the skill index can be re-injected after compaction without
@@ -100,6 +209,7 @@ func instructionFromContextFiles(
func skillsFromParts(
messages []database.ChatMessage,
) []chattool.SkillMeta {
filterAgentID, filterByAgent := latestContextAgentID(messages)
var skills []chattool.SkillMeta
for _, msg := range messages {
if !msg.Content.Valid ||
@@ -114,6 +224,10 @@ func skillsFromParts(
if part.Type != codersdk.ChatMessagePartTypeSkill {
continue
}
if filterByAgent && part.ContextFileAgentID.Valid &&
part.ContextFileAgentID.UUID != filterAgentID {
continue
}
skills = append(skills, chattool.SkillMeta{
Name: part.SkillName,
Description: part.SkillDescription,
+268 -61
View File
@@ -11,6 +11,7 @@ import (
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
@@ -361,48 +362,19 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
return fantasy.NewTextErrorResponse(err.Error()), nil
}
prompt := strings.TrimSpace(args.Prompt)
if prompt == "" {
return fantasy.NewTextErrorResponse("prompt is required"), nil
}
title := strings.TrimSpace(args.Title)
if title == "" {
title = subagentFallbackChatTitle(prompt)
}
rootChatID := parent.ID
if parent.RootChatID.Valid {
rootChatID = parent.RootChatID.UUID
}
if parent.LastModelConfigID == uuid.Nil {
return fantasy.NewTextErrorResponse("parent chat model config id is required"), nil
}
// Create the child chat with Mode set to
// computer_use. This signals runChat to use the
// predefined computer use model and include the
// computer tool.
childChat, err := p.CreateChat(ctx, CreateOptions{
OwnerID: parent.OwnerID,
WorkspaceID: parent.WorkspaceID,
BuildID: parent.BuildID,
AgentID: parent.AgentID,
ParentChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
childChat, err := p.createChildSubagentChatWithOptions(
ctx,
parent,
args.Prompt,
args.Title,
childSubagentChatOptions{
chatMode: database.NullChatMode{
ChatMode: database.ChatModeComputerUse,
Valid: true,
},
systemPrompt: computerUseSubagentSystemPrompt + "\n\n" + strings.TrimSpace(args.Prompt),
},
RootChatID: uuid.NullUUID{
UUID: rootChatID,
Valid: true,
},
ModelConfigID: parent.LastModelConfigID,
Title: title,
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
SystemPrompt: computerUseSubagentSystemPrompt + "\n\n" + prompt,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
MCPServerIDs: parent.MCPServerIDs,
})
)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
@@ -427,11 +399,26 @@ func parseSubagentToolChatID(raw string) (uuid.UUID, error) {
return chatID, nil
}
type childSubagentChatOptions struct {
chatMode database.NullChatMode
systemPrompt string
}
func (p *Server) createChildSubagentChat(
ctx context.Context,
parent database.Chat,
prompt string,
title string,
) (database.Chat, error) {
return p.createChildSubagentChatWithOptions(ctx, parent, prompt, title, childSubagentChatOptions{})
}
func (p *Server) createChildSubagentChatWithOptions(
ctx context.Context,
parent database.Chat,
prompt string,
title string,
opts childSubagentChatOptions,
) (database.Chat, error) {
if parent.ParentChatID.Valid {
return database.Chat{}, xerrors.New("delegated chats cannot create child subagents")
@@ -455,31 +442,251 @@ func (p *Server) createChildSubagentChat(
return database.Chat{}, xerrors.New("parent chat model config id is required")
}
child, err := p.CreateChat(ctx, CreateOptions{
OwnerID: parent.OwnerID,
WorkspaceID: parent.WorkspaceID,
BuildID: parent.BuildID,
AgentID: parent.AgentID,
ParentChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: rootChatID,
Valid: true,
},
ModelConfigID: parent.LastModelConfigID,
Title: title,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
MCPServerIDs: parent.MCPServerIDs,
})
if err != nil {
return database.Chat{}, xerrors.Errorf("create child chat: %w", err)
mcpServerIDs := parent.MCPServerIDs
if mcpServerIDs == nil {
mcpServerIDs = []uuid.UUID{}
}
labelsJSON, err := json.Marshal(database.StringMap{})
if err != nil {
return database.Chat{}, xerrors.Errorf("marshal labels: %w", err)
}
childSystemPrompt := SanitizePromptText(opts.systemPrompt)
var child database.Chat
txErr := p.db.InTx(func(tx database.Store) error {
if limitErr := p.checkUsageLimit(ctx, tx, parent.OwnerID); limitErr != nil {
return limitErr
}
insertedChat, err := tx.InsertChat(ctx, database.InsertChatParams{
OwnerID: parent.OwnerID,
WorkspaceID: parent.WorkspaceID,
BuildID: parent.BuildID,
AgentID: parent.AgentID,
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true},
LastModelConfigID: parent.LastModelConfigID,
Title: title,
Mode: opts.chatMode,
Status: database.ChatStatusPending,
MCPServerIDs: mcpServerIDs,
Labels: pqtype.NullRawMessage{
RawMessage: labelsJSON,
Valid: true,
},
DynamicTools: pqtype.NullRawMessage{},
})
if err != nil {
return xerrors.Errorf("insert child chat: %w", err)
}
deploymentPrompt := p.resolveDeploymentSystemPrompt(ctx)
workspaceAwareness := "There is no workspace associated with this chat yet. Create one using the create_workspace tool before using workspace tools like execute, read_file, write_file, etc."
if insertedChat.WorkspaceID.Valid {
workspaceAwareness = "This chat is attached to a workspace. You can use workspace tools like execute, read_file, write_file, etc."
}
workspaceAwarenessContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText(workspaceAwareness),
})
if err != nil {
return xerrors.Errorf("marshal workspace awareness: %w", err)
}
userContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)})
if err != nil {
return xerrors.Errorf("marshal initial user content: %w", err)
}
systemParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
ChatID: insertedChat.ID,
}
if deploymentPrompt != "" {
deploymentContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText(deploymentPrompt),
})
if err != nil {
return xerrors.Errorf("marshal deployment system prompt: %w", err)
}
appendChatMessage(&systemParams, newChatMessage(
database.ChatMessageRoleSystem,
deploymentContent,
database.ChatMessageVisibilityModel,
parent.LastModelConfigID,
chatprompt.CurrentContentVersion,
))
}
if childSystemPrompt != "" {
childSystemPromptContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText(childSystemPrompt),
})
if err != nil {
return xerrors.Errorf("marshal child system prompt: %w", err)
}
appendChatMessage(&systemParams, newChatMessage(
database.ChatMessageRoleSystem,
childSystemPromptContent,
database.ChatMessageVisibilityModel,
parent.LastModelConfigID,
chatprompt.CurrentContentVersion,
))
}
appendChatMessage(&systemParams, newChatMessage(
database.ChatMessageRoleSystem,
workspaceAwarenessContent,
database.ChatMessageVisibilityModel,
parent.LastModelConfigID,
chatprompt.CurrentContentVersion,
))
if _, err := tx.InsertChatMessages(ctx, systemParams); err != nil {
return xerrors.Errorf("insert initial child system messages: %w", err)
}
child = insertedChat
// Copy persisted context before the initial child prompt so the
// child cannot be acquired until its inherited context is in
// place. signalWake runs only after commit.
copiedContextParts, err := copyParentContextMessages(ctx, p.logger, tx, parent, child)
if err != nil {
return xerrors.Errorf("copy parent context messages: %w", err)
}
if err := updateChildLastInjectedContext(ctx, p.logger, tx, child.ID, copiedContextParts); err != nil {
return xerrors.Errorf("update child injected context: %w", err)
}
userParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
ChatID: insertedChat.ID,
}
appendChatMessage(&userParams, newChatMessage(
database.ChatMessageRoleUser,
userContent,
database.ChatMessageVisibilityBoth,
parent.LastModelConfigID,
chatprompt.CurrentContentVersion,
).withCreatedBy(parent.OwnerID))
if _, err := tx.InsertChatMessages(ctx, userParams); err != nil {
return xerrors.Errorf("insert initial child user message: %w", err)
}
return nil
}, nil)
if txErr != nil {
return database.Chat{}, xerrors.Errorf("create child chat: %w", txErr)
}
p.publishChatPubsubEvent(child, coderdpubsub.ChatEventKindCreated, nil)
p.signalWake()
return child, nil
}
// copyParentContextMessages reads persisted context-file and skill
// messages from the parent chat and inserts copies into the child
// chat. This ensures sub-agents inherit the same instruction and
// skill context as their parent without independently re-fetching
// from the agent.
func copyParentContextMessages(
ctx context.Context,
logger slog.Logger,
store database.Store,
parent database.Chat,
child database.Chat,
) ([]codersdk.ChatMessagePart, error) {
parentMessages, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: parent.ID,
AfterID: 0,
})
if err != nil {
return nil, xerrors.Errorf("get parent messages: %w", err)
}
var (
copiedParts []codersdk.ChatMessagePart
copiedRole database.ChatMessageRole
copiedVisibility database.ChatMessageVisibility
copiedVersion int16
)
for _, msg := range parentMessages {
if !msg.Content.Valid {
continue
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil {
logger.Warn(ctx, "failed to unmarshal parent context message",
slog.F("parent_chat_id", parent.ID),
slog.F("message_id", msg.ID),
slog.Error(err),
)
continue
}
messageContextParts := FilterContextParts(parts, true)
if len(messageContextParts) == 0 {
continue
}
if copiedParts == nil {
copiedRole = msg.Role
copiedVisibility = msg.Visibility
copiedVersion = msg.ContentVersion
}
copiedParts = append(copiedParts, messageContextParts...)
}
if len(copiedParts) == 0 {
return nil, nil
}
copiedParts = FilterContextPartsToLatestAgent(copiedParts)
filteredContent, err := chatprompt.MarshalParts(copiedParts)
if err != nil {
return nil, xerrors.Errorf("marshal filtered context parts: %w", err)
}
msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
ChatID: child.ID,
}
appendChatMessage(&msgParams, newChatMessage(
copiedRole,
filteredContent,
copiedVisibility,
child.LastModelConfigID,
copiedVersion,
))
if _, err := store.InsertChatMessages(ctx, msgParams); err != nil {
return nil, xerrors.Errorf("insert context message: %w", err)
}
return copiedParts, nil
}
func updateChildLastInjectedContext(
ctx context.Context,
logger slog.Logger,
store database.Store,
chatID uuid.UUID,
parts []codersdk.ChatMessagePart,
) error {
parts = FilterContextPartsToLatestAgent(parts)
param, err := BuildLastInjectedContext(parts)
if err != nil {
logger.Warn(ctx, "failed to marshal inherited injected context",
slog.F("chat_id", chatID),
slog.Error(err),
)
return xerrors.Errorf("marshal inherited injected context: %w", err)
}
if _, err := store.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
ID: chatID,
LastInjectedContext: param,
}); err != nil {
logger.Warn(ctx, "failed to update inherited injected context",
slog.F("chat_id", chatID),
slog.Error(err),
)
return xerrors.Errorf("update inherited injected context: %w", err)
}
return nil
}
func (p *Server) sendSubagentMessage(
ctx context.Context,
parentChatID uuid.UUID,
@@ -0,0 +1,506 @@
package chatd
import (
"context"
"encoding/json"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/codersdk"
)
func TestCollectContextPartsFromMessagesSkipsSentinelContextFiles(t *testing.T) {
t.Parallel()
content, err := json.Marshal([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/home/coder/project/.agents/skills/my-skill/SKILL.md",
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "my-skill",
SkillDescription: "A test skill",
},
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/home/coder/project/AGENTS.md",
ContextFileContent: "# Project instructions",
},
codersdk.ChatMessageText("ignored"),
})
require.NoError(t, err)
parts, err := CollectContextPartsFromMessages(context.Background(), slog.Make(), []database.ChatMessage{ //nolint:exhaustruct // Only content fields matter for this unit test.
{
ID: 1,
Content: pqtype.NullRawMessage{
RawMessage: content,
Valid: true,
},
},
}, false)
require.NoError(t, err)
require.Len(t, parts, 2)
require.Equal(t, codersdk.ChatMessagePartTypeSkill, parts[0].Type)
require.Equal(t, "my-skill", parts[0].SkillName)
require.Equal(t, codersdk.ChatMessagePartTypeContextFile, parts[1].Type)
require.Equal(t, "/home/coder/project/AGENTS.md", parts[1].ContextFilePath)
require.Equal(t, "# Project instructions", parts[1].ContextFileContent)
}
func TestCollectContextPartsFromMessagesKeepsEmptyContextFilesWhenRequested(t *testing.T) {
t.Parallel()
content, err := json.Marshal([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "my-skill",
},
})
require.NoError(t, err)
parts, err := CollectContextPartsFromMessages(context.Background(), slog.Make(), []database.ChatMessage{ //nolint:exhaustruct // Only content fields matter for this unit test.
{
ID: 1,
Content: pqtype.NullRawMessage{
RawMessage: content,
Valid: true,
},
},
}, true)
require.NoError(t, err)
require.Len(t, parts, 2)
require.Equal(t, AgentChatContextSentinelPath, parts[0].ContextFilePath)
require.Equal(t, "my-skill", parts[1].SkillName)
}
func TestFilterContextPartsToLatestAgent(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
parts := []codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/legacy/AGENTS.md",
ContextFileContent: "legacy instructions",
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-legacy",
},
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: newAgentID,
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
},
}
got := FilterContextPartsToLatestAgent(parts)
require.Len(t, got, 4)
require.Equal(t, "/legacy/AGENTS.md", got[0].ContextFilePath)
require.Equal(t, "repo-helper-legacy", got[1].SkillName)
require.Equal(t, AgentChatContextSentinelPath, got[2].ContextFilePath)
require.Equal(t, "repo-helper-new", got[3].SkillName)
}
func createParentChatWithInheritedContext(
ctx context.Context,
t *testing.T,
db database.Store,
server *Server,
) database.Chat {
t.Helper()
user, model := seedInternalChatDeps(ctx, t, db)
parent, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
Title: "parent-with-context",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
inheritedParts := []codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/home/coder/project/AGENTS.md",
ContextFileContent: "# Project instructions",
ContextFileOS: "linux",
ContextFileDirectory: "/home/coder/project",
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "my-skill",
SkillDescription: "A test skill",
SkillDir: "/home/coder/project/.agents/skills/my-skill",
ContextFileSkillMetaFile: "SKILL.md",
},
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/home/coder/project/.agents/skills/my-skill/SKILL.md",
},
}
content, err := json.Marshal(inheritedParts)
require.NoError(t, err)
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: parent.ID,
CreatedBy: []uuid.UUID{user.ID},
ModelConfigID: []uuid.UUID{model.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleUser},
Content: []string{string(content)},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{0},
RuntimeMs: []int64{0},
})
require.NoError(t, err)
parentChat, err := db.GetChatByID(ctx, parent.ID)
require.NoError(t, err)
return parentChat
}
func assertChildInheritedContext(
ctx context.Context,
t *testing.T,
db database.Store,
childID uuid.UUID,
prompt string,
) {
t.Helper()
childChat, err := db.GetChatByID(ctx, childID)
require.NoError(t, err)
require.True(t, childChat.LastInjectedContext.Valid)
var cached []codersdk.ChatMessagePart
require.NoError(t, json.Unmarshal(childChat.LastInjectedContext.RawMessage, &cached))
require.Len(t, cached, 2)
var sawContextFile bool
var sawSkill bool
for _, part := range cached {
switch part.Type {
case codersdk.ChatMessagePartTypeContextFile:
sawContextFile = true
require.Equal(t, "/home/coder/project/AGENTS.md", part.ContextFilePath)
require.Empty(t, part.ContextFileContent)
require.Empty(t, part.ContextFileOS)
require.Empty(t, part.ContextFileDirectory)
case codersdk.ChatMessagePartTypeSkill:
sawSkill = true
require.Equal(t, "my-skill", part.SkillName)
require.Equal(t, "A test skill", part.SkillDescription)
require.Empty(t, part.SkillDir)
require.Empty(t, part.ContextFileSkillMetaFile)
default:
t.Fatalf("unexpected cached part type %q", part.Type)
}
}
require.True(t, sawContextFile)
require.True(t, sawSkill)
childMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: childID,
AfterID: 0,
})
require.NoError(t, err)
var (
contextMessageIndexes []int
userPromptIndex = -1
sawDBAgentsContextFile bool
sawDBSkillCompanionContext bool
sawDBSkill bool
)
for i, msg := range childMessages {
if !msg.Content.Valid {
continue
}
var parts []codersdk.ChatMessagePart
require.NoError(t, json.Unmarshal(msg.Content.RawMessage, &parts))
if len(parts) == 1 && parts[0].Type == codersdk.ChatMessagePartTypeText && parts[0].Text == prompt {
require.Equal(t, database.ChatMessageRoleUser, msg.Role)
userPromptIndex = i
continue
}
hasInheritedContext := false
for _, part := range parts {
switch part.Type {
case codersdk.ChatMessagePartTypeContextFile:
hasInheritedContext = true
switch part.ContextFilePath {
case "/home/coder/project/AGENTS.md":
sawDBAgentsContextFile = true
require.Equal(t, "# Project instructions", part.ContextFileContent)
require.Equal(t, "linux", part.ContextFileOS)
require.Equal(t, "/home/coder/project", part.ContextFileDirectory)
case "/home/coder/project/.agents/skills/my-skill/SKILL.md":
sawDBSkillCompanionContext = true
require.Empty(t, part.ContextFileContent)
require.Empty(t, part.ContextFileOS)
require.Empty(t, part.ContextFileDirectory)
default:
t.Fatalf("unexpected child inherited context file path %q", part.ContextFilePath)
}
case codersdk.ChatMessagePartTypeSkill:
hasInheritedContext = true
sawDBSkill = true
require.Equal(t, "my-skill", part.SkillName)
require.Equal(t, "A test skill", part.SkillDescription)
require.Equal(t, "/home/coder/project/.agents/skills/my-skill", part.SkillDir)
require.Equal(t, "SKILL.md", part.ContextFileSkillMetaFile)
default:
t.Fatalf("unexpected child inherited part type %q", part.Type)
}
}
if hasInheritedContext {
require.Equal(t, database.ChatMessageRoleUser, msg.Role)
contextMessageIndexes = append(contextMessageIndexes, i)
}
}
require.NotEmpty(t, contextMessageIndexes)
require.NotEqual(t, -1, userPromptIndex)
for _, idx := range contextMessageIndexes {
require.Less(t, idx, userPromptIndex)
}
require.True(t, sawDBAgentsContextFile)
require.True(t, sawDBSkillCompanionContext)
require.True(t, sawDBSkill)
}
func createParentChatWithRotatedInheritedContext(
ctx context.Context,
t *testing.T,
db database.Store,
server *Server,
) database.Chat {
t.Helper()
user, model := seedInternalChatDeps(ctx, t, db)
parent, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
Title: "parent-with-rotated-context",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
oldAgentID := uuid.New()
newAgentID := uuid.New()
oldContent, err := json.Marshal([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/home/coder/project-old/AGENTS.md",
ContextFileContent: "# Old instructions",
ContextFileOS: "darwin",
ContextFileDirectory: "/home/coder/project-old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "old-skill",
SkillDescription: "Old skill",
SkillDir: "/home/coder/project-old/.agents/skills/old-skill",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
})
require.NoError(t, err)
newContent, err := json.Marshal([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/home/coder/project-new/AGENTS.md",
ContextFileContent: "# New instructions",
ContextFileOS: "linux",
ContextFileDirectory: "/home/coder/project-new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "new-skill",
SkillDescription: "New skill",
SkillDir: "/home/coder/project-new/.agents/skills/new-skill",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
},
})
require.NoError(t, err)
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: parent.ID,
CreatedBy: []uuid.UUID{user.ID, user.ID},
ModelConfigID: []uuid.UUID{model.ID, model.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleUser},
Content: []string{string(oldContent), string(newContent)},
ContentVersion: []int16{chatprompt.CurrentContentVersion, chatprompt.CurrentContentVersion},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
InputTokens: []int64{0, 0},
OutputTokens: []int64{0, 0},
TotalTokens: []int64{0, 0},
ReasoningTokens: []int64{0, 0},
CacheCreationTokens: []int64{0, 0},
CacheReadTokens: []int64{0, 0},
ContextLimit: []int64{0, 0},
Compressed: []bool{false, false},
TotalCostMicros: []int64{0, 0},
RuntimeMs: []int64{0, 0},
})
require.NoError(t, err)
parentChat, err := db.GetChatByID(ctx, parent.ID)
require.NoError(t, err)
return parentChat
}
func TestCreateChildSubagentChatCopiesOnlyLatestAgentContext(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
ctx := chatdTestContext(t)
parentChat := createParentChatWithRotatedInheritedContext(ctx, t, db, server)
child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "")
require.NoError(t, err)
childChat, err := db.GetChatByID(ctx, child.ID)
require.NoError(t, err)
require.True(t, childChat.LastInjectedContext.Valid)
var cached []codersdk.ChatMessagePart
require.NoError(t, json.Unmarshal(childChat.LastInjectedContext.RawMessage, &cached))
require.Len(t, cached, 2)
require.Equal(t, "/home/coder/project-new/AGENTS.md", cached[0].ContextFilePath)
require.Equal(t, "new-skill", cached[1].SkillName)
childMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: child.ID,
AfterID: 0,
})
require.NoError(t, err)
var inherited [][]codersdk.ChatMessagePart
for _, msg := range childMessages {
if !msg.Content.Valid {
continue
}
var parts []codersdk.ChatMessagePart
require.NoError(t, json.Unmarshal(msg.Content.RawMessage, &parts))
if len(parts) == 0 || parts[0].Type == codersdk.ChatMessagePartTypeText {
continue
}
inherited = append(inherited, parts)
}
require.Len(t, inherited, 1)
require.Len(t, inherited[0], 2)
require.Equal(t, "/home/coder/project-new/AGENTS.md", inherited[0][0].ContextFilePath)
require.Equal(t, "# New instructions", inherited[0][0].ContextFileContent)
require.Equal(t, "new-skill", inherited[0][1].SkillName)
}
func TestCreateChildSubagentChatUpdatesInheritedLastInjectedContext(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
ctx := chatdTestContext(t)
parentChat := createParentChatWithInheritedContext(ctx, t, db, server)
child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "")
require.NoError(t, err)
assertChildInheritedContext(ctx, t, db, child.ID, "inspect bindings")
}
func TestSpawnComputerUseAgentInheritsContext(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := chatdTestContext(t)
parentChat := createParentChatWithInheritedContext(ctx, t, db, server)
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
tool := findToolByName(tools, "spawn_computer_use_agent")
require.NotNil(t, tool)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-context",
Name: "spawn_computer_use_agent",
Input: `{"prompt":"inspect bindings"}`,
})
require.NoError(t, err)
require.False(t, resp.IsError, "expected success but got: %s", resp.Content)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
childIDStr, ok := result["chat_id"].(string)
require.True(t, ok)
childID, err := uuid.Parse(childIDStr)
require.NoError(t, err)
childChat, err := db.GetChatByID(ctx, childID)
require.NoError(t, err)
require.True(t, childChat.Mode.Valid)
require.Equal(t, database.ChatModeComputerUse, childChat.Mode.ChatMode)
assertChildInheritedContext(ctx, t, db, childID, "inspect bindings")
}
+63
View File
@@ -892,3 +892,66 @@ func (s *SSEAgentReinitReceiver) Receive(ctx context.Context) (*Reinitialization
return &reinitEvent, nil
}
}
// AddChatContextRequest is the request body for adding chat context.
type AddChatContextRequest struct {
// ChatID optionally identifies the chat to add context to.
// If empty, auto-detection is used (CODER_CHAT_ID env, the
// only active chat, or the only top-level active chat for this
// agent).
ChatID uuid.UUID `json:"chat_id,omitempty"`
// Parts are the context-file and skill parts to add.
Parts []codersdk.ChatMessagePart `json:"parts"`
}
// AddChatContextResponse is the response for adding chat context.
type AddChatContextResponse struct {
ChatID uuid.UUID `json:"chat_id"`
Count int `json:"count"`
}
// ClearChatContextRequest is the request body for clearing chat context.
type ClearChatContextRequest struct {
// ChatID optionally identifies the chat to clear context from.
// If empty, auto-detection is used (CODER_CHAT_ID env, the
// only active chat, or the only top-level active chat for this
// agent).
ChatID uuid.UUID `json:"chat_id,omitempty"`
}
// ClearChatContextResponse is the response for clearing chat context.
type ClearChatContextResponse struct {
ChatID uuid.UUID `json:"chat_id"`
}
// AddChatContext adds context-file and skill parts to an active chat.
func (c *Client) AddChatContext(ctx context.Context, req AddChatContextRequest) (AddChatContextResponse, error) {
res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/experimental/chat-context", req)
if err != nil {
return AddChatContextResponse{}, xerrors.Errorf("execute request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return AddChatContextResponse{}, codersdk.ReadBodyAsError(res)
}
var resp AddChatContextResponse
return resp, json.NewDecoder(res.Body).Decode(&resp)
}
// ClearChatContext soft-deletes context-file and skill messages from an active chat.
func (c *Client) ClearChatContext(ctx context.Context, req ClearChatContextRequest) (ClearChatContextResponse, error) {
res, err := c.SDK.Request(ctx, http.MethodDelete, "/api/v2/workspaceagents/me/experimental/chat-context", req)
if err != nil {
return ClearChatContextResponse{}, xerrors.Errorf("execute request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return ClearChatContextResponse{}, codersdk.ReadBodyAsError(res)
}
var resp ClearChatContextResponse
return resp, json.NewDecoder(res.Body).Decode(&resp)
}