150
|
1 # Quickstart tutorial to adding MLIR graph rewrite
|
|
2
|
|
3 This document will present a quickstart to adding graph rewrites. We shall start
|
|
4 by defining an operation, showing multiple ways to define the rewrite using
|
|
5 patterns, as well as defining the rewrite using a graph walker (note: using
|
|
6 patterns and the rewrite engine is preferred, showing the walker is for
|
|
7 demonstration purposes).
|
|
8
|
|
9 See [MLIR specification](LangRef.md) for more information about MLIR, the
|
|
10 structure of the IR, operations, etc. See
|
|
11 [Table-driven Operation Definition](OpDefinitions.md) and
|
|
12 [Declarative Rewrite Rule](DeclarativeRewrites.md) for the detailed explanation
|
|
13 of all available mechanisms for defining operations and rewrites in a
|
|
14 table-driven manner.
|
|
15
|
|
16 ## Adding operation
|
|
17
|
|
18 An operation in MLIR is specified using a definition in
|
|
19 [TableGen](https://llvm.org/docs/TableGen/LangIntro.html) file. TableGen is a
|
|
20 modeling tool to specify the ops and the C++ code to interact with these
|
|
21 operations are generated from. To define an operation one needs to specify:
|
|
22
|
|
23 * The operation name. This name is a unique identifier of the operation within
|
|
24 MLIR. Most operations are within a dialect, so for example one could have
|
|
25 `tfl.add` to represent the add operation in the TensorFlow Lite dialect.
|
|
26 Instead of repeating the dialect in the op definition, a base class for the
|
|
27 op dialect is commonly created that prepends the dialect namespace given an
|
|
28 op name.
|
|
29 * The traits of the operation. These allow you to specify traits of the
|
|
30 operation, such as whether it has side effects or whether it should be
|
|
31 verified that the operands and result types are the same. These are backed
|
|
32 by C++ traits that perform the verification.
|
|
33 * The arguments of the operation. These are the input operands (values at
|
|
34 runtime produced by other ops) and attributes (compile time known constant
|
|
35 values that affect the behavior of the op) that are the inputs of/define the
|
|
36 behavior of the operation. The input operands may be named, the attributes
|
|
37 must be named.
|
|
38 * The result(s) of the operation. These may again named or not.
|
|
39 * Documentation of the operation. This includes a one-line summary as well as
|
|
40 a longer human-readable description of the operation.
|
|
41 * Dialect specific information. Additional information could be added to the
|
|
42 operation definition that are only used by dialect specific drivers. These
|
|
43 are ignored by the main op and doc generators, but could be used in, say,
|
|
44 the translation from a dialect to another representation.
|
|
45
|
|
46 ```tablegen
|
|
47 def TFL_LeakyReluOp: TFL_Op<TFL_Dialect, "leaky_relu",
|
|
48 [NoSideEffect, SameValueType]>,
|
|
49 Results<(outs Tensor)> {
|
|
50 let arguments = (ins
|
|
51 F32Tensor:$x,
|
|
52 // Slope of the activation function at x < 0.
|
|
53 F32Attr:$alpha
|
|
54 );
|
|
55
|
|
56 let summary = "Leaky ReLU operator";
|
|
57 let description = [{
|
|
58 Element-wise Leaky ReLU operator
|
|
59 x -> x >= 0 ? x : (alpha * x)
|
|
60 }];
|
|
61
|
|
62 // TFLite specific attribute that is used when generating the output
|
|
63 // flatbuffer.
|
|
64 let hasOptions = 1;
|
|
65 }
|
|
66 ```
|
|
67
|
|
68 Note in the above the result types and inputs are specified in different ways,
|
|
69 one by way of trait and the other by way of let. It is possible to specify both
|
|
70 in either way.
|
|
71
|
|
72 <!-- TODO: Define a style convention. -->
|
|
73
|
|
74 Operations can also have custom parser, printer, builder, verifier, constant
|
|
75 folder, or canonicalizer. These require specifying additional C++ methods to
|
|
76 invoke for additional functionality. For example, if an operation is marked to
|
|
77 have a folder, the constant folder also needs to be added, e.g.,:
|
|
78
|
|
79 ```c++
|
|
80 OpFoldResult SpecificOp::fold(ArrayRef<Attribute> constOperands) {
|
|
81 if (unable_to_fold)
|
|
82 return {};
|
|
83 ....
|
|
84 return val;
|
|
85 }
|
|
86 ```
|
|
87
|
|
88 ## Adding patterns
|
|
89
|
|
90 There are multiple forms of graph rewrite that can be performed in MLIR. One of
|
|
91 the most common is DAG tile to DAG tile rewrite. Patterns provide a concise way
|
|
92 to express this transformation as a pair of source pattern to match and
|
|
93 resultant pattern. There are both the C++ classes to represent this
|
|
94 transformation, as well as the patterns in TableGen from which these can be
|
|
95 generated.
|
|
96
|
|
97 ### TableGen patterns
|
|
98
|
|
99 Let us continue with LeakyRelu. To map from TensorFlow's `LeakyRelu` to
|
|
100 TensorFlow Lite's `LeakyRelu`:
|
|
101
|
|
102 ```tablegen
|
|
103 def : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), (TFL_LeakyReluOp $arg, $a)>
|
|
104 ```
|
|
105
|
|
106 The pattern is specified by instantiating a `Pat` with a source and result DAG.
|
|
107 The arguments in the source pattern is captured and can be used in the result
|
|
108 pattern. This is a simple pattern as we have a 1:1 mapping and the attribute
|
|
109 does not need to be transformed (e.g., both have a floating point attribute for
|
|
110 alpha). The names of the attributes specified in the pattern is for
|
|
111 matching/referencing and need not match the original attribute name in the op
|
|
112 definition but the order of arguments of the dags do need to match.
|
|
113
|
|
114 To specify a pattern, both the source and resultant ops need to be defined using
|
|
115 TableGen.
|
|
116
|
|
117 If this were a more advance pattern that the current framework could not express
|
|
118 as destination then one could use a general native code fallback method. This
|
|
119 consists of defining a pattern as well as adding a C++ function to perform the
|
|
120 replacement:
|
|
121
|
|
122 ```tablegen
|
|
123 def createTFLLeakyRelu : NativeCodeCall<
|
|
124 "createTFLLeakyRelu($_builder, $0.getDefiningOp(), $1, $2)">;
|
|
125
|
|
126 def : Pat<(TF_LeakyReluOp:$old_value, $arg, F32Attr:$a),
|
|
127 (createTFLLeakyRelu $old_value, $arg, $a)>;
|
|
128 ```
|
|
129
|
|
130 ```c++
|
|
131 static Value createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op,
|
|
132 Value operand, Attribute attr) {
|
|
133 return rewriter.create<mlir::TFL::LeakyReluOp>(
|
|
134 op->getLoc(), operands[0].getType(), /*arg=*/operands[0],
|
|
135 /*alpha=*/attrs[0].cast<FloatAttr>());
|
|
136 }
|
|
137 ```
|
|
138
|
|
139 This allows for arbitrarily complex builders. Input pattern side one can express
|
|
140 multi-op patterns with constraints on input operands and attributes. But input
|
|
141 patterns cannot yet express constraints across multiple operands/attributes.
|
|
142
|
|
143 ### Register the pattern
|
|
144
|
|
145 The file containing the patterns need to be processed using `mlir-tblgen`
|
|
146 `-gen-rewriters` during compilation time. It can be invoked with the following
|
|
147 configuration in CMake:
|
|
148
|
|
149 ```cmake
|
|
150 set(LLVM_TARGET_DEFINITIONS <name-of-the-td-file>)
|
|
151 mlir_tablegen(<name-of-the-generated-inc-file> -gen-rewriters)
|
|
152 add_public_tablegen_target(<name-of-the-cmake-target>)
|
|
153 ```
|
|
154
|
|
155 Then you can `#include` the generated file in any C++ implementation file you
|
|
156 like. (You will also need to make sure the library depends on the CMake target
|
|
157 defined in the above.) The generated file will have a `populateWithGenerated(
|
|
158 MLIRContext *context, OwningRewritePatternList *patterns)` function that you can
|
|
159 use to collect all the generated patterns inside `patterns` and then use
|
|
160 `patterns` in any pass you would like.
|
|
161
|
|
162 ### C++ rewrite specification
|
|
163
|
|
164 In case patterns are not sufficient there is also the fully C++ way of
|
|
165 expressing a rewrite:
|
|
166
|
|
167 ```c++
|
|
168 /// Multi-step rewrite using "match" and "rewrite". This allows for separating
|
|
169 /// the concerns of matching and rewriting.
|
|
170 struct ConvertTFLeakyRelu : public RewritePattern {
|
|
171 ConvertTFLeakyRelu(MLIRContext *context)
|
|
172 : RewritePattern("tf.LeakyRelu", 1, context) {}
|
|
173
|
|
174 PatternMatchResult match(Operation *op) const override {
|
|
175 return matchSuccess();
|
|
176 }
|
|
177
|
|
178 void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
|
179 rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
|
|
180 op, op->getResult(0).getType(), op->getOperand(0),
|
|
181 /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
|
|
182 }
|
|
183 };
|
|
184
|
|
185 /// Single-step rewrite with "matchAndRewrite". This allows for performing the
|
|
186 /// rewrite immediately upon a successful match.
|
|
187 struct ConvertTFLeakyRelu : public RewritePattern {
|
|
188 ConvertTFLeakyRelu(MLIRContext *context)
|
|
189 : RewritePattern("tf.LeakyRelu", 1, context) {}
|
|
190
|
|
191 PatternMatchResult matchAndRewrite(Operation *op,
|
|
192 PatternRewriter &rewriter) const override {
|
|
193 rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
|
|
194 op, op->getResult(0).getType(), op->getOperand(0),
|
|
195 /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
|
|
196 return matchSuccess();
|
|
197 }
|
|
198 };
|
|
199 ```
|
|
200
|
|
201 In the C++ rewrite the static benefit of the rewrite pattern is specified at
|
|
202 construction. While in the pattern generator a simple heuristic is currently
|
|
203 employed based around the number of ops matched and replaced.
|
|
204
|
|
205 The above rule did not capture the matching operands/attributes, but in general
|
|
206 the `match` function in a multi-step rewrite may populate and return a
|
|
207 `PatternState` (or class derived from one) to pass information extracted during
|
|
208 matching to the rewrite. A single-step rewrite with the `matchAndRewrite`
|
|
209 function has the benefit of being able to directly use any values created when
|
|
210 matching; removing the need for `PatternState`.
|
|
211
|
|
212 ## Testing
|
|
213
|
|
214 MLIR uses [lit](https://llvm.org/docs/CommandGuide/lit.html) (LLVM Integrated
|
|
215 Testing) tool for performing testing. Testing is performed by way of creating
|
|
216 the input IR file, running a transformation and then verifying the output IR.
|
|
217 C++ unit tests are the exception, with the IR transformation serving as the core
|
|
218 testing mechanism. This results in fewer binaries that need to be built (and
|
|
219 linked) and forces to focus on the representation as an important piece.
|
|
220
|
|
221 For the legalization transform above we would have a test (probably as part of
|
|
222 the legalization pass test in TensorFlow Lite) such as:
|
|
223
|
|
224 ```mlir
|
|
225 // RUN: mlir-opt -tfl-legalize-tf %s | FileCheck %s
|
|
226
|
|
227 func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
|
228 %2 = "tf.LeakyRelu"(%arg0) {alpha: 0.1} : (tensor<1xf32>) -> tensor<1xf32>
|
|
229 return %2: tensor<1xf32>
|
|
230
|
|
231 // CHECK-LABEL: LeakyRelu
|
|
232 // CHECK: %0 = "tfl.leaky_relu"(%arg0) {alpha: 1.000000e-01} : (tensor<1xf32>) -> tensor<1xf32>
|
|
233 }
|
|
234 ```
|
|
235
|
|
236 The RUN command at the top results in running the `mlir-opt` binary (which is
|
|
237 compiler writer tool to exercise different registered passes) to invoke the
|
|
238 optimization pass this transform was added as part of on the current file and to
|
|
239 verify its output using `FileCheck`. `FileCheck` is textual output verifier. In
|
|
240 particular it uses the CHECK expressions to verify the given output is produced.
|
|
241
|
|
242 There can be multiple RUN commands with different corresponding CHECK prefixes.
|
|
243 And in addition multiple independent tests separated by `// -----` and
|
|
244 `mlir-opt` invoked with `-split-input-file` flag. This is especially useful for
|
|
245 error testing.
|
|
246
|
|
247 This results in very simple, directed testing without need to work around
|
|
248 constant propagation or other, unrelated, optimization passes.
|
|
249
|
|
250 ## Adding optimization pass
|
|
251
|
|
252 Optimization passes that do not fit/are difficult to specify in the above
|
|
253 structure can be specified as general iterations across modules/functions. See
|
|
254 [Writing a Pass](WritingAPass.md) for a general overview and introduction to
|
|
255 optimization passes in MLIR.
|