mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: tag chat-originating agent logs with chat_id (#25019)
Workspace-agent logs emitted while serving chatd-driven requests were not correlated with the originating chat, making agent logs hard to attribute to the corresponding/originating chat. This adds agent-side chat context middleware that parses `Coder-Chat-Id` once, enriches agent access logs and structured handler/background logs, and adds a chatd bridge log when chat headers are attached to an agent connection. Closes CODAGT-324
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
package agentgit
|
||||
package agentchat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// ExtractChatContext reads chat identity headers from the request.
|
||||
// extractContext reads chat identity headers from the request.
|
||||
// Returns zero values if headers are absent (non-chat request).
|
||||
func ExtractChatContext(r *http.Request) (chatID uuid.UUID, ancestorIDs []uuid.UUID, ok bool) {
|
||||
func extractContext(r *http.Request) (chatID uuid.UUID, ancestorIDs []uuid.UUID, ok bool) {
|
||||
raw := r.Header.Get(workspacesdk.CoderChatIDHeader)
|
||||
if raw == "" {
|
||||
return uuid.Nil, nil, false
|
||||
@@ -1,18 +1,19 @@
|
||||
package agentgit_test
|
||||
package agentchat_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/agent/agentchat"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
func TestExtractChatContext(t *testing.T) {
|
||||
func TestExtractContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
validID := uuid.MustParse("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee")
|
||||
@@ -43,7 +44,7 @@ func TestExtractChatContext(t *testing.T) {
|
||||
setChatID: true,
|
||||
setAncestors: false,
|
||||
wantChatID: validID,
|
||||
wantAncestorIDs: nil,
|
||||
wantAncestorIDs: []uuid.UUID{},
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
@@ -75,7 +76,7 @@ func TestExtractChatContext(t *testing.T) {
|
||||
ancestors: `{this is not json}`,
|
||||
setAncestors: true,
|
||||
wantChatID: validID,
|
||||
wantAncestorIDs: nil,
|
||||
wantAncestorIDs: []uuid.UUID{},
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
@@ -112,7 +113,7 @@ func TestExtractChatContext(t *testing.T) {
|
||||
ancestors: "",
|
||||
setAncestors: true,
|
||||
wantChatID: validID,
|
||||
wantAncestorIDs: nil,
|
||||
wantAncestorIDs: []uuid.UUID{},
|
||||
wantOK: true,
|
||||
},
|
||||
}
|
||||
@@ -130,7 +131,7 @@ func TestExtractChatContext(t *testing.T) {
|
||||
r.Header.Set(workspacesdk.CoderAncestorChatIDsHeader, tt.ancestors)
|
||||
}
|
||||
|
||||
chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r)
|
||||
chatID, ancestorIDs, ok := extractContextForTest(r)
|
||||
|
||||
require.Equal(t, tt.wantOK, ok, "ok mismatch")
|
||||
require.Equal(t, tt.wantChatID, chatID, "chatID mismatch")
|
||||
@@ -139,6 +140,18 @@ func TestExtractChatContext(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func extractContextForTest(r *http.Request) (uuid.UUID, []uuid.UUID, bool) {
|
||||
var chatContext agentchat.Context
|
||||
var ok bool
|
||||
agentchat.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
chatContext, ok = agentchat.FromContext(r.Context())
|
||||
})).ServeHTTP(httptest.NewRecorder(), r)
|
||||
if !ok {
|
||||
return uuid.Nil, nil, false
|
||||
}
|
||||
return chatContext.ID, chatContext.AncestorIDs, true
|
||||
}
|
||||
|
||||
// mustMarshalJSON marshals v to a JSON string, failing the test on error.
|
||||
func mustMarshalJSON(t *testing.T, v any) string {
|
||||
t.Helper()
|
||||
@@ -0,0 +1,85 @@
|
||||
package agentchat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
|
||||
)
|
||||
|
||||
type chatContextKey struct{}
|
||||
|
||||
// Context carries the chat identity associated with an agent request.
|
||||
type Context struct {
|
||||
ID uuid.UUID
|
||||
AncestorIDs []uuid.UUID
|
||||
}
|
||||
|
||||
// FromContext returns the chat identity stored on the context.
|
||||
func FromContext(ctx context.Context) (Context, bool) {
|
||||
chatCtx, ok := ctx.Value(chatContextKey{}).(Context)
|
||||
if !ok || chatCtx.ID == uuid.Nil {
|
||||
return Context{}, false
|
||||
}
|
||||
return chatCtx, true
|
||||
}
|
||||
|
||||
// WithContext stores chat identity on the context for downstream logs.
|
||||
func WithContext(ctx context.Context, chatID uuid.UUID, ancestorIDs []uuid.UUID) context.Context {
|
||||
if chatID == uuid.Nil {
|
||||
return ctx
|
||||
}
|
||||
ancestors := make([]uuid.UUID, len(ancestorIDs))
|
||||
copy(ancestors, ancestorIDs)
|
||||
return context.WithValue(ctx, chatContextKey{}, Context{
|
||||
ID: chatID,
|
||||
AncestorIDs: ancestors,
|
||||
})
|
||||
}
|
||||
|
||||
// Fields returns structured log fields for the chat identity on ctx.
|
||||
func Fields(ctx context.Context) []slog.Field {
|
||||
chatCtx, ok := FromContext(ctx)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return chatFields(chatCtx.ID, chatCtx.AncestorIDs)
|
||||
}
|
||||
|
||||
// Middleware tags agent logs for requests that originate from
|
||||
// chatd. Agent log lines emitted while serving a request with Coder-Chat-Id,
|
||||
// or by background work started by such a request, should include chat_id.
|
||||
// Install after loggermw.Logger so access-log enrichment can run.
|
||||
func Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
chatID, ancestorIDs, ok := extractContext(r)
|
||||
if !ok {
|
||||
next.ServeHTTP(rw, r)
|
||||
return
|
||||
}
|
||||
|
||||
fields := chatFields(chatID, ancestorIDs)
|
||||
if requestLogger := loggermw.RequestLoggerFromContext(r.Context()); requestLogger != nil {
|
||||
requestLogger.WithFields(fields...)
|
||||
}
|
||||
|
||||
ctx := WithContext(r.Context(), chatID, ancestorIDs)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func chatFields(chatID uuid.UUID, ancestorIDs []uuid.UUID) []slog.Field {
|
||||
fields := []slog.Field{slog.F("chat_id", chatID.String())}
|
||||
if len(ancestorIDs) == 0 {
|
||||
return fields
|
||||
}
|
||||
|
||||
ancestors := make([]string, 0, len(ancestorIDs))
|
||||
for _, id := range ancestorIDs {
|
||||
ancestors = append(ancestors, id.String())
|
||||
}
|
||||
return append(fields, slog.F("ancestor_chat_ids", ancestors))
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package agentchat_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentchat"
|
||||
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestMiddlewareAccessLog(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
ancestorID := uuid.New()
|
||||
sink := testutil.NewFakeSink(t)
|
||||
handler := tracing.StatusWriterMiddleware(loggermw.Logger(sink.Logger())(
|
||||
agentchat.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
})),
|
||||
))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
|
||||
req.Header.Set(workspacesdk.CoderAncestorChatIDsHeader, mustMarshalJSON(t, []string{ancestorID.String()}))
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req)
|
||||
require.Equal(t, http.StatusNoContent, rw.Code)
|
||||
|
||||
entries := sink.Entries()
|
||||
require.Len(t, entries, 1)
|
||||
fields := fieldsByName(entries[0].Fields)
|
||||
require.Equal(t, chatID.String(), fields["chat_id"])
|
||||
require.Equal(t, []string{ancestorID.String()}, fields["ancestor_chat_ids"])
|
||||
}
|
||||
|
||||
func TestMiddlewareWithoutChatHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sink := testutil.NewFakeSink(t)
|
||||
handler := tracing.StatusWriterMiddleware(loggermw.Logger(sink.Logger())(
|
||||
agentchat.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
})),
|
||||
))
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, httptest.NewRequest(http.MethodGet, "/test", nil))
|
||||
require.Equal(t, http.StatusNoContent, rw.Code)
|
||||
|
||||
entries := sink.Entries()
|
||||
require.Len(t, entries, 1)
|
||||
fields := fieldsByName(entries[0].Fields)
|
||||
require.NotContains(t, fields, "chat_id")
|
||||
require.NotContains(t, fields, "ancestor_chat_ids")
|
||||
}
|
||||
|
||||
func TestMiddlewareContextFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
sink := testutil.NewFakeSink(t)
|
||||
handler := tracing.StatusWriterMiddleware(loggermw.Logger(sink.Logger())(
|
||||
agentchat.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
sink.Logger().With(agentchat.Fields(r.Context())...).Info(r.Context(), "handler log")
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
})),
|
||||
))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req)
|
||||
require.Equal(t, http.StatusNoContent, rw.Code)
|
||||
|
||||
entries := sink.Entries()
|
||||
require.Len(t, entries, 2)
|
||||
for _, entry := range entries {
|
||||
if entry.Message != "handler log" {
|
||||
continue
|
||||
}
|
||||
fields := fieldsByName(entry.Fields)
|
||||
require.Equal(t, chatID.String(), fields["chat_id"])
|
||||
return
|
||||
}
|
||||
t.Fatal("handler log entry not found")
|
||||
}
|
||||
|
||||
func fieldsByName(fields []slog.Field) map[string]any {
|
||||
byName := make(map[string]any, len(fields))
|
||||
for _, field := range fields {
|
||||
byName[field.Name] = field.Value
|
||||
}
|
||||
return byName
|
||||
}
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/agent/agentchat"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
@@ -86,6 +86,8 @@ func (api *API) HandleReadFile(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (api *API) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (HTTPResponseCode, error) {
|
||||
logger := api.logger.With(agentchat.Fields(ctx)...)
|
||||
|
||||
if !filepath.IsAbs(path) {
|
||||
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
|
||||
}
|
||||
@@ -131,7 +133,7 @@ func (api *API) streamFile(ctx context.Context, rw http.ResponseWriter, path str
|
||||
reader := io.NewSectionReader(f, offset, bytesToRead)
|
||||
_, err = io.Copy(rw, reader)
|
||||
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
|
||||
api.logger.Error(ctx, "workspace agent read file", slog.Error(err))
|
||||
logger.Error(ctx, "workspace agent read file", slog.Error(err))
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
@@ -322,8 +324,8 @@ func (api *API) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Track edited path for git watch.
|
||||
if api.pathStore != nil {
|
||||
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok {
|
||||
api.pathStore.AddPaths(append([]uuid.UUID{chatID}, ancestorIDs...), []string{path})
|
||||
if chatContext, ok := agentchat.FromContext(ctx); ok {
|
||||
api.pathStore.AddPaths(append([]uuid.UUID{chatContext.ID}, chatContext.AncestorIDs...), []string{path})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -458,12 +460,12 @@ func (api *API) HandleEditFiles(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Track edited paths for git watch.
|
||||
if api.pathStore != nil {
|
||||
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok {
|
||||
if chatContext, ok := agentchat.FromContext(ctx); ok {
|
||||
filePaths := make([]string, 0, len(req.Files))
|
||||
for _, f := range req.Files {
|
||||
filePaths = append(filePaths, f.Path)
|
||||
}
|
||||
api.pathStore.AddPaths(append([]uuid.UUID{chatID}, ancestorIDs...), filePaths)
|
||||
api.pathStore.AddPaths(append([]uuid.UUID{chatContext.ID}, chatContext.AncestorIDs...), filePaths)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -565,6 +567,8 @@ func (api *API) prepareFileEdit(path string, edits []workspacesdk.FileEdit) (int
|
||||
// On failure the temp file is cleaned up and the original is
|
||||
// untouched.
|
||||
func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode, r io.Reader) (int, error) {
|
||||
logger := api.logger.With(agentchat.Fields(ctx)...)
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
tmpName := filepath.Join(dir, fmt.Sprintf(".%s.tmp.%s", filepath.Base(path), uuid.New().String()[:8]))
|
||||
|
||||
@@ -579,7 +583,7 @@ func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode,
|
||||
|
||||
cleanup := func() {
|
||||
if err := api.filesystem.Remove(tmpName); err != nil {
|
||||
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(err))
|
||||
logger.Warn(ctx, "unable to clean up temp file", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -601,7 +605,7 @@ func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode,
|
||||
// no window where the target has wrong permissions.
|
||||
if mode != nil {
|
||||
if err := api.filesystem.Chmod(tmpName, *mode); err != nil {
|
||||
api.logger.Warn(ctx, "unable to set file permissions",
|
||||
logger.Warn(ctx, "unable to set file permissions",
|
||||
slog.F("path", path),
|
||||
slog.Error(err),
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentchat"
|
||||
"github.com/coder/coder/v2/agent/agentfiles"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -1157,7 +1158,7 @@ func TestHandleWriteFile_ChatHeaders_UpdatesPathStore(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/write-file", api.HandleWriteFile)
|
||||
r.ServeHTTP(rr, req)
|
||||
agentchat.Middleware(r).ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
@@ -1185,7 +1186,7 @@ func TestHandleWriteFile_NoChatHeaders_NoPathStoreUpdate(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/write-file", api.HandleWriteFile)
|
||||
r.ServeHTTP(rr, req)
|
||||
agentchat.Middleware(r).ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
@@ -1211,7 +1212,7 @@ func TestHandleWriteFile_Failure_NoPathStoreUpdate(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/write-file", api.HandleWriteFile)
|
||||
r.ServeHTTP(rr, req)
|
||||
agentchat.Middleware(r).ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
|
||||
@@ -1252,7 +1253,7 @@ func TestHandleEditFiles_ChatHeaders_UpdatesPathStore(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/edit-files", api.HandleEditFiles)
|
||||
r.ServeHTTP(rr, req)
|
||||
agentchat.Middleware(r).ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
@@ -1289,7 +1290,7 @@ func TestHandleEditFiles_Failure_NoPathStoreUpdate(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/edit-files", api.HandleEditFiles)
|
||||
r.ServeHTTP(rr, req)
|
||||
agentchat.Middleware(r).ServeHTTP(rr, req)
|
||||
|
||||
require.NotEqual(t, http.StatusOK, rr.Code)
|
||||
|
||||
|
||||
+51
-35
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentchat"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/wsjson"
|
||||
@@ -40,6 +41,25 @@ func (a *API) Routes() http.Handler {
|
||||
func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var watchChatID uuid.UUID
|
||||
var hasWatchChatID bool
|
||||
if chatIDStr := r.URL.Query().Get("chat_id"); chatIDStr != "" {
|
||||
if parsedChatID, parseErr := uuid.Parse(chatIDStr); parseErr == nil {
|
||||
watchChatID = parsedChatID
|
||||
hasWatchChatID = true
|
||||
|
||||
// Reuse header-derived ancestors only when the query chat
|
||||
// matches the header chat. Otherwise the ancestors belong
|
||||
// to a different chat and would be misleading in logs.
|
||||
var ancestors []uuid.UUID
|
||||
if chatContext, ok := agentchat.FromContext(ctx); ok && chatContext.ID == watchChatID {
|
||||
ancestors = chatContext.AncestorIDs
|
||||
}
|
||||
ctx = agentchat.WithContext(ctx, watchChatID, ancestors)
|
||||
}
|
||||
}
|
||||
logger := a.logger.With(agentchat.Fields(ctx)...)
|
||||
|
||||
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
||||
CompressionMode: websocket.CompressionNoContextTakeover,
|
||||
})
|
||||
@@ -58,14 +78,14 @@ func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) {
|
||||
stream := wsjson.NewStream[
|
||||
codersdk.WorkspaceAgentGitClientMessage,
|
||||
codersdk.WorkspaceAgentGitServerMessage,
|
||||
](conn, websocket.MessageText, websocket.MessageText, a.logger)
|
||||
](conn, websocket.MessageText, websocket.MessageText, logger)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, a.logger, cancel, conn)
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
handler := NewHandler(a.logger, a.opts...)
|
||||
handler := NewHandler(logger, a.opts...)
|
||||
|
||||
// Scan returns nil only when no roots are subscribed; once any
|
||||
// root lands it returns either a delta or a heartbeat message.
|
||||
@@ -75,46 +95,42 @@ func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if err := stream.Send(*msg); err != nil {
|
||||
a.logger.Debug(ctx, "failed to send changes", slog.Error(err))
|
||||
logger.Debug(ctx, "failed to send changes", slog.Error(err))
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// If a chat_id query parameter is provided and the PathStore is
|
||||
// available, subscribe to path updates for this chat.
|
||||
chatIDStr := r.URL.Query().Get("chat_id")
|
||||
if chatIDStr != "" && a.pathStore != nil {
|
||||
chatID, parseErr := uuid.Parse(chatIDStr)
|
||||
if parseErr == nil {
|
||||
// Subscribe to future path updates BEFORE reading
|
||||
// existing paths. This ordering guarantees no
|
||||
// notification from AddPaths is lost: any call that
|
||||
// lands before Subscribe is picked up by GetPaths
|
||||
// below, and any call after Subscribe delivers a
|
||||
// notification on the channel.
|
||||
notifyCh, unsubscribe := a.pathStore.Subscribe(chatID)
|
||||
defer unsubscribe()
|
||||
if hasWatchChatID && a.pathStore != nil {
|
||||
// Subscribe to future path updates BEFORE reading
|
||||
// existing paths. This ordering guarantees no
|
||||
// notification from AddPaths is lost: any call that
|
||||
// lands before Subscribe is picked up by GetPaths
|
||||
// below, and any call after Subscribe delivers a
|
||||
// notification on the channel.
|
||||
notifyCh, unsubscribe := a.pathStore.Subscribe(watchChatID)
|
||||
defer unsubscribe()
|
||||
|
||||
// Load any paths that are already tracked for this chat.
|
||||
existingPaths := a.pathStore.GetPaths(chatID)
|
||||
if len(existingPaths) > 0 {
|
||||
handler.Subscribe(existingPaths)
|
||||
handler.RequestScan()
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-notifyCh:
|
||||
paths := a.pathStore.GetPaths(chatID)
|
||||
handler.Subscribe(paths)
|
||||
handler.RequestScan()
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Load any paths that are already tracked for this chat.
|
||||
existingPaths := a.pathStore.GetPaths(watchChatID)
|
||||
if len(existingPaths) > 0 {
|
||||
handler.Subscribe(existingPaths)
|
||||
handler.RequestScan()
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-notifyCh:
|
||||
paths := a.pathStore.GetPaths(watchChatID)
|
||||
handler.Subscribe(paths)
|
||||
handler.RequestScan()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Start the main run loop in a goroutine.
|
||||
|
||||
+13
-11
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentchat"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
@@ -80,8 +81,8 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var chatID string
|
||||
if id, _, ok := agentgit.ExtractChatContext(r); ok {
|
||||
chatID = id.String()
|
||||
if chatContext, ok := agentchat.FromContext(ctx); ok {
|
||||
chatID = chatContext.ID.String()
|
||||
}
|
||||
|
||||
proc, err := api.manager.start(req, chatID)
|
||||
@@ -97,8 +98,8 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
|
||||
// file changes made by the command are visible in the scan.
|
||||
// If a workdir is provided, track it as a path as well.
|
||||
if api.pathStore != nil {
|
||||
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok {
|
||||
allIDs := append([]uuid.UUID{chatID}, ancestorIDs...)
|
||||
if chatContext, ok := agentchat.FromContext(ctx); ok {
|
||||
allIDs := append([]uuid.UUID{chatContext.ID}, chatContext.AncestorIDs...)
|
||||
go func() {
|
||||
<-proc.done
|
||||
if req.WorkDir != "" {
|
||||
@@ -121,8 +122,8 @@ func (api *API) handleListProcesses(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var chatID string
|
||||
if id, _, ok := agentgit.ExtractChatContext(r); ok {
|
||||
chatID = id.String()
|
||||
if chatContext, ok := agentchat.FromContext(ctx); ok {
|
||||
chatID = chatContext.ID.String()
|
||||
}
|
||||
|
||||
infos := api.manager.list(chatID)
|
||||
@@ -150,6 +151,7 @@ func (api *API) handleListProcesses(rw http.ResponseWriter, r *http.Request) {
|
||||
// handleProcessOutput returns the output of a process.
|
||||
func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
logger := api.logger.With(agentchat.Fields(ctx)...)
|
||||
|
||||
id := chi.URLParam(r, "id")
|
||||
proc, ok := api.manager.get(id)
|
||||
@@ -163,8 +165,8 @@ func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
|
||||
// Enforce chat ID isolation. If the request carries
|
||||
// a chat context, only allow access to processes
|
||||
// belonging to that chat.
|
||||
if chatID, _, ok := agentgit.ExtractChatContext(r); ok {
|
||||
if proc.chatID != "" && proc.chatID != chatID.String() {
|
||||
if chatContext, ok := agentchat.FromContext(ctx); ok {
|
||||
if proc.chatID != "" && proc.chatID != chatContext.ID.String() {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: fmt.Sprintf("Process %q not found.", id),
|
||||
})
|
||||
@@ -184,7 +186,7 @@ func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
|
||||
// Add headroom beyond the wait timeout so there's time to
|
||||
// write the response after the blocking wait completes.
|
||||
if err := rc.SetWriteDeadline(time.Now().Add(maxWaitDuration + 30*time.Second)); err != nil {
|
||||
api.logger.Error(ctx, "extend write deadline for blocking process output",
|
||||
logger.Error(ctx, "extend write deadline for blocking process output",
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
@@ -216,9 +218,9 @@ func (api *API) handleSignalProcess(rw http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
// Enforce chat ID isolation.
|
||||
if chatID, _, ok := agentgit.ExtractChatContext(r); ok {
|
||||
if chatContext, ok := agentchat.FromContext(ctx); ok {
|
||||
proc, procOK := api.manager.get(id)
|
||||
if procOK && proc.chatID != "" && proc.chatID != chatID.String() {
|
||||
if procOK && proc.chatID != "" && proc.chatID != chatContext.ID.String() {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: fmt.Sprintf("Process %q not found.", id),
|
||||
})
|
||||
|
||||
@@ -20,9 +20,12 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentchat"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/agent/agentproc"
|
||||
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -137,7 +140,35 @@ func newTestAPIWithOptions(t *testing.T, updateEnv func([]string) ([]string, err
|
||||
t.Cleanup(func() {
|
||||
_ = api.Close()
|
||||
})
|
||||
return api.Routes()
|
||||
return agentchat.Middleware(api.Routes())
|
||||
}
|
||||
|
||||
func TestAccessLogIncludesChatID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sink := testutil.NewFakeSink(t)
|
||||
logger := sink.Logger()
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, nil, nil, nil)
|
||||
t.Cleanup(func() {
|
||||
_ = api.Close()
|
||||
})
|
||||
handler := tracing.StatusWriterMiddleware(loggermw.Logger(logger)(
|
||||
agentchat.Middleware(api.Routes()),
|
||||
))
|
||||
|
||||
chatID := uuid.New().String()
|
||||
w := getListWithChatHeader(t, handler, chatID)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
entries := sink.Entries(func(entry slog.SinkEntry) bool {
|
||||
return entry.Message == http.MethodGet
|
||||
})
|
||||
require.Len(t, entries, 1)
|
||||
fields := make(map[string]any, len(entries[0].Fields))
|
||||
for _, field := range entries[0].Fields {
|
||||
fields[field.Name] = field.Value
|
||||
}
|
||||
require.Equal(t, chatID, fields["chat_id"])
|
||||
}
|
||||
|
||||
// waitForExit polls the output endpoint until the process is
|
||||
@@ -1058,7 +1089,7 @@ func TestHandleStartProcess_ChatHeaders_EmptyWorkDir_StillNotifies(t *testing.T)
|
||||
}, pathStore, nil)
|
||||
defer api.Close()
|
||||
|
||||
routes := api.Routes()
|
||||
routes := agentchat.Middleware(api.Routes())
|
||||
|
||||
body, err := json.Marshal(workspacesdk.StartProcessRequest{
|
||||
Command: "echo hello",
|
||||
|
||||
@@ -38,6 +38,7 @@ type process struct {
|
||||
cmd *exec.Cmd
|
||||
cancel context.CancelFunc
|
||||
buf *HeadTailBuffer
|
||||
logger slog.Logger
|
||||
running bool
|
||||
exitCode *int
|
||||
startedAt int64
|
||||
@@ -105,6 +106,10 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
|
||||
m.mu.Unlock()
|
||||
|
||||
id := uuid.New().String()
|
||||
logger := m.logger
|
||||
if chatID != "" {
|
||||
logger = logger.With(slog.F("chat_id", chatID))
|
||||
}
|
||||
|
||||
// Use a cancellable context so Close() can terminate
|
||||
// all processes. context.Background() is the parent so
|
||||
@@ -132,7 +137,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
|
||||
if m.updateEnv != nil {
|
||||
updated, err := m.updateEnv(baseEnv)
|
||||
if err != nil {
|
||||
m.logger.Warn(
|
||||
logger.Warn(
|
||||
context.Background(),
|
||||
"failed to update command environment, falling back to os env",
|
||||
slog.Error(err),
|
||||
@@ -169,6 +174,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
|
||||
cmd: cmd,
|
||||
cancel: cancel,
|
||||
buf: buf,
|
||||
logger: logger,
|
||||
running: true,
|
||||
startedAt: now,
|
||||
done: make(chan struct{}),
|
||||
@@ -202,7 +208,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
|
||||
} else {
|
||||
// Unknown error; use -1 as a sentinel.
|
||||
code = -1
|
||||
m.logger.Warn(
|
||||
proc.logger.Warn(
|
||||
context.Background(),
|
||||
"process wait returned non-exit error",
|
||||
slog.F("id", id),
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentchat"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
@@ -20,6 +21,7 @@ func (a *agent) apiHandler() http.Handler {
|
||||
httpmw.Recover(a.logger),
|
||||
tracing.StatusWriterMiddleware,
|
||||
loggermw.Logger(a.logger),
|
||||
agentchat.Middleware,
|
||||
)
|
||||
r.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
|
||||
|
||||
+17
-13
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentchat"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -85,6 +86,7 @@ func (a *API) Routes() http.Handler {
|
||||
|
||||
func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
logger := a.logger.With(agentchat.Fields(ctx)...)
|
||||
|
||||
// Start the desktop session (idempotent).
|
||||
_, err := a.desktop.Start(ctx)
|
||||
@@ -112,7 +114,7 @@ func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) {
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Error(ctx, "failed to accept websocket", slog.Error(err))
|
||||
logger.Error(ctx, "failed to accept websocket", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -128,6 +130,7 @@ func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
logger := a.logger.With(agentchat.Fields(ctx)...)
|
||||
handlerStart := a.clock.Now()
|
||||
|
||||
// Update last desktop action timestamp for idle recording monitor.
|
||||
@@ -136,7 +139,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
// Ensure the desktop is running and grab native dimensions.
|
||||
cfg, err := a.desktop.Start(ctx)
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "handleAction: desktop.Start failed",
|
||||
logger.Warn(ctx, "handleAction: desktop.Start failed",
|
||||
slog.Error(err),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
)
|
||||
@@ -156,7 +159,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
a.logger.Info(ctx, "handleAction: started",
|
||||
logger.Info(ctx, "handleAction: started",
|
||||
slog.F("action", action.Action),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
)
|
||||
@@ -272,7 +275,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
x, y = scaleXY(x, y)
|
||||
stepStart := a.clock.Now()
|
||||
if err := a.desktop.Click(ctx, x, y, MouseButtonLeft); err != nil {
|
||||
a.logger.Warn(ctx, "handleAction: Click failed",
|
||||
logger.Warn(ctx, "handleAction: Click failed",
|
||||
slog.F("action", "left_click"),
|
||||
slog.F("step", "click"),
|
||||
slog.F("step_ms", time.Since(stepStart).Milliseconds()),
|
||||
@@ -285,7 +288,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
a.logger.Debug(ctx, "handleAction: Click completed",
|
||||
logger.Debug(ctx, "handleAction: Click completed",
|
||||
slog.F("action", "left_click"),
|
||||
slog.F("step_ms", time.Since(stepStart).Milliseconds()),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
@@ -473,7 +476,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
case <-ctx.Done():
|
||||
// Context canceled; release the key immediately.
|
||||
if err := a.desktop.KeyUp(ctx, *action.Text); err != nil {
|
||||
a.logger.Warn(ctx, "handleAction: KeyUp after context cancel", slog.Error(err))
|
||||
logger.Warn(ctx, "handleAction: KeyUp after context cancel", slog.Error(err))
|
||||
}
|
||||
return
|
||||
case <-timer.C:
|
||||
@@ -513,14 +516,14 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
elapsedMs := a.clock.Since(handlerStart).Milliseconds()
|
||||
if ctx.Err() != nil {
|
||||
a.logger.Error(ctx, "handleAction: context canceled before writing response",
|
||||
logger.Error(ctx, "handleAction: context canceled before writing response",
|
||||
slog.F("action", action.Action),
|
||||
slog.F("elapsed_ms", elapsedMs),
|
||||
slog.Error(ctx.Err()),
|
||||
)
|
||||
return
|
||||
}
|
||||
a.logger.Info(ctx, "handleAction: writing response",
|
||||
logger.Info(ctx, "handleAction: writing response",
|
||||
slog.F("action", action.Action),
|
||||
slog.F("elapsed_ms", elapsedMs),
|
||||
)
|
||||
@@ -609,6 +612,7 @@ func (a *API) handleRecordingStart(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
logger := a.logger.With(agentchat.Fields(ctx)...)
|
||||
|
||||
recordingID, ok := a.decodeRecordingRequest(rw, r)
|
||||
if !ok {
|
||||
@@ -661,7 +665,7 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
}()
|
||||
|
||||
if artifact.Size > workspacesdk.MaxRecordingSize {
|
||||
a.logger.Warn(ctx, "recording file exceeds maximum size",
|
||||
logger.Warn(ctx, "recording file exceeds maximum size",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("size", artifact.Size),
|
||||
slog.F("max_size", workspacesdk.MaxRecordingSize),
|
||||
@@ -677,7 +681,7 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
// rejecting it here avoids streaming a large thumbnail over
|
||||
// the wire for nothing.
|
||||
if artifact.ThumbnailReader != nil && artifact.ThumbnailSize > workspacesdk.MaxThumbnailSize {
|
||||
a.logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting",
|
||||
logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("size", artifact.ThumbnailSize),
|
||||
slog.F("max_size", workspacesdk.MaxThumbnailSize),
|
||||
@@ -701,13 +705,13 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
"Content-Type": {"video/mp4"},
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to create video multipart part",
|
||||
logger.Warn(ctx, "failed to create video multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
if _, err := io.Copy(videoPart, artifact.Reader); err != nil {
|
||||
a.logger.Warn(ctx, "failed to write video multipart part",
|
||||
logger.Warn(ctx, "failed to write video multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
@@ -719,7 +723,7 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
"Content-Type": {"image/jpeg"},
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to create thumbnail multipart part",
|
||||
logger.Warn(ctx, "failed to create thumbnail multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentchat"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
@@ -52,6 +53,7 @@ func (api *API) Routes() http.Handler {
|
||||
// independent of config changes.
|
||||
func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
logger := api.logger.With(agentchat.Fields(ctx)...)
|
||||
|
||||
// Check config freshness and reload if changed.
|
||||
var reloaded bool
|
||||
@@ -61,11 +63,11 @@ func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) {
|
||||
// Categorize the error for operator debugging.
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
api.logger.Warn(ctx, "mcp reload canceled by caller", slog.Error(err))
|
||||
logger.Warn(ctx, "mcp reload canceled by caller", slog.Error(err))
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
api.logger.Warn(ctx, "mcp reload timed out", slog.Error(err))
|
||||
logger.Warn(ctx, "mcp reload timed out", slog.Error(err))
|
||||
default:
|
||||
api.logger.Warn(ctx, "mcp reload failed", slog.Error(err))
|
||||
logger.Warn(ctx, "mcp reload failed", slog.Error(err))
|
||||
}
|
||||
// Fall through to return whatever tools we have.
|
||||
} else {
|
||||
@@ -78,7 +80,7 @@ func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) {
|
||||
// refreshes tools as part of the reload.
|
||||
if r.URL.Query().Get("refresh") == "true" && !reloaded {
|
||||
if err := api.manager.RefreshTools(ctx); err != nil {
|
||||
api.logger.Warn(ctx, "failed to refresh MCP tools", slog.Error(err))
|
||||
logger.Warn(ctx, "failed to refresh MCP tools", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
tailscalesingleflight "tailscale.com/util/singleflight"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentchat"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/usershell"
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
@@ -248,7 +249,8 @@ func (m *Manager) doReload(ctx context.Context, mcpConfigFiles []string) error {
|
||||
// Refresh tools outside the lock to avoid blocking
|
||||
// concurrent reads during network I/O.
|
||||
if err := m.RefreshTools(ctx); err != nil {
|
||||
m.logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err))
|
||||
logger := m.logger.With(agentchat.Fields(ctx)...)
|
||||
logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -257,6 +259,8 @@ func (m *Manager) doReload(ctx context.Context, mcpConfigFiles []string) error {
|
||||
// list of server configs. Missing files are silently skipped;
|
||||
// parse errors are logged and skipped.
|
||||
func (m *Manager) parseAndDedup(ctx context.Context, mcpConfigFiles []string) ([]ServerConfig, map[string]fileSnapshot) {
|
||||
logger := m.logger.With(agentchat.Fields(ctx)...)
|
||||
|
||||
// Stat before reading so the snapshot is conservatively old.
|
||||
// If a file changes between stat and read, the snapshot
|
||||
// records the old mtime, SnapshotChanged detects a mismatch
|
||||
@@ -272,7 +276,7 @@ func (m *Manager) parseAndDedup(ctx context.Context, mcpConfigFiles []string) ([
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
continue
|
||||
}
|
||||
m.logger.Warn(ctx, "failed to parse MCP config",
|
||||
logger.Warn(ctx, "failed to parse MCP config",
|
||||
slog.F("path", configPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
@@ -334,6 +338,8 @@ func (m *Manager) classifyServers(wanted map[string]ServerConfig) (*serverDiff,
|
||||
// connectAll runs connectServer in parallel for the given configs.
|
||||
// Failed connects are logged and skipped.
|
||||
func (m *Manager) connectAll(ctx context.Context, toConnect []ServerConfig) []connectedServer {
|
||||
logger := m.logger.With(agentchat.Fields(ctx)...)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
connected []connectedServer
|
||||
@@ -343,7 +349,7 @@ func (m *Manager) connectAll(ctx context.Context, toConnect []ServerConfig) []co
|
||||
eg.Go(func() error {
|
||||
c, err := m.connectServer(ctx, cfg)
|
||||
if err != nil {
|
||||
m.logger.Warn(ctx, "skipping MCP server",
|
||||
logger.Warn(ctx, "skipping MCP server",
|
||||
slog.F("server", cfg.Name),
|
||||
slog.F("transport", cfg.Transport),
|
||||
slog.Error(err),
|
||||
@@ -481,6 +487,8 @@ func (m *Manager) CallTool(ctx context.Context, req workspacesdk.CallMCPToolRequ
|
||||
// existing cached tools for servers that failed, so a single
|
||||
// dead server doesn't block updates from healthy ones.
|
||||
func (m *Manager) RefreshTools(ctx context.Context) error {
|
||||
logger := m.logger.With(agentchat.Fields(ctx)...)
|
||||
|
||||
// Snapshot servers under read lock.
|
||||
m.mu.RLock()
|
||||
servers := make(map[string]*serverEntry, len(m.servers))
|
||||
@@ -508,7 +516,7 @@ func (m *Manager) RefreshTools(ctx context.Context) error {
|
||||
result, err := entry.client.ListTools(listCtx, mcp.ListToolsRequest{})
|
||||
cancel()
|
||||
if err != nil {
|
||||
m.logger.Warn(ctx, "failed to list tools from MCP server",
|
||||
logger.Warn(ctx, "failed to list tools from MCP server",
|
||||
slog.F("server", name),
|
||||
slog.Error(err),
|
||||
)
|
||||
@@ -670,12 +678,14 @@ func (m *Manager) createTransport(ctx context.Context, cfg ServerConfig) (transp
|
||||
// updateEnv callback, then merges explicit overrides from the
|
||||
// server config on top.
|
||||
func (m *Manager) buildEnv(ctx context.Context, explicit map[string]string) []string {
|
||||
logger := m.logger.With(agentchat.Fields(ctx)...)
|
||||
|
||||
env := usershell.SystemEnvInfo{}.Environ()
|
||||
if m.updateEnv != nil {
|
||||
var err error
|
||||
env, err = m.updateEnv(env)
|
||||
if err != nil {
|
||||
m.logger.Warn(ctx, "failed to enrich MCP server environment",
|
||||
logger.Warn(ctx, "failed to enrich MCP server environment",
|
||||
slog.Error(err),
|
||||
)
|
||||
env = usershell.SystemEnvInfo{}.Environ()
|
||||
|
||||
@@ -905,6 +905,12 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces
|
||||
})
|
||||
|
||||
c.mu.Unlock()
|
||||
c.server.logger.Debug(ctx, "set chat headers on agent conn",
|
||||
slog.F("chat_id", chatSnapshot.ID),
|
||||
slog.F("ancestor_chat_ids", ancestorIDs),
|
||||
slog.F("workspace_id", chatSnapshot.WorkspaceID.UUID),
|
||||
slog.F("agent_id", dialResult.AgentID),
|
||||
)
|
||||
return agentConn, nil
|
||||
}
|
||||
currentConn = c.conn
|
||||
|
||||
Reference in New Issue
Block a user