mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(coderd/x/chatd): prevent chat hang when workspace agent is unavailable (#23707)
## Problem Chats with a persisted `agent_id` binding hang indefinitely when the workspace is stopped. The stale agent row still exists in the DB, so `ensureWorkspaceAgent` succeeds, but the dial blocks forever in `AwaitReachable`. The MCP discovery goroutine used an unbounded context, so `g2.Wait()` never returned and the LLM never started. ## Fix Three targeted changes restore the pre-binding behavior where stopped workspaces degrade gracefully instead of blocking: 1. **`dialWithLazyValidation`**: "no agents in latest build" is now a terminal fast-fail — the hanging dial is canceled and `errChatHasNoWorkspaceAgent` returned immediately, instead of falling through to `waitForOriginalDial`. 2. **Pre-LLM workspace setup**: MCP discovery and instruction persistence gate on `workspaceAgentIDForConn` before attempting any dial. MCP discovery is bounded by a 5s timeout and checks the in-memory tool cache first (using the cheap cached agent from `ensureWorkspaceAgent`), so the common subsequent-turn path has zero DB queries. 3. **`persistInstructionFiles`**: tracks whether the workspace connection succeeded and skips sentinel persistence on failure, so the next turn retries if the workspace is restarted. ## Scenarios **Running workspace, subsequent turn (hot path):** MCP cache hit via in-memory cached agent. Zero DB queries, zero dials. Unchanged from #23274. **Stopped workspace, persisted binding (the bug):** MCP cache hit (stale descriptors, fine — they fail at invocation). Pre-LLM setup completes instantly. Tool invocation enters `dialWithLazyValidation`, dial fails or hangs, validation discovers no agents, returns `errChatHasNoWorkspaceAgent`. Model sees the error and can call `start_workspace`. **New chat, running workspace:** `ensureWorkspaceAgent` resolves via latest-build, persists binding. MCP discovery dials and caches tools. **New chat, stopped workspace:** `ensureWorkspaceAgent` finds no agents, returns `errChatHasNoWorkspaceAgent`. Pre-LLM setup skips. LLM starts with built-in tools only. **Rebuilt workspace (agent switched):** MCP cache hit with stale agent (harmless for one turn). Tool invocation dials stale agent, fails fast, `dialWithLazyValidation` switches to new agent, persists updated binding. **Workspace restarted after stop:** No sentinel was persisted during the stopped turn, so instruction persistence retries. Agent binding switches to the new agent via `workspaceAgentIDForConn`. **Transient DB error during validation:** Not `errChatHasNoWorkspaceAgent`, so `dialWithLazyValidation` falls through to `waitForOriginalDial` (cannot prove stale). No false positive. **Tool invocation on stopped workspace:** `getWorkspaceConn` calls `ensureWorkspaceAgent` (returns stale row), then `dialWithLazyValidation` validation discovers no agents, returns `errChatHasNoWorkspaceAgent`, cached state cleared, error returned to model.
This commit is contained in:
+128
-27
@@ -55,6 +55,7 @@ const (
|
||||
homeInstructionLookupTimeout = 5 * time.Second
|
||||
instructionCacheTTL = 5 * time.Minute
|
||||
workspaceDialValidationDelay = 5 * time.Second
|
||||
workspaceMCPDiscoveryTimeout = 5 * time.Second
|
||||
// DefaultChatHeartbeatInterval is the default time between chat
|
||||
// heartbeat updates while a chat is being processed.
|
||||
DefaultChatHeartbeatInterval = 30 * time.Second
|
||||
@@ -89,6 +90,8 @@ const (
|
||||
defaultSubagentInstruction = "You are running as a delegated sub-agent chat. Complete the delegated task and provide clear, concise assistant responses for the parent agent."
|
||||
)
|
||||
|
||||
var errChatHasNoWorkspaceAgent = xerrors.New("chat has no workspace agent")
|
||||
|
||||
// Server handles background processing of pending chats.
|
||||
type Server struct {
|
||||
cancel context.CancelFunc
|
||||
@@ -336,8 +339,14 @@ func (c *turnWorkspaceContext) loadWorkspaceAgentLocked(
|
||||
ctx,
|
||||
chatSnapshot.WorkspaceID.UUID,
|
||||
)
|
||||
if err != nil || len(agents) == 0 {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("chat has no workspace agent")
|
||||
if err != nil {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf(
|
||||
"get workspace agents in latest build: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
if len(agents) == 0 {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, errChatHasNoWorkspaceAgent
|
||||
}
|
||||
|
||||
build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID)
|
||||
@@ -372,6 +381,65 @@ func (c *turnWorkspaceContext) loadWorkspaceAgentLocked(
|
||||
)
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) latestWorkspaceAgentID(
|
||||
ctx context.Context,
|
||||
workspaceID uuid.UUID,
|
||||
) (uuid.UUID, error) {
|
||||
agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
ctx,
|
||||
workspaceID,
|
||||
)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf(
|
||||
"get workspace agents in latest build: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
if len(agents) == 0 {
|
||||
return uuid.Nil, errChatHasNoWorkspaceAgent
|
||||
}
|
||||
return agents[0].ID, nil
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) workspaceAgentIDForConn(
|
||||
ctx context.Context,
|
||||
) (database.Chat, uuid.UUID, error) {
|
||||
for attempt := 0; attempt < 2; attempt++ {
|
||||
chatSnapshot := c.currentChatSnapshot()
|
||||
if !chatSnapshot.WorkspaceID.Valid || !chatSnapshot.AgentID.Valid {
|
||||
updatedChat, agent, err := c.ensureWorkspaceAgent(ctx)
|
||||
if err != nil {
|
||||
return updatedChat, uuid.Nil, err
|
||||
}
|
||||
return updatedChat, agent.ID, nil
|
||||
}
|
||||
|
||||
currentAgentID, err := c.latestWorkspaceAgentID(
|
||||
ctx,
|
||||
chatSnapshot.WorkspaceID.UUID,
|
||||
)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, errChatHasNoWorkspaceAgent) {
|
||||
c.clearCachedWorkspaceState()
|
||||
}
|
||||
return chatSnapshot, uuid.Nil, err
|
||||
}
|
||||
|
||||
latestChat, workspaceMatches := c.currentWorkspaceMatches(
|
||||
chatSnapshot.WorkspaceID,
|
||||
)
|
||||
if !workspaceMatches {
|
||||
continue
|
||||
}
|
||||
return latestChat, currentAgentID, nil
|
||||
}
|
||||
|
||||
chatSnapshot := c.currentChatSnapshot()
|
||||
return chatSnapshot, uuid.Nil, xerrors.New(
|
||||
"chat workspace changed while resolving agent",
|
||||
)
|
||||
}
|
||||
|
||||
// getWorkspaceConnLocked returns the cached connection when it still matches
|
||||
// the current workspace. When the workspace changed, it clears the stale
|
||||
// cached state and returns the release func for the caller to run after
|
||||
@@ -422,15 +490,14 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces
|
||||
chatSnapshot.WorkspaceID.UUID,
|
||||
DialFunc(c.server.agentConnFn),
|
||||
func(ctx context.Context, workspaceID uuid.UUID) (uuid.UUID, error) {
|
||||
agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspaceID)
|
||||
if err != nil || len(agents) == 0 {
|
||||
return uuid.Nil, xerrors.New("chat has no workspace agent")
|
||||
}
|
||||
return agents[0].ID, nil
|
||||
return c.latestWorkspaceAgentID(ctx, workspaceID)
|
||||
},
|
||||
workspaceDialValidationDelay,
|
||||
)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, errChatHasNoWorkspaceAgent) {
|
||||
c.clearCachedWorkspaceState()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -3763,7 +3830,12 @@ func (p *Server) runChat(
|
||||
chat,
|
||||
modelConfig.ID,
|
||||
workspaceCtx.getWorkspaceAgent,
|
||||
workspaceCtx.getWorkspaceConn,
|
||||
func(instructionCtx context.Context) (workspacesdk.AgentConn, error) {
|
||||
if _, _, err := workspaceCtx.workspaceAgentIDForConn(instructionCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return workspaceCtx.getWorkspaceConn(instructionCtx)
|
||||
},
|
||||
)
|
||||
if persistErr != nil {
|
||||
p.logger.Warn(ctx, "failed to persist instruction files",
|
||||
@@ -3793,32 +3865,54 @@ func (p *Server) runChat(
|
||||
}
|
||||
if chat.WorkspaceID.Valid {
|
||||
g2.Go(func() error {
|
||||
// Check cache first. On subsequent turns with the same
|
||||
// agent, reuse cached tools to avoid a round-trip.
|
||||
if cached, ok := p.workspaceMCPToolsCache.Load(chat.ID); ok {
|
||||
entry, ok2 := cached.(*cachedWorkspaceMCPTools)
|
||||
if !ok2 {
|
||||
return nil
|
||||
}
|
||||
// Verify the agent hasn't changed.
|
||||
if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil && agent.ID == entry.agentID {
|
||||
for _, t := range entry.tools {
|
||||
workspaceMCPTools = append(workspaceMCPTools,
|
||||
chattool.NewWorkspaceMCPTool(t, workspaceCtx.getWorkspaceConn),
|
||||
)
|
||||
// Fast path: check cache using the in-memory cached
|
||||
// agent (ensureWorkspaceAgent is free when already
|
||||
// loaded). This avoids a per-turn latest-build DB
|
||||
// query on the common subsequent-turn path.
|
||||
if agent, err := workspaceCtx.getWorkspaceAgent(ctx); err == nil {
|
||||
if cached, ok := p.workspaceMCPToolsCache.Load(chat.ID); ok {
|
||||
entry, ok := cached.(*cachedWorkspaceMCPTools)
|
||||
if ok && entry.agentID == agent.ID {
|
||||
for _, t := range entry.tools {
|
||||
workspaceMCPTools = append(workspaceMCPTools,
|
||||
chattool.NewWorkspaceMCPTool(t, workspaceCtx.getWorkspaceConn),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Cache miss or agent changed — fetch fresh tools.
|
||||
conn, connErr := workspaceCtx.getWorkspaceConn(ctx)
|
||||
// Cache miss, agent changed, or no cache — validate
|
||||
// that the workspace still has a live agent before
|
||||
// attempting a dial.
|
||||
workspaceMCPCtx, cancel := context.WithTimeout(
|
||||
ctx,
|
||||
workspaceMCPDiscoveryTimeout,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
_, _, agentErr := workspaceCtx.workspaceAgentIDForConn(
|
||||
workspaceMCPCtx,
|
||||
)
|
||||
if agentErr != nil {
|
||||
if xerrors.Is(agentErr, errChatHasNoWorkspaceAgent) {
|
||||
p.workspaceMCPToolsCache.Delete(chat.ID)
|
||||
return nil
|
||||
}
|
||||
logger.Warn(ctx, "failed to resolve workspace agent for MCP tools",
|
||||
slog.Error(agentErr))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fetch fresh tools from the workspace agent.
|
||||
conn, connErr := workspaceCtx.getWorkspaceConn(workspaceMCPCtx)
|
||||
if connErr != nil {
|
||||
logger.Warn(ctx, "failed to get workspace conn for MCP tools",
|
||||
slog.Error(connErr))
|
||||
return nil
|
||||
}
|
||||
toolsResp, listErr := conn.ListMCPTools(ctx)
|
||||
toolsResp, listErr := conn.ListMCPTools(workspaceMCPCtx)
|
||||
if listErr != nil {
|
||||
logger.Warn(ctx, "failed to list workspace MCP tools",
|
||||
slog.Error(listErr))
|
||||
@@ -3831,7 +3925,7 @@ func (p *Server) runChat(
|
||||
// caching an empty list would hide tools
|
||||
// permanently.
|
||||
if len(toolsResp.Tools) > 0 {
|
||||
if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil {
|
||||
if agent, agentErr := workspaceCtx.getWorkspaceAgent(workspaceMCPCtx); agentErr == nil {
|
||||
p.workspaceMCPToolsCache.Store(chat.ID, &cachedWorkspaceMCPTools{
|
||||
agentID: agent.ID,
|
||||
tools: toolsResp.Tools,
|
||||
@@ -4709,7 +4803,10 @@ func (p *Server) persistInstructionFiles(
|
||||
}
|
||||
|
||||
// Read instruction files from the workspace agent.
|
||||
var sections []instructionFileSection
|
||||
var (
|
||||
sections []instructionFileSection
|
||||
workspaceConnOK bool
|
||||
)
|
||||
if getWorkspaceConn != nil {
|
||||
instructionCtx, cancel := context.WithTimeout(ctx, homeInstructionLookupTimeout)
|
||||
defer cancel()
|
||||
@@ -4721,6 +4818,7 @@ func (p *Server) persistInstructionFiles(
|
||||
slog.Error(connErr),
|
||||
)
|
||||
} else {
|
||||
workspaceConnOK = true
|
||||
if content, source, truncated, readErr := readHomeInstructionFile(instructionCtx, conn); readErr != nil {
|
||||
p.logger.Debug(ctx, "failed to load home instruction file",
|
||||
slog.F("chat_id", chat.ID), slog.Error(readErr))
|
||||
@@ -4740,6 +4838,9 @@ func (p *Server) persistInstructionFiles(
|
||||
}
|
||||
|
||||
if len(sections) == 0 {
|
||||
if !workspaceConnOK {
|
||||
return "", nil
|
||||
}
|
||||
// Persist a sentinel so subsequent turns skip the
|
||||
// workspace agent dial.
|
||||
parts := []codersdk.ChatMessagePart{{
|
||||
|
||||
@@ -363,6 +363,43 @@ func TestPersistInstructionFilesIncludesAgentMetadata(t *testing.T) {
|
||||
require.Contains(t, instruction, "Working Directory: /home/coder/project")
|
||||
}
|
||||
|
||||
func TestPersistInstructionFilesSkipsSentinelWhenWorkspaceUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
}
|
||||
|
||||
instruction, err := server.persistInstructionFiles(
|
||||
ctx,
|
||||
chat,
|
||||
uuid.New(),
|
||||
func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
Directory: "/home/coder/project",
|
||||
}, nil
|
||||
},
|
||||
func(context.Context) (workspacesdk.AgentConn, error) {
|
||||
return nil, errChatHasNoWorkspaceAgent
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, instruction)
|
||||
}
|
||||
|
||||
func TestTurnWorkspaceContext_BindingFirstPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -590,6 +627,64 @@ func TestTurnWorkspaceContextGetWorkspaceConnLazyValidationSwitchesWorkspaceAgen
|
||||
require.Equal(t, currentAgent, gotAgent)
|
||||
}
|
||||
|
||||
func TestTurnWorkspaceContextGetWorkspaceConnFastFailsWithoutCurrentAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
staleAgentID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: staleAgentID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
staleAgent := database.WorkspaceAgent{ID: staleAgentID}
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), staleAgentID).
|
||||
Return(staleAgent, nil).
|
||||
Times(1)
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
||||
Return([]database.WorkspaceAgent{}, nil).
|
||||
Times(1)
|
||||
|
||||
server := &Server{db: db}
|
||||
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
return nil, nil, xerrors.New("dial failed")
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
defer workspaceCtx.close()
|
||||
|
||||
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
||||
require.Nil(t, gotConn)
|
||||
require.ErrorIs(t, err, errChatHasNoWorkspaceAgent)
|
||||
|
||||
workspaceCtx.mu.Lock()
|
||||
defer workspaceCtx.mu.Unlock()
|
||||
require.Equal(t, database.WorkspaceAgent{}, workspaceCtx.agent)
|
||||
require.False(t, workspaceCtx.agentLoaded)
|
||||
require.Nil(t, workspaceCtx.conn)
|
||||
require.Nil(t, workspaceCtx.releaseConn)
|
||||
require.Equal(t, uuid.NullUUID{}, workspaceCtx.cachedWorkspaceID)
|
||||
}
|
||||
|
||||
func TestTurnWorkspaceContext_SelectWorkspaceClearsCachedState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
@@ -2152,6 +2153,162 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include start_workspace tool output")
|
||||
}
|
||||
|
||||
func TestStoppedWorkspaceWithPersistedAgentBindingDoesNotBlockChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
var streamedCallCount atomic.Int32
|
||||
var streamedCallsMu sync.Mutex
|
||||
streamedCalls := make([][]chattest.OpenAIMessage, 0, 2)
|
||||
toolsByCall := make([][]string, 0, 2)
|
||||
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("Stopped workspace regression")
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(req.Tools))
|
||||
for _, tool := range req.Tools {
|
||||
names = append(names, tool.Function.Name)
|
||||
}
|
||||
|
||||
streamedCallsMu.Lock()
|
||||
streamedCalls = append(streamedCalls, append([]chattest.OpenAIMessage(nil), req.Messages...))
|
||||
toolsByCall = append(toolsByCall, names)
|
||||
streamedCallsMu.Unlock()
|
||||
|
||||
if streamedCallCount.Add(1) == 1 {
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAIToolCallChunk("execute", `{"command":"echo hi"}`),
|
||||
)
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("The workspace is unavailable. Start it before retrying workspace tools.")...,
|
||||
)
|
||||
})
|
||||
|
||||
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
|
||||
ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID)
|
||||
|
||||
inactive := newTestServer(t, db, ps, uuid.New())
|
||||
chat, err := inactive.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "stopped-workspace-regression",
|
||||
ModelConfigID: model.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("Run echo hi in the workspace."),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID)
|
||||
require.NoError(t, err)
|
||||
chat, err = db.UpdateChatBuildAgentBinding(ctx, database.UpdateChatBuildAgentBindingParams{
|
||||
ID: chat.ID,
|
||||
BuildID: uuid.NullUUID{UUID: build.ID, Valid: true},
|
||||
AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStop,
|
||||
BuildNumber: 2,
|
||||
}).Do()
|
||||
|
||||
var dialCalls atomic.Int32
|
||||
_ = newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
||||
cfg.AgentConn = func(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
dialCalls.Add(1)
|
||||
require.Equal(t, dbAgent.ID, agentID)
|
||||
<-ctx.Done()
|
||||
return nil, nil, ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
var chatResult database.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
chatResult = got
|
||||
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
if chatResult.Status == database.ChatStatusError {
|
||||
require.FailNowf(t, "chat failed", "last_error=%q", chatResult.LastError.String)
|
||||
}
|
||||
|
||||
require.EqualValues(t, 1, dialCalls.Load())
|
||||
require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2))
|
||||
|
||||
streamedCallsMu.Lock()
|
||||
recordedCalls := append([][]chattest.OpenAIMessage(nil), streamedCalls...)
|
||||
recordedTools := append([][]string(nil), toolsByCall...)
|
||||
streamedCallsMu.Unlock()
|
||||
require.GreaterOrEqual(t, len(recordedCalls), 2)
|
||||
require.NotEmpty(t, recordedTools)
|
||||
require.Contains(t, recordedTools[0], "execute")
|
||||
require.Contains(t, recordedTools[0], "start_workspace")
|
||||
|
||||
var foundUnavailableToolResult bool
|
||||
for _, message := range recordedCalls[1] {
|
||||
if message.Role != "tool" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(message.Content, "chat has no workspace agent") {
|
||||
foundUnavailableToolResult = true
|
||||
break
|
||||
}
|
||||
if !json.Valid([]byte(message.Content)) {
|
||||
continue
|
||||
}
|
||||
var toolResult map[string]any
|
||||
if err := json.Unmarshal([]byte(message.Content), &toolResult); err != nil {
|
||||
continue
|
||||
}
|
||||
errMsg, _ := toolResult["error"].(string)
|
||||
outputMsg, _ := toolResult["output"].(string)
|
||||
if strings.Contains(errMsg, "chat has no workspace agent") ||
|
||||
strings.Contains(outputMsg, "chat has no workspace agent") {
|
||||
foundUnavailableToolResult = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundUnavailableToolResult,
|
||||
"expected the second streamed model call to include the unavailable workspace tool result")
|
||||
|
||||
var toolMessage *database.ChatMessage
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
for i := range messages {
|
||||
if messages[i].Role == database.ChatMessageRoleTool {
|
||||
toolMessage = &messages[i]
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.IntervalFast)
|
||||
require.NotNil(t, toolMessage)
|
||||
|
||||
parts, err := chatprompt.ParseContent(*toolMessage)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, parts, 1)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeToolResult, parts[0].Type)
|
||||
require.Equal(t, "execute", parts[0].ToolName)
|
||||
require.True(t, parts[0].IsError)
|
||||
require.Contains(t, string(parts[0].Result), "chat has no workspace agent")
|
||||
}
|
||||
|
||||
func TestHeartbeatBumpsWorkspaceUsage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -123,6 +123,9 @@ func dialWithLazyValidation(
|
||||
validateBinding := func() (uuid.UUID, error) {
|
||||
validatedAgentID, err := validateFn(ctx, workspaceID)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, errChatHasNoWorkspaceAgent) {
|
||||
return uuid.Nil, errChatHasNoWorkspaceAgent
|
||||
}
|
||||
return uuid.Nil, wrapErr(err)
|
||||
}
|
||||
return validatedAgentID, nil
|
||||
@@ -150,12 +153,21 @@ func dialWithLazyValidation(
|
||||
return resolveFastFailure()
|
||||
|
||||
case <-timer.C:
|
||||
validatedAgentID, validationErr := validateFn(ctx, workspaceID)
|
||||
if validationErr != nil || validatedAgentID == agentID {
|
||||
validatedAgentID, validationErr := validateBinding()
|
||||
if validationErr != nil {
|
||||
if xerrors.Is(validationErr, errChatHasNoWorkspaceAgent) {
|
||||
dialCancel()
|
||||
return DialResult{}, validationErr
|
||||
}
|
||||
// Validation could not prove the binding was stale, so keep waiting on
|
||||
// the original dial.
|
||||
return waitForOriginalDial(ctx)
|
||||
}
|
||||
if validatedAgentID == agentID {
|
||||
// Validation confirmed the current binding, so keep waiting on the
|
||||
// original dial.
|
||||
return waitForOriginalDial(ctx)
|
||||
}
|
||||
// The original dial is stale. Cancel it first, then let the deferred drain
|
||||
// release any late result while we dial the validated agent immediately.
|
||||
dialCancel()
|
||||
|
||||
@@ -108,6 +108,55 @@ func TestDialWithLazyValidation_SlowDialSameAgent(t *testing.T) {
|
||||
require.EqualValues(t, 1, releaseCalls.Load())
|
||||
}
|
||||
|
||||
func TestDialWithLazyValidation_SlowDialNoCurrentAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
staleAgentID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
dialStarted := make(chan struct{})
|
||||
resultCh := make(chan error, 1)
|
||||
|
||||
var dialCalls atomic.Int32
|
||||
var validateCalls atomic.Int32
|
||||
|
||||
go func() {
|
||||
_, err := dialWithLazyValidation(
|
||||
context.Background(),
|
||||
staleAgentID,
|
||||
workspaceID,
|
||||
func(ctx context.Context, id uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
if id != staleAgentID {
|
||||
return nil, nil, xerrors.Errorf("unexpected agent ID %q", id)
|
||||
}
|
||||
dialCalls.Add(1)
|
||||
close(dialStarted)
|
||||
<-ctx.Done()
|
||||
return nil, nil, ctx.Err()
|
||||
},
|
||||
func(_ context.Context, id uuid.UUID) (uuid.UUID, error) {
|
||||
if id != workspaceID {
|
||||
return uuid.Nil, xerrors.Errorf("unexpected workspace ID %q", id)
|
||||
}
|
||||
<-dialStarted
|
||||
validateCalls.Add(1)
|
||||
return uuid.Nil, errChatHasNoWorkspaceAgent
|
||||
},
|
||||
0,
|
||||
)
|
||||
resultCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
require.ErrorIs(t, err, errChatHasNoWorkspaceAgent)
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("dialWithLazyValidation blocked after validation reported no current agent")
|
||||
}
|
||||
|
||||
require.EqualValues(t, 1, dialCalls.Load())
|
||||
require.EqualValues(t, 1, validateCalls.Load())
|
||||
}
|
||||
|
||||
func TestDialWithLazyValidation_SlowDialStaleAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user