121
|
1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
|
|
2 // instrinsics
|
|
3 //
|
|
4 // The LLVM Compiler Infrastructure
|
|
5 //
|
|
6 // This file is distributed under the University of Illinois Open Source
|
|
7 // License. See LICENSE.TXT for details.
|
|
8 //
|
|
9 //===----------------------------------------------------------------------===//
|
|
10 //
|
|
11 // This pass replaces masked memory intrinsics - when unsupported by the target
|
|
12 // - with a chain of basic blocks, that deal with the elements one-by-one if the
|
|
13 // appropriate mask bit is set.
|
|
14 //
|
|
15 //===----------------------------------------------------------------------===//
|
|
16
|
|
17 #include "llvm/ADT/Twine.h"
|
|
18 #include "llvm/Analysis/TargetTransformInfo.h"
|
134
|
19 #include "llvm/CodeGen/TargetSubtargetInfo.h"
|
121
|
20 #include "llvm/IR/BasicBlock.h"
|
|
21 #include "llvm/IR/Constant.h"
|
|
22 #include "llvm/IR/Constants.h"
|
|
23 #include "llvm/IR/DerivedTypes.h"
|
|
24 #include "llvm/IR/Function.h"
|
|
25 #include "llvm/IR/IRBuilder.h"
|
|
26 #include "llvm/IR/InstrTypes.h"
|
|
27 #include "llvm/IR/Instruction.h"
|
|
28 #include "llvm/IR/Instructions.h"
|
|
29 #include "llvm/IR/IntrinsicInst.h"
|
|
30 #include "llvm/IR/Intrinsics.h"
|
|
31 #include "llvm/IR/Type.h"
|
|
32 #include "llvm/IR/Value.h"
|
|
33 #include "llvm/Pass.h"
|
|
34 #include "llvm/Support/Casting.h"
|
|
35 #include <algorithm>
|
|
36 #include <cassert>
|
|
37
|
|
38 using namespace llvm;
|
|
39
|
|
40 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
|
|
41
|
|
42 namespace {
|
|
43
|
|
44 class ScalarizeMaskedMemIntrin : public FunctionPass {
|
|
45 const TargetTransformInfo *TTI = nullptr;
|
|
46
|
|
47 public:
|
|
48 static char ID; // Pass identification, replacement for typeid
|
|
49
|
|
50 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
|
|
51 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
|
|
52 }
|
|
53
|
|
54 bool runOnFunction(Function &F) override;
|
|
55
|
|
56 StringRef getPassName() const override {
|
|
57 return "Scalarize Masked Memory Intrinsics";
|
|
58 }
|
|
59
|
|
60 void getAnalysisUsage(AnalysisUsage &AU) const override {
|
|
61 AU.addRequired<TargetTransformInfoWrapperPass>();
|
|
62 }
|
|
63
|
|
64 private:
|
|
65 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
|
|
66 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
|
|
67 };
|
|
68
|
|
69 } // end anonymous namespace
|
|
70
|
|
71 char ScalarizeMaskedMemIntrin::ID = 0;
|
|
72
|
|
73 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
|
|
74 "Scalarize unsupported masked memory intrinsics", false, false)
|
|
75
|
|
76 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
|
|
77 return new ScalarizeMaskedMemIntrin();
|
|
78 }
|
|
79
|
|
80 // Translate a masked load intrinsic like
|
|
81 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
|
|
82 // <16 x i1> %mask, <16 x i32> %passthru)
|
|
83 // to a chain of basic blocks, with loading element one-by-one if
|
|
84 // the appropriate mask bit is set
|
|
85 //
|
|
86 // %1 = bitcast i8* %addr to i32*
|
|
87 // %2 = extractelement <16 x i1> %mask, i32 0
|
|
88 // %3 = icmp eq i1 %2, true
|
|
89 // br i1 %3, label %cond.load, label %else
|
|
90 //
|
|
91 // cond.load: ; preds = %0
|
|
92 // %4 = getelementptr i32* %1, i32 0
|
|
93 // %5 = load i32* %4
|
|
94 // %6 = insertelement <16 x i32> undef, i32 %5, i32 0
|
|
95 // br label %else
|
|
96 //
|
|
97 // else: ; preds = %0, %cond.load
|
|
98 // %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ]
|
|
99 // %7 = extractelement <16 x i1> %mask, i32 1
|
|
100 // %8 = icmp eq i1 %7, true
|
|
101 // br i1 %8, label %cond.load1, label %else2
|
|
102 //
|
|
103 // cond.load1: ; preds = %else
|
|
104 // %9 = getelementptr i32* %1, i32 1
|
|
105 // %10 = load i32* %9
|
|
106 // %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1
|
|
107 // br label %else2
|
|
108 //
|
|
109 // else2: ; preds = %else, %cond.load1
|
|
110 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
|
|
111 // %12 = extractelement <16 x i1> %mask, i32 2
|
|
112 // %13 = icmp eq i1 %12, true
|
|
113 // br i1 %13, label %cond.load4, label %else5
|
|
114 //
|
|
115 static void scalarizeMaskedLoad(CallInst *CI) {
|
|
116 Value *Ptr = CI->getArgOperand(0);
|
|
117 Value *Alignment = CI->getArgOperand(1);
|
|
118 Value *Mask = CI->getArgOperand(2);
|
|
119 Value *Src0 = CI->getArgOperand(3);
|
|
120
|
|
121 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
|
|
122 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
|
|
123 assert(VecType && "Unexpected return type of masked load intrinsic");
|
|
124
|
|
125 Type *EltTy = CI->getType()->getVectorElementType();
|
|
126
|
|
127 IRBuilder<> Builder(CI->getContext());
|
|
128 Instruction *InsertPt = CI;
|
|
129 BasicBlock *IfBlock = CI->getParent();
|
|
130 BasicBlock *CondBlock = nullptr;
|
|
131 BasicBlock *PrevIfBlock = CI->getParent();
|
|
132
|
|
133 Builder.SetInsertPoint(InsertPt);
|
|
134 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
135
|
|
136 // Short-cut if the mask is all-true.
|
|
137 bool IsAllOnesMask =
|
|
138 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
|
|
139
|
|
140 if (IsAllOnesMask) {
|
|
141 Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
|
|
142 CI->replaceAllUsesWith(NewI);
|
|
143 CI->eraseFromParent();
|
|
144 return;
|
|
145 }
|
|
146
|
|
147 // Adjust alignment for the scalar instruction.
|
|
148 AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8);
|
|
149 // Bitcast %addr fron i8* to EltTy*
|
|
150 Type *NewPtrType =
|
|
151 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
|
|
152 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
|
|
153 unsigned VectorWidth = VecType->getNumElements();
|
|
154
|
|
155 Value *UndefVal = UndefValue::get(VecType);
|
|
156
|
|
157 // The result vector
|
|
158 Value *VResult = UndefVal;
|
|
159
|
|
160 if (isa<ConstantVector>(Mask)) {
|
|
161 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
162 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
|
|
163 continue;
|
|
164 Value *Gep =
|
|
165 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
|
|
166 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
|
|
167 VResult =
|
|
168 Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
|
|
169 }
|
|
170 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
|
|
171 CI->replaceAllUsesWith(NewI);
|
|
172 CI->eraseFromParent();
|
|
173 return;
|
|
174 }
|
|
175
|
|
176 PHINode *Phi = nullptr;
|
|
177 Value *PrevPhi = UndefVal;
|
|
178
|
|
179 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
180 // Fill the "else" block, created in the previous iteration
|
|
181 //
|
|
182 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
|
|
183 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
|
|
184 // %to_load = icmp eq i1 %mask_1, true
|
|
185 // br i1 %to_load, label %cond.load, label %else
|
|
186 //
|
|
187 if (Idx > 0) {
|
|
188 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
|
|
189 Phi->addIncoming(VResult, CondBlock);
|
|
190 Phi->addIncoming(PrevPhi, PrevIfBlock);
|
|
191 PrevPhi = Phi;
|
|
192 VResult = Phi;
|
|
193 }
|
|
194
|
|
195 Value *Predicate =
|
|
196 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
|
|
197 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
|
|
198 ConstantInt::get(Predicate->getType(), 1));
|
|
199
|
|
200 // Create "cond" block
|
|
201 //
|
|
202 // %EltAddr = getelementptr i32* %1, i32 0
|
|
203 // %Elt = load i32* %EltAddr
|
|
204 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
|
|
205 //
|
|
206 CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
|
|
207 Builder.SetInsertPoint(InsertPt);
|
|
208
|
|
209 Value *Gep =
|
|
210 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
|
|
211 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
|
|
212 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
|
|
213
|
|
214 // Create "else" block, fill it in the next iteration
|
|
215 BasicBlock *NewIfBlock =
|
|
216 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
|
|
217 Builder.SetInsertPoint(InsertPt);
|
|
218 Instruction *OldBr = IfBlock->getTerminator();
|
|
219 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
|
|
220 OldBr->eraseFromParent();
|
|
221 PrevIfBlock = IfBlock;
|
|
222 IfBlock = NewIfBlock;
|
|
223 }
|
|
224
|
|
225 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
|
|
226 Phi->addIncoming(VResult, CondBlock);
|
|
227 Phi->addIncoming(PrevPhi, PrevIfBlock);
|
|
228 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
|
|
229 CI->replaceAllUsesWith(NewI);
|
|
230 CI->eraseFromParent();
|
|
231 }
|
|
232
|
|
233 // Translate a masked store intrinsic, like
|
|
234 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
|
|
235 // <16 x i1> %mask)
|
|
236 // to a chain of basic blocks, that stores element one-by-one if
|
|
237 // the appropriate mask bit is set
|
|
238 //
|
|
239 // %1 = bitcast i8* %addr to i32*
|
|
240 // %2 = extractelement <16 x i1> %mask, i32 0
|
|
241 // %3 = icmp eq i1 %2, true
|
|
242 // br i1 %3, label %cond.store, label %else
|
|
243 //
|
|
244 // cond.store: ; preds = %0
|
|
245 // %4 = extractelement <16 x i32> %val, i32 0
|
|
246 // %5 = getelementptr i32* %1, i32 0
|
|
247 // store i32 %4, i32* %5
|
|
248 // br label %else
|
|
249 //
|
|
250 // else: ; preds = %0, %cond.store
|
|
251 // %6 = extractelement <16 x i1> %mask, i32 1
|
|
252 // %7 = icmp eq i1 %6, true
|
|
253 // br i1 %7, label %cond.store1, label %else2
|
|
254 //
|
|
255 // cond.store1: ; preds = %else
|
|
256 // %8 = extractelement <16 x i32> %val, i32 1
|
|
257 // %9 = getelementptr i32* %1, i32 1
|
|
258 // store i32 %8, i32* %9
|
|
259 // br label %else2
|
|
260 // . . .
|
|
261 static void scalarizeMaskedStore(CallInst *CI) {
|
|
262 Value *Src = CI->getArgOperand(0);
|
|
263 Value *Ptr = CI->getArgOperand(1);
|
|
264 Value *Alignment = CI->getArgOperand(2);
|
|
265 Value *Mask = CI->getArgOperand(3);
|
|
266
|
|
267 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
|
|
268 VectorType *VecType = dyn_cast<VectorType>(Src->getType());
|
|
269 assert(VecType && "Unexpected data type in masked store intrinsic");
|
|
270
|
|
271 Type *EltTy = VecType->getElementType();
|
|
272
|
|
273 IRBuilder<> Builder(CI->getContext());
|
|
274 Instruction *InsertPt = CI;
|
|
275 BasicBlock *IfBlock = CI->getParent();
|
|
276 Builder.SetInsertPoint(InsertPt);
|
|
277 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
278
|
|
279 // Short-cut if the mask is all-true.
|
|
280 bool IsAllOnesMask =
|
|
281 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
|
|
282
|
|
283 if (IsAllOnesMask) {
|
|
284 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
|
|
285 CI->eraseFromParent();
|
|
286 return;
|
|
287 }
|
|
288
|
|
289 // Adjust alignment for the scalar instruction.
|
|
290 AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8);
|
|
291 // Bitcast %addr fron i8* to EltTy*
|
|
292 Type *NewPtrType =
|
|
293 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
|
|
294 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
|
|
295 unsigned VectorWidth = VecType->getNumElements();
|
|
296
|
|
297 if (isa<ConstantVector>(Mask)) {
|
|
298 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
299 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
|
|
300 continue;
|
|
301 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
|
|
302 Value *Gep =
|
|
303 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
|
|
304 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
|
|
305 }
|
|
306 CI->eraseFromParent();
|
|
307 return;
|
|
308 }
|
|
309
|
|
310 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
311 // Fill the "else" block, created in the previous iteration
|
|
312 //
|
|
313 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
|
|
314 // %to_store = icmp eq i1 %mask_1, true
|
|
315 // br i1 %to_store, label %cond.store, label %else
|
|
316 //
|
|
317 Value *Predicate =
|
|
318 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
|
|
319 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
|
|
320 ConstantInt::get(Predicate->getType(), 1));
|
|
321
|
|
322 // Create "cond" block
|
|
323 //
|
|
324 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
|
|
325 // %EltAddr = getelementptr i32* %1, i32 0
|
|
326 // %store i32 %OneElt, i32* %EltAddr
|
|
327 //
|
|
328 BasicBlock *CondBlock =
|
|
329 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
|
|
330 Builder.SetInsertPoint(InsertPt);
|
|
331
|
|
332 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
|
|
333 Value *Gep =
|
|
334 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
|
|
335 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
|
|
336
|
|
337 // Create "else" block, fill it in the next iteration
|
|
338 BasicBlock *NewIfBlock =
|
|
339 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
|
|
340 Builder.SetInsertPoint(InsertPt);
|
|
341 Instruction *OldBr = IfBlock->getTerminator();
|
|
342 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
|
|
343 OldBr->eraseFromParent();
|
|
344 IfBlock = NewIfBlock;
|
|
345 }
|
|
346 CI->eraseFromParent();
|
|
347 }
|
|
348
|
|
349 // Translate a masked gather intrinsic like
|
|
350 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
|
|
351 // <16 x i1> %Mask, <16 x i32> %Src)
|
|
352 // to a chain of basic blocks, with loading element one-by-one if
|
|
353 // the appropriate mask bit is set
|
|
354 //
|
|
355 // % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
|
|
356 // % Mask0 = extractelement <16 x i1> %Mask, i32 0
|
|
357 // % ToLoad0 = icmp eq i1 % Mask0, true
|
|
358 // br i1 % ToLoad0, label %cond.load, label %else
|
|
359 //
|
|
360 // cond.load:
|
|
361 // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
|
|
362 // % Load0 = load i32, i32* % Ptr0, align 4
|
|
363 // % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0
|
|
364 // br label %else
|
|
365 //
|
|
366 // else:
|
|
367 // %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0]
|
|
368 // % Mask1 = extractelement <16 x i1> %Mask, i32 1
|
|
369 // % ToLoad1 = icmp eq i1 % Mask1, true
|
|
370 // br i1 % ToLoad1, label %cond.load1, label %else2
|
|
371 //
|
|
372 // cond.load1:
|
|
373 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
|
|
374 // % Load1 = load i32, i32* % Ptr1, align 4
|
|
375 // % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1
|
|
376 // br label %else2
|
|
377 // . . .
|
|
378 // % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
|
|
379 // ret <16 x i32> %Result
|
|
380 static void scalarizeMaskedGather(CallInst *CI) {
|
|
381 Value *Ptrs = CI->getArgOperand(0);
|
|
382 Value *Alignment = CI->getArgOperand(1);
|
|
383 Value *Mask = CI->getArgOperand(2);
|
|
384 Value *Src0 = CI->getArgOperand(3);
|
|
385
|
|
386 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
|
|
387
|
|
388 assert(VecType && "Unexpected return type of masked load intrinsic");
|
|
389
|
|
390 IRBuilder<> Builder(CI->getContext());
|
|
391 Instruction *InsertPt = CI;
|
|
392 BasicBlock *IfBlock = CI->getParent();
|
|
393 BasicBlock *CondBlock = nullptr;
|
|
394 BasicBlock *PrevIfBlock = CI->getParent();
|
|
395 Builder.SetInsertPoint(InsertPt);
|
|
396 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
|
|
397
|
|
398 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
399
|
|
400 Value *UndefVal = UndefValue::get(VecType);
|
|
401
|
|
402 // The result vector
|
|
403 Value *VResult = UndefVal;
|
|
404 unsigned VectorWidth = VecType->getNumElements();
|
|
405
|
|
406 // Shorten the way if the mask is a vector of constants.
|
|
407 bool IsConstMask = isa<ConstantVector>(Mask);
|
|
408
|
|
409 if (IsConstMask) {
|
|
410 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
411 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
|
|
412 continue;
|
|
413 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
|
|
414 "Ptr" + Twine(Idx));
|
|
415 LoadInst *Load =
|
|
416 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
|
|
417 VResult = Builder.CreateInsertElement(
|
|
418 VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
|
|
419 }
|
|
420 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
|
|
421 CI->replaceAllUsesWith(NewI);
|
|
422 CI->eraseFromParent();
|
|
423 return;
|
|
424 }
|
|
425
|
|
426 PHINode *Phi = nullptr;
|
|
427 Value *PrevPhi = UndefVal;
|
|
428
|
|
429 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
430 // Fill the "else" block, created in the previous iteration
|
|
431 //
|
|
432 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
|
|
433 // %ToLoad1 = icmp eq i1 %Mask1, true
|
|
434 // br i1 %ToLoad1, label %cond.load, label %else
|
|
435 //
|
|
436 if (Idx > 0) {
|
|
437 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
|
|
438 Phi->addIncoming(VResult, CondBlock);
|
|
439 Phi->addIncoming(PrevPhi, PrevIfBlock);
|
|
440 PrevPhi = Phi;
|
|
441 VResult = Phi;
|
|
442 }
|
|
443
|
|
444 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
|
|
445 "Mask" + Twine(Idx));
|
|
446 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
|
|
447 ConstantInt::get(Predicate->getType(), 1),
|
|
448 "ToLoad" + Twine(Idx));
|
|
449
|
|
450 // Create "cond" block
|
|
451 //
|
|
452 // %EltAddr = getelementptr i32* %1, i32 0
|
|
453 // %Elt = load i32* %EltAddr
|
|
454 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
|
|
455 //
|
|
456 CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
|
|
457 Builder.SetInsertPoint(InsertPt);
|
|
458
|
|
459 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
|
|
460 "Ptr" + Twine(Idx));
|
|
461 LoadInst *Load =
|
|
462 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
|
|
463 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx),
|
|
464 "Res" + Twine(Idx));
|
|
465
|
|
466 // Create "else" block, fill it in the next iteration
|
|
467 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
|
|
468 Builder.SetInsertPoint(InsertPt);
|
|
469 Instruction *OldBr = IfBlock->getTerminator();
|
|
470 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
|
|
471 OldBr->eraseFromParent();
|
|
472 PrevIfBlock = IfBlock;
|
|
473 IfBlock = NewIfBlock;
|
|
474 }
|
|
475
|
|
476 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
|
|
477 Phi->addIncoming(VResult, CondBlock);
|
|
478 Phi->addIncoming(PrevPhi, PrevIfBlock);
|
|
479 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
|
|
480 CI->replaceAllUsesWith(NewI);
|
|
481 CI->eraseFromParent();
|
|
482 }
|
|
483
|
|
484 // Translate a masked scatter intrinsic, like
|
|
485 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
|
|
486 // <16 x i1> %Mask)
|
|
487 // to a chain of basic blocks, that stores element one-by-one if
|
|
488 // the appropriate mask bit is set.
|
|
489 //
|
|
490 // % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
|
|
491 // % Mask0 = extractelement <16 x i1> % Mask, i32 0
|
|
492 // % ToStore0 = icmp eq i1 % Mask0, true
|
|
493 // br i1 %ToStore0, label %cond.store, label %else
|
|
494 //
|
|
495 // cond.store:
|
|
496 // % Elt0 = extractelement <16 x i32> %Src, i32 0
|
|
497 // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
|
|
498 // store i32 %Elt0, i32* % Ptr0, align 4
|
|
499 // br label %else
|
|
500 //
|
|
501 // else:
|
|
502 // % Mask1 = extractelement <16 x i1> % Mask, i32 1
|
|
503 // % ToStore1 = icmp eq i1 % Mask1, true
|
|
504 // br i1 % ToStore1, label %cond.store1, label %else2
|
|
505 //
|
|
506 // cond.store1:
|
|
507 // % Elt1 = extractelement <16 x i32> %Src, i32 1
|
|
508 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
|
|
509 // store i32 % Elt1, i32* % Ptr1, align 4
|
|
510 // br label %else2
|
|
511 // . . .
|
|
512 static void scalarizeMaskedScatter(CallInst *CI) {
|
|
513 Value *Src = CI->getArgOperand(0);
|
|
514 Value *Ptrs = CI->getArgOperand(1);
|
|
515 Value *Alignment = CI->getArgOperand(2);
|
|
516 Value *Mask = CI->getArgOperand(3);
|
|
517
|
|
518 assert(isa<VectorType>(Src->getType()) &&
|
|
519 "Unexpected data type in masked scatter intrinsic");
|
|
520 assert(isa<VectorType>(Ptrs->getType()) &&
|
|
521 isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
|
|
522 "Vector of pointers is expected in masked scatter intrinsic");
|
|
523
|
|
524 IRBuilder<> Builder(CI->getContext());
|
|
525 Instruction *InsertPt = CI;
|
|
526 BasicBlock *IfBlock = CI->getParent();
|
|
527 Builder.SetInsertPoint(InsertPt);
|
|
528 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
529
|
|
530 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
|
|
531 unsigned VectorWidth = Src->getType()->getVectorNumElements();
|
|
532
|
|
533 // Shorten the way if the mask is a vector of constants.
|
|
534 bool IsConstMask = isa<ConstantVector>(Mask);
|
|
535
|
|
536 if (IsConstMask) {
|
|
537 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
538 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
|
|
539 continue;
|
|
540 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
|
|
541 "Elt" + Twine(Idx));
|
|
542 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
|
|
543 "Ptr" + Twine(Idx));
|
|
544 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
|
|
545 }
|
|
546 CI->eraseFromParent();
|
|
547 return;
|
|
548 }
|
|
549 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
550 // Fill the "else" block, created in the previous iteration
|
|
551 //
|
|
552 // % Mask1 = extractelement <16 x i1> % Mask, i32 Idx
|
|
553 // % ToStore = icmp eq i1 % Mask1, true
|
|
554 // br i1 % ToStore, label %cond.store, label %else
|
|
555 //
|
|
556 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
|
|
557 "Mask" + Twine(Idx));
|
|
558 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
|
|
559 ConstantInt::get(Predicate->getType(), 1),
|
|
560 "ToStore" + Twine(Idx));
|
|
561
|
|
562 // Create "cond" block
|
|
563 //
|
|
564 // % Elt1 = extractelement <16 x i32> %Src, i32 1
|
|
565 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
|
|
566 // %store i32 % Elt1, i32* % Ptr1
|
|
567 //
|
|
568 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
|
|
569 Builder.SetInsertPoint(InsertPt);
|
|
570
|
|
571 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
|
|
572 "Elt" + Twine(Idx));
|
|
573 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
|
|
574 "Ptr" + Twine(Idx));
|
|
575 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
|
|
576
|
|
577 // Create "else" block, fill it in the next iteration
|
|
578 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
|
|
579 Builder.SetInsertPoint(InsertPt);
|
|
580 Instruction *OldBr = IfBlock->getTerminator();
|
|
581 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
|
|
582 OldBr->eraseFromParent();
|
|
583 IfBlock = NewIfBlock;
|
|
584 }
|
|
585 CI->eraseFromParent();
|
|
586 }
|
|
587
|
|
588 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
|
|
589 if (skipFunction(F))
|
|
590 return false;
|
|
591
|
|
592 bool EverMadeChange = false;
|
|
593
|
|
594 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
|
|
595
|
|
596 bool MadeChange = true;
|
|
597 while (MadeChange) {
|
|
598 MadeChange = false;
|
|
599 for (Function::iterator I = F.begin(); I != F.end();) {
|
|
600 BasicBlock *BB = &*I++;
|
|
601 bool ModifiedDTOnIteration = false;
|
|
602 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
|
|
603
|
|
604 // Restart BB iteration if the dominator tree of the Function was changed
|
|
605 if (ModifiedDTOnIteration)
|
|
606 break;
|
|
607 }
|
|
608
|
|
609 EverMadeChange |= MadeChange;
|
|
610 }
|
|
611
|
|
612 return EverMadeChange;
|
|
613 }
|
|
614
|
|
615 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
|
|
616 bool MadeChange = false;
|
|
617
|
|
618 BasicBlock::iterator CurInstIterator = BB.begin();
|
|
619 while (CurInstIterator != BB.end()) {
|
|
620 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
|
|
621 MadeChange |= optimizeCallInst(CI, ModifiedDT);
|
|
622 if (ModifiedDT)
|
|
623 return true;
|
|
624 }
|
|
625
|
|
626 return MadeChange;
|
|
627 }
|
|
628
|
|
629 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
|
|
630 bool &ModifiedDT) {
|
|
631 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
|
|
632 if (II) {
|
|
633 switch (II->getIntrinsicID()) {
|
|
634 default:
|
|
635 break;
|
|
636 case Intrinsic::masked_load:
|
|
637 // Scalarize unsupported vector masked load
|
|
638 if (!TTI->isLegalMaskedLoad(CI->getType())) {
|
|
639 scalarizeMaskedLoad(CI);
|
|
640 ModifiedDT = true;
|
|
641 return true;
|
|
642 }
|
|
643 return false;
|
|
644 case Intrinsic::masked_store:
|
|
645 if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
|
|
646 scalarizeMaskedStore(CI);
|
|
647 ModifiedDT = true;
|
|
648 return true;
|
|
649 }
|
|
650 return false;
|
|
651 case Intrinsic::masked_gather:
|
|
652 if (!TTI->isLegalMaskedGather(CI->getType())) {
|
|
653 scalarizeMaskedGather(CI);
|
|
654 ModifiedDT = true;
|
|
655 return true;
|
|
656 }
|
|
657 return false;
|
|
658 case Intrinsic::masked_scatter:
|
|
659 if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
|
|
660 scalarizeMaskedScatter(CI);
|
|
661 ModifiedDT = true;
|
|
662 return true;
|
|
663 }
|
|
664 return false;
|
|
665 }
|
|
666 }
|
|
667
|
|
668 return false;
|
|
669 }
|