mirror of
https://github.com/coder/coder.git
synced 2026-06-03 04:58:23 +00:00
b779c9ee33
## Problem The chat listing endpoint (`GetChatsByOwnerID`) was using `fetchWithPostFilter`, which fetches N rows from the database and then filters them in Go memory using RBAC checks. This causes a pagination bug: if the user requests `limit=25` but some rows fail the auth check, fewer than 25 rows are returned even though more authorized rows exist in the database. The client may incorrectly assume it has reached the end of the list. ## Solution Switch to the same pattern used by `GetWorkspaces`, `GetTemplates`, and `GetUsers`: `prepareSQLFilter` + `GetAuthorized*` variant. The RBAC filter is compiled to a SQL WHERE clause and injected into the query before `ORDER BY`/`LIMIT`, so the database returns exactly the requested number of authorized rows. Additionally, `GetChatsByOwnerID` is renamed to `GetChats` with `OwnerID` as an optional (nullable) filter parameter, matching the `GetWorkspaces` naming convention. ## Changes | File | Change | |------|--------| | `queries/chats.sql` | Renamed to `GetChats`, `owner_id` now optional via CASE/NULL, added `-- @authorize_filter` | | `queries.sql.go` | Renamed constant, params struct (`GetChatsParams`), and method | | `querier.go` | Interface method renamed | | `modelqueries.go` | Added `chatQuerier` interface + `GetAuthorizedChats` impl | | `dbauthz/dbauthz.go` | `GetChats` now uses `prepareSQLFilter` instead of `fetchWithPostFilter` | | `dbauthz/dbauthz_test.go` | Updated tests for SQL filter pattern | | `dbmock/dbmock.go` | Renamed + added mock for `GetAuthorizedChats` | | `dbmetrics/querymetrics.go` | Renamed + added metrics wrapper | | `rbac/regosql/configs.go` | Added `ChatConverter` (maps `org_owner` to empty string literal since `chats` has no `organization_id` column) | | `rbac/authz.go` | Added `ConfigChats()` | | `chats.go` | Handler uses renamed method with `uuid.NullUUID` | | `searchquery/search.go` | Updated return type | | `gitsync/worker.go` | Updated interface and call site | | Various test files | Updated for renamed types |
184 lines
5.2 KiB
Go
184 lines
5.2 KiB
Go
package gentest_test
|
|
|
|
import (
|
|
"fmt"
|
|
"go/ast"
|
|
"go/parser"
|
|
"go/token"
|
|
"slices"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestCustomQueriesSynced makes sure the manual custom queries in modelqueries.go
|
|
// are synced with the autogenerated queries.sql.go. This should probably be
|
|
// autogenerated, but it's not atm and this is easy to throw in to elevate a better
|
|
// error message.
|
|
//
|
|
// If this breaks, and is hard to fix, you can t.Skip() it. It is not a critical
|
|
// test. Ping @Emyrk to fix it again.
|
|
func TestCustomQueriesSyncedRowScan(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
funcsToTrack := map[string]string{
|
|
"GetTemplatesWithFilter": "GetAuthorizedTemplates",
|
|
"GetWorkspaces": "GetAuthorizedWorkspaces",
|
|
"GetUsers": "GetAuthorizedUsers",
|
|
"GetChats": "GetAuthorizedChats",
|
|
}
|
|
|
|
// Scan custom
|
|
var custom []string
|
|
for _, fn := range funcsToTrack {
|
|
custom = append(custom, fn)
|
|
}
|
|
|
|
customFns := parseFile(t, "../modelqueries.go", func(name string) bool {
|
|
return slices.Contains(custom, name)
|
|
})
|
|
generatedFns := parseFile(t, "../queries.sql.go", func(name string) bool {
|
|
_, ok := funcsToTrack[name]
|
|
return ok
|
|
})
|
|
merged := customFns
|
|
for k, v := range generatedFns {
|
|
merged[k] = v
|
|
}
|
|
|
|
for a, b := range funcsToTrack {
|
|
a, b := a, b
|
|
if !compareFns(t, a, b, merged[a], merged[b]) {
|
|
//nolint:revive
|
|
defer func() {
|
|
// Run this at the end so the suggested fix is the last thing printed.
|
|
t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' "+
|
|
"and 'db.QueryContext()' arguments in their function bodies. "+
|
|
"Make sure to copy the function body from the autogenerated %q body. "+
|
|
"Specifically the parameters for 'rows.Scan()' and 'db.QueryContext()'.", a, b, a)
|
|
}()
|
|
}
|
|
}
|
|
}
|
|
|
|
type parsedFunc struct {
|
|
RowScanArgs []ast.Expr
|
|
QueryArgs []ast.Expr
|
|
}
|
|
|
|
func parseFile(t *testing.T, filename string, trackFunc func(name string) bool) map[string]*parsedFunc {
|
|
fset := token.NewFileSet()
|
|
f, err := parser.ParseFile(fset, filename, nil, parser.SkipObjectResolution)
|
|
require.NoErrorf(t, err, "failed to parse file %q", filename)
|
|
|
|
parsed := make(map[string]*parsedFunc)
|
|
for _, decl := range f.Decls {
|
|
if fn, ok := decl.(*ast.FuncDecl); ok {
|
|
if trackFunc(fn.Name.Name) {
|
|
parsed[fn.Name.String()] = &parsedFunc{
|
|
RowScanArgs: pullRowScanArgs(fn),
|
|
QueryArgs: pullQueryArgs(fn),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return parsed
|
|
}
|
|
|
|
func compareFns(t *testing.T, aName, bName string, a, b *parsedFunc) bool {
|
|
if a == nil {
|
|
t.Errorf("The function %q is missing", aName)
|
|
return false
|
|
}
|
|
if b == nil {
|
|
t.Errorf("The function %q is missing", bName)
|
|
return false
|
|
}
|
|
r := compareArgs(t, "rows.Scan() arguments", aName, bName, a.RowScanArgs, b.RowScanArgs)
|
|
if len(a.QueryArgs) > 2 && len(b.QueryArgs) > 2 {
|
|
// This is because the actual query param name is different. One uses the
|
|
// const, the other uses a variable that is a mutation of the original query.
|
|
a.QueryArgs[1] = b.QueryArgs[1]
|
|
}
|
|
q := compareArgs(t, "db.QueryContext() arguments", aName, bName, a.QueryArgs, b.QueryArgs)
|
|
return r && q
|
|
}
|
|
|
|
func compareArgs(t *testing.T, argType string, aName, bName string, a, b []ast.Expr) bool {
|
|
return assert.Equal(t, argList(t, a), argList(t, b), "mismatched %s for %s and %s", argType, aName, bName)
|
|
}
|
|
|
|
func argList(t *testing.T, args []ast.Expr) []string {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
t.Errorf("Recovered in f reading arg names: %s", r)
|
|
}
|
|
}()
|
|
|
|
var argNames []string
|
|
for _, arg := range args {
|
|
argname := "unknown"
|
|
// This is "&i.Arg" style stuff
|
|
if unary, ok := arg.(*ast.UnaryExpr); ok {
|
|
argname = unary.X.(*ast.SelectorExpr).Sel.Name
|
|
}
|
|
if ident, ok := arg.(*ast.Ident); ok {
|
|
argname = ident.Name
|
|
}
|
|
if sel, ok := arg.(*ast.SelectorExpr); ok {
|
|
argname = sel.Sel.Name
|
|
}
|
|
if call, ok := arg.(*ast.CallExpr); ok {
|
|
// Eh, this is pg.Array style stuff. Do a best effort.
|
|
argname = fmt.Sprintf("call(%d)", len(call.Args))
|
|
if fnCall, ok := call.Fun.(*ast.SelectorExpr); ok {
|
|
argname = fmt.Sprintf("%s(%d)", fnCall.Sel.Name, len(call.Args))
|
|
}
|
|
}
|
|
|
|
if argname == "unknown" {
|
|
t.Errorf("Unknown arg, cannot parse: %T", arg)
|
|
}
|
|
argNames = append(argNames, argname)
|
|
}
|
|
return argNames
|
|
}
|
|
|
|
func pullQueryArgs(fn *ast.FuncDecl) []ast.Expr {
|
|
for _, exp := range fn.Body.List {
|
|
// find "rows, err :="
|
|
if assign, ok := exp.(*ast.AssignStmt); ok {
|
|
if len(assign.Lhs) == 2 {
|
|
if id, ok := assign.Lhs[0].(*ast.Ident); ok && id.Name == "rows" {
|
|
// This is rows, err :=
|
|
query := assign.Rhs[0].(*ast.CallExpr)
|
|
if qSel, ok := query.Fun.(*ast.SelectorExpr); ok && qSel.Sel.Name == "QueryContext" {
|
|
return query.Args
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func pullRowScanArgs(fn *ast.FuncDecl) []ast.Expr {
|
|
for _, exp := range fn.Body.List {
|
|
if forStmt, ok := exp.(*ast.ForStmt); ok {
|
|
// This came from the debugger window and tracking it down.
|
|
rowScan := (forStmt.Body.
|
|
// Second statement in the for loop is the if statement
|
|
// with rows.can
|
|
List[1].(*ast.IfStmt).
|
|
// This is the err := rows.Scan()
|
|
Init.(*ast.AssignStmt).
|
|
// Rhs is the row.Scan part
|
|
Rhs)[0].(*ast.CallExpr)
|
|
return rowScan.Args
|
|
}
|
|
}
|
|
return nil
|
|
}
|