173
|
1 //===- Pass.cpp - MLIR pass registration generator ------------------------===//
|
|
2 //
|
|
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
4 // See https://llvm.org/LICENSE.txt for license information.
|
|
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
6 //
|
|
7 //===----------------------------------------------------------------------===//
|
|
8 //
|
|
9 // PassGen uses the description of passes to generate base classes for passes
|
|
10 // and command line registration.
|
|
11 //
|
|
12 //===----------------------------------------------------------------------===//
|
|
13
|
|
14 #include "mlir/TableGen/GenInfo.h"
|
|
15 #include "mlir/TableGen/Pass.h"
|
|
16 #include "llvm/ADT/StringExtras.h"
|
|
17 #include "llvm/Support/FormatVariadic.h"
|
|
18 #include "llvm/TableGen/Error.h"
|
|
19 #include "llvm/TableGen/Record.h"
|
|
20
|
|
21 using namespace mlir;
|
|
22 using namespace mlir::tblgen;
|
|
23
|
|
24 //===----------------------------------------------------------------------===//
|
|
25 // GEN: Pass base class generation
|
|
26 //===----------------------------------------------------------------------===//
|
|
27
|
|
28 /// The code snippet used to generate the start of a pass base class.
|
|
29 ///
|
|
30 /// {0}: The def name of the pass record.
|
|
31 /// {1}: The base class for the pass.
|
|
32 /// {2): The command line argument for the pass.
|
|
33 const char *const passDeclBegin = R"(
|
|
34 //===----------------------------------------------------------------------===//
|
|
35 // {0}
|
|
36 //===----------------------------------------------------------------------===//
|
|
37
|
|
38 template <typename DerivedT>
|
|
39 class {0}Base : public {1} {
|
|
40 public:
|
|
41 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
|
|
42 {0}Base(const {0}Base &) : {1}(::mlir::TypeID::get<DerivedT>()) {{}
|
|
43
|
|
44 /// Returns the command-line argument attached to this pass.
|
|
45 llvm::StringRef getArgument() const override { return "{2}"; }
|
|
46
|
|
47 /// Returns the derived pass name.
|
|
48 llvm::StringRef getName() const override { return "{0}"; }
|
|
49
|
|
50 /// Support isa/dyn_cast functionality for the derived pass class.
|
|
51 static bool classof(const ::mlir::Pass *pass) {{
|
|
52 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
|
|
53 }
|
|
54
|
|
55 /// A clone method to create a copy of this pass.
|
|
56 std::unique_ptr<Pass> clonePass() const override {{
|
|
57 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
|
|
58 }
|
|
59
|
|
60 protected:
|
|
61 )";
|
|
62
|
|
63 /// Emit the declarations for each of the pass options.
|
|
64 static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
|
|
65 for (const PassOption &opt : pass.getOptions()) {
|
|
66 os.indent(2) << "Pass::" << (opt.isListOption() ? "ListOption" : "Option");
|
|
67
|
|
68 os << llvm::formatv("<{0}> {1}{{*this, \"{2}\", llvm::cl::desc(\"{3}\")",
|
|
69 opt.getType(), opt.getCppVariableName(),
|
|
70 opt.getArgument(), opt.getDescription());
|
|
71 if (Optional<StringRef> defaultVal = opt.getDefaultValue())
|
|
72 os << ", llvm::cl::init(" << defaultVal << ")";
|
|
73 if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags())
|
|
74 os << ", " << *additionalFlags;
|
|
75 os << "};\n";
|
|
76 }
|
|
77 }
|
|
78
|
|
79 /// Emit the declarations for each of the pass statistics.
|
|
80 static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
|
|
81 for (const PassStatistic &stat : pass.getStatistics()) {
|
|
82 os << llvm::formatv(" Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
|
|
83 stat.getCppVariableName(), stat.getName(),
|
|
84 stat.getDescription());
|
|
85 }
|
|
86 }
|
|
87
|
|
88 static void emitPassDecl(const Pass &pass, raw_ostream &os) {
|
|
89 StringRef defName = pass.getDef()->getName();
|
|
90 os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
|
|
91 pass.getArgument());
|
|
92 emitPassOptionDecls(pass, os);
|
|
93 emitPassStatisticDecls(pass, os);
|
|
94 os << "};\n";
|
|
95 }
|
|
96
|
|
97 /// Emit the code for registering each of the given passes with the global
|
|
98 /// PassRegistry.
|
|
99 static void emitPassDecls(ArrayRef<Pass> passes, raw_ostream &os) {
|
|
100 os << "#ifdef GEN_PASS_CLASSES\n";
|
|
101 for (const Pass &pass : passes)
|
|
102 emitPassDecl(pass, os);
|
|
103 os << "#undef GEN_PASS_CLASSES\n";
|
|
104 os << "#endif // GEN_PASS_CLASSES\n";
|
|
105 }
|
|
106
|
|
107 //===----------------------------------------------------------------------===//
|
|
108 // GEN: Pass registration generation
|
|
109 //===----------------------------------------------------------------------===//
|
|
110
|
|
111 /// Emit the code for registering each of the given passes with the global
|
|
112 /// PassRegistry.
|
|
113 static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
|
|
114 os << "#ifdef GEN_PASS_REGISTRATION\n";
|
|
115 for (const Pass &pass : passes) {
|
|
116 os << llvm::formatv("#define GEN_PASS_REGISTRATION_{0}\n",
|
|
117 pass.getDef()->getName());
|
|
118 }
|
|
119 os << "#endif // GEN_PASS_REGISTRATION\n";
|
|
120
|
|
121 for (const Pass &pass : passes) {
|
|
122 os << llvm::formatv("#ifdef GEN_PASS_REGISTRATION_{0}\n",
|
|
123 pass.getDef()->getName());
|
|
124 os << llvm::formatv("::mlir::registerPass(\"{0}\", \"{1}\", []() -> "
|
|
125 "std::unique_ptr<Pass> {{ return {2}; });\n",
|
|
126 pass.getArgument(), pass.getSummary(),
|
|
127 pass.getConstructor());
|
|
128 os << llvm::formatv("#endif // GEN_PASS_REGISTRATION_{0}\n",
|
|
129 pass.getDef()->getName());
|
|
130 os << llvm::formatv("#undef GEN_PASS_REGISTRATION_{0}\n",
|
|
131 pass.getDef()->getName());
|
|
132 }
|
|
133
|
|
134 os << "#ifdef GEN_PASS_REGISTRATION\n";
|
|
135 for (const Pass &pass : passes) {
|
|
136 os << llvm::formatv("#undef GEN_PASS_REGISTRATION_{0}\n",
|
|
137 pass.getDef()->getName());
|
|
138 }
|
|
139 os << "#endif // GEN_PASS_REGISTRATION\n";
|
|
140 os << "#undef GEN_PASS_REGISTRATION\n";
|
|
141 }
|
|
142
|
|
143 //===----------------------------------------------------------------------===//
|
|
144 // GEN: Registration hooks
|
|
145 //===----------------------------------------------------------------------===//
|
|
146
|
|
147 static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
|
|
148 os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
|
|
149 std::vector<Pass> passes;
|
|
150 for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
|
|
151 passes.push_back(Pass(def));
|
|
152
|
|
153 emitPassDecls(passes, os);
|
|
154 emitRegistration(passes, os);
|
|
155 }
|
|
156
|
|
157 static mlir::GenRegistration
|
|
158 genRegister("gen-pass-decls", "Generate operation documentation",
|
|
159 [](const llvm::RecordKeeper &records, raw_ostream &os) {
|
|
160 emitDecls(records, os);
|
|
161 return false;
|
|
162 });
|