173
|
1 //===- Pass.cpp - MLIR pass registration generator ------------------------===//
|
|
2 //
|
|
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
4 // See https://llvm.org/LICENSE.txt for license information.
|
|
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
6 //
|
|
7 //===----------------------------------------------------------------------===//
|
|
8 //
|
|
9 // PassGen uses the description of passes to generate base classes for passes
|
|
10 // and command line registration.
|
|
11 //
|
|
12 //===----------------------------------------------------------------------===//
|
|
13
|
|
14 #include "mlir/TableGen/GenInfo.h"
|
|
15 #include "mlir/TableGen/Pass.h"
|
|
16 #include "llvm/ADT/StringExtras.h"
|
207
|
17 #include "llvm/Support/CommandLine.h"
|
173
|
18 #include "llvm/Support/FormatVariadic.h"
|
|
19 #include "llvm/TableGen/Error.h"
|
|
20 #include "llvm/TableGen/Record.h"
|
|
21
|
|
22 using namespace mlir;
|
|
23 using namespace mlir::tblgen;
|
|
24
|
207
|
25 static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
|
|
26 static llvm::cl::opt<std::string>
|
|
27 groupName("name", llvm::cl::desc("The name of this group of passes"),
|
|
28 llvm::cl::cat(passGenCat));
|
|
29
|
173
|
30 //===----------------------------------------------------------------------===//
|
|
31 // GEN: Pass base class generation
|
|
32 //===----------------------------------------------------------------------===//
|
|
33
|
|
34 /// The code snippet used to generate the start of a pass base class.
|
|
35 ///
|
|
36 /// {0}: The def name of the pass record.
|
|
37 /// {1}: The base class for the pass.
|
|
38 /// {2): The command line argument for the pass.
|
207
|
39 /// {3}: The dependent dialects registration.
|
173
|
40 const char *const passDeclBegin = R"(
|
|
41 //===----------------------------------------------------------------------===//
|
|
42 // {0}
|
|
43 //===----------------------------------------------------------------------===//
|
|
44
|
|
45 template <typename DerivedT>
|
|
46 class {0}Base : public {1} {
|
|
47 public:
|
207
|
48 using Base = {0}Base;
|
|
49
|
173
|
50 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
|
|
51 {0}Base(const {0}Base &) : {1}(::mlir::TypeID::get<DerivedT>()) {{}
|
|
52
|
|
53 /// Returns the command-line argument attached to this pass.
|
207
|
54 static constexpr ::llvm::StringLiteral getArgumentName() {
|
|
55 return ::llvm::StringLiteral("{2}");
|
|
56 }
|
|
57 ::llvm::StringRef getArgument() const override { return "{2}"; }
|
173
|
58
|
|
59 /// Returns the derived pass name.
|
207
|
60 static constexpr ::llvm::StringLiteral getPassName() {
|
|
61 return ::llvm::StringLiteral("{0}");
|
|
62 }
|
|
63 ::llvm::StringRef getName() const override { return "{0}"; }
|
173
|
64
|
|
65 /// Support isa/dyn_cast functionality for the derived pass class.
|
|
66 static bool classof(const ::mlir::Pass *pass) {{
|
|
67 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
|
|
68 }
|
|
69
|
|
70 /// A clone method to create a copy of this pass.
|
207
|
71 std::unique_ptr<::mlir::Pass> clonePass() const override {{
|
173
|
72 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
|
|
73 }
|
|
74
|
207
|
75 /// Return the dialect that must be loaded in the context before this pass.
|
|
76 void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
|
|
77 {3}
|
|
78 }
|
|
79
|
173
|
80 protected:
|
|
81 )";
|
|
82
|
207
|
83 /// Registration for a single dependent dialect, to be inserted for each
|
|
84 /// dependent dialect in the `getDependentDialects` above.
|
|
85 const char *const dialectRegistrationTemplate = R"(
|
|
86 registry.insert<{0}>();
|
|
87 )";
|
|
88
|
173
|
89 /// Emit the declarations for each of the pass options.
|
|
90 static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
|
|
91 for (const PassOption &opt : pass.getOptions()) {
|
207
|
92 os.indent(2) << "::mlir::Pass::"
|
|
93 << (opt.isListOption() ? "ListOption" : "Option");
|
173
|
94
|
207
|
95 os << llvm::formatv("<{0}> {1}{{*this, \"{2}\", ::llvm::cl::desc(\"{3}\")",
|
173
|
96 opt.getType(), opt.getCppVariableName(),
|
|
97 opt.getArgument(), opt.getDescription());
|
|
98 if (Optional<StringRef> defaultVal = opt.getDefaultValue())
|
207
|
99 os << ", ::llvm::cl::init(" << defaultVal << ")";
|
173
|
100 if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags())
|
|
101 os << ", " << *additionalFlags;
|
|
102 os << "};\n";
|
|
103 }
|
|
104 }
|
|
105
|
|
106 /// Emit the declarations for each of the pass statistics.
|
|
107 static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
|
|
108 for (const PassStatistic &stat : pass.getStatistics()) {
|
207
|
109 os << llvm::formatv(
|
|
110 " ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
|
|
111 stat.getCppVariableName(), stat.getName(), stat.getDescription());
|
173
|
112 }
|
|
113 }
|
|
114
|
|
115 static void emitPassDecl(const Pass &pass, raw_ostream &os) {
|
|
116 StringRef defName = pass.getDef()->getName();
|
207
|
117 std::string dependentDialectRegistrations;
|
|
118 {
|
|
119 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
|
|
120 for (StringRef dependentDialect : pass.getDependentDialects())
|
|
121 dialectsOs << llvm::formatv(dialectRegistrationTemplate,
|
|
122 dependentDialect);
|
|
123 }
|
173
|
124 os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
|
207
|
125 pass.getArgument(), dependentDialectRegistrations);
|
173
|
126 emitPassOptionDecls(pass, os);
|
|
127 emitPassStatisticDecls(pass, os);
|
|
128 os << "};\n";
|
|
129 }
|
|
130
|
|
131 /// Emit the code for registering each of the given passes with the global
|
|
132 /// PassRegistry.
|
|
133 static void emitPassDecls(ArrayRef<Pass> passes, raw_ostream &os) {
|
|
134 os << "#ifdef GEN_PASS_CLASSES\n";
|
|
135 for (const Pass &pass : passes)
|
|
136 emitPassDecl(pass, os);
|
|
137 os << "#undef GEN_PASS_CLASSES\n";
|
|
138 os << "#endif // GEN_PASS_CLASSES\n";
|
|
139 }
|
|
140
|
|
141 //===----------------------------------------------------------------------===//
|
|
142 // GEN: Pass registration generation
|
|
143 //===----------------------------------------------------------------------===//
|
|
144
|
207
|
145 /// The code snippet used to generate the start of a pass base class.
|
|
146 ///
|
|
147 /// {0}: The def name of the pass record.
|
|
148 /// {1}: The argument of the pass.
|
|
149 /// {2): The summary of the pass.
|
|
150 /// {3}: The code for constructing the pass.
|
|
151 const char *const passRegistrationCode = R"(
|
|
152 //===----------------------------------------------------------------------===//
|
|
153 // {0} Registration
|
|
154 //===----------------------------------------------------------------------===//
|
|
155
|
|
156 inline void register{0}Pass() {{
|
|
157 ::mlir::registerPass("{1}", "{2}", []() -> std::unique_ptr<::mlir::Pass> {{
|
|
158 return {3};
|
|
159 });
|
|
160 }
|
|
161 )";
|
|
162
|
|
163 /// {0}: The name of the pass group.
|
|
164 const char *const passGroupRegistrationCode = R"(
|
|
165 //===----------------------------------------------------------------------===//
|
|
166 // {0} Registration
|
|
167 //===----------------------------------------------------------------------===//
|
|
168
|
|
169 inline void register{0}Passes() {{
|
|
170 )";
|
|
171
|
173
|
172 /// Emit the code for registering each of the given passes with the global
|
|
173 /// PassRegistry.
|
|
174 static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
|
|
175 os << "#ifdef GEN_PASS_REGISTRATION\n";
|
|
176 for (const Pass &pass : passes) {
|
207
|
177 os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
|
173
|
178 pass.getArgument(), pass.getSummary(),
|
|
179 pass.getConstructor());
|
|
180 }
|
|
181
|
207
|
182 os << llvm::formatv(passGroupRegistrationCode, groupName);
|
|
183 for (const Pass &pass : passes)
|
|
184 os << " register" << pass.getDef()->getName() << "Pass();\n";
|
|
185 os << "}\n";
|
|
186 os << "#undef GEN_PASS_REGISTRATION\n";
|
173
|
187 os << "#endif // GEN_PASS_REGISTRATION\n";
|
|
188 }
|
|
189
|
|
190 //===----------------------------------------------------------------------===//
|
|
191 // GEN: Registration hooks
|
|
192 //===----------------------------------------------------------------------===//
|
|
193
|
|
194 static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
|
|
195 os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
|
|
196 std::vector<Pass> passes;
|
|
197 for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
|
|
198 passes.push_back(Pass(def));
|
|
199
|
|
200 emitPassDecls(passes, os);
|
|
201 emitRegistration(passes, os);
|
|
202 }
|
|
203
|
|
204 static mlir::GenRegistration
|
|
205 genRegister("gen-pass-decls", "Generate operation documentation",
|
|
206 [](const llvm::RecordKeeper &records, raw_ostream &os) {
|
|
207 emitDecls(records, os);
|
|
208 return false;
|
|
209 });
|