refactor: create tasks in coderd instead of frontend (#19280)

Instead of creating tasks with a specialized call to `CreateWorkspace`
on the frontend, we instead lift this to the backend and allow the
frontend to simply call `CreateAITask`.
This commit is contained in:
Danielle Maywood
2025-08-12 11:23:55 +01:00
committed by GitHub
parent cda1a3a593
commit f349edcc3c
14 changed files with 362 additions and 4 deletions
+110
View File
@@ -1,13 +1,20 @@
package coderd
import (
"database/sql"
"errors"
"fmt"
"net/http"
"slices"
"strings"
"github.com/google/uuid"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/codersdk"
)
@@ -61,3 +68,106 @@ func (api *API) aiTasksPrompts(rw http.ResponseWriter, r *http.Request) {
Prompts: promptsByBuildID,
})
}
// This endpoint is experimental and not guaranteed to be stable, so we're not
// generating public-facing documentation for it.
func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
apiKey = httpmw.APIKey(r)
auditor = api.Auditor.Load()
mems = httpmw.OrganizationMembersParam(r)
)
var req codersdk.CreateTaskRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
hasAITask, err := api.Database.GetTemplateVersionHasAITask(ctx, req.TemplateVersionID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) || rbac.IsUnauthorizedError(err) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching whether the template version has an AI task.",
Detail: err.Error(),
})
return
}
if !hasAITask {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf(`Template does not have required parameter %q`, codersdk.AITaskPromptParameterName),
})
return
}
createReq := codersdk.CreateWorkspaceRequest{
Name: req.Name,
TemplateVersionID: req.TemplateVersionID,
TemplateVersionPresetID: req.TemplateVersionPresetID,
RichParameterValues: []codersdk.WorkspaceBuildParameter{
{Name: codersdk.AITaskPromptParameterName, Value: req.Prompt},
},
}
var owner workspaceOwner
if mems.User != nil {
// This user fetch is an optimization path for the most common case of creating a
// task for 'Me'.
//
// This is also required to allow `owners` to create workspaces for users
// that are not in an organization.
owner = workspaceOwner{
ID: mems.User.ID,
Username: mems.User.Username,
AvatarURL: mems.User.AvatarURL,
}
} else {
// A task can still be created if the caller can read the organization
// member. The organization is required, which can be sourced from the
// template.
//
// TODO: This code gets called twice for each workspace build request.
// This is inefficient and costs at most 2 extra RTTs to the DB.
// This can be optimized. It exists as it is now for code simplicity.
// The most common case is to create a workspace for 'Me'. Which does
// not enter this code branch.
template, ok := requestTemplate(ctx, rw, createReq, api.Database)
if !ok {
return
}
// If the caller can find the organization membership in the same org
// as the template, then they can continue.
orgIndex := slices.IndexFunc(mems.Memberships, func(mem httpmw.OrganizationMember) bool {
return mem.OrganizationID == template.OrganizationID
})
if orgIndex == -1 {
httpapi.ResourceNotFound(rw)
return
}
member := mems.Memberships[orgIndex]
owner = workspaceOwner{
ID: member.UserID,
Username: member.Username,
AvatarURL: member.AvatarURL,
}
}
aReq, commitAudit := audit.InitRequest[database.WorkspaceTable](rw, &audit.RequestParams{
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
AdditionalFields: audit.AdditionalFields{
WorkspaceOwner: owner.Username,
},
})
defer commitAudit()
createWorkspace(ctx, aReq, apiKey.UserID, api, owner, createReq, rw, r)
}
+124
View File
@@ -1,9 +1,11 @@
package coderd_test
import (
"net/http"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
@@ -139,3 +141,125 @@ func TestAITasksPrompts(t *testing.T) {
require.Empty(t, prompts.Prompts)
})
}
func TestTaskCreate(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
taskName = "task-foo-bar-baz"
taskPrompt = "Some task prompt"
)
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
// Given: A template with an "AI Prompt" parameter
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionApply: echo.ApplyComplete,
ProvisionPlan: []*proto.Response{
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
Parameters: []*proto.RichParameter{{Name: "AI Prompt", Type: "string"}},
HasAiTasks: true,
}}},
},
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
expClient := codersdk.NewExperimentalClient(client)
// When: We attempt to create a Task.
workspace, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
Name: taskName,
TemplateVersionID: template.ActiveVersionID,
Prompt: taskPrompt,
})
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
// Then: We expect a workspace to have been created.
assert.Equal(t, taskName, workspace.Name)
assert.Equal(t, template.ID, workspace.TemplateID)
// And: We expect it to have the "AI Prompt" parameter correctly set.
parameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID)
require.NoError(t, err)
require.Len(t, parameters, 1)
assert.Equal(t, codersdk.AITaskPromptParameterName, parameters[0].Name)
assert.Equal(t, taskPrompt, parameters[0].Value)
})
t.Run("FailsOnNonTaskTemplate", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
taskName = "task-foo-bar-baz"
taskPrompt = "Some task prompt"
)
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
// Given: A template without an "AI Prompt" parameter
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
expClient := codersdk.NewExperimentalClient(client)
// When: We attempt to create a Task.
_, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
Name: taskName,
TemplateVersionID: template.ActiveVersionID,
Prompt: taskPrompt,
})
// Then: We expect it to fail.
var sdkErr *codersdk.Error
require.Error(t, err)
require.ErrorAsf(t, err, &sdkErr, "error should be of type *codersdk.Error")
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("FailsOnInvalidTemplate", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
taskName = "task-foo-bar-baz"
taskPrompt = "Some task prompt"
)
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
// Given: A template
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
_ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
expClient := codersdk.NewExperimentalClient(client)
// When: We attempt to create a Task with an invalid template version ID.
_, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
Name: taskName,
TemplateVersionID: uuid.New(),
Prompt: taskPrompt,
})
// Then: We expect it to fail.
var sdkErr *codersdk.Error
require.Error(t, err)
require.ErrorAsf(t, err, &sdkErr, "error should be of type *codersdk.Error")
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
})
}
+9
View File
@@ -995,6 +995,15 @@ func New(options *Options) *API {
r.Route("/aitasks", func(r chi.Router) {
r.Get("/prompts", api.aiTasksPrompts)
})
r.Route("/tasks", func(r chi.Router) {
r.Use(apiRateLimiter)
r.Route("/{user}", func(r chi.Router) {
r.Use(httpmw.ExtractOrganizationMembersParam(options.Database, api.HTTPAuth.Authorize))
r.Post("/", api.tasksCreate)
})
})
r.Route("/mcp", func(r chi.Router) {
r.Use(
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP),
+11
View File
@@ -2863,6 +2863,17 @@ func (q *querier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg
return tv, nil
}
func (q *querier) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) {
// If we can successfully call `GetTemplateVersionByID`, then
// we know the actor has sufficient permissions to know if the
// template has an AI task.
if _, err := q.GetTemplateVersionByID(ctx, id); err != nil {
return false, err
}
return q.db.GetTemplateVersionHasAITask(ctx, id)
}
func (q *querier) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) {
// An actor can read template version parameters if they can read the related template.
tv, err := q.db.GetTemplateVersionByID(ctx, templateVersionID)
+14
View File
@@ -1443,6 +1443,20 @@ func (s *MethodTestSuite) TestTemplate() {
})
check.Args(now.Add(-time.Hour)).Asserts(rbac.ResourceTemplate.All(), policy.ActionRead)
}))
s.Run("GetTemplateVersionHasAITask", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
u := dbgen.User(s.T(), db, database.User{})
t := dbgen.Template(s.T(), db, database.Template{
OrganizationID: o.ID,
CreatedBy: u.ID,
})
tv := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{
OrganizationID: o.ID,
TemplateID: uuid.NullUUID{UUID: t.ID, Valid: true},
CreatedBy: u.ID,
})
check.Args(tv.ID).Asserts(t, policy.ActionRead)
}))
s.Run("GetTemplatesWithFilter", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
u := dbgen.User(s.T(), db, database.User{})
@@ -1531,6 +1531,13 @@ func (m queryMetricsStore) GetTemplateVersionByTemplateIDAndName(ctx context.Con
return version, err
}
func (m queryMetricsStore) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) {
start := time.Now()
r0, r1 := m.s.GetTemplateVersionHasAITask(ctx, id)
m.queryLatencies.WithLabelValues("GetTemplateVersionHasAITask").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) {
start := time.Now()
parameters, err := m.s.GetTemplateVersionParameters(ctx, templateVersionID)
+15
View File
@@ -3256,6 +3256,21 @@ func (mr *MockStoreMockRecorder) GetTemplateVersionByTemplateIDAndName(ctx, arg
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTemplateVersionByTemplateIDAndName", reflect.TypeOf((*MockStore)(nil).GetTemplateVersionByTemplateIDAndName), ctx, arg)
}
// GetTemplateVersionHasAITask mocks base method.
func (m *MockStore) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTemplateVersionHasAITask", ctx, id)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTemplateVersionHasAITask indicates an expected call of GetTemplateVersionHasAITask.
func (mr *MockStoreMockRecorder) GetTemplateVersionHasAITask(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTemplateVersionHasAITask", reflect.TypeOf((*MockStore)(nil).GetTemplateVersionHasAITask), ctx, id)
}
// GetTemplateVersionParameters mocks base method.
func (m *MockStore) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) {
m.ctrl.T.Helper()
+1
View File
@@ -354,6 +354,7 @@ type sqlcQuerier interface {
GetTemplateVersionByID(ctx context.Context, id uuid.UUID) (TemplateVersion, error)
GetTemplateVersionByJobID(ctx context.Context, jobID uuid.UUID) (TemplateVersion, error)
GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg GetTemplateVersionByTemplateIDAndNameParams) (TemplateVersion, error)
GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error)
GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]TemplateVersionParameter, error)
GetTemplateVersionTerraformValues(ctx context.Context, templateVersionID uuid.UUID) (TemplateVersionTerraformValue, error)
GetTemplateVersionVariables(ctx context.Context, templateVersionID uuid.UUID) ([]TemplateVersionVariable, error)
+15
View File
@@ -12870,6 +12870,21 @@ func (q *sqlQuerier) GetTemplateVersionByTemplateIDAndName(ctx context.Context,
return i, err
}
const getTemplateVersionHasAITask = `-- name: GetTemplateVersionHasAITask :one
SELECT EXISTS (
SELECT 1
FROM template_versions
WHERE id = $1 AND has_ai_task = TRUE
)
`
func (q *sqlQuerier) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) {
row := q.db.QueryRowContext(ctx, getTemplateVersionHasAITask, id)
var exists bool
err := row.Scan(&exists)
return exists, err
}
const getTemplateVersionsByIDs = `-- name: GetTemplateVersionsByIDs :many
SELECT
id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, has_ai_task, created_by_avatar_url, created_by_username, created_by_name
@@ -234,3 +234,10 @@ FROM
WHERE
template_versions.id IN (archived_versions.id)
RETURNING template_versions.id;
-- name: GetTemplateVersionHasAITask :one
SELECT EXISTS (
SELECT 1
FROM template_versions
WHERE id = $1 AND has_ai_task = TRUE
);
+27
View File
@@ -3,6 +3,7 @@ package codersdk
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
@@ -44,3 +45,29 @@ func (c *ExperimentalClient) AITaskPrompts(ctx context.Context, buildIDs []uuid.
var prompts AITasksPromptsResponse
return prompts, json.NewDecoder(res.Body).Decode(&prompts)
}
type CreateTaskRequest struct {
Name string `json:"name"`
TemplateVersionID uuid.UUID `json:"template_version_id" format:"uuid"`
TemplateVersionPresetID uuid.UUID `json:"template_version_preset_id,omitempty" format:"uuid"`
Prompt string `json:"prompt"`
}
func (c *ExperimentalClient) CreateTask(ctx context.Context, user string, request CreateTaskRequest) (Workspace, error) {
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/tasks/%s", user), request)
if err != nil {
return Workspace{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return Workspace{}, ReadBodyAsError(res)
}
var workspace Workspace
if err := json.NewDecoder(res.Body).Decode(&workspace); err != nil {
return Workspace{}, err
}
return workspace, nil
}
+12
View File
@@ -2665,6 +2665,18 @@ class ExperimentalApiMethods {
return response.data;
};
createTask = async (
user: string,
req: TypesGen.CreateTaskRequest,
): Promise<TypesGen.Workspace> => {
const response = await this.axios.post<TypesGen.Workspace>(
`/api/experimental/tasks/${user}`,
req,
);
return response.data;
};
}
// This is a hard coded CSRF token/cookie pair for local development. In prod,
+8
View File
@@ -476,6 +476,14 @@ export interface CreateProvisionerKeyResponse {
readonly key: string;
}
// From codersdk/aitasks.go
export interface CreateTaskRequest {
readonly name: string;
readonly template_version_id: string;
readonly template_version_preset_id?: string;
readonly prompt: string;
}
// From codersdk/organizations.go
export interface CreateTemplateRequest {
readonly name: string;
+2 -4
View File
@@ -741,13 +741,11 @@ export const data = {
}
}
const workspace = await API.createWorkspace(userId, {
const workspace = await API.experimental.createTask(userId, {
name: `task-${generateWorkspaceName()}`,
template_version_id: templateVersionId,
template_version_preset_id: preset_id || undefined,
rich_parameter_values: [
{ name: AI_PROMPT_PARAMETER_NAME, value: prompt },
],
prompt,
});
return {