comparison mlir/tools/mlir-tblgen/PassGen.cpp @ 207:2e18cbf3894f

LLVM12
author Shinji KONO <kono@ie.u-ryukyu.ac.jp>
date Tue, 08 Jun 2021 06:07:14 +0900
parents 0572611fdcc8
children 5f17cb93ff66
comparison
equal deleted inserted replaced
173:0572611fdcc8 207:2e18cbf3894f
12 //===----------------------------------------------------------------------===// 12 //===----------------------------------------------------------------------===//
13 13
14 #include "mlir/TableGen/GenInfo.h" 14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Pass.h" 15 #include "mlir/TableGen/Pass.h"
16 #include "llvm/ADT/StringExtras.h" 16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/Support/CommandLine.h"
17 #include "llvm/Support/FormatVariadic.h" 18 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/TableGen/Error.h" 19 #include "llvm/TableGen/Error.h"
19 #include "llvm/TableGen/Record.h" 20 #include "llvm/TableGen/Record.h"
20 21
21 using namespace mlir; 22 using namespace mlir;
22 using namespace mlir::tblgen; 23 using namespace mlir::tblgen;
24
25 static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
26 static llvm::cl::opt<std::string>
27 groupName("name", llvm::cl::desc("The name of this group of passes"),
28 llvm::cl::cat(passGenCat));
23 29
24 //===----------------------------------------------------------------------===// 30 //===----------------------------------------------------------------------===//
25 // GEN: Pass base class generation 31 // GEN: Pass base class generation
26 //===----------------------------------------------------------------------===// 32 //===----------------------------------------------------------------------===//
27 33
28 /// The code snippet used to generate the start of a pass base class. 34 /// The code snippet used to generate the start of a pass base class.
29 /// 35 ///
30 /// {0}: The def name of the pass record. 36 /// {0}: The def name of the pass record.
31 /// {1}: The base class for the pass. 37 /// {1}: The base class for the pass.
32 /// {2): The command line argument for the pass. 38 /// {2): The command line argument for the pass.
39 /// {3}: The dependent dialects registration.
33 const char *const passDeclBegin = R"( 40 const char *const passDeclBegin = R"(
34 //===----------------------------------------------------------------------===// 41 //===----------------------------------------------------------------------===//
35 // {0} 42 // {0}
36 //===----------------------------------------------------------------------===// 43 //===----------------------------------------------------------------------===//
37 44
38 template <typename DerivedT> 45 template <typename DerivedT>
39 class {0}Base : public {1} { 46 class {0}Base : public {1} {
40 public: 47 public:
48 using Base = {0}Base;
49
41 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} 50 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
42 {0}Base(const {0}Base &) : {1}(::mlir::TypeID::get<DerivedT>()) {{} 51 {0}Base(const {0}Base &) : {1}(::mlir::TypeID::get<DerivedT>()) {{}
43 52
44 /// Returns the command-line argument attached to this pass. 53 /// Returns the command-line argument attached to this pass.
45 llvm::StringRef getArgument() const override { return "{2}"; } 54 static constexpr ::llvm::StringLiteral getArgumentName() {
55 return ::llvm::StringLiteral("{2}");
56 }
57 ::llvm::StringRef getArgument() const override { return "{2}"; }
46 58
47 /// Returns the derived pass name. 59 /// Returns the derived pass name.
48 llvm::StringRef getName() const override { return "{0}"; } 60 static constexpr ::llvm::StringLiteral getPassName() {
61 return ::llvm::StringLiteral("{0}");
62 }
63 ::llvm::StringRef getName() const override { return "{0}"; }
49 64
50 /// Support isa/dyn_cast functionality for the derived pass class. 65 /// Support isa/dyn_cast functionality for the derived pass class.
51 static bool classof(const ::mlir::Pass *pass) {{ 66 static bool classof(const ::mlir::Pass *pass) {{
52 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); 67 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
53 } 68 }
54 69
55 /// A clone method to create a copy of this pass. 70 /// A clone method to create a copy of this pass.
56 std::unique_ptr<Pass> clonePass() const override {{ 71 std::unique_ptr<::mlir::Pass> clonePass() const override {{
57 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); 72 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
58 } 73 }
59 74
75 /// Return the dialect that must be loaded in the context before this pass.
76 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
77 {3}
78 }
79
60 protected: 80 protected:
81 )";
82
83 /// Registration for a single dependent dialect, to be inserted for each
84 /// dependent dialect in the `getDependentDialects` above.
85 const char *const dialectRegistrationTemplate = R"(
86 registry.insert<{0}>();
61 )"; 87 )";
62 88
63 /// Emit the declarations for each of the pass options. 89 /// Emit the declarations for each of the pass options.
64 static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) { 90 static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
65 for (const PassOption &opt : pass.getOptions()) { 91 for (const PassOption &opt : pass.getOptions()) {
66 os.indent(2) << "Pass::" << (opt.isListOption() ? "ListOption" : "Option"); 92 os.indent(2) << "::mlir::Pass::"
67 93 << (opt.isListOption() ? "ListOption" : "Option");
68 os << llvm::formatv("<{0}> {1}{{*this, \"{2}\", llvm::cl::desc(\"{3}\")", 94
95 os << llvm::formatv("<{0}> {1}{{*this, \"{2}\", ::llvm::cl::desc(\"{3}\")",
69 opt.getType(), opt.getCppVariableName(), 96 opt.getType(), opt.getCppVariableName(),
70 opt.getArgument(), opt.getDescription()); 97 opt.getArgument(), opt.getDescription());
71 if (Optional<StringRef> defaultVal = opt.getDefaultValue()) 98 if (Optional<StringRef> defaultVal = opt.getDefaultValue())
72 os << ", llvm::cl::init(" << defaultVal << ")"; 99 os << ", ::llvm::cl::init(" << defaultVal << ")";
73 if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags()) 100 if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags())
74 os << ", " << *additionalFlags; 101 os << ", " << *additionalFlags;
75 os << "};\n"; 102 os << "};\n";
76 } 103 }
77 } 104 }
78 105
79 /// Emit the declarations for each of the pass statistics. 106 /// Emit the declarations for each of the pass statistics.
80 static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) { 107 static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
81 for (const PassStatistic &stat : pass.getStatistics()) { 108 for (const PassStatistic &stat : pass.getStatistics()) {
82 os << llvm::formatv(" Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n", 109 os << llvm::formatv(
83 stat.getCppVariableName(), stat.getName(), 110 " ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
84 stat.getDescription()); 111 stat.getCppVariableName(), stat.getName(), stat.getDescription());
85 } 112 }
86 } 113 }
87 114
88 static void emitPassDecl(const Pass &pass, raw_ostream &os) { 115 static void emitPassDecl(const Pass &pass, raw_ostream &os) {
89 StringRef defName = pass.getDef()->getName(); 116 StringRef defName = pass.getDef()->getName();
117 std::string dependentDialectRegistrations;
118 {
119 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
120 for (StringRef dependentDialect : pass.getDependentDialects())
121 dialectsOs << llvm::formatv(dialectRegistrationTemplate,
122 dependentDialect);
123 }
90 os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(), 124 os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
91 pass.getArgument()); 125 pass.getArgument(), dependentDialectRegistrations);
92 emitPassOptionDecls(pass, os); 126 emitPassOptionDecls(pass, os);
93 emitPassStatisticDecls(pass, os); 127 emitPassStatisticDecls(pass, os);
94 os << "};\n"; 128 os << "};\n";
95 } 129 }
96 130
106 140
107 //===----------------------------------------------------------------------===// 141 //===----------------------------------------------------------------------===//
108 // GEN: Pass registration generation 142 // GEN: Pass registration generation
109 //===----------------------------------------------------------------------===// 143 //===----------------------------------------------------------------------===//
110 144
145 /// The code snippet used to generate the start of a pass base class.
146 ///
147 /// {0}: The def name of the pass record.
148 /// {1}: The argument of the pass.
149 /// {2): The summary of the pass.
150 /// {3}: The code for constructing the pass.
151 const char *const passRegistrationCode = R"(
152 //===----------------------------------------------------------------------===//
153 // {0} Registration
154 //===----------------------------------------------------------------------===//
155
156 inline void register{0}Pass() {{
157 ::mlir::registerPass("{1}", "{2}", []() -> std::unique_ptr<::mlir::Pass> {{
158 return {3};
159 });
160 }
161 )";
162
163 /// {0}: The name of the pass group.
164 const char *const passGroupRegistrationCode = R"(
165 //===----------------------------------------------------------------------===//
166 // {0} Registration
167 //===----------------------------------------------------------------------===//
168
169 inline void register{0}Passes() {{
170 )";
171
111 /// Emit the code for registering each of the given passes with the global 172 /// Emit the code for registering each of the given passes with the global
112 /// PassRegistry. 173 /// PassRegistry.
113 static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) { 174 static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
114 os << "#ifdef GEN_PASS_REGISTRATION\n"; 175 os << "#ifdef GEN_PASS_REGISTRATION\n";
115 for (const Pass &pass : passes) { 176 for (const Pass &pass : passes) {
116 os << llvm::formatv("#define GEN_PASS_REGISTRATION_{0}\n", 177 os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
117 pass.getDef()->getName());
118 }
119 os << "#endif // GEN_PASS_REGISTRATION\n";
120
121 for (const Pass &pass : passes) {
122 os << llvm::formatv("#ifdef GEN_PASS_REGISTRATION_{0}\n",
123 pass.getDef()->getName());
124 os << llvm::formatv("::mlir::registerPass(\"{0}\", \"{1}\", []() -> "
125 "std::unique_ptr<Pass> {{ return {2}; });\n",
126 pass.getArgument(), pass.getSummary(), 178 pass.getArgument(), pass.getSummary(),
127 pass.getConstructor()); 179 pass.getConstructor());
128 os << llvm::formatv("#endif // GEN_PASS_REGISTRATION_{0}\n", 180 }
129 pass.getDef()->getName()); 181
130 os << llvm::formatv("#undef GEN_PASS_REGISTRATION_{0}\n", 182 os << llvm::formatv(passGroupRegistrationCode, groupName);
131 pass.getDef()->getName()); 183 for (const Pass &pass : passes)
132 } 184 os << " register" << pass.getDef()->getName() << "Pass();\n";
133 185 os << "}\n";
134 os << "#ifdef GEN_PASS_REGISTRATION\n"; 186 os << "#undef GEN_PASS_REGISTRATION\n";
135 for (const Pass &pass : passes) {
136 os << llvm::formatv("#undef GEN_PASS_REGISTRATION_{0}\n",
137 pass.getDef()->getName());
138 }
139 os << "#endif // GEN_PASS_REGISTRATION\n"; 187 os << "#endif // GEN_PASS_REGISTRATION\n";
140 os << "#undef GEN_PASS_REGISTRATION\n";
141 } 188 }
142 189
143 //===----------------------------------------------------------------------===// 190 //===----------------------------------------------------------------------===//
144 // GEN: Registration hooks 191 // GEN: Registration hooks
145 //===----------------------------------------------------------------------===// 192 //===----------------------------------------------------------------------===//