mirror of
https://github.com/coder/coder.git
synced 2026-06-03 21:18:24 +00:00
b779c9ee33
## Problem The chat listing endpoint (`GetChatsByOwnerID`) was using `fetchWithPostFilter`, which fetches N rows from the database and then filters them in Go memory using RBAC checks. This causes a pagination bug: if the user requests `limit=25` but some rows fail the auth check, fewer than 25 rows are returned even though more authorized rows exist in the database. The client may incorrectly assume it has reached the end of the list. ## Solution Switch to the same pattern used by `GetWorkspaces`, `GetTemplates`, and `GetUsers`: `prepareSQLFilter` + `GetAuthorized*` variant. The RBAC filter is compiled to a SQL WHERE clause and injected into the query before `ORDER BY`/`LIMIT`, so the database returns exactly the requested number of authorized rows. Additionally, `GetChatsByOwnerID` is renamed to `GetChats` with `OwnerID` as an optional (nullable) filter parameter, matching the `GetWorkspaces` naming convention. ## Changes | File | Change | |------|--------| | `queries/chats.sql` | Renamed to `GetChats`, `owner_id` now optional via CASE/NULL, added `-- @authorize_filter` | | `queries.sql.go` | Renamed constant, params struct (`GetChatsParams`), and method | | `querier.go` | Interface method renamed | | `modelqueries.go` | Added `chatQuerier` interface + `GetAuthorizedChats` impl | | `dbauthz/dbauthz.go` | `GetChats` now uses `prepareSQLFilter` instead of `fetchWithPostFilter` | | `dbauthz/dbauthz_test.go` | Updated tests for SQL filter pattern | | `dbmock/dbmock.go` | Renamed + added mock for `GetAuthorizedChats` | | `dbmetrics/querymetrics.go` | Renamed + added metrics wrapper | | `rbac/regosql/configs.go` | Added `ChatConverter` (maps `org_owner` to empty string literal since `chats` has no `organization_id` column) | | `rbac/authz.go` | Added `ConfigChats()` | | `chats.go` | Handler uses renamed method with `uuid.NullUUID` | | `searchquery/search.go` | Updated return type | | `gitsync/worker.go` | Updated interface and call site | | Various test files | Updated for renamed types |
963 lines
30 KiB
Go
963 lines
30 KiB
Go
package gitsync_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbgen"
|
|
"github.com/coder/coder/v2/coderd/database/dbmock"
|
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
|
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
|
"github.com/coder/coder/v2/coderd/gitsync"
|
|
"github.com/coder/coder/v2/coderd/util/ptr"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
// testRefresherCfg configures newTestRefresher.
|
|
type testRefresherCfg struct {
|
|
resolveBranchPR func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error)
|
|
fetchPRStatus func(context.Context, string, gitprovider.PRRef) (*gitprovider.PRStatus, error)
|
|
refresherOpts []gitsync.RefresherOption
|
|
}
|
|
|
|
type testRefresherOpt func(*testRefresherCfg)
|
|
|
|
func withResolveBranchPR(f func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error)) testRefresherOpt {
|
|
return func(c *testRefresherCfg) { c.resolveBranchPR = f }
|
|
}
|
|
|
|
func withRefresherOpts(opts ...gitsync.RefresherOption) testRefresherOpt {
|
|
return func(c *testRefresherCfg) { c.refresherOpts = opts }
|
|
}
|
|
|
|
// newTestRefresher creates a Refresher backed by mock
|
|
// provider/token resolvers. The provider recognises any origin,
|
|
// resolves branches to a canned PR, and returns a canned PRStatus.
|
|
func newTestRefresher(t *testing.T, clk quartz.Clock, opts ...testRefresherOpt) *gitsync.Refresher {
|
|
t.Helper()
|
|
|
|
cfg := testRefresherCfg{
|
|
resolveBranchPR: func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) {
|
|
return &gitprovider.PRRef{Owner: "o", Repo: "r", Number: 1}, nil
|
|
},
|
|
fetchPRStatus: func(context.Context, string, gitprovider.PRRef) (*gitprovider.PRStatus, error) {
|
|
return &gitprovider.PRStatus{
|
|
State: gitprovider.PRStateOpen,
|
|
DiffStats: gitprovider.DiffStats{
|
|
Additions: 10,
|
|
Deletions: 3,
|
|
ChangedFiles: 2,
|
|
},
|
|
}, nil
|
|
},
|
|
}
|
|
for _, o := range opts {
|
|
o(&cfg)
|
|
}
|
|
|
|
prov := &mockProvider{
|
|
parseRepositoryOrigin: func(string) (string, string, string, bool) {
|
|
return "owner", "repo", "https://github.com/owner/repo", true
|
|
},
|
|
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
|
|
return gitprovider.PRRef{Owner: "owner", Repo: "repo", Number: 1}, raw != ""
|
|
},
|
|
resolveBranchPR: cfg.resolveBranchPR,
|
|
fetchPullRequestStatus: cfg.fetchPRStatus,
|
|
buildPullRequestURL: func(ref gitprovider.PRRef) string {
|
|
return fmt.Sprintf("https://github.com/%s/%s/pull/%d", ref.Owner, ref.Repo, ref.Number)
|
|
},
|
|
}
|
|
|
|
providers := func(string) gitprovider.Provider { return prov }
|
|
tokens := func(context.Context, uuid.UUID, string) (*string, error) {
|
|
return ptr.Ref("tok"), nil
|
|
}
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
return gitsync.NewRefresher(providers, tokens, logger, clk, cfg.refresherOpts...)
|
|
}
|
|
|
|
// makeAcquiredRowWithBranch returns an AcquireStaleChatDiffStatusesRow with
|
|
// the given branch and a non-empty origin so the Refresher goes through the
|
|
// branch-resolution path.
|
|
func makeAcquiredRowWithBranch(chatID, ownerID uuid.UUID, branch string) database.AcquireStaleChatDiffStatusesRow {
|
|
return database.AcquireStaleChatDiffStatusesRow{
|
|
ChatID: chatID,
|
|
GitBranch: branch,
|
|
GitRemoteOrigin: "https://github.com/owner/repo",
|
|
StaleAt: time.Now().Add(-time.Minute),
|
|
OwnerID: ownerID,
|
|
}
|
|
}
|
|
|
|
// tickOnce traps the worker's NewTicker call, starts the worker,
|
|
// fires one tick, waits for it to finish by observing the given
|
|
// tickDone channel, then shuts the worker down. The tickDone
|
|
// channel must be closed when the last expected operation in the
|
|
// tick completes. For tests where the tick does nothing (e.g. 0
|
|
// stale rows or store error), tickDone should be closed inside
|
|
// acquireStaleChatDiffStatuses.
|
|
func tickOnce(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
mClock *quartz.Mock,
|
|
worker *gitsync.Worker,
|
|
tickDone <-chan struct{},
|
|
) {
|
|
t.Helper()
|
|
|
|
trap := mClock.Trap().NewTicker("gitsync", "worker")
|
|
defer trap.Close()
|
|
|
|
workerCtx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
go worker.Start(workerCtx)
|
|
|
|
// Wait for the worker to create its ticker.
|
|
trap.MustWait(ctx).MustRelease(ctx)
|
|
|
|
// Fire one tick. The waiter resolves when the channel receive
|
|
// completes, not when w.tick() returns, so we use tickDone to
|
|
// know when to proceed.
|
|
_, w := mClock.AdvanceNext()
|
|
w.MustWait(ctx)
|
|
|
|
// Wait for the tick's business logic to finish.
|
|
select {
|
|
case <-tickDone:
|
|
case <-ctx.Done():
|
|
t.Fatal("timed out waiting for tick to complete")
|
|
}
|
|
|
|
cancel()
|
|
<-worker.Done()
|
|
}
|
|
|
|
func TestWorker_SkipsFreshRows(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
tickDone := make(chan struct{})
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
|
DoAndReturn(func(context.Context, int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
|
|
// No stale rows — tick returns immediately.
|
|
close(tickDone)
|
|
return nil, nil
|
|
})
|
|
|
|
mClock := quartz.NewMock(t)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
refresher := newTestRefresher(t, mClock)
|
|
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
|
|
|
|
tickOnce(ctx, t, mClock, worker, tickDone)
|
|
}
|
|
|
|
func TestWorker_LimitsToNRows(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
var capturedLimit atomic.Int32
|
|
var upsertCount atomic.Int32
|
|
ownerID := uuid.New()
|
|
const numRows = 5
|
|
tickDone := make(chan struct{})
|
|
|
|
rows := make([]database.AcquireStaleChatDiffStatusesRow, numRows)
|
|
for i := range rows {
|
|
rows[i] = makeAcquiredRowWithBranch(uuid.New(), ownerID, "feature")
|
|
}
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
|
DoAndReturn(func(_ context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
|
|
capturedLimit.Store(limitVal)
|
|
return rows, nil
|
|
})
|
|
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
|
|
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
|
upsertCount.Add(1)
|
|
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
|
}).Times(numRows)
|
|
|
|
pub := func(_ context.Context, _ uuid.UUID) error {
|
|
if upsertCount.Load() == numRows {
|
|
close(tickDone)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
mClock := quartz.NewMock(t)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
refresher := newTestRefresher(t, mClock)
|
|
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
|
|
|
tickOnce(ctx, t, mClock, worker, tickDone)
|
|
|
|
// The default batch size is 50.
|
|
assert.Equal(t, int32(50), capturedLimit.Load())
|
|
assert.Equal(t, int32(numRows), upsertCount.Load())
|
|
}
|
|
|
|
func TestWorker_RefresherReturnsNilNil_SkipsUpsert(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
chatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
|
|
// When the Refresher returns (nil, nil) the worker skips the
|
|
// upsert and publish. We signal tickDone from the refresher
|
|
// mock since that is the last operation before the tick
|
|
// returns.
|
|
tickDone := make(chan struct{})
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
|
Return([]database.AcquireStaleChatDiffStatusesRow{makeAcquiredRowWithBranch(chatID, ownerID, "feature")}, nil)
|
|
|
|
mClock := quartz.NewMock(t)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
// ResolveBranchPullRequest returns nil → Refresher returns
|
|
// (nil, nil).
|
|
refresher := newTestRefresher(t, mClock, withResolveBranchPR(
|
|
func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) {
|
|
close(tickDone)
|
|
return nil, nil
|
|
},
|
|
))
|
|
|
|
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
|
|
|
|
tickOnce(ctx, t, mClock, worker, tickDone)
|
|
}
|
|
|
|
func TestWorker_RefresherError_BacksOffRow(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
chat1 := uuid.New()
|
|
chat2 := uuid.New()
|
|
ownerID := uuid.New()
|
|
|
|
var upsertCount atomic.Int32
|
|
var publishCount atomic.Int32
|
|
var backoffCount atomic.Int32
|
|
var mu sync.Mutex
|
|
var backoffArgs []database.BackoffChatDiffStatusParams
|
|
tickDone := make(chan struct{})
|
|
var closeOnce sync.Once
|
|
|
|
// Two rows processed: one fails (backoff), one succeeds
|
|
// (upsert+publish). Both must finish before we close tickDone.
|
|
var terminalOps atomic.Int32
|
|
signalIfDone := func() {
|
|
if terminalOps.Add(1) == 2 {
|
|
closeOnce.Do(func() { close(tickDone) })
|
|
}
|
|
}
|
|
|
|
mClock := quartz.NewMock(t)
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
|
Return([]database.AcquireStaleChatDiffStatusesRow{
|
|
makeAcquiredRowWithBranch(chat1, ownerID, "fail-branch"),
|
|
makeAcquiredRowWithBranch(chat2, ownerID, "success-branch"),
|
|
}, nil)
|
|
store.EXPECT().BackoffChatDiffStatus(gomock.Any(), gomock.Any()).
|
|
DoAndReturn(func(_ context.Context, arg database.BackoffChatDiffStatusParams) error {
|
|
backoffCount.Add(1)
|
|
mu.Lock()
|
|
backoffArgs = append(backoffArgs, arg)
|
|
mu.Unlock()
|
|
signalIfDone()
|
|
return nil
|
|
})
|
|
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
|
|
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
|
upsertCount.Add(1)
|
|
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
|
})
|
|
|
|
pub := func(_ context.Context, _ uuid.UUID) error {
|
|
// Only the successful row publishes.
|
|
publishCount.Add(1)
|
|
signalIfDone()
|
|
return nil
|
|
}
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
// Fail ResolveBranchPullRequest based on the branch name
|
|
// so the behavior is deterministic regardless of execution
|
|
// order.
|
|
refresher := newTestRefresher(t, mClock, withResolveBranchPR(
|
|
func(_ context.Context, _ string, ref gitprovider.BranchRef) (*gitprovider.PRRef, error) {
|
|
if ref.Branch == "fail-branch" {
|
|
return nil, fmt.Errorf("simulated provider error")
|
|
}
|
|
return &gitprovider.PRRef{Owner: "o", Repo: "r", Number: 1}, nil
|
|
},
|
|
))
|
|
|
|
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
|
|
|
tickOnce(ctx, t, mClock, worker, tickDone)
|
|
|
|
// BackoffChatDiffStatus was called for the failed row.
|
|
assert.Equal(t, int32(1), backoffCount.Load())
|
|
mu.Lock()
|
|
require.Len(t, backoffArgs, 1)
|
|
assert.Equal(t, chat1, backoffArgs[0].ChatID)
|
|
// stale_at should be approximately clock.Now() + DiffStatusTTL (120s).
|
|
expectedStaleAt := mClock.Now().UTC().Add(gitsync.DiffStatusTTL)
|
|
assert.WithinDuration(t, expectedStaleAt, backoffArgs[0].StaleAt, time.Second)
|
|
mu.Unlock()
|
|
|
|
// UpsertChatDiffStatus was called for the successful row.
|
|
assert.Equal(t, int32(1), upsertCount.Load())
|
|
// PublishDiffStatusChange was called only for the successful row.
|
|
assert.Equal(t, int32(1), publishCount.Load())
|
|
}
|
|
|
|
func TestWorker_UpsertError_ContinuesNextRow(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
chat1 := uuid.New()
|
|
chat2 := uuid.New()
|
|
ownerID := uuid.New()
|
|
|
|
var publishCount atomic.Int32
|
|
tickDone := make(chan struct{})
|
|
var closeOnce sync.Once
|
|
var mu sync.Mutex
|
|
upsertedChatIDs := make(map[uuid.UUID]struct{})
|
|
|
|
// We have 2 rows. The upsert for chat1 fails; the upsert
|
|
// for chat2 succeeds and publishes. Because goroutines run
|
|
// concurrently we don't know which finishes last, so we
|
|
// track the total number of "terminal" events (upsert error
|
|
// + publish success) and close tickDone when both have
|
|
// occurred.
|
|
var terminalOps atomic.Int32
|
|
signalIfDone := func() {
|
|
if terminalOps.Add(1) == 2 {
|
|
closeOnce.Do(func() { close(tickDone) })
|
|
}
|
|
}
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
|
Return([]database.AcquireStaleChatDiffStatusesRow{
|
|
makeAcquiredRowWithBranch(chat1, ownerID, "feature"),
|
|
makeAcquiredRowWithBranch(chat2, ownerID, "feature"),
|
|
}, nil)
|
|
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
|
|
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
|
if arg.ChatID == chat1 {
|
|
// Terminal event for the failing row.
|
|
signalIfDone()
|
|
return database.ChatDiffStatus{}, fmt.Errorf("db write error")
|
|
}
|
|
mu.Lock()
|
|
upsertedChatIDs[arg.ChatID] = struct{}{}
|
|
mu.Unlock()
|
|
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
|
}).Times(2)
|
|
|
|
pub := func(_ context.Context, _ uuid.UUID) error {
|
|
publishCount.Add(1)
|
|
// Terminal event for the successful row.
|
|
signalIfDone()
|
|
return nil
|
|
}
|
|
|
|
mClock := quartz.NewMock(t)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
refresher := newTestRefresher(t, mClock)
|
|
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
|
|
|
tickOnce(ctx, t, mClock, worker, tickDone)
|
|
|
|
mu.Lock()
|
|
_, gotChat2 := upsertedChatIDs[chat2]
|
|
mu.Unlock()
|
|
assert.True(t, gotChat2, "chat2 should have been upserted")
|
|
assert.Equal(t, int32(1), publishCount.Load())
|
|
}
|
|
|
|
func TestWorker_RespectsShutdown(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
|
Return(nil, nil).AnyTimes()
|
|
|
|
mClock := quartz.NewMock(t)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
refresher := newTestRefresher(t, mClock)
|
|
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
|
|
|
|
trap := mClock.Trap().NewTicker("gitsync", "worker")
|
|
defer trap.Close()
|
|
|
|
workerCtx, cancel := context.WithCancel(ctx)
|
|
go worker.Start(workerCtx)
|
|
|
|
// Wait for ticker creation so the worker is running.
|
|
trap.MustWait(ctx).MustRelease(ctx)
|
|
|
|
// Cancel immediately.
|
|
cancel()
|
|
|
|
select {
|
|
case <-worker.Done():
|
|
// Success — worker shut down.
|
|
case <-ctx.Done():
|
|
t.Fatal("timed out waiting for worker to shut down")
|
|
}
|
|
}
|
|
|
|
func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
workspaceID := uuid.New()
|
|
ownerID := uuid.New()
|
|
chat1 := uuid.New()
|
|
chat2 := uuid.New()
|
|
chatOther := uuid.New()
|
|
|
|
var mu sync.Mutex
|
|
var upsertRefCalls []database.UpsertChatDiffStatusReferenceParams
|
|
var publishedIDs []uuid.UUID
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
|
|
DoAndReturn(func(_ context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
|
require.Equal(t, ownerID, arg.OwnerID)
|
|
return []database.Chat{
|
|
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
|
|
{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
|
|
{ID: chatOther, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
|
}, nil
|
|
})
|
|
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
|
mu.Lock()
|
|
upsertRefCalls = append(upsertRefCalls, arg)
|
|
mu.Unlock()
|
|
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
|
}).Times(2)
|
|
|
|
pub := func(_ context.Context, chatID uuid.UUID) error {
|
|
mu.Lock()
|
|
publishedIDs = append(publishedIDs, chatID)
|
|
mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
mClock := quartz.NewMock(t)
|
|
now := mClock.Now()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
refresher := newTestRefresher(t, mClock)
|
|
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
|
|
|
worker.MarkStale(ctx, workspaceID, ownerID, "feature", "https://github.com/owner/repo")
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
require.Len(t, upsertRefCalls, 2)
|
|
for _, call := range upsertRefCalls {
|
|
assert.Equal(t, "feature", call.GitBranch)
|
|
assert.Equal(t, "https://github.com/owner/repo", call.GitRemoteOrigin)
|
|
assert.True(t, call.StaleAt.Before(now),
|
|
"stale_at should be in the past, got %v vs now %v", call.StaleAt, now)
|
|
assert.Equal(t, sql.NullString{}, call.Url)
|
|
}
|
|
|
|
require.Len(t, publishedIDs, 2)
|
|
assert.ElementsMatch(t, []uuid.UUID{chat1, chat2}, publishedIDs)
|
|
}
|
|
|
|
func TestWorker_MarkStale_NoMatchingChats(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
workspaceID := uuid.New()
|
|
ownerID := uuid.New()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
|
|
Return([]database.Chat{
|
|
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
|
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
|
}, nil)
|
|
|
|
mClock := quartz.NewMock(t)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
refresher := newTestRefresher(t, mClock)
|
|
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
|
|
|
|
worker.MarkStale(ctx, workspaceID, ownerID, "main", "https://github.com/x/y")
|
|
}
|
|
|
|
func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
workspaceID := uuid.New()
|
|
ownerID := uuid.New()
|
|
chat1 := uuid.New()
|
|
chat2 := uuid.New()
|
|
|
|
var publishCount atomic.Int32
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
|
|
Return([]database.Chat{
|
|
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
|
|
{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
|
|
}, nil)
|
|
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).
|
|
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
|
if arg.ChatID == chat1 {
|
|
return database.ChatDiffStatus{}, fmt.Errorf("upsert ref error")
|
|
}
|
|
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
|
}).Times(2)
|
|
|
|
pub := func(_ context.Context, _ uuid.UUID) error {
|
|
publishCount.Add(1)
|
|
return nil
|
|
}
|
|
|
|
mClock := quartz.NewMock(t)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
refresher := newTestRefresher(t, mClock)
|
|
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
|
|
|
worker.MarkStale(ctx, workspaceID, ownerID, "dev", "https://github.com/a/b")
|
|
|
|
assert.Equal(t, int32(1), publishCount.Load())
|
|
}
|
|
|
|
func TestWorker_MarkStale_GetChatsFails(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
|
|
Return(nil, fmt.Errorf("db error"))
|
|
|
|
mClock := quartz.NewMock(t)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
refresher := newTestRefresher(t, mClock)
|
|
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
|
|
|
|
worker.MarkStale(ctx, uuid.New(), uuid.New(), "main", "https://github.com/x/y")
|
|
}
|
|
|
|
func TestWorker_TickStoreError(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
tickDone := make(chan struct{})
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
|
DoAndReturn(func(context.Context, int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
|
|
close(tickDone)
|
|
return nil, fmt.Errorf("database unavailable")
|
|
})
|
|
|
|
mClock := quartz.NewMock(t)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
refresher := newTestRefresher(t, mClock)
|
|
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
|
|
|
|
tickOnce(ctx, t, mClock, worker, tickDone)
|
|
}
|
|
|
|
func TestWorker_MarkStale_EmptyBranchOrOrigin(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
branch string
|
|
origin string
|
|
}{
|
|
{"both empty", "", ""},
|
|
{"branch empty", "", "https://github.com/x/y"},
|
|
{"origin empty", "main", ""},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
mClock := quartz.NewMock(t)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
refresher := newTestRefresher(t, mClock)
|
|
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
|
|
|
|
worker.MarkStale(ctx, uuid.New(), uuid.New(), tc.branch, tc.origin)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestWorker exercises the worker tick against a
|
|
// real PostgreSQL database to verify that the SQL queries, foreign key
|
|
// constraints, and upsert logic work end-to-end.
|
|
func TestWorker(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
|
|
// 1. Real database store.
|
|
db, _ := dbtestutil.NewDB(t)
|
|
|
|
// 2. Create a user (FK for chats).
|
|
user := dbgen.User(t, db, database.User{})
|
|
|
|
// 3. Set up FK chain: chat_providers -> chat_model_configs -> chats.
|
|
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
|
Provider: "openai",
|
|
DisplayName: "OpenAI",
|
|
Enabled: true,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
|
Provider: "openai",
|
|
Model: "test-model",
|
|
DisplayName: "Test Model",
|
|
Enabled: true,
|
|
ContextLimit: 100000,
|
|
CompressionThreshold: 70,
|
|
Options: json.RawMessage("{}"),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
|
OwnerID: user.ID,
|
|
LastModelConfigID: modelCfg.ID,
|
|
Title: "integration-test",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// 4. Seed a stale diff status row so the worker picks it up.
|
|
_, err = db.UpsertChatDiffStatusReference(ctx, database.UpsertChatDiffStatusReferenceParams{
|
|
ChatID: chat.ID,
|
|
GitBranch: "feature",
|
|
GitRemoteOrigin: "https://github.com/o/r",
|
|
StaleAt: time.Now().Add(-time.Minute),
|
|
Url: sql.NullString{},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// 5. Mock refresher returns a canned PR status.
|
|
mClock := quartz.NewMock(t)
|
|
refresher := newTestRefresher(t, mClock)
|
|
|
|
// 6. Track publish calls.
|
|
var publishCount atomic.Int32
|
|
tickDone := make(chan struct{})
|
|
pub := func(_ context.Context, chatID uuid.UUID) error {
|
|
assert.Equal(t, chat.ID, chatID)
|
|
if publishCount.Add(1) == 1 {
|
|
close(tickDone)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// 7. Create and run the worker for one tick.
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
worker := gitsync.NewWorker(db, refresher, pub, mClock, logger)
|
|
|
|
tickOnce(ctx, t, mClock, worker, tickDone)
|
|
|
|
// 8. Assert publisher was called.
|
|
require.Equal(t, int32(1), publishCount.Load())
|
|
|
|
// 9. Read back and verify persisted fields.
|
|
status, err := db.GetChatDiffStatusByChatID(ctx, chat.ID)
|
|
require.NoError(t, err)
|
|
|
|
// The mock resolveBranchPR returns PRRef{Owner: "o", Repo: "r", Number: 1}
|
|
// and buildPullRequestURL formats it as https://github.com/o/r/pull/1.
|
|
assert.Equal(t, "https://github.com/o/r/pull/1", status.Url.String)
|
|
assert.True(t, status.Url.Valid)
|
|
assert.Equal(t, string(gitprovider.PRStateOpen), status.PullRequestState.String)
|
|
assert.True(t, status.PullRequestState.Valid)
|
|
assert.Equal(t, int32(10), status.Additions)
|
|
assert.Equal(t, int32(3), status.Deletions)
|
|
assert.Equal(t, int32(2), status.ChangedFiles)
|
|
assert.True(t, status.RefreshedAt.Valid, "refreshed_at should be set")
|
|
// The mock clock's Now() + DiffStatusTTL determines stale_at.
|
|
expectedStaleAt := mClock.Now().Add(gitsync.DiffStatusTTL)
|
|
assert.WithinDuration(t, expectedStaleAt, status.StaleAt, time.Second)
|
|
}
|
|
|
|
func TestRefreshChat_Success(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
chatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
|
|
row := database.ChatDiffStatus{
|
|
ChatID: chatID,
|
|
GitBranch: "feature",
|
|
GitRemoteOrigin: "https://github.com/owner/repo",
|
|
}
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
upsertedStatus := database.ChatDiffStatus{
|
|
ChatID: chatID,
|
|
Url: sql.NullString{String: "https://github.com/o/r/pull/1", Valid: true},
|
|
Additions: 10,
|
|
Deletions: 3,
|
|
ChangedFiles: 2,
|
|
}
|
|
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
|
|
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
|
assert.Equal(t, chatID, arg.ChatID)
|
|
return upsertedStatus, nil
|
|
})
|
|
|
|
var publishCalled atomic.Bool
|
|
pub := func(_ context.Context, id uuid.UUID) error {
|
|
assert.Equal(t, chatID, id)
|
|
publishCalled.Store(true)
|
|
return nil
|
|
}
|
|
|
|
mClock := quartz.NewMock(t)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
refresher := newTestRefresher(t, mClock)
|
|
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
|
|
|
result, err := worker.RefreshChat(ctx, row, ownerID)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
assert.Equal(t, chatID, result.ChatID)
|
|
assert.Equal(t, upsertedStatus.Url, result.Url)
|
|
assert.True(t, publishCalled.Load(), "publish should have been called")
|
|
}
|
|
|
|
func TestRefreshChat_NoPR(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
chatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
|
|
row := database.ChatDiffStatus{
|
|
ChatID: chatID,
|
|
GitBranch: "feature",
|
|
GitRemoteOrigin: "https://github.com/owner/repo",
|
|
}
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
// UpsertChatDiffStatus should NOT be called.
|
|
|
|
var publishCalled atomic.Bool
|
|
pub := func(_ context.Context, _ uuid.UUID) error {
|
|
publishCalled.Store(true)
|
|
return nil
|
|
}
|
|
|
|
mClock := quartz.NewMock(t)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
// ResolveBranchPullRequest returns nil → no PR exists yet.
|
|
refresher := newTestRefresher(t, mClock, withResolveBranchPR(
|
|
func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) {
|
|
return nil, nil
|
|
},
|
|
))
|
|
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
|
|
|
result, err := worker.RefreshChat(ctx, row, ownerID)
|
|
require.NoError(t, err)
|
|
assert.Nil(t, result, "result should be nil when no PR exists")
|
|
assert.False(t, publishCalled.Load(), "publish should not be called when no PR exists")
|
|
}
|
|
|
|
func TestRefreshChat_RefreshError(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
chatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
|
|
row := database.ChatDiffStatus{
|
|
ChatID: chatID,
|
|
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
|
|
GitBranch: "feature",
|
|
GitRemoteOrigin: "https://github.com/owner/repo",
|
|
}
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
// UpsertChatDiffStatus should NOT be called.
|
|
|
|
// Provider resolver returns nil → "no provider" error.
|
|
providers := func(string) gitprovider.Provider { return nil }
|
|
tokens := func(context.Context, uuid.UUID, string) (*string, error) {
|
|
return ptr.Ref("tok"), nil
|
|
}
|
|
|
|
mClock := quartz.NewMock(t)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
refresher := gitsync.NewRefresher(providers, tokens, logger, mClock)
|
|
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
|
|
|
|
result, err := worker.RefreshChat(ctx, row, ownerID)
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "no provider")
|
|
assert.Nil(t, result)
|
|
}
|
|
|
|
func TestRefreshChat_UpsertError(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
chatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
|
|
row := database.ChatDiffStatus{
|
|
ChatID: chatID,
|
|
GitBranch: "feature",
|
|
GitRemoteOrigin: "https://github.com/owner/repo",
|
|
}
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
|
|
Return(database.ChatDiffStatus{}, fmt.Errorf("db write error"))
|
|
|
|
var publishCalled atomic.Bool
|
|
pub := func(_ context.Context, _ uuid.UUID) error {
|
|
publishCalled.Store(true)
|
|
return nil
|
|
}
|
|
|
|
mClock := quartz.NewMock(t)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
refresher := newTestRefresher(t, mClock)
|
|
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
|
|
|
result, err := worker.RefreshChat(ctx, row, ownerID)
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "upsert chat diff status")
|
|
assert.Nil(t, result)
|
|
assert.False(t, publishCalled.Load(), "publish should not be called when upsert fails")
|
|
}
|
|
|
|
func TestWorker_NoTokenBackoff(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
|
|
chatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
|
|
var mu sync.Mutex
|
|
var backoffArgs []database.BackoffChatDiffStatusParams
|
|
tickDone := make(chan struct{})
|
|
|
|
mClock := quartz.NewMock(t)
|
|
|
|
ctrl := gomock.NewController(t)
|
|
store := dbmock.NewMockStore(ctrl)
|
|
|
|
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
|
Return([]database.AcquireStaleChatDiffStatusesRow{
|
|
makeAcquiredRowWithBranch(chatID, ownerID, "feature"),
|
|
}, nil)
|
|
store.EXPECT().BackoffChatDiffStatus(gomock.Any(), gomock.Any()).
|
|
DoAndReturn(func(_ context.Context, arg database.BackoffChatDiffStatusParams) error {
|
|
mu.Lock()
|
|
backoffArgs = append(backoffArgs, arg)
|
|
mu.Unlock()
|
|
close(tickDone)
|
|
return nil
|
|
})
|
|
|
|
// Token resolver returns empty token → ErrNoTokenAvailable.
|
|
// Provider methods should never be called.
|
|
prov := &mockProvider{}
|
|
providers := func(string) gitprovider.Provider { return prov }
|
|
tokens := func(context.Context, uuid.UUID, string) (*string, error) {
|
|
return ptr.Ref(""), nil
|
|
}
|
|
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
refresher := gitsync.NewRefresher(providers, tokens, logger, mClock)
|
|
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)
|
|
|
|
tickOnce(ctx, t, mClock, worker, tickDone)
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
require.Len(t, backoffArgs, 1)
|
|
assert.Equal(t, chatID, backoffArgs[0].ChatID)
|
|
|
|
// The backoff should use NoTokenBackoff (10min), not
|
|
// DiffStatusTTL (2min).
|
|
expectedStaleAt := mClock.Now().UTC().Add(gitsync.NoTokenBackoff)
|
|
assert.WithinDuration(t, expectedStaleAt, backoffArgs[0].StaleAt, time.Second)
|
|
}
|