mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat(coderd): add stop_workspace chatd tool and recovery classification (#24997)
## Summary Adds a `stop_workspace` tool to chatd so the model can recover from the "workspace running but agent dead" failure mode (e.g. an OOM that leaves the workspace running but the agent unreachable) by stopping and then starting the workspace. <img width="924" height="742" alt="image" src="https://github.com/user-attachments/assets/279dedb6-6e29-4fe1-8754-3a1f01e538bf" /> ## What changed **New `stop_workspace` chatd tool** (`coderd/x/chatd/chattool/stopworkspace.go`). Mirrors `start_workspace`: shares `WorkspaceMu` to serialize with create/start, waits for any in-progress build before issuing a stop, and is idempotent only after a successful Stop transition. Failed stop builds re-attempt rather than reporting success. **New `chatStopWorkspace` coderd hook** (`coderd/exp_chats.go`). Mirrors `chatStartWorkspace` minus the `RequireActiveVersion` gate. Stop should not be blocked by template version policy. **Differentiated recovery sentinels** (`coderd/x/chatd/chatd.go`). `errChatAgentDisconnected` instructs the model to call `stop_workspace` then `start_workspace`. `errChatDialTimeout` instructs a single retry, then user escalation if it repeats. The previous single message conflated transient and persistent failures. **Two-signal recovery gate.** Recovery is only surfaced when a tool call times out *and* a fresh DB read of the latest workspace agent says `Disconnected`. The previous draft escalated on the DB read alone, which would fire on a 30-second heartbeat blip (e.g. agent respawn) and prompt a destructive stop/start unnecessarily. **Cache-hit disconnected handling** now clears the cache and retries a fresh dial before escalating, rather than returning the recovery sentinel immediately. Latest-agent classification uses `GetWorkspaceAgentsInLatestBuildByWorkspaceID` instead of the chat's bound `AgentID`, so stale bindings after a rebuild don't misclassify. **Shared chattool helpers** in `coderd/x/chatd/chattool/chattool.go`: `latestWorkspaceBuildAndJob`, `publishBuildBinding`, `provisionerJobTerminal`. Applied to both `start_workspace` and `stop_workspace`. ## Notes - Reverts an earlier draft that widened `ask_user_question` to root standard turns. Plan-mode-only behavior is restored. - The `stop_workspace` tool currently renders via the generic chat tool-call UI. A follow-up frontend PR will prettify the `stop_workspace` tool and style it like the `start_workspace` tool. - Never-connected (`Timeout` status) agents are intentionally excluded from recovery. They indicate template or startup failure, not the running-but-dead case this PR targets. Closes CODAGT-315
This commit is contained in:
@@ -805,6 +805,7 @@ func New(options *Options) *API {
|
||||
InstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
StopWorkspace: api.chatStopWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
UsageTracker: options.WorkspaceUsageTracker,
|
||||
|
||||
@@ -3740,6 +3740,53 @@ func (api *API) chatStartWorkspace(
|
||||
return apiBuild, nil
|
||||
}
|
||||
|
||||
// chatStopWorkspace stops a workspace by creating a new build with the
|
||||
// "stop" transition. It mirrors chatStartWorkspace, without start-only
|
||||
// active-version behavior.
|
||||
func (api *API) chatStopWorkspace(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
workspaceID uuid.UUID,
|
||||
req codersdk.CreateWorkspaceBuildRequest,
|
||||
) (codersdk.WorkspaceBuild, error) {
|
||||
actor, _, err := httpmw.UserRBACSubject(ctx, api.Database, ownerID, rbac.ScopeAll)
|
||||
if err != nil {
|
||||
return codersdk.WorkspaceBuild{}, xerrors.Errorf("load user authorization: %w", err)
|
||||
}
|
||||
ctx = dbauthz.As(ctx, actor)
|
||||
|
||||
workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceID)
|
||||
if err != nil {
|
||||
return codersdk.WorkspaceBuild{}, xerrors.Errorf("get workspace: %w", err)
|
||||
}
|
||||
|
||||
req.Transition = codersdk.WorkspaceTransitionStop
|
||||
|
||||
// Build a synthetic API key so postWorkspaceBuildsInternal can
|
||||
// record the correct initiator.
|
||||
syntheticKey := database.APIKey{
|
||||
UserID: ownerID,
|
||||
}
|
||||
|
||||
apiBuild, err := api.postWorkspaceBuildsInternal(
|
||||
ctx,
|
||||
syntheticKey,
|
||||
workspace,
|
||||
req,
|
||||
func(action policy.Action, object rbac.Objecter) bool {
|
||||
// Authorization is handled by dbauthz on the context.
|
||||
authErr := api.HTTPAuth.Authorizer.Authorize(ctx, actor, action, object.RBACObject())
|
||||
return authErr == nil
|
||||
},
|
||||
audit.WorkspaceBuildBaggage{},
|
||||
)
|
||||
if err != nil {
|
||||
return codersdk.WorkspaceBuild{}, xerrors.Errorf("create workspace build: %w", err)
|
||||
}
|
||||
|
||||
return apiBuild, nil
|
||||
}
|
||||
|
||||
func rewriteChatStartWorkspaceManualUpdateResponse(resp codersdk.Response, fallbackDetail string, retryInstructions string) codersdk.Response {
|
||||
originalMessage := resp.Message
|
||||
resp.Message = retryInstructions
|
||||
|
||||
@@ -13310,6 +13310,43 @@ func TestChatStartWorkspace_RequireActiveVersion(t *testing.T) {
|
||||
require.Nil(t, build.TemplateVersionPresetID, "no preset must be applied")
|
||||
}
|
||||
|
||||
func TestChatStopWorkspace_BypassesRequireActiveVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
rawClient, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{})
|
||||
var store dbauthz.AccessControlStore = requireActiveVersionStore{}
|
||||
api.AccessControlStore.Store(&store)
|
||||
db := api.Database
|
||||
user := coderdtest.CreateFirstUser(t, rawClient)
|
||||
|
||||
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: user.UserID,
|
||||
OrganizationID: user.OrganizationID,
|
||||
}).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
}).Do()
|
||||
v1ID := wsResp.Build.TemplateVersionID
|
||||
tmplID := wsResp.Workspace.TemplateID
|
||||
|
||||
v2Resp := dbfake.TemplateVersion(t, db).Seed(database.TemplateVersion{
|
||||
TemplateID: uuid.NullUUID{UUID: tmplID, Valid: true},
|
||||
OrganizationID: user.OrganizationID,
|
||||
CreatedBy: user.UserID,
|
||||
}).Do()
|
||||
v2 := v2Resp.TemplateVersion
|
||||
require.NotEqual(t, v1ID, v2.ID, "v2 must differ from v1")
|
||||
|
||||
build, err := coderd.ChatStopWorkspace(api, ctx, user.UserID, wsResp.Workspace.ID,
|
||||
codersdk.CreateWorkspaceBuildRequest{})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.WorkspaceTransitionStop, build.Transition)
|
||||
require.Equal(t, v1ID, build.TemplateVersionID,
|
||||
"stop must not apply RequireActiveVersion start-only logic")
|
||||
require.NotEqual(t, v2.ID, build.TemplateVersionID)
|
||||
}
|
||||
|
||||
func TestGetChatMessages_Pagination(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -11,3 +11,6 @@ var InsertAgentChatTestModelConfig = insertAgentChatTestModelConfig
|
||||
// stubbing the entire DB layer. The proper fix is to extract a pure
|
||||
// request builder; tracked in CODAGT-292.
|
||||
var ChatStartWorkspace = (*API).chatStartWorkspace
|
||||
|
||||
// ChatStopWorkspace exposes chatStopWorkspace for external tests.
|
||||
var ChatStopWorkspace = (*API).chatStopWorkspace
|
||||
|
||||
+77
-6
@@ -121,6 +121,13 @@ const (
|
||||
// streamJanitorInterval.
|
||||
streamJanitorInterval = 30 * time.Second
|
||||
|
||||
// agentDisconnectedRecoveryThreshold is how long the latest
|
||||
// workspace agent must be disconnected before chatd suggests
|
||||
// destructive stop/start recovery. This is intentionally longer
|
||||
// than the inactive-disconnect timeout so short heartbeat gaps do
|
||||
// not prompt a workspace restart.
|
||||
agentDisconnectedRecoveryThreshold = 90 * time.Second
|
||||
|
||||
// DefaultMaxChatsPerAcquire is the maximum number of chats to
|
||||
// acquire in a single processOnce call. Batching avoids
|
||||
// waiting a full polling interval between acquisitions
|
||||
@@ -139,12 +146,14 @@ const (
|
||||
var (
|
||||
errChatHasNoWorkspaceAgent = xerrors.New("workspace has no running agent: the workspace is likely stopped. Use the start_workspace tool to start it")
|
||||
errChatAgentDisconnected = xerrors.New(
|
||||
"workspace agent is disconnected and cannot execute tools. " +
|
||||
"The workspace may need to be restarted from the Coder dashboard",
|
||||
"workspace agent has been disconnected for at least 90 seconds " +
|
||||
"and cannot execute tools. To recover, call stop_workspace " +
|
||||
"to stop the workspace, then start_workspace to start it " +
|
||||
"again",
|
||||
)
|
||||
errChatDialTimeout = xerrors.New(
|
||||
"connection to the workspace agent timed out. " +
|
||||
"The workspace may need to be restarted from the Coder dashboard",
|
||||
"The agent may still be reachable on the next attempt.",
|
||||
)
|
||||
errChatExternalAgentUnavailable = xerrors.New("external workspace agent unavailable")
|
||||
)
|
||||
@@ -187,6 +196,7 @@ type Server struct {
|
||||
instructionLookupTimeout time.Duration
|
||||
createWorkspaceFn chattool.CreateWorkspaceFn
|
||||
startWorkspaceFn chattool.StartWorkspaceFn
|
||||
stopWorkspaceFn chattool.StopWorkspaceFn
|
||||
pubsub pubsub.Pubsub
|
||||
webpushDispatcher webpush.Dispatcher
|
||||
providerAPIKeys chatprovider.ProviderAPIKeys
|
||||
@@ -786,6 +796,45 @@ func isAgentUnreachable(now time.Time, agent database.WorkspaceAgent, inactiveTi
|
||||
status.Status == database.WorkspaceAgentStatusTimeout
|
||||
}
|
||||
|
||||
func agentDisconnectedFor(now time.Time, agent database.WorkspaceAgent, inactiveTimeout time.Duration) (time.Duration, bool) {
|
||||
status := agent.Status(now, inactiveTimeout)
|
||||
if status.Status != database.WorkspaceAgentStatusDisconnected || status.DisconnectedAt == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
disconnectedFor := now.Sub(*status.DisconnectedAt)
|
||||
if disconnectedFor < 0 {
|
||||
disconnectedFor = 0
|
||||
}
|
||||
return disconnectedFor, true
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) latestWorkspaceAgentNeedsRestart(
|
||||
ctx context.Context,
|
||||
workspaceID uuid.UUID,
|
||||
) (bool, error) {
|
||||
agentID, err := c.latestWorkspaceAgentID(ctx, workspaceID)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, errChatHasNoWorkspaceAgent) {
|
||||
return false, err
|
||||
}
|
||||
c.server.logger.Warn(ctx, "failed to resolve latest agent for timeout classification", slog.Error(err))
|
||||
return false, nil
|
||||
}
|
||||
|
||||
agent, err := c.server.db.GetWorkspaceAgentByID(ctx, agentID)
|
||||
if err != nil {
|
||||
c.server.logger.Warn(ctx, "failed to load latest agent for timeout classification",
|
||||
slog.F("agent_id", agentID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
disconnectedFor, disconnected := agentDisconnectedFor(c.server.clock.Now(), agent, c.server.agentInactiveDisconnectTimeout)
|
||||
return disconnected && disconnectedFor >= agentDisconnectedRecoveryThreshold, nil
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) externalAgentError(
|
||||
ctx context.Context,
|
||||
agent database.WorkspaceAgent,
|
||||
@@ -853,9 +902,13 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces
|
||||
)
|
||||
// On DB error the check re-runs on the
|
||||
// next tool call.
|
||||
} else if isAgentUnreachable(c.server.clock.Now(), freshAgent, c.server.agentInactiveDisconnectTimeout) {
|
||||
} else if _, disconnected := agentDisconnectedFor(
|
||||
c.server.clock.Now(),
|
||||
freshAgent,
|
||||
c.server.agentInactiveDisconnectTimeout,
|
||||
); disconnected {
|
||||
c.clearCachedWorkspaceState()
|
||||
return nil, c.externalAgentError(ctx, freshAgent, errChatAgentDisconnected)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return currentConn, nil
|
||||
@@ -898,6 +951,14 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces
|
||||
// canceled (e.g. ErrInterrupted), its error must
|
||||
// propagate unchanged so the chatloop can detect it.
|
||||
if ctx.Err() == nil && errors.Is(context.Cause(dialCtx), errChatDialTimeout) {
|
||||
c.clearCachedWorkspaceState()
|
||||
needsRestart, statusErr := c.latestWorkspaceAgentNeedsRestart(ctx, chatSnapshot.WorkspaceID.UUID)
|
||||
if statusErr != nil {
|
||||
return nil, statusErr
|
||||
}
|
||||
if needsRestart {
|
||||
return nil, c.externalAgentError(ctx, agent, errChatAgentDisconnected)
|
||||
}
|
||||
return nil, c.externalAgentError(ctx, agent, errChatDialTimeout)
|
||||
}
|
||||
return nil, err
|
||||
@@ -3752,6 +3813,7 @@ type Config struct {
|
||||
InstructionLookupTimeout time.Duration
|
||||
CreateWorkspace chattool.CreateWorkspaceFn
|
||||
StartWorkspace chattool.StartWorkspaceFn
|
||||
StopWorkspace chattool.StopWorkspaceFn
|
||||
Pubsub pubsub.Pubsub
|
||||
ProviderAPIKeys chatprovider.ProviderAPIKeys
|
||||
AlwaysEnableDebugLogs bool
|
||||
@@ -3820,6 +3882,7 @@ func New(cfg Config) *Server {
|
||||
instructionLookupTimeout: instructionLookupTimeout,
|
||||
createWorkspaceFn: cfg.CreateWorkspace,
|
||||
startWorkspaceFn: cfg.StartWorkspace,
|
||||
stopWorkspaceFn: cfg.StopWorkspace,
|
||||
pubsub: cfg.Pubsub,
|
||||
webpushDispatcher: cfg.WebpushDispatcher,
|
||||
providerAPIKeys: cfg.ProviderAPIKeys,
|
||||
@@ -5899,7 +5962,7 @@ func builtinPlanToolAllowed(name string, isRootChat bool) bool {
|
||||
case "read_file", "execute", "process_output", "read_skill", "read_skill_file":
|
||||
return true
|
||||
case "write_file", "edit_files", "list_templates", "read_template",
|
||||
"create_workspace", "start_workspace", "propose_plan", "spawn_agent",
|
||||
"create_workspace", "start_workspace", "stop_workspace", "propose_plan", "spawn_agent",
|
||||
"spawn_explore_agent", "wait_agent", "ask_user_question":
|
||||
return isRootChat
|
||||
case "process_list", "process_signal", "message_agent", "close_agent",
|
||||
@@ -5979,6 +6042,7 @@ func allowedExploreToolNames(allTools []fantasy.AgentTool) []string {
|
||||
"read_template": false,
|
||||
"create_workspace": false,
|
||||
"start_workspace": false,
|
||||
"stop_workspace": false,
|
||||
"propose_plan": false,
|
||||
"spawn_agent": false,
|
||||
"wait_agent": false,
|
||||
@@ -6198,6 +6262,13 @@ func (p *Server) appendRootChatTools(
|
||||
OnChatUpdated: onChatUpdated,
|
||||
Logger: p.logger,
|
||||
}),
|
||||
chattool.StopWorkspace(p.db, opts.chat.ID, chattool.StopWorkspaceOptions{
|
||||
OwnerID: opts.chat.OwnerID,
|
||||
StopFn: p.stopWorkspaceFn,
|
||||
WorkspaceMu: opts.workspaceMu,
|
||||
OnChatUpdated: onChatUpdated,
|
||||
Logger: p.logger,
|
||||
}),
|
||||
)
|
||||
if opts.isPlanModeTurn {
|
||||
tools = append(tools, chattool.ProposePlan(chattool.ProposePlanOptions{
|
||||
|
||||
@@ -295,6 +295,26 @@ func TestFilterExternalMCPConfigsForTurn(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatWorkspaceRecoveryErrorsDifferentiateSignalStrength(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Disconnected recovery is gated by a DB-confirmed duration
|
||||
// threshold, so the message can give direct stop/start guidance
|
||||
// without asking the user.
|
||||
disconnected := errChatAgentDisconnected.Error()
|
||||
require.Contains(t, disconnected, "90 seconds")
|
||||
require.Contains(t, disconnected, "stop_workspace")
|
||||
require.Contains(t, disconnected, "start_workspace")
|
||||
require.NotContains(t, disconnected, "ask_user_question")
|
||||
|
||||
// Dial timeout alone is a weak signal. The model should not
|
||||
// escalate to lifecycle tools without DB-confirmed disconnect.
|
||||
dialTimeout := errChatDialTimeout.Error()
|
||||
require.NotContains(t, dialTimeout, "ask_user_question")
|
||||
require.NotContains(t, dialTimeout, "stop_workspace")
|
||||
require.NotContains(t, dialTimeout, "start_workspace")
|
||||
}
|
||||
|
||||
func TestActiveToolNamesForTurn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -344,6 +364,7 @@ func TestActiveToolNamesForTurn(t *testing.T) {
|
||||
"read_template",
|
||||
"create_workspace",
|
||||
"start_workspace",
|
||||
"stop_workspace",
|
||||
"propose_plan",
|
||||
"spawn_agent",
|
||||
"wait_agent",
|
||||
@@ -364,6 +385,7 @@ func TestActiveToolNamesForTurn(t *testing.T) {
|
||||
"read_template",
|
||||
"create_workspace",
|
||||
"start_workspace",
|
||||
"stop_workspace",
|
||||
"propose_plan",
|
||||
"spawn_agent",
|
||||
"wait_agent",
|
||||
@@ -386,6 +408,7 @@ func TestActiveToolNamesForTurn(t *testing.T) {
|
||||
"read_template",
|
||||
"create_workspace",
|
||||
"start_workspace",
|
||||
"stop_workspace",
|
||||
"propose_plan",
|
||||
"spawn_agent",
|
||||
"wait_agent",
|
||||
@@ -405,6 +428,8 @@ func TestActiveToolNamesForTurn(t *testing.T) {
|
||||
require.NotContains(t, got, "edit_files")
|
||||
require.NotContains(t, got, "ask_user_question")
|
||||
require.NotContains(t, got, "propose_plan")
|
||||
require.NotContains(t, got, "start_workspace")
|
||||
require.NotContains(t, got, "stop_workspace")
|
||||
require.NotContains(t, got, "spawn_explore_agent")
|
||||
})
|
||||
|
||||
@@ -474,6 +499,8 @@ func TestAllowedExploreToolNames(t *testing.T) {
|
||||
newTestAgentTool("write_file"),
|
||||
newTestMCPAgentTool("external-mcp__echo", externalConfigID),
|
||||
newTestAgentTool("workspace-mcp__echo"),
|
||||
newTestAgentTool("start_workspace"),
|
||||
newTestAgentTool("stop_workspace"),
|
||||
newTestAgentTool("execute"),
|
||||
newTestAgentTool("process_output"),
|
||||
newTestAgentTool("process_list"),
|
||||
@@ -494,6 +521,9 @@ func TestAllowedExploreToolNames(t *testing.T) {
|
||||
"read_skill_file",
|
||||
}, got)
|
||||
require.NotContains(t, got, "workspace-mcp__echo")
|
||||
require.NotContains(t, got, "start_workspace")
|
||||
require.NotContains(t, got, "stop_workspace")
|
||||
require.NotContains(t, got, "ask_user_question")
|
||||
}
|
||||
|
||||
func TestAllowedBehaviorToolNames(t *testing.T) {
|
||||
@@ -580,19 +610,13 @@ func TestStopAfterBehaviorTools(t *testing.T) {
|
||||
))
|
||||
})
|
||||
|
||||
t.Run("RootPlanModeIncludesClarificationTool", func(t *testing.T) {
|
||||
t.Run("PlanModeDelegatesToPlanTools", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, map[string]struct{}{
|
||||
"propose_plan": {},
|
||||
"ask_user_question": {},
|
||||
}, stopAfterBehaviorTools(planMode, database.NullChatMode{}, uuid.NullUUID{}))
|
||||
})
|
||||
|
||||
t.Run("ChildPlanModeSkipsClarificationTool", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, map[string]struct{}{
|
||||
"propose_plan": {},
|
||||
}, stopAfterBehaviorTools(planMode, database.NullChatMode{}, uuid.NullUUID{UUID: uuid.New(), Valid: true}))
|
||||
require.Equal(t, stopAfterPlanTools(planMode, uuid.NullUUID{}), stopAfterBehaviorTools(
|
||||
planMode,
|
||||
database.NullChatMode{},
|
||||
uuid.NullUUID{},
|
||||
))
|
||||
})
|
||||
|
||||
t.Run("ExploreModeReturnsNil", func(t *testing.T) {
|
||||
@@ -4229,46 +4253,28 @@ func TestGetWorkspaceConn_SameBuildAgentCrash(t *testing.T) {
|
||||
|
||||
func TestGetWorkspaceConn_StatusCheck(t *testing.T) {
|
||||
// The cache-hit status check re-fetches the agent row for a fresh
|
||||
// heartbeat timestamp. These tests verify that path detects
|
||||
// disconnected or timed-out agents and that healthy or DB-error
|
||||
// paths return the cached connection.
|
||||
// heartbeat timestamp. Healthy, timed-out, and DB-error paths return
|
||||
// the cached connection. Disconnected agents are covered separately
|
||||
// because they now trigger a fresh dial before recovery.
|
||||
t.Parallel()
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
agent database.WorkspaceAgent
|
||||
dbError bool
|
||||
wantErr error
|
||||
wantReleaseCalled bool
|
||||
name string
|
||||
agent database.WorkspaceAgent
|
||||
dbError bool
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "DisconnectedAgentCacheHit",
|
||||
agent: database.WorkspaceAgent{
|
||||
FirstConnectedAt: sql.NullTime{
|
||||
Time: time.Now().Add(-10 * time.Minute),
|
||||
Valid: true,
|
||||
},
|
||||
LastConnectedAt: sql.NullTime{
|
||||
Time: time.Now().Add(-10 * time.Minute),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
wantErr: errChatAgentDisconnected,
|
||||
wantReleaseCalled: true,
|
||||
},
|
||||
{
|
||||
// Agent never connected and the connection timeout
|
||||
// has elapsed. This is the cache-hit timeout branch
|
||||
// of isAgentUnreachable.
|
||||
// has elapsed. This should not trigger lifecycle
|
||||
// recovery because the agent did not connect and
|
||||
// then disconnect.
|
||||
name: "TimedOutAgentCacheHit",
|
||||
agent: database.WorkspaceAgent{
|
||||
CreatedAt: time.Now().Add(-10 * time.Minute),
|
||||
ConnectionTimeoutSeconds: 60,
|
||||
},
|
||||
wantErr: errChatAgentDisconnected,
|
||||
wantReleaseCalled: true,
|
||||
},
|
||||
{
|
||||
name: "CacheHitHealthyAgent",
|
||||
@@ -4371,28 +4377,274 @@ func TestGetWorkspaceConn_StatusCheck(t *testing.T) {
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
||||
|
||||
if tc.wantErr != nil {
|
||||
require.Nil(t, gotConn)
|
||||
require.ErrorIs(t, err, tc.wantErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Same(t, cachedConn, gotConn)
|
||||
}
|
||||
|
||||
require.Equal(t, tc.wantReleaseCalled, releaseCalled, "release called")
|
||||
|
||||
// For cache-hit disconnect, the cache should be cleared.
|
||||
if tc.wantErr != nil {
|
||||
workspaceCtx.mu.Lock()
|
||||
defer workspaceCtx.mu.Unlock()
|
||||
require.False(t, workspaceCtx.agentLoaded)
|
||||
require.Nil(t, workspaceCtx.conn)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Same(t, cachedConn, gotConn)
|
||||
require.False(t, releaseCalled, "release called")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetWorkspaceConn_DialTimeoutDisconnectedRecoveryThreshold(t *testing.T) {
|
||||
// The recovery sentinel requires a failed dial and a fresh
|
||||
// disconnected status check past the recovery threshold. A
|
||||
// disconnected DB row alone is not enough to trigger stop/start
|
||||
// recovery.
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
disconnectedFor time.Duration
|
||||
wantErr error
|
||||
wantRecovery bool
|
||||
}{
|
||||
{
|
||||
name: "RecentDisconnectReturnsDialTimeout",
|
||||
disconnectedFor: agentDisconnectedRecoveryThreshold / 2,
|
||||
wantErr: errChatDialTimeout,
|
||||
wantRecovery: false,
|
||||
},
|
||||
{
|
||||
name: "PastThresholdEscalates",
|
||||
disconnectedFor: agentDisconnectedRecoveryThreshold,
|
||||
wantErr: errChatAgentDisconnected,
|
||||
wantRecovery: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
agentID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
clock := quartz.NewMock(t)
|
||||
now := clock.Now()
|
||||
disconnectedAgent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
FirstConnectedAt: sql.NullTime{
|
||||
Time: now.Add(-10 * time.Minute),
|
||||
Valid: true,
|
||||
},
|
||||
LastConnectedAt: sql.NullTime{
|
||||
Time: now.Add(-10 * time.Minute),
|
||||
Valid: true,
|
||||
},
|
||||
DisconnectedAt: sql.NullTime{
|
||||
Time: now.Add(-tc.disconnectedFor),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
||||
Return(disconnectedAgent, nil).
|
||||
Times(2)
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
||||
Return([]database.WorkspaceAgent{disconnectedAgent}, nil).
|
||||
Times(1)
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
clock: clock,
|
||||
agentInactiveDisconnectTimeout: 30 * time.Second,
|
||||
dialTimeout: 10 * time.Millisecond,
|
||||
}
|
||||
server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
<-ctx.Done()
|
||||
return nil, nil, ctx.Err()
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
defer workspaceCtx.close()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
||||
require.Nil(t, gotConn)
|
||||
require.ErrorIs(t, err, tc.wantErr)
|
||||
if tc.wantRecovery {
|
||||
require.ErrorIs(t, err, errChatAgentDisconnected)
|
||||
} else {
|
||||
require.NotErrorIs(t, err, errChatAgentDisconnected)
|
||||
}
|
||||
|
||||
workspaceCtx.mu.Lock()
|
||||
defer workspaceCtx.mu.Unlock()
|
||||
require.False(t, workspaceCtx.agentLoaded)
|
||||
require.Nil(t, workspaceCtx.conn)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetWorkspaceConn_DisconnectedStatusDialSuccessDoesNotEscalate(t *testing.T) {
|
||||
// A stale disconnected row must not prompt stop/start if the
|
||||
// agent can still be dialed successfully.
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
agentID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
disconnectedAgent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
FirstConnectedAt: sql.NullTime{
|
||||
Time: time.Now().Add(-10 * time.Minute),
|
||||
Valid: true,
|
||||
},
|
||||
LastConnectedAt: sql.NullTime{
|
||||
Time: time.Now().Add(-10 * time.Minute),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
||||
Return(disconnectedAgent, nil).
|
||||
Times(1)
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
clock: quartz.NewReal(),
|
||||
agentInactiveDisconnectTimeout: 30 * time.Second,
|
||||
dialTimeout: 10 * time.Millisecond,
|
||||
}
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
||||
var dialCalled bool
|
||||
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
dialCalled = true
|
||||
return conn, nil, nil
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
defer workspaceCtx.close()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, conn, gotConn)
|
||||
require.True(t, dialCalled, "dial called")
|
||||
}
|
||||
|
||||
func TestGetWorkspaceConn_CacheHitDisconnectedRetriesDialBeforeEscalating(t *testing.T) {
|
||||
// A disconnected cached connection is discarded first. Recovery is
|
||||
// only surfaced if the replacement dial also times out.
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
agentID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
disconnectedAgent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
FirstConnectedAt: sql.NullTime{
|
||||
Time: time.Now().Add(-10 * time.Minute),
|
||||
Valid: true,
|
||||
},
|
||||
LastConnectedAt: sql.NullTime{
|
||||
Time: time.Now().Add(-10 * time.Minute),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
||||
Return(disconnectedAgent, nil).
|
||||
Times(2)
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
clock: quartz.NewReal(),
|
||||
agentInactiveDisconnectTimeout: 30 * time.Second,
|
||||
dialTimeout: 10 * time.Millisecond,
|
||||
}
|
||||
newConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
newConn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
||||
var dialCalled bool
|
||||
server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
dialCalled = true
|
||||
return newConn, nil, nil
|
||||
}
|
||||
|
||||
var releaseCalled bool
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
oldConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
agent: disconnectedAgent,
|
||||
agentLoaded: true,
|
||||
conn: oldConn,
|
||||
releaseConn: func() { releaseCalled = true },
|
||||
cachedWorkspaceID: chat.WorkspaceID,
|
||||
}
|
||||
defer workspaceCtx.close()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, newConn, gotConn)
|
||||
require.True(t, releaseCalled, "release called")
|
||||
require.True(t, dialCalled, "dial called")
|
||||
}
|
||||
|
||||
func TestGetWorkspaceConn_DialTimeout(t *testing.T) {
|
||||
// When dialWithLazyValidation blocks beyond the dial
|
||||
// timeout, getWorkspaceConn should return
|
||||
@@ -4431,6 +4683,9 @@ func TestGetWorkspaceConn_DialTimeout(t *testing.T) {
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
||||
Return(connectedAgent, nil).
|
||||
Times(2)
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
||||
Return([]database.WorkspaceAgent{connectedAgent}, nil).
|
||||
Times(1)
|
||||
|
||||
server := &Server{
|
||||
@@ -4461,6 +4716,70 @@ func TestGetWorkspaceConn_DialTimeout(t *testing.T) {
|
||||
require.ErrorIs(t, err, errChatDialTimeout)
|
||||
}
|
||||
|
||||
func TestGetWorkspaceConn_DialTimeoutStatusTimeoutDoesNotEscalate(t *testing.T) {
|
||||
// Agents that never connected are startup failures, not
|
||||
// disconnected recovery cases. A dial timeout should stay a
|
||||
// retry/escalation error rather than stop/start guidance.
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
agentID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
timedOutAgent := database.WorkspaceAgent{
|
||||
ID: agentID,
|
||||
CreatedAt: time.Now().Add(-10 * time.Minute),
|
||||
ConnectionTimeoutSeconds: 60,
|
||||
}
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).
|
||||
Return(timedOutAgent, nil).
|
||||
Times(2)
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
||||
Return([]database.WorkspaceAgent{timedOutAgent}, nil).
|
||||
Times(1)
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
clock: quartz.NewReal(),
|
||||
agentInactiveDisconnectTimeout: 30 * time.Second,
|
||||
dialTimeout: 10 * time.Millisecond,
|
||||
}
|
||||
server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
<-ctx.Done()
|
||||
return nil, nil, ctx.Err()
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
defer workspaceCtx.close()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
||||
require.Nil(t, gotConn)
|
||||
require.ErrorIs(t, err, errChatDialTimeout)
|
||||
require.NotErrorIs(t, err, errChatAgentDisconnected)
|
||||
}
|
||||
|
||||
func TestGetWorkspaceConn_DialTimeoutParentCanceled(t *testing.T) {
|
||||
// When the parent context is canceled, the parent's error
|
||||
// must propagate unchanged (not wrapped as a dial timeout).
|
||||
|
||||
@@ -344,7 +344,10 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
require.GreaterOrEqual(t, len(recorded), 2,
|
||||
"expected at least 2 streamed LLM calls (root + subagent)")
|
||||
|
||||
workspaceTools := []string{"list_templates", "read_template", "create_workspace"}
|
||||
workspaceTools := []string{
|
||||
"list_templates", "read_template", "create_workspace",
|
||||
"start_workspace", "stop_workspace",
|
||||
}
|
||||
subagentTools := []string{"spawn_agent", "wait_agent", "message_agent", "close_agent"}
|
||||
|
||||
// Identify root and subagent calls. Root chat calls include
|
||||
@@ -375,7 +378,10 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
"root chat should have subagent tool %q", tool)
|
||||
}
|
||||
|
||||
// Standard turns (no turn mode) should hide propose_plan.
|
||||
// Standard turns (no turn mode) hide plan-only tools until
|
||||
// plan mode.
|
||||
require.NotContains(t, rootCalls[0], "ask_user_question",
|
||||
"standard-turn root chat should NOT have ask_user_question")
|
||||
require.NotContains(t, rootCalls[0], "propose_plan",
|
||||
"standard-turn root chat should NOT have propose_plan")
|
||||
|
||||
@@ -388,6 +394,8 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
require.NotContains(t, childCalls[0], tool,
|
||||
"subagent chat should NOT have subagent tool %q", tool)
|
||||
}
|
||||
require.NotContains(t, childCalls[0], "ask_user_question",
|
||||
"subagent chat should NOT have ask_user_question")
|
||||
}
|
||||
|
||||
func TestPlanModeSubagentChatExcludesAskUserQuestion(t *testing.T) {
|
||||
@@ -7030,7 +7038,7 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) {
|
||||
// 4. Verify workspace provisioning tools are NOT present.
|
||||
workspaceProvisioningTools := []string{
|
||||
"list_templates", "read_template",
|
||||
"create_workspace", "start_workspace",
|
||||
"create_workspace", "start_workspace", "stop_workspace",
|
||||
}
|
||||
for _, tool := range workspaceProvisioningTools {
|
||||
require.NotContains(t, childTools, tool,
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"unicode/utf8"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -54,6 +58,65 @@ func responseErrorResult(resp codersdk.Response) map[string]any {
|
||||
return result
|
||||
}
|
||||
|
||||
func latestWorkspaceBuildAndJob(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
workspaceID uuid.UUID,
|
||||
) (database.WorkspaceBuild, database.ProvisionerJob, error) {
|
||||
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceID)
|
||||
if err != nil {
|
||||
return database.WorkspaceBuild{}, database.ProvisionerJob{}, xerrors.Errorf("get latest build: %w", err)
|
||||
}
|
||||
|
||||
job, err := db.GetProvisionerJobByID(ctx, build.JobID)
|
||||
if err != nil {
|
||||
return database.WorkspaceBuild{}, database.ProvisionerJob{}, xerrors.Errorf("get provisioner job: %w", err)
|
||||
}
|
||||
return build, job, nil
|
||||
}
|
||||
|
||||
func publishBuildBinding(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
logger slog.Logger,
|
||||
chatID uuid.UUID,
|
||||
workspaceID uuid.UUID,
|
||||
buildID uuid.UUID,
|
||||
onChatUpdated func(database.Chat),
|
||||
) {
|
||||
updatedChat, bindErr := db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
|
||||
ID: chatID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
|
||||
BuildID: uuid.NullUUID{
|
||||
UUID: buildID,
|
||||
Valid: buildID != uuid.Nil,
|
||||
},
|
||||
AgentID: uuid.NullUUID{},
|
||||
})
|
||||
if bindErr != nil {
|
||||
logger.Error(ctx, "failed to persist build ID on chat binding",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("build_id", buildID),
|
||||
slog.Error(bindErr),
|
||||
)
|
||||
return
|
||||
}
|
||||
if onChatUpdated != nil {
|
||||
onChatUpdated(updatedChat)
|
||||
}
|
||||
}
|
||||
|
||||
func provisionerJobTerminal(status database.ProvisionerJobStatus) bool {
|
||||
switch status {
|
||||
case database.ProvisionerJobStatusSucceeded,
|
||||
database.ProvisionerJobStatusFailed,
|
||||
database.ProvisionerJobStatusCanceled:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func truncateRunes(value string, maxLen int) string {
|
||||
if maxLen <= 0 || value == "" {
|
||||
return ""
|
||||
|
||||
@@ -56,7 +56,7 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO
|
||||
return fantasy.NewTextErrorResponse("workspace starter is not configured"), nil
|
||||
}
|
||||
|
||||
// Serialize with create_workspace to prevent races.
|
||||
// Serialize with create_workspace and stop_workspace to prevent races.
|
||||
if options.WorkspaceMu != nil {
|
||||
options.WorkspaceMu.Lock()
|
||||
defer options.WorkspaceMu.Unlock()
|
||||
@@ -86,18 +86,9 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO
|
||||
), nil
|
||||
}
|
||||
|
||||
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID)
|
||||
build, job, err := latestWorkspaceBuildAndJob(ctx, db, ws.ID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("get latest build: %w", err).Error(),
|
||||
), nil
|
||||
}
|
||||
|
||||
job, err := db.GetProvisionerJobByID(ctx, build.JobID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("get provisioner job: %w", err).Error(),
|
||||
), nil
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
// If a build is already in progress, wait for it.
|
||||
@@ -106,24 +97,7 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO
|
||||
database.ProvisionerJobStatusRunning:
|
||||
// Publish the build ID to the frontend so it
|
||||
// can start streaming logs immediately.
|
||||
updatedChat, bindErr := db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
|
||||
ID: chatID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
BuildID: uuid.NullUUID{
|
||||
UUID: build.ID,
|
||||
Valid: build.ID != uuid.Nil,
|
||||
},
|
||||
AgentID: uuid.NullUUID{},
|
||||
})
|
||||
if bindErr != nil {
|
||||
options.Logger.Error(ctx, "failed to persist build ID on chat binding",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("build_id", build.ID),
|
||||
slog.Error(bindErr),
|
||||
)
|
||||
} else if options.OnChatUpdated != nil {
|
||||
options.OnChatUpdated(updatedChat)
|
||||
}
|
||||
publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, build.ID, options.OnChatUpdated)
|
||||
if err := waitForBuild(ctx, db, build.ID); err != nil {
|
||||
// newBuildError returns via toolResponse (IsError: false)
|
||||
// rather than NewTextErrorResponse (IsError: true) so the
|
||||
@@ -199,24 +173,7 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO
|
||||
|
||||
// Persist the build ID on the chat binding so the
|
||||
// frontend can stream logs without polling.
|
||||
updatedChat, bindErr := db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{
|
||||
ID: chatID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
BuildID: uuid.NullUUID{
|
||||
UUID: startBuild.ID,
|
||||
Valid: startBuild.ID != uuid.Nil,
|
||||
},
|
||||
AgentID: uuid.NullUUID{},
|
||||
})
|
||||
if bindErr != nil {
|
||||
options.Logger.Error(ctx, "failed to persist build ID on chat binding",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("build_id", startBuild.ID),
|
||||
slog.Error(bindErr),
|
||||
)
|
||||
} else if options.OnChatUpdated != nil {
|
||||
options.OnChatUpdated(updatedChat)
|
||||
}
|
||||
publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, startBuild.ID, options.OnChatUpdated)
|
||||
if err := waitForBuild(ctx, db, startBuild.ID); err != nil {
|
||||
return buildFailureToolResponse(
|
||||
ctx,
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/httpapi/httperror"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// StopWorkspaceFn stops a workspace by creating a new build with
|
||||
// the "stop" transition.
|
||||
type StopWorkspaceFn func(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
workspaceID uuid.UUID,
|
||||
req codersdk.CreateWorkspaceBuildRequest,
|
||||
) (codersdk.WorkspaceBuild, error)
|
||||
|
||||
// StopWorkspaceOptions configures the stop_workspace tool.
|
||||
type StopWorkspaceOptions struct {
|
||||
OwnerID uuid.UUID
|
||||
StopFn StopWorkspaceFn
|
||||
WorkspaceMu *sync.Mutex
|
||||
OnChatUpdated func(database.Chat)
|
||||
Logger slog.Logger
|
||||
}
|
||||
|
||||
type stopWorkspaceArgs struct{}
|
||||
|
||||
// StopWorkspace returns a tool that stops the workspace associated
|
||||
// with the current chat. The tool is idempotent when the workspace is
|
||||
// already stopped. db must not be nil and chatID must not be uuid.Nil.
|
||||
func StopWorkspace(db database.Store, chatID uuid.UUID, options StopWorkspaceOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"stop_workspace",
|
||||
"Stop the chat's workspace and wait for the stop build to complete. "+
|
||||
"If another workspace build is already in progress, this waits "+
|
||||
"for that build first, then stops the workspace if needed. "+
|
||||
"After waiting, this tool is idempotent if the workspace is "+
|
||||
"already stopped or the in-progress build stopped it. Use "+
|
||||
"this when the "+
|
||||
"user explicitly asks to stop the workspace, or when a "+
|
||||
"workspace-agent error tells you to stop and then start the "+
|
||||
"workspace. Stopping a workspace terminates running processes "+
|
||||
"and may discard unsaved in-memory state. This tool does not "+
|
||||
"delete the workspace.",
|
||||
func(ctx context.Context, _ stopWorkspaceArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.StopFn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace stopper is not configured"), nil
|
||||
}
|
||||
|
||||
// Serialize with create_workspace and start_workspace to
|
||||
// prevent lifecycle races.
|
||||
if options.WorkspaceMu != nil {
|
||||
options.WorkspaceMu.Lock()
|
||||
defer options.WorkspaceMu.Unlock()
|
||||
}
|
||||
|
||||
chat, err := db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("load chat: %w", err).Error(),
|
||||
), nil
|
||||
}
|
||||
if !chat.WorkspaceID.Valid {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
"chat has no workspace; use create_workspace first",
|
||||
), nil
|
||||
}
|
||||
|
||||
ws, err := db.GetWorkspaceByID(ctx, chat.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("load workspace: %w", err).Error(),
|
||||
), nil
|
||||
}
|
||||
if ws.Deleted {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
"workspace was deleted; use create_workspace to make a new one",
|
||||
), nil
|
||||
}
|
||||
|
||||
build, job, err := latestWorkspaceBuildAndJob(ctx, db, ws.ID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
// If a build is already in progress, wait for it before
|
||||
// deciding whether a stop build is still needed.
|
||||
switch job.JobStatus {
|
||||
case database.ProvisionerJobStatusPending,
|
||||
database.ProvisionerJobStatusRunning,
|
||||
database.ProvisionerJobStatusCanceling:
|
||||
publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, build.ID, options.OnChatUpdated)
|
||||
|
||||
waitErr := waitForBuild(ctx, db, build.ID)
|
||||
// Re-read after waiting because another transition may
|
||||
// have completed while this tool was blocked.
|
||||
ws, err = db.GetWorkspaceByID(ctx, ws.ID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("load workspace: %w", err).Error(),
|
||||
), nil
|
||||
}
|
||||
if ws.Deleted {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
"workspace was deleted; use create_workspace to make a new one",
|
||||
), nil
|
||||
}
|
||||
build, job, err = latestWorkspaceBuildAndJob(ctx, db, ws.ID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
// The fresh job row is authoritative. A wait error can
|
||||
// be stale if the build reached a terminal state while the
|
||||
// wait context was ending.
|
||||
if waitErr != nil && !provisionerJobTerminal(job.JobStatus) {
|
||||
return buildToolResponse(newBuildError(
|
||||
xerrors.Errorf("waiting for in-progress build: %w", waitErr).Error(),
|
||||
build.ID,
|
||||
)), nil
|
||||
}
|
||||
}
|
||||
|
||||
if job.JobStatus == database.ProvisionerJobStatusSucceeded &&
|
||||
build.Transition == database.WorkspaceTransitionStop {
|
||||
result := map[string]any{
|
||||
"stopped": true,
|
||||
"workspace_name": ws.Name,
|
||||
}
|
||||
setNoBuild(result, uuid.Nil)
|
||||
return toolResponse(result), nil
|
||||
}
|
||||
|
||||
ownerCtx, ownerErr := asOwner(ctx, db, options.OwnerID)
|
||||
if ownerErr != nil {
|
||||
return fantasy.NewTextErrorResponse(ownerErr.Error()), nil
|
||||
}
|
||||
|
||||
stopBuild, err := options.StopFn(ownerCtx, options.OwnerID, ws.ID, codersdk.CreateWorkspaceBuildRequest{
|
||||
Transition: codersdk.WorkspaceTransitionStop,
|
||||
})
|
||||
if err != nil {
|
||||
if responseErr, ok := httperror.IsResponder(err); ok {
|
||||
_, resp := responseErr.Response()
|
||||
return toolResponse(responseErrorResult(resp)), nil
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(
|
||||
xerrors.Errorf("stop workspace: %w", err).Error(),
|
||||
), nil
|
||||
}
|
||||
|
||||
publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, stopBuild.ID, options.OnChatUpdated)
|
||||
if err := waitForBuild(ctx, db, stopBuild.ID); err != nil {
|
||||
return buildToolResponse(newBuildError(
|
||||
xerrors.Errorf("workspace stop build failed: %w", err).Error(),
|
||||
stopBuild.ID,
|
||||
)), nil
|
||||
}
|
||||
|
||||
if options.OnChatUpdated != nil {
|
||||
if latest, err := db.GetChatByID(ctx, chatID); err == nil {
|
||||
options.OnChatUpdated(latest)
|
||||
}
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"stopped": true,
|
||||
"workspace_name": ws.Name,
|
||||
}
|
||||
setBuildID(result, stopBuild.ID)
|
||||
return toolResponse(result), nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,449 @@
|
||||
package chattool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestStopWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NoWorkspace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
modelCfg := seedModelConfig(t, db)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
chat := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-stop-no-workspace",
|
||||
})
|
||||
|
||||
tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{
|
||||
StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
|
||||
t.Fatal("StopFn should not be called")
|
||||
return codersdk.WorkspaceBuild{}, nil
|
||||
},
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
})
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resp.Content, "use create_workspace first")
|
||||
})
|
||||
|
||||
t.Run("DeletedWorkspace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
modelCfg := seedModelConfig(t, db)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
Deleted: true,
|
||||
}).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionDelete,
|
||||
}).Do()
|
||||
ws := wsResp.Workspace
|
||||
|
||||
chat := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-stop-deleted-workspace",
|
||||
})
|
||||
|
||||
tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{
|
||||
StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
|
||||
t.Fatal("StopFn should not be called for deleted workspace")
|
||||
return codersdk.WorkspaceBuild{}, nil
|
||||
},
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
})
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resp.Content, "workspace was deleted")
|
||||
require.Contains(t, resp.Content, "create_workspace")
|
||||
})
|
||||
|
||||
t.Run("AlreadyStopped", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
modelCfg := seedModelConfig(t, db)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
}).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStop,
|
||||
}).Do()
|
||||
ws := wsResp.Workspace
|
||||
|
||||
chat := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-stop-already-stopped",
|
||||
})
|
||||
|
||||
tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{
|
||||
OwnerID: user.ID,
|
||||
StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
|
||||
t.Fatal("StopFn should not be called for already-stopped workspace")
|
||||
return codersdk.WorkspaceBuild{}, nil
|
||||
},
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
})
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
require.Equal(t, true, result["stopped"])
|
||||
require.Equal(t, ws.Name, result["workspace_name"])
|
||||
require.Equal(t, true, result["no_build"])
|
||||
require.Nil(t, result["build_id"])
|
||||
})
|
||||
|
||||
t.Run("RunningWorkspaceStops", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
modelCfg := seedModelConfig(t, db)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
}).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
}).Do()
|
||||
ws := wsResp.Workspace
|
||||
|
||||
chat := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-stop-running-workspace",
|
||||
})
|
||||
|
||||
var stopCalled atomic.Bool
|
||||
var stopBuildID uuid.UUID
|
||||
var seenBuildID uuid.UUID
|
||||
var onChatUpdatedCalls atomic.Int32
|
||||
tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{
|
||||
OwnerID: user.ID,
|
||||
StopFn: func(_ context.Context, _ uuid.UUID, wsID uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
|
||||
stopCalled.Store(true)
|
||||
require.Equal(t, ws.ID, wsID)
|
||||
require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition)
|
||||
buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStop,
|
||||
BuildNumber: 2,
|
||||
}).Do()
|
||||
stopBuildID = buildResp.Build.ID
|
||||
return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil
|
||||
},
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
OnChatUpdated: func(chat database.Chat) {
|
||||
onChatUpdatedCalls.Add(1)
|
||||
if chat.BuildID.Valid {
|
||||
seenBuildID = chat.BuildID.UUID
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
require.True(t, stopCalled.Load())
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
require.Equal(t, true, result["stopped"])
|
||||
require.Equal(t, ws.Name, result["workspace_name"])
|
||||
require.Equal(t, stopBuildID.String(), result["build_id"])
|
||||
require.Nil(t, result["no_build"])
|
||||
|
||||
require.GreaterOrEqual(t, onChatUpdatedCalls.Load(), int32(1))
|
||||
require.Equal(t, stopBuildID, seenBuildID)
|
||||
|
||||
updatedChat, err := db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, updatedChat.BuildID.Valid)
|
||||
require.Equal(t, stopBuildID, updatedChat.BuildID.UUID)
|
||||
})
|
||||
|
||||
t.Run("InProgressBuildWaitsThenStops", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
modelCfg := seedModelConfig(t, db)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
}).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
}).Starting().Do()
|
||||
ws := wsResp.Workspace
|
||||
|
||||
chat := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-stop-in-progress-build",
|
||||
})
|
||||
|
||||
jobRead := make(chan struct{}, 1)
|
||||
wrappedDB := &jobInterceptStore{Store: db, jobRead: jobRead}
|
||||
var stopCalled atomic.Bool
|
||||
var stopBuildID uuid.UUID
|
||||
var onChatUpdatedCalled atomic.Bool
|
||||
tool := chattool.StopWorkspace(wrappedDB, chat.ID, chattool.StopWorkspaceOptions{
|
||||
OwnerID: user.ID,
|
||||
StopFn: func(_ context.Context, _ uuid.UUID, wsID uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
|
||||
stopCalled.Store(true)
|
||||
require.Equal(t, ws.ID, wsID)
|
||||
require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition)
|
||||
buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStop,
|
||||
BuildNumber: 2,
|
||||
}).Do()
|
||||
stopBuildID = buildResp.Build.ID
|
||||
return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil
|
||||
},
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
OnChatUpdated: func(_ database.Chat) { onChatUpdatedCalled.Store(true) },
|
||||
})
|
||||
|
||||
type toolResult struct {
|
||||
resp fantasy.ToolResponse
|
||||
err error
|
||||
}
|
||||
done := make(chan toolResult, 1)
|
||||
go func() {
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"})
|
||||
done <- toolResult{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
testutil.TryReceive(ctx, t, jobRead)
|
||||
require.False(t, stopCalled.Load(), "StopFn must wait for the in-progress build")
|
||||
|
||||
now := time.Now().UTC()
|
||||
require.NoError(t, db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
ID: wsResp.Build.JobID,
|
||||
UpdatedAt: now,
|
||||
CompletedAt: sql.NullTime{Time: now, Valid: true},
|
||||
}))
|
||||
|
||||
res := testutil.TryReceive(ctx, t, done)
|
||||
require.NoError(t, res.err)
|
||||
require.True(t, stopCalled.Load())
|
||||
require.True(t, onChatUpdatedCalled.Load())
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(res.resp.Content), &result))
|
||||
require.Equal(t, true, result["stopped"])
|
||||
require.Equal(t, stopBuildID.String(), result["build_id"])
|
||||
})
|
||||
|
||||
t.Run("FailedLatestStopBuildStillStops", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
modelCfg := seedModelConfig(t, db)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
}).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStop,
|
||||
}).Do()
|
||||
ws := wsResp.Workspace
|
||||
now := time.Now().UTC()
|
||||
require.NoError(t, db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
ID: wsResp.Build.JobID,
|
||||
UpdatedAt: now,
|
||||
CompletedAt: sql.NullTime{Time: now, Valid: true},
|
||||
Error: sql.NullString{String: "latest build failed", Valid: true},
|
||||
}))
|
||||
|
||||
chat := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-stop-failed-latest-build",
|
||||
})
|
||||
|
||||
var stopCalled atomic.Bool
|
||||
tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{
|
||||
OwnerID: user.ID,
|
||||
StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
|
||||
stopCalled.Store(true)
|
||||
require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition)
|
||||
buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStop,
|
||||
BuildNumber: 2,
|
||||
}).Do()
|
||||
return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil
|
||||
},
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
})
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
require.True(t, stopCalled.Load())
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
require.Equal(t, true, result["stopped"])
|
||||
})
|
||||
|
||||
t.Run("StopTriggeredBuildFailure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
modelCfg := seedModelConfig(t, db)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
}).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
}).Do()
|
||||
ws := wsResp.Workspace
|
||||
|
||||
chat := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-stop-triggered-build-failure",
|
||||
})
|
||||
|
||||
var stopBuildJobID uuid.UUID
|
||||
var stopBuildID uuid.UUID
|
||||
stopFn := func(_ context.Context, _ uuid.UUID, wsID uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
|
||||
require.Equal(t, ws.ID, wsID)
|
||||
require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition)
|
||||
buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStop,
|
||||
BuildNumber: 2,
|
||||
}).Starting().Do()
|
||||
stopBuildJobID = buildResp.Build.JobID
|
||||
stopBuildID = buildResp.Build.ID
|
||||
return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil
|
||||
}
|
||||
|
||||
jobRead := make(chan struct{}, 2)
|
||||
wrappedDB := &jobInterceptStore{Store: db, jobRead: jobRead}
|
||||
tool := chattool.StopWorkspace(wrappedDB, chat.ID, chattool.StopWorkspaceOptions{
|
||||
OwnerID: user.ID,
|
||||
StopFn: stopFn,
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
})
|
||||
|
||||
type toolResult struct {
|
||||
resp fantasy.ToolResponse
|
||||
err error
|
||||
}
|
||||
done := make(chan toolResult, 1)
|
||||
go func() {
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"})
|
||||
done <- toolResult{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
testutil.TryReceive(ctx, t, jobRead)
|
||||
testutil.TryReceive(ctx, t, jobRead)
|
||||
|
||||
now := time.Now().UTC()
|
||||
require.NoError(t, db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
ID: stopBuildJobID,
|
||||
UpdatedAt: now,
|
||||
CompletedAt: sql.NullTime{Time: now, Valid: true},
|
||||
Error: sql.NullString{String: "terraform destroy failed", Valid: true},
|
||||
}))
|
||||
|
||||
res := testutil.TryReceive(ctx, t, done)
|
||||
require.NoError(t, res.err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(res.resp.Content), &result))
|
||||
require.Contains(t, result["error"], "workspace stop build failed")
|
||||
require.Equal(t, stopBuildID.String(), result["build_id"])
|
||||
require.False(t, res.resp.IsError,
|
||||
"buildToolResponse must not set IsError; chatprompt strips structured fields from error responses")
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user