mirror of
https://github.com/coder/coder.git
synced 2026-06-03 04:58:23 +00:00
b779c9ee33
## Problem The chat listing endpoint (`GetChatsByOwnerID`) was using `fetchWithPostFilter`, which fetches N rows from the database and then filters them in Go memory using RBAC checks. This causes a pagination bug: if the user requests `limit=25` but some rows fail the auth check, fewer than 25 rows are returned even though more authorized rows exist in the database. The client may incorrectly assume it has reached the end of the list. ## Solution Switch to the same pattern used by `GetWorkspaces`, `GetTemplates`, and `GetUsers`: `prepareSQLFilter` + `GetAuthorized*` variant. The RBAC filter is compiled to a SQL WHERE clause and injected into the query before `ORDER BY`/`LIMIT`, so the database returns exactly the requested number of authorized rows. Additionally, `GetChatsByOwnerID` is renamed to `GetChats` with `OwnerID` as an optional (nullable) filter parameter, matching the `GetWorkspaces` naming convention. ## Changes | File | Change | |------|--------| | `queries/chats.sql` | Renamed to `GetChats`, `owner_id` now optional via CASE/NULL, added `-- @authorize_filter` | | `queries.sql.go` | Renamed constant, params struct (`GetChatsParams`), and method | | `querier.go` | Interface method renamed | | `modelqueries.go` | Added `chatQuerier` interface + `GetAuthorizedChats` impl | | `dbauthz/dbauthz.go` | `GetChats` now uses `prepareSQLFilter` instead of `fetchWithPostFilter` | | `dbauthz/dbauthz_test.go` | Updated tests for SQL filter pattern | | `dbmock/dbmock.go` | Renamed + added mock for `GetAuthorizedChats` | | `dbmetrics/querymetrics.go` | Renamed + added metrics wrapper | | `rbac/regosql/configs.go` | Added `ChatConverter` (maps `org_owner` to empty string literal since `chats` has no `organization_id` column) | | `rbac/authz.go` | Added `ConfigChats()` | | `chats.go` | Handler uses renamed method with `uuid.NullUUID` | | `searchquery/search.go` | Updated return type | | `gitsync/worker.go` | Updated interface and call site | | Various test files | Updated for renamed types |
352 lines
9.9 KiB
Go
352 lines
9.9 KiB
Go
package gitsync
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
const (
|
|
// defaultBatchSize is the maximum number of stale rows fetched
|
|
// per tick.
|
|
defaultBatchSize int32 = 50
|
|
|
|
// defaultInterval is the polling interval between ticks.
|
|
defaultInterval = 10 * time.Second
|
|
|
|
// defaultTickTimeout is the maximum time a single tick may
|
|
// run. Decoupled from the polling interval so that a batch
|
|
// of concurrent HTTP calls has enough headroom to complete.
|
|
defaultTickTimeout = 30 * time.Second
|
|
|
|
// NoTokenBackoff is the backoff duration applied to rows
|
|
// whose owner has no linked external-auth token. Much longer
|
|
// than DiffStatusTTL because the user must manually link
|
|
// their account before retrying is useful.
|
|
NoTokenBackoff = 10 * time.Minute
|
|
)
|
|
|
|
// Store is the narrow DB interface the Worker needs.
|
|
type Store interface {
|
|
AcquireStaleChatDiffStatuses(
|
|
ctx context.Context, limitVal int32,
|
|
) ([]database.AcquireStaleChatDiffStatusesRow, error)
|
|
BackoffChatDiffStatus(
|
|
ctx context.Context, arg database.BackoffChatDiffStatusParams,
|
|
) error
|
|
UpsertChatDiffStatus(
|
|
ctx context.Context, arg database.UpsertChatDiffStatusParams,
|
|
) (database.ChatDiffStatus, error)
|
|
UpsertChatDiffStatusReference(
|
|
ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams,
|
|
) (database.ChatDiffStatus, error)
|
|
GetChats(
|
|
ctx context.Context, arg database.GetChatsParams,
|
|
) ([]database.Chat, error)
|
|
}
|
|
|
|
// EventPublisher notifies the frontend of diff status changes.
|
|
type PublishDiffStatusChangeFunc func(ctx context.Context, chatID uuid.UUID) error
|
|
|
|
// Worker is a background loop that periodically refreshes stale
|
|
// chat diff statuses by delegating to a Refresher.
|
|
type Worker struct {
|
|
store Store
|
|
refresher *Refresher
|
|
publishDiffStatusChangeFn PublishDiffStatusChangeFunc
|
|
clock quartz.Clock
|
|
logger slog.Logger
|
|
batchSize int32
|
|
interval time.Duration
|
|
tickTimeout time.Duration
|
|
done chan struct{}
|
|
}
|
|
|
|
// WorkerOption configures a Worker.
|
|
type WorkerOption func(*Worker)
|
|
|
|
// WithTickTimeout sets the maximum duration for a single tick.
|
|
func WithTickTimeout(d time.Duration) WorkerOption {
|
|
return func(w *Worker) {
|
|
if d > 0 {
|
|
w.tickTimeout = d
|
|
}
|
|
}
|
|
}
|
|
|
|
// NewWorker creates a Worker with default batch size and interval.
|
|
func NewWorker(
|
|
store Store,
|
|
refresher *Refresher,
|
|
publisher PublishDiffStatusChangeFunc,
|
|
clock quartz.Clock,
|
|
logger slog.Logger,
|
|
opts ...WorkerOption,
|
|
) *Worker {
|
|
w := &Worker{
|
|
store: store,
|
|
refresher: refresher,
|
|
publishDiffStatusChangeFn: publisher,
|
|
clock: clock,
|
|
logger: logger,
|
|
batchSize: defaultBatchSize,
|
|
interval: defaultInterval,
|
|
tickTimeout: defaultTickTimeout,
|
|
done: make(chan struct{}),
|
|
}
|
|
for _, o := range opts {
|
|
o(w)
|
|
}
|
|
return w
|
|
}
|
|
|
|
// Start launches the background loop. It blocks until ctx is
|
|
// cancelled, then closes w.done.
|
|
func (w *Worker) Start(ctx context.Context) {
|
|
defer close(w.done)
|
|
|
|
ticker := w.clock.NewTicker(w.interval, "gitsync", "worker")
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
w.tick(ctx)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Done returns a channel that is closed when the worker exits.
|
|
func (w *Worker) Done() <-chan struct{} {
|
|
return w.done
|
|
}
|
|
|
|
func chatDiffStatusFromRow(row database.AcquireStaleChatDiffStatusesRow) database.ChatDiffStatus {
|
|
return database.ChatDiffStatus{
|
|
ChatID: row.ChatID,
|
|
Url: row.Url,
|
|
PullRequestState: row.PullRequestState,
|
|
ChangesRequested: row.ChangesRequested,
|
|
Additions: row.Additions,
|
|
Deletions: row.Deletions,
|
|
ChangedFiles: row.ChangedFiles,
|
|
AuthorLogin: row.AuthorLogin,
|
|
AuthorAvatarUrl: row.AuthorAvatarUrl,
|
|
BaseBranch: row.BaseBranch,
|
|
HeadBranch: row.HeadBranch,
|
|
PrNumber: row.PrNumber,
|
|
Commits: row.Commits,
|
|
Approved: row.Approved,
|
|
ReviewerCount: row.ReviewerCount,
|
|
RefreshedAt: row.RefreshedAt,
|
|
StaleAt: row.StaleAt,
|
|
CreatedAt: row.CreatedAt,
|
|
UpdatedAt: row.UpdatedAt,
|
|
GitBranch: row.GitBranch,
|
|
GitRemoteOrigin: row.GitRemoteOrigin,
|
|
PullRequestTitle: row.PullRequestTitle,
|
|
PullRequestDraft: row.PullRequestDraft,
|
|
}
|
|
}
|
|
|
|
func (w *Worker) tick(ctx context.Context) {
|
|
// Use a dedicated tick timeout that is longer than the
|
|
// polling interval. This gives concurrent HTTP calls enough
|
|
// headroom without stalling the next tick excessively.
|
|
ctx, cancel := context.WithTimeout(ctx, w.tickTimeout)
|
|
defer cancel()
|
|
|
|
acquiredRows, err := w.store.AcquireStaleChatDiffStatuses(ctx, w.batchSize)
|
|
if err != nil {
|
|
w.logger.Warn(ctx, "acquire stale chat diff statuses",
|
|
slog.Error(err))
|
|
return
|
|
}
|
|
if len(acquiredRows) == 0 {
|
|
return
|
|
}
|
|
|
|
// Build refresh requests directly from acquired rows.
|
|
requests := make([]RefreshRequest, 0, len(acquiredRows))
|
|
for _, row := range acquiredRows {
|
|
requests = append(requests, RefreshRequest{
|
|
Row: chatDiffStatusFromRow(row),
|
|
OwnerID: row.OwnerID,
|
|
})
|
|
}
|
|
|
|
results, err := w.refresher.Refresh(ctx, requests)
|
|
if err != nil {
|
|
w.logger.Warn(ctx, "batch refresh chat diff statuses",
|
|
slog.Error(err))
|
|
return
|
|
}
|
|
|
|
for _, res := range results {
|
|
if res.Error != nil {
|
|
w.logger.Debug(ctx, "refresh chat diff status",
|
|
slog.F("chat_id", res.Request.Row.ChatID),
|
|
slog.Error(res.Error))
|
|
// Apply a longer backoff for rows whose owner has
|
|
// no linked token — retrying every 2 minutes is
|
|
// pointless until the user links their account.
|
|
backoff := DiffStatusTTL
|
|
if errors.Is(res.Error, ErrNoTokenAvailable) {
|
|
backoff = NoTokenBackoff
|
|
}
|
|
// Back off so the row isn't retried immediately.
|
|
if err := w.store.BackoffChatDiffStatus(ctx,
|
|
database.BackoffChatDiffStatusParams{
|
|
ChatID: res.Request.Row.ChatID,
|
|
StaleAt: w.clock.Now().UTC().Add(backoff),
|
|
},
|
|
); err != nil {
|
|
w.logger.Warn(ctx, "backoff failed chat diff status",
|
|
slog.F("chat_id", res.Request.Row.ChatID),
|
|
slog.Error(err))
|
|
}
|
|
continue
|
|
}
|
|
if res.Params == nil {
|
|
// No PR yet — skip.
|
|
continue
|
|
}
|
|
if _, err := w.store.UpsertChatDiffStatus(ctx, *res.Params); err != nil {
|
|
w.logger.Warn(ctx, "upsert refreshed chat diff status",
|
|
slog.F("chat_id", res.Request.Row.ChatID),
|
|
slog.Error(err))
|
|
continue
|
|
}
|
|
if w.publishDiffStatusChangeFn != nil {
|
|
if err := w.publishDiffStatusChangeFn(ctx, res.Request.Row.ChatID); err != nil {
|
|
w.logger.Debug(ctx, "publish diff status change",
|
|
slog.F("chat_id", res.Request.Row.ChatID),
|
|
slog.Error(err))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// MarkStale persists the git ref on all chats for a workspace,
|
|
// setting stale_at to the past so the next tick picks them up.
|
|
// Publishes a diff status event for each affected chat.
|
|
// Called from workspaceagents handlers. No goroutines spawned.
|
|
func (w *Worker) MarkStale(
|
|
ctx context.Context,
|
|
workspaceID, ownerID uuid.UUID,
|
|
branch, origin string,
|
|
) {
|
|
if branch == "" || origin == "" {
|
|
return
|
|
}
|
|
|
|
chats, err := w.store.GetChats(ctx, database.GetChatsParams{
|
|
OwnerID: ownerID,
|
|
})
|
|
if err != nil {
|
|
w.logger.Warn(ctx, "list chats for git ref storage",
|
|
slog.F("workspace_id", workspaceID),
|
|
slog.Error(err))
|
|
return
|
|
}
|
|
|
|
for _, chat := range filterChatsByWorkspaceID(chats, workspaceID) {
|
|
_, err := w.store.UpsertChatDiffStatusReference(ctx,
|
|
database.UpsertChatDiffStatusReferenceParams{
|
|
ChatID: chat.ID,
|
|
GitBranch: branch,
|
|
GitRemoteOrigin: origin,
|
|
StaleAt: w.clock.Now().Add(-time.Second),
|
|
Url: sql.NullString{},
|
|
},
|
|
)
|
|
if err != nil {
|
|
w.logger.Warn(ctx, "store git ref on chat diff status",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("workspace_id", workspaceID),
|
|
slog.Error(err))
|
|
continue
|
|
}
|
|
// Notify the frontend immediately so the UI shows the
|
|
// branch info even before the worker refreshes PR data.
|
|
if w.publishDiffStatusChangeFn != nil {
|
|
if pubErr := w.publishDiffStatusChangeFn(ctx, chat.ID); pubErr != nil {
|
|
w.logger.Debug(ctx, "publish diff status after mark stale",
|
|
slog.F("chat_id", chat.ID), slog.Error(pubErr))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// RefreshChat synchronously refreshes a single chat's diff
|
|
// status using the same Refresher pipeline as the background
|
|
// worker. Returns nil, nil when no PR exists yet for the
|
|
// branch. Called from HTTP handlers for instant feedback.
|
|
func (w *Worker) RefreshChat(
|
|
ctx context.Context,
|
|
row database.ChatDiffStatus,
|
|
ownerID uuid.UUID,
|
|
) (*database.ChatDiffStatus, error) {
|
|
requests := []RefreshRequest{{
|
|
Row: row,
|
|
OwnerID: ownerID,
|
|
}}
|
|
|
|
results, err := w.refresher.Refresh(ctx, requests)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("refresh chat diff status: %w", err)
|
|
}
|
|
|
|
if len(results) == 0 {
|
|
return nil, nil
|
|
}
|
|
res := results[0]
|
|
if res.Error != nil {
|
|
return nil, xerrors.Errorf("refresh chat diff status: %w", res.Error)
|
|
}
|
|
if res.Params == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
upserted, err := w.store.UpsertChatDiffStatus(ctx, *res.Params)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("upsert chat diff status: %w", err)
|
|
}
|
|
|
|
if w.publishDiffStatusChangeFn != nil {
|
|
if err := w.publishDiffStatusChangeFn(ctx, row.ChatID); err != nil {
|
|
w.logger.Debug(ctx, "publish diff status change",
|
|
slog.F("chat_id", row.ChatID),
|
|
slog.Error(err))
|
|
}
|
|
}
|
|
|
|
return &upserted, nil
|
|
}
|
|
|
|
// filterChatsByWorkspaceID returns only chats associated with
|
|
// the given workspace.
|
|
func filterChatsByWorkspaceID(
|
|
chats []database.Chat,
|
|
workspaceID uuid.UUID,
|
|
) []database.Chat {
|
|
filtered := make([]database.Chat, 0, len(chats))
|
|
for _, chat := range chats {
|
|
if !chat.WorkspaceID.Valid || chat.WorkspaceID.UUID != workspaceID {
|
|
continue
|
|
}
|
|
filtered = append(filtered, chat)
|
|
}
|
|
return filtered
|
|
}
|