fix: scope git askpass diff status updates to initiating chat (#22534)

## Problem

When the git askpass flow triggered diff status refreshes, it updated
**every chat** connected to the workspace. This was wasteful and could
cause confusing status updates on unrelated chats.

## Solution

Thread the chat ID through the entire git askpass flow so only the chat
that initiated the git operation gets updated:

1. **`coderd/chatd/chattool/execute.go`** — Sets `CODER_CHAT_ID` env var
on spawned processes (alongside the existing `CODER_CHAT_AGENT`)
2. **`cli/gitaskpass.go`** — Reads `CODER_CHAT_ID` from the environment
and sends it as a `chat_id` query parameter in the `ExternalAuthRequest`
3. **`codersdk/agentsdk/agentsdk.go`** — Adds `ChatID` field to
`ExternalAuthRequest` and encodes it as a query param
4. **`coderd/workspaceagents.go`** — Parses `chat_id` query param and
passes it through to `storeChatGitRef` and
`triggerWorkspaceChatDiffStatusRefresh`
5. **`coderd/chats.go`** — `storeChatGitRef` and
`refreshWorkspaceChatDiffStatuses` now scope updates to just the
initiating chat when a chat ID is provided, falling back to
all-workspace-chats behavior for backwards compatibility (non-chat git
operations)
This commit is contained in:
Kyle Carberry
2026-03-02 22:52:39 -05:00
committed by GitHub
parent 2bdf80d452
commit 56f95a3e6d
7 changed files with 96 additions and 36 deletions
+1
View File
@@ -78,6 +78,7 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
Match: host,
GitBranch: gitBranch,
GitRemoteOrigin: gitRemoteOrigin,
ChatID: inv.Environ.Get("CODER_CHAT_ID"),
})
if err != nil {
var apiError *codersdk.Error
+1
View File
@@ -2183,6 +2183,7 @@ func (p *Server) runChat(
}),
chattool.Execute(chattool.ExecuteOptions{
GetWorkspaceConn: getWorkspaceConn,
ChatID: chat.ID.String(),
}),
chattool.ProcessOutput(chattool.ProcessToolOptions{
GetWorkspaceConn: getWorkspaceConn,
+6 -1
View File
@@ -65,6 +65,7 @@ type ExecuteResult struct {
type ExecuteOptions struct {
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
DefaultTimeout time.Duration
ChatID string
}
// ProcessToolOptions configures a process management tool
@@ -96,7 +97,7 @@ func Execute(options ExecuteOptions) fantasy.AgentTool {
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return executeTool(ctx, conn, args, options.DefaultTimeout), nil
return executeTool(ctx, conn, args, options.DefaultTimeout, options.ChatID), nil
},
)
}
@@ -106,6 +107,7 @@ func executeTool(
conn workspacesdk.AgentConn,
args ExecuteArgs,
optTimeout time.Duration,
chatID string,
) fantasy.ToolResponse {
if args.Command == "" {
return fantasy.NewTextErrorResponse("command is required")
@@ -114,6 +116,9 @@ func executeTool(
// Build the environment map for the process request.
env := make(map[string]string, len(nonInteractiveEnvVars)+1)
env["CODER_CHAT_AGENT"] = "true"
if chatID != "" {
env["CODER_CHAT_ID"] = chatID
}
for k, v := range nonInteractiveEnvVars {
env[k] = v
}
+64 -30
View File
@@ -1003,12 +1003,12 @@ func shouldRefreshChatDiffStatus(status database.ChatDiffStatus, now time.Time,
return chatDiffStatusIsStale(status, now)
}
func (api *API) triggerWorkspaceChatDiffStatusRefresh(workspace database.Workspace, gitRef chatGitRef) {
func (api *API) triggerWorkspaceChatDiffStatusRefresh(workspace database.Workspace, chatID uuid.NullUUID, gitRef chatGitRef) {
if workspace.ID == uuid.Nil || workspace.OwnerID == uuid.Nil {
return
}
go func(workspaceID, workspaceOwnerID uuid.UUID, gitRef chatGitRef) {
go func(workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID, gitRef chatGitRef) {
ctx := api.ctx
if ctx == nil {
ctx = context.Background()
@@ -1019,7 +1019,7 @@ func (api *API) triggerWorkspaceChatDiffStatusRefresh(workspace database.Workspa
// Always store the git ref so the data is persisted even
// before a PR exists. The frontend can show branch info
// and the refresh loop can resolve a PR later.
api.storeChatGitRef(ctx, workspaceID, workspaceOwnerID, gitRef)
api.storeChatGitRef(ctx, workspaceID, workspaceOwnerID, chatID, gitRef)
for _, delay := range chatDiffRefreshBackoffSchedule {
t := api.Clock.NewTimer(delay, "chat_diff_refresh")
@@ -1033,26 +1033,44 @@ func (api *API) triggerWorkspaceChatDiffStatusRefresh(workspace database.Workspa
// Refresh and publish status on every iteration.
// Stop the loop once a PR is discovered — there's
// nothing more to wait for after that.
if api.refreshWorkspaceChatDiffStatuses(ctx, workspaceID, workspaceOwnerID) {
if api.refreshWorkspaceChatDiffStatuses(ctx, workspaceID, workspaceOwnerID, chatID) {
return
}
}
}(workspace.ID, workspace.OwnerID, gitRef)
}(workspace.ID, workspace.OwnerID, chatID, gitRef)
}
// storeChatGitRef persists the git branch and remote origin reported
// by the workspace agent on all chats associated with the workspace.
func (api *API) storeChatGitRef(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, gitRef chatGitRef) {
chats, err := api.Database.GetChatsByOwnerID(ctx, workspaceOwnerID)
if err != nil {
api.Logger.Warn(ctx, "failed to list chats for git ref storage",
slog.F("workspace_id", workspaceID),
slog.Error(err),
)
return
// by the workspace agent on the chat that initiated the git operation.
// When chatID is set, only that specific chat is updated; otherwise all
// chats associated with the workspace are updated (legacy fallback).
func (api *API) storeChatGitRef(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID, gitRef chatGitRef) {
var chatsToUpdate []database.Chat
if chatID.Valid {
chat, err := api.Database.GetChatByID(ctx, chatID.UUID)
if err != nil {
api.Logger.Warn(ctx, "failed to get chat for git ref storage",
slog.F("chat_id", chatID.UUID),
slog.F("workspace_id", workspaceID),
slog.Error(err),
)
return
}
chatsToUpdate = []database.Chat{chat}
} else {
chats, err := api.Database.GetChatsByOwnerID(ctx, workspaceOwnerID)
if err != nil {
api.Logger.Warn(ctx, "failed to list chats for git ref storage",
slog.F("workspace_id", workspaceID),
slog.Error(err),
)
return
}
chatsToUpdate = filterChatsByWorkspaceID(chats, workspaceID)
}
for _, chat := range filterChatsByWorkspaceID(chats, workspaceID) {
for _, chat := range chatsToUpdate {
_, err := api.Database.UpsertChatDiffStatusReference(ctx, database.UpsertChatDiffStatusReferenceParams{
ChatID: chat.ID,
GitBranch: gitRef.Branch,
@@ -1072,22 +1090,38 @@ func (api *API) storeChatGitRef(ctx context.Context, workspaceID, workspaceOwner
}
}
// refreshWorkspaceChatDiffStatuses refreshes the diff status for all
// chats associated with the given workspace. It returns true when
// every chat has a PR URL resolved, signaling that the caller can
// stop polling.
func (api *API) refreshWorkspaceChatDiffStatuses(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID) bool {
chats, err := api.Database.GetChatsByOwnerID(ctx, workspaceOwnerID)
if err != nil {
api.Logger.Warn(ctx, "failed to list workspace owner chats for diff refresh",
slog.F("workspace_id", workspaceID),
slog.F("workspace_owner_id", workspaceOwnerID),
slog.Error(err),
)
return false
}
// refreshWorkspaceChatDiffStatuses refreshes the diff status for chats
// associated with the given workspace. When chatID is set, only that
// specific chat is refreshed; otherwise all chats for the workspace
// are refreshed (legacy fallback). It returns true when every
// refreshed chat has a PR URL resolved, signaling that the caller
// can stop polling.
func (api *API) refreshWorkspaceChatDiffStatuses(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID) bool {
var filtered []database.Chat
filtered := filterChatsByWorkspaceID(chats, workspaceID)
if chatID.Valid {
chat, err := api.Database.GetChatByID(ctx, chatID.UUID)
if err != nil {
api.Logger.Warn(ctx, "failed to get chat for diff refresh",
slog.F("chat_id", chatID.UUID),
slog.F("workspace_id", workspaceID),
slog.Error(err),
)
return false
}
filtered = []database.Chat{chat}
} else {
chats, err := api.Database.GetChatsByOwnerID(ctx, workspaceOwnerID)
if err != nil {
api.Logger.Warn(ctx, "failed to list workspace owner chats for diff refresh",
slog.F("workspace_id", workspaceID),
slog.F("workspace_owner_id", workspaceOwnerID),
slog.Error(err),
)
return false
}
filtered = filterChatsByWorkspaceID(chats, workspaceID)
}
if len(filtered) == 0 {
return false
}
+17 -5
View File
@@ -1835,6 +1835,18 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
Branch: strings.TrimSpace(query.Get("git_branch")),
RemoteOrigin: strings.TrimSpace(query.Get("git_remote_origin")),
}
var chatID uuid.NullUUID
if rawChatID := query.Get("chat_id"); rawChatID != "" {
parsed, err := uuid.Parse(rawChatID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid chat_id.",
Detail: err.Error(),
})
return
}
chatID = uuid.NullUUID{UUID: parsed, Valid: true}
}
// Either match or configID must be provided!
match := query.Get("match")
if match == "" {
@@ -1932,7 +1944,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
// context is retained even if the flow requires an out-of-band login.
if gitRef.Branch != "" || gitRef.RemoteOrigin != "" {
//nolint:gocritic // System context required to persist chat git refs.
api.storeChatGitRef(dbauthz.AsSystemRestricted(ctx), workspace.ID, workspace.OwnerID, gitRef)
api.storeChatGitRef(dbauthz.AsSystemRestricted(ctx), workspace.ID, workspace.OwnerID, chatID, gitRef)
}
var previousToken *database.ExternalAuthLink
@@ -1948,7 +1960,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
return
}
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, gitRef)
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, chatID, gitRef)
}
// This is the URL that will redirect the user with a state token.
@@ -2006,11 +2018,11 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
})
return
}
api.triggerWorkspaceChatDiffStatusRefresh(workspace, gitRef)
api.triggerWorkspaceChatDiffStatusRefresh(workspace, chatID, gitRef)
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, gitRef chatGitRef) {
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, chatID uuid.NullUUID, gitRef chatGitRef) {
// Since we're ticking frequently and this sign-in operation is rare,
// we are OK with polling to avoid the complexity of pubsub.
ticker, done := api.NewTicker(time.Second)
@@ -2080,7 +2092,7 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R
})
return
}
api.triggerWorkspaceChatDiffStatusRefresh(workspace, gitRef)
api.triggerWorkspaceChatDiffStatusRefresh(workspace, chatID, gitRef)
httpapi.Write(ctx, rw, http.StatusOK, resp)
return
}
+5
View File
@@ -646,6 +646,8 @@ type ExternalAuthRequest struct {
// Sent by the agent so the control plane can resolve diffs
// without SSHing into the workspace.
GitRemoteOrigin string
// ChatID identifies which chat initiated the git operation.
ChatID string
// Listen indicates that the request should be long-lived and listen for
// a new token to be requested.
Listen bool
@@ -667,6 +669,9 @@ func (c *Client) ExternalAuth(ctx context.Context, req ExternalAuthRequest) (Ext
if req.GitRemoteOrigin != "" {
q.Set("git_remote_origin", req.GitRemoteOrigin)
}
if req.ChatID != "" {
q.Set("chat_id", req.ChatID)
}
reqURL := "/api/v2/workspaceagents/me/external-auth?" + q.Encode()
res, err := c.SDK.Request(ctx, http.MethodGet, reqURL, nil)
if err != nil {
+2
View File
@@ -165,6 +165,7 @@ func TestExternalAuthRequestQuery(t *testing.T) {
require.Equal(t, "true", r.URL.Query().Get("listen"))
require.Equal(t, "main", r.URL.Query().Get("git_branch"))
require.Equal(t, "https://github.com/coder/coder.git", r.URL.Query().Get("git_remote_origin"))
require.Equal(t, "test-chat-id", r.URL.Query().Get("chat_id"))
require.False(t, r.URL.Query().Has("workdir"))
_, _ = w.Write([]byte(`{"type":"github","access_token":"token"}`))
}))
@@ -179,6 +180,7 @@ func TestExternalAuthRequestQuery(t *testing.T) {
Listen: true,
GitBranch: "main",
GitRemoteOrigin: "https://github.com/coder/coder.git",
ChatID: "test-chat-id",
})
require.NoError(t, err)
})