comparison mlir/lib/Transforms/PipelineDataTransfer.cpp @ 173:0572611fdcc8 llvm10 llvm12

reorgnization done
author Shinji KONO <kono@ie.u-ryukyu.ac.jp>
date Mon, 25 May 2020 11:55:54 +0900
parents 1d019706d866
children 2e18cbf3894f
comparison
equal deleted inserted replaced
172:9fbae9c8bf63 173:0572611fdcc8
8 // 8 //
9 // This file implements a pass to pipeline data transfers. 9 // This file implements a pass to pipeline data transfers.
10 // 10 //
11 //===----------------------------------------------------------------------===// 11 //===----------------------------------------------------------------------===//
12 12
13 #include "PassDetail.h"
13 #include "mlir/Transforms/Passes.h" 14 #include "mlir/Transforms/Passes.h"
14 15
15 #include "mlir/Analysis/AffineAnalysis.h" 16 #include "mlir/Analysis/AffineAnalysis.h"
16 #include "mlir/Analysis/LoopAnalysis.h" 17 #include "mlir/Analysis/LoopAnalysis.h"
17 #include "mlir/Analysis/Utils.h" 18 #include "mlir/Analysis/Utils.h"
18 #include "mlir/Dialect/AffineOps/AffineOps.h" 19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/StandardOps/Ops.h"
20 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/Builders.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/LoopUtils.h" 21 #include "mlir/Transforms/LoopUtils.h"
23 #include "mlir/Transforms/Utils.h" 22 #include "mlir/Transforms/Utils.h"
24 #include "llvm/ADT/DenseMap.h" 23 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/Debug.h"
25
26 #define DEBUG_TYPE "affine-pipeline-data-transfer" 26 #define DEBUG_TYPE "affine-pipeline-data-transfer"
27 27
28 using namespace mlir; 28 using namespace mlir;
29 29
30 namespace { 30 namespace {
31 31 struct PipelineDataTransfer
32 struct PipelineDataTransfer : public FunctionPass<PipelineDataTransfer> { 32 : public AffinePipelineDataTransferBase<PipelineDataTransfer> {
33 void runOnFunction() override; 33 void runOnFunction() override;
34 void runOnAffineForOp(AffineForOp forOp); 34 void runOnAffineForOp(AffineForOp forOp);
35 35
36 std::vector<AffineForOp> forOps; 36 std::vector<AffineForOp> forOps;
37 }; 37 };
38 38
39 } // end anonymous namespace 39 } // end anonymous namespace
40 40
41 /// Creates a pass to pipeline explicit movement of data across levels of the 41 /// Creates a pass to pipeline explicit movement of data across levels of the
42 /// memory hierarchy. 42 /// memory hierarchy.
43 std::unique_ptr<OpPassBase<FuncOp>> mlir::createPipelineDataTransferPass() { 43 std::unique_ptr<OperationPass<FuncOp>> mlir::createPipelineDataTransferPass() {
44 return std::make_unique<PipelineDataTransfer>(); 44 return std::make_unique<PipelineDataTransfer>();
45 } 45 }
46 46
47 // Returns the position of the tag memref operand given a DMA operation. 47 // Returns the position of the tag memref operand given a DMA operation.
48 // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are 48 // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
49 // added. TODO(b/117228571) 49 // added. TODO(b/117228571)
50 static unsigned getTagMemRefPos(Operation &dmaInst) { 50 static unsigned getTagMemRefPos(Operation &dmaOp) {
51 assert(isa<AffineDmaStartOp>(dmaInst) || isa<AffineDmaWaitOp>(dmaInst)); 51 assert(isa<AffineDmaStartOp>(dmaOp) || isa<AffineDmaWaitOp>(dmaOp));
52 if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaInst)) { 52 if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) {
53 return dmaStartOp.getTagMemRefOperandIndex(); 53 return dmaStartOp.getTagMemRefOperandIndex();
54 } 54 }
55 // First operand for a dma finish operation. 55 // First operand for a dma finish operation.
56 return 0; 56 return 0;
57 } 57 }
78 }; 78 };
79 79
80 auto oldMemRefType = oldMemRef.getType().cast<MemRefType>(); 80 auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
81 auto newMemRefType = doubleShape(oldMemRefType); 81 auto newMemRefType = doubleShape(oldMemRefType);
82 82
83 // The double buffer is allocated right before 'forInst'. 83 // The double buffer is allocated right before 'forOp'.
84 auto *forInst = forOp.getOperation(); 84 OpBuilder bOuter(forOp);
85 OpBuilder bOuter(forInst);
86 // Put together alloc operands for any dynamic dimensions of the memref. 85 // Put together alloc operands for any dynamic dimensions of the memref.
87 SmallVector<Value, 4> allocOperands; 86 SmallVector<Value, 4> allocOperands;
88 unsigned dynamicDimCount = 0; 87 unsigned dynamicDimCount = 0;
89 for (auto dimSize : oldMemRefType.getShape()) { 88 for (auto dimSize : oldMemRefType.getShape()) {
90 if (dimSize == -1) 89 if (dimSize == -1)
91 allocOperands.push_back(bOuter.create<DimOp>(forInst->getLoc(), oldMemRef, 90 allocOperands.push_back(
92 dynamicDimCount++)); 91 bOuter.create<DimOp>(forOp.getLoc(), oldMemRef, dynamicDimCount++));
93 } 92 }
94 93
95 // Create and place the alloc right before the 'affine.for' operation. 94 // Create and place the alloc right before the 'affine.for' operation.
96 Value newMemRef = 95 Value newMemRef =
97 bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands); 96 bOuter.create<AllocOp>(forOp.getLoc(), newMemRefType, allocOperands);
98 97
99 // Create 'iv mod 2' value to index the leading dimension. 98 // Create 'iv mod 2' value to index the leading dimension.
100 auto d0 = bInner.getAffineDimExpr(0); 99 auto d0 = bInner.getAffineDimExpr(0);
101 int64_t step = forOp.getStep(); 100 int64_t step = forOp.getStep();
102 auto modTwoMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, 101 auto modTwoMap =
103 {d0.floorDiv(step) % 2}); 102 AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
104 auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap, 103 auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
105 forOp.getInductionVar()); 104 forOp.getInductionVar());
106 105
107 // replaceAllMemRefUsesWith will succeed unless the forOp body has 106 // replaceAllMemRefUsesWith will succeed unless the forOp body has
108 // non-dereferencing uses of the memref (dealloc's are fine though). 107 // non-dereferencing uses of the memref (dealloc's are fine though).
117 forOp.emitError("memref replacement for double buffering failed")); 116 forOp.emitError("memref replacement for double buffering failed"));
118 ivModTwoOp.erase(); 117 ivModTwoOp.erase();
119 return false; 118 return false;
120 } 119 }
121 // Insert the dealloc op right after the for loop. 120 // Insert the dealloc op right after the for loop.
122 bOuter.setInsertionPointAfter(forInst); 121 bOuter.setInsertionPointAfter(forOp);
123 bOuter.create<DeallocOp>(forInst->getLoc(), newMemRef); 122 bOuter.create<DeallocOp>(forOp.getLoc(), newMemRef);
124 123
125 return true; 124 return true;
126 } 125 }
127 126
128 /// Returns success if the IR is in a valid state. 127 /// Returns success if the IR is in a valid state.
218 if (!escapingUses) 217 if (!escapingUses)
219 dmaStartInsts.push_back(&op); 218 dmaStartInsts.push_back(&op);
220 } 219 }
221 220
222 // For each start operation, we look for a matching finish operation. 221 // For each start operation, we look for a matching finish operation.
223 for (auto *dmaStartInst : dmaStartInsts) { 222 for (auto *dmaStartOp : dmaStartInsts) {
224 for (auto *dmaFinishInst : dmaFinishInsts) { 223 for (auto *dmaFinishOp : dmaFinishInsts) {
225 if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartInst), 224 if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp),
226 cast<AffineDmaWaitOp>(dmaFinishInst))) { 225 cast<AffineDmaWaitOp>(dmaFinishOp))) {
227 startWaitPairs.push_back({dmaStartInst, dmaFinishInst}); 226 startWaitPairs.push_back({dmaStartOp, dmaFinishOp});
228 break; 227 break;
229 } 228 }
230 } 229 }
231 } 230 }
232 } 231 }
235 /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are 234 /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
236 /// inserted right before where it was. 235 /// inserted right before where it was.
237 void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { 236 void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
238 auto mayBeConstTripCount = getConstantTripCount(forOp); 237 auto mayBeConstTripCount = getConstantTripCount(forOp);
239 if (!mayBeConstTripCount.hasValue()) { 238 if (!mayBeConstTripCount.hasValue()) {
240 LLVM_DEBUG( 239 LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count"));
241 forOp.emitRemark("won't pipeline due to unknown trip count loop"));
242 return; 240 return;
243 } 241 }
244 242
245 SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs; 243 SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
246 findMatchingStartFinishInsts(forOp, startWaitPairs); 244 findMatchingStartFinishInsts(forOp, startWaitPairs);
257 // TODO(bondhugula): check whether double-buffering is even necessary. 255 // TODO(bondhugula): check whether double-buffering is even necessary.
258 // TODO(bondhugula): make this work with different layouts: assuming here that 256 // TODO(bondhugula): make this work with different layouts: assuming here that
259 // the dimension we are adding here for the double buffering is the outermost 257 // the dimension we are adding here for the double buffering is the outermost
260 // dimension. 258 // dimension.
261 for (auto &pair : startWaitPairs) { 259 for (auto &pair : startWaitPairs) {
262 auto *dmaStartInst = pair.first; 260 auto *dmaStartOp = pair.first;
263 Value oldMemRef = dmaStartInst->getOperand( 261 Value oldMemRef = dmaStartOp->getOperand(
264 cast<AffineDmaStartOp>(dmaStartInst).getFasterMemPos()); 262 cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos());
265 if (!doubleBuffer(oldMemRef, forOp)) { 263 if (!doubleBuffer(oldMemRef, forOp)) {
266 // Normally, double buffering should not fail because we already checked 264 // Normally, double buffering should not fail because we already checked
267 // that there are no uses outside. 265 // that there are no uses outside.
268 LLVM_DEBUG(llvm::dbgs() 266 LLVM_DEBUG(llvm::dbgs()
269 << "double buffering failed for" << dmaStartInst << "\n";); 267 << "double buffering failed for" << dmaStartOp << "\n";);
270 // IR still valid and semantically correct. 268 // IR still valid and semantically correct.
271 return; 269 return;
272 } 270 }
273 // If the old memref has no more uses, remove its 'dead' alloc if it was 271 // If the old memref has no more uses, remove its 'dead' alloc if it was
274 // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim' 272 // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
275 // operation could have been used on it if it was dynamically shaped in 273 // operation could have been used on it if it was dynamically shaped in
276 // order to create the double buffer above.) 274 // order to create the double buffer above.)
277 // '-canonicalize' does this in a more general way, but we'll anyway do the 275 // '-canonicalize' does this in a more general way, but we'll anyway do the
278 // simple/common case so that the output / test cases looks clear. 276 // simple/common case so that the output / test cases looks clear.
279 if (auto *allocInst = oldMemRef.getDefiningOp()) { 277 if (auto *allocOp = oldMemRef.getDefiningOp()) {
280 if (oldMemRef.use_empty()) { 278 if (oldMemRef.use_empty()) {
281 allocInst->erase(); 279 allocOp->erase();
282 } else if (oldMemRef.hasOneUse()) { 280 } else if (oldMemRef.hasOneUse()) {
283 if (auto dealloc = dyn_cast<DeallocOp>(*oldMemRef.user_begin())) { 281 if (auto dealloc = dyn_cast<DeallocOp>(*oldMemRef.user_begin())) {
284 dealloc.erase(); 282 dealloc.erase();
285 allocInst->erase(); 283 allocOp->erase();
286 } 284 }
287 } 285 }
288 } 286 }
289 } 287 }
290 288
291 // Double the buffers for tag memrefs. 289 // Double the buffers for tag memrefs.
292 for (auto &pair : startWaitPairs) { 290 for (auto &pair : startWaitPairs) {
293 auto *dmaFinishInst = pair.second; 291 auto *dmaFinishOp = pair.second;
294 Value oldTagMemRef = 292 Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp));
295 dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst));
296 if (!doubleBuffer(oldTagMemRef, forOp)) { 293 if (!doubleBuffer(oldTagMemRef, forOp)) {
297 LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); 294 LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
298 return; 295 return;
299 } 296 }
300 // If the old tag has no uses or a single dealloc use, remove it. 297 // If the old tag has no uses or a single dealloc use, remove it.
301 // (canonicalization handles more complex cases). 298 // (canonicalization handles more complex cases).
302 if (auto *tagAllocInst = oldTagMemRef.getDefiningOp()) { 299 if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) {
303 if (oldTagMemRef.use_empty()) { 300 if (oldTagMemRef.use_empty()) {
304 tagAllocInst->erase(); 301 tagAllocOp->erase();
305 } else if (oldTagMemRef.hasOneUse()) { 302 } else if (oldTagMemRef.hasOneUse()) {
306 if (auto dealloc = dyn_cast<DeallocOp>(*oldTagMemRef.user_begin())) { 303 if (auto dealloc = dyn_cast<DeallocOp>(*oldTagMemRef.user_begin())) {
307 dealloc.erase(); 304 dealloc.erase();
308 tagAllocInst->erase(); 305 tagAllocOp->erase();
309 } 306 }
310 } 307 }
311 } 308 }
312 } 309 }
313 310
316 findMatchingStartFinishInsts(forOp, startWaitPairs); 313 findMatchingStartFinishInsts(forOp, startWaitPairs);
317 314
318 // Store shift for operation for later lookup for AffineApplyOp's. 315 // Store shift for operation for later lookup for AffineApplyOp's.
319 DenseMap<Operation *, unsigned> instShiftMap; 316 DenseMap<Operation *, unsigned> instShiftMap;
320 for (auto &pair : startWaitPairs) { 317 for (auto &pair : startWaitPairs) {
321 auto *dmaStartInst = pair.first; 318 auto *dmaStartOp = pair.first;
322 assert(isa<AffineDmaStartOp>(dmaStartInst)); 319 assert(isa<AffineDmaStartOp>(dmaStartOp));
323 instShiftMap[dmaStartInst] = 0; 320 instShiftMap[dmaStartOp] = 0;
324 // Set shifts for DMA start op's affine operand computation slices to 0. 321 // Set shifts for DMA start op's affine operand computation slices to 0.
325 SmallVector<AffineApplyOp, 4> sliceOps; 322 SmallVector<AffineApplyOp, 4> sliceOps;
326 mlir::createAffineComputationSlice(dmaStartInst, &sliceOps); 323 mlir::createAffineComputationSlice(dmaStartOp, &sliceOps);
327 if (!sliceOps.empty()) { 324 if (!sliceOps.empty()) {
328 for (auto sliceOp : sliceOps) { 325 for (auto sliceOp : sliceOps) {
329 instShiftMap[sliceOp.getOperation()] = 0; 326 instShiftMap[sliceOp.getOperation()] = 0;
330 } 327 }
331 } else { 328 } else {
332 // If a slice wasn't created, the reachable affine.apply op's from its 329 // If a slice wasn't created, the reachable affine.apply op's from its
333 // operands are the ones that go with it. 330 // operands are the ones that go with it.
334 SmallVector<Operation *, 4> affineApplyInsts; 331 SmallVector<Operation *, 4> affineApplyInsts;
335 SmallVector<Value, 4> operands(dmaStartInst->getOperands()); 332 SmallVector<Value, 4> operands(dmaStartOp->getOperands());
336 getReachableAffineApplyOps(operands, affineApplyInsts); 333 getReachableAffineApplyOps(operands, affineApplyInsts);
337 for (auto *op : affineApplyInsts) { 334 for (auto *op : affineApplyInsts) {
338 instShiftMap[op] = 0; 335 instShiftMap[op] = 0;
339 } 336 }
340 } 337 }
341 } 338 }
342 // Everything else (including compute ops and dma finish) are shifted by one. 339 // Everything else (including compute ops and dma finish) are shifted by one.
343 for (auto &op : *forOp.getBody()) { 340 for (auto &op : forOp.getBody()->without_terminator())
344 if (instShiftMap.find(&op) == instShiftMap.end()) { 341 if (instShiftMap.find(&op) == instShiftMap.end())
345 instShiftMap[&op] = 1; 342 instShiftMap[&op] = 1;
346 }
347 }
348 343
349 // Get shifts stored in map. 344 // Get shifts stored in map.
350 std::vector<uint64_t> shifts(forOp.getBody()->getOperations().size()); 345 SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size());
351 unsigned s = 0; 346 unsigned s = 0;
352 for (auto &op : *forOp.getBody()) { 347 for (auto &op : forOp.getBody()->without_terminator()) {
353 assert(instShiftMap.find(&op) != instShiftMap.end()); 348 assert(instShiftMap.find(&op) != instShiftMap.end());
354 shifts[s++] = instShiftMap[&op]; 349 shifts[s++] = instShiftMap[&op];
355 350
356 // Tagging operations with shifts for debugging purposes. 351 // Tagging operations with shifts for debugging purposes.
357 LLVM_DEBUG({ 352 LLVM_DEBUG({
358 OpBuilder b(&op); 353 OpBuilder b(&op);
359 op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1])); 354 op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
360 }); 355 });
361 } 356 }
362 357
363 if (!isInstwiseShiftValid(forOp, shifts)) { 358 if (!isOpwiseShiftValid(forOp, shifts)) {
364 // Violates dependences. 359 // Violates dependences.
365 LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); 360 LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
366 return; 361 return;
367 } 362 }
368 363
369 if (failed(instBodySkew(forOp, shifts))) { 364 if (failed(affineForOpBodySkew(forOp, shifts))) {
370 LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";); 365 LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
371 return; 366 return;
372 } 367 }
373 } 368 }
374
375 static PassRegistration<PipelineDataTransfer> pass(
376 "affine-pipeline-data-transfer",
377 "Pipeline non-blocking data transfers between explicitly managed levels of "
378 "the memory hierarchy");