From c4ef94aacf6852089f0845f279b7e94d626d93c1 Mon Sep 17 00:00:00 2001 From: Ethan <39577870+ethanndickson@users.noreply.github.com> Date: Fri, 27 Mar 2026 18:47:39 +1100 Subject: [PATCH] fix(coderd/x/chatd): prevent chat hang when workspace agent is unavailable (#23707) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Problem Chats with a persisted `agent_id` binding hang indefinitely when the workspace is stopped. The stale agent row still exists in the DB, so `ensureWorkspaceAgent` succeeds, but the dial blocks forever in `AwaitReachable`. The MCP discovery goroutine used an unbounded context, so `g2.Wait()` never returned and the LLM never started. ## Fix Three targeted changes restore the pre-binding behavior where stopped workspaces degrade gracefully instead of blocking: 1. **`dialWithLazyValidation`**: "no agents in latest build" is now a terminal fast-fail — the hanging dial is canceled and `errChatHasNoWorkspaceAgent` returned immediately, instead of falling through to `waitForOriginalDial`. 2. **Pre-LLM workspace setup**: MCP discovery and instruction persistence gate on `workspaceAgentIDForConn` before attempting any dial. MCP discovery is bounded by a 5s timeout and checks the in-memory tool cache first (using the cheap cached agent from `ensureWorkspaceAgent`), so the common subsequent-turn path has zero DB queries. 3. **`persistInstructionFiles`**: tracks whether the workspace connection succeeded and skips sentinel persistence on failure, so the next turn retries if the workspace is restarted. ## Scenarios **Running workspace, subsequent turn (hot path):** MCP cache hit via in-memory cached agent. Zero DB queries, zero dials. Unchanged from #23274. **Stopped workspace, persisted binding (the bug):** MCP cache hit (stale descriptors, fine — they fail at invocation). Pre-LLM setup completes instantly. Tool invocation enters `dialWithLazyValidation`, dial fails or hangs, validation discovers no agents, returns `errChatHasNoWorkspaceAgent`. Model sees the error and can call `start_workspace`. **New chat, running workspace:** `ensureWorkspaceAgent` resolves via latest-build, persists binding. MCP discovery dials and caches tools. **New chat, stopped workspace:** `ensureWorkspaceAgent` finds no agents, returns `errChatHasNoWorkspaceAgent`. Pre-LLM setup skips. LLM starts with built-in tools only. **Rebuilt workspace (agent switched):** MCP cache hit with stale agent (harmless for one turn). Tool invocation dials stale agent, fails fast, `dialWithLazyValidation` switches to new agent, persists updated binding. **Workspace restarted after stop:** No sentinel was persisted during the stopped turn, so instruction persistence retries. Agent binding switches to the new agent via `workspaceAgentIDForConn`. **Transient DB error during validation:** Not `errChatHasNoWorkspaceAgent`, so `dialWithLazyValidation` falls through to `waitForOriginalDial` (cannot prove stale). No false positive. **Tool invocation on stopped workspace:** `getWorkspaceConn` calls `ensureWorkspaceAgent` (returns stale row), then `dialWithLazyValidation` validation discovers no agents, returns `errChatHasNoWorkspaceAgent`, cached state cleared, error returned to model. --- coderd/x/chatd/chatd.go | 155 ++++++++++++++++++++----- coderd/x/chatd/chatd_internal_test.go | 95 ++++++++++++++++ coderd/x/chatd/chatd_test.go | 157 ++++++++++++++++++++++++++ coderd/x/chatd/dialvalidation.go | 16 ++- coderd/x/chatd/dialvalidation_test.go | 49 ++++++++ 5 files changed, 443 insertions(+), 29 deletions(-) diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 21e2b4af66..c1d9f963f3 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -55,6 +55,7 @@ const ( homeInstructionLookupTimeout = 5 * time.Second instructionCacheTTL = 5 * time.Minute workspaceDialValidationDelay = 5 * time.Second + workspaceMCPDiscoveryTimeout = 5 * time.Second // DefaultChatHeartbeatInterval is the default time between chat // heartbeat updates while a chat is being processed. DefaultChatHeartbeatInterval = 30 * time.Second @@ -89,6 +90,8 @@ const ( defaultSubagentInstruction = "You are running as a delegated sub-agent chat. Complete the delegated task and provide clear, concise assistant responses for the parent agent." ) +var errChatHasNoWorkspaceAgent = xerrors.New("chat has no workspace agent") + // Server handles background processing of pending chats. type Server struct { cancel context.CancelFunc @@ -336,8 +339,14 @@ func (c *turnWorkspaceContext) loadWorkspaceAgentLocked( ctx, chatSnapshot.WorkspaceID.UUID, ) - if err != nil || len(agents) == 0 { - return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace agent") + if err != nil { + return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf( + "get workspace agents in latest build: %w", + err, + ) + } + if len(agents) == 0 { + return chatSnapshot, database.WorkspaceAgent{}, errChatHasNoWorkspaceAgent } build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID) @@ -372,6 +381,65 @@ func (c *turnWorkspaceContext) loadWorkspaceAgentLocked( ) } +func (c *turnWorkspaceContext) latestWorkspaceAgentID( + ctx context.Context, + workspaceID uuid.UUID, +) (uuid.UUID, error) { + agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID( + ctx, + workspaceID, + ) + if err != nil { + return uuid.Nil, xerrors.Errorf( + "get workspace agents in latest build: %w", + err, + ) + } + if len(agents) == 0 { + return uuid.Nil, errChatHasNoWorkspaceAgent + } + return agents[0].ID, nil +} + +func (c *turnWorkspaceContext) workspaceAgentIDForConn( + ctx context.Context, +) (database.Chat, uuid.UUID, error) { + for attempt := 0; attempt < 2; attempt++ { + chatSnapshot := c.currentChatSnapshot() + if !chatSnapshot.WorkspaceID.Valid || !chatSnapshot.AgentID.Valid { + updatedChat, agent, err := c.ensureWorkspaceAgent(ctx) + if err != nil { + return updatedChat, uuid.Nil, err + } + return updatedChat, agent.ID, nil + } + + currentAgentID, err := c.latestWorkspaceAgentID( + ctx, + chatSnapshot.WorkspaceID.UUID, + ) + if err != nil { + if xerrors.Is(err, errChatHasNoWorkspaceAgent) { + c.clearCachedWorkspaceState() + } + return chatSnapshot, uuid.Nil, err + } + + latestChat, workspaceMatches := c.currentWorkspaceMatches( + chatSnapshot.WorkspaceID, + ) + if !workspaceMatches { + continue + } + return latestChat, currentAgentID, nil + } + + chatSnapshot := c.currentChatSnapshot() + return chatSnapshot, uuid.Nil, xerrors.New( + "chat workspace changed while resolving agent", + ) +} + // getWorkspaceConnLocked returns the cached connection when it still matches // the current workspace. When the workspace changed, it clears the stale // cached state and returns the release func for the caller to run after @@ -422,15 +490,14 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces chatSnapshot.WorkspaceID.UUID, DialFunc(c.server.agentConnFn), func(ctx context.Context, workspaceID uuid.UUID) (uuid.UUID, error) { - agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspaceID) - if err != nil || len(agents) == 0 { - return uuid.Nil, xerrors.New("chat has no workspace agent") - } - return agents[0].ID, nil + return c.latestWorkspaceAgentID(ctx, workspaceID) }, workspaceDialValidationDelay, ) if err != nil { + if xerrors.Is(err, errChatHasNoWorkspaceAgent) { + c.clearCachedWorkspaceState() + } return nil, err } @@ -3763,7 +3830,12 @@ func (p *Server) runChat( chat, modelConfig.ID, workspaceCtx.getWorkspaceAgent, - workspaceCtx.getWorkspaceConn, + func(instructionCtx context.Context) (workspacesdk.AgentConn, error) { + if _, _, err := workspaceCtx.workspaceAgentIDForConn(instructionCtx); err != nil { + return nil, err + } + return workspaceCtx.getWorkspaceConn(instructionCtx) + }, ) if persistErr != nil { p.logger.Warn(ctx, "failed to persist instruction files", @@ -3793,32 +3865,54 @@ func (p *Server) runChat( } if chat.WorkspaceID.Valid { g2.Go(func() error { - // Check cache first. On subsequent turns with the same - // agent, reuse cached tools to avoid a round-trip. - if cached, ok := p.workspaceMCPToolsCache.Load(chat.ID); ok { - entry, ok2 := cached.(*cachedWorkspaceMCPTools) - if !ok2 { - return nil - } - // Verify the agent hasn't changed. - if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil && agent.ID == entry.agentID { - for _, t := range entry.tools { - workspaceMCPTools = append(workspaceMCPTools, - chattool.NewWorkspaceMCPTool(t, workspaceCtx.getWorkspaceConn), - ) + // Fast path: check cache using the in-memory cached + // agent (ensureWorkspaceAgent is free when already + // loaded). This avoids a per-turn latest-build DB + // query on the common subsequent-turn path. + if agent, err := workspaceCtx.getWorkspaceAgent(ctx); err == nil { + if cached, ok := p.workspaceMCPToolsCache.Load(chat.ID); ok { + entry, ok := cached.(*cachedWorkspaceMCPTools) + if ok && entry.agentID == agent.ID { + for _, t := range entry.tools { + workspaceMCPTools = append(workspaceMCPTools, + chattool.NewWorkspaceMCPTool(t, workspaceCtx.getWorkspaceConn), + ) + } + return nil } - return nil } } - // Cache miss or agent changed — fetch fresh tools. - conn, connErr := workspaceCtx.getWorkspaceConn(ctx) + // Cache miss, agent changed, or no cache — validate + // that the workspace still has a live agent before + // attempting a dial. + workspaceMCPCtx, cancel := context.WithTimeout( + ctx, + workspaceMCPDiscoveryTimeout, + ) + defer cancel() + + _, _, agentErr := workspaceCtx.workspaceAgentIDForConn( + workspaceMCPCtx, + ) + if agentErr != nil { + if xerrors.Is(agentErr, errChatHasNoWorkspaceAgent) { + p.workspaceMCPToolsCache.Delete(chat.ID) + return nil + } + logger.Warn(ctx, "failed to resolve workspace agent for MCP tools", + slog.Error(agentErr)) + return nil + } + + // Fetch fresh tools from the workspace agent. + conn, connErr := workspaceCtx.getWorkspaceConn(workspaceMCPCtx) if connErr != nil { logger.Warn(ctx, "failed to get workspace conn for MCP tools", slog.Error(connErr)) return nil } - toolsResp, listErr := conn.ListMCPTools(ctx) + toolsResp, listErr := conn.ListMCPTools(workspaceMCPCtx) if listErr != nil { logger.Warn(ctx, "failed to list workspace MCP tools", slog.Error(listErr)) @@ -3831,7 +3925,7 @@ func (p *Server) runChat( // caching an empty list would hide tools // permanently. if len(toolsResp.Tools) > 0 { - if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil { + if agent, agentErr := workspaceCtx.getWorkspaceAgent(workspaceMCPCtx); agentErr == nil { p.workspaceMCPToolsCache.Store(chat.ID, &cachedWorkspaceMCPTools{ agentID: agent.ID, tools: toolsResp.Tools, @@ -4709,7 +4803,10 @@ func (p *Server) persistInstructionFiles( } // Read instruction files from the workspace agent. - var sections []instructionFileSection + var ( + sections []instructionFileSection + workspaceConnOK bool + ) if getWorkspaceConn != nil { instructionCtx, cancel := context.WithTimeout(ctx, homeInstructionLookupTimeout) defer cancel() @@ -4721,6 +4818,7 @@ func (p *Server) persistInstructionFiles( slog.Error(connErr), ) } else { + workspaceConnOK = true if content, source, truncated, readErr := readHomeInstructionFile(instructionCtx, conn); readErr != nil { p.logger.Debug(ctx, "failed to load home instruction file", slog.F("chat_id", chat.ID), slog.Error(readErr)) @@ -4740,6 +4838,9 @@ func (p *Server) persistInstructionFiles( } if len(sections) == 0 { + if !workspaceConnOK { + return "", nil + } // Persist a sentinel so subsequent turns skip the // workspace agent dial. parts := []codersdk.ChatMessagePart{{ diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index f13b9b8b50..c71b2edd09 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -363,6 +363,43 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) { require.Contains(t, instruction, "Working Directory: /home/coder/project") } +func TestPersistInstructionFilesSkipsSentinelWhenWorkspaceUnavailable(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }, + } + server := &Server{ + db: db, + logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + } + + instruction, err := server.persistInstructionFiles( + ctx, + chat, + uuid.New(), + func(context.Context) (database.WorkspaceAgent, error) { + return database.WorkspaceAgent{ + ID: uuid.New(), + Directory: "/home/coder/project", + }, nil + }, + func(context.Context) (workspacesdk.AgentConn, error) { + return nil, errChatHasNoWorkspaceAgent + }, + ) + require.NoError(t, err) + require.Empty(t, instruction) +} + func TestTurnWorkspaceContext_BindingFirstPath(t *testing.T) { t.Parallel() @@ -590,6 +627,64 @@ func TestTurnWorkspaceContextGetWorkspaceConnLazyValidationSwitchesWorkspaceAgen require.Equal(t, currentAgent, gotAgent) } +func TestTurnWorkspaceContextGetWorkspaceConnFastFailsWithoutCurrentAgent(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + staleAgentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: staleAgentID, + Valid: true, + }, + } + + staleAgent := database.WorkspaceAgent{ID: staleAgentID} + + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID). + Return(staleAgent, nil). + Times(1) + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID). + Return([]database.WorkspaceAgent{}, nil). + Times(1) + + server := &Server{db: db} + server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) { + return nil, nil, xerrors.New("dial failed") + } + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: server, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + defer workspaceCtx.close() + + gotConn, err := workspaceCtx.getWorkspaceConn(ctx) + require.Nil(t, gotConn) + require.ErrorIs(t, err, errChatHasNoWorkspaceAgent) + + workspaceCtx.mu.Lock() + defer workspaceCtx.mu.Unlock() + require.Equal(t, database.WorkspaceAgent{}, workspaceCtx.agent) + require.False(t, workspaceCtx.agentLoaded) + require.Nil(t, workspaceCtx.conn) + require.Nil(t, workspaceCtx.releaseConn) + require.Equal(t, uuid.NullUUID{}, workspaceCtx.cachedWorkspaceID) +} + func TestTurnWorkspaceContext_SelectWorkspaceClearsCachedState(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index 783d268c22..0f9d366773 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -30,6 +30,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" @@ -2152,6 +2153,162 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) { require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include start_workspace tool output") } +func TestStoppedWorkspaceWithPersistedAgentBindingDoesNotBlockChat(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + var streamedCallCount atomic.Int32 + var streamedCallsMu sync.Mutex + streamedCalls := make([][]chattest.OpenAIMessage, 0, 2) + toolsByCall := make([][]string, 0, 2) + + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("Stopped workspace regression") + } + + names := make([]string, 0, len(req.Tools)) + for _, tool := range req.Tools { + names = append(names, tool.Function.Name) + } + + streamedCallsMu.Lock() + streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...)) + toolsByCall = append(toolsByCall, names) + streamedCallsMu.Unlock() + + if streamedCallCount.Add(1) == 1 { + return chattest.OpenAIStreamingResponse( + chattest.OpenAIToolCallChunk("execute", `{"command":"echo hi"}`), + ) + } + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("The workspace is unavailable. Start it before retrying workspace tools.")..., + ) + }) + + user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + + inactive := newTestServer(t, db, ps, uuid.New()) + chat, err := inactive.CreateChat(ctx, chatd.CreateOptions{ + OwnerID: user.ID, + Title: "stopped-workspace-regression", + ModelConfigID: model.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("Run echo hi in the workspace."), + }, + }) + require.NoError(t, err) + + build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID) + require.NoError(t, err) + chat, err = db.UpdateChatBuildAgentBinding(ctx, database.UpdateChatBuildAgentBindingParams{ + ID: chat.ID, + BuildID: uuid.NullUUID{UUID: build.ID, Valid: true}, + AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true}, + }) + require.NoError(t, err) + + dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStop, + BuildNumber: 2, + }).Do() + + var dialCalls atomic.Int32 + _ = newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + dialCalls.Add(1) + require.Equal(t, dbAgent.ID, agentID) + <-ctx.Done() + return nil, nil, ctx.Err() + } + }) + + var chatResult database.Chat + require.Eventually(t, func() bool { + got, getErr := db.GetChatByID(ctx, chat.ID) + if getErr != nil { + return false + } + chatResult = got + return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError + }, testutil.WaitLong, testutil.IntervalFast) + + if chatResult.Status == database.ChatStatusError { + require.FailNowf(t, "chat failed", "last_error=%q", chatResult.LastError.String) + } + + require.EqualValues(t, 1, dialCalls.Load()) + require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2)) + + streamedCallsMu.Lock() + recordedCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...) + recordedTools := append([][]string(nil), toolsByCall...) + streamedCallsMu.Unlock() + require.GreaterOrEqual(t, len(recordedCalls), 2) + require.NotEmpty(t, recordedTools) + require.Contains(t, recordedTools[0], "execute") + require.Contains(t, recordedTools[0], "start_workspace") + + var foundUnavailableToolResult bool + for _, message := range recordedCalls[1] { + if message.Role != "tool" { + continue + } + if strings.Contains(message.Content, "chat has no workspace agent") { + foundUnavailableToolResult = true + break + } + if !json.Valid([]byte(message.Content)) { + continue + } + var toolResult map[string]any + if err := json.Unmarshal([]byte(message.Content), &toolResult); err != nil { + continue + } + errMsg, _ := toolResult["error"].(string) + outputMsg, _ := toolResult["output"].(string) + if strings.Contains(errMsg, "chat has no workspace agent") || + strings.Contains(outputMsg, "chat has no workspace agent") { + foundUnavailableToolResult = true + break + } + } + require.True(t, foundUnavailableToolResult, + "expected the second streamed model call to include the unavailable workspace tool result") + + var toolMessage *database.ChatMessage + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + if dbErr != nil { + return false + } + for i := range messages { + if messages[i].Role == database.ChatMessageRoleTool { + toolMessage = &messages[i] + return true + } + } + return false + }, testutil.IntervalFast) + require.NotNil(t, toolMessage) + + parts, err := chatprompt.ParseContent(*toolMessage) + require.NoError(t, err) + require.Len(t, parts, 1) + require.Equal(t, codersdk.ChatMessagePartTypeToolResult, parts[0].Type) + require.Equal(t, "execute", parts[0].ToolName) + require.True(t, parts[0].IsError) + require.Contains(t, string(parts[0].Result), "chat has no workspace agent") +} + func TestHeartbeatBumpsWorkspaceUsage(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/dialvalidation.go b/coderd/x/chatd/dialvalidation.go index 06c1c536af..88c035c4c6 100644 --- a/coderd/x/chatd/dialvalidation.go +++ b/coderd/x/chatd/dialvalidation.go @@ -123,6 +123,9 @@ func dialWithLazyValidation( validateBinding := func() (uuid.UUID, error) { validatedAgentID, err := validateFn(ctx, workspaceID) if err != nil { + if xerrors.Is(err, errChatHasNoWorkspaceAgent) { + return uuid.Nil, errChatHasNoWorkspaceAgent + } return uuid.Nil, wrapErr(err) } return validatedAgentID, nil @@ -150,12 +153,21 @@ func dialWithLazyValidation( return resolveFastFailure() case <-timer.C: - validatedAgentID, validationErr := validateFn(ctx, workspaceID) - if validationErr != nil || validatedAgentID == agentID { + validatedAgentID, validationErr := validateBinding() + if validationErr != nil { + if xerrors.Is(validationErr, errChatHasNoWorkspaceAgent) { + dialCancel() + return DialResult{}, validationErr + } // Validation could not prove the binding was stale, so keep waiting on // the original dial. return waitForOriginalDial(ctx) } + if validatedAgentID == agentID { + // Validation confirmed the current binding, so keep waiting on the + // original dial. + return waitForOriginalDial(ctx) + } // The original dial is stale. Cancel it first, then let the deferred drain // release any late result while we dial the validated agent immediately. dialCancel() diff --git a/coderd/x/chatd/dialvalidation_test.go b/coderd/x/chatd/dialvalidation_test.go index c2fea03acb..da2b639d98 100644 --- a/coderd/x/chatd/dialvalidation_test.go +++ b/coderd/x/chatd/dialvalidation_test.go @@ -108,6 +108,55 @@ func TestDialWithLazyValidation_SlowDialSameAgent(t *testing.T) { require.EqualValues(t, 1, releaseCalls.Load()) } +func TestDialWithLazyValidation_SlowDialNoCurrentAgent(t *testing.T) { + t.Parallel() + + staleAgentID := uuid.New() + workspaceID := uuid.New() + dialStarted := make(chan struct{}) + resultCh := make(chan error, 1) + + var dialCalls atomic.Int32 + var validateCalls atomic.Int32 + + go func() { + _, err := dialWithLazyValidation( + context.Background(), + staleAgentID, + workspaceID, + func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) { + if id != staleAgentID { + return nil, nil, xerrors.Errorf("unexpected agent ID %q", id) + } + dialCalls.Add(1) + close(dialStarted) + <-ctx.Done() + return nil, nil, ctx.Err() + }, + func(_ context.Context, id uuid.UUID) (uuid.UUID, error) { + if id != workspaceID { + return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id) + } + <-dialStarted + validateCalls.Add(1) + return uuid.Nil, errChatHasNoWorkspaceAgent + }, + 0, + ) + resultCh <- err + }() + + select { + case err := <-resultCh: + require.ErrorIs(t, err, errChatHasNoWorkspaceAgent) + case <-time.After(testutil.WaitShort): + t.Fatal("dialWithLazyValidation blocked after validation reported no current agent") + } + + require.EqualValues(t, 1, dialCalls.Load()) + require.EqualValues(t, 1, validateCalls.Load()) +} + func TestDialWithLazyValidation_SlowDialStaleAgent(t *testing.T) { t.Parallel()