From bd6cc1aaf2f4f2479b591ebe7e99786d9dbab3cc Mon Sep 17 00:00:00 2001 From: Ethan <39577870+ethanndickson@users.noreply.github.com> Date: Mon, 11 May 2026 16:23:07 +1000 Subject: [PATCH] feat(coderd): add stop_workspace chatd tool and recovery classification (#24997) ## Summary Adds a `stop_workspace` tool to chatd so the model can recover from the "workspace running but agent dead" failure mode (e.g. an OOM that leaves the workspace running but the agent unreachable) by stopping and then starting the workspace. image ## What changed **New `stop_workspace` chatd tool** (`coderd/x/chatd/chattool/stopworkspace.go`). Mirrors `start_workspace`: shares `WorkspaceMu` to serialize with create/start, waits for any in-progress build before issuing a stop, and is idempotent only after a successful Stop transition. Failed stop builds re-attempt rather than reporting success. **New `chatStopWorkspace` coderd hook** (`coderd/exp_chats.go`). Mirrors `chatStartWorkspace` minus the `RequireActiveVersion` gate. Stop should not be blocked by template version policy. **Differentiated recovery sentinels** (`coderd/x/chatd/chatd.go`). `errChatAgentDisconnected` instructs the model to call `stop_workspace` then `start_workspace`. `errChatDialTimeout` instructs a single retry, then user escalation if it repeats. The previous single message conflated transient and persistent failures. **Two-signal recovery gate.** Recovery is only surfaced when a tool call times out *and* a fresh DB read of the latest workspace agent says `Disconnected`. The previous draft escalated on the DB read alone, which would fire on a 30-second heartbeat blip (e.g. agent respawn) and prompt a destructive stop/start unnecessarily. **Cache-hit disconnected handling** now clears the cache and retries a fresh dial before escalating, rather than returning the recovery sentinel immediately. Latest-agent classification uses `GetWorkspaceAgentsInLatestBuildByWorkspaceID` instead of the chat's bound `AgentID`, so stale bindings after a rebuild don't misclassify. **Shared chattool helpers** in `coderd/x/chatd/chattool/chattool.go`: `latestWorkspaceBuildAndJob`, `publishBuildBinding`, `provisionerJobTerminal`. Applied to both `start_workspace` and `stop_workspace`. ## Notes - Reverts an earlier draft that widened `ask_user_question` to root standard turns. Plan-mode-only behavior is restored. - The `stop_workspace` tool currently renders via the generic chat tool-call UI. A follow-up frontend PR will prettify the `stop_workspace` tool and style it like the `start_workspace` tool. - Never-connected (`Timeout` status) agents are intentionally excluded from recovery. They indicate template or startup failure, not the running-but-dead case this PR targets. Closes CODAGT-315 --- coderd/coderd.go | 1 + coderd/exp_chats.go | 47 ++ coderd/exp_chats_test.go | 37 ++ coderd/export_test.go | 3 + coderd/x/chatd/chatd.go | 83 +++- coderd/x/chatd/chatd_internal_test.go | 433 ++++++++++++++--- coderd/x/chatd/chatd_test.go | 14 +- coderd/x/chatd/chattool/chattool.go | 63 +++ coderd/x/chatd/chattool/startworkspace.go | 53 +-- coderd/x/chatd/chattool/stopworkspace.go | 181 +++++++ coderd/x/chatd/chattool/stopworkspace_test.go | 449 ++++++++++++++++++ 11 files changed, 1250 insertions(+), 114 deletions(-) create mode 100644 coderd/x/chatd/chattool/stopworkspace.go create mode 100644 coderd/x/chatd/chattool/stopworkspace_test.go diff --git a/coderd/coderd.go b/coderd/coderd.go index 619a91f7b0..3db9bea92d 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -805,6 +805,7 @@ func New(options *Options) *API { InstructionLookupTimeout: options.ChatdInstructionLookupTimeout, CreateWorkspace: api.chatCreateWorkspace, StartWorkspace: api.chatStartWorkspace, + StopWorkspace: api.chatStopWorkspace, Pubsub: options.Pubsub, WebpushDispatcher: options.WebPushDispatcher, UsageTracker: options.WorkspaceUsageTracker, diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 43b1f105a5..f7ed4b8498 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -3740,6 +3740,53 @@ func (api *API) chatStartWorkspace( return apiBuild, nil } +// chatStopWorkspace stops a workspace by creating a new build with the +// "stop" transition. It mirrors chatStartWorkspace, without start-only +// active-version behavior. +func (api *API) chatStopWorkspace( + ctx context.Context, + ownerID uuid.UUID, + workspaceID uuid.UUID, + req codersdk.CreateWorkspaceBuildRequest, +) (codersdk.WorkspaceBuild, error) { + actor, _, err := httpmw.UserRBACSubject(ctx, api.Database, ownerID, rbac.ScopeAll) + if err != nil { + return codersdk.WorkspaceBuild{}, xerrors.Errorf("load user authorization: %w", err) + } + ctx = dbauthz.As(ctx, actor) + + workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceID) + if err != nil { + return codersdk.WorkspaceBuild{}, xerrors.Errorf("get workspace: %w", err) + } + + req.Transition = codersdk.WorkspaceTransitionStop + + // Build a synthetic API key so postWorkspaceBuildsInternal can + // record the correct initiator. + syntheticKey := database.APIKey{ + UserID: ownerID, + } + + apiBuild, err := api.postWorkspaceBuildsInternal( + ctx, + syntheticKey, + workspace, + req, + func(action policy.Action, object rbac.Objecter) bool { + // Authorization is handled by dbauthz on the context. + authErr := api.HTTPAuth.Authorizer.Authorize(ctx, actor, action, object.RBACObject()) + return authErr == nil + }, + audit.WorkspaceBuildBaggage{}, + ) + if err != nil { + return codersdk.WorkspaceBuild{}, xerrors.Errorf("create workspace build: %w", err) + } + + return apiBuild, nil +} + func rewriteChatStartWorkspaceManualUpdateResponse(resp codersdk.Response, fallbackDetail string, retryInstructions string) codersdk.Response { originalMessage := resp.Message resp.Message = retryInstructions diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index a8e944c016..4d20d50840 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -13310,6 +13310,43 @@ func TestChatStartWorkspace_RequireActiveVersion(t *testing.T) { require.Nil(t, build.TemplateVersionPresetID, "no preset must be applied") } +func TestChatStopWorkspace_BypassesRequireActiveVersion(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + rawClient, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{}) + var store dbauthz.AccessControlStore = requireActiveVersionStore{} + api.AccessControlStore.Store(&store) + db := api.Database + user := coderdtest.CreateFirstUser(t, rawClient) + + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.UserID, + OrganizationID: user.OrganizationID, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStart, + }).Do() + v1ID := wsResp.Build.TemplateVersionID + tmplID := wsResp.Workspace.TemplateID + + v2Resp := dbfake.TemplateVersion(t, db).Seed(database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tmplID, Valid: true}, + OrganizationID: user.OrganizationID, + CreatedBy: user.UserID, + }).Do() + v2 := v2Resp.TemplateVersion + require.NotEqual(t, v1ID, v2.ID, "v2 must differ from v1") + + build, err := coderd.ChatStopWorkspace(api, ctx, user.UserID, wsResp.Workspace.ID, + codersdk.CreateWorkspaceBuildRequest{}) + + require.NoError(t, err) + require.Equal(t, codersdk.WorkspaceTransitionStop, build.Transition) + require.Equal(t, v1ID, build.TemplateVersionID, + "stop must not apply RequireActiveVersion start-only logic") + require.NotEqual(t, v2.ID, build.TemplateVersionID) +} + func TestGetChatMessages_Pagination(t *testing.T) { t.Parallel() diff --git a/coderd/export_test.go b/coderd/export_test.go index 44f24a09ba..475270b994 100644 --- a/coderd/export_test.go +++ b/coderd/export_test.go @@ -11,3 +11,6 @@ var InsertAgentChatTestModelConfig = insertAgentChatTestModelConfig // stubbing the entire DB layer. The proper fix is to extract a pure // request builder; tracked in CODAGT-292. var ChatStartWorkspace = (*API).chatStartWorkspace + +// ChatStopWorkspace exposes chatStopWorkspace for external tests. +var ChatStopWorkspace = (*API).chatStopWorkspace diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 54dfd43963..ee9b6f08c6 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -121,6 +121,13 @@ const ( // streamJanitorInterval. streamJanitorInterval = 30 * time.Second + // agentDisconnectedRecoveryThreshold is how long the latest + // workspace agent must be disconnected before chatd suggests + // destructive stop/start recovery. This is intentionally longer + // than the inactive-disconnect timeout so short heartbeat gaps do + // not prompt a workspace restart. + agentDisconnectedRecoveryThreshold = 90 * time.Second + // DefaultMaxChatsPerAcquire is the maximum number of chats to // acquire in a single processOnce call. Batching avoids // waiting a full polling interval between acquisitions @@ -139,12 +146,14 @@ const ( var ( errChatHasNoWorkspaceAgent = xerrors.New("workspace has no running agent: the workspace is likely stopped. Use the start_workspace tool to start it") errChatAgentDisconnected = xerrors.New( - "workspace agent is disconnected and cannot execute tools. " + - "The workspace may need to be restarted from the Coder dashboard", + "workspace agent has been disconnected for at least 90 seconds " + + "and cannot execute tools. To recover, call stop_workspace " + + "to stop the workspace, then start_workspace to start it " + + "again", ) errChatDialTimeout = xerrors.New( "connection to the workspace agent timed out. " + - "The workspace may need to be restarted from the Coder dashboard", + "The agent may still be reachable on the next attempt.", ) errChatExternalAgentUnavailable = xerrors.New("external workspace agent unavailable") ) @@ -187,6 +196,7 @@ type Server struct { instructionLookupTimeout time.Duration createWorkspaceFn chattool.CreateWorkspaceFn startWorkspaceFn chattool.StartWorkspaceFn + stopWorkspaceFn chattool.StopWorkspaceFn pubsub pubsub.Pubsub webpushDispatcher webpush.Dispatcher providerAPIKeys chatprovider.ProviderAPIKeys @@ -786,6 +796,45 @@ func isAgentUnreachable(now time.Time, agent database.WorkspaceAgent, inactiveTi status.Status == database.WorkspaceAgentStatusTimeout } +func agentDisconnectedFor(now time.Time, agent database.WorkspaceAgent, inactiveTimeout time.Duration) (time.Duration, bool) { + status := agent.Status(now, inactiveTimeout) + if status.Status != database.WorkspaceAgentStatusDisconnected || status.DisconnectedAt == nil { + return 0, false + } + + disconnectedFor := now.Sub(*status.DisconnectedAt) + if disconnectedFor < 0 { + disconnectedFor = 0 + } + return disconnectedFor, true +} + +func (c *turnWorkspaceContext) latestWorkspaceAgentNeedsRestart( + ctx context.Context, + workspaceID uuid.UUID, +) (bool, error) { + agentID, err := c.latestWorkspaceAgentID(ctx, workspaceID) + if err != nil { + if xerrors.Is(err, errChatHasNoWorkspaceAgent) { + return false, err + } + c.server.logger.Warn(ctx, "failed to resolve latest agent for timeout classification", slog.Error(err)) + return false, nil + } + + agent, err := c.server.db.GetWorkspaceAgentByID(ctx, agentID) + if err != nil { + c.server.logger.Warn(ctx, "failed to load latest agent for timeout classification", + slog.F("agent_id", agentID), + slog.Error(err), + ) + return false, nil + } + + disconnectedFor, disconnected := agentDisconnectedFor(c.server.clock.Now(), agent, c.server.agentInactiveDisconnectTimeout) + return disconnected && disconnectedFor >= agentDisconnectedRecoveryThreshold, nil +} + func (c *turnWorkspaceContext) externalAgentError( ctx context.Context, agent database.WorkspaceAgent, @@ -853,9 +902,13 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces ) // On DB error the check re-runs on the // next tool call. - } else if isAgentUnreachable(c.server.clock.Now(), freshAgent, c.server.agentInactiveDisconnectTimeout) { + } else if _, disconnected := agentDisconnectedFor( + c.server.clock.Now(), + freshAgent, + c.server.agentInactiveDisconnectTimeout, + ); disconnected { c.clearCachedWorkspaceState() - return nil, c.externalAgentError(ctx, freshAgent, errChatAgentDisconnected) + continue } } return currentConn, nil @@ -898,6 +951,14 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces // canceled (e.g. ErrInterrupted), its error must // propagate unchanged so the chatloop can detect it. if ctx.Err() == nil && errors.Is(context.Cause(dialCtx), errChatDialTimeout) { + c.clearCachedWorkspaceState() + needsRestart, statusErr := c.latestWorkspaceAgentNeedsRestart(ctx, chatSnapshot.WorkspaceID.UUID) + if statusErr != nil { + return nil, statusErr + } + if needsRestart { + return nil, c.externalAgentError(ctx, agent, errChatAgentDisconnected) + } return nil, c.externalAgentError(ctx, agent, errChatDialTimeout) } return nil, err @@ -3752,6 +3813,7 @@ type Config struct { InstructionLookupTimeout time.Duration CreateWorkspace chattool.CreateWorkspaceFn StartWorkspace chattool.StartWorkspaceFn + StopWorkspace chattool.StopWorkspaceFn Pubsub pubsub.Pubsub ProviderAPIKeys chatprovider.ProviderAPIKeys AlwaysEnableDebugLogs bool @@ -3820,6 +3882,7 @@ func New(cfg Config) *Server { instructionLookupTimeout: instructionLookupTimeout, createWorkspaceFn: cfg.CreateWorkspace, startWorkspaceFn: cfg.StartWorkspace, + stopWorkspaceFn: cfg.StopWorkspace, pubsub: cfg.Pubsub, webpushDispatcher: cfg.WebpushDispatcher, providerAPIKeys: cfg.ProviderAPIKeys, @@ -5899,7 +5962,7 @@ func builtinPlanToolAllowed(name string, isRootChat bool) bool { case "read_file", "execute", "process_output", "read_skill", "read_skill_file": return true case "write_file", "edit_files", "list_templates", "read_template", - "create_workspace", "start_workspace", "propose_plan", "spawn_agent", + "create_workspace", "start_workspace", "stop_workspace", "propose_plan", "spawn_agent", "spawn_explore_agent", "wait_agent", "ask_user_question": return isRootChat case "process_list", "process_signal", "message_agent", "close_agent", @@ -5979,6 +6042,7 @@ func allowedExploreToolNames(allTools []fantasy.AgentTool) []string { "read_template": false, "create_workspace": false, "start_workspace": false, + "stop_workspace": false, "propose_plan": false, "spawn_agent": false, "wait_agent": false, @@ -6198,6 +6262,13 @@ func (p *Server) appendRootChatTools( OnChatUpdated: onChatUpdated, Logger: p.logger, }), + chattool.StopWorkspace(p.db, opts.chat.ID, chattool.StopWorkspaceOptions{ + OwnerID: opts.chat.OwnerID, + StopFn: p.stopWorkspaceFn, + WorkspaceMu: opts.workspaceMu, + OnChatUpdated: onChatUpdated, + Logger: p.logger, + }), ) if opts.isPlanModeTurn { tools = append(tools, chattool.ProposePlan(chattool.ProposePlanOptions{ diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index b23ec23d18..1f80e6c81d 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -295,6 +295,26 @@ func TestFilterExternalMCPConfigsForTurn(t *testing.T) { }) } +func TestChatWorkspaceRecoveryErrorsDifferentiateSignalStrength(t *testing.T) { + t.Parallel() + + // Disconnected recovery is gated by a DB-confirmed duration + // threshold, so the message can give direct stop/start guidance + // without asking the user. + disconnected := errChatAgentDisconnected.Error() + require.Contains(t, disconnected, "90 seconds") + require.Contains(t, disconnected, "stop_workspace") + require.Contains(t, disconnected, "start_workspace") + require.NotContains(t, disconnected, "ask_user_question") + + // Dial timeout alone is a weak signal. The model should not + // escalate to lifecycle tools without DB-confirmed disconnect. + dialTimeout := errChatDialTimeout.Error() + require.NotContains(t, dialTimeout, "ask_user_question") + require.NotContains(t, dialTimeout, "stop_workspace") + require.NotContains(t, dialTimeout, "start_workspace") +} + func TestActiveToolNamesForTurn(t *testing.T) { t.Parallel() @@ -344,6 +364,7 @@ func TestActiveToolNamesForTurn(t *testing.T) { "read_template", "create_workspace", "start_workspace", + "stop_workspace", "propose_plan", "spawn_agent", "wait_agent", @@ -364,6 +385,7 @@ func TestActiveToolNamesForTurn(t *testing.T) { "read_template", "create_workspace", "start_workspace", + "stop_workspace", "propose_plan", "spawn_agent", "wait_agent", @@ -386,6 +408,7 @@ func TestActiveToolNamesForTurn(t *testing.T) { "read_template", "create_workspace", "start_workspace", + "stop_workspace", "propose_plan", "spawn_agent", "wait_agent", @@ -405,6 +428,8 @@ func TestActiveToolNamesForTurn(t *testing.T) { require.NotContains(t, got, "edit_files") require.NotContains(t, got, "ask_user_question") require.NotContains(t, got, "propose_plan") + require.NotContains(t, got, "start_workspace") + require.NotContains(t, got, "stop_workspace") require.NotContains(t, got, "spawn_explore_agent") }) @@ -474,6 +499,8 @@ func TestAllowedExploreToolNames(t *testing.T) { newTestAgentTool("write_file"), newTestMCPAgentTool("external-mcp__echo", externalConfigID), newTestAgentTool("workspace-mcp__echo"), + newTestAgentTool("start_workspace"), + newTestAgentTool("stop_workspace"), newTestAgentTool("execute"), newTestAgentTool("process_output"), newTestAgentTool("process_list"), @@ -494,6 +521,9 @@ func TestAllowedExploreToolNames(t *testing.T) { "read_skill_file", }, got) require.NotContains(t, got, "workspace-mcp__echo") + require.NotContains(t, got, "start_workspace") + require.NotContains(t, got, "stop_workspace") + require.NotContains(t, got, "ask_user_question") } func TestAllowedBehaviorToolNames(t *testing.T) { @@ -580,19 +610,13 @@ func TestStopAfterBehaviorTools(t *testing.T) { )) }) - t.Run("RootPlanModeIncludesClarificationTool", func(t *testing.T) { + t.Run("PlanModeDelegatesToPlanTools", func(t *testing.T) { t.Parallel() - require.Equal(t, map[string]struct{}{ - "propose_plan": {}, - "ask_user_question": {}, - }, stopAfterBehaviorTools(planMode, database.NullChatMode{}, uuid.NullUUID{})) - }) - - t.Run("ChildPlanModeSkipsClarificationTool", func(t *testing.T) { - t.Parallel() - require.Equal(t, map[string]struct{}{ - "propose_plan": {}, - }, stopAfterBehaviorTools(planMode, database.NullChatMode{}, uuid.NullUUID{UUID: uuid.New(), Valid: true})) + require.Equal(t, stopAfterPlanTools(planMode, uuid.NullUUID{}), stopAfterBehaviorTools( + planMode, + database.NullChatMode{}, + uuid.NullUUID{}, + )) }) t.Run("ExploreModeReturnsNil", func(t *testing.T) { @@ -4229,46 +4253,28 @@ func TestGetWorkspaceConn_SameBuildAgentCrash(t *testing.T) { func TestGetWorkspaceConn_StatusCheck(t *testing.T) { // The cache-hit status check re-fetches the agent row for a fresh - // heartbeat timestamp. These tests verify that path detects - // disconnected or timed-out agents and that healthy or DB-error - // paths return the cached connection. + // heartbeat timestamp. Healthy, timed-out, and DB-error paths return + // the cached connection. Disconnected agents are covered separately + // because they now trigger a fresh dial before recovery. t.Parallel() type testCase struct { - name string - agent database.WorkspaceAgent - dbError bool - wantErr error - wantReleaseCalled bool + name string + agent database.WorkspaceAgent + dbError bool } tests := []testCase{ - { - name: "DisconnectedAgentCacheHit", - agent: database.WorkspaceAgent{ - FirstConnectedAt: sql.NullTime{ - Time: time.Now().Add(-10 * time.Minute), - Valid: true, - }, - LastConnectedAt: sql.NullTime{ - Time: time.Now().Add(-10 * time.Minute), - Valid: true, - }, - }, - wantErr: errChatAgentDisconnected, - wantReleaseCalled: true, - }, { // Agent never connected and the connection timeout - // has elapsed. This is the cache-hit timeout branch - // of isAgentUnreachable. + // has elapsed. This should not trigger lifecycle + // recovery because the agent did not connect and + // then disconnect. name: "TimedOutAgentCacheHit", agent: database.WorkspaceAgent{ CreatedAt: time.Now().Add(-10 * time.Minute), ConnectionTimeoutSeconds: 60, }, - wantErr: errChatAgentDisconnected, - wantReleaseCalled: true, }, { name: "CacheHitHealthyAgent", @@ -4371,28 +4377,274 @@ func TestGetWorkspaceConn_StatusCheck(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) gotConn, err := workspaceCtx.getWorkspaceConn(ctx) - - if tc.wantErr != nil { - require.Nil(t, gotConn) - require.ErrorIs(t, err, tc.wantErr) - } else { - require.NoError(t, err) - require.Same(t, cachedConn, gotConn) - } - - require.Equal(t, tc.wantReleaseCalled, releaseCalled, "release called") - - // For cache-hit disconnect, the cache should be cleared. - if tc.wantErr != nil { - workspaceCtx.mu.Lock() - defer workspaceCtx.mu.Unlock() - require.False(t, workspaceCtx.agentLoaded) - require.Nil(t, workspaceCtx.conn) - } + require.NoError(t, err) + require.Same(t, cachedConn, gotConn) + require.False(t, releaseCalled, "release called") }) } } +func TestGetWorkspaceConn_DialTimeoutDisconnectedRecoveryThreshold(t *testing.T) { + // The recovery sentinel requires a failed dial and a fresh + // disconnected status check past the recovery threshold. A + // disconnected DB row alone is not enough to trigger stop/start + // recovery. + t.Parallel() + + testCases := []struct { + name string + disconnectedFor time.Duration + wantErr error + wantRecovery bool + }{ + { + name: "RecentDisconnectReturnsDialTimeout", + disconnectedFor: agentDisconnectedRecoveryThreshold / 2, + wantErr: errChatDialTimeout, + wantRecovery: false, + }, + { + name: "PastThresholdEscalates", + disconnectedFor: agentDisconnectedRecoveryThreshold, + wantErr: errChatAgentDisconnected, + wantRecovery: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + agentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + } + + clock := quartz.NewMock(t) + now := clock.Now() + disconnectedAgent := database.WorkspaceAgent{ + ID: agentID, + FirstConnectedAt: sql.NullTime{ + Time: now.Add(-10 * time.Minute), + Valid: true, + }, + LastConnectedAt: sql.NullTime{ + Time: now.Add(-10 * time.Minute), + Valid: true, + }, + DisconnectedAt: sql.NullTime{ + Time: now.Add(-tc.disconnectedFor), + Valid: true, + }, + } + + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID). + Return(disconnectedAgent, nil). + Times(2) + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID). + Return([]database.WorkspaceAgent{disconnectedAgent}, nil). + Times(1) + + server := &Server{ + db: db, + logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + clock: clock, + agentInactiveDisconnectTimeout: 30 * time.Second, + dialTimeout: 10 * time.Millisecond, + } + server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) { + <-ctx.Done() + return nil, nil, ctx.Err() + } + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: server, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + defer workspaceCtx.close() + + ctx := testutil.Context(t, testutil.WaitShort) + gotConn, err := workspaceCtx.getWorkspaceConn(ctx) + require.Nil(t, gotConn) + require.ErrorIs(t, err, tc.wantErr) + if tc.wantRecovery { + require.ErrorIs(t, err, errChatAgentDisconnected) + } else { + require.NotErrorIs(t, err, errChatAgentDisconnected) + } + + workspaceCtx.mu.Lock() + defer workspaceCtx.mu.Unlock() + require.False(t, workspaceCtx.agentLoaded) + require.Nil(t, workspaceCtx.conn) + }) + } +} + +func TestGetWorkspaceConn_DisconnectedStatusDialSuccessDoesNotEscalate(t *testing.T) { + // A stale disconnected row must not prompt stop/start if the + // agent can still be dialed successfully. + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + agentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + } + + disconnectedAgent := database.WorkspaceAgent{ + ID: agentID, + FirstConnectedAt: sql.NullTime{ + Time: time.Now().Add(-10 * time.Minute), + Valid: true, + }, + LastConnectedAt: sql.NullTime{ + Time: time.Now().Add(-10 * time.Minute), + Valid: true, + }, + } + + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID). + Return(disconnectedAgent, nil). + Times(1) + + server := &Server{ + db: db, + logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + clock: quartz.NewReal(), + agentInactiveDisconnectTimeout: 30 * time.Second, + dialTimeout: 10 * time.Millisecond, + } + conn := agentconnmock.NewMockAgentConn(ctrl) + conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1) + var dialCalled bool + server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) { + dialCalled = true + return conn, nil, nil + } + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: server, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + defer workspaceCtx.close() + + ctx := testutil.Context(t, testutil.WaitShort) + gotConn, err := workspaceCtx.getWorkspaceConn(ctx) + require.NoError(t, err) + require.Same(t, conn, gotConn) + require.True(t, dialCalled, "dial called") +} + +func TestGetWorkspaceConn_CacheHitDisconnectedRetriesDialBeforeEscalating(t *testing.T) { + // A disconnected cached connection is discarded first. Recovery is + // only surfaced if the replacement dial also times out. + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + agentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + } + disconnectedAgent := database.WorkspaceAgent{ + ID: agentID, + FirstConnectedAt: sql.NullTime{ + Time: time.Now().Add(-10 * time.Minute), + Valid: true, + }, + LastConnectedAt: sql.NullTime{ + Time: time.Now().Add(-10 * time.Minute), + Valid: true, + }, + } + + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID). + Return(disconnectedAgent, nil). + Times(2) + + server := &Server{ + db: db, + logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + clock: quartz.NewReal(), + agentInactiveDisconnectTimeout: 30 * time.Second, + dialTimeout: 10 * time.Millisecond, + } + newConn := agentconnmock.NewMockAgentConn(ctrl) + newConn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1) + var dialCalled bool + server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) { + dialCalled = true + return newConn, nil, nil + } + + var releaseCalled bool + chatStateMu := &sync.Mutex{} + currentChat := chat + oldConn := agentconnmock.NewMockAgentConn(ctrl) + workspaceCtx := turnWorkspaceContext{ + server: server, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + agent: disconnectedAgent, + agentLoaded: true, + conn: oldConn, + releaseConn: func() { releaseCalled = true }, + cachedWorkspaceID: chat.WorkspaceID, + } + defer workspaceCtx.close() + + ctx := testutil.Context(t, testutil.WaitShort) + gotConn, err := workspaceCtx.getWorkspaceConn(ctx) + require.NoError(t, err) + require.Same(t, newConn, gotConn) + require.True(t, releaseCalled, "release called") + require.True(t, dialCalled, "dial called") +} + func TestGetWorkspaceConn_DialTimeout(t *testing.T) { // When dialWithLazyValidation blocks beyond the dial // timeout, getWorkspaceConn should return @@ -4431,6 +4683,9 @@ func TestGetWorkspaceConn_DialTimeout(t *testing.T) { db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID). Return(connectedAgent, nil). + Times(2) + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID). + Return([]database.WorkspaceAgent{connectedAgent}, nil). Times(1) server := &Server{ @@ -4461,6 +4716,70 @@ func TestGetWorkspaceConn_DialTimeout(t *testing.T) { require.ErrorIs(t, err, errChatDialTimeout) } +func TestGetWorkspaceConn_DialTimeoutStatusTimeoutDoesNotEscalate(t *testing.T) { + // Agents that never connected are startup failures, not + // disconnected recovery cases. A dial timeout should stay a + // retry/escalation error rather than stop/start guidance. + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + agentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + } + + timedOutAgent := database.WorkspaceAgent{ + ID: agentID, + CreatedAt: time.Now().Add(-10 * time.Minute), + ConnectionTimeoutSeconds: 60, + } + + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID). + Return(timedOutAgent, nil). + Times(2) + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID). + Return([]database.WorkspaceAgent{timedOutAgent}, nil). + Times(1) + + server := &Server{ + db: db, + clock: quartz.NewReal(), + agentInactiveDisconnectTimeout: 30 * time.Second, + dialTimeout: 10 * time.Millisecond, + } + server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) { + <-ctx.Done() + return nil, nil, ctx.Err() + } + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: server, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + defer workspaceCtx.close() + + ctx := testutil.Context(t, testutil.WaitShort) + gotConn, err := workspaceCtx.getWorkspaceConn(ctx) + require.Nil(t, gotConn) + require.ErrorIs(t, err, errChatDialTimeout) + require.NotErrorIs(t, err, errChatAgentDisconnected) +} + func TestGetWorkspaceConn_DialTimeoutParentCanceled(t *testing.T) { // When the parent context is canceled, the parent's error // must propagate unchanged (not wrapped as a dial timeout). diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index 34ba65218e..a66c4d2c6a 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -344,7 +344,10 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) { require.GreaterOrEqual(t, len(recorded), 2, "expected at least 2 streamed LLM calls (root + subagent)") - workspaceTools := []string{"list_templates", "read_template", "create_workspace"} + workspaceTools := []string{ + "list_templates", "read_template", "create_workspace", + "start_workspace", "stop_workspace", + } subagentTools := []string{"spawn_agent", "wait_agent", "message_agent", "close_agent"} // Identify root and subagent calls. Root chat calls include @@ -375,7 +378,10 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) { "root chat should have subagent tool %q", tool) } - // Standard turns (no turn mode) should hide propose_plan. + // Standard turns (no turn mode) hide plan-only tools until + // plan mode. + require.NotContains(t, rootCalls[0], "ask_user_question", + "standard-turn root chat should NOT have ask_user_question") require.NotContains(t, rootCalls[0], "propose_plan", "standard-turn root chat should NOT have propose_plan") @@ -388,6 +394,8 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) { require.NotContains(t, childCalls[0], tool, "subagent chat should NOT have subagent tool %q", tool) } + require.NotContains(t, childCalls[0], "ask_user_question", + "subagent chat should NOT have ask_user_question") } func TestPlanModeSubagentChatExcludesAskUserQuestion(t *testing.T) { @@ -7030,7 +7038,7 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) { // 4. Verify workspace provisioning tools are NOT present. workspaceProvisioningTools := []string{ "list_templates", "read_template", - "create_workspace", "start_workspace", + "create_workspace", "start_workspace", "stop_workspace", } for _, tool := range workspaceProvisioningTools { require.NotContains(t, childTools, tool, diff --git a/coderd/x/chatd/chattool/chattool.go b/coderd/x/chatd/chattool/chattool.go index 69b65f3e10..6f7adadcdf 100644 --- a/coderd/x/chatd/chattool/chattool.go +++ b/coderd/x/chatd/chattool/chattool.go @@ -1,12 +1,16 @@ package chattool import ( + "context" "encoding/json" "unicode/utf8" "charm.land/fantasy" "github.com/google/uuid" + "golang.org/x/xerrors" + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/codersdk" ) @@ -54,6 +58,65 @@ func responseErrorResult(resp codersdk.Response) map[string]any { return result } +func latestWorkspaceBuildAndJob( + ctx context.Context, + db database.Store, + workspaceID uuid.UUID, +) (database.WorkspaceBuild, database.ProvisionerJob, error) { + build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceID) + if err != nil { + return database.WorkspaceBuild{}, database.ProvisionerJob{}, xerrors.Errorf("get latest build: %w", err) + } + + job, err := db.GetProvisionerJobByID(ctx, build.JobID) + if err != nil { + return database.WorkspaceBuild{}, database.ProvisionerJob{}, xerrors.Errorf("get provisioner job: %w", err) + } + return build, job, nil +} + +func publishBuildBinding( + ctx context.Context, + db database.Store, + logger slog.Logger, + chatID uuid.UUID, + workspaceID uuid.UUID, + buildID uuid.UUID, + onChatUpdated func(database.Chat), +) { + updatedChat, bindErr := db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{ + ID: chatID, + WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}, + BuildID: uuid.NullUUID{ + UUID: buildID, + Valid: buildID != uuid.Nil, + }, + AgentID: uuid.NullUUID{}, + }) + if bindErr != nil { + logger.Error(ctx, "failed to persist build ID on chat binding", + slog.F("chat_id", chatID), + slog.F("build_id", buildID), + slog.Error(bindErr), + ) + return + } + if onChatUpdated != nil { + onChatUpdated(updatedChat) + } +} + +func provisionerJobTerminal(status database.ProvisionerJobStatus) bool { + switch status { + case database.ProvisionerJobStatusSucceeded, + database.ProvisionerJobStatusFailed, + database.ProvisionerJobStatusCanceled: + return true + default: + return false + } +} + func truncateRunes(value string, maxLen int) string { if maxLen <= 0 || value == "" { return "" diff --git a/coderd/x/chatd/chattool/startworkspace.go b/coderd/x/chatd/chattool/startworkspace.go index 16d1d1f9be..24b55348e6 100644 --- a/coderd/x/chatd/chattool/startworkspace.go +++ b/coderd/x/chatd/chattool/startworkspace.go @@ -56,7 +56,7 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO return fantasy.NewTextErrorResponse("workspace starter is not configured"), nil } - // Serialize with create_workspace to prevent races. + // Serialize with create_workspace and stop_workspace to prevent races. if options.WorkspaceMu != nil { options.WorkspaceMu.Lock() defer options.WorkspaceMu.Unlock() @@ -86,18 +86,9 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO ), nil } - build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID) + build, job, err := latestWorkspaceBuildAndJob(ctx, db, ws.ID) if err != nil { - return fantasy.NewTextErrorResponse( - xerrors.Errorf("get latest build: %w", err).Error(), - ), nil - } - - job, err := db.GetProvisionerJobByID(ctx, build.JobID) - if err != nil { - return fantasy.NewTextErrorResponse( - xerrors.Errorf("get provisioner job: %w", err).Error(), - ), nil + return fantasy.NewTextErrorResponse(err.Error()), nil } // If a build is already in progress, wait for it. @@ -106,24 +97,7 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO database.ProvisionerJobStatusRunning: // Publish the build ID to the frontend so it // can start streaming logs immediately. - updatedChat, bindErr := db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{ - ID: chatID, - WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, - BuildID: uuid.NullUUID{ - UUID: build.ID, - Valid: build.ID != uuid.Nil, - }, - AgentID: uuid.NullUUID{}, - }) - if bindErr != nil { - options.Logger.Error(ctx, "failed to persist build ID on chat binding", - slog.F("chat_id", chatID), - slog.F("build_id", build.ID), - slog.Error(bindErr), - ) - } else if options.OnChatUpdated != nil { - options.OnChatUpdated(updatedChat) - } + publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, build.ID, options.OnChatUpdated) if err := waitForBuild(ctx, db, build.ID); err != nil { // newBuildError returns via toolResponse (IsError: false) // rather than NewTextErrorResponse (IsError: true) so the @@ -199,24 +173,7 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO // Persist the build ID on the chat binding so the // frontend can stream logs without polling. - updatedChat, bindErr := db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{ - ID: chatID, - WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, - BuildID: uuid.NullUUID{ - UUID: startBuild.ID, - Valid: startBuild.ID != uuid.Nil, - }, - AgentID: uuid.NullUUID{}, - }) - if bindErr != nil { - options.Logger.Error(ctx, "failed to persist build ID on chat binding", - slog.F("chat_id", chatID), - slog.F("build_id", startBuild.ID), - slog.Error(bindErr), - ) - } else if options.OnChatUpdated != nil { - options.OnChatUpdated(updatedChat) - } + publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, startBuild.ID, options.OnChatUpdated) if err := waitForBuild(ctx, db, startBuild.ID); err != nil { return buildFailureToolResponse( ctx, diff --git a/coderd/x/chatd/chattool/stopworkspace.go b/coderd/x/chatd/chattool/stopworkspace.go new file mode 100644 index 0000000000..1aea9ad836 --- /dev/null +++ b/coderd/x/chatd/chattool/stopworkspace.go @@ -0,0 +1,181 @@ +package chattool + +import ( + "context" + "sync" + + "charm.land/fantasy" + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/httpapi/httperror" + "github.com/coder/coder/v2/codersdk" +) + +// StopWorkspaceFn stops a workspace by creating a new build with +// the "stop" transition. +type StopWorkspaceFn func( + ctx context.Context, + ownerID uuid.UUID, + workspaceID uuid.UUID, + req codersdk.CreateWorkspaceBuildRequest, +) (codersdk.WorkspaceBuild, error) + +// StopWorkspaceOptions configures the stop_workspace tool. +type StopWorkspaceOptions struct { + OwnerID uuid.UUID + StopFn StopWorkspaceFn + WorkspaceMu *sync.Mutex + OnChatUpdated func(database.Chat) + Logger slog.Logger +} + +type stopWorkspaceArgs struct{} + +// StopWorkspace returns a tool that stops the workspace associated +// with the current chat. The tool is idempotent when the workspace is +// already stopped. db must not be nil and chatID must not be uuid.Nil. +func StopWorkspace(db database.Store, chatID uuid.UUID, options StopWorkspaceOptions) fantasy.AgentTool { + return fantasy.NewAgentTool( + "stop_workspace", + "Stop the chat's workspace and wait for the stop build to complete. "+ + "If another workspace build is already in progress, this waits "+ + "for that build first, then stops the workspace if needed. "+ + "After waiting, this tool is idempotent if the workspace is "+ + "already stopped or the in-progress build stopped it. Use "+ + "this when the "+ + "user explicitly asks to stop the workspace, or when a "+ + "workspace-agent error tells you to stop and then start the "+ + "workspace. Stopping a workspace terminates running processes "+ + "and may discard unsaved in-memory state. This tool does not "+ + "delete the workspace.", + func(ctx context.Context, _ stopWorkspaceArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + if options.StopFn == nil { + return fantasy.NewTextErrorResponse("workspace stopper is not configured"), nil + } + + // Serialize with create_workspace and start_workspace to + // prevent lifecycle races. + if options.WorkspaceMu != nil { + options.WorkspaceMu.Lock() + defer options.WorkspaceMu.Unlock() + } + + chat, err := db.GetChatByID(ctx, chatID) + if err != nil { + return fantasy.NewTextErrorResponse( + xerrors.Errorf("load chat: %w", err).Error(), + ), nil + } + if !chat.WorkspaceID.Valid { + return fantasy.NewTextErrorResponse( + "chat has no workspace; use create_workspace first", + ), nil + } + + ws, err := db.GetWorkspaceByID(ctx, chat.WorkspaceID.UUID) + if err != nil { + return fantasy.NewTextErrorResponse( + xerrors.Errorf("load workspace: %w", err).Error(), + ), nil + } + if ws.Deleted { + return fantasy.NewTextErrorResponse( + "workspace was deleted; use create_workspace to make a new one", + ), nil + } + + build, job, err := latestWorkspaceBuildAndJob(ctx, db, ws.ID) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + + // If a build is already in progress, wait for it before + // deciding whether a stop build is still needed. + switch job.JobStatus { + case database.ProvisionerJobStatusPending, + database.ProvisionerJobStatusRunning, + database.ProvisionerJobStatusCanceling: + publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, build.ID, options.OnChatUpdated) + + waitErr := waitForBuild(ctx, db, build.ID) + // Re-read after waiting because another transition may + // have completed while this tool was blocked. + ws, err = db.GetWorkspaceByID(ctx, ws.ID) + if err != nil { + return fantasy.NewTextErrorResponse( + xerrors.Errorf("load workspace: %w", err).Error(), + ), nil + } + if ws.Deleted { + return fantasy.NewTextErrorResponse( + "workspace was deleted; use create_workspace to make a new one", + ), nil + } + build, job, err = latestWorkspaceBuildAndJob(ctx, db, ws.ID) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + // The fresh job row is authoritative. A wait error can + // be stale if the build reached a terminal state while the + // wait context was ending. + if waitErr != nil && !provisionerJobTerminal(job.JobStatus) { + return buildToolResponse(newBuildError( + xerrors.Errorf("waiting for in-progress build: %w", waitErr).Error(), + build.ID, + )), nil + } + } + + if job.JobStatus == database.ProvisionerJobStatusSucceeded && + build.Transition == database.WorkspaceTransitionStop { + result := map[string]any{ + "stopped": true, + "workspace_name": ws.Name, + } + setNoBuild(result, uuid.Nil) + return toolResponse(result), nil + } + + ownerCtx, ownerErr := asOwner(ctx, db, options.OwnerID) + if ownerErr != nil { + return fantasy.NewTextErrorResponse(ownerErr.Error()), nil + } + + stopBuild, err := options.StopFn(ownerCtx, options.OwnerID, ws.ID, codersdk.CreateWorkspaceBuildRequest{ + Transition: codersdk.WorkspaceTransitionStop, + }) + if err != nil { + if responseErr, ok := httperror.IsResponder(err); ok { + _, resp := responseErr.Response() + return toolResponse(responseErrorResult(resp)), nil + } + return fantasy.NewTextErrorResponse( + xerrors.Errorf("stop workspace: %w", err).Error(), + ), nil + } + + publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, stopBuild.ID, options.OnChatUpdated) + if err := waitForBuild(ctx, db, stopBuild.ID); err != nil { + return buildToolResponse(newBuildError( + xerrors.Errorf("workspace stop build failed: %w", err).Error(), + stopBuild.ID, + )), nil + } + + if options.OnChatUpdated != nil { + if latest, err := db.GetChatByID(ctx, chatID); err == nil { + options.OnChatUpdated(latest) + } + } + + result := map[string]any{ + "stopped": true, + "workspace_name": ws.Name, + } + setBuildID(result, stopBuild.ID) + return toolResponse(result), nil + }) +} diff --git a/coderd/x/chatd/chattool/stopworkspace_test.go b/coderd/x/chatd/chattool/stopworkspace_test.go new file mode 100644 index 0000000000..4133ba223d --- /dev/null +++ b/coderd/x/chatd/chattool/stopworkspace_test.go @@ -0,0 +1,449 @@ +package chattool_test + +import ( + "context" + "database/sql" + "encoding/json" + "sync" + "sync/atomic" + "testing" + "time" + + "charm.land/fantasy" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/x/chatd/chattool" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +func TestStopWorkspace(t *testing.T) { + t.Parallel() + + t.Run("NoWorkspace", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(t, db) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: modelCfg.ID, + Title: "test-stop-no-workspace", + }) + + tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{ + StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + t.Fatal("StopFn should not be called") + return codersdk.WorkspaceBuild{}, nil + }, + WorkspaceMu: &sync.Mutex{}, + }) + + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"}) + require.NoError(t, err) + require.Contains(t, resp.Content, "use create_workspace first") + }) + + t.Run("DeletedWorkspace", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(t, db) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + Deleted: true, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionDelete, + }).Do() + ws := wsResp.Workspace + + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-stop-deleted-workspace", + }) + + tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{ + StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + t.Fatal("StopFn should not be called for deleted workspace") + return codersdk.WorkspaceBuild{}, nil + }, + WorkspaceMu: &sync.Mutex{}, + }) + + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"}) + require.NoError(t, err) + require.Contains(t, resp.Content, "workspace was deleted") + require.Contains(t, resp.Content, "create_workspace") + }) + + t.Run("AlreadyStopped", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(t, db) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStop, + }).Do() + ws := wsResp.Workspace + + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-stop-already-stopped", + }) + + tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{ + OwnerID: user.ID, + StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + t.Fatal("StopFn should not be called for already-stopped workspace") + return codersdk.WorkspaceBuild{}, nil + }, + WorkspaceMu: &sync.Mutex{}, + }) + + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"}) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + require.Equal(t, true, result["stopped"]) + require.Equal(t, ws.Name, result["workspace_name"]) + require.Equal(t, true, result["no_build"]) + require.Nil(t, result["build_id"]) + }) + + t.Run("RunningWorkspaceStops", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(t, db) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStart, + }).Do() + ws := wsResp.Workspace + + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-stop-running-workspace", + }) + + var stopCalled atomic.Bool + var stopBuildID uuid.UUID + var seenBuildID uuid.UUID + var onChatUpdatedCalls atomic.Int32 + tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{ + OwnerID: user.ID, + StopFn: func(_ context.Context, _ uuid.UUID, wsID uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + stopCalled.Store(true) + require.Equal(t, ws.ID, wsID) + require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition) + buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStop, + BuildNumber: 2, + }).Do() + stopBuildID = buildResp.Build.ID + return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil + }, + WorkspaceMu: &sync.Mutex{}, + OnChatUpdated: func(chat database.Chat) { + onChatUpdatedCalls.Add(1) + if chat.BuildID.Valid { + seenBuildID = chat.BuildID.UUID + } + }, + }) + + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"}) + require.NoError(t, err) + require.True(t, stopCalled.Load()) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + require.Equal(t, true, result["stopped"]) + require.Equal(t, ws.Name, result["workspace_name"]) + require.Equal(t, stopBuildID.String(), result["build_id"]) + require.Nil(t, result["no_build"]) + + require.GreaterOrEqual(t, onChatUpdatedCalls.Load(), int32(1)) + require.Equal(t, stopBuildID, seenBuildID) + + updatedChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.True(t, updatedChat.BuildID.Valid) + require.Equal(t, stopBuildID, updatedChat.BuildID.UUID) + }) + + t.Run("InProgressBuildWaitsThenStops", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(t, db) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStart, + }).Starting().Do() + ws := wsResp.Workspace + + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-stop-in-progress-build", + }) + + jobRead := make(chan struct{}, 1) + wrappedDB := &jobInterceptStore{Store: db, jobRead: jobRead} + var stopCalled atomic.Bool + var stopBuildID uuid.UUID + var onChatUpdatedCalled atomic.Bool + tool := chattool.StopWorkspace(wrappedDB, chat.ID, chattool.StopWorkspaceOptions{ + OwnerID: user.ID, + StopFn: func(_ context.Context, _ uuid.UUID, wsID uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + stopCalled.Store(true) + require.Equal(t, ws.ID, wsID) + require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition) + buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStop, + BuildNumber: 2, + }).Do() + stopBuildID = buildResp.Build.ID + return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil + }, + WorkspaceMu: &sync.Mutex{}, + Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + OnChatUpdated: func(_ database.Chat) { onChatUpdatedCalled.Store(true) }, + }) + + type toolResult struct { + resp fantasy.ToolResponse + err error + } + done := make(chan toolResult, 1) + go func() { + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"}) + done <- toolResult{resp: resp, err: err} + }() + + testutil.TryReceive(ctx, t, jobRead) + require.False(t, stopCalled.Load(), "StopFn must wait for the in-progress build") + + now := time.Now().UTC() + require.NoError(t, db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: wsResp.Build.JobID, + UpdatedAt: now, + CompletedAt: sql.NullTime{Time: now, Valid: true}, + })) + + res := testutil.TryReceive(ctx, t, done) + require.NoError(t, res.err) + require.True(t, stopCalled.Load()) + require.True(t, onChatUpdatedCalled.Load()) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(res.resp.Content), &result)) + require.Equal(t, true, result["stopped"]) + require.Equal(t, stopBuildID.String(), result["build_id"]) + }) + + t.Run("FailedLatestStopBuildStillStops", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(t, db) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStop, + }).Do() + ws := wsResp.Workspace + now := time.Now().UTC() + require.NoError(t, db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: wsResp.Build.JobID, + UpdatedAt: now, + CompletedAt: sql.NullTime{Time: now, Valid: true}, + Error: sql.NullString{String: "latest build failed", Valid: true}, + })) + + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-stop-failed-latest-build", + }) + + var stopCalled atomic.Bool + tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{ + OwnerID: user.ID, + StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + stopCalled.Store(true) + require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition) + buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStop, + BuildNumber: 2, + }).Do() + return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil + }, + WorkspaceMu: &sync.Mutex{}, + }) + + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"}) + require.NoError(t, err) + require.True(t, stopCalled.Load()) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + require.Equal(t, true, result["stopped"]) + }) + + t.Run("StopTriggeredBuildFailure", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(t, db) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStart, + }).Do() + ws := wsResp.Workspace + + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-stop-triggered-build-failure", + }) + + var stopBuildJobID uuid.UUID + var stopBuildID uuid.UUID + stopFn := func(_ context.Context, _ uuid.UUID, wsID uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + require.Equal(t, ws.ID, wsID) + require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition) + buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStop, + BuildNumber: 2, + }).Starting().Do() + stopBuildJobID = buildResp.Build.JobID + stopBuildID = buildResp.Build.ID + return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil + } + + jobRead := make(chan struct{}, 2) + wrappedDB := &jobInterceptStore{Store: db, jobRead: jobRead} + tool := chattool.StopWorkspace(wrappedDB, chat.ID, chattool.StopWorkspaceOptions{ + OwnerID: user.ID, + StopFn: stopFn, + WorkspaceMu: &sync.Mutex{}, + Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + }) + + type toolResult struct { + resp fantasy.ToolResponse + err error + } + done := make(chan toolResult, 1) + go func() { + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"}) + done <- toolResult{resp: resp, err: err} + }() + + testutil.TryReceive(ctx, t, jobRead) + testutil.TryReceive(ctx, t, jobRead) + + now := time.Now().UTC() + require.NoError(t, db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: stopBuildJobID, + UpdatedAt: now, + CompletedAt: sql.NullTime{Time: now, Valid: true}, + Error: sql.NullString{String: "terraform destroy failed", Valid: true}, + })) + + res := testutil.TryReceive(ctx, t, done) + require.NoError(t, res.err) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(res.resp.Content), &result)) + require.Contains(t, result["error"], "workspace stop build failed") + require.Equal(t, stopBuildID.String(), result["build_id"]) + require.False(t, res.resp.IsError, + "buildToolResponse must not set IsError; chatprompt strips structured fields from error responses") + }) +}