mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat(agent/agentcontext): add agent-side context resolution package
This commit is contained in:
@@ -0,0 +1,204 @@
|
||||
package agentcontext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// SourceResponse is the on-wire representation of a Source.
|
||||
// Matches the path-only RFC schema; future additions (tags,
|
||||
// labels) can land additively without breaking clients.
|
||||
type SourceResponse struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// SourceRequest is the request body for POST /sources.
|
||||
type SourceRequest struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// SnapshotResource is the on-wire representation of a Resource.
|
||||
// PayloadBase64 is omitted from list responses; clients that
|
||||
// need the bytes hit GET /sources/{path}.
|
||||
type SnapshotResource struct {
|
||||
ID string `json:"id"`
|
||||
Kind string `json:"kind"`
|
||||
Source string `json:"source"`
|
||||
SourcePath string `json:"source_path,omitempty"`
|
||||
ContentHash string `json:"content_hash"`
|
||||
SizeBytes uint64 `json:"size_bytes"`
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
// SnapshotResponse is the on-wire representation of a Snapshot
|
||||
// returned by the resync endpoint.
|
||||
type SnapshotResponse struct {
|
||||
Version uint64 `json:"version"`
|
||||
SchemaVersion uint64 `json:"schema_version"`
|
||||
AggregateHash string `json:"aggregate_hash"`
|
||||
Resources []SnapshotResource `json:"resources"`
|
||||
PayloadBytes uint64 `json:"payload_bytes"`
|
||||
SnapshotError string `json:"snapshot_error,omitempty"`
|
||||
}
|
||||
|
||||
// API exposes the Manager over HTTP. The routes match the RFC:
|
||||
//
|
||||
// GET /api/v0/context/sources
|
||||
// POST /api/v0/context/sources { path }
|
||||
// GET /api/v0/context/sources/{path}
|
||||
// DELETE /api/v0/context/sources/{path}
|
||||
// POST /api/v0/context/resync
|
||||
//
|
||||
// {path} is URL-encoded canonical path. Callers pass either the
|
||||
// canonical or original path; the handler canonicalizes before
|
||||
// matching.
|
||||
type API struct {
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
// NewAPI wraps the supplied Manager.
|
||||
func NewAPI(m *Manager) *API {
|
||||
return &API{manager: m}
|
||||
}
|
||||
|
||||
// Routes returns the chi handler for /api/v0/context/*. Mount
|
||||
// it at "/api/v0/context".
|
||||
func (a *API) Routes() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Route("/sources", func(r chi.Router) {
|
||||
r.Get("/", a.handleListSources)
|
||||
r.Post("/", a.handleAddSource)
|
||||
r.Get("/{path}", a.handleGetSource)
|
||||
r.Delete("/{path}", a.handleRemoveSource)
|
||||
})
|
||||
r.Post("/resync", a.handleResync)
|
||||
return r
|
||||
}
|
||||
|
||||
func (a *API) handleListSources(rw http.ResponseWriter, r *http.Request) {
|
||||
sources := a.manager.Sources()
|
||||
out := make([]SourceResponse, 0, len(sources))
|
||||
for _, s := range sources {
|
||||
out = append(out, SourceResponse(s))
|
||||
}
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, out)
|
||||
}
|
||||
|
||||
func (a *API) handleAddSource(rw http.ResponseWriter, r *http.Request) {
|
||||
var req SourceRequest
|
||||
if !httpapi.Read(r.Context(), rw, r, &req) {
|
||||
return
|
||||
}
|
||||
s, err := a.manager.AddSource(Source(req))
|
||||
if err != nil {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Could not add context source.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(r.Context(), rw, http.StatusCreated, SourceResponse(s))
|
||||
}
|
||||
|
||||
func (a *API) handleGetSource(rw http.ResponseWriter, r *http.Request) {
|
||||
raw := chi.URLParam(r, "path")
|
||||
decoded, err := url.PathUnescape(raw)
|
||||
if err != nil {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid context source path.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
canonical, ok := a.manager.HasSource(decoded)
|
||||
if !ok {
|
||||
httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Context source not found.",
|
||||
Detail: "No source registered for path " + strconv.Quote(decoded) + ".",
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, SourceResponse{Path: canonical})
|
||||
}
|
||||
|
||||
func (a *API) handleRemoveSource(rw http.ResponseWriter, r *http.Request) {
|
||||
raw := chi.URLParam(r, "path")
|
||||
decoded, err := url.PathUnescape(raw)
|
||||
if err != nil {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid context source path.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := a.manager.RemoveSource(decoded); err != nil {
|
||||
if errors.Is(err, ErrSourceNotFound) {
|
||||
httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Context source not found.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Could not remove context source.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (a *API) handleResync(rw http.ResponseWriter, r *http.Request) {
|
||||
snap, err := a.manager.Resync(r.Context())
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
status = http.StatusGatewayTimeout
|
||||
}
|
||||
httpapi.Write(r.Context(), rw, status, codersdk.Response{
|
||||
Message: "Resync failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, snapshotResponse(snap))
|
||||
}
|
||||
|
||||
// snapshotResponse converts a Snapshot to its on-wire form for
|
||||
// the resync endpoint. Payloads are omitted; the per-resource
|
||||
// payload bytes ship via the drpc PushContextState path.
|
||||
func snapshotResponse(s Snapshot) SnapshotResponse {
|
||||
out := SnapshotResponse{
|
||||
Version: s.Version,
|
||||
SchemaVersion: s.SchemaVersion,
|
||||
AggregateHash: hex.EncodeToString(s.AggregateHash[:]),
|
||||
Resources: make([]SnapshotResource, 0, len(s.Resources)),
|
||||
PayloadBytes: s.PayloadBytes,
|
||||
SnapshotError: s.SnapshotError,
|
||||
}
|
||||
for _, r := range s.Resources {
|
||||
out.Resources = append(out.Resources, SnapshotResource{
|
||||
ID: r.ID,
|
||||
Kind: r.Kind.String(),
|
||||
Source: r.Source,
|
||||
SourcePath: r.SourcePath,
|
||||
ContentHash: hex.EncodeToString(r.ContentHash[:]),
|
||||
SizeBytes: r.SizeBytes,
|
||||
Status: r.Status.String(),
|
||||
Error: r.Error,
|
||||
Description: r.Description,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,177 @@
|
||||
package agentcontext_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontext"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func newAPITestServer(t *testing.T, opts agentcontext.ManagerOptions) (*httptest.Server, *agentcontext.Manager) {
|
||||
t.Helper()
|
||||
m := newTestManager(t, opts)
|
||||
api := agentcontext.NewAPI(m)
|
||||
srv := httptest.NewServer(api.Routes())
|
||||
t.Cleanup(srv.Close)
|
||||
return srv, m
|
||||
}
|
||||
|
||||
// doRequest issues an HTTP request bounded by testutil.WaitShort
|
||||
// and returns the status code and response body. The response
|
||||
// body is closed before doRequest returns.
|
||||
func doRequest(t *testing.T, method, requrl string, body io.Reader) (int, []byte) {
|
||||
t.Helper()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
req, err := http.NewRequestWithContext(ctx, method, requrl, body)
|
||||
require.NoError(t, err)
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
res, err := http.DefaultClient.Do(req) //nolint:bodyclose // closed below.
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
bodyBytes, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
return res.StatusCode, bodyBytes
|
||||
}
|
||||
|
||||
func TestAPI_ListSourcesEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return t.TempDir() },
|
||||
})
|
||||
|
||||
status, body := doRequest(t, http.MethodGet, srv.URL+"/sources", nil)
|
||||
require.Equal(t, http.StatusOK, status)
|
||||
|
||||
var got []agentcontext.SourceResponse
|
||||
require.NoError(t, json.Unmarshal(body, &got))
|
||||
require.Empty(t, got)
|
||||
}
|
||||
|
||||
func TestAPI_AddAndListSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
wd := t.TempDir()
|
||||
src := t.TempDir()
|
||||
|
||||
srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return wd },
|
||||
AllowedRoots: []string{wd, src},
|
||||
})
|
||||
|
||||
body, _ := json.Marshal(agentcontext.SourceRequest{Path: src})
|
||||
status, addBody := doRequest(t, http.MethodPost, srv.URL+"/sources", bytes.NewReader(body))
|
||||
require.Equal(t, http.StatusCreated, status)
|
||||
|
||||
var created agentcontext.SourceResponse
|
||||
require.NoError(t, json.Unmarshal(addBody, &created))
|
||||
require.Equal(t, src, created.Path)
|
||||
|
||||
// List should show the new source.
|
||||
listStatus, listBody := doRequest(t, http.MethodGet, srv.URL+"/sources", nil)
|
||||
require.Equal(t, http.StatusOK, listStatus)
|
||||
var list []agentcontext.SourceResponse
|
||||
require.NoError(t, json.Unmarshal(listBody, &list))
|
||||
require.Len(t, list, 1)
|
||||
require.Equal(t, src, list[0].Path)
|
||||
}
|
||||
|
||||
func TestAPI_AddSourceRejected(t *testing.T) {
|
||||
t.Parallel()
|
||||
wd := t.TempDir()
|
||||
outside := t.TempDir()
|
||||
|
||||
srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return wd },
|
||||
AllowedRoots: []string{wd},
|
||||
})
|
||||
|
||||
body, _ := json.Marshal(agentcontext.SourceRequest{Path: outside})
|
||||
status, _ := doRequest(t, http.MethodPost, srv.URL+"/sources", bytes.NewReader(body))
|
||||
require.Equal(t, http.StatusBadRequest, status)
|
||||
}
|
||||
|
||||
func TestAPI_GetAndDeleteSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
wd := t.TempDir()
|
||||
src := t.TempDir()
|
||||
|
||||
srv, m := newAPITestServer(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return wd },
|
||||
AllowedRoots: []string{wd, src},
|
||||
})
|
||||
|
||||
_, err := m.AddSource(agentcontext.Source{Path: src})
|
||||
require.NoError(t, err)
|
||||
|
||||
status, body := doRequest(t, http.MethodGet, srv.URL+"/sources/"+url.PathEscape(src), nil)
|
||||
require.Equal(t, http.StatusOK, status)
|
||||
|
||||
var got agentcontext.SourceResponse
|
||||
require.NoError(t, json.Unmarshal(body, &got))
|
||||
require.Equal(t, src, got.Path)
|
||||
|
||||
delStatus, _ := doRequest(t, http.MethodDelete, srv.URL+"/sources/"+url.PathEscape(src), nil)
|
||||
require.Equal(t, http.StatusNoContent, delStatus)
|
||||
require.Empty(t, m.Sources())
|
||||
}
|
||||
|
||||
func TestAPI_GetSourceNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return t.TempDir() },
|
||||
})
|
||||
|
||||
status, _ := doRequest(t, http.MethodGet, srv.URL+"/sources/"+url.PathEscape("/never-added"), nil)
|
||||
require.Equal(t, http.StatusNotFound, status)
|
||||
}
|
||||
|
||||
func TestAPI_DeleteSourceNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return t.TempDir() },
|
||||
})
|
||||
|
||||
status, _ := doRequest(t, http.MethodDelete, srv.URL+"/sources/"+url.PathEscape("/never-added"), nil)
|
||||
require.Equal(t, http.StatusNotFound, status)
|
||||
}
|
||||
|
||||
func TestAPI_Resync(t *testing.T) {
|
||||
t.Parallel()
|
||||
wd := t.TempDir()
|
||||
mustWriteFile(t, filepath.Join(wd, "AGENTS.md"), "hello")
|
||||
|
||||
srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return wd },
|
||||
})
|
||||
|
||||
status, body := doRequest(t, http.MethodPost, srv.URL+"/resync", nil)
|
||||
require.Equal(t, http.StatusOK, status)
|
||||
|
||||
var snap agentcontext.SnapshotResponse
|
||||
require.NoError(t, json.Unmarshal(body, &snap))
|
||||
require.Equal(t, uint64(1), snap.SchemaVersion)
|
||||
require.NotEmpty(t, snap.AggregateHash)
|
||||
require.Len(t, snap.Resources, 1)
|
||||
require.Equal(t, "instruction_file", snap.Resources[0].Kind)
|
||||
require.Equal(t, "ok", snap.Resources[0].Status)
|
||||
}
|
||||
|
||||
func TestAPI_AddSourceMalformedBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return t.TempDir() },
|
||||
})
|
||||
|
||||
status, _ := doRequest(t, http.MethodPost, srv.URL+"/sources", bytes.NewReader([]byte("{not json")))
|
||||
require.Equal(t, http.StatusBadRequest, status)
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
// Package agentcontext consolidates the agent-side plumbing that
|
||||
// resolves, watches, and pushes workspace context (instruction
|
||||
// files, skills, and MCP configuration) to coderd.
|
||||
//
|
||||
// This is the agent half of the design described in
|
||||
// "RFC: Workspace Context Sources for Coder Agents". It owns:
|
||||
//
|
||||
// - User-declared scan roots (Sources) layered on top of
|
||||
// built-in defaults.
|
||||
// - A resolver that classifies files under each scan root into
|
||||
// typed Resources (instruction files, skills, MCP configs,
|
||||
// MCP servers).
|
||||
// - A unified recursive fsnotify watcher that signals a
|
||||
// re-resolve when any recognized file changes.
|
||||
// - An HTTP API at /api/v0/context/sources for source CRUD
|
||||
// and /api/v0/context/resync for synchronous push barriers.
|
||||
// - A Pusher abstraction so the latest Snapshot can be shipped
|
||||
// to coderd without coupling this package to any particular
|
||||
// drpc client version.
|
||||
//
|
||||
// The package is purely additive: existing agent code paths
|
||||
// (agent/agentcontextconfig and agent/x/agentmcp) continue to
|
||||
// operate unchanged. Wiring the Manager into the agent's HTTP
|
||||
// router and the drpc client lives in a follow-up change.
|
||||
package agentcontext
|
||||
@@ -0,0 +1,480 @@
|
||||
package agentcontext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// CurrentSchemaVersion is the on-wire shape version. Bump
|
||||
// whenever the resource format changes in a way that requires
|
||||
// coderd-side awareness.
|
||||
const CurrentSchemaVersion uint64 = 1
|
||||
|
||||
// ManagerOptions configures a Manager. Zero values get sensible
|
||||
// defaults.
|
||||
type ManagerOptions struct {
|
||||
// Logger receives diagnostic messages. Required.
|
||||
Logger slog.Logger
|
||||
// Clock is the time source used for the watcher's
|
||||
// debounce timer. Optional; defaults to quartz.NewReal().
|
||||
Clock quartz.Clock
|
||||
// WorkingDir is evaluated on every resolve, mirroring the
|
||||
// existing agent convention. The result is used as a
|
||||
// scan root.
|
||||
WorkingDir func() string
|
||||
// BuiltinRoots are scan roots layered before user-added
|
||||
// sources. Typically: the working directory, ~/.coder,
|
||||
// ~/.coder/skills, .agents/skills,
|
||||
// ~/.claude/plugins/cache.
|
||||
BuiltinRoots []string
|
||||
// InitialSources seeds the Manager's source list at boot
|
||||
// time. Sources from CODER_AGENT_EXP_*_DIRS env vars or
|
||||
// startup scripts are layered here.
|
||||
InitialSources []Source
|
||||
// AllowedRoots restricts which paths may be added as
|
||||
// sources at runtime. Defaults to [~, ~/.coder, ~/.claude,
|
||||
// workingDir]. Empty disables validation.
|
||||
AllowedRoots []string
|
||||
// Resolver, when non-nil, replaces the default resolver.
|
||||
// Tests use this to inject MCP providers and tighten
|
||||
// caps.
|
||||
Resolver *Resolver
|
||||
// Debounce overrides the watcher's debounce window.
|
||||
Debounce time.Duration
|
||||
// SchemaVersion is the version stamped on each Snapshot.
|
||||
// Use CurrentSchemaVersion (the default) unless rolling
|
||||
// out a schema change.
|
||||
SchemaVersion uint64
|
||||
}
|
||||
|
||||
// Manager orchestrates source CRUD, resolution, watching, and
|
||||
// Pusher fan-out. Construct with NewManager; start its lifecycle
|
||||
// goroutines with Run; tear down with Close.
|
||||
type Manager struct {
|
||||
logger slog.Logger
|
||||
clock quartz.Clock
|
||||
workingDir func() string
|
||||
builtinRoots []string
|
||||
allowedRoots []string
|
||||
resolver *Resolver
|
||||
debounce time.Duration
|
||||
schemaVersion uint64
|
||||
|
||||
mu sync.Mutex
|
||||
sources []Source
|
||||
// sourceIndex maps canonical path -> position in sources
|
||||
// for O(1) lookups during AddSource / RemoveSource.
|
||||
sourceIndex map[string]int
|
||||
|
||||
// snapshot is the latest result of a resolver pass. It is
|
||||
// replaced atomically under mu.
|
||||
snapshot Snapshot
|
||||
// version monotonically increases per resolve pass.
|
||||
version uint64
|
||||
|
||||
// subscribers receive a non-blocking signal whenever the
|
||||
// snapshot changes. Subscribers must drain their channel
|
||||
// promptly; the Manager drops sends to full channels.
|
||||
subscribers map[chan struct{}]struct{}
|
||||
|
||||
// trigger fires when AddSource / RemoveSource / watcher
|
||||
// observe a change.
|
||||
trigger chan struct{}
|
||||
|
||||
// running tracks Run lifetime.
|
||||
running bool
|
||||
closed bool
|
||||
closedCh chan struct{}
|
||||
runDoneCh chan struct{}
|
||||
|
||||
watcher *Watcher
|
||||
}
|
||||
|
||||
// NewManager validates options, canonicalizes initial sources,
|
||||
// performs the first resolver pass synchronously, and returns
|
||||
// the resulting Manager. Run must be called separately to start
|
||||
// the watcher and re-resolve goroutine.
|
||||
func NewManager(opts ManagerOptions) (*Manager, error) {
|
||||
clock := opts.Clock
|
||||
if clock == nil {
|
||||
clock = quartz.NewReal()
|
||||
}
|
||||
debounce := opts.Debounce
|
||||
if debounce <= 0 {
|
||||
debounce = DefaultWatchDebounce
|
||||
}
|
||||
schemaVersion := opts.SchemaVersion
|
||||
if schemaVersion == 0 {
|
||||
schemaVersion = CurrentSchemaVersion
|
||||
}
|
||||
resolver := opts.Resolver
|
||||
if resolver == nil {
|
||||
resolver = &Resolver{}
|
||||
}
|
||||
|
||||
m := &Manager{
|
||||
logger: opts.Logger,
|
||||
clock: clock,
|
||||
workingDir: opts.WorkingDir,
|
||||
builtinRoots: append([]string(nil), opts.BuiltinRoots...),
|
||||
allowedRoots: append([]string(nil), opts.AllowedRoots...),
|
||||
resolver: resolver,
|
||||
debounce: debounce,
|
||||
schemaVersion: schemaVersion,
|
||||
sources: make([]Source, 0),
|
||||
sourceIndex: make(map[string]int),
|
||||
subscribers: make(map[chan struct{}]struct{}),
|
||||
trigger: make(chan struct{}, 1),
|
||||
closedCh: make(chan struct{}),
|
||||
runDoneCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
for _, s := range opts.InitialSources {
|
||||
canonical, err := CanonicalizePath(s.Path)
|
||||
if err != nil {
|
||||
// Initial sources may not exist yet at boot
|
||||
// time; log and skip rather than abort the
|
||||
// agent.
|
||||
m.logger.Warn(context.Background(),
|
||||
"agentcontext: skipping invalid initial source",
|
||||
slog.F("path", s.Path),
|
||||
slog.Error(err))
|
||||
continue
|
||||
}
|
||||
if _, ok := m.sourceIndex[canonical]; ok {
|
||||
continue
|
||||
}
|
||||
m.sourceIndex[canonical] = len(m.sources)
|
||||
m.sources = append(m.sources, Source{Path: canonical})
|
||||
}
|
||||
|
||||
// First snapshot is computed eagerly. The push protocol
|
||||
// requires a snapshot to be present before the agent signals
|
||||
// lifecycle = ready, so callers can rely on Snapshot() being
|
||||
// populated immediately after NewManager returns.
|
||||
m.resolveLocked()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Run starts the watcher and the re-resolve goroutine. Run
|
||||
// blocks until ctx is canceled or Close is called. It is safe
|
||||
// to call Run at most once per Manager.
|
||||
func (m *Manager) Run(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
if m.running {
|
||||
m.mu.Unlock()
|
||||
return xerrors.New("agentcontext: Manager.Run called more than once")
|
||||
}
|
||||
if m.closed {
|
||||
m.mu.Unlock()
|
||||
return xerrors.New("agentcontext: Manager already closed")
|
||||
}
|
||||
m.running = true
|
||||
m.mu.Unlock()
|
||||
|
||||
watcher, err := NewWatcher(WatcherOptions{
|
||||
Logger: m.logger.Named("watcher"),
|
||||
Clock: m.clock,
|
||||
Debounce: m.debounce,
|
||||
OnChange: m.signal,
|
||||
})
|
||||
if err != nil {
|
||||
// NewWatcher already falls back to degraded mode on
|
||||
// init failure, so an actual error here is
|
||||
// exceptional.
|
||||
return xerrors.Errorf("create watcher: %w", err)
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.watcher = watcher
|
||||
roots := m.scanRootsLocked()
|
||||
m.mu.Unlock()
|
||||
watcher.Sync(ctx, roots)
|
||||
|
||||
defer close(m.runDoneCh)
|
||||
defer watcher.Close()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.closedCh:
|
||||
return nil
|
||||
case <-m.trigger:
|
||||
m.mu.Lock()
|
||||
roots := m.scanRootsLocked()
|
||||
m.mu.Unlock()
|
||||
watcher.Sync(ctx, roots)
|
||||
m.resolveAndBroadcast(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the Manager. Close is idempotent; subsequent
|
||||
// calls block until Run exits.
|
||||
func (m *Manager) Close() error {
|
||||
m.mu.Lock()
|
||||
if m.closed {
|
||||
running := m.running
|
||||
m.mu.Unlock()
|
||||
if running {
|
||||
<-m.runDoneCh
|
||||
}
|
||||
return nil
|
||||
}
|
||||
m.closed = true
|
||||
running := m.running
|
||||
close(m.closedCh)
|
||||
m.mu.Unlock()
|
||||
if running {
|
||||
<-m.runDoneCh
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sources returns a defensive copy of the current source list.
|
||||
// The returned slice is safe to mutate.
|
||||
func (m *Manager) Sources() []Source {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
out := make([]Source, len(m.sources))
|
||||
copy(out, m.sources)
|
||||
return out
|
||||
}
|
||||
|
||||
// HasSource reports whether path matches an existing source
|
||||
// after canonicalization. Returns the canonical path on
|
||||
// success.
|
||||
func (m *Manager) HasSource(path string) (canonical string, ok bool) {
|
||||
c, err := CanonicalizePath(path)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
_, ok = m.sourceIndex[c]
|
||||
return c, ok
|
||||
}
|
||||
|
||||
// AddSource adds a new source. The path is canonicalized and
|
||||
// validated against the AllowedRoots set. AddSource is
|
||||
// idempotent.
|
||||
func (m *Manager) AddSource(s Source) (Source, error) {
|
||||
canonical, err := CanonicalizePath(s.Path)
|
||||
if err != nil {
|
||||
return Source{}, xerrors.Errorf("canonicalize: %w", err)
|
||||
}
|
||||
if err := ValidateSourcePath(canonical, m.effectiveAllowedRoots()); err != nil {
|
||||
return Source{}, err
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
if _, ok := m.sourceIndex[canonical]; ok {
|
||||
out := m.sources[m.sourceIndex[canonical]]
|
||||
m.mu.Unlock()
|
||||
return out, nil
|
||||
}
|
||||
m.sourceIndex[canonical] = len(m.sources)
|
||||
m.sources = append(m.sources, Source{Path: canonical})
|
||||
m.mu.Unlock()
|
||||
|
||||
m.signal()
|
||||
return Source{Path: canonical}, nil
|
||||
}
|
||||
|
||||
// RemoveSource removes the source matching path. Path is
|
||||
// canonicalized before matching. Returns ErrSourceNotFound when
|
||||
// no such source exists.
|
||||
func (m *Manager) RemoveSource(path string) error {
|
||||
canonical, err := CanonicalizePath(path)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("canonicalize: %w", err)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
idx, ok := m.sourceIndex[canonical]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
return ErrSourceNotFound
|
||||
}
|
||||
// O(n) compaction is fine for the typical handful of
|
||||
// user-added sources.
|
||||
m.sources = append(m.sources[:idx], m.sources[idx+1:]...)
|
||||
delete(m.sourceIndex, canonical)
|
||||
for i := idx; i < len(m.sources); i++ {
|
||||
m.sourceIndex[m.sources[i].Path] = i
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.signal()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Snapshot returns the latest Snapshot. The returned value is
|
||||
// safe to share but shares the same Resources slice as the
|
||||
// internal state; callers must not mutate it.
|
||||
func (m *Manager) Snapshot() Snapshot {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.snapshot
|
||||
}
|
||||
|
||||
// SubscribeChanges returns a buffered channel that receives a
|
||||
// signal whenever the snapshot changes. The unsubscribe
|
||||
// callback is safe to call from any goroutine and is
|
||||
// idempotent.
|
||||
func (m *Manager) SubscribeChanges() (<-chan struct{}, func()) {
|
||||
ch := make(chan struct{}, 1)
|
||||
m.mu.Lock()
|
||||
m.subscribers[ch] = struct{}{}
|
||||
m.mu.Unlock()
|
||||
|
||||
var once sync.Once
|
||||
unsub := func() {
|
||||
once.Do(func() {
|
||||
m.mu.Lock()
|
||||
delete(m.subscribers, ch)
|
||||
m.mu.Unlock()
|
||||
// Don't close ch: readers may still be in flight.
|
||||
})
|
||||
}
|
||||
return ch, unsub
|
||||
}
|
||||
|
||||
// Resync forces an immediate re-resolve and returns the new
|
||||
// Snapshot. Resync is safe to call regardless of whether Run
|
||||
// is active; the work is done synchronously under the
|
||||
// Manager's mutex either way.
|
||||
func (m *Manager) Resync(ctx context.Context) (Snapshot, error) {
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
return m.Snapshot(), ctxErr
|
||||
}
|
||||
m.mu.Lock()
|
||||
if m.closed {
|
||||
m.mu.Unlock()
|
||||
return m.Snapshot(), ErrManagerClosed
|
||||
}
|
||||
m.resolveLocked()
|
||||
snap := m.snapshot
|
||||
subs := make([]chan struct{}, 0, len(m.subscribers))
|
||||
for ch := range m.subscribers {
|
||||
subs = append(subs, ch)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
for _, ch := range subs {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return snap, nil
|
||||
}
|
||||
|
||||
// signal triggers a re-resolve. Sends are non-blocking; the
|
||||
// trigger channel has a depth of 1, which coalesces bursts.
|
||||
func (m *Manager) signal() {
|
||||
select {
|
||||
case m.trigger <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// scanRootsLocked returns the list of ScanRoots to feed the
|
||||
// resolver and watcher. The Manager's mutex must be held.
|
||||
func (m *Manager) scanRootsLocked() []ScanRoot {
|
||||
out := make([]ScanRoot, 0, 1+len(m.builtinRoots)+len(m.sources))
|
||||
if m.workingDir != nil {
|
||||
if wd := strings.TrimSpace(m.workingDir()); wd != "" {
|
||||
out = append(out, ScanRoot{Path: wd})
|
||||
}
|
||||
}
|
||||
for _, r := range m.builtinRoots {
|
||||
canonical, err := CanonicalizePath(r)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, ScanRoot{Path: canonical})
|
||||
}
|
||||
for _, s := range m.sources {
|
||||
out = append(out, ScanRoot{Path: s.Path, UserSource: s.Path})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// effectiveAllowedRoots returns the AllowedRoots augmented with
|
||||
// a sensible fallback (~ and the working directory) when the
|
||||
// caller did not configure any.
|
||||
func (m *Manager) effectiveAllowedRoots() []string {
|
||||
if len(m.allowedRoots) > 0 {
|
||||
return append([]string{}, m.allowedRoots...)
|
||||
}
|
||||
roots := []string{"~"}
|
||||
if m.workingDir != nil {
|
||||
if wd := strings.TrimSpace(m.workingDir()); wd != "" {
|
||||
roots = append(roots, wd)
|
||||
}
|
||||
}
|
||||
return roots
|
||||
}
|
||||
|
||||
// resolveAndBroadcast computes a fresh snapshot and notifies
|
||||
// subscribers if the aggregate hash changed.
|
||||
func (m *Manager) resolveAndBroadcast(ctx context.Context) {
|
||||
m.mu.Lock()
|
||||
m.resolveLocked()
|
||||
subs := make([]chan struct{}, 0, len(m.subscribers))
|
||||
for ch := range m.subscribers {
|
||||
subs = append(subs, ch)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// The broadcast is unconditional: Resync waiters that
|
||||
// triggered the pass without an actual content change
|
||||
// still need to wake up. Subscribers compare snapshots via
|
||||
// AggregateHash if they want to filter.
|
||||
for _, ch := range subs {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
_ = ctx
|
||||
}
|
||||
|
||||
// resolveLocked runs the resolver and stamps the snapshot.
|
||||
// The Manager's mutex must be held.
|
||||
func (m *Manager) resolveLocked() {
|
||||
roots := m.scanRootsLocked()
|
||||
snap := m.resolver.Resolve(roots)
|
||||
m.version++
|
||||
snap.Version = m.version
|
||||
snap.SchemaVersion = m.schemaVersion
|
||||
// Surface watcher degradation as a snapshot-level error
|
||||
// when the resolver did not already emit one.
|
||||
if snap.SnapshotError == "" && m.watcher != nil {
|
||||
if d := m.watcher.Degraded(); d != "" {
|
||||
snap.SnapshotError = d
|
||||
}
|
||||
}
|
||||
// Ensure resources are stable-sorted by ID even when the
|
||||
// resolver did not run them through caps.
|
||||
sort.Slice(snap.Resources, func(i, j int) bool {
|
||||
return snap.Resources[i].ID < snap.Resources[j].ID
|
||||
})
|
||||
m.snapshot = snap
|
||||
}
|
||||
|
||||
// ErrSourceNotFound is returned by RemoveSource when the
|
||||
// requested path is not in the source list.
|
||||
var ErrSourceNotFound = xerrors.New("source not found")
|
||||
|
||||
// ErrManagerClosed is returned by methods called after Close.
|
||||
var ErrManagerClosed = xerrors.New("agentcontext: manager closed")
|
||||
@@ -0,0 +1,260 @@
|
||||
package agentcontext_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontext"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func newTestManager(t *testing.T, opts agentcontext.ManagerOptions) *agentcontext.Manager {
|
||||
t.Helper()
|
||||
opts.Logger = testutil.Logger(t).Named("agentcontext-test")
|
||||
m, err := agentcontext.NewManager(opts)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
return m
|
||||
}
|
||||
|
||||
func TestManager_InitialSnapshotIsPopulated(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "boot snapshot")
|
||||
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return dir },
|
||||
})
|
||||
|
||||
snap := m.Snapshot()
|
||||
require.Equal(t, uint64(1), snap.Version)
|
||||
require.Equal(t, agentcontext.CurrentSchemaVersion, snap.SchemaVersion)
|
||||
require.Len(t, snap.Resources, 1)
|
||||
}
|
||||
|
||||
func TestManager_AddSourceTriggersResolve(t *testing.T) {
|
||||
t.Parallel()
|
||||
wd := t.TempDir()
|
||||
src := t.TempDir()
|
||||
mustWriteFile(t, filepath.Join(src, "AGENTS.md"), "from source")
|
||||
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return wd },
|
||||
AllowedRoots: []string{wd, src},
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
go func() { _ = m.Run(ctx) }()
|
||||
|
||||
t.Cleanup(func() { _ = m.Close() })
|
||||
|
||||
// Subscribe before mutating so we observe the broadcast.
|
||||
ch, unsub := m.SubscribeChanges()
|
||||
defer unsub()
|
||||
|
||||
added, err := m.AddSource(agentcontext.Source{Path: src})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, src, added.Path)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatalf("expected a change broadcast after AddSource")
|
||||
}
|
||||
|
||||
snap := m.Snapshot()
|
||||
require.Greater(t, snap.Version, uint64(1))
|
||||
|
||||
found := false
|
||||
for _, r := range snap.Resources {
|
||||
if r.Kind == agentcontext.KindInstructionFile && r.SourcePath == src {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
require.True(t, found, "expected AGENTS.md attributed to the user source")
|
||||
}
|
||||
|
||||
func TestManager_AddSourceRejectsOutsideAllowedRoots(t *testing.T) {
|
||||
t.Parallel()
|
||||
wd := t.TempDir()
|
||||
outside := t.TempDir()
|
||||
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return wd },
|
||||
AllowedRoots: []string{wd},
|
||||
})
|
||||
|
||||
_, err := m.AddSource(agentcontext.Source{Path: outside})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestManager_AddSourceIsIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
wd := t.TempDir()
|
||||
src := t.TempDir()
|
||||
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return wd },
|
||||
AllowedRoots: []string{wd, src},
|
||||
})
|
||||
|
||||
added1, err := m.AddSource(agentcontext.Source{Path: src})
|
||||
require.NoError(t, err)
|
||||
added2, err := m.AddSource(agentcontext.Source{Path: src})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, added1.Path, added2.Path)
|
||||
|
||||
sources := m.Sources()
|
||||
require.Len(t, sources, 1)
|
||||
}
|
||||
|
||||
func TestManager_RemoveSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
wd := t.TempDir()
|
||||
src := t.TempDir()
|
||||
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return wd },
|
||||
AllowedRoots: []string{wd, src},
|
||||
})
|
||||
|
||||
_, err := m.AddSource(agentcontext.Source{Path: src})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, m.RemoveSource(src))
|
||||
require.Empty(t, m.Sources())
|
||||
|
||||
err = m.RemoveSource(src)
|
||||
require.ErrorIs(t, err, agentcontext.ErrSourceNotFound)
|
||||
}
|
||||
|
||||
func TestManager_HasSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
wd := t.TempDir()
|
||||
src := t.TempDir()
|
||||
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return wd },
|
||||
AllowedRoots: []string{wd, src},
|
||||
})
|
||||
|
||||
canonical, ok := m.HasSource(src)
|
||||
require.False(t, ok)
|
||||
require.Equal(t, src, canonical)
|
||||
|
||||
_, err := m.AddSource(agentcontext.Source{Path: src})
|
||||
require.NoError(t, err)
|
||||
|
||||
canonical, ok = m.HasSource(src)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, src, canonical)
|
||||
}
|
||||
|
||||
func TestManager_ResyncReturnsLatestSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
wd := t.TempDir()
|
||||
mustWriteFile(t, filepath.Join(wd, "AGENTS.md"), "first")
|
||||
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return wd },
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
runDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(runDone)
|
||||
_ = m.Run(ctx)
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
_ = m.Close()
|
||||
<-runDone
|
||||
})
|
||||
|
||||
// Mutate AGENTS.md and call Resync. The returned
|
||||
// snapshot must reflect the new content.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(wd, "AGENTS.md"), []byte("second content edit"), 0o600))
|
||||
|
||||
snap, err := m.Resync(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, snap.Resources, 1)
|
||||
require.Equal(t, "second content edit", string(snap.Resources[0].Payload))
|
||||
}
|
||||
|
||||
func TestManager_InitialSourcesSeeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
wd := t.TempDir()
|
||||
src := t.TempDir()
|
||||
mustWriteFile(t, filepath.Join(src, "AGENTS.md"), "from initial")
|
||||
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return wd },
|
||||
AllowedRoots: []string{wd, src},
|
||||
InitialSources: []agentcontext.Source{{Path: src}},
|
||||
})
|
||||
|
||||
sources := m.Sources()
|
||||
require.Len(t, sources, 1)
|
||||
require.Equal(t, src, sources[0].Path)
|
||||
|
||||
snap := m.Snapshot()
|
||||
require.Len(t, snap.Resources, 1)
|
||||
require.Equal(t, src, snap.Resources[0].SourcePath)
|
||||
}
|
||||
|
||||
func TestManager_CloseIsIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return t.TempDir() },
|
||||
})
|
||||
require.NoError(t, m.Close())
|
||||
require.NoError(t, m.Close())
|
||||
}
|
||||
|
||||
func TestManager_RunOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return t.TempDir() },
|
||||
})
|
||||
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
|
||||
defer cancel()
|
||||
go func() { _ = m.Run(ctx) }()
|
||||
// Brief wait so Run has a chance to set running=true.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
err := m.Run(ctx)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "more than once")
|
||||
cancel()
|
||||
_ = m.Close()
|
||||
}
|
||||
|
||||
func TestManager_SubscribeBroadcastOnChange(t *testing.T) {
|
||||
t.Parallel()
|
||||
wd := t.TempDir()
|
||||
src := t.TempDir()
|
||||
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return wd },
|
||||
AllowedRoots: []string{wd, src},
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
go func() { _ = m.Run(ctx) }()
|
||||
|
||||
ch, unsub := m.SubscribeChanges()
|
||||
defer unsub()
|
||||
|
||||
_, err := m.AddSource(agentcontext.Source{Path: src})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("expected subscriber to be notified")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package agentcontext
|
||||
|
||||
// MCPProvider supplies the live MCP server portion of a
|
||||
// snapshot. Implementations typically wrap an existing MCP
|
||||
// manager (e.g. agent/x/agentmcp.Manager) and translate each
|
||||
// server's tool list into a KindMCPServer resource.
|
||||
//
|
||||
// The interface is intentionally minimal so the existing MCP
|
||||
// lifecycle code can be reused without refactoring; a follow-up
|
||||
// change absorbs the lifecycle into this package.
|
||||
type MCPProvider interface {
|
||||
// MCPResources returns one Resource per MCP server known
|
||||
// to the provider. Each Resource must:
|
||||
//
|
||||
// - Have Kind == KindMCPServer.
|
||||
// - Use the server name as Source.
|
||||
// - Populate ContentHash over the canonical encoding
|
||||
// of the tool list so changes flip the dirty bit.
|
||||
// - Carry a Description summarizing the server.
|
||||
//
|
||||
// Implementations should never block; the resolver calls
|
||||
// this on every re-resolve.
|
||||
MCPResources() []Resource
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
package agentcontext
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// CanonicalizePath produces the canonical form of a user-
|
||||
// supplied path. The result is absolute, has ~ expanded, has
|
||||
// path-traversal segments collapsed, and has symlinks resolved
|
||||
// when the target exists. The path is left lexically clean if
|
||||
// it does not yet exist (so adding a not-yet-created directory
|
||||
// remains possible).
|
||||
//
|
||||
// CanonicalizePath returns the original input when it is empty.
|
||||
func CanonicalizePath(raw string) (string, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", xerrors.New("path is empty")
|
||||
}
|
||||
|
||||
// Expand ~ and ~/ prefixes against the current user's home
|
||||
// directory. Other ~user forms are not supported on
|
||||
// purpose; the agent runs as a known user.
|
||||
if raw == "~" || strings.HasPrefix(raw, "~/") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("expand home dir: %w", err)
|
||||
}
|
||||
if raw == "~" {
|
||||
raw = home
|
||||
} else {
|
||||
raw = filepath.Join(home, raw[2:])
|
||||
}
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(raw) {
|
||||
// Fail closed: relative paths could mean different
|
||||
// things depending on the agent's working directory at
|
||||
// add-time, so require the caller to absolutize first.
|
||||
return "", xerrors.Errorf("path %q is not absolute", raw)
|
||||
}
|
||||
|
||||
cleaned := filepath.Clean(raw)
|
||||
if resolved, err := filepath.EvalSymlinks(cleaned); err == nil {
|
||||
return resolved, nil
|
||||
}
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
// ValidateSourcePath enforces the path-validation rules from
|
||||
// the RFC's Authorization section. It rejects:
|
||||
//
|
||||
// - Paths containing ".." segments after expansion.
|
||||
// - Paths resolving outside the supplied allowedRoots, unless
|
||||
// allowedRoots is empty (which disables the check).
|
||||
//
|
||||
// allowedRoots are canonicalized lazily; missing roots are
|
||||
// silently skipped so a workspace with no $HOME does not break
|
||||
// validation for project-relative roots.
|
||||
func ValidateSourcePath(canonical string, allowedRoots []string) error {
|
||||
if canonical == "" {
|
||||
return xerrors.New("path is empty")
|
||||
}
|
||||
// filepath.Clean drops "." but leaves ".." when no parent
|
||||
// is available. Reject defensively.
|
||||
for _, part := range strings.Split(canonical, string(os.PathSeparator)) {
|
||||
if part == ".." {
|
||||
return xerrors.Errorf("path %q contains parent traversal segments", canonical)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowedRoots) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build canonical, deduplicated allowed roots. Missing
|
||||
// roots (e.g. an unconfigured ~/.claude/) are skipped.
|
||||
roots := make([]string, 0, len(allowedRoots))
|
||||
seen := make(map[string]struct{}, len(allowedRoots))
|
||||
for _, raw := range allowedRoots {
|
||||
c, err := CanonicalizePath(raw)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[c]; ok {
|
||||
continue
|
||||
}
|
||||
seen[c] = struct{}{}
|
||||
roots = append(roots, c)
|
||||
}
|
||||
if len(roots) == 0 {
|
||||
// All configured roots were invalid; treat as "deny
|
||||
// everything" so misconfiguration fails closed.
|
||||
return xerrors.Errorf("path %q is not inside any allowed root", canonical)
|
||||
}
|
||||
|
||||
for _, root := range roots {
|
||||
if pathHasPrefix(canonical, root) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return xerrors.Errorf("path %q is not inside any allowed root", canonical)
|
||||
}
|
||||
|
||||
// pathHasPrefix reports whether path is equal to or a
|
||||
// descendant of prefix. Both arguments must already be clean,
|
||||
// absolute paths.
|
||||
func pathHasPrefix(path, prefix string) bool {
|
||||
if path == prefix {
|
||||
return true
|
||||
}
|
||||
withSep := prefix
|
||||
if !strings.HasSuffix(withSep, string(os.PathSeparator)) {
|
||||
withSep += string(os.PathSeparator)
|
||||
}
|
||||
return strings.HasPrefix(path, withSep)
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
package agentcontext_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontext"
|
||||
)
|
||||
|
||||
func TestCanonicalizePath_AbsoluteCleansAndResolves(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
got, err := agentcontext.CanonicalizePath(filepath.Join(dir, "a", "..", "b"))
|
||||
require.NoError(t, err)
|
||||
// Path does not exist; EvalSymlinks fails. Result is
|
||||
// lexically cleaned: filepath.Clean drops the "..".
|
||||
require.Equal(t, filepath.Join(dir, "b"), got)
|
||||
}
|
||||
|
||||
func TestCanonicalizePath_RelativeRejected(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := agentcontext.CanonicalizePath("relative/path")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
//nolint:paralleltest,tparallel // Uses t.Setenv.
|
||||
func TestCanonicalizePath_TildeExpansion(t *testing.T) {
|
||||
t.Setenv("HOME", "/tmp/home")
|
||||
got, err := agentcontext.CanonicalizePath("~/.coder")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "/tmp/home/.coder", got)
|
||||
}
|
||||
|
||||
//nolint:paralleltest,tparallel // Uses t.Setenv.
|
||||
func TestCanonicalizePath_BareTildeExpandsToHome(t *testing.T) {
|
||||
t.Setenv("HOME", "/tmp/home")
|
||||
got, err := agentcontext.CanonicalizePath("~")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "/tmp/home", got)
|
||||
}
|
||||
|
||||
func TestCanonicalizePath_FollowsSymlinks(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
realDir := filepath.Join(dir, "real")
|
||||
link := filepath.Join(dir, "link")
|
||||
require.NoError(t, os.MkdirAll(realDir, 0o755))
|
||||
require.NoError(t, os.Symlink(realDir, link))
|
||||
|
||||
got, err := agentcontext.CanonicalizePath(link)
|
||||
require.NoError(t, err)
|
||||
// On macOS the temp dir is itself symlinked; both realDir and got
|
||||
// pass through the same EvalSymlinks so they line up.
|
||||
want, err := filepath.EvalSymlinks(realDir)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestValidateSourcePath_RejectsParentSegments(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := agentcontext.ValidateSourcePath("/a/../b", []string{"/a"})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "parent traversal")
|
||||
}
|
||||
|
||||
func TestValidateSourcePath_AllowsInsideRoot(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
child := filepath.Join(dir, "child")
|
||||
require.NoError(t, os.MkdirAll(child, 0o755))
|
||||
|
||||
require.NoError(t, agentcontext.ValidateSourcePath(child, []string{dir}))
|
||||
require.NoError(t, agentcontext.ValidateSourcePath(dir, []string{dir}))
|
||||
}
|
||||
|
||||
func TestValidateSourcePath_RejectsOutsideRoot(t *testing.T) {
|
||||
t.Parallel()
|
||||
root := t.TempDir()
|
||||
other := t.TempDir()
|
||||
err := agentcontext.ValidateSourcePath(other, []string{root})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not inside any allowed root")
|
||||
}
|
||||
|
||||
func TestValidateSourcePath_EmptyAllowedRootsBypass(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.NoError(t, agentcontext.ValidateSourcePath("/anywhere", nil))
|
||||
}
|
||||
|
||||
func TestValidateSourcePath_InvalidRootsFailClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
// All allowed roots are relative and therefore invalid;
|
||||
// validation must fail closed.
|
||||
err := agentcontext.ValidateSourcePath("/anywhere", []string{"relative-only"})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateSourcePath_PathPrefixIsPathAware(t *testing.T) {
|
||||
t.Parallel()
|
||||
// "/a-prefix" is not inside "/a", even though it starts
|
||||
// with the same bytes.
|
||||
dir := t.TempDir()
|
||||
sibling := strings.TrimRight(dir, string(os.PathSeparator)) + "-sibling"
|
||||
require.NoError(t, os.MkdirAll(sibling, 0o755))
|
||||
t.Cleanup(func() { _ = os.RemoveAll(sibling) })
|
||||
err := agentcontext.ValidateSourcePath(sibling, []string{dir})
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
package agentcontext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
// PushRequest is the wire-format-independent payload the
|
||||
// Manager hands to a Pusher. It mirrors the protobuf
|
||||
// PushContextStateRequest message reserved in the RFC.
|
||||
//
|
||||
// Keeping the shape in plain Go lets this package compile
|
||||
// without bumping the drpc proto version. The follow-up
|
||||
// integration change can add a thin adapter that converts
|
||||
// PushRequest to proto and back.
|
||||
type PushRequest struct {
|
||||
Version uint64
|
||||
AggregateHash [32]byte
|
||||
Resources []Resource
|
||||
Initial bool
|
||||
SchemaVersion uint64
|
||||
SnapshotError string
|
||||
}
|
||||
|
||||
// PushResponse is the wire-format-independent return value of
|
||||
// a push.
|
||||
type PushResponse struct {
|
||||
Accepted bool
|
||||
}
|
||||
|
||||
// Pusher delivers snapshots to coderd. Concrete implementations
|
||||
// wrap a drpc client (proto v30 and later) or, in tests, a
|
||||
// recording in-memory fake.
|
||||
//
|
||||
// PushContextState must respect ctx cancellation; the Manager
|
||||
// retries on transient errors with backoff but stops on
|
||||
// ErrPushUnimplemented.
|
||||
type Pusher interface {
|
||||
PushContextState(ctx context.Context, req *PushRequest) (*PushResponse, error)
|
||||
}
|
||||
|
||||
// ErrPushUnimplemented signals that the coderd peer does not
|
||||
// implement PushContextState. RunPush stops pushing for the
|
||||
// remainder of the connection.
|
||||
var ErrPushUnimplemented = xerrors.New("agentcontext: PushContextState unimplemented")
|
||||
|
||||
// PushOptions parameterizes RunPush.
|
||||
type PushOptions struct {
|
||||
// Logger receives push success/failure diagnostics.
|
||||
Logger slog.Logger
|
||||
// InitialBackoff is the wait before the first retry.
|
||||
// Default 250ms.
|
||||
InitialBackoff time.Duration
|
||||
// MaxBackoff caps the retry wait. Default 30s.
|
||||
MaxBackoff time.Duration
|
||||
}
|
||||
|
||||
// RunPush ships the current snapshot to the Pusher, then ships
|
||||
// every subsequent snapshot whenever the Manager broadcasts a
|
||||
// change. RunPush returns when ctx is canceled, when the
|
||||
// Manager is closed, or when the Pusher signals
|
||||
// ErrPushUnimplemented.
|
||||
//
|
||||
// The first push is always sent with Initial=true so coderd can
|
||||
// distinguish a fresh boot from a drift event.
|
||||
func (m *Manager) RunPush(ctx context.Context, p Pusher, opts PushOptions) error {
|
||||
if p == nil {
|
||||
return xerrors.New("agentcontext: Pusher is required")
|
||||
}
|
||||
logger := opts.Logger
|
||||
initialBackoff := opts.InitialBackoff
|
||||
if initialBackoff <= 0 {
|
||||
initialBackoff = 250 * time.Millisecond
|
||||
}
|
||||
maxBackoff := opts.MaxBackoff
|
||||
if maxBackoff <= 0 {
|
||||
maxBackoff = 30 * time.Second
|
||||
}
|
||||
|
||||
changes, unsub := m.SubscribeChanges()
|
||||
defer unsub()
|
||||
|
||||
// First push uses the snapshot computed by NewManager.
|
||||
initial := true
|
||||
for {
|
||||
snap := m.Snapshot()
|
||||
req := snapshotToPushRequest(snap, initial)
|
||||
|
||||
err := pushWithRetry(ctx, p, req, initialBackoff, maxBackoff, logger)
|
||||
switch {
|
||||
case err == nil:
|
||||
initial = false
|
||||
case errors.Is(err, ErrPushUnimplemented):
|
||||
logger.Warn(ctx, "agentcontext: coderd peer does not implement PushContextState; stopping")
|
||||
return nil
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
|
||||
return ctx.Err()
|
||||
default:
|
||||
// Should be unreachable: pushWithRetry only
|
||||
// returns terminal errors. Log and continue.
|
||||
logger.Warn(ctx, "agentcontext: push terminated with non-retried error", slog.Error(err))
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.closedCh:
|
||||
return nil
|
||||
case _, ok := <-changes:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pushWithRetry retries transient errors with exponential
|
||||
// backoff capped at maxBackoff. The retry loop exits when:
|
||||
//
|
||||
// - ctx is canceled (returns ctx.Err()).
|
||||
// - The Pusher returns nil (success).
|
||||
// - The Pusher returns ErrPushUnimplemented (propagated).
|
||||
func pushWithRetry(
|
||||
ctx context.Context,
|
||||
p Pusher,
|
||||
req *PushRequest,
|
||||
initialBackoff, maxBackoff time.Duration,
|
||||
logger slog.Logger,
|
||||
) error {
|
||||
backoff := initialBackoff
|
||||
for {
|
||||
resp, err := p.PushContextState(ctx, req)
|
||||
if err == nil {
|
||||
if resp != nil && !resp.Accepted {
|
||||
// Out-of-order or replayed push. Do not
|
||||
// retry; the next change will redeliver
|
||||
// the snapshot with a higher version.
|
||||
logger.Debug(ctx, "agentcontext: push rejected, awaiting next change",
|
||||
slog.F("version", req.Version))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, ErrPushUnimplemented) {
|
||||
return err
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return err
|
||||
}
|
||||
logger.Warn(ctx, "agentcontext: push failed, retrying",
|
||||
slog.F("version", req.Version),
|
||||
slog.F("backoff", backoff),
|
||||
slog.Error(err))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// snapshotToPushRequest copies the Snapshot into the wire
|
||||
// representation. The Resources slice is reused; callers must
|
||||
// not mutate it.
|
||||
func snapshotToPushRequest(s Snapshot, initial bool) *PushRequest {
|
||||
return &PushRequest{
|
||||
Version: s.Version,
|
||||
AggregateHash: s.AggregateHash,
|
||||
Resources: s.Resources,
|
||||
Initial: initial,
|
||||
SchemaVersion: s.SchemaVersion,
|
||||
SnapshotError: s.SnapshotError,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
package agentcontext_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontext"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// fakePusher records every push and lets the test control the
|
||||
// returned response and error.
|
||||
type fakePusher struct {
|
||||
mu sync.Mutex
|
||||
requests []*agentcontext.PushRequest
|
||||
resp *agentcontext.PushResponse
|
||||
err error
|
||||
// errOnce is non-nil to simulate a single transient
|
||||
// failure followed by success.
|
||||
errOnce error
|
||||
signal chan struct{}
|
||||
}
|
||||
|
||||
func newFakePusher() *fakePusher {
|
||||
return &fakePusher{
|
||||
resp: &agentcontext.PushResponse{Accepted: true},
|
||||
signal: make(chan struct{}, 16),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *fakePusher) PushContextState(_ context.Context, req *agentcontext.PushRequest) (*agentcontext.PushResponse, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.requests = append(p.requests, req)
|
||||
if p.errOnce != nil {
|
||||
err := p.errOnce
|
||||
p.errOnce = nil
|
||||
return nil, err
|
||||
}
|
||||
select {
|
||||
case p.signal <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return p.resp, p.err
|
||||
}
|
||||
|
||||
func (p *fakePusher) snapshot() []*agentcontext.PushRequest {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
out := make([]*agentcontext.PushRequest, len(p.requests))
|
||||
copy(out, p.requests)
|
||||
return out
|
||||
}
|
||||
|
||||
func TestRunPush_FirstPushIsInitial(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v1"), 0o600))
|
||||
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return dir },
|
||||
})
|
||||
|
||||
p := newFakePusher()
|
||||
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
|
||||
defer cancel()
|
||||
|
||||
pushDone := make(chan error, 1)
|
||||
go func() {
|
||||
pushDone <- m.RunPush(ctx, p, agentcontext.PushOptions{
|
||||
Logger: testutil.Logger(t).Named("push"),
|
||||
})
|
||||
}()
|
||||
|
||||
// Wait for the first push.
|
||||
select {
|
||||
case <-p.signal:
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatalf("expected initial push")
|
||||
}
|
||||
|
||||
requests := p.snapshot()
|
||||
require.Len(t, requests, 1)
|
||||
require.True(t, requests[0].Initial, "first push must be initial")
|
||||
require.Equal(t, uint64(1), requests[0].Version)
|
||||
|
||||
cancel()
|
||||
require.ErrorIs(t, <-pushDone, context.Canceled)
|
||||
}
|
||||
|
||||
func TestRunPush_SubsequentPushOnChange(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v1"), 0o600))
|
||||
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return dir },
|
||||
})
|
||||
|
||||
p := newFakePusher()
|
||||
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
|
||||
defer cancel()
|
||||
|
||||
pushDone := make(chan error, 1)
|
||||
go func() {
|
||||
pushDone <- m.RunPush(ctx, p, agentcontext.PushOptions{
|
||||
Logger: testutil.Logger(t).Named("push"),
|
||||
})
|
||||
}()
|
||||
|
||||
// Initial push.
|
||||
<-p.signal
|
||||
|
||||
// Trigger a resync via Resync.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v2"), 0o600))
|
||||
_, err := m.Resync(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second push.
|
||||
select {
|
||||
case <-p.signal:
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatalf("expected second push after resync")
|
||||
}
|
||||
|
||||
requests := p.snapshot()
|
||||
require.GreaterOrEqual(t, len(requests), 2)
|
||||
require.False(t, requests[1].Initial, "subsequent pushes must not be Initial")
|
||||
|
||||
cancel()
|
||||
require.ErrorIs(t, <-pushDone, context.Canceled)
|
||||
}
|
||||
|
||||
func TestRunPush_StopsOnUnimplemented(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return t.TempDir() },
|
||||
})
|
||||
|
||||
p := newFakePusher()
|
||||
p.err = agentcontext.ErrPushUnimplemented
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
err := m.RunPush(ctx, p, agentcontext.PushOptions{
|
||||
Logger: testutil.Logger(t).Named("push"),
|
||||
})
|
||||
require.NoError(t, err, "Unimplemented must stop the loop cleanly")
|
||||
}
|
||||
|
||||
func TestRunPush_RetriesTransientError(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return t.TempDir() },
|
||||
})
|
||||
|
||||
p := newFakePusher()
|
||||
p.errOnce = xerrors.New("transient")
|
||||
|
||||
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
|
||||
defer cancel()
|
||||
pushDone := make(chan error, 1)
|
||||
go func() {
|
||||
pushDone <- m.RunPush(ctx, p, agentcontext.PushOptions{
|
||||
Logger: testutil.Logger(t).Named("push"),
|
||||
InitialBackoff: 10 * time.Millisecond,
|
||||
})
|
||||
}()
|
||||
|
||||
// First push hits transient, second succeeds.
|
||||
select {
|
||||
case <-p.signal:
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatalf("expected push after transient error")
|
||||
}
|
||||
require.GreaterOrEqual(t, len(p.snapshot()), 2)
|
||||
|
||||
cancel()
|
||||
<-pushDone
|
||||
}
|
||||
|
||||
func TestRunPush_NilPusherErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := newTestManager(t, agentcontext.ManagerOptions{
|
||||
WorkingDir: func() string { return t.TempDir() },
|
||||
})
|
||||
err := m.RunPush(context.Background(), nil, agentcontext.PushOptions{
|
||||
Logger: testutil.Logger(t).Named("push"),
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -0,0 +1,613 @@
|
||||
package agentcontext
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// Default caps. Copied from the RFC. The Manager exposes
|
||||
// overrides via Options.
|
||||
const (
|
||||
// DefaultMaxResourceBytes is the per-resource payload cap.
|
||||
// Resources whose payload exceeds this size are emitted
|
||||
// with Status == StatusOversize and an empty Payload.
|
||||
DefaultMaxResourceBytes = 64 * 1024
|
||||
// DefaultMaxSnapshotBytes is the aggregate payload cap.
|
||||
// Resources past this cap are emitted with Status ==
|
||||
// StatusExcluded.
|
||||
DefaultMaxSnapshotBytes = 2 * 1024 * 1024
|
||||
// DefaultMaxResources is the resource count cap. Resources
|
||||
// past this cap are emitted with Status == StatusExcluded.
|
||||
DefaultMaxResources = 500
|
||||
// DefaultMaxScanDepth bounds how deep the recursive walk
|
||||
// descends from each scan root. The default avoids runaway
|
||||
// scans in node_modules / vendor / .git trees while still
|
||||
// covering realistic monorepo layouts.
|
||||
DefaultMaxScanDepth = 8
|
||||
)
|
||||
|
||||
// File-name conventions recognized by the v1 resolver.
|
||||
var (
|
||||
// instructionFileNames are picked up from any scan root.
|
||||
// Matching is case-insensitive on the basename.
|
||||
instructionFileNames = []string{
|
||||
"AGENTS.md",
|
||||
"CLAUDE.md",
|
||||
".cursorrules",
|
||||
}
|
||||
// mcpConfigFileName is recognized at any depth under a
|
||||
// scan root.
|
||||
mcpConfigFileName = ".mcp.json"
|
||||
// skillMetaFileName is the file inside a skill directory
|
||||
// that carries the skill front-matter.
|
||||
skillMetaFileName = "SKILL.md"
|
||||
)
|
||||
|
||||
// skipDirNames are directory basenames that the recursive walk
|
||||
// never descends into. The list mirrors what most language
|
||||
// tool-chains treat as opaque.
|
||||
var skipDirNames = map[string]struct{}{
|
||||
".git": {},
|
||||
".hg": {},
|
||||
".svn": {},
|
||||
"node_modules": {},
|
||||
"vendor": {},
|
||||
"target": {},
|
||||
"dist": {},
|
||||
"build": {},
|
||||
".venv": {},
|
||||
"__pycache__": {},
|
||||
}
|
||||
|
||||
// skillsParentNames are directory basenames that signal a
|
||||
// skills container; their immediate children are scanned for
|
||||
// SKILL.md files.
|
||||
var skillsParentNames = map[string]struct{}{
|
||||
"skills": {},
|
||||
".agents": {}, // covers ".agents/skills/<name>"
|
||||
"agents": {},
|
||||
"plugins": {}, // claude code plugin cache layout
|
||||
"cache": {},
|
||||
".coder": {},
|
||||
".claude": {},
|
||||
"skills-dir": {},
|
||||
}
|
||||
|
||||
// recognizedInstructionFile reports whether name is one of the
|
||||
// instruction-file conventions, case-insensitively.
|
||||
func recognizedInstructionFile(name string) bool {
|
||||
for _, candidate := range instructionFileNames {
|
||||
if strings.EqualFold(name, candidate) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Resolver walks one or more scan roots and produces a snapshot
|
||||
// of every recognized resource it finds. The Resolver is
|
||||
// stateless; the Manager owns the scan-root list and orchestrates
|
||||
// successive resolves.
|
||||
type Resolver struct {
|
||||
// MaxResourceBytes caps the per-resource payload size. Use
|
||||
// DefaultMaxResourceBytes if zero.
|
||||
MaxResourceBytes uint64
|
||||
// MaxSnapshotBytes caps the aggregate payload size. Use
|
||||
// DefaultMaxSnapshotBytes if zero.
|
||||
MaxSnapshotBytes uint64
|
||||
// MaxResources caps the resource count. Use
|
||||
// DefaultMaxResources if zero.
|
||||
MaxResources int
|
||||
// MaxDepth caps the directory walk depth. Use
|
||||
// DefaultMaxScanDepth if zero.
|
||||
MaxDepth int
|
||||
// MCP, when non-nil, is consulted after the filesystem
|
||||
// pass and contributes any KindMCPServer resources for
|
||||
// live MCP servers.
|
||||
MCP MCPProvider
|
||||
}
|
||||
|
||||
// ScanRoot describes a single directory or file the resolver
|
||||
// should examine.
|
||||
type ScanRoot struct {
|
||||
// Path is the absolute path. Symlinks should already be
|
||||
// resolved.
|
||||
Path string
|
||||
// UserSource is the canonical source path the user
|
||||
// declared, when this root came from a user-added Source.
|
||||
// Empty for built-in roots.
|
||||
UserSource string
|
||||
}
|
||||
|
||||
// Resolve walks the supplied scan roots and returns a Snapshot.
|
||||
// The version and schemaVersion fields are stamped by the
|
||||
// caller; Resolve fills everything else.
|
||||
func (r *Resolver) Resolve(roots []ScanRoot) Snapshot {
|
||||
res := r.normalize()
|
||||
resources, snapErrs := res.walk(roots)
|
||||
resources = res.applyCaps(resources)
|
||||
|
||||
// Append MCP server resources after the filesystem caps
|
||||
// are applied so a runaway MCP server cannot crowd out
|
||||
// instruction files.
|
||||
if r.MCP != nil {
|
||||
mcp := r.MCP.MCPResources()
|
||||
resources = append(resources, mcp...)
|
||||
// MCP resources may push the aggregate over the cap.
|
||||
// Re-apply count and size limits to MCP entries only.
|
||||
resources, snapErrs = res.applyMCPCaps(resources, snapErrs)
|
||||
}
|
||||
|
||||
// Deterministic order by ID for stable IDs and hashes.
|
||||
sort.Slice(resources, func(i, j int) bool {
|
||||
return resources[i].ID < resources[j].ID
|
||||
})
|
||||
|
||||
var payloadBytes uint64
|
||||
for _, r := range resources {
|
||||
payloadBytes += uint64(len(r.Payload))
|
||||
}
|
||||
|
||||
hash := ComputeAggregateHash(resources)
|
||||
|
||||
snap := Snapshot{
|
||||
Resources: resources,
|
||||
AggregateHash: hash,
|
||||
PayloadBytes: payloadBytes,
|
||||
}
|
||||
if len(snapErrs) > 0 {
|
||||
// Pick the most severe single error. Today every
|
||||
// snapshot-level problem is "warning equivalent" so
|
||||
// the first one wins; the design reserves the field
|
||||
// for a singular message.
|
||||
snap.SnapshotError = snapErrs[0]
|
||||
}
|
||||
return snap
|
||||
}
|
||||
|
||||
func (r *Resolver) normalize() *Resolver {
|
||||
out := *r
|
||||
if out.MaxResourceBytes == 0 {
|
||||
out.MaxResourceBytes = DefaultMaxResourceBytes
|
||||
}
|
||||
if out.MaxSnapshotBytes == 0 {
|
||||
out.MaxSnapshotBytes = DefaultMaxSnapshotBytes
|
||||
}
|
||||
if out.MaxResources == 0 {
|
||||
out.MaxResources = DefaultMaxResources
|
||||
}
|
||||
if out.MaxDepth == 0 {
|
||||
out.MaxDepth = DefaultMaxScanDepth
|
||||
}
|
||||
return &out
|
||||
}
|
||||
|
||||
// walk traverses every scan root and produces an unordered
|
||||
// resource list. Aggregate caps are applied separately.
|
||||
func (r *Resolver) walk(roots []ScanRoot) (resources []Resource, snapErrs []string) {
|
||||
// Dedup roots by canonical path. The first occurrence
|
||||
// wins so user-added roots that overlap with a built-in
|
||||
// root attribute resources to the built-in.
|
||||
seenRoot := make(map[string]struct{}, len(roots))
|
||||
dedup := make([]ScanRoot, 0, len(roots))
|
||||
for _, root := range roots {
|
||||
if root.Path == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seenRoot[root.Path]; ok {
|
||||
continue
|
||||
}
|
||||
seenRoot[root.Path] = struct{}{}
|
||||
dedup = append(dedup, root)
|
||||
}
|
||||
|
||||
// Deduplicate resources across roots by ID. Without this,
|
||||
// a built-in root and a user root that both cover the
|
||||
// same project tree would double-count AGENTS.md.
|
||||
seenID := make(map[string]struct{})
|
||||
|
||||
for _, root := range dedup {
|
||||
info, err := os.Stat(root.Path)
|
||||
if err != nil {
|
||||
// Missing roots silently fall through. The user
|
||||
// either added a path that does not exist yet or
|
||||
// removed it later. The watcher will surface
|
||||
// re-creation as a change event.
|
||||
continue
|
||||
}
|
||||
if !info.IsDir() {
|
||||
// Single-file roots are classified directly.
|
||||
if res, ok := r.classifyFile(root.Path, info, root.UserSource); ok {
|
||||
if _, dup := seenID[res.ID]; !dup {
|
||||
seenID[res.ID] = struct{}{}
|
||||
resources = append(resources, res)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
walkErr := r.walkDir(root, &resources, seenID)
|
||||
if walkErr != nil {
|
||||
snapErrs = append(snapErrs, fmt.Sprintf("walk %q: %s", root.Path, walkErr))
|
||||
}
|
||||
}
|
||||
return resources, snapErrs
|
||||
}
|
||||
|
||||
// walkDir performs the recursive descent for a single scan
|
||||
// directory. It honors r.MaxDepth and skipDirNames.
|
||||
func (r *Resolver) walkDir(root ScanRoot, out *[]Resource, seenID map[string]struct{}) error {
|
||||
rootDepth := strings.Count(filepath.Clean(root.Path), string(os.PathSeparator))
|
||||
maxDepth := rootDepth + r.MaxDepth
|
||||
|
||||
return filepath.WalkDir(root.Path, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
// Surface the error as Unreadable when we can
|
||||
// associate it with a single recognized file;
|
||||
// otherwise let the walk continue.
|
||||
if d != nil && !d.IsDir() {
|
||||
if recognizedInstructionFile(d.Name()) ||
|
||||
d.Name() == mcpConfigFileName ||
|
||||
d.Name() == skillMetaFileName {
|
||||
res := Resource{
|
||||
ID: resourceID(KindInstructionFile, path),
|
||||
Kind: KindInstructionFile,
|
||||
Source: path,
|
||||
SizeBytes: 0,
|
||||
Status: StatusUnreadable,
|
||||
Error: err.Error(),
|
||||
SourcePath: root.UserSource,
|
||||
}
|
||||
if _, dup := seenID[res.ID]; !dup {
|
||||
seenID[res.ID] = struct{}{}
|
||||
*out = append(*out, res)
|
||||
}
|
||||
}
|
||||
}
|
||||
if errors.Is(err, fs.ErrPermission) {
|
||||
// Permission errors on a directory: skip the
|
||||
// subtree but continue walking siblings.
|
||||
if d != nil && d.IsDir() {
|
||||
return fs.SkipDir
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if d.IsDir() {
|
||||
if strings.Count(path, string(os.PathSeparator)) > maxDepth {
|
||||
return fs.SkipDir
|
||||
}
|
||||
if _, skip := skipDirNames[d.Name()]; skip && path != root.Path {
|
||||
return fs.SkipDir
|
||||
}
|
||||
// If we are entering a "skills container"
|
||||
// directory (".agents/skills", "~/.coder/skills",
|
||||
// "plugins/<plugin>/skills"), eagerly emit skill
|
||||
// resources for its immediate subdirectories.
|
||||
if isSkillsContainer(path) {
|
||||
r.emitSkillsFromContainer(path, root, out, seenID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Regular file.
|
||||
info, statErr := d.Info()
|
||||
if statErr != nil {
|
||||
return nil
|
||||
}
|
||||
if res, ok := r.classifyFile(path, info, root.UserSource); ok {
|
||||
if _, dup := seenID[res.ID]; dup {
|
||||
return nil
|
||||
}
|
||||
seenID[res.ID] = struct{}{}
|
||||
*out = append(*out, res)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// classifyFile inspects a single file path and produces a
|
||||
// Resource when the basename matches a recognized convention.
|
||||
func (r *Resolver) classifyFile(path string, info fs.FileInfo, userSource string) (Resource, bool) {
|
||||
name := info.Name()
|
||||
switch {
|
||||
case recognizedInstructionFile(name):
|
||||
return r.readInstructionFile(path, info, userSource), true
|
||||
case name == mcpConfigFileName:
|
||||
return r.readMCPConfig(path, info, userSource), true
|
||||
case name == skillMetaFileName:
|
||||
// SKILL.md outside a skills container is still a
|
||||
// valid skill if its parent directory name matches
|
||||
// the front-matter name. emitSkillsFromContainer
|
||||
// already handles the common case; here we cover
|
||||
// "user adds a single SKILL.md file as a source".
|
||||
res, ok := r.readSkillMeta(path, info, userSource)
|
||||
return res, ok
|
||||
default:
|
||||
return Resource{}, false
|
||||
}
|
||||
}
|
||||
|
||||
// readInstructionFile reads an instruction file and produces a
|
||||
// KindInstructionFile resource. The file is read into memory
|
||||
// with the per-resource cap applied.
|
||||
func (r *Resolver) readInstructionFile(path string, info fs.FileInfo, userSource string) Resource {
|
||||
res := Resource{
|
||||
ID: resourceID(KindInstructionFile, path),
|
||||
Kind: KindInstructionFile,
|
||||
Source: path,
|
||||
SizeBytes: safeUint64(info.Size()),
|
||||
SourcePath: userSource,
|
||||
}
|
||||
if safeUint64(info.Size()) > r.MaxResourceBytes {
|
||||
res.Status = StatusOversize
|
||||
res.Error = fmt.Sprintf("file size %d exceeds per-resource cap of %d bytes", info.Size(), r.MaxResourceBytes)
|
||||
// Still hash the (capped) content so a fix is
|
||||
// detectable.
|
||||
if data, err := readFileCapped(path, safeInt64(r.MaxResourceBytes)); err == nil {
|
||||
res.ContentHash = sha256.Sum256(data)
|
||||
}
|
||||
return res
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
res.Status = StatusUnreadable
|
||||
res.Error = err.Error()
|
||||
return res
|
||||
}
|
||||
res.Payload = data
|
||||
res.ContentHash = sha256.Sum256(data)
|
||||
// Description is just the first non-empty line.
|
||||
res.Description = firstLine(string(data))
|
||||
return res
|
||||
}
|
||||
|
||||
// readMCPConfig reads a .mcp.json file and produces a
|
||||
// KindMCPConfig resource. Parsing is left to consumers; the
|
||||
// resolver only enforces JSON shape lightly via size and Unix
|
||||
// newline conversion. Future work: detect malformed JSON and
|
||||
// surface StatusInvalid.
|
||||
func (r *Resolver) readMCPConfig(path string, info fs.FileInfo, userSource string) Resource {
|
||||
res := Resource{
|
||||
ID: resourceID(KindMCPConfig, path),
|
||||
Kind: KindMCPConfig,
|
||||
Source: path,
|
||||
SizeBytes: safeUint64(info.Size()),
|
||||
SourcePath: userSource,
|
||||
}
|
||||
if safeUint64(info.Size()) > r.MaxResourceBytes {
|
||||
res.Status = StatusOversize
|
||||
res.Error = fmt.Sprintf("file size %d exceeds per-resource cap of %d bytes", info.Size(), r.MaxResourceBytes)
|
||||
if data, err := readFileCapped(path, safeInt64(r.MaxResourceBytes)); err == nil {
|
||||
res.ContentHash = sha256.Sum256(data)
|
||||
}
|
||||
return res
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
res.Status = StatusUnreadable
|
||||
res.Error = err.Error()
|
||||
return res
|
||||
}
|
||||
res.Payload = data
|
||||
res.ContentHash = sha256.Sum256(data)
|
||||
return res
|
||||
}
|
||||
|
||||
// readSkillMeta reads a SKILL.md file, parses its front-matter,
|
||||
// and emits a KindSkill resource. The name encoded in the
|
||||
// front-matter must match the parent directory's basename to
|
||||
// be considered valid; otherwise Status is StatusInvalid.
|
||||
func (r *Resolver) readSkillMeta(path string, info fs.FileInfo, userSource string) (Resource, bool) {
|
||||
parent := filepath.Base(filepath.Dir(path))
|
||||
res := Resource{
|
||||
ID: resourceID(KindSkill, filepath.Dir(path)),
|
||||
Kind: KindSkill,
|
||||
Source: filepath.Dir(path),
|
||||
SizeBytes: safeUint64(info.Size()),
|
||||
SourcePath: userSource,
|
||||
}
|
||||
if safeUint64(info.Size()) > r.MaxResourceBytes {
|
||||
res.Status = StatusOversize
|
||||
res.Error = fmt.Sprintf("file size %d exceeds per-resource cap of %d bytes", info.Size(), r.MaxResourceBytes)
|
||||
return res, true
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
res.Status = StatusUnreadable
|
||||
res.Error = err.Error()
|
||||
return res, true
|
||||
}
|
||||
res.ContentHash = sha256.Sum256(data)
|
||||
name, description, _, err := workspacesdk.ParseSkillFrontmatter(string(data))
|
||||
if err != nil {
|
||||
res.Status = StatusInvalid
|
||||
res.Error = err.Error()
|
||||
return res, true
|
||||
}
|
||||
if name != parent {
|
||||
res.Status = StatusInvalid
|
||||
res.Error = fmt.Sprintf("front-matter name %q does not match directory %q", name, parent)
|
||||
return res, true
|
||||
}
|
||||
if !workspacesdk.SkillNamePattern.MatchString(name) {
|
||||
res.Status = StatusInvalid
|
||||
res.Error = fmt.Sprintf("skill name %q is not kebab-case", name)
|
||||
return res, true
|
||||
}
|
||||
res.Description = description
|
||||
res.Payload = data
|
||||
return res, true
|
||||
}
|
||||
|
||||
// emitSkillsFromContainer scans the immediate children of a
|
||||
// recognized skills-container directory and emits one Skill
|
||||
// resource per subdirectory whose SKILL.md parses cleanly.
|
||||
func (r *Resolver) emitSkillsFromContainer(container string, root ScanRoot, out *[]Resource, seenID map[string]struct{}) {
|
||||
entries, err := os.ReadDir(container)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
meta := filepath.Join(container, e.Name(), skillMetaFileName)
|
||||
info, err := os.Stat(meta)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res, ok := r.readSkillMeta(meta, info, root.UserSource)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, dup := seenID[res.ID]; dup {
|
||||
continue
|
||||
}
|
||||
seenID[res.ID] = struct{}{}
|
||||
*out = append(*out, res)
|
||||
}
|
||||
}
|
||||
|
||||
// applyCaps enforces the resource-count cap and aggregate
|
||||
// payload cap. Resources past either cap have their Status set
|
||||
// to StatusExcluded and their Payload cleared.
|
||||
func (r *Resolver) applyCaps(resources []Resource) []Resource {
|
||||
// Stable sort by (Kind asc, Source asc) so excluded
|
||||
// resources are deterministic.
|
||||
sort.SliceStable(resources, func(i, j int) bool {
|
||||
if resources[i].Kind != resources[j].Kind {
|
||||
return resources[i].Kind < resources[j].Kind
|
||||
}
|
||||
return resources[i].Source < resources[j].Source
|
||||
})
|
||||
|
||||
var total uint64
|
||||
for i := range resources {
|
||||
if i >= r.MaxResources {
|
||||
resources[i] = excluded(resources[i],
|
||||
fmt.Sprintf("dropped to fit %d-resource snapshot count cap", r.MaxResources))
|
||||
continue
|
||||
}
|
||||
if resources[i].Status != StatusOK {
|
||||
continue
|
||||
}
|
||||
size := uint64(len(resources[i].Payload))
|
||||
if total+size > r.MaxSnapshotBytes {
|
||||
resources[i] = excluded(resources[i],
|
||||
fmt.Sprintf("dropped to fit %d-byte aggregate cap", r.MaxSnapshotBytes))
|
||||
continue
|
||||
}
|
||||
total += size
|
||||
}
|
||||
return resources
|
||||
}
|
||||
|
||||
// applyMCPCaps re-applies the count cap after MCP resources are
|
||||
// appended. MCP payloads are typically small JSON descriptors,
|
||||
// so we treat the aggregate budget as already consumed by the
|
||||
// filesystem pass.
|
||||
func (r *Resolver) applyMCPCaps(resources []Resource, snapErrs []string) ([]Resource, []string) {
|
||||
if len(resources) <= r.MaxResources {
|
||||
return resources, snapErrs
|
||||
}
|
||||
for i := r.MaxResources; i < len(resources); i++ {
|
||||
resources[i] = excluded(resources[i],
|
||||
fmt.Sprintf("dropped to fit %d-resource snapshot count cap", r.MaxResources))
|
||||
}
|
||||
snapErrs = append(snapErrs, fmt.Sprintf("snapshot exceeds %d-resource count cap", r.MaxResources))
|
||||
return resources, snapErrs
|
||||
}
|
||||
|
||||
// excluded mutates and returns the supplied resource with the
|
||||
// StatusExcluded outcome.
|
||||
func excluded(r Resource, reason string) Resource {
|
||||
r.Status = StatusExcluded
|
||||
r.Error = reason
|
||||
r.Payload = nil
|
||||
return r
|
||||
}
|
||||
|
||||
// isSkillsContainer reports whether dir is a recognized skills
|
||||
// container directory whose immediate children carry SKILL.md
|
||||
// files.
|
||||
func isSkillsContainer(dir string) bool {
|
||||
base := filepath.Base(dir)
|
||||
_, ok := skillsParentNames[base]
|
||||
if ok && base == "skills" {
|
||||
return true
|
||||
}
|
||||
// "<x>/skills" form (e.g. ".agents/skills", "plugins/foo/skills").
|
||||
if strings.HasSuffix(filepath.ToSlash(dir), "/skills") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// resourceID builds a stable resource ID. Kind plus canonical
|
||||
// source path is enough; sources never collide across kinds for
|
||||
// v1 because each kind owns a distinct file-name pattern.
|
||||
func resourceID(kind ResourceKind, source string) string {
|
||||
return kind.String() + ":" + source
|
||||
}
|
||||
|
||||
// readFileCapped reads up to maxBytes from path. It returns the
|
||||
// truncated payload on success.
|
||||
func readFileCapped(path string, maxBytes int64) ([]byte, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
return io.ReadAll(io.LimitReader(f, maxBytes))
|
||||
}
|
||||
|
||||
// firstLine returns the first non-empty trimmed line of s, used
|
||||
// as a short description fallback.
|
||||
func firstLine(s string) string {
|
||||
for _, line := range strings.Split(s, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
// Strip leading markdown heading markers for prettier
|
||||
// descriptions.
|
||||
return strings.TrimSpace(headingPrefixRegex.ReplaceAllString(line, ""))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var headingPrefixRegex = regexp.MustCompile(`^#+\s*`)
|
||||
|
||||
// safeUint64 converts a non-negative int64 to uint64. Negative
|
||||
// inputs are clamped to 0, which is safe for the size-tracking
|
||||
// fields that use it; a negative os.FileInfo size is pathological
|
||||
// and never indicates real content.
|
||||
func safeUint64(n int64) uint64 {
|
||||
if n < 0 {
|
||||
return 0
|
||||
}
|
||||
return uint64(n)
|
||||
}
|
||||
|
||||
// safeInt64 converts a uint64 to int64, clamping to math.MaxInt64
|
||||
// when the input would overflow. The caps configured on the
|
||||
// resolver never approach 2^63 bytes, so the clamp only guards
|
||||
// against pathological caller input.
|
||||
func safeInt64(n uint64) int64 {
|
||||
if n > math.MaxInt64 {
|
||||
return math.MaxInt64
|
||||
}
|
||||
return int64(n)
|
||||
}
|
||||
@@ -0,0 +1,276 @@
|
||||
package agentcontext_test
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontext"
|
||||
)
|
||||
|
||||
func mustWriteFile(t *testing.T, path, content string) {
|
||||
t.Helper()
|
||||
require.NoError(t, os.MkdirAll(filepath.Dir(path), 0o755))
|
||||
require.NoError(t, os.WriteFile(path, []byte(content), 0o600))
|
||||
}
|
||||
|
||||
func mustWriteSkill(t *testing.T, dir, name, description string) {
|
||||
t.Helper()
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(dir, name), 0o755))
|
||||
mustWriteFile(t, filepath.Join(dir, name, "SKILL.md"),
|
||||
"---\nname: "+name+"\ndescription: "+description+"\n---\nSkill body for "+name)
|
||||
}
|
||||
|
||||
func findResource(t *testing.T, resources []agentcontext.Resource, kind agentcontext.ResourceKind, source string) agentcontext.Resource {
|
||||
t.Helper()
|
||||
for _, r := range resources {
|
||||
if r.Kind == kind && r.Source == source {
|
||||
return r
|
||||
}
|
||||
}
|
||||
t.Fatalf("resource not found: kind=%s source=%s", kind, source)
|
||||
return agentcontext.Resource{}
|
||||
}
|
||||
|
||||
func TestResolver_ProjectAGENTSFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "# Project rules\n\nDo the thing.")
|
||||
|
||||
r := &agentcontext.Resolver{}
|
||||
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
|
||||
|
||||
require.Len(t, snap.Resources, 1)
|
||||
got := snap.Resources[0]
|
||||
require.Equal(t, agentcontext.KindInstructionFile, got.Kind)
|
||||
require.Equal(t, agentcontext.StatusOK, got.Status)
|
||||
require.Equal(t, filepath.Join(dir, "AGENTS.md"), got.Source)
|
||||
require.Contains(t, string(got.Payload), "Do the thing.")
|
||||
require.Equal(t, "Project rules", got.Description)
|
||||
require.NotEqual(t, [32]byte{}, got.ContentHash)
|
||||
}
|
||||
|
||||
func TestResolver_CaseInsensitiveInstructionNames(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
mustWriteFile(t, filepath.Join(dir, "agents.md"), "lower\n")
|
||||
mustWriteFile(t, filepath.Join(dir, "CLAUDE.md"), "claude\n")
|
||||
|
||||
r := &agentcontext.Resolver{}
|
||||
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
|
||||
|
||||
require.Len(t, snap.Resources, 2)
|
||||
}
|
||||
|
||||
func TestResolver_SkillsContainerEmitsEachSubdir(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
mustWriteSkill(t, filepath.Join(dir, ".agents", "skills"), "make-coffee", "Coffee skill")
|
||||
mustWriteSkill(t, filepath.Join(dir, ".agents", "skills"), "fold-laundry", "Laundry skill")
|
||||
|
||||
r := &agentcontext.Resolver{}
|
||||
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
|
||||
|
||||
var kinds []string
|
||||
for _, res := range snap.Resources {
|
||||
kinds = append(kinds, res.Kind.String()+":"+filepath.Base(res.Source))
|
||||
}
|
||||
require.ElementsMatch(t, []string{
|
||||
"skill:make-coffee",
|
||||
"skill:fold-laundry",
|
||||
}, kinds)
|
||||
}
|
||||
|
||||
func TestResolver_SkillNameMismatchInvalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
skillsDir := filepath.Join(dir, ".agents", "skills", "make-coffee")
|
||||
require.NoError(t, os.MkdirAll(skillsDir, 0o755))
|
||||
mustWriteFile(t, filepath.Join(skillsDir, "SKILL.md"),
|
||||
"---\nname: drink-tea\ndescription: oops\n---\nBody")
|
||||
|
||||
r := &agentcontext.Resolver{}
|
||||
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
|
||||
|
||||
require.Len(t, snap.Resources, 1)
|
||||
got := snap.Resources[0]
|
||||
require.Equal(t, agentcontext.KindSkill, got.Kind)
|
||||
require.Equal(t, agentcontext.StatusInvalid, got.Status)
|
||||
require.Contains(t, got.Error, "does not match directory")
|
||||
}
|
||||
|
||||
func TestResolver_MCPConfigEmitted(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
mustWriteFile(t, filepath.Join(dir, ".mcp.json"), `{"mcpServers": {}}`)
|
||||
|
||||
r := &agentcontext.Resolver{}
|
||||
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
|
||||
|
||||
require.Len(t, snap.Resources, 1)
|
||||
require.Equal(t, agentcontext.KindMCPConfig, snap.Resources[0].Kind)
|
||||
require.Equal(t, agentcontext.StatusOK, snap.Resources[0].Status)
|
||||
}
|
||||
|
||||
func TestResolver_OversizeInstructionFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
// Write a file larger than the per-resource cap.
|
||||
big := make([]byte, 200)
|
||||
for i := range big {
|
||||
big[i] = 'a'
|
||||
}
|
||||
mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), string(big))
|
||||
|
||||
r := &agentcontext.Resolver{MaxResourceBytes: 100}
|
||||
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
|
||||
|
||||
require.Len(t, snap.Resources, 1)
|
||||
got := snap.Resources[0]
|
||||
require.Equal(t, agentcontext.StatusOversize, got.Status)
|
||||
require.Empty(t, got.Payload)
|
||||
require.Equal(t, uint64(200), got.SizeBytes)
|
||||
// Hash over capped slice is still populated so callers
|
||||
// can detect "still oversize but content changed".
|
||||
require.NotEqual(t, [32]byte{}, got.ContentHash)
|
||||
}
|
||||
|
||||
func TestResolver_AggregateCapExcludes(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "small")
|
||||
subA := filepath.Join(dir, "a")
|
||||
subB := filepath.Join(dir, "b")
|
||||
mustWriteFile(t, filepath.Join(subA, "AGENTS.md"), "AAAA")
|
||||
mustWriteFile(t, filepath.Join(subB, "AGENTS.md"), "BBBB")
|
||||
|
||||
// Aggregate cap of 9 bytes lets the first two through but
|
||||
// excludes the third regardless of which order they
|
||||
// appear.
|
||||
r := &agentcontext.Resolver{MaxSnapshotBytes: 9}
|
||||
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
|
||||
|
||||
var excluded int
|
||||
for _, res := range snap.Resources {
|
||||
if res.Status == agentcontext.StatusExcluded {
|
||||
excluded++
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, excluded)
|
||||
}
|
||||
|
||||
func TestResolver_CountCapExcludes(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
for i := 0; i < 5; i++ {
|
||||
sub := filepath.Join(dir, "dir", string('a'+rune(i)))
|
||||
mustWriteFile(t, filepath.Join(sub, "AGENTS.md"), "x")
|
||||
}
|
||||
|
||||
r := &agentcontext.Resolver{MaxResources: 3}
|
||||
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
|
||||
|
||||
require.Len(t, snap.Resources, 5)
|
||||
var excluded int
|
||||
for _, res := range snap.Resources {
|
||||
if res.Status == agentcontext.StatusExcluded {
|
||||
excluded++
|
||||
}
|
||||
}
|
||||
require.Equal(t, 2, excluded)
|
||||
}
|
||||
|
||||
func TestResolver_SkipsVendorAndNodeModules(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "root")
|
||||
mustWriteFile(t, filepath.Join(dir, "node_modules", "deep", "AGENTS.md"), "should not appear")
|
||||
mustWriteFile(t, filepath.Join(dir, "vendor", "AGENTS.md"), "should not appear either")
|
||||
|
||||
r := &agentcontext.Resolver{}
|
||||
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
|
||||
|
||||
require.Len(t, snap.Resources, 1)
|
||||
require.Equal(t, filepath.Join(dir, "AGENTS.md"), snap.Resources[0].Source)
|
||||
}
|
||||
|
||||
func TestResolver_UserSourceAttribution(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "user-added")
|
||||
|
||||
r := &agentcontext.Resolver{}
|
||||
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir, UserSource: dir}})
|
||||
|
||||
require.Len(t, snap.Resources, 1)
|
||||
require.Equal(t, dir, snap.Resources[0].SourcePath)
|
||||
}
|
||||
|
||||
func TestResolver_MissingRootSilentlyIgnored(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &agentcontext.Resolver{}
|
||||
snap := r.Resolve([]agentcontext.ScanRoot{{Path: "/nonexistent/path"}})
|
||||
require.Empty(t, snap.Resources)
|
||||
require.Empty(t, snap.SnapshotError)
|
||||
}
|
||||
|
||||
func TestResolver_SingleFileRootClassified(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "AGENTS.md")
|
||||
mustWriteFile(t, path, "x")
|
||||
|
||||
r := &agentcontext.Resolver{}
|
||||
snap := r.Resolve([]agentcontext.ScanRoot{{Path: path}})
|
||||
|
||||
require.Len(t, snap.Resources, 1)
|
||||
require.Equal(t, agentcontext.KindInstructionFile, snap.Resources[0].Kind)
|
||||
}
|
||||
|
||||
func TestResolver_DuplicateRootsDeduplicated(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "x")
|
||||
|
||||
r := &agentcontext.Resolver{}
|
||||
snap := r.Resolve([]agentcontext.ScanRoot{
|
||||
{Path: dir},
|
||||
{Path: dir},
|
||||
{Path: dir},
|
||||
})
|
||||
require.Len(t, snap.Resources, 1)
|
||||
}
|
||||
|
||||
func TestResolver_MCPProviderResources(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
|
||||
mcpRes := agentcontext.Resource{
|
||||
ID: "mcp_server:github",
|
||||
Kind: agentcontext.KindMCPServer,
|
||||
Source: "github",
|
||||
Status: agentcontext.StatusOK,
|
||||
Payload: []byte("tool-list-json"),
|
||||
ContentHash: sha256.Sum256([]byte("tool-list-json")),
|
||||
Description: "GitHub MCP server",
|
||||
}
|
||||
r := &agentcontext.Resolver{
|
||||
MCP: &fakeMCPProvider{resources: []agentcontext.Resource{mcpRes}},
|
||||
}
|
||||
|
||||
snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}})
|
||||
got := findResource(t, snap.Resources, agentcontext.KindMCPServer, "github")
|
||||
require.Equal(t, agentcontext.StatusOK, got.Status)
|
||||
require.Equal(t, "GitHub MCP server", got.Description)
|
||||
}
|
||||
|
||||
type fakeMCPProvider struct {
|
||||
resources []agentcontext.Resource
|
||||
}
|
||||
|
||||
func (f *fakeMCPProvider) MCPResources() []agentcontext.Resource {
|
||||
return f.resources
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
package agentcontext
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"sort"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// ResourceKind describes the category of a resolved context
|
||||
// resource. The values mirror the proto ContextResource.Kind
|
||||
// enum reserved in the RFC; future kinds (PLUGIN, HOOK,
|
||||
// SUBAGENT, COMMAND) are defined here so callers can switch
|
||||
// exhaustively, but no v1 resolver emits them.
|
||||
type ResourceKind int
|
||||
|
||||
const (
|
||||
KindUnspecified ResourceKind = iota
|
||||
// KindInstructionFile covers AGENTS.md, CLAUDE.md,
|
||||
// .cursorrules, and similar plain-text rule files that
|
||||
// inject content into the model prompt.
|
||||
KindInstructionFile
|
||||
// KindSkill is a directory containing SKILL.md and any
|
||||
// supporting files. Only the meta file is read at
|
||||
// resolve time; bodies are fetched on demand.
|
||||
KindSkill
|
||||
// KindMCPConfig is a .mcp.json fragment declaring one or
|
||||
// more MCP servers.
|
||||
KindMCPConfig
|
||||
// KindMCPServer is a live MCP server's resolved tool list,
|
||||
// populated by an MCPProvider after the server has been
|
||||
// connected.
|
||||
KindMCPServer
|
||||
// KindPlugin is reserved for Claude Code plugin manifests.
|
||||
// Not emitted by v1.
|
||||
KindPlugin
|
||||
// KindHook is reserved for plugin hooks. Not emitted by v1.
|
||||
KindHook
|
||||
// KindSubagent is reserved for plugin-declared subagents.
|
||||
// Not emitted by v1.
|
||||
KindSubagent
|
||||
// KindCommand is reserved for plugin slash commands.
|
||||
// Not emitted by v1.
|
||||
KindCommand
|
||||
)
|
||||
|
||||
// String returns the lower-snake-case name used in IDs and
|
||||
// metrics. Unknown values stringify to "unknown".
|
||||
func (k ResourceKind) String() string {
|
||||
switch k {
|
||||
case KindInstructionFile:
|
||||
return "instruction_file"
|
||||
case KindSkill:
|
||||
return "skill"
|
||||
case KindMCPConfig:
|
||||
return "mcp_config"
|
||||
case KindMCPServer:
|
||||
return "mcp_server"
|
||||
case KindPlugin:
|
||||
return "plugin"
|
||||
case KindHook:
|
||||
return "hook"
|
||||
case KindSubagent:
|
||||
return "subagent"
|
||||
case KindCommand:
|
||||
return "command"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// ResourceStatus describes whether a resource was successfully
|
||||
// read and whether its payload survived the per-resource and
|
||||
// aggregate caps.
|
||||
type ResourceStatus int
|
||||
|
||||
const (
|
||||
// StatusOK indicates the payload was populated.
|
||||
StatusOK ResourceStatus = iota
|
||||
// StatusOversize indicates the resource exceeded the
|
||||
// per-resource size cap; payload is omitted.
|
||||
StatusOversize
|
||||
// StatusUnreadable indicates an IO error reading the
|
||||
// resource (permission denied, broken symlink, etc.).
|
||||
StatusUnreadable
|
||||
// StatusInvalid indicates the resource was structurally
|
||||
// malformed (bad JSON, missing front-matter, etc.).
|
||||
StatusInvalid
|
||||
// StatusExcluded indicates the resource was dropped to fit
|
||||
// the aggregate snapshot or count cap.
|
||||
StatusExcluded
|
||||
)
|
||||
|
||||
// String returns the lower-snake-case name used in IDs and
|
||||
// metrics. Unknown values stringify to "unknown".
|
||||
func (s ResourceStatus) String() string {
|
||||
switch s {
|
||||
case StatusOK:
|
||||
return "ok"
|
||||
case StatusOversize:
|
||||
return "oversize"
|
||||
case StatusUnreadable:
|
||||
return "unreadable"
|
||||
case StatusInvalid:
|
||||
return "invalid"
|
||||
case StatusExcluded:
|
||||
return "excluded"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Source is a user-declared scan root added to the agent's
|
||||
// in-memory list via the HTTP API or boot-time env seeding.
|
||||
// Identity is the canonical absolute path.
|
||||
type Source struct {
|
||||
// Path is the canonical absolute path (symlinks resolved,
|
||||
// ~ expanded). Empty means the zero value.
|
||||
Path string
|
||||
}
|
||||
|
||||
// Resource is what the resolver emits for each recognized file
|
||||
// or live server it discovers under a scan root.
|
||||
type Resource struct {
|
||||
// ID is stable across pushes for the same logical
|
||||
// resource. The current scheme is "<kind>:<source>".
|
||||
ID string
|
||||
// Kind classifies the resource for snapshot consumers.
|
||||
Kind ResourceKind
|
||||
// Source is the file path or MCP server name.
|
||||
Source string
|
||||
// ContentHash is sha256 over the resource's original
|
||||
// bytes (or transport-encoded server tool list).
|
||||
ContentHash [32]byte
|
||||
// Payload is the full bytes when Status == StatusOK; the
|
||||
// per-resource and aggregate caps may leave it empty.
|
||||
Payload []byte
|
||||
// SizeBytes is the original payload size, populated
|
||||
// regardless of Status.
|
||||
SizeBytes uint64
|
||||
// Status records OK or a reason the payload is absent.
|
||||
Status ResourceStatus
|
||||
// Error is populated whenever Status != StatusOK; may
|
||||
// also carry a non-fatal warning when Status == StatusOK.
|
||||
Error string
|
||||
// Description is a short human-readable summary (skill
|
||||
// front-matter description, MCP server description, etc.).
|
||||
Description string
|
||||
// SourcePath is the user-declared source that contributed
|
||||
// the resource; empty for built-in scan roots.
|
||||
SourcePath string
|
||||
}
|
||||
|
||||
// Snapshot is the immutable bundle of resources produced by a
|
||||
// single resolver pass.
|
||||
type Snapshot struct {
|
||||
// Version is monotonically increasing per Manager
|
||||
// instance; resets when the agent process restarts.
|
||||
Version uint64
|
||||
// SchemaVersion is bumped if the resource shape on the
|
||||
// wire changes.
|
||||
SchemaVersion uint64
|
||||
// AggregateHash is sha256 over a canonical encoding of
|
||||
// (ID, Kind, Source, ContentHash, Status) for every
|
||||
// resource. Identical inputs always produce identical
|
||||
// hashes; see ComputeAggregateHash.
|
||||
AggregateHash [32]byte
|
||||
// Resources is sorted by ID for deterministic encoding.
|
||||
Resources []Resource
|
||||
// PayloadBytes is the sum of len(Resource.Payload) across
|
||||
// emitted resources after caps were applied.
|
||||
PayloadBytes uint64
|
||||
// SnapshotError carries a single snapshot-level error
|
||||
// string when present (count cap exceeded, watcher
|
||||
// degraded, ENOSPC, etc.). Empty when healthy.
|
||||
SnapshotError string
|
||||
}
|
||||
|
||||
// ComputeAggregateHash produces the deterministic snapshot
|
||||
// aggregate hash for the supplied resources. The caller does
|
||||
// not need to pre-sort; the function sorts a copy of the slice
|
||||
// to keep its inputs side-effect free.
|
||||
//
|
||||
// The encoding is a newline-delimited stream of fields. The
|
||||
// resource boundary is a single NUL byte. The boundary scheme
|
||||
// is internal to the agent and coderd, but it is stable across
|
||||
// platforms because every field is encoded as either a UTF-8
|
||||
// string with a length prefix or a fixed-width integer.
|
||||
func ComputeAggregateHash(resources []Resource) [32]byte {
|
||||
indexed := make([]Resource, len(resources))
|
||||
copy(indexed, resources)
|
||||
sort.Slice(indexed, func(i, j int) bool {
|
||||
return indexed[i].ID < indexed[j].ID
|
||||
})
|
||||
|
||||
h := sha256.New()
|
||||
for _, r := range indexed {
|
||||
writeLengthPrefixed(h, r.ID)
|
||||
writeLengthPrefixed(h, r.Kind.String())
|
||||
writeLengthPrefixed(h, r.Source)
|
||||
_, _ = h.Write(r.ContentHash[:])
|
||||
writeLengthPrefixed(h, r.Status.String())
|
||||
_, _ = h.Write([]byte{0})
|
||||
}
|
||||
var out [32]byte
|
||||
copy(out[:], h.Sum(nil))
|
||||
return out
|
||||
}
|
||||
|
||||
// AggregateHashHex returns the hex-encoded aggregate hash.
|
||||
// Convenience for log lines and HTTP responses.
|
||||
func (s Snapshot) AggregateHashHex() string {
|
||||
return hex.EncodeToString(s.AggregateHash[:])
|
||||
}
|
||||
|
||||
// writeLengthPrefixed writes a uvarint length prefix followed
|
||||
// by the raw bytes of s.
|
||||
func writeLengthPrefixed(h interface{ Write([]byte) (int, error) }, s string) {
|
||||
_, _ = h.Write([]byte(strconv.Itoa(len(s))))
|
||||
_, _ = h.Write([]byte{':'})
|
||||
_, _ = h.Write([]byte(s))
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
package agentcontext_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontext"
|
||||
)
|
||||
|
||||
func TestResourceKindString(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
kind agentcontext.ResourceKind
|
||||
want string
|
||||
}{
|
||||
{agentcontext.KindUnspecified, "unknown"},
|
||||
{agentcontext.KindInstructionFile, "instruction_file"},
|
||||
{agentcontext.KindSkill, "skill"},
|
||||
{agentcontext.KindMCPConfig, "mcp_config"},
|
||||
{agentcontext.KindMCPServer, "mcp_server"},
|
||||
{agentcontext.KindPlugin, "plugin"},
|
||||
{agentcontext.KindHook, "hook"},
|
||||
{agentcontext.KindSubagent, "subagent"},
|
||||
{agentcontext.KindCommand, "command"},
|
||||
{agentcontext.ResourceKind(999), "unknown"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
require.Equal(t, tt.want, tt.kind.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceStatusString(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
status agentcontext.ResourceStatus
|
||||
want string
|
||||
}{
|
||||
{agentcontext.StatusOK, "ok"},
|
||||
{agentcontext.StatusOversize, "oversize"},
|
||||
{agentcontext.StatusUnreadable, "unreadable"},
|
||||
{agentcontext.StatusInvalid, "invalid"},
|
||||
{agentcontext.StatusExcluded, "excluded"},
|
||||
{agentcontext.ResourceStatus(999), "unknown"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
require.Equal(t, tt.want, tt.status.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeAggregateHash_DeterministicAcrossOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
a := agentcontext.Resource{
|
||||
ID: "instruction_file:/a/AGENTS.md",
|
||||
Kind: agentcontext.KindInstructionFile,
|
||||
Source: "/a/AGENTS.md",
|
||||
Status: agentcontext.StatusOK,
|
||||
}
|
||||
b := agentcontext.Resource{
|
||||
ID: "instruction_file:/b/AGENTS.md",
|
||||
Kind: agentcontext.KindInstructionFile,
|
||||
Source: "/b/AGENTS.md",
|
||||
Status: agentcontext.StatusOK,
|
||||
}
|
||||
got1 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{a, b})
|
||||
got2 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{b, a})
|
||||
require.Equal(t, got1, got2)
|
||||
}
|
||||
|
||||
func TestComputeAggregateHash_ChangesOnContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
base := agentcontext.Resource{
|
||||
ID: "instruction_file:/a/AGENTS.md",
|
||||
Kind: agentcontext.KindInstructionFile,
|
||||
Source: "/a/AGENTS.md",
|
||||
Status: agentcontext.StatusOK,
|
||||
}
|
||||
hash1 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{base})
|
||||
|
||||
withContent := base
|
||||
withContent.ContentHash = [32]byte{0x01}
|
||||
hash2 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{withContent})
|
||||
require.NotEqual(t, hash1, hash2)
|
||||
|
||||
withStatus := base
|
||||
withStatus.Status = agentcontext.StatusOversize
|
||||
hash3 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{withStatus})
|
||||
require.NotEqual(t, hash1, hash3)
|
||||
}
|
||||
|
||||
func TestSnapshotAggregateHashHex(t *testing.T) {
|
||||
t.Parallel()
|
||||
snap := agentcontext.Snapshot{
|
||||
AggregateHash: [32]byte{0xde, 0xad, 0xbe, 0xef},
|
||||
}
|
||||
require.Equal(t,
|
||||
"deadbeef0000000000000000000000000000000000000000000000000000000000000000"[:64],
|
||||
snap.AggregateHashHex())
|
||||
}
|
||||
@@ -0,0 +1,346 @@
|
||||
package agentcontext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// DefaultWatchDebounce coalesces editor-style multi-event writes
|
||||
// (truncate plus rename plus chmod) into a single re-resolve.
|
||||
// Mirrors the debounce window the existing MCP config watcher
|
||||
// uses so behavior is consistent across the agent.
|
||||
const DefaultWatchDebounce = 250 * time.Millisecond
|
||||
|
||||
// WatcherOptions parameterizes the recursive watcher.
|
||||
type WatcherOptions struct {
|
||||
Logger slog.Logger
|
||||
Clock quartz.Clock
|
||||
Debounce time.Duration
|
||||
// OnChange runs at most once per debounce window. The
|
||||
// caller must not block; the recommended pattern is a
|
||||
// non-blocking send on a re-resolve trigger channel.
|
||||
OnChange func()
|
||||
}
|
||||
|
||||
// Watcher is a recursive fsnotify wrapper. fsnotify does not
|
||||
// support recursive watches natively on Linux, so we walk every
|
||||
// scan root at sync time and register each subdirectory
|
||||
// individually. Inotify ENOSPC degrades the watcher into a
|
||||
// poll-only mode that still re-resolves on Sync calls.
|
||||
type Watcher struct {
|
||||
logger slog.Logger
|
||||
clock quartz.Clock
|
||||
debounce time.Duration
|
||||
onChange func()
|
||||
|
||||
mu sync.Mutex
|
||||
watcher *fsnotify.Watcher
|
||||
watched map[string]struct{}
|
||||
timer *quartz.Timer
|
||||
degraded string // non-empty when the watcher dropped events
|
||||
closed bool
|
||||
closedCh chan struct{}
|
||||
runDoneCh chan struct{}
|
||||
}
|
||||
|
||||
// NewWatcher constructs a recursive watcher. The watcher does
|
||||
// nothing until Sync is called.
|
||||
func NewWatcher(opts WatcherOptions) (*Watcher, error) {
|
||||
if opts.OnChange == nil {
|
||||
return nil, xerrors.New("OnChange callback is required")
|
||||
}
|
||||
debounce := opts.Debounce
|
||||
if debounce <= 0 {
|
||||
debounce = DefaultWatchDebounce
|
||||
}
|
||||
clock := opts.Clock
|
||||
if clock == nil {
|
||||
clock = quartz.NewReal()
|
||||
}
|
||||
|
||||
w, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
// On Linux, fsnotify.NewWatcher only fails when the
|
||||
// inotify subsystem is at the system-wide watch
|
||||
// limit. Surface a Watcher in "degraded" mode so the
|
||||
// caller can still rely on explicit Sync triggers.
|
||||
degraded := &Watcher{
|
||||
logger: opts.Logger,
|
||||
clock: clock,
|
||||
debounce: debounce,
|
||||
onChange: opts.OnChange,
|
||||
watched: make(map[string]struct{}),
|
||||
degraded: "fsnotify init failed: " + err.Error(),
|
||||
closedCh: make(chan struct{}),
|
||||
runDoneCh: closedChan(),
|
||||
}
|
||||
return degraded, nil
|
||||
}
|
||||
|
||||
cw := &Watcher{
|
||||
logger: opts.Logger,
|
||||
clock: clock,
|
||||
debounce: debounce,
|
||||
onChange: opts.OnChange,
|
||||
watcher: w,
|
||||
watched: make(map[string]struct{}),
|
||||
closedCh: make(chan struct{}),
|
||||
runDoneCh: make(chan struct{}),
|
||||
}
|
||||
go cw.run()
|
||||
return cw, nil
|
||||
}
|
||||
|
||||
// closedChan returns an already-closed channel for the
|
||||
// degraded-watcher case where there is no run goroutine.
|
||||
func closedChan() chan struct{} {
|
||||
c := make(chan struct{})
|
||||
close(c)
|
||||
return c
|
||||
}
|
||||
|
||||
// Degraded returns a non-empty string when the watcher is
|
||||
// running with reduced functionality (typically inotify
|
||||
// ENOSPC). The string is suitable for use as a snapshot-level
|
||||
// error message.
|
||||
func (w *Watcher) Degraded() string {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.degraded
|
||||
}
|
||||
|
||||
// Sync replaces the set of watched directories with a fresh
|
||||
// recursive walk of every scan root. Files are not watched
|
||||
// directly; watching the parent directory catches creates,
|
||||
// renames, removes, and writes that touch any recognized
|
||||
// basename. Files that are themselves scan roots are handled by
|
||||
// watching their parent.
|
||||
//
|
||||
// Sync is idempotent and safe to call repeatedly.
|
||||
func (w *Watcher) Sync(ctx context.Context, roots []ScanRoot) {
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if w.watcher == nil {
|
||||
// Degraded mode: nothing to wire up; fire the callback
|
||||
// so the caller still gets a fresh resolve.
|
||||
w.mu.Unlock()
|
||||
w.schedule()
|
||||
return
|
||||
}
|
||||
desired := w.collectDirs(roots)
|
||||
|
||||
// Remove directories no longer wanted.
|
||||
for path := range w.watched {
|
||||
if _, ok := desired[path]; ok {
|
||||
continue
|
||||
}
|
||||
_ = w.watcher.Remove(path)
|
||||
delete(w.watched, path)
|
||||
}
|
||||
// Add directories that are new.
|
||||
for path := range desired {
|
||||
if _, ok := w.watched[path]; ok {
|
||||
continue
|
||||
}
|
||||
if err := w.watcher.Add(path); err != nil {
|
||||
// ENOSPC means the kernel's per-user inotify
|
||||
// watch budget is exhausted. Mark the watcher
|
||||
// degraded; subsequent Sync calls still fire
|
||||
// the change callback so resync still works.
|
||||
if errors.Is(err, syscall.ENOSPC) {
|
||||
w.degraded = "inotify watch limit exceeded (ENOSPC)"
|
||||
w.logger.Warn(ctx, "context watcher degraded: inotify watch limit exceeded",
|
||||
slog.F("dir", path))
|
||||
break
|
||||
}
|
||||
w.logger.Debug(ctx, "context watcher could not add dir",
|
||||
slog.F("dir", path), slog.Error(err))
|
||||
continue
|
||||
}
|
||||
w.watched[path] = struct{}{}
|
||||
}
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
// Close stops the watcher and releases all kernel watch slots.
|
||||
// Close is idempotent.
|
||||
func (w *Watcher) Close() error {
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
w.closed = true
|
||||
close(w.closedCh)
|
||||
timer := w.timer
|
||||
watcher := w.watcher
|
||||
w.timer = nil
|
||||
w.watcher = nil
|
||||
w.mu.Unlock()
|
||||
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
if watcher != nil {
|
||||
_ = watcher.Close()
|
||||
}
|
||||
<-w.runDoneCh
|
||||
return nil
|
||||
}
|
||||
|
||||
// run forwards fsnotify events into the debounce timer. It exits
|
||||
// when Close is called or the underlying watcher is closed.
|
||||
func (w *Watcher) run() {
|
||||
defer close(w.runDoneCh)
|
||||
// Capture the watcher reference once. Close may set the
|
||||
// field to nil concurrently; reading the captured local
|
||||
// keeps the event loop safe through the race window.
|
||||
w.mu.Lock()
|
||||
fsw := w.watcher
|
||||
w.mu.Unlock()
|
||||
if fsw == nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-w.closedCh:
|
||||
return
|
||||
case ev, ok := <-fsw.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if !w.eventRelevant(ev) {
|
||||
continue
|
||||
}
|
||||
w.schedule()
|
||||
case err, ok := <-fsw.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
w.logger.Debug(context.Background(), "context watcher error", slog.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// eventRelevant filters out events that cannot affect any
|
||||
// recognized resource. The check is conservative: any event on
|
||||
// a directory triggers a re-resolve so newly created subtrees
|
||||
// are picked up.
|
||||
func (*Watcher) eventRelevant(ev fsnotify.Event) bool {
|
||||
name := filepath.Base(ev.Name)
|
||||
if recognizedInstructionFile(name) || name == mcpConfigFileName || name == skillMetaFileName {
|
||||
return true
|
||||
}
|
||||
// Directory create/remove flips re-resolve so new subtrees
|
||||
// arm watches and removed subtrees stop arming them.
|
||||
if ev.Has(fsnotify.Create) || ev.Has(fsnotify.Remove) || ev.Has(fsnotify.Rename) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// schedule arms or resets the debounce timer.
|
||||
func (w *Watcher) schedule() {
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
cb := w.onChange
|
||||
if w.timer != nil {
|
||||
w.timer.Reset(w.debounce)
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
w.timer = w.clock.AfterFunc(w.debounce, func() {
|
||||
w.mu.Lock()
|
||||
w.timer = nil
|
||||
w.mu.Unlock()
|
||||
cb()
|
||||
})
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
// collectDirs walks every scan root and returns the set of
|
||||
// directories to watch. The maximum depth mirrors the
|
||||
// resolver's MaxDepth so we never watch more than we'd scan.
|
||||
func (*Watcher) collectDirs(roots []ScanRoot) map[string]struct{} {
|
||||
out := make(map[string]struct{})
|
||||
for _, root := range roots {
|
||||
if root.Path == "" {
|
||||
continue
|
||||
}
|
||||
info, err := os.Stat(root.Path)
|
||||
if err != nil {
|
||||
// Watch the deepest existing ancestor so the
|
||||
// root being created later still fires.
|
||||
if anc := existingAncestor(root.Path); anc != "" {
|
||||
out[anc] = struct{}{}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !info.IsDir() {
|
||||
out[filepath.Dir(root.Path)] = struct{}{}
|
||||
continue
|
||||
}
|
||||
// Walk the directory and collect every descendant
|
||||
// directory up to the depth cap.
|
||||
rootDepth := strings.Count(filepath.Clean(root.Path), string(os.PathSeparator))
|
||||
_ = filepath.WalkDir(root.Path, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if !d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if _, skip := skipDirNames[d.Name()]; skip && path != root.Path {
|
||||
return fs.SkipDir
|
||||
}
|
||||
if strings.Count(path, string(os.PathSeparator))-rootDepth > DefaultMaxScanDepth {
|
||||
return fs.SkipDir
|
||||
}
|
||||
out[path] = struct{}{}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// existingAncestor returns the deepest existing ancestor of
|
||||
// path, or "" if no ancestor exists (e.g. an entirely missing
|
||||
// drive on Windows).
|
||||
func existingAncestor(path string) string {
|
||||
cur := filepath.Dir(path)
|
||||
for {
|
||||
if cur == "" || cur == "." {
|
||||
return ""
|
||||
}
|
||||
info, err := os.Stat(cur)
|
||||
if err == nil && info.IsDir() {
|
||||
return cur
|
||||
}
|
||||
parent := filepath.Dir(cur)
|
||||
if parent == cur {
|
||||
return ""
|
||||
}
|
||||
cur = parent
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package agentcontext_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/quartz"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontext"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestWatcher_FiresOnAgentsMdEdit(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v1"), 0o600))
|
||||
|
||||
var fires int32
|
||||
w, err := agentcontext.NewWatcher(agentcontext.WatcherOptions{
|
||||
Logger: testutil.Logger(t).Named("watcher"),
|
||||
Clock: quartz.NewReal(),
|
||||
Debounce: 10 * time.Millisecond,
|
||||
OnChange: func() { atomic.AddInt32(&fires, 1) },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = w.Close() })
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
w.Sync(ctx, []agentcontext.ScanRoot{{Path: dir}})
|
||||
|
||||
// Edit the file. Use a slight delay so fsnotify is ready.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v2"), 0o600))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt32(&fires) >= 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast, "expected at least one fire after AGENTS.md edit")
|
||||
}
|
||||
|
||||
func TestWatcher_FiresOnNewSkillFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
skillsRoot := filepath.Join(dir, ".agents", "skills")
|
||||
require.NoError(t, os.MkdirAll(skillsRoot, 0o755))
|
||||
|
||||
var fires int32
|
||||
w, err := agentcontext.NewWatcher(agentcontext.WatcherOptions{
|
||||
Logger: testutil.Logger(t).Named("watcher"),
|
||||
Debounce: 10 * time.Millisecond,
|
||||
OnChange: func() { atomic.AddInt32(&fires, 1) },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = w.Close() })
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
w.Sync(ctx, []agentcontext.ScanRoot{{Path: dir}})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
skillDir := filepath.Join(skillsRoot, "foo")
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("---\nname: foo\ndescription: bar\n---\nbody"), 0o600))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt32(&fires) >= 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast, "expected fire after SKILL.md create")
|
||||
}
|
||||
|
||||
func TestWatcher_CloseIsIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
w, err := agentcontext.NewWatcher(agentcontext.WatcherOptions{
|
||||
Logger: testutil.Logger(t).Named("watcher"),
|
||||
OnChange: func() {},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, w.Close())
|
||||
require.NoError(t, w.Close())
|
||||
}
|
||||
|
||||
func TestWatcher_SyncAfterCloseNoop(t *testing.T) {
|
||||
t.Parallel()
|
||||
w, err := agentcontext.NewWatcher(agentcontext.WatcherOptions{
|
||||
Logger: testutil.Logger(t).Named("watcher"),
|
||||
OnChange: func() {},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, w.Close())
|
||||
|
||||
// Must not panic.
|
||||
w.Sync(context.Background(), []agentcontext.ScanRoot{{Path: t.TempDir()}})
|
||||
}
|
||||
Reference in New Issue
Block a user