package main import ( "bufio" "bytes" "fmt" "os" "path/filepath" "regexp" "strings" "golang.org/x/tools/imports" "golang.org/x/xerrors" ) type constraintType string const ( constraintTypeUnique constraintType = "unique" constraintTypeForeignKey constraintType = "foreign_key" constraintTypeCheck constraintType = "check" ) func (c constraintType) goType() string { switch c { case constraintTypeUnique: return "UniqueConstraint" case constraintTypeForeignKey: return "ForeignKeyConstraint" case constraintTypeCheck: return "CheckConstraint" default: panic(fmt.Sprintf("unknown constraint type: %s", c)) } } func (c constraintType) goTypeDescriptionPart() string { switch c { case constraintTypeUnique: return "unique" case constraintTypeForeignKey: return "foreign key" case constraintTypeCheck: return "check" default: panic(fmt.Sprintf("unknown constraint type: %s", c)) } } func (c constraintType) goEnumNamePrefix() string { switch c { case constraintTypeUnique: return "Unique" case constraintTypeForeignKey: return "ForeignKey" case constraintTypeCheck: return "Check" default: panic(fmt.Sprintf("unknown constraint type: %s", c)) } } type constraint struct { name string // comment is typically the full constraint, but for check constraints it's // instead the table name. comment string } // queryToConstraintsFn is a function that takes a query and returns zero or // more constraints if the query matches the wanted constraint type. If the // query does not match the wanted constraint type, the function should return // no constraints. type queryToConstraintsFn func(query string) ([]constraint, error) // generateConstraints does the following: // 1. Read the dump.sql file // 2. Parse the file into each query // 3. Pass each query to the constraintFn function // 4. Generate the enum from the returned constraints // 5. Write the generated code to the output path func generateConstraints(dumpPath, outputPath string, outputConstraintType constraintType, fn queryToConstraintsFn) error { dump, err := os.Open(dumpPath) if err != nil { return err } defer dump.Close() var allConstraints []constraint dumpScanner := bufio.NewScanner(dump) query := "" for dumpScanner.Scan() { line := strings.TrimSpace(dumpScanner.Text()) switch { case strings.HasPrefix(line, "--"): case line == "": case strings.HasSuffix(line, ";"): query += line newConstraints, err := fn(query) query = "" if err != nil { return xerrors.Errorf("process query %q: %w", query, err) } allConstraints = append(allConstraints, newConstraints...) default: query += line + " " } } if err = dumpScanner.Err(); err != nil { return err } s := &bytes.Buffer{} _, _ = fmt.Fprintf(s, `// Code generated by scripts/dbgen/main.go. DO NOT EDIT. package database // %[1]s represents a named %[2]s constraint on a table. type %[1]s string // %[1]s enums. const ( `, outputConstraintType.goType(), outputConstraintType.goTypeDescriptionPart()) for _, c := range allConstraints { constName := outputConstraintType.goEnumNamePrefix() + nameFromSnakeCase(c.name) _, _ = fmt.Fprintf(s, "\t%[1]s %[2]s = %[3]q // %[4]s\n", constName, outputConstraintType.goType(), c.name, c.comment) } _, _ = fmt.Fprint(s, ")\n") data, err := imports.Process(outputPath, s.Bytes(), &imports.Options{ Comments: true, }) if err != nil { return err } return os.WriteFile(outputPath, data, 0o600) } // generateUniqueConstraints generates the UniqueConstraint enum. func generateUniqueConstraints() error { localPath, err := localFilePath() if err != nil { return err } databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database") dumpPath := filepath.Join(databasePath, "dump.sql") outputPath := filepath.Join(databasePath, "unique_constraint.go") fn := func(query string) ([]constraint, error) { if strings.Contains(query, "UNIQUE") || strings.Contains(query, "PRIMARY KEY") { name := "" switch { case strings.Contains(query, "ALTER TABLE") && strings.Contains(query, "ADD CONSTRAINT"): name = strings.Split(query, " ")[6] case strings.Contains(query, "CREATE UNIQUE INDEX"): name = strings.Split(query, " ")[3] default: return nil, xerrors.Errorf("unknown unique constraint format: %s", query) } return []constraint{ { name: name, comment: query, }, }, nil } return nil, nil } return generateConstraints(dumpPath, outputPath, constraintTypeUnique, fn) } // generateForeignKeyConstraints generates the ForeignKeyConstraint enum. func generateForeignKeyConstraints() error { localPath, err := localFilePath() if err != nil { return err } databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database") dumpPath := filepath.Join(databasePath, "dump.sql") outputPath := filepath.Join(databasePath, "foreign_key_constraint.go") fn := func(query string) ([]constraint, error) { if strings.Contains(query, "FOREIGN KEY") { name := "" switch { case strings.Contains(query, "ALTER TABLE") && strings.Contains(query, "ADD CONSTRAINT"): name = strings.Split(query, " ")[6] default: return nil, xerrors.Errorf("unknown foreign key constraint format: %s", query) } return []constraint{ { name: name, comment: query, }, }, nil } return []constraint{}, nil } return generateConstraints(dumpPath, outputPath, constraintTypeForeignKey, fn) } // generateCheckConstraints generates the CheckConstraint enum. func generateCheckConstraints() error { localPath, err := localFilePath() if err != nil { return err } databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database") dumpPath := filepath.Join(databasePath, "dump.sql") outputPath := filepath.Join(databasePath, "check_constraint.go") var ( tableRegex = regexp.MustCompile(`CREATE TABLE\s+([^\s]+)`) checkRegex = regexp.MustCompile(`CONSTRAINT\s+([^\s]+)\s+CHECK`) ) fn := func(query string) ([]constraint, error) { constraints := []constraint{} tableMatches := tableRegex.FindStringSubmatch(query) if len(tableMatches) > 0 { table := tableMatches[1] // Find every CONSTRAINT xxx CHECK occurrence. matches := checkRegex.FindAllStringSubmatch(query, -1) for _, match := range matches { constraints = append(constraints, constraint{ name: match[1], comment: table, }) } } return constraints, nil } return generateConstraints(dumpPath, outputPath, constraintTypeCheck, fn) }