Mercurial > hg > CbC > CbC_llvm
diff mlir/tools/mlir-tblgen/PassGen.cpp @ 207:2e18cbf3894f
LLVM12
author | Shinji KONO <kono@ie.u-ryukyu.ac.jp> |
---|---|
date | Tue, 08 Jun 2021 06:07:14 +0900 |
parents | 0572611fdcc8 |
children | 5f17cb93ff66 |
line wrap: on
line diff
--- a/mlir/tools/mlir-tblgen/PassGen.cpp Mon May 25 11:55:54 2020 +0900 +++ b/mlir/tools/mlir-tblgen/PassGen.cpp Tue Jun 08 06:07:14 2021 +0900 @@ -14,6 +14,7 @@ #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Pass.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -21,6 +22,11 @@ using namespace mlir; using namespace mlir::tblgen; +static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls"); +static llvm::cl::opt<std::string> + groupName("name", llvm::cl::desc("The name of this group of passes"), + llvm::cl::cat(passGenCat)); + //===----------------------------------------------------------------------===// // GEN: Pass base class generation //===----------------------------------------------------------------------===// @@ -30,6 +36,7 @@ /// {0}: The def name of the pass record. /// {1}: The base class for the pass. /// {2): The command line argument for the pass. +/// {3}: The dependent dialects registration. const char *const passDeclBegin = R"( //===----------------------------------------------------------------------===// // {0} @@ -38,14 +45,22 @@ template <typename DerivedT> class {0}Base : public {1} { public: + using Base = {0}Base; + {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} {0}Base(const {0}Base &) : {1}(::mlir::TypeID::get<DerivedT>()) {{} /// Returns the command-line argument attached to this pass. - llvm::StringRef getArgument() const override { return "{2}"; } + static constexpr ::llvm::StringLiteral getArgumentName() { + return ::llvm::StringLiteral("{2}"); + } + ::llvm::StringRef getArgument() const override { return "{2}"; } /// Returns the derived pass name. - llvm::StringRef getName() const override { return "{0}"; } + static constexpr ::llvm::StringLiteral getPassName() { + return ::llvm::StringLiteral("{0}"); + } + ::llvm::StringRef getName() const override { return "{0}"; } /// Support isa/dyn_cast functionality for the derived pass class. static bool classof(const ::mlir::Pass *pass) {{ @@ -53,23 +68,35 @@ } /// A clone method to create a copy of this pass. - std::unique_ptr<Pass> clonePass() const override {{ + std::unique_ptr<::mlir::Pass> clonePass() const override {{ return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); } + /// Return the dialect that must be loaded in the context before this pass. + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + {3} + } + protected: )"; +/// Registration for a single dependent dialect, to be inserted for each +/// dependent dialect in the `getDependentDialects` above. +const char *const dialectRegistrationTemplate = R"( + registry.insert<{0}>(); +)"; + /// Emit the declarations for each of the pass options. static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) { for (const PassOption &opt : pass.getOptions()) { - os.indent(2) << "Pass::" << (opt.isListOption() ? "ListOption" : "Option"); + os.indent(2) << "::mlir::Pass::" + << (opt.isListOption() ? "ListOption" : "Option"); - os << llvm::formatv("<{0}> {1}{{*this, \"{2}\", llvm::cl::desc(\"{3}\")", + os << llvm::formatv("<{0}> {1}{{*this, \"{2}\", ::llvm::cl::desc(\"{3}\")", opt.getType(), opt.getCppVariableName(), opt.getArgument(), opt.getDescription()); if (Optional<StringRef> defaultVal = opt.getDefaultValue()) - os << ", llvm::cl::init(" << defaultVal << ")"; + os << ", ::llvm::cl::init(" << defaultVal << ")"; if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags()) os << ", " << *additionalFlags; os << "};\n"; @@ -79,16 +106,23 @@ /// Emit the declarations for each of the pass statistics. static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) { for (const PassStatistic &stat : pass.getStatistics()) { - os << llvm::formatv(" Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n", - stat.getCppVariableName(), stat.getName(), - stat.getDescription()); + os << llvm::formatv( + " ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n", + stat.getCppVariableName(), stat.getName(), stat.getDescription()); } } static void emitPassDecl(const Pass &pass, raw_ostream &os) { StringRef defName = pass.getDef()->getName(); + std::string dependentDialectRegistrations; + { + llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); + for (StringRef dependentDialect : pass.getDependentDialects()) + dialectsOs << llvm::formatv(dialectRegistrationTemplate, + dependentDialect); + } os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(), - pass.getArgument()); + pass.getArgument(), dependentDialectRegistrations); emitPassOptionDecls(pass, os); emitPassStatisticDecls(pass, os); os << "};\n"; @@ -108,36 +142,49 @@ // GEN: Pass registration generation //===----------------------------------------------------------------------===// +/// The code snippet used to generate the start of a pass base class. +/// +/// {0}: The def name of the pass record. +/// {1}: The argument of the pass. +/// {2): The summary of the pass. +/// {3}: The code for constructing the pass. +const char *const passRegistrationCode = R"( +//===----------------------------------------------------------------------===// +// {0} Registration +//===----------------------------------------------------------------------===// + +inline void register{0}Pass() {{ + ::mlir::registerPass("{1}", "{2}", []() -> std::unique_ptr<::mlir::Pass> {{ + return {3}; + }); +} +)"; + +/// {0}: The name of the pass group. +const char *const passGroupRegistrationCode = R"( +//===----------------------------------------------------------------------===// +// {0} Registration +//===----------------------------------------------------------------------===// + +inline void register{0}Passes() {{ +)"; + /// Emit the code for registering each of the given passes with the global /// PassRegistry. static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) { os << "#ifdef GEN_PASS_REGISTRATION\n"; for (const Pass &pass : passes) { - os << llvm::formatv("#define GEN_PASS_REGISTRATION_{0}\n", - pass.getDef()->getName()); - } - os << "#endif // GEN_PASS_REGISTRATION\n"; - - for (const Pass &pass : passes) { - os << llvm::formatv("#ifdef GEN_PASS_REGISTRATION_{0}\n", - pass.getDef()->getName()); - os << llvm::formatv("::mlir::registerPass(\"{0}\", \"{1}\", []() -> " - "std::unique_ptr<Pass> {{ return {2}; });\n", + os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(), pass.getArgument(), pass.getSummary(), pass.getConstructor()); - os << llvm::formatv("#endif // GEN_PASS_REGISTRATION_{0}\n", - pass.getDef()->getName()); - os << llvm::formatv("#undef GEN_PASS_REGISTRATION_{0}\n", - pass.getDef()->getName()); } - os << "#ifdef GEN_PASS_REGISTRATION\n"; - for (const Pass &pass : passes) { - os << llvm::formatv("#undef GEN_PASS_REGISTRATION_{0}\n", - pass.getDef()->getName()); - } + os << llvm::formatv(passGroupRegistrationCode, groupName); + for (const Pass &pass : passes) + os << " register" << pass.getDef()->getName() << "Pass();\n"; + os << "}\n"; + os << "#undef GEN_PASS_REGISTRATION\n"; os << "#endif // GEN_PASS_REGISTRATION\n"; - os << "#undef GEN_PASS_REGISTRATION\n"; } //===----------------------------------------------------------------------===//