1 | //===-- lib/CodeGen/GlobalISel/GICombinerHelper.cpp -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" |
9 | #include "llvm/ADT/APFloat.h" |
10 | #include "llvm/ADT/STLExtras.h" |
11 | #include "llvm/ADT/SetVector.h" |
12 | #include "llvm/ADT/SmallBitVector.h" |
13 | #include "llvm/Analysis/CmpInstAnalysis.h" |
14 | #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" |
15 | #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" |
16 | #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" |
17 | #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" |
18 | #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" |
19 | #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" |
20 | #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" |
21 | #include "llvm/CodeGen/GlobalISel/Utils.h" |
22 | #include "llvm/CodeGen/LowLevelTypeUtils.h" |
23 | #include "llvm/CodeGen/MachineBasicBlock.h" |
24 | #include "llvm/CodeGen/MachineDominators.h" |
25 | #include "llvm/CodeGen/MachineInstr.h" |
26 | #include "llvm/CodeGen/MachineMemOperand.h" |
27 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
28 | #include "llvm/CodeGen/RegisterBankInfo.h" |
29 | #include "llvm/CodeGen/TargetInstrInfo.h" |
30 | #include "llvm/CodeGen/TargetLowering.h" |
31 | #include "llvm/CodeGen/TargetOpcodes.h" |
32 | #include "llvm/IR/ConstantRange.h" |
33 | #include "llvm/IR/DataLayout.h" |
34 | #include "llvm/IR/InstrTypes.h" |
35 | #include "llvm/Support/Casting.h" |
36 | #include "llvm/Support/DivisionByConstantInfo.h" |
37 | #include "llvm/Support/ErrorHandling.h" |
38 | #include "llvm/Support/MathExtras.h" |
39 | #include "llvm/Target/TargetMachine.h" |
40 | #include <cmath> |
41 | #include <optional> |
42 | #include <tuple> |
43 | |
44 | #define DEBUG_TYPE "gi-combiner" |
45 | |
46 | using namespace llvm; |
47 | using namespace MIPatternMatch; |
48 | |
49 | // Option to allow testing of the combiner while no targets know about indexed |
50 | // addressing. |
51 | static cl::opt<bool> |
52 | ForceLegalIndexing("force-legal-indexing" , cl::Hidden, cl::init(Val: false), |
53 | cl::desc("Force all indexed operations to be " |
54 | "legal for the GlobalISel combiner" )); |
55 | |
56 | CombinerHelper::CombinerHelper(GISelChangeObserver &Observer, |
57 | MachineIRBuilder &B, bool IsPreLegalize, |
58 | GISelKnownBits *KB, MachineDominatorTree *MDT, |
59 | const LegalizerInfo *LI) |
60 | : Builder(B), MRI(Builder.getMF().getRegInfo()), Observer(Observer), KB(KB), |
61 | MDT(MDT), IsPreLegalize(IsPreLegalize), LI(LI), |
62 | RBI(Builder.getMF().getSubtarget().getRegBankInfo()), |
63 | TRI(Builder.getMF().getSubtarget().getRegisterInfo()) { |
64 | (void)this->KB; |
65 | } |
66 | |
67 | const TargetLowering &CombinerHelper::getTargetLowering() const { |
68 | return *Builder.getMF().getSubtarget().getTargetLowering(); |
69 | } |
70 | |
71 | /// \returns The little endian in-memory byte position of byte \p I in a |
72 | /// \p ByteWidth bytes wide type. |
73 | /// |
74 | /// E.g. Given a 4-byte type x, x[0] -> byte 0 |
75 | static unsigned littleEndianByteAt(const unsigned ByteWidth, const unsigned I) { |
76 | assert(I < ByteWidth && "I must be in [0, ByteWidth)" ); |
77 | return I; |
78 | } |
79 | |
80 | /// Determines the LogBase2 value for a non-null input value using the |
81 | /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V). |
82 | static Register buildLogBase2(Register V, MachineIRBuilder &MIB) { |
83 | auto &MRI = *MIB.getMRI(); |
84 | LLT Ty = MRI.getType(Reg: V); |
85 | auto Ctlz = MIB.buildCTLZ(Dst: Ty, Src0: V); |
86 | auto Base = MIB.buildConstant(Res: Ty, Val: Ty.getScalarSizeInBits() - 1); |
87 | return MIB.buildSub(Dst: Ty, Src0: Base, Src1: Ctlz).getReg(Idx: 0); |
88 | } |
89 | |
90 | /// \returns The big endian in-memory byte position of byte \p I in a |
91 | /// \p ByteWidth bytes wide type. |
92 | /// |
93 | /// E.g. Given a 4-byte type x, x[0] -> byte 3 |
94 | static unsigned bigEndianByteAt(const unsigned ByteWidth, const unsigned I) { |
95 | assert(I < ByteWidth && "I must be in [0, ByteWidth)" ); |
96 | return ByteWidth - I - 1; |
97 | } |
98 | |
99 | /// Given a map from byte offsets in memory to indices in a load/store, |
100 | /// determine if that map corresponds to a little or big endian byte pattern. |
101 | /// |
102 | /// \param MemOffset2Idx maps memory offsets to address offsets. |
103 | /// \param LowestIdx is the lowest index in \p MemOffset2Idx. |
104 | /// |
105 | /// \returns true if the map corresponds to a big endian byte pattern, false if |
106 | /// it corresponds to a little endian byte pattern, and std::nullopt otherwise. |
107 | /// |
108 | /// E.g. given a 32-bit type x, and x[AddrOffset], the in-memory byte patterns |
109 | /// are as follows: |
110 | /// |
111 | /// AddrOffset Little endian Big endian |
112 | /// 0 0 3 |
113 | /// 1 1 2 |
114 | /// 2 2 1 |
115 | /// 3 3 0 |
116 | static std::optional<bool> |
117 | isBigEndian(const SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx, |
118 | int64_t LowestIdx) { |
119 | // Need at least two byte positions to decide on endianness. |
120 | unsigned Width = MemOffset2Idx.size(); |
121 | if (Width < 2) |
122 | return std::nullopt; |
123 | bool BigEndian = true, LittleEndian = true; |
124 | for (unsigned MemOffset = 0; MemOffset < Width; ++ MemOffset) { |
125 | auto MemOffsetAndIdx = MemOffset2Idx.find(Val: MemOffset); |
126 | if (MemOffsetAndIdx == MemOffset2Idx.end()) |
127 | return std::nullopt; |
128 | const int64_t Idx = MemOffsetAndIdx->second - LowestIdx; |
129 | assert(Idx >= 0 && "Expected non-negative byte offset?" ); |
130 | LittleEndian &= Idx == littleEndianByteAt(ByteWidth: Width, I: MemOffset); |
131 | BigEndian &= Idx == bigEndianByteAt(ByteWidth: Width, I: MemOffset); |
132 | if (!BigEndian && !LittleEndian) |
133 | return std::nullopt; |
134 | } |
135 | |
136 | assert((BigEndian != LittleEndian) && |
137 | "Pattern cannot be both big and little endian!" ); |
138 | return BigEndian; |
139 | } |
140 | |
141 | bool CombinerHelper::isPreLegalize() const { return IsPreLegalize; } |
142 | |
143 | bool CombinerHelper::isLegal(const LegalityQuery &Query) const { |
144 | assert(LI && "Must have LegalizerInfo to query isLegal!" ); |
145 | return LI->getAction(Query).Action == LegalizeActions::Legal; |
146 | } |
147 | |
148 | bool CombinerHelper::isLegalOrBeforeLegalizer( |
149 | const LegalityQuery &Query) const { |
150 | return isPreLegalize() || isLegal(Query); |
151 | } |
152 | |
153 | bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const { |
154 | if (!Ty.isVector()) |
155 | return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_CONSTANT, {Ty}}); |
156 | // Vector constants are represented as a G_BUILD_VECTOR of scalar G_CONSTANTs. |
157 | if (isPreLegalize()) |
158 | return true; |
159 | LLT EltTy = Ty.getElementType(); |
160 | return isLegal(Query: {TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}}) && |
161 | isLegal(Query: {TargetOpcode::G_CONSTANT, {EltTy}}); |
162 | } |
163 | |
164 | void CombinerHelper::replaceRegWith(MachineRegisterInfo &MRI, Register FromReg, |
165 | Register ToReg) const { |
166 | Observer.changingAllUsesOfReg(MRI, Reg: FromReg); |
167 | |
168 | if (MRI.constrainRegAttrs(Reg: ToReg, ConstrainingReg: FromReg)) |
169 | MRI.replaceRegWith(FromReg, ToReg); |
170 | else |
171 | Builder.buildCopy(Res: ToReg, Op: FromReg); |
172 | |
173 | Observer.finishedChangingAllUsesOfReg(); |
174 | } |
175 | |
176 | void CombinerHelper::replaceRegOpWith(MachineRegisterInfo &MRI, |
177 | MachineOperand &FromRegOp, |
178 | Register ToReg) const { |
179 | assert(FromRegOp.getParent() && "Expected an operand in an MI" ); |
180 | Observer.changingInstr(MI&: *FromRegOp.getParent()); |
181 | |
182 | FromRegOp.setReg(ToReg); |
183 | |
184 | Observer.changedInstr(MI&: *FromRegOp.getParent()); |
185 | } |
186 | |
187 | void CombinerHelper::replaceOpcodeWith(MachineInstr &FromMI, |
188 | unsigned ToOpcode) const { |
189 | Observer.changingInstr(MI&: FromMI); |
190 | |
191 | FromMI.setDesc(Builder.getTII().get(Opcode: ToOpcode)); |
192 | |
193 | Observer.changedInstr(MI&: FromMI); |
194 | } |
195 | |
196 | const RegisterBank *CombinerHelper::getRegBank(Register Reg) const { |
197 | return RBI->getRegBank(Reg, MRI, TRI: *TRI); |
198 | } |
199 | |
200 | void CombinerHelper::setRegBank(Register Reg, const RegisterBank *RegBank) { |
201 | if (RegBank) |
202 | MRI.setRegBank(Reg, RegBank: *RegBank); |
203 | } |
204 | |
205 | bool CombinerHelper::tryCombineCopy(MachineInstr &MI) { |
206 | if (matchCombineCopy(MI)) { |
207 | applyCombineCopy(MI); |
208 | return true; |
209 | } |
210 | return false; |
211 | } |
212 | bool CombinerHelper::matchCombineCopy(MachineInstr &MI) { |
213 | if (MI.getOpcode() != TargetOpcode::COPY) |
214 | return false; |
215 | Register DstReg = MI.getOperand(i: 0).getReg(); |
216 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
217 | return canReplaceReg(DstReg, SrcReg, MRI); |
218 | } |
219 | void CombinerHelper::applyCombineCopy(MachineInstr &MI) { |
220 | Register DstReg = MI.getOperand(i: 0).getReg(); |
221 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
222 | MI.eraseFromParent(); |
223 | replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg); |
224 | } |
225 | |
226 | bool CombinerHelper::matchCombineConcatVectors(MachineInstr &MI, |
227 | SmallVector<Register> &Ops) { |
228 | assert(MI.getOpcode() == TargetOpcode::G_CONCAT_VECTORS && |
229 | "Invalid instruction" ); |
230 | bool IsUndef = true; |
231 | MachineInstr *Undef = nullptr; |
232 | |
233 | // Walk over all the operands of concat vectors and check if they are |
234 | // build_vector themselves or undef. |
235 | // Then collect their operands in Ops. |
236 | for (const MachineOperand &MO : MI.uses()) { |
237 | Register Reg = MO.getReg(); |
238 | MachineInstr *Def = MRI.getVRegDef(Reg); |
239 | assert(Def && "Operand not defined" ); |
240 | if (!MRI.hasOneNonDBGUse(RegNo: Reg)) |
241 | return false; |
242 | switch (Def->getOpcode()) { |
243 | case TargetOpcode::G_BUILD_VECTOR: |
244 | IsUndef = false; |
245 | // Remember the operands of the build_vector to fold |
246 | // them into the yet-to-build flattened concat vectors. |
247 | for (const MachineOperand &BuildVecMO : Def->uses()) |
248 | Ops.push_back(Elt: BuildVecMO.getReg()); |
249 | break; |
250 | case TargetOpcode::G_IMPLICIT_DEF: { |
251 | LLT OpType = MRI.getType(Reg); |
252 | // Keep one undef value for all the undef operands. |
253 | if (!Undef) { |
254 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
255 | Undef = Builder.buildUndef(Res: OpType.getScalarType()); |
256 | } |
257 | assert(MRI.getType(Undef->getOperand(0).getReg()) == |
258 | OpType.getScalarType() && |
259 | "All undefs should have the same type" ); |
260 | // Break the undef vector in as many scalar elements as needed |
261 | // for the flattening. |
262 | for (unsigned EltIdx = 0, EltEnd = OpType.getNumElements(); |
263 | EltIdx != EltEnd; ++EltIdx) |
264 | Ops.push_back(Elt: Undef->getOperand(i: 0).getReg()); |
265 | break; |
266 | } |
267 | default: |
268 | return false; |
269 | } |
270 | } |
271 | |
272 | // Check if the combine is illegal |
273 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
274 | if (!isLegalOrBeforeLegalizer( |
275 | Query: {TargetOpcode::G_BUILD_VECTOR, {DstTy, MRI.getType(Reg: Ops[0])}})) { |
276 | return false; |
277 | } |
278 | |
279 | if (IsUndef) |
280 | Ops.clear(); |
281 | |
282 | return true; |
283 | } |
284 | void CombinerHelper::applyCombineConcatVectors(MachineInstr &MI, |
285 | SmallVector<Register> &Ops) { |
286 | // We determined that the concat_vectors can be flatten. |
287 | // Generate the flattened build_vector. |
288 | Register DstReg = MI.getOperand(i: 0).getReg(); |
289 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
290 | Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg); |
291 | |
292 | // Note: IsUndef is sort of redundant. We could have determine it by |
293 | // checking that at all Ops are undef. Alternatively, we could have |
294 | // generate a build_vector of undefs and rely on another combine to |
295 | // clean that up. For now, given we already gather this information |
296 | // in matchCombineConcatVectors, just save compile time and issue the |
297 | // right thing. |
298 | if (Ops.empty()) |
299 | Builder.buildUndef(Res: NewDstReg); |
300 | else |
301 | Builder.buildBuildVector(Res: NewDstReg, Ops); |
302 | MI.eraseFromParent(); |
303 | replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg); |
304 | } |
305 | |
306 | bool CombinerHelper::matchCombineShuffleConcat(MachineInstr &MI, |
307 | SmallVector<Register> &Ops) { |
308 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
309 | auto ConcatMI1 = |
310 | dyn_cast<GConcatVectors>(Val: MRI.getVRegDef(Reg: MI.getOperand(i: 1).getReg())); |
311 | auto ConcatMI2 = |
312 | dyn_cast<GConcatVectors>(Val: MRI.getVRegDef(Reg: MI.getOperand(i: 2).getReg())); |
313 | if (!ConcatMI1 || !ConcatMI2) |
314 | return false; |
315 | |
316 | // Check that the sources of the Concat instructions have the same type |
317 | if (MRI.getType(Reg: ConcatMI1->getSourceReg(I: 0)) != |
318 | MRI.getType(Reg: ConcatMI2->getSourceReg(I: 0))) |
319 | return false; |
320 | |
321 | LLT ConcatSrcTy = MRI.getType(Reg: ConcatMI1->getReg(Idx: 1)); |
322 | LLT ShuffleSrcTy1 = MRI.getType(Reg: MI.getOperand(i: 1).getReg()); |
323 | unsigned ConcatSrcNumElt = ConcatSrcTy.getNumElements(); |
324 | for (unsigned i = 0; i < Mask.size(); i += ConcatSrcNumElt) { |
325 | // Check if the index takes a whole source register from G_CONCAT_VECTORS |
326 | // Assumes that all Sources of G_CONCAT_VECTORS are the same type |
327 | if (Mask[i] == -1) { |
328 | for (unsigned j = 1; j < ConcatSrcNumElt; j++) { |
329 | if (i + j >= Mask.size()) |
330 | return false; |
331 | if (Mask[i + j] != -1) |
332 | return false; |
333 | } |
334 | if (!isLegalOrBeforeLegalizer( |
335 | Query: {TargetOpcode::G_IMPLICIT_DEF, {ConcatSrcTy}})) |
336 | return false; |
337 | Ops.push_back(Elt: 0); |
338 | } else if (Mask[i] % ConcatSrcNumElt == 0) { |
339 | for (unsigned j = 1; j < ConcatSrcNumElt; j++) { |
340 | if (i + j >= Mask.size()) |
341 | return false; |
342 | if (Mask[i + j] != Mask[i] + static_cast<int>(j)) |
343 | return false; |
344 | } |
345 | // Retrieve the source register from its respective G_CONCAT_VECTORS |
346 | // instruction |
347 | if (Mask[i] < ShuffleSrcTy1.getNumElements()) { |
348 | Ops.push_back(Elt: ConcatMI1->getSourceReg(I: Mask[i] / ConcatSrcNumElt)); |
349 | } else { |
350 | Ops.push_back(Elt: ConcatMI2->getSourceReg(I: Mask[i] / ConcatSrcNumElt - |
351 | ConcatMI1->getNumSources())); |
352 | } |
353 | } else { |
354 | return false; |
355 | } |
356 | } |
357 | |
358 | if (!isLegalOrBeforeLegalizer( |
359 | Query: {TargetOpcode::G_CONCAT_VECTORS, |
360 | {MRI.getType(Reg: MI.getOperand(i: 0).getReg()), ConcatSrcTy}})) |
361 | return false; |
362 | |
363 | return !Ops.empty(); |
364 | } |
365 | |
366 | void CombinerHelper::applyCombineShuffleConcat(MachineInstr &MI, |
367 | SmallVector<Register> &Ops) { |
368 | LLT SrcTy = MRI.getType(Reg: Ops[0]); |
369 | Register UndefReg = 0; |
370 | |
371 | for (unsigned i = 0; i < Ops.size(); i++) { |
372 | if (Ops[i] == 0) { |
373 | if (UndefReg == 0) |
374 | UndefReg = Builder.buildUndef(Res: SrcTy).getReg(Idx: 0); |
375 | Ops[i] = UndefReg; |
376 | } |
377 | } |
378 | |
379 | Builder.buildConcatVectors(Res: MI.getOperand(i: 0).getReg(), Ops); |
380 | MI.eraseFromParent(); |
381 | } |
382 | |
383 | bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) { |
384 | SmallVector<Register, 4> Ops; |
385 | if (matchCombineShuffleVector(MI, Ops)) { |
386 | applyCombineShuffleVector(MI, Ops); |
387 | return true; |
388 | } |
389 | return false; |
390 | } |
391 | |
392 | bool CombinerHelper::matchCombineShuffleVector(MachineInstr &MI, |
393 | SmallVectorImpl<Register> &Ops) { |
394 | assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR && |
395 | "Invalid instruction kind" ); |
396 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
397 | Register Src1 = MI.getOperand(i: 1).getReg(); |
398 | LLT SrcType = MRI.getType(Reg: Src1); |
399 | // As bizarre as it may look, shuffle vector can actually produce |
400 | // scalar! This is because at the IR level a <1 x ty> shuffle |
401 | // vector is perfectly valid. |
402 | unsigned DstNumElts = DstType.isVector() ? DstType.getNumElements() : 1; |
403 | unsigned SrcNumElts = SrcType.isVector() ? SrcType.getNumElements() : 1; |
404 | |
405 | // If the resulting vector is smaller than the size of the source |
406 | // vectors being concatenated, we won't be able to replace the |
407 | // shuffle vector into a concat_vectors. |
408 | // |
409 | // Note: We may still be able to produce a concat_vectors fed by |
410 | // extract_vector_elt and so on. It is less clear that would |
411 | // be better though, so don't bother for now. |
412 | // |
413 | // If the destination is a scalar, the size of the sources doesn't |
414 | // matter. we will lower the shuffle to a plain copy. This will |
415 | // work only if the source and destination have the same size. But |
416 | // that's covered by the next condition. |
417 | // |
418 | // TODO: If the size between the source and destination don't match |
419 | // we could still emit an extract vector element in that case. |
420 | if (DstNumElts < 2 * SrcNumElts && DstNumElts != 1) |
421 | return false; |
422 | |
423 | // Check that the shuffle mask can be broken evenly between the |
424 | // different sources. |
425 | if (DstNumElts % SrcNumElts != 0) |
426 | return false; |
427 | |
428 | // Mask length is a multiple of the source vector length. |
429 | // Check if the shuffle is some kind of concatenation of the input |
430 | // vectors. |
431 | unsigned NumConcat = DstNumElts / SrcNumElts; |
432 | SmallVector<int, 8> ConcatSrcs(NumConcat, -1); |
433 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
434 | for (unsigned i = 0; i != DstNumElts; ++i) { |
435 | int Idx = Mask[i]; |
436 | // Undef value. |
437 | if (Idx < 0) |
438 | continue; |
439 | // Ensure the indices in each SrcType sized piece are sequential and that |
440 | // the same source is used for the whole piece. |
441 | if ((Idx % SrcNumElts != (i % SrcNumElts)) || |
442 | (ConcatSrcs[i / SrcNumElts] >= 0 && |
443 | ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts))) |
444 | return false; |
445 | // Remember which source this index came from. |
446 | ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts; |
447 | } |
448 | |
449 | // The shuffle is concatenating multiple vectors together. |
450 | // Collect the different operands for that. |
451 | Register UndefReg; |
452 | Register Src2 = MI.getOperand(i: 2).getReg(); |
453 | for (auto Src : ConcatSrcs) { |
454 | if (Src < 0) { |
455 | if (!UndefReg) { |
456 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
457 | UndefReg = Builder.buildUndef(Res: SrcType).getReg(Idx: 0); |
458 | } |
459 | Ops.push_back(Elt: UndefReg); |
460 | } else if (Src == 0) |
461 | Ops.push_back(Elt: Src1); |
462 | else |
463 | Ops.push_back(Elt: Src2); |
464 | } |
465 | return true; |
466 | } |
467 | |
468 | void CombinerHelper::applyCombineShuffleVector(MachineInstr &MI, |
469 | const ArrayRef<Register> Ops) { |
470 | Register DstReg = MI.getOperand(i: 0).getReg(); |
471 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
472 | Register NewDstReg = MRI.cloneVirtualRegister(VReg: DstReg); |
473 | |
474 | if (Ops.size() == 1) |
475 | Builder.buildCopy(Res: NewDstReg, Op: Ops[0]); |
476 | else |
477 | Builder.buildMergeLikeInstr(Res: NewDstReg, Ops); |
478 | |
479 | MI.eraseFromParent(); |
480 | replaceRegWith(MRI, FromReg: DstReg, ToReg: NewDstReg); |
481 | } |
482 | |
483 | bool CombinerHelper::(MachineInstr &MI) { |
484 | assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR && |
485 | "Invalid instruction kind" ); |
486 | |
487 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
488 | return Mask.size() == 1; |
489 | } |
490 | |
491 | void CombinerHelper::(MachineInstr &MI) { |
492 | Register DstReg = MI.getOperand(i: 0).getReg(); |
493 | Builder.setInsertPt(MBB&: *MI.getParent(), II: MI); |
494 | |
495 | int I = MI.getOperand(i: 3).getShuffleMask()[0]; |
496 | Register Src1 = MI.getOperand(i: 1).getReg(); |
497 | LLT Src1Ty = MRI.getType(Reg: Src1); |
498 | int Src1NumElts = Src1Ty.isVector() ? Src1Ty.getNumElements() : 1; |
499 | Register SrcReg; |
500 | if (I >= Src1NumElts) { |
501 | SrcReg = MI.getOperand(i: 2).getReg(); |
502 | I -= Src1NumElts; |
503 | } else if (I >= 0) |
504 | SrcReg = Src1; |
505 | |
506 | if (I < 0) |
507 | Builder.buildUndef(Res: DstReg); |
508 | else if (!MRI.getType(Reg: SrcReg).isVector()) |
509 | Builder.buildCopy(Res: DstReg, Op: SrcReg); |
510 | else |
511 | Builder.buildExtractVectorElementConstant(Res: DstReg, Val: SrcReg, Idx: I); |
512 | |
513 | MI.eraseFromParent(); |
514 | } |
515 | |
516 | namespace { |
517 | |
518 | /// Select a preference between two uses. CurrentUse is the current preference |
519 | /// while *ForCandidate is attributes of the candidate under consideration. |
520 | PreferredTuple ChoosePreferredUse(MachineInstr &LoadMI, |
521 | PreferredTuple &CurrentUse, |
522 | const LLT TyForCandidate, |
523 | unsigned OpcodeForCandidate, |
524 | MachineInstr *MIForCandidate) { |
525 | if (!CurrentUse.Ty.isValid()) { |
526 | if (CurrentUse.ExtendOpcode == OpcodeForCandidate || |
527 | CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT) |
528 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
529 | return CurrentUse; |
530 | } |
531 | |
532 | // We permit the extend to hoist through basic blocks but this is only |
533 | // sensible if the target has extending loads. If you end up lowering back |
534 | // into a load and extend during the legalizer then the end result is |
535 | // hoisting the extend up to the load. |
536 | |
537 | // Prefer defined extensions to undefined extensions as these are more |
538 | // likely to reduce the number of instructions. |
539 | if (OpcodeForCandidate == TargetOpcode::G_ANYEXT && |
540 | CurrentUse.ExtendOpcode != TargetOpcode::G_ANYEXT) |
541 | return CurrentUse; |
542 | else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT && |
543 | OpcodeForCandidate != TargetOpcode::G_ANYEXT) |
544 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
545 | |
546 | // Prefer sign extensions to zero extensions as sign-extensions tend to be |
547 | // more expensive. Don't do this if the load is already a zero-extend load |
548 | // though, otherwise we'll rewrite a zero-extend load into a sign-extend |
549 | // later. |
550 | if (!isa<GZExtLoad>(Val: LoadMI) && CurrentUse.Ty == TyForCandidate) { |
551 | if (CurrentUse.ExtendOpcode == TargetOpcode::G_SEXT && |
552 | OpcodeForCandidate == TargetOpcode::G_ZEXT) |
553 | return CurrentUse; |
554 | else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ZEXT && |
555 | OpcodeForCandidate == TargetOpcode::G_SEXT) |
556 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
557 | } |
558 | |
559 | // This is potentially target specific. We've chosen the largest type |
560 | // because G_TRUNC is usually free. One potential catch with this is that |
561 | // some targets have a reduced number of larger registers than smaller |
562 | // registers and this choice potentially increases the live-range for the |
563 | // larger value. |
564 | if (TyForCandidate.getSizeInBits() > CurrentUse.Ty.getSizeInBits()) { |
565 | return {.Ty: TyForCandidate, .ExtendOpcode: OpcodeForCandidate, .MI: MIForCandidate}; |
566 | } |
567 | return CurrentUse; |
568 | } |
569 | |
570 | /// Find a suitable place to insert some instructions and insert them. This |
571 | /// function accounts for special cases like inserting before a PHI node. |
572 | /// The current strategy for inserting before PHI's is to duplicate the |
573 | /// instructions for each predecessor. However, while that's ok for G_TRUNC |
574 | /// on most targets since it generally requires no code, other targets/cases may |
575 | /// want to try harder to find a dominating block. |
576 | static void InsertInsnsWithoutSideEffectsBeforeUse( |
577 | MachineIRBuilder &Builder, MachineInstr &DefMI, MachineOperand &UseMO, |
578 | std::function<void(MachineBasicBlock *, MachineBasicBlock::iterator, |
579 | MachineOperand &UseMO)> |
580 | Inserter) { |
581 | MachineInstr &UseMI = *UseMO.getParent(); |
582 | |
583 | MachineBasicBlock *InsertBB = UseMI.getParent(); |
584 | |
585 | // If the use is a PHI then we want the predecessor block instead. |
586 | if (UseMI.isPHI()) { |
587 | MachineOperand *PredBB = std::next(x: &UseMO); |
588 | InsertBB = PredBB->getMBB(); |
589 | } |
590 | |
591 | // If the block is the same block as the def then we want to insert just after |
592 | // the def instead of at the start of the block. |
593 | if (InsertBB == DefMI.getParent()) { |
594 | MachineBasicBlock::iterator InsertPt = &DefMI; |
595 | Inserter(InsertBB, std::next(x: InsertPt), UseMO); |
596 | return; |
597 | } |
598 | |
599 | // Otherwise we want the start of the BB |
600 | Inserter(InsertBB, InsertBB->getFirstNonPHI(), UseMO); |
601 | } |
602 | } // end anonymous namespace |
603 | |
604 | bool CombinerHelper::tryCombineExtendingLoads(MachineInstr &MI) { |
605 | PreferredTuple Preferred; |
606 | if (matchCombineExtendingLoads(MI, MatchInfo&: Preferred)) { |
607 | applyCombineExtendingLoads(MI, MatchInfo&: Preferred); |
608 | return true; |
609 | } |
610 | return false; |
611 | } |
612 | |
613 | static unsigned getExtLoadOpcForExtend(unsigned ExtOpc) { |
614 | unsigned CandidateLoadOpc; |
615 | switch (ExtOpc) { |
616 | case TargetOpcode::G_ANYEXT: |
617 | CandidateLoadOpc = TargetOpcode::G_LOAD; |
618 | break; |
619 | case TargetOpcode::G_SEXT: |
620 | CandidateLoadOpc = TargetOpcode::G_SEXTLOAD; |
621 | break; |
622 | case TargetOpcode::G_ZEXT: |
623 | CandidateLoadOpc = TargetOpcode::G_ZEXTLOAD; |
624 | break; |
625 | default: |
626 | llvm_unreachable("Unexpected extend opc" ); |
627 | } |
628 | return CandidateLoadOpc; |
629 | } |
630 | |
631 | bool CombinerHelper::matchCombineExtendingLoads(MachineInstr &MI, |
632 | PreferredTuple &Preferred) { |
633 | // We match the loads and follow the uses to the extend instead of matching |
634 | // the extends and following the def to the load. This is because the load |
635 | // must remain in the same position for correctness (unless we also add code |
636 | // to find a safe place to sink it) whereas the extend is freely movable. |
637 | // It also prevents us from duplicating the load for the volatile case or just |
638 | // for performance. |
639 | GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: &MI); |
640 | if (!LoadMI) |
641 | return false; |
642 | |
643 | Register LoadReg = LoadMI->getDstReg(); |
644 | |
645 | LLT LoadValueTy = MRI.getType(Reg: LoadReg); |
646 | if (!LoadValueTy.isScalar()) |
647 | return false; |
648 | |
649 | // Most architectures are going to legalize <s8 loads into at least a 1 byte |
650 | // load, and the MMOs can only describe memory accesses in multiples of bytes. |
651 | // If we try to perform extload combining on those, we can end up with |
652 | // %a(s8) = extload %ptr (load 1 byte from %ptr) |
653 | // ... which is an illegal extload instruction. |
654 | if (LoadValueTy.getSizeInBits() < 8) |
655 | return false; |
656 | |
657 | // For non power-of-2 types, they will very likely be legalized into multiple |
658 | // loads. Don't bother trying to match them into extending loads. |
659 | if (!llvm::has_single_bit<uint32_t>(Value: LoadValueTy.getSizeInBits())) |
660 | return false; |
661 | |
662 | // Find the preferred type aside from the any-extends (unless it's the only |
663 | // one) and non-extending ops. We'll emit an extending load to that type and |
664 | // and emit a variant of (extend (trunc X)) for the others according to the |
665 | // relative type sizes. At the same time, pick an extend to use based on the |
666 | // extend involved in the chosen type. |
667 | unsigned PreferredOpcode = |
668 | isa<GLoad>(Val: &MI) |
669 | ? TargetOpcode::G_ANYEXT |
670 | : isa<GSExtLoad>(Val: &MI) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT; |
671 | Preferred = {.Ty: LLT(), .ExtendOpcode: PreferredOpcode, .MI: nullptr}; |
672 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: LoadReg)) { |
673 | if (UseMI.getOpcode() == TargetOpcode::G_SEXT || |
674 | UseMI.getOpcode() == TargetOpcode::G_ZEXT || |
675 | (UseMI.getOpcode() == TargetOpcode::G_ANYEXT)) { |
676 | const auto &MMO = LoadMI->getMMO(); |
677 | // Don't do anything for atomics. |
678 | if (MMO.isAtomic()) |
679 | continue; |
680 | // Check for legality. |
681 | if (!isPreLegalize()) { |
682 | LegalityQuery::MemDesc MMDesc(MMO); |
683 | unsigned CandidateLoadOpc = getExtLoadOpcForExtend(ExtOpc: UseMI.getOpcode()); |
684 | LLT UseTy = MRI.getType(Reg: UseMI.getOperand(i: 0).getReg()); |
685 | LLT SrcTy = MRI.getType(Reg: LoadMI->getPointerReg()); |
686 | if (LI->getAction(Query: {CandidateLoadOpc, {UseTy, SrcTy}, {MMDesc}}) |
687 | .Action != LegalizeActions::Legal) |
688 | continue; |
689 | } |
690 | Preferred = ChoosePreferredUse(LoadMI&: MI, CurrentUse&: Preferred, |
691 | TyForCandidate: MRI.getType(Reg: UseMI.getOperand(i: 0).getReg()), |
692 | OpcodeForCandidate: UseMI.getOpcode(), MIForCandidate: &UseMI); |
693 | } |
694 | } |
695 | |
696 | // There were no extends |
697 | if (!Preferred.MI) |
698 | return false; |
699 | // It should be impossible to chose an extend without selecting a different |
700 | // type since by definition the result of an extend is larger. |
701 | assert(Preferred.Ty != LoadValueTy && "Extending to same type?" ); |
702 | |
703 | LLVM_DEBUG(dbgs() << "Preferred use is: " << *Preferred.MI); |
704 | return true; |
705 | } |
706 | |
707 | void CombinerHelper::applyCombineExtendingLoads(MachineInstr &MI, |
708 | PreferredTuple &Preferred) { |
709 | // Rewrite the load to the chosen extending load. |
710 | Register ChosenDstReg = Preferred.MI->getOperand(i: 0).getReg(); |
711 | |
712 | // Inserter to insert a truncate back to the original type at a given point |
713 | // with some basic CSE to limit truncate duplication to one per BB. |
714 | DenseMap<MachineBasicBlock *, MachineInstr *> EmittedInsns; |
715 | auto InsertTruncAt = [&](MachineBasicBlock *InsertIntoBB, |
716 | MachineBasicBlock::iterator InsertBefore, |
717 | MachineOperand &UseMO) { |
718 | MachineInstr *PreviouslyEmitted = EmittedInsns.lookup(Val: InsertIntoBB); |
719 | if (PreviouslyEmitted) { |
720 | Observer.changingInstr(MI&: *UseMO.getParent()); |
721 | UseMO.setReg(PreviouslyEmitted->getOperand(i: 0).getReg()); |
722 | Observer.changedInstr(MI&: *UseMO.getParent()); |
723 | return; |
724 | } |
725 | |
726 | Builder.setInsertPt(MBB&: *InsertIntoBB, II: InsertBefore); |
727 | Register NewDstReg = MRI.cloneVirtualRegister(VReg: MI.getOperand(i: 0).getReg()); |
728 | MachineInstr *NewMI = Builder.buildTrunc(Res: NewDstReg, Op: ChosenDstReg); |
729 | EmittedInsns[InsertIntoBB] = NewMI; |
730 | replaceRegOpWith(MRI, FromRegOp&: UseMO, ToReg: NewDstReg); |
731 | }; |
732 | |
733 | Observer.changingInstr(MI); |
734 | unsigned LoadOpc = getExtLoadOpcForExtend(ExtOpc: Preferred.ExtendOpcode); |
735 | MI.setDesc(Builder.getTII().get(Opcode: LoadOpc)); |
736 | |
737 | // Rewrite all the uses to fix up the types. |
738 | auto &LoadValue = MI.getOperand(i: 0); |
739 | SmallVector<MachineOperand *, 4> Uses; |
740 | for (auto &UseMO : MRI.use_operands(Reg: LoadValue.getReg())) |
741 | Uses.push_back(Elt: &UseMO); |
742 | |
743 | for (auto *UseMO : Uses) { |
744 | MachineInstr *UseMI = UseMO->getParent(); |
745 | |
746 | // If the extend is compatible with the preferred extend then we should fix |
747 | // up the type and extend so that it uses the preferred use. |
748 | if (UseMI->getOpcode() == Preferred.ExtendOpcode || |
749 | UseMI->getOpcode() == TargetOpcode::G_ANYEXT) { |
750 | Register UseDstReg = UseMI->getOperand(i: 0).getReg(); |
751 | MachineOperand &UseSrcMO = UseMI->getOperand(i: 1); |
752 | const LLT UseDstTy = MRI.getType(Reg: UseDstReg); |
753 | if (UseDstReg != ChosenDstReg) { |
754 | if (Preferred.Ty == UseDstTy) { |
755 | // If the use has the same type as the preferred use, then merge |
756 | // the vregs and erase the extend. For example: |
757 | // %1:_(s8) = G_LOAD ... |
758 | // %2:_(s32) = G_SEXT %1(s8) |
759 | // %3:_(s32) = G_ANYEXT %1(s8) |
760 | // ... = ... %3(s32) |
761 | // rewrites to: |
762 | // %2:_(s32) = G_SEXTLOAD ... |
763 | // ... = ... %2(s32) |
764 | replaceRegWith(MRI, FromReg: UseDstReg, ToReg: ChosenDstReg); |
765 | Observer.erasingInstr(MI&: *UseMO->getParent()); |
766 | UseMO->getParent()->eraseFromParent(); |
767 | } else if (Preferred.Ty.getSizeInBits() < UseDstTy.getSizeInBits()) { |
768 | // If the preferred size is smaller, then keep the extend but extend |
769 | // from the result of the extending load. For example: |
770 | // %1:_(s8) = G_LOAD ... |
771 | // %2:_(s32) = G_SEXT %1(s8) |
772 | // %3:_(s64) = G_ANYEXT %1(s8) |
773 | // ... = ... %3(s64) |
774 | /// rewrites to: |
775 | // %2:_(s32) = G_SEXTLOAD ... |
776 | // %3:_(s64) = G_ANYEXT %2:_(s32) |
777 | // ... = ... %3(s64) |
778 | replaceRegOpWith(MRI, FromRegOp&: UseSrcMO, ToReg: ChosenDstReg); |
779 | } else { |
780 | // If the preferred size is large, then insert a truncate. For |
781 | // example: |
782 | // %1:_(s8) = G_LOAD ... |
783 | // %2:_(s64) = G_SEXT %1(s8) |
784 | // %3:_(s32) = G_ZEXT %1(s8) |
785 | // ... = ... %3(s32) |
786 | /// rewrites to: |
787 | // %2:_(s64) = G_SEXTLOAD ... |
788 | // %4:_(s8) = G_TRUNC %2:_(s32) |
789 | // %3:_(s64) = G_ZEXT %2:_(s8) |
790 | // ... = ... %3(s64) |
791 | InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO, |
792 | Inserter: InsertTruncAt); |
793 | } |
794 | continue; |
795 | } |
796 | // The use is (one of) the uses of the preferred use we chose earlier. |
797 | // We're going to update the load to def this value later so just erase |
798 | // the old extend. |
799 | Observer.erasingInstr(MI&: *UseMO->getParent()); |
800 | UseMO->getParent()->eraseFromParent(); |
801 | continue; |
802 | } |
803 | |
804 | // The use isn't an extend. Truncate back to the type we originally loaded. |
805 | // This is free on many targets. |
806 | InsertInsnsWithoutSideEffectsBeforeUse(Builder, DefMI&: MI, UseMO&: *UseMO, Inserter: InsertTruncAt); |
807 | } |
808 | |
809 | MI.getOperand(i: 0).setReg(ChosenDstReg); |
810 | Observer.changedInstr(MI); |
811 | } |
812 | |
813 | bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI, |
814 | BuildFnTy &MatchInfo) { |
815 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
816 | |
817 | // If we have the following code: |
818 | // %mask = G_CONSTANT 255 |
819 | // %ld = G_LOAD %ptr, (load s16) |
820 | // %and = G_AND %ld, %mask |
821 | // |
822 | // Try to fold it into |
823 | // %ld = G_ZEXTLOAD %ptr, (load s8) |
824 | |
825 | Register Dst = MI.getOperand(i: 0).getReg(); |
826 | if (MRI.getType(Reg: Dst).isVector()) |
827 | return false; |
828 | |
829 | auto MaybeMask = |
830 | getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
831 | if (!MaybeMask) |
832 | return false; |
833 | |
834 | APInt MaskVal = MaybeMask->Value; |
835 | |
836 | if (!MaskVal.isMask()) |
837 | return false; |
838 | |
839 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
840 | // Don't use getOpcodeDef() here since intermediate instructions may have |
841 | // multiple users. |
842 | GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(Val: MRI.getVRegDef(Reg: SrcReg)); |
843 | if (!LoadMI || !MRI.hasOneNonDBGUse(RegNo: LoadMI->getDstReg())) |
844 | return false; |
845 | |
846 | Register LoadReg = LoadMI->getDstReg(); |
847 | LLT RegTy = MRI.getType(Reg: LoadReg); |
848 | Register PtrReg = LoadMI->getPointerReg(); |
849 | unsigned RegSize = RegTy.getSizeInBits(); |
850 | LocationSize LoadSizeBits = LoadMI->getMemSizeInBits(); |
851 | unsigned MaskSizeBits = MaskVal.countr_one(); |
852 | |
853 | // The mask may not be larger than the in-memory type, as it might cover sign |
854 | // extended bits |
855 | if (MaskSizeBits > LoadSizeBits.getValue()) |
856 | return false; |
857 | |
858 | // If the mask covers the whole destination register, there's nothing to |
859 | // extend |
860 | if (MaskSizeBits >= RegSize) |
861 | return false; |
862 | |
863 | // Most targets cannot deal with loads of size < 8 and need to re-legalize to |
864 | // at least byte loads. Avoid creating such loads here |
865 | if (MaskSizeBits < 8 || !isPowerOf2_32(Value: MaskSizeBits)) |
866 | return false; |
867 | |
868 | const MachineMemOperand &MMO = LoadMI->getMMO(); |
869 | LegalityQuery::MemDesc MemDesc(MMO); |
870 | |
871 | // Don't modify the memory access size if this is atomic/volatile, but we can |
872 | // still adjust the opcode to indicate the high bit behavior. |
873 | if (LoadMI->isSimple()) |
874 | MemDesc.MemoryTy = LLT::scalar(SizeInBits: MaskSizeBits); |
875 | else if (LoadSizeBits.getValue() > MaskSizeBits || |
876 | LoadSizeBits.getValue() == RegSize) |
877 | return false; |
878 | |
879 | // TODO: Could check if it's legal with the reduced or original memory size. |
880 | if (!isLegalOrBeforeLegalizer( |
881 | Query: {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(Reg: PtrReg)}, {MemDesc}})) |
882 | return false; |
883 | |
884 | MatchInfo = [=](MachineIRBuilder &B) { |
885 | B.setInstrAndDebugLoc(*LoadMI); |
886 | auto &MF = B.getMF(); |
887 | auto PtrInfo = MMO.getPointerInfo(); |
888 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: MemDesc.MemoryTy); |
889 | B.buildLoadInstr(Opcode: TargetOpcode::G_ZEXTLOAD, Res: Dst, Addr: PtrReg, MMO&: *NewMMO); |
890 | LoadMI->eraseFromParent(); |
891 | }; |
892 | return true; |
893 | } |
894 | |
895 | bool CombinerHelper::isPredecessor(const MachineInstr &DefMI, |
896 | const MachineInstr &UseMI) { |
897 | assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() && |
898 | "shouldn't consider debug uses" ); |
899 | assert(DefMI.getParent() == UseMI.getParent()); |
900 | if (&DefMI == &UseMI) |
901 | return true; |
902 | const MachineBasicBlock &MBB = *DefMI.getParent(); |
903 | auto DefOrUse = find_if(Range: MBB, P: [&DefMI, &UseMI](const MachineInstr &MI) { |
904 | return &MI == &DefMI || &MI == &UseMI; |
905 | }); |
906 | if (DefOrUse == MBB.end()) |
907 | llvm_unreachable("Block must contain both DefMI and UseMI!" ); |
908 | return &*DefOrUse == &DefMI; |
909 | } |
910 | |
911 | bool CombinerHelper::dominates(const MachineInstr &DefMI, |
912 | const MachineInstr &UseMI) { |
913 | assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() && |
914 | "shouldn't consider debug uses" ); |
915 | if (MDT) |
916 | return MDT->dominates(A: &DefMI, B: &UseMI); |
917 | else if (DefMI.getParent() != UseMI.getParent()) |
918 | return false; |
919 | |
920 | return isPredecessor(DefMI, UseMI); |
921 | } |
922 | |
923 | bool CombinerHelper::matchSextTruncSextLoad(MachineInstr &MI) { |
924 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
925 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
926 | Register LoadUser = SrcReg; |
927 | |
928 | if (MRI.getType(Reg: SrcReg).isVector()) |
929 | return false; |
930 | |
931 | Register TruncSrc; |
932 | if (mi_match(R: SrcReg, MRI, P: m_GTrunc(Src: m_Reg(R&: TruncSrc)))) |
933 | LoadUser = TruncSrc; |
934 | |
935 | uint64_t SizeInBits = MI.getOperand(i: 2).getImm(); |
936 | // If the source is a G_SEXTLOAD from the same bit width, then we don't |
937 | // need any extend at all, just a truncate. |
938 | if (auto *LoadMI = getOpcodeDef<GSExtLoad>(Reg: LoadUser, MRI)) { |
939 | // If truncating more than the original extended value, abort. |
940 | auto LoadSizeBits = LoadMI->getMemSizeInBits(); |
941 | if (TruncSrc && |
942 | MRI.getType(Reg: TruncSrc).getSizeInBits() < LoadSizeBits.getValue()) |
943 | return false; |
944 | if (LoadSizeBits == SizeInBits) |
945 | return true; |
946 | } |
947 | return false; |
948 | } |
949 | |
950 | void CombinerHelper::applySextTruncSextLoad(MachineInstr &MI) { |
951 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
952 | Builder.buildCopy(Res: MI.getOperand(i: 0).getReg(), Op: MI.getOperand(i: 1).getReg()); |
953 | MI.eraseFromParent(); |
954 | } |
955 | |
956 | bool CombinerHelper::matchSextInRegOfLoad( |
957 | MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) { |
958 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
959 | |
960 | Register DstReg = MI.getOperand(i: 0).getReg(); |
961 | LLT RegTy = MRI.getType(Reg: DstReg); |
962 | |
963 | // Only supports scalars for now. |
964 | if (RegTy.isVector()) |
965 | return false; |
966 | |
967 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
968 | auto *LoadDef = getOpcodeDef<GLoad>(Reg: SrcReg, MRI); |
969 | if (!LoadDef || !MRI.hasOneNonDBGUse(RegNo: DstReg)) |
970 | return false; |
971 | |
972 | uint64_t MemBits = LoadDef->getMemSizeInBits().getValue(); |
973 | |
974 | // If the sign extend extends from a narrower width than the load's width, |
975 | // then we can narrow the load width when we combine to a G_SEXTLOAD. |
976 | // Avoid widening the load at all. |
977 | unsigned NewSizeBits = std::min(a: (uint64_t)MI.getOperand(i: 2).getImm(), b: MemBits); |
978 | |
979 | // Don't generate G_SEXTLOADs with a < 1 byte width. |
980 | if (NewSizeBits < 8) |
981 | return false; |
982 | // Don't bother creating a non-power-2 sextload, it will likely be broken up |
983 | // anyway for most targets. |
984 | if (!isPowerOf2_32(Value: NewSizeBits)) |
985 | return false; |
986 | |
987 | const MachineMemOperand &MMO = LoadDef->getMMO(); |
988 | LegalityQuery::MemDesc MMDesc(MMO); |
989 | |
990 | // Don't modify the memory access size if this is atomic/volatile, but we can |
991 | // still adjust the opcode to indicate the high bit behavior. |
992 | if (LoadDef->isSimple()) |
993 | MMDesc.MemoryTy = LLT::scalar(SizeInBits: NewSizeBits); |
994 | else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits()) |
995 | return false; |
996 | |
997 | // TODO: Could check if it's legal with the reduced or original memory size. |
998 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SEXTLOAD, |
999 | {MRI.getType(Reg: LoadDef->getDstReg()), |
1000 | MRI.getType(Reg: LoadDef->getPointerReg())}, |
1001 | {MMDesc}})) |
1002 | return false; |
1003 | |
1004 | MatchInfo = std::make_tuple(args: LoadDef->getDstReg(), args&: NewSizeBits); |
1005 | return true; |
1006 | } |
1007 | |
1008 | void CombinerHelper::applySextInRegOfLoad( |
1009 | MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) { |
1010 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
1011 | Register LoadReg; |
1012 | unsigned ScalarSizeBits; |
1013 | std::tie(args&: LoadReg, args&: ScalarSizeBits) = MatchInfo; |
1014 | GLoad *LoadDef = cast<GLoad>(Val: MRI.getVRegDef(Reg: LoadReg)); |
1015 | |
1016 | // If we have the following: |
1017 | // %ld = G_LOAD %ptr, (load 2) |
1018 | // %ext = G_SEXT_INREG %ld, 8 |
1019 | // ==> |
1020 | // %ld = G_SEXTLOAD %ptr (load 1) |
1021 | |
1022 | auto &MMO = LoadDef->getMMO(); |
1023 | Builder.setInstrAndDebugLoc(*LoadDef); |
1024 | auto &MF = Builder.getMF(); |
1025 | auto PtrInfo = MMO.getPointerInfo(); |
1026 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: ScalarSizeBits / 8); |
1027 | Builder.buildLoadInstr(Opcode: TargetOpcode::G_SEXTLOAD, Res: MI.getOperand(i: 0).getReg(), |
1028 | Addr: LoadDef->getPointerReg(), MMO&: *NewMMO); |
1029 | MI.eraseFromParent(); |
1030 | } |
1031 | |
1032 | static Type *getTypeForLLT(LLT Ty, LLVMContext &C) { |
1033 | if (Ty.isVector()) |
1034 | return FixedVectorType::get(ElementType: IntegerType::get(C, NumBits: Ty.getScalarSizeInBits()), |
1035 | NumElts: Ty.getNumElements()); |
1036 | return IntegerType::get(C, NumBits: Ty.getSizeInBits()); |
1037 | } |
1038 | |
1039 | /// Return true if 'MI' is a load or a store that may be fold it's address |
1040 | /// operand into the load / store addressing mode. |
1041 | static bool canFoldInAddressingMode(GLoadStore *MI, const TargetLowering &TLI, |
1042 | MachineRegisterInfo &MRI) { |
1043 | TargetLowering::AddrMode AM; |
1044 | auto *MF = MI->getMF(); |
1045 | auto *Addr = getOpcodeDef<GPtrAdd>(Reg: MI->getPointerReg(), MRI); |
1046 | if (!Addr) |
1047 | return false; |
1048 | |
1049 | AM.HasBaseReg = true; |
1050 | if (auto CstOff = getIConstantVRegVal(VReg: Addr->getOffsetReg(), MRI)) |
1051 | AM.BaseOffs = CstOff->getSExtValue(); // [reg +/- imm] |
1052 | else |
1053 | AM.Scale = 1; // [reg +/- reg] |
1054 | |
1055 | return TLI.isLegalAddressingMode( |
1056 | DL: MF->getDataLayout(), AM, |
1057 | Ty: getTypeForLLT(Ty: MI->getMMO().getMemoryType(), |
1058 | C&: MF->getFunction().getContext()), |
1059 | AddrSpace: MI->getMMO().getAddrSpace()); |
1060 | } |
1061 | |
1062 | static unsigned getIndexedOpc(unsigned LdStOpc) { |
1063 | switch (LdStOpc) { |
1064 | case TargetOpcode::G_LOAD: |
1065 | return TargetOpcode::G_INDEXED_LOAD; |
1066 | case TargetOpcode::G_STORE: |
1067 | return TargetOpcode::G_INDEXED_STORE; |
1068 | case TargetOpcode::G_ZEXTLOAD: |
1069 | return TargetOpcode::G_INDEXED_ZEXTLOAD; |
1070 | case TargetOpcode::G_SEXTLOAD: |
1071 | return TargetOpcode::G_INDEXED_SEXTLOAD; |
1072 | default: |
1073 | llvm_unreachable("Unexpected opcode" ); |
1074 | } |
1075 | } |
1076 | |
1077 | bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const { |
1078 | // Check for legality. |
1079 | LLT PtrTy = MRI.getType(Reg: LdSt.getPointerReg()); |
1080 | LLT Ty = MRI.getType(Reg: LdSt.getReg(Idx: 0)); |
1081 | LLT MemTy = LdSt.getMMO().getMemoryType(); |
1082 | SmallVector<LegalityQuery::MemDesc, 2> MemDescrs( |
1083 | {{MemTy, MemTy.getSizeInBits(), AtomicOrdering::NotAtomic}}); |
1084 | unsigned IndexedOpc = getIndexedOpc(LdStOpc: LdSt.getOpcode()); |
1085 | SmallVector<LLT> OpTys; |
1086 | if (IndexedOpc == TargetOpcode::G_INDEXED_STORE) |
1087 | OpTys = {PtrTy, Ty, Ty}; |
1088 | else |
1089 | OpTys = {Ty, PtrTy}; // For G_INDEXED_LOAD, G_INDEXED_[SZ]EXTLOAD |
1090 | |
1091 | LegalityQuery Q(IndexedOpc, OpTys, MemDescrs); |
1092 | return isLegal(Query: Q); |
1093 | } |
1094 | |
1095 | static cl::opt<unsigned> PostIndexUseThreshold( |
1096 | "post-index-use-threshold" , cl::Hidden, cl::init(Val: 32), |
1097 | cl::desc("Number of uses of a base pointer to check before it is no longer " |
1098 | "considered for post-indexing." )); |
1099 | |
1100 | bool CombinerHelper::findPostIndexCandidate(GLoadStore &LdSt, Register &Addr, |
1101 | Register &Base, Register &Offset, |
1102 | bool &RematOffset) { |
1103 | // We're looking for the following pattern, for either load or store: |
1104 | // %baseptr:_(p0) = ... |
1105 | // G_STORE %val(s64), %baseptr(p0) |
1106 | // %offset:_(s64) = G_CONSTANT i64 -256 |
1107 | // %new_addr:_(p0) = G_PTR_ADD %baseptr, %offset(s64) |
1108 | const auto &TLI = getTargetLowering(); |
1109 | |
1110 | Register Ptr = LdSt.getPointerReg(); |
1111 | // If the store is the only use, don't bother. |
1112 | if (MRI.hasOneNonDBGUse(RegNo: Ptr)) |
1113 | return false; |
1114 | |
1115 | if (!isIndexedLoadStoreLegal(LdSt)) |
1116 | return false; |
1117 | |
1118 | if (getOpcodeDef(Opcode: TargetOpcode::G_FRAME_INDEX, Reg: Ptr, MRI)) |
1119 | return false; |
1120 | |
1121 | MachineInstr *StoredValDef = getDefIgnoringCopies(Reg: LdSt.getReg(Idx: 0), MRI); |
1122 | auto *PtrDef = MRI.getVRegDef(Reg: Ptr); |
1123 | |
1124 | unsigned NumUsesChecked = 0; |
1125 | for (auto &Use : MRI.use_nodbg_instructions(Reg: Ptr)) { |
1126 | if (++NumUsesChecked > PostIndexUseThreshold) |
1127 | return false; // Try to avoid exploding compile time. |
1128 | |
1129 | auto *PtrAdd = dyn_cast<GPtrAdd>(Val: &Use); |
1130 | // The use itself might be dead. This can happen during combines if DCE |
1131 | // hasn't had a chance to run yet. Don't allow it to form an indexed op. |
1132 | if (!PtrAdd || MRI.use_nodbg_empty(RegNo: PtrAdd->getReg(Idx: 0))) |
1133 | continue; |
1134 | |
1135 | // Check the user of this isn't the store, otherwise we'd be generate a |
1136 | // indexed store defining its own use. |
1137 | if (StoredValDef == &Use) |
1138 | continue; |
1139 | |
1140 | Offset = PtrAdd->getOffsetReg(); |
1141 | if (!ForceLegalIndexing && |
1142 | !TLI.isIndexingLegal(MI&: LdSt, Base: PtrAdd->getBaseReg(), Offset, |
1143 | /*IsPre*/ false, MRI)) |
1144 | continue; |
1145 | |
1146 | // Make sure the offset calculation is before the potentially indexed op. |
1147 | MachineInstr *OffsetDef = MRI.getVRegDef(Reg: Offset); |
1148 | RematOffset = false; |
1149 | if (!dominates(DefMI: *OffsetDef, UseMI: LdSt)) { |
1150 | // If the offset however is just a G_CONSTANT, we can always just |
1151 | // rematerialize it where we need it. |
1152 | if (OffsetDef->getOpcode() != TargetOpcode::G_CONSTANT) |
1153 | continue; |
1154 | RematOffset = true; |
1155 | } |
1156 | |
1157 | for (auto &BasePtrUse : MRI.use_nodbg_instructions(Reg: PtrAdd->getBaseReg())) { |
1158 | if (&BasePtrUse == PtrDef) |
1159 | continue; |
1160 | |
1161 | // If the user is a later load/store that can be post-indexed, then don't |
1162 | // combine this one. |
1163 | auto *BasePtrLdSt = dyn_cast<GLoadStore>(Val: &BasePtrUse); |
1164 | if (BasePtrLdSt && BasePtrLdSt != &LdSt && |
1165 | dominates(DefMI: LdSt, UseMI: *BasePtrLdSt) && |
1166 | isIndexedLoadStoreLegal(LdSt&: *BasePtrLdSt)) |
1167 | return false; |
1168 | |
1169 | // Now we're looking for the key G_PTR_ADD instruction, which contains |
1170 | // the offset add that we want to fold. |
1171 | if (auto *BasePtrUseDef = dyn_cast<GPtrAdd>(Val: &BasePtrUse)) { |
1172 | Register PtrAddDefReg = BasePtrUseDef->getReg(Idx: 0); |
1173 | for (auto &BaseUseUse : MRI.use_nodbg_instructions(Reg: PtrAddDefReg)) { |
1174 | // If the use is in a different block, then we may produce worse code |
1175 | // due to the extra register pressure. |
1176 | if (BaseUseUse.getParent() != LdSt.getParent()) |
1177 | return false; |
1178 | |
1179 | if (auto *UseUseLdSt = dyn_cast<GLoadStore>(Val: &BaseUseUse)) |
1180 | if (canFoldInAddressingMode(MI: UseUseLdSt, TLI, MRI)) |
1181 | return false; |
1182 | } |
1183 | if (!dominates(DefMI: LdSt, UseMI: BasePtrUse)) |
1184 | return false; // All use must be dominated by the load/store. |
1185 | } |
1186 | } |
1187 | |
1188 | Addr = PtrAdd->getReg(Idx: 0); |
1189 | Base = PtrAdd->getBaseReg(); |
1190 | return true; |
1191 | } |
1192 | |
1193 | return false; |
1194 | } |
1195 | |
1196 | bool CombinerHelper::findPreIndexCandidate(GLoadStore &LdSt, Register &Addr, |
1197 | Register &Base, Register &Offset) { |
1198 | auto &MF = *LdSt.getParent()->getParent(); |
1199 | const auto &TLI = *MF.getSubtarget().getTargetLowering(); |
1200 | |
1201 | Addr = LdSt.getPointerReg(); |
1202 | if (!mi_match(R: Addr, MRI, P: m_GPtrAdd(L: m_Reg(R&: Base), R: m_Reg(R&: Offset))) || |
1203 | MRI.hasOneNonDBGUse(RegNo: Addr)) |
1204 | return false; |
1205 | |
1206 | if (!ForceLegalIndexing && |
1207 | !TLI.isIndexingLegal(MI&: LdSt, Base, Offset, /*IsPre*/ true, MRI)) |
1208 | return false; |
1209 | |
1210 | if (!isIndexedLoadStoreLegal(LdSt)) |
1211 | return false; |
1212 | |
1213 | MachineInstr *BaseDef = getDefIgnoringCopies(Reg: Base, MRI); |
1214 | if (BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX) |
1215 | return false; |
1216 | |
1217 | if (auto *St = dyn_cast<GStore>(Val: &LdSt)) { |
1218 | // Would require a copy. |
1219 | if (Base == St->getValueReg()) |
1220 | return false; |
1221 | |
1222 | // We're expecting one use of Addr in MI, but it could also be the |
1223 | // value stored, which isn't actually dominated by the instruction. |
1224 | if (St->getValueReg() == Addr) |
1225 | return false; |
1226 | } |
1227 | |
1228 | // Avoid increasing cross-block register pressure. |
1229 | for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr)) |
1230 | if (AddrUse.getParent() != LdSt.getParent()) |
1231 | return false; |
1232 | |
1233 | // FIXME: check whether all uses of the base pointer are constant PtrAdds. |
1234 | // That might allow us to end base's liveness here by adjusting the constant. |
1235 | bool RealUse = false; |
1236 | for (auto &AddrUse : MRI.use_nodbg_instructions(Reg: Addr)) { |
1237 | if (!dominates(DefMI: LdSt, UseMI: AddrUse)) |
1238 | return false; // All use must be dominated by the load/store. |
1239 | |
1240 | // If Ptr may be folded in addressing mode of other use, then it's |
1241 | // not profitable to do this transformation. |
1242 | if (auto *UseLdSt = dyn_cast<GLoadStore>(Val: &AddrUse)) { |
1243 | if (!canFoldInAddressingMode(MI: UseLdSt, TLI, MRI)) |
1244 | RealUse = true; |
1245 | } else { |
1246 | RealUse = true; |
1247 | } |
1248 | } |
1249 | return RealUse; |
1250 | } |
1251 | |
1252 | bool CombinerHelper::(MachineInstr &MI, |
1253 | BuildFnTy &MatchInfo) { |
1254 | assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT); |
1255 | |
1256 | // Check if there is a load that defines the vector being extracted from. |
1257 | auto *LoadMI = getOpcodeDef<GLoad>(Reg: MI.getOperand(i: 1).getReg(), MRI); |
1258 | if (!LoadMI) |
1259 | return false; |
1260 | |
1261 | Register Vector = MI.getOperand(i: 1).getReg(); |
1262 | LLT VecEltTy = MRI.getType(Reg: Vector).getElementType(); |
1263 | |
1264 | assert(MRI.getType(MI.getOperand(0).getReg()) == VecEltTy); |
1265 | |
1266 | // Checking whether we should reduce the load width. |
1267 | if (!MRI.hasOneNonDBGUse(RegNo: Vector)) |
1268 | return false; |
1269 | |
1270 | // Check if the defining load is simple. |
1271 | if (!LoadMI->isSimple()) |
1272 | return false; |
1273 | |
1274 | // If the vector element type is not a multiple of a byte then we are unable |
1275 | // to correctly compute an address to load only the extracted element as a |
1276 | // scalar. |
1277 | if (!VecEltTy.isByteSized()) |
1278 | return false; |
1279 | |
1280 | // Check for load fold barriers between the extraction and the load. |
1281 | if (MI.getParent() != LoadMI->getParent()) |
1282 | return false; |
1283 | const unsigned MaxIter = 20; |
1284 | unsigned Iter = 0; |
1285 | for (auto II = LoadMI->getIterator(), IE = MI.getIterator(); II != IE; ++II) { |
1286 | if (II->isLoadFoldBarrier()) |
1287 | return false; |
1288 | if (Iter++ == MaxIter) |
1289 | return false; |
1290 | } |
1291 | |
1292 | // Check if the new load that we are going to create is legal |
1293 | // if we are in the post-legalization phase. |
1294 | MachineMemOperand MMO = LoadMI->getMMO(); |
1295 | Align Alignment = MMO.getAlign(); |
1296 | MachinePointerInfo PtrInfo; |
1297 | uint64_t Offset; |
1298 | |
1299 | // Finding the appropriate PtrInfo if offset is a known constant. |
1300 | // This is required to create the memory operand for the narrowed load. |
1301 | // This machine memory operand object helps us infer about legality |
1302 | // before we proceed to combine the instruction. |
1303 | if (auto CVal = getIConstantVRegVal(VReg: Vector, MRI)) { |
1304 | int Elt = CVal->getZExtValue(); |
1305 | // FIXME: should be (ABI size)*Elt. |
1306 | Offset = VecEltTy.getSizeInBits() * Elt / 8; |
1307 | PtrInfo = MMO.getPointerInfo().getWithOffset(O: Offset); |
1308 | } else { |
1309 | // Discard the pointer info except the address space because the memory |
1310 | // operand can't represent this new access since the offset is variable. |
1311 | Offset = VecEltTy.getSizeInBits() / 8; |
1312 | PtrInfo = MachinePointerInfo(MMO.getPointerInfo().getAddrSpace()); |
1313 | } |
1314 | |
1315 | Alignment = commonAlignment(A: Alignment, Offset); |
1316 | |
1317 | Register VecPtr = LoadMI->getPointerReg(); |
1318 | LLT PtrTy = MRI.getType(Reg: VecPtr); |
1319 | |
1320 | MachineFunction &MF = *MI.getMF(); |
1321 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Ty: VecEltTy); |
1322 | |
1323 | LegalityQuery::MemDesc MMDesc(*NewMMO); |
1324 | |
1325 | LegalityQuery Q = {TargetOpcode::G_LOAD, {VecEltTy, PtrTy}, {MMDesc}}; |
1326 | |
1327 | if (!isLegalOrBeforeLegalizer(Query: Q)) |
1328 | return false; |
1329 | |
1330 | // Load must be allowed and fast on the target. |
1331 | LLVMContext &C = MF.getFunction().getContext(); |
1332 | auto &DL = MF.getDataLayout(); |
1333 | unsigned Fast = 0; |
1334 | if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty: VecEltTy, MMO: *NewMMO, |
1335 | Fast: &Fast) || |
1336 | !Fast) |
1337 | return false; |
1338 | |
1339 | Register Result = MI.getOperand(i: 0).getReg(); |
1340 | Register Index = MI.getOperand(i: 2).getReg(); |
1341 | |
1342 | MatchInfo = [=](MachineIRBuilder &B) { |
1343 | GISelObserverWrapper DummyObserver; |
1344 | LegalizerHelper Helper(B.getMF(), DummyObserver, B); |
1345 | //// Get pointer to the vector element. |
1346 | Register finalPtr = Helper.getVectorElementPointer( |
1347 | VecPtr: LoadMI->getPointerReg(), VecTy: MRI.getType(Reg: LoadMI->getOperand(i: 0).getReg()), |
1348 | Index); |
1349 | // New G_LOAD instruction. |
1350 | B.buildLoad(Res: Result, Addr: finalPtr, PtrInfo, Alignment); |
1351 | // Remove original GLOAD instruction. |
1352 | LoadMI->eraseFromParent(); |
1353 | }; |
1354 | |
1355 | return true; |
1356 | } |
1357 | |
1358 | bool CombinerHelper::matchCombineIndexedLoadStore( |
1359 | MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) { |
1360 | auto &LdSt = cast<GLoadStore>(Val&: MI); |
1361 | |
1362 | if (LdSt.isAtomic()) |
1363 | return false; |
1364 | |
1365 | MatchInfo.IsPre = findPreIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base, |
1366 | Offset&: MatchInfo.Offset); |
1367 | if (!MatchInfo.IsPre && |
1368 | !findPostIndexCandidate(LdSt, Addr&: MatchInfo.Addr, Base&: MatchInfo.Base, |
1369 | Offset&: MatchInfo.Offset, RematOffset&: MatchInfo.RematOffset)) |
1370 | return false; |
1371 | |
1372 | return true; |
1373 | } |
1374 | |
1375 | void CombinerHelper::applyCombineIndexedLoadStore( |
1376 | MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) { |
1377 | MachineInstr &AddrDef = *MRI.getUniqueVRegDef(Reg: MatchInfo.Addr); |
1378 | unsigned Opcode = MI.getOpcode(); |
1379 | bool IsStore = Opcode == TargetOpcode::G_STORE; |
1380 | unsigned NewOpcode = getIndexedOpc(LdStOpc: Opcode); |
1381 | |
1382 | // If the offset constant didn't happen to dominate the load/store, we can |
1383 | // just clone it as needed. |
1384 | if (MatchInfo.RematOffset) { |
1385 | auto *OldCst = MRI.getVRegDef(Reg: MatchInfo.Offset); |
1386 | auto NewCst = Builder.buildConstant(Res: MRI.getType(Reg: MatchInfo.Offset), |
1387 | Val: *OldCst->getOperand(i: 1).getCImm()); |
1388 | MatchInfo.Offset = NewCst.getReg(Idx: 0); |
1389 | } |
1390 | |
1391 | auto MIB = Builder.buildInstr(Opcode: NewOpcode); |
1392 | if (IsStore) { |
1393 | MIB.addDef(RegNo: MatchInfo.Addr); |
1394 | MIB.addUse(RegNo: MI.getOperand(i: 0).getReg()); |
1395 | } else { |
1396 | MIB.addDef(RegNo: MI.getOperand(i: 0).getReg()); |
1397 | MIB.addDef(RegNo: MatchInfo.Addr); |
1398 | } |
1399 | |
1400 | MIB.addUse(RegNo: MatchInfo.Base); |
1401 | MIB.addUse(RegNo: MatchInfo.Offset); |
1402 | MIB.addImm(Val: MatchInfo.IsPre); |
1403 | MIB->cloneMemRefs(MF&: *MI.getMF(), MI); |
1404 | MI.eraseFromParent(); |
1405 | AddrDef.eraseFromParent(); |
1406 | |
1407 | LLVM_DEBUG(dbgs() << " Combinined to indexed operation" ); |
1408 | } |
1409 | |
1410 | bool CombinerHelper::matchCombineDivRem(MachineInstr &MI, |
1411 | MachineInstr *&OtherMI) { |
1412 | unsigned Opcode = MI.getOpcode(); |
1413 | bool IsDiv, IsSigned; |
1414 | |
1415 | switch (Opcode) { |
1416 | default: |
1417 | llvm_unreachable("Unexpected opcode!" ); |
1418 | case TargetOpcode::G_SDIV: |
1419 | case TargetOpcode::G_UDIV: { |
1420 | IsDiv = true; |
1421 | IsSigned = Opcode == TargetOpcode::G_SDIV; |
1422 | break; |
1423 | } |
1424 | case TargetOpcode::G_SREM: |
1425 | case TargetOpcode::G_UREM: { |
1426 | IsDiv = false; |
1427 | IsSigned = Opcode == TargetOpcode::G_SREM; |
1428 | break; |
1429 | } |
1430 | } |
1431 | |
1432 | Register Src1 = MI.getOperand(i: 1).getReg(); |
1433 | unsigned DivOpcode, RemOpcode, DivremOpcode; |
1434 | if (IsSigned) { |
1435 | DivOpcode = TargetOpcode::G_SDIV; |
1436 | RemOpcode = TargetOpcode::G_SREM; |
1437 | DivremOpcode = TargetOpcode::G_SDIVREM; |
1438 | } else { |
1439 | DivOpcode = TargetOpcode::G_UDIV; |
1440 | RemOpcode = TargetOpcode::G_UREM; |
1441 | DivremOpcode = TargetOpcode::G_UDIVREM; |
1442 | } |
1443 | |
1444 | if (!isLegalOrBeforeLegalizer(Query: {DivremOpcode, {MRI.getType(Reg: Src1)}})) |
1445 | return false; |
1446 | |
1447 | // Combine: |
1448 | // %div:_ = G_[SU]DIV %src1:_, %src2:_ |
1449 | // %rem:_ = G_[SU]REM %src1:_, %src2:_ |
1450 | // into: |
1451 | // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_ |
1452 | |
1453 | // Combine: |
1454 | // %rem:_ = G_[SU]REM %src1:_, %src2:_ |
1455 | // %div:_ = G_[SU]DIV %src1:_, %src2:_ |
1456 | // into: |
1457 | // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_ |
1458 | |
1459 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: Src1)) { |
1460 | if (MI.getParent() == UseMI.getParent() && |
1461 | ((IsDiv && UseMI.getOpcode() == RemOpcode) || |
1462 | (!IsDiv && UseMI.getOpcode() == DivOpcode)) && |
1463 | matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: UseMI.getOperand(i: 2)) && |
1464 | matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: UseMI.getOperand(i: 1))) { |
1465 | OtherMI = &UseMI; |
1466 | return true; |
1467 | } |
1468 | } |
1469 | |
1470 | return false; |
1471 | } |
1472 | |
1473 | void CombinerHelper::applyCombineDivRem(MachineInstr &MI, |
1474 | MachineInstr *&OtherMI) { |
1475 | unsigned Opcode = MI.getOpcode(); |
1476 | assert(OtherMI && "OtherMI shouldn't be empty." ); |
1477 | |
1478 | Register DestDivReg, DestRemReg; |
1479 | if (Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_UDIV) { |
1480 | DestDivReg = MI.getOperand(i: 0).getReg(); |
1481 | DestRemReg = OtherMI->getOperand(i: 0).getReg(); |
1482 | } else { |
1483 | DestDivReg = OtherMI->getOperand(i: 0).getReg(); |
1484 | DestRemReg = MI.getOperand(i: 0).getReg(); |
1485 | } |
1486 | |
1487 | bool IsSigned = |
1488 | Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM; |
1489 | |
1490 | // Check which instruction is first in the block so we don't break def-use |
1491 | // deps by "moving" the instruction incorrectly. Also keep track of which |
1492 | // instruction is first so we pick it's operands, avoiding use-before-def |
1493 | // bugs. |
1494 | MachineInstr *FirstInst = dominates(DefMI: MI, UseMI: *OtherMI) ? &MI : OtherMI; |
1495 | Builder.setInstrAndDebugLoc(*FirstInst); |
1496 | |
1497 | Builder.buildInstr(Opc: IsSigned ? TargetOpcode::G_SDIVREM |
1498 | : TargetOpcode::G_UDIVREM, |
1499 | DstOps: {DestDivReg, DestRemReg}, |
1500 | SrcOps: { FirstInst->getOperand(i: 1), FirstInst->getOperand(i: 2) }); |
1501 | MI.eraseFromParent(); |
1502 | OtherMI->eraseFromParent(); |
1503 | } |
1504 | |
1505 | bool CombinerHelper::matchOptBrCondByInvertingCond(MachineInstr &MI, |
1506 | MachineInstr *&BrCond) { |
1507 | assert(MI.getOpcode() == TargetOpcode::G_BR); |
1508 | |
1509 | // Try to match the following: |
1510 | // bb1: |
1511 | // G_BRCOND %c1, %bb2 |
1512 | // G_BR %bb3 |
1513 | // bb2: |
1514 | // ... |
1515 | // bb3: |
1516 | |
1517 | // The above pattern does not have a fall through to the successor bb2, always |
1518 | // resulting in a branch no matter which path is taken. Here we try to find |
1519 | // and replace that pattern with conditional branch to bb3 and otherwise |
1520 | // fallthrough to bb2. This is generally better for branch predictors. |
1521 | |
1522 | MachineBasicBlock *MBB = MI.getParent(); |
1523 | MachineBasicBlock::iterator BrIt(MI); |
1524 | if (BrIt == MBB->begin()) |
1525 | return false; |
1526 | assert(std::next(BrIt) == MBB->end() && "expected G_BR to be a terminator" ); |
1527 | |
1528 | BrCond = &*std::prev(x: BrIt); |
1529 | if (BrCond->getOpcode() != TargetOpcode::G_BRCOND) |
1530 | return false; |
1531 | |
1532 | // Check that the next block is the conditional branch target. Also make sure |
1533 | // that it isn't the same as the G_BR's target (otherwise, this will loop.) |
1534 | MachineBasicBlock *BrCondTarget = BrCond->getOperand(i: 1).getMBB(); |
1535 | return BrCondTarget != MI.getOperand(i: 0).getMBB() && |
1536 | MBB->isLayoutSuccessor(MBB: BrCondTarget); |
1537 | } |
1538 | |
1539 | void CombinerHelper::applyOptBrCondByInvertingCond(MachineInstr &MI, |
1540 | MachineInstr *&BrCond) { |
1541 | MachineBasicBlock *BrTarget = MI.getOperand(i: 0).getMBB(); |
1542 | Builder.setInstrAndDebugLoc(*BrCond); |
1543 | LLT Ty = MRI.getType(Reg: BrCond->getOperand(i: 0).getReg()); |
1544 | // FIXME: Does int/fp matter for this? If so, we might need to restrict |
1545 | // this to i1 only since we might not know for sure what kind of |
1546 | // compare generated the condition value. |
1547 | auto True = Builder.buildConstant( |
1548 | Res: Ty, Val: getICmpTrueVal(TLI: getTargetLowering(), IsVector: false, IsFP: false)); |
1549 | auto Xor = Builder.buildXor(Dst: Ty, Src0: BrCond->getOperand(i: 0), Src1: True); |
1550 | |
1551 | auto *FallthroughBB = BrCond->getOperand(i: 1).getMBB(); |
1552 | Observer.changingInstr(MI); |
1553 | MI.getOperand(i: 0).setMBB(FallthroughBB); |
1554 | Observer.changedInstr(MI); |
1555 | |
1556 | // Change the conditional branch to use the inverted condition and |
1557 | // new target block. |
1558 | Observer.changingInstr(MI&: *BrCond); |
1559 | BrCond->getOperand(i: 0).setReg(Xor.getReg(Idx: 0)); |
1560 | BrCond->getOperand(i: 1).setMBB(BrTarget); |
1561 | Observer.changedInstr(MI&: *BrCond); |
1562 | } |
1563 | |
1564 | |
1565 | bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) { |
1566 | MachineIRBuilder HelperBuilder(MI); |
1567 | GISelObserverWrapper DummyObserver; |
1568 | LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder); |
1569 | return Helper.lowerMemcpyInline(MI) == |
1570 | LegalizerHelper::LegalizeResult::Legalized; |
1571 | } |
1572 | |
1573 | bool CombinerHelper::tryCombineMemCpyFamily(MachineInstr &MI, unsigned MaxLen) { |
1574 | MachineIRBuilder HelperBuilder(MI); |
1575 | GISelObserverWrapper DummyObserver; |
1576 | LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder); |
1577 | return Helper.lowerMemCpyFamily(MI, MaxLen) == |
1578 | LegalizerHelper::LegalizeResult::Legalized; |
1579 | } |
1580 | |
1581 | static APFloat constantFoldFpUnary(const MachineInstr &MI, |
1582 | const MachineRegisterInfo &MRI, |
1583 | const APFloat &Val) { |
1584 | APFloat Result(Val); |
1585 | switch (MI.getOpcode()) { |
1586 | default: |
1587 | llvm_unreachable("Unexpected opcode!" ); |
1588 | case TargetOpcode::G_FNEG: { |
1589 | Result.changeSign(); |
1590 | return Result; |
1591 | } |
1592 | case TargetOpcode::G_FABS: { |
1593 | Result.clearSign(); |
1594 | return Result; |
1595 | } |
1596 | case TargetOpcode::G_FPTRUNC: { |
1597 | bool Unused; |
1598 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
1599 | Result.convert(ToSemantics: getFltSemanticForLLT(Ty: DstTy), RM: APFloat::rmNearestTiesToEven, |
1600 | losesInfo: &Unused); |
1601 | return Result; |
1602 | } |
1603 | case TargetOpcode::G_FSQRT: { |
1604 | bool Unused; |
1605 | Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven, |
1606 | losesInfo: &Unused); |
1607 | Result = APFloat(sqrt(x: Result.convertToDouble())); |
1608 | break; |
1609 | } |
1610 | case TargetOpcode::G_FLOG2: { |
1611 | bool Unused; |
1612 | Result.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven, |
1613 | losesInfo: &Unused); |
1614 | Result = APFloat(log2(x: Result.convertToDouble())); |
1615 | break; |
1616 | } |
1617 | } |
1618 | // Convert `APFloat` to appropriate IEEE type depending on `DstTy`. Otherwise, |
1619 | // `buildFConstant` will assert on size mismatch. Only `G_FSQRT`, and |
1620 | // `G_FLOG2` reach here. |
1621 | bool Unused; |
1622 | Result.convert(ToSemantics: Val.getSemantics(), RM: APFloat::rmNearestTiesToEven, losesInfo: &Unused); |
1623 | return Result; |
1624 | } |
1625 | |
1626 | void CombinerHelper::applyCombineConstantFoldFpUnary(MachineInstr &MI, |
1627 | const ConstantFP *Cst) { |
1628 | APFloat Folded = constantFoldFpUnary(MI, MRI, Val: Cst->getValue()); |
1629 | const ConstantFP *NewCst = ConstantFP::get(Context&: Builder.getContext(), V: Folded); |
1630 | Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: *NewCst); |
1631 | MI.eraseFromParent(); |
1632 | } |
1633 | |
1634 | bool CombinerHelper::matchPtrAddImmedChain(MachineInstr &MI, |
1635 | PtrAddChain &MatchInfo) { |
1636 | // We're trying to match the following pattern: |
1637 | // %t1 = G_PTR_ADD %base, G_CONSTANT imm1 |
1638 | // %root = G_PTR_ADD %t1, G_CONSTANT imm2 |
1639 | // --> |
1640 | // %root = G_PTR_ADD %base, G_CONSTANT (imm1 + imm2) |
1641 | |
1642 | if (MI.getOpcode() != TargetOpcode::G_PTR_ADD) |
1643 | return false; |
1644 | |
1645 | Register Add2 = MI.getOperand(i: 1).getReg(); |
1646 | Register Imm1 = MI.getOperand(i: 2).getReg(); |
1647 | auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI); |
1648 | if (!MaybeImmVal) |
1649 | return false; |
1650 | |
1651 | MachineInstr *Add2Def = MRI.getVRegDef(Reg: Add2); |
1652 | if (!Add2Def || Add2Def->getOpcode() != TargetOpcode::G_PTR_ADD) |
1653 | return false; |
1654 | |
1655 | Register Base = Add2Def->getOperand(i: 1).getReg(); |
1656 | Register Imm2 = Add2Def->getOperand(i: 2).getReg(); |
1657 | auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI); |
1658 | if (!MaybeImm2Val) |
1659 | return false; |
1660 | |
1661 | // Check if the new combined immediate forms an illegal addressing mode. |
1662 | // Do not combine if it was legal before but would get illegal. |
1663 | // To do so, we need to find a load/store user of the pointer to get |
1664 | // the access type. |
1665 | Type *AccessTy = nullptr; |
1666 | auto &MF = *MI.getMF(); |
1667 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: MI.getOperand(i: 0).getReg())) { |
1668 | if (auto *LdSt = dyn_cast<GLoadStore>(Val: &UseMI)) { |
1669 | AccessTy = getTypeForLLT(Ty: MRI.getType(Reg: LdSt->getReg(Idx: 0)), |
1670 | C&: MF.getFunction().getContext()); |
1671 | break; |
1672 | } |
1673 | } |
1674 | TargetLoweringBase::AddrMode AMNew; |
1675 | APInt CombinedImm = MaybeImmVal->Value + MaybeImm2Val->Value; |
1676 | AMNew.BaseOffs = CombinedImm.getSExtValue(); |
1677 | if (AccessTy) { |
1678 | AMNew.HasBaseReg = true; |
1679 | TargetLoweringBase::AddrMode AMOld; |
1680 | AMOld.BaseOffs = MaybeImmVal->Value.getSExtValue(); |
1681 | AMOld.HasBaseReg = true; |
1682 | unsigned AS = MRI.getType(Reg: Add2).getAddressSpace(); |
1683 | const auto &TLI = *MF.getSubtarget().getTargetLowering(); |
1684 | if (TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMOld, Ty: AccessTy, AddrSpace: AS) && |
1685 | !TLI.isLegalAddressingMode(DL: MF.getDataLayout(), AM: AMNew, Ty: AccessTy, AddrSpace: AS)) |
1686 | return false; |
1687 | } |
1688 | |
1689 | // Pass the combined immediate to the apply function. |
1690 | MatchInfo.Imm = AMNew.BaseOffs; |
1691 | MatchInfo.Base = Base; |
1692 | MatchInfo.Bank = getRegBank(Reg: Imm2); |
1693 | return true; |
1694 | } |
1695 | |
1696 | void CombinerHelper::applyPtrAddImmedChain(MachineInstr &MI, |
1697 | PtrAddChain &MatchInfo) { |
1698 | assert(MI.getOpcode() == TargetOpcode::G_PTR_ADD && "Expected G_PTR_ADD" ); |
1699 | MachineIRBuilder MIB(MI); |
1700 | LLT OffsetTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg()); |
1701 | auto NewOffset = MIB.buildConstant(Res: OffsetTy, Val: MatchInfo.Imm); |
1702 | setRegBank(Reg: NewOffset.getReg(Idx: 0), RegBank: MatchInfo.Bank); |
1703 | Observer.changingInstr(MI); |
1704 | MI.getOperand(i: 1).setReg(MatchInfo.Base); |
1705 | MI.getOperand(i: 2).setReg(NewOffset.getReg(Idx: 0)); |
1706 | Observer.changedInstr(MI); |
1707 | } |
1708 | |
1709 | bool CombinerHelper::matchShiftImmedChain(MachineInstr &MI, |
1710 | RegisterImmPair &MatchInfo) { |
1711 | // We're trying to match the following pattern with any of |
1712 | // G_SHL/G_ASHR/G_LSHR/G_SSHLSAT/G_USHLSAT shift instructions: |
1713 | // %t1 = SHIFT %base, G_CONSTANT imm1 |
1714 | // %root = SHIFT %t1, G_CONSTANT imm2 |
1715 | // --> |
1716 | // %root = SHIFT %base, G_CONSTANT (imm1 + imm2) |
1717 | |
1718 | unsigned Opcode = MI.getOpcode(); |
1719 | assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || |
1720 | Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT || |
1721 | Opcode == TargetOpcode::G_USHLSAT) && |
1722 | "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT" ); |
1723 | |
1724 | Register Shl2 = MI.getOperand(i: 1).getReg(); |
1725 | Register Imm1 = MI.getOperand(i: 2).getReg(); |
1726 | auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: Imm1, MRI); |
1727 | if (!MaybeImmVal) |
1728 | return false; |
1729 | |
1730 | MachineInstr *Shl2Def = MRI.getUniqueVRegDef(Reg: Shl2); |
1731 | if (Shl2Def->getOpcode() != Opcode) |
1732 | return false; |
1733 | |
1734 | Register Base = Shl2Def->getOperand(i: 1).getReg(); |
1735 | Register Imm2 = Shl2Def->getOperand(i: 2).getReg(); |
1736 | auto MaybeImm2Val = getIConstantVRegValWithLookThrough(VReg: Imm2, MRI); |
1737 | if (!MaybeImm2Val) |
1738 | return false; |
1739 | |
1740 | // Pass the combined immediate to the apply function. |
1741 | MatchInfo.Imm = |
1742 | (MaybeImmVal->Value.getZExtValue() + MaybeImm2Val->Value).getZExtValue(); |
1743 | MatchInfo.Reg = Base; |
1744 | |
1745 | // There is no simple replacement for a saturating unsigned left shift that |
1746 | // exceeds the scalar size. |
1747 | if (Opcode == TargetOpcode::G_USHLSAT && |
1748 | MatchInfo.Imm >= MRI.getType(Reg: Shl2).getScalarSizeInBits()) |
1749 | return false; |
1750 | |
1751 | return true; |
1752 | } |
1753 | |
1754 | void CombinerHelper::applyShiftImmedChain(MachineInstr &MI, |
1755 | RegisterImmPair &MatchInfo) { |
1756 | unsigned Opcode = MI.getOpcode(); |
1757 | assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || |
1758 | Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT || |
1759 | Opcode == TargetOpcode::G_USHLSAT) && |
1760 | "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT" ); |
1761 | |
1762 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 1).getReg()); |
1763 | unsigned const ScalarSizeInBits = Ty.getScalarSizeInBits(); |
1764 | auto Imm = MatchInfo.Imm; |
1765 | |
1766 | if (Imm >= ScalarSizeInBits) { |
1767 | // Any logical shift that exceeds scalar size will produce zero. |
1768 | if (Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR) { |
1769 | Builder.buildConstant(Res: MI.getOperand(i: 0), Val: 0); |
1770 | MI.eraseFromParent(); |
1771 | return; |
1772 | } |
1773 | // Arithmetic shift and saturating signed left shift have no effect beyond |
1774 | // scalar size. |
1775 | Imm = ScalarSizeInBits - 1; |
1776 | } |
1777 | |
1778 | LLT ImmTy = MRI.getType(Reg: MI.getOperand(i: 2).getReg()); |
1779 | Register NewImm = Builder.buildConstant(Res: ImmTy, Val: Imm).getReg(Idx: 0); |
1780 | Observer.changingInstr(MI); |
1781 | MI.getOperand(i: 1).setReg(MatchInfo.Reg); |
1782 | MI.getOperand(i: 2).setReg(NewImm); |
1783 | Observer.changedInstr(MI); |
1784 | } |
1785 | |
1786 | bool CombinerHelper::matchShiftOfShiftedLogic(MachineInstr &MI, |
1787 | ShiftOfShiftedLogic &MatchInfo) { |
1788 | // We're trying to match the following pattern with any of |
1789 | // G_SHL/G_ASHR/G_LSHR/G_USHLSAT/G_SSHLSAT shift instructions in combination |
1790 | // with any of G_AND/G_OR/G_XOR logic instructions. |
1791 | // %t1 = SHIFT %X, G_CONSTANT C0 |
1792 | // %t2 = LOGIC %t1, %Y |
1793 | // %root = SHIFT %t2, G_CONSTANT C1 |
1794 | // --> |
1795 | // %t3 = SHIFT %X, G_CONSTANT (C0+C1) |
1796 | // %t4 = SHIFT %Y, G_CONSTANT C1 |
1797 | // %root = LOGIC %t3, %t4 |
1798 | unsigned ShiftOpcode = MI.getOpcode(); |
1799 | assert((ShiftOpcode == TargetOpcode::G_SHL || |
1800 | ShiftOpcode == TargetOpcode::G_ASHR || |
1801 | ShiftOpcode == TargetOpcode::G_LSHR || |
1802 | ShiftOpcode == TargetOpcode::G_USHLSAT || |
1803 | ShiftOpcode == TargetOpcode::G_SSHLSAT) && |
1804 | "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT" ); |
1805 | |
1806 | // Match a one-use bitwise logic op. |
1807 | Register LogicDest = MI.getOperand(i: 1).getReg(); |
1808 | if (!MRI.hasOneNonDBGUse(RegNo: LogicDest)) |
1809 | return false; |
1810 | |
1811 | MachineInstr *LogicMI = MRI.getUniqueVRegDef(Reg: LogicDest); |
1812 | unsigned LogicOpcode = LogicMI->getOpcode(); |
1813 | if (LogicOpcode != TargetOpcode::G_AND && LogicOpcode != TargetOpcode::G_OR && |
1814 | LogicOpcode != TargetOpcode::G_XOR) |
1815 | return false; |
1816 | |
1817 | // Find a matching one-use shift by constant. |
1818 | const Register C1 = MI.getOperand(i: 2).getReg(); |
1819 | auto MaybeImmVal = getIConstantVRegValWithLookThrough(VReg: C1, MRI); |
1820 | if (!MaybeImmVal || MaybeImmVal->Value == 0) |
1821 | return false; |
1822 | |
1823 | const uint64_t C1Val = MaybeImmVal->Value.getZExtValue(); |
1824 | |
1825 | auto matchFirstShift = [&](const MachineInstr *MI, uint64_t &ShiftVal) { |
1826 | // Shift should match previous one and should be a one-use. |
1827 | if (MI->getOpcode() != ShiftOpcode || |
1828 | !MRI.hasOneNonDBGUse(RegNo: MI->getOperand(i: 0).getReg())) |
1829 | return false; |
1830 | |
1831 | // Must be a constant. |
1832 | auto MaybeImmVal = |
1833 | getIConstantVRegValWithLookThrough(VReg: MI->getOperand(i: 2).getReg(), MRI); |
1834 | if (!MaybeImmVal) |
1835 | return false; |
1836 | |
1837 | ShiftVal = MaybeImmVal->Value.getSExtValue(); |
1838 | return true; |
1839 | }; |
1840 | |
1841 | // Logic ops are commutative, so check each operand for a match. |
1842 | Register LogicMIReg1 = LogicMI->getOperand(i: 1).getReg(); |
1843 | MachineInstr *LogicMIOp1 = MRI.getUniqueVRegDef(Reg: LogicMIReg1); |
1844 | Register LogicMIReg2 = LogicMI->getOperand(i: 2).getReg(); |
1845 | MachineInstr *LogicMIOp2 = MRI.getUniqueVRegDef(Reg: LogicMIReg2); |
1846 | uint64_t C0Val; |
1847 | |
1848 | if (matchFirstShift(LogicMIOp1, C0Val)) { |
1849 | MatchInfo.LogicNonShiftReg = LogicMIReg2; |
1850 | MatchInfo.Shift2 = LogicMIOp1; |
1851 | } else if (matchFirstShift(LogicMIOp2, C0Val)) { |
1852 | MatchInfo.LogicNonShiftReg = LogicMIReg1; |
1853 | MatchInfo.Shift2 = LogicMIOp2; |
1854 | } else |
1855 | return false; |
1856 | |
1857 | MatchInfo.ValSum = C0Val + C1Val; |
1858 | |
1859 | // The fold is not valid if the sum of the shift values exceeds bitwidth. |
1860 | if (MatchInfo.ValSum >= MRI.getType(Reg: LogicDest).getScalarSizeInBits()) |
1861 | return false; |
1862 | |
1863 | MatchInfo.Logic = LogicMI; |
1864 | return true; |
1865 | } |
1866 | |
1867 | void CombinerHelper::applyShiftOfShiftedLogic(MachineInstr &MI, |
1868 | ShiftOfShiftedLogic &MatchInfo) { |
1869 | unsigned Opcode = MI.getOpcode(); |
1870 | assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || |
1871 | Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_USHLSAT || |
1872 | Opcode == TargetOpcode::G_SSHLSAT) && |
1873 | "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT" ); |
1874 | |
1875 | LLT ShlType = MRI.getType(Reg: MI.getOperand(i: 2).getReg()); |
1876 | LLT DestType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
1877 | |
1878 | Register Const = Builder.buildConstant(Res: ShlType, Val: MatchInfo.ValSum).getReg(Idx: 0); |
1879 | |
1880 | Register Shift1Base = MatchInfo.Shift2->getOperand(i: 1).getReg(); |
1881 | Register Shift1 = |
1882 | Builder.buildInstr(Opc: Opcode, DstOps: {DestType}, SrcOps: {Shift1Base, Const}).getReg(Idx: 0); |
1883 | |
1884 | // If LogicNonShiftReg is the same to Shift1Base, and shift1 const is the same |
1885 | // to MatchInfo.Shift2 const, CSEMIRBuilder will reuse the old shift1 when |
1886 | // build shift2. So, if we erase MatchInfo.Shift2 at the end, actually we |
1887 | // remove old shift1. And it will cause crash later. So erase it earlier to |
1888 | // avoid the crash. |
1889 | MatchInfo.Shift2->eraseFromParent(); |
1890 | |
1891 | Register Shift2Const = MI.getOperand(i: 2).getReg(); |
1892 | Register Shift2 = Builder |
1893 | .buildInstr(Opc: Opcode, DstOps: {DestType}, |
1894 | SrcOps: {MatchInfo.LogicNonShiftReg, Shift2Const}) |
1895 | .getReg(Idx: 0); |
1896 | |
1897 | Register Dest = MI.getOperand(i: 0).getReg(); |
1898 | Builder.buildInstr(Opc: MatchInfo.Logic->getOpcode(), DstOps: {Dest}, SrcOps: {Shift1, Shift2}); |
1899 | |
1900 | // This was one use so it's safe to remove it. |
1901 | MatchInfo.Logic->eraseFromParent(); |
1902 | |
1903 | MI.eraseFromParent(); |
1904 | } |
1905 | |
1906 | bool CombinerHelper::matchCommuteShift(MachineInstr &MI, BuildFnTy &MatchInfo) { |
1907 | assert(MI.getOpcode() == TargetOpcode::G_SHL && "Expected G_SHL" ); |
1908 | // Combine (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2) |
1909 | // Combine (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2) |
1910 | auto &Shl = cast<GenericMachineInstr>(Val&: MI); |
1911 | Register DstReg = Shl.getReg(Idx: 0); |
1912 | Register SrcReg = Shl.getReg(Idx: 1); |
1913 | Register ShiftReg = Shl.getReg(Idx: 2); |
1914 | Register X, C1; |
1915 | |
1916 | if (!getTargetLowering().isDesirableToCommuteWithShift(MI, IsAfterLegal: !isPreLegalize())) |
1917 | return false; |
1918 | |
1919 | if (!mi_match(R: SrcReg, MRI, |
1920 | P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: C1)), |
1921 | preds: m_GOr(L: m_Reg(R&: X), R: m_Reg(R&: C1)))))) |
1922 | return false; |
1923 | |
1924 | APInt C1Val, C2Val; |
1925 | if (!mi_match(R: C1, MRI, P: m_ICstOrSplat(Cst&: C1Val)) || |
1926 | !mi_match(R: ShiftReg, MRI, P: m_ICstOrSplat(Cst&: C2Val))) |
1927 | return false; |
1928 | |
1929 | auto *SrcDef = MRI.getVRegDef(Reg: SrcReg); |
1930 | assert((SrcDef->getOpcode() == TargetOpcode::G_ADD || |
1931 | SrcDef->getOpcode() == TargetOpcode::G_OR) && "Unexpected op" ); |
1932 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
1933 | MatchInfo = [=](MachineIRBuilder &B) { |
1934 | auto S1 = B.buildShl(Dst: SrcTy, Src0: X, Src1: ShiftReg); |
1935 | auto S2 = B.buildShl(Dst: SrcTy, Src0: C1, Src1: ShiftReg); |
1936 | B.buildInstr(Opc: SrcDef->getOpcode(), DstOps: {DstReg}, SrcOps: {S1, S2}); |
1937 | }; |
1938 | return true; |
1939 | } |
1940 | |
1941 | bool CombinerHelper::matchCombineMulToShl(MachineInstr &MI, |
1942 | unsigned &ShiftVal) { |
1943 | assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL" ); |
1944 | auto MaybeImmVal = |
1945 | getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
1946 | if (!MaybeImmVal) |
1947 | return false; |
1948 | |
1949 | ShiftVal = MaybeImmVal->Value.exactLogBase2(); |
1950 | return (static_cast<int32_t>(ShiftVal) != -1); |
1951 | } |
1952 | |
1953 | void CombinerHelper::applyCombineMulToShl(MachineInstr &MI, |
1954 | unsigned &ShiftVal) { |
1955 | assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL" ); |
1956 | MachineIRBuilder MIB(MI); |
1957 | LLT ShiftTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
1958 | auto ShiftCst = MIB.buildConstant(Res: ShiftTy, Val: ShiftVal); |
1959 | Observer.changingInstr(MI); |
1960 | MI.setDesc(MIB.getTII().get(Opcode: TargetOpcode::G_SHL)); |
1961 | MI.getOperand(i: 2).setReg(ShiftCst.getReg(Idx: 0)); |
1962 | Observer.changedInstr(MI); |
1963 | } |
1964 | |
1965 | // shl ([sza]ext x), y => zext (shl x, y), if shift does not overflow source |
1966 | bool CombinerHelper::matchCombineShlOfExtend(MachineInstr &MI, |
1967 | RegisterImmPair &MatchData) { |
1968 | assert(MI.getOpcode() == TargetOpcode::G_SHL && KB); |
1969 | if (!getTargetLowering().isDesirableToPullExtFromShl(MI)) |
1970 | return false; |
1971 | |
1972 | Register LHS = MI.getOperand(i: 1).getReg(); |
1973 | |
1974 | Register ExtSrc; |
1975 | if (!mi_match(R: LHS, MRI, P: m_GAnyExt(Src: m_Reg(R&: ExtSrc))) && |
1976 | !mi_match(R: LHS, MRI, P: m_GZExt(Src: m_Reg(R&: ExtSrc))) && |
1977 | !mi_match(R: LHS, MRI, P: m_GSExt(Src: m_Reg(R&: ExtSrc)))) |
1978 | return false; |
1979 | |
1980 | Register RHS = MI.getOperand(i: 2).getReg(); |
1981 | MachineInstr *MIShiftAmt = MRI.getVRegDef(Reg: RHS); |
1982 | auto MaybeShiftAmtVal = isConstantOrConstantSplatVector(MI&: *MIShiftAmt, MRI); |
1983 | if (!MaybeShiftAmtVal) |
1984 | return false; |
1985 | |
1986 | if (LI) { |
1987 | LLT SrcTy = MRI.getType(Reg: ExtSrc); |
1988 | |
1989 | // We only really care about the legality with the shifted value. We can |
1990 | // pick any type the constant shift amount, so ask the target what to |
1991 | // use. Otherwise we would have to guess and hope it is reported as legal. |
1992 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: SrcTy); |
1993 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SHL, {SrcTy, ShiftAmtTy}})) |
1994 | return false; |
1995 | } |
1996 | |
1997 | int64_t ShiftAmt = MaybeShiftAmtVal->getSExtValue(); |
1998 | MatchData.Reg = ExtSrc; |
1999 | MatchData.Imm = ShiftAmt; |
2000 | |
2001 | unsigned MinLeadingZeros = KB->getKnownZeroes(R: ExtSrc).countl_one(); |
2002 | unsigned SrcTySize = MRI.getType(Reg: ExtSrc).getScalarSizeInBits(); |
2003 | return MinLeadingZeros >= ShiftAmt && ShiftAmt < SrcTySize; |
2004 | } |
2005 | |
2006 | void CombinerHelper::applyCombineShlOfExtend(MachineInstr &MI, |
2007 | const RegisterImmPair &MatchData) { |
2008 | Register ExtSrcReg = MatchData.Reg; |
2009 | int64_t ShiftAmtVal = MatchData.Imm; |
2010 | |
2011 | LLT ExtSrcTy = MRI.getType(Reg: ExtSrcReg); |
2012 | auto ShiftAmt = Builder.buildConstant(Res: ExtSrcTy, Val: ShiftAmtVal); |
2013 | auto NarrowShift = |
2014 | Builder.buildShl(Dst: ExtSrcTy, Src0: ExtSrcReg, Src1: ShiftAmt, Flags: MI.getFlags()); |
2015 | Builder.buildZExt(Res: MI.getOperand(i: 0), Op: NarrowShift); |
2016 | MI.eraseFromParent(); |
2017 | } |
2018 | |
2019 | bool CombinerHelper::matchCombineMergeUnmerge(MachineInstr &MI, |
2020 | Register &MatchInfo) { |
2021 | GMerge &Merge = cast<GMerge>(Val&: MI); |
2022 | SmallVector<Register, 16> MergedValues; |
2023 | for (unsigned I = 0; I < Merge.getNumSources(); ++I) |
2024 | MergedValues.emplace_back(Args: Merge.getSourceReg(I)); |
2025 | |
2026 | auto *Unmerge = getOpcodeDef<GUnmerge>(Reg: MergedValues[0], MRI); |
2027 | if (!Unmerge || Unmerge->getNumDefs() != Merge.getNumSources()) |
2028 | return false; |
2029 | |
2030 | for (unsigned I = 0; I < MergedValues.size(); ++I) |
2031 | if (MergedValues[I] != Unmerge->getReg(Idx: I)) |
2032 | return false; |
2033 | |
2034 | MatchInfo = Unmerge->getSourceReg(); |
2035 | return true; |
2036 | } |
2037 | |
2038 | static Register peekThroughBitcast(Register Reg, |
2039 | const MachineRegisterInfo &MRI) { |
2040 | while (mi_match(R: Reg, MRI, P: m_GBitcast(Src: m_Reg(R&: Reg)))) |
2041 | ; |
2042 | |
2043 | return Reg; |
2044 | } |
2045 | |
2046 | bool CombinerHelper::matchCombineUnmergeMergeToPlainValues( |
2047 | MachineInstr &MI, SmallVectorImpl<Register> &Operands) { |
2048 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2049 | "Expected an unmerge" ); |
2050 | auto &Unmerge = cast<GUnmerge>(Val&: MI); |
2051 | Register SrcReg = peekThroughBitcast(Reg: Unmerge.getSourceReg(), MRI); |
2052 | |
2053 | auto *SrcInstr = getOpcodeDef<GMergeLikeInstr>(Reg: SrcReg, MRI); |
2054 | if (!SrcInstr) |
2055 | return false; |
2056 | |
2057 | // Check the source type of the merge. |
2058 | LLT SrcMergeTy = MRI.getType(Reg: SrcInstr->getSourceReg(I: 0)); |
2059 | LLT Dst0Ty = MRI.getType(Reg: Unmerge.getReg(Idx: 0)); |
2060 | bool SameSize = Dst0Ty.getSizeInBits() == SrcMergeTy.getSizeInBits(); |
2061 | if (SrcMergeTy != Dst0Ty && !SameSize) |
2062 | return false; |
2063 | // They are the same now (modulo a bitcast). |
2064 | // We can collect all the src registers. |
2065 | for (unsigned Idx = 0; Idx < SrcInstr->getNumSources(); ++Idx) |
2066 | Operands.push_back(Elt: SrcInstr->getSourceReg(I: Idx)); |
2067 | return true; |
2068 | } |
2069 | |
2070 | void CombinerHelper::applyCombineUnmergeMergeToPlainValues( |
2071 | MachineInstr &MI, SmallVectorImpl<Register> &Operands) { |
2072 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2073 | "Expected an unmerge" ); |
2074 | assert((MI.getNumOperands() - 1 == Operands.size()) && |
2075 | "Not enough operands to replace all defs" ); |
2076 | unsigned NumElems = MI.getNumOperands() - 1; |
2077 | |
2078 | LLT SrcTy = MRI.getType(Reg: Operands[0]); |
2079 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2080 | bool CanReuseInputDirectly = DstTy == SrcTy; |
2081 | for (unsigned Idx = 0; Idx < NumElems; ++Idx) { |
2082 | Register DstReg = MI.getOperand(i: Idx).getReg(); |
2083 | Register SrcReg = Operands[Idx]; |
2084 | |
2085 | // This combine may run after RegBankSelect, so we need to be aware of |
2086 | // register banks. |
2087 | const auto &DstCB = MRI.getRegClassOrRegBank(Reg: DstReg); |
2088 | if (!DstCB.isNull() && DstCB != MRI.getRegClassOrRegBank(Reg: SrcReg)) { |
2089 | SrcReg = Builder.buildCopy(Res: MRI.getType(Reg: SrcReg), Op: SrcReg).getReg(Idx: 0); |
2090 | MRI.setRegClassOrRegBank(Reg: SrcReg, RCOrRB: DstCB); |
2091 | } |
2092 | |
2093 | if (CanReuseInputDirectly) |
2094 | replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg); |
2095 | else |
2096 | Builder.buildCast(Dst: DstReg, Src: SrcReg); |
2097 | } |
2098 | MI.eraseFromParent(); |
2099 | } |
2100 | |
2101 | bool CombinerHelper::matchCombineUnmergeConstant(MachineInstr &MI, |
2102 | SmallVectorImpl<APInt> &Csts) { |
2103 | unsigned SrcIdx = MI.getNumOperands() - 1; |
2104 | Register SrcReg = MI.getOperand(i: SrcIdx).getReg(); |
2105 | MachineInstr *SrcInstr = MRI.getVRegDef(Reg: SrcReg); |
2106 | if (SrcInstr->getOpcode() != TargetOpcode::G_CONSTANT && |
2107 | SrcInstr->getOpcode() != TargetOpcode::G_FCONSTANT) |
2108 | return false; |
2109 | // Break down the big constant in smaller ones. |
2110 | const MachineOperand &CstVal = SrcInstr->getOperand(i: 1); |
2111 | APInt Val = SrcInstr->getOpcode() == TargetOpcode::G_CONSTANT |
2112 | ? CstVal.getCImm()->getValue() |
2113 | : CstVal.getFPImm()->getValueAPF().bitcastToAPInt(); |
2114 | |
2115 | LLT Dst0Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2116 | unsigned ShiftAmt = Dst0Ty.getSizeInBits(); |
2117 | // Unmerge a constant. |
2118 | for (unsigned Idx = 0; Idx != SrcIdx; ++Idx) { |
2119 | Csts.emplace_back(Args: Val.trunc(width: ShiftAmt)); |
2120 | Val = Val.lshr(shiftAmt: ShiftAmt); |
2121 | } |
2122 | |
2123 | return true; |
2124 | } |
2125 | |
2126 | void CombinerHelper::applyCombineUnmergeConstant(MachineInstr &MI, |
2127 | SmallVectorImpl<APInt> &Csts) { |
2128 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2129 | "Expected an unmerge" ); |
2130 | assert((MI.getNumOperands() - 1 == Csts.size()) && |
2131 | "Not enough operands to replace all defs" ); |
2132 | unsigned NumElems = MI.getNumOperands() - 1; |
2133 | for (unsigned Idx = 0; Idx < NumElems; ++Idx) { |
2134 | Register DstReg = MI.getOperand(i: Idx).getReg(); |
2135 | Builder.buildConstant(Res: DstReg, Val: Csts[Idx]); |
2136 | } |
2137 | |
2138 | MI.eraseFromParent(); |
2139 | } |
2140 | |
2141 | bool CombinerHelper::matchCombineUnmergeUndef( |
2142 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
2143 | unsigned SrcIdx = MI.getNumOperands() - 1; |
2144 | Register SrcReg = MI.getOperand(i: SrcIdx).getReg(); |
2145 | MatchInfo = [&MI](MachineIRBuilder &B) { |
2146 | unsigned NumElems = MI.getNumOperands() - 1; |
2147 | for (unsigned Idx = 0; Idx < NumElems; ++Idx) { |
2148 | Register DstReg = MI.getOperand(i: Idx).getReg(); |
2149 | B.buildUndef(Res: DstReg); |
2150 | } |
2151 | }; |
2152 | return isa<GImplicitDef>(Val: MRI.getVRegDef(Reg: SrcReg)); |
2153 | } |
2154 | |
2155 | bool CombinerHelper::matchCombineUnmergeWithDeadLanesToTrunc(MachineInstr &MI) { |
2156 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2157 | "Expected an unmerge" ); |
2158 | if (MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector() || |
2159 | MRI.getType(Reg: MI.getOperand(i: MI.getNumDefs()).getReg()).isVector()) |
2160 | return false; |
2161 | // Check that all the lanes are dead except the first one. |
2162 | for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) { |
2163 | if (!MRI.use_nodbg_empty(RegNo: MI.getOperand(i: Idx).getReg())) |
2164 | return false; |
2165 | } |
2166 | return true; |
2167 | } |
2168 | |
2169 | void CombinerHelper::applyCombineUnmergeWithDeadLanesToTrunc(MachineInstr &MI) { |
2170 | Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg(); |
2171 | Register Dst0Reg = MI.getOperand(i: 0).getReg(); |
2172 | Builder.buildTrunc(Res: Dst0Reg, Op: SrcReg); |
2173 | MI.eraseFromParent(); |
2174 | } |
2175 | |
2176 | bool CombinerHelper::matchCombineUnmergeZExtToZExt(MachineInstr &MI) { |
2177 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2178 | "Expected an unmerge" ); |
2179 | Register Dst0Reg = MI.getOperand(i: 0).getReg(); |
2180 | LLT Dst0Ty = MRI.getType(Reg: Dst0Reg); |
2181 | // G_ZEXT on vector applies to each lane, so it will |
2182 | // affect all destinations. Therefore we won't be able |
2183 | // to simplify the unmerge to just the first definition. |
2184 | if (Dst0Ty.isVector()) |
2185 | return false; |
2186 | Register SrcReg = MI.getOperand(i: MI.getNumDefs()).getReg(); |
2187 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
2188 | if (SrcTy.isVector()) |
2189 | return false; |
2190 | |
2191 | Register ZExtSrcReg; |
2192 | if (!mi_match(R: SrcReg, MRI, P: m_GZExt(Src: m_Reg(R&: ZExtSrcReg)))) |
2193 | return false; |
2194 | |
2195 | // Finally we can replace the first definition with |
2196 | // a zext of the source if the definition is big enough to hold |
2197 | // all of ZExtSrc bits. |
2198 | LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg); |
2199 | return ZExtSrcTy.getSizeInBits() <= Dst0Ty.getSizeInBits(); |
2200 | } |
2201 | |
2202 | void CombinerHelper::applyCombineUnmergeZExtToZExt(MachineInstr &MI) { |
2203 | assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && |
2204 | "Expected an unmerge" ); |
2205 | |
2206 | Register Dst0Reg = MI.getOperand(i: 0).getReg(); |
2207 | |
2208 | MachineInstr *ZExtInstr = |
2209 | MRI.getVRegDef(Reg: MI.getOperand(i: MI.getNumDefs()).getReg()); |
2210 | assert(ZExtInstr && ZExtInstr->getOpcode() == TargetOpcode::G_ZEXT && |
2211 | "Expecting a G_ZEXT" ); |
2212 | |
2213 | Register ZExtSrcReg = ZExtInstr->getOperand(i: 1).getReg(); |
2214 | LLT Dst0Ty = MRI.getType(Reg: Dst0Reg); |
2215 | LLT ZExtSrcTy = MRI.getType(Reg: ZExtSrcReg); |
2216 | |
2217 | if (Dst0Ty.getSizeInBits() > ZExtSrcTy.getSizeInBits()) { |
2218 | Builder.buildZExt(Res: Dst0Reg, Op: ZExtSrcReg); |
2219 | } else { |
2220 | assert(Dst0Ty.getSizeInBits() == ZExtSrcTy.getSizeInBits() && |
2221 | "ZExt src doesn't fit in destination" ); |
2222 | replaceRegWith(MRI, FromReg: Dst0Reg, ToReg: ZExtSrcReg); |
2223 | } |
2224 | |
2225 | Register ZeroReg; |
2226 | for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) { |
2227 | if (!ZeroReg) |
2228 | ZeroReg = Builder.buildConstant(Res: Dst0Ty, Val: 0).getReg(Idx: 0); |
2229 | replaceRegWith(MRI, FromReg: MI.getOperand(i: Idx).getReg(), ToReg: ZeroReg); |
2230 | } |
2231 | MI.eraseFromParent(); |
2232 | } |
2233 | |
2234 | bool CombinerHelper::matchCombineShiftToUnmerge(MachineInstr &MI, |
2235 | unsigned TargetShiftSize, |
2236 | unsigned &ShiftVal) { |
2237 | assert((MI.getOpcode() == TargetOpcode::G_SHL || |
2238 | MI.getOpcode() == TargetOpcode::G_LSHR || |
2239 | MI.getOpcode() == TargetOpcode::G_ASHR) && "Expected a shift" ); |
2240 | |
2241 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2242 | if (Ty.isVector()) // TODO: |
2243 | return false; |
2244 | |
2245 | // Don't narrow further than the requested size. |
2246 | unsigned Size = Ty.getSizeInBits(); |
2247 | if (Size <= TargetShiftSize) |
2248 | return false; |
2249 | |
2250 | auto MaybeImmVal = |
2251 | getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
2252 | if (!MaybeImmVal) |
2253 | return false; |
2254 | |
2255 | ShiftVal = MaybeImmVal->Value.getSExtValue(); |
2256 | return ShiftVal >= Size / 2 && ShiftVal < Size; |
2257 | } |
2258 | |
2259 | void CombinerHelper::applyCombineShiftToUnmerge(MachineInstr &MI, |
2260 | const unsigned &ShiftVal) { |
2261 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2262 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2263 | LLT Ty = MRI.getType(Reg: SrcReg); |
2264 | unsigned Size = Ty.getSizeInBits(); |
2265 | unsigned HalfSize = Size / 2; |
2266 | assert(ShiftVal >= HalfSize); |
2267 | |
2268 | LLT HalfTy = LLT::scalar(SizeInBits: HalfSize); |
2269 | |
2270 | auto Unmerge = Builder.buildUnmerge(Res: HalfTy, Op: SrcReg); |
2271 | unsigned NarrowShiftAmt = ShiftVal - HalfSize; |
2272 | |
2273 | if (MI.getOpcode() == TargetOpcode::G_LSHR) { |
2274 | Register Narrowed = Unmerge.getReg(Idx: 1); |
2275 | |
2276 | // dst = G_LSHR s64:x, C for C >= 32 |
2277 | // => |
2278 | // lo, hi = G_UNMERGE_VALUES x |
2279 | // dst = G_MERGE_VALUES (G_LSHR hi, C - 32), 0 |
2280 | |
2281 | if (NarrowShiftAmt != 0) { |
2282 | Narrowed = Builder.buildLShr(Dst: HalfTy, Src0: Narrowed, |
2283 | Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0); |
2284 | } |
2285 | |
2286 | auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0); |
2287 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Narrowed, Zero}); |
2288 | } else if (MI.getOpcode() == TargetOpcode::G_SHL) { |
2289 | Register Narrowed = Unmerge.getReg(Idx: 0); |
2290 | // dst = G_SHL s64:x, C for C >= 32 |
2291 | // => |
2292 | // lo, hi = G_UNMERGE_VALUES x |
2293 | // dst = G_MERGE_VALUES 0, (G_SHL hi, C - 32) |
2294 | if (NarrowShiftAmt != 0) { |
2295 | Narrowed = Builder.buildShl(Dst: HalfTy, Src0: Narrowed, |
2296 | Src1: Builder.buildConstant(Res: HalfTy, Val: NarrowShiftAmt)).getReg(Idx: 0); |
2297 | } |
2298 | |
2299 | auto Zero = Builder.buildConstant(Res: HalfTy, Val: 0); |
2300 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Zero, Narrowed}); |
2301 | } else { |
2302 | assert(MI.getOpcode() == TargetOpcode::G_ASHR); |
2303 | auto Hi = Builder.buildAShr( |
2304 | Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1), |
2305 | Src1: Builder.buildConstant(Res: HalfTy, Val: HalfSize - 1)); |
2306 | |
2307 | if (ShiftVal == HalfSize) { |
2308 | // (G_ASHR i64:x, 32) -> |
2309 | // G_MERGE_VALUES hi_32(x), (G_ASHR hi_32(x), 31) |
2310 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Unmerge.getReg(Idx: 1), Hi}); |
2311 | } else if (ShiftVal == Size - 1) { |
2312 | // Don't need a second shift. |
2313 | // (G_ASHR i64:x, 63) -> |
2314 | // %narrowed = (G_ASHR hi_32(x), 31) |
2315 | // G_MERGE_VALUES %narrowed, %narrowed |
2316 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Hi, Hi}); |
2317 | } else { |
2318 | auto Lo = Builder.buildAShr( |
2319 | Dst: HalfTy, Src0: Unmerge.getReg(Idx: 1), |
2320 | Src1: Builder.buildConstant(Res: HalfTy, Val: ShiftVal - HalfSize)); |
2321 | |
2322 | // (G_ASHR i64:x, C) ->, for C >= 32 |
2323 | // G_MERGE_VALUES (G_ASHR hi_32(x), C - 32), (G_ASHR hi_32(x), 31) |
2324 | Builder.buildMergeLikeInstr(Res: DstReg, Ops: {Lo, Hi}); |
2325 | } |
2326 | } |
2327 | |
2328 | MI.eraseFromParent(); |
2329 | } |
2330 | |
2331 | bool CombinerHelper::tryCombineShiftToUnmerge(MachineInstr &MI, |
2332 | unsigned TargetShiftAmount) { |
2333 | unsigned ShiftAmt; |
2334 | if (matchCombineShiftToUnmerge(MI, TargetShiftSize: TargetShiftAmount, ShiftVal&: ShiftAmt)) { |
2335 | applyCombineShiftToUnmerge(MI, ShiftVal: ShiftAmt); |
2336 | return true; |
2337 | } |
2338 | |
2339 | return false; |
2340 | } |
2341 | |
2342 | bool CombinerHelper::matchCombineI2PToP2I(MachineInstr &MI, Register &Reg) { |
2343 | assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR" ); |
2344 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2345 | LLT DstTy = MRI.getType(Reg: DstReg); |
2346 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2347 | return mi_match(R: SrcReg, MRI, |
2348 | P: m_GPtrToInt(Src: m_all_of(preds: m_SpecificType(Ty: DstTy), preds: m_Reg(R&: Reg)))); |
2349 | } |
2350 | |
2351 | void CombinerHelper::applyCombineI2PToP2I(MachineInstr &MI, Register &Reg) { |
2352 | assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR" ); |
2353 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2354 | Builder.buildCopy(Res: DstReg, Op: Reg); |
2355 | MI.eraseFromParent(); |
2356 | } |
2357 | |
2358 | void CombinerHelper::applyCombineP2IToI2P(MachineInstr &MI, Register &Reg) { |
2359 | assert(MI.getOpcode() == TargetOpcode::G_PTRTOINT && "Expected a G_PTRTOINT" ); |
2360 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2361 | Builder.buildZExtOrTrunc(Res: DstReg, Op: Reg); |
2362 | MI.eraseFromParent(); |
2363 | } |
2364 | |
2365 | bool CombinerHelper::matchCombineAddP2IToPtrAdd( |
2366 | MachineInstr &MI, std::pair<Register, bool> &PtrReg) { |
2367 | assert(MI.getOpcode() == TargetOpcode::G_ADD); |
2368 | Register LHS = MI.getOperand(i: 1).getReg(); |
2369 | Register RHS = MI.getOperand(i: 2).getReg(); |
2370 | LLT IntTy = MRI.getType(Reg: LHS); |
2371 | |
2372 | // G_PTR_ADD always has the pointer in the LHS, so we may need to commute the |
2373 | // instruction. |
2374 | PtrReg.second = false; |
2375 | for (Register SrcReg : {LHS, RHS}) { |
2376 | if (mi_match(R: SrcReg, MRI, P: m_GPtrToInt(Src: m_Reg(R&: PtrReg.first)))) { |
2377 | // Don't handle cases where the integer is implicitly converted to the |
2378 | // pointer width. |
2379 | LLT PtrTy = MRI.getType(Reg: PtrReg.first); |
2380 | if (PtrTy.getScalarSizeInBits() == IntTy.getScalarSizeInBits()) |
2381 | return true; |
2382 | } |
2383 | |
2384 | PtrReg.second = true; |
2385 | } |
2386 | |
2387 | return false; |
2388 | } |
2389 | |
2390 | void CombinerHelper::applyCombineAddP2IToPtrAdd( |
2391 | MachineInstr &MI, std::pair<Register, bool> &PtrReg) { |
2392 | Register Dst = MI.getOperand(i: 0).getReg(); |
2393 | Register LHS = MI.getOperand(i: 1).getReg(); |
2394 | Register RHS = MI.getOperand(i: 2).getReg(); |
2395 | |
2396 | const bool DoCommute = PtrReg.second; |
2397 | if (DoCommute) |
2398 | std::swap(a&: LHS, b&: RHS); |
2399 | LHS = PtrReg.first; |
2400 | |
2401 | LLT PtrTy = MRI.getType(Reg: LHS); |
2402 | |
2403 | auto PtrAdd = Builder.buildPtrAdd(Res: PtrTy, Op0: LHS, Op1: RHS); |
2404 | Builder.buildPtrToInt(Dst, Src: PtrAdd); |
2405 | MI.eraseFromParent(); |
2406 | } |
2407 | |
2408 | bool CombinerHelper::matchCombineConstPtrAddToI2P(MachineInstr &MI, |
2409 | APInt &NewCst) { |
2410 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
2411 | Register LHS = PtrAdd.getBaseReg(); |
2412 | Register RHS = PtrAdd.getOffsetReg(); |
2413 | MachineRegisterInfo &MRI = Builder.getMF().getRegInfo(); |
2414 | |
2415 | if (auto RHSCst = getIConstantVRegVal(VReg: RHS, MRI)) { |
2416 | APInt Cst; |
2417 | if (mi_match(R: LHS, MRI, P: m_GIntToPtr(Src: m_ICst(Cst)))) { |
2418 | auto DstTy = MRI.getType(Reg: PtrAdd.getReg(Idx: 0)); |
2419 | // G_INTTOPTR uses zero-extension |
2420 | NewCst = Cst.zextOrTrunc(width: DstTy.getSizeInBits()); |
2421 | NewCst += RHSCst->sextOrTrunc(width: DstTy.getSizeInBits()); |
2422 | return true; |
2423 | } |
2424 | } |
2425 | |
2426 | return false; |
2427 | } |
2428 | |
2429 | void CombinerHelper::applyCombineConstPtrAddToI2P(MachineInstr &MI, |
2430 | APInt &NewCst) { |
2431 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
2432 | Register Dst = PtrAdd.getReg(Idx: 0); |
2433 | |
2434 | Builder.buildConstant(Res: Dst, Val: NewCst); |
2435 | PtrAdd.eraseFromParent(); |
2436 | } |
2437 | |
2438 | bool CombinerHelper::matchCombineAnyExtTrunc(MachineInstr &MI, Register &Reg) { |
2439 | assert(MI.getOpcode() == TargetOpcode::G_ANYEXT && "Expected a G_ANYEXT" ); |
2440 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2441 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2442 | Register OriginalSrcReg = getSrcRegIgnoringCopies(Reg: SrcReg, MRI); |
2443 | if (OriginalSrcReg.isValid()) |
2444 | SrcReg = OriginalSrcReg; |
2445 | LLT DstTy = MRI.getType(Reg: DstReg); |
2446 | return mi_match(R: SrcReg, MRI, |
2447 | P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy)))); |
2448 | } |
2449 | |
2450 | bool CombinerHelper::matchCombineZextTrunc(MachineInstr &MI, Register &Reg) { |
2451 | assert(MI.getOpcode() == TargetOpcode::G_ZEXT && "Expected a G_ZEXT" ); |
2452 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2453 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2454 | LLT DstTy = MRI.getType(Reg: DstReg); |
2455 | if (mi_match(R: SrcReg, MRI, |
2456 | P: m_GTrunc(Src: m_all_of(preds: m_Reg(R&: Reg), preds: m_SpecificType(Ty: DstTy))))) { |
2457 | unsigned DstSize = DstTy.getScalarSizeInBits(); |
2458 | unsigned SrcSize = MRI.getType(Reg: SrcReg).getScalarSizeInBits(); |
2459 | return KB->getKnownBits(R: Reg).countMinLeadingZeros() >= DstSize - SrcSize; |
2460 | } |
2461 | return false; |
2462 | } |
2463 | |
2464 | bool CombinerHelper::matchCombineExtOfExt( |
2465 | MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) { |
2466 | assert((MI.getOpcode() == TargetOpcode::G_ANYEXT || |
2467 | MI.getOpcode() == TargetOpcode::G_SEXT || |
2468 | MI.getOpcode() == TargetOpcode::G_ZEXT) && |
2469 | "Expected a G_[ASZ]EXT" ); |
2470 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2471 | Register OriginalSrcReg = getSrcRegIgnoringCopies(Reg: SrcReg, MRI); |
2472 | if (OriginalSrcReg.isValid()) |
2473 | SrcReg = OriginalSrcReg; |
2474 | MachineInstr *SrcMI = MRI.getVRegDef(Reg: SrcReg); |
2475 | // Match exts with the same opcode, anyext([sz]ext) and sext(zext). |
2476 | unsigned Opc = MI.getOpcode(); |
2477 | unsigned SrcOpc = SrcMI->getOpcode(); |
2478 | if (Opc == SrcOpc || |
2479 | (Opc == TargetOpcode::G_ANYEXT && |
2480 | (SrcOpc == TargetOpcode::G_SEXT || SrcOpc == TargetOpcode::G_ZEXT)) || |
2481 | (Opc == TargetOpcode::G_SEXT && SrcOpc == TargetOpcode::G_ZEXT)) { |
2482 | MatchInfo = std::make_tuple(args: SrcMI->getOperand(i: 1).getReg(), args&: SrcOpc); |
2483 | return true; |
2484 | } |
2485 | return false; |
2486 | } |
2487 | |
2488 | void CombinerHelper::applyCombineExtOfExt( |
2489 | MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) { |
2490 | assert((MI.getOpcode() == TargetOpcode::G_ANYEXT || |
2491 | MI.getOpcode() == TargetOpcode::G_SEXT || |
2492 | MI.getOpcode() == TargetOpcode::G_ZEXT) && |
2493 | "Expected a G_[ASZ]EXT" ); |
2494 | |
2495 | Register Reg = std::get<0>(t&: MatchInfo); |
2496 | unsigned SrcExtOp = std::get<1>(t&: MatchInfo); |
2497 | |
2498 | // Combine exts with the same opcode. |
2499 | if (MI.getOpcode() == SrcExtOp) { |
2500 | Observer.changingInstr(MI); |
2501 | MI.getOperand(i: 1).setReg(Reg); |
2502 | Observer.changedInstr(MI); |
2503 | return; |
2504 | } |
2505 | |
2506 | // Combine: |
2507 | // - anyext([sz]ext x) to [sz]ext x |
2508 | // - sext(zext x) to zext x |
2509 | if (MI.getOpcode() == TargetOpcode::G_ANYEXT || |
2510 | (MI.getOpcode() == TargetOpcode::G_SEXT && |
2511 | SrcExtOp == TargetOpcode::G_ZEXT)) { |
2512 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2513 | Builder.buildInstr(Opc: SrcExtOp, DstOps: {DstReg}, SrcOps: {Reg}); |
2514 | MI.eraseFromParent(); |
2515 | } |
2516 | } |
2517 | |
2518 | bool CombinerHelper::matchCombineTruncOfExt( |
2519 | MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) { |
2520 | assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC" ); |
2521 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2522 | MachineInstr *SrcMI = MRI.getVRegDef(Reg: SrcReg); |
2523 | unsigned SrcOpc = SrcMI->getOpcode(); |
2524 | if (SrcOpc == TargetOpcode::G_ANYEXT || SrcOpc == TargetOpcode::G_SEXT || |
2525 | SrcOpc == TargetOpcode::G_ZEXT) { |
2526 | MatchInfo = std::make_pair(x: SrcMI->getOperand(i: 1).getReg(), y&: SrcOpc); |
2527 | return true; |
2528 | } |
2529 | return false; |
2530 | } |
2531 | |
2532 | void CombinerHelper::applyCombineTruncOfExt( |
2533 | MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) { |
2534 | assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC" ); |
2535 | Register SrcReg = MatchInfo.first; |
2536 | unsigned SrcExtOp = MatchInfo.second; |
2537 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2538 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
2539 | LLT DstTy = MRI.getType(Reg: DstReg); |
2540 | if (SrcTy == DstTy) { |
2541 | MI.eraseFromParent(); |
2542 | replaceRegWith(MRI, FromReg: DstReg, ToReg: SrcReg); |
2543 | return; |
2544 | } |
2545 | if (SrcTy.getSizeInBits() < DstTy.getSizeInBits()) |
2546 | Builder.buildInstr(Opc: SrcExtOp, DstOps: {DstReg}, SrcOps: {SrcReg}); |
2547 | else |
2548 | Builder.buildTrunc(Res: DstReg, Op: SrcReg); |
2549 | MI.eraseFromParent(); |
2550 | } |
2551 | |
2552 | static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) { |
2553 | const unsigned ShiftSize = ShiftTy.getScalarSizeInBits(); |
2554 | const unsigned TruncSize = TruncTy.getScalarSizeInBits(); |
2555 | |
2556 | // ShiftTy > 32 > TruncTy -> 32 |
2557 | if (ShiftSize > 32 && TruncSize < 32) |
2558 | return ShiftTy.changeElementSize(NewEltSize: 32); |
2559 | |
2560 | // TODO: We could also reduce to 16 bits, but that's more target-dependent. |
2561 | // Some targets like it, some don't, some only like it under certain |
2562 | // conditions/processor versions, etc. |
2563 | // A TL hook might be needed for this. |
2564 | |
2565 | // Don't combine |
2566 | return ShiftTy; |
2567 | } |
2568 | |
2569 | bool CombinerHelper::matchCombineTruncOfShift( |
2570 | MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) { |
2571 | assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC" ); |
2572 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2573 | Register SrcReg = MI.getOperand(i: 1).getReg(); |
2574 | |
2575 | if (!MRI.hasOneNonDBGUse(RegNo: SrcReg)) |
2576 | return false; |
2577 | |
2578 | LLT SrcTy = MRI.getType(Reg: SrcReg); |
2579 | LLT DstTy = MRI.getType(Reg: DstReg); |
2580 | |
2581 | MachineInstr *SrcMI = getDefIgnoringCopies(Reg: SrcReg, MRI); |
2582 | const auto &TL = getTargetLowering(); |
2583 | |
2584 | LLT NewShiftTy; |
2585 | switch (SrcMI->getOpcode()) { |
2586 | default: |
2587 | return false; |
2588 | case TargetOpcode::G_SHL: { |
2589 | NewShiftTy = DstTy; |
2590 | |
2591 | // Make sure new shift amount is legal. |
2592 | KnownBits Known = KB->getKnownBits(R: SrcMI->getOperand(i: 2).getReg()); |
2593 | if (Known.getMaxValue().uge(RHS: NewShiftTy.getScalarSizeInBits())) |
2594 | return false; |
2595 | break; |
2596 | } |
2597 | case TargetOpcode::G_LSHR: |
2598 | case TargetOpcode::G_ASHR: { |
2599 | // For right shifts, we conservatively do not do the transform if the TRUNC |
2600 | // has any STORE users. The reason is that if we change the type of the |
2601 | // shift, we may break the truncstore combine. |
2602 | // |
2603 | // TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)). |
2604 | for (auto &User : MRI.use_instructions(Reg: DstReg)) |
2605 | if (User.getOpcode() == TargetOpcode::G_STORE) |
2606 | return false; |
2607 | |
2608 | NewShiftTy = getMidVTForTruncRightShiftCombine(ShiftTy: SrcTy, TruncTy: DstTy); |
2609 | if (NewShiftTy == SrcTy) |
2610 | return false; |
2611 | |
2612 | // Make sure we won't lose information by truncating the high bits. |
2613 | KnownBits Known = KB->getKnownBits(R: SrcMI->getOperand(i: 2).getReg()); |
2614 | if (Known.getMaxValue().ugt(RHS: NewShiftTy.getScalarSizeInBits() - |
2615 | DstTy.getScalarSizeInBits())) |
2616 | return false; |
2617 | break; |
2618 | } |
2619 | } |
2620 | |
2621 | if (!isLegalOrBeforeLegalizer( |
2622 | Query: {SrcMI->getOpcode(), |
2623 | {NewShiftTy, TL.getPreferredShiftAmountTy(ShiftValueTy: NewShiftTy)}})) |
2624 | return false; |
2625 | |
2626 | MatchInfo = std::make_pair(x&: SrcMI, y&: NewShiftTy); |
2627 | return true; |
2628 | } |
2629 | |
2630 | void CombinerHelper::applyCombineTruncOfShift( |
2631 | MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) { |
2632 | MachineInstr *ShiftMI = MatchInfo.first; |
2633 | LLT NewShiftTy = MatchInfo.second; |
2634 | |
2635 | Register Dst = MI.getOperand(i: 0).getReg(); |
2636 | LLT DstTy = MRI.getType(Reg: Dst); |
2637 | |
2638 | Register ShiftAmt = ShiftMI->getOperand(i: 2).getReg(); |
2639 | Register ShiftSrc = ShiftMI->getOperand(i: 1).getReg(); |
2640 | ShiftSrc = Builder.buildTrunc(Res: NewShiftTy, Op: ShiftSrc).getReg(Idx: 0); |
2641 | |
2642 | Register NewShift = |
2643 | Builder |
2644 | .buildInstr(Opc: ShiftMI->getOpcode(), DstOps: {NewShiftTy}, SrcOps: {ShiftSrc, ShiftAmt}) |
2645 | .getReg(Idx: 0); |
2646 | |
2647 | if (NewShiftTy == DstTy) |
2648 | replaceRegWith(MRI, FromReg: Dst, ToReg: NewShift); |
2649 | else |
2650 | Builder.buildTrunc(Res: Dst, Op: NewShift); |
2651 | |
2652 | eraseInst(MI); |
2653 | } |
2654 | |
2655 | bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) { |
2656 | return any_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) { |
2657 | return MO.isReg() && |
2658 | getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI); |
2659 | }); |
2660 | } |
2661 | |
2662 | bool CombinerHelper::matchAllExplicitUsesAreUndef(MachineInstr &MI) { |
2663 | return all_of(Range: MI.explicit_uses(), P: [this](const MachineOperand &MO) { |
2664 | return !MO.isReg() || |
2665 | getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI); |
2666 | }); |
2667 | } |
2668 | |
2669 | bool CombinerHelper::matchUndefShuffleVectorMask(MachineInstr &MI) { |
2670 | assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR); |
2671 | ArrayRef<int> Mask = MI.getOperand(i: 3).getShuffleMask(); |
2672 | return all_of(Range&: Mask, P: [](int Elt) { return Elt < 0; }); |
2673 | } |
2674 | |
2675 | bool CombinerHelper::matchUndefStore(MachineInstr &MI) { |
2676 | assert(MI.getOpcode() == TargetOpcode::G_STORE); |
2677 | return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 0).getReg(), |
2678 | MRI); |
2679 | } |
2680 | |
2681 | bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) { |
2682 | assert(MI.getOpcode() == TargetOpcode::G_SELECT); |
2683 | return getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MI.getOperand(i: 1).getReg(), |
2684 | MRI); |
2685 | } |
2686 | |
2687 | bool CombinerHelper::(MachineInstr &MI) { |
2688 | assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT || |
2689 | MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) && |
2690 | "Expected an insert/extract element op" ); |
2691 | LLT VecTy = MRI.getType(Reg: MI.getOperand(i: 1).getReg()); |
2692 | unsigned IdxIdx = |
2693 | MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3; |
2694 | auto Idx = getIConstantVRegVal(VReg: MI.getOperand(i: IdxIdx).getReg(), MRI); |
2695 | if (!Idx) |
2696 | return false; |
2697 | return Idx->getZExtValue() >= VecTy.getNumElements(); |
2698 | } |
2699 | |
2700 | bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI, unsigned &OpIdx) { |
2701 | GSelect &SelMI = cast<GSelect>(Val&: MI); |
2702 | auto Cst = |
2703 | isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: SelMI.getCondReg()), MRI); |
2704 | if (!Cst) |
2705 | return false; |
2706 | OpIdx = Cst->isZero() ? 3 : 2; |
2707 | return true; |
2708 | } |
2709 | |
2710 | void CombinerHelper::eraseInst(MachineInstr &MI) { MI.eraseFromParent(); } |
2711 | |
2712 | bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1, |
2713 | const MachineOperand &MOP2) { |
2714 | if (!MOP1.isReg() || !MOP2.isReg()) |
2715 | return false; |
2716 | auto InstAndDef1 = getDefSrcRegIgnoringCopies(Reg: MOP1.getReg(), MRI); |
2717 | if (!InstAndDef1) |
2718 | return false; |
2719 | auto InstAndDef2 = getDefSrcRegIgnoringCopies(Reg: MOP2.getReg(), MRI); |
2720 | if (!InstAndDef2) |
2721 | return false; |
2722 | MachineInstr *I1 = InstAndDef1->MI; |
2723 | MachineInstr *I2 = InstAndDef2->MI; |
2724 | |
2725 | // Handle a case like this: |
2726 | // |
2727 | // %0:_(s64), %1:_(s64) = G_UNMERGE_VALUES %2:_(<2 x s64>) |
2728 | // |
2729 | // Even though %0 and %1 are produced by the same instruction they are not |
2730 | // the same values. |
2731 | if (I1 == I2) |
2732 | return MOP1.getReg() == MOP2.getReg(); |
2733 | |
2734 | // If we have an instruction which loads or stores, we can't guarantee that |
2735 | // it is identical. |
2736 | // |
2737 | // For example, we may have |
2738 | // |
2739 | // %x1 = G_LOAD %addr (load N from @somewhere) |
2740 | // ... |
2741 | // call @foo |
2742 | // ... |
2743 | // %x2 = G_LOAD %addr (load N from @somewhere) |
2744 | // ... |
2745 | // %or = G_OR %x1, %x2 |
2746 | // |
2747 | // It's possible that @foo will modify whatever lives at the address we're |
2748 | // loading from. To be safe, let's just assume that all loads and stores |
2749 | // are different (unless we have something which is guaranteed to not |
2750 | // change.) |
2751 | if (I1->mayLoadOrStore() && !I1->isDereferenceableInvariantLoad()) |
2752 | return false; |
2753 | |
2754 | // If both instructions are loads or stores, they are equal only if both |
2755 | // are dereferenceable invariant loads with the same number of bits. |
2756 | if (I1->mayLoadOrStore() && I2->mayLoadOrStore()) { |
2757 | GLoadStore *LS1 = dyn_cast<GLoadStore>(Val: I1); |
2758 | GLoadStore *LS2 = dyn_cast<GLoadStore>(Val: I2); |
2759 | if (!LS1 || !LS2) |
2760 | return false; |
2761 | |
2762 | if (!I2->isDereferenceableInvariantLoad() || |
2763 | (LS1->getMemSizeInBits() != LS2->getMemSizeInBits())) |
2764 | return false; |
2765 | } |
2766 | |
2767 | // Check for physical registers on the instructions first to avoid cases |
2768 | // like this: |
2769 | // |
2770 | // %a = COPY $physreg |
2771 | // ... |
2772 | // SOMETHING implicit-def $physreg |
2773 | // ... |
2774 | // %b = COPY $physreg |
2775 | // |
2776 | // These copies are not equivalent. |
2777 | if (any_of(Range: I1->uses(), P: [](const MachineOperand &MO) { |
2778 | return MO.isReg() && MO.getReg().isPhysical(); |
2779 | })) { |
2780 | // Check if we have a case like this: |
2781 | // |
2782 | // %a = COPY $physreg |
2783 | // %b = COPY %a |
2784 | // |
2785 | // In this case, I1 and I2 will both be equal to %a = COPY $physreg. |
2786 | // From that, we know that they must have the same value, since they must |
2787 | // have come from the same COPY. |
2788 | return I1->isIdenticalTo(Other: *I2); |
2789 | } |
2790 | |
2791 | // We don't have any physical registers, so we don't necessarily need the |
2792 | // same vreg defs. |
2793 | // |
2794 | // On the off-chance that there's some target instruction feeding into the |
2795 | // instruction, let's use produceSameValue instead of isIdenticalTo. |
2796 | if (Builder.getTII().produceSameValue(MI0: *I1, MI1: *I2, MRI: &MRI)) { |
2797 | // Handle instructions with multiple defs that produce same values. Values |
2798 | // are same for operands with same index. |
2799 | // %0:_(s8), %1:_(s8), %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>) |
2800 | // %5:_(s8), %6:_(s8), %7:_(s8), %8:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>) |
2801 | // I1 and I2 are different instructions but produce same values, |
2802 | // %1 and %6 are same, %1 and %7 are not the same value. |
2803 | return I1->findRegisterDefOperandIdx(Reg: InstAndDef1->Reg, /*TRI=*/nullptr) == |
2804 | I2->findRegisterDefOperandIdx(Reg: InstAndDef2->Reg, /*TRI=*/nullptr); |
2805 | } |
2806 | return false; |
2807 | } |
2808 | |
2809 | bool CombinerHelper::matchConstantOp(const MachineOperand &MOP, int64_t C) { |
2810 | if (!MOP.isReg()) |
2811 | return false; |
2812 | auto *MI = MRI.getVRegDef(Reg: MOP.getReg()); |
2813 | auto MaybeCst = isConstantOrConstantSplatVector(MI&: *MI, MRI); |
2814 | return MaybeCst && MaybeCst->getBitWidth() <= 64 && |
2815 | MaybeCst->getSExtValue() == C; |
2816 | } |
2817 | |
2818 | bool CombinerHelper::matchConstantFPOp(const MachineOperand &MOP, double C) { |
2819 | if (!MOP.isReg()) |
2820 | return false; |
2821 | std::optional<FPValueAndVReg> MaybeCst; |
2822 | if (!mi_match(R: MOP.getReg(), MRI, P: m_GFCstOrSplat(FPValReg&: MaybeCst))) |
2823 | return false; |
2824 | |
2825 | return MaybeCst->Value.isExactlyValue(V: C); |
2826 | } |
2827 | |
2828 | void CombinerHelper::replaceSingleDefInstWithOperand(MachineInstr &MI, |
2829 | unsigned OpIdx) { |
2830 | assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?" ); |
2831 | Register OldReg = MI.getOperand(i: 0).getReg(); |
2832 | Register Replacement = MI.getOperand(i: OpIdx).getReg(); |
2833 | assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?" ); |
2834 | MI.eraseFromParent(); |
2835 | replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement); |
2836 | } |
2837 | |
2838 | void CombinerHelper::replaceSingleDefInstWithReg(MachineInstr &MI, |
2839 | Register Replacement) { |
2840 | assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?" ); |
2841 | Register OldReg = MI.getOperand(i: 0).getReg(); |
2842 | assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?" ); |
2843 | MI.eraseFromParent(); |
2844 | replaceRegWith(MRI, FromReg: OldReg, ToReg: Replacement); |
2845 | } |
2846 | |
2847 | bool CombinerHelper::matchConstantLargerBitWidth(MachineInstr &MI, |
2848 | unsigned ConstIdx) { |
2849 | Register ConstReg = MI.getOperand(i: ConstIdx).getReg(); |
2850 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2851 | |
2852 | // Get the shift amount |
2853 | auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI); |
2854 | if (!VRegAndVal) |
2855 | return false; |
2856 | |
2857 | // Return true of shift amount >= Bitwidth |
2858 | return (VRegAndVal->Value.uge(RHS: DstTy.getSizeInBits())); |
2859 | } |
2860 | |
2861 | void CombinerHelper::applyFunnelShiftConstantModulo(MachineInstr &MI) { |
2862 | assert((MI.getOpcode() == TargetOpcode::G_FSHL || |
2863 | MI.getOpcode() == TargetOpcode::G_FSHR) && |
2864 | "This is not a funnel shift operation" ); |
2865 | |
2866 | Register ConstReg = MI.getOperand(i: 3).getReg(); |
2867 | LLT ConstTy = MRI.getType(Reg: ConstReg); |
2868 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
2869 | |
2870 | auto VRegAndVal = getIConstantVRegValWithLookThrough(VReg: ConstReg, MRI); |
2871 | assert((VRegAndVal) && "Value is not a constant" ); |
2872 | |
2873 | // Calculate the new Shift Amount = Old Shift Amount % BitWidth |
2874 | APInt NewConst = VRegAndVal->Value.urem( |
2875 | RHS: APInt(ConstTy.getSizeInBits(), DstTy.getScalarSizeInBits())); |
2876 | |
2877 | auto NewConstInstr = Builder.buildConstant(Res: ConstTy, Val: NewConst.getZExtValue()); |
2878 | Builder.buildInstr( |
2879 | Opc: MI.getOpcode(), DstOps: {MI.getOperand(i: 0)}, |
2880 | SrcOps: {MI.getOperand(i: 1), MI.getOperand(i: 2), NewConstInstr.getReg(Idx: 0)}); |
2881 | |
2882 | MI.eraseFromParent(); |
2883 | } |
2884 | |
2885 | bool CombinerHelper::matchSelectSameVal(MachineInstr &MI) { |
2886 | assert(MI.getOpcode() == TargetOpcode::G_SELECT); |
2887 | // Match (cond ? x : x) |
2888 | return matchEqualDefs(MOP1: MI.getOperand(i: 2), MOP2: MI.getOperand(i: 3)) && |
2889 | canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 2).getReg(), |
2890 | MRI); |
2891 | } |
2892 | |
2893 | bool CombinerHelper::matchBinOpSameVal(MachineInstr &MI) { |
2894 | return matchEqualDefs(MOP1: MI.getOperand(i: 1), MOP2: MI.getOperand(i: 2)) && |
2895 | canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: 1).getReg(), |
2896 | MRI); |
2897 | } |
2898 | |
2899 | bool CombinerHelper::matchOperandIsZero(MachineInstr &MI, unsigned OpIdx) { |
2900 | return matchConstantOp(MOP: MI.getOperand(i: OpIdx), C: 0) && |
2901 | canReplaceReg(DstReg: MI.getOperand(i: 0).getReg(), SrcReg: MI.getOperand(i: OpIdx).getReg(), |
2902 | MRI); |
2903 | } |
2904 | |
2905 | bool CombinerHelper::matchOperandIsUndef(MachineInstr &MI, unsigned OpIdx) { |
2906 | MachineOperand &MO = MI.getOperand(i: OpIdx); |
2907 | return MO.isReg() && |
2908 | getOpcodeDef(Opcode: TargetOpcode::G_IMPLICIT_DEF, Reg: MO.getReg(), MRI); |
2909 | } |
2910 | |
2911 | bool CombinerHelper::matchOperandIsKnownToBeAPowerOfTwo(MachineInstr &MI, |
2912 | unsigned OpIdx) { |
2913 | MachineOperand &MO = MI.getOperand(i: OpIdx); |
2914 | return isKnownToBeAPowerOfTwo(Val: MO.getReg(), MRI, KnownBits: KB); |
2915 | } |
2916 | |
2917 | void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, double C) { |
2918 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
2919 | Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: C); |
2920 | MI.eraseFromParent(); |
2921 | } |
2922 | |
2923 | void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, int64_t C) { |
2924 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
2925 | Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C); |
2926 | MI.eraseFromParent(); |
2927 | } |
2928 | |
2929 | void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, APInt C) { |
2930 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
2931 | Builder.buildConstant(Res: MI.getOperand(i: 0), Val: C); |
2932 | MI.eraseFromParent(); |
2933 | } |
2934 | |
2935 | void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, |
2936 | ConstantFP *CFP) { |
2937 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
2938 | Builder.buildFConstant(Res: MI.getOperand(i: 0), Val: CFP->getValueAPF()); |
2939 | MI.eraseFromParent(); |
2940 | } |
2941 | |
2942 | void CombinerHelper::replaceInstWithUndef(MachineInstr &MI) { |
2943 | assert(MI.getNumDefs() == 1 && "Expected only one def?" ); |
2944 | Builder.buildUndef(Res: MI.getOperand(i: 0)); |
2945 | MI.eraseFromParent(); |
2946 | } |
2947 | |
2948 | bool CombinerHelper::matchSimplifyAddToSub( |
2949 | MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) { |
2950 | Register LHS = MI.getOperand(i: 1).getReg(); |
2951 | Register RHS = MI.getOperand(i: 2).getReg(); |
2952 | Register &NewLHS = std::get<0>(t&: MatchInfo); |
2953 | Register &NewRHS = std::get<1>(t&: MatchInfo); |
2954 | |
2955 | // Helper lambda to check for opportunities for |
2956 | // ((0-A) + B) -> B - A |
2957 | // (A + (0-B)) -> A - B |
2958 | auto CheckFold = [&](Register &MaybeSub, Register &MaybeNewLHS) { |
2959 | if (!mi_match(R: MaybeSub, MRI, P: m_Neg(Src: m_Reg(R&: NewRHS)))) |
2960 | return false; |
2961 | NewLHS = MaybeNewLHS; |
2962 | return true; |
2963 | }; |
2964 | |
2965 | return CheckFold(LHS, RHS) || CheckFold(RHS, LHS); |
2966 | } |
2967 | |
2968 | bool CombinerHelper::matchCombineInsertVecElts( |
2969 | MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) { |
2970 | assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT && |
2971 | "Invalid opcode" ); |
2972 | Register DstReg = MI.getOperand(i: 0).getReg(); |
2973 | LLT DstTy = MRI.getType(Reg: DstReg); |
2974 | assert(DstTy.isVector() && "Invalid G_INSERT_VECTOR_ELT?" ); |
2975 | unsigned NumElts = DstTy.getNumElements(); |
2976 | // If this MI is part of a sequence of insert_vec_elts, then |
2977 | // don't do the combine in the middle of the sequence. |
2978 | if (MRI.hasOneUse(RegNo: DstReg) && MRI.use_instr_begin(RegNo: DstReg)->getOpcode() == |
2979 | TargetOpcode::G_INSERT_VECTOR_ELT) |
2980 | return false; |
2981 | MachineInstr *CurrInst = &MI; |
2982 | MachineInstr *TmpInst; |
2983 | int64_t IntImm; |
2984 | Register TmpReg; |
2985 | MatchInfo.resize(N: NumElts); |
2986 | while (mi_match( |
2987 | R: CurrInst->getOperand(i: 0).getReg(), MRI, |
2988 | P: m_GInsertVecElt(Src0: m_MInstr(MI&: TmpInst), Src1: m_Reg(R&: TmpReg), Src2: m_ICst(Cst&: IntImm)))) { |
2989 | if (IntImm >= NumElts || IntImm < 0) |
2990 | return false; |
2991 | if (!MatchInfo[IntImm]) |
2992 | MatchInfo[IntImm] = TmpReg; |
2993 | CurrInst = TmpInst; |
2994 | } |
2995 | // Variable index. |
2996 | if (CurrInst->getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT) |
2997 | return false; |
2998 | if (TmpInst->getOpcode() == TargetOpcode::G_BUILD_VECTOR) { |
2999 | for (unsigned I = 1; I < TmpInst->getNumOperands(); ++I) { |
3000 | if (!MatchInfo[I - 1].isValid()) |
3001 | MatchInfo[I - 1] = TmpInst->getOperand(i: I).getReg(); |
3002 | } |
3003 | return true; |
3004 | } |
3005 | // If we didn't end in a G_IMPLICIT_DEF and the source is not fully |
3006 | // overwritten, bail out. |
3007 | return TmpInst->getOpcode() == TargetOpcode::G_IMPLICIT_DEF || |
3008 | all_of(Range&: MatchInfo, P: [](Register Reg) { return !!Reg; }); |
3009 | } |
3010 | |
3011 | void CombinerHelper::applyCombineInsertVecElts( |
3012 | MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) { |
3013 | Register UndefReg; |
3014 | auto GetUndef = [&]() { |
3015 | if (UndefReg) |
3016 | return UndefReg; |
3017 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
3018 | UndefReg = Builder.buildUndef(Res: DstTy.getScalarType()).getReg(Idx: 0); |
3019 | return UndefReg; |
3020 | }; |
3021 | for (unsigned I = 0; I < MatchInfo.size(); ++I) { |
3022 | if (!MatchInfo[I]) |
3023 | MatchInfo[I] = GetUndef(); |
3024 | } |
3025 | Builder.buildBuildVector(Res: MI.getOperand(i: 0).getReg(), Ops: MatchInfo); |
3026 | MI.eraseFromParent(); |
3027 | } |
3028 | |
3029 | void CombinerHelper::applySimplifyAddToSub( |
3030 | MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) { |
3031 | Register SubLHS, SubRHS; |
3032 | std::tie(args&: SubLHS, args&: SubRHS) = MatchInfo; |
3033 | Builder.buildSub(Dst: MI.getOperand(i: 0).getReg(), Src0: SubLHS, Src1: SubRHS); |
3034 | MI.eraseFromParent(); |
3035 | } |
3036 | |
3037 | bool CombinerHelper::matchHoistLogicOpWithSameOpcodeHands( |
3038 | MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) { |
3039 | // Matches: logic (hand x, ...), (hand y, ...) -> hand (logic x, y), ... |
3040 | // |
3041 | // Creates the new hand + logic instruction (but does not insert them.) |
3042 | // |
3043 | // On success, MatchInfo is populated with the new instructions. These are |
3044 | // inserted in applyHoistLogicOpWithSameOpcodeHands. |
3045 | unsigned LogicOpcode = MI.getOpcode(); |
3046 | assert(LogicOpcode == TargetOpcode::G_AND || |
3047 | LogicOpcode == TargetOpcode::G_OR || |
3048 | LogicOpcode == TargetOpcode::G_XOR); |
3049 | MachineIRBuilder MIB(MI); |
3050 | Register Dst = MI.getOperand(i: 0).getReg(); |
3051 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
3052 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
3053 | |
3054 | // Don't recompute anything. |
3055 | if (!MRI.hasOneNonDBGUse(RegNo: LHSReg) || !MRI.hasOneNonDBGUse(RegNo: RHSReg)) |
3056 | return false; |
3057 | |
3058 | // Make sure we have (hand x, ...), (hand y, ...) |
3059 | MachineInstr *LeftHandInst = getDefIgnoringCopies(Reg: LHSReg, MRI); |
3060 | MachineInstr *RightHandInst = getDefIgnoringCopies(Reg: RHSReg, MRI); |
3061 | if (!LeftHandInst || !RightHandInst) |
3062 | return false; |
3063 | unsigned HandOpcode = LeftHandInst->getOpcode(); |
3064 | if (HandOpcode != RightHandInst->getOpcode()) |
3065 | return false; |
3066 | if (!LeftHandInst->getOperand(i: 1).isReg() || |
3067 | !RightHandInst->getOperand(i: 1).isReg()) |
3068 | return false; |
3069 | |
3070 | // Make sure the types match up, and if we're doing this post-legalization, |
3071 | // we end up with legal types. |
3072 | Register X = LeftHandInst->getOperand(i: 1).getReg(); |
3073 | Register Y = RightHandInst->getOperand(i: 1).getReg(); |
3074 | LLT XTy = MRI.getType(Reg: X); |
3075 | LLT YTy = MRI.getType(Reg: Y); |
3076 | if (!XTy.isValid() || XTy != YTy) |
3077 | return false; |
3078 | |
3079 | // Optional extra source register. |
3080 | Register ExtraHandOpSrcReg; |
3081 | switch (HandOpcode) { |
3082 | default: |
3083 | return false; |
3084 | case TargetOpcode::G_ANYEXT: |
3085 | case TargetOpcode::G_SEXT: |
3086 | case TargetOpcode::G_ZEXT: { |
3087 | // Match: logic (ext X), (ext Y) --> ext (logic X, Y) |
3088 | break; |
3089 | } |
3090 | case TargetOpcode::G_AND: |
3091 | case TargetOpcode::G_ASHR: |
3092 | case TargetOpcode::G_LSHR: |
3093 | case TargetOpcode::G_SHL: { |
3094 | // Match: logic (binop x, z), (binop y, z) -> binop (logic x, y), z |
3095 | MachineOperand &ZOp = LeftHandInst->getOperand(i: 2); |
3096 | if (!matchEqualDefs(MOP1: ZOp, MOP2: RightHandInst->getOperand(i: 2))) |
3097 | return false; |
3098 | ExtraHandOpSrcReg = ZOp.getReg(); |
3099 | break; |
3100 | } |
3101 | } |
3102 | |
3103 | if (!isLegalOrBeforeLegalizer(Query: {LogicOpcode, {XTy, YTy}})) |
3104 | return false; |
3105 | |
3106 | // Record the steps to build the new instructions. |
3107 | // |
3108 | // Steps to build (logic x, y) |
3109 | auto NewLogicDst = MRI.createGenericVirtualRegister(Ty: XTy); |
3110 | OperandBuildSteps LogicBuildSteps = { |
3111 | [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: NewLogicDst); }, |
3112 | [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: X); }, |
3113 | [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: Y); }}; |
3114 | InstructionBuildSteps LogicSteps(LogicOpcode, LogicBuildSteps); |
3115 | |
3116 | // Steps to build hand (logic x, y), ...z |
3117 | OperandBuildSteps HandBuildSteps = { |
3118 | [=](MachineInstrBuilder &MIB) { MIB.addDef(RegNo: Dst); }, |
3119 | [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: NewLogicDst); }}; |
3120 | if (ExtraHandOpSrcReg.isValid()) |
3121 | HandBuildSteps.push_back( |
3122 | Elt: [=](MachineInstrBuilder &MIB) { MIB.addReg(RegNo: ExtraHandOpSrcReg); }); |
3123 | InstructionBuildSteps HandSteps(HandOpcode, HandBuildSteps); |
3124 | |
3125 | MatchInfo = InstructionStepsMatchInfo({LogicSteps, HandSteps}); |
3126 | return true; |
3127 | } |
3128 | |
3129 | void CombinerHelper::applyBuildInstructionSteps( |
3130 | MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) { |
3131 | assert(MatchInfo.InstrsToBuild.size() && |
3132 | "Expected at least one instr to build?" ); |
3133 | for (auto &InstrToBuild : MatchInfo.InstrsToBuild) { |
3134 | assert(InstrToBuild.Opcode && "Expected a valid opcode?" ); |
3135 | assert(InstrToBuild.OperandFns.size() && "Expected at least one operand?" ); |
3136 | MachineInstrBuilder Instr = Builder.buildInstr(Opcode: InstrToBuild.Opcode); |
3137 | for (auto &OperandFn : InstrToBuild.OperandFns) |
3138 | OperandFn(Instr); |
3139 | } |
3140 | MI.eraseFromParent(); |
3141 | } |
3142 | |
3143 | bool CombinerHelper::matchAshrShlToSextInreg( |
3144 | MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) { |
3145 | assert(MI.getOpcode() == TargetOpcode::G_ASHR); |
3146 | int64_t ShlCst, AshrCst; |
3147 | Register Src; |
3148 | if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI, |
3149 | P: m_GAShr(L: m_GShl(L: m_Reg(R&: Src), R: m_ICstOrSplat(Cst&: ShlCst)), |
3150 | R: m_ICstOrSplat(Cst&: AshrCst)))) |
3151 | return false; |
3152 | if (ShlCst != AshrCst) |
3153 | return false; |
3154 | if (!isLegalOrBeforeLegalizer( |
3155 | Query: {TargetOpcode::G_SEXT_INREG, {MRI.getType(Reg: Src)}})) |
3156 | return false; |
3157 | MatchInfo = std::make_tuple(args&: Src, args&: ShlCst); |
3158 | return true; |
3159 | } |
3160 | |
3161 | void CombinerHelper::applyAshShlToSextInreg( |
3162 | MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) { |
3163 | assert(MI.getOpcode() == TargetOpcode::G_ASHR); |
3164 | Register Src; |
3165 | int64_t ShiftAmt; |
3166 | std::tie(args&: Src, args&: ShiftAmt) = MatchInfo; |
3167 | unsigned Size = MRI.getType(Reg: Src).getScalarSizeInBits(); |
3168 | Builder.buildSExtInReg(Res: MI.getOperand(i: 0).getReg(), Op: Src, ImmOp: Size - ShiftAmt); |
3169 | MI.eraseFromParent(); |
3170 | } |
3171 | |
3172 | /// and(and(x, C1), C2) -> C1&C2 ? and(x, C1&C2) : 0 |
3173 | bool CombinerHelper::matchOverlappingAnd( |
3174 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
3175 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
3176 | |
3177 | Register Dst = MI.getOperand(i: 0).getReg(); |
3178 | LLT Ty = MRI.getType(Reg: Dst); |
3179 | |
3180 | Register R; |
3181 | int64_t C1; |
3182 | int64_t C2; |
3183 | if (!mi_match( |
3184 | R: Dst, MRI, |
3185 | P: m_GAnd(L: m_GAnd(L: m_Reg(R), R: m_ICst(Cst&: C1)), R: m_ICst(Cst&: C2)))) |
3186 | return false; |
3187 | |
3188 | MatchInfo = [=](MachineIRBuilder &B) { |
3189 | if (C1 & C2) { |
3190 | B.buildAnd(Dst, Src0: R, Src1: B.buildConstant(Res: Ty, Val: C1 & C2)); |
3191 | return; |
3192 | } |
3193 | auto Zero = B.buildConstant(Res: Ty, Val: 0); |
3194 | replaceRegWith(MRI, FromReg: Dst, ToReg: Zero->getOperand(i: 0).getReg()); |
3195 | }; |
3196 | return true; |
3197 | } |
3198 | |
3199 | bool CombinerHelper::matchRedundantAnd(MachineInstr &MI, |
3200 | Register &Replacement) { |
3201 | // Given |
3202 | // |
3203 | // %y:_(sN) = G_SOMETHING |
3204 | // %x:_(sN) = G_SOMETHING |
3205 | // %res:_(sN) = G_AND %x, %y |
3206 | // |
3207 | // Eliminate the G_AND when it is known that x & y == x or x & y == y. |
3208 | // |
3209 | // Patterns like this can appear as a result of legalization. E.g. |
3210 | // |
3211 | // %cmp:_(s32) = G_ICMP intpred(pred), %x(s32), %y |
3212 | // %one:_(s32) = G_CONSTANT i32 1 |
3213 | // %and:_(s32) = G_AND %cmp, %one |
3214 | // |
3215 | // In this case, G_ICMP only produces a single bit, so x & 1 == x. |
3216 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
3217 | if (!KB) |
3218 | return false; |
3219 | |
3220 | Register AndDst = MI.getOperand(i: 0).getReg(); |
3221 | Register LHS = MI.getOperand(i: 1).getReg(); |
3222 | Register RHS = MI.getOperand(i: 2).getReg(); |
3223 | KnownBits LHSBits = KB->getKnownBits(R: LHS); |
3224 | KnownBits RHSBits = KB->getKnownBits(R: RHS); |
3225 | |
3226 | // Check that x & Mask == x. |
3227 | // x & 1 == x, always |
3228 | // x & 0 == x, only if x is also 0 |
3229 | // Meaning Mask has no effect if every bit is either one in Mask or zero in x. |
3230 | // |
3231 | // Check if we can replace AndDst with the LHS of the G_AND |
3232 | if (canReplaceReg(DstReg: AndDst, SrcReg: LHS, MRI) && |
3233 | (LHSBits.Zero | RHSBits.One).isAllOnes()) { |
3234 | Replacement = LHS; |
3235 | return true; |
3236 | } |
3237 | |
3238 | // Check if we can replace AndDst with the RHS of the G_AND |
3239 | if (canReplaceReg(DstReg: AndDst, SrcReg: RHS, MRI) && |
3240 | (LHSBits.One | RHSBits.Zero).isAllOnes()) { |
3241 | Replacement = RHS; |
3242 | return true; |
3243 | } |
3244 | |
3245 | return false; |
3246 | } |
3247 | |
3248 | bool CombinerHelper::matchRedundantOr(MachineInstr &MI, Register &Replacement) { |
3249 | // Given |
3250 | // |
3251 | // %y:_(sN) = G_SOMETHING |
3252 | // %x:_(sN) = G_SOMETHING |
3253 | // %res:_(sN) = G_OR %x, %y |
3254 | // |
3255 | // Eliminate the G_OR when it is known that x | y == x or x | y == y. |
3256 | assert(MI.getOpcode() == TargetOpcode::G_OR); |
3257 | if (!KB) |
3258 | return false; |
3259 | |
3260 | Register OrDst = MI.getOperand(i: 0).getReg(); |
3261 | Register LHS = MI.getOperand(i: 1).getReg(); |
3262 | Register RHS = MI.getOperand(i: 2).getReg(); |
3263 | KnownBits LHSBits = KB->getKnownBits(R: LHS); |
3264 | KnownBits RHSBits = KB->getKnownBits(R: RHS); |
3265 | |
3266 | // Check that x | Mask == x. |
3267 | // x | 0 == x, always |
3268 | // x | 1 == x, only if x is also 1 |
3269 | // Meaning Mask has no effect if every bit is either zero in Mask or one in x. |
3270 | // |
3271 | // Check if we can replace OrDst with the LHS of the G_OR |
3272 | if (canReplaceReg(DstReg: OrDst, SrcReg: LHS, MRI) && |
3273 | (LHSBits.One | RHSBits.Zero).isAllOnes()) { |
3274 | Replacement = LHS; |
3275 | return true; |
3276 | } |
3277 | |
3278 | // Check if we can replace OrDst with the RHS of the G_OR |
3279 | if (canReplaceReg(DstReg: OrDst, SrcReg: RHS, MRI) && |
3280 | (LHSBits.Zero | RHSBits.One).isAllOnes()) { |
3281 | Replacement = RHS; |
3282 | return true; |
3283 | } |
3284 | |
3285 | return false; |
3286 | } |
3287 | |
3288 | bool CombinerHelper::matchRedundantSExtInReg(MachineInstr &MI) { |
3289 | // If the input is already sign extended, just drop the extension. |
3290 | Register Src = MI.getOperand(i: 1).getReg(); |
3291 | unsigned ExtBits = MI.getOperand(i: 2).getImm(); |
3292 | unsigned TypeSize = MRI.getType(Reg: Src).getScalarSizeInBits(); |
3293 | return KB->computeNumSignBits(R: Src) >= (TypeSize - ExtBits + 1); |
3294 | } |
3295 | |
3296 | static bool isConstValidTrue(const TargetLowering &TLI, unsigned ScalarSizeBits, |
3297 | int64_t Cst, bool IsVector, bool IsFP) { |
3298 | // For i1, Cst will always be -1 regardless of boolean contents. |
3299 | return (ScalarSizeBits == 1 && Cst == -1) || |
3300 | isConstTrueVal(TLI, Val: Cst, IsVector, IsFP); |
3301 | } |
3302 | |
3303 | bool CombinerHelper::matchNotCmp(MachineInstr &MI, |
3304 | SmallVectorImpl<Register> &RegsToNegate) { |
3305 | assert(MI.getOpcode() == TargetOpcode::G_XOR); |
3306 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
3307 | const auto &TLI = *Builder.getMF().getSubtarget().getTargetLowering(); |
3308 | Register XorSrc; |
3309 | Register CstReg; |
3310 | // We match xor(src, true) here. |
3311 | if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI, |
3312 | P: m_GXor(L: m_Reg(R&: XorSrc), R: m_Reg(R&: CstReg)))) |
3313 | return false; |
3314 | |
3315 | if (!MRI.hasOneNonDBGUse(RegNo: XorSrc)) |
3316 | return false; |
3317 | |
3318 | // Check that XorSrc is the root of a tree of comparisons combined with ANDs |
3319 | // and ORs. The suffix of RegsToNegate starting from index I is used a work |
3320 | // list of tree nodes to visit. |
3321 | RegsToNegate.push_back(Elt: XorSrc); |
3322 | // Remember whether the comparisons are all integer or all floating point. |
3323 | bool IsInt = false; |
3324 | bool IsFP = false; |
3325 | for (unsigned I = 0; I < RegsToNegate.size(); ++I) { |
3326 | Register Reg = RegsToNegate[I]; |
3327 | if (!MRI.hasOneNonDBGUse(RegNo: Reg)) |
3328 | return false; |
3329 | MachineInstr *Def = MRI.getVRegDef(Reg); |
3330 | switch (Def->getOpcode()) { |
3331 | default: |
3332 | // Don't match if the tree contains anything other than ANDs, ORs and |
3333 | // comparisons. |
3334 | return false; |
3335 | case TargetOpcode::G_ICMP: |
3336 | if (IsFP) |
3337 | return false; |
3338 | IsInt = true; |
3339 | // When we apply the combine we will invert the predicate. |
3340 | break; |
3341 | case TargetOpcode::G_FCMP: |
3342 | if (IsInt) |
3343 | return false; |
3344 | IsFP = true; |
3345 | // When we apply the combine we will invert the predicate. |
3346 | break; |
3347 | case TargetOpcode::G_AND: |
3348 | case TargetOpcode::G_OR: |
3349 | // Implement De Morgan's laws: |
3350 | // ~(x & y) -> ~x | ~y |
3351 | // ~(x | y) -> ~x & ~y |
3352 | // When we apply the combine we will change the opcode and recursively |
3353 | // negate the operands. |
3354 | RegsToNegate.push_back(Elt: Def->getOperand(i: 1).getReg()); |
3355 | RegsToNegate.push_back(Elt: Def->getOperand(i: 2).getReg()); |
3356 | break; |
3357 | } |
3358 | } |
3359 | |
3360 | // Now we know whether the comparisons are integer or floating point, check |
3361 | // the constant in the xor. |
3362 | int64_t Cst; |
3363 | if (Ty.isVector()) { |
3364 | MachineInstr *CstDef = MRI.getVRegDef(Reg: CstReg); |
3365 | auto MaybeCst = getIConstantSplatSExtVal(MI: *CstDef, MRI); |
3366 | if (!MaybeCst) |
3367 | return false; |
3368 | if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getScalarSizeInBits(), Cst: *MaybeCst, IsVector: true, IsFP)) |
3369 | return false; |
3370 | } else { |
3371 | if (!mi_match(R: CstReg, MRI, P: m_ICst(Cst))) |
3372 | return false; |
3373 | if (!isConstValidTrue(TLI, ScalarSizeBits: Ty.getSizeInBits(), Cst, IsVector: false, IsFP)) |
3374 | return false; |
3375 | } |
3376 | |
3377 | return true; |
3378 | } |
3379 | |
3380 | void CombinerHelper::applyNotCmp(MachineInstr &MI, |
3381 | SmallVectorImpl<Register> &RegsToNegate) { |
3382 | for (Register Reg : RegsToNegate) { |
3383 | MachineInstr *Def = MRI.getVRegDef(Reg); |
3384 | Observer.changingInstr(MI&: *Def); |
3385 | // For each comparison, invert the opcode. For each AND and OR, change the |
3386 | // opcode. |
3387 | switch (Def->getOpcode()) { |
3388 | default: |
3389 | llvm_unreachable("Unexpected opcode" ); |
3390 | case TargetOpcode::G_ICMP: |
3391 | case TargetOpcode::G_FCMP: { |
3392 | MachineOperand &PredOp = Def->getOperand(i: 1); |
3393 | CmpInst::Predicate NewP = CmpInst::getInversePredicate( |
3394 | pred: (CmpInst::Predicate)PredOp.getPredicate()); |
3395 | PredOp.setPredicate(NewP); |
3396 | break; |
3397 | } |
3398 | case TargetOpcode::G_AND: |
3399 | Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_OR)); |
3400 | break; |
3401 | case TargetOpcode::G_OR: |
3402 | Def->setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND)); |
3403 | break; |
3404 | } |
3405 | Observer.changedInstr(MI&: *Def); |
3406 | } |
3407 | |
3408 | replaceRegWith(MRI, FromReg: MI.getOperand(i: 0).getReg(), ToReg: MI.getOperand(i: 1).getReg()); |
3409 | MI.eraseFromParent(); |
3410 | } |
3411 | |
3412 | bool CombinerHelper::matchXorOfAndWithSameReg( |
3413 | MachineInstr &MI, std::pair<Register, Register> &MatchInfo) { |
3414 | // Match (xor (and x, y), y) (or any of its commuted cases) |
3415 | assert(MI.getOpcode() == TargetOpcode::G_XOR); |
3416 | Register &X = MatchInfo.first; |
3417 | Register &Y = MatchInfo.second; |
3418 | Register AndReg = MI.getOperand(i: 1).getReg(); |
3419 | Register SharedReg = MI.getOperand(i: 2).getReg(); |
3420 | |
3421 | // Find a G_AND on either side of the G_XOR. |
3422 | // Look for one of |
3423 | // |
3424 | // (xor (and x, y), SharedReg) |
3425 | // (xor SharedReg, (and x, y)) |
3426 | if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y)))) { |
3427 | std::swap(a&: AndReg, b&: SharedReg); |
3428 | if (!mi_match(R: AndReg, MRI, P: m_GAnd(L: m_Reg(R&: X), R: m_Reg(R&: Y)))) |
3429 | return false; |
3430 | } |
3431 | |
3432 | // Only do this if we'll eliminate the G_AND. |
3433 | if (!MRI.hasOneNonDBGUse(RegNo: AndReg)) |
3434 | return false; |
3435 | |
3436 | // We can combine if SharedReg is the same as either the LHS or RHS of the |
3437 | // G_AND. |
3438 | if (Y != SharedReg) |
3439 | std::swap(a&: X, b&: Y); |
3440 | return Y == SharedReg; |
3441 | } |
3442 | |
3443 | void CombinerHelper::applyXorOfAndWithSameReg( |
3444 | MachineInstr &MI, std::pair<Register, Register> &MatchInfo) { |
3445 | // Fold (xor (and x, y), y) -> (and (not x), y) |
3446 | Register X, Y; |
3447 | std::tie(args&: X, args&: Y) = MatchInfo; |
3448 | auto Not = Builder.buildNot(Dst: MRI.getType(Reg: X), Src0: X); |
3449 | Observer.changingInstr(MI); |
3450 | MI.setDesc(Builder.getTII().get(Opcode: TargetOpcode::G_AND)); |
3451 | MI.getOperand(i: 1).setReg(Not->getOperand(i: 0).getReg()); |
3452 | MI.getOperand(i: 2).setReg(Y); |
3453 | Observer.changedInstr(MI); |
3454 | } |
3455 | |
3456 | bool CombinerHelper::matchPtrAddZero(MachineInstr &MI) { |
3457 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
3458 | Register DstReg = PtrAdd.getReg(Idx: 0); |
3459 | LLT Ty = MRI.getType(Reg: DstReg); |
3460 | const DataLayout &DL = Builder.getMF().getDataLayout(); |
3461 | |
3462 | if (DL.isNonIntegralAddressSpace(AddrSpace: Ty.getScalarType().getAddressSpace())) |
3463 | return false; |
3464 | |
3465 | if (Ty.isPointer()) { |
3466 | auto ConstVal = getIConstantVRegVal(VReg: PtrAdd.getBaseReg(), MRI); |
3467 | return ConstVal && *ConstVal == 0; |
3468 | } |
3469 | |
3470 | assert(Ty.isVector() && "Expecting a vector type" ); |
3471 | const MachineInstr *VecMI = MRI.getVRegDef(Reg: PtrAdd.getBaseReg()); |
3472 | return isBuildVectorAllZeros(MI: *VecMI, MRI); |
3473 | } |
3474 | |
3475 | void CombinerHelper::applyPtrAddZero(MachineInstr &MI) { |
3476 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
3477 | Builder.buildIntToPtr(Dst: PtrAdd.getReg(Idx: 0), Src: PtrAdd.getOffsetReg()); |
3478 | PtrAdd.eraseFromParent(); |
3479 | } |
3480 | |
3481 | /// The second source operand is known to be a power of 2. |
3482 | void CombinerHelper::applySimplifyURemByPow2(MachineInstr &MI) { |
3483 | Register DstReg = MI.getOperand(i: 0).getReg(); |
3484 | Register Src0 = MI.getOperand(i: 1).getReg(); |
3485 | Register Pow2Src1 = MI.getOperand(i: 2).getReg(); |
3486 | LLT Ty = MRI.getType(Reg: DstReg); |
3487 | |
3488 | // Fold (urem x, pow2) -> (and x, pow2-1) |
3489 | auto NegOne = Builder.buildConstant(Res: Ty, Val: -1); |
3490 | auto Add = Builder.buildAdd(Dst: Ty, Src0: Pow2Src1, Src1: NegOne); |
3491 | Builder.buildAnd(Dst: DstReg, Src0, Src1: Add); |
3492 | MI.eraseFromParent(); |
3493 | } |
3494 | |
3495 | bool CombinerHelper::matchFoldBinOpIntoSelect(MachineInstr &MI, |
3496 | unsigned &SelectOpNo) { |
3497 | Register LHS = MI.getOperand(i: 1).getReg(); |
3498 | Register RHS = MI.getOperand(i: 2).getReg(); |
3499 | |
3500 | Register OtherOperandReg = RHS; |
3501 | SelectOpNo = 1; |
3502 | MachineInstr *Select = MRI.getVRegDef(Reg: LHS); |
3503 | |
3504 | // Don't do this unless the old select is going away. We want to eliminate the |
3505 | // binary operator, not replace a binop with a select. |
3506 | if (Select->getOpcode() != TargetOpcode::G_SELECT || |
3507 | !MRI.hasOneNonDBGUse(RegNo: LHS)) { |
3508 | OtherOperandReg = LHS; |
3509 | SelectOpNo = 2; |
3510 | Select = MRI.getVRegDef(Reg: RHS); |
3511 | if (Select->getOpcode() != TargetOpcode::G_SELECT || |
3512 | !MRI.hasOneNonDBGUse(RegNo: RHS)) |
3513 | return false; |
3514 | } |
3515 | |
3516 | MachineInstr *SelectLHS = MRI.getVRegDef(Reg: Select->getOperand(i: 2).getReg()); |
3517 | MachineInstr *SelectRHS = MRI.getVRegDef(Reg: Select->getOperand(i: 3).getReg()); |
3518 | |
3519 | if (!isConstantOrConstantVector(MI: *SelectLHS, MRI, |
3520 | /*AllowFP*/ true, |
3521 | /*AllowOpaqueConstants*/ false)) |
3522 | return false; |
3523 | if (!isConstantOrConstantVector(MI: *SelectRHS, MRI, |
3524 | /*AllowFP*/ true, |
3525 | /*AllowOpaqueConstants*/ false)) |
3526 | return false; |
3527 | |
3528 | unsigned BinOpcode = MI.getOpcode(); |
3529 | |
3530 | // We know that one of the operands is a select of constants. Now verify that |
3531 | // the other binary operator operand is either a constant, or we can handle a |
3532 | // variable. |
3533 | bool CanFoldNonConst = |
3534 | (BinOpcode == TargetOpcode::G_AND || BinOpcode == TargetOpcode::G_OR) && |
3535 | (isNullOrNullSplat(MI: *SelectLHS, MRI) || |
3536 | isAllOnesOrAllOnesSplat(MI: *SelectLHS, MRI)) && |
3537 | (isNullOrNullSplat(MI: *SelectRHS, MRI) || |
3538 | isAllOnesOrAllOnesSplat(MI: *SelectRHS, MRI)); |
3539 | if (CanFoldNonConst) |
3540 | return true; |
3541 | |
3542 | return isConstantOrConstantVector(MI: *MRI.getVRegDef(Reg: OtherOperandReg), MRI, |
3543 | /*AllowFP*/ true, |
3544 | /*AllowOpaqueConstants*/ false); |
3545 | } |
3546 | |
3547 | /// \p SelectOperand is the operand in binary operator \p MI that is the select |
3548 | /// to fold. |
3549 | void CombinerHelper::applyFoldBinOpIntoSelect(MachineInstr &MI, |
3550 | const unsigned &SelectOperand) { |
3551 | Register Dst = MI.getOperand(i: 0).getReg(); |
3552 | Register LHS = MI.getOperand(i: 1).getReg(); |
3553 | Register RHS = MI.getOperand(i: 2).getReg(); |
3554 | MachineInstr *Select = MRI.getVRegDef(Reg: MI.getOperand(i: SelectOperand).getReg()); |
3555 | |
3556 | Register SelectCond = Select->getOperand(i: 1).getReg(); |
3557 | Register SelectTrue = Select->getOperand(i: 2).getReg(); |
3558 | Register SelectFalse = Select->getOperand(i: 3).getReg(); |
3559 | |
3560 | LLT Ty = MRI.getType(Reg: Dst); |
3561 | unsigned BinOpcode = MI.getOpcode(); |
3562 | |
3563 | Register FoldTrue, FoldFalse; |
3564 | |
3565 | // We have a select-of-constants followed by a binary operator with a |
3566 | // constant. Eliminate the binop by pulling the constant math into the select. |
3567 | // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO |
3568 | if (SelectOperand == 1) { |
3569 | // TODO: SelectionDAG verifies this actually constant folds before |
3570 | // committing to the combine. |
3571 | |
3572 | FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectTrue, RHS}).getReg(Idx: 0); |
3573 | FoldFalse = |
3574 | Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {SelectFalse, RHS}).getReg(Idx: 0); |
3575 | } else { |
3576 | FoldTrue = Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectTrue}).getReg(Idx: 0); |
3577 | FoldFalse = |
3578 | Builder.buildInstr(Opc: BinOpcode, DstOps: {Ty}, SrcOps: {LHS, SelectFalse}).getReg(Idx: 0); |
3579 | } |
3580 | |
3581 | Builder.buildSelect(Res: Dst, Tst: SelectCond, Op0: FoldTrue, Op1: FoldFalse, Flags: MI.getFlags()); |
3582 | MI.eraseFromParent(); |
3583 | } |
3584 | |
3585 | std::optional<SmallVector<Register, 8>> |
3586 | CombinerHelper::findCandidatesForLoadOrCombine(const MachineInstr *Root) const { |
3587 | assert(Root->getOpcode() == TargetOpcode::G_OR && "Expected G_OR only!" ); |
3588 | // We want to detect if Root is part of a tree which represents a bunch |
3589 | // of loads being merged into a larger load. We'll try to recognize patterns |
3590 | // like, for example: |
3591 | // |
3592 | // Reg Reg |
3593 | // \ / |
3594 | // OR_1 Reg |
3595 | // \ / |
3596 | // OR_2 |
3597 | // \ Reg |
3598 | // .. / |
3599 | // Root |
3600 | // |
3601 | // Reg Reg Reg Reg |
3602 | // \ / \ / |
3603 | // OR_1 OR_2 |
3604 | // \ / |
3605 | // \ / |
3606 | // ... |
3607 | // Root |
3608 | // |
3609 | // Each "Reg" may have been produced by a load + some arithmetic. This |
3610 | // function will save each of them. |
3611 | SmallVector<Register, 8> RegsToVisit; |
3612 | SmallVector<const MachineInstr *, 7> Ors = {Root}; |
3613 | |
3614 | // In the "worst" case, we're dealing with a load for each byte. So, there |
3615 | // are at most #bytes - 1 ORs. |
3616 | const unsigned MaxIter = |
3617 | MRI.getType(Reg: Root->getOperand(i: 0).getReg()).getSizeInBytes() - 1; |
3618 | for (unsigned Iter = 0; Iter < MaxIter; ++Iter) { |
3619 | if (Ors.empty()) |
3620 | break; |
3621 | const MachineInstr *Curr = Ors.pop_back_val(); |
3622 | Register OrLHS = Curr->getOperand(i: 1).getReg(); |
3623 | Register OrRHS = Curr->getOperand(i: 2).getReg(); |
3624 | |
3625 | // In the combine, we want to elimate the entire tree. |
3626 | if (!MRI.hasOneNonDBGUse(RegNo: OrLHS) || !MRI.hasOneNonDBGUse(RegNo: OrRHS)) |
3627 | return std::nullopt; |
3628 | |
3629 | // If it's a G_OR, save it and continue to walk. If it's not, then it's |
3630 | // something that may be a load + arithmetic. |
3631 | if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrLHS, MRI)) |
3632 | Ors.push_back(Elt: Or); |
3633 | else |
3634 | RegsToVisit.push_back(Elt: OrLHS); |
3635 | if (const MachineInstr *Or = getOpcodeDef(Opcode: TargetOpcode::G_OR, Reg: OrRHS, MRI)) |
3636 | Ors.push_back(Elt: Or); |
3637 | else |
3638 | RegsToVisit.push_back(Elt: OrRHS); |
3639 | } |
3640 | |
3641 | // We're going to try and merge each register into a wider power-of-2 type, |
3642 | // so we ought to have an even number of registers. |
3643 | if (RegsToVisit.empty() || RegsToVisit.size() % 2 != 0) |
3644 | return std::nullopt; |
3645 | return RegsToVisit; |
3646 | } |
3647 | |
3648 | /// Helper function for findLoadOffsetsForLoadOrCombine. |
3649 | /// |
3650 | /// Check if \p Reg is the result of loading a \p MemSizeInBits wide value, |
3651 | /// and then moving that value into a specific byte offset. |
3652 | /// |
3653 | /// e.g. x[i] << 24 |
3654 | /// |
3655 | /// \returns The load instruction and the byte offset it is moved into. |
3656 | static std::optional<std::pair<GZExtLoad *, int64_t>> |
3657 | matchLoadAndBytePosition(Register Reg, unsigned MemSizeInBits, |
3658 | const MachineRegisterInfo &MRI) { |
3659 | assert(MRI.hasOneNonDBGUse(Reg) && |
3660 | "Expected Reg to only have one non-debug use?" ); |
3661 | Register MaybeLoad; |
3662 | int64_t Shift; |
3663 | if (!mi_match(R: Reg, MRI, |
3664 | P: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: MaybeLoad), R: m_ICst(Cst&: Shift))))) { |
3665 | Shift = 0; |
3666 | MaybeLoad = Reg; |
3667 | } |
3668 | |
3669 | if (Shift % MemSizeInBits != 0) |
3670 | return std::nullopt; |
3671 | |
3672 | // TODO: Handle other types of loads. |
3673 | auto *Load = getOpcodeDef<GZExtLoad>(Reg: MaybeLoad, MRI); |
3674 | if (!Load) |
3675 | return std::nullopt; |
3676 | |
3677 | if (!Load->isUnordered() || Load->getMemSizeInBits() != MemSizeInBits) |
3678 | return std::nullopt; |
3679 | |
3680 | return std::make_pair(x&: Load, y: Shift / MemSizeInBits); |
3681 | } |
3682 | |
3683 | std::optional<std::tuple<GZExtLoad *, int64_t, GZExtLoad *>> |
3684 | CombinerHelper::findLoadOffsetsForLoadOrCombine( |
3685 | SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx, |
3686 | const SmallVector<Register, 8> &RegsToVisit, const unsigned MemSizeInBits) { |
3687 | |
3688 | // Each load found for the pattern. There should be one for each RegsToVisit. |
3689 | SmallSetVector<const MachineInstr *, 8> Loads; |
3690 | |
3691 | // The lowest index used in any load. (The lowest "i" for each x[i].) |
3692 | int64_t LowestIdx = INT64_MAX; |
3693 | |
3694 | // The load which uses the lowest index. |
3695 | GZExtLoad *LowestIdxLoad = nullptr; |
3696 | |
3697 | // Keeps track of the load indices we see. We shouldn't see any indices twice. |
3698 | SmallSet<int64_t, 8> SeenIdx; |
3699 | |
3700 | // Ensure each load is in the same MBB. |
3701 | // TODO: Support multiple MachineBasicBlocks. |
3702 | MachineBasicBlock *MBB = nullptr; |
3703 | const MachineMemOperand *MMO = nullptr; |
3704 | |
3705 | // Earliest instruction-order load in the pattern. |
3706 | GZExtLoad *EarliestLoad = nullptr; |
3707 | |
3708 | // Latest instruction-order load in the pattern. |
3709 | GZExtLoad *LatestLoad = nullptr; |
3710 | |
3711 | // Base pointer which every load should share. |
3712 | Register BasePtr; |
3713 | |
3714 | // We want to find a load for each register. Each load should have some |
3715 | // appropriate bit twiddling arithmetic. During this loop, we will also keep |
3716 | // track of the load which uses the lowest index. Later, we will check if we |
3717 | // can use its pointer in the final, combined load. |
3718 | for (auto Reg : RegsToVisit) { |
3719 | // Find the load, and find the position that it will end up in (e.g. a |
3720 | // shifted) value. |
3721 | auto LoadAndPos = matchLoadAndBytePosition(Reg, MemSizeInBits, MRI); |
3722 | if (!LoadAndPos) |
3723 | return std::nullopt; |
3724 | GZExtLoad *Load; |
3725 | int64_t DstPos; |
3726 | std::tie(args&: Load, args&: DstPos) = *LoadAndPos; |
3727 | |
3728 | // TODO: Handle multiple MachineBasicBlocks. Currently not handled because |
3729 | // it is difficult to check for stores/calls/etc between loads. |
3730 | MachineBasicBlock *LoadMBB = Load->getParent(); |
3731 | if (!MBB) |
3732 | MBB = LoadMBB; |
3733 | if (LoadMBB != MBB) |
3734 | return std::nullopt; |
3735 | |
3736 | // Make sure that the MachineMemOperands of every seen load are compatible. |
3737 | auto &LoadMMO = Load->getMMO(); |
3738 | if (!MMO) |
3739 | MMO = &LoadMMO; |
3740 | if (MMO->getAddrSpace() != LoadMMO.getAddrSpace()) |
3741 | return std::nullopt; |
3742 | |
3743 | // Find out what the base pointer and index for the load is. |
3744 | Register LoadPtr; |
3745 | int64_t Idx; |
3746 | if (!mi_match(R: Load->getOperand(i: 1).getReg(), MRI, |
3747 | P: m_GPtrAdd(L: m_Reg(R&: LoadPtr), R: m_ICst(Cst&: Idx)))) { |
3748 | LoadPtr = Load->getOperand(i: 1).getReg(); |
3749 | Idx = 0; |
3750 | } |
3751 | |
3752 | // Don't combine things like a[i], a[i] -> a bigger load. |
3753 | if (!SeenIdx.insert(V: Idx).second) |
3754 | return std::nullopt; |
3755 | |
3756 | // Every load must share the same base pointer; don't combine things like: |
3757 | // |
3758 | // a[i], b[i + 1] -> a bigger load. |
3759 | if (!BasePtr.isValid()) |
3760 | BasePtr = LoadPtr; |
3761 | if (BasePtr != LoadPtr) |
3762 | return std::nullopt; |
3763 | |
3764 | if (Idx < LowestIdx) { |
3765 | LowestIdx = Idx; |
3766 | LowestIdxLoad = Load; |
3767 | } |
3768 | |
3769 | // Keep track of the byte offset that this load ends up at. If we have seen |
3770 | // the byte offset, then stop here. We do not want to combine: |
3771 | // |
3772 | // a[i] << 16, a[i + k] << 16 -> a bigger load. |
3773 | if (!MemOffset2Idx.try_emplace(Key: DstPos, Args&: Idx).second) |
3774 | return std::nullopt; |
3775 | Loads.insert(X: Load); |
3776 | |
3777 | // Keep track of the position of the earliest/latest loads in the pattern. |
3778 | // We will check that there are no load fold barriers between them later |
3779 | // on. |
3780 | // |
3781 | // FIXME: Is there a better way to check for load fold barriers? |
3782 | if (!EarliestLoad || dominates(DefMI: *Load, UseMI: *EarliestLoad)) |
3783 | EarliestLoad = Load; |
3784 | if (!LatestLoad || dominates(DefMI: *LatestLoad, UseMI: *Load)) |
3785 | LatestLoad = Load; |
3786 | } |
3787 | |
3788 | // We found a load for each register. Let's check if each load satisfies the |
3789 | // pattern. |
3790 | assert(Loads.size() == RegsToVisit.size() && |
3791 | "Expected to find a load for each register?" ); |
3792 | assert(EarliestLoad != LatestLoad && EarliestLoad && |
3793 | LatestLoad && "Expected at least two loads?" ); |
3794 | |
3795 | // Check if there are any stores, calls, etc. between any of the loads. If |
3796 | // there are, then we can't safely perform the combine. |
3797 | // |
3798 | // MaxIter is chosen based off the (worst case) number of iterations it |
3799 | // typically takes to succeed in the LLVM test suite plus some padding. |
3800 | // |
3801 | // FIXME: Is there a better way to check for load fold barriers? |
3802 | const unsigned MaxIter = 20; |
3803 | unsigned Iter = 0; |
3804 | for (const auto &MI : instructionsWithoutDebug(It: EarliestLoad->getIterator(), |
3805 | End: LatestLoad->getIterator())) { |
3806 | if (Loads.count(key: &MI)) |
3807 | continue; |
3808 | if (MI.isLoadFoldBarrier()) |
3809 | return std::nullopt; |
3810 | if (Iter++ == MaxIter) |
3811 | return std::nullopt; |
3812 | } |
3813 | |
3814 | return std::make_tuple(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad); |
3815 | } |
3816 | |
3817 | bool CombinerHelper::matchLoadOrCombine( |
3818 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
3819 | assert(MI.getOpcode() == TargetOpcode::G_OR); |
3820 | MachineFunction &MF = *MI.getMF(); |
3821 | // Assuming a little-endian target, transform: |
3822 | // s8 *a = ... |
3823 | // s32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24) |
3824 | // => |
3825 | // s32 val = *((i32)a) |
3826 | // |
3827 | // s8 *a = ... |
3828 | // s32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3] |
3829 | // => |
3830 | // s32 val = BSWAP(*((s32)a)) |
3831 | Register Dst = MI.getOperand(i: 0).getReg(); |
3832 | LLT Ty = MRI.getType(Reg: Dst); |
3833 | if (Ty.isVector()) |
3834 | return false; |
3835 | |
3836 | // We need to combine at least two loads into this type. Since the smallest |
3837 | // possible load is into a byte, we need at least a 16-bit wide type. |
3838 | const unsigned WideMemSizeInBits = Ty.getSizeInBits(); |
3839 | if (WideMemSizeInBits < 16 || WideMemSizeInBits % 8 != 0) |
3840 | return false; |
3841 | |
3842 | // Match a collection of non-OR instructions in the pattern. |
3843 | auto RegsToVisit = findCandidatesForLoadOrCombine(Root: &MI); |
3844 | if (!RegsToVisit) |
3845 | return false; |
3846 | |
3847 | // We have a collection of non-OR instructions. Figure out how wide each of |
3848 | // the small loads should be based off of the number of potential loads we |
3849 | // found. |
3850 | const unsigned NarrowMemSizeInBits = WideMemSizeInBits / RegsToVisit->size(); |
3851 | if (NarrowMemSizeInBits % 8 != 0) |
3852 | return false; |
3853 | |
3854 | // Check if each register feeding into each OR is a load from the same |
3855 | // base pointer + some arithmetic. |
3856 | // |
3857 | // e.g. a[0], a[1] << 8, a[2] << 16, etc. |
3858 | // |
3859 | // Also verify that each of these ends up putting a[i] into the same memory |
3860 | // offset as a load into a wide type would. |
3861 | SmallDenseMap<int64_t, int64_t, 8> MemOffset2Idx; |
3862 | GZExtLoad *LowestIdxLoad, *LatestLoad; |
3863 | int64_t LowestIdx; |
3864 | auto MaybeLoadInfo = findLoadOffsetsForLoadOrCombine( |
3865 | MemOffset2Idx, RegsToVisit: *RegsToVisit, MemSizeInBits: NarrowMemSizeInBits); |
3866 | if (!MaybeLoadInfo) |
3867 | return false; |
3868 | std::tie(args&: LowestIdxLoad, args&: LowestIdx, args&: LatestLoad) = *MaybeLoadInfo; |
3869 | |
3870 | // We have a bunch of loads being OR'd together. Using the addresses + offsets |
3871 | // we found before, check if this corresponds to a big or little endian byte |
3872 | // pattern. If it does, then we can represent it using a load + possibly a |
3873 | // BSWAP. |
3874 | bool IsBigEndianTarget = MF.getDataLayout().isBigEndian(); |
3875 | std::optional<bool> IsBigEndian = isBigEndian(MemOffset2Idx, LowestIdx); |
3876 | if (!IsBigEndian) |
3877 | return false; |
3878 | bool NeedsBSwap = IsBigEndianTarget != *IsBigEndian; |
3879 | if (NeedsBSwap && !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_BSWAP, {Ty}})) |
3880 | return false; |
3881 | |
3882 | // Make sure that the load from the lowest index produces offset 0 in the |
3883 | // final value. |
3884 | // |
3885 | // This ensures that we won't combine something like this: |
3886 | // |
3887 | // load x[i] -> byte 2 |
3888 | // load x[i+1] -> byte 0 ---> wide_load x[i] |
3889 | // load x[i+2] -> byte 1 |
3890 | const unsigned NumLoadsInTy = WideMemSizeInBits / NarrowMemSizeInBits; |
3891 | const unsigned ZeroByteOffset = |
3892 | *IsBigEndian |
3893 | ? bigEndianByteAt(ByteWidth: NumLoadsInTy, I: 0) |
3894 | : littleEndianByteAt(ByteWidth: NumLoadsInTy, I: 0); |
3895 | auto ZeroOffsetIdx = MemOffset2Idx.find(Val: ZeroByteOffset); |
3896 | if (ZeroOffsetIdx == MemOffset2Idx.end() || |
3897 | ZeroOffsetIdx->second != LowestIdx) |
3898 | return false; |
3899 | |
3900 | // We wil reuse the pointer from the load which ends up at byte offset 0. It |
3901 | // may not use index 0. |
3902 | Register Ptr = LowestIdxLoad->getPointerReg(); |
3903 | const MachineMemOperand &MMO = LowestIdxLoad->getMMO(); |
3904 | LegalityQuery::MemDesc MMDesc(MMO); |
3905 | MMDesc.MemoryTy = Ty; |
3906 | if (!isLegalOrBeforeLegalizer( |
3907 | Query: {TargetOpcode::G_LOAD, {Ty, MRI.getType(Reg: Ptr)}, {MMDesc}})) |
3908 | return false; |
3909 | auto PtrInfo = MMO.getPointerInfo(); |
3910 | auto *NewMMO = MF.getMachineMemOperand(MMO: &MMO, PtrInfo, Size: WideMemSizeInBits / 8); |
3911 | |
3912 | // Load must be allowed and fast on the target. |
3913 | LLVMContext &C = MF.getFunction().getContext(); |
3914 | auto &DL = MF.getDataLayout(); |
3915 | unsigned Fast = 0; |
3916 | if (!getTargetLowering().allowsMemoryAccess(Context&: C, DL, Ty, MMO: *NewMMO, Fast: &Fast) || |
3917 | !Fast) |
3918 | return false; |
3919 | |
3920 | MatchInfo = [=](MachineIRBuilder &MIB) { |
3921 | MIB.setInstrAndDebugLoc(*LatestLoad); |
3922 | Register LoadDst = NeedsBSwap ? MRI.cloneVirtualRegister(VReg: Dst) : Dst; |
3923 | MIB.buildLoad(Res: LoadDst, Addr: Ptr, MMO&: *NewMMO); |
3924 | if (NeedsBSwap) |
3925 | MIB.buildBSwap(Dst, Src0: LoadDst); |
3926 | }; |
3927 | return true; |
3928 | } |
3929 | |
3930 | bool CombinerHelper::matchExtendThroughPhis(MachineInstr &MI, |
3931 | MachineInstr *&ExtMI) { |
3932 | auto &PHI = cast<GPhi>(Val&: MI); |
3933 | Register DstReg = PHI.getReg(Idx: 0); |
3934 | |
3935 | // TODO: Extending a vector may be expensive, don't do this until heuristics |
3936 | // are better. |
3937 | if (MRI.getType(Reg: DstReg).isVector()) |
3938 | return false; |
3939 | |
3940 | // Try to match a phi, whose only use is an extend. |
3941 | if (!MRI.hasOneNonDBGUse(RegNo: DstReg)) |
3942 | return false; |
3943 | ExtMI = &*MRI.use_instr_nodbg_begin(RegNo: DstReg); |
3944 | switch (ExtMI->getOpcode()) { |
3945 | case TargetOpcode::G_ANYEXT: |
3946 | return true; // G_ANYEXT is usually free. |
3947 | case TargetOpcode::G_ZEXT: |
3948 | case TargetOpcode::G_SEXT: |
3949 | break; |
3950 | default: |
3951 | return false; |
3952 | } |
3953 | |
3954 | // If the target is likely to fold this extend away, don't propagate. |
3955 | if (Builder.getTII().isExtendLikelyToBeFolded(ExtMI&: *ExtMI, MRI)) |
3956 | return false; |
3957 | |
3958 | // We don't want to propagate the extends unless there's a good chance that |
3959 | // they'll be optimized in some way. |
3960 | // Collect the unique incoming values. |
3961 | SmallPtrSet<MachineInstr *, 4> InSrcs; |
3962 | for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) { |
3963 | auto *DefMI = getDefIgnoringCopies(Reg: PHI.getIncomingValue(I), MRI); |
3964 | switch (DefMI->getOpcode()) { |
3965 | case TargetOpcode::G_LOAD: |
3966 | case TargetOpcode::G_TRUNC: |
3967 | case TargetOpcode::G_SEXT: |
3968 | case TargetOpcode::G_ZEXT: |
3969 | case TargetOpcode::G_ANYEXT: |
3970 | case TargetOpcode::G_CONSTANT: |
3971 | InSrcs.insert(Ptr: DefMI); |
3972 | // Don't try to propagate if there are too many places to create new |
3973 | // extends, chances are it'll increase code size. |
3974 | if (InSrcs.size() > 2) |
3975 | return false; |
3976 | break; |
3977 | default: |
3978 | return false; |
3979 | } |
3980 | } |
3981 | return true; |
3982 | } |
3983 | |
3984 | void CombinerHelper::applyExtendThroughPhis(MachineInstr &MI, |
3985 | MachineInstr *&ExtMI) { |
3986 | auto &PHI = cast<GPhi>(Val&: MI); |
3987 | Register DstReg = ExtMI->getOperand(i: 0).getReg(); |
3988 | LLT ExtTy = MRI.getType(Reg: DstReg); |
3989 | |
3990 | // Propagate the extension into the block of each incoming reg's block. |
3991 | // Use a SetVector here because PHIs can have duplicate edges, and we want |
3992 | // deterministic iteration order. |
3993 | SmallSetVector<MachineInstr *, 8> SrcMIs; |
3994 | SmallDenseMap<MachineInstr *, MachineInstr *, 8> OldToNewSrcMap; |
3995 | for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) { |
3996 | auto SrcReg = PHI.getIncomingValue(I); |
3997 | auto *SrcMI = MRI.getVRegDef(Reg: SrcReg); |
3998 | if (!SrcMIs.insert(X: SrcMI)) |
3999 | continue; |
4000 | |
4001 | // Build an extend after each src inst. |
4002 | auto *MBB = SrcMI->getParent(); |
4003 | MachineBasicBlock::iterator InsertPt = ++SrcMI->getIterator(); |
4004 | if (InsertPt != MBB->end() && InsertPt->isPHI()) |
4005 | InsertPt = MBB->getFirstNonPHI(); |
4006 | |
4007 | Builder.setInsertPt(MBB&: *SrcMI->getParent(), II: InsertPt); |
4008 | Builder.setDebugLoc(MI.getDebugLoc()); |
4009 | auto NewExt = Builder.buildExtOrTrunc(ExtOpc: ExtMI->getOpcode(), Res: ExtTy, Op: SrcReg); |
4010 | OldToNewSrcMap[SrcMI] = NewExt; |
4011 | } |
4012 | |
4013 | // Create a new phi with the extended inputs. |
4014 | Builder.setInstrAndDebugLoc(MI); |
4015 | auto NewPhi = Builder.buildInstrNoInsert(Opcode: TargetOpcode::G_PHI); |
4016 | NewPhi.addDef(RegNo: DstReg); |
4017 | for (const MachineOperand &MO : llvm::drop_begin(RangeOrContainer: MI.operands())) { |
4018 | if (!MO.isReg()) { |
4019 | NewPhi.addMBB(MBB: MO.getMBB()); |
4020 | continue; |
4021 | } |
4022 | auto *NewSrc = OldToNewSrcMap[MRI.getVRegDef(Reg: MO.getReg())]; |
4023 | NewPhi.addUse(RegNo: NewSrc->getOperand(i: 0).getReg()); |
4024 | } |
4025 | Builder.insertInstr(MIB: NewPhi); |
4026 | ExtMI->eraseFromParent(); |
4027 | } |
4028 | |
4029 | bool CombinerHelper::(MachineInstr &MI, |
4030 | Register &Reg) { |
4031 | assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT); |
4032 | // If we have a constant index, look for a G_BUILD_VECTOR source |
4033 | // and find the source register that the index maps to. |
4034 | Register SrcVec = MI.getOperand(i: 1).getReg(); |
4035 | LLT SrcTy = MRI.getType(Reg: SrcVec); |
4036 | |
4037 | auto Cst = getIConstantVRegValWithLookThrough(VReg: MI.getOperand(i: 2).getReg(), MRI); |
4038 | if (!Cst || Cst->Value.getZExtValue() >= SrcTy.getNumElements()) |
4039 | return false; |
4040 | |
4041 | unsigned VecIdx = Cst->Value.getZExtValue(); |
4042 | |
4043 | // Check if we have a build_vector or build_vector_trunc with an optional |
4044 | // trunc in front. |
4045 | MachineInstr *SrcVecMI = MRI.getVRegDef(Reg: SrcVec); |
4046 | if (SrcVecMI->getOpcode() == TargetOpcode::G_TRUNC) { |
4047 | SrcVecMI = MRI.getVRegDef(Reg: SrcVecMI->getOperand(i: 1).getReg()); |
4048 | } |
4049 | |
4050 | if (SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR && |
4051 | SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR_TRUNC) |
4052 | return false; |
4053 | |
4054 | EVT Ty(getMVTForLLT(Ty: SrcTy)); |
4055 | if (!MRI.hasOneNonDBGUse(RegNo: SrcVec) && |
4056 | !getTargetLowering().aggressivelyPreferBuildVectorSources(VecVT: Ty)) |
4057 | return false; |
4058 | |
4059 | Reg = SrcVecMI->getOperand(i: VecIdx + 1).getReg(); |
4060 | return true; |
4061 | } |
4062 | |
4063 | void CombinerHelper::(MachineInstr &MI, |
4064 | Register &Reg) { |
4065 | // Check the type of the register, since it may have come from a |
4066 | // G_BUILD_VECTOR_TRUNC. |
4067 | LLT ScalarTy = MRI.getType(Reg); |
4068 | Register DstReg = MI.getOperand(i: 0).getReg(); |
4069 | LLT DstTy = MRI.getType(Reg: DstReg); |
4070 | |
4071 | if (ScalarTy != DstTy) { |
4072 | assert(ScalarTy.getSizeInBits() > DstTy.getSizeInBits()); |
4073 | Builder.buildTrunc(Res: DstReg, Op: Reg); |
4074 | MI.eraseFromParent(); |
4075 | return; |
4076 | } |
4077 | replaceSingleDefInstWithReg(MI, Replacement: Reg); |
4078 | } |
4079 | |
4080 | bool CombinerHelper::( |
4081 | MachineInstr &MI, |
4082 | SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) { |
4083 | assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); |
4084 | // This combine tries to find build_vector's which have every source element |
4085 | // extracted using G_EXTRACT_VECTOR_ELT. This can happen when transforms like |
4086 | // the masked load scalarization is run late in the pipeline. There's already |
4087 | // a combine for a similar pattern starting from the extract, but that |
4088 | // doesn't attempt to do it if there are multiple uses of the build_vector, |
4089 | // which in this case is true. Starting the combine from the build_vector |
4090 | // feels more natural than trying to find sibling nodes of extracts. |
4091 | // E.g. |
4092 | // %vec(<4 x s32>) = G_BUILD_VECTOR %s1(s32), %s2, %s3, %s4 |
4093 | // %ext1 = G_EXTRACT_VECTOR_ELT %vec, 0 |
4094 | // %ext2 = G_EXTRACT_VECTOR_ELT %vec, 1 |
4095 | // %ext3 = G_EXTRACT_VECTOR_ELT %vec, 2 |
4096 | // %ext4 = G_EXTRACT_VECTOR_ELT %vec, 3 |
4097 | // ==> |
4098 | // replace ext{1,2,3,4} with %s{1,2,3,4} |
4099 | |
4100 | Register DstReg = MI.getOperand(i: 0).getReg(); |
4101 | LLT DstTy = MRI.getType(Reg: DstReg); |
4102 | unsigned NumElts = DstTy.getNumElements(); |
4103 | |
4104 | SmallBitVector (NumElts); |
4105 | for (MachineInstr &II : MRI.use_nodbg_instructions(Reg: DstReg)) { |
4106 | if (II.getOpcode() != TargetOpcode::G_EXTRACT_VECTOR_ELT) |
4107 | return false; |
4108 | auto Cst = getIConstantVRegVal(VReg: II.getOperand(i: 2).getReg(), MRI); |
4109 | if (!Cst) |
4110 | return false; |
4111 | unsigned Idx = Cst->getZExtValue(); |
4112 | if (Idx >= NumElts) |
4113 | return false; // Out of range. |
4114 | ExtractedElts.set(Idx); |
4115 | SrcDstPairs.emplace_back( |
4116 | Args: std::make_pair(x: MI.getOperand(i: Idx + 1).getReg(), y: &II)); |
4117 | } |
4118 | // Match if every element was extracted. |
4119 | return ExtractedElts.all(); |
4120 | } |
4121 | |
4122 | void CombinerHelper::( |
4123 | MachineInstr &MI, |
4124 | SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) { |
4125 | assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); |
4126 | for (auto &Pair : SrcDstPairs) { |
4127 | auto *ExtMI = Pair.second; |
4128 | replaceRegWith(MRI, FromReg: ExtMI->getOperand(i: 0).getReg(), ToReg: Pair.first); |
4129 | ExtMI->eraseFromParent(); |
4130 | } |
4131 | MI.eraseFromParent(); |
4132 | } |
4133 | |
4134 | void CombinerHelper::applyBuildFn( |
4135 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4136 | applyBuildFnNoErase(MI, MatchInfo); |
4137 | MI.eraseFromParent(); |
4138 | } |
4139 | |
4140 | void CombinerHelper::applyBuildFnMO(const MachineOperand &MO, |
4141 | BuildFnTy &MatchInfo) { |
4142 | MachineInstr *Root = getDefIgnoringCopies(Reg: MO.getReg(), MRI); |
4143 | Builder.setInstrAndDebugLoc(*Root); |
4144 | MatchInfo(Builder); |
4145 | Root->eraseFromParent(); |
4146 | } |
4147 | |
4148 | void CombinerHelper::applyBuildFnNoErase( |
4149 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4150 | MatchInfo(Builder); |
4151 | } |
4152 | |
4153 | bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI, |
4154 | BuildFnTy &MatchInfo) { |
4155 | assert(MI.getOpcode() == TargetOpcode::G_OR); |
4156 | |
4157 | Register Dst = MI.getOperand(i: 0).getReg(); |
4158 | LLT Ty = MRI.getType(Reg: Dst); |
4159 | unsigned BitWidth = Ty.getScalarSizeInBits(); |
4160 | |
4161 | Register ShlSrc, ShlAmt, LShrSrc, LShrAmt, Amt; |
4162 | unsigned FshOpc = 0; |
4163 | |
4164 | // Match (or (shl ...), (lshr ...)). |
4165 | if (!mi_match(R: Dst, MRI, |
4166 | // m_GOr() handles the commuted version as well. |
4167 | P: m_GOr(L: m_GShl(L: m_Reg(R&: ShlSrc), R: m_Reg(R&: ShlAmt)), |
4168 | R: m_GLShr(L: m_Reg(R&: LShrSrc), R: m_Reg(R&: LShrAmt))))) |
4169 | return false; |
4170 | |
4171 | // Given constants C0 and C1 such that C0 + C1 is bit-width: |
4172 | // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1) |
4173 | int64_t CstShlAmt, CstLShrAmt; |
4174 | if (mi_match(R: ShlAmt, MRI, P: m_ICstOrSplat(Cst&: CstShlAmt)) && |
4175 | mi_match(R: LShrAmt, MRI, P: m_ICstOrSplat(Cst&: CstLShrAmt)) && |
4176 | CstShlAmt + CstLShrAmt == BitWidth) { |
4177 | FshOpc = TargetOpcode::G_FSHR; |
4178 | Amt = LShrAmt; |
4179 | |
4180 | } else if (mi_match(R: LShrAmt, MRI, |
4181 | P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) && |
4182 | ShlAmt == Amt) { |
4183 | // (or (shl x, amt), (lshr y, (sub bw, amt))) -> (fshl x, y, amt) |
4184 | FshOpc = TargetOpcode::G_FSHL; |
4185 | |
4186 | } else if (mi_match(R: ShlAmt, MRI, |
4187 | P: m_GSub(L: m_SpecificICstOrSplat(RequestedValue: BitWidth), R: m_Reg(R&: Amt))) && |
4188 | LShrAmt == Amt) { |
4189 | // (or (shl x, (sub bw, amt)), (lshr y, amt)) -> (fshr x, y, amt) |
4190 | FshOpc = TargetOpcode::G_FSHR; |
4191 | |
4192 | } else { |
4193 | return false; |
4194 | } |
4195 | |
4196 | LLT AmtTy = MRI.getType(Reg: Amt); |
4197 | if (!isLegalOrBeforeLegalizer(Query: {FshOpc, {Ty, AmtTy}})) |
4198 | return false; |
4199 | |
4200 | MatchInfo = [=](MachineIRBuilder &B) { |
4201 | B.buildInstr(Opc: FshOpc, DstOps: {Dst}, SrcOps: {ShlSrc, LShrSrc, Amt}); |
4202 | }; |
4203 | return true; |
4204 | } |
4205 | |
4206 | /// Match an FSHL or FSHR that can be combined to a ROTR or ROTL rotate. |
4207 | bool CombinerHelper::matchFunnelShiftToRotate(MachineInstr &MI) { |
4208 | unsigned Opc = MI.getOpcode(); |
4209 | assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR); |
4210 | Register X = MI.getOperand(i: 1).getReg(); |
4211 | Register Y = MI.getOperand(i: 2).getReg(); |
4212 | if (X != Y) |
4213 | return false; |
4214 | unsigned RotateOpc = |
4215 | Opc == TargetOpcode::G_FSHL ? TargetOpcode::G_ROTL : TargetOpcode::G_ROTR; |
4216 | return isLegalOrBeforeLegalizer(Query: {RotateOpc, {MRI.getType(Reg: X), MRI.getType(Reg: Y)}}); |
4217 | } |
4218 | |
4219 | void CombinerHelper::applyFunnelShiftToRotate(MachineInstr &MI) { |
4220 | unsigned Opc = MI.getOpcode(); |
4221 | assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR); |
4222 | bool IsFSHL = Opc == TargetOpcode::G_FSHL; |
4223 | Observer.changingInstr(MI); |
4224 | MI.setDesc(Builder.getTII().get(Opcode: IsFSHL ? TargetOpcode::G_ROTL |
4225 | : TargetOpcode::G_ROTR)); |
4226 | MI.removeOperand(OpNo: 2); |
4227 | Observer.changedInstr(MI); |
4228 | } |
4229 | |
4230 | // Fold (rot x, c) -> (rot x, c % BitSize) |
4231 | bool CombinerHelper::matchRotateOutOfRange(MachineInstr &MI) { |
4232 | assert(MI.getOpcode() == TargetOpcode::G_ROTL || |
4233 | MI.getOpcode() == TargetOpcode::G_ROTR); |
4234 | unsigned Bitsize = |
4235 | MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits(); |
4236 | Register AmtReg = MI.getOperand(i: 2).getReg(); |
4237 | bool OutOfRange = false; |
4238 | auto MatchOutOfRange = [Bitsize, &OutOfRange](const Constant *C) { |
4239 | if (auto *CI = dyn_cast<ConstantInt>(Val: C)) |
4240 | OutOfRange |= CI->getValue().uge(RHS: Bitsize); |
4241 | return true; |
4242 | }; |
4243 | return matchUnaryPredicate(MRI, Reg: AmtReg, Match: MatchOutOfRange) && OutOfRange; |
4244 | } |
4245 | |
4246 | void CombinerHelper::applyRotateOutOfRange(MachineInstr &MI) { |
4247 | assert(MI.getOpcode() == TargetOpcode::G_ROTL || |
4248 | MI.getOpcode() == TargetOpcode::G_ROTR); |
4249 | unsigned Bitsize = |
4250 | MRI.getType(Reg: MI.getOperand(i: 0).getReg()).getScalarSizeInBits(); |
4251 | Register Amt = MI.getOperand(i: 2).getReg(); |
4252 | LLT AmtTy = MRI.getType(Reg: Amt); |
4253 | auto Bits = Builder.buildConstant(Res: AmtTy, Val: Bitsize); |
4254 | Amt = Builder.buildURem(Dst: AmtTy, Src0: MI.getOperand(i: 2).getReg(), Src1: Bits).getReg(Idx: 0); |
4255 | Observer.changingInstr(MI); |
4256 | MI.getOperand(i: 2).setReg(Amt); |
4257 | Observer.changedInstr(MI); |
4258 | } |
4259 | |
4260 | bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI, |
4261 | int64_t &MatchInfo) { |
4262 | assert(MI.getOpcode() == TargetOpcode::G_ICMP); |
4263 | auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate()); |
4264 | auto KnownLHS = KB->getKnownBits(R: MI.getOperand(i: 2).getReg()); |
4265 | auto KnownRHS = KB->getKnownBits(R: MI.getOperand(i: 3).getReg()); |
4266 | std::optional<bool> KnownVal; |
4267 | switch (Pred) { |
4268 | default: |
4269 | llvm_unreachable("Unexpected G_ICMP predicate?" ); |
4270 | case CmpInst::ICMP_EQ: |
4271 | KnownVal = KnownBits::eq(LHS: KnownLHS, RHS: KnownRHS); |
4272 | break; |
4273 | case CmpInst::ICMP_NE: |
4274 | KnownVal = KnownBits::ne(LHS: KnownLHS, RHS: KnownRHS); |
4275 | break; |
4276 | case CmpInst::ICMP_SGE: |
4277 | KnownVal = KnownBits::sge(LHS: KnownLHS, RHS: KnownRHS); |
4278 | break; |
4279 | case CmpInst::ICMP_SGT: |
4280 | KnownVal = KnownBits::sgt(LHS: KnownLHS, RHS: KnownRHS); |
4281 | break; |
4282 | case CmpInst::ICMP_SLE: |
4283 | KnownVal = KnownBits::sle(LHS: KnownLHS, RHS: KnownRHS); |
4284 | break; |
4285 | case CmpInst::ICMP_SLT: |
4286 | KnownVal = KnownBits::slt(LHS: KnownLHS, RHS: KnownRHS); |
4287 | break; |
4288 | case CmpInst::ICMP_UGE: |
4289 | KnownVal = KnownBits::uge(LHS: KnownLHS, RHS: KnownRHS); |
4290 | break; |
4291 | case CmpInst::ICMP_UGT: |
4292 | KnownVal = KnownBits::ugt(LHS: KnownLHS, RHS: KnownRHS); |
4293 | break; |
4294 | case CmpInst::ICMP_ULE: |
4295 | KnownVal = KnownBits::ule(LHS: KnownLHS, RHS: KnownRHS); |
4296 | break; |
4297 | case CmpInst::ICMP_ULT: |
4298 | KnownVal = KnownBits::ult(LHS: KnownLHS, RHS: KnownRHS); |
4299 | break; |
4300 | } |
4301 | if (!KnownVal) |
4302 | return false; |
4303 | MatchInfo = |
4304 | *KnownVal |
4305 | ? getICmpTrueVal(TLI: getTargetLowering(), |
4306 | /*IsVector = */ |
4307 | MRI.getType(Reg: MI.getOperand(i: 0).getReg()).isVector(), |
4308 | /* IsFP = */ false) |
4309 | : 0; |
4310 | return true; |
4311 | } |
4312 | |
4313 | bool CombinerHelper::matchICmpToLHSKnownBits( |
4314 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4315 | assert(MI.getOpcode() == TargetOpcode::G_ICMP); |
4316 | // Given: |
4317 | // |
4318 | // %x = G_WHATEVER (... x is known to be 0 or 1 ...) |
4319 | // %cmp = G_ICMP ne %x, 0 |
4320 | // |
4321 | // Or: |
4322 | // |
4323 | // %x = G_WHATEVER (... x is known to be 0 or 1 ...) |
4324 | // %cmp = G_ICMP eq %x, 1 |
4325 | // |
4326 | // We can replace %cmp with %x assuming true is 1 on the target. |
4327 | auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(i: 1).getPredicate()); |
4328 | if (!CmpInst::isEquality(pred: Pred)) |
4329 | return false; |
4330 | Register Dst = MI.getOperand(i: 0).getReg(); |
4331 | LLT DstTy = MRI.getType(Reg: Dst); |
4332 | if (getICmpTrueVal(TLI: getTargetLowering(), IsVector: DstTy.isVector(), |
4333 | /* IsFP = */ false) != 1) |
4334 | return false; |
4335 | int64_t OneOrZero = Pred == CmpInst::ICMP_EQ; |
4336 | if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICst(RequestedValue: OneOrZero))) |
4337 | return false; |
4338 | Register LHS = MI.getOperand(i: 2).getReg(); |
4339 | auto KnownLHS = KB->getKnownBits(R: LHS); |
4340 | if (KnownLHS.getMinValue() != 0 || KnownLHS.getMaxValue() != 1) |
4341 | return false; |
4342 | // Make sure replacing Dst with the LHS is a legal operation. |
4343 | LLT LHSTy = MRI.getType(Reg: LHS); |
4344 | unsigned LHSSize = LHSTy.getSizeInBits(); |
4345 | unsigned DstSize = DstTy.getSizeInBits(); |
4346 | unsigned Op = TargetOpcode::COPY; |
4347 | if (DstSize != LHSSize) |
4348 | Op = DstSize < LHSSize ? TargetOpcode::G_TRUNC : TargetOpcode::G_ZEXT; |
4349 | if (!isLegalOrBeforeLegalizer(Query: {Op, {DstTy, LHSTy}})) |
4350 | return false; |
4351 | MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Opc: Op, DstOps: {Dst}, SrcOps: {LHS}); }; |
4352 | return true; |
4353 | } |
4354 | |
4355 | // Replace (and (or x, c1), c2) with (and x, c2) iff c1 & c2 == 0 |
4356 | bool CombinerHelper::matchAndOrDisjointMask( |
4357 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4358 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
4359 | |
4360 | // Ignore vector types to simplify matching the two constants. |
4361 | // TODO: do this for vectors and scalars via a demanded bits analysis. |
4362 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
4363 | if (Ty.isVector()) |
4364 | return false; |
4365 | |
4366 | Register Src; |
4367 | Register AndMaskReg; |
4368 | int64_t AndMaskBits; |
4369 | int64_t OrMaskBits; |
4370 | if (!mi_match(MI, MRI, |
4371 | P: m_GAnd(L: m_GOr(L: m_Reg(R&: Src), R: m_ICst(Cst&: OrMaskBits)), |
4372 | R: m_all_of(preds: m_ICst(Cst&: AndMaskBits), preds: m_Reg(R&: AndMaskReg))))) |
4373 | return false; |
4374 | |
4375 | // Check if OrMask could turn on any bits in Src. |
4376 | if (AndMaskBits & OrMaskBits) |
4377 | return false; |
4378 | |
4379 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4380 | Observer.changingInstr(MI); |
4381 | // Canonicalize the result to have the constant on the RHS. |
4382 | if (MI.getOperand(i: 1).getReg() == AndMaskReg) |
4383 | MI.getOperand(i: 2).setReg(AndMaskReg); |
4384 | MI.getOperand(i: 1).setReg(Src); |
4385 | Observer.changedInstr(MI); |
4386 | }; |
4387 | return true; |
4388 | } |
4389 | |
4390 | /// Form a G_SBFX from a G_SEXT_INREG fed by a right shift. |
4391 | bool CombinerHelper::( |
4392 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4393 | assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); |
4394 | Register Dst = MI.getOperand(i: 0).getReg(); |
4395 | Register Src = MI.getOperand(i: 1).getReg(); |
4396 | LLT Ty = MRI.getType(Reg: Src); |
4397 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4398 | if (!LI || !LI->isLegalOrCustom(Query: {TargetOpcode::G_SBFX, {Ty, ExtractTy}})) |
4399 | return false; |
4400 | int64_t Width = MI.getOperand(i: 2).getImm(); |
4401 | Register ShiftSrc; |
4402 | int64_t ShiftImm; |
4403 | if (!mi_match( |
4404 | R: Src, MRI, |
4405 | P: m_OneNonDBGUse(SP: m_any_of(preds: m_GAShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm)), |
4406 | preds: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: ShiftImm)))))) |
4407 | return false; |
4408 | if (ShiftImm < 0 || ShiftImm + Width > Ty.getScalarSizeInBits()) |
4409 | return false; |
4410 | |
4411 | MatchInfo = [=](MachineIRBuilder &B) { |
4412 | auto Cst1 = B.buildConstant(Res: ExtractTy, Val: ShiftImm); |
4413 | auto Cst2 = B.buildConstant(Res: ExtractTy, Val: Width); |
4414 | B.buildSbfx(Dst, Src: ShiftSrc, LSB: Cst1, Width: Cst2); |
4415 | }; |
4416 | return true; |
4417 | } |
4418 | |
4419 | /// Form a G_UBFX from "(a srl b) & mask", where b and mask are constants. |
4420 | bool CombinerHelper::matchBitfieldExtractFromAnd( |
4421 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4422 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
4423 | Register Dst = MI.getOperand(i: 0).getReg(); |
4424 | LLT Ty = MRI.getType(Reg: Dst); |
4425 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4426 | if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}})) |
4427 | return false; |
4428 | |
4429 | int64_t AndImm, LSBImm; |
4430 | Register ShiftSrc; |
4431 | const unsigned Size = Ty.getScalarSizeInBits(); |
4432 | if (!mi_match(R: MI.getOperand(i: 0).getReg(), MRI, |
4433 | P: m_GAnd(L: m_OneNonDBGUse(SP: m_GLShr(L: m_Reg(R&: ShiftSrc), R: m_ICst(Cst&: LSBImm))), |
4434 | R: m_ICst(Cst&: AndImm)))) |
4435 | return false; |
4436 | |
4437 | // The mask is a mask of the low bits iff imm & (imm+1) == 0. |
4438 | auto MaybeMask = static_cast<uint64_t>(AndImm); |
4439 | if (MaybeMask & (MaybeMask + 1)) |
4440 | return false; |
4441 | |
4442 | // LSB must fit within the register. |
4443 | if (static_cast<uint64_t>(LSBImm) >= Size) |
4444 | return false; |
4445 | |
4446 | uint64_t Width = APInt(Size, AndImm).countr_one(); |
4447 | MatchInfo = [=](MachineIRBuilder &B) { |
4448 | auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width); |
4449 | auto LSBCst = B.buildConstant(Res: ExtractTy, Val: LSBImm); |
4450 | B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {ShiftSrc, LSBCst, WidthCst}); |
4451 | }; |
4452 | return true; |
4453 | } |
4454 | |
4455 | bool CombinerHelper::( |
4456 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4457 | const unsigned Opcode = MI.getOpcode(); |
4458 | assert(Opcode == TargetOpcode::G_ASHR || Opcode == TargetOpcode::G_LSHR); |
4459 | |
4460 | const Register Dst = MI.getOperand(i: 0).getReg(); |
4461 | |
4462 | const unsigned ExtrOpcode = Opcode == TargetOpcode::G_ASHR |
4463 | ? TargetOpcode::G_SBFX |
4464 | : TargetOpcode::G_UBFX; |
4465 | |
4466 | // Check if the type we would use for the extract is legal |
4467 | LLT Ty = MRI.getType(Reg: Dst); |
4468 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4469 | if (!LI || !LI->isLegalOrCustom(Query: {ExtrOpcode, {Ty, ExtractTy}})) |
4470 | return false; |
4471 | |
4472 | Register ShlSrc; |
4473 | int64_t ShrAmt; |
4474 | int64_t ShlAmt; |
4475 | const unsigned Size = Ty.getScalarSizeInBits(); |
4476 | |
4477 | // Try to match shr (shl x, c1), c2 |
4478 | if (!mi_match(R: Dst, MRI, |
4479 | P: m_BinOp(Opcode, |
4480 | L: m_OneNonDBGUse(SP: m_GShl(L: m_Reg(R&: ShlSrc), R: m_ICst(Cst&: ShlAmt))), |
4481 | R: m_ICst(Cst&: ShrAmt)))) |
4482 | return false; |
4483 | |
4484 | // Make sure that the shift sizes can fit a bitfield extract |
4485 | if (ShlAmt < 0 || ShlAmt > ShrAmt || ShrAmt >= Size) |
4486 | return false; |
4487 | |
4488 | // Skip this combine if the G_SEXT_INREG combine could handle it |
4489 | if (Opcode == TargetOpcode::G_ASHR && ShlAmt == ShrAmt) |
4490 | return false; |
4491 | |
4492 | // Calculate start position and width of the extract |
4493 | const int64_t Pos = ShrAmt - ShlAmt; |
4494 | const int64_t Width = Size - ShrAmt; |
4495 | |
4496 | MatchInfo = [=](MachineIRBuilder &B) { |
4497 | auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width); |
4498 | auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos); |
4499 | B.buildInstr(Opc: ExtrOpcode, DstOps: {Dst}, SrcOps: {ShlSrc, PosCst, WidthCst}); |
4500 | }; |
4501 | return true; |
4502 | } |
4503 | |
4504 | bool CombinerHelper::matchBitfieldExtractFromShrAnd( |
4505 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4506 | const unsigned Opcode = MI.getOpcode(); |
4507 | assert(Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_ASHR); |
4508 | |
4509 | const Register Dst = MI.getOperand(i: 0).getReg(); |
4510 | LLT Ty = MRI.getType(Reg: Dst); |
4511 | LLT = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
4512 | if (LI && !LI->isLegalOrCustom(Query: {TargetOpcode::G_UBFX, {Ty, ExtractTy}})) |
4513 | return false; |
4514 | |
4515 | // Try to match shr (and x, c1), c2 |
4516 | Register AndSrc; |
4517 | int64_t ShrAmt; |
4518 | int64_t SMask; |
4519 | if (!mi_match(R: Dst, MRI, |
4520 | P: m_BinOp(Opcode, |
4521 | L: m_OneNonDBGUse(SP: m_GAnd(L: m_Reg(R&: AndSrc), R: m_ICst(Cst&: SMask))), |
4522 | R: m_ICst(Cst&: ShrAmt)))) |
4523 | return false; |
4524 | |
4525 | const unsigned Size = Ty.getScalarSizeInBits(); |
4526 | if (ShrAmt < 0 || ShrAmt >= Size) |
4527 | return false; |
4528 | |
4529 | // If the shift subsumes the mask, emit the 0 directly. |
4530 | if (0 == (SMask >> ShrAmt)) { |
4531 | MatchInfo = [=](MachineIRBuilder &B) { |
4532 | B.buildConstant(Res: Dst, Val: 0); |
4533 | }; |
4534 | return true; |
4535 | } |
4536 | |
4537 | // Check that ubfx can do the extraction, with no holes in the mask. |
4538 | uint64_t UMask = SMask; |
4539 | UMask |= maskTrailingOnes<uint64_t>(N: ShrAmt); |
4540 | UMask &= maskTrailingOnes<uint64_t>(N: Size); |
4541 | if (!isMask_64(Value: UMask)) |
4542 | return false; |
4543 | |
4544 | // Calculate start position and width of the extract. |
4545 | const int64_t Pos = ShrAmt; |
4546 | const int64_t Width = llvm::countr_one(Value: UMask) - ShrAmt; |
4547 | |
4548 | // It's preferable to keep the shift, rather than form G_SBFX. |
4549 | // TODO: remove the G_AND via demanded bits analysis. |
4550 | if (Opcode == TargetOpcode::G_ASHR && Width + ShrAmt == Size) |
4551 | return false; |
4552 | |
4553 | MatchInfo = [=](MachineIRBuilder &B) { |
4554 | auto WidthCst = B.buildConstant(Res: ExtractTy, Val: Width); |
4555 | auto PosCst = B.buildConstant(Res: ExtractTy, Val: Pos); |
4556 | B.buildInstr(Opc: TargetOpcode::G_UBFX, DstOps: {Dst}, SrcOps: {AndSrc, PosCst, WidthCst}); |
4557 | }; |
4558 | return true; |
4559 | } |
4560 | |
4561 | bool CombinerHelper::reassociationCanBreakAddressingModePattern( |
4562 | MachineInstr &MI) { |
4563 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
4564 | |
4565 | Register Src1Reg = PtrAdd.getBaseReg(); |
4566 | auto *Src1Def = getOpcodeDef<GPtrAdd>(Reg: Src1Reg, MRI); |
4567 | if (!Src1Def) |
4568 | return false; |
4569 | |
4570 | Register Src2Reg = PtrAdd.getOffsetReg(); |
4571 | |
4572 | if (MRI.hasOneNonDBGUse(RegNo: Src1Reg)) |
4573 | return false; |
4574 | |
4575 | auto C1 = getIConstantVRegVal(VReg: Src1Def->getOffsetReg(), MRI); |
4576 | if (!C1) |
4577 | return false; |
4578 | auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI); |
4579 | if (!C2) |
4580 | return false; |
4581 | |
4582 | const APInt &C1APIntVal = *C1; |
4583 | const APInt &C2APIntVal = *C2; |
4584 | const int64_t CombinedValue = (C1APIntVal + C2APIntVal).getSExtValue(); |
4585 | |
4586 | for (auto &UseMI : MRI.use_nodbg_instructions(Reg: PtrAdd.getReg(Idx: 0))) { |
4587 | // This combine may end up running before ptrtoint/inttoptr combines |
4588 | // manage to eliminate redundant conversions, so try to look through them. |
4589 | MachineInstr *ConvUseMI = &UseMI; |
4590 | unsigned ConvUseOpc = ConvUseMI->getOpcode(); |
4591 | while (ConvUseOpc == TargetOpcode::G_INTTOPTR || |
4592 | ConvUseOpc == TargetOpcode::G_PTRTOINT) { |
4593 | Register DefReg = ConvUseMI->getOperand(i: 0).getReg(); |
4594 | if (!MRI.hasOneNonDBGUse(RegNo: DefReg)) |
4595 | break; |
4596 | ConvUseMI = &*MRI.use_instr_nodbg_begin(RegNo: DefReg); |
4597 | ConvUseOpc = ConvUseMI->getOpcode(); |
4598 | } |
4599 | auto *LdStMI = dyn_cast<GLoadStore>(Val: ConvUseMI); |
4600 | if (!LdStMI) |
4601 | continue; |
4602 | // Is x[offset2] already not a legal addressing mode? If so then |
4603 | // reassociating the constants breaks nothing (we test offset2 because |
4604 | // that's the one we hope to fold into the load or store). |
4605 | TargetLoweringBase::AddrMode AM; |
4606 | AM.HasBaseReg = true; |
4607 | AM.BaseOffs = C2APIntVal.getSExtValue(); |
4608 | unsigned AS = MRI.getType(Reg: LdStMI->getPointerReg()).getAddressSpace(); |
4609 | Type *AccessTy = getTypeForLLT(Ty: LdStMI->getMMO().getMemoryType(), |
4610 | C&: PtrAdd.getMF()->getFunction().getContext()); |
4611 | const auto &TLI = *PtrAdd.getMF()->getSubtarget().getTargetLowering(); |
4612 | if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM, |
4613 | Ty: AccessTy, AddrSpace: AS)) |
4614 | continue; |
4615 | |
4616 | // Would x[offset1+offset2] still be a legal addressing mode? |
4617 | AM.BaseOffs = CombinedValue; |
4618 | if (!TLI.isLegalAddressingMode(DL: PtrAdd.getMF()->getDataLayout(), AM, |
4619 | Ty: AccessTy, AddrSpace: AS)) |
4620 | return true; |
4621 | } |
4622 | |
4623 | return false; |
4624 | } |
4625 | |
4626 | bool CombinerHelper::matchReassocConstantInnerRHS(GPtrAdd &MI, |
4627 | MachineInstr *RHS, |
4628 | BuildFnTy &MatchInfo) { |
4629 | // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C) |
4630 | Register Src1Reg = MI.getOperand(i: 1).getReg(); |
4631 | if (RHS->getOpcode() != TargetOpcode::G_ADD) |
4632 | return false; |
4633 | auto C2 = getIConstantVRegVal(VReg: RHS->getOperand(i: 2).getReg(), MRI); |
4634 | if (!C2) |
4635 | return false; |
4636 | |
4637 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4638 | LLT PtrTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
4639 | |
4640 | auto NewBase = |
4641 | Builder.buildPtrAdd(Res: PtrTy, Op0: Src1Reg, Op1: RHS->getOperand(i: 1).getReg()); |
4642 | Observer.changingInstr(MI); |
4643 | MI.getOperand(i: 1).setReg(NewBase.getReg(Idx: 0)); |
4644 | MI.getOperand(i: 2).setReg(RHS->getOperand(i: 2).getReg()); |
4645 | Observer.changedInstr(MI); |
4646 | }; |
4647 | return !reassociationCanBreakAddressingModePattern(MI); |
4648 | } |
4649 | |
4650 | bool CombinerHelper::matchReassocConstantInnerLHS(GPtrAdd &MI, |
4651 | MachineInstr *LHS, |
4652 | MachineInstr *RHS, |
4653 | BuildFnTy &MatchInfo) { |
4654 | // G_PTR_ADD (G_PTR_ADD X, C), Y) -> (G_PTR_ADD (G_PTR_ADD(X, Y), C) |
4655 | // if and only if (G_PTR_ADD X, C) has one use. |
4656 | Register LHSBase; |
4657 | std::optional<ValueAndVReg> LHSCstOff; |
4658 | if (!mi_match(R: MI.getBaseReg(), MRI, |
4659 | P: m_OneNonDBGUse(SP: m_GPtrAdd(L: m_Reg(R&: LHSBase), R: m_GCst(ValReg&: LHSCstOff))))) |
4660 | return false; |
4661 | |
4662 | auto *LHSPtrAdd = cast<GPtrAdd>(Val: LHS); |
4663 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4664 | // When we change LHSPtrAdd's offset register we might cause it to use a reg |
4665 | // before its def. Sink the instruction so the outer PTR_ADD to ensure this |
4666 | // doesn't happen. |
4667 | LHSPtrAdd->moveBefore(MovePos: &MI); |
4668 | Register RHSReg = MI.getOffsetReg(); |
4669 | // set VReg will cause type mismatch if it comes from extend/trunc |
4670 | auto NewCst = B.buildConstant(Res: MRI.getType(Reg: RHSReg), Val: LHSCstOff->Value); |
4671 | Observer.changingInstr(MI); |
4672 | MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0)); |
4673 | Observer.changedInstr(MI); |
4674 | Observer.changingInstr(MI&: *LHSPtrAdd); |
4675 | LHSPtrAdd->getOperand(i: 2).setReg(RHSReg); |
4676 | Observer.changedInstr(MI&: *LHSPtrAdd); |
4677 | }; |
4678 | return !reassociationCanBreakAddressingModePattern(MI); |
4679 | } |
4680 | |
4681 | bool CombinerHelper::matchReassocFoldConstantsInSubTree(GPtrAdd &MI, |
4682 | MachineInstr *LHS, |
4683 | MachineInstr *RHS, |
4684 | BuildFnTy &MatchInfo) { |
4685 | // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2) |
4686 | auto *LHSPtrAdd = dyn_cast<GPtrAdd>(Val: LHS); |
4687 | if (!LHSPtrAdd) |
4688 | return false; |
4689 | |
4690 | Register Src2Reg = MI.getOperand(i: 2).getReg(); |
4691 | Register LHSSrc1 = LHSPtrAdd->getBaseReg(); |
4692 | Register LHSSrc2 = LHSPtrAdd->getOffsetReg(); |
4693 | auto C1 = getIConstantVRegVal(VReg: LHSSrc2, MRI); |
4694 | if (!C1) |
4695 | return false; |
4696 | auto C2 = getIConstantVRegVal(VReg: Src2Reg, MRI); |
4697 | if (!C2) |
4698 | return false; |
4699 | |
4700 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4701 | auto NewCst = B.buildConstant(Res: MRI.getType(Reg: Src2Reg), Val: *C1 + *C2); |
4702 | Observer.changingInstr(MI); |
4703 | MI.getOperand(i: 1).setReg(LHSSrc1); |
4704 | MI.getOperand(i: 2).setReg(NewCst.getReg(Idx: 0)); |
4705 | Observer.changedInstr(MI); |
4706 | }; |
4707 | return !reassociationCanBreakAddressingModePattern(MI); |
4708 | } |
4709 | |
4710 | bool CombinerHelper::matchReassocPtrAdd(MachineInstr &MI, |
4711 | BuildFnTy &MatchInfo) { |
4712 | auto &PtrAdd = cast<GPtrAdd>(Val&: MI); |
4713 | // We're trying to match a few pointer computation patterns here for |
4714 | // re-association opportunities. |
4715 | // 1) Isolating a constant operand to be on the RHS, e.g.: |
4716 | // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C) |
4717 | // |
4718 | // 2) Folding two constants in each sub-tree as long as such folding |
4719 | // doesn't break a legal addressing mode. |
4720 | // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2) |
4721 | // |
4722 | // 3) Move a constant from the LHS of an inner op to the RHS of the outer. |
4723 | // G_PTR_ADD (G_PTR_ADD X, C), Y) -> G_PTR_ADD (G_PTR_ADD(X, Y), C) |
4724 | // iif (G_PTR_ADD X, C) has one use. |
4725 | MachineInstr *LHS = MRI.getVRegDef(Reg: PtrAdd.getBaseReg()); |
4726 | MachineInstr *RHS = MRI.getVRegDef(Reg: PtrAdd.getOffsetReg()); |
4727 | |
4728 | // Try to match example 2. |
4729 | if (matchReassocFoldConstantsInSubTree(MI&: PtrAdd, LHS, RHS, MatchInfo)) |
4730 | return true; |
4731 | |
4732 | // Try to match example 3. |
4733 | if (matchReassocConstantInnerLHS(MI&: PtrAdd, LHS, RHS, MatchInfo)) |
4734 | return true; |
4735 | |
4736 | // Try to match example 1. |
4737 | if (matchReassocConstantInnerRHS(MI&: PtrAdd, RHS, MatchInfo)) |
4738 | return true; |
4739 | |
4740 | return false; |
4741 | } |
4742 | bool CombinerHelper::tryReassocBinOp(unsigned Opc, Register DstReg, |
4743 | Register OpLHS, Register OpRHS, |
4744 | BuildFnTy &MatchInfo) { |
4745 | LLT OpRHSTy = MRI.getType(Reg: OpRHS); |
4746 | MachineInstr *OpLHSDef = MRI.getVRegDef(Reg: OpLHS); |
4747 | |
4748 | if (OpLHSDef->getOpcode() != Opc) |
4749 | return false; |
4750 | |
4751 | MachineInstr *OpRHSDef = MRI.getVRegDef(Reg: OpRHS); |
4752 | Register OpLHSLHS = OpLHSDef->getOperand(i: 1).getReg(); |
4753 | Register OpLHSRHS = OpLHSDef->getOperand(i: 2).getReg(); |
4754 | |
4755 | // If the inner op is (X op C), pull the constant out so it can be folded with |
4756 | // other constants in the expression tree. Folding is not guaranteed so we |
4757 | // might have (C1 op C2). In that case do not pull a constant out because it |
4758 | // won't help and can lead to infinite loops. |
4759 | if (isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSRHS), MRI) && |
4760 | !isConstantOrConstantSplatVector(MI&: *MRI.getVRegDef(Reg: OpLHSLHS), MRI)) { |
4761 | if (isConstantOrConstantSplatVector(MI&: *OpRHSDef, MRI)) { |
4762 | // (Opc (Opc X, C1), C2) -> (Opc X, (Opc C1, C2)) |
4763 | MatchInfo = [=](MachineIRBuilder &B) { |
4764 | auto NewCst = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSRHS, OpRHS}); |
4765 | B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {OpLHSLHS, NewCst}); |
4766 | }; |
4767 | return true; |
4768 | } |
4769 | if (getTargetLowering().isReassocProfitable(MRI, N0: OpLHS, N1: OpRHS)) { |
4770 | // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1) |
4771 | // iff (op x, c1) has one use |
4772 | MatchInfo = [=](MachineIRBuilder &B) { |
4773 | auto NewLHSLHS = B.buildInstr(Opc, DstOps: {OpRHSTy}, SrcOps: {OpLHSLHS, OpRHS}); |
4774 | B.buildInstr(Opc, DstOps: {DstReg}, SrcOps: {NewLHSLHS, OpLHSRHS}); |
4775 | }; |
4776 | return true; |
4777 | } |
4778 | } |
4779 | |
4780 | return false; |
4781 | } |
4782 | |
4783 | bool CombinerHelper::matchReassocCommBinOp(MachineInstr &MI, |
4784 | BuildFnTy &MatchInfo) { |
4785 | // We don't check if the reassociation will break a legal addressing mode |
4786 | // here since pointer arithmetic is handled by G_PTR_ADD. |
4787 | unsigned Opc = MI.getOpcode(); |
4788 | Register DstReg = MI.getOperand(i: 0).getReg(); |
4789 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
4790 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
4791 | |
4792 | if (tryReassocBinOp(Opc, DstReg, OpLHS: LHSReg, OpRHS: RHSReg, MatchInfo)) |
4793 | return true; |
4794 | if (tryReassocBinOp(Opc, DstReg, OpLHS: RHSReg, OpRHS: LHSReg, MatchInfo)) |
4795 | return true; |
4796 | return false; |
4797 | } |
4798 | |
4799 | bool CombinerHelper::matchConstantFoldCastOp(MachineInstr &MI, APInt &MatchInfo) { |
4800 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
4801 | Register SrcOp = MI.getOperand(i: 1).getReg(); |
4802 | |
4803 | if (auto MaybeCst = ConstantFoldCastOp(Opcode: MI.getOpcode(), DstTy, Op0: SrcOp, MRI)) { |
4804 | MatchInfo = *MaybeCst; |
4805 | return true; |
4806 | } |
4807 | |
4808 | return false; |
4809 | } |
4810 | |
4811 | bool CombinerHelper::matchConstantFoldBinOp(MachineInstr &MI, APInt &MatchInfo) { |
4812 | Register Op1 = MI.getOperand(i: 1).getReg(); |
4813 | Register Op2 = MI.getOperand(i: 2).getReg(); |
4814 | auto MaybeCst = ConstantFoldBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI); |
4815 | if (!MaybeCst) |
4816 | return false; |
4817 | MatchInfo = *MaybeCst; |
4818 | return true; |
4819 | } |
4820 | |
4821 | bool CombinerHelper::matchConstantFoldFPBinOp(MachineInstr &MI, ConstantFP* &MatchInfo) { |
4822 | Register Op1 = MI.getOperand(i: 1).getReg(); |
4823 | Register Op2 = MI.getOperand(i: 2).getReg(); |
4824 | auto MaybeCst = ConstantFoldFPBinOp(Opcode: MI.getOpcode(), Op1, Op2, MRI); |
4825 | if (!MaybeCst) |
4826 | return false; |
4827 | MatchInfo = |
4828 | ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: *MaybeCst); |
4829 | return true; |
4830 | } |
4831 | |
4832 | bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI, |
4833 | ConstantFP *&MatchInfo) { |
4834 | assert(MI.getOpcode() == TargetOpcode::G_FMA || |
4835 | MI.getOpcode() == TargetOpcode::G_FMAD); |
4836 | auto [_, Op1, Op2, Op3] = MI.getFirst4Regs(); |
4837 | |
4838 | const ConstantFP *Op3Cst = getConstantFPVRegVal(VReg: Op3, MRI); |
4839 | if (!Op3Cst) |
4840 | return false; |
4841 | |
4842 | const ConstantFP *Op2Cst = getConstantFPVRegVal(VReg: Op2, MRI); |
4843 | if (!Op2Cst) |
4844 | return false; |
4845 | |
4846 | const ConstantFP *Op1Cst = getConstantFPVRegVal(VReg: Op1, MRI); |
4847 | if (!Op1Cst) |
4848 | return false; |
4849 | |
4850 | APFloat Op1F = Op1Cst->getValueAPF(); |
4851 | Op1F.fusedMultiplyAdd(Multiplicand: Op2Cst->getValueAPF(), Addend: Op3Cst->getValueAPF(), |
4852 | RM: APFloat::rmNearestTiesToEven); |
4853 | MatchInfo = ConstantFP::get(Context&: MI.getMF()->getFunction().getContext(), V: Op1F); |
4854 | return true; |
4855 | } |
4856 | |
4857 | bool CombinerHelper::matchNarrowBinopFeedingAnd( |
4858 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
4859 | // Look for a binop feeding into an AND with a mask: |
4860 | // |
4861 | // %add = G_ADD %lhs, %rhs |
4862 | // %and = G_AND %add, 000...11111111 |
4863 | // |
4864 | // Check if it's possible to perform the binop at a narrower width and zext |
4865 | // back to the original width like so: |
4866 | // |
4867 | // %narrow_lhs = G_TRUNC %lhs |
4868 | // %narrow_rhs = G_TRUNC %rhs |
4869 | // %narrow_add = G_ADD %narrow_lhs, %narrow_rhs |
4870 | // %new_add = G_ZEXT %narrow_add |
4871 | // %and = G_AND %new_add, 000...11111111 |
4872 | // |
4873 | // This can allow later combines to eliminate the G_AND if it turns out |
4874 | // that the mask is irrelevant. |
4875 | assert(MI.getOpcode() == TargetOpcode::G_AND); |
4876 | Register Dst = MI.getOperand(i: 0).getReg(); |
4877 | Register AndLHS = MI.getOperand(i: 1).getReg(); |
4878 | Register AndRHS = MI.getOperand(i: 2).getReg(); |
4879 | LLT WideTy = MRI.getType(Reg: Dst); |
4880 | |
4881 | // If the potential binop has more than one use, then it's possible that one |
4882 | // of those uses will need its full width. |
4883 | if (!WideTy.isScalar() || !MRI.hasOneNonDBGUse(RegNo: AndLHS)) |
4884 | return false; |
4885 | |
4886 | // Check if the LHS feeding the AND is impacted by the high bits that we're |
4887 | // masking out. |
4888 | // |
4889 | // e.g. for 64-bit x, y: |
4890 | // |
4891 | // add_64(x, y) & 65535 == zext(add_16(trunc(x), trunc(y))) & 65535 |
4892 | MachineInstr *LHSInst = getDefIgnoringCopies(Reg: AndLHS, MRI); |
4893 | if (!LHSInst) |
4894 | return false; |
4895 | unsigned LHSOpc = LHSInst->getOpcode(); |
4896 | switch (LHSOpc) { |
4897 | default: |
4898 | return false; |
4899 | case TargetOpcode::G_ADD: |
4900 | case TargetOpcode::G_SUB: |
4901 | case TargetOpcode::G_MUL: |
4902 | case TargetOpcode::G_AND: |
4903 | case TargetOpcode::G_OR: |
4904 | case TargetOpcode::G_XOR: |
4905 | break; |
4906 | } |
4907 | |
4908 | // Find the mask on the RHS. |
4909 | auto Cst = getIConstantVRegValWithLookThrough(VReg: AndRHS, MRI); |
4910 | if (!Cst) |
4911 | return false; |
4912 | auto Mask = Cst->Value; |
4913 | if (!Mask.isMask()) |
4914 | return false; |
4915 | |
4916 | // No point in combining if there's nothing to truncate. |
4917 | unsigned NarrowWidth = Mask.countr_one(); |
4918 | if (NarrowWidth == WideTy.getSizeInBits()) |
4919 | return false; |
4920 | LLT NarrowTy = LLT::scalar(SizeInBits: NarrowWidth); |
4921 | |
4922 | // Check if adding the zext + truncates could be harmful. |
4923 | auto &MF = *MI.getMF(); |
4924 | const auto &TLI = getTargetLowering(); |
4925 | LLVMContext &Ctx = MF.getFunction().getContext(); |
4926 | auto &DL = MF.getDataLayout(); |
4927 | if (!TLI.isTruncateFree(FromTy: WideTy, ToTy: NarrowTy, DL, Ctx) || |
4928 | !TLI.isZExtFree(FromTy: NarrowTy, ToTy: WideTy, DL, Ctx)) |
4929 | return false; |
4930 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_TRUNC, {NarrowTy, WideTy}}) || |
4931 | !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ZEXT, {WideTy, NarrowTy}})) |
4932 | return false; |
4933 | Register BinOpLHS = LHSInst->getOperand(i: 1).getReg(); |
4934 | Register BinOpRHS = LHSInst->getOperand(i: 2).getReg(); |
4935 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4936 | auto NarrowLHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpLHS); |
4937 | auto NarrowRHS = Builder.buildTrunc(Res: NarrowTy, Op: BinOpRHS); |
4938 | auto NarrowBinOp = |
4939 | Builder.buildInstr(Opc: LHSOpc, DstOps: {NarrowTy}, SrcOps: {NarrowLHS, NarrowRHS}); |
4940 | auto Ext = Builder.buildZExt(Res: WideTy, Op: NarrowBinOp); |
4941 | Observer.changingInstr(MI); |
4942 | MI.getOperand(i: 1).setReg(Ext.getReg(Idx: 0)); |
4943 | Observer.changedInstr(MI); |
4944 | }; |
4945 | return true; |
4946 | } |
4947 | |
4948 | bool CombinerHelper::matchMulOBy2(MachineInstr &MI, BuildFnTy &MatchInfo) { |
4949 | unsigned Opc = MI.getOpcode(); |
4950 | assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO); |
4951 | |
4952 | if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 2))) |
4953 | return false; |
4954 | |
4955 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
4956 | Observer.changingInstr(MI); |
4957 | unsigned NewOpc = Opc == TargetOpcode::G_UMULO ? TargetOpcode::G_UADDO |
4958 | : TargetOpcode::G_SADDO; |
4959 | MI.setDesc(Builder.getTII().get(Opcode: NewOpc)); |
4960 | MI.getOperand(i: 3).setReg(MI.getOperand(i: 2).getReg()); |
4961 | Observer.changedInstr(MI); |
4962 | }; |
4963 | return true; |
4964 | } |
4965 | |
4966 | bool CombinerHelper::matchMulOBy0(MachineInstr &MI, BuildFnTy &MatchInfo) { |
4967 | // (G_*MULO x, 0) -> 0 + no carry out |
4968 | assert(MI.getOpcode() == TargetOpcode::G_UMULO || |
4969 | MI.getOpcode() == TargetOpcode::G_SMULO); |
4970 | if (!mi_match(R: MI.getOperand(i: 3).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0))) |
4971 | return false; |
4972 | Register Dst = MI.getOperand(i: 0).getReg(); |
4973 | Register Carry = MI.getOperand(i: 1).getReg(); |
4974 | if (!isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Dst)) || |
4975 | !isConstantLegalOrBeforeLegalizer(Ty: MRI.getType(Reg: Carry))) |
4976 | return false; |
4977 | MatchInfo = [=](MachineIRBuilder &B) { |
4978 | B.buildConstant(Res: Dst, Val: 0); |
4979 | B.buildConstant(Res: Carry, Val: 0); |
4980 | }; |
4981 | return true; |
4982 | } |
4983 | |
4984 | bool CombinerHelper::matchAddEToAddO(MachineInstr &MI, BuildFnTy &MatchInfo) { |
4985 | // (G_*ADDE x, y, 0) -> (G_*ADDO x, y) |
4986 | // (G_*SUBE x, y, 0) -> (G_*SUBO x, y) |
4987 | assert(MI.getOpcode() == TargetOpcode::G_UADDE || |
4988 | MI.getOpcode() == TargetOpcode::G_SADDE || |
4989 | MI.getOpcode() == TargetOpcode::G_USUBE || |
4990 | MI.getOpcode() == TargetOpcode::G_SSUBE); |
4991 | if (!mi_match(R: MI.getOperand(i: 4).getReg(), MRI, P: m_SpecificICstOrSplat(RequestedValue: 0))) |
4992 | return false; |
4993 | MatchInfo = [&](MachineIRBuilder &B) { |
4994 | unsigned NewOpcode; |
4995 | switch (MI.getOpcode()) { |
4996 | case TargetOpcode::G_UADDE: |
4997 | NewOpcode = TargetOpcode::G_UADDO; |
4998 | break; |
4999 | case TargetOpcode::G_SADDE: |
5000 | NewOpcode = TargetOpcode::G_SADDO; |
5001 | break; |
5002 | case TargetOpcode::G_USUBE: |
5003 | NewOpcode = TargetOpcode::G_USUBO; |
5004 | break; |
5005 | case TargetOpcode::G_SSUBE: |
5006 | NewOpcode = TargetOpcode::G_SSUBO; |
5007 | break; |
5008 | } |
5009 | Observer.changingInstr(MI); |
5010 | MI.setDesc(B.getTII().get(Opcode: NewOpcode)); |
5011 | MI.removeOperand(OpNo: 4); |
5012 | Observer.changedInstr(MI); |
5013 | }; |
5014 | return true; |
5015 | } |
5016 | |
5017 | bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI, |
5018 | BuildFnTy &MatchInfo) { |
5019 | assert(MI.getOpcode() == TargetOpcode::G_SUB); |
5020 | Register Dst = MI.getOperand(i: 0).getReg(); |
5021 | // (x + y) - z -> x (if y == z) |
5022 | // (x + y) - z -> y (if x == z) |
5023 | Register X, Y, Z; |
5024 | if (mi_match(R: Dst, MRI, P: m_GSub(L: m_GAdd(L: m_Reg(R&: X), R: m_Reg(R&: Y)), R: m_Reg(R&: Z)))) { |
5025 | Register ReplaceReg; |
5026 | int64_t CstX, CstY; |
5027 | if (Y == Z || (mi_match(R: Y, MRI, P: m_ICstOrSplat(Cst&: CstY)) && |
5028 | mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstY)))) |
5029 | ReplaceReg = X; |
5030 | else if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) && |
5031 | mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX)))) |
5032 | ReplaceReg = Y; |
5033 | if (ReplaceReg) { |
5034 | MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Res: Dst, Op: ReplaceReg); }; |
5035 | return true; |
5036 | } |
5037 | } |
5038 | |
5039 | // x - (y + z) -> 0 - y (if x == z) |
5040 | // x - (y + z) -> 0 - z (if x == y) |
5041 | if (mi_match(R: Dst, MRI, P: m_GSub(L: m_Reg(R&: X), R: m_GAdd(L: m_Reg(R&: Y), R: m_Reg(R&: Z))))) { |
5042 | Register ReplaceReg; |
5043 | int64_t CstX; |
5044 | if (X == Z || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) && |
5045 | mi_match(R: Z, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX)))) |
5046 | ReplaceReg = Y; |
5047 | else if (X == Y || (mi_match(R: X, MRI, P: m_ICstOrSplat(Cst&: CstX)) && |
5048 | mi_match(R: Y, MRI, P: m_SpecificICstOrSplat(RequestedValue: CstX)))) |
5049 | ReplaceReg = Z; |
5050 | if (ReplaceReg) { |
5051 | MatchInfo = [=](MachineIRBuilder &B) { |
5052 | auto Zero = B.buildConstant(Res: MRI.getType(Reg: Dst), Val: 0); |
5053 | B.buildSub(Dst, Src0: Zero, Src1: ReplaceReg); |
5054 | }; |
5055 | return true; |
5056 | } |
5057 | } |
5058 | return false; |
5059 | } |
5060 | |
5061 | MachineInstr *CombinerHelper::buildUDivUsingMul(MachineInstr &MI) { |
5062 | assert(MI.getOpcode() == TargetOpcode::G_UDIV); |
5063 | auto &UDiv = cast<GenericMachineInstr>(Val&: MI); |
5064 | Register Dst = UDiv.getReg(Idx: 0); |
5065 | Register LHS = UDiv.getReg(Idx: 1); |
5066 | Register RHS = UDiv.getReg(Idx: 2); |
5067 | LLT Ty = MRI.getType(Reg: Dst); |
5068 | LLT ScalarTy = Ty.getScalarType(); |
5069 | const unsigned EltBits = ScalarTy.getScalarSizeInBits(); |
5070 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5071 | LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType(); |
5072 | auto &MIB = Builder; |
5073 | |
5074 | bool UseNPQ = false; |
5075 | SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors; |
5076 | |
5077 | auto BuildUDIVPattern = [&](const Constant *C) { |
5078 | auto *CI = cast<ConstantInt>(Val: C); |
5079 | const APInt &Divisor = CI->getValue(); |
5080 | |
5081 | bool SelNPQ = false; |
5082 | APInt Magic(Divisor.getBitWidth(), 0); |
5083 | unsigned PreShift = 0, PostShift = 0; |
5084 | |
5085 | // Magic algorithm doesn't work for division by 1. We need to emit a select |
5086 | // at the end. |
5087 | // TODO: Use undef values for divisor of 1. |
5088 | if (!Divisor.isOne()) { |
5089 | UnsignedDivisionByConstantInfo magics = |
5090 | UnsignedDivisionByConstantInfo::get(D: Divisor); |
5091 | |
5092 | Magic = std::move(magics.Magic); |
5093 | |
5094 | assert(magics.PreShift < Divisor.getBitWidth() && |
5095 | "We shouldn't generate an undefined shift!" ); |
5096 | assert(magics.PostShift < Divisor.getBitWidth() && |
5097 | "We shouldn't generate an undefined shift!" ); |
5098 | assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift" ); |
5099 | PreShift = magics.PreShift; |
5100 | PostShift = magics.PostShift; |
5101 | SelNPQ = magics.IsAdd; |
5102 | } |
5103 | |
5104 | PreShifts.push_back( |
5105 | Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PreShift).getReg(Idx: 0)); |
5106 | MagicFactors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Magic).getReg(Idx: 0)); |
5107 | NPQFactors.push_back( |
5108 | Elt: MIB.buildConstant(Res: ScalarTy, |
5109 | Val: SelNPQ ? APInt::getOneBitSet(numBits: EltBits, BitNo: EltBits - 1) |
5110 | : APInt::getZero(numBits: EltBits)) |
5111 | .getReg(Idx: 0)); |
5112 | PostShifts.push_back( |
5113 | Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: PostShift).getReg(Idx: 0)); |
5114 | UseNPQ |= SelNPQ; |
5115 | return true; |
5116 | }; |
5117 | |
5118 | // Collect the shifts/magic values from each element. |
5119 | bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildUDIVPattern); |
5120 | (void)Matched; |
5121 | assert(Matched && "Expected unary predicate match to succeed" ); |
5122 | |
5123 | Register PreShift, PostShift, MagicFactor, NPQFactor; |
5124 | auto *RHSDef = getOpcodeDef<GBuildVector>(Reg: RHS, MRI); |
5125 | if (RHSDef) { |
5126 | PreShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PreShifts).getReg(Idx: 0); |
5127 | MagicFactor = MIB.buildBuildVector(Res: Ty, Ops: MagicFactors).getReg(Idx: 0); |
5128 | NPQFactor = MIB.buildBuildVector(Res: Ty, Ops: NPQFactors).getReg(Idx: 0); |
5129 | PostShift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: PostShifts).getReg(Idx: 0); |
5130 | } else { |
5131 | assert(MRI.getType(RHS).isScalar() && |
5132 | "Non-build_vector operation should have been a scalar" ); |
5133 | PreShift = PreShifts[0]; |
5134 | MagicFactor = MagicFactors[0]; |
5135 | PostShift = PostShifts[0]; |
5136 | } |
5137 | |
5138 | Register Q = LHS; |
5139 | Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PreShift).getReg(Idx: 0); |
5140 | |
5141 | // Multiply the numerator (operand 0) by the magic value. |
5142 | Q = MIB.buildUMulH(Dst: Ty, Src0: Q, Src1: MagicFactor).getReg(Idx: 0); |
5143 | |
5144 | if (UseNPQ) { |
5145 | Register NPQ = MIB.buildSub(Dst: Ty, Src0: LHS, Src1: Q).getReg(Idx: 0); |
5146 | |
5147 | // For vectors we might have a mix of non-NPQ/NPQ paths, so use |
5148 | // G_UMULH to act as a SRL-by-1 for NPQ, else multiply by zero. |
5149 | if (Ty.isVector()) |
5150 | NPQ = MIB.buildUMulH(Dst: Ty, Src0: NPQ, Src1: NPQFactor).getReg(Idx: 0); |
5151 | else |
5152 | NPQ = MIB.buildLShr(Dst: Ty, Src0: NPQ, Src1: MIB.buildConstant(Res: ShiftAmtTy, Val: 1)).getReg(Idx: 0); |
5153 | |
5154 | Q = MIB.buildAdd(Dst: Ty, Src0: NPQ, Src1: Q).getReg(Idx: 0); |
5155 | } |
5156 | |
5157 | Q = MIB.buildLShr(Dst: Ty, Src0: Q, Src1: PostShift).getReg(Idx: 0); |
5158 | auto One = MIB.buildConstant(Res: Ty, Val: 1); |
5159 | auto IsOne = MIB.buildICmp( |
5160 | Pred: CmpInst::Predicate::ICMP_EQ, |
5161 | Res: Ty.isScalar() ? LLT::scalar(SizeInBits: 1) : Ty.changeElementSize(NewEltSize: 1), Op0: RHS, Op1: One); |
5162 | return MIB.buildSelect(Res: Ty, Tst: IsOne, Op0: LHS, Op1: Q); |
5163 | } |
5164 | |
5165 | bool CombinerHelper::matchUDivByConst(MachineInstr &MI) { |
5166 | assert(MI.getOpcode() == TargetOpcode::G_UDIV); |
5167 | Register Dst = MI.getOperand(i: 0).getReg(); |
5168 | Register RHS = MI.getOperand(i: 2).getReg(); |
5169 | LLT DstTy = MRI.getType(Reg: Dst); |
5170 | auto *RHSDef = MRI.getVRegDef(Reg: RHS); |
5171 | if (!isConstantOrConstantVector(MI&: *RHSDef, MRI)) |
5172 | return false; |
5173 | |
5174 | auto &MF = *MI.getMF(); |
5175 | AttributeList Attr = MF.getFunction().getAttributes(); |
5176 | const auto &TLI = getTargetLowering(); |
5177 | LLVMContext &Ctx = MF.getFunction().getContext(); |
5178 | auto &DL = MF.getDataLayout(); |
5179 | if (TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, DL, Ctx), Attr)) |
5180 | return false; |
5181 | |
5182 | // Don't do this for minsize because the instruction sequence is usually |
5183 | // larger. |
5184 | if (MF.getFunction().hasMinSize()) |
5185 | return false; |
5186 | |
5187 | // Don't do this if the types are not going to be legal. |
5188 | if (LI) { |
5189 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_MUL, {DstTy, DstTy}})) |
5190 | return false; |
5191 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMULH, {DstTy}})) |
5192 | return false; |
5193 | if (!isLegalOrBeforeLegalizer( |
5194 | Query: {TargetOpcode::G_ICMP, |
5195 | {DstTy.isVector() ? DstTy.changeElementSize(NewEltSize: 1) : LLT::scalar(SizeInBits: 1), |
5196 | DstTy}})) |
5197 | return false; |
5198 | } |
5199 | |
5200 | return matchUnaryPredicate( |
5201 | MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); }); |
5202 | } |
5203 | |
5204 | void CombinerHelper::applyUDivByConst(MachineInstr &MI) { |
5205 | auto *NewMI = buildUDivUsingMul(MI); |
5206 | replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg()); |
5207 | } |
5208 | |
5209 | bool CombinerHelper::matchSDivByConst(MachineInstr &MI) { |
5210 | assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV" ); |
5211 | Register Dst = MI.getOperand(i: 0).getReg(); |
5212 | Register RHS = MI.getOperand(i: 2).getReg(); |
5213 | LLT DstTy = MRI.getType(Reg: Dst); |
5214 | |
5215 | auto &MF = *MI.getMF(); |
5216 | AttributeList Attr = MF.getFunction().getAttributes(); |
5217 | const auto &TLI = getTargetLowering(); |
5218 | LLVMContext &Ctx = MF.getFunction().getContext(); |
5219 | auto &DL = MF.getDataLayout(); |
5220 | if (TLI.isIntDivCheap(VT: getApproximateEVTForLLT(Ty: DstTy, DL, Ctx), Attr)) |
5221 | return false; |
5222 | |
5223 | // Don't do this for minsize because the instruction sequence is usually |
5224 | // larger. |
5225 | if (MF.getFunction().hasMinSize()) |
5226 | return false; |
5227 | |
5228 | // If the sdiv has an 'exact' flag we can use a simpler lowering. |
5229 | if (MI.getFlag(Flag: MachineInstr::MIFlag::IsExact)) { |
5230 | return matchUnaryPredicate( |
5231 | MRI, Reg: RHS, Match: [](const Constant *C) { return C && !C->isNullValue(); }); |
5232 | } |
5233 | |
5234 | // Don't support the general case for now. |
5235 | return false; |
5236 | } |
5237 | |
5238 | void CombinerHelper::applySDivByConst(MachineInstr &MI) { |
5239 | auto *NewMI = buildSDivUsingMul(MI); |
5240 | replaceSingleDefInstWithReg(MI, Replacement: NewMI->getOperand(i: 0).getReg()); |
5241 | } |
5242 | |
5243 | MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) { |
5244 | assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV" ); |
5245 | auto &SDiv = cast<GenericMachineInstr>(Val&: MI); |
5246 | Register Dst = SDiv.getReg(Idx: 0); |
5247 | Register LHS = SDiv.getReg(Idx: 1); |
5248 | Register RHS = SDiv.getReg(Idx: 2); |
5249 | LLT Ty = MRI.getType(Reg: Dst); |
5250 | LLT ScalarTy = Ty.getScalarType(); |
5251 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5252 | LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType(); |
5253 | auto &MIB = Builder; |
5254 | |
5255 | bool UseSRA = false; |
5256 | SmallVector<Register, 16> Shifts, Factors; |
5257 | |
5258 | auto *RHSDef = cast<GenericMachineInstr>(Val: getDefIgnoringCopies(Reg: RHS, MRI)); |
5259 | bool IsSplat = getIConstantSplatVal(MI: *RHSDef, MRI).has_value(); |
5260 | |
5261 | auto BuildSDIVPattern = [&](const Constant *C) { |
5262 | // Don't recompute inverses for each splat element. |
5263 | if (IsSplat && !Factors.empty()) { |
5264 | Shifts.push_back(Elt: Shifts[0]); |
5265 | Factors.push_back(Elt: Factors[0]); |
5266 | return true; |
5267 | } |
5268 | |
5269 | auto *CI = cast<ConstantInt>(Val: C); |
5270 | APInt Divisor = CI->getValue(); |
5271 | unsigned Shift = Divisor.countr_zero(); |
5272 | if (Shift) { |
5273 | Divisor.ashrInPlace(ShiftAmt: Shift); |
5274 | UseSRA = true; |
5275 | } |
5276 | |
5277 | // Calculate the multiplicative inverse modulo BW. |
5278 | // 2^W requires W + 1 bits, so we have to extend and then truncate. |
5279 | APInt Factor = Divisor.multiplicativeInverse(); |
5280 | Shifts.push_back(Elt: MIB.buildConstant(Res: ScalarShiftAmtTy, Val: Shift).getReg(Idx: 0)); |
5281 | Factors.push_back(Elt: MIB.buildConstant(Res: ScalarTy, Val: Factor).getReg(Idx: 0)); |
5282 | return true; |
5283 | }; |
5284 | |
5285 | // Collect all magic values from the build vector. |
5286 | bool Matched = matchUnaryPredicate(MRI, Reg: RHS, Match: BuildSDIVPattern); |
5287 | (void)Matched; |
5288 | assert(Matched && "Expected unary predicate match to succeed" ); |
5289 | |
5290 | Register Shift, Factor; |
5291 | if (Ty.isVector()) { |
5292 | Shift = MIB.buildBuildVector(Res: ShiftAmtTy, Ops: Shifts).getReg(Idx: 0); |
5293 | Factor = MIB.buildBuildVector(Res: Ty, Ops: Factors).getReg(Idx: 0); |
5294 | } else { |
5295 | Shift = Shifts[0]; |
5296 | Factor = Factors[0]; |
5297 | } |
5298 | |
5299 | Register Res = LHS; |
5300 | |
5301 | if (UseSRA) |
5302 | Res = MIB.buildAShr(Dst: Ty, Src0: Res, Src1: Shift, Flags: MachineInstr::IsExact).getReg(Idx: 0); |
5303 | |
5304 | return MIB.buildMul(Dst: Ty, Src0: Res, Src1: Factor); |
5305 | } |
5306 | |
5307 | bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) { |
5308 | assert((MI.getOpcode() == TargetOpcode::G_SDIV || |
5309 | MI.getOpcode() == TargetOpcode::G_UDIV) && |
5310 | "Expected SDIV or UDIV" ); |
5311 | auto &Div = cast<GenericMachineInstr>(Val&: MI); |
5312 | Register RHS = Div.getReg(Idx: 2); |
5313 | auto MatchPow2 = [&](const Constant *C) { |
5314 | auto *CI = dyn_cast<ConstantInt>(Val: C); |
5315 | return CI && (CI->getValue().isPowerOf2() || |
5316 | (IsSigned && CI->getValue().isNegatedPowerOf2())); |
5317 | }; |
5318 | return matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2, /*AllowUndefs=*/false); |
5319 | } |
5320 | |
5321 | void CombinerHelper::applySDivByPow2(MachineInstr &MI) { |
5322 | assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV" ); |
5323 | auto &SDiv = cast<GenericMachineInstr>(Val&: MI); |
5324 | Register Dst = SDiv.getReg(Idx: 0); |
5325 | Register LHS = SDiv.getReg(Idx: 1); |
5326 | Register RHS = SDiv.getReg(Idx: 2); |
5327 | LLT Ty = MRI.getType(Reg: Dst); |
5328 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5329 | LLT CCVT = |
5330 | Ty.isVector() ? LLT::vector(EC: Ty.getElementCount(), ScalarSizeInBits: 1) : LLT::scalar(SizeInBits: 1); |
5331 | |
5332 | // Effectively we want to lower G_SDIV %lhs, %rhs, where %rhs is a power of 2, |
5333 | // to the following version: |
5334 | // |
5335 | // %c1 = G_CTTZ %rhs |
5336 | // %inexact = G_SUB $bitwidth, %c1 |
5337 | // %sign = %G_ASHR %lhs, $(bitwidth - 1) |
5338 | // %lshr = G_LSHR %sign, %inexact |
5339 | // %add = G_ADD %lhs, %lshr |
5340 | // %ashr = G_ASHR %add, %c1 |
5341 | // %ashr = G_SELECT, %isoneorallones, %lhs, %ashr |
5342 | // %zero = G_CONSTANT $0 |
5343 | // %neg = G_NEG %ashr |
5344 | // %isneg = G_ICMP SLT %rhs, %zero |
5345 | // %res = G_SELECT %isneg, %neg, %ashr |
5346 | |
5347 | unsigned BitWidth = Ty.getScalarSizeInBits(); |
5348 | auto Zero = Builder.buildConstant(Res: Ty, Val: 0); |
5349 | |
5350 | auto Bits = Builder.buildConstant(Res: ShiftAmtTy, Val: BitWidth); |
5351 | auto C1 = Builder.buildCTTZ(Dst: ShiftAmtTy, Src0: RHS); |
5352 | auto Inexact = Builder.buildSub(Dst: ShiftAmtTy, Src0: Bits, Src1: C1); |
5353 | // Splat the sign bit into the register |
5354 | auto Sign = Builder.buildAShr( |
5355 | Dst: Ty, Src0: LHS, Src1: Builder.buildConstant(Res: ShiftAmtTy, Val: BitWidth - 1)); |
5356 | |
5357 | // Add (LHS < 0) ? abs2 - 1 : 0; |
5358 | auto LSrl = Builder.buildLShr(Dst: Ty, Src0: Sign, Src1: Inexact); |
5359 | auto Add = Builder.buildAdd(Dst: Ty, Src0: LHS, Src1: LSrl); |
5360 | auto AShr = Builder.buildAShr(Dst: Ty, Src0: Add, Src1: C1); |
5361 | |
5362 | // Special case: (sdiv X, 1) -> X |
5363 | // Special Case: (sdiv X, -1) -> 0-X |
5364 | auto One = Builder.buildConstant(Res: Ty, Val: 1); |
5365 | auto MinusOne = Builder.buildConstant(Res: Ty, Val: -1); |
5366 | auto IsOne = Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: CCVT, Op0: RHS, Op1: One); |
5367 | auto IsMinusOne = |
5368 | Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_EQ, Res: CCVT, Op0: RHS, Op1: MinusOne); |
5369 | auto IsOneOrMinusOne = Builder.buildOr(Dst: CCVT, Src0: IsOne, Src1: IsMinusOne); |
5370 | AShr = Builder.buildSelect(Res: Ty, Tst: IsOneOrMinusOne, Op0: LHS, Op1: AShr); |
5371 | |
5372 | // If divided by a positive value, we're done. Otherwise, the result must be |
5373 | // negated. |
5374 | auto Neg = Builder.buildNeg(Dst: Ty, Src0: AShr); |
5375 | auto IsNeg = Builder.buildICmp(Pred: CmpInst::Predicate::ICMP_SLT, Res: CCVT, Op0: RHS, Op1: Zero); |
5376 | Builder.buildSelect(Res: MI.getOperand(i: 0).getReg(), Tst: IsNeg, Op0: Neg, Op1: AShr); |
5377 | MI.eraseFromParent(); |
5378 | } |
5379 | |
5380 | void CombinerHelper::applyUDivByPow2(MachineInstr &MI) { |
5381 | assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV" ); |
5382 | auto &UDiv = cast<GenericMachineInstr>(Val&: MI); |
5383 | Register Dst = UDiv.getReg(Idx: 0); |
5384 | Register LHS = UDiv.getReg(Idx: 1); |
5385 | Register RHS = UDiv.getReg(Idx: 2); |
5386 | LLT Ty = MRI.getType(Reg: Dst); |
5387 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5388 | |
5389 | auto C1 = Builder.buildCTTZ(Dst: ShiftAmtTy, Src0: RHS); |
5390 | Builder.buildLShr(Dst: MI.getOperand(i: 0).getReg(), Src0: LHS, Src1: C1); |
5391 | MI.eraseFromParent(); |
5392 | } |
5393 | |
5394 | bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) { |
5395 | assert(MI.getOpcode() == TargetOpcode::G_UMULH); |
5396 | Register RHS = MI.getOperand(i: 2).getReg(); |
5397 | Register Dst = MI.getOperand(i: 0).getReg(); |
5398 | LLT Ty = MRI.getType(Reg: Dst); |
5399 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5400 | auto MatchPow2ExceptOne = [&](const Constant *C) { |
5401 | if (auto *CI = dyn_cast<ConstantInt>(Val: C)) |
5402 | return CI->getValue().isPowerOf2() && !CI->getValue().isOne(); |
5403 | return false; |
5404 | }; |
5405 | if (!matchUnaryPredicate(MRI, Reg: RHS, Match: MatchPow2ExceptOne, AllowUndefs: false)) |
5406 | return false; |
5407 | return isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_LSHR, {Ty, ShiftAmtTy}}); |
5408 | } |
5409 | |
5410 | void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) { |
5411 | Register LHS = MI.getOperand(i: 1).getReg(); |
5412 | Register RHS = MI.getOperand(i: 2).getReg(); |
5413 | Register Dst = MI.getOperand(i: 0).getReg(); |
5414 | LLT Ty = MRI.getType(Reg: Dst); |
5415 | LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(ShiftValueTy: Ty); |
5416 | unsigned NumEltBits = Ty.getScalarSizeInBits(); |
5417 | |
5418 | auto LogBase2 = buildLogBase2(V: RHS, MIB&: Builder); |
5419 | auto ShiftAmt = |
5420 | Builder.buildSub(Dst: Ty, Src0: Builder.buildConstant(Res: Ty, Val: NumEltBits), Src1: LogBase2); |
5421 | auto Trunc = Builder.buildZExtOrTrunc(Res: ShiftAmtTy, Op: ShiftAmt); |
5422 | Builder.buildLShr(Dst, Src0: LHS, Src1: Trunc); |
5423 | MI.eraseFromParent(); |
5424 | } |
5425 | |
5426 | bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI, |
5427 | BuildFnTy &MatchInfo) { |
5428 | unsigned Opc = MI.getOpcode(); |
5429 | assert(Opc == TargetOpcode::G_FADD || Opc == TargetOpcode::G_FSUB || |
5430 | Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV || |
5431 | Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA); |
5432 | |
5433 | Register Dst = MI.getOperand(i: 0).getReg(); |
5434 | Register X = MI.getOperand(i: 1).getReg(); |
5435 | Register Y = MI.getOperand(i: 2).getReg(); |
5436 | LLT Type = MRI.getType(Reg: Dst); |
5437 | |
5438 | // fold (fadd x, fneg(y)) -> (fsub x, y) |
5439 | // fold (fadd fneg(y), x) -> (fsub x, y) |
5440 | // G_ADD is commutative so both cases are checked by m_GFAdd |
5441 | if (mi_match(R: Dst, MRI, P: m_GFAdd(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) && |
5442 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FSUB, {Type}})) { |
5443 | Opc = TargetOpcode::G_FSUB; |
5444 | } |
5445 | /// fold (fsub x, fneg(y)) -> (fadd x, y) |
5446 | else if (mi_match(R: Dst, MRI, P: m_GFSub(L: m_Reg(R&: X), R: m_GFNeg(Src: m_Reg(R&: Y)))) && |
5447 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FADD, {Type}})) { |
5448 | Opc = TargetOpcode::G_FADD; |
5449 | } |
5450 | // fold (fmul fneg(x), fneg(y)) -> (fmul x, y) |
5451 | // fold (fdiv fneg(x), fneg(y)) -> (fdiv x, y) |
5452 | // fold (fmad fneg(x), fneg(y), z) -> (fmad x, y, z) |
5453 | // fold (fma fneg(x), fneg(y), z) -> (fma x, y, z) |
5454 | else if ((Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV || |
5455 | Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA) && |
5456 | mi_match(R: X, MRI, P: m_GFNeg(Src: m_Reg(R&: X))) && |
5457 | mi_match(R: Y, MRI, P: m_GFNeg(Src: m_Reg(R&: Y)))) { |
5458 | // no opcode change |
5459 | } else |
5460 | return false; |
5461 | |
5462 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5463 | Observer.changingInstr(MI); |
5464 | MI.setDesc(B.getTII().get(Opcode: Opc)); |
5465 | MI.getOperand(i: 1).setReg(X); |
5466 | MI.getOperand(i: 2).setReg(Y); |
5467 | Observer.changedInstr(MI); |
5468 | }; |
5469 | return true; |
5470 | } |
5471 | |
5472 | bool CombinerHelper::matchFsubToFneg(MachineInstr &MI, Register &MatchInfo) { |
5473 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
5474 | |
5475 | Register LHS = MI.getOperand(i: 1).getReg(); |
5476 | MatchInfo = MI.getOperand(i: 2).getReg(); |
5477 | LLT Ty = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5478 | |
5479 | const auto LHSCst = Ty.isVector() |
5480 | ? getFConstantSplat(VReg: LHS, MRI, /* allowUndef */ AllowUndef: true) |
5481 | : getFConstantVRegValWithLookThrough(VReg: LHS, MRI); |
5482 | if (!LHSCst) |
5483 | return false; |
5484 | |
5485 | // -0.0 is always allowed |
5486 | if (LHSCst->Value.isNegZero()) |
5487 | return true; |
5488 | |
5489 | // +0.0 is only allowed if nsz is set. |
5490 | if (LHSCst->Value.isPosZero()) |
5491 | return MI.getFlag(Flag: MachineInstr::FmNsz); |
5492 | |
5493 | return false; |
5494 | } |
5495 | |
5496 | void CombinerHelper::applyFsubToFneg(MachineInstr &MI, Register &MatchInfo) { |
5497 | Register Dst = MI.getOperand(i: 0).getReg(); |
5498 | Builder.buildFNeg( |
5499 | Dst, Src0: Builder.buildFCanonicalize(Dst: MRI.getType(Reg: Dst), Src0: MatchInfo).getReg(Idx: 0)); |
5500 | eraseInst(MI); |
5501 | } |
5502 | |
5503 | /// Checks if \p MI is TargetOpcode::G_FMUL and contractable either |
5504 | /// due to global flags or MachineInstr flags. |
5505 | static bool isContractableFMul(MachineInstr &MI, bool AllowFusionGlobally) { |
5506 | if (MI.getOpcode() != TargetOpcode::G_FMUL) |
5507 | return false; |
5508 | return AllowFusionGlobally || MI.getFlag(Flag: MachineInstr::MIFlag::FmContract); |
5509 | } |
5510 | |
5511 | static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1, |
5512 | const MachineRegisterInfo &MRI) { |
5513 | return std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI0.getOperand(i: 0).getReg()), |
5514 | last: MRI.use_instr_nodbg_end()) > |
5515 | std::distance(first: MRI.use_instr_nodbg_begin(RegNo: MI1.getOperand(i: 0).getReg()), |
5516 | last: MRI.use_instr_nodbg_end()); |
5517 | } |
5518 | |
5519 | bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI, |
5520 | bool &AllowFusionGlobally, |
5521 | bool &HasFMAD, bool &Aggressive, |
5522 | bool CanReassociate) { |
5523 | |
5524 | auto *MF = MI.getMF(); |
5525 | const auto &TLI = *MF->getSubtarget().getTargetLowering(); |
5526 | const TargetOptions &Options = MF->getTarget().Options; |
5527 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5528 | |
5529 | if (CanReassociate && |
5530 | !(Options.UnsafeFPMath || MI.getFlag(Flag: MachineInstr::MIFlag::FmReassoc))) |
5531 | return false; |
5532 | |
5533 | // Floating-point multiply-add with intermediate rounding. |
5534 | HasFMAD = (!isPreLegalize() && TLI.isFMADLegal(MI, Ty: DstType)); |
5535 | // Floating-point multiply-add without intermediate rounding. |
5536 | bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(MF: *MF, DstType) && |
5537 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_FMA, {DstType}}); |
5538 | // No valid opcode, do not combine. |
5539 | if (!HasFMAD && !HasFMA) |
5540 | return false; |
5541 | |
5542 | AllowFusionGlobally = Options.AllowFPOpFusion == FPOpFusion::Fast || |
5543 | Options.UnsafeFPMath || HasFMAD; |
5544 | // If the addition is not contractable, do not combine. |
5545 | if (!AllowFusionGlobally && !MI.getFlag(Flag: MachineInstr::MIFlag::FmContract)) |
5546 | return false; |
5547 | |
5548 | Aggressive = TLI.enableAggressiveFMAFusion(Ty: DstType); |
5549 | return true; |
5550 | } |
5551 | |
5552 | bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA( |
5553 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5554 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5555 | |
5556 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5557 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5558 | return false; |
5559 | |
5560 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5561 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5562 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5563 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5564 | unsigned PreferredFusedOpcode = |
5565 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5566 | |
5567 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5568 | // prefer to fold the multiply with fewer uses. |
5569 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5570 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5571 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5572 | std::swap(a&: LHS, b&: RHS); |
5573 | } |
5574 | |
5575 | // fold (fadd (fmul x, y), z) -> (fma x, y, z) |
5576 | if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5577 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg))) { |
5578 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5579 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5580 | SrcOps: {LHS.MI->getOperand(i: 1).getReg(), |
5581 | LHS.MI->getOperand(i: 2).getReg(), RHS.Reg}); |
5582 | }; |
5583 | return true; |
5584 | } |
5585 | |
5586 | // fold (fadd x, (fmul y, z)) -> (fma y, z, x) |
5587 | if (isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) && |
5588 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg))) { |
5589 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5590 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5591 | SrcOps: {RHS.MI->getOperand(i: 1).getReg(), |
5592 | RHS.MI->getOperand(i: 2).getReg(), LHS.Reg}); |
5593 | }; |
5594 | return true; |
5595 | } |
5596 | |
5597 | return false; |
5598 | } |
5599 | |
5600 | bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA( |
5601 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5602 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5603 | |
5604 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5605 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5606 | return false; |
5607 | |
5608 | const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); |
5609 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5610 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5611 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5612 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5613 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5614 | |
5615 | unsigned PreferredFusedOpcode = |
5616 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5617 | |
5618 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5619 | // prefer to fold the multiply with fewer uses. |
5620 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5621 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5622 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5623 | std::swap(a&: LHS, b&: RHS); |
5624 | } |
5625 | |
5626 | // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) |
5627 | MachineInstr *FpExtSrc; |
5628 | if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) && |
5629 | isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) && |
5630 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5631 | SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) { |
5632 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5633 | auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg()); |
5634 | auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg()); |
5635 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5636 | SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), RHS.Reg}); |
5637 | }; |
5638 | return true; |
5639 | } |
5640 | |
5641 | // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z) |
5642 | // Note: Commutes FADD operands. |
5643 | if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FpExtSrc))) && |
5644 | isContractableFMul(MI&: *FpExtSrc, AllowFusionGlobally) && |
5645 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5646 | SrcTy: MRI.getType(Reg: FpExtSrc->getOperand(i: 1).getReg()))) { |
5647 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5648 | auto FpExtX = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 1).getReg()); |
5649 | auto FpExtY = B.buildFPExt(Res: DstType, Op: FpExtSrc->getOperand(i: 2).getReg()); |
5650 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5651 | SrcOps: {FpExtX.getReg(Idx: 0), FpExtY.getReg(Idx: 0), LHS.Reg}); |
5652 | }; |
5653 | return true; |
5654 | } |
5655 | |
5656 | return false; |
5657 | } |
5658 | |
5659 | bool CombinerHelper::matchCombineFAddFMAFMulToFMadOrFMA( |
5660 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5661 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5662 | |
5663 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5664 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive, CanReassociate: true)) |
5665 | return false; |
5666 | |
5667 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5668 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5669 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5670 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5671 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5672 | |
5673 | unsigned PreferredFusedOpcode = |
5674 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5675 | |
5676 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5677 | // prefer to fold the multiply with fewer uses. |
5678 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5679 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5680 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5681 | std::swap(a&: LHS, b&: RHS); |
5682 | } |
5683 | |
5684 | MachineInstr *FMA = nullptr; |
5685 | Register Z; |
5686 | // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y, (fma u, v, z)) |
5687 | if (LHS.MI->getOpcode() == PreferredFusedOpcode && |
5688 | (MRI.getVRegDef(Reg: LHS.MI->getOperand(i: 3).getReg())->getOpcode() == |
5689 | TargetOpcode::G_FMUL) && |
5690 | MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 0).getReg()) && |
5691 | MRI.hasOneNonDBGUse(RegNo: LHS.MI->getOperand(i: 3).getReg())) { |
5692 | FMA = LHS.MI; |
5693 | Z = RHS.Reg; |
5694 | } |
5695 | // fold (fadd z, (fma x, y, (fmul u, v))) -> (fma x, y, (fma u, v, z)) |
5696 | else if (RHS.MI->getOpcode() == PreferredFusedOpcode && |
5697 | (MRI.getVRegDef(Reg: RHS.MI->getOperand(i: 3).getReg())->getOpcode() == |
5698 | TargetOpcode::G_FMUL) && |
5699 | MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 0).getReg()) && |
5700 | MRI.hasOneNonDBGUse(RegNo: RHS.MI->getOperand(i: 3).getReg())) { |
5701 | Z = LHS.Reg; |
5702 | FMA = RHS.MI; |
5703 | } |
5704 | |
5705 | if (FMA) { |
5706 | MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMA->getOperand(i: 3).getReg()); |
5707 | Register X = FMA->getOperand(i: 1).getReg(); |
5708 | Register Y = FMA->getOperand(i: 2).getReg(); |
5709 | Register U = FMulMI->getOperand(i: 1).getReg(); |
5710 | Register V = FMulMI->getOperand(i: 2).getReg(); |
5711 | |
5712 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5713 | Register InnerFMA = MRI.createGenericVirtualRegister(Ty: DstTy); |
5714 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {InnerFMA}, SrcOps: {U, V, Z}); |
5715 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5716 | SrcOps: {X, Y, InnerFMA}); |
5717 | }; |
5718 | return true; |
5719 | } |
5720 | |
5721 | return false; |
5722 | } |
5723 | |
5724 | bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive( |
5725 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5726 | assert(MI.getOpcode() == TargetOpcode::G_FADD); |
5727 | |
5728 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5729 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5730 | return false; |
5731 | |
5732 | if (!Aggressive) |
5733 | return false; |
5734 | |
5735 | const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); |
5736 | LLT DstType = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5737 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5738 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5739 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5740 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5741 | |
5742 | unsigned PreferredFusedOpcode = |
5743 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5744 | |
5745 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5746 | // prefer to fold the multiply with fewer uses. |
5747 | if (Aggressive && isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5748 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally)) { |
5749 | if (hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5750 | std::swap(a&: LHS, b&: RHS); |
5751 | } |
5752 | |
5753 | // Builds: (fma x, y, (fma (fpext u), (fpext v), z)) |
5754 | auto buildMatchInfo = [=, &MI](Register U, Register V, Register Z, Register X, |
5755 | Register Y, MachineIRBuilder &B) { |
5756 | Register FpExtU = B.buildFPExt(Res: DstType, Op: U).getReg(Idx: 0); |
5757 | Register FpExtV = B.buildFPExt(Res: DstType, Op: V).getReg(Idx: 0); |
5758 | Register InnerFMA = |
5759 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {DstType}, SrcOps: {FpExtU, FpExtV, Z}) |
5760 | .getReg(Idx: 0); |
5761 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5762 | SrcOps: {X, Y, InnerFMA}); |
5763 | }; |
5764 | |
5765 | MachineInstr *FMulMI, *FMAMI; |
5766 | // fold (fadd (fma x, y, (fpext (fmul u, v))), z) |
5767 | // -> (fma x, y, (fma (fpext u), (fpext v), z)) |
5768 | if (LHS.MI->getOpcode() == PreferredFusedOpcode && |
5769 | mi_match(R: LHS.MI->getOperand(i: 3).getReg(), MRI, |
5770 | P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
5771 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5772 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5773 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
5774 | MatchInfo = [=](MachineIRBuilder &B) { |
5775 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
5776 | FMulMI->getOperand(i: 2).getReg(), RHS.Reg, |
5777 | LHS.MI->getOperand(i: 1).getReg(), |
5778 | LHS.MI->getOperand(i: 2).getReg(), B); |
5779 | }; |
5780 | return true; |
5781 | } |
5782 | |
5783 | // fold (fadd (fpext (fma x, y, (fmul u, v))), z) |
5784 | // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) |
5785 | // FIXME: This turns two single-precision and one double-precision |
5786 | // operation into two double-precision operations, which might not be |
5787 | // interesting for all targets, especially GPUs. |
5788 | if (mi_match(R: LHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) && |
5789 | FMAMI->getOpcode() == PreferredFusedOpcode) { |
5790 | MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg()); |
5791 | if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5792 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5793 | SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) { |
5794 | MatchInfo = [=](MachineIRBuilder &B) { |
5795 | Register X = FMAMI->getOperand(i: 1).getReg(); |
5796 | Register Y = FMAMI->getOperand(i: 2).getReg(); |
5797 | X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0); |
5798 | Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0); |
5799 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
5800 | FMulMI->getOperand(i: 2).getReg(), RHS.Reg, X, Y, B); |
5801 | }; |
5802 | |
5803 | return true; |
5804 | } |
5805 | } |
5806 | |
5807 | // fold (fadd z, (fma x, y, (fpext (fmul u, v))) |
5808 | // -> (fma x, y, (fma (fpext u), (fpext v), z)) |
5809 | if (RHS.MI->getOpcode() == PreferredFusedOpcode && |
5810 | mi_match(R: RHS.MI->getOperand(i: 3).getReg(), MRI, |
5811 | P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
5812 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5813 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5814 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
5815 | MatchInfo = [=](MachineIRBuilder &B) { |
5816 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
5817 | FMulMI->getOperand(i: 2).getReg(), LHS.Reg, |
5818 | RHS.MI->getOperand(i: 1).getReg(), |
5819 | RHS.MI->getOperand(i: 2).getReg(), B); |
5820 | }; |
5821 | return true; |
5822 | } |
5823 | |
5824 | // fold (fadd z, (fpext (fma x, y, (fmul u, v))) |
5825 | // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) |
5826 | // FIXME: This turns two single-precision and one double-precision |
5827 | // operation into two double-precision operations, which might not be |
5828 | // interesting for all targets, especially GPUs. |
5829 | if (mi_match(R: RHS.Reg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMAMI))) && |
5830 | FMAMI->getOpcode() == PreferredFusedOpcode) { |
5831 | MachineInstr *FMulMI = MRI.getVRegDef(Reg: FMAMI->getOperand(i: 3).getReg()); |
5832 | if (isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5833 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstType, |
5834 | SrcTy: MRI.getType(Reg: FMAMI->getOperand(i: 0).getReg()))) { |
5835 | MatchInfo = [=](MachineIRBuilder &B) { |
5836 | Register X = FMAMI->getOperand(i: 1).getReg(); |
5837 | Register Y = FMAMI->getOperand(i: 2).getReg(); |
5838 | X = B.buildFPExt(Res: DstType, Op: X).getReg(Idx: 0); |
5839 | Y = B.buildFPExt(Res: DstType, Op: Y).getReg(Idx: 0); |
5840 | buildMatchInfo(FMulMI->getOperand(i: 1).getReg(), |
5841 | FMulMI->getOperand(i: 2).getReg(), LHS.Reg, X, Y, B); |
5842 | }; |
5843 | return true; |
5844 | } |
5845 | } |
5846 | |
5847 | return false; |
5848 | } |
5849 | |
5850 | bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA( |
5851 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5852 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
5853 | |
5854 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5855 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5856 | return false; |
5857 | |
5858 | Register Op1 = MI.getOperand(i: 1).getReg(); |
5859 | Register Op2 = MI.getOperand(i: 2).getReg(); |
5860 | DefinitionAndSourceRegister LHS = {.MI: MRI.getVRegDef(Reg: Op1), .Reg: Op1}; |
5861 | DefinitionAndSourceRegister RHS = {.MI: MRI.getVRegDef(Reg: Op2), .Reg: Op2}; |
5862 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5863 | |
5864 | // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), |
5865 | // prefer to fold the multiply with fewer uses. |
5866 | int FirstMulHasFewerUses = true; |
5867 | if (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5868 | isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) && |
5869 | hasMoreUses(MI0: *LHS.MI, MI1: *RHS.MI, MRI)) |
5870 | FirstMulHasFewerUses = false; |
5871 | |
5872 | unsigned PreferredFusedOpcode = |
5873 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5874 | |
5875 | // fold (fsub (fmul x, y), z) -> (fma x, y, -z) |
5876 | if (FirstMulHasFewerUses && |
5877 | (isContractableFMul(MI&: *LHS.MI, AllowFusionGlobally) && |
5878 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHS.Reg)))) { |
5879 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5880 | Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHS.Reg).getReg(Idx: 0); |
5881 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5882 | SrcOps: {LHS.MI->getOperand(i: 1).getReg(), |
5883 | LHS.MI->getOperand(i: 2).getReg(), NegZ}); |
5884 | }; |
5885 | return true; |
5886 | } |
5887 | // fold (fsub x, (fmul y, z)) -> (fma -y, z, x) |
5888 | else if ((isContractableFMul(MI&: *RHS.MI, AllowFusionGlobally) && |
5889 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHS.Reg)))) { |
5890 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5891 | Register NegY = |
5892 | B.buildFNeg(Dst: DstTy, Src0: RHS.MI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
5893 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5894 | SrcOps: {NegY, RHS.MI->getOperand(i: 2).getReg(), LHS.Reg}); |
5895 | }; |
5896 | return true; |
5897 | } |
5898 | |
5899 | return false; |
5900 | } |
5901 | |
5902 | bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA( |
5903 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5904 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
5905 | |
5906 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5907 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5908 | return false; |
5909 | |
5910 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
5911 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
5912 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5913 | |
5914 | unsigned PreferredFusedOpcode = |
5915 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5916 | |
5917 | MachineInstr *FMulMI; |
5918 | // fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z)) |
5919 | if (mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) && |
5920 | (Aggressive || (MRI.hasOneNonDBGUse(RegNo: LHSReg) && |
5921 | MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) && |
5922 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) { |
5923 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5924 | Register NegX = |
5925 | B.buildFNeg(Dst: DstTy, Src0: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
5926 | Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0); |
5927 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5928 | SrcOps: {NegX, FMulMI->getOperand(i: 2).getReg(), NegZ}); |
5929 | }; |
5930 | return true; |
5931 | } |
5932 | |
5933 | // fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x) |
5934 | if (mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_MInstr(MI&: FMulMI))) && |
5935 | (Aggressive || (MRI.hasOneNonDBGUse(RegNo: RHSReg) && |
5936 | MRI.hasOneNonDBGUse(RegNo: FMulMI->getOperand(i: 0).getReg()))) && |
5937 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally)) { |
5938 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5939 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5940 | SrcOps: {FMulMI->getOperand(i: 1).getReg(), |
5941 | FMulMI->getOperand(i: 2).getReg(), LHSReg}); |
5942 | }; |
5943 | return true; |
5944 | } |
5945 | |
5946 | return false; |
5947 | } |
5948 | |
5949 | bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA( |
5950 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
5951 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
5952 | |
5953 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
5954 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
5955 | return false; |
5956 | |
5957 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
5958 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
5959 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
5960 | |
5961 | unsigned PreferredFusedOpcode = |
5962 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
5963 | |
5964 | MachineInstr *FMulMI; |
5965 | // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z)) |
5966 | if (mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
5967 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5968 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: LHSReg))) { |
5969 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5970 | Register FpExtX = |
5971 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
5972 | Register FpExtY = |
5973 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0); |
5974 | Register NegZ = B.buildFNeg(Dst: DstTy, Src0: RHSReg).getReg(Idx: 0); |
5975 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5976 | SrcOps: {FpExtX, FpExtY, NegZ}); |
5977 | }; |
5978 | return true; |
5979 | } |
5980 | |
5981 | // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x) |
5982 | if (mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_MInstr(MI&: FMulMI))) && |
5983 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
5984 | (Aggressive || MRI.hasOneNonDBGUse(RegNo: RHSReg))) { |
5985 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
5986 | Register FpExtY = |
5987 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 1).getReg()).getReg(Idx: 0); |
5988 | Register NegY = B.buildFNeg(Dst: DstTy, Src0: FpExtY).getReg(Idx: 0); |
5989 | Register FpExtZ = |
5990 | B.buildFPExt(Res: DstTy, Op: FMulMI->getOperand(i: 2).getReg()).getReg(Idx: 0); |
5991 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {MI.getOperand(i: 0).getReg()}, |
5992 | SrcOps: {NegY, FpExtZ, LHSReg}); |
5993 | }; |
5994 | return true; |
5995 | } |
5996 | |
5997 | return false; |
5998 | } |
5999 | |
6000 | bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA( |
6001 | MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) { |
6002 | assert(MI.getOpcode() == TargetOpcode::G_FSUB); |
6003 | |
6004 | bool AllowFusionGlobally, HasFMAD, Aggressive; |
6005 | if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) |
6006 | return false; |
6007 | |
6008 | const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); |
6009 | LLT DstTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6010 | Register LHSReg = MI.getOperand(i: 1).getReg(); |
6011 | Register RHSReg = MI.getOperand(i: 2).getReg(); |
6012 | |
6013 | unsigned PreferredFusedOpcode = |
6014 | HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; |
6015 | |
6016 | auto buildMatchInfo = [=](Register Dst, Register X, Register Y, Register Z, |
6017 | MachineIRBuilder &B) { |
6018 | Register FpExtX = B.buildFPExt(Res: DstTy, Op: X).getReg(Idx: 0); |
6019 | Register FpExtY = B.buildFPExt(Res: DstTy, Op: Y).getReg(Idx: 0); |
6020 | B.buildInstr(Opc: PreferredFusedOpcode, DstOps: {Dst}, SrcOps: {FpExtX, FpExtY, Z}); |
6021 | }; |
6022 | |
6023 | MachineInstr *FMulMI; |
6024 | // fold (fsub (fpext (fneg (fmul x, y))), z) -> |
6025 | // (fneg (fma (fpext x), (fpext y), z)) |
6026 | // fold (fsub (fneg (fpext (fmul x, y))), z) -> |
6027 | // (fneg (fma (fpext x), (fpext y), z)) |
6028 | if ((mi_match(R: LHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) || |
6029 | mi_match(R: LHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) && |
6030 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6031 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy, |
6032 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
6033 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6034 | Register FMAReg = MRI.createGenericVirtualRegister(Ty: DstTy); |
6035 | buildMatchInfo(FMAReg, FMulMI->getOperand(i: 1).getReg(), |
6036 | FMulMI->getOperand(i: 2).getReg(), RHSReg, B); |
6037 | B.buildFNeg(Dst: MI.getOperand(i: 0).getReg(), Src0: FMAReg); |
6038 | }; |
6039 | return true; |
6040 | } |
6041 | |
6042 | // fold (fsub x, (fpext (fneg (fmul y, z)))) -> (fma (fpext y), (fpext z), x) |
6043 | // fold (fsub x, (fneg (fpext (fmul y, z)))) -> (fma (fpext y), (fpext z), x) |
6044 | if ((mi_match(R: RHSReg, MRI, P: m_GFPExt(Src: m_GFNeg(Src: m_MInstr(MI&: FMulMI)))) || |
6045 | mi_match(R: RHSReg, MRI, P: m_GFNeg(Src: m_GFPExt(Src: m_MInstr(MI&: FMulMI))))) && |
6046 | isContractableFMul(MI&: *FMulMI, AllowFusionGlobally) && |
6047 | TLI.isFPExtFoldable(MI, Opcode: PreferredFusedOpcode, DestTy: DstTy, |
6048 | SrcTy: MRI.getType(Reg: FMulMI->getOperand(i: 0).getReg()))) { |
6049 | MatchInfo = [=, &MI](MachineIRBuilder &B) { |
6050 | buildMatchInfo(MI.getOperand(i: 0).getReg(), FMulMI->getOperand(i: 1).getReg(), |
6051 | FMulMI->getOperand(i: 2).getReg(), LHSReg, B); |
6052 | }; |
6053 | return true; |
6054 | } |
6055 | |
6056 | return false; |
6057 | } |
6058 | |
6059 | bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI, |
6060 | unsigned &IdxToPropagate) { |
6061 | bool PropagateNaN; |
6062 | switch (MI.getOpcode()) { |
6063 | default: |
6064 | return false; |
6065 | case TargetOpcode::G_FMINNUM: |
6066 | case TargetOpcode::G_FMAXNUM: |
6067 | PropagateNaN = false; |
6068 | break; |
6069 | case TargetOpcode::G_FMINIMUM: |
6070 | case TargetOpcode::G_FMAXIMUM: |
6071 | PropagateNaN = true; |
6072 | break; |
6073 | } |
6074 | |
6075 | auto MatchNaN = [&](unsigned Idx) { |
6076 | Register MaybeNaNReg = MI.getOperand(i: Idx).getReg(); |
6077 | const ConstantFP *MaybeCst = getConstantFPVRegVal(VReg: MaybeNaNReg, MRI); |
6078 | if (!MaybeCst || !MaybeCst->getValueAPF().isNaN()) |
6079 | return false; |
6080 | IdxToPropagate = PropagateNaN ? Idx : (Idx == 1 ? 2 : 1); |
6081 | return true; |
6082 | }; |
6083 | |
6084 | return MatchNaN(1) || MatchNaN(2); |
6085 | } |
6086 | |
6087 | bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) { |
6088 | assert(MI.getOpcode() == TargetOpcode::G_ADD && "Expected a G_ADD" ); |
6089 | Register LHS = MI.getOperand(i: 1).getReg(); |
6090 | Register RHS = MI.getOperand(i: 2).getReg(); |
6091 | |
6092 | // Helper lambda to check for opportunities for |
6093 | // A + (B - A) -> B |
6094 | // (B - A) + A -> B |
6095 | auto CheckFold = [&](Register MaybeSub, Register MaybeSameReg) { |
6096 | Register Reg; |
6097 | return mi_match(R: MaybeSub, MRI, P: m_GSub(L: m_Reg(R&: Src), R: m_Reg(R&: Reg))) && |
6098 | Reg == MaybeSameReg; |
6099 | }; |
6100 | return CheckFold(LHS, RHS) || CheckFold(RHS, LHS); |
6101 | } |
6102 | |
6103 | bool CombinerHelper::matchBuildVectorIdentityFold(MachineInstr &MI, |
6104 | Register &MatchInfo) { |
6105 | // This combine folds the following patterns: |
6106 | // |
6107 | // G_BUILD_VECTOR_TRUNC (G_BITCAST(x), G_LSHR(G_BITCAST(x), k)) |
6108 | // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), G_TRUNC(G_LSHR(G_BITCAST(x), k))) |
6109 | // into |
6110 | // x |
6111 | // if |
6112 | // k == sizeof(VecEltTy)/2 |
6113 | // type(x) == type(dst) |
6114 | // |
6115 | // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), undef) |
6116 | // into |
6117 | // x |
6118 | // if |
6119 | // type(x) == type(dst) |
6120 | |
6121 | LLT DstVecTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6122 | LLT DstEltTy = DstVecTy.getElementType(); |
6123 | |
6124 | Register Lo, Hi; |
6125 | |
6126 | if (mi_match( |
6127 | MI, MRI, |
6128 | P: m_GBuildVector(L: m_GTrunc(Src: m_GBitcast(Src: m_Reg(R&: Lo))), R: m_GImplicitDef()))) { |
6129 | MatchInfo = Lo; |
6130 | return MRI.getType(Reg: MatchInfo) == DstVecTy; |
6131 | } |
6132 | |
6133 | std::optional<ValueAndVReg> ShiftAmount; |
6134 | const auto LoPattern = m_GBitcast(Src: m_Reg(R&: Lo)); |
6135 | const auto HiPattern = m_GLShr(L: m_GBitcast(Src: m_Reg(R&: Hi)), R: m_GCst(ValReg&: ShiftAmount)); |
6136 | if (mi_match( |
6137 | MI, MRI, |
6138 | P: m_any_of(preds: m_GBuildVectorTrunc(L: LoPattern, R: HiPattern), |
6139 | preds: m_GBuildVector(L: m_GTrunc(Src: LoPattern), R: m_GTrunc(Src: HiPattern))))) { |
6140 | if (Lo == Hi && ShiftAmount->Value == DstEltTy.getSizeInBits()) { |
6141 | MatchInfo = Lo; |
6142 | return MRI.getType(Reg: MatchInfo) == DstVecTy; |
6143 | } |
6144 | } |
6145 | |
6146 | return false; |
6147 | } |
6148 | |
6149 | bool CombinerHelper::matchTruncBuildVectorFold(MachineInstr &MI, |
6150 | Register &MatchInfo) { |
6151 | // Replace (G_TRUNC (G_BITCAST (G_BUILD_VECTOR x, y)) with just x |
6152 | // if type(x) == type(G_TRUNC) |
6153 | if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI, |
6154 | P: m_GBitcast(Src: m_GBuildVector(L: m_Reg(R&: MatchInfo), R: m_Reg())))) |
6155 | return false; |
6156 | |
6157 | return MRI.getType(Reg: MatchInfo) == MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6158 | } |
6159 | |
6160 | bool CombinerHelper::matchTruncLshrBuildVectorFold(MachineInstr &MI, |
6161 | Register &MatchInfo) { |
6162 | // Replace (G_TRUNC (G_LSHR (G_BITCAST (G_BUILD_VECTOR x, y)), K)) with |
6163 | // y if K == size of vector element type |
6164 | std::optional<ValueAndVReg> ShiftAmt; |
6165 | if (!mi_match(R: MI.getOperand(i: 1).getReg(), MRI, |
6166 | P: m_GLShr(L: m_GBitcast(Src: m_GBuildVector(L: m_Reg(), R: m_Reg(R&: MatchInfo))), |
6167 | R: m_GCst(ValReg&: ShiftAmt)))) |
6168 | return false; |
6169 | |
6170 | LLT MatchTy = MRI.getType(Reg: MatchInfo); |
6171 | return ShiftAmt->Value.getZExtValue() == MatchTy.getSizeInBits() && |
6172 | MatchTy == MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6173 | } |
6174 | |
6175 | unsigned CombinerHelper::getFPMinMaxOpcForSelect( |
6176 | CmpInst::Predicate Pred, LLT DstTy, |
6177 | SelectPatternNaNBehaviour VsNaNRetVal) const { |
6178 | assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE && |
6179 | "Expected a NaN behaviour?" ); |
6180 | // Choose an opcode based off of legality or the behaviour when one of the |
6181 | // LHS/RHS may be NaN. |
6182 | switch (Pred) { |
6183 | default: |
6184 | return 0; |
6185 | case CmpInst::FCMP_UGT: |
6186 | case CmpInst::FCMP_UGE: |
6187 | case CmpInst::FCMP_OGT: |
6188 | case CmpInst::FCMP_OGE: |
6189 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER) |
6190 | return TargetOpcode::G_FMAXNUM; |
6191 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN) |
6192 | return TargetOpcode::G_FMAXIMUM; |
6193 | if (isLegal(Query: {TargetOpcode::G_FMAXNUM, {DstTy}})) |
6194 | return TargetOpcode::G_FMAXNUM; |
6195 | if (isLegal(Query: {TargetOpcode::G_FMAXIMUM, {DstTy}})) |
6196 | return TargetOpcode::G_FMAXIMUM; |
6197 | return 0; |
6198 | case CmpInst::FCMP_ULT: |
6199 | case CmpInst::FCMP_ULE: |
6200 | case CmpInst::FCMP_OLT: |
6201 | case CmpInst::FCMP_OLE: |
6202 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER) |
6203 | return TargetOpcode::G_FMINNUM; |
6204 | if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN) |
6205 | return TargetOpcode::G_FMINIMUM; |
6206 | if (isLegal(Query: {TargetOpcode::G_FMINNUM, {DstTy}})) |
6207 | return TargetOpcode::G_FMINNUM; |
6208 | if (!isLegal(Query: {TargetOpcode::G_FMINIMUM, {DstTy}})) |
6209 | return 0; |
6210 | return TargetOpcode::G_FMINIMUM; |
6211 | } |
6212 | } |
6213 | |
6214 | CombinerHelper::SelectPatternNaNBehaviour |
6215 | CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS, |
6216 | bool IsOrderedComparison) const { |
6217 | bool LHSSafe = isKnownNeverNaN(Val: LHS, MRI); |
6218 | bool RHSSafe = isKnownNeverNaN(Val: RHS, MRI); |
6219 | // Completely unsafe. |
6220 | if (!LHSSafe && !RHSSafe) |
6221 | return SelectPatternNaNBehaviour::NOT_APPLICABLE; |
6222 | if (LHSSafe && RHSSafe) |
6223 | return SelectPatternNaNBehaviour::RETURNS_ANY; |
6224 | // An ordered comparison will return false when given a NaN, so it |
6225 | // returns the RHS. |
6226 | if (IsOrderedComparison) |
6227 | return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN |
6228 | : SelectPatternNaNBehaviour::RETURNS_OTHER; |
6229 | // An unordered comparison will return true when given a NaN, so it |
6230 | // returns the LHS. |
6231 | return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER |
6232 | : SelectPatternNaNBehaviour::RETURNS_NAN; |
6233 | } |
6234 | |
6235 | bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond, |
6236 | Register TrueVal, Register FalseVal, |
6237 | BuildFnTy &MatchInfo) { |
6238 | // Match: select (fcmp cond x, y) x, y |
6239 | // select (fcmp cond x, y) y, x |
6240 | // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition. |
6241 | LLT DstTy = MRI.getType(Reg: Dst); |
6242 | // Bail out early on pointers, since we'll never want to fold to a min/max. |
6243 | if (DstTy.isPointer()) |
6244 | return false; |
6245 | // Match a floating point compare with a less-than/greater-than predicate. |
6246 | // TODO: Allow multiple users of the compare if they are all selects. |
6247 | CmpInst::Predicate Pred; |
6248 | Register CmpLHS, CmpRHS; |
6249 | if (!mi_match(R: Cond, MRI, |
6250 | P: m_OneNonDBGUse( |
6251 | SP: m_GFCmp(P: m_Pred(P&: Pred), L: m_Reg(R&: CmpLHS), R: m_Reg(R&: CmpRHS)))) || |
6252 | CmpInst::isEquality(pred: Pred)) |
6253 | return false; |
6254 | SelectPatternNaNBehaviour ResWithKnownNaNInfo = |
6255 | computeRetValAgainstNaN(LHS: CmpLHS, RHS: CmpRHS, IsOrderedComparison: CmpInst::isOrdered(predicate: Pred)); |
6256 | if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE) |
6257 | return false; |
6258 | if (TrueVal == CmpRHS && FalseVal == CmpLHS) { |
6259 | std::swap(a&: CmpLHS, b&: CmpRHS); |
6260 | Pred = CmpInst::getSwappedPredicate(pred: Pred); |
6261 | if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN) |
6262 | ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER; |
6263 | else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER) |
6264 | ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN; |
6265 | } |
6266 | if (TrueVal != CmpLHS || FalseVal != CmpRHS) |
6267 | return false; |
6268 | // Decide what type of max/min this should be based off of the predicate. |
6269 | unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, VsNaNRetVal: ResWithKnownNaNInfo); |
6270 | if (!Opc || !isLegal(Query: {Opc, {DstTy}})) |
6271 | return false; |
6272 | // Comparisons between signed zero and zero may have different results... |
6273 | // unless we have fmaximum/fminimum. In that case, we know -0 < 0. |
6274 | if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) { |
6275 | // We don't know if a comparison between two 0s will give us a consistent |
6276 | // result. Be conservative and only proceed if at least one side is |
6277 | // non-zero. |
6278 | auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpLHS, MRI); |
6279 | if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) { |
6280 | KnownNonZeroSide = getFConstantVRegValWithLookThrough(VReg: CmpRHS, MRI); |
6281 | if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) |
6282 | return false; |
6283 | } |
6284 | } |
6285 | MatchInfo = [=](MachineIRBuilder &B) { |
6286 | B.buildInstr(Opc, DstOps: {Dst}, SrcOps: {CmpLHS, CmpRHS}); |
6287 | }; |
6288 | return true; |
6289 | } |
6290 | |
6291 | bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI, |
6292 | BuildFnTy &MatchInfo) { |
6293 | // TODO: Handle integer cases. |
6294 | assert(MI.getOpcode() == TargetOpcode::G_SELECT); |
6295 | // Condition may be fed by a truncated compare. |
6296 | Register Cond = MI.getOperand(i: 1).getReg(); |
6297 | Register MaybeTrunc; |
6298 | if (mi_match(R: Cond, MRI, P: m_OneNonDBGUse(SP: m_GTrunc(Src: m_Reg(R&: MaybeTrunc))))) |
6299 | Cond = MaybeTrunc; |
6300 | Register Dst = MI.getOperand(i: 0).getReg(); |
6301 | Register TrueVal = MI.getOperand(i: 2).getReg(); |
6302 | Register FalseVal = MI.getOperand(i: 3).getReg(); |
6303 | return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo); |
6304 | } |
6305 | |
6306 | bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI, |
6307 | BuildFnTy &MatchInfo) { |
6308 | assert(MI.getOpcode() == TargetOpcode::G_ICMP); |
6309 | // (X + Y) == X --> Y == 0 |
6310 | // (X + Y) != X --> Y != 0 |
6311 | // (X - Y) == X --> Y == 0 |
6312 | // (X - Y) != X --> Y != 0 |
6313 | // (X ^ Y) == X --> Y == 0 |
6314 | // (X ^ Y) != X --> Y != 0 |
6315 | Register Dst = MI.getOperand(i: 0).getReg(); |
6316 | CmpInst::Predicate Pred; |
6317 | Register X, Y, OpLHS, OpRHS; |
6318 | bool MatchedSub = mi_match( |
6319 | R: Dst, MRI, |
6320 | P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X), R: m_GSub(L: m_Reg(R&: OpLHS), R: m_Reg(R&: Y)))); |
6321 | if (MatchedSub && X != OpLHS) |
6322 | return false; |
6323 | if (!MatchedSub) { |
6324 | if (!mi_match(R: Dst, MRI, |
6325 | P: m_c_GICmp(P: m_Pred(P&: Pred), L: m_Reg(R&: X), |
6326 | R: m_any_of(preds: m_GAdd(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS)), |
6327 | preds: m_GXor(L: m_Reg(R&: OpLHS), R: m_Reg(R&: OpRHS)))))) |
6328 | return false; |
6329 | Y = X == OpLHS ? OpRHS : X == OpRHS ? OpLHS : Register(); |
6330 | } |
6331 | MatchInfo = [=](MachineIRBuilder &B) { |
6332 | auto Zero = B.buildConstant(Res: MRI.getType(Reg: Y), Val: 0); |
6333 | B.buildICmp(Pred, Res: Dst, Op0: Y, Op1: Zero); |
6334 | }; |
6335 | return CmpInst::isEquality(pred: Pred) && Y.isValid(); |
6336 | } |
6337 | |
6338 | bool CombinerHelper::matchShiftsTooBig(MachineInstr &MI) { |
6339 | Register ShiftReg = MI.getOperand(i: 2).getReg(); |
6340 | LLT ResTy = MRI.getType(Reg: MI.getOperand(i: 0).getReg()); |
6341 | auto IsShiftTooBig = [&](const Constant *C) { |
6342 | auto *CI = dyn_cast<ConstantInt>(Val: C); |
6343 | return CI && CI->uge(Num: ResTy.getScalarSizeInBits()); |
6344 | }; |
6345 | return matchUnaryPredicate(MRI, Reg: ShiftReg, Match: IsShiftTooBig); |
6346 | } |
6347 | |
6348 | bool CombinerHelper::matchCommuteConstantToRHS(MachineInstr &MI) { |
6349 | unsigned LHSOpndIdx = 1; |
6350 | unsigned RHSOpndIdx = 2; |
6351 | switch (MI.getOpcode()) { |
6352 | case TargetOpcode::G_UADDO: |
6353 | case TargetOpcode::G_SADDO: |
6354 | case TargetOpcode::G_UMULO: |
6355 | case TargetOpcode::G_SMULO: |
6356 | LHSOpndIdx = 2; |
6357 | RHSOpndIdx = 3; |
6358 | break; |
6359 | default: |
6360 | break; |
6361 | } |
6362 | Register LHS = MI.getOperand(i: LHSOpndIdx).getReg(); |
6363 | Register RHS = MI.getOperand(i: RHSOpndIdx).getReg(); |
6364 | if (!getIConstantVRegVal(VReg: LHS, MRI)) { |
6365 | // Skip commuting if LHS is not a constant. But, LHS may be a |
6366 | // G_CONSTANT_FOLD_BARRIER. If so we commute as long as we don't already |
6367 | // have a constant on the RHS. |
6368 | if (MRI.getVRegDef(Reg: LHS)->getOpcode() != |
6369 | TargetOpcode::G_CONSTANT_FOLD_BARRIER) |
6370 | return false; |
6371 | } |
6372 | // Commute as long as RHS is not a constant or G_CONSTANT_FOLD_BARRIER. |
6373 | return MRI.getVRegDef(Reg: RHS)->getOpcode() != |
6374 | TargetOpcode::G_CONSTANT_FOLD_BARRIER && |
6375 | !getIConstantVRegVal(VReg: RHS, MRI); |
6376 | } |
6377 | |
6378 | bool CombinerHelper::matchCommuteFPConstantToRHS(MachineInstr &MI) { |
6379 | Register LHS = MI.getOperand(i: 1).getReg(); |
6380 | Register RHS = MI.getOperand(i: 2).getReg(); |
6381 | std::optional<FPValueAndVReg> ValAndVReg; |
6382 | if (!mi_match(R: LHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg))) |
6383 | return false; |
6384 | return !mi_match(R: RHS, MRI, P: m_GFCstOrSplat(FPValReg&: ValAndVReg)); |
6385 | } |
6386 | |
6387 | void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) { |
6388 | Observer.changingInstr(MI); |
6389 | unsigned LHSOpndIdx = 1; |
6390 | unsigned RHSOpndIdx = 2; |
6391 | switch (MI.getOpcode()) { |
6392 | case TargetOpcode::G_UADDO: |
6393 | case TargetOpcode::G_SADDO: |
6394 | case TargetOpcode::G_UMULO: |
6395 | case TargetOpcode::G_SMULO: |
6396 | LHSOpndIdx = 2; |
6397 | RHSOpndIdx = 3; |
6398 | break; |
6399 | default: |
6400 | break; |
6401 | } |
6402 | Register LHSReg = MI.getOperand(i: LHSOpndIdx).getReg(); |
6403 | Register RHSReg = MI.getOperand(i: RHSOpndIdx).getReg(); |
6404 | MI.getOperand(i: LHSOpndIdx).setReg(RHSReg); |
6405 | MI.getOperand(i: RHSOpndIdx).setReg(LHSReg); |
6406 | Observer.changedInstr(MI); |
6407 | } |
6408 | |
6409 | bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) { |
6410 | LLT SrcTy = MRI.getType(Reg: Src); |
6411 | if (SrcTy.isFixedVector()) |
6412 | return isConstantSplatVector(Src, SplatValue: 1, AllowUndefs); |
6413 | if (SrcTy.isScalar()) { |
6414 | if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr) |
6415 | return true; |
6416 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6417 | return IConstant && IConstant->Value == 1; |
6418 | } |
6419 | return false; // scalable vector |
6420 | } |
6421 | |
6422 | bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) { |
6423 | LLT SrcTy = MRI.getType(Reg: Src); |
6424 | if (SrcTy.isFixedVector()) |
6425 | return isConstantSplatVector(Src, SplatValue: 0, AllowUndefs); |
6426 | if (SrcTy.isScalar()) { |
6427 | if (AllowUndefs && getOpcodeDef<GImplicitDef>(Reg: Src, MRI) != nullptr) |
6428 | return true; |
6429 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6430 | return IConstant && IConstant->Value == 0; |
6431 | } |
6432 | return false; // scalable vector |
6433 | } |
6434 | |
6435 | // Ignores COPYs during conformance checks. |
6436 | // FIXME scalable vectors. |
6437 | bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue, |
6438 | bool AllowUndefs) { |
6439 | GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI); |
6440 | if (!BuildVector) |
6441 | return false; |
6442 | unsigned NumSources = BuildVector->getNumSources(); |
6443 | |
6444 | for (unsigned I = 0; I < NumSources; ++I) { |
6445 | GImplicitDef *ImplicitDef = |
6446 | getOpcodeDef<GImplicitDef>(Reg: BuildVector->getSourceReg(I), MRI); |
6447 | if (ImplicitDef && AllowUndefs) |
6448 | continue; |
6449 | if (ImplicitDef && !AllowUndefs) |
6450 | return false; |
6451 | std::optional<ValueAndVReg> IConstant = |
6452 | getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI); |
6453 | if (IConstant && IConstant->Value == SplatValue) |
6454 | continue; |
6455 | return false; |
6456 | } |
6457 | return true; |
6458 | } |
6459 | |
6460 | // Ignores COPYs during lookups. |
6461 | // FIXME scalable vectors |
6462 | std::optional<APInt> |
6463 | CombinerHelper::getConstantOrConstantSplatVector(Register Src) { |
6464 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6465 | if (IConstant) |
6466 | return IConstant->Value; |
6467 | |
6468 | GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI); |
6469 | if (!BuildVector) |
6470 | return std::nullopt; |
6471 | unsigned NumSources = BuildVector->getNumSources(); |
6472 | |
6473 | std::optional<APInt> Value = std::nullopt; |
6474 | for (unsigned I = 0; I < NumSources; ++I) { |
6475 | std::optional<ValueAndVReg> IConstant = |
6476 | getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI); |
6477 | if (!IConstant) |
6478 | return std::nullopt; |
6479 | if (!Value) |
6480 | Value = IConstant->Value; |
6481 | else if (*Value != IConstant->Value) |
6482 | return std::nullopt; |
6483 | } |
6484 | return Value; |
6485 | } |
6486 | |
6487 | // FIXME G_SPLAT_VECTOR |
6488 | bool CombinerHelper::isConstantOrConstantVectorI(Register Src) const { |
6489 | auto IConstant = getIConstantVRegValWithLookThrough(VReg: Src, MRI); |
6490 | if (IConstant) |
6491 | return true; |
6492 | |
6493 | GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Reg: Src, MRI); |
6494 | if (!BuildVector) |
6495 | return false; |
6496 | |
6497 | unsigned NumSources = BuildVector->getNumSources(); |
6498 | for (unsigned I = 0; I < NumSources; ++I) { |
6499 | std::optional<ValueAndVReg> IConstant = |
6500 | getIConstantVRegValWithLookThrough(VReg: BuildVector->getSourceReg(I), MRI); |
6501 | if (!IConstant) |
6502 | return false; |
6503 | } |
6504 | return true; |
6505 | } |
6506 | |
6507 | // TODO: use knownbits to determine zeros |
6508 | bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select, |
6509 | BuildFnTy &MatchInfo) { |
6510 | uint32_t Flags = Select->getFlags(); |
6511 | Register Dest = Select->getReg(Idx: 0); |
6512 | Register Cond = Select->getCondReg(); |
6513 | Register True = Select->getTrueReg(); |
6514 | Register False = Select->getFalseReg(); |
6515 | LLT CondTy = MRI.getType(Reg: Select->getCondReg()); |
6516 | LLT TrueTy = MRI.getType(Reg: Select->getTrueReg()); |
6517 | |
6518 | // We only do this combine for scalar boolean conditions. |
6519 | if (CondTy != LLT::scalar(SizeInBits: 1)) |
6520 | return false; |
6521 | |
6522 | if (TrueTy.isPointer()) |
6523 | return false; |
6524 | |
6525 | // Both are scalars. |
6526 | std::optional<ValueAndVReg> TrueOpt = |
6527 | getIConstantVRegValWithLookThrough(VReg: True, MRI); |
6528 | std::optional<ValueAndVReg> FalseOpt = |
6529 | getIConstantVRegValWithLookThrough(VReg: False, MRI); |
6530 | |
6531 | if (!TrueOpt || !FalseOpt) |
6532 | return false; |
6533 | |
6534 | APInt TrueValue = TrueOpt->Value; |
6535 | APInt FalseValue = FalseOpt->Value; |
6536 | |
6537 | // select Cond, 1, 0 --> zext (Cond) |
6538 | if (TrueValue.isOne() && FalseValue.isZero()) { |
6539 | MatchInfo = [=](MachineIRBuilder &B) { |
6540 | B.setInstrAndDebugLoc(*Select); |
6541 | B.buildZExtOrTrunc(Res: Dest, Op: Cond); |
6542 | }; |
6543 | return true; |
6544 | } |
6545 | |
6546 | // select Cond, -1, 0 --> sext (Cond) |
6547 | if (TrueValue.isAllOnes() && FalseValue.isZero()) { |
6548 | MatchInfo = [=](MachineIRBuilder &B) { |
6549 | B.setInstrAndDebugLoc(*Select); |
6550 | B.buildSExtOrTrunc(Res: Dest, Op: Cond); |
6551 | }; |
6552 | return true; |
6553 | } |
6554 | |
6555 | // select Cond, 0, 1 --> zext (!Cond) |
6556 | if (TrueValue.isZero() && FalseValue.isOne()) { |
6557 | MatchInfo = [=](MachineIRBuilder &B) { |
6558 | B.setInstrAndDebugLoc(*Select); |
6559 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
6560 | B.buildNot(Dst: Inner, Src0: Cond); |
6561 | B.buildZExtOrTrunc(Res: Dest, Op: Inner); |
6562 | }; |
6563 | return true; |
6564 | } |
6565 | |
6566 | // select Cond, 0, -1 --> sext (!Cond) |
6567 | if (TrueValue.isZero() && FalseValue.isAllOnes()) { |
6568 | MatchInfo = [=](MachineIRBuilder &B) { |
6569 | B.setInstrAndDebugLoc(*Select); |
6570 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
6571 | B.buildNot(Dst: Inner, Src0: Cond); |
6572 | B.buildSExtOrTrunc(Res: Dest, Op: Inner); |
6573 | }; |
6574 | return true; |
6575 | } |
6576 | |
6577 | // select Cond, C1, C1-1 --> add (zext Cond), C1-1 |
6578 | if (TrueValue - 1 == FalseValue) { |
6579 | MatchInfo = [=](MachineIRBuilder &B) { |
6580 | B.setInstrAndDebugLoc(*Select); |
6581 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6582 | B.buildZExtOrTrunc(Res: Inner, Op: Cond); |
6583 | B.buildAdd(Dst: Dest, Src0: Inner, Src1: False); |
6584 | }; |
6585 | return true; |
6586 | } |
6587 | |
6588 | // select Cond, C1, C1+1 --> add (sext Cond), C1+1 |
6589 | if (TrueValue + 1 == FalseValue) { |
6590 | MatchInfo = [=](MachineIRBuilder &B) { |
6591 | B.setInstrAndDebugLoc(*Select); |
6592 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6593 | B.buildSExtOrTrunc(Res: Inner, Op: Cond); |
6594 | B.buildAdd(Dst: Dest, Src0: Inner, Src1: False); |
6595 | }; |
6596 | return true; |
6597 | } |
6598 | |
6599 | // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2) |
6600 | if (TrueValue.isPowerOf2() && FalseValue.isZero()) { |
6601 | MatchInfo = [=](MachineIRBuilder &B) { |
6602 | B.setInstrAndDebugLoc(*Select); |
6603 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6604 | B.buildZExtOrTrunc(Res: Inner, Op: Cond); |
6605 | // The shift amount must be scalar. |
6606 | LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy; |
6607 | auto ShAmtC = B.buildConstant(Res: ShiftTy, Val: TrueValue.exactLogBase2()); |
6608 | B.buildShl(Dst: Dest, Src0: Inner, Src1: ShAmtC, Flags); |
6609 | }; |
6610 | return true; |
6611 | } |
6612 | // select Cond, -1, C --> or (sext Cond), C |
6613 | if (TrueValue.isAllOnes()) { |
6614 | MatchInfo = [=](MachineIRBuilder &B) { |
6615 | B.setInstrAndDebugLoc(*Select); |
6616 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6617 | B.buildSExtOrTrunc(Res: Inner, Op: Cond); |
6618 | B.buildOr(Dst: Dest, Src0: Inner, Src1: False, Flags); |
6619 | }; |
6620 | return true; |
6621 | } |
6622 | |
6623 | // select Cond, C, -1 --> or (sext (not Cond)), C |
6624 | if (FalseValue.isAllOnes()) { |
6625 | MatchInfo = [=](MachineIRBuilder &B) { |
6626 | B.setInstrAndDebugLoc(*Select); |
6627 | Register Not = MRI.createGenericVirtualRegister(Ty: CondTy); |
6628 | B.buildNot(Dst: Not, Src0: Cond); |
6629 | Register Inner = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6630 | B.buildSExtOrTrunc(Res: Inner, Op: Not); |
6631 | B.buildOr(Dst: Dest, Src0: Inner, Src1: True, Flags); |
6632 | }; |
6633 | return true; |
6634 | } |
6635 | |
6636 | return false; |
6637 | } |
6638 | |
6639 | // TODO: use knownbits to determine zeros |
6640 | bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select, |
6641 | BuildFnTy &MatchInfo) { |
6642 | uint32_t Flags = Select->getFlags(); |
6643 | Register DstReg = Select->getReg(Idx: 0); |
6644 | Register Cond = Select->getCondReg(); |
6645 | Register True = Select->getTrueReg(); |
6646 | Register False = Select->getFalseReg(); |
6647 | LLT CondTy = MRI.getType(Reg: Select->getCondReg()); |
6648 | LLT TrueTy = MRI.getType(Reg: Select->getTrueReg()); |
6649 | |
6650 | // Boolean or fixed vector of booleans. |
6651 | if (CondTy.isScalableVector() || |
6652 | (CondTy.isFixedVector() && |
6653 | CondTy.getElementType().getScalarSizeInBits() != 1) || |
6654 | CondTy.getScalarSizeInBits() != 1) |
6655 | return false; |
6656 | |
6657 | if (CondTy != TrueTy) |
6658 | return false; |
6659 | |
6660 | // select Cond, Cond, F --> or Cond, F |
6661 | // select Cond, 1, F --> or Cond, F |
6662 | if ((Cond == True) || isOneOrOneSplat(Src: True, /* AllowUndefs */ true)) { |
6663 | MatchInfo = [=](MachineIRBuilder &B) { |
6664 | B.setInstrAndDebugLoc(*Select); |
6665 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6666 | B.buildZExtOrTrunc(Res: Ext, Op: Cond); |
6667 | auto FreezeFalse = B.buildFreeze(Dst: TrueTy, Src: False); |
6668 | B.buildOr(Dst: DstReg, Src0: Ext, Src1: FreezeFalse, Flags); |
6669 | }; |
6670 | return true; |
6671 | } |
6672 | |
6673 | // select Cond, T, Cond --> and Cond, T |
6674 | // select Cond, T, 0 --> and Cond, T |
6675 | if ((Cond == False) || isZeroOrZeroSplat(Src: False, /* AllowUndefs */ true)) { |
6676 | MatchInfo = [=](MachineIRBuilder &B) { |
6677 | B.setInstrAndDebugLoc(*Select); |
6678 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6679 | B.buildZExtOrTrunc(Res: Ext, Op: Cond); |
6680 | auto FreezeTrue = B.buildFreeze(Dst: TrueTy, Src: True); |
6681 | B.buildAnd(Dst: DstReg, Src0: Ext, Src1: FreezeTrue); |
6682 | }; |
6683 | return true; |
6684 | } |
6685 | |
6686 | // select Cond, T, 1 --> or (not Cond), T |
6687 | if (isOneOrOneSplat(Src: False, /* AllowUndefs */ true)) { |
6688 | MatchInfo = [=](MachineIRBuilder &B) { |
6689 | B.setInstrAndDebugLoc(*Select); |
6690 | // First the not. |
6691 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
6692 | B.buildNot(Dst: Inner, Src0: Cond); |
6693 | // Then an ext to match the destination register. |
6694 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6695 | B.buildZExtOrTrunc(Res: Ext, Op: Inner); |
6696 | auto FreezeTrue = B.buildFreeze(Dst: TrueTy, Src: True); |
6697 | B.buildOr(Dst: DstReg, Src0: Ext, Src1: FreezeTrue, Flags); |
6698 | }; |
6699 | return true; |
6700 | } |
6701 | |
6702 | // select Cond, 0, F --> and (not Cond), F |
6703 | if (isZeroOrZeroSplat(Src: True, /* AllowUndefs */ true)) { |
6704 | MatchInfo = [=](MachineIRBuilder &B) { |
6705 | B.setInstrAndDebugLoc(*Select); |
6706 | // First the not. |
6707 | Register Inner = MRI.createGenericVirtualRegister(Ty: CondTy); |
6708 | B.buildNot(Dst: Inner, Src0: Cond); |
6709 | // Then an ext to match the destination register. |
6710 | Register Ext = MRI.createGenericVirtualRegister(Ty: TrueTy); |
6711 | B.buildZExtOrTrunc(Res: Ext, Op: Inner); |
6712 | auto FreezeFalse = B.buildFreeze(Dst: TrueTy, Src: False); |
6713 | B.buildAnd(Dst: DstReg, Src0: Ext, Src1: FreezeFalse); |
6714 | }; |
6715 | return true; |
6716 | } |
6717 | |
6718 | return false; |
6719 | } |
6720 | |
6721 | bool CombinerHelper::tryFoldSelectToIntMinMax(GSelect *Select, |
6722 | BuildFnTy &MatchInfo) { |
6723 | Register DstReg = Select->getReg(Idx: 0); |
6724 | Register Cond = Select->getCondReg(); |
6725 | Register True = Select->getTrueReg(); |
6726 | Register False = Select->getFalseReg(); |
6727 | LLT DstTy = MRI.getType(Reg: DstReg); |
6728 | |
6729 | if (DstTy.isPointer()) |
6730 | return false; |
6731 | |
6732 | // We need an G_ICMP on the condition register. |
6733 | GICmp *Cmp = getOpcodeDef<GICmp>(Reg: Cond, MRI); |
6734 | if (!Cmp) |
6735 | return false; |
6736 | |
6737 | // We want to fold the icmp and replace the select. |
6738 | if (!MRI.hasOneNonDBGUse(RegNo: Cmp->getReg(Idx: 0))) |
6739 | return false; |
6740 | |
6741 | CmpInst::Predicate Pred = Cmp->getCond(); |
6742 | // We need a larger or smaller predicate for |
6743 | // canonicalization. |
6744 | if (CmpInst::isEquality(pred: Pred)) |
6745 | return false; |
6746 | |
6747 | Register CmpLHS = Cmp->getLHSReg(); |
6748 | Register CmpRHS = Cmp->getRHSReg(); |
6749 | |
6750 | // We can swap CmpLHS and CmpRHS for higher hitrate. |
6751 | if (True == CmpRHS && False == CmpLHS) { |
6752 | std::swap(a&: CmpLHS, b&: CmpRHS); |
6753 | Pred = CmpInst::getSwappedPredicate(pred: Pred); |
6754 | } |
6755 | |
6756 | // (icmp X, Y) ? X : Y -> integer minmax. |
6757 | // see matchSelectPattern in ValueTracking. |
6758 | // Legality between G_SELECT and integer minmax can differ. |
6759 | if (True == CmpLHS && False == CmpRHS) { |
6760 | switch (Pred) { |
6761 | case ICmpInst::ICMP_UGT: |
6762 | case ICmpInst::ICMP_UGE: { |
6763 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMAX, DstTy})) |
6764 | return false; |
6765 | MatchInfo = [=](MachineIRBuilder &B) { |
6766 | B.buildUMax(Dst: DstReg, Src0: True, Src1: False); |
6767 | }; |
6768 | return true; |
6769 | } |
6770 | case ICmpInst::ICMP_SGT: |
6771 | case ICmpInst::ICMP_SGE: { |
6772 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMAX, DstTy})) |
6773 | return false; |
6774 | MatchInfo = [=](MachineIRBuilder &B) { |
6775 | B.buildSMax(Dst: DstReg, Src0: True, Src1: False); |
6776 | }; |
6777 | return true; |
6778 | } |
6779 | case ICmpInst::ICMP_ULT: |
6780 | case ICmpInst::ICMP_ULE: { |
6781 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_UMIN, DstTy})) |
6782 | return false; |
6783 | MatchInfo = [=](MachineIRBuilder &B) { |
6784 | B.buildUMin(Dst: DstReg, Src0: True, Src1: False); |
6785 | }; |
6786 | return true; |
6787 | } |
6788 | case ICmpInst::ICMP_SLT: |
6789 | case ICmpInst::ICMP_SLE: { |
6790 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_SMIN, DstTy})) |
6791 | return false; |
6792 | MatchInfo = [=](MachineIRBuilder &B) { |
6793 | B.buildSMin(Dst: DstReg, Src0: True, Src1: False); |
6794 | }; |
6795 | return true; |
6796 | } |
6797 | default: |
6798 | return false; |
6799 | } |
6800 | } |
6801 | |
6802 | return false; |
6803 | } |
6804 | |
6805 | bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) { |
6806 | GSelect *Select = cast<GSelect>(Val: &MI); |
6807 | |
6808 | if (tryFoldSelectOfConstants(Select, MatchInfo)) |
6809 | return true; |
6810 | |
6811 | if (tryFoldBoolSelectToLogic(Select, MatchInfo)) |
6812 | return true; |
6813 | |
6814 | if (tryFoldSelectToIntMinMax(Select, MatchInfo)) |
6815 | return true; |
6816 | |
6817 | return false; |
6818 | } |
6819 | |
6820 | /// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2) |
6821 | /// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2) |
6822 | /// into a single comparison using range-based reasoning. |
6823 | /// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges. |
6824 | bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges(GLogicalBinOp *Logic, |
6825 | BuildFnTy &MatchInfo) { |
6826 | assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor" ); |
6827 | bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND; |
6828 | Register DstReg = Logic->getReg(Idx: 0); |
6829 | Register LHS = Logic->getLHSReg(); |
6830 | Register RHS = Logic->getRHSReg(); |
6831 | unsigned Flags = Logic->getFlags(); |
6832 | |
6833 | // We need an G_ICMP on the LHS register. |
6834 | GICmp *Cmp1 = getOpcodeDef<GICmp>(Reg: LHS, MRI); |
6835 | if (!Cmp1) |
6836 | return false; |
6837 | |
6838 | // We need an G_ICMP on the RHS register. |
6839 | GICmp *Cmp2 = getOpcodeDef<GICmp>(Reg: RHS, MRI); |
6840 | if (!Cmp2) |
6841 | return false; |
6842 | |
6843 | // We want to fold the icmps. |
6844 | if (!MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) || |
6845 | !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0))) |
6846 | return false; |
6847 | |
6848 | APInt C1; |
6849 | APInt C2; |
6850 | std::optional<ValueAndVReg> MaybeC1 = |
6851 | getIConstantVRegValWithLookThrough(VReg: Cmp1->getRHSReg(), MRI); |
6852 | if (!MaybeC1) |
6853 | return false; |
6854 | C1 = MaybeC1->Value; |
6855 | |
6856 | std::optional<ValueAndVReg> MaybeC2 = |
6857 | getIConstantVRegValWithLookThrough(VReg: Cmp2->getRHSReg(), MRI); |
6858 | if (!MaybeC2) |
6859 | return false; |
6860 | C2 = MaybeC2->Value; |
6861 | |
6862 | Register R1 = Cmp1->getLHSReg(); |
6863 | Register R2 = Cmp2->getLHSReg(); |
6864 | CmpInst::Predicate Pred1 = Cmp1->getCond(); |
6865 | CmpInst::Predicate Pred2 = Cmp2->getCond(); |
6866 | LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0)); |
6867 | LLT CmpOperandTy = MRI.getType(Reg: R1); |
6868 | |
6869 | if (CmpOperandTy.isPointer()) |
6870 | return false; |
6871 | |
6872 | // We build ands, adds, and constants of type CmpOperandTy. |
6873 | // They must be legal to build. |
6874 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_AND, CmpOperandTy}) || |
6875 | !isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, CmpOperandTy}) || |
6876 | !isConstantLegalOrBeforeLegalizer(Ty: CmpOperandTy)) |
6877 | return false; |
6878 | |
6879 | // Look through add of a constant offset on R1, R2, or both operands. This |
6880 | // allows us to interpret the R + C' < C'' range idiom into a proper range. |
6881 | std::optional<APInt> Offset1; |
6882 | std::optional<APInt> Offset2; |
6883 | if (R1 != R2) { |
6884 | if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R1, MRI)) { |
6885 | std::optional<ValueAndVReg> MaybeOffset1 = |
6886 | getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI); |
6887 | if (MaybeOffset1) { |
6888 | R1 = Add->getLHSReg(); |
6889 | Offset1 = MaybeOffset1->Value; |
6890 | } |
6891 | } |
6892 | if (GAdd *Add = getOpcodeDef<GAdd>(Reg: R2, MRI)) { |
6893 | std::optional<ValueAndVReg> MaybeOffset2 = |
6894 | getIConstantVRegValWithLookThrough(VReg: Add->getRHSReg(), MRI); |
6895 | if (MaybeOffset2) { |
6896 | R2 = Add->getLHSReg(); |
6897 | Offset2 = MaybeOffset2->Value; |
6898 | } |
6899 | } |
6900 | } |
6901 | |
6902 | if (R1 != R2) |
6903 | return false; |
6904 | |
6905 | // We calculate the icmp ranges including maybe offsets. |
6906 | ConstantRange CR1 = ConstantRange::makeExactICmpRegion( |
6907 | Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred1) : Pred1, Other: C1); |
6908 | if (Offset1) |
6909 | CR1 = CR1.subtract(CI: *Offset1); |
6910 | |
6911 | ConstantRange CR2 = ConstantRange::makeExactICmpRegion( |
6912 | Pred: IsAnd ? ICmpInst::getInversePredicate(pred: Pred2) : Pred2, Other: C2); |
6913 | if (Offset2) |
6914 | CR2 = CR2.subtract(CI: *Offset2); |
6915 | |
6916 | bool CreateMask = false; |
6917 | APInt LowerDiff; |
6918 | std::optional<ConstantRange> CR = CR1.exactUnionWith(CR: CR2); |
6919 | if (!CR) { |
6920 | // We need non-wrapping ranges. |
6921 | if (CR1.isWrappedSet() || CR2.isWrappedSet()) |
6922 | return false; |
6923 | |
6924 | // Check whether we have equal-size ranges that only differ by one bit. |
6925 | // In that case we can apply a mask to map one range onto the other. |
6926 | LowerDiff = CR1.getLower() ^ CR2.getLower(); |
6927 | APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1); |
6928 | APInt CR1Size = CR1.getUpper() - CR1.getLower(); |
6929 | if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff || |
6930 | CR1Size != CR2.getUpper() - CR2.getLower()) |
6931 | return false; |
6932 | |
6933 | CR = CR1.getLower().ult(RHS: CR2.getLower()) ? CR1 : CR2; |
6934 | CreateMask = true; |
6935 | } |
6936 | |
6937 | if (IsAnd) |
6938 | CR = CR->inverse(); |
6939 | |
6940 | CmpInst::Predicate NewPred; |
6941 | APInt NewC, Offset; |
6942 | CR->getEquivalentICmp(Pred&: NewPred, RHS&: NewC, Offset); |
6943 | |
6944 | // We take the result type of one of the original icmps, CmpTy, for |
6945 | // the to be build icmp. The operand type, CmpOperandTy, is used for |
6946 | // the other instructions and constants to be build. The types of |
6947 | // the parameters and output are the same for add and and. CmpTy |
6948 | // and the type of DstReg might differ. That is why we zext or trunc |
6949 | // the icmp into the destination register. |
6950 | |
6951 | MatchInfo = [=](MachineIRBuilder &B) { |
6952 | if (CreateMask && Offset != 0) { |
6953 | auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff); |
6954 | auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask. |
6955 | auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset); |
6956 | auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: And, Src1: OffsetC, Flags); |
6957 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
6958 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon); |
6959 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
6960 | } else if (CreateMask && Offset == 0) { |
6961 | auto TildeLowerDiff = B.buildConstant(Res: CmpOperandTy, Val: ~LowerDiff); |
6962 | auto And = B.buildAnd(Dst: CmpOperandTy, Src0: R1, Src1: TildeLowerDiff); // the mask. |
6963 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
6964 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: And, Op1: NewCon); |
6965 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
6966 | } else if (!CreateMask && Offset != 0) { |
6967 | auto OffsetC = B.buildConstant(Res: CmpOperandTy, Val: Offset); |
6968 | auto Add = B.buildAdd(Dst: CmpOperandTy, Src0: R1, Src1: OffsetC, Flags); |
6969 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
6970 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: Add, Op1: NewCon); |
6971 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
6972 | } else if (!CreateMask && Offset == 0) { |
6973 | auto NewCon = B.buildConstant(Res: CmpOperandTy, Val: NewC); |
6974 | auto ICmp = B.buildICmp(Pred: NewPred, Res: CmpTy, Op0: R1, Op1: NewCon); |
6975 | B.buildZExtOrTrunc(Res: DstReg, Op: ICmp); |
6976 | } else { |
6977 | llvm_unreachable("unexpected configuration of CreateMask and Offset" ); |
6978 | } |
6979 | }; |
6980 | return true; |
6981 | } |
6982 | |
6983 | bool CombinerHelper::tryFoldLogicOfFCmps(GLogicalBinOp *Logic, |
6984 | BuildFnTy &MatchInfo) { |
6985 | assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpecte xor" ); |
6986 | Register DestReg = Logic->getReg(Idx: 0); |
6987 | Register LHS = Logic->getLHSReg(); |
6988 | Register RHS = Logic->getRHSReg(); |
6989 | bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND; |
6990 | |
6991 | // We need a compare on the LHS register. |
6992 | GFCmp *Cmp1 = getOpcodeDef<GFCmp>(Reg: LHS, MRI); |
6993 | if (!Cmp1) |
6994 | return false; |
6995 | |
6996 | // We need a compare on the RHS register. |
6997 | GFCmp *Cmp2 = getOpcodeDef<GFCmp>(Reg: RHS, MRI); |
6998 | if (!Cmp2) |
6999 | return false; |
7000 | |
7001 | LLT CmpTy = MRI.getType(Reg: Cmp1->getReg(Idx: 0)); |
7002 | LLT CmpOperandTy = MRI.getType(Reg: Cmp1->getLHSReg()); |
7003 | |
7004 | // We build one fcmp, want to fold the fcmps, replace the logic op, |
7005 | // and the fcmps must have the same shape. |
7006 | if (!isLegalOrBeforeLegalizer( |
7007 | Query: {TargetOpcode::G_FCMP, {CmpTy, CmpOperandTy}}) || |
7008 | !MRI.hasOneNonDBGUse(RegNo: Logic->getReg(Idx: 0)) || |
7009 | !MRI.hasOneNonDBGUse(RegNo: Cmp1->getReg(Idx: 0)) || |
7010 | !MRI.hasOneNonDBGUse(RegNo: Cmp2->getReg(Idx: 0)) || |
7011 | MRI.getType(Reg: Cmp1->getLHSReg()) != MRI.getType(Reg: Cmp2->getLHSReg())) |
7012 | return false; |
7013 | |
7014 | CmpInst::Predicate PredL = Cmp1->getCond(); |
7015 | CmpInst::Predicate PredR = Cmp2->getCond(); |
7016 | Register LHS0 = Cmp1->getLHSReg(); |
7017 | Register LHS1 = Cmp1->getRHSReg(); |
7018 | Register RHS0 = Cmp2->getLHSReg(); |
7019 | Register RHS1 = Cmp2->getRHSReg(); |
7020 | |
7021 | if (LHS0 == RHS1 && LHS1 == RHS0) { |
7022 | // Swap RHS operands to match LHS. |
7023 | PredR = CmpInst::getSwappedPredicate(pred: PredR); |
7024 | std::swap(a&: RHS0, b&: RHS1); |
7025 | } |
7026 | |
7027 | if (LHS0 == RHS0 && LHS1 == RHS1) { |
7028 | // We determine the new predicate. |
7029 | unsigned CmpCodeL = getFCmpCode(CC: PredL); |
7030 | unsigned CmpCodeR = getFCmpCode(CC: PredR); |
7031 | unsigned NewPred = IsAnd ? CmpCodeL & CmpCodeR : CmpCodeL | CmpCodeR; |
7032 | unsigned Flags = Cmp1->getFlags() | Cmp2->getFlags(); |
7033 | MatchInfo = [=](MachineIRBuilder &B) { |
7034 | // The fcmp predicates fill the lower part of the enum. |
7035 | FCmpInst::Predicate Pred = static_cast<FCmpInst::Predicate>(NewPred); |
7036 | if (Pred == FCmpInst::FCMP_FALSE && |
7037 | isConstantLegalOrBeforeLegalizer(Ty: CmpTy)) { |
7038 | auto False = B.buildConstant(Res: CmpTy, Val: 0); |
7039 | B.buildZExtOrTrunc(Res: DestReg, Op: False); |
7040 | } else if (Pred == FCmpInst::FCMP_TRUE && |
7041 | isConstantLegalOrBeforeLegalizer(Ty: CmpTy)) { |
7042 | auto True = |
7043 | B.buildConstant(Res: CmpTy, Val: getICmpTrueVal(TLI: getTargetLowering(), |
7044 | IsVector: CmpTy.isVector() /*isVector*/, |
7045 | IsFP: true /*isFP*/)); |
7046 | B.buildZExtOrTrunc(Res: DestReg, Op: True); |
7047 | } else { // We take the predicate without predicate optimizations. |
7048 | auto Cmp = B.buildFCmp(Pred, Res: CmpTy, Op0: LHS0, Op1: LHS1, Flags); |
7049 | B.buildZExtOrTrunc(Res: DestReg, Op: Cmp); |
7050 | } |
7051 | }; |
7052 | return true; |
7053 | } |
7054 | |
7055 | return false; |
7056 | } |
7057 | |
7058 | bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) { |
7059 | GAnd *And = cast<GAnd>(Val: &MI); |
7060 | |
7061 | if (tryFoldAndOrOrICmpsUsingRanges(Logic: And, MatchInfo)) |
7062 | return true; |
7063 | |
7064 | if (tryFoldLogicOfFCmps(Logic: And, MatchInfo)) |
7065 | return true; |
7066 | |
7067 | return false; |
7068 | } |
7069 | |
7070 | bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) { |
7071 | GOr *Or = cast<GOr>(Val: &MI); |
7072 | |
7073 | if (tryFoldAndOrOrICmpsUsingRanges(Logic: Or, MatchInfo)) |
7074 | return true; |
7075 | |
7076 | if (tryFoldLogicOfFCmps(Logic: Or, MatchInfo)) |
7077 | return true; |
7078 | |
7079 | return false; |
7080 | } |
7081 | |
7082 | bool CombinerHelper::matchAddOverflow(MachineInstr &MI, BuildFnTy &MatchInfo) { |
7083 | GAddCarryOut *Add = cast<GAddCarryOut>(Val: &MI); |
7084 | |
7085 | // Addo has no flags |
7086 | Register Dst = Add->getReg(Idx: 0); |
7087 | Register Carry = Add->getReg(Idx: 1); |
7088 | Register LHS = Add->getLHSReg(); |
7089 | Register RHS = Add->getRHSReg(); |
7090 | bool IsSigned = Add->isSigned(); |
7091 | LLT DstTy = MRI.getType(Reg: Dst); |
7092 | LLT CarryTy = MRI.getType(Reg: Carry); |
7093 | |
7094 | // Fold addo, if the carry is dead -> add, undef. |
7095 | if (MRI.use_nodbg_empty(RegNo: Carry) && |
7096 | isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {DstTy}})) { |
7097 | MatchInfo = [=](MachineIRBuilder &B) { |
7098 | B.buildAdd(Dst, Src0: LHS, Src1: RHS); |
7099 | B.buildUndef(Res: Carry); |
7100 | }; |
7101 | return true; |
7102 | } |
7103 | |
7104 | // Canonicalize constant to RHS. |
7105 | if (isConstantOrConstantVectorI(Src: LHS) && !isConstantOrConstantVectorI(Src: RHS)) { |
7106 | if (IsSigned) { |
7107 | MatchInfo = [=](MachineIRBuilder &B) { |
7108 | B.buildSAddo(Res: Dst, CarryOut: Carry, Op0: RHS, Op1: LHS); |
7109 | }; |
7110 | return true; |
7111 | } |
7112 | // !IsSigned |
7113 | MatchInfo = [=](MachineIRBuilder &B) { |
7114 | B.buildUAddo(Res: Dst, CarryOut: Carry, Op0: RHS, Op1: LHS); |
7115 | }; |
7116 | return true; |
7117 | } |
7118 | |
7119 | std::optional<APInt> MaybeLHS = getConstantOrConstantSplatVector(Src: LHS); |
7120 | std::optional<APInt> MaybeRHS = getConstantOrConstantSplatVector(Src: RHS); |
7121 | |
7122 | // Fold addo(c1, c2) -> c3, carry. |
7123 | if (MaybeLHS && MaybeRHS && isConstantLegalOrBeforeLegalizer(Ty: DstTy) && |
7124 | isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) { |
7125 | bool Overflow; |
7126 | APInt Result = IsSigned ? MaybeLHS->sadd_ov(RHS: *MaybeRHS, Overflow) |
7127 | : MaybeLHS->uadd_ov(RHS: *MaybeRHS, Overflow); |
7128 | MatchInfo = [=](MachineIRBuilder &B) { |
7129 | B.buildConstant(Res: Dst, Val: Result); |
7130 | B.buildConstant(Res: Carry, Val: Overflow); |
7131 | }; |
7132 | return true; |
7133 | } |
7134 | |
7135 | // Fold (addo x, 0) -> x, no carry |
7136 | if (MaybeRHS && *MaybeRHS == 0 && isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) { |
7137 | MatchInfo = [=](MachineIRBuilder &B) { |
7138 | B.buildCopy(Res: Dst, Op: LHS); |
7139 | B.buildConstant(Res: Carry, Val: 0); |
7140 | }; |
7141 | return true; |
7142 | } |
7143 | |
7144 | // Given 2 constant operands whose sum does not overflow: |
7145 | // uaddo (X +nuw C0), C1 -> uaddo X, C0 + C1 |
7146 | // saddo (X +nsw C0), C1 -> saddo X, C0 + C1 |
7147 | GAdd *AddLHS = getOpcodeDef<GAdd>(Reg: LHS, MRI); |
7148 | if (MaybeRHS && AddLHS && MRI.hasOneNonDBGUse(RegNo: Add->getReg(Idx: 0)) && |
7149 | ((IsSigned && AddLHS->getFlag(Flag: MachineInstr::MIFlag::NoSWrap)) || |
7150 | (!IsSigned && AddLHS->getFlag(Flag: MachineInstr::MIFlag::NoUWrap)))) { |
7151 | std::optional<APInt> MaybeAddRHS = |
7152 | getConstantOrConstantSplatVector(Src: AddLHS->getRHSReg()); |
7153 | if (MaybeAddRHS) { |
7154 | bool Overflow; |
7155 | APInt NewC = IsSigned ? MaybeAddRHS->sadd_ov(RHS: *MaybeRHS, Overflow) |
7156 | : MaybeAddRHS->uadd_ov(RHS: *MaybeRHS, Overflow); |
7157 | if (!Overflow && isConstantLegalOrBeforeLegalizer(Ty: DstTy)) { |
7158 | if (IsSigned) { |
7159 | MatchInfo = [=](MachineIRBuilder &B) { |
7160 | auto ConstRHS = B.buildConstant(Res: DstTy, Val: NewC); |
7161 | B.buildSAddo(Res: Dst, CarryOut: Carry, Op0: AddLHS->getLHSReg(), Op1: ConstRHS); |
7162 | }; |
7163 | return true; |
7164 | } |
7165 | // !IsSigned |
7166 | MatchInfo = [=](MachineIRBuilder &B) { |
7167 | auto ConstRHS = B.buildConstant(Res: DstTy, Val: NewC); |
7168 | B.buildUAddo(Res: Dst, CarryOut: Carry, Op0: AddLHS->getLHSReg(), Op1: ConstRHS); |
7169 | }; |
7170 | return true; |
7171 | } |
7172 | } |
7173 | }; |
7174 | |
7175 | // We try to combine addo to non-overflowing add. |
7176 | if (!isLegalOrBeforeLegalizer(Query: {TargetOpcode::G_ADD, {DstTy}}) || |
7177 | !isConstantLegalOrBeforeLegalizer(Ty: CarryTy)) |
7178 | return false; |
7179 | |
7180 | // We try to combine uaddo to non-overflowing add. |
7181 | if (!IsSigned) { |
7182 | ConstantRange CRLHS = |
7183 | ConstantRange::fromKnownBits(Known: KB->getKnownBits(R: LHS), /*IsSigned=*/false); |
7184 | ConstantRange CRRHS = |
7185 | ConstantRange::fromKnownBits(Known: KB->getKnownBits(R: RHS), /*IsSigned=*/false); |
7186 | |
7187 | switch (CRLHS.unsignedAddMayOverflow(Other: CRRHS)) { |
7188 | case ConstantRange::OverflowResult::MayOverflow: |
7189 | return false; |
7190 | case ConstantRange::OverflowResult::NeverOverflows: { |
7191 | MatchInfo = [=](MachineIRBuilder &B) { |
7192 | B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoUWrap); |
7193 | B.buildConstant(Res: Carry, Val: 0); |
7194 | }; |
7195 | return true; |
7196 | } |
7197 | case ConstantRange::OverflowResult::AlwaysOverflowsLow: |
7198 | case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { |
7199 | MatchInfo = [=](MachineIRBuilder &B) { |
7200 | B.buildAdd(Dst, Src0: LHS, Src1: RHS); |
7201 | B.buildConstant(Res: Carry, Val: 1); |
7202 | }; |
7203 | return true; |
7204 | } |
7205 | } |
7206 | return false; |
7207 | } |
7208 | |
7209 | // We try to combine saddo to non-overflowing add. |
7210 | |
7211 | // If LHS and RHS each have at least two sign bits, then there is no signed |
7212 | // overflow. |
7213 | if (KB->computeNumSignBits(R: RHS) > 1 && KB->computeNumSignBits(R: LHS) > 1) { |
7214 | MatchInfo = [=](MachineIRBuilder &B) { |
7215 | B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap); |
7216 | B.buildConstant(Res: Carry, Val: 0); |
7217 | }; |
7218 | return true; |
7219 | } |
7220 | |
7221 | ConstantRange CRLHS = |
7222 | ConstantRange::fromKnownBits(Known: KB->getKnownBits(R: LHS), /*IsSigned=*/true); |
7223 | ConstantRange CRRHS = |
7224 | ConstantRange::fromKnownBits(Known: KB->getKnownBits(R: RHS), /*IsSigned=*/true); |
7225 | |
7226 | switch (CRLHS.signedAddMayOverflow(Other: CRRHS)) { |
7227 | case ConstantRange::OverflowResult::MayOverflow: |
7228 | return false; |
7229 | case ConstantRange::OverflowResult::NeverOverflows: { |
7230 | MatchInfo = [=](MachineIRBuilder &B) { |
7231 | B.buildAdd(Dst, Src0: LHS, Src1: RHS, Flags: MachineInstr::MIFlag::NoSWrap); |
7232 | B.buildConstant(Res: Carry, Val: 0); |
7233 | }; |
7234 | return true; |
7235 | } |
7236 | case ConstantRange::OverflowResult::AlwaysOverflowsLow: |
7237 | case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { |
7238 | MatchInfo = [=](MachineIRBuilder &B) { |
7239 | B.buildAdd(Dst, Src0: LHS, Src1: RHS); |
7240 | B.buildConstant(Res: Carry, Val: 1); |
7241 | }; |
7242 | return true; |
7243 | } |
7244 | } |
7245 | |
7246 | return false; |
7247 | } |
7248 | |