Files
Kyle Carberry b779c9ee33 fix: use SQL-level auth filtering for chat listing (#23159)
## Problem

The chat listing endpoint (`GetChatsByOwnerID`) was using
`fetchWithPostFilter`, which fetches N rows from the database and then
filters them in Go memory using RBAC checks. This causes a pagination
bug: if the user requests `limit=25` but some rows fail the auth check,
fewer than 25 rows are returned even though more authorized rows exist
in the database. The client may incorrectly assume it has reached the
end of the list.

## Solution

Switch to the same pattern used by `GetWorkspaces`, `GetTemplates`, and
`GetUsers`: `prepareSQLFilter` + `GetAuthorized*` variant. The RBAC
filter is compiled to a SQL WHERE clause and injected into the query
before `ORDER BY`/`LIMIT`, so the database returns exactly the requested
number of authorized rows.

Additionally, `GetChatsByOwnerID` is renamed to `GetChats` with
`OwnerID` as an optional (nullable) filter parameter, matching the
`GetWorkspaces` naming convention.

## Changes

| File | Change |
|------|--------|
| `queries/chats.sql` | Renamed to `GetChats`, `owner_id` now optional
via CASE/NULL, added `-- @authorize_filter` |
| `queries.sql.go` | Renamed constant, params struct (`GetChatsParams`),
and method |
| `querier.go` | Interface method renamed |
| `modelqueries.go` | Added `chatQuerier` interface +
`GetAuthorizedChats` impl |
| `dbauthz/dbauthz.go` | `GetChats` now uses `prepareSQLFilter` instead
of `fetchWithPostFilter` |
| `dbauthz/dbauthz_test.go` | Updated tests for SQL filter pattern |
| `dbmock/dbmock.go` | Renamed + added mock for `GetAuthorizedChats` |
| `dbmetrics/querymetrics.go` | Renamed + added metrics wrapper |
| `rbac/regosql/configs.go` | Added `ChatConverter` (maps `org_owner` to
empty string literal since `chats` has no `organization_id` column) |
| `rbac/authz.go` | Added `ConfigChats()` |
| `chats.go` | Handler uses renamed method with `uuid.NullUUID` |
| `searchquery/search.go` | Updated return type |
| `gitsync/worker.go` | Updated interface and call site |
| Various test files | Updated for renamed types |
2026-03-17 12:46:24 -04:00

184 lines
5.2 KiB
Go

package gentest_test
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"slices"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestCustomQueriesSynced makes sure the manual custom queries in modelqueries.go
// are synced with the autogenerated queries.sql.go. This should probably be
// autogenerated, but it's not atm and this is easy to throw in to elevate a better
// error message.
//
// If this breaks, and is hard to fix, you can t.Skip() it. It is not a critical
// test. Ping @Emyrk to fix it again.
func TestCustomQueriesSyncedRowScan(t *testing.T) {
t.Parallel()
funcsToTrack := map[string]string{
"GetTemplatesWithFilter": "GetAuthorizedTemplates",
"GetWorkspaces": "GetAuthorizedWorkspaces",
"GetUsers": "GetAuthorizedUsers",
"GetChats": "GetAuthorizedChats",
}
// Scan custom
var custom []string
for _, fn := range funcsToTrack {
custom = append(custom, fn)
}
customFns := parseFile(t, "../modelqueries.go", func(name string) bool {
return slices.Contains(custom, name)
})
generatedFns := parseFile(t, "../queries.sql.go", func(name string) bool {
_, ok := funcsToTrack[name]
return ok
})
merged := customFns
for k, v := range generatedFns {
merged[k] = v
}
for a, b := range funcsToTrack {
a, b := a, b
if !compareFns(t, a, b, merged[a], merged[b]) {
//nolint:revive
defer func() {
// Run this at the end so the suggested fix is the last thing printed.
t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' "+
"and 'db.QueryContext()' arguments in their function bodies. "+
"Make sure to copy the function body from the autogenerated %q body. "+
"Specifically the parameters for 'rows.Scan()' and 'db.QueryContext()'.", a, b, a)
}()
}
}
}
type parsedFunc struct {
RowScanArgs []ast.Expr
QueryArgs []ast.Expr
}
func parseFile(t *testing.T, filename string, trackFunc func(name string) bool) map[string]*parsedFunc {
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, filename, nil, parser.SkipObjectResolution)
require.NoErrorf(t, err, "failed to parse file %q", filename)
parsed := make(map[string]*parsedFunc)
for _, decl := range f.Decls {
if fn, ok := decl.(*ast.FuncDecl); ok {
if trackFunc(fn.Name.Name) {
parsed[fn.Name.String()] = &parsedFunc{
RowScanArgs: pullRowScanArgs(fn),
QueryArgs: pullQueryArgs(fn),
}
}
}
}
return parsed
}
func compareFns(t *testing.T, aName, bName string, a, b *parsedFunc) bool {
if a == nil {
t.Errorf("The function %q is missing", aName)
return false
}
if b == nil {
t.Errorf("The function %q is missing", bName)
return false
}
r := compareArgs(t, "rows.Scan() arguments", aName, bName, a.RowScanArgs, b.RowScanArgs)
if len(a.QueryArgs) > 2 && len(b.QueryArgs) > 2 {
// This is because the actual query param name is different. One uses the
// const, the other uses a variable that is a mutation of the original query.
a.QueryArgs[1] = b.QueryArgs[1]
}
q := compareArgs(t, "db.QueryContext() arguments", aName, bName, a.QueryArgs, b.QueryArgs)
return r && q
}
func compareArgs(t *testing.T, argType string, aName, bName string, a, b []ast.Expr) bool {
return assert.Equal(t, argList(t, a), argList(t, b), "mismatched %s for %s and %s", argType, aName, bName)
}
func argList(t *testing.T, args []ast.Expr) []string {
defer func() {
if r := recover(); r != nil {
t.Errorf("Recovered in f reading arg names: %s", r)
}
}()
var argNames []string
for _, arg := range args {
argname := "unknown"
// This is "&i.Arg" style stuff
if unary, ok := arg.(*ast.UnaryExpr); ok {
argname = unary.X.(*ast.SelectorExpr).Sel.Name
}
if ident, ok := arg.(*ast.Ident); ok {
argname = ident.Name
}
if sel, ok := arg.(*ast.SelectorExpr); ok {
argname = sel.Sel.Name
}
if call, ok := arg.(*ast.CallExpr); ok {
// Eh, this is pg.Array style stuff. Do a best effort.
argname = fmt.Sprintf("call(%d)", len(call.Args))
if fnCall, ok := call.Fun.(*ast.SelectorExpr); ok {
argname = fmt.Sprintf("%s(%d)", fnCall.Sel.Name, len(call.Args))
}
}
if argname == "unknown" {
t.Errorf("Unknown arg, cannot parse: %T", arg)
}
argNames = append(argNames, argname)
}
return argNames
}
func pullQueryArgs(fn *ast.FuncDecl) []ast.Expr {
for _, exp := range fn.Body.List {
// find "rows, err :="
if assign, ok := exp.(*ast.AssignStmt); ok {
if len(assign.Lhs) == 2 {
if id, ok := assign.Lhs[0].(*ast.Ident); ok && id.Name == "rows" {
// This is rows, err :=
query := assign.Rhs[0].(*ast.CallExpr)
if qSel, ok := query.Fun.(*ast.SelectorExpr); ok && qSel.Sel.Name == "QueryContext" {
return query.Args
}
}
}
}
}
return nil
}
func pullRowScanArgs(fn *ast.FuncDecl) []ast.Expr {
for _, exp := range fn.Body.List {
if forStmt, ok := exp.(*ast.ForStmt); ok {
// This came from the debugger window and tracking it down.
rowScan := (forStmt.Body.
// Second statement in the for loop is the if statement
// with rows.can
List[1].(*ast.IfStmt).
// This is the err := rows.Scan()
Init.(*ast.AssignStmt).
// Rhs is the row.Scan part
Rhs)[0].(*ast.CallExpr)
return rowScan.Args
}
}
return nil
}