feat(agent/agentcontext): add agent-side context resolution package

This commit is contained in:
Kyle Carberry
2026-06-02 14:11:05 +00:00
parent eea427f288
commit 678617ab38
16 changed files with 3434 additions and 0 deletions
+204
View File
@@ -0,0 +1,204 @@
package agentcontext
import (
"context"
"encoding/hex"
"errors"
"net/http"
"net/url"
"strconv"
"github.com/go-chi/chi/v5"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
)
// SourceResponse is the on-wire representation of a Source.
// Matches the path-only RFC schema; future additions (tags,
// labels) can land additively without breaking clients.
type SourceResponse struct {
Path string `json:"path"`
}
// SourceRequest is the request body for POST /sources.
type SourceRequest struct {
Path string `json:"path"`
}
// SnapshotResource is the on-wire representation of a Resource.
// PayloadBase64 is omitted from list responses; clients that
// need the bytes hit GET /sources/{path}.
type SnapshotResource struct {
ID string `json:"id"`
Kind string `json:"kind"`
Source string `json:"source"`
SourcePath string `json:"source_path,omitempty"`
ContentHash string `json:"content_hash"`
SizeBytes uint64 `json:"size_bytes"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
Description string `json:"description,omitempty"`
}
// SnapshotResponse is the on-wire representation of a Snapshot
// returned by the resync endpoint.
type SnapshotResponse struct {
Version uint64 `json:"version"`
SchemaVersion uint64 `json:"schema_version"`
AggregateHash string `json:"aggregate_hash"`
Resources []SnapshotResource `json:"resources"`
PayloadBytes uint64 `json:"payload_bytes"`
SnapshotError string `json:"snapshot_error,omitempty"`
}
// API exposes the Manager over HTTP. The routes match the RFC:
//
// GET /api/v0/context/sources
// POST /api/v0/context/sources { path }
// GET /api/v0/context/sources/{path}
// DELETE /api/v0/context/sources/{path}
// POST /api/v0/context/resync
//
// {path} is URL-encoded canonical path. Callers pass either the
// canonical or original path; the handler canonicalizes before
// matching.
type API struct {
manager *Manager
}
// NewAPI wraps the supplied Manager.
func NewAPI(m *Manager) *API {
return &API{manager: m}
}
// Routes returns the chi handler for /api/v0/context/*. Mount
// it at "/api/v0/context".
func (a *API) Routes() http.Handler {
r := chi.NewRouter()
r.Route("/sources", func(r chi.Router) {
r.Get("/", a.handleListSources)
r.Post("/", a.handleAddSource)
r.Get("/{path}", a.handleGetSource)
r.Delete("/{path}", a.handleRemoveSource)
})
r.Post("/resync", a.handleResync)
return r
}
func (a *API) handleListSources(rw http.ResponseWriter, r *http.Request) {
sources := a.manager.Sources()
out := make([]SourceResponse, 0, len(sources))
for _, s := range sources {
out = append(out, SourceResponse(s))
}
httpapi.Write(r.Context(), rw, http.StatusOK, out)
}
func (a *API) handleAddSource(rw http.ResponseWriter, r *http.Request) {
var req SourceRequest
if !httpapi.Read(r.Context(), rw, r, &req) {
return
}
s, err := a.manager.AddSource(Source(req))
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Could not add context source.",
Detail: err.Error(),
})
return
}
httpapi.Write(r.Context(), rw, http.StatusCreated, SourceResponse(s))
}
func (a *API) handleGetSource(rw http.ResponseWriter, r *http.Request) {
raw := chi.URLParam(r, "path")
decoded, err := url.PathUnescape(raw)
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid context source path.",
Detail: err.Error(),
})
return
}
canonical, ok := a.manager.HasSource(decoded)
if !ok {
httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{
Message: "Context source not found.",
Detail: "No source registered for path " + strconv.Quote(decoded) + ".",
})
return
}
httpapi.Write(r.Context(), rw, http.StatusOK, SourceResponse{Path: canonical})
}
func (a *API) handleRemoveSource(rw http.ResponseWriter, r *http.Request) {
raw := chi.URLParam(r, "path")
decoded, err := url.PathUnescape(raw)
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid context source path.",
Detail: err.Error(),
})
return
}
if err := a.manager.RemoveSource(decoded); err != nil {
if errors.Is(err, ErrSourceNotFound) {
httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{
Message: "Context source not found.",
Detail: err.Error(),
})
return
}
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Could not remove context source.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
func (a *API) handleResync(rw http.ResponseWriter, r *http.Request) {
snap, err := a.manager.Resync(r.Context())
if err != nil {
status := http.StatusInternalServerError
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
status = http.StatusGatewayTimeout
}
httpapi.Write(r.Context(), rw, status, codersdk.Response{
Message: "Resync failed.",
Detail: err.Error(),
})
return
}
httpapi.Write(r.Context(), rw, http.StatusOK, snapshotResponse(snap))
}
// snapshotResponse converts a Snapshot to its on-wire form for
// the resync endpoint. Payloads are omitted; the per-resource
// payload bytes ship via the drpc PushContextState path.
func snapshotResponse(s Snapshot) SnapshotResponse {
out := SnapshotResponse{
Version: s.Version,
SchemaVersion: s.SchemaVersion,
AggregateHash: hex.EncodeToString(s.AggregateHash[:]),
Resources: make([]SnapshotResource, 0, len(s.Resources)),
PayloadBytes: s.PayloadBytes,
SnapshotError: s.SnapshotError,
}
for _, r := range s.Resources {
out.Resources = append(out.Resources, SnapshotResource{
ID: r.ID,
Kind: r.Kind.String(),
Source: r.Source,
SourcePath: r.SourcePath,
ContentHash: hex.EncodeToString(r.ContentHash[:]),
SizeBytes: r.SizeBytes,
Status: r.Status.String(),
Error: r.Error,
Description: r.Description,
})
}
return out
}
+177
View File
@@ -0,0 +1,177 @@
package agentcontext_test
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentcontext"
"github.com/coder/coder/v2/testutil"
)
func newAPITestServer(t *testing.T, opts agentcontext.ManagerOptions) (*httptest.Server, *agentcontext.Manager) {
t.Helper()
m := newTestManager(t, opts)
api := agentcontext.NewAPI(m)
srv := httptest.NewServer(api.Routes())
t.Cleanup(srv.Close)
return srv, m
}
// doRequest issues an HTTP request bounded by testutil.WaitShort
// and returns the status code and response body. The response
// body is closed before doRequest returns.
func doRequest(t *testing.T, method, requrl string, body io.Reader) (int, []byte) {
t.Helper()
ctx := testutil.Context(t, testutil.WaitShort)
req, err := http.NewRequestWithContext(ctx, method, requrl, body)
require.NoError(t, err)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
res, err := http.DefaultClient.Do(req) //nolint:bodyclose // closed below.
require.NoError(t, err)
defer res.Body.Close()
bodyBytes, err := io.ReadAll(res.Body)
require.NoError(t, err)
return res.StatusCode, bodyBytes
}
func TestAPI_ListSourcesEmpty(t *testing.T) {
t.Parallel()
srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return t.TempDir() },
})
status, body := doRequest(t, http.MethodGet, srv.URL+"/sources", nil)
require.Equal(t, http.StatusOK, status)
var got []agentcontext.SourceResponse
require.NoError(t, json.Unmarshal(body, &got))
require.Empty(t, got)
}
func TestAPI_AddAndListSource(t *testing.T) {
t.Parallel()
wd := t.TempDir()
src := t.TempDir()
srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return wd },
AllowedRoots: []string{wd, src},
})
body, _ := json.Marshal(agentcontext.SourceRequest{Path: src})
status, addBody := doRequest(t, http.MethodPost, srv.URL+"/sources", bytes.NewReader(body))
require.Equal(t, http.StatusCreated, status)
var created agentcontext.SourceResponse
require.NoError(t, json.Unmarshal(addBody, &created))
require.Equal(t, src, created.Path)
// List should show the new source.
listStatus, listBody := doRequest(t, http.MethodGet, srv.URL+"/sources", nil)
require.Equal(t, http.StatusOK, listStatus)
var list []agentcontext.SourceResponse
require.NoError(t, json.Unmarshal(listBody, &list))
require.Len(t, list, 1)
require.Equal(t, src, list[0].Path)
}
func TestAPI_AddSourceRejected(t *testing.T) {
t.Parallel()
wd := t.TempDir()
outside := t.TempDir()
srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return wd },
AllowedRoots: []string{wd},
})
body, _ := json.Marshal(agentcontext.SourceRequest{Path: outside})
status, _ := doRequest(t, http.MethodPost, srv.URL+"/sources", bytes.NewReader(body))
require.Equal(t, http.StatusBadRequest, status)
}
func TestAPI_GetAndDeleteSource(t *testing.T) {
t.Parallel()
wd := t.TempDir()
src := t.TempDir()
srv, m := newAPITestServer(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return wd },
AllowedRoots: []string{wd, src},
})
_, err := m.AddSource(agentcontext.Source{Path: src})
require.NoError(t, err)
status, body := doRequest(t, http.MethodGet, srv.URL+"/sources/"+url.PathEscape(src), nil)
require.Equal(t, http.StatusOK, status)
var got agentcontext.SourceResponse
require.NoError(t, json.Unmarshal(body, &got))
require.Equal(t, src, got.Path)
delStatus, _ := doRequest(t, http.MethodDelete, srv.URL+"/sources/"+url.PathEscape(src), nil)
require.Equal(t, http.StatusNoContent, delStatus)
require.Empty(t, m.Sources())
}
func TestAPI_GetSourceNotFound(t *testing.T) {
t.Parallel()
srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return t.TempDir() },
})
status, _ := doRequest(t, http.MethodGet, srv.URL+"/sources/"+url.PathEscape("/never-added"), nil)
require.Equal(t, http.StatusNotFound, status)
}
func TestAPI_DeleteSourceNotFound(t *testing.T) {
t.Parallel()
srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return t.TempDir() },
})
status, _ := doRequest(t, http.MethodDelete, srv.URL+"/sources/"+url.PathEscape("/never-added"), nil)
require.Equal(t, http.StatusNotFound, status)
}
func TestAPI_Resync(t *testing.T) {
t.Parallel()
wd := t.TempDir()
mustWriteFile(t, filepath.Join(wd, "AGENTS.md"), "hello")
srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return wd },
})
status, body := doRequest(t, http.MethodPost, srv.URL+"/resync", nil)
require.Equal(t, http.StatusOK, status)
var snap agentcontext.SnapshotResponse
require.NoError(t, json.Unmarshal(body, &snap))
require.Equal(t, uint64(1), snap.SchemaVersion)
require.NotEmpty(t, snap.AggregateHash)
require.Len(t, snap.Resources, 1)
require.Equal(t, "instruction_file", snap.Resources[0].Kind)
require.Equal(t, "ok", snap.Resources[0].Status)
}
func TestAPI_AddSourceMalformedBody(t *testing.T) {
t.Parallel()
srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return t.TempDir() },
})
status, _ := doRequest(t, http.MethodPost, srv.URL+"/sources", bytes.NewReader([]byte("{not json")))
require.Equal(t, http.StatusBadRequest, status)
}
+25
View File
@@ -0,0 +1,25 @@
// Package agentcontext consolidates the agent-side plumbing that
// resolves, watches, and pushes workspace context (instruction
// files, skills, and MCP configuration) to coderd.
//
// This is the agent half of the design described in
// "RFC: Workspace Context Sources for Coder Agents". It owns:
//
// - User-declared scan roots (Sources) layered on top of
// built-in defaults.
// - A resolver that classifies files under each scan root into
// typed Resources (instruction files, skills, MCP configs,
// MCP servers).
// - A unified recursive fsnotify watcher that signals a
// re-resolve when any recognized file changes.
// - An HTTP API at /api/v0/context/sources for source CRUD
// and /api/v0/context/resync for synchronous push barriers.
// - A Pusher abstraction so the latest Snapshot can be shipped
// to coderd without coupling this package to any particular
// drpc client version.
//
// The package is purely additive: existing agent code paths
// (agent/agentcontextconfig and agent/x/agentmcp) continue to
// operate unchanged. Wiring the Manager into the agent's HTTP
// router and the drpc client lives in a follow-up change.
package agentcontext
+480
View File
@@ -0,0 +1,480 @@
package agentcontext
import (
"context"
"sort"
"strings"
"sync"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/quartz"
)
// CurrentSchemaVersion is the on-wire shape version. Bump
// whenever the resource format changes in a way that requires
// coderd-side awareness.
const CurrentSchemaVersion uint64 = 1
// ManagerOptions configures a Manager. Zero values get sensible
// defaults.
type ManagerOptions struct {
// Logger receives diagnostic messages. Required.
Logger slog.Logger
// Clock is the time source used for the watcher's
// debounce timer. Optional; defaults to quartz.NewReal().
Clock quartz.Clock
// WorkingDir is evaluated on every resolve, mirroring the
// existing agent convention. The result is used as a
// scan root.
WorkingDir func() string
// BuiltinRoots are scan roots layered before user-added
// sources. Typically: the working directory, ~/.coder,
// ~/.coder/skills, .agents/skills,
// ~/.claude/plugins/cache.
BuiltinRoots []string
// InitialSources seeds the Manager's source list at boot
// time. Sources from CODER_AGENT_EXP_*_DIRS env vars or
// startup scripts are layered here.
InitialSources []Source
// AllowedRoots restricts which paths may be added as
// sources at runtime. Defaults to [~, ~/.coder, ~/.claude,
// workingDir]. Empty disables validation.
AllowedRoots []string
// Resolver, when non-nil, replaces the default resolver.
// Tests use this to inject MCP providers and tighten
// caps.
Resolver *Resolver
// Debounce overrides the watcher's debounce window.
Debounce time.Duration
// SchemaVersion is the version stamped on each Snapshot.
// Use CurrentSchemaVersion (the default) unless rolling
// out a schema change.
SchemaVersion uint64
}
// Manager orchestrates source CRUD, resolution, watching, and
// Pusher fan-out. Construct with NewManager; start its lifecycle
// goroutines with Run; tear down with Close.
type Manager struct {
logger slog.Logger
clock quartz.Clock
workingDir func() string
builtinRoots []string
allowedRoots []string
resolver *Resolver
debounce time.Duration
schemaVersion uint64
mu sync.Mutex
sources []Source
// sourceIndex maps canonical path -> position in sources
// for O(1) lookups during AddSource / RemoveSource.
sourceIndex map[string]int
// snapshot is the latest result of a resolver pass. It is
// replaced atomically under mu.
snapshot Snapshot
// version monotonically increases per resolve pass.
version uint64
// subscribers receive a non-blocking signal whenever the
// snapshot changes. Subscribers must drain their channel
// promptly; the Manager drops sends to full channels.
subscribers map[chan struct{}]struct{}
// trigger fires when AddSource / RemoveSource / watcher
// observe a change.
trigger chan struct{}
// running tracks Run lifetime.
running bool
closed bool
closedCh chan struct{}
runDoneCh chan struct{}
watcher *Watcher
}
// NewManager validates options, canonicalizes initial sources,
// performs the first resolver pass synchronously, and returns
// the resulting Manager. Run must be called separately to start
// the watcher and re-resolve goroutine.
func NewManager(opts ManagerOptions) (*Manager, error) {
clock := opts.Clock
if clock == nil {
clock = quartz.NewReal()
}
debounce := opts.Debounce
if debounce <= 0 {
debounce = DefaultWatchDebounce
}
schemaVersion := opts.SchemaVersion
if schemaVersion == 0 {
schemaVersion = CurrentSchemaVersion
}
resolver := opts.Resolver
if resolver == nil {
resolver = &Resolver{}
}
m := &Manager{
logger: opts.Logger,
clock: clock,
workingDir: opts.WorkingDir,
builtinRoots: append([]string(nil), opts.BuiltinRoots...),
allowedRoots: append([]string(nil), opts.AllowedRoots...),
resolver: resolver,
debounce: debounce,
schemaVersion: schemaVersion,
sources: make([]Source, 0),
sourceIndex: make(map[string]int),
subscribers: make(map[chan struct{}]struct{}),
trigger: make(chan struct{}, 1),
closedCh: make(chan struct{}),
runDoneCh: make(chan struct{}),
}
for _, s := range opts.InitialSources {
canonical, err := CanonicalizePath(s.Path)
if err != nil {
// Initial sources may not exist yet at boot
// time; log and skip rather than abort the
// agent.
m.logger.Warn(context.Background(),
"agentcontext: skipping invalid initial source",
slog.F("path", s.Path),
slog.Error(err))
continue
}
if _, ok := m.sourceIndex[canonical]; ok {
continue
}
m.sourceIndex[canonical] = len(m.sources)
m.sources = append(m.sources, Source{Path: canonical})
}
// First snapshot is computed eagerly. The push protocol
// requires a snapshot to be present before the agent signals
// lifecycle = ready, so callers can rely on Snapshot() being
// populated immediately after NewManager returns.
m.resolveLocked()
return m, nil
}
// Run starts the watcher and the re-resolve goroutine. Run
// blocks until ctx is canceled or Close is called. It is safe
// to call Run at most once per Manager.
func (m *Manager) Run(ctx context.Context) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
return xerrors.New("agentcontext: Manager.Run called more than once")
}
if m.closed {
m.mu.Unlock()
return xerrors.New("agentcontext: Manager already closed")
}
m.running = true
m.mu.Unlock()
watcher, err := NewWatcher(WatcherOptions{
Logger: m.logger.Named("watcher"),
Clock: m.clock,
Debounce: m.debounce,
OnChange: m.signal,
})
if err != nil {
// NewWatcher already falls back to degraded mode on
// init failure, so an actual error here is
// exceptional.
return xerrors.Errorf("create watcher: %w", err)
}
m.mu.Lock()
m.watcher = watcher
roots := m.scanRootsLocked()
m.mu.Unlock()
watcher.Sync(ctx, roots)
defer close(m.runDoneCh)
defer watcher.Close()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-m.closedCh:
return nil
case <-m.trigger:
m.mu.Lock()
roots := m.scanRootsLocked()
m.mu.Unlock()
watcher.Sync(ctx, roots)
m.resolveAndBroadcast(ctx)
}
}
}
// Close stops the Manager. Close is idempotent; subsequent
// calls block until Run exits.
func (m *Manager) Close() error {
m.mu.Lock()
if m.closed {
running := m.running
m.mu.Unlock()
if running {
<-m.runDoneCh
}
return nil
}
m.closed = true
running := m.running
close(m.closedCh)
m.mu.Unlock()
if running {
<-m.runDoneCh
}
return nil
}
// Sources returns a defensive copy of the current source list.
// The returned slice is safe to mutate.
func (m *Manager) Sources() []Source {
m.mu.Lock()
defer m.mu.Unlock()
out := make([]Source, len(m.sources))
copy(out, m.sources)
return out
}
// HasSource reports whether path matches an existing source
// after canonicalization. Returns the canonical path on
// success.
func (m *Manager) HasSource(path string) (canonical string, ok bool) {
c, err := CanonicalizePath(path)
if err != nil {
return "", false
}
m.mu.Lock()
defer m.mu.Unlock()
_, ok = m.sourceIndex[c]
return c, ok
}
// AddSource adds a new source. The path is canonicalized and
// validated against the AllowedRoots set. AddSource is
// idempotent.
func (m *Manager) AddSource(s Source) (Source, error) {
canonical, err := CanonicalizePath(s.Path)
if err != nil {
return Source{}, xerrors.Errorf("canonicalize: %w", err)
}
if err := ValidateSourcePath(canonical, m.effectiveAllowedRoots()); err != nil {
return Source{}, err
}
m.mu.Lock()
if _, ok := m.sourceIndex[canonical]; ok {
out := m.sources[m.sourceIndex[canonical]]
m.mu.Unlock()
return out, nil
}
m.sourceIndex[canonical] = len(m.sources)
m.sources = append(m.sources, Source{Path: canonical})
m.mu.Unlock()
m.signal()
return Source{Path: canonical}, nil
}
// RemoveSource removes the source matching path. Path is
// canonicalized before matching. Returns ErrSourceNotFound when
// no such source exists.
func (m *Manager) RemoveSource(path string) error {
canonical, err := CanonicalizePath(path)
if err != nil {
return xerrors.Errorf("canonicalize: %w", err)
}
m.mu.Lock()
idx, ok := m.sourceIndex[canonical]
if !ok {
m.mu.Unlock()
return ErrSourceNotFound
}
// O(n) compaction is fine for the typical handful of
// user-added sources.
m.sources = append(m.sources[:idx], m.sources[idx+1:]...)
delete(m.sourceIndex, canonical)
for i := idx; i < len(m.sources); i++ {
m.sourceIndex[m.sources[i].Path] = i
}
m.mu.Unlock()
m.signal()
return nil
}
// Snapshot returns the latest Snapshot. The returned value is
// safe to share but shares the same Resources slice as the
// internal state; callers must not mutate it.
func (m *Manager) Snapshot() Snapshot {
m.mu.Lock()
defer m.mu.Unlock()
return m.snapshot
}
// SubscribeChanges returns a buffered channel that receives a
// signal whenever the snapshot changes. The unsubscribe
// callback is safe to call from any goroutine and is
// idempotent.
func (m *Manager) SubscribeChanges() (<-chan struct{}, func()) {
ch := make(chan struct{}, 1)
m.mu.Lock()
m.subscribers[ch] = struct{}{}
m.mu.Unlock()
var once sync.Once
unsub := func() {
once.Do(func() {
m.mu.Lock()
delete(m.subscribers, ch)
m.mu.Unlock()
// Don't close ch: readers may still be in flight.
})
}
return ch, unsub
}
// Resync forces an immediate re-resolve and returns the new
// Snapshot. Resync is safe to call regardless of whether Run
// is active; the work is done synchronously under the
// Manager's mutex either way.
func (m *Manager) Resync(ctx context.Context) (Snapshot, error) {
if ctxErr := ctx.Err(); ctxErr != nil {
return m.Snapshot(), ctxErr
}
m.mu.Lock()
if m.closed {
m.mu.Unlock()
return m.Snapshot(), ErrManagerClosed
}
m.resolveLocked()
snap := m.snapshot
subs := make([]chan struct{}, 0, len(m.subscribers))
for ch := range m.subscribers {
subs = append(subs, ch)
}
m.mu.Unlock()
for _, ch := range subs {
select {
case ch <- struct{}{}:
default:
}
}
return snap, nil
}
// signal triggers a re-resolve. Sends are non-blocking; the
// trigger channel has a depth of 1, which coalesces bursts.
func (m *Manager) signal() {
select {
case m.trigger <- struct{}{}:
default:
}
}
// scanRootsLocked returns the list of ScanRoots to feed the
// resolver and watcher. The Manager's mutex must be held.
func (m *Manager) scanRootsLocked() []ScanRoot {
out := make([]ScanRoot, 0, 1+len(m.builtinRoots)+len(m.sources))
if m.workingDir != nil {
if wd := strings.TrimSpace(m.workingDir()); wd != "" {
out = append(out, ScanRoot{Path: wd})
}
}
for _, r := range m.builtinRoots {
canonical, err := CanonicalizePath(r)
if err != nil {
continue
}
out = append(out, ScanRoot{Path: canonical})
}
for _, s := range m.sources {
out = append(out, ScanRoot{Path: s.Path, UserSource: s.Path})
}
return out
}
// effectiveAllowedRoots returns the AllowedRoots augmented with
// a sensible fallback (~ and the working directory) when the
// caller did not configure any.
func (m *Manager) effectiveAllowedRoots() []string {
if len(m.allowedRoots) > 0 {
return append([]string{}, m.allowedRoots...)
}
roots := []string{"~"}
if m.workingDir != nil {
if wd := strings.TrimSpace(m.workingDir()); wd != "" {
roots = append(roots, wd)
}
}
return roots
}
// resolveAndBroadcast computes a fresh snapshot and notifies
// subscribers if the aggregate hash changed.
func (m *Manager) resolveAndBroadcast(ctx context.Context) {
m.mu.Lock()
m.resolveLocked()
subs := make([]chan struct{}, 0, len(m.subscribers))
for ch := range m.subscribers {
subs = append(subs, ch)
}
m.mu.Unlock()
// The broadcast is unconditional: Resync waiters that
// triggered the pass without an actual content change
// still need to wake up. Subscribers compare snapshots via
// AggregateHash if they want to filter.
for _, ch := range subs {
select {
case ch <- struct{}{}:
default:
}
}
_ = ctx
}
// resolveLocked runs the resolver and stamps the snapshot.
// The Manager's mutex must be held.
func (m *Manager) resolveLocked() {
roots := m.scanRootsLocked()
snap := m.resolver.Resolve(roots)
m.version++
snap.Version = m.version
snap.SchemaVersion = m.schemaVersion
// Surface watcher degradation as a snapshot-level error
// when the resolver did not already emit one.
if snap.SnapshotError == "" && m.watcher != nil {
if d := m.watcher.Degraded(); d != "" {
snap.SnapshotError = d
}
}
// Ensure resources are stable-sorted by ID even when the
// resolver did not run them through caps.
sort.Slice(snap.Resources, func(i, j int) bool {
return snap.Resources[i].ID < snap.Resources[j].ID
})
m.snapshot = snap
}
// ErrSourceNotFound is returned by RemoveSource when the
// requested path is not in the source list.
var ErrSourceNotFound = xerrors.New("source not found")
// ErrManagerClosed is returned by methods called after Close.
var ErrManagerClosed = xerrors.New("agentcontext: manager closed")
+260
View File
@@ -0,0 +1,260 @@
package agentcontext_test
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentcontext"
"github.com/coder/coder/v2/testutil"
)
func newTestManager(t *testing.T, opts agentcontext.ManagerOptions) *agentcontext.Manager {
t.Helper()
opts.Logger = testutil.Logger(t).Named("agentcontext-test")
m, err := agentcontext.NewManager(opts)
require.NoError(t, err)
t.Cleanup(func() { _ = m.Close() })
return m
}
func TestManager_InitialSnapshotIsPopulated(t *testing.T) {
t.Parallel()
dir := t.TempDir()
mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "boot snapshot")
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return dir },
})
snap := m.Snapshot()
require.Equal(t, uint64(1), snap.Version)
require.Equal(t, agentcontext.CurrentSchemaVersion, snap.SchemaVersion)
require.Len(t, snap.Resources, 1)
}
func TestManager_AddSourceTriggersResolve(t *testing.T) {
t.Parallel()
wd := t.TempDir()
src := t.TempDir()
mustWriteFile(t, filepath.Join(src, "AGENTS.md"), "from source")
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return wd },
AllowedRoots: []string{wd, src},
})
ctx := testutil.Context(t, testutil.WaitLong)
go func() { _ = m.Run(ctx) }()
t.Cleanup(func() { _ = m.Close() })
// Subscribe before mutating so we observe the broadcast.
ch, unsub := m.SubscribeChanges()
defer unsub()
added, err := m.AddSource(agentcontext.Source{Path: src})
require.NoError(t, err)
require.Equal(t, src, added.Path)
select {
case <-ch:
case <-time.After(testutil.WaitShort):
t.Fatalf("expected a change broadcast after AddSource")
}
snap := m.Snapshot()
require.Greater(t, snap.Version, uint64(1))
found := false
for _, r := range snap.Resources {
if r.Kind == agentcontext.KindInstructionFile && r.SourcePath == src {
found = true
}
}
require.True(t, found, "expected AGENTS.md attributed to the user source")
}
func TestManager_AddSourceRejectsOutsideAllowedRoots(t *testing.T) {
t.Parallel()
wd := t.TempDir()
outside := t.TempDir()
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return wd },
AllowedRoots: []string{wd},
})
_, err := m.AddSource(agentcontext.Source{Path: outside})
require.Error(t, err)
}
func TestManager_AddSourceIsIdempotent(t *testing.T) {
t.Parallel()
wd := t.TempDir()
src := t.TempDir()
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return wd },
AllowedRoots: []string{wd, src},
})
added1, err := m.AddSource(agentcontext.Source{Path: src})
require.NoError(t, err)
added2, err := m.AddSource(agentcontext.Source{Path: src})
require.NoError(t, err)
require.Equal(t, added1.Path, added2.Path)
sources := m.Sources()
require.Len(t, sources, 1)
}
func TestManager_RemoveSource(t *testing.T) {
t.Parallel()
wd := t.TempDir()
src := t.TempDir()
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return wd },
AllowedRoots: []string{wd, src},
})
_, err := m.AddSource(agentcontext.Source{Path: src})
require.NoError(t, err)
require.NoError(t, m.RemoveSource(src))
require.Empty(t, m.Sources())
err = m.RemoveSource(src)
require.ErrorIs(t, err, agentcontext.ErrSourceNotFound)
}
func TestManager_HasSource(t *testing.T) {
t.Parallel()
wd := t.TempDir()
src := t.TempDir()
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return wd },
AllowedRoots: []string{wd, src},
})
canonical, ok := m.HasSource(src)
require.False(t, ok)
require.Equal(t, src, canonical)
_, err := m.AddSource(agentcontext.Source{Path: src})
require.NoError(t, err)
canonical, ok = m.HasSource(src)
require.True(t, ok)
require.Equal(t, src, canonical)
}
func TestManager_ResyncReturnsLatestSnapshot(t *testing.T) {
t.Parallel()
wd := t.TempDir()
mustWriteFile(t, filepath.Join(wd, "AGENTS.md"), "first")
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return wd },
})
ctx := testutil.Context(t, testutil.WaitLong)
runDone := make(chan struct{})
go func() {
defer close(runDone)
_ = m.Run(ctx)
}()
t.Cleanup(func() {
_ = m.Close()
<-runDone
})
// Mutate AGENTS.md and call Resync. The returned
// snapshot must reflect the new content.
require.NoError(t, os.WriteFile(filepath.Join(wd, "AGENTS.md"), []byte("second content edit"), 0o600))
snap, err := m.Resync(ctx)
require.NoError(t, err)
require.Len(t, snap.Resources, 1)
require.Equal(t, "second content edit", string(snap.Resources[0].Payload))
}
func TestManager_InitialSourcesSeeded(t *testing.T) {
t.Parallel()
wd := t.TempDir()
src := t.TempDir()
mustWriteFile(t, filepath.Join(src, "AGENTS.md"), "from initial")
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return wd },
AllowedRoots: []string{wd, src},
InitialSources: []agentcontext.Source{{Path: src}},
})
sources := m.Sources()
require.Len(t, sources, 1)
require.Equal(t, src, sources[0].Path)
snap := m.Snapshot()
require.Len(t, snap.Resources, 1)
require.Equal(t, src, snap.Resources[0].SourcePath)
}
func TestManager_CloseIsIdempotent(t *testing.T) {
t.Parallel()
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return t.TempDir() },
})
require.NoError(t, m.Close())
require.NoError(t, m.Close())
}
func TestManager_RunOnce(t *testing.T) {
t.Parallel()
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return t.TempDir() },
})
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
defer cancel()
go func() { _ = m.Run(ctx) }()
// Brief wait so Run has a chance to set running=true.
time.Sleep(50 * time.Millisecond)
err := m.Run(ctx)
require.Error(t, err)
require.Contains(t, err.Error(), "more than once")
cancel()
_ = m.Close()
}
func TestManager_SubscribeBroadcastOnChange(t *testing.T) {
t.Parallel()
wd := t.TempDir()
src := t.TempDir()
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return wd },
AllowedRoots: []string{wd, src},
})
ctx := testutil.Context(t, testutil.WaitLong)
go func() { _ = m.Run(ctx) }()
ch, unsub := m.SubscribeChanges()
defer unsub()
_, err := m.AddSource(agentcontext.Source{Path: src})
require.NoError(t, err)
select {
case <-ch:
case <-time.After(testutil.WaitShort):
t.Fatal("expected subscriber to be notified")
}
}
+24
View File
@@ -0,0 +1,24 @@
package agentcontext
// MCPProvider supplies the live MCP server portion of a
// snapshot. Implementations typically wrap an existing MCP
// manager (e.g. agent/x/agentmcp.Manager) and translate each
// server's tool list into a KindMCPServer resource.
//
// The interface is intentionally minimal so the existing MCP
// lifecycle code can be reused without refactoring; a follow-up
// change absorbs the lifecycle into this package.
type MCPProvider interface {
// MCPResources returns one Resource per MCP server known
// to the provider. Each Resource must:
//
// - Have Kind == KindMCPServer.
// - Use the server name as Source.
// - Populate ContentHash over the canonical encoding
// of the tool list so changes flip the dirty bit.
// - Carry a Description summarizing the server.
//
// Implementations should never block; the resolver calls
// this on every re-resolve.
MCPResources() []Resource
}
+121
View File
@@ -0,0 +1,121 @@
package agentcontext
import (
"os"
"path/filepath"
"strings"
"golang.org/x/xerrors"
)
// CanonicalizePath produces the canonical form of a user-
// supplied path. The result is absolute, has ~ expanded, has
// path-traversal segments collapsed, and has symlinks resolved
// when the target exists. The path is left lexically clean if
// it does not yet exist (so adding a not-yet-created directory
// remains possible).
//
// CanonicalizePath returns the original input when it is empty.
func CanonicalizePath(raw string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", xerrors.New("path is empty")
}
// Expand ~ and ~/ prefixes against the current user's home
// directory. Other ~user forms are not supported on
// purpose; the agent runs as a known user.
if raw == "~" || strings.HasPrefix(raw, "~/") {
home, err := os.UserHomeDir()
if err != nil {
return "", xerrors.Errorf("expand home dir: %w", err)
}
if raw == "~" {
raw = home
} else {
raw = filepath.Join(home, raw[2:])
}
}
if !filepath.IsAbs(raw) {
// Fail closed: relative paths could mean different
// things depending on the agent's working directory at
// add-time, so require the caller to absolutize first.
return "", xerrors.Errorf("path %q is not absolute", raw)
}
cleaned := filepath.Clean(raw)
if resolved, err := filepath.EvalSymlinks(cleaned); err == nil {
return resolved, nil
}
return cleaned, nil
}
// ValidateSourcePath enforces the path-validation rules from
// the RFC's Authorization section. It rejects:
//
// - Paths containing ".." segments after expansion.
// - Paths resolving outside the supplied allowedRoots, unless
// allowedRoots is empty (which disables the check).
//
// allowedRoots are canonicalized lazily; missing roots are
// silently skipped so a workspace with no $HOME does not break
// validation for project-relative roots.
func ValidateSourcePath(canonical string, allowedRoots []string) error {
if canonical == "" {
return xerrors.New("path is empty")
}
// filepath.Clean drops "." but leaves ".." when no parent
// is available. Reject defensively.
for _, part := range strings.Split(canonical, string(os.PathSeparator)) {
if part == ".." {
return xerrors.Errorf("path %q contains parent traversal segments", canonical)
}
}
if len(allowedRoots) == 0 {
return nil
}
// Build canonical, deduplicated allowed roots. Missing
// roots (e.g. an unconfigured ~/.claude/) are skipped.
roots := make([]string, 0, len(allowedRoots))
seen := make(map[string]struct{}, len(allowedRoots))
for _, raw := range allowedRoots {
c, err := CanonicalizePath(raw)
if err != nil {
continue
}
if _, ok := seen[c]; ok {
continue
}
seen[c] = struct{}{}
roots = append(roots, c)
}
if len(roots) == 0 {
// All configured roots were invalid; treat as "deny
// everything" so misconfiguration fails closed.
return xerrors.Errorf("path %q is not inside any allowed root", canonical)
}
for _, root := range roots {
if pathHasPrefix(canonical, root) {
return nil
}
}
return xerrors.Errorf("path %q is not inside any allowed root", canonical)
}
// pathHasPrefix reports whether path is equal to or a
// descendant of prefix. Both arguments must already be clean,
// absolute paths.
func pathHasPrefix(path, prefix string) bool {
if path == prefix {
return true
}
withSep := prefix
if !strings.HasSuffix(withSep, string(os.PathSeparator)) {
withSep += string(os.PathSeparator)
}
return strings.HasPrefix(path, withSep)
}
+112
View File
@@ -0,0 +1,112 @@
package agentcontext_test
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentcontext"
)
func TestCanonicalizePath_AbsoluteCleansAndResolves(t *testing.T) {
t.Parallel()
dir := t.TempDir()
got, err := agentcontext.CanonicalizePath(filepath.Join(dir, "a", "..", "b"))
require.NoError(t, err)
// Path does not exist; EvalSymlinks fails. Result is
// lexically cleaned: filepath.Clean drops the "..".
require.Equal(t, filepath.Join(dir, "b"), got)
}
func TestCanonicalizePath_RelativeRejected(t *testing.T) {
t.Parallel()
_, err := agentcontext.CanonicalizePath("relative/path")
require.Error(t, err)
}
//nolint:paralleltest,tparallel // Uses t.Setenv.
func TestCanonicalizePath_TildeExpansion(t *testing.T) {
t.Setenv("HOME", "/tmp/home")
got, err := agentcontext.CanonicalizePath("~/.coder")
require.NoError(t, err)
require.Equal(t, "/tmp/home/.coder", got)
}
//nolint:paralleltest,tparallel // Uses t.Setenv.
func TestCanonicalizePath_BareTildeExpandsToHome(t *testing.T) {
t.Setenv("HOME", "/tmp/home")
got, err := agentcontext.CanonicalizePath("~")
require.NoError(t, err)
require.Equal(t, "/tmp/home", got)
}
func TestCanonicalizePath_FollowsSymlinks(t *testing.T) {
t.Parallel()
dir := t.TempDir()
realDir := filepath.Join(dir, "real")
link := filepath.Join(dir, "link")
require.NoError(t, os.MkdirAll(realDir, 0o755))
require.NoError(t, os.Symlink(realDir, link))
got, err := agentcontext.CanonicalizePath(link)
require.NoError(t, err)
// On macOS the temp dir is itself symlinked; both realDir and got
// pass through the same EvalSymlinks so they line up.
want, err := filepath.EvalSymlinks(realDir)
require.NoError(t, err)
require.Equal(t, want, got)
}
func TestValidateSourcePath_RejectsParentSegments(t *testing.T) {
t.Parallel()
err := agentcontext.ValidateSourcePath("/a/../b", []string{"/a"})
require.Error(t, err)
require.Contains(t, err.Error(), "parent traversal")
}
func TestValidateSourcePath_AllowsInsideRoot(t *testing.T) {
t.Parallel()
dir := t.TempDir()
child := filepath.Join(dir, "child")
require.NoError(t, os.MkdirAll(child, 0o755))
require.NoError(t, agentcontext.ValidateSourcePath(child, []string{dir}))
require.NoError(t, agentcontext.ValidateSourcePath(dir, []string{dir}))
}
func TestValidateSourcePath_RejectsOutsideRoot(t *testing.T) {
t.Parallel()
root := t.TempDir()
other := t.TempDir()
err := agentcontext.ValidateSourcePath(other, []string{root})
require.Error(t, err)
require.Contains(t, err.Error(), "not inside any allowed root")
}
func TestValidateSourcePath_EmptyAllowedRootsBypass(t *testing.T) {
t.Parallel()
require.NoError(t, agentcontext.ValidateSourcePath("/anywhere", nil))
}
func TestValidateSourcePath_InvalidRootsFailClosed(t *testing.T) {
t.Parallel()
// All allowed roots are relative and therefore invalid;
// validation must fail closed.
err := agentcontext.ValidateSourcePath("/anywhere", []string{"relative-only"})
require.Error(t, err)
}
func TestValidateSourcePath_PathPrefixIsPathAware(t *testing.T) {
t.Parallel()
// "/a-prefix" is not inside "/a", even though it starts
// with the same bytes.
dir := t.TempDir()
sibling := strings.TrimRight(dir, string(os.PathSeparator)) + "-sibling"
require.NoError(t, os.MkdirAll(sibling, 0o755))
t.Cleanup(func() { _ = os.RemoveAll(sibling) })
err := agentcontext.ValidateSourcePath(sibling, []string{dir})
require.Error(t, err)
}
+182
View File
@@ -0,0 +1,182 @@
package agentcontext
import (
"context"
"errors"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
)
// PushRequest is the wire-format-independent payload the
// Manager hands to a Pusher. It mirrors the protobuf
// PushContextStateRequest message reserved in the RFC.
//
// Keeping the shape in plain Go lets this package compile
// without bumping the drpc proto version. The follow-up
// integration change can add a thin adapter that converts
// PushRequest to proto and back.
type PushRequest struct {
Version uint64
AggregateHash [32]byte
Resources []Resource
Initial bool
SchemaVersion uint64
SnapshotError string
}
// PushResponse is the wire-format-independent return value of
// a push.
type PushResponse struct {
Accepted bool
}
// Pusher delivers snapshots to coderd. Concrete implementations
// wrap a drpc client (proto v30 and later) or, in tests, a
// recording in-memory fake.
//
// PushContextState must respect ctx cancellation; the Manager
// retries on transient errors with backoff but stops on
// ErrPushUnimplemented.
type Pusher interface {
PushContextState(ctx context.Context, req *PushRequest) (*PushResponse, error)
}
// ErrPushUnimplemented signals that the coderd peer does not
// implement PushContextState. RunPush stops pushing for the
// remainder of the connection.
var ErrPushUnimplemented = xerrors.New("agentcontext: PushContextState unimplemented")
// PushOptions parameterizes RunPush.
type PushOptions struct {
// Logger receives push success/failure diagnostics.
Logger slog.Logger
// InitialBackoff is the wait before the first retry.
// Default 250ms.
InitialBackoff time.Duration
// MaxBackoff caps the retry wait. Default 30s.
MaxBackoff time.Duration
}
// RunPush ships the current snapshot to the Pusher, then ships
// every subsequent snapshot whenever the Manager broadcasts a
// change. RunPush returns when ctx is canceled, when the
// Manager is closed, or when the Pusher signals
// ErrPushUnimplemented.
//
// The first push is always sent with Initial=true so coderd can
// distinguish a fresh boot from a drift event.
func (m *Manager) RunPush(ctx context.Context, p Pusher, opts PushOptions) error {
if p == nil {
return xerrors.New("agentcontext: Pusher is required")
}
logger := opts.Logger
initialBackoff := opts.InitialBackoff
if initialBackoff <= 0 {
initialBackoff = 250 * time.Millisecond
}
maxBackoff := opts.MaxBackoff
if maxBackoff <= 0 {
maxBackoff = 30 * time.Second
}
changes, unsub := m.SubscribeChanges()
defer unsub()
// First push uses the snapshot computed by NewManager.
initial := true
for {
snap := m.Snapshot()
req := snapshotToPushRequest(snap, initial)
err := pushWithRetry(ctx, p, req, initialBackoff, maxBackoff, logger)
switch {
case err == nil:
initial = false
case errors.Is(err, ErrPushUnimplemented):
logger.Warn(ctx, "agentcontext: coderd peer does not implement PushContextState; stopping")
return nil
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
return ctx.Err()
default:
// Should be unreachable: pushWithRetry only
// returns terminal errors. Log and continue.
logger.Warn(ctx, "agentcontext: push terminated with non-retried error", slog.Error(err))
}
select {
case <-ctx.Done():
return ctx.Err()
case <-m.closedCh:
return nil
case _, ok := <-changes:
if !ok {
return nil
}
}
}
}
// pushWithRetry retries transient errors with exponential
// backoff capped at maxBackoff. The retry loop exits when:
//
// - ctx is canceled (returns ctx.Err()).
// - The Pusher returns nil (success).
// - The Pusher returns ErrPushUnimplemented (propagated).
func pushWithRetry(
ctx context.Context,
p Pusher,
req *PushRequest,
initialBackoff, maxBackoff time.Duration,
logger slog.Logger,
) error {
backoff := initialBackoff
for {
resp, err := p.PushContextState(ctx, req)
if err == nil {
if resp != nil && !resp.Accepted {
// Out-of-order or replayed push. Do not
// retry; the next change will redeliver
// the snapshot with a higher version.
logger.Debug(ctx, "agentcontext: push rejected, awaiting next change",
slog.F("version", req.Version))
}
return nil
}
if errors.Is(err, ErrPushUnimplemented) {
return err
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return err
}
logger.Warn(ctx, "agentcontext: push failed, retrying",
slog.F("version", req.Version),
slog.F("backoff", backoff),
slog.Error(err))
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff):
}
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
}
}
// snapshotToPushRequest copies the Snapshot into the wire
// representation. The Resources slice is reused; callers must
// not mutate it.
func snapshotToPushRequest(s Snapshot, initial bool) *PushRequest {
return &PushRequest{
Version: s.Version,
AggregateHash: s.AggregateHash,
Resources: s.Resources,
Initial: initial,
SchemaVersion: s.SchemaVersion,
SnapshotError: s.SnapshotError,
}
}
+197
View File
@@ -0,0 +1,197 @@
package agentcontext_test
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/agent/agentcontext"
"github.com/coder/coder/v2/testutil"
)
// fakePusher records every push and lets the test control the
// returned response and error.
type fakePusher struct {
mu sync.Mutex
requests []*agentcontext.PushRequest
resp *agentcontext.PushResponse
err error
// errOnce is non-nil to simulate a single transient
// failure followed by success.
errOnce error
signal chan struct{}
}
func newFakePusher() *fakePusher {
return &fakePusher{
resp: &agentcontext.PushResponse{Accepted: true},
signal: make(chan struct{}, 16),
}
}
func (p *fakePusher) PushContextState(_ context.Context, req *agentcontext.PushRequest) (*agentcontext.PushResponse, error) {
p.mu.Lock()
defer p.mu.Unlock()
p.requests = append(p.requests, req)
if p.errOnce != nil {
err := p.errOnce
p.errOnce = nil
return nil, err
}
select {
case p.signal <- struct{}{}:
default:
}
return p.resp, p.err
}
func (p *fakePusher) snapshot() []*agentcontext.PushRequest {
p.mu.Lock()
defer p.mu.Unlock()
out := make([]*agentcontext.PushRequest, len(p.requests))
copy(out, p.requests)
return out
}
func TestRunPush_FirstPushIsInitial(t *testing.T) {
t.Parallel()
dir := t.TempDir()
require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v1"), 0o600))
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return dir },
})
p := newFakePusher()
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
defer cancel()
pushDone := make(chan error, 1)
go func() {
pushDone <- m.RunPush(ctx, p, agentcontext.PushOptions{
Logger: testutil.Logger(t).Named("push"),
})
}()
// Wait for the first push.
select {
case <-p.signal:
case <-time.After(testutil.WaitShort):
t.Fatalf("expected initial push")
}
requests := p.snapshot()
require.Len(t, requests, 1)
require.True(t, requests[0].Initial, "first push must be initial")
require.Equal(t, uint64(1), requests[0].Version)
cancel()
require.ErrorIs(t, <-pushDone, context.Canceled)
}
func TestRunPush_SubsequentPushOnChange(t *testing.T) {
t.Parallel()
dir := t.TempDir()
require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v1"), 0o600))
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return dir },
})
p := newFakePusher()
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
defer cancel()
pushDone := make(chan error, 1)
go func() {
pushDone <- m.RunPush(ctx, p, agentcontext.PushOptions{
Logger: testutil.Logger(t).Named("push"),
})
}()
// Initial push.
<-p.signal
// Trigger a resync via Resync.
require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v2"), 0o600))
_, err := m.Resync(ctx)
require.NoError(t, err)
// Second push.
select {
case <-p.signal:
case <-time.After(testutil.WaitShort):
t.Fatalf("expected second push after resync")
}
requests := p.snapshot()
require.GreaterOrEqual(t, len(requests), 2)
require.False(t, requests[1].Initial, "subsequent pushes must not be Initial")
cancel()
require.ErrorIs(t, <-pushDone, context.Canceled)
}
func TestRunPush_StopsOnUnimplemented(t *testing.T) {
t.Parallel()
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return t.TempDir() },
})
p := newFakePusher()
p.err = agentcontext.ErrPushUnimplemented
ctx := testutil.Context(t, testutil.WaitShort)
err := m.RunPush(ctx, p, agentcontext.PushOptions{
Logger: testutil.Logger(t).Named("push"),
})
require.NoError(t, err, "Unimplemented must stop the loop cleanly")
}
func TestRunPush_RetriesTransientError(t *testing.T) {
t.Parallel()
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return t.TempDir() },
})
p := newFakePusher()
p.errOnce = xerrors.New("transient")
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
defer cancel()
pushDone := make(chan error, 1)
go func() {
pushDone <- m.RunPush(ctx, p, agentcontext.PushOptions{
Logger: testutil.Logger(t).Named("push"),
InitialBackoff: 10 * time.Millisecond,
})
}()
// First push hits transient, second succeeds.
select {
case <-p.signal:
case <-time.After(testutil.WaitShort):
t.Fatalf("expected push after transient error")
}
require.GreaterOrEqual(t, len(p.snapshot()), 2)
cancel()
<-pushDone
}
func TestRunPush_NilPusherErrors(t *testing.T) {
t.Parallel()
m := newTestManager(t, agentcontext.ManagerOptions{
WorkingDir: func() string { return t.TempDir() },
})
err := m.RunPush(context.Background(), nil, agentcontext.PushOptions{
Logger: testutil.Logger(t).Named("push"),
})
require.Error(t, err)
}
+613
View File
@@ -0,0 +1,613 @@
package agentcontext
import (
"crypto/sha256"
"errors"
"fmt"
"io"
"io/fs"
"math"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// Default caps. Copied from the RFC. The Manager exposes
// overrides via Options.
const (
// DefaultMaxResourceBytes is the per-resource payload cap.
// Resources whose payload exceeds this size are emitted
// with Status == StatusOversize and an empty Payload.
DefaultMaxResourceBytes = 64 * 1024
// DefaultMaxSnapshotBytes is the aggregate payload cap.
// Resources past this cap are emitted with Status ==
// StatusExcluded.
DefaultMaxSnapshotBytes = 2 * 1024 * 1024
// DefaultMaxResources is the resource count cap. Resources
// past this cap are emitted with Status == StatusExcluded.
DefaultMaxResources = 500
// DefaultMaxScanDepth bounds how deep the recursive walk
// descends from each scan root. The default avoids runaway
// scans in node_modules / vendor / .git trees while still
// covering realistic monorepo layouts.
DefaultMaxScanDepth = 8
)
// File-name conventions recognized by the v1 resolver.
var (
// instructionFileNames are picked up from any scan root.
// Matching is case-insensitive on the basename.
instructionFileNames = []string{
"AGENTS.md",
"CLAUDE.md",
".cursorrules",
}
// mcpConfigFileName is recognized at any depth under a
// scan root.
mcpConfigFileName = ".mcp.json"
// skillMetaFileName is the file inside a skill directory
// that carries the skill front-matter.
skillMetaFileName = "SKILL.md"
)
// skipDirNames are directory basenames that the recursive walk
// never descends into. The list mirrors what most language
// tool-chains treat as opaque.
var skipDirNames = map[string]struct{}{
".git": {},
".hg": {},
".svn": {},
"node_modules": {},
"vendor": {},
"target": {},
"dist": {},
"build": {},
".venv": {},
"__pycache__": {},
}
// skillsParentNames are directory basenames that signal a
// skills container; their immediate children are scanned for
// SKILL.md files.
var skillsParentNames = map[string]struct{}{
"skills": {},
".agents": {}, // covers ".agents/skills/<name>"
"agents": {},
"plugins": {}, // claude code plugin cache layout
"cache": {},
".coder": {},
".claude": {},
"skills-dir": {},
}
// recognizedInstructionFile reports whether name is one of the
// instruction-file conventions, case-insensitively.
func recognizedInstructionFile(name string) bool {
for _, candidate := range instructionFileNames {
if strings.EqualFold(name, candidate) {
return true
}
}
return false
}
// Resolver walks one or more scan roots and produces a snapshot
// of every recognized resource it finds. The Resolver is
// stateless; the Manager owns the scan-root list and orchestrates
// successive resolves.
type Resolver struct {
// MaxResourceBytes caps the per-resource payload size. Use
// DefaultMaxResourceBytes if zero.
MaxResourceBytes uint64
// MaxSnapshotBytes caps the aggregate payload size. Use
// DefaultMaxSnapshotBytes if zero.
MaxSnapshotBytes uint64
// MaxResources caps the resource count. Use
// DefaultMaxResources if zero.
MaxResources int
// MaxDepth caps the directory walk depth. Use
// DefaultMaxScanDepth if zero.
MaxDepth int
// MCP, when non-nil, is consulted after the filesystem
// pass and contributes any KindMCPServer resources for
// live MCP servers.
MCP MCPProvider
}
// ScanRoot describes a single directory or file the resolver
// should examine.
type ScanRoot struct {
// Path is the absolute path. Symlinks should already be
// resolved.
Path string
// UserSource is the canonical source path the user
// declared, when this root came from a user-added Source.
// Empty for built-in roots.
UserSource string
}
// Resolve walks the supplied scan roots and returns a Snapshot.
// The version and schemaVersion fields are stamped by the
// caller; Resolve fills everything else.
func (r *Resolver) Resolve(roots []ScanRoot) Snapshot {
res := r.normalize()
resources, snapErrs := res.walk(roots)
resources = res.applyCaps(resources)
// Append MCP server resources after the filesystem caps
// are applied so a runaway MCP server cannot crowd out
// instruction files.
if r.MCP != nil {
mcp := r.MCP.MCPResources()
resources = append(resources, mcp...)
// MCP resources may push the aggregate over the cap.
// Re-apply count and size limits to MCP entries only.
resources, snapErrs = res.applyMCPCaps(resources, snapErrs)
}
// Deterministic order by ID for stable IDs and hashes.
sort.Slice(resources, func(i, j int) bool {
return resources[i].ID < resources[j].ID
})
var payloadBytes uint64
for _, r := range resources {
payloadBytes += uint64(len(r.Payload))
}
hash := ComputeAggregateHash(resources)
snap := Snapshot{
Resources: resources,
AggregateHash: hash,
PayloadBytes: payloadBytes,
}
if len(snapErrs) > 0 {
// Pick the most severe single error. Today every
// snapshot-level problem is "warning equivalent" so
// the first one wins; the design reserves the field
// for a singular message.
snap.SnapshotError = snapErrs[0]
}
return snap
}
func (r *Resolver) normalize() *Resolver {
out := *r
if out.MaxResourceBytes == 0 {
out.MaxResourceBytes = DefaultMaxResourceBytes
}
if out.MaxSnapshotBytes == 0 {
out.MaxSnapshotBytes = DefaultMaxSnapshotBytes
}
if out.MaxResources == 0 {
out.MaxResources = DefaultMaxResources
}
if out.MaxDepth == 0 {
out.MaxDepth = DefaultMaxScanDepth
}
return &out
}
// walk traverses every scan root and produces an unordered
// resource list. Aggregate caps are applied separately.
func (r *Resolver) walk(roots []ScanRoot) (resources []Resource, snapErrs []string) {
// Dedup roots by canonical path. The first occurrence
// wins so user-added roots that overlap with a built-in
// root attribute resources to the built-in.
seenRoot := make(map[string]struct{}, len(roots))
dedup := make([]ScanRoot, 0, len(roots))
for _, root := range roots {
if root.Path == "" {
continue
}
if _, ok := seenRoot[root.Path]; ok {
continue
}
seenRoot[root.Path] = struct{}{}
dedup = append(dedup, root)
}
// Deduplicate resources across roots by ID. Without this,
// a built-in root and a user root that both cover the
// same project tree would double-count AGENTS.md.
seenID := make(map[string]struct{})
for _, root := range dedup {
info, err := os.Stat(root.Path)
if err != nil {
// Missing roots silently fall through. The user
// either added a path that does not exist yet or
// removed it later. The watcher will surface
// re-creation as a change event.
continue
}
if !info.IsDir() {
// Single-file roots are classified directly.
if res, ok := r.classifyFile(root.Path, info, root.UserSource); ok {
if _, dup := seenID[res.ID]; !dup {
seenID[res.ID] = struct{}{}
resources = append(resources, res)
}
}
continue
}
walkErr := r.walkDir(root, &resources, seenID)
if walkErr != nil {
snapErrs = append(snapErrs, fmt.Sprintf("walk %q: %s", root.Path, walkErr))
}
}
return resources, snapErrs
}
// walkDir performs the recursive descent for a single scan
// directory. It honors r.MaxDepth and skipDirNames.
func (r *Resolver) walkDir(root ScanRoot, out *[]Resource, seenID map[string]struct{}) error {
rootDepth := strings.Count(filepath.Clean(root.Path), string(os.PathSeparator))
maxDepth := rootDepth + r.MaxDepth
return filepath.WalkDir(root.Path, func(path string, d fs.DirEntry, err error) error {
if err != nil {
// Surface the error as Unreadable when we can
// associate it with a single recognized file;
// otherwise let the walk continue.
if d != nil && !d.IsDir() {
if recognizedInstructionFile(d.Name()) ||
d.Name() == mcpConfigFileName ||
d.Name() == skillMetaFileName {
res := Resource{
ID: resourceID(KindInstructionFile, path),
Kind: KindInstructionFile,
Source: path,
SizeBytes: 0,
Status: StatusUnreadable,
Error: err.Error(),
SourcePath: root.UserSource,
}
if _, dup := seenID[res.ID]; !dup {
seenID[res.ID] = struct{}{}
*out = append(*out, res)
}
}
}
if errors.Is(err, fs.ErrPermission) {
// Permission errors on a directory: skip the
// subtree but continue walking siblings.
if d != nil && d.IsDir() {
return fs.SkipDir
}
}
return nil
}
if d.IsDir() {
if strings.Count(path, string(os.PathSeparator)) > maxDepth {
return fs.SkipDir
}
if _, skip := skipDirNames[d.Name()]; skip && path != root.Path {
return fs.SkipDir
}
// If we are entering a "skills container"
// directory (".agents/skills", "~/.coder/skills",
// "plugins/<plugin>/skills"), eagerly emit skill
// resources for its immediate subdirectories.
if isSkillsContainer(path) {
r.emitSkillsFromContainer(path, root, out, seenID)
}
return nil
}
// Regular file.
info, statErr := d.Info()
if statErr != nil {
return nil
}
if res, ok := r.classifyFile(path, info, root.UserSource); ok {
if _, dup := seenID[res.ID]; dup {
return nil
}
seenID[res.ID] = struct{}{}
*out = append(*out, res)
}
return nil
})
}
// classifyFile inspects a single file path and produces a
// Resource when the basename matches a recognized convention.
func (r *Resolver) classifyFile(path string, info fs.FileInfo, userSource string) (Resource, bool) {
name := info.Name()
switch {
case recognizedInstructionFile(name):
return r.readInstructionFile(path, info, userSource), true
case name == mcpConfigFileName:
return r.readMCPConfig(path, info, userSource), true
case name == skillMetaFileName:
// SKILL.md outside a skills container is still a
// valid skill if its parent directory name matches
// the front-matter name. emitSkillsFromContainer
// already handles the common case; here we cover
// "user adds a single SKILL.md file as a source".
res, ok := r.readSkillMeta(path, info, userSource)
return res, ok
default:
return Resource{}, false
}
}
// readInstructionFile reads an instruction file and produces a
// KindInstructionFile resource. The file is read into memory
// with the per-resource cap applied.
func (r *Resolver) readInstructionFile(path string, info fs.FileInfo, userSource string) Resource {
res := Resource{
ID: resourceID(KindInstructionFile, path),
Kind: KindInstructionFile,
Source: path,
SizeBytes: safeUint64(info.Size()),
SourcePath: userSource,
}
if safeUint64(info.Size()) > r.MaxResourceBytes {
res.Status = StatusOversize
res.Error = fmt.Sprintf("file size %d exceeds per-resource cap of %d bytes", info.Size(), r.MaxResourceBytes)
// Still hash the (capped) content so a fix is
// detectable.
if data, err := readFileCapped(path, safeInt64(r.MaxResourceBytes)); err == nil {
res.ContentHash = sha256.Sum256(data)
}
return res
}
data, err := os.ReadFile(path)
if err != nil {
res.Status = StatusUnreadable
res.Error = err.Error()
return res
}
res.Payload = data
res.ContentHash = sha256.Sum256(data)
// Description is just the first non-empty line.
res.Description = firstLine(string(data))
return res
}
// readMCPConfig reads a .mcp.json file and produces a
// KindMCPConfig resource. Parsing is left to consumers; the
// resolver only enforces JSON shape lightly via size and Unix
// newline conversion. Future work: detect malformed JSON and
// surface StatusInvalid.
func (r *Resolver) readMCPConfig(path string, info fs.FileInfo, userSource string) Resource {
res := Resource{
ID: resourceID(KindMCPConfig, path),
Kind: KindMCPConfig,
Source: path,
SizeBytes: safeUint64(info.Size()),
SourcePath: userSource,
}
if safeUint64(info.Size()) > r.MaxResourceBytes {
res.Status = StatusOversize
res.Error = fmt.Sprintf("file size %d exceeds per-resource cap of %d bytes", info.Size(), r.MaxResourceBytes)
if data, err := readFileCapped(path, safeInt64(r.MaxResourceBytes)); err == nil {
res.ContentHash = sha256.Sum256(data)
}
return res
}
data, err := os.ReadFile(path)
if err != nil {
res.Status = StatusUnreadable
res.Error = err.Error()
return res
}
res.Payload = data
res.ContentHash = sha256.Sum256(data)
return res
}
// readSkillMeta reads a SKILL.md file, parses its front-matter,
// and emits a KindSkill resource. The name encoded in the
// front-matter must match the parent directory's basename to
// be considered valid; otherwise Status is StatusInvalid.
func (r *Resolver) readSkillMeta(path string, info fs.FileInfo, userSource string) (Resource, bool) {
parent := filepath.Base(filepath.Dir(path))
res := Resource{
ID: resourceID(KindSkill, filepath.Dir(path)),
Kind: KindSkill,
Source: filepath.Dir(path),
SizeBytes: safeUint64(info.Size()),
SourcePath: userSource,
}
if safeUint64(info.Size()) > r.MaxResourceBytes {
res.Status = StatusOversize
res.Error = fmt.Sprintf("file size %d exceeds per-resource cap of %d bytes", info.Size(), r.MaxResourceBytes)
return res, true
}
data, err := os.ReadFile(path)
if err != nil {
res.Status = StatusUnreadable
res.Error = err.Error()
return res, true
}
res.ContentHash = sha256.Sum256(data)
name, description, _, err := workspacesdk.ParseSkillFrontmatter(string(data))
if err != nil {
res.Status = StatusInvalid
res.Error = err.Error()
return res, true
}
if name != parent {
res.Status = StatusInvalid
res.Error = fmt.Sprintf("front-matter name %q does not match directory %q", name, parent)
return res, true
}
if !workspacesdk.SkillNamePattern.MatchString(name) {
res.Status = StatusInvalid
res.Error = fmt.Sprintf("skill name %q is not kebab-case", name)
return res, true
}
res.Description = description
res.Payload = data
return res, true
}
// emitSkillsFromContainer scans the immediate children of a
// recognized skills-container directory and emits one Skill
// resource per subdirectory whose SKILL.md parses cleanly.
func (r *Resolver) emitSkillsFromContainer(container string, root ScanRoot, out *[]Resource, seenID map[string]struct{}) {
entries, err := os.ReadDir(container)
if err != nil {
return
}
for _, e := range entries {
if !e.IsDir() {
continue
}
meta := filepath.Join(container, e.Name(), skillMetaFileName)
info, err := os.Stat(meta)
if err != nil {
continue
}
res, ok := r.readSkillMeta(meta, info, root.UserSource)
if !ok {
continue
}
if _, dup := seenID[res.ID]; dup {
continue
}
seenID[res.ID] = struct{}{}
*out = append(*out, res)
}
}
// applyCaps enforces the resource-count cap and aggregate
// payload cap. Resources past either cap have their Status set
// to StatusExcluded and their Payload cleared.
func (r *Resolver) applyCaps(resources []Resource) []Resource {
// Stable sort by (Kind asc, Source asc) so excluded
// resources are deterministic.
sort.SliceStable(resources, func(i, j int) bool {
if resources[i].Kind != resources[j].Kind {
return resources[i].Kind < resources[j].Kind
}
return resources[i].Source < resources[j].Source
})
var total uint64
for i := range resources {
if i >= r.MaxResources {
resources[i] = excluded(resources[i],
fmt.Sprintf("dropped to fit %d-resource snapshot count cap", r.MaxResources))
continue
}
if resources[i].Status != StatusOK {
continue
}
size := uint64(len(resources[i].Payload))
if total+size > r.MaxSnapshotBytes {
resources[i] = excluded(resources[i],
fmt.Sprintf("dropped to fit %d-byte aggregate cap", r.MaxSnapshotBytes))
continue
}
total += size
}
return resources
}
// applyMCPCaps re-applies the count cap after MCP resources are
// appended. MCP payloads are typically small JSON descriptors,
// so we treat the aggregate budget as already consumed by the
// filesystem pass.
func (r *Resolver) applyMCPCaps(resources []Resource, snapErrs []string) ([]Resource, []string) {
if len(resources) <= r.MaxResources {
return resources, snapErrs
}
for i := r.MaxResources; i < len(resources); i++ {
resources[i] = excluded(resources[i],
fmt.Sprintf("dropped to fit %d-resource snapshot count cap", r.MaxResources))
}
snapErrs = append(snapErrs, fmt.Sprintf("snapshot exceeds %d-resource count cap", r.MaxResources))
return resources, snapErrs
}
// excluded mutates and returns the supplied resource with the
// StatusExcluded outcome.
func excluded(r Resource, reason string) Resource {
r.Status = StatusExcluded
r.Error = reason
r.Payload = nil
return r
}
// isSkillsContainer reports whether dir is a recognized skills
// container directory whose immediate children carry SKILL.md
// files.
func isSkillsContainer(dir string) bool {
base := filepath.Base(dir)
_, ok := skillsParentNames[base]
if ok && base == "skills" {
return true
}
// "<x>/skills" form (e.g. ".agents/skills", "plugins/foo/skills").
if strings.HasSuffix(filepath.ToSlash(dir), "/skills") {
return true
}
return false
}
// resourceID builds a stable resource ID. Kind plus canonical
// source path is enough; sources never collide across kinds for
// v1 because each kind owns a distinct file-name pattern.
func resourceID(kind ResourceKind, source string) string {
return kind.String() + ":" + source
}
// readFileCapped reads up to maxBytes from path. It returns the
// truncated payload on success.
func readFileCapped(path string, maxBytes int64) ([]byte, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
return io.ReadAll(io.LimitReader(f, maxBytes))
}
// firstLine returns the first non-empty trimmed line of s, used
// as a short description fallback.
func firstLine(s string) string {
for _, line := range strings.Split(s, "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Strip leading markdown heading markers for prettier
// descriptions.
return strings.TrimSpace(headingPrefixRegex.ReplaceAllString(line, ""))
}
return ""
}
var headingPrefixRegex = regexp.MustCompile(`^#+\s*`)
// safeUint64 converts a non-negative int64 to uint64. Negative
// inputs are clamped to 0, which is safe for the size-tracking
// fields that use it; a negative os.FileInfo size is pathological
// and never indicates real content.
func safeUint64(n int64) uint64 {
if n < 0 {
return 0
}
return uint64(n)
}
// safeInt64 converts a uint64 to int64, clamping to math.MaxInt64
// when the input would overflow. The caps configured on the
// resolver never approach 2^63 bytes, so the clamp only guards
// against pathological caller input.
func safeInt64(n uint64) int64 {
if n > math.MaxInt64 {
return math.MaxInt64
}
return int64(n)
}
+276
View File
@@ -0,0 +1,276 @@
package agentcontext_test
import (
"crypto/sha256"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentcontext"
)
func mustWriteFile(t *testing.T, path, content string) {
t.Helper()
require.NoError(t, os.MkdirAll(filepath.Dir(path), 0o755))
require.NoError(t, os.WriteFile(path, []byte(content), 0o600))
}
func mustWriteSkill(t *testing.T, dir, name, description string) {
t.Helper()
require.NoError(t, os.MkdirAll(filepath.Join(dir, name), 0o755))
mustWriteFile(t, filepath.Join(dir, name, "SKILL.md"),
"---\nname: "+name+"\ndescription: "+description+"\n---\nSkill body for "+name)
}
func findResource(t *testing.T, resources []agentcontext.Resource, kind agentcontext.ResourceKind, source string) agentcontext.Resource {
t.Helper()
for _, r := range resources {
if r.Kind == kind && r.Source == source {
return r
}
}
t.Fatalf("resource not found: kind=%s source=%s", kind, source)
return agentcontext.Resource{}
}
func TestResolver_ProjectAGENTSFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "# Project rules\n\nDo the thing.")
r := &agentcontext.Resolver{}
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
require.Len(t, snap.Resources, 1)
got := snap.Resources[0]
require.Equal(t, agentcontext.KindInstructionFile, got.Kind)
require.Equal(t, agentcontext.StatusOK, got.Status)
require.Equal(t, filepath.Join(dir, "AGENTS.md"), got.Source)
require.Contains(t, string(got.Payload), "Do the thing.")
require.Equal(t, "Project rules", got.Description)
require.NotEqual(t, [32]byte{}, got.ContentHash)
}
func TestResolver_CaseInsensitiveInstructionNames(t *testing.T) {
t.Parallel()
dir := t.TempDir()
mustWriteFile(t, filepath.Join(dir, "agents.md"), "lower\n")
mustWriteFile(t, filepath.Join(dir, "CLAUDE.md"), "claude\n")
r := &agentcontext.Resolver{}
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
require.Len(t, snap.Resources, 2)
}
func TestResolver_SkillsContainerEmitsEachSubdir(t *testing.T) {
t.Parallel()
dir := t.TempDir()
mustWriteSkill(t, filepath.Join(dir, ".agents", "skills"), "make-coffee", "Coffee skill")
mustWriteSkill(t, filepath.Join(dir, ".agents", "skills"), "fold-laundry", "Laundry skill")
r := &agentcontext.Resolver{}
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
var kinds []string
for _, res := range snap.Resources {
kinds = append(kinds, res.Kind.String()+":"+filepath.Base(res.Source))
}
require.ElementsMatch(t, []string{
"skill:make-coffee",
"skill:fold-laundry",
}, kinds)
}
func TestResolver_SkillNameMismatchInvalid(t *testing.T) {
t.Parallel()
dir := t.TempDir()
skillsDir := filepath.Join(dir, ".agents", "skills", "make-coffee")
require.NoError(t, os.MkdirAll(skillsDir, 0o755))
mustWriteFile(t, filepath.Join(skillsDir, "SKILL.md"),
"---\nname: drink-tea\ndescription: oops\n---\nBody")
r := &agentcontext.Resolver{}
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
require.Len(t, snap.Resources, 1)
got := snap.Resources[0]
require.Equal(t, agentcontext.KindSkill, got.Kind)
require.Equal(t, agentcontext.StatusInvalid, got.Status)
require.Contains(t, got.Error, "does not match directory")
}
func TestResolver_MCPConfigEmitted(t *testing.T) {
t.Parallel()
dir := t.TempDir()
mustWriteFile(t, filepath.Join(dir, ".mcp.json"), `{"mcpServers": {}}`)
r := &agentcontext.Resolver{}
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
require.Len(t, snap.Resources, 1)
require.Equal(t, agentcontext.KindMCPConfig, snap.Resources[0].Kind)
require.Equal(t, agentcontext.StatusOK, snap.Resources[0].Status)
}
func TestResolver_OversizeInstructionFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
// Write a file larger than the per-resource cap.
big := make([]byte, 200)
for i := range big {
big[i] = 'a'
}
mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), string(big))
r := &agentcontext.Resolver{MaxResourceBytes: 100}
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
require.Len(t, snap.Resources, 1)
got := snap.Resources[0]
require.Equal(t, agentcontext.StatusOversize, got.Status)
require.Empty(t, got.Payload)
require.Equal(t, uint64(200), got.SizeBytes)
// Hash over capped slice is still populated so callers
// can detect "still oversize but content changed".
require.NotEqual(t, [32]byte{}, got.ContentHash)
}
func TestResolver_AggregateCapExcludes(t *testing.T) {
t.Parallel()
dir := t.TempDir()
mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "small")
subA := filepath.Join(dir, "a")
subB := filepath.Join(dir, "b")
mustWriteFile(t, filepath.Join(subA, "AGENTS.md"), "AAAA")
mustWriteFile(t, filepath.Join(subB, "AGENTS.md"), "BBBB")
// Aggregate cap of 9 bytes lets the first two through but
// excludes the third regardless of which order they
// appear.
r := &agentcontext.Resolver{MaxSnapshotBytes: 9}
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
var excluded int
for _, res := range snap.Resources {
if res.Status == agentcontext.StatusExcluded {
excluded++
}
}
require.Equal(t, 1, excluded)
}
func TestResolver_CountCapExcludes(t *testing.T) {
t.Parallel()
dir := t.TempDir()
for i := 0; i < 5; i++ {
sub := filepath.Join(dir, "dir", string('a'+rune(i)))
mustWriteFile(t, filepath.Join(sub, "AGENTS.md"), "x")
}
r := &agentcontext.Resolver{MaxResources: 3}
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
require.Len(t, snap.Resources, 5)
var excluded int
for _, res := range snap.Resources {
if res.Status == agentcontext.StatusExcluded {
excluded++
}
}
require.Equal(t, 2, excluded)
}
func TestResolver_SkipsVendorAndNodeModules(t *testing.T) {
t.Parallel()
dir := t.TempDir()
mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "root")
mustWriteFile(t, filepath.Join(dir, "node_modules", "deep", "AGENTS.md"), "should not appear")
mustWriteFile(t, filepath.Join(dir, "vendor", "AGENTS.md"), "should not appear either")
r := &agentcontext.Resolver{}
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
require.Len(t, snap.Resources, 1)
require.Equal(t, filepath.Join(dir, "AGENTS.md"), snap.Resources[0].Source)
}
func TestResolver_UserSourceAttribution(t *testing.T) {
t.Parallel()
dir := t.TempDir()
mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "user-added")
r := &agentcontext.Resolver{}
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir, UserSource: dir}})
require.Len(t, snap.Resources, 1)
require.Equal(t, dir, snap.Resources[0].SourcePath)
}
func TestResolver_MissingRootSilentlyIgnored(t *testing.T) {
t.Parallel()
r := &agentcontext.Resolver{}
snap := r.Resolve([]agentcontext.ScanRoot{{Path: "/nonexistent/path"}})
require.Empty(t, snap.Resources)
require.Empty(t, snap.SnapshotError)
}
func TestResolver_SingleFileRootClassified(t *testing.T) {
t.Parallel()
dir := t.TempDir()
path := filepath.Join(dir, "AGENTS.md")
mustWriteFile(t, path, "x")
r := &agentcontext.Resolver{}
snap := r.Resolve([]agentcontext.ScanRoot{{Path: path}})
require.Len(t, snap.Resources, 1)
require.Equal(t, agentcontext.KindInstructionFile, snap.Resources[0].Kind)
}
func TestResolver_DuplicateRootsDeduplicated(t *testing.T) {
t.Parallel()
dir := t.TempDir()
mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "x")
r := &agentcontext.Resolver{}
snap := r.Resolve([]agentcontext.ScanRoot{
{Path: dir},
{Path: dir},
{Path: dir},
})
require.Len(t, snap.Resources, 1)
}
func TestResolver_MCPProviderResources(t *testing.T) {
t.Parallel()
dir := t.TempDir()
mcpRes := agentcontext.Resource{
ID: "mcp_server:github",
Kind: agentcontext.KindMCPServer,
Source: "github",
Status: agentcontext.StatusOK,
Payload: []byte("tool-list-json"),
ContentHash: sha256.Sum256([]byte("tool-list-json")),
Description: "GitHub MCP server",
}
r := &agentcontext.Resolver{
MCP: &fakeMCPProvider{resources: []agentcontext.Resource{mcpRes}},
}
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
got := findResource(t, snap.Resources, agentcontext.KindMCPServer, "github")
require.Equal(t, agentcontext.StatusOK, got.Status)
require.Equal(t, "GitHub MCP server", got.Description)
}
type fakeMCPProvider struct {
resources []agentcontext.Resource
}
func (f *fakeMCPProvider) MCPResources() []agentcontext.Resource {
return f.resources
}
+222
View File
@@ -0,0 +1,222 @@
package agentcontext
import (
"crypto/sha256"
"encoding/hex"
"sort"
"strconv"
)
// ResourceKind describes the category of a resolved context
// resource. The values mirror the proto ContextResource.Kind
// enum reserved in the RFC; future kinds (PLUGIN, HOOK,
// SUBAGENT, COMMAND) are defined here so callers can switch
// exhaustively, but no v1 resolver emits them.
type ResourceKind int
const (
KindUnspecified ResourceKind = iota
// KindInstructionFile covers AGENTS.md, CLAUDE.md,
// .cursorrules, and similar plain-text rule files that
// inject content into the model prompt.
KindInstructionFile
// KindSkill is a directory containing SKILL.md and any
// supporting files. Only the meta file is read at
// resolve time; bodies are fetched on demand.
KindSkill
// KindMCPConfig is a .mcp.json fragment declaring one or
// more MCP servers.
KindMCPConfig
// KindMCPServer is a live MCP server's resolved tool list,
// populated by an MCPProvider after the server has been
// connected.
KindMCPServer
// KindPlugin is reserved for Claude Code plugin manifests.
// Not emitted by v1.
KindPlugin
// KindHook is reserved for plugin hooks. Not emitted by v1.
KindHook
// KindSubagent is reserved for plugin-declared subagents.
// Not emitted by v1.
KindSubagent
// KindCommand is reserved for plugin slash commands.
// Not emitted by v1.
KindCommand
)
// String returns the lower-snake-case name used in IDs and
// metrics. Unknown values stringify to "unknown".
func (k ResourceKind) String() string {
switch k {
case KindInstructionFile:
return "instruction_file"
case KindSkill:
return "skill"
case KindMCPConfig:
return "mcp_config"
case KindMCPServer:
return "mcp_server"
case KindPlugin:
return "plugin"
case KindHook:
return "hook"
case KindSubagent:
return "subagent"
case KindCommand:
return "command"
default:
return "unknown"
}
}
// ResourceStatus describes whether a resource was successfully
// read and whether its payload survived the per-resource and
// aggregate caps.
type ResourceStatus int
const (
// StatusOK indicates the payload was populated.
StatusOK ResourceStatus = iota
// StatusOversize indicates the resource exceeded the
// per-resource size cap; payload is omitted.
StatusOversize
// StatusUnreadable indicates an IO error reading the
// resource (permission denied, broken symlink, etc.).
StatusUnreadable
// StatusInvalid indicates the resource was structurally
// malformed (bad JSON, missing front-matter, etc.).
StatusInvalid
// StatusExcluded indicates the resource was dropped to fit
// the aggregate snapshot or count cap.
StatusExcluded
)
// String returns the lower-snake-case name used in IDs and
// metrics. Unknown values stringify to "unknown".
func (s ResourceStatus) String() string {
switch s {
case StatusOK:
return "ok"
case StatusOversize:
return "oversize"
case StatusUnreadable:
return "unreadable"
case StatusInvalid:
return "invalid"
case StatusExcluded:
return "excluded"
default:
return "unknown"
}
}
// Source is a user-declared scan root added to the agent's
// in-memory list via the HTTP API or boot-time env seeding.
// Identity is the canonical absolute path.
type Source struct {
// Path is the canonical absolute path (symlinks resolved,
// ~ expanded). Empty means the zero value.
Path string
}
// Resource is what the resolver emits for each recognized file
// or live server it discovers under a scan root.
type Resource struct {
// ID is stable across pushes for the same logical
// resource. The current scheme is "<kind>:<source>".
ID string
// Kind classifies the resource for snapshot consumers.
Kind ResourceKind
// Source is the file path or MCP server name.
Source string
// ContentHash is sha256 over the resource's original
// bytes (or transport-encoded server tool list).
ContentHash [32]byte
// Payload is the full bytes when Status == StatusOK; the
// per-resource and aggregate caps may leave it empty.
Payload []byte
// SizeBytes is the original payload size, populated
// regardless of Status.
SizeBytes uint64
// Status records OK or a reason the payload is absent.
Status ResourceStatus
// Error is populated whenever Status != StatusOK; may
// also carry a non-fatal warning when Status == StatusOK.
Error string
// Description is a short human-readable summary (skill
// front-matter description, MCP server description, etc.).
Description string
// SourcePath is the user-declared source that contributed
// the resource; empty for built-in scan roots.
SourcePath string
}
// Snapshot is the immutable bundle of resources produced by a
// single resolver pass.
type Snapshot struct {
// Version is monotonically increasing per Manager
// instance; resets when the agent process restarts.
Version uint64
// SchemaVersion is bumped if the resource shape on the
// wire changes.
SchemaVersion uint64
// AggregateHash is sha256 over a canonical encoding of
// (ID, Kind, Source, ContentHash, Status) for every
// resource. Identical inputs always produce identical
// hashes; see ComputeAggregateHash.
AggregateHash [32]byte
// Resources is sorted by ID for deterministic encoding.
Resources []Resource
// PayloadBytes is the sum of len(Resource.Payload) across
// emitted resources after caps were applied.
PayloadBytes uint64
// SnapshotError carries a single snapshot-level error
// string when present (count cap exceeded, watcher
// degraded, ENOSPC, etc.). Empty when healthy.
SnapshotError string
}
// ComputeAggregateHash produces the deterministic snapshot
// aggregate hash for the supplied resources. The caller does
// not need to pre-sort; the function sorts a copy of the slice
// to keep its inputs side-effect free.
//
// The encoding is a newline-delimited stream of fields. The
// resource boundary is a single NUL byte. The boundary scheme
// is internal to the agent and coderd, but it is stable across
// platforms because every field is encoded as either a UTF-8
// string with a length prefix or a fixed-width integer.
func ComputeAggregateHash(resources []Resource) [32]byte {
indexed := make([]Resource, len(resources))
copy(indexed, resources)
sort.Slice(indexed, func(i, j int) bool {
return indexed[i].ID < indexed[j].ID
})
h := sha256.New()
for _, r := range indexed {
writeLengthPrefixed(h, r.ID)
writeLengthPrefixed(h, r.Kind.String())
writeLengthPrefixed(h, r.Source)
_, _ = h.Write(r.ContentHash[:])
writeLengthPrefixed(h, r.Status.String())
_, _ = h.Write([]byte{0})
}
var out [32]byte
copy(out[:], h.Sum(nil))
return out
}
// AggregateHashHex returns the hex-encoded aggregate hash.
// Convenience for log lines and HTTP responses.
func (s Snapshot) AggregateHashHex() string {
return hex.EncodeToString(s.AggregateHash[:])
}
// writeLengthPrefixed writes a uvarint length prefix followed
// by the raw bytes of s.
func writeLengthPrefixed(h interface{ Write([]byte) (int, error) }, s string) {
_, _ = h.Write([]byte(strconv.Itoa(len(s))))
_, _ = h.Write([]byte{':'})
_, _ = h.Write([]byte(s))
}
+99
View File
@@ -0,0 +1,99 @@
package agentcontext_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentcontext"
)
func TestResourceKindString(t *testing.T) {
t.Parallel()
tests := []struct {
kind agentcontext.ResourceKind
want string
}{
{agentcontext.KindUnspecified, "unknown"},
{agentcontext.KindInstructionFile, "instruction_file"},
{agentcontext.KindSkill, "skill"},
{agentcontext.KindMCPConfig, "mcp_config"},
{agentcontext.KindMCPServer, "mcp_server"},
{agentcontext.KindPlugin, "plugin"},
{agentcontext.KindHook, "hook"},
{agentcontext.KindSubagent, "subagent"},
{agentcontext.KindCommand, "command"},
{agentcontext.ResourceKind(999), "unknown"},
}
for _, tt := range tests {
require.Equal(t, tt.want, tt.kind.String())
}
}
func TestResourceStatusString(t *testing.T) {
t.Parallel()
tests := []struct {
status agentcontext.ResourceStatus
want string
}{
{agentcontext.StatusOK, "ok"},
{agentcontext.StatusOversize, "oversize"},
{agentcontext.StatusUnreadable, "unreadable"},
{agentcontext.StatusInvalid, "invalid"},
{agentcontext.StatusExcluded, "excluded"},
{agentcontext.ResourceStatus(999), "unknown"},
}
for _, tt := range tests {
require.Equal(t, tt.want, tt.status.String())
}
}
func TestComputeAggregateHash_DeterministicAcrossOrder(t *testing.T) {
t.Parallel()
a := agentcontext.Resource{
ID: "instruction_file:/a/AGENTS.md",
Kind: agentcontext.KindInstructionFile,
Source: "/a/AGENTS.md",
Status: agentcontext.StatusOK,
}
b := agentcontext.Resource{
ID: "instruction_file:/b/AGENTS.md",
Kind: agentcontext.KindInstructionFile,
Source: "/b/AGENTS.md",
Status: agentcontext.StatusOK,
}
got1 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{a, b})
got2 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{b, a})
require.Equal(t, got1, got2)
}
func TestComputeAggregateHash_ChangesOnContent(t *testing.T) {
t.Parallel()
base := agentcontext.Resource{
ID: "instruction_file:/a/AGENTS.md",
Kind: agentcontext.KindInstructionFile,
Source: "/a/AGENTS.md",
Status: agentcontext.StatusOK,
}
hash1 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{base})
withContent := base
withContent.ContentHash = [32]byte{0x01}
hash2 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{withContent})
require.NotEqual(t, hash1, hash2)
withStatus := base
withStatus.Status = agentcontext.StatusOversize
hash3 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{withStatus})
require.NotEqual(t, hash1, hash3)
}
func TestSnapshotAggregateHashHex(t *testing.T) {
t.Parallel()
snap := agentcontext.Snapshot{
AggregateHash: [32]byte{0xde, 0xad, 0xbe, 0xef},
}
require.Equal(t,
"deadbeef0000000000000000000000000000000000000000000000000000000000000000"[:64],
snap.AggregateHashHex())
}
+346
View File
@@ -0,0 +1,346 @@
package agentcontext
import (
"context"
"errors"
"io/fs"
"os"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
"github.com/fsnotify/fsnotify"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/quartz"
)
// DefaultWatchDebounce coalesces editor-style multi-event writes
// (truncate plus rename plus chmod) into a single re-resolve.
// Mirrors the debounce window the existing MCP config watcher
// uses so behavior is consistent across the agent.
const DefaultWatchDebounce = 250 * time.Millisecond
// WatcherOptions parameterizes the recursive watcher.
type WatcherOptions struct {
Logger slog.Logger
Clock quartz.Clock
Debounce time.Duration
// OnChange runs at most once per debounce window. The
// caller must not block; the recommended pattern is a
// non-blocking send on a re-resolve trigger channel.
OnChange func()
}
// Watcher is a recursive fsnotify wrapper. fsnotify does not
// support recursive watches natively on Linux, so we walk every
// scan root at sync time and register each subdirectory
// individually. Inotify ENOSPC degrades the watcher into a
// poll-only mode that still re-resolves on Sync calls.
type Watcher struct {
logger slog.Logger
clock quartz.Clock
debounce time.Duration
onChange func()
mu sync.Mutex
watcher *fsnotify.Watcher
watched map[string]struct{}
timer *quartz.Timer
degraded string // non-empty when the watcher dropped events
closed bool
closedCh chan struct{}
runDoneCh chan struct{}
}
// NewWatcher constructs a recursive watcher. The watcher does
// nothing until Sync is called.
func NewWatcher(opts WatcherOptions) (*Watcher, error) {
if opts.OnChange == nil {
return nil, xerrors.New("OnChange callback is required")
}
debounce := opts.Debounce
if debounce <= 0 {
debounce = DefaultWatchDebounce
}
clock := opts.Clock
if clock == nil {
clock = quartz.NewReal()
}
w, err := fsnotify.NewWatcher()
if err != nil {
// On Linux, fsnotify.NewWatcher only fails when the
// inotify subsystem is at the system-wide watch
// limit. Surface a Watcher in "degraded" mode so the
// caller can still rely on explicit Sync triggers.
degraded := &Watcher{
logger: opts.Logger,
clock: clock,
debounce: debounce,
onChange: opts.OnChange,
watched: make(map[string]struct{}),
degraded: "fsnotify init failed: " + err.Error(),
closedCh: make(chan struct{}),
runDoneCh: closedChan(),
}
return degraded, nil
}
cw := &Watcher{
logger: opts.Logger,
clock: clock,
debounce: debounce,
onChange: opts.OnChange,
watcher: w,
watched: make(map[string]struct{}),
closedCh: make(chan struct{}),
runDoneCh: make(chan struct{}),
}
go cw.run()
return cw, nil
}
// closedChan returns an already-closed channel for the
// degraded-watcher case where there is no run goroutine.
func closedChan() chan struct{} {
c := make(chan struct{})
close(c)
return c
}
// Degraded returns a non-empty string when the watcher is
// running with reduced functionality (typically inotify
// ENOSPC). The string is suitable for use as a snapshot-level
// error message.
func (w *Watcher) Degraded() string {
w.mu.Lock()
defer w.mu.Unlock()
return w.degraded
}
// Sync replaces the set of watched directories with a fresh
// recursive walk of every scan root. Files are not watched
// directly; watching the parent directory catches creates,
// renames, removes, and writes that touch any recognized
// basename. Files that are themselves scan roots are handled by
// watching their parent.
//
// Sync is idempotent and safe to call repeatedly.
func (w *Watcher) Sync(ctx context.Context, roots []ScanRoot) {
w.mu.Lock()
if w.closed {
w.mu.Unlock()
return
}
if w.watcher == nil {
// Degraded mode: nothing to wire up; fire the callback
// so the caller still gets a fresh resolve.
w.mu.Unlock()
w.schedule()
return
}
desired := w.collectDirs(roots)
// Remove directories no longer wanted.
for path := range w.watched {
if _, ok := desired[path]; ok {
continue
}
_ = w.watcher.Remove(path)
delete(w.watched, path)
}
// Add directories that are new.
for path := range desired {
if _, ok := w.watched[path]; ok {
continue
}
if err := w.watcher.Add(path); err != nil {
// ENOSPC means the kernel's per-user inotify
// watch budget is exhausted. Mark the watcher
// degraded; subsequent Sync calls still fire
// the change callback so resync still works.
if errors.Is(err, syscall.ENOSPC) {
w.degraded = "inotify watch limit exceeded (ENOSPC)"
w.logger.Warn(ctx, "context watcher degraded: inotify watch limit exceeded",
slog.F("dir", path))
break
}
w.logger.Debug(ctx, "context watcher could not add dir",
slog.F("dir", path), slog.Error(err))
continue
}
w.watched[path] = struct{}{}
}
w.mu.Unlock()
}
// Close stops the watcher and releases all kernel watch slots.
// Close is idempotent.
func (w *Watcher) Close() error {
w.mu.Lock()
if w.closed {
w.mu.Unlock()
return nil
}
w.closed = true
close(w.closedCh)
timer := w.timer
watcher := w.watcher
w.timer = nil
w.watcher = nil
w.mu.Unlock()
if timer != nil {
timer.Stop()
}
if watcher != nil {
_ = watcher.Close()
}
<-w.runDoneCh
return nil
}
// run forwards fsnotify events into the debounce timer. It exits
// when Close is called or the underlying watcher is closed.
func (w *Watcher) run() {
defer close(w.runDoneCh)
// Capture the watcher reference once. Close may set the
// field to nil concurrently; reading the captured local
// keeps the event loop safe through the race window.
w.mu.Lock()
fsw := w.watcher
w.mu.Unlock()
if fsw == nil {
return
}
for {
select {
case <-w.closedCh:
return
case ev, ok := <-fsw.Events:
if !ok {
return
}
if !w.eventRelevant(ev) {
continue
}
w.schedule()
case err, ok := <-fsw.Errors:
if !ok {
return
}
if err != nil {
w.logger.Debug(context.Background(), "context watcher error", slog.Error(err))
}
}
}
}
// eventRelevant filters out events that cannot affect any
// recognized resource. The check is conservative: any event on
// a directory triggers a re-resolve so newly created subtrees
// are picked up.
func (*Watcher) eventRelevant(ev fsnotify.Event) bool {
name := filepath.Base(ev.Name)
if recognizedInstructionFile(name) || name == mcpConfigFileName || name == skillMetaFileName {
return true
}
// Directory create/remove flips re-resolve so new subtrees
// arm watches and removed subtrees stop arming them.
if ev.Has(fsnotify.Create) || ev.Has(fsnotify.Remove) || ev.Has(fsnotify.Rename) {
return true
}
return false
}
// schedule arms or resets the debounce timer.
func (w *Watcher) schedule() {
w.mu.Lock()
if w.closed {
w.mu.Unlock()
return
}
cb := w.onChange
if w.timer != nil {
w.timer.Reset(w.debounce)
w.mu.Unlock()
return
}
w.timer = w.clock.AfterFunc(w.debounce, func() {
w.mu.Lock()
w.timer = nil
w.mu.Unlock()
cb()
})
w.mu.Unlock()
}
// collectDirs walks every scan root and returns the set of
// directories to watch. The maximum depth mirrors the
// resolver's MaxDepth so we never watch more than we'd scan.
func (*Watcher) collectDirs(roots []ScanRoot) map[string]struct{} {
out := make(map[string]struct{})
for _, root := range roots {
if root.Path == "" {
continue
}
info, err := os.Stat(root.Path)
if err != nil {
// Watch the deepest existing ancestor so the
// root being created later still fires.
if anc := existingAncestor(root.Path); anc != "" {
out[anc] = struct{}{}
}
continue
}
if !info.IsDir() {
out[filepath.Dir(root.Path)] = struct{}{}
continue
}
// Walk the directory and collect every descendant
// directory up to the depth cap.
rootDepth := strings.Count(filepath.Clean(root.Path), string(os.PathSeparator))
_ = filepath.WalkDir(root.Path, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return nil
}
if !d.IsDir() {
return nil
}
if _, skip := skipDirNames[d.Name()]; skip && path != root.Path {
return fs.SkipDir
}
if strings.Count(path, string(os.PathSeparator))-rootDepth > DefaultMaxScanDepth {
return fs.SkipDir
}
out[path] = struct{}{}
return nil
})
}
return out
}
// existingAncestor returns the deepest existing ancestor of
// path, or "" if no ancestor exists (e.g. an entirely missing
// drive on Windows).
func existingAncestor(path string) string {
cur := filepath.Dir(path)
for {
if cur == "" || cur == "." {
return ""
}
info, err := os.Stat(cur)
if err == nil && info.IsDir() {
return cur
}
parent := filepath.Dir(cur)
if parent == cur {
return ""
}
cur = parent
}
}
+96
View File
@@ -0,0 +1,96 @@
package agentcontext_test
import (
"context"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/coder/quartz"
"github.com/coder/coder/v2/agent/agentcontext"
"github.com/coder/coder/v2/testutil"
)
func TestWatcher_FiresOnAgentsMdEdit(t *testing.T) {
t.Parallel()
dir := t.TempDir()
require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v1"), 0o600))
var fires int32
w, err := agentcontext.NewWatcher(agentcontext.WatcherOptions{
Logger: testutil.Logger(t).Named("watcher"),
Clock: quartz.NewReal(),
Debounce: 10 * time.Millisecond,
OnChange: func() { atomic.AddInt32(&fires, 1) },
})
require.NoError(t, err)
t.Cleanup(func() { _ = w.Close() })
ctx := testutil.Context(t, testutil.WaitShort)
w.Sync(ctx, []agentcontext.ScanRoot{{Path: dir}})
// Edit the file. Use a slight delay so fsnotify is ready.
time.Sleep(50 * time.Millisecond)
require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v2"), 0o600))
require.Eventually(t, func() bool {
return atomic.LoadInt32(&fires) >= 1
}, testutil.WaitShort, testutil.IntervalFast, "expected at least one fire after AGENTS.md edit")
}
func TestWatcher_FiresOnNewSkillFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
skillsRoot := filepath.Join(dir, ".agents", "skills")
require.NoError(t, os.MkdirAll(skillsRoot, 0o755))
var fires int32
w, err := agentcontext.NewWatcher(agentcontext.WatcherOptions{
Logger: testutil.Logger(t).Named("watcher"),
Debounce: 10 * time.Millisecond,
OnChange: func() { atomic.AddInt32(&fires, 1) },
})
require.NoError(t, err)
t.Cleanup(func() { _ = w.Close() })
ctx := testutil.Context(t, testutil.WaitShort)
w.Sync(ctx, []agentcontext.ScanRoot{{Path: dir}})
time.Sleep(50 * time.Millisecond)
skillDir := filepath.Join(skillsRoot, "foo")
require.NoError(t, os.MkdirAll(skillDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("---\nname: foo\ndescription: bar\n---\nbody"), 0o600))
require.Eventually(t, func() bool {
return atomic.LoadInt32(&fires) >= 1
}, testutil.WaitShort, testutil.IntervalFast, "expected fire after SKILL.md create")
}
func TestWatcher_CloseIsIdempotent(t *testing.T) {
t.Parallel()
w, err := agentcontext.NewWatcher(agentcontext.WatcherOptions{
Logger: testutil.Logger(t).Named("watcher"),
OnChange: func() {},
})
require.NoError(t, err)
require.NoError(t, w.Close())
require.NoError(t, w.Close())
}
func TestWatcher_SyncAfterCloseNoop(t *testing.T) {
t.Parallel()
w, err := agentcontext.NewWatcher(agentcontext.WatcherOptions{
Logger: testutil.Logger(t).Named("watcher"),
OnChange: func() {},
})
require.NoError(t, err)
require.NoError(t, w.Close())
// Must not panic.
w.Sync(context.Background(), []agentcontext.ScanRoot{{Path: t.TempDir()}})
}