150
|
1 //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===//
|
|
2 //
|
|
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
4 // See https://llvm.org/LICENSE.txt for license information.
|
|
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
6 //
|
|
7 //===----------------------------------------------------------------------===//
|
|
8
|
|
9 #include "mlir/IR/AffineMap.h"
|
|
10 #include "AffineMapDetail.h"
|
|
11 #include "mlir/IR/Attributes.h"
|
|
12 #include "mlir/IR/StandardTypes.h"
|
|
13 #include "mlir/Support/LogicalResult.h"
|
|
14 #include "mlir/Support/MathExtras.h"
|
|
15 #include "llvm/ADT/StringRef.h"
|
|
16 #include "llvm/Support/raw_ostream.h"
|
|
17
|
|
18 using namespace mlir;
|
|
19
|
|
20 namespace {
|
|
21
|
|
22 // AffineExprConstantFolder evaluates an affine expression using constant
|
|
23 // operands passed in 'operandConsts'. Returns an IntegerAttr attribute
|
|
24 // representing the constant value of the affine expression evaluated on
|
|
25 // constant 'operandConsts', or nullptr if it can't be folded.
|
|
26 class AffineExprConstantFolder {
|
|
27 public:
|
|
28 AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts)
|
|
29 : numDims(numDims), operandConsts(operandConsts) {}
|
|
30
|
|
31 /// Attempt to constant fold the specified affine expr, or return null on
|
|
32 /// failure.
|
|
33 IntegerAttr constantFold(AffineExpr expr) {
|
|
34 if (auto result = constantFoldImpl(expr))
|
|
35 return IntegerAttr::get(IndexType::get(expr.getContext()), *result);
|
|
36 return nullptr;
|
|
37 }
|
|
38
|
|
39 private:
|
|
40 Optional<int64_t> constantFoldImpl(AffineExpr expr) {
|
|
41 switch (expr.getKind()) {
|
|
42 case AffineExprKind::Add:
|
|
43 return constantFoldBinExpr(
|
|
44 expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; });
|
|
45 case AffineExprKind::Mul:
|
|
46 return constantFoldBinExpr(
|
|
47 expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
|
|
48 case AffineExprKind::Mod:
|
|
49 return constantFoldBinExpr(
|
|
50 expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
|
|
51 case AffineExprKind::FloorDiv:
|
|
52 return constantFoldBinExpr(
|
|
53 expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
|
|
54 case AffineExprKind::CeilDiv:
|
|
55 return constantFoldBinExpr(
|
|
56 expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
|
|
57 case AffineExprKind::Constant:
|
|
58 return expr.cast<AffineConstantExpr>().getValue();
|
|
59 case AffineExprKind::DimId:
|
|
60 if (auto attr = operandConsts[expr.cast<AffineDimExpr>().getPosition()]
|
|
61 .dyn_cast_or_null<IntegerAttr>())
|
|
62 return attr.getInt();
|
|
63 return llvm::None;
|
|
64 case AffineExprKind::SymbolId:
|
|
65 if (auto attr = operandConsts[numDims +
|
|
66 expr.cast<AffineSymbolExpr>().getPosition()]
|
|
67 .dyn_cast_or_null<IntegerAttr>())
|
|
68 return attr.getInt();
|
|
69 return llvm::None;
|
|
70 }
|
|
71 llvm_unreachable("Unknown AffineExpr");
|
|
72 }
|
|
73
|
|
74 // TODO: Change these to operate on APInts too.
|
|
75 Optional<int64_t> constantFoldBinExpr(AffineExpr expr,
|
|
76 int64_t (*op)(int64_t, int64_t)) {
|
|
77 auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
|
|
78 if (auto lhs = constantFoldImpl(binOpExpr.getLHS()))
|
|
79 if (auto rhs = constantFoldImpl(binOpExpr.getRHS()))
|
|
80 return op(*lhs, *rhs);
|
|
81 return llvm::None;
|
|
82 }
|
|
83
|
|
84 // The number of dimension operands in AffineMap containing this expression.
|
|
85 unsigned numDims;
|
|
86 // The constant valued operands used to evaluate this AffineExpr.
|
|
87 ArrayRef<Attribute> operandConsts;
|
|
88 };
|
|
89
|
|
90 } // end anonymous namespace
|
|
91
|
|
92 /// Returns a single constant result affine map.
|
|
93 AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
|
|
94 return get(/*dimCount=*/0, /*symbolCount=*/0,
|
|
95 {getAffineConstantExpr(val, context)});
|
|
96 }
|
|
97
|
173
|
98 /// Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most
|
|
99 /// minor dimensions.
|
|
100 AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results,
|
|
101 MLIRContext *context) {
|
|
102 assert(dims >= results && "Dimension mismatch");
|
|
103 auto id = AffineMap::getMultiDimIdentityMap(dims, context);
|
|
104 return AffineMap::get(dims, 0, id.getResults().take_back(results), context);
|
|
105 }
|
|
106
|
|
107 bool AffineMap::isMinorIdentity(AffineMap map) {
|
|
108 if (!map)
|
|
109 return false;
|
|
110 return map == getMinorIdentityMap(map.getNumDims(), map.getNumResults(),
|
|
111 map.getContext());
|
|
112 }
|
|
113
|
150
|
114 /// Returns an AffineMap representing a permutation.
|
|
115 AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
|
|
116 MLIRContext *context) {
|
|
117 assert(!permutation.empty() &&
|
|
118 "Cannot create permutation map from empty permutation vector");
|
|
119 SmallVector<AffineExpr, 4> affExprs;
|
|
120 for (auto index : permutation)
|
|
121 affExprs.push_back(getAffineDimExpr(index, context));
|
|
122 auto m = std::max_element(permutation.begin(), permutation.end());
|
173
|
123 auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context);
|
150
|
124 assert(permutationMap.isPermutation() && "Invalid permutation vector");
|
|
125 return permutationMap;
|
|
126 }
|
|
127
|
|
128 template <typename AffineExprContainer>
|
|
129 static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
|
|
130 int64_t &maxDim, int64_t &maxSym) {
|
|
131 for (const auto &exprs : exprsList) {
|
|
132 for (auto expr : exprs) {
|
|
133 expr.walk([&maxDim, &maxSym](AffineExpr e) {
|
|
134 if (auto d = e.dyn_cast<AffineDimExpr>())
|
|
135 maxDim = std::max(maxDim, static_cast<int64_t>(d.getPosition()));
|
|
136 if (auto s = e.dyn_cast<AffineSymbolExpr>())
|
|
137 maxSym = std::max(maxSym, static_cast<int64_t>(s.getPosition()));
|
|
138 });
|
|
139 }
|
|
140 }
|
|
141 }
|
|
142
|
|
143 template <typename AffineExprContainer>
|
173
|
144 static SmallVector<AffineMap, 4>
|
150
|
145 inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
|
173
|
146 assert(!exprsList.empty());
|
|
147 assert(!exprsList[0].empty());
|
|
148 auto context = exprsList[0][0].getContext();
|
150
|
149 int64_t maxDim = -1, maxSym = -1;
|
|
150 getMaxDimAndSymbol(exprsList, maxDim, maxSym);
|
|
151 SmallVector<AffineMap, 4> maps;
|
|
152 maps.reserve(exprsList.size());
|
|
153 for (const auto &exprs : exprsList)
|
|
154 maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1,
|
173
|
155 /*symbolCount=*/maxSym + 1, exprs, context));
|
150
|
156 return maps;
|
|
157 }
|
|
158
|
|
159 SmallVector<AffineMap, 4>
|
|
160 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList) {
|
|
161 return ::inferFromExprList(exprsList);
|
|
162 }
|
|
163
|
|
164 SmallVector<AffineMap, 4>
|
|
165 AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList) {
|
|
166 return ::inferFromExprList(exprsList);
|
|
167 }
|
|
168
|
|
169 AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
|
|
170 MLIRContext *context) {
|
|
171 SmallVector<AffineExpr, 4> dimExprs;
|
|
172 dimExprs.reserve(numDims);
|
|
173 for (unsigned i = 0; i < numDims; ++i)
|
|
174 dimExprs.push_back(mlir::getAffineDimExpr(i, context));
|
173
|
175 return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, context);
|
150
|
176 }
|
|
177
|
|
178 MLIRContext *AffineMap::getContext() const { return map->context; }
|
|
179
|
|
180 bool AffineMap::isIdentity() const {
|
|
181 if (getNumDims() != getNumResults())
|
|
182 return false;
|
|
183 ArrayRef<AffineExpr> results = getResults();
|
|
184 for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
|
|
185 auto expr = results[i].dyn_cast<AffineDimExpr>();
|
|
186 if (!expr || expr.getPosition() != i)
|
|
187 return false;
|
|
188 }
|
|
189 return true;
|
|
190 }
|
|
191
|
|
192 bool AffineMap::isEmpty() const {
|
|
193 return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0;
|
|
194 }
|
|
195
|
|
196 bool AffineMap::isSingleConstant() const {
|
|
197 return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>();
|
|
198 }
|
|
199
|
|
200 int64_t AffineMap::getSingleConstantResult() const {
|
|
201 assert(isSingleConstant() && "map must have a single constant result");
|
|
202 return getResult(0).cast<AffineConstantExpr>().getValue();
|
|
203 }
|
|
204
|
|
205 unsigned AffineMap::getNumDims() const {
|
|
206 assert(map && "uninitialized map storage");
|
|
207 return map->numDims;
|
|
208 }
|
|
209 unsigned AffineMap::getNumSymbols() const {
|
|
210 assert(map && "uninitialized map storage");
|
|
211 return map->numSymbols;
|
|
212 }
|
|
213 unsigned AffineMap::getNumResults() const {
|
|
214 assert(map && "uninitialized map storage");
|
|
215 return map->results.size();
|
|
216 }
|
|
217 unsigned AffineMap::getNumInputs() const {
|
|
218 assert(map && "uninitialized map storage");
|
|
219 return map->numDims + map->numSymbols;
|
|
220 }
|
|
221
|
|
222 ArrayRef<AffineExpr> AffineMap::getResults() const {
|
|
223 assert(map && "uninitialized map storage");
|
|
224 return map->results;
|
|
225 }
|
|
226 AffineExpr AffineMap::getResult(unsigned idx) const {
|
|
227 assert(map && "uninitialized map storage");
|
|
228 return map->results[idx];
|
|
229 }
|
|
230
|
|
231 /// Folds the results of the application of an affine map on the provided
|
|
232 /// operands to a constant if possible. Returns false if the folding happens,
|
|
233 /// true otherwise.
|
|
234 LogicalResult
|
|
235 AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
|
|
236 SmallVectorImpl<Attribute> &results) const {
|
173
|
237 // Attempt partial folding.
|
|
238 SmallVector<int64_t, 2> integers;
|
|
239 partialConstantFold(operandConstants, &integers);
|
|
240
|
|
241 // If all expressions folded to a constant, populate results with attributes
|
|
242 // containing those constants.
|
|
243 if (integers.empty())
|
|
244 return failure();
|
|
245
|
|
246 auto range = llvm::map_range(integers, [this](int64_t i) {
|
|
247 return IntegerAttr::get(IndexType::get(getContext()), i);
|
|
248 });
|
|
249 results.append(range.begin(), range.end());
|
|
250 return success();
|
|
251 }
|
|
252
|
|
253 AffineMap
|
|
254 AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
|
|
255 SmallVectorImpl<int64_t> *results) const {
|
150
|
256 assert(getNumInputs() == operandConstants.size());
|
|
257
|
|
258 // Fold each of the result expressions.
|
|
259 AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
|
173
|
260 SmallVector<AffineExpr, 4> exprs;
|
|
261 exprs.reserve(getNumResults());
|
|
262
|
150
|
263 for (auto expr : getResults()) {
|
|
264 auto folded = exprFolder.constantFold(expr);
|
173
|
265 // If did not fold to a constant, keep the original expression, and clear
|
|
266 // the integer results vector.
|
|
267 if (folded) {
|
|
268 exprs.push_back(
|
|
269 getAffineConstantExpr(folded.getInt(), folded.getContext()));
|
|
270 if (results)
|
|
271 results->push_back(folded.getInt());
|
|
272 } else {
|
|
273 exprs.push_back(expr);
|
|
274 if (results) {
|
|
275 results->clear();
|
|
276 results = nullptr;
|
|
277 }
|
|
278 }
|
|
279 }
|
150
|
280
|
173
|
281 return get(getNumDims(), getNumSymbols(), exprs, getContext());
|
150
|
282 }
|
|
283
|
|
284 /// Walk all of the AffineExpr's in this mapping. Each node in an expression
|
|
285 /// tree is visited in postorder.
|
|
286 void AffineMap::walkExprs(std::function<void(AffineExpr)> callback) const {
|
|
287 for (auto expr : getResults())
|
|
288 expr.walk(callback);
|
|
289 }
|
|
290
|
|
291 /// This method substitutes any uses of dimensions and symbols (e.g.
|
|
292 /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
|
|
293 /// expression mapping. Because this can be used to eliminate dims and
|
|
294 /// symbols, the client needs to specify the number of dims and symbols in
|
|
295 /// the result. The returned map always has the same number of results.
|
|
296 AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
|
|
297 ArrayRef<AffineExpr> symReplacements,
|
|
298 unsigned numResultDims,
|
173
|
299 unsigned numResultSyms) const {
|
150
|
300 SmallVector<AffineExpr, 8> results;
|
|
301 results.reserve(getNumResults());
|
|
302 for (auto expr : getResults())
|
|
303 results.push_back(
|
|
304 expr.replaceDimsAndSymbols(dimReplacements, symReplacements));
|
|
305
|
173
|
306 return get(numResultDims, numResultSyms, results, getContext());
|
150
|
307 }
|
|
308
|
|
309 AffineMap AffineMap::compose(AffineMap map) {
|
|
310 assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
|
|
311 // Prepare `map` by concatenating the symbols and rewriting its exprs.
|
|
312 unsigned numDims = map.getNumDims();
|
|
313 unsigned numSymbolsThisMap = getNumSymbols();
|
|
314 unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols();
|
|
315 SmallVector<AffineExpr, 8> newDims(numDims);
|
|
316 for (unsigned idx = 0; idx < numDims; ++idx) {
|
|
317 newDims[idx] = getAffineDimExpr(idx, getContext());
|
|
318 }
|
|
319 SmallVector<AffineExpr, 8> newSymbols(numSymbols);
|
|
320 for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) {
|
|
321 newSymbols[idx - numSymbolsThisMap] =
|
|
322 getAffineSymbolExpr(idx, getContext());
|
|
323 }
|
|
324 auto newMap =
|
|
325 map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols);
|
|
326 SmallVector<AffineExpr, 8> exprs;
|
|
327 exprs.reserve(getResults().size());
|
|
328 for (auto expr : getResults())
|
|
329 exprs.push_back(expr.compose(newMap));
|
173
|
330 return AffineMap::get(numDims, numSymbols, exprs, map.getContext());
|
150
|
331 }
|
|
332
|
|
333 bool AffineMap::isProjectedPermutation() {
|
|
334 if (getNumSymbols() > 0)
|
|
335 return false;
|
|
336 SmallVector<bool, 8> seen(getNumInputs(), false);
|
|
337 for (auto expr : getResults()) {
|
|
338 if (auto dim = expr.dyn_cast<AffineDimExpr>()) {
|
|
339 if (seen[dim.getPosition()])
|
|
340 return false;
|
|
341 seen[dim.getPosition()] = true;
|
|
342 continue;
|
|
343 }
|
|
344 return false;
|
|
345 }
|
|
346 return true;
|
|
347 }
|
|
348
|
|
349 bool AffineMap::isPermutation() {
|
|
350 if (getNumDims() != getNumResults())
|
|
351 return false;
|
|
352 return isProjectedPermutation();
|
|
353 }
|
|
354
|
|
355 AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) {
|
|
356 SmallVector<AffineExpr, 4> exprs;
|
|
357 exprs.reserve(resultPos.size());
|
|
358 for (auto idx : resultPos) {
|
|
359 exprs.push_back(getResult(idx));
|
|
360 }
|
173
|
361 return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
|
150
|
362 }
|
|
363
|
|
364 AffineMap mlir::simplifyAffineMap(AffineMap map) {
|
|
365 SmallVector<AffineExpr, 8> exprs;
|
|
366 for (auto e : map.getResults()) {
|
|
367 exprs.push_back(
|
|
368 simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols()));
|
|
369 }
|
173
|
370 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs,
|
|
371 map.getContext());
|
|
372 }
|
|
373
|
|
374 AffineMap mlir::removeDuplicateExprs(AffineMap map) {
|
|
375 auto results = map.getResults();
|
|
376 SmallVector<AffineExpr, 4> uniqueExprs(results.begin(), results.end());
|
|
377 uniqueExprs.erase(std::unique(uniqueExprs.begin(), uniqueExprs.end()),
|
|
378 uniqueExprs.end());
|
|
379 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), uniqueExprs,
|
|
380 map.getContext());
|
150
|
381 }
|
|
382
|
|
383 AffineMap mlir::inversePermutation(AffineMap map) {
|
173
|
384 if (map.isEmpty())
|
150
|
385 return map;
|
|
386 assert(map.getNumSymbols() == 0 && "expected map without symbols");
|
|
387 SmallVector<AffineExpr, 4> exprs(map.getNumDims());
|
|
388 for (auto en : llvm::enumerate(map.getResults())) {
|
|
389 auto expr = en.value();
|
|
390 // Skip non-permutations.
|
|
391 if (auto d = expr.dyn_cast<AffineDimExpr>()) {
|
|
392 if (exprs[d.getPosition()])
|
|
393 continue;
|
|
394 exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
|
|
395 }
|
|
396 }
|
|
397 SmallVector<AffineExpr, 4> seenExprs;
|
|
398 seenExprs.reserve(map.getNumDims());
|
|
399 for (auto expr : exprs)
|
|
400 if (expr)
|
|
401 seenExprs.push_back(expr);
|
|
402 if (seenExprs.size() != map.getNumInputs())
|
|
403 return AffineMap();
|
173
|
404 return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext());
|
150
|
405 }
|
|
406
|
|
407 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
|
|
408 unsigned numResults = 0;
|
|
409 for (auto m : maps)
|
173
|
410 numResults += m.getNumResults();
|
150
|
411 unsigned numDims = 0;
|
|
412 SmallVector<AffineExpr, 8> results;
|
|
413 results.reserve(numResults);
|
|
414 for (auto m : maps) {
|
|
415 assert(m.getNumSymbols() == 0 && "expected map without symbols");
|
|
416 results.append(m.getResults().begin(), m.getResults().end());
|
|
417 numDims = std::max(m.getNumDims(), numDims);
|
|
418 }
|
173
|
419 return AffineMap::get(numDims, /*numSymbols=*/0, results,
|
|
420 maps.front().getContext());
|
150
|
421 }
|
|
422
|
|
423 //===----------------------------------------------------------------------===//
|
|
424 // MutableAffineMap.
|
|
425 //===----------------------------------------------------------------------===//
|
|
426
|
|
427 MutableAffineMap::MutableAffineMap(AffineMap map)
|
|
428 : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
|
173
|
429 context(map.getContext()) {
|
150
|
430 for (auto result : map.getResults())
|
|
431 results.push_back(result);
|
|
432 }
|
|
433
|
|
434 void MutableAffineMap::reset(AffineMap map) {
|
|
435 results.clear();
|
|
436 numDims = map.getNumDims();
|
|
437 numSymbols = map.getNumSymbols();
|
173
|
438 context = map.getContext();
|
150
|
439 for (auto result : map.getResults())
|
|
440 results.push_back(result);
|
|
441 }
|
|
442
|
|
443 bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
|
|
444 if (results[idx].isMultipleOf(factor))
|
|
445 return true;
|
|
446
|
|
447 // TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to
|
|
448 // complete this (for a more powerful analysis).
|
|
449 return false;
|
|
450 }
|
|
451
|
|
452 // Simplifies the result affine expressions of this map. The expressions have to
|
|
453 // be pure for the simplification implemented.
|
|
454 void MutableAffineMap::simplify() {
|
|
455 // Simplify each of the results if possible.
|
|
456 // TODO(ntv): functional-style map
|
|
457 for (unsigned i = 0, e = getNumResults(); i < e; i++) {
|
|
458 results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols);
|
|
459 }
|
|
460 }
|
|
461
|
|
462 AffineMap MutableAffineMap::getAffineMap() const {
|
173
|
463 return AffineMap::get(numDims, numSymbols, results, context);
|
150
|
464 }
|