diff --git a/coderd/coderd.go b/coderd/coderd.go index 619a91f7b0..3db9bea92d 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -805,6 +805,7 @@ func New(options *Options) *API { InstructionLookupTimeout: options.ChatdInstructionLookupTimeout, CreateWorkspace: api.chatCreateWorkspace, StartWorkspace: api.chatStartWorkspace, + StopWorkspace: api.chatStopWorkspace, Pubsub: options.Pubsub, WebpushDispatcher: options.WebPushDispatcher, UsageTracker: options.WorkspaceUsageTracker, diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 43b1f105a5..f7ed4b8498 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -3740,6 +3740,53 @@ func (api *API) chatStartWorkspace( return apiBuild, nil } +// chatStopWorkspace stops a workspace by creating a new build with the +// "stop" transition. It mirrors chatStartWorkspace, without start-only +// active-version behavior. +func (api *API) chatStopWorkspace( + ctx context.Context, + ownerID uuid.UUID, + workspaceID uuid.UUID, + req codersdk.CreateWorkspaceBuildRequest, +) (codersdk.WorkspaceBuild, error) { + actor, _, err := httpmw.UserRBACSubject(ctx, api.Database, ownerID, rbac.ScopeAll) + if err != nil { + return codersdk.WorkspaceBuild{}, xerrors.Errorf("load user authorization: %w", err) + } + ctx = dbauthz.As(ctx, actor) + + workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceID) + if err != nil { + return codersdk.WorkspaceBuild{}, xerrors.Errorf("get workspace: %w", err) + } + + req.Transition = codersdk.WorkspaceTransitionStop + + // Build a synthetic API key so postWorkspaceBuildsInternal can + // record the correct initiator. + syntheticKey := database.APIKey{ + UserID: ownerID, + } + + apiBuild, err := api.postWorkspaceBuildsInternal( + ctx, + syntheticKey, + workspace, + req, + func(action policy.Action, object rbac.Objecter) bool { + // Authorization is handled by dbauthz on the context. + authErr := api.HTTPAuth.Authorizer.Authorize(ctx, actor, action, object.RBACObject()) + return authErr == nil + }, + audit.WorkspaceBuildBaggage{}, + ) + if err != nil { + return codersdk.WorkspaceBuild{}, xerrors.Errorf("create workspace build: %w", err) + } + + return apiBuild, nil +} + func rewriteChatStartWorkspaceManualUpdateResponse(resp codersdk.Response, fallbackDetail string, retryInstructions string) codersdk.Response { originalMessage := resp.Message resp.Message = retryInstructions diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index a8e944c016..4d20d50840 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -13310,6 +13310,43 @@ func TestChatStartWorkspace_RequireActiveVersion(t *testing.T) { require.Nil(t, build.TemplateVersionPresetID, "no preset must be applied") } +func TestChatStopWorkspace_BypassesRequireActiveVersion(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + rawClient, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{}) + var store dbauthz.AccessControlStore = requireActiveVersionStore{} + api.AccessControlStore.Store(&store) + db := api.Database + user := coderdtest.CreateFirstUser(t, rawClient) + + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.UserID, + OrganizationID: user.OrganizationID, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStart, + }).Do() + v1ID := wsResp.Build.TemplateVersionID + tmplID := wsResp.Workspace.TemplateID + + v2Resp := dbfake.TemplateVersion(t, db).Seed(database.TemplateVersion{ + TemplateID: uuid.NullUUID{UUID: tmplID, Valid: true}, + OrganizationID: user.OrganizationID, + CreatedBy: user.UserID, + }).Do() + v2 := v2Resp.TemplateVersion + require.NotEqual(t, v1ID, v2.ID, "v2 must differ from v1") + + build, err := coderd.ChatStopWorkspace(api, ctx, user.UserID, wsResp.Workspace.ID, + codersdk.CreateWorkspaceBuildRequest{}) + + require.NoError(t, err) + require.Equal(t, codersdk.WorkspaceTransitionStop, build.Transition) + require.Equal(t, v1ID, build.TemplateVersionID, + "stop must not apply RequireActiveVersion start-only logic") + require.NotEqual(t, v2.ID, build.TemplateVersionID) +} + func TestGetChatMessages_Pagination(t *testing.T) { t.Parallel() diff --git a/coderd/export_test.go b/coderd/export_test.go index 44f24a09ba..475270b994 100644 --- a/coderd/export_test.go +++ b/coderd/export_test.go @@ -11,3 +11,6 @@ var InsertAgentChatTestModelConfig = insertAgentChatTestModelConfig // stubbing the entire DB layer. The proper fix is to extract a pure // request builder; tracked in CODAGT-292. var ChatStartWorkspace = (*API).chatStartWorkspace + +// ChatStopWorkspace exposes chatStopWorkspace for external tests. +var ChatStopWorkspace = (*API).chatStopWorkspace diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 54dfd43963..ee9b6f08c6 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -121,6 +121,13 @@ const ( // streamJanitorInterval. streamJanitorInterval = 30 * time.Second + // agentDisconnectedRecoveryThreshold is how long the latest + // workspace agent must be disconnected before chatd suggests + // destructive stop/start recovery. This is intentionally longer + // than the inactive-disconnect timeout so short heartbeat gaps do + // not prompt a workspace restart. + agentDisconnectedRecoveryThreshold = 90 * time.Second + // DefaultMaxChatsPerAcquire is the maximum number of chats to // acquire in a single processOnce call. Batching avoids // waiting a full polling interval between acquisitions @@ -139,12 +146,14 @@ const ( var ( errChatHasNoWorkspaceAgent = xerrors.New("workspace has no running agent: the workspace is likely stopped. Use the start_workspace tool to start it") errChatAgentDisconnected = xerrors.New( - "workspace agent is disconnected and cannot execute tools. " + - "The workspace may need to be restarted from the Coder dashboard", + "workspace agent has been disconnected for at least 90 seconds " + + "and cannot execute tools. To recover, call stop_workspace " + + "to stop the workspace, then start_workspace to start it " + + "again", ) errChatDialTimeout = xerrors.New( "connection to the workspace agent timed out. " + - "The workspace may need to be restarted from the Coder dashboard", + "The agent may still be reachable on the next attempt.", ) errChatExternalAgentUnavailable = xerrors.New("external workspace agent unavailable") ) @@ -187,6 +196,7 @@ type Server struct { instructionLookupTimeout time.Duration createWorkspaceFn chattool.CreateWorkspaceFn startWorkspaceFn chattool.StartWorkspaceFn + stopWorkspaceFn chattool.StopWorkspaceFn pubsub pubsub.Pubsub webpushDispatcher webpush.Dispatcher providerAPIKeys chatprovider.ProviderAPIKeys @@ -786,6 +796,45 @@ func isAgentUnreachable(now time.Time, agent database.WorkspaceAgent, inactiveTi status.Status == database.WorkspaceAgentStatusTimeout } +func agentDisconnectedFor(now time.Time, agent database.WorkspaceAgent, inactiveTimeout time.Duration) (time.Duration, bool) { + status := agent.Status(now, inactiveTimeout) + if status.Status != database.WorkspaceAgentStatusDisconnected || status.DisconnectedAt == nil { + return 0, false + } + + disconnectedFor := now.Sub(*status.DisconnectedAt) + if disconnectedFor < 0 { + disconnectedFor = 0 + } + return disconnectedFor, true +} + +func (c *turnWorkspaceContext) latestWorkspaceAgentNeedsRestart( + ctx context.Context, + workspaceID uuid.UUID, +) (bool, error) { + agentID, err := c.latestWorkspaceAgentID(ctx, workspaceID) + if err != nil { + if xerrors.Is(err, errChatHasNoWorkspaceAgent) { + return false, err + } + c.server.logger.Warn(ctx, "failed to resolve latest agent for timeout classification", slog.Error(err)) + return false, nil + } + + agent, err := c.server.db.GetWorkspaceAgentByID(ctx, agentID) + if err != nil { + c.server.logger.Warn(ctx, "failed to load latest agent for timeout classification", + slog.F("agent_id", agentID), + slog.Error(err), + ) + return false, nil + } + + disconnectedFor, disconnected := agentDisconnectedFor(c.server.clock.Now(), agent, c.server.agentInactiveDisconnectTimeout) + return disconnected && disconnectedFor >= agentDisconnectedRecoveryThreshold, nil +} + func (c *turnWorkspaceContext) externalAgentError( ctx context.Context, agent database.WorkspaceAgent, @@ -853,9 +902,13 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces ) // On DB error the check re-runs on the // next tool call. - } else if isAgentUnreachable(c.server.clock.Now(), freshAgent, c.server.agentInactiveDisconnectTimeout) { + } else if _, disconnected := agentDisconnectedFor( + c.server.clock.Now(), + freshAgent, + c.server.agentInactiveDisconnectTimeout, + ); disconnected { c.clearCachedWorkspaceState() - return nil, c.externalAgentError(ctx, freshAgent, errChatAgentDisconnected) + continue } } return currentConn, nil @@ -898,6 +951,14 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces // canceled (e.g. ErrInterrupted), its error must // propagate unchanged so the chatloop can detect it. if ctx.Err() == nil && errors.Is(context.Cause(dialCtx), errChatDialTimeout) { + c.clearCachedWorkspaceState() + needsRestart, statusErr := c.latestWorkspaceAgentNeedsRestart(ctx, chatSnapshot.WorkspaceID.UUID) + if statusErr != nil { + return nil, statusErr + } + if needsRestart { + return nil, c.externalAgentError(ctx, agent, errChatAgentDisconnected) + } return nil, c.externalAgentError(ctx, agent, errChatDialTimeout) } return nil, err @@ -3752,6 +3813,7 @@ type Config struct { InstructionLookupTimeout time.Duration CreateWorkspace chattool.CreateWorkspaceFn StartWorkspace chattool.StartWorkspaceFn + StopWorkspace chattool.StopWorkspaceFn Pubsub pubsub.Pubsub ProviderAPIKeys chatprovider.ProviderAPIKeys AlwaysEnableDebugLogs bool @@ -3820,6 +3882,7 @@ func New(cfg Config) *Server { instructionLookupTimeout: instructionLookupTimeout, createWorkspaceFn: cfg.CreateWorkspace, startWorkspaceFn: cfg.StartWorkspace, + stopWorkspaceFn: cfg.StopWorkspace, pubsub: cfg.Pubsub, webpushDispatcher: cfg.WebpushDispatcher, providerAPIKeys: cfg.ProviderAPIKeys, @@ -5899,7 +5962,7 @@ func builtinPlanToolAllowed(name string, isRootChat bool) bool { case "read_file", "execute", "process_output", "read_skill", "read_skill_file": return true case "write_file", "edit_files", "list_templates", "read_template", - "create_workspace", "start_workspace", "propose_plan", "spawn_agent", + "create_workspace", "start_workspace", "stop_workspace", "propose_plan", "spawn_agent", "spawn_explore_agent", "wait_agent", "ask_user_question": return isRootChat case "process_list", "process_signal", "message_agent", "close_agent", @@ -5979,6 +6042,7 @@ func allowedExploreToolNames(allTools []fantasy.AgentTool) []string { "read_template": false, "create_workspace": false, "start_workspace": false, + "stop_workspace": false, "propose_plan": false, "spawn_agent": false, "wait_agent": false, @@ -6198,6 +6262,13 @@ func (p *Server) appendRootChatTools( OnChatUpdated: onChatUpdated, Logger: p.logger, }), + chattool.StopWorkspace(p.db, opts.chat.ID, chattool.StopWorkspaceOptions{ + OwnerID: opts.chat.OwnerID, + StopFn: p.stopWorkspaceFn, + WorkspaceMu: opts.workspaceMu, + OnChatUpdated: onChatUpdated, + Logger: p.logger, + }), ) if opts.isPlanModeTurn { tools = append(tools, chattool.ProposePlan(chattool.ProposePlanOptions{ diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index b23ec23d18..1f80e6c81d 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -295,6 +295,26 @@ func TestFilterExternalMCPConfigsForTurn(t *testing.T) { }) } +func TestChatWorkspaceRecoveryErrorsDifferentiateSignalStrength(t *testing.T) { + t.Parallel() + + // Disconnected recovery is gated by a DB-confirmed duration + // threshold, so the message can give direct stop/start guidance + // without asking the user. + disconnected := errChatAgentDisconnected.Error() + require.Contains(t, disconnected, "90 seconds") + require.Contains(t, disconnected, "stop_workspace") + require.Contains(t, disconnected, "start_workspace") + require.NotContains(t, disconnected, "ask_user_question") + + // Dial timeout alone is a weak signal. The model should not + // escalate to lifecycle tools without DB-confirmed disconnect. + dialTimeout := errChatDialTimeout.Error() + require.NotContains(t, dialTimeout, "ask_user_question") + require.NotContains(t, dialTimeout, "stop_workspace") + require.NotContains(t, dialTimeout, "start_workspace") +} + func TestActiveToolNamesForTurn(t *testing.T) { t.Parallel() @@ -344,6 +364,7 @@ func TestActiveToolNamesForTurn(t *testing.T) { "read_template", "create_workspace", "start_workspace", + "stop_workspace", "propose_plan", "spawn_agent", "wait_agent", @@ -364,6 +385,7 @@ func TestActiveToolNamesForTurn(t *testing.T) { "read_template", "create_workspace", "start_workspace", + "stop_workspace", "propose_plan", "spawn_agent", "wait_agent", @@ -386,6 +408,7 @@ func TestActiveToolNamesForTurn(t *testing.T) { "read_template", "create_workspace", "start_workspace", + "stop_workspace", "propose_plan", "spawn_agent", "wait_agent", @@ -405,6 +428,8 @@ func TestActiveToolNamesForTurn(t *testing.T) { require.NotContains(t, got, "edit_files") require.NotContains(t, got, "ask_user_question") require.NotContains(t, got, "propose_plan") + require.NotContains(t, got, "start_workspace") + require.NotContains(t, got, "stop_workspace") require.NotContains(t, got, "spawn_explore_agent") }) @@ -474,6 +499,8 @@ func TestAllowedExploreToolNames(t *testing.T) { newTestAgentTool("write_file"), newTestMCPAgentTool("external-mcp__echo", externalConfigID), newTestAgentTool("workspace-mcp__echo"), + newTestAgentTool("start_workspace"), + newTestAgentTool("stop_workspace"), newTestAgentTool("execute"), newTestAgentTool("process_output"), newTestAgentTool("process_list"), @@ -494,6 +521,9 @@ func TestAllowedExploreToolNames(t *testing.T) { "read_skill_file", }, got) require.NotContains(t, got, "workspace-mcp__echo") + require.NotContains(t, got, "start_workspace") + require.NotContains(t, got, "stop_workspace") + require.NotContains(t, got, "ask_user_question") } func TestAllowedBehaviorToolNames(t *testing.T) { @@ -580,19 +610,13 @@ func TestStopAfterBehaviorTools(t *testing.T) { )) }) - t.Run("RootPlanModeIncludesClarificationTool", func(t *testing.T) { + t.Run("PlanModeDelegatesToPlanTools", func(t *testing.T) { t.Parallel() - require.Equal(t, map[string]struct{}{ - "propose_plan": {}, - "ask_user_question": {}, - }, stopAfterBehaviorTools(planMode, database.NullChatMode{}, uuid.NullUUID{})) - }) - - t.Run("ChildPlanModeSkipsClarificationTool", func(t *testing.T) { - t.Parallel() - require.Equal(t, map[string]struct{}{ - "propose_plan": {}, - }, stopAfterBehaviorTools(planMode, database.NullChatMode{}, uuid.NullUUID{UUID: uuid.New(), Valid: true})) + require.Equal(t, stopAfterPlanTools(planMode, uuid.NullUUID{}), stopAfterBehaviorTools( + planMode, + database.NullChatMode{}, + uuid.NullUUID{}, + )) }) t.Run("ExploreModeReturnsNil", func(t *testing.T) { @@ -4229,46 +4253,28 @@ func TestGetWorkspaceConn_SameBuildAgentCrash(t *testing.T) { func TestGetWorkspaceConn_StatusCheck(t *testing.T) { // The cache-hit status check re-fetches the agent row for a fresh - // heartbeat timestamp. These tests verify that path detects - // disconnected or timed-out agents and that healthy or DB-error - // paths return the cached connection. + // heartbeat timestamp. Healthy, timed-out, and DB-error paths return + // the cached connection. Disconnected agents are covered separately + // because they now trigger a fresh dial before recovery. t.Parallel() type testCase struct { - name string - agent database.WorkspaceAgent - dbError bool - wantErr error - wantReleaseCalled bool + name string + agent database.WorkspaceAgent + dbError bool } tests := []testCase{ - { - name: "DisconnectedAgentCacheHit", - agent: database.WorkspaceAgent{ - FirstConnectedAt: sql.NullTime{ - Time: time.Now().Add(-10 * time.Minute), - Valid: true, - }, - LastConnectedAt: sql.NullTime{ - Time: time.Now().Add(-10 * time.Minute), - Valid: true, - }, - }, - wantErr: errChatAgentDisconnected, - wantReleaseCalled: true, - }, { // Agent never connected and the connection timeout - // has elapsed. This is the cache-hit timeout branch - // of isAgentUnreachable. + // has elapsed. This should not trigger lifecycle + // recovery because the agent did not connect and + // then disconnect. name: "TimedOutAgentCacheHit", agent: database.WorkspaceAgent{ CreatedAt: time.Now().Add(-10 * time.Minute), ConnectionTimeoutSeconds: 60, }, - wantErr: errChatAgentDisconnected, - wantReleaseCalled: true, }, { name: "CacheHitHealthyAgent", @@ -4371,28 +4377,274 @@ func TestGetWorkspaceConn_StatusCheck(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) gotConn, err := workspaceCtx.getWorkspaceConn(ctx) - - if tc.wantErr != nil { - require.Nil(t, gotConn) - require.ErrorIs(t, err, tc.wantErr) - } else { - require.NoError(t, err) - require.Same(t, cachedConn, gotConn) - } - - require.Equal(t, tc.wantReleaseCalled, releaseCalled, "release called") - - // For cache-hit disconnect, the cache should be cleared. - if tc.wantErr != nil { - workspaceCtx.mu.Lock() - defer workspaceCtx.mu.Unlock() - require.False(t, workspaceCtx.agentLoaded) - require.Nil(t, workspaceCtx.conn) - } + require.NoError(t, err) + require.Same(t, cachedConn, gotConn) + require.False(t, releaseCalled, "release called") }) } } +func TestGetWorkspaceConn_DialTimeoutDisconnectedRecoveryThreshold(t *testing.T) { + // The recovery sentinel requires a failed dial and a fresh + // disconnected status check past the recovery threshold. A + // disconnected DB row alone is not enough to trigger stop/start + // recovery. + t.Parallel() + + testCases := []struct { + name string + disconnectedFor time.Duration + wantErr error + wantRecovery bool + }{ + { + name: "RecentDisconnectReturnsDialTimeout", + disconnectedFor: agentDisconnectedRecoveryThreshold / 2, + wantErr: errChatDialTimeout, + wantRecovery: false, + }, + { + name: "PastThresholdEscalates", + disconnectedFor: agentDisconnectedRecoveryThreshold, + wantErr: errChatAgentDisconnected, + wantRecovery: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + agentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + } + + clock := quartz.NewMock(t) + now := clock.Now() + disconnectedAgent := database.WorkspaceAgent{ + ID: agentID, + FirstConnectedAt: sql.NullTime{ + Time: now.Add(-10 * time.Minute), + Valid: true, + }, + LastConnectedAt: sql.NullTime{ + Time: now.Add(-10 * time.Minute), + Valid: true, + }, + DisconnectedAt: sql.NullTime{ + Time: now.Add(-tc.disconnectedFor), + Valid: true, + }, + } + + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID). + Return(disconnectedAgent, nil). + Times(2) + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID). + Return([]database.WorkspaceAgent{disconnectedAgent}, nil). + Times(1) + + server := &Server{ + db: db, + logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + clock: clock, + agentInactiveDisconnectTimeout: 30 * time.Second, + dialTimeout: 10 * time.Millisecond, + } + server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) { + <-ctx.Done() + return nil, nil, ctx.Err() + } + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: server, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + defer workspaceCtx.close() + + ctx := testutil.Context(t, testutil.WaitShort) + gotConn, err := workspaceCtx.getWorkspaceConn(ctx) + require.Nil(t, gotConn) + require.ErrorIs(t, err, tc.wantErr) + if tc.wantRecovery { + require.ErrorIs(t, err, errChatAgentDisconnected) + } else { + require.NotErrorIs(t, err, errChatAgentDisconnected) + } + + workspaceCtx.mu.Lock() + defer workspaceCtx.mu.Unlock() + require.False(t, workspaceCtx.agentLoaded) + require.Nil(t, workspaceCtx.conn) + }) + } +} + +func TestGetWorkspaceConn_DisconnectedStatusDialSuccessDoesNotEscalate(t *testing.T) { + // A stale disconnected row must not prompt stop/start if the + // agent can still be dialed successfully. + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + agentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + } + + disconnectedAgent := database.WorkspaceAgent{ + ID: agentID, + FirstConnectedAt: sql.NullTime{ + Time: time.Now().Add(-10 * time.Minute), + Valid: true, + }, + LastConnectedAt: sql.NullTime{ + Time: time.Now().Add(-10 * time.Minute), + Valid: true, + }, + } + + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID). + Return(disconnectedAgent, nil). + Times(1) + + server := &Server{ + db: db, + logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + clock: quartz.NewReal(), + agentInactiveDisconnectTimeout: 30 * time.Second, + dialTimeout: 10 * time.Millisecond, + } + conn := agentconnmock.NewMockAgentConn(ctrl) + conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1) + var dialCalled bool + server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) { + dialCalled = true + return conn, nil, nil + } + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: server, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + defer workspaceCtx.close() + + ctx := testutil.Context(t, testutil.WaitShort) + gotConn, err := workspaceCtx.getWorkspaceConn(ctx) + require.NoError(t, err) + require.Same(t, conn, gotConn) + require.True(t, dialCalled, "dial called") +} + +func TestGetWorkspaceConn_CacheHitDisconnectedRetriesDialBeforeEscalating(t *testing.T) { + // A disconnected cached connection is discarded first. Recovery is + // only surfaced if the replacement dial also times out. + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + agentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + } + disconnectedAgent := database.WorkspaceAgent{ + ID: agentID, + FirstConnectedAt: sql.NullTime{ + Time: time.Now().Add(-10 * time.Minute), + Valid: true, + }, + LastConnectedAt: sql.NullTime{ + Time: time.Now().Add(-10 * time.Minute), + Valid: true, + }, + } + + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID). + Return(disconnectedAgent, nil). + Times(2) + + server := &Server{ + db: db, + logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + clock: quartz.NewReal(), + agentInactiveDisconnectTimeout: 30 * time.Second, + dialTimeout: 10 * time.Millisecond, + } + newConn := agentconnmock.NewMockAgentConn(ctrl) + newConn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1) + var dialCalled bool + server.agentConnFn = func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) { + dialCalled = true + return newConn, nil, nil + } + + var releaseCalled bool + chatStateMu := &sync.Mutex{} + currentChat := chat + oldConn := agentconnmock.NewMockAgentConn(ctrl) + workspaceCtx := turnWorkspaceContext{ + server: server, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + agent: disconnectedAgent, + agentLoaded: true, + conn: oldConn, + releaseConn: func() { releaseCalled = true }, + cachedWorkspaceID: chat.WorkspaceID, + } + defer workspaceCtx.close() + + ctx := testutil.Context(t, testutil.WaitShort) + gotConn, err := workspaceCtx.getWorkspaceConn(ctx) + require.NoError(t, err) + require.Same(t, newConn, gotConn) + require.True(t, releaseCalled, "release called") + require.True(t, dialCalled, "dial called") +} + func TestGetWorkspaceConn_DialTimeout(t *testing.T) { // When dialWithLazyValidation blocks beyond the dial // timeout, getWorkspaceConn should return @@ -4431,6 +4683,9 @@ func TestGetWorkspaceConn_DialTimeout(t *testing.T) { db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID). Return(connectedAgent, nil). + Times(2) + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID). + Return([]database.WorkspaceAgent{connectedAgent}, nil). Times(1) server := &Server{ @@ -4461,6 +4716,70 @@ func TestGetWorkspaceConn_DialTimeout(t *testing.T) { require.ErrorIs(t, err, errChatDialTimeout) } +func TestGetWorkspaceConn_DialTimeoutStatusTimeoutDoesNotEscalate(t *testing.T) { + // Agents that never connected are startup failures, not + // disconnected recovery cases. A dial timeout should stay a + // retry/escalation error rather than stop/start guidance. + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + workspaceID := uuid.New() + agentID := uuid.New() + chat := database.Chat{ + ID: uuid.New(), + WorkspaceID: uuid.NullUUID{ + UUID: workspaceID, + Valid: true, + }, + AgentID: uuid.NullUUID{ + UUID: agentID, + Valid: true, + }, + } + + timedOutAgent := database.WorkspaceAgent{ + ID: agentID, + CreatedAt: time.Now().Add(-10 * time.Minute), + ConnectionTimeoutSeconds: 60, + } + + db.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID). + Return(timedOutAgent, nil). + Times(2) + db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID). + Return([]database.WorkspaceAgent{timedOutAgent}, nil). + Times(1) + + server := &Server{ + db: db, + clock: quartz.NewReal(), + agentInactiveDisconnectTimeout: 30 * time.Second, + dialTimeout: 10 * time.Millisecond, + } + server.agentConnFn = func(ctx context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) { + <-ctx.Done() + return nil, nil, ctx.Err() + } + + chatStateMu := &sync.Mutex{} + currentChat := chat + workspaceCtx := turnWorkspaceContext{ + server: server, + chatStateMu: chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil }, + } + defer workspaceCtx.close() + + ctx := testutil.Context(t, testutil.WaitShort) + gotConn, err := workspaceCtx.getWorkspaceConn(ctx) + require.Nil(t, gotConn) + require.ErrorIs(t, err, errChatDialTimeout) + require.NotErrorIs(t, err, errChatAgentDisconnected) +} + func TestGetWorkspaceConn_DialTimeoutParentCanceled(t *testing.T) { // When the parent context is canceled, the parent's error // must propagate unchanged (not wrapped as a dial timeout). diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index 34ba65218e..a66c4d2c6a 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -344,7 +344,10 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) { require.GreaterOrEqual(t, len(recorded), 2, "expected at least 2 streamed LLM calls (root + subagent)") - workspaceTools := []string{"list_templates", "read_template", "create_workspace"} + workspaceTools := []string{ + "list_templates", "read_template", "create_workspace", + "start_workspace", "stop_workspace", + } subagentTools := []string{"spawn_agent", "wait_agent", "message_agent", "close_agent"} // Identify root and subagent calls. Root chat calls include @@ -375,7 +378,10 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) { "root chat should have subagent tool %q", tool) } - // Standard turns (no turn mode) should hide propose_plan. + // Standard turns (no turn mode) hide plan-only tools until + // plan mode. + require.NotContains(t, rootCalls[0], "ask_user_question", + "standard-turn root chat should NOT have ask_user_question") require.NotContains(t, rootCalls[0], "propose_plan", "standard-turn root chat should NOT have propose_plan") @@ -388,6 +394,8 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) { require.NotContains(t, childCalls[0], tool, "subagent chat should NOT have subagent tool %q", tool) } + require.NotContains(t, childCalls[0], "ask_user_question", + "subagent chat should NOT have ask_user_question") } func TestPlanModeSubagentChatExcludesAskUserQuestion(t *testing.T) { @@ -7030,7 +7038,7 @@ func TestComputerUseSubagentToolsAndModel(t *testing.T) { // 4. Verify workspace provisioning tools are NOT present. workspaceProvisioningTools := []string{ "list_templates", "read_template", - "create_workspace", "start_workspace", + "create_workspace", "start_workspace", "stop_workspace", } for _, tool := range workspaceProvisioningTools { require.NotContains(t, childTools, tool, diff --git a/coderd/x/chatd/chattool/chattool.go b/coderd/x/chatd/chattool/chattool.go index 69b65f3e10..6f7adadcdf 100644 --- a/coderd/x/chatd/chattool/chattool.go +++ b/coderd/x/chatd/chattool/chattool.go @@ -1,12 +1,16 @@ package chattool import ( + "context" "encoding/json" "unicode/utf8" "charm.land/fantasy" "github.com/google/uuid" + "golang.org/x/xerrors" + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/codersdk" ) @@ -54,6 +58,65 @@ func responseErrorResult(resp codersdk.Response) map[string]any { return result } +func latestWorkspaceBuildAndJob( + ctx context.Context, + db database.Store, + workspaceID uuid.UUID, +) (database.WorkspaceBuild, database.ProvisionerJob, error) { + build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceID) + if err != nil { + return database.WorkspaceBuild{}, database.ProvisionerJob{}, xerrors.Errorf("get latest build: %w", err) + } + + job, err := db.GetProvisionerJobByID(ctx, build.JobID) + if err != nil { + return database.WorkspaceBuild{}, database.ProvisionerJob{}, xerrors.Errorf("get provisioner job: %w", err) + } + return build, job, nil +} + +func publishBuildBinding( + ctx context.Context, + db database.Store, + logger slog.Logger, + chatID uuid.UUID, + workspaceID uuid.UUID, + buildID uuid.UUID, + onChatUpdated func(database.Chat), +) { + updatedChat, bindErr := db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{ + ID: chatID, + WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}, + BuildID: uuid.NullUUID{ + UUID: buildID, + Valid: buildID != uuid.Nil, + }, + AgentID: uuid.NullUUID{}, + }) + if bindErr != nil { + logger.Error(ctx, "failed to persist build ID on chat binding", + slog.F("chat_id", chatID), + slog.F("build_id", buildID), + slog.Error(bindErr), + ) + return + } + if onChatUpdated != nil { + onChatUpdated(updatedChat) + } +} + +func provisionerJobTerminal(status database.ProvisionerJobStatus) bool { + switch status { + case database.ProvisionerJobStatusSucceeded, + database.ProvisionerJobStatusFailed, + database.ProvisionerJobStatusCanceled: + return true + default: + return false + } +} + func truncateRunes(value string, maxLen int) string { if maxLen <= 0 || value == "" { return "" diff --git a/coderd/x/chatd/chattool/startworkspace.go b/coderd/x/chatd/chattool/startworkspace.go index 16d1d1f9be..24b55348e6 100644 --- a/coderd/x/chatd/chattool/startworkspace.go +++ b/coderd/x/chatd/chattool/startworkspace.go @@ -56,7 +56,7 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO return fantasy.NewTextErrorResponse("workspace starter is not configured"), nil } - // Serialize with create_workspace to prevent races. + // Serialize with create_workspace and stop_workspace to prevent races. if options.WorkspaceMu != nil { options.WorkspaceMu.Lock() defer options.WorkspaceMu.Unlock() @@ -86,18 +86,9 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO ), nil } - build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID) + build, job, err := latestWorkspaceBuildAndJob(ctx, db, ws.ID) if err != nil { - return fantasy.NewTextErrorResponse( - xerrors.Errorf("get latest build: %w", err).Error(), - ), nil - } - - job, err := db.GetProvisionerJobByID(ctx, build.JobID) - if err != nil { - return fantasy.NewTextErrorResponse( - xerrors.Errorf("get provisioner job: %w", err).Error(), - ), nil + return fantasy.NewTextErrorResponse(err.Error()), nil } // If a build is already in progress, wait for it. @@ -106,24 +97,7 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO database.ProvisionerJobStatusRunning: // Publish the build ID to the frontend so it // can start streaming logs immediately. - updatedChat, bindErr := db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{ - ID: chatID, - WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, - BuildID: uuid.NullUUID{ - UUID: build.ID, - Valid: build.ID != uuid.Nil, - }, - AgentID: uuid.NullUUID{}, - }) - if bindErr != nil { - options.Logger.Error(ctx, "failed to persist build ID on chat binding", - slog.F("chat_id", chatID), - slog.F("build_id", build.ID), - slog.Error(bindErr), - ) - } else if options.OnChatUpdated != nil { - options.OnChatUpdated(updatedChat) - } + publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, build.ID, options.OnChatUpdated) if err := waitForBuild(ctx, db, build.ID); err != nil { // newBuildError returns via toolResponse (IsError: false) // rather than NewTextErrorResponse (IsError: true) so the @@ -199,24 +173,7 @@ func StartWorkspace(db database.Store, chatID uuid.UUID, options StartWorkspaceO // Persist the build ID on the chat binding so the // frontend can stream logs without polling. - updatedChat, bindErr := db.UpdateChatWorkspaceBinding(ctx, database.UpdateChatWorkspaceBindingParams{ - ID: chatID, - WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, - BuildID: uuid.NullUUID{ - UUID: startBuild.ID, - Valid: startBuild.ID != uuid.Nil, - }, - AgentID: uuid.NullUUID{}, - }) - if bindErr != nil { - options.Logger.Error(ctx, "failed to persist build ID on chat binding", - slog.F("chat_id", chatID), - slog.F("build_id", startBuild.ID), - slog.Error(bindErr), - ) - } else if options.OnChatUpdated != nil { - options.OnChatUpdated(updatedChat) - } + publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, startBuild.ID, options.OnChatUpdated) if err := waitForBuild(ctx, db, startBuild.ID); err != nil { return buildFailureToolResponse( ctx, diff --git a/coderd/x/chatd/chattool/stopworkspace.go b/coderd/x/chatd/chattool/stopworkspace.go new file mode 100644 index 0000000000..1aea9ad836 --- /dev/null +++ b/coderd/x/chatd/chattool/stopworkspace.go @@ -0,0 +1,181 @@ +package chattool + +import ( + "context" + "sync" + + "charm.land/fantasy" + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/httpapi/httperror" + "github.com/coder/coder/v2/codersdk" +) + +// StopWorkspaceFn stops a workspace by creating a new build with +// the "stop" transition. +type StopWorkspaceFn func( + ctx context.Context, + ownerID uuid.UUID, + workspaceID uuid.UUID, + req codersdk.CreateWorkspaceBuildRequest, +) (codersdk.WorkspaceBuild, error) + +// StopWorkspaceOptions configures the stop_workspace tool. +type StopWorkspaceOptions struct { + OwnerID uuid.UUID + StopFn StopWorkspaceFn + WorkspaceMu *sync.Mutex + OnChatUpdated func(database.Chat) + Logger slog.Logger +} + +type stopWorkspaceArgs struct{} + +// StopWorkspace returns a tool that stops the workspace associated +// with the current chat. The tool is idempotent when the workspace is +// already stopped. db must not be nil and chatID must not be uuid.Nil. +func StopWorkspace(db database.Store, chatID uuid.UUID, options StopWorkspaceOptions) fantasy.AgentTool { + return fantasy.NewAgentTool( + "stop_workspace", + "Stop the chat's workspace and wait for the stop build to complete. "+ + "If another workspace build is already in progress, this waits "+ + "for that build first, then stops the workspace if needed. "+ + "After waiting, this tool is idempotent if the workspace is "+ + "already stopped or the in-progress build stopped it. Use "+ + "this when the "+ + "user explicitly asks to stop the workspace, or when a "+ + "workspace-agent error tells you to stop and then start the "+ + "workspace. Stopping a workspace terminates running processes "+ + "and may discard unsaved in-memory state. This tool does not "+ + "delete the workspace.", + func(ctx context.Context, _ stopWorkspaceArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + if options.StopFn == nil { + return fantasy.NewTextErrorResponse("workspace stopper is not configured"), nil + } + + // Serialize with create_workspace and start_workspace to + // prevent lifecycle races. + if options.WorkspaceMu != nil { + options.WorkspaceMu.Lock() + defer options.WorkspaceMu.Unlock() + } + + chat, err := db.GetChatByID(ctx, chatID) + if err != nil { + return fantasy.NewTextErrorResponse( + xerrors.Errorf("load chat: %w", err).Error(), + ), nil + } + if !chat.WorkspaceID.Valid { + return fantasy.NewTextErrorResponse( + "chat has no workspace; use create_workspace first", + ), nil + } + + ws, err := db.GetWorkspaceByID(ctx, chat.WorkspaceID.UUID) + if err != nil { + return fantasy.NewTextErrorResponse( + xerrors.Errorf("load workspace: %w", err).Error(), + ), nil + } + if ws.Deleted { + return fantasy.NewTextErrorResponse( + "workspace was deleted; use create_workspace to make a new one", + ), nil + } + + build, job, err := latestWorkspaceBuildAndJob(ctx, db, ws.ID) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + + // If a build is already in progress, wait for it before + // deciding whether a stop build is still needed. + switch job.JobStatus { + case database.ProvisionerJobStatusPending, + database.ProvisionerJobStatusRunning, + database.ProvisionerJobStatusCanceling: + publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, build.ID, options.OnChatUpdated) + + waitErr := waitForBuild(ctx, db, build.ID) + // Re-read after waiting because another transition may + // have completed while this tool was blocked. + ws, err = db.GetWorkspaceByID(ctx, ws.ID) + if err != nil { + return fantasy.NewTextErrorResponse( + xerrors.Errorf("load workspace: %w", err).Error(), + ), nil + } + if ws.Deleted { + return fantasy.NewTextErrorResponse( + "workspace was deleted; use create_workspace to make a new one", + ), nil + } + build, job, err = latestWorkspaceBuildAndJob(ctx, db, ws.ID) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + // The fresh job row is authoritative. A wait error can + // be stale if the build reached a terminal state while the + // wait context was ending. + if waitErr != nil && !provisionerJobTerminal(job.JobStatus) { + return buildToolResponse(newBuildError( + xerrors.Errorf("waiting for in-progress build: %w", waitErr).Error(), + build.ID, + )), nil + } + } + + if job.JobStatus == database.ProvisionerJobStatusSucceeded && + build.Transition == database.WorkspaceTransitionStop { + result := map[string]any{ + "stopped": true, + "workspace_name": ws.Name, + } + setNoBuild(result, uuid.Nil) + return toolResponse(result), nil + } + + ownerCtx, ownerErr := asOwner(ctx, db, options.OwnerID) + if ownerErr != nil { + return fantasy.NewTextErrorResponse(ownerErr.Error()), nil + } + + stopBuild, err := options.StopFn(ownerCtx, options.OwnerID, ws.ID, codersdk.CreateWorkspaceBuildRequest{ + Transition: codersdk.WorkspaceTransitionStop, + }) + if err != nil { + if responseErr, ok := httperror.IsResponder(err); ok { + _, resp := responseErr.Response() + return toolResponse(responseErrorResult(resp)), nil + } + return fantasy.NewTextErrorResponse( + xerrors.Errorf("stop workspace: %w", err).Error(), + ), nil + } + + publishBuildBinding(ctx, db, options.Logger, chatID, ws.ID, stopBuild.ID, options.OnChatUpdated) + if err := waitForBuild(ctx, db, stopBuild.ID); err != nil { + return buildToolResponse(newBuildError( + xerrors.Errorf("workspace stop build failed: %w", err).Error(), + stopBuild.ID, + )), nil + } + + if options.OnChatUpdated != nil { + if latest, err := db.GetChatByID(ctx, chatID); err == nil { + options.OnChatUpdated(latest) + } + } + + result := map[string]any{ + "stopped": true, + "workspace_name": ws.Name, + } + setBuildID(result, stopBuild.ID) + return toolResponse(result), nil + }) +} diff --git a/coderd/x/chatd/chattool/stopworkspace_test.go b/coderd/x/chatd/chattool/stopworkspace_test.go new file mode 100644 index 0000000000..4133ba223d --- /dev/null +++ b/coderd/x/chatd/chattool/stopworkspace_test.go @@ -0,0 +1,449 @@ +package chattool_test + +import ( + "context" + "database/sql" + "encoding/json" + "sync" + "sync/atomic" + "testing" + "time" + + "charm.land/fantasy" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbfake" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/x/chatd/chattool" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +func TestStopWorkspace(t *testing.T) { + t.Parallel() + + t.Run("NoWorkspace", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(t, db) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: modelCfg.ID, + Title: "test-stop-no-workspace", + }) + + tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{ + StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + t.Fatal("StopFn should not be called") + return codersdk.WorkspaceBuild{}, nil + }, + WorkspaceMu: &sync.Mutex{}, + }) + + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"}) + require.NoError(t, err) + require.Contains(t, resp.Content, "use create_workspace first") + }) + + t.Run("DeletedWorkspace", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(t, db) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + Deleted: true, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionDelete, + }).Do() + ws := wsResp.Workspace + + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-stop-deleted-workspace", + }) + + tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{ + StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + t.Fatal("StopFn should not be called for deleted workspace") + return codersdk.WorkspaceBuild{}, nil + }, + WorkspaceMu: &sync.Mutex{}, + }) + + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"}) + require.NoError(t, err) + require.Contains(t, resp.Content, "workspace was deleted") + require.Contains(t, resp.Content, "create_workspace") + }) + + t.Run("AlreadyStopped", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(t, db) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStop, + }).Do() + ws := wsResp.Workspace + + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-stop-already-stopped", + }) + + tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{ + OwnerID: user.ID, + StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + t.Fatal("StopFn should not be called for already-stopped workspace") + return codersdk.WorkspaceBuild{}, nil + }, + WorkspaceMu: &sync.Mutex{}, + }) + + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"}) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + require.Equal(t, true, result["stopped"]) + require.Equal(t, ws.Name, result["workspace_name"]) + require.Equal(t, true, result["no_build"]) + require.Nil(t, result["build_id"]) + }) + + t.Run("RunningWorkspaceStops", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(t, db) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStart, + }).Do() + ws := wsResp.Workspace + + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-stop-running-workspace", + }) + + var stopCalled atomic.Bool + var stopBuildID uuid.UUID + var seenBuildID uuid.UUID + var onChatUpdatedCalls atomic.Int32 + tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{ + OwnerID: user.ID, + StopFn: func(_ context.Context, _ uuid.UUID, wsID uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + stopCalled.Store(true) + require.Equal(t, ws.ID, wsID) + require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition) + buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStop, + BuildNumber: 2, + }).Do() + stopBuildID = buildResp.Build.ID + return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil + }, + WorkspaceMu: &sync.Mutex{}, + OnChatUpdated: func(chat database.Chat) { + onChatUpdatedCalls.Add(1) + if chat.BuildID.Valid { + seenBuildID = chat.BuildID.UUID + } + }, + }) + + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"}) + require.NoError(t, err) + require.True(t, stopCalled.Load()) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + require.Equal(t, true, result["stopped"]) + require.Equal(t, ws.Name, result["workspace_name"]) + require.Equal(t, stopBuildID.String(), result["build_id"]) + require.Nil(t, result["no_build"]) + + require.GreaterOrEqual(t, onChatUpdatedCalls.Load(), int32(1)) + require.Equal(t, stopBuildID, seenBuildID) + + updatedChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.True(t, updatedChat.BuildID.Valid) + require.Equal(t, stopBuildID, updatedChat.BuildID.UUID) + }) + + t.Run("InProgressBuildWaitsThenStops", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(t, db) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStart, + }).Starting().Do() + ws := wsResp.Workspace + + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-stop-in-progress-build", + }) + + jobRead := make(chan struct{}, 1) + wrappedDB := &jobInterceptStore{Store: db, jobRead: jobRead} + var stopCalled atomic.Bool + var stopBuildID uuid.UUID + var onChatUpdatedCalled atomic.Bool + tool := chattool.StopWorkspace(wrappedDB, chat.ID, chattool.StopWorkspaceOptions{ + OwnerID: user.ID, + StopFn: func(_ context.Context, _ uuid.UUID, wsID uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + stopCalled.Store(true) + require.Equal(t, ws.ID, wsID) + require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition) + buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStop, + BuildNumber: 2, + }).Do() + stopBuildID = buildResp.Build.ID + return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil + }, + WorkspaceMu: &sync.Mutex{}, + Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + OnChatUpdated: func(_ database.Chat) { onChatUpdatedCalled.Store(true) }, + }) + + type toolResult struct { + resp fantasy.ToolResponse + err error + } + done := make(chan toolResult, 1) + go func() { + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"}) + done <- toolResult{resp: resp, err: err} + }() + + testutil.TryReceive(ctx, t, jobRead) + require.False(t, stopCalled.Load(), "StopFn must wait for the in-progress build") + + now := time.Now().UTC() + require.NoError(t, db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: wsResp.Build.JobID, + UpdatedAt: now, + CompletedAt: sql.NullTime{Time: now, Valid: true}, + })) + + res := testutil.TryReceive(ctx, t, done) + require.NoError(t, res.err) + require.True(t, stopCalled.Load()) + require.True(t, onChatUpdatedCalled.Load()) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(res.resp.Content), &result)) + require.Equal(t, true, result["stopped"]) + require.Equal(t, stopBuildID.String(), result["build_id"]) + }) + + t.Run("FailedLatestStopBuildStillStops", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(t, db) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStop, + }).Do() + ws := wsResp.Workspace + now := time.Now().UTC() + require.NoError(t, db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: wsResp.Build.JobID, + UpdatedAt: now, + CompletedAt: sql.NullTime{Time: now, Valid: true}, + Error: sql.NullString{String: "latest build failed", Valid: true}, + })) + + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-stop-failed-latest-build", + }) + + var stopCalled atomic.Bool + tool := chattool.StopWorkspace(db, chat.ID, chattool.StopWorkspaceOptions{ + OwnerID: user.ID, + StopFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + stopCalled.Store(true) + require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition) + buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStop, + BuildNumber: 2, + }).Do() + return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil + }, + WorkspaceMu: &sync.Mutex{}, + }) + + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"}) + require.NoError(t, err) + require.True(t, stopCalled.Load()) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + require.Equal(t, true, result["stopped"]) + }) + + t.Run("StopTriggeredBuildFailure", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(t, db) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStart, + }).Do() + ws := wsResp.Workspace + + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-stop-triggered-build-failure", + }) + + var stopBuildJobID uuid.UUID + var stopBuildID uuid.UUID + stopFn := func(_ context.Context, _ uuid.UUID, wsID uuid.UUID, req codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + require.Equal(t, ws.ID, wsID) + require.Equal(t, codersdk.WorkspaceTransitionStop, req.Transition) + buildResp := dbfake.WorkspaceBuild(t, db, ws).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStop, + BuildNumber: 2, + }).Starting().Do() + stopBuildJobID = buildResp.Build.JobID + stopBuildID = buildResp.Build.ID + return codersdk.WorkspaceBuild{ID: buildResp.Build.ID}, nil + } + + jobRead := make(chan struct{}, 2) + wrappedDB := &jobInterceptStore{Store: db, jobRead: jobRead} + tool := chattool.StopWorkspace(wrappedDB, chat.ID, chattool.StopWorkspaceOptions{ + OwnerID: user.ID, + StopFn: stopFn, + WorkspaceMu: &sync.Mutex{}, + Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + }) + + type toolResult struct { + resp fantasy.ToolResponse + err error + } + done := make(chan toolResult, 1) + go func() { + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "stop_workspace", Input: "{}"}) + done <- toolResult{resp: resp, err: err} + }() + + testutil.TryReceive(ctx, t, jobRead) + testutil.TryReceive(ctx, t, jobRead) + + now := time.Now().UTC() + require.NoError(t, db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ + ID: stopBuildJobID, + UpdatedAt: now, + CompletedAt: sql.NullTime{Time: now, Valid: true}, + Error: sql.NullString{String: "terraform destroy failed", Valid: true}, + })) + + res := testutil.TryReceive(ctx, t, done) + require.NoError(t, res.err) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(res.resp.Content), &result)) + require.Contains(t, result["error"], "workspace stop build failed") + require.Equal(t, stopBuildID.String(), result["build_id"]) + require.False(t, res.resp.IsError, + "buildToolResponse must not set IsError; chatprompt strips structured fields from error responses") + }) +}