Mercurial > hg > CbC > CbC_llvm
comparison mlir/lib/Transforms/PipelineDataTransfer.cpp @ 173:0572611fdcc8 llvm10 llvm12
reorgnization done
author | Shinji KONO <kono@ie.u-ryukyu.ac.jp> |
---|---|
date | Mon, 25 May 2020 11:55:54 +0900 |
parents | 1d019706d866 |
children | 2e18cbf3894f |
comparison
equal
deleted
inserted
replaced
172:9fbae9c8bf63 | 173:0572611fdcc8 |
---|---|
8 // | 8 // |
9 // This file implements a pass to pipeline data transfers. | 9 // This file implements a pass to pipeline data transfers. |
10 // | 10 // |
11 //===----------------------------------------------------------------------===// | 11 //===----------------------------------------------------------------------===// |
12 | 12 |
13 #include "PassDetail.h" | |
13 #include "mlir/Transforms/Passes.h" | 14 #include "mlir/Transforms/Passes.h" |
14 | 15 |
15 #include "mlir/Analysis/AffineAnalysis.h" | 16 #include "mlir/Analysis/AffineAnalysis.h" |
16 #include "mlir/Analysis/LoopAnalysis.h" | 17 #include "mlir/Analysis/LoopAnalysis.h" |
17 #include "mlir/Analysis/Utils.h" | 18 #include "mlir/Analysis/Utils.h" |
18 #include "mlir/Dialect/AffineOps/AffineOps.h" | 19 #include "mlir/Dialect/Affine/IR/AffineOps.h" |
19 #include "mlir/Dialect/StandardOps/Ops.h" | |
20 #include "mlir/IR/Builders.h" | 20 #include "mlir/IR/Builders.h" |
21 #include "mlir/Pass/Pass.h" | |
22 #include "mlir/Transforms/LoopUtils.h" | 21 #include "mlir/Transforms/LoopUtils.h" |
23 #include "mlir/Transforms/Utils.h" | 22 #include "mlir/Transforms/Utils.h" |
24 #include "llvm/ADT/DenseMap.h" | 23 #include "llvm/ADT/DenseMap.h" |
25 #include "llvm/Support/Debug.h" | 24 #include "llvm/Support/Debug.h" |
25 | |
26 #define DEBUG_TYPE "affine-pipeline-data-transfer" | 26 #define DEBUG_TYPE "affine-pipeline-data-transfer" |
27 | 27 |
28 using namespace mlir; | 28 using namespace mlir; |
29 | 29 |
30 namespace { | 30 namespace { |
31 | 31 struct PipelineDataTransfer |
32 struct PipelineDataTransfer : public FunctionPass<PipelineDataTransfer> { | 32 : public AffinePipelineDataTransferBase<PipelineDataTransfer> { |
33 void runOnFunction() override; | 33 void runOnFunction() override; |
34 void runOnAffineForOp(AffineForOp forOp); | 34 void runOnAffineForOp(AffineForOp forOp); |
35 | 35 |
36 std::vector<AffineForOp> forOps; | 36 std::vector<AffineForOp> forOps; |
37 }; | 37 }; |
38 | 38 |
39 } // end anonymous namespace | 39 } // end anonymous namespace |
40 | 40 |
41 /// Creates a pass to pipeline explicit movement of data across levels of the | 41 /// Creates a pass to pipeline explicit movement of data across levels of the |
42 /// memory hierarchy. | 42 /// memory hierarchy. |
43 std::unique_ptr<OpPassBase<FuncOp>> mlir::createPipelineDataTransferPass() { | 43 std::unique_ptr<OperationPass<FuncOp>> mlir::createPipelineDataTransferPass() { |
44 return std::make_unique<PipelineDataTransfer>(); | 44 return std::make_unique<PipelineDataTransfer>(); |
45 } | 45 } |
46 | 46 |
47 // Returns the position of the tag memref operand given a DMA operation. | 47 // Returns the position of the tag memref operand given a DMA operation. |
48 // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are | 48 // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are |
49 // added. TODO(b/117228571) | 49 // added. TODO(b/117228571) |
50 static unsigned getTagMemRefPos(Operation &dmaInst) { | 50 static unsigned getTagMemRefPos(Operation &dmaOp) { |
51 assert(isa<AffineDmaStartOp>(dmaInst) || isa<AffineDmaWaitOp>(dmaInst)); | 51 assert(isa<AffineDmaStartOp>(dmaOp) || isa<AffineDmaWaitOp>(dmaOp)); |
52 if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaInst)) { | 52 if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) { |
53 return dmaStartOp.getTagMemRefOperandIndex(); | 53 return dmaStartOp.getTagMemRefOperandIndex(); |
54 } | 54 } |
55 // First operand for a dma finish operation. | 55 // First operand for a dma finish operation. |
56 return 0; | 56 return 0; |
57 } | 57 } |
78 }; | 78 }; |
79 | 79 |
80 auto oldMemRefType = oldMemRef.getType().cast<MemRefType>(); | 80 auto oldMemRefType = oldMemRef.getType().cast<MemRefType>(); |
81 auto newMemRefType = doubleShape(oldMemRefType); | 81 auto newMemRefType = doubleShape(oldMemRefType); |
82 | 82 |
83 // The double buffer is allocated right before 'forInst'. | 83 // The double buffer is allocated right before 'forOp'. |
84 auto *forInst = forOp.getOperation(); | 84 OpBuilder bOuter(forOp); |
85 OpBuilder bOuter(forInst); | |
86 // Put together alloc operands for any dynamic dimensions of the memref. | 85 // Put together alloc operands for any dynamic dimensions of the memref. |
87 SmallVector<Value, 4> allocOperands; | 86 SmallVector<Value, 4> allocOperands; |
88 unsigned dynamicDimCount = 0; | 87 unsigned dynamicDimCount = 0; |
89 for (auto dimSize : oldMemRefType.getShape()) { | 88 for (auto dimSize : oldMemRefType.getShape()) { |
90 if (dimSize == -1) | 89 if (dimSize == -1) |
91 allocOperands.push_back(bOuter.create<DimOp>(forInst->getLoc(), oldMemRef, | 90 allocOperands.push_back( |
92 dynamicDimCount++)); | 91 bOuter.create<DimOp>(forOp.getLoc(), oldMemRef, dynamicDimCount++)); |
93 } | 92 } |
94 | 93 |
95 // Create and place the alloc right before the 'affine.for' operation. | 94 // Create and place the alloc right before the 'affine.for' operation. |
96 Value newMemRef = | 95 Value newMemRef = |
97 bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands); | 96 bOuter.create<AllocOp>(forOp.getLoc(), newMemRefType, allocOperands); |
98 | 97 |
99 // Create 'iv mod 2' value to index the leading dimension. | 98 // Create 'iv mod 2' value to index the leading dimension. |
100 auto d0 = bInner.getAffineDimExpr(0); | 99 auto d0 = bInner.getAffineDimExpr(0); |
101 int64_t step = forOp.getStep(); | 100 int64_t step = forOp.getStep(); |
102 auto modTwoMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, | 101 auto modTwoMap = |
103 {d0.floorDiv(step) % 2}); | 102 AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2); |
104 auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap, | 103 auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap, |
105 forOp.getInductionVar()); | 104 forOp.getInductionVar()); |
106 | 105 |
107 // replaceAllMemRefUsesWith will succeed unless the forOp body has | 106 // replaceAllMemRefUsesWith will succeed unless the forOp body has |
108 // non-dereferencing uses of the memref (dealloc's are fine though). | 107 // non-dereferencing uses of the memref (dealloc's are fine though). |
117 forOp.emitError("memref replacement for double buffering failed")); | 116 forOp.emitError("memref replacement for double buffering failed")); |
118 ivModTwoOp.erase(); | 117 ivModTwoOp.erase(); |
119 return false; | 118 return false; |
120 } | 119 } |
121 // Insert the dealloc op right after the for loop. | 120 // Insert the dealloc op right after the for loop. |
122 bOuter.setInsertionPointAfter(forInst); | 121 bOuter.setInsertionPointAfter(forOp); |
123 bOuter.create<DeallocOp>(forInst->getLoc(), newMemRef); | 122 bOuter.create<DeallocOp>(forOp.getLoc(), newMemRef); |
124 | 123 |
125 return true; | 124 return true; |
126 } | 125 } |
127 | 126 |
128 /// Returns success if the IR is in a valid state. | 127 /// Returns success if the IR is in a valid state. |
218 if (!escapingUses) | 217 if (!escapingUses) |
219 dmaStartInsts.push_back(&op); | 218 dmaStartInsts.push_back(&op); |
220 } | 219 } |
221 | 220 |
222 // For each start operation, we look for a matching finish operation. | 221 // For each start operation, we look for a matching finish operation. |
223 for (auto *dmaStartInst : dmaStartInsts) { | 222 for (auto *dmaStartOp : dmaStartInsts) { |
224 for (auto *dmaFinishInst : dmaFinishInsts) { | 223 for (auto *dmaFinishOp : dmaFinishInsts) { |
225 if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartInst), | 224 if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp), |
226 cast<AffineDmaWaitOp>(dmaFinishInst))) { | 225 cast<AffineDmaWaitOp>(dmaFinishOp))) { |
227 startWaitPairs.push_back({dmaStartInst, dmaFinishInst}); | 226 startWaitPairs.push_back({dmaStartOp, dmaFinishOp}); |
228 break; | 227 break; |
229 } | 228 } |
230 } | 229 } |
231 } | 230 } |
232 } | 231 } |
235 /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are | 234 /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are |
236 /// inserted right before where it was. | 235 /// inserted right before where it was. |
237 void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { | 236 void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { |
238 auto mayBeConstTripCount = getConstantTripCount(forOp); | 237 auto mayBeConstTripCount = getConstantTripCount(forOp); |
239 if (!mayBeConstTripCount.hasValue()) { | 238 if (!mayBeConstTripCount.hasValue()) { |
240 LLVM_DEBUG( | 239 LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count")); |
241 forOp.emitRemark("won't pipeline due to unknown trip count loop")); | |
242 return; | 240 return; |
243 } | 241 } |
244 | 242 |
245 SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs; | 243 SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs; |
246 findMatchingStartFinishInsts(forOp, startWaitPairs); | 244 findMatchingStartFinishInsts(forOp, startWaitPairs); |
257 // TODO(bondhugula): check whether double-buffering is even necessary. | 255 // TODO(bondhugula): check whether double-buffering is even necessary. |
258 // TODO(bondhugula): make this work with different layouts: assuming here that | 256 // TODO(bondhugula): make this work with different layouts: assuming here that |
259 // the dimension we are adding here for the double buffering is the outermost | 257 // the dimension we are adding here for the double buffering is the outermost |
260 // dimension. | 258 // dimension. |
261 for (auto &pair : startWaitPairs) { | 259 for (auto &pair : startWaitPairs) { |
262 auto *dmaStartInst = pair.first; | 260 auto *dmaStartOp = pair.first; |
263 Value oldMemRef = dmaStartInst->getOperand( | 261 Value oldMemRef = dmaStartOp->getOperand( |
264 cast<AffineDmaStartOp>(dmaStartInst).getFasterMemPos()); | 262 cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos()); |
265 if (!doubleBuffer(oldMemRef, forOp)) { | 263 if (!doubleBuffer(oldMemRef, forOp)) { |
266 // Normally, double buffering should not fail because we already checked | 264 // Normally, double buffering should not fail because we already checked |
267 // that there are no uses outside. | 265 // that there are no uses outside. |
268 LLVM_DEBUG(llvm::dbgs() | 266 LLVM_DEBUG(llvm::dbgs() |
269 << "double buffering failed for" << dmaStartInst << "\n";); | 267 << "double buffering failed for" << dmaStartOp << "\n";); |
270 // IR still valid and semantically correct. | 268 // IR still valid and semantically correct. |
271 return; | 269 return; |
272 } | 270 } |
273 // If the old memref has no more uses, remove its 'dead' alloc if it was | 271 // If the old memref has no more uses, remove its 'dead' alloc if it was |
274 // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim' | 272 // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim' |
275 // operation could have been used on it if it was dynamically shaped in | 273 // operation could have been used on it if it was dynamically shaped in |
276 // order to create the double buffer above.) | 274 // order to create the double buffer above.) |
277 // '-canonicalize' does this in a more general way, but we'll anyway do the | 275 // '-canonicalize' does this in a more general way, but we'll anyway do the |
278 // simple/common case so that the output / test cases looks clear. | 276 // simple/common case so that the output / test cases looks clear. |
279 if (auto *allocInst = oldMemRef.getDefiningOp()) { | 277 if (auto *allocOp = oldMemRef.getDefiningOp()) { |
280 if (oldMemRef.use_empty()) { | 278 if (oldMemRef.use_empty()) { |
281 allocInst->erase(); | 279 allocOp->erase(); |
282 } else if (oldMemRef.hasOneUse()) { | 280 } else if (oldMemRef.hasOneUse()) { |
283 if (auto dealloc = dyn_cast<DeallocOp>(*oldMemRef.user_begin())) { | 281 if (auto dealloc = dyn_cast<DeallocOp>(*oldMemRef.user_begin())) { |
284 dealloc.erase(); | 282 dealloc.erase(); |
285 allocInst->erase(); | 283 allocOp->erase(); |
286 } | 284 } |
287 } | 285 } |
288 } | 286 } |
289 } | 287 } |
290 | 288 |
291 // Double the buffers for tag memrefs. | 289 // Double the buffers for tag memrefs. |
292 for (auto &pair : startWaitPairs) { | 290 for (auto &pair : startWaitPairs) { |
293 auto *dmaFinishInst = pair.second; | 291 auto *dmaFinishOp = pair.second; |
294 Value oldTagMemRef = | 292 Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp)); |
295 dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst)); | |
296 if (!doubleBuffer(oldTagMemRef, forOp)) { | 293 if (!doubleBuffer(oldTagMemRef, forOp)) { |
297 LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); | 294 LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); |
298 return; | 295 return; |
299 } | 296 } |
300 // If the old tag has no uses or a single dealloc use, remove it. | 297 // If the old tag has no uses or a single dealloc use, remove it. |
301 // (canonicalization handles more complex cases). | 298 // (canonicalization handles more complex cases). |
302 if (auto *tagAllocInst = oldTagMemRef.getDefiningOp()) { | 299 if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) { |
303 if (oldTagMemRef.use_empty()) { | 300 if (oldTagMemRef.use_empty()) { |
304 tagAllocInst->erase(); | 301 tagAllocOp->erase(); |
305 } else if (oldTagMemRef.hasOneUse()) { | 302 } else if (oldTagMemRef.hasOneUse()) { |
306 if (auto dealloc = dyn_cast<DeallocOp>(*oldTagMemRef.user_begin())) { | 303 if (auto dealloc = dyn_cast<DeallocOp>(*oldTagMemRef.user_begin())) { |
307 dealloc.erase(); | 304 dealloc.erase(); |
308 tagAllocInst->erase(); | 305 tagAllocOp->erase(); |
309 } | 306 } |
310 } | 307 } |
311 } | 308 } |
312 } | 309 } |
313 | 310 |
316 findMatchingStartFinishInsts(forOp, startWaitPairs); | 313 findMatchingStartFinishInsts(forOp, startWaitPairs); |
317 | 314 |
318 // Store shift for operation for later lookup for AffineApplyOp's. | 315 // Store shift for operation for later lookup for AffineApplyOp's. |
319 DenseMap<Operation *, unsigned> instShiftMap; | 316 DenseMap<Operation *, unsigned> instShiftMap; |
320 for (auto &pair : startWaitPairs) { | 317 for (auto &pair : startWaitPairs) { |
321 auto *dmaStartInst = pair.first; | 318 auto *dmaStartOp = pair.first; |
322 assert(isa<AffineDmaStartOp>(dmaStartInst)); | 319 assert(isa<AffineDmaStartOp>(dmaStartOp)); |
323 instShiftMap[dmaStartInst] = 0; | 320 instShiftMap[dmaStartOp] = 0; |
324 // Set shifts for DMA start op's affine operand computation slices to 0. | 321 // Set shifts for DMA start op's affine operand computation slices to 0. |
325 SmallVector<AffineApplyOp, 4> sliceOps; | 322 SmallVector<AffineApplyOp, 4> sliceOps; |
326 mlir::createAffineComputationSlice(dmaStartInst, &sliceOps); | 323 mlir::createAffineComputationSlice(dmaStartOp, &sliceOps); |
327 if (!sliceOps.empty()) { | 324 if (!sliceOps.empty()) { |
328 for (auto sliceOp : sliceOps) { | 325 for (auto sliceOp : sliceOps) { |
329 instShiftMap[sliceOp.getOperation()] = 0; | 326 instShiftMap[sliceOp.getOperation()] = 0; |
330 } | 327 } |
331 } else { | 328 } else { |
332 // If a slice wasn't created, the reachable affine.apply op's from its | 329 // If a slice wasn't created, the reachable affine.apply op's from its |
333 // operands are the ones that go with it. | 330 // operands are the ones that go with it. |
334 SmallVector<Operation *, 4> affineApplyInsts; | 331 SmallVector<Operation *, 4> affineApplyInsts; |
335 SmallVector<Value, 4> operands(dmaStartInst->getOperands()); | 332 SmallVector<Value, 4> operands(dmaStartOp->getOperands()); |
336 getReachableAffineApplyOps(operands, affineApplyInsts); | 333 getReachableAffineApplyOps(operands, affineApplyInsts); |
337 for (auto *op : affineApplyInsts) { | 334 for (auto *op : affineApplyInsts) { |
338 instShiftMap[op] = 0; | 335 instShiftMap[op] = 0; |
339 } | 336 } |
340 } | 337 } |
341 } | 338 } |
342 // Everything else (including compute ops and dma finish) are shifted by one. | 339 // Everything else (including compute ops and dma finish) are shifted by one. |
343 for (auto &op : *forOp.getBody()) { | 340 for (auto &op : forOp.getBody()->without_terminator()) |
344 if (instShiftMap.find(&op) == instShiftMap.end()) { | 341 if (instShiftMap.find(&op) == instShiftMap.end()) |
345 instShiftMap[&op] = 1; | 342 instShiftMap[&op] = 1; |
346 } | |
347 } | |
348 | 343 |
349 // Get shifts stored in map. | 344 // Get shifts stored in map. |
350 std::vector<uint64_t> shifts(forOp.getBody()->getOperations().size()); | 345 SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size()); |
351 unsigned s = 0; | 346 unsigned s = 0; |
352 for (auto &op : *forOp.getBody()) { | 347 for (auto &op : forOp.getBody()->without_terminator()) { |
353 assert(instShiftMap.find(&op) != instShiftMap.end()); | 348 assert(instShiftMap.find(&op) != instShiftMap.end()); |
354 shifts[s++] = instShiftMap[&op]; | 349 shifts[s++] = instShiftMap[&op]; |
355 | 350 |
356 // Tagging operations with shifts for debugging purposes. | 351 // Tagging operations with shifts for debugging purposes. |
357 LLVM_DEBUG({ | 352 LLVM_DEBUG({ |
358 OpBuilder b(&op); | 353 OpBuilder b(&op); |
359 op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1])); | 354 op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1])); |
360 }); | 355 }); |
361 } | 356 } |
362 | 357 |
363 if (!isInstwiseShiftValid(forOp, shifts)) { | 358 if (!isOpwiseShiftValid(forOp, shifts)) { |
364 // Violates dependences. | 359 // Violates dependences. |
365 LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); | 360 LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); |
366 return; | 361 return; |
367 } | 362 } |
368 | 363 |
369 if (failed(instBodySkew(forOp, shifts))) { | 364 if (failed(affineForOpBodySkew(forOp, shifts))) { |
370 LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";); | 365 LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";); |
371 return; | 366 return; |
372 } | 367 } |
373 } | 368 } |
374 | |
375 static PassRegistration<PipelineDataTransfer> pass( | |
376 "affine-pipeline-data-transfer", | |
377 "Pipeline non-blocking data transfers between explicitly managed levels of " | |
378 "the memory hierarchy"); |