diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 916a4e5a84..fec76b6a2d 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -14175,6 +14175,9 @@ const docTemplate = `{ }, "count": { "type": "integer" + }, + "count_cap": { + "type": "integer" } } }, @@ -14496,6 +14499,9 @@ const docTemplate = `{ }, "count": { "type": "integer" + }, + "count_cap": { + "type": "integer" } } }, diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 1fb7b92d03..fd22aa91b9 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -12739,6 +12739,9 @@ }, "count": { "type": "integer" + }, + "count_cap": { + "type": "integer" } } }, @@ -13039,6 +13042,9 @@ }, "count": { "type": "integer" + }, + "count_cap": { + "type": "integer" } } }, diff --git a/coderd/audit.go b/coderd/audit.go index f1fd7668f7..3d8aed3005 100644 --- a/coderd/audit.go +++ b/coderd/audit.go @@ -26,6 +26,11 @@ import ( "github.com/coder/coder/v2/codersdk" ) +// Limit the count query to avoid a slow sequential scan due to joins +// on a large table. Set to 0 to disable capping (but also see the note +// in the SQL query). +const auditLogCountCap = 2000 + // @Summary Get audit logs // @ID get-audit-logs // @Security CoderSessionToken @@ -66,7 +71,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) { countFilter.Username = "" } - // Use the same filters to count the number of audit logs + countFilter.CountCap = auditLogCountCap count, err := api.Database.CountAuditLogs(ctx, countFilter) if dbauthz.IsNotAuthorizedError(err) { httpapi.Forbidden(rw) @@ -81,6 +86,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{ AuditLogs: []codersdk.AuditLog{}, Count: 0, + CountCap: auditLogCountCap, }) return } @@ -98,6 +104,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{ AuditLogs: api.convertAuditLogs(ctx, dblogs), Count: count, + CountCap: auditLogCountCap, }) } diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index d0d08609ca..169e46b3d4 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -584,6 +584,7 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi arg.DateTo, arg.BuildReason, arg.RequestID, + arg.CountCap, ) if err != nil { return 0, err @@ -720,6 +721,7 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun arg.WorkspaceID, arg.ConnectionID, arg.Status, + arg.CountCap, ) if err != nil { return 0, err diff --git a/coderd/database/modelqueries_internal_test.go b/coderd/database/modelqueries_internal_test.go index 9e84324b72..3f425b4347 100644 --- a/coderd/database/modelqueries_internal_test.go +++ b/coderd/database/modelqueries_internal_test.go @@ -145,5 +145,13 @@ func extractWhereClause(query string) string { // Remove SQL comments whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "") + // Normalize indentation so subquery wrapping doesn't cause + // mismatches. + lines := strings.Split(whereClause, "\n") + for i, line := range lines { + lines[i] = strings.TrimLeft(line, " \t") + } + whereClause = strings.Join(lines, "\n") + return strings.TrimSpace(whereClause) } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 58ae9bc037..7e1c7f0009 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2275,93 +2275,105 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP } const countAuditLogs = `-- name: CountAuditLogs :one -SELECT COUNT(*) -FROM audit_logs - LEFT JOIN users ON audit_logs.user_id = users.id - LEFT JOIN organizations ON audit_logs.organization_id = organizations.id - -- First join on workspaces to get the initial workspace create - -- to workspace build 1 id. This is because the first create is - -- is a different audit log than subsequent starts. - LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace' - AND audit_logs.resource_id = workspaces.id - -- Get the reason from the build if the resource type - -- is a workspace_build - LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build' - AND audit_logs.resource_id = wb_build.id - -- Get the reason from the build #1 if this is the first - -- workspace create. - LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace' - AND audit_logs.action = 'create' - AND workspaces.id = wb_workspace.workspace_id - AND wb_workspace.build_number = 1 -WHERE - -- Filter resource_type - CASE - WHEN $1::text != '' THEN resource_type = $1::resource_type - ELSE true - END - -- Filter resource_id - AND CASE - WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2 - ELSE true - END - -- Filter organization_id - AND CASE - WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3 - ELSE true - END - -- Filter by resource_target - AND CASE - WHEN $4::text != '' THEN resource_target = $4 - ELSE true - END - -- Filter action - AND CASE - WHEN $5::text != '' THEN action = $5::audit_action - ELSE true - END - -- Filter by user_id - AND CASE - WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6 - ELSE true - END - -- Filter by username - AND CASE - WHEN $7::text != '' THEN user_id = ( - SELECT id - FROM users - WHERE lower(username) = lower($7) - AND deleted = false - ) - ELSE true - END - -- Filter by user_email - AND CASE - WHEN $8::text != '' THEN users.email = $8 - ELSE true - END - -- Filter by date_from - AND CASE - WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9 - ELSE true - END - -- Filter by date_to - AND CASE - WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10 - ELSE true - END - -- Filter by build_reason - AND CASE - WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11 - ELSE true - END - -- Filter request_id - AND CASE - WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12 - ELSE true - END - -- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs - -- @authorize_filter +SELECT COUNT(*) FROM ( + SELECT 1 + FROM audit_logs + LEFT JOIN users ON audit_logs.user_id = users.id + LEFT JOIN organizations ON audit_logs.organization_id = organizations.id + -- First join on workspaces to get the initial workspace create + -- to workspace build 1 id. This is because the first create is + -- is a different audit log than subsequent starts. + LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace' + AND audit_logs.resource_id = workspaces.id + -- Get the reason from the build if the resource type + -- is a workspace_build + LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build' + AND audit_logs.resource_id = wb_build.id + -- Get the reason from the build #1 if this is the first + -- workspace create. + LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace' + AND audit_logs.action = 'create' + AND workspaces.id = wb_workspace.workspace_id + AND wb_workspace.build_number = 1 + WHERE + -- Filter resource_type + CASE + WHEN $1::text != '' THEN resource_type = $1::resource_type + ELSE true + END + -- Filter resource_id + AND CASE + WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2 + ELSE true + END + -- Filter organization_id + AND CASE + WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3 + ELSE true + END + -- Filter by resource_target + AND CASE + WHEN $4::text != '' THEN resource_target = $4 + ELSE true + END + -- Filter action + AND CASE + WHEN $5::text != '' THEN action = $5::audit_action + ELSE true + END + -- Filter by user_id + AND CASE + WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6 + ELSE true + END + -- Filter by username + AND CASE + WHEN $7::text != '' THEN user_id = ( + SELECT id + FROM users + WHERE lower(username) = lower($7) + AND deleted = false + ) + ELSE true + END + -- Filter by user_email + AND CASE + WHEN $8::text != '' THEN users.email = $8 + ELSE true + END + -- Filter by date_from + AND CASE + WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9 + ELSE true + END + -- Filter by date_to + AND CASE + WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10 + ELSE true + END + -- Filter by build_reason + AND CASE + WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11 + ELSE true + END + -- Filter request_id + AND CASE + WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12 + ELSE true + END + -- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs + -- @authorize_filter + -- Avoid a slow scan on a large table with joins. The caller + -- passes the count cap and we add 1 so the frontend can detect + -- capping and show "... of N+". A cap of 0 means no limit (NULLIF + -- -> NULL + 1 = NULL). + -- NOTE: Parameterizing this so that we can easily change from, + -- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT) + -- here if disabling the capping on a large table permanently. + -- This way the PG planner can plan parallel execution for + -- potential large wins. + LIMIT NULLIF($13::int, 0) + 1 +) AS limited_count ` type CountAuditLogsParams struct { @@ -2377,6 +2389,7 @@ type CountAuditLogsParams struct { DateTo time.Time `db:"date_to" json:"date_to"` BuildReason string `db:"build_reason" json:"build_reason"` RequestID uuid.UUID `db:"request_id" json:"request_id"` + CountCap int32 `db:"count_cap" json:"count_cap"` } func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) { @@ -2393,6 +2406,7 @@ func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParam arg.DateTo, arg.BuildReason, arg.RequestID, + arg.CountCap, ) var count int64 err := row.Scan(&count) @@ -7571,110 +7585,113 @@ func (q *sqlQuerier) BatchUpsertConnectionLogs(ctx context.Context, arg BatchUps } const countConnectionLogs = `-- name: CountConnectionLogs :one -SELECT - COUNT(*) AS count -FROM - connection_logs -JOIN users AS workspace_owner ON - connection_logs.workspace_owner_id = workspace_owner.id -LEFT JOIN users ON - connection_logs.user_id = users.id -JOIN organizations ON - connection_logs.organization_id = organizations.id -WHERE - -- Filter organization_id - CASE - WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - connection_logs.organization_id = $1 - ELSE true - END - -- Filter by workspace owner username - AND CASE - WHEN $2 :: text != '' THEN - workspace_owner_id = ( - SELECT id FROM users - WHERE lower(username) = lower($2) AND deleted = false - ) - ELSE true - END - -- Filter by workspace_owner_id - AND CASE - WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - workspace_owner_id = $3 - ELSE true - END - -- Filter by workspace_owner_email - AND CASE - WHEN $4 :: text != '' THEN - workspace_owner_id = ( - SELECT id FROM users - WHERE email = $4 AND deleted = false - ) - ELSE true - END - -- Filter by type - AND CASE - WHEN $5 :: text != '' THEN - type = $5 :: connection_type - ELSE true - END - -- Filter by user_id - AND CASE - WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - user_id = $6 - ELSE true - END - -- Filter by username - AND CASE - WHEN $7 :: text != '' THEN - user_id = ( - SELECT id FROM users - WHERE lower(username) = lower($7) AND deleted = false - ) - ELSE true - END - -- Filter by user_email - AND CASE - WHEN $8 :: text != '' THEN - users.email = $8 - ELSE true - END - -- Filter by connected_after - AND CASE - WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN - connect_time >= $9 - ELSE true - END - -- Filter by connected_before - AND CASE - WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN - connect_time <= $10 - ELSE true - END - -- Filter by workspace_id - AND CASE - WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - connection_logs.workspace_id = $11 - ELSE true - END - -- Filter by connection_id - AND CASE - WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - connection_logs.connection_id = $12 - ELSE true - END - -- Filter by whether the session has a disconnect_time - AND CASE - WHEN $13 :: text != '' THEN - (($13 = 'ongoing' AND disconnect_time IS NULL) OR - ($13 = 'completed' AND disconnect_time IS NOT NULL)) AND - -- Exclude web events, since we don't know their close time. - "type" NOT IN ('workspace_app', 'port_forwarding') - ELSE true - END - -- Authorize Filter clause will be injected below in - -- CountAuthorizedConnectionLogs - -- @authorize_filter +SELECT COUNT(*) AS count FROM ( + SELECT 1 + FROM + connection_logs + JOIN users AS workspace_owner ON + connection_logs.workspace_owner_id = workspace_owner.id + LEFT JOIN users ON + connection_logs.user_id = users.id + JOIN organizations ON + connection_logs.organization_id = organizations.id + WHERE + -- Filter organization_id + CASE + WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.organization_id = $1 + ELSE true + END + -- Filter by workspace owner username + AND CASE + WHEN $2 :: text != '' THEN + workspace_owner_id = ( + SELECT id FROM users + WHERE lower(username) = lower($2) AND deleted = false + ) + ELSE true + END + -- Filter by workspace_owner_id + AND CASE + WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + workspace_owner_id = $3 + ELSE true + END + -- Filter by workspace_owner_email + AND CASE + WHEN $4 :: text != '' THEN + workspace_owner_id = ( + SELECT id FROM users + WHERE email = $4 AND deleted = false + ) + ELSE true + END + -- Filter by type + AND CASE + WHEN $5 :: text != '' THEN + type = $5 :: connection_type + ELSE true + END + -- Filter by user_id + AND CASE + WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_id = $6 + ELSE true + END + -- Filter by username + AND CASE + WHEN $7 :: text != '' THEN + user_id = ( + SELECT id FROM users + WHERE lower(username) = lower($7) AND deleted = false + ) + ELSE true + END + -- Filter by user_email + AND CASE + WHEN $8 :: text != '' THEN + users.email = $8 + ELSE true + END + -- Filter by connected_after + AND CASE + WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN + connect_time >= $9 + ELSE true + END + -- Filter by connected_before + AND CASE + WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN + connect_time <= $10 + ELSE true + END + -- Filter by workspace_id + AND CASE + WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.workspace_id = $11 + ELSE true + END + -- Filter by connection_id + AND CASE + WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.connection_id = $12 + ELSE true + END + -- Filter by whether the session has a disconnect_time + AND CASE + WHEN $13 :: text != '' THEN + (($13 = 'ongoing' AND disconnect_time IS NULL) OR + ($13 = 'completed' AND disconnect_time IS NOT NULL)) AND + -- Exclude web events, since we don't know their close time. + "type" NOT IN ('workspace_app', 'port_forwarding') + ELSE true + END + -- Authorize Filter clause will be injected below in + -- CountAuthorizedConnectionLogs + -- @authorize_filter + -- NOTE: See the CountAuditLogs LIMIT note. + LIMIT NULLIF($14::int, 0) + 1 +) AS limited_count ` type CountConnectionLogsParams struct { @@ -7691,6 +7708,7 @@ type CountConnectionLogsParams struct { WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"` ConnectionID uuid.UUID `db:"connection_id" json:"connection_id"` Status string `db:"status" json:"status"` + CountCap int32 `db:"count_cap" json:"count_cap"` } func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) { @@ -7708,6 +7726,7 @@ func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectio arg.WorkspaceID, arg.ConnectionID, arg.Status, + arg.CountCap, ) var count int64 err := row.Scan(&count) diff --git a/coderd/database/queries/auditlogs.sql b/coderd/database/queries/auditlogs.sql index a1c219e702..5a2f9a31e8 100644 --- a/coderd/database/queries/auditlogs.sql +++ b/coderd/database/queries/auditlogs.sql @@ -149,94 +149,105 @@ VALUES ( RETURNING *; -- name: CountAuditLogs :one -SELECT COUNT(*) -FROM audit_logs - LEFT JOIN users ON audit_logs.user_id = users.id - LEFT JOIN organizations ON audit_logs.organization_id = organizations.id - -- First join on workspaces to get the initial workspace create - -- to workspace build 1 id. This is because the first create is - -- is a different audit log than subsequent starts. - LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace' - AND audit_logs.resource_id = workspaces.id - -- Get the reason from the build if the resource type - -- is a workspace_build - LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build' - AND audit_logs.resource_id = wb_build.id - -- Get the reason from the build #1 if this is the first - -- workspace create. - LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace' - AND audit_logs.action = 'create' - AND workspaces.id = wb_workspace.workspace_id - AND wb_workspace.build_number = 1 -WHERE - -- Filter resource_type - CASE - WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type - ELSE true - END - -- Filter resource_id - AND CASE - WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id - ELSE true - END - -- Filter organization_id - AND CASE - WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id - ELSE true - END - -- Filter by resource_target - AND CASE - WHEN @resource_target::text != '' THEN resource_target = @resource_target - ELSE true - END - -- Filter action - AND CASE - WHEN @action::text != '' THEN action = @action::audit_action - ELSE true - END - -- Filter by user_id - AND CASE - WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id - ELSE true - END - -- Filter by username - AND CASE - WHEN @username::text != '' THEN user_id = ( - SELECT id - FROM users - WHERE lower(username) = lower(@username) - AND deleted = false - ) - ELSE true - END - -- Filter by user_email - AND CASE - WHEN @email::text != '' THEN users.email = @email - ELSE true - END - -- Filter by date_from - AND CASE - WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from - ELSE true - END - -- Filter by date_to - AND CASE - WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to - ELSE true - END - -- Filter by build_reason - AND CASE - WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason - ELSE true - END - -- Filter request_id - AND CASE - WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id - ELSE true - END - -- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs - -- @authorize_filter -; +SELECT COUNT(*) FROM ( + SELECT 1 + FROM audit_logs + LEFT JOIN users ON audit_logs.user_id = users.id + LEFT JOIN organizations ON audit_logs.organization_id = organizations.id + -- First join on workspaces to get the initial workspace create + -- to workspace build 1 id. This is because the first create is + -- is a different audit log than subsequent starts. + LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace' + AND audit_logs.resource_id = workspaces.id + -- Get the reason from the build if the resource type + -- is a workspace_build + LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build' + AND audit_logs.resource_id = wb_build.id + -- Get the reason from the build #1 if this is the first + -- workspace create. + LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace' + AND audit_logs.action = 'create' + AND workspaces.id = wb_workspace.workspace_id + AND wb_workspace.build_number = 1 + WHERE + -- Filter resource_type + CASE + WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type + ELSE true + END + -- Filter resource_id + AND CASE + WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id + ELSE true + END + -- Filter organization_id + AND CASE + WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id + ELSE true + END + -- Filter by resource_target + AND CASE + WHEN @resource_target::text != '' THEN resource_target = @resource_target + ELSE true + END + -- Filter action + AND CASE + WHEN @action::text != '' THEN action = @action::audit_action + ELSE true + END + -- Filter by user_id + AND CASE + WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id + ELSE true + END + -- Filter by username + AND CASE + WHEN @username::text != '' THEN user_id = ( + SELECT id + FROM users + WHERE lower(username) = lower(@username) + AND deleted = false + ) + ELSE true + END + -- Filter by user_email + AND CASE + WHEN @email::text != '' THEN users.email = @email + ELSE true + END + -- Filter by date_from + AND CASE + WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from + ELSE true + END + -- Filter by date_to + AND CASE + WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to + ELSE true + END + -- Filter by build_reason + AND CASE + WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason + ELSE true + END + -- Filter request_id + AND CASE + WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id + ELSE true + END + -- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs + -- @authorize_filter + -- Avoid a slow scan on a large table with joins. The caller + -- passes the count cap and we add 1 so the frontend can detect + -- capping and show "... of N+". A cap of 0 means no limit (NULLIF + -- -> NULL + 1 = NULL). + -- NOTE: Parameterizing this so that we can easily change from, + -- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT) + -- here if disabling the capping on a large table permanently. + -- This way the PG planner can plan parallel execution for + -- potential large wins. + LIMIT NULLIF(@count_cap::int, 0) + 1 +) AS limited_count; -- name: DeleteOldAuditLogConnectionEvents :exec DELETE FROM audit_logs diff --git a/coderd/database/queries/connectionlogs.sql b/coderd/database/queries/connectionlogs.sql index 63e0023dcc..7e5fb63a37 100644 --- a/coderd/database/queries/connectionlogs.sql +++ b/coderd/database/queries/connectionlogs.sql @@ -133,111 +133,113 @@ OFFSET @offset_opt; -- name: CountConnectionLogs :one -SELECT - COUNT(*) AS count -FROM - connection_logs -JOIN users AS workspace_owner ON - connection_logs.workspace_owner_id = workspace_owner.id -LEFT JOIN users ON - connection_logs.user_id = users.id -JOIN organizations ON - connection_logs.organization_id = organizations.id -WHERE - -- Filter organization_id - CASE - WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - connection_logs.organization_id = @organization_id - ELSE true - END - -- Filter by workspace owner username - AND CASE - WHEN @workspace_owner :: text != '' THEN - workspace_owner_id = ( - SELECT id FROM users - WHERE lower(username) = lower(@workspace_owner) AND deleted = false - ) - ELSE true - END - -- Filter by workspace_owner_id - AND CASE - WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - workspace_owner_id = @workspace_owner_id - ELSE true - END - -- Filter by workspace_owner_email - AND CASE - WHEN @workspace_owner_email :: text != '' THEN - workspace_owner_id = ( - SELECT id FROM users - WHERE email = @workspace_owner_email AND deleted = false - ) - ELSE true - END - -- Filter by type - AND CASE - WHEN @type :: text != '' THEN - type = @type :: connection_type - ELSE true - END - -- Filter by user_id - AND CASE - WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - user_id = @user_id - ELSE true - END - -- Filter by username - AND CASE - WHEN @username :: text != '' THEN - user_id = ( - SELECT id FROM users - WHERE lower(username) = lower(@username) AND deleted = false - ) - ELSE true - END - -- Filter by user_email - AND CASE - WHEN @user_email :: text != '' THEN - users.email = @user_email - ELSE true - END - -- Filter by connected_after - AND CASE - WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN - connect_time >= @connected_after - ELSE true - END - -- Filter by connected_before - AND CASE - WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN - connect_time <= @connected_before - ELSE true - END - -- Filter by workspace_id - AND CASE - WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - connection_logs.workspace_id = @workspace_id - ELSE true - END - -- Filter by connection_id - AND CASE - WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN - connection_logs.connection_id = @connection_id - ELSE true - END - -- Filter by whether the session has a disconnect_time - AND CASE - WHEN @status :: text != '' THEN - ((@status = 'ongoing' AND disconnect_time IS NULL) OR - (@status = 'completed' AND disconnect_time IS NOT NULL)) AND - -- Exclude web events, since we don't know their close time. - "type" NOT IN ('workspace_app', 'port_forwarding') - ELSE true - END - -- Authorize Filter clause will be injected below in - -- CountAuthorizedConnectionLogs - -- @authorize_filter -; +SELECT COUNT(*) AS count FROM ( + SELECT 1 + FROM + connection_logs + JOIN users AS workspace_owner ON + connection_logs.workspace_owner_id = workspace_owner.id + LEFT JOIN users ON + connection_logs.user_id = users.id + JOIN organizations ON + connection_logs.organization_id = organizations.id + WHERE + -- Filter organization_id + CASE + WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.organization_id = @organization_id + ELSE true + END + -- Filter by workspace owner username + AND CASE + WHEN @workspace_owner :: text != '' THEN + workspace_owner_id = ( + SELECT id FROM users + WHERE lower(username) = lower(@workspace_owner) AND deleted = false + ) + ELSE true + END + -- Filter by workspace_owner_id + AND CASE + WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + workspace_owner_id = @workspace_owner_id + ELSE true + END + -- Filter by workspace_owner_email + AND CASE + WHEN @workspace_owner_email :: text != '' THEN + workspace_owner_id = ( + SELECT id FROM users + WHERE email = @workspace_owner_email AND deleted = false + ) + ELSE true + END + -- Filter by type + AND CASE + WHEN @type :: text != '' THEN + type = @type :: connection_type + ELSE true + END + -- Filter by user_id + AND CASE + WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_id = @user_id + ELSE true + END + -- Filter by username + AND CASE + WHEN @username :: text != '' THEN + user_id = ( + SELECT id FROM users + WHERE lower(username) = lower(@username) AND deleted = false + ) + ELSE true + END + -- Filter by user_email + AND CASE + WHEN @user_email :: text != '' THEN + users.email = @user_email + ELSE true + END + -- Filter by connected_after + AND CASE + WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN + connect_time >= @connected_after + ELSE true + END + -- Filter by connected_before + AND CASE + WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN + connect_time <= @connected_before + ELSE true + END + -- Filter by workspace_id + AND CASE + WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.workspace_id = @workspace_id + ELSE true + END + -- Filter by connection_id + AND CASE + WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.connection_id = @connection_id + ELSE true + END + -- Filter by whether the session has a disconnect_time + AND CASE + WHEN @status :: text != '' THEN + ((@status = 'ongoing' AND disconnect_time IS NULL) OR + (@status = 'completed' AND disconnect_time IS NOT NULL)) AND + -- Exclude web events, since we don't know their close time. + "type" NOT IN ('workspace_app', 'port_forwarding') + ELSE true + END + -- Authorize Filter clause will be injected below in + -- CountAuthorizedConnectionLogs + -- @authorize_filter + -- NOTE: See the CountAuditLogs LIMIT note. + LIMIT NULLIF(@count_cap::int, 0) + 1 +) AS limited_count; -- name: DeleteOldConnectionLogs :execrows WITH old_logs AS ( diff --git a/coderd/rbac/regosql/compile_test.go b/coderd/rbac/regosql/compile_test.go index 9249e890ad..63f9302d3a 100644 --- a/coderd/rbac/regosql/compile_test.go +++ b/coderd/rbac/regosql/compile_test.go @@ -298,6 +298,40 @@ neq(input.object.owner, ""); ExpectedSQL: p("'' = 'org-id'"), VariableConverter: regosql.ChatConverter(), }, + { + Name: "AuditLogUUID", + Queries: []string{ + `"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`, + `input.object.org_owner != ""`, + `neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`, + `input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708", "05f58202-4bfc-43ce-9ba4-5ff6e0174a71"}`, + `"read" in input.object.acl_group_list[input.object.org_owner]`, + }, + ExpectedSQL: p( + p("audit_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " + + p("audit_logs.organization_id IS NOT NULL") + " OR " + + p("audit_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " + + p("audit_logs.organization_id = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'::uuid,'8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " + + "(false)"), + VariableConverter: regosql.AuditLogConverter(), + }, + { + Name: "ConnectionLogUUID", + Queries: []string{ + `"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`, + `input.object.org_owner != ""`, + `neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`, + `input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708"}`, + `"read" in input.object.acl_group_list[input.object.org_owner]`, + }, + ExpectedSQL: p( + p("connection_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " + + p("connection_logs.organization_id IS NOT NULL") + " OR " + + p("connection_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " + + p("connection_logs.organization_id = ANY(ARRAY ['8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " + + "(false)"), + VariableConverter: regosql.ConnectionLogConverter(), + }, } for _, tc := range testCases { diff --git a/coderd/rbac/regosql/configs.go b/coderd/rbac/regosql/configs.go index 4f156e8a26..2066d93473 100644 --- a/coderd/rbac/regosql/configs.go +++ b/coderd/rbac/regosql/configs.go @@ -53,7 +53,7 @@ func WorkspaceConverter() *sqltypes.VariableConverter { func AuditLogConverter() *sqltypes.VariableConverter { matcher := sqltypes.NewVariableConverter().RegisterMatcher( resourceIDMatcher(), - sqltypes.StringVarMatcher("COALESCE(audit_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}), + sqltypes.UUIDVarMatcher("audit_logs.organization_id", []string{"input", "object", "org_owner"}), // Audit logs have no user owner, only owner by an organization. sqltypes.AlwaysFalse(userOwnerMatcher()), ) @@ -67,7 +67,7 @@ func AuditLogConverter() *sqltypes.VariableConverter { func ConnectionLogConverter() *sqltypes.VariableConverter { matcher := sqltypes.NewVariableConverter().RegisterMatcher( resourceIDMatcher(), - sqltypes.StringVarMatcher("COALESCE(connection_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}), + sqltypes.UUIDVarMatcher("connection_logs.organization_id", []string{"input", "object", "org_owner"}), // Connection logs have no user owner, only owner by an organization. sqltypes.AlwaysFalse(userOwnerMatcher()), ) diff --git a/coderd/rbac/regosql/sqltypes/uuid.go b/coderd/rbac/regosql/sqltypes/uuid.go new file mode 100644 index 0000000000..bcf95c8411 --- /dev/null +++ b/coderd/rbac/regosql/sqltypes/uuid.go @@ -0,0 +1,114 @@ +package sqltypes + +import ( + "fmt" + "strings" + + "github.com/open-policy-agent/opa/ast" + "golang.org/x/xerrors" +) + +var ( + _ VariableMatcher = astUUIDVar{} + _ Node = astUUIDVar{} + _ SupportsEquality = astUUIDVar{} +) + +// astUUIDVar is a variable that represents a UUID column. Unlike +// astStringVar it emits native UUID comparisons (column = 'val'::uuid) +// instead of text-based ones (COALESCE(column::text, ”) = 'val'). +// This allows PostgreSQL to use indexes on UUID columns. +type astUUIDVar struct { + Source RegoSource + FieldPath []string + ColumnString string +} + +func UUIDVarMatcher(sqlColumn string, regoPath []string) VariableMatcher { + return astUUIDVar{FieldPath: regoPath, ColumnString: sqlColumn} +} + +func (astUUIDVar) UseAs() Node { return astUUIDVar{} } + +func (u astUUIDVar) ConvertVariable(rego ast.Ref) (Node, bool) { + left, err := RegoVarPath(u.FieldPath, rego) + if err == nil && len(left) == 0 { + return astUUIDVar{ + Source: RegoSource(rego.String()), + FieldPath: u.FieldPath, + ColumnString: u.ColumnString, + }, true + } + + return nil, false +} + +func (u astUUIDVar) SQLString(_ *SQLGenerator) string { + return u.ColumnString +} + +// EqualsSQLString handles equality comparisons for UUID columns. +// Rego always produces string literals, so we accept AstString and +// cast the literal to ::uuid in the output SQL. This lets PG use +// native UUID indexes instead of falling back to text comparisons. +// nolint:revive +func (u astUUIDVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) { + switch other.UseAs().(type) { + case AstString: + // The other side is a rego string literal like + // "8c0b9bdc-a013-4b14-a49b-5747bc335708". Emit a comparison + // that casts the literal to uuid so PG can use indexes: + // column = 'val'::uuid + // instead of the text-based: + // 'val' = COALESCE(column::text, '') + s, ok := other.(AstString) + if !ok { + return "", xerrors.Errorf("expected AstString, got %T", other) + } + if s.Value == "" { + // Empty string in rego means "no value". Compare the + // column against NULL since UUID columns represent + // absent values as NULL, not empty strings. + op := "IS NULL" + if not { + op = "IS NOT NULL" + } + return fmt.Sprintf("%s %s", u.ColumnString, op), nil + } + return fmt.Sprintf("%s %s '%s'::uuid", + u.ColumnString, equalsOp(not), s.Value), nil + case astUUIDVar: + return basicSQLEquality(cfg, not, u, other), nil + default: + return "", xerrors.Errorf("unsupported equality: %T %s %T", + u, equalsOp(not), other) + } +} + +// ContainedInSQL implements SupportsContainedIn so that a UUID column +// can appear in membership checks like `col = ANY(ARRAY[...])`. The +// array elements are rego strings, so we cast each to ::uuid. +func (u astUUIDVar) ContainedInSQL(_ *SQLGenerator, haystack Node) (string, error) { + arr, ok := haystack.(ASTArray) + if !ok { + return "", xerrors.Errorf("unsupported containedIn: %T in %T", u, haystack) + } + + if len(arr.Value) == 0 { + return "false", nil + } + + // Build ARRAY['uuid1'::uuid, 'uuid2'::uuid, ...] + values := make([]string, 0, len(arr.Value)) + for _, v := range arr.Value { + s, ok := v.(AstString) + if !ok { + return "", xerrors.Errorf("expected AstString array element, got %T", v) + } + values = append(values, fmt.Sprintf("'%s'::uuid", s.Value)) + } + + return fmt.Sprintf("%s = ANY(ARRAY [%s])", + u.ColumnString, + strings.Join(values, ",")), nil +} diff --git a/coderd/searchquery/search.go b/coderd/searchquery/search.go index 330c2e6eb4..260ba792fc 100644 --- a/coderd/searchquery/search.go +++ b/coderd/searchquery/search.go @@ -66,7 +66,7 @@ func AuditLogs(ctx context.Context, db database.Store, query string) (database.G } // Prepare the count filter, which uses the same parameters as the GetAuditLogsOffsetParams. - // nolint:exhaustruct // UserID is not obtained from the query parameters. + // nolint:exhaustruct // UserID and CountCap are not obtained from the query parameters. countFilter := database.CountAuditLogsParams{ RequestID: filter.RequestID, ResourceID: filter.ResourceID, @@ -123,6 +123,7 @@ func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey } // This MUST be kept in sync with the above + // nolint:exhaustruct // CountCap is not obtained from the query parameters. countFilter := database.CountConnectionLogsParams{ OrganizationID: filter.OrganizationID, WorkspaceOwner: filter.WorkspaceOwner, diff --git a/codersdk/audit.go b/codersdk/audit.go index 5018982c6c..ac0b4e908f 100644 --- a/codersdk/audit.go +++ b/codersdk/audit.go @@ -212,6 +212,7 @@ type AuditLogsRequest struct { type AuditLogResponse struct { AuditLogs []AuditLog `json:"audit_logs"` Count int64 `json:"count"` + CountCap int64 `json:"count_cap"` } type CreateTestAuditLogRequest struct { diff --git a/codersdk/connectionlog.go b/codersdk/connectionlog.go index 3e2acec6df..61e1ccbb30 100644 --- a/codersdk/connectionlog.go +++ b/codersdk/connectionlog.go @@ -96,6 +96,7 @@ type ConnectionLogsRequest struct { type ConnectionLogResponse struct { ConnectionLogs []ConnectionLog `json:"connection_logs"` Count int64 `json:"count"` + CountCap int64 `json:"count_cap"` } func (c *Client) ConnectionLogs(ctx context.Context, req ConnectionLogsRequest) (ConnectionLogResponse, error) { diff --git a/docs/reference/api/audit.md b/docs/reference/api/audit.md index bfdc1a259e..8ae32c1295 100644 --- a/docs/reference/api/audit.md +++ b/docs/reference/api/audit.md @@ -90,7 +90,8 @@ curl -X GET http://coder-server:8080/api/v2/audit?limit=0 \ "user_agent": "string" } ], - "count": 0 + "count": 0, + "count_cap": 0 } ``` diff --git a/docs/reference/api/enterprise.md b/docs/reference/api/enterprise.md index 7b16911c5c..439de03cd3 100644 --- a/docs/reference/api/enterprise.md +++ b/docs/reference/api/enterprise.md @@ -291,7 +291,8 @@ curl -X GET http://coder-server:8080/api/v2/connectionlog?limit=0 \ "workspace_owner_username": "string" } ], - "count": 0 + "count": 0, + "count_cap": 0 } ``` diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index b10bf485cb..7b37a35624 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -1740,7 +1740,8 @@ "user_agent": "string" } ], - "count": 0 + "count": 0, + "count_cap": 0 } ``` @@ -1750,6 +1751,7 @@ |--------------|-------------------------------------------------|----------|--------------|-------------| | `audit_logs` | array of [codersdk.AuditLog](#codersdkauditlog) | false | | | | `count` | integer | false | | | +| `count_cap` | integer | false | | | ## codersdk.AuthMethod @@ -2173,7 +2175,8 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in "workspace_owner_username": "string" } ], - "count": 0 + "count": 0, + "count_cap": 0 } ``` @@ -2183,6 +2186,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in |-------------------|-----------------------------------------------------------|----------|--------------|-------------| | `connection_logs` | array of [codersdk.ConnectionLog](#codersdkconnectionlog) | false | | | | `count` | integer | false | | | +| `count_cap` | integer | false | | | ## codersdk.ConnectionLogSSHInfo diff --git a/enterprise/coderd/connectionlog.go b/enterprise/coderd/connectionlog.go index 05e3a40b2d..c37e2ce497 100644 --- a/enterprise/coderd/connectionlog.go +++ b/enterprise/coderd/connectionlog.go @@ -16,6 +16,9 @@ import ( "github.com/coder/coder/v2/codersdk" ) +// NOTE: See the auditLogCountCap note. +const connectionLogCountCap = 2000 + // @Summary Get connection logs // @ID get-connection-logs // @Security CoderSessionToken @@ -49,6 +52,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) { // #nosec G115 - Safe conversion as pagination limit is expected to be within int32 range filter.LimitOpt = int32(page.Limit) + countFilter.CountCap = connectionLogCountCap count, err := api.Database.CountConnectionLogs(ctx, countFilter) if dbauthz.IsNotAuthorizedError(err) { httpapi.Forbidden(rw) @@ -63,6 +67,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{ ConnectionLogs: []codersdk.ConnectionLog{}, Count: 0, + CountCap: connectionLogCountCap, }) return } @@ -80,6 +85,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{ ConnectionLogs: convertConnectionLogs(dblogs), Count: count, + CountCap: connectionLogCountCap, }) } diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 6df4b9b6a9..3fe98238fe 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -913,6 +913,7 @@ export interface AuditLog { export interface AuditLogResponse { readonly audit_logs: readonly AuditLog[]; readonly count: number; + readonly count_cap: number; } // From codersdk/audit.go @@ -2269,6 +2270,7 @@ export interface ConnectionLog { export interface ConnectionLogResponse { readonly connection_logs: readonly ConnectionLog[]; readonly count: number; + readonly count_cap: number; } // From codersdk/connectionlog.go diff --git a/site/src/components/PaginationWidget/PaginationAmount.tsx b/site/src/components/PaginationWidget/PaginationAmount.tsx index 5e9f62b3af..a8d53da276 100644 --- a/site/src/components/PaginationWidget/PaginationAmount.tsx +++ b/site/src/components/PaginationWidget/PaginationAmount.tsx @@ -7,6 +7,7 @@ type PaginationHeaderProps = { limit: number; totalRecords: number | undefined; currentOffsetStart: number | undefined; + countIsCapped?: boolean; // Temporary escape hatch until Workspaces can be switched over to using // PaginationContainer @@ -18,6 +19,7 @@ export const PaginationAmount: FC = ({ limit, totalRecords, currentOffsetStart, + countIsCapped, className, }) => { const theme = useTheme(); @@ -52,10 +54,16 @@ export const PaginationAmount: FC = ({ {( currentOffsetStart + - Math.min(limit - 1, totalRecords - currentOffsetStart) + (countIsCapped + ? limit - 1 + : Math.min(limit - 1, totalRecords - currentOffsetStart)) ).toLocaleString()} {" "} - of {totalRecords.toLocaleString()}{" "} + of{" "} + + {totalRecords.toLocaleString()} + {countIsCapped && "+"} + {" "} {paginationUnitLabel} )} diff --git a/site/src/components/PaginationWidget/PaginationContainer.mocks.ts b/site/src/components/PaginationWidget/PaginationContainer.mocks.ts index e638e1e3db..7466529af5 100644 --- a/site/src/components/PaginationWidget/PaginationContainer.mocks.ts +++ b/site/src/components/PaginationWidget/PaginationContainer.mocks.ts @@ -18,6 +18,7 @@ export const mockPaginationResultBase: ResultBase = { limit: 25, hasNextPage: false, hasPreviousPage: false, + countIsCapped: false, goToPreviousPage: () => {}, goToNextPage: () => {}, goToFirstPage: () => {}, @@ -33,6 +34,7 @@ export const mockInitialRenderResult: PaginationResult = { hasPreviousPage: false, totalRecords: undefined, totalPages: undefined, + countIsCapped: false, }; export const mockSuccessResult: PaginationResult = { diff --git a/site/src/components/PaginationWidget/PaginationContainer.stories.tsx b/site/src/components/PaginationWidget/PaginationContainer.stories.tsx index 23ea7700b7..1cf3904b3a 100644 --- a/site/src/components/PaginationWidget/PaginationContainer.stories.tsx +++ b/site/src/components/PaginationWidget/PaginationContainer.stories.tsx @@ -94,7 +94,7 @@ export const FirstPageWithTonsOfData: Story = { currentPage: 2, currentOffsetStart: 1000, totalRecords: 123_456, - totalPages: 1235, + totalPages: 4939, hasPreviousPage: false, hasNextPage: true, isPlaceholderData: false, @@ -135,3 +135,54 @@ export const SecondPageWithData: Story = { children:
New data for page 2
, }, }; + +export const CappedCountFirstPage: Story = { + args: { + query: { + ...mockPaginationResultBase, + isSuccess: true, + currentPage: 1, + currentOffsetStart: 1, + totalRecords: 2000, + totalPages: 80, + hasPreviousPage: false, + hasNextPage: true, + isPlaceholderData: false, + countIsCapped: true, + }, + }, +}; + +export const CappedCountMiddlePage: Story = { + args: { + query: { + ...mockPaginationResultBase, + isSuccess: true, + currentPage: 3, + currentOffsetStart: 51, + totalRecords: 2000, + totalPages: 80, + hasPreviousPage: true, + hasNextPage: true, + isPlaceholderData: false, + countIsCapped: true, + }, + }, +}; + +export const CappedCountBeyondKnownPages: Story = { + args: { + query: { + ...mockPaginationResultBase, + isSuccess: true, + currentPage: 85, + currentOffsetStart: 2101, + totalRecords: 2000, + totalPages: 85, + hasPreviousPage: true, + hasNextPage: true, + isPlaceholderData: false, + countIsCapped: true, + }, + }, +}; diff --git a/site/src/components/PaginationWidget/PaginationContainer.tsx b/site/src/components/PaginationWidget/PaginationContainer.tsx index b8ebd7d7b7..ce43d98e47 100644 --- a/site/src/components/PaginationWidget/PaginationContainer.tsx +++ b/site/src/components/PaginationWidget/PaginationContainer.tsx @@ -27,12 +27,14 @@ export const PaginationContainer: FC = ({ totalRecords={query.totalRecords} currentOffsetStart={query.currentOffsetStart} paginationUnitLabel={paginationUnitLabel} + countIsCapped={query.countIsCapped} className="justify-end" /> {query.isSuccess && ( = ({ @@ -21,8 +25,9 @@ export const PaginationWidgetBase: FC = ({ onPageChange, hasPreviousPage, hasNextPage, + totalPages: totalPagesProp, }) => { - const totalPages = Math.ceil(totalRecords / pageSize); + const totalPages = totalPagesProp ?? Math.ceil(totalRecords / pageSize); if (totalPages < 2) { return null; diff --git a/site/src/hooks/usePaginatedQuery.test.ts b/site/src/hooks/usePaginatedQuery.test.ts index 060e44e07c..46044f495f 100644 --- a/site/src/hooks/usePaginatedQuery.test.ts +++ b/site/src/hooks/usePaginatedQuery.test.ts @@ -258,6 +258,78 @@ describe(usePaginatedQuery.name, () => { }); }); + describe("Capped count behavior", () => { + const mockQueryKey = vi.fn(() => ["mock"]); + + // Returns count 2001 (capped) with items on pages up to page 84 + // (84 * 25 = 2100 items total). + const mockCappedQueryFn = vi.fn(({ pageNumber, limit }) => { + const totalItems = 2100; + const offset = (pageNumber - 1) * limit; + // Returns 0 items when the requested page is past the end, simulating + // an empty server response. + const itemsOnPage = Math.max(0, Math.min(limit, totalItems - offset)); + return Promise.resolve({ + data: new Array(itemsOnPage).fill(pageNumber), + count: 2001, + count_cap: 2000, + }); + }); + + it("Caps totalRecords at 2000 when count exceeds cap", async () => { + const { result } = await render({ + queryKey: mockQueryKey, + queryFn: mockCappedQueryFn, + }); + + await waitFor(() => expect(result.current.isSuccess).toBe(true)); + expect(result.current.totalRecords).toBe(2000); + }); + + it("hasNextPage is true when count is capped", async () => { + const { result } = await render( + { queryKey: mockQueryKey, queryFn: mockCappedQueryFn }, + "/?page=80", + ); + + await waitFor(() => expect(result.current.isSuccess).toBe(true)); + expect(result.current.hasNextPage).toBe(true); + }); + + it("hasPreviousPage is true when count is capped and page is beyond cap", async () => { + const { result } = await render( + { queryKey: mockQueryKey, queryFn: mockCappedQueryFn }, + "/?page=83", + ); + + await waitFor(() => expect(result.current.isSuccess).toBe(true)); + expect(result.current.hasPreviousPage).toBe(true); + }); + + it("Does not redirect to last page when count is capped and page is valid", async () => { + const { result } = await render( + { queryKey: mockQueryKey, queryFn: mockCappedQueryFn }, + "/?page=83", + ); + + await waitFor(() => expect(result.current.isSuccess).toBe(true)); + // Should stay on page 83 — not redirect to page 80. + expect(result.current.currentPage).toBe(83); + }); + + it("Redirects to last known page when navigating beyond actual data", async () => { + const { result } = await render( + { queryKey: mockQueryKey, queryFn: mockCappedQueryFn }, + "/?page=999", + ); + + // Page 999 has no items. Should redirect to page 81 + // (ceil(2001 / 25) = 81), the last page guaranteed to + // have data. + await waitFor(() => expect(result.current.currentPage).toBe(81)); + }); + }); + describe("Passing in searchParams property", () => { const mockQueryKey = vi.fn(() => ["mock"]); const mockQueryFn = vi.fn(({ pageNumber, limit }) => diff --git a/site/src/hooks/usePaginatedQuery.ts b/site/src/hooks/usePaginatedQuery.ts index 200674d69c..1ad03272a7 100644 --- a/site/src/hooks/usePaginatedQuery.ts +++ b/site/src/hooks/usePaginatedQuery.ts @@ -144,16 +144,44 @@ export function usePaginatedQuery< placeholderData: keepPreviousData, }); - const totalRecords = query.data?.count; - const totalPages = - totalRecords !== undefined ? Math.ceil(totalRecords / limit) : undefined; + const count = query.data?.count; + const countCap = query.data?.count_cap; + const countIsCapped = + countCap !== undefined && + countCap > 0 && + count !== undefined && + count > countCap; + const totalRecords = countIsCapped ? countCap : count; + let totalPages = + totalRecords !== undefined + ? Math.max( + Math.ceil(totalRecords / limit), + // True count is not known; let them navigate forward + // until they hit an empty page (checked below). + countIsCapped ? currentPage : 0, + ) + : undefined; + + // When the true count is unknown, the user can navigate past + // all actual data. If that happens, we need to redirect (via + // updatePageIfInvalid) to the last page guaranteed to be not + // empty. + const pageIsEmpty = + query.data != null && + !Object.values(query.data).some((v) => Array.isArray(v) && v.length > 0); + if (pageIsEmpty) { + totalPages = count !== undefined ? Math.ceil(count / limit) : 1; + } const hasNextPage = - totalRecords !== undefined && limit + currentPageOffset < totalRecords; + totalRecords !== undefined && + ((countIsCapped && !pageIsEmpty) || + limit + currentPageOffset < totalRecords); const hasPreviousPage = totalRecords !== undefined && currentPage > 1 && - currentPageOffset - limit < totalRecords; + ((countIsCapped && !pageIsEmpty) || + currentPageOffset - limit < totalRecords); const queryClient = useQueryClient(); const prefetchPage = useEffectEvent((newPage: number) => { @@ -224,10 +252,14 @@ export function usePaginatedQuery< }); useEffect(() => { - if (!query.isFetching && totalPages !== undefined) { + if ( + !query.isFetching && + totalPages !== undefined && + currentPage > totalPages + ) { void updatePageIfInvalid(totalPages); } - }, [updatePageIfInvalid, query.isFetching, totalPages]); + }, [updatePageIfInvalid, query.isFetching, totalPages, currentPage]); const onPageChange = (newPage: number) => { // Page 1 is the only page that can be safely navigated to without knowing @@ -236,7 +268,12 @@ export function usePaginatedQuery< return; } - const cleanedInput = clamp(Math.trunc(newPage), 1, totalPages ?? 1); + // If the true count is unknown, we allow navigating past the + // known page range. + const upperBound = countIsCapped + ? Number.MAX_SAFE_INTEGER + : (totalPages ?? 1); + const cleanedInput = clamp(Math.trunc(newPage), 1, upperBound); if (Number.isNaN(cleanedInput)) { return; } @@ -274,6 +311,7 @@ export function usePaginatedQuery< totalRecords: totalRecords as number, totalPages: totalPages as number, currentOffsetStart: currentPageOffset + 1, + countIsCapped, } : { isSuccess: false, @@ -282,6 +320,7 @@ export function usePaginatedQuery< totalRecords: undefined, totalPages: undefined, currentOffsetStart: undefined, + countIsCapped: false as const, }), }; @@ -323,6 +362,7 @@ export type PaginationResultInfo = { totalRecords: undefined; totalPages: undefined; currentOffsetStart: undefined; + countIsCapped: false; } | { isSuccess: true; @@ -331,6 +371,7 @@ export type PaginationResultInfo = { totalRecords: number; totalPages: number; currentOffsetStart: number; + countIsCapped: boolean; } ); @@ -417,6 +458,7 @@ type QueryPageParamsWithPayload = QueryPageParams & { */ export type PaginatedData = { count: number; + count_cap?: number; }; /** diff --git a/site/src/pages/AuditPage/AuditPage.test.tsx b/site/src/pages/AuditPage/AuditPage.test.tsx index 75d91e7acb..fb0d82a3f3 100644 --- a/site/src/pages/AuditPage/AuditPage.test.tsx +++ b/site/src/pages/AuditPage/AuditPage.test.tsx @@ -71,6 +71,7 @@ describe("AuditPage", () => { const getAuditLogsSpy = vi.spyOn(API, "getAuditLogs").mockResolvedValue({ audit_logs: [MockAuditLog, MockAuditLog2], count: 2, + count_cap: 0, }); // When @@ -90,6 +91,7 @@ describe("AuditPage", () => { vi.spyOn(API, "getAuditLogs").mockResolvedValue({ audit_logs: [MockAuditLog], count: 1, + count_cap: 0, }); await renderPage(); @@ -114,6 +116,7 @@ describe("AuditPage", () => { vi.spyOn(API, "getAuditLogs").mockResolvedValue({ audit_logs: [MockAuditLog], count: 1, + count_cap: 0, }); await renderPage(); @@ -140,9 +143,11 @@ describe("AuditPage", () => { describe("Filtering", () => { it("filters by URL", async () => { - const getAuditLogsSpy = vi - .spyOn(API, "getAuditLogs") - .mockResolvedValue({ audit_logs: [MockAuditLog], count: 1 }); + const getAuditLogsSpy = vi.spyOn(API, "getAuditLogs").mockResolvedValue({ + audit_logs: [MockAuditLog], + count: 1, + count_cap: 0, + }); const query = "resource_type:workspace action:create"; await renderPage({ filter: query }); @@ -173,4 +178,29 @@ describe("AuditPage", () => { ); }); }); + + describe("Capped count", () => { + it("shows capped count indicator and navigates to next page with correct offset", async () => { + vi.spyOn(API, "getAuditLogs").mockResolvedValue({ + audit_logs: [MockAuditLog, MockAuditLog2], + count: 2001, + count_cap: 2000, + }); + + const user = userEvent.setup(); + await renderPage(); + + await screen.findByText(/2,000\+/); + + await user.click(screen.getByRole("button", { name: /next page/i })); + + await waitFor(() => + expect(API.getAuditLogs).toHaveBeenLastCalledWith<[AuditLogsRequest]>({ + limit: DEFAULT_RECORDS_PER_PAGE, + offset: DEFAULT_RECORDS_PER_PAGE, + q: "", + }), + ); + }); + }); }); diff --git a/site/src/pages/ConnectionLogPage/ConnectionLogPage.test.tsx b/site/src/pages/ConnectionLogPage/ConnectionLogPage.test.tsx index c555839ab9..3fff5aa896 100644 --- a/site/src/pages/ConnectionLogPage/ConnectionLogPage.test.tsx +++ b/site/src/pages/ConnectionLogPage/ConnectionLogPage.test.tsx @@ -69,6 +69,7 @@ describe("ConnectionLogPage", () => { MockDisconnectedSSHConnectionLog, ], count: 2, + count_cap: 0, }); // When @@ -95,6 +96,7 @@ describe("ConnectionLogPage", () => { .mockResolvedValue({ connection_logs: [MockConnectedSSHConnectionLog], count: 1, + count_cap: 0, }); const query = "type:ssh status:ongoing";