Mercurial > hg > CbC > CbC_llvm
comparison mlir/tools/mlir-tblgen/PassGen.cpp @ 207:2e18cbf3894f
LLVM12
author | Shinji KONO <kono@ie.u-ryukyu.ac.jp> |
---|---|
date | Tue, 08 Jun 2021 06:07:14 +0900 |
parents | 0572611fdcc8 |
children | 5f17cb93ff66 |
comparison
equal
deleted
inserted
replaced
173:0572611fdcc8 | 207:2e18cbf3894f |
---|---|
12 //===----------------------------------------------------------------------===// | 12 //===----------------------------------------------------------------------===// |
13 | 13 |
14 #include "mlir/TableGen/GenInfo.h" | 14 #include "mlir/TableGen/GenInfo.h" |
15 #include "mlir/TableGen/Pass.h" | 15 #include "mlir/TableGen/Pass.h" |
16 #include "llvm/ADT/StringExtras.h" | 16 #include "llvm/ADT/StringExtras.h" |
17 #include "llvm/Support/CommandLine.h" | |
17 #include "llvm/Support/FormatVariadic.h" | 18 #include "llvm/Support/FormatVariadic.h" |
18 #include "llvm/TableGen/Error.h" | 19 #include "llvm/TableGen/Error.h" |
19 #include "llvm/TableGen/Record.h" | 20 #include "llvm/TableGen/Record.h" |
20 | 21 |
21 using namespace mlir; | 22 using namespace mlir; |
22 using namespace mlir::tblgen; | 23 using namespace mlir::tblgen; |
24 | |
25 static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls"); | |
26 static llvm::cl::opt<std::string> | |
27 groupName("name", llvm::cl::desc("The name of this group of passes"), | |
28 llvm::cl::cat(passGenCat)); | |
23 | 29 |
24 //===----------------------------------------------------------------------===// | 30 //===----------------------------------------------------------------------===// |
25 // GEN: Pass base class generation | 31 // GEN: Pass base class generation |
26 //===----------------------------------------------------------------------===// | 32 //===----------------------------------------------------------------------===// |
27 | 33 |
28 /// The code snippet used to generate the start of a pass base class. | 34 /// The code snippet used to generate the start of a pass base class. |
29 /// | 35 /// |
30 /// {0}: The def name of the pass record. | 36 /// {0}: The def name of the pass record. |
31 /// {1}: The base class for the pass. | 37 /// {1}: The base class for the pass. |
32 /// {2): The command line argument for the pass. | 38 /// {2): The command line argument for the pass. |
39 /// {3}: The dependent dialects registration. | |
33 const char *const passDeclBegin = R"( | 40 const char *const passDeclBegin = R"( |
34 //===----------------------------------------------------------------------===// | 41 //===----------------------------------------------------------------------===// |
35 // {0} | 42 // {0} |
36 //===----------------------------------------------------------------------===// | 43 //===----------------------------------------------------------------------===// |
37 | 44 |
38 template <typename DerivedT> | 45 template <typename DerivedT> |
39 class {0}Base : public {1} { | 46 class {0}Base : public {1} { |
40 public: | 47 public: |
48 using Base = {0}Base; | |
49 | |
41 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} | 50 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} |
42 {0}Base(const {0}Base &) : {1}(::mlir::TypeID::get<DerivedT>()) {{} | 51 {0}Base(const {0}Base &) : {1}(::mlir::TypeID::get<DerivedT>()) {{} |
43 | 52 |
44 /// Returns the command-line argument attached to this pass. | 53 /// Returns the command-line argument attached to this pass. |
45 llvm::StringRef getArgument() const override { return "{2}"; } | 54 static constexpr ::llvm::StringLiteral getArgumentName() { |
55 return ::llvm::StringLiteral("{2}"); | |
56 } | |
57 ::llvm::StringRef getArgument() const override { return "{2}"; } | |
46 | 58 |
47 /// Returns the derived pass name. | 59 /// Returns the derived pass name. |
48 llvm::StringRef getName() const override { return "{0}"; } | 60 static constexpr ::llvm::StringLiteral getPassName() { |
61 return ::llvm::StringLiteral("{0}"); | |
62 } | |
63 ::llvm::StringRef getName() const override { return "{0}"; } | |
49 | 64 |
50 /// Support isa/dyn_cast functionality for the derived pass class. | 65 /// Support isa/dyn_cast functionality for the derived pass class. |
51 static bool classof(const ::mlir::Pass *pass) {{ | 66 static bool classof(const ::mlir::Pass *pass) {{ |
52 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); | 67 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
53 } | 68 } |
54 | 69 |
55 /// A clone method to create a copy of this pass. | 70 /// A clone method to create a copy of this pass. |
56 std::unique_ptr<Pass> clonePass() const override {{ | 71 std::unique_ptr<::mlir::Pass> clonePass() const override {{ |
57 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); | 72 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
58 } | 73 } |
59 | 74 |
75 /// Return the dialect that must be loaded in the context before this pass. | |
76 void getDependentDialects(::mlir::DialectRegistry ®istry) const override { | |
77 {3} | |
78 } | |
79 | |
60 protected: | 80 protected: |
81 )"; | |
82 | |
83 /// Registration for a single dependent dialect, to be inserted for each | |
84 /// dependent dialect in the `getDependentDialects` above. | |
85 const char *const dialectRegistrationTemplate = R"( | |
86 registry.insert<{0}>(); | |
61 )"; | 87 )"; |
62 | 88 |
63 /// Emit the declarations for each of the pass options. | 89 /// Emit the declarations for each of the pass options. |
64 static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) { | 90 static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) { |
65 for (const PassOption &opt : pass.getOptions()) { | 91 for (const PassOption &opt : pass.getOptions()) { |
66 os.indent(2) << "Pass::" << (opt.isListOption() ? "ListOption" : "Option"); | 92 os.indent(2) << "::mlir::Pass::" |
67 | 93 << (opt.isListOption() ? "ListOption" : "Option"); |
68 os << llvm::formatv("<{0}> {1}{{*this, \"{2}\", llvm::cl::desc(\"{3}\")", | 94 |
95 os << llvm::formatv("<{0}> {1}{{*this, \"{2}\", ::llvm::cl::desc(\"{3}\")", | |
69 opt.getType(), opt.getCppVariableName(), | 96 opt.getType(), opt.getCppVariableName(), |
70 opt.getArgument(), opt.getDescription()); | 97 opt.getArgument(), opt.getDescription()); |
71 if (Optional<StringRef> defaultVal = opt.getDefaultValue()) | 98 if (Optional<StringRef> defaultVal = opt.getDefaultValue()) |
72 os << ", llvm::cl::init(" << defaultVal << ")"; | 99 os << ", ::llvm::cl::init(" << defaultVal << ")"; |
73 if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags()) | 100 if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags()) |
74 os << ", " << *additionalFlags; | 101 os << ", " << *additionalFlags; |
75 os << "};\n"; | 102 os << "};\n"; |
76 } | 103 } |
77 } | 104 } |
78 | 105 |
79 /// Emit the declarations for each of the pass statistics. | 106 /// Emit the declarations for each of the pass statistics. |
80 static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) { | 107 static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) { |
81 for (const PassStatistic &stat : pass.getStatistics()) { | 108 for (const PassStatistic &stat : pass.getStatistics()) { |
82 os << llvm::formatv(" Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n", | 109 os << llvm::formatv( |
83 stat.getCppVariableName(), stat.getName(), | 110 " ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n", |
84 stat.getDescription()); | 111 stat.getCppVariableName(), stat.getName(), stat.getDescription()); |
85 } | 112 } |
86 } | 113 } |
87 | 114 |
88 static void emitPassDecl(const Pass &pass, raw_ostream &os) { | 115 static void emitPassDecl(const Pass &pass, raw_ostream &os) { |
89 StringRef defName = pass.getDef()->getName(); | 116 StringRef defName = pass.getDef()->getName(); |
117 std::string dependentDialectRegistrations; | |
118 { | |
119 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); | |
120 for (StringRef dependentDialect : pass.getDependentDialects()) | |
121 dialectsOs << llvm::formatv(dialectRegistrationTemplate, | |
122 dependentDialect); | |
123 } | |
90 os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(), | 124 os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(), |
91 pass.getArgument()); | 125 pass.getArgument(), dependentDialectRegistrations); |
92 emitPassOptionDecls(pass, os); | 126 emitPassOptionDecls(pass, os); |
93 emitPassStatisticDecls(pass, os); | 127 emitPassStatisticDecls(pass, os); |
94 os << "};\n"; | 128 os << "};\n"; |
95 } | 129 } |
96 | 130 |
106 | 140 |
107 //===----------------------------------------------------------------------===// | 141 //===----------------------------------------------------------------------===// |
108 // GEN: Pass registration generation | 142 // GEN: Pass registration generation |
109 //===----------------------------------------------------------------------===// | 143 //===----------------------------------------------------------------------===// |
110 | 144 |
145 /// The code snippet used to generate the start of a pass base class. | |
146 /// | |
147 /// {0}: The def name of the pass record. | |
148 /// {1}: The argument of the pass. | |
149 /// {2): The summary of the pass. | |
150 /// {3}: The code for constructing the pass. | |
151 const char *const passRegistrationCode = R"( | |
152 //===----------------------------------------------------------------------===// | |
153 // {0} Registration | |
154 //===----------------------------------------------------------------------===// | |
155 | |
156 inline void register{0}Pass() {{ | |
157 ::mlir::registerPass("{1}", "{2}", []() -> std::unique_ptr<::mlir::Pass> {{ | |
158 return {3}; | |
159 }); | |
160 } | |
161 )"; | |
162 | |
163 /// {0}: The name of the pass group. | |
164 const char *const passGroupRegistrationCode = R"( | |
165 //===----------------------------------------------------------------------===// | |
166 // {0} Registration | |
167 //===----------------------------------------------------------------------===// | |
168 | |
169 inline void register{0}Passes() {{ | |
170 )"; | |
171 | |
111 /// Emit the code for registering each of the given passes with the global | 172 /// Emit the code for registering each of the given passes with the global |
112 /// PassRegistry. | 173 /// PassRegistry. |
113 static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) { | 174 static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) { |
114 os << "#ifdef GEN_PASS_REGISTRATION\n"; | 175 os << "#ifdef GEN_PASS_REGISTRATION\n"; |
115 for (const Pass &pass : passes) { | 176 for (const Pass &pass : passes) { |
116 os << llvm::formatv("#define GEN_PASS_REGISTRATION_{0}\n", | 177 os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(), |
117 pass.getDef()->getName()); | |
118 } | |
119 os << "#endif // GEN_PASS_REGISTRATION\n"; | |
120 | |
121 for (const Pass &pass : passes) { | |
122 os << llvm::formatv("#ifdef GEN_PASS_REGISTRATION_{0}\n", | |
123 pass.getDef()->getName()); | |
124 os << llvm::formatv("::mlir::registerPass(\"{0}\", \"{1}\", []() -> " | |
125 "std::unique_ptr<Pass> {{ return {2}; });\n", | |
126 pass.getArgument(), pass.getSummary(), | 178 pass.getArgument(), pass.getSummary(), |
127 pass.getConstructor()); | 179 pass.getConstructor()); |
128 os << llvm::formatv("#endif // GEN_PASS_REGISTRATION_{0}\n", | 180 } |
129 pass.getDef()->getName()); | 181 |
130 os << llvm::formatv("#undef GEN_PASS_REGISTRATION_{0}\n", | 182 os << llvm::formatv(passGroupRegistrationCode, groupName); |
131 pass.getDef()->getName()); | 183 for (const Pass &pass : passes) |
132 } | 184 os << " register" << pass.getDef()->getName() << "Pass();\n"; |
133 | 185 os << "}\n"; |
134 os << "#ifdef GEN_PASS_REGISTRATION\n"; | 186 os << "#undef GEN_PASS_REGISTRATION\n"; |
135 for (const Pass &pass : passes) { | |
136 os << llvm::formatv("#undef GEN_PASS_REGISTRATION_{0}\n", | |
137 pass.getDef()->getName()); | |
138 } | |
139 os << "#endif // GEN_PASS_REGISTRATION\n"; | 187 os << "#endif // GEN_PASS_REGISTRATION\n"; |
140 os << "#undef GEN_PASS_REGISTRATION\n"; | |
141 } | 188 } |
142 | 189 |
143 //===----------------------------------------------------------------------===// | 190 //===----------------------------------------------------------------------===// |
144 // GEN: Registration hooks | 191 // GEN: Registration hooks |
145 //===----------------------------------------------------------------------===// | 192 //===----------------------------------------------------------------------===// |