diff --git a/cli/exp_task_delete_test.go b/cli/exp_task_delete_test.go index e90ee8c5b1..04bad3ea5f 100644 --- a/cli/exp_task_delete_test.go +++ b/cli/exp_task_delete_test.go @@ -56,19 +56,14 @@ func TestExpTaskDelete(t *testing.T) { taskID := uuid.MustParse(id1) return func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"": + case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks/me/exists": c.nameResolves.Add(1) - httpapi.Write(r.Context(), w, http.StatusOK, struct { - Tasks []codersdk.Task `json:"tasks"` - Count int `json:"count"` - }{ - Tasks: []codersdk.Task{{ + httpapi.Write(r.Context(), w, http.StatusOK, + codersdk.Task{ ID: taskID, Name: "exists", OwnerName: "me", - }}, - Count: 1, - }) + }) case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id1: c.deleteCalls.Add(1) w.WriteHeader(http.StatusAccepted) @@ -107,27 +102,21 @@ func TestExpTaskDelete(t *testing.T) { name: "Multiple_YesFlag", args: []string{"--yes", "first", id4}, buildHandler: func(c *testCounters) http.HandlerFunc { - firstID := uuid.MustParse(id3) return func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"": + case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks/me/first": c.nameResolves.Add(1) - httpapi.Write(r.Context(), w, http.StatusOK, struct { - Tasks []codersdk.Task `json:"tasks"` - Count int `json:"count"` - }{ - Tasks: []codersdk.Task{{ - ID: firstID, - Name: "first", - OwnerName: "me", - }}, - Count: 1, + httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Task{ + ID: uuid.MustParse(id3), + Name: "first", + OwnerName: "me", }) case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks/me/"+id4: + c.nameResolves.Add(1) httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Task{ ID: uuid.MustParse(id4), OwnerName: "me", - Name: "uuid-task-2", + Name: "uuid-task-4", }) case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id3: c.deleteCalls.Add(1) @@ -141,7 +130,7 @@ func TestExpTaskDelete(t *testing.T) { } }, wantDeleteCalls: 2, - wantNameResolves: 1, + wantNameResolves: 2, wantDeletedMessage: 2, }, { @@ -174,20 +163,14 @@ func TestExpTaskDelete(t *testing.T) { taskID := uuid.MustParse(id5) return func(w http.ResponseWriter, r *http.Request) { switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"": + case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks/me/bad": c.nameResolves.Add(1) - httpapi.Write(r.Context(), w, http.StatusOK, struct { - Tasks []codersdk.Task `json:"tasks"` - Count int `json:"count"` - }{ - Tasks: []codersdk.Task{{ - ID: taskID, - Name: "bad", - OwnerName: "me", - }}, - Count: 1, + httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Task{ + ID: taskID, + Name: "bad", + OwnerName: "me", }) - case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id5: + case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/bad": httpapi.InternalServerError(w, xerrors.New("boom")) default: httpapi.InternalServerError(w, xerrors.New("unwanted path: "+r.Method+" "+r.URL.Path)) diff --git a/cli/exp_task_status_test.go b/cli/exp_task_status_test.go index f15222d51b..c3e19f94f7 100644 --- a/cli/exp_task_status_test.go +++ b/cli/exp_task_status_test.go @@ -36,17 +36,9 @@ func Test_TaskStatus(t *testing.T) { hf: func(ctx context.Context, _ time.Time) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { - case "/api/experimental/tasks": - if r.URL.Query().Get("q") == "owner:\"me\"" { - httpapi.Write(ctx, w, http.StatusOK, struct { - Tasks []codersdk.Task `json:"tasks"` - Count int `json:"count"` - }{ - Tasks: []codersdk.Task{}, - Count: 0, - }) - return - } + case "/api/experimental/tasks/me/doesnotexist": + httpapi.ResourceNotFound(w) + return default: t.Errorf("unexpected path: %s", r.URL.Path) } @@ -60,35 +52,7 @@ func Test_TaskStatus(t *testing.T) { hf: func(ctx context.Context, now time.Time) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { - case "/api/experimental/tasks": - if r.URL.Query().Get("q") == "owner:\"me\"" { - httpapi.Write(ctx, w, http.StatusOK, struct { - Tasks []codersdk.Task `json:"tasks"` - Count int `json:"count"` - }{ - Tasks: []codersdk.Task{{ - ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"), - Name: "exists", - OwnerName: "me", - WorkspaceStatus: codersdk.WorkspaceStatusRunning, - CreatedAt: now, - UpdatedAt: now, - CurrentState: &codersdk.TaskStateEntry{ - State: codersdk.TaskStateWorking, - Timestamp: now, - Message: "Thinking furiously...", - }, - WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{ - Healthy: true, - }, - WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady), - Status: codersdk.TaskStatusActive, - }}, - Count: 1, - }) - return - } - case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111": + case "/api/experimental/tasks/me/exists": httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{ ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"), WorkspaceStatus: codersdk.WorkspaceStatusRunning, @@ -124,30 +88,21 @@ func Test_TaskStatus(t *testing.T) { var calls atomic.Int64 return func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { - case "/api/experimental/tasks": - if r.URL.Query().Get("q") == "owner:\"me\"" { - // Return initial task state for --watch test - httpapi.Write(ctx, w, http.StatusOK, struct { - Tasks []codersdk.Task `json:"tasks"` - Count int `json:"count"` - }{ - Tasks: []codersdk.Task{{ - ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"), - Name: "exists", - OwnerName: "me", - WorkspaceStatus: codersdk.WorkspaceStatusPending, - CreatedAt: now.Add(-5 * time.Second), - UpdatedAt: now.Add(-5 * time.Second), - WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{ - Healthy: true, - }, - WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady), - Status: codersdk.TaskStatusPending, - }}, - Count: 1, - }) - return - } + case "/api/experimental/tasks/me/exists": + httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{ + ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"), + Name: "exists", + OwnerName: "me", + WorkspaceStatus: codersdk.WorkspaceStatusPending, + CreatedAt: now.Add(-5 * time.Second), + UpdatedAt: now.Add(-5 * time.Second), + WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{ + Healthy: true, + }, + WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady), + Status: codersdk.TaskStatusPending, + }) + return case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111": defer calls.Add(1) switch calls.Load() { @@ -263,40 +218,18 @@ func Test_TaskStatus(t *testing.T) { ts := time.Date(2025, 8, 26, 12, 34, 56, 0, time.UTC) return func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { - case "/api/experimental/tasks": - if r.URL.Query().Get("q") == "owner:\"me\"" { - httpapi.Write(ctx, w, http.StatusOK, struct { - Tasks []codersdk.Task `json:"tasks"` - Count int `json:"count"` - }{ - Tasks: []codersdk.Task{{ - ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"), - Name: "exists", - OwnerName: "me", - WorkspaceStatus: codersdk.WorkspaceStatusRunning, - CreatedAt: ts, - UpdatedAt: ts, - CurrentState: &codersdk.TaskStateEntry{ - State: codersdk.TaskStateWorking, - Timestamp: ts.Add(time.Second), - Message: "Thinking furiously...", - }, - WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{ - Healthy: true, - }, - WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady), - Status: codersdk.TaskStatusActive, - }}, - Count: 1, - }) - return - } - case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111": + case "/api/experimental/tasks/me/exists": httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{ - ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"), - WorkspaceStatus: codersdk.WorkspaceStatusRunning, - CreatedAt: ts, - UpdatedAt: ts, + ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"), + Name: "exists", + OwnerName: "me", + WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{ + Healthy: true, + }, + WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady), + WorkspaceStatus: codersdk.WorkspaceStatusRunning, + CreatedAt: ts, + UpdatedAt: ts, CurrentState: &codersdk.TaskStateEntry{ State: codersdk.TaskStateWorking, Timestamp: ts.Add(time.Second), diff --git a/coderd/aitasks_test.go b/coderd/aitasks_test.go index 34f6dd4a07..0151d77c19 100644 --- a/coderd/aitasks_test.go +++ b/coderd/aitasks_test.go @@ -156,12 +156,13 @@ func TestTasks(t *testing.T) { t.Parallel() var ( - client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) - ctx = testutil.Context(t, testutil.WaitLong) - user = coderdtest.CreateFirstUser(t, client) - template = createAITemplate(t, client, user) - wantPrompt = "review my code" - exp = codersdk.NewExperimentalClient(client) + client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + ctx = testutil.Context(t, testutil.WaitLong) + user = coderdtest.CreateFirstUser(t, client) + anotherUser, _ = coderdtest.CreateAnotherUser(t, client, user.OrganizationID) + template = createAITemplate(t, client, user) + wantPrompt = "review my code" + exp = codersdk.NewExperimentalClient(client) ) task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{ @@ -211,6 +212,24 @@ func TestTasks(t *testing.T) { assert.Equal(t, taskAppID, updated.WorkspaceAppID.UUID, "workspace app id should match") assert.NotEmpty(t, updated.WorkspaceStatus, "task status should not be empty") + // Fetch the task by name and verify the same result + byName, err := exp.TaskByOwnerAndName(ctx, codersdk.Me, task.Name) + require.NoError(t, err) + require.Equal(t, byName, updated) + + // Another member user should not be able to fetch the task + otherClient := codersdk.NewExperimentalClient(anotherUser) + _, err = otherClient.TaskByID(ctx, task.ID) + require.Error(t, err, "fetching task should fail by ID for another member user") + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) + // Also test by name + _, err = otherClient.TaskByOwnerAndName(ctx, task.OwnerName, task.Name) + require.Error(t, err, "fetching task should fail by name for another member user") + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) + // Stop the workspace coderdtest.MustTransitionWorkspace(t, client, task.WorkspaceID.UUID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop) @@ -654,7 +673,7 @@ func TestTasks(t *testing.T) { coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID) // Fetch the task by ID via experimental API and verify fields. - task, err = exp.TaskByID(ctx, task.ID) + task, err = exp.TaskByIdentifier(ctx, task.ID.String()) require.NoError(t, err) require.NotZero(t, task.WorkspaceBuildNumber) require.True(t, task.WorkspaceAgentID.Valid) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 8066ebd047..87b5de3600 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2989,6 +2989,10 @@ func (q *querier) GetTaskByID(ctx context.Context, id uuid.UUID) (database.Task, return fetch(q.log, q.auth, q.db.GetTaskByID)(ctx, id) } +func (q *querier) GetTaskByOwnerIDAndName(ctx context.Context, arg database.GetTaskByOwnerIDAndNameParams) (database.Task, error) { + return fetch(q.log, q.auth, q.db.GetTaskByOwnerIDAndName)(ctx, arg) +} + func (q *querier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.Task, error) { return fetch(q.log, q.auth, q.db.GetTaskByWorkspaceID)(ctx, workspaceID) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 32c951fb5c..7d7c136eb5 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -2375,6 +2375,17 @@ func (s *MethodTestSuite) TestTasks() { dbm.EXPECT().GetTaskByID(gomock.Any(), task.ID).Return(task, nil).AnyTimes() check.Args(task.ID).Asserts(task, policy.ActionRead).Returns(task) })) + s.Run("GetTaskByOwnerIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + task := testutil.Fake(s.T(), faker, database.Task{}) + dbm.EXPECT().GetTaskByOwnerIDAndName(gomock.Any(), database.GetTaskByOwnerIDAndNameParams{ + OwnerID: task.OwnerID, + Name: task.Name, + }).Return(task, nil).AnyTimes() + check.Args(database.GetTaskByOwnerIDAndNameParams{ + OwnerID: task.OwnerID, + Name: task.Name, + }).Asserts(task, policy.ActionRead).Returns(task) + })) s.Run("DeleteTask", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { task := testutil.Fake(s.T(), faker, database.Task{}) arg := database.DeleteTaskParams{ diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 252f6f9b5a..d841315924 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1530,6 +1530,13 @@ func (m queryMetricsStore) GetTaskByID(ctx context.Context, id uuid.UUID) (datab return r0, r1 } +func (m queryMetricsStore) GetTaskByOwnerIDAndName(ctx context.Context, arg database.GetTaskByOwnerIDAndNameParams) (database.Task, error) { + start := time.Now() + r0, r1 := m.s.GetTaskByOwnerIDAndName(ctx, arg) + m.queryLatencies.WithLabelValues("GetTaskByOwnerIDAndName").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.Task, error) { start := time.Now() r0, r1 := m.s.GetTaskByWorkspaceID(ctx, workspaceID) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index af89a987a3..313bb98897 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -3237,6 +3237,21 @@ func (mr *MockStoreMockRecorder) GetTaskByID(ctx, id any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskByID", reflect.TypeOf((*MockStore)(nil).GetTaskByID), ctx, id) } +// GetTaskByOwnerIDAndName mocks base method. +func (m *MockStore) GetTaskByOwnerIDAndName(ctx context.Context, arg database.GetTaskByOwnerIDAndNameParams) (database.Task, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTaskByOwnerIDAndName", ctx, arg) + ret0, _ := ret[0].(database.Task) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTaskByOwnerIDAndName indicates an expected call of GetTaskByOwnerIDAndName. +func (mr *MockStoreMockRecorder) GetTaskByOwnerIDAndName(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskByOwnerIDAndName", reflect.TypeOf((*MockStore)(nil).GetTaskByOwnerIDAndName), ctx, arg) +} + // GetTaskByWorkspaceID mocks base method. func (m *MockStore) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.Task, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 2739cb7430..3e5771f96d 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -343,6 +343,7 @@ type sqlcQuerier interface { GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerBindingsRow, error) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerIDsRow, error) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error) + GetTaskByOwnerIDAndName(ctx context.Context, arg GetTaskByOwnerIDAndNameParams) (Task, error) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (Task, error) GetTelemetryItem(ctx context.Context, key string) (TelemetryItem, error) GetTelemetryItems(ctx context.Context) ([]TelemetryItem, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 65fac4733b..5c9a2f499c 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -13093,6 +13093,44 @@ func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error return i, err } +const getTaskByOwnerIDAndName = `-- name: GetTaskByOwnerIDAndName :one +SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status +WHERE + owner_id = $1::uuid + AND deleted_at IS NULL + AND LOWER(name) = LOWER($2::text) +` + +type GetTaskByOwnerIDAndNameParams struct { + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + Name string `db:"name" json:"name"` +} + +func (q *sqlQuerier) GetTaskByOwnerIDAndName(ctx context.Context, arg GetTaskByOwnerIDAndNameParams) (Task, error) { + row := q.db.QueryRowContext(ctx, getTaskByOwnerIDAndName, arg.OwnerID, arg.Name) + var i Task + err := row.Scan( + &i.ID, + &i.OrganizationID, + &i.OwnerID, + &i.Name, + &i.WorkspaceID, + &i.TemplateVersionID, + &i.TemplateParameters, + &i.Prompt, + &i.CreatedAt, + &i.DeletedAt, + &i.Status, + &i.WorkspaceBuildNumber, + &i.WorkspaceAgentID, + &i.WorkspaceAppID, + &i.OwnerUsername, + &i.OwnerName, + &i.OwnerAvatarUrl, + ) + return i, err +} + const getTaskByWorkspaceID = `-- name: GetTaskByWorkspaceID :one SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE workspace_id = $1::uuid ` diff --git a/coderd/database/queries/tasks.sql b/coderd/database/queries/tasks.sql index d0617ad39f..5cbbefd458 100644 --- a/coderd/database/queries/tasks.sql +++ b/coderd/database/queries/tasks.sql @@ -41,6 +41,13 @@ SELECT * FROM tasks_with_status WHERE id = @id::uuid; -- name: GetTaskByWorkspaceID :one SELECT * FROM tasks_with_status WHERE workspace_id = @workspace_id::uuid; +-- name: GetTaskByOwnerIDAndName :one +SELECT * FROM tasks_with_status +WHERE + owner_id = @owner_id::uuid + AND deleted_at IS NULL + AND LOWER(name) = LOWER(@name::text); + -- name: ListTasks :many SELECT * FROM tasks_with_status tws WHERE tws.deleted_at IS NULL diff --git a/coderd/httpmw/taskparam.go b/coderd/httpmw/taskparam.go index 6ecc888b37..1e6051eb03 100644 --- a/coderd/httpmw/taskparam.go +++ b/coderd/httpmw/taskparam.go @@ -2,8 +2,14 @@ package httpmw import ( "context" + "database/sql" + "errors" "net/http" + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "golang.org/x/xerrors" + "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" @@ -23,16 +29,34 @@ func TaskParam(r *http.Request) database.Task { return task } -// ExtractTaskParam grabs a task from the "task" URL parameter by UUID. +// ExtractTaskParam grabs a task from the "task" URL parameter. +// It supports two lookup strategies: +// 1. Task UUID (primary) +// 2. Task name scoped to owner (secondary) +// +// This middleware depends on ExtractOrganizationMembersParam being in the chain +// to provide the owner context for name-based lookups. func ExtractTaskParam(db database.Store) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - taskID, parsed := ParseUUIDParam(rw, r, "task") - if !parsed { + + // Get the task parameter value. We can't use ParseUUIDParam here because + // we need to support non-UUID values (task names) and + // attempt all lookup strategies. + taskParam := chi.URLParam(r, "task") + if taskParam == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "\"task\" must be provided.", + }) return } - task, err := db.GetTaskByID(ctx, taskID) + + // Get owner from OrganizationMembersParam middleware for name-based lookups + members := OrganizationMembersParam(r) + ownerID := members.UserID() + + task, err := fetchTaskWithFallback(ctx, db, taskParam, ownerID) if err != nil { if httpapi.Is404Error(err) { httpapi.ResourceNotFound(rw) @@ -48,10 +72,38 @@ func ExtractTaskParam(db database.Store) func(http.Handler) http.Handler { ctx = context.WithValue(ctx, taskParamContextKey{}, task) if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil { - rlogger.WithFields(slog.F("task_id", task.ID), slog.F("task_name", task.Name)) + rlogger.WithFields( + slog.F("task_id", task.ID), + slog.F("task_name", task.Name), + ) } next.ServeHTTP(rw, r.WithContext(ctx)) }) } } + +func fetchTaskWithFallback(ctx context.Context, db database.Store, taskParam string, ownerID uuid.UUID) (database.Task, error) { + // Attempt to first lookup the task by UUID. + taskID, err := uuid.Parse(taskParam) + if err == nil { + task, err := db.GetTaskByID(ctx, taskID) + if err == nil { + return task, nil + } + // There may be a task named with a valid UUID. Fall back to name lookup in this case. + if !errors.Is(err, sql.ErrNoRows) { + return database.Task{}, xerrors.Errorf("fetch task by uuid: %w", err) + } + } + + // taskParam not a valid UUID, OR valid UUID but not found, so attempt lookup by name. + task, err := db.GetTaskByOwnerIDAndName(ctx, database.GetTaskByOwnerIDAndNameParams{ + OwnerID: ownerID, + Name: taskParam, + }) + if err != nil { + return database.Task{}, xerrors.Errorf("fetch task by name: %w", err) + } + return task, nil +} diff --git a/coderd/httpmw/taskparam_test.go b/coderd/httpmw/taskparam_test.go index 559ccc2a2d..7430785f33 100644 --- a/coderd/httpmw/taskparam_test.go +++ b/coderd/httpmw/taskparam_test.go @@ -4,35 +4,119 @@ import ( "context" "net/http" "net/http/httptest" + "strings" "testing" "github.com/go-chi/chi/v5" "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/codersdk" ) func TestTaskParam(t *testing.T) { t.Parallel() - setup := func(db database.Store) (*http.Request, database.User) { - user := dbgen.User(t, db, database.User{}) - _, token := dbgen.APIKey(t, db, database.APIKey{ - UserID: user.ID, - }) + // Create all fixtures once - they're only read, never modified + db, _ := dbtestutil.NewDB(t) + user := dbgen.User(t, db, database.User{}) + _, token := dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + }) + org := dbgen.Organization(t, db, database.Organization{}) + tpl := dbgen.Template(t, db, database.Template{ + OrganizationID: org.ID, + CreatedBy: user.ID, + }) + tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{ + UUID: tpl.ID, + Valid: true, + }, + OrganizationID: org.ID, + CreatedBy: user.ID, + }) + workspace := dbgen.Workspace(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + TemplateID: tpl.ID, + }) + task := dbgen.Task(t, db, database.TaskTable{ + OrganizationID: org.ID, + OwnerID: user.ID, + TemplateVersionID: tv.ID, + WorkspaceID: uuid.NullUUID{UUID: workspace.ID, Valid: true}, + Prompt: "test prompt", + }) + workspaceNoTask := dbgen.Workspace(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + TemplateID: tpl.ID, + }) + taskFoundByUUID := dbgen.Task(t, db, database.TaskTable{ + Name: "found-by-uuid", + OrganizationID: org.ID, + OwnerID: user.ID, + TemplateVersionID: tv.ID, + WorkspaceID: uuid.NullUUID{UUID: workspace.ID, Valid: true}, + Prompt: "test prompt", + }) + // To test precedence of UUID over name, we create another task with the same name as the UUID task + _ = dbgen.Task(t, db, database.TaskTable{ + Name: taskFoundByUUID.ID.String(), + OrganizationID: org.ID, + OwnerID: user.ID, + TemplateVersionID: tv.ID, + WorkspaceID: uuid.NullUUID{UUID: workspace.ID, Valid: true}, + Prompt: "test prompt", + }) + workspaceSharedName := dbgen.Workspace(t, db, database.WorkspaceTable{ + Name: "shared-name", + OwnerID: user.ID, + OrganizationID: org.ID, + TemplateID: tpl.ID, + }) + // We create a task with the same name as the workspace shared name. + _ = dbgen.Task(t, db, database.TaskTable{ + Name: "task-different-name", + OrganizationID: org.ID, + OwnerID: user.ID, + TemplateVersionID: tv.ID, + WorkspaceID: uuid.NullUUID{UUID: workspaceSharedName.ID, Valid: true}, + Prompt: "test prompt", + }) + makeRequest := func(userID uuid.UUID, sessionToken string) *http.Request { r := httptest.NewRequest("GET", "/", nil) - r.Header.Set(codersdk.SessionTokenHeader, token) + r.Header.Set(codersdk.SessionTokenHeader, sessionToken) ctx := chi.NewRouteContext() - ctx.URLParams.Add("user", "me") + ctx.URLParams.Add("user", userID.String()) r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx)) - return r, user + return r + } + + makeRouter := func(handler http.HandlerFunc) chi.Router { + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + RedirectToLogin: false, + }), + httpmw.ExtractOrganizationMembersParam(db, func(r *http.Request, _ policy.Action, _ rbac.Objecter) bool { + return true + }), + httpmw.ExtractTaskParam(db), + ) + rtr.Get("/", handler) + return rtr } t.Run("None", func(t *testing.T) { @@ -40,8 +124,11 @@ func TestTaskParam(t *testing.T) { db, _ := dbtestutil.NewDB(t) rtr := chi.NewRouter() rtr.Use(httpmw.ExtractTaskParam(db)) - rtr.Get("/", nil) - r, _ := setup(db) + rtr.Get("/", func(w http.ResponseWriter, r *http.Request) { + assert.Fail(t, "this should never get called") + }) + r := httptest.NewRequest("GET", "/", nil) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, chi.NewRouteContext())) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) @@ -52,11 +139,10 @@ func TestTaskParam(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - db, _ := dbtestutil.NewDB(t) - rtr := chi.NewRouter() - rtr.Use(httpmw.ExtractTaskParam(db)) - rtr.Get("/", nil) - r, _ := setup(db) + rtr := makeRouter(func(w http.ResponseWriter, r *http.Request) { + assert.Fail(t, "this should never get called") + }) + r := makeRequest(user.ID, token) chi.RouteContext(r.Context()).URLParams.Add("task", uuid.NewString()) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) @@ -68,47 +154,11 @@ func TestTaskParam(t *testing.T) { t.Run("Found", func(t *testing.T) { t.Parallel() - db, _ := dbtestutil.NewDB(t) - rtr := chi.NewRouter() - rtr.Use( - httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ - DB: db, - RedirectToLogin: false, - }), - httpmw.ExtractTaskParam(db), - ) - rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { - _ = httpmw.TaskParam(r) - rw.WriteHeader(http.StatusOK) - }) - r, user := setup(db) - org := dbgen.Organization(t, db, database.Organization{}) - tpl := dbgen.Template(t, db, database.Template{ - OrganizationID: org.ID, - CreatedBy: user.ID, - }) - tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{ - UUID: tpl.ID, - Valid: true, - }, - OrganizationID: org.ID, - CreatedBy: user.ID, - }) - workspace := dbgen.Workspace(t, db, database.WorkspaceTable{ - OwnerID: user.ID, - Name: "test-workspace", - OrganizationID: org.ID, - TemplateID: tpl.ID, - }) - task := dbgen.Task(t, db, database.TaskTable{ - Name: "test-task", - OrganizationID: org.ID, - OwnerID: user.ID, - TemplateVersionID: tv.ID, - WorkspaceID: uuid.NullUUID{UUID: workspace.ID, Valid: true}, - Prompt: "test prompt", + rtr := makeRouter(func(w http.ResponseWriter, r *http.Request) { + foundTask := httpmw.TaskParam(r) + assert.Equal(t, task.ID.String(), foundTask.ID.String()) }) + r := makeRequest(user.ID, token) chi.RouteContext(r.Context()).URLParams.Add("task", task.ID.String()) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) @@ -117,4 +167,100 @@ func TestTaskParam(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) }) + + t.Run("FoundByTaskName", func(t *testing.T) { + t.Parallel() + rtr := makeRouter(func(w http.ResponseWriter, r *http.Request) { + foundTask := httpmw.TaskParam(r) + assert.Equal(t, task.ID.String(), foundTask.ID.String()) + }) + r := makeRequest(user.ID, token) + chi.RouteContext(r.Context()).URLParams.Add("task", task.Name) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) + + t.Run("NotFoundByWorkspaceName", func(t *testing.T) { + t.Parallel() + rtr := makeRouter(func(w http.ResponseWriter, r *http.Request) { + assert.Fail(t, "this should never get called") + }) + r := makeRequest(user.ID, token) + chi.RouteContext(r.Context()).URLParams.Add("task", workspace.Name) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusNotFound, res.StatusCode) + }) + + t.Run("CaseInsensitiveTaskName", func(t *testing.T) { + t.Parallel() + rtr := makeRouter(func(w http.ResponseWriter, r *http.Request) { + foundTask := httpmw.TaskParam(r) + assert.Equal(t, task.ID.String(), foundTask.ID.String()) + }) + r := makeRequest(user.ID, token) + // Look up with different case + chi.RouteContext(r.Context()).URLParams.Add("task", strings.ToUpper(task.Name)) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) + + t.Run("UUIDTakesPrecedence", func(t *testing.T) { + t.Parallel() + rtr := makeRouter(func(w http.ResponseWriter, r *http.Request) { + foundTask := httpmw.TaskParam(r) + assert.Equal(t, taskFoundByUUID.ID.String(), foundTask.ID.String()) + }) + r := makeRequest(user.ID, token) + // Look up by UUID - should find the first task, not the one named with the UUID + chi.RouteContext(r.Context()).URLParams.Add("task", taskFoundByUUID.ID.String()) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) + + t.Run("NotFoundWhenNoMatch", func(t *testing.T) { + t.Parallel() + rtr := makeRouter(func(w http.ResponseWriter, r *http.Request) { + assert.Fail(t, "this should never get called") + }) + r := makeRequest(user.ID, token) + chi.RouteContext(r.Context()).URLParams.Add("task", "nonexistent-name") + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusNotFound, res.StatusCode) + }) + + t.Run("WorkspaceWithoutTask", func(t *testing.T) { + t.Parallel() + rtr := makeRouter(func(w http.ResponseWriter, r *http.Request) { + assert.Fail(t, "this should never get called") + }) + r := makeRequest(user.ID, token) + // Look up by workspace name, but workspace has no task + chi.RouteContext(r.Context()).URLParams.Add("task", workspaceNoTask.Name) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusNotFound, res.StatusCode) + }) } diff --git a/codersdk/aitasks.go b/codersdk/aitasks.go index 1f1e9758e9..db8db8abca 100644 --- a/codersdk/aitasks.go +++ b/codersdk/aitasks.go @@ -231,6 +231,7 @@ func (c *ExperimentalClient) Tasks(ctx context.Context, filter *TasksFilter) ([] } // TaskByID fetches a single experimental task by its ID. +// Only tasks owned by codersdk.Me are supported. // // Experimental: This method is experimental and may change in the future. func (c *ExperimentalClient) TaskByID(ctx context.Context, id uuid.UUID) (Task, error) { @@ -251,6 +252,30 @@ func (c *ExperimentalClient) TaskByID(ctx context.Context, id uuid.UUID) (Task, return task, nil } +// TaskByOwnerAndName fetches a single experimental task by its owner and name. +// +// Experimental: This method is experimental and may change in the future. +func (c *ExperimentalClient) TaskByOwnerAndName(ctx context.Context, owner, ident string) (Task, error) { + if owner == "" { + owner = Me + } + res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/tasks/%s/%s", owner, ident), nil) + if err != nil { + return Task{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return Task{}, ReadBodyAsError(res) + } + + var task Task + if err := json.NewDecoder(res.Body).Decode(&task); err != nil { + return Task{}, err + } + + return task, nil +} + func splitTaskIdentifier(identifier string) (owner string, taskName string, err error) { parts := strings.Split(identifier, "/") @@ -287,34 +312,7 @@ func (c *ExperimentalClient) TaskByIdentifier(ctx context.Context, identifier st return Task{}, err } - tasks, err := c.Tasks(ctx, &TasksFilter{ - Owner: owner, - }) - if err != nil { - return Task{}, xerrors.Errorf("list tasks for owner %q: %w", owner, err) - } - - if taskID, err := uuid.Parse(taskName); err == nil { - // Find task by ID. - for _, task := range tasks { - if task.ID == taskID { - return task, nil - } - } - } else { - // Find task by name. - for _, task := range tasks { - if task.Name == taskName { - return task, nil - } - } - } - - // Mimic resource not found from API. - var notFoundErr error = &Error{ - Response: Response{Message: "Resource not found or you do not have access to this resource"}, - } - return Task{}, xerrors.Errorf("task %q not found for owner %q: %w", taskName, owner, notFoundErr) + return c.TaskByOwnerAndName(ctx, owner, taskName) } // DeleteTask deletes a task by its ID.