feat: tag chat-originating agent logs with chat_id (#25019)

Workspace-agent logs emitted while serving chatd-driven requests were
not correlated with the originating chat, making agent logs hard to
attribute to the corresponding/originating chat.

This adds agent-side chat context middleware that parses `Coder-Chat-Id`
once, enriches agent access logs and structured handler/background logs,
and adds a chatd bridge log when chat headers are attached to an agent
connection.

Closes CODAGT-324
This commit is contained in:
Ethan
2026-05-08 13:25:30 +10:00
committed by GitHub
parent e9f0385198
commit 3a9080fff6
15 changed files with 380 additions and 95 deletions
@@ -1,4 +1,4 @@
package agentgit package agentchat
import ( import (
"encoding/json" "encoding/json"
@@ -9,9 +9,9 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/codersdk/workspacesdk"
) )
// ExtractChatContext reads chat identity headers from the request. // extractContext reads chat identity headers from the request.
// Returns zero values if headers are absent (non-chat request). // Returns zero values if headers are absent (non-chat request).
func ExtractChatContext(r *http.Request) (chatID uuid.UUID, ancestorIDs []uuid.UUID, ok bool) { func extractContext(r *http.Request) (chatID uuid.UUID, ancestorIDs []uuid.UUID, ok bool) {
raw := r.Header.Get(workspacesdk.CoderChatIDHeader) raw := r.Header.Get(workspacesdk.CoderChatIDHeader)
if raw == "" { if raw == "" {
return uuid.Nil, nil, false return uuid.Nil, nil, false
@@ -1,18 +1,19 @@
package agentgit_test package agentchat_test
import ( import (
"encoding/json" "encoding/json"
"net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentgit" "github.com/coder/coder/v2/agent/agentchat"
"github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/codersdk/workspacesdk"
) )
func TestExtractChatContext(t *testing.T) { func TestExtractContext(t *testing.T) {
t.Parallel() t.Parallel()
validID := uuid.MustParse("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") validID := uuid.MustParse("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee")
@@ -43,7 +44,7 @@ func TestExtractChatContext(t *testing.T) {
setChatID: true, setChatID: true,
setAncestors: false, setAncestors: false,
wantChatID: validID, wantChatID: validID,
wantAncestorIDs: nil, wantAncestorIDs: []uuid.UUID{},
wantOK: true, wantOK: true,
}, },
{ {
@@ -75,7 +76,7 @@ func TestExtractChatContext(t *testing.T) {
ancestors: `{this is not json}`, ancestors: `{this is not json}`,
setAncestors: true, setAncestors: true,
wantChatID: validID, wantChatID: validID,
wantAncestorIDs: nil, wantAncestorIDs: []uuid.UUID{},
wantOK: true, wantOK: true,
}, },
{ {
@@ -112,7 +113,7 @@ func TestExtractChatContext(t *testing.T) {
ancestors: "", ancestors: "",
setAncestors: true, setAncestors: true,
wantChatID: validID, wantChatID: validID,
wantAncestorIDs: nil, wantAncestorIDs: []uuid.UUID{},
wantOK: true, wantOK: true,
}, },
} }
@@ -130,7 +131,7 @@ func TestExtractChatContext(t *testing.T) {
r.Header.Set(workspacesdk.CoderAncestorChatIDsHeader, tt.ancestors) r.Header.Set(workspacesdk.CoderAncestorChatIDsHeader, tt.ancestors)
} }
chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r) chatID, ancestorIDs, ok := extractContextForTest(r)
require.Equal(t, tt.wantOK, ok, "ok mismatch") require.Equal(t, tt.wantOK, ok, "ok mismatch")
require.Equal(t, tt.wantChatID, chatID, "chatID mismatch") require.Equal(t, tt.wantChatID, chatID, "chatID mismatch")
@@ -139,6 +140,18 @@ func TestExtractChatContext(t *testing.T) {
} }
} }
func extractContextForTest(r *http.Request) (uuid.UUID, []uuid.UUID, bool) {
var chatContext agentchat.Context
var ok bool
agentchat.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
chatContext, ok = agentchat.FromContext(r.Context())
})).ServeHTTP(httptest.NewRecorder(), r)
if !ok {
return uuid.Nil, nil, false
}
return chatContext.ID, chatContext.AncestorIDs, true
}
// mustMarshalJSON marshals v to a JSON string, failing the test on error. // mustMarshalJSON marshals v to a JSON string, failing the test on error.
func mustMarshalJSON(t *testing.T, v any) string { func mustMarshalJSON(t *testing.T, v any) string {
t.Helper() t.Helper()
+85
View File
@@ -0,0 +1,85 @@
package agentchat
import (
"context"
"net/http"
"github.com/google/uuid"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
)
type chatContextKey struct{}
// Context carries the chat identity associated with an agent request.
type Context struct {
ID uuid.UUID
AncestorIDs []uuid.UUID
}
// FromContext returns the chat identity stored on the context.
func FromContext(ctx context.Context) (Context, bool) {
chatCtx, ok := ctx.Value(chatContextKey{}).(Context)
if !ok || chatCtx.ID == uuid.Nil {
return Context{}, false
}
return chatCtx, true
}
// WithContext stores chat identity on the context for downstream logs.
func WithContext(ctx context.Context, chatID uuid.UUID, ancestorIDs []uuid.UUID) context.Context {
if chatID == uuid.Nil {
return ctx
}
ancestors := make([]uuid.UUID, len(ancestorIDs))
copy(ancestors, ancestorIDs)
return context.WithValue(ctx, chatContextKey{}, Context{
ID: chatID,
AncestorIDs: ancestors,
})
}
// Fields returns structured log fields for the chat identity on ctx.
func Fields(ctx context.Context) []slog.Field {
chatCtx, ok := FromContext(ctx)
if !ok {
return nil
}
return chatFields(chatCtx.ID, chatCtx.AncestorIDs)
}
// Middleware tags agent logs for requests that originate from
// chatd. Agent log lines emitted while serving a request with Coder-Chat-Id,
// or by background work started by such a request, should include chat_id.
// Install after loggermw.Logger so access-log enrichment can run.
func Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
chatID, ancestorIDs, ok := extractContext(r)
if !ok {
next.ServeHTTP(rw, r)
return
}
fields := chatFields(chatID, ancestorIDs)
if requestLogger := loggermw.RequestLoggerFromContext(r.Context()); requestLogger != nil {
requestLogger.WithFields(fields...)
}
ctx := WithContext(r.Context(), chatID, ancestorIDs)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
func chatFields(chatID uuid.UUID, ancestorIDs []uuid.UUID) []slog.Field {
fields := []slog.Field{slog.F("chat_id", chatID.String())}
if len(ancestorIDs) == 0 {
return fields
}
ancestors := make([]string, 0, len(ancestorIDs))
for _, id := range ancestorIDs {
ancestors = append(ancestors, id.String())
}
return append(fields, slog.F("ancestor_chat_ids", ancestors))
}
+103
View File
@@ -0,0 +1,103 @@
package agentchat_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentchat"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/testutil"
)
func TestMiddlewareAccessLog(t *testing.T) {
t.Parallel()
chatID := uuid.New()
ancestorID := uuid.New()
sink := testutil.NewFakeSink(t)
handler := tracing.StatusWriterMiddleware(loggermw.Logger(sink.Logger())(
agentchat.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
rw.WriteHeader(http.StatusNoContent)
})),
))
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
req.Header.Set(workspacesdk.CoderAncestorChatIDsHeader, mustMarshalJSON(t, []string{ancestorID.String()}))
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
require.Equal(t, http.StatusNoContent, rw.Code)
entries := sink.Entries()
require.Len(t, entries, 1)
fields := fieldsByName(entries[0].Fields)
require.Equal(t, chatID.String(), fields["chat_id"])
require.Equal(t, []string{ancestorID.String()}, fields["ancestor_chat_ids"])
}
func TestMiddlewareWithoutChatHeader(t *testing.T) {
t.Parallel()
sink := testutil.NewFakeSink(t)
handler := tracing.StatusWriterMiddleware(loggermw.Logger(sink.Logger())(
agentchat.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
rw.WriteHeader(http.StatusNoContent)
})),
))
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, httptest.NewRequest(http.MethodGet, "/test", nil))
require.Equal(t, http.StatusNoContent, rw.Code)
entries := sink.Entries()
require.Len(t, entries, 1)
fields := fieldsByName(entries[0].Fields)
require.NotContains(t, fields, "chat_id")
require.NotContains(t, fields, "ancestor_chat_ids")
}
func TestMiddlewareContextFields(t *testing.T) {
t.Parallel()
chatID := uuid.New()
sink := testutil.NewFakeSink(t)
handler := tracing.StatusWriterMiddleware(loggermw.Logger(sink.Logger())(
agentchat.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
sink.Logger().With(agentchat.Fields(r.Context())...).Info(r.Context(), "handler log")
rw.WriteHeader(http.StatusNoContent)
})),
))
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
require.Equal(t, http.StatusNoContent, rw.Code)
entries := sink.Entries()
require.Len(t, entries, 2)
for _, entry := range entries {
if entry.Message != "handler log" {
continue
}
fields := fieldsByName(entry.Fields)
require.Equal(t, chatID.String(), fields["chat_id"])
return
}
t.Fatal("handler log entry not found")
}
func fieldsByName(fields []slog.Field) map[string]any {
byName := make(map[string]any, len(fields))
for _, field := range fields {
byName[field.Name] = field.Value
}
return byName
}
+12 -8
View File
@@ -18,7 +18,7 @@ import (
"golang.org/x/xerrors" "golang.org/x/xerrors"
"cdr.dev/slog/v3" "cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentgit" "github.com/coder/coder/v2/agent/agentchat"
"github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/codersdk/workspacesdk"
@@ -86,6 +86,8 @@ func (api *API) HandleReadFile(rw http.ResponseWriter, r *http.Request) {
} }
func (api *API) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (HTTPResponseCode, error) { func (api *API) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (HTTPResponseCode, error) {
logger := api.logger.With(agentchat.Fields(ctx)...)
if !filepath.IsAbs(path) { if !filepath.IsAbs(path) {
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path) return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
} }
@@ -131,7 +133,7 @@ func (api *API) streamFile(ctx context.Context, rw http.ResponseWriter, path str
reader := io.NewSectionReader(f, offset, bytesToRead) reader := io.NewSectionReader(f, offset, bytesToRead)
_, err = io.Copy(rw, reader) _, err = io.Copy(rw, reader)
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil { if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
api.logger.Error(ctx, "workspace agent read file", slog.Error(err)) logger.Error(ctx, "workspace agent read file", slog.Error(err))
} }
return 0, nil return 0, nil
@@ -322,8 +324,8 @@ func (api *API) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
// Track edited path for git watch. // Track edited path for git watch.
if api.pathStore != nil { if api.pathStore != nil {
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok { if chatContext, ok := agentchat.FromContext(ctx); ok {
api.pathStore.AddPaths(append([]uuid.UUID{chatID}, ancestorIDs...), []string{path}) api.pathStore.AddPaths(append([]uuid.UUID{chatContext.ID}, chatContext.AncestorIDs...), []string{path})
} }
} }
@@ -458,12 +460,12 @@ func (api *API) HandleEditFiles(rw http.ResponseWriter, r *http.Request) {
// Track edited paths for git watch. // Track edited paths for git watch.
if api.pathStore != nil { if api.pathStore != nil {
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok { if chatContext, ok := agentchat.FromContext(ctx); ok {
filePaths := make([]string, 0, len(req.Files)) filePaths := make([]string, 0, len(req.Files))
for _, f := range req.Files { for _, f := range req.Files {
filePaths = append(filePaths, f.Path) filePaths = append(filePaths, f.Path)
} }
api.pathStore.AddPaths(append([]uuid.UUID{chatID}, ancestorIDs...), filePaths) api.pathStore.AddPaths(append([]uuid.UUID{chatContext.ID}, chatContext.AncestorIDs...), filePaths)
} }
} }
@@ -565,6 +567,8 @@ func (api *API) prepareFileEdit(path string, edits []workspacesdk.FileEdit) (int
// On failure the temp file is cleaned up and the original is // On failure the temp file is cleaned up and the original is
// untouched. // untouched.
func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode, r io.Reader) (int, error) { func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode, r io.Reader) (int, error) {
logger := api.logger.With(agentchat.Fields(ctx)...)
dir := filepath.Dir(path) dir := filepath.Dir(path)
tmpName := filepath.Join(dir, fmt.Sprintf(".%s.tmp.%s", filepath.Base(path), uuid.New().String()[:8])) tmpName := filepath.Join(dir, fmt.Sprintf(".%s.tmp.%s", filepath.Base(path), uuid.New().String()[:8]))
@@ -579,7 +583,7 @@ func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode,
cleanup := func() { cleanup := func() {
if err := api.filesystem.Remove(tmpName); err != nil { if err := api.filesystem.Remove(tmpName); err != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(err)) logger.Warn(ctx, "unable to clean up temp file", slog.Error(err))
} }
} }
@@ -601,7 +605,7 @@ func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode,
// no window where the target has wrong permissions. // no window where the target has wrong permissions.
if mode != nil { if mode != nil {
if err := api.filesystem.Chmod(tmpName, *mode); err != nil { if err := api.filesystem.Chmod(tmpName, *mode); err != nil {
api.logger.Warn(ctx, "unable to set file permissions", logger.Warn(ctx, "unable to set file permissions",
slog.F("path", path), slog.F("path", path),
slog.Error(err), slog.Error(err),
) )
+6 -5
View File
@@ -24,6 +24,7 @@ import (
"cdr.dev/slog/v3" "cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest" "cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentchat"
"github.com/coder/coder/v2/agent/agentfiles" "github.com/coder/coder/v2/agent/agentfiles"
"github.com/coder/coder/v2/agent/agentgit" "github.com/coder/coder/v2/agent/agentgit"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
@@ -1157,7 +1158,7 @@ func TestHandleWriteFile_ChatHeaders_UpdatesPathStore(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
r := chi.NewRouter() r := chi.NewRouter()
r.Post("/write-file", api.HandleWriteFile) r.Post("/write-file", api.HandleWriteFile)
r.ServeHTTP(rr, req) agentchat.Middleware(r).ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code) require.Equal(t, http.StatusOK, rr.Code)
@@ -1185,7 +1186,7 @@ func TestHandleWriteFile_NoChatHeaders_NoPathStoreUpdate(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
r := chi.NewRouter() r := chi.NewRouter()
r.Post("/write-file", api.HandleWriteFile) r.Post("/write-file", api.HandleWriteFile)
r.ServeHTTP(rr, req) agentchat.Middleware(r).ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code) require.Equal(t, http.StatusOK, rr.Code)
@@ -1211,7 +1212,7 @@ func TestHandleWriteFile_Failure_NoPathStoreUpdate(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
r := chi.NewRouter() r := chi.NewRouter()
r.Post("/write-file", api.HandleWriteFile) r.Post("/write-file", api.HandleWriteFile)
r.ServeHTTP(rr, req) agentchat.Middleware(r).ServeHTTP(rr, req)
require.Equal(t, http.StatusBadRequest, rr.Code) require.Equal(t, http.StatusBadRequest, rr.Code)
@@ -1252,7 +1253,7 @@ func TestHandleEditFiles_ChatHeaders_UpdatesPathStore(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
r := chi.NewRouter() r := chi.NewRouter()
r.Post("/edit-files", api.HandleEditFiles) r.Post("/edit-files", api.HandleEditFiles)
r.ServeHTTP(rr, req) agentchat.Middleware(r).ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code) require.Equal(t, http.StatusOK, rr.Code)
@@ -1289,7 +1290,7 @@ func TestHandleEditFiles_Failure_NoPathStoreUpdate(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
r := chi.NewRouter() r := chi.NewRouter()
r.Post("/edit-files", api.HandleEditFiles) r.Post("/edit-files", api.HandleEditFiles)
r.ServeHTTP(rr, req) agentchat.Middleware(r).ServeHTTP(rr, req)
require.NotEqual(t, http.StatusOK, rr.Code) require.NotEqual(t, http.StatusOK, rr.Code)
+28 -12
View File
@@ -8,6 +8,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"cdr.dev/slog/v3" "cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentchat"
"github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/codersdk/wsjson"
@@ -40,6 +41,25 @@ func (a *API) Routes() http.Handler {
func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) { func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
var watchChatID uuid.UUID
var hasWatchChatID bool
if chatIDStr := r.URL.Query().Get("chat_id"); chatIDStr != "" {
if parsedChatID, parseErr := uuid.Parse(chatIDStr); parseErr == nil {
watchChatID = parsedChatID
hasWatchChatID = true
// Reuse header-derived ancestors only when the query chat
// matches the header chat. Otherwise the ancestors belong
// to a different chat and would be misleading in logs.
var ancestors []uuid.UUID
if chatContext, ok := agentchat.FromContext(ctx); ok && chatContext.ID == watchChatID {
ancestors = chatContext.AncestorIDs
}
ctx = agentchat.WithContext(ctx, watchChatID, ancestors)
}
}
logger := a.logger.With(agentchat.Fields(ctx)...)
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionNoContextTakeover, CompressionMode: websocket.CompressionNoContextTakeover,
}) })
@@ -58,14 +78,14 @@ func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) {
stream := wsjson.NewStream[ stream := wsjson.NewStream[
codersdk.WorkspaceAgentGitClientMessage, codersdk.WorkspaceAgentGitClientMessage,
codersdk.WorkspaceAgentGitServerMessage, codersdk.WorkspaceAgentGitServerMessage,
](conn, websocket.MessageText, websocket.MessageText, a.logger) ](conn, websocket.MessageText, websocket.MessageText, logger)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
go httpapi.HeartbeatClose(ctx, a.logger, cancel, conn) go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
handler := NewHandler(a.logger, a.opts...) handler := NewHandler(logger, a.opts...)
// Scan returns nil only when no roots are subscribed; once any // Scan returns nil only when no roots are subscribed; once any
// root lands it returns either a delta or a heartbeat message. // root lands it returns either a delta or a heartbeat message.
@@ -75,28 +95,25 @@ func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) {
return return
} }
if err := stream.Send(*msg); err != nil { if err := stream.Send(*msg); err != nil {
a.logger.Debug(ctx, "failed to send changes", slog.Error(err)) logger.Debug(ctx, "failed to send changes", slog.Error(err))
cancel() cancel()
} }
} }
// If a chat_id query parameter is provided and the PathStore is // If a chat_id query parameter is provided and the PathStore is
// available, subscribe to path updates for this chat. // available, subscribe to path updates for this chat.
chatIDStr := r.URL.Query().Get("chat_id") if hasWatchChatID && a.pathStore != nil {
if chatIDStr != "" && a.pathStore != nil {
chatID, parseErr := uuid.Parse(chatIDStr)
if parseErr == nil {
// Subscribe to future path updates BEFORE reading // Subscribe to future path updates BEFORE reading
// existing paths. This ordering guarantees no // existing paths. This ordering guarantees no
// notification from AddPaths is lost: any call that // notification from AddPaths is lost: any call that
// lands before Subscribe is picked up by GetPaths // lands before Subscribe is picked up by GetPaths
// below, and any call after Subscribe delivers a // below, and any call after Subscribe delivers a
// notification on the channel. // notification on the channel.
notifyCh, unsubscribe := a.pathStore.Subscribe(chatID) notifyCh, unsubscribe := a.pathStore.Subscribe(watchChatID)
defer unsubscribe() defer unsubscribe()
// Load any paths that are already tracked for this chat. // Load any paths that are already tracked for this chat.
existingPaths := a.pathStore.GetPaths(chatID) existingPaths := a.pathStore.GetPaths(watchChatID)
if len(existingPaths) > 0 { if len(existingPaths) > 0 {
handler.Subscribe(existingPaths) handler.Subscribe(existingPaths)
handler.RequestScan() handler.RequestScan()
@@ -108,14 +125,13 @@ func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-notifyCh: case <-notifyCh:
paths := a.pathStore.GetPaths(chatID) paths := a.pathStore.GetPaths(watchChatID)
handler.Subscribe(paths) handler.Subscribe(paths)
handler.RequestScan() handler.RequestScan()
} }
} }
}() }()
} }
}
// Start the main run loop in a goroutine. // Start the main run loop in a goroutine.
go handler.RunLoop(ctx, scanAndSend) go handler.RunLoop(ctx, scanAndSend)
+13 -11
View File
@@ -13,6 +13,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"cdr.dev/slog/v3" "cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentchat"
"github.com/coder/coder/v2/agent/agentexec" "github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentgit" "github.com/coder/coder/v2/agent/agentgit"
"github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpapi"
@@ -80,8 +81,8 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
} }
var chatID string var chatID string
if id, _, ok := agentgit.ExtractChatContext(r); ok { if chatContext, ok := agentchat.FromContext(ctx); ok {
chatID = id.String() chatID = chatContext.ID.String()
} }
proc, err := api.manager.start(req, chatID) proc, err := api.manager.start(req, chatID)
@@ -97,8 +98,8 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
// file changes made by the command are visible in the scan. // file changes made by the command are visible in the scan.
// If a workdir is provided, track it as a path as well. // If a workdir is provided, track it as a path as well.
if api.pathStore != nil { if api.pathStore != nil {
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok { if chatContext, ok := agentchat.FromContext(ctx); ok {
allIDs := append([]uuid.UUID{chatID}, ancestorIDs...) allIDs := append([]uuid.UUID{chatContext.ID}, chatContext.AncestorIDs...)
go func() { go func() {
<-proc.done <-proc.done
if req.WorkDir != "" { if req.WorkDir != "" {
@@ -121,8 +122,8 @@ func (api *API) handleListProcesses(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
var chatID string var chatID string
if id, _, ok := agentgit.ExtractChatContext(r); ok { if chatContext, ok := agentchat.FromContext(ctx); ok {
chatID = id.String() chatID = chatContext.ID.String()
} }
infos := api.manager.list(chatID) infos := api.manager.list(chatID)
@@ -150,6 +151,7 @@ func (api *API) handleListProcesses(rw http.ResponseWriter, r *http.Request) {
// handleProcessOutput returns the output of a process. // handleProcessOutput returns the output of a process.
func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) { func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
logger := api.logger.With(agentchat.Fields(ctx)...)
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
proc, ok := api.manager.get(id) proc, ok := api.manager.get(id)
@@ -163,8 +165,8 @@ func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
// Enforce chat ID isolation. If the request carries // Enforce chat ID isolation. If the request carries
// a chat context, only allow access to processes // a chat context, only allow access to processes
// belonging to that chat. // belonging to that chat.
if chatID, _, ok := agentgit.ExtractChatContext(r); ok { if chatContext, ok := agentchat.FromContext(ctx); ok {
if proc.chatID != "" && proc.chatID != chatID.String() { if proc.chatID != "" && proc.chatID != chatContext.ID.String() {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id), Message: fmt.Sprintf("Process %q not found.", id),
}) })
@@ -184,7 +186,7 @@ func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
// Add headroom beyond the wait timeout so there's time to // Add headroom beyond the wait timeout so there's time to
// write the response after the blocking wait completes. // write the response after the blocking wait completes.
if err := rc.SetWriteDeadline(time.Now().Add(maxWaitDuration + 30*time.Second)); err != nil { if err := rc.SetWriteDeadline(time.Now().Add(maxWaitDuration + 30*time.Second)); err != nil {
api.logger.Error(ctx, "extend write deadline for blocking process output", logger.Error(ctx, "extend write deadline for blocking process output",
slog.Error(err), slog.Error(err),
) )
} }
@@ -216,9 +218,9 @@ func (api *API) handleSignalProcess(rw http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
// Enforce chat ID isolation. // Enforce chat ID isolation.
if chatID, _, ok := agentgit.ExtractChatContext(r); ok { if chatContext, ok := agentchat.FromContext(ctx); ok {
proc, procOK := api.manager.get(id) proc, procOK := api.manager.get(id)
if procOK && proc.chatID != "" && proc.chatID != chatID.String() { if procOK && proc.chatID != "" && proc.chatID != chatContext.ID.String() {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id), Message: fmt.Sprintf("Process %q not found.", id),
}) })
+33 -2
View File
@@ -20,9 +20,12 @@ import (
"cdr.dev/slog/v3" "cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest" "cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentchat"
"github.com/coder/coder/v2/agent/agentexec" "github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentgit" "github.com/coder/coder/v2/agent/agentgit"
"github.com/coder/coder/v2/agent/agentproc" "github.com/coder/coder/v2/agent/agentproc"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
@@ -137,7 +140,35 @@ func newTestAPIWithOptions(t *testing.T, updateEnv func([]string) ([]string, err
t.Cleanup(func() { t.Cleanup(func() {
_ = api.Close() _ = api.Close()
}) })
return api.Routes() return agentchat.Middleware(api.Routes())
}
func TestAccessLogIncludesChatID(t *testing.T) {
t.Parallel()
sink := testutil.NewFakeSink(t)
logger := sink.Logger()
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, nil, nil, nil)
t.Cleanup(func() {
_ = api.Close()
})
handler := tracing.StatusWriterMiddleware(loggermw.Logger(logger)(
agentchat.Middleware(api.Routes()),
))
chatID := uuid.New().String()
w := getListWithChatHeader(t, handler, chatID)
require.Equal(t, http.StatusOK, w.Code)
entries := sink.Entries(func(entry slog.SinkEntry) bool {
return entry.Message == http.MethodGet
})
require.Len(t, entries, 1)
fields := make(map[string]any, len(entries[0].Fields))
for _, field := range entries[0].Fields {
fields[field.Name] = field.Value
}
require.Equal(t, chatID, fields["chat_id"])
} }
// waitForExit polls the output endpoint until the process is // waitForExit polls the output endpoint until the process is
@@ -1058,7 +1089,7 @@ func TestHandleStartProcess_ChatHeaders_EmptyWorkDir_StillNotifies(t *testing.T)
}, pathStore, nil) }, pathStore, nil)
defer api.Close() defer api.Close()
routes := api.Routes() routes := agentchat.Middleware(api.Routes())
body, err := json.Marshal(workspacesdk.StartProcessRequest{ body, err := json.Marshal(workspacesdk.StartProcessRequest{
Command: "echo hello", Command: "echo hello",
+8 -2
View File
@@ -38,6 +38,7 @@ type process struct {
cmd *exec.Cmd cmd *exec.Cmd
cancel context.CancelFunc cancel context.CancelFunc
buf *HeadTailBuffer buf *HeadTailBuffer
logger slog.Logger
running bool running bool
exitCode *int exitCode *int
startedAt int64 startedAt int64
@@ -105,6 +106,10 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
m.mu.Unlock() m.mu.Unlock()
id := uuid.New().String() id := uuid.New().String()
logger := m.logger
if chatID != "" {
logger = logger.With(slog.F("chat_id", chatID))
}
// Use a cancellable context so Close() can terminate // Use a cancellable context so Close() can terminate
// all processes. context.Background() is the parent so // all processes. context.Background() is the parent so
@@ -132,7 +137,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
if m.updateEnv != nil { if m.updateEnv != nil {
updated, err := m.updateEnv(baseEnv) updated, err := m.updateEnv(baseEnv)
if err != nil { if err != nil {
m.logger.Warn( logger.Warn(
context.Background(), context.Background(),
"failed to update command environment, falling back to os env", "failed to update command environment, falling back to os env",
slog.Error(err), slog.Error(err),
@@ -169,6 +174,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
cmd: cmd, cmd: cmd,
cancel: cancel, cancel: cancel,
buf: buf, buf: buf,
logger: logger,
running: true, running: true,
startedAt: now, startedAt: now,
done: make(chan struct{}), done: make(chan struct{}),
@@ -202,7 +208,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
} else { } else {
// Unknown error; use -1 as a sentinel. // Unknown error; use -1 as a sentinel.
code = -1 code = -1
m.logger.Warn( proc.logger.Warn(
context.Background(), context.Background(),
"process wait returned non-exit error", "process wait returned non-exit error",
slog.F("id", id), slog.F("id", id),
+2
View File
@@ -6,6 +6,7 @@ import (
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/coder/coder/v2/agent/agentchat"
"github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/tracing"
@@ -20,6 +21,7 @@ func (a *agent) apiHandler() http.Handler {
httpmw.Recover(a.logger), httpmw.Recover(a.logger),
tracing.StatusWriterMiddleware, tracing.StatusWriterMiddleware,
loggermw.Logger(a.logger), loggermw.Logger(a.logger),
agentchat.Middleware,
) )
r.Get("/", func(rw http.ResponseWriter, r *http.Request) { r.Get("/", func(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{ httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
+17 -13
View File
@@ -16,6 +16,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"cdr.dev/slog/v3" "cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentchat"
"github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
@@ -85,6 +86,7 @@ func (a *API) Routes() http.Handler {
func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) { func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
logger := a.logger.With(agentchat.Fields(ctx)...)
// Start the desktop session (idempotent). // Start the desktop session (idempotent).
_, err := a.desktop.Start(ctx) _, err := a.desktop.Start(ctx)
@@ -112,7 +114,7 @@ func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) {
CompressionMode: websocket.CompressionDisabled, CompressionMode: websocket.CompressionDisabled,
}) })
if err != nil { if err != nil {
a.logger.Error(ctx, "failed to accept websocket", slog.Error(err)) logger.Error(ctx, "failed to accept websocket", slog.Error(err))
return return
} }
@@ -128,6 +130,7 @@ func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) {
func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) { func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
logger := a.logger.With(agentchat.Fields(ctx)...)
handlerStart := a.clock.Now() handlerStart := a.clock.Now()
// Update last desktop action timestamp for idle recording monitor. // Update last desktop action timestamp for idle recording monitor.
@@ -136,7 +139,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
// Ensure the desktop is running and grab native dimensions. // Ensure the desktop is running and grab native dimensions.
cfg, err := a.desktop.Start(ctx) cfg, err := a.desktop.Start(ctx)
if err != nil { if err != nil {
a.logger.Warn(ctx, "handleAction: desktop.Start failed", logger.Warn(ctx, "handleAction: desktop.Start failed",
slog.Error(err), slog.Error(err),
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()), slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
) )
@@ -156,7 +159,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
return return
} }
a.logger.Info(ctx, "handleAction: started", logger.Info(ctx, "handleAction: started",
slog.F("action", action.Action), slog.F("action", action.Action),
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()), slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
) )
@@ -272,7 +275,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
x, y = scaleXY(x, y) x, y = scaleXY(x, y)
stepStart := a.clock.Now() stepStart := a.clock.Now()
if err := a.desktop.Click(ctx, x, y, MouseButtonLeft); err != nil { if err := a.desktop.Click(ctx, x, y, MouseButtonLeft); err != nil {
a.logger.Warn(ctx, "handleAction: Click failed", logger.Warn(ctx, "handleAction: Click failed",
slog.F("action", "left_click"), slog.F("action", "left_click"),
slog.F("step", "click"), slog.F("step", "click"),
slog.F("step_ms", time.Since(stepStart).Milliseconds()), slog.F("step_ms", time.Since(stepStart).Milliseconds()),
@@ -285,7 +288,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
}) })
return return
} }
a.logger.Debug(ctx, "handleAction: Click completed", logger.Debug(ctx, "handleAction: Click completed",
slog.F("action", "left_click"), slog.F("action", "left_click"),
slog.F("step_ms", time.Since(stepStart).Milliseconds()), slog.F("step_ms", time.Since(stepStart).Milliseconds()),
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()), slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
@@ -473,7 +476,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
case <-ctx.Done(): case <-ctx.Done():
// Context canceled; release the key immediately. // Context canceled; release the key immediately.
if err := a.desktop.KeyUp(ctx, *action.Text); err != nil { if err := a.desktop.KeyUp(ctx, *action.Text); err != nil {
a.logger.Warn(ctx, "handleAction: KeyUp after context cancel", slog.Error(err)) logger.Warn(ctx, "handleAction: KeyUp after context cancel", slog.Error(err))
} }
return return
case <-timer.C: case <-timer.C:
@@ -513,14 +516,14 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
elapsedMs := a.clock.Since(handlerStart).Milliseconds() elapsedMs := a.clock.Since(handlerStart).Milliseconds()
if ctx.Err() != nil { if ctx.Err() != nil {
a.logger.Error(ctx, "handleAction: context canceled before writing response", logger.Error(ctx, "handleAction: context canceled before writing response",
slog.F("action", action.Action), slog.F("action", action.Action),
slog.F("elapsed_ms", elapsedMs), slog.F("elapsed_ms", elapsedMs),
slog.Error(ctx.Err()), slog.Error(ctx.Err()),
) )
return return
} }
a.logger.Info(ctx, "handleAction: writing response", logger.Info(ctx, "handleAction: writing response",
slog.F("action", action.Action), slog.F("action", action.Action),
slog.F("elapsed_ms", elapsedMs), slog.F("elapsed_ms", elapsedMs),
) )
@@ -609,6 +612,7 @@ func (a *API) handleRecordingStart(rw http.ResponseWriter, r *http.Request) {
func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) { func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
logger := a.logger.With(agentchat.Fields(ctx)...)
recordingID, ok := a.decodeRecordingRequest(rw, r) recordingID, ok := a.decodeRecordingRequest(rw, r)
if !ok { if !ok {
@@ -661,7 +665,7 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
}() }()
if artifact.Size > workspacesdk.MaxRecordingSize { if artifact.Size > workspacesdk.MaxRecordingSize {
a.logger.Warn(ctx, "recording file exceeds maximum size", logger.Warn(ctx, "recording file exceeds maximum size",
slog.F("recording_id", recordingID), slog.F("recording_id", recordingID),
slog.F("size", artifact.Size), slog.F("size", artifact.Size),
slog.F("max_size", workspacesdk.MaxRecordingSize), slog.F("max_size", workspacesdk.MaxRecordingSize),
@@ -677,7 +681,7 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
// rejecting it here avoids streaming a large thumbnail over // rejecting it here avoids streaming a large thumbnail over
// the wire for nothing. // the wire for nothing.
if artifact.ThumbnailReader != nil && artifact.ThumbnailSize > workspacesdk.MaxThumbnailSize { if artifact.ThumbnailReader != nil && artifact.ThumbnailSize > workspacesdk.MaxThumbnailSize {
a.logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting", logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting",
slog.F("recording_id", recordingID), slog.F("recording_id", recordingID),
slog.F("size", artifact.ThumbnailSize), slog.F("size", artifact.ThumbnailSize),
slog.F("max_size", workspacesdk.MaxThumbnailSize), slog.F("max_size", workspacesdk.MaxThumbnailSize),
@@ -701,13 +705,13 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
"Content-Type": {"video/mp4"}, "Content-Type": {"video/mp4"},
}) })
if err != nil { if err != nil {
a.logger.Warn(ctx, "failed to create video multipart part", logger.Warn(ctx, "failed to create video multipart part",
slog.F("recording_id", recordingID), slog.F("recording_id", recordingID),
slog.Error(err)) slog.Error(err))
return return
} }
if _, err := io.Copy(videoPart, artifact.Reader); err != nil { if _, err := io.Copy(videoPart, artifact.Reader); err != nil {
a.logger.Warn(ctx, "failed to write video multipart part", logger.Warn(ctx, "failed to write video multipart part",
slog.F("recording_id", recordingID), slog.F("recording_id", recordingID),
slog.Error(err)) slog.Error(err))
return return
@@ -719,7 +723,7 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
"Content-Type": {"image/jpeg"}, "Content-Type": {"image/jpeg"},
}) })
if err != nil { if err != nil {
a.logger.Warn(ctx, "failed to create thumbnail multipart part", logger.Warn(ctx, "failed to create thumbnail multipart part",
slog.F("recording_id", recordingID), slog.F("recording_id", recordingID),
slog.Error(err)) slog.Error(err))
return return
+6 -4
View File
@@ -8,6 +8,7 @@ import (
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"cdr.dev/slog/v3" "cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentchat"
"github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/codersdk/workspacesdk"
@@ -52,6 +53,7 @@ func (api *API) Routes() http.Handler {
// independent of config changes. // independent of config changes.
func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) { func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
logger := api.logger.With(agentchat.Fields(ctx)...)
// Check config freshness and reload if changed. // Check config freshness and reload if changed.
var reloaded bool var reloaded bool
@@ -61,11 +63,11 @@ func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) {
// Categorize the error for operator debugging. // Categorize the error for operator debugging.
switch { switch {
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):
api.logger.Warn(ctx, "mcp reload canceled by caller", slog.Error(err)) logger.Warn(ctx, "mcp reload canceled by caller", slog.Error(err))
case errors.Is(err, context.DeadlineExceeded): case errors.Is(err, context.DeadlineExceeded):
api.logger.Warn(ctx, "mcp reload timed out", slog.Error(err)) logger.Warn(ctx, "mcp reload timed out", slog.Error(err))
default: default:
api.logger.Warn(ctx, "mcp reload failed", slog.Error(err)) logger.Warn(ctx, "mcp reload failed", slog.Error(err))
} }
// Fall through to return whatever tools we have. // Fall through to return whatever tools we have.
} else { } else {
@@ -78,7 +80,7 @@ func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) {
// refreshes tools as part of the reload. // refreshes tools as part of the reload.
if r.URL.Query().Get("refresh") == "true" && !reloaded { if r.URL.Query().Get("refresh") == "true" && !reloaded {
if err := api.manager.RefreshTools(ctx); err != nil { if err := api.manager.RefreshTools(ctx); err != nil {
api.logger.Warn(ctx, "failed to refresh MCP tools", slog.Error(err)) logger.Warn(ctx, "failed to refresh MCP tools", slog.Error(err))
} }
} }
+15 -5
View File
@@ -22,6 +22,7 @@ import (
tailscalesingleflight "tailscale.com/util/singleflight" tailscalesingleflight "tailscale.com/util/singleflight"
"cdr.dev/slog/v3" "cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentchat"
"github.com/coder/coder/v2/agent/agentexec" "github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/usershell" "github.com/coder/coder/v2/agent/usershell"
"github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/buildinfo"
@@ -248,7 +249,8 @@ func (m *Manager) doReload(ctx context.Context, mcpConfigFiles []string) error {
// Refresh tools outside the lock to avoid blocking // Refresh tools outside the lock to avoid blocking
// concurrent reads during network I/O. // concurrent reads during network I/O.
if err := m.RefreshTools(ctx); err != nil { if err := m.RefreshTools(ctx); err != nil {
m.logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err)) logger := m.logger.With(agentchat.Fields(ctx)...)
logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err))
} }
return nil return nil
} }
@@ -257,6 +259,8 @@ func (m *Manager) doReload(ctx context.Context, mcpConfigFiles []string) error {
// list of server configs. Missing files are silently skipped; // list of server configs. Missing files are silently skipped;
// parse errors are logged and skipped. // parse errors are logged and skipped.
func (m *Manager) parseAndDedup(ctx context.Context, mcpConfigFiles []string) ([]ServerConfig, map[string]fileSnapshot) { func (m *Manager) parseAndDedup(ctx context.Context, mcpConfigFiles []string) ([]ServerConfig, map[string]fileSnapshot) {
logger := m.logger.With(agentchat.Fields(ctx)...)
// Stat before reading so the snapshot is conservatively old. // Stat before reading so the snapshot is conservatively old.
// If a file changes between stat and read, the snapshot // If a file changes between stat and read, the snapshot
// records the old mtime, SnapshotChanged detects a mismatch // records the old mtime, SnapshotChanged detects a mismatch
@@ -272,7 +276,7 @@ func (m *Manager) parseAndDedup(ctx context.Context, mcpConfigFiles []string) ([
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {
continue continue
} }
m.logger.Warn(ctx, "failed to parse MCP config", logger.Warn(ctx, "failed to parse MCP config",
slog.F("path", configPath), slog.F("path", configPath),
slog.Error(err), slog.Error(err),
) )
@@ -334,6 +338,8 @@ func (m *Manager) classifyServers(wanted map[string]ServerConfig) (*serverDiff,
// connectAll runs connectServer in parallel for the given configs. // connectAll runs connectServer in parallel for the given configs.
// Failed connects are logged and skipped. // Failed connects are logged and skipped.
func (m *Manager) connectAll(ctx context.Context, toConnect []ServerConfig) []connectedServer { func (m *Manager) connectAll(ctx context.Context, toConnect []ServerConfig) []connectedServer {
logger := m.logger.With(agentchat.Fields(ctx)...)
var ( var (
mu sync.Mutex mu sync.Mutex
connected []connectedServer connected []connectedServer
@@ -343,7 +349,7 @@ func (m *Manager) connectAll(ctx context.Context, toConnect []ServerConfig) []co
eg.Go(func() error { eg.Go(func() error {
c, err := m.connectServer(ctx, cfg) c, err := m.connectServer(ctx, cfg)
if err != nil { if err != nil {
m.logger.Warn(ctx, "skipping MCP server", logger.Warn(ctx, "skipping MCP server",
slog.F("server", cfg.Name), slog.F("server", cfg.Name),
slog.F("transport", cfg.Transport), slog.F("transport", cfg.Transport),
slog.Error(err), slog.Error(err),
@@ -481,6 +487,8 @@ func (m *Manager) CallTool(ctx context.Context, req workspacesdk.CallMCPToolRequ
// existing cached tools for servers that failed, so a single // existing cached tools for servers that failed, so a single
// dead server doesn't block updates from healthy ones. // dead server doesn't block updates from healthy ones.
func (m *Manager) RefreshTools(ctx context.Context) error { func (m *Manager) RefreshTools(ctx context.Context) error {
logger := m.logger.With(agentchat.Fields(ctx)...)
// Snapshot servers under read lock. // Snapshot servers under read lock.
m.mu.RLock() m.mu.RLock()
servers := make(map[string]*serverEntry, len(m.servers)) servers := make(map[string]*serverEntry, len(m.servers))
@@ -508,7 +516,7 @@ func (m *Manager) RefreshTools(ctx context.Context) error {
result, err := entry.client.ListTools(listCtx, mcp.ListToolsRequest{}) result, err := entry.client.ListTools(listCtx, mcp.ListToolsRequest{})
cancel() cancel()
if err != nil { if err != nil {
m.logger.Warn(ctx, "failed to list tools from MCP server", logger.Warn(ctx, "failed to list tools from MCP server",
slog.F("server", name), slog.F("server", name),
slog.Error(err), slog.Error(err),
) )
@@ -670,12 +678,14 @@ func (m *Manager) createTransport(ctx context.Context, cfg ServerConfig) (transp
// updateEnv callback, then merges explicit overrides from the // updateEnv callback, then merges explicit overrides from the
// server config on top. // server config on top.
func (m *Manager) buildEnv(ctx context.Context, explicit map[string]string) []string { func (m *Manager) buildEnv(ctx context.Context, explicit map[string]string) []string {
logger := m.logger.With(agentchat.Fields(ctx)...)
env := usershell.SystemEnvInfo{}.Environ() env := usershell.SystemEnvInfo{}.Environ()
if m.updateEnv != nil { if m.updateEnv != nil {
var err error var err error
env, err = m.updateEnv(env) env, err = m.updateEnv(env)
if err != nil { if err != nil {
m.logger.Warn(ctx, "failed to enrich MCP server environment", logger.Warn(ctx, "failed to enrich MCP server environment",
slog.Error(err), slog.Error(err),
) )
env = usershell.SystemEnvInfo{}.Environ() env = usershell.SystemEnvInfo{}.Environ()
+6
View File
@@ -905,6 +905,12 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces
}) })
c.mu.Unlock() c.mu.Unlock()
c.server.logger.Debug(ctx, "set chat headers on agent conn",
slog.F("chat_id", chatSnapshot.ID),
slog.F("ancestor_chat_ids", ancestorIDs),
slog.F("workspace_id", chatSnapshot.WorkspaceID.UUID),
slog.F("agent_id", dialResult.AgentID),
)
return agentConn, nil return agentConn, nil
} }
currentConn = c.conn currentConn = c.conn