1 | //===- InstCombineCasts.cpp -----------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the visit functions for cast operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "InstCombineInternal.h" |
14 | #include "llvm/ADT/SetVector.h" |
15 | #include "llvm/Analysis/ConstantFolding.h" |
16 | #include "llvm/IR/DataLayout.h" |
17 | #include "llvm/IR/DebugInfo.h" |
18 | #include "llvm/IR/PatternMatch.h" |
19 | #include "llvm/Support/KnownBits.h" |
20 | #include "llvm/Transforms/InstCombine/InstCombiner.h" |
21 | #include <optional> |
22 | |
23 | using namespace llvm; |
24 | using namespace PatternMatch; |
25 | |
26 | #define DEBUG_TYPE "instcombine" |
27 | |
28 | /// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns |
29 | /// true for, actually insert the code to evaluate the expression. |
30 | Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, |
31 | bool isSigned) { |
32 | if (Constant *C = dyn_cast<Constant>(Val: V)) |
33 | return ConstantFoldIntegerCast(C, DestTy: Ty, IsSigned: isSigned, DL); |
34 | |
35 | // Otherwise, it must be an instruction. |
36 | Instruction *I = cast<Instruction>(Val: V); |
37 | Instruction *Res = nullptr; |
38 | unsigned Opc = I->getOpcode(); |
39 | switch (Opc) { |
40 | case Instruction::Add: |
41 | case Instruction::Sub: |
42 | case Instruction::Mul: |
43 | case Instruction::And: |
44 | case Instruction::Or: |
45 | case Instruction::Xor: |
46 | case Instruction::AShr: |
47 | case Instruction::LShr: |
48 | case Instruction::Shl: |
49 | case Instruction::UDiv: |
50 | case Instruction::URem: { |
51 | Value *LHS = EvaluateInDifferentType(V: I->getOperand(i: 0), Ty, isSigned); |
52 | Value *RHS = EvaluateInDifferentType(V: I->getOperand(i: 1), Ty, isSigned); |
53 | Res = BinaryOperator::Create(Op: (Instruction::BinaryOps)Opc, S1: LHS, S2: RHS); |
54 | break; |
55 | } |
56 | case Instruction::Trunc: |
57 | case Instruction::ZExt: |
58 | case Instruction::SExt: |
59 | // If the source type of the cast is the type we're trying for then we can |
60 | // just return the source. There's no need to insert it because it is not |
61 | // new. |
62 | if (I->getOperand(i: 0)->getType() == Ty) |
63 | return I->getOperand(i: 0); |
64 | |
65 | // Otherwise, must be the same type of cast, so just reinsert a new one. |
66 | // This also handles the case of zext(trunc(x)) -> zext(x). |
67 | Res = CastInst::CreateIntegerCast(S: I->getOperand(i: 0), Ty, |
68 | isSigned: Opc == Instruction::SExt); |
69 | break; |
70 | case Instruction::Select: { |
71 | Value *True = EvaluateInDifferentType(V: I->getOperand(i: 1), Ty, isSigned); |
72 | Value *False = EvaluateInDifferentType(V: I->getOperand(i: 2), Ty, isSigned); |
73 | Res = SelectInst::Create(C: I->getOperand(i: 0), S1: True, S2: False); |
74 | break; |
75 | } |
76 | case Instruction::PHI: { |
77 | PHINode *OPN = cast<PHINode>(Val: I); |
78 | PHINode *NPN = PHINode::Create(Ty, NumReservedValues: OPN->getNumIncomingValues()); |
79 | for (unsigned i = 0, e = OPN->getNumIncomingValues(); i != e; ++i) { |
80 | Value *V = |
81 | EvaluateInDifferentType(V: OPN->getIncomingValue(i), Ty, isSigned); |
82 | NPN->addIncoming(V, BB: OPN->getIncomingBlock(i)); |
83 | } |
84 | Res = NPN; |
85 | break; |
86 | } |
87 | case Instruction::FPToUI: |
88 | case Instruction::FPToSI: |
89 | Res = CastInst::Create( |
90 | static_cast<Instruction::CastOps>(Opc), S: I->getOperand(i: 0), Ty); |
91 | break; |
92 | case Instruction::Call: |
93 | if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I)) { |
94 | switch (II->getIntrinsicID()) { |
95 | default: |
96 | llvm_unreachable("Unsupported call!" ); |
97 | case Intrinsic::vscale: { |
98 | Function *Fn = |
99 | Intrinsic::getDeclaration(M: I->getModule(), Intrinsic::id: vscale, Tys: {Ty}); |
100 | Res = CallInst::Create(Ty: Fn->getFunctionType(), F: Fn); |
101 | break; |
102 | } |
103 | } |
104 | } |
105 | break; |
106 | case Instruction::ShuffleVector: { |
107 | auto *ScalarTy = cast<VectorType>(Val: Ty)->getElementType(); |
108 | auto *VTy = cast<VectorType>(Val: I->getOperand(i: 0)->getType()); |
109 | auto *FixedTy = VectorType::get(ElementType: ScalarTy, EC: VTy->getElementCount()); |
110 | Value *Op0 = EvaluateInDifferentType(V: I->getOperand(i: 0), Ty: FixedTy, isSigned); |
111 | Value *Op1 = EvaluateInDifferentType(V: I->getOperand(i: 1), Ty: FixedTy, isSigned); |
112 | Res = new ShuffleVectorInst(Op0, Op1, |
113 | cast<ShuffleVectorInst>(Val: I)->getShuffleMask()); |
114 | break; |
115 | } |
116 | default: |
117 | // TODO: Can handle more cases here. |
118 | llvm_unreachable("Unreachable!" ); |
119 | } |
120 | |
121 | Res->takeName(V: I); |
122 | return InsertNewInstWith(New: Res, Old: I->getIterator()); |
123 | } |
124 | |
125 | Instruction::CastOps |
126 | InstCombinerImpl::isEliminableCastPair(const CastInst *CI1, |
127 | const CastInst *CI2) { |
128 | Type *SrcTy = CI1->getSrcTy(); |
129 | Type *MidTy = CI1->getDestTy(); |
130 | Type *DstTy = CI2->getDestTy(); |
131 | |
132 | Instruction::CastOps firstOp = CI1->getOpcode(); |
133 | Instruction::CastOps secondOp = CI2->getOpcode(); |
134 | Type *SrcIntPtrTy = |
135 | SrcTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(SrcTy) : nullptr; |
136 | Type *MidIntPtrTy = |
137 | MidTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(MidTy) : nullptr; |
138 | Type *DstIntPtrTy = |
139 | DstTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(DstTy) : nullptr; |
140 | unsigned Res = CastInst::isEliminableCastPair(firstOpcode: firstOp, secondOpcode: secondOp, SrcTy, MidTy, |
141 | DstTy, SrcIntPtrTy, MidIntPtrTy, |
142 | DstIntPtrTy); |
143 | |
144 | // We don't want to form an inttoptr or ptrtoint that converts to an integer |
145 | // type that differs from the pointer size. |
146 | if ((Res == Instruction::IntToPtr && SrcTy != DstIntPtrTy) || |
147 | (Res == Instruction::PtrToInt && DstTy != SrcIntPtrTy)) |
148 | Res = 0; |
149 | |
150 | return Instruction::CastOps(Res); |
151 | } |
152 | |
153 | /// Implement the transforms common to all CastInst visitors. |
154 | Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) { |
155 | Value *Src = CI.getOperand(i_nocapture: 0); |
156 | Type *Ty = CI.getType(); |
157 | |
158 | if (auto *SrcC = dyn_cast<Constant>(Val: Src)) |
159 | if (Constant *Res = ConstantFoldCastOperand(Opcode: CI.getOpcode(), C: SrcC, DestTy: Ty, DL)) |
160 | return replaceInstUsesWith(I&: CI, V: Res); |
161 | |
162 | // Try to eliminate a cast of a cast. |
163 | if (auto *CSrc = dyn_cast<CastInst>(Val: Src)) { // A->B->C cast |
164 | if (Instruction::CastOps NewOpc = isEliminableCastPair(CI1: CSrc, CI2: &CI)) { |
165 | // The first cast (CSrc) is eliminable so we need to fix up or replace |
166 | // the second cast (CI). CSrc will then have a good chance of being dead. |
167 | auto *Res = CastInst::Create(NewOpc, S: CSrc->getOperand(i_nocapture: 0), Ty); |
168 | // Point debug users of the dying cast to the new one. |
169 | if (CSrc->hasOneUse()) |
170 | replaceAllDbgUsesWith(From&: *CSrc, To&: *Res, DomPoint&: CI, DT); |
171 | return Res; |
172 | } |
173 | } |
174 | |
175 | if (auto *Sel = dyn_cast<SelectInst>(Val: Src)) { |
176 | // We are casting a select. Try to fold the cast into the select if the |
177 | // select does not have a compare instruction with matching operand types |
178 | // or the select is likely better done in a narrow type. |
179 | // Creating a select with operands that are different sizes than its |
180 | // condition may inhibit other folds and lead to worse codegen. |
181 | auto *Cmp = dyn_cast<CmpInst>(Val: Sel->getCondition()); |
182 | if (!Cmp || Cmp->getOperand(i_nocapture: 0)->getType() != Sel->getType() || |
183 | (CI.getOpcode() == Instruction::Trunc && |
184 | shouldChangeType(From: CI.getSrcTy(), To: CI.getType()))) { |
185 | |
186 | // If it's a bitcast involving vectors, make sure it has the same number |
187 | // of elements on both sides. |
188 | if (CI.getOpcode() != Instruction::BitCast || |
189 | match(V: &CI, P: m_ElementWiseBitCast(Op: m_Value()))) { |
190 | if (Instruction *NV = FoldOpIntoSelect(Op&: CI, SI: Sel)) { |
191 | replaceAllDbgUsesWith(From&: *Sel, To&: *NV, DomPoint&: CI, DT); |
192 | return NV; |
193 | } |
194 | } |
195 | } |
196 | } |
197 | |
198 | // If we are casting a PHI, then fold the cast into the PHI. |
199 | if (auto *PN = dyn_cast<PHINode>(Val: Src)) { |
200 | // Don't do this if it would create a PHI node with an illegal type from a |
201 | // legal type. |
202 | if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() || |
203 | shouldChangeType(From: CI.getSrcTy(), To: CI.getType())) |
204 | if (Instruction *NV = foldOpIntoPhi(I&: CI, PN)) |
205 | return NV; |
206 | } |
207 | |
208 | // Canonicalize a unary shuffle after the cast if neither operation changes |
209 | // the size or element size of the input vector. |
210 | // TODO: We could allow size-changing ops if that doesn't harm codegen. |
211 | // cast (shuffle X, Mask) --> shuffle (cast X), Mask |
212 | Value *X; |
213 | ArrayRef<int> Mask; |
214 | if (match(V: Src, P: m_OneUse(SubPattern: m_Shuffle(v1: m_Value(V&: X), v2: m_Undef(), mask: m_Mask(Mask))))) { |
215 | // TODO: Allow scalable vectors? |
216 | auto *SrcTy = dyn_cast<FixedVectorType>(Val: X->getType()); |
217 | auto *DestTy = dyn_cast<FixedVectorType>(Val: Ty); |
218 | if (SrcTy && DestTy && |
219 | SrcTy->getNumElements() == DestTy->getNumElements() && |
220 | SrcTy->getPrimitiveSizeInBits() == DestTy->getPrimitiveSizeInBits()) { |
221 | Value *CastX = Builder.CreateCast(Op: CI.getOpcode(), V: X, DestTy); |
222 | return new ShuffleVectorInst(CastX, Mask); |
223 | } |
224 | } |
225 | |
226 | return nullptr; |
227 | } |
228 | |
229 | /// Constants and extensions/truncates from the destination type are always |
230 | /// free to be evaluated in that type. This is a helper for canEvaluate*. |
231 | static bool canAlwaysEvaluateInType(Value *V, Type *Ty) { |
232 | if (isa<Constant>(Val: V)) |
233 | return match(V, P: m_ImmConstant()); |
234 | |
235 | Value *X; |
236 | if ((match(V, P: m_ZExtOrSExt(Op: m_Value(V&: X))) || match(V, P: m_Trunc(Op: m_Value(V&: X)))) && |
237 | X->getType() == Ty) |
238 | return true; |
239 | |
240 | return false; |
241 | } |
242 | |
243 | /// Filter out values that we can not evaluate in the destination type for free. |
244 | /// This is a helper for canEvaluate*. |
245 | static bool canNotEvaluateInType(Value *V, Type *Ty) { |
246 | if (!isa<Instruction>(Val: V)) |
247 | return true; |
248 | // We don't extend or shrink something that has multiple uses -- doing so |
249 | // would require duplicating the instruction which isn't profitable. |
250 | if (!V->hasOneUse()) |
251 | return true; |
252 | |
253 | return false; |
254 | } |
255 | |
256 | /// Return true if we can evaluate the specified expression tree as type Ty |
257 | /// instead of its larger type, and arrive with the same value. |
258 | /// This is used by code that tries to eliminate truncates. |
259 | /// |
260 | /// Ty will always be a type smaller than V. We should return true if trunc(V) |
261 | /// can be computed by computing V in the smaller type. If V is an instruction, |
262 | /// then trunc(inst(x,y)) can be computed as inst(trunc(x),trunc(y)), which only |
263 | /// makes sense if x and y can be efficiently truncated. |
264 | /// |
265 | /// This function works on both vectors and scalars. |
266 | /// |
267 | static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, |
268 | Instruction *CxtI) { |
269 | if (canAlwaysEvaluateInType(V, Ty)) |
270 | return true; |
271 | if (canNotEvaluateInType(V, Ty)) |
272 | return false; |
273 | |
274 | auto *I = cast<Instruction>(Val: V); |
275 | Type *OrigTy = V->getType(); |
276 | switch (I->getOpcode()) { |
277 | case Instruction::Add: |
278 | case Instruction::Sub: |
279 | case Instruction::Mul: |
280 | case Instruction::And: |
281 | case Instruction::Or: |
282 | case Instruction::Xor: |
283 | // These operators can all arbitrarily be extended or truncated. |
284 | return canEvaluateTruncated(V: I->getOperand(i: 0), Ty, IC, CxtI) && |
285 | canEvaluateTruncated(V: I->getOperand(i: 1), Ty, IC, CxtI); |
286 | |
287 | case Instruction::UDiv: |
288 | case Instruction::URem: { |
289 | // UDiv and URem can be truncated if all the truncated bits are zero. |
290 | uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); |
291 | uint32_t BitWidth = Ty->getScalarSizeInBits(); |
292 | assert(BitWidth < OrigBitWidth && "Unexpected bitwidths!" ); |
293 | APInt Mask = APInt::getBitsSetFrom(numBits: OrigBitWidth, loBit: BitWidth); |
294 | if (IC.MaskedValueIsZero(V: I->getOperand(i: 0), Mask, Depth: 0, CxtI) && |
295 | IC.MaskedValueIsZero(V: I->getOperand(i: 1), Mask, Depth: 0, CxtI)) { |
296 | return canEvaluateTruncated(V: I->getOperand(i: 0), Ty, IC, CxtI) && |
297 | canEvaluateTruncated(V: I->getOperand(i: 1), Ty, IC, CxtI); |
298 | } |
299 | break; |
300 | } |
301 | case Instruction::Shl: { |
302 | // If we are truncating the result of this SHL, and if it's a shift of an |
303 | // inrange amount, we can always perform a SHL in a smaller type. |
304 | uint32_t BitWidth = Ty->getScalarSizeInBits(); |
305 | KnownBits AmtKnownBits = |
306 | llvm::computeKnownBits(V: I->getOperand(i: 1), DL: IC.getDataLayout()); |
307 | if (AmtKnownBits.getMaxValue().ult(RHS: BitWidth)) |
308 | return canEvaluateTruncated(V: I->getOperand(i: 0), Ty, IC, CxtI) && |
309 | canEvaluateTruncated(V: I->getOperand(i: 1), Ty, IC, CxtI); |
310 | break; |
311 | } |
312 | case Instruction::LShr: { |
313 | // If this is a truncate of a logical shr, we can truncate it to a smaller |
314 | // lshr iff we know that the bits we would otherwise be shifting in are |
315 | // already zeros. |
316 | // TODO: It is enough to check that the bits we would be shifting in are |
317 | // zero - use AmtKnownBits.getMaxValue(). |
318 | uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); |
319 | uint32_t BitWidth = Ty->getScalarSizeInBits(); |
320 | KnownBits AmtKnownBits = |
321 | llvm::computeKnownBits(V: I->getOperand(i: 1), DL: IC.getDataLayout()); |
322 | APInt ShiftedBits = APInt::getBitsSetFrom(numBits: OrigBitWidth, loBit: BitWidth); |
323 | if (AmtKnownBits.getMaxValue().ult(RHS: BitWidth) && |
324 | IC.MaskedValueIsZero(V: I->getOperand(i: 0), Mask: ShiftedBits, Depth: 0, CxtI)) { |
325 | return canEvaluateTruncated(V: I->getOperand(i: 0), Ty, IC, CxtI) && |
326 | canEvaluateTruncated(V: I->getOperand(i: 1), Ty, IC, CxtI); |
327 | } |
328 | break; |
329 | } |
330 | case Instruction::AShr: { |
331 | // If this is a truncate of an arithmetic shr, we can truncate it to a |
332 | // smaller ashr iff we know that all the bits from the sign bit of the |
333 | // original type and the sign bit of the truncate type are similar. |
334 | // TODO: It is enough to check that the bits we would be shifting in are |
335 | // similar to sign bit of the truncate type. |
336 | uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); |
337 | uint32_t BitWidth = Ty->getScalarSizeInBits(); |
338 | KnownBits AmtKnownBits = |
339 | llvm::computeKnownBits(V: I->getOperand(i: 1), DL: IC.getDataLayout()); |
340 | unsigned ShiftedBits = OrigBitWidth - BitWidth; |
341 | if (AmtKnownBits.getMaxValue().ult(RHS: BitWidth) && |
342 | ShiftedBits < IC.ComputeNumSignBits(Op: I->getOperand(i: 0), Depth: 0, CxtI)) |
343 | return canEvaluateTruncated(V: I->getOperand(i: 0), Ty, IC, CxtI) && |
344 | canEvaluateTruncated(V: I->getOperand(i: 1), Ty, IC, CxtI); |
345 | break; |
346 | } |
347 | case Instruction::Trunc: |
348 | // trunc(trunc(x)) -> trunc(x) |
349 | return true; |
350 | case Instruction::ZExt: |
351 | case Instruction::SExt: |
352 | // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest |
353 | // trunc(ext(x)) -> trunc(x) if the source type is larger than the new dest |
354 | return true; |
355 | case Instruction::Select: { |
356 | SelectInst *SI = cast<SelectInst>(Val: I); |
357 | return canEvaluateTruncated(V: SI->getTrueValue(), Ty, IC, CxtI) && |
358 | canEvaluateTruncated(V: SI->getFalseValue(), Ty, IC, CxtI); |
359 | } |
360 | case Instruction::PHI: { |
361 | // We can change a phi if we can change all operands. Note that we never |
362 | // get into trouble with cyclic PHIs here because we only consider |
363 | // instructions with a single use. |
364 | PHINode *PN = cast<PHINode>(Val: I); |
365 | for (Value *IncValue : PN->incoming_values()) |
366 | if (!canEvaluateTruncated(V: IncValue, Ty, IC, CxtI)) |
367 | return false; |
368 | return true; |
369 | } |
370 | case Instruction::FPToUI: |
371 | case Instruction::FPToSI: { |
372 | // If the integer type can hold the max FP value, it is safe to cast |
373 | // directly to that type. Otherwise, we may create poison via overflow |
374 | // that did not exist in the original code. |
375 | Type *InputTy = I->getOperand(i: 0)->getType()->getScalarType(); |
376 | const fltSemantics &Semantics = InputTy->getFltSemantics(); |
377 | uint32_t MinBitWidth = |
378 | APFloatBase::semanticsIntSizeInBits(Semantics, |
379 | I->getOpcode() == Instruction::FPToSI); |
380 | return Ty->getScalarSizeInBits() >= MinBitWidth; |
381 | } |
382 | case Instruction::ShuffleVector: |
383 | return canEvaluateTruncated(V: I->getOperand(i: 0), Ty, IC, CxtI) && |
384 | canEvaluateTruncated(V: I->getOperand(i: 1), Ty, IC, CxtI); |
385 | default: |
386 | // TODO: Can handle more cases here. |
387 | break; |
388 | } |
389 | |
390 | return false; |
391 | } |
392 | |
393 | /// Given a vector that is bitcast to an integer, optionally logically |
394 | /// right-shifted, and truncated, convert it to an extractelement. |
395 | /// Example (big endian): |
396 | /// trunc (lshr (bitcast <4 x i32> %X to i128), 32) to i32 |
397 | /// ---> |
398 | /// extractelement <4 x i32> %X, 1 |
399 | static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, |
400 | InstCombinerImpl &IC) { |
401 | Value *TruncOp = Trunc.getOperand(i_nocapture: 0); |
402 | Type *DestType = Trunc.getType(); |
403 | if (!TruncOp->hasOneUse() || !isa<IntegerType>(Val: DestType)) |
404 | return nullptr; |
405 | |
406 | Value *VecInput = nullptr; |
407 | ConstantInt *ShiftVal = nullptr; |
408 | if (!match(V: TruncOp, P: m_CombineOr(L: m_BitCast(Op: m_Value(V&: VecInput)), |
409 | R: m_LShr(L: m_BitCast(Op: m_Value(V&: VecInput)), |
410 | R: m_ConstantInt(CI&: ShiftVal)))) || |
411 | !isa<VectorType>(Val: VecInput->getType())) |
412 | return nullptr; |
413 | |
414 | VectorType *VecType = cast<VectorType>(Val: VecInput->getType()); |
415 | unsigned VecWidth = VecType->getPrimitiveSizeInBits(); |
416 | unsigned DestWidth = DestType->getPrimitiveSizeInBits(); |
417 | unsigned ShiftAmount = ShiftVal ? ShiftVal->getZExtValue() : 0; |
418 | |
419 | if ((VecWidth % DestWidth != 0) || (ShiftAmount % DestWidth != 0)) |
420 | return nullptr; |
421 | |
422 | // If the element type of the vector doesn't match the result type, |
423 | // bitcast it to a vector type that we can extract from. |
424 | unsigned NumVecElts = VecWidth / DestWidth; |
425 | if (VecType->getElementType() != DestType) { |
426 | VecType = FixedVectorType::get(ElementType: DestType, NumElts: NumVecElts); |
427 | VecInput = IC.Builder.CreateBitCast(V: VecInput, DestTy: VecType, Name: "bc" ); |
428 | } |
429 | |
430 | unsigned Elt = ShiftAmount / DestWidth; |
431 | if (IC.getDataLayout().isBigEndian()) |
432 | Elt = NumVecElts - 1 - Elt; |
433 | |
434 | return ExtractElementInst::Create(Vec: VecInput, Idx: IC.Builder.getInt32(C: Elt)); |
435 | } |
436 | |
437 | /// Funnel/Rotate left/right may occur in a wider type than necessary because of |
438 | /// type promotion rules. Try to narrow the inputs and convert to funnel shift. |
439 | Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) { |
440 | assert((isa<VectorType>(Trunc.getSrcTy()) || |
441 | shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) && |
442 | "Don't narrow to an illegal scalar type" ); |
443 | |
444 | // Bail out on strange types. It is possible to handle some of these patterns |
445 | // even with non-power-of-2 sizes, but it is not a likely scenario. |
446 | Type *DestTy = Trunc.getType(); |
447 | unsigned NarrowWidth = DestTy->getScalarSizeInBits(); |
448 | unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits(); |
449 | if (!isPowerOf2_32(Value: NarrowWidth)) |
450 | return nullptr; |
451 | |
452 | // First, find an or'd pair of opposite shifts: |
453 | // trunc (or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)) |
454 | BinaryOperator *Or0, *Or1; |
455 | if (!match(V: Trunc.getOperand(i_nocapture: 0), P: m_OneUse(SubPattern: m_Or(L: m_BinOp(I&: Or0), R: m_BinOp(I&: Or1))))) |
456 | return nullptr; |
457 | |
458 | Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1; |
459 | if (!match(V: Or0, P: m_OneUse(SubPattern: m_LogicalShift(L: m_Value(V&: ShVal0), R: m_Value(V&: ShAmt0)))) || |
460 | !match(V: Or1, P: m_OneUse(SubPattern: m_LogicalShift(L: m_Value(V&: ShVal1), R: m_Value(V&: ShAmt1)))) || |
461 | Or0->getOpcode() == Or1->getOpcode()) |
462 | return nullptr; |
463 | |
464 | // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)). |
465 | if (Or0->getOpcode() == BinaryOperator::LShr) { |
466 | std::swap(a&: Or0, b&: Or1); |
467 | std::swap(a&: ShVal0, b&: ShVal1); |
468 | std::swap(a&: ShAmt0, b&: ShAmt1); |
469 | } |
470 | assert(Or0->getOpcode() == BinaryOperator::Shl && |
471 | Or1->getOpcode() == BinaryOperator::LShr && |
472 | "Illegal or(shift,shift) pair" ); |
473 | |
474 | // Match the shift amount operands for a funnel/rotate pattern. This always |
475 | // matches a subtraction on the R operand. |
476 | auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * { |
477 | // The shift amounts may add up to the narrow bit width: |
478 | // (shl ShVal0, L) | (lshr ShVal1, Width - L) |
479 | // If this is a funnel shift (different operands are shifted), then the |
480 | // shift amount can not over-shift (create poison) in the narrow type. |
481 | unsigned MaxShiftAmountWidth = Log2_32(Value: NarrowWidth); |
482 | APInt HiBitMask = ~APInt::getLowBitsSet(numBits: WideWidth, loBitsSet: MaxShiftAmountWidth); |
483 | if (ShVal0 == ShVal1 || MaskedValueIsZero(V: L, Mask: HiBitMask)) |
484 | if (match(V: R, P: m_OneUse(SubPattern: m_Sub(L: m_SpecificInt(V: Width), R: m_Specific(V: L))))) |
485 | return L; |
486 | |
487 | // The following patterns currently only work for rotation patterns. |
488 | // TODO: Add more general funnel-shift compatible patterns. |
489 | if (ShVal0 != ShVal1) |
490 | return nullptr; |
491 | |
492 | // The shift amount may be masked with negation: |
493 | // (shl ShVal0, (X & (Width - 1))) | (lshr ShVal1, ((-X) & (Width - 1))) |
494 | Value *X; |
495 | unsigned Mask = Width - 1; |
496 | if (match(V: L, P: m_And(L: m_Value(V&: X), R: m_SpecificInt(V: Mask))) && |
497 | match(V: R, P: m_And(L: m_Neg(V: m_Specific(V: X)), R: m_SpecificInt(V: Mask)))) |
498 | return X; |
499 | |
500 | // Same as above, but the shift amount may be extended after masking: |
501 | if (match(V: L, P: m_ZExt(Op: m_And(L: m_Value(V&: X), R: m_SpecificInt(V: Mask)))) && |
502 | match(V: R, P: m_ZExt(Op: m_And(L: m_Neg(V: m_Specific(V: X)), R: m_SpecificInt(V: Mask))))) |
503 | return X; |
504 | |
505 | return nullptr; |
506 | }; |
507 | |
508 | Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, NarrowWidth); |
509 | bool IsFshl = true; // Sub on LSHR. |
510 | if (!ShAmt) { |
511 | ShAmt = matchShiftAmount(ShAmt1, ShAmt0, NarrowWidth); |
512 | IsFshl = false; // Sub on SHL. |
513 | } |
514 | if (!ShAmt) |
515 | return nullptr; |
516 | |
517 | // The right-shifted value must have high zeros in the wide type (for example |
518 | // from 'zext', 'and' or 'shift'). High bits of the left-shifted value are |
519 | // truncated, so those do not matter. |
520 | APInt HiBitMask = APInt::getHighBitsSet(numBits: WideWidth, hiBitsSet: WideWidth - NarrowWidth); |
521 | if (!MaskedValueIsZero(V: ShVal1, Mask: HiBitMask, Depth: 0, CxtI: &Trunc)) |
522 | return nullptr; |
523 | |
524 | // Adjust the width of ShAmt for narrowed funnel shift operation: |
525 | // - Zero-extend if ShAmt is narrower than the destination type. |
526 | // - Truncate if ShAmt is wider, discarding non-significant high-order bits. |
527 | // This prepares ShAmt for llvm.fshl.i8(trunc(ShVal), trunc(ShVal), |
528 | // zext/trunc(ShAmt)). |
529 | Value *NarrowShAmt = Builder.CreateZExtOrTrunc(V: ShAmt, DestTy); |
530 | |
531 | Value *X, *Y; |
532 | X = Y = Builder.CreateTrunc(V: ShVal0, DestTy); |
533 | if (ShVal0 != ShVal1) |
534 | Y = Builder.CreateTrunc(V: ShVal1, DestTy); |
535 | Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; |
536 | Function *F = Intrinsic::getDeclaration(M: Trunc.getModule(), id: IID, Tys: DestTy); |
537 | return CallInst::Create(Func: F, Args: {X, Y, NarrowShAmt}); |
538 | } |
539 | |
540 | /// Try to narrow the width of math or bitwise logic instructions by pulling a |
541 | /// truncate ahead of binary operators. |
542 | Instruction *InstCombinerImpl::narrowBinOp(TruncInst &Trunc) { |
543 | Type *SrcTy = Trunc.getSrcTy(); |
544 | Type *DestTy = Trunc.getType(); |
545 | unsigned SrcWidth = SrcTy->getScalarSizeInBits(); |
546 | unsigned DestWidth = DestTy->getScalarSizeInBits(); |
547 | |
548 | if (!isa<VectorType>(Val: SrcTy) && !shouldChangeType(From: SrcTy, To: DestTy)) |
549 | return nullptr; |
550 | |
551 | BinaryOperator *BinOp; |
552 | if (!match(V: Trunc.getOperand(i_nocapture: 0), P: m_OneUse(SubPattern: m_BinOp(I&: BinOp)))) |
553 | return nullptr; |
554 | |
555 | Value *BinOp0 = BinOp->getOperand(i_nocapture: 0); |
556 | Value *BinOp1 = BinOp->getOperand(i_nocapture: 1); |
557 | switch (BinOp->getOpcode()) { |
558 | case Instruction::And: |
559 | case Instruction::Or: |
560 | case Instruction::Xor: |
561 | case Instruction::Add: |
562 | case Instruction::Sub: |
563 | case Instruction::Mul: { |
564 | Constant *C; |
565 | if (match(V: BinOp0, P: m_Constant(C))) { |
566 | // trunc (binop C, X) --> binop (trunc C', X) |
567 | Constant *NarrowC = ConstantExpr::getTrunc(C, Ty: DestTy); |
568 | Value *TruncX = Builder.CreateTrunc(V: BinOp1, DestTy); |
569 | return BinaryOperator::Create(Op: BinOp->getOpcode(), S1: NarrowC, S2: TruncX); |
570 | } |
571 | if (match(V: BinOp1, P: m_Constant(C))) { |
572 | // trunc (binop X, C) --> binop (trunc X, C') |
573 | Constant *NarrowC = ConstantExpr::getTrunc(C, Ty: DestTy); |
574 | Value *TruncX = Builder.CreateTrunc(V: BinOp0, DestTy); |
575 | return BinaryOperator::Create(Op: BinOp->getOpcode(), S1: TruncX, S2: NarrowC); |
576 | } |
577 | Value *X; |
578 | if (match(V: BinOp0, P: m_ZExtOrSExt(Op: m_Value(V&: X))) && X->getType() == DestTy) { |
579 | // trunc (binop (ext X), Y) --> binop X, (trunc Y) |
580 | Value *NarrowOp1 = Builder.CreateTrunc(V: BinOp1, DestTy); |
581 | return BinaryOperator::Create(Op: BinOp->getOpcode(), S1: X, S2: NarrowOp1); |
582 | } |
583 | if (match(V: BinOp1, P: m_ZExtOrSExt(Op: m_Value(V&: X))) && X->getType() == DestTy) { |
584 | // trunc (binop Y, (ext X)) --> binop (trunc Y), X |
585 | Value *NarrowOp0 = Builder.CreateTrunc(V: BinOp0, DestTy); |
586 | return BinaryOperator::Create(Op: BinOp->getOpcode(), S1: NarrowOp0, S2: X); |
587 | } |
588 | break; |
589 | } |
590 | case Instruction::LShr: |
591 | case Instruction::AShr: { |
592 | // trunc (*shr (trunc A), C) --> trunc(*shr A, C) |
593 | Value *A; |
594 | Constant *C; |
595 | if (match(V: BinOp0, P: m_Trunc(Op: m_Value(V&: A))) && match(V: BinOp1, P: m_Constant(C))) { |
596 | unsigned MaxShiftAmt = SrcWidth - DestWidth; |
597 | // If the shift is small enough, all zero/sign bits created by the shift |
598 | // are removed by the trunc. |
599 | if (match(V: C, P: m_SpecificInt_ICMP(Predicate: ICmpInst::ICMP_ULE, |
600 | Threshold: APInt(SrcWidth, MaxShiftAmt)))) { |
601 | auto *OldShift = cast<Instruction>(Val: Trunc.getOperand(i_nocapture: 0)); |
602 | bool IsExact = OldShift->isExact(); |
603 | if (Constant *ShAmt = ConstantFoldIntegerCast(C, DestTy: A->getType(), |
604 | /*IsSigned*/ true, DL)) { |
605 | ShAmt = Constant::mergeUndefsWith(C: ShAmt, Other: C); |
606 | Value *Shift = |
607 | OldShift->getOpcode() == Instruction::AShr |
608 | ? Builder.CreateAShr(LHS: A, RHS: ShAmt, Name: OldShift->getName(), isExact: IsExact) |
609 | : Builder.CreateLShr(LHS: A, RHS: ShAmt, Name: OldShift->getName(), isExact: IsExact); |
610 | return CastInst::CreateTruncOrBitCast(S: Shift, Ty: DestTy); |
611 | } |
612 | } |
613 | } |
614 | break; |
615 | } |
616 | default: break; |
617 | } |
618 | |
619 | if (Instruction *NarrowOr = narrowFunnelShift(Trunc)) |
620 | return NarrowOr; |
621 | |
622 | return nullptr; |
623 | } |
624 | |
625 | /// Try to narrow the width of a splat shuffle. This could be generalized to any |
626 | /// shuffle with a constant operand, but we limit the transform to avoid |
627 | /// creating a shuffle type that targets may not be able to lower effectively. |
628 | static Instruction *shrinkSplatShuffle(TruncInst &Trunc, |
629 | InstCombiner::BuilderTy &Builder) { |
630 | auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: Trunc.getOperand(i_nocapture: 0)); |
631 | if (Shuf && Shuf->hasOneUse() && match(V: Shuf->getOperand(i_nocapture: 1), P: m_Undef()) && |
632 | all_equal(Range: Shuf->getShuffleMask()) && |
633 | Shuf->getType() == Shuf->getOperand(i_nocapture: 0)->getType()) { |
634 | // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Poison, SplatMask |
635 | // trunc (shuf X, Poison, SplatMask) --> shuf (trunc X), Poison, SplatMask |
636 | Value *NarrowOp = Builder.CreateTrunc(V: Shuf->getOperand(i_nocapture: 0), DestTy: Trunc.getType()); |
637 | return new ShuffleVectorInst(NarrowOp, Shuf->getShuffleMask()); |
638 | } |
639 | |
640 | return nullptr; |
641 | } |
642 | |
643 | /// Try to narrow the width of an insert element. This could be generalized for |
644 | /// any vector constant, but we limit the transform to insertion into undef to |
645 | /// avoid potential backend problems from unsupported insertion widths. This |
646 | /// could also be extended to handle the case of inserting a scalar constant |
647 | /// into a vector variable. |
648 | static Instruction *shrinkInsertElt(CastInst &Trunc, |
649 | InstCombiner::BuilderTy &Builder) { |
650 | Instruction::CastOps Opcode = Trunc.getOpcode(); |
651 | assert((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) && |
652 | "Unexpected instruction for shrinking" ); |
653 | |
654 | auto *InsElt = dyn_cast<InsertElementInst>(Val: Trunc.getOperand(i_nocapture: 0)); |
655 | if (!InsElt || !InsElt->hasOneUse()) |
656 | return nullptr; |
657 | |
658 | Type *DestTy = Trunc.getType(); |
659 | Type *DestScalarTy = DestTy->getScalarType(); |
660 | Value *VecOp = InsElt->getOperand(i_nocapture: 0); |
661 | Value *ScalarOp = InsElt->getOperand(i_nocapture: 1); |
662 | Value *Index = InsElt->getOperand(i_nocapture: 2); |
663 | |
664 | if (match(V: VecOp, P: m_Undef())) { |
665 | // trunc (inselt undef, X, Index) --> inselt undef, (trunc X), Index |
666 | // fptrunc (inselt undef, X, Index) --> inselt undef, (fptrunc X), Index |
667 | UndefValue *NarrowUndef = UndefValue::get(T: DestTy); |
668 | Value *NarrowOp = Builder.CreateCast(Op: Opcode, V: ScalarOp, DestTy: DestScalarTy); |
669 | return InsertElementInst::Create(Vec: NarrowUndef, NewElt: NarrowOp, Idx: Index); |
670 | } |
671 | |
672 | return nullptr; |
673 | } |
674 | |
675 | Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { |
676 | if (Instruction *Result = commonCastTransforms(CI&: Trunc)) |
677 | return Result; |
678 | |
679 | Value *Src = Trunc.getOperand(i_nocapture: 0); |
680 | Type *DestTy = Trunc.getType(), *SrcTy = Src->getType(); |
681 | unsigned DestWidth = DestTy->getScalarSizeInBits(); |
682 | unsigned SrcWidth = SrcTy->getScalarSizeInBits(); |
683 | |
684 | // Attempt to truncate the entire input expression tree to the destination |
685 | // type. Only do this if the dest type is a simple type, don't convert the |
686 | // expression tree to something weird like i93 unless the source is also |
687 | // strange. |
688 | if ((DestTy->isVectorTy() || shouldChangeType(From: SrcTy, To: DestTy)) && |
689 | canEvaluateTruncated(V: Src, Ty: DestTy, IC&: *this, CxtI: &Trunc)) { |
690 | |
691 | // If this cast is a truncate, evaluting in a different type always |
692 | // eliminates the cast, so it is always a win. |
693 | LLVM_DEBUG( |
694 | dbgs() << "ICE: EvaluateInDifferentType converting expression type" |
695 | " to avoid cast: " |
696 | << Trunc << '\n'); |
697 | Value *Res = EvaluateInDifferentType(V: Src, Ty: DestTy, isSigned: false); |
698 | assert(Res->getType() == DestTy); |
699 | return replaceInstUsesWith(I&: Trunc, V: Res); |
700 | } |
701 | |
702 | // For integer types, check if we can shorten the entire input expression to |
703 | // DestWidth * 2, which won't allow removing the truncate, but reducing the |
704 | // width may enable further optimizations, e.g. allowing for larger |
705 | // vectorization factors. |
706 | if (auto *DestITy = dyn_cast<IntegerType>(Val: DestTy)) { |
707 | if (DestWidth * 2 < SrcWidth) { |
708 | auto *NewDestTy = DestITy->getExtendedType(); |
709 | if (shouldChangeType(From: SrcTy, To: NewDestTy) && |
710 | canEvaluateTruncated(V: Src, Ty: NewDestTy, IC&: *this, CxtI: &Trunc)) { |
711 | LLVM_DEBUG( |
712 | dbgs() << "ICE: EvaluateInDifferentType converting expression type" |
713 | " to reduce the width of operand of" |
714 | << Trunc << '\n'); |
715 | Value *Res = EvaluateInDifferentType(V: Src, Ty: NewDestTy, isSigned: false); |
716 | return new TruncInst(Res, DestTy); |
717 | } |
718 | } |
719 | } |
720 | |
721 | // Test if the trunc is the user of a select which is part of a |
722 | // minimum or maximum operation. If so, don't do any more simplification. |
723 | // Even simplifying demanded bits can break the canonical form of a |
724 | // min/max. |
725 | Value *LHS, *RHS; |
726 | if (SelectInst *Sel = dyn_cast<SelectInst>(Val: Src)) |
727 | if (matchSelectPattern(V: Sel, LHS, RHS).Flavor != SPF_UNKNOWN) |
728 | return nullptr; |
729 | |
730 | // See if we can simplify any instructions used by the input whose sole |
731 | // purpose is to compute bits we don't care about. |
732 | if (SimplifyDemandedInstructionBits(Inst&: Trunc)) |
733 | return &Trunc; |
734 | |
735 | if (DestWidth == 1) { |
736 | Value *Zero = Constant::getNullValue(Ty: SrcTy); |
737 | |
738 | Value *X; |
739 | const APInt *C1; |
740 | Constant *C2; |
741 | if (match(V: Src, P: m_OneUse(SubPattern: m_Shr(L: m_Shl(L: m_Power2(V&: C1), R: m_Value(V&: X)), |
742 | R: m_ImmConstant(C&: C2))))) { |
743 | // trunc ((C1 << X) >> C2) to i1 --> X == (C2-cttz(C1)), where C1 is pow2 |
744 | Constant *Log2C1 = ConstantInt::get(Ty: SrcTy, V: C1->exactLogBase2()); |
745 | Constant *CmpC = ConstantExpr::getSub(C1: C2, C2: Log2C1); |
746 | return new ICmpInst(ICmpInst::ICMP_EQ, X, CmpC); |
747 | } |
748 | |
749 | Constant *C; |
750 | if (match(V: Src, P: m_OneUse(SubPattern: m_LShr(L: m_Value(V&: X), R: m_Constant(C))))) { |
751 | // trunc (lshr X, C) to i1 --> icmp ne (and X, C'), 0 |
752 | Constant *One = ConstantInt::get(Ty: SrcTy, V: APInt(SrcWidth, 1)); |
753 | Constant *MaskC = ConstantExpr::getShl(C1: One, C2: C); |
754 | Value *And = Builder.CreateAnd(LHS: X, RHS: MaskC); |
755 | return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); |
756 | } |
757 | if (match(V: Src, P: m_OneUse(SubPattern: m_c_Or(L: m_LShr(L: m_Value(V&: X), R: m_ImmConstant(C)), |
758 | R: m_Deferred(V: X))))) { |
759 | // trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0 |
760 | Constant *One = ConstantInt::get(Ty: SrcTy, V: APInt(SrcWidth, 1)); |
761 | Constant *MaskC = ConstantExpr::getShl(C1: One, C2: C); |
762 | Value *And = Builder.CreateAnd(LHS: X, RHS: Builder.CreateOr(LHS: MaskC, RHS: One)); |
763 | return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); |
764 | } |
765 | |
766 | { |
767 | const APInt *C; |
768 | if (match(V: Src, P: m_Shl(L: m_APInt(Res&: C), R: m_Value(V&: X))) && (*C)[0] == 1) { |
769 | // trunc (C << X) to i1 --> X == 0, where C is odd |
770 | return new ICmpInst(ICmpInst::Predicate::ICMP_EQ, X, Zero); |
771 | } |
772 | } |
773 | } |
774 | |
775 | Value *A, *B; |
776 | Constant *C; |
777 | if (match(V: Src, P: m_LShr(L: m_SExt(Op: m_Value(V&: A)), R: m_Constant(C)))) { |
778 | unsigned AWidth = A->getType()->getScalarSizeInBits(); |
779 | unsigned MaxShiftAmt = SrcWidth - std::max(a: DestWidth, b: AWidth); |
780 | auto *OldSh = cast<Instruction>(Val: Src); |
781 | bool IsExact = OldSh->isExact(); |
782 | |
783 | // If the shift is small enough, all zero bits created by the shift are |
784 | // removed by the trunc. |
785 | if (match(V: C, P: m_SpecificInt_ICMP(Predicate: ICmpInst::ICMP_ULE, |
786 | Threshold: APInt(SrcWidth, MaxShiftAmt)))) { |
787 | auto GetNewShAmt = [&](unsigned Width) { |
788 | Constant *MaxAmt = ConstantInt::get(Ty: SrcTy, V: Width - 1, IsSigned: false); |
789 | Constant *Cmp = |
790 | ConstantFoldCompareInstOperands(Predicate: ICmpInst::ICMP_ULT, LHS: C, RHS: MaxAmt, DL); |
791 | Constant *ShAmt = ConstantFoldSelectInstruction(Cond: Cmp, V1: C, V2: MaxAmt); |
792 | return ConstantFoldCastOperand(Opcode: Instruction::Trunc, C: ShAmt, DestTy: A->getType(), |
793 | DL); |
794 | }; |
795 | |
796 | // trunc (lshr (sext A), C) --> ashr A, C |
797 | if (A->getType() == DestTy) { |
798 | Constant *ShAmt = GetNewShAmt(DestWidth); |
799 | ShAmt = Constant::mergeUndefsWith(C: ShAmt, Other: C); |
800 | return IsExact ? BinaryOperator::CreateExactAShr(V1: A, V2: ShAmt) |
801 | : BinaryOperator::CreateAShr(V1: A, V2: ShAmt); |
802 | } |
803 | // The types are mismatched, so create a cast after shifting: |
804 | // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C) |
805 | if (Src->hasOneUse()) { |
806 | Constant *ShAmt = GetNewShAmt(AWidth); |
807 | Value *Shift = Builder.CreateAShr(LHS: A, RHS: ShAmt, Name: "" , isExact: IsExact); |
808 | return CastInst::CreateIntegerCast(S: Shift, Ty: DestTy, isSigned: true); |
809 | } |
810 | } |
811 | // TODO: Mask high bits with 'and'. |
812 | } |
813 | |
814 | if (Instruction *I = narrowBinOp(Trunc)) |
815 | return I; |
816 | |
817 | if (Instruction *I = shrinkSplatShuffle(Trunc, Builder)) |
818 | return I; |
819 | |
820 | if (Instruction *I = shrinkInsertElt(Trunc, Builder)) |
821 | return I; |
822 | |
823 | if (Src->hasOneUse() && |
824 | (isa<VectorType>(Val: SrcTy) || shouldChangeType(From: SrcTy, To: DestTy))) { |
825 | // Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the |
826 | // dest type is native and cst < dest size. |
827 | if (match(V: Src, P: m_Shl(L: m_Value(V&: A), R: m_Constant(C))) && |
828 | !match(V: A, P: m_Shr(L: m_Value(), R: m_Constant()))) { |
829 | // Skip shifts of shift by constants. It undoes a combine in |
830 | // FoldShiftByConstant and is the extend in reg pattern. |
831 | APInt Threshold = APInt(C->getType()->getScalarSizeInBits(), DestWidth); |
832 | if (match(V: C, P: m_SpecificInt_ICMP(Predicate: ICmpInst::ICMP_ULT, Threshold))) { |
833 | Value *NewTrunc = Builder.CreateTrunc(V: A, DestTy, Name: A->getName() + ".tr" ); |
834 | return BinaryOperator::Create(Op: Instruction::Shl, S1: NewTrunc, |
835 | S2: ConstantExpr::getTrunc(C, Ty: DestTy)); |
836 | } |
837 | } |
838 | } |
839 | |
840 | if (Instruction *I = foldVecTruncToExtElt(Trunc, IC&: *this)) |
841 | return I; |
842 | |
843 | // Whenever an element is extracted from a vector, and then truncated, |
844 | // canonicalize by converting it to a bitcast followed by an |
845 | // extractelement. |
846 | // |
847 | // Example (little endian): |
848 | // trunc (extractelement <4 x i64> %X, 0) to i32 |
849 | // ---> |
850 | // extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0 |
851 | Value *VecOp; |
852 | ConstantInt *Cst; |
853 | if (match(V: Src, P: m_OneUse(SubPattern: m_ExtractElt(Val: m_Value(V&: VecOp), Idx: m_ConstantInt(CI&: Cst))))) { |
854 | auto *VecOpTy = cast<VectorType>(Val: VecOp->getType()); |
855 | auto VecElts = VecOpTy->getElementCount(); |
856 | |
857 | // A badly fit destination size would result in an invalid cast. |
858 | if (SrcWidth % DestWidth == 0) { |
859 | uint64_t TruncRatio = SrcWidth / DestWidth; |
860 | uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio; |
861 | uint64_t VecOpIdx = Cst->getZExtValue(); |
862 | uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1 |
863 | : VecOpIdx * TruncRatio; |
864 | assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() && |
865 | "overflow 32-bits" ); |
866 | |
867 | auto *BitCastTo = |
868 | VectorType::get(ElementType: DestTy, NumElements: BitCastNumElts, Scalable: VecElts.isScalable()); |
869 | Value *BitCast = Builder.CreateBitCast(V: VecOp, DestTy: BitCastTo); |
870 | return ExtractElementInst::Create(Vec: BitCast, Idx: Builder.getInt32(C: NewIdx)); |
871 | } |
872 | } |
873 | |
874 | // trunc (ctlz_i32(zext(A), B) --> add(ctlz_i16(A, B), C) |
875 | if (match(Src, m_OneUse(m_Intrinsic<Intrinsic::ctlz>(m_ZExt(m_Value(A)), |
876 | m_Value(B))))) { |
877 | unsigned AWidth = A->getType()->getScalarSizeInBits(); |
878 | if (AWidth == DestWidth && AWidth > Log2_32(Value: SrcWidth)) { |
879 | Value *WidthDiff = ConstantInt::get(Ty: A->getType(), V: SrcWidth - AWidth); |
880 | Value *NarrowCtlz = |
881 | Builder.CreateIntrinsic(Intrinsic::ctlz, {Trunc.getType()}, {A, B}); |
882 | return BinaryOperator::CreateAdd(V1: NarrowCtlz, V2: WidthDiff); |
883 | } |
884 | } |
885 | |
886 | if (match(V: Src, P: m_VScale())) { |
887 | if (Trunc.getFunction() && |
888 | Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { |
889 | Attribute Attr = |
890 | Trunc.getFunction()->getFnAttribute(Attribute::VScaleRange); |
891 | if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { |
892 | if (Log2_32(Value: *MaxVScale) < DestWidth) { |
893 | Value *VScale = Builder.CreateVScale(Scaling: ConstantInt::get(Ty: DestTy, V: 1)); |
894 | return replaceInstUsesWith(I&: Trunc, V: VScale); |
895 | } |
896 | } |
897 | } |
898 | } |
899 | |
900 | bool Changed = false; |
901 | if (!Trunc.hasNoSignedWrap() && |
902 | ComputeMaxSignificantBits(Op: Src, /*Depth=*/0, CxtI: &Trunc) <= DestWidth) { |
903 | Trunc.setHasNoSignedWrap(true); |
904 | Changed = true; |
905 | } |
906 | if (!Trunc.hasNoUnsignedWrap() && |
907 | MaskedValueIsZero(V: Src, Mask: APInt::getBitsSetFrom(numBits: SrcWidth, loBit: DestWidth), |
908 | /*Depth=*/0, CxtI: &Trunc)) { |
909 | Trunc.setHasNoUnsignedWrap(true); |
910 | Changed = true; |
911 | } |
912 | |
913 | return Changed ? &Trunc : nullptr; |
914 | } |
915 | |
916 | Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, |
917 | ZExtInst &Zext) { |
918 | // If we are just checking for a icmp eq of a single bit and zext'ing it |
919 | // to an integer, then shift the bit to the appropriate place and then |
920 | // cast to integer to avoid the comparison. |
921 | |
922 | // FIXME: This set of transforms does not check for extra uses and/or creates |
923 | // an extra instruction (an optional final cast is not included |
924 | // in the transform comments). We may also want to favor icmp over |
925 | // shifts in cases of equal instructions because icmp has better |
926 | // analysis in general (invert the transform). |
927 | |
928 | const APInt *Op1CV; |
929 | if (match(V: Cmp->getOperand(i_nocapture: 1), P: m_APInt(Res&: Op1CV))) { |
930 | |
931 | // zext (x <s 0) to i32 --> x>>u31 true if signbit set. |
932 | if (Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isZero()) { |
933 | Value *In = Cmp->getOperand(i_nocapture: 0); |
934 | Value *Sh = ConstantInt::get(Ty: In->getType(), |
935 | V: In->getType()->getScalarSizeInBits() - 1); |
936 | In = Builder.CreateLShr(LHS: In, RHS: Sh, Name: In->getName() + ".lobit" ); |
937 | if (In->getType() != Zext.getType()) |
938 | In = Builder.CreateIntCast(V: In, DestTy: Zext.getType(), isSigned: false /*ZExt*/); |
939 | |
940 | return replaceInstUsesWith(I&: Zext, V: In); |
941 | } |
942 | |
943 | // zext (X == 0) to i32 --> X^1 iff X has only the low bit set. |
944 | // zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. |
945 | // zext (X != 0) to i32 --> X iff X has only the low bit set. |
946 | // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. |
947 | |
948 | if (Op1CV->isZero() && Cmp->isEquality()) { |
949 | // Exactly 1 possible 1? But not the high-bit because that is |
950 | // canonicalized to this form. |
951 | KnownBits Known = computeKnownBits(V: Cmp->getOperand(i_nocapture: 0), Depth: 0, CxtI: &Zext); |
952 | APInt KnownZeroMask(~Known.Zero); |
953 | uint32_t ShAmt = KnownZeroMask.logBase2(); |
954 | bool IsExpectShAmt = KnownZeroMask.isPowerOf2() && |
955 | (Zext.getType()->getScalarSizeInBits() != ShAmt + 1); |
956 | if (IsExpectShAmt && |
957 | (Cmp->getOperand(i_nocapture: 0)->getType() == Zext.getType() || |
958 | Cmp->getPredicate() == ICmpInst::ICMP_NE || ShAmt == 0)) { |
959 | Value *In = Cmp->getOperand(i_nocapture: 0); |
960 | if (ShAmt) { |
961 | // Perform a logical shr by shiftamt. |
962 | // Insert the shift to put the result in the low bit. |
963 | In = Builder.CreateLShr(LHS: In, RHS: ConstantInt::get(Ty: In->getType(), V: ShAmt), |
964 | Name: In->getName() + ".lobit" ); |
965 | } |
966 | |
967 | // Toggle the low bit for "X == 0". |
968 | if (Cmp->getPredicate() == ICmpInst::ICMP_EQ) |
969 | In = Builder.CreateXor(LHS: In, RHS: ConstantInt::get(Ty: In->getType(), V: 1)); |
970 | |
971 | if (Zext.getType() == In->getType()) |
972 | return replaceInstUsesWith(I&: Zext, V: In); |
973 | |
974 | Value *IntCast = Builder.CreateIntCast(V: In, DestTy: Zext.getType(), isSigned: false); |
975 | return replaceInstUsesWith(I&: Zext, V: IntCast); |
976 | } |
977 | } |
978 | } |
979 | |
980 | if (Cmp->isEquality() && Zext.getType() == Cmp->getOperand(i_nocapture: 0)->getType()) { |
981 | // Test if a bit is clear/set using a shifted-one mask: |
982 | // zext (icmp eq (and X, (1 << ShAmt)), 0) --> and (lshr (not X), ShAmt), 1 |
983 | // zext (icmp ne (and X, (1 << ShAmt)), 0) --> and (lshr X, ShAmt), 1 |
984 | Value *X, *ShAmt; |
985 | if (Cmp->hasOneUse() && match(V: Cmp->getOperand(i_nocapture: 1), P: m_ZeroInt()) && |
986 | match(V: Cmp->getOperand(i_nocapture: 0), |
987 | P: m_OneUse(SubPattern: m_c_And(L: m_Shl(L: m_One(), R: m_Value(V&: ShAmt)), R: m_Value(V&: X))))) { |
988 | if (Cmp->getPredicate() == ICmpInst::ICMP_EQ) |
989 | X = Builder.CreateNot(V: X); |
990 | Value *Lshr = Builder.CreateLShr(LHS: X, RHS: ShAmt); |
991 | Value *And1 = Builder.CreateAnd(LHS: Lshr, RHS: ConstantInt::get(Ty: X->getType(), V: 1)); |
992 | return replaceInstUsesWith(I&: Zext, V: And1); |
993 | } |
994 | } |
995 | |
996 | return nullptr; |
997 | } |
998 | |
999 | /// Determine if the specified value can be computed in the specified wider type |
1000 | /// and produce the same low bits. If not, return false. |
1001 | /// |
1002 | /// If this function returns true, it can also return a non-zero number of bits |
1003 | /// (in BitsToClear) which indicates that the value it computes is correct for |
1004 | /// the zero extend, but that the additional BitsToClear bits need to be zero'd |
1005 | /// out. For example, to promote something like: |
1006 | /// |
1007 | /// %B = trunc i64 %A to i32 |
1008 | /// %C = lshr i32 %B, 8 |
1009 | /// %E = zext i32 %C to i64 |
1010 | /// |
1011 | /// CanEvaluateZExtd for the 'lshr' will return true, and BitsToClear will be |
1012 | /// set to 8 to indicate that the promoted value needs to have bits 24-31 |
1013 | /// cleared in addition to bits 32-63. Since an 'and' will be generated to |
1014 | /// clear the top bits anyway, doing this has no extra cost. |
1015 | /// |
1016 | /// This function works on both vectors and scalars. |
1017 | static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, |
1018 | InstCombinerImpl &IC, Instruction *CxtI) { |
1019 | BitsToClear = 0; |
1020 | if (canAlwaysEvaluateInType(V, Ty)) |
1021 | return true; |
1022 | if (canNotEvaluateInType(V, Ty)) |
1023 | return false; |
1024 | |
1025 | auto *I = cast<Instruction>(Val: V); |
1026 | unsigned Tmp; |
1027 | switch (I->getOpcode()) { |
1028 | case Instruction::ZExt: // zext(zext(x)) -> zext(x). |
1029 | case Instruction::SExt: // zext(sext(x)) -> sext(x). |
1030 | case Instruction::Trunc: // zext(trunc(x)) -> trunc(x) or zext(x) |
1031 | return true; |
1032 | case Instruction::And: |
1033 | case Instruction::Or: |
1034 | case Instruction::Xor: |
1035 | case Instruction::Add: |
1036 | case Instruction::Sub: |
1037 | case Instruction::Mul: |
1038 | if (!canEvaluateZExtd(V: I->getOperand(i: 0), Ty, BitsToClear, IC, CxtI) || |
1039 | !canEvaluateZExtd(V: I->getOperand(i: 1), Ty, BitsToClear&: Tmp, IC, CxtI)) |
1040 | return false; |
1041 | // These can all be promoted if neither operand has 'bits to clear'. |
1042 | if (BitsToClear == 0 && Tmp == 0) |
1043 | return true; |
1044 | |
1045 | // If the operation is an AND/OR/XOR and the bits to clear are zero in the |
1046 | // other side, BitsToClear is ok. |
1047 | if (Tmp == 0 && I->isBitwiseLogicOp()) { |
1048 | // We use MaskedValueIsZero here for generality, but the case we care |
1049 | // about the most is constant RHS. |
1050 | unsigned VSize = V->getType()->getScalarSizeInBits(); |
1051 | if (IC.MaskedValueIsZero(V: I->getOperand(i: 1), |
1052 | Mask: APInt::getHighBitsSet(numBits: VSize, hiBitsSet: BitsToClear), |
1053 | Depth: 0, CxtI)) { |
1054 | // If this is an And instruction and all of the BitsToClear are |
1055 | // known to be zero we can reset BitsToClear. |
1056 | if (I->getOpcode() == Instruction::And) |
1057 | BitsToClear = 0; |
1058 | return true; |
1059 | } |
1060 | } |
1061 | |
1062 | // Otherwise, we don't know how to analyze this BitsToClear case yet. |
1063 | return false; |
1064 | |
1065 | case Instruction::Shl: { |
1066 | // We can promote shl(x, cst) if we can promote x. Since shl overwrites the |
1067 | // upper bits we can reduce BitsToClear by the shift amount. |
1068 | const APInt *Amt; |
1069 | if (match(V: I->getOperand(i: 1), P: m_APInt(Res&: Amt))) { |
1070 | if (!canEvaluateZExtd(V: I->getOperand(i: 0), Ty, BitsToClear, IC, CxtI)) |
1071 | return false; |
1072 | uint64_t ShiftAmt = Amt->getZExtValue(); |
1073 | BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0; |
1074 | return true; |
1075 | } |
1076 | return false; |
1077 | } |
1078 | case Instruction::LShr: { |
1079 | // We can promote lshr(x, cst) if we can promote x. This requires the |
1080 | // ultimate 'and' to clear out the high zero bits we're clearing out though. |
1081 | const APInt *Amt; |
1082 | if (match(V: I->getOperand(i: 1), P: m_APInt(Res&: Amt))) { |
1083 | if (!canEvaluateZExtd(V: I->getOperand(i: 0), Ty, BitsToClear, IC, CxtI)) |
1084 | return false; |
1085 | BitsToClear += Amt->getZExtValue(); |
1086 | if (BitsToClear > V->getType()->getScalarSizeInBits()) |
1087 | BitsToClear = V->getType()->getScalarSizeInBits(); |
1088 | return true; |
1089 | } |
1090 | // Cannot promote variable LSHR. |
1091 | return false; |
1092 | } |
1093 | case Instruction::Select: |
1094 | if (!canEvaluateZExtd(V: I->getOperand(i: 1), Ty, BitsToClear&: Tmp, IC, CxtI) || |
1095 | !canEvaluateZExtd(V: I->getOperand(i: 2), Ty, BitsToClear, IC, CxtI) || |
1096 | // TODO: If important, we could handle the case when the BitsToClear are |
1097 | // known zero in the disagreeing side. |
1098 | Tmp != BitsToClear) |
1099 | return false; |
1100 | return true; |
1101 | |
1102 | case Instruction::PHI: { |
1103 | // We can change a phi if we can change all operands. Note that we never |
1104 | // get into trouble with cyclic PHIs here because we only consider |
1105 | // instructions with a single use. |
1106 | PHINode *PN = cast<PHINode>(Val: I); |
1107 | if (!canEvaluateZExtd(V: PN->getIncomingValue(i: 0), Ty, BitsToClear, IC, CxtI)) |
1108 | return false; |
1109 | for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i) |
1110 | if (!canEvaluateZExtd(V: PN->getIncomingValue(i), Ty, BitsToClear&: Tmp, IC, CxtI) || |
1111 | // TODO: If important, we could handle the case when the BitsToClear |
1112 | // are known zero in the disagreeing input. |
1113 | Tmp != BitsToClear) |
1114 | return false; |
1115 | return true; |
1116 | } |
1117 | case Instruction::Call: |
1118 | // llvm.vscale() can always be executed in larger type, because the |
1119 | // value is automatically zero-extended. |
1120 | if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I)) |
1121 | if (II->getIntrinsicID() == Intrinsic::vscale) |
1122 | return true; |
1123 | return false; |
1124 | default: |
1125 | // TODO: Can handle more cases here. |
1126 | return false; |
1127 | } |
1128 | } |
1129 | |
1130 | Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { |
1131 | // If this zero extend is only used by a truncate, let the truncate be |
1132 | // eliminated before we try to optimize this zext. |
1133 | if (Zext.hasOneUse() && isa<TruncInst>(Val: Zext.user_back()) && |
1134 | !isa<Constant>(Val: Zext.getOperand(i_nocapture: 0))) |
1135 | return nullptr; |
1136 | |
1137 | // If one of the common conversion will work, do it. |
1138 | if (Instruction *Result = commonCastTransforms(CI&: Zext)) |
1139 | return Result; |
1140 | |
1141 | Value *Src = Zext.getOperand(i_nocapture: 0); |
1142 | Type *SrcTy = Src->getType(), *DestTy = Zext.getType(); |
1143 | |
1144 | // zext nneg bool x -> 0 |
1145 | if (SrcTy->isIntOrIntVectorTy(BitWidth: 1) && Zext.hasNonNeg()) |
1146 | return replaceInstUsesWith(I&: Zext, V: Constant::getNullValue(Ty: Zext.getType())); |
1147 | |
1148 | // Try to extend the entire expression tree to the wide destination type. |
1149 | unsigned BitsToClear; |
1150 | if (shouldChangeType(From: SrcTy, To: DestTy) && |
1151 | canEvaluateZExtd(V: Src, Ty: DestTy, BitsToClear, IC&: *this, CxtI: &Zext)) { |
1152 | assert(BitsToClear <= SrcTy->getScalarSizeInBits() && |
1153 | "Can't clear more bits than in SrcTy" ); |
1154 | |
1155 | // Okay, we can transform this! Insert the new expression now. |
1156 | LLVM_DEBUG( |
1157 | dbgs() << "ICE: EvaluateInDifferentType converting expression type" |
1158 | " to avoid zero extend: " |
1159 | << Zext << '\n'); |
1160 | Value *Res = EvaluateInDifferentType(V: Src, Ty: DestTy, isSigned: false); |
1161 | assert(Res->getType() == DestTy); |
1162 | |
1163 | // Preserve debug values referring to Src if the zext is its last use. |
1164 | if (auto *SrcOp = dyn_cast<Instruction>(Val: Src)) |
1165 | if (SrcOp->hasOneUse()) |
1166 | replaceAllDbgUsesWith(From&: *SrcOp, To&: *Res, DomPoint&: Zext, DT); |
1167 | |
1168 | uint32_t SrcBitsKept = SrcTy->getScalarSizeInBits() - BitsToClear; |
1169 | uint32_t DestBitSize = DestTy->getScalarSizeInBits(); |
1170 | |
1171 | // If the high bits are already filled with zeros, just replace this |
1172 | // cast with the result. |
1173 | if (MaskedValueIsZero(V: Res, |
1174 | Mask: APInt::getHighBitsSet(numBits: DestBitSize, |
1175 | hiBitsSet: DestBitSize - SrcBitsKept), |
1176 | Depth: 0, CxtI: &Zext)) |
1177 | return replaceInstUsesWith(I&: Zext, V: Res); |
1178 | |
1179 | // We need to emit an AND to clear the high bits. |
1180 | Constant *C = ConstantInt::get(Ty: Res->getType(), |
1181 | V: APInt::getLowBitsSet(numBits: DestBitSize, loBitsSet: SrcBitsKept)); |
1182 | return BinaryOperator::CreateAnd(V1: Res, V2: C); |
1183 | } |
1184 | |
1185 | // If this is a TRUNC followed by a ZEXT then we are dealing with integral |
1186 | // types and if the sizes are just right we can convert this into a logical |
1187 | // 'and' which will be much cheaper than the pair of casts. |
1188 | if (auto *CSrc = dyn_cast<TruncInst>(Val: Src)) { // A->B->C cast |
1189 | // TODO: Subsume this into EvaluateInDifferentType. |
1190 | |
1191 | // Get the sizes of the types involved. We know that the intermediate type |
1192 | // will be smaller than A or C, but don't know the relation between A and C. |
1193 | Value *A = CSrc->getOperand(i_nocapture: 0); |
1194 | unsigned SrcSize = A->getType()->getScalarSizeInBits(); |
1195 | unsigned MidSize = CSrc->getType()->getScalarSizeInBits(); |
1196 | unsigned DstSize = DestTy->getScalarSizeInBits(); |
1197 | // If we're actually extending zero bits, then if |
1198 | // SrcSize < DstSize: zext(a & mask) |
1199 | // SrcSize == DstSize: a & mask |
1200 | // SrcSize > DstSize: trunc(a) & mask |
1201 | if (SrcSize < DstSize) { |
1202 | APInt AndValue(APInt::getLowBitsSet(numBits: SrcSize, loBitsSet: MidSize)); |
1203 | Constant *AndConst = ConstantInt::get(Ty: A->getType(), V: AndValue); |
1204 | Value *And = Builder.CreateAnd(LHS: A, RHS: AndConst, Name: CSrc->getName() + ".mask" ); |
1205 | return new ZExtInst(And, DestTy); |
1206 | } |
1207 | |
1208 | if (SrcSize == DstSize) { |
1209 | APInt AndValue(APInt::getLowBitsSet(numBits: SrcSize, loBitsSet: MidSize)); |
1210 | return BinaryOperator::CreateAnd(V1: A, V2: ConstantInt::get(Ty: A->getType(), |
1211 | V: AndValue)); |
1212 | } |
1213 | if (SrcSize > DstSize) { |
1214 | Value *Trunc = Builder.CreateTrunc(V: A, DestTy); |
1215 | APInt AndValue(APInt::getLowBitsSet(numBits: DstSize, loBitsSet: MidSize)); |
1216 | return BinaryOperator::CreateAnd(V1: Trunc, |
1217 | V2: ConstantInt::get(Ty: Trunc->getType(), |
1218 | V: AndValue)); |
1219 | } |
1220 | } |
1221 | |
1222 | if (auto *Cmp = dyn_cast<ICmpInst>(Val: Src)) |
1223 | return transformZExtICmp(Cmp, Zext); |
1224 | |
1225 | // zext(trunc(X) & C) -> (X & zext(C)). |
1226 | Constant *C; |
1227 | Value *X; |
1228 | if (match(V: Src, P: m_OneUse(SubPattern: m_And(L: m_Trunc(Op: m_Value(V&: X)), R: m_Constant(C)))) && |
1229 | X->getType() == DestTy) |
1230 | return BinaryOperator::CreateAnd(V1: X, V2: Builder.CreateZExt(V: C, DestTy)); |
1231 | |
1232 | // zext((trunc(X) & C) ^ C) -> ((X & zext(C)) ^ zext(C)). |
1233 | Value *And; |
1234 | if (match(V: Src, P: m_OneUse(SubPattern: m_Xor(L: m_Value(V&: And), R: m_Constant(C)))) && |
1235 | match(V: And, P: m_OneUse(SubPattern: m_And(L: m_Trunc(Op: m_Value(V&: X)), R: m_Specific(V: C)))) && |
1236 | X->getType() == DestTy) { |
1237 | Value *ZC = Builder.CreateZExt(V: C, DestTy); |
1238 | return BinaryOperator::CreateXor(V1: Builder.CreateAnd(LHS: X, RHS: ZC), V2: ZC); |
1239 | } |
1240 | |
1241 | // If we are truncating, masking, and then zexting back to the original type, |
1242 | // that's just a mask. This is not handled by canEvaluateZextd if the |
1243 | // intermediate values have extra uses. This could be generalized further for |
1244 | // a non-constant mask operand. |
1245 | // zext (and (trunc X), C) --> and X, (zext C) |
1246 | if (match(V: Src, P: m_And(L: m_Trunc(Op: m_Value(V&: X)), R: m_Constant(C))) && |
1247 | X->getType() == DestTy) { |
1248 | Value *ZextC = Builder.CreateZExt(V: C, DestTy); |
1249 | return BinaryOperator::CreateAnd(V1: X, V2: ZextC); |
1250 | } |
1251 | |
1252 | if (match(V: Src, P: m_VScale())) { |
1253 | if (Zext.getFunction() && |
1254 | Zext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { |
1255 | Attribute Attr = |
1256 | Zext.getFunction()->getFnAttribute(Attribute::VScaleRange); |
1257 | if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { |
1258 | unsigned TypeWidth = Src->getType()->getScalarSizeInBits(); |
1259 | if (Log2_32(Value: *MaxVScale) < TypeWidth) { |
1260 | Value *VScale = Builder.CreateVScale(Scaling: ConstantInt::get(Ty: DestTy, V: 1)); |
1261 | return replaceInstUsesWith(I&: Zext, V: VScale); |
1262 | } |
1263 | } |
1264 | } |
1265 | } |
1266 | |
1267 | if (!Zext.hasNonNeg()) { |
1268 | // If this zero extend is only used by a shift, add nneg flag. |
1269 | if (Zext.hasOneUse() && |
1270 | SrcTy->getScalarSizeInBits() > |
1271 | Log2_64_Ceil(Value: DestTy->getScalarSizeInBits()) && |
1272 | match(V: Zext.user_back(), P: m_Shift(L: m_Value(), R: m_Specific(V: &Zext)))) { |
1273 | Zext.setNonNeg(); |
1274 | return &Zext; |
1275 | } |
1276 | |
1277 | if (isKnownNonNegative(V: Src, SQ: SQ.getWithInstruction(I: &Zext))) { |
1278 | Zext.setNonNeg(); |
1279 | return &Zext; |
1280 | } |
1281 | } |
1282 | |
1283 | return nullptr; |
1284 | } |
1285 | |
1286 | /// Transform (sext icmp) to bitwise / integer operations to eliminate the icmp. |
1287 | Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *Cmp, |
1288 | SExtInst &Sext) { |
1289 | Value *Op0 = Cmp->getOperand(i_nocapture: 0), *Op1 = Cmp->getOperand(i_nocapture: 1); |
1290 | ICmpInst::Predicate Pred = Cmp->getPredicate(); |
1291 | |
1292 | // Don't bother if Op1 isn't of vector or integer type. |
1293 | if (!Op1->getType()->isIntOrIntVectorTy()) |
1294 | return nullptr; |
1295 | |
1296 | if (Pred == ICmpInst::ICMP_SLT && match(V: Op1, P: m_ZeroInt())) { |
1297 | // sext (x <s 0) --> ashr x, 31 (all ones if negative) |
1298 | Value *Sh = ConstantInt::get(Ty: Op0->getType(), |
1299 | V: Op0->getType()->getScalarSizeInBits() - 1); |
1300 | Value *In = Builder.CreateAShr(LHS: Op0, RHS: Sh, Name: Op0->getName() + ".lobit" ); |
1301 | if (In->getType() != Sext.getType()) |
1302 | In = Builder.CreateIntCast(V: In, DestTy: Sext.getType(), isSigned: true /*SExt*/); |
1303 | |
1304 | return replaceInstUsesWith(I&: Sext, V: In); |
1305 | } |
1306 | |
1307 | if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Val: Op1)) { |
1308 | // If we know that only one bit of the LHS of the icmp can be set and we |
1309 | // have an equality comparison with zero or a power of 2, we can transform |
1310 | // the icmp and sext into bitwise/integer operations. |
1311 | if (Cmp->hasOneUse() && |
1312 | Cmp->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){ |
1313 | KnownBits Known = computeKnownBits(V: Op0, Depth: 0, CxtI: &Sext); |
1314 | |
1315 | APInt KnownZeroMask(~Known.Zero); |
1316 | if (KnownZeroMask.isPowerOf2()) { |
1317 | Value *In = Cmp->getOperand(i_nocapture: 0); |
1318 | |
1319 | // If the icmp tests for a known zero bit we can constant fold it. |
1320 | if (!Op1C->isZero() && Op1C->getValue() != KnownZeroMask) { |
1321 | Value *V = Pred == ICmpInst::ICMP_NE ? |
1322 | ConstantInt::getAllOnesValue(Ty: Sext.getType()) : |
1323 | ConstantInt::getNullValue(Ty: Sext.getType()); |
1324 | return replaceInstUsesWith(I&: Sext, V); |
1325 | } |
1326 | |
1327 | if (!Op1C->isZero() == (Pred == ICmpInst::ICMP_NE)) { |
1328 | // sext ((x & 2^n) == 0) -> (x >> n) - 1 |
1329 | // sext ((x & 2^n) != 2^n) -> (x >> n) - 1 |
1330 | unsigned ShiftAmt = KnownZeroMask.countr_zero(); |
1331 | // Perform a right shift to place the desired bit in the LSB. |
1332 | if (ShiftAmt) |
1333 | In = Builder.CreateLShr(LHS: In, |
1334 | RHS: ConstantInt::get(Ty: In->getType(), V: ShiftAmt)); |
1335 | |
1336 | // At this point "In" is either 1 or 0. Subtract 1 to turn |
1337 | // {1, 0} -> {0, -1}. |
1338 | In = Builder.CreateAdd(LHS: In, |
1339 | RHS: ConstantInt::getAllOnesValue(Ty: In->getType()), |
1340 | Name: "sext" ); |
1341 | } else { |
1342 | // sext ((x & 2^n) != 0) -> (x << bitwidth-n) a>> bitwidth-1 |
1343 | // sext ((x & 2^n) == 2^n) -> (x << bitwidth-n) a>> bitwidth-1 |
1344 | unsigned ShiftAmt = KnownZeroMask.countl_zero(); |
1345 | // Perform a left shift to place the desired bit in the MSB. |
1346 | if (ShiftAmt) |
1347 | In = Builder.CreateShl(LHS: In, |
1348 | RHS: ConstantInt::get(Ty: In->getType(), V: ShiftAmt)); |
1349 | |
1350 | // Distribute the bit over the whole bit width. |
1351 | In = Builder.CreateAShr(LHS: In, RHS: ConstantInt::get(Ty: In->getType(), |
1352 | V: KnownZeroMask.getBitWidth() - 1), Name: "sext" ); |
1353 | } |
1354 | |
1355 | if (Sext.getType() == In->getType()) |
1356 | return replaceInstUsesWith(I&: Sext, V: In); |
1357 | return CastInst::CreateIntegerCast(S: In, Ty: Sext.getType(), isSigned: true/*SExt*/); |
1358 | } |
1359 | } |
1360 | } |
1361 | |
1362 | return nullptr; |
1363 | } |
1364 | |
1365 | /// Return true if we can take the specified value and return it as type Ty |
1366 | /// without inserting any new casts and without changing the value of the common |
1367 | /// low bits. This is used by code that tries to promote integer operations to |
1368 | /// a wider types will allow us to eliminate the extension. |
1369 | /// |
1370 | /// This function works on both vectors and scalars. |
1371 | /// |
1372 | static bool canEvaluateSExtd(Value *V, Type *Ty) { |
1373 | assert(V->getType()->getScalarSizeInBits() < Ty->getScalarSizeInBits() && |
1374 | "Can't sign extend type to a smaller type" ); |
1375 | if (canAlwaysEvaluateInType(V, Ty)) |
1376 | return true; |
1377 | if (canNotEvaluateInType(V, Ty)) |
1378 | return false; |
1379 | |
1380 | auto *I = cast<Instruction>(Val: V); |
1381 | switch (I->getOpcode()) { |
1382 | case Instruction::SExt: // sext(sext(x)) -> sext(x) |
1383 | case Instruction::ZExt: // sext(zext(x)) -> zext(x) |
1384 | case Instruction::Trunc: // sext(trunc(x)) -> trunc(x) or sext(x) |
1385 | return true; |
1386 | case Instruction::And: |
1387 | case Instruction::Or: |
1388 | case Instruction::Xor: |
1389 | case Instruction::Add: |
1390 | case Instruction::Sub: |
1391 | case Instruction::Mul: |
1392 | // These operators can all arbitrarily be extended if their inputs can. |
1393 | return canEvaluateSExtd(V: I->getOperand(i: 0), Ty) && |
1394 | canEvaluateSExtd(V: I->getOperand(i: 1), Ty); |
1395 | |
1396 | //case Instruction::Shl: TODO |
1397 | //case Instruction::LShr: TODO |
1398 | |
1399 | case Instruction::Select: |
1400 | return canEvaluateSExtd(V: I->getOperand(i: 1), Ty) && |
1401 | canEvaluateSExtd(V: I->getOperand(i: 2), Ty); |
1402 | |
1403 | case Instruction::PHI: { |
1404 | // We can change a phi if we can change all operands. Note that we never |
1405 | // get into trouble with cyclic PHIs here because we only consider |
1406 | // instructions with a single use. |
1407 | PHINode *PN = cast<PHINode>(Val: I); |
1408 | for (Value *IncValue : PN->incoming_values()) |
1409 | if (!canEvaluateSExtd(V: IncValue, Ty)) return false; |
1410 | return true; |
1411 | } |
1412 | default: |
1413 | // TODO: Can handle more cases here. |
1414 | break; |
1415 | } |
1416 | |
1417 | return false; |
1418 | } |
1419 | |
1420 | Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { |
1421 | // If this sign extend is only used by a truncate, let the truncate be |
1422 | // eliminated before we try to optimize this sext. |
1423 | if (Sext.hasOneUse() && isa<TruncInst>(Val: Sext.user_back())) |
1424 | return nullptr; |
1425 | |
1426 | if (Instruction *I = commonCastTransforms(CI&: Sext)) |
1427 | return I; |
1428 | |
1429 | Value *Src = Sext.getOperand(i_nocapture: 0); |
1430 | Type *SrcTy = Src->getType(), *DestTy = Sext.getType(); |
1431 | unsigned SrcBitSize = SrcTy->getScalarSizeInBits(); |
1432 | unsigned DestBitSize = DestTy->getScalarSizeInBits(); |
1433 | |
1434 | // If the value being extended is zero or positive, use a zext instead. |
1435 | if (isKnownNonNegative(V: Src, SQ: SQ.getWithInstruction(I: &Sext))) { |
1436 | auto CI = CastInst::Create(Instruction::ZExt, S: Src, Ty: DestTy); |
1437 | CI->setNonNeg(true); |
1438 | return CI; |
1439 | } |
1440 | |
1441 | // Try to extend the entire expression tree to the wide destination type. |
1442 | if (shouldChangeType(From: SrcTy, To: DestTy) && canEvaluateSExtd(V: Src, Ty: DestTy)) { |
1443 | // Okay, we can transform this! Insert the new expression now. |
1444 | LLVM_DEBUG( |
1445 | dbgs() << "ICE: EvaluateInDifferentType converting expression type" |
1446 | " to avoid sign extend: " |
1447 | << Sext << '\n'); |
1448 | Value *Res = EvaluateInDifferentType(V: Src, Ty: DestTy, isSigned: true); |
1449 | assert(Res->getType() == DestTy); |
1450 | |
1451 | // If the high bits are already filled with sign bit, just replace this |
1452 | // cast with the result. |
1453 | if (ComputeNumSignBits(Op: Res, Depth: 0, CxtI: &Sext) > DestBitSize - SrcBitSize) |
1454 | return replaceInstUsesWith(I&: Sext, V: Res); |
1455 | |
1456 | // We need to emit a shl + ashr to do the sign extend. |
1457 | Value *ShAmt = ConstantInt::get(Ty: DestTy, V: DestBitSize-SrcBitSize); |
1458 | return BinaryOperator::CreateAShr(V1: Builder.CreateShl(LHS: Res, RHS: ShAmt, Name: "sext" ), |
1459 | V2: ShAmt); |
1460 | } |
1461 | |
1462 | Value *X; |
1463 | if (match(V: Src, P: m_Trunc(Op: m_Value(V&: X)))) { |
1464 | // If the input has more sign bits than bits truncated, then convert |
1465 | // directly to final type. |
1466 | unsigned XBitSize = X->getType()->getScalarSizeInBits(); |
1467 | if (ComputeNumSignBits(Op: X, Depth: 0, CxtI: &Sext) > XBitSize - SrcBitSize) |
1468 | return CastInst::CreateIntegerCast(S: X, Ty: DestTy, /* isSigned */ true); |
1469 | |
1470 | // If input is a trunc from the destination type, then convert into shifts. |
1471 | if (Src->hasOneUse() && X->getType() == DestTy) { |
1472 | // sext (trunc X) --> ashr (shl X, C), C |
1473 | Constant *ShAmt = ConstantInt::get(Ty: DestTy, V: DestBitSize - SrcBitSize); |
1474 | return BinaryOperator::CreateAShr(V1: Builder.CreateShl(LHS: X, RHS: ShAmt), V2: ShAmt); |
1475 | } |
1476 | |
1477 | // If we are replacing shifted-in high zero bits with sign bits, convert |
1478 | // the logic shift to arithmetic shift and eliminate the cast to |
1479 | // intermediate type: |
1480 | // sext (trunc (lshr Y, C)) --> sext/trunc (ashr Y, C) |
1481 | Value *Y; |
1482 | if (Src->hasOneUse() && |
1483 | match(V: X, P: m_LShr(L: m_Value(V&: Y), |
1484 | R: m_SpecificIntAllowPoison(V: XBitSize - SrcBitSize)))) { |
1485 | Value *Ashr = Builder.CreateAShr(LHS: Y, RHS: XBitSize - SrcBitSize); |
1486 | return CastInst::CreateIntegerCast(S: Ashr, Ty: DestTy, /* isSigned */ true); |
1487 | } |
1488 | } |
1489 | |
1490 | if (auto *Cmp = dyn_cast<ICmpInst>(Val: Src)) |
1491 | return transformSExtICmp(Cmp, Sext); |
1492 | |
1493 | // If the input is a shl/ashr pair of a same constant, then this is a sign |
1494 | // extension from a smaller value. If we could trust arbitrary bitwidth |
1495 | // integers, we could turn this into a truncate to the smaller bit and then |
1496 | // use a sext for the whole extension. Since we don't, look deeper and check |
1497 | // for a truncate. If the source and dest are the same type, eliminate the |
1498 | // trunc and extend and just do shifts. For example, turn: |
1499 | // %a = trunc i32 %i to i8 |
1500 | // %b = shl i8 %a, C |
1501 | // %c = ashr i8 %b, C |
1502 | // %d = sext i8 %c to i32 |
1503 | // into: |
1504 | // %a = shl i32 %i, 32-(8-C) |
1505 | // %d = ashr i32 %a, 32-(8-C) |
1506 | Value *A = nullptr; |
1507 | // TODO: Eventually this could be subsumed by EvaluateInDifferentType. |
1508 | Constant *BA = nullptr, *CA = nullptr; |
1509 | if (match(V: Src, P: m_AShr(L: m_Shl(L: m_Trunc(Op: m_Value(V&: A)), R: m_Constant(C&: BA)), |
1510 | R: m_ImmConstant(C&: CA))) && |
1511 | BA->isElementWiseEqual(Y: CA) && A->getType() == DestTy) { |
1512 | Constant *WideCurrShAmt = |
1513 | ConstantFoldCastOperand(Opcode: Instruction::SExt, C: CA, DestTy, DL); |
1514 | assert(WideCurrShAmt && "Constant folding of ImmConstant cannot fail" ); |
1515 | Constant *NumLowbitsLeft = ConstantExpr::getSub( |
1516 | C1: ConstantInt::get(Ty: DestTy, V: SrcTy->getScalarSizeInBits()), C2: WideCurrShAmt); |
1517 | Constant *NewShAmt = ConstantExpr::getSub( |
1518 | C1: ConstantInt::get(Ty: DestTy, V: DestTy->getScalarSizeInBits()), |
1519 | C2: NumLowbitsLeft); |
1520 | NewShAmt = |
1521 | Constant::mergeUndefsWith(C: Constant::mergeUndefsWith(C: NewShAmt, Other: BA), Other: CA); |
1522 | A = Builder.CreateShl(LHS: A, RHS: NewShAmt, Name: Sext.getName()); |
1523 | return BinaryOperator::CreateAShr(V1: A, V2: NewShAmt); |
1524 | } |
1525 | |
1526 | // Splatting a bit of constant-index across a value: |
1527 | // sext (ashr (trunc iN X to iM), M-1) to iN --> ashr (shl X, N-M), N-1 |
1528 | // If the dest type is different, use a cast (adjust use check). |
1529 | if (match(V: Src, P: m_OneUse(SubPattern: m_AShr(L: m_Trunc(Op: m_Value(V&: X)), |
1530 | R: m_SpecificInt(V: SrcBitSize - 1))))) { |
1531 | Type *XTy = X->getType(); |
1532 | unsigned XBitSize = XTy->getScalarSizeInBits(); |
1533 | Constant *ShlAmtC = ConstantInt::get(Ty: XTy, V: XBitSize - SrcBitSize); |
1534 | Constant *AshrAmtC = ConstantInt::get(Ty: XTy, V: XBitSize - 1); |
1535 | if (XTy == DestTy) |
1536 | return BinaryOperator::CreateAShr(V1: Builder.CreateShl(LHS: X, RHS: ShlAmtC), |
1537 | V2: AshrAmtC); |
1538 | if (cast<BinaryOperator>(Val: Src)->getOperand(i_nocapture: 0)->hasOneUse()) { |
1539 | Value *Ashr = Builder.CreateAShr(LHS: Builder.CreateShl(LHS: X, RHS: ShlAmtC), RHS: AshrAmtC); |
1540 | return CastInst::CreateIntegerCast(S: Ashr, Ty: DestTy, /* isSigned */ true); |
1541 | } |
1542 | } |
1543 | |
1544 | if (match(V: Src, P: m_VScale())) { |
1545 | if (Sext.getFunction() && |
1546 | Sext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { |
1547 | Attribute Attr = |
1548 | Sext.getFunction()->getFnAttribute(Attribute::VScaleRange); |
1549 | if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { |
1550 | if (Log2_32(Value: *MaxVScale) < (SrcBitSize - 1)) { |
1551 | Value *VScale = Builder.CreateVScale(Scaling: ConstantInt::get(Ty: DestTy, V: 1)); |
1552 | return replaceInstUsesWith(I&: Sext, V: VScale); |
1553 | } |
1554 | } |
1555 | } |
1556 | } |
1557 | |
1558 | return nullptr; |
1559 | } |
1560 | |
1561 | /// Return a Constant* for the specified floating-point constant if it fits |
1562 | /// in the specified FP type without changing its value. |
1563 | static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { |
1564 | bool losesInfo; |
1565 | APFloat F = CFP->getValueAPF(); |
1566 | (void)F.convert(ToSemantics: Sem, RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo); |
1567 | return !losesInfo; |
1568 | } |
1569 | |
1570 | static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) { |
1571 | if (CFP->getType() == Type::getPPC_FP128Ty(C&: CFP->getContext())) |
1572 | return nullptr; // No constant folding of this. |
1573 | // See if the value can be truncated to bfloat and then reextended. |
1574 | if (PreferBFloat && fitsInFPType(CFP, Sem: APFloat::BFloat())) |
1575 | return Type::getBFloatTy(C&: CFP->getContext()); |
1576 | // See if the value can be truncated to half and then reextended. |
1577 | if (!PreferBFloat && fitsInFPType(CFP, Sem: APFloat::IEEEhalf())) |
1578 | return Type::getHalfTy(C&: CFP->getContext()); |
1579 | // See if the value can be truncated to float and then reextended. |
1580 | if (fitsInFPType(CFP, Sem: APFloat::IEEEsingle())) |
1581 | return Type::getFloatTy(C&: CFP->getContext()); |
1582 | if (CFP->getType()->isDoubleTy()) |
1583 | return nullptr; // Won't shrink. |
1584 | if (fitsInFPType(CFP, Sem: APFloat::IEEEdouble())) |
1585 | return Type::getDoubleTy(C&: CFP->getContext()); |
1586 | // Don't try to shrink to various long double types. |
1587 | return nullptr; |
1588 | } |
1589 | |
1590 | // Determine if this is a vector of ConstantFPs and if so, return the minimal |
1591 | // type we can safely truncate all elements to. |
1592 | static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) { |
1593 | auto *CV = dyn_cast<Constant>(Val: V); |
1594 | auto *CVVTy = dyn_cast<FixedVectorType>(Val: V->getType()); |
1595 | if (!CV || !CVVTy) |
1596 | return nullptr; |
1597 | |
1598 | Type *MinType = nullptr; |
1599 | |
1600 | unsigned NumElts = CVVTy->getNumElements(); |
1601 | |
1602 | // For fixed-width vectors we find the minimal type by looking |
1603 | // through the constant values of the vector. |
1604 | for (unsigned i = 0; i != NumElts; ++i) { |
1605 | if (isa<UndefValue>(Val: CV->getAggregateElement(Elt: i))) |
1606 | continue; |
1607 | |
1608 | auto *CFP = dyn_cast_or_null<ConstantFP>(Val: CV->getAggregateElement(Elt: i)); |
1609 | if (!CFP) |
1610 | return nullptr; |
1611 | |
1612 | Type *T = shrinkFPConstant(CFP, PreferBFloat); |
1613 | if (!T) |
1614 | return nullptr; |
1615 | |
1616 | // If we haven't found a type yet or this type has a larger mantissa than |
1617 | // our previous type, this is our new minimal type. |
1618 | if (!MinType || T->getFPMantissaWidth() > MinType->getFPMantissaWidth()) |
1619 | MinType = T; |
1620 | } |
1621 | |
1622 | // Make a vector type from the minimal type. |
1623 | return MinType ? FixedVectorType::get(ElementType: MinType, NumElts) : nullptr; |
1624 | } |
1625 | |
1626 | /// Find the minimum FP type we can safely truncate to. |
1627 | static Type *getMinimumFPType(Value *V, bool PreferBFloat) { |
1628 | if (auto *FPExt = dyn_cast<FPExtInst>(Val: V)) |
1629 | return FPExt->getOperand(i_nocapture: 0)->getType(); |
1630 | |
1631 | // If this value is a constant, return the constant in the smallest FP type |
1632 | // that can accurately represent it. This allows us to turn |
1633 | // (float)((double)X+2.0) into x+2.0f. |
1634 | if (auto *CFP = dyn_cast<ConstantFP>(Val: V)) |
1635 | if (Type *T = shrinkFPConstant(CFP, PreferBFloat)) |
1636 | return T; |
1637 | |
1638 | // We can only correctly find a minimum type for a scalable vector when it is |
1639 | // a splat. For splats of constant values the fpext is wrapped up as a |
1640 | // ConstantExpr. |
1641 | if (auto *FPCExt = dyn_cast<ConstantExpr>(Val: V)) |
1642 | if (FPCExt->getOpcode() == Instruction::FPExt) |
1643 | return FPCExt->getOperand(i_nocapture: 0)->getType(); |
1644 | |
1645 | // Try to shrink a vector of FP constants. This returns nullptr on scalable |
1646 | // vectors |
1647 | if (Type *T = shrinkFPConstantVector(V, PreferBFloat)) |
1648 | return T; |
1649 | |
1650 | return V->getType(); |
1651 | } |
1652 | |
1653 | /// Return true if the cast from integer to FP can be proven to be exact for all |
1654 | /// possible inputs (the conversion does not lose any precision). |
1655 | static bool isKnownExactCastIntToFP(CastInst &I, InstCombinerImpl &IC) { |
1656 | CastInst::CastOps Opcode = I.getOpcode(); |
1657 | assert((Opcode == CastInst::SIToFP || Opcode == CastInst::UIToFP) && |
1658 | "Unexpected cast" ); |
1659 | Value *Src = I.getOperand(i_nocapture: 0); |
1660 | Type *SrcTy = Src->getType(); |
1661 | Type *FPTy = I.getType(); |
1662 | bool IsSigned = Opcode == Instruction::SIToFP; |
1663 | int SrcSize = (int)SrcTy->getScalarSizeInBits() - IsSigned; |
1664 | |
1665 | // Easy case - if the source integer type has less bits than the FP mantissa, |
1666 | // then the cast must be exact. |
1667 | int DestNumSigBits = FPTy->getFPMantissaWidth(); |
1668 | if (SrcSize <= DestNumSigBits) |
1669 | return true; |
1670 | |
1671 | // Cast from FP to integer and back to FP is independent of the intermediate |
1672 | // integer width because of poison on overflow. |
1673 | Value *F; |
1674 | if (match(V: Src, P: m_FPToSI(Op: m_Value(V&: F))) || match(V: Src, P: m_FPToUI(Op: m_Value(V&: F)))) { |
1675 | // If this is uitofp (fptosi F), the source needs an extra bit to avoid |
1676 | // potential rounding of negative FP input values. |
1677 | int SrcNumSigBits = F->getType()->getFPMantissaWidth(); |
1678 | if (!IsSigned && match(V: Src, P: m_FPToSI(Op: m_Value()))) |
1679 | SrcNumSigBits++; |
1680 | |
1681 | // [su]itofp (fpto[su]i F) --> exact if the source type has less or equal |
1682 | // significant bits than the destination (and make sure neither type is |
1683 | // weird -- ppc_fp128). |
1684 | if (SrcNumSigBits > 0 && DestNumSigBits > 0 && |
1685 | SrcNumSigBits <= DestNumSigBits) |
1686 | return true; |
1687 | } |
1688 | |
1689 | // TODO: |
1690 | // Try harder to find if the source integer type has less significant bits. |
1691 | // For example, compute number of sign bits. |
1692 | KnownBits SrcKnown = IC.computeKnownBits(V: Src, Depth: 0, CxtI: &I); |
1693 | int SigBits = (int)SrcTy->getScalarSizeInBits() - |
1694 | SrcKnown.countMinLeadingZeros() - |
1695 | SrcKnown.countMinTrailingZeros(); |
1696 | if (SigBits <= DestNumSigBits) |
1697 | return true; |
1698 | |
1699 | return false; |
1700 | } |
1701 | |
1702 | Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) { |
1703 | if (Instruction *I = commonCastTransforms(CI&: FPT)) |
1704 | return I; |
1705 | |
1706 | // If we have fptrunc(OpI (fpextend x), (fpextend y)), we would like to |
1707 | // simplify this expression to avoid one or more of the trunc/extend |
1708 | // operations if we can do so without changing the numerical results. |
1709 | // |
1710 | // The exact manner in which the widths of the operands interact to limit |
1711 | // what we can and cannot do safely varies from operation to operation, and |
1712 | // is explained below in the various case statements. |
1713 | Type *Ty = FPT.getType(); |
1714 | auto *BO = dyn_cast<BinaryOperator>(Val: FPT.getOperand(i_nocapture: 0)); |
1715 | if (BO && BO->hasOneUse()) { |
1716 | Type *LHSMinType = |
1717 | getMinimumFPType(V: BO->getOperand(i_nocapture: 0), /*PreferBFloat=*/Ty->isBFloatTy()); |
1718 | Type *RHSMinType = |
1719 | getMinimumFPType(V: BO->getOperand(i_nocapture: 1), /*PreferBFloat=*/Ty->isBFloatTy()); |
1720 | unsigned OpWidth = BO->getType()->getFPMantissaWidth(); |
1721 | unsigned LHSWidth = LHSMinType->getFPMantissaWidth(); |
1722 | unsigned RHSWidth = RHSMinType->getFPMantissaWidth(); |
1723 | unsigned SrcWidth = std::max(a: LHSWidth, b: RHSWidth); |
1724 | unsigned DstWidth = Ty->getFPMantissaWidth(); |
1725 | switch (BO->getOpcode()) { |
1726 | default: break; |
1727 | case Instruction::FAdd: |
1728 | case Instruction::FSub: |
1729 | // For addition and subtraction, the infinitely precise result can |
1730 | // essentially be arbitrarily wide; proving that double rounding |
1731 | // will not occur because the result of OpI is exact (as we will for |
1732 | // FMul, for example) is hopeless. However, we *can* nonetheless |
1733 | // frequently know that double rounding cannot occur (or that it is |
1734 | // innocuous) by taking advantage of the specific structure of |
1735 | // infinitely-precise results that admit double rounding. |
1736 | // |
1737 | // Specifically, if OpWidth >= 2*DstWdith+1 and DstWidth is sufficient |
1738 | // to represent both sources, we can guarantee that the double |
1739 | // rounding is innocuous (See p50 of Figueroa's 2000 PhD thesis, |
1740 | // "A Rigorous Framework for Fully Supporting the IEEE Standard ..." |
1741 | // for proof of this fact). |
1742 | // |
1743 | // Note: Figueroa does not consider the case where DstFormat != |
1744 | // SrcFormat. It's possible (likely even!) that this analysis |
1745 | // could be tightened for those cases, but they are rare (the main |
1746 | // case of interest here is (float)((double)float + float)). |
1747 | if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) { |
1748 | Value *LHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 0), DestTy: Ty); |
1749 | Value *RHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 1), DestTy: Ty); |
1750 | Instruction *RI = BinaryOperator::Create(Op: BO->getOpcode(), S1: LHS, S2: RHS); |
1751 | RI->copyFastMathFlags(I: BO); |
1752 | return RI; |
1753 | } |
1754 | break; |
1755 | case Instruction::FMul: |
1756 | // For multiplication, the infinitely precise result has at most |
1757 | // LHSWidth + RHSWidth significant bits; if OpWidth is sufficient |
1758 | // that such a value can be exactly represented, then no double |
1759 | // rounding can possibly occur; we can safely perform the operation |
1760 | // in the destination format if it can represent both sources. |
1761 | if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) { |
1762 | Value *LHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 0), DestTy: Ty); |
1763 | Value *RHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 1), DestTy: Ty); |
1764 | return BinaryOperator::CreateFMulFMF(V1: LHS, V2: RHS, FMFSource: BO); |
1765 | } |
1766 | break; |
1767 | case Instruction::FDiv: |
1768 | // For division, we use again use the bound from Figueroa's |
1769 | // dissertation. I am entirely certain that this bound can be |
1770 | // tightened in the unbalanced operand case by an analysis based on |
1771 | // the diophantine rational approximation bound, but the well-known |
1772 | // condition used here is a good conservative first pass. |
1773 | // TODO: Tighten bound via rigorous analysis of the unbalanced case. |
1774 | if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) { |
1775 | Value *LHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 0), DestTy: Ty); |
1776 | Value *RHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 1), DestTy: Ty); |
1777 | return BinaryOperator::CreateFDivFMF(V1: LHS, V2: RHS, FMFSource: BO); |
1778 | } |
1779 | break; |
1780 | case Instruction::FRem: { |
1781 | // Remainder is straightforward. Remainder is always exact, so the |
1782 | // type of OpI doesn't enter into things at all. We simply evaluate |
1783 | // in whichever source type is larger, then convert to the |
1784 | // destination type. |
1785 | if (SrcWidth == OpWidth) |
1786 | break; |
1787 | Value *LHS, *RHS; |
1788 | if (LHSWidth == SrcWidth) { |
1789 | LHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 0), DestTy: LHSMinType); |
1790 | RHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 1), DestTy: LHSMinType); |
1791 | } else { |
1792 | LHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 0), DestTy: RHSMinType); |
1793 | RHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 1), DestTy: RHSMinType); |
1794 | } |
1795 | |
1796 | Value *ExactResult = Builder.CreateFRemFMF(L: LHS, R: RHS, FMFSource: BO); |
1797 | return CastInst::CreateFPCast(S: ExactResult, Ty); |
1798 | } |
1799 | } |
1800 | } |
1801 | |
1802 | // (fptrunc (fneg x)) -> (fneg (fptrunc x)) |
1803 | Value *X; |
1804 | Instruction *Op = dyn_cast<Instruction>(Val: FPT.getOperand(i_nocapture: 0)); |
1805 | if (Op && Op->hasOneUse()) { |
1806 | // FIXME: The FMF should propagate from the fptrunc, not the source op. |
1807 | IRBuilder<>::FastMathFlagGuard FMFG(Builder); |
1808 | if (isa<FPMathOperator>(Val: Op)) |
1809 | Builder.setFastMathFlags(Op->getFastMathFlags()); |
1810 | |
1811 | if (match(V: Op, P: m_FNeg(X: m_Value(V&: X)))) { |
1812 | Value *InnerTrunc = Builder.CreateFPTrunc(V: X, DestTy: Ty); |
1813 | |
1814 | return UnaryOperator::CreateFNegFMF(Op: InnerTrunc, FMFSource: Op); |
1815 | } |
1816 | |
1817 | // If we are truncating a select that has an extended operand, we can |
1818 | // narrow the other operand and do the select as a narrow op. |
1819 | Value *Cond, *X, *Y; |
1820 | if (match(V: Op, P: m_Select(C: m_Value(V&: Cond), L: m_FPExt(Op: m_Value(V&: X)), R: m_Value(V&: Y))) && |
1821 | X->getType() == Ty) { |
1822 | // fptrunc (select Cond, (fpext X), Y --> select Cond, X, (fptrunc Y) |
1823 | Value *NarrowY = Builder.CreateFPTrunc(V: Y, DestTy: Ty); |
1824 | Value *Sel = Builder.CreateSelect(C: Cond, True: X, False: NarrowY, Name: "narrow.sel" , MDFrom: Op); |
1825 | return replaceInstUsesWith(I&: FPT, V: Sel); |
1826 | } |
1827 | if (match(V: Op, P: m_Select(C: m_Value(V&: Cond), L: m_Value(V&: Y), R: m_FPExt(Op: m_Value(V&: X)))) && |
1828 | X->getType() == Ty) { |
1829 | // fptrunc (select Cond, Y, (fpext X) --> select Cond, (fptrunc Y), X |
1830 | Value *NarrowY = Builder.CreateFPTrunc(V: Y, DestTy: Ty); |
1831 | Value *Sel = Builder.CreateSelect(C: Cond, True: NarrowY, False: X, Name: "narrow.sel" , MDFrom: Op); |
1832 | return replaceInstUsesWith(I&: FPT, V: Sel); |
1833 | } |
1834 | } |
1835 | |
1836 | if (auto *II = dyn_cast<IntrinsicInst>(Val: FPT.getOperand(i_nocapture: 0))) { |
1837 | switch (II->getIntrinsicID()) { |
1838 | default: break; |
1839 | case Intrinsic::ceil: |
1840 | case Intrinsic::fabs: |
1841 | case Intrinsic::floor: |
1842 | case Intrinsic::nearbyint: |
1843 | case Intrinsic::rint: |
1844 | case Intrinsic::round: |
1845 | case Intrinsic::roundeven: |
1846 | case Intrinsic::trunc: { |
1847 | Value *Src = II->getArgOperand(i: 0); |
1848 | if (!Src->hasOneUse()) |
1849 | break; |
1850 | |
1851 | // Except for fabs, this transformation requires the input of the unary FP |
1852 | // operation to be itself an fpext from the type to which we're |
1853 | // truncating. |
1854 | if (II->getIntrinsicID() != Intrinsic::fabs) { |
1855 | FPExtInst *FPExtSrc = dyn_cast<FPExtInst>(Val: Src); |
1856 | if (!FPExtSrc || FPExtSrc->getSrcTy() != Ty) |
1857 | break; |
1858 | } |
1859 | |
1860 | // Do unary FP operation on smaller type. |
1861 | // (fptrunc (fabs x)) -> (fabs (fptrunc x)) |
1862 | Value *InnerTrunc = Builder.CreateFPTrunc(V: Src, DestTy: Ty); |
1863 | Function *Overload = Intrinsic::getDeclaration(M: FPT.getModule(), |
1864 | id: II->getIntrinsicID(), Tys: Ty); |
1865 | SmallVector<OperandBundleDef, 1> OpBundles; |
1866 | II->getOperandBundlesAsDefs(Defs&: OpBundles); |
1867 | CallInst *NewCI = |
1868 | CallInst::Create(Func: Overload, Args: {InnerTrunc}, Bundles: OpBundles, NameStr: II->getName()); |
1869 | NewCI->copyFastMathFlags(I: II); |
1870 | return NewCI; |
1871 | } |
1872 | } |
1873 | } |
1874 | |
1875 | if (Instruction *I = shrinkInsertElt(Trunc&: FPT, Builder)) |
1876 | return I; |
1877 | |
1878 | Value *Src = FPT.getOperand(i_nocapture: 0); |
1879 | if (isa<SIToFPInst>(Val: Src) || isa<UIToFPInst>(Val: Src)) { |
1880 | auto *FPCast = cast<CastInst>(Val: Src); |
1881 | if (isKnownExactCastIntToFP(I&: *FPCast, IC&: *this)) |
1882 | return CastInst::Create(FPCast->getOpcode(), S: FPCast->getOperand(i_nocapture: 0), Ty); |
1883 | } |
1884 | |
1885 | return nullptr; |
1886 | } |
1887 | |
1888 | Instruction *InstCombinerImpl::visitFPExt(CastInst &FPExt) { |
1889 | // If the source operand is a cast from integer to FP and known exact, then |
1890 | // cast the integer operand directly to the destination type. |
1891 | Type *Ty = FPExt.getType(); |
1892 | Value *Src = FPExt.getOperand(i_nocapture: 0); |
1893 | if (isa<SIToFPInst>(Val: Src) || isa<UIToFPInst>(Val: Src)) { |
1894 | auto *FPCast = cast<CastInst>(Val: Src); |
1895 | if (isKnownExactCastIntToFP(I&: *FPCast, IC&: *this)) |
1896 | return CastInst::Create(FPCast->getOpcode(), S: FPCast->getOperand(i_nocapture: 0), Ty); |
1897 | } |
1898 | |
1899 | return commonCastTransforms(CI&: FPExt); |
1900 | } |
1901 | |
1902 | /// fpto{s/u}i({u/s}itofp(X)) --> X or zext(X) or sext(X) or trunc(X) |
1903 | /// This is safe if the intermediate type has enough bits in its mantissa to |
1904 | /// accurately represent all values of X. For example, this won't work with |
1905 | /// i64 -> float -> i64. |
1906 | Instruction *InstCombinerImpl::foldItoFPtoI(CastInst &FI) { |
1907 | if (!isa<UIToFPInst>(Val: FI.getOperand(i_nocapture: 0)) && !isa<SIToFPInst>(Val: FI.getOperand(i_nocapture: 0))) |
1908 | return nullptr; |
1909 | |
1910 | auto *OpI = cast<CastInst>(Val: FI.getOperand(i_nocapture: 0)); |
1911 | Value *X = OpI->getOperand(i_nocapture: 0); |
1912 | Type *XType = X->getType(); |
1913 | Type *DestType = FI.getType(); |
1914 | bool IsOutputSigned = isa<FPToSIInst>(Val: FI); |
1915 | |
1916 | // Since we can assume the conversion won't overflow, our decision as to |
1917 | // whether the input will fit in the float should depend on the minimum |
1918 | // of the input range and output range. |
1919 | |
1920 | // This means this is also safe for a signed input and unsigned output, since |
1921 | // a negative input would lead to undefined behavior. |
1922 | if (!isKnownExactCastIntToFP(I&: *OpI, IC&: *this)) { |
1923 | // The first cast may not round exactly based on the source integer width |
1924 | // and FP width, but the overflow UB rules can still allow this to fold. |
1925 | // If the destination type is narrow, that means the intermediate FP value |
1926 | // must be large enough to hold the source value exactly. |
1927 | // For example, (uint8_t)((float)(uint32_t 16777217) is undefined behavior. |
1928 | int OutputSize = (int)DestType->getScalarSizeInBits(); |
1929 | if (OutputSize > OpI->getType()->getFPMantissaWidth()) |
1930 | return nullptr; |
1931 | } |
1932 | |
1933 | if (DestType->getScalarSizeInBits() > XType->getScalarSizeInBits()) { |
1934 | bool IsInputSigned = isa<SIToFPInst>(Val: OpI); |
1935 | if (IsInputSigned && IsOutputSigned) |
1936 | return new SExtInst(X, DestType); |
1937 | return new ZExtInst(X, DestType); |
1938 | } |
1939 | if (DestType->getScalarSizeInBits() < XType->getScalarSizeInBits()) |
1940 | return new TruncInst(X, DestType); |
1941 | |
1942 | assert(XType == DestType && "Unexpected types for int to FP to int casts" ); |
1943 | return replaceInstUsesWith(I&: FI, V: X); |
1944 | } |
1945 | |
1946 | static Instruction *foldFPtoI(Instruction &FI, InstCombiner &IC) { |
1947 | // fpto{u/s}i non-norm --> 0 |
1948 | FPClassTest Mask = |
1949 | FI.getOpcode() == Instruction::FPToUI ? fcPosNormal : fcNormal; |
1950 | KnownFPClass FPClass = |
1951 | computeKnownFPClass(V: FI.getOperand(i: 0), InterestedClasses: Mask, /*Depth=*/0, |
1952 | SQ: IC.getSimplifyQuery().getWithInstruction(I: &FI)); |
1953 | if (FPClass.isKnownNever(Mask)) |
1954 | return IC.replaceInstUsesWith(I&: FI, V: ConstantInt::getNullValue(Ty: FI.getType())); |
1955 | |
1956 | return nullptr; |
1957 | } |
1958 | |
1959 | Instruction *InstCombinerImpl::visitFPToUI(FPToUIInst &FI) { |
1960 | if (Instruction *I = foldItoFPtoI(FI)) |
1961 | return I; |
1962 | |
1963 | if (Instruction *I = foldFPtoI(FI, IC&: *this)) |
1964 | return I; |
1965 | |
1966 | return commonCastTransforms(CI&: FI); |
1967 | } |
1968 | |
1969 | Instruction *InstCombinerImpl::visitFPToSI(FPToSIInst &FI) { |
1970 | if (Instruction *I = foldItoFPtoI(FI)) |
1971 | return I; |
1972 | |
1973 | if (Instruction *I = foldFPtoI(FI, IC&: *this)) |
1974 | return I; |
1975 | |
1976 | return commonCastTransforms(CI&: FI); |
1977 | } |
1978 | |
1979 | Instruction *InstCombinerImpl::visitUIToFP(CastInst &CI) { |
1980 | if (Instruction *R = commonCastTransforms(CI)) |
1981 | return R; |
1982 | if (!CI.hasNonNeg() && isKnownNonNegative(V: CI.getOperand(i_nocapture: 0), SQ)) { |
1983 | CI.setNonNeg(); |
1984 | return &CI; |
1985 | } |
1986 | return nullptr; |
1987 | } |
1988 | |
1989 | Instruction *InstCombinerImpl::visitSIToFP(CastInst &CI) { |
1990 | if (Instruction *R = commonCastTransforms(CI)) |
1991 | return R; |
1992 | if (isKnownNonNegative(V: CI.getOperand(i_nocapture: 0), SQ)) { |
1993 | auto *UI = |
1994 | CastInst::Create(Instruction::UIToFP, S: CI.getOperand(i_nocapture: 0), Ty: CI.getType()); |
1995 | UI->setNonNeg(true); |
1996 | return UI; |
1997 | } |
1998 | return nullptr; |
1999 | } |
2000 | |
2001 | Instruction *InstCombinerImpl::visitIntToPtr(IntToPtrInst &CI) { |
2002 | // If the source integer type is not the intptr_t type for this target, do a |
2003 | // trunc or zext to the intptr_t type, then inttoptr of it. This allows the |
2004 | // cast to be exposed to other transforms. |
2005 | unsigned AS = CI.getAddressSpace(); |
2006 | if (CI.getOperand(i_nocapture: 0)->getType()->getScalarSizeInBits() != |
2007 | DL.getPointerSizeInBits(AS)) { |
2008 | Type *Ty = CI.getOperand(i_nocapture: 0)->getType()->getWithNewType( |
2009 | EltTy: DL.getIntPtrType(C&: CI.getContext(), AddressSpace: AS)); |
2010 | Value *P = Builder.CreateZExtOrTrunc(V: CI.getOperand(i_nocapture: 0), DestTy: Ty); |
2011 | return new IntToPtrInst(P, CI.getType()); |
2012 | } |
2013 | |
2014 | if (Instruction *I = commonCastTransforms(CI)) |
2015 | return I; |
2016 | |
2017 | return nullptr; |
2018 | } |
2019 | |
2020 | Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { |
2021 | // If the destination integer type is not the intptr_t type for this target, |
2022 | // do a ptrtoint to intptr_t then do a trunc or zext. This allows the cast |
2023 | // to be exposed to other transforms. |
2024 | Value *SrcOp = CI.getPointerOperand(); |
2025 | Type *SrcTy = SrcOp->getType(); |
2026 | Type *Ty = CI.getType(); |
2027 | unsigned AS = CI.getPointerAddressSpace(); |
2028 | unsigned TySize = Ty->getScalarSizeInBits(); |
2029 | unsigned PtrSize = DL.getPointerSizeInBits(AS); |
2030 | if (TySize != PtrSize) { |
2031 | Type *IntPtrTy = |
2032 | SrcTy->getWithNewType(EltTy: DL.getIntPtrType(C&: CI.getContext(), AddressSpace: AS)); |
2033 | Value *P = Builder.CreatePtrToInt(V: SrcOp, DestTy: IntPtrTy); |
2034 | return CastInst::CreateIntegerCast(S: P, Ty, /*isSigned=*/false); |
2035 | } |
2036 | |
2037 | // (ptrtoint (ptrmask P, M)) |
2038 | // -> (and (ptrtoint P), M) |
2039 | // This is generally beneficial as `and` is better supported than `ptrmask`. |
2040 | Value *Ptr, *Mask; |
2041 | if (match(SrcOp, m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(Ptr), |
2042 | m_Value(Mask)))) && |
2043 | Mask->getType() == Ty) |
2044 | return BinaryOperator::CreateAnd(V1: Builder.CreatePtrToInt(V: Ptr, DestTy: Ty), V2: Mask); |
2045 | |
2046 | if (auto *GEP = dyn_cast<GetElementPtrInst>(Val: SrcOp)) { |
2047 | // Fold ptrtoint(gep null, x) to multiply + constant if the GEP has one use. |
2048 | // While this can increase the number of instructions it doesn't actually |
2049 | // increase the overall complexity since the arithmetic is just part of |
2050 | // the GEP otherwise. |
2051 | if (GEP->hasOneUse() && |
2052 | isa<ConstantPointerNull>(Val: GEP->getPointerOperand())) { |
2053 | return replaceInstUsesWith(I&: CI, |
2054 | V: Builder.CreateIntCast(V: EmitGEPOffset(GEP), DestTy: Ty, |
2055 | /*isSigned=*/false)); |
2056 | } |
2057 | } |
2058 | |
2059 | Value *Vec, *Scalar, *Index; |
2060 | if (match(V: SrcOp, P: m_OneUse(SubPattern: m_InsertElt(Val: m_IntToPtr(Op: m_Value(V&: Vec)), |
2061 | Elt: m_Value(V&: Scalar), Idx: m_Value(V&: Index)))) && |
2062 | Vec->getType() == Ty) { |
2063 | assert(Vec->getType()->getScalarSizeInBits() == PtrSize && "Wrong type" ); |
2064 | // Convert the scalar to int followed by insert to eliminate one cast: |
2065 | // p2i (ins (i2p Vec), Scalar, Index --> ins Vec, (p2i Scalar), Index |
2066 | Value *NewCast = Builder.CreatePtrToInt(V: Scalar, DestTy: Ty->getScalarType()); |
2067 | return InsertElementInst::Create(Vec, NewElt: NewCast, Idx: Index); |
2068 | } |
2069 | |
2070 | return commonCastTransforms(CI); |
2071 | } |
2072 | |
2073 | /// This input value (which is known to have vector type) is being zero extended |
2074 | /// or truncated to the specified vector type. Since the zext/trunc is done |
2075 | /// using an integer type, we have a (bitcast(cast(bitcast))) pattern, |
2076 | /// endianness will impact which end of the vector that is extended or |
2077 | /// truncated. |
2078 | /// |
2079 | /// A vector is always stored with index 0 at the lowest address, which |
2080 | /// corresponds to the most significant bits for a big endian stored integer and |
2081 | /// the least significant bits for little endian. A trunc/zext of an integer |
2082 | /// impacts the big end of the integer. Thus, we need to add/remove elements at |
2083 | /// the front of the vector for big endian targets, and the back of the vector |
2084 | /// for little endian targets. |
2085 | /// |
2086 | /// Try to replace it with a shuffle (and vector/vector bitcast) if possible. |
2087 | /// |
2088 | /// The source and destination vector types may have different element types. |
2089 | static Instruction * |
2090 | optimizeVectorResizeWithIntegerBitCasts(Value *InVal, VectorType *DestTy, |
2091 | InstCombinerImpl &IC) { |
2092 | // We can only do this optimization if the output is a multiple of the input |
2093 | // element size, or the input is a multiple of the output element size. |
2094 | // Convert the input type to have the same element type as the output. |
2095 | VectorType *SrcTy = cast<VectorType>(Val: InVal->getType()); |
2096 | |
2097 | if (SrcTy->getElementType() != DestTy->getElementType()) { |
2098 | // The input types don't need to be identical, but for now they must be the |
2099 | // same size. There is no specific reason we couldn't handle things like |
2100 | // <4 x i16> -> <4 x i32> by bitcasting to <2 x i32> but haven't gotten |
2101 | // there yet. |
2102 | if (SrcTy->getElementType()->getPrimitiveSizeInBits() != |
2103 | DestTy->getElementType()->getPrimitiveSizeInBits()) |
2104 | return nullptr; |
2105 | |
2106 | SrcTy = |
2107 | FixedVectorType::get(ElementType: DestTy->getElementType(), |
2108 | NumElts: cast<FixedVectorType>(Val: SrcTy)->getNumElements()); |
2109 | InVal = IC.Builder.CreateBitCast(V: InVal, DestTy: SrcTy); |
2110 | } |
2111 | |
2112 | bool IsBigEndian = IC.getDataLayout().isBigEndian(); |
2113 | unsigned SrcElts = cast<FixedVectorType>(Val: SrcTy)->getNumElements(); |
2114 | unsigned DestElts = cast<FixedVectorType>(Val: DestTy)->getNumElements(); |
2115 | |
2116 | assert(SrcElts != DestElts && "Element counts should be different." ); |
2117 | |
2118 | // Now that the element types match, get the shuffle mask and RHS of the |
2119 | // shuffle to use, which depends on whether we're increasing or decreasing the |
2120 | // size of the input. |
2121 | auto ShuffleMaskStorage = llvm::to_vector<16>(Range: llvm::seq<int>(Begin: 0, End: SrcElts)); |
2122 | ArrayRef<int> ShuffleMask; |
2123 | Value *V2; |
2124 | |
2125 | if (SrcElts > DestElts) { |
2126 | // If we're shrinking the number of elements (rewriting an integer |
2127 | // truncate), just shuffle in the elements corresponding to the least |
2128 | // significant bits from the input and use poison as the second shuffle |
2129 | // input. |
2130 | V2 = PoisonValue::get(T: SrcTy); |
2131 | // Make sure the shuffle mask selects the "least significant bits" by |
2132 | // keeping elements from back of the src vector for big endian, and from the |
2133 | // front for little endian. |
2134 | ShuffleMask = ShuffleMaskStorage; |
2135 | if (IsBigEndian) |
2136 | ShuffleMask = ShuffleMask.take_back(N: DestElts); |
2137 | else |
2138 | ShuffleMask = ShuffleMask.take_front(N: DestElts); |
2139 | } else { |
2140 | // If we're increasing the number of elements (rewriting an integer zext), |
2141 | // shuffle in all of the elements from InVal. Fill the rest of the result |
2142 | // elements with zeros from a constant zero. |
2143 | V2 = Constant::getNullValue(Ty: SrcTy); |
2144 | // Use first elt from V2 when indicating zero in the shuffle mask. |
2145 | uint32_t NullElt = SrcElts; |
2146 | // Extend with null values in the "most significant bits" by adding elements |
2147 | // in front of the src vector for big endian, and at the back for little |
2148 | // endian. |
2149 | unsigned DeltaElts = DestElts - SrcElts; |
2150 | if (IsBigEndian) |
2151 | ShuffleMaskStorage.insert(I: ShuffleMaskStorage.begin(), NumToInsert: DeltaElts, Elt: NullElt); |
2152 | else |
2153 | ShuffleMaskStorage.append(NumInputs: DeltaElts, Elt: NullElt); |
2154 | ShuffleMask = ShuffleMaskStorage; |
2155 | } |
2156 | |
2157 | return new ShuffleVectorInst(InVal, V2, ShuffleMask); |
2158 | } |
2159 | |
2160 | static bool isMultipleOfTypeSize(unsigned Value, Type *Ty) { |
2161 | return Value % Ty->getPrimitiveSizeInBits() == 0; |
2162 | } |
2163 | |
2164 | static unsigned getTypeSizeIndex(unsigned Value, Type *Ty) { |
2165 | return Value / Ty->getPrimitiveSizeInBits(); |
2166 | } |
2167 | |
2168 | /// V is a value which is inserted into a vector of VecEltTy. |
2169 | /// Look through the value to see if we can decompose it into |
2170 | /// insertions into the vector. See the example in the comment for |
2171 | /// OptimizeIntegerToVectorInsertions for the pattern this handles. |
2172 | /// The type of V is always a non-zero multiple of VecEltTy's size. |
2173 | /// Shift is the number of bits between the lsb of V and the lsb of |
2174 | /// the vector. |
2175 | /// |
2176 | /// This returns false if the pattern can't be matched or true if it can, |
2177 | /// filling in Elements with the elements found here. |
2178 | static bool collectInsertionElements(Value *V, unsigned Shift, |
2179 | SmallVectorImpl<Value *> &Elements, |
2180 | Type *VecEltTy, bool isBigEndian) { |
2181 | assert(isMultipleOfTypeSize(Shift, VecEltTy) && |
2182 | "Shift should be a multiple of the element type size" ); |
2183 | |
2184 | // Undef values never contribute useful bits to the result. |
2185 | if (isa<UndefValue>(Val: V)) return true; |
2186 | |
2187 | // If we got down to a value of the right type, we win, try inserting into the |
2188 | // right element. |
2189 | if (V->getType() == VecEltTy) { |
2190 | // Inserting null doesn't actually insert any elements. |
2191 | if (Constant *C = dyn_cast<Constant>(Val: V)) |
2192 | if (C->isNullValue()) |
2193 | return true; |
2194 | |
2195 | unsigned ElementIndex = getTypeSizeIndex(Value: Shift, Ty: VecEltTy); |
2196 | if (isBigEndian) |
2197 | ElementIndex = Elements.size() - ElementIndex - 1; |
2198 | |
2199 | // Fail if multiple elements are inserted into this slot. |
2200 | if (Elements[ElementIndex]) |
2201 | return false; |
2202 | |
2203 | Elements[ElementIndex] = V; |
2204 | return true; |
2205 | } |
2206 | |
2207 | if (Constant *C = dyn_cast<Constant>(Val: V)) { |
2208 | // Figure out the # elements this provides, and bitcast it or slice it up |
2209 | // as required. |
2210 | unsigned NumElts = getTypeSizeIndex(Value: C->getType()->getPrimitiveSizeInBits(), |
2211 | Ty: VecEltTy); |
2212 | // If the constant is the size of a vector element, we just need to bitcast |
2213 | // it to the right type so it gets properly inserted. |
2214 | if (NumElts == 1) |
2215 | return collectInsertionElements(V: ConstantExpr::getBitCast(C, Ty: VecEltTy), |
2216 | Shift, Elements, VecEltTy, isBigEndian); |
2217 | |
2218 | // Okay, this is a constant that covers multiple elements. Slice it up into |
2219 | // pieces and insert each element-sized piece into the vector. |
2220 | if (!isa<IntegerType>(Val: C->getType())) |
2221 | C = ConstantExpr::getBitCast(C, Ty: IntegerType::get(C&: V->getContext(), |
2222 | NumBits: C->getType()->getPrimitiveSizeInBits())); |
2223 | unsigned ElementSize = VecEltTy->getPrimitiveSizeInBits(); |
2224 | Type *ElementIntTy = IntegerType::get(C&: C->getContext(), NumBits: ElementSize); |
2225 | |
2226 | for (unsigned i = 0; i != NumElts; ++i) { |
2227 | unsigned ShiftI = i * ElementSize; |
2228 | Constant *Piece = ConstantFoldBinaryInstruction( |
2229 | Opcode: Instruction::LShr, V1: C, V2: ConstantInt::get(Ty: C->getType(), V: ShiftI)); |
2230 | if (!Piece) |
2231 | return false; |
2232 | |
2233 | Piece = ConstantExpr::getTrunc(C: Piece, Ty: ElementIntTy); |
2234 | if (!collectInsertionElements(V: Piece, Shift: ShiftI + Shift, Elements, VecEltTy, |
2235 | isBigEndian)) |
2236 | return false; |
2237 | } |
2238 | return true; |
2239 | } |
2240 | |
2241 | if (!V->hasOneUse()) return false; |
2242 | |
2243 | Instruction *I = dyn_cast<Instruction>(Val: V); |
2244 | if (!I) return false; |
2245 | switch (I->getOpcode()) { |
2246 | default: return false; // Unhandled case. |
2247 | case Instruction::BitCast: |
2248 | if (I->getOperand(i: 0)->getType()->isVectorTy()) |
2249 | return false; |
2250 | return collectInsertionElements(V: I->getOperand(i: 0), Shift, Elements, VecEltTy, |
2251 | isBigEndian); |
2252 | case Instruction::ZExt: |
2253 | if (!isMultipleOfTypeSize( |
2254 | Value: I->getOperand(i: 0)->getType()->getPrimitiveSizeInBits(), |
2255 | Ty: VecEltTy)) |
2256 | return false; |
2257 | return collectInsertionElements(V: I->getOperand(i: 0), Shift, Elements, VecEltTy, |
2258 | isBigEndian); |
2259 | case Instruction::Or: |
2260 | return collectInsertionElements(V: I->getOperand(i: 0), Shift, Elements, VecEltTy, |
2261 | isBigEndian) && |
2262 | collectInsertionElements(V: I->getOperand(i: 1), Shift, Elements, VecEltTy, |
2263 | isBigEndian); |
2264 | case Instruction::Shl: { |
2265 | // Must be shifting by a constant that is a multiple of the element size. |
2266 | ConstantInt *CI = dyn_cast<ConstantInt>(Val: I->getOperand(i: 1)); |
2267 | if (!CI) return false; |
2268 | Shift += CI->getZExtValue(); |
2269 | if (!isMultipleOfTypeSize(Value: Shift, Ty: VecEltTy)) return false; |
2270 | return collectInsertionElements(V: I->getOperand(i: 0), Shift, Elements, VecEltTy, |
2271 | isBigEndian); |
2272 | } |
2273 | |
2274 | } |
2275 | } |
2276 | |
2277 | |
2278 | /// If the input is an 'or' instruction, we may be doing shifts and ors to |
2279 | /// assemble the elements of the vector manually. |
2280 | /// Try to rip the code out and replace it with insertelements. This is to |
2281 | /// optimize code like this: |
2282 | /// |
2283 | /// %tmp37 = bitcast float %inc to i32 |
2284 | /// %tmp38 = zext i32 %tmp37 to i64 |
2285 | /// %tmp31 = bitcast float %inc5 to i32 |
2286 | /// %tmp32 = zext i32 %tmp31 to i64 |
2287 | /// %tmp33 = shl i64 %tmp32, 32 |
2288 | /// %ins35 = or i64 %tmp33, %tmp38 |
2289 | /// %tmp43 = bitcast i64 %ins35 to <2 x float> |
2290 | /// |
2291 | /// Into two insertelements that do "buildvector{%inc, %inc5}". |
2292 | static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI, |
2293 | InstCombinerImpl &IC) { |
2294 | auto *DestVecTy = cast<FixedVectorType>(Val: CI.getType()); |
2295 | Value *IntInput = CI.getOperand(i_nocapture: 0); |
2296 | |
2297 | SmallVector<Value*, 8> Elements(DestVecTy->getNumElements()); |
2298 | if (!collectInsertionElements(V: IntInput, Shift: 0, Elements, |
2299 | VecEltTy: DestVecTy->getElementType(), |
2300 | isBigEndian: IC.getDataLayout().isBigEndian())) |
2301 | return nullptr; |
2302 | |
2303 | // If we succeeded, we know that all of the element are specified by Elements |
2304 | // or are zero if Elements has a null entry. Recast this as a set of |
2305 | // insertions. |
2306 | Value *Result = Constant::getNullValue(Ty: CI.getType()); |
2307 | for (unsigned i = 0, e = Elements.size(); i != e; ++i) { |
2308 | if (!Elements[i]) continue; // Unset element. |
2309 | |
2310 | Result = IC.Builder.CreateInsertElement(Vec: Result, NewElt: Elements[i], |
2311 | Idx: IC.Builder.getInt32(C: i)); |
2312 | } |
2313 | |
2314 | return Result; |
2315 | } |
2316 | |
2317 | /// Canonicalize scalar bitcasts of extracted elements into a bitcast of the |
2318 | /// vector followed by extract element. The backend tends to handle bitcasts of |
2319 | /// vectors better than bitcasts of scalars because vector registers are |
2320 | /// usually not type-specific like scalar integer or scalar floating-point. |
2321 | static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast, |
2322 | InstCombinerImpl &IC) { |
2323 | Value *VecOp, *Index; |
2324 | if (!match(V: BitCast.getOperand(i_nocapture: 0), |
2325 | P: m_OneUse(SubPattern: m_ExtractElt(Val: m_Value(V&: VecOp), Idx: m_Value(V&: Index))))) |
2326 | return nullptr; |
2327 | |
2328 | // The bitcast must be to a vectorizable type, otherwise we can't make a new |
2329 | // type to extract from. |
2330 | Type *DestType = BitCast.getType(); |
2331 | VectorType *VecType = cast<VectorType>(Val: VecOp->getType()); |
2332 | if (VectorType::isValidElementType(ElemTy: DestType)) { |
2333 | auto *NewVecType = VectorType::get(ElementType: DestType, Other: VecType); |
2334 | auto *NewBC = IC.Builder.CreateBitCast(V: VecOp, DestTy: NewVecType, Name: "bc" ); |
2335 | return ExtractElementInst::Create(Vec: NewBC, Idx: Index); |
2336 | } |
2337 | |
2338 | // Only solve DestType is vector to avoid inverse transform in visitBitCast. |
2339 | // bitcast (extractelement <1 x elt>, dest) -> bitcast(<1 x elt>, dest) |
2340 | auto *FixedVType = dyn_cast<FixedVectorType>(Val: VecType); |
2341 | if (DestType->isVectorTy() && FixedVType && FixedVType->getNumElements() == 1) |
2342 | return CastInst::Create(Instruction::BitCast, S: VecOp, Ty: DestType); |
2343 | |
2344 | return nullptr; |
2345 | } |
2346 | |
2347 | /// Change the type of a bitwise logic operation if we can eliminate a bitcast. |
2348 | static Instruction *foldBitCastBitwiseLogic(BitCastInst &BitCast, |
2349 | InstCombiner::BuilderTy &Builder) { |
2350 | Type *DestTy = BitCast.getType(); |
2351 | BinaryOperator *BO; |
2352 | |
2353 | if (!match(V: BitCast.getOperand(i_nocapture: 0), P: m_OneUse(SubPattern: m_BinOp(I&: BO))) || |
2354 | !BO->isBitwiseLogicOp()) |
2355 | return nullptr; |
2356 | |
2357 | // FIXME: This transform is restricted to vector types to avoid backend |
2358 | // problems caused by creating potentially illegal operations. If a fix-up is |
2359 | // added to handle that situation, we can remove this check. |
2360 | if (!DestTy->isVectorTy() || !BO->getType()->isVectorTy()) |
2361 | return nullptr; |
2362 | |
2363 | if (DestTy->isFPOrFPVectorTy()) { |
2364 | Value *X, *Y; |
2365 | // bitcast(logic(bitcast(X), bitcast(Y))) -> bitcast'(logic(bitcast'(X), Y)) |
2366 | if (match(V: BO->getOperand(i_nocapture: 0), P: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: X)))) && |
2367 | match(V: BO->getOperand(i_nocapture: 1), P: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: Y))))) { |
2368 | if (X->getType()->isFPOrFPVectorTy() && |
2369 | Y->getType()->isIntOrIntVectorTy()) { |
2370 | Value *CastedOp = |
2371 | Builder.CreateBitCast(V: BO->getOperand(i_nocapture: 0), DestTy: Y->getType()); |
2372 | Value *NewBO = Builder.CreateBinOp(Opc: BO->getOpcode(), LHS: CastedOp, RHS: Y); |
2373 | return CastInst::CreateBitOrPointerCast(S: NewBO, Ty: DestTy); |
2374 | } |
2375 | if (X->getType()->isIntOrIntVectorTy() && |
2376 | Y->getType()->isFPOrFPVectorTy()) { |
2377 | Value *CastedOp = |
2378 | Builder.CreateBitCast(V: BO->getOperand(i_nocapture: 1), DestTy: X->getType()); |
2379 | Value *NewBO = Builder.CreateBinOp(Opc: BO->getOpcode(), LHS: CastedOp, RHS: X); |
2380 | return CastInst::CreateBitOrPointerCast(S: NewBO, Ty: DestTy); |
2381 | } |
2382 | } |
2383 | return nullptr; |
2384 | } |
2385 | |
2386 | if (!DestTy->isIntOrIntVectorTy()) |
2387 | return nullptr; |
2388 | |
2389 | Value *X; |
2390 | if (match(V: BO->getOperand(i_nocapture: 0), P: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: X)))) && |
2391 | X->getType() == DestTy && !isa<Constant>(Val: X)) { |
2392 | // bitcast(logic(bitcast(X), Y)) --> logic'(X, bitcast(Y)) |
2393 | Value *CastedOp1 = Builder.CreateBitCast(V: BO->getOperand(i_nocapture: 1), DestTy); |
2394 | return BinaryOperator::Create(Op: BO->getOpcode(), S1: X, S2: CastedOp1); |
2395 | } |
2396 | |
2397 | if (match(V: BO->getOperand(i_nocapture: 1), P: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: X)))) && |
2398 | X->getType() == DestTy && !isa<Constant>(Val: X)) { |
2399 | // bitcast(logic(Y, bitcast(X))) --> logic'(bitcast(Y), X) |
2400 | Value *CastedOp0 = Builder.CreateBitCast(V: BO->getOperand(i_nocapture: 0), DestTy); |
2401 | return BinaryOperator::Create(Op: BO->getOpcode(), S1: CastedOp0, S2: X); |
2402 | } |
2403 | |
2404 | // Canonicalize vector bitcasts to come before vector bitwise logic with a |
2405 | // constant. This eases recognition of special constants for later ops. |
2406 | // Example: |
2407 | // icmp u/s (a ^ signmask), (b ^ signmask) --> icmp s/u a, b |
2408 | Constant *C; |
2409 | if (match(V: BO->getOperand(i_nocapture: 1), P: m_Constant(C))) { |
2410 | // bitcast (logic X, C) --> logic (bitcast X, C') |
2411 | Value *CastedOp0 = Builder.CreateBitCast(V: BO->getOperand(i_nocapture: 0), DestTy); |
2412 | Value *CastedC = Builder.CreateBitCast(V: C, DestTy); |
2413 | return BinaryOperator::Create(Op: BO->getOpcode(), S1: CastedOp0, S2: CastedC); |
2414 | } |
2415 | |
2416 | return nullptr; |
2417 | } |
2418 | |
2419 | /// Change the type of a select if we can eliminate a bitcast. |
2420 | static Instruction *foldBitCastSelect(BitCastInst &BitCast, |
2421 | InstCombiner::BuilderTy &Builder) { |
2422 | Value *Cond, *TVal, *FVal; |
2423 | if (!match(V: BitCast.getOperand(i_nocapture: 0), |
2424 | P: m_OneUse(SubPattern: m_Select(C: m_Value(V&: Cond), L: m_Value(V&: TVal), R: m_Value(V&: FVal))))) |
2425 | return nullptr; |
2426 | |
2427 | // A vector select must maintain the same number of elements in its operands. |
2428 | Type *CondTy = Cond->getType(); |
2429 | Type *DestTy = BitCast.getType(); |
2430 | if (auto *CondVTy = dyn_cast<VectorType>(Val: CondTy)) |
2431 | if (!DestTy->isVectorTy() || |
2432 | CondVTy->getElementCount() != |
2433 | cast<VectorType>(Val: DestTy)->getElementCount()) |
2434 | return nullptr; |
2435 | |
2436 | // FIXME: This transform is restricted from changing the select between |
2437 | // scalars and vectors to avoid backend problems caused by creating |
2438 | // potentially illegal operations. If a fix-up is added to handle that |
2439 | // situation, we can remove this check. |
2440 | if (DestTy->isVectorTy() != TVal->getType()->isVectorTy()) |
2441 | return nullptr; |
2442 | |
2443 | auto *Sel = cast<Instruction>(Val: BitCast.getOperand(i_nocapture: 0)); |
2444 | Value *X; |
2445 | if (match(V: TVal, P: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: X)))) && X->getType() == DestTy && |
2446 | !isa<Constant>(Val: X)) { |
2447 | // bitcast(select(Cond, bitcast(X), Y)) --> select'(Cond, X, bitcast(Y)) |
2448 | Value *CastedVal = Builder.CreateBitCast(V: FVal, DestTy); |
2449 | return SelectInst::Create(C: Cond, S1: X, S2: CastedVal, NameStr: "" , InsertBefore: nullptr, MDFrom: Sel); |
2450 | } |
2451 | |
2452 | if (match(V: FVal, P: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: X)))) && X->getType() == DestTy && |
2453 | !isa<Constant>(Val: X)) { |
2454 | // bitcast(select(Cond, Y, bitcast(X))) --> select'(Cond, bitcast(Y), X) |
2455 | Value *CastedVal = Builder.CreateBitCast(V: TVal, DestTy); |
2456 | return SelectInst::Create(C: Cond, S1: CastedVal, S2: X, NameStr: "" , InsertBefore: nullptr, MDFrom: Sel); |
2457 | } |
2458 | |
2459 | return nullptr; |
2460 | } |
2461 | |
2462 | /// Check if all users of CI are StoreInsts. |
2463 | static bool hasStoreUsersOnly(CastInst &CI) { |
2464 | for (User *U : CI.users()) { |
2465 | if (!isa<StoreInst>(Val: U)) |
2466 | return false; |
2467 | } |
2468 | return true; |
2469 | } |
2470 | |
2471 | /// This function handles following case |
2472 | /// |
2473 | /// A -> B cast |
2474 | /// PHI |
2475 | /// B -> A cast |
2476 | /// |
2477 | /// All the related PHI nodes can be replaced by new PHI nodes with type A. |
2478 | /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. |
2479 | Instruction *InstCombinerImpl::optimizeBitCastFromPhi(CastInst &CI, |
2480 | PHINode *PN) { |
2481 | // BitCast used by Store can be handled in InstCombineLoadStoreAlloca.cpp. |
2482 | if (hasStoreUsersOnly(CI)) |
2483 | return nullptr; |
2484 | |
2485 | Value *Src = CI.getOperand(i_nocapture: 0); |
2486 | Type *SrcTy = Src->getType(); // Type B |
2487 | Type *DestTy = CI.getType(); // Type A |
2488 | |
2489 | SmallVector<PHINode *, 4> PhiWorklist; |
2490 | SmallSetVector<PHINode *, 4> OldPhiNodes; |
2491 | |
2492 | // Find all of the A->B casts and PHI nodes. |
2493 | // We need to inspect all related PHI nodes, but PHIs can be cyclic, so |
2494 | // OldPhiNodes is used to track all known PHI nodes, before adding a new |
2495 | // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. |
2496 | PhiWorklist.push_back(Elt: PN); |
2497 | OldPhiNodes.insert(X: PN); |
2498 | while (!PhiWorklist.empty()) { |
2499 | auto *OldPN = PhiWorklist.pop_back_val(); |
2500 | for (Value *IncValue : OldPN->incoming_values()) { |
2501 | if (isa<Constant>(Val: IncValue)) |
2502 | continue; |
2503 | |
2504 | if (auto *LI = dyn_cast<LoadInst>(Val: IncValue)) { |
2505 | // If there is a sequence of one or more load instructions, each loaded |
2506 | // value is used as address of later load instruction, bitcast is |
2507 | // necessary to change the value type, don't optimize it. For |
2508 | // simplicity we give up if the load address comes from another load. |
2509 | Value *Addr = LI->getOperand(i_nocapture: 0); |
2510 | if (Addr == &CI || isa<LoadInst>(Val: Addr)) |
2511 | return nullptr; |
2512 | // Don't tranform "load <256 x i32>, <256 x i32>*" to |
2513 | // "load x86_amx, x86_amx*", because x86_amx* is invalid. |
2514 | // TODO: Remove this check when bitcast between vector and x86_amx |
2515 | // is replaced with a specific intrinsic. |
2516 | if (DestTy->isX86_AMXTy()) |
2517 | return nullptr; |
2518 | if (LI->hasOneUse() && LI->isSimple()) |
2519 | continue; |
2520 | // If a LoadInst has more than one use, changing the type of loaded |
2521 | // value may create another bitcast. |
2522 | return nullptr; |
2523 | } |
2524 | |
2525 | if (auto *PNode = dyn_cast<PHINode>(Val: IncValue)) { |
2526 | if (OldPhiNodes.insert(X: PNode)) |
2527 | PhiWorklist.push_back(Elt: PNode); |
2528 | continue; |
2529 | } |
2530 | |
2531 | auto *BCI = dyn_cast<BitCastInst>(Val: IncValue); |
2532 | // We can't handle other instructions. |
2533 | if (!BCI) |
2534 | return nullptr; |
2535 | |
2536 | // Verify it's a A->B cast. |
2537 | Type *TyA = BCI->getOperand(i_nocapture: 0)->getType(); |
2538 | Type *TyB = BCI->getType(); |
2539 | if (TyA != DestTy || TyB != SrcTy) |
2540 | return nullptr; |
2541 | } |
2542 | } |
2543 | |
2544 | // Check that each user of each old PHI node is something that we can |
2545 | // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. |
2546 | for (auto *OldPN : OldPhiNodes) { |
2547 | for (User *V : OldPN->users()) { |
2548 | if (auto *SI = dyn_cast<StoreInst>(Val: V)) { |
2549 | if (!SI->isSimple() || SI->getOperand(i_nocapture: 0) != OldPN) |
2550 | return nullptr; |
2551 | } else if (auto *BCI = dyn_cast<BitCastInst>(Val: V)) { |
2552 | // Verify it's a B->A cast. |
2553 | Type *TyB = BCI->getOperand(i_nocapture: 0)->getType(); |
2554 | Type *TyA = BCI->getType(); |
2555 | if (TyA != DestTy || TyB != SrcTy) |
2556 | return nullptr; |
2557 | } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) { |
2558 | // As long as the user is another old PHI node, then even if we don't |
2559 | // rewrite it, the PHI web we're considering won't have any users |
2560 | // outside itself, so it'll be dead. |
2561 | if (!OldPhiNodes.contains(key: PHI)) |
2562 | return nullptr; |
2563 | } else { |
2564 | return nullptr; |
2565 | } |
2566 | } |
2567 | } |
2568 | |
2569 | // For each old PHI node, create a corresponding new PHI node with a type A. |
2570 | SmallDenseMap<PHINode *, PHINode *> NewPNodes; |
2571 | for (auto *OldPN : OldPhiNodes) { |
2572 | Builder.SetInsertPoint(OldPN); |
2573 | PHINode *NewPN = Builder.CreatePHI(Ty: DestTy, NumReservedValues: OldPN->getNumOperands()); |
2574 | NewPNodes[OldPN] = NewPN; |
2575 | } |
2576 | |
2577 | // Fill in the operands of new PHI nodes. |
2578 | for (auto *OldPN : OldPhiNodes) { |
2579 | PHINode *NewPN = NewPNodes[OldPN]; |
2580 | for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { |
2581 | Value *V = OldPN->getOperand(i_nocapture: j); |
2582 | Value *NewV = nullptr; |
2583 | if (auto *C = dyn_cast<Constant>(Val: V)) { |
2584 | NewV = ConstantExpr::getBitCast(C, Ty: DestTy); |
2585 | } else if (auto *LI = dyn_cast<LoadInst>(Val: V)) { |
2586 | // Explicitly perform load combine to make sure no opposing transform |
2587 | // can remove the bitcast in the meantime and trigger an infinite loop. |
2588 | Builder.SetInsertPoint(LI); |
2589 | NewV = combineLoadToNewType(LI&: *LI, NewTy: DestTy); |
2590 | // Remove the old load and its use in the old phi, which itself becomes |
2591 | // dead once the whole transform finishes. |
2592 | replaceInstUsesWith(I&: *LI, V: PoisonValue::get(T: LI->getType())); |
2593 | eraseInstFromFunction(I&: *LI); |
2594 | } else if (auto *BCI = dyn_cast<BitCastInst>(Val: V)) { |
2595 | NewV = BCI->getOperand(i_nocapture: 0); |
2596 | } else if (auto *PrevPN = dyn_cast<PHINode>(Val: V)) { |
2597 | NewV = NewPNodes[PrevPN]; |
2598 | } |
2599 | assert(NewV); |
2600 | NewPN->addIncoming(V: NewV, BB: OldPN->getIncomingBlock(i: j)); |
2601 | } |
2602 | } |
2603 | |
2604 | // Traverse all accumulated PHI nodes and process its users, |
2605 | // which are Stores and BitcCasts. Without this processing |
2606 | // NewPHI nodes could be replicated and could lead to extra |
2607 | // moves generated after DeSSA. |
2608 | // If there is a store with type B, change it to type A. |
2609 | |
2610 | |
2611 | // Replace users of BitCast B->A with NewPHI. These will help |
2612 | // later to get rid off a closure formed by OldPHI nodes. |
2613 | Instruction *RetVal = nullptr; |
2614 | for (auto *OldPN : OldPhiNodes) { |
2615 | PHINode *NewPN = NewPNodes[OldPN]; |
2616 | for (User *V : make_early_inc_range(Range: OldPN->users())) { |
2617 | if (auto *SI = dyn_cast<StoreInst>(Val: V)) { |
2618 | assert(SI->isSimple() && SI->getOperand(0) == OldPN); |
2619 | Builder.SetInsertPoint(SI); |
2620 | auto *NewBC = |
2621 | cast<BitCastInst>(Val: Builder.CreateBitCast(V: NewPN, DestTy: SrcTy)); |
2622 | SI->setOperand(i_nocapture: 0, Val_nocapture: NewBC); |
2623 | Worklist.push(I: SI); |
2624 | assert(hasStoreUsersOnly(*NewBC)); |
2625 | } |
2626 | else if (auto *BCI = dyn_cast<BitCastInst>(Val: V)) { |
2627 | Type *TyB = BCI->getOperand(i_nocapture: 0)->getType(); |
2628 | Type *TyA = BCI->getType(); |
2629 | assert(TyA == DestTy && TyB == SrcTy); |
2630 | (void) TyA; |
2631 | (void) TyB; |
2632 | Instruction *I = replaceInstUsesWith(I&: *BCI, V: NewPN); |
2633 | if (BCI == &CI) |
2634 | RetVal = I; |
2635 | } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) { |
2636 | assert(OldPhiNodes.contains(PHI)); |
2637 | (void) PHI; |
2638 | } else { |
2639 | llvm_unreachable("all uses should be handled" ); |
2640 | } |
2641 | } |
2642 | } |
2643 | |
2644 | return RetVal; |
2645 | } |
2646 | |
2647 | Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { |
2648 | // If the operands are integer typed then apply the integer transforms, |
2649 | // otherwise just apply the common ones. |
2650 | Value *Src = CI.getOperand(i_nocapture: 0); |
2651 | Type *SrcTy = Src->getType(); |
2652 | Type *DestTy = CI.getType(); |
2653 | |
2654 | // Get rid of casts from one type to the same type. These are useless and can |
2655 | // be replaced by the operand. |
2656 | if (DestTy == Src->getType()) |
2657 | return replaceInstUsesWith(I&: CI, V: Src); |
2658 | |
2659 | if (FixedVectorType *DestVTy = dyn_cast<FixedVectorType>(Val: DestTy)) { |
2660 | // Beware: messing with this target-specific oddity may cause trouble. |
2661 | if (DestVTy->getNumElements() == 1 && SrcTy->isX86_MMXTy()) { |
2662 | Value *Elem = Builder.CreateBitCast(V: Src, DestTy: DestVTy->getElementType()); |
2663 | return InsertElementInst::Create(Vec: PoisonValue::get(T: DestTy), NewElt: Elem, |
2664 | Idx: Constant::getNullValue(Ty: Type::getInt32Ty(C&: CI.getContext()))); |
2665 | } |
2666 | |
2667 | if (isa<IntegerType>(Val: SrcTy)) { |
2668 | // If this is a cast from an integer to vector, check to see if the input |
2669 | // is a trunc or zext of a bitcast from vector. If so, we can replace all |
2670 | // the casts with a shuffle and (potentially) a bitcast. |
2671 | if (isa<TruncInst>(Val: Src) || isa<ZExtInst>(Val: Src)) { |
2672 | CastInst *SrcCast = cast<CastInst>(Val: Src); |
2673 | if (BitCastInst *BCIn = dyn_cast<BitCastInst>(Val: SrcCast->getOperand(i_nocapture: 0))) |
2674 | if (isa<VectorType>(Val: BCIn->getOperand(i_nocapture: 0)->getType())) |
2675 | if (Instruction *I = optimizeVectorResizeWithIntegerBitCasts( |
2676 | InVal: BCIn->getOperand(i_nocapture: 0), DestTy: cast<VectorType>(Val: DestTy), IC&: *this)) |
2677 | return I; |
2678 | } |
2679 | |
2680 | // If the input is an 'or' instruction, we may be doing shifts and ors to |
2681 | // assemble the elements of the vector manually. Try to rip the code out |
2682 | // and replace it with insertelements. |
2683 | if (Value *V = optimizeIntegerToVectorInsertions(CI, IC&: *this)) |
2684 | return replaceInstUsesWith(I&: CI, V); |
2685 | } |
2686 | } |
2687 | |
2688 | if (FixedVectorType *SrcVTy = dyn_cast<FixedVectorType>(Val: SrcTy)) { |
2689 | if (SrcVTy->getNumElements() == 1) { |
2690 | // If our destination is not a vector, then make this a straight |
2691 | // scalar-scalar cast. |
2692 | if (!DestTy->isVectorTy()) { |
2693 | Value *Elem = |
2694 | Builder.CreateExtractElement(Vec: Src, |
2695 | Idx: Constant::getNullValue(Ty: Type::getInt32Ty(C&: CI.getContext()))); |
2696 | return CastInst::Create(Instruction::BitCast, S: Elem, Ty: DestTy); |
2697 | } |
2698 | |
2699 | // Otherwise, see if our source is an insert. If so, then use the scalar |
2700 | // component directly: |
2701 | // bitcast (inselt <1 x elt> V, X, 0) to <n x m> --> bitcast X to <n x m> |
2702 | if (auto *InsElt = dyn_cast<InsertElementInst>(Val: Src)) |
2703 | return new BitCastInst(InsElt->getOperand(i_nocapture: 1), DestTy); |
2704 | } |
2705 | |
2706 | // Convert an artificial vector insert into more analyzable bitwise logic. |
2707 | unsigned BitWidth = DestTy->getScalarSizeInBits(); |
2708 | Value *X, *Y; |
2709 | uint64_t IndexC; |
2710 | if (match(V: Src, P: m_OneUse(SubPattern: m_InsertElt(Val: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: X))), |
2711 | Elt: m_Value(V&: Y), Idx: m_ConstantInt(V&: IndexC)))) && |
2712 | DestTy->isIntegerTy() && X->getType() == DestTy && |
2713 | Y->getType()->isIntegerTy() && isDesirableIntType(BitWidth)) { |
2714 | // Adjust for big endian - the LSBs are at the high index. |
2715 | if (DL.isBigEndian()) |
2716 | IndexC = SrcVTy->getNumElements() - 1 - IndexC; |
2717 | |
2718 | // We only handle (endian-normalized) insert to index 0. Any other insert |
2719 | // would require a left-shift, so that is an extra instruction. |
2720 | if (IndexC == 0) { |
2721 | // bitcast (inselt (bitcast X), Y, 0) --> or (and X, MaskC), (zext Y) |
2722 | unsigned EltWidth = Y->getType()->getScalarSizeInBits(); |
2723 | APInt MaskC = APInt::getHighBitsSet(numBits: BitWidth, hiBitsSet: BitWidth - EltWidth); |
2724 | Value *AndX = Builder.CreateAnd(LHS: X, RHS: MaskC); |
2725 | Value *ZextY = Builder.CreateZExt(V: Y, DestTy); |
2726 | return BinaryOperator::CreateOr(V1: AndX, V2: ZextY); |
2727 | } |
2728 | } |
2729 | } |
2730 | |
2731 | if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: Src)) { |
2732 | // Okay, we have (bitcast (shuffle ..)). Check to see if this is |
2733 | // a bitcast to a vector with the same # elts. |
2734 | Value *ShufOp0 = Shuf->getOperand(i_nocapture: 0); |
2735 | Value *ShufOp1 = Shuf->getOperand(i_nocapture: 1); |
2736 | auto ShufElts = cast<VectorType>(Val: Shuf->getType())->getElementCount(); |
2737 | auto SrcVecElts = cast<VectorType>(Val: ShufOp0->getType())->getElementCount(); |
2738 | if (Shuf->hasOneUse() && DestTy->isVectorTy() && |
2739 | cast<VectorType>(Val: DestTy)->getElementCount() == ShufElts && |
2740 | ShufElts == SrcVecElts) { |
2741 | BitCastInst *Tmp; |
2742 | // If either of the operands is a cast from CI.getType(), then |
2743 | // evaluating the shuffle in the casted destination's type will allow |
2744 | // us to eliminate at least one cast. |
2745 | if (((Tmp = dyn_cast<BitCastInst>(Val: ShufOp0)) && |
2746 | Tmp->getOperand(i_nocapture: 0)->getType() == DestTy) || |
2747 | ((Tmp = dyn_cast<BitCastInst>(Val: ShufOp1)) && |
2748 | Tmp->getOperand(i_nocapture: 0)->getType() == DestTy)) { |
2749 | Value *LHS = Builder.CreateBitCast(V: ShufOp0, DestTy); |
2750 | Value *RHS = Builder.CreateBitCast(V: ShufOp1, DestTy); |
2751 | // Return a new shuffle vector. Use the same element ID's, as we |
2752 | // know the vector types match #elts. |
2753 | return new ShuffleVectorInst(LHS, RHS, Shuf->getShuffleMask()); |
2754 | } |
2755 | } |
2756 | |
2757 | // A bitcasted-to-scalar and byte/bit reversing shuffle is better recognized |
2758 | // as a byte/bit swap: |
2759 | // bitcast <N x i8> (shuf X, undef, <N, N-1,...0>) -> bswap (bitcast X) |
2760 | // bitcast <N x i1> (shuf X, undef, <N, N-1,...0>) -> bitreverse (bitcast X) |
2761 | if (DestTy->isIntegerTy() && ShufElts.getKnownMinValue() % 2 == 0 && |
2762 | Shuf->hasOneUse() && Shuf->isReverse()) { |
2763 | unsigned IntrinsicNum = 0; |
2764 | if (DL.isLegalInteger(Width: DestTy->getScalarSizeInBits()) && |
2765 | SrcTy->getScalarSizeInBits() == 8) { |
2766 | IntrinsicNum = Intrinsic::bswap; |
2767 | } else if (SrcTy->getScalarSizeInBits() == 1) { |
2768 | IntrinsicNum = Intrinsic::bitreverse; |
2769 | } |
2770 | if (IntrinsicNum != 0) { |
2771 | assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask" ); |
2772 | assert(match(ShufOp1, m_Undef()) && "Unexpected shuffle op" ); |
2773 | Function *BswapOrBitreverse = |
2774 | Intrinsic::getDeclaration(M: CI.getModule(), id: IntrinsicNum, Tys: DestTy); |
2775 | Value *ScalarX = Builder.CreateBitCast(V: ShufOp0, DestTy); |
2776 | return CallInst::Create(Func: BswapOrBitreverse, Args: {ScalarX}); |
2777 | } |
2778 | } |
2779 | } |
2780 | |
2781 | // Handle the A->B->A cast, and there is an intervening PHI node. |
2782 | if (PHINode *PN = dyn_cast<PHINode>(Val: Src)) |
2783 | if (Instruction *I = optimizeBitCastFromPhi(CI, PN)) |
2784 | return I; |
2785 | |
2786 | if (Instruction *I = canonicalizeBitCastExtElt(BitCast&: CI, IC&: *this)) |
2787 | return I; |
2788 | |
2789 | if (Instruction *I = foldBitCastBitwiseLogic(BitCast&: CI, Builder)) |
2790 | return I; |
2791 | |
2792 | if (Instruction *I = foldBitCastSelect(BitCast&: CI, Builder)) |
2793 | return I; |
2794 | |
2795 | return commonCastTransforms(CI); |
2796 | } |
2797 | |
2798 | Instruction *InstCombinerImpl::visitAddrSpaceCast(AddrSpaceCastInst &CI) { |
2799 | return commonCastTransforms(CI); |
2800 | } |
2801 | |