From 55e525fc287780e90be3e4cdca573b0ffddda3bf Mon Sep 17 00:00:00 2001 From: Ethan <39577870+ethanndickson@users.noreply.github.com> Date: Fri, 17 Apr 2026 00:07:30 +1000 Subject: [PATCH] ci: add InTx linter replacing ruleguard rule (#24422) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the old `InTx` ruleguard rule in `scripts/rules.go` with a custom in-tree `go/analysis` analyzer under `scripts/intxcheck/`. The new analyzer catches the same direct and pass-through misuse classes as before, plus two new classes the pattern-matcher couldn't reach: - **Indirect same-package helper misuse** — flags `p.someHelper(ctx)` inside `InTx` when the helper body uses the outer store (the PR #24369 bug class). - **Nested dangerous closures** — descends into `go func() { ... }()`, `defer func() { ... }()`, and immediately-invoked function literals. The analyzer uses semantic `types.Object` identity instead of raw expression string comparison, which avoids false positives from closure-local shadowing and catches simple aliases like `outer := s.db` and `alias := s`. This PR also fixes three real outer-store-inside-transaction bugs the new analyzer surfaced: - `coderd/wsbuilder/wsbuilder.go`: `FindMatchingPresetID` and `getWorkspaceTask` now use the inner transaction store instead of `b.store`. - `enterprise/dbcrypt/dbcrypt.go`: `ensureEncrypted` now calls `s.InsertDBCryptKey` (the tx-wrapped store) instead of `db.InsertDBCryptKey`. The `dbCrypt.InTx` method wraps the raw tx in a new `*dbCrypt`, so `s.InsertDBCryptKey` still dispatches through the encryption layer. Two call sites need `// intxcheck:ignore` suppressions. Both are one-off patterns that only look like misuse because the analyzer doesn't track assignments — proving them safe would require full dataflow analysis, which is well beyond what a targeted lint like this should attempt: - `coderd/database/dbfake/dbfake.go` — `b.db` is reassigned to `tx` on the preceding line, so `b.doInTX()` actually uses the transaction. The analyzer sees the original `b.db` identity and flags it. - `coderd/database/db_test.go` — test intentionally passes the outer store to `require.Equal` to assert that nested `InTx` returns the same handle. Suppressions use `// intxcheck:ignore` instead of `//nolint:intxcheck` because `intxcheck` runs as a standalone `go/analysis` tool outside golangci-lint. golangci-lint's `nolintlint` checker flags `//nolint` directives for linters it doesn't control, so we use a custom comment prefix to avoid that conflict. --- Makefile | 1 + coderd/database/db_test.go | 2 +- coderd/database/dbfake/dbfake.go | 2 +- coderd/wsbuilder/wsbuilder.go | 10 +- enterprise/dbcrypt/dbcrypt.go | 2 +- scripts/intxcheck/analyzer.go | 601 ++++++++++++++++++ scripts/intxcheck/analyzer_test.go | 13 + scripts/intxcheck/main.go | 7 + .../intxcheck/testdata/src/example/example.go | 155 +++++ scripts/rules.go | 46 -- 10 files changed, 785 insertions(+), 54 deletions(-) create mode 100644 scripts/intxcheck/analyzer.go create mode 100644 scripts/intxcheck/analyzer_test.go create mode 100644 scripts/intxcheck/main.go create mode 100644 scripts/intxcheck/testdata/src/example/example.go diff --git a/Makefile b/Makefile index a1afce77eb..c711d2c62f 100644 --- a/Makefile +++ b/Makefile @@ -721,6 +721,7 @@ lint/go: linter_ver=$$(grep -oE 'GOLANGCI_LINT_VERSION=\S+' dogfood/coder/ubuntu-26.04/Dockerfile | cut -d '=' -f 2) go run github.com/golangci/golangci-lint/cmd/golangci-lint@v$$linter_ver run go tool github.com/coder/paralleltestctx/cmd/paralleltestctx -custom-funcs="testutil.Context" ./... + go run ./scripts/intxcheck ./... .PHONY: lint/go lint/examples: | _gen/bin/examplegen diff --git a/coderd/database/db_test.go b/coderd/database/db_test.go index 68b60a788f..9941ef5ba3 100644 --- a/coderd/database/db_test.go +++ b/coderd/database/db_test.go @@ -60,7 +60,7 @@ func TestNestedInTx(t *testing.T) { err = db.InTx(func(outer database.Store) error { return outer.InTx(func(inner database.Store) error { //nolint:gocritic - require.Equal(t, outer, inner, "should be same transaction") + require.Equal(t, outer, inner, "should be same transaction") // intxcheck:ignore // intentional: test asserts nested InTx returns same store _, err := inner.InsertUser(context.Background(), database.InsertUserParams{ ID: uid, diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index e784e3121b..0b859a4fb1 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -274,7 +274,7 @@ func (b WorkspaceBuildBuilder) Do() WorkspaceResponse { err := b.db.InTx(func(tx database.Store) error { //nolint:revive // calls do on modified struct b.db = tx - resp = b.doInTX() + resp = b.doInTX() // intxcheck:ignore // b.db is reassigned to tx on the line above return nil }, nil) require.NoError(b.t, err) diff --git a/coderd/wsbuilder/wsbuilder.go b/coderd/wsbuilder/wsbuilder.go index b87806863d..b1b46d3492 100644 --- a/coderd/wsbuilder/wsbuilder.go +++ b/coderd/wsbuilder/wsbuilder.go @@ -490,7 +490,7 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object } if b.templateVersionPresetID == uuid.Nil { - presetID, err := prebuilds.FindMatchingPresetID(b.ctx, b.store, templateVersionID, names, values) + presetID, err := prebuilds.FindMatchingPresetID(b.ctx, store, templateVersionID, names, values) if err != nil { return BuildError{http.StatusInternalServerError, "find matching preset", err} } @@ -528,7 +528,7 @@ func (b *Builder) buildTx(authFunc func(action policy.Action, object rbac.Object return BuildError{code, "insert workspace build", err} } - task, err := b.getWorkspaceTask() + task, err := b.getWorkspaceTask(store) if err != nil { return BuildError{http.StatusInternalServerError, "get task by workspace id", err} } @@ -677,11 +677,11 @@ func (b *Builder) getTemplateVersionID() (uuid.UUID, error) { // getWorkspaceTask returns the task associated with the workspace, if any. // If no task exists, it returns (nil, nil). -func (b *Builder) getWorkspaceTask() (*database.Task, error) { +func (b *Builder) getWorkspaceTask(store database.Store) (*database.Task, error) { if b.hasTask != nil { return b.task, nil } - t, err := b.store.GetTaskByWorkspaceID(b.ctx, b.workspace.ID) + t, err := store.GetTaskByWorkspaceID(b.ctx, b.workspace.ID) if err != nil { if xerrors.Is(err, sql.ErrNoRows) { b.hasTask = ptr.Ref(false) @@ -1382,7 +1382,7 @@ func (b *Builder) checkUsage() error { return BuildError{http.StatusInternalServerError, "Failed to fetch template version", err} } - task, err := b.getWorkspaceTask() + task, err := b.getWorkspaceTask(b.store) if err != nil { return BuildError{http.StatusInternalServerError, "Failed to fetch workspace task", err} } diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index f2d2133d07..a222de1607 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -875,7 +875,7 @@ func (db *dbCrypt) ensureEncrypted(ctx context.Context) error { } // If we get here, then we have a new key that we need to insert. - return db.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{ + return s.InsertDBCryptKey(ctx, database.InsertDBCryptKeyParams{ Number: highestNumber + 1, ActiveKeyDigest: db.primaryCipherDigest, Test: testValue, diff --git a/scripts/intxcheck/analyzer.go b/scripts/intxcheck/analyzer.go new file mode 100644 index 0000000000..2b72f14571 --- /dev/null +++ b/scripts/intxcheck/analyzer.go @@ -0,0 +1,601 @@ +package main + +import ( + "fmt" + "go/ast" + "go/token" + "go/types" + "reflect" + "strings" + + "golang.org/x/tools/go/analysis" +) + +// Analyzer reports outer store usage inside database.Store.InTx closures. +var Analyzer = &analysis.Analyzer{ + Name: "intxcheck", + Doc: "report unsafe outer-store usage inside database.Store.InTx closures", + Run: run, + // ResultType must be set so run can return a typed nil instead + // of nil, nil — which the nilnil linter forbids. No downstream + // analyzer depends on this result. + ResultType: reflect.TypeOf((*struct{})(nil)), +} + +type txContext struct { + outerStore outerStoreMatcher + txName string +} + +type outerStoreMatcher struct { + display string + fieldSuffix string + ownerForms []exprForm + storeForms []exprForm +} + +type exprForm struct { + text string + root types.Object + suffix string +} + +func run(pass *analysis.Pass) (any, error) { + decls := make(map[types.Object]*ast.FuncDecl) + for _, file := range pass.Files { + for _, decl := range file.Decls { + funcDecl, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + obj := pass.TypesInfo.Defs[funcDecl.Name] + if obj == nil { + continue + } + decls[obj] = funcDecl + } + } + + for _, file := range pass.Files { + suppressed := suppressedLines(pass.Fset, file) + ast.Inspect(file, func(n ast.Node) bool { + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + inTxSelector, ok := unparen(call.Fun).(*ast.SelectorExpr) + if !ok || inTxSelector.Sel.Name != "InTx" { + return true + } + if len(call.Args) == 0 { + return true + } + + funcLit, ok := unparen(call.Args[0]).(*ast.FuncLit) + if !ok { + return true + } + + outerStore, ok := newOuterStoreMatcher(pass, inTxSelector.X) + if !ok { + return true + } + + ctx := txContext{ + outerStore: outerStore, + txName: firstParamName(funcLit.Type), + } + + inspectInTxBody(pass, funcLit.Body, ctx, decls, suppressed) + return true + }) + } + + return (*struct{})(nil), nil +} + +func inspectInTxBody(pass *analysis.Pass, body *ast.BlockStmt, ctx txContext, decls map[types.Object]*ast.FuncDecl, suppressed map[int]bool) { + ctx = ctx.withAliases(pass, body) + + ast.Inspect(body, func(n ast.Node) bool { + switch n := n.(type) { + case *ast.FuncLit: + return false + case *ast.GoStmt: + if funcLit, ok := funcLitCall(n.Call); ok { + reportCallMisuse(pass, n.Call, ctx, suppressed) + inspectInTxBody(pass, funcLit.Body, ctx, decls, suppressed) + return false + } + return true + case *ast.DeferStmt: + if funcLit, ok := funcLitCall(n.Call); ok { + reportCallMisuse(pass, n.Call, ctx, suppressed) + inspectInTxBody(pass, funcLit.Body, ctx, decls, suppressed) + return false + } + return true + } + + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + reported := reportCallMisuse(pass, call, ctx, suppressed) + if funcLit, ok := funcLitCall(call); ok { + inspectInTxBody(pass, funcLit.Body, ctx, decls, suppressed) + return true + } + if reported { + return true + } + + callee, calleeOuterStore, ok := resolveSamePackageCallee(pass, call, ctx, decls) + if !ok || callee == nil || callee.Body == nil { + return true + } + if !bodyUsesOuterStore(pass, callee.Body, calleeOuterStore) { + return true + } + + reportIfNotSuppressed(pass, suppressed, call.Pos(), fmt.Sprintf( + "call to '%s' inside InTx uses outer store '%s'; pass '%s' through the helper or hoist the call", + exprString(call.Fun), + ctx.outerStore.display, + ctx.txName, + )) + return true + }) +} + +func reportCallMisuse(pass *analysis.Pass, call *ast.CallExpr, ctx txContext, suppressed map[int]bool) bool { + kind, pos := classifyCall(pass, call, ctx.outerStore) + switch kind { + case misuseDirect: + reportIfNotSuppressed(pass, suppressed, pos, fmt.Sprintf( + "outer store '%s' used inside InTx; use transaction store '%s' instead", + ctx.outerStore.display, + ctx.txName, + )) + return true + case misusePassThrough: + reportIfNotSuppressed(pass, suppressed, pos, fmt.Sprintf( + "outer store '%s' passed as argument inside InTx; use transaction store '%s' instead", + ctx.outerStore.display, + ctx.txName, + )) + return true + default: + return false + } +} + +func funcLitCall(call *ast.CallExpr) (*ast.FuncLit, bool) { + funcLit, ok := unparen(call.Fun).(*ast.FuncLit) + if !ok { + return nil, false + } + return funcLit, true +} + +func reportIfNotSuppressed(pass *analysis.Pass, suppressed map[int]bool, pos token.Pos, message string) { + if suppressedLine(pass.Fset, suppressed, pos) { + return + } + + pass.Report(analysis.Diagnostic{ + Pos: pos, + Message: message, + }) +} + +type misuseKind int + +const ( + misuseNone misuseKind = iota + misuseDirect + misusePassThrough +) + +func classifyCall(pass *analysis.Pass, call *ast.CallExpr, outerStore outerStoreMatcher) (misuseKind, token.Pos) { + if receiver := callReceiver(call); receiver != nil && outerStore.matches(pass, receiver) { + return misuseDirect, receiver.Pos() + } + + for _, arg := range call.Args { + if outerStore.matches(pass, arg) { + return misusePassThrough, arg.Pos() + } + } + + return misuseNone, token.NoPos +} + +func bodyUsesOuterStore(pass *analysis.Pass, body *ast.BlockStmt, outerStore outerStoreMatcher) bool { + outerStore = outerStore.withAliases(pass, body) + + found := false + ast.Inspect(body, func(n ast.Node) bool { + if found { + return false + } + + switch n := n.(type) { + case *ast.FuncLit: + return false + case *ast.GoStmt: + if kind, _ := classifyCall(pass, n.Call, outerStore); kind != misuseNone { + found = true + return false + } + if funcLit, ok := funcLitCall(n.Call); ok { + found = bodyUsesOuterStore(pass, funcLit.Body, outerStore) + return false + } + return true + case *ast.DeferStmt: + if kind, _ := classifyCall(pass, n.Call, outerStore); kind != misuseNone { + found = true + return false + } + if funcLit, ok := funcLitCall(n.Call); ok { + found = bodyUsesOuterStore(pass, funcLit.Body, outerStore) + return false + } + return true + } + + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + kind, _ := classifyCall(pass, call, outerStore) + if kind != misuseNone { + found = true + return false + } + if funcLit, ok := funcLitCall(call); ok { + found = bodyUsesOuterStore(pass, funcLit.Body, outerStore) + if found { + return false + } + } + return true + }) + return found +} + +func resolveSamePackageCallee(pass *analysis.Pass, call *ast.CallExpr, ctx txContext, decls map[types.Object]*ast.FuncDecl) (*ast.FuncDecl, outerStoreMatcher, bool) { + switch fun := unparen(call.Fun).(type) { + case *ast.Ident: + // Package-level helpers have their own parameter scope. The + // pass-through check already catches explicit outer-store + // arguments, so skip indirect analysis here. + return nil, outerStoreMatcher{}, false + case *ast.SelectorExpr: + selection := pass.TypesInfo.Selections[fun] + if selection == nil { + return nil, outerStoreMatcher{}, false + } + decl, ok := decls[selection.Obj()] + if !ok || decl == nil || decl.Recv == nil { + return nil, outerStoreMatcher{}, false + } + if !ctx.outerStore.matchesOwner(pass, fun.X) { + return nil, outerStoreMatcher{}, false + } + calleeOuterStore, ok := ctx.outerStore.withReceiver(pass, decl) + if !ok { + return nil, outerStoreMatcher{}, false + } + return decl, calleeOuterStore, true + default: + return nil, outerStoreMatcher{}, false + } +} + +func (ctx txContext) withAliases(pass *analysis.Pass, body *ast.BlockStmt) txContext { + ctx.outerStore = ctx.outerStore.withAliases(pass, body) + return ctx +} + +func newOuterStoreMatcher(pass *analysis.Pass, expr ast.Expr) (outerStoreMatcher, bool) { + display := exprString(expr) + if display == "" { + return outerStoreMatcher{}, false + } + + matcher := outerStoreMatcher{display: display} + matcher.addStoreForm(exprFormFor(pass, expr)) + + selector, ok := unparen(expr).(*ast.SelectorExpr) + if !ok { + return matcher, true + } + + matcher.fieldSuffix = "." + selector.Sel.Name + matcher.addOwnerForm(exprFormFor(pass, selector.X)) + return matcher, true +} + +func (m outerStoreMatcher) withAliases(pass *analysis.Pass, body *ast.BlockStmt) outerStoreMatcher { + base := m + derived := m + + ast.Inspect(body, func(n ast.Node) bool { + switch n := n.(type) { + case *ast.FuncLit: + return false + case *ast.AssignStmt: + if n.Tok != token.DEFINE { + return true + } + for i, lhs := range n.Lhs { + if i >= len(n.Rhs) { + break + } + derived.collectAlias(pass, base, lhs, n.Rhs[i]) + } + case *ast.DeclStmt: + genDecl, ok := n.Decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.VAR { + return true + } + for _, spec := range genDecl.Specs { + valueSpec, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + for i, name := range valueSpec.Names { + if i >= len(valueSpec.Values) { + break + } + derived.collectAlias(pass, base, name, valueSpec.Values[i]) + } + } + } + return true + }) + + return derived +} + +func (m *outerStoreMatcher) collectAlias(pass *analysis.Pass, base outerStoreMatcher, lhs ast.Expr, rhs ast.Expr) { + lhsForm, ok := declaredIdentForm(pass, lhs) + if !ok { + return + } + + switch { + case base.matches(pass, rhs): + m.addStoreForm(lhsForm) + case base.matchesOwner(pass, rhs): + m.addOwnerForm(lhsForm) + } +} + +func (m outerStoreMatcher) withReceiver(pass *analysis.Pass, decl *ast.FuncDecl) (outerStoreMatcher, bool) { + recvForm, ok := receiverForm(pass, decl) + if !ok { + return outerStoreMatcher{}, false + } + + rebound := outerStoreMatcher{ + display: m.display, + fieldSuffix: m.fieldSuffix, + ownerForms: []exprForm{recvForm}, + } + if m.fieldSuffix == "" { + rebound.storeForms = []exprForm{recvForm} + } + return rebound, true +} + +func (m outerStoreMatcher) matches(pass *analysis.Pass, expr ast.Expr) bool { + form := exprFormFor(pass, expr) + if form.text == "" { + return false + } + + for _, storeForm := range m.storeForms { + if sameExprForm(form, storeForm) { + return true + } + } + + if m.fieldSuffix == "" { + return false + } + + for _, ownerForm := range m.ownerForms { + if sameExprFormWithSuffix(form, ownerForm, m.fieldSuffix) { + return true + } + } + + return false +} + +func (m outerStoreMatcher) matchesOwner(pass *analysis.Pass, expr ast.Expr) bool { + if len(m.ownerForms) == 0 { + return false + } + + form := exprFormFor(pass, expr) + if form.text == "" { + return false + } + + for _, ownerForm := range m.ownerForms { + if sameExprForm(form, ownerForm) { + return true + } + } + return false +} + +func (m *outerStoreMatcher) addOwnerForm(form exprForm) { + if form.text == "" || containsExprForm(m.ownerForms, form) { + return + } + m.ownerForms = append(m.ownerForms, form) +} + +func (m *outerStoreMatcher) addStoreForm(form exprForm) { + if form.text == "" || containsExprForm(m.storeForms, form) { + return + } + m.storeForms = append(m.storeForms, form) +} + +func containsExprForm(forms []exprForm, want exprForm) bool { + for _, form := range forms { + if sameExprForm(form, want) { + return true + } + } + return false +} + +func sameExprForm(got, want exprForm) bool { + if got.root != nil && want.root != nil { + return got.root == want.root && got.suffix == want.suffix + } + return got.text == want.text +} + +func sameExprFormWithSuffix(got, base exprForm, suffix string) bool { + if got.root != nil && base.root != nil { + return got.root == base.root && got.suffix == base.suffix+suffix + } + return got.text == base.text+suffix +} + +func exprFormFor(pass *analysis.Pass, expr ast.Expr) exprForm { + text := exprString(expr) + if text == "" { + return exprForm{} + } + + ident, suffix, ok := rootIdentAndSuffix(expr) + if !ok { + return exprForm{text: text} + } + + return exprForm{ + text: text, + root: identObject(pass, ident), + suffix: suffix, + } +} + +func receiverForm(pass *analysis.Pass, decl *ast.FuncDecl) (exprForm, bool) { + if decl.Recv == nil || len(decl.Recv.List) == 0 { + return exprForm{}, false + } + if len(decl.Recv.List[0].Names) == 0 { + return exprForm{}, false + } + + ident := decl.Recv.List[0].Names[0] + obj := pass.TypesInfo.Defs[ident] + if obj == nil { + return exprForm{}, false + } + + return exprForm{text: ident.Name, root: obj}, true +} + +func declaredIdentForm(pass *analysis.Pass, expr ast.Expr) (exprForm, bool) { + ident, ok := unparen(expr).(*ast.Ident) + if !ok || ident.Name == "_" { + return exprForm{}, false + } + + obj := pass.TypesInfo.Defs[ident] + if obj == nil { + return exprForm{}, false + } + + return exprForm{text: ident.Name, root: obj}, true +} + +func identObject(pass *analysis.Pass, ident *ast.Ident) types.Object { + if ident == nil { + return nil + } + if obj := pass.TypesInfo.Uses[ident]; obj != nil { + return obj + } + return pass.TypesInfo.Defs[ident] +} + +func rootIdentAndSuffix(expr ast.Expr) (*ast.Ident, string, bool) { + switch expr := unparen(expr).(type) { + case *ast.Ident: + return expr, "", true + case *ast.SelectorExpr: + ident, suffix, ok := rootIdentAndSuffix(expr.X) + if !ok { + return nil, "", false + } + return ident, suffix + "." + expr.Sel.Name, true + default: + return nil, "", false + } +} + +func callReceiver(call *ast.CallExpr) ast.Expr { + selector, ok := unparen(call.Fun).(*ast.SelectorExpr) + if !ok { + return nil + } + return selector.X +} + +func suppressedLines(fset *token.FileSet, file *ast.File) map[int]bool { + lines := make(map[int]bool) + for _, group := range file.Comments { + for _, comment := range group.List { + if strings.Contains(comment.Text, "intxcheck:ignore") { + lines[fset.Position(comment.Pos()).Line] = true + } + } + } + return lines +} + +func suppressedLine(fset *token.FileSet, suppressed map[int]bool, pos token.Pos) bool { + return suppressed[fset.Position(pos).Line] +} + +func firstParamName(funcType *ast.FuncType) string { + if funcType == nil || funcType.Params == nil || len(funcType.Params.List) == 0 { + return "tx" + } + first := funcType.Params.List[0] + if len(first.Names) == 0 { + return "tx" + } + return first.Names[0].Name +} + +func exprString(expr ast.Expr) string { + if expr == nil { + return "" + } + return types.ExprString(unparen(expr)) +} + +func unparen(expr ast.Expr) ast.Expr { + for { + paren, ok := expr.(*ast.ParenExpr) + if !ok { + return expr + } + expr = paren.X + } +} diff --git a/scripts/intxcheck/analyzer_test.go b/scripts/intxcheck/analyzer_test.go new file mode 100644 index 0000000000..8cfd7b50cf --- /dev/null +++ b/scripts/intxcheck/analyzer_test.go @@ -0,0 +1,13 @@ +package main + +import ( + "testing" + + "golang.org/x/tools/go/analysis/analysistest" +) + +func TestAnalyzer(t *testing.T) { + t.Parallel() + + analysistest.Run(t, analysistest.TestData(), Analyzer, "example") +} diff --git a/scripts/intxcheck/main.go b/scripts/intxcheck/main.go new file mode 100644 index 0000000000..57d4993dec --- /dev/null +++ b/scripts/intxcheck/main.go @@ -0,0 +1,7 @@ +package main + +import "golang.org/x/tools/go/analysis/singlechecker" + +func main() { + singlechecker.Main(Analyzer) +} diff --git a/scripts/intxcheck/testdata/src/example/example.go b/scripts/intxcheck/testdata/src/example/example.go new file mode 100644 index 0000000000..1f3b3b4c30 --- /dev/null +++ b/scripts/intxcheck/testdata/src/example/example.go @@ -0,0 +1,155 @@ +package example + +import "context" + +type TxOptions struct{} + +type Store interface { + InTx(func(Store) error, *TxOptions) error + GetUser(context.Context) (string, error) + GetConfig(context.Context) (string, error) +} + +type Server struct { + db Store +} + +type wrapper struct { + db Store +} + +func helper(context.Context, Store) {} + +func helperWithDB(ctx context.Context, db Store) { + _, _ = db.GetUser(ctx) +} + +func shadowingOK(ctx context.Context, db Store) error { + return db.InTx(func(db Store) error { + _, _ = db.GetUser(ctx) + return nil + }, nil) +} + +func pkgFuncOK(ctx context.Context, db Store) error { + return db.InTx(func(tx Store) error { + helperWithDB(ctx, tx) + return nil + }, nil) +} + +func (s *Server) directMisuse(ctx context.Context) error { + return s.db.InTx(func(tx Store) error { + _, _ = s.db.GetUser(ctx) // want "outer store 's[.]db' used inside InTx; use transaction store 'tx' instead" + return nil + }, nil) +} + +func (s *Server) passThroughMisuse(ctx context.Context) error { + return s.db.InTx(func(tx Store) error { + helper(ctx, s.db) // want "outer store 's[.]db' passed as argument inside InTx; use transaction store 'tx' instead" + return nil + }, nil) +} + +func (s *Server) indirectMisuse(ctx context.Context) error { + return s.db.InTx(func(tx Store) error { + s.getConfig(ctx) // want "call to 's[.]getConfig' inside InTx uses outer store 's[.]db'; pass 'tx' through the helper or hoist the call" + return nil + }, nil) +} + +func (s *Server) shadowedLocalOK(ctx context.Context) error { + return s.db.InTx(func(tx Store) error { + s := wrapper{db: tx} + _, _ = s.db.GetUser(ctx) + return nil + }, nil) +} + +func (s *Server) aliasedStoreMisuse(ctx context.Context) error { + return s.db.InTx(func(tx Store) error { + outer := s.db + _, _ = outer.GetUser(ctx) // want "outer store 's[.]db' used inside InTx; use transaction store 'tx' instead" + return nil + }, nil) +} + +func (s *Server) aliasedHelperMisuse(ctx context.Context) error { + return s.db.InTx(func(tx Store) error { + alias := s + alias.getConfig(ctx) // want "call to 'alias[.]getConfig' inside InTx uses outer store 's[.]db'; pass 'tx' through the helper or hoist the call" + return nil + }, nil) +} + +func (s *Server) goFuncLiteralMisuse(ctx context.Context) error { + return s.db.InTx(func(tx Store) error { + go func() { + _, _ = s.db.GetUser(ctx) // want "outer store 's[.]db' used inside InTx; use transaction store 'tx' instead" + }() + return nil + }, nil) +} + +func (s *Server) goFuncLiteralArgMisuse(ctx context.Context) error { + return s.db.InTx(func(tx Store) error { + go func(db Store) { + _, _ = db.GetUser(ctx) + }(s.db) // want "outer store 's[.]db' passed as argument inside InTx; use transaction store 'tx' instead" + return nil + }, nil) +} + +func (s *Server) deferFuncLiteralMisuse(ctx context.Context) error { + return s.db.InTx(func(tx Store) error { + defer func() { + _, _ = s.db.GetUser(ctx) // want "outer store 's[.]db' used inside InTx; use transaction store 'tx' instead" + }() + return nil + }, nil) +} + +func (s *Server) immediateFuncLiteralMisuse(ctx context.Context) error { + return s.db.InTx(func(tx Store) error { + func() { + _, _ = s.db.GetUser(ctx) // want "outer store 's[.]db' used inside InTx; use transaction store 'tx' instead" + }() + return nil + }, nil) +} + +func (s *Server) suppressedCase(ctx context.Context) error { + return s.db.InTx(func(tx Store) error { + _, _ = s.db.GetUser(ctx) // intxcheck:ignore + return nil + }, nil) +} + +func (srv *Server) getConfig(ctx context.Context) string { + value, _ := srv.db.GetConfig(ctx) + return value +} + +func (s *Server) correctUsage(ctx context.Context) error { + return s.db.InTx(func(tx Store) error { + _, _ = tx.GetUser(ctx) + return nil + }, nil) +} + +func (s *Server) safeHelper(ctx context.Context) error { + return s.db.InTx(func(tx Store) error { + s.formatName("test") + return nil + }, nil) +} + +func (s *Server) formatName(name string) string { + return name +} + +func (s *Server) outsideInTx(ctx context.Context) error { + _, _ = s.db.GetUser(ctx) + return nil +} diff --git a/scripts/rules.go b/scripts/rules.go index ce9b473985..327c21dcd7 100644 --- a/scripts/rules.go +++ b/scripts/rules.go @@ -248,52 +248,6 @@ func useStandardTimeoutsAndDelaysInTests(m dsl.Matcher) { Report("Do not use magic numbers in test timeouts and delays. Use the standard testutil.Wait* or testutil.Interval* constants instead.") } -// InTx checks to ensure the database used inside the transaction closure is the transaction -// database, and not the original database that creates the tx. -func InTx(m dsl.Matcher) { - // ':=' and '=' are 2 different matches :( - m.Match(` - $x.InTx(func($y) error { - $*_ - $*_ = $x.$f($*_) - $*_ - }) - `, ` - $x.InTx(func($y) error { - $*_ - $*_ := $x.$f($*_) - $*_ - }) - `).Where(m["x"].Text != m["y"].Text). - At(m["f"]). - Report("Do not use the database directly within the InTx closure. Use '$y' instead of '$x'.") - - // When using a tx closure, ensure that if you pass the db to another - // function inside the closure, it is the tx. - // This will miss more complex cases such as passing the db as apart - // of another struct. - m.Match(` - $x.InTx(func($y database.Store) error { - $*_ - $*_ = $f($*_, $x, $*_) - $*_ - }) - `, ` - $x.InTx(func($y database.Store) error { - $*_ - $*_ := $f($*_, $x, $*_) - $*_ - }) - `, ` - $x.InTx(func($y database.Store) error { - $*_ - $f($*_, $x, $*_) - $*_ - }) - `).Where(m["x"].Text != m["y"].Text). - At(m["f"]).Report("Pass the tx database into the '$f' function inside the closure. Use '$y' over $x'") -} - // HttpAPIErrorMessage intends to enforce constructing proper sentences as // error messages for the api. A proper sentence includes proper capitalization // and ends with punctuation.