feat: add task send and logs MCP tools (#20230)

Closes https://github.com/coder/internal/issues/776
This commit is contained in:
Asher
2025-10-15 13:21:20 -08:00
committed by GitHub
parent 9bef5de30d
commit 41de4ad91a
3 changed files with 380 additions and 8 deletions
+7 -3
View File
@@ -119,17 +119,21 @@ func (b WorkspaceBuildBuilder) WithAgent(mutations ...func([]*sdkproto.Agent) []
return b
}
func (b WorkspaceBuildBuilder) WithTask() WorkspaceBuildBuilder {
func (b WorkspaceBuildBuilder) WithTask(seed *sdkproto.App) WorkspaceBuildBuilder {
//nolint: revive // returns modified struct
b.taskAppID = uuid.New()
if seed == nil {
seed = &sdkproto.App{}
}
return b.Params(database.WorkspaceBuildParameter{
Name: codersdk.AITaskPromptParameterName,
Value: "list me",
}).WithAgent(func(a []*sdkproto.Agent) []*sdkproto.Agent {
a[0].Apps = []*sdkproto.App{
{
Id: b.taskAppID.String(),
Slug: "vcode",
Id: takeFirst(seed.Id, b.taskAppID.String()),
Slug: takeFirst(seed.Slug, "vcode"),
Url: takeFirst(seed.Url, ""),
},
}
return a
+107
View File
@@ -55,6 +55,8 @@ const (
ToolNameDeleteTask = "coder_delete_task"
ToolNameListTasks = "coder_list_tasks"
ToolNameGetTaskStatus = "coder_get_task_status"
ToolNameSendTaskInput = "coder_send_task_input"
ToolNameGetTaskLogs = "coder_get_task_logs"
)
func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) {
@@ -233,6 +235,8 @@ var All = []GenericTool{
DeleteTask.Generic(),
ListTasks.Generic(),
GetTaskStatus.Generic(),
SendTaskInput.Generic(),
GetTaskLogs.Generic(),
}
type ReportTaskArgs struct {
@@ -2033,6 +2037,97 @@ var GetTaskStatus = Tool[GetTaskStatusArgs, GetTaskStatusResponse]{
},
}
type SendTaskInputArgs struct {
TaskID string `json:"task_id"`
Input string `json:"input"`
}
var SendTaskInput = Tool[SendTaskInputArgs, codersdk.Response]{
Tool: aisdk.Tool{
Name: ToolNameSendTaskInput,
Description: `Send input to a running task.`,
Schema: aisdk.Schema{
Properties: map[string]any{
"task_id": map[string]any{
"type": "string",
"description": taskIDDescription("prompt"),
},
"input": map[string]any{
"type": "string",
"description": "The input to send to the task.",
},
},
Required: []string{"task_id", "input"},
},
},
UserClientOptional: true,
Handler: func(ctx context.Context, deps Deps, args SendTaskInputArgs) (codersdk.Response, error) {
if args.TaskID == "" {
return codersdk.Response{}, xerrors.New("task_id is required")
}
if args.Input == "" {
return codersdk.Response{}, xerrors.New("input is required")
}
expClient := codersdk.NewExperimentalClient(deps.coderClient)
id, owner, err := resolveTaskID(ctx, deps.coderClient, args.TaskID)
if err != nil {
return codersdk.Response{}, err
}
err = expClient.TaskSend(ctx, owner, id, codersdk.TaskSendRequest{
Input: args.Input,
})
if err != nil {
return codersdk.Response{}, xerrors.Errorf("send task input %q: %w", args.TaskID, err)
}
return codersdk.Response{
Message: "Input sent to task successfully.",
}, nil
},
}
type GetTaskLogsArgs struct {
TaskID string `json:"task_id"`
}
var GetTaskLogs = Tool[GetTaskLogsArgs, codersdk.TaskLogsResponse]{
Tool: aisdk.Tool{
Name: ToolNameGetTaskLogs,
Description: `Get the logs of a task.`,
Schema: aisdk.Schema{
Properties: map[string]any{
"task_id": map[string]any{
"type": "string",
"description": taskIDDescription("query"),
},
},
Required: []string{"task_id"},
},
},
UserClientOptional: true,
Handler: func(ctx context.Context, deps Deps, args GetTaskLogsArgs) (codersdk.TaskLogsResponse, error) {
if args.TaskID == "" {
return codersdk.TaskLogsResponse{}, xerrors.New("task_id is required")
}
expClient := codersdk.NewExperimentalClient(deps.coderClient)
id, owner, err := resolveTaskID(ctx, deps.coderClient, args.TaskID)
if err != nil {
return codersdk.TaskLogsResponse{}, err
}
logs, err := expClient.TaskLogs(ctx, owner, id)
if err != nil {
return codersdk.TaskLogsResponse{}, xerrors.Errorf("get task logs %q: %w", args.TaskID, err)
}
return logs, nil
},
}
// normalizedNamedWorkspace normalizes the workspace name before getting the
// workspace by name.
func normalizedNamedWorkspace(ctx context.Context, client *codersdk.Client, name string) (codersdk.Workspace, error) {
@@ -2110,3 +2205,15 @@ func taskIDDescription(action string) string {
func userDescription(action string) string {
return fmt.Sprintf("Username or ID of the user for which to %s. Omit or use the `me` keyword to %s for the authenticated user.", action, action)
}
func resolveTaskID(ctx context.Context, coderClient *codersdk.Client, taskID string) (uuid.UUID, string, error) {
id, err := uuid.Parse(taskID)
if err == nil {
return id, codersdk.Me, nil
}
ws, err := normalizedNamedWorkspace(ctx, coderClient, taskID)
if err != nil {
return uuid.UUID{}, codersdk.Me, xerrors.Errorf("get task workspace %q: %w", taskID, err)
}
return ws.ID, ws.OwnerName, nil
}
+266 -5
View File
@@ -5,6 +5,8 @@ import (
"database/sql"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
@@ -21,12 +23,14 @@ import (
"github.com/coder/aisdk-go"
agentapi "github.com/coder/agentapi-sdk-go"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/toolsdk"
@@ -896,7 +900,7 @@ func TestTools(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: member.ID,
TemplateID: aiTV.Template.ID,
}).WithTask().Do()
}).WithTask(nil).Do()
// nolint:gocritic // This is in a test package and does not end up in the build
_ = dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
@@ -904,7 +908,7 @@ func TestTools(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: member.ID,
TemplateID: aiTV.Template.ID,
}).WithTask().Do()
}).WithTask(nil).Do()
tests := []struct {
name string
@@ -991,7 +995,7 @@ func TestTools(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: member.ID,
TemplateID: aiTV.Template.ID,
}).WithTask().Do()
}).WithTask(nil).Do()
// These tasks should show up.
for i := range 5 {
@@ -1007,7 +1011,7 @@ func TestTools(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taskUser.ID,
TemplateID: aiTV.Template.ID,
}).Seed(database.WorkspaceBuild{Transition: transition}).WithTask().Do()
}).Seed(database.WorkspaceBuild{Transition: transition}).WithTask(nil).Do()
}
tests := []struct {
@@ -1079,7 +1083,7 @@ func TestTools(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: member.ID,
TemplateID: aiTV.Template.ID,
}).WithTask().Do()
}).WithTask(nil).Do()
tests := []struct {
name string
@@ -1287,6 +1291,263 @@ func TestTools(t *testing.T) {
})
}
})
t.Run("SendTaskInput", func(t *testing.T) {
t.Parallel()
// Start a fake AgentAPI that accepts GET /status and POST /message.
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet && r.URL.Path == "/status" {
httpapi.Write(r.Context(), rw, http.StatusOK, agentapi.GetStatusResponse{
Status: agentapi.StatusStable,
})
return
}
if r.Method == http.MethodPost && r.URL.Path == "/message" {
rw.Header().Set("Content-Type", "application/json")
var req agentapi.PostMessageParams
ok := httpapi.Read(r.Context(), rw, r, &req)
assert.True(t, ok, "failed to read request")
assert.Equal(t, req.Content, "frob the baz")
assert.Equal(t, req.Type, agentapi.MessageTypeUser)
httpapi.Write(r.Context(), rw, http.StatusOK, agentapi.PostMessageResponse{
Ok: true,
})
return
}
rw.WriteHeader(http.StatusInternalServerError)
}))
t.Cleanup(srv.Close)
// nolint:gocritic // This is in a test package and does not end up in the build
aiTV := dbfake.TemplateVersion(t, store).Seed(database.TemplateVersion{
OrganizationID: owner.OrganizationID,
CreatedBy: member.ID,
HasAITask: sql.NullBool{
Bool: true,
Valid: true,
},
}).Do()
// nolint:gocritic // This is in a test package and does not end up in the build
ws := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
Name: "send-task-input",
OrganizationID: owner.OrganizationID,
OwnerID: member.ID,
TemplateID: aiTV.Template.ID,
}).WithTask(&proto.App{Url: srv.URL}).Do()
_ = agenttest.New(t, client.URL, ws.AgentToken)
coderdtest.NewWorkspaceAgentWaiter(t, client, ws.Workspace.ID).Wait()
tests := []struct {
name string
args toolsdk.SendTaskInputArgs
error string
}{
{
name: "ByUUID",
args: toolsdk.SendTaskInputArgs{
TaskID: ws.Workspace.ID.String(),
Input: "frob the baz",
},
},
{
name: "ByWorkspaceIdentifier",
args: toolsdk.SendTaskInputArgs{
TaskID: "send-task-input",
Input: "frob the baz",
},
},
{
name: "NoID",
args: toolsdk.SendTaskInputArgs{},
error: "task_id is required",
},
{
name: "NoInput",
args: toolsdk.SendTaskInputArgs{
TaskID: "send-task-input",
},
error: "input is required",
},
{
name: "NoTaskByID",
args: toolsdk.SendTaskInputArgs{
TaskID: uuid.New().String(),
Input: "this is ignored",
},
error: "Resource not found",
},
{
name: "NoTaskByWorkspaceIdentifier",
args: toolsdk.SendTaskInputArgs{
TaskID: "non-existent",
Input: "this is ignored",
},
error: "Resource not found",
},
{
name: "ExistsButNotATask",
args: toolsdk.SendTaskInputArgs{
TaskID: r.Workspace.ID.String(),
Input: "this is ignored",
},
error: "Task is not configured with a sidebar app",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
tb, err := toolsdk.NewDeps(memberClient)
require.NoError(t, err)
_, err = testTool(t, toolsdk.SendTaskInput, tb, tt.args)
if tt.error != "" {
require.Error(t, err)
require.ErrorContains(t, err, tt.error)
} else {
require.NoError(t, err)
}
})
}
})
t.Run("GetTaskLogs", func(t *testing.T) {
t.Parallel()
messages := []agentapi.Message{
{
Id: 0,
Content: "welcome",
Role: agentapi.RoleAgent,
},
{
Id: 1,
Content: "frob the dazzle",
Role: agentapi.RoleUser,
},
{
Id: 2,
Content: "frob dazzled",
Role: agentapi.RoleAgent,
},
}
// Start a fake AgentAPI that returns some messages.
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet && r.URL.Path == "/messages" {
httpapi.Write(r.Context(), rw, http.StatusOK, agentapi.GetMessagesResponse{
Messages: messages,
})
return
}
rw.WriteHeader(http.StatusInternalServerError)
}))
t.Cleanup(srv.Close)
// nolint:gocritic // This is in a test package and does not end up in the build
aiTV := dbfake.TemplateVersion(t, store).Seed(database.TemplateVersion{
OrganizationID: owner.OrganizationID,
CreatedBy: member.ID,
HasAITask: sql.NullBool{
Bool: true,
Valid: true,
},
}).Do()
// nolint:gocritic // This is in a test package and does not end up in the build
ws := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
Name: "get-task-logs",
OrganizationID: owner.OrganizationID,
OwnerID: member.ID,
TemplateID: aiTV.Template.ID,
}).WithTask(&proto.App{Url: srv.URL}).Do()
_ = agenttest.New(t, client.URL, ws.AgentToken)
coderdtest.NewWorkspaceAgentWaiter(t, client, ws.Workspace.ID).Wait()
tests := []struct {
name string
args toolsdk.GetTaskLogsArgs
expected []agentapi.Message
error string
}{
{
name: "ByUUID",
args: toolsdk.GetTaskLogsArgs{
TaskID: ws.Workspace.ID.String(),
},
expected: messages,
},
{
name: "ByWorkspaceIdentifier",
args: toolsdk.GetTaskLogsArgs{
TaskID: "get-task-logs",
},
expected: messages,
},
{
name: "NoID",
args: toolsdk.GetTaskLogsArgs{},
error: "task_id is required",
},
{
name: "NoTaskByID",
args: toolsdk.GetTaskLogsArgs{
TaskID: uuid.New().String(),
},
error: "Resource not found",
},
{
name: "NoTaskByWorkspaceIdentifier",
args: toolsdk.GetTaskLogsArgs{
TaskID: "non-existent",
},
error: "Resource not found",
},
{
name: "ExistsButNotATask",
args: toolsdk.GetTaskLogsArgs{
TaskID: r.Workspace.ID.String(),
},
error: "Task is not configured with a sidebar app",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
tb, err := toolsdk.NewDeps(memberClient)
require.NoError(t, err)
res, err := testTool(t, toolsdk.GetTaskLogs, tb, tt.args)
if tt.error != "" {
require.Error(t, err)
require.ErrorContains(t, err, tt.error)
} else {
require.NoError(t, err)
require.Len(t, res.Logs, len(tt.expected))
for i, msg := range tt.expected {
require.Equal(t, msg.Id, int64(res.Logs[i].ID))
require.Equal(t, msg.Content, res.Logs[i].Content)
if msg.Role == agentapi.RoleUser {
require.Equal(t, codersdk.TaskLogTypeInput, res.Logs[i].Type)
} else {
require.Equal(t, codersdk.TaskLogTypeOutput, res.Logs[i].Type)
}
require.Equal(t, msg.Time, res.Logs[i].Time)
}
}
})
}
})
}
// TestedTools keeps track of which tools have been tested.