mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: include OS and working directory in workspace agent prompt injection (#22399)
When injecting system instructions into the chat prompt, include: 1. **Operating system** and **working directory** from the `workspace_agents` table 2. **Home-level instructions** from `~/.coder/AGENTS.md` (existing behavior) 3. **Project-level instructions** from `<pwd>/AGENTS.md` (new) The XML tag is renamed from `<coder-home-instructions>` to `<system-instructions>` since it now carries more than just the home instruction file. ### Example output (both files present) ```xml <system-instructions> Operating System: linux Working Directory: /home/coder/coder Source: /home/coder/.coder/AGENTS.md ... home instructions ... Source: /home/coder/coder/AGENTS.md ... project instructions ... </system-instructions> ``` ### Example output (no AGENTS.md files) ```xml <system-instructions> Operating System: linux Working Directory: /home/coder/coder </system-instructions> ``` ### Changes - **`coderd/chatd/instruction.go`**: - Renamed types: `homeInstructionContext` → `agentContext`, added `instructionFile` struct - Extracted `readInstructionFileAtPath` shared helper - Added `readWorkingDirectoryInstructionFile` to read `<pwd>/AGENTS.md` - Replaced `formatHomeInstruction` with `formatInstructions` that renders both files under `<system-instructions>` - **`coderd/chatd/chatd.go`**: - Renamed `resolveHomeInstruction` → `resolveInstructions`; now reads both home and pwd instruction files - `resolveAgentContext` returns `agentContext` (renamed from `homeInstructionContext`) - pwd file read is skipped gracefully if directory is empty or file doesn't exist - **`coderd/chatd/instruction_test.go`**: - Added `TestReadWorkingDirectoryInstructionFile` (success, not-found, empty-directory) - Replaced `TestFormatHomeInstruction` with `TestFormatInstructions` covering all combinations - Added ordering test (`AgentContextBeforeFiles`) to verify OS/pwd appear before file sources
This commit is contained in:
+52
-42
@@ -1832,12 +1832,9 @@ func (p *Server) runChat(
|
||||
return currentConn, nil
|
||||
}
|
||||
|
||||
prompt = p.appendHomeInstructionToPrompt(
|
||||
ctx,
|
||||
chat,
|
||||
prompt,
|
||||
getWorkspaceConn,
|
||||
)
|
||||
if instruction := p.resolveInstructions(ctx, chat, getWorkspaceConn); instruction != "" {
|
||||
prompt = chatprompt.InsertSystem(prompt, instruction)
|
||||
}
|
||||
|
||||
// Use the model config's context_limit as a fallback when the LLM
|
||||
// provider doesn't include context_limit in its response metadata
|
||||
@@ -2268,31 +2265,18 @@ func refreshChatWorkspaceSnapshot(
|
||||
return refreshedChat, nil
|
||||
}
|
||||
|
||||
func (p *Server) appendHomeInstructionToPrompt(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
prompt []fantasy.Message,
|
||||
getWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error),
|
||||
) []fantasy.Message {
|
||||
if !chat.WorkspaceAgentID.Valid || getWorkspaceConn == nil {
|
||||
return prompt
|
||||
}
|
||||
|
||||
instruction := p.resolveHomeInstruction(ctx, chat, getWorkspaceConn)
|
||||
if instruction == "" {
|
||||
return prompt
|
||||
}
|
||||
|
||||
return chatprompt.InsertSystem(prompt, instruction)
|
||||
}
|
||||
|
||||
// resolveHomeInstruction returns cached home instructions for the
|
||||
// workspace agent, fetching them on cache miss or expiry.
|
||||
func (p *Server) resolveHomeInstruction(
|
||||
// resolveInstructions returns the combined system instructions for the
|
||||
// workspace agent. It reads the home-level (~/.coder/AGENTS.md) and
|
||||
// working-directory-level (<pwd>/AGENTS.md) instruction files, combines
|
||||
// them with agent metadata (OS, directory), and caches the result.
|
||||
func (p *Server) resolveInstructions(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
getWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error),
|
||||
) string {
|
||||
if !chat.WorkspaceAgentID.Valid {
|
||||
return ""
|
||||
}
|
||||
agentID := chat.WorkspaceAgentID.UUID
|
||||
|
||||
p.instructionCacheMu.Lock()
|
||||
@@ -2303,28 +2287,54 @@ func (p *Server) resolveHomeInstruction(
|
||||
return cached.instruction
|
||||
}
|
||||
|
||||
instructionCtx, cancel := context.WithTimeout(ctx, homeInstructionLookupTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := getWorkspaceConn(instructionCtx)
|
||||
// Look up the agent's OS and working directory.
|
||||
//nolint:gocritic // System context needed to read workspace agent metadata.
|
||||
agent, err := p.db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "failed to resolve workspace connection for home instruction file",
|
||||
slog.F("chat_id", chat.ID),
|
||||
p.logger.Debug(ctx, "failed to look up workspace agent for instruction context",
|
||||
slog.F("agent_id", agentID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return cached.instruction
|
||||
}
|
||||
directory := agent.ExpandedDirectory
|
||||
if directory == "" {
|
||||
directory = agent.Directory
|
||||
}
|
||||
|
||||
content, sourcePath, truncated, err := readHomeInstructionFile(instructionCtx, conn)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "failed to load home instruction file",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return cached.instruction
|
||||
// Read instruction files from the workspace agent.
|
||||
var sections []instructionFileSection
|
||||
if getWorkspaceConn != nil {
|
||||
instructionCtx, cancel := context.WithTimeout(ctx, homeInstructionLookupTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, connErr := getWorkspaceConn(instructionCtx)
|
||||
if connErr != nil {
|
||||
p.logger.Debug(ctx, "failed to resolve workspace connection for instruction files",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(connErr),
|
||||
)
|
||||
} else {
|
||||
// ~/.coder/AGENTS.md
|
||||
if content, source, truncated, err := readHomeInstructionFile(instructionCtx, conn); err != nil {
|
||||
p.logger.Debug(ctx, "failed to load home instruction file",
|
||||
slog.F("chat_id", chat.ID), slog.Error(err))
|
||||
} else if content != "" {
|
||||
sections = append(sections, instructionFileSection{content, source, truncated})
|
||||
}
|
||||
|
||||
// <pwd>/AGENTS.md
|
||||
if pwdPath := pwdInstructionFilePath(directory); pwdPath != "" {
|
||||
if content, source, truncated, err := readInstructionFile(instructionCtx, conn, pwdPath); err != nil {
|
||||
p.logger.Debug(ctx, "failed to load working directory instruction file",
|
||||
slog.F("chat_id", chat.ID), slog.F("directory", directory), slog.Error(err))
|
||||
} else if content != "" {
|
||||
sections = append(sections, instructionFileSection{content, source, truncated})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
instruction := formatHomeInstruction(content, sourcePath, truncated)
|
||||
instruction := formatSystemInstructions(agent.OperatingSystem, directory, sections)
|
||||
|
||||
p.instructionCacheMu.Lock()
|
||||
p.instructionCache[agentID] = cachedInstruction{
|
||||
|
||||
+70
-18
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"path"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
@@ -21,6 +22,8 @@ const (
|
||||
|
||||
var markdownCommentPattern = regexp.MustCompile(`<!--[\s\S]*?-->`)
|
||||
|
||||
// readHomeInstructionFile reads the ~/.coder/AGENTS.md file from the
|
||||
// workspace agent's home directory.
|
||||
func readHomeInstructionFile(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
@@ -54,6 +57,16 @@ func readHomeInstructionFile(
|
||||
return "", "", false, nil
|
||||
}
|
||||
|
||||
return readInstructionFile(ctx, conn, filePath)
|
||||
}
|
||||
|
||||
// readInstructionFile reads and sanitizes an instruction file at the
|
||||
// given absolute path.
|
||||
func readInstructionFile(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
filePath string,
|
||||
) (content string, sourcePath string, truncated bool, err error) {
|
||||
reader, _, err := conn.ReadFile(
|
||||
ctx,
|
||||
filePath,
|
||||
@@ -64,13 +77,13 @@ func readHomeInstructionFile(
|
||||
if isCodersdkStatusCode(err, http.StatusNotFound) {
|
||||
return "", "", false, nil
|
||||
}
|
||||
return "", "", false, xerrors.Errorf("read home instruction file: %w", err)
|
||||
return "", "", false, xerrors.Errorf("read instruction file: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
raw, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return "", "", false, xerrors.Errorf("read home instruction bytes: %w", err)
|
||||
return "", "", false, xerrors.Errorf("read instruction bytes: %w", err)
|
||||
}
|
||||
|
||||
truncated = int64(len(raw)) > maxInstructionFileBytes
|
||||
@@ -93,30 +106,69 @@ func sanitizeInstructionMarkdown(content string) string {
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
//nolint:revive // Boolean indicates content was truncated.
|
||||
func formatHomeInstruction(content string, sourcePath string, truncated bool) string {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" {
|
||||
return ""
|
||||
// formatSystemInstructions builds the <workspace-context> block from
|
||||
// agent metadata and zero or more instruction file sections.
|
||||
func formatSystemInstructions(
|
||||
operatingSystem, directory string,
|
||||
sections []instructionFileSection,
|
||||
) string {
|
||||
hasSections := false
|
||||
for _, s := range sections {
|
||||
if s.content != "" {
|
||||
hasSections = true
|
||||
break
|
||||
}
|
||||
}
|
||||
sourcePath = strings.TrimSpace(sourcePath)
|
||||
if sourcePath == "" {
|
||||
sourcePath = "~/.coder/AGENTS.md"
|
||||
if !hasSections && operatingSystem == "" && directory == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
_, _ = b.WriteString("<coder-home-instructions>\n")
|
||||
_, _ = b.WriteString("Source: ")
|
||||
_, _ = b.WriteString(sourcePath)
|
||||
if truncated {
|
||||
_, _ = b.WriteString(" (truncated to 64KiB)")
|
||||
_, _ = b.WriteString("<workspace-context>\n")
|
||||
if operatingSystem != "" {
|
||||
_, _ = b.WriteString("Operating System: ")
|
||||
_, _ = b.WriteString(operatingSystem)
|
||||
_, _ = b.WriteString("\n")
|
||||
}
|
||||
_, _ = b.WriteString("\n\n")
|
||||
_, _ = b.WriteString(content)
|
||||
_, _ = b.WriteString("\n</coder-home-instructions>")
|
||||
if directory != "" {
|
||||
_, _ = b.WriteString("Working Directory: ")
|
||||
_, _ = b.WriteString(directory)
|
||||
_, _ = b.WriteString("\n")
|
||||
}
|
||||
for _, s := range sections {
|
||||
if s.content == "" {
|
||||
continue
|
||||
}
|
||||
_, _ = b.WriteString("\nSource: ")
|
||||
_, _ = b.WriteString(s.source)
|
||||
if s.truncated {
|
||||
_, _ = b.WriteString(" (truncated to 64KiB)")
|
||||
}
|
||||
_, _ = b.WriteString("\n")
|
||||
_, _ = b.WriteString(s.content)
|
||||
_, _ = b.WriteString("\n")
|
||||
}
|
||||
_, _ = b.WriteString("</workspace-context>")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// instructionFileSection is a single instruction file's content and
|
||||
// source path for rendering inside <workspace-context>.
|
||||
type instructionFileSection struct {
|
||||
content string
|
||||
source string
|
||||
truncated bool
|
||||
}
|
||||
|
||||
// pwdInstructionFilePath returns the absolute path to the AGENTS.md
|
||||
// file in the given working directory, or empty if directory is empty.
|
||||
func pwdInstructionFilePath(directory string) string {
|
||||
if directory == "" {
|
||||
return ""
|
||||
}
|
||||
return path.Join(directory, coderHomeInstructionFile)
|
||||
}
|
||||
|
||||
func isCodersdkStatusCode(err error, statusCode int) bool {
|
||||
var sdkErr *codersdk.Error
|
||||
if !xerrors.As(err, &sdkErr) {
|
||||
|
||||
@@ -104,6 +104,58 @@ func TestReadHomeInstructionFileTruncates(t *testing.T) {
|
||||
require.Len(t, got, maxInstructionFileBytes)
|
||||
}
|
||||
|
||||
func TestReadInstructionFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
conn.EXPECT().ReadFile(
|
||||
gomock.Any(),
|
||||
"/home/coder/project/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(
|
||||
io.NopCloser(strings.NewReader("project rules")),
|
||||
"text/markdown",
|
||||
nil,
|
||||
)
|
||||
|
||||
content, source, truncated, err := readInstructionFile(
|
||||
context.Background(), conn, "/home/coder/project/AGENTS.md",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "project rules", content)
|
||||
require.Equal(t, "/home/coder/project/AGENTS.md", source)
|
||||
require.False(t, truncated)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
conn.EXPECT().ReadFile(
|
||||
gomock.Any(),
|
||||
"/home/coder/project/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(nil, "", codersdk.NewTestError(404, "GET", "/api/v0/read-file"))
|
||||
|
||||
content, source, truncated, err := readInstructionFile(
|
||||
context.Background(), conn, "/home/coder/project/AGENTS.md",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, content)
|
||||
require.Empty(t, source)
|
||||
require.False(t, truncated)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInsertSystemInstructionAfterSystemMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -132,3 +184,100 @@ func TestInsertSystemInstructionAfterSystemMessages(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "project rules", part.Text)
|
||||
}
|
||||
|
||||
func TestFormatSystemInstructions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("HomeAndPwdWithAgentContext", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := formatSystemInstructions("linux", "/home/coder/project", []instructionFileSection{
|
||||
{content: "home rules", source: "/home/coder/.coder/AGENTS.md"},
|
||||
{content: "project rules", source: "/home/coder/project/AGENTS.md"},
|
||||
})
|
||||
require.Contains(t, got, "Operating System: linux")
|
||||
require.Contains(t, got, "Working Directory: /home/coder/project")
|
||||
require.Contains(t, got, "Source: /home/coder/.coder/AGENTS.md")
|
||||
require.Contains(t, got, "home rules")
|
||||
require.Contains(t, got, "Source: /home/coder/project/AGENTS.md")
|
||||
require.Contains(t, got, "project rules")
|
||||
require.True(t, strings.HasPrefix(got, "<workspace-context>"))
|
||||
require.True(t, strings.HasSuffix(got, "</workspace-context>"))
|
||||
})
|
||||
|
||||
t.Run("OnlyPwdFile", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := formatSystemInstructions("", "/home/coder/project", []instructionFileSection{
|
||||
{content: "project rules", source: "/home/coder/project/AGENTS.md"},
|
||||
})
|
||||
require.Contains(t, got, "project rules")
|
||||
require.Contains(t, got, "Source: /home/coder/project/AGENTS.md")
|
||||
require.NotContains(t, got, ".coder/AGENTS.md")
|
||||
})
|
||||
|
||||
t.Run("OnlyAgentContext", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := formatSystemInstructions("darwin", "/Users/dev/repo", nil)
|
||||
require.Contains(t, got, "Operating System: darwin")
|
||||
require.Contains(t, got, "Working Directory: /Users/dev/repo")
|
||||
require.NotContains(t, got, "Source:")
|
||||
require.True(t, strings.HasPrefix(got, "<workspace-context>"))
|
||||
require.True(t, strings.HasSuffix(got, "</workspace-context>"))
|
||||
})
|
||||
|
||||
t.Run("OnlyHomeFile", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := formatSystemInstructions("", "", []instructionFileSection{
|
||||
{content: "home rules", source: "~/.coder/AGENTS.md"},
|
||||
})
|
||||
require.Contains(t, got, "Source: ~/.coder/AGENTS.md")
|
||||
require.Contains(t, got, "home rules")
|
||||
require.NotContains(t, got, "Operating System:")
|
||||
require.NotContains(t, got, "Working Directory:")
|
||||
})
|
||||
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := formatSystemInstructions("", "", nil)
|
||||
require.Empty(t, got)
|
||||
})
|
||||
|
||||
t.Run("TruncatedFile", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := formatSystemInstructions("windows", "", []instructionFileSection{
|
||||
{content: "rules", source: "/path/AGENTS.md", truncated: true},
|
||||
})
|
||||
require.Contains(t, got, "truncated to 64KiB")
|
||||
require.Contains(t, got, "Operating System: windows")
|
||||
})
|
||||
|
||||
t.Run("AgentContextBeforeFiles", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := formatSystemInstructions("linux", "/home/project", []instructionFileSection{
|
||||
{content: "home", source: "/home/.coder/AGENTS.md"},
|
||||
{content: "pwd", source: "/home/project/AGENTS.md"},
|
||||
})
|
||||
osIdx := strings.Index(got, "Operating System:")
|
||||
dirIdx := strings.Index(got, "Working Directory:")
|
||||
homeSourceIdx := strings.Index(got, "Source: /home/.coder/AGENTS.md")
|
||||
pwdSourceIdx := strings.Index(got, "Source: /home/project/AGENTS.md")
|
||||
require.Less(t, osIdx, homeSourceIdx)
|
||||
require.Less(t, dirIdx, homeSourceIdx)
|
||||
require.Less(t, homeSourceIdx, pwdSourceIdx)
|
||||
})
|
||||
|
||||
t.Run("EmptySectionsIgnored", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := formatSystemInstructions("linux", "", []instructionFileSection{
|
||||
{content: "", source: "/empty"},
|
||||
{content: "real", source: "/real/AGENTS.md"},
|
||||
})
|
||||
require.NotContains(t, got, "Source: /empty")
|
||||
require.Contains(t, got, "Source: /real/AGENTS.md")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPwdInstructionFilePath(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, "/home/coder/project/AGENTS.md", pwdInstructionFilePath("/home/coder/project"))
|
||||
require.Empty(t, pwdInstructionFilePath(""))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user