feat(coderd): add stop_workspace chatd tool and recovery classification (#24997)

## Summary

Adds a `stop_workspace` tool to chatd so the model can recover from the
"workspace running but agent dead" failure mode (e.g. an OOM that leaves
the workspace running but the agent unreachable) by stopping and then
starting the workspace.

<img width="924" height="742" alt="image"
src="https://github.com/user-attachments/assets/279dedb6-6e29-4fe1-8754-3a1f01e538bf"
/>



## What changed

**New `stop_workspace` chatd tool**
(`coderd/x/chatd/chattool/stopworkspace.go`). Mirrors `start_workspace`:
shares `WorkspaceMu` to serialize with create/start, waits for any
in-progress build before issuing a stop, and is idempotent only after a
successful Stop transition. Failed stop builds re-attempt rather than
reporting success.

**New `chatStopWorkspace` coderd hook** (`coderd/exp_chats.go`). Mirrors
`chatStartWorkspace` minus the `RequireActiveVersion` gate. Stop should
not be blocked by template version policy.

**Differentiated recovery sentinels** (`coderd/x/chatd/chatd.go`).
`errChatAgentDisconnected` instructs the model to call `stop_workspace`
then `start_workspace`. `errChatDialTimeout` instructs a single retry,
then user escalation if it repeats. The previous single message
conflated transient and persistent failures.

**Two-signal recovery gate.** Recovery is only surfaced when a tool call
times out *and* a fresh DB read of the latest workspace agent says
`Disconnected`. The previous draft escalated on the DB read alone, which
would fire on a 30-second heartbeat blip (e.g. agent respawn) and prompt
a destructive stop/start unnecessarily.

**Cache-hit disconnected handling** now clears the cache and retries a
fresh dial before escalating, rather than returning the recovery
sentinel immediately. Latest-agent classification uses
`GetWorkspaceAgentsInLatestBuildByWorkspaceID` instead of the chat's
bound `AgentID`, so stale bindings after a rebuild don't misclassify.

**Shared chattool helpers** in `coderd/x/chatd/chattool/chattool.go`:
`latestWorkspaceBuildAndJob`, `publishBuildBinding`,
`provisionerJobTerminal`. Applied to both `start_workspace` and
`stop_workspace`.

## Notes

- Reverts an earlier draft that widened `ask_user_question` to root
standard turns. Plan-mode-only behavior is restored.
- The `stop_workspace` tool currently renders via the generic chat
tool-call UI. A follow-up frontend PR will prettify the `stop_workspace`
tool and style it like the `start_workspace` tool.
- Never-connected (`Timeout` status) agents are intentionally excluded
from recovery. They indicate template or startup failure, not the
running-but-dead case this PR targets.

Closes CODAGT-315
This commit is contained in:
Ethan
2026-05-11 16:23:07 +10:00
committed by GitHub
parent cee504e8a0
commit bd6cc1aaf2
11 changed files with 1250 additions and 114 deletions
+1
View File
@@ -805,6 +805,7 @@ func New(options *Options) *API {
InstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
CreateWorkspace: api.chatCreateWorkspace,
StartWorkspace: api.chatStartWorkspace,
StopWorkspace: api.chatStopWorkspace,
Pubsub: options.Pubsub,
WebpushDispatcher: options.WebPushDispatcher,
UsageTracker: options.WorkspaceUsageTracker,
+47
View File
@@ -3740,6 +3740,53 @@ func (api *API) chatStartWorkspace(
return apiBuild, nil
}
// chatStopWorkspace stops a workspace by creating a new build with the
// "stop" transition. It mirrors chatStartWorkspace, without start-only
// active-version behavior.
func (api *API) chatStopWorkspace(
ctx context.Context,
ownerID uuid.UUID,
workspaceID uuid.UUID,
req codersdk.CreateWorkspaceBuildRequest,
) (codersdk.WorkspaceBuild, error) {
actor, _, err := httpmw.UserRBACSubject(ctx, api.Database, ownerID, rbac.ScopeAll)
if err != nil {
return codersdk.WorkspaceBuild{}, xerrors.Errorf("load user authorization: %w", err)
}
ctx = dbauthz.As(ctx, actor)
workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceID)
if err != nil {
return codersdk.WorkspaceBuild{}, xerrors.Errorf("get workspace: %w", err)
}
req.Transition = codersdk.WorkspaceTransitionStop
// Build a synthetic API key so postWorkspaceBuildsInternal can
// record the correct initiator.
syntheticKey := database.APIKey{
UserID: ownerID,
}
apiBuild, err := api.postWorkspaceBuildsInternal(
ctx,
syntheticKey,
workspace,
req,
func(action policy.Action, object rbac.Objecter) bool {
// Authorization is handled by dbauthz on the context.
authErr := api.HTTPAuth.Authorizer.Authorize(ctx, actor, action, object.RBACObject())
return authErr == nil
},
audit.WorkspaceBuildBaggage{},
)
if err != nil {
return codersdk.WorkspaceBuild{}, xerrors.Errorf("create workspace build: %w", err)
}
return apiBuild, nil
}
func rewriteChatStartWorkspaceManualUpdateResponse(resp codersdk.Response, fallbackDetail string, retryInstructions string) codersdk.Response {
originalMessage := resp.Message
resp.Message = retryInstructions
+37
View File
@@ -13310,6 +13310,43 @@ func TestChatStartWorkspace_RequireActiveVersion(t *testing.T) {
require.Nil(t, build.TemplateVersionPresetID, "no preset must be applied")
}
func TestChatStopWorkspace_BypassesRequireActiveVersion(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
rawClient, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{})
var store dbauthz.AccessControlStore = requireActiveVersionStore{}
api.AccessControlStore.Store(&store)
db := api.Database
user := coderdtest.CreateFirstUser(t, rawClient)
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: user.UserID,
OrganizationID: user.OrganizationID,
}).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStart,
}).Do()
v1ID := wsResp.Build.TemplateVersionID
tmplID := wsResp.Workspace.TemplateID
v2Resp := dbfake.TemplateVersion(t, db).Seed(database.TemplateVersion{
TemplateID: uuid.NullUUID{UUID: tmplID, Valid: true},
OrganizationID: user.OrganizationID,
CreatedBy: user.UserID,
}).Do()
v2 := v2Resp.TemplateVersion
require.NotEqual(t, v1ID, v2.ID, "v2 must differ from v1")
build, err := coderd.ChatStopWorkspace(api, ctx, user.UserID, wsResp.Workspace.ID,
codersdk.CreateWorkspaceBuildRequest{})
require.NoError(t, err)
require.Equal(t, codersdk.WorkspaceTransitionStop, build.Transition)
require.Equal(t, v1ID, build.TemplateVersionID,
"stop must not apply RequireActiveVersion start-only logic")
require.NotEqual(t, v2.ID, build.TemplateVersionID)
}
func TestGetChatMessages_Pagination(t *testing.T) {
t.Parallel()
+3
View File
@@ -11,3 +11,6 @@ var InsertAgentChatTestModelConfig = insertAgentChatTestModelConfig
// stubbing the entire DB layer. The proper fix is to extract a pure
// request builder; tracked in CODAGT-292.
var ChatStartWorkspace = (*API).chatStartWorkspace
// ChatStopWorkspace exposes chatStopWorkspace for external tests.
var ChatStopWorkspace = (*API).chatStopWorkspace
+77 -6
View File
@@ -121,6 +121,13 @@ const (
// streamJanitorInterval.
streamJanitorInterval = 30 * time.Second
// agentDisconnectedRecoveryThreshold is how long the latest
// workspace agent must be disconnected before chatd suggests
// destructive stop/start recovery. This is intentionally longer
// than the inactive-disconnect timeout so short heartbeat gaps do
// not prompt a workspace restart.
agentDisconnectedRecoveryThreshold = 90 * time.Second
// DefaultMaxChatsPerAcquire is the maximum number of chats to
// acquire in a single processOnce call. Batching avoids
// waiting a full polling interval between acquisitions
@@ -139,12 +146,14 @@ const (
var (
errChatHasNoWorkspaceAgent = xerrors.New("workspace has no running agent: the workspace is likely stopped. Use the start_workspace tool to start it")
errChatAgentDisconnected = xerrors.New(
"workspace agent is disconnected and cannot execute tools. " +
"The workspace may need to be restarted from the Coder dashboard",
"workspace agent has been disconnected for at least 90 seconds " +
"and cannot execute tools. To recover, call stop_workspace " +
"to stop the workspace, then start_workspace to start it " +
"again",
)
errChatDialTimeout = xerrors.New(
"connection to the workspace agent timed out. " +
"The workspace may need to be restarted from the Coder dashboard",
"The agent may still be reachable on the next attempt.",
)
errChatExternalAgentUnavailable = xerrors.New("external workspace agent unavailable")
)
@@ -187,6 +196,7 @@ type Server struct {
instructionLookupTimeout time.Duration
createWorkspaceFn chattool.CreateWorkspaceFn
startWorkspaceFn chattool.StartWorkspaceFn
stopWorkspaceFn chattool.StopWorkspaceFn
pubsub pubsub.Pubsub
webpushDispatcher webpush.Dispatcher
providerAPIKeys chatprovider.ProviderAPIKeys
@@ -786,6 +796,45 @@ func isAgentUnreachable(now time.Time, agent database.WorkspaceAgent, inactiveTi
status.Status == database.WorkspaceAgentStatusTimeout
}
func agentDisconnectedFor(now time.Time, agent database.WorkspaceAgent, inactiveTimeout time.Duration) (time.Duration, bool) {
status := agent.Status(now, inactiveTimeout)
if status.Status != database.WorkspaceAgentStatusDisconnected || status.DisconnectedAt == nil {
return 0, false
}
disconnectedFor := now.Sub(*status.DisconnectedAt)
if disconnectedFor < 0 {
disconnectedFor = 0
}
return disconnectedFor, true
}
func (c *turnWorkspaceContext) latestWorkspaceAgentNeedsRestart(
ctx context.Context,
workspaceID uuid.UUID,
) (bool, error) {
agentID, err := c.latestWorkspaceAgentID(ctx, workspaceID)
if err != nil {
if xerrors.Is(err, errChatHasNoWorkspaceAgent) {
return false, err
}
c.server.logger.Warn(ctx, "failed to resolve latest agent for timeout classification", slog.Error(err))
return false, nil
}
agent, err := c.server.db.GetWorkspaceAgentByID(ctx, agentID)
if err != nil {
c.server.logger.Warn(ctx, "failed to load latest agent for timeout classification",
slog.F("agent_id", agentID),
slog.Error(err),
)
return false, nil
}
disconnectedFor, disconnected := agentDisconnectedFor(c.server.clock.Now(), agent, c.server.agentInactiveDisconnectTimeout)
return disconnected && disconnectedFor >= agentDisconnectedRecoveryThreshold, nil
}
func (c *turnWorkspaceContext) externalAgentError(
ctx context.Context,
agent database.WorkspaceAgent,
@@ -853,9 +902,13 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces
)
// On DB error the check re-runs on the
// next tool call.
} else if isAgentUnreachable(c.server.clock.Now(), freshAgent, c.server.agentInactiveDisconnectTimeout) {
} else if _, disconnected := agentDisconnectedFor(
c.server.clock.Now(),
freshAgent,
c.server.agentInactiveDisconnectTimeout,
); disconnected {
c.clearCachedWorkspaceState()
return nil, c.externalAgentError(ctx, freshAgent, errChatAgentDisconnected)
continue
}
}
return currentConn, nil
@@ -898,6 +951,14 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces
// canceled (e.g. ErrInterrupted), its error must
// propagate unchanged so the chatloop can detect it.
if ctx.Err() == nil && errors.Is(context.Cause(dialCtx), errChatDialTimeout) {
c.clearCachedWorkspaceState()
needsRestart, statusErr := c.latestWorkspaceAgentNeedsRestart(ctx, chatSnapshot.WorkspaceID.UUID)
if statusErr != nil {
return nil, statusErr
}
if needsRestart {
return nil, c.externalAgentError(ctx, agent, errChatAgentDisconnected)
}
return nil, c.externalAgentError(ctx, agent, errChatDialTimeout)
}
return nil, err
@@ -3752,6 +3813,7 @@ type Config struct {
InstructionLookupTimeout time.Duration
CreateWorkspace chattool.CreateWorkspaceFn
StartWorkspace chattool.StartWorkspaceFn
StopWorkspace chattool.StopWorkspaceFn
Pubsub pubsub.Pubsub
ProviderAPIKeys chatprovider.ProviderAPIKeys
AlwaysEnableDebugLogs bool
@@ -3820,6 +3882,7 @@ func New(cfg Config) *Server {
instructionLookupTimeout: instructionLookupTimeout,
createWorkspaceFn: cfg.CreateWorkspace,
startWorkspaceFn: cfg.StartWorkspace,
stopWorkspaceFn: cfg.StopWorkspace,
pubsub: cfg.Pubsub,
webpushDispatcher: cfg.WebpushDispatcher,
providerAPIKeys: cfg.ProviderAPIKeys,
@@ -5899,7 +5962,7 @@ func builtinPlanToolAllowed(name string, isRootChat bool) bool {
case "read_file", "execute", "process_output", "read_skill", "read_skill_file":
return true
case "write_file", "edit_files", "list_templates", "read_template",
"create_workspace", "start_workspace", "propose_plan", "spawn_agent",
"create_workspace", "start_workspace", "stop_workspace", "propose_plan", "spawn_agent",
"spawn_explore_agent", "wait_agent", "ask_user_question":
return isRootChat
case "process_list", "process_signal", "message_agent", "close_agent",
@@ -5979,6 +6042,7 @@ func allowedExploreToolNames(allTools []fantasy.AgentTool) []string {
"read_template": false,
"create_workspace": false,
"start_workspace": false,
"stop_workspace": false,
"propose_plan": false,
"spawn_agent": false,
"wait_agent": false,
@@ -6198,6 +6262,13 @@ func (p *Server) appendRootChatTools(
OnChatUpdated: onChatUpdated,
Logger: p.logger,
}),
chattool.StopWorkspace(p.db, opts.chat.ID, chattool.StopWorkspaceOptions{
OwnerID: opts.chat.OwnerID,
StopFn: p.stopWorkspaceFn,
WorkspaceMu: opts.workspaceMu,
OnChatUpdated: onChatUpdated,
Logger: p.logger,
}),
)
if opts.isPlanModeTurn {
tools = append(tools, chattool.ProposePlan(chattool.ProposePlanOptions{
+376 -57
View File
@@ -295,6 +295,26 @@ func TestFilterExternalMCPConfigsForTurn(t *testing.T) {
})
}
func TestChatWorkspaceRecoveryErrorsDifferentiateSignalStrength(t *testing.T) {
t.Parallel()
// Disconnected recovery is gated by a DB-confirmed duration
// threshold, so the message can give direct stop/start guidance
// without asking the user.
disconnected := errChatAgentDisconnected.Error()
require.Contains(t, disconnected, "90 seconds")
require.Contains(t, disconnected, "stop_workspace")
require.Contains(t, disconnected, "start_workspace")
require.NotContains(t, disconnected, "ask_user_question")
// Dial timeout alone is a weak signal. The model should not
// escalate to lifecycle tools without DB-confirmed disconnect.
dialTimeout := errChatDialTimeout.Error()
require.NotContains(t, dialTimeout, "ask_user_question")
require.NotContains(t, dialTimeout, "stop_workspace")
require.NotContains(t, dialTimeout, "start_workspace")
}
func TestActiveToolNamesForTurn(t *testing.T) {
t.Parallel()
@@ -344,6 +364,7 @@ func TestActiveToolNamesForTurn(t *testing.T) {
"read_template",
"create_workspace",
"start_workspace",
"stop_workspace",
"propose_plan",
"spawn_agent",
"wait_agent",
@@ -364,6 +385,7 @@ func TestActiveToolNamesForTurn(t *testing.T) {
"read_template",
"create_workspace",
"start_workspace",
"stop_workspace",
"propose_plan",
"spawn_agent",
"wait_agent",
@@ -386,6 +408,7 @@ func TestActiveToolNamesForTurn(t *testing.T) {
"read_template",
"create_workspace",
"start_workspace",
"stop_workspace",
"propose_plan",
"spawn_agent",
"wait_agent",
@@ -405,6 +428,8 @@ func TestActiveToolNamesForTurn(t *testing.T) {
require.NotContains(t, got, "edit_files")
require.NotContains(t, got, "ask_user_question")
require.NotContains(t, got, "propose_plan")
require.NotContains(t, got, "start_workspace")
require.NotContains(t, got, "stop_workspace")
require.NotContains(t, got, "spawn_explore_agent")
})
@@ -474,6 +499,8 @@ func TestAllowedExploreToolNames(t *testing.T) {
newTestAgentTool("write_file"),
newTestMCPAgentTool("external-mcp__echo", externalConfigID),
newTestAgentTool("workspace-mcp__echo"),
newTestAgentTool("start_workspace"),
newTestAgentTool("stop_workspace"),
newTestAgentTool("execute"),
newTestAgentTool("process_output"),
newTestAgentTool("process_list"),
@@ -494,6 +521,9 @@ func TestAllowedExploreToolNames(t *testing.T) {
"read_skill_file",
}, got)
require.NotContains(t, got, "workspace-mcp__echo")
require.NotContains(t, got, "start_workspace")
require.NotContains(t, got, "stop_workspace")
require.NotContains(t, got, "ask_user_question")
}
func TestAllowedBehaviorToolNames(t *testing.T) {
@@ -580,19 +610,13 @@ func TestStopAfterBehaviorTools(t *testing.T) {
))
})
t.Run("RootPlanModeIncludesClarificationTool", func(t *testing.T) {
t.Run("PlanModeDelegatesToPlanTools", func(t *testing.T) {
t.Parallel()
require.Equal(t, map[string]struct{}{
"propose_plan": {},
"ask_user_question": {},
}, stopAfterBehaviorTools(planMode, database.NullChatMode{}, uuid.NullUUID{}))
})
t.Run("ChildPlanModeSkipsClarificationTool", func(t *testing.T) {
t.Parallel()
require.Equal(t, map[string]struct{}{
"propose_plan": {},
}, stopAfterBehaviorTools(planMode, database.NullChatMode{}, uuid.NullUUID{UUID: uuid.New(), Valid: true}))
require.Equal(t, stopAfterPlanTools(planMode, uuid.NullUUID{}), stopAfterBehaviorTools(
planMode,
database.NullChatMode{},
uuid.NullUUID{},
))
})
t.Run("ExploreModeReturnsNil", func(t *testing.T) {
@@ -4229,46 +4253,28 @@ func TestGetWorkspaceConn_SameBuildAgentCrash(t *testing.T) {
func TestGetWorkspaceConn_StatusCheck(t *testing.T) {
// The cache-hit status check re-fetches the agent row for a fresh
// heartbeat timestamp. These tests verify that path detects
// disconnected or timed-out agents and that healthy or DB-error
// paths return the cached connection.
// heartbeat timestamp. Healthy, timed-out, and DB-error paths return
// the cached connection. Disconnected agents are covered separately
// because they now trigger a fresh dial before recovery.
t.Parallel()
type testCase struct {
name string
agent database.WorkspaceAgent
dbError bool
wantErr error
wantReleaseCalled bool
name string
agent database.WorkspaceAgent
dbError bool
}
tests := []testCase{
{
name: "DisconnectedAgentCacheHit",
agent: database.WorkspaceAgent{
FirstConnectedAt: sql.NullTime{
Time: time.Now().Add(-10 * time.Minute),
Valid: true,
},
LastConnectedAt: sql.NullTime{
Time: time.Now().Add(-10 * time.Minute),
Valid: true,
},
},
wantErr: errChatAgentDisconnected,
wantReleaseCalled: true,
},
{
// Agent never connected and the connection timeout
// has elapsed. This is the cache-hit timeout branch
// of isAgentUnreachable.
// has elapsed. This should not trigger lifecycle
// recovery because the agent did not connect and
// then disconnect.
name: "TimedOutAgentCacheHit",
agent: database.WorkspaceAgent{
CreatedAt: time.Now().Add(-10 * time.Minute),
ConnectionTimeoutSeconds: 60,
},
wantErr: errChatAgentDisconnected,
wantReleaseCalled: true,
},
{
name: "CacheHitHealthyAgent",
@@ -4371,28 +4377,274 @@ func TestGetWorkspaceConn_StatusCheck(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
if tc.wantErr != nil {
require.Nil(t, gotConn)
require.ErrorIs(t, err, tc.wantErr)
} else {
require.NoError(t, err)
require.Same(t, cachedConn, gotConn)
}
require.Equal(t, tc.wantReleaseCalled, releaseCalled, "release called")
// For cache-hit disconnect, the cache should be cleared.
if tc.wantErr != nil {
workspaceCtx.mu.Lock()
defer workspaceCtx.mu.Unlock()
require.False(t, workspaceCtx.agentLoaded)
require.Nil(t, workspaceCtx.conn)
}
require.NoError(t, err)
require.Same(t, cachedConn, gotConn)
require.False(t, releaseCalled, "release called")
})
}
}
func TestGetWorkspaceConn_DialTimeoutDisconnectedRecoveryThreshold(t *testing.T) {
// The recovery sentinel requires a failed dial and a fresh
// disconnected status check past the recovery threshold. A
// disconnected DB row alone is not enough to trigger stop/start
// recovery.
t.Parallel()
testCases := []struct {
name string
disconnectedFor time.Duration
wantErr error
wantRecovery bool
}{
{
name: "RecentDisconnectReturnsDialTimeout",
disconnectedFor: agentDisconnectedRecoveryThreshold / 2,
wantErr: errChatDialTimeout,
wantRecovery: false,
},
{
name: "PastThresholdEscalates",
disconnectedFor: agentDisconnectedRecoveryThreshold,
wantErr: errChatAgentDisconnected,
wantRecovery: true,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
clock := quartz.NewMock(t)
now := clock.Now()
disconnectedAgent := database.WorkspaceAgent{
ID: agentID,
FirstConnectedAt: sql.NullTime{
Time: now.Add(-10 * time.Minute),
Valid: true,
},
LastConnectedAt: sql.NullTime{
Time: now.Add(-10 * time.Minute),
Valid: true,
},
DisconnectedAt: sql.NullTime{
Time: now.Add(-tc.disconnectedFor),
Valid: true,
},
}
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
Return(disconnectedAgent, nil).
Times(2)
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{disconnectedAgent}, nil).
Times(1)
server := &Server{
db: db,
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
clock: clock,
agentInactiveDisconnectTimeout: 30 * time.Second,
dialTimeout: 10 * time.Millisecond,
}
server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
<-ctx.Done()
return nil, nil, ctx.Err()
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
defer workspaceCtx.close()
ctx := testutil.Context(t, testutil.WaitShort)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.Nil(t, gotConn)
require.ErrorIs(t, err, tc.wantErr)
if tc.wantRecovery {
require.ErrorIs(t, err, errChatAgentDisconnected)
} else {
require.NotErrorIs(t, err, errChatAgentDisconnected)
}
workspaceCtx.mu.Lock()
defer workspaceCtx.mu.Unlock()
require.False(t, workspaceCtx.agentLoaded)
require.Nil(t, workspaceCtx.conn)
})
}
}
func TestGetWorkspaceConn_DisconnectedStatusDialSuccessDoesNotEscalate(t *testing.T) {
// A stale disconnected row must not prompt stop/start if the
// agent can still be dialed successfully.
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
disconnectedAgent := database.WorkspaceAgent{
ID: agentID,
FirstConnectedAt: sql.NullTime{
Time: time.Now().Add(-10 * time.Minute),
Valid: true,
},
LastConnectedAt: sql.NullTime{
Time: time.Now().Add(-10 * time.Minute),
Valid: true,
},
}
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
Return(disconnectedAgent, nil).
Times(1)
server := &Server{
db: db,
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
clock: quartz.NewReal(),
agentInactiveDisconnectTimeout: 30 * time.Second,
dialTimeout: 10 * time.Millisecond,
}
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
var dialCalled bool
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialCalled = true
return conn, nil, nil
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
defer workspaceCtx.close()
ctx := testutil.Context(t, testutil.WaitShort)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.NoError(t, err)
require.Same(t, conn, gotConn)
require.True(t, dialCalled, "dial called")
}
func TestGetWorkspaceConn_CacheHitDisconnectedRetriesDialBeforeEscalating(t *testing.T) {
// A disconnected cached connection is discarded first. Recovery is
// only surfaced if the replacement dial also times out.
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
disconnectedAgent := database.WorkspaceAgent{
ID: agentID,
FirstConnectedAt: sql.NullTime{
Time: time.Now().Add(-10 * time.Minute),
Valid: true,
},
LastConnectedAt: sql.NullTime{
Time: time.Now().Add(-10 * time.Minute),
Valid: true,
},
}
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
Return(disconnectedAgent, nil).
Times(2)
server := &Server{
db: db,
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
clock: quartz.NewReal(),
agentInactiveDisconnectTimeout: 30 * time.Second,
dialTimeout: 10 * time.Millisecond,
}
newConn := agentconnmock.NewMockAgentConn(ctrl)
newConn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
var dialCalled bool
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialCalled = true
return newConn, nil, nil
}
var releaseCalled bool
chatStateMu := &sync.Mutex{}
currentChat := chat
oldConn := agentconnmock.NewMockAgentConn(ctrl)
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
agent: disconnectedAgent,
agentLoaded: true,
conn: oldConn,
releaseConn: func() { releaseCalled = true },
cachedWorkspaceID: chat.WorkspaceID,
}
defer workspaceCtx.close()
ctx := testutil.Context(t, testutil.WaitShort)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.NoError(t, err)
require.Same(t, newConn, gotConn)
require.True(t, releaseCalled, "release called")
require.True(t, dialCalled, "dial called")
}
func TestGetWorkspaceConn_DialTimeout(t *testing.T) {
// When dialWithLazyValidation blocks beyond the dial
// timeout, getWorkspaceConn should return
@@ -4431,6 +4683,9 @@ func TestGetWorkspaceConn_DialTimeout(t *testing.T) {
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
Return(connectedAgent, nil).
Times(2)
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{connectedAgent}, nil).
Times(1)
server := &Server{
@@ -4461,6 +4716,70 @@ func TestGetWorkspaceConn_DialTimeout(t *testing.T) {
require.ErrorIs(t, err, errChatDialTimeout)
}
func TestGetWorkspaceConn_DialTimeoutStatusTimeoutDoesNotEscalate(t *testing.T) {
// Agents that never connected are startup failures, not
// disconnected recovery cases. A dial timeout should stay a
// retry/escalation error rather than stop/start guidance.
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
agentID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}
timedOutAgent := database.WorkspaceAgent{
ID: agentID,
CreatedAt: time.Now().Add(-10 * time.Minute),
ConnectionTimeoutSeconds: 60,
}
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
Return(timedOutAgent, nil).
Times(2)
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{timedOutAgent}, nil).
Times(1)
server := &Server{
db: db,
clock: quartz.NewReal(),
agentInactiveDisconnectTimeout: 30 * time.Second,
dialTimeout: 10 * time.Millisecond,
}
server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
<-ctx.Done()
return nil, nil, ctx.Err()
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
defer workspaceCtx.close()
ctx := testutil.Context(t, testutil.WaitShort)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.Nil(t, gotConn)
require.ErrorIs(t, err, errChatDialTimeout)
require.NotErrorIs(t, err, errChatAgentDisconnected)
}
func TestGetWorkspaceConn_DialTimeoutParentCanceled(t *testing.T) {
// When the parent context is canceled, the parent's error
// must propagate unchanged (not wrapped as a dial timeout).
+11 -3
View File
@@ -344,7 +344,10 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
require.GreaterOrEqual(t, len(recorded), 2,
"expected at least 2 streamed LLM calls (root + subagent)")
workspaceTools := []string{"list_templates", "read_template", "create_workspace"}
workspaceTools := []string{
"list_templates", "read_template", "create_workspace",
"start_workspace", "stop_workspace",
}
subagentTools := []string{"spawn_agent", "wait_agent", "message_agent", "close_agent"}
// Identify root and subagent calls. Root chat calls include
@@ -375,7 +378,10 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
"root chat should have subagent tool %q", tool)
}
// Standard turns (no turn mode) should hide propose_plan.
// Standard turns (no turn mode) hide plan-only tools until
// plan mode.
require.NotContains(t, rootCalls[0], "ask_user_question",
"standard-turn root chat should NOT have ask_user_question")
require.NotContains(t, rootCalls[0], "propose_plan",
"standard-turn root chat should NOT have propose_plan")
@@ -388,6 +394,8 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
require.NotContains(t, childCalls[0], tool,
"subagent chat should NOT have subagent tool %q", tool)
}
require.NotContains(t, childCalls[0], "ask_user_question",
"subagent chat should NOT have ask_user_question")
}
func TestPlanModeSubagentChatExcludesAskUserQuestion(t *testing.T) {
@@ -7030,7 +7038,7 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) {
// 4. Verify workspace provisioning tools are NOT present.
workspaceProvisioningTools := []string{
"list_templates", "read_template",
"create_workspace", "start_workspace",
"create_workspace", "start_workspace", "stop_workspace",
}
for _, tool := range workspaceProvisioningTools {
require.NotContains(t, childTools, tool,
+63
View File
@@ -1,12 +1,16 @@
package chattool
import (
"context"
"encoding/json"
"unicode/utf8"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/codersdk"
)
@@ -54,6 +58,65 @@ func responseErrorResult(resp codersdk.Response) map[string]any {
return result
}
func latestWorkspaceBuildAndJob(
ctx context.Context,
db database.Store,
workspaceID uuid.UUID,
) (database.WorkspaceBuild, database.ProvisionerJob, error) {
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceID)
if err != nil {
return database.WorkspaceBuild{}, database.ProvisionerJob{}, xerrors.Errorf("get latest build: %w", err)
}
job, err := db.GetProvisionerJobByID(ctx, build.JobID)
if err != nil {
return database.WorkspaceBuild{}, database.ProvisionerJob{}, xerrors.Errorf("get provisioner job: %w", err)
}
return build, job, nil
}
func publishBuildBinding(
ctx context.Context,
db database.Store,
logger slog.Logger,
chatID uuid.UUID,
workspaceID uuid.UUID,
buildID uuid.UUID,
onChatUpdated func(database.Chat),
) {
updatedChat, bindErr := db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
ID: chatID,
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
BuildID: uuid.NullUUID{
UUID: buildID,
Valid: buildID != uuid.Nil,
},
AgentID: uuid.NullUUID{},
})
if bindErr != nil {
logger.Error(ctx, "failed to persist build ID on chat binding",
slog.F("chat_id", chatID),
slog.F("build_id", buildID),
slog.Error(bindErr),
)
return
}
if onChatUpdated != nil {
onChatUpdated(updatedChat)
}
}
func provisionerJobTerminal(status database.ProvisionerJobStatus) bool {
switch status {
case database.ProvisionerJobStatusSucceeded,
database.ProvisionerJobStatusFailed,
database.ProvisionerJobStatusCanceled:
return true
default:
return false
}
}
func truncateRunes(value string, maxLen int) string {
if maxLen <= 0 || value == "" {
return ""
+5 -48
View File
@@ -56,7 +56,7 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO
return fantasy.NewTextErrorResponse("workspace starter is not configured"), nil
}
// Serialize with create_workspace to prevent races.
// Serialize with create_workspace and stop_workspace to prevent races.
if options.WorkspaceMu != nil {
options.WorkspaceMu.Lock()
defer options.WorkspaceMu.Unlock()
@@ -86,18 +86,9 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO
), nil
}
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID)
build, job, err := latestWorkspaceBuildAndJob(ctx, db, ws.ID)
if err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("get latest build: %w", err).Error(),
), nil
}
job, err := db.GetProvisionerJobByID(ctx, build.JobID)
if err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("get provisioner job: %w", err).Error(),
), nil
return fantasy.NewTextErrorResponse(err.Error()), nil
}
// If a build is already in progress, wait for it.
@@ -106,24 +97,7 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO
database.ProvisionerJobStatusRunning:
// Publish the build ID to the frontend so it
// can start streaming logs immediately.
updatedChat, bindErr := db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
ID: chatID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
BuildID: uuid.NullUUID{
UUID: build.ID,
Valid: build.ID != uuid.Nil,
},
AgentID: uuid.NullUUID{},
})
if bindErr != nil {
options.Logger.Error(ctx, "failed to persist build ID on chat binding",
slog.F("chat_id", chatID),
slog.F("build_id", build.ID),
slog.Error(bindErr),
)
} else if options.OnChatUpdated != nil {
options.OnChatUpdated(updatedChat)
}
publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, build.ID, options.OnChatUpdated)
if err := waitForBuild(ctx, db, build.ID); err != nil {
// newBuildError returns via toolResponse (IsError: false)
// rather than NewTextErrorResponse (IsError: true) so the
@@ -199,24 +173,7 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO
// Persist the build ID on the chat binding so the
// frontend can stream logs without polling.
updatedChat, bindErr := db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
ID: chatID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
BuildID: uuid.NullUUID{
UUID: startBuild.ID,
Valid: startBuild.ID != uuid.Nil,
},
AgentID: uuid.NullUUID{},
})
if bindErr != nil {
options.Logger.Error(ctx, "failed to persist build ID on chat binding",
slog.F("chat_id", chatID),
slog.F("build_id", startBuild.ID),
slog.Error(bindErr),
)
} else if options.OnChatUpdated != nil {
options.OnChatUpdated(updatedChat)
}
publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, startBuild.ID, options.OnChatUpdated)
if err := waitForBuild(ctx, db, startBuild.ID); err != nil {
return buildFailureToolResponse(
ctx,
+181
View File
@@ -0,0 +1,181 @@
package chattool
import (
"context"
"sync"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/httpapi/httperror"
"github.com/coder/coder/v2/codersdk"
)
// StopWorkspaceFn stops a workspace by creating a new build with
// the "stop" transition.
type StopWorkspaceFn func(
ctx context.Context,
ownerID uuid.UUID,
workspaceID uuid.UUID,
req codersdk.CreateWorkspaceBuildRequest,
) (codersdk.WorkspaceBuild, error)
// StopWorkspaceOptions configures the stop_workspace tool.
type StopWorkspaceOptions struct {
OwnerID uuid.UUID
StopFn StopWorkspaceFn
WorkspaceMu *sync.Mutex
OnChatUpdated func(database.Chat)
Logger slog.Logger
}
type stopWorkspaceArgs struct{}
// StopWorkspace returns a tool that stops the workspace associated
// with the current chat. The tool is idempotent when the workspace is
// already stopped. db must not be nil and chatID must not be uuid.Nil.
func StopWorkspace(db database.Store, chatID uuid.UUID, options StopWorkspaceOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"stop_workspace",
"Stop the chat's workspace and wait for the stop build to complete. "+
"If another workspace build is already in progress, this waits "+
"for that build first, then stops the workspace if needed. "+
"After waiting, this tool is idempotent if the workspace is "+
"already stopped or the in-progress build stopped it. Use "+
"this when the "+
"user explicitly asks to stop the workspace, or when a "+
"workspace-agent error tells you to stop and then start the "+
"workspace. Stopping a workspace terminates running processes "+
"and may discard unsaved in-memory state. This tool does not "+
"delete the workspace.",
func(ctx context.Context, _ stopWorkspaceArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.StopFn == nil {
return fantasy.NewTextErrorResponse("workspace stopper is not configured"), nil
}
// Serialize with create_workspace and start_workspace to
// prevent lifecycle races.
if options.WorkspaceMu != nil {
options.WorkspaceMu.Lock()
defer options.WorkspaceMu.Unlock()
}
chat, err := db.GetChatByID(ctx, chatID)
if err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("load chat: %w", err).Error(),
), nil
}
if !chat.WorkspaceID.Valid {
return fantasy.NewTextErrorResponse(
"chat has no workspace; use create_workspace first",
), nil
}
ws, err := db.GetWorkspaceByID(ctx, chat.WorkspaceID.UUID)
if err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("load workspace: %w", err).Error(),
), nil
}
if ws.Deleted {
return fantasy.NewTextErrorResponse(
"workspace was deleted; use create_workspace to make a new one",
), nil
}
build, job, err := latestWorkspaceBuildAndJob(ctx, db, ws.ID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
// If a build is already in progress, wait for it before
// deciding whether a stop build is still needed.
switch job.JobStatus {
case database.ProvisionerJobStatusPending,
database.ProvisionerJobStatusRunning,
database.ProvisionerJobStatusCanceling:
publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, build.ID, options.OnChatUpdated)
waitErr := waitForBuild(ctx, db, build.ID)
// Re-read after waiting because another transition may
// have completed while this tool was blocked.
ws, err = db.GetWorkspaceByID(ctx, ws.ID)
if err != nil {
return fantasy.NewTextErrorResponse(
xerrors.Errorf("load workspace: %w", err).Error(),
), nil
}
if ws.Deleted {
return fantasy.NewTextErrorResponse(
"workspace was deleted; use create_workspace to make a new one",
), nil
}
build, job, err = latestWorkspaceBuildAndJob(ctx, db, ws.ID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
// The fresh job row is authoritative. A wait error can
// be stale if the build reached a terminal state while the
// wait context was ending.
if waitErr != nil && !provisionerJobTerminal(job.JobStatus) {
return buildToolResponse(newBuildError(
xerrors.Errorf("waiting for in-progress build: %w", waitErr).Error(),
build.ID,
)), nil
}
}
if job.JobStatus == database.ProvisionerJobStatusSucceeded &&
build.Transition == database.WorkspaceTransitionStop {
result := map[string]any{
"stopped": true,
"workspace_name": ws.Name,
}
setNoBuild(result, uuid.Nil)
return toolResponse(result), nil
}
ownerCtx, ownerErr := asOwner(ctx, db, options.OwnerID)
if ownerErr != nil {
return fantasy.NewTextErrorResponse(ownerErr.Error()), nil
}
stopBuild, err := options.StopFn(ownerCtx, options.OwnerID, ws.ID, codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionStop,
})
if err != nil {
if responseErr, ok := httperror.IsResponder(err); ok {
_, resp := responseErr.Response()
return toolResponse(responseErrorResult(resp)), nil
}
return fantasy.NewTextErrorResponse(
xerrors.Errorf("stop workspace: %w", err).Error(),
), nil
}
publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, stopBuild.ID, options.OnChatUpdated)
if err := waitForBuild(ctx, db, stopBuild.ID); err != nil {
return buildToolResponse(newBuildError(
xerrors.Errorf("workspace stop build failed: %w", err).Error(),
stopBuild.ID,
)), nil
}
if options.OnChatUpdated != nil {
if latest, err := db.GetChatByID(ctx, chatID); err == nil {
options.OnChatUpdated(latest)
}
}
result := map[string]any{
"stopped": true,
"workspace_name": ws.Name,
}
setBuildID(result, stopBuild.ID)
return toolResponse(result), nil
})
}
@@ -0,0 +1,449 @@
package chattool_test
import (
"context"
"database/sql"
"encoding/json"
"sync"
"sync/atomic"
"testing"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func TestStopWorkspace(t *testing.T) {
t.Parallel()
t.Run("NoWorkspace", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
modelCfg := seedModelConfig(t, db)
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
chat := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
LastModelConfigID: modelCfg.ID,
Title: "test-stop-no-workspace",
})
tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{
StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
t.Fatal("StopFn should not be called")
return codersdk.WorkspaceBuild{}, nil
},
WorkspaceMu: &sync.Mutex{},
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"})
require.NoError(t, err)
require.Contains(t, resp.Content, "use create_workspace first")
})
t.Run("DeletedWorkspace", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
modelCfg := seedModelConfig(t, db)
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: user.ID,
OrganizationID: org.ID,
Deleted: true,
}).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionDelete,
}).Do()
ws := wsResp.Workspace
chat := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
LastModelConfigID: modelCfg.ID,
Title: "test-stop-deleted-workspace",
})
tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{
StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
t.Fatal("StopFn should not be called for deleted workspace")
return codersdk.WorkspaceBuild{}, nil
},
WorkspaceMu: &sync.Mutex{},
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"})
require.NoError(t, err)
require.Contains(t, resp.Content, "workspace was deleted")
require.Contains(t, resp.Content, "create_workspace")
})
t.Run("AlreadyStopped", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
modelCfg := seedModelConfig(t, db)
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: user.ID,
OrganizationID: org.ID,
}).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
}).Do()
ws := wsResp.Workspace
chat := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
LastModelConfigID: modelCfg.ID,
Title: "test-stop-already-stopped",
})
tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{
OwnerID: user.ID,
StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
t.Fatal("StopFn should not be called for already-stopped workspace")
return codersdk.WorkspaceBuild{}, nil
},
WorkspaceMu: &sync.Mutex{},
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"})
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
require.Equal(t, true, result["stopped"])
require.Equal(t, ws.Name, result["workspace_name"])
require.Equal(t, true, result["no_build"])
require.Nil(t, result["build_id"])
})
t.Run("RunningWorkspaceStops", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
modelCfg := seedModelConfig(t, db)
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: user.ID,
OrganizationID: org.ID,
}).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStart,
}).Do()
ws := wsResp.Workspace
chat := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
LastModelConfigID: modelCfg.ID,
Title: "test-stop-running-workspace",
})
var stopCalled atomic.Bool
var stopBuildID uuid.UUID
var seenBuildID uuid.UUID
var onChatUpdatedCalls atomic.Int32
tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{
OwnerID: user.ID,
StopFn: func(_ context.Context, _ uuid.UUID, wsID uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
stopCalled.Store(true)
require.Equal(t, ws.ID, wsID)
require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition)
buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
BuildNumber: 2,
}).Do()
stopBuildID = buildResp.Build.ID
return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil
},
WorkspaceMu: &sync.Mutex{},
OnChatUpdated: func(chat database.Chat) {
onChatUpdatedCalls.Add(1)
if chat.BuildID.Valid {
seenBuildID = chat.BuildID.UUID
}
},
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"})
require.NoError(t, err)
require.True(t, stopCalled.Load())
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
require.Equal(t, true, result["stopped"])
require.Equal(t, ws.Name, result["workspace_name"])
require.Equal(t, stopBuildID.String(), result["build_id"])
require.Nil(t, result["no_build"])
require.GreaterOrEqual(t, onChatUpdatedCalls.Load(), int32(1))
require.Equal(t, stopBuildID, seenBuildID)
updatedChat, err := db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
require.True(t, updatedChat.BuildID.Valid)
require.Equal(t, stopBuildID, updatedChat.BuildID.UUID)
})
t.Run("InProgressBuildWaitsThenStops", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
modelCfg := seedModelConfig(t, db)
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: user.ID,
OrganizationID: org.ID,
}).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStart,
}).Starting().Do()
ws := wsResp.Workspace
chat := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
LastModelConfigID: modelCfg.ID,
Title: "test-stop-in-progress-build",
})
jobRead := make(chan struct{}, 1)
wrappedDB := &jobInterceptStore{Store: db, jobRead: jobRead}
var stopCalled atomic.Bool
var stopBuildID uuid.UUID
var onChatUpdatedCalled atomic.Bool
tool := chattool.StopWorkspace(wrappedDB, chat.ID, chattool.StopWorkspaceOptions{
OwnerID: user.ID,
StopFn: func(_ context.Context, _ uuid.UUID, wsID uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
stopCalled.Store(true)
require.Equal(t, ws.ID, wsID)
require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition)
buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
BuildNumber: 2,
}).Do()
stopBuildID = buildResp.Build.ID
return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil
},
WorkspaceMu: &sync.Mutex{},
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
OnChatUpdated: func(_ database.Chat) { onChatUpdatedCalled.Store(true) },
})
type toolResult struct {
resp fantasy.ToolResponse
err error
}
done := make(chan toolResult, 1)
go func() {
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"})
done <- toolResult{resp: resp, err: err}
}()
testutil.TryReceive(ctx, t, jobRead)
require.False(t, stopCalled.Load(), "StopFn must wait for the in-progress build")
now := time.Now().UTC()
require.NoError(t, db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
ID: wsResp.Build.JobID,
UpdatedAt: now,
CompletedAt: sql.NullTime{Time: now, Valid: true},
}))
res := testutil.TryReceive(ctx, t, done)
require.NoError(t, res.err)
require.True(t, stopCalled.Load())
require.True(t, onChatUpdatedCalled.Load())
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(res.resp.Content), &result))
require.Equal(t, true, result["stopped"])
require.Equal(t, stopBuildID.String(), result["build_id"])
})
t.Run("FailedLatestStopBuildStillStops", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
modelCfg := seedModelConfig(t, db)
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: user.ID,
OrganizationID: org.ID,
}).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
}).Do()
ws := wsResp.Workspace
now := time.Now().UTC()
require.NoError(t, db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
ID: wsResp.Build.JobID,
UpdatedAt: now,
CompletedAt: sql.NullTime{Time: now, Valid: true},
Error: sql.NullString{String: "latest build failed", Valid: true},
}))
chat := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
LastModelConfigID: modelCfg.ID,
Title: "test-stop-failed-latest-build",
})
var stopCalled atomic.Bool
tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{
OwnerID: user.ID,
StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
stopCalled.Store(true)
require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition)
buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
BuildNumber: 2,
}).Do()
return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil
},
WorkspaceMu: &sync.Mutex{},
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"})
require.NoError(t, err)
require.True(t, stopCalled.Load())
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
require.Equal(t, true, result["stopped"])
})
t.Run("StopTriggeredBuildFailure", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
modelCfg := seedModelConfig(t, db)
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: user.ID,
OrganizationID: org.ID,
}).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStart,
}).Do()
ws := wsResp.Workspace
chat := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
LastModelConfigID: modelCfg.ID,
Title: "test-stop-triggered-build-failure",
})
var stopBuildJobID uuid.UUID
var stopBuildID uuid.UUID
stopFn := func(_ context.Context, _ uuid.UUID, wsID uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
require.Equal(t, ws.ID, wsID)
require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition)
buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
BuildNumber: 2,
}).Starting().Do()
stopBuildJobID = buildResp.Build.JobID
stopBuildID = buildResp.Build.ID
return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil
}
jobRead := make(chan struct{}, 2)
wrappedDB := &jobInterceptStore{Store: db, jobRead: jobRead}
tool := chattool.StopWorkspace(wrappedDB, chat.ID, chattool.StopWorkspaceOptions{
OwnerID: user.ID,
StopFn: stopFn,
WorkspaceMu: &sync.Mutex{},
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
})
type toolResult struct {
resp fantasy.ToolResponse
err error
}
done := make(chan toolResult, 1)
go func() {
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"})
done <- toolResult{resp: resp, err: err}
}()
testutil.TryReceive(ctx, t, jobRead)
testutil.TryReceive(ctx, t, jobRead)
now := time.Now().UTC()
require.NoError(t, db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
ID: stopBuildJobID,
UpdatedAt: now,
CompletedAt: sql.NullTime{Time: now, Valid: true},
Error: sql.NullString{String: "terraform destroy failed", Valid: true},
}))
res := testutil.TryReceive(ctx, t, done)
require.NoError(t, res.err)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(res.resp.Content), &result))
require.Contains(t, result["error"], "workspace stop build failed")
require.Equal(t, stopBuildID.String(), result["build_id"])
require.False(t, res.resp.IsError,
"buildToolResponse must not set IsError; chatprompt strips structured fields from error responses")
})
}