fix: optimize queue position sql query (#17974)

Use only `online provisioner daemons` for
`GetProvisionerJobsByIDsWithQueuePosition` query. It should improve
performance of the query.
This commit is contained in:
Yevhenii Shcherbina
2025-05-28 08:21:16 -04:00
committed by GitHub
parent 2bcbd9bdbd
commit 110102a60a
11 changed files with 88 additions and 36 deletions
+1 -1
View File
@@ -2341,7 +2341,7 @@ func (q *querier) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID)
return provisionerJobs, nil
}
func (q *querier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
func (q *querier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
// TODO: Remove this once we have a proper rbac check for provisioner jobs.
// Details in https://github.com/coder/coder/issues/16160
return q.db.GetProvisionerJobsByIDsWithQueuePosition(ctx, ids)
+1 -1
View File
@@ -4345,7 +4345,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
check.Args([]uuid.UUID{uuid.New()}).Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetProvisionerJobsByIDsWithQueuePosition", s.Subtest(func(db database.Store, check *expects) {
check.Args([]uuid.UUID{}).Asserts()
check.Args(database.GetProvisionerJobsByIDsWithQueuePositionParams{}).Asserts()
}))
s.Run("GetReplicaByID", s.Subtest(func(db database.Store, check *expects) {
check.Args(uuid.New()).Asserts(rbac.ResourceSystem, policy.ActionRead).Errors(sql.ErrNoRows)
+4 -4
View File
@@ -4684,14 +4684,14 @@ func (q *FakeQuerier) GetProvisionerJobsByIDs(_ context.Context, ids []uuid.UUID
return jobs, nil
}
func (q *FakeQuerier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
func (q *FakeQuerier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
if ids == nil {
ids = []uuid.UUID{}
if arg.IDs == nil {
arg.IDs = []uuid.UUID{}
}
return q.getProvisionerJobsByIDsWithQueuePositionLockedTagBasedQueue(ctx, ids)
return q.getProvisionerJobsByIDsWithQueuePositionLockedTagBasedQueue(ctx, arg.IDs)
}
func (q *FakeQuerier) GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisioner(ctx context.Context, arg database.GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerParams) ([]database.GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow, error) {
+1 -1
View File
@@ -1215,7 +1215,7 @@ func (m queryMetricsStore) GetProvisionerJobsByIDs(ctx context.Context, ids []uu
return jobs, err
}
func (m queryMetricsStore) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
func (m queryMetricsStore) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
start := time.Now()
r0, r1 := m.s.GetProvisionerJobsByIDsWithQueuePosition(ctx, ids)
m.queryLatencies.WithLabelValues("GetProvisionerJobsByIDsWithQueuePosition").Observe(time.Since(start).Seconds())
+4 -4
View File
@@ -2494,18 +2494,18 @@ func (mr *MockStoreMockRecorder) GetProvisionerJobsByIDs(ctx, ids any) *gomock.C
}
// GetProvisionerJobsByIDsWithQueuePosition mocks base method.
func (m *MockStore) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
func (m *MockStore) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProvisionerJobsByIDsWithQueuePosition", ctx, ids)
ret := m.ctrl.Call(m, "GetProvisionerJobsByIDsWithQueuePosition", ctx, arg)
ret0, _ := ret[0].([]database.GetProvisionerJobsByIDsWithQueuePositionRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProvisionerJobsByIDsWithQueuePosition indicates an expected call of GetProvisionerJobsByIDsWithQueuePosition.
func (mr *MockStoreMockRecorder) GetProvisionerJobsByIDsWithQueuePosition(ctx, ids any) *gomock.Call {
func (mr *MockStoreMockRecorder) GetProvisionerJobsByIDsWithQueuePosition(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerJobsByIDsWithQueuePosition", reflect.TypeOf((*MockStore)(nil).GetProvisionerJobsByIDsWithQueuePosition), ctx, ids)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerJobsByIDsWithQueuePosition", reflect.TypeOf((*MockStore)(nil).GetProvisionerJobsByIDsWithQueuePosition), ctx, arg)
}
// GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisioner mocks base method.
+1 -1
View File
@@ -278,7 +278,7 @@ type sqlcQuerier interface {
GetProvisionerJobByIDForUpdate(ctx context.Context, id uuid.UUID) (ProvisionerJob, error)
GetProvisionerJobTimingsByJobID(ctx context.Context, jobID uuid.UUID) ([]ProvisionerJobTiming, error)
GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error)
GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error)
GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg GetProvisionerJobsByIDsWithQueuePositionParams) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error)
GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisioner(ctx context.Context, arg GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerParams) ([]GetProvisionerJobsByOrganizationAndStatusWithQueuePositionAndProvisionerRow, error)
GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]ProvisionerJob, error)
// To avoid repeatedly attempting to reap the same jobs, we randomly order and limit to @max_jobs.
+21 -6
View File
@@ -15,7 +15,6 @@ import (
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
@@ -27,6 +26,7 @@ import (
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/prebuilds"
"github.com/coder/coder/v2/coderd/provisionerdserver"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/provisionersdk"
@@ -1268,7 +1268,10 @@ func TestQueuePosition(t *testing.T) {
Tags: database.StringMap{},
})
queued, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, jobIDs)
queued, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: jobIDs,
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
require.NoError(t, err)
require.Len(t, queued, jobCount)
sort.Slice(queued, func(i, j int) bool {
@@ -1296,7 +1299,10 @@ func TestQueuePosition(t *testing.T) {
require.NoError(t, err)
require.Equal(t, jobs[0].ID, job.ID)
queued, err = db.GetProvisionerJobsByIDsWithQueuePosition(ctx, jobIDs)
queued, err = db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: jobIDs,
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
require.NoError(t, err)
require.Len(t, queued, jobCount)
sort.Slice(queued, func(i, j int) bool {
@@ -2550,7 +2556,10 @@ func TestGetProvisionerJobsByIDsWithQueuePosition(t *testing.T) {
}
// When: we fetch the jobs by their IDs
actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, filteredJobIDs)
actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: filteredJobIDs,
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
require.NoError(t, err)
require.Len(t, actualJobs, len(filteredJobs), "should return all unskipped jobs")
@@ -2693,7 +2702,10 @@ func TestGetProvisionerJobsByIDsWithQueuePosition_MixedStatuses(t *testing.T) {
}
// When: we fetch the jobs by their IDs
actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, jobIDs)
actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: jobIDs,
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
require.NoError(t, err)
require.Len(t, actualJobs, len(allJobs), "should return all jobs")
@@ -2788,7 +2800,10 @@ func TestGetProvisionerJobsByIDsWithQueuePosition_OrderValidation(t *testing.T)
}
// When: we fetch the jobs by their IDs
actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, jobIDs)
actualJobs, err := db.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: jobIDs,
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
require.NoError(t, err)
require.Len(t, actualJobs, len(allJobs), "should return all jobs")
+15 -6
View File
@@ -7663,17 +7663,21 @@ pending_jobs AS (
WHERE
job_status = 'pending'
),
online_provisioner_daemons AS (
SELECT id, tags FROM provisioner_daemons pd
WHERE pd.last_seen_at IS NOT NULL AND pd.last_seen_at >= (NOW() - ($2::bigint || ' ms')::interval)
),
ranked_jobs AS (
-- Step 3: Rank only pending jobs based on provisioner availability
SELECT
pj.id,
pj.created_at,
ROW_NUMBER() OVER (PARTITION BY pd.id ORDER BY pj.created_at ASC) AS queue_position,
COUNT(*) OVER (PARTITION BY pd.id) AS queue_size
ROW_NUMBER() OVER (PARTITION BY opd.id ORDER BY pj.created_at ASC) AS queue_position,
COUNT(*) OVER (PARTITION BY opd.id) AS queue_size
FROM
pending_jobs pj
INNER JOIN provisioner_daemons pd
ON provisioner_tagset_contains(pd.tags, pj.tags) -- Join only on the small pending set
INNER JOIN online_provisioner_daemons opd
ON provisioner_tagset_contains(opd.tags, pj.tags) -- Join only on the small pending set
),
final_jobs AS (
-- Step 4: Compute best queue position and max queue size per job
@@ -7705,6 +7709,11 @@ ORDER BY
fj.created_at
`
type GetProvisionerJobsByIDsWithQueuePositionParams struct {
IDs []uuid.UUID `db:"ids" json:"ids"`
StaleIntervalMS int64 `db:"stale_interval_ms" json:"stale_interval_ms"`
}
type GetProvisionerJobsByIDsWithQueuePositionRow struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
@@ -7713,8 +7722,8 @@ type GetProvisionerJobsByIDsWithQueuePositionRow struct {
QueueSize int64 `db:"queue_size" json:"queue_size"`
}
func (q *sqlQuerier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error) {
rows, err := q.db.QueryContext(ctx, getProvisionerJobsByIDsWithQueuePosition, pq.Array(ids))
func (q *sqlQuerier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg GetProvisionerJobsByIDsWithQueuePositionParams) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error) {
rows, err := q.db.QueryContext(ctx, getProvisionerJobsByIDsWithQueuePosition, pq.Array(arg.IDs), arg.StaleIntervalMS)
if err != nil {
return nil, err
}
+8 -4
View File
@@ -80,17 +80,21 @@ pending_jobs AS (
WHERE
job_status = 'pending'
),
online_provisioner_daemons AS (
SELECT id, tags FROM provisioner_daemons pd
WHERE pd.last_seen_at IS NOT NULL AND pd.last_seen_at >= (NOW() - (@stale_interval_ms::bigint || ' ms')::interval)
),
ranked_jobs AS (
-- Step 3: Rank only pending jobs based on provisioner availability
SELECT
pj.id,
pj.created_at,
ROW_NUMBER() OVER (PARTITION BY pd.id ORDER BY pj.created_at ASC) AS queue_position,
COUNT(*) OVER (PARTITION BY pd.id) AS queue_size
ROW_NUMBER() OVER (PARTITION BY opd.id ORDER BY pj.created_at ASC) AS queue_position,
COUNT(*) OVER (PARTITION BY opd.id) AS queue_size
FROM
pending_jobs pj
INNER JOIN provisioner_daemons pd
ON provisioner_tagset_contains(pd.tags, pj.tags) -- Join only on the small pending set
INNER JOIN online_provisioner_daemons opd
ON provisioner_tagset_contains(opd.tags, pj.tags) -- Join only on the small pending set
),
final_jobs AS (
-- Step 4: Compute best queue position and max queue size per job
+28 -7
View File
@@ -53,7 +53,10 @@ func (api *API) templateVersion(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
templateVersion := httpmw.TemplateVersionParam(r)
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, []uuid.UUID{templateVersion.JobID})
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: []uuid.UUID{templateVersion.JobID},
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
if err != nil || len(jobs) == 0 {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner job.",
@@ -182,7 +185,10 @@ func (api *API) patchTemplateVersion(rw http.ResponseWriter, r *http.Request) {
return
}
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, []uuid.UUID{templateVersion.JobID})
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: []uuid.UUID{templateVersion.JobID},
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
if err != nil || len(jobs) == 0 {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner job.",
@@ -733,7 +739,10 @@ func (api *API) fetchTemplateVersionDryRunJob(rw http.ResponseWriter, r *http.Re
return database.GetProvisionerJobsByIDsWithQueuePositionRow{}, false
}
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, []uuid.UUID{jobUUID})
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: []uuid.UUID{jobUUID},
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Provisioner job %q not found.", jobUUID),
@@ -865,7 +874,10 @@ func (api *API) templateVersionsByTemplate(rw http.ResponseWriter, r *http.Reque
for _, version := range versions {
jobIDs = append(jobIDs, version.JobID)
}
jobs, err := store.GetProvisionerJobsByIDsWithQueuePosition(ctx, jobIDs)
jobs, err := store.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: jobIDs,
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner job.",
@@ -933,7 +945,10 @@ func (api *API) templateVersionByName(rw http.ResponseWriter, r *http.Request) {
})
return
}
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, []uuid.UUID{templateVersion.JobID})
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: []uuid.UUID{templateVersion.JobID},
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
if err != nil || len(jobs) == 0 {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner job.",
@@ -1013,7 +1028,10 @@ func (api *API) templateVersionByOrganizationTemplateAndName(rw http.ResponseWri
})
return
}
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, []uuid.UUID{templateVersion.JobID})
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: []uuid.UUID{templateVersion.JobID},
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
if err != nil || len(jobs) == 0 {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner job.",
@@ -1115,7 +1133,10 @@ func (api *API) previousTemplateVersionByOrganizationTemplateAndName(rw http.Res
return
}
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, []uuid.UUID{previousTemplateVersion.JobID})
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: []uuid.UUID{previousTemplateVersion.JobID},
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
if err != nil || len(jobs) == 0 {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner job.",
+4 -1
View File
@@ -797,7 +797,10 @@ func (api *API) workspaceBuildsData(ctx context.Context, workspaceBuilds []datab
for _, build := range workspaceBuilds {
jobIDs = append(jobIDs, build.JobID)
}
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, jobIDs)
jobs, err := api.Database.GetProvisionerJobsByIDsWithQueuePosition(ctx, database.GetProvisionerJobsByIDsWithQueuePositionParams{
IDs: jobIDs,
StaleIntervalMS: provisionerdserver.StaleInterval.Milliseconds(),
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return workspaceBuildsData{}, xerrors.Errorf("get provisioner jobs: %w", err)
}