mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: suffix-based chat agent selection (#23741)
Adds suffix-based agent selection for chatd. Template authors can direct chat traffic to a specific root workspace agent by naming it with the `-coderd-chat` suffix (for example, `coder_agent "dev-coderd-chat"`). When no suffix match exists, chatd falls back to the first root agent by `DisplayOrder`, then `Name`. Multiple suffix matches return an error. The selection logic lives in `coderd/x/chatd/internal/agentselect` and is shared by chatd core plus the workspace chat tools so all chat entry points pick the same agent deterministically. No database migrations, API contract changes, or provider changes. The experimental sandbox template was split out to #23777.
This commit is contained in:
+18
-3
@@ -39,6 +39,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/mcpclient"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
@@ -379,6 +380,13 @@ func (c *turnWorkspaceContext) loadWorkspaceAgentLocked(
|
||||
if len(agents) == 0 {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, errChatHasNoWorkspaceAgent
|
||||
}
|
||||
selected, err := agentselect.FindChatAgent(agents)
|
||||
if err != nil {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf(
|
||||
"find chat agent: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
@@ -389,7 +397,7 @@ func (c *turnWorkspaceContext) loadWorkspaceAgentLocked(
|
||||
ctx,
|
||||
chatSnapshot,
|
||||
build.ID,
|
||||
agents[0].ID,
|
||||
selected.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return chatSnapshot, database.WorkspaceAgent{}, err
|
||||
@@ -401,7 +409,7 @@ func (c *turnWorkspaceContext) loadWorkspaceAgentLocked(
|
||||
chatSnapshot = latestChat
|
||||
continue
|
||||
}
|
||||
c.agent = agents[0]
|
||||
c.agent = selected
|
||||
c.agentLoaded = true
|
||||
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
|
||||
return chatSnapshot, c.agent, nil
|
||||
@@ -429,7 +437,14 @@ func (c *turnWorkspaceContext) latestWorkspaceAgentID(
|
||||
if len(agents) == 0 {
|
||||
return uuid.Nil, errChatHasNoWorkspaceAgent
|
||||
}
|
||||
return agents[0].ID, nil
|
||||
selected, err := agentselect.FindChatAgent(agents)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf(
|
||||
"find chat agent: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
return selected.ID, nil
|
||||
}
|
||||
|
||||
func (c *turnWorkspaceContext) workspaceAgentIDForConn(
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/util/namesgenerator"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
@@ -203,12 +204,28 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
}
|
||||
}
|
||||
|
||||
// Look up the first agent so we can link it to the chat.
|
||||
result := map[string]any{
|
||||
"created": true,
|
||||
"workspace_name": workspace.FullName(),
|
||||
}
|
||||
|
||||
// Select the chat agent so follow-up tools wait on the
|
||||
// intended workspace agent.
|
||||
workspaceAgentID := uuid.Nil
|
||||
if options.DB != nil {
|
||||
agents, agentErr := options.DB.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspace.ID)
|
||||
if agentErr == nil && len(agents) > 0 {
|
||||
workspaceAgentID = agents[0].ID
|
||||
if agentErr == nil {
|
||||
if len(agents) == 0 {
|
||||
result["agent_status"] = "no_agent"
|
||||
} else {
|
||||
selected, selectErr := agentselect.FindChatAgent(agents)
|
||||
if selectErr != nil {
|
||||
result["agent_status"] = "selection_error"
|
||||
result["agent_error"] = selectErr.Error()
|
||||
} else {
|
||||
workspaceAgentID = selected.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -241,20 +258,12 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
// Wait for the agent to come online and startup scripts to finish.
|
||||
if workspaceAgentID != uuid.Nil {
|
||||
agentStatus := waitForAgentReady(ctx, options.DB, workspaceAgentID, options.AgentConnFn)
|
||||
result := map[string]any{
|
||||
"created": true,
|
||||
"workspace_name": workspace.FullName(),
|
||||
}
|
||||
for k, v := range agentStatus {
|
||||
result[k] = v
|
||||
}
|
||||
return toolResponse(result), nil
|
||||
}
|
||||
|
||||
return toolResponse(map[string]any{
|
||||
"created": true,
|
||||
"workspace_name": workspace.FullName(),
|
||||
}), nil
|
||||
return toolResponse(result), nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -322,7 +331,15 @@ func (o CreateWorkspaceOptions) checkExistingWorkspace(
|
||||
}
|
||||
agents, agentsErr := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, ws.ID)
|
||||
if agentsErr == nil && len(agents) > 0 {
|
||||
for k, v := range waitForAgentReady(ctx, db, agents[0].ID, agentConnFn) {
|
||||
selected, selectErr := agentselect.FindChatAgent(agents)
|
||||
if selectErr != nil {
|
||||
o.Logger.Debug(ctx, "agent selection failed, falling back to first agent for readiness check",
|
||||
slog.F("workspace_id", ws.ID),
|
||||
slog.Error(selectErr),
|
||||
)
|
||||
selected = agents[0]
|
||||
}
|
||||
for k, v := range waitForAgentReady(ctx, db, selected.ID, agentConnFn) {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
@@ -345,7 +362,15 @@ func (o CreateWorkspaceOptions) checkExistingWorkspace(
|
||||
// still usable.
|
||||
agents, agentsErr := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, ws.ID)
|
||||
if agentsErr == nil && len(agents) > 0 {
|
||||
status := agents[0].Status(agentInactiveDisconnectTimeout)
|
||||
selected, selectErr := agentselect.FindChatAgent(agents)
|
||||
if selectErr != nil {
|
||||
o.Logger.Debug(ctx, "agent selection failed, falling back to first agent for status check",
|
||||
slog.F("workspace_id", ws.ID),
|
||||
slog.Error(selectErr),
|
||||
)
|
||||
selected = agents[0]
|
||||
}
|
||||
status := selected.Status(agentInactiveDisconnectTimeout)
|
||||
result := map[string]any{
|
||||
"created": false,
|
||||
"workspace_name": ws.Name,
|
||||
@@ -355,19 +380,19 @@ func (o CreateWorkspaceOptions) checkExistingWorkspace(
|
||||
switch status.Status {
|
||||
case database.WorkspaceAgentStatusConnected:
|
||||
result["message"] = "workspace is already running and recently connected"
|
||||
for k, v := range waitForAgentReady(ctx, db, agents[0].ID, nil) {
|
||||
for k, v := range waitForAgentReady(ctx, db, selected.ID, nil) {
|
||||
result[k] = v
|
||||
}
|
||||
return result, true, nil
|
||||
case database.WorkspaceAgentStatusConnecting:
|
||||
result["message"] = "workspace exists and the agent is still connecting"
|
||||
for k, v := range waitForAgentReady(ctx, db, agents[0].ID, agentConnFn) {
|
||||
for k, v := range waitForAgentReady(ctx, db, selected.ID, agentConnFn) {
|
||||
result[k] = v
|
||||
}
|
||||
return result, true, nil
|
||||
case database.WorkspaceAgentStatusDisconnected,
|
||||
database.WorkspaceAgentStatusTimeout:
|
||||
// Agent is offline or never became ready — allow
|
||||
// Agent is offline or never became ready - allow
|
||||
// creation.
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package chattool //nolint:testpackage // Uses internal symbols.
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -118,6 +119,180 @@ func TestWaitForAgentReady(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateWorkspace_PrefersChatSuffixAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
ownerID := uuid.New()
|
||||
templateID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
jobID := uuid.New()
|
||||
fallbackAgentID := uuid.New()
|
||||
chatAgentID := uuid.New()
|
||||
|
||||
db.EXPECT().
|
||||
GetAuthorizationUserRoles(gomock.Any(), ownerID).
|
||||
Return(database.GetAuthorizationUserRolesRow{
|
||||
ID: ownerID,
|
||||
Roles: []string{},
|
||||
Groups: []string{},
|
||||
Status: database.UserStatusActive,
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().
|
||||
GetChatWorkspaceTTL(gomock.Any()).
|
||||
Return("0s", nil)
|
||||
|
||||
db.EXPECT().
|
||||
GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).
|
||||
Return(database.WorkspaceBuild{
|
||||
WorkspaceID: workspaceID,
|
||||
JobID: jobID,
|
||||
}, nil)
|
||||
db.EXPECT().
|
||||
GetProvisionerJobByID(gomock.Any(), jobID).
|
||||
Return(database.ProvisionerJob{
|
||||
ID: jobID,
|
||||
JobStatus: database.ProvisionerJobStatusSucceeded,
|
||||
}, nil)
|
||||
db.EXPECT().
|
||||
GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
||||
Return([]database.WorkspaceAgent{
|
||||
{ID: fallbackAgentID, Name: "dev", DisplayOrder: 0},
|
||||
{ID: chatAgentID, Name: "dev-coderd-chat", DisplayOrder: 1},
|
||||
}, nil)
|
||||
db.EXPECT().
|
||||
GetWorkspaceAgentLifecycleStateByID(gomock.Any(), chatAgentID).
|
||||
Return(database.GetWorkspaceAgentLifecycleStateByIDRow{
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
}, nil)
|
||||
|
||||
var connectedAgentID uuid.UUID
|
||||
createFn := func(_ context.Context, _ uuid.UUID, req codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
|
||||
return codersdk.Workspace{
|
||||
ID: workspaceID,
|
||||
Name: req.Name,
|
||||
OwnerName: "testuser",
|
||||
}, nil
|
||||
}
|
||||
agentConnFn := func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
connectedAgentID = agentID
|
||||
return nil, func() {}, nil
|
||||
}
|
||||
|
||||
tool := CreateWorkspace(CreateWorkspaceOptions{
|
||||
DB: db,
|
||||
OwnerID: ownerID,
|
||||
CreateFn: createFn,
|
||||
AgentConnFn: agentConnFn,
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
})
|
||||
|
||||
input := fmt.Sprintf(`{"template_id":%q,"name":"test-chat-agent"}`, templateID.String())
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "create_workspace",
|
||||
Input: input,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, resp.Content)
|
||||
require.Equal(t, chatAgentID, connectedAgentID)
|
||||
}
|
||||
|
||||
func TestCreateWorkspace_ReturnsSelectionErrorImmediately(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
ownerID := uuid.New()
|
||||
chatID := uuid.New()
|
||||
templateID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
jobID := uuid.New()
|
||||
|
||||
db.EXPECT().
|
||||
GetChatByID(gomock.Any(), chatID).
|
||||
Return(database.Chat{ID: chatID}, nil)
|
||||
db.EXPECT().
|
||||
GetAuthorizationUserRoles(gomock.Any(), ownerID).
|
||||
Return(database.GetAuthorizationUserRolesRow{
|
||||
ID: ownerID,
|
||||
Roles: []string{},
|
||||
Groups: []string{},
|
||||
Status: database.UserStatusActive,
|
||||
}, nil)
|
||||
db.EXPECT().
|
||||
GetChatWorkspaceTTL(gomock.Any()).
|
||||
Return("0s", nil)
|
||||
db.EXPECT().
|
||||
GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), workspaceID).
|
||||
Return(database.WorkspaceBuild{
|
||||
WorkspaceID: workspaceID,
|
||||
JobID: jobID,
|
||||
}, nil)
|
||||
db.EXPECT().
|
||||
GetProvisionerJobByID(gomock.Any(), jobID).
|
||||
Return(database.ProvisionerJob{
|
||||
ID: jobID,
|
||||
JobStatus: database.ProvisionerJobStatusSucceeded,
|
||||
}, nil)
|
||||
db.EXPECT().
|
||||
UpdateChatWorkspaceBinding(gomock.Any(), database.UpdateChatWorkspaceBindingParams{
|
||||
ID: chatID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
|
||||
BuildID: uuid.NullUUID{},
|
||||
AgentID: uuid.NullUUID{},
|
||||
}).
|
||||
Return(database.Chat{
|
||||
ID: chatID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
|
||||
}, nil)
|
||||
db.EXPECT().
|
||||
GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
||||
Return([]database.WorkspaceAgent{
|
||||
{ID: uuid.New(), Name: "alpha-coderd-chat", DisplayOrder: 0},
|
||||
{ID: uuid.New(), Name: "beta-coderd-chat", DisplayOrder: 1},
|
||||
}, nil)
|
||||
|
||||
tool := CreateWorkspace(CreateWorkspaceOptions{
|
||||
DB: db,
|
||||
OwnerID: ownerID,
|
||||
ChatID: chatID,
|
||||
CreateFn: func(_ context.Context, _ uuid.UUID, req codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) {
|
||||
return codersdk.Workspace{
|
||||
ID: workspaceID,
|
||||
Name: req.Name,
|
||||
OwnerName: "testuser",
|
||||
}, nil
|
||||
},
|
||||
AgentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
t.Fatal("AgentConnFn should not be called when agent selection fails")
|
||||
return nil, nil, xerrors.New("unexpected agent dial")
|
||||
},
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
})
|
||||
|
||||
input := fmt.Sprintf(`{"template_id":%q,"name":"test-selection-error"}`, templateID.String())
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "create_workspace",
|
||||
Input: input,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
require.Equal(t, true, result["created"])
|
||||
require.Equal(t, "testuser/test-selection-error", result["workspace_name"])
|
||||
require.Equal(t, "selection_error", result["agent_status"])
|
||||
require.Contains(t, result["agent_error"], "multiple agents match the chat suffix")
|
||||
}
|
||||
|
||||
func TestCreateWorkspace_GlobalTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -253,6 +428,7 @@ func TestCheckExistingWorkspace_ConnectedAgent(t *testing.T) {
|
||||
GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
||||
Return([]database.WorkspaceAgent{{
|
||||
ID: agentID,
|
||||
Name: "dev",
|
||||
CreatedAt: now.Add(-time.Minute),
|
||||
FirstConnectedAt: validNullTime(now.Add(-45 * time.Second)),
|
||||
LastConnectedAt: validNullTime(now.Add(-5 * time.Second)),
|
||||
@@ -302,6 +478,7 @@ func TestCheckExistingWorkspace_ConnectingAgentWaits(t *testing.T) {
|
||||
GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
|
||||
Return([]database.WorkspaceAgent{{
|
||||
ID: agentID,
|
||||
Name: "dev",
|
||||
CreatedAt: now,
|
||||
ConnectionTimeoutSeconds: 60,
|
||||
}}, nil)
|
||||
@@ -336,6 +513,7 @@ func TestCheckExistingWorkspace_DeadAgentAllowsCreation(t *testing.T) {
|
||||
name: "Disconnected",
|
||||
agent: database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
Name: "disconnected",
|
||||
CreatedAt: time.Now().UTC().Add(-2 * time.Minute),
|
||||
FirstConnectedAt: validNullTime(time.Now().UTC().Add(-2 * time.Minute)),
|
||||
LastConnectedAt: validNullTime(time.Now().UTC().Add(-time.Minute)),
|
||||
@@ -345,6 +523,7 @@ func TestCheckExistingWorkspace_DeadAgentAllowsCreation(t *testing.T) {
|
||||
name: "TimedOut",
|
||||
agent: database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
Name: "timed-out",
|
||||
CreatedAt: time.Now().UTC().Add(-2 * time.Second),
|
||||
ConnectionTimeoutSeconds: 1,
|
||||
},
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -144,7 +145,7 @@ func StartWorkspace(options StartWorkspaceOptions) fantasy.AgentTool {
|
||||
)
|
||||
}
|
||||
|
||||
// waitForAgentAndRespond looks up the first agent in the workspace's
|
||||
// waitForAgentAndRespond selects the chat agent from the workspace's
|
||||
// latest build, waits for it to become reachable, and returns a
|
||||
// success response.
|
||||
func waitForAgentAndRespond(
|
||||
@@ -155,7 +156,7 @@ func waitForAgentAndRespond(
|
||||
) (fantasy.ToolResponse, error) {
|
||||
agents, err := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, ws.ID)
|
||||
if err != nil || len(agents) == 0 {
|
||||
// Workspace started but no agent found — still report
|
||||
// Workspace started but no agent found - still report
|
||||
// success so the model knows the workspace is up.
|
||||
return toolResponse(map[string]any{
|
||||
"started": true,
|
||||
@@ -164,11 +165,21 @@ func waitForAgentAndRespond(
|
||||
}), nil
|
||||
}
|
||||
|
||||
selected, err := agentselect.FindChatAgent(agents)
|
||||
if err != nil {
|
||||
return toolResponse(map[string]any{
|
||||
"started": true,
|
||||
"workspace_name": ws.Name,
|
||||
"agent_status": "selection_error",
|
||||
"agent_error": err.Error(),
|
||||
}), nil
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"started": true,
|
||||
"workspace_name": ws.Name,
|
||||
}
|
||||
for k, v := range waitForAgentReady(ctx, db, agents[0].ID, agentConnFn) {
|
||||
for k, v := range waitForAgentReady(ctx, db, selected.ID, agentConnFn) {
|
||||
result[k] = v
|
||||
}
|
||||
return toolResponse(result), nil
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
@@ -108,6 +110,206 @@ func TestStartWorkspace(t *testing.T) {
|
||||
require.True(t, started)
|
||||
})
|
||||
|
||||
t.Run("AlreadyRunningPrefersChatSuffixAgent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
modelCfg := seedModelConfig(ctx, t, db, user.ID)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
}).WithAgent(func(agents []*sdkproto.Agent) []*sdkproto.Agent {
|
||||
agents[0].Name = "dev"
|
||||
return append(agents, &sdkproto.Agent{
|
||||
Id: uuid.NewString(),
|
||||
Name: "dev-coderd-chat",
|
||||
Auth: &sdkproto.Agent_Token{Token: uuid.NewString()},
|
||||
Env: map[string]string{},
|
||||
})
|
||||
}).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
}).Do()
|
||||
ws := wsResp.Workspace
|
||||
|
||||
now := time.Now().UTC()
|
||||
preferredAgentID := uuid.Nil
|
||||
for _, agent := range wsResp.Agents {
|
||||
if agent.Name == "dev-coderd-chat" {
|
||||
preferredAgentID = agent.ID
|
||||
}
|
||||
err := db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
||||
ID: agent.ID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
StartedAt: sql.NullTime{Time: now, Valid: true},
|
||||
ReadyAt: sql.NullTime{Time: now, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.NotEqual(t, uuid.Nil, preferredAgentID)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-running-preferred-agent",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var connectedAgentID uuid.UUID
|
||||
agentConnFn := func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
connectedAgentID = agentID
|
||||
return nil, func() {}, nil
|
||||
}
|
||||
|
||||
tool := chattool.StartWorkspace(chattool.StartWorkspaceOptions{
|
||||
DB: db,
|
||||
OwnerID: user.ID,
|
||||
ChatID: chat.ID,
|
||||
AgentConnFn: agentConnFn,
|
||||
StartFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
|
||||
t.Fatal("StartFn should not be called for already-running workspace")
|
||||
return codersdk.WorkspaceBuild{}, nil
|
||||
},
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
})
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "start_workspace", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, preferredAgentID, connectedAgentID)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
started, ok := result["started"].(bool)
|
||||
require.True(t, ok)
|
||||
require.True(t, started)
|
||||
})
|
||||
|
||||
t.Run("AlreadyRunningWithoutAgentsReturnsNoAgent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
modelCfg := seedModelConfig(ctx, t, db, user.ID)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
}).WithAgent(func(_ []*sdkproto.Agent) []*sdkproto.Agent {
|
||||
return nil
|
||||
}).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
}).Do()
|
||||
ws := wsResp.Workspace
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-running-no-agent",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tool := chattool.StartWorkspace(chattool.StartWorkspaceOptions{
|
||||
DB: db,
|
||||
OwnerID: user.ID,
|
||||
ChatID: chat.ID,
|
||||
AgentConnFn: func(_ context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
t.Fatal("AgentConnFn should not be called when no agents exist")
|
||||
return nil, func() {}, nil
|
||||
},
|
||||
StartFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
|
||||
t.Fatal("StartFn should not be called for already-running workspace")
|
||||
return codersdk.WorkspaceBuild{}, nil
|
||||
},
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
})
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "start_workspace", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
started, ok := result["started"].(bool)
|
||||
require.True(t, ok)
|
||||
require.True(t, started)
|
||||
require.Equal(t, "no_agent", result["agent_status"])
|
||||
})
|
||||
|
||||
t.Run("AlreadyRunningPreservesAgentSelectionError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
modelCfg := seedModelConfig(ctx, t, db, user.ID)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
}).WithAgent(func(agents []*sdkproto.Agent) []*sdkproto.Agent {
|
||||
agents[0].Name = "alpha-coderd-chat"
|
||||
return append(agents, &sdkproto.Agent{
|
||||
Id: uuid.NewString(),
|
||||
Name: "beta-coderd-chat",
|
||||
Auth: &sdkproto.Agent_Token{Token: uuid.NewString()},
|
||||
Env: map[string]string{},
|
||||
})
|
||||
}).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
}).Do()
|
||||
ws := wsResp.Workspace
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "test-running-selection-error",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tool := chattool.StartWorkspace(chattool.StartWorkspaceOptions{
|
||||
DB: db,
|
||||
OwnerID: user.ID,
|
||||
ChatID: chat.ID,
|
||||
AgentConnFn: func(_ context.Context, _ uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
t.Fatal("AgentConnFn should not be called when agent selection fails")
|
||||
return nil, func() {}, nil
|
||||
},
|
||||
StartFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
|
||||
t.Fatal("StartFn should not be called for already-running workspace")
|
||||
return codersdk.WorkspaceBuild{}, nil
|
||||
},
|
||||
WorkspaceMu: &sync.Mutex{},
|
||||
})
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "start_workspace", Input: "{}"})
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
started, ok := result["started"].(bool)
|
||||
require.True(t, ok)
|
||||
require.True(t, started)
|
||||
require.Equal(t, "selection_error", result["agent_status"])
|
||||
require.Contains(t, result["agent_error"], "multiple agents match the chat suffix")
|
||||
})
|
||||
|
||||
t.Run("StoppedWorkspace", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
package agentselect
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
// Suffix marks chat-designated agents during the current PoC. This naming
|
||||
// convention is an implementation detail, not a stable contract.
|
||||
const Suffix = "-coderd-chat"
|
||||
|
||||
// IsChatAgent reports whether name uses the chat-agent suffix convention.
|
||||
func IsChatAgent(name string) bool {
|
||||
return strings.HasSuffix(strings.ToLower(name), Suffix)
|
||||
}
|
||||
|
||||
// FindChatAgent picks the best workspace agent for a chat session from the
|
||||
// provided candidates. It applies these rules in order:
|
||||
// 1. Filter to root agents only (ParentID is null).
|
||||
// 2. Sort stably and deterministically by DisplayOrder ASC, then Name ASC
|
||||
// (case-insensitive), then Name ASC, then ID ASC.
|
||||
// 3. If exactly one root agent name ends with Suffix (case-insensitive),
|
||||
// return it.
|
||||
// 4. If zero root agents match the suffix, return the first root agent after
|
||||
// sorting (deterministic fallback).
|
||||
// 5. If more than one root agent matches the suffix, return an error with an
|
||||
// actionable message.
|
||||
// 6. If no root agents exist at all, return an error.
|
||||
func FindChatAgent(
|
||||
agents []database.WorkspaceAgent,
|
||||
) (database.WorkspaceAgent, error) {
|
||||
rootAgents := make([]database.WorkspaceAgent, 0, len(agents))
|
||||
matchingAgents := make([]database.WorkspaceAgent, 0, 1)
|
||||
for _, agent := range agents {
|
||||
if agent.ParentID.Valid {
|
||||
continue
|
||||
}
|
||||
rootAgents = append(rootAgents, agent)
|
||||
if IsChatAgent(agent.Name) {
|
||||
matchingAgents = append(matchingAgents, agent)
|
||||
}
|
||||
}
|
||||
|
||||
if len(rootAgents) == 0 {
|
||||
return database.WorkspaceAgent{}, xerrors.New(
|
||||
"no eligible workspace agents found",
|
||||
)
|
||||
}
|
||||
|
||||
compareAgents := func(a, b database.WorkspaceAgent) int {
|
||||
if order := cmp.Compare(a.DisplayOrder, b.DisplayOrder); order != 0 {
|
||||
return order
|
||||
}
|
||||
if order := cmp.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name)); order != 0 {
|
||||
return order
|
||||
}
|
||||
if order := cmp.Compare(a.Name, b.Name); order != 0 {
|
||||
return order
|
||||
}
|
||||
return cmp.Compare(a.ID.String(), b.ID.String())
|
||||
}
|
||||
slices.SortStableFunc(rootAgents, compareAgents)
|
||||
slices.SortStableFunc(matchingAgents, compareAgents)
|
||||
|
||||
switch len(matchingAgents) {
|
||||
case 0:
|
||||
return rootAgents[0], nil
|
||||
case 1:
|
||||
return matchingAgents[0], nil
|
||||
default:
|
||||
names := make([]string, 0, len(matchingAgents))
|
||||
for _, agent := range matchingAgents {
|
||||
names = append(names, agent.Name)
|
||||
}
|
||||
return database.WorkspaceAgent{}, xerrors.Errorf(
|
||||
"multiple agents match the chat suffix %q: %s; only one agent should use this suffix",
|
||||
Suffix,
|
||||
strings.Join(names, ", "),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
package agentselect_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect"
|
||||
)
|
||||
|
||||
func TestFindChatAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
newRootAgentWithID := func(id, name string, displayOrder int32) database.WorkspaceAgent {
|
||||
return database.WorkspaceAgent{
|
||||
ID: uuid.MustParse(id),
|
||||
Name: name,
|
||||
DisplayOrder: displayOrder,
|
||||
}
|
||||
}
|
||||
|
||||
newRootAgent := func(name string, displayOrder int32) database.WorkspaceAgent {
|
||||
return newRootAgentWithID(uuid.NewString(), name, displayOrder)
|
||||
}
|
||||
|
||||
newChildAgent := func(name string, displayOrder int32) database.WorkspaceAgent {
|
||||
agent := newRootAgent(name, displayOrder)
|
||||
agent.ParentID = uuid.NullUUID{UUID: uuid.New(), Valid: true}
|
||||
return agent
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
agents []database.WorkspaceAgent
|
||||
wantIndex int
|
||||
wantErrContains []string
|
||||
}{
|
||||
{
|
||||
name: "SingleSuffixMatch",
|
||||
agents: []database.WorkspaceAgent{
|
||||
newRootAgent("alpha", 0),
|
||||
newRootAgent("dev-coderd-chat", 2),
|
||||
newRootAgent("zeta", 1),
|
||||
},
|
||||
wantIndex: 1,
|
||||
},
|
||||
{
|
||||
name: "SuffixMatchCaseInsensitive",
|
||||
agents: []database.WorkspaceAgent{
|
||||
newRootAgent("alpha", 0),
|
||||
newRootAgent("Dev-Coderd-Chat", 2),
|
||||
newRootAgent("zeta", 1),
|
||||
},
|
||||
wantIndex: 1,
|
||||
},
|
||||
{
|
||||
name: "NoSuffixMatchFallbackDeterministic",
|
||||
agents: []database.WorkspaceAgent{
|
||||
newRootAgent("zeta", 2),
|
||||
newRootAgent("bravo", 1),
|
||||
newRootAgent("alpha", 1),
|
||||
},
|
||||
wantIndex: 2,
|
||||
},
|
||||
{
|
||||
name: "NoSuffixMatchFallbackByName",
|
||||
agents: []database.WorkspaceAgent{
|
||||
newRootAgent("Bravo", 3),
|
||||
newRootAgent("alpha", 3),
|
||||
newRootAgent("charlie", 3),
|
||||
},
|
||||
wantIndex: 1,
|
||||
},
|
||||
{
|
||||
name: "CaseOnlyNameTieFallbackDeterministic",
|
||||
agents: []database.WorkspaceAgent{
|
||||
newRootAgent("Dev", 0),
|
||||
newRootAgent("dev", 0),
|
||||
},
|
||||
wantIndex: 0,
|
||||
},
|
||||
{
|
||||
name: "ExactNameTieFallbackByID",
|
||||
agents: []database.WorkspaceAgent{
|
||||
newRootAgentWithID("00000000-0000-0000-0000-000000000002", "dev", 0),
|
||||
newRootAgentWithID("00000000-0000-0000-0000-000000000001", "dev", 0),
|
||||
},
|
||||
wantIndex: 1,
|
||||
},
|
||||
{
|
||||
name: "MultipleSuffixMatchesError",
|
||||
agents: []database.WorkspaceAgent{
|
||||
newRootAgent("alpha-coderd-chat", 2),
|
||||
newRootAgent("beta-coderd-chat", 1),
|
||||
newRootAgent("gamma", 0),
|
||||
},
|
||||
wantErrContains: []string{
|
||||
fmt.Sprintf(
|
||||
"multiple agents match the chat suffix %q",
|
||||
agentselect.Suffix,
|
||||
),
|
||||
"alpha-coderd-chat",
|
||||
"beta-coderd-chat",
|
||||
"only one agent should use this suffix",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ChildAgentSuffixIgnored",
|
||||
agents: []database.WorkspaceAgent{
|
||||
newRootAgent("alpha", 1),
|
||||
newChildAgent("child-coderd-chat", 0),
|
||||
newRootAgent("bravo", 0),
|
||||
},
|
||||
wantIndex: 2,
|
||||
},
|
||||
{
|
||||
name: "ChildAgentSuffixIgnoredWithRootMatch",
|
||||
agents: []database.WorkspaceAgent{
|
||||
newRootAgent("alpha", 0),
|
||||
newChildAgent("child-coderd-chat", 1),
|
||||
newRootAgent("root-coderd-chat", 2),
|
||||
},
|
||||
wantIndex: 2,
|
||||
},
|
||||
{
|
||||
name: "EmptyAgentList",
|
||||
agents: []database.WorkspaceAgent{},
|
||||
wantErrContains: []string{
|
||||
"no eligible workspace agents found",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OnlyChildAgents",
|
||||
agents: []database.WorkspaceAgent{
|
||||
newChildAgent("alpha", 0),
|
||||
newChildAgent("beta-coderd-chat", 1),
|
||||
},
|
||||
wantErrContains: []string{
|
||||
"no eligible workspace agents found",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SingleRootAgent",
|
||||
agents: []database.WorkspaceAgent{
|
||||
newRootAgent("solo", 5),
|
||||
},
|
||||
wantIndex: 0,
|
||||
},
|
||||
{
|
||||
name: "SuffixAgentWinsRegardlessOfOrder",
|
||||
agents: []database.WorkspaceAgent{
|
||||
newRootAgent("alpha", 0),
|
||||
newRootAgent("zeta", 1),
|
||||
newRootAgent("preferred-coderd-chat", 99),
|
||||
},
|
||||
wantIndex: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := agentselect.FindChatAgent(tt.agents)
|
||||
if len(tt.wantErrContains) > 0 {
|
||||
require.Error(t, err)
|
||||
for _, wantErr := range tt.wantErrContains {
|
||||
require.ErrorContains(t, err, wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.agents[tt.wantIndex], got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsChatAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "ExactSuffix",
|
||||
input: "agent-coderd-chat",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "UppercaseSuffix",
|
||||
input: "agent-CODERD-CHAT",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "MixedCaseSuffix",
|
||||
input: "agent-Coderd-Chat",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "NoSuffix",
|
||||
input: "my-agent",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "SuffixOnly",
|
||||
input: "-coderd-chat",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "PartialSuffix",
|
||||
input: "agent-coderd",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, tt.want, agentselect.IsChatAgent(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user