mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: scope git askpass diff status updates to initiating chat (#22534)
## Problem When the git askpass flow triggered diff status refreshes, it updated **every chat** connected to the workspace. This was wasteful and could cause confusing status updates on unrelated chats. ## Solution Thread the chat ID through the entire git askpass flow so only the chat that initiated the git operation gets updated: 1. **`coderd/chatd/chattool/execute.go`** — Sets `CODER_CHAT_ID` env var on spawned processes (alongside the existing `CODER_CHAT_AGENT`) 2. **`cli/gitaskpass.go`** — Reads `CODER_CHAT_ID` from the environment and sends it as a `chat_id` query parameter in the `ExternalAuthRequest` 3. **`codersdk/agentsdk/agentsdk.go`** — Adds `ChatID` field to `ExternalAuthRequest` and encodes it as a query param 4. **`coderd/workspaceagents.go`** — Parses `chat_id` query param and passes it through to `storeChatGitRef` and `triggerWorkspaceChatDiffStatusRefresh` 5. **`coderd/chats.go`** — `storeChatGitRef` and `refreshWorkspaceChatDiffStatuses` now scope updates to just the initiating chat when a chat ID is provided, falling back to all-workspace-chats behavior for backwards compatibility (non-chat git operations)
This commit is contained in:
@@ -78,6 +78,7 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
Match: host,
|
||||
GitBranch: gitBranch,
|
||||
GitRemoteOrigin: gitRemoteOrigin,
|
||||
ChatID: inv.Environ.Get("CODER_CHAT_ID"),
|
||||
})
|
||||
if err != nil {
|
||||
var apiError *codersdk.Error
|
||||
|
||||
@@ -2183,6 +2183,7 @@ func (p *Server) runChat(
|
||||
}),
|
||||
chattool.Execute(chattool.ExecuteOptions{
|
||||
GetWorkspaceConn: getWorkspaceConn,
|
||||
ChatID: chat.ID.String(),
|
||||
}),
|
||||
chattool.ProcessOutput(chattool.ProcessToolOptions{
|
||||
GetWorkspaceConn: getWorkspaceConn,
|
||||
|
||||
@@ -65,6 +65,7 @@ type ExecuteResult struct {
|
||||
type ExecuteOptions struct {
|
||||
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
|
||||
DefaultTimeout time.Duration
|
||||
ChatID string
|
||||
}
|
||||
|
||||
// ProcessToolOptions configures a process management tool
|
||||
@@ -96,7 +97,7 @@ func Execute(options ExecuteOptions) fantasy.AgentTool {
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return executeTool(ctx, conn, args, options.DefaultTimeout), nil
|
||||
return executeTool(ctx, conn, args, options.DefaultTimeout, options.ChatID), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -106,6 +107,7 @@ func executeTool(
|
||||
conn workspacesdk.AgentConn,
|
||||
args ExecuteArgs,
|
||||
optTimeout time.Duration,
|
||||
chatID string,
|
||||
) fantasy.ToolResponse {
|
||||
if args.Command == "" {
|
||||
return fantasy.NewTextErrorResponse("command is required")
|
||||
@@ -114,6 +116,9 @@ func executeTool(
|
||||
// Build the environment map for the process request.
|
||||
env := make(map[string]string, len(nonInteractiveEnvVars)+1)
|
||||
env["CODER_CHAT_AGENT"] = "true"
|
||||
if chatID != "" {
|
||||
env["CODER_CHAT_ID"] = chatID
|
||||
}
|
||||
for k, v := range nonInteractiveEnvVars {
|
||||
env[k] = v
|
||||
}
|
||||
|
||||
+64
-30
@@ -1003,12 +1003,12 @@ func shouldRefreshChatDiffStatus(status database.ChatDiffStatus, now time.Time,
|
||||
return chatDiffStatusIsStale(status, now)
|
||||
}
|
||||
|
||||
func (api *API) triggerWorkspaceChatDiffStatusRefresh(workspace database.Workspace, gitRef chatGitRef) {
|
||||
func (api *API) triggerWorkspaceChatDiffStatusRefresh(workspace database.Workspace, chatID uuid.NullUUID, gitRef chatGitRef) {
|
||||
if workspace.ID == uuid.Nil || workspace.OwnerID == uuid.Nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func(workspaceID, workspaceOwnerID uuid.UUID, gitRef chatGitRef) {
|
||||
go func(workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID, gitRef chatGitRef) {
|
||||
ctx := api.ctx
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
@@ -1019,7 +1019,7 @@ func (api *API) triggerWorkspaceChatDiffStatusRefresh(workspace database.Workspa
|
||||
// Always store the git ref so the data is persisted even
|
||||
// before a PR exists. The frontend can show branch info
|
||||
// and the refresh loop can resolve a PR later.
|
||||
api.storeChatGitRef(ctx, workspaceID, workspaceOwnerID, gitRef)
|
||||
api.storeChatGitRef(ctx, workspaceID, workspaceOwnerID, chatID, gitRef)
|
||||
|
||||
for _, delay := range chatDiffRefreshBackoffSchedule {
|
||||
t := api.Clock.NewTimer(delay, "chat_diff_refresh")
|
||||
@@ -1033,26 +1033,44 @@ func (api *API) triggerWorkspaceChatDiffStatusRefresh(workspace database.Workspa
|
||||
// Refresh and publish status on every iteration.
|
||||
// Stop the loop once a PR is discovered — there's
|
||||
// nothing more to wait for after that.
|
||||
if api.refreshWorkspaceChatDiffStatuses(ctx, workspaceID, workspaceOwnerID) {
|
||||
if api.refreshWorkspaceChatDiffStatuses(ctx, workspaceID, workspaceOwnerID, chatID) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}(workspace.ID, workspace.OwnerID, gitRef)
|
||||
}(workspace.ID, workspace.OwnerID, chatID, gitRef)
|
||||
}
|
||||
|
||||
// storeChatGitRef persists the git branch and remote origin reported
|
||||
// by the workspace agent on all chats associated with the workspace.
|
||||
func (api *API) storeChatGitRef(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, gitRef chatGitRef) {
|
||||
chats, err := api.Database.GetChatsByOwnerID(ctx, workspaceOwnerID)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to list chats for git ref storage",
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
// by the workspace agent on the chat that initiated the git operation.
|
||||
// When chatID is set, only that specific chat is updated; otherwise all
|
||||
// chats associated with the workspace are updated (legacy fallback).
|
||||
func (api *API) storeChatGitRef(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID, gitRef chatGitRef) {
|
||||
var chatsToUpdate []database.Chat
|
||||
|
||||
if chatID.Valid {
|
||||
chat, err := api.Database.GetChatByID(ctx, chatID.UUID)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to get chat for git ref storage",
|
||||
slog.F("chat_id", chatID.UUID),
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
chatsToUpdate = []database.Chat{chat}
|
||||
} else {
|
||||
chats, err := api.Database.GetChatsByOwnerID(ctx, workspaceOwnerID)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to list chats for git ref storage",
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
chatsToUpdate = filterChatsByWorkspaceID(chats, workspaceID)
|
||||
}
|
||||
|
||||
for _, chat := range filterChatsByWorkspaceID(chats, workspaceID) {
|
||||
for _, chat := range chatsToUpdate {
|
||||
_, err := api.Database.UpsertChatDiffStatusReference(ctx, database.UpsertChatDiffStatusReferenceParams{
|
||||
ChatID: chat.ID,
|
||||
GitBranch: gitRef.Branch,
|
||||
@@ -1072,22 +1090,38 @@ func (api *API) storeChatGitRef(ctx context.Context, workspaceID, workspaceOwner
|
||||
}
|
||||
}
|
||||
|
||||
// refreshWorkspaceChatDiffStatuses refreshes the diff status for all
|
||||
// chats associated with the given workspace. It returns true when
|
||||
// every chat has a PR URL resolved, signaling that the caller can
|
||||
// stop polling.
|
||||
func (api *API) refreshWorkspaceChatDiffStatuses(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID) bool {
|
||||
chats, err := api.Database.GetChatsByOwnerID(ctx, workspaceOwnerID)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to list workspace owner chats for diff refresh",
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.F("workspace_owner_id", workspaceOwnerID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return false
|
||||
}
|
||||
// refreshWorkspaceChatDiffStatuses refreshes the diff status for chats
|
||||
// associated with the given workspace. When chatID is set, only that
|
||||
// specific chat is refreshed; otherwise all chats for the workspace
|
||||
// are refreshed (legacy fallback). It returns true when every
|
||||
// refreshed chat has a PR URL resolved, signaling that the caller
|
||||
// can stop polling.
|
||||
func (api *API) refreshWorkspaceChatDiffStatuses(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID) bool {
|
||||
var filtered []database.Chat
|
||||
|
||||
filtered := filterChatsByWorkspaceID(chats, workspaceID)
|
||||
if chatID.Valid {
|
||||
chat, err := api.Database.GetChatByID(ctx, chatID.UUID)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to get chat for diff refresh",
|
||||
slog.F("chat_id", chatID.UUID),
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return false
|
||||
}
|
||||
filtered = []database.Chat{chat}
|
||||
} else {
|
||||
chats, err := api.Database.GetChatsByOwnerID(ctx, workspaceOwnerID)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to list workspace owner chats for diff refresh",
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.F("workspace_owner_id", workspaceOwnerID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return false
|
||||
}
|
||||
filtered = filterChatsByWorkspaceID(chats, workspaceID)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1835,6 +1835,18 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
Branch: strings.TrimSpace(query.Get("git_branch")),
|
||||
RemoteOrigin: strings.TrimSpace(query.Get("git_remote_origin")),
|
||||
}
|
||||
var chatID uuid.NullUUID
|
||||
if rawChatID := query.Get("chat_id"); rawChatID != "" {
|
||||
parsed, err := uuid.Parse(rawChatID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid chat_id.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
chatID = uuid.NullUUID{UUID: parsed, Valid: true}
|
||||
}
|
||||
// Either match or configID must be provided!
|
||||
match := query.Get("match")
|
||||
if match == "" {
|
||||
@@ -1932,7 +1944,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
// context is retained even if the flow requires an out-of-band login.
|
||||
if gitRef.Branch != "" || gitRef.RemoteOrigin != "" {
|
||||
//nolint:gocritic // System context required to persist chat git refs.
|
||||
api.storeChatGitRef(dbauthz.AsSystemRestricted(ctx), workspace.ID, workspace.OwnerID, gitRef)
|
||||
api.storeChatGitRef(dbauthz.AsSystemRestricted(ctx), workspace.ID, workspace.OwnerID, chatID, gitRef)
|
||||
}
|
||||
|
||||
var previousToken *database.ExternalAuthLink
|
||||
@@ -1948,7 +1960,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, gitRef)
|
||||
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, chatID, gitRef)
|
||||
}
|
||||
|
||||
// This is the URL that will redirect the user with a state token.
|
||||
@@ -2006,11 +2018,11 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
})
|
||||
return
|
||||
}
|
||||
api.triggerWorkspaceChatDiffStatusRefresh(workspace, gitRef)
|
||||
api.triggerWorkspaceChatDiffStatusRefresh(workspace, chatID, gitRef)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, gitRef chatGitRef) {
|
||||
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, chatID uuid.NullUUID, gitRef chatGitRef) {
|
||||
// Since we're ticking frequently and this sign-in operation is rare,
|
||||
// we are OK with polling to avoid the complexity of pubsub.
|
||||
ticker, done := api.NewTicker(time.Second)
|
||||
@@ -2080,7 +2092,7 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R
|
||||
})
|
||||
return
|
||||
}
|
||||
api.triggerWorkspaceChatDiffStatusRefresh(workspace, gitRef)
|
||||
api.triggerWorkspaceChatDiffStatusRefresh(workspace, chatID, gitRef)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -646,6 +646,8 @@ type ExternalAuthRequest struct {
|
||||
// Sent by the agent so the control plane can resolve diffs
|
||||
// without SSHing into the workspace.
|
||||
GitRemoteOrigin string
|
||||
// ChatID identifies which chat initiated the git operation.
|
||||
ChatID string
|
||||
// Listen indicates that the request should be long-lived and listen for
|
||||
// a new token to be requested.
|
||||
Listen bool
|
||||
@@ -667,6 +669,9 @@ func (c *Client) ExternalAuth(ctx context.Context, req ExternalAuthRequest) (Ext
|
||||
if req.GitRemoteOrigin != "" {
|
||||
q.Set("git_remote_origin", req.GitRemoteOrigin)
|
||||
}
|
||||
if req.ChatID != "" {
|
||||
q.Set("chat_id", req.ChatID)
|
||||
}
|
||||
reqURL := "/api/v2/workspaceagents/me/external-auth?" + q.Encode()
|
||||
res, err := c.SDK.Request(ctx, http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -165,6 +165,7 @@ func TestExternalAuthRequestQuery(t *testing.T) {
|
||||
require.Equal(t, "true", r.URL.Query().Get("listen"))
|
||||
require.Equal(t, "main", r.URL.Query().Get("git_branch"))
|
||||
require.Equal(t, "https://github.com/coder/coder.git", r.URL.Query().Get("git_remote_origin"))
|
||||
require.Equal(t, "test-chat-id", r.URL.Query().Get("chat_id"))
|
||||
require.False(t, r.URL.Query().Has("workdir"))
|
||||
_, _ = w.Write([]byte(`{"type":"github","access_token":"token"}`))
|
||||
}))
|
||||
@@ -179,6 +180,7 @@ func TestExternalAuthRequestQuery(t *testing.T) {
|
||||
Listen: true,
|
||||
GitBranch: "main",
|
||||
GitRemoteOrigin: "https://github.com/coder/coder.git",
|
||||
ChatID: "test-chat-id",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user