150
|
1 //===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
|
|
2 //
|
|
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
4 // See https://llvm.org/LICENSE.txt for license information.
|
|
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
6 //
|
|
7 //===----------------------------------------------------------------------===//
|
|
8 //
|
|
9 // This file implements a pass to pipeline data transfers.
|
|
10 //
|
|
11 //===----------------------------------------------------------------------===//
|
|
12
|
173
|
13 #include "PassDetail.h"
|
150
|
14 #include "mlir/Transforms/Passes.h"
|
|
15
|
|
16 #include "mlir/Analysis/AffineAnalysis.h"
|
|
17 #include "mlir/Analysis/LoopAnalysis.h"
|
|
18 #include "mlir/Analysis/Utils.h"
|
173
|
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
|
150
|
20 #include "mlir/IR/Builders.h"
|
|
21 #include "mlir/Transforms/LoopUtils.h"
|
|
22 #include "mlir/Transforms/Utils.h"
|
|
23 #include "llvm/ADT/DenseMap.h"
|
|
24 #include "llvm/Support/Debug.h"
|
173
|
25
|
150
|
26 #define DEBUG_TYPE "affine-pipeline-data-transfer"
|
|
27
|
|
28 using namespace mlir;
|
|
29
|
|
30 namespace {
|
173
|
31 struct PipelineDataTransfer
|
|
32 : public AffinePipelineDataTransferBase<PipelineDataTransfer> {
|
150
|
33 void runOnFunction() override;
|
|
34 void runOnAffineForOp(AffineForOp forOp);
|
|
35
|
|
36 std::vector<AffineForOp> forOps;
|
|
37 };
|
|
38
|
|
39 } // end anonymous namespace
|
|
40
|
|
41 /// Creates a pass to pipeline explicit movement of data across levels of the
|
|
42 /// memory hierarchy.
|
173
|
43 std::unique_ptr<OperationPass<FuncOp>> mlir::createPipelineDataTransferPass() {
|
150
|
44 return std::make_unique<PipelineDataTransfer>();
|
|
45 }
|
|
46
|
|
47 // Returns the position of the tag memref operand given a DMA operation.
|
|
48 // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
|
|
49 // added. TODO(b/117228571)
|
173
|
50 static unsigned getTagMemRefPos(Operation &dmaOp) {
|
|
51 assert(isa<AffineDmaStartOp>(dmaOp) || isa<AffineDmaWaitOp>(dmaOp));
|
|
52 if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) {
|
150
|
53 return dmaStartOp.getTagMemRefOperandIndex();
|
|
54 }
|
|
55 // First operand for a dma finish operation.
|
|
56 return 0;
|
|
57 }
|
|
58
|
|
59 /// Doubles the buffer of the supplied memref on the specified 'affine.for'
|
|
60 /// operation by adding a leading dimension of size two to the memref.
|
|
61 /// Replaces all uses of the old memref by the new one while indexing the newly
|
|
62 /// added dimension by the loop IV of the specified 'affine.for' operation
|
|
63 /// modulo 2. Returns false if such a replacement cannot be performed.
|
|
64 static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
|
|
65 auto *forBody = forOp.getBody();
|
|
66 OpBuilder bInner(forBody, forBody->begin());
|
|
67
|
|
68 // Doubles the shape with a leading dimension extent of 2.
|
|
69 auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
|
|
70 // Add the leading dimension in the shape for the double buffer.
|
|
71 ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
|
|
72 SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
|
|
73 newShape[0] = 2;
|
|
74 std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
|
|
75 return MemRefType::Builder(oldMemRefType)
|
|
76 .setShape(newShape)
|
|
77 .setAffineMaps({});
|
|
78 };
|
|
79
|
|
80 auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
|
|
81 auto newMemRefType = doubleShape(oldMemRefType);
|
|
82
|
173
|
83 // The double buffer is allocated right before 'forOp'.
|
|
84 OpBuilder bOuter(forOp);
|
150
|
85 // Put together alloc operands for any dynamic dimensions of the memref.
|
|
86 SmallVector<Value, 4> allocOperands;
|
|
87 unsigned dynamicDimCount = 0;
|
|
88 for (auto dimSize : oldMemRefType.getShape()) {
|
|
89 if (dimSize == -1)
|
173
|
90 allocOperands.push_back(
|
|
91 bOuter.create<DimOp>(forOp.getLoc(), oldMemRef, dynamicDimCount++));
|
150
|
92 }
|
|
93
|
|
94 // Create and place the alloc right before the 'affine.for' operation.
|
|
95 Value newMemRef =
|
173
|
96 bOuter.create<AllocOp>(forOp.getLoc(), newMemRefType, allocOperands);
|
150
|
97
|
|
98 // Create 'iv mod 2' value to index the leading dimension.
|
|
99 auto d0 = bInner.getAffineDimExpr(0);
|
|
100 int64_t step = forOp.getStep();
|
173
|
101 auto modTwoMap =
|
|
102 AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
|
150
|
103 auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
|
|
104 forOp.getInductionVar());
|
|
105
|
|
106 // replaceAllMemRefUsesWith will succeed unless the forOp body has
|
|
107 // non-dereferencing uses of the memref (dealloc's are fine though).
|
|
108 if (failed(replaceAllMemRefUsesWith(
|
|
109 oldMemRef, newMemRef,
|
|
110 /*extraIndices=*/{ivModTwoOp},
|
|
111 /*indexRemap=*/AffineMap(),
|
|
112 /*extraOperands=*/{},
|
|
113 /*symbolOperands=*/{},
|
|
114 /*domInstFilter=*/&*forOp.getBody()->begin()))) {
|
|
115 LLVM_DEBUG(
|
|
116 forOp.emitError("memref replacement for double buffering failed"));
|
|
117 ivModTwoOp.erase();
|
|
118 return false;
|
|
119 }
|
|
120 // Insert the dealloc op right after the for loop.
|
173
|
121 bOuter.setInsertionPointAfter(forOp);
|
|
122 bOuter.create<DeallocOp>(forOp.getLoc(), newMemRef);
|
150
|
123
|
|
124 return true;
|
|
125 }
|
|
126
|
|
127 /// Returns success if the IR is in a valid state.
|
|
128 void PipelineDataTransfer::runOnFunction() {
|
|
129 // Do a post order walk so that inner loop DMAs are processed first. This is
|
|
130 // necessary since 'affine.for' operations nested within would otherwise
|
|
131 // become invalid (erased) when the outer loop is pipelined (the pipelined one
|
|
132 // gets deleted and replaced by a prologue, a new steady-state loop and an
|
|
133 // epilogue).
|
|
134 forOps.clear();
|
|
135 getFunction().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
|
|
136 for (auto forOp : forOps)
|
|
137 runOnAffineForOp(forOp);
|
|
138 }
|
|
139
|
|
140 // Check if tags of the dma start op and dma wait op match.
|
|
141 static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
|
|
142 if (startOp.getTagMemRef() != waitOp.getTagMemRef())
|
|
143 return false;
|
|
144 auto startIndices = startOp.getTagIndices();
|
|
145 auto waitIndices = waitOp.getTagIndices();
|
|
146 // Both of these have the same number of indices since they correspond to the
|
|
147 // same tag memref.
|
|
148 for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
|
|
149 e = startIndices.end();
|
|
150 it != e; ++it, ++wIt) {
|
|
151 // Keep it simple for now, just checking if indices match.
|
|
152 // TODO(mlir-team): this would in general need to check if there is no
|
|
153 // intervening write writing to the same tag location, i.e., memory last
|
|
154 // write/data flow analysis. This is however sufficient/powerful enough for
|
|
155 // now since the DMA generation pass or the input for it will always have
|
|
156 // start/wait with matching tags (same SSA operand indices).
|
|
157 if (*it != *wIt)
|
|
158 return false;
|
|
159 }
|
|
160 return true;
|
|
161 }
|
|
162
|
|
163 // Identify matching DMA start/finish operations to overlap computation with.
|
|
164 static void findMatchingStartFinishInsts(
|
|
165 AffineForOp forOp,
|
|
166 SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
|
|
167
|
|
168 // Collect outgoing DMA operations - needed to check for dependences below.
|
|
169 SmallVector<AffineDmaStartOp, 4> outgoingDmaOps;
|
|
170 for (auto &op : *forOp.getBody()) {
|
|
171 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
|
|
172 if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
|
|
173 outgoingDmaOps.push_back(dmaStartOp);
|
|
174 }
|
|
175
|
|
176 SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
|
|
177 for (auto &op : *forOp.getBody()) {
|
|
178 // Collect DMA finish operations.
|
|
179 if (isa<AffineDmaWaitOp>(op)) {
|
|
180 dmaFinishInsts.push_back(&op);
|
|
181 continue;
|
|
182 }
|
|
183 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
|
|
184 if (!dmaStartOp)
|
|
185 continue;
|
|
186
|
|
187 // Only DMAs incoming into higher memory spaces are pipelined for now.
|
|
188 // TODO(bondhugula): handle outgoing DMA pipelining.
|
|
189 if (!dmaStartOp.isDestMemorySpaceFaster())
|
|
190 continue;
|
|
191
|
|
192 // Check for dependence with outgoing DMAs. Doing this conservatively.
|
|
193 // TODO(andydavis,bondhugula): use the dependence analysis to check for
|
|
194 // dependences between an incoming and outgoing DMA in the same iteration.
|
|
195 auto it = outgoingDmaOps.begin();
|
|
196 for (; it != outgoingDmaOps.end(); ++it) {
|
|
197 if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
|
|
198 break;
|
|
199 }
|
|
200 if (it != outgoingDmaOps.end())
|
|
201 continue;
|
|
202
|
|
203 // We only double buffer if the buffer is not live out of loop.
|
|
204 auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
|
|
205 bool escapingUses = false;
|
|
206 for (auto *user : memref.getUsers()) {
|
|
207 // We can double buffer regardless of dealloc's outside the loop.
|
|
208 if (isa<DeallocOp>(user))
|
|
209 continue;
|
|
210 if (!forOp.getBody()->findAncestorOpInBlock(*user)) {
|
|
211 LLVM_DEBUG(llvm::dbgs()
|
|
212 << "can't pipeline: buffer is live out of loop\n";);
|
|
213 escapingUses = true;
|
|
214 break;
|
|
215 }
|
|
216 }
|
|
217 if (!escapingUses)
|
|
218 dmaStartInsts.push_back(&op);
|
|
219 }
|
|
220
|
|
221 // For each start operation, we look for a matching finish operation.
|
173
|
222 for (auto *dmaStartOp : dmaStartInsts) {
|
|
223 for (auto *dmaFinishOp : dmaFinishInsts) {
|
|
224 if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp),
|
|
225 cast<AffineDmaWaitOp>(dmaFinishOp))) {
|
|
226 startWaitPairs.push_back({dmaStartOp, dmaFinishOp});
|
150
|
227 break;
|
|
228 }
|
|
229 }
|
|
230 }
|
|
231 }
|
|
232
|
|
233 /// Overlap DMA transfers with computation in this loop. If successful,
|
|
234 /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
|
|
235 /// inserted right before where it was.
|
|
236 void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
|
|
237 auto mayBeConstTripCount = getConstantTripCount(forOp);
|
|
238 if (!mayBeConstTripCount.hasValue()) {
|
173
|
239 LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count"));
|
150
|
240 return;
|
|
241 }
|
|
242
|
|
243 SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
|
|
244 findMatchingStartFinishInsts(forOp, startWaitPairs);
|
|
245
|
|
246 if (startWaitPairs.empty()) {
|
|
247 LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
|
|
248 return;
|
|
249 }
|
|
250
|
|
251 // Double the buffers for the higher memory space memref's.
|
|
252 // Identify memref's to replace by scanning through all DMA start
|
|
253 // operations. A DMA start operation has two memref's - the one from the
|
|
254 // higher level of memory hierarchy is the one to double buffer.
|
|
255 // TODO(bondhugula): check whether double-buffering is even necessary.
|
|
256 // TODO(bondhugula): make this work with different layouts: assuming here that
|
|
257 // the dimension we are adding here for the double buffering is the outermost
|
|
258 // dimension.
|
|
259 for (auto &pair : startWaitPairs) {
|
173
|
260 auto *dmaStartOp = pair.first;
|
|
261 Value oldMemRef = dmaStartOp->getOperand(
|
|
262 cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos());
|
150
|
263 if (!doubleBuffer(oldMemRef, forOp)) {
|
|
264 // Normally, double buffering should not fail because we already checked
|
|
265 // that there are no uses outside.
|
|
266 LLVM_DEBUG(llvm::dbgs()
|
173
|
267 << "double buffering failed for" << dmaStartOp << "\n";);
|
150
|
268 // IR still valid and semantically correct.
|
|
269 return;
|
|
270 }
|
|
271 // If the old memref has no more uses, remove its 'dead' alloc if it was
|
|
272 // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
|
|
273 // operation could have been used on it if it was dynamically shaped in
|
|
274 // order to create the double buffer above.)
|
|
275 // '-canonicalize' does this in a more general way, but we'll anyway do the
|
|
276 // simple/common case so that the output / test cases looks clear.
|
173
|
277 if (auto *allocOp = oldMemRef.getDefiningOp()) {
|
150
|
278 if (oldMemRef.use_empty()) {
|
173
|
279 allocOp->erase();
|
150
|
280 } else if (oldMemRef.hasOneUse()) {
|
|
281 if (auto dealloc = dyn_cast<DeallocOp>(*oldMemRef.user_begin())) {
|
|
282 dealloc.erase();
|
173
|
283 allocOp->erase();
|
150
|
284 }
|
|
285 }
|
|
286 }
|
|
287 }
|
|
288
|
|
289 // Double the buffers for tag memrefs.
|
|
290 for (auto &pair : startWaitPairs) {
|
173
|
291 auto *dmaFinishOp = pair.second;
|
|
292 Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp));
|
150
|
293 if (!doubleBuffer(oldTagMemRef, forOp)) {
|
|
294 LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
|
|
295 return;
|
|
296 }
|
|
297 // If the old tag has no uses or a single dealloc use, remove it.
|
|
298 // (canonicalization handles more complex cases).
|
173
|
299 if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) {
|
150
|
300 if (oldTagMemRef.use_empty()) {
|
173
|
301 tagAllocOp->erase();
|
150
|
302 } else if (oldTagMemRef.hasOneUse()) {
|
|
303 if (auto dealloc = dyn_cast<DeallocOp>(*oldTagMemRef.user_begin())) {
|
|
304 dealloc.erase();
|
173
|
305 tagAllocOp->erase();
|
150
|
306 }
|
|
307 }
|
|
308 }
|
|
309 }
|
|
310
|
|
311 // Double buffering would have invalidated all the old DMA start/wait insts.
|
|
312 startWaitPairs.clear();
|
|
313 findMatchingStartFinishInsts(forOp, startWaitPairs);
|
|
314
|
|
315 // Store shift for operation for later lookup for AffineApplyOp's.
|
|
316 DenseMap<Operation *, unsigned> instShiftMap;
|
|
317 for (auto &pair : startWaitPairs) {
|
173
|
318 auto *dmaStartOp = pair.first;
|
|
319 assert(isa<AffineDmaStartOp>(dmaStartOp));
|
|
320 instShiftMap[dmaStartOp] = 0;
|
150
|
321 // Set shifts for DMA start op's affine operand computation slices to 0.
|
|
322 SmallVector<AffineApplyOp, 4> sliceOps;
|
173
|
323 mlir::createAffineComputationSlice(dmaStartOp, &sliceOps);
|
150
|
324 if (!sliceOps.empty()) {
|
|
325 for (auto sliceOp : sliceOps) {
|
|
326 instShiftMap[sliceOp.getOperation()] = 0;
|
|
327 }
|
|
328 } else {
|
|
329 // If a slice wasn't created, the reachable affine.apply op's from its
|
|
330 // operands are the ones that go with it.
|
|
331 SmallVector<Operation *, 4> affineApplyInsts;
|
173
|
332 SmallVector<Value, 4> operands(dmaStartOp->getOperands());
|
150
|
333 getReachableAffineApplyOps(operands, affineApplyInsts);
|
|
334 for (auto *op : affineApplyInsts) {
|
|
335 instShiftMap[op] = 0;
|
|
336 }
|
|
337 }
|
|
338 }
|
|
339 // Everything else (including compute ops and dma finish) are shifted by one.
|
173
|
340 for (auto &op : forOp.getBody()->without_terminator())
|
|
341 if (instShiftMap.find(&op) == instShiftMap.end())
|
150
|
342 instShiftMap[&op] = 1;
|
|
343
|
|
344 // Get shifts stored in map.
|
173
|
345 SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size());
|
150
|
346 unsigned s = 0;
|
173
|
347 for (auto &op : forOp.getBody()->without_terminator()) {
|
150
|
348 assert(instShiftMap.find(&op) != instShiftMap.end());
|
|
349 shifts[s++] = instShiftMap[&op];
|
|
350
|
|
351 // Tagging operations with shifts for debugging purposes.
|
|
352 LLVM_DEBUG({
|
|
353 OpBuilder b(&op);
|
|
354 op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
|
|
355 });
|
|
356 }
|
|
357
|
173
|
358 if (!isOpwiseShiftValid(forOp, shifts)) {
|
150
|
359 // Violates dependences.
|
|
360 LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
|
|
361 return;
|
|
362 }
|
|
363
|
173
|
364 if (failed(affineForOpBodySkew(forOp, shifts))) {
|
150
|
365 LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
|
|
366 return;
|
|
367 }
|
|
368 }
|