comparison mlir/lib/Transforms/MemRefDataFlowOpt.cpp @ 173:0572611fdcc8 llvm10 llvm12

reorgnization done
author Shinji KONO <kono@ie.u-ryukyu.ac.jp>
date Mon, 25 May 2020 11:55:54 +0900
parents 1d019706d866
children 2e18cbf3894f
comparison
equal deleted inserted replaced
172:9fbae9c8bf63 173:0572611fdcc8
11 // TODO(mlir-team): In the future, similar techniques could be used to eliminate 11 // TODO(mlir-team): In the future, similar techniques could be used to eliminate
12 // dead memref store's and perform more complex forwarding when support for 12 // dead memref store's and perform more complex forwarding when support for
13 // SSA scalars live out of 'affine.for'/'affine.if' statements is available. 13 // SSA scalars live out of 'affine.for'/'affine.if' statements is available.
14 //===----------------------------------------------------------------------===// 14 //===----------------------------------------------------------------------===//
15 15
16 #include "PassDetail.h"
16 #include "mlir/Analysis/AffineAnalysis.h" 17 #include "mlir/Analysis/AffineAnalysis.h"
17 #include "mlir/Analysis/Dominance.h"
18 #include "mlir/Analysis/Utils.h" 18 #include "mlir/Analysis/Utils.h"
19 #include "mlir/Dialect/AffineOps/AffineOps.h" 19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/Dialect/StandardOps/Ops.h" 20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/Pass/Pass.h" 21 #include "mlir/IR/Dominance.h"
22 #include "mlir/Transforms/Passes.h" 22 #include "mlir/Transforms/Passes.h"
23 #include "llvm/ADT/SmallPtrSet.h" 23 #include "llvm/ADT/SmallPtrSet.h"
24 #include <algorithm> 24 #include <algorithm>
25 25
26 #define DEBUG_TYPE "memref-dataflow-opt" 26 #define DEBUG_TYPE "memref-dataflow-opt"
27 27
28 using namespace mlir; 28 using namespace mlir;
29 29
30 namespace { 30 namespace {
31
32 // The store to load forwarding relies on three conditions: 31 // The store to load forwarding relies on three conditions:
33 // 32 //
34 // 1) they need to have mathematically equivalent affine access functions 33 // 1) they need to have mathematically equivalent affine access functions
35 // (checked after full composition of load/store operands); this implies that 34 // (checked after full composition of load/store operands); this implies that
36 // they access the same single memref element for all iterations of the common 35 // they access the same single memref element for all iterations of the common
59 // loop/conditional live-out SSA values is available. 58 // loop/conditional live-out SSA values is available.
60 // TODO(mlir-team): do general dead store elimination for memref's. This pass 59 // TODO(mlir-team): do general dead store elimination for memref's. This pass
61 // currently only eliminates the stores only if no other loads/uses (other 60 // currently only eliminates the stores only if no other loads/uses (other
62 // than dealloc) remain. 61 // than dealloc) remain.
63 // 62 //
64 struct MemRefDataFlowOpt : public FunctionPass<MemRefDataFlowOpt> { 63 struct MemRefDataFlowOpt : public MemRefDataFlowOptBase<MemRefDataFlowOpt> {
65 void runOnFunction() override; 64 void runOnFunction() override;
66 65
67 void forwardStoreToLoad(AffineLoadOp loadOp); 66 void forwardStoreToLoad(AffineLoadOp loadOp);
68 67
69 // A list of memref's that are potentially dead / could be eliminated. 68 // A list of memref's that are potentially dead / could be eliminated.
77 76
78 } // end anonymous namespace 77 } // end anonymous namespace
79 78
80 /// Creates a pass to perform optimizations relying on memref dataflow such as 79 /// Creates a pass to perform optimizations relying on memref dataflow such as
81 /// store to load forwarding, elimination of dead stores, and dead allocs. 80 /// store to load forwarding, elimination of dead stores, and dead allocs.
82 std::unique_ptr<OpPassBase<FuncOp>> mlir::createMemRefDataFlowOptPass() { 81 std::unique_ptr<OperationPass<FuncOp>> mlir::createMemRefDataFlowOptPass() {
83 return std::make_unique<MemRefDataFlowOpt>(); 82 return std::make_unique<MemRefDataFlowOpt>();
84 } 83 }
85 84
86 // This is a straightforward implementation not optimized for speed. Optimize 85 // This is a straightforward implementation not optimized for speed. Optimize
87 // if needed. 86 // if needed.
88 void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) { 87 void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
89 Operation *loadOpInst = loadOp.getOperation(); 88 // First pass over the use list to get the minimum number of surrounding
90
91 // First pass over the use list to get minimum number of surrounding
92 // loops common between the load op and the store op, with min taken across 89 // loops common between the load op and the store op, with min taken across
93 // all store ops. 90 // all store ops.
94 SmallVector<Operation *, 8> storeOps; 91 SmallVector<Operation *, 8> storeOps;
95 unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); 92 unsigned minSurroundingLoops = getNestingDepth(loadOp);
96 for (auto *user : loadOp.getMemRef().getUsers()) { 93 for (auto *user : loadOp.getMemRef().getUsers()) {
97 auto storeOp = dyn_cast<AffineStoreOp>(user); 94 auto storeOp = dyn_cast<AffineStoreOp>(user);
98 if (!storeOp) 95 if (!storeOp)
99 continue; 96 continue;
100 auto *storeOpInst = storeOp.getOperation(); 97 unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp);
101 unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst);
102 minSurroundingLoops = std::min(nsLoops, minSurroundingLoops); 98 minSurroundingLoops = std::min(nsLoops, minSurroundingLoops);
103 storeOps.push_back(storeOpInst); 99 storeOps.push_back(storeOp);
104 } 100 }
105 101
106 // The list of store op candidates for forwarding that satisfy conditions 102 // The list of store op candidates for forwarding that satisfy conditions
107 // (1) and (2) above - they will be filtered later when checking (3). 103 // (1) and (2) above - they will be filtered later when checking (3).
108 SmallVector<Operation *, 8> fwdingCandidates; 104 SmallVector<Operation *, 8> fwdingCandidates;
110 // Store ops that have a dependence into the load (even if they aren't 106 // Store ops that have a dependence into the load (even if they aren't
111 // forwarding candidates). Each forwarding candidate will be checked for a 107 // forwarding candidates). Each forwarding candidate will be checked for a
112 // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores. 108 // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores.
113 SmallVector<Operation *, 8> depSrcStores; 109 SmallVector<Operation *, 8> depSrcStores;
114 110
115 for (auto *storeOpInst : storeOps) { 111 for (auto *storeOp : storeOps) {
116 MemRefAccess srcAccess(storeOpInst); 112 MemRefAccess srcAccess(storeOp);
117 MemRefAccess destAccess(loadOpInst); 113 MemRefAccess destAccess(loadOp);
118 // Find stores that may be reaching the load. 114 // Find stores that may be reaching the load.
119 FlatAffineConstraints dependenceConstraints; 115 FlatAffineConstraints dependenceConstraints;
120 unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst); 116 unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp);
121 unsigned d; 117 unsigned d;
122 // Dependences at loop depth <= minSurroundingLoops do NOT matter. 118 // Dependences at loop depth <= minSurroundingLoops do NOT matter.
123 for (d = nsLoops + 1; d > minSurroundingLoops; d--) { 119 for (d = nsLoops + 1; d > minSurroundingLoops; d--) {
124 DependenceResult result = checkMemrefAccessDependence( 120 DependenceResult result = checkMemrefAccessDependence(
125 srcAccess, destAccess, d, &dependenceConstraints, 121 srcAccess, destAccess, d, &dependenceConstraints,
129 } 125 }
130 if (d == minSurroundingLoops) 126 if (d == minSurroundingLoops)
131 continue; 127 continue;
132 128
133 // Stores that *may* be reaching the load. 129 // Stores that *may* be reaching the load.
134 depSrcStores.push_back(storeOpInst); 130 depSrcStores.push_back(storeOp);
135 131
136 // 1. Check if the store and the load have mathematically equivalent 132 // 1. Check if the store and the load have mathematically equivalent
137 // affine access functions; this implies that they statically refer to the 133 // affine access functions; this implies that they statically refer to the
138 // same single memref element. As an example this filters out cases like: 134 // same single memref element. As an example this filters out cases like:
139 // store %A[%i0 + 1] 135 // store %A[%i0 + 1]
143 // Use the AffineValueMap difference based memref access equality checking. 139 // Use the AffineValueMap difference based memref access equality checking.
144 if (srcAccess != destAccess) 140 if (srcAccess != destAccess)
145 continue; 141 continue;
146 142
147 // 2. The store has to dominate the load op to be candidate. 143 // 2. The store has to dominate the load op to be candidate.
148 if (!domInfo->dominates(storeOpInst, loadOpInst)) 144 if (!domInfo->dominates(storeOp, loadOp))
149 continue; 145 continue;
150 146
151 // We now have a candidate for forwarding. 147 // We now have a candidate for forwarding.
152 fwdingCandidates.push_back(storeOpInst); 148 fwdingCandidates.push_back(storeOp);
153 } 149 }
154 150
155 // 3. Of all the store op's that meet the above criteria, the store that 151 // 3. Of all the store op's that meet the above criteria, the store that
156 // postdominates all 'depSrcStores' (if one exists) is the unique store 152 // postdominates all 'depSrcStores' (if one exists) is the unique store
157 // providing the value to the load, i.e., provably the last writer to that 153 // providing the value to the load, i.e., provably the last writer to that
158 // memref loc. 154 // memref loc.
159 // Note: this can be implemented in a cleaner way with postdominator tree 155 // Note: this can be implemented in a cleaner way with postdominator tree
160 // traversals. Consider this for the future if needed. 156 // traversals. Consider this for the future if needed.
161 Operation *lastWriteStoreOp = nullptr; 157 Operation *lastWriteStoreOp = nullptr;
162 for (auto *storeOpInst : fwdingCandidates) { 158 for (auto *storeOp : fwdingCandidates) {
163 if (llvm::all_of(depSrcStores, [&](Operation *depStore) { 159 if (llvm::all_of(depSrcStores, [&](Operation *depStore) {
164 return postDomInfo->postDominates(storeOpInst, depStore); 160 return postDomInfo->postDominates(storeOp, depStore);
165 })) { 161 })) {
166 lastWriteStoreOp = storeOpInst; 162 lastWriteStoreOp = storeOp;
167 break; 163 break;
168 } 164 }
169 } 165 }
170 if (!lastWriteStoreOp) 166 if (!lastWriteStoreOp)
171 return; 167 return;
174 Value storeVal = cast<AffineStoreOp>(lastWriteStoreOp).getValueToStore(); 170 Value storeVal = cast<AffineStoreOp>(lastWriteStoreOp).getValueToStore();
175 loadOp.replaceAllUsesWith(storeVal); 171 loadOp.replaceAllUsesWith(storeVal);
176 // Record the memref for a later sweep to optimize away. 172 // Record the memref for a later sweep to optimize away.
177 memrefsToErase.insert(loadOp.getMemRef()); 173 memrefsToErase.insert(loadOp.getMemRef());
178 // Record this to erase later. 174 // Record this to erase later.
179 loadOpsToErase.push_back(loadOpInst); 175 loadOpsToErase.push_back(loadOp);
180 } 176 }
181 177
182 void MemRefDataFlowOpt::runOnFunction() { 178 void MemRefDataFlowOpt::runOnFunction() {
183 // Only supports single block functions at the moment. 179 // Only supports single block functions at the moment.
184 FuncOp f = getFunction(); 180 FuncOp f = getFunction();
191 postDomInfo = &getAnalysis<PostDominanceInfo>(); 187 postDomInfo = &getAnalysis<PostDominanceInfo>();
192 188
193 loadOpsToErase.clear(); 189 loadOpsToErase.clear();
194 memrefsToErase.clear(); 190 memrefsToErase.clear();
195 191
196 // Walk all load's and perform load/store forwarding. 192 // Walk all load's and perform store to load forwarding.
197 f.walk([&](AffineLoadOp loadOp) { forwardStoreToLoad(loadOp); }); 193 f.walk([&](AffineLoadOp loadOp) { forwardStoreToLoad(loadOp); });
198 194
199 // Erase all load op's whose results were replaced with store fwd'ed ones. 195 // Erase all load op's whose results were replaced with store fwd'ed ones.
200 for (auto *loadOp : loadOpsToErase) { 196 for (auto *loadOp : loadOpsToErase)
201 loadOp->erase(); 197 loadOp->erase();
202 }
203 198
204 // Check if the store fwd'ed memrefs are now left with only stores and can 199 // Check if the store fwd'ed memrefs are now left with only stores and can
205 // thus be completely deleted. Note: the canonicalize pass should be able 200 // thus be completely deleted. Note: the canonicalize pass should be able
206 // to do this as well, but we'll do it here since we collected these anyway. 201 // to do this as well, but we'll do it here since we collected these anyway.
207 for (auto memref : memrefsToErase) { 202 for (auto memref : memrefsToErase) {
208 // If the memref hasn't been alloc'ed in this function, skip. 203 // If the memref hasn't been alloc'ed in this function, skip.
209 Operation *defInst = memref.getDefiningOp(); 204 Operation *defOp = memref.getDefiningOp();
210 if (!defInst || !isa<AllocOp>(defInst)) 205 if (!defOp || !isa<AllocOp>(defOp))
211 // TODO(mlir-team): if the memref was returned by a 'call' operation, we 206 // TODO(mlir-team): if the memref was returned by a 'call' operation, we
212 // could still erase it if the call had no side-effects. 207 // could still erase it if the call had no side-effects.
213 continue; 208 continue;
214 if (llvm::any_of(memref.getUsers(), [&](Operation *ownerInst) { 209 if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) {
215 return (!isa<AffineStoreOp>(ownerInst) && !isa<DeallocOp>(ownerInst)); 210 return (!isa<AffineStoreOp>(ownerOp) && !isa<DeallocOp>(ownerOp));
216 })) 211 }))
217 continue; 212 continue;
218 213
219 // Erase all stores, the dealloc, and the alloc on the memref. 214 // Erase all stores, the dealloc, and the alloc on the memref.
220 for (auto *user : llvm::make_early_inc_range(memref.getUsers())) 215 for (auto *user : llvm::make_early_inc_range(memref.getUsers()))
221 user->erase(); 216 user->erase();
222 defInst->erase(); 217 defOp->erase();
223 } 218 }
224 } 219 }
225
226 static PassRegistration<MemRefDataFlowOpt>
227 pass("memref-dataflow-opt", "Perform store/load forwarding for memrefs");