From 678617ab388c6f96998b751ade2c91f228071db1 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Tue, 2 Jun 2026 14:11:05 +0000 Subject: [PATCH] feat(agent/agentcontext): add agent-side context resolution package --- agent/agentcontext/api.go | 204 ++++++++++ agent/agentcontext/api_test.go | 177 +++++++++ agent/agentcontext/doc.go | 25 ++ agent/agentcontext/manager.go | 480 ++++++++++++++++++++++ agent/agentcontext/manager_test.go | 260 ++++++++++++ agent/agentcontext/mcp.go | 24 ++ agent/agentcontext/paths.go | 121 ++++++ agent/agentcontext/paths_test.go | 112 ++++++ agent/agentcontext/push.go | 182 +++++++++ agent/agentcontext/push_test.go | 197 +++++++++ agent/agentcontext/resolve.go | 613 +++++++++++++++++++++++++++++ agent/agentcontext/resolve_test.go | 276 +++++++++++++ agent/agentcontext/types.go | 222 +++++++++++ agent/agentcontext/types_test.go | 99 +++++ agent/agentcontext/watch.go | 346 ++++++++++++++++ agent/agentcontext/watch_test.go | 96 +++++ 16 files changed, 3434 insertions(+) create mode 100644 agent/agentcontext/api.go create mode 100644 agent/agentcontext/api_test.go create mode 100644 agent/agentcontext/doc.go create mode 100644 agent/agentcontext/manager.go create mode 100644 agent/agentcontext/manager_test.go create mode 100644 agent/agentcontext/mcp.go create mode 100644 agent/agentcontext/paths.go create mode 100644 agent/agentcontext/paths_test.go create mode 100644 agent/agentcontext/push.go create mode 100644 agent/agentcontext/push_test.go create mode 100644 agent/agentcontext/resolve.go create mode 100644 agent/agentcontext/resolve_test.go create mode 100644 agent/agentcontext/types.go create mode 100644 agent/agentcontext/types_test.go create mode 100644 agent/agentcontext/watch.go create mode 100644 agent/agentcontext/watch_test.go diff --git a/agent/agentcontext/api.go b/agent/agentcontext/api.go new file mode 100644 index 0000000000..2ffa329350 --- /dev/null +++ b/agent/agentcontext/api.go @@ -0,0 +1,204 @@ +package agentcontext + +import ( + "context" + "encoding/hex" + "errors" + "net/http" + "net/url" + "strconv" + + "github.com/go-chi/chi/v5" + + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/codersdk" +) + +// SourceResponse is the on-wire representation of a Source. +// Matches the path-only RFC schema; future additions (tags, +// labels) can land additively without breaking clients. +type SourceResponse struct { + Path string `json:"path"` +} + +// SourceRequest is the request body for POST /sources. +type SourceRequest struct { + Path string `json:"path"` +} + +// SnapshotResource is the on-wire representation of a Resource. +// PayloadBase64 is omitted from list responses; clients that +// need the bytes hit GET /sources/{path}. +type SnapshotResource struct { + ID string `json:"id"` + Kind string `json:"kind"` + Source string `json:"source"` + SourcePath string `json:"source_path,omitempty"` + ContentHash string `json:"content_hash"` + SizeBytes uint64 `json:"size_bytes"` + Status string `json:"status"` + Error string `json:"error,omitempty"` + Description string `json:"description,omitempty"` +} + +// SnapshotResponse is the on-wire representation of a Snapshot +// returned by the resync endpoint. +type SnapshotResponse struct { + Version uint64 `json:"version"` + SchemaVersion uint64 `json:"schema_version"` + AggregateHash string `json:"aggregate_hash"` + Resources []SnapshotResource `json:"resources"` + PayloadBytes uint64 `json:"payload_bytes"` + SnapshotError string `json:"snapshot_error,omitempty"` +} + +// API exposes the Manager over HTTP. The routes match the RFC: +// +// GET /api/v0/context/sources +// POST /api/v0/context/sources { path } +// GET /api/v0/context/sources/{path} +// DELETE /api/v0/context/sources/{path} +// POST /api/v0/context/resync +// +// {path} is URL-encoded canonical path. Callers pass either the +// canonical or original path; the handler canonicalizes before +// matching. +type API struct { + manager *Manager +} + +// NewAPI wraps the supplied Manager. +func NewAPI(m *Manager) *API { + return &API{manager: m} +} + +// Routes returns the chi handler for /api/v0/context/*. Mount +// it at "/api/v0/context". +func (a *API) Routes() http.Handler { + r := chi.NewRouter() + r.Route("/sources", func(r chi.Router) { + r.Get("/", a.handleListSources) + r.Post("/", a.handleAddSource) + r.Get("/{path}", a.handleGetSource) + r.Delete("/{path}", a.handleRemoveSource) + }) + r.Post("/resync", a.handleResync) + return r +} + +func (a *API) handleListSources(rw http.ResponseWriter, r *http.Request) { + sources := a.manager.Sources() + out := make([]SourceResponse, 0, len(sources)) + for _, s := range sources { + out = append(out, SourceResponse(s)) + } + httpapi.Write(r.Context(), rw, http.StatusOK, out) +} + +func (a *API) handleAddSource(rw http.ResponseWriter, r *http.Request) { + var req SourceRequest + if !httpapi.Read(r.Context(), rw, r, &req) { + return + } + s, err := a.manager.AddSource(Source(req)) + if err != nil { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Could not add context source.", + Detail: err.Error(), + }) + return + } + httpapi.Write(r.Context(), rw, http.StatusCreated, SourceResponse(s)) +} + +func (a *API) handleGetSource(rw http.ResponseWriter, r *http.Request) { + raw := chi.URLParam(r, "path") + decoded, err := url.PathUnescape(raw) + if err != nil { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid context source path.", + Detail: err.Error(), + }) + return + } + canonical, ok := a.manager.HasSource(decoded) + if !ok { + httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{ + Message: "Context source not found.", + Detail: "No source registered for path " + strconv.Quote(decoded) + ".", + }) + return + } + httpapi.Write(r.Context(), rw, http.StatusOK, SourceResponse{Path: canonical}) +} + +func (a *API) handleRemoveSource(rw http.ResponseWriter, r *http.Request) { + raw := chi.URLParam(r, "path") + decoded, err := url.PathUnescape(raw) + if err != nil { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid context source path.", + Detail: err.Error(), + }) + return + } + if err := a.manager.RemoveSource(decoded); err != nil { + if errors.Is(err, ErrSourceNotFound) { + httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{ + Message: "Context source not found.", + Detail: err.Error(), + }) + return + } + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Could not remove context source.", + Detail: err.Error(), + }) + return + } + rw.WriteHeader(http.StatusNoContent) +} + +func (a *API) handleResync(rw http.ResponseWriter, r *http.Request) { + snap, err := a.manager.Resync(r.Context()) + if err != nil { + status := http.StatusInternalServerError + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + status = http.StatusGatewayTimeout + } + httpapi.Write(r.Context(), rw, status, codersdk.Response{ + Message: "Resync failed.", + Detail: err.Error(), + }) + return + } + httpapi.Write(r.Context(), rw, http.StatusOK, snapshotResponse(snap)) +} + +// snapshotResponse converts a Snapshot to its on-wire form for +// the resync endpoint. Payloads are omitted; the per-resource +// payload bytes ship via the drpc PushContextState path. +func snapshotResponse(s Snapshot) SnapshotResponse { + out := SnapshotResponse{ + Version: s.Version, + SchemaVersion: s.SchemaVersion, + AggregateHash: hex.EncodeToString(s.AggregateHash[:]), + Resources: make([]SnapshotResource, 0, len(s.Resources)), + PayloadBytes: s.PayloadBytes, + SnapshotError: s.SnapshotError, + } + for _, r := range s.Resources { + out.Resources = append(out.Resources, SnapshotResource{ + ID: r.ID, + Kind: r.Kind.String(), + Source: r.Source, + SourcePath: r.SourcePath, + ContentHash: hex.EncodeToString(r.ContentHash[:]), + SizeBytes: r.SizeBytes, + Status: r.Status.String(), + Error: r.Error, + Description: r.Description, + }) + } + return out +} diff --git a/agent/agentcontext/api_test.go b/agent/agentcontext/api_test.go new file mode 100644 index 0000000000..1563345d9e --- /dev/null +++ b/agent/agentcontext/api_test.go @@ -0,0 +1,177 @@ +package agentcontext_test + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/agent/agentcontext" + "github.com/coder/coder/v2/testutil" +) + +func newAPITestServer(t *testing.T, opts agentcontext.ManagerOptions) (*httptest.Server, *agentcontext.Manager) { + t.Helper() + m := newTestManager(t, opts) + api := agentcontext.NewAPI(m) + srv := httptest.NewServer(api.Routes()) + t.Cleanup(srv.Close) + return srv, m +} + +// doRequest issues an HTTP request bounded by testutil.WaitShort +// and returns the status code and response body. The response +// body is closed before doRequest returns. +func doRequest(t *testing.T, method, requrl string, body io.Reader) (int, []byte) { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + req, err := http.NewRequestWithContext(ctx, method, requrl, body) + require.NoError(t, err) + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + res, err := http.DefaultClient.Do(req) //nolint:bodyclose // closed below. + require.NoError(t, err) + defer res.Body.Close() + bodyBytes, err := io.ReadAll(res.Body) + require.NoError(t, err) + return res.StatusCode, bodyBytes +} + +func TestAPI_ListSourcesEmpty(t *testing.T) { + t.Parallel() + srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return t.TempDir() }, + }) + + status, body := doRequest(t, http.MethodGet, srv.URL+"/sources", nil) + require.Equal(t, http.StatusOK, status) + + var got []agentcontext.SourceResponse + require.NoError(t, json.Unmarshal(body, &got)) + require.Empty(t, got) +} + +func TestAPI_AddAndListSource(t *testing.T) { + t.Parallel() + wd := t.TempDir() + src := t.TempDir() + + srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return wd }, + AllowedRoots: []string{wd, src}, + }) + + body, _ := json.Marshal(agentcontext.SourceRequest{Path: src}) + status, addBody := doRequest(t, http.MethodPost, srv.URL+"/sources", bytes.NewReader(body)) + require.Equal(t, http.StatusCreated, status) + + var created agentcontext.SourceResponse + require.NoError(t, json.Unmarshal(addBody, &created)) + require.Equal(t, src, created.Path) + + // List should show the new source. + listStatus, listBody := doRequest(t, http.MethodGet, srv.URL+"/sources", nil) + require.Equal(t, http.StatusOK, listStatus) + var list []agentcontext.SourceResponse + require.NoError(t, json.Unmarshal(listBody, &list)) + require.Len(t, list, 1) + require.Equal(t, src, list[0].Path) +} + +func TestAPI_AddSourceRejected(t *testing.T) { + t.Parallel() + wd := t.TempDir() + outside := t.TempDir() + + srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return wd }, + AllowedRoots: []string{wd}, + }) + + body, _ := json.Marshal(agentcontext.SourceRequest{Path: outside}) + status, _ := doRequest(t, http.MethodPost, srv.URL+"/sources", bytes.NewReader(body)) + require.Equal(t, http.StatusBadRequest, status) +} + +func TestAPI_GetAndDeleteSource(t *testing.T) { + t.Parallel() + wd := t.TempDir() + src := t.TempDir() + + srv, m := newAPITestServer(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return wd }, + AllowedRoots: []string{wd, src}, + }) + + _, err := m.AddSource(agentcontext.Source{Path: src}) + require.NoError(t, err) + + status, body := doRequest(t, http.MethodGet, srv.URL+"/sources/"+url.PathEscape(src), nil) + require.Equal(t, http.StatusOK, status) + + var got agentcontext.SourceResponse + require.NoError(t, json.Unmarshal(body, &got)) + require.Equal(t, src, got.Path) + + delStatus, _ := doRequest(t, http.MethodDelete, srv.URL+"/sources/"+url.PathEscape(src), nil) + require.Equal(t, http.StatusNoContent, delStatus) + require.Empty(t, m.Sources()) +} + +func TestAPI_GetSourceNotFound(t *testing.T) { + t.Parallel() + srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return t.TempDir() }, + }) + + status, _ := doRequest(t, http.MethodGet, srv.URL+"/sources/"+url.PathEscape("/never-added"), nil) + require.Equal(t, http.StatusNotFound, status) +} + +func TestAPI_DeleteSourceNotFound(t *testing.T) { + t.Parallel() + srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return t.TempDir() }, + }) + + status, _ := doRequest(t, http.MethodDelete, srv.URL+"/sources/"+url.PathEscape("/never-added"), nil) + require.Equal(t, http.StatusNotFound, status) +} + +func TestAPI_Resync(t *testing.T) { + t.Parallel() + wd := t.TempDir() + mustWriteFile(t, filepath.Join(wd, "AGENTS.md"), "hello") + + srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return wd }, + }) + + status, body := doRequest(t, http.MethodPost, srv.URL+"/resync", nil) + require.Equal(t, http.StatusOK, status) + + var snap agentcontext.SnapshotResponse + require.NoError(t, json.Unmarshal(body, &snap)) + require.Equal(t, uint64(1), snap.SchemaVersion) + require.NotEmpty(t, snap.AggregateHash) + require.Len(t, snap.Resources, 1) + require.Equal(t, "instruction_file", snap.Resources[0].Kind) + require.Equal(t, "ok", snap.Resources[0].Status) +} + +func TestAPI_AddSourceMalformedBody(t *testing.T) { + t.Parallel() + srv, _ := newAPITestServer(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return t.TempDir() }, + }) + + status, _ := doRequest(t, http.MethodPost, srv.URL+"/sources", bytes.NewReader([]byte("{not json"))) + require.Equal(t, http.StatusBadRequest, status) +} diff --git a/agent/agentcontext/doc.go b/agent/agentcontext/doc.go new file mode 100644 index 0000000000..31929f6938 --- /dev/null +++ b/agent/agentcontext/doc.go @@ -0,0 +1,25 @@ +// Package agentcontext consolidates the agent-side plumbing that +// resolves, watches, and pushes workspace context (instruction +// files, skills, and MCP configuration) to coderd. +// +// This is the agent half of the design described in +// "RFC: Workspace Context Sources for Coder Agents". It owns: +// +// - User-declared scan roots (Sources) layered on top of +// built-in defaults. +// - A resolver that classifies files under each scan root into +// typed Resources (instruction files, skills, MCP configs, +// MCP servers). +// - A unified recursive fsnotify watcher that signals a +// re-resolve when any recognized file changes. +// - An HTTP API at /api/v0/context/sources for source CRUD +// and /api/v0/context/resync for synchronous push barriers. +// - A Pusher abstraction so the latest Snapshot can be shipped +// to coderd without coupling this package to any particular +// drpc client version. +// +// The package is purely additive: existing agent code paths +// (agent/agentcontextconfig and agent/x/agentmcp) continue to +// operate unchanged. Wiring the Manager into the agent's HTTP +// router and the drpc client lives in a follow-up change. +package agentcontext diff --git a/agent/agentcontext/manager.go b/agent/agentcontext/manager.go new file mode 100644 index 0000000000..e6f82e9eb4 --- /dev/null +++ b/agent/agentcontext/manager.go @@ -0,0 +1,480 @@ +package agentcontext + +import ( + "context" + "sort" + "strings" + "sync" + "time" + + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/quartz" +) + +// CurrentSchemaVersion is the on-wire shape version. Bump +// whenever the resource format changes in a way that requires +// coderd-side awareness. +const CurrentSchemaVersion uint64 = 1 + +// ManagerOptions configures a Manager. Zero values get sensible +// defaults. +type ManagerOptions struct { + // Logger receives diagnostic messages. Required. + Logger slog.Logger + // Clock is the time source used for the watcher's + // debounce timer. Optional; defaults to quartz.NewReal(). + Clock quartz.Clock + // WorkingDir is evaluated on every resolve, mirroring the + // existing agent convention. The result is used as a + // scan root. + WorkingDir func() string + // BuiltinRoots are scan roots layered before user-added + // sources. Typically: the working directory, ~/.coder, + // ~/.coder/skills, .agents/skills, + // ~/.claude/plugins/cache. + BuiltinRoots []string + // InitialSources seeds the Manager's source list at boot + // time. Sources from CODER_AGENT_EXP_*_DIRS env vars or + // startup scripts are layered here. + InitialSources []Source + // AllowedRoots restricts which paths may be added as + // sources at runtime. Defaults to [~, ~/.coder, ~/.claude, + // workingDir]. Empty disables validation. + AllowedRoots []string + // Resolver, when non-nil, replaces the default resolver. + // Tests use this to inject MCP providers and tighten + // caps. + Resolver *Resolver + // Debounce overrides the watcher's debounce window. + Debounce time.Duration + // SchemaVersion is the version stamped on each Snapshot. + // Use CurrentSchemaVersion (the default) unless rolling + // out a schema change. + SchemaVersion uint64 +} + +// Manager orchestrates source CRUD, resolution, watching, and +// Pusher fan-out. Construct with NewManager; start its lifecycle +// goroutines with Run; tear down with Close. +type Manager struct { + logger slog.Logger + clock quartz.Clock + workingDir func() string + builtinRoots []string + allowedRoots []string + resolver *Resolver + debounce time.Duration + schemaVersion uint64 + + mu sync.Mutex + sources []Source + // sourceIndex maps canonical path -> position in sources + // for O(1) lookups during AddSource / RemoveSource. + sourceIndex map[string]int + + // snapshot is the latest result of a resolver pass. It is + // replaced atomically under mu. + snapshot Snapshot + // version monotonically increases per resolve pass. + version uint64 + + // subscribers receive a non-blocking signal whenever the + // snapshot changes. Subscribers must drain their channel + // promptly; the Manager drops sends to full channels. + subscribers map[chan struct{}]struct{} + + // trigger fires when AddSource / RemoveSource / watcher + // observe a change. + trigger chan struct{} + + // running tracks Run lifetime. + running bool + closed bool + closedCh chan struct{} + runDoneCh chan struct{} + + watcher *Watcher +} + +// NewManager validates options, canonicalizes initial sources, +// performs the first resolver pass synchronously, and returns +// the resulting Manager. Run must be called separately to start +// the watcher and re-resolve goroutine. +func NewManager(opts ManagerOptions) (*Manager, error) { + clock := opts.Clock + if clock == nil { + clock = quartz.NewReal() + } + debounce := opts.Debounce + if debounce <= 0 { + debounce = DefaultWatchDebounce + } + schemaVersion := opts.SchemaVersion + if schemaVersion == 0 { + schemaVersion = CurrentSchemaVersion + } + resolver := opts.Resolver + if resolver == nil { + resolver = &Resolver{} + } + + m := &Manager{ + logger: opts.Logger, + clock: clock, + workingDir: opts.WorkingDir, + builtinRoots: append([]string(nil), opts.BuiltinRoots...), + allowedRoots: append([]string(nil), opts.AllowedRoots...), + resolver: resolver, + debounce: debounce, + schemaVersion: schemaVersion, + sources: make([]Source, 0), + sourceIndex: make(map[string]int), + subscribers: make(map[chan struct{}]struct{}), + trigger: make(chan struct{}, 1), + closedCh: make(chan struct{}), + runDoneCh: make(chan struct{}), + } + + for _, s := range opts.InitialSources { + canonical, err := CanonicalizePath(s.Path) + if err != nil { + // Initial sources may not exist yet at boot + // time; log and skip rather than abort the + // agent. + m.logger.Warn(context.Background(), + "agentcontext: skipping invalid initial source", + slog.F("path", s.Path), + slog.Error(err)) + continue + } + if _, ok := m.sourceIndex[canonical]; ok { + continue + } + m.sourceIndex[canonical] = len(m.sources) + m.sources = append(m.sources, Source{Path: canonical}) + } + + // First snapshot is computed eagerly. The push protocol + // requires a snapshot to be present before the agent signals + // lifecycle = ready, so callers can rely on Snapshot() being + // populated immediately after NewManager returns. + m.resolveLocked() + + return m, nil +} + +// Run starts the watcher and the re-resolve goroutine. Run +// blocks until ctx is canceled or Close is called. It is safe +// to call Run at most once per Manager. +func (m *Manager) Run(ctx context.Context) error { + m.mu.Lock() + if m.running { + m.mu.Unlock() + return xerrors.New("agentcontext: Manager.Run called more than once") + } + if m.closed { + m.mu.Unlock() + return xerrors.New("agentcontext: Manager already closed") + } + m.running = true + m.mu.Unlock() + + watcher, err := NewWatcher(WatcherOptions{ + Logger: m.logger.Named("watcher"), + Clock: m.clock, + Debounce: m.debounce, + OnChange: m.signal, + }) + if err != nil { + // NewWatcher already falls back to degraded mode on + // init failure, so an actual error here is + // exceptional. + return xerrors.Errorf("create watcher: %w", err) + } + m.mu.Lock() + m.watcher = watcher + roots := m.scanRootsLocked() + m.mu.Unlock() + watcher.Sync(ctx, roots) + + defer close(m.runDoneCh) + defer watcher.Close() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.closedCh: + return nil + case <-m.trigger: + m.mu.Lock() + roots := m.scanRootsLocked() + m.mu.Unlock() + watcher.Sync(ctx, roots) + m.resolveAndBroadcast(ctx) + } + } +} + +// Close stops the Manager. Close is idempotent; subsequent +// calls block until Run exits. +func (m *Manager) Close() error { + m.mu.Lock() + if m.closed { + running := m.running + m.mu.Unlock() + if running { + <-m.runDoneCh + } + return nil + } + m.closed = true + running := m.running + close(m.closedCh) + m.mu.Unlock() + if running { + <-m.runDoneCh + } + return nil +} + +// Sources returns a defensive copy of the current source list. +// The returned slice is safe to mutate. +func (m *Manager) Sources() []Source { + m.mu.Lock() + defer m.mu.Unlock() + out := make([]Source, len(m.sources)) + copy(out, m.sources) + return out +} + +// HasSource reports whether path matches an existing source +// after canonicalization. Returns the canonical path on +// success. +func (m *Manager) HasSource(path string) (canonical string, ok bool) { + c, err := CanonicalizePath(path) + if err != nil { + return "", false + } + m.mu.Lock() + defer m.mu.Unlock() + _, ok = m.sourceIndex[c] + return c, ok +} + +// AddSource adds a new source. The path is canonicalized and +// validated against the AllowedRoots set. AddSource is +// idempotent. +func (m *Manager) AddSource(s Source) (Source, error) { + canonical, err := CanonicalizePath(s.Path) + if err != nil { + return Source{}, xerrors.Errorf("canonicalize: %w", err) + } + if err := ValidateSourcePath(canonical, m.effectiveAllowedRoots()); err != nil { + return Source{}, err + } + + m.mu.Lock() + if _, ok := m.sourceIndex[canonical]; ok { + out := m.sources[m.sourceIndex[canonical]] + m.mu.Unlock() + return out, nil + } + m.sourceIndex[canonical] = len(m.sources) + m.sources = append(m.sources, Source{Path: canonical}) + m.mu.Unlock() + + m.signal() + return Source{Path: canonical}, nil +} + +// RemoveSource removes the source matching path. Path is +// canonicalized before matching. Returns ErrSourceNotFound when +// no such source exists. +func (m *Manager) RemoveSource(path string) error { + canonical, err := CanonicalizePath(path) + if err != nil { + return xerrors.Errorf("canonicalize: %w", err) + } + + m.mu.Lock() + idx, ok := m.sourceIndex[canonical] + if !ok { + m.mu.Unlock() + return ErrSourceNotFound + } + // O(n) compaction is fine for the typical handful of + // user-added sources. + m.sources = append(m.sources[:idx], m.sources[idx+1:]...) + delete(m.sourceIndex, canonical) + for i := idx; i < len(m.sources); i++ { + m.sourceIndex[m.sources[i].Path] = i + } + m.mu.Unlock() + + m.signal() + return nil +} + +// Snapshot returns the latest Snapshot. The returned value is +// safe to share but shares the same Resources slice as the +// internal state; callers must not mutate it. +func (m *Manager) Snapshot() Snapshot { + m.mu.Lock() + defer m.mu.Unlock() + return m.snapshot +} + +// SubscribeChanges returns a buffered channel that receives a +// signal whenever the snapshot changes. The unsubscribe +// callback is safe to call from any goroutine and is +// idempotent. +func (m *Manager) SubscribeChanges() (<-chan struct{}, func()) { + ch := make(chan struct{}, 1) + m.mu.Lock() + m.subscribers[ch] = struct{}{} + m.mu.Unlock() + + var once sync.Once + unsub := func() { + once.Do(func() { + m.mu.Lock() + delete(m.subscribers, ch) + m.mu.Unlock() + // Don't close ch: readers may still be in flight. + }) + } + return ch, unsub +} + +// Resync forces an immediate re-resolve and returns the new +// Snapshot. Resync is safe to call regardless of whether Run +// is active; the work is done synchronously under the +// Manager's mutex either way. +func (m *Manager) Resync(ctx context.Context) (Snapshot, error) { + if ctxErr := ctx.Err(); ctxErr != nil { + return m.Snapshot(), ctxErr + } + m.mu.Lock() + if m.closed { + m.mu.Unlock() + return m.Snapshot(), ErrManagerClosed + } + m.resolveLocked() + snap := m.snapshot + subs := make([]chan struct{}, 0, len(m.subscribers)) + for ch := range m.subscribers { + subs = append(subs, ch) + } + m.mu.Unlock() + for _, ch := range subs { + select { + case ch <- struct{}{}: + default: + } + } + return snap, nil +} + +// signal triggers a re-resolve. Sends are non-blocking; the +// trigger channel has a depth of 1, which coalesces bursts. +func (m *Manager) signal() { + select { + case m.trigger <- struct{}{}: + default: + } +} + +// scanRootsLocked returns the list of ScanRoots to feed the +// resolver and watcher. The Manager's mutex must be held. +func (m *Manager) scanRootsLocked() []ScanRoot { + out := make([]ScanRoot, 0, 1+len(m.builtinRoots)+len(m.sources)) + if m.workingDir != nil { + if wd := strings.TrimSpace(m.workingDir()); wd != "" { + out = append(out, ScanRoot{Path: wd}) + } + } + for _, r := range m.builtinRoots { + canonical, err := CanonicalizePath(r) + if err != nil { + continue + } + out = append(out, ScanRoot{Path: canonical}) + } + for _, s := range m.sources { + out = append(out, ScanRoot{Path: s.Path, UserSource: s.Path}) + } + return out +} + +// effectiveAllowedRoots returns the AllowedRoots augmented with +// a sensible fallback (~ and the working directory) when the +// caller did not configure any. +func (m *Manager) effectiveAllowedRoots() []string { + if len(m.allowedRoots) > 0 { + return append([]string{}, m.allowedRoots...) + } + roots := []string{"~"} + if m.workingDir != nil { + if wd := strings.TrimSpace(m.workingDir()); wd != "" { + roots = append(roots, wd) + } + } + return roots +} + +// resolveAndBroadcast computes a fresh snapshot and notifies +// subscribers if the aggregate hash changed. +func (m *Manager) resolveAndBroadcast(ctx context.Context) { + m.mu.Lock() + m.resolveLocked() + subs := make([]chan struct{}, 0, len(m.subscribers)) + for ch := range m.subscribers { + subs = append(subs, ch) + } + m.mu.Unlock() + + // The broadcast is unconditional: Resync waiters that + // triggered the pass without an actual content change + // still need to wake up. Subscribers compare snapshots via + // AggregateHash if they want to filter. + for _, ch := range subs { + select { + case ch <- struct{}{}: + default: + } + } + _ = ctx +} + +// resolveLocked runs the resolver and stamps the snapshot. +// The Manager's mutex must be held. +func (m *Manager) resolveLocked() { + roots := m.scanRootsLocked() + snap := m.resolver.Resolve(roots) + m.version++ + snap.Version = m.version + snap.SchemaVersion = m.schemaVersion + // Surface watcher degradation as a snapshot-level error + // when the resolver did not already emit one. + if snap.SnapshotError == "" && m.watcher != nil { + if d := m.watcher.Degraded(); d != "" { + snap.SnapshotError = d + } + } + // Ensure resources are stable-sorted by ID even when the + // resolver did not run them through caps. + sort.Slice(snap.Resources, func(i, j int) bool { + return snap.Resources[i].ID < snap.Resources[j].ID + }) + m.snapshot = snap +} + +// ErrSourceNotFound is returned by RemoveSource when the +// requested path is not in the source list. +var ErrSourceNotFound = xerrors.New("source not found") + +// ErrManagerClosed is returned by methods called after Close. +var ErrManagerClosed = xerrors.New("agentcontext: manager closed") diff --git a/agent/agentcontext/manager_test.go b/agent/agentcontext/manager_test.go new file mode 100644 index 0000000000..64a4540a7f --- /dev/null +++ b/agent/agentcontext/manager_test.go @@ -0,0 +1,260 @@ +package agentcontext_test + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/agent/agentcontext" + "github.com/coder/coder/v2/testutil" +) + +func newTestManager(t *testing.T, opts agentcontext.ManagerOptions) *agentcontext.Manager { + t.Helper() + opts.Logger = testutil.Logger(t).Named("agentcontext-test") + m, err := agentcontext.NewManager(opts) + require.NoError(t, err) + t.Cleanup(func() { _ = m.Close() }) + return m +} + +func TestManager_InitialSnapshotIsPopulated(t *testing.T) { + t.Parallel() + dir := t.TempDir() + mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "boot snapshot") + + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return dir }, + }) + + snap := m.Snapshot() + require.Equal(t, uint64(1), snap.Version) + require.Equal(t, agentcontext.CurrentSchemaVersion, snap.SchemaVersion) + require.Len(t, snap.Resources, 1) +} + +func TestManager_AddSourceTriggersResolve(t *testing.T) { + t.Parallel() + wd := t.TempDir() + src := t.TempDir() + mustWriteFile(t, filepath.Join(src, "AGENTS.md"), "from source") + + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return wd }, + AllowedRoots: []string{wd, src}, + }) + + ctx := testutil.Context(t, testutil.WaitLong) + go func() { _ = m.Run(ctx) }() + + t.Cleanup(func() { _ = m.Close() }) + + // Subscribe before mutating so we observe the broadcast. + ch, unsub := m.SubscribeChanges() + defer unsub() + + added, err := m.AddSource(agentcontext.Source{Path: src}) + require.NoError(t, err) + require.Equal(t, src, added.Path) + + select { + case <-ch: + case <-time.After(testutil.WaitShort): + t.Fatalf("expected a change broadcast after AddSource") + } + + snap := m.Snapshot() + require.Greater(t, snap.Version, uint64(1)) + + found := false + for _, r := range snap.Resources { + if r.Kind == agentcontext.KindInstructionFile && r.SourcePath == src { + found = true + } + } + require.True(t, found, "expected AGENTS.md attributed to the user source") +} + +func TestManager_AddSourceRejectsOutsideAllowedRoots(t *testing.T) { + t.Parallel() + wd := t.TempDir() + outside := t.TempDir() + + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return wd }, + AllowedRoots: []string{wd}, + }) + + _, err := m.AddSource(agentcontext.Source{Path: outside}) + require.Error(t, err) +} + +func TestManager_AddSourceIsIdempotent(t *testing.T) { + t.Parallel() + wd := t.TempDir() + src := t.TempDir() + + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return wd }, + AllowedRoots: []string{wd, src}, + }) + + added1, err := m.AddSource(agentcontext.Source{Path: src}) + require.NoError(t, err) + added2, err := m.AddSource(agentcontext.Source{Path: src}) + require.NoError(t, err) + require.Equal(t, added1.Path, added2.Path) + + sources := m.Sources() + require.Len(t, sources, 1) +} + +func TestManager_RemoveSource(t *testing.T) { + t.Parallel() + wd := t.TempDir() + src := t.TempDir() + + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return wd }, + AllowedRoots: []string{wd, src}, + }) + + _, err := m.AddSource(agentcontext.Source{Path: src}) + require.NoError(t, err) + require.NoError(t, m.RemoveSource(src)) + require.Empty(t, m.Sources()) + + err = m.RemoveSource(src) + require.ErrorIs(t, err, agentcontext.ErrSourceNotFound) +} + +func TestManager_HasSource(t *testing.T) { + t.Parallel() + wd := t.TempDir() + src := t.TempDir() + + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return wd }, + AllowedRoots: []string{wd, src}, + }) + + canonical, ok := m.HasSource(src) + require.False(t, ok) + require.Equal(t, src, canonical) + + _, err := m.AddSource(agentcontext.Source{Path: src}) + require.NoError(t, err) + + canonical, ok = m.HasSource(src) + require.True(t, ok) + require.Equal(t, src, canonical) +} + +func TestManager_ResyncReturnsLatestSnapshot(t *testing.T) { + t.Parallel() + wd := t.TempDir() + mustWriteFile(t, filepath.Join(wd, "AGENTS.md"), "first") + + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return wd }, + }) + + ctx := testutil.Context(t, testutil.WaitLong) + runDone := make(chan struct{}) + go func() { + defer close(runDone) + _ = m.Run(ctx) + }() + t.Cleanup(func() { + _ = m.Close() + <-runDone + }) + + // Mutate AGENTS.md and call Resync. The returned + // snapshot must reflect the new content. + require.NoError(t, os.WriteFile(filepath.Join(wd, "AGENTS.md"), []byte("second content edit"), 0o600)) + + snap, err := m.Resync(ctx) + require.NoError(t, err) + + require.Len(t, snap.Resources, 1) + require.Equal(t, "second content edit", string(snap.Resources[0].Payload)) +} + +func TestManager_InitialSourcesSeeded(t *testing.T) { + t.Parallel() + wd := t.TempDir() + src := t.TempDir() + mustWriteFile(t, filepath.Join(src, "AGENTS.md"), "from initial") + + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return wd }, + AllowedRoots: []string{wd, src}, + InitialSources: []agentcontext.Source{{Path: src}}, + }) + + sources := m.Sources() + require.Len(t, sources, 1) + require.Equal(t, src, sources[0].Path) + + snap := m.Snapshot() + require.Len(t, snap.Resources, 1) + require.Equal(t, src, snap.Resources[0].SourcePath) +} + +func TestManager_CloseIsIdempotent(t *testing.T) { + t.Parallel() + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return t.TempDir() }, + }) + require.NoError(t, m.Close()) + require.NoError(t, m.Close()) +} + +func TestManager_RunOnce(t *testing.T) { + t.Parallel() + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return t.TempDir() }, + }) + ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) + defer cancel() + go func() { _ = m.Run(ctx) }() + // Brief wait so Run has a chance to set running=true. + time.Sleep(50 * time.Millisecond) + + err := m.Run(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "more than once") + cancel() + _ = m.Close() +} + +func TestManager_SubscribeBroadcastOnChange(t *testing.T) { + t.Parallel() + wd := t.TempDir() + src := t.TempDir() + + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return wd }, + AllowedRoots: []string{wd, src}, + }) + + ctx := testutil.Context(t, testutil.WaitLong) + go func() { _ = m.Run(ctx) }() + + ch, unsub := m.SubscribeChanges() + defer unsub() + + _, err := m.AddSource(agentcontext.Source{Path: src}) + require.NoError(t, err) + + select { + case <-ch: + case <-time.After(testutil.WaitShort): + t.Fatal("expected subscriber to be notified") + } +} diff --git a/agent/agentcontext/mcp.go b/agent/agentcontext/mcp.go new file mode 100644 index 0000000000..cfce2bce3f --- /dev/null +++ b/agent/agentcontext/mcp.go @@ -0,0 +1,24 @@ +package agentcontext + +// MCPProvider supplies the live MCP server portion of a +// snapshot. Implementations typically wrap an existing MCP +// manager (e.g. agent/x/agentmcp.Manager) and translate each +// server's tool list into a KindMCPServer resource. +// +// The interface is intentionally minimal so the existing MCP +// lifecycle code can be reused without refactoring; a follow-up +// change absorbs the lifecycle into this package. +type MCPProvider interface { + // MCPResources returns one Resource per MCP server known + // to the provider. Each Resource must: + // + // - Have Kind == KindMCPServer. + // - Use the server name as Source. + // - Populate ContentHash over the canonical encoding + // of the tool list so changes flip the dirty bit. + // - Carry a Description summarizing the server. + // + // Implementations should never block; the resolver calls + // this on every re-resolve. + MCPResources() []Resource +} diff --git a/agent/agentcontext/paths.go b/agent/agentcontext/paths.go new file mode 100644 index 0000000000..518d9d5e62 --- /dev/null +++ b/agent/agentcontext/paths.go @@ -0,0 +1,121 @@ +package agentcontext + +import ( + "os" + "path/filepath" + "strings" + + "golang.org/x/xerrors" +) + +// CanonicalizePath produces the canonical form of a user- +// supplied path. The result is absolute, has ~ expanded, has +// path-traversal segments collapsed, and has symlinks resolved +// when the target exists. The path is left lexically clean if +// it does not yet exist (so adding a not-yet-created directory +// remains possible). +// +// CanonicalizePath returns the original input when it is empty. +func CanonicalizePath(raw string) (string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", xerrors.New("path is empty") + } + + // Expand ~ and ~/ prefixes against the current user's home + // directory. Other ~user forms are not supported on + // purpose; the agent runs as a known user. + if raw == "~" || strings.HasPrefix(raw, "~/") { + home, err := os.UserHomeDir() + if err != nil { + return "", xerrors.Errorf("expand home dir: %w", err) + } + if raw == "~" { + raw = home + } else { + raw = filepath.Join(home, raw[2:]) + } + } + + if !filepath.IsAbs(raw) { + // Fail closed: relative paths could mean different + // things depending on the agent's working directory at + // add-time, so require the caller to absolutize first. + return "", xerrors.Errorf("path %q is not absolute", raw) + } + + cleaned := filepath.Clean(raw) + if resolved, err := filepath.EvalSymlinks(cleaned); err == nil { + return resolved, nil + } + return cleaned, nil +} + +// ValidateSourcePath enforces the path-validation rules from +// the RFC's Authorization section. It rejects: +// +// - Paths containing ".." segments after expansion. +// - Paths resolving outside the supplied allowedRoots, unless +// allowedRoots is empty (which disables the check). +// +// allowedRoots are canonicalized lazily; missing roots are +// silently skipped so a workspace with no $HOME does not break +// validation for project-relative roots. +func ValidateSourcePath(canonical string, allowedRoots []string) error { + if canonical == "" { + return xerrors.New("path is empty") + } + // filepath.Clean drops "." but leaves ".." when no parent + // is available. Reject defensively. + for _, part := range strings.Split(canonical, string(os.PathSeparator)) { + if part == ".." { + return xerrors.Errorf("path %q contains parent traversal segments", canonical) + } + } + + if len(allowedRoots) == 0 { + return nil + } + + // Build canonical, deduplicated allowed roots. Missing + // roots (e.g. an unconfigured ~/.claude/) are skipped. + roots := make([]string, 0, len(allowedRoots)) + seen := make(map[string]struct{}, len(allowedRoots)) + for _, raw := range allowedRoots { + c, err := CanonicalizePath(raw) + if err != nil { + continue + } + if _, ok := seen[c]; ok { + continue + } + seen[c] = struct{}{} + roots = append(roots, c) + } + if len(roots) == 0 { + // All configured roots were invalid; treat as "deny + // everything" so misconfiguration fails closed. + return xerrors.Errorf("path %q is not inside any allowed root", canonical) + } + + for _, root := range roots { + if pathHasPrefix(canonical, root) { + return nil + } + } + return xerrors.Errorf("path %q is not inside any allowed root", canonical) +} + +// pathHasPrefix reports whether path is equal to or a +// descendant of prefix. Both arguments must already be clean, +// absolute paths. +func pathHasPrefix(path, prefix string) bool { + if path == prefix { + return true + } + withSep := prefix + if !strings.HasSuffix(withSep, string(os.PathSeparator)) { + withSep += string(os.PathSeparator) + } + return strings.HasPrefix(path, withSep) +} diff --git a/agent/agentcontext/paths_test.go b/agent/agentcontext/paths_test.go new file mode 100644 index 0000000000..7268edf1b9 --- /dev/null +++ b/agent/agentcontext/paths_test.go @@ -0,0 +1,112 @@ +package agentcontext_test + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/agent/agentcontext" +) + +func TestCanonicalizePath_AbsoluteCleansAndResolves(t *testing.T) { + t.Parallel() + dir := t.TempDir() + got, err := agentcontext.CanonicalizePath(filepath.Join(dir, "a", "..", "b")) + require.NoError(t, err) + // Path does not exist; EvalSymlinks fails. Result is + // lexically cleaned: filepath.Clean drops the "..". + require.Equal(t, filepath.Join(dir, "b"), got) +} + +func TestCanonicalizePath_RelativeRejected(t *testing.T) { + t.Parallel() + _, err := agentcontext.CanonicalizePath("relative/path") + require.Error(t, err) +} + +//nolint:paralleltest,tparallel // Uses t.Setenv. +func TestCanonicalizePath_TildeExpansion(t *testing.T) { + t.Setenv("HOME", "/tmp/home") + got, err := agentcontext.CanonicalizePath("~/.coder") + require.NoError(t, err) + require.Equal(t, "/tmp/home/.coder", got) +} + +//nolint:paralleltest,tparallel // Uses t.Setenv. +func TestCanonicalizePath_BareTildeExpandsToHome(t *testing.T) { + t.Setenv("HOME", "/tmp/home") + got, err := agentcontext.CanonicalizePath("~") + require.NoError(t, err) + require.Equal(t, "/tmp/home", got) +} + +func TestCanonicalizePath_FollowsSymlinks(t *testing.T) { + t.Parallel() + dir := t.TempDir() + realDir := filepath.Join(dir, "real") + link := filepath.Join(dir, "link") + require.NoError(t, os.MkdirAll(realDir, 0o755)) + require.NoError(t, os.Symlink(realDir, link)) + + got, err := agentcontext.CanonicalizePath(link) + require.NoError(t, err) + // On macOS the temp dir is itself symlinked; both realDir and got + // pass through the same EvalSymlinks so they line up. + want, err := filepath.EvalSymlinks(realDir) + require.NoError(t, err) + require.Equal(t, want, got) +} + +func TestValidateSourcePath_RejectsParentSegments(t *testing.T) { + t.Parallel() + err := agentcontext.ValidateSourcePath("/a/../b", []string{"/a"}) + require.Error(t, err) + require.Contains(t, err.Error(), "parent traversal") +} + +func TestValidateSourcePath_AllowsInsideRoot(t *testing.T) { + t.Parallel() + dir := t.TempDir() + child := filepath.Join(dir, "child") + require.NoError(t, os.MkdirAll(child, 0o755)) + + require.NoError(t, agentcontext.ValidateSourcePath(child, []string{dir})) + require.NoError(t, agentcontext.ValidateSourcePath(dir, []string{dir})) +} + +func TestValidateSourcePath_RejectsOutsideRoot(t *testing.T) { + t.Parallel() + root := t.TempDir() + other := t.TempDir() + err := agentcontext.ValidateSourcePath(other, []string{root}) + require.Error(t, err) + require.Contains(t, err.Error(), "not inside any allowed root") +} + +func TestValidateSourcePath_EmptyAllowedRootsBypass(t *testing.T) { + t.Parallel() + require.NoError(t, agentcontext.ValidateSourcePath("/anywhere", nil)) +} + +func TestValidateSourcePath_InvalidRootsFailClosed(t *testing.T) { + t.Parallel() + // All allowed roots are relative and therefore invalid; + // validation must fail closed. + err := agentcontext.ValidateSourcePath("/anywhere", []string{"relative-only"}) + require.Error(t, err) +} + +func TestValidateSourcePath_PathPrefixIsPathAware(t *testing.T) { + t.Parallel() + // "/a-prefix" is not inside "/a", even though it starts + // with the same bytes. + dir := t.TempDir() + sibling := strings.TrimRight(dir, string(os.PathSeparator)) + "-sibling" + require.NoError(t, os.MkdirAll(sibling, 0o755)) + t.Cleanup(func() { _ = os.RemoveAll(sibling) }) + err := agentcontext.ValidateSourcePath(sibling, []string{dir}) + require.Error(t, err) +} diff --git a/agent/agentcontext/push.go b/agent/agentcontext/push.go new file mode 100644 index 0000000000..8f43066f0d --- /dev/null +++ b/agent/agentcontext/push.go @@ -0,0 +1,182 @@ +package agentcontext + +import ( + "context" + "errors" + "time" + + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" +) + +// PushRequest is the wire-format-independent payload the +// Manager hands to a Pusher. It mirrors the protobuf +// PushContextStateRequest message reserved in the RFC. +// +// Keeping the shape in plain Go lets this package compile +// without bumping the drpc proto version. The follow-up +// integration change can add a thin adapter that converts +// PushRequest to proto and back. +type PushRequest struct { + Version uint64 + AggregateHash [32]byte + Resources []Resource + Initial bool + SchemaVersion uint64 + SnapshotError string +} + +// PushResponse is the wire-format-independent return value of +// a push. +type PushResponse struct { + Accepted bool +} + +// Pusher delivers snapshots to coderd. Concrete implementations +// wrap a drpc client (proto v30 and later) or, in tests, a +// recording in-memory fake. +// +// PushContextState must respect ctx cancellation; the Manager +// retries on transient errors with backoff but stops on +// ErrPushUnimplemented. +type Pusher interface { + PushContextState(ctx context.Context, req *PushRequest) (*PushResponse, error) +} + +// ErrPushUnimplemented signals that the coderd peer does not +// implement PushContextState. RunPush stops pushing for the +// remainder of the connection. +var ErrPushUnimplemented = xerrors.New("agentcontext: PushContextState unimplemented") + +// PushOptions parameterizes RunPush. +type PushOptions struct { + // Logger receives push success/failure diagnostics. + Logger slog.Logger + // InitialBackoff is the wait before the first retry. + // Default 250ms. + InitialBackoff time.Duration + // MaxBackoff caps the retry wait. Default 30s. + MaxBackoff time.Duration +} + +// RunPush ships the current snapshot to the Pusher, then ships +// every subsequent snapshot whenever the Manager broadcasts a +// change. RunPush returns when ctx is canceled, when the +// Manager is closed, or when the Pusher signals +// ErrPushUnimplemented. +// +// The first push is always sent with Initial=true so coderd can +// distinguish a fresh boot from a drift event. +func (m *Manager) RunPush(ctx context.Context, p Pusher, opts PushOptions) error { + if p == nil { + return xerrors.New("agentcontext: Pusher is required") + } + logger := opts.Logger + initialBackoff := opts.InitialBackoff + if initialBackoff <= 0 { + initialBackoff = 250 * time.Millisecond + } + maxBackoff := opts.MaxBackoff + if maxBackoff <= 0 { + maxBackoff = 30 * time.Second + } + + changes, unsub := m.SubscribeChanges() + defer unsub() + + // First push uses the snapshot computed by NewManager. + initial := true + for { + snap := m.Snapshot() + req := snapshotToPushRequest(snap, initial) + + err := pushWithRetry(ctx, p, req, initialBackoff, maxBackoff, logger) + switch { + case err == nil: + initial = false + case errors.Is(err, ErrPushUnimplemented): + logger.Warn(ctx, "agentcontext: coderd peer does not implement PushContextState; stopping") + return nil + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + return ctx.Err() + default: + // Should be unreachable: pushWithRetry only + // returns terminal errors. Log and continue. + logger.Warn(ctx, "agentcontext: push terminated with non-retried error", slog.Error(err)) + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.closedCh: + return nil + case _, ok := <-changes: + if !ok { + return nil + } + } + } +} + +// pushWithRetry retries transient errors with exponential +// backoff capped at maxBackoff. The retry loop exits when: +// +// - ctx is canceled (returns ctx.Err()). +// - The Pusher returns nil (success). +// - The Pusher returns ErrPushUnimplemented (propagated). +func pushWithRetry( + ctx context.Context, + p Pusher, + req *PushRequest, + initialBackoff, maxBackoff time.Duration, + logger slog.Logger, +) error { + backoff := initialBackoff + for { + resp, err := p.PushContextState(ctx, req) + if err == nil { + if resp != nil && !resp.Accepted { + // Out-of-order or replayed push. Do not + // retry; the next change will redeliver + // the snapshot with a higher version. + logger.Debug(ctx, "agentcontext: push rejected, awaiting next change", + slog.F("version", req.Version)) + } + return nil + } + if errors.Is(err, ErrPushUnimplemented) { + return err + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return err + } + logger.Warn(ctx, "agentcontext: push failed, retrying", + slog.F("version", req.Version), + slog.F("backoff", backoff), + slog.Error(err)) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(backoff): + } + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + } +} + +// snapshotToPushRequest copies the Snapshot into the wire +// representation. The Resources slice is reused; callers must +// not mutate it. +func snapshotToPushRequest(s Snapshot, initial bool) *PushRequest { + return &PushRequest{ + Version: s.Version, + AggregateHash: s.AggregateHash, + Resources: s.Resources, + Initial: initial, + SchemaVersion: s.SchemaVersion, + SnapshotError: s.SnapshotError, + } +} diff --git a/agent/agentcontext/push_test.go b/agent/agentcontext/push_test.go new file mode 100644 index 0000000000..f533879e1f --- /dev/null +++ b/agent/agentcontext/push_test.go @@ -0,0 +1,197 @@ +package agentcontext_test + +import ( + "context" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/agent/agentcontext" + "github.com/coder/coder/v2/testutil" +) + +// fakePusher records every push and lets the test control the +// returned response and error. +type fakePusher struct { + mu sync.Mutex + requests []*agentcontext.PushRequest + resp *agentcontext.PushResponse + err error + // errOnce is non-nil to simulate a single transient + // failure followed by success. + errOnce error + signal chan struct{} +} + +func newFakePusher() *fakePusher { + return &fakePusher{ + resp: &agentcontext.PushResponse{Accepted: true}, + signal: make(chan struct{}, 16), + } +} + +func (p *fakePusher) PushContextState(_ context.Context, req *agentcontext.PushRequest) (*agentcontext.PushResponse, error) { + p.mu.Lock() + defer p.mu.Unlock() + p.requests = append(p.requests, req) + if p.errOnce != nil { + err := p.errOnce + p.errOnce = nil + return nil, err + } + select { + case p.signal <- struct{}{}: + default: + } + return p.resp, p.err +} + +func (p *fakePusher) snapshot() []*agentcontext.PushRequest { + p.mu.Lock() + defer p.mu.Unlock() + out := make([]*agentcontext.PushRequest, len(p.requests)) + copy(out, p.requests) + return out +} + +func TestRunPush_FirstPushIsInitial(t *testing.T) { + t.Parallel() + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v1"), 0o600)) + + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return dir }, + }) + + p := newFakePusher() + ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) + defer cancel() + + pushDone := make(chan error, 1) + go func() { + pushDone <- m.RunPush(ctx, p, agentcontext.PushOptions{ + Logger: testutil.Logger(t).Named("push"), + }) + }() + + // Wait for the first push. + select { + case <-p.signal: + case <-time.After(testutil.WaitShort): + t.Fatalf("expected initial push") + } + + requests := p.snapshot() + require.Len(t, requests, 1) + require.True(t, requests[0].Initial, "first push must be initial") + require.Equal(t, uint64(1), requests[0].Version) + + cancel() + require.ErrorIs(t, <-pushDone, context.Canceled) +} + +func TestRunPush_SubsequentPushOnChange(t *testing.T) { + t.Parallel() + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v1"), 0o600)) + + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return dir }, + }) + + p := newFakePusher() + ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) + defer cancel() + + pushDone := make(chan error, 1) + go func() { + pushDone <- m.RunPush(ctx, p, agentcontext.PushOptions{ + Logger: testutil.Logger(t).Named("push"), + }) + }() + + // Initial push. + <-p.signal + + // Trigger a resync via Resync. + require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v2"), 0o600)) + _, err := m.Resync(ctx) + require.NoError(t, err) + + // Second push. + select { + case <-p.signal: + case <-time.After(testutil.WaitShort): + t.Fatalf("expected second push after resync") + } + + requests := p.snapshot() + require.GreaterOrEqual(t, len(requests), 2) + require.False(t, requests[1].Initial, "subsequent pushes must not be Initial") + + cancel() + require.ErrorIs(t, <-pushDone, context.Canceled) +} + +func TestRunPush_StopsOnUnimplemented(t *testing.T) { + t.Parallel() + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return t.TempDir() }, + }) + + p := newFakePusher() + p.err = agentcontext.ErrPushUnimplemented + + ctx := testutil.Context(t, testutil.WaitShort) + err := m.RunPush(ctx, p, agentcontext.PushOptions{ + Logger: testutil.Logger(t).Named("push"), + }) + require.NoError(t, err, "Unimplemented must stop the loop cleanly") +} + +func TestRunPush_RetriesTransientError(t *testing.T) { + t.Parallel() + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return t.TempDir() }, + }) + + p := newFakePusher() + p.errOnce = xerrors.New("transient") + + ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort)) + defer cancel() + pushDone := make(chan error, 1) + go func() { + pushDone <- m.RunPush(ctx, p, agentcontext.PushOptions{ + Logger: testutil.Logger(t).Named("push"), + InitialBackoff: 10 * time.Millisecond, + }) + }() + + // First push hits transient, second succeeds. + select { + case <-p.signal: + case <-time.After(testutil.WaitShort): + t.Fatalf("expected push after transient error") + } + require.GreaterOrEqual(t, len(p.snapshot()), 2) + + cancel() + <-pushDone +} + +func TestRunPush_NilPusherErrors(t *testing.T) { + t.Parallel() + m := newTestManager(t, agentcontext.ManagerOptions{ + WorkingDir: func() string { return t.TempDir() }, + }) + err := m.RunPush(context.Background(), nil, agentcontext.PushOptions{ + Logger: testutil.Logger(t).Named("push"), + }) + require.Error(t, err) +} diff --git a/agent/agentcontext/resolve.go b/agent/agentcontext/resolve.go new file mode 100644 index 0000000000..80eb12fcfd --- /dev/null +++ b/agent/agentcontext/resolve.go @@ -0,0 +1,613 @@ +package agentcontext + +import ( + "crypto/sha256" + "errors" + "fmt" + "io" + "io/fs" + "math" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + + "github.com/coder/coder/v2/codersdk/workspacesdk" +) + +// Default caps. Copied from the RFC. The Manager exposes +// overrides via Options. +const ( + // DefaultMaxResourceBytes is the per-resource payload cap. + // Resources whose payload exceeds this size are emitted + // with Status == StatusOversize and an empty Payload. + DefaultMaxResourceBytes = 64 * 1024 + // DefaultMaxSnapshotBytes is the aggregate payload cap. + // Resources past this cap are emitted with Status == + // StatusExcluded. + DefaultMaxSnapshotBytes = 2 * 1024 * 1024 + // DefaultMaxResources is the resource count cap. Resources + // past this cap are emitted with Status == StatusExcluded. + DefaultMaxResources = 500 + // DefaultMaxScanDepth bounds how deep the recursive walk + // descends from each scan root. The default avoids runaway + // scans in node_modules / vendor / .git trees while still + // covering realistic monorepo layouts. + DefaultMaxScanDepth = 8 +) + +// File-name conventions recognized by the v1 resolver. +var ( + // instructionFileNames are picked up from any scan root. + // Matching is case-insensitive on the basename. + instructionFileNames = []string{ + "AGENTS.md", + "CLAUDE.md", + ".cursorrules", + } + // mcpConfigFileName is recognized at any depth under a + // scan root. + mcpConfigFileName = ".mcp.json" + // skillMetaFileName is the file inside a skill directory + // that carries the skill front-matter. + skillMetaFileName = "SKILL.md" +) + +// skipDirNames are directory basenames that the recursive walk +// never descends into. The list mirrors what most language +// tool-chains treat as opaque. +var skipDirNames = map[string]struct{}{ + ".git": {}, + ".hg": {}, + ".svn": {}, + "node_modules": {}, + "vendor": {}, + "target": {}, + "dist": {}, + "build": {}, + ".venv": {}, + "__pycache__": {}, +} + +// skillsParentNames are directory basenames that signal a +// skills container; their immediate children are scanned for +// SKILL.md files. +var skillsParentNames = map[string]struct{}{ + "skills": {}, + ".agents": {}, // covers ".agents/skills/" + "agents": {}, + "plugins": {}, // claude code plugin cache layout + "cache": {}, + ".coder": {}, + ".claude": {}, + "skills-dir": {}, +} + +// recognizedInstructionFile reports whether name is one of the +// instruction-file conventions, case-insensitively. +func recognizedInstructionFile(name string) bool { + for _, candidate := range instructionFileNames { + if strings.EqualFold(name, candidate) { + return true + } + } + return false +} + +// Resolver walks one or more scan roots and produces a snapshot +// of every recognized resource it finds. The Resolver is +// stateless; the Manager owns the scan-root list and orchestrates +// successive resolves. +type Resolver struct { + // MaxResourceBytes caps the per-resource payload size. Use + // DefaultMaxResourceBytes if zero. + MaxResourceBytes uint64 + // MaxSnapshotBytes caps the aggregate payload size. Use + // DefaultMaxSnapshotBytes if zero. + MaxSnapshotBytes uint64 + // MaxResources caps the resource count. Use + // DefaultMaxResources if zero. + MaxResources int + // MaxDepth caps the directory walk depth. Use + // DefaultMaxScanDepth if zero. + MaxDepth int + // MCP, when non-nil, is consulted after the filesystem + // pass and contributes any KindMCPServer resources for + // live MCP servers. + MCP MCPProvider +} + +// ScanRoot describes a single directory or file the resolver +// should examine. +type ScanRoot struct { + // Path is the absolute path. Symlinks should already be + // resolved. + Path string + // UserSource is the canonical source path the user + // declared, when this root came from a user-added Source. + // Empty for built-in roots. + UserSource string +} + +// Resolve walks the supplied scan roots and returns a Snapshot. +// The version and schemaVersion fields are stamped by the +// caller; Resolve fills everything else. +func (r *Resolver) Resolve(roots []ScanRoot) Snapshot { + res := r.normalize() + resources, snapErrs := res.walk(roots) + resources = res.applyCaps(resources) + + // Append MCP server resources after the filesystem caps + // are applied so a runaway MCP server cannot crowd out + // instruction files. + if r.MCP != nil { + mcp := r.MCP.MCPResources() + resources = append(resources, mcp...) + // MCP resources may push the aggregate over the cap. + // Re-apply count and size limits to MCP entries only. + resources, snapErrs = res.applyMCPCaps(resources, snapErrs) + } + + // Deterministic order by ID for stable IDs and hashes. + sort.Slice(resources, func(i, j int) bool { + return resources[i].ID < resources[j].ID + }) + + var payloadBytes uint64 + for _, r := range resources { + payloadBytes += uint64(len(r.Payload)) + } + + hash := ComputeAggregateHash(resources) + + snap := Snapshot{ + Resources: resources, + AggregateHash: hash, + PayloadBytes: payloadBytes, + } + if len(snapErrs) > 0 { + // Pick the most severe single error. Today every + // snapshot-level problem is "warning equivalent" so + // the first one wins; the design reserves the field + // for a singular message. + snap.SnapshotError = snapErrs[0] + } + return snap +} + +func (r *Resolver) normalize() *Resolver { + out := *r + if out.MaxResourceBytes == 0 { + out.MaxResourceBytes = DefaultMaxResourceBytes + } + if out.MaxSnapshotBytes == 0 { + out.MaxSnapshotBytes = DefaultMaxSnapshotBytes + } + if out.MaxResources == 0 { + out.MaxResources = DefaultMaxResources + } + if out.MaxDepth == 0 { + out.MaxDepth = DefaultMaxScanDepth + } + return &out +} + +// walk traverses every scan root and produces an unordered +// resource list. Aggregate caps are applied separately. +func (r *Resolver) walk(roots []ScanRoot) (resources []Resource, snapErrs []string) { + // Dedup roots by canonical path. The first occurrence + // wins so user-added roots that overlap with a built-in + // root attribute resources to the built-in. + seenRoot := make(map[string]struct{}, len(roots)) + dedup := make([]ScanRoot, 0, len(roots)) + for _, root := range roots { + if root.Path == "" { + continue + } + if _, ok := seenRoot[root.Path]; ok { + continue + } + seenRoot[root.Path] = struct{}{} + dedup = append(dedup, root) + } + + // Deduplicate resources across roots by ID. Without this, + // a built-in root and a user root that both cover the + // same project tree would double-count AGENTS.md. + seenID := make(map[string]struct{}) + + for _, root := range dedup { + info, err := os.Stat(root.Path) + if err != nil { + // Missing roots silently fall through. The user + // either added a path that does not exist yet or + // removed it later. The watcher will surface + // re-creation as a change event. + continue + } + if !info.IsDir() { + // Single-file roots are classified directly. + if res, ok := r.classifyFile(root.Path, info, root.UserSource); ok { + if _, dup := seenID[res.ID]; !dup { + seenID[res.ID] = struct{}{} + resources = append(resources, res) + } + } + continue + } + walkErr := r.walkDir(root, &resources, seenID) + if walkErr != nil { + snapErrs = append(snapErrs, fmt.Sprintf("walk %q: %s", root.Path, walkErr)) + } + } + return resources, snapErrs +} + +// walkDir performs the recursive descent for a single scan +// directory. It honors r.MaxDepth and skipDirNames. +func (r *Resolver) walkDir(root ScanRoot, out *[]Resource, seenID map[string]struct{}) error { + rootDepth := strings.Count(filepath.Clean(root.Path), string(os.PathSeparator)) + maxDepth := rootDepth + r.MaxDepth + + return filepath.WalkDir(root.Path, func(path string, d fs.DirEntry, err error) error { + if err != nil { + // Surface the error as Unreadable when we can + // associate it with a single recognized file; + // otherwise let the walk continue. + if d != nil && !d.IsDir() { + if recognizedInstructionFile(d.Name()) || + d.Name() == mcpConfigFileName || + d.Name() == skillMetaFileName { + res := Resource{ + ID: resourceID(KindInstructionFile, path), + Kind: KindInstructionFile, + Source: path, + SizeBytes: 0, + Status: StatusUnreadable, + Error: err.Error(), + SourcePath: root.UserSource, + } + if _, dup := seenID[res.ID]; !dup { + seenID[res.ID] = struct{}{} + *out = append(*out, res) + } + } + } + if errors.Is(err, fs.ErrPermission) { + // Permission errors on a directory: skip the + // subtree but continue walking siblings. + if d != nil && d.IsDir() { + return fs.SkipDir + } + } + return nil + } + + if d.IsDir() { + if strings.Count(path, string(os.PathSeparator)) > maxDepth { + return fs.SkipDir + } + if _, skip := skipDirNames[d.Name()]; skip && path != root.Path { + return fs.SkipDir + } + // If we are entering a "skills container" + // directory (".agents/skills", "~/.coder/skills", + // "plugins//skills"), eagerly emit skill + // resources for its immediate subdirectories. + if isSkillsContainer(path) { + r.emitSkillsFromContainer(path, root, out, seenID) + } + return nil + } + + // Regular file. + info, statErr := d.Info() + if statErr != nil { + return nil + } + if res, ok := r.classifyFile(path, info, root.UserSource); ok { + if _, dup := seenID[res.ID]; dup { + return nil + } + seenID[res.ID] = struct{}{} + *out = append(*out, res) + } + return nil + }) +} + +// classifyFile inspects a single file path and produces a +// Resource when the basename matches a recognized convention. +func (r *Resolver) classifyFile(path string, info fs.FileInfo, userSource string) (Resource, bool) { + name := info.Name() + switch { + case recognizedInstructionFile(name): + return r.readInstructionFile(path, info, userSource), true + case name == mcpConfigFileName: + return r.readMCPConfig(path, info, userSource), true + case name == skillMetaFileName: + // SKILL.md outside a skills container is still a + // valid skill if its parent directory name matches + // the front-matter name. emitSkillsFromContainer + // already handles the common case; here we cover + // "user adds a single SKILL.md file as a source". + res, ok := r.readSkillMeta(path, info, userSource) + return res, ok + default: + return Resource{}, false + } +} + +// readInstructionFile reads an instruction file and produces a +// KindInstructionFile resource. The file is read into memory +// with the per-resource cap applied. +func (r *Resolver) readInstructionFile(path string, info fs.FileInfo, userSource string) Resource { + res := Resource{ + ID: resourceID(KindInstructionFile, path), + Kind: KindInstructionFile, + Source: path, + SizeBytes: safeUint64(info.Size()), + SourcePath: userSource, + } + if safeUint64(info.Size()) > r.MaxResourceBytes { + res.Status = StatusOversize + res.Error = fmt.Sprintf("file size %d exceeds per-resource cap of %d bytes", info.Size(), r.MaxResourceBytes) + // Still hash the (capped) content so a fix is + // detectable. + if data, err := readFileCapped(path, safeInt64(r.MaxResourceBytes)); err == nil { + res.ContentHash = sha256.Sum256(data) + } + return res + } + data, err := os.ReadFile(path) + if err != nil { + res.Status = StatusUnreadable + res.Error = err.Error() + return res + } + res.Payload = data + res.ContentHash = sha256.Sum256(data) + // Description is just the first non-empty line. + res.Description = firstLine(string(data)) + return res +} + +// readMCPConfig reads a .mcp.json file and produces a +// KindMCPConfig resource. Parsing is left to consumers; the +// resolver only enforces JSON shape lightly via size and Unix +// newline conversion. Future work: detect malformed JSON and +// surface StatusInvalid. +func (r *Resolver) readMCPConfig(path string, info fs.FileInfo, userSource string) Resource { + res := Resource{ + ID: resourceID(KindMCPConfig, path), + Kind: KindMCPConfig, + Source: path, + SizeBytes: safeUint64(info.Size()), + SourcePath: userSource, + } + if safeUint64(info.Size()) > r.MaxResourceBytes { + res.Status = StatusOversize + res.Error = fmt.Sprintf("file size %d exceeds per-resource cap of %d bytes", info.Size(), r.MaxResourceBytes) + if data, err := readFileCapped(path, safeInt64(r.MaxResourceBytes)); err == nil { + res.ContentHash = sha256.Sum256(data) + } + return res + } + data, err := os.ReadFile(path) + if err != nil { + res.Status = StatusUnreadable + res.Error = err.Error() + return res + } + res.Payload = data + res.ContentHash = sha256.Sum256(data) + return res +} + +// readSkillMeta reads a SKILL.md file, parses its front-matter, +// and emits a KindSkill resource. The name encoded in the +// front-matter must match the parent directory's basename to +// be considered valid; otherwise Status is StatusInvalid. +func (r *Resolver) readSkillMeta(path string, info fs.FileInfo, userSource string) (Resource, bool) { + parent := filepath.Base(filepath.Dir(path)) + res := Resource{ + ID: resourceID(KindSkill, filepath.Dir(path)), + Kind: KindSkill, + Source: filepath.Dir(path), + SizeBytes: safeUint64(info.Size()), + SourcePath: userSource, + } + if safeUint64(info.Size()) > r.MaxResourceBytes { + res.Status = StatusOversize + res.Error = fmt.Sprintf("file size %d exceeds per-resource cap of %d bytes", info.Size(), r.MaxResourceBytes) + return res, true + } + data, err := os.ReadFile(path) + if err != nil { + res.Status = StatusUnreadable + res.Error = err.Error() + return res, true + } + res.ContentHash = sha256.Sum256(data) + name, description, _, err := workspacesdk.ParseSkillFrontmatter(string(data)) + if err != nil { + res.Status = StatusInvalid + res.Error = err.Error() + return res, true + } + if name != parent { + res.Status = StatusInvalid + res.Error = fmt.Sprintf("front-matter name %q does not match directory %q", name, parent) + return res, true + } + if !workspacesdk.SkillNamePattern.MatchString(name) { + res.Status = StatusInvalid + res.Error = fmt.Sprintf("skill name %q is not kebab-case", name) + return res, true + } + res.Description = description + res.Payload = data + return res, true +} + +// emitSkillsFromContainer scans the immediate children of a +// recognized skills-container directory and emits one Skill +// resource per subdirectory whose SKILL.md parses cleanly. +func (r *Resolver) emitSkillsFromContainer(container string, root ScanRoot, out *[]Resource, seenID map[string]struct{}) { + entries, err := os.ReadDir(container) + if err != nil { + return + } + for _, e := range entries { + if !e.IsDir() { + continue + } + meta := filepath.Join(container, e.Name(), skillMetaFileName) + info, err := os.Stat(meta) + if err != nil { + continue + } + res, ok := r.readSkillMeta(meta, info, root.UserSource) + if !ok { + continue + } + if _, dup := seenID[res.ID]; dup { + continue + } + seenID[res.ID] = struct{}{} + *out = append(*out, res) + } +} + +// applyCaps enforces the resource-count cap and aggregate +// payload cap. Resources past either cap have their Status set +// to StatusExcluded and their Payload cleared. +func (r *Resolver) applyCaps(resources []Resource) []Resource { + // Stable sort by (Kind asc, Source asc) so excluded + // resources are deterministic. + sort.SliceStable(resources, func(i, j int) bool { + if resources[i].Kind != resources[j].Kind { + return resources[i].Kind < resources[j].Kind + } + return resources[i].Source < resources[j].Source + }) + + var total uint64 + for i := range resources { + if i >= r.MaxResources { + resources[i] = excluded(resources[i], + fmt.Sprintf("dropped to fit %d-resource snapshot count cap", r.MaxResources)) + continue + } + if resources[i].Status != StatusOK { + continue + } + size := uint64(len(resources[i].Payload)) + if total+size > r.MaxSnapshotBytes { + resources[i] = excluded(resources[i], + fmt.Sprintf("dropped to fit %d-byte aggregate cap", r.MaxSnapshotBytes)) + continue + } + total += size + } + return resources +} + +// applyMCPCaps re-applies the count cap after MCP resources are +// appended. MCP payloads are typically small JSON descriptors, +// so we treat the aggregate budget as already consumed by the +// filesystem pass. +func (r *Resolver) applyMCPCaps(resources []Resource, snapErrs []string) ([]Resource, []string) { + if len(resources) <= r.MaxResources { + return resources, snapErrs + } + for i := r.MaxResources; i < len(resources); i++ { + resources[i] = excluded(resources[i], + fmt.Sprintf("dropped to fit %d-resource snapshot count cap", r.MaxResources)) + } + snapErrs = append(snapErrs, fmt.Sprintf("snapshot exceeds %d-resource count cap", r.MaxResources)) + return resources, snapErrs +} + +// excluded mutates and returns the supplied resource with the +// StatusExcluded outcome. +func excluded(r Resource, reason string) Resource { + r.Status = StatusExcluded + r.Error = reason + r.Payload = nil + return r +} + +// isSkillsContainer reports whether dir is a recognized skills +// container directory whose immediate children carry SKILL.md +// files. +func isSkillsContainer(dir string) bool { + base := filepath.Base(dir) + _, ok := skillsParentNames[base] + if ok && base == "skills" { + return true + } + // "/skills" form (e.g. ".agents/skills", "plugins/foo/skills"). + if strings.HasSuffix(filepath.ToSlash(dir), "/skills") { + return true + } + return false +} + +// resourceID builds a stable resource ID. Kind plus canonical +// source path is enough; sources never collide across kinds for +// v1 because each kind owns a distinct file-name pattern. +func resourceID(kind ResourceKind, source string) string { + return kind.String() + ":" + source +} + +// readFileCapped reads up to maxBytes from path. It returns the +// truncated payload on success. +func readFileCapped(path string, maxBytes int64) ([]byte, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + return io.ReadAll(io.LimitReader(f, maxBytes)) +} + +// firstLine returns the first non-empty trimmed line of s, used +// as a short description fallback. +func firstLine(s string) string { + for _, line := range strings.Split(s, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + // Strip leading markdown heading markers for prettier + // descriptions. + return strings.TrimSpace(headingPrefixRegex.ReplaceAllString(line, "")) + } + return "" +} + +var headingPrefixRegex = regexp.MustCompile(`^#+\s*`) + +// safeUint64 converts a non-negative int64 to uint64. Negative +// inputs are clamped to 0, which is safe for the size-tracking +// fields that use it; a negative os.FileInfo size is pathological +// and never indicates real content. +func safeUint64(n int64) uint64 { + if n < 0 { + return 0 + } + return uint64(n) +} + +// safeInt64 converts a uint64 to int64, clamping to math.MaxInt64 +// when the input would overflow. The caps configured on the +// resolver never approach 2^63 bytes, so the clamp only guards +// against pathological caller input. +func safeInt64(n uint64) int64 { + if n > math.MaxInt64 { + return math.MaxInt64 + } + return int64(n) +} diff --git a/agent/agentcontext/resolve_test.go b/agent/agentcontext/resolve_test.go new file mode 100644 index 0000000000..8c62b600b8 --- /dev/null +++ b/agent/agentcontext/resolve_test.go @@ -0,0 +1,276 @@ +package agentcontext_test + +import ( + "crypto/sha256" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/agent/agentcontext" +) + +func mustWriteFile(t *testing.T, path, content string) { + t.Helper() + require.NoError(t, os.MkdirAll(filepath.Dir(path), 0o755)) + require.NoError(t, os.WriteFile(path, []byte(content), 0o600)) +} + +func mustWriteSkill(t *testing.T, dir, name, description string) { + t.Helper() + require.NoError(t, os.MkdirAll(filepath.Join(dir, name), 0o755)) + mustWriteFile(t, filepath.Join(dir, name, "SKILL.md"), + "---\nname: "+name+"\ndescription: "+description+"\n---\nSkill body for "+name) +} + +func findResource(t *testing.T, resources []agentcontext.Resource, kind agentcontext.ResourceKind, source string) agentcontext.Resource { + t.Helper() + for _, r := range resources { + if r.Kind == kind && r.Source == source { + return r + } + } + t.Fatalf("resource not found: kind=%s source=%s", kind, source) + return agentcontext.Resource{} +} + +func TestResolver_ProjectAGENTSFile(t *testing.T) { + t.Parallel() + dir := t.TempDir() + mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "# Project rules\n\nDo the thing.") + + r := &agentcontext.Resolver{} + snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}}) + + require.Len(t, snap.Resources, 1) + got := snap.Resources[0] + require.Equal(t, agentcontext.KindInstructionFile, got.Kind) + require.Equal(t, agentcontext.StatusOK, got.Status) + require.Equal(t, filepath.Join(dir, "AGENTS.md"), got.Source) + require.Contains(t, string(got.Payload), "Do the thing.") + require.Equal(t, "Project rules", got.Description) + require.NotEqual(t, [32]byte{}, got.ContentHash) +} + +func TestResolver_CaseInsensitiveInstructionNames(t *testing.T) { + t.Parallel() + dir := t.TempDir() + mustWriteFile(t, filepath.Join(dir, "agents.md"), "lower\n") + mustWriteFile(t, filepath.Join(dir, "CLAUDE.md"), "claude\n") + + r := &agentcontext.Resolver{} + snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}}) + + require.Len(t, snap.Resources, 2) +} + +func TestResolver_SkillsContainerEmitsEachSubdir(t *testing.T) { + t.Parallel() + dir := t.TempDir() + mustWriteSkill(t, filepath.Join(dir, ".agents", "skills"), "make-coffee", "Coffee skill") + mustWriteSkill(t, filepath.Join(dir, ".agents", "skills"), "fold-laundry", "Laundry skill") + + r := &agentcontext.Resolver{} + snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}}) + + var kinds []string + for _, res := range snap.Resources { + kinds = append(kinds, res.Kind.String()+":"+filepath.Base(res.Source)) + } + require.ElementsMatch(t, []string{ + "skill:make-coffee", + "skill:fold-laundry", + }, kinds) +} + +func TestResolver_SkillNameMismatchInvalid(t *testing.T) { + t.Parallel() + dir := t.TempDir() + skillsDir := filepath.Join(dir, ".agents", "skills", "make-coffee") + require.NoError(t, os.MkdirAll(skillsDir, 0o755)) + mustWriteFile(t, filepath.Join(skillsDir, "SKILL.md"), + "---\nname: drink-tea\ndescription: oops\n---\nBody") + + r := &agentcontext.Resolver{} + snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}}) + + require.Len(t, snap.Resources, 1) + got := snap.Resources[0] + require.Equal(t, agentcontext.KindSkill, got.Kind) + require.Equal(t, agentcontext.StatusInvalid, got.Status) + require.Contains(t, got.Error, "does not match directory") +} + +func TestResolver_MCPConfigEmitted(t *testing.T) { + t.Parallel() + dir := t.TempDir() + mustWriteFile(t, filepath.Join(dir, ".mcp.json"), `{"mcpServers": {}}`) + + r := &agentcontext.Resolver{} + snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}}) + + require.Len(t, snap.Resources, 1) + require.Equal(t, agentcontext.KindMCPConfig, snap.Resources[0].Kind) + require.Equal(t, agentcontext.StatusOK, snap.Resources[0].Status) +} + +func TestResolver_OversizeInstructionFile(t *testing.T) { + t.Parallel() + dir := t.TempDir() + // Write a file larger than the per-resource cap. + big := make([]byte, 200) + for i := range big { + big[i] = 'a' + } + mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), string(big)) + + r := &agentcontext.Resolver{MaxResourceBytes: 100} + snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}}) + + require.Len(t, snap.Resources, 1) + got := snap.Resources[0] + require.Equal(t, agentcontext.StatusOversize, got.Status) + require.Empty(t, got.Payload) + require.Equal(t, uint64(200), got.SizeBytes) + // Hash over capped slice is still populated so callers + // can detect "still oversize but content changed". + require.NotEqual(t, [32]byte{}, got.ContentHash) +} + +func TestResolver_AggregateCapExcludes(t *testing.T) { + t.Parallel() + dir := t.TempDir() + mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "small") + subA := filepath.Join(dir, "a") + subB := filepath.Join(dir, "b") + mustWriteFile(t, filepath.Join(subA, "AGENTS.md"), "AAAA") + mustWriteFile(t, filepath.Join(subB, "AGENTS.md"), "BBBB") + + // Aggregate cap of 9 bytes lets the first two through but + // excludes the third regardless of which order they + // appear. + r := &agentcontext.Resolver{MaxSnapshotBytes: 9} + snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}}) + + var excluded int + for _, res := range snap.Resources { + if res.Status == agentcontext.StatusExcluded { + excluded++ + } + } + require.Equal(t, 1, excluded) +} + +func TestResolver_CountCapExcludes(t *testing.T) { + t.Parallel() + dir := t.TempDir() + for i := 0; i < 5; i++ { + sub := filepath.Join(dir, "dir", string('a'+rune(i))) + mustWriteFile(t, filepath.Join(sub, "AGENTS.md"), "x") + } + + r := &agentcontext.Resolver{MaxResources: 3} + snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}}) + + require.Len(t, snap.Resources, 5) + var excluded int + for _, res := range snap.Resources { + if res.Status == agentcontext.StatusExcluded { + excluded++ + } + } + require.Equal(t, 2, excluded) +} + +func TestResolver_SkipsVendorAndNodeModules(t *testing.T) { + t.Parallel() + dir := t.TempDir() + mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "root") + mustWriteFile(t, filepath.Join(dir, "node_modules", "deep", "AGENTS.md"), "should not appear") + mustWriteFile(t, filepath.Join(dir, "vendor", "AGENTS.md"), "should not appear either") + + r := &agentcontext.Resolver{} + snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}}) + + require.Len(t, snap.Resources, 1) + require.Equal(t, filepath.Join(dir, "AGENTS.md"), snap.Resources[0].Source) +} + +func TestResolver_UserSourceAttribution(t *testing.T) { + t.Parallel() + dir := t.TempDir() + mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "user-added") + + r := &agentcontext.Resolver{} + snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir, UserSource: dir}}) + + require.Len(t, snap.Resources, 1) + require.Equal(t, dir, snap.Resources[0].SourcePath) +} + +func TestResolver_MissingRootSilentlyIgnored(t *testing.T) { + t.Parallel() + r := &agentcontext.Resolver{} + snap := r.Resolve([]agentcontext.ScanRoot{{Path: "/nonexistent/path"}}) + require.Empty(t, snap.Resources) + require.Empty(t, snap.SnapshotError) +} + +func TestResolver_SingleFileRootClassified(t *testing.T) { + t.Parallel() + dir := t.TempDir() + path := filepath.Join(dir, "AGENTS.md") + mustWriteFile(t, path, "x") + + r := &agentcontext.Resolver{} + snap := r.Resolve([]agentcontext.ScanRoot{{Path: path}}) + + require.Len(t, snap.Resources, 1) + require.Equal(t, agentcontext.KindInstructionFile, snap.Resources[0].Kind) +} + +func TestResolver_DuplicateRootsDeduplicated(t *testing.T) { + t.Parallel() + dir := t.TempDir() + mustWriteFile(t, filepath.Join(dir, "AGENTS.md"), "x") + + r := &agentcontext.Resolver{} + snap := r.Resolve([]agentcontext.ScanRoot{ + {Path: dir}, + {Path: dir}, + {Path: dir}, + }) + require.Len(t, snap.Resources, 1) +} + +func TestResolver_MCPProviderResources(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + mcpRes := agentcontext.Resource{ + ID: "mcp_server:github", + Kind: agentcontext.KindMCPServer, + Source: "github", + Status: agentcontext.StatusOK, + Payload: []byte("tool-list-json"), + ContentHash: sha256.Sum256([]byte("tool-list-json")), + Description: "GitHub MCP server", + } + r := &agentcontext.Resolver{ + MCP: &fakeMCPProvider{resources: []agentcontext.Resource{mcpRes}}, + } + + snap := r.Resolve([]agentcontext.ScanRoot{{Path: dir}}) + got := findResource(t, snap.Resources, agentcontext.KindMCPServer, "github") + require.Equal(t, agentcontext.StatusOK, got.Status) + require.Equal(t, "GitHub MCP server", got.Description) +} + +type fakeMCPProvider struct { + resources []agentcontext.Resource +} + +func (f *fakeMCPProvider) MCPResources() []agentcontext.Resource { + return f.resources +} diff --git a/agent/agentcontext/types.go b/agent/agentcontext/types.go new file mode 100644 index 0000000000..cf90a12b46 --- /dev/null +++ b/agent/agentcontext/types.go @@ -0,0 +1,222 @@ +package agentcontext + +import ( + "crypto/sha256" + "encoding/hex" + "sort" + "strconv" +) + +// ResourceKind describes the category of a resolved context +// resource. The values mirror the proto ContextResource.Kind +// enum reserved in the RFC; future kinds (PLUGIN, HOOK, +// SUBAGENT, COMMAND) are defined here so callers can switch +// exhaustively, but no v1 resolver emits them. +type ResourceKind int + +const ( + KindUnspecified ResourceKind = iota + // KindInstructionFile covers AGENTS.md, CLAUDE.md, + // .cursorrules, and similar plain-text rule files that + // inject content into the model prompt. + KindInstructionFile + // KindSkill is a directory containing SKILL.md and any + // supporting files. Only the meta file is read at + // resolve time; bodies are fetched on demand. + KindSkill + // KindMCPConfig is a .mcp.json fragment declaring one or + // more MCP servers. + KindMCPConfig + // KindMCPServer is a live MCP server's resolved tool list, + // populated by an MCPProvider after the server has been + // connected. + KindMCPServer + // KindPlugin is reserved for Claude Code plugin manifests. + // Not emitted by v1. + KindPlugin + // KindHook is reserved for plugin hooks. Not emitted by v1. + KindHook + // KindSubagent is reserved for plugin-declared subagents. + // Not emitted by v1. + KindSubagent + // KindCommand is reserved for plugin slash commands. + // Not emitted by v1. + KindCommand +) + +// String returns the lower-snake-case name used in IDs and +// metrics. Unknown values stringify to "unknown". +func (k ResourceKind) String() string { + switch k { + case KindInstructionFile: + return "instruction_file" + case KindSkill: + return "skill" + case KindMCPConfig: + return "mcp_config" + case KindMCPServer: + return "mcp_server" + case KindPlugin: + return "plugin" + case KindHook: + return "hook" + case KindSubagent: + return "subagent" + case KindCommand: + return "command" + default: + return "unknown" + } +} + +// ResourceStatus describes whether a resource was successfully +// read and whether its payload survived the per-resource and +// aggregate caps. +type ResourceStatus int + +const ( + // StatusOK indicates the payload was populated. + StatusOK ResourceStatus = iota + // StatusOversize indicates the resource exceeded the + // per-resource size cap; payload is omitted. + StatusOversize + // StatusUnreadable indicates an IO error reading the + // resource (permission denied, broken symlink, etc.). + StatusUnreadable + // StatusInvalid indicates the resource was structurally + // malformed (bad JSON, missing front-matter, etc.). + StatusInvalid + // StatusExcluded indicates the resource was dropped to fit + // the aggregate snapshot or count cap. + StatusExcluded +) + +// String returns the lower-snake-case name used in IDs and +// metrics. Unknown values stringify to "unknown". +func (s ResourceStatus) String() string { + switch s { + case StatusOK: + return "ok" + case StatusOversize: + return "oversize" + case StatusUnreadable: + return "unreadable" + case StatusInvalid: + return "invalid" + case StatusExcluded: + return "excluded" + default: + return "unknown" + } +} + +// Source is a user-declared scan root added to the agent's +// in-memory list via the HTTP API or boot-time env seeding. +// Identity is the canonical absolute path. +type Source struct { + // Path is the canonical absolute path (symlinks resolved, + // ~ expanded). Empty means the zero value. + Path string +} + +// Resource is what the resolver emits for each recognized file +// or live server it discovers under a scan root. +type Resource struct { + // ID is stable across pushes for the same logical + // resource. The current scheme is ":". + ID string + // Kind classifies the resource for snapshot consumers. + Kind ResourceKind + // Source is the file path or MCP server name. + Source string + // ContentHash is sha256 over the resource's original + // bytes (or transport-encoded server tool list). + ContentHash [32]byte + // Payload is the full bytes when Status == StatusOK; the + // per-resource and aggregate caps may leave it empty. + Payload []byte + // SizeBytes is the original payload size, populated + // regardless of Status. + SizeBytes uint64 + // Status records OK or a reason the payload is absent. + Status ResourceStatus + // Error is populated whenever Status != StatusOK; may + // also carry a non-fatal warning when Status == StatusOK. + Error string + // Description is a short human-readable summary (skill + // front-matter description, MCP server description, etc.). + Description string + // SourcePath is the user-declared source that contributed + // the resource; empty for built-in scan roots. + SourcePath string +} + +// Snapshot is the immutable bundle of resources produced by a +// single resolver pass. +type Snapshot struct { + // Version is monotonically increasing per Manager + // instance; resets when the agent process restarts. + Version uint64 + // SchemaVersion is bumped if the resource shape on the + // wire changes. + SchemaVersion uint64 + // AggregateHash is sha256 over a canonical encoding of + // (ID, Kind, Source, ContentHash, Status) for every + // resource. Identical inputs always produce identical + // hashes; see ComputeAggregateHash. + AggregateHash [32]byte + // Resources is sorted by ID for deterministic encoding. + Resources []Resource + // PayloadBytes is the sum of len(Resource.Payload) across + // emitted resources after caps were applied. + PayloadBytes uint64 + // SnapshotError carries a single snapshot-level error + // string when present (count cap exceeded, watcher + // degraded, ENOSPC, etc.). Empty when healthy. + SnapshotError string +} + +// ComputeAggregateHash produces the deterministic snapshot +// aggregate hash for the supplied resources. The caller does +// not need to pre-sort; the function sorts a copy of the slice +// to keep its inputs side-effect free. +// +// The encoding is a newline-delimited stream of fields. The +// resource boundary is a single NUL byte. The boundary scheme +// is internal to the agent and coderd, but it is stable across +// platforms because every field is encoded as either a UTF-8 +// string with a length prefix or a fixed-width integer. +func ComputeAggregateHash(resources []Resource) [32]byte { + indexed := make([]Resource, len(resources)) + copy(indexed, resources) + sort.Slice(indexed, func(i, j int) bool { + return indexed[i].ID < indexed[j].ID + }) + + h := sha256.New() + for _, r := range indexed { + writeLengthPrefixed(h, r.ID) + writeLengthPrefixed(h, r.Kind.String()) + writeLengthPrefixed(h, r.Source) + _, _ = h.Write(r.ContentHash[:]) + writeLengthPrefixed(h, r.Status.String()) + _, _ = h.Write([]byte{0}) + } + var out [32]byte + copy(out[:], h.Sum(nil)) + return out +} + +// AggregateHashHex returns the hex-encoded aggregate hash. +// Convenience for log lines and HTTP responses. +func (s Snapshot) AggregateHashHex() string { + return hex.EncodeToString(s.AggregateHash[:]) +} + +// writeLengthPrefixed writes a uvarint length prefix followed +// by the raw bytes of s. +func writeLengthPrefixed(h interface{ Write([]byte) (int, error) }, s string) { + _, _ = h.Write([]byte(strconv.Itoa(len(s)))) + _, _ = h.Write([]byte{':'}) + _, _ = h.Write([]byte(s)) +} diff --git a/agent/agentcontext/types_test.go b/agent/agentcontext/types_test.go new file mode 100644 index 0000000000..585dac0c69 --- /dev/null +++ b/agent/agentcontext/types_test.go @@ -0,0 +1,99 @@ +package agentcontext_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/agent/agentcontext" +) + +func TestResourceKindString(t *testing.T) { + t.Parallel() + tests := []struct { + kind agentcontext.ResourceKind + want string + }{ + {agentcontext.KindUnspecified, "unknown"}, + {agentcontext.KindInstructionFile, "instruction_file"}, + {agentcontext.KindSkill, "skill"}, + {agentcontext.KindMCPConfig, "mcp_config"}, + {agentcontext.KindMCPServer, "mcp_server"}, + {agentcontext.KindPlugin, "plugin"}, + {agentcontext.KindHook, "hook"}, + {agentcontext.KindSubagent, "subagent"}, + {agentcontext.KindCommand, "command"}, + {agentcontext.ResourceKind(999), "unknown"}, + } + for _, tt := range tests { + require.Equal(t, tt.want, tt.kind.String()) + } +} + +func TestResourceStatusString(t *testing.T) { + t.Parallel() + tests := []struct { + status agentcontext.ResourceStatus + want string + }{ + {agentcontext.StatusOK, "ok"}, + {agentcontext.StatusOversize, "oversize"}, + {agentcontext.StatusUnreadable, "unreadable"}, + {agentcontext.StatusInvalid, "invalid"}, + {agentcontext.StatusExcluded, "excluded"}, + {agentcontext.ResourceStatus(999), "unknown"}, + } + for _, tt := range tests { + require.Equal(t, tt.want, tt.status.String()) + } +} + +func TestComputeAggregateHash_DeterministicAcrossOrder(t *testing.T) { + t.Parallel() + a := agentcontext.Resource{ + ID: "instruction_file:/a/AGENTS.md", + Kind: agentcontext.KindInstructionFile, + Source: "/a/AGENTS.md", + Status: agentcontext.StatusOK, + } + b := agentcontext.Resource{ + ID: "instruction_file:/b/AGENTS.md", + Kind: agentcontext.KindInstructionFile, + Source: "/b/AGENTS.md", + Status: agentcontext.StatusOK, + } + got1 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{a, b}) + got2 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{b, a}) + require.Equal(t, got1, got2) +} + +func TestComputeAggregateHash_ChangesOnContent(t *testing.T) { + t.Parallel() + base := agentcontext.Resource{ + ID: "instruction_file:/a/AGENTS.md", + Kind: agentcontext.KindInstructionFile, + Source: "/a/AGENTS.md", + Status: agentcontext.StatusOK, + } + hash1 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{base}) + + withContent := base + withContent.ContentHash = [32]byte{0x01} + hash2 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{withContent}) + require.NotEqual(t, hash1, hash2) + + withStatus := base + withStatus.Status = agentcontext.StatusOversize + hash3 := agentcontext.ComputeAggregateHash([]agentcontext.Resource{withStatus}) + require.NotEqual(t, hash1, hash3) +} + +func TestSnapshotAggregateHashHex(t *testing.T) { + t.Parallel() + snap := agentcontext.Snapshot{ + AggregateHash: [32]byte{0xde, 0xad, 0xbe, 0xef}, + } + require.Equal(t, + "deadbeef0000000000000000000000000000000000000000000000000000000000000000"[:64], + snap.AggregateHashHex()) +} diff --git a/agent/agentcontext/watch.go b/agent/agentcontext/watch.go new file mode 100644 index 0000000000..02118fbe91 --- /dev/null +++ b/agent/agentcontext/watch.go @@ -0,0 +1,346 @@ +package agentcontext + +import ( + "context" + "errors" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + + "github.com/fsnotify/fsnotify" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/quartz" +) + +// DefaultWatchDebounce coalesces editor-style multi-event writes +// (truncate plus rename plus chmod) into a single re-resolve. +// Mirrors the debounce window the existing MCP config watcher +// uses so behavior is consistent across the agent. +const DefaultWatchDebounce = 250 * time.Millisecond + +// WatcherOptions parameterizes the recursive watcher. +type WatcherOptions struct { + Logger slog.Logger + Clock quartz.Clock + Debounce time.Duration + // OnChange runs at most once per debounce window. The + // caller must not block; the recommended pattern is a + // non-blocking send on a re-resolve trigger channel. + OnChange func() +} + +// Watcher is a recursive fsnotify wrapper. fsnotify does not +// support recursive watches natively on Linux, so we walk every +// scan root at sync time and register each subdirectory +// individually. Inotify ENOSPC degrades the watcher into a +// poll-only mode that still re-resolves on Sync calls. +type Watcher struct { + logger slog.Logger + clock quartz.Clock + debounce time.Duration + onChange func() + + mu sync.Mutex + watcher *fsnotify.Watcher + watched map[string]struct{} + timer *quartz.Timer + degraded string // non-empty when the watcher dropped events + closed bool + closedCh chan struct{} + runDoneCh chan struct{} +} + +// NewWatcher constructs a recursive watcher. The watcher does +// nothing until Sync is called. +func NewWatcher(opts WatcherOptions) (*Watcher, error) { + if opts.OnChange == nil { + return nil, xerrors.New("OnChange callback is required") + } + debounce := opts.Debounce + if debounce <= 0 { + debounce = DefaultWatchDebounce + } + clock := opts.Clock + if clock == nil { + clock = quartz.NewReal() + } + + w, err := fsnotify.NewWatcher() + if err != nil { + // On Linux, fsnotify.NewWatcher only fails when the + // inotify subsystem is at the system-wide watch + // limit. Surface a Watcher in "degraded" mode so the + // caller can still rely on explicit Sync triggers. + degraded := &Watcher{ + logger: opts.Logger, + clock: clock, + debounce: debounce, + onChange: opts.OnChange, + watched: make(map[string]struct{}), + degraded: "fsnotify init failed: " + err.Error(), + closedCh: make(chan struct{}), + runDoneCh: closedChan(), + } + return degraded, nil + } + + cw := &Watcher{ + logger: opts.Logger, + clock: clock, + debounce: debounce, + onChange: opts.OnChange, + watcher: w, + watched: make(map[string]struct{}), + closedCh: make(chan struct{}), + runDoneCh: make(chan struct{}), + } + go cw.run() + return cw, nil +} + +// closedChan returns an already-closed channel for the +// degraded-watcher case where there is no run goroutine. +func closedChan() chan struct{} { + c := make(chan struct{}) + close(c) + return c +} + +// Degraded returns a non-empty string when the watcher is +// running with reduced functionality (typically inotify +// ENOSPC). The string is suitable for use as a snapshot-level +// error message. +func (w *Watcher) Degraded() string { + w.mu.Lock() + defer w.mu.Unlock() + return w.degraded +} + +// Sync replaces the set of watched directories with a fresh +// recursive walk of every scan root. Files are not watched +// directly; watching the parent directory catches creates, +// renames, removes, and writes that touch any recognized +// basename. Files that are themselves scan roots are handled by +// watching their parent. +// +// Sync is idempotent and safe to call repeatedly. +func (w *Watcher) Sync(ctx context.Context, roots []ScanRoot) { + w.mu.Lock() + if w.closed { + w.mu.Unlock() + return + } + if w.watcher == nil { + // Degraded mode: nothing to wire up; fire the callback + // so the caller still gets a fresh resolve. + w.mu.Unlock() + w.schedule() + return + } + desired := w.collectDirs(roots) + + // Remove directories no longer wanted. + for path := range w.watched { + if _, ok := desired[path]; ok { + continue + } + _ = w.watcher.Remove(path) + delete(w.watched, path) + } + // Add directories that are new. + for path := range desired { + if _, ok := w.watched[path]; ok { + continue + } + if err := w.watcher.Add(path); err != nil { + // ENOSPC means the kernel's per-user inotify + // watch budget is exhausted. Mark the watcher + // degraded; subsequent Sync calls still fire + // the change callback so resync still works. + if errors.Is(err, syscall.ENOSPC) { + w.degraded = "inotify watch limit exceeded (ENOSPC)" + w.logger.Warn(ctx, "context watcher degraded: inotify watch limit exceeded", + slog.F("dir", path)) + break + } + w.logger.Debug(ctx, "context watcher could not add dir", + slog.F("dir", path), slog.Error(err)) + continue + } + w.watched[path] = struct{}{} + } + w.mu.Unlock() +} + +// Close stops the watcher and releases all kernel watch slots. +// Close is idempotent. +func (w *Watcher) Close() error { + w.mu.Lock() + if w.closed { + w.mu.Unlock() + return nil + } + w.closed = true + close(w.closedCh) + timer := w.timer + watcher := w.watcher + w.timer = nil + w.watcher = nil + w.mu.Unlock() + + if timer != nil { + timer.Stop() + } + if watcher != nil { + _ = watcher.Close() + } + <-w.runDoneCh + return nil +} + +// run forwards fsnotify events into the debounce timer. It exits +// when Close is called or the underlying watcher is closed. +func (w *Watcher) run() { + defer close(w.runDoneCh) + // Capture the watcher reference once. Close may set the + // field to nil concurrently; reading the captured local + // keeps the event loop safe through the race window. + w.mu.Lock() + fsw := w.watcher + w.mu.Unlock() + if fsw == nil { + return + } + for { + select { + case <-w.closedCh: + return + case ev, ok := <-fsw.Events: + if !ok { + return + } + if !w.eventRelevant(ev) { + continue + } + w.schedule() + case err, ok := <-fsw.Errors: + if !ok { + return + } + if err != nil { + w.logger.Debug(context.Background(), "context watcher error", slog.Error(err)) + } + } + } +} + +// eventRelevant filters out events that cannot affect any +// recognized resource. The check is conservative: any event on +// a directory triggers a re-resolve so newly created subtrees +// are picked up. +func (*Watcher) eventRelevant(ev fsnotify.Event) bool { + name := filepath.Base(ev.Name) + if recognizedInstructionFile(name) || name == mcpConfigFileName || name == skillMetaFileName { + return true + } + // Directory create/remove flips re-resolve so new subtrees + // arm watches and removed subtrees stop arming them. + if ev.Has(fsnotify.Create) || ev.Has(fsnotify.Remove) || ev.Has(fsnotify.Rename) { + return true + } + return false +} + +// schedule arms or resets the debounce timer. +func (w *Watcher) schedule() { + w.mu.Lock() + if w.closed { + w.mu.Unlock() + return + } + cb := w.onChange + if w.timer != nil { + w.timer.Reset(w.debounce) + w.mu.Unlock() + return + } + w.timer = w.clock.AfterFunc(w.debounce, func() { + w.mu.Lock() + w.timer = nil + w.mu.Unlock() + cb() + }) + w.mu.Unlock() +} + +// collectDirs walks every scan root and returns the set of +// directories to watch. The maximum depth mirrors the +// resolver's MaxDepth so we never watch more than we'd scan. +func (*Watcher) collectDirs(roots []ScanRoot) map[string]struct{} { + out := make(map[string]struct{}) + for _, root := range roots { + if root.Path == "" { + continue + } + info, err := os.Stat(root.Path) + if err != nil { + // Watch the deepest existing ancestor so the + // root being created later still fires. + if anc := existingAncestor(root.Path); anc != "" { + out[anc] = struct{}{} + } + continue + } + if !info.IsDir() { + out[filepath.Dir(root.Path)] = struct{}{} + continue + } + // Walk the directory and collect every descendant + // directory up to the depth cap. + rootDepth := strings.Count(filepath.Clean(root.Path), string(os.PathSeparator)) + _ = filepath.WalkDir(root.Path, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return nil + } + if !d.IsDir() { + return nil + } + if _, skip := skipDirNames[d.Name()]; skip && path != root.Path { + return fs.SkipDir + } + if strings.Count(path, string(os.PathSeparator))-rootDepth > DefaultMaxScanDepth { + return fs.SkipDir + } + out[path] = struct{}{} + return nil + }) + } + return out +} + +// existingAncestor returns the deepest existing ancestor of +// path, or "" if no ancestor exists (e.g. an entirely missing +// drive on Windows). +func existingAncestor(path string) string { + cur := filepath.Dir(path) + for { + if cur == "" || cur == "." { + return "" + } + info, err := os.Stat(cur) + if err == nil && info.IsDir() { + return cur + } + parent := filepath.Dir(cur) + if parent == cur { + return "" + } + cur = parent + } +} diff --git a/agent/agentcontext/watch_test.go b/agent/agentcontext/watch_test.go new file mode 100644 index 0000000000..f46a47b753 --- /dev/null +++ b/agent/agentcontext/watch_test.go @@ -0,0 +1,96 @@ +package agentcontext_test + +import ( + "context" + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/coder/quartz" + + "github.com/coder/coder/v2/agent/agentcontext" + "github.com/coder/coder/v2/testutil" +) + +func TestWatcher_FiresOnAgentsMdEdit(t *testing.T) { + t.Parallel() + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v1"), 0o600)) + + var fires int32 + w, err := agentcontext.NewWatcher(agentcontext.WatcherOptions{ + Logger: testutil.Logger(t).Named("watcher"), + Clock: quartz.NewReal(), + Debounce: 10 * time.Millisecond, + OnChange: func() { atomic.AddInt32(&fires, 1) }, + }) + require.NoError(t, err) + t.Cleanup(func() { _ = w.Close() }) + + ctx := testutil.Context(t, testutil.WaitShort) + w.Sync(ctx, []agentcontext.ScanRoot{{Path: dir}}) + + // Edit the file. Use a slight delay so fsnotify is ready. + time.Sleep(50 * time.Millisecond) + require.NoError(t, os.WriteFile(filepath.Join(dir, "AGENTS.md"), []byte("v2"), 0o600)) + + require.Eventually(t, func() bool { + return atomic.LoadInt32(&fires) >= 1 + }, testutil.WaitShort, testutil.IntervalFast, "expected at least one fire after AGENTS.md edit") +} + +func TestWatcher_FiresOnNewSkillFile(t *testing.T) { + t.Parallel() + dir := t.TempDir() + skillsRoot := filepath.Join(dir, ".agents", "skills") + require.NoError(t, os.MkdirAll(skillsRoot, 0o755)) + + var fires int32 + w, err := agentcontext.NewWatcher(agentcontext.WatcherOptions{ + Logger: testutil.Logger(t).Named("watcher"), + Debounce: 10 * time.Millisecond, + OnChange: func() { atomic.AddInt32(&fires, 1) }, + }) + require.NoError(t, err) + t.Cleanup(func() { _ = w.Close() }) + + ctx := testutil.Context(t, testutil.WaitShort) + w.Sync(ctx, []agentcontext.ScanRoot{{Path: dir}}) + + time.Sleep(50 * time.Millisecond) + skillDir := filepath.Join(skillsRoot, "foo") + require.NoError(t, os.MkdirAll(skillDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("---\nname: foo\ndescription: bar\n---\nbody"), 0o600)) + + require.Eventually(t, func() bool { + return atomic.LoadInt32(&fires) >= 1 + }, testutil.WaitShort, testutil.IntervalFast, "expected fire after SKILL.md create") +} + +func TestWatcher_CloseIsIdempotent(t *testing.T) { + t.Parallel() + w, err := agentcontext.NewWatcher(agentcontext.WatcherOptions{ + Logger: testutil.Logger(t).Named("watcher"), + OnChange: func() {}, + }) + require.NoError(t, err) + require.NoError(t, w.Close()) + require.NoError(t, w.Close()) +} + +func TestWatcher_SyncAfterCloseNoop(t *testing.T) { + t.Parallel() + w, err := agentcontext.NewWatcher(agentcontext.WatcherOptions{ + Logger: testutil.Logger(t).Named("watcher"), + OnChange: func() {}, + }) + require.NoError(t, err) + require.NoError(t, w.Close()) + + // Must not panic. + w.Sync(context.Background(), []agentcontext.ScanRoot{{Path: t.TempDir()}}) +}