diff --git a/coderd/aitasks.go b/coderd/aitasks.go index 1d06daeae9..8601b50196 100644 --- a/coderd/aitasks.go +++ b/coderd/aitasks.go @@ -241,6 +241,7 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) { // Create task record in the database before creating the workspace so that // we can request that the workspace be linked to it after creation. dbTaskTable, err = tx.InsertTask(ctx, database.InsertTaskParams{ + ID: uuid.New(), OrganizationID: templateVersion.OrganizationID, OwnerID: owner.ID, Name: taskName, diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 12c1a28d98..b812be6e16 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -41,6 +41,7 @@ type WorkspaceResponse struct { Build database.WorkspaceBuild AgentToken string TemplateVersionResponse + Task database.Task } // WorkspaceBuildBuilder generates workspace builds and associated @@ -57,6 +58,7 @@ type WorkspaceBuildBuilder struct { agentToken string jobStatus database.ProvisionerJobStatus taskAppID uuid.UUID + taskSeed database.TaskTable } // WorkspaceBuild generates a workspace build for the provided workspace. @@ -115,25 +117,28 @@ func (b WorkspaceBuildBuilder) WithAgent(mutations ...func([]*sdkproto.Agent) [] return b } -func (b WorkspaceBuildBuilder) WithTask(seed *sdkproto.App) WorkspaceBuildBuilder { - if seed == nil { - seed = &sdkproto.App{} +func (b WorkspaceBuildBuilder) WithTask(taskSeed database.TaskTable, appSeed *sdkproto.App) WorkspaceBuildBuilder { + //nolint:revive // returns modified struct + b.taskSeed = taskSeed + + if appSeed == nil { + appSeed = &sdkproto.App{} } var err error //nolint: revive // returns modified struct - b.taskAppID, err = uuid.Parse(takeFirst(seed.Id, uuid.NewString())) + b.taskAppID, err = uuid.Parse(takeFirst(appSeed.Id, uuid.NewString())) require.NoError(b.t, err) return b.Params(database.WorkspaceBuildParameter{ Name: codersdk.AITaskPromptParameterName, - Value: "list me", + Value: b.taskSeed.Prompt, }).WithAgent(func(a []*sdkproto.Agent) []*sdkproto.Agent { a[0].Apps = []*sdkproto.App{ { Id: b.taskAppID.String(), - Slug: takeFirst(seed.Slug, "task-app"), - Url: takeFirst(seed.Url, ""), + Slug: takeFirst(appSeed.Slug, "task-app"), + Url: takeFirst(appSeed.Url, ""), }, } return a @@ -225,6 +230,37 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse { b.seed.WorkspaceID = b.ws.ID b.seed.InitiatorID = takeFirst(b.seed.InitiatorID, b.ws.OwnerID) + // If a task was requested, ensure it exists and is associated with this + // workspace. + if b.taskAppID != uuid.Nil { + b.logger.Debug(context.Background(), "creating or updating task", "task_id", b.taskSeed.ID) + b.taskSeed.OrganizationID = takeFirst(b.taskSeed.OrganizationID, b.ws.OrganizationID) + b.taskSeed.OwnerID = takeFirst(b.taskSeed.OwnerID, b.ws.OwnerID) + b.taskSeed.Name = takeFirst(b.taskSeed.Name, b.ws.Name) + b.taskSeed.WorkspaceID = uuid.NullUUID{UUID: takeFirst(b.taskSeed.WorkspaceID.UUID, b.ws.ID), Valid: true} + b.taskSeed.TemplateVersionID = takeFirst(b.taskSeed.TemplateVersionID, b.seed.TemplateVersionID) + + // Try to fetch existing task and update its workspace ID. + if task, err := b.db.GetTaskByID(ownerCtx, b.taskSeed.ID); err == nil { + if !task.WorkspaceID.Valid { + b.logger.Info(context.Background(), "updating task workspace id", "task_id", b.taskSeed.ID, "workspace_id", b.ws.ID) + _, err = b.db.UpdateTaskWorkspaceID(ownerCtx, database.UpdateTaskWorkspaceIDParams{ + ID: b.taskSeed.ID, + WorkspaceID: uuid.NullUUID{UUID: b.ws.ID, Valid: true}, + }) + require.NoError(b.t, err, "update task workspace id") + } else if task.WorkspaceID.UUID != b.ws.ID { + require.Fail(b.t, "task already has a workspace id, mismatch", task.WorkspaceID.UUID, b.ws.ID) + } + } else if errors.Is(err, sql.ErrNoRows) { + task := dbgen.Task(b.t, b.db, b.taskSeed) + b.taskSeed.ID = task.ID + b.logger.Info(context.Background(), "created new task", "task_id", b.taskSeed.ID) + } else { + require.NoError(b.t, err, "get task by id") + } + } + // Create a provisioner job for the build! payload, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{ WorkspaceBuildID: b.seed.ID, @@ -337,6 +373,11 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse { b.logger.Debug(context.Background(), "linked task to workspace build", slog.F("task_id", task.ID), slog.F("build_number", resp.Build.BuildNumber)) + + // Update task after linking. + task, err = b.db.GetTaskByID(ownerCtx, task.ID) + require.NoError(b.t, err, "get task by id") + resp.Task = task } for i := range b.params { diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 0a62911223..532460700a 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -1576,6 +1576,7 @@ func Task(t testing.TB, db database.Store, orig database.TaskTable) database.Tas } task, err := db.InsertTask(genCtx, database.InsertTaskParams{ + ID: takeFirst(orig.ID, uuid.New()), OrganizationID: orig.OrganizationID, OwnerID: orig.OwnerID, Name: takeFirst(orig.Name, taskname.GenerateFallback()), diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index e1b6cbd7ad..773f944756 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -7248,7 +7248,9 @@ func TestTaskNameUniqueness(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) + taskID := uuid.New() task, err := db.InsertTask(ctx, database.InsertTaskParams{ + ID: taskID, OrganizationID: org.ID, OwnerID: tt.ownerID, Name: tt.taskName, @@ -7263,6 +7265,7 @@ func TestTaskNameUniqueness(t *testing.T) { require.NoError(t, err) require.NotEqual(t, uuid.Nil, task.ID) require.NotEqual(t, task1.ID, task.ID) + require.Equal(t, taskID, task.ID) } }) } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index e6dfa9afd0..8e24ab04c9 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -13019,11 +13019,12 @@ const insertTask = `-- name: InsertTask :one INSERT INTO tasks (id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at) VALUES - (gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8) + ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at ` type InsertTaskParams struct { + ID uuid.UUID `db:"id" json:"id"` OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` Name string `db:"name" json:"name"` @@ -13036,6 +13037,7 @@ type InsertTaskParams struct { func (q *sqlQuerier) InsertTask(ctx context.Context, arg InsertTaskParams) (TaskTable, error) { row := q.db.QueryRowContext(ctx, insertTask, + arg.ID, arg.OrganizationID, arg.OwnerID, arg.Name, diff --git a/coderd/database/queries/tasks.sql b/coderd/database/queries/tasks.sql index 6c076b8cca..d0617ad39f 100644 --- a/coderd/database/queries/tasks.sql +++ b/coderd/database/queries/tasks.sql @@ -2,7 +2,7 @@ INSERT INTO tasks (id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at) VALUES - (gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8) + ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *; -- name: UpdateTaskWorkspaceID :one diff --git a/codersdk/toolsdk/toolsdk_test.go b/codersdk/toolsdk/toolsdk_test.go index 44da500400..749105f5b8 100644 --- a/codersdk/toolsdk/toolsdk_test.go +++ b/codersdk/toolsdk/toolsdk_test.go @@ -895,37 +895,27 @@ func TestTools(t *testing.T) { }, }).Do() - ws1Table := dbgen.Workspace(t, store, database.WorkspaceTable{ + build1 := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{ Name: "delete-task-workspace-1", OrganizationID: owner.OrganizationID, OwnerID: member.ID, TemplateID: aiTV.Template.ID, - }) - task1 := dbgen.Task(t, store, database.TaskTable{ - OrganizationID: owner.OrganizationID, - OwnerID: member.ID, - Name: ws1Table.Name, - WorkspaceID: uuid.NullUUID{UUID: ws1Table.ID, Valid: true}, - TemplateVersionID: aiTV.TemplateVersion.ID, - Prompt: "delete task 1", - }) - _ = dbfake.WorkspaceBuild(t, store, ws1Table).WithTask(nil).Do() + }).WithTask(database.TaskTable{ + Name: "delete-task-1", + Prompt: "delete task 1", + }, nil).Do() + task1 := build1.Task - ws2Table := dbgen.Workspace(t, store, database.WorkspaceTable{ + build2 := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{ Name: "delete-task-workspace-2", OrganizationID: owner.OrganizationID, OwnerID: member.ID, TemplateID: aiTV.Template.ID, - }) - task2 := dbgen.Task(t, store, database.TaskTable{ - OrganizationID: owner.OrganizationID, - OwnerID: member.ID, - Name: ws2Table.Name, - WorkspaceID: uuid.NullUUID{UUID: ws2Table.ID, Valid: true}, - TemplateVersionID: aiTV.TemplateVersion.ID, - Prompt: "delete task 2", - }) - _ = dbfake.WorkspaceBuild(t, store, ws2Table).WithTask(nil).Do() + }).WithTask(database.TaskTable{ + Name: "delete-task-2", + Prompt: "delete task 2", + }, nil).Do() + task2 := build2.Task tests := []struct { name string @@ -1113,21 +1103,16 @@ func TestTools(t *testing.T) { }, }).Do() - ws1Table := dbgen.Workspace(t, store, database.WorkspaceTable{ + build := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{ Name: "get-task-workspace-1", OrganizationID: owner.OrganizationID, OwnerID: member.ID, TemplateID: aiTV.Template.ID, - }) - task := dbgen.Task(t, store, database.TaskTable{ - OrganizationID: owner.OrganizationID, - OwnerID: member.ID, - Name: "get-task-1", - WorkspaceID: uuid.NullUUID{UUID: ws1Table.ID, Valid: true}, - TemplateVersionID: aiTV.TemplateVersion.ID, - Prompt: "get task", - }) - _ = dbfake.WorkspaceBuild(t, store, ws1Table).WithTask(nil).Do() + }).WithTask(database.TaskTable{ + Name: "get-task-1", + Prompt: "get task", + }, nil).Do() + task := build.Task tests := []struct { name string @@ -1376,21 +1361,16 @@ func TestTools(t *testing.T) { }, }).Do() - wsTable := dbgen.Workspace(t, store, database.WorkspaceTable{ + ws := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{ Name: "send-task-input-ws", OrganizationID: owner.OrganizationID, OwnerID: member.ID, TemplateID: aiTV.Template.ID, - }) - task := dbgen.Task(t, store, database.TaskTable{ - OrganizationID: owner.OrganizationID, - OwnerID: member.ID, - Name: "send-task-input", - WorkspaceID: uuid.NullUUID{UUID: wsTable.ID, Valid: true}, - TemplateVersionID: aiTV.TemplateVersion.ID, - Prompt: "send task input", - }) - ws := dbfake.WorkspaceBuild(t, store, wsTable).WithTask(&proto.App{Url: srv.URL}).Do() + }).WithTask(database.TaskTable{ + Name: "send-task-input", + Prompt: "send task input", + }, &proto.App{Url: srv.URL}).Do() + task := ws.Task _ = agenttest.New(t, client.URL, ws.AgentToken) coderdtest.NewWorkspaceAgentWaiter(t, client, ws.Workspace.ID).Wait() @@ -1513,21 +1493,16 @@ func TestTools(t *testing.T) { }, }).Do() - wsTable := dbgen.Workspace(t, store, database.WorkspaceTable{ + ws := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{ Name: "get-task-logs-ws", OrganizationID: owner.OrganizationID, OwnerID: member.ID, TemplateID: aiTV.Template.ID, - }) - task := dbgen.Task(t, store, database.TaskTable{ - OrganizationID: owner.OrganizationID, - OwnerID: member.ID, - Name: "get-task-logs", - WorkspaceID: uuid.NullUUID{UUID: wsTable.ID, Valid: true}, - TemplateVersionID: aiTV.TemplateVersion.ID, - Prompt: "get task logs", - }) - ws := dbfake.WorkspaceBuild(t, store, wsTable).WithTask(&proto.App{Url: srv.URL}).Do() + }).WithTask(database.TaskTable{ + Name: "get-task-logs", + Prompt: "get task logs", + }, &proto.App{Url: srv.URL}).Do() + task := ws.Task _ = agenttest.New(t, client.URL, ws.AgentToken) coderdtest.NewWorkspaceAgentWaiter(t, client, ws.Workspace.ID).Wait()