feat: suffix-based chat agent selection (#23741)

Adds suffix-based agent selection for chatd. Template authors can direct
chat traffic to a specific root workspace agent by naming it with the
`-coderd-chat` suffix (for example, `coder_agent "dev-coderd-chat"`).
When no suffix match exists, chatd falls back to the first root agent by
`DisplayOrder`, then `Name`. Multiple suffix matches return an error.

The selection logic lives in `coderd/x/chatd/internal/agentselect` and
is shared by chatd core plus the workspace chat tools so all chat entry
points pick the same agent deterministically.

No database migrations, API contract changes, or provider changes. The
experimental sandbox template was split out to #23777.
This commit is contained in:
Michael Suchacz
2026-03-30 13:43:59 +02:00
committed by GitHub
parent 4c97b63d79
commit 73f6cd8169
7 changed files with 772 additions and 23 deletions
+18 -3
View File
@@ -39,6 +39,7 @@ import (
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect"
"github.com/coder/coder/v2/coderd/x/chatd/mcpclient"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
@@ -379,6 +380,13 @@ func (c *turnWorkspaceContext) loadWorkspaceAgentLocked(
if len(agents) == 0 {
return chatSnapshot, database.WorkspaceAgent{}, errChatHasNoWorkspaceAgent
}
selected, err := agentselect.FindChatAgent(agents)
if err != nil {
return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf(
"find chat agent: %w",
err,
)
}
build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID)
if err != nil {
@@ -389,7 +397,7 @@ func (c *turnWorkspaceContext) loadWorkspaceAgentLocked(
ctx,
chatSnapshot,
build.ID,
agents[0].ID,
selected.ID,
)
if err != nil {
return chatSnapshot, database.WorkspaceAgent{}, err
@@ -401,7 +409,7 @@ func (c *turnWorkspaceContext) loadWorkspaceAgentLocked(
chatSnapshot = latestChat
continue
}
c.agent = agents[0]
c.agent = selected
c.agentLoaded = true
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
return chatSnapshot, c.agent, nil
@@ -429,7 +437,14 @@ func (c *turnWorkspaceContext) latestWorkspaceAgentID(
if len(agents) == 0 {
return uuid.Nil, errChatHasNoWorkspaceAgent
}
return agents[0].ID, nil
selected, err := agentselect.FindChatAgent(agents)
if err != nil {
return uuid.Nil, xerrors.Errorf(
"find chat agent: %w",
err,
)
}
return selected.ID, nil
}
func (c *turnWorkspaceContext) workspaceAgentIDForConn(
+42 -17
View File
@@ -15,6 +15,7 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/util/namesgenerator"
"github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
@@ -203,12 +204,28 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
}
}
// Look up the first agent so we can link it to the chat.
result := map[string]any{
"created": true,
"workspace_name": workspace.FullName(),
}
// Select the chat agent so follow-up tools wait on the
// intended workspace agent.
workspaceAgentID := uuid.Nil
if options.DB != nil {
agents, agentErr := options.DB.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID)
if agentErr == nil && len(agents) > 0 {
workspaceAgentID = agents[0].ID
if agentErr == nil {
if len(agents) == 0 {
result["agent_status"] = "no_agent"
} else {
selected, selectErr := agentselect.FindChatAgent(agents)
if selectErr != nil {
result["agent_status"] = "selection_error"
result["agent_error"] = selectErr.Error()
} else {
workspaceAgentID = selected.ID
}
}
}
}
@@ -241,20 +258,12 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
// Wait for the agent to come online and startup scripts to finish.
if workspaceAgentID != uuid.Nil {
agentStatus := waitForAgentReady(ctx, options.DB, workspaceAgentID, options.AgentConnFn)
result := map[string]any{
"created": true,
"workspace_name": workspace.FullName(),
}
for k, v := range agentStatus {
result[k] = v
}
return toolResponse(result), nil
}
return toolResponse(map[string]any{
"created": true,
"workspace_name": workspace.FullName(),
}), nil
return toolResponse(result), nil
})
}
@@ -322,7 +331,15 @@ func (o CreateWorkspaceOptions) checkExistingWorkspace(
}
agents, agentsErr := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, ws.ID)
if agentsErr == nil && len(agents) > 0 {
for k, v := range waitForAgentReady(ctx, db, agents[0].ID, agentConnFn) {
selected, selectErr := agentselect.FindChatAgent(agents)
if selectErr != nil {
o.Logger.Debug(ctx, "agent selection failed, falling back to first agent for readiness check",
slog.F("workspace_id", ws.ID),
slog.Error(selectErr),
)
selected = agents[0]
}
for k, v := range waitForAgentReady(ctx, db, selected.ID, agentConnFn) {
result[k] = v
}
}
@@ -345,7 +362,15 @@ func (o CreateWorkspaceOptions) checkExistingWorkspace(
// still usable.
agents, agentsErr := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, ws.ID)
if agentsErr == nil && len(agents) > 0 {
status := agents[0].Status(agentInactiveDisconnectTimeout)
selected, selectErr := agentselect.FindChatAgent(agents)
if selectErr != nil {
o.Logger.Debug(ctx, "agent selection failed, falling back to first agent for status check",
slog.F("workspace_id", ws.ID),
slog.Error(selectErr),
)
selected = agents[0]
}
status := selected.Status(agentInactiveDisconnectTimeout)
result := map[string]any{
"created": false,
"workspace_name": ws.Name,
@@ -355,19 +380,19 @@ func (o CreateWorkspaceOptions) checkExistingWorkspace(
switch status.Status {
case database.WorkspaceAgentStatusConnected:
result["message"] = "workspace is already running and recently connected"
for k, v := range waitForAgentReady(ctx, db, agents[0].ID, nil) {
for k, v := range waitForAgentReady(ctx, db, selected.ID, nil) {
result[k] = v
}
return result, true, nil
case database.WorkspaceAgentStatusConnecting:
result["message"] = "workspace exists and the agent is still connecting"
for k, v := range waitForAgentReady(ctx, db, agents[0].ID, agentConnFn) {
for k, v := range waitForAgentReady(ctx, db, selected.ID, agentConnFn) {
result[k] = v
}
return result, true, nil
case database.WorkspaceAgentStatusDisconnected,
database.WorkspaceAgentStatusTimeout:
// Agent is offline or never became ready allow
// Agent is offline or never became ready - allow
// creation.
}
}
@@ -3,6 +3,7 @@ package chattool //nolint:testpackage // Uses internal symbols.
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"sync"
"testing"
@@ -118,6 +119,180 @@ func TestWaitForAgentReady(t *testing.T) {
})
}
func TestCreateWorkspace_PrefersChatSuffixAgent(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
templateID := uuid.New()
workspaceID := uuid.New()
jobID := uuid.New()
fallbackAgentID := uuid.New()
chatAgentID := uuid.New()
db.EXPECT().
GetAuthorizationUserRoles(gomock.Any(), ownerID).
Return(database.GetAuthorizationUserRolesRow{
ID: ownerID,
Roles: []string{},
Groups: []string{},
Status: database.UserStatusActive,
}, nil)
db.EXPECT().
GetChatWorkspaceTTL(gomock.Any()).
Return("0s", nil)
db.EXPECT().
GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).
Return(database.WorkspaceBuild{
WorkspaceID: workspaceID,
JobID: jobID,
}, nil)
db.EXPECT().
GetProvisionerJobByID(gomock.Any(), jobID).
Return(database.ProvisionerJob{
ID: jobID,
JobStatus: database.ProvisionerJobStatusSucceeded,
}, nil)
db.EXPECT().
GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{
{ID: fallbackAgentID, Name: "dev", DisplayOrder: 0},
{ID: chatAgentID, Name: "dev-coderd-chat", DisplayOrder: 1},
}, nil)
db.EXPECT().
GetWorkspaceAgentLifecycleStateByID(gomock.Any(), chatAgentID).
Return(database.GetWorkspaceAgentLifecycleStateByIDRow{
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}, nil)
var connectedAgentID uuid.UUID
createFn := func(_ context.Context, _ uuid.UUID, req codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
return codersdk.Workspace{
ID: workspaceID,
Name: req.Name,
OwnerName: "testuser",
}, nil
}
agentConnFn := func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
connectedAgentID = agentID
return nil, func() {}, nil
}
tool := CreateWorkspace(CreateWorkspaceOptions{
DB: db,
OwnerID: ownerID,
CreateFn: createFn,
AgentConnFn: agentConnFn,
WorkspaceMu: &sync.Mutex{},
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
})
input := fmt.Sprintf(`{"template_id":%q,"name":"test-chat-agent"}`, templateID.String())
resp, err := tool.Run(context.Background(), fantasy.ToolCall{
ID: "call-1",
Name: "create_workspace",
Input: input,
})
require.NoError(t, err)
require.NotEmpty(t, resp.Content)
require.Equal(t, chatAgentID, connectedAgentID)
}
func TestCreateWorkspace_ReturnsSelectionErrorImmediately(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
chatID := uuid.New()
templateID := uuid.New()
workspaceID := uuid.New()
jobID := uuid.New()
db.EXPECT().
GetChatByID(gomock.Any(), chatID).
Return(database.Chat{ID: chatID}, nil)
db.EXPECT().
GetAuthorizationUserRoles(gomock.Any(), ownerID).
Return(database.GetAuthorizationUserRolesRow{
ID: ownerID,
Roles: []string{},
Groups: []string{},
Status: database.UserStatusActive,
}, nil)
db.EXPECT().
GetChatWorkspaceTTL(gomock.Any()).
Return("0s", nil)
db.EXPECT().
GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).
Return(database.WorkspaceBuild{
WorkspaceID: workspaceID,
JobID: jobID,
}, nil)
db.EXPECT().
GetProvisionerJobByID(gomock.Any(), jobID).
Return(database.ProvisionerJob{
ID: jobID,
JobStatus: database.ProvisionerJobStatusSucceeded,
}, nil)
db.EXPECT().
UpdateChatWorkspaceBinding(gomock.Any(), database.UpdateChatWorkspaceBindingParams{
ID: chatID,
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
BuildID: uuid.NullUUID{},
AgentID: uuid.NullUUID{},
}).
Return(database.Chat{
ID: chatID,
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
}, nil)
db.EXPECT().
GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{
{ID: uuid.New(), Name: "alpha-coderd-chat", DisplayOrder: 0},
{ID: uuid.New(), Name: "beta-coderd-chat", DisplayOrder: 1},
}, nil)
tool := CreateWorkspace(CreateWorkspaceOptions{
DB: db,
OwnerID: ownerID,
ChatID: chatID,
CreateFn: func(_ context.Context, _ uuid.UUID, req codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
return codersdk.Workspace{
ID: workspaceID,
Name: req.Name,
OwnerName: "testuser",
}, nil
},
AgentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
t.Fatal("AgentConnFn should not be called when agent selection fails")
return nil, nil, xerrors.New("unexpected agent dial")
},
WorkspaceMu: &sync.Mutex{},
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
})
input := fmt.Sprintf(`{"template_id":%q,"name":"test-selection-error"}`, templateID.String())
resp, err := tool.Run(context.Background(), fantasy.ToolCall{
ID: "call-1",
Name: "create_workspace",
Input: input,
})
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
require.Equal(t, true, result["created"])
require.Equal(t, "testuser/test-selection-error", result["workspace_name"])
require.Equal(t, "selection_error", result["agent_status"])
require.Contains(t, result["agent_error"], "multiple agents match the chat suffix")
}
func TestCreateWorkspace_GlobalTTL(t *testing.T) {
t.Parallel()
@@ -253,6 +428,7 @@ func TestCheckExistingWorkspace_ConnectedAgent(t *testing.T) {
GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{{
ID: agentID,
Name: "dev",
CreatedAt: now.Add(-time.Minute),
FirstConnectedAt: validNullTime(now.Add(-45 * time.Second)),
LastConnectedAt: validNullTime(now.Add(-5 * time.Second)),
@@ -302,6 +478,7 @@ func TestCheckExistingWorkspace_ConnectingAgentWaits(t *testing.T) {
GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{{
ID: agentID,
Name: "dev",
CreatedAt: now,
ConnectionTimeoutSeconds: 60,
}}, nil)
@@ -336,6 +513,7 @@ func TestCheckExistingWorkspace_DeadAgentAllowsCreation(t *testing.T) {
name: "Disconnected",
agent: database.WorkspaceAgent{
ID: uuid.New(),
Name: "disconnected",
CreatedAt: time.Now().UTC().Add(-2 * time.Minute),
FirstConnectedAt: validNullTime(time.Now().UTC().Add(-2 * time.Minute)),
LastConnectedAt: validNullTime(time.Now().UTC().Add(-time.Minute)),
@@ -345,6 +523,7 @@ func TestCheckExistingWorkspace_DeadAgentAllowsCreation(t *testing.T) {
name: "TimedOut",
agent: database.WorkspaceAgent{
ID: uuid.New(),
Name: "timed-out",
CreatedAt: time.Now().UTC().Add(-2 * time.Second),
ConnectionTimeoutSeconds: 1,
},
+14 -3
View File
@@ -9,6 +9,7 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect"
"github.com/coder/coder/v2/codersdk"
)
@@ -144,7 +145,7 @@ func StartWorkspace(options StartWorkspaceOptions) fantasy.AgentTool {
)
}
// waitForAgentAndRespond looks up the first agent in the workspace's
// waitForAgentAndRespond selects the chat agent from the workspace's
// latest build, waits for it to become reachable, and returns a
// success response.
func waitForAgentAndRespond(
@@ -155,7 +156,7 @@ func waitForAgentAndRespond(
) (fantasy.ToolResponse, error) {
agents, err := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, ws.ID)
if err != nil || len(agents) == 0 {
// Workspace started but no agent found still report
// Workspace started but no agent found - still report
// success so the model knows the workspace is up.
return toolResponse(map[string]any{
"started": true,
@@ -164,11 +165,21 @@ func waitForAgentAndRespond(
}), nil
}
selected, err := agentselect.FindChatAgent(agents)
if err != nil {
return toolResponse(map[string]any{
"started": true,
"workspace_name": ws.Name,
"agent_status": "selection_error",
"agent_error": err.Error(),
}), nil
}
result := map[string]any{
"started": true,
"workspace_name": ws.Name,
}
for k, v := range waitForAgentReady(ctx, db, agents[0].ID, agentConnFn) {
for k, v := range waitForAgentReady(ctx, db, selected.ID, agentConnFn) {
result[k] = v
}
return toolResponse(result), nil
@@ -6,6 +6,7 @@ import (
"encoding/json"
"sync"
"testing"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
@@ -18,6 +19,7 @@ import (
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil"
)
@@ -108,6 +110,206 @@ func TestStartWorkspace(t *testing.T) {
require.True(t, started)
})
t.Run("AlreadyRunningPrefersChatSuffixAgent", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
modelCfg := seedModelConfig(ctx, t, db, user.ID)
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: user.ID,
OrganizationID: org.ID,
}).WithAgent(func(agents []*sdkproto.Agent) []*sdkproto.Agent {
agents[0].Name = "dev"
return append(agents, &sdkproto.Agent{
Id: uuid.NewString(),
Name: "dev-coderd-chat",
Auth: &sdkproto.Agent_Token{Token: uuid.NewString()},
Env: map[string]string{},
})
}).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStart,
}).Do()
ws := wsResp.Workspace
now := time.Now().UTC()
preferredAgentID := uuid.Nil
for _, agent := range wsResp.Agents {
if agent.Name == "dev-coderd-chat" {
preferredAgentID = agent.ID
}
err := db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: agent.ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
StartedAt: sql.NullTime{Time: now, Valid: true},
ReadyAt: sql.NullTime{Time: now, Valid: true},
})
require.NoError(t, err)
}
require.NotEqual(t, uuid.Nil, preferredAgentID)
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
LastModelConfigID: modelCfg.ID,
Title: "test-running-preferred-agent",
})
require.NoError(t, err)
var connectedAgentID uuid.UUID
agentConnFn := func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
connectedAgentID = agentID
return nil, func() {}, nil
}
tool := chattool.StartWorkspace(chattool.StartWorkspaceOptions{
DB: db,
OwnerID: user.ID,
ChatID: chat.ID,
AgentConnFn: agentConnFn,
StartFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
t.Fatal("StartFn should not be called for already-running workspace")
return codersdk.WorkspaceBuild{}, nil
},
WorkspaceMu: &sync.Mutex{},
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "start_workspace", Input: "{}"})
require.NoError(t, err)
require.Equal(t, preferredAgentID, connectedAgentID)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
started, ok := result["started"].(bool)
require.True(t, ok)
require.True(t, started)
})
t.Run("AlreadyRunningWithoutAgentsReturnsNoAgent", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
modelCfg := seedModelConfig(ctx, t, db, user.ID)
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: user.ID,
OrganizationID: org.ID,
}).WithAgent(func(_ []*sdkproto.Agent) []*sdkproto.Agent {
return nil
}).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStart,
}).Do()
ws := wsResp.Workspace
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
LastModelConfigID: modelCfg.ID,
Title: "test-running-no-agent",
})
require.NoError(t, err)
tool := chattool.StartWorkspace(chattool.StartWorkspaceOptions{
DB: db,
OwnerID: user.ID,
ChatID: chat.ID,
AgentConnFn: func(_ context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
t.Fatal("AgentConnFn should not be called when no agents exist")
return nil, func() {}, nil
},
StartFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
t.Fatal("StartFn should not be called for already-running workspace")
return codersdk.WorkspaceBuild{}, nil
},
WorkspaceMu: &sync.Mutex{},
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "start_workspace", Input: "{}"})
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
started, ok := result["started"].(bool)
require.True(t, ok)
require.True(t, started)
require.Equal(t, "no_agent", result["agent_status"])
})
t.Run("AlreadyRunningPreservesAgentSelectionError", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
modelCfg := seedModelConfig(ctx, t, db, user.ID)
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: user.ID,
OrganizationID: org.ID,
}).WithAgent(func(agents []*sdkproto.Agent) []*sdkproto.Agent {
agents[0].Name = "alpha-coderd-chat"
return append(agents, &sdkproto.Agent{
Id: uuid.NewString(),
Name: "beta-coderd-chat",
Auth: &sdkproto.Agent_Token{Token: uuid.NewString()},
Env: map[string]string{},
})
}).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStart,
}).Do()
ws := wsResp.Workspace
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
LastModelConfigID: modelCfg.ID,
Title: "test-running-selection-error",
})
require.NoError(t, err)
tool := chattool.StartWorkspace(chattool.StartWorkspaceOptions{
DB: db,
OwnerID: user.ID,
ChatID: chat.ID,
AgentConnFn: func(_ context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
t.Fatal("AgentConnFn should not be called when agent selection fails")
return nil, func() {}, nil
},
StartFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
t.Fatal("StartFn should not be called for already-running workspace")
return codersdk.WorkspaceBuild{}, nil
},
WorkspaceMu: &sync.Mutex{},
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "start_workspace", Input: "{}"})
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
started, ok := result["started"].(bool)
require.True(t, ok)
require.True(t, started)
require.Equal(t, "selection_error", result["agent_status"])
require.Contains(t, result["agent_error"], "multiple agents match the chat suffix")
})
t.Run("StoppedWorkspace", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
@@ -0,0 +1,86 @@
package agentselect
import (
"cmp"
"slices"
"strings"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
)
// Suffix marks chat-designated agents during the current PoC. This naming
// convention is an implementation detail, not a stable contract.
const Suffix = "-coderd-chat"
// IsChatAgent reports whether name uses the chat-agent suffix convention.
func IsChatAgent(name string) bool {
return strings.HasSuffix(strings.ToLower(name), Suffix)
}
// FindChatAgent picks the best workspace agent for a chat session from the
// provided candidates. It applies these rules in order:
// 1. Filter to root agents only (ParentID is null).
// 2. Sort stably and deterministically by DisplayOrder ASC, then Name ASC
// (case-insensitive), then Name ASC, then ID ASC.
// 3. If exactly one root agent name ends with Suffix (case-insensitive),
// return it.
// 4. If zero root agents match the suffix, return the first root agent after
// sorting (deterministic fallback).
// 5. If more than one root agent matches the suffix, return an error with an
// actionable message.
// 6. If no root agents exist at all, return an error.
func FindChatAgent(
agents []database.WorkspaceAgent,
) (database.WorkspaceAgent, error) {
rootAgents := make([]database.WorkspaceAgent, 0, len(agents))
matchingAgents := make([]database.WorkspaceAgent, 0, 1)
for _, agent := range agents {
if agent.ParentID.Valid {
continue
}
rootAgents = append(rootAgents, agent)
if IsChatAgent(agent.Name) {
matchingAgents = append(matchingAgents, agent)
}
}
if len(rootAgents) == 0 {
return database.WorkspaceAgent{}, xerrors.New(
"no eligible workspace agents found",
)
}
compareAgents := func(a, b database.WorkspaceAgent) int {
if order := cmp.Compare(a.DisplayOrder, b.DisplayOrder); order != 0 {
return order
}
if order := cmp.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name)); order != 0 {
return order
}
if order := cmp.Compare(a.Name, b.Name); order != 0 {
return order
}
return cmp.Compare(a.ID.String(), b.ID.String())
}
slices.SortStableFunc(rootAgents, compareAgents)
slices.SortStableFunc(matchingAgents, compareAgents)
switch len(matchingAgents) {
case 0:
return rootAgents[0], nil
case 1:
return matchingAgents[0], nil
default:
names := make([]string, 0, len(matchingAgents))
for _, agent := range matchingAgents {
names = append(names, agent.Name)
}
return database.WorkspaceAgent{}, xerrors.Errorf(
"multiple agents match the chat suffix %q: %s; only one agent should use this suffix",
Suffix,
strings.Join(names, ", "),
)
}
}
@@ -0,0 +1,231 @@
package agentselect_test
import (
"fmt"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect"
)
func TestFindChatAgent(t *testing.T) {
t.Parallel()
newRootAgentWithID := func(id, name string, displayOrder int32) database.WorkspaceAgent {
return database.WorkspaceAgent{
ID: uuid.MustParse(id),
Name: name,
DisplayOrder: displayOrder,
}
}
newRootAgent := func(name string, displayOrder int32) database.WorkspaceAgent {
return newRootAgentWithID(uuid.NewString(), name, displayOrder)
}
newChildAgent := func(name string, displayOrder int32) database.WorkspaceAgent {
agent := newRootAgent(name, displayOrder)
agent.ParentID = uuid.NullUUID{UUID: uuid.New(), Valid: true}
return agent
}
tests := []struct {
name string
agents []database.WorkspaceAgent
wantIndex int
wantErrContains []string
}{
{
name: "SingleSuffixMatch",
agents: []database.WorkspaceAgent{
newRootAgent("alpha", 0),
newRootAgent("dev-coderd-chat", 2),
newRootAgent("zeta", 1),
},
wantIndex: 1,
},
{
name: "SuffixMatchCaseInsensitive",
agents: []database.WorkspaceAgent{
newRootAgent("alpha", 0),
newRootAgent("Dev-Coderd-Chat", 2),
newRootAgent("zeta", 1),
},
wantIndex: 1,
},
{
name: "NoSuffixMatchFallbackDeterministic",
agents: []database.WorkspaceAgent{
newRootAgent("zeta", 2),
newRootAgent("bravo", 1),
newRootAgent("alpha", 1),
},
wantIndex: 2,
},
{
name: "NoSuffixMatchFallbackByName",
agents: []database.WorkspaceAgent{
newRootAgent("Bravo", 3),
newRootAgent("alpha", 3),
newRootAgent("charlie", 3),
},
wantIndex: 1,
},
{
name: "CaseOnlyNameTieFallbackDeterministic",
agents: []database.WorkspaceAgent{
newRootAgent("Dev", 0),
newRootAgent("dev", 0),
},
wantIndex: 0,
},
{
name: "ExactNameTieFallbackByID",
agents: []database.WorkspaceAgent{
newRootAgentWithID("00000000-0000-0000-0000-000000000002", "dev", 0),
newRootAgentWithID("00000000-0000-0000-0000-000000000001", "dev", 0),
},
wantIndex: 1,
},
{
name: "MultipleSuffixMatchesError",
agents: []database.WorkspaceAgent{
newRootAgent("alpha-coderd-chat", 2),
newRootAgent("beta-coderd-chat", 1),
newRootAgent("gamma", 0),
},
wantErrContains: []string{
fmt.Sprintf(
"multiple agents match the chat suffix %q",
agentselect.Suffix,
),
"alpha-coderd-chat",
"beta-coderd-chat",
"only one agent should use this suffix",
},
},
{
name: "ChildAgentSuffixIgnored",
agents: []database.WorkspaceAgent{
newRootAgent("alpha", 1),
newChildAgent("child-coderd-chat", 0),
newRootAgent("bravo", 0),
},
wantIndex: 2,
},
{
name: "ChildAgentSuffixIgnoredWithRootMatch",
agents: []database.WorkspaceAgent{
newRootAgent("alpha", 0),
newChildAgent("child-coderd-chat", 1),
newRootAgent("root-coderd-chat", 2),
},
wantIndex: 2,
},
{
name: "EmptyAgentList",
agents: []database.WorkspaceAgent{},
wantErrContains: []string{
"no eligible workspace agents found",
},
},
{
name: "OnlyChildAgents",
agents: []database.WorkspaceAgent{
newChildAgent("alpha", 0),
newChildAgent("beta-coderd-chat", 1),
},
wantErrContains: []string{
"no eligible workspace agents found",
},
},
{
name: "SingleRootAgent",
agents: []database.WorkspaceAgent{
newRootAgent("solo", 5),
},
wantIndex: 0,
},
{
name: "SuffixAgentWinsRegardlessOfOrder",
agents: []database.WorkspaceAgent{
newRootAgent("alpha", 0),
newRootAgent("zeta", 1),
newRootAgent("preferred-coderd-chat", 99),
},
wantIndex: 2,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := agentselect.FindChatAgent(tt.agents)
if len(tt.wantErrContains) > 0 {
require.Error(t, err)
for _, wantErr := range tt.wantErrContains {
require.ErrorContains(t, err, wantErr)
}
return
}
require.NoError(t, err)
require.Equal(t, tt.agents[tt.wantIndex], got)
})
}
}
func TestIsChatAgent(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want bool
}{
{
name: "ExactSuffix",
input: "agent-coderd-chat",
want: true,
},
{
name: "UppercaseSuffix",
input: "agent-CODERD-CHAT",
want: true,
},
{
name: "MixedCaseSuffix",
input: "agent-Coderd-Chat",
want: true,
},
{
name: "NoSuffix",
input: "my-agent",
want: false,
},
{
name: "SuffixOnly",
input: "-coderd-chat",
want: true,
},
{
name: "PartialSuffix",
input: "agent-coderd",
want: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, agentselect.IsChatAgent(tt.input))
})
}
}