diff --git a/agent/agentgit/chatheaders.go b/agent/agentchat/headers.go similarity index 80% rename from agent/agentgit/chatheaders.go rename to agent/agentchat/headers.go index d516173ec8..84db99bb25 100644 --- a/agent/agentgit/chatheaders.go +++ b/agent/agentchat/headers.go @@ -1,4 +1,4 @@ -package agentgit +package agentchat import ( "encoding/json" @@ -9,9 +9,9 @@ import ( "github.com/coder/coder/v2/codersdk/workspacesdk" ) -// ExtractChatContext reads chat identity headers from the request. +// extractContext reads chat identity headers from the request. // Returns zero values if headers are absent (non-chat request). -func ExtractChatContext(r *http.Request) (chatID uuid.UUID, ancestorIDs []uuid.UUID, ok bool) { +func extractContext(r *http.Request) (chatID uuid.UUID, ancestorIDs []uuid.UUID, ok bool) { raw := r.Header.Get(workspacesdk.CoderChatIDHeader) if raw == "" { return uuid.Nil, nil, false diff --git a/agent/agentgit/chatheaders_test.go b/agent/agentchat/headers_test.go similarity index 84% rename from agent/agentgit/chatheaders_test.go rename to agent/agentchat/headers_test.go index 3242c7b40a..90599eab28 100644 --- a/agent/agentgit/chatheaders_test.go +++ b/agent/agentchat/headers_test.go @@ -1,18 +1,19 @@ -package agentgit_test +package agentchat_test import ( "encoding/json" + "net/http" "net/http/httptest" "testing" "github.com/google/uuid" "github.com/stretchr/testify/require" - "github.com/coder/coder/v2/agent/agentgit" + "github.com/coder/coder/v2/agent/agentchat" "github.com/coder/coder/v2/codersdk/workspacesdk" ) -func TestExtractChatContext(t *testing.T) { +func TestExtractContext(t *testing.T) { t.Parallel() validID := uuid.MustParse("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") @@ -43,7 +44,7 @@ func TestExtractChatContext(t *testing.T) { setChatID: true, setAncestors: false, wantChatID: validID, - wantAncestorIDs: nil, + wantAncestorIDs: []uuid.UUID{}, wantOK: true, }, { @@ -75,7 +76,7 @@ func TestExtractChatContext(t *testing.T) { ancestors: `{this is not json}`, setAncestors: true, wantChatID: validID, - wantAncestorIDs: nil, + wantAncestorIDs: []uuid.UUID{}, wantOK: true, }, { @@ -112,7 +113,7 @@ func TestExtractChatContext(t *testing.T) { ancestors: "", setAncestors: true, wantChatID: validID, - wantAncestorIDs: nil, + wantAncestorIDs: []uuid.UUID{}, wantOK: true, }, } @@ -130,7 +131,7 @@ func TestExtractChatContext(t *testing.T) { r.Header.Set(workspacesdk.CoderAncestorChatIDsHeader, tt.ancestors) } - chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r) + chatID, ancestorIDs, ok := extractContextForTest(r) require.Equal(t, tt.wantOK, ok, "ok mismatch") require.Equal(t, tt.wantChatID, chatID, "chatID mismatch") @@ -139,6 +140,18 @@ func TestExtractChatContext(t *testing.T) { } } +func extractContextForTest(r *http.Request) (uuid.UUID, []uuid.UUID, bool) { + var chatContext agentchat.Context + var ok bool + agentchat.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + chatContext, ok = agentchat.FromContext(r.Context()) + })).ServeHTTP(httptest.NewRecorder(), r) + if !ok { + return uuid.Nil, nil, false + } + return chatContext.ID, chatContext.AncestorIDs, true +} + // mustMarshalJSON marshals v to a JSON string, failing the test on error. func mustMarshalJSON(t *testing.T, v any) string { t.Helper() diff --git a/agent/agentchat/log.go b/agent/agentchat/log.go new file mode 100644 index 0000000000..319f6a79b6 --- /dev/null +++ b/agent/agentchat/log.go @@ -0,0 +1,85 @@ +package agentchat + +import ( + "context" + "net/http" + + "github.com/google/uuid" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" +) + +type chatContextKey struct{} + +// Context carries the chat identity associated with an agent request. +type Context struct { + ID uuid.UUID + AncestorIDs []uuid.UUID +} + +// FromContext returns the chat identity stored on the context. +func FromContext(ctx context.Context) (Context, bool) { + chatCtx, ok := ctx.Value(chatContextKey{}).(Context) + if !ok || chatCtx.ID == uuid.Nil { + return Context{}, false + } + return chatCtx, true +} + +// WithContext stores chat identity on the context for downstream logs. +func WithContext(ctx context.Context, chatID uuid.UUID, ancestorIDs []uuid.UUID) context.Context { + if chatID == uuid.Nil { + return ctx + } + ancestors := make([]uuid.UUID, len(ancestorIDs)) + copy(ancestors, ancestorIDs) + return context.WithValue(ctx, chatContextKey{}, Context{ + ID: chatID, + AncestorIDs: ancestors, + }) +} + +// Fields returns structured log fields for the chat identity on ctx. +func Fields(ctx context.Context) []slog.Field { + chatCtx, ok := FromContext(ctx) + if !ok { + return nil + } + return chatFields(chatCtx.ID, chatCtx.AncestorIDs) +} + +// Middleware tags agent logs for requests that originate from +// chatd. Agent log lines emitted while serving a request with Coder-Chat-Id, +// or by background work started by such a request, should include chat_id. +// Install after loggermw.Logger so access-log enrichment can run. +func Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + chatID, ancestorIDs, ok := extractContext(r) + if !ok { + next.ServeHTTP(rw, r) + return + } + + fields := chatFields(chatID, ancestorIDs) + if requestLogger := loggermw.RequestLoggerFromContext(r.Context()); requestLogger != nil { + requestLogger.WithFields(fields...) + } + + ctx := WithContext(r.Context(), chatID, ancestorIDs) + next.ServeHTTP(rw, r.WithContext(ctx)) + }) +} + +func chatFields(chatID uuid.UUID, ancestorIDs []uuid.UUID) []slog.Field { + fields := []slog.Field{slog.F("chat_id", chatID.String())} + if len(ancestorIDs) == 0 { + return fields + } + + ancestors := make([]string, 0, len(ancestorIDs)) + for _, id := range ancestorIDs { + ancestors = append(ancestors, id.String()) + } + return append(fields, slog.F("ancestor_chat_ids", ancestors)) +} diff --git a/agent/agentchat/log_test.go b/agent/agentchat/log_test.go new file mode 100644 index 0000000000..c9fb1fc49a --- /dev/null +++ b/agent/agentchat/log_test.go @@ -0,0 +1,103 @@ +package agentchat_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/agent/agentchat" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" + "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/testutil" +) + +func TestMiddlewareAccessLog(t *testing.T) { + t.Parallel() + + chatID := uuid.New() + ancestorID := uuid.New() + sink := testutil.NewFakeSink(t) + handler := tracing.StatusWriterMiddleware(loggermw.Logger(sink.Logger())( + agentchat.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusNoContent) + })), + )) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String()) + req.Header.Set(workspacesdk.CoderAncestorChatIDsHeader, mustMarshalJSON(t, []string{ancestorID.String()})) + rw := httptest.NewRecorder() + handler.ServeHTTP(rw, req) + require.Equal(t, http.StatusNoContent, rw.Code) + + entries := sink.Entries() + require.Len(t, entries, 1) + fields := fieldsByName(entries[0].Fields) + require.Equal(t, chatID.String(), fields["chat_id"]) + require.Equal(t, []string{ancestorID.String()}, fields["ancestor_chat_ids"]) +} + +func TestMiddlewareWithoutChatHeader(t *testing.T) { + t.Parallel() + + sink := testutil.NewFakeSink(t) + handler := tracing.StatusWriterMiddleware(loggermw.Logger(sink.Logger())( + agentchat.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusNoContent) + })), + )) + + rw := httptest.NewRecorder() + handler.ServeHTTP(rw, httptest.NewRequest(http.MethodGet, "/test", nil)) + require.Equal(t, http.StatusNoContent, rw.Code) + + entries := sink.Entries() + require.Len(t, entries, 1) + fields := fieldsByName(entries[0].Fields) + require.NotContains(t, fields, "chat_id") + require.NotContains(t, fields, "ancestor_chat_ids") +} + +func TestMiddlewareContextFields(t *testing.T) { + t.Parallel() + + chatID := uuid.New() + sink := testutil.NewFakeSink(t) + handler := tracing.StatusWriterMiddleware(loggermw.Logger(sink.Logger())( + agentchat.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + sink.Logger().With(agentchat.Fields(r.Context())...).Info(r.Context(), "handler log") + rw.WriteHeader(http.StatusNoContent) + })), + )) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String()) + rw := httptest.NewRecorder() + handler.ServeHTTP(rw, req) + require.Equal(t, http.StatusNoContent, rw.Code) + + entries := sink.Entries() + require.Len(t, entries, 2) + for _, entry := range entries { + if entry.Message != "handler log" { + continue + } + fields := fieldsByName(entry.Fields) + require.Equal(t, chatID.String(), fields["chat_id"]) + return + } + t.Fatal("handler log entry not found") +} + +func fieldsByName(fields []slog.Field) map[string]any { + byName := make(map[string]any, len(fields)) + for _, field := range fields { + byName[field.Name] = field.Value + } + return byName +} diff --git a/agent/agentfiles/files.go b/agent/agentfiles/files.go index 868b4e5fb1..79602dbc17 100644 --- a/agent/agentfiles/files.go +++ b/agent/agentfiles/files.go @@ -18,7 +18,7 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog/v3" - "github.com/coder/coder/v2/agent/agentgit" + "github.com/coder/coder/v2/agent/agentchat" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" @@ -86,6 +86,8 @@ func (api *API) HandleReadFile(rw http.ResponseWriter, r *http.Request) { } func (api *API) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (HTTPResponseCode, error) { + logger := api.logger.With(agentchat.Fields(ctx)...) + if !filepath.IsAbs(path) { return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path) } @@ -131,7 +133,7 @@ func (api *API) streamFile(ctx context.Context, rw http.ResponseWriter, path str reader := io.NewSectionReader(f, offset, bytesToRead) _, err = io.Copy(rw, reader) if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil { - api.logger.Error(ctx, "workspace agent read file", slog.Error(err)) + logger.Error(ctx, "workspace agent read file", slog.Error(err)) } return 0, nil @@ -322,8 +324,8 @@ func (api *API) HandleWriteFile(rw http.ResponseWriter, r *http.Request) { // Track edited path for git watch. if api.pathStore != nil { - if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok { - api.pathStore.AddPaths(append([]uuid.UUID{chatID}, ancestorIDs...), []string{path}) + if chatContext, ok := agentchat.FromContext(ctx); ok { + api.pathStore.AddPaths(append([]uuid.UUID{chatContext.ID}, chatContext.AncestorIDs...), []string{path}) } } @@ -458,12 +460,12 @@ func (api *API) HandleEditFiles(rw http.ResponseWriter, r *http.Request) { // Track edited paths for git watch. if api.pathStore != nil { - if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok { + if chatContext, ok := agentchat.FromContext(ctx); ok { filePaths := make([]string, 0, len(req.Files)) for _, f := range req.Files { filePaths = append(filePaths, f.Path) } - api.pathStore.AddPaths(append([]uuid.UUID{chatID}, ancestorIDs...), filePaths) + api.pathStore.AddPaths(append([]uuid.UUID{chatContext.ID}, chatContext.AncestorIDs...), filePaths) } } @@ -565,6 +567,8 @@ func (api *API) prepareFileEdit(path string, edits []workspacesdk.FileEdit) (int // On failure the temp file is cleaned up and the original is // untouched. func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode, r io.Reader) (int, error) { + logger := api.logger.With(agentchat.Fields(ctx)...) + dir := filepath.Dir(path) tmpName := filepath.Join(dir, fmt.Sprintf(".%s.tmp.%s", filepath.Base(path), uuid.New().String()[:8])) @@ -579,7 +583,7 @@ func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode, cleanup := func() { if err := api.filesystem.Remove(tmpName); err != nil { - api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(err)) + logger.Warn(ctx, "unable to clean up temp file", slog.Error(err)) } } @@ -601,7 +605,7 @@ func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode, // no window where the target has wrong permissions. if mode != nil { if err := api.filesystem.Chmod(tmpName, *mode); err != nil { - api.logger.Warn(ctx, "unable to set file permissions", + logger.Warn(ctx, "unable to set file permissions", slog.F("path", path), slog.Error(err), ) diff --git a/agent/agentfiles/files_test.go b/agent/agentfiles/files_test.go index cc0df0c96a..19f3187d88 100644 --- a/agent/agentfiles/files_test.go +++ b/agent/agentfiles/files_test.go @@ -24,6 +24,7 @@ import ( "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/agent/agentchat" "github.com/coder/coder/v2/agent/agentfiles" "github.com/coder/coder/v2/agent/agentgit" "github.com/coder/coder/v2/codersdk" @@ -1157,7 +1158,7 @@ func TestHandleWriteFile_ChatHeaders_UpdatesPathStore(t *testing.T) { rr := httptest.NewRecorder() r := chi.NewRouter() r.Post("/write-file", api.HandleWriteFile) - r.ServeHTTP(rr, req) + agentchat.Middleware(r).ServeHTTP(rr, req) require.Equal(t, http.StatusOK, rr.Code) @@ -1185,7 +1186,7 @@ func TestHandleWriteFile_NoChatHeaders_NoPathStoreUpdate(t *testing.T) { rr := httptest.NewRecorder() r := chi.NewRouter() r.Post("/write-file", api.HandleWriteFile) - r.ServeHTTP(rr, req) + agentchat.Middleware(r).ServeHTTP(rr, req) require.Equal(t, http.StatusOK, rr.Code) @@ -1211,7 +1212,7 @@ func TestHandleWriteFile_Failure_NoPathStoreUpdate(t *testing.T) { rr := httptest.NewRecorder() r := chi.NewRouter() r.Post("/write-file", api.HandleWriteFile) - r.ServeHTTP(rr, req) + agentchat.Middleware(r).ServeHTTP(rr, req) require.Equal(t, http.StatusBadRequest, rr.Code) @@ -1252,7 +1253,7 @@ func TestHandleEditFiles_ChatHeaders_UpdatesPathStore(t *testing.T) { rr := httptest.NewRecorder() r := chi.NewRouter() r.Post("/edit-files", api.HandleEditFiles) - r.ServeHTTP(rr, req) + agentchat.Middleware(r).ServeHTTP(rr, req) require.Equal(t, http.StatusOK, rr.Code) @@ -1289,7 +1290,7 @@ func TestHandleEditFiles_Failure_NoPathStoreUpdate(t *testing.T) { rr := httptest.NewRecorder() r := chi.NewRouter() r.Post("/edit-files", api.HandleEditFiles) - r.ServeHTTP(rr, req) + agentchat.Middleware(r).ServeHTTP(rr, req) require.NotEqual(t, http.StatusOK, rr.Code) diff --git a/agent/agentgit/api.go b/agent/agentgit/api.go index 5e31e6c0e8..ea9ac11132 100644 --- a/agent/agentgit/api.go +++ b/agent/agentgit/api.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/agent/agentchat" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/wsjson" @@ -40,6 +41,25 @@ func (a *API) Routes() http.Handler { func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + var watchChatID uuid.UUID + var hasWatchChatID bool + if chatIDStr := r.URL.Query().Get("chat_id"); chatIDStr != "" { + if parsedChatID, parseErr := uuid.Parse(chatIDStr); parseErr == nil { + watchChatID = parsedChatID + hasWatchChatID = true + + // Reuse header-derived ancestors only when the query chat + // matches the header chat. Otherwise the ancestors belong + // to a different chat and would be misleading in logs. + var ancestors []uuid.UUID + if chatContext, ok := agentchat.FromContext(ctx); ok && chatContext.ID == watchChatID { + ancestors = chatContext.AncestorIDs + } + ctx = agentchat.WithContext(ctx, watchChatID, ancestors) + } + } + logger := a.logger.With(agentchat.Fields(ctx)...) + conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ CompressionMode: websocket.CompressionNoContextTakeover, }) @@ -58,14 +78,14 @@ func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) { stream := wsjson.NewStream[ codersdk.WorkspaceAgentGitClientMessage, codersdk.WorkspaceAgentGitServerMessage, - ](conn, websocket.MessageText, websocket.MessageText, a.logger) + ](conn, websocket.MessageText, websocket.MessageText, logger) ctx, cancel := context.WithCancel(ctx) defer cancel() - go httpapi.HeartbeatClose(ctx, a.logger, cancel, conn) + go httpapi.HeartbeatClose(ctx, logger, cancel, conn) - handler := NewHandler(a.logger, a.opts...) + handler := NewHandler(logger, a.opts...) // Scan returns nil only when no roots are subscribed; once any // root lands it returns either a delta or a heartbeat message. @@ -75,46 +95,42 @@ func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) { return } if err := stream.Send(*msg); err != nil { - a.logger.Debug(ctx, "failed to send changes", slog.Error(err)) + logger.Debug(ctx, "failed to send changes", slog.Error(err)) cancel() } } // If a chat_id query parameter is provided and the PathStore is // available, subscribe to path updates for this chat. - chatIDStr := r.URL.Query().Get("chat_id") - if chatIDStr != "" && a.pathStore != nil { - chatID, parseErr := uuid.Parse(chatIDStr) - if parseErr == nil { - // Subscribe to future path updates BEFORE reading - // existing paths. This ordering guarantees no - // notification from AddPaths is lost: any call that - // lands before Subscribe is picked up by GetPaths - // below, and any call after Subscribe delivers a - // notification on the channel. - notifyCh, unsubscribe := a.pathStore.Subscribe(chatID) - defer unsubscribe() + if hasWatchChatID && a.pathStore != nil { + // Subscribe to future path updates BEFORE reading + // existing paths. This ordering guarantees no + // notification from AddPaths is lost: any call that + // lands before Subscribe is picked up by GetPaths + // below, and any call after Subscribe delivers a + // notification on the channel. + notifyCh, unsubscribe := a.pathStore.Subscribe(watchChatID) + defer unsubscribe() - // Load any paths that are already tracked for this chat. - existingPaths := a.pathStore.GetPaths(chatID) - if len(existingPaths) > 0 { - handler.Subscribe(existingPaths) - handler.RequestScan() - } - - go func() { - for { - select { - case <-ctx.Done(): - return - case <-notifyCh: - paths := a.pathStore.GetPaths(chatID) - handler.Subscribe(paths) - handler.RequestScan() - } - } - }() + // Load any paths that are already tracked for this chat. + existingPaths := a.pathStore.GetPaths(watchChatID) + if len(existingPaths) > 0 { + handler.Subscribe(existingPaths) + handler.RequestScan() } + + go func() { + for { + select { + case <-ctx.Done(): + return + case <-notifyCh: + paths := a.pathStore.GetPaths(watchChatID) + handler.Subscribe(paths) + handler.RequestScan() + } + } + }() } // Start the main run loop in a goroutine. diff --git a/agent/agentproc/api.go b/agent/agentproc/api.go index c2b8d072c1..4dcc07b541 100644 --- a/agent/agentproc/api.go +++ b/agent/agentproc/api.go @@ -13,6 +13,7 @@ import ( "github.com/google/uuid" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/agent/agentchat" "github.com/coder/coder/v2/agent/agentexec" "github.com/coder/coder/v2/agent/agentgit" "github.com/coder/coder/v2/coderd/httpapi" @@ -80,8 +81,8 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) { } var chatID string - if id, _, ok := agentgit.ExtractChatContext(r); ok { - chatID = id.String() + if chatContext, ok := agentchat.FromContext(ctx); ok { + chatID = chatContext.ID.String() } proc, err := api.manager.start(req, chatID) @@ -97,8 +98,8 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) { // file changes made by the command are visible in the scan. // If a workdir is provided, track it as a path as well. if api.pathStore != nil { - if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok { - allIDs := append([]uuid.UUID{chatID}, ancestorIDs...) + if chatContext, ok := agentchat.FromContext(ctx); ok { + allIDs := append([]uuid.UUID{chatContext.ID}, chatContext.AncestorIDs...) go func() { <-proc.done if req.WorkDir != "" { @@ -121,8 +122,8 @@ func (api *API) handleListProcesses(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() var chatID string - if id, _, ok := agentgit.ExtractChatContext(r); ok { - chatID = id.String() + if chatContext, ok := agentchat.FromContext(ctx); ok { + chatID = chatContext.ID.String() } infos := api.manager.list(chatID) @@ -150,6 +151,7 @@ func (api *API) handleListProcesses(rw http.ResponseWriter, r *http.Request) { // handleProcessOutput returns the output of a process. func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + logger := api.logger.With(agentchat.Fields(ctx)...) id := chi.URLParam(r, "id") proc, ok := api.manager.get(id) @@ -163,8 +165,8 @@ func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) { // Enforce chat ID isolation. If the request carries // a chat context, only allow access to processes // belonging to that chat. - if chatID, _, ok := agentgit.ExtractChatContext(r); ok { - if proc.chatID != "" && proc.chatID != chatID.String() { + if chatContext, ok := agentchat.FromContext(ctx); ok { + if proc.chatID != "" && proc.chatID != chatContext.ID.String() { httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ Message: fmt.Sprintf("Process %q not found.", id), }) @@ -184,7 +186,7 @@ func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) { // Add headroom beyond the wait timeout so there's time to // write the response after the blocking wait completes. if err := rc.SetWriteDeadline(time.Now().Add(maxWaitDuration + 30*time.Second)); err != nil { - api.logger.Error(ctx, "extend write deadline for blocking process output", + logger.Error(ctx, "extend write deadline for blocking process output", slog.Error(err), ) } @@ -216,9 +218,9 @@ func (api *API) handleSignalProcess(rw http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") // Enforce chat ID isolation. - if chatID, _, ok := agentgit.ExtractChatContext(r); ok { + if chatContext, ok := agentchat.FromContext(ctx); ok { proc, procOK := api.manager.get(id) - if procOK && proc.chatID != "" && proc.chatID != chatID.String() { + if procOK && proc.chatID != "" && proc.chatID != chatContext.ID.String() { httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ Message: fmt.Sprintf("Process %q not found.", id), }) diff --git a/agent/agentproc/api_test.go b/agent/agentproc/api_test.go index eddbe2d6f9..704d968899 100644 --- a/agent/agentproc/api_test.go +++ b/agent/agentproc/api_test.go @@ -20,9 +20,12 @@ import ( "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/agent/agentchat" "github.com/coder/coder/v2/agent/agentexec" "github.com/coder/coder/v2/agent/agentgit" "github.com/coder/coder/v2/agent/agentproc" + "github.com/coder/coder/v2/coderd/httpmw/loggermw" + "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/testutil" @@ -137,7 +140,35 @@ func newTestAPIWithOptions(t *testing.T, updateEnv func([]string) ([]string, err t.Cleanup(func() { _ = api.Close() }) - return api.Routes() + return agentchat.Middleware(api.Routes()) +} + +func TestAccessLogIncludesChatID(t *testing.T) { + t.Parallel() + + sink := testutil.NewFakeSink(t) + logger := sink.Logger() + api := agentproc.NewAPI(logger, agentexec.DefaultExecer, nil, nil, nil) + t.Cleanup(func() { + _ = api.Close() + }) + handler := tracing.StatusWriterMiddleware(loggermw.Logger(logger)( + agentchat.Middleware(api.Routes()), + )) + + chatID := uuid.New().String() + w := getListWithChatHeader(t, handler, chatID) + require.Equal(t, http.StatusOK, w.Code) + + entries := sink.Entries(func(entry slog.SinkEntry) bool { + return entry.Message == http.MethodGet + }) + require.Len(t, entries, 1) + fields := make(map[string]any, len(entries[0].Fields)) + for _, field := range entries[0].Fields { + fields[field.Name] = field.Value + } + require.Equal(t, chatID, fields["chat_id"]) } // waitForExit polls the output endpoint until the process is @@ -1058,7 +1089,7 @@ func TestHandleStartProcess_ChatHeaders_EmptyWorkDir_StillNotifies(t *testing.T) }, pathStore, nil) defer api.Close() - routes := api.Routes() + routes := agentchat.Middleware(api.Routes()) body, err := json.Marshal(workspacesdk.StartProcessRequest{ Command: "echo hello", diff --git a/agent/agentproc/process.go b/agent/agentproc/process.go index c172195b8b..d4cecdff9b 100644 --- a/agent/agentproc/process.go +++ b/agent/agentproc/process.go @@ -38,6 +38,7 @@ type process struct { cmd *exec.Cmd cancel context.CancelFunc buf *HeadTailBuffer + logger slog.Logger running bool exitCode *int startedAt int64 @@ -105,6 +106,10 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p m.mu.Unlock() id := uuid.New().String() + logger := m.logger + if chatID != "" { + logger = logger.With(slog.F("chat_id", chatID)) + } // Use a cancellable context so Close() can terminate // all processes. context.Background() is the parent so @@ -132,7 +137,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p if m.updateEnv != nil { updated, err := m.updateEnv(baseEnv) if err != nil { - m.logger.Warn( + logger.Warn( context.Background(), "failed to update command environment, falling back to os env", slog.Error(err), @@ -169,6 +174,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p cmd: cmd, cancel: cancel, buf: buf, + logger: logger, running: true, startedAt: now, done: make(chan struct{}), @@ -202,7 +208,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p } else { // Unknown error; use -1 as a sentinel. code = -1 - m.logger.Warn( + proc.logger.Warn( context.Background(), "process wait returned non-exit error", slog.F("id", id), diff --git a/agent/api.go b/agent/api.go index 0258d410cd..0346805528 100644 --- a/agent/api.go +++ b/agent/api.go @@ -6,6 +6,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/google/uuid" + "github.com/coder/coder/v2/agent/agentchat" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw/loggermw" "github.com/coder/coder/v2/coderd/tracing" @@ -20,6 +21,7 @@ func (a *agent) apiHandler() http.Handler { httpmw.Recover(a.logger), tracing.StatusWriterMiddleware, loggermw.Logger(a.logger), + agentchat.Middleware, ) r.Get("/", func(rw http.ResponseWriter, r *http.Request) { httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{ diff --git a/agent/x/agentdesktop/api.go b/agent/x/agentdesktop/api.go index fc7686b072..73890c55ed 100644 --- a/agent/x/agentdesktop/api.go +++ b/agent/x/agentdesktop/api.go @@ -16,6 +16,7 @@ import ( "github.com/google/uuid" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/agent/agentchat" "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/codersdk" @@ -85,6 +86,7 @@ func (a *API) Routes() http.Handler { func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + logger := a.logger.With(agentchat.Fields(ctx)...) // Start the desktop session (idempotent). _, err := a.desktop.Start(ctx) @@ -112,7 +114,7 @@ func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) { CompressionMode: websocket.CompressionDisabled, }) if err != nil { - a.logger.Error(ctx, "failed to accept websocket", slog.Error(err)) + logger.Error(ctx, "failed to accept websocket", slog.Error(err)) return } @@ -128,6 +130,7 @@ func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) { func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + logger := a.logger.With(agentchat.Fields(ctx)...) handlerStart := a.clock.Now() // Update last desktop action timestamp for idle recording monitor. @@ -136,7 +139,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) { // Ensure the desktop is running and grab native dimensions. cfg, err := a.desktop.Start(ctx) if err != nil { - a.logger.Warn(ctx, "handleAction: desktop.Start failed", + logger.Warn(ctx, "handleAction: desktop.Start failed", slog.Error(err), slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()), ) @@ -156,7 +159,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) { return } - a.logger.Info(ctx, "handleAction: started", + logger.Info(ctx, "handleAction: started", slog.F("action", action.Action), slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()), ) @@ -272,7 +275,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) { x, y = scaleXY(x, y) stepStart := a.clock.Now() if err := a.desktop.Click(ctx, x, y, MouseButtonLeft); err != nil { - a.logger.Warn(ctx, "handleAction: Click failed", + logger.Warn(ctx, "handleAction: Click failed", slog.F("action", "left_click"), slog.F("step", "click"), slog.F("step_ms", time.Since(stepStart).Milliseconds()), @@ -285,7 +288,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) { }) return } - a.logger.Debug(ctx, "handleAction: Click completed", + logger.Debug(ctx, "handleAction: Click completed", slog.F("action", "left_click"), slog.F("step_ms", time.Since(stepStart).Milliseconds()), slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()), @@ -473,7 +476,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) { case <-ctx.Done(): // Context canceled; release the key immediately. if err := a.desktop.KeyUp(ctx, *action.Text); err != nil { - a.logger.Warn(ctx, "handleAction: KeyUp after context cancel", slog.Error(err)) + logger.Warn(ctx, "handleAction: KeyUp after context cancel", slog.Error(err)) } return case <-timer.C: @@ -513,14 +516,14 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) { elapsedMs := a.clock.Since(handlerStart).Milliseconds() if ctx.Err() != nil { - a.logger.Error(ctx, "handleAction: context canceled before writing response", + logger.Error(ctx, "handleAction: context canceled before writing response", slog.F("action", action.Action), slog.F("elapsed_ms", elapsedMs), slog.Error(ctx.Err()), ) return } - a.logger.Info(ctx, "handleAction: writing response", + logger.Info(ctx, "handleAction: writing response", slog.F("action", action.Action), slog.F("elapsed_ms", elapsedMs), ) @@ -609,6 +612,7 @@ func (a *API) handleRecordingStart(rw http.ResponseWriter, r *http.Request) { func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + logger := a.logger.With(agentchat.Fields(ctx)...) recordingID, ok := a.decodeRecordingRequest(rw, r) if !ok { @@ -661,7 +665,7 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) { }() if artifact.Size > workspacesdk.MaxRecordingSize { - a.logger.Warn(ctx, "recording file exceeds maximum size", + logger.Warn(ctx, "recording file exceeds maximum size", slog.F("recording_id", recordingID), slog.F("size", artifact.Size), slog.F("max_size", workspacesdk.MaxRecordingSize), @@ -677,7 +681,7 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) { // rejecting it here avoids streaming a large thumbnail over // the wire for nothing. if artifact.ThumbnailReader != nil && artifact.ThumbnailSize > workspacesdk.MaxThumbnailSize { - a.logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting", + logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting", slog.F("recording_id", recordingID), slog.F("size", artifact.ThumbnailSize), slog.F("max_size", workspacesdk.MaxThumbnailSize), @@ -701,13 +705,13 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) { "Content-Type": {"video/mp4"}, }) if err != nil { - a.logger.Warn(ctx, "failed to create video multipart part", + logger.Warn(ctx, "failed to create video multipart part", slog.F("recording_id", recordingID), slog.Error(err)) return } if _, err := io.Copy(videoPart, artifact.Reader); err != nil { - a.logger.Warn(ctx, "failed to write video multipart part", + logger.Warn(ctx, "failed to write video multipart part", slog.F("recording_id", recordingID), slog.Error(err)) return @@ -719,7 +723,7 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) { "Content-Type": {"image/jpeg"}, }) if err != nil { - a.logger.Warn(ctx, "failed to create thumbnail multipart part", + logger.Warn(ctx, "failed to create thumbnail multipart part", slog.F("recording_id", recordingID), slog.Error(err)) return diff --git a/agent/x/agentmcp/api.go b/agent/x/agentmcp/api.go index 9b632f8b9b..d291f7a03d 100644 --- a/agent/x/agentmcp/api.go +++ b/agent/x/agentmcp/api.go @@ -8,6 +8,7 @@ import ( "github.com/go-chi/chi/v5" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/agent/agentchat" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" @@ -52,6 +53,7 @@ func (api *API) Routes() http.Handler { // independent of config changes. func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() + logger := api.logger.With(agentchat.Fields(ctx)...) // Check config freshness and reload if changed. var reloaded bool @@ -61,11 +63,11 @@ func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) { // Categorize the error for operator debugging. switch { case errors.Is(err, context.Canceled): - api.logger.Warn(ctx, "mcp reload canceled by caller", slog.Error(err)) + logger.Warn(ctx, "mcp reload canceled by caller", slog.Error(err)) case errors.Is(err, context.DeadlineExceeded): - api.logger.Warn(ctx, "mcp reload timed out", slog.Error(err)) + logger.Warn(ctx, "mcp reload timed out", slog.Error(err)) default: - api.logger.Warn(ctx, "mcp reload failed", slog.Error(err)) + logger.Warn(ctx, "mcp reload failed", slog.Error(err)) } // Fall through to return whatever tools we have. } else { @@ -78,7 +80,7 @@ func (api *API) handleListTools(rw http.ResponseWriter, r *http.Request) { // refreshes tools as part of the reload. if r.URL.Query().Get("refresh") == "true" && !reloaded { if err := api.manager.RefreshTools(ctx); err != nil { - api.logger.Warn(ctx, "failed to refresh MCP tools", slog.Error(err)) + logger.Warn(ctx, "failed to refresh MCP tools", slog.Error(err)) } } diff --git a/agent/x/agentmcp/manager.go b/agent/x/agentmcp/manager.go index 94fc1bf0e3..d1ecab31b6 100644 --- a/agent/x/agentmcp/manager.go +++ b/agent/x/agentmcp/manager.go @@ -22,6 +22,7 @@ import ( tailscalesingleflight "tailscale.com/util/singleflight" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/agent/agentchat" "github.com/coder/coder/v2/agent/agentexec" "github.com/coder/coder/v2/agent/usershell" "github.com/coder/coder/v2/buildinfo" @@ -248,7 +249,8 @@ func (m *Manager) doReload(ctx context.Context, mcpConfigFiles []string) error { // Refresh tools outside the lock to avoid blocking // concurrent reads during network I/O. if err := m.RefreshTools(ctx); err != nil { - m.logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err)) + logger := m.logger.With(agentchat.Fields(ctx)...) + logger.Warn(ctx, "failed to refresh MCP tools after connect", slog.Error(err)) } return nil } @@ -257,6 +259,8 @@ func (m *Manager) doReload(ctx context.Context, mcpConfigFiles []string) error { // list of server configs. Missing files are silently skipped; // parse errors are logged and skipped. func (m *Manager) parseAndDedup(ctx context.Context, mcpConfigFiles []string) ([]ServerConfig, map[string]fileSnapshot) { + logger := m.logger.With(agentchat.Fields(ctx)...) + // Stat before reading so the snapshot is conservatively old. // If a file changes between stat and read, the snapshot // records the old mtime, SnapshotChanged detects a mismatch @@ -272,7 +276,7 @@ func (m *Manager) parseAndDedup(ctx context.Context, mcpConfigFiles []string) ([ if errors.Is(err, fs.ErrNotExist) { continue } - m.logger.Warn(ctx, "failed to parse MCP config", + logger.Warn(ctx, "failed to parse MCP config", slog.F("path", configPath), slog.Error(err), ) @@ -334,6 +338,8 @@ func (m *Manager) classifyServers(wanted map[string]ServerConfig) (*serverDiff, // connectAll runs connectServer in parallel for the given configs. // Failed connects are logged and skipped. func (m *Manager) connectAll(ctx context.Context, toConnect []ServerConfig) []connectedServer { + logger := m.logger.With(agentchat.Fields(ctx)...) + var ( mu sync.Mutex connected []connectedServer @@ -343,7 +349,7 @@ func (m *Manager) connectAll(ctx context.Context, toConnect []ServerConfig) []co eg.Go(func() error { c, err := m.connectServer(ctx, cfg) if err != nil { - m.logger.Warn(ctx, "skipping MCP server", + logger.Warn(ctx, "skipping MCP server", slog.F("server", cfg.Name), slog.F("transport", cfg.Transport), slog.Error(err), @@ -481,6 +487,8 @@ func (m *Manager) CallTool(ctx context.Context, req workspacesdk.CallMCPToolRequ // existing cached tools for servers that failed, so a single // dead server doesn't block updates from healthy ones. func (m *Manager) RefreshTools(ctx context.Context) error { + logger := m.logger.With(agentchat.Fields(ctx)...) + // Snapshot servers under read lock. m.mu.RLock() servers := make(map[string]*serverEntry, len(m.servers)) @@ -508,7 +516,7 @@ func (m *Manager) RefreshTools(ctx context.Context) error { result, err := entry.client.ListTools(listCtx, mcp.ListToolsRequest{}) cancel() if err != nil { - m.logger.Warn(ctx, "failed to list tools from MCP server", + logger.Warn(ctx, "failed to list tools from MCP server", slog.F("server", name), slog.Error(err), ) @@ -670,12 +678,14 @@ func (m *Manager) createTransport(ctx context.Context, cfg ServerConfig) (transp // updateEnv callback, then merges explicit overrides from the // server config on top. func (m *Manager) buildEnv(ctx context.Context, explicit map[string]string) []string { + logger := m.logger.With(agentchat.Fields(ctx)...) + env := usershell.SystemEnvInfo{}.Environ() if m.updateEnv != nil { var err error env, err = m.updateEnv(env) if err != nil { - m.logger.Warn(ctx, "failed to enrich MCP server environment", + logger.Warn(ctx, "failed to enrich MCP server environment", slog.Error(err), ) env = usershell.SystemEnvInfo{}.Environ() diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 92837fd28a..e6bc005b8a 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -905,6 +905,12 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces }) c.mu.Unlock() + c.server.logger.Debug(ctx, "set chat headers on agent conn", + slog.F("chat_id", chatSnapshot.ID), + slog.F("ancestor_chat_ids", ancestorIDs), + slog.F("workspace_id", chatSnapshot.WorkspaceID.UUID), + slog.F("agent_id", dialResult.AgentID), + ) return agentConn, nil } currentConn = c.conn