diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index ef4ea259a2..b038e818af 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -39,6 +39,7 @@ import ( "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" "github.com/coder/coder/v2/coderd/x/chatd/chatretry" "github.com/coder/coder/v2/coderd/x/chatd/chattool" + "github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect" "github.com/coder/coder/v2/coderd/x/chatd/mcpclient" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" @@ -379,6 +380,13 @@ func (c *turnWorkspaceContext) loadWorkspaceAgentLocked( if len(agents) == 0 { return chatSnapshot, database.WorkspaceAgent{}, errChatHasNoWorkspaceAgent } + selected, err := agentselect.FindChatAgent(agents) + if err != nil { + return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf( + "find chat agent: %w", + err, + ) + } build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID) if err != nil { @@ -389,7 +397,7 @@ func (c *turnWorkspaceContext) loadWorkspaceAgentLocked( ctx, chatSnapshot, build.ID, - agents[0].ID, + selected.ID, ) if err != nil { return chatSnapshot, database.WorkspaceAgent{}, err @@ -401,7 +409,7 @@ func (c *turnWorkspaceContext) loadWorkspaceAgentLocked( chatSnapshot = latestChat continue } - c.agent = agents[0] + c.agent = selected c.agentLoaded = true c.cachedWorkspaceID = chatSnapshot.WorkspaceID return chatSnapshot, c.agent, nil @@ -429,7 +437,14 @@ func (c *turnWorkspaceContext) latestWorkspaceAgentID( if len(agents) == 0 { return uuid.Nil, errChatHasNoWorkspaceAgent } - return agents[0].ID, nil + selected, err := agentselect.FindChatAgent(agents) + if err != nil { + return uuid.Nil, xerrors.Errorf( + "find chat agent: %w", + err, + ) + } + return selected.ID, nil } func (c *turnWorkspaceContext) workspaceAgentIDForConn( diff --git a/coderd/x/chatd/chattool/createworkspace.go b/coderd/x/chatd/chattool/createworkspace.go index 9f00c2108f..e6b61cad2b 100644 --- a/coderd/x/chatd/chattool/createworkspace.go +++ b/coderd/x/chatd/chattool/createworkspace.go @@ -15,6 +15,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/util/namesgenerator" + "github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" ) @@ -203,12 +204,28 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool { } } - // Look up the first agent so we can link it to the chat. + result := map[string]any{ + "created": true, + "workspace_name": workspace.FullName(), + } + + // Select the chat agent so follow-up tools wait on the + // intended workspace agent. workspaceAgentID := uuid.Nil if options.DB != nil { agents, agentErr := options.DB.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID) - if agentErr == nil && len(agents) > 0 { - workspaceAgentID = agents[0].ID + if agentErr == nil { + if len(agents) == 0 { + result["agent_status"] = "no_agent" + } else { + selected, selectErr := agentselect.FindChatAgent(agents) + if selectErr != nil { + result["agent_status"] = "selection_error" + result["agent_error"] = selectErr.Error() + } else { + workspaceAgentID = selected.ID + } + } } } @@ -241,20 +258,12 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool { // Wait for the agent to come online and startup scripts to finish. if workspaceAgentID != uuid.Nil { agentStatus := waitForAgentReady(ctx, options.DB, workspaceAgentID, options.AgentConnFn) - result := map[string]any{ - "created": true, - "workspace_name": workspace.FullName(), - } for k, v := range agentStatus { result[k] = v } - return toolResponse(result), nil } - return toolResponse(map[string]any{ - "created": true, - "workspace_name": workspace.FullName(), - }), nil + return toolResponse(result), nil }) } @@ -322,7 +331,15 @@ func (o CreateWorkspaceOptions) checkExistingWorkspace( } agents, agentsErr := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, ws.ID) if agentsErr == nil && len(agents) > 0 { - for k, v := range waitForAgentReady(ctx, db, agents[0].ID, agentConnFn) { + selected, selectErr := agentselect.FindChatAgent(agents) + if selectErr != nil { + o.Logger.Debug(ctx, "agent selection failed, falling back to first agent for readiness check", + slog.F("workspace_id", ws.ID), + slog.Error(selectErr), + ) + selected = agents[0] + } + for k, v := range waitForAgentReady(ctx, db, selected.ID, agentConnFn) { result[k] = v } } @@ -345,7 +362,15 @@ func (o CreateWorkspaceOptions) checkExistingWorkspace( // still usable. agents, agentsErr := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, ws.ID) if agentsErr == nil && len(agents) > 0 { - status := agents[0].Status(agentInactiveDisconnectTimeout) + selected, selectErr := agentselect.FindChatAgent(agents) + if selectErr != nil { + o.Logger.Debug(ctx, "agent selection failed, falling back to first agent for status check", + slog.F("workspace_id", ws.ID), + slog.Error(selectErr), + ) + selected = agents[0] + } + status := selected.Status(agentInactiveDisconnectTimeout) result := map[string]any{ "created": false, "workspace_name": ws.Name, @@ -355,19 +380,19 @@ func (o CreateWorkspaceOptions) checkExistingWorkspace( switch status.Status { case database.WorkspaceAgentStatusConnected: result["message"] = "workspace is already running and recently connected" - for k, v := range waitForAgentReady(ctx, db, agents[0].ID, nil) { + for k, v := range waitForAgentReady(ctx, db, selected.ID, nil) { result[k] = v } return result, true, nil case database.WorkspaceAgentStatusConnecting: result["message"] = "workspace exists and the agent is still connecting" - for k, v := range waitForAgentReady(ctx, db, agents[0].ID, agentConnFn) { + for k, v := range waitForAgentReady(ctx, db, selected.ID, agentConnFn) { result[k] = v } return result, true, nil case database.WorkspaceAgentStatusDisconnected, database.WorkspaceAgentStatusTimeout: - // Agent is offline or never became ready — allow + // Agent is offline or never became ready - allow // creation. } } diff --git a/coderd/x/chatd/chattool/createworkspace_test.go b/coderd/x/chatd/chattool/createworkspace_test.go index 6257393c67..060f9c4b7e 100644 --- a/coderd/x/chatd/chattool/createworkspace_test.go +++ b/coderd/x/chatd/chattool/createworkspace_test.go @@ -3,6 +3,7 @@ package chattool //nolint:testpackage // Uses internal symbols. import ( "context" "database/sql" + "encoding/json" "fmt" "sync" "testing" @@ -118,6 +119,180 @@ func TestWaitForAgentReady(t *testing.T) { }) } +func TestCreateWorkspace_PrefersChatSuffixAgent(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + ownerID := uuid.New() + templateID := uuid.New() + workspaceID := uuid.New() + jobID := uuid.New() + fallbackAgentID := uuid.New() + chatAgentID := uuid.New() + + db.EXPECT(). + GetAuthorizationUserRoles(gomock.Any(), ownerID). + Return(database.GetAuthorizationUserRolesRow{ + ID: ownerID, + Roles: []string{}, + Groups: []string{}, + Status: database.UserStatusActive, + }, nil) + + db.EXPECT(). + GetChatWorkspaceTTL(gomock.Any()). + Return("0s", nil) + + db.EXPECT(). + GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID). + Return(database.WorkspaceBuild{ + WorkspaceID: workspaceID, + JobID: jobID, + }, nil) + db.EXPECT(). + GetProvisionerJobByID(gomock.Any(), jobID). + Return(database.ProvisionerJob{ + ID: jobID, + JobStatus: database.ProvisionerJobStatusSucceeded, + }, nil) + db.EXPECT(). + GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID). + Return([]database.WorkspaceAgent{ + {ID: fallbackAgentID, Name: "dev", DisplayOrder: 0}, + {ID: chatAgentID, Name: "dev-coderd-chat", DisplayOrder: 1}, + }, nil) + db.EXPECT(). + GetWorkspaceAgentLifecycleStateByID(gomock.Any(), chatAgentID). + Return(database.GetWorkspaceAgentLifecycleStateByIDRow{ + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + }, nil) + + var connectedAgentID uuid.UUID + createFn := func(_ context.Context, _ uuid.UUID, req codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) { + return codersdk.Workspace{ + ID: workspaceID, + Name: req.Name, + OwnerName: "testuser", + }, nil + } + agentConnFn := func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + connectedAgentID = agentID + return nil, func() {}, nil + } + + tool := CreateWorkspace(CreateWorkspaceOptions{ + DB: db, + OwnerID: ownerID, + CreateFn: createFn, + AgentConnFn: agentConnFn, + WorkspaceMu: &sync.Mutex{}, + Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + }) + + input := fmt.Sprintf(`{"template_id":%q,"name":"test-chat-agent"}`, templateID.String()) + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call-1", + Name: "create_workspace", + Input: input, + }) + require.NoError(t, err) + require.NotEmpty(t, resp.Content) + require.Equal(t, chatAgentID, connectedAgentID) +} + +func TestCreateWorkspace_ReturnsSelectionErrorImmediately(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + ownerID := uuid.New() + chatID := uuid.New() + templateID := uuid.New() + workspaceID := uuid.New() + jobID := uuid.New() + + db.EXPECT(). + GetChatByID(gomock.Any(), chatID). + Return(database.Chat{ID: chatID}, nil) + db.EXPECT(). + GetAuthorizationUserRoles(gomock.Any(), ownerID). + Return(database.GetAuthorizationUserRolesRow{ + ID: ownerID, + Roles: []string{}, + Groups: []string{}, + Status: database.UserStatusActive, + }, nil) + db.EXPECT(). + GetChatWorkspaceTTL(gomock.Any()). + Return("0s", nil) + db.EXPECT(). + GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID). + Return(database.WorkspaceBuild{ + WorkspaceID: workspaceID, + JobID: jobID, + }, nil) + db.EXPECT(). + GetProvisionerJobByID(gomock.Any(), jobID). + Return(database.ProvisionerJob{ + ID: jobID, + JobStatus: database.ProvisionerJobStatusSucceeded, + }, nil) + db.EXPECT(). + UpdateChatWorkspaceBinding(gomock.Any(), database.UpdateChatWorkspaceBindingParams{ + ID: chatID, + WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}, + BuildID: uuid.NullUUID{}, + AgentID: uuid.NullUUID{}, + }). + Return(database.Chat{ + ID: chatID, + WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}, + }, nil) + db.EXPECT(). + GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID). + Return([]database.WorkspaceAgent{ + {ID: uuid.New(), Name: "alpha-coderd-chat", DisplayOrder: 0}, + {ID: uuid.New(), Name: "beta-coderd-chat", DisplayOrder: 1}, + }, nil) + + tool := CreateWorkspace(CreateWorkspaceOptions{ + DB: db, + OwnerID: ownerID, + ChatID: chatID, + CreateFn: func(_ context.Context, _ uuid.UUID, req codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) { + return codersdk.Workspace{ + ID: workspaceID, + Name: req.Name, + OwnerName: "testuser", + }, nil + }, + AgentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) { + t.Fatal("AgentConnFn should not be called when agent selection fails") + return nil, nil, xerrors.New("unexpected agent dial") + }, + WorkspaceMu: &sync.Mutex{}, + Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + }) + + input := fmt.Sprintf(`{"template_id":%q,"name":"test-selection-error"}`, templateID.String()) + resp, err := tool.Run(context.Background(), fantasy.ToolCall{ + ID: "call-1", + Name: "create_workspace", + Input: input, + }) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + require.Equal(t, true, result["created"]) + require.Equal(t, "testuser/test-selection-error", result["workspace_name"]) + require.Equal(t, "selection_error", result["agent_status"]) + require.Contains(t, result["agent_error"], "multiple agents match the chat suffix") +} + func TestCreateWorkspace_GlobalTTL(t *testing.T) { t.Parallel() @@ -253,6 +428,7 @@ func TestCheckExistingWorkspace_ConnectedAgent(t *testing.T) { GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID). Return([]database.WorkspaceAgent{{ ID: agentID, + Name: "dev", CreatedAt: now.Add(-time.Minute), FirstConnectedAt: validNullTime(now.Add(-45 * time.Second)), LastConnectedAt: validNullTime(now.Add(-5 * time.Second)), @@ -302,6 +478,7 @@ func TestCheckExistingWorkspace_ConnectingAgentWaits(t *testing.T) { GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID). Return([]database.WorkspaceAgent{{ ID: agentID, + Name: "dev", CreatedAt: now, ConnectionTimeoutSeconds: 60, }}, nil) @@ -336,6 +513,7 @@ func TestCheckExistingWorkspace_DeadAgentAllowsCreation(t *testing.T) { name: "Disconnected", agent: database.WorkspaceAgent{ ID: uuid.New(), + Name: "disconnected", CreatedAt: time.Now().UTC().Add(-2 * time.Minute), FirstConnectedAt: validNullTime(time.Now().UTC().Add(-2 * time.Minute)), LastConnectedAt: validNullTime(time.Now().UTC().Add(-time.Minute)), @@ -345,6 +523,7 @@ func TestCheckExistingWorkspace_DeadAgentAllowsCreation(t *testing.T) { name: "TimedOut", agent: database.WorkspaceAgent{ ID: uuid.New(), + Name: "timed-out", CreatedAt: time.Now().UTC().Add(-2 * time.Second), ConnectionTimeoutSeconds: 1, }, diff --git a/coderd/x/chatd/chattool/startworkspace.go b/coderd/x/chatd/chattool/startworkspace.go index bc19a8cd77..bd6473d183 100644 --- a/coderd/x/chatd/chattool/startworkspace.go +++ b/coderd/x/chatd/chattool/startworkspace.go @@ -9,6 +9,7 @@ import ( "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect" "github.com/coder/coder/v2/codersdk" ) @@ -144,7 +145,7 @@ func StartWorkspace(options StartWorkspaceOptions) fantasy.AgentTool { ) } -// waitForAgentAndRespond looks up the first agent in the workspace's +// waitForAgentAndRespond selects the chat agent from the workspace's // latest build, waits for it to become reachable, and returns a // success response. func waitForAgentAndRespond( @@ -155,7 +156,7 @@ func waitForAgentAndRespond( ) (fantasy.ToolResponse, error) { agents, err := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, ws.ID) if err != nil || len(agents) == 0 { - // Workspace started but no agent found — still report + // Workspace started but no agent found - still report // success so the model knows the workspace is up. return toolResponse(map[string]any{ "started": true, @@ -164,11 +165,21 @@ func waitForAgentAndRespond( }), nil } + selected, err := agentselect.FindChatAgent(agents) + if err != nil { + return toolResponse(map[string]any{ + "started": true, + "workspace_name": ws.Name, + "agent_status": "selection_error", + "agent_error": err.Error(), + }), nil + } + result := map[string]any{ "started": true, "workspace_name": ws.Name, } - for k, v := range waitForAgentReady(ctx, db, agents[0].ID, agentConnFn) { + for k, v := range waitForAgentReady(ctx, db, selected.ID, agentConnFn) { result[k] = v } return toolResponse(result), nil diff --git a/coderd/x/chatd/chattool/startworkspace_test.go b/coderd/x/chatd/chattool/startworkspace_test.go index 74629cc402..a9d593ae49 100644 --- a/coderd/x/chatd/chattool/startworkspace_test.go +++ b/coderd/x/chatd/chattool/startworkspace_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "sync" "testing" + "time" "charm.land/fantasy" "github.com/google/uuid" @@ -18,6 +19,7 @@ import ( "github.com/coder/coder/v2/coderd/x/chatd/chattool" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" + sdkproto "github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/coder/v2/testutil" ) @@ -108,6 +110,206 @@ func TestStartWorkspace(t *testing.T) { require.True(t, started) }) + t.Run("AlreadyRunningPrefersChatSuffixAgent", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(ctx, t, db, user.ID) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + }).WithAgent(func(agents []*sdkproto.Agent) []*sdkproto.Agent { + agents[0].Name = "dev" + return append(agents, &sdkproto.Agent{ + Id: uuid.NewString(), + Name: "dev-coderd-chat", + Auth: &sdkproto.Agent_Token{Token: uuid.NewString()}, + Env: map[string]string{}, + }) + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStart, + }).Do() + ws := wsResp.Workspace + + now := time.Now().UTC() + preferredAgentID := uuid.Nil + for _, agent := range wsResp.Agents { + if agent.Name == "dev-coderd-chat" { + preferredAgentID = agent.ID + } + err := db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agent.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + StartedAt: sql.NullTime{Time: now, Valid: true}, + ReadyAt: sql.NullTime{Time: now, Valid: true}, + }) + require.NoError(t, err) + } + require.NotEqual(t, uuid.Nil, preferredAgentID) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-running-preferred-agent", + }) + require.NoError(t, err) + + var connectedAgentID uuid.UUID + agentConnFn := func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + connectedAgentID = agentID + return nil, func() {}, nil + } + + tool := chattool.StartWorkspace(chattool.StartWorkspaceOptions{ + DB: db, + OwnerID: user.ID, + ChatID: chat.ID, + AgentConnFn: agentConnFn, + StartFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + t.Fatal("StartFn should not be called for already-running workspace") + return codersdk.WorkspaceBuild{}, nil + }, + WorkspaceMu: &sync.Mutex{}, + }) + + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "start_workspace", Input: "{}"}) + require.NoError(t, err) + require.Equal(t, preferredAgentID, connectedAgentID) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + started, ok := result["started"].(bool) + require.True(t, ok) + require.True(t, started) + }) + + t.Run("AlreadyRunningWithoutAgentsReturnsNoAgent", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(ctx, t, db, user.ID) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + }).WithAgent(func(_ []*sdkproto.Agent) []*sdkproto.Agent { + return nil + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStart, + }).Do() + ws := wsResp.Workspace + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-running-no-agent", + }) + require.NoError(t, err) + + tool := chattool.StartWorkspace(chattool.StartWorkspaceOptions{ + DB: db, + OwnerID: user.ID, + ChatID: chat.ID, + AgentConnFn: func(_ context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) { + t.Fatal("AgentConnFn should not be called when no agents exist") + return nil, func() {}, nil + }, + StartFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + t.Fatal("StartFn should not be called for already-running workspace") + return codersdk.WorkspaceBuild{}, nil + }, + WorkspaceMu: &sync.Mutex{}, + }) + + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "start_workspace", Input: "{}"}) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + started, ok := result["started"].(bool) + require.True(t, ok) + require.True(t, started) + require.Equal(t, "no_agent", result["agent_status"]) + }) + + t.Run("AlreadyRunningPreservesAgentSelectionError", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + + user := dbgen.User(t, db, database.User{}) + modelCfg := seedModelConfig(ctx, t, db, user.ID) + org := dbgen.Organization(t, db, database.Organization{}) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + }).WithAgent(func(agents []*sdkproto.Agent) []*sdkproto.Agent { + agents[0].Name = "alpha-coderd-chat" + return append(agents, &sdkproto.Agent{ + Id: uuid.NewString(), + Name: "beta-coderd-chat", + Auth: &sdkproto.Agent_Token{Token: uuid.NewString()}, + Env: map[string]string{}, + }) + }).Seed(database.WorkspaceBuild{ + Transition: database.WorkspaceTransitionStart, + }).Do() + ws := wsResp.Workspace + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + LastModelConfigID: modelCfg.ID, + Title: "test-running-selection-error", + }) + require.NoError(t, err) + + tool := chattool.StartWorkspace(chattool.StartWorkspaceOptions{ + DB: db, + OwnerID: user.ID, + ChatID: chat.ID, + AgentConnFn: func(_ context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) { + t.Fatal("AgentConnFn should not be called when agent selection fails") + return nil, func() {}, nil + }, + StartFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) { + t.Fatal("StartFn should not be called for already-running workspace") + return codersdk.WorkspaceBuild{}, nil + }, + WorkspaceMu: &sync.Mutex{}, + }) + + resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "start_workspace", Input: "{}"}) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal([]byte(resp.Content), &result)) + started, ok := result["started"].(bool) + require.True(t, ok) + require.True(t, started) + require.Equal(t, "selection_error", result["agent_status"]) + require.Contains(t, result["agent_error"], "multiple agents match the chat suffix") + }) + t.Run("StoppedWorkspace", func(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) diff --git a/coderd/x/chatd/internal/agentselect/agentselect.go b/coderd/x/chatd/internal/agentselect/agentselect.go new file mode 100644 index 0000000000..4d5530523a --- /dev/null +++ b/coderd/x/chatd/internal/agentselect/agentselect.go @@ -0,0 +1,86 @@ +package agentselect + +import ( + "cmp" + "slices" + "strings" + + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" +) + +// Suffix marks chat-designated agents during the current PoC. This naming +// convention is an implementation detail, not a stable contract. +const Suffix = "-coderd-chat" + +// IsChatAgent reports whether name uses the chat-agent suffix convention. +func IsChatAgent(name string) bool { + return strings.HasSuffix(strings.ToLower(name), Suffix) +} + +// FindChatAgent picks the best workspace agent for a chat session from the +// provided candidates. It applies these rules in order: +// 1. Filter to root agents only (ParentID is null). +// 2. Sort stably and deterministically by DisplayOrder ASC, then Name ASC +// (case-insensitive), then Name ASC, then ID ASC. +// 3. If exactly one root agent name ends with Suffix (case-insensitive), +// return it. +// 4. If zero root agents match the suffix, return the first root agent after +// sorting (deterministic fallback). +// 5. If more than one root agent matches the suffix, return an error with an +// actionable message. +// 6. If no root agents exist at all, return an error. +func FindChatAgent( + agents []database.WorkspaceAgent, +) (database.WorkspaceAgent, error) { + rootAgents := make([]database.WorkspaceAgent, 0, len(agents)) + matchingAgents := make([]database.WorkspaceAgent, 0, 1) + for _, agent := range agents { + if agent.ParentID.Valid { + continue + } + rootAgents = append(rootAgents, agent) + if IsChatAgent(agent.Name) { + matchingAgents = append(matchingAgents, agent) + } + } + + if len(rootAgents) == 0 { + return database.WorkspaceAgent{}, xerrors.New( + "no eligible workspace agents found", + ) + } + + compareAgents := func(a, b database.WorkspaceAgent) int { + if order := cmp.Compare(a.DisplayOrder, b.DisplayOrder); order != 0 { + return order + } + if order := cmp.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name)); order != 0 { + return order + } + if order := cmp.Compare(a.Name, b.Name); order != 0 { + return order + } + return cmp.Compare(a.ID.String(), b.ID.String()) + } + slices.SortStableFunc(rootAgents, compareAgents) + slices.SortStableFunc(matchingAgents, compareAgents) + + switch len(matchingAgents) { + case 0: + return rootAgents[0], nil + case 1: + return matchingAgents[0], nil + default: + names := make([]string, 0, len(matchingAgents)) + for _, agent := range matchingAgents { + names = append(names, agent.Name) + } + return database.WorkspaceAgent{}, xerrors.Errorf( + "multiple agents match the chat suffix %q: %s; only one agent should use this suffix", + Suffix, + strings.Join(names, ", "), + ) + } +} diff --git a/coderd/x/chatd/internal/agentselect/agentselect_test.go b/coderd/x/chatd/internal/agentselect/agentselect_test.go new file mode 100644 index 0000000000..84bbb5bee8 --- /dev/null +++ b/coderd/x/chatd/internal/agentselect/agentselect_test.go @@ -0,0 +1,231 @@ +package agentselect_test + +import ( + "fmt" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect" +) + +func TestFindChatAgent(t *testing.T) { + t.Parallel() + + newRootAgentWithID := func(id, name string, displayOrder int32) database.WorkspaceAgent { + return database.WorkspaceAgent{ + ID: uuid.MustParse(id), + Name: name, + DisplayOrder: displayOrder, + } + } + + newRootAgent := func(name string, displayOrder int32) database.WorkspaceAgent { + return newRootAgentWithID(uuid.NewString(), name, displayOrder) + } + + newChildAgent := func(name string, displayOrder int32) database.WorkspaceAgent { + agent := newRootAgent(name, displayOrder) + agent.ParentID = uuid.NullUUID{UUID: uuid.New(), Valid: true} + return agent + } + + tests := []struct { + name string + agents []database.WorkspaceAgent + wantIndex int + wantErrContains []string + }{ + { + name: "SingleSuffixMatch", + agents: []database.WorkspaceAgent{ + newRootAgent("alpha", 0), + newRootAgent("dev-coderd-chat", 2), + newRootAgent("zeta", 1), + }, + wantIndex: 1, + }, + { + name: "SuffixMatchCaseInsensitive", + agents: []database.WorkspaceAgent{ + newRootAgent("alpha", 0), + newRootAgent("Dev-Coderd-Chat", 2), + newRootAgent("zeta", 1), + }, + wantIndex: 1, + }, + { + name: "NoSuffixMatchFallbackDeterministic", + agents: []database.WorkspaceAgent{ + newRootAgent("zeta", 2), + newRootAgent("bravo", 1), + newRootAgent("alpha", 1), + }, + wantIndex: 2, + }, + { + name: "NoSuffixMatchFallbackByName", + agents: []database.WorkspaceAgent{ + newRootAgent("Bravo", 3), + newRootAgent("alpha", 3), + newRootAgent("charlie", 3), + }, + wantIndex: 1, + }, + { + name: "CaseOnlyNameTieFallbackDeterministic", + agents: []database.WorkspaceAgent{ + newRootAgent("Dev", 0), + newRootAgent("dev", 0), + }, + wantIndex: 0, + }, + { + name: "ExactNameTieFallbackByID", + agents: []database.WorkspaceAgent{ + newRootAgentWithID("00000000-0000-0000-0000-000000000002", "dev", 0), + newRootAgentWithID("00000000-0000-0000-0000-000000000001", "dev", 0), + }, + wantIndex: 1, + }, + { + name: "MultipleSuffixMatchesError", + agents: []database.WorkspaceAgent{ + newRootAgent("alpha-coderd-chat", 2), + newRootAgent("beta-coderd-chat", 1), + newRootAgent("gamma", 0), + }, + wantErrContains: []string{ + fmt.Sprintf( + "multiple agents match the chat suffix %q", + agentselect.Suffix, + ), + "alpha-coderd-chat", + "beta-coderd-chat", + "only one agent should use this suffix", + }, + }, + { + name: "ChildAgentSuffixIgnored", + agents: []database.WorkspaceAgent{ + newRootAgent("alpha", 1), + newChildAgent("child-coderd-chat", 0), + newRootAgent("bravo", 0), + }, + wantIndex: 2, + }, + { + name: "ChildAgentSuffixIgnoredWithRootMatch", + agents: []database.WorkspaceAgent{ + newRootAgent("alpha", 0), + newChildAgent("child-coderd-chat", 1), + newRootAgent("root-coderd-chat", 2), + }, + wantIndex: 2, + }, + { + name: "EmptyAgentList", + agents: []database.WorkspaceAgent{}, + wantErrContains: []string{ + "no eligible workspace agents found", + }, + }, + { + name: "OnlyChildAgents", + agents: []database.WorkspaceAgent{ + newChildAgent("alpha", 0), + newChildAgent("beta-coderd-chat", 1), + }, + wantErrContains: []string{ + "no eligible workspace agents found", + }, + }, + { + name: "SingleRootAgent", + agents: []database.WorkspaceAgent{ + newRootAgent("solo", 5), + }, + wantIndex: 0, + }, + { + name: "SuffixAgentWinsRegardlessOfOrder", + agents: []database.WorkspaceAgent{ + newRootAgent("alpha", 0), + newRootAgent("zeta", 1), + newRootAgent("preferred-coderd-chat", 99), + }, + wantIndex: 2, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := agentselect.FindChatAgent(tt.agents) + if len(tt.wantErrContains) > 0 { + require.Error(t, err) + for _, wantErr := range tt.wantErrContains { + require.ErrorContains(t, err, wantErr) + } + return + } + + require.NoError(t, err) + require.Equal(t, tt.agents[tt.wantIndex], got) + }) + } +} + +func TestIsChatAgent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want bool + }{ + { + name: "ExactSuffix", + input: "agent-coderd-chat", + want: true, + }, + { + name: "UppercaseSuffix", + input: "agent-CODERD-CHAT", + want: true, + }, + { + name: "MixedCaseSuffix", + input: "agent-Coderd-Chat", + want: true, + }, + { + name: "NoSuffix", + input: "my-agent", + want: false, + }, + { + name: "SuffixOnly", + input: "-coderd-chat", + want: true, + }, + { + name: "PartialSuffix", + input: "agent-coderd", + want: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + require.Equal(t, tt.want, agentselect.IsChatAgent(tt.input)) + }) + } +}