llvm.org GIT mirror llvm / 46b13dd
[InstSimplify] Teach InstSimplify how to simplify extractelement git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@242008 91177308-0d34-0410-b5e6-96231b3b80d8 David Majnemer 4 years ago
7 changed file(s) with 142 addition(s) and 58 deletion(s). Raw diff Collapse all Expand all
7777 Constant *ConstantFoldExtractValueInstruction(Constant *Agg,
7878 ArrayRef Idxs);
7979
80 /// \brief Attempt to constant fold an extractelement instruction with the
81 /// specified operands and indices. The constant result is returned if
82 /// successful; if not, null is returned.
83 Constant *ConstantFoldExtractElementInstruction(Constant *Val, Constant *Idx);
84
8085 /// ConstantFoldLoadFromConstPtr - Return the value that a load from C would
8186 /// produce if it is constant and determinable. If this is not determinable,
8287 /// return null.
252252 AssumptionCache *AC = nullptr,
253253 const Instruction *CxtI = nullptr);
254254
255 /// \brief Given operands for an ExtractElementInst, see if we can fold the
256 /// result. If not, this returns null.
257 Value *SimplifyExtractElementInst(Value *Vec, Value *Idx,
258 const DataLayout &DL,
259 const TargetLibraryInfo *TLI = nullptr,
260 const DominatorTree *DT = nullptr,
261 AssumptionCache *AC = nullptr,
262 const Instruction *CxtI = nullptr);
263
255264 /// SimplifyTruncInst - Given operands for an TruncInst, see if we can fold
256265 /// the result. If not, this returns null.
257266 Value *SimplifyTruncInst(Value *Op, Type *Ty, const DataLayout &DL,
7373 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
7474 Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
7575
76 /// \brief Given a vector and an element number, see if the scalar value is
77 /// already around as a register, for example if it were inserted then extracted
78 /// from the vector.
79 Value *findScalarElement(Value *V, unsigned EltNo);
80
7681 } // llvm namespace
7782
7883 #endif
2323 #include "llvm/Analysis/ConstantFolding.h"
2424 #include "llvm/Analysis/MemoryBuiltins.h"
2525 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/Analysis/VectorUtils.h"
2627 #include "llvm/IR/ConstantRange.h"
2728 #include "llvm/IR/DataLayout.h"
2829 #include "llvm/IR/Dominators.h"
35543555 RecursionLimit);
35553556 }
35563557
3558 /// SimplifyExtractElementInst - Given operands for an ExtractElementInst, see if we
3559 /// can fold the result. If not, this returns null.
3560 static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const Query &,
3561 unsigned) {
3562 if (auto *CVec = dyn_cast(Vec)) {
3563 if (auto *CIdx = dyn_cast(Idx))
3564 return ConstantFoldExtractElementInstruction(CVec, CIdx);
3565
3566 // The index is not relevant if our vector is a splat.
3567 if (auto *Splat = CVec->getSplatValue())
3568 return Splat;
3569
3570 if (isa(Vec))
3571 return UndefValue::get(Vec->getType()->getVectorElementType());
3572 }
3573
3574 // If extracting a specified index from the vector, see if we can recursively
3575 // find a previously computed scalar that was inserted into the vector.
3576 if (auto *IdxC = dyn_cast(Idx)) {
3577 unsigned IndexVal = IdxC->getZExtValue();
3578 unsigned VectorWidth = Vec->getType()->getVectorNumElements();
3579
3580 // If this is extracting an invalid index, turn this into undef, to avoid
3581 // crashing the code below.
3582 if (IndexVal >= VectorWidth)
3583 return UndefValue::get(Vec->getType()->getVectorElementType());
3584
3585 if (Value *Elt = findScalarElement(Vec, IndexVal))
3586 return Elt;
3587 }
3588
3589 return nullptr;
3590 }
3591
3592 Value *llvm::SimplifyExtractElementInst(
3593 Value *Vec, Value *Idx, const DataLayout &DL, const TargetLibraryInfo *TLI,
3594 const DominatorTree *DT, AssumptionCache *AC, const Instruction *CxtI) {
3595 return ::SimplifyExtractElementInst(Vec, Idx, Query(DL, TLI, DT, AC, CxtI),
3596 RecursionLimit);
3597 }
3598
35573599 /// SimplifyPHINode - See if we can fold the given phi. If not, returns null.
35583600 static Value *SimplifyPHINode(PHINode *PN, const Query &Q) {
35593601 // If all of the PHI's incoming values are the same then replace the PHI node
39694011 EVI->getIndices(), DL, TLI, DT, AC, I);
39704012 break;
39714013 }
4014 case Instruction::ExtractElement: {
4015 auto *EEI = cast(I);
4016 Result = SimplifyExtractElementInst(
4017 EEI->getVectorOperand(), EEI->getIndexOperand(), DL, TLI, DT, AC, I);
4018 break;
4019 }
39724020 case Instruction::PHI:
39734021 Result = SimplifyPHINode(cast(I), Query(DL, TLI, DT, AC, I));
39744022 break;
356356
357357 return Stride;
358358 }
359
360 /// \brief Given a vector and an element number, see if the scalar value is
361 /// already around as a register, for example if it were inserted then extracted
362 /// from the vector.
363 llvm::Value *llvm::findScalarElement(llvm::Value *V, unsigned EltNo) {
364 assert(V->getType()->isVectorTy() && "Not looking at a vector?");
365 VectorType *VTy = cast(V->getType());
366 unsigned Width = VTy->getNumElements();
367 if (EltNo >= Width) // Out of range access.
368 return UndefValue::get(VTy->getElementType());
369
370 if (Constant *C = dyn_cast(V))
371 return C->getAggregateElement(EltNo);
372
373 if (InsertElementInst *III = dyn_cast(V)) {
374 // If this is an insert to a variable element, we don't know what it is.
375 if (!isa(III->getOperand(2)))
376 return nullptr;
377 unsigned IIElt = cast(III->getOperand(2))->getZExtValue();
378
379 // If this is an insert to the element we are looking for, return the
380 // inserted value.
381 if (EltNo == IIElt)
382 return III->getOperand(1);
383
384 // Otherwise, the insertelement doesn't modify the value, recurse on its
385 // vector input.
386 return findScalarElement(III->getOperand(0), EltNo);
387 }
388
389 if (ShuffleVectorInst *SVI = dyn_cast(V)) {
390 unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements();
391 int InEl = SVI->getMaskValue(EltNo);
392 if (InEl < 0)
393 return UndefValue::get(VTy->getElementType());
394 if (InEl < (int)LHSWidth)
395 return findScalarElement(SVI->getOperand(0), InEl);
396 return findScalarElement(SVI->getOperand(1), InEl - LHSWidth);
397 }
398
399 // Extract a value from a vector add operation with a constant zero.
400 Value *Val = nullptr; Constant *Con = nullptr;
401 if (match(V,
402 llvm::PatternMatch::m_Add(llvm::PatternMatch::m_Value(Val),
403 llvm::PatternMatch::m_Constant(Con)))) {
404 if (Con->getAggregateElement(EltNo)->isNullValue())
405 return findScalarElement(Val, EltNo);
406 }
407
408 // Otherwise, we don't know.
409 return nullptr;
410 }
1313
1414 #include "InstCombineInternal.h"
1515 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/Analysis/InstructionSimplify.h"
17 #include "llvm/Analysis/VectorUtils.h"
1618 #include "llvm/IR/PatternMatch.h"
1719 using namespace llvm;
1820 using namespace PatternMatch;
5759 return true;
5860
5961 return false;
60 }
61
62 /// FindScalarElement - Given a vector and an element number, see if the scalar
63 /// value is already around as a register, for example if it were inserted then
64 /// extracted from the vector.
65 static Value *FindScalarElement(Value *V, unsigned EltNo) {
66 assert(V->getType()->isVectorTy() && "Not looking at a vector?");
67 VectorType *VTy = cast(V->getType());
68 unsigned Width = VTy->getNumElements();
69 if (EltNo >= Width) // Out of range access.
70 return UndefValue::get(VTy->getElementType());
71
72 if (Constant *C = dyn_cast(V))
73 return C->getAggregateElement(EltNo);
74
75 if (InsertElementInst *III = dyn_cast(V)) {
76 // If this is an insert to a variable element, we don't know what it is.
77 if (!isa(III->getOperand(2)))
78 return nullptr;
79 unsigned IIElt = cast(III->getOperand(2))->getZExtValue();
80
81 // If this is an insert to the element we are looking for, return the
82 // inserted value.
83 if (EltNo == IIElt)
84 return III->getOperand(1);
85
86 // Otherwise, the insertelement doesn't modify the value, recurse on its
87 // vector input.
88 return FindScalarElement(III->getOperand(0), EltNo);
89 }
90
91 if (ShuffleVectorInst *SVI = dyn_cast(V)) {
92 unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements();
93 int InEl = SVI->getMaskValue(EltNo);
94 if (InEl < 0)
95 return UndefValue::get(VTy->getElementType());
96 if (InEl < (int)LHSWidth)
97 return FindScalarElement(SVI->getOperand(0), InEl);
98 return FindScalarElement(SVI->getOperand(1), InEl - LHSWidth);
99 }
100
101 // Extract a value from a vector add operation with a constant zero.
102 Value *Val = nullptr; Constant *Con = nullptr;
103 if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) {
104 if (Con->getAggregateElement(EltNo)->isNullValue())
105 return FindScalarElement(Val, EltNo);
106 }
107
108 // Otherwise, we don't know.
109 return nullptr;
11062 }
11163
11264 // If we have a PHI node with a vector type that has only 2 uses: feed
177129 }
178130
179131 Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
132 if (Value *V = SimplifyExtractElementInst(
133 EI.getVectorOperand(), EI.getIndexOperand(), DL, TLI, DT, AC))
134 return ReplaceInstUsesWith(EI, V);
135
180136 // If vector val is constant with all elements the same, replace EI with
181137 // that element. We handle a known element # below.
182138 if (Constant *C = dyn_cast(EI.getOperand(0)))
189145 unsigned IndexVal = IdxC->getZExtValue();
190146 unsigned VectorWidth = EI.getVectorOperandType()->getNumElements();
191147
192 // If this is extracting an invalid index, turn this into undef, to avoid
193 // crashing the code below.
194 if (IndexVal >= VectorWidth)
195 return ReplaceInstUsesWith(EI, UndefValue::get(EI.getType()));
148 // InstSimplify handles cases where the index is invalid.
149 assert(IndexVal < VectorWidth);
196150
197151 // This instruction only demands the single element from the input vector.
198152 // If the input vector has a single use, simplify it based on this use
208162 }
209163 }
210164
211 if (Value *Elt = FindScalarElement(EI.getOperand(0), IndexVal))
212 return ReplaceInstUsesWith(EI, Elt);
213
214165 // If the this extractelement is directly using a bitcast from a vector of
215166 // the same number of elements, see if we can find the source element from
216167 // it. In this case, we will end up needing to bitcast the scalars.
217168 if (BitCastInst *BCI = dyn_cast(EI.getOperand(0))) {
218169 if (VectorType *VT = dyn_cast(BCI->getOperand(0)->getType()))
219170 if (VT->getNumElements() == VectorWidth)
220 if (Value *Elt = FindScalarElement(BCI->getOperand(0), IndexVal))
171 if (Value *Elt = findScalarElement(BCI->getOperand(0), IndexVal))
221172 return new BitCastInst(Elt, EI.getType());
222173 }
223174
264264 %b = lshr i32 undef, 0
265265 ret i32 %b
266266 }
267
268 ; CHECK-LABEL: @test35
269 ; CHECK: ret i32 undef
270 define i32 @test35(<4 x i32> %V) {
271 %b = extractelement <4 x i32> %V, i32 4
272 ret i32 %b
273 }
274
275 ; CHECK-LABEL: @test36
276 ; CHECK: ret i32 undef
277 define i32 @test36(i32 %V) {
278 %b = extractelement <4 x i32> undef, i32 %V
279 ret i32 %b
280 }