mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add task send and logs MCP tools (#20230)
Closes https://github.com/coder/internal/issues/776
This commit is contained in:
@@ -119,17 +119,21 @@ func (b WorkspaceBuildBuilder) WithAgent(mutations ...func([]*sdkproto.Agent) []
|
||||
return b
|
||||
}
|
||||
|
||||
func (b WorkspaceBuildBuilder) WithTask() WorkspaceBuildBuilder {
|
||||
func (b WorkspaceBuildBuilder) WithTask(seed *sdkproto.App) WorkspaceBuildBuilder {
|
||||
//nolint: revive // returns modified struct
|
||||
b.taskAppID = uuid.New()
|
||||
if seed == nil {
|
||||
seed = &sdkproto.App{}
|
||||
}
|
||||
return b.Params(database.WorkspaceBuildParameter{
|
||||
Name: codersdk.AITaskPromptParameterName,
|
||||
Value: "list me",
|
||||
}).WithAgent(func(a []*sdkproto.Agent) []*sdkproto.Agent {
|
||||
a[0].Apps = []*sdkproto.App{
|
||||
{
|
||||
Id: b.taskAppID.String(),
|
||||
Slug: "vcode",
|
||||
Id: takeFirst(seed.Id, b.taskAppID.String()),
|
||||
Slug: takeFirst(seed.Slug, "vcode"),
|
||||
Url: takeFirst(seed.Url, ""),
|
||||
},
|
||||
}
|
||||
return a
|
||||
|
||||
@@ -55,6 +55,8 @@ const (
|
||||
ToolNameDeleteTask = "coder_delete_task"
|
||||
ToolNameListTasks = "coder_list_tasks"
|
||||
ToolNameGetTaskStatus = "coder_get_task_status"
|
||||
ToolNameSendTaskInput = "coder_send_task_input"
|
||||
ToolNameGetTaskLogs = "coder_get_task_logs"
|
||||
)
|
||||
|
||||
func NewDeps(client *codersdk.Client, opts ...func(*Deps)) (Deps, error) {
|
||||
@@ -233,6 +235,8 @@ var All = []GenericTool{
|
||||
DeleteTask.Generic(),
|
||||
ListTasks.Generic(),
|
||||
GetTaskStatus.Generic(),
|
||||
SendTaskInput.Generic(),
|
||||
GetTaskLogs.Generic(),
|
||||
}
|
||||
|
||||
type ReportTaskArgs struct {
|
||||
@@ -2033,6 +2037,97 @@ var GetTaskStatus = Tool[GetTaskStatusArgs, GetTaskStatusResponse]{
|
||||
},
|
||||
}
|
||||
|
||||
type SendTaskInputArgs struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
var SendTaskInput = Tool[SendTaskInputArgs, codersdk.Response]{
|
||||
Tool: aisdk.Tool{
|
||||
Name: ToolNameSendTaskInput,
|
||||
Description: `Send input to a running task.`,
|
||||
Schema: aisdk.Schema{
|
||||
Properties: map[string]any{
|
||||
"task_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": taskIDDescription("prompt"),
|
||||
},
|
||||
"input": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The input to send to the task.",
|
||||
},
|
||||
},
|
||||
Required: []string{"task_id", "input"},
|
||||
},
|
||||
},
|
||||
UserClientOptional: true,
|
||||
Handler: func(ctx context.Context, deps Deps, args SendTaskInputArgs) (codersdk.Response, error) {
|
||||
if args.TaskID == "" {
|
||||
return codersdk.Response{}, xerrors.New("task_id is required")
|
||||
}
|
||||
|
||||
if args.Input == "" {
|
||||
return codersdk.Response{}, xerrors.New("input is required")
|
||||
}
|
||||
|
||||
expClient := codersdk.NewExperimentalClient(deps.coderClient)
|
||||
id, owner, err := resolveTaskID(ctx, deps.coderClient, args.TaskID)
|
||||
if err != nil {
|
||||
return codersdk.Response{}, err
|
||||
}
|
||||
|
||||
err = expClient.TaskSend(ctx, owner, id, codersdk.TaskSendRequest{
|
||||
Input: args.Input,
|
||||
})
|
||||
if err != nil {
|
||||
return codersdk.Response{}, xerrors.Errorf("send task input %q: %w", args.TaskID, err)
|
||||
}
|
||||
|
||||
return codersdk.Response{
|
||||
Message: "Input sent to task successfully.",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
type GetTaskLogsArgs struct {
|
||||
TaskID string `json:"task_id"`
|
||||
}
|
||||
|
||||
var GetTaskLogs = Tool[GetTaskLogsArgs, codersdk.TaskLogsResponse]{
|
||||
Tool: aisdk.Tool{
|
||||
Name: ToolNameGetTaskLogs,
|
||||
Description: `Get the logs of a task.`,
|
||||
Schema: aisdk.Schema{
|
||||
Properties: map[string]any{
|
||||
"task_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": taskIDDescription("query"),
|
||||
},
|
||||
},
|
||||
Required: []string{"task_id"},
|
||||
},
|
||||
},
|
||||
UserClientOptional: true,
|
||||
Handler: func(ctx context.Context, deps Deps, args GetTaskLogsArgs) (codersdk.TaskLogsResponse, error) {
|
||||
if args.TaskID == "" {
|
||||
return codersdk.TaskLogsResponse{}, xerrors.New("task_id is required")
|
||||
}
|
||||
|
||||
expClient := codersdk.NewExperimentalClient(deps.coderClient)
|
||||
id, owner, err := resolveTaskID(ctx, deps.coderClient, args.TaskID)
|
||||
if err != nil {
|
||||
return codersdk.TaskLogsResponse{}, err
|
||||
}
|
||||
|
||||
logs, err := expClient.TaskLogs(ctx, owner, id)
|
||||
if err != nil {
|
||||
return codersdk.TaskLogsResponse{}, xerrors.Errorf("get task logs %q: %w", args.TaskID, err)
|
||||
}
|
||||
|
||||
return logs, nil
|
||||
},
|
||||
}
|
||||
|
||||
// normalizedNamedWorkspace normalizes the workspace name before getting the
|
||||
// workspace by name.
|
||||
func normalizedNamedWorkspace(ctx context.Context, client *codersdk.Client, name string) (codersdk.Workspace, error) {
|
||||
@@ -2110,3 +2205,15 @@ func taskIDDescription(action string) string {
|
||||
func userDescription(action string) string {
|
||||
return fmt.Sprintf("Username or ID of the user for which to %s. Omit or use the `me` keyword to %s for the authenticated user.", action, action)
|
||||
}
|
||||
|
||||
func resolveTaskID(ctx context.Context, coderClient *codersdk.Client, taskID string) (uuid.UUID, string, error) {
|
||||
id, err := uuid.Parse(taskID)
|
||||
if err == nil {
|
||||
return id, codersdk.Me, nil
|
||||
}
|
||||
ws, err := normalizedNamedWorkspace(ctx, coderClient, taskID)
|
||||
if err != nil {
|
||||
return uuid.UUID{}, codersdk.Me, xerrors.Errorf("get task workspace %q: %w", taskID, err)
|
||||
}
|
||||
return ws.ID, ws.OwnerName, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -21,12 +23,14 @@ import (
|
||||
|
||||
"github.com/coder/aisdk-go"
|
||||
|
||||
agentapi "github.com/coder/agentapi-sdk-go"
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/codersdk/toolsdk"
|
||||
@@ -896,7 +900,7 @@ func TestTools(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: member.ID,
|
||||
TemplateID: aiTV.Template.ID,
|
||||
}).WithTask().Do()
|
||||
}).WithTask(nil).Do()
|
||||
|
||||
// nolint:gocritic // This is in a test package and does not end up in the build
|
||||
_ = dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
|
||||
@@ -904,7 +908,7 @@ func TestTools(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: member.ID,
|
||||
TemplateID: aiTV.Template.ID,
|
||||
}).WithTask().Do()
|
||||
}).WithTask(nil).Do()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -991,7 +995,7 @@ func TestTools(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: member.ID,
|
||||
TemplateID: aiTV.Template.ID,
|
||||
}).WithTask().Do()
|
||||
}).WithTask(nil).Do()
|
||||
|
||||
// These tasks should show up.
|
||||
for i := range 5 {
|
||||
@@ -1007,7 +1011,7 @@ func TestTools(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taskUser.ID,
|
||||
TemplateID: aiTV.Template.ID,
|
||||
}).Seed(database.WorkspaceBuild{Transition: transition}).WithTask().Do()
|
||||
}).Seed(database.WorkspaceBuild{Transition: transition}).WithTask(nil).Do()
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
@@ -1079,7 +1083,7 @@ func TestTools(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: member.ID,
|
||||
TemplateID: aiTV.Template.ID,
|
||||
}).WithTask().Do()
|
||||
}).WithTask(nil).Do()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1287,6 +1291,263 @@ func TestTools(t *testing.T) {
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SendTaskInput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Start a fake AgentAPI that accepts GET /status and POST /message.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/status" {
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, agentapi.GetStatusResponse{
|
||||
Status: agentapi.StatusStable,
|
||||
})
|
||||
return
|
||||
}
|
||||
if r.Method == http.MethodPost && r.URL.Path == "/message" {
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var req agentapi.PostMessageParams
|
||||
ok := httpapi.Read(r.Context(), rw, r, &req)
|
||||
assert.True(t, ok, "failed to read request")
|
||||
|
||||
assert.Equal(t, req.Content, "frob the baz")
|
||||
assert.Equal(t, req.Type, agentapi.MessageTypeUser)
|
||||
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, agentapi.PostMessageResponse{
|
||||
Ok: true,
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
// nolint:gocritic // This is in a test package and does not end up in the build
|
||||
aiTV := dbfake.TemplateVersion(t, store).Seed(database.TemplateVersion{
|
||||
OrganizationID: owner.OrganizationID,
|
||||
CreatedBy: member.ID,
|
||||
HasAITask: sql.NullBool{
|
||||
Bool: true,
|
||||
Valid: true,
|
||||
},
|
||||
}).Do()
|
||||
|
||||
// nolint:gocritic // This is in a test package and does not end up in the build
|
||||
ws := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
|
||||
Name: "send-task-input",
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: member.ID,
|
||||
TemplateID: aiTV.Template.ID,
|
||||
}).WithTask(&proto.App{Url: srv.URL}).Do()
|
||||
|
||||
_ = agenttest.New(t, client.URL, ws.AgentToken)
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, client, ws.Workspace.ID).Wait()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args toolsdk.SendTaskInputArgs
|
||||
error string
|
||||
}{
|
||||
{
|
||||
name: "ByUUID",
|
||||
args: toolsdk.SendTaskInputArgs{
|
||||
TaskID: ws.Workspace.ID.String(),
|
||||
Input: "frob the baz",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ByWorkspaceIdentifier",
|
||||
args: toolsdk.SendTaskInputArgs{
|
||||
TaskID: "send-task-input",
|
||||
Input: "frob the baz",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NoID",
|
||||
args: toolsdk.SendTaskInputArgs{},
|
||||
error: "task_id is required",
|
||||
},
|
||||
{
|
||||
name: "NoInput",
|
||||
args: toolsdk.SendTaskInputArgs{
|
||||
TaskID: "send-task-input",
|
||||
},
|
||||
error: "input is required",
|
||||
},
|
||||
{
|
||||
name: "NoTaskByID",
|
||||
args: toolsdk.SendTaskInputArgs{
|
||||
TaskID: uuid.New().String(),
|
||||
Input: "this is ignored",
|
||||
},
|
||||
error: "Resource not found",
|
||||
},
|
||||
{
|
||||
name: "NoTaskByWorkspaceIdentifier",
|
||||
args: toolsdk.SendTaskInputArgs{
|
||||
TaskID: "non-existent",
|
||||
Input: "this is ignored",
|
||||
},
|
||||
error: "Resource not found",
|
||||
},
|
||||
{
|
||||
name: "ExistsButNotATask",
|
||||
args: toolsdk.SendTaskInputArgs{
|
||||
TaskID: r.Workspace.ID.String(),
|
||||
Input: "this is ignored",
|
||||
},
|
||||
error: "Task is not configured with a sidebar app",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tb, err := toolsdk.NewDeps(memberClient)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = testTool(t, toolsdk.SendTaskInput, tb, tt.args)
|
||||
if tt.error != "" {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, tt.error)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetTaskLogs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
messages := []agentapi.Message{
|
||||
{
|
||||
Id: 0,
|
||||
Content: "welcome",
|
||||
Role: agentapi.RoleAgent,
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Content: "frob the dazzle",
|
||||
Role: agentapi.RoleUser,
|
||||
},
|
||||
{
|
||||
Id: 2,
|
||||
Content: "frob dazzled",
|
||||
Role: agentapi.RoleAgent,
|
||||
},
|
||||
}
|
||||
|
||||
// Start a fake AgentAPI that returns some messages.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/messages" {
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, agentapi.GetMessagesResponse{
|
||||
Messages: messages,
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
// nolint:gocritic // This is in a test package and does not end up in the build
|
||||
aiTV := dbfake.TemplateVersion(t, store).Seed(database.TemplateVersion{
|
||||
OrganizationID: owner.OrganizationID,
|
||||
CreatedBy: member.ID,
|
||||
HasAITask: sql.NullBool{
|
||||
Bool: true,
|
||||
Valid: true,
|
||||
},
|
||||
}).Do()
|
||||
|
||||
// nolint:gocritic // This is in a test package and does not end up in the build
|
||||
ws := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
|
||||
Name: "get-task-logs",
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: member.ID,
|
||||
TemplateID: aiTV.Template.ID,
|
||||
}).WithTask(&proto.App{Url: srv.URL}).Do()
|
||||
|
||||
_ = agenttest.New(t, client.URL, ws.AgentToken)
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, client, ws.Workspace.ID).Wait()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args toolsdk.GetTaskLogsArgs
|
||||
expected []agentapi.Message
|
||||
error string
|
||||
}{
|
||||
{
|
||||
name: "ByUUID",
|
||||
args: toolsdk.GetTaskLogsArgs{
|
||||
TaskID: ws.Workspace.ID.String(),
|
||||
},
|
||||
expected: messages,
|
||||
},
|
||||
{
|
||||
name: "ByWorkspaceIdentifier",
|
||||
args: toolsdk.GetTaskLogsArgs{
|
||||
TaskID: "get-task-logs",
|
||||
},
|
||||
expected: messages,
|
||||
},
|
||||
{
|
||||
name: "NoID",
|
||||
args: toolsdk.GetTaskLogsArgs{},
|
||||
error: "task_id is required",
|
||||
},
|
||||
{
|
||||
name: "NoTaskByID",
|
||||
args: toolsdk.GetTaskLogsArgs{
|
||||
TaskID: uuid.New().String(),
|
||||
},
|
||||
error: "Resource not found",
|
||||
},
|
||||
{
|
||||
name: "NoTaskByWorkspaceIdentifier",
|
||||
args: toolsdk.GetTaskLogsArgs{
|
||||
TaskID: "non-existent",
|
||||
},
|
||||
error: "Resource not found",
|
||||
},
|
||||
{
|
||||
name: "ExistsButNotATask",
|
||||
args: toolsdk.GetTaskLogsArgs{
|
||||
TaskID: r.Workspace.ID.String(),
|
||||
},
|
||||
error: "Task is not configured with a sidebar app",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tb, err := toolsdk.NewDeps(memberClient)
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := testTool(t, toolsdk.GetTaskLogs, tb, tt.args)
|
||||
if tt.error != "" {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, tt.error)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Logs, len(tt.expected))
|
||||
for i, msg := range tt.expected {
|
||||
require.Equal(t, msg.Id, int64(res.Logs[i].ID))
|
||||
require.Equal(t, msg.Content, res.Logs[i].Content)
|
||||
if msg.Role == agentapi.RoleUser {
|
||||
require.Equal(t, codersdk.TaskLogTypeInput, res.Logs[i].Type)
|
||||
} else {
|
||||
require.Equal(t, codersdk.TaskLogTypeOutput, res.Logs[i].Type)
|
||||
}
|
||||
require.Equal(t, msg.Time, res.Logs[i].Time)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestedTools keeps track of which tools have been tested.
|
||||
|
||||
Reference in New Issue
Block a user