1//===- DAGCombiner.cpp - Implement a DAG node combiner --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass combines dag nodes to form fewer, simpler DAG nodes. It can be run
10// both before and after the DAG is legalized.
11//
12// This pass is not a substitute for the LLVM IR instcombine pass. This pass is
13// primarily intended to handle simplification opportunities that are implicit
14// in the LLVM IR and exposed by the various codegen lowering phases.
15//
16//===----------------------------------------------------------------------===//
17
18#include "llvm/ADT/APFloat.h"
19#include "llvm/ADT/APInt.h"
20#include "llvm/ADT/ArrayRef.h"
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/IntervalMap.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/ADT/SmallBitVector.h"
26#include "llvm/ADT/SmallPtrSet.h"
27#include "llvm/ADT/SmallSet.h"
28#include "llvm/ADT/SmallVector.h"
29#include "llvm/ADT/Statistic.h"
30#include "llvm/Analysis/AliasAnalysis.h"
31#include "llvm/Analysis/MemoryLocation.h"
32#include "llvm/Analysis/TargetLibraryInfo.h"
33#include "llvm/Analysis/ValueTracking.h"
34#include "llvm/Analysis/VectorUtils.h"
35#include "llvm/CodeGen/ByteProvider.h"
36#include "llvm/CodeGen/DAGCombine.h"
37#include "llvm/CodeGen/ISDOpcodes.h"
38#include "llvm/CodeGen/MachineFunction.h"
39#include "llvm/CodeGen/MachineMemOperand.h"
40#include "llvm/CodeGen/RuntimeLibcalls.h"
41#include "llvm/CodeGen/SDPatternMatch.h"
42#include "llvm/CodeGen/SelectionDAG.h"
43#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
44#include "llvm/CodeGen/SelectionDAGNodes.h"
45#include "llvm/CodeGen/SelectionDAGTargetInfo.h"
46#include "llvm/CodeGen/TargetLowering.h"
47#include "llvm/CodeGen/TargetRegisterInfo.h"
48#include "llvm/CodeGen/TargetSubtargetInfo.h"
49#include "llvm/CodeGen/ValueTypes.h"
50#include "llvm/CodeGenTypes/MachineValueType.h"
51#include "llvm/IR/Attributes.h"
52#include "llvm/IR/Constant.h"
53#include "llvm/IR/DataLayout.h"
54#include "llvm/IR/DerivedTypes.h"
55#include "llvm/IR/Function.h"
56#include "llvm/IR/Metadata.h"
57#include "llvm/Support/Casting.h"
58#include "llvm/Support/CodeGen.h"
59#include "llvm/Support/CommandLine.h"
60#include "llvm/Support/Compiler.h"
61#include "llvm/Support/Debug.h"
62#include "llvm/Support/DebugCounter.h"
63#include "llvm/Support/ErrorHandling.h"
64#include "llvm/Support/KnownBits.h"
65#include "llvm/Support/MathExtras.h"
66#include "llvm/Support/raw_ostream.h"
67#include "llvm/Target/TargetMachine.h"
68#include "llvm/Target/TargetOptions.h"
69#include <algorithm>
70#include <cassert>
71#include <cstdint>
72#include <functional>
73#include <iterator>
74#include <optional>
75#include <string>
76#include <tuple>
77#include <utility>
78#include <variant>
79
80#include "MatchContext.h"
81
82using namespace llvm;
83using namespace llvm::SDPatternMatch;
84
85#define DEBUG_TYPE "dagcombine"
86
87STATISTIC(NodesCombined , "Number of dag nodes combined");
88STATISTIC(PreIndexedNodes , "Number of pre-indexed nodes created");
89STATISTIC(PostIndexedNodes, "Number of post-indexed nodes created");
90STATISTIC(OpsNarrowed , "Number of load/op/store narrowed");
91STATISTIC(LdStFP2Int , "Number of fp load/store pairs transformed to int");
92STATISTIC(SlicedLoads, "Number of load sliced");
93STATISTIC(NumFPLogicOpsConv, "Number of logic ops converted to fp ops");
94
95DEBUG_COUNTER(DAGCombineCounter, "dagcombine",
96 "Controls whether a DAG combine is performed for a node");
97
98static cl::opt<bool>
99CombinerGlobalAA("combiner-global-alias-analysis", cl::Hidden,
100 cl::desc("Enable DAG combiner's use of IR alias analysis"));
101
102static cl::opt<bool>
103UseTBAA("combiner-use-tbaa", cl::Hidden, cl::init(Val: true),
104 cl::desc("Enable DAG combiner's use of TBAA"));
105
106#ifndef NDEBUG
107static cl::opt<std::string>
108CombinerAAOnlyFunc("combiner-aa-only-func", cl::Hidden,
109 cl::desc("Only use DAG-combiner alias analysis in this"
110 " function"));
111#endif
112
113/// Hidden option to stress test load slicing, i.e., when this option
114/// is enabled, load slicing bypasses most of its profitability guards.
115static cl::opt<bool>
116StressLoadSlicing("combiner-stress-load-slicing", cl::Hidden,
117 cl::desc("Bypass the profitability model of load slicing"),
118 cl::init(Val: false));
119
120static cl::opt<bool>
121 MaySplitLoadIndex("combiner-split-load-index", cl::Hidden, cl::init(Val: true),
122 cl::desc("DAG combiner may split indexing from loads"));
123
124static cl::opt<bool>
125 EnableStoreMerging("combiner-store-merging", cl::Hidden, cl::init(Val: true),
126 cl::desc("DAG combiner enable merging multiple stores "
127 "into a wider store"));
128
129static cl::opt<unsigned> TokenFactorInlineLimit(
130 "combiner-tokenfactor-inline-limit", cl::Hidden, cl::init(Val: 2048),
131 cl::desc("Limit the number of operands to inline for Token Factors"));
132
133static cl::opt<unsigned> StoreMergeDependenceLimit(
134 "combiner-store-merge-dependence-limit", cl::Hidden, cl::init(Val: 10),
135 cl::desc("Limit the number of times for the same StoreNode and RootNode "
136 "to bail out in store merging dependence check"));
137
138static cl::opt<bool> EnableReduceLoadOpStoreWidth(
139 "combiner-reduce-load-op-store-width", cl::Hidden, cl::init(Val: true),
140 cl::desc("DAG combiner enable reducing the width of load/op/store "
141 "sequence"));
142
143static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
144 "combiner-shrink-load-replace-store-with-store", cl::Hidden, cl::init(Val: true),
145 cl::desc("DAG combiner enable load/<replace bytes>/store with "
146 "a narrower store"));
147
148static cl::opt<bool> EnableVectorFCopySignExtendRound(
149 "combiner-vector-fcopysign-extend-round", cl::Hidden, cl::init(Val: false),
150 cl::desc(
151 "Enable merging extends and rounds into FCOPYSIGN on vector types"));
152
153namespace {
154
155 class DAGCombiner {
156 SelectionDAG &DAG;
157 const TargetLowering &TLI;
158 const SelectionDAGTargetInfo *STI;
159 CombineLevel Level = BeforeLegalizeTypes;
160 CodeGenOptLevel OptLevel;
161 bool LegalDAG = false;
162 bool LegalOperations = false;
163 bool LegalTypes = false;
164 bool ForCodeSize;
165 bool DisableGenericCombines;
166
167 /// Worklist of all of the nodes that need to be simplified.
168 ///
169 /// This must behave as a stack -- new nodes to process are pushed onto the
170 /// back and when processing we pop off of the back.
171 ///
172 /// The worklist will not contain duplicates but may contain null entries
173 /// due to nodes being deleted from the underlying DAG.
174 SmallVector<SDNode *, 64> Worklist;
175
176 /// Mapping from an SDNode to its position on the worklist.
177 ///
178 /// This is used to find and remove nodes from the worklist (by nulling
179 /// them) when they are deleted from the underlying DAG. It relies on
180 /// stable indices of nodes within the worklist.
181 DenseMap<SDNode *, unsigned> WorklistMap;
182
183 /// This records all nodes attempted to be added to the worklist since we
184 /// considered a new worklist entry. As we keep do not add duplicate nodes
185 /// in the worklist, this is different from the tail of the worklist.
186 SmallSetVector<SDNode *, 32> PruningList;
187
188 /// Set of nodes which have been combined (at least once).
189 ///
190 /// This is used to allow us to reliably add any operands of a DAG node
191 /// which have not yet been combined to the worklist.
192 SmallPtrSet<SDNode *, 32> CombinedNodes;
193
194 /// Map from candidate StoreNode to the pair of RootNode and count.
195 /// The count is used to track how many times we have seen the StoreNode
196 /// with the same RootNode bail out in dependence check. If we have seen
197 /// the bail out for the same pair many times over a limit, we won't
198 /// consider the StoreNode with the same RootNode as store merging
199 /// candidate again.
200 DenseMap<SDNode *, std::pair<SDNode *, unsigned>> StoreRootCountMap;
201
202 // AA - Used for DAG load/store alias analysis.
203 AliasAnalysis *AA;
204
205 /// When an instruction is simplified, add all users of the instruction to
206 /// the work lists because they might get more simplified now.
207 void AddUsersToWorklist(SDNode *N) {
208 for (SDNode *Node : N->uses())
209 AddToWorklist(N: Node);
210 }
211
212 /// Convenient shorthand to add a node and all of its user to the worklist.
213 void AddToWorklistWithUsers(SDNode *N) {
214 AddUsersToWorklist(N);
215 AddToWorklist(N);
216 }
217
218 // Prune potentially dangling nodes. This is called after
219 // any visit to a node, but should also be called during a visit after any
220 // failed combine which may have created a DAG node.
221 void clearAddedDanglingWorklistEntries() {
222 // Check any nodes added to the worklist to see if they are prunable.
223 while (!PruningList.empty()) {
224 auto *N = PruningList.pop_back_val();
225 if (N->use_empty())
226 recursivelyDeleteUnusedNodes(N);
227 }
228 }
229
230 SDNode *getNextWorklistEntry() {
231 // Before we do any work, remove nodes that are not in use.
232 clearAddedDanglingWorklistEntries();
233 SDNode *N = nullptr;
234 // The Worklist holds the SDNodes in order, but it may contain null
235 // entries.
236 while (!N && !Worklist.empty()) {
237 N = Worklist.pop_back_val();
238 }
239
240 if (N) {
241 bool GoodWorklistEntry = WorklistMap.erase(Val: N);
242 (void)GoodWorklistEntry;
243 assert(GoodWorklistEntry &&
244 "Found a worklist entry without a corresponding map entry!");
245 }
246 return N;
247 }
248
249 /// Call the node-specific routine that folds each particular type of node.
250 SDValue visit(SDNode *N);
251
252 public:
253 DAGCombiner(SelectionDAG &D, AliasAnalysis *AA, CodeGenOptLevel OL)
254 : DAG(D), TLI(D.getTargetLoweringInfo()),
255 STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL), AA(AA) {
256 ForCodeSize = DAG.shouldOptForSize();
257 DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
258
259 MaximumLegalStoreInBits = 0;
260 // We use the minimum store size here, since that's all we can guarantee
261 // for the scalable vector types.
262 for (MVT VT : MVT::all_valuetypes())
263 if (EVT(VT).isSimple() && VT != MVT::Other &&
264 TLI.isTypeLegal(EVT(VT)) &&
265 VT.getSizeInBits().getKnownMinValue() >= MaximumLegalStoreInBits)
266 MaximumLegalStoreInBits = VT.getSizeInBits().getKnownMinValue();
267 }
268
269 void ConsiderForPruning(SDNode *N) {
270 // Mark this for potential pruning.
271 PruningList.insert(X: N);
272 }
273
274 /// Add to the worklist making sure its instance is at the back (next to be
275 /// processed.)
276 void AddToWorklist(SDNode *N, bool IsCandidateForPruning = true) {
277 assert(N->getOpcode() != ISD::DELETED_NODE &&
278 "Deleted Node added to Worklist");
279
280 // Skip handle nodes as they can't usefully be combined and confuse the
281 // zero-use deletion strategy.
282 if (N->getOpcode() == ISD::HANDLENODE)
283 return;
284
285 if (IsCandidateForPruning)
286 ConsiderForPruning(N);
287
288 if (WorklistMap.insert(KV: std::make_pair(x&: N, y: Worklist.size())).second)
289 Worklist.push_back(Elt: N);
290 }
291
292 /// Remove all instances of N from the worklist.
293 void removeFromWorklist(SDNode *N) {
294 CombinedNodes.erase(Ptr: N);
295 PruningList.remove(X: N);
296 StoreRootCountMap.erase(Val: N);
297
298 auto It = WorklistMap.find(Val: N);
299 if (It == WorklistMap.end())
300 return; // Not in the worklist.
301
302 // Null out the entry rather than erasing it to avoid a linear operation.
303 Worklist[It->second] = nullptr;
304 WorklistMap.erase(I: It);
305 }
306
307 void deleteAndRecombine(SDNode *N);
308 bool recursivelyDeleteUnusedNodes(SDNode *N);
309
310 /// Replaces all uses of the results of one DAG node with new values.
311 SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
312 bool AddTo = true);
313
314 /// Replaces all uses of the results of one DAG node with new values.
315 SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) {
316 return CombineTo(N, To: &Res, NumTo: 1, AddTo);
317 }
318
319 /// Replaces all uses of the results of one DAG node with new values.
320 SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1,
321 bool AddTo = true) {
322 SDValue To[] = { Res0, Res1 };
323 return CombineTo(N, To, NumTo: 2, AddTo);
324 }
325
326 void CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO);
327
328 private:
329 unsigned MaximumLegalStoreInBits;
330
331 /// Check the specified integer node value to see if it can be simplified or
332 /// if things it uses can be simplified by bit propagation.
333 /// If so, return true.
334 bool SimplifyDemandedBits(SDValue Op) {
335 unsigned BitWidth = Op.getScalarValueSizeInBits();
336 APInt DemandedBits = APInt::getAllOnes(numBits: BitWidth);
337 return SimplifyDemandedBits(Op, DemandedBits);
338 }
339
340 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits) {
341 EVT VT = Op.getValueType();
342 APInt DemandedElts = VT.isFixedLengthVector()
343 ? APInt::getAllOnes(numBits: VT.getVectorNumElements())
344 : APInt(1, 1);
345 return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, AssumeSingleUse: false);
346 }
347
348 /// Check the specified vector node value to see if it can be simplified or
349 /// if things it uses can be simplified as it only uses some of the
350 /// elements. If so, return true.
351 bool SimplifyDemandedVectorElts(SDValue Op) {
352 // TODO: For now just pretend it cannot be simplified.
353 if (Op.getValueType().isScalableVector())
354 return false;
355
356 unsigned NumElts = Op.getValueType().getVectorNumElements();
357 APInt DemandedElts = APInt::getAllOnes(numBits: NumElts);
358 return SimplifyDemandedVectorElts(Op, DemandedElts);
359 }
360
361 bool SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
362 const APInt &DemandedElts,
363 bool AssumeSingleUse = false);
364 bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedElts,
365 bool AssumeSingleUse = false);
366
367 bool CombineToPreIndexedLoadStore(SDNode *N);
368 bool CombineToPostIndexedLoadStore(SDNode *N);
369 SDValue SplitIndexingFromLoad(LoadSDNode *LD);
370 bool SliceUpLoad(SDNode *N);
371
372 // Looks up the chain to find a unique (unaliased) store feeding the passed
373 // load. If no such store is found, returns a nullptr.
374 // Note: This will look past a CALLSEQ_START if the load is chained to it so
375 // so that it can find stack stores for byval params.
376 StoreSDNode *getUniqueStoreFeeding(LoadSDNode *LD, int64_t &Offset);
377 // Scalars have size 0 to distinguish from singleton vectors.
378 SDValue ForwardStoreValueToDirectLoad(LoadSDNode *LD);
379 bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
380 bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
381
382 /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
383 /// load.
384 ///
385 /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
386 /// \param InVecVT type of the input vector to EVE with bitcasts resolved.
387 /// \param EltNo index of the vector element to load.
388 /// \param OriginalLoad load that EVE came from to be replaced.
389 /// \returns EVE on success SDValue() on failure.
390 SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
391 SDValue EltNo,
392 LoadSDNode *OriginalLoad);
393 void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
394 SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
395 SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
396 SDValue ZExtPromoteOperand(SDValue Op, EVT PVT);
397 SDValue PromoteIntBinOp(SDValue Op);
398 SDValue PromoteIntShiftOp(SDValue Op);
399 SDValue PromoteExtend(SDValue Op);
400 bool PromoteLoad(SDValue Op);
401
402 SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
403 SDValue RHS, SDValue True, SDValue False,
404 ISD::CondCode CC);
405
406 /// Call the node-specific routine that knows how to fold each
407 /// particular type of node. If that doesn't do anything, try the
408 /// target-specific DAG combines.
409 SDValue combine(SDNode *N);
410
411 // Visitation implementation - Implement dag node combining for different
412 // node types. The semantics are as follows:
413 // Return Value:
414 // SDValue.getNode() == 0 - No change was made
415 // SDValue.getNode() == N - N was replaced, is dead and has been handled.
416 // otherwise - N should be replaced by the returned Operand.
417 //
418 SDValue visitTokenFactor(SDNode *N);
419 SDValue visitMERGE_VALUES(SDNode *N);
420 SDValue visitADD(SDNode *N);
421 SDValue visitADDLike(SDNode *N);
422 SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
423 SDValue visitSUB(SDNode *N);
424 SDValue visitADDSAT(SDNode *N);
425 SDValue visitSUBSAT(SDNode *N);
426 SDValue visitADDC(SDNode *N);
427 SDValue visitADDO(SDNode *N);
428 SDValue visitUADDOLike(SDValue N0, SDValue N1, SDNode *N);
429 SDValue visitSUBC(SDNode *N);
430 SDValue visitSUBO(SDNode *N);
431 SDValue visitADDE(SDNode *N);
432 SDValue visitUADDO_CARRY(SDNode *N);
433 SDValue visitSADDO_CARRY(SDNode *N);
434 SDValue visitUADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
435 SDNode *N);
436 SDValue visitSADDO_CARRYLike(SDValue N0, SDValue N1, SDValue CarryIn,
437 SDNode *N);
438 SDValue visitSUBE(SDNode *N);
439 SDValue visitUSUBO_CARRY(SDNode *N);
440 SDValue visitSSUBO_CARRY(SDNode *N);
441 SDValue visitMUL(SDNode *N);
442 SDValue visitMULFIX(SDNode *N);
443 SDValue useDivRem(SDNode *N);
444 SDValue visitSDIV(SDNode *N);
445 SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N);
446 SDValue visitUDIV(SDNode *N);
447 SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N);
448 SDValue visitREM(SDNode *N);
449 SDValue visitMULHU(SDNode *N);
450 SDValue visitMULHS(SDNode *N);
451 SDValue visitAVG(SDNode *N);
452 SDValue visitABD(SDNode *N);
453 SDValue visitSMUL_LOHI(SDNode *N);
454 SDValue visitUMUL_LOHI(SDNode *N);
455 SDValue visitMULO(SDNode *N);
456 SDValue visitIMINMAX(SDNode *N);
457 SDValue visitAND(SDNode *N);
458 SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *N);
459 SDValue visitOR(SDNode *N);
460 SDValue visitORLike(SDValue N0, SDValue N1, const SDLoc &DL);
461 SDValue visitXOR(SDNode *N);
462 SDValue SimplifyVCastOp(SDNode *N, const SDLoc &DL);
463 SDValue SimplifyVBinOp(SDNode *N, const SDLoc &DL);
464 SDValue visitSHL(SDNode *N);
465 SDValue visitSRA(SDNode *N);
466 SDValue visitSRL(SDNode *N);
467 SDValue visitFunnelShift(SDNode *N);
468 SDValue visitSHLSAT(SDNode *N);
469 SDValue visitRotate(SDNode *N);
470 SDValue visitABS(SDNode *N);
471 SDValue visitBSWAP(SDNode *N);
472 SDValue visitBITREVERSE(SDNode *N);
473 SDValue visitCTLZ(SDNode *N);
474 SDValue visitCTLZ_ZERO_UNDEF(SDNode *N);
475 SDValue visitCTTZ(SDNode *N);
476 SDValue visitCTTZ_ZERO_UNDEF(SDNode *N);
477 SDValue visitCTPOP(SDNode *N);
478 SDValue visitSELECT(SDNode *N);
479 SDValue visitVSELECT(SDNode *N);
480 SDValue visitVP_SELECT(SDNode *N);
481 SDValue visitSELECT_CC(SDNode *N);
482 SDValue visitSETCC(SDNode *N);
483 SDValue visitSETCCCARRY(SDNode *N);
484 SDValue visitSIGN_EXTEND(SDNode *N);
485 SDValue visitZERO_EXTEND(SDNode *N);
486 SDValue visitANY_EXTEND(SDNode *N);
487 SDValue visitAssertExt(SDNode *N);
488 SDValue visitAssertAlign(SDNode *N);
489 SDValue visitSIGN_EXTEND_INREG(SDNode *N);
490 SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
491 SDValue visitTRUNCATE(SDNode *N);
492 SDValue visitBITCAST(SDNode *N);
493 SDValue visitFREEZE(SDNode *N);
494 SDValue visitBUILD_PAIR(SDNode *N);
495 SDValue visitFADD(SDNode *N);
496 SDValue visitVP_FADD(SDNode *N);
497 SDValue visitVP_FSUB(SDNode *N);
498 SDValue visitSTRICT_FADD(SDNode *N);
499 SDValue visitFSUB(SDNode *N);
500 SDValue visitFMUL(SDNode *N);
501 template <class MatchContextClass> SDValue visitFMA(SDNode *N);
502 SDValue visitFMAD(SDNode *N);
503 SDValue visitFDIV(SDNode *N);
504 SDValue visitFREM(SDNode *N);
505 SDValue visitFSQRT(SDNode *N);
506 SDValue visitFCOPYSIGN(SDNode *N);
507 SDValue visitFPOW(SDNode *N);
508 SDValue visitSINT_TO_FP(SDNode *N);
509 SDValue visitUINT_TO_FP(SDNode *N);
510 SDValue visitFP_TO_SINT(SDNode *N);
511 SDValue visitFP_TO_UINT(SDNode *N);
512 SDValue visitXRINT(SDNode *N);
513 SDValue visitFP_ROUND(SDNode *N);
514 SDValue visitFP_EXTEND(SDNode *N);
515 SDValue visitFNEG(SDNode *N);
516 SDValue visitFABS(SDNode *N);
517 SDValue visitFCEIL(SDNode *N);
518 SDValue visitFTRUNC(SDNode *N);
519 SDValue visitFFREXP(SDNode *N);
520 SDValue visitFFLOOR(SDNode *N);
521 SDValue visitFMinMax(SDNode *N);
522 SDValue visitBRCOND(SDNode *N);
523 SDValue visitBR_CC(SDNode *N);
524 SDValue visitLOAD(SDNode *N);
525
526 SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain);
527 SDValue replaceStoreOfFPConstant(StoreSDNode *ST);
528 SDValue replaceStoreOfInsertLoad(StoreSDNode *ST);
529
530 bool refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(SDNode *N);
531
532 SDValue visitSTORE(SDNode *N);
533 SDValue visitLIFETIME_END(SDNode *N);
534 SDValue visitINSERT_VECTOR_ELT(SDNode *N);
535 SDValue visitEXTRACT_VECTOR_ELT(SDNode *N);
536 SDValue visitBUILD_VECTOR(SDNode *N);
537 SDValue visitCONCAT_VECTORS(SDNode *N);
538 SDValue visitEXTRACT_SUBVECTOR(SDNode *N);
539 SDValue visitVECTOR_SHUFFLE(SDNode *N);
540 SDValue visitSCALAR_TO_VECTOR(SDNode *N);
541 SDValue visitINSERT_SUBVECTOR(SDNode *N);
542 SDValue visitMLOAD(SDNode *N);
543 SDValue visitMSTORE(SDNode *N);
544 SDValue visitMGATHER(SDNode *N);
545 SDValue visitMSCATTER(SDNode *N);
546 SDValue visitVPGATHER(SDNode *N);
547 SDValue visitVPSCATTER(SDNode *N);
548 SDValue visitVP_STRIDED_LOAD(SDNode *N);
549 SDValue visitVP_STRIDED_STORE(SDNode *N);
550 SDValue visitFP_TO_FP16(SDNode *N);
551 SDValue visitFP16_TO_FP(SDNode *N);
552 SDValue visitFP_TO_BF16(SDNode *N);
553 SDValue visitBF16_TO_FP(SDNode *N);
554 SDValue visitVECREDUCE(SDNode *N);
555 SDValue visitVPOp(SDNode *N);
556 SDValue visitGET_FPENV_MEM(SDNode *N);
557 SDValue visitSET_FPENV_MEM(SDNode *N);
558
559 template <class MatchContextClass>
560 SDValue visitFADDForFMACombine(SDNode *N);
561 template <class MatchContextClass>
562 SDValue visitFSUBForFMACombine(SDNode *N);
563 SDValue visitFMULForFMADistributiveCombine(SDNode *N);
564
565 SDValue XformToShuffleWithZero(SDNode *N);
566 bool reassociationCanBreakAddressingModePattern(unsigned Opc,
567 const SDLoc &DL,
568 SDNode *N,
569 SDValue N0,
570 SDValue N1);
571 SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
572 SDValue N1, SDNodeFlags Flags);
573 SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
574 SDValue N1, SDNodeFlags Flags);
575 SDValue reassociateReduction(unsigned RedOpc, unsigned Opc, const SDLoc &DL,
576 EVT VT, SDValue N0, SDValue N1,
577 SDNodeFlags Flags = SDNodeFlags());
578
579 SDValue visitShiftByConstant(SDNode *N);
580
581 SDValue foldSelectOfConstants(SDNode *N);
582 SDValue foldVSelectOfConstants(SDNode *N);
583 SDValue foldBinOpIntoSelect(SDNode *BO);
584 bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
585 SDValue hoistLogicOpWithSameOpcodeHands(SDNode *N);
586 SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
587 SDValue SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
588 SDValue N2, SDValue N3, ISD::CondCode CC,
589 bool NotExtCompare = false);
590 SDValue convertSelectOfFPConstantsToLoadOffset(
591 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
592 ISD::CondCode CC);
593 SDValue foldSignChangeInBitcast(SDNode *N);
594 SDValue foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0, SDValue N1,
595 SDValue N2, SDValue N3, ISD::CondCode CC);
596 SDValue foldSelectOfBinops(SDNode *N);
597 SDValue foldSextSetcc(SDNode *N);
598 SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
599 const SDLoc &DL);
600 SDValue foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL);
601 SDValue foldABSToABD(SDNode *N, const SDLoc &DL);
602 SDValue unfoldMaskedMerge(SDNode *N);
603 SDValue unfoldExtremeBitClearingToShifts(SDNode *N);
604 SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
605 const SDLoc &DL, bool foldBooleans);
606 SDValue rebuildSetCC(SDValue N);
607
608 bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
609 SDValue &CC, bool MatchStrict = false) const;
610 bool isOneUseSetCC(SDValue N) const;
611
612 SDValue foldAddToAvg(SDNode *N, const SDLoc &DL);
613 SDValue foldSubToAvg(SDNode *N, const SDLoc &DL);
614
615 SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
616 unsigned HiOp);
617 SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
618 SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
619 const TargetLowering &TLI);
620
621 SDValue CombineExtLoad(SDNode *N);
622 SDValue CombineZExtLogicopShiftLoad(SDNode *N);
623 SDValue combineRepeatedFPDivisors(SDNode *N);
624 SDValue combineFMulOrFDivWithIntPow2(SDNode *N);
625 SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
626 SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
627 SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex);
628 SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
629 SDValue BuildSDIV(SDNode *N);
630 SDValue BuildSDIVPow2(SDNode *N);
631 SDValue BuildUDIV(SDNode *N);
632 SDValue BuildSREMPow2(SDNode *N);
633 SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N);
634 SDValue BuildLogBase2(SDValue V, const SDLoc &DL,
635 bool KnownNeverZero = false,
636 bool InexpensiveOnly = false,
637 std::optional<EVT> OutVT = std::nullopt);
638 SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags);
639 SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags);
640 SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags);
641 SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, bool Recip);
642 SDValue buildSqrtNROneConst(SDValue Arg, SDValue Est, unsigned Iterations,
643 SDNodeFlags Flags, bool Reciprocal);
644 SDValue buildSqrtNRTwoConst(SDValue Arg, SDValue Est, unsigned Iterations,
645 SDNodeFlags Flags, bool Reciprocal);
646 SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
647 bool DemandHighBits = true);
648 SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
649 SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
650 SDValue InnerPos, SDValue InnerNeg, bool HasPos,
651 unsigned PosOpcode, unsigned NegOpcode,
652 const SDLoc &DL);
653 SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
654 SDValue InnerPos, SDValue InnerNeg, bool HasPos,
655 unsigned PosOpcode, unsigned NegOpcode,
656 const SDLoc &DL);
657 SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
658 SDValue MatchLoadCombine(SDNode *N);
659 SDValue mergeTruncStores(StoreSDNode *N);
660 SDValue reduceLoadWidth(SDNode *N);
661 SDValue ReduceLoadOpStoreWidth(SDNode *N);
662 SDValue splitMergedValStore(StoreSDNode *ST);
663 SDValue TransformFPLoadStorePair(SDNode *N);
664 SDValue convertBuildVecZextToZext(SDNode *N);
665 SDValue convertBuildVecZextToBuildVecWithZeros(SDNode *N);
666 SDValue reduceBuildVecExtToExtBuildVec(SDNode *N);
667 SDValue reduceBuildVecTruncToBitCast(SDNode *N);
668 SDValue reduceBuildVecToShuffle(SDNode *N);
669 SDValue createBuildVecShuffle(const SDLoc &DL, SDNode *N,
670 ArrayRef<int> VectorMask, SDValue VecIn1,
671 SDValue VecIn2, unsigned LeftIdx,
672 bool DidSplitVec);
673 SDValue matchVSelectOpSizesWithSetCC(SDNode *Cast);
674
675 /// Walk up chain skipping non-aliasing memory nodes,
676 /// looking for aliasing nodes and adding them to the Aliases vector.
677 void GatherAllAliases(SDNode *N, SDValue OriginalChain,
678 SmallVectorImpl<SDValue> &Aliases);
679
680 /// Return true if there is any possibility that the two addresses overlap.
681 bool mayAlias(SDNode *Op0, SDNode *Op1) const;
682
683 /// Walk up chain skipping non-aliasing memory nodes, looking for a better
684 /// chain (aliasing node.)
685 SDValue FindBetterChain(SDNode *N, SDValue Chain);
686
687 /// Try to replace a store and any possibly adjacent stores on
688 /// consecutive chains with better chains. Return true only if St is
689 /// replaced.
690 ///
691 /// Notice that other chains may still be replaced even if the function
692 /// returns false.
693 bool findBetterNeighborChains(StoreSDNode *St);
694
695 // Helper for findBetterNeighborChains. Walk up store chain add additional
696 // chained stores that do not overlap and can be parallelized.
697 bool parallelizeChainedStores(StoreSDNode *St);
698
699 /// Holds a pointer to an LSBaseSDNode as well as information on where it
700 /// is located in a sequence of memory operations connected by a chain.
701 struct MemOpLink {
702 // Ptr to the mem node.
703 LSBaseSDNode *MemNode;
704
705 // Offset from the base ptr.
706 int64_t OffsetFromBase;
707
708 MemOpLink(LSBaseSDNode *N, int64_t Offset)
709 : MemNode(N), OffsetFromBase(Offset) {}
710 };
711
712 // Classify the origin of a stored value.
713 enum class StoreSource { Unknown, Constant, Extract, Load };
714 StoreSource getStoreSource(SDValue StoreVal) {
715 switch (StoreVal.getOpcode()) {
716 case ISD::Constant:
717 case ISD::ConstantFP:
718 return StoreSource::Constant;
719 case ISD::BUILD_VECTOR:
720 if (ISD::isBuildVectorOfConstantSDNodes(N: StoreVal.getNode()) ||
721 ISD::isBuildVectorOfConstantFPSDNodes(N: StoreVal.getNode()))
722 return StoreSource::Constant;
723 return StoreSource::Unknown;
724 case ISD::EXTRACT_VECTOR_ELT:
725 case ISD::EXTRACT_SUBVECTOR:
726 return StoreSource::Extract;
727 case ISD::LOAD:
728 return StoreSource::Load;
729 default:
730 return StoreSource::Unknown;
731 }
732 }
733
734 /// This is a helper function for visitMUL to check the profitability
735 /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
736 /// MulNode is the original multiply, AddNode is (add x, c1),
737 /// and ConstNode is c2.
738 bool isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
739 SDValue ConstNode);
740
741 /// This is a helper function for visitAND and visitZERO_EXTEND. Returns
742 /// true if the (and (load x) c) pattern matches an extload. ExtVT returns
743 /// the type of the loaded value to be extended.
744 bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
745 EVT LoadResultTy, EVT &ExtVT);
746
747 /// Helper function to calculate whether the given Load/Store can have its
748 /// width reduced to ExtVT.
749 bool isLegalNarrowLdSt(LSBaseSDNode *LDSTN, ISD::LoadExtType ExtType,
750 EVT &MemVT, unsigned ShAmt = 0);
751
752 /// Used by BackwardsPropagateMask to find suitable loads.
753 bool SearchForAndLoads(SDNode *N, SmallVectorImpl<LoadSDNode*> &Loads,
754 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
755 ConstantSDNode *Mask, SDNode *&NodeToMask);
756 /// Attempt to propagate a given AND node back to load leaves so that they
757 /// can be combined into narrow loads.
758 bool BackwardsPropagateMask(SDNode *N);
759
760 /// Helper function for mergeConsecutiveStores which merges the component
761 /// store chains.
762 SDValue getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
763 unsigned NumStores);
764
765 /// Helper function for mergeConsecutiveStores which checks if all the store
766 /// nodes have the same underlying object. We can still reuse the first
767 /// store's pointer info if all the stores are from the same object.
768 bool hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes);
769
770 /// This is a helper function for mergeConsecutiveStores. When the source
771 /// elements of the consecutive stores are all constants or all extracted
772 /// vector elements, try to merge them into one larger store introducing
773 /// bitcasts if necessary. \return True if a merged store was created.
774 bool mergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
775 EVT MemVT, unsigned NumStores,
776 bool IsConstantSrc, bool UseVector,
777 bool UseTrunc);
778
779 /// This is a helper function for mergeConsecutiveStores. Stores that
780 /// potentially may be merged with St are placed in StoreNodes. RootNode is
781 /// a chain predecessor to all store candidates.
782 void getStoreMergeCandidates(StoreSDNode *St,
783 SmallVectorImpl<MemOpLink> &StoreNodes,
784 SDNode *&Root);
785
786 /// Helper function for mergeConsecutiveStores. Checks if candidate stores
787 /// have indirect dependency through their operands. RootNode is the
788 /// predecessor to all stores calculated by getStoreMergeCandidates and is
789 /// used to prune the dependency check. \return True if safe to merge.
790 bool checkMergeStoreCandidatesForDependencies(
791 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
792 SDNode *RootNode);
793
794 /// This is a helper function for mergeConsecutiveStores. Given a list of
795 /// store candidates, find the first N that are consecutive in memory.
796 /// Returns 0 if there are not at least 2 consecutive stores to try merging.
797 unsigned getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
798 int64_t ElementSizeBytes) const;
799
800 /// This is a helper function for mergeConsecutiveStores. It is used for
801 /// store chains that are composed entirely of constant values.
802 bool tryStoreMergeOfConstants(SmallVectorImpl<MemOpLink> &StoreNodes,
803 unsigned NumConsecutiveStores,
804 EVT MemVT, SDNode *Root, bool AllowVectors);
805
806 /// This is a helper function for mergeConsecutiveStores. It is used for
807 /// store chains that are composed entirely of extracted vector elements.
808 /// When extracting multiple vector elements, try to store them in one
809 /// vector store rather than a sequence of scalar stores.
810 bool tryStoreMergeOfExtracts(SmallVectorImpl<MemOpLink> &StoreNodes,
811 unsigned NumConsecutiveStores, EVT MemVT,
812 SDNode *Root);
813
814 /// This is a helper function for mergeConsecutiveStores. It is used for
815 /// store chains that are composed entirely of loaded values.
816 bool tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
817 unsigned NumConsecutiveStores, EVT MemVT,
818 SDNode *Root, bool AllowVectors,
819 bool IsNonTemporalStore, bool IsNonTemporalLoad);
820
821 /// Merge consecutive store operations into a wide store.
822 /// This optimization uses wide integers or vectors when possible.
823 /// \return true if stores were merged.
824 bool mergeConsecutiveStores(StoreSDNode *St);
825
826 /// Try to transform a truncation where C is a constant:
827 /// (trunc (and X, C)) -> (and (trunc X), (trunc C))
828 ///
829 /// \p N needs to be a truncation and its first operand an AND. Other
830 /// requirements are checked by the function (e.g. that trunc is
831 /// single-use) and if missed an empty SDValue is returned.
832 SDValue distributeTruncateThroughAnd(SDNode *N);
833
834 /// Helper function to determine whether the target supports operation
835 /// given by \p Opcode for type \p VT, that is, whether the operation
836 /// is legal or custom before legalizing operations, and whether is
837 /// legal (but not custom) after legalization.
838 bool hasOperation(unsigned Opcode, EVT VT) {
839 return TLI.isOperationLegalOrCustom(Op: Opcode, VT, LegalOnly: LegalOperations);
840 }
841
842 public:
843 /// Runs the dag combiner on all nodes in the work list
844 void Run(CombineLevel AtLevel);
845
846 SelectionDAG &getDAG() const { return DAG; }
847
848 /// Returns a type large enough to hold any valid shift amount - before type
849 /// legalization these can be huge.
850 EVT getShiftAmountTy(EVT LHSTy) {
851 assert(LHSTy.isInteger() && "Shift amount is not an integer type!");
852 return TLI.getShiftAmountTy(LHSTy, DL: DAG.getDataLayout(), LegalTypes);
853 }
854
855 /// This method returns true if we are running before type legalization or
856 /// if the specified VT is legal.
857 bool isTypeLegal(const EVT &VT) {
858 if (!LegalTypes) return true;
859 return TLI.isTypeLegal(VT);
860 }
861
862 /// Convenience wrapper around TargetLowering::getSetCCResultType
863 EVT getSetCCResultType(EVT VT) const {
864 return TLI.getSetCCResultType(DL: DAG.getDataLayout(), Context&: *DAG.getContext(), VT);
865 }
866
867 void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
868 SDValue OrigLoad, SDValue ExtLoad,
869 ISD::NodeType ExtType);
870 };
871
872/// This class is a DAGUpdateListener that removes any deleted
873/// nodes from the worklist.
874class WorklistRemover : public SelectionDAG::DAGUpdateListener {
875 DAGCombiner &DC;
876
877public:
878 explicit WorklistRemover(DAGCombiner &dc)
879 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
880
881 void NodeDeleted(SDNode *N, SDNode *E) override {
882 DC.removeFromWorklist(N);
883 }
884};
885
886class WorklistInserter : public SelectionDAG::DAGUpdateListener {
887 DAGCombiner &DC;
888
889public:
890 explicit WorklistInserter(DAGCombiner &dc)
891 : SelectionDAG::DAGUpdateListener(dc.getDAG()), DC(dc) {}
892
893 // FIXME: Ideally we could add N to the worklist, but this causes exponential
894 // compile time costs in large DAGs, e.g. Halide.
895 void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
896};
897
898} // end anonymous namespace
899
900//===----------------------------------------------------------------------===//
901// TargetLowering::DAGCombinerInfo implementation
902//===----------------------------------------------------------------------===//
903
904void TargetLowering::DAGCombinerInfo::AddToWorklist(SDNode *N) {
905 ((DAGCombiner*)DC)->AddToWorklist(N);
906}
907
908SDValue TargetLowering::DAGCombinerInfo::
909CombineTo(SDNode *N, ArrayRef<SDValue> To, bool AddTo) {
910 return ((DAGCombiner*)DC)->CombineTo(N, To: &To[0], NumTo: To.size(), AddTo);
911}
912
913SDValue TargetLowering::DAGCombinerInfo::
914CombineTo(SDNode *N, SDValue Res, bool AddTo) {
915 return ((DAGCombiner*)DC)->CombineTo(N, Res, AddTo);
916}
917
918SDValue TargetLowering::DAGCombinerInfo::
919CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo) {
920 return ((DAGCombiner*)DC)->CombineTo(N, Res0, Res1, AddTo);
921}
922
923bool TargetLowering::DAGCombinerInfo::
924recursivelyDeleteUnusedNodes(SDNode *N) {
925 return ((DAGCombiner*)DC)->recursivelyDeleteUnusedNodes(N);
926}
927
928void TargetLowering::DAGCombinerInfo::
929CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
930 return ((DAGCombiner*)DC)->CommitTargetLoweringOpt(TLO);
931}
932
933//===----------------------------------------------------------------------===//
934// Helper Functions
935//===----------------------------------------------------------------------===//
936
937void DAGCombiner::deleteAndRecombine(SDNode *N) {
938 removeFromWorklist(N);
939
940 // If the operands of this node are only used by the node, they will now be
941 // dead. Make sure to re-visit them and recursively delete dead nodes.
942 for (const SDValue &Op : N->ops())
943 // For an operand generating multiple values, one of the values may
944 // become dead allowing further simplification (e.g. split index
945 // arithmetic from an indexed load).
946 if (Op->hasOneUse() || Op->getNumValues() > 1)
947 AddToWorklist(N: Op.getNode());
948
949 DAG.DeleteNode(N);
950}
951
952// APInts must be the same size for most operations, this helper
953// function zero extends the shorter of the pair so that they match.
954// We provide an Offset so that we can create bitwidths that won't overflow.
955static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
956 unsigned Bits = Offset + std::max(a: LHS.getBitWidth(), b: RHS.getBitWidth());
957 LHS = LHS.zext(width: Bits);
958 RHS = RHS.zext(width: Bits);
959}
960
961// Return true if this node is a setcc, or is a select_cc
962// that selects between the target values used for true and false, making it
963// equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
964// the appropriate nodes based on the type of node we are checking. This
965// simplifies life a bit for the callers.
966bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
967 SDValue &CC, bool MatchStrict) const {
968 if (N.getOpcode() == ISD::SETCC) {
969 LHS = N.getOperand(i: 0);
970 RHS = N.getOperand(i: 1);
971 CC = N.getOperand(i: 2);
972 return true;
973 }
974
975 if (MatchStrict &&
976 (N.getOpcode() == ISD::STRICT_FSETCC ||
977 N.getOpcode() == ISD::STRICT_FSETCCS)) {
978 LHS = N.getOperand(i: 1);
979 RHS = N.getOperand(i: 2);
980 CC = N.getOperand(i: 3);
981 return true;
982 }
983
984 if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N: N.getOperand(i: 2)) ||
985 !TLI.isConstFalseVal(N: N.getOperand(i: 3)))
986 return false;
987
988 if (TLI.getBooleanContents(Type: N.getValueType()) ==
989 TargetLowering::UndefinedBooleanContent)
990 return false;
991
992 LHS = N.getOperand(i: 0);
993 RHS = N.getOperand(i: 1);
994 CC = N.getOperand(i: 4);
995 return true;
996}
997
998/// Return true if this is a SetCC-equivalent operation with only one use.
999/// If this is true, it allows the users to invert the operation for free when
1000/// it is profitable to do so.
1001bool DAGCombiner::isOneUseSetCC(SDValue N) const {
1002 SDValue N0, N1, N2;
1003 if (isSetCCEquivalent(N, LHS&: N0, RHS&: N1, CC&: N2) && N->hasOneUse())
1004 return true;
1005 return false;
1006}
1007
1008static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) {
1009 if (!ScalarTy.isSimple())
1010 return false;
1011
1012 uint64_t MaskForTy = 0ULL;
1013 switch (ScalarTy.getSimpleVT().SimpleTy) {
1014 case MVT::i8:
1015 MaskForTy = 0xFFULL;
1016 break;
1017 case MVT::i16:
1018 MaskForTy = 0xFFFFULL;
1019 break;
1020 case MVT::i32:
1021 MaskForTy = 0xFFFFFFFFULL;
1022 break;
1023 default:
1024 return false;
1025 break;
1026 }
1027
1028 APInt Val;
1029 if (ISD::isConstantSplatVector(N, SplatValue&: Val))
1030 return Val.getLimitedValue() == MaskForTy;
1031
1032 return false;
1033}
1034
1035// Determines if it is a constant integer or a splat/build vector of constant
1036// integers (and undefs).
1037// Do not permit build vector implicit truncation.
1038static bool isConstantOrConstantVector(SDValue N, bool NoOpaques = false) {
1039 if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val&: N))
1040 return !(Const->isOpaque() && NoOpaques);
1041 if (N.getOpcode() != ISD::BUILD_VECTOR && N.getOpcode() != ISD::SPLAT_VECTOR)
1042 return false;
1043 unsigned BitWidth = N.getScalarValueSizeInBits();
1044 for (const SDValue &Op : N->op_values()) {
1045 if (Op.isUndef())
1046 continue;
1047 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val: Op);
1048 if (!Const || Const->getAPIntValue().getBitWidth() != BitWidth ||
1049 (Const->isOpaque() && NoOpaques))
1050 return false;
1051 }
1052 return true;
1053}
1054
1055// Determines if a BUILD_VECTOR is composed of all-constants possibly mixed with
1056// undef's.
1057static bool isAnyConstantBuildVector(SDValue V, bool NoOpaques = false) {
1058 if (V.getOpcode() != ISD::BUILD_VECTOR)
1059 return false;
1060 return isConstantOrConstantVector(N: V, NoOpaques) ||
1061 ISD::isBuildVectorOfConstantFPSDNodes(N: V.getNode());
1062}
1063
1064// Determine if this an indexed load with an opaque target constant index.
1065static bool canSplitIdx(LoadSDNode *LD) {
1066 return MaySplitLoadIndex &&
1067 (LD->getOperand(Num: 2).getOpcode() != ISD::TargetConstant ||
1068 !cast<ConstantSDNode>(Val: LD->getOperand(Num: 2))->isOpaque());
1069}
1070
1071bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
1072 const SDLoc &DL,
1073 SDNode *N,
1074 SDValue N0,
1075 SDValue N1) {
1076 // Currently this only tries to ensure we don't undo the GEP splits done by
1077 // CodeGenPrepare when shouldConsiderGEPOffsetSplit is true. To ensure this,
1078 // we check if the following transformation would be problematic:
1079 // (load/store (add, (add, x, offset1), offset2)) ->
1080 // (load/store (add, x, offset1+offset2)).
1081
1082 // (load/store (add, (add, x, y), offset2)) ->
1083 // (load/store (add, (add, x, offset2), y)).
1084
1085 if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD)
1086 return false;
1087
1088 auto *C2 = dyn_cast<ConstantSDNode>(Val&: N1);
1089 if (!C2)
1090 return false;
1091
1092 const APInt &C2APIntVal = C2->getAPIntValue();
1093 if (C2APIntVal.getSignificantBits() > 64)
1094 return false;
1095
1096 if (auto *C1 = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
1097 if (N0.hasOneUse())
1098 return false;
1099
1100 const APInt &C1APIntVal = C1->getAPIntValue();
1101 const APInt CombinedValueIntVal = C1APIntVal + C2APIntVal;
1102 if (CombinedValueIntVal.getSignificantBits() > 64)
1103 return false;
1104 const int64_t CombinedValue = CombinedValueIntVal.getSExtValue();
1105
1106 for (SDNode *Node : N->uses()) {
1107 if (auto *LoadStore = dyn_cast<MemSDNode>(Val: Node)) {
1108 // Is x[offset2] already not a legal addressing mode? If so then
1109 // reassociating the constants breaks nothing (we test offset2 because
1110 // that's the one we hope to fold into the load or store).
1111 TargetLoweringBase::AddrMode AM;
1112 AM.HasBaseReg = true;
1113 AM.BaseOffs = C2APIntVal.getSExtValue();
1114 EVT VT = LoadStore->getMemoryVT();
1115 unsigned AS = LoadStore->getAddressSpace();
1116 Type *AccessTy = VT.getTypeForEVT(Context&: *DAG.getContext());
1117 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1118 continue;
1119
1120 // Would x[offset1+offset2] still be a legal addressing mode?
1121 AM.BaseOffs = CombinedValue;
1122 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1123 return true;
1124 }
1125 }
1126 } else {
1127 if (auto *GA = dyn_cast<GlobalAddressSDNode>(Val: N0.getOperand(i: 1)))
1128 if (GA->getOpcode() == ISD::GlobalAddress && TLI.isOffsetFoldingLegal(GA))
1129 return false;
1130
1131 for (SDNode *Node : N->uses()) {
1132 auto *LoadStore = dyn_cast<MemSDNode>(Val: Node);
1133 if (!LoadStore)
1134 return false;
1135
1136 // Is x[offset2] a legal addressing mode? If so then
1137 // reassociating the constants breaks address pattern
1138 TargetLoweringBase::AddrMode AM;
1139 AM.HasBaseReg = true;
1140 AM.BaseOffs = C2APIntVal.getSExtValue();
1141 EVT VT = LoadStore->getMemoryVT();
1142 unsigned AS = LoadStore->getAddressSpace();
1143 Type *AccessTy = VT.getTypeForEVT(Context&: *DAG.getContext());
1144 if (!TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM, Ty: AccessTy, AddrSpace: AS))
1145 return false;
1146 }
1147 return true;
1148 }
1149
1150 return false;
1151}
1152
1153/// Helper for DAGCombiner::reassociateOps. Try to reassociate (Opc N0, N1) if
1154/// \p N0 is the same kind of operation as \p Opc.
1155SDValue DAGCombiner::reassociateOpsCommutative(unsigned Opc, const SDLoc &DL,
1156 SDValue N0, SDValue N1,
1157 SDNodeFlags Flags) {
1158 EVT VT = N0.getValueType();
1159
1160 if (N0.getOpcode() != Opc)
1161 return SDValue();
1162
1163 SDValue N00 = N0.getOperand(i: 0);
1164 SDValue N01 = N0.getOperand(i: 1);
1165
1166 if (DAG.isConstantIntBuildVectorOrConstantInt(N: peekThroughBitcasts(V: N01))) {
1167 SDNodeFlags NewFlags;
1168 if (N0.getOpcode() == ISD::ADD && N0->getFlags().hasNoUnsignedWrap() &&
1169 Flags.hasNoUnsignedWrap())
1170 NewFlags.setNoUnsignedWrap(true);
1171
1172 if (DAG.isConstantIntBuildVectorOrConstantInt(N: peekThroughBitcasts(V: N1))) {
1173 // Reassociate: (op (op x, c1), c2) -> (op x, (op c1, c2))
1174 if (SDValue OpNode = DAG.FoldConstantArithmetic(Opcode: Opc, DL, VT, Ops: {N01, N1}))
1175 return DAG.getNode(Opcode: Opc, DL, VT, N1: N00, N2: OpNode, Flags: NewFlags);
1176 return SDValue();
1177 }
1178 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1179 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1)
1180 // iff (op x, c1) has one use
1181 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N00, N2: N1, Flags: NewFlags);
1182 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N01, Flags: NewFlags);
1183 }
1184 }
1185
1186 // Check for repeated operand logic simplifications.
1187 if (Opc == ISD::AND || Opc == ISD::OR) {
1188 // (N00 & N01) & N00 --> N00 & N01
1189 // (N00 & N01) & N01 --> N00 & N01
1190 // (N00 | N01) | N00 --> N00 | N01
1191 // (N00 | N01) | N01 --> N00 | N01
1192 if (N1 == N00 || N1 == N01)
1193 return N0;
1194 }
1195 if (Opc == ISD::XOR) {
1196 // (N00 ^ N01) ^ N00 --> N01
1197 if (N1 == N00)
1198 return N01;
1199 // (N00 ^ N01) ^ N01 --> N00
1200 if (N1 == N01)
1201 return N00;
1202 }
1203
1204 if (TLI.isReassocProfitable(DAG, N0, N1)) {
1205 if (N1 != N01) {
1206 // Reassociate if (op N00, N1) already exist
1207 if (SDNode *NE = DAG.getNodeIfExists(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {N00, N1})) {
1208 // if Op (Op N00, N1), N01 already exist
1209 // we need to stop reassciate to avoid dead loop
1210 if (!DAG.doesNodeExist(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {SDValue(NE, 0), N01}))
1211 return DAG.getNode(Opcode: Opc, DL, VT, N1: SDValue(NE, 0), N2: N01);
1212 }
1213 }
1214
1215 if (N1 != N00) {
1216 // Reassociate if (op N01, N1) already exist
1217 if (SDNode *NE = DAG.getNodeIfExists(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {N01, N1})) {
1218 // if Op (Op N01, N1), N00 already exist
1219 // we need to stop reassciate to avoid dead loop
1220 if (!DAG.doesNodeExist(Opcode: Opc, VTList: DAG.getVTList(VT), Ops: {SDValue(NE, 0), N00}))
1221 return DAG.getNode(Opcode: Opc, DL, VT, N1: SDValue(NE, 0), N2: N00);
1222 }
1223 }
1224
1225 // Reassociate the operands from (OR/AND (OR/AND(N00, N001)), N1) to (OR/AND
1226 // (OR/AND(N00, N1)), N01) when N00 and N1 are comparisons with the same
1227 // predicate or to (OR/AND (OR/AND(N1, N01)), N00) when N01 and N1 are
1228 // comparisons with the same predicate. This enables optimizations as the
1229 // following one:
1230 // CMP(A,C)||CMP(B,C) => CMP(MIN/MAX(A,B), C)
1231 // CMP(A,C)&&CMP(B,C) => CMP(MIN/MAX(A,B), C)
1232 if (Opc == ISD::AND || Opc == ISD::OR) {
1233 if (N1->getOpcode() == ISD::SETCC && N00->getOpcode() == ISD::SETCC &&
1234 N01->getOpcode() == ISD::SETCC) {
1235 ISD::CondCode CC1 = cast<CondCodeSDNode>(Val: N1.getOperand(i: 2))->get();
1236 ISD::CondCode CC00 = cast<CondCodeSDNode>(Val: N00.getOperand(i: 2))->get();
1237 ISD::CondCode CC01 = cast<CondCodeSDNode>(Val: N01.getOperand(i: 2))->get();
1238 if (CC1 == CC00 && CC1 != CC01) {
1239 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N00, N2: N1, Flags);
1240 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N01, Flags);
1241 }
1242 if (CC1 == CC01 && CC1 != CC00) {
1243 SDValue OpNode = DAG.getNode(Opcode: Opc, DL: SDLoc(N0), VT, N1: N01, N2: N1, Flags);
1244 return DAG.getNode(Opcode: Opc, DL, VT, N1: OpNode, N2: N00, Flags);
1245 }
1246 }
1247 }
1248 }
1249
1250 return SDValue();
1251}
1252
1253/// Try to reassociate commutative (Opc N0, N1) if either \p N0 or \p N1 is the
1254/// same kind of operation as \p Opc.
1255SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1256 SDValue N1, SDNodeFlags Flags) {
1257 assert(TLI.isCommutativeBinOp(Opc) && "Operation not commutative.");
1258
1259 // Floating-point reassociation is not allowed without loose FP math.
1260 if (N0.getValueType().isFloatingPoint() ||
1261 N1.getValueType().isFloatingPoint())
1262 if (!Flags.hasAllowReassociation() || !Flags.hasNoSignedZeros())
1263 return SDValue();
1264
1265 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0, N1, Flags))
1266 return Combined;
1267 if (SDValue Combined = reassociateOpsCommutative(Opc, DL, N0: N1, N1: N0, Flags))
1268 return Combined;
1269 return SDValue();
1270}
1271
1272// Try to fold Opc(vecreduce(x), vecreduce(y)) -> vecreduce(Opc(x, y))
1273// Note that we only expect Flags to be passed from FP operations. For integer
1274// operations they need to be dropped.
1275SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
1276 const SDLoc &DL, EVT VT, SDValue N0,
1277 SDValue N1, SDNodeFlags Flags) {
1278 if (N0.getOpcode() == RedOpc && N1.getOpcode() == RedOpc &&
1279 N0.getOperand(i: 0).getValueType() == N1.getOperand(i: 0).getValueType() &&
1280 N0->hasOneUse() && N1->hasOneUse() &&
1281 TLI.isOperationLegalOrCustom(Op: Opc, VT: N0.getOperand(i: 0).getValueType()) &&
1282 TLI.shouldReassociateReduction(RedOpc, VT: N0.getOperand(i: 0).getValueType())) {
1283 SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
1284 return DAG.getNode(Opcode: RedOpc, DL, VT,
1285 Operand: DAG.getNode(Opcode: Opc, DL, VT: N0.getOperand(i: 0).getValueType(),
1286 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0)));
1287 }
1288 return SDValue();
1289}
1290
1291SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1292 bool AddTo) {
1293 assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
1294 ++NodesCombined;
1295 LLVM_DEBUG(dbgs() << "\nReplacing.1 "; N->dump(&DAG); dbgs() << "\nWith: ";
1296 To[0].dump(&DAG);
1297 dbgs() << " and " << NumTo - 1 << " other values\n");
1298 for (unsigned i = 0, e = NumTo; i != e; ++i)
1299 assert((!To[i].getNode() ||
1300 N->getValueType(i) == To[i].getValueType()) &&
1301 "Cannot combine value to value of different type!");
1302
1303 WorklistRemover DeadNodes(*this);
1304 DAG.ReplaceAllUsesWith(From: N, To);
1305 if (AddTo) {
1306 // Push the new nodes and any users onto the worklist
1307 for (unsigned i = 0, e = NumTo; i != e; ++i) {
1308 if (To[i].getNode())
1309 AddToWorklistWithUsers(N: To[i].getNode());
1310 }
1311 }
1312
1313 // Finally, if the node is now dead, remove it from the graph. The node
1314 // may not be dead if the replacement process recursively simplified to
1315 // something else needing this node.
1316 if (N->use_empty())
1317 deleteAndRecombine(N);
1318 return SDValue(N, 0);
1319}
1320
1321void DAGCombiner::
1322CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) {
1323 // Replace the old value with the new one.
1324 ++NodesCombined;
1325 LLVM_DEBUG(dbgs() << "\nReplacing.2 "; TLO.Old.dump(&DAG);
1326 dbgs() << "\nWith: "; TLO.New.dump(&DAG); dbgs() << '\n');
1327
1328 // Replace all uses.
1329 DAG.ReplaceAllUsesOfValueWith(From: TLO.Old, To: TLO.New);
1330
1331 // Push the new node and any (possibly new) users onto the worklist.
1332 AddToWorklistWithUsers(N: TLO.New.getNode());
1333
1334 // Finally, if the node is now dead, remove it from the graph.
1335 recursivelyDeleteUnusedNodes(N: TLO.Old.getNode());
1336}
1337
1338/// Check the specified integer node value to see if it can be simplified or if
1339/// things it uses can be simplified by bit propagation. If so, return true.
1340bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
1341 const APInt &DemandedElts,
1342 bool AssumeSingleUse) {
1343 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1344 KnownBits Known;
1345 if (!TLI.SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth: 0,
1346 AssumeSingleUse))
1347 return false;
1348
1349 // Revisit the node.
1350 AddToWorklist(N: Op.getNode());
1351
1352 CommitTargetLoweringOpt(TLO);
1353 return true;
1354}
1355
1356/// Check the specified vector node value to see if it can be simplified or
1357/// if things it uses can be simplified as it only uses some of the elements.
1358/// If so, return true.
1359bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
1360 const APInt &DemandedElts,
1361 bool AssumeSingleUse) {
1362 TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
1363 APInt KnownUndef, KnownZero;
1364 if (!TLI.SimplifyDemandedVectorElts(Op, DemandedEltMask: DemandedElts, KnownUndef, KnownZero,
1365 TLO, Depth: 0, AssumeSingleUse))
1366 return false;
1367
1368 // Revisit the node.
1369 AddToWorklist(N: Op.getNode());
1370
1371 CommitTargetLoweringOpt(TLO);
1372 return true;
1373}
1374
1375void DAGCombiner::ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad) {
1376 SDLoc DL(Load);
1377 EVT VT = Load->getValueType(ResNo: 0);
1378 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: SDValue(ExtLoad, 0));
1379
1380 LLVM_DEBUG(dbgs() << "\nReplacing.9 "; Load->dump(&DAG); dbgs() << "\nWith: ";
1381 Trunc.dump(&DAG); dbgs() << '\n');
1382
1383 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 0), To: Trunc);
1384 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 1), To: SDValue(ExtLoad, 1));
1385
1386 AddToWorklist(N: Trunc.getNode());
1387 recursivelyDeleteUnusedNodes(N: Load);
1388}
1389
1390SDValue DAGCombiner::PromoteOperand(SDValue Op, EVT PVT, bool &Replace) {
1391 Replace = false;
1392 SDLoc DL(Op);
1393 if (ISD::isUNINDEXEDLoad(N: Op.getNode())) {
1394 LoadSDNode *LD = cast<LoadSDNode>(Val&: Op);
1395 EVT MemVT = LD->getMemoryVT();
1396 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(N: LD) ? ISD::EXTLOAD
1397 : LD->getExtensionType();
1398 Replace = true;
1399 return DAG.getExtLoad(ExtType, dl: DL, VT: PVT,
1400 Chain: LD->getChain(), Ptr: LD->getBasePtr(),
1401 MemVT, MMO: LD->getMemOperand());
1402 }
1403
1404 unsigned Opc = Op.getOpcode();
1405 switch (Opc) {
1406 default: break;
1407 case ISD::AssertSext:
1408 if (SDValue Op0 = SExtPromoteOperand(Op: Op.getOperand(i: 0), PVT))
1409 return DAG.getNode(Opcode: ISD::AssertSext, DL, VT: PVT, N1: Op0, N2: Op.getOperand(i: 1));
1410 break;
1411 case ISD::AssertZext:
1412 if (SDValue Op0 = ZExtPromoteOperand(Op: Op.getOperand(i: 0), PVT))
1413 return DAG.getNode(Opcode: ISD::AssertZext, DL, VT: PVT, N1: Op0, N2: Op.getOperand(i: 1));
1414 break;
1415 case ISD::Constant: {
1416 unsigned ExtOpc =
1417 Op.getValueType().isByteSized() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1418 return DAG.getNode(Opcode: ExtOpc, DL, VT: PVT, Operand: Op);
1419 }
1420 }
1421
1422 if (!TLI.isOperationLegal(Op: ISD::ANY_EXTEND, VT: PVT))
1423 return SDValue();
1424 return DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: PVT, Operand: Op);
1425}
1426
1427SDValue DAGCombiner::SExtPromoteOperand(SDValue Op, EVT PVT) {
1428 if (!TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_INREG, VT: PVT))
1429 return SDValue();
1430 EVT OldVT = Op.getValueType();
1431 SDLoc DL(Op);
1432 bool Replace = false;
1433 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1434 if (!NewOp.getNode())
1435 return SDValue();
1436 AddToWorklist(N: NewOp.getNode());
1437
1438 if (Replace)
1439 ReplaceLoadWithPromotedLoad(Load: Op.getNode(), ExtLoad: NewOp.getNode());
1440 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT: NewOp.getValueType(), N1: NewOp,
1441 N2: DAG.getValueType(OldVT));
1442}
1443
1444SDValue DAGCombiner::ZExtPromoteOperand(SDValue Op, EVT PVT) {
1445 EVT OldVT = Op.getValueType();
1446 SDLoc DL(Op);
1447 bool Replace = false;
1448 SDValue NewOp = PromoteOperand(Op, PVT, Replace);
1449 if (!NewOp.getNode())
1450 return SDValue();
1451 AddToWorklist(N: NewOp.getNode());
1452
1453 if (Replace)
1454 ReplaceLoadWithPromotedLoad(Load: Op.getNode(), ExtLoad: NewOp.getNode());
1455 return DAG.getZeroExtendInReg(Op: NewOp, DL, VT: OldVT);
1456}
1457
1458/// Promote the specified integer binary operation if the target indicates it is
1459/// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1460/// i32 since i16 instructions are longer.
1461SDValue DAGCombiner::PromoteIntBinOp(SDValue Op) {
1462 if (!LegalOperations)
1463 return SDValue();
1464
1465 EVT VT = Op.getValueType();
1466 if (VT.isVector() || !VT.isInteger())
1467 return SDValue();
1468
1469 // If operation type is 'undesirable', e.g. i16 on x86, consider
1470 // promoting it.
1471 unsigned Opc = Op.getOpcode();
1472 if (TLI.isTypeDesirableForOp(Opc, VT))
1473 return SDValue();
1474
1475 EVT PVT = VT;
1476 // Consult target whether it is a good idea to promote this operation and
1477 // what's the right type to promote it to.
1478 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1479 assert(PVT != VT && "Don't know what type to promote to!");
1480
1481 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1482
1483 bool Replace0 = false;
1484 SDValue N0 = Op.getOperand(i: 0);
1485 SDValue NN0 = PromoteOperand(Op: N0, PVT, Replace&: Replace0);
1486
1487 bool Replace1 = false;
1488 SDValue N1 = Op.getOperand(i: 1);
1489 SDValue NN1 = PromoteOperand(Op: N1, PVT, Replace&: Replace1);
1490 SDLoc DL(Op);
1491
1492 SDValue RV =
1493 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: DAG.getNode(Opcode: Opc, DL, VT: PVT, N1: NN0, N2: NN1));
1494
1495 // We are always replacing N0/N1's use in N and only need additional
1496 // replacements if there are additional uses.
1497 // Note: We are checking uses of the *nodes* (SDNode) rather than values
1498 // (SDValue) here because the node may reference multiple values
1499 // (for example, the chain value of a load node).
1500 Replace0 &= !N0->hasOneUse();
1501 Replace1 &= (N0 != N1) && !N1->hasOneUse();
1502
1503 // Combine Op here so it is preserved past replacements.
1504 CombineTo(N: Op.getNode(), Res: RV);
1505
1506 // If operands have a use ordering, make sure we deal with
1507 // predecessor first.
1508 if (Replace0 && Replace1 && N0->isPredecessorOf(N: N1.getNode())) {
1509 std::swap(a&: N0, b&: N1);
1510 std::swap(a&: NN0, b&: NN1);
1511 }
1512
1513 if (Replace0) {
1514 AddToWorklist(N: NN0.getNode());
1515 ReplaceLoadWithPromotedLoad(Load: N0.getNode(), ExtLoad: NN0.getNode());
1516 }
1517 if (Replace1) {
1518 AddToWorklist(N: NN1.getNode());
1519 ReplaceLoadWithPromotedLoad(Load: N1.getNode(), ExtLoad: NN1.getNode());
1520 }
1521 return Op;
1522 }
1523 return SDValue();
1524}
1525
1526/// Promote the specified integer shift operation if the target indicates it is
1527/// beneficial. e.g. On x86, it's usually better to promote i16 operations to
1528/// i32 since i16 instructions are longer.
1529SDValue DAGCombiner::PromoteIntShiftOp(SDValue Op) {
1530 if (!LegalOperations)
1531 return SDValue();
1532
1533 EVT VT = Op.getValueType();
1534 if (VT.isVector() || !VT.isInteger())
1535 return SDValue();
1536
1537 // If operation type is 'undesirable', e.g. i16 on x86, consider
1538 // promoting it.
1539 unsigned Opc = Op.getOpcode();
1540 if (TLI.isTypeDesirableForOp(Opc, VT))
1541 return SDValue();
1542
1543 EVT PVT = VT;
1544 // Consult target whether it is a good idea to promote this operation and
1545 // what's the right type to promote it to.
1546 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1547 assert(PVT != VT && "Don't know what type to promote to!");
1548
1549 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1550
1551 bool Replace = false;
1552 SDValue N0 = Op.getOperand(i: 0);
1553 if (Opc == ISD::SRA)
1554 N0 = SExtPromoteOperand(Op: N0, PVT);
1555 else if (Opc == ISD::SRL)
1556 N0 = ZExtPromoteOperand(Op: N0, PVT);
1557 else
1558 N0 = PromoteOperand(Op: N0, PVT, Replace);
1559
1560 if (!N0.getNode())
1561 return SDValue();
1562
1563 SDLoc DL(Op);
1564 SDValue N1 = Op.getOperand(i: 1);
1565 SDValue RV =
1566 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: DAG.getNode(Opcode: Opc, DL, VT: PVT, N1: N0, N2: N1));
1567
1568 if (Replace)
1569 ReplaceLoadWithPromotedLoad(Load: Op.getOperand(i: 0).getNode(), ExtLoad: N0.getNode());
1570
1571 // Deal with Op being deleted.
1572 if (Op && Op.getOpcode() != ISD::DELETED_NODE)
1573 return RV;
1574 }
1575 return SDValue();
1576}
1577
1578SDValue DAGCombiner::PromoteExtend(SDValue Op) {
1579 if (!LegalOperations)
1580 return SDValue();
1581
1582 EVT VT = Op.getValueType();
1583 if (VT.isVector() || !VT.isInteger())
1584 return SDValue();
1585
1586 // If operation type is 'undesirable', e.g. i16 on x86, consider
1587 // promoting it.
1588 unsigned Opc = Op.getOpcode();
1589 if (TLI.isTypeDesirableForOp(Opc, VT))
1590 return SDValue();
1591
1592 EVT PVT = VT;
1593 // Consult target whether it is a good idea to promote this operation and
1594 // what's the right type to promote it to.
1595 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1596 assert(PVT != VT && "Don't know what type to promote to!");
1597 // fold (aext (aext x)) -> (aext x)
1598 // fold (aext (zext x)) -> (zext x)
1599 // fold (aext (sext x)) -> (sext x)
1600 LLVM_DEBUG(dbgs() << "\nPromoting "; Op.dump(&DAG));
1601 return DAG.getNode(Opcode: Op.getOpcode(), DL: SDLoc(Op), VT, Operand: Op.getOperand(i: 0));
1602 }
1603 return SDValue();
1604}
1605
1606bool DAGCombiner::PromoteLoad(SDValue Op) {
1607 if (!LegalOperations)
1608 return false;
1609
1610 if (!ISD::isUNINDEXEDLoad(N: Op.getNode()))
1611 return false;
1612
1613 EVT VT = Op.getValueType();
1614 if (VT.isVector() || !VT.isInteger())
1615 return false;
1616
1617 // If operation type is 'undesirable', e.g. i16 on x86, consider
1618 // promoting it.
1619 unsigned Opc = Op.getOpcode();
1620 if (TLI.isTypeDesirableForOp(Opc, VT))
1621 return false;
1622
1623 EVT PVT = VT;
1624 // Consult target whether it is a good idea to promote this operation and
1625 // what's the right type to promote it to.
1626 if (TLI.IsDesirableToPromoteOp(Op, PVT)) {
1627 assert(PVT != VT && "Don't know what type to promote to!");
1628
1629 SDLoc DL(Op);
1630 SDNode *N = Op.getNode();
1631 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
1632 EVT MemVT = LD->getMemoryVT();
1633 ISD::LoadExtType ExtType = ISD::isNON_EXTLoad(N: LD) ? ISD::EXTLOAD
1634 : LD->getExtensionType();
1635 SDValue NewLD = DAG.getExtLoad(ExtType, dl: DL, VT: PVT,
1636 Chain: LD->getChain(), Ptr: LD->getBasePtr(),
1637 MemVT, MMO: LD->getMemOperand());
1638 SDValue Result = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: NewLD);
1639
1640 LLVM_DEBUG(dbgs() << "\nPromoting "; N->dump(&DAG); dbgs() << "\nTo: ";
1641 Result.dump(&DAG); dbgs() << '\n');
1642
1643 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result);
1644 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: NewLD.getValue(R: 1));
1645
1646 AddToWorklist(N: Result.getNode());
1647 recursivelyDeleteUnusedNodes(N);
1648 return true;
1649 }
1650
1651 return false;
1652}
1653
1654/// Recursively delete a node which has no uses and any operands for
1655/// which it is the only use.
1656///
1657/// Note that this both deletes the nodes and removes them from the worklist.
1658/// It also adds any nodes who have had a user deleted to the worklist as they
1659/// may now have only one use and subject to other combines.
1660bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
1661 if (!N->use_empty())
1662 return false;
1663
1664 SmallSetVector<SDNode *, 16> Nodes;
1665 Nodes.insert(X: N);
1666 do {
1667 N = Nodes.pop_back_val();
1668 if (!N)
1669 continue;
1670
1671 if (N->use_empty()) {
1672 for (const SDValue &ChildN : N->op_values())
1673 Nodes.insert(X: ChildN.getNode());
1674
1675 removeFromWorklist(N);
1676 DAG.DeleteNode(N);
1677 } else {
1678 AddToWorklist(N);
1679 }
1680 } while (!Nodes.empty());
1681 return true;
1682}
1683
1684//===----------------------------------------------------------------------===//
1685// Main DAG Combiner implementation
1686//===----------------------------------------------------------------------===//
1687
1688void DAGCombiner::Run(CombineLevel AtLevel) {
1689 // set the instance variables, so that the various visit routines may use it.
1690 Level = AtLevel;
1691 LegalDAG = Level >= AfterLegalizeDAG;
1692 LegalOperations = Level >= AfterLegalizeVectorOps;
1693 LegalTypes = Level >= AfterLegalizeTypes;
1694
1695 WorklistInserter AddNodes(*this);
1696
1697 // Add all the dag nodes to the worklist.
1698 //
1699 // Note: All nodes are not added to PruningList here, this is because the only
1700 // nodes which can be deleted are those which have no uses and all other nodes
1701 // which would otherwise be added to the worklist by the first call to
1702 // getNextWorklistEntry are already present in it.
1703 for (SDNode &Node : DAG.allnodes())
1704 AddToWorklist(N: &Node, /* IsCandidateForPruning */ Node.use_empty());
1705
1706 // Create a dummy node (which is not added to allnodes), that adds a reference
1707 // to the root node, preventing it from being deleted, and tracking any
1708 // changes of the root.
1709 HandleSDNode Dummy(DAG.getRoot());
1710
1711 // While we have a valid worklist entry node, try to combine it.
1712 while (SDNode *N = getNextWorklistEntry()) {
1713 // If N has no uses, it is dead. Make sure to revisit all N's operands once
1714 // N is deleted from the DAG, since they too may now be dead or may have a
1715 // reduced number of uses, allowing other xforms.
1716 if (recursivelyDeleteUnusedNodes(N))
1717 continue;
1718
1719 WorklistRemover DeadNodes(*this);
1720
1721 // If this combine is running after legalizing the DAG, re-legalize any
1722 // nodes pulled off the worklist.
1723 if (LegalDAG) {
1724 SmallSetVector<SDNode *, 16> UpdatedNodes;
1725 bool NIsValid = DAG.LegalizeOp(N, UpdatedNodes);
1726
1727 for (SDNode *LN : UpdatedNodes)
1728 AddToWorklistWithUsers(N: LN);
1729
1730 if (!NIsValid)
1731 continue;
1732 }
1733
1734 LLVM_DEBUG(dbgs() << "\nCombining: "; N->dump(&DAG));
1735
1736 // Add any operands of the new node which have not yet been combined to the
1737 // worklist as well. Because the worklist uniques things already, this
1738 // won't repeatedly process the same operand.
1739 for (const SDValue &ChildN : N->op_values())
1740 if (!CombinedNodes.count(Ptr: ChildN.getNode()))
1741 AddToWorklist(N: ChildN.getNode());
1742
1743 CombinedNodes.insert(Ptr: N);
1744 SDValue RV = combine(N);
1745
1746 if (!RV.getNode())
1747 continue;
1748
1749 ++NodesCombined;
1750
1751 // If we get back the same node we passed in, rather than a new node or
1752 // zero, we know that the node must have defined multiple values and
1753 // CombineTo was used. Since CombineTo takes care of the worklist
1754 // mechanics for us, we have no work to do in this case.
1755 if (RV.getNode() == N)
1756 continue;
1757
1758 assert(N->getOpcode() != ISD::DELETED_NODE &&
1759 RV.getOpcode() != ISD::DELETED_NODE &&
1760 "Node was deleted but visit returned new node!");
1761
1762 LLVM_DEBUG(dbgs() << " ... into: "; RV.dump(&DAG));
1763
1764 if (N->getNumValues() == RV->getNumValues())
1765 DAG.ReplaceAllUsesWith(From: N, To: RV.getNode());
1766 else {
1767 assert(N->getValueType(0) == RV.getValueType() &&
1768 N->getNumValues() == 1 && "Type mismatch");
1769 DAG.ReplaceAllUsesWith(From: N, To: &RV);
1770 }
1771
1772 // Push the new node and any users onto the worklist. Omit this if the
1773 // new node is the EntryToken (e.g. if a store managed to get optimized
1774 // out), because re-visiting the EntryToken and its users will not uncover
1775 // any additional opportunities, but there may be a large number of such
1776 // users, potentially causing compile time explosion.
1777 if (RV.getOpcode() != ISD::EntryToken)
1778 AddToWorklistWithUsers(N: RV.getNode());
1779
1780 // Finally, if the node is now dead, remove it from the graph. The node
1781 // may not be dead if the replacement process recursively simplified to
1782 // something else needing this node. This will also take care of adding any
1783 // operands which have lost a user to the worklist.
1784 recursivelyDeleteUnusedNodes(N);
1785 }
1786
1787 // If the root changed (e.g. it was a dead load, update the root).
1788 DAG.setRoot(Dummy.getValue());
1789 DAG.RemoveDeadNodes();
1790}
1791
1792SDValue DAGCombiner::visit(SDNode *N) {
1793 // clang-format off
1794 switch (N->getOpcode()) {
1795 default: break;
1796 case ISD::TokenFactor: return visitTokenFactor(N);
1797 case ISD::MERGE_VALUES: return visitMERGE_VALUES(N);
1798 case ISD::ADD: return visitADD(N);
1799 case ISD::SUB: return visitSUB(N);
1800 case ISD::SADDSAT:
1801 case ISD::UADDSAT: return visitADDSAT(N);
1802 case ISD::SSUBSAT:
1803 case ISD::USUBSAT: return visitSUBSAT(N);
1804 case ISD::ADDC: return visitADDC(N);
1805 case ISD::SADDO:
1806 case ISD::UADDO: return visitADDO(N);
1807 case ISD::SUBC: return visitSUBC(N);
1808 case ISD::SSUBO:
1809 case ISD::USUBO: return visitSUBO(N);
1810 case ISD::ADDE: return visitADDE(N);
1811 case ISD::UADDO_CARRY: return visitUADDO_CARRY(N);
1812 case ISD::SADDO_CARRY: return visitSADDO_CARRY(N);
1813 case ISD::SUBE: return visitSUBE(N);
1814 case ISD::USUBO_CARRY: return visitUSUBO_CARRY(N);
1815 case ISD::SSUBO_CARRY: return visitSSUBO_CARRY(N);
1816 case ISD::SMULFIX:
1817 case ISD::SMULFIXSAT:
1818 case ISD::UMULFIX:
1819 case ISD::UMULFIXSAT: return visitMULFIX(N);
1820 case ISD::MUL: return visitMUL(N);
1821 case ISD::SDIV: return visitSDIV(N);
1822 case ISD::UDIV: return visitUDIV(N);
1823 case ISD::SREM:
1824 case ISD::UREM: return visitREM(N);
1825 case ISD::MULHU: return visitMULHU(N);
1826 case ISD::MULHS: return visitMULHS(N);
1827 case ISD::AVGFLOORS:
1828 case ISD::AVGFLOORU:
1829 case ISD::AVGCEILS:
1830 case ISD::AVGCEILU: return visitAVG(N);
1831 case ISD::ABDS:
1832 case ISD::ABDU: return visitABD(N);
1833 case ISD::SMUL_LOHI: return visitSMUL_LOHI(N);
1834 case ISD::UMUL_LOHI: return visitUMUL_LOHI(N);
1835 case ISD::SMULO:
1836 case ISD::UMULO: return visitMULO(N);
1837 case ISD::SMIN:
1838 case ISD::SMAX:
1839 case ISD::UMIN:
1840 case ISD::UMAX: return visitIMINMAX(N);
1841 case ISD::AND: return visitAND(N);
1842 case ISD::OR: return visitOR(N);
1843 case ISD::XOR: return visitXOR(N);
1844 case ISD::SHL: return visitSHL(N);
1845 case ISD::SRA: return visitSRA(N);
1846 case ISD::SRL: return visitSRL(N);
1847 case ISD::ROTR:
1848 case ISD::ROTL: return visitRotate(N);
1849 case ISD::FSHL:
1850 case ISD::FSHR: return visitFunnelShift(N);
1851 case ISD::SSHLSAT:
1852 case ISD::USHLSAT: return visitSHLSAT(N);
1853 case ISD::ABS: return visitABS(N);
1854 case ISD::BSWAP: return visitBSWAP(N);
1855 case ISD::BITREVERSE: return visitBITREVERSE(N);
1856 case ISD::CTLZ: return visitCTLZ(N);
1857 case ISD::CTLZ_ZERO_UNDEF: return visitCTLZ_ZERO_UNDEF(N);
1858 case ISD::CTTZ: return visitCTTZ(N);
1859 case ISD::CTTZ_ZERO_UNDEF: return visitCTTZ_ZERO_UNDEF(N);
1860 case ISD::CTPOP: return visitCTPOP(N);
1861 case ISD::SELECT: return visitSELECT(N);
1862 case ISD::VSELECT: return visitVSELECT(N);
1863 case ISD::SELECT_CC: return visitSELECT_CC(N);
1864 case ISD::SETCC: return visitSETCC(N);
1865 case ISD::SETCCCARRY: return visitSETCCCARRY(N);
1866 case ISD::SIGN_EXTEND: return visitSIGN_EXTEND(N);
1867 case ISD::ZERO_EXTEND: return visitZERO_EXTEND(N);
1868 case ISD::ANY_EXTEND: return visitANY_EXTEND(N);
1869 case ISD::AssertSext:
1870 case ISD::AssertZext: return visitAssertExt(N);
1871 case ISD::AssertAlign: return visitAssertAlign(N);
1872 case ISD::SIGN_EXTEND_INREG: return visitSIGN_EXTEND_INREG(N);
1873 case ISD::SIGN_EXTEND_VECTOR_INREG:
1874 case ISD::ZERO_EXTEND_VECTOR_INREG:
1875 case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
1876 case ISD::TRUNCATE: return visitTRUNCATE(N);
1877 case ISD::BITCAST: return visitBITCAST(N);
1878 case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
1879 case ISD::FADD: return visitFADD(N);
1880 case ISD::STRICT_FADD: return visitSTRICT_FADD(N);
1881 case ISD::FSUB: return visitFSUB(N);
1882 case ISD::FMUL: return visitFMUL(N);
1883 case ISD::FMA: return visitFMA<EmptyMatchContext>(N);
1884 case ISD::FMAD: return visitFMAD(N);
1885 case ISD::FDIV: return visitFDIV(N);
1886 case ISD::FREM: return visitFREM(N);
1887 case ISD::FSQRT: return visitFSQRT(N);
1888 case ISD::FCOPYSIGN: return visitFCOPYSIGN(N);
1889 case ISD::FPOW: return visitFPOW(N);
1890 case ISD::SINT_TO_FP: return visitSINT_TO_FP(N);
1891 case ISD::UINT_TO_FP: return visitUINT_TO_FP(N);
1892 case ISD::FP_TO_SINT: return visitFP_TO_SINT(N);
1893 case ISD::FP_TO_UINT: return visitFP_TO_UINT(N);
1894 case ISD::LRINT:
1895 case ISD::LLRINT: return visitXRINT(N);
1896 case ISD::FP_ROUND: return visitFP_ROUND(N);
1897 case ISD::FP_EXTEND: return visitFP_EXTEND(N);
1898 case ISD::FNEG: return visitFNEG(N);
1899 case ISD::FABS: return visitFABS(N);
1900 case ISD::FFLOOR: return visitFFLOOR(N);
1901 case ISD::FMINNUM:
1902 case ISD::FMAXNUM:
1903 case ISD::FMINIMUM:
1904 case ISD::FMAXIMUM: return visitFMinMax(N);
1905 case ISD::FCEIL: return visitFCEIL(N);
1906 case ISD::FTRUNC: return visitFTRUNC(N);
1907 case ISD::FFREXP: return visitFFREXP(N);
1908 case ISD::BRCOND: return visitBRCOND(N);
1909 case ISD::BR_CC: return visitBR_CC(N);
1910 case ISD::LOAD: return visitLOAD(N);
1911 case ISD::STORE: return visitSTORE(N);
1912 case ISD::INSERT_VECTOR_ELT: return visitINSERT_VECTOR_ELT(N);
1913 case ISD::EXTRACT_VECTOR_ELT: return visitEXTRACT_VECTOR_ELT(N);
1914 case ISD::BUILD_VECTOR: return visitBUILD_VECTOR(N);
1915 case ISD::CONCAT_VECTORS: return visitCONCAT_VECTORS(N);
1916 case ISD::EXTRACT_SUBVECTOR: return visitEXTRACT_SUBVECTOR(N);
1917 case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N);
1918 case ISD::SCALAR_TO_VECTOR: return visitSCALAR_TO_VECTOR(N);
1919 case ISD::INSERT_SUBVECTOR: return visitINSERT_SUBVECTOR(N);
1920 case ISD::MGATHER: return visitMGATHER(N);
1921 case ISD::MLOAD: return visitMLOAD(N);
1922 case ISD::MSCATTER: return visitMSCATTER(N);
1923 case ISD::MSTORE: return visitMSTORE(N);
1924 case ISD::LIFETIME_END: return visitLIFETIME_END(N);
1925 case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
1926 case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
1927 case ISD::FP_TO_BF16: return visitFP_TO_BF16(N);
1928 case ISD::BF16_TO_FP: return visitBF16_TO_FP(N);
1929 case ISD::FREEZE: return visitFREEZE(N);
1930 case ISD::GET_FPENV_MEM: return visitGET_FPENV_MEM(N);
1931 case ISD::SET_FPENV_MEM: return visitSET_FPENV_MEM(N);
1932 case ISD::VECREDUCE_FADD:
1933 case ISD::VECREDUCE_FMUL:
1934 case ISD::VECREDUCE_ADD:
1935 case ISD::VECREDUCE_MUL:
1936 case ISD::VECREDUCE_AND:
1937 case ISD::VECREDUCE_OR:
1938 case ISD::VECREDUCE_XOR:
1939 case ISD::VECREDUCE_SMAX:
1940 case ISD::VECREDUCE_SMIN:
1941 case ISD::VECREDUCE_UMAX:
1942 case ISD::VECREDUCE_UMIN:
1943 case ISD::VECREDUCE_FMAX:
1944 case ISD::VECREDUCE_FMIN:
1945 case ISD::VECREDUCE_FMAXIMUM:
1946 case ISD::VECREDUCE_FMINIMUM: return visitVECREDUCE(N);
1947#define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
1948#include "llvm/IR/VPIntrinsics.def"
1949 return visitVPOp(N);
1950 }
1951 // clang-format on
1952 return SDValue();
1953}
1954
1955SDValue DAGCombiner::combine(SDNode *N) {
1956 if (!DebugCounter::shouldExecute(CounterName: DAGCombineCounter))
1957 return SDValue();
1958
1959 SDValue RV;
1960 if (!DisableGenericCombines)
1961 RV = visit(N);
1962
1963 // If nothing happened, try a target-specific DAG combine.
1964 if (!RV.getNode()) {
1965 assert(N->getOpcode() != ISD::DELETED_NODE &&
1966 "Node was deleted but visit returned NULL!");
1967
1968 if (N->getOpcode() >= ISD::BUILTIN_OP_END ||
1969 TLI.hasTargetDAGCombine(NT: (ISD::NodeType)N->getOpcode())) {
1970
1971 // Expose the DAG combiner to the target combiner impls.
1972 TargetLowering::DAGCombinerInfo
1973 DagCombineInfo(DAG, Level, false, this);
1974
1975 RV = TLI.PerformDAGCombine(N, DCI&: DagCombineInfo);
1976 }
1977 }
1978
1979 // If nothing happened still, try promoting the operation.
1980 if (!RV.getNode()) {
1981 switch (N->getOpcode()) {
1982 default: break;
1983 case ISD::ADD:
1984 case ISD::SUB:
1985 case ISD::MUL:
1986 case ISD::AND:
1987 case ISD::OR:
1988 case ISD::XOR:
1989 RV = PromoteIntBinOp(Op: SDValue(N, 0));
1990 break;
1991 case ISD::SHL:
1992 case ISD::SRA:
1993 case ISD::SRL:
1994 RV = PromoteIntShiftOp(Op: SDValue(N, 0));
1995 break;
1996 case ISD::SIGN_EXTEND:
1997 case ISD::ZERO_EXTEND:
1998 case ISD::ANY_EXTEND:
1999 RV = PromoteExtend(Op: SDValue(N, 0));
2000 break;
2001 case ISD::LOAD:
2002 if (PromoteLoad(Op: SDValue(N, 0)))
2003 RV = SDValue(N, 0);
2004 break;
2005 }
2006 }
2007
2008 // If N is a commutative binary node, try to eliminate it if the commuted
2009 // version is already present in the DAG.
2010 if (!RV.getNode() && TLI.isCommutativeBinOp(Opcode: N->getOpcode())) {
2011 SDValue N0 = N->getOperand(Num: 0);
2012 SDValue N1 = N->getOperand(Num: 1);
2013
2014 // Constant operands are canonicalized to RHS.
2015 if (N0 != N1 && (isa<ConstantSDNode>(Val: N0) || !isa<ConstantSDNode>(Val: N1))) {
2016 SDValue Ops[] = {N1, N0};
2017 SDNode *CSENode = DAG.getNodeIfExists(Opcode: N->getOpcode(), VTList: N->getVTList(), Ops,
2018 Flags: N->getFlags());
2019 if (CSENode)
2020 return SDValue(CSENode, 0);
2021 }
2022 }
2023
2024 return RV;
2025}
2026
2027/// Given a node, return its input chain if it has one, otherwise return a null
2028/// sd operand.
2029static SDValue getInputChainForNode(SDNode *N) {
2030 if (unsigned NumOps = N->getNumOperands()) {
2031 if (N->getOperand(Num: 0).getValueType() == MVT::Other)
2032 return N->getOperand(Num: 0);
2033 if (N->getOperand(Num: NumOps-1).getValueType() == MVT::Other)
2034 return N->getOperand(Num: NumOps-1);
2035 for (unsigned i = 1; i < NumOps-1; ++i)
2036 if (N->getOperand(Num: i).getValueType() == MVT::Other)
2037 return N->getOperand(Num: i);
2038 }
2039 return SDValue();
2040}
2041
2042SDValue DAGCombiner::visitTokenFactor(SDNode *N) {
2043 // If N has two operands, where one has an input chain equal to the other,
2044 // the 'other' chain is redundant.
2045 if (N->getNumOperands() == 2) {
2046 if (getInputChainForNode(N: N->getOperand(Num: 0).getNode()) == N->getOperand(Num: 1))
2047 return N->getOperand(Num: 0);
2048 if (getInputChainForNode(N: N->getOperand(Num: 1).getNode()) == N->getOperand(Num: 0))
2049 return N->getOperand(Num: 1);
2050 }
2051
2052 // Don't simplify token factors if optnone.
2053 if (OptLevel == CodeGenOptLevel::None)
2054 return SDValue();
2055
2056 // Don't simplify the token factor if the node itself has too many operands.
2057 if (N->getNumOperands() > TokenFactorInlineLimit)
2058 return SDValue();
2059
2060 // If the sole user is a token factor, we should make sure we have a
2061 // chance to merge them together. This prevents TF chains from inhibiting
2062 // optimizations.
2063 if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::TokenFactor)
2064 AddToWorklist(N: *(N->use_begin()));
2065
2066 SmallVector<SDNode *, 8> TFs; // List of token factors to visit.
2067 SmallVector<SDValue, 8> Ops; // Ops for replacing token factor.
2068 SmallPtrSet<SDNode*, 16> SeenOps;
2069 bool Changed = false; // If we should replace this token factor.
2070
2071 // Start out with this token factor.
2072 TFs.push_back(Elt: N);
2073
2074 // Iterate through token factors. The TFs grows when new token factors are
2075 // encountered.
2076 for (unsigned i = 0; i < TFs.size(); ++i) {
2077 // Limit number of nodes to inline, to avoid quadratic compile times.
2078 // We have to add the outstanding Token Factors to Ops, otherwise we might
2079 // drop Ops from the resulting Token Factors.
2080 if (Ops.size() > TokenFactorInlineLimit) {
2081 for (unsigned j = i; j < TFs.size(); j++)
2082 Ops.emplace_back(Args&: TFs[j], Args: 0);
2083 // Drop unprocessed Token Factors from TFs, so we do not add them to the
2084 // combiner worklist later.
2085 TFs.resize(N: i);
2086 break;
2087 }
2088
2089 SDNode *TF = TFs[i];
2090 // Check each of the operands.
2091 for (const SDValue &Op : TF->op_values()) {
2092 switch (Op.getOpcode()) {
2093 case ISD::EntryToken:
2094 // Entry tokens don't need to be added to the list. They are
2095 // redundant.
2096 Changed = true;
2097 break;
2098
2099 case ISD::TokenFactor:
2100 if (Op.hasOneUse() && !is_contained(Range&: TFs, Element: Op.getNode())) {
2101 // Queue up for processing.
2102 TFs.push_back(Elt: Op.getNode());
2103 Changed = true;
2104 break;
2105 }
2106 [[fallthrough]];
2107
2108 default:
2109 // Only add if it isn't already in the list.
2110 if (SeenOps.insert(Ptr: Op.getNode()).second)
2111 Ops.push_back(Elt: Op);
2112 else
2113 Changed = true;
2114 break;
2115 }
2116 }
2117 }
2118
2119 // Re-visit inlined Token Factors, to clean them up in case they have been
2120 // removed. Skip the first Token Factor, as this is the current node.
2121 for (unsigned i = 1, e = TFs.size(); i < e; i++)
2122 AddToWorklist(N: TFs[i]);
2123
2124 // Remove Nodes that are chained to another node in the list. Do so
2125 // by walking up chains breath-first stopping when we've seen
2126 // another operand. In general we must climb to the EntryNode, but we can exit
2127 // early if we find all remaining work is associated with just one operand as
2128 // no further pruning is possible.
2129
2130 // List of nodes to search through and original Ops from which they originate.
2131 SmallVector<std::pair<SDNode *, unsigned>, 8> Worklist;
2132 SmallVector<unsigned, 8> OpWorkCount; // Count of work for each Op.
2133 SmallPtrSet<SDNode *, 16> SeenChains;
2134 bool DidPruneOps = false;
2135
2136 unsigned NumLeftToConsider = 0;
2137 for (const SDValue &Op : Ops) {
2138 Worklist.push_back(Elt: std::make_pair(x: Op.getNode(), y: NumLeftToConsider++));
2139 OpWorkCount.push_back(Elt: 1);
2140 }
2141
2142 auto AddToWorklist = [&](unsigned CurIdx, SDNode *Op, unsigned OpNumber) {
2143 // If this is an Op, we can remove the op from the list. Remark any
2144 // search associated with it as from the current OpNumber.
2145 if (SeenOps.contains(Ptr: Op)) {
2146 Changed = true;
2147 DidPruneOps = true;
2148 unsigned OrigOpNumber = 0;
2149 while (OrigOpNumber < Ops.size() && Ops[OrigOpNumber].getNode() != Op)
2150 OrigOpNumber++;
2151 assert((OrigOpNumber != Ops.size()) &&
2152 "expected to find TokenFactor Operand");
2153 // Re-mark worklist from OrigOpNumber to OpNumber
2154 for (unsigned i = CurIdx + 1; i < Worklist.size(); ++i) {
2155 if (Worklist[i].second == OrigOpNumber) {
2156 Worklist[i].second = OpNumber;
2157 }
2158 }
2159 OpWorkCount[OpNumber] += OpWorkCount[OrigOpNumber];
2160 OpWorkCount[OrigOpNumber] = 0;
2161 NumLeftToConsider--;
2162 }
2163 // Add if it's a new chain
2164 if (SeenChains.insert(Ptr: Op).second) {
2165 OpWorkCount[OpNumber]++;
2166 Worklist.push_back(Elt: std::make_pair(x&: Op, y&: OpNumber));
2167 }
2168 };
2169
2170 for (unsigned i = 0; i < Worklist.size() && i < 1024; ++i) {
2171 // We need at least be consider at least 2 Ops to prune.
2172 if (NumLeftToConsider <= 1)
2173 break;
2174 auto CurNode = Worklist[i].first;
2175 auto CurOpNumber = Worklist[i].second;
2176 assert((OpWorkCount[CurOpNumber] > 0) &&
2177 "Node should not appear in worklist");
2178 switch (CurNode->getOpcode()) {
2179 case ISD::EntryToken:
2180 // Hitting EntryToken is the only way for the search to terminate without
2181 // hitting
2182 // another operand's search. Prevent us from marking this operand
2183 // considered.
2184 NumLeftToConsider++;
2185 break;
2186 case ISD::TokenFactor:
2187 for (const SDValue &Op : CurNode->op_values())
2188 AddToWorklist(i, Op.getNode(), CurOpNumber);
2189 break;
2190 case ISD::LIFETIME_START:
2191 case ISD::LIFETIME_END:
2192 case ISD::CopyFromReg:
2193 case ISD::CopyToReg:
2194 AddToWorklist(i, CurNode->getOperand(Num: 0).getNode(), CurOpNumber);
2195 break;
2196 default:
2197 if (auto *MemNode = dyn_cast<MemSDNode>(Val: CurNode))
2198 AddToWorklist(i, MemNode->getChain().getNode(), CurOpNumber);
2199 break;
2200 }
2201 OpWorkCount[CurOpNumber]--;
2202 if (OpWorkCount[CurOpNumber] == 0)
2203 NumLeftToConsider--;
2204 }
2205
2206 // If we've changed things around then replace token factor.
2207 if (Changed) {
2208 SDValue Result;
2209 if (Ops.empty()) {
2210 // The entry token is the only possible outcome.
2211 Result = DAG.getEntryNode();
2212 } else {
2213 if (DidPruneOps) {
2214 SmallVector<SDValue, 8> PrunedOps;
2215 //
2216 for (const SDValue &Op : Ops) {
2217 if (SeenChains.count(Ptr: Op.getNode()) == 0)
2218 PrunedOps.push_back(Elt: Op);
2219 }
2220 Result = DAG.getTokenFactor(DL: SDLoc(N), Vals&: PrunedOps);
2221 } else {
2222 Result = DAG.getTokenFactor(DL: SDLoc(N), Vals&: Ops);
2223 }
2224 }
2225 return Result;
2226 }
2227 return SDValue();
2228}
2229
2230/// MERGE_VALUES can always be eliminated.
2231SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) {
2232 WorklistRemover DeadNodes(*this);
2233 // Replacing results may cause a different MERGE_VALUES to suddenly
2234 // be CSE'd with N, and carry its uses with it. Iterate until no
2235 // uses remain, to ensure that the node can be safely deleted.
2236 // First add the users of this node to the work list so that they
2237 // can be tried again once they have new operands.
2238 AddUsersToWorklist(N);
2239 do {
2240 // Do as a single replacement to avoid rewalking use lists.
2241 SmallVector<SDValue, 8> Ops;
2242 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2243 Ops.push_back(Elt: N->getOperand(Num: i));
2244 DAG.ReplaceAllUsesWith(From: N, To: Ops.data());
2245 } while (!N->use_empty());
2246 deleteAndRecombine(N);
2247 return SDValue(N, 0); // Return N so it doesn't get rechecked!
2248}
2249
2250/// If \p N is a ConstantSDNode with isOpaque() == false return it casted to a
2251/// ConstantSDNode pointer else nullptr.
2252static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
2253 ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Val&: N);
2254 return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
2255}
2256
2257// isTruncateOf - If N is a truncate of some other value, return true, record
2258// the value being truncated in Op and which of Op's bits are zero/one in Known.
2259// This function computes KnownBits to avoid a duplicated call to
2260// computeKnownBits in the caller.
2261static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
2262 KnownBits &Known) {
2263 if (N->getOpcode() == ISD::TRUNCATE) {
2264 Op = N->getOperand(Num: 0);
2265 Known = DAG.computeKnownBits(Op);
2266 return true;
2267 }
2268
2269 if (N.getOpcode() != ISD::SETCC ||
2270 N.getValueType().getScalarType() != MVT::i1 ||
2271 cast<CondCodeSDNode>(Val: N.getOperand(i: 2))->get() != ISD::SETNE)
2272 return false;
2273
2274 SDValue Op0 = N->getOperand(Num: 0);
2275 SDValue Op1 = N->getOperand(Num: 1);
2276 assert(Op0.getValueType() == Op1.getValueType());
2277
2278 if (isNullOrNullSplat(V: Op0))
2279 Op = Op1;
2280 else if (isNullOrNullSplat(V: Op1))
2281 Op = Op0;
2282 else
2283 return false;
2284
2285 Known = DAG.computeKnownBits(Op);
2286
2287 return (Known.Zero | 1).isAllOnes();
2288}
2289
2290/// Return true if 'Use' is a load or a store that uses N as its base pointer
2291/// and that N may be folded in the load / store addressing mode.
2292static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
2293 const TargetLowering &TLI) {
2294 EVT VT;
2295 unsigned AS;
2296
2297 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Val: Use)) {
2298 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2299 return false;
2300 VT = LD->getMemoryVT();
2301 AS = LD->getAddressSpace();
2302 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Val: Use)) {
2303 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2304 return false;
2305 VT = ST->getMemoryVT();
2306 AS = ST->getAddressSpace();
2307 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Val: Use)) {
2308 if (LD->isIndexed() || LD->getBasePtr().getNode() != N)
2309 return false;
2310 VT = LD->getMemoryVT();
2311 AS = LD->getAddressSpace();
2312 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Val: Use)) {
2313 if (ST->isIndexed() || ST->getBasePtr().getNode() != N)
2314 return false;
2315 VT = ST->getMemoryVT();
2316 AS = ST->getAddressSpace();
2317 } else {
2318 return false;
2319 }
2320
2321 TargetLowering::AddrMode AM;
2322 if (N->getOpcode() == ISD::ADD) {
2323 AM.HasBaseReg = true;
2324 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
2325 if (Offset)
2326 // [reg +/- imm]
2327 AM.BaseOffs = Offset->getSExtValue();
2328 else
2329 // [reg +/- reg]
2330 AM.Scale = 1;
2331 } else if (N->getOpcode() == ISD::SUB) {
2332 AM.HasBaseReg = true;
2333 ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
2334 if (Offset)
2335 // [reg +/- imm]
2336 AM.BaseOffs = -Offset->getSExtValue();
2337 else
2338 // [reg +/- reg]
2339 AM.Scale = 1;
2340 } else {
2341 return false;
2342 }
2343
2344 return TLI.isLegalAddressingMode(DL: DAG.getDataLayout(), AM,
2345 Ty: VT.getTypeForEVT(Context&: *DAG.getContext()), AddrSpace: AS);
2346}
2347
2348/// This inverts a canonicalization in IR that replaces a variable select arm
2349/// with an identity constant. Codegen improves if we re-use the variable
2350/// operand rather than load a constant. This can also be converted into a
2351/// masked vector operation if the target supports it.
2352static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
2353 bool ShouldCommuteOperands) {
2354 // Match a select as operand 1. The identity constant that we are looking for
2355 // is only valid as operand 1 of a non-commutative binop.
2356 SDValue N0 = N->getOperand(Num: 0);
2357 SDValue N1 = N->getOperand(Num: 1);
2358 if (ShouldCommuteOperands)
2359 std::swap(a&: N0, b&: N1);
2360
2361 // TODO: Should this apply to scalar select too?
2362 if (N1.getOpcode() != ISD::VSELECT || !N1.hasOneUse())
2363 return SDValue();
2364
2365 // We can't hoist all instructions because of immediate UB (not speculatable).
2366 // For example div/rem by zero.
2367 if (!DAG.isSafeToSpeculativelyExecuteNode(N))
2368 return SDValue();
2369
2370 unsigned Opcode = N->getOpcode();
2371 EVT VT = N->getValueType(ResNo: 0);
2372 SDValue Cond = N1.getOperand(i: 0);
2373 SDValue TVal = N1.getOperand(i: 1);
2374 SDValue FVal = N1.getOperand(i: 2);
2375
2376 // This transform increases uses of N0, so freeze it to be safe.
2377 // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
2378 unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
2379 if (isNeutralConstant(Opc: Opcode, Flags: N->getFlags(), V: TVal, OperandNo: OpNo)) {
2380 SDValue F0 = DAG.getFreeze(V: N0);
2381 SDValue NewBO = DAG.getNode(Opcode, DL: SDLoc(N), VT, N1: F0, N2: FVal, Flags: N->getFlags());
2382 return DAG.getSelect(DL: SDLoc(N), VT, Cond, LHS: F0, RHS: NewBO);
2383 }
2384 // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2385 if (isNeutralConstant(Opc: Opcode, Flags: N->getFlags(), V: FVal, OperandNo: OpNo)) {
2386 SDValue F0 = DAG.getFreeze(V: N0);
2387 SDValue NewBO = DAG.getNode(Opcode, DL: SDLoc(N), VT, N1: F0, N2: TVal, Flags: N->getFlags());
2388 return DAG.getSelect(DL: SDLoc(N), VT, Cond, LHS: NewBO, RHS: F0);
2389 }
2390
2391 return SDValue();
2392}
2393
2394SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2395 assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
2396 "Unexpected binary operator");
2397
2398 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2399 auto BinOpcode = BO->getOpcode();
2400 EVT VT = BO->getValueType(ResNo: 0);
2401 if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) {
2402 if (SDValue Sel = foldSelectWithIdentityConstant(N: BO, DAG, ShouldCommuteOperands: false))
2403 return Sel;
2404
2405 if (TLI.isCommutativeBinOp(Opcode: BO->getOpcode()))
2406 if (SDValue Sel = foldSelectWithIdentityConstant(N: BO, DAG, ShouldCommuteOperands: true))
2407 return Sel;
2408 }
2409
2410 // Don't do this unless the old select is going away. We want to eliminate the
2411 // binary operator, not replace a binop with a select.
2412 // TODO: Handle ISD::SELECT_CC.
2413 unsigned SelOpNo = 0;
2414 SDValue Sel = BO->getOperand(Num: 0);
2415 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
2416 SelOpNo = 1;
2417 Sel = BO->getOperand(Num: 1);
2418
2419 // Peek through trunc to shift amount type.
2420 if ((BinOpcode == ISD::SHL || BinOpcode == ISD::SRA ||
2421 BinOpcode == ISD::SRL) && Sel.hasOneUse()) {
2422 // This is valid when the truncated bits of x are already zero.
2423 SDValue Op;
2424 KnownBits Known;
2425 if (isTruncateOf(DAG, N: Sel, Op, Known) &&
2426 Known.countMaxActiveBits() < Sel.getScalarValueSizeInBits())
2427 Sel = Op;
2428 }
2429 }
2430
2431 if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
2432 return SDValue();
2433
2434 SDValue CT = Sel.getOperand(i: 1);
2435 if (!isConstantOrConstantVector(N: CT, NoOpaques: true) &&
2436 !DAG.isConstantFPBuildVectorOrConstantFP(N: CT))
2437 return SDValue();
2438
2439 SDValue CF = Sel.getOperand(i: 2);
2440 if (!isConstantOrConstantVector(N: CF, NoOpaques: true) &&
2441 !DAG.isConstantFPBuildVectorOrConstantFP(N: CF))
2442 return SDValue();
2443
2444 // Bail out if any constants are opaque because we can't constant fold those.
2445 // The exception is "and" and "or" with either 0 or -1 in which case we can
2446 // propagate non constant operands into select. I.e.:
2447 // and (select Cond, 0, -1), X --> select Cond, 0, X
2448 // or X, (select Cond, -1, 0) --> select Cond, -1, X
2449 bool CanFoldNonConst =
2450 (BinOpcode == ISD::AND || BinOpcode == ISD::OR) &&
2451 ((isNullOrNullSplat(V: CT) && isAllOnesOrAllOnesSplat(V: CF)) ||
2452 (isNullOrNullSplat(V: CF) && isAllOnesOrAllOnesSplat(V: CT)));
2453
2454 SDValue CBO = BO->getOperand(Num: SelOpNo ^ 1);
2455 if (!CanFoldNonConst &&
2456 !isConstantOrConstantVector(N: CBO, NoOpaques: true) &&
2457 !DAG.isConstantFPBuildVectorOrConstantFP(N: CBO))
2458 return SDValue();
2459
2460 SDLoc DL(Sel);
2461 SDValue NewCT, NewCF;
2462
2463 if (CanFoldNonConst) {
2464 // If CBO is an opaque constant, we can't rely on getNode to constant fold.
2465 if ((BinOpcode == ISD::AND && isNullOrNullSplat(V: CT)) ||
2466 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(V: CT)))
2467 NewCT = CT;
2468 else
2469 NewCT = CBO;
2470
2471 if ((BinOpcode == ISD::AND && isNullOrNullSplat(V: CF)) ||
2472 (BinOpcode == ISD::OR && isAllOnesOrAllOnesSplat(V: CF)))
2473 NewCF = CF;
2474 else
2475 NewCF = CBO;
2476 } else {
2477 // We have a select-of-constants followed by a binary operator with a
2478 // constant. Eliminate the binop by pulling the constant math into the
2479 // select. Example: add (select Cond, CT, CF), CBO --> select Cond, CT +
2480 // CBO, CF + CBO
2481 NewCT = SelOpNo ? DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CBO, CT})
2482 : DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CT, CBO});
2483 if (!NewCT)
2484 return SDValue();
2485
2486 NewCF = SelOpNo ? DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CBO, CF})
2487 : DAG.FoldConstantArithmetic(Opcode: BinOpcode, DL, VT, Ops: {CF, CBO});
2488 if (!NewCF)
2489 return SDValue();
2490 }
2491
2492 SDValue SelectOp = DAG.getSelect(DL, VT, Cond: Sel.getOperand(i: 0), LHS: NewCT, RHS: NewCF);
2493 SelectOp->setFlags(BO->getFlags());
2494 return SelectOp;
2495}
2496
2497static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
2498 SelectionDAG &DAG) {
2499 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2500 "Expecting add or sub");
2501
2502 // Match a constant operand and a zext operand for the math instruction:
2503 // add Z, C
2504 // sub C, Z
2505 bool IsAdd = N->getOpcode() == ISD::ADD;
2506 SDValue C = IsAdd ? N->getOperand(Num: 1) : N->getOperand(Num: 0);
2507 SDValue Z = IsAdd ? N->getOperand(Num: 0) : N->getOperand(Num: 1);
2508 auto *CN = dyn_cast<ConstantSDNode>(Val&: C);
2509 if (!CN || Z.getOpcode() != ISD::ZERO_EXTEND)
2510 return SDValue();
2511
2512 // Match the zext operand as a setcc of a boolean.
2513 if (Z.getOperand(i: 0).getOpcode() != ISD::SETCC ||
2514 Z.getOperand(i: 0).getValueType() != MVT::i1)
2515 return SDValue();
2516
2517 // Match the compare as: setcc (X & 1), 0, eq.
2518 SDValue SetCC = Z.getOperand(i: 0);
2519 ISD::CondCode CC = cast<CondCodeSDNode>(Val: SetCC->getOperand(Num: 2))->get();
2520 if (CC != ISD::SETEQ || !isNullConstant(V: SetCC.getOperand(i: 1)) ||
2521 SetCC.getOperand(i: 0).getOpcode() != ISD::AND ||
2522 !isOneConstant(V: SetCC.getOperand(i: 0).getOperand(i: 1)))
2523 return SDValue();
2524
2525 // We are adding/subtracting a constant and an inverted low bit. Turn that
2526 // into a subtract/add of the low bit with incremented/decremented constant:
2527 // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
2528 // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
2529 EVT VT = C.getValueType();
2530 SDValue LowBit = DAG.getZExtOrTrunc(Op: SetCC.getOperand(i: 0), DL, VT);
2531 SDValue C1 = IsAdd ? DAG.getConstant(Val: CN->getAPIntValue() + 1, DL, VT) :
2532 DAG.getConstant(Val: CN->getAPIntValue() - 1, DL, VT);
2533 return DAG.getNode(Opcode: IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N1: C1, N2: LowBit);
2534}
2535
2536// Attempt to form avgceil(A, B) from (A | B) - ((A ^ B) >> 1)
2537SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
2538 SDValue N0 = N->getOperand(Num: 0);
2539 EVT VT = N0.getValueType();
2540 SDValue A, B;
2541
2542 if (hasOperation(Opcode: ISD::AVGCEILU, VT) &&
2543 sd_match(N, P: m_Sub(L: m_Or(L: m_Value(N&: A), R: m_Value(N&: B)),
2544 R: m_Srl(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)),
2545 R: m_SpecificInt(V: 1))))) {
2546 return DAG.getNode(Opcode: ISD::AVGCEILU, DL, VT, N1: A, N2: B);
2547 }
2548 if (hasOperation(Opcode: ISD::AVGCEILS, VT) &&
2549 sd_match(N, P: m_Sub(L: m_Or(L: m_Value(N&: A), R: m_Value(N&: B)),
2550 R: m_Sra(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)),
2551 R: m_SpecificInt(V: 1))))) {
2552 return DAG.getNode(Opcode: ISD::AVGCEILS, DL, VT, N1: A, N2: B);
2553 }
2554 return SDValue();
2555}
2556
2557/// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2558/// a shift and add with a different constant.
2559static SDValue foldAddSubOfSignBit(SDNode *N, const SDLoc &DL,
2560 SelectionDAG &DAG) {
2561 assert((N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::SUB) &&
2562 "Expecting add or sub");
2563
2564 // We need a constant operand for the add/sub, and the other operand is a
2565 // logical shift right: add (srl), C or sub C, (srl).
2566 bool IsAdd = N->getOpcode() == ISD::ADD;
2567 SDValue ConstantOp = IsAdd ? N->getOperand(Num: 1) : N->getOperand(Num: 0);
2568 SDValue ShiftOp = IsAdd ? N->getOperand(Num: 0) : N->getOperand(Num: 1);
2569 if (!DAG.isConstantIntBuildVectorOrConstantInt(N: ConstantOp) ||
2570 ShiftOp.getOpcode() != ISD::SRL)
2571 return SDValue();
2572
2573 // The shift must be of a 'not' value.
2574 SDValue Not = ShiftOp.getOperand(i: 0);
2575 if (!Not.hasOneUse() || !isBitwiseNot(V: Not))
2576 return SDValue();
2577
2578 // The shift must be moving the sign bit to the least-significant-bit.
2579 EVT VT = ShiftOp.getValueType();
2580 SDValue ShAmt = ShiftOp.getOperand(i: 1);
2581 ConstantSDNode *ShAmtC = isConstOrConstSplat(N: ShAmt);
2582 if (!ShAmtC || ShAmtC->getAPIntValue() != (VT.getScalarSizeInBits() - 1))
2583 return SDValue();
2584
2585 // Eliminate the 'not' by adjusting the shift and add/sub constant:
2586 // add (srl (not X), 31), C --> add (sra X, 31), (C + 1)
2587 // sub C, (srl (not X), 31) --> add (srl X, 31), (C - 1)
2588 if (SDValue NewC = DAG.FoldConstantArithmetic(
2589 Opcode: IsAdd ? ISD::ADD : ISD::SUB, DL, VT,
2590 Ops: {ConstantOp, DAG.getConstant(Val: 1, DL, VT)})) {
2591 SDValue NewShift = DAG.getNode(Opcode: IsAdd ? ISD::SRA : ISD::SRL, DL, VT,
2592 N1: Not.getOperand(i: 0), N2: ShAmt);
2593 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: NewShift, N2: NewC);
2594 }
2595
2596 return SDValue();
2597}
2598
2599static bool
2600areBitwiseNotOfEachother(SDValue Op0, SDValue Op1) {
2601 return (isBitwiseNot(V: Op0) && Op0.getOperand(i: 0) == Op1) ||
2602 (isBitwiseNot(V: Op1) && Op1.getOperand(i: 0) == Op0);
2603}
2604
2605/// Try to fold a node that behaves like an ADD (note that N isn't necessarily
2606/// an ISD::ADD here, it could for example be an ISD::OR if we know that there
2607/// are no common bits set in the operands).
2608SDValue DAGCombiner::visitADDLike(SDNode *N) {
2609 SDValue N0 = N->getOperand(Num: 0);
2610 SDValue N1 = N->getOperand(Num: 1);
2611 EVT VT = N0.getValueType();
2612 SDLoc DL(N);
2613
2614 // fold (add x, undef) -> undef
2615 if (N0.isUndef())
2616 return N0;
2617 if (N1.isUndef())
2618 return N1;
2619
2620 // fold (add c1, c2) -> c1+c2
2621 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N0, N1}))
2622 return C;
2623
2624 // canonicalize constant to RHS
2625 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
2626 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
2627 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0);
2628
2629 if (areBitwiseNotOfEachother(Op0: N0, Op1: N1))
2630 return DAG.getConstant(Val: APInt::getAllOnes(numBits: VT.getScalarSizeInBits()),
2631 DL: SDLoc(N), VT);
2632
2633 // fold vector ops
2634 if (VT.isVector()) {
2635 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2636 return FoldedVOp;
2637
2638 // fold (add x, 0) -> x, vector edition
2639 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
2640 return N0;
2641 }
2642
2643 // fold (add x, 0) -> x
2644 if (isNullConstant(V: N1))
2645 return N0;
2646
2647 if (N0.getOpcode() == ISD::SUB) {
2648 SDValue N00 = N0.getOperand(i: 0);
2649 SDValue N01 = N0.getOperand(i: 1);
2650
2651 // fold ((A-c1)+c2) -> (A+(c2-c1))
2652 if (SDValue Sub = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N1, N01}))
2653 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: Sub);
2654
2655 // fold ((c1-A)+c2) -> (c1+c2)-A
2656 if (SDValue Add = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N1, N00}))
2657 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Add, N2: N0.getOperand(i: 1));
2658 }
2659
2660 // add (sext i1 X), 1 -> zext (not i1 X)
2661 // We don't transform this pattern:
2662 // add (zext i1 X), -1 -> sext (not i1 X)
2663 // because most (?) targets generate better code for the zext form.
2664 if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
2665 isOneOrOneSplat(V: N1)) {
2666 SDValue X = N0.getOperand(i: 0);
2667 if ((!LegalOperations ||
2668 (TLI.isOperationLegal(Op: ISD::XOR, VT: X.getValueType()) &&
2669 TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT))) &&
2670 X.getScalarValueSizeInBits() == 1) {
2671 SDValue Not = DAG.getNOT(DL, Val: X, VT: X.getValueType());
2672 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Not);
2673 }
2674 }
2675
2676 // Fold (add (or x, c0), c1) -> (add x, (c0 + c1))
2677 // iff (or x, c0) is equivalent to (add x, c0).
2678 // Fold (add (xor x, c0), c1) -> (add x, (c0 + c1))
2679 // iff (xor x, c0) is equivalent to (add x, c0).
2680 if (DAG.isADDLike(Op: N0)) {
2681 SDValue N01 = N0.getOperand(i: 1);
2682 if (SDValue Add = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N1, N01}))
2683 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: Add);
2684 }
2685
2686 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
2687 return NewSel;
2688
2689 // reassociate add
2690 if (!reassociationCanBreakAddressingModePattern(Opc: ISD::ADD, DL, N, N0, N1)) {
2691 if (SDValue RADD = reassociateOps(Opc: ISD::ADD, DL, N0, N1, Flags: N->getFlags()))
2692 return RADD;
2693
2694 // Reassociate (add (or x, c), y) -> (add add(x, y), c)) if (or x, c) is
2695 // equivalent to (add x, c).
2696 // Reassociate (add (xor x, c), y) -> (add add(x, y), c)) if (xor x, c) is
2697 // equivalent to (add x, c).
2698 // Do this optimization only when adding c does not introduce instructions
2699 // for adding carries.
2700 auto ReassociateAddOr = [&](SDValue N0, SDValue N1) {
2701 if (DAG.isADDLike(Op: N0) && N0.hasOneUse() &&
2702 isConstantOrConstantVector(N: N0.getOperand(i: 1), /* NoOpaque */ NoOpaques: true)) {
2703 // If N0's type does not split or is a sign mask, it does not introduce
2704 // add carry.
2705 auto TyActn = TLI.getTypeAction(Context&: *DAG.getContext(), VT: N0.getValueType());
2706 bool NoAddCarry = TyActn == TargetLoweringBase::TypeLegal ||
2707 TyActn == TargetLoweringBase::TypePromoteInteger ||
2708 isMinSignedConstant(V: N0.getOperand(i: 1));
2709 if (NoAddCarry)
2710 return DAG.getNode(
2711 Opcode: ISD::ADD, DL, VT,
2712 N1: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0.getOperand(i: 0)),
2713 N2: N0.getOperand(i: 1));
2714 }
2715 return SDValue();
2716 };
2717 if (SDValue Add = ReassociateAddOr(N0, N1))
2718 return Add;
2719 if (SDValue Add = ReassociateAddOr(N1, N0))
2720 return Add;
2721
2722 // Fold add(vecreduce(x), vecreduce(y)) -> vecreduce(add(x, y))
2723 if (SDValue SD =
2724 reassociateReduction(RedOpc: ISD::VECREDUCE_ADD, Opc: ISD::ADD, DL, VT, N0, N1))
2725 return SD;
2726 }
2727
2728 SDValue A, B, C;
2729
2730 // fold ((0-A) + B) -> B-A
2731 if (sd_match(N: N0, P: m_Neg(V: m_Value(N&: A))))
2732 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: A);
2733
2734 // fold (A + (0-B)) -> A-B
2735 if (sd_match(N: N1, P: m_Neg(V: m_Value(N&: B))))
2736 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: B);
2737
2738 // fold (A+(B-A)) -> B
2739 if (sd_match(N: N1, P: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N0))))
2740 return B;
2741
2742 // fold ((B-A)+A) -> B
2743 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N1))))
2744 return B;
2745
2746 // fold ((A-B)+(C-A)) -> (C-B)
2747 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B))) &&
2748 sd_match(N: N1, P: m_Sub(L: m_Value(N&: C), R: m_Specific(N: A))))
2749 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: C, N2: B);
2750
2751 // fold ((A-B)+(B-C)) -> (A-C)
2752 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B))) &&
2753 sd_match(N: N1, P: m_Sub(L: m_Specific(N: B), R: m_Value(N&: C))))
2754 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: C);
2755
2756 // fold (A+(B-(A+C))) to (B-C)
2757 // fold (A+(B-(C+A))) to (B-C)
2758 if (sd_match(N: N1, P: m_Sub(L: m_Value(N&: B), R: m_Add(L: m_Specific(N: N0), R: m_Value(N&: C)))))
2759 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: B, N2: C);
2760
2761 // fold (A+((B-A)+or-C)) to (B+or-C)
2762 if (sd_match(N: N1,
2763 P: m_AnyOf(preds: m_Add(L: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N0)), R: m_Value(N&: C)),
2764 preds: m_Sub(L: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N0)), R: m_Value(N&: C)))))
2765 return DAG.getNode(Opcode: N1.getOpcode(), DL, VT, N1: B, N2: C);
2766
2767 // fold (A-B)+(C-D) to (A+C)-(B+D) when A or C is constant
2768 if (N0.getOpcode() == ISD::SUB && N1.getOpcode() == ISD::SUB &&
2769 N0->hasOneUse() && N1->hasOneUse()) {
2770 SDValue N00 = N0.getOperand(i: 0);
2771 SDValue N01 = N0.getOperand(i: 1);
2772 SDValue N10 = N1.getOperand(i: 0);
2773 SDValue N11 = N1.getOperand(i: 1);
2774
2775 if (isConstantOrConstantVector(N: N00) || isConstantOrConstantVector(N: N10))
2776 return DAG.getNode(Opcode: ISD::SUB, DL, VT,
2777 N1: DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N0), VT, N1: N00, N2: N10),
2778 N2: DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N1), VT, N1: N01, N2: N11));
2779 }
2780
2781 // fold (add (umax X, C), -C) --> (usubsat X, C)
2782 if (N0.getOpcode() == ISD::UMAX && hasOperation(Opcode: ISD::USUBSAT, VT)) {
2783 auto MatchUSUBSAT = [](ConstantSDNode *Max, ConstantSDNode *Op) {
2784 return (!Max && !Op) ||
2785 (Max && Op && Max->getAPIntValue() == (-Op->getAPIntValue()));
2786 };
2787 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchUSUBSAT,
2788 /*AllowUndefs*/ true))
2789 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: N0.getOperand(i: 0),
2790 N2: N0.getOperand(i: 1));
2791 }
2792
2793 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
2794 return SDValue(N, 0);
2795
2796 if (isOneOrOneSplat(V: N1)) {
2797 // fold (add (xor a, -1), 1) -> (sub 0, a)
2798 if (isBitwiseNot(V: N0))
2799 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: DAG.getConstant(Val: 0, DL, VT),
2800 N2: N0.getOperand(i: 0));
2801
2802 // fold (add (add (xor a, -1), b), 1) -> (sub b, a)
2803 if (N0.getOpcode() == ISD::ADD) {
2804 SDValue A, Xor;
2805
2806 if (isBitwiseNot(V: N0.getOperand(i: 0))) {
2807 A = N0.getOperand(i: 1);
2808 Xor = N0.getOperand(i: 0);
2809 } else if (isBitwiseNot(V: N0.getOperand(i: 1))) {
2810 A = N0.getOperand(i: 0);
2811 Xor = N0.getOperand(i: 1);
2812 }
2813
2814 if (Xor)
2815 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: Xor.getOperand(i: 0));
2816 }
2817
2818 // Look for:
2819 // add (add x, y), 1
2820 // And if the target does not like this form then turn into:
2821 // sub y, (xor x, -1)
2822 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
2823 N0.hasOneUse() &&
2824 // Limit this to after legalization if the add has wrap flags
2825 (Level >= AfterLegalizeDAG || (!N->getFlags().hasNoUnsignedWrap() &&
2826 !N->getFlags().hasNoSignedWrap()))) {
2827 SDValue Not = DAG.getNOT(DL, Val: N0.getOperand(i: 0), VT);
2828 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 1), N2: Not);
2829 }
2830 }
2831
2832 // (x - y) + -1 -> add (xor y, -1), x
2833 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
2834 isAllOnesOrAllOnesSplat(V: N1, /*AllowUndefs=*/true)) {
2835 SDValue Not = DAG.getNOT(DL, Val: N0.getOperand(i: 1), VT);
2836 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Not, N2: N0.getOperand(i: 0));
2837 }
2838
2839 if (SDValue Combined = visitADDLikeCommutative(N0, N1, LocReference: N))
2840 return Combined;
2841
2842 if (SDValue Combined = visitADDLikeCommutative(N0: N1, N1: N0, LocReference: N))
2843 return Combined;
2844
2845 return SDValue();
2846}
2847
2848// Attempt to form avgfloor(A, B) from (A & B) + ((A ^ B) >> 1)
2849SDValue DAGCombiner::foldAddToAvg(SDNode *N, const SDLoc &DL) {
2850 SDValue N0 = N->getOperand(Num: 0);
2851 EVT VT = N0.getValueType();
2852 SDValue A, B;
2853
2854 if (hasOperation(Opcode: ISD::AVGFLOORU, VT) &&
2855 sd_match(N, P: m_Add(L: m_And(L: m_Value(N&: A), R: m_Value(N&: B)),
2856 R: m_Srl(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)),
2857 R: m_SpecificInt(V: 1))))) {
2858 return DAG.getNode(Opcode: ISD::AVGFLOORU, DL, VT, N1: A, N2: B);
2859 }
2860 if (hasOperation(Opcode: ISD::AVGFLOORS, VT) &&
2861 sd_match(N, P: m_Add(L: m_And(L: m_Value(N&: A), R: m_Value(N&: B)),
2862 R: m_Sra(L: m_Xor(L: m_Deferred(V&: A), R: m_Deferred(V&: B)),
2863 R: m_SpecificInt(V: 1))))) {
2864 return DAG.getNode(Opcode: ISD::AVGFLOORS, DL, VT, N1: A, N2: B);
2865 }
2866
2867 return SDValue();
2868}
2869
2870SDValue DAGCombiner::visitADD(SDNode *N) {
2871 SDValue N0 = N->getOperand(Num: 0);
2872 SDValue N1 = N->getOperand(Num: 1);
2873 EVT VT = N0.getValueType();
2874 SDLoc DL(N);
2875
2876 if (SDValue Combined = visitADDLike(N))
2877 return Combined;
2878
2879 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
2880 return V;
2881
2882 if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
2883 return V;
2884
2885 // Try to match AVGFLOOR fixedwidth pattern
2886 if (SDValue V = foldAddToAvg(N, DL))
2887 return V;
2888
2889 // fold (a+b) -> (a|b) iff a and b share no bits.
2890 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::OR, VT)) &&
2891 DAG.haveNoCommonBitsSet(A: N0, B: N1)) {
2892 SDNodeFlags Flags;
2893 Flags.setDisjoint(true);
2894 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: N0, N2: N1, Flags);
2895 }
2896
2897 // Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
2898 if (N0.getOpcode() == ISD::VSCALE && N1.getOpcode() == ISD::VSCALE) {
2899 const APInt &C0 = N0->getConstantOperandAPInt(Num: 0);
2900 const APInt &C1 = N1->getConstantOperandAPInt(Num: 0);
2901 return DAG.getVScale(DL, VT, MulImm: C0 + C1);
2902 }
2903
2904 // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2)
2905 if (N0.getOpcode() == ISD::ADD &&
2906 N0.getOperand(i: 1).getOpcode() == ISD::VSCALE &&
2907 N1.getOpcode() == ISD::VSCALE) {
2908 const APInt &VS0 = N0.getOperand(i: 1)->getConstantOperandAPInt(Num: 0);
2909 const APInt &VS1 = N1->getConstantOperandAPInt(Num: 0);
2910 SDValue VS = DAG.getVScale(DL, VT, MulImm: VS0 + VS1);
2911 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: VS);
2912 }
2913
2914 // Fold (add step_vector(c1), step_vector(c2) to step_vector(c1+c2))
2915 if (N0.getOpcode() == ISD::STEP_VECTOR &&
2916 N1.getOpcode() == ISD::STEP_VECTOR) {
2917 const APInt &C0 = N0->getConstantOperandAPInt(Num: 0);
2918 const APInt &C1 = N1->getConstantOperandAPInt(Num: 0);
2919 APInt NewStep = C0 + C1;
2920 return DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
2921 }
2922
2923 // Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
2924 if (N0.getOpcode() == ISD::ADD &&
2925 N0.getOperand(i: 1).getOpcode() == ISD::STEP_VECTOR &&
2926 N1.getOpcode() == ISD::STEP_VECTOR) {
2927 const APInt &SV0 = N0.getOperand(i: 1)->getConstantOperandAPInt(Num: 0);
2928 const APInt &SV1 = N1->getConstantOperandAPInt(Num: 0);
2929 APInt NewStep = SV0 + SV1;
2930 SDValue SV = DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
2931 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: SV);
2932 }
2933
2934 return SDValue();
2935}
2936
2937SDValue DAGCombiner::visitADDSAT(SDNode *N) {
2938 unsigned Opcode = N->getOpcode();
2939 SDValue N0 = N->getOperand(Num: 0);
2940 SDValue N1 = N->getOperand(Num: 1);
2941 EVT VT = N0.getValueType();
2942 bool IsSigned = Opcode == ISD::SADDSAT;
2943 SDLoc DL(N);
2944
2945 // fold (add_sat x, undef) -> -1
2946 if (N0.isUndef() || N1.isUndef())
2947 return DAG.getAllOnesConstant(DL, VT);
2948
2949 // fold (add_sat c1, c2) -> c3
2950 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
2951 return C;
2952
2953 // canonicalize constant to RHS
2954 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
2955 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
2956 return DAG.getNode(Opcode, DL, VT, N1, N2: N0);
2957
2958 // fold vector ops
2959 if (VT.isVector()) {
2960 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
2961 return FoldedVOp;
2962
2963 // fold (add_sat x, 0) -> x, vector edition
2964 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
2965 return N0;
2966 }
2967
2968 // fold (add_sat x, 0) -> x
2969 if (isNullConstant(V: N1))
2970 return N0;
2971
2972 // If it cannot overflow, transform into an add.
2973 if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
2974 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1);
2975
2976 return SDValue();
2977}
2978
2979static SDValue getAsCarry(const TargetLowering &TLI, SDValue V,
2980 bool ForceCarryReconstruction = false) {
2981 bool Masked = false;
2982
2983 // First, peel away TRUNCATE/ZERO_EXTEND/AND nodes due to legalization.
2984 while (true) {
2985 if (V.getOpcode() == ISD::TRUNCATE || V.getOpcode() == ISD::ZERO_EXTEND) {
2986 V = V.getOperand(i: 0);
2987 continue;
2988 }
2989
2990 if (V.getOpcode() == ISD::AND && isOneConstant(V: V.getOperand(i: 1))) {
2991 if (ForceCarryReconstruction)
2992 return V;
2993
2994 Masked = true;
2995 V = V.getOperand(i: 0);
2996 continue;
2997 }
2998
2999 if (ForceCarryReconstruction && V.getValueType() == MVT::i1)
3000 return V;
3001
3002 break;
3003 }
3004
3005 // If this is not a carry, return.
3006 if (V.getResNo() != 1)
3007 return SDValue();
3008
3009 if (V.getOpcode() != ISD::UADDO_CARRY && V.getOpcode() != ISD::USUBO_CARRY &&
3010 V.getOpcode() != ISD::UADDO && V.getOpcode() != ISD::USUBO)
3011 return SDValue();
3012
3013 EVT VT = V->getValueType(ResNo: 0);
3014 if (!TLI.isOperationLegalOrCustom(Op: V.getOpcode(), VT))
3015 return SDValue();
3016
3017 // If the result is masked, then no matter what kind of bool it is we can
3018 // return. If it isn't, then we need to make sure the bool type is either 0 or
3019 // 1 and not other values.
3020 if (Masked ||
3021 TLI.getBooleanContents(Type: V.getValueType()) ==
3022 TargetLoweringBase::ZeroOrOneBooleanContent)
3023 return V;
3024
3025 return SDValue();
3026}
3027
3028/// Given the operands of an add/sub operation, see if the 2nd operand is a
3029/// masked 0/1 whose source operand is actually known to be 0/-1. If so, invert
3030/// the opcode and bypass the mask operation.
3031static SDValue foldAddSubMasked1(bool IsAdd, SDValue N0, SDValue N1,
3032 SelectionDAG &DAG, const SDLoc &DL) {
3033 if (N1.getOpcode() == ISD::ZERO_EXTEND)
3034 N1 = N1.getOperand(i: 0);
3035
3036 if (N1.getOpcode() != ISD::AND || !isOneOrOneSplat(V: N1->getOperand(Num: 1)))
3037 return SDValue();
3038
3039 EVT VT = N0.getValueType();
3040 SDValue N10 = N1.getOperand(i: 0);
3041 if (N10.getValueType() != VT && N10.getOpcode() == ISD::TRUNCATE)
3042 N10 = N10.getOperand(i: 0);
3043
3044 if (N10.getValueType() != VT)
3045 return SDValue();
3046
3047 if (DAG.ComputeNumSignBits(Op: N10) != VT.getScalarSizeInBits())
3048 return SDValue();
3049
3050 // add N0, (and (AssertSext X, i1), 1) --> sub N0, X
3051 // sub N0, (and (AssertSext X, i1), 1) --> add N0, X
3052 return DAG.getNode(Opcode: IsAdd ? ISD::SUB : ISD::ADD, DL, VT, N1: N0, N2: N10);
3053}
3054
3055/// Helper for doing combines based on N0 and N1 being added to each other.
3056SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
3057 SDNode *LocReference) {
3058 EVT VT = N0.getValueType();
3059 SDLoc DL(LocReference);
3060
3061 // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n))
3062 SDValue Y, N;
3063 if (sd_match(N: N1, P: m_Shl(L: m_Neg(V: m_Value(N&: Y)), R: m_Value(N))))
3064 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0,
3065 N2: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Y, N2: N));
3066
3067 if (SDValue V = foldAddSubMasked1(IsAdd: true, N0, N1, DAG, DL))
3068 return V;
3069
3070 // Look for:
3071 // add (add x, 1), y
3072 // And if the target does not like this form then turn into:
3073 // sub y, (xor x, -1)
3074 if (!TLI.preferIncOfAddToSubOfNot(VT) && N0.getOpcode() == ISD::ADD &&
3075 N0.hasOneUse() && isOneOrOneSplat(V: N0.getOperand(i: 1)) &&
3076 // Limit this to after legalization if the add has wrap flags
3077 (Level >= AfterLegalizeDAG || (!N0->getFlags().hasNoUnsignedWrap() &&
3078 !N0->getFlags().hasNoSignedWrap()))) {
3079 SDValue Not = DAG.getNOT(DL, Val: N0.getOperand(i: 0), VT);
3080 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: Not);
3081 }
3082
3083 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse()) {
3084 // Hoist one-use subtraction by non-opaque constant:
3085 // (x - C) + y -> (x + y) - C
3086 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3087 if (isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
3088 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
3089 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Add, N2: N0.getOperand(i: 1));
3090 }
3091 // Hoist one-use subtraction from non-opaque constant:
3092 // (C - x) + y -> (y - x) + C
3093 if (isConstantOrConstantVector(N: N0.getOperand(i: 0), /*NoOpaques=*/true)) {
3094 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: N0.getOperand(i: 1));
3095 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Sub, N2: N0.getOperand(i: 0));
3096 }
3097 }
3098
3099 // add (mul x, C), x -> mul x, C+1
3100 if (N0.getOpcode() == ISD::MUL && N0.getOperand(i: 0) == N1 &&
3101 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true) &&
3102 N0.hasOneUse()) {
3103 SDValue NewC = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 1),
3104 N2: DAG.getConstant(Val: 1, DL, VT));
3105 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
3106 }
3107
3108 // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
3109 // rather than 'add 0/-1' (the zext should get folded).
3110 // add (sext i1 Y), X --> sub X, (zext i1 Y)
3111 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
3112 N0.getOperand(i: 0).getScalarValueSizeInBits() == 1 &&
3113 TLI.getBooleanContents(Type: VT) == TargetLowering::ZeroOrOneBooleanContent) {
3114 SDValue ZExt = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
3115 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1, N2: ZExt);
3116 }
3117
3118 // add X, (sextinreg Y i1) -> sub X, (and Y 1)
3119 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3120 VTSDNode *TN = cast<VTSDNode>(Val: N1.getOperand(i: 1));
3121 if (TN->getVT() == MVT::i1) {
3122 SDValue ZExt = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N1.getOperand(i: 0),
3123 N2: DAG.getConstant(Val: 1, DL, VT));
3124 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: ZExt);
3125 }
3126 }
3127
3128 // (add X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3129 if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(V: N1.getOperand(i: 1)) &&
3130 N1.getResNo() == 0)
3131 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N1->getVTList(),
3132 N1: N0, N2: N1.getOperand(i: 0), N3: N1.getOperand(i: 2));
3133
3134 // (add X, Carry) -> (uaddo_carry X, 0, Carry)
3135 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT))
3136 if (SDValue Carry = getAsCarry(TLI, V: N1))
3137 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL,
3138 VTList: DAG.getVTList(VT1: VT, VT2: Carry.getValueType()), N1: N0,
3139 N2: DAG.getConstant(Val: 0, DL, VT), N3: Carry);
3140
3141 return SDValue();
3142}
3143
3144SDValue DAGCombiner::visitADDC(SDNode *N) {
3145 SDValue N0 = N->getOperand(Num: 0);
3146 SDValue N1 = N->getOperand(Num: 1);
3147 EVT VT = N0.getValueType();
3148 SDLoc DL(N);
3149
3150 // If the flag result is dead, turn this into an ADD.
3151 if (!N->hasAnyUseOfValue(Value: 1))
3152 return CombineTo(N, DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3153 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3154
3155 // canonicalize constant to RHS.
3156 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3157 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3158 if (N0C && !N1C)
3159 return DAG.getNode(Opcode: ISD::ADDC, DL, VTList: N->getVTList(), N1, N2: N0);
3160
3161 // fold (addc x, 0) -> x + no carry out
3162 if (isNullConstant(V: N1))
3163 return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE,
3164 DL, MVT::Glue));
3165
3166 // If it cannot overflow, transform into an add.
3167 if (DAG.computeOverflowForUnsignedAdd(N0, N1) == SelectionDAG::OFK_Never)
3168 return CombineTo(N, DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3169 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
3170
3171 return SDValue();
3172}
3173
3174/**
3175 * Flips a boolean if it is cheaper to compute. If the Force parameters is set,
3176 * then the flip also occurs if computing the inverse is the same cost.
3177 * This function returns an empty SDValue in case it cannot flip the boolean
3178 * without increasing the cost of the computation. If you want to flip a boolean
3179 * no matter what, use DAG.getLogicalNOT.
3180 */
3181static SDValue extractBooleanFlip(SDValue V, SelectionDAG &DAG,
3182 const TargetLowering &TLI,
3183 bool Force) {
3184 if (Force && isa<ConstantSDNode>(Val: V))
3185 return DAG.getLogicalNOT(DL: SDLoc(V), Val: V, VT: V.getValueType());
3186
3187 if (V.getOpcode() != ISD::XOR)
3188 return SDValue();
3189
3190 ConstantSDNode *Const = isConstOrConstSplat(N: V.getOperand(i: 1), AllowUndefs: false);
3191 if (!Const)
3192 return SDValue();
3193
3194 EVT VT = V.getValueType();
3195
3196 bool IsFlip = false;
3197 switch(TLI.getBooleanContents(Type: VT)) {
3198 case TargetLowering::ZeroOrOneBooleanContent:
3199 IsFlip = Const->isOne();
3200 break;
3201 case TargetLowering::ZeroOrNegativeOneBooleanContent:
3202 IsFlip = Const->isAllOnes();
3203 break;
3204 case TargetLowering::UndefinedBooleanContent:
3205 IsFlip = (Const->getAPIntValue() & 0x01) == 1;
3206 break;
3207 }
3208
3209 if (IsFlip)
3210 return V.getOperand(i: 0);
3211 if (Force)
3212 return DAG.getLogicalNOT(DL: SDLoc(V), Val: V, VT: V.getValueType());
3213 return SDValue();
3214}
3215
3216SDValue DAGCombiner::visitADDO(SDNode *N) {
3217 SDValue N0 = N->getOperand(Num: 0);
3218 SDValue N1 = N->getOperand(Num: 1);
3219 EVT VT = N0.getValueType();
3220 bool IsSigned = (ISD::SADDO == N->getOpcode());
3221
3222 EVT CarryVT = N->getValueType(ResNo: 1);
3223 SDLoc DL(N);
3224
3225 // If the flag result is dead, turn this into an ADD.
3226 if (!N->hasAnyUseOfValue(Value: 1))
3227 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3228 Res1: DAG.getUNDEF(VT: CarryVT));
3229
3230 // canonicalize constant to RHS.
3231 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
3232 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
3233 return DAG.getNode(Opcode: N->getOpcode(), DL, VTList: N->getVTList(), N1, N2: N0);
3234
3235 // fold (addo x, 0) -> x + no carry out
3236 if (isNullOrNullSplat(V: N1))
3237 return CombineTo(N, Res0: N0, Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3238
3239 // If it cannot overflow, transform into an add.
3240 if (DAG.willNotOverflowAdd(IsSigned, N0, N1))
3241 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1),
3242 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3243
3244 if (IsSigned) {
3245 // fold (saddo (xor a, -1), 1) -> (ssub 0, a).
3246 if (isBitwiseNot(V: N0) && isOneOrOneSplat(V: N1))
3247 return DAG.getNode(Opcode: ISD::SSUBO, DL, VTList: N->getVTList(),
3248 N1: DAG.getConstant(Val: 0, DL, VT), N2: N0.getOperand(i: 0));
3249 } else {
3250 // fold (uaddo (xor a, -1), 1) -> (usub 0, a) and flip carry.
3251 if (isBitwiseNot(V: N0) && isOneOrOneSplat(V: N1)) {
3252 SDValue Sub = DAG.getNode(Opcode: ISD::USUBO, DL, VTList: N->getVTList(),
3253 N1: DAG.getConstant(Val: 0, DL, VT), N2: N0.getOperand(i: 0));
3254 return CombineTo(
3255 N, Res0: Sub, Res1: DAG.getLogicalNOT(DL, Val: Sub.getValue(R: 1), VT: Sub->getValueType(ResNo: 1)));
3256 }
3257
3258 if (SDValue Combined = visitUADDOLike(N0, N1, N))
3259 return Combined;
3260
3261 if (SDValue Combined = visitUADDOLike(N0: N1, N1: N0, N))
3262 return Combined;
3263 }
3264
3265 return SDValue();
3266}
3267
3268SDValue DAGCombiner::visitUADDOLike(SDValue N0, SDValue N1, SDNode *N) {
3269 EVT VT = N0.getValueType();
3270 if (VT.isVector())
3271 return SDValue();
3272
3273 // (uaddo X, (uaddo_carry Y, 0, Carry)) -> (uaddo_carry X, Y, Carry)
3274 // If Y + 1 cannot overflow.
3275 if (N1.getOpcode() == ISD::UADDO_CARRY && isNullConstant(V: N1.getOperand(i: 1))) {
3276 SDValue Y = N1.getOperand(i: 0);
3277 SDValue One = DAG.getConstant(Val: 1, DL: SDLoc(N), VT: Y.getValueType());
3278 if (DAG.computeOverflowForUnsignedAdd(N0: Y, N1: One) == SelectionDAG::OFK_Never)
3279 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: Y,
3280 N3: N1.getOperand(i: 2));
3281 }
3282
3283 // (uaddo X, Carry) -> (uaddo_carry X, 0, Carry)
3284 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT))
3285 if (SDValue Carry = getAsCarry(TLI, V: N1))
3286 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1: N0,
3287 N2: DAG.getConstant(Val: 0, DL: SDLoc(N), VT), N3: Carry);
3288
3289 return SDValue();
3290}
3291
3292SDValue DAGCombiner::visitADDE(SDNode *N) {
3293 SDValue N0 = N->getOperand(Num: 0);
3294 SDValue N1 = N->getOperand(Num: 1);
3295 SDValue CarryIn = N->getOperand(Num: 2);
3296
3297 // canonicalize constant to RHS
3298 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3299 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3300 if (N0C && !N1C)
3301 return DAG.getNode(Opcode: ISD::ADDE, DL: SDLoc(N), VTList: N->getVTList(),
3302 N1, N2: N0, N3: CarryIn);
3303
3304 // fold (adde x, y, false) -> (addc x, y)
3305 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
3306 return DAG.getNode(Opcode: ISD::ADDC, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
3307
3308 return SDValue();
3309}
3310
3311SDValue DAGCombiner::visitUADDO_CARRY(SDNode *N) {
3312 SDValue N0 = N->getOperand(Num: 0);
3313 SDValue N1 = N->getOperand(Num: 1);
3314 SDValue CarryIn = N->getOperand(Num: 2);
3315 SDLoc DL(N);
3316
3317 // canonicalize constant to RHS
3318 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3319 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3320 if (N0C && !N1C)
3321 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N->getVTList(), N1, N2: N0, N3: CarryIn);
3322
3323 // fold (uaddo_carry x, y, false) -> (uaddo x, y)
3324 if (isNullConstant(V: CarryIn)) {
3325 if (!LegalOperations ||
3326 TLI.isOperationLegalOrCustom(Op: ISD::UADDO, VT: N->getValueType(ResNo: 0)))
3327 return DAG.getNode(Opcode: ISD::UADDO, DL, VTList: N->getVTList(), N1: N0, N2: N1);
3328 }
3329
3330 // fold (uaddo_carry 0, 0, X) -> (and (ext/trunc X), 1) and no carry.
3331 if (isNullConstant(V: N0) && isNullConstant(V: N1)) {
3332 EVT VT = N0.getValueType();
3333 EVT CarryVT = CarryIn.getValueType();
3334 SDValue CarryExt = DAG.getBoolExtOrTrunc(Op: CarryIn, SL: DL, VT, OpVT: CarryVT);
3335 AddToWorklist(N: CarryExt.getNode());
3336 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::AND, DL, VT, N1: CarryExt,
3337 N2: DAG.getConstant(Val: 1, DL, VT)),
3338 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
3339 }
3340
3341 if (SDValue Combined = visitUADDO_CARRYLike(N0, N1, CarryIn, N))
3342 return Combined;
3343
3344 if (SDValue Combined = visitUADDO_CARRYLike(N0: N1, N1: N0, CarryIn, N))
3345 return Combined;
3346
3347 // We want to avoid useless duplication.
3348 // TODO: This is done automatically for binary operations. As UADDO_CARRY is
3349 // not a binary operation, this is not really possible to leverage this
3350 // existing mechanism for it. However, if more operations require the same
3351 // deduplication logic, then it may be worth generalize.
3352 SDValue Ops[] = {N1, N0, CarryIn};
3353 SDNode *CSENode =
3354 DAG.getNodeIfExists(Opcode: ISD::UADDO_CARRY, VTList: N->getVTList(), Ops, Flags: N->getFlags());
3355 if (CSENode)
3356 return SDValue(CSENode, 0);
3357
3358 return SDValue();
3359}
3360
3361/**
3362 * If we are facing some sort of diamond carry propagation pattern try to
3363 * break it up to generate something like:
3364 * (uaddo_carry X, 0, (uaddo_carry A, B, Z):Carry)
3365 *
3366 * The end result is usually an increase in operation required, but because the
3367 * carry is now linearized, other transforms can kick in and optimize the DAG.
3368 *
3369 * Patterns typically look something like
3370 * (uaddo A, B)
3371 * / \
3372 * Carry Sum
3373 * | \
3374 * | (uaddo_carry *, 0, Z)
3375 * | /
3376 * \ Carry
3377 * | /
3378 * (uaddo_carry X, *, *)
3379 *
3380 * But numerous variation exist. Our goal is to identify A, B, X and Z and
3381 * produce a combine with a single path for carry propagation.
3382 */
3383static SDValue combineUADDO_CARRYDiamond(DAGCombiner &Combiner,
3384 SelectionDAG &DAG, SDValue X,
3385 SDValue Carry0, SDValue Carry1,
3386 SDNode *N) {
3387 if (Carry1.getResNo() != 1 || Carry0.getResNo() != 1)
3388 return SDValue();
3389 if (Carry1.getOpcode() != ISD::UADDO)
3390 return SDValue();
3391
3392 SDValue Z;
3393
3394 /**
3395 * First look for a suitable Z. It will present itself in the form of
3396 * (uaddo_carry Y, 0, Z) or its equivalent (uaddo Y, 1) for Z=true
3397 */
3398 if (Carry0.getOpcode() == ISD::UADDO_CARRY &&
3399 isNullConstant(V: Carry0.getOperand(i: 1))) {
3400 Z = Carry0.getOperand(i: 2);
3401 } else if (Carry0.getOpcode() == ISD::UADDO &&
3402 isOneConstant(V: Carry0.getOperand(i: 1))) {
3403 EVT VT = Carry0->getValueType(ResNo: 1);
3404 Z = DAG.getConstant(Val: 1, DL: SDLoc(Carry0.getOperand(i: 1)), VT);
3405 } else {
3406 // We couldn't find a suitable Z.
3407 return SDValue();
3408 }
3409
3410
3411 auto cancelDiamond = [&](SDValue A,SDValue B) {
3412 SDLoc DL(N);
3413 SDValue NewY =
3414 DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: Carry0->getVTList(), N1: A, N2: B, N3: Z);
3415 Combiner.AddToWorklist(N: NewY.getNode());
3416 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL, VTList: N->getVTList(), N1: X,
3417 N2: DAG.getConstant(Val: 0, DL, VT: X.getValueType()),
3418 N3: NewY.getValue(R: 1));
3419 };
3420
3421 /**
3422 * (uaddo A, B)
3423 * |
3424 * Sum
3425 * |
3426 * (uaddo_carry *, 0, Z)
3427 */
3428 if (Carry0.getOperand(i: 0) == Carry1.getValue(R: 0)) {
3429 return cancelDiamond(Carry1.getOperand(i: 0), Carry1.getOperand(i: 1));
3430 }
3431
3432 /**
3433 * (uaddo_carry A, 0, Z)
3434 * |
3435 * Sum
3436 * |
3437 * (uaddo *, B)
3438 */
3439 if (Carry1.getOperand(i: 0) == Carry0.getValue(R: 0)) {
3440 return cancelDiamond(Carry0.getOperand(i: 0), Carry1.getOperand(i: 1));
3441 }
3442
3443 if (Carry1.getOperand(i: 1) == Carry0.getValue(R: 0)) {
3444 return cancelDiamond(Carry1.getOperand(i: 0), Carry0.getOperand(i: 0));
3445 }
3446
3447 return SDValue();
3448}
3449
3450// If we are facing some sort of diamond carry/borrow in/out pattern try to
3451// match patterns like:
3452//
3453// (uaddo A, B) CarryIn
3454// | \ |
3455// | \ |
3456// PartialSum PartialCarryOutX /
3457// | | /
3458// | ____|____________/
3459// | / |
3460// (uaddo *, *) \________
3461// | \ \
3462// | \ |
3463// | PartialCarryOutY |
3464// | \ |
3465// | \ /
3466// AddCarrySum | ______/
3467// | /
3468// CarryOut = (or *, *)
3469//
3470// And generate UADDO_CARRY (or USUBO_CARRY) with two result values:
3471//
3472// {AddCarrySum, CarryOut} = (uaddo_carry A, B, CarryIn)
3473//
3474// Our goal is to identify A, B, and CarryIn and produce UADDO_CARRY/USUBO_CARRY
3475// with a single path for carry/borrow out propagation.
3476static SDValue combineCarryDiamond(SelectionDAG &DAG, const TargetLowering &TLI,
3477 SDValue N0, SDValue N1, SDNode *N) {
3478 SDValue Carry0 = getAsCarry(TLI, V: N0);
3479 if (!Carry0)
3480 return SDValue();
3481 SDValue Carry1 = getAsCarry(TLI, V: N1);
3482 if (!Carry1)
3483 return SDValue();
3484
3485 unsigned Opcode = Carry0.getOpcode();
3486 if (Opcode != Carry1.getOpcode())
3487 return SDValue();
3488 if (Opcode != ISD::UADDO && Opcode != ISD::USUBO)
3489 return SDValue();
3490 // Guarantee identical type of CarryOut
3491 EVT CarryOutType = N->getValueType(ResNo: 0);
3492 if (CarryOutType != Carry0.getValue(R: 1).getValueType() ||
3493 CarryOutType != Carry1.getValue(R: 1).getValueType())
3494 return SDValue();
3495
3496 // Canonicalize the add/sub of A and B (the top node in the above ASCII art)
3497 // as Carry0 and the add/sub of the carry in as Carry1 (the middle node).
3498 if (Carry1.getNode()->isOperandOf(N: Carry0.getNode()))
3499 std::swap(a&: Carry0, b&: Carry1);
3500
3501 // Check if nodes are connected in expected way.
3502 if (Carry1.getOperand(i: 0) != Carry0.getValue(R: 0) &&
3503 Carry1.getOperand(i: 1) != Carry0.getValue(R: 0))
3504 return SDValue();
3505
3506 // The carry in value must be on the righthand side for subtraction.
3507 unsigned CarryInOperandNum =
3508 Carry1.getOperand(i: 0) == Carry0.getValue(R: 0) ? 1 : 0;
3509 if (Opcode == ISD::USUBO && CarryInOperandNum != 1)
3510 return SDValue();
3511 SDValue CarryIn = Carry1.getOperand(i: CarryInOperandNum);
3512
3513 unsigned NewOp = Opcode == ISD::UADDO ? ISD::UADDO_CARRY : ISD::USUBO_CARRY;
3514 if (!TLI.isOperationLegalOrCustom(Op: NewOp, VT: Carry0.getValue(R: 0).getValueType()))
3515 return SDValue();
3516
3517 // Verify that the carry/borrow in is plausibly a carry/borrow bit.
3518 CarryIn = getAsCarry(TLI, V: CarryIn, ForceCarryReconstruction: true);
3519 if (!CarryIn)
3520 return SDValue();
3521
3522 SDLoc DL(N);
3523 CarryIn = DAG.getBoolExtOrTrunc(Op: CarryIn, SL: DL, VT: Carry1->getValueType(ResNo: 1),
3524 OpVT: Carry1->getValueType(ResNo: 0));
3525 SDValue Merged =
3526 DAG.getNode(Opcode: NewOp, DL, VTList: Carry1->getVTList(), N1: Carry0.getOperand(i: 0),
3527 N2: Carry0.getOperand(i: 1), N3: CarryIn);
3528
3529 // Please note that because we have proven that the result of the UADDO/USUBO
3530 // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can
3531 // therefore prove that if the first UADDO/USUBO overflows, the second
3532 // UADDO/USUBO cannot. For example consider 8-bit numbers where 0xFF is the
3533 // maximum value.
3534 //
3535 // 0xFF + 0xFF == 0xFE with carry but 0xFE + 1 does not carry
3536 // 0x00 - 0xFF == 1 with a carry/borrow but 1 - 1 == 0 (no carry/borrow)
3537 //
3538 // This is important because it means that OR and XOR can be used to merge
3539 // carry flags; and that AND can return a constant zero.
3540 //
3541 // TODO: match other operations that can merge flags (ADD, etc)
3542 DAG.ReplaceAllUsesOfValueWith(From: Carry1.getValue(R: 0), To: Merged.getValue(R: 0));
3543 if (N->getOpcode() == ISD::AND)
3544 return DAG.getConstant(Val: 0, DL, VT: CarryOutType);
3545 return Merged.getValue(R: 1);
3546}
3547
3548SDValue DAGCombiner::visitUADDO_CARRYLike(SDValue N0, SDValue N1,
3549 SDValue CarryIn, SDNode *N) {
3550 // fold (uaddo_carry (xor a, -1), b, c) -> (usubo_carry b, a, !c) and flip
3551 // carry.
3552 if (isBitwiseNot(V: N0))
3553 if (SDValue NotC = extractBooleanFlip(V: CarryIn, DAG, TLI, Force: true)) {
3554 SDLoc DL(N);
3555 SDValue Sub = DAG.getNode(Opcode: ISD::USUBO_CARRY, DL, VTList: N->getVTList(), N1,
3556 N2: N0.getOperand(i: 0), N3: NotC);
3557 return CombineTo(
3558 N, Res0: Sub, Res1: DAG.getLogicalNOT(DL, Val: Sub.getValue(R: 1), VT: Sub->getValueType(ResNo: 1)));
3559 }
3560
3561 // Iff the flag result is dead:
3562 // (uaddo_carry (add|uaddo X, Y), 0, Carry) -> (uaddo_carry X, Y, Carry)
3563 // Don't do this if the Carry comes from the uaddo. It won't remove the uaddo
3564 // or the dependency between the instructions.
3565 if ((N0.getOpcode() == ISD::ADD ||
3566 (N0.getOpcode() == ISD::UADDO && N0.getResNo() == 0 &&
3567 N0.getValue(R: 1) != CarryIn)) &&
3568 isNullConstant(V: N1) && !N->hasAnyUseOfValue(Value: 1))
3569 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL: SDLoc(N), VTList: N->getVTList(),
3570 N1: N0.getOperand(i: 0), N2: N0.getOperand(i: 1), N3: CarryIn);
3571
3572 /**
3573 * When one of the uaddo_carry argument is itself a carry, we may be facing
3574 * a diamond carry propagation. In which case we try to transform the DAG
3575 * to ensure linear carry propagation if that is possible.
3576 */
3577 if (auto Y = getAsCarry(TLI, V: N1)) {
3578 // Because both are carries, Y and Z can be swapped.
3579 if (auto R = combineUADDO_CARRYDiamond(Combiner&: *this, DAG, X: N0, Carry0: Y, Carry1: CarryIn, N))
3580 return R;
3581 if (auto R = combineUADDO_CARRYDiamond(Combiner&: *this, DAG, X: N0, Carry0: CarryIn, Carry1: Y, N))
3582 return R;
3583 }
3584
3585 return SDValue();
3586}
3587
3588SDValue DAGCombiner::visitSADDO_CARRYLike(SDValue N0, SDValue N1,
3589 SDValue CarryIn, SDNode *N) {
3590 // fold (saddo_carry (xor a, -1), b, c) -> (ssubo_carry b, a, !c)
3591 if (isBitwiseNot(V: N0)) {
3592 if (SDValue NotC = extractBooleanFlip(V: CarryIn, DAG, TLI, Force: true))
3593 return DAG.getNode(Opcode: ISD::SSUBO_CARRY, DL: SDLoc(N), VTList: N->getVTList(), N1,
3594 N2: N0.getOperand(i: 0), N3: NotC);
3595 }
3596
3597 return SDValue();
3598}
3599
3600SDValue DAGCombiner::visitSADDO_CARRY(SDNode *N) {
3601 SDValue N0 = N->getOperand(Num: 0);
3602 SDValue N1 = N->getOperand(Num: 1);
3603 SDValue CarryIn = N->getOperand(Num: 2);
3604 SDLoc DL(N);
3605
3606 // canonicalize constant to RHS
3607 ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(Val&: N0);
3608 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
3609 if (N0C && !N1C)
3610 return DAG.getNode(Opcode: ISD::SADDO_CARRY, DL, VTList: N->getVTList(), N1, N2: N0, N3: CarryIn);
3611
3612 // fold (saddo_carry x, y, false) -> (saddo x, y)
3613 if (isNullConstant(V: CarryIn)) {
3614 if (!LegalOperations ||
3615 TLI.isOperationLegalOrCustom(Op: ISD::SADDO, VT: N->getValueType(ResNo: 0)))
3616 return DAG.getNode(Opcode: ISD::SADDO, DL, VTList: N->getVTList(), N1: N0, N2: N1);
3617 }
3618
3619 if (SDValue Combined = visitSADDO_CARRYLike(N0, N1, CarryIn, N))
3620 return Combined;
3621
3622 if (SDValue Combined = visitSADDO_CARRYLike(N0: N1, N1: N0, CarryIn, N))
3623 return Combined;
3624
3625 return SDValue();
3626}
3627
3628// Attempt to create a USUBSAT(LHS, RHS) node with DstVT, performing a
3629// clamp/truncation if necessary.
3630static SDValue getTruncatedUSUBSAT(EVT DstVT, EVT SrcVT, SDValue LHS,
3631 SDValue RHS, SelectionDAG &DAG,
3632 const SDLoc &DL) {
3633 assert(DstVT.getScalarSizeInBits() <= SrcVT.getScalarSizeInBits() &&
3634 "Illegal truncation");
3635
3636 if (DstVT == SrcVT)
3637 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT: DstVT, N1: LHS, N2: RHS);
3638
3639 // If the LHS is zero-extended then we can perform the USUBSAT as DstVT by
3640 // clamping RHS.
3641 APInt UpperBits = APInt::getBitsSetFrom(numBits: SrcVT.getScalarSizeInBits(),
3642 loBit: DstVT.getScalarSizeInBits());
3643 if (!DAG.MaskedValueIsZero(Op: LHS, Mask: UpperBits))
3644 return SDValue();
3645
3646 SDValue SatLimit =
3647 DAG.getConstant(Val: APInt::getLowBitsSet(numBits: SrcVT.getScalarSizeInBits(),
3648 loBitsSet: DstVT.getScalarSizeInBits()),
3649 DL, VT: SrcVT);
3650 RHS = DAG.getNode(Opcode: ISD::UMIN, DL, VT: SrcVT, N1: RHS, N2: SatLimit);
3651 RHS = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DstVT, Operand: RHS);
3652 LHS = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: DstVT, Operand: LHS);
3653 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT: DstVT, N1: LHS, N2: RHS);
3654}
3655
3656// Try to find umax(a,b) - b or a - umin(a,b) patterns that may be converted to
3657// usubsat(a,b), optionally as a truncated type.
3658SDValue DAGCombiner::foldSubToUSubSat(EVT DstVT, SDNode *N, const SDLoc &DL) {
3659 if (N->getOpcode() != ISD::SUB ||
3660 !(!LegalOperations || hasOperation(Opcode: ISD::USUBSAT, VT: DstVT)))
3661 return SDValue();
3662
3663 EVT SubVT = N->getValueType(ResNo: 0);
3664 SDValue Op0 = N->getOperand(Num: 0);
3665 SDValue Op1 = N->getOperand(Num: 1);
3666
3667 // Try to find umax(a,b) - b or a - umin(a,b) patterns
3668 // they may be converted to usubsat(a,b).
3669 if (Op0.getOpcode() == ISD::UMAX && Op0.hasOneUse()) {
3670 SDValue MaxLHS = Op0.getOperand(i: 0);
3671 SDValue MaxRHS = Op0.getOperand(i: 1);
3672 if (MaxLHS == Op1)
3673 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: MaxRHS, RHS: Op1, DAG, DL);
3674 if (MaxRHS == Op1)
3675 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: MaxLHS, RHS: Op1, DAG, DL);
3676 }
3677
3678 if (Op1.getOpcode() == ISD::UMIN && Op1.hasOneUse()) {
3679 SDValue MinLHS = Op1.getOperand(i: 0);
3680 SDValue MinRHS = Op1.getOperand(i: 1);
3681 if (MinLHS == Op0)
3682 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: Op0, RHS: MinRHS, DAG, DL);
3683 if (MinRHS == Op0)
3684 return getTruncatedUSUBSAT(DstVT, SrcVT: SubVT, LHS: Op0, RHS: MinLHS, DAG, DL);
3685 }
3686
3687 // sub(a,trunc(umin(zext(a),b))) -> usubsat(a,trunc(umin(b,SatLimit)))
3688 if (Op1.getOpcode() == ISD::TRUNCATE &&
3689 Op1.getOperand(i: 0).getOpcode() == ISD::UMIN &&
3690 Op1.getOperand(i: 0).hasOneUse()) {
3691 SDValue MinLHS = Op1.getOperand(i: 0).getOperand(i: 0);
3692 SDValue MinRHS = Op1.getOperand(i: 0).getOperand(i: 1);
3693 if (MinLHS.getOpcode() == ISD::ZERO_EXTEND && MinLHS.getOperand(i: 0) == Op0)
3694 return getTruncatedUSUBSAT(DstVT, SrcVT: MinLHS.getValueType(), LHS: MinLHS, RHS: MinRHS,
3695 DAG, DL);
3696 if (MinRHS.getOpcode() == ISD::ZERO_EXTEND && MinRHS.getOperand(i: 0) == Op0)
3697 return getTruncatedUSUBSAT(DstVT, SrcVT: MinLHS.getValueType(), LHS: MinRHS, RHS: MinLHS,
3698 DAG, DL);
3699 }
3700
3701 return SDValue();
3702}
3703
3704// Since it may not be valid to emit a fold to zero for vector initializers
3705// check if we can before folding.
3706static SDValue tryFoldToZero(const SDLoc &DL, const TargetLowering &TLI, EVT VT,
3707 SelectionDAG &DAG, bool LegalOperations) {
3708 if (!VT.isVector())
3709 return DAG.getConstant(Val: 0, DL, VT);
3710 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT))
3711 return DAG.getConstant(Val: 0, DL, VT);
3712 return SDValue();
3713}
3714
3715SDValue DAGCombiner::visitSUB(SDNode *N) {
3716 SDValue N0 = N->getOperand(Num: 0);
3717 SDValue N1 = N->getOperand(Num: 1);
3718 EVT VT = N0.getValueType();
3719 unsigned BitWidth = VT.getScalarSizeInBits();
3720 SDLoc DL(N);
3721
3722 auto PeekThroughFreeze = [](SDValue N) {
3723 if (N->getOpcode() == ISD::FREEZE && N.hasOneUse())
3724 return N->getOperand(Num: 0);
3725 return N;
3726 };
3727
3728 // fold (sub x, x) -> 0
3729 // FIXME: Refactor this and xor and other similar operations together.
3730 if (PeekThroughFreeze(N0) == PeekThroughFreeze(N1))
3731 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
3732
3733 // fold (sub c1, c2) -> c3
3734 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N0, N1}))
3735 return C;
3736
3737 // fold vector ops
3738 if (VT.isVector()) {
3739 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
3740 return FoldedVOp;
3741
3742 // fold (sub x, 0) -> x, vector edition
3743 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
3744 return N0;
3745 }
3746
3747 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
3748 return NewSel;
3749
3750 // fold (sub x, c) -> (add x, -c)
3751 if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N: N1))
3752 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
3753 N2: DAG.getConstant(Val: -N1C->getAPIntValue(), DL, VT));
3754
3755 if (isNullOrNullSplat(V: N0)) {
3756 // Right-shifting everything out but the sign bit followed by negation is
3757 // the same as flipping arithmetic/logical shift type without the negation:
3758 // -(X >>u 31) -> (X >>s 31)
3759 // -(X >>s 31) -> (X >>u 31)
3760 if (N1->getOpcode() == ISD::SRA || N1->getOpcode() == ISD::SRL) {
3761 ConstantSDNode *ShiftAmt = isConstOrConstSplat(N: N1.getOperand(i: 1));
3762 if (ShiftAmt && ShiftAmt->getAPIntValue() == (BitWidth - 1)) {
3763 auto NewSh = N1->getOpcode() == ISD::SRA ? ISD::SRL : ISD::SRA;
3764 if (!LegalOperations || TLI.isOperationLegal(Op: NewSh, VT))
3765 return DAG.getNode(Opcode: NewSh, DL, VT, N1: N1.getOperand(i: 0), N2: N1.getOperand(i: 1));
3766 }
3767 }
3768
3769 // 0 - X --> 0 if the sub is NUW.
3770 if (N->getFlags().hasNoUnsignedWrap())
3771 return N0;
3772
3773 if (DAG.MaskedValueIsZero(Op: N1, Mask: ~APInt::getSignMask(BitWidth))) {
3774 // N1 is either 0 or the minimum signed value. If the sub is NSW, then
3775 // N1 must be 0 because negating the minimum signed value is undefined.
3776 if (N->getFlags().hasNoSignedWrap())
3777 return N0;
3778
3779 // 0 - X --> X if X is 0 or the minimum signed value.
3780 return N1;
3781 }
3782
3783 // Convert 0 - abs(x).
3784 if (N1.getOpcode() == ISD::ABS && N1.hasOneUse() &&
3785 !TLI.isOperationLegalOrCustom(Op: ISD::ABS, VT))
3786 if (SDValue Result = TLI.expandABS(N: N1.getNode(), DAG, IsNegative: true))
3787 return Result;
3788
3789 // Fold neg(splat(neg(x)) -> splat(x)
3790 if (VT.isVector()) {
3791 SDValue N1S = DAG.getSplatValue(V: N1, LegalTypes: true);
3792 if (N1S && N1S.getOpcode() == ISD::SUB &&
3793 isNullConstant(V: N1S.getOperand(i: 0)))
3794 return DAG.getSplat(VT, DL, Op: N1S.getOperand(i: 1));
3795 }
3796 }
3797
3798 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1)
3799 if (isAllOnesOrAllOnesSplat(V: N0))
3800 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
3801
3802 // fold (A - (0-B)) -> A+B
3803 if (N1.getOpcode() == ISD::SUB && isNullOrNullSplat(V: N1.getOperand(i: 0)))
3804 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1.getOperand(i: 1));
3805
3806 // fold A-(A-B) -> B
3807 if (N1.getOpcode() == ISD::SUB && N0 == N1.getOperand(i: 0))
3808 return N1.getOperand(i: 1);
3809
3810 // fold (A+B)-A -> B
3811 if (N0.getOpcode() == ISD::ADD && N0.getOperand(i: 0) == N1)
3812 return N0.getOperand(i: 1);
3813
3814 // fold (A+B)-B -> A
3815 if (N0.getOpcode() == ISD::ADD && N0.getOperand(i: 1) == N1)
3816 return N0.getOperand(i: 0);
3817
3818 // fold (A+C1)-C2 -> A+(C1-C2)
3819 if (N0.getOpcode() == ISD::ADD) {
3820 SDValue N01 = N0.getOperand(i: 1);
3821 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N01, N1}))
3822 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
3823 }
3824
3825 // fold C2-(A+C1) -> (C2-C1)-A
3826 if (N1.getOpcode() == ISD::ADD) {
3827 SDValue N11 = N1.getOperand(i: 1);
3828 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N0, N11}))
3829 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: NewC, N2: N1.getOperand(i: 0));
3830 }
3831
3832 // fold (A-C1)-C2 -> A-(C1+C2)
3833 if (N0.getOpcode() == ISD::SUB) {
3834 SDValue N01 = N0.getOperand(i: 1);
3835 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL, VT, Ops: {N01, N1}))
3836 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
3837 }
3838
3839 // fold (c1-A)-c2 -> (c1-c2)-A
3840 if (N0.getOpcode() == ISD::SUB) {
3841 SDValue N00 = N0.getOperand(i: 0);
3842 if (SDValue NewC = DAG.FoldConstantArithmetic(Opcode: ISD::SUB, DL, VT, Ops: {N00, N1}))
3843 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: NewC, N2: N0.getOperand(i: 1));
3844 }
3845
3846 SDValue A, B, C;
3847
3848 // fold ((A+(B+C))-B) -> A+C
3849 if (sd_match(N: N0, P: m_Add(L: m_Value(N&: A), R: m_Add(L: m_Specific(N: N1), R: m_Value(N&: C)))))
3850 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: A, N2: C);
3851
3852 // fold ((A+(B-C))-B) -> A-C
3853 if (sd_match(N: N0, P: m_Add(L: m_Value(N&: A), R: m_Sub(L: m_Specific(N: N1), R: m_Value(N&: C)))))
3854 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: C);
3855
3856 // fold ((A-(B-C))-C) -> A-B
3857 if (sd_match(N: N0, P: m_Sub(L: m_Value(N&: A), R: m_Sub(L: m_Value(N&: B), R: m_Specific(N: N1)))))
3858 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: A, N2: B);
3859
3860 // fold (A-(B-C)) -> A+(C-B)
3861 if (sd_match(N: N1, P: m_OneUse(P: m_Sub(L: m_Value(N&: B), R: m_Value(N&: C)))))
3862 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
3863 N2: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: C, N2: B));
3864
3865 // A - (A & B) -> A & (~B)
3866 if (sd_match(N: N1, P: m_And(L: m_Specific(N: N0), R: m_Value(N&: B))) &&
3867 (N1.hasOneUse() || isConstantOrConstantVector(N: B, /*NoOpaques=*/true)))
3868 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: DAG.getNOT(DL, Val: B, VT));
3869
3870 // fold (A - (-B * C)) -> (A + (B * C))
3871 if (sd_match(N: N1, P: m_OneUse(P: m_Mul(L: m_Neg(V: m_Value(N&: B)), R: m_Value(N&: C)))))
3872 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
3873 N2: DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: B, N2: C));
3874
3875 // If either operand of a sub is undef, the result is undef
3876 if (N0.isUndef())
3877 return N0;
3878 if (N1.isUndef())
3879 return N1;
3880
3881 if (SDValue V = foldAddSubBoolOfMaskedVal(N, DL, DAG))
3882 return V;
3883
3884 if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
3885 return V;
3886
3887 // Try to match AVGCEIL fixedwidth pattern
3888 if (SDValue V = foldSubToAvg(N, DL))
3889 return V;
3890
3891 if (SDValue V = foldAddSubMasked1(IsAdd: false, N0, N1, DAG, DL))
3892 return V;
3893
3894 if (SDValue V = foldSubToUSubSat(DstVT: VT, N, DL))
3895 return V;
3896
3897 // (A - B) - 1 -> add (xor B, -1), A
3898 if (sd_match(N, P: m_Sub(L: m_OneUse(P: m_Sub(L: m_Value(N&: A), R: m_Value(N&: B))), R: m_One())))
3899 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: A, N2: DAG.getNOT(DL, Val: B, VT));
3900
3901 // Look for:
3902 // sub y, (xor x, -1)
3903 // And if the target does not like this form then turn into:
3904 // add (add x, y), 1
3905 if (TLI.preferIncOfAddToSubOfNot(VT) && N1.hasOneUse() && isBitwiseNot(V: N1)) {
3906 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: N1.getOperand(i: 0));
3907 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Add, N2: DAG.getConstant(Val: 1, DL, VT));
3908 }
3909
3910 // Hoist one-use addition by non-opaque constant:
3911 // (x + C) - y -> (x - y) + C
3912 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
3913 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
3914 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
3915 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Sub, N2: N0.getOperand(i: 1));
3916 }
3917 // y - (x + C) -> (y - x) - C
3918 if (N1.getOpcode() == ISD::ADD && N1.hasOneUse() &&
3919 isConstantOrConstantVector(N: N1.getOperand(i: 1), /*NoOpaques=*/true)) {
3920 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1.getOperand(i: 0));
3921 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Sub, N2: N1.getOperand(i: 1));
3922 }
3923 // (x - C) - y -> (x - y) - C
3924 // This is necessary because SUB(X,C) -> ADD(X,-C) doesn't work for vectors.
3925 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
3926 isConstantOrConstantVector(N: N0.getOperand(i: 1), /*NoOpaques=*/true)) {
3927 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: N1);
3928 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Sub, N2: N0.getOperand(i: 1));
3929 }
3930 // (C - x) - y -> C - (x + y)
3931 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
3932 isConstantOrConstantVector(N: N0.getOperand(i: 0), /*NoOpaques=*/true)) {
3933 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 1), N2: N1);
3934 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0.getOperand(i: 0), N2: Add);
3935 }
3936
3937 // If the target's bool is represented as 0/-1, prefer to make this 'add 0/-1'
3938 // rather than 'sub 0/1' (the sext should get folded).
3939 // sub X, (zext i1 Y) --> add X, (sext i1 Y)
3940 if (N1.getOpcode() == ISD::ZERO_EXTEND &&
3941 N1.getOperand(i: 0).getScalarValueSizeInBits() == 1 &&
3942 TLI.getBooleanContents(Type: VT) ==
3943 TargetLowering::ZeroOrNegativeOneBooleanContent) {
3944 SDValue SExt = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: N1.getOperand(i: 0));
3945 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: SExt);
3946 }
3947
3948 // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X)
3949 if (TLI.isOperationLegalOrCustom(Op: ISD::ABS, VT)) {
3950 if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) {
3951 SDValue X0 = N0.getOperand(i: 0), X1 = N0.getOperand(i: 1);
3952 SDValue S0 = N1.getOperand(i: 0);
3953 if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0))
3954 if (ConstantSDNode *C = isConstOrConstSplat(N: N1.getOperand(i: 1)))
3955 if (C->getAPIntValue() == (BitWidth - 1))
3956 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: S0);
3957 }
3958 }
3959
3960 // If the relocation model supports it, consider symbol offsets.
3961 if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Val&: N0))
3962 if (!LegalOperations && TLI.isOffsetFoldingLegal(GA)) {
3963 // fold (sub Sym+c1, Sym+c2) -> c1-c2
3964 if (GlobalAddressSDNode *GB = dyn_cast<GlobalAddressSDNode>(Val&: N1))
3965 if (GA->getGlobal() == GB->getGlobal())
3966 return DAG.getConstant(Val: (uint64_t)GA->getOffset() - GB->getOffset(),
3967 DL, VT);
3968 }
3969
3970 // sub X, (sextinreg Y i1) -> add X, (and Y 1)
3971 if (N1.getOpcode() == ISD::SIGN_EXTEND_INREG) {
3972 VTSDNode *TN = cast<VTSDNode>(Val: N1.getOperand(i: 1));
3973 if (TN->getVT() == MVT::i1) {
3974 SDValue ZExt = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N1.getOperand(i: 0),
3975 N2: DAG.getConstant(Val: 1, DL, VT));
3976 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: ZExt);
3977 }
3978 }
3979
3980 // canonicalize (sub X, (vscale * C)) to (add X, (vscale * -C))
3981 if (N1.getOpcode() == ISD::VSCALE && N1.hasOneUse()) {
3982 const APInt &IntVal = N1.getConstantOperandAPInt(i: 0);
3983 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: DAG.getVScale(DL, VT, MulImm: -IntVal));
3984 }
3985
3986 // canonicalize (sub X, step_vector(C)) to (add X, step_vector(-C))
3987 if (N1.getOpcode() == ISD::STEP_VECTOR && N1.hasOneUse()) {
3988 APInt NewStep = -N1.getConstantOperandAPInt(i: 0);
3989 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0,
3990 N2: DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep));
3991 }
3992
3993 // Prefer an add for more folding potential and possibly better codegen:
3994 // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1)
3995 if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) {
3996 SDValue ShAmt = N1.getOperand(i: 1);
3997 ConstantSDNode *ShAmtC = isConstOrConstSplat(N: ShAmt);
3998 if (ShAmtC && ShAmtC->getAPIntValue() == (BitWidth - 1)) {
3999 SDValue SRA = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N1.getOperand(i: 0), N2: ShAmt);
4000 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: SRA);
4001 }
4002 }
4003
4004 // As with the previous fold, prefer add for more folding potential.
4005 // Subtracting SMIN/0 is the same as adding SMIN/0:
4006 // N0 - (X << BW-1) --> N0 + (X << BW-1)
4007 if (N1.getOpcode() == ISD::SHL) {
4008 ConstantSDNode *ShlC = isConstOrConstSplat(N: N1.getOperand(i: 1));
4009 if (ShlC && ShlC->getAPIntValue() == (BitWidth - 1))
4010 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: N0);
4011 }
4012
4013 // (sub (usubo_carry X, 0, Carry), Y) -> (usubo_carry X, Y, Carry)
4014 if (N0.getOpcode() == ISD::USUBO_CARRY && isNullConstant(V: N0.getOperand(i: 1)) &&
4015 N0.getResNo() == 0 && N0.hasOneUse())
4016 return DAG.getNode(Opcode: ISD::USUBO_CARRY, DL, VTList: N0->getVTList(),
4017 N1: N0.getOperand(i: 0), N2: N1, N3: N0.getOperand(i: 2));
4018
4019 if (TLI.isOperationLegalOrCustom(Op: ISD::UADDO_CARRY, VT)) {
4020 // (sub Carry, X) -> (uaddo_carry (sub 0, X), 0, Carry)
4021 if (SDValue Carry = getAsCarry(TLI, V: N0)) {
4022 SDValue X = N1;
4023 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
4024 SDValue NegX = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Zero, N2: X);
4025 return DAG.getNode(Opcode: ISD::UADDO_CARRY, DL,
4026 VTList: DAG.getVTList(VT1: VT, VT2: Carry.getValueType()), N1: NegX, N2: Zero,
4027 N3: Carry);
4028 }
4029 }
4030
4031 // If there's no chance of borrowing from adjacent bits, then sub is xor:
4032 // sub C0, X --> xor X, C0
4033 if (ConstantSDNode *C0 = isConstOrConstSplat(N: N0)) {
4034 if (!C0->isOpaque()) {
4035 const APInt &C0Val = C0->getAPIntValue();
4036 const APInt &MaybeOnes = ~DAG.computeKnownBits(Op: N1).Zero;
4037 if ((C0Val - MaybeOnes) == (C0Val ^ MaybeOnes))
4038 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
4039 }
4040 }
4041
4042 // smax(a,b) - smin(a,b) --> abds(a,b)
4043 if (hasOperation(Opcode: ISD::ABDS, VT) &&
4044 sd_match(N: N0, P: m_SMax(L: m_Value(N&: A), R: m_Value(N&: B))) &&
4045 sd_match(N: N1, P: m_SMin(L: m_Specific(N: A), R: m_Specific(N: B))))
4046 return DAG.getNode(Opcode: ISD::ABDS, DL, VT, N1: A, N2: B);
4047
4048 // umax(a,b) - umin(a,b) --> abdu(a,b)
4049 if (hasOperation(Opcode: ISD::ABDU, VT) &&
4050 sd_match(N: N0, P: m_UMax(L: m_Value(N&: A), R: m_Value(N&: B))) &&
4051 sd_match(N: N1, P: m_UMin(L: m_Specific(N: A), R: m_Specific(N: B))))
4052 return DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1: A, N2: B);
4053
4054 return SDValue();
4055}
4056
4057SDValue DAGCombiner::visitSUBSAT(SDNode *N) {
4058 unsigned Opcode = N->getOpcode();
4059 SDValue N0 = N->getOperand(Num: 0);
4060 SDValue N1 = N->getOperand(Num: 1);
4061 EVT VT = N0.getValueType();
4062 bool IsSigned = Opcode == ISD::SSUBSAT;
4063 SDLoc DL(N);
4064
4065 // fold (sub_sat x, undef) -> 0
4066 if (N0.isUndef() || N1.isUndef())
4067 return DAG.getConstant(Val: 0, DL, VT);
4068
4069 // fold (sub_sat x, x) -> 0
4070 if (N0 == N1)
4071 return DAG.getConstant(Val: 0, DL, VT);
4072
4073 // fold (sub_sat c1, c2) -> c3
4074 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
4075 return C;
4076
4077 // fold vector ops
4078 if (VT.isVector()) {
4079 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4080 return FoldedVOp;
4081
4082 // fold (sub_sat x, 0) -> x, vector edition
4083 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
4084 return N0;
4085 }
4086
4087 // fold (sub_sat x, 0) -> x
4088 if (isNullConstant(V: N1))
4089 return N0;
4090
4091 // If it cannot overflow, transform into an sub.
4092 if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4093 return DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1);
4094
4095 return SDValue();
4096}
4097
4098SDValue DAGCombiner::visitSUBC(SDNode *N) {
4099 SDValue N0 = N->getOperand(Num: 0);
4100 SDValue N1 = N->getOperand(Num: 1);
4101 EVT VT = N0.getValueType();
4102 SDLoc DL(N);
4103
4104 // If the flag result is dead, turn this into an SUB.
4105 if (!N->hasAnyUseOfValue(Value: 1))
4106 return CombineTo(N, DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4107 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4108
4109 // fold (subc x, x) -> 0 + no borrow
4110 if (N0 == N1)
4111 return CombineTo(N, DAG.getConstant(Val: 0, DL, VT),
4112 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4113
4114 // fold (subc x, 0) -> x + no borrow
4115 if (isNullConstant(V: N1))
4116 return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4117
4118 // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4119 if (isAllOnesConstant(V: N0))
4120 return CombineTo(N, DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0),
4121 DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue));
4122
4123 return SDValue();
4124}
4125
4126SDValue DAGCombiner::visitSUBO(SDNode *N) {
4127 SDValue N0 = N->getOperand(Num: 0);
4128 SDValue N1 = N->getOperand(Num: 1);
4129 EVT VT = N0.getValueType();
4130 bool IsSigned = (ISD::SSUBO == N->getOpcode());
4131
4132 EVT CarryVT = N->getValueType(ResNo: 1);
4133 SDLoc DL(N);
4134
4135 // If the flag result is dead, turn this into an SUB.
4136 if (!N->hasAnyUseOfValue(Value: 1))
4137 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4138 Res1: DAG.getUNDEF(VT: CarryVT));
4139
4140 // fold (subo x, x) -> 0 + no borrow
4141 if (N0 == N1)
4142 return CombineTo(N, Res0: DAG.getConstant(Val: 0, DL, VT),
4143 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4144
4145 // fold (subox, c) -> (addo x, -c)
4146 if (ConstantSDNode *N1C = getAsNonOpaqueConstant(N: N1))
4147 if (IsSigned && !N1C->isMinSignedValue())
4148 return DAG.getNode(Opcode: ISD::SADDO, DL, VTList: N->getVTList(), N1: N0,
4149 N2: DAG.getConstant(Val: -N1C->getAPIntValue(), DL, VT));
4150
4151 // fold (subo x, 0) -> x + no borrow
4152 if (isNullOrNullSplat(V: N1))
4153 return CombineTo(N, Res0: N0, Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4154
4155 // If it cannot overflow, transform into an sub.
4156 if (DAG.willNotOverflowSub(IsSigned, N0, N1))
4157 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: N1),
4158 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4159
4160 // Canonicalize (usubo -1, x) -> ~x, i.e. (xor x, -1) + no borrow
4161 if (!IsSigned && isAllOnesOrAllOnesSplat(V: N0))
4162 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0),
4163 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
4164
4165 return SDValue();
4166}
4167
4168SDValue DAGCombiner::visitSUBE(SDNode *N) {
4169 SDValue N0 = N->getOperand(Num: 0);
4170 SDValue N1 = N->getOperand(Num: 1);
4171 SDValue CarryIn = N->getOperand(Num: 2);
4172
4173 // fold (sube x, y, false) -> (subc x, y)
4174 if (CarryIn.getOpcode() == ISD::CARRY_FALSE)
4175 return DAG.getNode(Opcode: ISD::SUBC, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4176
4177 return SDValue();
4178}
4179
4180SDValue DAGCombiner::visitUSUBO_CARRY(SDNode *N) {
4181 SDValue N0 = N->getOperand(Num: 0);
4182 SDValue N1 = N->getOperand(Num: 1);
4183 SDValue CarryIn = N->getOperand(Num: 2);
4184
4185 // fold (usubo_carry x, y, false) -> (usubo x, y)
4186 if (isNullConstant(V: CarryIn)) {
4187 if (!LegalOperations ||
4188 TLI.isOperationLegalOrCustom(Op: ISD::USUBO, VT: N->getValueType(ResNo: 0)))
4189 return DAG.getNode(Opcode: ISD::USUBO, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4190 }
4191
4192 return SDValue();
4193}
4194
4195SDValue DAGCombiner::visitSSUBO_CARRY(SDNode *N) {
4196 SDValue N0 = N->getOperand(Num: 0);
4197 SDValue N1 = N->getOperand(Num: 1);
4198 SDValue CarryIn = N->getOperand(Num: 2);
4199
4200 // fold (ssubo_carry x, y, false) -> (ssubo x, y)
4201 if (isNullConstant(V: CarryIn)) {
4202 if (!LegalOperations ||
4203 TLI.isOperationLegalOrCustom(Op: ISD::SSUBO, VT: N->getValueType(ResNo: 0)))
4204 return DAG.getNode(Opcode: ISD::SSUBO, DL: SDLoc(N), VTList: N->getVTList(), N1: N0, N2: N1);
4205 }
4206
4207 return SDValue();
4208}
4209
4210// Notice that "mulfix" can be any of SMULFIX, SMULFIXSAT, UMULFIX and
4211// UMULFIXSAT here.
4212SDValue DAGCombiner::visitMULFIX(SDNode *N) {
4213 SDValue N0 = N->getOperand(Num: 0);
4214 SDValue N1 = N->getOperand(Num: 1);
4215 SDValue Scale = N->getOperand(Num: 2);
4216 EVT VT = N0.getValueType();
4217
4218 // fold (mulfix x, undef, scale) -> 0
4219 if (N0.isUndef() || N1.isUndef())
4220 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
4221
4222 // Canonicalize constant to RHS (vector doesn't have to splat)
4223 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
4224 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
4225 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, N1, N2: N0, N3: Scale);
4226
4227 // fold (mulfix x, 0, scale) -> 0
4228 if (isNullConstant(V: N1))
4229 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
4230
4231 return SDValue();
4232}
4233
4234SDValue DAGCombiner::visitMUL(SDNode *N) {
4235 SDValue N0 = N->getOperand(Num: 0);
4236 SDValue N1 = N->getOperand(Num: 1);
4237 EVT VT = N0.getValueType();
4238 SDLoc DL(N);
4239
4240 // fold (mul x, undef) -> 0
4241 if (N0.isUndef() || N1.isUndef())
4242 return DAG.getConstant(Val: 0, DL, VT);
4243
4244 // fold (mul c1, c2) -> c1*c2
4245 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MUL, DL, VT, Ops: {N0, N1}))
4246 return C;
4247
4248 // canonicalize constant to RHS (vector doesn't have to splat)
4249 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
4250 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
4251 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1, N2: N0);
4252
4253 bool N1IsConst = false;
4254 bool N1IsOpaqueConst = false;
4255 APInt ConstValue1;
4256
4257 // fold vector ops
4258 if (VT.isVector()) {
4259 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4260 return FoldedVOp;
4261
4262 N1IsConst = ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: ConstValue1);
4263 assert((!N1IsConst ||
4264 ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) &&
4265 "Splat APInt should be element width");
4266 } else {
4267 N1IsConst = isa<ConstantSDNode>(Val: N1);
4268 if (N1IsConst) {
4269 ConstValue1 = N1->getAsAPIntVal();
4270 N1IsOpaqueConst = cast<ConstantSDNode>(Val&: N1)->isOpaque();
4271 }
4272 }
4273
4274 // fold (mul x, 0) -> 0
4275 if (N1IsConst && ConstValue1.isZero())
4276 return N1;
4277
4278 // fold (mul x, 1) -> x
4279 if (N1IsConst && ConstValue1.isOne())
4280 return N0;
4281
4282 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4283 return NewSel;
4284
4285 // fold (mul x, -1) -> 0-x
4286 if (N1IsConst && ConstValue1.isAllOnes())
4287 return DAG.getNegative(Val: N0, DL, VT);
4288
4289 // fold (mul x, (1 << c)) -> x << c
4290 if (isConstantOrConstantVector(N: N1, /*NoOpaques*/ true) &&
4291 (!VT.isVector() || Level <= AfterLegalizeVectorOps)) {
4292 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
4293 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
4294 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ShiftVT);
4295 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: Trunc);
4296 }
4297 }
4298
4299 // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
4300 if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
4301 unsigned Log2Val = (-ConstValue1).logBase2();
4302 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
4303
4304 // FIXME: If the input is something that is easily negated (e.g. a
4305 // single-use add), we should put the negate there.
4306 return DAG.getNode(Opcode: ISD::SUB, DL, VT,
4307 N1: DAG.getConstant(Val: 0, DL, VT),
4308 N2: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0,
4309 N2: DAG.getConstant(Val: Log2Val, DL, VT: ShiftVT)));
4310 }
4311
4312 // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
4313 // hi result is in use in case we hit this mid-legalization.
4314 for (unsigned LoHiOpc : {ISD::UMUL_LOHI, ISD::SMUL_LOHI}) {
4315 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: LoHiOpc, VT)) {
4316 SDVTList LoHiVT = DAG.getVTList(VT1: VT, VT2: VT);
4317 // TODO: Can we match commutable operands with getNodeIfExists?
4318 if (SDNode *LoHi = DAG.getNodeIfExists(Opcode: LoHiOpc, VTList: LoHiVT, Ops: {N0, N1}))
4319 if (LoHi->hasAnyUseOfValue(Value: 1))
4320 return SDValue(LoHi, 0);
4321 if (SDNode *LoHi = DAG.getNodeIfExists(Opcode: LoHiOpc, VTList: LoHiVT, Ops: {N1, N0}))
4322 if (LoHi->hasAnyUseOfValue(Value: 1))
4323 return SDValue(LoHi, 0);
4324 }
4325 }
4326
4327 // Try to transform:
4328 // (1) multiply-by-(power-of-2 +/- 1) into shift and add/sub.
4329 // mul x, (2^N + 1) --> add (shl x, N), x
4330 // mul x, (2^N - 1) --> sub (shl x, N), x
4331 // Examples: x * 33 --> (x << 5) + x
4332 // x * 15 --> (x << 4) - x
4333 // x * -33 --> -((x << 5) + x)
4334 // x * -15 --> -((x << 4) - x) ; this reduces --> x - (x << 4)
4335 // (2) multiply-by-(power-of-2 +/- power-of-2) into shifts and add/sub.
4336 // mul x, (2^N + 2^M) --> (add (shl x, N), (shl x, M))
4337 // mul x, (2^N - 2^M) --> (sub (shl x, N), (shl x, M))
4338 // Examples: x * 0x8800 --> (x << 15) + (x << 11)
4339 // x * 0xf800 --> (x << 16) - (x << 11)
4340 // x * -0x8800 --> -((x << 15) + (x << 11))
4341 // x * -0xf800 --> -((x << 16) - (x << 11)) ; (x << 11) - (x << 16)
4342 if (N1IsConst && TLI.decomposeMulByConstant(Context&: *DAG.getContext(), VT, C: N1)) {
4343 // TODO: We could handle more general decomposition of any constant by
4344 // having the target set a limit on number of ops and making a
4345 // callback to determine that sequence (similar to sqrt expansion).
4346 unsigned MathOp = ISD::DELETED_NODE;
4347 APInt MulC = ConstValue1.abs();
4348 // The constant `2` should be treated as (2^0 + 1).
4349 unsigned TZeros = MulC == 2 ? 0 : MulC.countr_zero();
4350 MulC.lshrInPlace(ShiftAmt: TZeros);
4351 if ((MulC - 1).isPowerOf2())
4352 MathOp = ISD::ADD;
4353 else if ((MulC + 1).isPowerOf2())
4354 MathOp = ISD::SUB;
4355
4356 if (MathOp != ISD::DELETED_NODE) {
4357 unsigned ShAmt =
4358 MathOp == ISD::ADD ? (MulC - 1).logBase2() : (MulC + 1).logBase2();
4359 ShAmt += TZeros;
4360 assert(ShAmt < VT.getScalarSizeInBits() &&
4361 "multiply-by-constant generated out of bounds shift");
4362 SDValue Shl =
4363 DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0, N2: DAG.getConstant(Val: ShAmt, DL, VT));
4364 SDValue R =
4365 TZeros ? DAG.getNode(Opcode: MathOp, DL, VT, N1: Shl,
4366 N2: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0,
4367 N2: DAG.getConstant(Val: TZeros, DL, VT)))
4368 : DAG.getNode(Opcode: MathOp, DL, VT, N1: Shl, N2: N0);
4369 if (ConstValue1.isNegative())
4370 R = DAG.getNegative(Val: R, DL, VT);
4371 return R;
4372 }
4373 }
4374
4375 // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
4376 if (N0.getOpcode() == ISD::SHL) {
4377 SDValue N01 = N0.getOperand(i: 1);
4378 if (SDValue C3 = DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {N1, N01}))
4379 return DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0.getOperand(i: 0), N2: C3);
4380 }
4381
4382 // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one
4383 // use.
4384 {
4385 SDValue Sh, Y;
4386
4387 // Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)).
4388 if (N0.getOpcode() == ISD::SHL &&
4389 isConstantOrConstantVector(N: N0.getOperand(i: 1)) && N0->hasOneUse()) {
4390 Sh = N0; Y = N1;
4391 } else if (N1.getOpcode() == ISD::SHL &&
4392 isConstantOrConstantVector(N: N1.getOperand(i: 1)) &&
4393 N1->hasOneUse()) {
4394 Sh = N1; Y = N0;
4395 }
4396
4397 if (Sh.getNode()) {
4398 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: Sh.getOperand(i: 0), N2: Y);
4399 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mul, N2: Sh.getOperand(i: 1));
4400 }
4401 }
4402
4403 // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
4404 if (N0.getOpcode() == ISD::ADD &&
4405 DAG.isConstantIntBuildVectorOrConstantInt(N: N1) &&
4406 DAG.isConstantIntBuildVectorOrConstantInt(N: N0.getOperand(i: 1)) &&
4407 isMulAddWithConstProfitable(MulNode: N, AddNode: N0, ConstNode: N1))
4408 return DAG.getNode(
4409 Opcode: ISD::ADD, DL, VT,
4410 N1: DAG.getNode(Opcode: ISD::MUL, DL: SDLoc(N0), VT, N1: N0.getOperand(i: 0), N2: N1),
4411 N2: DAG.getNode(Opcode: ISD::MUL, DL: SDLoc(N1), VT, N1: N0.getOperand(i: 1), N2: N1));
4412
4413 // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
4414 ConstantSDNode *NC1 = isConstOrConstSplat(N: N1);
4415 if (N0.getOpcode() == ISD::VSCALE && NC1) {
4416 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
4417 const APInt &C1 = NC1->getAPIntValue();
4418 return DAG.getVScale(DL, VT, MulImm: C0 * C1);
4419 }
4420
4421 // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
4422 APInt MulVal;
4423 if (N0.getOpcode() == ISD::STEP_VECTOR &&
4424 ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: MulVal)) {
4425 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
4426 APInt NewStep = C0 * MulVal;
4427 return DAG.getStepVector(DL, ResVT: VT, StepVal: NewStep);
4428 }
4429
4430 // Fold ((mul x, 0/undef) -> 0,
4431 // (mul x, 1) -> x) -> x)
4432 // -> and(x, mask)
4433 // We can replace vectors with '0' and '1' factors with a clearing mask.
4434 if (VT.isFixedLengthVector()) {
4435 unsigned NumElts = VT.getVectorNumElements();
4436 SmallBitVector ClearMask;
4437 ClearMask.reserve(N: NumElts);
4438 auto IsClearMask = [&ClearMask](ConstantSDNode *V) {
4439 if (!V || V->isZero()) {
4440 ClearMask.push_back(Val: true);
4441 return true;
4442 }
4443 ClearMask.push_back(Val: false);
4444 return V->isOne();
4445 };
4446 if ((!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::AND, VT)) &&
4447 ISD::matchUnaryPredicate(Op: N1, Match: IsClearMask, /*AllowUndefs*/ true)) {
4448 assert(N1.getOpcode() == ISD::BUILD_VECTOR && "Unknown constant vector");
4449 EVT LegalSVT = N1.getOperand(i: 0).getValueType();
4450 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: LegalSVT);
4451 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT: LegalSVT);
4452 SmallVector<SDValue, 16> Mask(NumElts, AllOnes);
4453 for (unsigned I = 0; I != NumElts; ++I)
4454 if (ClearMask[I])
4455 Mask[I] = Zero;
4456 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: DAG.getBuildVector(VT, DL, Ops: Mask));
4457 }
4458 }
4459
4460 // reassociate mul
4461 if (SDValue RMUL = reassociateOps(Opc: ISD::MUL, DL, N0, N1, Flags: N->getFlags()))
4462 return RMUL;
4463
4464 // Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4465 if (SDValue SD =
4466 reassociateReduction(RedOpc: ISD::VECREDUCE_MUL, Opc: ISD::MUL, DL, VT, N0, N1))
4467 return SD;
4468
4469 // Simplify the operands using demanded-bits information.
4470 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
4471 return SDValue(N, 0);
4472
4473 return SDValue();
4474}
4475
4476/// Return true if divmod libcall is available.
4477static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned,
4478 const TargetLowering &TLI) {
4479 RTLIB::Libcall LC;
4480 EVT NodeType = Node->getValueType(ResNo: 0);
4481 if (!NodeType.isSimple())
4482 return false;
4483 switch (NodeType.getSimpleVT().SimpleTy) {
4484 default: return false; // No libcall for vector types.
4485 case MVT::i8: LC= isSigned ? RTLIB::SDIVREM_I8 : RTLIB::UDIVREM_I8; break;
4486 case MVT::i16: LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break;
4487 case MVT::i32: LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break;
4488 case MVT::i64: LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break;
4489 case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break;
4490 }
4491
4492 return TLI.getLibcallName(Call: LC) != nullptr;
4493}
4494
4495/// Issue divrem if both quotient and remainder are needed.
4496SDValue DAGCombiner::useDivRem(SDNode *Node) {
4497 if (Node->use_empty())
4498 return SDValue(); // This is a dead node, leave it alone.
4499
4500 unsigned Opcode = Node->getOpcode();
4501 bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM);
4502 unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM;
4503
4504 // DivMod lib calls can still work on non-legal types if using lib-calls.
4505 EVT VT = Node->getValueType(ResNo: 0);
4506 if (VT.isVector() || !VT.isInteger())
4507 return SDValue();
4508
4509 if (!TLI.isTypeLegal(VT) && !TLI.isOperationCustom(Op: DivRemOpc, VT))
4510 return SDValue();
4511
4512 // If DIVREM is going to get expanded into a libcall,
4513 // but there is no libcall available, then don't combine.
4514 if (!TLI.isOperationLegalOrCustom(Op: DivRemOpc, VT) &&
4515 !isDivRemLibcallAvailable(Node, isSigned, TLI))
4516 return SDValue();
4517
4518 // If div is legal, it's better to do the normal expansion
4519 unsigned OtherOpcode = 0;
4520 if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) {
4521 OtherOpcode = isSigned ? ISD::SREM : ISD::UREM;
4522 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT))
4523 return SDValue();
4524 } else {
4525 OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4526 if (TLI.isOperationLegalOrCustom(Op: OtherOpcode, VT))
4527 return SDValue();
4528 }
4529
4530 SDValue Op0 = Node->getOperand(Num: 0);
4531 SDValue Op1 = Node->getOperand(Num: 1);
4532 SDValue combined;
4533 for (SDNode *User : Op0->uses()) {
4534 if (User == Node || User->getOpcode() == ISD::DELETED_NODE ||
4535 User->use_empty())
4536 continue;
4537 // Convert the other matching node(s), too;
4538 // otherwise, the DIVREM may get target-legalized into something
4539 // target-specific that we won't be able to recognize.
4540 unsigned UserOpc = User->getOpcode();
4541 if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) &&
4542 User->getOperand(Num: 0) == Op0 &&
4543 User->getOperand(Num: 1) == Op1) {
4544 if (!combined) {
4545 if (UserOpc == OtherOpcode) {
4546 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: VT);
4547 combined = DAG.getNode(Opcode: DivRemOpc, DL: SDLoc(Node), VTList: VTs, N1: Op0, N2: Op1);
4548 } else if (UserOpc == DivRemOpc) {
4549 combined = SDValue(User, 0);
4550 } else {
4551 assert(UserOpc == Opcode);
4552 continue;
4553 }
4554 }
4555 if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV)
4556 CombineTo(N: User, Res: combined);
4557 else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM)
4558 CombineTo(N: User, Res: combined.getValue(R: 1));
4559 }
4560 }
4561 return combined;
4562}
4563
4564static SDValue simplifyDivRem(SDNode *N, SelectionDAG &DAG) {
4565 SDValue N0 = N->getOperand(Num: 0);
4566 SDValue N1 = N->getOperand(Num: 1);
4567 EVT VT = N->getValueType(ResNo: 0);
4568 SDLoc DL(N);
4569
4570 unsigned Opc = N->getOpcode();
4571 bool IsDiv = (ISD::SDIV == Opc) || (ISD::UDIV == Opc);
4572 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
4573
4574 // X / undef -> undef
4575 // X % undef -> undef
4576 // X / 0 -> undef
4577 // X % 0 -> undef
4578 // NOTE: This includes vectors where any divisor element is zero/undef.
4579 if (DAG.isUndef(Opcode: Opc, Ops: {N0, N1}))
4580 return DAG.getUNDEF(VT);
4581
4582 // undef / X -> 0
4583 // undef % X -> 0
4584 if (N0.isUndef())
4585 return DAG.getConstant(Val: 0, DL, VT);
4586
4587 // 0 / X -> 0
4588 // 0 % X -> 0
4589 ConstantSDNode *N0C = isConstOrConstSplat(N: N0);
4590 if (N0C && N0C->isZero())
4591 return N0;
4592
4593 // X / X -> 1
4594 // X % X -> 0
4595 if (N0 == N1)
4596 return DAG.getConstant(Val: IsDiv ? 1 : 0, DL, VT);
4597
4598 // X / 1 -> X
4599 // X % 1 -> 0
4600 // If this is a boolean op (single-bit element type), we can't have
4601 // division-by-zero or remainder-by-zero, so assume the divisor is 1.
4602 // TODO: Similarly, if we're zero-extending a boolean divisor, then assume
4603 // it's a 1.
4604 if ((N1C && N1C->isOne()) || (VT.getScalarType() == MVT::i1))
4605 return IsDiv ? N0 : DAG.getConstant(Val: 0, DL, VT);
4606
4607 return SDValue();
4608}
4609
4610SDValue DAGCombiner::visitSDIV(SDNode *N) {
4611 SDValue N0 = N->getOperand(Num: 0);
4612 SDValue N1 = N->getOperand(Num: 1);
4613 EVT VT = N->getValueType(ResNo: 0);
4614 EVT CCVT = getSetCCResultType(VT);
4615 SDLoc DL(N);
4616
4617 // fold (sdiv c1, c2) -> c1/c2
4618 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SDIV, DL, VT, Ops: {N0, N1}))
4619 return C;
4620
4621 // fold vector ops
4622 if (VT.isVector())
4623 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4624 return FoldedVOp;
4625
4626 // fold (sdiv X, -1) -> 0-X
4627 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
4628 if (N1C && N1C->isAllOnes())
4629 return DAG.getNegative(Val: N0, DL, VT);
4630
4631 // fold (sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
4632 if (N1C && N1C->isMinSignedValue())
4633 return DAG.getSelect(DL, VT, Cond: DAG.getSetCC(DL, VT: CCVT, LHS: N0, RHS: N1, Cond: ISD::SETEQ),
4634 LHS: DAG.getConstant(Val: 1, DL, VT),
4635 RHS: DAG.getConstant(Val: 0, DL, VT));
4636
4637 if (SDValue V = simplifyDivRem(N, DAG))
4638 return V;
4639
4640 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4641 return NewSel;
4642
4643 // If we know the sign bits of both operands are zero, strength reduce to a
4644 // udiv instead. Handles (X&15) /s 4 -> X&15 >> 2
4645 if (DAG.SignBitIsZero(Op: N1) && DAG.SignBitIsZero(Op: N0))
4646 return DAG.getNode(Opcode: ISD::UDIV, DL, VT: N1.getValueType(), N1: N0, N2: N1);
4647
4648 if (SDValue V = visitSDIVLike(N0, N1, N)) {
4649 // If the corresponding remainder node exists, update its users with
4650 // (Dividend - (Quotient * Divisor).
4651 if (SDNode *RemNode = DAG.getNodeIfExists(Opcode: ISD::SREM, VTList: N->getVTList(),
4652 Ops: { N0, N1 })) {
4653 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: V, N2: N1);
4654 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
4655 AddToWorklist(N: Mul.getNode());
4656 AddToWorklist(N: Sub.getNode());
4657 CombineTo(N: RemNode, Res: Sub);
4658 }
4659 return V;
4660 }
4661
4662 // sdiv, srem -> sdivrem
4663 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4664 // true. Otherwise, we break the simplification logic in visitREM().
4665 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4666 if (!N1C || TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
4667 if (SDValue DivRem = useDivRem(Node: N))
4668 return DivRem;
4669
4670 return SDValue();
4671}
4672
4673static bool isDivisorPowerOfTwo(SDValue Divisor) {
4674 // Helper for determining whether a value is a power-2 constant scalar or a
4675 // vector of such elements.
4676 auto IsPowerOfTwo = [](ConstantSDNode *C) {
4677 if (C->isZero() || C->isOpaque())
4678 return false;
4679 if (C->getAPIntValue().isPowerOf2())
4680 return true;
4681 if (C->getAPIntValue().isNegatedPowerOf2())
4682 return true;
4683 return false;
4684 };
4685
4686 return ISD::matchUnaryPredicate(Op: Divisor, Match: IsPowerOfTwo);
4687}
4688
4689SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4690 SDLoc DL(N);
4691 EVT VT = N->getValueType(ResNo: 0);
4692 EVT CCVT = getSetCCResultType(VT);
4693 unsigned BitWidth = VT.getScalarSizeInBits();
4694
4695 // fold (sdiv X, pow2) -> simple ops after legalize
4696 // FIXME: We check for the exact bit here because the generic lowering gives
4697 // better results in that case. The target-specific lowering should learn how
4698 // to handle exact sdivs efficiently.
4699 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(Divisor: N1)) {
4700 // Target-specific implementation of sdiv x, pow2.
4701 if (SDValue Res = BuildSDIVPow2(N))
4702 return Res;
4703
4704 // Create constants that are functions of the shift amount value.
4705 EVT ShiftAmtTy = getShiftAmountTy(LHSTy: N0.getValueType());
4706 SDValue Bits = DAG.getConstant(Val: BitWidth, DL, VT: ShiftAmtTy);
4707 SDValue C1 = DAG.getNode(Opcode: ISD::CTTZ, DL, VT, Operand: N1);
4708 C1 = DAG.getZExtOrTrunc(Op: C1, DL, VT: ShiftAmtTy);
4709 SDValue Inexact = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftAmtTy, N1: Bits, N2: C1);
4710 if (!isConstantOrConstantVector(N: Inexact))
4711 return SDValue();
4712
4713 // Splat the sign bit into the register
4714 SDValue Sign = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0,
4715 N2: DAG.getConstant(Val: BitWidth - 1, DL, VT: ShiftAmtTy));
4716 AddToWorklist(N: Sign.getNode());
4717
4718 // Add (N0 < 0) ? abs2 - 1 : 0;
4719 SDValue Srl = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Sign, N2: Inexact);
4720 AddToWorklist(N: Srl.getNode());
4721 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0, N2: Srl);
4722 AddToWorklist(N: Add.getNode());
4723 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Add, N2: C1);
4724 AddToWorklist(N: Sra.getNode());
4725
4726 // Special case: (sdiv X, 1) -> X
4727 // Special Case: (sdiv X, -1) -> 0-X
4728 SDValue One = DAG.getConstant(Val: 1, DL, VT);
4729 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
4730 SDValue IsOne = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: One, Cond: ISD::SETEQ);
4731 SDValue IsAllOnes = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: AllOnes, Cond: ISD::SETEQ);
4732 SDValue IsOneOrAllOnes = DAG.getNode(Opcode: ISD::OR, DL, VT: CCVT, N1: IsOne, N2: IsAllOnes);
4733 Sra = DAG.getSelect(DL, VT, Cond: IsOneOrAllOnes, LHS: N0, RHS: Sra);
4734
4735 // If dividing by a positive value, we're done. Otherwise, the result must
4736 // be negated.
4737 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
4738 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Zero, N2: Sra);
4739
4740 // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
4741 SDValue IsNeg = DAG.getSetCC(DL, VT: CCVT, LHS: N1, RHS: Zero, Cond: ISD::SETLT);
4742 SDValue Res = DAG.getSelect(DL, VT, Cond: IsNeg, LHS: Sub, RHS: Sra);
4743 return Res;
4744 }
4745
4746 // If integer divide is expensive and we satisfy the requirements, emit an
4747 // alternate sequence. Targets may check function attributes for size/speed
4748 // trade-offs.
4749 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4750 if (isConstantOrConstantVector(N: N1) &&
4751 !TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
4752 if (SDValue Op = BuildSDIV(N))
4753 return Op;
4754
4755 return SDValue();
4756}
4757
4758SDValue DAGCombiner::visitUDIV(SDNode *N) {
4759 SDValue N0 = N->getOperand(Num: 0);
4760 SDValue N1 = N->getOperand(Num: 1);
4761 EVT VT = N->getValueType(ResNo: 0);
4762 EVT CCVT = getSetCCResultType(VT);
4763 SDLoc DL(N);
4764
4765 // fold (udiv c1, c2) -> c1/c2
4766 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::UDIV, DL, VT, Ops: {N0, N1}))
4767 return C;
4768
4769 // fold vector ops
4770 if (VT.isVector())
4771 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4772 return FoldedVOp;
4773
4774 // fold (udiv X, -1) -> select(X == -1, 1, 0)
4775 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
4776 if (N1C && N1C->isAllOnes() && CCVT.isVector() == VT.isVector()) {
4777 return DAG.getSelect(DL, VT, Cond: DAG.getSetCC(DL, VT: CCVT, LHS: N0, RHS: N1, Cond: ISD::SETEQ),
4778 LHS: DAG.getConstant(Val: 1, DL, VT),
4779 RHS: DAG.getConstant(Val: 0, DL, VT));
4780 }
4781
4782 if (SDValue V = simplifyDivRem(N, DAG))
4783 return V;
4784
4785 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4786 return NewSel;
4787
4788 if (SDValue V = visitUDIVLike(N0, N1, N)) {
4789 // If the corresponding remainder node exists, update its users with
4790 // (Dividend - (Quotient * Divisor).
4791 if (SDNode *RemNode = DAG.getNodeIfExists(Opcode: ISD::UREM, VTList: N->getVTList(),
4792 Ops: { N0, N1 })) {
4793 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: V, N2: N1);
4794 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
4795 AddToWorklist(N: Mul.getNode());
4796 AddToWorklist(N: Sub.getNode());
4797 CombineTo(N: RemNode, Res: Sub);
4798 }
4799 return V;
4800 }
4801
4802 // sdiv, srem -> sdivrem
4803 // If the divisor is constant, then return DIVREM only if isIntDivCheap() is
4804 // true. Otherwise, we break the simplification logic in visitREM().
4805 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4806 if (!N1C || TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
4807 if (SDValue DivRem = useDivRem(Node: N))
4808 return DivRem;
4809
4810 return SDValue();
4811}
4812
4813SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
4814 SDLoc DL(N);
4815 EVT VT = N->getValueType(ResNo: 0);
4816
4817 // fold (udiv x, (1 << c)) -> x >>u c
4818 if (isConstantOrConstantVector(N: N1, /*NoOpaques*/ true)) {
4819 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
4820 AddToWorklist(N: LogBase2.getNode());
4821
4822 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
4823 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ShiftVT);
4824 AddToWorklist(N: Trunc.getNode());
4825 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Trunc);
4826 }
4827 }
4828
4829 // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
4830 if (N1.getOpcode() == ISD::SHL) {
4831 SDValue N10 = N1.getOperand(i: 0);
4832 if (isConstantOrConstantVector(N: N10, /*NoOpaques*/ true)) {
4833 if (SDValue LogBase2 = BuildLogBase2(V: N10, DL)) {
4834 AddToWorklist(N: LogBase2.getNode());
4835
4836 EVT ADDVT = N1.getOperand(i: 1).getValueType();
4837 SDValue Trunc = DAG.getZExtOrTrunc(Op: LogBase2, DL, VT: ADDVT);
4838 AddToWorklist(N: Trunc.getNode());
4839 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT: ADDVT, N1: N1.getOperand(i: 1), N2: Trunc);
4840 AddToWorklist(N: Add.getNode());
4841 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Add);
4842 }
4843 }
4844 }
4845
4846 // fold (udiv x, c) -> alternate
4847 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4848 if (isConstantOrConstantVector(N: N1) &&
4849 !TLI.isIntDivCheap(VT: N->getValueType(ResNo: 0), Attr))
4850 if (SDValue Op = BuildUDIV(N))
4851 return Op;
4852
4853 return SDValue();
4854}
4855
4856SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) {
4857 if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(Divisor: N1) &&
4858 !DAG.doesNodeExist(Opcode: ISD::SDIV, VTList: N->getVTList(), Ops: {N0, N1})) {
4859 // Target-specific implementation of srem x, pow2.
4860 if (SDValue Res = BuildSREMPow2(N))
4861 return Res;
4862 }
4863 return SDValue();
4864}
4865
4866// handles ISD::SREM and ISD::UREM
4867SDValue DAGCombiner::visitREM(SDNode *N) {
4868 unsigned Opcode = N->getOpcode();
4869 SDValue N0 = N->getOperand(Num: 0);
4870 SDValue N1 = N->getOperand(Num: 1);
4871 EVT VT = N->getValueType(ResNo: 0);
4872 EVT CCVT = getSetCCResultType(VT);
4873
4874 bool isSigned = (Opcode == ISD::SREM);
4875 SDLoc DL(N);
4876
4877 // fold (rem c1, c2) -> c1%c2
4878 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
4879 return C;
4880
4881 // fold (urem X, -1) -> select(FX == -1, 0, FX)
4882 // Freeze the numerator to avoid a miscompile with an undefined value.
4883 if (!isSigned && llvm::isAllOnesOrAllOnesSplat(V: N1, /*AllowUndefs*/ false) &&
4884 CCVT.isVector() == VT.isVector()) {
4885 SDValue F0 = DAG.getFreeze(V: N0);
4886 SDValue EqualsNeg1 = DAG.getSetCC(DL, VT: CCVT, LHS: F0, RHS: N1, Cond: ISD::SETEQ);
4887 return DAG.getSelect(DL, VT, Cond: EqualsNeg1, LHS: DAG.getConstant(Val: 0, DL, VT), RHS: F0);
4888 }
4889
4890 if (SDValue V = simplifyDivRem(N, DAG))
4891 return V;
4892
4893 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
4894 return NewSel;
4895
4896 if (isSigned) {
4897 // If we know the sign bits of both operands are zero, strength reduce to a
4898 // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15
4899 if (DAG.SignBitIsZero(Op: N1) && DAG.SignBitIsZero(Op: N0))
4900 return DAG.getNode(Opcode: ISD::UREM, DL, VT, N1: N0, N2: N1);
4901 } else {
4902 if (DAG.isKnownToBeAPowerOfTwo(Val: N1)) {
4903 // fold (urem x, pow2) -> (and x, pow2-1)
4904 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4905 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: NegOne);
4906 AddToWorklist(N: Add.getNode());
4907 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: Add);
4908 }
4909 // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
4910 // fold (urem x, (lshr pow2, y)) -> (and x, (add (lshr pow2, y), -1))
4911 // TODO: We should sink the following into isKnownToBePowerOfTwo
4912 // using a OrZero parameter analogous to our handling in ValueTracking.
4913 if ((N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) &&
4914 DAG.isKnownToBeAPowerOfTwo(Val: N1.getOperand(i: 0))) {
4915 SDValue NegOne = DAG.getAllOnesConstant(DL, VT);
4916 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1, N2: NegOne);
4917 AddToWorklist(N: Add.getNode());
4918 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: Add);
4919 }
4920 }
4921
4922 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
4923
4924 // If X/C can be simplified by the division-by-constant logic, lower
4925 // X%C to the equivalent of X-X/C*C.
4926 // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
4927 // speculative DIV must not cause a DIVREM conversion. We guard against this
4928 // by skipping the simplification if isIntDivCheap(). When div is not cheap,
4929 // combine will not return a DIVREM. Regardless, checking cheapness here
4930 // makes sense since the simplification results in fatter code.
4931 if (DAG.isKnownNeverZero(Op: N1) && !TLI.isIntDivCheap(VT, Attr)) {
4932 if (isSigned) {
4933 // check if we can build faster implementation for srem
4934 if (SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N))
4935 return OptimizedRem;
4936 }
4937
4938 SDValue OptimizedDiv =
4939 isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N);
4940 if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
4941 // If the equivalent Div node also exists, update its users.
4942 unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV;
4943 if (SDNode *DivNode = DAG.getNodeIfExists(Opcode: DivOpcode, VTList: N->getVTList(),
4944 Ops: { N0, N1 }))
4945 CombineTo(N: DivNode, Res: OptimizedDiv);
4946 SDValue Mul = DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: OptimizedDiv, N2: N1);
4947 SDValue Sub = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: N0, N2: Mul);
4948 AddToWorklist(N: OptimizedDiv.getNode());
4949 AddToWorklist(N: Mul.getNode());
4950 return Sub;
4951 }
4952 }
4953
4954 // sdiv, srem -> sdivrem
4955 if (SDValue DivRem = useDivRem(Node: N))
4956 return DivRem.getValue(R: 1);
4957
4958 return SDValue();
4959}
4960
4961SDValue DAGCombiner::visitMULHS(SDNode *N) {
4962 SDValue N0 = N->getOperand(Num: 0);
4963 SDValue N1 = N->getOperand(Num: 1);
4964 EVT VT = N->getValueType(ResNo: 0);
4965 SDLoc DL(N);
4966
4967 // fold (mulhs c1, c2)
4968 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MULHS, DL, VT, Ops: {N0, N1}))
4969 return C;
4970
4971 // canonicalize constant to RHS.
4972 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
4973 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
4974 return DAG.getNode(Opcode: ISD::MULHS, DL, VTList: N->getVTList(), N1, N2: N0);
4975
4976 if (VT.isVector()) {
4977 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
4978 return FoldedVOp;
4979
4980 // fold (mulhs x, 0) -> 0
4981 // do not return N1, because undef node may exist.
4982 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
4983 return DAG.getConstant(Val: 0, DL, VT);
4984 }
4985
4986 // fold (mulhs x, 0) -> 0
4987 if (isNullConstant(V: N1))
4988 return N1;
4989
4990 // fold (mulhs x, 1) -> (sra x, size(x)-1)
4991 if (isOneConstant(V: N1))
4992 return DAG.getNode(Opcode: ISD::SRA, DL, VT: N0.getValueType(), N1: N0,
4993 N2: DAG.getConstant(Val: N0.getScalarValueSizeInBits() - 1, DL,
4994 VT: getShiftAmountTy(LHSTy: N0.getValueType())));
4995
4996 // fold (mulhs x, undef) -> 0
4997 if (N0.isUndef() || N1.isUndef())
4998 return DAG.getConstant(Val: 0, DL, VT);
4999
5000 // If the type twice as wide is legal, transform the mulhs to a wider multiply
5001 // plus a shift.
5002 if (!TLI.isOperationLegalOrCustom(Op: ISD::MULHS, VT) && VT.isSimple() &&
5003 !VT.isVector()) {
5004 MVT Simple = VT.getSimpleVT();
5005 unsigned SimpleSize = Simple.getSizeInBits();
5006 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5007 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5008 N0 = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N0);
5009 N1 = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N1);
5010 N1 = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: N0, N2: N1);
5011 N1 = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1,
5012 N2: DAG.getConstant(Val: SimpleSize, DL,
5013 VT: getShiftAmountTy(LHSTy: N1.getValueType())));
5014 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N1);
5015 }
5016 }
5017
5018 return SDValue();
5019}
5020
5021SDValue DAGCombiner::visitMULHU(SDNode *N) {
5022 SDValue N0 = N->getOperand(Num: 0);
5023 SDValue N1 = N->getOperand(Num: 1);
5024 EVT VT = N->getValueType(ResNo: 0);
5025 SDLoc DL(N);
5026
5027 // fold (mulhu c1, c2)
5028 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::MULHU, DL, VT, Ops: {N0, N1}))
5029 return C;
5030
5031 // canonicalize constant to RHS.
5032 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5033 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5034 return DAG.getNode(Opcode: ISD::MULHU, DL, VTList: N->getVTList(), N1, N2: N0);
5035
5036 if (VT.isVector()) {
5037 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5038 return FoldedVOp;
5039
5040 // fold (mulhu x, 0) -> 0
5041 // do not return N1, because undef node may exist.
5042 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
5043 return DAG.getConstant(Val: 0, DL, VT);
5044 }
5045
5046 // fold (mulhu x, 0) -> 0
5047 if (isNullConstant(V: N1))
5048 return N1;
5049
5050 // fold (mulhu x, 1) -> 0
5051 if (isOneConstant(V: N1))
5052 return DAG.getConstant(Val: 0, DL, VT: N0.getValueType());
5053
5054 // fold (mulhu x, undef) -> 0
5055 if (N0.isUndef() || N1.isUndef())
5056 return DAG.getConstant(Val: 0, DL, VT);
5057
5058 // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c)
5059 if (isConstantOrConstantVector(N: N1, /*NoOpaques*/ true) &&
5060 hasOperation(Opcode: ISD::SRL, VT)) {
5061 if (SDValue LogBase2 = BuildLogBase2(V: N1, DL)) {
5062 unsigned NumEltBits = VT.getScalarSizeInBits();
5063 SDValue SRLAmt = DAG.getNode(
5064 Opcode: ISD::SUB, DL, VT, N1: DAG.getConstant(Val: NumEltBits, DL, VT), N2: LogBase2);
5065 EVT ShiftVT = getShiftAmountTy(LHSTy: N0.getValueType());
5066 SDValue Trunc = DAG.getZExtOrTrunc(Op: SRLAmt, DL, VT: ShiftVT);
5067 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: Trunc);
5068 }
5069 }
5070
5071 // If the type twice as wide is legal, transform the mulhu to a wider multiply
5072 // plus a shift.
5073 if (!TLI.isOperationLegalOrCustom(Op: ISD::MULHU, VT) && VT.isSimple() &&
5074 !VT.isVector()) {
5075 MVT Simple = VT.getSimpleVT();
5076 unsigned SimpleSize = Simple.getSizeInBits();
5077 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5078 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5079 N0 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N0);
5080 N1 = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N1);
5081 N1 = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: N0, N2: N1);
5082 N1 = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1,
5083 N2: DAG.getConstant(Val: SimpleSize, DL,
5084 VT: getShiftAmountTy(LHSTy: N1.getValueType())));
5085 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N1);
5086 }
5087 }
5088
5089 // Simplify the operands using demanded-bits information.
5090 // We don't have demanded bits support for MULHU so this just enables constant
5091 // folding based on known bits.
5092 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
5093 return SDValue(N, 0);
5094
5095 return SDValue();
5096}
5097
5098SDValue DAGCombiner::visitAVG(SDNode *N) {
5099 unsigned Opcode = N->getOpcode();
5100 SDValue N0 = N->getOperand(Num: 0);
5101 SDValue N1 = N->getOperand(Num: 1);
5102 EVT VT = N->getValueType(ResNo: 0);
5103 SDLoc DL(N);
5104
5105 // fold (avg c1, c2)
5106 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5107 return C;
5108
5109 // canonicalize constant to RHS.
5110 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5111 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5112 return DAG.getNode(Opcode, DL, VTList: N->getVTList(), N1, N2: N0);
5113
5114 if (VT.isVector()) {
5115 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5116 return FoldedVOp;
5117
5118 // fold (avgfloor x, 0) -> x >> 1
5119 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode())) {
5120 if (Opcode == ISD::AVGFLOORS)
5121 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0, N2: DAG.getConstant(Val: 1, DL, VT));
5122 if (Opcode == ISD::AVGFLOORU)
5123 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0, N2: DAG.getConstant(Val: 1, DL, VT));
5124 }
5125 }
5126
5127 // fold (avg x, undef) -> x
5128 if (N0.isUndef())
5129 return N1;
5130 if (N1.isUndef())
5131 return N0;
5132
5133 // Fold (avg x, x) --> x
5134 if (N0 == N1 && Level >= AfterLegalizeTypes)
5135 return N0;
5136
5137 // TODO If we use avg for scalars anywhere, we can add (avgfl x, 0) -> x >> 1
5138
5139 return SDValue();
5140}
5141
5142SDValue DAGCombiner::visitABD(SDNode *N) {
5143 unsigned Opcode = N->getOpcode();
5144 SDValue N0 = N->getOperand(Num: 0);
5145 SDValue N1 = N->getOperand(Num: 1);
5146 EVT VT = N->getValueType(ResNo: 0);
5147 SDLoc DL(N);
5148
5149 // fold (abd c1, c2)
5150 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5151 return C;
5152
5153 // canonicalize constant to RHS.
5154 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5155 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5156 return DAG.getNode(Opcode, DL, VTList: N->getVTList(), N1, N2: N0);
5157
5158 if (VT.isVector()) {
5159 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5160 return FoldedVOp;
5161
5162 // fold (abds x, 0) -> abs x
5163 // fold (abdu x, 0) -> x
5164 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode())) {
5165 if (Opcode == ISD::ABDS)
5166 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: N0);
5167 if (Opcode == ISD::ABDU)
5168 return N0;
5169 }
5170 }
5171
5172 // fold (abd x, undef) -> 0
5173 if (N0.isUndef() || N1.isUndef())
5174 return DAG.getConstant(Val: 0, DL, VT);
5175
5176 // fold (abds x, y) -> (abdu x, y) iff both args are known positive
5177 if (Opcode == ISD::ABDS && hasOperation(Opcode: ISD::ABDU, VT) &&
5178 DAG.SignBitIsZero(Op: N0) && DAG.SignBitIsZero(Op: N1))
5179 return DAG.getNode(Opcode: ISD::ABDU, DL, VT, N1, N2: N0);
5180
5181 return SDValue();
5182}
5183
5184/// Perform optimizations common to nodes that compute two values. LoOp and HiOp
5185/// give the opcodes for the two computations that are being performed. Return
5186/// true if a simplification was made.
5187SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
5188 unsigned HiOp) {
5189 // If the high half is not needed, just compute the low half.
5190 bool HiExists = N->hasAnyUseOfValue(Value: 1);
5191 if (!HiExists && (!LegalOperations ||
5192 TLI.isOperationLegalOrCustom(Op: LoOp, VT: N->getValueType(ResNo: 0)))) {
5193 SDValue Res = DAG.getNode(Opcode: LoOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Ops: N->ops());
5194 return CombineTo(N, Res0: Res, Res1: Res);
5195 }
5196
5197 // If the low half is not needed, just compute the high half.
5198 bool LoExists = N->hasAnyUseOfValue(Value: 0);
5199 if (!LoExists && (!LegalOperations ||
5200 TLI.isOperationLegalOrCustom(Op: HiOp, VT: N->getValueType(ResNo: 1)))) {
5201 SDValue Res = DAG.getNode(Opcode: HiOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 1), Ops: N->ops());
5202 return CombineTo(N, Res0: Res, Res1: Res);
5203 }
5204
5205 // If both halves are used, return as it is.
5206 if (LoExists && HiExists)
5207 return SDValue();
5208
5209 // If the two computed results can be simplified separately, separate them.
5210 if (LoExists) {
5211 SDValue Lo = DAG.getNode(Opcode: LoOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Ops: N->ops());
5212 AddToWorklist(N: Lo.getNode());
5213 SDValue LoOpt = combine(N: Lo.getNode());
5214 if (LoOpt.getNode() && LoOpt.getNode() != Lo.getNode() &&
5215 (!LegalOperations ||
5216 TLI.isOperationLegalOrCustom(Op: LoOpt.getOpcode(), VT: LoOpt.getValueType())))
5217 return CombineTo(N, Res0: LoOpt, Res1: LoOpt);
5218 }
5219
5220 if (HiExists) {
5221 SDValue Hi = DAG.getNode(Opcode: HiOp, DL: SDLoc(N), VT: N->getValueType(ResNo: 1), Ops: N->ops());
5222 AddToWorklist(N: Hi.getNode());
5223 SDValue HiOpt = combine(N: Hi.getNode());
5224 if (HiOpt.getNode() && HiOpt != Hi &&
5225 (!LegalOperations ||
5226 TLI.isOperationLegalOrCustom(Op: HiOpt.getOpcode(), VT: HiOpt.getValueType())))
5227 return CombineTo(N, Res0: HiOpt, Res1: HiOpt);
5228 }
5229
5230 return SDValue();
5231}
5232
5233SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
5234 if (SDValue Res = SimplifyNodeWithTwoResults(N, LoOp: ISD::MUL, HiOp: ISD::MULHS))
5235 return Res;
5236
5237 SDValue N0 = N->getOperand(Num: 0);
5238 SDValue N1 = N->getOperand(Num: 1);
5239 EVT VT = N->getValueType(ResNo: 0);
5240 SDLoc DL(N);
5241
5242 // Constant fold.
5243 if (isa<ConstantSDNode>(Val: N0) && isa<ConstantSDNode>(Val: N1))
5244 return DAG.getNode(Opcode: ISD::SMUL_LOHI, DL, VTList: N->getVTList(), N1: N0, N2: N1);
5245
5246 // canonicalize constant to RHS (vector doesn't have to splat)
5247 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5248 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5249 return DAG.getNode(Opcode: ISD::SMUL_LOHI, DL, VTList: N->getVTList(), N1, N2: N0);
5250
5251 // If the type is twice as wide is legal, transform the mulhu to a wider
5252 // multiply plus a shift.
5253 if (VT.isSimple() && !VT.isVector()) {
5254 MVT Simple = VT.getSimpleVT();
5255 unsigned SimpleSize = Simple.getSizeInBits();
5256 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5257 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5258 SDValue Lo = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N0);
5259 SDValue Hi = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: NewVT, Operand: N1);
5260 Lo = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: Lo, N2: Hi);
5261 // Compute the high part as N1.
5262 Hi = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1: Lo,
5263 N2: DAG.getConstant(Val: SimpleSize, DL,
5264 VT: getShiftAmountTy(LHSTy: Lo.getValueType())));
5265 Hi = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Hi);
5266 // Compute the low part as N0.
5267 Lo = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Lo);
5268 return CombineTo(N, Res0: Lo, Res1: Hi);
5269 }
5270 }
5271
5272 return SDValue();
5273}
5274
5275SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
5276 if (SDValue Res = SimplifyNodeWithTwoResults(N, LoOp: ISD::MUL, HiOp: ISD::MULHU))
5277 return Res;
5278
5279 SDValue N0 = N->getOperand(Num: 0);
5280 SDValue N1 = N->getOperand(Num: 1);
5281 EVT VT = N->getValueType(ResNo: 0);
5282 SDLoc DL(N);
5283
5284 // Constant fold.
5285 if (isa<ConstantSDNode>(Val: N0) && isa<ConstantSDNode>(Val: N1))
5286 return DAG.getNode(Opcode: ISD::UMUL_LOHI, DL, VTList: N->getVTList(), N1: N0, N2: N1);
5287
5288 // canonicalize constant to RHS (vector doesn't have to splat)
5289 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5290 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5291 return DAG.getNode(Opcode: ISD::UMUL_LOHI, DL, VTList: N->getVTList(), N1, N2: N0);
5292
5293 // (umul_lohi N0, 0) -> (0, 0)
5294 if (isNullConstant(V: N1)) {
5295 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
5296 return CombineTo(N, Res0: Zero, Res1: Zero);
5297 }
5298
5299 // (umul_lohi N0, 1) -> (N0, 0)
5300 if (isOneConstant(V: N1)) {
5301 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
5302 return CombineTo(N, Res0: N0, Res1: Zero);
5303 }
5304
5305 // If the type is twice as wide is legal, transform the mulhu to a wider
5306 // multiply plus a shift.
5307 if (VT.isSimple() && !VT.isVector()) {
5308 MVT Simple = VT.getSimpleVT();
5309 unsigned SimpleSize = Simple.getSizeInBits();
5310 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SimpleSize*2);
5311 if (TLI.isOperationLegal(Op: ISD::MUL, VT: NewVT)) {
5312 SDValue Lo = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N0);
5313 SDValue Hi = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: NewVT, Operand: N1);
5314 Lo = DAG.getNode(Opcode: ISD::MUL, DL, VT: NewVT, N1: Lo, N2: Hi);
5315 // Compute the high part as N1.
5316 Hi = DAG.getNode(Opcode: ISD::SRL, DL, VT: NewVT, N1: Lo,
5317 N2: DAG.getConstant(Val: SimpleSize, DL,
5318 VT: getShiftAmountTy(LHSTy: Lo.getValueType())));
5319 Hi = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Hi);
5320 // Compute the low part as N0.
5321 Lo = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Lo);
5322 return CombineTo(N, Res0: Lo, Res1: Hi);
5323 }
5324 }
5325
5326 return SDValue();
5327}
5328
5329SDValue DAGCombiner::visitMULO(SDNode *N) {
5330 SDValue N0 = N->getOperand(Num: 0);
5331 SDValue N1 = N->getOperand(Num: 1);
5332 EVT VT = N0.getValueType();
5333 bool IsSigned = (ISD::SMULO == N->getOpcode());
5334
5335 EVT CarryVT = N->getValueType(ResNo: 1);
5336 SDLoc DL(N);
5337
5338 ConstantSDNode *N0C = isConstOrConstSplat(N: N0);
5339 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
5340
5341 // fold operation with constant operands.
5342 // TODO: Move this to FoldConstantArithmetic when it supports nodes with
5343 // multiple results.
5344 if (N0C && N1C) {
5345 bool Overflow;
5346 APInt Result =
5347 IsSigned ? N0C->getAPIntValue().smul_ov(RHS: N1C->getAPIntValue(), Overflow)
5348 : N0C->getAPIntValue().umul_ov(RHS: N1C->getAPIntValue(), Overflow);
5349 return CombineTo(N, Res0: DAG.getConstant(Val: Result, DL, VT),
5350 Res1: DAG.getBoolConstant(V: Overflow, DL, VT: CarryVT, OpVT: CarryVT));
5351 }
5352
5353 // canonicalize constant to RHS.
5354 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5355 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5356 return DAG.getNode(Opcode: N->getOpcode(), DL, VTList: N->getVTList(), N1, N2: N0);
5357
5358 // fold (mulo x, 0) -> 0 + no carry out
5359 if (isNullOrNullSplat(V: N1))
5360 return CombineTo(N, Res0: DAG.getConstant(Val: 0, DL, VT),
5361 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
5362
5363 // (mulo x, 2) -> (addo x, x)
5364 // FIXME: This needs a freeze.
5365 if (N1C && N1C->getAPIntValue() == 2 &&
5366 (!IsSigned || VT.getScalarSizeInBits() > 2))
5367 return DAG.getNode(Opcode: IsSigned ? ISD::SADDO : ISD::UADDO, DL,
5368 VTList: N->getVTList(), N1: N0, N2: N0);
5369
5370 // A 1 bit SMULO overflows if both inputs are 1.
5371 if (IsSigned && VT.getScalarSizeInBits() == 1) {
5372 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0, N2: N1);
5373 SDValue Cmp = DAG.getSetCC(DL, VT: CarryVT, LHS: And,
5374 RHS: DAG.getConstant(Val: 0, DL, VT), Cond: ISD::SETNE);
5375 return CombineTo(N, Res0: And, Res1: Cmp);
5376 }
5377
5378 // If it cannot overflow, transform into a mul.
5379 if (DAG.willNotOverflowMul(IsSigned, N0, N1))
5380 return CombineTo(N, Res0: DAG.getNode(Opcode: ISD::MUL, DL, VT, N1: N0, N2: N1),
5381 Res1: DAG.getConstant(Val: 0, DL, VT: CarryVT));
5382 return SDValue();
5383}
5384
5385// Function to calculate whether the Min/Max pair of SDNodes (potentially
5386// swapped around) make a signed saturate pattern, clamping to between a signed
5387// saturate of -2^(BW-1) and 2^(BW-1)-1, or an unsigned saturate of 0 and 2^BW.
5388// Returns the node being clamped and the bitwidth of the clamp in BW. Should
5389// work with both SMIN/SMAX nodes and setcc/select combo. The operands are the
5390// same as SimplifySelectCC. N0<N1 ? N2 : N3.
5391static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2,
5392 SDValue N3, ISD::CondCode CC, unsigned &BW,
5393 bool &Unsigned, SelectionDAG &DAG) {
5394 auto isSignedMinMax = [&](SDValue N0, SDValue N1, SDValue N2, SDValue N3,
5395 ISD::CondCode CC) {
5396 // The compare and select operand should be the same or the select operands
5397 // should be truncated versions of the comparison.
5398 if (N0 != N2 && (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(i: 0)))
5399 return 0;
5400 // The constants need to be the same or a truncated version of each other.
5401 ConstantSDNode *N1C = isConstOrConstSplat(N: peekThroughTruncates(V: N1));
5402 ConstantSDNode *N3C = isConstOrConstSplat(N: peekThroughTruncates(V: N3));
5403 if (!N1C || !N3C)
5404 return 0;
5405 const APInt &C1 = N1C->getAPIntValue().trunc(width: N1.getScalarValueSizeInBits());
5406 const APInt &C2 = N3C->getAPIntValue().trunc(width: N3.getScalarValueSizeInBits());
5407 if (C1.getBitWidth() < C2.getBitWidth() || C1 != C2.sext(width: C1.getBitWidth()))
5408 return 0;
5409 return CC == ISD::SETLT ? ISD::SMIN : (CC == ISD::SETGT ? ISD::SMAX : 0);
5410 };
5411
5412 // Check the initial value is a SMIN/SMAX equivalent.
5413 unsigned Opcode0 = isSignedMinMax(N0, N1, N2, N3, CC);
5414 if (!Opcode0)
5415 return SDValue();
5416
5417 // We could only need one range check, if the fptosi could never produce
5418 // the upper value.
5419 if (N0.getOpcode() == ISD::FP_TO_SINT && Opcode0 == ISD::SMAX) {
5420 if (isNullOrNullSplat(V: N3)) {
5421 EVT IntVT = N0.getValueType().getScalarType();
5422 EVT FPVT = N0.getOperand(i: 0).getValueType().getScalarType();
5423 if (FPVT.isSimple()) {
5424 Type *InputTy = FPVT.getTypeForEVT(Context&: *DAG.getContext());
5425 const fltSemantics &Semantics = InputTy->getFltSemantics();
5426 uint32_t MinBitWidth =
5427 APFloatBase::semanticsIntSizeInBits(Semantics, /*isSigned*/ true);
5428 if (IntVT.getSizeInBits() >= MinBitWidth) {
5429 Unsigned = true;
5430 BW = PowerOf2Ceil(A: MinBitWidth);
5431 return N0;
5432 }
5433 }
5434 }
5435 }
5436
5437 SDValue N00, N01, N02, N03;
5438 ISD::CondCode N0CC;
5439 switch (N0.getOpcode()) {
5440 case ISD::SMIN:
5441 case ISD::SMAX:
5442 N00 = N02 = N0.getOperand(i: 0);
5443 N01 = N03 = N0.getOperand(i: 1);
5444 N0CC = N0.getOpcode() == ISD::SMIN ? ISD::SETLT : ISD::SETGT;
5445 break;
5446 case ISD::SELECT_CC:
5447 N00 = N0.getOperand(i: 0);
5448 N01 = N0.getOperand(i: 1);
5449 N02 = N0.getOperand(i: 2);
5450 N03 = N0.getOperand(i: 3);
5451 N0CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 4))->get();
5452 break;
5453 case ISD::SELECT:
5454 case ISD::VSELECT:
5455 if (N0.getOperand(i: 0).getOpcode() != ISD::SETCC)
5456 return SDValue();
5457 N00 = N0.getOperand(i: 0).getOperand(i: 0);
5458 N01 = N0.getOperand(i: 0).getOperand(i: 1);
5459 N02 = N0.getOperand(i: 1);
5460 N03 = N0.getOperand(i: 2);
5461 N0CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 0).getOperand(i: 2))->get();
5462 break;
5463 default:
5464 return SDValue();
5465 }
5466
5467 unsigned Opcode1 = isSignedMinMax(N00, N01, N02, N03, N0CC);
5468 if (!Opcode1 || Opcode0 == Opcode1)
5469 return SDValue();
5470
5471 ConstantSDNode *MinCOp = isConstOrConstSplat(N: Opcode0 == ISD::SMIN ? N1 : N01);
5472 ConstantSDNode *MaxCOp = isConstOrConstSplat(N: Opcode0 == ISD::SMIN ? N01 : N1);
5473 if (!MinCOp || !MaxCOp || MinCOp->getValueType(ResNo: 0) != MaxCOp->getValueType(ResNo: 0))
5474 return SDValue();
5475
5476 const APInt &MinC = MinCOp->getAPIntValue();
5477 const APInt &MaxC = MaxCOp->getAPIntValue();
5478 APInt MinCPlus1 = MinC + 1;
5479 if (-MaxC == MinCPlus1 && MinCPlus1.isPowerOf2()) {
5480 BW = MinCPlus1.exactLogBase2() + 1;
5481 Unsigned = false;
5482 return N02;
5483 }
5484
5485 if (MaxC == 0 && MinCPlus1.isPowerOf2()) {
5486 BW = MinCPlus1.exactLogBase2();
5487 Unsigned = true;
5488 return N02;
5489 }
5490
5491 return SDValue();
5492}
5493
5494static SDValue PerformMinMaxFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5495 SDValue N3, ISD::CondCode CC,
5496 SelectionDAG &DAG) {
5497 unsigned BW;
5498 bool Unsigned;
5499 SDValue Fp = isSaturatingMinMax(N0, N1, N2, N3, CC, BW, Unsigned, DAG);
5500 if (!Fp || Fp.getOpcode() != ISD::FP_TO_SINT)
5501 return SDValue();
5502 EVT FPVT = Fp.getOperand(i: 0).getValueType();
5503 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW);
5504 if (FPVT.isVector())
5505 NewVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewVT,
5506 EC: FPVT.getVectorElementCount());
5507 unsigned NewOpc = Unsigned ? ISD::FP_TO_UINT_SAT : ISD::FP_TO_SINT_SAT;
5508 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(Op: NewOpc, FPVT, VT: NewVT))
5509 return SDValue();
5510 SDLoc DL(Fp);
5511 SDValue Sat = DAG.getNode(Opcode: NewOpc, DL, VT: NewVT, N1: Fp.getOperand(i: 0),
5512 N2: DAG.getValueType(NewVT.getScalarType()));
5513 return DAG.getExtOrTrunc(IsSigned: !Unsigned, Op: Sat, DL, VT: N2->getValueType(ResNo: 0));
5514}
5515
5516static SDValue PerformUMinFpToSatCombine(SDValue N0, SDValue N1, SDValue N2,
5517 SDValue N3, ISD::CondCode CC,
5518 SelectionDAG &DAG) {
5519 // We are looking for UMIN(FPTOUI(X), (2^n)-1), which may have come via a
5520 // select/vselect/select_cc. The two operands pairs for the select (N2/N3) may
5521 // be truncated versions of the setcc (N0/N1).
5522 if ((N0 != N2 &&
5523 (N2.getOpcode() != ISD::TRUNCATE || N0 != N2.getOperand(i: 0))) ||
5524 N0.getOpcode() != ISD::FP_TO_UINT || CC != ISD::SETULT)
5525 return SDValue();
5526 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
5527 ConstantSDNode *N3C = isConstOrConstSplat(N: N3);
5528 if (!N1C || !N3C)
5529 return SDValue();
5530 const APInt &C1 = N1C->getAPIntValue();
5531 const APInt &C3 = N3C->getAPIntValue();
5532 if (!(C1 + 1).isPowerOf2() || C1.getBitWidth() < C3.getBitWidth() ||
5533 C1 != C3.zext(width: C1.getBitWidth()))
5534 return SDValue();
5535
5536 unsigned BW = (C1 + 1).exactLogBase2();
5537 EVT FPVT = N0.getOperand(i: 0).getValueType();
5538 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW);
5539 if (FPVT.isVector())
5540 NewVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewVT,
5541 EC: FPVT.getVectorElementCount());
5542 if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(Op: ISD::FP_TO_UINT_SAT,
5543 FPVT, VT: NewVT))
5544 return SDValue();
5545
5546 SDValue Sat =
5547 DAG.getNode(Opcode: ISD::FP_TO_UINT_SAT, DL: SDLoc(N0), VT: NewVT, N1: N0.getOperand(i: 0),
5548 N2: DAG.getValueType(NewVT.getScalarType()));
5549 return DAG.getZExtOrTrunc(Op: Sat, DL: SDLoc(N0), VT: N3.getValueType());
5550}
5551
5552SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
5553 SDValue N0 = N->getOperand(Num: 0);
5554 SDValue N1 = N->getOperand(Num: 1);
5555 EVT VT = N0.getValueType();
5556 unsigned Opcode = N->getOpcode();
5557 SDLoc DL(N);
5558
5559 // fold operation with constant operands.
5560 if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, Ops: {N0, N1}))
5561 return C;
5562
5563 // If the operands are the same, this is a no-op.
5564 if (N0 == N1)
5565 return N0;
5566
5567 // canonicalize constant to RHS
5568 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
5569 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
5570 return DAG.getNode(Opcode, DL, VT, N1, N2: N0);
5571
5572 // fold vector ops
5573 if (VT.isVector())
5574 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
5575 return FoldedVOp;
5576
5577 // reassociate minmax
5578 if (SDValue RMINMAX = reassociateOps(Opc: Opcode, DL, N0, N1, Flags: N->getFlags()))
5579 return RMINMAX;
5580
5581 // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
5582 // Only do this if:
5583 // 1. The current op isn't legal and the flipped is.
5584 // 2. The saturation pattern is broken by canonicalization in InstCombine.
5585 bool IsOpIllegal = !TLI.isOperationLegal(Op: Opcode, VT);
5586 bool IsSatBroken = Opcode == ISD::UMIN && N0.getOpcode() == ISD::SMAX;
5587 if ((IsSatBroken || IsOpIllegal) && (N0.isUndef() || DAG.SignBitIsZero(Op: N0)) &&
5588 (N1.isUndef() || DAG.SignBitIsZero(Op: N1))) {
5589 unsigned AltOpcode;
5590 switch (Opcode) {
5591 case ISD::SMIN: AltOpcode = ISD::UMIN; break;
5592 case ISD::SMAX: AltOpcode = ISD::UMAX; break;
5593 case ISD::UMIN: AltOpcode = ISD::SMIN; break;
5594 case ISD::UMAX: AltOpcode = ISD::SMAX; break;
5595 default: llvm_unreachable("Unknown MINMAX opcode");
5596 }
5597 if ((IsSatBroken && IsOpIllegal) || TLI.isOperationLegal(Op: AltOpcode, VT))
5598 return DAG.getNode(Opcode: AltOpcode, DL, VT, N1: N0, N2: N1);
5599 }
5600
5601 if (Opcode == ISD::SMIN || Opcode == ISD::SMAX)
5602 if (SDValue S = PerformMinMaxFpToSatCombine(
5603 N0, N1, N2: N0, N3: N1, CC: Opcode == ISD::SMIN ? ISD::SETLT : ISD::SETGT, DAG))
5604 return S;
5605 if (Opcode == ISD::UMIN)
5606 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2: N0, N3: N1, CC: ISD::SETULT, DAG))
5607 return S;
5608
5609 // Fold min/max(vecreduce(x), vecreduce(y)) -> vecreduce(min/max(x, y))
5610 auto ReductionOpcode = [](unsigned Opcode) {
5611 switch (Opcode) {
5612 case ISD::SMIN:
5613 return ISD::VECREDUCE_SMIN;
5614 case ISD::SMAX:
5615 return ISD::VECREDUCE_SMAX;
5616 case ISD::UMIN:
5617 return ISD::VECREDUCE_UMIN;
5618 case ISD::UMAX:
5619 return ISD::VECREDUCE_UMAX;
5620 default:
5621 llvm_unreachable("Unexpected opcode");
5622 }
5623 };
5624 if (SDValue SD = reassociateReduction(RedOpc: ReductionOpcode(Opcode), Opc: Opcode,
5625 DL: SDLoc(N), VT, N0, N1))
5626 return SD;
5627
5628 // Simplify the operands using demanded-bits information.
5629 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
5630 return SDValue(N, 0);
5631
5632 return SDValue();
5633}
5634
5635/// If this is a bitwise logic instruction and both operands have the same
5636/// opcode, try to sink the other opcode after the logic instruction.
5637SDValue DAGCombiner::hoistLogicOpWithSameOpcodeHands(SDNode *N) {
5638 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
5639 EVT VT = N0.getValueType();
5640 unsigned LogicOpcode = N->getOpcode();
5641 unsigned HandOpcode = N0.getOpcode();
5642 assert(ISD::isBitwiseLogicOp(LogicOpcode) && "Expected logic opcode");
5643 assert(HandOpcode == N1.getOpcode() && "Bad input!");
5644
5645 // Bail early if none of these transforms apply.
5646 if (N0.getNumOperands() == 0)
5647 return SDValue();
5648
5649 // FIXME: We should check number of uses of the operands to not increase
5650 // the instruction count for all transforms.
5651
5652 // Handle size-changing casts (or sign_extend_inreg).
5653 SDValue X = N0.getOperand(i: 0);
5654 SDValue Y = N1.getOperand(i: 0);
5655 EVT XVT = X.getValueType();
5656 SDLoc DL(N);
5657 if (ISD::isExtOpcode(Opcode: HandOpcode) || ISD::isExtVecInRegOpcode(Opcode: HandOpcode) ||
5658 (HandOpcode == ISD::SIGN_EXTEND_INREG &&
5659 N0.getOperand(i: 1) == N1.getOperand(i: 1))) {
5660 // If both operands have other uses, this transform would create extra
5661 // instructions without eliminating anything.
5662 if (!N0.hasOneUse() && !N1.hasOneUse())
5663 return SDValue();
5664 // We need matching integer source types.
5665 if (XVT != Y.getValueType())
5666 return SDValue();
5667 // Don't create an illegal op during or after legalization. Don't ever
5668 // create an unsupported vector op.
5669 if ((VT.isVector() || LegalOperations) &&
5670 !TLI.isOperationLegalOrCustom(Op: LogicOpcode, VT: XVT))
5671 return SDValue();
5672 // Avoid infinite looping with PromoteIntBinOp.
5673 // TODO: Should we apply desirable/legal constraints to all opcodes?
5674 if ((HandOpcode == ISD::ANY_EXTEND ||
5675 HandOpcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
5676 LegalTypes && !TLI.isTypeDesirableForOp(LogicOpcode, VT: XVT))
5677 return SDValue();
5678 // logic_op (hand_op X), (hand_op Y) --> hand_op (logic_op X, Y)
5679 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5680 if (HandOpcode == ISD::SIGN_EXTEND_INREG)
5681 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic, N2: N0.getOperand(i: 1));
5682 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
5683 }
5684
5685 // logic_op (truncate x), (truncate y) --> truncate (logic_op x, y)
5686 if (HandOpcode == ISD::TRUNCATE) {
5687 // If both operands have other uses, this transform would create extra
5688 // instructions without eliminating anything.
5689 if (!N0.hasOneUse() && !N1.hasOneUse())
5690 return SDValue();
5691 // We need matching source types.
5692 if (XVT != Y.getValueType())
5693 return SDValue();
5694 // Don't create an illegal op during or after legalization.
5695 if (LegalOperations && !TLI.isOperationLegal(Op: LogicOpcode, VT: XVT))
5696 return SDValue();
5697 // Be extra careful sinking truncate. If it's free, there's no benefit in
5698 // widening a binop. Also, don't create a logic op on an illegal type.
5699 if (TLI.isZExtFree(FromTy: VT, ToTy: XVT) && TLI.isTruncateFree(FromVT: XVT, ToVT: VT))
5700 return SDValue();
5701 if (!TLI.isTypeLegal(VT: XVT))
5702 return SDValue();
5703 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5704 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
5705 }
5706
5707 // For binops SHL/SRL/SRA/AND:
5708 // logic_op (OP x, z), (OP y, z) --> OP (logic_op x, y), z
5709 if ((HandOpcode == ISD::SHL || HandOpcode == ISD::SRL ||
5710 HandOpcode == ISD::SRA || HandOpcode == ISD::AND) &&
5711 N0.getOperand(i: 1) == N1.getOperand(i: 1)) {
5712 // If either operand has other uses, this transform is not an improvement.
5713 if (!N0.hasOneUse() || !N1.hasOneUse())
5714 return SDValue();
5715 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5716 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic, N2: N0.getOperand(i: 1));
5717 }
5718
5719 // Unary ops: logic_op (bswap x), (bswap y) --> bswap (logic_op x, y)
5720 if (HandOpcode == ISD::BSWAP) {
5721 // If either operand has other uses, this transform is not an improvement.
5722 if (!N0.hasOneUse() || !N1.hasOneUse())
5723 return SDValue();
5724 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5725 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
5726 }
5727
5728 // For funnel shifts FSHL/FSHR:
5729 // logic_op (OP x, x1, s), (OP y, y1, s) -->
5730 // --> OP (logic_op x, y), (logic_op, x1, y1), s
5731 if ((HandOpcode == ISD::FSHL || HandOpcode == ISD::FSHR) &&
5732 N0.getOperand(i: 2) == N1.getOperand(i: 2)) {
5733 if (!N0.hasOneUse() || !N1.hasOneUse())
5734 return SDValue();
5735 SDValue X1 = N0.getOperand(i: 1);
5736 SDValue Y1 = N1.getOperand(i: 1);
5737 SDValue S = N0.getOperand(i: 2);
5738 SDValue Logic0 = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X, N2: Y);
5739 SDValue Logic1 = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X1, N2: Y1);
5740 return DAG.getNode(Opcode: HandOpcode, DL, VT, N1: Logic0, N2: Logic1, N3: S);
5741 }
5742
5743 // Simplify xor/and/or (bitcast(A), bitcast(B)) -> bitcast(op (A,B))
5744 // Only perform this optimization up until type legalization, before
5745 // LegalizeVectorOprs. LegalizeVectorOprs promotes vector operations by
5746 // adding bitcasts. For example (xor v4i32) is promoted to (v2i64), and
5747 // we don't want to undo this promotion.
5748 // We also handle SCALAR_TO_VECTOR because xor/or/and operations are cheaper
5749 // on scalars.
5750 if ((HandOpcode == ISD::BITCAST || HandOpcode == ISD::SCALAR_TO_VECTOR) &&
5751 Level <= AfterLegalizeTypes) {
5752 // Input types must be integer and the same.
5753 if (XVT.isInteger() && XVT == Y.getValueType() &&
5754 !(VT.isVector() && TLI.isTypeLegal(VT) &&
5755 !XVT.isVector() && !TLI.isTypeLegal(VT: XVT))) {
5756 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT: XVT, N1: X, N2: Y);
5757 return DAG.getNode(Opcode: HandOpcode, DL, VT, Operand: Logic);
5758 }
5759 }
5760
5761 // Xor/and/or are indifferent to the swizzle operation (shuffle of one value).
5762 // Simplify xor/and/or (shuff(A), shuff(B)) -> shuff(op (A,B))
5763 // If both shuffles use the same mask, and both shuffle within a single
5764 // vector, then it is worthwhile to move the swizzle after the operation.
5765 // The type-legalizer generates this pattern when loading illegal
5766 // vector types from memory. In many cases this allows additional shuffle
5767 // optimizations.
5768 // There are other cases where moving the shuffle after the xor/and/or
5769 // is profitable even if shuffles don't perform a swizzle.
5770 // If both shuffles use the same mask, and both shuffles have the same first
5771 // or second operand, then it might still be profitable to move the shuffle
5772 // after the xor/and/or operation.
5773 if (HandOpcode == ISD::VECTOR_SHUFFLE && Level < AfterLegalizeDAG) {
5774 auto *SVN0 = cast<ShuffleVectorSDNode>(Val&: N0);
5775 auto *SVN1 = cast<ShuffleVectorSDNode>(Val&: N1);
5776 assert(X.getValueType() == Y.getValueType() &&
5777 "Inputs to shuffles are not the same type");
5778
5779 // Check that both shuffles use the same mask. The masks are known to be of
5780 // the same length because the result vector type is the same.
5781 // Check also that shuffles have only one use to avoid introducing extra
5782 // instructions.
5783 if (!SVN0->hasOneUse() || !SVN1->hasOneUse() ||
5784 !SVN0->getMask().equals(RHS: SVN1->getMask()))
5785 return SDValue();
5786
5787 // Don't try to fold this node if it requires introducing a
5788 // build vector of all zeros that might be illegal at this stage.
5789 SDValue ShOp = N0.getOperand(i: 1);
5790 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5791 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5792
5793 // (logic_op (shuf (A, C), shuf (B, C))) --> shuf (logic_op (A, B), C)
5794 if (N0.getOperand(i: 1) == N1.getOperand(i: 1) && ShOp.getNode()) {
5795 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT,
5796 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0));
5797 return DAG.getVectorShuffle(VT, dl: DL, N1: Logic, N2: ShOp, Mask: SVN0->getMask());
5798 }
5799
5800 // Don't try to fold this node if it requires introducing a
5801 // build vector of all zeros that might be illegal at this stage.
5802 ShOp = N0.getOperand(i: 0);
5803 if (LogicOpcode == ISD::XOR && !ShOp.isUndef())
5804 ShOp = tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
5805
5806 // (logic_op (shuf (C, A), shuf (C, B))) --> shuf (C, logic_op (A, B))
5807 if (N0.getOperand(i: 0) == N1.getOperand(i: 0) && ShOp.getNode()) {
5808 SDValue Logic = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: N0.getOperand(i: 1),
5809 N2: N1.getOperand(i: 1));
5810 return DAG.getVectorShuffle(VT, dl: DL, N1: ShOp, N2: Logic, Mask: SVN0->getMask());
5811 }
5812 }
5813
5814 return SDValue();
5815}
5816
5817/// Try to make (and/or setcc (LL, LR), setcc (RL, RR)) more efficient.
5818SDValue DAGCombiner::foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1,
5819 const SDLoc &DL) {
5820 SDValue LL, LR, RL, RR, N0CC, N1CC;
5821 if (!isSetCCEquivalent(N: N0, LHS&: LL, RHS&: LR, CC&: N0CC) ||
5822 !isSetCCEquivalent(N: N1, LHS&: RL, RHS&: RR, CC&: N1CC))
5823 return SDValue();
5824
5825 assert(N0.getValueType() == N1.getValueType() &&
5826 "Unexpected operand types for bitwise logic op");
5827 assert(LL.getValueType() == LR.getValueType() &&
5828 RL.getValueType() == RR.getValueType() &&
5829 "Unexpected operand types for setcc");
5830
5831 // If we're here post-legalization or the logic op type is not i1, the logic
5832 // op type must match a setcc result type. Also, all folds require new
5833 // operations on the left and right operands, so those types must match.
5834 EVT VT = N0.getValueType();
5835 EVT OpVT = LL.getValueType();
5836 if (LegalOperations || VT.getScalarType() != MVT::i1)
5837 if (VT != getSetCCResultType(VT: OpVT))
5838 return SDValue();
5839 if (OpVT != RL.getValueType())
5840 return SDValue();
5841
5842 ISD::CondCode CC0 = cast<CondCodeSDNode>(Val&: N0CC)->get();
5843 ISD::CondCode CC1 = cast<CondCodeSDNode>(Val&: N1CC)->get();
5844 bool IsInteger = OpVT.isInteger();
5845 if (LR == RR && CC0 == CC1 && IsInteger) {
5846 bool IsZero = isNullOrNullSplat(V: LR);
5847 bool IsNeg1 = isAllOnesOrAllOnesSplat(V: LR);
5848
5849 // All bits clear?
5850 bool AndEqZero = IsAnd && CC1 == ISD::SETEQ && IsZero;
5851 // All sign bits clear?
5852 bool AndGtNeg1 = IsAnd && CC1 == ISD::SETGT && IsNeg1;
5853 // Any bits set?
5854 bool OrNeZero = !IsAnd && CC1 == ISD::SETNE && IsZero;
5855 // Any sign bits set?
5856 bool OrLtZero = !IsAnd && CC1 == ISD::SETLT && IsZero;
5857
5858 // (and (seteq X, 0), (seteq Y, 0)) --> (seteq (or X, Y), 0)
5859 // (and (setgt X, -1), (setgt Y, -1)) --> (setgt (or X, Y), -1)
5860 // (or (setne X, 0), (setne Y, 0)) --> (setne (or X, Y), 0)
5861 // (or (setlt X, 0), (setlt Y, 0)) --> (setlt (or X, Y), 0)
5862 if (AndEqZero || AndGtNeg1 || OrNeZero || OrLtZero) {
5863 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: RL);
5864 AddToWorklist(N: Or.getNode());
5865 return DAG.getSetCC(DL, VT, LHS: Or, RHS: LR, Cond: CC1);
5866 }
5867
5868 // All bits set?
5869 bool AndEqNeg1 = IsAnd && CC1 == ISD::SETEQ && IsNeg1;
5870 // All sign bits set?
5871 bool AndLtZero = IsAnd && CC1 == ISD::SETLT && IsZero;
5872 // Any bits clear?
5873 bool OrNeNeg1 = !IsAnd && CC1 == ISD::SETNE && IsNeg1;
5874 // Any sign bits clear?
5875 bool OrGtNeg1 = !IsAnd && CC1 == ISD::SETGT && IsNeg1;
5876
5877 // (and (seteq X, -1), (seteq Y, -1)) --> (seteq (and X, Y), -1)
5878 // (and (setlt X, 0), (setlt Y, 0)) --> (setlt (and X, Y), 0)
5879 // (or (setne X, -1), (setne Y, -1)) --> (setne (and X, Y), -1)
5880 // (or (setgt X, -1), (setgt Y -1)) --> (setgt (and X, Y), -1)
5881 if (AndEqNeg1 || AndLtZero || OrNeNeg1 || OrGtNeg1) {
5882 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: RL);
5883 AddToWorklist(N: And.getNode());
5884 return DAG.getSetCC(DL, VT, LHS: And, RHS: LR, Cond: CC1);
5885 }
5886 }
5887
5888 // TODO: What is the 'or' equivalent of this fold?
5889 // (and (setne X, 0), (setne X, -1)) --> (setuge (add X, 1), 2)
5890 if (IsAnd && LL == RL && CC0 == CC1 && OpVT.getScalarSizeInBits() > 1 &&
5891 IsInteger && CC0 == ISD::SETNE &&
5892 ((isNullConstant(V: LR) && isAllOnesConstant(V: RR)) ||
5893 (isAllOnesConstant(V: LR) && isNullConstant(V: RR)))) {
5894 SDValue One = DAG.getConstant(Val: 1, DL, VT: OpVT);
5895 SDValue Two = DAG.getConstant(Val: 2, DL, VT: OpVT);
5896 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: One);
5897 AddToWorklist(N: Add.getNode());
5898 return DAG.getSetCC(DL, VT, LHS: Add, RHS: Two, Cond: ISD::SETUGE);
5899 }
5900
5901 // Try more general transforms if the predicates match and the only user of
5902 // the compares is the 'and' or 'or'.
5903 if (IsInteger && TLI.convertSetCCLogicToBitwiseLogic(VT: OpVT) && CC0 == CC1 &&
5904 N0.hasOneUse() && N1.hasOneUse()) {
5905 // and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0
5906 // or (setne A, B), (setne C, D) --> setne (or (xor A, B), (xor C, D)), 0
5907 if ((IsAnd && CC1 == ISD::SETEQ) || (!IsAnd && CC1 == ISD::SETNE)) {
5908 SDValue XorL = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N0), VT: OpVT, N1: LL, N2: LR);
5909 SDValue XorR = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N1), VT: OpVT, N1: RL, N2: RR);
5910 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: OpVT, N1: XorL, N2: XorR);
5911 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: OpVT);
5912 return DAG.getSetCC(DL, VT, LHS: Or, RHS: Zero, Cond: CC1);
5913 }
5914
5915 // Turn compare of constants whose difference is 1 bit into add+and+setcc.
5916 if ((IsAnd && CC1 == ISD::SETNE) || (!IsAnd && CC1 == ISD::SETEQ)) {
5917 // Match a shared variable operand and 2 non-opaque constant operands.
5918 auto MatchDiffPow2 = [&](ConstantSDNode *C0, ConstantSDNode *C1) {
5919 // The difference of the constants must be a single bit.
5920 const APInt &CMax =
5921 APIntOps::umax(A: C0->getAPIntValue(), B: C1->getAPIntValue());
5922 const APInt &CMin =
5923 APIntOps::umin(A: C0->getAPIntValue(), B: C1->getAPIntValue());
5924 return !C0->isOpaque() && !C1->isOpaque() && (CMax - CMin).isPowerOf2();
5925 };
5926 if (LL == RL && ISD::matchBinaryPredicate(LHS: LR, RHS: RR, Match: MatchDiffPow2)) {
5927 // and/or (setcc X, CMax, ne), (setcc X, CMin, ne/eq) -->
5928 // setcc ((sub X, CMin), ~(CMax - CMin)), 0, ne/eq
5929 SDValue Max = DAG.getNode(Opcode: ISD::UMAX, DL, VT: OpVT, N1: LR, N2: RR);
5930 SDValue Min = DAG.getNode(Opcode: ISD::UMIN, DL, VT: OpVT, N1: LR, N2: RR);
5931 SDValue Offset = DAG.getNode(Opcode: ISD::SUB, DL, VT: OpVT, N1: LL, N2: Min);
5932 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: OpVT, N1: Max, N2: Min);
5933 SDValue Mask = DAG.getNOT(DL, Val: Diff, VT: OpVT);
5934 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: Offset, N2: Mask);
5935 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: OpVT);
5936 return DAG.getSetCC(DL, VT, LHS: And, RHS: Zero, Cond: CC0);
5937 }
5938 }
5939 }
5940
5941 // Canonicalize equivalent operands to LL == RL.
5942 if (LL == RR && LR == RL) {
5943 CC1 = ISD::getSetCCSwappedOperands(Operation: CC1);
5944 std::swap(a&: RL, b&: RR);
5945 }
5946
5947 // (and (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
5948 // (or (setcc X, Y, CC0), (setcc X, Y, CC1)) --> (setcc X, Y, NewCC)
5949 if (LL == RL && LR == RR) {
5950 ISD::CondCode NewCC = IsAnd ? ISD::getSetCCAndOperation(Op1: CC0, Op2: CC1, Type: OpVT)
5951 : ISD::getSetCCOrOperation(Op1: CC0, Op2: CC1, Type: OpVT);
5952 if (NewCC != ISD::SETCC_INVALID &&
5953 (!LegalOperations ||
5954 (TLI.isCondCodeLegal(CC: NewCC, VT: LL.getSimpleValueType()) &&
5955 TLI.isOperationLegal(Op: ISD::SETCC, VT: OpVT))))
5956 return DAG.getSetCC(DL, VT, LHS: LL, RHS: LR, Cond: NewCC);
5957 }
5958
5959 return SDValue();
5960}
5961
5962static bool arebothOperandsNotSNan(SDValue Operand1, SDValue Operand2,
5963 SelectionDAG &DAG) {
5964 return DAG.isKnownNeverSNaN(Op: Operand2) && DAG.isKnownNeverSNaN(Op: Operand1);
5965}
5966
5967static bool arebothOperandsNotNan(SDValue Operand1, SDValue Operand2,
5968 SelectionDAG &DAG) {
5969 return DAG.isKnownNeverNaN(Op: Operand2) && DAG.isKnownNeverNaN(Op: Operand1);
5970}
5971
5972static unsigned getMinMaxOpcodeForFP(SDValue Operand1, SDValue Operand2,
5973 ISD::CondCode CC, unsigned OrAndOpcode,
5974 SelectionDAG &DAG,
5975 bool isFMAXNUMFMINNUM_IEEE,
5976 bool isFMAXNUMFMINNUM) {
5977 // The optimization cannot be applied for all the predicates because
5978 // of the way FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle
5979 // NaNs. For FMINNUM_IEEE/FMAXNUM_IEEE, the optimization cannot be
5980 // applied at all if one of the operands is a signaling NaN.
5981
5982 // It is safe to use FMINNUM_IEEE/FMAXNUM_IEEE if all the operands
5983 // are non NaN values.
5984 if (((CC == ISD::SETLT || CC == ISD::SETLE) && (OrAndOpcode == ISD::OR)) ||
5985 ((CC == ISD::SETGT || CC == ISD::SETGE) && (OrAndOpcode == ISD::AND)))
5986 return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
5987 isFMAXNUMFMINNUM_IEEE
5988 ? ISD::FMINNUM_IEEE
5989 : ISD::DELETED_NODE;
5990 else if (((CC == ISD::SETGT || CC == ISD::SETGE) &&
5991 (OrAndOpcode == ISD::OR)) ||
5992 ((CC == ISD::SETLT || CC == ISD::SETLE) &&
5993 (OrAndOpcode == ISD::AND)))
5994 return arebothOperandsNotNan(Operand1, Operand2, DAG) &&
5995 isFMAXNUMFMINNUM_IEEE
5996 ? ISD::FMAXNUM_IEEE
5997 : ISD::DELETED_NODE;
5998 // Both FMINNUM/FMAXNUM and FMINNUM_IEEE/FMAXNUM_IEEE handle quiet
5999 // NaNs in the same way. But, FMINNUM/FMAXNUM and FMINNUM_IEEE/
6000 // FMAXNUM_IEEE handle signaling NaNs differently. If we cannot prove
6001 // that there are not any sNaNs, then the optimization is not valid
6002 // for FMINNUM_IEEE/FMAXNUM_IEEE. In the presence of sNaNs, we apply
6003 // the optimization using FMINNUM/FMAXNUM for the following cases. If
6004 // we can prove that we do not have any sNaNs, then we can do the
6005 // optimization using FMINNUM_IEEE/FMAXNUM_IEEE for the following
6006 // cases.
6007 else if (((CC == ISD::SETOLT || CC == ISD::SETOLE) &&
6008 (OrAndOpcode == ISD::OR)) ||
6009 ((CC == ISD::SETUGT || CC == ISD::SETUGE) &&
6010 (OrAndOpcode == ISD::AND)))
6011 return isFMAXNUMFMINNUM ? ISD::FMINNUM
6012 : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6013 isFMAXNUMFMINNUM_IEEE
6014 ? ISD::FMINNUM_IEEE
6015 : ISD::DELETED_NODE;
6016 else if (((CC == ISD::SETOGT || CC == ISD::SETOGE) &&
6017 (OrAndOpcode == ISD::OR)) ||
6018 ((CC == ISD::SETULT || CC == ISD::SETULE) &&
6019 (OrAndOpcode == ISD::AND)))
6020 return isFMAXNUMFMINNUM ? ISD::FMAXNUM
6021 : arebothOperandsNotSNan(Operand1, Operand2, DAG) &&
6022 isFMAXNUMFMINNUM_IEEE
6023 ? ISD::FMAXNUM_IEEE
6024 : ISD::DELETED_NODE;
6025 return ISD::DELETED_NODE;
6026}
6027
6028static SDValue foldAndOrOfSETCC(SDNode *LogicOp, SelectionDAG &DAG) {
6029 using AndOrSETCCFoldKind = TargetLowering::AndOrSETCCFoldKind;
6030 assert(
6031 (LogicOp->getOpcode() == ISD::AND || LogicOp->getOpcode() == ISD::OR) &&
6032 "Invalid Op to combine SETCC with");
6033
6034 // TODO: Search past casts/truncates.
6035 SDValue LHS = LogicOp->getOperand(Num: 0);
6036 SDValue RHS = LogicOp->getOperand(Num: 1);
6037 if (LHS->getOpcode() != ISD::SETCC || RHS->getOpcode() != ISD::SETCC ||
6038 !LHS->hasOneUse() || !RHS->hasOneUse())
6039 return SDValue();
6040
6041 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6042 AndOrSETCCFoldKind TargetPreference = TLI.isDesirableToCombineLogicOpOfSETCC(
6043 LogicOp, SETCC0: LHS.getNode(), SETCC1: RHS.getNode());
6044
6045 SDValue LHS0 = LHS->getOperand(Num: 0);
6046 SDValue RHS0 = RHS->getOperand(Num: 0);
6047 SDValue LHS1 = LHS->getOperand(Num: 1);
6048 SDValue RHS1 = RHS->getOperand(Num: 1);
6049 // TODO: We don't actually need a splat here, for vectors we just need the
6050 // invariants to hold for each element.
6051 auto *LHS1C = isConstOrConstSplat(N: LHS1);
6052 auto *RHS1C = isConstOrConstSplat(N: RHS1);
6053 ISD::CondCode CCL = cast<CondCodeSDNode>(Val: LHS.getOperand(i: 2))->get();
6054 ISD::CondCode CCR = cast<CondCodeSDNode>(Val: RHS.getOperand(i: 2))->get();
6055 EVT VT = LogicOp->getValueType(ResNo: 0);
6056 EVT OpVT = LHS0.getValueType();
6057 SDLoc DL(LogicOp);
6058
6059 // Check if the operands of an and/or operation are comparisons and if they
6060 // compare against the same value. Replace the and/or-cmp-cmp sequence with
6061 // min/max cmp sequence. If LHS1 is equal to RHS1, then the or-cmp-cmp
6062 // sequence will be replaced with min-cmp sequence:
6063 // (LHS0 < LHS1) | (RHS0 < RHS1) -> min(LHS0, RHS0) < LHS1
6064 // and and-cmp-cmp will be replaced with max-cmp sequence:
6065 // (LHS0 < LHS1) & (RHS0 < RHS1) -> max(LHS0, RHS0) < LHS1
6066 // The optimization does not work for `==` or `!=` .
6067 // The two comparisons should have either the same predicate or the
6068 // predicate of one of the comparisons is the opposite of the other one.
6069 bool isFMAXNUMFMINNUM_IEEE = TLI.isOperationLegal(Op: ISD::FMAXNUM_IEEE, VT: OpVT) &&
6070 TLI.isOperationLegal(Op: ISD::FMINNUM_IEEE, VT: OpVT);
6071 bool isFMAXNUMFMINNUM = TLI.isOperationLegalOrCustom(Op: ISD::FMAXNUM, VT: OpVT) &&
6072 TLI.isOperationLegalOrCustom(Op: ISD::FMINNUM, VT: OpVT);
6073 if (((OpVT.isInteger() && TLI.isOperationLegal(Op: ISD::UMAX, VT: OpVT) &&
6074 TLI.isOperationLegal(Op: ISD::SMAX, VT: OpVT) &&
6075 TLI.isOperationLegal(Op: ISD::UMIN, VT: OpVT) &&
6076 TLI.isOperationLegal(Op: ISD::SMIN, VT: OpVT)) ||
6077 (OpVT.isFloatingPoint() &&
6078 (isFMAXNUMFMINNUM_IEEE || isFMAXNUMFMINNUM))) &&
6079 !ISD::isIntEqualitySetCC(Code: CCL) && !ISD::isFPEqualitySetCC(Code: CCL) &&
6080 CCL != ISD::SETFALSE && CCL != ISD::SETO && CCL != ISD::SETUO &&
6081 CCL != ISD::SETTRUE &&
6082 (CCL == CCR || CCL == ISD::getSetCCSwappedOperands(Operation: CCR))) {
6083
6084 SDValue CommonValue, Operand1, Operand2;
6085 ISD::CondCode CC = ISD::SETCC_INVALID;
6086 if (CCL == CCR) {
6087 if (LHS0 == RHS0) {
6088 CommonValue = LHS0;
6089 Operand1 = LHS1;
6090 Operand2 = RHS1;
6091 CC = ISD::getSetCCSwappedOperands(Operation: CCL);
6092 } else if (LHS1 == RHS1) {
6093 CommonValue = LHS1;
6094 Operand1 = LHS0;
6095 Operand2 = RHS0;
6096 CC = CCL;
6097 }
6098 } else {
6099 assert(CCL == ISD::getSetCCSwappedOperands(CCR) && "Unexpected CC");
6100 if (LHS0 == RHS1) {
6101 CommonValue = LHS0;
6102 Operand1 = LHS1;
6103 Operand2 = RHS0;
6104 CC = CCR;
6105 } else if (RHS0 == LHS1) {
6106 CommonValue = LHS1;
6107 Operand1 = LHS0;
6108 Operand2 = RHS1;
6109 CC = CCL;
6110 }
6111 }
6112
6113 // Don't do this transform for sign bit tests. Let foldLogicOfSetCCs
6114 // handle it using OR/AND.
6115 if (CC == ISD::SETLT && isNullOrNullSplat(V: CommonValue))
6116 CC = ISD::SETCC_INVALID;
6117 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: CommonValue))
6118 CC = ISD::SETCC_INVALID;
6119
6120 if (CC != ISD::SETCC_INVALID) {
6121 unsigned NewOpcode = ISD::DELETED_NODE;
6122 bool IsSigned = isSignedIntSetCC(Code: CC);
6123 if (OpVT.isInteger()) {
6124 bool IsLess = (CC == ISD::SETLE || CC == ISD::SETULE ||
6125 CC == ISD::SETLT || CC == ISD::SETULT);
6126 bool IsOr = (LogicOp->getOpcode() == ISD::OR);
6127 if (IsLess == IsOr)
6128 NewOpcode = IsSigned ? ISD::SMIN : ISD::UMIN;
6129 else
6130 NewOpcode = IsSigned ? ISD::SMAX : ISD::UMAX;
6131 } else if (OpVT.isFloatingPoint())
6132 NewOpcode =
6133 getMinMaxOpcodeForFP(Operand1, Operand2, CC, OrAndOpcode: LogicOp->getOpcode(),
6134 DAG, isFMAXNUMFMINNUM_IEEE, isFMAXNUMFMINNUM);
6135
6136 if (NewOpcode != ISD::DELETED_NODE) {
6137 SDValue MinMaxValue =
6138 DAG.getNode(Opcode: NewOpcode, DL, VT: OpVT, N1: Operand1, N2: Operand2);
6139 return DAG.getSetCC(DL, VT, LHS: MinMaxValue, RHS: CommonValue, Cond: CC);
6140 }
6141 }
6142 }
6143
6144 if (TargetPreference == AndOrSETCCFoldKind::None)
6145 return SDValue();
6146
6147 if (CCL == CCR &&
6148 CCL == (LogicOp->getOpcode() == ISD::AND ? ISD::SETNE : ISD::SETEQ) &&
6149 LHS0 == RHS0 && LHS1C && RHS1C && OpVT.isInteger()) {
6150 const APInt &APLhs = LHS1C->getAPIntValue();
6151 const APInt &APRhs = RHS1C->getAPIntValue();
6152
6153 // Preference is to use ISD::ABS or we already have an ISD::ABS (in which
6154 // case this is just a compare).
6155 if (APLhs == (-APRhs) &&
6156 ((TargetPreference & AndOrSETCCFoldKind::ABS) ||
6157 DAG.doesNodeExist(Opcode: ISD::ABS, VTList: DAG.getVTList(VT: OpVT), Ops: {LHS0}))) {
6158 const APInt &C = APLhs.isNegative() ? APRhs : APLhs;
6159 // (icmp eq A, C) | (icmp eq A, -C)
6160 // -> (icmp eq Abs(A), C)
6161 // (icmp ne A, C) & (icmp ne A, -C)
6162 // -> (icmp ne Abs(A), C)
6163 SDValue AbsOp = DAG.getNode(Opcode: ISD::ABS, DL, VT: OpVT, Operand: LHS0);
6164 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AbsOp,
6165 N2: DAG.getConstant(Val: C, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6166 } else if (TargetPreference &
6167 (AndOrSETCCFoldKind::AddAnd | AndOrSETCCFoldKind::NotAnd)) {
6168
6169 // AndOrSETCCFoldKind::AddAnd:
6170 // A == C0 | A == C1
6171 // IF IsPow2(smax(C0, C1)-smin(C0, C1))
6172 // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) == 0
6173 // A != C0 & A != C1
6174 // IF IsPow2(smax(C0, C1)-smin(C0, C1))
6175 // -> ((A - smin(C0, C1)) & ~(smax(C0, C1)-smin(C0, C1))) != 0
6176
6177 // AndOrSETCCFoldKind::NotAnd:
6178 // A == C0 | A == C1
6179 // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6180 // -> ~A & smin(C0, C1) == 0
6181 // A != C0 & A != C1
6182 // IF smax(C0, C1) == -1 AND IsPow2(smax(C0, C1) - smin(C0, C1))
6183 // -> ~A & smin(C0, C1) != 0
6184
6185 const APInt &MaxC = APIntOps::smax(A: APRhs, B: APLhs);
6186 const APInt &MinC = APIntOps::smin(A: APRhs, B: APLhs);
6187 APInt Dif = MaxC - MinC;
6188 if (!Dif.isZero() && Dif.isPowerOf2()) {
6189 if (MaxC.isAllOnes() &&
6190 (TargetPreference & AndOrSETCCFoldKind::NotAnd)) {
6191 SDValue NotOp = DAG.getNOT(DL, Val: LHS0, VT: OpVT);
6192 SDValue AndOp = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: NotOp,
6193 N2: DAG.getConstant(Val: MinC, DL, VT: OpVT));
6194 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AndOp,
6195 N2: DAG.getConstant(Val: 0, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6196 } else if (TargetPreference & AndOrSETCCFoldKind::AddAnd) {
6197
6198 SDValue AddOp = DAG.getNode(Opcode: ISD::ADD, DL, VT: OpVT, N1: LHS0,
6199 N2: DAG.getConstant(Val: -MinC, DL, VT: OpVT));
6200 SDValue AndOp = DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: AddOp,
6201 N2: DAG.getConstant(Val: ~Dif, DL, VT: OpVT));
6202 return DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: AndOp,
6203 N2: DAG.getConstant(Val: 0, DL, VT: OpVT), N3: LHS.getOperand(i: 2));
6204 }
6205 }
6206 }
6207 }
6208
6209 return SDValue();
6210}
6211
6212// Combine `(select c, (X & 1), 0)` -> `(and (zext c), X)`.
6213// We canonicalize to the `select` form in the middle end, but the `and` form
6214// gets better codegen and all tested targets (arm, x86, riscv)
6215static SDValue combineSelectAsExtAnd(SDValue Cond, SDValue T, SDValue F,
6216 const SDLoc &DL, SelectionDAG &DAG) {
6217 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6218 if (!isNullConstant(V: F))
6219 return SDValue();
6220
6221 EVT CondVT = Cond.getValueType();
6222 if (TLI.getBooleanContents(Type: CondVT) !=
6223 TargetLoweringBase::ZeroOrOneBooleanContent)
6224 return SDValue();
6225
6226 if (T.getOpcode() != ISD::AND)
6227 return SDValue();
6228
6229 if (!isOneConstant(V: T.getOperand(i: 1)))
6230 return SDValue();
6231
6232 EVT OpVT = T.getValueType();
6233
6234 SDValue CondMask =
6235 OpVT == CondVT ? Cond : DAG.getBoolExtOrTrunc(Op: Cond, SL: DL, VT: OpVT, OpVT: CondVT);
6236 return DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: CondMask, N2: T.getOperand(i: 0));
6237}
6238
6239/// This contains all DAGCombine rules which reduce two values combined by
6240/// an And operation to a single value. This makes them reusable in the context
6241/// of visitSELECT(). Rules involving constants are not included as
6242/// visitSELECT() already handles those cases.
6243SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, SDNode *N) {
6244 EVT VT = N1.getValueType();
6245 SDLoc DL(N);
6246
6247 // fold (and x, undef) -> 0
6248 if (N0.isUndef() || N1.isUndef())
6249 return DAG.getConstant(Val: 0, DL, VT);
6250
6251 if (SDValue V = foldLogicOfSetCCs(IsAnd: true, N0, N1, DL))
6252 return V;
6253
6254 // Canonicalize:
6255 // and(x, add) -> and(add, x)
6256 if (N1.getOpcode() == ISD::ADD)
6257 std::swap(a&: N0, b&: N1);
6258
6259 // TODO: Rewrite this to return a new 'AND' instead of using CombineTo.
6260 if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::SRL &&
6261 VT.isScalarInteger() && VT.getSizeInBits() <= 64 && N0->hasOneUse()) {
6262 if (ConstantSDNode *ADDI = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
6263 if (ConstantSDNode *SRLI = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1))) {
6264 // Look for (and (add x, c1), (lshr y, c2)). If C1 wasn't a legal
6265 // immediate for an add, but it is legal if its top c2 bits are set,
6266 // transform the ADD so the immediate doesn't need to be materialized
6267 // in a register.
6268 APInt ADDC = ADDI->getAPIntValue();
6269 APInt SRLC = SRLI->getAPIntValue();
6270 if (ADDC.getSignificantBits() <= 64 && SRLC.ult(RHS: VT.getSizeInBits()) &&
6271 !TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6272 APInt Mask = APInt::getHighBitsSet(numBits: VT.getSizeInBits(),
6273 hiBitsSet: SRLC.getZExtValue());
6274 if (DAG.MaskedValueIsZero(Op: N0.getOperand(i: 1), Mask)) {
6275 ADDC |= Mask;
6276 if (TLI.isLegalAddImmediate(ADDC.getSExtValue())) {
6277 SDLoc DL0(N0);
6278 SDValue NewAdd =
6279 DAG.getNode(Opcode: ISD::ADD, DL: DL0, VT,
6280 N1: N0.getOperand(i: 0), N2: DAG.getConstant(Val: ADDC, DL, VT));
6281 CombineTo(N: N0.getNode(), Res: NewAdd);
6282 // Return N so it doesn't get rechecked!
6283 return SDValue(N, 0);
6284 }
6285 }
6286 }
6287 }
6288 }
6289 }
6290
6291 return SDValue();
6292}
6293
6294bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
6295 EVT LoadResultTy, EVT &ExtVT) {
6296 if (!AndC->getAPIntValue().isMask())
6297 return false;
6298
6299 unsigned ActiveBits = AndC->getAPIntValue().countr_one();
6300
6301 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
6302 EVT LoadedVT = LoadN->getMemoryVT();
6303
6304 if (ExtVT == LoadedVT &&
6305 (!LegalOperations ||
6306 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LoadResultTy, MemVT: ExtVT))) {
6307 // ZEXTLOAD will match without needing to change the size of the value being
6308 // loaded.
6309 return true;
6310 }
6311
6312 // Do not change the width of a volatile or atomic loads.
6313 if (!LoadN->isSimple())
6314 return false;
6315
6316 // Do not generate loads of non-round integer types since these can
6317 // be expensive (and would be wrong if the type is not byte sized).
6318 if (!LoadedVT.bitsGT(VT: ExtVT) || !ExtVT.isRound())
6319 return false;
6320
6321 if (LegalOperations &&
6322 !TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LoadResultTy, MemVT: ExtVT))
6323 return false;
6324
6325 if (!TLI.shouldReduceLoadWidth(Load: LoadN, ExtTy: ISD::ZEXTLOAD, NewVT: ExtVT))
6326 return false;
6327
6328 return true;
6329}
6330
6331bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
6332 ISD::LoadExtType ExtType, EVT &MemVT,
6333 unsigned ShAmt) {
6334 if (!LDST)
6335 return false;
6336 // Only allow byte offsets.
6337 if (ShAmt % 8)
6338 return false;
6339
6340 // Do not generate loads of non-round integer types since these can
6341 // be expensive (and would be wrong if the type is not byte sized).
6342 if (!MemVT.isRound())
6343 return false;
6344
6345 // Don't change the width of a volatile or atomic loads.
6346 if (!LDST->isSimple())
6347 return false;
6348
6349 EVT LdStMemVT = LDST->getMemoryVT();
6350
6351 // Bail out when changing the scalable property, since we can't be sure that
6352 // we're actually narrowing here.
6353 if (LdStMemVT.isScalableVector() != MemVT.isScalableVector())
6354 return false;
6355
6356 // Verify that we are actually reducing a load width here.
6357 if (LdStMemVT.bitsLT(VT: MemVT))
6358 return false;
6359
6360 // Ensure that this isn't going to produce an unsupported memory access.
6361 if (ShAmt) {
6362 assert(ShAmt % 8 == 0 && "ShAmt is byte offset");
6363 const unsigned ByteShAmt = ShAmt / 8;
6364 const Align LDSTAlign = LDST->getAlign();
6365 const Align NarrowAlign = commonAlignment(A: LDSTAlign, Offset: ByteShAmt);
6366 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: MemVT,
6367 AddrSpace: LDST->getAddressSpace(), Alignment: NarrowAlign,
6368 Flags: LDST->getMemOperand()->getFlags()))
6369 return false;
6370 }
6371
6372 // It's not possible to generate a constant of extended or untyped type.
6373 EVT PtrType = LDST->getBasePtr().getValueType();
6374 if (PtrType == MVT::Untyped || PtrType.isExtended())
6375 return false;
6376
6377 if (isa<LoadSDNode>(Val: LDST)) {
6378 LoadSDNode *Load = cast<LoadSDNode>(Val: LDST);
6379 // Don't transform one with multiple uses, this would require adding a new
6380 // load.
6381 if (!SDValue(Load, 0).hasOneUse())
6382 return false;
6383
6384 if (LegalOperations &&
6385 !TLI.isLoadExtLegal(ExtType, ValVT: Load->getValueType(ResNo: 0), MemVT))
6386 return false;
6387
6388 // For the transform to be legal, the load must produce only two values
6389 // (the value loaded and the chain). Don't transform a pre-increment
6390 // load, for example, which produces an extra value. Otherwise the
6391 // transformation is not equivalent, and the downstream logic to replace
6392 // uses gets things wrong.
6393 if (Load->getNumValues() > 2)
6394 return false;
6395
6396 // If the load that we're shrinking is an extload and we're not just
6397 // discarding the extension we can't simply shrink the load. Bail.
6398 // TODO: It would be possible to merge the extensions in some cases.
6399 if (Load->getExtensionType() != ISD::NON_EXTLOAD &&
6400 Load->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6401 return false;
6402
6403 if (!TLI.shouldReduceLoadWidth(Load, ExtTy: ExtType, NewVT: MemVT))
6404 return false;
6405 } else {
6406 assert(isa<StoreSDNode>(LDST) && "It is not a Load nor a Store SDNode");
6407 StoreSDNode *Store = cast<StoreSDNode>(Val: LDST);
6408 // Can't write outside the original store
6409 if (Store->getMemoryVT().getSizeInBits() < MemVT.getSizeInBits() + ShAmt)
6410 return false;
6411
6412 if (LegalOperations &&
6413 !TLI.isTruncStoreLegal(ValVT: Store->getValue().getValueType(), MemVT))
6414 return false;
6415 }
6416 return true;
6417}
6418
6419bool DAGCombiner::SearchForAndLoads(SDNode *N,
6420 SmallVectorImpl<LoadSDNode*> &Loads,
6421 SmallPtrSetImpl<SDNode*> &NodesWithConsts,
6422 ConstantSDNode *Mask,
6423 SDNode *&NodeToMask) {
6424 // Recursively search for the operands, looking for loads which can be
6425 // narrowed.
6426 for (SDValue Op : N->op_values()) {
6427 if (Op.getValueType().isVector())
6428 return false;
6429
6430 // Some constants may need fixing up later if they are too large.
6431 if (auto *C = dyn_cast<ConstantSDNode>(Val&: Op)) {
6432 if ((N->getOpcode() == ISD::OR || N->getOpcode() == ISD::XOR) &&
6433 (Mask->getAPIntValue() & C->getAPIntValue()) != C->getAPIntValue())
6434 NodesWithConsts.insert(Ptr: N);
6435 continue;
6436 }
6437
6438 if (!Op.hasOneUse())
6439 return false;
6440
6441 switch(Op.getOpcode()) {
6442 case ISD::LOAD: {
6443 auto *Load = cast<LoadSDNode>(Val&: Op);
6444 EVT ExtVT;
6445 if (isAndLoadExtLoad(AndC: Mask, LoadN: Load, LoadResultTy: Load->getValueType(ResNo: 0), ExtVT) &&
6446 isLegalNarrowLdSt(LDST: Load, ExtType: ISD::ZEXTLOAD, MemVT&: ExtVT)) {
6447
6448 // ZEXTLOAD is already small enough.
6449 if (Load->getExtensionType() == ISD::ZEXTLOAD &&
6450 ExtVT.bitsGE(VT: Load->getMemoryVT()))
6451 continue;
6452
6453 // Use LE to convert equal sized loads to zext.
6454 if (ExtVT.bitsLE(VT: Load->getMemoryVT()))
6455 Loads.push_back(Elt: Load);
6456
6457 continue;
6458 }
6459 return false;
6460 }
6461 case ISD::ZERO_EXTEND:
6462 case ISD::AssertZext: {
6463 unsigned ActiveBits = Mask->getAPIntValue().countr_one();
6464 EVT ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
6465 EVT VT = Op.getOpcode() == ISD::AssertZext ?
6466 cast<VTSDNode>(Val: Op.getOperand(i: 1))->getVT() :
6467 Op.getOperand(i: 0).getValueType();
6468
6469 // We can accept extending nodes if the mask is wider or an equal
6470 // width to the original type.
6471 if (ExtVT.bitsGE(VT))
6472 continue;
6473 break;
6474 }
6475 case ISD::OR:
6476 case ISD::XOR:
6477 case ISD::AND:
6478 if (!SearchForAndLoads(N: Op.getNode(), Loads, NodesWithConsts, Mask,
6479 NodeToMask))
6480 return false;
6481 continue;
6482 }
6483
6484 // Allow one node which will masked along with any loads found.
6485 if (NodeToMask)
6486 return false;
6487
6488 // Also ensure that the node to be masked only produces one data result.
6489 NodeToMask = Op.getNode();
6490 if (NodeToMask->getNumValues() > 1) {
6491 bool HasValue = false;
6492 for (unsigned i = 0, e = NodeToMask->getNumValues(); i < e; ++i) {
6493 MVT VT = SDValue(NodeToMask, i).getSimpleValueType();
6494 if (VT != MVT::Glue && VT != MVT::Other) {
6495 if (HasValue) {
6496 NodeToMask = nullptr;
6497 return false;
6498 }
6499 HasValue = true;
6500 }
6501 }
6502 assert(HasValue && "Node to be masked has no data result?");
6503 }
6504 }
6505 return true;
6506}
6507
6508bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
6509 auto *Mask = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
6510 if (!Mask)
6511 return false;
6512
6513 if (!Mask->getAPIntValue().isMask())
6514 return false;
6515
6516 // No need to do anything if the and directly uses a load.
6517 if (isa<LoadSDNode>(Val: N->getOperand(Num: 0)))
6518 return false;
6519
6520 SmallVector<LoadSDNode*, 8> Loads;
6521 SmallPtrSet<SDNode*, 2> NodesWithConsts;
6522 SDNode *FixupNode = nullptr;
6523 if (SearchForAndLoads(N, Loads, NodesWithConsts, Mask, NodeToMask&: FixupNode)) {
6524 if (Loads.empty())
6525 return false;
6526
6527 LLVM_DEBUG(dbgs() << "Backwards propagate AND: "; N->dump());
6528 SDValue MaskOp = N->getOperand(Num: 1);
6529
6530 // If it exists, fixup the single node we allow in the tree that needs
6531 // masking.
6532 if (FixupNode) {
6533 LLVM_DEBUG(dbgs() << "First, need to fix up: "; FixupNode->dump());
6534 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(FixupNode),
6535 VT: FixupNode->getValueType(ResNo: 0),
6536 N1: SDValue(FixupNode, 0), N2: MaskOp);
6537 DAG.ReplaceAllUsesOfValueWith(From: SDValue(FixupNode, 0), To: And);
6538 if (And.getOpcode() == ISD ::AND)
6539 DAG.UpdateNodeOperands(N: And.getNode(), Op1: SDValue(FixupNode, 0), Op2: MaskOp);
6540 }
6541
6542 // Narrow any constants that need it.
6543 for (auto *LogicN : NodesWithConsts) {
6544 SDValue Op0 = LogicN->getOperand(Num: 0);
6545 SDValue Op1 = LogicN->getOperand(Num: 1);
6546
6547 if (isa<ConstantSDNode>(Val: Op0))
6548 Op0 =
6549 DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Op0), VT: Op0.getValueType(), N1: Op0, N2: MaskOp);
6550
6551 if (isa<ConstantSDNode>(Val: Op1))
6552 Op1 =
6553 DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Op1), VT: Op1.getValueType(), N1: Op1, N2: MaskOp);
6554
6555 if (isa<ConstantSDNode>(Val: Op0) && !isa<ConstantSDNode>(Val: Op1))
6556 std::swap(a&: Op0, b&: Op1);
6557
6558 DAG.UpdateNodeOperands(N: LogicN, Op1: Op0, Op2: Op1);
6559 }
6560
6561 // Create narrow loads.
6562 for (auto *Load : Loads) {
6563 LLVM_DEBUG(dbgs() << "Propagate AND back to: "; Load->dump());
6564 SDValue And = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Load), VT: Load->getValueType(ResNo: 0),
6565 N1: SDValue(Load, 0), N2: MaskOp);
6566 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 0), To: And);
6567 if (And.getOpcode() == ISD ::AND)
6568 And = SDValue(
6569 DAG.UpdateNodeOperands(N: And.getNode(), Op1: SDValue(Load, 0), Op2: MaskOp), 0);
6570 SDValue NewLoad = reduceLoadWidth(N: And.getNode());
6571 assert(NewLoad &&
6572 "Shouldn't be masking the load if it can't be narrowed");
6573 CombineTo(N: Load, Res0: NewLoad, Res1: NewLoad.getValue(R: 1));
6574 }
6575 DAG.ReplaceAllUsesWith(From: N, To: N->getOperand(Num: 0).getNode());
6576 return true;
6577 }
6578 return false;
6579}
6580
6581// Unfold
6582// x & (-1 'logical shift' y)
6583// To
6584// (x 'opposite logical shift' y) 'logical shift' y
6585// if it is better for performance.
6586SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) {
6587 assert(N->getOpcode() == ISD::AND);
6588
6589 SDValue N0 = N->getOperand(Num: 0);
6590 SDValue N1 = N->getOperand(Num: 1);
6591
6592 // Do we actually prefer shifts over mask?
6593 if (!TLI.shouldFoldMaskToVariableShiftPair(X: N0))
6594 return SDValue();
6595
6596 // Try to match (-1 '[outer] logical shift' y)
6597 unsigned OuterShift;
6598 unsigned InnerShift; // The opposite direction to the OuterShift.
6599 SDValue Y; // Shift amount.
6600 auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool {
6601 if (!M.hasOneUse())
6602 return false;
6603 OuterShift = M->getOpcode();
6604 if (OuterShift == ISD::SHL)
6605 InnerShift = ISD::SRL;
6606 else if (OuterShift == ISD::SRL)
6607 InnerShift = ISD::SHL;
6608 else
6609 return false;
6610 if (!isAllOnesConstant(V: M->getOperand(Num: 0)))
6611 return false;
6612 Y = M->getOperand(Num: 1);
6613 return true;
6614 };
6615
6616 SDValue X;
6617 if (matchMask(N1))
6618 X = N0;
6619 else if (matchMask(N0))
6620 X = N1;
6621 else
6622 return SDValue();
6623
6624 SDLoc DL(N);
6625 EVT VT = N->getValueType(ResNo: 0);
6626
6627 // tmp = x 'opposite logical shift' y
6628 SDValue T0 = DAG.getNode(Opcode: InnerShift, DL, VT, N1: X, N2: Y);
6629 // ret = tmp 'logical shift' y
6630 SDValue T1 = DAG.getNode(Opcode: OuterShift, DL, VT, N1: T0, N2: Y);
6631
6632 return T1;
6633}
6634
6635/// Try to replace shift/logic that tests if a bit is clear with mask + setcc.
6636/// For a target with a bit test, this is expected to become test + set and save
6637/// at least 1 instruction.
6638static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
6639 assert(And->getOpcode() == ISD::AND && "Expected an 'and' op");
6640
6641 // Look through an optional extension.
6642 SDValue And0 = And->getOperand(Num: 0), And1 = And->getOperand(Num: 1);
6643 if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
6644 And0 = And0.getOperand(i: 0);
6645 if (!isOneConstant(V: And1) || !And0.hasOneUse())
6646 return SDValue();
6647
6648 SDValue Src = And0;
6649
6650 // Attempt to find a 'not' op.
6651 // TODO: Should we favor test+set even without the 'not' op?
6652 bool FoundNot = false;
6653 if (isBitwiseNot(V: Src)) {
6654 FoundNot = true;
6655 Src = Src.getOperand(i: 0);
6656
6657 // Look though an optional truncation. The source operand may not be the
6658 // same type as the original 'and', but that is ok because we are masking
6659 // off everything but the low bit.
6660 if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
6661 Src = Src.getOperand(i: 0);
6662 }
6663
6664 // Match a shift-right by constant.
6665 if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
6666 return SDValue();
6667
6668 // This is probably not worthwhile without a supported type.
6669 EVT SrcVT = Src.getValueType();
6670 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6671 if (!TLI.isTypeLegal(VT: SrcVT))
6672 return SDValue();
6673
6674 // We might have looked through casts that make this transform invalid.
6675 unsigned BitWidth = SrcVT.getScalarSizeInBits();
6676 SDValue ShiftAmt = Src.getOperand(i: 1);
6677 auto *ShiftAmtC = dyn_cast<ConstantSDNode>(Val&: ShiftAmt);
6678 if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(RHS: BitWidth))
6679 return SDValue();
6680
6681 // Set source to shift source.
6682 Src = Src.getOperand(i: 0);
6683
6684 // Try again to find a 'not' op.
6685 // TODO: Should we favor test+set even with two 'not' ops?
6686 if (!FoundNot) {
6687 if (!isBitwiseNot(V: Src))
6688 return SDValue();
6689 Src = Src.getOperand(i: 0);
6690 }
6691
6692 if (!TLI.hasBitTest(X: Src, Y: ShiftAmt))
6693 return SDValue();
6694
6695 // Turn this into a bit-test pattern using mask op + setcc:
6696 // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
6697 // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
6698 SDLoc DL(And);
6699 SDValue X = DAG.getZExtOrTrunc(Op: Src, DL, VT: SrcVT);
6700 EVT CCVT =
6701 TLI.getSetCCResultType(DL: DAG.getDataLayout(), Context&: *DAG.getContext(), VT: SrcVT);
6702 SDValue Mask = DAG.getConstant(
6703 Val: APInt::getOneBitSet(numBits: BitWidth, BitNo: ShiftAmtC->getZExtValue()), DL, VT: SrcVT);
6704 SDValue NewAnd = DAG.getNode(Opcode: ISD::AND, DL, VT: SrcVT, N1: X, N2: Mask);
6705 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: SrcVT);
6706 SDValue Setcc = DAG.getSetCC(DL, VT: CCVT, LHS: NewAnd, RHS: Zero, Cond: ISD::SETEQ);
6707 return DAG.getZExtOrTrunc(Op: Setcc, DL, VT: And->getValueType(ResNo: 0));
6708}
6709
6710/// For targets that support usubsat, match a bit-hack form of that operation
6711/// that ends in 'and' and convert it.
6712static SDValue foldAndToUsubsat(SDNode *N, SelectionDAG &DAG, const SDLoc &DL) {
6713 EVT VT = N->getValueType(ResNo: 0);
6714 unsigned BitWidth = VT.getScalarSizeInBits();
6715 APInt SignMask = APInt::getSignMask(BitWidth);
6716
6717 // (i8 X ^ 128) & (i8 X s>> 7) --> usubsat X, 128
6718 // (i8 X + 128) & (i8 X s>> 7) --> usubsat X, 128
6719 // xor/add with SMIN (signmask) are logically equivalent.
6720 SDValue X;
6721 if (!sd_match(N, P: m_And(L: m_OneUse(P: m_Xor(L: m_Value(N&: X), R: m_SpecificInt(V: SignMask))),
6722 R: m_OneUse(P: m_Sra(L: m_Deferred(V&: X),
6723 R: m_SpecificInt(V: BitWidth - 1))))) &&
6724 !sd_match(N, P: m_And(L: m_OneUse(P: m_Add(L: m_Value(N&: X), R: m_SpecificInt(V: SignMask))),
6725 R: m_OneUse(P: m_Sra(L: m_Deferred(V&: X),
6726 R: m_SpecificInt(V: BitWidth - 1))))))
6727 return SDValue();
6728
6729 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: X,
6730 N2: DAG.getConstant(Val: SignMask, DL, VT));
6731}
6732
6733/// Given a bitwise logic operation N with a matching bitwise logic operand,
6734/// fold a pattern where 2 of the source operands are identically shifted
6735/// values. For example:
6736/// ((X0 << Y) | Z) | (X1 << Y) --> ((X0 | X1) << Y) | Z
6737static SDValue foldLogicOfShifts(SDNode *N, SDValue LogicOp, SDValue ShiftOp,
6738 SelectionDAG &DAG) {
6739 unsigned LogicOpcode = N->getOpcode();
6740 assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
6741 "Expected bitwise logic operation");
6742
6743 if (!LogicOp.hasOneUse() || !ShiftOp.hasOneUse())
6744 return SDValue();
6745
6746 // Match another bitwise logic op and a shift.
6747 unsigned ShiftOpcode = ShiftOp.getOpcode();
6748 if (LogicOp.getOpcode() != LogicOpcode ||
6749 !(ShiftOpcode == ISD::SHL || ShiftOpcode == ISD::SRL ||
6750 ShiftOpcode == ISD::SRA))
6751 return SDValue();
6752
6753 // Match another shift op inside the first logic operand. Handle both commuted
6754 // possibilities.
6755 // LOGIC (LOGIC (SH X0, Y), Z), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6756 // LOGIC (LOGIC Z, (SH X0, Y)), (SH X1, Y) --> LOGIC (SH (LOGIC X0, X1), Y), Z
6757 SDValue X1 = ShiftOp.getOperand(i: 0);
6758 SDValue Y = ShiftOp.getOperand(i: 1);
6759 SDValue X0, Z;
6760 if (LogicOp.getOperand(i: 0).getOpcode() == ShiftOpcode &&
6761 LogicOp.getOperand(i: 0).getOperand(i: 1) == Y) {
6762 X0 = LogicOp.getOperand(i: 0).getOperand(i: 0);
6763 Z = LogicOp.getOperand(i: 1);
6764 } else if (LogicOp.getOperand(i: 1).getOpcode() == ShiftOpcode &&
6765 LogicOp.getOperand(i: 1).getOperand(i: 1) == Y) {
6766 X0 = LogicOp.getOperand(i: 1).getOperand(i: 0);
6767 Z = LogicOp.getOperand(i: 0);
6768 } else {
6769 return SDValue();
6770 }
6771
6772 EVT VT = N->getValueType(ResNo: 0);
6773 SDLoc DL(N);
6774 SDValue LogicX = DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: X0, N2: X1);
6775 SDValue NewShift = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: LogicX, N2: Y);
6776 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: NewShift, N2: Z);
6777}
6778
6779/// Given a tree of logic operations with shape like
6780/// (LOGIC (LOGIC (X, Y), LOGIC (Z, Y)))
6781/// try to match and fold shift operations with the same shift amount.
6782/// For example:
6783/// LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W) -->
6784/// --> LOGIC (SH (LOGIC X0, X1), Y), (LOGIC Z, W)
6785static SDValue foldLogicTreeOfShifts(SDNode *N, SDValue LeftHand,
6786 SDValue RightHand, SelectionDAG &DAG) {
6787 unsigned LogicOpcode = N->getOpcode();
6788 assert(ISD::isBitwiseLogicOp(LogicOpcode) &&
6789 "Expected bitwise logic operation");
6790 if (LeftHand.getOpcode() != LogicOpcode ||
6791 RightHand.getOpcode() != LogicOpcode)
6792 return SDValue();
6793 if (!LeftHand.hasOneUse() || !RightHand.hasOneUse())
6794 return SDValue();
6795
6796 // Try to match one of following patterns:
6797 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC (SH X1, Y), W)
6798 // LOGIC (LOGIC (SH X0, Y), Z), (LOGIC W, (SH X1, Y))
6799 // Note that foldLogicOfShifts will handle commuted versions of the left hand
6800 // itself.
6801 SDValue CombinedShifts, W;
6802 SDValue R0 = RightHand.getOperand(i: 0);
6803 SDValue R1 = RightHand.getOperand(i: 1);
6804 if ((CombinedShifts = foldLogicOfShifts(N, LogicOp: LeftHand, ShiftOp: R0, DAG)))
6805 W = R1;
6806 else if ((CombinedShifts = foldLogicOfShifts(N, LogicOp: LeftHand, ShiftOp: R1, DAG)))
6807 W = R0;
6808 else
6809 return SDValue();
6810
6811 EVT VT = N->getValueType(ResNo: 0);
6812 SDLoc DL(N);
6813 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: CombinedShifts, N2: W);
6814}
6815
6816SDValue DAGCombiner::visitAND(SDNode *N) {
6817 SDValue N0 = N->getOperand(Num: 0);
6818 SDValue N1 = N->getOperand(Num: 1);
6819 EVT VT = N1.getValueType();
6820 SDLoc DL(N);
6821
6822 // x & x --> x
6823 if (N0 == N1)
6824 return N0;
6825
6826 // fold (and c1, c2) -> c1&c2
6827 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::AND, DL, VT, Ops: {N0, N1}))
6828 return C;
6829
6830 // canonicalize constant to RHS
6831 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
6832 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
6833 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1, N2: N0);
6834
6835 if (areBitwiseNotOfEachother(Op0: N0, Op1: N1))
6836 return DAG.getConstant(Val: APInt::getZero(numBits: VT.getScalarSizeInBits()), DL, VT);
6837
6838 // fold vector ops
6839 if (VT.isVector()) {
6840 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
6841 return FoldedVOp;
6842
6843 // fold (and x, 0) -> 0, vector edition
6844 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
6845 // do not return N1, because undef node may exist in N1
6846 return DAG.getConstant(Val: APInt::getZero(numBits: N1.getScalarValueSizeInBits()), DL,
6847 VT: N1.getValueType());
6848
6849 // fold (and x, -1) -> x, vector edition
6850 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode()))
6851 return N0;
6852
6853 // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load
6854 auto *MLoad = dyn_cast<MaskedLoadSDNode>(Val&: N0);
6855 ConstantSDNode *Splat = isConstOrConstSplat(N: N1, AllowUndefs: true, AllowTruncation: true);
6856 if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat &&
6857 N1.hasOneUse()) {
6858 EVT LoadVT = MLoad->getMemoryVT();
6859 EVT ExtVT = VT;
6860 if (TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: ExtVT, MemVT: LoadVT)) {
6861 // For this AND to be a zero extension of the masked load the elements
6862 // of the BuildVec must mask the bottom bits of the extended element
6863 // type
6864 uint64_t ElementSize =
6865 LoadVT.getVectorElementType().getScalarSizeInBits();
6866 if (Splat->getAPIntValue().isMask(numBits: ElementSize)) {
6867 SDValue NewLoad = DAG.getMaskedLoad(
6868 VT: ExtVT, dl: DL, Chain: MLoad->getChain(), Base: MLoad->getBasePtr(),
6869 Offset: MLoad->getOffset(), Mask: MLoad->getMask(), Src0: MLoad->getPassThru(),
6870 MemVT: LoadVT, MMO: MLoad->getMemOperand(), AM: MLoad->getAddressingMode(),
6871 ISD::ZEXTLOAD, IsExpanding: MLoad->isExpandingLoad());
6872 bool LoadHasOtherUsers = !N0.hasOneUse();
6873 CombineTo(N, Res: NewLoad);
6874 if (LoadHasOtherUsers)
6875 CombineTo(N: MLoad, Res0: NewLoad.getValue(R: 0), Res1: NewLoad.getValue(R: 1));
6876 return SDValue(N, 0);
6877 }
6878 }
6879 }
6880 }
6881
6882 // fold (and x, -1) -> x
6883 if (isAllOnesConstant(V: N1))
6884 return N0;
6885
6886 // if (and x, c) is known to be zero, return 0
6887 unsigned BitWidth = VT.getScalarSizeInBits();
6888 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
6889 if (N1C && DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: BitWidth)))
6890 return DAG.getConstant(Val: 0, DL, VT);
6891
6892 if (SDValue R = foldAndOrOfSETCC(LogicOp: N, DAG))
6893 return R;
6894
6895 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
6896 return NewSel;
6897
6898 // reassociate and
6899 if (SDValue RAND = reassociateOps(Opc: ISD::AND, DL, N0, N1, Flags: N->getFlags()))
6900 return RAND;
6901
6902 // Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
6903 if (SDValue SD =
6904 reassociateReduction(RedOpc: ISD::VECREDUCE_AND, Opc: ISD::AND, DL, VT, N0, N1))
6905 return SD;
6906
6907 // fold (and (or x, C), D) -> D if (C & D) == D
6908 auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
6909 return RHS->getAPIntValue().isSubsetOf(RHS: LHS->getAPIntValue());
6910 };
6911 if (N0.getOpcode() == ISD::OR &&
6912 ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchSubset))
6913 return N1;
6914
6915 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
6916 SDValue N0Op0 = N0.getOperand(i: 0);
6917 EVT SrcVT = N0Op0.getValueType();
6918 unsigned SrcBitWidth = SrcVT.getScalarSizeInBits();
6919 APInt Mask = ~N1C->getAPIntValue();
6920 Mask = Mask.trunc(width: SrcBitWidth);
6921
6922 // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
6923 if (DAG.MaskedValueIsZero(Op: N0Op0, Mask))
6924 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0Op0);
6925
6926 // fold (and (any_ext V), c) -> (zero_ext (and (trunc V), c)) if profitable.
6927 if (N1C->getAPIntValue().countLeadingZeros() >= (BitWidth - SrcBitWidth) &&
6928 TLI.isTruncateFree(FromVT: VT, ToVT: SrcVT) && TLI.isZExtFree(FromTy: SrcVT, ToTy: VT) &&
6929 TLI.isTypeDesirableForOp(ISD::AND, VT: SrcVT) &&
6930 TLI.isNarrowingProfitable(SrcVT: VT, DestVT: SrcVT))
6931 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT,
6932 Operand: DAG.getNode(Opcode: ISD::AND, DL, VT: SrcVT, N1: N0Op0,
6933 N2: DAG.getZExtOrTrunc(Op: N1, DL, VT: SrcVT)));
6934 }
6935
6936 // fold (and (ext (and V, c1)), c2) -> (and (ext V), (and c1, (ext c2)))
6937 if (ISD::isExtOpcode(Opcode: N0.getOpcode())) {
6938 unsigned ExtOpc = N0.getOpcode();
6939 SDValue N0Op0 = N0.getOperand(i: 0);
6940 if (N0Op0.getOpcode() == ISD::AND &&
6941 (ExtOpc != ISD::ZERO_EXTEND || !TLI.isZExtFree(Val: N0Op0, VT2: VT)) &&
6942 DAG.isConstantIntBuildVectorOrConstantInt(N: N1) &&
6943 DAG.isConstantIntBuildVectorOrConstantInt(N: N0Op0.getOperand(i: 1)) &&
6944 N0->hasOneUse() && N0Op0->hasOneUse()) {
6945 SDValue NewMask =
6946 DAG.getNode(Opcode: ISD::AND, DL, VT, N1,
6947 N2: DAG.getNode(Opcode: ExtOpc, DL, VT, Operand: N0Op0.getOperand(i: 1)));
6948 return DAG.getNode(Opcode: ISD::AND, DL, VT,
6949 N1: DAG.getNode(Opcode: ExtOpc, DL, VT, Operand: N0Op0.getOperand(i: 0)),
6950 N2: NewMask);
6951 }
6952 }
6953
6954 // similarly fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
6955 // (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
6956 // already be zero by virtue of the width of the base type of the load.
6957 //
6958 // the 'X' node here can either be nothing or an extract_vector_elt to catch
6959 // more cases.
6960 if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
6961 N0.getValueSizeInBits() == N0.getOperand(i: 0).getScalarValueSizeInBits() &&
6962 N0.getOperand(i: 0).getOpcode() == ISD::LOAD &&
6963 N0.getOperand(i: 0).getResNo() == 0) ||
6964 (N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
6965 auto *Load =
6966 cast<LoadSDNode>(Val: (N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(i: 0));
6967
6968 // Get the constant (if applicable) the zero'th operand is being ANDed with.
6969 // This can be a pure constant or a vector splat, in which case we treat the
6970 // vector as a scalar and use the splat value.
6971 APInt Constant = APInt::getZero(numBits: 1);
6972 if (const ConstantSDNode *C = isConstOrConstSplat(
6973 N: N1, /*AllowUndef=*/AllowUndefs: false, /*AllowTruncation=*/true)) {
6974 Constant = C->getAPIntValue();
6975 } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(Val&: N1)) {
6976 unsigned EltBitWidth = Vector->getValueType(ResNo: 0).getScalarSizeInBits();
6977 APInt SplatValue, SplatUndef;
6978 unsigned SplatBitSize;
6979 bool HasAnyUndefs;
6980 // Endianness should not matter here. Code below makes sure that we only
6981 // use the result if the SplatBitSize is a multiple of the vector element
6982 // size. And after that we AND all element sized parts of the splat
6983 // together. So the end result should be the same regardless of in which
6984 // order we do those operations.
6985 const bool IsBigEndian = false;
6986 bool IsSplat =
6987 Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
6988 HasAnyUndefs, MinSplatBits: EltBitWidth, isBigEndian: IsBigEndian);
6989
6990 // Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
6991 // multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
6992 if (IsSplat && (SplatBitSize % EltBitWidth) == 0) {
6993 // Undef bits can contribute to a possible optimisation if set, so
6994 // set them.
6995 SplatValue |= SplatUndef;
6996
6997 // The splat value may be something like "0x00FFFFFF", which means 0 for
6998 // the first vector value and FF for the rest, repeating. We need a mask
6999 // that will apply equally to all members of the vector, so AND all the
7000 // lanes of the constant together.
7001 Constant = APInt::getAllOnes(numBits: EltBitWidth);
7002 for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
7003 Constant &= SplatValue.extractBits(numBits: EltBitWidth, bitPosition: i * EltBitWidth);
7004 }
7005 }
7006
7007 // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
7008 // actually legal and isn't going to get expanded, else this is a false
7009 // optimisation.
7010 bool CanZextLoadProfitably = TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD,
7011 ValVT: Load->getValueType(ResNo: 0),
7012 MemVT: Load->getMemoryVT());
7013
7014 // Resize the constant to the same size as the original memory access before
7015 // extension. If it is still the AllOnesValue then this AND is completely
7016 // unneeded.
7017 Constant = Constant.zextOrTrunc(width: Load->getMemoryVT().getScalarSizeInBits());
7018
7019 bool B;
7020 switch (Load->getExtensionType()) {
7021 default: B = false; break;
7022 case ISD::EXTLOAD: B = CanZextLoadProfitably; break;
7023 case ISD::ZEXTLOAD:
7024 case ISD::NON_EXTLOAD: B = true; break;
7025 }
7026
7027 if (B && Constant.isAllOnes()) {
7028 // If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
7029 // preserve semantics once we get rid of the AND.
7030 SDValue NewLoad(Load, 0);
7031
7032 // Fold the AND away. NewLoad may get replaced immediately.
7033 CombineTo(N, Res: (N0.getNode() == Load) ? NewLoad : N0);
7034
7035 if (Load->getExtensionType() == ISD::EXTLOAD) {
7036 NewLoad = DAG.getLoad(AM: Load->getAddressingMode(), ExtType: ISD::ZEXTLOAD,
7037 VT: Load->getValueType(ResNo: 0), dl: SDLoc(Load),
7038 Chain: Load->getChain(), Ptr: Load->getBasePtr(),
7039 Offset: Load->getOffset(), MemVT: Load->getMemoryVT(),
7040 MMO: Load->getMemOperand());
7041 // Replace uses of the EXTLOAD with the new ZEXTLOAD.
7042 if (Load->getNumValues() == 3) {
7043 // PRE/POST_INC loads have 3 values.
7044 SDValue To[] = { NewLoad.getValue(R: 0), NewLoad.getValue(R: 1),
7045 NewLoad.getValue(R: 2) };
7046 CombineTo(N: Load, To, NumTo: 3, AddTo: true);
7047 } else {
7048 CombineTo(N: Load, Res0: NewLoad.getValue(R: 0), Res1: NewLoad.getValue(R: 1));
7049 }
7050 }
7051
7052 return SDValue(N, 0); // Return N so it doesn't get rechecked!
7053 }
7054 }
7055
7056 // Try to convert a constant mask AND into a shuffle clear mask.
7057 if (VT.isVector())
7058 if (SDValue Shuffle = XformToShuffleWithZero(N))
7059 return Shuffle;
7060
7061 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7062 return Combined;
7063
7064 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() && N1C &&
7065 ISD::isExtOpcode(Opcode: N0.getOperand(i: 0).getOpcode())) {
7066 SDValue Ext = N0.getOperand(i: 0);
7067 EVT ExtVT = Ext->getValueType(ResNo: 0);
7068 SDValue Extendee = Ext->getOperand(Num: 0);
7069
7070 unsigned ScalarWidth = Extendee.getValueType().getScalarSizeInBits();
7071 if (N1C->getAPIntValue().isMask(numBits: ScalarWidth) &&
7072 (!LegalOperations || TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT: ExtVT))) {
7073 // (and (extract_subvector (zext|anyext|sext v) _) iN_mask)
7074 // => (extract_subvector (iN_zeroext v))
7075 SDValue ZeroExtExtendee =
7076 DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: ExtVT, Operand: Extendee);
7077
7078 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: ZeroExtExtendee,
7079 N2: N0.getOperand(i: 1));
7080 }
7081 }
7082
7083 // fold (and (masked_gather x)) -> (zext_masked_gather x)
7084 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(Val&: N0)) {
7085 EVT MemVT = GN0->getMemoryVT();
7086 EVT ScalarVT = MemVT.getScalarType();
7087
7088 if (SDValue(GN0, 0).hasOneUse() &&
7089 isConstantSplatVectorMaskForType(N: N1.getNode(), ScalarTy: ScalarVT) &&
7090 TLI.isVectorLoadExtDesirable(ExtVal: SDValue(SDValue(GN0, 0)))) {
7091 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
7092 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
7093
7094 SDValue ZExtLoad = DAG.getMaskedGather(
7095 VTs: DAG.getVTList(VT, MVT::Other), MemVT, dl: DL, Ops, MMO: GN0->getMemOperand(),
7096 IndexType: GN0->getIndexType(), ExtTy: ISD::ZEXTLOAD);
7097
7098 CombineTo(N, Res: ZExtLoad);
7099 AddToWorklist(N: ZExtLoad.getNode());
7100 // Avoid recheck of N.
7101 return SDValue(N, 0);
7102 }
7103 }
7104
7105 // fold (and (load x), 255) -> (zextload x, i8)
7106 // fold (and (extload x, i16), 255) -> (zextload x, i8)
7107 if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
7108 if (SDValue Res = reduceLoadWidth(N))
7109 return Res;
7110
7111 if (LegalTypes) {
7112 // Attempt to propagate the AND back up to the leaves which, if they're
7113 // loads, can be combined to narrow loads and the AND node can be removed.
7114 // Perform after legalization so that extend nodes will already be
7115 // combined into the loads.
7116 if (BackwardsPropagateMask(N))
7117 return SDValue(N, 0);
7118 }
7119
7120 if (SDValue Combined = visitANDLike(N0, N1, N))
7121 return Combined;
7122
7123 // Simplify: (and (op x...), (op y...)) -> (op (and x, y))
7124 if (N0.getOpcode() == N1.getOpcode())
7125 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7126 return V;
7127
7128 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
7129 return R;
7130 if (SDValue R = foldLogicOfShifts(N, LogicOp: N1, ShiftOp: N0, DAG))
7131 return R;
7132
7133 // Masking the negated extension of a boolean is just the zero-extended
7134 // boolean:
7135 // and (sub 0, zext(bool X)), 1 --> zext(bool X)
7136 // and (sub 0, sext(bool X)), 1 --> zext(bool X)
7137 //
7138 // Note: the SimplifyDemandedBits fold below can make an information-losing
7139 // transform, and then we have no way to find this better fold.
7140 if (N1C && N1C->isOne() && N0.getOpcode() == ISD::SUB) {
7141 if (isNullOrNullSplat(V: N0.getOperand(i: 0))) {
7142 SDValue SubRHS = N0.getOperand(i: 1);
7143 if (SubRHS.getOpcode() == ISD::ZERO_EXTEND &&
7144 SubRHS.getOperand(i: 0).getScalarValueSizeInBits() == 1)
7145 return SubRHS;
7146 if (SubRHS.getOpcode() == ISD::SIGN_EXTEND &&
7147 SubRHS.getOperand(i: 0).getScalarValueSizeInBits() == 1)
7148 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: SubRHS.getOperand(i: 0));
7149 }
7150 }
7151
7152 // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
7153 // fold (and (sra)) -> (and (srl)) when possible.
7154 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
7155 return SDValue(N, 0);
7156
7157 // fold (zext_inreg (extload x)) -> (zextload x)
7158 // fold (zext_inreg (sextload x)) -> (zextload x) iff load has one use
7159 if (ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
7160 (ISD::isEXTLoad(N: N0.getNode()) ||
7161 (ISD::isSEXTLoad(N: N0.getNode()) && N0.hasOneUse()))) {
7162 auto *LN0 = cast<LoadSDNode>(Val&: N0);
7163 EVT MemVT = LN0->getMemoryVT();
7164 // If we zero all the possible extended bits, then we can turn this into
7165 // a zextload if we are running before legalize or the operation is legal.
7166 unsigned ExtBitSize = N1.getScalarValueSizeInBits();
7167 unsigned MemBitSize = MemVT.getScalarSizeInBits();
7168 APInt ExtBits = APInt::getHighBitsSet(numBits: ExtBitSize, hiBitsSet: ExtBitSize - MemBitSize);
7169 if (DAG.MaskedValueIsZero(Op: N1, Mask: ExtBits) &&
7170 ((!LegalOperations && LN0->isSimple()) ||
7171 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT))) {
7172 SDValue ExtLoad =
7173 DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(N0), VT, Chain: LN0->getChain(),
7174 Ptr: LN0->getBasePtr(), MemVT, MMO: LN0->getMemOperand());
7175 AddToWorklist(N);
7176 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
7177 return SDValue(N, 0); // Return N so it doesn't get rechecked!
7178 }
7179 }
7180
7181 // fold (and (or (srl N, 8), (shl N, 8)), 0xffff) -> (srl (bswap N), const)
7182 if (N1C && N1C->getAPIntValue() == 0xffff && N0.getOpcode() == ISD::OR) {
7183 if (SDValue BSwap = MatchBSwapHWordLow(N: N0.getNode(), N0: N0.getOperand(i: 0),
7184 N1: N0.getOperand(i: 1), DemandHighBits: false))
7185 return BSwap;
7186 }
7187
7188 if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N))
7189 return Shifts;
7190
7191 if (SDValue V = combineShiftAnd1ToBitTest(And: N, DAG))
7192 return V;
7193
7194 // Recognize the following pattern:
7195 //
7196 // AndVT = (and (sign_extend NarrowVT to AndVT) #bitmask)
7197 //
7198 // where bitmask is a mask that clears the upper bits of AndVT. The
7199 // number of bits in bitmask must be a power of two.
7200 auto IsAndZeroExtMask = [](SDValue LHS, SDValue RHS) {
7201 if (LHS->getOpcode() != ISD::SIGN_EXTEND)
7202 return false;
7203
7204 auto *C = dyn_cast<ConstantSDNode>(Val&: RHS);
7205 if (!C)
7206 return false;
7207
7208 if (!C->getAPIntValue().isMask(
7209 numBits: LHS.getOperand(i: 0).getValueType().getFixedSizeInBits()))
7210 return false;
7211
7212 return true;
7213 };
7214
7215 // Replace (and (sign_extend ...) #bitmask) with (zero_extend ...).
7216 if (IsAndZeroExtMask(N0, N1))
7217 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
7218
7219 if (hasOperation(Opcode: ISD::USUBSAT, VT))
7220 if (SDValue V = foldAndToUsubsat(N, DAG, DL))
7221 return V;
7222
7223 // Postpone until legalization completed to avoid interference with bswap
7224 // folding
7225 if (LegalOperations || VT.isVector())
7226 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
7227 return R;
7228
7229 return SDValue();
7230}
7231
7232/// Match (a >> 8) | (a << 8) as (bswap a) >> 16.
7233SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
7234 bool DemandHighBits) {
7235 if (!LegalOperations)
7236 return SDValue();
7237
7238 EVT VT = N->getValueType(ResNo: 0);
7239 if (VT != MVT::i64 && VT != MVT::i32 && VT != MVT::i16)
7240 return SDValue();
7241 if (!TLI.isOperationLegalOrCustom(Op: ISD::BSWAP, VT))
7242 return SDValue();
7243
7244 // Recognize (and (shl a, 8), 0xff00), (and (srl a, 8), 0xff)
7245 bool LookPassAnd0 = false;
7246 bool LookPassAnd1 = false;
7247 if (N0.getOpcode() == ISD::AND && N0.getOperand(i: 0).getOpcode() == ISD::SRL)
7248 std::swap(a&: N0, b&: N1);
7249 if (N1.getOpcode() == ISD::AND && N1.getOperand(i: 0).getOpcode() == ISD::SHL)
7250 std::swap(a&: N0, b&: N1);
7251 if (N0.getOpcode() == ISD::AND) {
7252 if (!N0->hasOneUse())
7253 return SDValue();
7254 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7255 // Also handle 0xffff since the LHS is guaranteed to have zeros there.
7256 // This is needed for X86.
7257 if (!N01C || (N01C->getZExtValue() != 0xFF00 &&
7258 N01C->getZExtValue() != 0xFFFF))
7259 return SDValue();
7260 N0 = N0.getOperand(i: 0);
7261 LookPassAnd0 = true;
7262 }
7263
7264 if (N1.getOpcode() == ISD::AND) {
7265 if (!N1->hasOneUse())
7266 return SDValue();
7267 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1));
7268 if (!N11C || N11C->getZExtValue() != 0xFF)
7269 return SDValue();
7270 N1 = N1.getOperand(i: 0);
7271 LookPassAnd1 = true;
7272 }
7273
7274 if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
7275 std::swap(a&: N0, b&: N1);
7276 if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
7277 return SDValue();
7278 if (!N0->hasOneUse() || !N1->hasOneUse())
7279 return SDValue();
7280
7281 ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7282 ConstantSDNode *N11C = dyn_cast<ConstantSDNode>(Val: N1.getOperand(i: 1));
7283 if (!N01C || !N11C)
7284 return SDValue();
7285 if (N01C->getZExtValue() != 8 || N11C->getZExtValue() != 8)
7286 return SDValue();
7287
7288 // Look for (shl (and a, 0xff), 8), (srl (and a, 0xff00), 8)
7289 SDValue N00 = N0->getOperand(Num: 0);
7290 if (!LookPassAnd0 && N00.getOpcode() == ISD::AND) {
7291 if (!N00->hasOneUse())
7292 return SDValue();
7293 ConstantSDNode *N001C = dyn_cast<ConstantSDNode>(Val: N00.getOperand(i: 1));
7294 if (!N001C || N001C->getZExtValue() != 0xFF)
7295 return SDValue();
7296 N00 = N00.getOperand(i: 0);
7297 LookPassAnd0 = true;
7298 }
7299
7300 SDValue N10 = N1->getOperand(Num: 0);
7301 if (!LookPassAnd1 && N10.getOpcode() == ISD::AND) {
7302 if (!N10->hasOneUse())
7303 return SDValue();
7304 ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(Val: N10.getOperand(i: 1));
7305 // Also allow 0xFFFF since the bits will be shifted out. This is needed
7306 // for X86.
7307 if (!N101C || (N101C->getZExtValue() != 0xFF00 &&
7308 N101C->getZExtValue() != 0xFFFF))
7309 return SDValue();
7310 N10 = N10.getOperand(i: 0);
7311 LookPassAnd1 = true;
7312 }
7313
7314 if (N00 != N10)
7315 return SDValue();
7316
7317 // Make sure everything beyond the low halfword gets set to zero since the SRL
7318 // 16 will clear the top bits.
7319 unsigned OpSizeInBits = VT.getSizeInBits();
7320 if (OpSizeInBits > 16) {
7321 // If the left-shift isn't masked out then the only way this is a bswap is
7322 // if all bits beyond the low 8 are 0. In that case the entire pattern
7323 // reduces to a left shift anyway: leave it for other parts of the combiner.
7324 if (DemandHighBits && !LookPassAnd0)
7325 return SDValue();
7326
7327 // However, if the right shift isn't masked out then it might be because
7328 // it's not needed. See if we can spot that too. If the high bits aren't
7329 // demanded, we only need bits 23:16 to be zero. Otherwise, we need all
7330 // upper bits to be zero.
7331 if (!LookPassAnd1) {
7332 unsigned HighBit = DemandHighBits ? OpSizeInBits : 24;
7333 if (!DAG.MaskedValueIsZero(Op: N10,
7334 Mask: APInt::getBitsSet(numBits: OpSizeInBits, loBit: 16, hiBit: HighBit)))
7335 return SDValue();
7336 }
7337 }
7338
7339 SDValue Res = DAG.getNode(Opcode: ISD::BSWAP, DL: SDLoc(N), VT, Operand: N00);
7340 if (OpSizeInBits > 16) {
7341 SDLoc DL(N);
7342 Res = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Res,
7343 N2: DAG.getConstant(Val: OpSizeInBits - 16, DL,
7344 VT: getShiftAmountTy(LHSTy: VT)));
7345 }
7346 return Res;
7347}
7348
7349/// Return true if the specified node is an element that makes up a 32-bit
7350/// packed halfword byteswap.
7351/// ((x & 0x000000ff) << 8) |
7352/// ((x & 0x0000ff00) >> 8) |
7353/// ((x & 0x00ff0000) << 8) |
7354/// ((x & 0xff000000) >> 8)
7355static bool isBSwapHWordElement(SDValue N, MutableArrayRef<SDNode *> Parts) {
7356 if (!N->hasOneUse())
7357 return false;
7358
7359 unsigned Opc = N.getOpcode();
7360 if (Opc != ISD::AND && Opc != ISD::SHL && Opc != ISD::SRL)
7361 return false;
7362
7363 SDValue N0 = N.getOperand(i: 0);
7364 unsigned Opc0 = N0.getOpcode();
7365 if (Opc0 != ISD::AND && Opc0 != ISD::SHL && Opc0 != ISD::SRL)
7366 return false;
7367
7368 ConstantSDNode *N1C = nullptr;
7369 // SHL or SRL: look upstream for AND mask operand
7370 if (Opc == ISD::AND)
7371 N1C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
7372 else if (Opc0 == ISD::AND)
7373 N1C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7374 if (!N1C)
7375 return false;
7376
7377 unsigned MaskByteOffset;
7378 switch (N1C->getZExtValue()) {
7379 default:
7380 return false;
7381 case 0xFF: MaskByteOffset = 0; break;
7382 case 0xFF00: MaskByteOffset = 1; break;
7383 case 0xFFFF:
7384 // In case demanded bits didn't clear the bits that will be shifted out.
7385 // This is needed for X86.
7386 if (Opc == ISD::SRL || (Opc == ISD::AND && Opc0 == ISD::SHL)) {
7387 MaskByteOffset = 1;
7388 break;
7389 }
7390 return false;
7391 case 0xFF0000: MaskByteOffset = 2; break;
7392 case 0xFF000000: MaskByteOffset = 3; break;
7393 }
7394
7395 // Look for (x & 0xff) << 8 as well as ((x << 8) & 0xff00).
7396 if (Opc == ISD::AND) {
7397 if (MaskByteOffset == 0 || MaskByteOffset == 2) {
7398 // (x >> 8) & 0xff
7399 // (x >> 8) & 0xff0000
7400 if (Opc0 != ISD::SRL)
7401 return false;
7402 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7403 if (!C || C->getZExtValue() != 8)
7404 return false;
7405 } else {
7406 // (x << 8) & 0xff00
7407 // (x << 8) & 0xff000000
7408 if (Opc0 != ISD::SHL)
7409 return false;
7410 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
7411 if (!C || C->getZExtValue() != 8)
7412 return false;
7413 }
7414 } else if (Opc == ISD::SHL) {
7415 // (x & 0xff) << 8
7416 // (x & 0xff0000) << 8
7417 if (MaskByteOffset != 0 && MaskByteOffset != 2)
7418 return false;
7419 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
7420 if (!C || C->getZExtValue() != 8)
7421 return false;
7422 } else { // Opc == ISD::SRL
7423 // (x & 0xff00) >> 8
7424 // (x & 0xff000000) >> 8
7425 if (MaskByteOffset != 1 && MaskByteOffset != 3)
7426 return false;
7427 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val: N.getOperand(i: 1));
7428 if (!C || C->getZExtValue() != 8)
7429 return false;
7430 }
7431
7432 if (Parts[MaskByteOffset])
7433 return false;
7434
7435 Parts[MaskByteOffset] = N0.getOperand(i: 0).getNode();
7436 return true;
7437}
7438
7439// Match 2 elements of a packed halfword bswap.
7440static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
7441 if (N.getOpcode() == ISD::OR)
7442 return isBSwapHWordElement(N: N.getOperand(i: 0), Parts) &&
7443 isBSwapHWordElement(N: N.getOperand(i: 1), Parts);
7444
7445 if (N.getOpcode() == ISD::SRL && N.getOperand(i: 0).getOpcode() == ISD::BSWAP) {
7446 ConstantSDNode *C = isConstOrConstSplat(N: N.getOperand(i: 1));
7447 if (!C || C->getAPIntValue() != 16)
7448 return false;
7449 Parts[0] = Parts[1] = N.getOperand(i: 0).getOperand(i: 0).getNode();
7450 return true;
7451 }
7452
7453 return false;
7454}
7455
7456// Match this pattern:
7457// (or (and (shl (A, 8)), 0xff00ff00), (and (srl (A, 8)), 0x00ff00ff))
7458// And rewrite this to:
7459// (rotr (bswap A), 16)
7460static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
7461 SelectionDAG &DAG, SDNode *N, SDValue N0,
7462 SDValue N1, EVT VT, EVT ShiftAmountTy) {
7463 assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
7464 "MatchBSwapHWordOrAndAnd: expecting i32");
7465 if (!TLI.isOperationLegalOrCustom(Op: ISD::ROTR, VT))
7466 return SDValue();
7467 if (N0.getOpcode() != ISD::AND || N1.getOpcode() != ISD::AND)
7468 return SDValue();
7469 // TODO: this is too restrictive; lifting this restriction requires more tests
7470 if (!N0->hasOneUse() || !N1->hasOneUse())
7471 return SDValue();
7472 ConstantSDNode *Mask0 = isConstOrConstSplat(N: N0.getOperand(i: 1));
7473 ConstantSDNode *Mask1 = isConstOrConstSplat(N: N1.getOperand(i: 1));
7474 if (!Mask0 || !Mask1)
7475 return SDValue();
7476 if (Mask0->getAPIntValue() != 0xff00ff00 ||
7477 Mask1->getAPIntValue() != 0x00ff00ff)
7478 return SDValue();
7479 SDValue Shift0 = N0.getOperand(i: 0);
7480 SDValue Shift1 = N1.getOperand(i: 0);
7481 if (Shift0.getOpcode() != ISD::SHL || Shift1.getOpcode() != ISD::SRL)
7482 return SDValue();
7483 ConstantSDNode *ShiftAmt0 = isConstOrConstSplat(N: Shift0.getOperand(i: 1));
7484 ConstantSDNode *ShiftAmt1 = isConstOrConstSplat(N: Shift1.getOperand(i: 1));
7485 if (!ShiftAmt0 || !ShiftAmt1)
7486 return SDValue();
7487 if (ShiftAmt0->getAPIntValue() != 8 || ShiftAmt1->getAPIntValue() != 8)
7488 return SDValue();
7489 if (Shift0.getOperand(i: 0) != Shift1.getOperand(i: 0))
7490 return SDValue();
7491
7492 SDLoc DL(N);
7493 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: Shift0.getOperand(i: 0));
7494 SDValue ShAmt = DAG.getConstant(Val: 16, DL, VT: ShiftAmountTy);
7495 return DAG.getNode(Opcode: ISD::ROTR, DL, VT, N1: BSwap, N2: ShAmt);
7496}
7497
7498/// Match a 32-bit packed halfword bswap. That is
7499/// ((x & 0x000000ff) << 8) |
7500/// ((x & 0x0000ff00) >> 8) |
7501/// ((x & 0x00ff0000) << 8) |
7502/// ((x & 0xff000000) >> 8)
7503/// => (rotl (bswap x), 16)
7504SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
7505 if (!LegalOperations)
7506 return SDValue();
7507
7508 EVT VT = N->getValueType(ResNo: 0);
7509 if (VT != MVT::i32)
7510 return SDValue();
7511 if (!TLI.isOperationLegalOrCustom(Op: ISD::BSWAP, VT))
7512 return SDValue();
7513
7514 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT,
7515 ShiftAmountTy: getShiftAmountTy(LHSTy: VT)))
7516 return BSwap;
7517
7518 // Try again with commuted operands.
7519 if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0: N1, N1: N0, VT,
7520 ShiftAmountTy: getShiftAmountTy(LHSTy: VT)))
7521 return BSwap;
7522
7523
7524 // Look for either
7525 // (or (bswaphpair), (bswaphpair))
7526 // (or (or (bswaphpair), (and)), (and))
7527 // (or (or (and), (bswaphpair)), (and))
7528 SDNode *Parts[4] = {};
7529
7530 if (isBSwapHWordPair(N: N0, Parts)) {
7531 // (or (or (and), (and)), (or (and), (and)))
7532 if (!isBSwapHWordPair(N: N1, Parts))
7533 return SDValue();
7534 } else if (N0.getOpcode() == ISD::OR) {
7535 // (or (or (or (and), (and)), (and)), (and))
7536 if (!isBSwapHWordElement(N: N1, Parts))
7537 return SDValue();
7538 SDValue N00 = N0.getOperand(i: 0);
7539 SDValue N01 = N0.getOperand(i: 1);
7540 if (!(isBSwapHWordElement(N: N01, Parts) && isBSwapHWordPair(N: N00, Parts)) &&
7541 !(isBSwapHWordElement(N: N00, Parts) && isBSwapHWordPair(N: N01, Parts)))
7542 return SDValue();
7543 } else {
7544 return SDValue();
7545 }
7546
7547 // Make sure the parts are all coming from the same node.
7548 if (Parts[0] != Parts[1] || Parts[0] != Parts[2] || Parts[0] != Parts[3])
7549 return SDValue();
7550
7551 SDLoc DL(N);
7552 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT,
7553 Operand: SDValue(Parts[0], 0));
7554
7555 // Result of the bswap should be rotated by 16. If it's not legal, then
7556 // do (x << 16) | (x >> 16).
7557 SDValue ShAmt = DAG.getConstant(Val: 16, DL, VT: getShiftAmountTy(LHSTy: VT));
7558 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTL, VT))
7559 return DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: BSwap, N2: ShAmt);
7560 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTR, VT))
7561 return DAG.getNode(Opcode: ISD::ROTR, DL, VT, N1: BSwap, N2: ShAmt);
7562 return DAG.getNode(Opcode: ISD::OR, DL, VT,
7563 N1: DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: BSwap, N2: ShAmt),
7564 N2: DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: BSwap, N2: ShAmt));
7565}
7566
7567/// This contains all DAGCombine rules which reduce two values combined by
7568/// an Or operation to a single value \see visitANDLike().
7569SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, const SDLoc &DL) {
7570 EVT VT = N1.getValueType();
7571
7572 // fold (or x, undef) -> -1
7573 if (!LegalOperations && (N0.isUndef() || N1.isUndef()))
7574 return DAG.getAllOnesConstant(DL, VT);
7575
7576 if (SDValue V = foldLogicOfSetCCs(IsAnd: false, N0, N1, DL))
7577 return V;
7578
7579 // (or (and X, C1), (and Y, C2)) -> (and (or X, Y), C3) if possible.
7580 if (N0.getOpcode() == ISD::AND && N1.getOpcode() == ISD::AND &&
7581 // Don't increase # computations.
7582 (N0->hasOneUse() || N1->hasOneUse())) {
7583 // We can only do this xform if we know that bits from X that are set in C2
7584 // but not in C1 are already zero. Likewise for Y.
7585 if (const ConstantSDNode *N0O1C =
7586 getAsNonOpaqueConstant(N: N0.getOperand(i: 1))) {
7587 if (const ConstantSDNode *N1O1C =
7588 getAsNonOpaqueConstant(N: N1.getOperand(i: 1))) {
7589 // We can only do this xform if we know that bits from X that are set in
7590 // C2 but not in C1 are already zero. Likewise for Y.
7591 const APInt &LHSMask = N0O1C->getAPIntValue();
7592 const APInt &RHSMask = N1O1C->getAPIntValue();
7593
7594 if (DAG.MaskedValueIsZero(Op: N0.getOperand(i: 0), Mask: RHSMask&~LHSMask) &&
7595 DAG.MaskedValueIsZero(Op: N1.getOperand(i: 0), Mask: LHSMask&~RHSMask)) {
7596 SDValue X = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT,
7597 N1: N0.getOperand(i: 0), N2: N1.getOperand(i: 0));
7598 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X,
7599 N2: DAG.getConstant(Val: LHSMask | RHSMask, DL, VT));
7600 }
7601 }
7602 }
7603 }
7604
7605 // (or (and X, M), (and X, N)) -> (and X, (or M, N))
7606 if (N0.getOpcode() == ISD::AND &&
7607 N1.getOpcode() == ISD::AND &&
7608 N0.getOperand(i: 0) == N1.getOperand(i: 0) &&
7609 // Don't increase # computations.
7610 (N0->hasOneUse() || N1->hasOneUse())) {
7611 SDValue X = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT,
7612 N1: N0.getOperand(i: 1), N2: N1.getOperand(i: 1));
7613 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0.getOperand(i: 0), N2: X);
7614 }
7615
7616 return SDValue();
7617}
7618
7619/// OR combines for which the commuted variant will be tried as well.
7620static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1,
7621 SDNode *N) {
7622 EVT VT = N0.getValueType();
7623
7624 auto peekThroughResize = [](SDValue V) {
7625 if (V->getOpcode() == ISD::ZERO_EXTEND || V->getOpcode() == ISD::TRUNCATE)
7626 return V->getOperand(Num: 0);
7627 return V;
7628 };
7629
7630 SDValue N0Resized = peekThroughResize(N0);
7631 if (N0Resized.getOpcode() == ISD::AND) {
7632 SDValue N1Resized = peekThroughResize(N1);
7633 SDValue N00 = N0Resized.getOperand(i: 0);
7634 SDValue N01 = N0Resized.getOperand(i: 1);
7635
7636 // fold or (and x, y), x --> x
7637 if (N00 == N1Resized || N01 == N1Resized)
7638 return N1;
7639
7640 // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
7641 // TODO: Set AllowUndefs = true.
7642 if (SDValue NotOperand = getBitwiseNotOperand(V: N01, Mask: N00,
7643 /* AllowUndefs */ false)) {
7644 if (peekThroughResize(NotOperand) == N1Resized)
7645 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT,
7646 N1: DAG.getZExtOrTrunc(Op: N00, DL: SDLoc(N), VT), N2: N1);
7647 }
7648
7649 // fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
7650 if (SDValue NotOperand = getBitwiseNotOperand(V: N00, Mask: N01,
7651 /* AllowUndefs */ false)) {
7652 if (peekThroughResize(NotOperand) == N1Resized)
7653 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT,
7654 N1: DAG.getZExtOrTrunc(Op: N01, DL: SDLoc(N), VT), N2: N1);
7655 }
7656 }
7657
7658 SDValue X, Y;
7659
7660 // fold or (xor X, N1), N1 --> or X, N1
7661 if (sd_match(N: N0, P: m_Xor(L: m_Value(N&: X), R: m_Specific(N: N1))))
7662 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT, N1: X, N2: N1);
7663
7664 // fold or (xor x, y), (x and/or y) --> or x, y
7665 if (sd_match(N: N0, P: m_Xor(L: m_Value(N&: X), R: m_Value(N&: Y))) &&
7666 (sd_match(N: N1, P: m_And(L: m_Specific(N: X), R: m_Specific(N: Y))) ||
7667 sd_match(N: N1, P: m_Or(L: m_Specific(N: X), R: m_Specific(N: Y)))))
7668 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT, N1: X, N2: Y);
7669
7670 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
7671 return R;
7672
7673 auto peekThroughZext = [](SDValue V) {
7674 if (V->getOpcode() == ISD::ZERO_EXTEND)
7675 return V->getOperand(Num: 0);
7676 return V;
7677 };
7678
7679 // (fshl X, ?, Y) | (shl X, Y) --> fshl X, ?, Y
7680 if (N0.getOpcode() == ISD::FSHL && N1.getOpcode() == ISD::SHL &&
7681 N0.getOperand(i: 0) == N1.getOperand(i: 0) &&
7682 peekThroughZext(N0.getOperand(i: 2)) == peekThroughZext(N1.getOperand(i: 1)))
7683 return N0;
7684
7685 // (fshr ?, X, Y) | (srl X, Y) --> fshr ?, X, Y
7686 if (N0.getOpcode() == ISD::FSHR && N1.getOpcode() == ISD::SRL &&
7687 N0.getOperand(i: 1) == N1.getOperand(i: 0) &&
7688 peekThroughZext(N0.getOperand(i: 2)) == peekThroughZext(N1.getOperand(i: 1)))
7689 return N0;
7690
7691 return SDValue();
7692}
7693
7694SDValue DAGCombiner::visitOR(SDNode *N) {
7695 SDValue N0 = N->getOperand(Num: 0);
7696 SDValue N1 = N->getOperand(Num: 1);
7697 EVT VT = N1.getValueType();
7698 SDLoc DL(N);
7699
7700 // x | x --> x
7701 if (N0 == N1)
7702 return N0;
7703
7704 // fold (or c1, c2) -> c1|c2
7705 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::OR, DL, VT, Ops: {N0, N1}))
7706 return C;
7707
7708 // canonicalize constant to RHS
7709 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
7710 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
7711 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1, N2: N0);
7712
7713 // fold vector ops
7714 if (VT.isVector()) {
7715 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
7716 return FoldedVOp;
7717
7718 // fold (or x, 0) -> x, vector edition
7719 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
7720 return N0;
7721
7722 // fold (or x, -1) -> -1, vector edition
7723 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode()))
7724 // do not return N1, because undef node may exist in N1
7725 return DAG.getAllOnesConstant(DL, VT: N1.getValueType());
7726
7727 // fold (or (shuf A, V_0, MA), (shuf B, V_0, MB)) -> (shuf A, B, Mask)
7728 // Do this only if the resulting type / shuffle is legal.
7729 auto *SV0 = dyn_cast<ShuffleVectorSDNode>(Val&: N0);
7730 auto *SV1 = dyn_cast<ShuffleVectorSDNode>(Val&: N1);
7731 if (SV0 && SV1 && TLI.isTypeLegal(VT)) {
7732 bool ZeroN00 = ISD::isBuildVectorAllZeros(N: N0.getOperand(i: 0).getNode());
7733 bool ZeroN01 = ISD::isBuildVectorAllZeros(N: N0.getOperand(i: 1).getNode());
7734 bool ZeroN10 = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 0).getNode());
7735 bool ZeroN11 = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 1).getNode());
7736 // Ensure both shuffles have a zero input.
7737 if ((ZeroN00 != ZeroN01) && (ZeroN10 != ZeroN11)) {
7738 assert((!ZeroN00 || !ZeroN01) && "Both inputs zero!");
7739 assert((!ZeroN10 || !ZeroN11) && "Both inputs zero!");
7740 bool CanFold = true;
7741 int NumElts = VT.getVectorNumElements();
7742 SmallVector<int, 4> Mask(NumElts, -1);
7743
7744 for (int i = 0; i != NumElts; ++i) {
7745 int M0 = SV0->getMaskElt(Idx: i);
7746 int M1 = SV1->getMaskElt(Idx: i);
7747
7748 // Determine if either index is pointing to a zero vector.
7749 bool M0Zero = M0 < 0 || (ZeroN00 == (M0 < NumElts));
7750 bool M1Zero = M1 < 0 || (ZeroN10 == (M1 < NumElts));
7751
7752 // If one element is zero and the otherside is undef, keep undef.
7753 // This also handles the case that both are undef.
7754 if ((M0Zero && M1 < 0) || (M1Zero && M0 < 0))
7755 continue;
7756
7757 // Make sure only one of the elements is zero.
7758 if (M0Zero == M1Zero) {
7759 CanFold = false;
7760 break;
7761 }
7762
7763 assert((M0 >= 0 || M1 >= 0) && "Undef index!");
7764
7765 // We have a zero and non-zero element. If the non-zero came from
7766 // SV0 make the index a LHS index. If it came from SV1, make it
7767 // a RHS index. We need to mod by NumElts because we don't care
7768 // which operand it came from in the original shuffles.
7769 Mask[i] = M1Zero ? M0 % NumElts : (M1 % NumElts) + NumElts;
7770 }
7771
7772 if (CanFold) {
7773 SDValue NewLHS = ZeroN00 ? N0.getOperand(i: 1) : N0.getOperand(i: 0);
7774 SDValue NewRHS = ZeroN10 ? N1.getOperand(i: 1) : N1.getOperand(i: 0);
7775 SDValue LegalShuffle =
7776 TLI.buildLegalVectorShuffle(VT, DL, N0: NewLHS, N1: NewRHS, Mask, DAG);
7777 if (LegalShuffle)
7778 return LegalShuffle;
7779 }
7780 }
7781 }
7782 }
7783
7784 // fold (or x, 0) -> x
7785 if (isNullConstant(V: N1))
7786 return N0;
7787
7788 // fold (or x, -1) -> -1
7789 if (isAllOnesConstant(V: N1))
7790 return N1;
7791
7792 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
7793 return NewSel;
7794
7795 // fold (or x, c) -> c iff (x & ~c) == 0
7796 ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(Val&: N1);
7797 if (N1C && DAG.MaskedValueIsZero(Op: N0, Mask: ~N1C->getAPIntValue()))
7798 return N1;
7799
7800 if (SDValue R = foldAndOrOfSETCC(LogicOp: N, DAG))
7801 return R;
7802
7803 if (SDValue Combined = visitORLike(N0, N1, DL))
7804 return Combined;
7805
7806 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
7807 return Combined;
7808
7809 // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
7810 if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
7811 return BSwap;
7812 if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
7813 return BSwap;
7814
7815 // reassociate or
7816 if (SDValue ROR = reassociateOps(Opc: ISD::OR, DL, N0, N1, Flags: N->getFlags()))
7817 return ROR;
7818
7819 // Fold or(vecreduce(x), vecreduce(y)) -> vecreduce(or(x, y))
7820 if (SDValue SD =
7821 reassociateReduction(RedOpc: ISD::VECREDUCE_OR, Opc: ISD::OR, DL, VT, N0, N1))
7822 return SD;
7823
7824 // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
7825 // iff (c1 & c2) != 0 or c1/c2 are undef.
7826 auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
7827 return !C1 || !C2 || C1->getAPIntValue().intersects(RHS: C2->getAPIntValue());
7828 };
7829 if (N0.getOpcode() == ISD::AND && N0->hasOneUse() &&
7830 ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchIntersect, AllowUndefs: true)) {
7831 if (SDValue COR = DAG.FoldConstantArithmetic(Opcode: ISD::OR, DL: SDLoc(N1), VT,
7832 Ops: {N1, N0.getOperand(i: 1)})) {
7833 SDValue IOR = DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N0), VT, N1: N0.getOperand(i: 0), N2: N1);
7834 AddToWorklist(N: IOR.getNode());
7835 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: COR, N2: IOR);
7836 }
7837 }
7838
7839 if (SDValue Combined = visitORCommutative(DAG, N0, N1, N))
7840 return Combined;
7841 if (SDValue Combined = visitORCommutative(DAG, N0: N1, N1: N0, N))
7842 return Combined;
7843
7844 // Simplify: (or (op x...), (op y...)) -> (op (or x, y))
7845 if (N0.getOpcode() == N1.getOpcode())
7846 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
7847 return V;
7848
7849 // See if this is some rotate idiom.
7850 if (SDValue Rot = MatchRotate(LHS: N0, RHS: N1, DL))
7851 return Rot;
7852
7853 if (SDValue Load = MatchLoadCombine(N))
7854 return Load;
7855
7856 // Simplify the operands using demanded-bits information.
7857 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
7858 return SDValue(N, 0);
7859
7860 // If OR can be rewritten into ADD, try combines based on ADD.
7861 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::ADD, VT)) &&
7862 DAG.isADDLike(Op: SDValue(N, 0)))
7863 if (SDValue Combined = visitADDLike(N))
7864 return Combined;
7865
7866 // Postpone until legalization completed to avoid interference with bswap
7867 // folding
7868 if (LegalOperations || VT.isVector())
7869 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
7870 return R;
7871
7872 return SDValue();
7873}
7874
7875static SDValue stripConstantMask(const SelectionDAG &DAG, SDValue Op,
7876 SDValue &Mask) {
7877 if (Op.getOpcode() == ISD::AND &&
7878 DAG.isConstantIntBuildVectorOrConstantInt(N: Op.getOperand(i: 1))) {
7879 Mask = Op.getOperand(i: 1);
7880 return Op.getOperand(i: 0);
7881 }
7882 return Op;
7883}
7884
7885/// Match "(X shl/srl V1) & V2" where V2 may not be present.
7886static bool matchRotateHalf(const SelectionDAG &DAG, SDValue Op, SDValue &Shift,
7887 SDValue &Mask) {
7888 Op = stripConstantMask(DAG, Op, Mask);
7889 if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) {
7890 Shift = Op;
7891 return true;
7892 }
7893 return false;
7894}
7895
7896/// Helper function for visitOR to extract the needed side of a rotate idiom
7897/// from a shl/srl/mul/udiv. This is meant to handle cases where
7898/// InstCombine merged some outside op with one of the shifts from
7899/// the rotate pattern.
7900/// \returns An empty \c SDValue if the needed shift couldn't be extracted.
7901/// Otherwise, returns an expansion of \p ExtractFrom based on the following
7902/// patterns:
7903///
7904/// (or (add v v) (shrl v bitwidth-1)):
7905/// expands (add v v) -> (shl v 1)
7906///
7907/// (or (mul v c0) (shrl (mul v c1) c2)):
7908/// expands (mul v c0) -> (shl (mul v c1) c3)
7909///
7910/// (or (udiv v c0) (shl (udiv v c1) c2)):
7911/// expands (udiv v c0) -> (shrl (udiv v c1) c3)
7912///
7913/// (or (shl v c0) (shrl (shl v c1) c2)):
7914/// expands (shl v c0) -> (shl (shl v c1) c3)
7915///
7916/// (or (shrl v c0) (shl (shrl v c1) c2)):
7917/// expands (shrl v c0) -> (shrl (shrl v c1) c3)
7918///
7919/// Such that in all cases, c3+c2==bitwidth(op v c1).
7920static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
7921 SDValue ExtractFrom, SDValue &Mask,
7922 const SDLoc &DL) {
7923 assert(OppShift && ExtractFrom && "Empty SDValue");
7924 if (OppShift.getOpcode() != ISD::SHL && OppShift.getOpcode() != ISD::SRL)
7925 return SDValue();
7926
7927 ExtractFrom = stripConstantMask(DAG, Op: ExtractFrom, Mask);
7928
7929 // Value and Type of the shift.
7930 SDValue OppShiftLHS = OppShift.getOperand(i: 0);
7931 EVT ShiftedVT = OppShiftLHS.getValueType();
7932
7933 // Amount of the existing shift.
7934 ConstantSDNode *OppShiftCst = isConstOrConstSplat(N: OppShift.getOperand(i: 1));
7935
7936 // (add v v) -> (shl v 1)
7937 // TODO: Should this be a general DAG canonicalization?
7938 if (OppShift.getOpcode() == ISD::SRL && OppShiftCst &&
7939 ExtractFrom.getOpcode() == ISD::ADD &&
7940 ExtractFrom.getOperand(i: 0) == ExtractFrom.getOperand(i: 1) &&
7941 ExtractFrom.getOperand(i: 0) == OppShiftLHS &&
7942 OppShiftCst->getAPIntValue() == ShiftedVT.getScalarSizeInBits() - 1)
7943 return DAG.getNode(Opcode: ISD::SHL, DL, VT: ShiftedVT, N1: OppShiftLHS,
7944 N2: DAG.getShiftAmountConstant(Val: 1, VT: ShiftedVT, DL));
7945
7946 // Preconditions:
7947 // (or (op0 v c0) (shiftl/r (op0 v c1) c2))
7948 //
7949 // Find opcode of the needed shift to be extracted from (op0 v c0).
7950 unsigned Opcode = ISD::DELETED_NODE;
7951 bool IsMulOrDiv = false;
7952 // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift
7953 // opcode or its arithmetic (mul or udiv) variant.
7954 auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) {
7955 IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant;
7956 if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift)
7957 return false;
7958 Opcode = NeededShift;
7959 return true;
7960 };
7961 // op0 must be either the needed shift opcode or the mul/udiv equivalent
7962 // that the needed shift can be extracted from.
7963 if ((OppShift.getOpcode() != ISD::SRL || !SelectOpcode(ISD::SHL, ISD::MUL)) &&
7964 (OppShift.getOpcode() != ISD::SHL || !SelectOpcode(ISD::SRL, ISD::UDIV)))
7965 return SDValue();
7966
7967 // op0 must be the same opcode on both sides, have the same LHS argument,
7968 // and produce the same value type.
7969 if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() ||
7970 OppShiftLHS.getOperand(i: 0) != ExtractFrom.getOperand(i: 0) ||
7971 ShiftedVT != ExtractFrom.getValueType())
7972 return SDValue();
7973
7974 // Constant mul/udiv/shift amount from the RHS of the shift's LHS op.
7975 ConstantSDNode *OppLHSCst = isConstOrConstSplat(N: OppShiftLHS.getOperand(i: 1));
7976 // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op.
7977 ConstantSDNode *ExtractFromCst =
7978 isConstOrConstSplat(N: ExtractFrom.getOperand(i: 1));
7979 // TODO: We should be able to handle non-uniform constant vectors for these values
7980 // Check that we have constant values.
7981 if (!OppShiftCst || !OppShiftCst->getAPIntValue() ||
7982 !OppLHSCst || !OppLHSCst->getAPIntValue() ||
7983 !ExtractFromCst || !ExtractFromCst->getAPIntValue())
7984 return SDValue();
7985
7986 // Compute the shift amount we need to extract to complete the rotate.
7987 const unsigned VTWidth = ShiftedVT.getScalarSizeInBits();
7988 if (OppShiftCst->getAPIntValue().ugt(RHS: VTWidth))
7989 return SDValue();
7990 APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue();
7991 // Normalize the bitwidth of the two mul/udiv/shift constant operands.
7992 APInt ExtractFromAmt = ExtractFromCst->getAPIntValue();
7993 APInt OppLHSAmt = OppLHSCst->getAPIntValue();
7994 zeroExtendToMatch(LHS&: ExtractFromAmt, RHS&: OppLHSAmt);
7995
7996 // Now try extract the needed shift from the ExtractFrom op and see if the
7997 // result matches up with the existing shift's LHS op.
7998 if (IsMulOrDiv) {
7999 // Op to extract from is a mul or udiv by a constant.
8000 // Check:
8001 // c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0
8002 // c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0
8003 const APInt ExtractDiv = APInt::getOneBitSet(numBits: ExtractFromAmt.getBitWidth(),
8004 BitNo: NeededShiftAmt.getZExtValue());
8005 APInt ResultAmt;
8006 APInt Rem;
8007 APInt::udivrem(LHS: ExtractFromAmt, RHS: ExtractDiv, Quotient&: ResultAmt, Remainder&: Rem);
8008 if (Rem != 0 || ResultAmt != OppLHSAmt)
8009 return SDValue();
8010 } else {
8011 // Op to extract from is a shift by a constant.
8012 // Check:
8013 // c2 - (bitwidth(op0 v c0) - c1) == c0
8014 if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc(
8015 width: ExtractFromAmt.getBitWidth()))
8016 return SDValue();
8017 }
8018
8019 // Return the expanded shift op that should allow a rotate to be formed.
8020 EVT ShiftVT = OppShift.getOperand(i: 1).getValueType();
8021 EVT ResVT = ExtractFrom.getValueType();
8022 SDValue NewShiftNode = DAG.getConstant(Val: NeededShiftAmt, DL, VT: ShiftVT);
8023 return DAG.getNode(Opcode, DL, VT: ResVT, N1: OppShiftLHS, N2: NewShiftNode);
8024}
8025
8026// Return true if we can prove that, whenever Neg and Pos are both in the
8027// range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that
8028// for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
8029//
8030// (or (shift1 X, Neg), (shift2 X, Pos))
8031//
8032// reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
8033// in direction shift1 by Neg. The range [0, EltSize) means that we only need
8034// to consider shift amounts with defined behavior.
8035//
8036// The IsRotate flag should be set when the LHS of both shifts is the same.
8037// Otherwise if matching a general funnel shift, it should be clear.
8038static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
8039 SelectionDAG &DAG, bool IsRotate) {
8040 const auto &TLI = DAG.getTargetLoweringInfo();
8041 // If EltSize is a power of 2 then:
8042 //
8043 // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
8044 // (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
8045 //
8046 // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
8047 // for the stronger condition:
8048 //
8049 // Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1) [A]
8050 //
8051 // for all Neg and Pos. Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
8052 // we can just replace Neg with Neg' for the rest of the function.
8053 //
8054 // In other cases we check for the even stronger condition:
8055 //
8056 // Neg == EltSize - Pos [B]
8057 //
8058 // for all Neg and Pos. Note that the (or ...) then invokes undefined
8059 // behavior if Pos == 0 (and consequently Neg == EltSize).
8060 //
8061 // We could actually use [A] whenever EltSize is a power of 2, but the
8062 // only extra cases that it would match are those uninteresting ones
8063 // where Neg and Pos are never in range at the same time. E.g. for
8064 // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
8065 // as well as (sub 32, Pos), but:
8066 //
8067 // (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
8068 //
8069 // always invokes undefined behavior for 32-bit X.
8070 //
8071 // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
8072 // This allows us to peek through any operations that only affect Mask's
8073 // un-demanded bits.
8074 //
8075 // NOTE: We can only do this when matching operations which won't modify the
8076 // least Log2(EltSize) significant bits and not a general funnel shift.
8077 unsigned MaskLoBits = 0;
8078 if (IsRotate && isPowerOf2_64(Value: EltSize)) {
8079 unsigned Bits = Log2_64(Value: EltSize);
8080 unsigned NegBits = Neg.getScalarValueSizeInBits();
8081 if (NegBits >= Bits) {
8082 APInt DemandedBits = APInt::getLowBitsSet(numBits: NegBits, loBitsSet: Bits);
8083 if (SDValue Inner =
8084 TLI.SimplifyMultipleUseDemandedBits(Op: Neg, DemandedBits, DAG)) {
8085 Neg = Inner;
8086 MaskLoBits = Bits;
8087 }
8088 }
8089 }
8090
8091 // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
8092 if (Neg.getOpcode() != ISD::SUB)
8093 return false;
8094 ConstantSDNode *NegC = isConstOrConstSplat(N: Neg.getOperand(i: 0));
8095 if (!NegC)
8096 return false;
8097 SDValue NegOp1 = Neg.getOperand(i: 1);
8098
8099 // On the RHS of [A], if Pos is the result of operation on Pos' that won't
8100 // affect Mask's demanded bits, just replace Pos with Pos'. These operations
8101 // are redundant for the purpose of the equality.
8102 if (MaskLoBits) {
8103 unsigned PosBits = Pos.getScalarValueSizeInBits();
8104 if (PosBits >= MaskLoBits) {
8105 APInt DemandedBits = APInt::getLowBitsSet(numBits: PosBits, loBitsSet: MaskLoBits);
8106 if (SDValue Inner =
8107 TLI.SimplifyMultipleUseDemandedBits(Op: Pos, DemandedBits, DAG)) {
8108 Pos = Inner;
8109 }
8110 }
8111 }
8112
8113 // The condition we need is now:
8114 //
8115 // (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
8116 //
8117 // If NegOp1 == Pos then we need:
8118 //
8119 // EltSize & Mask == NegC & Mask
8120 //
8121 // (because "x & Mask" is a truncation and distributes through subtraction).
8122 //
8123 // We also need to account for a potential truncation of NegOp1 if the amount
8124 // has already been legalized to a shift amount type.
8125 APInt Width;
8126 if ((Pos == NegOp1) ||
8127 (NegOp1.getOpcode() == ISD::TRUNCATE && Pos == NegOp1.getOperand(i: 0)))
8128 Width = NegC->getAPIntValue();
8129
8130 // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
8131 // Then the condition we want to prove becomes:
8132 //
8133 // (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
8134 //
8135 // which, again because "x & Mask" is a truncation, becomes:
8136 //
8137 // NegC & Mask == (EltSize - PosC) & Mask
8138 // EltSize & Mask == (NegC + PosC) & Mask
8139 else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(i: 0) == NegOp1) {
8140 if (ConstantSDNode *PosC = isConstOrConstSplat(N: Pos.getOperand(i: 1)))
8141 Width = PosC->getAPIntValue() + NegC->getAPIntValue();
8142 else
8143 return false;
8144 } else
8145 return false;
8146
8147 // Now we just need to check that EltSize & Mask == Width & Mask.
8148 if (MaskLoBits)
8149 // EltSize & Mask is 0 since Mask is EltSize - 1.
8150 return Width.getLoBits(numBits: MaskLoBits) == 0;
8151 return Width == EltSize;
8152}
8153
8154// A subroutine of MatchRotate used once we have found an OR of two opposite
8155// shifts of Shifted. If Neg == <operand size> - Pos then the OR reduces
8156// to both (PosOpcode Shifted, Pos) and (NegOpcode Shifted, Neg), with the
8157// former being preferred if supported. InnerPos and InnerNeg are Pos and
8158// Neg with outer conversions stripped away.
8159SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
8160 SDValue Neg, SDValue InnerPos,
8161 SDValue InnerNeg, bool HasPos,
8162 unsigned PosOpcode, unsigned NegOpcode,
8163 const SDLoc &DL) {
8164 // fold (or (shl x, (*ext y)),
8165 // (srl x, (*ext (sub 32, y)))) ->
8166 // (rotl x, y) or (rotr x, (sub 32, y))
8167 //
8168 // fold (or (shl x, (*ext (sub 32, y))),
8169 // (srl x, (*ext y))) ->
8170 // (rotr x, y) or (rotl x, (sub 32, y))
8171 EVT VT = Shifted.getValueType();
8172 if (matchRotateSub(Pos: InnerPos, Neg: InnerNeg, EltSize: VT.getScalarSizeInBits(), DAG,
8173 /*IsRotate*/ true)) {
8174 return DAG.getNode(Opcode: HasPos ? PosOpcode : NegOpcode, DL, VT, N1: Shifted,
8175 N2: HasPos ? Pos : Neg);
8176 }
8177
8178 return SDValue();
8179}
8180
8181// A subroutine of MatchRotate used once we have found an OR of two opposite
8182// shifts of N0 + N1. If Neg == <operand size> - Pos then the OR reduces
8183// to both (PosOpcode N0, N1, Pos) and (NegOpcode N0, N1, Neg), with the
8184// former being preferred if supported. InnerPos and InnerNeg are Pos and
8185// Neg with outer conversions stripped away.
8186// TODO: Merge with MatchRotatePosNeg.
8187SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
8188 SDValue Neg, SDValue InnerPos,
8189 SDValue InnerNeg, bool HasPos,
8190 unsigned PosOpcode, unsigned NegOpcode,
8191 const SDLoc &DL) {
8192 EVT VT = N0.getValueType();
8193 unsigned EltBits = VT.getScalarSizeInBits();
8194
8195 // fold (or (shl x0, (*ext y)),
8196 // (srl x1, (*ext (sub 32, y)))) ->
8197 // (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
8198 //
8199 // fold (or (shl x0, (*ext (sub 32, y))),
8200 // (srl x1, (*ext y))) ->
8201 // (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
8202 if (matchRotateSub(Pos: InnerPos, Neg: InnerNeg, EltSize: EltBits, DAG, /*IsRotate*/ N0 == N1)) {
8203 return DAG.getNode(Opcode: HasPos ? PosOpcode : NegOpcode, DL, VT, N1: N0, N2: N1,
8204 N3: HasPos ? Pos : Neg);
8205 }
8206
8207 // Matching the shift+xor cases, we can't easily use the xor'd shift amount
8208 // so for now just use the PosOpcode case if its legal.
8209 // TODO: When can we use the NegOpcode case?
8210 if (PosOpcode == ISD::FSHL && isPowerOf2_32(Value: EltBits)) {
8211 auto IsBinOpImm = [](SDValue Op, unsigned BinOpc, unsigned Imm) {
8212 if (Op.getOpcode() != BinOpc)
8213 return false;
8214 ConstantSDNode *Cst = isConstOrConstSplat(N: Op.getOperand(i: 1));
8215 return Cst && (Cst->getAPIntValue() == Imm);
8216 };
8217
8218 // fold (or (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
8219 // -> (fshl x0, x1, y)
8220 if (IsBinOpImm(N1, ISD::SRL, 1) &&
8221 IsBinOpImm(InnerNeg, ISD::XOR, EltBits - 1) &&
8222 InnerPos == InnerNeg.getOperand(i: 0) &&
8223 TLI.isOperationLegalOrCustom(Op: ISD::FSHL, VT)) {
8224 return DAG.getNode(Opcode: ISD::FSHL, DL, VT, N1: N0, N2: N1.getOperand(i: 0), N3: Pos);
8225 }
8226
8227 // fold (or (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
8228 // -> (fshr x0, x1, y)
8229 if (IsBinOpImm(N0, ISD::SHL, 1) &&
8230 IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
8231 InnerNeg == InnerPos.getOperand(i: 0) &&
8232 TLI.isOperationLegalOrCustom(Op: ISD::FSHR, VT)) {
8233 return DAG.getNode(Opcode: ISD::FSHR, DL, VT, N1: N0.getOperand(i: 0), N2: N1, N3: Neg);
8234 }
8235
8236 // fold (or (shl (add x0, x0), (xor y, 31)), (srl x1, y))
8237 // -> (fshr x0, x1, y)
8238 // TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
8239 if (N0.getOpcode() == ISD::ADD && N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
8240 IsBinOpImm(InnerPos, ISD::XOR, EltBits - 1) &&
8241 InnerNeg == InnerPos.getOperand(i: 0) &&
8242 TLI.isOperationLegalOrCustom(Op: ISD::FSHR, VT)) {
8243 return DAG.getNode(Opcode: ISD::FSHR, DL, VT, N1: N0.getOperand(i: 0), N2: N1, N3: Neg);
8244 }
8245 }
8246
8247 return SDValue();
8248}
8249
8250// MatchRotate - Handle an 'or' of two operands. If this is one of the many
8251// idioms for rotate, and if the target supports rotation instructions, generate
8252// a rot[lr]. This also matches funnel shift patterns, similar to rotation but
8253// with different shifted sources.
8254SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
8255 EVT VT = LHS.getValueType();
8256
8257 // The target must have at least one rotate/funnel flavor.
8258 // We still try to match rotate by constant pre-legalization.
8259 // TODO: Support pre-legalization funnel-shift by constant.
8260 bool HasROTL = hasOperation(Opcode: ISD::ROTL, VT);
8261 bool HasROTR = hasOperation(Opcode: ISD::ROTR, VT);
8262 bool HasFSHL = hasOperation(Opcode: ISD::FSHL, VT);
8263 bool HasFSHR = hasOperation(Opcode: ISD::FSHR, VT);
8264
8265 // If the type is going to be promoted and the target has enabled custom
8266 // lowering for rotate, allow matching rotate by non-constants. Only allow
8267 // this for scalar types.
8268 if (VT.isScalarInteger() && TLI.getTypeAction(Context&: *DAG.getContext(), VT) ==
8269 TargetLowering::TypePromoteInteger) {
8270 HasROTL |= TLI.getOperationAction(Op: ISD::ROTL, VT) == TargetLowering::Custom;
8271 HasROTR |= TLI.getOperationAction(Op: ISD::ROTR, VT) == TargetLowering::Custom;
8272 }
8273
8274 if (LegalOperations && !HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
8275 return SDValue();
8276
8277 // Check for truncated rotate.
8278 if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
8279 LHS.getOperand(i: 0).getValueType() == RHS.getOperand(i: 0).getValueType()) {
8280 assert(LHS.getValueType() == RHS.getValueType());
8281 if (SDValue Rot = MatchRotate(LHS: LHS.getOperand(i: 0), RHS: RHS.getOperand(i: 0), DL)) {
8282 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LHS), VT: LHS.getValueType(), Operand: Rot);
8283 }
8284 }
8285
8286 // Match "(X shl/srl V1) & V2" where V2 may not be present.
8287 SDValue LHSShift; // The shift.
8288 SDValue LHSMask; // AND value if any.
8289 matchRotateHalf(DAG, Op: LHS, Shift&: LHSShift, Mask&: LHSMask);
8290
8291 SDValue RHSShift; // The shift.
8292 SDValue RHSMask; // AND value if any.
8293 matchRotateHalf(DAG, Op: RHS, Shift&: RHSShift, Mask&: RHSMask);
8294
8295 // If neither side matched a rotate half, bail
8296 if (!LHSShift && !RHSShift)
8297 return SDValue();
8298
8299 // InstCombine may have combined a constant shl, srl, mul, or udiv with one
8300 // side of the rotate, so try to handle that here. In all cases we need to
8301 // pass the matched shift from the opposite side to compute the opcode and
8302 // needed shift amount to extract. We still want to do this if both sides
8303 // matched a rotate half because one half may be a potential overshift that
8304 // can be broken down (ie if InstCombine merged two shl or srl ops into a
8305 // single one).
8306
8307 // Have LHS side of the rotate, try to extract the needed shift from the RHS.
8308 if (LHSShift)
8309 if (SDValue NewRHSShift =
8310 extractShiftForRotate(DAG, OppShift: LHSShift, ExtractFrom: RHS, Mask&: RHSMask, DL))
8311 RHSShift = NewRHSShift;
8312 // Have RHS side of the rotate, try to extract the needed shift from the LHS.
8313 if (RHSShift)
8314 if (SDValue NewLHSShift =
8315 extractShiftForRotate(DAG, OppShift: RHSShift, ExtractFrom: LHS, Mask&: LHSMask, DL))
8316 LHSShift = NewLHSShift;
8317
8318 // If a side is still missing, nothing else we can do.
8319 if (!RHSShift || !LHSShift)
8320 return SDValue();
8321
8322 // At this point we've matched or extracted a shift op on each side.
8323
8324 if (LHSShift.getOpcode() == RHSShift.getOpcode())
8325 return SDValue(); // Shifts must disagree.
8326
8327 // Canonicalize shl to left side in a shl/srl pair.
8328 if (RHSShift.getOpcode() == ISD::SHL) {
8329 std::swap(a&: LHS, b&: RHS);
8330 std::swap(a&: LHSShift, b&: RHSShift);
8331 std::swap(a&: LHSMask, b&: RHSMask);
8332 }
8333
8334 // Something has gone wrong - we've lost the shl/srl pair - bail.
8335 if (LHSShift.getOpcode() != ISD::SHL || RHSShift.getOpcode() != ISD::SRL)
8336 return SDValue();
8337
8338 unsigned EltSizeInBits = VT.getScalarSizeInBits();
8339 SDValue LHSShiftArg = LHSShift.getOperand(i: 0);
8340 SDValue LHSShiftAmt = LHSShift.getOperand(i: 1);
8341 SDValue RHSShiftArg = RHSShift.getOperand(i: 0);
8342 SDValue RHSShiftAmt = RHSShift.getOperand(i: 1);
8343
8344 auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
8345 ConstantSDNode *RHS) {
8346 return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
8347 };
8348
8349 auto ApplyMasks = [&](SDValue Res) {
8350 // If there is an AND of either shifted operand, apply it to the result.
8351 if (LHSMask.getNode() || RHSMask.getNode()) {
8352 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
8353 SDValue Mask = AllOnes;
8354
8355 if (LHSMask.getNode()) {
8356 SDValue RHSBits = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: AllOnes, N2: RHSShiftAmt);
8357 Mask = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Mask,
8358 N2: DAG.getNode(Opcode: ISD::OR, DL, VT, N1: LHSMask, N2: RHSBits));
8359 }
8360 if (RHSMask.getNode()) {
8361 SDValue LHSBits = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: AllOnes, N2: LHSShiftAmt);
8362 Mask = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Mask,
8363 N2: DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RHSMask, N2: LHSBits));
8364 }
8365
8366 Res = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Res, N2: Mask);
8367 }
8368
8369 return Res;
8370 };
8371
8372 // TODO: Support pre-legalization funnel-shift by constant.
8373 bool IsRotate = LHSShiftArg == RHSShiftArg;
8374 if (!IsRotate && !(HasFSHL || HasFSHR)) {
8375 if (TLI.isTypeLegal(VT) && LHS.hasOneUse() && RHS.hasOneUse() &&
8376 ISD::matchBinaryPredicate(LHS: LHSShiftAmt, RHS: RHSShiftAmt, Match: MatchRotateSum)) {
8377 // Look for a disguised rotate by constant.
8378 // The common shifted operand X may be hidden inside another 'or'.
8379 SDValue X, Y;
8380 auto matchOr = [&X, &Y](SDValue Or, SDValue CommonOp) {
8381 if (!Or.hasOneUse() || Or.getOpcode() != ISD::OR)
8382 return false;
8383 if (CommonOp == Or.getOperand(i: 0)) {
8384 X = CommonOp;
8385 Y = Or.getOperand(i: 1);
8386 return true;
8387 }
8388 if (CommonOp == Or.getOperand(i: 1)) {
8389 X = CommonOp;
8390 Y = Or.getOperand(i: 0);
8391 return true;
8392 }
8393 return false;
8394 };
8395
8396 SDValue Res;
8397 if (matchOr(LHSShiftArg, RHSShiftArg)) {
8398 // (shl (X | Y), C1) | (srl X, C2) --> (rotl X, C1) | (shl Y, C1)
8399 SDValue RotX = DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: X, N2: LHSShiftAmt);
8400 SDValue ShlY = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Y, N2: LHSShiftAmt);
8401 Res = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RotX, N2: ShlY);
8402 } else if (matchOr(RHSShiftArg, LHSShiftArg)) {
8403 // (shl X, C1) | (srl (X | Y), C2) --> (rotl X, C1) | (srl Y, C2)
8404 SDValue RotX = DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: X, N2: LHSShiftAmt);
8405 SDValue SrlY = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Y, N2: RHSShiftAmt);
8406 Res = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: RotX, N2: SrlY);
8407 } else {
8408 return SDValue();
8409 }
8410
8411 return ApplyMasks(Res);
8412 }
8413
8414 return SDValue(); // Requires funnel shift support.
8415 }
8416
8417 // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
8418 // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
8419 // fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
8420 // fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
8421 // iff C1+C2 == EltSizeInBits
8422 if (ISD::matchBinaryPredicate(LHS: LHSShiftAmt, RHS: RHSShiftAmt, Match: MatchRotateSum)) {
8423 SDValue Res;
8424 if (IsRotate && (HasROTL || HasROTR || !(HasFSHL || HasFSHR))) {
8425 bool UseROTL = !LegalOperations || HasROTL;
8426 Res = DAG.getNode(Opcode: UseROTL ? ISD::ROTL : ISD::ROTR, DL, VT, N1: LHSShiftArg,
8427 N2: UseROTL ? LHSShiftAmt : RHSShiftAmt);
8428 } else {
8429 bool UseFSHL = !LegalOperations || HasFSHL;
8430 Res = DAG.getNode(Opcode: UseFSHL ? ISD::FSHL : ISD::FSHR, DL, VT, N1: LHSShiftArg,
8431 N2: RHSShiftArg, N3: UseFSHL ? LHSShiftAmt : RHSShiftAmt);
8432 }
8433
8434 return ApplyMasks(Res);
8435 }
8436
8437 // Even pre-legalization, we can't easily rotate/funnel-shift by a variable
8438 // shift.
8439 if (!HasROTL && !HasROTR && !HasFSHL && !HasFSHR)
8440 return SDValue();
8441
8442 // If there is a mask here, and we have a variable shift, we can't be sure
8443 // that we're masking out the right stuff.
8444 if (LHSMask.getNode() || RHSMask.getNode())
8445 return SDValue();
8446
8447 // If the shift amount is sign/zext/any-extended just peel it off.
8448 SDValue LExtOp0 = LHSShiftAmt;
8449 SDValue RExtOp0 = RHSShiftAmt;
8450 if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
8451 LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
8452 LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
8453 LHSShiftAmt.getOpcode() == ISD::TRUNCATE) &&
8454 (RHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
8455 RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
8456 RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
8457 RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
8458 LExtOp0 = LHSShiftAmt.getOperand(i: 0);
8459 RExtOp0 = RHSShiftAmt.getOperand(i: 0);
8460 }
8461
8462 if (IsRotate && (HasROTL || HasROTR)) {
8463 SDValue TryL =
8464 MatchRotatePosNeg(Shifted: LHSShiftArg, Pos: LHSShiftAmt, Neg: RHSShiftAmt, InnerPos: LExtOp0,
8465 InnerNeg: RExtOp0, HasPos: HasROTL, PosOpcode: ISD::ROTL, NegOpcode: ISD::ROTR, DL);
8466 if (TryL)
8467 return TryL;
8468
8469 SDValue TryR =
8470 MatchRotatePosNeg(Shifted: RHSShiftArg, Pos: RHSShiftAmt, Neg: LHSShiftAmt, InnerPos: RExtOp0,
8471 InnerNeg: LExtOp0, HasPos: HasROTR, PosOpcode: ISD::ROTR, NegOpcode: ISD::ROTL, DL);
8472 if (TryR)
8473 return TryR;
8474 }
8475
8476 SDValue TryL =
8477 MatchFunnelPosNeg(N0: LHSShiftArg, N1: RHSShiftArg, Pos: LHSShiftAmt, Neg: RHSShiftAmt,
8478 InnerPos: LExtOp0, InnerNeg: RExtOp0, HasPos: HasFSHL, PosOpcode: ISD::FSHL, NegOpcode: ISD::FSHR, DL);
8479 if (TryL)
8480 return TryL;
8481
8482 SDValue TryR =
8483 MatchFunnelPosNeg(N0: LHSShiftArg, N1: RHSShiftArg, Pos: RHSShiftAmt, Neg: LHSShiftAmt,
8484 InnerPos: RExtOp0, InnerNeg: LExtOp0, HasPos: HasFSHR, PosOpcode: ISD::FSHR, NegOpcode: ISD::FSHL, DL);
8485 if (TryR)
8486 return TryR;
8487
8488 return SDValue();
8489}
8490
8491/// Recursively traverses the expression calculating the origin of the requested
8492/// byte of the given value. Returns std::nullopt if the provider can't be
8493/// calculated.
8494///
8495/// For all the values except the root of the expression, we verify that the
8496/// value has exactly one use and if not then return std::nullopt. This way if
8497/// the origin of the byte is returned it's guaranteed that the values which
8498/// contribute to the byte are not used outside of this expression.
8499
8500/// However, there is a special case when dealing with vector loads -- we allow
8501/// more than one use if the load is a vector type. Since the values that
8502/// contribute to the byte ultimately come from the ExtractVectorElements of the
8503/// Load, we don't care if the Load has uses other than ExtractVectorElements,
8504/// because those operations are independent from the pattern to be combined.
8505/// For vector loads, we simply care that the ByteProviders are adjacent
8506/// positions of the same vector, and their index matches the byte that is being
8507/// provided. This is captured by the \p VectorIndex algorithm. \p VectorIndex
8508/// is the index used in an ExtractVectorElement, and \p StartingIndex is the
8509/// byte position we are trying to provide for the LoadCombine. If these do
8510/// not match, then we can not combine the vector loads. \p Index uses the
8511/// byte position we are trying to provide for and is matched against the
8512/// shl and load size. The \p Index algorithm ensures the requested byte is
8513/// provided for by the pattern, and the pattern does not over provide bytes.
8514///
8515///
8516/// The supported LoadCombine pattern for vector loads is as follows
8517/// or
8518/// / \
8519/// or shl
8520/// / \ |
8521/// or shl zext
8522/// / \ | |
8523/// shl zext zext EVE*
8524/// | | | |
8525/// zext EVE* EVE* LOAD
8526/// | | |
8527/// EVE* LOAD LOAD
8528/// |
8529/// LOAD
8530///
8531/// *ExtractVectorElement
8532using SDByteProvider = ByteProvider<SDNode *>;
8533
8534static std::optional<SDByteProvider>
8535calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
8536 std::optional<uint64_t> VectorIndex,
8537 unsigned StartingIndex = 0) {
8538
8539 // Typical i64 by i8 pattern requires recursion up to 8 calls depth
8540 if (Depth == 10)
8541 return std::nullopt;
8542
8543 // Only allow multiple uses if the instruction is a vector load (in which
8544 // case we will use the load for every ExtractVectorElement)
8545 if (Depth && !Op.hasOneUse() &&
8546 (Op.getOpcode() != ISD::LOAD || !Op.getValueType().isVector()))
8547 return std::nullopt;
8548
8549 // Fail to combine if we have encountered anything but a LOAD after handling
8550 // an ExtractVectorElement.
8551 if (Op.getOpcode() != ISD::LOAD && VectorIndex.has_value())
8552 return std::nullopt;
8553
8554 unsigned BitWidth = Op.getValueSizeInBits();
8555 if (BitWidth % 8 != 0)
8556 return std::nullopt;
8557 unsigned ByteWidth = BitWidth / 8;
8558 assert(Index < ByteWidth && "invalid index requested");
8559 (void) ByteWidth;
8560
8561 switch (Op.getOpcode()) {
8562 case ISD::OR: {
8563 auto LHS =
8564 calculateByteProvider(Op: Op->getOperand(Num: 0), Index, Depth: Depth + 1, VectorIndex);
8565 if (!LHS)
8566 return std::nullopt;
8567 auto RHS =
8568 calculateByteProvider(Op: Op->getOperand(Num: 1), Index, Depth: Depth + 1, VectorIndex);
8569 if (!RHS)
8570 return std::nullopt;
8571
8572 if (LHS->isConstantZero())
8573 return RHS;
8574 if (RHS->isConstantZero())
8575 return LHS;
8576 return std::nullopt;
8577 }
8578 case ISD::SHL: {
8579 auto ShiftOp = dyn_cast<ConstantSDNode>(Val: Op->getOperand(Num: 1));
8580 if (!ShiftOp)
8581 return std::nullopt;
8582
8583 uint64_t BitShift = ShiftOp->getZExtValue();
8584
8585 if (BitShift % 8 != 0)
8586 return std::nullopt;
8587 uint64_t ByteShift = BitShift / 8;
8588
8589 // If we are shifting by an amount greater than the index we are trying to
8590 // provide, then do not provide anything. Otherwise, subtract the index by
8591 // the amount we shifted by.
8592 return Index < ByteShift
8593 ? SDByteProvider::getConstantZero()
8594 : calculateByteProvider(Op: Op->getOperand(Num: 0), Index: Index - ByteShift,
8595 Depth: Depth + 1, VectorIndex, StartingIndex: Index);
8596 }
8597 case ISD::ANY_EXTEND:
8598 case ISD::SIGN_EXTEND:
8599 case ISD::ZERO_EXTEND: {
8600 SDValue NarrowOp = Op->getOperand(Num: 0);
8601 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
8602 if (NarrowBitWidth % 8 != 0)
8603 return std::nullopt;
8604 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8605
8606 if (Index >= NarrowByteWidth)
8607 return Op.getOpcode() == ISD::ZERO_EXTEND
8608 ? std::optional<SDByteProvider>(
8609 SDByteProvider::getConstantZero())
8610 : std::nullopt;
8611 return calculateByteProvider(Op: NarrowOp, Index, Depth: Depth + 1, VectorIndex,
8612 StartingIndex);
8613 }
8614 case ISD::BSWAP:
8615 return calculateByteProvider(Op: Op->getOperand(Num: 0), Index: ByteWidth - Index - 1,
8616 Depth: Depth + 1, VectorIndex, StartingIndex);
8617 case ISD::EXTRACT_VECTOR_ELT: {
8618 auto OffsetOp = dyn_cast<ConstantSDNode>(Val: Op->getOperand(Num: 1));
8619 if (!OffsetOp)
8620 return std::nullopt;
8621
8622 VectorIndex = OffsetOp->getZExtValue();
8623
8624 SDValue NarrowOp = Op->getOperand(Num: 0);
8625 unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
8626 if (NarrowBitWidth % 8 != 0)
8627 return std::nullopt;
8628 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8629 // EXTRACT_VECTOR_ELT can extend the element type to the width of the return
8630 // type, leaving the high bits undefined.
8631 if (Index >= NarrowByteWidth)
8632 return std::nullopt;
8633
8634 // Check to see if the position of the element in the vector corresponds
8635 // with the byte we are trying to provide for. In the case of a vector of
8636 // i8, this simply means the VectorIndex == StartingIndex. For non i8 cases,
8637 // the element will provide a range of bytes. For example, if we have a
8638 // vector of i16s, each element provides two bytes (V[1] provides byte 2 and
8639 // 3).
8640 if (*VectorIndex * NarrowByteWidth > StartingIndex)
8641 return std::nullopt;
8642 if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
8643 return std::nullopt;
8644
8645 return calculateByteProvider(Op: Op->getOperand(Num: 0), Index, Depth: Depth + 1,
8646 VectorIndex, StartingIndex);
8647 }
8648 case ISD::LOAD: {
8649 auto L = cast<LoadSDNode>(Val: Op.getNode());
8650 if (!L->isSimple() || L->isIndexed())
8651 return std::nullopt;
8652
8653 unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
8654 if (NarrowBitWidth % 8 != 0)
8655 return std::nullopt;
8656 uint64_t NarrowByteWidth = NarrowBitWidth / 8;
8657
8658 // If the width of the load does not reach byte we are trying to provide for
8659 // and it is not a ZEXTLOAD, then the load does not provide for the byte in
8660 // question
8661 if (Index >= NarrowByteWidth)
8662 return L->getExtensionType() == ISD::ZEXTLOAD
8663 ? std::optional<SDByteProvider>(
8664 SDByteProvider::getConstantZero())
8665 : std::nullopt;
8666
8667 unsigned BPVectorIndex = VectorIndex.value_or(u: 0U);
8668 return SDByteProvider::getSrc(Val: L, ByteOffset: Index, VectorOffset: BPVectorIndex);
8669 }
8670 }
8671
8672 return std::nullopt;
8673}
8674
8675static unsigned littleEndianByteAt(unsigned BW, unsigned i) {
8676 return i;
8677}
8678
8679static unsigned bigEndianByteAt(unsigned BW, unsigned i) {
8680 return BW - i - 1;
8681}
8682
8683// Check if the bytes offsets we are looking at match with either big or
8684// little endian value loaded. Return true for big endian, false for little
8685// endian, and std::nullopt if match failed.
8686static std::optional<bool> isBigEndian(const ArrayRef<int64_t> ByteOffsets,
8687 int64_t FirstOffset) {
8688 // The endian can be decided only when it is 2 bytes at least.
8689 unsigned Width = ByteOffsets.size();
8690 if (Width < 2)
8691 return std::nullopt;
8692
8693 bool BigEndian = true, LittleEndian = true;
8694 for (unsigned i = 0; i < Width; i++) {
8695 int64_t CurrentByteOffset = ByteOffsets[i] - FirstOffset;
8696 LittleEndian &= CurrentByteOffset == littleEndianByteAt(BW: Width, i);
8697 BigEndian &= CurrentByteOffset == bigEndianByteAt(BW: Width, i);
8698 if (!BigEndian && !LittleEndian)
8699 return std::nullopt;
8700 }
8701
8702 assert((BigEndian != LittleEndian) && "It should be either big endian or"
8703 "little endian");
8704 return BigEndian;
8705}
8706
8707static SDValue stripTruncAndExt(SDValue Value) {
8708 switch (Value.getOpcode()) {
8709 case ISD::TRUNCATE:
8710 case ISD::ZERO_EXTEND:
8711 case ISD::SIGN_EXTEND:
8712 case ISD::ANY_EXTEND:
8713 return stripTruncAndExt(Value: Value.getOperand(i: 0));
8714 }
8715 return Value;
8716}
8717
8718/// Match a pattern where a wide type scalar value is stored by several narrow
8719/// stores. Fold it into a single store or a BSWAP and a store if the targets
8720/// supports it.
8721///
8722/// Assuming little endian target:
8723/// i8 *p = ...
8724/// i32 val = ...
8725/// p[0] = (val >> 0) & 0xFF;
8726/// p[1] = (val >> 8) & 0xFF;
8727/// p[2] = (val >> 16) & 0xFF;
8728/// p[3] = (val >> 24) & 0xFF;
8729/// =>
8730/// *((i32)p) = val;
8731///
8732/// i8 *p = ...
8733/// i32 val = ...
8734/// p[0] = (val >> 24) & 0xFF;
8735/// p[1] = (val >> 16) & 0xFF;
8736/// p[2] = (val >> 8) & 0xFF;
8737/// p[3] = (val >> 0) & 0xFF;
8738/// =>
8739/// *((i32)p) = BSWAP(val);
8740SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
8741 // The matching looks for "store (trunc x)" patterns that appear early but are
8742 // likely to be replaced by truncating store nodes during combining.
8743 // TODO: If there is evidence that running this later would help, this
8744 // limitation could be removed. Legality checks may need to be added
8745 // for the created store and optional bswap/rotate.
8746 if (LegalOperations || OptLevel == CodeGenOptLevel::None)
8747 return SDValue();
8748
8749 // We only handle merging simple stores of 1-4 bytes.
8750 // TODO: Allow unordered atomics when wider type is legal (see D66309)
8751 EVT MemVT = N->getMemoryVT();
8752 if (!(MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) ||
8753 !N->isSimple() || N->isIndexed())
8754 return SDValue();
8755
8756 // Collect all of the stores in the chain, upto the maximum store width (i64).
8757 SDValue Chain = N->getChain();
8758 SmallVector<StoreSDNode *, 8> Stores = {N};
8759 unsigned NarrowNumBits = MemVT.getScalarSizeInBits();
8760 unsigned MaxWideNumBits = 64;
8761 unsigned MaxStores = MaxWideNumBits / NarrowNumBits;
8762 while (auto *Store = dyn_cast<StoreSDNode>(Val&: Chain)) {
8763 // All stores must be the same size to ensure that we are writing all of the
8764 // bytes in the wide value.
8765 // This store should have exactly one use as a chain operand for another
8766 // store in the merging set. If there are other chain uses, then the
8767 // transform may not be safe because order of loads/stores outside of this
8768 // set may not be preserved.
8769 // TODO: We could allow multiple sizes by tracking each stored byte.
8770 if (Store->getMemoryVT() != MemVT || !Store->isSimple() ||
8771 Store->isIndexed() || !Store->hasOneUse())
8772 return SDValue();
8773 Stores.push_back(Elt: Store);
8774 Chain = Store->getChain();
8775 if (MaxStores < Stores.size())
8776 return SDValue();
8777 }
8778 // There is no reason to continue if we do not have at least a pair of stores.
8779 if (Stores.size() < 2)
8780 return SDValue();
8781
8782 // Handle simple types only.
8783 LLVMContext &Context = *DAG.getContext();
8784 unsigned NumStores = Stores.size();
8785 unsigned WideNumBits = NumStores * NarrowNumBits;
8786 EVT WideVT = EVT::getIntegerVT(Context, BitWidth: WideNumBits);
8787 if (WideVT != MVT::i16 && WideVT != MVT::i32 && WideVT != MVT::i64)
8788 return SDValue();
8789
8790 // Check if all bytes of the source value that we are looking at are stored
8791 // to the same base address. Collect offsets from Base address into OffsetMap.
8792 SDValue SourceValue;
8793 SmallVector<int64_t, 8> OffsetMap(NumStores, INT64_MAX);
8794 int64_t FirstOffset = INT64_MAX;
8795 StoreSDNode *FirstStore = nullptr;
8796 std::optional<BaseIndexOffset> Base;
8797 for (auto *Store : Stores) {
8798 // All the stores store different parts of the CombinedValue. A truncate is
8799 // required to get the partial value.
8800 SDValue Trunc = Store->getValue();
8801 if (Trunc.getOpcode() != ISD::TRUNCATE)
8802 return SDValue();
8803 // Other than the first/last part, a shift operation is required to get the
8804 // offset.
8805 int64_t Offset = 0;
8806 SDValue WideVal = Trunc.getOperand(i: 0);
8807 if ((WideVal.getOpcode() == ISD::SRL || WideVal.getOpcode() == ISD::SRA) &&
8808 isa<ConstantSDNode>(Val: WideVal.getOperand(i: 1))) {
8809 // The shift amount must be a constant multiple of the narrow type.
8810 // It is translated to the offset address in the wide source value "y".
8811 //
8812 // x = srl y, ShiftAmtC
8813 // i8 z = trunc x
8814 // store z, ...
8815 uint64_t ShiftAmtC = WideVal.getConstantOperandVal(i: 1);
8816 if (ShiftAmtC % NarrowNumBits != 0)
8817 return SDValue();
8818
8819 Offset = ShiftAmtC / NarrowNumBits;
8820 WideVal = WideVal.getOperand(i: 0);
8821 }
8822
8823 // Stores must share the same source value with different offsets.
8824 // Truncate and extends should be stripped to get the single source value.
8825 if (!SourceValue)
8826 SourceValue = WideVal;
8827 else if (stripTruncAndExt(Value: SourceValue) != stripTruncAndExt(Value: WideVal))
8828 return SDValue();
8829 else if (SourceValue.getValueType() != WideVT) {
8830 if (WideVal.getValueType() == WideVT ||
8831 WideVal.getScalarValueSizeInBits() >
8832 SourceValue.getScalarValueSizeInBits())
8833 SourceValue = WideVal;
8834 // Give up if the source value type is smaller than the store size.
8835 if (SourceValue.getScalarValueSizeInBits() < WideVT.getScalarSizeInBits())
8836 return SDValue();
8837 }
8838
8839 // Stores must share the same base address.
8840 BaseIndexOffset Ptr = BaseIndexOffset::match(N: Store, DAG);
8841 int64_t ByteOffsetFromBase = 0;
8842 if (!Base)
8843 Base = Ptr;
8844 else if (!Base->equalBaseIndex(Other: Ptr, DAG, Off&: ByteOffsetFromBase))
8845 return SDValue();
8846
8847 // Remember the first store.
8848 if (ByteOffsetFromBase < FirstOffset) {
8849 FirstStore = Store;
8850 FirstOffset = ByteOffsetFromBase;
8851 }
8852 // Map the offset in the store and the offset in the combined value, and
8853 // early return if it has been set before.
8854 if (Offset < 0 || Offset >= NumStores || OffsetMap[Offset] != INT64_MAX)
8855 return SDValue();
8856 OffsetMap[Offset] = ByteOffsetFromBase;
8857 }
8858
8859 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
8860 assert(FirstStore && "First store must be set");
8861
8862 // Check that a store of the wide type is both allowed and fast on the target
8863 const DataLayout &Layout = DAG.getDataLayout();
8864 unsigned Fast = 0;
8865 bool Allowed = TLI.allowsMemoryAccess(Context, DL: Layout, VT: WideVT,
8866 MMO: *FirstStore->getMemOperand(), Fast: &Fast);
8867 if (!Allowed || !Fast)
8868 return SDValue();
8869
8870 // Check if the pieces of the value are going to the expected places in memory
8871 // to merge the stores.
8872 auto checkOffsets = [&](bool MatchLittleEndian) {
8873 if (MatchLittleEndian) {
8874 for (unsigned i = 0; i != NumStores; ++i)
8875 if (OffsetMap[i] != i * (NarrowNumBits / 8) + FirstOffset)
8876 return false;
8877 } else { // MatchBigEndian by reversing loop counter.
8878 for (unsigned i = 0, j = NumStores - 1; i != NumStores; ++i, --j)
8879 if (OffsetMap[j] != i * (NarrowNumBits / 8) + FirstOffset)
8880 return false;
8881 }
8882 return true;
8883 };
8884
8885 // Check if the offsets line up for the native data layout of this target.
8886 bool NeedBswap = false;
8887 bool NeedRotate = false;
8888 if (!checkOffsets(Layout.isLittleEndian())) {
8889 // Special-case: check if byte offsets line up for the opposite endian.
8890 if (NarrowNumBits == 8 && checkOffsets(Layout.isBigEndian()))
8891 NeedBswap = true;
8892 else if (NumStores == 2 && checkOffsets(Layout.isBigEndian()))
8893 NeedRotate = true;
8894 else
8895 return SDValue();
8896 }
8897
8898 SDLoc DL(N);
8899 if (WideVT != SourceValue.getValueType()) {
8900 assert(SourceValue.getValueType().getScalarSizeInBits() > WideNumBits &&
8901 "Unexpected store value to merge");
8902 SourceValue = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: WideVT, Operand: SourceValue);
8903 }
8904
8905 // Before legalize we can introduce illegal bswaps/rotates which will be later
8906 // converted to an explicit bswap sequence. This way we end up with a single
8907 // store and byte shuffling instead of several stores and byte shuffling.
8908 if (NeedBswap) {
8909 SourceValue = DAG.getNode(Opcode: ISD::BSWAP, DL, VT: WideVT, Operand: SourceValue);
8910 } else if (NeedRotate) {
8911 assert(WideNumBits % 2 == 0 && "Unexpected type for rotate");
8912 SDValue RotAmt = DAG.getConstant(Val: WideNumBits / 2, DL, VT: WideVT);
8913 SourceValue = DAG.getNode(Opcode: ISD::ROTR, DL, VT: WideVT, N1: SourceValue, N2: RotAmt);
8914 }
8915
8916 SDValue NewStore =
8917 DAG.getStore(Chain, dl: DL, Val: SourceValue, Ptr: FirstStore->getBasePtr(),
8918 PtrInfo: FirstStore->getPointerInfo(), Alignment: FirstStore->getAlign());
8919
8920 // Rely on other DAG combine rules to remove the other individual stores.
8921 DAG.ReplaceAllUsesWith(From: N, To: NewStore.getNode());
8922 return NewStore;
8923}
8924
8925/// Match a pattern where a wide type scalar value is loaded by several narrow
8926/// loads and combined by shifts and ors. Fold it into a single load or a load
8927/// and a BSWAP if the targets supports it.
8928///
8929/// Assuming little endian target:
8930/// i8 *a = ...
8931/// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
8932/// =>
8933/// i32 val = *((i32)a)
8934///
8935/// i8 *a = ...
8936/// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
8937/// =>
8938/// i32 val = BSWAP(*((i32)a))
8939///
8940/// TODO: This rule matches complex patterns with OR node roots and doesn't
8941/// interact well with the worklist mechanism. When a part of the pattern is
8942/// updated (e.g. one of the loads) its direct users are put into the worklist,
8943/// but the root node of the pattern which triggers the load combine is not
8944/// necessarily a direct user of the changed node. For example, once the address
8945/// of t28 load is reassociated load combine won't be triggered:
8946/// t25: i32 = add t4, Constant:i32<2>
8947/// t26: i64 = sign_extend t25
8948/// t27: i64 = add t2, t26
8949/// t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
8950/// t29: i32 = zero_extend t28
8951/// t32: i32 = shl t29, Constant:i8<8>
8952/// t33: i32 = or t23, t32
8953/// As a possible fix visitLoad can check if the load can be a part of a load
8954/// combine pattern and add corresponding OR roots to the worklist.
8955SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
8956 assert(N->getOpcode() == ISD::OR &&
8957 "Can only match load combining against OR nodes");
8958
8959 // Handles simple types only
8960 EVT VT = N->getValueType(ResNo: 0);
8961 if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
8962 return SDValue();
8963 unsigned ByteWidth = VT.getSizeInBits() / 8;
8964
8965 bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
8966 auto MemoryByteOffset = [&](SDByteProvider P) {
8967 assert(P.hasSrc() && "Must be a memory byte provider");
8968 auto *Load = cast<LoadSDNode>(Val: P.Src.value());
8969
8970 unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits();
8971
8972 assert(LoadBitWidth % 8 == 0 &&
8973 "can only analyze providers for individual bytes not bit");
8974 unsigned LoadByteWidth = LoadBitWidth / 8;
8975 return IsBigEndianTarget ? bigEndianByteAt(BW: LoadByteWidth, i: P.DestOffset)
8976 : littleEndianByteAt(BW: LoadByteWidth, i: P.DestOffset);
8977 };
8978
8979 std::optional<BaseIndexOffset> Base;
8980 SDValue Chain;
8981
8982 SmallPtrSet<LoadSDNode *, 8> Loads;
8983 std::optional<SDByteProvider> FirstByteProvider;
8984 int64_t FirstOffset = INT64_MAX;
8985
8986 // Check if all the bytes of the OR we are looking at are loaded from the same
8987 // base address. Collect bytes offsets from Base address in ByteOffsets.
8988 SmallVector<int64_t, 8> ByteOffsets(ByteWidth);
8989 unsigned ZeroExtendedBytes = 0;
8990 for (int i = ByteWidth - 1; i >= 0; --i) {
8991 auto P =
8992 calculateByteProvider(Op: SDValue(N, 0), Index: i, Depth: 0, /*VectorIndex*/ std::nullopt,
8993 /*StartingIndex*/ i);
8994 if (!P)
8995 return SDValue();
8996
8997 if (P->isConstantZero()) {
8998 // It's OK for the N most significant bytes to be 0, we can just
8999 // zero-extend the load.
9000 if (++ZeroExtendedBytes != (ByteWidth - static_cast<unsigned>(i)))
9001 return SDValue();
9002 continue;
9003 }
9004 assert(P->hasSrc() && "provenance should either be memory or zero");
9005 auto *L = cast<LoadSDNode>(Val: P->Src.value());
9006
9007 // All loads must share the same chain
9008 SDValue LChain = L->getChain();
9009 if (!Chain)
9010 Chain = LChain;
9011 else if (Chain != LChain)
9012 return SDValue();
9013
9014 // Loads must share the same base address
9015 BaseIndexOffset Ptr = BaseIndexOffset::match(N: L, DAG);
9016 int64_t ByteOffsetFromBase = 0;
9017
9018 // For vector loads, the expected load combine pattern will have an
9019 // ExtractElement for each index in the vector. While each of these
9020 // ExtractElements will be accessing the same base address as determined
9021 // by the load instruction, the actual bytes they interact with will differ
9022 // due to different ExtractElement indices. To accurately determine the
9023 // byte position of an ExtractElement, we offset the base load ptr with
9024 // the index multiplied by the byte size of each element in the vector.
9025 if (L->getMemoryVT().isVector()) {
9026 unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits();
9027 if (LoadWidthInBit % 8 != 0)
9028 return SDValue();
9029 unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8;
9030 Ptr.addToOffset(VectorOff: ByteOffsetFromVector);
9031 }
9032
9033 if (!Base)
9034 Base = Ptr;
9035
9036 else if (!Base->equalBaseIndex(Other: Ptr, DAG, Off&: ByteOffsetFromBase))
9037 return SDValue();
9038
9039 // Calculate the offset of the current byte from the base address
9040 ByteOffsetFromBase += MemoryByteOffset(*P);
9041 ByteOffsets[i] = ByteOffsetFromBase;
9042
9043 // Remember the first byte load
9044 if (ByteOffsetFromBase < FirstOffset) {
9045 FirstByteProvider = P;
9046 FirstOffset = ByteOffsetFromBase;
9047 }
9048
9049 Loads.insert(Ptr: L);
9050 }
9051
9052 assert(!Loads.empty() && "All the bytes of the value must be loaded from "
9053 "memory, so there must be at least one load which produces the value");
9054 assert(Base && "Base address of the accessed memory location must be set");
9055 assert(FirstOffset != INT64_MAX && "First byte offset must be set");
9056
9057 bool NeedsZext = ZeroExtendedBytes > 0;
9058
9059 EVT MemVT =
9060 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: (ByteWidth - ZeroExtendedBytes) * 8);
9061
9062 if (!MemVT.isSimple())
9063 return SDValue();
9064
9065 // Before legalize we can introduce too wide illegal loads which will be later
9066 // split into legal sized loads. This enables us to combine i64 load by i8
9067 // patterns to a couple of i32 loads on 32 bit targets.
9068 if (LegalOperations &&
9069 !TLI.isOperationLegal(Op: NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD,
9070 VT: MemVT))
9071 return SDValue();
9072
9073 // Check if the bytes of the OR we are looking at match with either big or
9074 // little endian value load
9075 std::optional<bool> IsBigEndian = isBigEndian(
9076 ByteOffsets: ArrayRef(ByteOffsets).drop_back(N: ZeroExtendedBytes), FirstOffset);
9077 if (!IsBigEndian)
9078 return SDValue();
9079
9080 assert(FirstByteProvider && "must be set");
9081
9082 // Ensure that the first byte is loaded from zero offset of the first load.
9083 // So the combined value can be loaded from the first load address.
9084 if (MemoryByteOffset(*FirstByteProvider) != 0)
9085 return SDValue();
9086 auto *FirstLoad = cast<LoadSDNode>(Val: FirstByteProvider->Src.value());
9087
9088 // The node we are looking at matches with the pattern, check if we can
9089 // replace it with a single (possibly zero-extended) load and bswap + shift if
9090 // needed.
9091
9092 // If the load needs byte swap check if the target supports it
9093 bool NeedsBswap = IsBigEndianTarget != *IsBigEndian;
9094
9095 // Before legalize we can introduce illegal bswaps which will be later
9096 // converted to an explicit bswap sequence. This way we end up with a single
9097 // load and byte shuffling instead of several loads and byte shuffling.
9098 // We do not introduce illegal bswaps when zero-extending as this tends to
9099 // introduce too many arithmetic instructions.
9100 if (NeedsBswap && (LegalOperations || NeedsZext) &&
9101 !TLI.isOperationLegal(Op: ISD::BSWAP, VT))
9102 return SDValue();
9103
9104 // If we need to bswap and zero extend, we have to insert a shift. Check that
9105 // it is legal.
9106 if (NeedsBswap && NeedsZext && LegalOperations &&
9107 !TLI.isOperationLegal(Op: ISD::SHL, VT))
9108 return SDValue();
9109
9110 // Check that a load of the wide type is both allowed and fast on the target
9111 unsigned Fast = 0;
9112 bool Allowed =
9113 TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: MemVT,
9114 MMO: *FirstLoad->getMemOperand(), Fast: &Fast);
9115 if (!Allowed || !Fast)
9116 return SDValue();
9117
9118 SDValue NewLoad =
9119 DAG.getExtLoad(ExtType: NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, dl: SDLoc(N), VT,
9120 Chain, Ptr: FirstLoad->getBasePtr(),
9121 PtrInfo: FirstLoad->getPointerInfo(), MemVT, Alignment: FirstLoad->getAlign());
9122
9123 // Transfer chain users from old loads to the new load.
9124 for (LoadSDNode *L : Loads)
9125 DAG.makeEquivalentMemoryOrdering(OldLoad: L, NewMemOp: NewLoad);
9126
9127 if (!NeedsBswap)
9128 return NewLoad;
9129
9130 SDValue ShiftedLoad =
9131 NeedsZext
9132 ? DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: NewLoad,
9133 N2: DAG.getShiftAmountConstant(Val: ZeroExtendedBytes * 8, VT,
9134 DL: SDLoc(N), LegalTypes: LegalOperations))
9135 : NewLoad;
9136 return DAG.getNode(Opcode: ISD::BSWAP, DL: SDLoc(N), VT, Operand: ShiftedLoad);
9137}
9138
9139// If the target has andn, bsl, or a similar bit-select instruction,
9140// we want to unfold masked merge, with canonical pattern of:
9141// | A | |B|
9142// ((x ^ y) & m) ^ y
9143// | D |
9144// Into:
9145// (x & m) | (y & ~m)
9146// If y is a constant, m is not a 'not', and the 'andn' does not work with
9147// immediates, we unfold into a different pattern:
9148// ~(~x & m) & (m | y)
9149// If x is a constant, m is a 'not', and the 'andn' does not work with
9150// immediates, we unfold into a different pattern:
9151// (x | ~m) & ~(~m & ~y)
9152// NOTE: we don't unfold the pattern if 'xor' is actually a 'not', because at
9153// the very least that breaks andnpd / andnps patterns, and because those
9154// patterns are simplified in IR and shouldn't be created in the DAG
9155SDValue DAGCombiner::unfoldMaskedMerge(SDNode *N) {
9156 assert(N->getOpcode() == ISD::XOR);
9157
9158 // Don't touch 'not' (i.e. where y = -1).
9159 if (isAllOnesOrAllOnesSplat(V: N->getOperand(Num: 1)))
9160 return SDValue();
9161
9162 EVT VT = N->getValueType(ResNo: 0);
9163
9164 // There are 3 commutable operators in the pattern,
9165 // so we have to deal with 8 possible variants of the basic pattern.
9166 SDValue X, Y, M;
9167 auto matchAndXor = [&X, &Y, &M](SDValue And, unsigned XorIdx, SDValue Other) {
9168 if (And.getOpcode() != ISD::AND || !And.hasOneUse())
9169 return false;
9170 SDValue Xor = And.getOperand(i: XorIdx);
9171 if (Xor.getOpcode() != ISD::XOR || !Xor.hasOneUse())
9172 return false;
9173 SDValue Xor0 = Xor.getOperand(i: 0);
9174 SDValue Xor1 = Xor.getOperand(i: 1);
9175 // Don't touch 'not' (i.e. where y = -1).
9176 if (isAllOnesOrAllOnesSplat(V: Xor1))
9177 return false;
9178 if (Other == Xor0)
9179 std::swap(a&: Xor0, b&: Xor1);
9180 if (Other != Xor1)
9181 return false;
9182 X = Xor0;
9183 Y = Xor1;
9184 M = And.getOperand(i: XorIdx ? 0 : 1);
9185 return true;
9186 };
9187
9188 SDValue N0 = N->getOperand(Num: 0);
9189 SDValue N1 = N->getOperand(Num: 1);
9190 if (!matchAndXor(N0, 0, N1) && !matchAndXor(N0, 1, N1) &&
9191 !matchAndXor(N1, 0, N0) && !matchAndXor(N1, 1, N0))
9192 return SDValue();
9193
9194 // Don't do anything if the mask is constant. This should not be reachable.
9195 // InstCombine should have already unfolded this pattern, and DAGCombiner
9196 // probably shouldn't produce it, too.
9197 if (isa<ConstantSDNode>(Val: M.getNode()))
9198 return SDValue();
9199
9200 // We can transform if the target has AndNot
9201 if (!TLI.hasAndNot(X: M))
9202 return SDValue();
9203
9204 SDLoc DL(N);
9205
9206 // If Y is a constant, check that 'andn' works with immediates. Unless M is
9207 // a bitwise not that would already allow ANDN to be used.
9208 if (!TLI.hasAndNot(X: Y) && !isBitwiseNot(V: M)) {
9209 assert(TLI.hasAndNot(X) && "Only mask is a variable? Unreachable.");
9210 // If not, we need to do a bit more work to make sure andn is still used.
9211 SDValue NotX = DAG.getNOT(DL, Val: X, VT);
9212 SDValue LHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotX, N2: M);
9213 SDValue NotLHS = DAG.getNOT(DL, Val: LHS, VT);
9214 SDValue RHS = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: M, N2: Y);
9215 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotLHS, N2: RHS);
9216 }
9217
9218 // If X is a constant and M is a bitwise not, check that 'andn' works with
9219 // immediates.
9220 if (!TLI.hasAndNot(X) && isBitwiseNot(V: M)) {
9221 assert(TLI.hasAndNot(Y) && "Only mask is a variable? Unreachable.");
9222 // If not, we need to do a bit more work to make sure andn is still used.
9223 SDValue NotM = M.getOperand(i: 0);
9224 SDValue LHS = DAG.getNode(Opcode: ISD::OR, DL, VT, N1: X, N2: NotM);
9225 SDValue NotY = DAG.getNOT(DL, Val: Y, VT);
9226 SDValue RHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotM, N2: NotY);
9227 SDValue NotRHS = DAG.getNOT(DL, Val: RHS, VT);
9228 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: LHS, N2: NotRHS);
9229 }
9230
9231 SDValue LHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X, N2: M);
9232 SDValue NotM = DAG.getNOT(DL, Val: M, VT);
9233 SDValue RHS = DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Y, N2: NotM);
9234
9235 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: LHS, N2: RHS);
9236}
9237
9238SDValue DAGCombiner::visitXOR(SDNode *N) {
9239 SDValue N0 = N->getOperand(Num: 0);
9240 SDValue N1 = N->getOperand(Num: 1);
9241 EVT VT = N0.getValueType();
9242 SDLoc DL(N);
9243
9244 // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
9245 if (N0.isUndef() && N1.isUndef())
9246 return DAG.getConstant(Val: 0, DL, VT);
9247
9248 // fold (xor x, undef) -> undef
9249 if (N0.isUndef())
9250 return N0;
9251 if (N1.isUndef())
9252 return N1;
9253
9254 // fold (xor c1, c2) -> c1^c2
9255 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::XOR, DL, VT, Ops: {N0, N1}))
9256 return C;
9257
9258 // canonicalize constant to RHS
9259 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
9260 !DAG.isConstantIntBuildVectorOrConstantInt(N: N1))
9261 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1, N2: N0);
9262
9263 // fold vector ops
9264 if (VT.isVector()) {
9265 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
9266 return FoldedVOp;
9267
9268 // fold (xor x, 0) -> x, vector edition
9269 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode()))
9270 return N0;
9271 }
9272
9273 // fold (xor x, 0) -> x
9274 if (isNullConstant(V: N1))
9275 return N0;
9276
9277 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
9278 return NewSel;
9279
9280 // reassociate xor
9281 if (SDValue RXOR = reassociateOps(Opc: ISD::XOR, DL, N0, N1, Flags: N->getFlags()))
9282 return RXOR;
9283
9284 // Fold xor(vecreduce(x), vecreduce(y)) -> vecreduce(xor(x, y))
9285 if (SDValue SD =
9286 reassociateReduction(RedOpc: ISD::VECREDUCE_XOR, Opc: ISD::XOR, DL, VT, N0, N1))
9287 return SD;
9288
9289 // fold (a^b) -> (a|b) iff a and b share no bits.
9290 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::OR, VT)) &&
9291 DAG.haveNoCommonBitsSet(A: N0, B: N1)) {
9292 SDNodeFlags Flags;
9293 Flags.setDisjoint(true);
9294 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: N0, N2: N1, Flags);
9295 }
9296
9297 // look for 'add-like' folds:
9298 // XOR(N0,MIN_SIGNED_VALUE) == ADD(N0,MIN_SIGNED_VALUE)
9299 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::ADD, VT)) &&
9300 isMinSignedConstant(V: N1))
9301 if (SDValue Combined = visitADDLike(N))
9302 return Combined;
9303
9304 // fold !(x cc y) -> (x !cc y)
9305 unsigned N0Opcode = N0.getOpcode();
9306 SDValue LHS, RHS, CC;
9307 if (TLI.isConstTrueVal(N: N1) &&
9308 isSetCCEquivalent(N: N0, LHS, RHS, CC, /*MatchStrict*/ true)) {
9309 ISD::CondCode NotCC = ISD::getSetCCInverse(Operation: cast<CondCodeSDNode>(Val&: CC)->get(),
9310 Type: LHS.getValueType());
9311 if (!LegalOperations ||
9312 TLI.isCondCodeLegal(CC: NotCC, VT: LHS.getSimpleValueType())) {
9313 switch (N0Opcode) {
9314 default:
9315 llvm_unreachable("Unhandled SetCC Equivalent!");
9316 case ISD::SETCC:
9317 return DAG.getSetCC(DL: SDLoc(N0), VT, LHS, RHS, Cond: NotCC);
9318 case ISD::SELECT_CC:
9319 return DAG.getSelectCC(DL: SDLoc(N0), LHS, RHS, True: N0.getOperand(i: 2),
9320 False: N0.getOperand(i: 3), Cond: NotCC);
9321 case ISD::STRICT_FSETCC:
9322 case ISD::STRICT_FSETCCS: {
9323 if (N0.hasOneUse()) {
9324 // FIXME Can we handle multiple uses? Could we token factor the chain
9325 // results from the new/old setcc?
9326 SDValue SetCC =
9327 DAG.getSetCC(DL: SDLoc(N0), VT, LHS, RHS, Cond: NotCC,
9328 Chain: N0.getOperand(i: 0), IsSignaling: N0Opcode == ISD::STRICT_FSETCCS);
9329 CombineTo(N, Res: SetCC);
9330 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: SetCC.getValue(R: 1));
9331 recursivelyDeleteUnusedNodes(N: N0.getNode());
9332 return SDValue(N, 0); // Return N so it doesn't get rechecked!
9333 }
9334 break;
9335 }
9336 }
9337 }
9338 }
9339
9340 // fold (not (zext (setcc x, y))) -> (zext (not (setcc x, y)))
9341 if (isOneConstant(V: N1) && N0Opcode == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9342 isSetCCEquivalent(N: N0.getOperand(i: 0), LHS, RHS, CC)){
9343 SDValue V = N0.getOperand(i: 0);
9344 SDLoc DL0(N0);
9345 V = DAG.getNode(Opcode: ISD::XOR, DL: DL0, VT: V.getValueType(), N1: V,
9346 N2: DAG.getConstant(Val: 1, DL: DL0, VT: V.getValueType()));
9347 AddToWorklist(N: V.getNode());
9348 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: V);
9349 }
9350
9351 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are setcc
9352 if (isOneConstant(N1) && VT == MVT::i1 && N0.hasOneUse() &&
9353 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9354 SDValue N00 = N0.getOperand(i: 0), N01 = N0.getOperand(i: 1);
9355 if (isOneUseSetCC(N: N01) || isOneUseSetCC(N: N00)) {
9356 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9357 N00 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N00), VT, N1: N00, N2: N1); // N00 = ~N00
9358 N01 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N01), VT, N1: N01, N2: N1); // N01 = ~N01
9359 AddToWorklist(N: N00.getNode()); AddToWorklist(N: N01.getNode());
9360 return DAG.getNode(Opcode: NewOpcode, DL, VT, N1: N00, N2: N01);
9361 }
9362 }
9363 // fold (not (or x, y)) -> (and (not x), (not y)) iff x or y are constants
9364 if (isAllOnesConstant(V: N1) && N0.hasOneUse() &&
9365 (N0Opcode == ISD::OR || N0Opcode == ISD::AND)) {
9366 SDValue N00 = N0.getOperand(i: 0), N01 = N0.getOperand(i: 1);
9367 if (isa<ConstantSDNode>(Val: N01) || isa<ConstantSDNode>(Val: N00)) {
9368 unsigned NewOpcode = N0Opcode == ISD::AND ? ISD::OR : ISD::AND;
9369 N00 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N00), VT, N1: N00, N2: N1); // N00 = ~N00
9370 N01 = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N01), VT, N1: N01, N2: N1); // N01 = ~N01
9371 AddToWorklist(N: N00.getNode()); AddToWorklist(N: N01.getNode());
9372 return DAG.getNode(Opcode: NewOpcode, DL, VT, N1: N00, N2: N01);
9373 }
9374 }
9375
9376 // fold (not (neg x)) -> (add X, -1)
9377 // FIXME: This can be generalized to (not (sub Y, X)) -> (add X, ~Y) if
9378 // Y is a constant or the subtract has a single use.
9379 if (isAllOnesConstant(V: N1) && N0.getOpcode() == ISD::SUB &&
9380 isNullConstant(V: N0.getOperand(i: 0))) {
9381 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: N0.getOperand(i: 1),
9382 N2: DAG.getAllOnesConstant(DL, VT));
9383 }
9384
9385 // fold (not (add X, -1)) -> (neg X)
9386 if (isAllOnesConstant(V: N1) && N0.getOpcode() == ISD::ADD &&
9387 isAllOnesOrAllOnesSplat(V: N0.getOperand(i: 1))) {
9388 return DAG.getNegative(Val: N0.getOperand(i: 0), DL, VT);
9389 }
9390
9391 // fold (xor (and x, y), y) -> (and (not x), y)
9392 if (N0Opcode == ISD::AND && N0.hasOneUse() && N0->getOperand(Num: 1) == N1) {
9393 SDValue X = N0.getOperand(i: 0);
9394 SDValue NotX = DAG.getNOT(DL: SDLoc(X), Val: X, VT);
9395 AddToWorklist(N: NotX.getNode());
9396 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: NotX, N2: N1);
9397 }
9398
9399 // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
9400 if (TLI.isOperationLegalOrCustom(Op: ISD::ABS, VT)) {
9401 SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
9402 SDValue S = N0Opcode == ISD::SRA ? N0 : N1;
9403 if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) {
9404 SDValue A0 = A.getOperand(i: 0), A1 = A.getOperand(i: 1);
9405 SDValue S0 = S.getOperand(i: 0);
9406 if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0))
9407 if (ConstantSDNode *C = isConstOrConstSplat(N: S.getOperand(i: 1)))
9408 if (C->getAPIntValue() == (VT.getScalarSizeInBits() - 1))
9409 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: S0);
9410 }
9411 }
9412
9413 // fold (xor x, x) -> 0
9414 if (N0 == N1)
9415 return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
9416
9417 // fold (xor (shl 1, x), -1) -> (rotl ~1, x)
9418 // Here is a concrete example of this equivalence:
9419 // i16 x == 14
9420 // i16 shl == 1 << 14 == 16384 == 0b0100000000000000
9421 // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111
9422 //
9423 // =>
9424 //
9425 // i16 ~1 == 0b1111111111111110
9426 // i16 rol(~1, 14) == 0b1011111111111111
9427 //
9428 // Some additional tips to help conceptualize this transform:
9429 // - Try to see the operation as placing a single zero in a value of all ones.
9430 // - There exists no value for x which would allow the result to contain zero.
9431 // - Values of x larger than the bitwidth are undefined and do not require a
9432 // consistent result.
9433 // - Pushing the zero left requires shifting one bits in from the right.
9434 // A rotate left of ~1 is a nice way of achieving the desired result.
9435 if (TLI.isOperationLegalOrCustom(Op: ISD::ROTL, VT) && N0Opcode == ISD::SHL &&
9436 isAllOnesConstant(V: N1) && isOneConstant(V: N0.getOperand(i: 0))) {
9437 return DAG.getNode(Opcode: ISD::ROTL, DL, VT, N1: DAG.getConstant(Val: ~1, DL, VT),
9438 N2: N0.getOperand(i: 1));
9439 }
9440
9441 // Simplify: xor (op x...), (op y...) -> (op (xor x, y))
9442 if (N0Opcode == N1.getOpcode())
9443 if (SDValue V = hoistLogicOpWithSameOpcodeHands(N))
9444 return V;
9445
9446 if (SDValue R = foldLogicOfShifts(N, LogicOp: N0, ShiftOp: N1, DAG))
9447 return R;
9448 if (SDValue R = foldLogicOfShifts(N, LogicOp: N1, ShiftOp: N0, DAG))
9449 return R;
9450 if (SDValue R = foldLogicTreeOfShifts(N, LeftHand: N0, RightHand: N1, DAG))
9451 return R;
9452
9453 // Unfold ((x ^ y) & m) ^ y into (x & m) | (y & ~m) if profitable
9454 if (SDValue MM = unfoldMaskedMerge(N))
9455 return MM;
9456
9457 // Simplify the expression using non-local knowledge.
9458 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
9459 return SDValue(N, 0);
9460
9461 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
9462 return Combined;
9463
9464 return SDValue();
9465}
9466
9467/// If we have a shift-by-constant of a bitwise logic op that itself has a
9468/// shift-by-constant operand with identical opcode, we may be able to convert
9469/// that into 2 independent shifts followed by the logic op. This is a
9470/// throughput improvement.
9471static SDValue combineShiftOfShiftedLogic(SDNode *Shift, SelectionDAG &DAG) {
9472 // Match a one-use bitwise logic op.
9473 SDValue LogicOp = Shift->getOperand(Num: 0);
9474 if (!LogicOp.hasOneUse())
9475 return SDValue();
9476
9477 unsigned LogicOpcode = LogicOp.getOpcode();
9478 if (LogicOpcode != ISD::AND && LogicOpcode != ISD::OR &&
9479 LogicOpcode != ISD::XOR)
9480 return SDValue();
9481
9482 // Find a matching one-use shift by constant.
9483 unsigned ShiftOpcode = Shift->getOpcode();
9484 SDValue C1 = Shift->getOperand(Num: 1);
9485 ConstantSDNode *C1Node = isConstOrConstSplat(N: C1);
9486 assert(C1Node && "Expected a shift with constant operand");
9487 const APInt &C1Val = C1Node->getAPIntValue();
9488 auto matchFirstShift = [&](SDValue V, SDValue &ShiftOp,
9489 const APInt *&ShiftAmtVal) {
9490 if (V.getOpcode() != ShiftOpcode || !V.hasOneUse())
9491 return false;
9492
9493 ConstantSDNode *ShiftCNode = isConstOrConstSplat(N: V.getOperand(i: 1));
9494 if (!ShiftCNode)
9495 return false;
9496
9497 // Capture the shifted operand and shift amount value.
9498 ShiftOp = V.getOperand(i: 0);
9499 ShiftAmtVal = &ShiftCNode->getAPIntValue();
9500
9501 // Shift amount types do not have to match their operand type, so check that
9502 // the constants are the same width.
9503 if (ShiftAmtVal->getBitWidth() != C1Val.getBitWidth())
9504 return false;
9505
9506 // The fold is not valid if the sum of the shift values doesn't fit in the
9507 // given shift amount type.
9508 bool Overflow = false;
9509 APInt NewShiftAmt = C1Val.uadd_ov(RHS: *ShiftAmtVal, Overflow);
9510 if (Overflow)
9511 return false;
9512
9513 // The fold is not valid if the sum of the shift values exceeds bitwidth.
9514 if (NewShiftAmt.uge(RHS: V.getScalarValueSizeInBits()))
9515 return false;
9516
9517 return true;
9518 };
9519
9520 // Logic ops are commutative, so check each operand for a match.
9521 SDValue X, Y;
9522 const APInt *C0Val;
9523 if (matchFirstShift(LogicOp.getOperand(i: 0), X, C0Val))
9524 Y = LogicOp.getOperand(i: 1);
9525 else if (matchFirstShift(LogicOp.getOperand(i: 1), X, C0Val))
9526 Y = LogicOp.getOperand(i: 0);
9527 else
9528 return SDValue();
9529
9530 // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
9531 SDLoc DL(Shift);
9532 EVT VT = Shift->getValueType(ResNo: 0);
9533 EVT ShiftAmtVT = Shift->getOperand(Num: 1).getValueType();
9534 SDValue ShiftSumC = DAG.getConstant(Val: *C0Val + C1Val, DL, VT: ShiftAmtVT);
9535 SDValue NewShift1 = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: X, N2: ShiftSumC);
9536 SDValue NewShift2 = DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: Y, N2: C1);
9537 return DAG.getNode(Opcode: LogicOpcode, DL, VT, N1: NewShift1, N2: NewShift2,
9538 Flags: LogicOp->getFlags());
9539}
9540
9541/// Handle transforms common to the three shifts, when the shift amount is a
9542/// constant.
9543/// We are looking for: (shift being one of shl/sra/srl)
9544/// shift (binop X, C0), C1
9545/// And want to transform into:
9546/// binop (shift X, C1), (shift C0, C1)
9547SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
9548 assert(isConstOrConstSplat(N->getOperand(1)) && "Expected constant operand");
9549
9550 // Do not turn a 'not' into a regular xor.
9551 if (isBitwiseNot(V: N->getOperand(Num: 0)))
9552 return SDValue();
9553
9554 // The inner binop must be one-use, since we want to replace it.
9555 SDValue LHS = N->getOperand(Num: 0);
9556 if (!LHS.hasOneUse() || !TLI.isDesirableToCommuteWithShift(N, Level))
9557 return SDValue();
9558
9559 // Fold shift(bitop(shift(x,c1),y), c2) -> bitop(shift(x,c1+c2),shift(y,c2)).
9560 if (SDValue R = combineShiftOfShiftedLogic(Shift: N, DAG))
9561 return R;
9562
9563 // We want to pull some binops through shifts, so that we have (and (shift))
9564 // instead of (shift (and)), likewise for add, or, xor, etc. This sort of
9565 // thing happens with address calculations, so it's important to canonicalize
9566 // it.
9567 switch (LHS.getOpcode()) {
9568 default:
9569 return SDValue();
9570 case ISD::OR:
9571 case ISD::XOR:
9572 case ISD::AND:
9573 break;
9574 case ISD::ADD:
9575 if (N->getOpcode() != ISD::SHL)
9576 return SDValue(); // only shl(add) not sr[al](add).
9577 break;
9578 }
9579
9580 // FIXME: disable this unless the input to the binop is a shift by a constant
9581 // or is copy/select. Enable this in other cases when figure out it's exactly
9582 // profitable.
9583 SDValue BinOpLHSVal = LHS.getOperand(i: 0);
9584 bool IsShiftByConstant = (BinOpLHSVal.getOpcode() == ISD::SHL ||
9585 BinOpLHSVal.getOpcode() == ISD::SRA ||
9586 BinOpLHSVal.getOpcode() == ISD::SRL) &&
9587 isa<ConstantSDNode>(Val: BinOpLHSVal.getOperand(i: 1));
9588 bool IsCopyOrSelect = BinOpLHSVal.getOpcode() == ISD::CopyFromReg ||
9589 BinOpLHSVal.getOpcode() == ISD::SELECT;
9590
9591 if (!IsShiftByConstant && !IsCopyOrSelect)
9592 return SDValue();
9593
9594 if (IsCopyOrSelect && N->hasOneUse())
9595 return SDValue();
9596
9597 // Attempt to fold the constants, shifting the binop RHS by the shift amount.
9598 SDLoc DL(N);
9599 EVT VT = N->getValueType(ResNo: 0);
9600 if (SDValue NewRHS = DAG.FoldConstantArithmetic(
9601 Opcode: N->getOpcode(), DL, VT, Ops: {LHS.getOperand(i: 1), N->getOperand(Num: 1)})) {
9602 SDValue NewShift = DAG.getNode(Opcode: N->getOpcode(), DL, VT, N1: LHS.getOperand(i: 0),
9603 N2: N->getOperand(Num: 1));
9604 return DAG.getNode(Opcode: LHS.getOpcode(), DL, VT, N1: NewShift, N2: NewRHS);
9605 }
9606
9607 return SDValue();
9608}
9609
9610SDValue DAGCombiner::distributeTruncateThroughAnd(SDNode *N) {
9611 assert(N->getOpcode() == ISD::TRUNCATE);
9612 assert(N->getOperand(0).getOpcode() == ISD::AND);
9613
9614 // (truncate:TruncVT (and N00, N01C)) -> (and (truncate:TruncVT N00), TruncC)
9615 EVT TruncVT = N->getValueType(ResNo: 0);
9616 if (N->hasOneUse() && N->getOperand(Num: 0).hasOneUse() &&
9617 TLI.isTypeDesirableForOp(ISD::AND, VT: TruncVT)) {
9618 SDValue N01 = N->getOperand(Num: 0).getOperand(i: 1);
9619 if (isConstantOrConstantVector(N: N01, /* NoOpaques */ true)) {
9620 SDLoc DL(N);
9621 SDValue N00 = N->getOperand(Num: 0).getOperand(i: 0);
9622 SDValue Trunc00 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT, Operand: N00);
9623 SDValue Trunc01 = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT, Operand: N01);
9624 AddToWorklist(N: Trunc00.getNode());
9625 AddToWorklist(N: Trunc01.getNode());
9626 return DAG.getNode(Opcode: ISD::AND, DL, VT: TruncVT, N1: Trunc00, N2: Trunc01);
9627 }
9628 }
9629
9630 return SDValue();
9631}
9632
9633SDValue DAGCombiner::visitRotate(SDNode *N) {
9634 SDLoc dl(N);
9635 SDValue N0 = N->getOperand(Num: 0);
9636 SDValue N1 = N->getOperand(Num: 1);
9637 EVT VT = N->getValueType(ResNo: 0);
9638 unsigned Bitsize = VT.getScalarSizeInBits();
9639
9640 // fold (rot x, 0) -> x
9641 if (isNullOrNullSplat(V: N1))
9642 return N0;
9643
9644 // fold (rot x, c) -> x iff (c % BitSize) == 0
9645 if (isPowerOf2_32(Value: Bitsize) && Bitsize > 1) {
9646 APInt ModuloMask(N1.getScalarValueSizeInBits(), Bitsize - 1);
9647 if (DAG.MaskedValueIsZero(Op: N1, Mask: ModuloMask))
9648 return N0;
9649 }
9650
9651 // fold (rot x, c) -> (rot x, c % BitSize)
9652 bool OutOfRange = false;
9653 auto MatchOutOfRange = [Bitsize, &OutOfRange](ConstantSDNode *C) {
9654 OutOfRange |= C->getAPIntValue().uge(RHS: Bitsize);
9655 return true;
9656 };
9657 if (ISD::matchUnaryPredicate(Op: N1, Match: MatchOutOfRange) && OutOfRange) {
9658 EVT AmtVT = N1.getValueType();
9659 SDValue Bits = DAG.getConstant(Val: Bitsize, DL: dl, VT: AmtVT);
9660 if (SDValue Amt =
9661 DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: AmtVT, Ops: {N1, Bits}))
9662 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0, N2: Amt);
9663 }
9664
9665 // rot i16 X, 8 --> bswap X
9666 auto *RotAmtC = isConstOrConstSplat(N: N1);
9667 if (RotAmtC && RotAmtC->getAPIntValue() == 8 &&
9668 VT.getScalarSizeInBits() == 16 && hasOperation(Opcode: ISD::BSWAP, VT))
9669 return DAG.getNode(Opcode: ISD::BSWAP, DL: dl, VT, Operand: N0);
9670
9671 // Simplify the operands using demanded-bits information.
9672 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
9673 return SDValue(N, 0);
9674
9675 // fold (rot* x, (trunc (and y, c))) -> (rot* x, (and (trunc y), (trunc c))).
9676 if (N1.getOpcode() == ISD::TRUNCATE &&
9677 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
9678 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
9679 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0, N2: NewOp1);
9680 }
9681
9682 unsigned NextOp = N0.getOpcode();
9683
9684 // fold (rot* (rot* x, c2), c1)
9685 // -> (rot* x, ((c1 % bitsize) +- (c2 % bitsize) + bitsize) % bitsize)
9686 if (NextOp == ISD::ROTL || NextOp == ISD::ROTR) {
9687 SDNode *C1 = DAG.isConstantIntBuildVectorOrConstantInt(N: N1);
9688 SDNode *C2 = DAG.isConstantIntBuildVectorOrConstantInt(N: N0.getOperand(i: 1));
9689 if (C1 && C2 && C1->getValueType(ResNo: 0) == C2->getValueType(ResNo: 0)) {
9690 EVT ShiftVT = C1->getValueType(ResNo: 0);
9691 bool SameSide = (N->getOpcode() == NextOp);
9692 unsigned CombineOp = SameSide ? ISD::ADD : ISD::SUB;
9693 SDValue BitsizeC = DAG.getConstant(Val: Bitsize, DL: dl, VT: ShiftVT);
9694 SDValue Norm1 = DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: ShiftVT,
9695 Ops: {N1, BitsizeC});
9696 SDValue Norm2 = DAG.FoldConstantArithmetic(Opcode: ISD::UREM, DL: dl, VT: ShiftVT,
9697 Ops: {N0.getOperand(i: 1), BitsizeC});
9698 if (Norm1 && Norm2)
9699 if (SDValue CombinedShift = DAG.FoldConstantArithmetic(
9700 Opcode: CombineOp, DL: dl, VT: ShiftVT, Ops: {Norm1, Norm2})) {
9701 CombinedShift = DAG.FoldConstantArithmetic(Opcode: ISD::ADD, DL: dl, VT: ShiftVT,
9702 Ops: {CombinedShift, BitsizeC});
9703 SDValue CombinedShiftNorm = DAG.FoldConstantArithmetic(
9704 Opcode: ISD::UREM, DL: dl, VT: ShiftVT, Ops: {CombinedShift, BitsizeC});
9705 return DAG.getNode(Opcode: N->getOpcode(), DL: dl, VT, N1: N0->getOperand(Num: 0),
9706 N2: CombinedShiftNorm);
9707 }
9708 }
9709 }
9710 return SDValue();
9711}
9712
9713SDValue DAGCombiner::visitSHL(SDNode *N) {
9714 SDValue N0 = N->getOperand(Num: 0);
9715 SDValue N1 = N->getOperand(Num: 1);
9716 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
9717 return V;
9718
9719 EVT VT = N0.getValueType();
9720 EVT ShiftVT = N1.getValueType();
9721 unsigned OpSizeInBits = VT.getScalarSizeInBits();
9722
9723 // fold (shl c1, c2) -> c1<<c2
9724 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL: SDLoc(N), VT, Ops: {N0, N1}))
9725 return C;
9726
9727 // fold vector ops
9728 if (VT.isVector()) {
9729 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL: SDLoc(N)))
9730 return FoldedVOp;
9731
9732 BuildVectorSDNode *N1CV = dyn_cast<BuildVectorSDNode>(Val&: N1);
9733 // If setcc produces all-one true value then:
9734 // (shl (and (setcc) N01CV) N1CV) -> (and (setcc) N01CV<<N1CV)
9735 if (N1CV && N1CV->isConstant()) {
9736 if (N0.getOpcode() == ISD::AND) {
9737 SDValue N00 = N0->getOperand(Num: 0);
9738 SDValue N01 = N0->getOperand(Num: 1);
9739 BuildVectorSDNode *N01CV = dyn_cast<BuildVectorSDNode>(Val&: N01);
9740
9741 if (N01CV && N01CV->isConstant() && N00.getOpcode() == ISD::SETCC &&
9742 TLI.getBooleanContents(Type: N00.getOperand(i: 0).getValueType()) ==
9743 TargetLowering::ZeroOrNegativeOneBooleanContent) {
9744 if (SDValue C =
9745 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL: SDLoc(N), VT, Ops: {N01, N1}))
9746 return DAG.getNode(Opcode: ISD::AND, DL: SDLoc(N), VT, N1: N00, N2: C);
9747 }
9748 }
9749 }
9750 }
9751
9752 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
9753 return NewSel;
9754
9755 // if (shl x, c) is known to be zero, return 0
9756 if (DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: OpSizeInBits)))
9757 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
9758
9759 // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), (trunc c))).
9760 if (N1.getOpcode() == ISD::TRUNCATE &&
9761 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
9762 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
9763 return DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: N0, N2: NewOp1);
9764 }
9765
9766 // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
9767 if (N0.getOpcode() == ISD::SHL) {
9768 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
9769 ConstantSDNode *RHS) {
9770 APInt c1 = LHS->getAPIntValue();
9771 APInt c2 = RHS->getAPIntValue();
9772 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
9773 return (c1 + c2).uge(RHS: OpSizeInBits);
9774 };
9775 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchOutOfRange))
9776 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
9777
9778 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
9779 ConstantSDNode *RHS) {
9780 APInt c1 = LHS->getAPIntValue();
9781 APInt c2 = RHS->getAPIntValue();
9782 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
9783 return (c1 + c2).ult(RHS: OpSizeInBits);
9784 };
9785 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchInRange)) {
9786 SDLoc DL(N);
9787 SDValue Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1, N2: N0.getOperand(i: 1));
9788 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Sum);
9789 }
9790 }
9791
9792 // fold (shl (ext (shl x, c1)), c2) -> (shl (ext x), (add c1, c2))
9793 // For this to be valid, the second form must not preserve any of the bits
9794 // that are shifted out by the inner shift in the first form. This means
9795 // the outer shift size must be >= the number of bits added by the ext.
9796 // As a corollary, we don't care what kind of ext it is.
9797 if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
9798 N0.getOpcode() == ISD::ANY_EXTEND ||
9799 N0.getOpcode() == ISD::SIGN_EXTEND) &&
9800 N0.getOperand(i: 0).getOpcode() == ISD::SHL) {
9801 SDValue N0Op0 = N0.getOperand(i: 0);
9802 SDValue InnerShiftAmt = N0Op0.getOperand(i: 1);
9803 EVT InnerVT = N0Op0.getValueType();
9804 uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
9805
9806 auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
9807 ConstantSDNode *RHS) {
9808 APInt c1 = LHS->getAPIntValue();
9809 APInt c2 = RHS->getAPIntValue();
9810 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
9811 return c2.uge(RHS: OpSizeInBits - InnerBitwidth) &&
9812 (c1 + c2).uge(RHS: OpSizeInBits);
9813 };
9814 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchOutOfRange,
9815 /*AllowUndefs*/ false,
9816 /*AllowTypeMismatch*/ true))
9817 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
9818
9819 auto MatchInRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
9820 ConstantSDNode *RHS) {
9821 APInt c1 = LHS->getAPIntValue();
9822 APInt c2 = RHS->getAPIntValue();
9823 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
9824 return c2.uge(RHS: OpSizeInBits - InnerBitwidth) &&
9825 (c1 + c2).ult(RHS: OpSizeInBits);
9826 };
9827 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchInRange,
9828 /*AllowUndefs*/ false,
9829 /*AllowTypeMismatch*/ true)) {
9830 SDLoc DL(N);
9831 SDValue Ext = DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0Op0.getOperand(i: 0));
9832 SDValue Sum = DAG.getZExtOrTrunc(Op: InnerShiftAmt, DL, VT: ShiftVT);
9833 Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1: Sum, N2: N1);
9834 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Ext, N2: Sum);
9835 }
9836 }
9837
9838 // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
9839 // Only fold this if the inner zext has no other uses to avoid increasing
9840 // the total number of instructions.
9841 if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
9842 N0.getOperand(i: 0).getOpcode() == ISD::SRL) {
9843 SDValue N0Op0 = N0.getOperand(i: 0);
9844 SDValue InnerShiftAmt = N0Op0.getOperand(i: 1);
9845
9846 auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) {
9847 APInt c1 = LHS->getAPIntValue();
9848 APInt c2 = RHS->getAPIntValue();
9849 zeroExtendToMatch(LHS&: c1, RHS&: c2);
9850 return c1.ult(RHS: VT.getScalarSizeInBits()) && (c1 == c2);
9851 };
9852 if (ISD::matchBinaryPredicate(LHS: InnerShiftAmt, RHS: N1, Match: MatchEqual,
9853 /*AllowUndefs*/ false,
9854 /*AllowTypeMismatch*/ true)) {
9855 SDLoc DL(N);
9856 EVT InnerShiftAmtVT = N0Op0.getOperand(i: 1).getValueType();
9857 SDValue NewSHL = DAG.getZExtOrTrunc(Op: N1, DL, VT: InnerShiftAmtVT);
9858 NewSHL = DAG.getNode(Opcode: ISD::SHL, DL, VT: N0Op0.getValueType(), N1: N0Op0, N2: NewSHL);
9859 AddToWorklist(N: NewSHL.getNode());
9860 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N0), VT, Operand: NewSHL);
9861 }
9862 }
9863
9864 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) {
9865 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
9866 ConstantSDNode *RHS) {
9867 const APInt &LHSC = LHS->getAPIntValue();
9868 const APInt &RHSC = RHS->getAPIntValue();
9869 return LHSC.ult(RHS: OpSizeInBits) && RHSC.ult(RHS: OpSizeInBits) &&
9870 LHSC.getZExtValue() <= RHSC.getZExtValue();
9871 };
9872
9873 SDLoc DL(N);
9874
9875 // fold (shl (sr[la] exact X, C1), C2) -> (shl X, (C2-C1)) if C1 <= C2
9876 // fold (shl (sr[la] exact X, C1), C2) -> (sr[la] X, (C2-C1)) if C1 >= C2
9877 if (N0->getFlags().hasExact()) {
9878 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
9879 /*AllowUndefs*/ false,
9880 /*AllowTypeMismatch*/ true)) {
9881 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
9882 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
9883 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
9884 }
9885 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
9886 /*AllowUndefs*/ false,
9887 /*AllowTypeMismatch*/ true)) {
9888 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
9889 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
9890 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
9891 }
9892 }
9893
9894 // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
9895 // (and (srl x, (sub c1, c2), MASK)
9896 // Only fold this if the inner shift has no other uses -- if it does,
9897 // folding this will increase the total number of instructions.
9898 if (N0.getOpcode() == ISD::SRL &&
9899 (N0.getOperand(i: 1) == N1 || N0.hasOneUse()) &&
9900 TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
9901 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
9902 /*AllowUndefs*/ false,
9903 /*AllowTypeMismatch*/ true)) {
9904 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
9905 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
9906 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9907 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: N01);
9908 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: Diff);
9909 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
9910 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
9911 }
9912 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
9913 /*AllowUndefs*/ false,
9914 /*AllowTypeMismatch*/ true)) {
9915 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
9916 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
9917 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
9918 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: N1);
9919 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
9920 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
9921 }
9922 }
9923 }
9924
9925 // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1))
9926 if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(i: 1) &&
9927 isConstantOrConstantVector(N: N1, /* No Opaques */ NoOpaques: true)) {
9928 SDLoc DL(N);
9929 SDValue AllBits = DAG.getAllOnesConstant(DL, VT);
9930 SDValue HiBitsMask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: AllBits, N2: N1);
9931 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: N0.getOperand(i: 0), N2: HiBitsMask);
9932 }
9933
9934 // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2)
9935 // fold (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
9936 // Variant of version done on multiply, except mul by a power of 2 is turned
9937 // into a shift.
9938 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
9939 N0->hasOneUse() && TLI.isDesirableToCommuteWithShift(N, Level)) {
9940 SDValue N01 = N0.getOperand(i: 1);
9941 if (SDValue Shl1 =
9942 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL: SDLoc(N1), VT, Ops: {N01, N1})) {
9943 SDValue Shl0 = DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N0), VT, N1: N0.getOperand(i: 0), N2: N1);
9944 AddToWorklist(N: Shl0.getNode());
9945 SDNodeFlags Flags;
9946 // Preserve the disjoint flag for Or.
9947 if (N0.getOpcode() == ISD::OR && N0->getFlags().hasDisjoint())
9948 Flags.setDisjoint(true);
9949 return DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N), VT, N1: Shl0, N2: Shl1, Flags);
9950 }
9951 }
9952
9953 // fold (shl (sext (add_nsw x, c1)), c2) -> (add (shl (sext x), c2), c1 << c2)
9954 // TODO: Add zext/add_nuw variant with suitable test coverage
9955 // TODO: Should we limit this with isLegalAddImmediate?
9956 if (N0.getOpcode() == ISD::SIGN_EXTEND &&
9957 N0.getOperand(i: 0).getOpcode() == ISD::ADD &&
9958 N0.getOperand(i: 0)->getFlags().hasNoSignedWrap() && N0->hasOneUse() &&
9959 N0.getOperand(i: 0)->hasOneUse() &&
9960 TLI.isDesirableToCommuteWithShift(N, Level)) {
9961 SDValue Add = N0.getOperand(i: 0);
9962 SDLoc DL(N0);
9963 if (SDValue ExtC = DAG.FoldConstantArithmetic(Opcode: N0.getOpcode(), DL, VT,
9964 Ops: {Add.getOperand(i: 1)})) {
9965 if (SDValue ShlC =
9966 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL, VT, Ops: {ExtC, N1})) {
9967 SDValue ExtX = DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: Add.getOperand(i: 0));
9968 SDValue ShlX = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: ExtX, N2: N1);
9969 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: ShlX, N2: ShlC);
9970 }
9971 }
9972 }
9973
9974 // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2)
9975 if (N0.getOpcode() == ISD::MUL && N0->hasOneUse()) {
9976 SDValue N01 = N0.getOperand(i: 1);
9977 if (SDValue Shl =
9978 DAG.FoldConstantArithmetic(Opcode: ISD::SHL, DL: SDLoc(N1), VT, Ops: {N01, N1}))
9979 return DAG.getNode(Opcode: ISD::MUL, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0), N2: Shl);
9980 }
9981
9982 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
9983 if (N1C && !N1C->isOpaque())
9984 if (SDValue NewSHL = visitShiftByConstant(N))
9985 return NewSHL;
9986
9987 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
9988 return SDValue(N, 0);
9989
9990 // Fold (shl (vscale * C0), C1) to (vscale * (C0 << C1)).
9991 if (N0.getOpcode() == ISD::VSCALE && N1C) {
9992 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
9993 const APInt &C1 = N1C->getAPIntValue();
9994 return DAG.getVScale(DL: SDLoc(N), VT, MulImm: C0 << C1);
9995 }
9996
9997 // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
9998 APInt ShlVal;
9999 if (N0.getOpcode() == ISD::STEP_VECTOR &&
10000 ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: ShlVal)) {
10001 const APInt &C0 = N0.getConstantOperandAPInt(i: 0);
10002 if (ShlVal.ult(RHS: C0.getBitWidth())) {
10003 APInt NewStep = C0 << ShlVal;
10004 return DAG.getStepVector(DL: SDLoc(N), ResVT: VT, StepVal: NewStep);
10005 }
10006 }
10007
10008 return SDValue();
10009}
10010
10011// Transform a right shift of a multiply into a multiply-high.
10012// Examples:
10013// (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b)
10014// (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b)
10015static SDValue combineShiftToMULH(SDNode *N, SelectionDAG &DAG,
10016 const TargetLowering &TLI) {
10017 assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
10018 "SRL or SRA node is required here!");
10019
10020 // Check the shift amount. Proceed with the transformation if the shift
10021 // amount is constant.
10022 ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N: N->getOperand(Num: 1));
10023 if (!ShiftAmtSrc)
10024 return SDValue();
10025
10026 SDLoc DL(N);
10027
10028 // The operation feeding into the shift must be a multiply.
10029 SDValue ShiftOperand = N->getOperand(Num: 0);
10030 if (ShiftOperand.getOpcode() != ISD::MUL)
10031 return SDValue();
10032
10033 // Both operands must be equivalent extend nodes.
10034 SDValue LeftOp = ShiftOperand.getOperand(i: 0);
10035 SDValue RightOp = ShiftOperand.getOperand(i: 1);
10036
10037 bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
10038 bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
10039
10040 if (!IsSignExt && !IsZeroExt)
10041 return SDValue();
10042
10043 EVT NarrowVT = LeftOp.getOperand(i: 0).getValueType();
10044 unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
10045
10046 // return true if U may use the lower bits of its operands
10047 auto UserOfLowerBits = [NarrowVTSize](SDNode *U) {
10048 if (U->getOpcode() != ISD::SRL && U->getOpcode() != ISD::SRA) {
10049 return true;
10050 }
10051 ConstantSDNode *UShiftAmtSrc = isConstOrConstSplat(N: U->getOperand(Num: 1));
10052 if (!UShiftAmtSrc) {
10053 return true;
10054 }
10055 unsigned UShiftAmt = UShiftAmtSrc->getZExtValue();
10056 return UShiftAmt < NarrowVTSize;
10057 };
10058
10059 // If the lower part of the MUL is also used and MUL_LOHI is supported
10060 // do not introduce the MULH in favor of MUL_LOHI
10061 unsigned MulLoHiOp = IsSignExt ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
10062 if (!ShiftOperand.hasOneUse() &&
10063 TLI.isOperationLegalOrCustom(Op: MulLoHiOp, VT: NarrowVT) &&
10064 llvm::any_of(Range: ShiftOperand->uses(), P: UserOfLowerBits)) {
10065 return SDValue();
10066 }
10067
10068 SDValue MulhRightOp;
10069 if (ConstantSDNode *Constant = isConstOrConstSplat(N: RightOp)) {
10070 unsigned ActiveBits = IsSignExt
10071 ? Constant->getAPIntValue().getSignificantBits()
10072 : Constant->getAPIntValue().getActiveBits();
10073 if (ActiveBits > NarrowVTSize)
10074 return SDValue();
10075 MulhRightOp = DAG.getConstant(
10076 Val: Constant->getAPIntValue().trunc(width: NarrowVT.getScalarSizeInBits()), DL,
10077 VT: NarrowVT);
10078 } else {
10079 if (LeftOp.getOpcode() != RightOp.getOpcode())
10080 return SDValue();
10081 // Check that the two extend nodes are the same type.
10082 if (NarrowVT != RightOp.getOperand(i: 0).getValueType())
10083 return SDValue();
10084 MulhRightOp = RightOp.getOperand(i: 0);
10085 }
10086
10087 EVT WideVT = LeftOp.getValueType();
10088 // Proceed with the transformation if the wide types match.
10089 assert((WideVT == RightOp.getValueType()) &&
10090 "Cannot have a multiply node with two different operand types.");
10091
10092 // Proceed with the transformation if the wide type is twice as large
10093 // as the narrow type.
10094 if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
10095 return SDValue();
10096
10097 // Check the shift amount with the narrow type size.
10098 // Proceed with the transformation if the shift amount is the width
10099 // of the narrow type.
10100 unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
10101 if (ShiftAmt != NarrowVTSize)
10102 return SDValue();
10103
10104 // If the operation feeding into the MUL is a sign extend (sext),
10105 // we use mulhs. Othewise, zero extends (zext) use mulhu.
10106 unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU;
10107
10108 // Combine to mulh if mulh is legal/custom for the narrow type on the target
10109 // or if it is a vector type then we could transform to an acceptable type and
10110 // rely on legalization to split/combine the result.
10111 if (NarrowVT.isVector()) {
10112 EVT TransformVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: NarrowVT);
10113 if (TransformVT.getVectorElementType() != NarrowVT.getVectorElementType() ||
10114 !TLI.isOperationLegalOrCustom(Op: MulhOpcode, VT: TransformVT))
10115 return SDValue();
10116 } else {
10117 if (!TLI.isOperationLegalOrCustom(Op: MulhOpcode, VT: NarrowVT))
10118 return SDValue();
10119 }
10120
10121 SDValue Result =
10122 DAG.getNode(Opcode: MulhOpcode, DL, VT: NarrowVT, N1: LeftOp.getOperand(i: 0), N2: MulhRightOp);
10123 bool IsSigned = N->getOpcode() == ISD::SRA;
10124 return DAG.getExtOrTrunc(IsSigned, Op: Result, DL, VT: WideVT);
10125}
10126
10127// fold (bswap (logic_op(bswap(x),y))) -> logic_op(x,bswap(y))
10128// This helper function accept SDNode with opcode ISD::BSWAP and ISD::BITREVERSE
10129static SDValue foldBitOrderCrossLogicOp(SDNode *N, SelectionDAG &DAG) {
10130 unsigned Opcode = N->getOpcode();
10131 if (Opcode != ISD::BSWAP && Opcode != ISD::BITREVERSE)
10132 return SDValue();
10133
10134 SDValue N0 = N->getOperand(Num: 0);
10135 EVT VT = N->getValueType(ResNo: 0);
10136 SDLoc DL(N);
10137 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) && N0.hasOneUse()) {
10138 SDValue OldLHS = N0.getOperand(i: 0);
10139 SDValue OldRHS = N0.getOperand(i: 1);
10140
10141 // If both operands are bswap/bitreverse, ignore the multiuse
10142 // Otherwise need to ensure logic_op and bswap/bitreverse(x) have one use.
10143 if (OldLHS.getOpcode() == Opcode && OldRHS.getOpcode() == Opcode) {
10144 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: OldLHS.getOperand(i: 0),
10145 N2: OldRHS.getOperand(i: 0));
10146 }
10147
10148 if (OldLHS.getOpcode() == Opcode && OldLHS.hasOneUse()) {
10149 SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, Operand: OldRHS);
10150 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: OldLHS.getOperand(i: 0),
10151 N2: NewBitReorder);
10152 }
10153
10154 if (OldRHS.getOpcode() == Opcode && OldRHS.hasOneUse()) {
10155 SDValue NewBitReorder = DAG.getNode(Opcode, DL, VT, Operand: OldLHS);
10156 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: NewBitReorder,
10157 N2: OldRHS.getOperand(i: 0));
10158 }
10159 }
10160 return SDValue();
10161}
10162
10163SDValue DAGCombiner::visitSRA(SDNode *N) {
10164 SDValue N0 = N->getOperand(Num: 0);
10165 SDValue N1 = N->getOperand(Num: 1);
10166 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
10167 return V;
10168
10169 EVT VT = N0.getValueType();
10170 unsigned OpSizeInBits = VT.getScalarSizeInBits();
10171
10172 // fold (sra c1, c2) -> (sra c1, c2)
10173 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SRA, DL: SDLoc(N), VT, Ops: {N0, N1}))
10174 return C;
10175
10176 // Arithmetic shifting an all-sign-bit value is a no-op.
10177 // fold (sra 0, x) -> 0
10178 // fold (sra -1, x) -> -1
10179 if (DAG.ComputeNumSignBits(Op: N0) == OpSizeInBits)
10180 return N0;
10181
10182 // fold vector ops
10183 if (VT.isVector())
10184 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL: SDLoc(N)))
10185 return FoldedVOp;
10186
10187 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
10188 return NewSel;
10189
10190 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10191
10192 // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
10193 // clamp (add c1, c2) to max shift.
10194 if (N0.getOpcode() == ISD::SRA) {
10195 SDLoc DL(N);
10196 EVT ShiftVT = N1.getValueType();
10197 EVT ShiftSVT = ShiftVT.getScalarType();
10198 SmallVector<SDValue, 16> ShiftValues;
10199
10200 auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) {
10201 APInt c1 = LHS->getAPIntValue();
10202 APInt c2 = RHS->getAPIntValue();
10203 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10204 APInt Sum = c1 + c2;
10205 unsigned ShiftSum =
10206 Sum.uge(RHS: OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
10207 ShiftValues.push_back(Elt: DAG.getConstant(Val: ShiftSum, DL, VT: ShiftSVT));
10208 return true;
10209 };
10210 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: SumOfShifts)) {
10211 SDValue ShiftValue;
10212 if (N1.getOpcode() == ISD::BUILD_VECTOR)
10213 ShiftValue = DAG.getBuildVector(VT: ShiftVT, DL, Ops: ShiftValues);
10214 else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
10215 assert(ShiftValues.size() == 1 &&
10216 "Expected matchBinaryPredicate to return one element for "
10217 "SPLAT_VECTORs");
10218 ShiftValue = DAG.getSplatVector(VT: ShiftVT, DL, Op: ShiftValues[0]);
10219 } else
10220 ShiftValue = ShiftValues[0];
10221 return DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: N0.getOperand(i: 0), N2: ShiftValue);
10222 }
10223 }
10224
10225 // fold (sra (shl X, m), (sub result_size, n))
10226 // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
10227 // result_size - n != m.
10228 // If truncate is free for the target sext(shl) is likely to result in better
10229 // code.
10230 if (N0.getOpcode() == ISD::SHL && N1C) {
10231 // Get the two constants of the shifts, CN0 = m, CN = n.
10232 const ConstantSDNode *N01C = isConstOrConstSplat(N: N0.getOperand(i: 1));
10233 if (N01C) {
10234 LLVMContext &Ctx = *DAG.getContext();
10235 // Determine what the truncate's result bitsize and type would be.
10236 EVT TruncVT = EVT::getIntegerVT(Context&: Ctx, BitWidth: OpSizeInBits - N1C->getZExtValue());
10237
10238 if (VT.isVector())
10239 TruncVT = EVT::getVectorVT(Context&: Ctx, VT: TruncVT, EC: VT.getVectorElementCount());
10240
10241 // Determine the residual right-shift amount.
10242 int ShiftAmt = N1C->getZExtValue() - N01C->getZExtValue();
10243
10244 // If the shift is not a no-op (in which case this should be just a sign
10245 // extend already), the truncated to type is legal, sign_extend is legal
10246 // on that type, and the truncate to that type is both legal and free,
10247 // perform the transform.
10248 if ((ShiftAmt > 0) &&
10249 TLI.isOperationLegalOrCustom(Op: ISD::SIGN_EXTEND, VT: TruncVT) &&
10250 TLI.isOperationLegalOrCustom(Op: ISD::TRUNCATE, VT) &&
10251 TLI.isTruncateFree(FromVT: VT, ToVT: TruncVT)) {
10252 SDLoc DL(N);
10253 SDValue Amt = DAG.getConstant(Val: ShiftAmt, DL,
10254 VT: getShiftAmountTy(LHSTy: N0.getOperand(i: 0).getValueType()));
10255 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT,
10256 N1: N0.getOperand(i: 0), N2: Amt);
10257 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: TruncVT,
10258 Operand: Shift);
10259 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL,
10260 VT: N->getValueType(ResNo: 0), Operand: Trunc);
10261 }
10262 }
10263 }
10264
10265 // We convert trunc/ext to opposing shifts in IR, but casts may be cheaper.
10266 // sra (add (shl X, N1C), AddC), N1C -->
10267 // sext (add (trunc X to (width - N1C)), AddC')
10268 // sra (sub AddC, (shl X, N1C)), N1C -->
10269 // sext (sub AddC1',(trunc X to (width - N1C)))
10270 if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && N1C &&
10271 N0.hasOneUse()) {
10272 bool IsAdd = N0.getOpcode() == ISD::ADD;
10273 SDValue Shl = N0.getOperand(i: IsAdd ? 0 : 1);
10274 if (Shl.getOpcode() == ISD::SHL && Shl.getOperand(i: 1) == N1 &&
10275 Shl.hasOneUse()) {
10276 // TODO: AddC does not need to be a splat.
10277 if (ConstantSDNode *AddC =
10278 isConstOrConstSplat(N: N0.getOperand(i: IsAdd ? 1 : 0))) {
10279 // Determine what the truncate's type would be and ask the target if
10280 // that is a free operation.
10281 LLVMContext &Ctx = *DAG.getContext();
10282 unsigned ShiftAmt = N1C->getZExtValue();
10283 EVT TruncVT = EVT::getIntegerVT(Context&: Ctx, BitWidth: OpSizeInBits - ShiftAmt);
10284 if (VT.isVector())
10285 TruncVT = EVT::getVectorVT(Context&: Ctx, VT: TruncVT, EC: VT.getVectorElementCount());
10286
10287 // TODO: The simple type check probably belongs in the default hook
10288 // implementation and/or target-specific overrides (because
10289 // non-simple types likely require masking when legalized), but
10290 // that restriction may conflict with other transforms.
10291 if (TruncVT.isSimple() && isTypeLegal(VT: TruncVT) &&
10292 TLI.isTruncateFree(FromVT: VT, ToVT: TruncVT)) {
10293 SDLoc DL(N);
10294 SDValue Trunc = DAG.getZExtOrTrunc(Op: Shl.getOperand(i: 0), DL, VT: TruncVT);
10295 SDValue ShiftC =
10296 DAG.getConstant(Val: AddC->getAPIntValue().lshr(shiftAmt: ShiftAmt).trunc(
10297 width: TruncVT.getScalarSizeInBits()),
10298 DL, VT: TruncVT);
10299 SDValue Add;
10300 if (IsAdd)
10301 Add = DAG.getNode(Opcode: ISD::ADD, DL, VT: TruncVT, N1: Trunc, N2: ShiftC);
10302 else
10303 Add = DAG.getNode(Opcode: ISD::SUB, DL, VT: TruncVT, N1: ShiftC, N2: Trunc);
10304 return DAG.getSExtOrTrunc(Op: Add, DL, VT);
10305 }
10306 }
10307 }
10308 }
10309
10310 // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), (trunc c))).
10311 if (N1.getOpcode() == ISD::TRUNCATE &&
10312 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
10313 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
10314 return DAG.getNode(Opcode: ISD::SRA, DL: SDLoc(N), VT, N1: N0, N2: NewOp1);
10315 }
10316
10317 // fold (sra (trunc (sra x, c1)), c2) -> (trunc (sra x, c1 + c2))
10318 // fold (sra (trunc (srl x, c1)), c2) -> (trunc (sra x, c1 + c2))
10319 // if c1 is equal to the number of bits the trunc removes
10320 // TODO - support non-uniform vector shift amounts.
10321 if (N0.getOpcode() == ISD::TRUNCATE &&
10322 (N0.getOperand(i: 0).getOpcode() == ISD::SRL ||
10323 N0.getOperand(i: 0).getOpcode() == ISD::SRA) &&
10324 N0.getOperand(i: 0).hasOneUse() &&
10325 N0.getOperand(i: 0).getOperand(i: 1).hasOneUse() && N1C) {
10326 SDValue N0Op0 = N0.getOperand(i: 0);
10327 if (ConstantSDNode *LargeShift = isConstOrConstSplat(N: N0Op0.getOperand(i: 1))) {
10328 EVT LargeVT = N0Op0.getValueType();
10329 unsigned TruncBits = LargeVT.getScalarSizeInBits() - OpSizeInBits;
10330 if (LargeShift->getAPIntValue() == TruncBits) {
10331 SDLoc DL(N);
10332 EVT LargeShiftVT = getShiftAmountTy(LHSTy: LargeVT);
10333 SDValue Amt = DAG.getZExtOrTrunc(Op: N1, DL, VT: LargeShiftVT);
10334 Amt = DAG.getNode(Opcode: ISD::ADD, DL, VT: LargeShiftVT, N1: Amt,
10335 N2: DAG.getConstant(Val: TruncBits, DL, VT: LargeShiftVT));
10336 SDValue SRA =
10337 DAG.getNode(Opcode: ISD::SRA, DL, VT: LargeVT, N1: N0Op0.getOperand(i: 0), N2: Amt);
10338 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: SRA);
10339 }
10340 }
10341 }
10342
10343 // Simplify, based on bits shifted out of the LHS.
10344 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10345 return SDValue(N, 0);
10346
10347 // If the sign bit is known to be zero, switch this to a SRL.
10348 if (DAG.SignBitIsZero(Op: N0))
10349 return DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(N), VT, N1: N0, N2: N1);
10350
10351 if (N1C && !N1C->isOpaque())
10352 if (SDValue NewSRA = visitShiftByConstant(N))
10353 return NewSRA;
10354
10355 // Try to transform this shift into a multiply-high if
10356 // it matches the appropriate pattern detected in combineShiftToMULH.
10357 if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
10358 return MULH;
10359
10360 // Attempt to convert a sra of a load into a narrower sign-extending load.
10361 if (SDValue NarrowLoad = reduceLoadWidth(N))
10362 return NarrowLoad;
10363
10364 return SDValue();
10365}
10366
10367SDValue DAGCombiner::visitSRL(SDNode *N) {
10368 SDValue N0 = N->getOperand(Num: 0);
10369 SDValue N1 = N->getOperand(Num: 1);
10370 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
10371 return V;
10372
10373 EVT VT = N0.getValueType();
10374 EVT ShiftVT = N1.getValueType();
10375 unsigned OpSizeInBits = VT.getScalarSizeInBits();
10376
10377 // fold (srl c1, c2) -> c1 >>u c2
10378 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::SRL, DL: SDLoc(N), VT, Ops: {N0, N1}))
10379 return C;
10380
10381 // fold vector ops
10382 if (VT.isVector())
10383 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL: SDLoc(N)))
10384 return FoldedVOp;
10385
10386 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
10387 return NewSel;
10388
10389 // if (srl x, c) is known to be zero, return 0
10390 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10391 if (N1C &&
10392 DAG.MaskedValueIsZero(Op: SDValue(N, 0), Mask: APInt::getAllOnes(numBits: OpSizeInBits)))
10393 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
10394
10395 // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
10396 if (N0.getOpcode() == ISD::SRL) {
10397 auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
10398 ConstantSDNode *RHS) {
10399 APInt c1 = LHS->getAPIntValue();
10400 APInt c2 = RHS->getAPIntValue();
10401 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10402 return (c1 + c2).uge(RHS: OpSizeInBits);
10403 };
10404 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchOutOfRange))
10405 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
10406
10407 auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
10408 ConstantSDNode *RHS) {
10409 APInt c1 = LHS->getAPIntValue();
10410 APInt c2 = RHS->getAPIntValue();
10411 zeroExtendToMatch(LHS&: c1, RHS&: c2, Offset: 1 /* Overflow Bit */);
10412 return (c1 + c2).ult(RHS: OpSizeInBits);
10413 };
10414 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchInRange)) {
10415 SDLoc DL(N);
10416 SDValue Sum = DAG.getNode(Opcode: ISD::ADD, DL, VT: ShiftVT, N1, N2: N0.getOperand(i: 1));
10417 return DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Sum);
10418 }
10419 }
10420
10421 if (N1C && N0.getOpcode() == ISD::TRUNCATE &&
10422 N0.getOperand(i: 0).getOpcode() == ISD::SRL) {
10423 SDValue InnerShift = N0.getOperand(i: 0);
10424 // TODO - support non-uniform vector shift amounts.
10425 if (auto *N001C = isConstOrConstSplat(N: InnerShift.getOperand(i: 1))) {
10426 uint64_t c1 = N001C->getZExtValue();
10427 uint64_t c2 = N1C->getZExtValue();
10428 EVT InnerShiftVT = InnerShift.getValueType();
10429 EVT ShiftAmtVT = InnerShift.getOperand(i: 1).getValueType();
10430 uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
10431 // srl (trunc (srl x, c1)), c2 --> 0 or (trunc (srl x, (add c1, c2)))
10432 // This is only valid if the OpSizeInBits + c1 = size of inner shift.
10433 if (c1 + OpSizeInBits == InnerShiftSize) {
10434 SDLoc DL(N);
10435 if (c1 + c2 >= InnerShiftSize)
10436 return DAG.getConstant(Val: 0, DL, VT);
10437 SDValue NewShiftAmt = DAG.getConstant(Val: c1 + c2, DL, VT: ShiftAmtVT);
10438 SDValue NewShift = DAG.getNode(Opcode: ISD::SRL, DL, VT: InnerShiftVT,
10439 N1: InnerShift.getOperand(i: 0), N2: NewShiftAmt);
10440 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: NewShift);
10441 }
10442 // In the more general case, we can clear the high bits after the shift:
10443 // srl (trunc (srl x, c1)), c2 --> trunc (and (srl x, (c1+c2)), Mask)
10444 if (N0.hasOneUse() && InnerShift.hasOneUse() &&
10445 c1 + c2 < InnerShiftSize) {
10446 SDLoc DL(N);
10447 SDValue NewShiftAmt = DAG.getConstant(Val: c1 + c2, DL, VT: ShiftAmtVT);
10448 SDValue NewShift = DAG.getNode(Opcode: ISD::SRL, DL, VT: InnerShiftVT,
10449 N1: InnerShift.getOperand(i: 0), N2: NewShiftAmt);
10450 SDValue Mask = DAG.getConstant(Val: APInt::getLowBitsSet(numBits: InnerShiftSize,
10451 loBitsSet: OpSizeInBits - c2),
10452 DL, VT: InnerShiftVT);
10453 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: InnerShiftVT, N1: NewShift, N2: Mask);
10454 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: And);
10455 }
10456 }
10457 }
10458
10459 // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
10460 // (and (srl x, (sub c2, c1), MASK)
10461 if (N0.getOpcode() == ISD::SHL &&
10462 (N0.getOperand(i: 1) == N1 || N0->hasOneUse()) &&
10463 TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
10464 auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
10465 ConstantSDNode *RHS) {
10466 const APInt &LHSC = LHS->getAPIntValue();
10467 const APInt &RHSC = RHS->getAPIntValue();
10468 return LHSC.ult(RHS: OpSizeInBits) && RHSC.ult(RHS: OpSizeInBits) &&
10469 LHSC.getZExtValue() <= RHSC.getZExtValue();
10470 };
10471 if (ISD::matchBinaryPredicate(LHS: N1, RHS: N0.getOperand(i: 1), Match: MatchShiftAmount,
10472 /*AllowUndefs*/ false,
10473 /*AllowTypeMismatch*/ true)) {
10474 SDLoc DL(N);
10475 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10476 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1: N01, N2: N1);
10477 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10478 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: N01);
10479 Mask = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Mask, N2: Diff);
10480 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10481 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
10482 }
10483 if (ISD::matchBinaryPredicate(LHS: N0.getOperand(i: 1), RHS: N1, Match: MatchShiftAmount,
10484 /*AllowUndefs*/ false,
10485 /*AllowTypeMismatch*/ true)) {
10486 SDLoc DL(N);
10487 SDValue N01 = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1), DL, VT: ShiftVT);
10488 SDValue Diff = DAG.getNode(Opcode: ISD::SUB, DL, VT: ShiftVT, N1, N2: N01);
10489 SDValue Mask = DAG.getAllOnesConstant(DL, VT);
10490 Mask = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Mask, N2: N1);
10491 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: N0.getOperand(i: 0), N2: Diff);
10492 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shift, N2: Mask);
10493 }
10494 }
10495
10496 // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask)
10497 // TODO - support non-uniform vector shift amounts.
10498 if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
10499 // Shifting in all undef bits?
10500 EVT SmallVT = N0.getOperand(i: 0).getValueType();
10501 unsigned BitSize = SmallVT.getScalarSizeInBits();
10502 if (N1C->getAPIntValue().uge(RHS: BitSize))
10503 return DAG.getUNDEF(VT);
10504
10505 if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, VT: SmallVT)) {
10506 uint64_t ShiftAmt = N1C->getZExtValue();
10507 SDLoc DL0(N0);
10508 SDValue SmallShift = DAG.getNode(Opcode: ISD::SRL, DL: DL0, VT: SmallVT,
10509 N1: N0.getOperand(i: 0),
10510 N2: DAG.getConstant(Val: ShiftAmt, DL: DL0,
10511 VT: getShiftAmountTy(LHSTy: SmallVT)));
10512 AddToWorklist(N: SmallShift.getNode());
10513 APInt Mask = APInt::getLowBitsSet(numBits: OpSizeInBits, loBitsSet: OpSizeInBits - ShiftAmt);
10514 SDLoc DL(N);
10515 return DAG.getNode(Opcode: ISD::AND, DL, VT,
10516 N1: DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: SmallShift),
10517 N2: DAG.getConstant(Val: Mask, DL, VT));
10518 }
10519 }
10520
10521 // fold (srl (sra X, Y), 31) -> (srl X, 31). This srl only looks at the sign
10522 // bit, which is unmodified by sra.
10523 if (N1C && N1C->getAPIntValue() == (OpSizeInBits - 1)) {
10524 if (N0.getOpcode() == ISD::SRA)
10525 return DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0), N2: N1);
10526 }
10527
10528 // fold (srl (ctlz x), "5") -> x iff x has one bit set (the low bit), and x has a power
10529 // of two bitwidth. The "5" represents (log2 (bitwidth x)).
10530 if (N1C && N0.getOpcode() == ISD::CTLZ &&
10531 isPowerOf2_32(Value: OpSizeInBits) &&
10532 N1C->getAPIntValue() == Log2_32(Value: OpSizeInBits)) {
10533 KnownBits Known = DAG.computeKnownBits(Op: N0.getOperand(i: 0));
10534
10535 // If any of the input bits are KnownOne, then the input couldn't be all
10536 // zeros, thus the result of the srl will always be zero.
10537 if (Known.One.getBoolValue()) return DAG.getConstant(Val: 0, DL: SDLoc(N0), VT);
10538
10539 // If all of the bits input the to ctlz node are known to be zero, then
10540 // the result of the ctlz is "32" and the result of the shift is one.
10541 APInt UnknownBits = ~Known.Zero;
10542 if (UnknownBits == 0) return DAG.getConstant(Val: 1, DL: SDLoc(N0), VT);
10543
10544 // Otherwise, check to see if there is exactly one bit input to the ctlz.
10545 if (UnknownBits.isPowerOf2()) {
10546 // Okay, we know that only that the single bit specified by UnknownBits
10547 // could be set on input to the CTLZ node. If this bit is set, the SRL
10548 // will return 0, if it is clear, it returns 1. Change the CTLZ/SRL pair
10549 // to an SRL/XOR pair, which is likely to simplify more.
10550 unsigned ShAmt = UnknownBits.countr_zero();
10551 SDValue Op = N0.getOperand(i: 0);
10552
10553 if (ShAmt) {
10554 SDLoc DL(N0);
10555 Op = DAG.getNode(Opcode: ISD::SRL, DL, VT, N1: Op,
10556 N2: DAG.getConstant(Val: ShAmt, DL,
10557 VT: getShiftAmountTy(LHSTy: Op.getValueType())));
10558 AddToWorklist(N: Op.getNode());
10559 }
10560
10561 SDLoc DL(N);
10562 return DAG.getNode(Opcode: ISD::XOR, DL, VT,
10563 N1: Op, N2: DAG.getConstant(Val: 1, DL, VT));
10564 }
10565 }
10566
10567 // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))).
10568 if (N1.getOpcode() == ISD::TRUNCATE &&
10569 N1.getOperand(i: 0).getOpcode() == ISD::AND) {
10570 if (SDValue NewOp1 = distributeTruncateThroughAnd(N: N1.getNode()))
10571 return DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(N), VT, N1: N0, N2: NewOp1);
10572 }
10573
10574 // fold operands of srl based on knowledge that the low bits are not
10575 // demanded.
10576 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10577 return SDValue(N, 0);
10578
10579 if (N1C && !N1C->isOpaque())
10580 if (SDValue NewSRL = visitShiftByConstant(N))
10581 return NewSRL;
10582
10583 // Attempt to convert a srl of a load into a narrower zero-extending load.
10584 if (SDValue NarrowLoad = reduceLoadWidth(N))
10585 return NarrowLoad;
10586
10587 // Here is a common situation. We want to optimize:
10588 //
10589 // %a = ...
10590 // %b = and i32 %a, 2
10591 // %c = srl i32 %b, 1
10592 // brcond i32 %c ...
10593 //
10594 // into
10595 //
10596 // %a = ...
10597 // %b = and %a, 2
10598 // %c = setcc eq %b, 0
10599 // brcond %c ...
10600 //
10601 // However when after the source operand of SRL is optimized into AND, the SRL
10602 // itself may not be optimized further. Look for it and add the BRCOND into
10603 // the worklist.
10604 //
10605 // The also tends to happen for binary operations when SimplifyDemandedBits
10606 // is involved.
10607 //
10608 // FIXME: This is unecessary if we process the DAG in topological order,
10609 // which we plan to do. This workaround can be removed once the DAG is
10610 // processed in topological order.
10611 if (N->hasOneUse()) {
10612 SDNode *Use = *N->use_begin();
10613
10614 // Look pass the truncate.
10615 if (Use->getOpcode() == ISD::TRUNCATE && Use->hasOneUse())
10616 Use = *Use->use_begin();
10617
10618 if (Use->getOpcode() == ISD::BRCOND || Use->getOpcode() == ISD::AND ||
10619 Use->getOpcode() == ISD::OR || Use->getOpcode() == ISD::XOR)
10620 AddToWorklist(N: Use);
10621 }
10622
10623 // Try to transform this shift into a multiply-high if
10624 // it matches the appropriate pattern detected in combineShiftToMULH.
10625 if (SDValue MULH = combineShiftToMULH(N, DAG, TLI))
10626 return MULH;
10627
10628 return SDValue();
10629}
10630
10631SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
10632 EVT VT = N->getValueType(ResNo: 0);
10633 SDValue N0 = N->getOperand(Num: 0);
10634 SDValue N1 = N->getOperand(Num: 1);
10635 SDValue N2 = N->getOperand(Num: 2);
10636 bool IsFSHL = N->getOpcode() == ISD::FSHL;
10637 unsigned BitWidth = VT.getScalarSizeInBits();
10638
10639 // fold (fshl N0, N1, 0) -> N0
10640 // fold (fshr N0, N1, 0) -> N1
10641 if (isPowerOf2_32(Value: BitWidth))
10642 if (DAG.MaskedValueIsZero(
10643 Op: N2, Mask: APInt(N2.getScalarValueSizeInBits(), BitWidth - 1)))
10644 return IsFSHL ? N0 : N1;
10645
10646 auto IsUndefOrZero = [](SDValue V) {
10647 return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
10648 };
10649
10650 // TODO - support non-uniform vector shift amounts.
10651 if (ConstantSDNode *Cst = isConstOrConstSplat(N: N2)) {
10652 EVT ShAmtTy = N2.getValueType();
10653
10654 // fold (fsh* N0, N1, c) -> (fsh* N0, N1, c % BitWidth)
10655 if (Cst->getAPIntValue().uge(RHS: BitWidth)) {
10656 uint64_t RotAmt = Cst->getAPIntValue().urem(RHS: BitWidth);
10657 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, N1: N0, N2: N1,
10658 N3: DAG.getConstant(Val: RotAmt, DL: SDLoc(N), VT: ShAmtTy));
10659 }
10660
10661 unsigned ShAmt = Cst->getZExtValue();
10662 if (ShAmt == 0)
10663 return IsFSHL ? N0 : N1;
10664
10665 // fold fshl(undef_or_zero, N1, C) -> lshr(N1, BW-C)
10666 // fold fshr(undef_or_zero, N1, C) -> lshr(N1, C)
10667 // fold fshl(N0, undef_or_zero, C) -> shl(N0, C)
10668 // fold fshr(N0, undef_or_zero, C) -> shl(N0, BW-C)
10669 if (IsUndefOrZero(N0))
10670 return DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(N), VT, N1,
10671 N2: DAG.getConstant(Val: IsFSHL ? BitWidth - ShAmt : ShAmt,
10672 DL: SDLoc(N), VT: ShAmtTy));
10673 if (IsUndefOrZero(N1))
10674 return DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: N0,
10675 N2: DAG.getConstant(Val: IsFSHL ? ShAmt : BitWidth - ShAmt,
10676 DL: SDLoc(N), VT: ShAmtTy));
10677
10678 // fold (fshl ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
10679 // fold (fshr ld1, ld0, c) -> (ld0[ofs]) iff ld0 and ld1 are consecutive.
10680 // TODO - bigendian support once we have test coverage.
10681 // TODO - can we merge this with CombineConseutiveLoads/MatchLoadCombine?
10682 // TODO - permit LHS EXTLOAD if extensions are shifted out.
10683 if ((BitWidth % 8) == 0 && (ShAmt % 8) == 0 && !VT.isVector() &&
10684 !DAG.getDataLayout().isBigEndian()) {
10685 auto *LHS = dyn_cast<LoadSDNode>(Val&: N0);
10686 auto *RHS = dyn_cast<LoadSDNode>(Val&: N1);
10687 if (LHS && RHS && LHS->isSimple() && RHS->isSimple() &&
10688 LHS->getAddressSpace() == RHS->getAddressSpace() &&
10689 (LHS->hasOneUse() || RHS->hasOneUse()) && ISD::isNON_EXTLoad(N: RHS) &&
10690 ISD::isNON_EXTLoad(N: LHS)) {
10691 if (DAG.areNonVolatileConsecutiveLoads(LD: LHS, Base: RHS, Bytes: BitWidth / 8, Dist: 1)) {
10692 SDLoc DL(RHS);
10693 uint64_t PtrOff =
10694 IsFSHL ? (((BitWidth - ShAmt) % BitWidth) / 8) : (ShAmt / 8);
10695 Align NewAlign = commonAlignment(A: RHS->getAlign(), Offset: PtrOff);
10696 unsigned Fast = 0;
10697 if (TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
10698 AddrSpace: RHS->getAddressSpace(), Alignment: NewAlign,
10699 Flags: RHS->getMemOperand()->getFlags(), Fast: &Fast) &&
10700 Fast) {
10701 SDValue NewPtr = DAG.getMemBasePlusOffset(
10702 Base: RHS->getBasePtr(), Offset: TypeSize::getFixed(ExactSize: PtrOff), DL);
10703 AddToWorklist(N: NewPtr.getNode());
10704 SDValue Load = DAG.getLoad(
10705 VT, dl: DL, Chain: RHS->getChain(), Ptr: NewPtr,
10706 PtrInfo: RHS->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign,
10707 MMOFlags: RHS->getMemOperand()->getFlags(), AAInfo: RHS->getAAInfo());
10708 // Replace the old load's chain with the new load's chain.
10709 WorklistRemover DeadNodes(*this);
10710 DAG.ReplaceAllUsesOfValueWith(From: N1.getValue(R: 1), To: Load.getValue(R: 1));
10711 return Load;
10712 }
10713 }
10714 }
10715 }
10716 }
10717
10718 // fold fshr(undef_or_zero, N1, N2) -> lshr(N1, N2)
10719 // fold fshl(N0, undef_or_zero, N2) -> shl(N0, N2)
10720 // iff We know the shift amount is in range.
10721 // TODO: when is it worth doing SUB(BW, N2) as well?
10722 if (isPowerOf2_32(Value: BitWidth)) {
10723 APInt ModuloBits(N2.getScalarValueSizeInBits(), BitWidth - 1);
10724 if (IsUndefOrZero(N0) && !IsFSHL && DAG.MaskedValueIsZero(Op: N2, Mask: ~ModuloBits))
10725 return DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(N), VT, N1, N2);
10726 if (IsUndefOrZero(N1) && IsFSHL && DAG.MaskedValueIsZero(Op: N2, Mask: ~ModuloBits))
10727 return DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: N0, N2);
10728 }
10729
10730 // fold (fshl N0, N0, N2) -> (rotl N0, N2)
10731 // fold (fshr N0, N0, N2) -> (rotr N0, N2)
10732 // TODO: Investigate flipping this rotate if only one is legal, if funnel shift
10733 // is legal as well we might be better off avoiding non-constant (BW - N2).
10734 unsigned RotOpc = IsFSHL ? ISD::ROTL : ISD::ROTR;
10735 if (N0 == N1 && hasOperation(Opcode: RotOpc, VT))
10736 return DAG.getNode(Opcode: RotOpc, DL: SDLoc(N), VT, N1: N0, N2);
10737
10738 // Simplify, based on bits shifted out of N0/N1.
10739 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
10740 return SDValue(N, 0);
10741
10742 return SDValue();
10743}
10744
10745SDValue DAGCombiner::visitSHLSAT(SDNode *N) {
10746 SDValue N0 = N->getOperand(Num: 0);
10747 SDValue N1 = N->getOperand(Num: 1);
10748 if (SDValue V = DAG.simplifyShift(X: N0, Y: N1))
10749 return V;
10750
10751 EVT VT = N0.getValueType();
10752
10753 // fold (*shlsat c1, c2) -> c1<<c2
10754 if (SDValue C =
10755 DAG.FoldConstantArithmetic(Opcode: N->getOpcode(), DL: SDLoc(N), VT, Ops: {N0, N1}))
10756 return C;
10757
10758 ConstantSDNode *N1C = isConstOrConstSplat(N: N1);
10759
10760 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::SHL, VT)) {
10761 // fold (sshlsat x, c) -> (shl x, c)
10762 if (N->getOpcode() == ISD::SSHLSAT && N1C &&
10763 N1C->getAPIntValue().ult(RHS: DAG.ComputeNumSignBits(Op: N0)))
10764 return DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: N0, N2: N1);
10765
10766 // fold (ushlsat x, c) -> (shl x, c)
10767 if (N->getOpcode() == ISD::USHLSAT && N1C &&
10768 N1C->getAPIntValue().ule(
10769 RHS: DAG.computeKnownBits(Op: N0).countMinLeadingZeros()))
10770 return DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N), VT, N1: N0, N2: N1);
10771 }
10772
10773 return SDValue();
10774}
10775
10776// Given a ABS node, detect the following patterns:
10777// (ABS (SUB (EXTEND a), (EXTEND b))).
10778// (TRUNC (ABS (SUB (EXTEND a), (EXTEND b)))).
10779// Generates UABD/SABD instruction.
10780SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
10781 EVT SrcVT = N->getValueType(ResNo: 0);
10782
10783 if (N->getOpcode() == ISD::TRUNCATE)
10784 N = N->getOperand(Num: 0).getNode();
10785
10786 if (N->getOpcode() != ISD::ABS)
10787 return SDValue();
10788
10789 EVT VT = N->getValueType(ResNo: 0);
10790 SDValue AbsOp1 = N->getOperand(Num: 0);
10791 SDValue Op0, Op1;
10792
10793 if (AbsOp1.getOpcode() != ISD::SUB)
10794 return SDValue();
10795
10796 Op0 = AbsOp1.getOperand(i: 0);
10797 Op1 = AbsOp1.getOperand(i: 1);
10798
10799 unsigned Opc0 = Op0.getOpcode();
10800
10801 // Check if the operands of the sub are (zero|sign)-extended.
10802 // TODO: Should we use ValueTracking instead?
10803 if (Opc0 != Op1.getOpcode() ||
10804 (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
10805 Opc0 != ISD::SIGN_EXTEND_INREG)) {
10806 // fold (abs (sub nsw x, y)) -> abds(x, y)
10807 if (AbsOp1->getFlags().hasNoSignedWrap() && hasOperation(Opcode: ISD::ABDS, VT) &&
10808 TLI.preferABDSToABSWithNSW(VT)) {
10809 SDValue ABD = DAG.getNode(Opcode: ISD::ABDS, DL, VT, N1: Op0, N2: Op1);
10810 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
10811 }
10812 return SDValue();
10813 }
10814
10815 EVT VT0, VT1;
10816 if (Opc0 == ISD::SIGN_EXTEND_INREG) {
10817 VT0 = cast<VTSDNode>(Val: Op0.getOperand(i: 1))->getVT();
10818 VT1 = cast<VTSDNode>(Val: Op1.getOperand(i: 1))->getVT();
10819 } else {
10820 VT0 = Op0.getOperand(i: 0).getValueType();
10821 VT1 = Op1.getOperand(i: 0).getValueType();
10822 }
10823 unsigned ABDOpcode = (Opc0 == ISD::ZERO_EXTEND) ? ISD::ABDU : ISD::ABDS;
10824
10825 // fold abs(sext(x) - sext(y)) -> zext(abds(x, y))
10826 // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
10827 EVT MaxVT = VT0.bitsGT(VT: VT1) ? VT0 : VT1;
10828 if ((VT0 == MaxVT || Op0->hasOneUse()) &&
10829 (VT1 == MaxVT || Op1->hasOneUse()) && hasOperation(Opcode: ABDOpcode, VT: MaxVT)) {
10830 SDValue ABD = DAG.getNode(Opcode: ABDOpcode, DL, VT: MaxVT,
10831 N1: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MaxVT, Operand: Op0),
10832 N2: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: MaxVT, Operand: Op1));
10833 ABD = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: ABD);
10834 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
10835 }
10836
10837 // fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
10838 // fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
10839 if (hasOperation(Opcode: ABDOpcode, VT)) {
10840 SDValue ABD = DAG.getNode(Opcode: ABDOpcode, DL, VT, N1: Op0, N2: Op1);
10841 return DAG.getZExtOrTrunc(Op: ABD, DL, VT: SrcVT);
10842 }
10843
10844 return SDValue();
10845}
10846
10847SDValue DAGCombiner::visitABS(SDNode *N) {
10848 SDValue N0 = N->getOperand(Num: 0);
10849 EVT VT = N->getValueType(ResNo: 0);
10850 SDLoc DL(N);
10851
10852 // fold (abs c1) -> c2
10853 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::ABS, DL, VT, Ops: {N0}))
10854 return C;
10855 // fold (abs (abs x)) -> (abs x)
10856 if (N0.getOpcode() == ISD::ABS)
10857 return N0;
10858 // fold (abs x) -> x iff not-negative
10859 if (DAG.SignBitIsZero(Op: N0))
10860 return N0;
10861
10862 if (SDValue ABD = foldABSToABD(N, DL))
10863 return ABD;
10864
10865 // fold (abs (sign_extend_inreg x)) -> (zero_extend (abs (truncate x)))
10866 // iff zero_extend/truncate are free.
10867 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
10868 EVT ExtVT = cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT();
10869 if (TLI.isTruncateFree(FromVT: VT, ToVT: ExtVT) && TLI.isZExtFree(FromTy: ExtVT, ToTy: VT) &&
10870 TLI.isTypeDesirableForOp(ISD::ABS, VT: ExtVT) &&
10871 hasOperation(Opcode: ISD::ABS, VT: ExtVT)) {
10872 return DAG.getNode(
10873 Opcode: ISD::ZERO_EXTEND, DL, VT,
10874 Operand: DAG.getNode(Opcode: ISD::ABS, DL, VT: ExtVT,
10875 Operand: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ExtVT, Operand: N0.getOperand(i: 0))));
10876 }
10877 }
10878
10879 return SDValue();
10880}
10881
10882SDValue DAGCombiner::visitBSWAP(SDNode *N) {
10883 SDValue N0 = N->getOperand(Num: 0);
10884 EVT VT = N->getValueType(ResNo: 0);
10885 SDLoc DL(N);
10886
10887 // fold (bswap c1) -> c2
10888 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::BSWAP, DL, VT, Ops: {N0}))
10889 return C;
10890 // fold (bswap (bswap x)) -> x
10891 if (N0.getOpcode() == ISD::BSWAP)
10892 return N0.getOperand(i: 0);
10893
10894 // Canonicalize bswap(bitreverse(x)) -> bitreverse(bswap(x)). If bitreverse
10895 // isn't supported, it will be expanded to bswap followed by a manual reversal
10896 // of bits in each byte. By placing bswaps before bitreverse, we can remove
10897 // the two bswaps if the bitreverse gets expanded.
10898 if (N0.getOpcode() == ISD::BITREVERSE && N0.hasOneUse()) {
10899 SDValue BSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: N0.getOperand(i: 0));
10900 return DAG.getNode(Opcode: ISD::BITREVERSE, DL, VT, Operand: BSwap);
10901 }
10902
10903 // fold (bswap shl(x,c)) -> (zext(bswap(trunc(shl(x,sub(c,bw/2))))))
10904 // iff x >= bw/2 (i.e. lower half is known zero)
10905 unsigned BW = VT.getScalarSizeInBits();
10906 if (BW >= 32 && N0.getOpcode() == ISD::SHL && N0.hasOneUse()) {
10907 auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
10908 EVT HalfVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: BW / 2);
10909 if (ShAmt && ShAmt->getAPIntValue().ult(RHS: BW) &&
10910 ShAmt->getZExtValue() >= (BW / 2) &&
10911 (ShAmt->getZExtValue() % 16) == 0 && TLI.isTypeLegal(VT: HalfVT) &&
10912 TLI.isTruncateFree(FromVT: VT, ToVT: HalfVT) &&
10913 (!LegalOperations || hasOperation(Opcode: ISD::BSWAP, VT: HalfVT))) {
10914 SDValue Res = N0.getOperand(i: 0);
10915 if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
10916 Res = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Res,
10917 N2: DAG.getConstant(Val: NewShAmt, DL, VT: getShiftAmountTy(LHSTy: VT)));
10918 Res = DAG.getZExtOrTrunc(Op: Res, DL, VT: HalfVT);
10919 Res = DAG.getNode(Opcode: ISD::BSWAP, DL, VT: HalfVT, Operand: Res);
10920 return DAG.getZExtOrTrunc(Op: Res, DL, VT);
10921 }
10922 }
10923
10924 // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as
10925 // inverse-shift-of-bswap:
10926 // bswap (X u<< C) --> (bswap X) u>> C
10927 // bswap (X u>> C) --> (bswap X) u<< C
10928 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
10929 N0.hasOneUse()) {
10930 auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
10931 if (ShAmt && ShAmt->getAPIntValue().ult(RHS: BW) &&
10932 ShAmt->getZExtValue() % 8 == 0) {
10933 SDValue NewSwap = DAG.getNode(Opcode: ISD::BSWAP, DL, VT, Operand: N0.getOperand(i: 0));
10934 unsigned InverseShift = N0.getOpcode() == ISD::SHL ? ISD::SRL : ISD::SHL;
10935 return DAG.getNode(Opcode: InverseShift, DL, VT, N1: NewSwap, N2: N0.getOperand(i: 1));
10936 }
10937 }
10938
10939 if (SDValue V = foldBitOrderCrossLogicOp(N, DAG))
10940 return V;
10941
10942 return SDValue();
10943}
10944
10945SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
10946 SDValue N0 = N->getOperand(Num: 0);
10947 EVT VT = N->getValueType(ResNo: 0);
10948 SDLoc DL(N);
10949
10950 // fold (bitreverse c1) -> c2
10951 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::BITREVERSE, DL, VT, Ops: {N0}))
10952 return C;
10953 // fold (bitreverse (bitreverse x)) -> x
10954 if (N0.getOpcode() == ISD::BITREVERSE)
10955 return N0.getOperand(i: 0);
10956 return SDValue();
10957}
10958
10959SDValue DAGCombiner::visitCTLZ(SDNode *N) {
10960 SDValue N0 = N->getOperand(Num: 0);
10961 EVT VT = N->getValueType(ResNo: 0);
10962 SDLoc DL(N);
10963
10964 // fold (ctlz c1) -> c2
10965 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTLZ, DL, VT, Ops: {N0}))
10966 return C;
10967
10968 // If the value is known never to be zero, switch to the undef version.
10969 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTLZ_ZERO_UNDEF, VT))
10970 if (DAG.isKnownNeverZero(Op: N0))
10971 return DAG.getNode(Opcode: ISD::CTLZ_ZERO_UNDEF, DL, VT, Operand: N0);
10972
10973 return SDValue();
10974}
10975
10976SDValue DAGCombiner::visitCTLZ_ZERO_UNDEF(SDNode *N) {
10977 SDValue N0 = N->getOperand(Num: 0);
10978 EVT VT = N->getValueType(ResNo: 0);
10979 SDLoc DL(N);
10980
10981 // fold (ctlz_zero_undef c1) -> c2
10982 if (SDValue C =
10983 DAG.FoldConstantArithmetic(Opcode: ISD::CTLZ_ZERO_UNDEF, DL, VT, Ops: {N0}))
10984 return C;
10985 return SDValue();
10986}
10987
10988SDValue DAGCombiner::visitCTTZ(SDNode *N) {
10989 SDValue N0 = N->getOperand(Num: 0);
10990 EVT VT = N->getValueType(ResNo: 0);
10991 SDLoc DL(N);
10992
10993 // fold (cttz c1) -> c2
10994 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTTZ, DL, VT, Ops: {N0}))
10995 return C;
10996
10997 // If the value is known never to be zero, switch to the undef version.
10998 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTTZ_ZERO_UNDEF, VT))
10999 if (DAG.isKnownNeverZero(Op: N0))
11000 return DAG.getNode(Opcode: ISD::CTTZ_ZERO_UNDEF, DL, VT, Operand: N0);
11001
11002 return SDValue();
11003}
11004
11005SDValue DAGCombiner::visitCTTZ_ZERO_UNDEF(SDNode *N) {
11006 SDValue N0 = N->getOperand(Num: 0);
11007 EVT VT = N->getValueType(ResNo: 0);
11008 SDLoc DL(N);
11009
11010 // fold (cttz_zero_undef c1) -> c2
11011 if (SDValue C =
11012 DAG.FoldConstantArithmetic(Opcode: ISD::CTTZ_ZERO_UNDEF, DL, VT, Ops: {N0}))
11013 return C;
11014 return SDValue();
11015}
11016
11017SDValue DAGCombiner::visitCTPOP(SDNode *N) {
11018 SDValue N0 = N->getOperand(Num: 0);
11019 EVT VT = N->getValueType(ResNo: 0);
11020 unsigned NumBits = VT.getScalarSizeInBits();
11021 SDLoc DL(N);
11022
11023 // fold (ctpop c1) -> c2
11024 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::CTPOP, DL, VT, Ops: {N0}))
11025 return C;
11026
11027 // If the source is being shifted, but doesn't affect any active bits,
11028 // then we can call CTPOP on the shift source directly.
11029 if (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SHL) {
11030 if (ConstantSDNode *AmtC = isConstOrConstSplat(N: N0.getOperand(i: 1))) {
11031 const APInt &Amt = AmtC->getAPIntValue();
11032 if (Amt.ult(RHS: NumBits)) {
11033 KnownBits KnownSrc = DAG.computeKnownBits(Op: N0.getOperand(i: 0));
11034 if ((N0.getOpcode() == ISD::SRL &&
11035 Amt.ule(RHS: KnownSrc.countMinTrailingZeros())) ||
11036 (N0.getOpcode() == ISD::SHL &&
11037 Amt.ule(RHS: KnownSrc.countMinLeadingZeros()))) {
11038 return DAG.getNode(Opcode: ISD::CTPOP, DL, VT, Operand: N0.getOperand(i: 0));
11039 }
11040 }
11041 }
11042 }
11043
11044 // If the upper bits are known to be zero, then see if its profitable to
11045 // only count the lower bits.
11046 if (VT.isScalarInteger() && NumBits > 8 && (NumBits & 1) == 0) {
11047 EVT HalfVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NumBits / 2);
11048 if (hasOperation(Opcode: ISD::CTPOP, VT: HalfVT) &&
11049 TLI.isTypeDesirableForOp(ISD::CTPOP, VT: HalfVT) &&
11050 TLI.isTruncateFree(Val: N0, VT2: HalfVT) && TLI.isZExtFree(FromTy: HalfVT, ToTy: VT)) {
11051 APInt UpperBits = APInt::getHighBitsSet(numBits: NumBits, hiBitsSet: NumBits / 2);
11052 if (DAG.MaskedValueIsZero(Op: N0, Mask: UpperBits)) {
11053 SDValue PopCnt = DAG.getNode(Opcode: ISD::CTPOP, DL, VT: HalfVT,
11054 Operand: DAG.getZExtOrTrunc(Op: N0, DL, VT: HalfVT));
11055 return DAG.getZExtOrTrunc(Op: PopCnt, DL, VT);
11056 }
11057 }
11058 }
11059
11060 return SDValue();
11061}
11062
11063// FIXME: This should be checking for no signed zeros on individual operands, as
11064// well as no nans.
11065static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS,
11066 SDValue RHS,
11067 const TargetLowering &TLI) {
11068 const TargetOptions &Options = DAG.getTarget().Options;
11069 EVT VT = LHS.getValueType();
11070
11071 return Options.NoSignedZerosFPMath && VT.isFloatingPoint() &&
11072 TLI.isProfitableToCombineMinNumMaxNum(VT) &&
11073 DAG.isKnownNeverNaN(Op: LHS) && DAG.isKnownNeverNaN(Op: RHS);
11074}
11075
11076static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
11077 SDValue RHS, SDValue True, SDValue False,
11078 ISD::CondCode CC,
11079 const TargetLowering &TLI,
11080 SelectionDAG &DAG) {
11081 EVT TransformVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT);
11082 switch (CC) {
11083 case ISD::SETOLT:
11084 case ISD::SETOLE:
11085 case ISD::SETLT:
11086 case ISD::SETLE:
11087 case ISD::SETULT:
11088 case ISD::SETULE: {
11089 // Since it's known never nan to get here already, either fminnum or
11090 // fminnum_ieee are OK. Try the ieee version first, since it's fminnum is
11091 // expanded in terms of it.
11092 unsigned IEEEOpcode = (LHS == True) ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
11093 if (TLI.isOperationLegalOrCustom(Op: IEEEOpcode, VT))
11094 return DAG.getNode(Opcode: IEEEOpcode, DL, VT, N1: LHS, N2: RHS);
11095
11096 unsigned Opcode = (LHS == True) ? ISD::FMINNUM : ISD::FMAXNUM;
11097 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT: TransformVT))
11098 return DAG.getNode(Opcode, DL, VT, N1: LHS, N2: RHS);
11099 return SDValue();
11100 }
11101 case ISD::SETOGT:
11102 case ISD::SETOGE:
11103 case ISD::SETGT:
11104 case ISD::SETGE:
11105 case ISD::SETUGT:
11106 case ISD::SETUGE: {
11107 unsigned IEEEOpcode = (LHS == True) ? ISD::FMAXNUM_IEEE : ISD::FMINNUM_IEEE;
11108 if (TLI.isOperationLegalOrCustom(Op: IEEEOpcode, VT))
11109 return DAG.getNode(Opcode: IEEEOpcode, DL, VT, N1: LHS, N2: RHS);
11110
11111 unsigned Opcode = (LHS == True) ? ISD::FMAXNUM : ISD::FMINNUM;
11112 if (TLI.isOperationLegalOrCustom(Op: Opcode, VT: TransformVT))
11113 return DAG.getNode(Opcode, DL, VT, N1: LHS, N2: RHS);
11114 return SDValue();
11115 }
11116 default:
11117 return SDValue();
11118 }
11119}
11120
11121/// Generate Min/Max node
11122SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
11123 SDValue RHS, SDValue True,
11124 SDValue False, ISD::CondCode CC) {
11125 if ((LHS == True && RHS == False) || (LHS == False && RHS == True))
11126 return combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True, False, CC, TLI, DAG);
11127
11128 // If we can't directly match this, try to see if we can pull an fneg out of
11129 // the select.
11130 SDValue NegTrue = TLI.getCheaperOrNeutralNegatedExpression(
11131 Op: True, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize);
11132 if (!NegTrue)
11133 return SDValue();
11134
11135 HandleSDNode NegTrueHandle(NegTrue);
11136
11137 // Try to unfold an fneg from the select if we are comparing the negated
11138 // constant.
11139 //
11140 // select (setcc x, K) (fneg x), -K -> fneg(minnum(x, K))
11141 //
11142 // TODO: Handle fabs
11143 if (LHS == NegTrue) {
11144 // If we can't directly match this, try to see if we can pull an fneg out of
11145 // the select.
11146 SDValue NegRHS = TLI.getCheaperOrNeutralNegatedExpression(
11147 Op: RHS, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize);
11148 if (NegRHS) {
11149 HandleSDNode NegRHSHandle(NegRHS);
11150 if (NegRHS == False) {
11151 SDValue Combined = combineMinNumMaxNumImpl(DL, VT, LHS, RHS, True: NegTrue,
11152 False, CC, TLI, DAG);
11153 if (Combined)
11154 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: Combined);
11155 }
11156 }
11157 }
11158
11159 return SDValue();
11160}
11161
11162/// If a (v)select has a condition value that is a sign-bit test, try to smear
11163/// the condition operand sign-bit across the value width and use it as a mask.
11164static SDValue foldSelectOfConstantsUsingSra(SDNode *N, SelectionDAG &DAG) {
11165 SDValue Cond = N->getOperand(Num: 0);
11166 SDValue C1 = N->getOperand(Num: 1);
11167 SDValue C2 = N->getOperand(Num: 2);
11168 if (!isConstantOrConstantVector(N: C1) || !isConstantOrConstantVector(N: C2))
11169 return SDValue();
11170
11171 EVT VT = N->getValueType(ResNo: 0);
11172 if (Cond.getOpcode() != ISD::SETCC || !Cond.hasOneUse() ||
11173 VT != Cond.getOperand(i: 0).getValueType())
11174 return SDValue();
11175
11176 // The inverted-condition + commuted-select variants of these patterns are
11177 // canonicalized to these forms in IR.
11178 SDValue X = Cond.getOperand(i: 0);
11179 SDValue CondC = Cond.getOperand(i: 1);
11180 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
11181 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: CondC) &&
11182 isAllOnesOrAllOnesSplat(V: C2)) {
11183 // i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
11184 SDLoc DL(N);
11185 SDValue ShAmtC = DAG.getConstant(Val: X.getScalarValueSizeInBits() - 1, DL, VT);
11186 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: X, N2: ShAmtC);
11187 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Sra, N2: C1);
11188 }
11189 if (CC == ISD::SETLT && isNullOrNullSplat(V: CondC) && isNullOrNullSplat(V: C2)) {
11190 // i8 X < 0 ? C1 : 0 --> (X >>s 7) & C1
11191 SDLoc DL(N);
11192 SDValue ShAmtC = DAG.getConstant(Val: X.getScalarValueSizeInBits() - 1, DL, VT);
11193 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: X, N2: ShAmtC);
11194 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Sra, N2: C1);
11195 }
11196 return SDValue();
11197}
11198
11199static bool shouldConvertSelectOfConstantsToMath(const SDValue &Cond, EVT VT,
11200 const TargetLowering &TLI) {
11201 if (!TLI.convertSelectOfConstantsToMath(VT))
11202 return false;
11203
11204 if (Cond.getOpcode() != ISD::SETCC || !Cond->hasOneUse())
11205 return true;
11206 if (!TLI.isOperationLegalOrCustom(Op: ISD::SELECT_CC, VT))
11207 return true;
11208
11209 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
11210 if (CC == ISD::SETLT && isNullOrNullSplat(V: Cond.getOperand(i: 1)))
11211 return true;
11212 if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: Cond.getOperand(i: 1)))
11213 return true;
11214
11215 return false;
11216}
11217
11218SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
11219 SDValue Cond = N->getOperand(Num: 0);
11220 SDValue N1 = N->getOperand(Num: 1);
11221 SDValue N2 = N->getOperand(Num: 2);
11222 EVT VT = N->getValueType(ResNo: 0);
11223 EVT CondVT = Cond.getValueType();
11224 SDLoc DL(N);
11225
11226 if (!VT.isInteger())
11227 return SDValue();
11228
11229 auto *C1 = dyn_cast<ConstantSDNode>(Val&: N1);
11230 auto *C2 = dyn_cast<ConstantSDNode>(Val&: N2);
11231 if (!C1 || !C2)
11232 return SDValue();
11233
11234 if (CondVT != MVT::i1 || LegalOperations) {
11235 // fold (select Cond, 0, 1) -> (xor Cond, 1)
11236 // We can't do this reliably if integer based booleans have different contents
11237 // to floating point based booleans. This is because we can't tell whether we
11238 // have an integer-based boolean or a floating-point-based boolean unless we
11239 // can find the SETCC that produced it and inspect its operands. This is
11240 // fairly easy if C is the SETCC node, but it can potentially be
11241 // undiscoverable (or not reasonably discoverable). For example, it could be
11242 // in another basic block or it could require searching a complicated
11243 // expression.
11244 if (CondVT.isInteger() &&
11245 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/true) ==
11246 TargetLowering::ZeroOrOneBooleanContent &&
11247 TLI.getBooleanContents(/*isVec*/false, /*isFloat*/false) ==
11248 TargetLowering::ZeroOrOneBooleanContent &&
11249 C1->isZero() && C2->isOne()) {
11250 SDValue NotCond =
11251 DAG.getNode(Opcode: ISD::XOR, DL, VT: CondVT, N1: Cond, N2: DAG.getConstant(Val: 1, DL, VT: CondVT));
11252 if (VT.bitsEq(VT: CondVT))
11253 return NotCond;
11254 return DAG.getZExtOrTrunc(Op: NotCond, DL, VT);
11255 }
11256
11257 return SDValue();
11258 }
11259
11260 // Only do this before legalization to avoid conflicting with target-specific
11261 // transforms in the other direction (create a select from a zext/sext). There
11262 // is also a target-independent combine here in DAGCombiner in the other
11263 // direction for (select Cond, -1, 0) when the condition is not i1.
11264 assert(CondVT == MVT::i1 && !LegalOperations);
11265
11266 // select Cond, 1, 0 --> zext (Cond)
11267 if (C1->isOne() && C2->isZero())
11268 return DAG.getZExtOrTrunc(Op: Cond, DL, VT);
11269
11270 // select Cond, -1, 0 --> sext (Cond)
11271 if (C1->isAllOnes() && C2->isZero())
11272 return DAG.getSExtOrTrunc(Op: Cond, DL, VT);
11273
11274 // select Cond, 0, 1 --> zext (!Cond)
11275 if (C1->isZero() && C2->isOne()) {
11276 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11277 NotCond = DAG.getZExtOrTrunc(Op: NotCond, DL, VT);
11278 return NotCond;
11279 }
11280
11281 // select Cond, 0, -1 --> sext (!Cond)
11282 if (C1->isZero() && C2->isAllOnes()) {
11283 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11284 NotCond = DAG.getSExtOrTrunc(Op: NotCond, DL, VT);
11285 return NotCond;
11286 }
11287
11288 // Use a target hook because some targets may prefer to transform in the
11289 // other direction.
11290 if (!shouldConvertSelectOfConstantsToMath(Cond, VT, TLI))
11291 return SDValue();
11292
11293 // For any constants that differ by 1, we can transform the select into
11294 // an extend and add.
11295 const APInt &C1Val = C1->getAPIntValue();
11296 const APInt &C2Val = C2->getAPIntValue();
11297
11298 // select Cond, C1, C1-1 --> add (zext Cond), C1-1
11299 if (C1Val - 1 == C2Val) {
11300 Cond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
11301 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Cond, N2);
11302 }
11303
11304 // select Cond, C1, C1+1 --> add (sext Cond), C1+1
11305 if (C1Val + 1 == C2Val) {
11306 Cond = DAG.getSExtOrTrunc(Op: Cond, DL, VT);
11307 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Cond, N2);
11308 }
11309
11310 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2)
11311 if (C1Val.isPowerOf2() && C2Val.isZero()) {
11312 Cond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
11313 SDValue ShAmtC =
11314 DAG.getShiftAmountConstant(Val: C1Val.exactLogBase2(), VT, DL);
11315 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Cond, N2: ShAmtC);
11316 }
11317
11318 // select Cond, -1, C --> or (sext Cond), C
11319 if (C1->isAllOnes()) {
11320 Cond = DAG.getSExtOrTrunc(Op: Cond, DL, VT);
11321 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Cond, N2);
11322 }
11323
11324 // select Cond, C, -1 --> or (sext (not Cond)), C
11325 if (C2->isAllOnes()) {
11326 SDValue NotCond = DAG.getNOT(DL, Cond, MVT::i1);
11327 NotCond = DAG.getSExtOrTrunc(Op: NotCond, DL, VT);
11328 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: NotCond, N2: N1);
11329 }
11330
11331 if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
11332 return V;
11333
11334 return SDValue();
11335}
11336
11337template <class MatchContextClass>
11338static SDValue foldBoolSelectToLogic(SDNode *N, SelectionDAG &DAG) {
11339 assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
11340 N->getOpcode() == ISD::VP_SELECT) &&
11341 "Expected a (v)(vp.)select");
11342 SDValue Cond = N->getOperand(Num: 0);
11343 SDValue T = N->getOperand(Num: 1), F = N->getOperand(Num: 2);
11344 EVT VT = N->getValueType(ResNo: 0);
11345 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11346 MatchContextClass matcher(DAG, TLI, N);
11347
11348 if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
11349 return SDValue();
11350
11351 // select Cond, Cond, F --> or Cond, F
11352 // select Cond, 1, F --> or Cond, F
11353 if (Cond == T || isOneOrOneSplat(V: T, /* AllowUndefs */ true))
11354 return matcher.getNode(ISD::OR, SDLoc(N), VT, Cond, F);
11355
11356 // select Cond, T, Cond --> and Cond, T
11357 // select Cond, T, 0 --> and Cond, T
11358 if (Cond == F || isNullOrNullSplat(V: F, /* AllowUndefs */ true))
11359 return matcher.getNode(ISD::AND, SDLoc(N), VT, Cond, T);
11360
11361 // select Cond, T, 1 --> or (not Cond), T
11362 if (isOneOrOneSplat(V: F, /* AllowUndefs */ true)) {
11363 SDValue NotCond = matcher.getNode(ISD::XOR, SDLoc(N), VT, Cond,
11364 DAG.getAllOnesConstant(DL: SDLoc(N), VT));
11365 return matcher.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
11366 }
11367
11368 // select Cond, 0, F --> and (not Cond), F
11369 if (isNullOrNullSplat(V: T, /* AllowUndefs */ true)) {
11370 SDValue NotCond = matcher.getNode(ISD::XOR, SDLoc(N), VT, Cond,
11371 DAG.getAllOnesConstant(DL: SDLoc(N), VT));
11372 return matcher.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
11373 }
11374
11375 return SDValue();
11376}
11377
11378static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
11379 SDValue N0 = N->getOperand(Num: 0);
11380 SDValue N1 = N->getOperand(Num: 1);
11381 SDValue N2 = N->getOperand(Num: 2);
11382 EVT VT = N->getValueType(ResNo: 0);
11383 if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse())
11384 return SDValue();
11385
11386 SDValue Cond0 = N0.getOperand(i: 0);
11387 SDValue Cond1 = N0.getOperand(i: 1);
11388 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
11389 if (VT != Cond0.getValueType())
11390 return SDValue();
11391
11392 // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
11393 // compare is inverted from that pattern ("Cond0 s> -1").
11394 if (CC == ISD::SETLT && isNullOrNullSplat(V: Cond1))
11395 ; // This is the pattern we are looking for.
11396 else if (CC == ISD::SETGT && isAllOnesOrAllOnesSplat(V: Cond1))
11397 std::swap(a&: N1, b&: N2);
11398 else
11399 return SDValue();
11400
11401 // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & N1
11402 if (isNullOrNullSplat(V: N2)) {
11403 SDLoc DL(N);
11404 SDValue ShiftAmt = DAG.getConstant(Val: VT.getScalarSizeInBits() - 1, DL, VT);
11405 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
11406 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Sra, N2: N1);
11407 }
11408
11409 // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | N2
11410 if (isAllOnesOrAllOnesSplat(V: N1)) {
11411 SDLoc DL(N);
11412 SDValue ShiftAmt = DAG.getConstant(Val: VT.getScalarSizeInBits() - 1, DL, VT);
11413 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
11414 return DAG.getNode(Opcode: ISD::OR, DL, VT, N1: Sra, N2);
11415 }
11416
11417 // If we have to invert the sign bit mask, only do that transform if the
11418 // target has a bitwise 'and not' instruction (the invert is free).
11419 // (Cond0 s< -0) ? 0 : N2 --> ~(Cond0 s>> BW-1) & N2
11420 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11421 if (isNullOrNullSplat(V: N1) && TLI.hasAndNot(X: N1)) {
11422 SDLoc DL(N);
11423 SDValue ShiftAmt = DAG.getConstant(Val: VT.getScalarSizeInBits() - 1, DL, VT);
11424 SDValue Sra = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: Cond0, N2: ShiftAmt);
11425 SDValue Not = DAG.getNOT(DL, Val: Sra, VT);
11426 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Not, N2);
11427 }
11428
11429 // TODO: There's another pattern in this family, but it may require
11430 // implementing hasOrNot() to check for profitability:
11431 // (Cond0 s> -1) ? -1 : N2 --> ~(Cond0 s>> BW-1) | N2
11432
11433 return SDValue();
11434}
11435
11436SDValue DAGCombiner::visitSELECT(SDNode *N) {
11437 SDValue N0 = N->getOperand(Num: 0);
11438 SDValue N1 = N->getOperand(Num: 1);
11439 SDValue N2 = N->getOperand(Num: 2);
11440 EVT VT = N->getValueType(ResNo: 0);
11441 EVT VT0 = N0.getValueType();
11442 SDLoc DL(N);
11443 SDNodeFlags Flags = N->getFlags();
11444
11445 if (SDValue V = DAG.simplifySelect(Cond: N0, TVal: N1, FVal: N2))
11446 return V;
11447
11448 if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DAG))
11449 return V;
11450
11451 // select (not Cond), N1, N2 -> select Cond, N2, N1
11452 if (SDValue F = extractBooleanFlip(V: N0, DAG, TLI, Force: false)) {
11453 SDValue SelectOp = DAG.getSelect(DL, VT, Cond: F, LHS: N2, RHS: N1);
11454 SelectOp->setFlags(Flags);
11455 return SelectOp;
11456 }
11457
11458 if (SDValue V = foldSelectOfConstants(N))
11459 return V;
11460
11461 // If we can fold this based on the true/false value, do so.
11462 if (SimplifySelectOps(SELECT: N, LHS: N1, RHS: N2))
11463 return SDValue(N, 0); // Don't revisit N.
11464
11465 if (VT0 == MVT::i1) {
11466 // The code in this block deals with the following 2 equivalences:
11467 // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y))
11468 // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y)
11469 // The target can specify its preferred form with the
11470 // shouldNormalizeToSelectSequence() callback. However we always transform
11471 // to the right anyway if we find the inner select exists in the DAG anyway
11472 // and we always transform to the left side if we know that we can further
11473 // optimize the combination of the conditions.
11474 bool normalizeToSequence =
11475 TLI.shouldNormalizeToSelectSequence(Context&: *DAG.getContext(), VT);
11476 // select (and Cond0, Cond1), X, Y
11477 // -> select Cond0, (select Cond1, X, Y), Y
11478 if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) {
11479 SDValue Cond0 = N0->getOperand(Num: 0);
11480 SDValue Cond1 = N0->getOperand(Num: 1);
11481 SDValue InnerSelect =
11482 DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond1, N2: N1, N3: N2, Flags);
11483 if (normalizeToSequence || !InnerSelect.use_empty())
11484 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond0,
11485 N2: InnerSelect, N3: N2, Flags);
11486 // Cleanup on failure.
11487 if (InnerSelect.use_empty())
11488 recursivelyDeleteUnusedNodes(N: InnerSelect.getNode());
11489 }
11490 // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y)
11491 if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) {
11492 SDValue Cond0 = N0->getOperand(Num: 0);
11493 SDValue Cond1 = N0->getOperand(Num: 1);
11494 SDValue InnerSelect = DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(),
11495 N1: Cond1, N2: N1, N3: N2, Flags);
11496 if (normalizeToSequence || !InnerSelect.use_empty())
11497 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Cond0, N2: N1,
11498 N3: InnerSelect, Flags);
11499 // Cleanup on failure.
11500 if (InnerSelect.use_empty())
11501 recursivelyDeleteUnusedNodes(N: InnerSelect.getNode());
11502 }
11503
11504 // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y
11505 if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) {
11506 SDValue N1_0 = N1->getOperand(Num: 0);
11507 SDValue N1_1 = N1->getOperand(Num: 1);
11508 SDValue N1_2 = N1->getOperand(Num: 2);
11509 if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) {
11510 // Create the actual and node if we can generate good code for it.
11511 if (!normalizeToSequence) {
11512 SDValue And = DAG.getNode(Opcode: ISD::AND, DL, VT: N0.getValueType(), N1: N0, N2: N1_0);
11513 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: And, N2: N1_1,
11514 N3: N2, Flags);
11515 }
11516 // Otherwise see if we can optimize the "and" to a better pattern.
11517 if (SDValue Combined = visitANDLike(N0, N1: N1_0, N)) {
11518 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Combined, N2: N1_1,
11519 N3: N2, Flags);
11520 }
11521 }
11522 }
11523 // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y
11524 if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) {
11525 SDValue N2_0 = N2->getOperand(Num: 0);
11526 SDValue N2_1 = N2->getOperand(Num: 1);
11527 SDValue N2_2 = N2->getOperand(Num: 2);
11528 if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) {
11529 // Create the actual or node if we can generate good code for it.
11530 if (!normalizeToSequence) {
11531 SDValue Or = DAG.getNode(Opcode: ISD::OR, DL, VT: N0.getValueType(), N1: N0, N2: N2_0);
11532 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Or, N2: N1,
11533 N3: N2_2, Flags);
11534 }
11535 // Otherwise see if we can optimize to a better pattern.
11536 if (SDValue Combined = visitORLike(N0, N1: N2_0, DL))
11537 return DAG.getNode(Opcode: ISD::SELECT, DL, VT: N1.getValueType(), N1: Combined, N2: N1,
11538 N3: N2_2, Flags);
11539 }
11540 }
11541 }
11542
11543 // Fold selects based on a setcc into other things, such as min/max/abs.
11544 if (N0.getOpcode() == ISD::SETCC) {
11545 SDValue Cond0 = N0.getOperand(i: 0), Cond1 = N0.getOperand(i: 1);
11546 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
11547
11548 // select (fcmp lt x, y), x, y -> fminnum x, y
11549 // select (fcmp gt x, y), x, y -> fmaxnum x, y
11550 //
11551 // This is OK if we don't care what happens if either operand is a NaN.
11552 if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS: N1, RHS: N2, TLI))
11553 if (SDValue FMinMax =
11554 combineMinNumMaxNum(DL, VT, LHS: Cond0, RHS: Cond1, True: N1, False: N2, CC))
11555 return FMinMax;
11556
11557 // Use 'unsigned add with overflow' to optimize an unsigned saturating add.
11558 // This is conservatively limited to pre-legal-operations to give targets
11559 // a chance to reverse the transform if they want to do that. Also, it is
11560 // unlikely that the pattern would be formed late, so it's probably not
11561 // worth going through the other checks.
11562 if (!LegalOperations && TLI.isOperationLegalOrCustom(Op: ISD::UADDO, VT) &&
11563 CC == ISD::SETUGT && N0.hasOneUse() && isAllOnesConstant(V: N1) &&
11564 N2.getOpcode() == ISD::ADD && Cond0 == N2.getOperand(i: 0)) {
11565 auto *C = dyn_cast<ConstantSDNode>(Val: N2.getOperand(i: 1));
11566 auto *NotC = dyn_cast<ConstantSDNode>(Val&: Cond1);
11567 if (C && NotC && C->getAPIntValue() == ~NotC->getAPIntValue()) {
11568 // select (setcc Cond0, ~C, ugt), -1, (add Cond0, C) -->
11569 // uaddo Cond0, C; select uaddo.1, -1, uaddo.0
11570 //
11571 // The IR equivalent of this transform would have this form:
11572 // %a = add %x, C
11573 // %c = icmp ugt %x, ~C
11574 // %r = select %c, -1, %a
11575 // =>
11576 // %u = call {iN,i1} llvm.uadd.with.overflow(%x, C)
11577 // %u0 = extractvalue %u, 0
11578 // %u1 = extractvalue %u, 1
11579 // %r = select %u1, -1, %u0
11580 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: VT0);
11581 SDValue UAO = DAG.getNode(Opcode: ISD::UADDO, DL, VTList: VTs, N1: Cond0, N2: N2.getOperand(i: 1));
11582 return DAG.getSelect(DL, VT, Cond: UAO.getValue(R: 1), LHS: N1, RHS: UAO.getValue(R: 0));
11583 }
11584 }
11585
11586 if (TLI.isOperationLegal(Op: ISD::SELECT_CC, VT) ||
11587 (!LegalOperations &&
11588 TLI.isOperationLegalOrCustom(Op: ISD::SELECT_CC, VT))) {
11589 // Any flags available in a select/setcc fold will be on the setcc as they
11590 // migrated from fcmp
11591 Flags = N0->getFlags();
11592 SDValue SelectNode = DAG.getNode(Opcode: ISD::SELECT_CC, DL, VT, N1: Cond0, N2: Cond1, N3: N1,
11593 N4: N2, N5: N0.getOperand(i: 2));
11594 SelectNode->setFlags(Flags);
11595 return SelectNode;
11596 }
11597
11598 if (SDValue NewSel = SimplifySelect(DL, N0, N1, N2))
11599 return NewSel;
11600 }
11601
11602 if (!VT.isVector())
11603 if (SDValue BinOp = foldSelectOfBinops(N))
11604 return BinOp;
11605
11606 if (SDValue R = combineSelectAsExtAnd(Cond: N0, T: N1, F: N2, DL, DAG))
11607 return R;
11608
11609 return SDValue();
11610}
11611
11612// This function assumes all the vselect's arguments are CONCAT_VECTOR
11613// nodes and that the condition is a BV of ConstantSDNodes (or undefs).
11614static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
11615 SDLoc DL(N);
11616 SDValue Cond = N->getOperand(Num: 0);
11617 SDValue LHS = N->getOperand(Num: 1);
11618 SDValue RHS = N->getOperand(Num: 2);
11619 EVT VT = N->getValueType(ResNo: 0);
11620 int NumElems = VT.getVectorNumElements();
11621 assert(LHS.getOpcode() == ISD::CONCAT_VECTORS &&
11622 RHS.getOpcode() == ISD::CONCAT_VECTORS &&
11623 Cond.getOpcode() == ISD::BUILD_VECTOR);
11624
11625 // CONCAT_VECTOR can take an arbitrary number of arguments. We only care about
11626 // binary ones here.
11627 if (LHS->getNumOperands() != 2 || RHS->getNumOperands() != 2)
11628 return SDValue();
11629
11630 // We're sure we have an even number of elements due to the
11631 // concat_vectors we have as arguments to vselect.
11632 // Skip BV elements until we find one that's not an UNDEF
11633 // After we find an UNDEF element, keep looping until we get to half the
11634 // length of the BV and see if all the non-undef nodes are the same.
11635 ConstantSDNode *BottomHalf = nullptr;
11636 for (int i = 0; i < NumElems / 2; ++i) {
11637 if (Cond->getOperand(Num: i)->isUndef())
11638 continue;
11639
11640 if (BottomHalf == nullptr)
11641 BottomHalf = cast<ConstantSDNode>(Val: Cond.getOperand(i));
11642 else if (Cond->getOperand(Num: i).getNode() != BottomHalf)
11643 return SDValue();
11644 }
11645
11646 // Do the same for the second half of the BuildVector
11647 ConstantSDNode *TopHalf = nullptr;
11648 for (int i = NumElems / 2; i < NumElems; ++i) {
11649 if (Cond->getOperand(Num: i)->isUndef())
11650 continue;
11651
11652 if (TopHalf == nullptr)
11653 TopHalf = cast<ConstantSDNode>(Val: Cond.getOperand(i));
11654 else if (Cond->getOperand(Num: i).getNode() != TopHalf)
11655 return SDValue();
11656 }
11657
11658 assert(TopHalf && BottomHalf &&
11659 "One half of the selector was all UNDEFs and the other was all the "
11660 "same value. This should have been addressed before this function.");
11661 return DAG.getNode(
11662 Opcode: ISD::CONCAT_VECTORS, DL, VT,
11663 N1: BottomHalf->isZero() ? RHS->getOperand(Num: 0) : LHS->getOperand(Num: 0),
11664 N2: TopHalf->isZero() ? RHS->getOperand(Num: 1) : LHS->getOperand(Num: 1));
11665}
11666
11667bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
11668 SelectionDAG &DAG, const SDLoc &DL) {
11669
11670 // Only perform the transformation when existing operands can be reused.
11671 if (IndexIsScaled)
11672 return false;
11673
11674 if (!isNullConstant(V: BasePtr) && !Index.hasOneUse())
11675 return false;
11676
11677 EVT VT = BasePtr.getValueType();
11678
11679 if (SDValue SplatVal = DAG.getSplatValue(V: Index);
11680 SplatVal && !isNullConstant(V: SplatVal) &&
11681 SplatVal.getValueType() == VT) {
11682 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
11683 Index = DAG.getSplat(VT: Index.getValueType(), DL, Op: DAG.getConstant(Val: 0, DL, VT));
11684 return true;
11685 }
11686
11687 if (Index.getOpcode() != ISD::ADD)
11688 return false;
11689
11690 if (SDValue SplatVal = DAG.getSplatValue(V: Index.getOperand(i: 0));
11691 SplatVal && SplatVal.getValueType() == VT) {
11692 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
11693 Index = Index.getOperand(i: 1);
11694 return true;
11695 }
11696 if (SDValue SplatVal = DAG.getSplatValue(V: Index.getOperand(i: 1));
11697 SplatVal && SplatVal.getValueType() == VT) {
11698 BasePtr = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: BasePtr, N2: SplatVal);
11699 Index = Index.getOperand(i: 0);
11700 return true;
11701 }
11702 return false;
11703}
11704
11705// Fold sext/zext of index into index type.
11706bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT,
11707 SelectionDAG &DAG) {
11708 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
11709
11710 // It's always safe to look through zero extends.
11711 if (Index.getOpcode() == ISD::ZERO_EXTEND) {
11712 if (TLI.shouldRemoveExtendFromGSIndex(Extend: Index, DataVT)) {
11713 IndexType = ISD::UNSIGNED_SCALED;
11714 Index = Index.getOperand(i: 0);
11715 return true;
11716 }
11717 if (ISD::isIndexTypeSigned(IndexType)) {
11718 IndexType = ISD::UNSIGNED_SCALED;
11719 return true;
11720 }
11721 }
11722
11723 // It's only safe to look through sign extends when Index is signed.
11724 if (Index.getOpcode() == ISD::SIGN_EXTEND &&
11725 ISD::isIndexTypeSigned(IndexType) &&
11726 TLI.shouldRemoveExtendFromGSIndex(Extend: Index, DataVT)) {
11727 Index = Index.getOperand(i: 0);
11728 return true;
11729 }
11730
11731 return false;
11732}
11733
11734SDValue DAGCombiner::visitVPSCATTER(SDNode *N) {
11735 VPScatterSDNode *MSC = cast<VPScatterSDNode>(Val: N);
11736 SDValue Mask = MSC->getMask();
11737 SDValue Chain = MSC->getChain();
11738 SDValue Index = MSC->getIndex();
11739 SDValue Scale = MSC->getScale();
11740 SDValue StoreVal = MSC->getValue();
11741 SDValue BasePtr = MSC->getBasePtr();
11742 SDValue VL = MSC->getVectorLength();
11743 ISD::MemIndexType IndexType = MSC->getIndexType();
11744 SDLoc DL(N);
11745
11746 // Zap scatters with a zero mask.
11747 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
11748 return Chain;
11749
11750 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MSC->isIndexScaled(), DAG, DL)) {
11751 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
11752 return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11753 DL, Ops, MSC->getMemOperand(), IndexType);
11754 }
11755
11756 if (refineIndexType(Index, IndexType, DataVT: StoreVal.getValueType(), DAG)) {
11757 SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
11758 return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11759 DL, Ops, MSC->getMemOperand(), IndexType);
11760 }
11761
11762 return SDValue();
11763}
11764
11765SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
11766 MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(Val: N);
11767 SDValue Mask = MSC->getMask();
11768 SDValue Chain = MSC->getChain();
11769 SDValue Index = MSC->getIndex();
11770 SDValue Scale = MSC->getScale();
11771 SDValue StoreVal = MSC->getValue();
11772 SDValue BasePtr = MSC->getBasePtr();
11773 ISD::MemIndexType IndexType = MSC->getIndexType();
11774 SDLoc DL(N);
11775
11776 // Zap scatters with a zero mask.
11777 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
11778 return Chain;
11779
11780 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MSC->isIndexScaled(), DAG, DL)) {
11781 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
11782 return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11783 DL, Ops, MSC->getMemOperand(), IndexType,
11784 MSC->isTruncatingStore());
11785 }
11786
11787 if (refineIndexType(Index, IndexType, DataVT: StoreVal.getValueType(), DAG)) {
11788 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
11789 return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
11790 DL, Ops, MSC->getMemOperand(), IndexType,
11791 MSC->isTruncatingStore());
11792 }
11793
11794 return SDValue();
11795}
11796
11797SDValue DAGCombiner::visitMSTORE(SDNode *N) {
11798 MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(Val: N);
11799 SDValue Mask = MST->getMask();
11800 SDValue Chain = MST->getChain();
11801 SDValue Value = MST->getValue();
11802 SDValue Ptr = MST->getBasePtr();
11803 SDLoc DL(N);
11804
11805 // Zap masked stores with a zero mask.
11806 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
11807 return Chain;
11808
11809 // Remove a masked store if base pointers and masks are equal.
11810 if (MaskedStoreSDNode *MST1 = dyn_cast<MaskedStoreSDNode>(Val&: Chain)) {
11811 if (MST->isUnindexed() && MST->isSimple() && MST1->isUnindexed() &&
11812 MST1->isSimple() && MST1->getBasePtr() == Ptr &&
11813 !MST->getBasePtr().isUndef() &&
11814 ((Mask == MST1->getMask() && MST->getMemoryVT().getStoreSize() ==
11815 MST1->getMemoryVT().getStoreSize()) ||
11816 ISD::isConstantSplatVectorAllOnes(N: Mask.getNode())) &&
11817 TypeSize::isKnownLE(LHS: MST1->getMemoryVT().getStoreSize(),
11818 RHS: MST->getMemoryVT().getStoreSize())) {
11819 CombineTo(N: MST1, Res: MST1->getChain());
11820 if (N->getOpcode() != ISD::DELETED_NODE)
11821 AddToWorklist(N);
11822 return SDValue(N, 0);
11823 }
11824 }
11825
11826 // If this is a masked load with an all ones mask, we can use a unmasked load.
11827 // FIXME: Can we do this for indexed, compressing, or truncating stores?
11828 if (ISD::isConstantSplatVectorAllOnes(N: Mask.getNode()) && MST->isUnindexed() &&
11829 !MST->isCompressingStore() && !MST->isTruncatingStore())
11830 return DAG.getStore(Chain: MST->getChain(), dl: SDLoc(N), Val: MST->getValue(),
11831 Ptr: MST->getBasePtr(), PtrInfo: MST->getPointerInfo(),
11832 Alignment: MST->getOriginalAlign(),
11833 MMOFlags: MST->getMemOperand()->getFlags(), AAInfo: MST->getAAInfo());
11834
11835 // Try transforming N to an indexed store.
11836 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
11837 return SDValue(N, 0);
11838
11839 if (MST->isTruncatingStore() && MST->isUnindexed() &&
11840 Value.getValueType().isInteger() &&
11841 (!isa<ConstantSDNode>(Val: Value) ||
11842 !cast<ConstantSDNode>(Val&: Value)->isOpaque())) {
11843 APInt TruncDemandedBits =
11844 APInt::getLowBitsSet(numBits: Value.getScalarValueSizeInBits(),
11845 loBitsSet: MST->getMemoryVT().getScalarSizeInBits());
11846
11847 // See if we can simplify the operation with
11848 // SimplifyDemandedBits, which only works if the value has a single use.
11849 if (SimplifyDemandedBits(Op: Value, DemandedBits: TruncDemandedBits)) {
11850 // Re-visit the store if anything changed and the store hasn't been merged
11851 // with another node (N is deleted) SimplifyDemandedBits will add Value's
11852 // node back to the worklist if necessary, but we also need to re-visit
11853 // the Store node itself.
11854 if (N->getOpcode() != ISD::DELETED_NODE)
11855 AddToWorklist(N);
11856 return SDValue(N, 0);
11857 }
11858 }
11859
11860 // If this is a TRUNC followed by a masked store, fold this into a masked
11861 // truncating store. We can do this even if this is already a masked
11862 // truncstore.
11863 // TODO: Try combine to masked compress store if possiable.
11864 if ((Value.getOpcode() == ISD::TRUNCATE) && Value->hasOneUse() &&
11865 MST->isUnindexed() && !MST->isCompressingStore() &&
11866 TLI.canCombineTruncStore(ValVT: Value.getOperand(i: 0).getValueType(),
11867 MemVT: MST->getMemoryVT(), LegalOnly: LegalOperations)) {
11868 auto Mask = TLI.promoteTargetBoolean(DAG, Bool: MST->getMask(),
11869 ValVT: Value.getOperand(i: 0).getValueType());
11870 return DAG.getMaskedStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Base: Ptr,
11871 Offset: MST->getOffset(), Mask, MemVT: MST->getMemoryVT(),
11872 MMO: MST->getMemOperand(), AM: MST->getAddressingMode(),
11873 /*IsTruncating=*/true);
11874 }
11875
11876 return SDValue();
11877}
11878
11879SDValue DAGCombiner::visitVP_STRIDED_STORE(SDNode *N) {
11880 auto *SST = cast<VPStridedStoreSDNode>(Val: N);
11881 EVT EltVT = SST->getValue().getValueType().getVectorElementType();
11882 // Combine strided stores with unit-stride to a regular VP store.
11883 if (auto *CStride = dyn_cast<ConstantSDNode>(Val: SST->getStride());
11884 CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
11885 return DAG.getStoreVP(Chain: SST->getChain(), dl: SDLoc(N), Val: SST->getValue(),
11886 Ptr: SST->getBasePtr(), Offset: SST->getOffset(), Mask: SST->getMask(),
11887 EVL: SST->getVectorLength(), MemVT: SST->getMemoryVT(),
11888 MMO: SST->getMemOperand(), AM: SST->getAddressingMode(),
11889 IsTruncating: SST->isTruncatingStore(), IsCompressing: SST->isCompressingStore());
11890 }
11891 return SDValue();
11892}
11893
11894SDValue DAGCombiner::visitVPGATHER(SDNode *N) {
11895 VPGatherSDNode *MGT = cast<VPGatherSDNode>(Val: N);
11896 SDValue Mask = MGT->getMask();
11897 SDValue Chain = MGT->getChain();
11898 SDValue Index = MGT->getIndex();
11899 SDValue Scale = MGT->getScale();
11900 SDValue BasePtr = MGT->getBasePtr();
11901 SDValue VL = MGT->getVectorLength();
11902 ISD::MemIndexType IndexType = MGT->getIndexType();
11903 SDLoc DL(N);
11904
11905 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MGT->isIndexScaled(), DAG, DL)) {
11906 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
11907 return DAG.getGatherVP(
11908 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
11909 Ops, MGT->getMemOperand(), IndexType);
11910 }
11911
11912 if (refineIndexType(Index, IndexType, DataVT: N->getValueType(ResNo: 0), DAG)) {
11913 SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
11914 return DAG.getGatherVP(
11915 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
11916 Ops, MGT->getMemOperand(), IndexType);
11917 }
11918
11919 return SDValue();
11920}
11921
11922SDValue DAGCombiner::visitMGATHER(SDNode *N) {
11923 MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Val: N);
11924 SDValue Mask = MGT->getMask();
11925 SDValue Chain = MGT->getChain();
11926 SDValue Index = MGT->getIndex();
11927 SDValue Scale = MGT->getScale();
11928 SDValue PassThru = MGT->getPassThru();
11929 SDValue BasePtr = MGT->getBasePtr();
11930 ISD::MemIndexType IndexType = MGT->getIndexType();
11931 SDLoc DL(N);
11932
11933 // Zap gathers with a zero mask.
11934 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
11935 return CombineTo(N, Res0: PassThru, Res1: MGT->getChain());
11936
11937 if (refineUniformBase(BasePtr, Index, IndexIsScaled: MGT->isIndexScaled(), DAG, DL)) {
11938 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
11939 return DAG.getMaskedGather(
11940 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
11941 Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
11942 }
11943
11944 if (refineIndexType(Index, IndexType, DataVT: N->getValueType(ResNo: 0), DAG)) {
11945 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
11946 return DAG.getMaskedGather(
11947 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
11948 Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
11949 }
11950
11951 return SDValue();
11952}
11953
11954SDValue DAGCombiner::visitMLOAD(SDNode *N) {
11955 MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(Val: N);
11956 SDValue Mask = MLD->getMask();
11957 SDLoc DL(N);
11958
11959 // Zap masked loads with a zero mask.
11960 if (ISD::isConstantSplatVectorAllZeros(N: Mask.getNode()))
11961 return CombineTo(N, Res0: MLD->getPassThru(), Res1: MLD->getChain());
11962
11963 // If this is a masked load with an all ones mask, we can use a unmasked load.
11964 // FIXME: Can we do this for indexed, expanding, or extending loads?
11965 if (ISD::isConstantSplatVectorAllOnes(N: Mask.getNode()) && MLD->isUnindexed() &&
11966 !MLD->isExpandingLoad() && MLD->getExtensionType() == ISD::NON_EXTLOAD) {
11967 SDValue NewLd = DAG.getLoad(
11968 VT: N->getValueType(ResNo: 0), dl: SDLoc(N), Chain: MLD->getChain(), Ptr: MLD->getBasePtr(),
11969 PtrInfo: MLD->getPointerInfo(), Alignment: MLD->getOriginalAlign(),
11970 MMOFlags: MLD->getMemOperand()->getFlags(), AAInfo: MLD->getAAInfo(), Ranges: MLD->getRanges());
11971 return CombineTo(N, Res0: NewLd, Res1: NewLd.getValue(R: 1));
11972 }
11973
11974 // Try transforming N to an indexed load.
11975 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
11976 return SDValue(N, 0);
11977
11978 return SDValue();
11979}
11980
11981SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
11982 auto *SLD = cast<VPStridedLoadSDNode>(Val: N);
11983 EVT EltVT = SLD->getValueType(ResNo: 0).getVectorElementType();
11984 // Combine strided loads with unit-stride to a regular VP load.
11985 if (auto *CStride = dyn_cast<ConstantSDNode>(Val: SLD->getStride());
11986 CStride && CStride->getZExtValue() == EltVT.getStoreSize()) {
11987 SDValue NewLd = DAG.getLoadVP(
11988 AM: SLD->getAddressingMode(), ExtType: SLD->getExtensionType(), VT: SLD->getValueType(ResNo: 0),
11989 dl: SDLoc(N), Chain: SLD->getChain(), Ptr: SLD->getBasePtr(), Offset: SLD->getOffset(),
11990 Mask: SLD->getMask(), EVL: SLD->getVectorLength(), MemVT: SLD->getMemoryVT(),
11991 MMO: SLD->getMemOperand(), IsExpanding: SLD->isExpandingLoad());
11992 return CombineTo(N, Res0: NewLd, Res1: NewLd.getValue(R: 1));
11993 }
11994 return SDValue();
11995}
11996
11997/// A vector select of 2 constant vectors can be simplified to math/logic to
11998/// avoid a variable select instruction and possibly avoid constant loads.
11999SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
12000 SDValue Cond = N->getOperand(Num: 0);
12001 SDValue N1 = N->getOperand(Num: 1);
12002 SDValue N2 = N->getOperand(Num: 2);
12003 EVT VT = N->getValueType(ResNo: 0);
12004 if (!Cond.hasOneUse() || Cond.getScalarValueSizeInBits() != 1 ||
12005 !shouldConvertSelectOfConstantsToMath(Cond, VT, TLI) ||
12006 !ISD::isBuildVectorOfConstantSDNodes(N: N1.getNode()) ||
12007 !ISD::isBuildVectorOfConstantSDNodes(N: N2.getNode()))
12008 return SDValue();
12009
12010 // Check if we can use the condition value to increment/decrement a single
12011 // constant value. This simplifies a select to an add and removes a constant
12012 // load/materialization from the general case.
12013 bool AllAddOne = true;
12014 bool AllSubOne = true;
12015 unsigned Elts = VT.getVectorNumElements();
12016 for (unsigned i = 0; i != Elts; ++i) {
12017 SDValue N1Elt = N1.getOperand(i);
12018 SDValue N2Elt = N2.getOperand(i);
12019 if (N1Elt.isUndef() || N2Elt.isUndef())
12020 continue;
12021 if (N1Elt.getValueType() != N2Elt.getValueType())
12022 continue;
12023
12024 const APInt &C1 = N1Elt->getAsAPIntVal();
12025 const APInt &C2 = N2Elt->getAsAPIntVal();
12026 if (C1 != C2 + 1)
12027 AllAddOne = false;
12028 if (C1 != C2 - 1)
12029 AllSubOne = false;
12030 }
12031
12032 // Further simplifications for the extra-special cases where the constants are
12033 // all 0 or all -1 should be implemented as folds of these patterns.
12034 SDLoc DL(N);
12035 if (AllAddOne || AllSubOne) {
12036 // vselect <N x i1> Cond, C+1, C --> add (zext Cond), C
12037 // vselect <N x i1> Cond, C-1, C --> add (sext Cond), C
12038 auto ExtendOpcode = AllAddOne ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND;
12039 SDValue ExtendedCond = DAG.getNode(Opcode: ExtendOpcode, DL, VT, Operand: Cond);
12040 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: ExtendedCond, N2);
12041 }
12042
12043 // select Cond, Pow2C, 0 --> (zext Cond) << log2(Pow2C)
12044 APInt Pow2C;
12045 if (ISD::isConstantSplatVector(N: N1.getNode(), SplatValue&: Pow2C) && Pow2C.isPowerOf2() &&
12046 isNullOrNullSplat(V: N2)) {
12047 SDValue ZextCond = DAG.getZExtOrTrunc(Op: Cond, DL, VT);
12048 SDValue ShAmtC = DAG.getConstant(Val: Pow2C.exactLogBase2(), DL, VT);
12049 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: ZextCond, N2: ShAmtC);
12050 }
12051
12052 if (SDValue V = foldSelectOfConstantsUsingSra(N, DAG))
12053 return V;
12054
12055 // The general case for select-of-constants:
12056 // vselect <N x i1> Cond, C1, C2 --> xor (and (sext Cond), (C1^C2)), C2
12057 // ...but that only makes sense if a vselect is slower than 2 logic ops, so
12058 // leave that to a machine-specific pass.
12059 return SDValue();
12060}
12061
12062SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
12063 SDValue N0 = N->getOperand(Num: 0);
12064 SDValue N1 = N->getOperand(Num: 1);
12065 SDValue N2 = N->getOperand(Num: 2);
12066
12067 if (SDValue V = DAG.simplifySelect(Cond: N0, TVal: N1, FVal: N2))
12068 return V;
12069
12070 if (SDValue V = foldBoolSelectToLogic<VPMatchContext>(N, DAG))
12071 return V;
12072
12073 return SDValue();
12074}
12075
12076SDValue DAGCombiner::visitVSELECT(SDNode *N) {
12077 SDValue N0 = N->getOperand(Num: 0);
12078 SDValue N1 = N->getOperand(Num: 1);
12079 SDValue N2 = N->getOperand(Num: 2);
12080 EVT VT = N->getValueType(ResNo: 0);
12081 SDLoc DL(N);
12082
12083 if (SDValue V = DAG.simplifySelect(Cond: N0, TVal: N1, FVal: N2))
12084 return V;
12085
12086 if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DAG))
12087 return V;
12088
12089 // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
12090 if (SDValue F = extractBooleanFlip(V: N0, DAG, TLI, Force: false))
12091 return DAG.getSelect(DL, VT, Cond: F, LHS: N2, RHS: N1);
12092
12093 // select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
12094 if (N1.getOpcode() == ISD::ADD && N1.getOperand(i: 0) == N2 && N1->hasOneUse() &&
12095 DAG.isConstantIntBuildVectorOrConstantInt(N: N1.getOperand(i: 1)) &&
12096 N0.getScalarValueSizeInBits() == N1.getScalarValueSizeInBits() &&
12097 TLI.getBooleanContents(Type: N0.getValueType()) ==
12098 TargetLowering::ZeroOrNegativeOneBooleanContent) {
12099 return DAG.getNode(
12100 Opcode: ISD::ADD, DL, VT: N1.getValueType(), N1: N2,
12101 N2: DAG.getNode(Opcode: ISD::AND, DL, VT: N0.getValueType(), N1: N1.getOperand(i: 1), N2: N0));
12102 }
12103
12104 // Canonicalize integer abs.
12105 // vselect (setg[te] X, 0), X, -X ->
12106 // vselect (setgt X, -1), X, -X ->
12107 // vselect (setl[te] X, 0), -X, X ->
12108 // Y = sra (X, size(X)-1); xor (add (X, Y), Y)
12109 if (N0.getOpcode() == ISD::SETCC) {
12110 SDValue LHS = N0.getOperand(i: 0), RHS = N0.getOperand(i: 1);
12111 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
12112 bool isAbs = false;
12113 bool RHSIsAllZeros = ISD::isBuildVectorAllZeros(N: RHS.getNode());
12114
12115 if (((RHSIsAllZeros && (CC == ISD::SETGT || CC == ISD::SETGE)) ||
12116 (ISD::isBuildVectorAllOnes(N: RHS.getNode()) && CC == ISD::SETGT)) &&
12117 N1 == LHS && N2.getOpcode() == ISD::SUB && N1 == N2.getOperand(i: 1))
12118 isAbs = ISD::isBuildVectorAllZeros(N: N2.getOperand(i: 0).getNode());
12119 else if ((RHSIsAllZeros && (CC == ISD::SETLT || CC == ISD::SETLE)) &&
12120 N2 == LHS && N1.getOpcode() == ISD::SUB && N2 == N1.getOperand(i: 1))
12121 isAbs = ISD::isBuildVectorAllZeros(N: N1.getOperand(i: 0).getNode());
12122
12123 if (isAbs) {
12124 if (TLI.isOperationLegalOrCustom(Op: ISD::ABS, VT))
12125 return DAG.getNode(Opcode: ISD::ABS, DL, VT, Operand: LHS);
12126
12127 SDValue Shift = DAG.getNode(Opcode: ISD::SRA, DL, VT, N1: LHS,
12128 N2: DAG.getConstant(Val: VT.getScalarSizeInBits() - 1,
12129 DL, VT: getShiftAmountTy(LHSTy: VT)));
12130 SDValue Add = DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: LHS, N2: Shift);
12131 AddToWorklist(N: Shift.getNode());
12132 AddToWorklist(N: Add.getNode());
12133 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: Add, N2: Shift);
12134 }
12135
12136 // vselect x, y (fcmp lt x, y) -> fminnum x, y
12137 // vselect x, y (fcmp gt x, y) -> fmaxnum x, y
12138 //
12139 // This is OK if we don't care about what happens if either operand is a
12140 // NaN.
12141 //
12142 if (N0.hasOneUse() && isLegalToCombineMinNumMaxNum(DAG, LHS, RHS, TLI)) {
12143 if (SDValue FMinMax = combineMinNumMaxNum(DL, VT, LHS, RHS, True: N1, False: N2, CC))
12144 return FMinMax;
12145 }
12146
12147 if (SDValue S = PerformMinMaxFpToSatCombine(N0: LHS, N1: RHS, N2: N1, N3: N2, CC, DAG))
12148 return S;
12149 if (SDValue S = PerformUMinFpToSatCombine(N0: LHS, N1: RHS, N2: N1, N3: N2, CC, DAG))
12150 return S;
12151
12152 // If this select has a condition (setcc) with narrower operands than the
12153 // select, try to widen the compare to match the select width.
12154 // TODO: This should be extended to handle any constant.
12155 // TODO: This could be extended to handle non-loading patterns, but that
12156 // requires thorough testing to avoid regressions.
12157 if (isNullOrNullSplat(V: RHS)) {
12158 EVT NarrowVT = LHS.getValueType();
12159 EVT WideVT = N1.getValueType().changeVectorElementTypeToInteger();
12160 EVT SetCCVT = getSetCCResultType(VT: LHS.getValueType());
12161 unsigned SetCCWidth = SetCCVT.getScalarSizeInBits();
12162 unsigned WideWidth = WideVT.getScalarSizeInBits();
12163 bool IsSigned = isSignedIntSetCC(Code: CC);
12164 auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
12165 if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
12166 SetCCWidth != 1 && SetCCWidth < WideWidth &&
12167 TLI.isLoadExtLegalOrCustom(ExtType: LoadExtOpcode, ValVT: WideVT, MemVT: NarrowVT) &&
12168 TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT: WideVT)) {
12169 // Both compare operands can be widened for free. The LHS can use an
12170 // extended load, and the RHS is a constant:
12171 // vselect (ext (setcc load(X), C)), N1, N2 -->
12172 // vselect (setcc extload(X), C'), N1, N2
12173 auto ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
12174 SDValue WideLHS = DAG.getNode(Opcode: ExtOpcode, DL, VT: WideVT, Operand: LHS);
12175 SDValue WideRHS = DAG.getNode(Opcode: ExtOpcode, DL, VT: WideVT, Operand: RHS);
12176 EVT WideSetCCVT = getSetCCResultType(VT: WideVT);
12177 SDValue WideSetCC = DAG.getSetCC(DL, VT: WideSetCCVT, LHS: WideLHS, RHS: WideRHS, Cond: CC);
12178 return DAG.getSelect(DL, VT: N1.getValueType(), Cond: WideSetCC, LHS: N1, RHS: N2);
12179 }
12180 }
12181
12182 // Match VSELECTs with absolute difference patterns.
12183 // (vselect (setcc a, b, set?gt), (sub a, b), (sub b, a)) --> (abd? a, b)
12184 // (vselect (setcc a, b, set?ge), (sub a, b), (sub b, a)) --> (abd? a, b)
12185 // (vselect (setcc a, b, set?lt), (sub b, a), (sub a, b)) --> (abd? a, b)
12186 // (vselect (setcc a, b, set?le), (sub b, a), (sub a, b)) --> (abd? a, b)
12187 if (N1.getOpcode() == ISD::SUB && N2.getOpcode() == ISD::SUB &&
12188 N1.getOperand(i: 0) == N2.getOperand(i: 1) &&
12189 N1.getOperand(i: 1) == N2.getOperand(i: 0)) {
12190 bool IsSigned = isSignedIntSetCC(Code: CC);
12191 unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
12192 if (hasOperation(Opcode: ABDOpc, VT)) {
12193 switch (CC) {
12194 case ISD::SETGT:
12195 case ISD::SETGE:
12196 case ISD::SETUGT:
12197 case ISD::SETUGE:
12198 if (LHS == N1.getOperand(i: 0) && RHS == N1.getOperand(i: 1))
12199 return DAG.getNode(Opcode: ABDOpc, DL, VT, N1: LHS, N2: RHS);
12200 break;
12201 case ISD::SETLT:
12202 case ISD::SETLE:
12203 case ISD::SETULT:
12204 case ISD::SETULE:
12205 if (RHS == N1.getOperand(i: 0) && LHS == N1.getOperand(i: 1) )
12206 return DAG.getNode(Opcode: ABDOpc, DL, VT, N1: LHS, N2: RHS);
12207 break;
12208 default:
12209 break;
12210 }
12211 }
12212 }
12213
12214 // Match VSELECTs into add with unsigned saturation.
12215 if (hasOperation(Opcode: ISD::UADDSAT, VT)) {
12216 // Check if one of the arms of the VSELECT is vector with all bits set.
12217 // If it's on the left side invert the predicate to simplify logic below.
12218 SDValue Other;
12219 ISD::CondCode SatCC = CC;
12220 if (ISD::isConstantSplatVectorAllOnes(N: N1.getNode())) {
12221 Other = N2;
12222 SatCC = ISD::getSetCCInverse(Operation: SatCC, Type: VT.getScalarType());
12223 } else if (ISD::isConstantSplatVectorAllOnes(N: N2.getNode())) {
12224 Other = N1;
12225 }
12226
12227 if (Other && Other.getOpcode() == ISD::ADD) {
12228 SDValue CondLHS = LHS, CondRHS = RHS;
12229 SDValue OpLHS = Other.getOperand(i: 0), OpRHS = Other.getOperand(i: 1);
12230
12231 // Canonicalize condition operands.
12232 if (SatCC == ISD::SETUGE) {
12233 std::swap(a&: CondLHS, b&: CondRHS);
12234 SatCC = ISD::SETULE;
12235 }
12236
12237 // We can test against either of the addition operands.
12238 // x <= x+y ? x+y : ~0 --> uaddsat x, y
12239 // x+y >= x ? x+y : ~0 --> uaddsat x, y
12240 if (SatCC == ISD::SETULE && Other == CondRHS &&
12241 (OpLHS == CondLHS || OpRHS == CondLHS))
12242 return DAG.getNode(Opcode: ISD::UADDSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12243
12244 if (OpRHS.getOpcode() == CondRHS.getOpcode() &&
12245 (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
12246 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) &&
12247 CondLHS == OpLHS) {
12248 // If the RHS is a constant we have to reverse the const
12249 // canonicalization.
12250 // x >= ~C ? x+C : ~0 --> uaddsat x, C
12251 auto MatchUADDSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
12252 return Cond->getAPIntValue() == ~Op->getAPIntValue();
12253 };
12254 if (SatCC == ISD::SETULE &&
12255 ISD::matchBinaryPredicate(LHS: OpRHS, RHS: CondRHS, Match: MatchUADDSAT))
12256 return DAG.getNode(Opcode: ISD::UADDSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12257 }
12258 }
12259 }
12260
12261 // Match VSELECTs into sub with unsigned saturation.
12262 if (hasOperation(Opcode: ISD::USUBSAT, VT)) {
12263 // Check if one of the arms of the VSELECT is a zero vector. If it's on
12264 // the left side invert the predicate to simplify logic below.
12265 SDValue Other;
12266 ISD::CondCode SatCC = CC;
12267 if (ISD::isConstantSplatVectorAllZeros(N: N1.getNode())) {
12268 Other = N2;
12269 SatCC = ISD::getSetCCInverse(Operation: SatCC, Type: VT.getScalarType());
12270 } else if (ISD::isConstantSplatVectorAllZeros(N: N2.getNode())) {
12271 Other = N1;
12272 }
12273
12274 // zext(x) >= y ? trunc(zext(x) - y) : 0
12275 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
12276 // zext(x) > y ? trunc(zext(x) - y) : 0
12277 // --> usubsat(trunc(zext(x)),trunc(umin(y,SatLimit)))
12278 if (Other && Other.getOpcode() == ISD::TRUNCATE &&
12279 Other.getOperand(i: 0).getOpcode() == ISD::SUB &&
12280 (SatCC == ISD::SETUGE || SatCC == ISD::SETUGT)) {
12281 SDValue OpLHS = Other.getOperand(i: 0).getOperand(i: 0);
12282 SDValue OpRHS = Other.getOperand(i: 0).getOperand(i: 1);
12283 if (LHS == OpLHS && RHS == OpRHS && LHS.getOpcode() == ISD::ZERO_EXTEND)
12284 if (SDValue R = getTruncatedUSUBSAT(DstVT: VT, SrcVT: LHS.getValueType(), LHS, RHS,
12285 DAG, DL))
12286 return R;
12287 }
12288
12289 if (Other && Other.getNumOperands() == 2) {
12290 SDValue CondRHS = RHS;
12291 SDValue OpLHS = Other.getOperand(i: 0), OpRHS = Other.getOperand(i: 1);
12292
12293 if (OpLHS == LHS) {
12294 // Look for a general sub with unsigned saturation first.
12295 // x >= y ? x-y : 0 --> usubsat x, y
12296 // x > y ? x-y : 0 --> usubsat x, y
12297 if ((SatCC == ISD::SETUGE || SatCC == ISD::SETUGT) &&
12298 Other.getOpcode() == ISD::SUB && OpRHS == CondRHS)
12299 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12300
12301 if (OpRHS.getOpcode() == ISD::BUILD_VECTOR ||
12302 OpRHS.getOpcode() == ISD::SPLAT_VECTOR) {
12303 if (CondRHS.getOpcode() == ISD::BUILD_VECTOR ||
12304 CondRHS.getOpcode() == ISD::SPLAT_VECTOR) {
12305 // If the RHS is a constant we have to reverse the const
12306 // canonicalization.
12307 // x > C-1 ? x+-C : 0 --> usubsat x, C
12308 auto MatchUSUBSAT = [](ConstantSDNode *Op, ConstantSDNode *Cond) {
12309 return (!Op && !Cond) ||
12310 (Op && Cond &&
12311 Cond->getAPIntValue() == (-Op->getAPIntValue() - 1));
12312 };
12313 if (SatCC == ISD::SETUGT && Other.getOpcode() == ISD::ADD &&
12314 ISD::matchBinaryPredicate(LHS: OpRHS, RHS: CondRHS, Match: MatchUSUBSAT,
12315 /*AllowUndefs*/ true)) {
12316 OpRHS = DAG.getNegative(Val: OpRHS, DL, VT);
12317 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12318 }
12319
12320 // Another special case: If C was a sign bit, the sub has been
12321 // canonicalized into a xor.
12322 // FIXME: Would it be better to use computeKnownBits to
12323 // determine whether it's safe to decanonicalize the xor?
12324 // x s< 0 ? x^C : 0 --> usubsat x, C
12325 APInt SplatValue;
12326 if (SatCC == ISD::SETLT && Other.getOpcode() == ISD::XOR &&
12327 ISD::isConstantSplatVector(N: OpRHS.getNode(), SplatValue) &&
12328 ISD::isConstantSplatVectorAllZeros(N: CondRHS.getNode()) &&
12329 SplatValue.isSignMask()) {
12330 // Note that we have to rebuild the RHS constant here to
12331 // ensure we don't rely on particular values of undef lanes.
12332 OpRHS = DAG.getConstant(Val: SplatValue, DL, VT);
12333 return DAG.getNode(Opcode: ISD::USUBSAT, DL, VT, N1: OpLHS, N2: OpRHS);
12334 }
12335 }
12336 }
12337 }
12338 }
12339 }
12340 }
12341
12342 if (SimplifySelectOps(SELECT: N, LHS: N1, RHS: N2))
12343 return SDValue(N, 0); // Don't revisit N.
12344
12345 // Fold (vselect all_ones, N1, N2) -> N1
12346 if (ISD::isConstantSplatVectorAllOnes(N: N0.getNode()))
12347 return N1;
12348 // Fold (vselect all_zeros, N1, N2) -> N2
12349 if (ISD::isConstantSplatVectorAllZeros(N: N0.getNode()))
12350 return N2;
12351
12352 // The ConvertSelectToConcatVector function is assuming both the above
12353 // checks for (vselect (build_vector all{ones,zeros) ...) have been made
12354 // and addressed.
12355 if (N1.getOpcode() == ISD::CONCAT_VECTORS &&
12356 N2.getOpcode() == ISD::CONCAT_VECTORS &&
12357 ISD::isBuildVectorOfConstantSDNodes(N: N0.getNode())) {
12358 if (SDValue CV = ConvertSelectToConcatVector(N, DAG))
12359 return CV;
12360 }
12361
12362 if (SDValue V = foldVSelectOfConstants(N))
12363 return V;
12364
12365 if (hasOperation(Opcode: ISD::SRA, VT))
12366 if (SDValue V = foldVSelectToSignBitSplatMask(N, DAG))
12367 return V;
12368
12369 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
12370 return SDValue(N, 0);
12371
12372 return SDValue();
12373}
12374
12375SDValue DAGCombiner::visitSELECT_CC(SDNode *N) {
12376 SDValue N0 = N->getOperand(Num: 0);
12377 SDValue N1 = N->getOperand(Num: 1);
12378 SDValue N2 = N->getOperand(Num: 2);
12379 SDValue N3 = N->getOperand(Num: 3);
12380 SDValue N4 = N->getOperand(Num: 4);
12381 ISD::CondCode CC = cast<CondCodeSDNode>(Val&: N4)->get();
12382
12383 // fold select_cc lhs, rhs, x, x, cc -> x
12384 if (N2 == N3)
12385 return N2;
12386
12387 // select_cc bool, 0, x, y, seteq -> select bool, y, x
12388 if (CC == ISD::SETEQ && !LegalTypes && N0.getValueType() == MVT::i1 &&
12389 isNullConstant(N1))
12390 return DAG.getSelect(DL: SDLoc(N), VT: N2.getValueType(), Cond: N0, LHS: N3, RHS: N2);
12391
12392 // Determine if the condition we're dealing with is constant
12393 if (SDValue SCC = SimplifySetCC(VT: getSetCCResultType(VT: N0.getValueType()), N0, N1,
12394 Cond: CC, DL: SDLoc(N), foldBooleans: false)) {
12395 AddToWorklist(N: SCC.getNode());
12396
12397 // cond always true -> true val
12398 // cond always false -> false val
12399 if (auto *SCCC = dyn_cast<ConstantSDNode>(Val: SCC.getNode()))
12400 return SCCC->isZero() ? N3 : N2;
12401
12402 // When the condition is UNDEF, just return the first operand. This is
12403 // coherent the DAG creation, no setcc node is created in this case
12404 if (SCC->isUndef())
12405 return N2;
12406
12407 // Fold to a simpler select_cc
12408 if (SCC.getOpcode() == ISD::SETCC) {
12409 SDValue SelectOp = DAG.getNode(
12410 Opcode: ISD::SELECT_CC, DL: SDLoc(N), VT: N2.getValueType(), N1: SCC.getOperand(i: 0),
12411 N2: SCC.getOperand(i: 1), N3: N2, N4: N3, N5: SCC.getOperand(i: 2));
12412 SelectOp->setFlags(SCC->getFlags());
12413 return SelectOp;
12414 }
12415 }
12416
12417 // If we can fold this based on the true/false value, do so.
12418 if (SimplifySelectOps(SELECT: N, LHS: N2, RHS: N3))
12419 return SDValue(N, 0); // Don't revisit N.
12420
12421 // fold select_cc into other things, such as min/max/abs
12422 return SimplifySelectCC(DL: SDLoc(N), N0, N1, N2, N3, CC);
12423}
12424
12425SDValue DAGCombiner::visitSETCC(SDNode *N) {
12426 // setcc is very commonly used as an argument to brcond. This pattern
12427 // also lend itself to numerous combines and, as a result, it is desired
12428 // we keep the argument to a brcond as a setcc as much as possible.
12429 bool PreferSetCC =
12430 N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BRCOND;
12431
12432 ISD::CondCode Cond = cast<CondCodeSDNode>(Val: N->getOperand(Num: 2))->get();
12433 EVT VT = N->getValueType(ResNo: 0);
12434 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
12435
12436 SDValue Combined = SimplifySetCC(VT, N0, N1, Cond, DL: SDLoc(N), foldBooleans: !PreferSetCC);
12437
12438 if (Combined) {
12439 // If we prefer to have a setcc, and we don't, we'll try our best to
12440 // recreate one using rebuildSetCC.
12441 if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
12442 SDValue NewSetCC = rebuildSetCC(N: Combined);
12443
12444 // We don't have anything interesting to combine to.
12445 if (NewSetCC.getNode() == N)
12446 return SDValue();
12447
12448 if (NewSetCC)
12449 return NewSetCC;
12450 }
12451 return Combined;
12452 }
12453
12454 // Optimize
12455 // 1) (icmp eq/ne (and X, C0), (shift X, C1))
12456 // or
12457 // 2) (icmp eq/ne X, (rotate X, C1))
12458 // If C0 is a mask or shifted mask and the shift amt (C1) isolates the
12459 // remaining bits (i.e something like `(x64 & UINT32_MAX) == (x64 >> 32)`)
12460 // Then:
12461 // If C1 is a power of 2, then the rotate and shift+and versions are
12462 // equivilent, so we can interchange them depending on target preference.
12463 // Otherwise, if we have the shift+and version we can interchange srl/shl
12464 // which inturn affects the constant C0. We can use this to get better
12465 // constants again determined by target preference.
12466 if (Cond == ISD::SETNE || Cond == ISD::SETEQ) {
12467 auto IsAndWithShift = [](SDValue A, SDValue B) {
12468 return A.getOpcode() == ISD::AND &&
12469 (B.getOpcode() == ISD::SRL || B.getOpcode() == ISD::SHL) &&
12470 A.getOperand(i: 0) == B.getOperand(i: 0);
12471 };
12472 auto IsRotateWithOp = [](SDValue A, SDValue B) {
12473 return (B.getOpcode() == ISD::ROTL || B.getOpcode() == ISD::ROTR) &&
12474 B.getOperand(i: 0) == A;
12475 };
12476 SDValue AndOrOp = SDValue(), ShiftOrRotate = SDValue();
12477 bool IsRotate = false;
12478
12479 // Find either shift+and or rotate pattern.
12480 if (IsAndWithShift(N0, N1)) {
12481 AndOrOp = N0;
12482 ShiftOrRotate = N1;
12483 } else if (IsAndWithShift(N1, N0)) {
12484 AndOrOp = N1;
12485 ShiftOrRotate = N0;
12486 } else if (IsRotateWithOp(N0, N1)) {
12487 IsRotate = true;
12488 AndOrOp = N0;
12489 ShiftOrRotate = N1;
12490 } else if (IsRotateWithOp(N1, N0)) {
12491 IsRotate = true;
12492 AndOrOp = N1;
12493 ShiftOrRotate = N0;
12494 }
12495
12496 if (AndOrOp && ShiftOrRotate && ShiftOrRotate.hasOneUse() &&
12497 (IsRotate || AndOrOp.hasOneUse())) {
12498 EVT OpVT = N0.getValueType();
12499 // Get constant shift/rotate amount and possibly mask (if its shift+and
12500 // variant).
12501 auto GetAPIntValue = [](SDValue Op) -> std::optional<APInt> {
12502 ConstantSDNode *CNode = isConstOrConstSplat(N: Op, /*AllowUndefs*/ false,
12503 /*AllowTrunc*/ AllowTruncation: false);
12504 if (CNode == nullptr)
12505 return std::nullopt;
12506 return CNode->getAPIntValue();
12507 };
12508 std::optional<APInt> AndCMask =
12509 IsRotate ? std::nullopt : GetAPIntValue(AndOrOp.getOperand(i: 1));
12510 std::optional<APInt> ShiftCAmt =
12511 GetAPIntValue(ShiftOrRotate.getOperand(i: 1));
12512 unsigned NumBits = OpVT.getScalarSizeInBits();
12513
12514 // We found constants.
12515 if (ShiftCAmt && (IsRotate || AndCMask) && ShiftCAmt->ult(RHS: NumBits)) {
12516 unsigned ShiftOpc = ShiftOrRotate.getOpcode();
12517 // Check that the constants meet the constraints.
12518 bool CanTransform = IsRotate;
12519 if (!CanTransform) {
12520 // Check that mask and shift compliment eachother
12521 CanTransform = *ShiftCAmt == (~*AndCMask).popcount();
12522 // Check that we are comparing all bits
12523 CanTransform &= (*ShiftCAmt + AndCMask->popcount()) == NumBits;
12524 // Check that the and mask is correct for the shift
12525 CanTransform &=
12526 ShiftOpc == ISD::SHL ? (~*AndCMask).isMask() : AndCMask->isMask();
12527 }
12528
12529 // See if target prefers another shift/rotate opcode.
12530 unsigned NewShiftOpc = TLI.preferedOpcodeForCmpEqPiecesOfOperand(
12531 VT: OpVT, ShiftOpc, MayTransformRotate: ShiftCAmt->isPowerOf2(), ShiftOrRotateAmt: *ShiftCAmt, AndMask: AndCMask);
12532 // Transform is valid and we have a new preference.
12533 if (CanTransform && NewShiftOpc != ShiftOpc) {
12534 SDLoc DL(N);
12535 SDValue NewShiftOrRotate =
12536 DAG.getNode(Opcode: NewShiftOpc, DL, VT: OpVT, N1: ShiftOrRotate.getOperand(i: 0),
12537 N2: ShiftOrRotate.getOperand(i: 1));
12538 SDValue NewAndOrOp = SDValue();
12539
12540 if (NewShiftOpc == ISD::SHL || NewShiftOpc == ISD::SRL) {
12541 APInt NewMask =
12542 NewShiftOpc == ISD::SHL
12543 ? APInt::getHighBitsSet(numBits: NumBits,
12544 hiBitsSet: NumBits - ShiftCAmt->getZExtValue())
12545 : APInt::getLowBitsSet(numBits: NumBits,
12546 loBitsSet: NumBits - ShiftCAmt->getZExtValue());
12547 NewAndOrOp =
12548 DAG.getNode(Opcode: ISD::AND, DL, VT: OpVT, N1: ShiftOrRotate.getOperand(i: 0),
12549 N2: DAG.getConstant(Val: NewMask, DL, VT: OpVT));
12550 } else {
12551 NewAndOrOp = ShiftOrRotate.getOperand(i: 0);
12552 }
12553
12554 return DAG.getSetCC(DL, VT, LHS: NewAndOrOp, RHS: NewShiftOrRotate, Cond);
12555 }
12556 }
12557 }
12558 }
12559 return SDValue();
12560}
12561
12562SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {
12563 SDValue LHS = N->getOperand(Num: 0);
12564 SDValue RHS = N->getOperand(Num: 1);
12565 SDValue Carry = N->getOperand(Num: 2);
12566 SDValue Cond = N->getOperand(Num: 3);
12567
12568 // If Carry is false, fold to a regular SETCC.
12569 if (isNullConstant(V: Carry))
12570 return DAG.getNode(Opcode: ISD::SETCC, DL: SDLoc(N), VTList: N->getVTList(), N1: LHS, N2: RHS, N3: Cond);
12571
12572 return SDValue();
12573}
12574
12575/// Check if N satisfies:
12576/// N is used once.
12577/// N is a Load.
12578/// The load is compatible with ExtOpcode. It means
12579/// If load has explicit zero/sign extension, ExpOpcode must have the same
12580/// extension.
12581/// Otherwise returns true.
12582static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
12583 if (!N.hasOneUse())
12584 return false;
12585
12586 if (!isa<LoadSDNode>(Val: N))
12587 return false;
12588
12589 LoadSDNode *Load = cast<LoadSDNode>(Val&: N);
12590 ISD::LoadExtType LoadExt = Load->getExtensionType();
12591 if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
12592 return true;
12593
12594 // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
12595 // extension.
12596 if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
12597 (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
12598 return false;
12599
12600 return true;
12601}
12602
12603/// Fold
12604/// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
12605/// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
12606/// (aext (select c, load x, load y)) -> (select c, extload x, extload y)
12607/// This function is called by the DAGCombiner when visiting sext/zext/aext
12608/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
12609static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
12610 SelectionDAG &DAG,
12611 CombineLevel Level) {
12612 unsigned Opcode = N->getOpcode();
12613 SDValue N0 = N->getOperand(Num: 0);
12614 EVT VT = N->getValueType(ResNo: 0);
12615 SDLoc DL(N);
12616
12617 assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
12618 Opcode == ISD::ANY_EXTEND) &&
12619 "Expected EXTEND dag node in input!");
12620
12621 if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
12622 !N0.hasOneUse())
12623 return SDValue();
12624
12625 SDValue Op1 = N0->getOperand(Num: 1);
12626 SDValue Op2 = N0->getOperand(Num: 2);
12627 if (!isCompatibleLoad(N: Op1, ExtOpcode: Opcode) || !isCompatibleLoad(N: Op2, ExtOpcode: Opcode))
12628 return SDValue();
12629
12630 auto ExtLoadOpcode = ISD::EXTLOAD;
12631 if (Opcode == ISD::SIGN_EXTEND)
12632 ExtLoadOpcode = ISD::SEXTLOAD;
12633 else if (Opcode == ISD::ZERO_EXTEND)
12634 ExtLoadOpcode = ISD::ZEXTLOAD;
12635
12636 // Illegal VSELECT may ISel fail if happen after legalization (DAG
12637 // Combine2), so we should conservatively check the OperationAction.
12638 LoadSDNode *Load1 = cast<LoadSDNode>(Val&: Op1);
12639 LoadSDNode *Load2 = cast<LoadSDNode>(Val&: Op2);
12640 if (!TLI.isLoadExtLegal(ExtType: ExtLoadOpcode, ValVT: VT, MemVT: Load1->getMemoryVT()) ||
12641 !TLI.isLoadExtLegal(ExtType: ExtLoadOpcode, ValVT: VT, MemVT: Load2->getMemoryVT()) ||
12642 (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes &&
12643 TLI.getOperationAction(Op: ISD::VSELECT, VT) != TargetLowering::Legal))
12644 return SDValue();
12645
12646 SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Operand: Op1);
12647 SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Operand: Op2);
12648 return DAG.getSelect(DL, VT, Cond: N0->getOperand(Num: 0), LHS: Ext1, RHS: Ext2);
12649}
12650
12651/// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
12652/// a build_vector of constants.
12653/// This function is called by the DAGCombiner when visiting sext/zext/aext
12654/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
12655/// Vector extends are not folded if operations are legal; this is to
12656/// avoid introducing illegal build_vector dag nodes.
12657static SDValue tryToFoldExtendOfConstant(SDNode *N, const SDLoc &DL,
12658 const TargetLowering &TLI,
12659 SelectionDAG &DAG, bool LegalTypes) {
12660 unsigned Opcode = N->getOpcode();
12661 SDValue N0 = N->getOperand(Num: 0);
12662 EVT VT = N->getValueType(ResNo: 0);
12663
12664 assert((ISD::isExtOpcode(Opcode) || ISD::isExtVecInRegOpcode(Opcode)) &&
12665 "Expected EXTEND dag node in input!");
12666
12667 // fold (sext c1) -> c1
12668 // fold (zext c1) -> c1
12669 // fold (aext c1) -> c1
12670 if (isa<ConstantSDNode>(Val: N0))
12671 return DAG.getNode(Opcode, DL, VT, Operand: N0);
12672
12673 // fold (sext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
12674 // fold (zext (select cond, c1, c2)) -> (select cond, zext c1, zext c2)
12675 // fold (aext (select cond, c1, c2)) -> (select cond, sext c1, sext c2)
12676 if (N0->getOpcode() == ISD::SELECT) {
12677 SDValue Op1 = N0->getOperand(Num: 1);
12678 SDValue Op2 = N0->getOperand(Num: 2);
12679 if (isa<ConstantSDNode>(Val: Op1) && isa<ConstantSDNode>(Val: Op2) &&
12680 (Opcode != ISD::ZERO_EXTEND || !TLI.isZExtFree(FromTy: N0.getValueType(), ToTy: VT))) {
12681 // For any_extend, choose sign extension of the constants to allow a
12682 // possible further transform to sign_extend_inreg.i.e.
12683 //
12684 // t1: i8 = select t0, Constant:i8<-1>, Constant:i8<0>
12685 // t2: i64 = any_extend t1
12686 // -->
12687 // t3: i64 = select t0, Constant:i64<-1>, Constant:i64<0>
12688 // -->
12689 // t4: i64 = sign_extend_inreg t3
12690 unsigned FoldOpc = Opcode;
12691 if (FoldOpc == ISD::ANY_EXTEND)
12692 FoldOpc = ISD::SIGN_EXTEND;
12693 return DAG.getSelect(DL, VT, Cond: N0->getOperand(Num: 0),
12694 LHS: DAG.getNode(Opcode: FoldOpc, DL, VT, Operand: Op1),
12695 RHS: DAG.getNode(Opcode: FoldOpc, DL, VT, Operand: Op2));
12696 }
12697 }
12698
12699 // fold (sext (build_vector AllConstants) -> (build_vector AllConstants)
12700 // fold (zext (build_vector AllConstants) -> (build_vector AllConstants)
12701 // fold (aext (build_vector AllConstants) -> (build_vector AllConstants)
12702 EVT SVT = VT.getScalarType();
12703 if (!(VT.isVector() && (!LegalTypes || TLI.isTypeLegal(VT: SVT)) &&
12704 ISD::isBuildVectorOfConstantSDNodes(N: N0.getNode())))
12705 return SDValue();
12706
12707 // We can fold this node into a build_vector.
12708 unsigned VTBits = SVT.getSizeInBits();
12709 unsigned EVTBits = N0->getValueType(ResNo: 0).getScalarSizeInBits();
12710 SmallVector<SDValue, 8> Elts;
12711 unsigned NumElts = VT.getVectorNumElements();
12712
12713 for (unsigned i = 0; i != NumElts; ++i) {
12714 SDValue Op = N0.getOperand(i);
12715 if (Op.isUndef()) {
12716 if (Opcode == ISD::ANY_EXTEND || Opcode == ISD::ANY_EXTEND_VECTOR_INREG)
12717 Elts.push_back(Elt: DAG.getUNDEF(VT: SVT));
12718 else
12719 Elts.push_back(Elt: DAG.getConstant(Val: 0, DL, VT: SVT));
12720 continue;
12721 }
12722
12723 SDLoc DL(Op);
12724 // Get the constant value and if needed trunc it to the size of the type.
12725 // Nodes like build_vector might have constants wider than the scalar type.
12726 APInt C = Op->getAsAPIntVal().zextOrTrunc(width: EVTBits);
12727 if (Opcode == ISD::SIGN_EXTEND || Opcode == ISD::SIGN_EXTEND_VECTOR_INREG)
12728 Elts.push_back(Elt: DAG.getConstant(Val: C.sext(width: VTBits), DL, VT: SVT));
12729 else
12730 Elts.push_back(Elt: DAG.getConstant(Val: C.zext(width: VTBits), DL, VT: SVT));
12731 }
12732
12733 return DAG.getBuildVector(VT, DL, Ops: Elts);
12734}
12735
12736// ExtendUsesToFormExtLoad - Trying to extend uses of a load to enable this:
12737// "fold ({s|z|a}ext (load x)) -> ({s|z|a}ext (truncate ({s|z|a}extload x)))"
12738// transformation. Returns true if extension are possible and the above
12739// mentioned transformation is profitable.
12740static bool ExtendUsesToFormExtLoad(EVT VT, SDNode *N, SDValue N0,
12741 unsigned ExtOpc,
12742 SmallVectorImpl<SDNode *> &ExtendNodes,
12743 const TargetLowering &TLI) {
12744 bool HasCopyToRegUses = false;
12745 bool isTruncFree = TLI.isTruncateFree(FromVT: VT, ToVT: N0.getValueType());
12746 for (SDNode::use_iterator UI = N0->use_begin(), UE = N0->use_end(); UI != UE;
12747 ++UI) {
12748 SDNode *User = *UI;
12749 if (User == N)
12750 continue;
12751 if (UI.getUse().getResNo() != N0.getResNo())
12752 continue;
12753 // FIXME: Only extend SETCC N, N and SETCC N, c for now.
12754 if (ExtOpc != ISD::ANY_EXTEND && User->getOpcode() == ISD::SETCC) {
12755 ISD::CondCode CC = cast<CondCodeSDNode>(Val: User->getOperand(Num: 2))->get();
12756 if (ExtOpc == ISD::ZERO_EXTEND && ISD::isSignedIntSetCC(Code: CC))
12757 // Sign bits will be lost after a zext.
12758 return false;
12759 bool Add = false;
12760 for (unsigned i = 0; i != 2; ++i) {
12761 SDValue UseOp = User->getOperand(Num: i);
12762 if (UseOp == N0)
12763 continue;
12764 if (!isa<ConstantSDNode>(Val: UseOp))
12765 return false;
12766 Add = true;
12767 }
12768 if (Add)
12769 ExtendNodes.push_back(Elt: User);
12770 continue;
12771 }
12772 // If truncates aren't free and there are users we can't
12773 // extend, it isn't worthwhile.
12774 if (!isTruncFree)
12775 return false;
12776 // Remember if this value is live-out.
12777 if (User->getOpcode() == ISD::CopyToReg)
12778 HasCopyToRegUses = true;
12779 }
12780
12781 if (HasCopyToRegUses) {
12782 bool BothLiveOut = false;
12783 for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end();
12784 UI != UE; ++UI) {
12785 SDUse &Use = UI.getUse();
12786 if (Use.getResNo() == 0 && Use.getUser()->getOpcode() == ISD::CopyToReg) {
12787 BothLiveOut = true;
12788 break;
12789 }
12790 }
12791 if (BothLiveOut)
12792 // Both unextended and extended values are live out. There had better be
12793 // a good reason for the transformation.
12794 return !ExtendNodes.empty();
12795 }
12796 return true;
12797}
12798
12799void DAGCombiner::ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
12800 SDValue OrigLoad, SDValue ExtLoad,
12801 ISD::NodeType ExtType) {
12802 // Extend SetCC uses if necessary.
12803 SDLoc DL(ExtLoad);
12804 for (SDNode *SetCC : SetCCs) {
12805 SmallVector<SDValue, 4> Ops;
12806
12807 for (unsigned j = 0; j != 2; ++j) {
12808 SDValue SOp = SetCC->getOperand(Num: j);
12809 if (SOp == OrigLoad)
12810 Ops.push_back(Elt: ExtLoad);
12811 else
12812 Ops.push_back(Elt: DAG.getNode(Opcode: ExtType, DL, VT: ExtLoad->getValueType(ResNo: 0), Operand: SOp));
12813 }
12814
12815 Ops.push_back(Elt: SetCC->getOperand(Num: 2));
12816 CombineTo(N: SetCC, Res: DAG.getNode(Opcode: ISD::SETCC, DL, VT: SetCC->getValueType(ResNo: 0), Ops));
12817 }
12818}
12819
12820// FIXME: Bring more similar combines here, common to sext/zext (maybe aext?).
12821SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
12822 SDValue N0 = N->getOperand(Num: 0);
12823 EVT DstVT = N->getValueType(ResNo: 0);
12824 EVT SrcVT = N0.getValueType();
12825
12826 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
12827 N->getOpcode() == ISD::ZERO_EXTEND) &&
12828 "Unexpected node type (not an extend)!");
12829
12830 // fold (sext (load x)) to multiple smaller sextloads; same for zext.
12831 // For example, on a target with legal v4i32, but illegal v8i32, turn:
12832 // (v8i32 (sext (v8i16 (load x))))
12833 // into:
12834 // (v8i32 (concat_vectors (v4i32 (sextload x)),
12835 // (v4i32 (sextload (x + 16)))))
12836 // Where uses of the original load, i.e.:
12837 // (v8i16 (load x))
12838 // are replaced with:
12839 // (v8i16 (truncate
12840 // (v8i32 (concat_vectors (v4i32 (sextload x)),
12841 // (v4i32 (sextload (x + 16)))))))
12842 //
12843 // This combine is only applicable to illegal, but splittable, vectors.
12844 // All legal types, and illegal non-vector types, are handled elsewhere.
12845 // This combine is controlled by TargetLowering::isVectorLoadExtDesirable.
12846 //
12847 if (N0->getOpcode() != ISD::LOAD)
12848 return SDValue();
12849
12850 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
12851
12852 if (!ISD::isNON_EXTLoad(N: LN0) || !ISD::isUNINDEXEDLoad(N: LN0) ||
12853 !N0.hasOneUse() || !LN0->isSimple() ||
12854 !DstVT.isVector() || !DstVT.isPow2VectorType() ||
12855 !TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0)))
12856 return SDValue();
12857
12858 SmallVector<SDNode *, 4> SetCCs;
12859 if (!ExtendUsesToFormExtLoad(VT: DstVT, N, N0, ExtOpc: N->getOpcode(), ExtendNodes&: SetCCs, TLI))
12860 return SDValue();
12861
12862 ISD::LoadExtType ExtType =
12863 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
12864
12865 // Try to split the vector types to get down to legal types.
12866 EVT SplitSrcVT = SrcVT;
12867 EVT SplitDstVT = DstVT;
12868 while (!TLI.isLoadExtLegalOrCustom(ExtType, ValVT: SplitDstVT, MemVT: SplitSrcVT) &&
12869 SplitSrcVT.getVectorNumElements() > 1) {
12870 SplitDstVT = DAG.GetSplitDestVTs(VT: SplitDstVT).first;
12871 SplitSrcVT = DAG.GetSplitDestVTs(VT: SplitSrcVT).first;
12872 }
12873
12874 if (!TLI.isLoadExtLegalOrCustom(ExtType, ValVT: SplitDstVT, MemVT: SplitSrcVT))
12875 return SDValue();
12876
12877 assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
12878
12879 SDLoc DL(N);
12880 const unsigned NumSplits =
12881 DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements();
12882 const unsigned Stride = SplitSrcVT.getStoreSize();
12883 SmallVector<SDValue, 4> Loads;
12884 SmallVector<SDValue, 4> Chains;
12885
12886 SDValue BasePtr = LN0->getBasePtr();
12887 for (unsigned Idx = 0; Idx < NumSplits; Idx++) {
12888 const unsigned Offset = Idx * Stride;
12889
12890 SDValue SplitLoad =
12891 DAG.getExtLoad(ExtType, dl: SDLoc(LN0), VT: SplitDstVT, Chain: LN0->getChain(),
12892 Ptr: BasePtr, PtrInfo: LN0->getPointerInfo().getWithOffset(O: Offset),
12893 MemVT: SplitSrcVT, Alignment: LN0->getOriginalAlign(),
12894 MMOFlags: LN0->getMemOperand()->getFlags(), AAInfo: LN0->getAAInfo());
12895
12896 BasePtr = DAG.getMemBasePlusOffset(Base: BasePtr, Offset: TypeSize::getFixed(ExactSize: Stride), DL);
12897
12898 Loads.push_back(Elt: SplitLoad.getValue(R: 0));
12899 Chains.push_back(Elt: SplitLoad.getValue(R: 1));
12900 }
12901
12902 SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains);
12903 SDValue NewValue = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: DstVT, Ops: Loads);
12904
12905 // Simplify TF.
12906 AddToWorklist(N: NewChain.getNode());
12907
12908 CombineTo(N, Res: NewValue);
12909
12910 // Replace uses of the original load (before extension)
12911 // with a truncate of the concatenated sextloaded vectors.
12912 SDValue Trunc =
12913 DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT: N0.getValueType(), Operand: NewValue);
12914 ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad: NewValue, ExtType: (ISD::NodeType)N->getOpcode());
12915 CombineTo(N: N0.getNode(), Res0: Trunc, Res1: NewChain);
12916 return SDValue(N, 0); // Return N so it doesn't get rechecked!
12917}
12918
12919// fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
12920// (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
12921SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
12922 assert(N->getOpcode() == ISD::ZERO_EXTEND);
12923 EVT VT = N->getValueType(ResNo: 0);
12924 EVT OrigVT = N->getOperand(Num: 0).getValueType();
12925 if (TLI.isZExtFree(FromTy: OrigVT, ToTy: VT))
12926 return SDValue();
12927
12928 // and/or/xor
12929 SDValue N0 = N->getOperand(Num: 0);
12930 if (!ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) ||
12931 N0.getOperand(i: 1).getOpcode() != ISD::Constant ||
12932 (LegalOperations && !TLI.isOperationLegal(Op: N0.getOpcode(), VT)))
12933 return SDValue();
12934
12935 // shl/shr
12936 SDValue N1 = N0->getOperand(Num: 0);
12937 if (!(N1.getOpcode() == ISD::SHL || N1.getOpcode() == ISD::SRL) ||
12938 N1.getOperand(i: 1).getOpcode() != ISD::Constant ||
12939 (LegalOperations && !TLI.isOperationLegal(Op: N1.getOpcode(), VT)))
12940 return SDValue();
12941
12942 // load
12943 if (!isa<LoadSDNode>(Val: N1.getOperand(i: 0)))
12944 return SDValue();
12945 LoadSDNode *Load = cast<LoadSDNode>(Val: N1.getOperand(i: 0));
12946 EVT MemVT = Load->getMemoryVT();
12947 if (!TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT) ||
12948 Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
12949 return SDValue();
12950
12951
12952 // If the shift op is SHL, the logic op must be AND, otherwise the result
12953 // will be wrong.
12954 if (N1.getOpcode() == ISD::SHL && N0.getOpcode() != ISD::AND)
12955 return SDValue();
12956
12957 if (!N0.hasOneUse() || !N1.hasOneUse())
12958 return SDValue();
12959
12960 SmallVector<SDNode*, 4> SetCCs;
12961 if (!ExtendUsesToFormExtLoad(VT, N: N1.getNode(), N0: N1.getOperand(i: 0),
12962 ExtOpc: ISD::ZERO_EXTEND, ExtendNodes&: SetCCs, TLI))
12963 return SDValue();
12964
12965 // Actually do the transformation.
12966 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(Load), VT,
12967 Chain: Load->getChain(), Ptr: Load->getBasePtr(),
12968 MemVT: Load->getMemoryVT(), MMO: Load->getMemOperand());
12969
12970 SDLoc DL1(N1);
12971 SDValue Shift = DAG.getNode(Opcode: N1.getOpcode(), DL: DL1, VT, N1: ExtLoad,
12972 N2: N1.getOperand(i: 1));
12973
12974 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
12975 SDLoc DL0(N0);
12976 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL: DL0, VT, N1: Shift,
12977 N2: DAG.getConstant(Val: Mask, DL: DL0, VT));
12978
12979 ExtendSetCCUses(SetCCs, OrigLoad: N1.getOperand(i: 0), ExtLoad, ExtType: ISD::ZERO_EXTEND);
12980 CombineTo(N, Res: And);
12981 if (SDValue(Load, 0).hasOneUse()) {
12982 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Load, 1), To: ExtLoad.getValue(R: 1));
12983 } else {
12984 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(Load),
12985 VT: Load->getValueType(ResNo: 0), Operand: ExtLoad);
12986 CombineTo(N: Load, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
12987 }
12988
12989 // N0 is dead at this point.
12990 recursivelyDeleteUnusedNodes(N: N0.getNode());
12991
12992 return SDValue(N,0); // Return N so it doesn't get rechecked!
12993}
12994
12995/// If we're narrowing or widening the result of a vector select and the final
12996/// size is the same size as a setcc (compare) feeding the select, then try to
12997/// apply the cast operation to the select's operands because matching vector
12998/// sizes for a select condition and other operands should be more efficient.
12999SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
13000 unsigned CastOpcode = Cast->getOpcode();
13001 assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
13002 CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
13003 CastOpcode == ISD::FP_ROUND) &&
13004 "Unexpected opcode for vector select narrowing/widening");
13005
13006 // We only do this transform before legal ops because the pattern may be
13007 // obfuscated by target-specific operations after legalization. Do not create
13008 // an illegal select op, however, because that may be difficult to lower.
13009 EVT VT = Cast->getValueType(ResNo: 0);
13010 if (LegalOperations || !TLI.isOperationLegalOrCustom(Op: ISD::VSELECT, VT))
13011 return SDValue();
13012
13013 SDValue VSel = Cast->getOperand(Num: 0);
13014 if (VSel.getOpcode() != ISD::VSELECT || !VSel.hasOneUse() ||
13015 VSel.getOperand(i: 0).getOpcode() != ISD::SETCC)
13016 return SDValue();
13017
13018 // Does the setcc have the same vector size as the casted select?
13019 SDValue SetCC = VSel.getOperand(i: 0);
13020 EVT SetCCVT = getSetCCResultType(VT: SetCC.getOperand(i: 0).getValueType());
13021 if (SetCCVT.getSizeInBits() != VT.getSizeInBits())
13022 return SDValue();
13023
13024 // cast (vsel (setcc X), A, B) --> vsel (setcc X), (cast A), (cast B)
13025 SDValue A = VSel.getOperand(i: 1);
13026 SDValue B = VSel.getOperand(i: 2);
13027 SDValue CastA, CastB;
13028 SDLoc DL(Cast);
13029 if (CastOpcode == ISD::FP_ROUND) {
13030 // FP_ROUND (fptrunc) has an extra flag operand to pass along.
13031 CastA = DAG.getNode(Opcode: CastOpcode, DL, VT, N1: A, N2: Cast->getOperand(Num: 1));
13032 CastB = DAG.getNode(Opcode: CastOpcode, DL, VT, N1: B, N2: Cast->getOperand(Num: 1));
13033 } else {
13034 CastA = DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: A);
13035 CastB = DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: B);
13036 }
13037 return DAG.getNode(Opcode: ISD::VSELECT, DL, VT, N1: SetCC, N2: CastA, N3: CastB);
13038}
13039
13040// fold ([s|z]ext ([s|z]extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
13041// fold ([s|z]ext ( extload x)) -> ([s|z]ext (truncate ([s|z]extload x)))
13042static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
13043 const TargetLowering &TLI, EVT VT,
13044 bool LegalOperations, SDNode *N,
13045 SDValue N0, ISD::LoadExtType ExtLoadType) {
13046 SDNode *N0Node = N0.getNode();
13047 bool isAExtLoad = (ExtLoadType == ISD::SEXTLOAD) ? ISD::isSEXTLoad(N: N0Node)
13048 : ISD::isZEXTLoad(N: N0Node);
13049 if ((!isAExtLoad && !ISD::isEXTLoad(N: N0Node)) ||
13050 !ISD::isUNINDEXEDLoad(N: N0Node) || !N0.hasOneUse())
13051 return SDValue();
13052
13053 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
13054 EVT MemVT = LN0->getMemoryVT();
13055 if ((LegalOperations || !LN0->isSimple() ||
13056 VT.isVector()) &&
13057 !TLI.isLoadExtLegal(ExtType: ExtLoadType, ValVT: VT, MemVT))
13058 return SDValue();
13059
13060 SDValue ExtLoad =
13061 DAG.getExtLoad(ExtType: ExtLoadType, dl: SDLoc(LN0), VT, Chain: LN0->getChain(),
13062 Ptr: LN0->getBasePtr(), MemVT, MMO: LN0->getMemOperand());
13063 Combiner.CombineTo(N, Res: ExtLoad);
13064 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
13065 if (LN0->use_empty())
13066 Combiner.recursivelyDeleteUnusedNodes(N: LN0);
13067 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13068}
13069
13070// fold ([s|z]ext (load x)) -> ([s|z]ext (truncate ([s|z]extload x)))
13071// Only generate vector extloads when 1) they're legal, and 2) they are
13072// deemed desirable by the target. NonNegZExt can be set to true if a zero
13073// extend has the nonneg flag to allow use of sextload if profitable.
13074static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
13075 const TargetLowering &TLI, EVT VT,
13076 bool LegalOperations, SDNode *N, SDValue N0,
13077 ISD::LoadExtType ExtLoadType,
13078 ISD::NodeType ExtOpc,
13079 bool NonNegZExt = false) {
13080 if (!ISD::isNON_EXTLoad(N: N0.getNode()) || !ISD::isUNINDEXEDLoad(N: N0.getNode()))
13081 return {};
13082
13083 // If this is zext nneg, see if it would make sense to treat it as a sext.
13084 if (NonNegZExt) {
13085 assert(ExtLoadType == ISD::ZEXTLOAD && ExtOpc == ISD::ZERO_EXTEND &&
13086 "Unexpected load type or opcode");
13087 for (SDNode *User : N0->uses()) {
13088 if (User->getOpcode() == ISD::SETCC) {
13089 ISD::CondCode CC = cast<CondCodeSDNode>(Val: User->getOperand(Num: 2))->get();
13090 if (ISD::isSignedIntSetCC(Code: CC)) {
13091 ExtLoadType = ISD::SEXTLOAD;
13092 ExtOpc = ISD::SIGN_EXTEND;
13093 break;
13094 }
13095 }
13096 }
13097 }
13098
13099 // TODO: isFixedLengthVector() should be removed and any negative effects on
13100 // code generation being the result of that target's implementation of
13101 // isVectorLoadExtDesirable().
13102 if ((LegalOperations || VT.isFixedLengthVector() ||
13103 !cast<LoadSDNode>(Val&: N0)->isSimple()) &&
13104 !TLI.isLoadExtLegal(ExtType: ExtLoadType, ValVT: VT, MemVT: N0.getValueType()))
13105 return {};
13106
13107 bool DoXform = true;
13108 SmallVector<SDNode *, 4> SetCCs;
13109 if (!N0.hasOneUse())
13110 DoXform = ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc, ExtendNodes&: SetCCs, TLI);
13111 if (VT.isVector())
13112 DoXform &= TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0));
13113 if (!DoXform)
13114 return {};
13115
13116 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
13117 SDValue ExtLoad = DAG.getExtLoad(ExtType: ExtLoadType, dl: SDLoc(LN0), VT, Chain: LN0->getChain(),
13118 Ptr: LN0->getBasePtr(), MemVT: N0.getValueType(),
13119 MMO: LN0->getMemOperand());
13120 Combiner.ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad, ExtType: ExtOpc);
13121 // If the load value is used only by N, replace it via CombineTo N.
13122 bool NoReplaceTrunc = SDValue(LN0, 0).hasOneUse();
13123 Combiner.CombineTo(N, Res: ExtLoad);
13124 if (NoReplaceTrunc) {
13125 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
13126 Combiner.recursivelyDeleteUnusedNodes(N: LN0);
13127 } else {
13128 SDValue Trunc =
13129 DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT: N0.getValueType(), Operand: ExtLoad);
13130 Combiner.CombineTo(N: LN0, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
13131 }
13132 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13133}
13134
13135static SDValue
13136tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT,
13137 bool LegalOperations, SDNode *N, SDValue N0,
13138 ISD::LoadExtType ExtLoadType, ISD::NodeType ExtOpc) {
13139 if (!N0.hasOneUse())
13140 return SDValue();
13141
13142 MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(Val&: N0);
13143 if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
13144 return SDValue();
13145
13146 if ((LegalOperations || !cast<MaskedLoadSDNode>(Val&: N0)->isSimple()) &&
13147 !TLI.isLoadExtLegalOrCustom(ExtType: ExtLoadType, ValVT: VT, MemVT: Ld->getValueType(ResNo: 0)))
13148 return SDValue();
13149
13150 if (!TLI.isVectorLoadExtDesirable(ExtVal: SDValue(N, 0)))
13151 return SDValue();
13152
13153 SDLoc dl(Ld);
13154 SDValue PassThru = DAG.getNode(Opcode: ExtOpc, DL: dl, VT, Operand: Ld->getPassThru());
13155 SDValue NewLoad = DAG.getMaskedLoad(
13156 VT, dl, Chain: Ld->getChain(), Base: Ld->getBasePtr(), Offset: Ld->getOffset(), Mask: Ld->getMask(),
13157 Src0: PassThru, MemVT: Ld->getMemoryVT(), MMO: Ld->getMemOperand(), AM: Ld->getAddressingMode(),
13158 ExtLoadType, IsExpanding: Ld->isExpandingLoad());
13159 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Ld, 1), To: SDValue(NewLoad.getNode(), 1));
13160 return NewLoad;
13161}
13162
13163// fold ([s|z]ext (atomic_load)) -> ([s|z]ext (truncate ([s|z]ext atomic_load)))
13164static SDValue tryToFoldExtOfAtomicLoad(SelectionDAG &DAG,
13165 const TargetLowering &TLI, EVT VT,
13166 SDValue N0,
13167 ISD::LoadExtType ExtLoadType) {
13168 auto *ALoad = dyn_cast<AtomicSDNode>(Val&: N0);
13169 if (!ALoad || ALoad->getOpcode() != ISD::ATOMIC_LOAD)
13170 return {};
13171 EVT MemoryVT = ALoad->getMemoryVT();
13172 if (!TLI.isAtomicLoadExtLegal(ExtType: ExtLoadType, ValVT: VT, MemVT: MemoryVT))
13173 return {};
13174 // Can't fold into ALoad if it is already extending differently.
13175 ISD::LoadExtType ALoadExtTy = ALoad->getExtensionType();
13176 if ((ALoadExtTy == ISD::ZEXTLOAD && ExtLoadType == ISD::SEXTLOAD) ||
13177 (ALoadExtTy == ISD::SEXTLOAD && ExtLoadType == ISD::ZEXTLOAD))
13178 return {};
13179
13180 EVT OrigVT = ALoad->getValueType(ResNo: 0);
13181 assert(OrigVT.getSizeInBits() < VT.getSizeInBits() && "VT should be wider.");
13182 auto *NewALoad = cast<AtomicSDNode>(Val: DAG.getAtomic(
13183 Opcode: ISD::ATOMIC_LOAD, dl: SDLoc(ALoad), MemVT: MemoryVT, VT, Chain: ALoad->getChain(),
13184 Ptr: ALoad->getBasePtr(), MMO: ALoad->getMemOperand()));
13185 NewALoad->setExtensionType(ExtLoadType);
13186 DAG.ReplaceAllUsesOfValueWith(
13187 From: SDValue(ALoad, 0),
13188 To: DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(ALoad), VT: OrigVT, Operand: SDValue(NewALoad, 0)));
13189 // Update the chain uses.
13190 DAG.ReplaceAllUsesOfValueWith(From: SDValue(ALoad, 1), To: SDValue(NewALoad, 1));
13191 return SDValue(NewALoad, 0);
13192}
13193
13194static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
13195 bool LegalOperations) {
13196 assert((N->getOpcode() == ISD::SIGN_EXTEND ||
13197 N->getOpcode() == ISD::ZERO_EXTEND) && "Expected sext or zext");
13198
13199 SDValue SetCC = N->getOperand(Num: 0);
13200 if (LegalOperations || SetCC.getOpcode() != ISD::SETCC ||
13201 !SetCC.hasOneUse() || SetCC.getValueType() != MVT::i1)
13202 return SDValue();
13203
13204 SDValue X = SetCC.getOperand(i: 0);
13205 SDValue Ones = SetCC.getOperand(i: 1);
13206 ISD::CondCode CC = cast<CondCodeSDNode>(Val: SetCC.getOperand(i: 2))->get();
13207 EVT VT = N->getValueType(ResNo: 0);
13208 EVT XVT = X.getValueType();
13209 // setge X, C is canonicalized to setgt, so we do not need to match that
13210 // pattern. The setlt sibling is folded in SimplifySelectCC() because it does
13211 // not require the 'not' op.
13212 if (CC == ISD::SETGT && isAllOnesConstant(V: Ones) && VT == XVT) {
13213 // Invert and smear/shift the sign bit:
13214 // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1)
13215 // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1)
13216 SDLoc DL(N);
13217 unsigned ShCt = VT.getSizeInBits() - 1;
13218 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13219 if (!TLI.shouldAvoidTransformToShift(VT, Amount: ShCt)) {
13220 SDValue NotX = DAG.getNOT(DL, Val: X, VT);
13221 SDValue ShiftAmount = DAG.getConstant(Val: ShCt, DL, VT);
13222 auto ShiftOpcode =
13223 N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL;
13224 return DAG.getNode(Opcode: ShiftOpcode, DL, VT, N1: NotX, N2: ShiftAmount);
13225 }
13226 }
13227 return SDValue();
13228}
13229
13230SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
13231 SDValue N0 = N->getOperand(Num: 0);
13232 if (N0.getOpcode() != ISD::SETCC)
13233 return SDValue();
13234
13235 SDValue N00 = N0.getOperand(i: 0);
13236 SDValue N01 = N0.getOperand(i: 1);
13237 ISD::CondCode CC = cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get();
13238 EVT VT = N->getValueType(ResNo: 0);
13239 EVT N00VT = N00.getValueType();
13240 SDLoc DL(N);
13241
13242 // Propagate fast-math-flags.
13243 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
13244
13245 // On some architectures (such as SSE/NEON/etc) the SETCC result type is
13246 // the same size as the compared operands. Try to optimize sext(setcc())
13247 // if this is the case.
13248 if (VT.isVector() && !LegalOperations &&
13249 TLI.getBooleanContents(Type: N00VT) ==
13250 TargetLowering::ZeroOrNegativeOneBooleanContent) {
13251 EVT SVT = getSetCCResultType(VT: N00VT);
13252
13253 // If we already have the desired type, don't change it.
13254 if (SVT != N0.getValueType()) {
13255 // We know that the # elements of the results is the same as the
13256 // # elements of the compare (and the # elements of the compare result
13257 // for that matter). Check to see that they are the same size. If so,
13258 // we know that the element size of the sext'd result matches the
13259 // element size of the compare operands.
13260 if (VT.getSizeInBits() == SVT.getSizeInBits())
13261 return DAG.getSetCC(DL, VT, LHS: N00, RHS: N01, Cond: CC);
13262
13263 // If the desired elements are smaller or larger than the source
13264 // elements, we can use a matching integer vector type and then
13265 // truncate/sign extend.
13266 EVT MatchingVecType = N00VT.changeVectorElementTypeToInteger();
13267 if (SVT == MatchingVecType) {
13268 SDValue VsetCC = DAG.getSetCC(DL, VT: MatchingVecType, LHS: N00, RHS: N01, Cond: CC);
13269 return DAG.getSExtOrTrunc(Op: VsetCC, DL, VT);
13270 }
13271 }
13272
13273 // Try to eliminate the sext of a setcc by zexting the compare operands.
13274 if (N0.hasOneUse() && TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT) &&
13275 !TLI.isOperationLegalOrCustom(Op: ISD::SETCC, VT: SVT)) {
13276 bool IsSignedCmp = ISD::isSignedIntSetCC(Code: CC);
13277 unsigned LoadOpcode = IsSignedCmp ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
13278 unsigned ExtOpcode = IsSignedCmp ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
13279
13280 // We have an unsupported narrow vector compare op that would be legal
13281 // if extended to the destination type. See if the compare operands
13282 // can be freely extended to the destination type.
13283 auto IsFreeToExtend = [&](SDValue V) {
13284 if (isConstantOrConstantVector(N: V, /*NoOpaques*/ true))
13285 return true;
13286 // Match a simple, non-extended load that can be converted to a
13287 // legal {z/s}ext-load.
13288 // TODO: Allow widening of an existing {z/s}ext-load?
13289 if (!(ISD::isNON_EXTLoad(N: V.getNode()) &&
13290 ISD::isUNINDEXEDLoad(N: V.getNode()) &&
13291 cast<LoadSDNode>(Val&: V)->isSimple() &&
13292 TLI.isLoadExtLegal(ExtType: LoadOpcode, ValVT: VT, MemVT: V.getValueType())))
13293 return false;
13294
13295 // Non-chain users of this value must either be the setcc in this
13296 // sequence or extends that can be folded into the new {z/s}ext-load.
13297 for (SDNode::use_iterator UI = V->use_begin(), UE = V->use_end();
13298 UI != UE; ++UI) {
13299 // Skip uses of the chain and the setcc.
13300 SDNode *User = *UI;
13301 if (UI.getUse().getResNo() != 0 || User == N0.getNode())
13302 continue;
13303 // Extra users must have exactly the same cast we are about to create.
13304 // TODO: This restriction could be eased if ExtendUsesToFormExtLoad()
13305 // is enhanced similarly.
13306 if (User->getOpcode() != ExtOpcode || User->getValueType(ResNo: 0) != VT)
13307 return false;
13308 }
13309 return true;
13310 };
13311
13312 if (IsFreeToExtend(N00) && IsFreeToExtend(N01)) {
13313 SDValue Ext0 = DAG.getNode(Opcode: ExtOpcode, DL, VT, Operand: N00);
13314 SDValue Ext1 = DAG.getNode(Opcode: ExtOpcode, DL, VT, Operand: N01);
13315 return DAG.getSetCC(DL, VT, LHS: Ext0, RHS: Ext1, Cond: CC);
13316 }
13317 }
13318 }
13319
13320 // sext(setcc x, y, cc) -> (select (setcc x, y, cc), T, 0)
13321 // Here, T can be 1 or -1, depending on the type of the setcc and
13322 // getBooleanContents().
13323 unsigned SetCCWidth = N0.getScalarValueSizeInBits();
13324
13325 // To determine the "true" side of the select, we need to know the high bit
13326 // of the value returned by the setcc if it evaluates to true.
13327 // If the type of the setcc is i1, then the true case of the select is just
13328 // sext(i1 1), that is, -1.
13329 // If the type of the setcc is larger (say, i8) then the value of the high
13330 // bit depends on getBooleanContents(), so ask TLI for a real "true" value
13331 // of the appropriate width.
13332 SDValue ExtTrueVal = (SetCCWidth == 1)
13333 ? DAG.getAllOnesConstant(DL, VT)
13334 : DAG.getBoolConstant(V: true, DL, VT, OpVT: N00VT);
13335 SDValue Zero = DAG.getConstant(Val: 0, DL, VT);
13336 if (SDValue SCC = SimplifySelectCC(DL, N0: N00, N1: N01, N2: ExtTrueVal, N3: Zero, CC, NotExtCompare: true))
13337 return SCC;
13338
13339 if (!VT.isVector() && !shouldConvertSelectOfConstantsToMath(Cond: N0, VT, TLI)) {
13340 EVT SetCCVT = getSetCCResultType(VT: N00VT);
13341 // Don't do this transform for i1 because there's a select transform
13342 // that would reverse it.
13343 // TODO: We should not do this transform at all without a target hook
13344 // because a sext is likely cheaper than a select?
13345 if (SetCCVT.getScalarSizeInBits() != 1 &&
13346 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SETCC, VT: N00VT))) {
13347 SDValue SetCC = DAG.getSetCC(DL, VT: SetCCVT, LHS: N00, RHS: N01, Cond: CC);
13348 return DAG.getSelect(DL, VT, Cond: SetCC, LHS: ExtTrueVal, RHS: Zero);
13349 }
13350 }
13351
13352 return SDValue();
13353}
13354
13355SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
13356 SDValue N0 = N->getOperand(Num: 0);
13357 EVT VT = N->getValueType(ResNo: 0);
13358 SDLoc DL(N);
13359
13360 if (VT.isVector())
13361 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
13362 return FoldedVOp;
13363
13364 // sext(undef) = 0 because the top bit will all be the same.
13365 if (N0.isUndef())
13366 return DAG.getConstant(Val: 0, DL, VT);
13367
13368 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
13369 return Res;
13370
13371 // fold (sext (sext x)) -> (sext x)
13372 // fold (sext (aext x)) -> (sext x)
13373 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND)
13374 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
13375
13376 // fold (sext (aext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
13377 // fold (sext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
13378 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13379 N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
13380 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_VECTOR_INREG, DL: SDLoc(N), VT,
13381 Operand: N0.getOperand(i: 0));
13382
13383 // fold (sext (sext_inreg x)) -> (sext (trunc x))
13384 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG) {
13385 SDValue N00 = N0.getOperand(i: 0);
13386 EVT ExtVT = cast<VTSDNode>(Val: N0->getOperand(Num: 1))->getVT();
13387 if ((N00.getOpcode() == ISD::TRUNCATE || TLI.isTruncateFree(Val: N00, VT2: ExtVT)) &&
13388 (!LegalTypes || TLI.isTypeLegal(VT: ExtVT))) {
13389 SDValue T = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ExtVT, Operand: N00);
13390 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: T);
13391 }
13392 }
13393
13394 if (N0.getOpcode() == ISD::TRUNCATE) {
13395 // fold (sext (truncate (load x))) -> (sext (smaller load x))
13396 // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
13397 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
13398 SDNode *oye = N0.getOperand(i: 0).getNode();
13399 if (NarrowLoad.getNode() != N0.getNode()) {
13400 CombineTo(N: N0.getNode(), Res: NarrowLoad);
13401 // CombineTo deleted the truncate, if needed, but not what's under it.
13402 AddToWorklist(N: oye);
13403 }
13404 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13405 }
13406
13407 // See if the value being truncated is already sign extended. If so, just
13408 // eliminate the trunc/sext pair.
13409 SDValue Op = N0.getOperand(i: 0);
13410 unsigned OpBits = Op.getScalarValueSizeInBits();
13411 unsigned MidBits = N0.getScalarValueSizeInBits();
13412 unsigned DestBits = VT.getScalarSizeInBits();
13413 unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
13414
13415 if (OpBits == DestBits) {
13416 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
13417 // bits, it is already ready.
13418 if (NumSignBits > DestBits-MidBits)
13419 return Op;
13420 } else if (OpBits < DestBits) {
13421 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
13422 // bits, just sext from i32.
13423 if (NumSignBits > OpBits-MidBits)
13424 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: Op);
13425 } else {
13426 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
13427 // bits, just truncate to i32.
13428 if (NumSignBits > OpBits-MidBits)
13429 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Op);
13430 }
13431
13432 // fold (sext (truncate x)) -> (sextinreg x).
13433 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_INREG,
13434 VT: N0.getValueType())) {
13435 if (OpBits < DestBits)
13436 Op = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: SDLoc(N0), VT, Operand: Op);
13437 else if (OpBits > DestBits)
13438 Op = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT, Operand: Op);
13439 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: Op,
13440 N2: DAG.getValueType(N0.getValueType()));
13441 }
13442 }
13443
13444 // Try to simplify (sext (load x)).
13445 if (SDValue foldedExt =
13446 tryToFoldExtOfLoad(DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0,
13447 ExtLoadType: ISD::SEXTLOAD, ExtOpc: ISD::SIGN_EXTEND))
13448 return foldedExt;
13449
13450 if (SDValue foldedExt =
13451 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
13452 ExtLoadType: ISD::SEXTLOAD, ExtOpc: ISD::SIGN_EXTEND))
13453 return foldedExt;
13454
13455 // fold (sext (load x)) to multiple smaller sextloads.
13456 // Only on illegal but splittable vectors.
13457 if (SDValue ExtLoad = CombineExtLoad(N))
13458 return ExtLoad;
13459
13460 // Try to simplify (sext (sextload x)).
13461 if (SDValue foldedExt = tryToFoldExtOfExtload(
13462 DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0, ExtLoadType: ISD::SEXTLOAD))
13463 return foldedExt;
13464
13465 // Try to simplify (sext (atomic_load x)).
13466 if (SDValue foldedExt =
13467 tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ExtLoadType: ISD::SEXTLOAD))
13468 return foldedExt;
13469
13470 // fold (sext (and/or/xor (load x), cst)) ->
13471 // (and/or/xor (sextload x), (sext cst))
13472 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) &&
13473 isa<LoadSDNode>(Val: N0.getOperand(i: 0)) &&
13474 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
13475 (!LegalOperations && TLI.isOperationLegal(Op: N0.getOpcode(), VT))) {
13476 LoadSDNode *LN00 = cast<LoadSDNode>(Val: N0.getOperand(i: 0));
13477 EVT MemVT = LN00->getMemoryVT();
13478 if (TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT) &&
13479 LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
13480 SmallVector<SDNode*, 4> SetCCs;
13481 bool DoXform = ExtendUsesToFormExtLoad(VT, N: N0.getNode(), N0: N0.getOperand(i: 0),
13482 ExtOpc: ISD::SIGN_EXTEND, ExtendNodes&: SetCCs, TLI);
13483 if (DoXform) {
13484 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: SDLoc(LN00), VT,
13485 Chain: LN00->getChain(), Ptr: LN00->getBasePtr(),
13486 MemVT: LN00->getMemoryVT(),
13487 MMO: LN00->getMemOperand());
13488 APInt Mask = N0.getConstantOperandAPInt(i: 1).sext(width: VT.getSizeInBits());
13489 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
13490 N1: ExtLoad, N2: DAG.getConstant(Val: Mask, DL, VT));
13491 ExtendSetCCUses(SetCCs, OrigLoad: N0.getOperand(i: 0), ExtLoad, ExtType: ISD::SIGN_EXTEND);
13492 bool NoReplaceTruncAnd = !N0.hasOneUse();
13493 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
13494 CombineTo(N, Res: And);
13495 // If N0 has multiple uses, change other uses as well.
13496 if (NoReplaceTruncAnd) {
13497 SDValue TruncAnd =
13498 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N0.getValueType(), Operand: And);
13499 CombineTo(N: N0.getNode(), Res: TruncAnd);
13500 }
13501 if (NoReplaceTrunc) {
13502 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN00, 1), To: ExtLoad.getValue(R: 1));
13503 } else {
13504 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LN00),
13505 VT: LN00->getValueType(ResNo: 0), Operand: ExtLoad);
13506 CombineTo(N: LN00, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
13507 }
13508 return SDValue(N,0); // Return N so it doesn't get rechecked!
13509 }
13510 }
13511 }
13512
13513 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
13514 return V;
13515
13516 if (SDValue V = foldSextSetcc(N))
13517 return V;
13518
13519 // fold (sext x) -> (zext x) if the sign bit is known zero.
13520 if (!TLI.isSExtCheaperThanZExt(FromTy: N0.getValueType(), ToTy: VT) &&
13521 (!LegalOperations || TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT)) &&
13522 DAG.SignBitIsZero(Op: N0)) {
13523 SDNodeFlags Flags;
13524 Flags.setNonNeg(true);
13525 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0, Flags);
13526 }
13527
13528 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
13529 return NewVSel;
13530
13531 // Eliminate this sign extend by doing a negation in the destination type:
13532 // sext i32 (0 - (zext i8 X to i32)) to i64 --> 0 - (zext i8 X to i64)
13533 if (N0.getOpcode() == ISD::SUB && N0.hasOneUse() &&
13534 isNullOrNullSplat(V: N0.getOperand(i: 0)) &&
13535 N0.getOperand(i: 1).getOpcode() == ISD::ZERO_EXTEND &&
13536 TLI.isOperationLegalOrCustom(Op: ISD::SUB, VT)) {
13537 SDValue Zext = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 1).getOperand(i: 0), DL, VT);
13538 return DAG.getNegative(Val: Zext, DL, VT);
13539 }
13540 // Eliminate this sign extend by doing a decrement in the destination type:
13541 // sext i32 ((zext i8 X to i32) + (-1)) to i64 --> (zext i8 X to i64) + (-1)
13542 if (N0.getOpcode() == ISD::ADD && N0.hasOneUse() &&
13543 isAllOnesOrAllOnesSplat(V: N0.getOperand(i: 1)) &&
13544 N0.getOperand(i: 0).getOpcode() == ISD::ZERO_EXTEND &&
13545 TLI.isOperationLegalOrCustom(Op: ISD::ADD, VT)) {
13546 SDValue Zext = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0), DL, VT);
13547 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Zext, N2: DAG.getAllOnesConstant(DL, VT));
13548 }
13549
13550 // fold sext (not i1 X) -> add (zext i1 X), -1
13551 // TODO: This could be extended to handle bool vectors.
13552 if (N0.getValueType() == MVT::i1 && isBitwiseNot(N0) && N0.hasOneUse() &&
13553 (!LegalOperations || (TLI.isOperationLegal(ISD::ZERO_EXTEND, VT) &&
13554 TLI.isOperationLegal(ISD::ADD, VT)))) {
13555 // If we can eliminate the 'not', the sext form should be better
13556 if (SDValue NewXor = visitXOR(N: N0.getNode())) {
13557 // Returning N0 is a form of in-visit replacement that may have
13558 // invalidated N0.
13559 if (NewXor.getNode() == N0.getNode()) {
13560 // Return SDValue here as the xor should have already been replaced in
13561 // this sext.
13562 return SDValue();
13563 }
13564
13565 // Return a new sext with the new xor.
13566 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: NewXor);
13567 }
13568
13569 SDValue Zext = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0));
13570 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: Zext, N2: DAG.getAllOnesConstant(DL, VT));
13571 }
13572
13573 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, Level))
13574 return Res;
13575
13576 return SDValue();
13577}
13578
13579/// Given an extending node with a pop-count operand, if the target does not
13580/// support a pop-count in the narrow source type but does support it in the
13581/// destination type, widen the pop-count to the destination type.
13582static SDValue widenCtPop(SDNode *Extend, SelectionDAG &DAG) {
13583 assert((Extend->getOpcode() == ISD::ZERO_EXTEND ||
13584 Extend->getOpcode() == ISD::ANY_EXTEND) && "Expected extend op");
13585
13586 SDValue CtPop = Extend->getOperand(Num: 0);
13587 if (CtPop.getOpcode() != ISD::CTPOP || !CtPop.hasOneUse())
13588 return SDValue();
13589
13590 EVT VT = Extend->getValueType(ResNo: 0);
13591 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13592 if (TLI.isOperationLegalOrCustom(Op: ISD::CTPOP, VT: CtPop.getValueType()) ||
13593 !TLI.isOperationLegalOrCustom(Op: ISD::CTPOP, VT))
13594 return SDValue();
13595
13596 // zext (ctpop X) --> ctpop (zext X)
13597 SDLoc DL(Extend);
13598 SDValue NewZext = DAG.getZExtOrTrunc(Op: CtPop.getOperand(i: 0), DL, VT);
13599 return DAG.getNode(Opcode: ISD::CTPOP, DL, VT, Operand: NewZext);
13600}
13601
13602// If we have (zext (abs X)) where X is a type that will be promoted by type
13603// legalization, convert to (abs (sext X)). But don't extend past a legal type.
13604static SDValue widenAbs(SDNode *Extend, SelectionDAG &DAG) {
13605 assert(Extend->getOpcode() == ISD::ZERO_EXTEND && "Expected zero extend.");
13606
13607 EVT VT = Extend->getValueType(ResNo: 0);
13608 if (VT.isVector())
13609 return SDValue();
13610
13611 SDValue Abs = Extend->getOperand(Num: 0);
13612 if (Abs.getOpcode() != ISD::ABS || !Abs.hasOneUse())
13613 return SDValue();
13614
13615 EVT AbsVT = Abs.getValueType();
13616 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
13617 if (TLI.getTypeAction(Context&: *DAG.getContext(), VT: AbsVT) !=
13618 TargetLowering::TypePromoteInteger)
13619 return SDValue();
13620
13621 EVT LegalVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: AbsVT);
13622
13623 SDValue SExt =
13624 DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(Abs), VT: LegalVT, Operand: Abs.getOperand(i: 0));
13625 SDValue NewAbs = DAG.getNode(Opcode: ISD::ABS, DL: SDLoc(Abs), VT: LegalVT, Operand: SExt);
13626 return DAG.getZExtOrTrunc(Op: NewAbs, DL: SDLoc(Extend), VT);
13627}
13628
13629SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
13630 SDValue N0 = N->getOperand(Num: 0);
13631 EVT VT = N->getValueType(ResNo: 0);
13632 SDLoc DL(N);
13633
13634 if (VT.isVector())
13635 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL))
13636 return FoldedVOp;
13637
13638 // zext(undef) = 0
13639 if (N0.isUndef())
13640 return DAG.getConstant(Val: 0, DL, VT);
13641
13642 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
13643 return Res;
13644
13645 // fold (zext (zext x)) -> (zext x)
13646 // fold (zext (aext x)) -> (zext x)
13647 if (N0.getOpcode() == ISD::ZERO_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
13648 SDNodeFlags Flags;
13649 if (N0.getOpcode() == ISD::ZERO_EXTEND)
13650 Flags.setNonNeg(N0->getFlags().hasNonNeg());
13651 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: N0.getOperand(i: 0), Flags);
13652 }
13653
13654 // fold (zext (aext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
13655 // fold (zext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
13656 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13657 N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG)
13658 return DAG.getNode(Opcode: ISD::ZERO_EXTEND_VECTOR_INREG, DL: SDLoc(N), VT,
13659 Operand: N0.getOperand(i: 0));
13660
13661 // fold (zext (truncate x)) -> (zext x) or
13662 // (zext (truncate x)) -> (truncate x)
13663 // This is valid when the truncated bits of x are already zero.
13664 SDValue Op;
13665 KnownBits Known;
13666 if (isTruncateOf(DAG, N: N0, Op, Known)) {
13667 APInt TruncatedBits =
13668 (Op.getScalarValueSizeInBits() == N0.getScalarValueSizeInBits()) ?
13669 APInt(Op.getScalarValueSizeInBits(), 0) :
13670 APInt::getBitsSet(numBits: Op.getScalarValueSizeInBits(),
13671 loBit: N0.getScalarValueSizeInBits(),
13672 hiBit: std::min(a: Op.getScalarValueSizeInBits(),
13673 b: VT.getScalarSizeInBits()));
13674 if (TruncatedBits.isSubsetOf(RHS: Known.Zero)) {
13675 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
13676 DAG.salvageDebugInfo(N&: *N0.getNode());
13677
13678 return ZExtOrTrunc;
13679 }
13680 }
13681
13682 // fold (zext (truncate x)) -> (and x, mask)
13683 if (N0.getOpcode() == ISD::TRUNCATE) {
13684 // fold (zext (truncate (load x))) -> (zext (smaller load x))
13685 // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
13686 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
13687 SDNode *oye = N0.getOperand(i: 0).getNode();
13688 if (NarrowLoad.getNode() != N0.getNode()) {
13689 CombineTo(N: N0.getNode(), Res: NarrowLoad);
13690 // CombineTo deleted the truncate, if needed, but not what's under it.
13691 AddToWorklist(N: oye);
13692 }
13693 return SDValue(N, 0); // Return N so it doesn't get rechecked!
13694 }
13695
13696 EVT SrcVT = N0.getOperand(i: 0).getValueType();
13697 EVT MinVT = N0.getValueType();
13698
13699 if (N->getFlags().hasNonNeg()) {
13700 SDValue Op = N0.getOperand(i: 0);
13701 unsigned OpBits = SrcVT.getScalarSizeInBits();
13702 unsigned MidBits = MinVT.getScalarSizeInBits();
13703 unsigned DestBits = VT.getScalarSizeInBits();
13704 unsigned NumSignBits = DAG.ComputeNumSignBits(Op);
13705
13706 if (OpBits == DestBits) {
13707 // Op is i32, Mid is i8, and Dest is i32. If Op has more than 24 sign
13708 // bits, it is already ready.
13709 if (NumSignBits > DestBits - MidBits)
13710 return Op;
13711 } else if (OpBits < DestBits) {
13712 // Op is i32, Mid is i8, and Dest is i64. If Op has more than 24 sign
13713 // bits, just sext from i32.
13714 // FIXME: This can probably be ZERO_EXTEND nneg?
13715 if (NumSignBits > OpBits - MidBits)
13716 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL, VT, Operand: Op);
13717 } else {
13718 // Op is i64, Mid is i8, and Dest is i32. If Op has more than 56 sign
13719 // bits, just truncate to i32.
13720 if (NumSignBits > OpBits - MidBits)
13721 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: Op);
13722 }
13723 }
13724
13725 // Try to mask before the extension to avoid having to generate a larger mask,
13726 // possibly over several sub-vectors.
13727 if (SrcVT.bitsLT(VT) && VT.isVector()) {
13728 if (!LegalOperations || (TLI.isOperationLegal(Op: ISD::AND, VT: SrcVT) &&
13729 TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT))) {
13730 SDValue Op = N0.getOperand(i: 0);
13731 Op = DAG.getZeroExtendInReg(Op, DL, VT: MinVT);
13732 AddToWorklist(N: Op.getNode());
13733 SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, DL, VT);
13734 // Transfer the debug info; the new node is equivalent to N0.
13735 DAG.transferDbgValues(From: N0, To: ZExtOrTrunc);
13736 return ZExtOrTrunc;
13737 }
13738 }
13739
13740 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::AND, VT)) {
13741 SDValue Op = DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0), DL, VT);
13742 AddToWorklist(N: Op.getNode());
13743 SDValue And = DAG.getZeroExtendInReg(Op, DL, VT: MinVT);
13744 // We may safely transfer the debug info describing the truncate node over
13745 // to the equivalent and operation.
13746 DAG.transferDbgValues(From: N0, To: And);
13747 return And;
13748 }
13749 }
13750
13751 // Fold (zext (and (trunc x), cst)) -> (and x, cst),
13752 // if either of the casts is not free.
13753 if (N0.getOpcode() == ISD::AND &&
13754 N0.getOperand(i: 0).getOpcode() == ISD::TRUNCATE &&
13755 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
13756 (!TLI.isTruncateFree(Val: N0.getOperand(i: 0).getOperand(i: 0), VT2: N0.getValueType()) ||
13757 !TLI.isZExtFree(FromTy: N0.getValueType(), ToTy: VT))) {
13758 SDValue X = N0.getOperand(i: 0).getOperand(i: 0);
13759 X = DAG.getAnyExtOrTrunc(Op: X, DL: SDLoc(X), VT);
13760 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
13761 return DAG.getNode(Opcode: ISD::AND, DL, VT,
13762 N1: X, N2: DAG.getConstant(Val: Mask, DL, VT));
13763 }
13764
13765 // Try to simplify (zext (load x)).
13766 if (SDValue foldedExt = tryToFoldExtOfLoad(
13767 DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0, ExtLoadType: ISD::ZEXTLOAD,
13768 ExtOpc: ISD::ZERO_EXTEND, NonNegZExt: N->getFlags().hasNonNeg()))
13769 return foldedExt;
13770
13771 if (SDValue foldedExt =
13772 tryToFoldExtOfMaskedLoad(DAG, TLI, VT, LegalOperations, N, N0,
13773 ExtLoadType: ISD::ZEXTLOAD, ExtOpc: ISD::ZERO_EXTEND))
13774 return foldedExt;
13775
13776 // fold (zext (load x)) to multiple smaller zextloads.
13777 // Only on illegal but splittable vectors.
13778 if (SDValue ExtLoad = CombineExtLoad(N))
13779 return ExtLoad;
13780
13781 // Try to simplify (zext (atomic_load x)).
13782 if (SDValue foldedExt =
13783 tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ExtLoadType: ISD::ZEXTLOAD))
13784 return foldedExt;
13785
13786 // fold (zext (and/or/xor (load x), cst)) ->
13787 // (and/or/xor (zextload x), (zext cst))
13788 // Unless (and (load x) cst) will match as a zextload already and has
13789 // additional users, or the zext is already free.
13790 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) && !TLI.isZExtFree(Val: N0, VT2: VT) &&
13791 isa<LoadSDNode>(Val: N0.getOperand(i: 0)) &&
13792 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
13793 (!LegalOperations && TLI.isOperationLegal(Op: N0.getOpcode(), VT))) {
13794 LoadSDNode *LN00 = cast<LoadSDNode>(Val: N0.getOperand(i: 0));
13795 EVT MemVT = LN00->getMemoryVT();
13796 if (TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: VT, MemVT) &&
13797 LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
13798 bool DoXform = true;
13799 SmallVector<SDNode*, 4> SetCCs;
13800 if (!N0.hasOneUse()) {
13801 if (N0.getOpcode() == ISD::AND) {
13802 auto *AndC = cast<ConstantSDNode>(Val: N0.getOperand(i: 1));
13803 EVT LoadResultTy = AndC->getValueType(ResNo: 0);
13804 EVT ExtVT;
13805 if (isAndLoadExtLoad(AndC, LoadN: LN00, LoadResultTy, ExtVT))
13806 DoXform = false;
13807 }
13808 }
13809 if (DoXform)
13810 DoXform = ExtendUsesToFormExtLoad(VT, N: N0.getNode(), N0: N0.getOperand(i: 0),
13811 ExtOpc: ISD::ZERO_EXTEND, ExtendNodes&: SetCCs, TLI);
13812 if (DoXform) {
13813 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::ZEXTLOAD, dl: SDLoc(LN00), VT,
13814 Chain: LN00->getChain(), Ptr: LN00->getBasePtr(),
13815 MemVT: LN00->getMemoryVT(),
13816 MMO: LN00->getMemOperand());
13817 APInt Mask = N0.getConstantOperandAPInt(i: 1).zext(width: VT.getSizeInBits());
13818 SDValue And = DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
13819 N1: ExtLoad, N2: DAG.getConstant(Val: Mask, DL, VT));
13820 ExtendSetCCUses(SetCCs, OrigLoad: N0.getOperand(i: 0), ExtLoad, ExtType: ISD::ZERO_EXTEND);
13821 bool NoReplaceTruncAnd = !N0.hasOneUse();
13822 bool NoReplaceTrunc = SDValue(LN00, 0).hasOneUse();
13823 CombineTo(N, Res: And);
13824 // If N0 has multiple uses, change other uses as well.
13825 if (NoReplaceTruncAnd) {
13826 SDValue TruncAnd =
13827 DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N0.getValueType(), Operand: And);
13828 CombineTo(N: N0.getNode(), Res: TruncAnd);
13829 }
13830 if (NoReplaceTrunc) {
13831 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN00, 1), To: ExtLoad.getValue(R: 1));
13832 } else {
13833 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LN00),
13834 VT: LN00->getValueType(ResNo: 0), Operand: ExtLoad);
13835 CombineTo(N: LN00, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
13836 }
13837 return SDValue(N,0); // Return N so it doesn't get rechecked!
13838 }
13839 }
13840 }
13841
13842 // fold (zext (and/or/xor (shl/shr (load x), cst), cst)) ->
13843 // (and/or/xor (shl/shr (zextload x), (zext cst)), (zext cst))
13844 if (SDValue ZExtLoad = CombineZExtLogicopShiftLoad(N))
13845 return ZExtLoad;
13846
13847 // Try to simplify (zext (zextload x)).
13848 if (SDValue foldedExt = tryToFoldExtOfExtload(
13849 DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0, ExtLoadType: ISD::ZEXTLOAD))
13850 return foldedExt;
13851
13852 if (SDValue V = foldExtendedSignBitTest(N, DAG, LegalOperations))
13853 return V;
13854
13855 if (N0.getOpcode() == ISD::SETCC) {
13856 // Propagate fast-math-flags.
13857 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
13858
13859 // Only do this before legalize for now.
13860 if (!LegalOperations && VT.isVector() &&
13861 N0.getValueType().getVectorElementType() == MVT::i1) {
13862 EVT N00VT = N0.getOperand(i: 0).getValueType();
13863 if (getSetCCResultType(VT: N00VT) == N0.getValueType())
13864 return SDValue();
13865
13866 // We know that the # elements of the results is the same as the #
13867 // elements of the compare (and the # elements of the compare result for
13868 // that matter). Check to see that they are the same size. If so, we know
13869 // that the element size of the sext'd result matches the element size of
13870 // the compare operands.
13871 if (VT.getSizeInBits() == N00VT.getSizeInBits()) {
13872 // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
13873 SDValue VSetCC = DAG.getNode(Opcode: ISD::SETCC, DL, VT, N1: N0.getOperand(i: 0),
13874 N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
13875 return DAG.getZeroExtendInReg(Op: VSetCC, DL, VT: N0.getValueType());
13876 }
13877
13878 // If the desired elements are smaller or larger than the source
13879 // elements we can use a matching integer vector type and then
13880 // truncate/any extend followed by zext_in_reg.
13881 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
13882 SDValue VsetCC =
13883 DAG.getNode(Opcode: ISD::SETCC, DL, VT: MatchingVectorType, N1: N0.getOperand(i: 0),
13884 N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
13885 return DAG.getZeroExtendInReg(Op: DAG.getAnyExtOrTrunc(Op: VsetCC, DL, VT), DL,
13886 VT: N0.getValueType());
13887 }
13888
13889 // zext(setcc x,y,cc) -> zext(select x, y, true, false, cc)
13890 EVT N0VT = N0.getValueType();
13891 EVT N00VT = N0.getOperand(i: 0).getValueType();
13892 if (SDValue SCC = SimplifySelectCC(
13893 DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1),
13894 N2: DAG.getBoolConstant(V: true, DL, VT: N0VT, OpVT: N00VT),
13895 N3: DAG.getBoolConstant(V: false, DL, VT: N0VT, OpVT: N00VT),
13896 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get(), NotExtCompare: true))
13897 return DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: SCC);
13898 }
13899
13900 // (zext (shl (zext x), cst)) -> (shl (zext x), cst)
13901 if ((N0.getOpcode() == ISD::SHL || N0.getOpcode() == ISD::SRL) &&
13902 !TLI.isZExtFree(Val: N0, VT2: VT)) {
13903 SDValue ShVal = N0.getOperand(i: 0);
13904 SDValue ShAmt = N0.getOperand(i: 1);
13905 if (auto *ShAmtC = dyn_cast<ConstantSDNode>(Val&: ShAmt)) {
13906 if (ShVal.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse()) {
13907 if (N0.getOpcode() == ISD::SHL) {
13908 // If the original shl may be shifting out bits, do not perform this
13909 // transformation.
13910 unsigned KnownZeroBits = ShVal.getValueSizeInBits() -
13911 ShVal.getOperand(i: 0).getValueSizeInBits();
13912 if (ShAmtC->getAPIntValue().ugt(RHS: KnownZeroBits)) {
13913 // If the shift is too large, then see if we can deduce that the
13914 // shift is safe anyway.
13915 // Create a mask that has ones for the bits being shifted out.
13916 APInt ShiftOutMask =
13917 APInt::getHighBitsSet(numBits: ShVal.getValueSizeInBits(),
13918 hiBitsSet: ShAmtC->getAPIntValue().getZExtValue());
13919
13920 // Check if the bits being shifted out are known to be zero.
13921 if (!DAG.MaskedValueIsZero(Op: ShVal, Mask: ShiftOutMask))
13922 return SDValue();
13923 }
13924 }
13925
13926 // Ensure that the shift amount is wide enough for the shifted value.
13927 if (Log2_32_Ceil(VT.getSizeInBits()) > ShAmt.getValueSizeInBits())
13928 ShAmt = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, ShAmt);
13929
13930 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT,
13931 N1: DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: ShVal), N2: ShAmt);
13932 }
13933 }
13934 }
13935
13936 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
13937 return NewVSel;
13938
13939 if (SDValue NewCtPop = widenCtPop(Extend: N, DAG))
13940 return NewCtPop;
13941
13942 if (SDValue V = widenAbs(Extend: N, DAG))
13943 return V;
13944
13945 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, Level))
13946 return Res;
13947
13948 // CSE zext nneg with sext if the zext is not free.
13949 if (N->getFlags().hasNonNeg() && !TLI.isZExtFree(FromTy: N0.getValueType(), ToTy: VT)) {
13950 SDNode *CSENode = DAG.getNodeIfExists(Opcode: ISD::SIGN_EXTEND, VTList: N->getVTList(), Ops: N0);
13951 if (CSENode)
13952 return SDValue(CSENode, 0);
13953 }
13954
13955 return SDValue();
13956}
13957
13958SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
13959 SDValue N0 = N->getOperand(Num: 0);
13960 EVT VT = N->getValueType(ResNo: 0);
13961 SDLoc DL(N);
13962
13963 // aext(undef) = undef
13964 if (N0.isUndef())
13965 return DAG.getUNDEF(VT);
13966
13967 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
13968 return Res;
13969
13970 // fold (aext (aext x)) -> (aext x)
13971 // fold (aext (zext x)) -> (zext x)
13972 // fold (aext (sext x)) -> (sext x)
13973 if (N0.getOpcode() == ISD::ANY_EXTEND || N0.getOpcode() == ISD::ZERO_EXTEND ||
13974 N0.getOpcode() == ISD::SIGN_EXTEND) {
13975 SDNodeFlags Flags;
13976 if (N0.getOpcode() == ISD::ZERO_EXTEND)
13977 Flags.setNonNeg(N0->getFlags().hasNonNeg());
13978 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0.getOperand(i: 0), Flags);
13979 }
13980
13981 // fold (aext (aext_extend_vector_inreg x)) -> (aext_extend_vector_inreg x)
13982 // fold (aext (zext_extend_vector_inreg x)) -> (zext_extend_vector_inreg x)
13983 // fold (aext (sext_extend_vector_inreg x)) -> (sext_extend_vector_inreg x)
13984 if (N0.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG ||
13985 N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG ||
13986 N0.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG)
13987 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0.getOperand(i: 0));
13988
13989 // fold (aext (truncate (load x))) -> (aext (smaller load x))
13990 // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
13991 if (N0.getOpcode() == ISD::TRUNCATE) {
13992 if (SDValue NarrowLoad = reduceLoadWidth(N: N0.getNode())) {
13993 SDNode *oye = N0.getOperand(i: 0).getNode();
13994 if (NarrowLoad.getNode() != N0.getNode()) {
13995 CombineTo(N: N0.getNode(), Res: NarrowLoad);
13996 // CombineTo deleted the truncate, if needed, but not what's under it.
13997 AddToWorklist(N: oye);
13998 }
13999 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14000 }
14001 }
14002
14003 // fold (aext (truncate x))
14004 if (N0.getOpcode() == ISD::TRUNCATE)
14005 return DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0), DL, VT);
14006
14007 // Fold (aext (and (trunc x), cst)) -> (and x, cst)
14008 // if the trunc is not free.
14009 if (N0.getOpcode() == ISD::AND &&
14010 N0.getOperand(i: 0).getOpcode() == ISD::TRUNCATE &&
14011 N0.getOperand(i: 1).getOpcode() == ISD::Constant &&
14012 !TLI.isTruncateFree(Val: N0.getOperand(i: 0).getOperand(i: 0), VT2: N0.getValueType())) {
14013 SDValue X = DAG.getAnyExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0), DL, VT);
14014 SDValue Y = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT, Operand: N0.getOperand(i: 1));
14015 assert(isa<ConstantSDNode>(Y) && "Expected constant to be folded!");
14016 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: X, N2: Y);
14017 }
14018
14019 // fold (aext (load x)) -> (aext (truncate (extload x)))
14020 // None of the supported targets knows how to perform load and any_ext
14021 // on vectors in one instruction, so attempt to fold to zext instead.
14022 if (VT.isVector()) {
14023 // Try to simplify (zext (load x)).
14024 if (SDValue foldedExt =
14025 tryToFoldExtOfLoad(DAG, Combiner&: *this, TLI, VT, LegalOperations, N, N0,
14026 ExtLoadType: ISD::ZEXTLOAD, ExtOpc: ISD::ZERO_EXTEND))
14027 return foldedExt;
14028 } else if (ISD::isNON_EXTLoad(N: N0.getNode()) &&
14029 ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
14030 TLI.isLoadExtLegal(ExtType: ISD::EXTLOAD, ValVT: VT, MemVT: N0.getValueType())) {
14031 bool DoXform = true;
14032 SmallVector<SDNode *, 4> SetCCs;
14033 if (!N0.hasOneUse())
14034 DoXform =
14035 ExtendUsesToFormExtLoad(VT, N, N0, ExtOpc: ISD::ANY_EXTEND, ExtendNodes&: SetCCs, TLI);
14036 if (DoXform) {
14037 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14038 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: DL, VT, Chain: LN0->getChain(),
14039 Ptr: LN0->getBasePtr(), MemVT: N0.getValueType(),
14040 MMO: LN0->getMemOperand());
14041 ExtendSetCCUses(SetCCs, OrigLoad: N0, ExtLoad, ExtType: ISD::ANY_EXTEND);
14042 // If the load value is used only by N, replace it via CombineTo N.
14043 bool NoReplaceTrunc = N0.hasOneUse();
14044 CombineTo(N, Res: ExtLoad);
14045 if (NoReplaceTrunc) {
14046 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
14047 recursivelyDeleteUnusedNodes(N: LN0);
14048 } else {
14049 SDValue Trunc =
14050 DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N0), VT: N0.getValueType(), Operand: ExtLoad);
14051 CombineTo(N: LN0, Res0: Trunc, Res1: ExtLoad.getValue(R: 1));
14052 }
14053 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14054 }
14055 }
14056
14057 // fold (aext (zextload x)) -> (aext (truncate (zextload x)))
14058 // fold (aext (sextload x)) -> (aext (truncate (sextload x)))
14059 // fold (aext ( extload x)) -> (aext (truncate (extload x)))
14060 if (N0.getOpcode() == ISD::LOAD && !ISD::isNON_EXTLoad(N: N0.getNode()) &&
14061 ISD::isUNINDEXEDLoad(N: N0.getNode()) && N0.hasOneUse()) {
14062 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14063 ISD::LoadExtType ExtType = LN0->getExtensionType();
14064 EVT MemVT = LN0->getMemoryVT();
14065 if (!LegalOperations || TLI.isLoadExtLegal(ExtType, ValVT: VT, MemVT)) {
14066 SDValue ExtLoad =
14067 DAG.getExtLoad(ExtType, dl: DL, VT, Chain: LN0->getChain(), Ptr: LN0->getBasePtr(),
14068 MemVT, MMO: LN0->getMemOperand());
14069 CombineTo(N, Res: ExtLoad);
14070 DAG.ReplaceAllUsesOfValueWith(From: SDValue(LN0, 1), To: ExtLoad.getValue(R: 1));
14071 recursivelyDeleteUnusedNodes(N: LN0);
14072 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14073 }
14074 }
14075
14076 if (N0.getOpcode() == ISD::SETCC) {
14077 // Propagate fast-math-flags.
14078 SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
14079
14080 // For vectors:
14081 // aext(setcc) -> vsetcc
14082 // aext(setcc) -> truncate(vsetcc)
14083 // aext(setcc) -> aext(vsetcc)
14084 // Only do this before legalize for now.
14085 if (VT.isVector() && !LegalOperations) {
14086 EVT N00VT = N0.getOperand(i: 0).getValueType();
14087 if (getSetCCResultType(VT: N00VT) == N0.getValueType())
14088 return SDValue();
14089
14090 // We know that the # elements of the results is the same as the
14091 // # elements of the compare (and the # elements of the compare result
14092 // for that matter). Check to see that they are the same size. If so,
14093 // we know that the element size of the sext'd result matches the
14094 // element size of the compare operands.
14095 if (VT.getSizeInBits() == N00VT.getSizeInBits())
14096 return DAG.getSetCC(DL, VT, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
14097 Cond: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
14098
14099 // If the desired elements are smaller or larger than the source
14100 // elements we can use a matching integer vector type and then
14101 // truncate/any extend
14102 EVT MatchingVectorType = N00VT.changeVectorElementTypeToInteger();
14103 SDValue VsetCC = DAG.getSetCC(
14104 DL, VT: MatchingVectorType, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
14105 Cond: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
14106 return DAG.getAnyExtOrTrunc(Op: VsetCC, DL, VT);
14107 }
14108
14109 // aext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
14110 if (SDValue SCC = SimplifySelectCC(
14111 DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1), N2: DAG.getConstant(Val: 1, DL, VT),
14112 N3: DAG.getConstant(Val: 0, DL, VT),
14113 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get(), NotExtCompare: true))
14114 return SCC;
14115 }
14116
14117 if (SDValue NewCtPop = widenCtPop(Extend: N, DAG))
14118 return NewCtPop;
14119
14120 if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG, Level))
14121 return Res;
14122
14123 return SDValue();
14124}
14125
14126SDValue DAGCombiner::visitAssertExt(SDNode *N) {
14127 unsigned Opcode = N->getOpcode();
14128 SDValue N0 = N->getOperand(Num: 0);
14129 SDValue N1 = N->getOperand(Num: 1);
14130 EVT AssertVT = cast<VTSDNode>(Val&: N1)->getVT();
14131
14132 // fold (assert?ext (assert?ext x, vt), vt) -> (assert?ext x, vt)
14133 if (N0.getOpcode() == Opcode &&
14134 AssertVT == cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT())
14135 return N0;
14136
14137 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
14138 N0.getOperand(i: 0).getOpcode() == Opcode) {
14139 // We have an assert, truncate, assert sandwich. Make one stronger assert
14140 // by asserting on the smallest asserted type to the larger source type.
14141 // This eliminates the later assert:
14142 // assert (trunc (assert X, i8) to iN), i1 --> trunc (assert X, i1) to iN
14143 // assert (trunc (assert X, i1) to iN), i8 --> trunc (assert X, i1) to iN
14144 SDLoc DL(N);
14145 SDValue BigA = N0.getOperand(i: 0);
14146 EVT BigA_AssertVT = cast<VTSDNode>(Val: BigA.getOperand(i: 1))->getVT();
14147 EVT MinAssertVT = AssertVT.bitsLT(VT: BigA_AssertVT) ? AssertVT : BigA_AssertVT;
14148 SDValue MinAssertVTVal = DAG.getValueType(MinAssertVT);
14149 SDValue NewAssert = DAG.getNode(Opcode, DL, VT: BigA.getValueType(),
14150 N1: BigA.getOperand(i: 0), N2: MinAssertVTVal);
14151 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N->getValueType(ResNo: 0), Operand: NewAssert);
14152 }
14153
14154 // If we have (AssertZext (truncate (AssertSext X, iX)), iY) and Y is smaller
14155 // than X. Just move the AssertZext in front of the truncate and drop the
14156 // AssertSExt.
14157 if (N0.getOpcode() == ISD::TRUNCATE && N0.hasOneUse() &&
14158 N0.getOperand(i: 0).getOpcode() == ISD::AssertSext &&
14159 Opcode == ISD::AssertZext) {
14160 SDValue BigA = N0.getOperand(i: 0);
14161 EVT BigA_AssertVT = cast<VTSDNode>(Val: BigA.getOperand(i: 1))->getVT();
14162 if (AssertVT.bitsLT(VT: BigA_AssertVT)) {
14163 SDLoc DL(N);
14164 SDValue NewAssert = DAG.getNode(Opcode, DL, VT: BigA.getValueType(),
14165 N1: BigA.getOperand(i: 0), N2: N1);
14166 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: N->getValueType(ResNo: 0), Operand: NewAssert);
14167 }
14168 }
14169
14170 return SDValue();
14171}
14172
14173SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
14174 SDLoc DL(N);
14175
14176 Align AL = cast<AssertAlignSDNode>(Val: N)->getAlign();
14177 SDValue N0 = N->getOperand(Num: 0);
14178
14179 // Fold (assertalign (assertalign x, AL0), AL1) ->
14180 // (assertalign x, max(AL0, AL1))
14181 if (auto *AAN = dyn_cast<AssertAlignSDNode>(Val&: N0))
14182 return DAG.getAssertAlign(DL, V: N0.getOperand(i: 0),
14183 A: std::max(a: AL, b: AAN->getAlign()));
14184
14185 // In rare cases, there are trivial arithmetic ops in source operands. Sink
14186 // this assert down to source operands so that those arithmetic ops could be
14187 // exposed to the DAG combining.
14188 switch (N0.getOpcode()) {
14189 default:
14190 break;
14191 case ISD::ADD:
14192 case ISD::SUB: {
14193 unsigned AlignShift = Log2(A: AL);
14194 SDValue LHS = N0.getOperand(i: 0);
14195 SDValue RHS = N0.getOperand(i: 1);
14196 unsigned LHSAlignShift = DAG.computeKnownBits(Op: LHS).countMinTrailingZeros();
14197 unsigned RHSAlignShift = DAG.computeKnownBits(Op: RHS).countMinTrailingZeros();
14198 if (LHSAlignShift >= AlignShift || RHSAlignShift >= AlignShift) {
14199 if (LHSAlignShift < AlignShift)
14200 LHS = DAG.getAssertAlign(DL, V: LHS, A: AL);
14201 if (RHSAlignShift < AlignShift)
14202 RHS = DAG.getAssertAlign(DL, V: RHS, A: AL);
14203 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT: N0.getValueType(), N1: LHS, N2: RHS);
14204 }
14205 break;
14206 }
14207 }
14208
14209 return SDValue();
14210}
14211
14212/// If the result of a load is shifted/masked/truncated to an effectively
14213/// narrower type, try to transform the load to a narrower type and/or
14214/// use an extending load.
14215SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
14216 unsigned Opc = N->getOpcode();
14217
14218 ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
14219 SDValue N0 = N->getOperand(Num: 0);
14220 EVT VT = N->getValueType(ResNo: 0);
14221 EVT ExtVT = VT;
14222
14223 // This transformation isn't valid for vector loads.
14224 if (VT.isVector())
14225 return SDValue();
14226
14227 // The ShAmt variable is used to indicate that we've consumed a right
14228 // shift. I.e. we want to narrow the width of the load by skipping to load the
14229 // ShAmt least significant bits.
14230 unsigned ShAmt = 0;
14231 // A special case is when the least significant bits from the load are masked
14232 // away, but using an AND rather than a right shift. HasShiftedOffset is used
14233 // to indicate that the narrowed load should be left-shifted ShAmt bits to get
14234 // the result.
14235 unsigned ShiftedOffset = 0;
14236 // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
14237 // extended to VT.
14238 if (Opc == ISD::SIGN_EXTEND_INREG) {
14239 ExtType = ISD::SEXTLOAD;
14240 ExtVT = cast<VTSDNode>(Val: N->getOperand(Num: 1))->getVT();
14241 } else if (Opc == ISD::SRL || Opc == ISD::SRA) {
14242 // Another special-case: SRL/SRA is basically zero/sign-extending a narrower
14243 // value, or it may be shifting a higher subword, half or byte into the
14244 // lowest bits.
14245
14246 // Only handle shift with constant shift amount, and the shiftee must be a
14247 // load.
14248 auto *LN = dyn_cast<LoadSDNode>(Val&: N0);
14249 auto *N1C = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
14250 if (!N1C || !LN)
14251 return SDValue();
14252 // If the shift amount is larger than the memory type then we're not
14253 // accessing any of the loaded bytes.
14254 ShAmt = N1C->getZExtValue();
14255 uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
14256 if (MemoryWidth <= ShAmt)
14257 return SDValue();
14258 // Attempt to fold away the SRL by using ZEXTLOAD and SRA by using SEXTLOAD.
14259 ExtType = Opc == ISD::SRL ? ISD::ZEXTLOAD : ISD::SEXTLOAD;
14260 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemoryWidth - ShAmt);
14261 // If original load is a SEXTLOAD then we can't simply replace it by a
14262 // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
14263 // followed by a ZEXT, but that is not handled at the moment). Similarly if
14264 // the original load is a ZEXTLOAD and we want to use a SEXTLOAD.
14265 if ((LN->getExtensionType() == ISD::SEXTLOAD ||
14266 LN->getExtensionType() == ISD::ZEXTLOAD) &&
14267 LN->getExtensionType() != ExtType)
14268 return SDValue();
14269 } else if (Opc == ISD::AND) {
14270 // An AND with a constant mask is the same as a truncate + zero-extend.
14271 auto AndC = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
14272 if (!AndC)
14273 return SDValue();
14274
14275 const APInt &Mask = AndC->getAPIntValue();
14276 unsigned ActiveBits = 0;
14277 if (Mask.isMask()) {
14278 ActiveBits = Mask.countr_one();
14279 } else if (Mask.isShiftedMask(MaskIdx&: ShAmt, MaskLen&: ActiveBits)) {
14280 ShiftedOffset = ShAmt;
14281 } else {
14282 return SDValue();
14283 }
14284
14285 ExtType = ISD::ZEXTLOAD;
14286 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
14287 }
14288
14289 // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
14290 // a right shift. Here we redo some of those checks, to possibly adjust the
14291 // ExtVT even further based on "a masking AND". We could also end up here for
14292 // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
14293 // need to be done here as well.
14294 if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
14295 SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
14296 // Bail out when the SRL has more than one use. This is done for historical
14297 // (undocumented) reasons. Maybe intent was to guard the AND-masking below
14298 // check below? And maybe it could be non-profitable to do the transform in
14299 // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
14300 // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
14301 if (!SRL.hasOneUse())
14302 return SDValue();
14303
14304 // Only handle shift with constant shift amount, and the shiftee must be a
14305 // load.
14306 auto *LN = dyn_cast<LoadSDNode>(Val: SRL.getOperand(i: 0));
14307 auto *SRL1C = dyn_cast<ConstantSDNode>(Val: SRL.getOperand(i: 1));
14308 if (!SRL1C || !LN)
14309 return SDValue();
14310
14311 // If the shift amount is larger than the input type then we're not
14312 // accessing any of the loaded bytes. If the load was a zextload/extload
14313 // then the result of the shift+trunc is zero/undef (handled elsewhere).
14314 ShAmt = SRL1C->getZExtValue();
14315 uint64_t MemoryWidth = LN->getMemoryVT().getSizeInBits();
14316 if (ShAmt >= MemoryWidth)
14317 return SDValue();
14318
14319 // Because a SRL must be assumed to *need* to zero-extend the high bits
14320 // (as opposed to anyext the high bits), we can't combine the zextload
14321 // lowering of SRL and an sextload.
14322 if (LN->getExtensionType() == ISD::SEXTLOAD)
14323 return SDValue();
14324
14325 // Avoid reading outside the memory accessed by the original load (could
14326 // happened if we only adjust the load base pointer by ShAmt). Instead we
14327 // try to narrow the load even further. The typical scenario here is:
14328 // (i64 (truncate (i96 (srl (load x), 64)))) ->
14329 // (i64 (truncate (i96 (zextload (load i32 + offset) from i32))))
14330 if (ExtVT.getScalarSizeInBits() > MemoryWidth - ShAmt) {
14331 // Don't replace sextload by zextload.
14332 if (ExtType == ISD::SEXTLOAD)
14333 return SDValue();
14334 // Narrow the load.
14335 ExtType = ISD::ZEXTLOAD;
14336 ExtVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemoryWidth - ShAmt);
14337 }
14338
14339 // If the SRL is only used by a masking AND, we may be able to adjust
14340 // the ExtVT to make the AND redundant.
14341 SDNode *Mask = *(SRL->use_begin());
14342 if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
14343 isa<ConstantSDNode>(Val: Mask->getOperand(Num: 1))) {
14344 unsigned Offset, ActiveBits;
14345 const APInt& ShiftMask = Mask->getConstantOperandAPInt(Num: 1);
14346 if (ShiftMask.isMask()) {
14347 EVT MaskedVT =
14348 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ShiftMask.countr_one());
14349 // If the mask is smaller, recompute the type.
14350 if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
14351 TLI.isLoadExtLegal(ExtType, ValVT: SRL.getValueType(), MemVT: MaskedVT))
14352 ExtVT = MaskedVT;
14353 } else if (ExtType == ISD::ZEXTLOAD &&
14354 ShiftMask.isShiftedMask(MaskIdx&: Offset, MaskLen&: ActiveBits) &&
14355 (Offset + ShAmt) < VT.getScalarSizeInBits()) {
14356 EVT MaskedVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ActiveBits);
14357 // If the mask is shifted we can use a narrower load and a shl to insert
14358 // the trailing zeros.
14359 if (((Offset + ActiveBits) <= ExtVT.getScalarSizeInBits()) &&
14360 TLI.isLoadExtLegal(ExtType, ValVT: SRL.getValueType(), MemVT: MaskedVT)) {
14361 ExtVT = MaskedVT;
14362 ShAmt = Offset + ShAmt;
14363 ShiftedOffset = Offset;
14364 }
14365 }
14366 }
14367
14368 N0 = SRL.getOperand(i: 0);
14369 }
14370
14371 // If the load is shifted left (and the result isn't shifted back right), we
14372 // can fold a truncate through the shift. The typical scenario is that N
14373 // points at a TRUNCATE here so the attempted fold is:
14374 // (truncate (shl (load x), c))) -> (shl (narrow load x), c)
14375 // ShLeftAmt will indicate how much a narrowed load should be shifted left.
14376 unsigned ShLeftAmt = 0;
14377 if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
14378 ExtVT == VT && TLI.isNarrowingProfitable(SrcVT: N0.getValueType(), DestVT: VT)) {
14379 if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1))) {
14380 ShLeftAmt = N01->getZExtValue();
14381 N0 = N0.getOperand(i: 0);
14382 }
14383 }
14384
14385 // If we haven't found a load, we can't narrow it.
14386 if (!isa<LoadSDNode>(Val: N0))
14387 return SDValue();
14388
14389 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14390 // Reducing the width of a volatile load is illegal. For atomics, we may be
14391 // able to reduce the width provided we never widen again. (see D66309)
14392 if (!LN0->isSimple() ||
14393 !isLegalNarrowLdSt(LDST: LN0, ExtType, MemVT&: ExtVT, ShAmt))
14394 return SDValue();
14395
14396 auto AdjustBigEndianShift = [&](unsigned ShAmt) {
14397 unsigned LVTStoreBits =
14398 LN0->getMemoryVT().getStoreSizeInBits().getFixedValue();
14399 unsigned EVTStoreBits = ExtVT.getStoreSizeInBits().getFixedValue();
14400 return LVTStoreBits - EVTStoreBits - ShAmt;
14401 };
14402
14403 // We need to adjust the pointer to the load by ShAmt bits in order to load
14404 // the correct bytes.
14405 unsigned PtrAdjustmentInBits =
14406 DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
14407
14408 uint64_t PtrOff = PtrAdjustmentInBits / 8;
14409 SDLoc DL(LN0);
14410 // The original load itself didn't wrap, so an offset within it doesn't.
14411 SDNodeFlags Flags;
14412 Flags.setNoUnsignedWrap(true);
14413 SDValue NewPtr = DAG.getMemBasePlusOffset(
14414 Base: LN0->getBasePtr(), Offset: TypeSize::getFixed(ExactSize: PtrOff), DL, Flags);
14415 AddToWorklist(N: NewPtr.getNode());
14416
14417 SDValue Load;
14418 if (ExtType == ISD::NON_EXTLOAD)
14419 Load = DAG.getLoad(VT, dl: DL, Chain: LN0->getChain(), Ptr: NewPtr,
14420 PtrInfo: LN0->getPointerInfo().getWithOffset(O: PtrOff),
14421 Alignment: LN0->getOriginalAlign(),
14422 MMOFlags: LN0->getMemOperand()->getFlags(), AAInfo: LN0->getAAInfo());
14423 else
14424 Load = DAG.getExtLoad(ExtType, dl: DL, VT, Chain: LN0->getChain(), Ptr: NewPtr,
14425 PtrInfo: LN0->getPointerInfo().getWithOffset(O: PtrOff), MemVT: ExtVT,
14426 Alignment: LN0->getOriginalAlign(),
14427 MMOFlags: LN0->getMemOperand()->getFlags(), AAInfo: LN0->getAAInfo());
14428
14429 // Replace the old load's chain with the new load's chain.
14430 WorklistRemover DeadNodes(*this);
14431 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: Load.getValue(R: 1));
14432
14433 // Shift the result left, if we've swallowed a left shift.
14434 SDValue Result = Load;
14435 if (ShLeftAmt != 0) {
14436 EVT ShImmTy = getShiftAmountTy(LHSTy: Result.getValueType());
14437 if (!isUIntN(N: ShImmTy.getScalarSizeInBits(), x: ShLeftAmt))
14438 ShImmTy = VT;
14439 // If the shift amount is as large as the result size (but, presumably,
14440 // no larger than the source) then the useful bits of the result are
14441 // zero; we can't simply return the shortened shift, because the result
14442 // of that operation is undefined.
14443 if (ShLeftAmt >= VT.getScalarSizeInBits())
14444 Result = DAG.getConstant(Val: 0, DL, VT);
14445 else
14446 Result = DAG.getNode(Opcode: ISD::SHL, DL, VT,
14447 N1: Result, N2: DAG.getConstant(Val: ShLeftAmt, DL, VT: ShImmTy));
14448 }
14449
14450 if (ShiftedOffset != 0) {
14451 // We're using a shifted mask, so the load now has an offset. This means
14452 // that data has been loaded into the lower bytes than it would have been
14453 // before, so we need to shl the loaded data into the correct position in the
14454 // register.
14455 SDValue ShiftC = DAG.getConstant(Val: ShiftedOffset, DL, VT);
14456 Result = DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Result, N2: ShiftC);
14457 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result);
14458 }
14459
14460 // Return the new loaded value.
14461 return Result;
14462}
14463
14464SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
14465 SDValue N0 = N->getOperand(Num: 0);
14466 SDValue N1 = N->getOperand(Num: 1);
14467 EVT VT = N->getValueType(ResNo: 0);
14468 EVT ExtVT = cast<VTSDNode>(Val&: N1)->getVT();
14469 unsigned VTBits = VT.getScalarSizeInBits();
14470 unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
14471
14472 // sext_vector_inreg(undef) = 0 because the top bit will all be the same.
14473 if (N0.isUndef())
14474 return DAG.getConstant(Val: 0, DL: SDLoc(N), VT);
14475
14476 // fold (sext_in_reg c1) -> c1
14477 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0))
14478 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL: SDLoc(N), VT, N1: N0, N2: N1);
14479
14480 // If the input is already sign extended, just drop the extension.
14481 if (ExtVTBits >= DAG.ComputeMaxSignificantBits(Op: N0))
14482 return N0;
14483
14484 // fold (sext_in_reg (sext_in_reg x, VT2), VT1) -> (sext_in_reg x, minVT) pt2
14485 if (N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
14486 ExtVT.bitsLT(VT: cast<VTSDNode>(Val: N0.getOperand(i: 1))->getVT()))
14487 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
14488 N2: N1);
14489
14490 // fold (sext_in_reg (sext x)) -> (sext x)
14491 // fold (sext_in_reg (aext x)) -> (sext x)
14492 // if x is small enough or if we know that x has more than 1 sign bit and the
14493 // sign_extend_inreg is extending from one of them.
14494 if (N0.getOpcode() == ISD::SIGN_EXTEND || N0.getOpcode() == ISD::ANY_EXTEND) {
14495 SDValue N00 = N0.getOperand(i: 0);
14496 unsigned N00Bits = N00.getScalarValueSizeInBits();
14497 if ((N00Bits <= ExtVTBits ||
14498 DAG.ComputeMaxSignificantBits(Op: N00) <= ExtVTBits) &&
14499 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT)))
14500 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(N), VT, Operand: N00);
14501 }
14502
14503 // fold (sext_in_reg (*_extend_vector_inreg x)) -> (sext_vector_inreg x)
14504 // if x is small enough or if we know that x has more than 1 sign bit and the
14505 // sign_extend_inreg is extending from one of them.
14506 if (ISD::isExtVecInRegOpcode(Opcode: N0.getOpcode())) {
14507 SDValue N00 = N0.getOperand(i: 0);
14508 unsigned N00Bits = N00.getScalarValueSizeInBits();
14509 unsigned DstElts = N0.getValueType().getVectorMinNumElements();
14510 unsigned SrcElts = N00.getValueType().getVectorMinNumElements();
14511 bool IsZext = N0.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG;
14512 APInt DemandedSrcElts = APInt::getLowBitsSet(numBits: SrcElts, loBitsSet: DstElts);
14513 if ((N00Bits == ExtVTBits ||
14514 (!IsZext && (N00Bits < ExtVTBits ||
14515 DAG.ComputeMaxSignificantBits(Op: N00) <= ExtVTBits))) &&
14516 (!LegalOperations ||
14517 TLI.isOperationLegal(Op: ISD::SIGN_EXTEND_VECTOR_INREG, VT)))
14518 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_VECTOR_INREG, DL: SDLoc(N), VT, Operand: N00);
14519 }
14520
14521 // fold (sext_in_reg (zext x)) -> (sext x)
14522 // iff we are extending the source sign bit.
14523 if (N0.getOpcode() == ISD::ZERO_EXTEND) {
14524 SDValue N00 = N0.getOperand(i: 0);
14525 if (N00.getScalarValueSizeInBits() == ExtVTBits &&
14526 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT)))
14527 return DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(N), VT, Operand: N00);
14528 }
14529
14530 // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
14531 if (DAG.MaskedValueIsZero(Op: N0, Mask: APInt::getOneBitSet(numBits: VTBits, BitNo: ExtVTBits - 1)))
14532 return DAG.getZeroExtendInReg(Op: N0, DL: SDLoc(N), VT: ExtVT);
14533
14534 // fold operands of sext_in_reg based on knowledge that the top bits are not
14535 // demanded.
14536 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
14537 return SDValue(N, 0);
14538
14539 // fold (sext_in_reg (load x)) -> (smaller sextload x)
14540 // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
14541 if (SDValue NarrowLoad = reduceLoadWidth(N))
14542 return NarrowLoad;
14543
14544 // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
14545 // fold (sext_in_reg (srl X, 23), i8) -> (sra X, 23) iff possible.
14546 // We already fold "(sext_in_reg (srl X, 25), i8) -> srl X, 25" above.
14547 if (N0.getOpcode() == ISD::SRL) {
14548 if (auto *ShAmt = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 1)))
14549 if (ShAmt->getAPIntValue().ule(RHS: VTBits - ExtVTBits)) {
14550 // We can turn this into an SRA iff the input to the SRL is already sign
14551 // extended enough.
14552 unsigned InSignBits = DAG.ComputeNumSignBits(Op: N0.getOperand(i: 0));
14553 if (((VTBits - ExtVTBits) - ShAmt->getZExtValue()) < InSignBits)
14554 return DAG.getNode(Opcode: ISD::SRA, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
14555 N2: N0.getOperand(i: 1));
14556 }
14557 }
14558
14559 // fold (sext_inreg (extload x)) -> (sextload x)
14560 // If sextload is not supported by target, we can only do the combine when
14561 // load has one use. Doing otherwise can block folding the extload with other
14562 // extends that the target does support.
14563 if (ISD::isEXTLoad(N: N0.getNode()) &&
14564 ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
14565 ExtVT == cast<LoadSDNode>(Val&: N0)->getMemoryVT() &&
14566 ((!LegalOperations && cast<LoadSDNode>(Val&: N0)->isSimple() &&
14567 N0.hasOneUse()) ||
14568 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT))) {
14569 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14570 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: SDLoc(N), VT,
14571 Chain: LN0->getChain(),
14572 Ptr: LN0->getBasePtr(), MemVT: ExtVT,
14573 MMO: LN0->getMemOperand());
14574 CombineTo(N, Res: ExtLoad);
14575 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
14576 AddToWorklist(N: ExtLoad.getNode());
14577 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14578 }
14579
14580 // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
14581 if (ISD::isZEXTLoad(N: N0.getNode()) && ISD::isUNINDEXEDLoad(N: N0.getNode()) &&
14582 N0.hasOneUse() &&
14583 ExtVT == cast<LoadSDNode>(Val&: N0)->getMemoryVT() &&
14584 ((!LegalOperations && cast<LoadSDNode>(Val&: N0)->isSimple()) &&
14585 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT))) {
14586 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
14587 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::SEXTLOAD, dl: SDLoc(N), VT,
14588 Chain: LN0->getChain(),
14589 Ptr: LN0->getBasePtr(), MemVT: ExtVT,
14590 MMO: LN0->getMemOperand());
14591 CombineTo(N, Res: ExtLoad);
14592 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
14593 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14594 }
14595
14596 // fold (sext_inreg (masked_load x)) -> (sext_masked_load x)
14597 // ignore it if the masked load is already sign extended
14598 if (MaskedLoadSDNode *Ld = dyn_cast<MaskedLoadSDNode>(Val&: N0)) {
14599 if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() &&
14600 Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD &&
14601 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: VT, MemVT: ExtVT)) {
14602 SDValue ExtMaskedLoad = DAG.getMaskedLoad(
14603 VT, dl: SDLoc(N), Chain: Ld->getChain(), Base: Ld->getBasePtr(), Offset: Ld->getOffset(),
14604 Mask: Ld->getMask(), Src0: Ld->getPassThru(), MemVT: ExtVT, MMO: Ld->getMemOperand(),
14605 AM: Ld->getAddressingMode(), ISD::SEXTLOAD, IsExpanding: Ld->isExpandingLoad());
14606 CombineTo(N, Res: ExtMaskedLoad);
14607 CombineTo(N: N0.getNode(), Res0: ExtMaskedLoad, Res1: ExtMaskedLoad.getValue(R: 1));
14608 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14609 }
14610 }
14611
14612 // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x)
14613 if (auto *GN0 = dyn_cast<MaskedGatherSDNode>(Val&: N0)) {
14614 if (SDValue(GN0, 0).hasOneUse() &&
14615 ExtVT == GN0->getMemoryVT() &&
14616 TLI.isVectorLoadExtDesirable(ExtVal: SDValue(SDValue(GN0, 0)))) {
14617 SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(),
14618 GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()};
14619
14620 SDValue ExtLoad = DAG.getMaskedGather(
14621 DAG.getVTList(VT, MVT::Other), ExtVT, SDLoc(N), Ops,
14622 GN0->getMemOperand(), GN0->getIndexType(), ISD::SEXTLOAD);
14623
14624 CombineTo(N, Res: ExtLoad);
14625 CombineTo(N: N0.getNode(), Res0: ExtLoad, Res1: ExtLoad.getValue(R: 1));
14626 AddToWorklist(N: ExtLoad.getNode());
14627 return SDValue(N, 0); // Return N so it doesn't get rechecked!
14628 }
14629 }
14630
14631 // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16))
14632 if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) {
14633 if (SDValue BSwap = MatchBSwapHWordLow(N: N0.getNode(), N0: N0.getOperand(i: 0),
14634 N1: N0.getOperand(i: 1), DemandHighBits: false))
14635 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL: SDLoc(N), VT, N1: BSwap, N2: N1);
14636 }
14637
14638 // Fold (iM_signext_inreg
14639 // (extract_subvector (zext|anyext|sext iN_v to _) _)
14640 // from iN)
14641 // -> (extract_subvector (signext iN_v to iM))
14642 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && N0.hasOneUse() &&
14643 ISD::isExtOpcode(Opcode: N0.getOperand(i: 0).getOpcode())) {
14644 SDValue InnerExt = N0.getOperand(i: 0);
14645 EVT InnerExtVT = InnerExt->getValueType(ResNo: 0);
14646 SDValue Extendee = InnerExt->getOperand(Num: 0);
14647
14648 if (ExtVTBits == Extendee.getValueType().getScalarSizeInBits() &&
14649 (!LegalOperations ||
14650 TLI.isOperationLegal(Op: ISD::SIGN_EXTEND, VT: InnerExtVT))) {
14651 SDValue SignExtExtendee =
14652 DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(N), VT: InnerExtVT, Operand: Extendee);
14653 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT, N1: SignExtExtendee,
14654 N2: N0.getOperand(i: 1));
14655 }
14656 }
14657
14658 return SDValue();
14659}
14660
14661static SDValue foldExtendVectorInregToExtendOfSubvector(
14662 SDNode *N, const SDLoc &DL, const TargetLowering &TLI, SelectionDAG &DAG,
14663 bool LegalOperations) {
14664 unsigned InregOpcode = N->getOpcode();
14665 unsigned Opcode = DAG.getOpcode_EXTEND(Opcode: InregOpcode);
14666
14667 SDValue Src = N->getOperand(Num: 0);
14668 EVT VT = N->getValueType(ResNo: 0);
14669 EVT SrcVT = EVT::getVectorVT(Context&: *DAG.getContext(),
14670 VT: Src.getValueType().getVectorElementType(),
14671 EC: VT.getVectorElementCount());
14672
14673 assert(ISD::isExtVecInRegOpcode(InregOpcode) &&
14674 "Expected EXTEND_VECTOR_INREG dag node in input!");
14675
14676 // Profitability check: our operand must be an one-use CONCAT_VECTORS.
14677 // FIXME: one-use check may be overly restrictive
14678 if (!Src.hasOneUse() || Src.getOpcode() != ISD::CONCAT_VECTORS)
14679 return SDValue();
14680
14681 // Profitability check: we must be extending exactly one of it's operands.
14682 // FIXME: this is probably overly restrictive.
14683 Src = Src.getOperand(i: 0);
14684 if (Src.getValueType() != SrcVT)
14685 return SDValue();
14686
14687 if (LegalOperations && !TLI.isOperationLegal(Op: Opcode, VT))
14688 return SDValue();
14689
14690 return DAG.getNode(Opcode, DL, VT, Operand: Src);
14691}
14692
14693SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
14694 SDValue N0 = N->getOperand(Num: 0);
14695 EVT VT = N->getValueType(ResNo: 0);
14696 SDLoc DL(N);
14697
14698 if (N0.isUndef()) {
14699 // aext_vector_inreg(undef) = undef because the top bits are undefined.
14700 // {s/z}ext_vector_inreg(undef) = 0 because the top bits must be the same.
14701 return N->getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG
14702 ? DAG.getUNDEF(VT)
14703 : DAG.getConstant(Val: 0, DL, VT);
14704 }
14705
14706 if (SDValue Res = tryToFoldExtendOfConstant(N, DL, TLI, DAG, LegalTypes))
14707 return Res;
14708
14709 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
14710 return SDValue(N, 0);
14711
14712 if (SDValue R = foldExtendVectorInregToExtendOfSubvector(N, DL, TLI, DAG,
14713 LegalOperations))
14714 return R;
14715
14716 return SDValue();
14717}
14718
14719SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
14720 SDValue N0 = N->getOperand(Num: 0);
14721 EVT VT = N->getValueType(ResNo: 0);
14722 EVT SrcVT = N0.getValueType();
14723 bool isLE = DAG.getDataLayout().isLittleEndian();
14724 SDLoc DL(N);
14725
14726 // trunc(undef) = undef
14727 if (N0.isUndef())
14728 return DAG.getUNDEF(VT);
14729
14730 // fold (truncate (truncate x)) -> (truncate x)
14731 if (N0.getOpcode() == ISD::TRUNCATE)
14732 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
14733
14734 // fold (truncate c1) -> c1
14735 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::TRUNCATE, DL, VT, Ops: {N0}))
14736 return C;
14737
14738 // fold (truncate (ext x)) -> (ext x) or (truncate x) or x
14739 if (N0.getOpcode() == ISD::ZERO_EXTEND ||
14740 N0.getOpcode() == ISD::SIGN_EXTEND ||
14741 N0.getOpcode() == ISD::ANY_EXTEND) {
14742 // if the source is smaller than the dest, we still need an extend.
14743 if (N0.getOperand(i: 0).getValueType().bitsLT(VT))
14744 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, Operand: N0.getOperand(i: 0));
14745 // if the source is larger than the dest, than we just need the truncate.
14746 if (N0.getOperand(i: 0).getValueType().bitsGT(VT))
14747 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
14748 // if the source and dest are the same type, we can drop both the extend
14749 // and the truncate.
14750 return N0.getOperand(i: 0);
14751 }
14752
14753 // Try to narrow a truncate-of-sext_in_reg to the destination type:
14754 // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
14755 if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
14756 N0.hasOneUse()) {
14757 SDValue X = N0.getOperand(i: 0);
14758 SDValue ExtVal = N0.getOperand(i: 1);
14759 EVT ExtVT = cast<VTSDNode>(Val&: ExtVal)->getVT();
14760 if (ExtVT.bitsLT(VT) && TLI.preferSextInRegOfTruncate(TruncVT: VT, VT: SrcVT, ExtVT)) {
14761 SDValue TrX = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: X);
14762 return DAG.getNode(Opcode: ISD::SIGN_EXTEND_INREG, DL, VT, N1: TrX, N2: ExtVal);
14763 }
14764 }
14765
14766 // If this is anyext(trunc), don't fold it, allow ourselves to be folded.
14767 if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
14768 return SDValue();
14769
14770 // Fold extract-and-trunc into a narrow extract. For example:
14771 // i64 x = EXTRACT_VECTOR_ELT(v2i64 val, i32 1)
14772 // i32 y = TRUNCATE(i64 x)
14773 // -- becomes --
14774 // v16i8 b = BITCAST (v2i64 val)
14775 // i8 x = EXTRACT_VECTOR_ELT(v16i8 b, i32 8)
14776 //
14777 // Note: We only run this optimization after type legalization (which often
14778 // creates this pattern) and before operation legalization after which
14779 // we need to be more careful about the vector instructions that we generate.
14780 if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
14781 LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
14782 EVT VecTy = N0.getOperand(i: 0).getValueType();
14783 EVT ExTy = N0.getValueType();
14784 EVT TrTy = N->getValueType(ResNo: 0);
14785
14786 auto EltCnt = VecTy.getVectorElementCount();
14787 unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
14788 auto NewEltCnt = EltCnt * SizeRatio;
14789
14790 EVT NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: TrTy, EC: NewEltCnt);
14791 assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
14792
14793 SDValue EltNo = N0->getOperand(Num: 1);
14794 if (isa<ConstantSDNode>(Val: EltNo) && isTypeLegal(VT: NVT)) {
14795 int Elt = EltNo->getAsZExtVal();
14796 int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
14797 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: TrTy,
14798 N1: DAG.getBitcast(VT: NVT, V: N0.getOperand(i: 0)),
14799 N2: DAG.getVectorIdxConstant(Val: Index, DL));
14800 }
14801 }
14802
14803 // trunc (select c, a, b) -> select c, (trunc a), (trunc b)
14804 if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
14805 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::SELECT, VT: SrcVT)) &&
14806 TLI.isTruncateFree(FromVT: SrcVT, ToVT: VT)) {
14807 SDLoc SL(N0);
14808 SDValue Cond = N0.getOperand(i: 0);
14809 SDValue TruncOp0 = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SL, VT, Operand: N0.getOperand(i: 1));
14810 SDValue TruncOp1 = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SL, VT, Operand: N0.getOperand(i: 2));
14811 return DAG.getNode(Opcode: ISD::SELECT, DL, VT, N1: Cond, N2: TruncOp0, N3: TruncOp1);
14812 }
14813 }
14814
14815 // trunc (shl x, K) -> shl (trunc x), K => K < VT.getScalarSizeInBits()
14816 if (N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
14817 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SHL, VT)) &&
14818 TLI.isTypeDesirableForOp(ISD::SHL, VT)) {
14819 SDValue Amt = N0.getOperand(i: 1);
14820 KnownBits Known = DAG.computeKnownBits(Op: Amt);
14821 unsigned Size = VT.getScalarSizeInBits();
14822 if (Known.countMaxActiveBits() <= Log2_32(Value: Size)) {
14823 EVT AmtVT = TLI.getShiftAmountTy(LHSTy: VT, DL: DAG.getDataLayout());
14824 SDValue Trunc = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
14825 if (AmtVT != Amt.getValueType()) {
14826 Amt = DAG.getZExtOrTrunc(Op: Amt, DL, VT: AmtVT);
14827 AddToWorklist(N: Amt.getNode());
14828 }
14829 return DAG.getNode(Opcode: ISD::SHL, DL, VT, N1: Trunc, N2: Amt);
14830 }
14831 }
14832
14833 if (SDValue V = foldSubToUSubSat(DstVT: VT, N: N0.getNode(), DL))
14834 return V;
14835
14836 if (SDValue ABD = foldABSToABD(N, DL))
14837 return ABD;
14838
14839 // Attempt to pre-truncate BUILD_VECTOR sources.
14840 if (N0.getOpcode() == ISD::BUILD_VECTOR && !LegalOperations &&
14841 N0.hasOneUse() &&
14842 TLI.isTruncateFree(FromVT: SrcVT.getScalarType(), ToVT: VT.getScalarType()) &&
14843 // Avoid creating illegal types if running after type legalizer.
14844 (!LegalTypes || TLI.isTypeLegal(VT: VT.getScalarType()))) {
14845 EVT SVT = VT.getScalarType();
14846 SmallVector<SDValue, 8> TruncOps;
14847 for (const SDValue &Op : N0->op_values()) {
14848 SDValue TruncOp = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: SVT, Operand: Op);
14849 TruncOps.push_back(Elt: TruncOp);
14850 }
14851 return DAG.getBuildVector(VT, DL, Ops: TruncOps);
14852 }
14853
14854 // trunc (splat_vector x) -> splat_vector (trunc x)
14855 if (N0.getOpcode() == ISD::SPLAT_VECTOR &&
14856 (!LegalTypes || TLI.isTypeLegal(VT: VT.getScalarType())) &&
14857 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SPLAT_VECTOR, VT))) {
14858 EVT SVT = VT.getScalarType();
14859 return DAG.getSplatVector(
14860 VT, DL, Op: DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: SVT, Operand: N0->getOperand(Num: 0)));
14861 }
14862
14863 // Fold a series of buildvector, bitcast, and truncate if possible.
14864 // For example fold
14865 // (2xi32 trunc (bitcast ((4xi32)buildvector x, x, y, y) 2xi64)) to
14866 // (2xi32 (buildvector x, y)).
14867 if (Level == AfterLegalizeVectorOps && VT.isVector() &&
14868 N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
14869 N0.getOperand(i: 0).getOpcode() == ISD::BUILD_VECTOR &&
14870 N0.getOperand(i: 0).hasOneUse()) {
14871 SDValue BuildVect = N0.getOperand(i: 0);
14872 EVT BuildVectEltTy = BuildVect.getValueType().getVectorElementType();
14873 EVT TruncVecEltTy = VT.getVectorElementType();
14874
14875 // Check that the element types match.
14876 if (BuildVectEltTy == TruncVecEltTy) {
14877 // Now we only need to compute the offset of the truncated elements.
14878 unsigned BuildVecNumElts = BuildVect.getNumOperands();
14879 unsigned TruncVecNumElts = VT.getVectorNumElements();
14880 unsigned TruncEltOffset = BuildVecNumElts / TruncVecNumElts;
14881
14882 assert((BuildVecNumElts % TruncVecNumElts) == 0 &&
14883 "Invalid number of elements");
14884
14885 SmallVector<SDValue, 8> Opnds;
14886 for (unsigned i = 0, e = BuildVecNumElts; i != e; i += TruncEltOffset)
14887 Opnds.push_back(Elt: BuildVect.getOperand(i));
14888
14889 return DAG.getBuildVector(VT, DL, Ops: Opnds);
14890 }
14891 }
14892
14893 // fold (truncate (load x)) -> (smaller load x)
14894 // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
14895 if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
14896 if (SDValue Reduced = reduceLoadWidth(N))
14897 return Reduced;
14898
14899 // Handle the case where the truncated result is at least as wide as the
14900 // loaded type.
14901 if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N: N0.getNode())) {
14902 auto *LN0 = cast<LoadSDNode>(Val&: N0);
14903 if (LN0->isSimple() && LN0->getMemoryVT().bitsLE(VT)) {
14904 SDValue NewLoad = DAG.getExtLoad(
14905 ExtType: LN0->getExtensionType(), dl: SDLoc(LN0), VT, Chain: LN0->getChain(),
14906 Ptr: LN0->getBasePtr(), MemVT: LN0->getMemoryVT(), MMO: LN0->getMemOperand());
14907 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: NewLoad.getValue(R: 1));
14908 return NewLoad;
14909 }
14910 }
14911 }
14912
14913 // fold (trunc (concat ... x ...)) -> (concat ..., (trunc x), ...)),
14914 // where ... are all 'undef'.
14915 if (N0.getOpcode() == ISD::CONCAT_VECTORS && !LegalTypes) {
14916 SmallVector<EVT, 8> VTs;
14917 SDValue V;
14918 unsigned Idx = 0;
14919 unsigned NumDefs = 0;
14920
14921 for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
14922 SDValue X = N0.getOperand(i);
14923 if (!X.isUndef()) {
14924 V = X;
14925 Idx = i;
14926 NumDefs++;
14927 }
14928 // Stop if more than one members are non-undef.
14929 if (NumDefs > 1)
14930 break;
14931
14932 VTs.push_back(Elt: EVT::getVectorVT(Context&: *DAG.getContext(),
14933 VT: VT.getVectorElementType(),
14934 EC: X.getValueType().getVectorElementCount()));
14935 }
14936
14937 if (NumDefs == 0)
14938 return DAG.getUNDEF(VT);
14939
14940 if (NumDefs == 1) {
14941 assert(V.getNode() && "The single defined operand is empty!");
14942 SmallVector<SDValue, 8> Opnds;
14943 for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
14944 if (i != Idx) {
14945 Opnds.push_back(Elt: DAG.getUNDEF(VT: VTs[i]));
14946 continue;
14947 }
14948 SDValue NV = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(V), VT: VTs[i], Operand: V);
14949 AddToWorklist(N: NV.getNode());
14950 Opnds.push_back(Elt: NV);
14951 }
14952 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: Opnds);
14953 }
14954 }
14955
14956 // Fold truncate of a bitcast of a vector to an extract of the low vector
14957 // element.
14958 //
14959 // e.g. trunc (i64 (bitcast v2i32:x)) -> extract_vector_elt v2i32:x, idx
14960 if (N0.getOpcode() == ISD::BITCAST && !VT.isVector()) {
14961 SDValue VecSrc = N0.getOperand(i: 0);
14962 EVT VecSrcVT = VecSrc.getValueType();
14963 if (VecSrcVT.isVector() && VecSrcVT.getScalarType() == VT &&
14964 (!LegalOperations ||
14965 TLI.isOperationLegal(Op: ISD::EXTRACT_VECTOR_ELT, VT: VecSrcVT))) {
14966 unsigned Idx = isLE ? 0 : VecSrcVT.getVectorNumElements() - 1;
14967 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT, N1: VecSrc,
14968 N2: DAG.getVectorIdxConstant(Val: Idx, DL));
14969 }
14970 }
14971
14972 // Simplify the operands using demanded-bits information.
14973 if (SimplifyDemandedBits(Op: SDValue(N, 0)))
14974 return SDValue(N, 0);
14975
14976 // fold (truncate (extract_subvector(ext x))) ->
14977 // (extract_subvector x)
14978 // TODO: This can be generalized to cover cases where the truncate and extract
14979 // do not fully cancel each other out.
14980 if (!LegalTypes && N0.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
14981 SDValue N00 = N0.getOperand(i: 0);
14982 if (N00.getOpcode() == ISD::SIGN_EXTEND ||
14983 N00.getOpcode() == ISD::ZERO_EXTEND ||
14984 N00.getOpcode() == ISD::ANY_EXTEND) {
14985 if (N00.getOperand(i: 0)->getValueType(ResNo: 0).getVectorElementType() ==
14986 VT.getVectorElementType())
14987 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N0->getOperand(Num: 0)), VT,
14988 N1: N00.getOperand(i: 0), N2: N0.getOperand(i: 1));
14989 }
14990 }
14991
14992 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
14993 return NewVSel;
14994
14995 // Narrow a suitable binary operation with a non-opaque constant operand by
14996 // moving it ahead of the truncate. This is limited to pre-legalization
14997 // because targets may prefer a wider type during later combines and invert
14998 // this transform.
14999 switch (N0.getOpcode()) {
15000 case ISD::ADD:
15001 case ISD::SUB:
15002 case ISD::MUL:
15003 case ISD::AND:
15004 case ISD::OR:
15005 case ISD::XOR:
15006 if (!LegalOperations && N0.hasOneUse() &&
15007 (isConstantOrConstantVector(N: N0.getOperand(i: 0), NoOpaques: true) ||
15008 isConstantOrConstantVector(N: N0.getOperand(i: 1), NoOpaques: true))) {
15009 // TODO: We already restricted this to pre-legalization, but for vectors
15010 // we are extra cautious to not create an unsupported operation.
15011 // Target-specific changes are likely needed to avoid regressions here.
15012 if (VT.isScalarInteger() || TLI.isOperationLegal(Op: N0.getOpcode(), VT)) {
15013 SDValue NarrowL = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
15014 SDValue NarrowR = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 1));
15015 return DAG.getNode(Opcode: N0.getOpcode(), DL, VT, N1: NarrowL, N2: NarrowR);
15016 }
15017 }
15018 break;
15019 case ISD::ADDE:
15020 case ISD::UADDO_CARRY:
15021 // (trunc adde(X, Y, Carry)) -> (adde trunc(X), trunc(Y), Carry)
15022 // (trunc uaddo_carry(X, Y, Carry)) ->
15023 // (uaddo_carry trunc(X), trunc(Y), Carry)
15024 // When the adde's carry is not used.
15025 // We only do for uaddo_carry before legalize operation
15026 if (((!LegalOperations && N0.getOpcode() == ISD::UADDO_CARRY) ||
15027 TLI.isOperationLegal(Op: N0.getOpcode(), VT)) &&
15028 N0.hasOneUse() && !N0->hasAnyUseOfValue(Value: 1)) {
15029 SDValue X = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 0));
15030 SDValue Y = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT, Operand: N0.getOperand(i: 1));
15031 SDVTList VTs = DAG.getVTList(VT1: VT, VT2: N0->getValueType(ResNo: 1));
15032 return DAG.getNode(Opcode: N0.getOpcode(), DL, VTList: VTs, N1: X, N2: Y, N3: N0.getOperand(i: 2));
15033 }
15034 break;
15035 case ISD::USUBSAT:
15036 // Truncate the USUBSAT only if LHS is a known zero-extension, its not
15037 // enough to know that the upper bits are zero we must ensure that we don't
15038 // introduce an extra truncate.
15039 if (!LegalOperations && N0.hasOneUse() &&
15040 N0.getOperand(i: 0).getOpcode() == ISD::ZERO_EXTEND &&
15041 N0.getOperand(i: 0).getOperand(i: 0).getScalarValueSizeInBits() <=
15042 VT.getScalarSizeInBits() &&
15043 hasOperation(Opcode: N0.getOpcode(), VT)) {
15044 return getTruncatedUSUBSAT(DstVT: VT, SrcVT, LHS: N0.getOperand(i: 0), RHS: N0.getOperand(i: 1),
15045 DAG, DL);
15046 }
15047 break;
15048 }
15049
15050 return SDValue();
15051}
15052
15053static SDNode *getBuildPairElt(SDNode *N, unsigned i) {
15054 SDValue Elt = N->getOperand(Num: i);
15055 if (Elt.getOpcode() != ISD::MERGE_VALUES)
15056 return Elt.getNode();
15057 return Elt.getOperand(i: Elt.getResNo()).getNode();
15058}
15059
15060/// build_pair (load, load) -> load
15061/// if load locations are consecutive.
15062SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) {
15063 assert(N->getOpcode() == ISD::BUILD_PAIR);
15064
15065 auto *LD1 = dyn_cast<LoadSDNode>(Val: getBuildPairElt(N, i: 0));
15066 auto *LD2 = dyn_cast<LoadSDNode>(Val: getBuildPairElt(N, i: 1));
15067
15068 // A BUILD_PAIR is always having the least significant part in elt 0 and the
15069 // most significant part in elt 1. So when combining into one large load, we
15070 // need to consider the endianness.
15071 if (DAG.getDataLayout().isBigEndian())
15072 std::swap(a&: LD1, b&: LD2);
15073
15074 if (!LD1 || !LD2 || !ISD::isNON_EXTLoad(N: LD1) || !ISD::isNON_EXTLoad(N: LD2) ||
15075 !LD1->hasOneUse() || !LD2->hasOneUse() ||
15076 LD1->getAddressSpace() != LD2->getAddressSpace())
15077 return SDValue();
15078
15079 unsigned LD1Fast = 0;
15080 EVT LD1VT = LD1->getValueType(ResNo: 0);
15081 unsigned LD1Bytes = LD1VT.getStoreSize();
15082 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::LOAD, VT)) &&
15083 DAG.areNonVolatileConsecutiveLoads(LD: LD2, Base: LD1, Bytes: LD1Bytes, Dist: 1) &&
15084 TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
15085 MMO: *LD1->getMemOperand(), Fast: &LD1Fast) && LD1Fast)
15086 return DAG.getLoad(VT, dl: SDLoc(N), Chain: LD1->getChain(), Ptr: LD1->getBasePtr(),
15087 PtrInfo: LD1->getPointerInfo(), Alignment: LD1->getAlign());
15088
15089 return SDValue();
15090}
15091
15092static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) {
15093 // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi
15094 // and Lo parts; on big-endian machines it doesn't.
15095 return DAG.getDataLayout().isBigEndian() ? 1 : 0;
15096}
15097
15098SDValue DAGCombiner::foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
15099 const TargetLowering &TLI) {
15100 // If this is not a bitcast to an FP type or if the target doesn't have
15101 // IEEE754-compliant FP logic, we're done.
15102 EVT VT = N->getValueType(ResNo: 0);
15103 SDValue N0 = N->getOperand(Num: 0);
15104 EVT SourceVT = N0.getValueType();
15105
15106 if (!VT.isFloatingPoint())
15107 return SDValue();
15108
15109 // TODO: Handle cases where the integer constant is a different scalar
15110 // bitwidth to the FP.
15111 if (VT.getScalarSizeInBits() != SourceVT.getScalarSizeInBits())
15112 return SDValue();
15113
15114 unsigned FPOpcode;
15115 APInt SignMask;
15116 switch (N0.getOpcode()) {
15117 case ISD::AND:
15118 FPOpcode = ISD::FABS;
15119 SignMask = ~APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
15120 break;
15121 case ISD::XOR:
15122 FPOpcode = ISD::FNEG;
15123 SignMask = APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
15124 break;
15125 case ISD::OR:
15126 FPOpcode = ISD::FABS;
15127 SignMask = APInt::getSignMask(BitWidth: SourceVT.getScalarSizeInBits());
15128 break;
15129 default:
15130 return SDValue();
15131 }
15132
15133 if (LegalOperations && !TLI.isOperationLegal(Op: FPOpcode, VT))
15134 return SDValue();
15135
15136 // This needs to be the inverse of logic in foldSignChangeInBitcast.
15137 // FIXME: I don't think looking for bitcast intrinsically makes sense, but
15138 // removing this would require more changes.
15139 auto IsBitCastOrFree = [&TLI, FPOpcode](SDValue Op, EVT VT) {
15140 if (Op.getOpcode() == ISD::BITCAST && Op.getOperand(i: 0).getValueType() == VT)
15141 return true;
15142
15143 return FPOpcode == ISD::FABS ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
15144 };
15145
15146 // Fold (bitcast int (and (bitcast fp X to int), 0x7fff...) to fp) -> fabs X
15147 // Fold (bitcast int (xor (bitcast fp X to int), 0x8000...) to fp) -> fneg X
15148 // Fold (bitcast int (or (bitcast fp X to int), 0x8000...) to fp) ->
15149 // fneg (fabs X)
15150 SDValue LogicOp0 = N0.getOperand(i: 0);
15151 ConstantSDNode *LogicOp1 = isConstOrConstSplat(N: N0.getOperand(i: 1), AllowUndefs: true);
15152 if (LogicOp1 && LogicOp1->getAPIntValue() == SignMask &&
15153 IsBitCastOrFree(LogicOp0, VT)) {
15154 SDValue CastOp0 = DAG.getNode(Opcode: ISD::BITCAST, DL: SDLoc(N), VT, Operand: LogicOp0);
15155 SDValue FPOp = DAG.getNode(Opcode: FPOpcode, DL: SDLoc(N), VT, Operand: CastOp0);
15156 NumFPLogicOpsConv++;
15157 if (N0.getOpcode() == ISD::OR)
15158 return DAG.getNode(Opcode: ISD::FNEG, DL: SDLoc(N), VT, Operand: FPOp);
15159 return FPOp;
15160 }
15161
15162 return SDValue();
15163}
15164
15165SDValue DAGCombiner::visitBITCAST(SDNode *N) {
15166 SDValue N0 = N->getOperand(Num: 0);
15167 EVT VT = N->getValueType(ResNo: 0);
15168
15169 if (N0.isUndef())
15170 return DAG.getUNDEF(VT);
15171
15172 // If the input is a BUILD_VECTOR with all constant elements, fold this now.
15173 // Only do this before legalize types, unless both types are integer and the
15174 // scalar type is legal. Only do this before legalize ops, since the target
15175 // maybe depending on the bitcast.
15176 // First check to see if this is all constant.
15177 // TODO: Support FP bitcasts after legalize types.
15178 if (VT.isVector() &&
15179 (!LegalTypes ||
15180 (!LegalOperations && VT.isInteger() && N0.getValueType().isInteger() &&
15181 TLI.isTypeLegal(VT: VT.getVectorElementType()))) &&
15182 N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
15183 cast<BuildVectorSDNode>(Val&: N0)->isConstant())
15184 return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
15185 VT.getVectorElementType());
15186
15187 // If the input is a constant, let getNode fold it.
15188 if (isIntOrFPConstant(V: N0)) {
15189 // If we can't allow illegal operations, we need to check that this is just
15190 // a fp -> int or int -> conversion and that the resulting operation will
15191 // be legal.
15192 if (!LegalOperations ||
15193 (isa<ConstantSDNode>(Val: N0) && VT.isFloatingPoint() && !VT.isVector() &&
15194 TLI.isOperationLegal(Op: ISD::ConstantFP, VT)) ||
15195 (isa<ConstantFPSDNode>(Val: N0) && VT.isInteger() && !VT.isVector() &&
15196 TLI.isOperationLegal(Op: ISD::Constant, VT))) {
15197 SDValue C = DAG.getBitcast(VT, V: N0);
15198 if (C.getNode() != N)
15199 return C;
15200 }
15201 }
15202
15203 // (conv (conv x, t1), t2) -> (conv x, t2)
15204 if (N0.getOpcode() == ISD::BITCAST)
15205 return DAG.getBitcast(VT, V: N0.getOperand(i: 0));
15206
15207 // fold (conv (logicop (conv x), (c))) -> (logicop x, (conv c))
15208 // iff the current bitwise logicop type isn't legal
15209 if (ISD::isBitwiseLogicOp(Opcode: N0.getOpcode()) && VT.isInteger() &&
15210 !TLI.isTypeLegal(VT: N0.getOperand(i: 0).getValueType())) {
15211 auto IsFreeBitcast = [VT](SDValue V) {
15212 return (V.getOpcode() == ISD::BITCAST &&
15213 V.getOperand(i: 0).getValueType() == VT) ||
15214 (ISD::isBuildVectorOfConstantSDNodes(N: V.getNode()) &&
15215 V->hasOneUse());
15216 };
15217 if (IsFreeBitcast(N0.getOperand(i: 0)) && IsFreeBitcast(N0.getOperand(i: 1)))
15218 return DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N), VT,
15219 N1: DAG.getBitcast(VT, V: N0.getOperand(i: 0)),
15220 N2: DAG.getBitcast(VT, V: N0.getOperand(i: 1)));
15221 }
15222
15223 // fold (conv (load x)) -> (load (conv*)x)
15224 // If the resultant load doesn't need a higher alignment than the original!
15225 if (ISD::isNormalLoad(N: N0.getNode()) && N0.hasOneUse() &&
15226 // Do not remove the cast if the types differ in endian layout.
15227 TLI.hasBigEndianPartOrdering(VT: N0.getValueType(), DL: DAG.getDataLayout()) ==
15228 TLI.hasBigEndianPartOrdering(VT, DL: DAG.getDataLayout()) &&
15229 // If the load is volatile, we only want to change the load type if the
15230 // resulting load is legal. Otherwise we might increase the number of
15231 // memory accesses. We don't care if the original type was legal or not
15232 // as we assume software couldn't rely on the number of accesses of an
15233 // illegal type.
15234 ((!LegalOperations && cast<LoadSDNode>(Val&: N0)->isSimple()) ||
15235 TLI.isOperationLegal(Op: ISD::LOAD, VT))) {
15236 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
15237
15238 if (TLI.isLoadBitCastBeneficial(LoadVT: N0.getValueType(), BitcastVT: VT, DAG,
15239 MMO: *LN0->getMemOperand())) {
15240 SDValue Load =
15241 DAG.getLoad(VT, dl: SDLoc(N), Chain: LN0->getChain(), Ptr: LN0->getBasePtr(),
15242 MMO: LN0->getMemOperand());
15243 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: Load.getValue(R: 1));
15244 return Load;
15245 }
15246 }
15247
15248 if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
15249 return V;
15250
15251 // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit)
15252 // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit))
15253 //
15254 // For ppc_fp128:
15255 // fold (bitcast (fneg x)) ->
15256 // flipbit = signbit
15257 // (xor (bitcast x) (build_pair flipbit, flipbit))
15258 //
15259 // fold (bitcast (fabs x)) ->
15260 // flipbit = (and (extract_element (bitcast x), 0), signbit)
15261 // (xor (bitcast x) (build_pair flipbit, flipbit))
15262 // This often reduces constant pool loads.
15263 if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(VT: N0.getValueType())) ||
15264 (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(VT: N0.getValueType()))) &&
15265 N0->hasOneUse() && VT.isInteger() && !VT.isVector() &&
15266 !N0.getValueType().isVector()) {
15267 SDValue NewConv = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
15268 AddToWorklist(N: NewConv.getNode());
15269
15270 SDLoc DL(N);
15271 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
15272 assert(VT.getSizeInBits() == 128);
15273 SDValue SignBit = DAG.getConstant(
15274 APInt::getSignMask(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64);
15275 SDValue FlipBit;
15276 if (N0.getOpcode() == ISD::FNEG) {
15277 FlipBit = SignBit;
15278 AddToWorklist(N: FlipBit.getNode());
15279 } else {
15280 assert(N0.getOpcode() == ISD::FABS);
15281 SDValue Hi =
15282 DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv,
15283 DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
15284 SDLoc(NewConv)));
15285 AddToWorklist(N: Hi.getNode());
15286 FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit);
15287 AddToWorklist(N: FlipBit.getNode());
15288 }
15289 SDValue FlipBits =
15290 DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: SDLoc(N0), VT, N1: FlipBit, N2: FlipBit);
15291 AddToWorklist(N: FlipBits.getNode());
15292 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: NewConv, N2: FlipBits);
15293 }
15294 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits());
15295 if (N0.getOpcode() == ISD::FNEG)
15296 return DAG.getNode(Opcode: ISD::XOR, DL, VT,
15297 N1: NewConv, N2: DAG.getConstant(Val: SignBit, DL, VT));
15298 assert(N0.getOpcode() == ISD::FABS);
15299 return DAG.getNode(Opcode: ISD::AND, DL, VT,
15300 N1: NewConv, N2: DAG.getConstant(Val: ~SignBit, DL, VT));
15301 }
15302
15303 // fold (bitconvert (fcopysign cst, x)) ->
15304 // (or (and (bitconvert x), sign), (and cst, (not sign)))
15305 // Note that we don't handle (copysign x, cst) because this can always be
15306 // folded to an fneg or fabs.
15307 //
15308 // For ppc_fp128:
15309 // fold (bitcast (fcopysign cst, x)) ->
15310 // flipbit = (and (extract_element
15311 // (xor (bitcast cst), (bitcast x)), 0),
15312 // signbit)
15313 // (xor (bitcast cst) (build_pair flipbit, flipbit))
15314 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
15315 isa<ConstantFPSDNode>(Val: N0.getOperand(i: 0)) && VT.isInteger() &&
15316 !VT.isVector()) {
15317 unsigned OrigXWidth = N0.getOperand(i: 1).getValueSizeInBits();
15318 EVT IntXVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OrigXWidth);
15319 if (isTypeLegal(VT: IntXVT)) {
15320 SDValue X = DAG.getBitcast(VT: IntXVT, V: N0.getOperand(i: 1));
15321 AddToWorklist(N: X.getNode());
15322
15323 // If X has a different width than the result/lhs, sext it or truncate it.
15324 unsigned VTWidth = VT.getSizeInBits();
15325 if (OrigXWidth < VTWidth) {
15326 X = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(N), VT, Operand: X);
15327 AddToWorklist(N: X.getNode());
15328 } else if (OrigXWidth > VTWidth) {
15329 // To get the sign bit in the right place, we have to shift it right
15330 // before truncating.
15331 SDLoc DL(X);
15332 X = DAG.getNode(Opcode: ISD::SRL, DL,
15333 VT: X.getValueType(), N1: X,
15334 N2: DAG.getConstant(Val: OrigXWidth-VTWidth, DL,
15335 VT: X.getValueType()));
15336 AddToWorklist(N: X.getNode());
15337 X = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(X), VT, Operand: X);
15338 AddToWorklist(N: X.getNode());
15339 }
15340
15341 if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) {
15342 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits() / 2);
15343 SDValue Cst = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
15344 AddToWorklist(N: Cst.getNode());
15345 SDValue X = DAG.getBitcast(VT, V: N0.getOperand(i: 1));
15346 AddToWorklist(N: X.getNode());
15347 SDValue XorResult = DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N0), VT, N1: Cst, N2: X);
15348 AddToWorklist(N: XorResult.getNode());
15349 SDValue XorResult64 = DAG.getNode(
15350 ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult,
15351 DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG),
15352 SDLoc(XorResult)));
15353 AddToWorklist(N: XorResult64.getNode());
15354 SDValue FlipBit =
15355 DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64,
15356 DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64));
15357 AddToWorklist(N: FlipBit.getNode());
15358 SDValue FlipBits =
15359 DAG.getNode(Opcode: ISD::BUILD_PAIR, DL: SDLoc(N0), VT, N1: FlipBit, N2: FlipBit);
15360 AddToWorklist(N: FlipBits.getNode());
15361 return DAG.getNode(Opcode: ISD::XOR, DL: SDLoc(N), VT, N1: Cst, N2: FlipBits);
15362 }
15363 APInt SignBit = APInt::getSignMask(BitWidth: VT.getSizeInBits());
15364 X = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(X), VT,
15365 N1: X, N2: DAG.getConstant(Val: SignBit, DL: SDLoc(X), VT));
15366 AddToWorklist(N: X.getNode());
15367
15368 SDValue Cst = DAG.getBitcast(VT, V: N0.getOperand(i: 0));
15369 Cst = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(Cst), VT,
15370 N1: Cst, N2: DAG.getConstant(Val: ~SignBit, DL: SDLoc(Cst), VT));
15371 AddToWorklist(N: Cst.getNode());
15372
15373 return DAG.getNode(Opcode: ISD::OR, DL: SDLoc(N), VT, N1: X, N2: Cst);
15374 }
15375 }
15376
15377 // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
15378 if (N0.getOpcode() == ISD::BUILD_PAIR)
15379 if (SDValue CombineLD = CombineConsecutiveLoads(N: N0.getNode(), VT))
15380 return CombineLD;
15381
15382 // Remove double bitcasts from shuffles - this is often a legacy of
15383 // XformToShuffleWithZero being used to combine bitmaskings (of
15384 // float vectors bitcast to integer vectors) into shuffles.
15385 // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1)
15386 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() &&
15387 N0->getOpcode() == ISD::VECTOR_SHUFFLE && N0.hasOneUse() &&
15388 VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() &&
15389 !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) {
15390 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val&: N0);
15391
15392 // If operands are a bitcast, peek through if it casts the original VT.
15393 // If operands are a constant, just bitcast back to original VT.
15394 auto PeekThroughBitcast = [&](SDValue Op) {
15395 if (Op.getOpcode() == ISD::BITCAST &&
15396 Op.getOperand(i: 0).getValueType() == VT)
15397 return SDValue(Op.getOperand(i: 0));
15398 if (Op.isUndef() || isAnyConstantBuildVector(V: Op))
15399 return DAG.getBitcast(VT, V: Op);
15400 return SDValue();
15401 };
15402
15403 // FIXME: If either input vector is bitcast, try to convert the shuffle to
15404 // the result type of this bitcast. This would eliminate at least one
15405 // bitcast. See the transform in InstCombine.
15406 SDValue SV0 = PeekThroughBitcast(N0->getOperand(Num: 0));
15407 SDValue SV1 = PeekThroughBitcast(N0->getOperand(Num: 1));
15408 if (!(SV0 && SV1))
15409 return SDValue();
15410
15411 int MaskScale =
15412 VT.getVectorNumElements() / N0.getValueType().getVectorNumElements();
15413 SmallVector<int, 8> NewMask;
15414 for (int M : SVN->getMask())
15415 for (int i = 0; i != MaskScale; ++i)
15416 NewMask.push_back(Elt: M < 0 ? -1 : M * MaskScale + i);
15417
15418 SDValue LegalShuffle =
15419 TLI.buildLegalVectorShuffle(VT, DL: SDLoc(N), N0: SV0, N1: SV1, Mask: NewMask, DAG);
15420 if (LegalShuffle)
15421 return LegalShuffle;
15422 }
15423
15424 return SDValue();
15425}
15426
15427SDValue DAGCombiner::visitBUILD_PAIR(SDNode *N) {
15428 EVT VT = N->getValueType(ResNo: 0);
15429 return CombineConsecutiveLoads(N, VT);
15430}
15431
15432SDValue DAGCombiner::visitFREEZE(SDNode *N) {
15433 SDValue N0 = N->getOperand(Num: 0);
15434
15435 if (DAG.isGuaranteedNotToBeUndefOrPoison(Op: N0, /*PoisonOnly*/ false))
15436 return N0;
15437
15438 // Fold freeze(op(x, ...)) -> op(freeze(x), ...).
15439 // Try to push freeze through instructions that propagate but don't produce
15440 // poison as far as possible. If an operand of freeze follows three
15441 // conditions 1) one-use, 2) does not produce poison, and 3) has all but one
15442 // guaranteed-non-poison operands (or is a BUILD_VECTOR or similar) then push
15443 // the freeze through to the operands that are not guaranteed non-poison.
15444 // NOTE: we will strip poison-generating flags, so ignore them here.
15445 if (DAG.canCreateUndefOrPoison(Op: N0, /*PoisonOnly*/ false,
15446 /*ConsiderFlags*/ false) ||
15447 N0->getNumValues() != 1 || !N0->hasOneUse())
15448 return SDValue();
15449
15450 bool AllowMultipleMaybePoisonOperands = N0.getOpcode() == ISD::BUILD_VECTOR ||
15451 N0.getOpcode() == ISD::BUILD_PAIR ||
15452 N0.getOpcode() == ISD::CONCAT_VECTORS;
15453
15454 SmallSetVector<SDValue, 8> MaybePoisonOperands;
15455 for (SDValue Op : N0->ops()) {
15456 if (DAG.isGuaranteedNotToBeUndefOrPoison(Op, /*PoisonOnly*/ false,
15457 /*Depth*/ 1))
15458 continue;
15459 bool HadMaybePoisonOperands = !MaybePoisonOperands.empty();
15460 bool IsNewMaybePoisonOperand = MaybePoisonOperands.insert(X: Op);
15461 if (!HadMaybePoisonOperands)
15462 continue;
15463 if (IsNewMaybePoisonOperand && !AllowMultipleMaybePoisonOperands) {
15464 // Multiple maybe-poison ops when not allowed - bail out.
15465 return SDValue();
15466 }
15467 }
15468 // NOTE: the whole op may be not guaranteed to not be undef or poison because
15469 // it could create undef or poison due to it's poison-generating flags.
15470 // So not finding any maybe-poison operands is fine.
15471
15472 for (SDValue MaybePoisonOperand : MaybePoisonOperands) {
15473 // Don't replace every single UNDEF everywhere with frozen UNDEF, though.
15474 if (MaybePoisonOperand.getOpcode() == ISD::UNDEF)
15475 continue;
15476 // First, freeze each offending operand.
15477 SDValue FrozenMaybePoisonOperand = DAG.getFreeze(V: MaybePoisonOperand);
15478 // Then, change all other uses of unfrozen operand to use frozen operand.
15479 DAG.ReplaceAllUsesOfValueWith(From: MaybePoisonOperand, To: FrozenMaybePoisonOperand);
15480 if (FrozenMaybePoisonOperand.getOpcode() == ISD::FREEZE &&
15481 FrozenMaybePoisonOperand.getOperand(i: 0) == FrozenMaybePoisonOperand) {
15482 // But, that also updated the use in the freeze we just created, thus
15483 // creating a cycle in a DAG. Let's undo that by mutating the freeze.
15484 DAG.UpdateNodeOperands(N: FrozenMaybePoisonOperand.getNode(),
15485 Op: MaybePoisonOperand);
15486 }
15487 }
15488
15489 // This node has been merged with another.
15490 if (N->getOpcode() == ISD::DELETED_NODE)
15491 return SDValue(N, 0);
15492
15493 // The whole node may have been updated, so the value we were holding
15494 // may no longer be valid. Re-fetch the operand we're `freeze`ing.
15495 N0 = N->getOperand(Num: 0);
15496
15497 // Finally, recreate the node, it's operands were updated to use
15498 // frozen operands, so we just need to use it's "original" operands.
15499 SmallVector<SDValue> Ops(N0->op_begin(), N0->op_end());
15500 // Special-handle ISD::UNDEF, each single one of them can be it's own thing.
15501 for (SDValue &Op : Ops) {
15502 if (Op.getOpcode() == ISD::UNDEF)
15503 Op = DAG.getFreeze(V: Op);
15504 }
15505 // NOTE: this strips poison generating flags.
15506 SDValue R = DAG.getNode(Opcode: N0.getOpcode(), DL: SDLoc(N0), VTList: N0->getVTList(), Ops);
15507 assert(DAG.isGuaranteedNotToBeUndefOrPoison(R, /*PoisonOnly*/ false) &&
15508 "Can't create node that may be undef/poison!");
15509 return R;
15510}
15511
15512/// We know that BV is a build_vector node with Constant, ConstantFP or Undef
15513/// operands. DstEltVT indicates the destination element value type.
15514SDValue DAGCombiner::
15515ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
15516 EVT SrcEltVT = BV->getValueType(ResNo: 0).getVectorElementType();
15517
15518 // If this is already the right type, we're done.
15519 if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
15520
15521 unsigned SrcBitSize = SrcEltVT.getSizeInBits();
15522 unsigned DstBitSize = DstEltVT.getSizeInBits();
15523
15524 // If this is a conversion of N elements of one type to N elements of another
15525 // type, convert each element. This handles FP<->INT cases.
15526 if (SrcBitSize == DstBitSize) {
15527 SmallVector<SDValue, 8> Ops;
15528 for (SDValue Op : BV->op_values()) {
15529 // If the vector element type is not legal, the BUILD_VECTOR operands
15530 // are promoted and implicitly truncated. Make that explicit here.
15531 if (Op.getValueType() != SrcEltVT)
15532 Op = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(BV), VT: SrcEltVT, Operand: Op);
15533 Ops.push_back(Elt: DAG.getBitcast(VT: DstEltVT, V: Op));
15534 AddToWorklist(N: Ops.back().getNode());
15535 }
15536 EVT VT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: DstEltVT,
15537 NumElements: BV->getValueType(ResNo: 0).getVectorNumElements());
15538 return DAG.getBuildVector(VT, DL: SDLoc(BV), Ops);
15539 }
15540
15541 // Otherwise, we're growing or shrinking the elements. To avoid having to
15542 // handle annoying details of growing/shrinking FP values, we convert them to
15543 // int first.
15544 if (SrcEltVT.isFloatingPoint()) {
15545 // Convert the input float vector to a int vector where the elements are the
15546 // same sizes.
15547 EVT IntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SrcEltVT.getSizeInBits());
15548 BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, DstEltVT: IntVT).getNode();
15549 SrcEltVT = IntVT;
15550 }
15551
15552 // Now we know the input is an integer vector. If the output is a FP type,
15553 // convert to integer first, then to FP of the right size.
15554 if (DstEltVT.isFloatingPoint()) {
15555 EVT TmpVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: DstEltVT.getSizeInBits());
15556 SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, DstEltVT: TmpVT).getNode();
15557
15558 // Next, convert to FP elements of the same size.
15559 return ConstantFoldBITCASTofBUILD_VECTOR(BV: Tmp, DstEltVT);
15560 }
15561
15562 // Okay, we know the src/dst types are both integers of differing types.
15563 assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
15564
15565 // TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a
15566 // BuildVectorSDNode?
15567 auto *BVN = cast<BuildVectorSDNode>(Val: BV);
15568
15569 // Extract the constant raw bit data.
15570 BitVector UndefElements;
15571 SmallVector<APInt> RawBits;
15572 bool IsLE = DAG.getDataLayout().isLittleEndian();
15573 if (!BVN->getConstantRawBits(IsLittleEndian: IsLE, DstEltSizeInBits: DstBitSize, RawBitElements&: RawBits, UndefElements))
15574 return SDValue();
15575
15576 SDLoc DL(BV);
15577 SmallVector<SDValue, 8> Ops;
15578 for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
15579 if (UndefElements[I])
15580 Ops.push_back(Elt: DAG.getUNDEF(VT: DstEltVT));
15581 else
15582 Ops.push_back(Elt: DAG.getConstant(Val: RawBits[I], DL, VT: DstEltVT));
15583 }
15584
15585 EVT VT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: DstEltVT, NumElements: Ops.size());
15586 return DAG.getBuildVector(VT, DL, Ops);
15587}
15588
15589// Returns true if floating point contraction is allowed on the FMUL-SDValue
15590// `N`
15591static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
15592 assert(N.getOpcode() == ISD::FMUL);
15593
15594 return Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath ||
15595 N->getFlags().hasAllowContract();
15596}
15597
15598// Returns true if `N` can assume no infinities involved in its computation.
15599static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
15600 return Options.NoInfsFPMath || N->getFlags().hasNoInfs();
15601}
15602
15603/// Try to perform FMA combining on a given FADD node.
15604template <class MatchContextClass>
15605SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
15606 SDValue N0 = N->getOperand(Num: 0);
15607 SDValue N1 = N->getOperand(Num: 1);
15608 EVT VT = N->getValueType(ResNo: 0);
15609 SDLoc SL(N);
15610 MatchContextClass matcher(DAG, TLI, N);
15611 const TargetOptions &Options = DAG.getTarget().Options;
15612
15613 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
15614
15615 // Floating-point multiply-add with intermediate rounding.
15616 // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
15617 // FIXME: Add VP_FMAD opcode.
15618 bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
15619
15620 // Floating-point multiply-add without intermediate rounding.
15621 bool HasFMA =
15622 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT) &&
15623 (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT));
15624
15625 // No valid opcode, do not combine.
15626 if (!HasFMAD && !HasFMA)
15627 return SDValue();
15628
15629 bool CanReassociate =
15630 Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
15631 bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
15632 Options.UnsafeFPMath || HasFMAD);
15633 // If the addition is not contractable, do not combine.
15634 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
15635 return SDValue();
15636
15637 // Folding fadd (fmul x, y), (fmul x, y) -> fma x, y, (fmul x, y) is never
15638 // beneficial. It does not reduce latency. It increases register pressure. It
15639 // replaces an fadd with an fma which is a more complex instruction, so is
15640 // likely to have a larger encoding, use more functional units, etc.
15641 if (N0 == N1)
15642 return SDValue();
15643
15644 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
15645 return SDValue();
15646
15647 // Always prefer FMAD to FMA for precision.
15648 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
15649 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
15650
15651 auto isFusedOp = [&](SDValue N) {
15652 return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
15653 };
15654
15655 // Is the node an FMUL and contractable either due to global flags or
15656 // SDNodeFlags.
15657 auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
15658 if (!matcher.match(N, ISD::FMUL))
15659 return false;
15660 return AllowFusionGlobally || N->getFlags().hasAllowContract();
15661 };
15662 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
15663 // prefer to fold the multiply with fewer uses.
15664 if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
15665 if (N0->use_size() > N1->use_size())
15666 std::swap(a&: N0, b&: N1);
15667 }
15668
15669 // fold (fadd (fmul x, y), z) -> (fma x, y, z)
15670 if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
15671 return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0),
15672 N0.getOperand(i: 1), N1);
15673 }
15674
15675 // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
15676 // Note: Commutes FADD operands.
15677 if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
15678 return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(i: 0),
15679 N1.getOperand(i: 1), N0);
15680 }
15681
15682 // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
15683 // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
15684 // This also works with nested fma instructions:
15685 // fadd (fma A, B, (fma (C, D, (fmul (E, F))))), G -->
15686 // fma A, B, (fma C, D, fma (E, F, G))
15687 // fadd (G, (fma A, B, (fma (C, D, (fmul (E, F)))))) -->
15688 // fma A, B, (fma C, D, fma (E, F, G)).
15689 // This requires reassociation because it changes the order of operations.
15690 if (CanReassociate) {
15691 SDValue FMA, E;
15692 if (isFusedOp(N0) && N0.hasOneUse()) {
15693 FMA = N0;
15694 E = N1;
15695 } else if (isFusedOp(N1) && N1.hasOneUse()) {
15696 FMA = N1;
15697 E = N0;
15698 }
15699
15700 SDValue TmpFMA = FMA;
15701 while (E && isFusedOp(TmpFMA) && TmpFMA.hasOneUse()) {
15702 SDValue FMul = TmpFMA->getOperand(Num: 2);
15703 if (matcher.match(FMul, ISD::FMUL) && FMul.hasOneUse()) {
15704 SDValue C = FMul.getOperand(i: 0);
15705 SDValue D = FMul.getOperand(i: 1);
15706 SDValue CDE = matcher.getNode(PreferredFusedOpcode, SL, VT, C, D, E);
15707 DAG.ReplaceAllUsesOfValueWith(From: FMul, To: CDE);
15708 // Replacing the inner FMul could cause the outer FMA to be simplified
15709 // away.
15710 return FMA.getOpcode() == ISD::DELETED_NODE ? SDValue(N, 0) : FMA;
15711 }
15712
15713 TmpFMA = TmpFMA->getOperand(Num: 2);
15714 }
15715 }
15716
15717 // Look through FP_EXTEND nodes to do more combining.
15718
15719 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
15720 if (matcher.match(N0, ISD::FP_EXTEND)) {
15721 SDValue N00 = N0.getOperand(i: 0);
15722 if (isContractableFMUL(N00) &&
15723 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15724 SrcVT: N00.getValueType())) {
15725 return matcher.getNode(
15726 PreferredFusedOpcode, SL, VT,
15727 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
15728 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)), N1);
15729 }
15730 }
15731
15732 // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
15733 // Note: Commutes FADD operands.
15734 if (matcher.match(N1, ISD::FP_EXTEND)) {
15735 SDValue N10 = N1.getOperand(i: 0);
15736 if (isContractableFMUL(N10) &&
15737 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15738 SrcVT: N10.getValueType())) {
15739 return matcher.getNode(
15740 PreferredFusedOpcode, SL, VT,
15741 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 0)),
15742 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 1)), N0);
15743 }
15744 }
15745
15746 // More folding opportunities when target permits.
15747 if (Aggressive) {
15748 // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
15749 // -> (fma x, y, (fma (fpext u), (fpext v), z))
15750 auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
15751 SDValue Z) {
15752 return matcher.getNode(
15753 PreferredFusedOpcode, SL, VT, X, Y,
15754 matcher.getNode(PreferredFusedOpcode, SL, VT,
15755 matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
15756 matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
15757 };
15758 if (isFusedOp(N0)) {
15759 SDValue N02 = N0.getOperand(i: 2);
15760 if (matcher.match(N02, ISD::FP_EXTEND)) {
15761 SDValue N020 = N02.getOperand(i: 0);
15762 if (isContractableFMUL(N020) &&
15763 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15764 SrcVT: N020.getValueType())) {
15765 return FoldFAddFMAFPExtFMul(N0.getOperand(i: 0), N0.getOperand(i: 1),
15766 N020.getOperand(i: 0), N020.getOperand(i: 1),
15767 N1);
15768 }
15769 }
15770 }
15771
15772 // fold (fadd (fpext (fma x, y, (fmul u, v))), z)
15773 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z))
15774 // FIXME: This turns two single-precision and one double-precision
15775 // operation into two double-precision operations, which might not be
15776 // interesting for all targets, especially GPUs.
15777 auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
15778 SDValue Z) {
15779 return matcher.getNode(
15780 PreferredFusedOpcode, SL, VT,
15781 matcher.getNode(ISD::FP_EXTEND, SL, VT, X),
15782 matcher.getNode(ISD::FP_EXTEND, SL, VT, Y),
15783 matcher.getNode(PreferredFusedOpcode, SL, VT,
15784 matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
15785 matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
15786 };
15787 if (N0.getOpcode() == ISD::FP_EXTEND) {
15788 SDValue N00 = N0.getOperand(i: 0);
15789 if (isFusedOp(N00)) {
15790 SDValue N002 = N00.getOperand(i: 2);
15791 if (isContractableFMUL(N002) &&
15792 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15793 SrcVT: N00.getValueType())) {
15794 return FoldFAddFPExtFMAFMul(N00.getOperand(i: 0), N00.getOperand(i: 1),
15795 N002.getOperand(i: 0), N002.getOperand(i: 1),
15796 N1);
15797 }
15798 }
15799 }
15800
15801 // fold (fadd x, (fma y, z, (fpext (fmul u, v)))
15802 // -> (fma y, z, (fma (fpext u), (fpext v), x))
15803 if (isFusedOp(N1)) {
15804 SDValue N12 = N1.getOperand(i: 2);
15805 if (N12.getOpcode() == ISD::FP_EXTEND) {
15806 SDValue N120 = N12.getOperand(i: 0);
15807 if (isContractableFMUL(N120) &&
15808 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15809 SrcVT: N120.getValueType())) {
15810 return FoldFAddFMAFPExtFMul(N1.getOperand(i: 0), N1.getOperand(i: 1),
15811 N120.getOperand(i: 0), N120.getOperand(i: 1),
15812 N0);
15813 }
15814 }
15815 }
15816
15817 // fold (fadd x, (fpext (fma y, z, (fmul u, v)))
15818 // -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x))
15819 // FIXME: This turns two single-precision and one double-precision
15820 // operation into two double-precision operations, which might not be
15821 // interesting for all targets, especially GPUs.
15822 if (N1.getOpcode() == ISD::FP_EXTEND) {
15823 SDValue N10 = N1.getOperand(i: 0);
15824 if (isFusedOp(N10)) {
15825 SDValue N102 = N10.getOperand(i: 2);
15826 if (isContractableFMUL(N102) &&
15827 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15828 SrcVT: N10.getValueType())) {
15829 return FoldFAddFPExtFMAFMul(N10.getOperand(i: 0), N10.getOperand(i: 1),
15830 N102.getOperand(i: 0), N102.getOperand(i: 1),
15831 N0);
15832 }
15833 }
15834 }
15835 }
15836
15837 return SDValue();
15838}
15839
15840/// Try to perform FMA combining on a given FSUB node.
15841template <class MatchContextClass>
15842SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
15843 SDValue N0 = N->getOperand(Num: 0);
15844 SDValue N1 = N->getOperand(Num: 1);
15845 EVT VT = N->getValueType(ResNo: 0);
15846 SDLoc SL(N);
15847 MatchContextClass matcher(DAG, TLI, N);
15848 const TargetOptions &Options = DAG.getTarget().Options;
15849
15850 bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
15851
15852 // Floating-point multiply-add with intermediate rounding.
15853 // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
15854 // FIXME: Add VP_FMAD opcode.
15855 bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
15856
15857 // Floating-point multiply-add without intermediate rounding.
15858 bool HasFMA =
15859 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT) &&
15860 (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT));
15861
15862 // No valid opcode, do not combine.
15863 if (!HasFMAD && !HasFMA)
15864 return SDValue();
15865
15866 const SDNodeFlags Flags = N->getFlags();
15867 bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
15868 Options.UnsafeFPMath || HasFMAD);
15869
15870 // If the subtraction is not contractable, do not combine.
15871 if (!AllowFusionGlobally && !N->getFlags().hasAllowContract())
15872 return SDValue();
15873
15874 if (TLI.generateFMAsInMachineCombiner(VT, OptLevel))
15875 return SDValue();
15876
15877 // Always prefer FMAD to FMA for precision.
15878 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
15879 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
15880 bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros();
15881
15882 // Is the node an FMUL and contractable either due to global flags or
15883 // SDNodeFlags.
15884 auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
15885 if (!matcher.match(N, ISD::FMUL))
15886 return false;
15887 return AllowFusionGlobally || N->getFlags().hasAllowContract();
15888 };
15889
15890 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
15891 auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) {
15892 if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) {
15893 return matcher.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(i: 0),
15894 XY.getOperand(i: 1),
15895 matcher.getNode(ISD::FNEG, SL, VT, Z));
15896 }
15897 return SDValue();
15898 };
15899
15900 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
15901 // Note: Commutes FSUB operands.
15902 auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) {
15903 if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) {
15904 return matcher.getNode(
15905 PreferredFusedOpcode, SL, VT,
15906 matcher.getNode(ISD::FNEG, SL, VT, YZ.getOperand(i: 0)),
15907 YZ.getOperand(i: 1), X);
15908 }
15909 return SDValue();
15910 };
15911
15912 // If we have two choices trying to fold (fsub (fmul u, v), (fmul x, y)),
15913 // prefer to fold the multiply with fewer uses.
15914 if (isContractableFMUL(N0) && isContractableFMUL(N1) &&
15915 (N0->use_size() > N1->use_size())) {
15916 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma (fneg c), d, (fmul a, b))
15917 if (SDValue V = tryToFoldXSubYZ(N0, N1))
15918 return V;
15919 // fold (fsub (fmul a, b), (fmul c, d)) -> (fma a, b, (fneg (fmul c, d)))
15920 if (SDValue V = tryToFoldXYSubZ(N0, N1))
15921 return V;
15922 } else {
15923 // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
15924 if (SDValue V = tryToFoldXYSubZ(N0, N1))
15925 return V;
15926 // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
15927 if (SDValue V = tryToFoldXSubYZ(N0, N1))
15928 return V;
15929 }
15930
15931 // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
15932 if (matcher.match(N0, ISD::FNEG) && isContractableFMUL(N0.getOperand(i: 0)) &&
15933 (Aggressive || (N0->hasOneUse() && N0.getOperand(i: 0).hasOneUse()))) {
15934 SDValue N00 = N0.getOperand(i: 0).getOperand(i: 0);
15935 SDValue N01 = N0.getOperand(i: 0).getOperand(i: 1);
15936 return matcher.getNode(PreferredFusedOpcode, SL, VT,
15937 matcher.getNode(ISD::FNEG, SL, VT, N00), N01,
15938 matcher.getNode(ISD::FNEG, SL, VT, N1));
15939 }
15940
15941 // Look through FP_EXTEND nodes to do more combining.
15942
15943 // fold (fsub (fpext (fmul x, y)), z)
15944 // -> (fma (fpext x), (fpext y), (fneg z))
15945 if (matcher.match(N0, ISD::FP_EXTEND)) {
15946 SDValue N00 = N0.getOperand(i: 0);
15947 if (isContractableFMUL(N00) &&
15948 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15949 SrcVT: N00.getValueType())) {
15950 return matcher.getNode(
15951 PreferredFusedOpcode, SL, VT,
15952 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
15953 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)),
15954 matcher.getNode(ISD::FNEG, SL, VT, N1));
15955 }
15956 }
15957
15958 // fold (fsub x, (fpext (fmul y, z)))
15959 // -> (fma (fneg (fpext y)), (fpext z), x)
15960 // Note: Commutes FSUB operands.
15961 if (matcher.match(N1, ISD::FP_EXTEND)) {
15962 SDValue N10 = N1.getOperand(i: 0);
15963 if (isContractableFMUL(N10) &&
15964 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15965 SrcVT: N10.getValueType())) {
15966 return matcher.getNode(
15967 PreferredFusedOpcode, SL, VT,
15968 matcher.getNode(
15969 ISD::FNEG, SL, VT,
15970 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 0))),
15971 matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(i: 1)), N0);
15972 }
15973 }
15974
15975 // fold (fsub (fpext (fneg (fmul, x, y))), z)
15976 // -> (fneg (fma (fpext x), (fpext y), z))
15977 // Note: This could be removed with appropriate canonicalization of the
15978 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
15979 // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
15980 // from implementing the canonicalization in visitFSUB.
15981 if (matcher.match(N0, ISD::FP_EXTEND)) {
15982 SDValue N00 = N0.getOperand(i: 0);
15983 if (matcher.match(N00, ISD::FNEG)) {
15984 SDValue N000 = N00.getOperand(i: 0);
15985 if (isContractableFMUL(N000) &&
15986 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
15987 SrcVT: N00.getValueType())) {
15988 return matcher.getNode(
15989 ISD::FNEG, SL, VT,
15990 matcher.getNode(
15991 PreferredFusedOpcode, SL, VT,
15992 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 0)),
15993 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 1)),
15994 N1));
15995 }
15996 }
15997 }
15998
15999 // fold (fsub (fneg (fpext (fmul, x, y))), z)
16000 // -> (fneg (fma (fpext x)), (fpext y), z)
16001 // Note: This could be removed with appropriate canonicalization of the
16002 // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the
16003 // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent
16004 // from implementing the canonicalization in visitFSUB.
16005 if (matcher.match(N0, ISD::FNEG)) {
16006 SDValue N00 = N0.getOperand(i: 0);
16007 if (matcher.match(N00, ISD::FP_EXTEND)) {
16008 SDValue N000 = N00.getOperand(i: 0);
16009 if (isContractableFMUL(N000) &&
16010 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16011 SrcVT: N000.getValueType())) {
16012 return matcher.getNode(
16013 ISD::FNEG, SL, VT,
16014 matcher.getNode(
16015 PreferredFusedOpcode, SL, VT,
16016 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 0)),
16017 matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(i: 1)),
16018 N1));
16019 }
16020 }
16021 }
16022
16023 auto isReassociable = [&Options](SDNode *N) {
16024 return Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
16025 };
16026
16027 auto isContractableAndReassociableFMUL = [&isContractableFMUL,
16028 &isReassociable](SDValue N) {
16029 return isContractableFMUL(N) && isReassociable(N.getNode());
16030 };
16031
16032 auto isFusedOp = [&](SDValue N) {
16033 return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
16034 };
16035
16036 // More folding opportunities when target permits.
16037 if (Aggressive && isReassociable(N)) {
16038 bool CanFuse = Options.UnsafeFPMath || N->getFlags().hasAllowContract();
16039 // fold (fsub (fma x, y, (fmul u, v)), z)
16040 // -> (fma x, y (fma u, v, (fneg z)))
16041 if (CanFuse && isFusedOp(N0) &&
16042 isContractableAndReassociableFMUL(N0.getOperand(i: 2)) &&
16043 N0->hasOneUse() && N0.getOperand(i: 2)->hasOneUse()) {
16044 return matcher.getNode(
16045 PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0), N0.getOperand(i: 1),
16046 matcher.getNode(PreferredFusedOpcode, SL, VT,
16047 N0.getOperand(i: 2).getOperand(i: 0),
16048 N0.getOperand(i: 2).getOperand(i: 1),
16049 matcher.getNode(ISD::FNEG, SL, VT, N1)));
16050 }
16051
16052 // fold (fsub x, (fma y, z, (fmul u, v)))
16053 // -> (fma (fneg y), z, (fma (fneg u), v, x))
16054 if (CanFuse && isFusedOp(N1) &&
16055 isContractableAndReassociableFMUL(N1.getOperand(i: 2)) &&
16056 N1->hasOneUse() && NoSignedZero) {
16057 SDValue N20 = N1.getOperand(i: 2).getOperand(i: 0);
16058 SDValue N21 = N1.getOperand(i: 2).getOperand(i: 1);
16059 return matcher.getNode(
16060 PreferredFusedOpcode, SL, VT,
16061 matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(i: 0)),
16062 N1.getOperand(i: 1),
16063 matcher.getNode(PreferredFusedOpcode, SL, VT,
16064 matcher.getNode(ISD::FNEG, SL, VT, N20), N21, N0));
16065 }
16066
16067 // fold (fsub (fma x, y, (fpext (fmul u, v))), z)
16068 // -> (fma x, y (fma (fpext u), (fpext v), (fneg z)))
16069 if (isFusedOp(N0) && N0->hasOneUse()) {
16070 SDValue N02 = N0.getOperand(i: 2);
16071 if (matcher.match(N02, ISD::FP_EXTEND)) {
16072 SDValue N020 = N02.getOperand(i: 0);
16073 if (isContractableAndReassociableFMUL(N020) &&
16074 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16075 SrcVT: N020.getValueType())) {
16076 return matcher.getNode(
16077 PreferredFusedOpcode, SL, VT, N0.getOperand(i: 0), N0.getOperand(i: 1),
16078 matcher.getNode(
16079 PreferredFusedOpcode, SL, VT,
16080 matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(i: 0)),
16081 matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(i: 1)),
16082 matcher.getNode(ISD::FNEG, SL, VT, N1)));
16083 }
16084 }
16085 }
16086
16087 // fold (fsub (fpext (fma x, y, (fmul u, v))), z)
16088 // -> (fma (fpext x), (fpext y),
16089 // (fma (fpext u), (fpext v), (fneg z)))
16090 // FIXME: This turns two single-precision and one double-precision
16091 // operation into two double-precision operations, which might not be
16092 // interesting for all targets, especially GPUs.
16093 if (matcher.match(N0, ISD::FP_EXTEND)) {
16094 SDValue N00 = N0.getOperand(i: 0);
16095 if (isFusedOp(N00)) {
16096 SDValue N002 = N00.getOperand(i: 2);
16097 if (isContractableAndReassociableFMUL(N002) &&
16098 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16099 SrcVT: N00.getValueType())) {
16100 return matcher.getNode(
16101 PreferredFusedOpcode, SL, VT,
16102 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 0)),
16103 matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(i: 1)),
16104 matcher.getNode(
16105 PreferredFusedOpcode, SL, VT,
16106 matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(i: 0)),
16107 matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(i: 1)),
16108 matcher.getNode(ISD::FNEG, SL, VT, N1)));
16109 }
16110 }
16111 }
16112
16113 // fold (fsub x, (fma y, z, (fpext (fmul u, v))))
16114 // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x))
16115 if (isFusedOp(N1) && matcher.match(N1.getOperand(i: 2), ISD::FP_EXTEND) &&
16116 N1->hasOneUse()) {
16117 SDValue N120 = N1.getOperand(i: 2).getOperand(i: 0);
16118 if (isContractableAndReassociableFMUL(N120) &&
16119 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16120 SrcVT: N120.getValueType())) {
16121 SDValue N1200 = N120.getOperand(i: 0);
16122 SDValue N1201 = N120.getOperand(i: 1);
16123 return matcher.getNode(
16124 PreferredFusedOpcode, SL, VT,
16125 matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(i: 0)),
16126 N1.getOperand(i: 1),
16127 matcher.getNode(
16128 PreferredFusedOpcode, SL, VT,
16129 matcher.getNode(ISD::FNEG, SL, VT,
16130 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1200)),
16131 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0));
16132 }
16133 }
16134
16135 // fold (fsub x, (fpext (fma y, z, (fmul u, v))))
16136 // -> (fma (fneg (fpext y)), (fpext z),
16137 // (fma (fneg (fpext u)), (fpext v), x))
16138 // FIXME: This turns two single-precision and one double-precision
16139 // operation into two double-precision operations, which might not be
16140 // interesting for all targets, especially GPUs.
16141 if (matcher.match(N1, ISD::FP_EXTEND) && isFusedOp(N1.getOperand(i: 0))) {
16142 SDValue CvtSrc = N1.getOperand(i: 0);
16143 SDValue N100 = CvtSrc.getOperand(i: 0);
16144 SDValue N101 = CvtSrc.getOperand(i: 1);
16145 SDValue N102 = CvtSrc.getOperand(i: 2);
16146 if (isContractableAndReassociableFMUL(N102) &&
16147 TLI.isFPExtFoldable(DAG, Opcode: PreferredFusedOpcode, DestVT: VT,
16148 SrcVT: CvtSrc.getValueType())) {
16149 SDValue N1020 = N102.getOperand(i: 0);
16150 SDValue N1021 = N102.getOperand(i: 1);
16151 return matcher.getNode(
16152 PreferredFusedOpcode, SL, VT,
16153 matcher.getNode(ISD::FNEG, SL, VT,
16154 matcher.getNode(ISD::FP_EXTEND, SL, VT, N100)),
16155 matcher.getNode(ISD::FP_EXTEND, SL, VT, N101),
16156 matcher.getNode(
16157 PreferredFusedOpcode, SL, VT,
16158 matcher.getNode(ISD::FNEG, SL, VT,
16159 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1020)),
16160 matcher.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0));
16161 }
16162 }
16163 }
16164
16165 return SDValue();
16166}
16167
16168/// Try to perform FMA combining on a given FMUL node based on the distributive
16169/// law x * (y + 1) = x * y + x and variants thereof (commuted versions,
16170/// subtraction instead of addition).
16171SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
16172 SDValue N0 = N->getOperand(Num: 0);
16173 SDValue N1 = N->getOperand(Num: 1);
16174 EVT VT = N->getValueType(ResNo: 0);
16175 SDLoc SL(N);
16176
16177 assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation");
16178
16179 const TargetOptions &Options = DAG.getTarget().Options;
16180
16181 // The transforms below are incorrect when x == 0 and y == inf, because the
16182 // intermediate multiplication produces a nan.
16183 SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
16184 if (!hasNoInfs(Options, N: FAdd))
16185 return SDValue();
16186
16187 // Floating-point multiply-add without intermediate rounding.
16188 bool HasFMA =
16189 isContractableFMUL(Options, N: SDValue(N, 0)) &&
16190 TLI.isFMAFasterThanFMulAndFAdd(MF: DAG.getMachineFunction(), VT) &&
16191 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FMA, VT));
16192
16193 // Floating-point multiply-add with intermediate rounding. This can result
16194 // in a less precise result due to the changed rounding order.
16195 bool HasFMAD = Options.UnsafeFPMath &&
16196 (LegalOperations && TLI.isFMADLegal(DAG, N));
16197
16198 // No valid opcode, do not combine.
16199 if (!HasFMAD && !HasFMA)
16200 return SDValue();
16201
16202 // Always prefer FMAD to FMA for precision.
16203 unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA;
16204 bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
16205
16206 // fold (fmul (fadd x0, +1.0), y) -> (fma x0, y, y)
16207 // fold (fmul (fadd x0, -1.0), y) -> (fma x0, y, (fneg y))
16208 auto FuseFADD = [&](SDValue X, SDValue Y) {
16209 if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) {
16210 if (auto *C = isConstOrConstSplatFP(N: X.getOperand(i: 1), AllowUndefs: true)) {
16211 if (C->isExactlyValue(V: +1.0))
16212 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
16213 N3: Y);
16214 if (C->isExactlyValue(V: -1.0))
16215 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
16216 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
16217 }
16218 }
16219 return SDValue();
16220 };
16221
16222 if (SDValue FMA = FuseFADD(N0, N1))
16223 return FMA;
16224 if (SDValue FMA = FuseFADD(N1, N0))
16225 return FMA;
16226
16227 // fold (fmul (fsub +1.0, x1), y) -> (fma (fneg x1), y, y)
16228 // fold (fmul (fsub -1.0, x1), y) -> (fma (fneg x1), y, (fneg y))
16229 // fold (fmul (fsub x0, +1.0), y) -> (fma x0, y, (fneg y))
16230 // fold (fmul (fsub x0, -1.0), y) -> (fma x0, y, y)
16231 auto FuseFSUB = [&](SDValue X, SDValue Y) {
16232 if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) {
16233 if (auto *C0 = isConstOrConstSplatFP(N: X.getOperand(i: 0), AllowUndefs: true)) {
16234 if (C0->isExactlyValue(V: +1.0))
16235 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT,
16236 N1: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: X.getOperand(i: 1)), N2: Y,
16237 N3: Y);
16238 if (C0->isExactlyValue(V: -1.0))
16239 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT,
16240 N1: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: X.getOperand(i: 1)), N2: Y,
16241 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
16242 }
16243 if (auto *C1 = isConstOrConstSplatFP(N: X.getOperand(i: 1), AllowUndefs: true)) {
16244 if (C1->isExactlyValue(V: +1.0))
16245 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
16246 N3: DAG.getNode(Opcode: ISD::FNEG, DL: SL, VT, Operand: Y));
16247 if (C1->isExactlyValue(V: -1.0))
16248 return DAG.getNode(Opcode: PreferredFusedOpcode, DL: SL, VT, N1: X.getOperand(i: 0), N2: Y,
16249 N3: Y);
16250 }
16251 }
16252 return SDValue();
16253 };
16254
16255 if (SDValue FMA = FuseFSUB(N0, N1))
16256 return FMA;
16257 if (SDValue FMA = FuseFSUB(N1, N0))
16258 return FMA;
16259
16260 return SDValue();
16261}
16262
16263SDValue DAGCombiner::visitVP_FADD(SDNode *N) {
16264 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16265
16266 // FADD -> FMA combines:
16267 if (SDValue Fused = visitFADDForFMACombine<VPMatchContext>(N)) {
16268 if (Fused.getOpcode() != ISD::DELETED_NODE)
16269 AddToWorklist(N: Fused.getNode());
16270 return Fused;
16271 }
16272 return SDValue();
16273}
16274
16275SDValue DAGCombiner::visitFADD(SDNode *N) {
16276 SDValue N0 = N->getOperand(Num: 0);
16277 SDValue N1 = N->getOperand(Num: 1);
16278 SDNode *N0CFP = DAG.isConstantFPBuildVectorOrConstantFP(N: N0);
16279 SDNode *N1CFP = DAG.isConstantFPBuildVectorOrConstantFP(N: N1);
16280 EVT VT = N->getValueType(ResNo: 0);
16281 SDLoc DL(N);
16282 const TargetOptions &Options = DAG.getTarget().Options;
16283 SDNodeFlags Flags = N->getFlags();
16284 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16285
16286 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
16287 return R;
16288
16289 // fold (fadd c1, c2) -> c1 + c2
16290 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FADD, DL, VT, Ops: {N0, N1}))
16291 return C;
16292
16293 // canonicalize constant to RHS
16294 if (N0CFP && !N1CFP)
16295 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1, N2: N0);
16296
16297 // fold vector ops
16298 if (VT.isVector())
16299 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16300 return FoldedVOp;
16301
16302 // N0 + -0.0 --> N0 (also allowed with +0.0 and fast-math)
16303 ConstantFPSDNode *N1C = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
16304 if (N1C && N1C->isZero())
16305 if (N1C->isNegative() || Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())
16306 return N0;
16307
16308 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
16309 return NewSel;
16310
16311 // fold (fadd A, (fneg B)) -> (fsub A, B)
16312 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FSUB, VT))
16313 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
16314 Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
16315 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: N0, N2: NegN1);
16316
16317 // fold (fadd (fneg A), B) -> (fsub B, A)
16318 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::FSUB, VT))
16319 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
16320 Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
16321 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1, N2: NegN0);
16322
16323 auto isFMulNegTwo = [](SDValue FMul) {
16324 if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL)
16325 return false;
16326 auto *C = isConstOrConstSplatFP(N: FMul.getOperand(i: 1), AllowUndefs: true);
16327 return C && C->isExactlyValue(V: -2.0);
16328 };
16329
16330 // fadd (fmul B, -2.0), A --> fsub A, (fadd B, B)
16331 if (isFMulNegTwo(N0)) {
16332 SDValue B = N0.getOperand(i: 0);
16333 SDValue Add = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: B, N2: B);
16334 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1, N2: Add);
16335 }
16336 // fadd A, (fmul B, -2.0) --> fsub A, (fadd B, B)
16337 if (isFMulNegTwo(N1)) {
16338 SDValue B = N1.getOperand(i: 0);
16339 SDValue Add = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: B, N2: B);
16340 return DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: N0, N2: Add);
16341 }
16342
16343 // No FP constant should be created after legalization as Instruction
16344 // Selection pass has a hard time dealing with FP constants.
16345 bool AllowNewConst = (Level < AfterLegalizeDAG);
16346
16347 // If nnan is enabled, fold lots of things.
16348 if ((Options.NoNaNsFPMath || Flags.hasNoNaNs()) && AllowNewConst) {
16349 // If allowed, fold (fadd (fneg x), x) -> 0.0
16350 if (N0.getOpcode() == ISD::FNEG && N0.getOperand(i: 0) == N1)
16351 return DAG.getConstantFP(Val: 0.0, DL, VT);
16352
16353 // If allowed, fold (fadd x, (fneg x)) -> 0.0
16354 if (N1.getOpcode() == ISD::FNEG && N1.getOperand(i: 0) == N0)
16355 return DAG.getConstantFP(Val: 0.0, DL, VT);
16356 }
16357
16358 // If 'unsafe math' or reassoc and nsz, fold lots of things.
16359 // TODO: break out portions of the transformations below for which Unsafe is
16360 // considered and which do not require both nsz and reassoc
16361 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
16362 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
16363 AllowNewConst) {
16364 // fadd (fadd x, c1), c2 -> fadd x, c1 + c2
16365 if (N1CFP && N0.getOpcode() == ISD::FADD &&
16366 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1))) {
16367 SDValue NewC = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1), N2: N1);
16368 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 0), N2: NewC);
16369 }
16370
16371 // We can fold chains of FADD's of the same value into multiplications.
16372 // This transform is not safe in general because we are reducing the number
16373 // of rounding steps.
16374 if (TLI.isOperationLegalOrCustom(Op: ISD::FMUL, VT) && !N0CFP && !N1CFP) {
16375 if (N0.getOpcode() == ISD::FMUL) {
16376 SDNode *CFP00 =
16377 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 0));
16378 SDNode *CFP01 =
16379 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1));
16380
16381 // (fadd (fmul x, c), x) -> (fmul x, c+1)
16382 if (CFP01 && !CFP00 && N0.getOperand(i: 0) == N1) {
16383 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1),
16384 N2: DAG.getConstantFP(Val: 1.0, DL, VT));
16385 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1, N2: NewCFP);
16386 }
16387
16388 // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2)
16389 if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD &&
16390 N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
16391 N0.getOperand(i: 0) == N1.getOperand(i: 0)) {
16392 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0.getOperand(i: 1),
16393 N2: DAG.getConstantFP(Val: 2.0, DL, VT));
16394 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0), N2: NewCFP);
16395 }
16396 }
16397
16398 if (N1.getOpcode() == ISD::FMUL) {
16399 SDNode *CFP10 =
16400 DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 0));
16401 SDNode *CFP11 =
16402 DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 1));
16403
16404 // (fadd x, (fmul x, c)) -> (fmul x, c+1)
16405 if (CFP11 && !CFP10 && N1.getOperand(i: 0) == N0) {
16406 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N1.getOperand(i: 1),
16407 N2: DAG.getConstantFP(Val: 1.0, DL, VT));
16408 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: NewCFP);
16409 }
16410
16411 // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2)
16412 if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD &&
16413 N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
16414 N1.getOperand(i: 0) == N0.getOperand(i: 0)) {
16415 SDValue NewCFP = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N1.getOperand(i: 1),
16416 N2: DAG.getConstantFP(Val: 2.0, DL, VT));
16417 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N1.getOperand(i: 0), N2: NewCFP);
16418 }
16419 }
16420
16421 if (N0.getOpcode() == ISD::FADD) {
16422 SDNode *CFP00 =
16423 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 0));
16424 // (fadd (fadd x, x), x) -> (fmul x, 3.0)
16425 if (!CFP00 && N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
16426 (N0.getOperand(i: 0) == N1)) {
16427 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1,
16428 N2: DAG.getConstantFP(Val: 3.0, DL, VT));
16429 }
16430 }
16431
16432 if (N1.getOpcode() == ISD::FADD) {
16433 SDNode *CFP10 =
16434 DAG.isConstantFPBuildVectorOrConstantFP(N: N1.getOperand(i: 0));
16435 // (fadd x, (fadd x, x)) -> (fmul x, 3.0)
16436 if (!CFP10 && N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
16437 N1.getOperand(i: 0) == N0) {
16438 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0,
16439 N2: DAG.getConstantFP(Val: 3.0, DL, VT));
16440 }
16441 }
16442
16443 // (fadd (fadd x, x), (fadd x, x)) -> (fmul x, 4.0)
16444 if (N0.getOpcode() == ISD::FADD && N1.getOpcode() == ISD::FADD &&
16445 N0.getOperand(i: 0) == N0.getOperand(i: 1) &&
16446 N1.getOperand(i: 0) == N1.getOperand(i: 1) &&
16447 N0.getOperand(i: 0) == N1.getOperand(i: 0)) {
16448 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0),
16449 N2: DAG.getConstantFP(Val: 4.0, DL, VT));
16450 }
16451 }
16452
16453 // Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y))
16454 if (SDValue SD = reassociateReduction(RedOpc: ISD::VECREDUCE_FADD, Opc: ISD::FADD, DL,
16455 VT, N0, N1, Flags))
16456 return SD;
16457 } // enable-unsafe-fp-math
16458
16459 // FADD -> FMA combines:
16460 if (SDValue Fused = visitFADDForFMACombine<EmptyMatchContext>(N)) {
16461 if (Fused.getOpcode() != ISD::DELETED_NODE)
16462 AddToWorklist(N: Fused.getNode());
16463 return Fused;
16464 }
16465 return SDValue();
16466}
16467
16468SDValue DAGCombiner::visitSTRICT_FADD(SDNode *N) {
16469 SDValue Chain = N->getOperand(Num: 0);
16470 SDValue N0 = N->getOperand(Num: 1);
16471 SDValue N1 = N->getOperand(Num: 2);
16472 EVT VT = N->getValueType(ResNo: 0);
16473 EVT ChainVT = N->getValueType(ResNo: 1);
16474 SDLoc DL(N);
16475 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16476
16477 // fold (strict_fadd A, (fneg B)) -> (strict_fsub A, B)
16478 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::STRICT_FSUB, VT))
16479 if (SDValue NegN1 = TLI.getCheaperNegatedExpression(
16480 Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize)) {
16481 return DAG.getNode(Opcode: ISD::STRICT_FSUB, DL, VTList: DAG.getVTList(VT1: VT, VT2: ChainVT),
16482 Ops: {Chain, N0, NegN1});
16483 }
16484
16485 // fold (strict_fadd (fneg A), B) -> (strict_fsub B, A)
16486 if (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::STRICT_FSUB, VT))
16487 if (SDValue NegN0 = TLI.getCheaperNegatedExpression(
16488 Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize)) {
16489 return DAG.getNode(Opcode: ISD::STRICT_FSUB, DL, VTList: DAG.getVTList(VT1: VT, VT2: ChainVT),
16490 Ops: {Chain, N1, NegN0});
16491 }
16492 return SDValue();
16493}
16494
16495SDValue DAGCombiner::visitFSUB(SDNode *N) {
16496 SDValue N0 = N->getOperand(Num: 0);
16497 SDValue N1 = N->getOperand(Num: 1);
16498 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N: N0, AllowUndefs: true);
16499 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
16500 EVT VT = N->getValueType(ResNo: 0);
16501 SDLoc DL(N);
16502 const TargetOptions &Options = DAG.getTarget().Options;
16503 const SDNodeFlags Flags = N->getFlags();
16504 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16505
16506 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
16507 return R;
16508
16509 // fold (fsub c1, c2) -> c1-c2
16510 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FSUB, DL, VT, Ops: {N0, N1}))
16511 return C;
16512
16513 // fold vector ops
16514 if (VT.isVector())
16515 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16516 return FoldedVOp;
16517
16518 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
16519 return NewSel;
16520
16521 // (fsub A, 0) -> A
16522 if (N1CFP && N1CFP->isZero()) {
16523 if (!N1CFP->isNegative() || Options.NoSignedZerosFPMath ||
16524 Flags.hasNoSignedZeros()) {
16525 return N0;
16526 }
16527 }
16528
16529 if (N0 == N1) {
16530 // (fsub x, x) -> 0.0
16531 if (Options.NoNaNsFPMath || Flags.hasNoNaNs())
16532 return DAG.getConstantFP(Val: 0.0f, DL, VT);
16533 }
16534
16535 // (fsub -0.0, N1) -> -N1
16536 if (N0CFP && N0CFP->isZero()) {
16537 if (N0CFP->isNegative() ||
16538 (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) {
16539 // We cannot replace an FSUB(+-0.0,X) with FNEG(X) when denormals are
16540 // flushed to zero, unless all users treat denorms as zero (DAZ).
16541 // FIXME: This transform will change the sign of a NaN and the behavior
16542 // of a signaling NaN. It is only valid when a NoNaN flag is present.
16543 DenormalMode DenormMode = DAG.getDenormalMode(VT);
16544 if (DenormMode == DenormalMode::getIEEE()) {
16545 if (SDValue NegN1 =
16546 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
16547 return NegN1;
16548 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FNEG, VT))
16549 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1);
16550 }
16551 }
16552 }
16553
16554 if (((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
16555 (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros())) &&
16556 N1.getOpcode() == ISD::FADD) {
16557 // X - (X + Y) -> -Y
16558 if (N0 == N1->getOperand(Num: 0))
16559 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1->getOperand(Num: 1));
16560 // X - (Y + X) -> -Y
16561 if (N0 == N1->getOperand(Num: 1))
16562 return DAG.getNode(Opcode: ISD::FNEG, DL, VT, Operand: N1->getOperand(Num: 0));
16563 }
16564
16565 // fold (fsub A, (fneg B)) -> (fadd A, B)
16566 if (SDValue NegN1 =
16567 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
16568 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0, N2: NegN1);
16569
16570 // FSUB -> FMA combines:
16571 if (SDValue Fused = visitFSUBForFMACombine<EmptyMatchContext>(N)) {
16572 AddToWorklist(N: Fused.getNode());
16573 return Fused;
16574 }
16575
16576 return SDValue();
16577}
16578
16579// Transform IEEE Floats:
16580// (fmul C, (uitofp Pow2))
16581// -> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa))
16582// (fdiv C, (uitofp Pow2))
16583// -> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa))
16584//
16585// The rationale is fmul/fdiv by a power of 2 is just change the exponent, so
16586// there is no need for more than an add/sub.
16587//
16588// This is valid under the following circumstances:
16589// 1) We are dealing with IEEE floats
16590// 2) C is normal
16591// 3) The fmul/fdiv add/sub will not go outside of min/max exponent bounds.
16592// TODO: Much of this could also be used for generating `ldexp` on targets the
16593// prefer it.
16594SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
16595 EVT VT = N->getValueType(ResNo: 0);
16596 SDValue ConstOp, Pow2Op;
16597
16598 std::optional<int> Mantissa;
16599 auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
16600 if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
16601 return false;
16602
16603 ConstOp = peekThroughBitcasts(V: N->getOperand(Num: ConstOpIdx));
16604 Pow2Op = N->getOperand(Num: 1 - ConstOpIdx);
16605 if (Pow2Op.getOpcode() != ISD::UINT_TO_FP &&
16606 (Pow2Op.getOpcode() != ISD::SINT_TO_FP ||
16607 !DAG.computeKnownBits(Op: Pow2Op).isNonNegative()))
16608 return false;
16609
16610 Pow2Op = Pow2Op.getOperand(i: 0);
16611
16612 // `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
16613 // TODO: We could use knownbits to make this bound more precise.
16614 int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();
16615
16616 auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) {
16617 if (CFP == nullptr)
16618 return false;
16619
16620 const APFloat &APF = CFP->getValueAPF();
16621
16622 // Make sure we have normal/ieee constant.
16623 if (!APF.isNormal() || !APF.isIEEE())
16624 return false;
16625
16626 // Make sure the floats exponent is within the bounds that this transform
16627 // produces bitwise equals value.
16628 int CurExp = ilogb(Arg: APF);
16629 // FMul by pow2 will only increase exponent.
16630 int MinExp =
16631 N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
16632 // FDiv by pow2 will only decrease exponent.
16633 int MaxExp =
16634 N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
16635 if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
16636 MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
16637 return false;
16638
16639 // Finally make sure we actually know the mantissa for the float type.
16640 int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
16641 if (!Mantissa)
16642 Mantissa = ThisMantissa;
16643
16644 return *Mantissa == ThisMantissa && ThisMantissa > 0;
16645 };
16646
16647 // TODO: We may be able to include undefs.
16648 return ISD::matchUnaryFpPredicate(Op: ConstOp, Match: IsFPConstValid);
16649 };
16650
16651 if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
16652 return SDValue();
16653
16654 if (!TLI.optimizeFMulOrFDivAsShiftAddBitcast(N, FPConst: ConstOp, IntPow2: Pow2Op))
16655 return SDValue();
16656
16657 // Get log2 after all other checks have taken place. This is because
16658 // BuildLogBase2 may create a new node.
16659 SDLoc DL(N);
16660 // Get Log2 type with same bitwidth as the float type (VT).
16661 EVT NewIntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: VT.getScalarSizeInBits());
16662 if (VT.isVector())
16663 NewIntVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewIntVT,
16664 EC: VT.getVectorElementCount());
16665
16666 SDValue Log2 = BuildLogBase2(V: Pow2Op, DL, KnownNeverZero: DAG.isKnownNeverZero(Op: Pow2Op),
16667 /*InexpensiveOnly*/ true, OutVT: NewIntVT);
16668 if (!Log2)
16669 return SDValue();
16670
16671 // Perform actual transform.
16672 SDValue MantissaShiftCnt =
16673 DAG.getConstant(Val: *Mantissa, DL, VT: getShiftAmountTy(LHSTy: NewIntVT));
16674 // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
16675 // `(X << C1) + (C << C1)`, but that isn't always the case because of the
16676 // cast. We could implement that by handle here to handle the casts.
16677 SDValue Shift = DAG.getNode(Opcode: ISD::SHL, DL, VT: NewIntVT, N1: Log2, N2: MantissaShiftCnt);
16678 SDValue ResAsInt =
16679 DAG.getNode(Opcode: N->getOpcode() == ISD::FMUL ? ISD::ADD : ISD::SUB, DL,
16680 VT: NewIntVT, N1: DAG.getBitcast(VT: NewIntVT, V: ConstOp), N2: Shift);
16681 SDValue ResAsFP = DAG.getBitcast(VT, V: ResAsInt);
16682 return ResAsFP;
16683}
16684
16685SDValue DAGCombiner::visitFMUL(SDNode *N) {
16686 SDValue N0 = N->getOperand(Num: 0);
16687 SDValue N1 = N->getOperand(Num: 1);
16688 ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1, AllowUndefs: true);
16689 EVT VT = N->getValueType(ResNo: 0);
16690 SDLoc DL(N);
16691 const TargetOptions &Options = DAG.getTarget().Options;
16692 const SDNodeFlags Flags = N->getFlags();
16693 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16694
16695 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
16696 return R;
16697
16698 // fold (fmul c1, c2) -> c1*c2
16699 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FMUL, DL, VT, Ops: {N0, N1}))
16700 return C;
16701
16702 // canonicalize constant to RHS
16703 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
16704 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
16705 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1, N2: N0);
16706
16707 // fold vector ops
16708 if (VT.isVector())
16709 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
16710 return FoldedVOp;
16711
16712 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
16713 return NewSel;
16714
16715 if (Options.UnsafeFPMath || Flags.hasAllowReassociation()) {
16716 // fmul (fmul X, C1), C2 -> fmul X, C1 * C2
16717 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
16718 N0.getOpcode() == ISD::FMUL) {
16719 SDValue N00 = N0.getOperand(i: 0);
16720 SDValue N01 = N0.getOperand(i: 1);
16721 // Avoid an infinite loop by making sure that N00 is not a constant
16722 // (the inner multiply has not been constant folded yet).
16723 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N01) &&
16724 !DAG.isConstantFPBuildVectorOrConstantFP(N: N00)) {
16725 SDValue MulConsts = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N01, N2: N1);
16726 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N00, N2: MulConsts);
16727 }
16728 }
16729
16730 // Match a special-case: we convert X * 2.0 into fadd.
16731 // fmul (fadd X, X), C -> fmul X, 2.0 * C
16732 if (N0.getOpcode() == ISD::FADD && N0.hasOneUse() &&
16733 N0.getOperand(i: 0) == N0.getOperand(i: 1)) {
16734 const SDValue Two = DAG.getConstantFP(Val: 2.0, DL, VT);
16735 SDValue MulConsts = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Two, N2: N1);
16736 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0.getOperand(i: 0), N2: MulConsts);
16737 }
16738
16739 // Fold fmul(vecreduce(x), vecreduce(y)) -> vecreduce(fmul(x, y))
16740 if (SDValue SD = reassociateReduction(RedOpc: ISD::VECREDUCE_FMUL, Opc: ISD::FMUL, DL,
16741 VT, N0, N1, Flags))
16742 return SD;
16743 }
16744
16745 // fold (fmul X, 2.0) -> (fadd X, X)
16746 if (N1CFP && N1CFP->isExactlyValue(V: +2.0))
16747 return DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: N0, N2: N0);
16748
16749 // fold (fmul X, -1.0) -> (fsub -0.0, X)
16750 if (N1CFP && N1CFP->isExactlyValue(V: -1.0)) {
16751 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FSUB, VT)) {
16752 return DAG.getNode(Opcode: ISD::FSUB, DL, VT,
16753 N1: DAG.getConstantFP(Val: -0.0, DL, VT), N2: N0, Flags);
16754 }
16755 }
16756
16757 // -N0 * -N1 --> N0 * N1
16758 TargetLowering::NegatibleCost CostN0 =
16759 TargetLowering::NegatibleCost::Expensive;
16760 TargetLowering::NegatibleCost CostN1 =
16761 TargetLowering::NegatibleCost::Expensive;
16762 SDValue NegN0 =
16763 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
16764 if (NegN0) {
16765 HandleSDNode NegN0Handle(NegN0);
16766 SDValue NegN1 =
16767 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
16768 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
16769 CostN1 == TargetLowering::NegatibleCost::Cheaper))
16770 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: NegN0, N2: NegN1);
16771 }
16772
16773 // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X))
16774 // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X)
16775 if (Flags.hasNoNaNs() && Flags.hasNoSignedZeros() &&
16776 (N0.getOpcode() == ISD::SELECT || N1.getOpcode() == ISD::SELECT) &&
16777 TLI.isOperationLegal(Op: ISD::FABS, VT)) {
16778 SDValue Select = N0, X = N1;
16779 if (Select.getOpcode() != ISD::SELECT)
16780 std::swap(a&: Select, b&: X);
16781
16782 SDValue Cond = Select.getOperand(i: 0);
16783 auto TrueOpnd = dyn_cast<ConstantFPSDNode>(Val: Select.getOperand(i: 1));
16784 auto FalseOpnd = dyn_cast<ConstantFPSDNode>(Val: Select.getOperand(i: 2));
16785
16786 if (TrueOpnd && FalseOpnd &&
16787 Cond.getOpcode() == ISD::SETCC && Cond.getOperand(i: 0) == X &&
16788 isa<ConstantFPSDNode>(Val: Cond.getOperand(i: 1)) &&
16789 cast<ConstantFPSDNode>(Val: Cond.getOperand(i: 1))->isExactlyValue(V: 0.0)) {
16790 ISD::CondCode CC = cast<CondCodeSDNode>(Val: Cond.getOperand(i: 2))->get();
16791 switch (CC) {
16792 default: break;
16793 case ISD::SETOLT:
16794 case ISD::SETULT:
16795 case ISD::SETOLE:
16796 case ISD::SETULE:
16797 case ISD::SETLT:
16798 case ISD::SETLE:
16799 std::swap(a&: TrueOpnd, b&: FalseOpnd);
16800 [[fallthrough]];
16801 case ISD::SETOGT:
16802 case ISD::SETUGT:
16803 case ISD::SETOGE:
16804 case ISD::SETUGE:
16805 case ISD::SETGT:
16806 case ISD::SETGE:
16807 if (TrueOpnd->isExactlyValue(V: -1.0) && FalseOpnd->isExactlyValue(V: 1.0) &&
16808 TLI.isOperationLegal(Op: ISD::FNEG, VT))
16809 return DAG.getNode(Opcode: ISD::FNEG, DL, VT,
16810 Operand: DAG.getNode(Opcode: ISD::FABS, DL, VT, Operand: X));
16811 if (TrueOpnd->isExactlyValue(V: 1.0) && FalseOpnd->isExactlyValue(V: -1.0))
16812 return DAG.getNode(Opcode: ISD::FABS, DL, VT, Operand: X);
16813
16814 break;
16815 }
16816 }
16817 }
16818
16819 // FMUL -> FMA combines:
16820 if (SDValue Fused = visitFMULForFMADistributiveCombine(N)) {
16821 AddToWorklist(N: Fused.getNode());
16822 return Fused;
16823 }
16824
16825 // Don't do `combineFMulOrFDivWithIntPow2` until after FMUL -> FMA has been
16826 // able to run.
16827 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
16828 return R;
16829
16830 return SDValue();
16831}
16832
16833template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
16834 SDValue N0 = N->getOperand(Num: 0);
16835 SDValue N1 = N->getOperand(Num: 1);
16836 SDValue N2 = N->getOperand(Num: 2);
16837 ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(Val&: N0);
16838 ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(Val&: N1);
16839 EVT VT = N->getValueType(ResNo: 0);
16840 SDLoc DL(N);
16841 const TargetOptions &Options = DAG.getTarget().Options;
16842 // FMA nodes have flags that propagate to the created nodes.
16843 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
16844 MatchContextClass matcher(DAG, TLI, N);
16845
16846 bool CanReassociate =
16847 Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
16848
16849 // Constant fold FMA.
16850 if (isa<ConstantFPSDNode>(Val: N0) &&
16851 isa<ConstantFPSDNode>(Val: N1) &&
16852 isa<ConstantFPSDNode>(Val: N2)) {
16853 return matcher.getNode(ISD::FMA, DL, VT, N0, N1, N2);
16854 }
16855
16856 // (-N0 * -N1) + N2 --> (N0 * N1) + N2
16857 TargetLowering::NegatibleCost CostN0 =
16858 TargetLowering::NegatibleCost::Expensive;
16859 TargetLowering::NegatibleCost CostN1 =
16860 TargetLowering::NegatibleCost::Expensive;
16861 SDValue NegN0 =
16862 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
16863 if (NegN0) {
16864 HandleSDNode NegN0Handle(NegN0);
16865 SDValue NegN1 =
16866 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
16867 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
16868 CostN1 == TargetLowering::NegatibleCost::Cheaper))
16869 return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
16870 }
16871
16872 // FIXME: use fast math flags instead of Options.UnsafeFPMath
16873 if (Options.UnsafeFPMath) {
16874 if (N0CFP && N0CFP->isZero())
16875 return N2;
16876 if (N1CFP && N1CFP->isZero())
16877 return N2;
16878 }
16879
16880 // FIXME: Support splat of constant.
16881 if (N0CFP && N0CFP->isExactlyValue(V: 1.0))
16882 return matcher.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
16883 if (N1CFP && N1CFP->isExactlyValue(V: 1.0))
16884 return matcher.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
16885
16886 // Canonicalize (fma c, x, y) -> (fma x, c, y)
16887 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
16888 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
16889 return matcher.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
16890
16891 if (CanReassociate) {
16892 // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
16893 if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(i: 0) &&
16894 DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
16895 DAG.isConstantFPBuildVectorOrConstantFP(N: N2.getOperand(i: 1))) {
16896 return matcher.getNode(
16897 ISD::FMUL, DL, VT, N0,
16898 matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(i: 1)));
16899 }
16900
16901 // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
16902 if (matcher.match(N0, ISD::FMUL) &&
16903 DAG.isConstantFPBuildVectorOrConstantFP(N: N1) &&
16904 DAG.isConstantFPBuildVectorOrConstantFP(N: N0.getOperand(i: 1))) {
16905 return matcher.getNode(
16906 ISD::FMA, DL, VT, N0.getOperand(i: 0),
16907 matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(i: 1)), N2);
16908 }
16909 }
16910
16911 // (fma x, -1, y) -> (fadd (fneg x), y)
16912 // FIXME: Support splat of constant.
16913 if (N1CFP) {
16914 if (N1CFP->isExactlyValue(V: 1.0))
16915 return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
16916
16917 if (N1CFP->isExactlyValue(V: -1.0) &&
16918 (!LegalOperations || TLI.isOperationLegal(Op: ISD::FNEG, VT))) {
16919 SDValue RHSNeg = matcher.getNode(ISD::FNEG, DL, VT, N0);
16920 AddToWorklist(N: RHSNeg.getNode());
16921 return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
16922 }
16923
16924 // fma (fneg x), K, y -> fma x -K, y
16925 if (matcher.match(N0, ISD::FNEG) &&
16926 (TLI.isOperationLegal(Op: ISD::ConstantFP, VT) ||
16927 (N1.hasOneUse() &&
16928 !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
16929 return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(i: 0),
16930 matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
16931 }
16932 }
16933
16934 // FIXME: Support splat of constant.
16935 if (CanReassociate) {
16936 // (fma x, c, x) -> (fmul x, (c+1))
16937 if (N1CFP && N0 == N2) {
16938 return matcher.getNode(ISD::FMUL, DL, VT, N0,
16939 matcher.getNode(ISD::FADD, DL, VT, N1,
16940 DAG.getConstantFP(Val: 1.0, DL, VT)));
16941 }
16942
16943 // (fma x, c, (fneg x)) -> (fmul x, (c-1))
16944 if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(i: 0) == N0) {
16945 return matcher.getNode(ISD::FMUL, DL, VT, N0,
16946 matcher.getNode(ISD::FADD, DL, VT, N1,
16947 DAG.getConstantFP(Val: -1.0, DL, VT)));
16948 }
16949 }
16950
16951 // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z))
16952 // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z))
16953 if (!TLI.isFNegFree(VT))
16954 if (SDValue Neg = TLI.getCheaperNegatedExpression(
16955 Op: SDValue(N, 0), DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
16956 return matcher.getNode(ISD::FNEG, DL, VT, Neg);
16957 return SDValue();
16958}
16959
16960SDValue DAGCombiner::visitFMAD(SDNode *N) {
16961 SDValue N0 = N->getOperand(Num: 0);
16962 SDValue N1 = N->getOperand(Num: 1);
16963 SDValue N2 = N->getOperand(Num: 2);
16964 EVT VT = N->getValueType(ResNo: 0);
16965 SDLoc DL(N);
16966
16967 // Constant fold FMAD.
16968 if (isa<ConstantFPSDNode>(Val: N0) && isa<ConstantFPSDNode>(Val: N1) &&
16969 isa<ConstantFPSDNode>(Val: N2))
16970 return DAG.getNode(Opcode: ISD::FMAD, DL, VT, N1: N0, N2: N1, N3: N2);
16971
16972 return SDValue();
16973}
16974
16975// Combine multiple FDIVs with the same divisor into multiple FMULs by the
16976// reciprocal.
16977// E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
16978// Notice that this is not always beneficial. One reason is different targets
16979// may have different costs for FDIV and FMUL, so sometimes the cost of two
16980// FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
16981// is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
16982SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
16983 // TODO: Limit this transform based on optsize/minsize - it always creates at
16984 // least 1 extra instruction. But the perf win may be substantial enough
16985 // that only minsize should restrict this.
16986 bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath;
16987 const SDNodeFlags Flags = N->getFlags();
16988 if (LegalDAG || (!UnsafeMath && !Flags.hasAllowReciprocal()))
16989 return SDValue();
16990
16991 // Skip if current node is a reciprocal/fneg-reciprocal.
16992 SDValue N0 = N->getOperand(Num: 0), N1 = N->getOperand(Num: 1);
16993 ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N: N0, /* AllowUndefs */ true);
16994 if (N0CFP && (N0CFP->isExactlyValue(V: 1.0) || N0CFP->isExactlyValue(V: -1.0)))
16995 return SDValue();
16996
16997 // Exit early if the target does not want this transform or if there can't
16998 // possibly be enough uses of the divisor to make the transform worthwhile.
16999 unsigned MinUses = TLI.combineRepeatedFPDivisors();
17000
17001 // For splat vectors, scale the number of uses by the splat factor. If we can
17002 // convert the division into a scalar op, that will likely be much faster.
17003 unsigned NumElts = 1;
17004 EVT VT = N->getValueType(ResNo: 0);
17005 if (VT.isVector() && DAG.isSplatValue(V: N1))
17006 NumElts = VT.getVectorMinNumElements();
17007
17008 if (!MinUses || (N1->use_size() * NumElts) < MinUses)
17009 return SDValue();
17010
17011 // Find all FDIV users of the same divisor.
17012 // Use a set because duplicates may be present in the user list.
17013 SetVector<SDNode *> Users;
17014 for (auto *U : N1->uses()) {
17015 if (U->getOpcode() == ISD::FDIV && U->getOperand(Num: 1) == N1) {
17016 // Skip X/sqrt(X) that has not been simplified to sqrt(X) yet.
17017 if (U->getOperand(Num: 1).getOpcode() == ISD::FSQRT &&
17018 U->getOperand(Num: 0) == U->getOperand(Num: 1).getOperand(i: 0) &&
17019 U->getFlags().hasAllowReassociation() &&
17020 U->getFlags().hasNoSignedZeros())
17021 continue;
17022
17023 // This division is eligible for optimization only if global unsafe math
17024 // is enabled or if this division allows reciprocal formation.
17025 if (UnsafeMath || U->getFlags().hasAllowReciprocal())
17026 Users.insert(X: U);
17027 }
17028 }
17029
17030 // Now that we have the actual number of divisor uses, make sure it meets
17031 // the minimum threshold specified by the target.
17032 if ((Users.size() * NumElts) < MinUses)
17033 return SDValue();
17034
17035 SDLoc DL(N);
17036 SDValue FPOne = DAG.getConstantFP(Val: 1.0, DL, VT);
17037 SDValue Reciprocal = DAG.getNode(Opcode: ISD::FDIV, DL, VT, N1: FPOne, N2: N1, Flags);
17038
17039 // Dividend / Divisor -> Dividend * Reciprocal
17040 for (auto *U : Users) {
17041 SDValue Dividend = U->getOperand(Num: 0);
17042 if (Dividend != FPOne) {
17043 SDValue NewNode = DAG.getNode(Opcode: ISD::FMUL, DL: SDLoc(U), VT, N1: Dividend,
17044 N2: Reciprocal, Flags);
17045 CombineTo(N: U, Res: NewNode);
17046 } else if (U != Reciprocal.getNode()) {
17047 // In the absence of fast-math-flags, this user node is always the
17048 // same node as Reciprocal, but with FMF they may be different nodes.
17049 CombineTo(N: U, Res: Reciprocal);
17050 }
17051 }
17052 return SDValue(N, 0); // N was replaced.
17053}
17054
17055SDValue DAGCombiner::visitFDIV(SDNode *N) {
17056 SDValue N0 = N->getOperand(Num: 0);
17057 SDValue N1 = N->getOperand(Num: 1);
17058 EVT VT = N->getValueType(ResNo: 0);
17059 SDLoc DL(N);
17060 const TargetOptions &Options = DAG.getTarget().Options;
17061 SDNodeFlags Flags = N->getFlags();
17062 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17063
17064 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
17065 return R;
17066
17067 // fold (fdiv c1, c2) -> c1/c2
17068 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FDIV, DL, VT, Ops: {N0, N1}))
17069 return C;
17070
17071 // fold vector ops
17072 if (VT.isVector())
17073 if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
17074 return FoldedVOp;
17075
17076 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
17077 return NewSel;
17078
17079 if (SDValue V = combineRepeatedFPDivisors(N))
17080 return V;
17081
17082 if (Options.UnsafeFPMath || Flags.hasAllowReciprocal()) {
17083 // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
17084 if (auto *N1CFP = dyn_cast<ConstantFPSDNode>(Val&: N1)) {
17085 // Compute the reciprocal 1.0 / c2.
17086 const APFloat &N1APF = N1CFP->getValueAPF();
17087 APFloat Recip(N1APF.getSemantics(), 1); // 1.0
17088 APFloat::opStatus st = Recip.divide(RHS: N1APF, RM: APFloat::rmNearestTiesToEven);
17089 // Only do the transform if the reciprocal is a legal fp immediate that
17090 // isn't too nasty (eg NaN, denormal, ...).
17091 if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty
17092 (!LegalOperations ||
17093 // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM
17094 // backend)... we should handle this gracefully after Legalize.
17095 // TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT) ||
17096 TLI.isOperationLegal(Op: ISD::ConstantFP, VT) ||
17097 TLI.isFPImmLegal(Recip, VT, ForCodeSize)))
17098 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0,
17099 N2: DAG.getConstantFP(Val: Recip, DL, VT));
17100 }
17101
17102 // If this FDIV is part of a reciprocal square root, it may be folded
17103 // into a target-specific square root estimate instruction.
17104 if (N1.getOpcode() == ISD::FSQRT) {
17105 if (SDValue RV = buildRsqrtEstimate(Op: N1.getOperand(i: 0), Flags))
17106 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
17107 } else if (N1.getOpcode() == ISD::FP_EXTEND &&
17108 N1.getOperand(i: 0).getOpcode() == ISD::FSQRT) {
17109 if (SDValue RV =
17110 buildRsqrtEstimate(Op: N1.getOperand(i: 0).getOperand(i: 0), Flags)) {
17111 RV = DAG.getNode(Opcode: ISD::FP_EXTEND, DL: SDLoc(N1), VT, Operand: RV);
17112 AddToWorklist(N: RV.getNode());
17113 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
17114 }
17115 } else if (N1.getOpcode() == ISD::FP_ROUND &&
17116 N1.getOperand(i: 0).getOpcode() == ISD::FSQRT) {
17117 if (SDValue RV =
17118 buildRsqrtEstimate(Op: N1.getOperand(i: 0).getOperand(i: 0), Flags)) {
17119 RV = DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N1), VT, N1: RV, N2: N1.getOperand(i: 1));
17120 AddToWorklist(N: RV.getNode());
17121 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: RV);
17122 }
17123 } else if (N1.getOpcode() == ISD::FMUL) {
17124 // Look through an FMUL. Even though this won't remove the FDIV directly,
17125 // it's still worthwhile to get rid of the FSQRT if possible.
17126 SDValue Sqrt, Y;
17127 if (N1.getOperand(i: 0).getOpcode() == ISD::FSQRT) {
17128 Sqrt = N1.getOperand(i: 0);
17129 Y = N1.getOperand(i: 1);
17130 } else if (N1.getOperand(i: 1).getOpcode() == ISD::FSQRT) {
17131 Sqrt = N1.getOperand(i: 1);
17132 Y = N1.getOperand(i: 0);
17133 }
17134 if (Sqrt.getNode()) {
17135 // If the other multiply operand is known positive, pull it into the
17136 // sqrt. That will eliminate the division if we convert to an estimate.
17137 if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
17138 N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
17139 SDValue A;
17140 if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
17141 A = Y.getOperand(i: 0);
17142 else if (Y == Sqrt.getOperand(i: 0))
17143 A = Y;
17144 if (A) {
17145 // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
17146 // X / (A * sqrt(A)) --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
17147 SDValue AA = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: A, N2: A);
17148 SDValue AAZ =
17149 DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AA, N2: Sqrt.getOperand(i: 0));
17150 if (SDValue Rsqrt = buildRsqrtEstimate(Op: AAZ, Flags))
17151 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: Rsqrt);
17152
17153 // Estimate creation failed. Clean up speculatively created nodes.
17154 recursivelyDeleteUnusedNodes(N: AAZ.getNode());
17155 }
17156 }
17157
17158 // We found a FSQRT, so try to make this fold:
17159 // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
17160 if (SDValue Rsqrt = buildRsqrtEstimate(Op: Sqrt.getOperand(i: 0), Flags)) {
17161 SDValue Div = DAG.getNode(Opcode: ISD::FDIV, DL: SDLoc(N1), VT, N1: Rsqrt, N2: Y);
17162 AddToWorklist(N: Div.getNode());
17163 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N0, N2: Div);
17164 }
17165 }
17166 }
17167
17168 // Fold into a reciprocal estimate and multiply instead of a real divide.
17169 if (Options.NoInfsFPMath || Flags.hasNoInfs())
17170 if (SDValue RV = BuildDivEstimate(N: N0, Op: N1, Flags))
17171 return RV;
17172 }
17173
17174 // Fold X/Sqrt(X) -> Sqrt(X)
17175 if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
17176 (Options.UnsafeFPMath || Flags.hasAllowReassociation()))
17177 if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(i: 0))
17178 return N1;
17179
17180 // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
17181 TargetLowering::NegatibleCost CostN0 =
17182 TargetLowering::NegatibleCost::Expensive;
17183 TargetLowering::NegatibleCost CostN1 =
17184 TargetLowering::NegatibleCost::Expensive;
17185 SDValue NegN0 =
17186 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN0);
17187 if (NegN0) {
17188 HandleSDNode NegN0Handle(NegN0);
17189 SDValue NegN1 =
17190 TLI.getNegatedExpression(Op: N1, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize, Cost&: CostN1);
17191 if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
17192 CostN1 == TargetLowering::NegatibleCost::Cheaper))
17193 return DAG.getNode(Opcode: ISD::FDIV, DL: SDLoc(N), VT, N1: NegN0, N2: NegN1);
17194 }
17195
17196 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
17197 return R;
17198
17199 return SDValue();
17200}
17201
17202SDValue DAGCombiner::visitFREM(SDNode *N) {
17203 SDValue N0 = N->getOperand(Num: 0);
17204 SDValue N1 = N->getOperand(Num: 1);
17205 EVT VT = N->getValueType(ResNo: 0);
17206 SDNodeFlags Flags = N->getFlags();
17207 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17208
17209 if (SDValue R = DAG.simplifyFPBinop(Opcode: N->getOpcode(), X: N0, Y: N1, Flags))
17210 return R;
17211
17212 // fold (frem c1, c2) -> fmod(c1,c2)
17213 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: ISD::FREM, DL: SDLoc(N), VT, Ops: {N0, N1}))
17214 return C;
17215
17216 if (SDValue NewSel = foldBinOpIntoSelect(BO: N))
17217 return NewSel;
17218
17219 return SDValue();
17220}
17221
17222SDValue DAGCombiner::visitFSQRT(SDNode *N) {
17223 SDNodeFlags Flags = N->getFlags();
17224 const TargetOptions &Options = DAG.getTarget().Options;
17225
17226 // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as:
17227 // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN
17228 if (!Flags.hasApproximateFuncs() ||
17229 (!Options.NoInfsFPMath && !Flags.hasNoInfs()))
17230 return SDValue();
17231
17232 SDValue N0 = N->getOperand(Num: 0);
17233 if (TLI.isFsqrtCheap(X: N0, DAG))
17234 return SDValue();
17235
17236 // FSQRT nodes have flags that propagate to the created nodes.
17237 // TODO: If this is N0/sqrt(N0), and we reach this node before trying to
17238 // transform the fdiv, we may produce a sub-optimal estimate sequence
17239 // because the reciprocal calculation may not have to filter out a
17240 // 0.0 input.
17241 return buildSqrtEstimate(Op: N0, Flags);
17242}
17243
17244/// copysign(x, fp_extend(y)) -> copysign(x, y)
17245/// copysign(x, fp_round(y)) -> copysign(x, y)
17246/// Operands to the functions are the type of X and Y respectively.
17247static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(EVT XTy, EVT YTy) {
17248 // Always fold no-op FP casts.
17249 if (XTy == YTy)
17250 return true;
17251
17252 // Do not optimize out type conversion of f128 type yet.
17253 // For some targets like x86_64, configuration is changed to keep one f128
17254 // value in one SSE register, but instruction selection cannot handle
17255 // FCOPYSIGN on SSE registers yet.
17256 if (YTy == MVT::f128)
17257 return false;
17258
17259 return !YTy.isVector() || EnableVectorFCopySignExtendRound;
17260}
17261
17262static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) {
17263 SDValue N1 = N->getOperand(Num: 1);
17264 if (N1.getOpcode() != ISD::FP_EXTEND &&
17265 N1.getOpcode() != ISD::FP_ROUND)
17266 return false;
17267 EVT N1VT = N1->getValueType(ResNo: 0);
17268 EVT N1Op0VT = N1->getOperand(Num: 0).getValueType();
17269 return CanCombineFCOPYSIGN_EXTEND_ROUND(XTy: N1VT, YTy: N1Op0VT);
17270}
17271
17272SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
17273 SDValue N0 = N->getOperand(Num: 0);
17274 SDValue N1 = N->getOperand(Num: 1);
17275 EVT VT = N->getValueType(ResNo: 0);
17276
17277 // fold (fcopysign c1, c2) -> fcopysign(c1,c2)
17278 if (SDValue C =
17279 DAG.FoldConstantArithmetic(Opcode: ISD::FCOPYSIGN, DL: SDLoc(N), VT, Ops: {N0, N1}))
17280 return C;
17281
17282 if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N: N->getOperand(Num: 1))) {
17283 const APFloat &V = N1C->getValueAPF();
17284 // copysign(x, c1) -> fabs(x) iff ispos(c1)
17285 // copysign(x, c1) -> fneg(fabs(x)) iff isneg(c1)
17286 if (!V.isNegative()) {
17287 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FABS, VT))
17288 return DAG.getNode(Opcode: ISD::FABS, DL: SDLoc(N), VT, Operand: N0);
17289 } else {
17290 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::FNEG, VT))
17291 return DAG.getNode(Opcode: ISD::FNEG, DL: SDLoc(N), VT,
17292 Operand: DAG.getNode(Opcode: ISD::FABS, DL: SDLoc(N0), VT, Operand: N0));
17293 }
17294 }
17295
17296 // copysign(fabs(x), y) -> copysign(x, y)
17297 // copysign(fneg(x), y) -> copysign(x, y)
17298 // copysign(copysign(x,z), y) -> copysign(x, y)
17299 if (N0.getOpcode() == ISD::FABS || N0.getOpcode() == ISD::FNEG ||
17300 N0.getOpcode() == ISD::FCOPYSIGN)
17301 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0), N2: N1);
17302
17303 // copysign(x, abs(y)) -> abs(x)
17304 if (N1.getOpcode() == ISD::FABS)
17305 return DAG.getNode(Opcode: ISD::FABS, DL: SDLoc(N), VT, Operand: N0);
17306
17307 // copysign(x, copysign(y,z)) -> copysign(x, z)
17308 if (N1.getOpcode() == ISD::FCOPYSIGN)
17309 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SDLoc(N), VT, N1: N0, N2: N1.getOperand(i: 1));
17310
17311 // copysign(x, fp_extend(y)) -> copysign(x, y)
17312 // copysign(x, fp_round(y)) -> copysign(x, y)
17313 if (CanCombineFCOPYSIGN_EXTEND_ROUND(N))
17314 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SDLoc(N), VT, N1: N0, N2: N1.getOperand(i: 0));
17315
17316 return SDValue();
17317}
17318
17319SDValue DAGCombiner::visitFPOW(SDNode *N) {
17320 ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N: N->getOperand(Num: 1));
17321 if (!ExponentC)
17322 return SDValue();
17323 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17324
17325 // Try to convert x ** (1/3) into cube root.
17326 // TODO: Handle the various flavors of long double.
17327 // TODO: Since we're approximating, we don't need an exact 1/3 exponent.
17328 // Some range near 1/3 should be fine.
17329 EVT VT = N->getValueType(ResNo: 0);
17330 if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
17331 (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
17332 // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
17333 // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
17334 // pow(-val, 1/3) = nan; cbrt(-val) = -num.
17335 // For regular numbers, rounding may cause the results to differ.
17336 // Therefore, we require { nsz ninf nnan afn } for this transform.
17337 // TODO: We could select out the special cases if we don't have nsz/ninf.
17338 SDNodeFlags Flags = N->getFlags();
17339 if (!Flags.hasNoSignedZeros() || !Flags.hasNoInfs() || !Flags.hasNoNaNs() ||
17340 !Flags.hasApproximateFuncs())
17341 return SDValue();
17342
17343 // Do not create a cbrt() libcall if the target does not have it, and do not
17344 // turn a pow that has lowering support into a cbrt() libcall.
17345 if (!DAG.getLibInfo().has(F: LibFunc_cbrt) ||
17346 (!DAG.getTargetLoweringInfo().isOperationExpand(Op: ISD::FPOW, VT) &&
17347 DAG.getTargetLoweringInfo().isOperationExpand(Op: ISD::FCBRT, VT)))
17348 return SDValue();
17349
17350 return DAG.getNode(Opcode: ISD::FCBRT, DL: SDLoc(N), VT, Operand: N->getOperand(Num: 0));
17351 }
17352
17353 // Try to convert x ** (1/4) and x ** (3/4) into square roots.
17354 // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case.
17355 // TODO: This could be extended (using a target hook) to handle smaller
17356 // power-of-2 fractional exponents.
17357 bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(V: 0.25);
17358 bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(V: 0.75);
17359 if (ExponentIs025 || ExponentIs075) {
17360 // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0.
17361 // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) = NaN.
17362 // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0.
17363 // pow(-inf, 0.75) = +inf; sqrt(-inf) * sqrt(sqrt(-inf)) = NaN.
17364 // For regular numbers, rounding may cause the results to differ.
17365 // Therefore, we require { nsz ninf afn } for this transform.
17366 // TODO: We could select out the special cases if we don't have nsz/ninf.
17367 SDNodeFlags Flags = N->getFlags();
17368
17369 // We only need no signed zeros for the 0.25 case.
17370 if ((!Flags.hasNoSignedZeros() && ExponentIs025) || !Flags.hasNoInfs() ||
17371 !Flags.hasApproximateFuncs())
17372 return SDValue();
17373
17374 // Don't double the number of libcalls. We are trying to inline fast code.
17375 if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(Op: ISD::FSQRT, VT))
17376 return SDValue();
17377
17378 // Assume that libcalls are the smallest code.
17379 // TODO: This restriction should probably be lifted for vectors.
17380 if (ForCodeSize)
17381 return SDValue();
17382
17383 // pow(X, 0.25) --> sqrt(sqrt(X))
17384 SDLoc DL(N);
17385 SDValue Sqrt = DAG.getNode(Opcode: ISD::FSQRT, DL, VT, Operand: N->getOperand(Num: 0));
17386 SDValue SqrtSqrt = DAG.getNode(Opcode: ISD::FSQRT, DL, VT, Operand: Sqrt);
17387 if (ExponentIs025)
17388 return SqrtSqrt;
17389 // pow(X, 0.75) --> sqrt(X) * sqrt(sqrt(X))
17390 return DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Sqrt, N2: SqrtSqrt);
17391 }
17392
17393 return SDValue();
17394}
17395
17396static SDValue foldFPToIntToFP(SDNode *N, SelectionDAG &DAG,
17397 const TargetLowering &TLI) {
17398 // We only do this if the target has legal ftrunc. Otherwise, we'd likely be
17399 // replacing casts with a libcall. We also must be allowed to ignore -0.0
17400 // because FTRUNC will return -0.0 for (-1.0, -0.0), but using integer
17401 // conversions would return +0.0.
17402 // FIXME: We should be able to use node-level FMF here.
17403 // TODO: If strict math, should we use FABS (+ range check for signed cast)?
17404 EVT VT = N->getValueType(ResNo: 0);
17405 if (!TLI.isOperationLegal(Op: ISD::FTRUNC, VT) ||
17406 !DAG.getTarget().Options.NoSignedZerosFPMath)
17407 return SDValue();
17408
17409 // fptosi/fptoui round towards zero, so converting from FP to integer and
17410 // back is the same as an 'ftrunc': [us]itofp (fpto[us]i X) --> ftrunc X
17411 SDValue N0 = N->getOperand(Num: 0);
17412 if (N->getOpcode() == ISD::SINT_TO_FP && N0.getOpcode() == ISD::FP_TO_SINT &&
17413 N0.getOperand(i: 0).getValueType() == VT)
17414 return DAG.getNode(Opcode: ISD::FTRUNC, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
17415
17416 if (N->getOpcode() == ISD::UINT_TO_FP && N0.getOpcode() == ISD::FP_TO_UINT &&
17417 N0.getOperand(i: 0).getValueType() == VT)
17418 return DAG.getNode(Opcode: ISD::FTRUNC, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
17419
17420 return SDValue();
17421}
17422
17423SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) {
17424 SDValue N0 = N->getOperand(Num: 0);
17425 EVT VT = N->getValueType(ResNo: 0);
17426 EVT OpVT = N0.getValueType();
17427
17428 // [us]itofp(undef) = 0, because the result value is bounded.
17429 if (N0.isUndef())
17430 return DAG.getConstantFP(Val: 0.0, DL: SDLoc(N), VT);
17431
17432 // fold (sint_to_fp c1) -> c1fp
17433 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
17434 // ...but only if the target supports immediate floating-point values
17435 (!LegalOperations ||
17436 TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
17437 return DAG.getNode(Opcode: ISD::SINT_TO_FP, DL: SDLoc(N), VT, Operand: N0);
17438
17439 // If the input is a legal type, and SINT_TO_FP is not legal on this target,
17440 // but UINT_TO_FP is legal on this target, try to convert.
17441 if (!hasOperation(Opcode: ISD::SINT_TO_FP, VT: OpVT) &&
17442 hasOperation(Opcode: ISD::UINT_TO_FP, VT: OpVT)) {
17443 // If the sign bit is known to be zero, we can change this to UINT_TO_FP.
17444 if (DAG.SignBitIsZero(Op: N0))
17445 return DAG.getNode(Opcode: ISD::UINT_TO_FP, DL: SDLoc(N), VT, Operand: N0);
17446 }
17447
17448 // The next optimizations are desirable only if SELECT_CC can be lowered.
17449 // fold (sint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), -1.0, 0.0)
17450 if (N0.getOpcode() == ISD::SETCC && N0.getValueType() == MVT::i1 &&
17451 !VT.isVector() &&
17452 (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT))) {
17453 SDLoc DL(N);
17454 return DAG.getSelect(DL, VT, Cond: N0, LHS: DAG.getConstantFP(Val: -1.0, DL, VT),
17455 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
17456 }
17457
17458 // fold (sint_to_fp (zext (setcc x, y, cc))) ->
17459 // (select (setcc x, y, cc), 1.0, 0.0)
17460 if (N0.getOpcode() == ISD::ZERO_EXTEND &&
17461 N0.getOperand(i: 0).getOpcode() == ISD::SETCC && !VT.isVector() &&
17462 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT))) {
17463 SDLoc DL(N);
17464 return DAG.getSelect(DL, VT, Cond: N0.getOperand(i: 0),
17465 LHS: DAG.getConstantFP(Val: 1.0, DL, VT),
17466 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
17467 }
17468
17469 if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
17470 return FTrunc;
17471
17472 return SDValue();
17473}
17474
17475SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
17476 SDValue N0 = N->getOperand(Num: 0);
17477 EVT VT = N->getValueType(ResNo: 0);
17478 EVT OpVT = N0.getValueType();
17479
17480 // [us]itofp(undef) = 0, because the result value is bounded.
17481 if (N0.isUndef())
17482 return DAG.getConstantFP(Val: 0.0, DL: SDLoc(N), VT);
17483
17484 // fold (uint_to_fp c1) -> c1fp
17485 if (DAG.isConstantIntBuildVectorOrConstantInt(N: N0) &&
17486 // ...but only if the target supports immediate floating-point values
17487 (!LegalOperations ||
17488 TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT)))
17489 return DAG.getNode(Opcode: ISD::UINT_TO_FP, DL: SDLoc(N), VT, Operand: N0);
17490
17491 // If the input is a legal type, and UINT_TO_FP is not legal on this target,
17492 // but SINT_TO_FP is legal on this target, try to convert.
17493 if (!hasOperation(Opcode: ISD::UINT_TO_FP, VT: OpVT) &&
17494 hasOperation(Opcode: ISD::SINT_TO_FP, VT: OpVT)) {
17495 // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
17496 if (DAG.SignBitIsZero(Op: N0))
17497 return DAG.getNode(Opcode: ISD::SINT_TO_FP, DL: SDLoc(N), VT, Operand: N0);
17498 }
17499
17500 // fold (uint_to_fp (setcc x, y, cc)) -> (select (setcc x, y, cc), 1.0, 0.0)
17501 if (N0.getOpcode() == ISD::SETCC && !VT.isVector() &&
17502 (!LegalOperations || TLI.isOperationLegalOrCustom(Op: ISD::ConstantFP, VT))) {
17503 SDLoc DL(N);
17504 return DAG.getSelect(DL, VT, Cond: N0, LHS: DAG.getConstantFP(Val: 1.0, DL, VT),
17505 RHS: DAG.getConstantFP(Val: 0.0, DL, VT));
17506 }
17507
17508 if (SDValue FTrunc = foldFPToIntToFP(N, DAG, TLI))
17509 return FTrunc;
17510
17511 return SDValue();
17512}
17513
17514// Fold (fp_to_{s/u}int ({s/u}int_to_fpx)) -> zext x, sext x, trunc x, or x
17515static SDValue FoldIntToFPToInt(SDNode *N, SelectionDAG &DAG) {
17516 SDValue N0 = N->getOperand(Num: 0);
17517 EVT VT = N->getValueType(ResNo: 0);
17518
17519 if (N0.getOpcode() != ISD::UINT_TO_FP && N0.getOpcode() != ISD::SINT_TO_FP)
17520 return SDValue();
17521
17522 SDValue Src = N0.getOperand(i: 0);
17523 EVT SrcVT = Src.getValueType();
17524 bool IsInputSigned = N0.getOpcode() == ISD::SINT_TO_FP;
17525 bool IsOutputSigned = N->getOpcode() == ISD::FP_TO_SINT;
17526
17527 // We can safely assume the conversion won't overflow the output range,
17528 // because (for example) (uint8_t)18293.f is undefined behavior.
17529
17530 // Since we can assume the conversion won't overflow, our decision as to
17531 // whether the input will fit in the float should depend on the minimum
17532 // of the input range and output range.
17533
17534 // This means this is also safe for a signed input and unsigned output, since
17535 // a negative input would lead to undefined behavior.
17536 unsigned InputSize = (int)SrcVT.getScalarSizeInBits() - IsInputSigned;
17537 unsigned OutputSize = (int)VT.getScalarSizeInBits();
17538 unsigned ActualSize = std::min(a: InputSize, b: OutputSize);
17539 const fltSemantics &sem = DAG.EVTToAPFloatSemantics(VT: N0.getValueType());
17540
17541 // We can only fold away the float conversion if the input range can be
17542 // represented exactly in the float range.
17543 if (APFloat::semanticsPrecision(sem) >= ActualSize) {
17544 if (VT.getScalarSizeInBits() > SrcVT.getScalarSizeInBits()) {
17545 unsigned ExtOp = IsInputSigned && IsOutputSigned ? ISD::SIGN_EXTEND
17546 : ISD::ZERO_EXTEND;
17547 return DAG.getNode(Opcode: ExtOp, DL: SDLoc(N), VT, Operand: Src);
17548 }
17549 if (VT.getScalarSizeInBits() < SrcVT.getScalarSizeInBits())
17550 return DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N), VT, Operand: Src);
17551 return DAG.getBitcast(VT, V: Src);
17552 }
17553 return SDValue();
17554}
17555
17556SDValue DAGCombiner::visitFP_TO_SINT(SDNode *N) {
17557 SDValue N0 = N->getOperand(Num: 0);
17558 EVT VT = N->getValueType(ResNo: 0);
17559
17560 // fold (fp_to_sint undef) -> undef
17561 if (N0.isUndef())
17562 return DAG.getUNDEF(VT);
17563
17564 // fold (fp_to_sint c1fp) -> c1
17565 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17566 return DAG.getNode(Opcode: ISD::FP_TO_SINT, DL: SDLoc(N), VT, Operand: N0);
17567
17568 return FoldIntToFPToInt(N, DAG);
17569}
17570
17571SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
17572 SDValue N0 = N->getOperand(Num: 0);
17573 EVT VT = N->getValueType(ResNo: 0);
17574
17575 // fold (fp_to_uint undef) -> undef
17576 if (N0.isUndef())
17577 return DAG.getUNDEF(VT);
17578
17579 // fold (fp_to_uint c1fp) -> c1
17580 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17581 return DAG.getNode(Opcode: ISD::FP_TO_UINT, DL: SDLoc(N), VT, Operand: N0);
17582
17583 return FoldIntToFPToInt(N, DAG);
17584}
17585
17586SDValue DAGCombiner::visitXRINT(SDNode *N) {
17587 SDValue N0 = N->getOperand(Num: 0);
17588 EVT VT = N->getValueType(ResNo: 0);
17589
17590 // fold (lrint|llrint undef) -> undef
17591 if (N0.isUndef())
17592 return DAG.getUNDEF(VT);
17593
17594 // fold (lrint|llrint c1fp) -> c1
17595 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17596 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, Operand: N0);
17597
17598 return SDValue();
17599}
17600
17601SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
17602 SDValue N0 = N->getOperand(Num: 0);
17603 SDValue N1 = N->getOperand(Num: 1);
17604 EVT VT = N->getValueType(ResNo: 0);
17605
17606 // fold (fp_round c1fp) -> c1fp
17607 if (SDValue C =
17608 DAG.FoldConstantArithmetic(Opcode: ISD::FP_ROUND, DL: SDLoc(N), VT, Ops: {N0, N1}))
17609 return C;
17610
17611 // fold (fp_round (fp_extend x)) -> x
17612 if (N0.getOpcode() == ISD::FP_EXTEND && VT == N0.getOperand(i: 0).getValueType())
17613 return N0.getOperand(i: 0);
17614
17615 // fold (fp_round (fp_round x)) -> (fp_round x)
17616 if (N0.getOpcode() == ISD::FP_ROUND) {
17617 const bool NIsTrunc = N->getConstantOperandVal(Num: 1) == 1;
17618 const bool N0IsTrunc = N0.getConstantOperandVal(i: 1) == 1;
17619
17620 // Avoid folding legal fp_rounds into non-legal ones.
17621 if (!hasOperation(Opcode: ISD::FP_ROUND, VT))
17622 return SDValue();
17623
17624 // Skip this folding if it results in an fp_round from f80 to f16.
17625 //
17626 // f80 to f16 always generates an expensive (and as yet, unimplemented)
17627 // libcall to __truncxfhf2 instead of selecting native f16 conversion
17628 // instructions from f32 or f64. Moreover, the first (value-preserving)
17629 // fp_round from f80 to either f32 or f64 may become a NOP in platforms like
17630 // x86.
17631 if (N0.getOperand(0).getValueType() == MVT::f80 && VT == MVT::f16)
17632 return SDValue();
17633
17634 // If the first fp_round isn't a value preserving truncation, it might
17635 // introduce a tie in the second fp_round, that wouldn't occur in the
17636 // single-step fp_round we want to fold to.
17637 // In other words, double rounding isn't the same as rounding.
17638 // Also, this is a value preserving truncation iff both fp_round's are.
17639 if (DAG.getTarget().Options.UnsafeFPMath || N0IsTrunc) {
17640 SDLoc DL(N);
17641 return DAG.getNode(
17642 Opcode: ISD::FP_ROUND, DL, VT, N1: N0.getOperand(i: 0),
17643 N2: DAG.getIntPtrConstant(Val: NIsTrunc && N0IsTrunc, DL, /*isTarget=*/true));
17644 }
17645 }
17646
17647 // fold (fp_round (copysign X, Y)) -> (copysign (fp_round X), Y)
17648 // Note: From a legality perspective, this is a two step transform. First,
17649 // we duplicate the fp_round to the arguments of the copysign, then we
17650 // eliminate the fp_round on Y. The second step requires an additional
17651 // predicate to match the implementation above.
17652 if (N0.getOpcode() == ISD::FCOPYSIGN && N0->hasOneUse() &&
17653 CanCombineFCOPYSIGN_EXTEND_ROUND(XTy: VT,
17654 YTy: N0.getValueType())) {
17655 SDValue Tmp = DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N0), VT,
17656 N1: N0.getOperand(i: 0), N2: N1);
17657 AddToWorklist(N: Tmp.getNode());
17658 return DAG.getNode(Opcode: ISD::FCOPYSIGN, DL: SDLoc(N), VT,
17659 N1: Tmp, N2: N0.getOperand(i: 1));
17660 }
17661
17662 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
17663 return NewVSel;
17664
17665 return SDValue();
17666}
17667
17668SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
17669 SDValue N0 = N->getOperand(Num: 0);
17670 EVT VT = N->getValueType(ResNo: 0);
17671
17672 if (VT.isVector())
17673 if (SDValue FoldedVOp = SimplifyVCastOp(N, DL: SDLoc(N)))
17674 return FoldedVOp;
17675
17676 // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
17677 if (N->hasOneUse() &&
17678 N->use_begin()->getOpcode() == ISD::FP_ROUND)
17679 return SDValue();
17680
17681 // fold (fp_extend c1fp) -> c1fp
17682 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17683 return DAG.getNode(Opcode: ISD::FP_EXTEND, DL: SDLoc(N), VT, Operand: N0);
17684
17685 // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op)
17686 if (N0.getOpcode() == ISD::FP16_TO_FP &&
17687 TLI.getOperationAction(Op: ISD::FP16_TO_FP, VT) == TargetLowering::Legal)
17688 return DAG.getNode(Opcode: ISD::FP16_TO_FP, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
17689
17690 // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
17691 // value of X.
17692 if (N0.getOpcode() == ISD::FP_ROUND
17693 && N0.getConstantOperandVal(i: 1) == 1) {
17694 SDValue In = N0.getOperand(i: 0);
17695 if (In.getValueType() == VT) return In;
17696 if (VT.bitsLT(VT: In.getValueType()))
17697 return DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N), VT,
17698 N1: In, N2: N0.getOperand(i: 1));
17699 return DAG.getNode(Opcode: ISD::FP_EXTEND, DL: SDLoc(N), VT, Operand: In);
17700 }
17701
17702 // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
17703 if (ISD::isNormalLoad(N: N0.getNode()) && N0.hasOneUse() &&
17704 TLI.isLoadExtLegalOrCustom(ExtType: ISD::EXTLOAD, ValVT: VT, MemVT: N0.getValueType())) {
17705 LoadSDNode *LN0 = cast<LoadSDNode>(Val&: N0);
17706 SDValue ExtLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: SDLoc(N), VT,
17707 Chain: LN0->getChain(),
17708 Ptr: LN0->getBasePtr(), MemVT: N0.getValueType(),
17709 MMO: LN0->getMemOperand());
17710 CombineTo(N, Res: ExtLoad);
17711 CombineTo(
17712 N: N0.getNode(),
17713 Res0: DAG.getNode(Opcode: ISD::FP_ROUND, DL: SDLoc(N0), VT: N0.getValueType(), N1: ExtLoad,
17714 N2: DAG.getIntPtrConstant(Val: 1, DL: SDLoc(N0), /*isTarget=*/true)),
17715 Res1: ExtLoad.getValue(R: 1));
17716 return SDValue(N, 0); // Return N so it doesn't get rechecked!
17717 }
17718
17719 if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(Cast: N))
17720 return NewVSel;
17721
17722 return SDValue();
17723}
17724
17725SDValue DAGCombiner::visitFCEIL(SDNode *N) {
17726 SDValue N0 = N->getOperand(Num: 0);
17727 EVT VT = N->getValueType(ResNo: 0);
17728
17729 // fold (fceil c1) -> fceil(c1)
17730 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17731 return DAG.getNode(Opcode: ISD::FCEIL, DL: SDLoc(N), VT, Operand: N0);
17732
17733 return SDValue();
17734}
17735
17736SDValue DAGCombiner::visitFTRUNC(SDNode *N) {
17737 SDValue N0 = N->getOperand(Num: 0);
17738 EVT VT = N->getValueType(ResNo: 0);
17739
17740 // fold (ftrunc c1) -> ftrunc(c1)
17741 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17742 return DAG.getNode(Opcode: ISD::FTRUNC, DL: SDLoc(N), VT, Operand: N0);
17743
17744 // fold ftrunc (known rounded int x) -> x
17745 // ftrunc is a part of fptosi/fptoui expansion on some targets, so this is
17746 // likely to be generated to extract integer from a rounded floating value.
17747 switch (N0.getOpcode()) {
17748 default: break;
17749 case ISD::FRINT:
17750 case ISD::FTRUNC:
17751 case ISD::FNEARBYINT:
17752 case ISD::FROUNDEVEN:
17753 case ISD::FFLOOR:
17754 case ISD::FCEIL:
17755 return N0;
17756 }
17757
17758 return SDValue();
17759}
17760
17761SDValue DAGCombiner::visitFFREXP(SDNode *N) {
17762 SDValue N0 = N->getOperand(Num: 0);
17763
17764 // fold (ffrexp c1) -> ffrexp(c1)
17765 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17766 return DAG.getNode(Opcode: ISD::FFREXP, DL: SDLoc(N), VTList: N->getVTList(), N: N0);
17767 return SDValue();
17768}
17769
17770SDValue DAGCombiner::visitFFLOOR(SDNode *N) {
17771 SDValue N0 = N->getOperand(Num: 0);
17772 EVT VT = N->getValueType(ResNo: 0);
17773
17774 // fold (ffloor c1) -> ffloor(c1)
17775 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17776 return DAG.getNode(Opcode: ISD::FFLOOR, DL: SDLoc(N), VT, Operand: N0);
17777
17778 return SDValue();
17779}
17780
17781SDValue DAGCombiner::visitFNEG(SDNode *N) {
17782 SDValue N0 = N->getOperand(Num: 0);
17783 EVT VT = N->getValueType(ResNo: 0);
17784 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17785
17786 // Constant fold FNEG.
17787 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17788 return DAG.getNode(Opcode: ISD::FNEG, DL: SDLoc(N), VT, Operand: N0);
17789
17790 if (SDValue NegN0 =
17791 TLI.getNegatedExpression(Op: N0, DAG, LegalOps: LegalOperations, OptForSize: ForCodeSize))
17792 return NegN0;
17793
17794 // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0
17795 // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't
17796 // know it was called from a context with a nsz flag if the input fsub does
17797 // not.
17798 if (N0.getOpcode() == ISD::FSUB &&
17799 (DAG.getTarget().Options.NoSignedZerosFPMath ||
17800 N->getFlags().hasNoSignedZeros()) && N0.hasOneUse()) {
17801 return DAG.getNode(Opcode: ISD::FSUB, DL: SDLoc(N), VT, N1: N0.getOperand(i: 1),
17802 N2: N0.getOperand(i: 0));
17803 }
17804
17805 if (SDValue Cast = foldSignChangeInBitcast(N))
17806 return Cast;
17807
17808 return SDValue();
17809}
17810
17811SDValue DAGCombiner::visitFMinMax(SDNode *N) {
17812 SDValue N0 = N->getOperand(Num: 0);
17813 SDValue N1 = N->getOperand(Num: 1);
17814 EVT VT = N->getValueType(ResNo: 0);
17815 const SDNodeFlags Flags = N->getFlags();
17816 unsigned Opc = N->getOpcode();
17817 bool PropagatesNaN = Opc == ISD::FMINIMUM || Opc == ISD::FMAXIMUM;
17818 bool IsMin = Opc == ISD::FMINNUM || Opc == ISD::FMINIMUM;
17819 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
17820
17821 // Constant fold.
17822 if (SDValue C = DAG.FoldConstantArithmetic(Opcode: Opc, DL: SDLoc(N), VT, Ops: {N0, N1}))
17823 return C;
17824
17825 // Canonicalize to constant on RHS.
17826 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0) &&
17827 !DAG.isConstantFPBuildVectorOrConstantFP(N: N1))
17828 return DAG.getNode(Opcode: N->getOpcode(), DL: SDLoc(N), VT, N1, N2: N0);
17829
17830 if (const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N: N1)) {
17831 const APFloat &AF = N1CFP->getValueAPF();
17832
17833 // minnum(X, nan) -> X
17834 // maxnum(X, nan) -> X
17835 // minimum(X, nan) -> nan
17836 // maximum(X, nan) -> nan
17837 if (AF.isNaN())
17838 return PropagatesNaN ? N->getOperand(Num: 1) : N->getOperand(Num: 0);
17839
17840 // In the following folds, inf can be replaced with the largest finite
17841 // float, if the ninf flag is set.
17842 if (AF.isInfinity() || (Flags.hasNoInfs() && AF.isLargest())) {
17843 // minnum(X, -inf) -> -inf
17844 // maxnum(X, +inf) -> +inf
17845 // minimum(X, -inf) -> -inf if nnan
17846 // maximum(X, +inf) -> +inf if nnan
17847 if (IsMin == AF.isNegative() && (!PropagatesNaN || Flags.hasNoNaNs()))
17848 return N->getOperand(Num: 1);
17849
17850 // minnum(X, +inf) -> X if nnan
17851 // maxnum(X, -inf) -> X if nnan
17852 // minimum(X, +inf) -> X
17853 // maximum(X, -inf) -> X
17854 if (IsMin != AF.isNegative() && (PropagatesNaN || Flags.hasNoNaNs()))
17855 return N->getOperand(Num: 0);
17856 }
17857 }
17858
17859 if (SDValue SD = reassociateReduction(
17860 RedOpc: PropagatesNaN
17861 ? (IsMin ? ISD::VECREDUCE_FMINIMUM : ISD::VECREDUCE_FMAXIMUM)
17862 : (IsMin ? ISD::VECREDUCE_FMIN : ISD::VECREDUCE_FMAX),
17863 Opc, DL: SDLoc(N), VT, N0, N1, Flags))
17864 return SD;
17865
17866 return SDValue();
17867}
17868
17869SDValue DAGCombiner::visitFABS(SDNode *N) {
17870 SDValue N0 = N->getOperand(Num: 0);
17871 EVT VT = N->getValueType(ResNo: 0);
17872
17873 // fold (fabs c1) -> fabs(c1)
17874 if (DAG.isConstantFPBuildVectorOrConstantFP(N: N0))
17875 return DAG.getNode(Opcode: ISD::FABS, DL: SDLoc(N), VT, Operand: N0);
17876
17877 // fold (fabs (fabs x)) -> (fabs x)
17878 if (N0.getOpcode() == ISD::FABS)
17879 return N->getOperand(Num: 0);
17880
17881 // fold (fabs (fneg x)) -> (fabs x)
17882 // fold (fabs (fcopysign x, y)) -> (fabs x)
17883 if (N0.getOpcode() == ISD::FNEG || N0.getOpcode() == ISD::FCOPYSIGN)
17884 return DAG.getNode(Opcode: ISD::FABS, DL: SDLoc(N), VT, Operand: N0.getOperand(i: 0));
17885
17886 if (SDValue Cast = foldSignChangeInBitcast(N))
17887 return Cast;
17888
17889 return SDValue();
17890}
17891
17892SDValue DAGCombiner::visitBRCOND(SDNode *N) {
17893 SDValue Chain = N->getOperand(Num: 0);
17894 SDValue N1 = N->getOperand(Num: 1);
17895 SDValue N2 = N->getOperand(Num: 2);
17896
17897 // BRCOND(FREEZE(cond)) is equivalent to BRCOND(cond) (both are
17898 // nondeterministic jumps).
17899 if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
17900 return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
17901 N1->getOperand(0), N2);
17902 }
17903
17904 // Variant of the previous fold where there is a SETCC in between:
17905 // BRCOND(SETCC(FREEZE(X), CONST, Cond))
17906 // =>
17907 // BRCOND(FREEZE(SETCC(X, CONST, Cond)))
17908 // =>
17909 // BRCOND(SETCC(X, CONST, Cond))
17910 // This is correct if FREEZE(X) has one use and SETCC(FREEZE(X), CONST, Cond)
17911 // isn't equivalent to true or false.
17912 // For example, SETCC(FREEZE(X), -128, SETULT) cannot be folded to
17913 // FREEZE(SETCC(X, -128, SETULT)) because X can be poison.
17914 if (N1->getOpcode() == ISD::SETCC && N1.hasOneUse()) {
17915 SDValue S0 = N1->getOperand(Num: 0), S1 = N1->getOperand(Num: 1);
17916 ISD::CondCode Cond = cast<CondCodeSDNode>(Val: N1->getOperand(Num: 2))->get();
17917 ConstantSDNode *S0C = dyn_cast<ConstantSDNode>(Val&: S0);
17918 ConstantSDNode *S1C = dyn_cast<ConstantSDNode>(Val&: S1);
17919 bool Updated = false;
17920
17921 // Is 'X Cond C' always true or false?
17922 auto IsAlwaysTrueOrFalse = [](ISD::CondCode Cond, ConstantSDNode *C) {
17923 bool False = (Cond == ISD::SETULT && C->isZero()) ||
17924 (Cond == ISD::SETLT && C->isMinSignedValue()) ||
17925 (Cond == ISD::SETUGT && C->isAllOnes()) ||
17926 (Cond == ISD::SETGT && C->isMaxSignedValue());
17927 bool True = (Cond == ISD::SETULE && C->isAllOnes()) ||
17928 (Cond == ISD::SETLE && C->isMaxSignedValue()) ||
17929 (Cond == ISD::SETUGE && C->isZero()) ||
17930 (Cond == ISD::SETGE && C->isMinSignedValue());
17931 return True || False;
17932 };
17933
17934 if (S0->getOpcode() == ISD::FREEZE && S0.hasOneUse() && S1C) {
17935 if (!IsAlwaysTrueOrFalse(Cond, S1C)) {
17936 S0 = S0->getOperand(Num: 0);
17937 Updated = true;
17938 }
17939 }
17940 if (S1->getOpcode() == ISD::FREEZE && S1.hasOneUse() && S0C) {
17941 if (!IsAlwaysTrueOrFalse(ISD::getSetCCSwappedOperands(Operation: Cond), S0C)) {
17942 S1 = S1->getOperand(Num: 0);
17943 Updated = true;
17944 }
17945 }
17946
17947 if (Updated)
17948 return DAG.getNode(
17949 ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
17950 DAG.getSetCC(SDLoc(N1), N1->getValueType(0), S0, S1, Cond), N2);
17951 }
17952
17953 // If N is a constant we could fold this into a fallthrough or unconditional
17954 // branch. However that doesn't happen very often in normal code, because
17955 // Instcombine/SimplifyCFG should have handled the available opportunities.
17956 // If we did this folding here, it would be necessary to update the
17957 // MachineBasicBlock CFG, which is awkward.
17958
17959 // fold a brcond with a setcc condition into a BR_CC node if BR_CC is legal
17960 // on the target.
17961 if (N1.getOpcode() == ISD::SETCC &&
17962 TLI.isOperationLegalOrCustom(Op: ISD::BR_CC,
17963 VT: N1.getOperand(i: 0).getValueType())) {
17964 return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
17965 Chain, N1.getOperand(2),
17966 N1.getOperand(0), N1.getOperand(1), N2);
17967 }
17968
17969 if (N1.hasOneUse()) {
17970 // rebuildSetCC calls visitXor which may change the Chain when there is a
17971 // STRICT_FSETCC/STRICT_FSETCCS involved. Use a handle to track changes.
17972 HandleSDNode ChainHandle(Chain);
17973 if (SDValue NewN1 = rebuildSetCC(N1))
17974 return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
17975 ChainHandle.getValue(), NewN1, N2);
17976 }
17977
17978 return SDValue();
17979}
17980
17981SDValue DAGCombiner::rebuildSetCC(SDValue N) {
17982 if (N.getOpcode() == ISD::SRL ||
17983 (N.getOpcode() == ISD::TRUNCATE &&
17984 (N.getOperand(i: 0).hasOneUse() &&
17985 N.getOperand(i: 0).getOpcode() == ISD::SRL))) {
17986 // Look pass the truncate.
17987 if (N.getOpcode() == ISD::TRUNCATE)
17988 N = N.getOperand(i: 0);
17989
17990 // Match this pattern so that we can generate simpler code:
17991 //
17992 // %a = ...
17993 // %b = and i32 %a, 2
17994 // %c = srl i32 %b, 1
17995 // brcond i32 %c ...
17996 //
17997 // into
17998 //
17999 // %a = ...
18000 // %b = and i32 %a, 2
18001 // %c = setcc eq %b, 0
18002 // brcond %c ...
18003 //
18004 // This applies only when the AND constant value has one bit set and the
18005 // SRL constant is equal to the log2 of the AND constant. The back-end is
18006 // smart enough to convert the result into a TEST/JMP sequence.
18007 SDValue Op0 = N.getOperand(i: 0);
18008 SDValue Op1 = N.getOperand(i: 1);
18009
18010 if (Op0.getOpcode() == ISD::AND && Op1.getOpcode() == ISD::Constant) {
18011 SDValue AndOp1 = Op0.getOperand(i: 1);
18012
18013 if (AndOp1.getOpcode() == ISD::Constant) {
18014 const APInt &AndConst = AndOp1->getAsAPIntVal();
18015
18016 if (AndConst.isPowerOf2() &&
18017 Op1->getAsAPIntVal() == AndConst.logBase2()) {
18018 SDLoc DL(N);
18019 return DAG.getSetCC(DL, VT: getSetCCResultType(VT: Op0.getValueType()),
18020 LHS: Op0, RHS: DAG.getConstant(Val: 0, DL, VT: Op0.getValueType()),
18021 Cond: ISD::SETNE);
18022 }
18023 }
18024 }
18025 }
18026
18027 // Transform (brcond (xor x, y)) -> (brcond (setcc, x, y, ne))
18028 // Transform (brcond (xor (xor x, y), -1)) -> (brcond (setcc, x, y, eq))
18029 if (N.getOpcode() == ISD::XOR) {
18030 // Because we may call this on a speculatively constructed
18031 // SimplifiedSetCC Node, we need to simplify this node first.
18032 // Ideally this should be folded into SimplifySetCC and not
18033 // here. For now, grab a handle to N so we don't lose it from
18034 // replacements interal to the visit.
18035 HandleSDNode XORHandle(N);
18036 while (N.getOpcode() == ISD::XOR) {
18037 SDValue Tmp = visitXOR(N: N.getNode());
18038 // No simplification done.
18039 if (!Tmp.getNode())
18040 break;
18041 // Returning N is form in-visit replacement that may invalidated
18042 // N. Grab value from Handle.
18043 if (Tmp.getNode() == N.getNode())
18044 N = XORHandle.getValue();
18045 else // Node simplified. Try simplifying again.
18046 N = Tmp;
18047 }
18048
18049 if (N.getOpcode() != ISD::XOR)
18050 return N;
18051
18052 SDValue Op0 = N->getOperand(Num: 0);
18053 SDValue Op1 = N->getOperand(Num: 1);
18054
18055 if (Op0.getOpcode() != ISD::SETCC && Op1.getOpcode() != ISD::SETCC) {
18056 bool Equal = false;
18057 // (brcond (xor (xor x, y), -1)) -> (brcond (setcc x, y, eq))
18058 if (isBitwiseNot(N) && Op0.hasOneUse() && Op0.getOpcode() == ISD::XOR &&
18059 Op0.getValueType() == MVT::i1) {
18060 N = Op0;
18061 Op0 = N->getOperand(Num: 0);
18062 Op1 = N->getOperand(Num: 1);
18063 Equal = true;
18064 }
18065
18066 EVT SetCCVT = N.getValueType();
18067 if (LegalTypes)
18068 SetCCVT = getSetCCResultType(VT: SetCCVT);
18069 // Replace the uses of XOR with SETCC
18070 return DAG.getSetCC(DL: SDLoc(N), VT: SetCCVT, LHS: Op0, RHS: Op1,
18071 Cond: Equal ? ISD::SETEQ : ISD::SETNE);
18072 }
18073 }
18074
18075 return SDValue();
18076}
18077
18078// Operand List for BR_CC: Chain, CondCC, CondLHS, CondRHS, DestBB.
18079//
18080SDValue DAGCombiner::visitBR_CC(SDNode *N) {
18081 CondCodeSDNode *CC = cast<CondCodeSDNode>(Val: N->getOperand(Num: 1));
18082 SDValue CondLHS = N->getOperand(Num: 2), CondRHS = N->getOperand(Num: 3);
18083
18084 // If N is a constant we could fold this into a fallthrough or unconditional
18085 // branch. However that doesn't happen very often in normal code, because
18086 // Instcombine/SimplifyCFG should have handled the available opportunities.
18087 // If we did this folding here, it would be necessary to update the
18088 // MachineBasicBlock CFG, which is awkward.
18089
18090 // Use SimplifySetCC to simplify SETCC's.
18091 SDValue Simp = SimplifySetCC(VT: getSetCCResultType(VT: CondLHS.getValueType()),
18092 N0: CondLHS, N1: CondRHS, Cond: CC->get(), DL: SDLoc(N),
18093 foldBooleans: false);
18094 if (Simp.getNode()) AddToWorklist(N: Simp.getNode());
18095
18096 // fold to a simpler setcc
18097 if (Simp.getNode() && Simp.getOpcode() == ISD::SETCC)
18098 return DAG.getNode(ISD::BR_CC, SDLoc(N), MVT::Other,
18099 N->getOperand(0), Simp.getOperand(2),
18100 Simp.getOperand(0), Simp.getOperand(1),
18101 N->getOperand(4));
18102
18103 return SDValue();
18104}
18105
18106static bool getCombineLoadStoreParts(SDNode *N, unsigned Inc, unsigned Dec,
18107 bool &IsLoad, bool &IsMasked, SDValue &Ptr,
18108 const TargetLowering &TLI) {
18109 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(Val: N)) {
18110 if (LD->isIndexed())
18111 return false;
18112 EVT VT = LD->getMemoryVT();
18113 if (!TLI.isIndexedLoadLegal(IdxMode: Inc, VT) && !TLI.isIndexedLoadLegal(IdxMode: Dec, VT))
18114 return false;
18115 Ptr = LD->getBasePtr();
18116 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(Val: N)) {
18117 if (ST->isIndexed())
18118 return false;
18119 EVT VT = ST->getMemoryVT();
18120 if (!TLI.isIndexedStoreLegal(IdxMode: Inc, VT) && !TLI.isIndexedStoreLegal(IdxMode: Dec, VT))
18121 return false;
18122 Ptr = ST->getBasePtr();
18123 IsLoad = false;
18124 } else if (MaskedLoadSDNode *LD = dyn_cast<MaskedLoadSDNode>(Val: N)) {
18125 if (LD->isIndexed())
18126 return false;
18127 EVT VT = LD->getMemoryVT();
18128 if (!TLI.isIndexedMaskedLoadLegal(IdxMode: Inc, VT) &&
18129 !TLI.isIndexedMaskedLoadLegal(IdxMode: Dec, VT))
18130 return false;
18131 Ptr = LD->getBasePtr();
18132 IsMasked = true;
18133 } else if (MaskedStoreSDNode *ST = dyn_cast<MaskedStoreSDNode>(Val: N)) {
18134 if (ST->isIndexed())
18135 return false;
18136 EVT VT = ST->getMemoryVT();
18137 if (!TLI.isIndexedMaskedStoreLegal(IdxMode: Inc, VT) &&
18138 !TLI.isIndexedMaskedStoreLegal(IdxMode: Dec, VT))
18139 return false;
18140 Ptr = ST->getBasePtr();
18141 IsLoad = false;
18142 IsMasked = true;
18143 } else {
18144 return false;
18145 }
18146 return true;
18147}
18148
18149/// Try turning a load/store into a pre-indexed load/store when the base
18150/// pointer is an add or subtract and it has other uses besides the load/store.
18151/// After the transformation, the new indexed load/store has effectively folded
18152/// the add/subtract in and all of its other uses are redirected to the
18153/// new load/store.
18154bool DAGCombiner::CombineToPreIndexedLoadStore(SDNode *N) {
18155 if (Level < AfterLegalizeDAG)
18156 return false;
18157
18158 bool IsLoad = true;
18159 bool IsMasked = false;
18160 SDValue Ptr;
18161 if (!getCombineLoadStoreParts(N, Inc: ISD::PRE_INC, Dec: ISD::PRE_DEC, IsLoad, IsMasked,
18162 Ptr, TLI))
18163 return false;
18164
18165 // If the pointer is not an add/sub, or if it doesn't have multiple uses, bail
18166 // out. There is no reason to make this a preinc/predec.
18167 if ((Ptr.getOpcode() != ISD::ADD && Ptr.getOpcode() != ISD::SUB) ||
18168 Ptr->hasOneUse())
18169 return false;
18170
18171 // Ask the target to do addressing mode selection.
18172 SDValue BasePtr;
18173 SDValue Offset;
18174 ISD::MemIndexedMode AM = ISD::UNINDEXED;
18175 if (!TLI.getPreIndexedAddressParts(N, BasePtr, Offset, AM, DAG))
18176 return false;
18177
18178 // Backends without true r+i pre-indexed forms may need to pass a
18179 // constant base with a variable offset so that constant coercion
18180 // will work with the patterns in canonical form.
18181 bool Swapped = false;
18182 if (isa<ConstantSDNode>(Val: BasePtr)) {
18183 std::swap(a&: BasePtr, b&: Offset);
18184 Swapped = true;
18185 }
18186
18187 // Don't create a indexed load / store with zero offset.
18188 if (isNullConstant(V: Offset))
18189 return false;
18190
18191 // Try turning it into a pre-indexed load / store except when:
18192 // 1) The new base ptr is a frame index.
18193 // 2) If N is a store and the new base ptr is either the same as or is a
18194 // predecessor of the value being stored.
18195 // 3) Another use of old base ptr is a predecessor of N. If ptr is folded
18196 // that would create a cycle.
18197 // 4) All uses are load / store ops that use it as old base ptr.
18198
18199 // Check #1. Preinc'ing a frame index would require copying the stack pointer
18200 // (plus the implicit offset) to a register to preinc anyway.
18201 if (isa<FrameIndexSDNode>(Val: BasePtr) || isa<RegisterSDNode>(Val: BasePtr))
18202 return false;
18203
18204 // Check #2.
18205 if (!IsLoad) {
18206 SDValue Val = IsMasked ? cast<MaskedStoreSDNode>(Val: N)->getValue()
18207 : cast<StoreSDNode>(Val: N)->getValue();
18208
18209 // Would require a copy.
18210 if (Val == BasePtr)
18211 return false;
18212
18213 // Would create a cycle.
18214 if (Val == Ptr || Ptr->isPredecessorOf(N: Val.getNode()))
18215 return false;
18216 }
18217
18218 // Caches for hasPredecessorHelper.
18219 SmallPtrSet<const SDNode *, 32> Visited;
18220 SmallVector<const SDNode *, 16> Worklist;
18221 Worklist.push_back(Elt: N);
18222
18223 // If the offset is a constant, there may be other adds of constants that
18224 // can be folded with this one. We should do this to avoid having to keep
18225 // a copy of the original base pointer.
18226 SmallVector<SDNode *, 16> OtherUses;
18227 constexpr unsigned int MaxSteps = 8192;
18228 if (isa<ConstantSDNode>(Val: Offset))
18229 for (SDNode::use_iterator UI = BasePtr->use_begin(),
18230 UE = BasePtr->use_end();
18231 UI != UE; ++UI) {
18232 SDUse &Use = UI.getUse();
18233 // Skip the use that is Ptr and uses of other results from BasePtr's
18234 // node (important for nodes that return multiple results).
18235 if (Use.getUser() == Ptr.getNode() || Use != BasePtr)
18236 continue;
18237
18238 if (SDNode::hasPredecessorHelper(N: Use.getUser(), Visited, Worklist,
18239 MaxSteps))
18240 continue;
18241
18242 if (Use.getUser()->getOpcode() != ISD::ADD &&
18243 Use.getUser()->getOpcode() != ISD::SUB) {
18244 OtherUses.clear();
18245 break;
18246 }
18247
18248 SDValue Op1 = Use.getUser()->getOperand(Num: (UI.getOperandNo() + 1) & 1);
18249 if (!isa<ConstantSDNode>(Val: Op1)) {
18250 OtherUses.clear();
18251 break;
18252 }
18253
18254 // FIXME: In some cases, we can be smarter about this.
18255 if (Op1.getValueType() != Offset.getValueType()) {
18256 OtherUses.clear();
18257 break;
18258 }
18259
18260 OtherUses.push_back(Elt: Use.getUser());
18261 }
18262
18263 if (Swapped)
18264 std::swap(a&: BasePtr, b&: Offset);
18265
18266 // Now check for #3 and #4.
18267 bool RealUse = false;
18268
18269 for (SDNode *Use : Ptr->uses()) {
18270 if (Use == N)
18271 continue;
18272 if (SDNode::hasPredecessorHelper(N: Use, Visited, Worklist, MaxSteps))
18273 return false;
18274
18275 // If Ptr may be folded in addressing mode of other use, then it's
18276 // not profitable to do this transformation.
18277 if (!canFoldInAddressingMode(N: Ptr.getNode(), Use, DAG, TLI))
18278 RealUse = true;
18279 }
18280
18281 if (!RealUse)
18282 return false;
18283
18284 SDValue Result;
18285 if (!IsMasked) {
18286 if (IsLoad)
18287 Result = DAG.getIndexedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr, Offset, AM);
18288 else
18289 Result =
18290 DAG.getIndexedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr, Offset, AM);
18291 } else {
18292 if (IsLoad)
18293 Result = DAG.getIndexedMaskedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
18294 Offset, AM);
18295 else
18296 Result = DAG.getIndexedMaskedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
18297 Offset, AM);
18298 }
18299 ++PreIndexedNodes;
18300 ++NodesCombined;
18301 LLVM_DEBUG(dbgs() << "\nReplacing.4 "; N->dump(&DAG); dbgs() << "\nWith: ";
18302 Result.dump(&DAG); dbgs() << '\n');
18303 WorklistRemover DeadNodes(*this);
18304 if (IsLoad) {
18305 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 0));
18306 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Result.getValue(R: 2));
18307 } else {
18308 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 1));
18309 }
18310
18311 // Finally, since the node is now dead, remove it from the graph.
18312 deleteAndRecombine(N);
18313
18314 if (Swapped)
18315 std::swap(a&: BasePtr, b&: Offset);
18316
18317 // Replace other uses of BasePtr that can be updated to use Ptr
18318 for (unsigned i = 0, e = OtherUses.size(); i != e; ++i) {
18319 unsigned OffsetIdx = 1;
18320 if (OtherUses[i]->getOperand(Num: OffsetIdx).getNode() == BasePtr.getNode())
18321 OffsetIdx = 0;
18322 assert(OtherUses[i]->getOperand(!OffsetIdx).getNode() ==
18323 BasePtr.getNode() && "Expected BasePtr operand");
18324
18325 // We need to replace ptr0 in the following expression:
18326 // x0 * offset0 + y0 * ptr0 = t0
18327 // knowing that
18328 // x1 * offset1 + y1 * ptr0 = t1 (the indexed load/store)
18329 //
18330 // where x0, x1, y0 and y1 in {-1, 1} are given by the types of the
18331 // indexed load/store and the expression that needs to be re-written.
18332 //
18333 // Therefore, we have:
18334 // t0 = (x0 * offset0 - x1 * y0 * y1 *offset1) + (y0 * y1) * t1
18335
18336 auto *CN = cast<ConstantSDNode>(Val: OtherUses[i]->getOperand(Num: OffsetIdx));
18337 const APInt &Offset0 = CN->getAPIntValue();
18338 const APInt &Offset1 = Offset->getAsAPIntVal();
18339 int X0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 1) ? -1 : 1;
18340 int Y0 = (OtherUses[i]->getOpcode() == ISD::SUB && OffsetIdx == 0) ? -1 : 1;
18341 int X1 = (AM == ISD::PRE_DEC && !Swapped) ? -1 : 1;
18342 int Y1 = (AM == ISD::PRE_DEC && Swapped) ? -1 : 1;
18343
18344 unsigned Opcode = (Y0 * Y1 < 0) ? ISD::SUB : ISD::ADD;
18345
18346 APInt CNV = Offset0;
18347 if (X0 < 0) CNV = -CNV;
18348 if (X1 * Y0 * Y1 < 0) CNV = CNV + Offset1;
18349 else CNV = CNV - Offset1;
18350
18351 SDLoc DL(OtherUses[i]);
18352
18353 // We can now generate the new expression.
18354 SDValue NewOp1 = DAG.getConstant(Val: CNV, DL, VT: CN->getValueType(ResNo: 0));
18355 SDValue NewOp2 = Result.getValue(R: IsLoad ? 1 : 0);
18356
18357 SDValue NewUse = DAG.getNode(Opcode,
18358 DL,
18359 VT: OtherUses[i]->getValueType(ResNo: 0), N1: NewOp1, N2: NewOp2);
18360 DAG.ReplaceAllUsesOfValueWith(From: SDValue(OtherUses[i], 0), To: NewUse);
18361 deleteAndRecombine(N: OtherUses[i]);
18362 }
18363
18364 // Replace the uses of Ptr with uses of the updated base value.
18365 DAG.ReplaceAllUsesOfValueWith(From: Ptr, To: Result.getValue(R: IsLoad ? 1 : 0));
18366 deleteAndRecombine(N: Ptr.getNode());
18367 AddToWorklist(N: Result.getNode());
18368
18369 return true;
18370}
18371
18372static bool shouldCombineToPostInc(SDNode *N, SDValue Ptr, SDNode *PtrUse,
18373 SDValue &BasePtr, SDValue &Offset,
18374 ISD::MemIndexedMode &AM,
18375 SelectionDAG &DAG,
18376 const TargetLowering &TLI) {
18377 if (PtrUse == N ||
18378 (PtrUse->getOpcode() != ISD::ADD && PtrUse->getOpcode() != ISD::SUB))
18379 return false;
18380
18381 if (!TLI.getPostIndexedAddressParts(N, PtrUse, BasePtr, Offset, AM, DAG))
18382 return false;
18383
18384 // Don't create a indexed load / store with zero offset.
18385 if (isNullConstant(V: Offset))
18386 return false;
18387
18388 if (isa<FrameIndexSDNode>(Val: BasePtr) || isa<RegisterSDNode>(Val: BasePtr))
18389 return false;
18390
18391 SmallPtrSet<const SDNode *, 32> Visited;
18392 for (SDNode *Use : BasePtr->uses()) {
18393 if (Use == Ptr.getNode())
18394 continue;
18395
18396 // No if there's a later user which could perform the index instead.
18397 if (isa<MemSDNode>(Val: Use)) {
18398 bool IsLoad = true;
18399 bool IsMasked = false;
18400 SDValue OtherPtr;
18401 if (getCombineLoadStoreParts(N: Use, Inc: ISD::POST_INC, Dec: ISD::POST_DEC, IsLoad,
18402 IsMasked, Ptr&: OtherPtr, TLI)) {
18403 SmallVector<const SDNode *, 2> Worklist;
18404 Worklist.push_back(Elt: Use);
18405 if (SDNode::hasPredecessorHelper(N, Visited, Worklist))
18406 return false;
18407 }
18408 }
18409
18410 // If all the uses are load / store addresses, then don't do the
18411 // transformation.
18412 if (Use->getOpcode() == ISD::ADD || Use->getOpcode() == ISD::SUB) {
18413 for (SDNode *UseUse : Use->uses())
18414 if (canFoldInAddressingMode(N: Use, Use: UseUse, DAG, TLI))
18415 return false;
18416 }
18417 }
18418 return true;
18419}
18420
18421static SDNode *getPostIndexedLoadStoreOp(SDNode *N, bool &IsLoad,
18422 bool &IsMasked, SDValue &Ptr,
18423 SDValue &BasePtr, SDValue &Offset,
18424 ISD::MemIndexedMode &AM,
18425 SelectionDAG &DAG,
18426 const TargetLowering &TLI) {
18427 if (!getCombineLoadStoreParts(N, Inc: ISD::POST_INC, Dec: ISD::POST_DEC, IsLoad,
18428 IsMasked, Ptr, TLI) ||
18429 Ptr->hasOneUse())
18430 return nullptr;
18431
18432 // Try turning it into a post-indexed load / store except when
18433 // 1) All uses are load / store ops that use it as base ptr (and
18434 // it may be folded as addressing mmode).
18435 // 2) Op must be independent of N, i.e. Op is neither a predecessor
18436 // nor a successor of N. Otherwise, if Op is folded that would
18437 // create a cycle.
18438 for (SDNode *Op : Ptr->uses()) {
18439 // Check for #1.
18440 if (!shouldCombineToPostInc(N, Ptr, PtrUse: Op, BasePtr, Offset, AM, DAG, TLI))
18441 continue;
18442
18443 // Check for #2.
18444 SmallPtrSet<const SDNode *, 32> Visited;
18445 SmallVector<const SDNode *, 8> Worklist;
18446 constexpr unsigned int MaxSteps = 8192;
18447 // Ptr is predecessor to both N and Op.
18448 Visited.insert(Ptr: Ptr.getNode());
18449 Worklist.push_back(Elt: N);
18450 Worklist.push_back(Elt: Op);
18451 if (!SDNode::hasPredecessorHelper(N, Visited, Worklist, MaxSteps) &&
18452 !SDNode::hasPredecessorHelper(N: Op, Visited, Worklist, MaxSteps))
18453 return Op;
18454 }
18455 return nullptr;
18456}
18457
18458/// Try to combine a load/store with a add/sub of the base pointer node into a
18459/// post-indexed load/store. The transformation folded the add/subtract into the
18460/// new indexed load/store effectively and all of its uses are redirected to the
18461/// new load/store.
18462bool DAGCombiner::CombineToPostIndexedLoadStore(SDNode *N) {
18463 if (Level < AfterLegalizeDAG)
18464 return false;
18465
18466 bool IsLoad = true;
18467 bool IsMasked = false;
18468 SDValue Ptr;
18469 SDValue BasePtr;
18470 SDValue Offset;
18471 ISD::MemIndexedMode AM = ISD::UNINDEXED;
18472 SDNode *Op = getPostIndexedLoadStoreOp(N, IsLoad, IsMasked, Ptr, BasePtr,
18473 Offset, AM, DAG, TLI);
18474 if (!Op)
18475 return false;
18476
18477 SDValue Result;
18478 if (!IsMasked)
18479 Result = IsLoad ? DAG.getIndexedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N), Base: BasePtr,
18480 Offset, AM)
18481 : DAG.getIndexedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N),
18482 Base: BasePtr, Offset, AM);
18483 else
18484 Result = IsLoad ? DAG.getIndexedMaskedLoad(OrigLoad: SDValue(N, 0), dl: SDLoc(N),
18485 Base: BasePtr, Offset, AM)
18486 : DAG.getIndexedMaskedStore(OrigStore: SDValue(N, 0), dl: SDLoc(N),
18487 Base: BasePtr, Offset, AM);
18488 ++PostIndexedNodes;
18489 ++NodesCombined;
18490 LLVM_DEBUG(dbgs() << "\nReplacing.5 "; N->dump(&DAG); dbgs() << "\nWith: ";
18491 Result.dump(&DAG); dbgs() << '\n');
18492 WorklistRemover DeadNodes(*this);
18493 if (IsLoad) {
18494 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 0));
18495 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Result.getValue(R: 2));
18496 } else {
18497 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Result.getValue(R: 1));
18498 }
18499
18500 // Finally, since the node is now dead, remove it from the graph.
18501 deleteAndRecombine(N);
18502
18503 // Replace the uses of Use with uses of the updated base value.
18504 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Op, 0),
18505 To: Result.getValue(R: IsLoad ? 1 : 0));
18506 deleteAndRecombine(N: Op);
18507 return true;
18508}
18509
18510/// Return the base-pointer arithmetic from an indexed \p LD.
18511SDValue DAGCombiner::SplitIndexingFromLoad(LoadSDNode *LD) {
18512 ISD::MemIndexedMode AM = LD->getAddressingMode();
18513 assert(AM != ISD::UNINDEXED);
18514 SDValue BP = LD->getOperand(Num: 1);
18515 SDValue Inc = LD->getOperand(Num: 2);
18516
18517 // Some backends use TargetConstants for load offsets, but don't expect
18518 // TargetConstants in general ADD nodes. We can convert these constants into
18519 // regular Constants (if the constant is not opaque).
18520 assert((Inc.getOpcode() != ISD::TargetConstant ||
18521 !cast<ConstantSDNode>(Inc)->isOpaque()) &&
18522 "Cannot split out indexing using opaque target constants");
18523 if (Inc.getOpcode() == ISD::TargetConstant) {
18524 ConstantSDNode *ConstInc = cast<ConstantSDNode>(Val&: Inc);
18525 Inc = DAG.getConstant(Val: *ConstInc->getConstantIntValue(), DL: SDLoc(Inc),
18526 VT: ConstInc->getValueType(ResNo: 0));
18527 }
18528
18529 unsigned Opc =
18530 (AM == ISD::PRE_INC || AM == ISD::POST_INC ? ISD::ADD : ISD::SUB);
18531 return DAG.getNode(Opcode: Opc, DL: SDLoc(LD), VT: BP.getSimpleValueType(), N1: BP, N2: Inc);
18532}
18533
18534static inline ElementCount numVectorEltsOrZero(EVT T) {
18535 return T.isVector() ? T.getVectorElementCount() : ElementCount::getFixed(MinVal: 0);
18536}
18537
18538bool DAGCombiner::getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val) {
18539 EVT STType = Val.getValueType();
18540 EVT STMemType = ST->getMemoryVT();
18541 if (STType == STMemType)
18542 return true;
18543 if (isTypeLegal(VT: STMemType))
18544 return false; // fail.
18545 if (STType.isFloatingPoint() && STMemType.isFloatingPoint() &&
18546 TLI.isOperationLegal(Op: ISD::FTRUNC, VT: STMemType)) {
18547 Val = DAG.getNode(Opcode: ISD::FTRUNC, DL: SDLoc(ST), VT: STMemType, Operand: Val);
18548 return true;
18549 }
18550 if (numVectorEltsOrZero(T: STType) == numVectorEltsOrZero(T: STMemType) &&
18551 STType.isInteger() && STMemType.isInteger()) {
18552 Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(ST), VT: STMemType, Operand: Val);
18553 return true;
18554 }
18555 if (STType.getSizeInBits() == STMemType.getSizeInBits()) {
18556 Val = DAG.getBitcast(VT: STMemType, V: Val);
18557 return true;
18558 }
18559 return false; // fail.
18560}
18561
18562bool DAGCombiner::extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val) {
18563 EVT LDMemType = LD->getMemoryVT();
18564 EVT LDType = LD->getValueType(ResNo: 0);
18565 assert(Val.getValueType() == LDMemType &&
18566 "Attempting to extend value of non-matching type");
18567 if (LDType == LDMemType)
18568 return true;
18569 if (LDMemType.isInteger() && LDType.isInteger()) {
18570 switch (LD->getExtensionType()) {
18571 case ISD::NON_EXTLOAD:
18572 Val = DAG.getBitcast(VT: LDType, V: Val);
18573 return true;
18574 case ISD::EXTLOAD:
18575 Val = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
18576 return true;
18577 case ISD::SEXTLOAD:
18578 Val = DAG.getNode(Opcode: ISD::SIGN_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
18579 return true;
18580 case ISD::ZEXTLOAD:
18581 Val = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(LD), VT: LDType, Operand: Val);
18582 return true;
18583 }
18584 }
18585 return false;
18586}
18587
18588StoreSDNode *DAGCombiner::getUniqueStoreFeeding(LoadSDNode *LD,
18589 int64_t &Offset) {
18590 SDValue Chain = LD->getOperand(Num: 0);
18591
18592 // Look through CALLSEQ_START.
18593 if (Chain.getOpcode() == ISD::CALLSEQ_START)
18594 Chain = Chain->getOperand(Num: 0);
18595
18596 StoreSDNode *ST = nullptr;
18597 SmallVector<SDValue, 8> Aliases;
18598 if (Chain.getOpcode() == ISD::TokenFactor) {
18599 // Look for unique store within the TokenFactor.
18600 for (SDValue Op : Chain->ops()) {
18601 StoreSDNode *Store = dyn_cast<StoreSDNode>(Val: Op.getNode());
18602 if (!Store)
18603 continue;
18604 BaseIndexOffset BasePtrLD = BaseIndexOffset::match(N: LD, DAG);
18605 BaseIndexOffset BasePtrST = BaseIndexOffset::match(N: Store, DAG);
18606 if (!BasePtrST.equalBaseIndex(Other: BasePtrLD, DAG, Off&: Offset))
18607 continue;
18608 // Make sure the store is not aliased with any nodes in TokenFactor.
18609 GatherAllAliases(N: Store, OriginalChain: Chain, Aliases);
18610 if (Aliases.empty() ||
18611 (Aliases.size() == 1 && Aliases.front().getNode() == Store))
18612 ST = Store;
18613 break;
18614 }
18615 } else {
18616 StoreSDNode *Store = dyn_cast<StoreSDNode>(Val: Chain.getNode());
18617 if (Store) {
18618 BaseIndexOffset BasePtrLD = BaseIndexOffset::match(N: LD, DAG);
18619 BaseIndexOffset BasePtrST = BaseIndexOffset::match(N: Store, DAG);
18620 if (BasePtrST.equalBaseIndex(Other: BasePtrLD, DAG, Off&: Offset))
18621 ST = Store;
18622 }
18623 }
18624
18625 return ST;
18626}
18627
18628SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
18629 if (OptLevel == CodeGenOptLevel::None || !LD->isSimple())
18630 return SDValue();
18631 SDValue Chain = LD->getOperand(Num: 0);
18632 int64_t Offset;
18633
18634 StoreSDNode *ST = getUniqueStoreFeeding(LD, Offset);
18635 // TODO: Relax this restriction for unordered atomics (see D66309)
18636 if (!ST || !ST->isSimple() || ST->getAddressSpace() != LD->getAddressSpace())
18637 return SDValue();
18638
18639 EVT LDType = LD->getValueType(ResNo: 0);
18640 EVT LDMemType = LD->getMemoryVT();
18641 EVT STMemType = ST->getMemoryVT();
18642 EVT STType = ST->getValue().getValueType();
18643
18644 // There are two cases to consider here:
18645 // 1. The store is fixed width and the load is scalable. In this case we
18646 // don't know at compile time if the store completely envelops the load
18647 // so we abandon the optimisation.
18648 // 2. The store is scalable and the load is fixed width. We could
18649 // potentially support a limited number of cases here, but there has been
18650 // no cost-benefit analysis to prove it's worth it.
18651 bool LdStScalable = LDMemType.isScalableVT();
18652 if (LdStScalable != STMemType.isScalableVT())
18653 return SDValue();
18654
18655 // If we are dealing with scalable vectors on a big endian platform the
18656 // calculation of offsets below becomes trickier, since we do not know at
18657 // compile time the absolute size of the vector. Until we've done more
18658 // analysis on big-endian platforms it seems better to bail out for now.
18659 if (LdStScalable && DAG.getDataLayout().isBigEndian())
18660 return SDValue();
18661
18662 // Normalize for Endianness. After this Offset=0 will denote that the least
18663 // significant bit in the loaded value maps to the least significant bit in
18664 // the stored value). With Offset=n (for n > 0) the loaded value starts at the
18665 // n:th least significant byte of the stored value.
18666 int64_t OrigOffset = Offset;
18667 if (DAG.getDataLayout().isBigEndian())
18668 Offset = ((int64_t)STMemType.getStoreSizeInBits().getFixedValue() -
18669 (int64_t)LDMemType.getStoreSizeInBits().getFixedValue()) /
18670 8 -
18671 Offset;
18672
18673 // Check that the stored value cover all bits that are loaded.
18674 bool STCoversLD;
18675
18676 TypeSize LdMemSize = LDMemType.getSizeInBits();
18677 TypeSize StMemSize = STMemType.getSizeInBits();
18678 if (LdStScalable)
18679 STCoversLD = (Offset == 0) && LdMemSize == StMemSize;
18680 else
18681 STCoversLD = (Offset >= 0) && (Offset * 8 + LdMemSize.getFixedValue() <=
18682 StMemSize.getFixedValue());
18683
18684 auto ReplaceLd = [&](LoadSDNode *LD, SDValue Val, SDValue Chain) -> SDValue {
18685 if (LD->isIndexed()) {
18686 // Cannot handle opaque target constants and we must respect the user's
18687 // request not to split indexes from loads.
18688 if (!canSplitIdx(LD))
18689 return SDValue();
18690 SDValue Idx = SplitIndexingFromLoad(LD);
18691 SDValue Ops[] = {Val, Idx, Chain};
18692 return CombineTo(N: LD, To: Ops, NumTo: 3);
18693 }
18694 return CombineTo(N: LD, Res0: Val, Res1: Chain);
18695 };
18696
18697 if (!STCoversLD)
18698 return SDValue();
18699
18700 // Memory as copy space (potentially masked).
18701 if (Offset == 0 && LDType == STType && STMemType == LDMemType) {
18702 // Simple case: Direct non-truncating forwarding
18703 if (LDType.getSizeInBits() == LdMemSize)
18704 return ReplaceLd(LD, ST->getValue(), Chain);
18705 // Can we model the truncate and extension with an and mask?
18706 if (STType.isInteger() && LDMemType.isInteger() && !STType.isVector() &&
18707 !LDMemType.isVector() && LD->getExtensionType() != ISD::SEXTLOAD) {
18708 // Mask to size of LDMemType
18709 auto Mask =
18710 DAG.getConstant(Val: APInt::getLowBitsSet(numBits: STType.getFixedSizeInBits(),
18711 loBitsSet: StMemSize.getFixedValue()),
18712 DL: SDLoc(ST), VT: STType);
18713 auto Val = DAG.getNode(Opcode: ISD::AND, DL: SDLoc(LD), VT: LDType, N1: ST->getValue(), N2: Mask);
18714 return ReplaceLd(LD, Val, Chain);
18715 }
18716 }
18717
18718 // Handle some cases for big-endian that would be Offset 0 and handled for
18719 // little-endian.
18720 SDValue Val = ST->getValue();
18721 if (DAG.getDataLayout().isBigEndian() && Offset > 0 && OrigOffset == 0) {
18722 if (STType.isInteger() && !STType.isVector() && LDType.isInteger() &&
18723 !LDType.isVector() && isTypeLegal(VT: STType) &&
18724 TLI.isOperationLegal(Op: ISD::SRL, VT: STType)) {
18725 Val = DAG.getNode(Opcode: ISD::SRL, DL: SDLoc(LD), VT: STType, N1: Val,
18726 N2: DAG.getConstant(Val: Offset * 8, DL: SDLoc(LD), VT: STType));
18727 Offset = 0;
18728 }
18729 }
18730
18731 // TODO: Deal with nonzero offset.
18732 if (LD->getBasePtr().isUndef() || Offset != 0)
18733 return SDValue();
18734 // Model necessary truncations / extenstions.
18735 // Truncate Value To Stored Memory Size.
18736 do {
18737 if (!getTruncatedStoreValue(ST, Val))
18738 continue;
18739 if (!isTypeLegal(VT: LDMemType))
18740 continue;
18741 if (STMemType != LDMemType) {
18742 // TODO: Support vectors? This requires extract_subvector/bitcast.
18743 if (!STMemType.isVector() && !LDMemType.isVector() &&
18744 STMemType.isInteger() && LDMemType.isInteger())
18745 Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(LD), VT: LDMemType, Operand: Val);
18746 else
18747 continue;
18748 }
18749 if (!extendLoadedValueToExtension(LD, Val))
18750 continue;
18751 return ReplaceLd(LD, Val, Chain);
18752 } while (false);
18753
18754 // On failure, cleanup dead nodes we may have created.
18755 if (Val->use_empty())
18756 deleteAndRecombine(N: Val.getNode());
18757 return SDValue();
18758}
18759
18760SDValue DAGCombiner::visitLOAD(SDNode *N) {
18761 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
18762 SDValue Chain = LD->getChain();
18763 SDValue Ptr = LD->getBasePtr();
18764
18765 // If load is not volatile and there are no uses of the loaded value (and
18766 // the updated indexed value in case of indexed loads), change uses of the
18767 // chain value into uses of the chain input (i.e. delete the dead load).
18768 // TODO: Allow this for unordered atomics (see D66309)
18769 if (LD->isSimple()) {
18770 if (N->getValueType(1) == MVT::Other) {
18771 // Unindexed loads.
18772 if (!N->hasAnyUseOfValue(Value: 0)) {
18773 // It's not safe to use the two value CombineTo variant here. e.g.
18774 // v1, chain2 = load chain1, loc
18775 // v2, chain3 = load chain2, loc
18776 // v3 = add v2, c
18777 // Now we replace use of chain2 with chain1. This makes the second load
18778 // isomorphic to the one we are deleting, and thus makes this load live.
18779 LLVM_DEBUG(dbgs() << "\nReplacing.6 "; N->dump(&DAG);
18780 dbgs() << "\nWith chain: "; Chain.dump(&DAG);
18781 dbgs() << "\n");
18782 WorklistRemover DeadNodes(*this);
18783 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Chain);
18784 AddUsersToWorklist(N: Chain.getNode());
18785 if (N->use_empty())
18786 deleteAndRecombine(N);
18787
18788 return SDValue(N, 0); // Return N so it doesn't get rechecked!
18789 }
18790 } else {
18791 // Indexed loads.
18792 assert(N->getValueType(2) == MVT::Other && "Malformed indexed loads?");
18793
18794 // If this load has an opaque TargetConstant offset, then we cannot split
18795 // the indexing into an add/sub directly (that TargetConstant may not be
18796 // valid for a different type of node, and we cannot convert an opaque
18797 // target constant into a regular constant).
18798 bool CanSplitIdx = canSplitIdx(LD);
18799
18800 if (!N->hasAnyUseOfValue(Value: 0) && (CanSplitIdx || !N->hasAnyUseOfValue(Value: 1))) {
18801 SDValue Undef = DAG.getUNDEF(VT: N->getValueType(ResNo: 0));
18802 SDValue Index;
18803 if (N->hasAnyUseOfValue(Value: 1) && CanSplitIdx) {
18804 Index = SplitIndexingFromLoad(LD);
18805 // Try to fold the base pointer arithmetic into subsequent loads and
18806 // stores.
18807 AddUsersToWorklist(N);
18808 } else
18809 Index = DAG.getUNDEF(VT: N->getValueType(ResNo: 1));
18810 LLVM_DEBUG(dbgs() << "\nReplacing.7 "; N->dump(&DAG);
18811 dbgs() << "\nWith: "; Undef.dump(&DAG);
18812 dbgs() << " and 2 other values\n");
18813 WorklistRemover DeadNodes(*this);
18814 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 0), To: Undef);
18815 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Index);
18816 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 2), To: Chain);
18817 deleteAndRecombine(N);
18818 return SDValue(N, 0); // Return N so it doesn't get rechecked!
18819 }
18820 }
18821 }
18822
18823 // If this load is directly stored, replace the load value with the stored
18824 // value.
18825 if (auto V = ForwardStoreValueToDirectLoad(LD))
18826 return V;
18827
18828 // Try to infer better alignment information than the load already has.
18829 if (OptLevel != CodeGenOptLevel::None && LD->isUnindexed() &&
18830 !LD->isAtomic()) {
18831 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
18832 if (*Alignment > LD->getAlign() &&
18833 isAligned(Lhs: *Alignment, SizeInBytes: LD->getSrcValueOffset())) {
18834 SDValue NewLoad = DAG.getExtLoad(
18835 ExtType: LD->getExtensionType(), dl: SDLoc(N), VT: LD->getValueType(ResNo: 0), Chain, Ptr,
18836 PtrInfo: LD->getPointerInfo(), MemVT: LD->getMemoryVT(), Alignment: *Alignment,
18837 MMOFlags: LD->getMemOperand()->getFlags(), AAInfo: LD->getAAInfo());
18838 // NewLoad will always be N as we are only refining the alignment
18839 assert(NewLoad.getNode() == N);
18840 (void)NewLoad;
18841 }
18842 }
18843 }
18844
18845 if (LD->isUnindexed()) {
18846 // Walk up chain skipping non-aliasing memory nodes.
18847 SDValue BetterChain = FindBetterChain(N: LD, Chain);
18848
18849 // If there is a better chain.
18850 if (Chain != BetterChain) {
18851 SDValue ReplLoad;
18852
18853 // Replace the chain to void dependency.
18854 if (LD->getExtensionType() == ISD::NON_EXTLOAD) {
18855 ReplLoad = DAG.getLoad(VT: N->getValueType(ResNo: 0), dl: SDLoc(LD),
18856 Chain: BetterChain, Ptr, MMO: LD->getMemOperand());
18857 } else {
18858 ReplLoad = DAG.getExtLoad(ExtType: LD->getExtensionType(), dl: SDLoc(LD),
18859 VT: LD->getValueType(ResNo: 0),
18860 Chain: BetterChain, Ptr, MemVT: LD->getMemoryVT(),
18861 MMO: LD->getMemOperand());
18862 }
18863
18864 // Create token factor to keep old chain connected.
18865 SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N),
18866 MVT::Other, Chain, ReplLoad.getValue(1));
18867
18868 // Replace uses with load result and token factor
18869 return CombineTo(N, Res0: ReplLoad.getValue(R: 0), Res1: Token);
18870 }
18871 }
18872
18873 // Try transforming N to an indexed load.
18874 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
18875 return SDValue(N, 0);
18876
18877 // Try to slice up N to more direct loads if the slices are mapped to
18878 // different register banks or pairing can take place.
18879 if (SliceUpLoad(N))
18880 return SDValue(N, 0);
18881
18882 return SDValue();
18883}
18884
18885namespace {
18886
18887/// Helper structure used to slice a load in smaller loads.
18888/// Basically a slice is obtained from the following sequence:
18889/// Origin = load Ty1, Base
18890/// Shift = srl Ty1 Origin, CstTy Amount
18891/// Inst = trunc Shift to Ty2
18892///
18893/// Then, it will be rewritten into:
18894/// Slice = load SliceTy, Base + SliceOffset
18895/// [Inst = zext Slice to Ty2], only if SliceTy <> Ty2
18896///
18897/// SliceTy is deduced from the number of bits that are actually used to
18898/// build Inst.
18899struct LoadedSlice {
18900 /// Helper structure used to compute the cost of a slice.
18901 struct Cost {
18902 /// Are we optimizing for code size.
18903 bool ForCodeSize = false;
18904
18905 /// Various cost.
18906 unsigned Loads = 0;
18907 unsigned Truncates = 0;
18908 unsigned CrossRegisterBanksCopies = 0;
18909 unsigned ZExts = 0;
18910 unsigned Shift = 0;
18911
18912 explicit Cost(bool ForCodeSize) : ForCodeSize(ForCodeSize) {}
18913
18914 /// Get the cost of one isolated slice.
18915 Cost(const LoadedSlice &LS, bool ForCodeSize)
18916 : ForCodeSize(ForCodeSize), Loads(1) {
18917 EVT TruncType = LS.Inst->getValueType(ResNo: 0);
18918 EVT LoadedType = LS.getLoadedType();
18919 if (TruncType != LoadedType &&
18920 !LS.DAG->getTargetLoweringInfo().isZExtFree(FromTy: LoadedType, ToTy: TruncType))
18921 ZExts = 1;
18922 }
18923
18924 /// Account for slicing gain in the current cost.
18925 /// Slicing provide a few gains like removing a shift or a
18926 /// truncate. This method allows to grow the cost of the original
18927 /// load with the gain from this slice.
18928 void addSliceGain(const LoadedSlice &LS) {
18929 // Each slice saves a truncate.
18930 const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo();
18931 if (!TLI.isTruncateFree(Val: LS.Inst->getOperand(Num: 0), VT2: LS.Inst->getValueType(ResNo: 0)))
18932 ++Truncates;
18933 // If there is a shift amount, this slice gets rid of it.
18934 if (LS.Shift)
18935 ++Shift;
18936 // If this slice can merge a cross register bank copy, account for it.
18937 if (LS.canMergeExpensiveCrossRegisterBankCopy())
18938 ++CrossRegisterBanksCopies;
18939 }
18940
18941 Cost &operator+=(const Cost &RHS) {
18942 Loads += RHS.Loads;
18943 Truncates += RHS.Truncates;
18944 CrossRegisterBanksCopies += RHS.CrossRegisterBanksCopies;
18945 ZExts += RHS.ZExts;
18946 Shift += RHS.Shift;
18947 return *this;
18948 }
18949
18950 bool operator==(const Cost &RHS) const {
18951 return Loads == RHS.Loads && Truncates == RHS.Truncates &&
18952 CrossRegisterBanksCopies == RHS.CrossRegisterBanksCopies &&
18953 ZExts == RHS.ZExts && Shift == RHS.Shift;
18954 }
18955
18956 bool operator!=(const Cost &RHS) const { return !(*this == RHS); }
18957
18958 bool operator<(const Cost &RHS) const {
18959 // Assume cross register banks copies are as expensive as loads.
18960 // FIXME: Do we want some more target hooks?
18961 unsigned ExpensiveOpsLHS = Loads + CrossRegisterBanksCopies;
18962 unsigned ExpensiveOpsRHS = RHS.Loads + RHS.CrossRegisterBanksCopies;
18963 // Unless we are optimizing for code size, consider the
18964 // expensive operation first.
18965 if (!ForCodeSize && ExpensiveOpsLHS != ExpensiveOpsRHS)
18966 return ExpensiveOpsLHS < ExpensiveOpsRHS;
18967 return (Truncates + ZExts + Shift + ExpensiveOpsLHS) <
18968 (RHS.Truncates + RHS.ZExts + RHS.Shift + ExpensiveOpsRHS);
18969 }
18970
18971 bool operator>(const Cost &RHS) const { return RHS < *this; }
18972
18973 bool operator<=(const Cost &RHS) const { return !(RHS < *this); }
18974
18975 bool operator>=(const Cost &RHS) const { return !(*this < RHS); }
18976 };
18977
18978 // The last instruction that represent the slice. This should be a
18979 // truncate instruction.
18980 SDNode *Inst;
18981
18982 // The original load instruction.
18983 LoadSDNode *Origin;
18984
18985 // The right shift amount in bits from the original load.
18986 unsigned Shift;
18987
18988 // The DAG from which Origin came from.
18989 // This is used to get some contextual information about legal types, etc.
18990 SelectionDAG *DAG;
18991
18992 LoadedSlice(SDNode *Inst = nullptr, LoadSDNode *Origin = nullptr,
18993 unsigned Shift = 0, SelectionDAG *DAG = nullptr)
18994 : Inst(Inst), Origin(Origin), Shift(Shift), DAG(DAG) {}
18995
18996 /// Get the bits used in a chunk of bits \p BitWidth large.
18997 /// \return Result is \p BitWidth and has used bits set to 1 and
18998 /// not used bits set to 0.
18999 APInt getUsedBits() const {
19000 // Reproduce the trunc(lshr) sequence:
19001 // - Start from the truncated value.
19002 // - Zero extend to the desired bit width.
19003 // - Shift left.
19004 assert(Origin && "No original load to compare against.");
19005 unsigned BitWidth = Origin->getValueSizeInBits(ResNo: 0);
19006 assert(Inst && "This slice is not bound to an instruction");
19007 assert(Inst->getValueSizeInBits(0) <= BitWidth &&
19008 "Extracted slice is bigger than the whole type!");
19009 APInt UsedBits(Inst->getValueSizeInBits(ResNo: 0), 0);
19010 UsedBits.setAllBits();
19011 UsedBits = UsedBits.zext(width: BitWidth);
19012 UsedBits <<= Shift;
19013 return UsedBits;
19014 }
19015
19016 /// Get the size of the slice to be loaded in bytes.
19017 unsigned getLoadedSize() const {
19018 unsigned SliceSize = getUsedBits().popcount();
19019 assert(!(SliceSize & 0x7) && "Size is not a multiple of a byte.");
19020 return SliceSize / 8;
19021 }
19022
19023 /// Get the type that will be loaded for this slice.
19024 /// Note: This may not be the final type for the slice.
19025 EVT getLoadedType() const {
19026 assert(DAG && "Missing context");
19027 LLVMContext &Ctxt = *DAG->getContext();
19028 return EVT::getIntegerVT(Context&: Ctxt, BitWidth: getLoadedSize() * 8);
19029 }
19030
19031 /// Get the alignment of the load used for this slice.
19032 Align getAlign() const {
19033 Align Alignment = Origin->getAlign();
19034 uint64_t Offset = getOffsetFromBase();
19035 if (Offset != 0)
19036 Alignment = commonAlignment(A: Alignment, Offset: Alignment.value() + Offset);
19037 return Alignment;
19038 }
19039
19040 /// Check if this slice can be rewritten with legal operations.
19041 bool isLegal() const {
19042 // An invalid slice is not legal.
19043 if (!Origin || !Inst || !DAG)
19044 return false;
19045
19046 // Offsets are for indexed load only, we do not handle that.
19047 if (!Origin->getOffset().isUndef())
19048 return false;
19049
19050 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
19051
19052 // Check that the type is legal.
19053 EVT SliceType = getLoadedType();
19054 if (!TLI.isTypeLegal(VT: SliceType))
19055 return false;
19056
19057 // Check that the load is legal for this type.
19058 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: SliceType))
19059 return false;
19060
19061 // Check that the offset can be computed.
19062 // 1. Check its type.
19063 EVT PtrType = Origin->getBasePtr().getValueType();
19064 if (PtrType == MVT::Untyped || PtrType.isExtended())
19065 return false;
19066
19067 // 2. Check that it fits in the immediate.
19068 if (!TLI.isLegalAddImmediate(getOffsetFromBase()))
19069 return false;
19070
19071 // 3. Check that the computation is legal.
19072 if (!TLI.isOperationLegal(Op: ISD::ADD, VT: PtrType))
19073 return false;
19074
19075 // Check that the zext is legal if it needs one.
19076 EVT TruncateType = Inst->getValueType(ResNo: 0);
19077 if (TruncateType != SliceType &&
19078 !TLI.isOperationLegal(Op: ISD::ZERO_EXTEND, VT: TruncateType))
19079 return false;
19080
19081 return true;
19082 }
19083
19084 /// Get the offset in bytes of this slice in the original chunk of
19085 /// bits.
19086 /// \pre DAG != nullptr.
19087 uint64_t getOffsetFromBase() const {
19088 assert(DAG && "Missing context.");
19089 bool IsBigEndian = DAG->getDataLayout().isBigEndian();
19090 assert(!(Shift & 0x7) && "Shifts not aligned on Bytes are not supported.");
19091 uint64_t Offset = Shift / 8;
19092 unsigned TySizeInBytes = Origin->getValueSizeInBits(ResNo: 0) / 8;
19093 assert(!(Origin->getValueSizeInBits(0) & 0x7) &&
19094 "The size of the original loaded type is not a multiple of a"
19095 " byte.");
19096 // If Offset is bigger than TySizeInBytes, it means we are loading all
19097 // zeros. This should have been optimized before in the process.
19098 assert(TySizeInBytes > Offset &&
19099 "Invalid shift amount for given loaded size");
19100 if (IsBigEndian)
19101 Offset = TySizeInBytes - Offset - getLoadedSize();
19102 return Offset;
19103 }
19104
19105 /// Generate the sequence of instructions to load the slice
19106 /// represented by this object and redirect the uses of this slice to
19107 /// this new sequence of instructions.
19108 /// \pre this->Inst && this->Origin are valid Instructions and this
19109 /// object passed the legal check: LoadedSlice::isLegal returned true.
19110 /// \return The last instruction of the sequence used to load the slice.
19111 SDValue loadSlice() const {
19112 assert(Inst && Origin && "Unable to replace a non-existing slice.");
19113 const SDValue &OldBaseAddr = Origin->getBasePtr();
19114 SDValue BaseAddr = OldBaseAddr;
19115 // Get the offset in that chunk of bytes w.r.t. the endianness.
19116 int64_t Offset = static_cast<int64_t>(getOffsetFromBase());
19117 assert(Offset >= 0 && "Offset too big to fit in int64_t!");
19118 if (Offset) {
19119 // BaseAddr = BaseAddr + Offset.
19120 EVT ArithType = BaseAddr.getValueType();
19121 SDLoc DL(Origin);
19122 BaseAddr = DAG->getNode(Opcode: ISD::ADD, DL, VT: ArithType, N1: BaseAddr,
19123 N2: DAG->getConstant(Val: Offset, DL, VT: ArithType));
19124 }
19125
19126 // Create the type of the loaded slice according to its size.
19127 EVT SliceType = getLoadedType();
19128
19129 // Create the load for the slice.
19130 SDValue LastInst =
19131 DAG->getLoad(VT: SliceType, dl: SDLoc(Origin), Chain: Origin->getChain(), Ptr: BaseAddr,
19132 PtrInfo: Origin->getPointerInfo().getWithOffset(O: Offset), Alignment: getAlign(),
19133 MMOFlags: Origin->getMemOperand()->getFlags());
19134 // If the final type is not the same as the loaded type, this means that
19135 // we have to pad with zero. Create a zero extend for that.
19136 EVT FinalType = Inst->getValueType(ResNo: 0);
19137 if (SliceType != FinalType)
19138 LastInst =
19139 DAG->getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(LastInst), VT: FinalType, Operand: LastInst);
19140 return LastInst;
19141 }
19142
19143 /// Check if this slice can be merged with an expensive cross register
19144 /// bank copy. E.g.,
19145 /// i = load i32
19146 /// f = bitcast i32 i to float
19147 bool canMergeExpensiveCrossRegisterBankCopy() const {
19148 if (!Inst || !Inst->hasOneUse())
19149 return false;
19150 SDNode *Use = *Inst->use_begin();
19151 if (Use->getOpcode() != ISD::BITCAST)
19152 return false;
19153 assert(DAG && "Missing context");
19154 const TargetLowering &TLI = DAG->getTargetLoweringInfo();
19155 EVT ResVT = Use->getValueType(ResNo: 0);
19156 const TargetRegisterClass *ResRC =
19157 TLI.getRegClassFor(VT: ResVT.getSimpleVT(), isDivergent: Use->isDivergent());
19158 const TargetRegisterClass *ArgRC =
19159 TLI.getRegClassFor(VT: Use->getOperand(Num: 0).getValueType().getSimpleVT(),
19160 isDivergent: Use->getOperand(Num: 0)->isDivergent());
19161 if (ArgRC == ResRC || !TLI.isOperationLegal(Op: ISD::LOAD, VT: ResVT))
19162 return false;
19163
19164 // At this point, we know that we perform a cross-register-bank copy.
19165 // Check if it is expensive.
19166 const TargetRegisterInfo *TRI = DAG->getSubtarget().getRegisterInfo();
19167 // Assume bitcasts are cheap, unless both register classes do not
19168 // explicitly share a common sub class.
19169 if (!TRI || TRI->getCommonSubClass(A: ArgRC, B: ResRC))
19170 return false;
19171
19172 // Check if it will be merged with the load.
19173 // 1. Check the alignment / fast memory access constraint.
19174 unsigned IsFast = 0;
19175 if (!TLI.allowsMemoryAccess(Context&: *DAG->getContext(), DL: DAG->getDataLayout(), VT: ResVT,
19176 AddrSpace: Origin->getAddressSpace(), Alignment: getAlign(),
19177 Flags: Origin->getMemOperand()->getFlags(), Fast: &IsFast) ||
19178 !IsFast)
19179 return false;
19180
19181 // 2. Check that the load is a legal operation for that type.
19182 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: ResVT))
19183 return false;
19184
19185 // 3. Check that we do not have a zext in the way.
19186 if (Inst->getValueType(ResNo: 0) != getLoadedType())
19187 return false;
19188
19189 return true;
19190 }
19191};
19192
19193} // end anonymous namespace
19194
19195/// Check that all bits set in \p UsedBits form a dense region, i.e.,
19196/// \p UsedBits looks like 0..0 1..1 0..0.
19197static bool areUsedBitsDense(const APInt &UsedBits) {
19198 // If all the bits are one, this is dense!
19199 if (UsedBits.isAllOnes())
19200 return true;
19201
19202 // Get rid of the unused bits on the right.
19203 APInt NarrowedUsedBits = UsedBits.lshr(shiftAmt: UsedBits.countr_zero());
19204 // Get rid of the unused bits on the left.
19205 if (NarrowedUsedBits.countl_zero())
19206 NarrowedUsedBits = NarrowedUsedBits.trunc(width: NarrowedUsedBits.getActiveBits());
19207 // Check that the chunk of bits is completely used.
19208 return NarrowedUsedBits.isAllOnes();
19209}
19210
19211/// Check whether or not \p First and \p Second are next to each other
19212/// in memory. This means that there is no hole between the bits loaded
19213/// by \p First and the bits loaded by \p Second.
19214static bool areSlicesNextToEachOther(const LoadedSlice &First,
19215 const LoadedSlice &Second) {
19216 assert(First.Origin == Second.Origin && First.Origin &&
19217 "Unable to match different memory origins.");
19218 APInt UsedBits = First.getUsedBits();
19219 assert((UsedBits & Second.getUsedBits()) == 0 &&
19220 "Slices are not supposed to overlap.");
19221 UsedBits |= Second.getUsedBits();
19222 return areUsedBitsDense(UsedBits);
19223}
19224
19225/// Adjust the \p GlobalLSCost according to the target
19226/// paring capabilities and the layout of the slices.
19227/// \pre \p GlobalLSCost should account for at least as many loads as
19228/// there is in the slices in \p LoadedSlices.
19229static void adjustCostForPairing(SmallVectorImpl<LoadedSlice> &LoadedSlices,
19230 LoadedSlice::Cost &GlobalLSCost) {
19231 unsigned NumberOfSlices = LoadedSlices.size();
19232 // If there is less than 2 elements, no pairing is possible.
19233 if (NumberOfSlices < 2)
19234 return;
19235
19236 // Sort the slices so that elements that are likely to be next to each
19237 // other in memory are next to each other in the list.
19238 llvm::sort(C&: LoadedSlices, Comp: [](const LoadedSlice &LHS, const LoadedSlice &RHS) {
19239 assert(LHS.Origin == RHS.Origin && "Different bases not implemented.");
19240 return LHS.getOffsetFromBase() < RHS.getOffsetFromBase();
19241 });
19242 const TargetLowering &TLI = LoadedSlices[0].DAG->getTargetLoweringInfo();
19243 // First (resp. Second) is the first (resp. Second) potentially candidate
19244 // to be placed in a paired load.
19245 const LoadedSlice *First = nullptr;
19246 const LoadedSlice *Second = nullptr;
19247 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice,
19248 // Set the beginning of the pair.
19249 First = Second) {
19250 Second = &LoadedSlices[CurrSlice];
19251
19252 // If First is NULL, it means we start a new pair.
19253 // Get to the next slice.
19254 if (!First)
19255 continue;
19256
19257 EVT LoadedType = First->getLoadedType();
19258
19259 // If the types of the slices are different, we cannot pair them.
19260 if (LoadedType != Second->getLoadedType())
19261 continue;
19262
19263 // Check if the target supplies paired loads for this type.
19264 Align RequiredAlignment;
19265 if (!TLI.hasPairedLoad(LoadedType, RequiredAlignment)) {
19266 // move to the next pair, this type is hopeless.
19267 Second = nullptr;
19268 continue;
19269 }
19270 // Check if we meet the alignment requirement.
19271 if (First->getAlign() < RequiredAlignment)
19272 continue;
19273
19274 // Check that both loads are next to each other in memory.
19275 if (!areSlicesNextToEachOther(First: *First, Second: *Second))
19276 continue;
19277
19278 assert(GlobalLSCost.Loads > 0 && "We save more loads than we created!");
19279 --GlobalLSCost.Loads;
19280 // Move to the next pair.
19281 Second = nullptr;
19282 }
19283}
19284
19285/// Check the profitability of all involved LoadedSlice.
19286/// Currently, it is considered profitable if there is exactly two
19287/// involved slices (1) which are (2) next to each other in memory, and
19288/// whose cost (\see LoadedSlice::Cost) is smaller than the original load (3).
19289///
19290/// Note: The order of the elements in \p LoadedSlices may be modified, but not
19291/// the elements themselves.
19292///
19293/// FIXME: When the cost model will be mature enough, we can relax
19294/// constraints (1) and (2).
19295static bool isSlicingProfitable(SmallVectorImpl<LoadedSlice> &LoadedSlices,
19296 const APInt &UsedBits, bool ForCodeSize) {
19297 unsigned NumberOfSlices = LoadedSlices.size();
19298 if (StressLoadSlicing)
19299 return NumberOfSlices > 1;
19300
19301 // Check (1).
19302 if (NumberOfSlices != 2)
19303 return false;
19304
19305 // Check (2).
19306 if (!areUsedBitsDense(UsedBits))
19307 return false;
19308
19309 // Check (3).
19310 LoadedSlice::Cost OrigCost(ForCodeSize), GlobalSlicingCost(ForCodeSize);
19311 // The original code has one big load.
19312 OrigCost.Loads = 1;
19313 for (unsigned CurrSlice = 0; CurrSlice < NumberOfSlices; ++CurrSlice) {
19314 const LoadedSlice &LS = LoadedSlices[CurrSlice];
19315 // Accumulate the cost of all the slices.
19316 LoadedSlice::Cost SliceCost(LS, ForCodeSize);
19317 GlobalSlicingCost += SliceCost;
19318
19319 // Account as cost in the original configuration the gain obtained
19320 // with the current slices.
19321 OrigCost.addSliceGain(LS);
19322 }
19323
19324 // If the target supports paired load, adjust the cost accordingly.
19325 adjustCostForPairing(LoadedSlices, GlobalLSCost&: GlobalSlicingCost);
19326 return OrigCost > GlobalSlicingCost;
19327}
19328
19329/// If the given load, \p LI, is used only by trunc or trunc(lshr)
19330/// operations, split it in the various pieces being extracted.
19331///
19332/// This sort of thing is introduced by SROA.
19333/// This slicing takes care not to insert overlapping loads.
19334/// \pre LI is a simple load (i.e., not an atomic or volatile load).
19335bool DAGCombiner::SliceUpLoad(SDNode *N) {
19336 if (Level < AfterLegalizeDAG)
19337 return false;
19338
19339 LoadSDNode *LD = cast<LoadSDNode>(Val: N);
19340 if (!LD->isSimple() || !ISD::isNormalLoad(N: LD) ||
19341 !LD->getValueType(ResNo: 0).isInteger())
19342 return false;
19343
19344 // The algorithm to split up a load of a scalable vector into individual
19345 // elements currently requires knowing the length of the loaded type,
19346 // so will need adjusting to work on scalable vectors.
19347 if (LD->getValueType(ResNo: 0).isScalableVector())
19348 return false;
19349
19350 // Keep track of already used bits to detect overlapping values.
19351 // In that case, we will just abort the transformation.
19352 APInt UsedBits(LD->getValueSizeInBits(ResNo: 0), 0);
19353
19354 SmallVector<LoadedSlice, 4> LoadedSlices;
19355
19356 // Check if this load is used as several smaller chunks of bits.
19357 // Basically, look for uses in trunc or trunc(lshr) and record a new chain
19358 // of computation for each trunc.
19359 for (SDNode::use_iterator UI = LD->use_begin(), UIEnd = LD->use_end();
19360 UI != UIEnd; ++UI) {
19361 // Skip the uses of the chain.
19362 if (UI.getUse().getResNo() != 0)
19363 continue;
19364
19365 SDNode *User = *UI;
19366 unsigned Shift = 0;
19367
19368 // Check if this is a trunc(lshr).
19369 if (User->getOpcode() == ISD::SRL && User->hasOneUse() &&
19370 isa<ConstantSDNode>(Val: User->getOperand(Num: 1))) {
19371 Shift = User->getConstantOperandVal(Num: 1);
19372 User = *User->use_begin();
19373 }
19374
19375 // At this point, User is a Truncate, iff we encountered, trunc or
19376 // trunc(lshr).
19377 if (User->getOpcode() != ISD::TRUNCATE)
19378 return false;
19379
19380 // The width of the type must be a power of 2 and greater than 8-bits.
19381 // Otherwise the load cannot be represented in LLVM IR.
19382 // Moreover, if we shifted with a non-8-bits multiple, the slice
19383 // will be across several bytes. We do not support that.
19384 unsigned Width = User->getValueSizeInBits(ResNo: 0);
19385 if (Width < 8 || !isPowerOf2_32(Value: Width) || (Shift & 0x7))
19386 return false;
19387
19388 // Build the slice for this chain of computations.
19389 LoadedSlice LS(User, LD, Shift, &DAG);
19390 APInt CurrentUsedBits = LS.getUsedBits();
19391
19392 // Check if this slice overlaps with another.
19393 if ((CurrentUsedBits & UsedBits) != 0)
19394 return false;
19395 // Update the bits used globally.
19396 UsedBits |= CurrentUsedBits;
19397
19398 // Check if the new slice would be legal.
19399 if (!LS.isLegal())
19400 return false;
19401
19402 // Record the slice.
19403 LoadedSlices.push_back(Elt: LS);
19404 }
19405
19406 // Abort slicing if it does not seem to be profitable.
19407 if (!isSlicingProfitable(LoadedSlices, UsedBits, ForCodeSize))
19408 return false;
19409
19410 ++SlicedLoads;
19411
19412 // Rewrite each chain to use an independent load.
19413 // By construction, each chain can be represented by a unique load.
19414
19415 // Prepare the argument for the new token factor for all the slices.
19416 SmallVector<SDValue, 8> ArgChains;
19417 for (const LoadedSlice &LS : LoadedSlices) {
19418 SDValue SliceInst = LS.loadSlice();
19419 CombineTo(N: LS.Inst, Res: SliceInst, AddTo: true);
19420 if (SliceInst.getOpcode() != ISD::LOAD)
19421 SliceInst = SliceInst.getOperand(i: 0);
19422 assert(SliceInst->getOpcode() == ISD::LOAD &&
19423 "It takes more than a zext to get to the loaded slice!!");
19424 ArgChains.push_back(Elt: SliceInst.getValue(R: 1));
19425 }
19426
19427 SDValue Chain = DAG.getNode(ISD::TokenFactor, SDLoc(LD), MVT::Other,
19428 ArgChains);
19429 DAG.ReplaceAllUsesOfValueWith(From: SDValue(N, 1), To: Chain);
19430 AddToWorklist(N: Chain.getNode());
19431 return true;
19432}
19433
19434/// Check to see if V is (and load (ptr), imm), where the load is having
19435/// specific bytes cleared out. If so, return the byte size being masked out
19436/// and the shift amount.
19437static std::pair<unsigned, unsigned>
19438CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) {
19439 std::pair<unsigned, unsigned> Result(0, 0);
19440
19441 // Check for the structure we're looking for.
19442 if (V->getOpcode() != ISD::AND ||
19443 !isa<ConstantSDNode>(Val: V->getOperand(Num: 1)) ||
19444 !ISD::isNormalLoad(N: V->getOperand(Num: 0).getNode()))
19445 return Result;
19446
19447 // Check the chain and pointer.
19448 LoadSDNode *LD = cast<LoadSDNode>(Val: V->getOperand(Num: 0));
19449 if (LD->getBasePtr() != Ptr) return Result; // Not from same pointer.
19450
19451 // This only handles simple types.
19452 if (V.getValueType() != MVT::i16 &&
19453 V.getValueType() != MVT::i32 &&
19454 V.getValueType() != MVT::i64)
19455 return Result;
19456
19457 // Check the constant mask. Invert it so that the bits being masked out are
19458 // 0 and the bits being kept are 1. Use getSExtValue so that leading bits
19459 // follow the sign bit for uniformity.
19460 uint64_t NotMask = ~cast<ConstantSDNode>(Val: V->getOperand(Num: 1))->getSExtValue();
19461 unsigned NotMaskLZ = llvm::countl_zero(Val: NotMask);
19462 if (NotMaskLZ & 7) return Result; // Must be multiple of a byte.
19463 unsigned NotMaskTZ = llvm::countr_zero(Val: NotMask);
19464 if (NotMaskTZ & 7) return Result; // Must be multiple of a byte.
19465 if (NotMaskLZ == 64) return Result; // All zero mask.
19466
19467 // See if we have a continuous run of bits. If so, we have 0*1+0*
19468 if (llvm::countr_one(Value: NotMask >> NotMaskTZ) + NotMaskTZ + NotMaskLZ != 64)
19469 return Result;
19470
19471 // Adjust NotMaskLZ down to be from the actual size of the int instead of i64.
19472 if (V.getValueType() != MVT::i64 && NotMaskLZ)
19473 NotMaskLZ -= 64-V.getValueSizeInBits();
19474
19475 unsigned MaskedBytes = (V.getValueSizeInBits()-NotMaskLZ-NotMaskTZ)/8;
19476 switch (MaskedBytes) {
19477 case 1:
19478 case 2:
19479 case 4: break;
19480 default: return Result; // All one mask, or 5-byte mask.
19481 }
19482
19483 // Verify that the first bit starts at a multiple of mask so that the access
19484 // is aligned the same as the access width.
19485 if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result;
19486
19487 // For narrowing to be valid, it must be the case that the load the
19488 // immediately preceding memory operation before the store.
19489 if (LD == Chain.getNode())
19490 ; // ok.
19491 else if (Chain->getOpcode() == ISD::TokenFactor &&
19492 SDValue(LD, 1).hasOneUse()) {
19493 // LD has only 1 chain use so they are no indirect dependencies.
19494 if (!LD->isOperandOf(N: Chain.getNode()))
19495 return Result;
19496 } else
19497 return Result; // Fail.
19498
19499 Result.first = MaskedBytes;
19500 Result.second = NotMaskTZ/8;
19501 return Result;
19502}
19503
19504/// Check to see if IVal is something that provides a value as specified by
19505/// MaskInfo. If so, replace the specified store with a narrower store of
19506/// truncated IVal.
19507static SDValue
19508ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
19509 SDValue IVal, StoreSDNode *St,
19510 DAGCombiner *DC) {
19511 unsigned NumBytes = MaskInfo.first;
19512 unsigned ByteShift = MaskInfo.second;
19513 SelectionDAG &DAG = DC->getDAG();
19514
19515 // Check to see if IVal is all zeros in the part being masked in by the 'or'
19516 // that uses this. If not, this is not a replacement.
19517 APInt Mask = ~APInt::getBitsSet(numBits: IVal.getValueSizeInBits(),
19518 loBit: ByteShift*8, hiBit: (ByteShift+NumBytes)*8);
19519 if (!DAG.MaskedValueIsZero(Op: IVal, Mask)) return SDValue();
19520
19521 // Check that it is legal on the target to do this. It is legal if the new
19522 // VT we're shrinking to (i8/i16/i32) is legal or we're still before type
19523 // legalization. If the source type is legal, but the store type isn't, see
19524 // if we can use a truncating store.
19525 MVT VT = MVT::getIntegerVT(BitWidth: NumBytes * 8);
19526 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19527 bool UseTruncStore;
19528 if (DC->isTypeLegal(VT))
19529 UseTruncStore = false;
19530 else if (TLI.isTypeLegal(VT: IVal.getValueType()) &&
19531 TLI.isTruncStoreLegal(ValVT: IVal.getValueType(), MemVT: VT))
19532 UseTruncStore = true;
19533 else
19534 return SDValue();
19535
19536 // Can't do this for indexed stores.
19537 if (St->isIndexed())
19538 return SDValue();
19539
19540 // Check that the target doesn't think this is a bad idea.
19541 if (St->getMemOperand() &&
19542 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT,
19543 MMO: *St->getMemOperand()))
19544 return SDValue();
19545
19546 // Okay, we can do this! Replace the 'St' store with a store of IVal that is
19547 // shifted by ByteShift and truncated down to NumBytes.
19548 if (ByteShift) {
19549 SDLoc DL(IVal);
19550 IVal = DAG.getNode(Opcode: ISD::SRL, DL, VT: IVal.getValueType(), N1: IVal,
19551 N2: DAG.getConstant(Val: ByteShift*8, DL,
19552 VT: DC->getShiftAmountTy(LHSTy: IVal.getValueType())));
19553 }
19554
19555 // Figure out the offset for the store and the alignment of the access.
19556 unsigned StOffset;
19557 if (DAG.getDataLayout().isLittleEndian())
19558 StOffset = ByteShift;
19559 else
19560 StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes;
19561
19562 SDValue Ptr = St->getBasePtr();
19563 if (StOffset) {
19564 SDLoc DL(IVal);
19565 Ptr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: StOffset), DL);
19566 }
19567
19568 ++OpsNarrowed;
19569 if (UseTruncStore)
19570 return DAG.getTruncStore(Chain: St->getChain(), dl: SDLoc(St), Val: IVal, Ptr,
19571 PtrInfo: St->getPointerInfo().getWithOffset(O: StOffset),
19572 SVT: VT, Alignment: St->getOriginalAlign());
19573
19574 // Truncate down to the new size.
19575 IVal = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(IVal), VT, Operand: IVal);
19576
19577 return DAG
19578 .getStore(Chain: St->getChain(), dl: SDLoc(St), Val: IVal, Ptr,
19579 PtrInfo: St->getPointerInfo().getWithOffset(O: StOffset),
19580 Alignment: St->getOriginalAlign());
19581}
19582
19583/// Look for sequence of load / op / store where op is one of 'or', 'xor', and
19584/// 'and' of immediates. If 'op' is only touching some of the loaded bits, try
19585/// narrowing the load and store if it would end up being a win for performance
19586/// or code size.
19587SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
19588 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
19589 if (!ST->isSimple())
19590 return SDValue();
19591
19592 SDValue Chain = ST->getChain();
19593 SDValue Value = ST->getValue();
19594 SDValue Ptr = ST->getBasePtr();
19595 EVT VT = Value.getValueType();
19596
19597 if (ST->isTruncatingStore() || VT.isVector())
19598 return SDValue();
19599
19600 unsigned Opc = Value.getOpcode();
19601
19602 if ((Opc != ISD::OR && Opc != ISD::XOR && Opc != ISD::AND) ||
19603 !Value.hasOneUse())
19604 return SDValue();
19605
19606 // If this is "store (or X, Y), P" and X is "(and (load P), cst)", where cst
19607 // is a byte mask indicating a consecutive number of bytes, check to see if
19608 // Y is known to provide just those bytes. If so, we try to replace the
19609 // load + replace + store sequence with a single (narrower) store, which makes
19610 // the load dead.
19611 if (Opc == ISD::OR && EnableShrinkLoadReplaceStoreWithStore) {
19612 std::pair<unsigned, unsigned> MaskedLoad;
19613 MaskedLoad = CheckForMaskedLoad(V: Value.getOperand(i: 0), Ptr, Chain);
19614 if (MaskedLoad.first)
19615 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskInfo: MaskedLoad,
19616 IVal: Value.getOperand(i: 1), St: ST,DC: this))
19617 return NewST;
19618
19619 // Or is commutative, so try swapping X and Y.
19620 MaskedLoad = CheckForMaskedLoad(V: Value.getOperand(i: 1), Ptr, Chain);
19621 if (MaskedLoad.first)
19622 if (SDValue NewST = ShrinkLoadReplaceStoreWithStore(MaskInfo: MaskedLoad,
19623 IVal: Value.getOperand(i: 0), St: ST,DC: this))
19624 return NewST;
19625 }
19626
19627 if (!EnableReduceLoadOpStoreWidth)
19628 return SDValue();
19629
19630 if (Value.getOperand(i: 1).getOpcode() != ISD::Constant)
19631 return SDValue();
19632
19633 SDValue N0 = Value.getOperand(i: 0);
19634 if (ISD::isNormalLoad(N: N0.getNode()) && N0.hasOneUse() &&
19635 Chain == SDValue(N0.getNode(), 1)) {
19636 LoadSDNode *LD = cast<LoadSDNode>(Val&: N0);
19637 if (LD->getBasePtr() != Ptr ||
19638 LD->getPointerInfo().getAddrSpace() !=
19639 ST->getPointerInfo().getAddrSpace())
19640 return SDValue();
19641
19642 // Find the type to narrow it the load / op / store to.
19643 SDValue N1 = Value.getOperand(i: 1);
19644 unsigned BitWidth = N1.getValueSizeInBits();
19645 APInt Imm = N1->getAsAPIntVal();
19646 if (Opc == ISD::AND)
19647 Imm ^= APInt::getAllOnes(numBits: BitWidth);
19648 if (Imm == 0 || Imm.isAllOnes())
19649 return SDValue();
19650 unsigned ShAmt = Imm.countr_zero();
19651 unsigned MSB = BitWidth - Imm.countl_zero() - 1;
19652 unsigned NewBW = NextPowerOf2(A: MSB - ShAmt);
19653 EVT NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewBW);
19654 // The narrowing should be profitable, the load/store operation should be
19655 // legal (or custom) and the store size should be equal to the NewVT width.
19656 while (NewBW < BitWidth &&
19657 (NewVT.getStoreSizeInBits() != NewBW ||
19658 !TLI.isOperationLegalOrCustom(Op: Opc, VT: NewVT) ||
19659 !TLI.isNarrowingProfitable(SrcVT: VT, DestVT: NewVT))) {
19660 NewBW = NextPowerOf2(A: NewBW);
19661 NewVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewBW);
19662 }
19663 if (NewBW >= BitWidth)
19664 return SDValue();
19665
19666 // If the lsb changed does not start at the type bitwidth boundary,
19667 // start at the previous one.
19668 if (ShAmt % NewBW)
19669 ShAmt = (((ShAmt + NewBW - 1) / NewBW) * NewBW) - NewBW;
19670 APInt Mask = APInt::getBitsSet(numBits: BitWidth, loBit: ShAmt,
19671 hiBit: std::min(a: BitWidth, b: ShAmt + NewBW));
19672 if ((Imm & Mask) == Imm) {
19673 APInt NewImm = (Imm & Mask).lshr(shiftAmt: ShAmt).trunc(width: NewBW);
19674 if (Opc == ISD::AND)
19675 NewImm ^= APInt::getAllOnes(numBits: NewBW);
19676 uint64_t PtrOff = ShAmt / 8;
19677 // For big endian targets, we need to adjust the offset to the pointer to
19678 // load the correct bytes.
19679 if (DAG.getDataLayout().isBigEndian())
19680 PtrOff = (BitWidth + 7 - NewBW) / 8 - PtrOff;
19681
19682 unsigned IsFast = 0;
19683 Align NewAlign = commonAlignment(A: LD->getAlign(), Offset: PtrOff);
19684 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: NewVT,
19685 AddrSpace: LD->getAddressSpace(), Alignment: NewAlign,
19686 Flags: LD->getMemOperand()->getFlags(), Fast: &IsFast) ||
19687 !IsFast)
19688 return SDValue();
19689
19690 SDValue NewPtr =
19691 DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: PtrOff), DL: SDLoc(LD));
19692 SDValue NewLD =
19693 DAG.getLoad(VT: NewVT, dl: SDLoc(N0), Chain: LD->getChain(), Ptr: NewPtr,
19694 PtrInfo: LD->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign,
19695 MMOFlags: LD->getMemOperand()->getFlags(), AAInfo: LD->getAAInfo());
19696 SDValue NewVal = DAG.getNode(Opcode: Opc, DL: SDLoc(Value), VT: NewVT, N1: NewLD,
19697 N2: DAG.getConstant(Val: NewImm, DL: SDLoc(Value),
19698 VT: NewVT));
19699 SDValue NewST =
19700 DAG.getStore(Chain, dl: SDLoc(N), Val: NewVal, Ptr: NewPtr,
19701 PtrInfo: ST->getPointerInfo().getWithOffset(O: PtrOff), Alignment: NewAlign);
19702
19703 AddToWorklist(N: NewPtr.getNode());
19704 AddToWorklist(N: NewLD.getNode());
19705 AddToWorklist(N: NewVal.getNode());
19706 WorklistRemover DeadNodes(*this);
19707 DAG.ReplaceAllUsesOfValueWith(From: N0.getValue(R: 1), To: NewLD.getValue(R: 1));
19708 ++OpsNarrowed;
19709 return NewST;
19710 }
19711 }
19712
19713 return SDValue();
19714}
19715
19716/// For a given floating point load / store pair, if the load value isn't used
19717/// by any other operations, then consider transforming the pair to integer
19718/// load / store operations if the target deems the transformation profitable.
19719SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) {
19720 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
19721 SDValue Value = ST->getValue();
19722 if (ISD::isNormalStore(N: ST) && ISD::isNormalLoad(N: Value.getNode()) &&
19723 Value.hasOneUse()) {
19724 LoadSDNode *LD = cast<LoadSDNode>(Val&: Value);
19725 EVT VT = LD->getMemoryVT();
19726 if (!VT.isFloatingPoint() ||
19727 VT != ST->getMemoryVT() ||
19728 LD->isNonTemporal() ||
19729 ST->isNonTemporal() ||
19730 LD->getPointerInfo().getAddrSpace() != 0 ||
19731 ST->getPointerInfo().getAddrSpace() != 0)
19732 return SDValue();
19733
19734 TypeSize VTSize = VT.getSizeInBits();
19735
19736 // We don't know the size of scalable types at compile time so we cannot
19737 // create an integer of the equivalent size.
19738 if (VTSize.isScalable())
19739 return SDValue();
19740
19741 unsigned FastLD = 0, FastST = 0;
19742 EVT IntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: VTSize.getFixedValue());
19743 if (!TLI.isOperationLegal(Op: ISD::LOAD, VT: IntVT) ||
19744 !TLI.isOperationLegal(Op: ISD::STORE, VT: IntVT) ||
19745 !TLI.isDesirableToTransformToIntegerOp(ISD::LOAD, VT) ||
19746 !TLI.isDesirableToTransformToIntegerOp(ISD::STORE, VT) ||
19747 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: IntVT,
19748 MMO: *LD->getMemOperand(), Fast: &FastLD) ||
19749 !TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: IntVT,
19750 MMO: *ST->getMemOperand(), Fast: &FastST) ||
19751 !FastLD || !FastST)
19752 return SDValue();
19753
19754 SDValue NewLD =
19755 DAG.getLoad(VT: IntVT, dl: SDLoc(Value), Chain: LD->getChain(), Ptr: LD->getBasePtr(),
19756 PtrInfo: LD->getPointerInfo(), Alignment: LD->getAlign());
19757
19758 SDValue NewST =
19759 DAG.getStore(Chain: ST->getChain(), dl: SDLoc(N), Val: NewLD, Ptr: ST->getBasePtr(),
19760 PtrInfo: ST->getPointerInfo(), Alignment: ST->getAlign());
19761
19762 AddToWorklist(N: NewLD.getNode());
19763 AddToWorklist(N: NewST.getNode());
19764 WorklistRemover DeadNodes(*this);
19765 DAG.ReplaceAllUsesOfValueWith(From: Value.getValue(R: 1), To: NewLD.getValue(R: 1));
19766 ++LdStFP2Int;
19767 return NewST;
19768 }
19769
19770 return SDValue();
19771}
19772
19773// This is a helper function for visitMUL to check the profitability
19774// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
19775// MulNode is the original multiply, AddNode is (add x, c1),
19776// and ConstNode is c2.
19777//
19778// If the (add x, c1) has multiple uses, we could increase
19779// the number of adds if we make this transformation.
19780// It would only be worth doing this if we can remove a
19781// multiply in the process. Check for that here.
19782// To illustrate:
19783// (A + c1) * c3
19784// (A + c2) * c3
19785// We're checking for cases where we have common "c3 * A" expressions.
19786bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
19787 SDValue ConstNode) {
19788 APInt Val;
19789
19790 // If the add only has one use, and the target thinks the folding is
19791 // profitable or does not lead to worse code, this would be OK to do.
19792 if (AddNode->hasOneUse() &&
19793 TLI.isMulAddWithConstProfitable(AddNode, ConstNode))
19794 return true;
19795
19796 // Walk all the users of the constant with which we're multiplying.
19797 for (SDNode *Use : ConstNode->uses()) {
19798 if (Use == MulNode) // This use is the one we're on right now. Skip it.
19799 continue;
19800
19801 if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
19802 SDNode *OtherOp;
19803 SDNode *MulVar = AddNode.getOperand(i: 0).getNode();
19804
19805 // OtherOp is what we're multiplying against the constant.
19806 if (Use->getOperand(Num: 0) == ConstNode)
19807 OtherOp = Use->getOperand(Num: 1).getNode();
19808 else
19809 OtherOp = Use->getOperand(Num: 0).getNode();
19810
19811 // Check to see if multiply is with the same operand of our "add".
19812 //
19813 // ConstNode = CONST
19814 // Use = ConstNode * A <-- visiting Use. OtherOp is A.
19815 // ...
19816 // AddNode = (A + c1) <-- MulVar is A.
19817 // = AddNode * ConstNode <-- current visiting instruction.
19818 //
19819 // If we make this transformation, we will have a common
19820 // multiply (ConstNode * A) that we can save.
19821 if (OtherOp == MulVar)
19822 return true;
19823
19824 // Now check to see if a future expansion will give us a common
19825 // multiply.
19826 //
19827 // ConstNode = CONST
19828 // AddNode = (A + c1)
19829 // ... = AddNode * ConstNode <-- current visiting instruction.
19830 // ...
19831 // OtherOp = (A + c2)
19832 // Use = OtherOp * ConstNode <-- visiting Use.
19833 //
19834 // If we make this transformation, we will have a common
19835 // multiply (CONST * A) after we also do the same transformation
19836 // to the "t2" instruction.
19837 if (OtherOp->getOpcode() == ISD::ADD &&
19838 DAG.isConstantIntBuildVectorOrConstantInt(N: OtherOp->getOperand(Num: 1)) &&
19839 OtherOp->getOperand(Num: 0).getNode() == MulVar)
19840 return true;
19841 }
19842 }
19843
19844 // Didn't find a case where this would be profitable.
19845 return false;
19846}
19847
19848SDValue DAGCombiner::getMergeStoreChains(SmallVectorImpl<MemOpLink> &StoreNodes,
19849 unsigned NumStores) {
19850 SmallVector<SDValue, 8> Chains;
19851 SmallPtrSet<const SDNode *, 8> Visited;
19852 SDLoc StoreDL(StoreNodes[0].MemNode);
19853
19854 for (unsigned i = 0; i < NumStores; ++i) {
19855 Visited.insert(Ptr: StoreNodes[i].MemNode);
19856 }
19857
19858 // don't include nodes that are children or repeated nodes.
19859 for (unsigned i = 0; i < NumStores; ++i) {
19860 if (Visited.insert(Ptr: StoreNodes[i].MemNode->getChain().getNode()).second)
19861 Chains.push_back(Elt: StoreNodes[i].MemNode->getChain());
19862 }
19863
19864 assert(!Chains.empty() && "Chain should have generated a chain");
19865 return DAG.getTokenFactor(DL: StoreDL, Vals&: Chains);
19866}
19867
19868bool DAGCombiner::hasSameUnderlyingObj(ArrayRef<MemOpLink> StoreNodes) {
19869 const Value *UnderlyingObj = nullptr;
19870 for (const auto &MemOp : StoreNodes) {
19871 const MachineMemOperand *MMO = MemOp.MemNode->getMemOperand();
19872 // Pseudo value like stack frame has its own frame index and size, should
19873 // not use the first store's frame index for other frames.
19874 if (MMO->getPseudoValue())
19875 return false;
19876
19877 if (!MMO->getValue())
19878 return false;
19879
19880 const Value *Obj = getUnderlyingObject(V: MMO->getValue());
19881
19882 if (UnderlyingObj && UnderlyingObj != Obj)
19883 return false;
19884
19885 if (!UnderlyingObj)
19886 UnderlyingObj = Obj;
19887 }
19888
19889 return true;
19890}
19891
19892bool DAGCombiner::mergeStoresOfConstantsOrVecElts(
19893 SmallVectorImpl<MemOpLink> &StoreNodes, EVT MemVT, unsigned NumStores,
19894 bool IsConstantSrc, bool UseVector, bool UseTrunc) {
19895 // Make sure we have something to merge.
19896 if (NumStores < 2)
19897 return false;
19898
19899 assert((!UseTrunc || !UseVector) &&
19900 "This optimization cannot emit a vector truncating store");
19901
19902 // The latest Node in the DAG.
19903 SDLoc DL(StoreNodes[0].MemNode);
19904
19905 TypeSize ElementSizeBits = MemVT.getStoreSizeInBits();
19906 unsigned SizeInBits = NumStores * ElementSizeBits;
19907 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
19908
19909 std::optional<MachineMemOperand::Flags> Flags;
19910 AAMDNodes AAInfo;
19911 for (unsigned I = 0; I != NumStores; ++I) {
19912 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[I].MemNode);
19913 if (!Flags) {
19914 Flags = St->getMemOperand()->getFlags();
19915 AAInfo = St->getAAInfo();
19916 continue;
19917 }
19918 // Skip merging if there's an inconsistent flag.
19919 if (Flags != St->getMemOperand()->getFlags())
19920 return false;
19921 // Concatenate AA metadata.
19922 AAInfo = AAInfo.concat(Other: St->getAAInfo());
19923 }
19924
19925 EVT StoreTy;
19926 if (UseVector) {
19927 unsigned Elts = NumStores * NumMemElts;
19928 // Get the type for the merged vector store.
19929 StoreTy = EVT::getVectorVT(Context&: *DAG.getContext(), VT: MemVT.getScalarType(), NumElements: Elts);
19930 } else
19931 StoreTy = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: SizeInBits);
19932
19933 SDValue StoredVal;
19934 if (UseVector) {
19935 if (IsConstantSrc) {
19936 SmallVector<SDValue, 8> BuildVector;
19937 for (unsigned I = 0; I != NumStores; ++I) {
19938 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[I].MemNode);
19939 SDValue Val = St->getValue();
19940 // If constant is of the wrong type, convert it now. This comes up
19941 // when one of our stores was truncating.
19942 if (MemVT != Val.getValueType()) {
19943 Val = peekThroughBitcasts(V: Val);
19944 // Deal with constants of wrong size.
19945 if (ElementSizeBits != Val.getValueSizeInBits()) {
19946 auto *C = dyn_cast<ConstantSDNode>(Val);
19947 if (!C)
19948 // Not clear how to truncate FP values.
19949 // TODO: Handle truncation of build_vector constants
19950 return false;
19951
19952 EVT IntMemVT =
19953 EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: MemVT.getSizeInBits());
19954 Val = DAG.getConstant(Val: C->getAPIntValue()
19955 .zextOrTrunc(width: Val.getValueSizeInBits())
19956 .zextOrTrunc(width: ElementSizeBits),
19957 DL: SDLoc(C), VT: IntMemVT);
19958 }
19959 // Make sure correctly size type is the correct type.
19960 Val = DAG.getBitcast(VT: MemVT, V: Val);
19961 }
19962 BuildVector.push_back(Elt: Val);
19963 }
19964 StoredVal = DAG.getNode(Opcode: MemVT.isVector() ? ISD::CONCAT_VECTORS
19965 : ISD::BUILD_VECTOR,
19966 DL, VT: StoreTy, Ops: BuildVector);
19967 } else {
19968 SmallVector<SDValue, 8> Ops;
19969 for (unsigned i = 0; i < NumStores; ++i) {
19970 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
19971 SDValue Val = peekThroughBitcasts(V: St->getValue());
19972 // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
19973 // type MemVT. If the underlying value is not the correct
19974 // type, but it is an extraction of an appropriate vector we
19975 // can recast Val to be of the correct type. This may require
19976 // converting between EXTRACT_VECTOR_ELT and
19977 // EXTRACT_SUBVECTOR.
19978 if ((MemVT != Val.getValueType()) &&
19979 (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
19980 Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
19981 EVT MemVTScalarTy = MemVT.getScalarType();
19982 // We may need to add a bitcast here to get types to line up.
19983 if (MemVTScalarTy != Val.getValueType().getScalarType()) {
19984 Val = DAG.getBitcast(VT: MemVT, V: Val);
19985 } else if (MemVT.isVector() &&
19986 Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
19987 Val = DAG.getNode(Opcode: ISD::BUILD_VECTOR, DL, VT: MemVT, Operand: Val);
19988 } else {
19989 unsigned OpC = MemVT.isVector() ? ISD::EXTRACT_SUBVECTOR
19990 : ISD::EXTRACT_VECTOR_ELT;
19991 SDValue Vec = Val.getOperand(i: 0);
19992 SDValue Idx = Val.getOperand(i: 1);
19993 Val = DAG.getNode(Opcode: OpC, DL: SDLoc(Val), VT: MemVT, N1: Vec, N2: Idx);
19994 }
19995 }
19996 Ops.push_back(Elt: Val);
19997 }
19998
19999 // Build the extracted vector elements back into a vector.
20000 StoredVal = DAG.getNode(Opcode: MemVT.isVector() ? ISD::CONCAT_VECTORS
20001 : ISD::BUILD_VECTOR,
20002 DL, VT: StoreTy, Ops);
20003 }
20004 } else {
20005 // We should always use a vector store when merging extracted vector
20006 // elements, so this path implies a store of constants.
20007 assert(IsConstantSrc && "Merged vector elements should use vector store");
20008
20009 APInt StoreInt(SizeInBits, 0);
20010
20011 // Construct a single integer constant which is made of the smaller
20012 // constant inputs.
20013 bool IsLE = DAG.getDataLayout().isLittleEndian();
20014 for (unsigned i = 0; i < NumStores; ++i) {
20015 unsigned Idx = IsLE ? (NumStores - 1 - i) : i;
20016 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[Idx].MemNode);
20017
20018 SDValue Val = St->getValue();
20019 Val = peekThroughBitcasts(V: Val);
20020 StoreInt <<= ElementSizeBits;
20021 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val)) {
20022 StoreInt |= C->getAPIntValue()
20023 .zextOrTrunc(width: ElementSizeBits)
20024 .zextOrTrunc(width: SizeInBits);
20025 } else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val)) {
20026 StoreInt |= C->getValueAPF()
20027 .bitcastToAPInt()
20028 .zextOrTrunc(width: ElementSizeBits)
20029 .zextOrTrunc(width: SizeInBits);
20030 // If fp truncation is necessary give up for now.
20031 if (MemVT.getSizeInBits() != ElementSizeBits)
20032 return false;
20033 } else if (ISD::isBuildVectorOfConstantSDNodes(N: Val.getNode()) ||
20034 ISD::isBuildVectorOfConstantFPSDNodes(N: Val.getNode())) {
20035 // Not yet handled
20036 return false;
20037 } else {
20038 llvm_unreachable("Invalid constant element type");
20039 }
20040 }
20041
20042 // Create the new Load and Store operations.
20043 StoredVal = DAG.getConstant(Val: StoreInt, DL, VT: StoreTy);
20044 }
20045
20046 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20047 SDValue NewChain = getMergeStoreChains(StoreNodes, NumStores);
20048 bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
20049
20050 // make sure we use trunc store if it's necessary to be legal.
20051 // When generate the new widen store, if the first store's pointer info can
20052 // not be reused, discard the pointer info except the address space because
20053 // now the widen store can not be represented by the original pointer info
20054 // which is for the narrow memory object.
20055 SDValue NewStore;
20056 if (!UseTrunc) {
20057 NewStore = DAG.getStore(
20058 Chain: NewChain, dl: DL, Val: StoredVal, Ptr: FirstInChain->getBasePtr(),
20059 PtrInfo: CanReusePtrInfo
20060 ? FirstInChain->getPointerInfo()
20061 : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
20062 Alignment: FirstInChain->getAlign(), MMOFlags: *Flags, AAInfo);
20063 } else { // Must be realized as a trunc store
20064 EVT LegalizedStoredValTy =
20065 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: StoredVal.getValueType());
20066 unsigned LegalizedStoreSize = LegalizedStoredValTy.getSizeInBits();
20067 ConstantSDNode *C = cast<ConstantSDNode>(Val&: StoredVal);
20068 SDValue ExtendedStoreVal =
20069 DAG.getConstant(Val: C->getAPIntValue().zextOrTrunc(width: LegalizedStoreSize), DL,
20070 VT: LegalizedStoredValTy);
20071 NewStore = DAG.getTruncStore(
20072 Chain: NewChain, dl: DL, Val: ExtendedStoreVal, Ptr: FirstInChain->getBasePtr(),
20073 PtrInfo: CanReusePtrInfo
20074 ? FirstInChain->getPointerInfo()
20075 : MachinePointerInfo(FirstInChain->getPointerInfo().getAddrSpace()),
20076 SVT: StoredVal.getValueType() /*TVT*/, Alignment: FirstInChain->getAlign(), MMOFlags: *Flags,
20077 AAInfo);
20078 }
20079
20080 // Replace all merged stores with the new store.
20081 for (unsigned i = 0; i < NumStores; ++i)
20082 CombineTo(N: StoreNodes[i].MemNode, Res: NewStore);
20083
20084 AddToWorklist(N: NewChain.getNode());
20085 return true;
20086}
20087
20088void DAGCombiner::getStoreMergeCandidates(
20089 StoreSDNode *St, SmallVectorImpl<MemOpLink> &StoreNodes,
20090 SDNode *&RootNode) {
20091 // This holds the base pointer, index, and the offset in bytes from the base
20092 // pointer. We must have a base and an offset. Do not handle stores to undef
20093 // base pointers.
20094 BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
20095 if (!BasePtr.getBase().getNode() || BasePtr.getBase().isUndef())
20096 return;
20097
20098 SDValue Val = peekThroughBitcasts(V: St->getValue());
20099 StoreSource StoreSrc = getStoreSource(StoreVal: Val);
20100 assert(StoreSrc != StoreSource::Unknown && "Expected known source for store");
20101
20102 // Match on loadbaseptr if relevant.
20103 EVT MemVT = St->getMemoryVT();
20104 BaseIndexOffset LBasePtr;
20105 EVT LoadVT;
20106 if (StoreSrc == StoreSource::Load) {
20107 auto *Ld = cast<LoadSDNode>(Val);
20108 LBasePtr = BaseIndexOffset::match(N: Ld, DAG);
20109 LoadVT = Ld->getMemoryVT();
20110 // Load and store should be the same type.
20111 if (MemVT != LoadVT)
20112 return;
20113 // Loads must only have one use.
20114 if (!Ld->hasNUsesOfValue(NUses: 1, Value: 0))
20115 return;
20116 // The memory operands must not be volatile/indexed/atomic.
20117 // TODO: May be able to relax for unordered atomics (see D66309)
20118 if (!Ld->isSimple() || Ld->isIndexed())
20119 return;
20120 }
20121 auto CandidateMatch = [&](StoreSDNode *Other, BaseIndexOffset &Ptr,
20122 int64_t &Offset) -> bool {
20123 // The memory operands must not be volatile/indexed/atomic.
20124 // TODO: May be able to relax for unordered atomics (see D66309)
20125 if (!Other->isSimple() || Other->isIndexed())
20126 return false;
20127 // Don't mix temporal stores with non-temporal stores.
20128 if (St->isNonTemporal() != Other->isNonTemporal())
20129 return false;
20130 if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(NodeX: *St, NodeY: *Other))
20131 return false;
20132 SDValue OtherBC = peekThroughBitcasts(V: Other->getValue());
20133 // Allow merging constants of different types as integers.
20134 bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(VT: Other->getMemoryVT())
20135 : Other->getMemoryVT() != MemVT;
20136 switch (StoreSrc) {
20137 case StoreSource::Load: {
20138 if (NoTypeMatch)
20139 return false;
20140 // The Load's Base Ptr must also match.
20141 auto *OtherLd = dyn_cast<LoadSDNode>(Val&: OtherBC);
20142 if (!OtherLd)
20143 return false;
20144 BaseIndexOffset LPtr = BaseIndexOffset::match(N: OtherLd, DAG);
20145 if (LoadVT != OtherLd->getMemoryVT())
20146 return false;
20147 // Loads must only have one use.
20148 if (!OtherLd->hasNUsesOfValue(NUses: 1, Value: 0))
20149 return false;
20150 // The memory operands must not be volatile/indexed/atomic.
20151 // TODO: May be able to relax for unordered atomics (see D66309)
20152 if (!OtherLd->isSimple() || OtherLd->isIndexed())
20153 return false;
20154 // Don't mix temporal loads with non-temporal loads.
20155 if (cast<LoadSDNode>(Val)->isNonTemporal() != OtherLd->isNonTemporal())
20156 return false;
20157 if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(NodeX: *cast<LoadSDNode>(Val),
20158 NodeY: *OtherLd))
20159 return false;
20160 if (!(LBasePtr.equalBaseIndex(Other: LPtr, DAG)))
20161 return false;
20162 break;
20163 }
20164 case StoreSource::Constant:
20165 if (NoTypeMatch)
20166 return false;
20167 if (getStoreSource(StoreVal: OtherBC) != StoreSource::Constant)
20168 return false;
20169 break;
20170 case StoreSource::Extract:
20171 // Do not merge truncated stores here.
20172 if (Other->isTruncatingStore())
20173 return false;
20174 if (!MemVT.bitsEq(VT: OtherBC.getValueType()))
20175 return false;
20176 if (OtherBC.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
20177 OtherBC.getOpcode() != ISD::EXTRACT_SUBVECTOR)
20178 return false;
20179 break;
20180 default:
20181 llvm_unreachable("Unhandled store source for merging");
20182 }
20183 Ptr = BaseIndexOffset::match(N: Other, DAG);
20184 return (BasePtr.equalBaseIndex(Other: Ptr, DAG, Off&: Offset));
20185 };
20186
20187 // Check if the pair of StoreNode and the RootNode already bail out many
20188 // times which is over the limit in dependence check.
20189 auto OverLimitInDependenceCheck = [&](SDNode *StoreNode,
20190 SDNode *RootNode) -> bool {
20191 auto RootCount = StoreRootCountMap.find(Val: StoreNode);
20192 return RootCount != StoreRootCountMap.end() &&
20193 RootCount->second.first == RootNode &&
20194 RootCount->second.second > StoreMergeDependenceLimit;
20195 };
20196
20197 auto TryToAddCandidate = [&](SDNode::use_iterator UseIter) {
20198 // This must be a chain use.
20199 if (UseIter.getOperandNo() != 0)
20200 return;
20201 if (auto *OtherStore = dyn_cast<StoreSDNode>(Val: *UseIter)) {
20202 BaseIndexOffset Ptr;
20203 int64_t PtrDiff;
20204 if (CandidateMatch(OtherStore, Ptr, PtrDiff) &&
20205 !OverLimitInDependenceCheck(OtherStore, RootNode))
20206 StoreNodes.push_back(Elt: MemOpLink(OtherStore, PtrDiff));
20207 }
20208 };
20209
20210 // We looking for a root node which is an ancestor to all mergable
20211 // stores. We search up through a load, to our root and then down
20212 // through all children. For instance we will find Store{1,2,3} if
20213 // St is Store1, Store2. or Store3 where the root is not a load
20214 // which always true for nonvolatile ops. TODO: Expand
20215 // the search to find all valid candidates through multiple layers of loads.
20216 //
20217 // Root
20218 // |-------|-------|
20219 // Load Load Store3
20220 // | |
20221 // Store1 Store2
20222 //
20223 // FIXME: We should be able to climb and
20224 // descend TokenFactors to find candidates as well.
20225
20226 RootNode = St->getChain().getNode();
20227
20228 unsigned NumNodesExplored = 0;
20229 const unsigned MaxSearchNodes = 1024;
20230 if (auto *Ldn = dyn_cast<LoadSDNode>(Val: RootNode)) {
20231 RootNode = Ldn->getChain().getNode();
20232 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
20233 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored) {
20234 if (I.getOperandNo() == 0 && isa<LoadSDNode>(Val: *I)) { // walk down chain
20235 for (auto I2 = (*I)->use_begin(), E2 = (*I)->use_end(); I2 != E2; ++I2)
20236 TryToAddCandidate(I2);
20237 }
20238 // Check stores that depend on the root (e.g. Store 3 in the chart above).
20239 if (I.getOperandNo() == 0 && isa<StoreSDNode>(Val: *I)) {
20240 TryToAddCandidate(I);
20241 }
20242 }
20243 } else {
20244 for (auto I = RootNode->use_begin(), E = RootNode->use_end();
20245 I != E && NumNodesExplored < MaxSearchNodes; ++I, ++NumNodesExplored)
20246 TryToAddCandidate(I);
20247 }
20248}
20249
20250// We need to check that merging these stores does not cause a loop in the
20251// DAG. Any store candidate may depend on another candidate indirectly through
20252// its operands. Check in parallel by searching up from operands of candidates.
20253bool DAGCombiner::checkMergeStoreCandidatesForDependencies(
20254 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumStores,
20255 SDNode *RootNode) {
20256 // FIXME: We should be able to truncate a full search of
20257 // predecessors by doing a BFS and keeping tabs the originating
20258 // stores from which worklist nodes come from in a similar way to
20259 // TokenFactor simplfication.
20260
20261 SmallPtrSet<const SDNode *, 32> Visited;
20262 SmallVector<const SDNode *, 8> Worklist;
20263
20264 // RootNode is a predecessor to all candidates so we need not search
20265 // past it. Add RootNode (peeking through TokenFactors). Do not count
20266 // these towards size check.
20267
20268 Worklist.push_back(Elt: RootNode);
20269 while (!Worklist.empty()) {
20270 auto N = Worklist.pop_back_val();
20271 if (!Visited.insert(Ptr: N).second)
20272 continue; // Already present in Visited.
20273 if (N->getOpcode() == ISD::TokenFactor) {
20274 for (SDValue Op : N->ops())
20275 Worklist.push_back(Elt: Op.getNode());
20276 }
20277 }
20278
20279 // Don't count pruning nodes towards max.
20280 unsigned int Max = 1024 + Visited.size();
20281 // Search Ops of store candidates.
20282 for (unsigned i = 0; i < NumStores; ++i) {
20283 SDNode *N = StoreNodes[i].MemNode;
20284 // Of the 4 Store Operands:
20285 // * Chain (Op 0) -> We have already considered these
20286 // in candidate selection, but only by following the
20287 // chain dependencies. We could still have a chain
20288 // dependency to a load, that has a non-chain dep to
20289 // another load, that depends on a store, etc. So it is
20290 // possible to have dependencies that consist of a mix
20291 // of chain and non-chain deps, and we need to include
20292 // chain operands in the analysis here..
20293 // * Value (Op 1) -> Cycles may happen (e.g. through load chains)
20294 // * Address (Op 2) -> Merged addresses may only vary by a fixed constant,
20295 // but aren't necessarily fromt the same base node, so
20296 // cycles possible (e.g. via indexed store).
20297 // * (Op 3) -> Represents the pre or post-indexing offset (or undef for
20298 // non-indexed stores). Not constant on all targets (e.g. ARM)
20299 // and so can participate in a cycle.
20300 for (unsigned j = 0; j < N->getNumOperands(); ++j)
20301 Worklist.push_back(Elt: N->getOperand(Num: j).getNode());
20302 }
20303 // Search through DAG. We can stop early if we find a store node.
20304 for (unsigned i = 0; i < NumStores; ++i)
20305 if (SDNode::hasPredecessorHelper(N: StoreNodes[i].MemNode, Visited, Worklist,
20306 MaxSteps: Max)) {
20307 // If the searching bail out, record the StoreNode and RootNode in the
20308 // StoreRootCountMap. If we have seen the pair many times over a limit,
20309 // we won't add the StoreNode into StoreNodes set again.
20310 if (Visited.size() >= Max) {
20311 auto &RootCount = StoreRootCountMap[StoreNodes[i].MemNode];
20312 if (RootCount.first == RootNode)
20313 RootCount.second++;
20314 else
20315 RootCount = {RootNode, 1};
20316 }
20317 return false;
20318 }
20319 return true;
20320}
20321
20322unsigned
20323DAGCombiner::getConsecutiveStores(SmallVectorImpl<MemOpLink> &StoreNodes,
20324 int64_t ElementSizeBytes) const {
20325 while (true) {
20326 // Find a store past the width of the first store.
20327 size_t StartIdx = 0;
20328 while ((StartIdx + 1 < StoreNodes.size()) &&
20329 StoreNodes[StartIdx].OffsetFromBase + ElementSizeBytes !=
20330 StoreNodes[StartIdx + 1].OffsetFromBase)
20331 ++StartIdx;
20332
20333 // Bail if we don't have enough candidates to merge.
20334 if (StartIdx + 1 >= StoreNodes.size())
20335 return 0;
20336
20337 // Trim stores that overlapped with the first store.
20338 if (StartIdx)
20339 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + StartIdx);
20340
20341 // Scan the memory operations on the chain and find the first
20342 // non-consecutive store memory address.
20343 unsigned NumConsecutiveStores = 1;
20344 int64_t StartAddress = StoreNodes[0].OffsetFromBase;
20345 // Check that the addresses are consecutive starting from the second
20346 // element in the list of stores.
20347 for (unsigned i = 1, e = StoreNodes.size(); i < e; ++i) {
20348 int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
20349 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
20350 break;
20351 NumConsecutiveStores = i + 1;
20352 }
20353 if (NumConsecutiveStores > 1)
20354 return NumConsecutiveStores;
20355
20356 // There are no consecutive stores at the start of the list.
20357 // Remove the first store and try again.
20358 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + 1);
20359 }
20360}
20361
20362bool DAGCombiner::tryStoreMergeOfConstants(
20363 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
20364 EVT MemVT, SDNode *RootNode, bool AllowVectors) {
20365 LLVMContext &Context = *DAG.getContext();
20366 const DataLayout &DL = DAG.getDataLayout();
20367 int64_t ElementSizeBytes = MemVT.getStoreSize();
20368 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20369 bool MadeChange = false;
20370
20371 // Store the constants into memory as one consecutive store.
20372 while (NumConsecutiveStores >= 2) {
20373 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20374 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20375 Align FirstStoreAlign = FirstInChain->getAlign();
20376 unsigned LastLegalType = 1;
20377 unsigned LastLegalVectorType = 1;
20378 bool LastIntegerTrunc = false;
20379 bool NonZero = false;
20380 unsigned FirstZeroAfterNonZero = NumConsecutiveStores;
20381 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20382 StoreSDNode *ST = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
20383 SDValue StoredVal = ST->getValue();
20384 bool IsElementZero = false;
20385 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Val&: StoredVal))
20386 IsElementZero = C->isZero();
20387 else if (ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Val&: StoredVal))
20388 IsElementZero = C->getConstantFPValue()->isNullValue();
20389 else if (ISD::isBuildVectorAllZeros(N: StoredVal.getNode()))
20390 IsElementZero = true;
20391 if (IsElementZero) {
20392 if (NonZero && FirstZeroAfterNonZero == NumConsecutiveStores)
20393 FirstZeroAfterNonZero = i;
20394 }
20395 NonZero |= !IsElementZero;
20396
20397 // Find a legal type for the constant store.
20398 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
20399 EVT StoreTy = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
20400 unsigned IsFast = 0;
20401
20402 // Break early when size is too large to be legal.
20403 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
20404 break;
20405
20406 if (TLI.isTypeLegal(VT: StoreTy) &&
20407 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
20408 MF: DAG.getMachineFunction()) &&
20409 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20410 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
20411 IsFast) {
20412 LastIntegerTrunc = false;
20413 LastLegalType = i + 1;
20414 // Or check whether a truncstore is legal.
20415 } else if (TLI.getTypeAction(Context, VT: StoreTy) ==
20416 TargetLowering::TypePromoteInteger) {
20417 EVT LegalizedStoredValTy =
20418 TLI.getTypeToTransformTo(Context, VT: StoredVal.getValueType());
20419 if (TLI.isTruncStoreLegal(ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
20420 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: LegalizedStoredValTy,
20421 MF: DAG.getMachineFunction()) &&
20422 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20423 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
20424 IsFast) {
20425 LastIntegerTrunc = true;
20426 LastLegalType = i + 1;
20427 }
20428 }
20429
20430 // We only use vectors if the target allows it and the function is not
20431 // marked with the noimplicitfloat attribute.
20432 if (TLI.storeOfVectorConstantIsCheap(IsZero: !NonZero, MemVT, NumElem: i + 1, AddrSpace: FirstStoreAS) &&
20433 AllowVectors) {
20434 // Find a legal type for the vector store.
20435 unsigned Elts = (i + 1) * NumMemElts;
20436 EVT Ty = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
20437 if (TLI.isTypeLegal(VT: Ty) && TLI.isTypeLegal(VT: MemVT) &&
20438 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: Ty, MF: DAG.getMachineFunction()) &&
20439 TLI.allowsMemoryAccess(Context, DL, VT: Ty,
20440 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
20441 IsFast)
20442 LastLegalVectorType = i + 1;
20443 }
20444 }
20445
20446 bool UseVector = (LastLegalVectorType > LastLegalType) && AllowVectors;
20447 unsigned NumElem = (UseVector) ? LastLegalVectorType : LastLegalType;
20448 bool UseTrunc = LastIntegerTrunc && !UseVector;
20449
20450 // Check if we found a legal integer type that creates a meaningful
20451 // merge.
20452 if (NumElem < 2) {
20453 // We know that candidate stores are in order and of correct
20454 // shape. While there is no mergeable sequence from the
20455 // beginning one may start later in the sequence. The only
20456 // reason a merge of size N could have failed where another of
20457 // the same size would not have, is if the alignment has
20458 // improved or we've dropped a non-zero value. Drop as many
20459 // candidates as we can here.
20460 unsigned NumSkip = 1;
20461 while ((NumSkip < NumConsecutiveStores) &&
20462 (NumSkip < FirstZeroAfterNonZero) &&
20463 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
20464 NumSkip++;
20465
20466 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
20467 NumConsecutiveStores -= NumSkip;
20468 continue;
20469 }
20470
20471 // Check that we can merge these candidates without causing a cycle.
20472 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumElem,
20473 RootNode)) {
20474 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
20475 NumConsecutiveStores -= NumElem;
20476 continue;
20477 }
20478
20479 MadeChange |= mergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumStores: NumElem,
20480 /*IsConstantSrc*/ true,
20481 UseVector, UseTrunc);
20482
20483 // Remove merged stores for next iteration.
20484 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
20485 NumConsecutiveStores -= NumElem;
20486 }
20487 return MadeChange;
20488}
20489
20490bool DAGCombiner::tryStoreMergeOfExtracts(
20491 SmallVectorImpl<MemOpLink> &StoreNodes, unsigned NumConsecutiveStores,
20492 EVT MemVT, SDNode *RootNode) {
20493 LLVMContext &Context = *DAG.getContext();
20494 const DataLayout &DL = DAG.getDataLayout();
20495 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20496 bool MadeChange = false;
20497
20498 // Loop on Consecutive Stores on success.
20499 while (NumConsecutiveStores >= 2) {
20500 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20501 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20502 Align FirstStoreAlign = FirstInChain->getAlign();
20503 unsigned NumStoresToMerge = 1;
20504 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20505 // Find a legal type for the vector store.
20506 unsigned Elts = (i + 1) * NumMemElts;
20507 EVT Ty = EVT::getVectorVT(Context&: *DAG.getContext(), VT: MemVT.getScalarType(), NumElements: Elts);
20508 unsigned IsFast = 0;
20509
20510 // Break early when size is too large to be legal.
20511 if (Ty.getSizeInBits() > MaximumLegalStoreInBits)
20512 break;
20513
20514 if (TLI.isTypeLegal(VT: Ty) &&
20515 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: Ty, MF: DAG.getMachineFunction()) &&
20516 TLI.allowsMemoryAccess(Context, DL, VT: Ty,
20517 MMO: *FirstInChain->getMemOperand(), Fast: &IsFast) &&
20518 IsFast)
20519 NumStoresToMerge = i + 1;
20520 }
20521
20522 // Check if we found a legal integer type creating a meaningful
20523 // merge.
20524 if (NumStoresToMerge < 2) {
20525 // We know that candidate stores are in order and of correct
20526 // shape. While there is no mergeable sequence from the
20527 // beginning one may start later in the sequence. The only
20528 // reason a merge of size N could have failed where another of
20529 // the same size would not have, is if the alignment has
20530 // improved. Drop as many candidates as we can here.
20531 unsigned NumSkip = 1;
20532 while ((NumSkip < NumConsecutiveStores) &&
20533 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
20534 NumSkip++;
20535
20536 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
20537 NumConsecutiveStores -= NumSkip;
20538 continue;
20539 }
20540
20541 // Check that we can merge these candidates without causing a cycle.
20542 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumStoresToMerge,
20543 RootNode)) {
20544 StoreNodes.erase(CS: StoreNodes.begin(),
20545 CE: StoreNodes.begin() + NumStoresToMerge);
20546 NumConsecutiveStores -= NumStoresToMerge;
20547 continue;
20548 }
20549
20550 MadeChange |= mergeStoresOfConstantsOrVecElts(
20551 StoreNodes, MemVT, NumStores: NumStoresToMerge, /*IsConstantSrc*/ false,
20552 /*UseVector*/ true, /*UseTrunc*/ false);
20553
20554 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumStoresToMerge);
20555 NumConsecutiveStores -= NumStoresToMerge;
20556 }
20557 return MadeChange;
20558}
20559
20560bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl<MemOpLink> &StoreNodes,
20561 unsigned NumConsecutiveStores, EVT MemVT,
20562 SDNode *RootNode, bool AllowVectors,
20563 bool IsNonTemporalStore,
20564 bool IsNonTemporalLoad) {
20565 LLVMContext &Context = *DAG.getContext();
20566 const DataLayout &DL = DAG.getDataLayout();
20567 int64_t ElementSizeBytes = MemVT.getStoreSize();
20568 unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
20569 bool MadeChange = false;
20570
20571 // Look for load nodes which are used by the stored values.
20572 SmallVector<MemOpLink, 8> LoadNodes;
20573
20574 // Find acceptable loads. Loads need to have the same chain (token factor),
20575 // must not be zext, volatile, indexed, and they must be consecutive.
20576 BaseIndexOffset LdBasePtr;
20577
20578 for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
20579 StoreSDNode *St = cast<StoreSDNode>(Val: StoreNodes[i].MemNode);
20580 SDValue Val = peekThroughBitcasts(V: St->getValue());
20581 LoadSDNode *Ld = cast<LoadSDNode>(Val);
20582
20583 BaseIndexOffset LdPtr = BaseIndexOffset::match(N: Ld, DAG);
20584 // If this is not the first ptr that we check.
20585 int64_t LdOffset = 0;
20586 if (LdBasePtr.getBase().getNode()) {
20587 // The base ptr must be the same.
20588 if (!LdBasePtr.equalBaseIndex(Other: LdPtr, DAG, Off&: LdOffset))
20589 break;
20590 } else {
20591 // Check that all other base pointers are the same as this one.
20592 LdBasePtr = LdPtr;
20593 }
20594
20595 // We found a potential memory operand to merge.
20596 LoadNodes.push_back(Elt: MemOpLink(Ld, LdOffset));
20597 }
20598
20599 while (NumConsecutiveStores >= 2 && LoadNodes.size() >= 2) {
20600 Align RequiredAlignment;
20601 bool NeedRotate = false;
20602 if (LoadNodes.size() == 2) {
20603 // If we have load/store pair instructions and we only have two values,
20604 // don't bother merging.
20605 if (TLI.hasPairedLoad(MemVT, RequiredAlignment) &&
20606 StoreNodes[0].MemNode->getAlign() >= RequiredAlignment) {
20607 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + 2);
20608 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + 2);
20609 break;
20610 }
20611 // If the loads are reversed, see if we can rotate the halves into place.
20612 int64_t Offset0 = LoadNodes[0].OffsetFromBase;
20613 int64_t Offset1 = LoadNodes[1].OffsetFromBase;
20614 EVT PairVT = EVT::getIntegerVT(Context, BitWidth: ElementSizeBytes * 8 * 2);
20615 if (Offset0 - Offset1 == ElementSizeBytes &&
20616 (hasOperation(Opcode: ISD::ROTL, VT: PairVT) ||
20617 hasOperation(Opcode: ISD::ROTR, VT: PairVT))) {
20618 std::swap(a&: LoadNodes[0], b&: LoadNodes[1]);
20619 NeedRotate = true;
20620 }
20621 }
20622 LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode;
20623 unsigned FirstStoreAS = FirstInChain->getAddressSpace();
20624 Align FirstStoreAlign = FirstInChain->getAlign();
20625 LoadSDNode *FirstLoad = cast<LoadSDNode>(Val: LoadNodes[0].MemNode);
20626
20627 // Scan the memory operations on the chain and find the first
20628 // non-consecutive load memory address. These variables hold the index in
20629 // the store node array.
20630
20631 unsigned LastConsecutiveLoad = 1;
20632
20633 // This variable refers to the size and not index in the array.
20634 unsigned LastLegalVectorType = 1;
20635 unsigned LastLegalIntegerType = 1;
20636 bool isDereferenceable = true;
20637 bool DoIntegerTruncate = false;
20638 int64_t StartAddress = LoadNodes[0].OffsetFromBase;
20639 SDValue LoadChain = FirstLoad->getChain();
20640 for (unsigned i = 1; i < LoadNodes.size(); ++i) {
20641 // All loads must share the same chain.
20642 if (LoadNodes[i].MemNode->getChain() != LoadChain)
20643 break;
20644
20645 int64_t CurrAddress = LoadNodes[i].OffsetFromBase;
20646 if (CurrAddress - StartAddress != (ElementSizeBytes * i))
20647 break;
20648 LastConsecutiveLoad = i;
20649
20650 if (isDereferenceable && !LoadNodes[i].MemNode->isDereferenceable())
20651 isDereferenceable = false;
20652
20653 // Find a legal type for the vector store.
20654 unsigned Elts = (i + 1) * NumMemElts;
20655 EVT StoreTy = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
20656
20657 // Break early when size is too large to be legal.
20658 if (StoreTy.getSizeInBits() > MaximumLegalStoreInBits)
20659 break;
20660
20661 unsigned IsFastSt = 0;
20662 unsigned IsFastLd = 0;
20663 // Don't try vector types if we need a rotate. We may still fail the
20664 // legality checks for the integer type, but we can't handle the rotate
20665 // case with vectors.
20666 // FIXME: We could use a shuffle in place of the rotate.
20667 if (!NeedRotate && TLI.isTypeLegal(VT: StoreTy) &&
20668 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
20669 MF: DAG.getMachineFunction()) &&
20670 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20671 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
20672 IsFastSt &&
20673 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20674 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
20675 IsFastLd) {
20676 LastLegalVectorType = i + 1;
20677 }
20678
20679 // Find a legal type for the integer store.
20680 unsigned SizeInBits = (i + 1) * ElementSizeBytes * 8;
20681 StoreTy = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
20682 if (TLI.isTypeLegal(VT: StoreTy) &&
20683 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: StoreTy,
20684 MF: DAG.getMachineFunction()) &&
20685 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20686 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
20687 IsFastSt &&
20688 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20689 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
20690 IsFastLd) {
20691 LastLegalIntegerType = i + 1;
20692 DoIntegerTruncate = false;
20693 // Or check whether a truncstore and extload is legal.
20694 } else if (TLI.getTypeAction(Context, VT: StoreTy) ==
20695 TargetLowering::TypePromoteInteger) {
20696 EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, VT: StoreTy);
20697 if (TLI.isTruncStoreLegal(ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
20698 TLI.canMergeStoresTo(AS: FirstStoreAS, MemVT: LegalizedStoredValTy,
20699 MF: DAG.getMachineFunction()) &&
20700 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
20701 TLI.isLoadExtLegal(ExtType: ISD::SEXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
20702 TLI.isLoadExtLegal(ExtType: ISD::EXTLOAD, ValVT: LegalizedStoredValTy, MemVT: StoreTy) &&
20703 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20704 MMO: *FirstInChain->getMemOperand(), Fast: &IsFastSt) &&
20705 IsFastSt &&
20706 TLI.allowsMemoryAccess(Context, DL, VT: StoreTy,
20707 MMO: *FirstLoad->getMemOperand(), Fast: &IsFastLd) &&
20708 IsFastLd) {
20709 LastLegalIntegerType = i + 1;
20710 DoIntegerTruncate = true;
20711 }
20712 }
20713 }
20714
20715 // Only use vector types if the vector type is larger than the integer
20716 // type. If they are the same, use integers.
20717 bool UseVectorTy =
20718 LastLegalVectorType > LastLegalIntegerType && AllowVectors;
20719 unsigned LastLegalType =
20720 std::max(a: LastLegalVectorType, b: LastLegalIntegerType);
20721
20722 // We add +1 here because the LastXXX variables refer to location while
20723 // the NumElem refers to array/index size.
20724 unsigned NumElem = std::min(a: NumConsecutiveStores, b: LastConsecutiveLoad + 1);
20725 NumElem = std::min(a: LastLegalType, b: NumElem);
20726 Align FirstLoadAlign = FirstLoad->getAlign();
20727
20728 if (NumElem < 2) {
20729 // We know that candidate stores are in order and of correct
20730 // shape. While there is no mergeable sequence from the
20731 // beginning one may start later in the sequence. The only
20732 // reason a merge of size N could have failed where another of
20733 // the same size would not have is if the alignment or either
20734 // the load or store has improved. Drop as many candidates as we
20735 // can here.
20736 unsigned NumSkip = 1;
20737 while ((NumSkip < LoadNodes.size()) &&
20738 (LoadNodes[NumSkip].MemNode->getAlign() <= FirstLoadAlign) &&
20739 (StoreNodes[NumSkip].MemNode->getAlign() <= FirstStoreAlign))
20740 NumSkip++;
20741 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumSkip);
20742 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumSkip);
20743 NumConsecutiveStores -= NumSkip;
20744 continue;
20745 }
20746
20747 // Check that we can merge these candidates without causing a cycle.
20748 if (!checkMergeStoreCandidatesForDependencies(StoreNodes, NumStores: NumElem,
20749 RootNode)) {
20750 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
20751 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumElem);
20752 NumConsecutiveStores -= NumElem;
20753 continue;
20754 }
20755
20756 // Find if it is better to use vectors or integers to load and store
20757 // to memory.
20758 EVT JointMemOpVT;
20759 if (UseVectorTy) {
20760 // Find a legal type for the vector store.
20761 unsigned Elts = NumElem * NumMemElts;
20762 JointMemOpVT = EVT::getVectorVT(Context, VT: MemVT.getScalarType(), NumElements: Elts);
20763 } else {
20764 unsigned SizeInBits = NumElem * ElementSizeBytes * 8;
20765 JointMemOpVT = EVT::getIntegerVT(Context, BitWidth: SizeInBits);
20766 }
20767
20768 SDLoc LoadDL(LoadNodes[0].MemNode);
20769 SDLoc StoreDL(StoreNodes[0].MemNode);
20770
20771 // The merged loads are required to have the same incoming chain, so
20772 // using the first's chain is acceptable.
20773
20774 SDValue NewStoreChain = getMergeStoreChains(StoreNodes, NumStores: NumElem);
20775 bool CanReusePtrInfo = hasSameUnderlyingObj(StoreNodes);
20776 AddToWorklist(N: NewStoreChain.getNode());
20777
20778 MachineMemOperand::Flags LdMMOFlags =
20779 isDereferenceable ? MachineMemOperand::MODereferenceable
20780 : MachineMemOperand::MONone;
20781 if (IsNonTemporalLoad)
20782 LdMMOFlags |= MachineMemOperand::MONonTemporal;
20783
20784 LdMMOFlags |= TLI.getTargetMMOFlags(Node: *FirstLoad);
20785
20786 MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore
20787 ? MachineMemOperand::MONonTemporal
20788 : MachineMemOperand::MONone;
20789
20790 StMMOFlags |= TLI.getTargetMMOFlags(Node: *StoreNodes[0].MemNode);
20791
20792 SDValue NewLoad, NewStore;
20793 if (UseVectorTy || !DoIntegerTruncate) {
20794 NewLoad = DAG.getLoad(
20795 VT: JointMemOpVT, dl: LoadDL, Chain: FirstLoad->getChain(), Ptr: FirstLoad->getBasePtr(),
20796 PtrInfo: FirstLoad->getPointerInfo(), Alignment: FirstLoadAlign, MMOFlags: LdMMOFlags);
20797 SDValue StoreOp = NewLoad;
20798 if (NeedRotate) {
20799 unsigned LoadWidth = ElementSizeBytes * 8 * 2;
20800 assert(JointMemOpVT == EVT::getIntegerVT(Context, LoadWidth) &&
20801 "Unexpected type for rotate-able load pair");
20802 SDValue RotAmt =
20803 DAG.getShiftAmountConstant(Val: LoadWidth / 2, VT: JointMemOpVT, DL: LoadDL);
20804 // Target can convert to the identical ROTR if it does not have ROTL.
20805 StoreOp = DAG.getNode(Opcode: ISD::ROTL, DL: LoadDL, VT: JointMemOpVT, N1: NewLoad, N2: RotAmt);
20806 }
20807 NewStore = DAG.getStore(
20808 Chain: NewStoreChain, dl: StoreDL, Val: StoreOp, Ptr: FirstInChain->getBasePtr(),
20809 PtrInfo: CanReusePtrInfo ? FirstInChain->getPointerInfo()
20810 : MachinePointerInfo(FirstStoreAS),
20811 Alignment: FirstStoreAlign, MMOFlags: StMMOFlags);
20812 } else { // This must be the truncstore/extload case
20813 EVT ExtendedTy =
20814 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: JointMemOpVT);
20815 NewLoad = DAG.getExtLoad(ExtType: ISD::EXTLOAD, dl: LoadDL, VT: ExtendedTy,
20816 Chain: FirstLoad->getChain(), Ptr: FirstLoad->getBasePtr(),
20817 PtrInfo: FirstLoad->getPointerInfo(), MemVT: JointMemOpVT,
20818 Alignment: FirstLoadAlign, MMOFlags: LdMMOFlags);
20819 NewStore = DAG.getTruncStore(
20820 Chain: NewStoreChain, dl: StoreDL, Val: NewLoad, Ptr: FirstInChain->getBasePtr(),
20821 PtrInfo: CanReusePtrInfo ? FirstInChain->getPointerInfo()
20822 : MachinePointerInfo(FirstStoreAS),
20823 SVT: JointMemOpVT, Alignment: FirstInChain->getAlign(),
20824 MMOFlags: FirstInChain->getMemOperand()->getFlags());
20825 }
20826
20827 // Transfer chain users from old loads to the new load.
20828 for (unsigned i = 0; i < NumElem; ++i) {
20829 LoadSDNode *Ld = cast<LoadSDNode>(Val: LoadNodes[i].MemNode);
20830 DAG.ReplaceAllUsesOfValueWith(From: SDValue(Ld, 1),
20831 To: SDValue(NewLoad.getNode(), 1));
20832 }
20833
20834 // Replace all stores with the new store. Recursively remove corresponding
20835 // values if they are no longer used.
20836 for (unsigned i = 0; i < NumElem; ++i) {
20837 SDValue Val = StoreNodes[i].MemNode->getOperand(Num: 1);
20838 CombineTo(N: StoreNodes[i].MemNode, Res: NewStore);
20839 if (Val->use_empty())
20840 recursivelyDeleteUnusedNodes(N: Val.getNode());
20841 }
20842
20843 MadeChange = true;
20844 StoreNodes.erase(CS: StoreNodes.begin(), CE: StoreNodes.begin() + NumElem);
20845 LoadNodes.erase(CS: LoadNodes.begin(), CE: LoadNodes.begin() + NumElem);
20846 NumConsecutiveStores -= NumElem;
20847 }
20848 return MadeChange;
20849}
20850
20851bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
20852 if (OptLevel == CodeGenOptLevel::None || !EnableStoreMerging)
20853 return false;
20854
20855 // TODO: Extend this function to merge stores of scalable vectors.
20856 // (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
20857 // store since we know <vscale x 16 x i8> is exactly twice as large as
20858 // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
20859 EVT MemVT = St->getMemoryVT();
20860 if (MemVT.isScalableVT())
20861 return false;
20862 if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
20863 return false;
20864
20865 // This function cannot currently deal with non-byte-sized memory sizes.
20866 int64_t ElementSizeBytes = MemVT.getStoreSize();
20867 if (ElementSizeBytes * 8 != (int64_t)MemVT.getSizeInBits())
20868 return false;
20869
20870 // Do not bother looking at stored values that are not constants, loads, or
20871 // extracted vector elements.
20872 SDValue StoredVal = peekThroughBitcasts(V: St->getValue());
20873 const StoreSource StoreSrc = getStoreSource(StoreVal: StoredVal);
20874 if (StoreSrc == StoreSource::Unknown)
20875 return false;
20876
20877 SmallVector<MemOpLink, 8> StoreNodes;
20878 SDNode *RootNode;
20879 // Find potential store merge candidates by searching through chain sub-DAG
20880 getStoreMergeCandidates(St, StoreNodes, RootNode);
20881
20882 // Check if there is anything to merge.
20883 if (StoreNodes.size() < 2)
20884 return false;
20885
20886 // Sort the memory operands according to their distance from the
20887 // base pointer.
20888 llvm::sort(C&: StoreNodes, Comp: [](MemOpLink LHS, MemOpLink RHS) {
20889 return LHS.OffsetFromBase < RHS.OffsetFromBase;
20890 });
20891
20892 bool AllowVectors = !DAG.getMachineFunction().getFunction().hasFnAttribute(
20893 Attribute::NoImplicitFloat);
20894 bool IsNonTemporalStore = St->isNonTemporal();
20895 bool IsNonTemporalLoad = StoreSrc == StoreSource::Load &&
20896 cast<LoadSDNode>(Val&: StoredVal)->isNonTemporal();
20897
20898 // Store Merge attempts to merge the lowest stores. This generally
20899 // works out as if successful, as the remaining stores are checked
20900 // after the first collection of stores is merged. However, in the
20901 // case that a non-mergeable store is found first, e.g., {p[-2],
20902 // p[0], p[1], p[2], p[3]}, we would fail and miss the subsequent
20903 // mergeable cases. To prevent this, we prune such stores from the
20904 // front of StoreNodes here.
20905 bool MadeChange = false;
20906 while (StoreNodes.size() > 1) {
20907 unsigned NumConsecutiveStores =
20908 getConsecutiveStores(StoreNodes, ElementSizeBytes);
20909 // There are no more stores in the list to examine.
20910 if (NumConsecutiveStores == 0)
20911 return MadeChange;
20912
20913 // We have at least 2 consecutive stores. Try to merge them.
20914 assert(NumConsecutiveStores >= 2 && "Expected at least 2 stores");
20915 switch (StoreSrc) {
20916 case StoreSource::Constant:
20917 MadeChange |= tryStoreMergeOfConstants(StoreNodes, NumConsecutiveStores,
20918 MemVT, RootNode, AllowVectors);
20919 break;
20920
20921 case StoreSource::Extract:
20922 MadeChange |= tryStoreMergeOfExtracts(StoreNodes, NumConsecutiveStores,
20923 MemVT, RootNode);
20924 break;
20925
20926 case StoreSource::Load:
20927 MadeChange |= tryStoreMergeOfLoads(StoreNodes, NumConsecutiveStores,
20928 MemVT, RootNode, AllowVectors,
20929 IsNonTemporalStore, IsNonTemporalLoad);
20930 break;
20931
20932 default:
20933 llvm_unreachable("Unhandled store source type");
20934 }
20935 }
20936 return MadeChange;
20937}
20938
20939SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) {
20940 SDLoc SL(ST);
20941 SDValue ReplStore;
20942
20943 // Replace the chain to avoid dependency.
20944 if (ST->isTruncatingStore()) {
20945 ReplStore = DAG.getTruncStore(Chain: BetterChain, dl: SL, Val: ST->getValue(),
20946 Ptr: ST->getBasePtr(), SVT: ST->getMemoryVT(),
20947 MMO: ST->getMemOperand());
20948 } else {
20949 ReplStore = DAG.getStore(Chain: BetterChain, dl: SL, Val: ST->getValue(), Ptr: ST->getBasePtr(),
20950 MMO: ST->getMemOperand());
20951 }
20952
20953 // Create token to keep both nodes around.
20954 SDValue Token = DAG.getNode(ISD::TokenFactor, SL,
20955 MVT::Other, ST->getChain(), ReplStore);
20956
20957 // Make sure the new and old chains are cleaned up.
20958 AddToWorklist(N: Token.getNode());
20959
20960 // Don't add users to work list.
20961 return CombineTo(N: ST, Res: Token, AddTo: false);
20962}
20963
20964SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) {
20965 SDValue Value = ST->getValue();
20966 if (Value.getOpcode() == ISD::TargetConstantFP)
20967 return SDValue();
20968
20969 if (!ISD::isNormalStore(N: ST))
20970 return SDValue();
20971
20972 SDLoc DL(ST);
20973
20974 SDValue Chain = ST->getChain();
20975 SDValue Ptr = ST->getBasePtr();
20976
20977 const ConstantFPSDNode *CFP = cast<ConstantFPSDNode>(Val&: Value);
20978
20979 // NOTE: If the original store is volatile, this transform must not increase
20980 // the number of stores. For example, on x86-32 an f64 can be stored in one
20981 // processor operation but an i64 (which is not legal) requires two. So the
20982 // transform should not be done in this case.
20983
20984 SDValue Tmp;
20985 switch (CFP->getSimpleValueType(ResNo: 0).SimpleTy) {
20986 default:
20987 llvm_unreachable("Unknown FP type");
20988 case MVT::f16: // We don't do this for these yet.
20989 case MVT::bf16:
20990 case MVT::f80:
20991 case MVT::f128:
20992 case MVT::ppcf128:
20993 return SDValue();
20994 case MVT::f32:
20995 if ((isTypeLegal(MVT::i32) && !LegalOperations && ST->isSimple()) ||
20996 TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) {
20997 Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF().
20998 bitcastToAPInt().getZExtValue(), SDLoc(CFP),
20999 MVT::i32);
21000 return DAG.getStore(Chain, dl: DL, Val: Tmp, Ptr, MMO: ST->getMemOperand());
21001 }
21002
21003 return SDValue();
21004 case MVT::f64:
21005 if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations &&
21006 ST->isSimple()) ||
21007 TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) {
21008 Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt().
21009 getZExtValue(), SDLoc(CFP), MVT::i64);
21010 return DAG.getStore(Chain, dl: DL, Val: Tmp,
21011 Ptr, MMO: ST->getMemOperand());
21012 }
21013
21014 if (ST->isSimple() && TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32) &&
21015 !TLI.isFPImmLegal(CFP->getValueAPF(), MVT::f64)) {
21016 // Many FP stores are not made apparent until after legalize, e.g. for
21017 // argument passing. Since this is so common, custom legalize the
21018 // 64-bit integer store into two 32-bit stores.
21019 uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue();
21020 SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32);
21021 SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32);
21022 if (DAG.getDataLayout().isBigEndian())
21023 std::swap(a&: Lo, b&: Hi);
21024
21025 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
21026 AAMDNodes AAInfo = ST->getAAInfo();
21027
21028 SDValue St0 = DAG.getStore(Chain, dl: DL, Val: Lo, Ptr, PtrInfo: ST->getPointerInfo(),
21029 Alignment: ST->getOriginalAlign(), MMOFlags, AAInfo);
21030 Ptr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: 4), DL);
21031 SDValue St1 = DAG.getStore(Chain, dl: DL, Val: Hi, Ptr,
21032 PtrInfo: ST->getPointerInfo().getWithOffset(O: 4),
21033 Alignment: ST->getOriginalAlign(), MMOFlags, AAInfo);
21034 return DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
21035 St0, St1);
21036 }
21037
21038 return SDValue();
21039 }
21040}
21041
21042// (store (insert_vector_elt (load p), x, i), p) -> (store x, p+offset)
21043//
21044// If a store of a load with an element inserted into it has no other
21045// uses in between the chain, then we can consider the vector store
21046// dead and replace it with just the single scalar element store.
21047SDValue DAGCombiner::replaceStoreOfInsertLoad(StoreSDNode *ST) {
21048 SDLoc DL(ST);
21049 SDValue Value = ST->getValue();
21050 SDValue Ptr = ST->getBasePtr();
21051 SDValue Chain = ST->getChain();
21052 if (Value.getOpcode() != ISD::INSERT_VECTOR_ELT || !Value.hasOneUse())
21053 return SDValue();
21054
21055 SDValue Elt = Value.getOperand(i: 1);
21056 SDValue Idx = Value.getOperand(i: 2);
21057
21058 // If the element isn't byte sized or is implicitly truncated then we can't
21059 // compute an offset.
21060 EVT EltVT = Elt.getValueType();
21061 if (!EltVT.isByteSized() ||
21062 EltVT != Value.getOperand(i: 0).getValueType().getVectorElementType())
21063 return SDValue();
21064
21065 auto *Ld = dyn_cast<LoadSDNode>(Val: Value.getOperand(i: 0));
21066 if (!Ld || Ld->getBasePtr() != Ptr ||
21067 ST->getMemoryVT() != Ld->getMemoryVT() || !ST->isSimple() ||
21068 !ISD::isNormalStore(N: ST) ||
21069 Ld->getAddressSpace() != ST->getAddressSpace() ||
21070 !Chain.reachesChainWithoutSideEffects(Dest: SDValue(Ld, 1)))
21071 return SDValue();
21072
21073 unsigned IsFast;
21074 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(),
21075 VT: Elt.getValueType(), AddrSpace: ST->getAddressSpace(),
21076 Alignment: ST->getAlign(), Flags: ST->getMemOperand()->getFlags(),
21077 Fast: &IsFast) ||
21078 !IsFast)
21079 return SDValue();
21080
21081 MachinePointerInfo PointerInfo(ST->getAddressSpace());
21082
21083 // If the offset is a known constant then try to recover the pointer
21084 // info
21085 SDValue NewPtr;
21086 if (auto *CIdx = dyn_cast<ConstantSDNode>(Val&: Idx)) {
21087 unsigned COffset = CIdx->getSExtValue() * EltVT.getSizeInBits() / 8;
21088 NewPtr = DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: COffset), DL);
21089 PointerInfo = ST->getPointerInfo().getWithOffset(O: COffset);
21090 } else {
21091 NewPtr = TLI.getVectorElementPointer(DAG, VecPtr: Ptr, VecVT: Value.getValueType(), Index: Idx);
21092 }
21093
21094 return DAG.getStore(Chain, dl: DL, Val: Elt, Ptr: NewPtr, PtrInfo: PointerInfo, Alignment: ST->getAlign(),
21095 MMOFlags: ST->getMemOperand()->getFlags());
21096}
21097
21098SDValue DAGCombiner::visitSTORE(SDNode *N) {
21099 StoreSDNode *ST = cast<StoreSDNode>(Val: N);
21100 SDValue Chain = ST->getChain();
21101 SDValue Value = ST->getValue();
21102 SDValue Ptr = ST->getBasePtr();
21103
21104 // If this is a store of a bit convert, store the input value if the
21105 // resultant store does not need a higher alignment than the original.
21106 if (Value.getOpcode() == ISD::BITCAST && !ST->isTruncatingStore() &&
21107 ST->isUnindexed()) {
21108 EVT SVT = Value.getOperand(i: 0).getValueType();
21109 // If the store is volatile, we only want to change the store type if the
21110 // resulting store is legal. Otherwise we might increase the number of
21111 // memory accesses. We don't care if the original type was legal or not
21112 // as we assume software couldn't rely on the number of accesses of an
21113 // illegal type.
21114 // TODO: May be able to relax for unordered atomics (see D66309)
21115 if (((!LegalOperations && ST->isSimple()) ||
21116 TLI.isOperationLegal(Op: ISD::STORE, VT: SVT)) &&
21117 TLI.isStoreBitCastBeneficial(StoreVT: Value.getValueType(), BitcastVT: SVT,
21118 DAG, MMO: *ST->getMemOperand())) {
21119 return DAG.getStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Ptr,
21120 MMO: ST->getMemOperand());
21121 }
21122 }
21123
21124 // Turn 'store undef, Ptr' -> nothing.
21125 if (Value.isUndef() && ST->isUnindexed())
21126 return Chain;
21127
21128 // Try to infer better alignment information than the store already has.
21129 if (OptLevel != CodeGenOptLevel::None && ST->isUnindexed() &&
21130 !ST->isAtomic()) {
21131 if (MaybeAlign Alignment = DAG.InferPtrAlign(Ptr)) {
21132 if (*Alignment > ST->getAlign() &&
21133 isAligned(Lhs: *Alignment, SizeInBytes: ST->getSrcValueOffset())) {
21134 SDValue NewStore =
21135 DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Value, Ptr, PtrInfo: ST->getPointerInfo(),
21136 SVT: ST->getMemoryVT(), Alignment: *Alignment,
21137 MMOFlags: ST->getMemOperand()->getFlags(), AAInfo: ST->getAAInfo());
21138 // NewStore will always be N as we are only refining the alignment
21139 assert(NewStore.getNode() == N);
21140 (void)NewStore;
21141 }
21142 }
21143 }
21144
21145 // Try transforming a pair floating point load / store ops to integer
21146 // load / store ops.
21147 if (SDValue NewST = TransformFPLoadStorePair(N))
21148 return NewST;
21149
21150 // Try transforming several stores into STORE (BSWAP).
21151 if (SDValue Store = mergeTruncStores(N: ST))
21152 return Store;
21153
21154 if (ST->isUnindexed()) {
21155 // Walk up chain skipping non-aliasing memory nodes, on this store and any
21156 // adjacent stores.
21157 if (findBetterNeighborChains(St: ST)) {
21158 // replaceStoreChain uses CombineTo, which handled all of the worklist
21159 // manipulation. Return the original node to not do anything else.
21160 return SDValue(ST, 0);
21161 }
21162 Chain = ST->getChain();
21163 }
21164
21165 // FIXME: is there such a thing as a truncating indexed store?
21166 if (ST->isTruncatingStore() && ST->isUnindexed() &&
21167 Value.getValueType().isInteger() &&
21168 (!isa<ConstantSDNode>(Val: Value) ||
21169 !cast<ConstantSDNode>(Val&: Value)->isOpaque())) {
21170 // Convert a truncating store of a extension into a standard store.
21171 if ((Value.getOpcode() == ISD::ZERO_EXTEND ||
21172 Value.getOpcode() == ISD::SIGN_EXTEND ||
21173 Value.getOpcode() == ISD::ANY_EXTEND) &&
21174 Value.getOperand(i: 0).getValueType() == ST->getMemoryVT() &&
21175 TLI.isOperationLegalOrCustom(Op: ISD::STORE, VT: ST->getMemoryVT()))
21176 return DAG.getStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0), Ptr,
21177 MMO: ST->getMemOperand());
21178
21179 APInt TruncDemandedBits =
21180 APInt::getLowBitsSet(numBits: Value.getScalarValueSizeInBits(),
21181 loBitsSet: ST->getMemoryVT().getScalarSizeInBits());
21182
21183 // See if we can simplify the operation with SimplifyDemandedBits, which
21184 // only works if the value has a single use.
21185 AddToWorklist(N: Value.getNode());
21186 if (SimplifyDemandedBits(Op: Value, DemandedBits: TruncDemandedBits)) {
21187 // Re-visit the store if anything changed and the store hasn't been merged
21188 // with another node (N is deleted) SimplifyDemandedBits will add Value's
21189 // node back to the worklist if necessary, but we also need to re-visit
21190 // the Store node itself.
21191 if (N->getOpcode() != ISD::DELETED_NODE)
21192 AddToWorklist(N);
21193 return SDValue(N, 0);
21194 }
21195
21196 // Otherwise, see if we can simplify the input to this truncstore with
21197 // knowledge that only the low bits are being used. For example:
21198 // "truncstore (or (shl x, 8), y), i8" -> "truncstore y, i8"
21199 if (SDValue Shorter =
21200 TLI.SimplifyMultipleUseDemandedBits(Op: Value, DemandedBits: TruncDemandedBits, DAG))
21201 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Shorter, Ptr, SVT: ST->getMemoryVT(),
21202 MMO: ST->getMemOperand());
21203
21204 // If we're storing a truncated constant, see if we can simplify it.
21205 // TODO: Move this to targetShrinkDemandedConstant?
21206 if (auto *Cst = dyn_cast<ConstantSDNode>(Val&: Value))
21207 if (!Cst->isOpaque()) {
21208 const APInt &CValue = Cst->getAPIntValue();
21209 APInt NewVal = CValue & TruncDemandedBits;
21210 if (NewVal != CValue) {
21211 SDValue Shorter =
21212 DAG.getConstant(Val: NewVal, DL: SDLoc(N), VT: Value.getValueType());
21213 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Shorter, Ptr,
21214 SVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
21215 }
21216 }
21217 }
21218
21219 // If this is a load followed by a store to the same location, then the store
21220 // is dead/noop. Peek through any truncates if canCombineTruncStore failed.
21221 // TODO: Add big-endian truncate support with test coverage.
21222 // TODO: Can relax for unordered atomics (see D66309)
21223 SDValue TruncVal = DAG.getDataLayout().isLittleEndian()
21224 ? peekThroughTruncates(V: Value)
21225 : Value;
21226 if (auto *Ld = dyn_cast<LoadSDNode>(Val&: TruncVal)) {
21227 if (Ld->getBasePtr() == Ptr && ST->getMemoryVT() == Ld->getMemoryVT() &&
21228 ST->isUnindexed() && ST->isSimple() &&
21229 Ld->getAddressSpace() == ST->getAddressSpace() &&
21230 // There can't be any side effects between the load and store, such as
21231 // a call or store.
21232 Chain.reachesChainWithoutSideEffects(Dest: SDValue(Ld, 1))) {
21233 // The store is dead, remove it.
21234 return Chain;
21235 }
21236 }
21237
21238 // Try scalarizing vector stores of loads where we only change one element
21239 if (SDValue NewST = replaceStoreOfInsertLoad(ST))
21240 return NewST;
21241
21242 // TODO: Can relax for unordered atomics (see D66309)
21243 if (StoreSDNode *ST1 = dyn_cast<StoreSDNode>(Val&: Chain)) {
21244 if (ST->isUnindexed() && ST->isSimple() &&
21245 ST1->isUnindexed() && ST1->isSimple()) {
21246 if (OptLevel != CodeGenOptLevel::None && ST1->getBasePtr() == Ptr &&
21247 ST1->getValue() == Value && ST->getMemoryVT() == ST1->getMemoryVT() &&
21248 ST->getAddressSpace() == ST1->getAddressSpace()) {
21249 // If this is a store followed by a store with the same value to the
21250 // same location, then the store is dead/noop.
21251 return Chain;
21252 }
21253
21254 if (OptLevel != CodeGenOptLevel::None && ST1->hasOneUse() &&
21255 !ST1->getBasePtr().isUndef() &&
21256 ST->getAddressSpace() == ST1->getAddressSpace()) {
21257 // If we consider two stores and one smaller in size is a scalable
21258 // vector type and another one a bigger size store with a fixed type,
21259 // then we could not allow the scalable store removal because we don't
21260 // know its final size in the end.
21261 if (ST->getMemoryVT().isScalableVector() ||
21262 ST1->getMemoryVT().isScalableVector()) {
21263 if (ST1->getBasePtr() == Ptr &&
21264 TypeSize::isKnownLE(LHS: ST1->getMemoryVT().getStoreSize(),
21265 RHS: ST->getMemoryVT().getStoreSize())) {
21266 CombineTo(N: ST1, Res: ST1->getChain());
21267 return SDValue(N, 0);
21268 }
21269 } else {
21270 const BaseIndexOffset STBase = BaseIndexOffset::match(N: ST, DAG);
21271 const BaseIndexOffset ChainBase = BaseIndexOffset::match(N: ST1, DAG);
21272 // If this is a store who's preceding store to a subset of the current
21273 // location and no one other node is chained to that store we can
21274 // effectively drop the store. Do not remove stores to undef as they
21275 // may be used as data sinks.
21276 if (STBase.contains(DAG, BitSize: ST->getMemoryVT().getFixedSizeInBits(),
21277 Other: ChainBase,
21278 OtherBitSize: ST1->getMemoryVT().getFixedSizeInBits())) {
21279 CombineTo(N: ST1, Res: ST1->getChain());
21280 return SDValue(N, 0);
21281 }
21282 }
21283 }
21284 }
21285 }
21286
21287 // If this is an FP_ROUND or TRUNC followed by a store, fold this into a
21288 // truncating store. We can do this even if this is already a truncstore.
21289 if ((Value.getOpcode() == ISD::FP_ROUND ||
21290 Value.getOpcode() == ISD::TRUNCATE) &&
21291 Value->hasOneUse() && ST->isUnindexed() &&
21292 TLI.canCombineTruncStore(ValVT: Value.getOperand(i: 0).getValueType(),
21293 MemVT: ST->getMemoryVT(), LegalOnly: LegalOperations)) {
21294 return DAG.getTruncStore(Chain, dl: SDLoc(N), Val: Value.getOperand(i: 0),
21295 Ptr, SVT: ST->getMemoryVT(), MMO: ST->getMemOperand());
21296 }
21297
21298 // Always perform this optimization before types are legal. If the target
21299 // prefers, also try this after legalization to catch stores that were created
21300 // by intrinsics or other nodes.
21301 if (!LegalTypes || (TLI.mergeStoresAfterLegalization(MemVT: ST->getMemoryVT()))) {
21302 while (true) {
21303 // There can be multiple store sequences on the same chain.
21304 // Keep trying to merge store sequences until we are unable to do so
21305 // or until we merge the last store on the chain.
21306 bool Changed = mergeConsecutiveStores(St: ST);
21307 if (!Changed) break;
21308 // Return N as merge only uses CombineTo and no worklist clean
21309 // up is necessary.
21310 if (N->getOpcode() == ISD::DELETED_NODE || !isa<StoreSDNode>(Val: N))
21311 return SDValue(N, 0);
21312 }
21313 }
21314
21315 // Try transforming N to an indexed store.
21316 if (CombineToPreIndexedLoadStore(N) || CombineToPostIndexedLoadStore(N))
21317 return SDValue(N, 0);
21318
21319 // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr'
21320 //
21321 // Make sure to do this only after attempting to merge stores in order to
21322 // avoid changing the types of some subset of stores due to visit order,
21323 // preventing their merging.
21324 if (isa<ConstantFPSDNode>(Val: ST->getValue())) {
21325 if (SDValue NewSt = replaceStoreOfFPConstant(ST))
21326 return NewSt;
21327 }
21328
21329 if (SDValue NewSt = splitMergedValStore(ST))
21330 return NewSt;
21331
21332 return ReduceLoadOpStoreWidth(N);
21333}
21334
21335SDValue DAGCombiner::visitLIFETIME_END(SDNode *N) {
21336 const auto *LifetimeEnd = cast<LifetimeSDNode>(Val: N);
21337 if (!LifetimeEnd->hasOffset())
21338 return SDValue();
21339
21340 const BaseIndexOffset LifetimeEndBase(N->getOperand(Num: 1), SDValue(),
21341 LifetimeEnd->getOffset(), false);
21342
21343 // We walk up the chains to find stores.
21344 SmallVector<SDValue, 8> Chains = {N->getOperand(Num: 0)};
21345 while (!Chains.empty()) {
21346 SDValue Chain = Chains.pop_back_val();
21347 if (!Chain.hasOneUse())
21348 continue;
21349 switch (Chain.getOpcode()) {
21350 case ISD::TokenFactor:
21351 for (unsigned Nops = Chain.getNumOperands(); Nops;)
21352 Chains.push_back(Elt: Chain.getOperand(i: --Nops));
21353 break;
21354 case ISD::LIFETIME_START:
21355 case ISD::LIFETIME_END:
21356 // We can forward past any lifetime start/end that can be proven not to
21357 // alias the node.
21358 if (!mayAlias(Op0: Chain.getNode(), Op1: N))
21359 Chains.push_back(Elt: Chain.getOperand(i: 0));
21360 break;
21361 case ISD::STORE: {
21362 StoreSDNode *ST = dyn_cast<StoreSDNode>(Val&: Chain);
21363 // TODO: Can relax for unordered atomics (see D66309)
21364 if (!ST->isSimple() || ST->isIndexed())
21365 continue;
21366 const TypeSize StoreSize = ST->getMemoryVT().getStoreSize();
21367 // The bounds of a scalable store are not known until runtime, so this
21368 // store cannot be elided.
21369 if (StoreSize.isScalable())
21370 continue;
21371 const BaseIndexOffset StoreBase = BaseIndexOffset::match(N: ST, DAG);
21372 // If we store purely within object bounds just before its lifetime ends,
21373 // we can remove the store.
21374 if (LifetimeEndBase.contains(DAG, BitSize: LifetimeEnd->getSize() * 8, Other: StoreBase,
21375 OtherBitSize: StoreSize.getFixedValue() * 8)) {
21376 LLVM_DEBUG(dbgs() << "\nRemoving store:"; StoreBase.dump();
21377 dbgs() << "\nwithin LIFETIME_END of : ";
21378 LifetimeEndBase.dump(); dbgs() << "\n");
21379 CombineTo(N: ST, Res: ST->getChain());
21380 return SDValue(N, 0);
21381 }
21382 }
21383 }
21384 }
21385 return SDValue();
21386}
21387
21388/// For the instruction sequence of store below, F and I values
21389/// are bundled together as an i64 value before being stored into memory.
21390/// Sometimes it is more efficent to generate separate stores for F and I,
21391/// which can remove the bitwise instructions or sink them to colder places.
21392///
21393/// (store (or (zext (bitcast F to i32) to i64),
21394/// (shl (zext I to i64), 32)), addr) -->
21395/// (store F, addr) and (store I, addr+4)
21396///
21397/// Similarly, splitting for other merged store can also be beneficial, like:
21398/// For pair of {i32, i32}, i64 store --> two i32 stores.
21399/// For pair of {i32, i16}, i64 store --> two i32 stores.
21400/// For pair of {i16, i16}, i32 store --> two i16 stores.
21401/// For pair of {i16, i8}, i32 store --> two i16 stores.
21402/// For pair of {i8, i8}, i16 store --> two i8 stores.
21403///
21404/// We allow each target to determine specifically which kind of splitting is
21405/// supported.
21406///
21407/// The store patterns are commonly seen from the simple code snippet below
21408/// if only std::make_pair(...) is sroa transformed before inlined into hoo.
21409/// void goo(const std::pair<int, float> &);
21410/// hoo() {
21411/// ...
21412/// goo(std::make_pair(tmp, ftmp));
21413/// ...
21414/// }
21415///
21416SDValue DAGCombiner::splitMergedValStore(StoreSDNode *ST) {
21417 if (OptLevel == CodeGenOptLevel::None)
21418 return SDValue();
21419
21420 // Can't change the number of memory accesses for a volatile store or break
21421 // atomicity for an atomic one.
21422 if (!ST->isSimple())
21423 return SDValue();
21424
21425 SDValue Val = ST->getValue();
21426 SDLoc DL(ST);
21427
21428 // Match OR operand.
21429 if (!Val.getValueType().isScalarInteger() || Val.getOpcode() != ISD::OR)
21430 return SDValue();
21431
21432 // Match SHL operand and get Lower and Higher parts of Val.
21433 SDValue Op1 = Val.getOperand(i: 0);
21434 SDValue Op2 = Val.getOperand(i: 1);
21435 SDValue Lo, Hi;
21436 if (Op1.getOpcode() != ISD::SHL) {
21437 std::swap(a&: Op1, b&: Op2);
21438 if (Op1.getOpcode() != ISD::SHL)
21439 return SDValue();
21440 }
21441 Lo = Op2;
21442 Hi = Op1.getOperand(i: 0);
21443 if (!Op1.hasOneUse())
21444 return SDValue();
21445
21446 // Match shift amount to HalfValBitSize.
21447 unsigned HalfValBitSize = Val.getValueSizeInBits() / 2;
21448 ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(Val: Op1.getOperand(i: 1));
21449 if (!ShAmt || ShAmt->getAPIntValue() != HalfValBitSize)
21450 return SDValue();
21451
21452 // Lo and Hi are zero-extended from int with size less equal than 32
21453 // to i64.
21454 if (Lo.getOpcode() != ISD::ZERO_EXTEND || !Lo.hasOneUse() ||
21455 !Lo.getOperand(i: 0).getValueType().isScalarInteger() ||
21456 Lo.getOperand(i: 0).getValueSizeInBits() > HalfValBitSize ||
21457 Hi.getOpcode() != ISD::ZERO_EXTEND || !Hi.hasOneUse() ||
21458 !Hi.getOperand(i: 0).getValueType().isScalarInteger() ||
21459 Hi.getOperand(i: 0).getValueSizeInBits() > HalfValBitSize)
21460 return SDValue();
21461
21462 // Use the EVT of low and high parts before bitcast as the input
21463 // of target query.
21464 EVT LowTy = (Lo.getOperand(i: 0).getOpcode() == ISD::BITCAST)
21465 ? Lo.getOperand(i: 0).getValueType()
21466 : Lo.getValueType();
21467 EVT HighTy = (Hi.getOperand(i: 0).getOpcode() == ISD::BITCAST)
21468 ? Hi.getOperand(i: 0).getValueType()
21469 : Hi.getValueType();
21470 if (!TLI.isMultiStoresCheaperThanBitsMerge(LTy: LowTy, HTy: HighTy))
21471 return SDValue();
21472
21473 // Start to split store.
21474 MachineMemOperand::Flags MMOFlags = ST->getMemOperand()->getFlags();
21475 AAMDNodes AAInfo = ST->getAAInfo();
21476
21477 // Change the sizes of Lo and Hi's value types to HalfValBitSize.
21478 EVT VT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: HalfValBitSize);
21479 Lo = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Lo.getOperand(i: 0));
21480 Hi = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL, VT, Operand: Hi.getOperand(i: 0));
21481
21482 SDValue Chain = ST->getChain();
21483 SDValue Ptr = ST->getBasePtr();
21484 // Lower value store.
21485 SDValue St0 = DAG.getStore(Chain, dl: DL, Val: Lo, Ptr, PtrInfo: ST->getPointerInfo(),
21486 Alignment: ST->getOriginalAlign(), MMOFlags, AAInfo);
21487 Ptr =
21488 DAG.getMemBasePlusOffset(Base: Ptr, Offset: TypeSize::getFixed(ExactSize: HalfValBitSize / 8), DL);
21489 // Higher value store.
21490 SDValue St1 = DAG.getStore(
21491 Chain: St0, dl: DL, Val: Hi, Ptr, PtrInfo: ST->getPointerInfo().getWithOffset(O: HalfValBitSize / 8),
21492 Alignment: ST->getOriginalAlign(), MMOFlags, AAInfo);
21493 return St1;
21494}
21495
21496// Merge an insertion into an existing shuffle:
21497// (insert_vector_elt (vector_shuffle X, Y, Mask),
21498// .(extract_vector_elt X, N), InsIndex)
21499// --> (vector_shuffle X, Y, NewMask)
21500// and variations where shuffle operands may be CONCAT_VECTORS.
21501static bool mergeEltWithShuffle(SDValue &X, SDValue &Y, ArrayRef<int> Mask,
21502 SmallVectorImpl<int> &NewMask, SDValue Elt,
21503 unsigned InsIndex) {
21504 if (Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
21505 !isa<ConstantSDNode>(Val: Elt.getOperand(i: 1)))
21506 return false;
21507
21508 // Vec's operand 0 is using indices from 0 to N-1 and
21509 // operand 1 from N to 2N - 1, where N is the number of
21510 // elements in the vectors.
21511 SDValue InsertVal0 = Elt.getOperand(i: 0);
21512 int ElementOffset = -1;
21513
21514 // We explore the inputs of the shuffle in order to see if we find the
21515 // source of the extract_vector_elt. If so, we can use it to modify the
21516 // shuffle rather than perform an insert_vector_elt.
21517 SmallVector<std::pair<int, SDValue>, 8> ArgWorkList;
21518 ArgWorkList.emplace_back(Args: Mask.size(), Args&: Y);
21519 ArgWorkList.emplace_back(Args: 0, Args&: X);
21520
21521 while (!ArgWorkList.empty()) {
21522 int ArgOffset;
21523 SDValue ArgVal;
21524 std::tie(args&: ArgOffset, args&: ArgVal) = ArgWorkList.pop_back_val();
21525
21526 if (ArgVal == InsertVal0) {
21527 ElementOffset = ArgOffset;
21528 break;
21529 }
21530
21531 // Peek through concat_vector.
21532 if (ArgVal.getOpcode() == ISD::CONCAT_VECTORS) {
21533 int CurrentArgOffset =
21534 ArgOffset + ArgVal.getValueType().getVectorNumElements();
21535 int Step = ArgVal.getOperand(i: 0).getValueType().getVectorNumElements();
21536 for (SDValue Op : reverse(C: ArgVal->ops())) {
21537 CurrentArgOffset -= Step;
21538 ArgWorkList.emplace_back(Args&: CurrentArgOffset, Args&: Op);
21539 }
21540
21541 // Make sure we went through all the elements and did not screw up index
21542 // computation.
21543 assert(CurrentArgOffset == ArgOffset);
21544 }
21545 }
21546
21547 // If we failed to find a match, see if we can replace an UNDEF shuffle
21548 // operand.
21549 if (ElementOffset == -1) {
21550 if (!Y.isUndef() || InsertVal0.getValueType() != Y.getValueType())
21551 return false;
21552 ElementOffset = Mask.size();
21553 Y = InsertVal0;
21554 }
21555
21556 NewMask.assign(in_start: Mask.begin(), in_end: Mask.end());
21557 NewMask[InsIndex] = ElementOffset + Elt.getConstantOperandVal(i: 1);
21558 assert(NewMask[InsIndex] < (int)(2 * Mask.size()) && NewMask[InsIndex] >= 0 &&
21559 "NewMask[InsIndex] is out of bound");
21560 return true;
21561}
21562
21563// Merge an insertion into an existing shuffle:
21564// (insert_vector_elt (vector_shuffle X, Y), (extract_vector_elt X, N),
21565// InsIndex)
21566// --> (vector_shuffle X, Y) and variations where shuffle operands may be
21567// CONCAT_VECTORS.
21568SDValue DAGCombiner::mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex) {
21569 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
21570 "Expected extract_vector_elt");
21571 SDValue InsertVal = N->getOperand(Num: 1);
21572 SDValue Vec = N->getOperand(Num: 0);
21573
21574 auto *SVN = dyn_cast<ShuffleVectorSDNode>(Val&: Vec);
21575 if (!SVN || !Vec.hasOneUse())
21576 return SDValue();
21577
21578 ArrayRef<int> Mask = SVN->getMask();
21579 SDValue X = Vec.getOperand(i: 0);
21580 SDValue Y = Vec.getOperand(i: 1);
21581
21582 SmallVector<int, 16> NewMask(Mask);
21583 if (mergeEltWithShuffle(X, Y, Mask, NewMask, Elt: InsertVal, InsIndex)) {
21584 SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
21585 VT: Vec.getValueType(), DL: SDLoc(N), N0: X, N1: Y, Mask: NewMask, DAG);
21586 if (LegalShuffle)
21587 return LegalShuffle;
21588 }
21589
21590 return SDValue();
21591}
21592
21593// Convert a disguised subvector insertion into a shuffle:
21594// insert_vector_elt V, (bitcast X from vector type), IdxC -->
21595// bitcast(shuffle (bitcast V), (extended X), Mask)
21596// Note: We do not use an insert_subvector node because that requires a
21597// legal subvector type.
21598SDValue DAGCombiner::combineInsertEltToShuffle(SDNode *N, unsigned InsIndex) {
21599 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT &&
21600 "Expected extract_vector_elt");
21601 SDValue InsertVal = N->getOperand(Num: 1);
21602
21603 if (InsertVal.getOpcode() != ISD::BITCAST || !InsertVal.hasOneUse() ||
21604 !InsertVal.getOperand(i: 0).getValueType().isVector())
21605 return SDValue();
21606
21607 SDValue SubVec = InsertVal.getOperand(i: 0);
21608 SDValue DestVec = N->getOperand(Num: 0);
21609 EVT SubVecVT = SubVec.getValueType();
21610 EVT VT = DestVec.getValueType();
21611 unsigned NumSrcElts = SubVecVT.getVectorNumElements();
21612 // If the source only has a single vector element, the cost of creating adding
21613 // it to a vector is likely to exceed the cost of a insert_vector_elt.
21614 if (NumSrcElts == 1)
21615 return SDValue();
21616 unsigned ExtendRatio = VT.getSizeInBits() / SubVecVT.getSizeInBits();
21617 unsigned NumMaskVals = ExtendRatio * NumSrcElts;
21618
21619 // Step 1: Create a shuffle mask that implements this insert operation. The
21620 // vector that we are inserting into will be operand 0 of the shuffle, so
21621 // those elements are just 'i'. The inserted subvector is in the first
21622 // positions of operand 1 of the shuffle. Example:
21623 // insert v4i32 V, (v2i16 X), 2 --> shuffle v8i16 V', X', {0,1,2,3,8,9,6,7}
21624 SmallVector<int, 16> Mask(NumMaskVals);
21625 for (unsigned i = 0; i != NumMaskVals; ++i) {
21626 if (i / NumSrcElts == InsIndex)
21627 Mask[i] = (i % NumSrcElts) + NumMaskVals;
21628 else
21629 Mask[i] = i;
21630 }
21631
21632 // Bail out if the target can not handle the shuffle we want to create.
21633 EVT SubVecEltVT = SubVecVT.getVectorElementType();
21634 EVT ShufVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SubVecEltVT, NumElements: NumMaskVals);
21635 if (!TLI.isShuffleMaskLegal(Mask, ShufVT))
21636 return SDValue();
21637
21638 // Step 2: Create a wide vector from the inserted source vector by appending
21639 // undefined elements. This is the same size as our destination vector.
21640 SDLoc DL(N);
21641 SmallVector<SDValue, 8> ConcatOps(ExtendRatio, DAG.getUNDEF(VT: SubVecVT));
21642 ConcatOps[0] = SubVec;
21643 SDValue PaddedSubV = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: ShufVT, Ops: ConcatOps);
21644
21645 // Step 3: Shuffle in the padded subvector.
21646 SDValue DestVecBC = DAG.getBitcast(VT: ShufVT, V: DestVec);
21647 SDValue Shuf = DAG.getVectorShuffle(VT: ShufVT, dl: DL, N1: DestVecBC, N2: PaddedSubV, Mask);
21648 AddToWorklist(N: PaddedSubV.getNode());
21649 AddToWorklist(N: DestVecBC.getNode());
21650 AddToWorklist(N: Shuf.getNode());
21651 return DAG.getBitcast(VT, V: Shuf);
21652}
21653
21654// Combine insert(shuffle(load, <u,0,1,2>), load, 0) into a single load if
21655// possible and the new load will be quick. We use more loads but less shuffles
21656// and inserts.
21657SDValue DAGCombiner::combineInsertEltToLoad(SDNode *N, unsigned InsIndex) {
21658 EVT VT = N->getValueType(ResNo: 0);
21659
21660 // InsIndex is expected to be the first of last lane.
21661 if (!VT.isFixedLengthVector() ||
21662 (InsIndex != 0 && InsIndex != VT.getVectorNumElements() - 1))
21663 return SDValue();
21664
21665 // Look for a shuffle with the mask u,0,1,2,3,4,5,6 or 1,2,3,4,5,6,7,u
21666 // depending on the InsIndex.
21667 auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(Val: N->getOperand(Num: 0));
21668 SDValue Scalar = N->getOperand(Num: 1);
21669 if (!Shuffle || !all_of(Range: enumerate(First: Shuffle->getMask()), P: [&](auto P) {
21670 return InsIndex == P.index() || P.value() < 0 ||
21671 (InsIndex == 0 && P.value() == (int)P.index() - 1) ||
21672 (InsIndex == VT.getVectorNumElements() - 1 &&
21673 P.value() == (int)P.index() + 1);
21674 }))
21675 return SDValue();
21676
21677 // We optionally skip over an extend so long as both loads are extended in the
21678 // same way from the same type.
21679 unsigned Extend = 0;
21680 if (Scalar.getOpcode() == ISD::ZERO_EXTEND ||
21681 Scalar.getOpcode() == ISD::SIGN_EXTEND ||
21682 Scalar.getOpcode() == ISD::ANY_EXTEND) {
21683 Extend = Scalar.getOpcode();
21684 Scalar = Scalar.getOperand(i: 0);
21685 }
21686
21687 auto *ScalarLoad = dyn_cast<LoadSDNode>(Val&: Scalar);
21688 if (!ScalarLoad)
21689 return SDValue();
21690
21691 SDValue Vec = Shuffle->getOperand(Num: 0);
21692 if (Extend) {
21693 if (Vec.getOpcode() != Extend)
21694 return SDValue();
21695 Vec = Vec.getOperand(i: 0);
21696 }
21697 auto *VecLoad = dyn_cast<LoadSDNode>(Val&: Vec);
21698 if (!VecLoad || Vec.getValueType().getScalarType() != Scalar.getValueType())
21699 return SDValue();
21700
21701 int EltSize = ScalarLoad->getValueType(ResNo: 0).getScalarSizeInBits();
21702 if (EltSize == 0 || EltSize % 8 != 0 || !ScalarLoad->isSimple() ||
21703 !VecLoad->isSimple() || VecLoad->getExtensionType() != ISD::NON_EXTLOAD ||
21704 ScalarLoad->getExtensionType() != ISD::NON_EXTLOAD ||
21705 ScalarLoad->getAddressSpace() != VecLoad->getAddressSpace())
21706 return SDValue();
21707
21708 // Check that the offset between the pointers to produce a single continuous
21709 // load.
21710 if (InsIndex == 0) {
21711 if (!DAG.areNonVolatileConsecutiveLoads(LD: ScalarLoad, Base: VecLoad, Bytes: EltSize / 8,
21712 Dist: -1))
21713 return SDValue();
21714 } else {
21715 if (!DAG.areNonVolatileConsecutiveLoads(
21716 LD: VecLoad, Base: ScalarLoad, Bytes: VT.getVectorNumElements() * EltSize / 8, Dist: -1))
21717 return SDValue();
21718 }
21719
21720 // And that the new unaligned load will be fast.
21721 unsigned IsFast = 0;
21722 Align NewAlign = commonAlignment(A: VecLoad->getAlign(), Offset: EltSize / 8);
21723 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(),
21724 VT: Vec.getValueType(), AddrSpace: VecLoad->getAddressSpace(),
21725 Alignment: NewAlign, Flags: VecLoad->getMemOperand()->getFlags(),
21726 Fast: &IsFast) ||
21727 !IsFast)
21728 return SDValue();
21729
21730 // Calculate the new Ptr and create the new load.
21731 SDLoc DL(N);
21732 SDValue Ptr = ScalarLoad->getBasePtr();
21733 if (InsIndex != 0)
21734 Ptr = DAG.getNode(Opcode: ISD::ADD, DL, VT: Ptr.getValueType(), N1: VecLoad->getBasePtr(),
21735 N2: DAG.getConstant(Val: EltSize / 8, DL, VT: Ptr.getValueType()));
21736 MachinePointerInfo PtrInfo =
21737 InsIndex == 0 ? ScalarLoad->getPointerInfo()
21738 : VecLoad->getPointerInfo().getWithOffset(O: EltSize / 8);
21739
21740 SDValue Load = DAG.getLoad(VT: VecLoad->getValueType(ResNo: 0), dl: DL,
21741 Chain: ScalarLoad->getChain(), Ptr, PtrInfo, Alignment: NewAlign);
21742 DAG.makeEquivalentMemoryOrdering(OldLoad: ScalarLoad, NewMemOp: Load.getValue(R: 1));
21743 DAG.makeEquivalentMemoryOrdering(OldLoad: VecLoad, NewMemOp: Load.getValue(R: 1));
21744 return Extend ? DAG.getNode(Opcode: Extend, DL, VT, Operand: Load) : Load;
21745}
21746
21747SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
21748 SDValue InVec = N->getOperand(Num: 0);
21749 SDValue InVal = N->getOperand(Num: 1);
21750 SDValue EltNo = N->getOperand(Num: 2);
21751 SDLoc DL(N);
21752
21753 EVT VT = InVec.getValueType();
21754 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: EltNo);
21755
21756 // Insert into out-of-bounds element is undefined.
21757 if (IndexC && VT.isFixedLengthVector() &&
21758 IndexC->getZExtValue() >= VT.getVectorNumElements())
21759 return DAG.getUNDEF(VT);
21760
21761 // Remove redundant insertions:
21762 // (insert_vector_elt x (extract_vector_elt x idx) idx) -> x
21763 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
21764 InVec == InVal.getOperand(i: 0) && EltNo == InVal.getOperand(i: 1))
21765 return InVec;
21766
21767 if (!IndexC) {
21768 // If this is variable insert to undef vector, it might be better to splat:
21769 // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
21770 if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
21771 return DAG.getSplat(VT, DL, Op: InVal);
21772 return SDValue();
21773 }
21774
21775 if (VT.isScalableVector())
21776 return SDValue();
21777
21778 unsigned NumElts = VT.getVectorNumElements();
21779
21780 // We must know which element is being inserted for folds below here.
21781 unsigned Elt = IndexC->getZExtValue();
21782
21783 // Handle <1 x ???> vector insertion special cases.
21784 if (NumElts == 1) {
21785 // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y
21786 if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
21787 InVal.getOperand(i: 0).getValueType() == VT &&
21788 isNullConstant(V: InVal.getOperand(i: 1)))
21789 return InVal.getOperand(i: 0);
21790 }
21791
21792 // Canonicalize insert_vector_elt dag nodes.
21793 // Example:
21794 // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1)
21795 // -> (insert_vector_elt (insert_vector_elt A, Idx1), Idx0)
21796 //
21797 // Do this only if the child insert_vector node has one use; also
21798 // do this only if indices are both constants and Idx1 < Idx0.
21799 if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT && InVec.hasOneUse()
21800 && isa<ConstantSDNode>(Val: InVec.getOperand(i: 2))) {
21801 unsigned OtherElt = InVec.getConstantOperandVal(i: 2);
21802 if (Elt < OtherElt) {
21803 // Swap nodes.
21804 SDValue NewOp = DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL, VT,
21805 N1: InVec.getOperand(i: 0), N2: InVal, N3: EltNo);
21806 AddToWorklist(N: NewOp.getNode());
21807 return DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL: SDLoc(InVec.getNode()),
21808 VT, N1: NewOp, N2: InVec.getOperand(i: 1), N3: InVec.getOperand(i: 2));
21809 }
21810 }
21811
21812 if (SDValue Shuf = mergeInsertEltWithShuffle(N, InsIndex: Elt))
21813 return Shuf;
21814
21815 if (SDValue Shuf = combineInsertEltToShuffle(N, InsIndex: Elt))
21816 return Shuf;
21817
21818 if (SDValue Shuf = combineInsertEltToLoad(N, InsIndex: Elt))
21819 return Shuf;
21820
21821 // Attempt to convert an insert_vector_elt chain into a legal build_vector.
21822 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)) {
21823 // vXi1 vector - we don't need to recurse.
21824 if (NumElts == 1)
21825 return DAG.getBuildVector(VT, DL, Ops: {InVal});
21826
21827 // If we haven't already collected the element, insert into the op list.
21828 EVT MaxEltVT = InVal.getValueType();
21829 auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt,
21830 unsigned Idx) {
21831 if (!Ops[Idx]) {
21832 Ops[Idx] = Elt;
21833 if (VT.isInteger()) {
21834 EVT EltVT = Elt.getValueType();
21835 MaxEltVT = MaxEltVT.bitsGE(VT: EltVT) ? MaxEltVT : EltVT;
21836 }
21837 }
21838 };
21839
21840 // Ensure all the operands are the same value type, fill any missing
21841 // operands with UNDEF and create the BUILD_VECTOR.
21842 auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops) {
21843 assert(Ops.size() == NumElts && "Unexpected vector size");
21844 for (SDValue &Op : Ops) {
21845 if (Op)
21846 Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, VT: MaxEltVT) : Op;
21847 else
21848 Op = DAG.getUNDEF(VT: MaxEltVT);
21849 }
21850 return DAG.getBuildVector(VT, DL, Ops);
21851 };
21852
21853 SmallVector<SDValue, 8> Ops(NumElts, SDValue());
21854 Ops[Elt] = InVal;
21855
21856 // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR.
21857 for (SDValue CurVec = InVec; CurVec;) {
21858 // UNDEF - build new BUILD_VECTOR from already inserted operands.
21859 if (CurVec.isUndef())
21860 return CanonicalizeBuildVector(Ops);
21861
21862 // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR.
21863 if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) {
21864 for (unsigned I = 0; I != NumElts; ++I)
21865 AddBuildVectorOp(Ops, CurVec.getOperand(i: I), I);
21866 return CanonicalizeBuildVector(Ops);
21867 }
21868
21869 // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR.
21870 if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) {
21871 AddBuildVectorOp(Ops, CurVec.getOperand(i: 0), 0);
21872 return CanonicalizeBuildVector(Ops);
21873 }
21874
21875 // INSERT_VECTOR_ELT - insert operand and continue up the chain.
21876 if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse())
21877 if (auto *CurIdx = dyn_cast<ConstantSDNode>(Val: CurVec.getOperand(i: 2)))
21878 if (CurIdx->getAPIntValue().ult(RHS: NumElts)) {
21879 unsigned Idx = CurIdx->getZExtValue();
21880 AddBuildVectorOp(Ops, CurVec.getOperand(i: 1), Idx);
21881
21882 // Found entire BUILD_VECTOR.
21883 if (all_of(Range&: Ops, P: [](SDValue Op) { return !!Op; }))
21884 return CanonicalizeBuildVector(Ops);
21885
21886 CurVec = CurVec->getOperand(Num: 0);
21887 continue;
21888 }
21889
21890 // VECTOR_SHUFFLE - if all the operands match the shuffle's sources,
21891 // update the shuffle mask (and second operand if we started with unary
21892 // shuffle) and create a new legal shuffle.
21893 if (CurVec.getOpcode() == ISD::VECTOR_SHUFFLE && CurVec.hasOneUse()) {
21894 auto *SVN = cast<ShuffleVectorSDNode>(Val&: CurVec);
21895 SDValue LHS = SVN->getOperand(Num: 0);
21896 SDValue RHS = SVN->getOperand(Num: 1);
21897 SmallVector<int, 16> Mask(SVN->getMask());
21898 bool Merged = true;
21899 for (auto I : enumerate(First&: Ops)) {
21900 SDValue &Op = I.value();
21901 if (Op) {
21902 SmallVector<int, 16> NewMask;
21903 if (!mergeEltWithShuffle(X&: LHS, Y&: RHS, Mask, NewMask, Elt: Op, InsIndex: I.index())) {
21904 Merged = false;
21905 break;
21906 }
21907 Mask = std::move(NewMask);
21908 }
21909 }
21910 if (Merged)
21911 if (SDValue NewShuffle =
21912 TLI.buildLegalVectorShuffle(VT, DL, N0: LHS, N1: RHS, Mask, DAG))
21913 return NewShuffle;
21914 }
21915
21916 // If all insertions are zero value, try to convert to AND mask.
21917 // TODO: Do this for -1 with OR mask?
21918 if (!LegalOperations && llvm::isNullConstant(V: InVal) &&
21919 all_of(Range&: Ops, P: [InVal](SDValue Op) { return !Op || Op == InVal; }) &&
21920 count_if(Range&: Ops, P: [InVal](SDValue Op) { return Op == InVal; }) >= 2) {
21921 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: MaxEltVT);
21922 SDValue AllOnes = DAG.getAllOnesConstant(DL, VT: MaxEltVT);
21923 SmallVector<SDValue, 8> Mask(NumElts);
21924 for (unsigned I = 0; I != NumElts; ++I)
21925 Mask[I] = Ops[I] ? Zero : AllOnes;
21926 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: CurVec,
21927 N2: DAG.getBuildVector(VT, DL, Ops: Mask));
21928 }
21929
21930 // Failed to find a match in the chain - bail.
21931 break;
21932 }
21933
21934 // See if we can fill in the missing constant elements as zeros.
21935 // TODO: Should we do this for any constant?
21936 APInt DemandedZeroElts = APInt::getZero(numBits: NumElts);
21937 for (unsigned I = 0; I != NumElts; ++I)
21938 if (!Ops[I])
21939 DemandedZeroElts.setBit(I);
21940
21941 if (DAG.MaskedVectorIsZero(Op: InVec, DemandedElts: DemandedZeroElts)) {
21942 SDValue Zero = VT.isInteger() ? DAG.getConstant(Val: 0, DL, VT: MaxEltVT)
21943 : DAG.getConstantFP(Val: 0, DL, VT: MaxEltVT);
21944 for (unsigned I = 0; I != NumElts; ++I)
21945 if (!Ops[I])
21946 Ops[I] = Zero;
21947
21948 return CanonicalizeBuildVector(Ops);
21949 }
21950 }
21951
21952 return SDValue();
21953}
21954
21955SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
21956 SDValue EltNo,
21957 LoadSDNode *OriginalLoad) {
21958 assert(OriginalLoad->isSimple());
21959
21960 EVT ResultVT = EVE->getValueType(ResNo: 0);
21961 EVT VecEltVT = InVecVT.getVectorElementType();
21962
21963 // If the vector element type is not a multiple of a byte then we are unable
21964 // to correctly compute an address to load only the extracted element as a
21965 // scalar.
21966 if (!VecEltVT.isByteSized())
21967 return SDValue();
21968
21969 ISD::LoadExtType ExtTy =
21970 ResultVT.bitsGT(VT: VecEltVT) ? ISD::NON_EXTLOAD : ISD::EXTLOAD;
21971 if (!TLI.isOperationLegalOrCustom(Op: ISD::LOAD, VT: VecEltVT) ||
21972 !TLI.shouldReduceLoadWidth(Load: OriginalLoad, ExtTy, NewVT: VecEltVT))
21973 return SDValue();
21974
21975 Align Alignment = OriginalLoad->getAlign();
21976 MachinePointerInfo MPI;
21977 SDLoc DL(EVE);
21978 if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(Val&: EltNo)) {
21979 int Elt = ConstEltNo->getZExtValue();
21980 unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
21981 MPI = OriginalLoad->getPointerInfo().getWithOffset(O: PtrOff);
21982 Alignment = commonAlignment(A: Alignment, Offset: PtrOff);
21983 } else {
21984 // Discard the pointer info except the address space because the memory
21985 // operand can't represent this new access since the offset is variable.
21986 MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
21987 Alignment = commonAlignment(A: Alignment, Offset: VecEltVT.getSizeInBits() / 8);
21988 }
21989
21990 unsigned IsFast = 0;
21991 if (!TLI.allowsMemoryAccess(Context&: *DAG.getContext(), DL: DAG.getDataLayout(), VT: VecEltVT,
21992 AddrSpace: OriginalLoad->getAddressSpace(), Alignment,
21993 Flags: OriginalLoad->getMemOperand()->getFlags(),
21994 Fast: &IsFast) ||
21995 !IsFast)
21996 return SDValue();
21997
21998 SDValue NewPtr = TLI.getVectorElementPointer(DAG, VecPtr: OriginalLoad->getBasePtr(),
21999 VecVT: InVecVT, Index: EltNo);
22000
22001 // We are replacing a vector load with a scalar load. The new load must have
22002 // identical memory op ordering to the original.
22003 SDValue Load;
22004 if (ResultVT.bitsGT(VT: VecEltVT)) {
22005 // If the result type of vextract is wider than the load, then issue an
22006 // extending load instead.
22007 ISD::LoadExtType ExtType =
22008 TLI.isLoadExtLegal(ExtType: ISD::ZEXTLOAD, ValVT: ResultVT, MemVT: VecEltVT) ? ISD::ZEXTLOAD
22009 : ISD::EXTLOAD;
22010 Load = DAG.getExtLoad(ExtType, dl: DL, VT: ResultVT, Chain: OriginalLoad->getChain(),
22011 Ptr: NewPtr, PtrInfo: MPI, MemVT: VecEltVT, Alignment,
22012 MMOFlags: OriginalLoad->getMemOperand()->getFlags(),
22013 AAInfo: OriginalLoad->getAAInfo());
22014 DAG.makeEquivalentMemoryOrdering(OldLoad: OriginalLoad, NewMemOp: Load);
22015 } else {
22016 // The result type is narrower or the same width as the vector element
22017 Load = DAG.getLoad(VT: VecEltVT, dl: DL, Chain: OriginalLoad->getChain(), Ptr: NewPtr, PtrInfo: MPI,
22018 Alignment, MMOFlags: OriginalLoad->getMemOperand()->getFlags(),
22019 AAInfo: OriginalLoad->getAAInfo());
22020 DAG.makeEquivalentMemoryOrdering(OldLoad: OriginalLoad, NewMemOp: Load);
22021 if (ResultVT.bitsLT(VT: VecEltVT))
22022 Load = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ResultVT, Operand: Load);
22023 else
22024 Load = DAG.getBitcast(VT: ResultVT, V: Load);
22025 }
22026 ++OpsNarrowed;
22027 return Load;
22028}
22029
22030/// Transform a vector binary operation into a scalar binary operation by moving
22031/// the math/logic after an extract element of a vector.
22032static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
22033 bool LegalOperations) {
22034 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22035 SDValue Vec = ExtElt->getOperand(Num: 0);
22036 SDValue Index = ExtElt->getOperand(Num: 1);
22037 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: Index);
22038 if (!IndexC || !TLI.isBinOp(Opcode: Vec.getOpcode()) || !Vec.hasOneUse() ||
22039 Vec->getNumValues() != 1)
22040 return SDValue();
22041
22042 // Targets may want to avoid this to prevent an expensive register transfer.
22043 if (!TLI.shouldScalarizeBinop(VecOp: Vec))
22044 return SDValue();
22045
22046 // Extracting an element of a vector constant is constant-folded, so this
22047 // transform is just replacing a vector op with a scalar op while moving the
22048 // extract.
22049 SDValue Op0 = Vec.getOperand(i: 0);
22050 SDValue Op1 = Vec.getOperand(i: 1);
22051 APInt SplatVal;
22052 if (isAnyConstantBuildVector(V: Op0, NoOpaques: true) ||
22053 ISD::isConstantSplatVector(N: Op0.getNode(), SplatValue&: SplatVal) ||
22054 isAnyConstantBuildVector(V: Op1, NoOpaques: true) ||
22055 ISD::isConstantSplatVector(N: Op1.getNode(), SplatValue&: SplatVal)) {
22056 // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
22057 // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
22058 SDLoc DL(ExtElt);
22059 EVT VT = ExtElt->getValueType(ResNo: 0);
22060 SDValue Ext0 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT, N1: Op0, N2: Index);
22061 SDValue Ext1 = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT, N1: Op1, N2: Index);
22062 return DAG.getNode(Opcode: Vec.getOpcode(), DL, VT, N1: Ext0, N2: Ext1);
22063 }
22064
22065 return SDValue();
22066}
22067
22068// Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
22069// recursively analyse all of it's users. and try to model themselves as
22070// bit sequence extractions. If all of them agree on the new, narrower element
22071// type, and all of them can be modelled as ISD::EXTRACT_VECTOR_ELT's of that
22072// new element type, do so now.
22073// This is mainly useful to recover from legalization that scalarized
22074// the vector as wide elements, but tries to rebuild it with narrower elements.
22075//
22076// Some more nodes could be modelled if that helps cover interesting patterns.
22077bool DAGCombiner::refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(
22078 SDNode *N) {
22079 // We perform this optimization post type-legalization because
22080 // the type-legalizer often scalarizes integer-promoted vectors.
22081 // Performing this optimization before may cause legalizaton cycles.
22082 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
22083 return false;
22084
22085 // TODO: Add support for big-endian.
22086 if (DAG.getDataLayout().isBigEndian())
22087 return false;
22088
22089 SDValue VecOp = N->getOperand(Num: 0);
22090 EVT VecVT = VecOp.getValueType();
22091 assert(!VecVT.isScalableVector() && "Only for fixed vectors.");
22092
22093 // We must start with a constant extraction index.
22094 auto *IndexC = dyn_cast<ConstantSDNode>(Val: N->getOperand(Num: 1));
22095 if (!IndexC)
22096 return false;
22097
22098 assert(IndexC->getZExtValue() < VecVT.getVectorNumElements() &&
22099 "Original ISD::EXTRACT_VECTOR_ELT is undefinend?");
22100
22101 // TODO: deal with the case of implicit anyext of the extraction.
22102 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
22103 EVT ScalarVT = N->getValueType(ResNo: 0);
22104 if (VecVT.getScalarType() != ScalarVT)
22105 return false;
22106
22107 // TODO: deal with the cases other than everything being integer-typed.
22108 if (!ScalarVT.isScalarInteger())
22109 return false;
22110
22111 struct Entry {
22112 SDNode *Producer;
22113
22114 // Which bits of VecOp does it contain?
22115 unsigned BitPos;
22116 int NumBits;
22117 // NOTE: the actual width of \p Producer may be wider than NumBits!
22118
22119 Entry(Entry &&) = default;
22120 Entry(SDNode *Producer_, unsigned BitPos_, int NumBits_)
22121 : Producer(Producer_), BitPos(BitPos_), NumBits(NumBits_) {}
22122
22123 Entry() = delete;
22124 Entry(const Entry &) = delete;
22125 Entry &operator=(const Entry &) = delete;
22126 Entry &operator=(Entry &&) = delete;
22127 };
22128 SmallVector<Entry, 32> Worklist;
22129 SmallVector<Entry, 32> Leafs;
22130
22131 // We start at the "root" ISD::EXTRACT_VECTOR_ELT.
22132 Worklist.emplace_back(Args&: N, /*BitPos=*/Args: VecEltBitWidth * IndexC->getZExtValue(),
22133 /*NumBits=*/Args&: VecEltBitWidth);
22134
22135 while (!Worklist.empty()) {
22136 Entry E = Worklist.pop_back_val();
22137 // Does the node not even use any of the VecOp bits?
22138 if (!(E.NumBits > 0 && E.BitPos < VecVT.getSizeInBits() &&
22139 E.BitPos + E.NumBits <= VecVT.getSizeInBits()))
22140 return false; // Let's allow the other combines clean this up first.
22141 // Did we fail to model any of the users of the Producer?
22142 bool ProducerIsLeaf = false;
22143 // Look at each user of this Producer.
22144 for (SDNode *User : E.Producer->uses()) {
22145 switch (User->getOpcode()) {
22146 // TODO: support ISD::BITCAST
22147 // TODO: support ISD::ANY_EXTEND
22148 // TODO: support ISD::ZERO_EXTEND
22149 // TODO: support ISD::SIGN_EXTEND
22150 case ISD::TRUNCATE:
22151 // Truncation simply means we keep position, but extract less bits.
22152 Worklist.emplace_back(Args&: User, Args&: E.BitPos,
22153 /*NumBits=*/Args: User->getValueSizeInBits(ResNo: 0));
22154 break;
22155 // TODO: support ISD::SRA
22156 // TODO: support ISD::SHL
22157 case ISD::SRL:
22158 // We should be shifting the Producer by a constant amount.
22159 if (auto *ShAmtC = dyn_cast<ConstantSDNode>(Val: User->getOperand(Num: 1));
22160 User->getOperand(Num: 0).getNode() == E.Producer && ShAmtC) {
22161 // Logical right-shift means that we start extraction later,
22162 // but stop it at the same position we did previously.
22163 unsigned ShAmt = ShAmtC->getZExtValue();
22164 Worklist.emplace_back(Args&: User, Args: E.BitPos + ShAmt, Args: E.NumBits - ShAmt);
22165 break;
22166 }
22167 [[fallthrough]];
22168 default:
22169 // We can not model this user of the Producer.
22170 // Which means the current Producer will be a ISD::EXTRACT_VECTOR_ELT.
22171 ProducerIsLeaf = true;
22172 // Profitability check: all users that we can not model
22173 // must be ISD::BUILD_VECTOR's.
22174 if (User->getOpcode() != ISD::BUILD_VECTOR)
22175 return false;
22176 break;
22177 }
22178 }
22179 if (ProducerIsLeaf)
22180 Leafs.emplace_back(Args: std::move(E));
22181 }
22182
22183 unsigned NewVecEltBitWidth = Leafs.front().NumBits;
22184
22185 // If we are still at the same element granularity, give up,
22186 if (NewVecEltBitWidth == VecEltBitWidth)
22187 return false;
22188
22189 // The vector width must be a multiple of the new element width.
22190 if (VecVT.getSizeInBits() % NewVecEltBitWidth != 0)
22191 return false;
22192
22193 // All leafs must agree on the new element width.
22194 // All leafs must not expect any "padding" bits ontop of that width.
22195 // All leafs must start extraction from multiple of that width.
22196 if (!all_of(Range&: Leafs, P: [NewVecEltBitWidth](const Entry &E) {
22197 return (unsigned)E.NumBits == NewVecEltBitWidth &&
22198 E.Producer->getValueSizeInBits(ResNo: 0) == NewVecEltBitWidth &&
22199 E.BitPos % NewVecEltBitWidth == 0;
22200 }))
22201 return false;
22202
22203 EVT NewScalarVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NewVecEltBitWidth);
22204 EVT NewVecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewScalarVT,
22205 NumElements: VecVT.getSizeInBits() / NewVecEltBitWidth);
22206
22207 if (LegalTypes &&
22208 !(TLI.isTypeLegal(VT: NewScalarVT) && TLI.isTypeLegal(VT: NewVecVT)))
22209 return false;
22210
22211 if (LegalOperations &&
22212 !(TLI.isOperationLegalOrCustom(Op: ISD::BITCAST, VT: NewVecVT) &&
22213 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_VECTOR_ELT, VT: NewVecVT)))
22214 return false;
22215
22216 SDValue NewVecOp = DAG.getBitcast(VT: NewVecVT, V: VecOp);
22217 for (const Entry &E : Leafs) {
22218 SDLoc DL(E.Producer);
22219 unsigned NewIndex = E.BitPos / NewVecEltBitWidth;
22220 assert(NewIndex < NewVecVT.getVectorNumElements() &&
22221 "Creating out-of-bounds ISD::EXTRACT_VECTOR_ELT?");
22222 SDValue V = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: NewScalarVT, N1: NewVecOp,
22223 N2: DAG.getVectorIdxConstant(Val: NewIndex, DL));
22224 CombineTo(N: E.Producer, Res: V);
22225 }
22226
22227 return true;
22228}
22229
22230SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
22231 SDValue VecOp = N->getOperand(Num: 0);
22232 SDValue Index = N->getOperand(Num: 1);
22233 EVT ScalarVT = N->getValueType(ResNo: 0);
22234 EVT VecVT = VecOp.getValueType();
22235 if (VecOp.isUndef())
22236 return DAG.getUNDEF(VT: ScalarVT);
22237
22238 // extract_vector_elt (insert_vector_elt vec, val, idx), idx) -> val
22239 //
22240 // This only really matters if the index is non-constant since other combines
22241 // on the constant elements already work.
22242 SDLoc DL(N);
22243 if (VecOp.getOpcode() == ISD::INSERT_VECTOR_ELT &&
22244 Index == VecOp.getOperand(i: 2)) {
22245 SDValue Elt = VecOp.getOperand(i: 1);
22246 return VecVT.isInteger() ? DAG.getAnyExtOrTrunc(Op: Elt, DL, VT: ScalarVT) : Elt;
22247 }
22248
22249 // (vextract (scalar_to_vector val, 0) -> val
22250 if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR) {
22251 // Only 0'th element of SCALAR_TO_VECTOR is defined.
22252 if (DAG.isKnownNeverZero(Op: Index))
22253 return DAG.getUNDEF(VT: ScalarVT);
22254
22255 // Check if the result type doesn't match the inserted element type.
22256 // The inserted element and extracted element may have mismatched bitwidth.
22257 // As a result, EXTRACT_VECTOR_ELT may extend or truncate the extracted vector.
22258 SDValue InOp = VecOp.getOperand(i: 0);
22259 if (InOp.getValueType() != ScalarVT) {
22260 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
22261 if (InOp.getValueType().bitsGT(VT: ScalarVT))
22262 return DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: ScalarVT, Operand: InOp);
22263 return DAG.getNode(Opcode: ISD::ANY_EXTEND, DL, VT: ScalarVT, Operand: InOp);
22264 }
22265 return InOp;
22266 }
22267
22268 // extract_vector_elt of out-of-bounds element -> UNDEF
22269 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: Index);
22270 if (IndexC && VecVT.isFixedLengthVector() &&
22271 IndexC->getAPIntValue().uge(RHS: VecVT.getVectorNumElements()))
22272 return DAG.getUNDEF(VT: ScalarVT);
22273
22274 // extract_vector_elt (build_vector x, y), 1 -> y
22275 if (((IndexC && VecOp.getOpcode() == ISD::BUILD_VECTOR) ||
22276 VecOp.getOpcode() == ISD::SPLAT_VECTOR) &&
22277 TLI.isTypeLegal(VT: VecVT)) {
22278 assert((VecOp.getOpcode() != ISD::BUILD_VECTOR ||
22279 VecVT.isFixedLengthVector()) &&
22280 "BUILD_VECTOR used for scalable vectors");
22281 unsigned IndexVal =
22282 VecOp.getOpcode() == ISD::BUILD_VECTOR ? IndexC->getZExtValue() : 0;
22283 SDValue Elt = VecOp.getOperand(i: IndexVal);
22284 EVT InEltVT = Elt.getValueType();
22285
22286 if (VecOp.hasOneUse() || TLI.aggressivelyPreferBuildVectorSources(VecVT) ||
22287 isNullConstant(V: Elt)) {
22288 // Sometimes build_vector's scalar input types do not match result type.
22289 if (ScalarVT == InEltVT)
22290 return Elt;
22291
22292 // TODO: It may be useful to truncate if free if the build_vector
22293 // implicitly converts.
22294 }
22295 }
22296
22297 if (SDValue BO = scalarizeExtractedBinop(ExtElt: N, DAG, LegalOperations))
22298 return BO;
22299
22300 if (VecVT.isScalableVector())
22301 return SDValue();
22302
22303 // All the code from this point onwards assumes fixed width vectors, but it's
22304 // possible that some of the combinations could be made to work for scalable
22305 // vectors too.
22306 unsigned NumElts = VecVT.getVectorNumElements();
22307 unsigned VecEltBitWidth = VecVT.getScalarSizeInBits();
22308
22309 // See if the extracted element is constant, in which case fold it if its
22310 // a legal fp immediate.
22311 if (IndexC && ScalarVT.isFloatingPoint()) {
22312 APInt EltMask = APInt::getOneBitSet(numBits: NumElts, BitNo: IndexC->getZExtValue());
22313 KnownBits KnownElt = DAG.computeKnownBits(Op: VecOp, DemandedElts: EltMask);
22314 if (KnownElt.isConstant()) {
22315 APFloat CstFP =
22316 APFloat(DAG.EVTToAPFloatSemantics(VT: ScalarVT), KnownElt.getConstant());
22317 if (TLI.isFPImmLegal(CstFP, ScalarVT))
22318 return DAG.getConstantFP(Val: CstFP, DL, VT: ScalarVT);
22319 }
22320 }
22321
22322 // TODO: These transforms should not require the 'hasOneUse' restriction, but
22323 // there are regressions on multiple targets without it. We can end up with a
22324 // mess of scalar and vector code if we reduce only part of the DAG to scalar.
22325 if (IndexC && VecOp.getOpcode() == ISD::BITCAST && VecVT.isInteger() &&
22326 VecOp.hasOneUse()) {
22327 // The vector index of the LSBs of the source depend on the endian-ness.
22328 bool IsLE = DAG.getDataLayout().isLittleEndian();
22329 unsigned ExtractIndex = IndexC->getZExtValue();
22330 // extract_elt (v2i32 (bitcast i64:x)), BCTruncElt -> i32 (trunc i64:x)
22331 unsigned BCTruncElt = IsLE ? 0 : NumElts - 1;
22332 SDValue BCSrc = VecOp.getOperand(i: 0);
22333 if (ExtractIndex == BCTruncElt && BCSrc.getValueType().isScalarInteger())
22334 return DAG.getAnyExtOrTrunc(Op: BCSrc, DL, VT: ScalarVT);
22335
22336 if (LegalTypes && BCSrc.getValueType().isInteger() &&
22337 BCSrc.getOpcode() == ISD::SCALAR_TO_VECTOR) {
22338 // ext_elt (bitcast (scalar_to_vec i64 X to v2i64) to v4i32), TruncElt -->
22339 // trunc i64 X to i32
22340 SDValue X = BCSrc.getOperand(i: 0);
22341 assert(X.getValueType().isScalarInteger() && ScalarVT.isScalarInteger() &&
22342 "Extract element and scalar to vector can't change element type "
22343 "from FP to integer.");
22344 unsigned XBitWidth = X.getValueSizeInBits();
22345 BCTruncElt = IsLE ? 0 : XBitWidth / VecEltBitWidth - 1;
22346
22347 // An extract element return value type can be wider than its vector
22348 // operand element type. In that case, the high bits are undefined, so
22349 // it's possible that we may need to extend rather than truncate.
22350 if (ExtractIndex == BCTruncElt && XBitWidth > VecEltBitWidth) {
22351 assert(XBitWidth % VecEltBitWidth == 0 &&
22352 "Scalar bitwidth must be a multiple of vector element bitwidth");
22353 return DAG.getAnyExtOrTrunc(Op: X, DL, VT: ScalarVT);
22354 }
22355 }
22356 }
22357
22358 // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT.
22359 // We only perform this optimization before the op legalization phase because
22360 // we may introduce new vector instructions which are not backed by TD
22361 // patterns. For example on AVX, extracting elements from a wide vector
22362 // without using extract_subvector. However, if we can find an underlying
22363 // scalar value, then we can always use that.
22364 if (IndexC && VecOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
22365 auto *Shuf = cast<ShuffleVectorSDNode>(Val&: VecOp);
22366 // Find the new index to extract from.
22367 int OrigElt = Shuf->getMaskElt(Idx: IndexC->getZExtValue());
22368
22369 // Extracting an undef index is undef.
22370 if (OrigElt == -1)
22371 return DAG.getUNDEF(VT: ScalarVT);
22372
22373 // Select the right vector half to extract from.
22374 SDValue SVInVec;
22375 if (OrigElt < (int)NumElts) {
22376 SVInVec = VecOp.getOperand(i: 0);
22377 } else {
22378 SVInVec = VecOp.getOperand(i: 1);
22379 OrigElt -= NumElts;
22380 }
22381
22382 if (SVInVec.getOpcode() == ISD::BUILD_VECTOR) {
22383 SDValue InOp = SVInVec.getOperand(i: OrigElt);
22384 if (InOp.getValueType() != ScalarVT) {
22385 assert(InOp.getValueType().isInteger() && ScalarVT.isInteger());
22386 InOp = DAG.getSExtOrTrunc(Op: InOp, DL, VT: ScalarVT);
22387 }
22388
22389 return InOp;
22390 }
22391
22392 // FIXME: We should handle recursing on other vector shuffles and
22393 // scalar_to_vector here as well.
22394
22395 if (!LegalOperations ||
22396 // FIXME: Should really be just isOperationLegalOrCustom.
22397 TLI.isOperationLegal(Op: ISD::EXTRACT_VECTOR_ELT, VT: VecVT) ||
22398 TLI.isOperationExpand(Op: ISD::VECTOR_SHUFFLE, VT: VecVT)) {
22399 return DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT, N1: SVInVec,
22400 N2: DAG.getVectorIdxConstant(Val: OrigElt, DL));
22401 }
22402 }
22403
22404 // If only EXTRACT_VECTOR_ELT nodes use the source vector we can
22405 // simplify it based on the (valid) extraction indices.
22406 if (llvm::all_of(Range: VecOp->uses(), P: [&](SDNode *Use) {
22407 return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
22408 Use->getOperand(Num: 0) == VecOp &&
22409 isa<ConstantSDNode>(Val: Use->getOperand(Num: 1));
22410 })) {
22411 APInt DemandedElts = APInt::getZero(numBits: NumElts);
22412 for (SDNode *Use : VecOp->uses()) {
22413 auto *CstElt = cast<ConstantSDNode>(Val: Use->getOperand(Num: 1));
22414 if (CstElt->getAPIntValue().ult(RHS: NumElts))
22415 DemandedElts.setBit(CstElt->getZExtValue());
22416 }
22417 if (SimplifyDemandedVectorElts(Op: VecOp, DemandedElts, AssumeSingleUse: true)) {
22418 // We simplified the vector operand of this extract element. If this
22419 // extract is not dead, visit it again so it is folded properly.
22420 if (N->getOpcode() != ISD::DELETED_NODE)
22421 AddToWorklist(N);
22422 return SDValue(N, 0);
22423 }
22424 APInt DemandedBits = APInt::getAllOnes(numBits: VecEltBitWidth);
22425 if (SimplifyDemandedBits(Op: VecOp, DemandedBits, DemandedElts, AssumeSingleUse: true)) {
22426 // We simplified the vector operand of this extract element. If this
22427 // extract is not dead, visit it again so it is folded properly.
22428 if (N->getOpcode() != ISD::DELETED_NODE)
22429 AddToWorklist(N);
22430 return SDValue(N, 0);
22431 }
22432 }
22433
22434 if (refineExtractVectorEltIntoMultipleNarrowExtractVectorElts(N))
22435 return SDValue(N, 0);
22436
22437 // Everything under here is trying to match an extract of a loaded value.
22438 // If the result of load has to be truncated, then it's not necessarily
22439 // profitable.
22440 bool BCNumEltsChanged = false;
22441 EVT ExtVT = VecVT.getVectorElementType();
22442 EVT LVT = ExtVT;
22443 if (ScalarVT.bitsLT(VT: LVT) && !TLI.isTruncateFree(FromVT: LVT, ToVT: ScalarVT))
22444 return SDValue();
22445
22446 if (VecOp.getOpcode() == ISD::BITCAST) {
22447 // Don't duplicate a load with other uses.
22448 if (!VecOp.hasOneUse())
22449 return SDValue();
22450
22451 EVT BCVT = VecOp.getOperand(i: 0).getValueType();
22452 if (!BCVT.isVector() || ExtVT.bitsGT(VT: BCVT.getVectorElementType()))
22453 return SDValue();
22454 if (NumElts != BCVT.getVectorNumElements())
22455 BCNumEltsChanged = true;
22456 VecOp = VecOp.getOperand(i: 0);
22457 ExtVT = BCVT.getVectorElementType();
22458 }
22459
22460 // extract (vector load $addr), i --> load $addr + i * size
22461 if (!LegalOperations && !IndexC && VecOp.hasOneUse() &&
22462 ISD::isNormalLoad(N: VecOp.getNode()) &&
22463 !Index->hasPredecessor(N: VecOp.getNode())) {
22464 auto *VecLoad = dyn_cast<LoadSDNode>(Val&: VecOp);
22465 if (VecLoad && VecLoad->isSimple())
22466 return scalarizeExtractedVectorLoad(EVE: N, InVecVT: VecVT, EltNo: Index, OriginalLoad: VecLoad);
22467 }
22468
22469 // Perform only after legalization to ensure build_vector / vector_shuffle
22470 // optimizations have already been done.
22471 if (!LegalOperations || !IndexC)
22472 return SDValue();
22473
22474 // (vextract (v4f32 load $addr), c) -> (f32 load $addr+c*size)
22475 // (vextract (v4f32 s2v (f32 load $addr)), c) -> (f32 load $addr+c*size)
22476 // (vextract (v4f32 shuffle (load $addr), <1,u,u,u>), 0) -> (f32 load $addr)
22477 int Elt = IndexC->getZExtValue();
22478 LoadSDNode *LN0 = nullptr;
22479 if (ISD::isNormalLoad(N: VecOp.getNode())) {
22480 LN0 = cast<LoadSDNode>(Val&: VecOp);
22481 } else if (VecOp.getOpcode() == ISD::SCALAR_TO_VECTOR &&
22482 VecOp.getOperand(i: 0).getValueType() == ExtVT &&
22483 ISD::isNormalLoad(N: VecOp.getOperand(i: 0).getNode())) {
22484 // Don't duplicate a load with other uses.
22485 if (!VecOp.hasOneUse())
22486 return SDValue();
22487
22488 LN0 = cast<LoadSDNode>(Val: VecOp.getOperand(i: 0));
22489 }
22490 if (auto *Shuf = dyn_cast<ShuffleVectorSDNode>(Val&: VecOp)) {
22491 // (vextract (vector_shuffle (load $addr), v2, <1, u, u, u>), 1)
22492 // =>
22493 // (load $addr+1*size)
22494
22495 // Don't duplicate a load with other uses.
22496 if (!VecOp.hasOneUse())
22497 return SDValue();
22498
22499 // If the bit convert changed the number of elements, it is unsafe
22500 // to examine the mask.
22501 if (BCNumEltsChanged)
22502 return SDValue();
22503
22504 // Select the input vector, guarding against out of range extract vector.
22505 int Idx = (Elt > (int)NumElts) ? -1 : Shuf->getMaskElt(Idx: Elt);
22506 VecOp = (Idx < (int)NumElts) ? VecOp.getOperand(i: 0) : VecOp.getOperand(i: 1);
22507
22508 if (VecOp.getOpcode() == ISD::BITCAST) {
22509 // Don't duplicate a load with other uses.
22510 if (!VecOp.hasOneUse())
22511 return SDValue();
22512
22513 VecOp = VecOp.getOperand(i: 0);
22514 }
22515 if (ISD::isNormalLoad(N: VecOp.getNode())) {
22516 LN0 = cast<LoadSDNode>(Val&: VecOp);
22517 Elt = (Idx < (int)NumElts) ? Idx : Idx - (int)NumElts;
22518 Index = DAG.getConstant(Val: Elt, DL, VT: Index.getValueType());
22519 }
22520 } else if (VecOp.getOpcode() == ISD::CONCAT_VECTORS && !BCNumEltsChanged &&
22521 VecVT.getVectorElementType() == ScalarVT &&
22522 (!LegalTypes ||
22523 TLI.isTypeLegal(
22524 VT: VecOp.getOperand(i: 0).getValueType().getVectorElementType()))) {
22525 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 0
22526 // -> extract_vector_elt a, 0
22527 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 1
22528 // -> extract_vector_elt a, 1
22529 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 2
22530 // -> extract_vector_elt b, 0
22531 // extract_vector_elt (concat_vectors v2i16:a, v2i16:b), 3
22532 // -> extract_vector_elt b, 1
22533 EVT ConcatVT = VecOp.getOperand(i: 0).getValueType();
22534 unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
22535 SDValue NewIdx = DAG.getConstant(Val: Elt % ConcatNumElts, DL,
22536 VT: Index.getValueType());
22537
22538 SDValue ConcatOp = VecOp.getOperand(i: Elt / ConcatNumElts);
22539 SDValue Elt = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL,
22540 VT: ConcatVT.getVectorElementType(),
22541 N1: ConcatOp, N2: NewIdx);
22542 return DAG.getNode(Opcode: ISD::BITCAST, DL, VT: ScalarVT, Operand: Elt);
22543 }
22544
22545 // Make sure we found a non-volatile load and the extractelement is
22546 // the only use.
22547 if (!LN0 || !LN0->hasNUsesOfValue(NUses: 1,Value: 0) || !LN0->isSimple())
22548 return SDValue();
22549
22550 // If Idx was -1 above, Elt is going to be -1, so just return undef.
22551 if (Elt == -1)
22552 return DAG.getUNDEF(VT: LVT);
22553
22554 return scalarizeExtractedVectorLoad(EVE: N, InVecVT: VecVT, EltNo: Index, OriginalLoad: LN0);
22555}
22556
22557// Simplify (build_vec (ext )) to (bitcast (build_vec ))
22558SDValue DAGCombiner::reduceBuildVecExtToExtBuildVec(SDNode *N) {
22559 // We perform this optimization post type-legalization because
22560 // the type-legalizer often scalarizes integer-promoted vectors.
22561 // Performing this optimization before may create bit-casts which
22562 // will be type-legalized to complex code sequences.
22563 // We perform this optimization only before the operation legalizer because we
22564 // may introduce illegal operations.
22565 if (Level != AfterLegalizeVectorOps && Level != AfterLegalizeTypes)
22566 return SDValue();
22567
22568 unsigned NumInScalars = N->getNumOperands();
22569 SDLoc DL(N);
22570 EVT VT = N->getValueType(ResNo: 0);
22571
22572 // Check to see if this is a BUILD_VECTOR of a bunch of values
22573 // which come from any_extend or zero_extend nodes. If so, we can create
22574 // a new BUILD_VECTOR using bit-casts which may enable other BUILD_VECTOR
22575 // optimizations. We do not handle sign-extend because we can't fill the sign
22576 // using shuffles.
22577 EVT SourceType = MVT::Other;
22578 bool AllAnyExt = true;
22579
22580 for (unsigned i = 0; i != NumInScalars; ++i) {
22581 SDValue In = N->getOperand(Num: i);
22582 // Ignore undef inputs.
22583 if (In.isUndef()) continue;
22584
22585 bool AnyExt = In.getOpcode() == ISD::ANY_EXTEND;
22586 bool ZeroExt = In.getOpcode() == ISD::ZERO_EXTEND;
22587
22588 // Abort if the element is not an extension.
22589 if (!ZeroExt && !AnyExt) {
22590 SourceType = MVT::Other;
22591 break;
22592 }
22593
22594 // The input is a ZeroExt or AnyExt. Check the original type.
22595 EVT InTy = In.getOperand(i: 0).getValueType();
22596
22597 // Check that all of the widened source types are the same.
22598 if (SourceType == MVT::Other)
22599 // First time.
22600 SourceType = InTy;
22601 else if (InTy != SourceType) {
22602 // Multiple income types. Abort.
22603 SourceType = MVT::Other;
22604 break;
22605 }
22606
22607 // Check if all of the extends are ANY_EXTENDs.
22608 AllAnyExt &= AnyExt;
22609 }
22610
22611 // In order to have valid types, all of the inputs must be extended from the
22612 // same source type and all of the inputs must be any or zero extend.
22613 // Scalar sizes must be a power of two.
22614 EVT OutScalarTy = VT.getScalarType();
22615 bool ValidTypes =
22616 SourceType != MVT::Other &&
22617 llvm::has_single_bit<uint32_t>(OutScalarTy.getSizeInBits()) &&
22618 llvm::has_single_bit<uint32_t>(SourceType.getSizeInBits());
22619
22620 // Create a new simpler BUILD_VECTOR sequence which other optimizations can
22621 // turn into a single shuffle instruction.
22622 if (!ValidTypes)
22623 return SDValue();
22624
22625 // If we already have a splat buildvector, then don't fold it if it means
22626 // introducing zeros.
22627 if (!AllAnyExt && DAG.isSplatValue(V: SDValue(N, 0), /*AllowUndefs*/ true))
22628 return SDValue();
22629
22630 bool isLE = DAG.getDataLayout().isLittleEndian();
22631 unsigned ElemRatio = OutScalarTy.getSizeInBits()/SourceType.getSizeInBits();
22632 assert(ElemRatio > 1 && "Invalid element size ratio");
22633 SDValue Filler = AllAnyExt ? DAG.getUNDEF(VT: SourceType):
22634 DAG.getConstant(Val: 0, DL, VT: SourceType);
22635
22636 unsigned NewBVElems = ElemRatio * VT.getVectorNumElements();
22637 SmallVector<SDValue, 8> Ops(NewBVElems, Filler);
22638
22639 // Populate the new build_vector
22640 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
22641 SDValue Cast = N->getOperand(Num: i);
22642 assert((Cast.getOpcode() == ISD::ANY_EXTEND ||
22643 Cast.getOpcode() == ISD::ZERO_EXTEND ||
22644 Cast.isUndef()) && "Invalid cast opcode");
22645 SDValue In;
22646 if (Cast.isUndef())
22647 In = DAG.getUNDEF(VT: SourceType);
22648 else
22649 In = Cast->getOperand(Num: 0);
22650 unsigned Index = isLE ? (i * ElemRatio) :
22651 (i * ElemRatio + (ElemRatio - 1));
22652
22653 assert(Index < Ops.size() && "Invalid index");
22654 Ops[Index] = In;
22655 }
22656
22657 // The type of the new BUILD_VECTOR node.
22658 EVT VecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SourceType, NumElements: NewBVElems);
22659 assert(VecVT.getSizeInBits() == VT.getSizeInBits() &&
22660 "Invalid vector size");
22661 // Check if the new vector type is legal.
22662 if (!isTypeLegal(VT: VecVT) ||
22663 (!TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT: VecVT) &&
22664 TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)))
22665 return SDValue();
22666
22667 // Make the new BUILD_VECTOR.
22668 SDValue BV = DAG.getBuildVector(VT: VecVT, DL, Ops);
22669
22670 // The new BUILD_VECTOR node has the potential to be further optimized.
22671 AddToWorklist(N: BV.getNode());
22672 // Bitcast to the desired type.
22673 return DAG.getBitcast(VT, V: BV);
22674}
22675
22676// Simplify (build_vec (trunc $1)
22677// (trunc (srl $1 half-width))
22678// (trunc (srl $1 (2 * half-width))))
22679// to (bitcast $1)
22680SDValue DAGCombiner::reduceBuildVecTruncToBitCast(SDNode *N) {
22681 assert(N->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
22682
22683 EVT VT = N->getValueType(ResNo: 0);
22684
22685 // Don't run this before LegalizeTypes if VT is legal.
22686 // Targets may have other preferences.
22687 if (Level < AfterLegalizeTypes && TLI.isTypeLegal(VT))
22688 return SDValue();
22689
22690 // Only for little endian
22691 if (!DAG.getDataLayout().isLittleEndian())
22692 return SDValue();
22693
22694 SDLoc DL(N);
22695 EVT OutScalarTy = VT.getScalarType();
22696 uint64_t ScalarTypeBitsize = OutScalarTy.getSizeInBits();
22697
22698 // Only for power of two types to be sure that bitcast works well
22699 if (!isPowerOf2_64(Value: ScalarTypeBitsize))
22700 return SDValue();
22701
22702 unsigned NumInScalars = N->getNumOperands();
22703
22704 // Look through bitcasts
22705 auto PeekThroughBitcast = [](SDValue Op) {
22706 if (Op.getOpcode() == ISD::BITCAST)
22707 return Op.getOperand(i: 0);
22708 return Op;
22709 };
22710
22711 // The source value where all the parts are extracted.
22712 SDValue Src;
22713 for (unsigned i = 0; i != NumInScalars; ++i) {
22714 SDValue In = PeekThroughBitcast(N->getOperand(Num: i));
22715 // Ignore undef inputs.
22716 if (In.isUndef()) continue;
22717
22718 if (In.getOpcode() != ISD::TRUNCATE)
22719 return SDValue();
22720
22721 In = PeekThroughBitcast(In.getOperand(i: 0));
22722
22723 if (In.getOpcode() != ISD::SRL) {
22724 // For now only build_vec without shuffling, handle shifts here in the
22725 // future.
22726 if (i != 0)
22727 return SDValue();
22728
22729 Src = In;
22730 } else {
22731 // In is SRL
22732 SDValue part = PeekThroughBitcast(In.getOperand(i: 0));
22733
22734 if (!Src) {
22735 Src = part;
22736 } else if (Src != part) {
22737 // Vector parts do not stem from the same variable
22738 return SDValue();
22739 }
22740
22741 SDValue ShiftAmtVal = In.getOperand(i: 1);
22742 if (!isa<ConstantSDNode>(Val: ShiftAmtVal))
22743 return SDValue();
22744
22745 uint64_t ShiftAmt = In.getConstantOperandVal(i: 1);
22746
22747 // The extracted value is not extracted at the right position
22748 if (ShiftAmt != i * ScalarTypeBitsize)
22749 return SDValue();
22750 }
22751 }
22752
22753 // Only cast if the size is the same
22754 if (!Src || Src.getValueType().getSizeInBits() != VT.getSizeInBits())
22755 return SDValue();
22756
22757 return DAG.getBitcast(VT, V: Src);
22758}
22759
22760SDValue DAGCombiner::createBuildVecShuffle(const SDLoc &DL, SDNode *N,
22761 ArrayRef<int> VectorMask,
22762 SDValue VecIn1, SDValue VecIn2,
22763 unsigned LeftIdx, bool DidSplitVec) {
22764 SDValue ZeroIdx = DAG.getVectorIdxConstant(Val: 0, DL);
22765
22766 EVT VT = N->getValueType(ResNo: 0);
22767 EVT InVT1 = VecIn1.getValueType();
22768 EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1;
22769
22770 unsigned NumElems = VT.getVectorNumElements();
22771 unsigned ShuffleNumElems = NumElems;
22772
22773 // If we artificially split a vector in two already, then the offsets in the
22774 // operands will all be based off of VecIn1, even those in VecIn2.
22775 unsigned Vec2Offset = DidSplitVec ? 0 : InVT1.getVectorNumElements();
22776
22777 uint64_t VTSize = VT.getFixedSizeInBits();
22778 uint64_t InVT1Size = InVT1.getFixedSizeInBits();
22779 uint64_t InVT2Size = InVT2.getFixedSizeInBits();
22780
22781 assert(InVT2Size <= InVT1Size &&
22782 "Inputs must be sorted to be in non-increasing vector size order.");
22783
22784 // We can't generate a shuffle node with mismatched input and output types.
22785 // Try to make the types match the type of the output.
22786 if (InVT1 != VT || InVT2 != VT) {
22787 if ((VTSize % InVT1Size == 0) && InVT1 == InVT2) {
22788 // If the output vector length is a multiple of both input lengths,
22789 // we can concatenate them and pad the rest with undefs.
22790 unsigned NumConcats = VTSize / InVT1Size;
22791 assert(NumConcats >= 2 && "Concat needs at least two inputs!");
22792 SmallVector<SDValue, 2> ConcatOps(NumConcats, DAG.getUNDEF(VT: InVT1));
22793 ConcatOps[0] = VecIn1;
22794 ConcatOps[1] = VecIn2 ? VecIn2 : DAG.getUNDEF(VT: InVT1);
22795 VecIn1 = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
22796 VecIn2 = SDValue();
22797 } else if (InVT1Size == VTSize * 2) {
22798 if (!TLI.isExtractSubvectorCheap(ResVT: VT, SrcVT: InVT1, Index: NumElems))
22799 return SDValue();
22800
22801 if (!VecIn2.getNode()) {
22802 // If we only have one input vector, and it's twice the size of the
22803 // output, split it in two.
22804 VecIn2 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: VecIn1,
22805 N2: DAG.getVectorIdxConstant(Val: NumElems, DL));
22806 VecIn1 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: VecIn1, N2: ZeroIdx);
22807 // Since we now have shorter input vectors, adjust the offset of the
22808 // second vector's start.
22809 Vec2Offset = NumElems;
22810 } else {
22811 assert(InVT2Size <= InVT1Size &&
22812 "Second input is not going to be larger than the first one.");
22813
22814 // VecIn1 is wider than the output, and we have another, possibly
22815 // smaller input. Pad the smaller input with undefs, shuffle at the
22816 // input vector width, and extract the output.
22817 // The shuffle type is different than VT, so check legality again.
22818 if (LegalOperations &&
22819 !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT: InVT1))
22820 return SDValue();
22821
22822 // Legalizing INSERT_SUBVECTOR is tricky - you basically have to
22823 // lower it back into a BUILD_VECTOR. So if the inserted type is
22824 // illegal, don't even try.
22825 if (InVT1 != InVT2) {
22826 if (!TLI.isTypeLegal(VT: InVT2))
22827 return SDValue();
22828 VecIn2 = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT: InVT1,
22829 N1: DAG.getUNDEF(VT: InVT1), N2: VecIn2, N3: ZeroIdx);
22830 }
22831 ShuffleNumElems = NumElems * 2;
22832 }
22833 } else if (InVT2Size * 2 == VTSize && InVT1Size == VTSize) {
22834 SmallVector<SDValue, 2> ConcatOps(2, DAG.getUNDEF(VT: InVT2));
22835 ConcatOps[0] = VecIn2;
22836 VecIn2 = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
22837 } else if (InVT1Size / VTSize > 1 && InVT1Size % VTSize == 0) {
22838 if (!TLI.isExtractSubvectorCheap(ResVT: VT, SrcVT: InVT1, Index: NumElems) ||
22839 !TLI.isTypeLegal(VT: InVT1) || !TLI.isTypeLegal(VT: InVT2))
22840 return SDValue();
22841 // If dest vector has less than two elements, then use shuffle and extract
22842 // from larger regs will cost even more.
22843 if (VT.getVectorNumElements() <= 2 || !VecIn2.getNode())
22844 return SDValue();
22845 assert(InVT2Size <= InVT1Size &&
22846 "Second input is not going to be larger than the first one.");
22847
22848 // VecIn1 is wider than the output, and we have another, possibly
22849 // smaller input. Pad the smaller input with undefs, shuffle at the
22850 // input vector width, and extract the output.
22851 // The shuffle type is different than VT, so check legality again.
22852 if (LegalOperations && !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT: InVT1))
22853 return SDValue();
22854
22855 if (InVT1 != InVT2) {
22856 VecIn2 = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT: InVT1,
22857 N1: DAG.getUNDEF(VT: InVT1), N2: VecIn2, N3: ZeroIdx);
22858 }
22859 ShuffleNumElems = InVT1Size / VTSize * NumElems;
22860 } else {
22861 // TODO: Support cases where the length mismatch isn't exactly by a
22862 // factor of 2.
22863 // TODO: Move this check upwards, so that if we have bad type
22864 // mismatches, we don't create any DAG nodes.
22865 return SDValue();
22866 }
22867 }
22868
22869 // Initialize mask to undef.
22870 SmallVector<int, 8> Mask(ShuffleNumElems, -1);
22871
22872 // Only need to run up to the number of elements actually used, not the
22873 // total number of elements in the shuffle - if we are shuffling a wider
22874 // vector, the high lanes should be set to undef.
22875 for (unsigned i = 0; i != NumElems; ++i) {
22876 if (VectorMask[i] <= 0)
22877 continue;
22878
22879 unsigned ExtIndex = N->getOperand(Num: i).getConstantOperandVal(i: 1);
22880 if (VectorMask[i] == (int)LeftIdx) {
22881 Mask[i] = ExtIndex;
22882 } else if (VectorMask[i] == (int)LeftIdx + 1) {
22883 Mask[i] = Vec2Offset + ExtIndex;
22884 }
22885 }
22886
22887 // The type the input vectors may have changed above.
22888 InVT1 = VecIn1.getValueType();
22889
22890 // If we already have a VecIn2, it should have the same type as VecIn1.
22891 // If we don't, get an undef/zero vector of the appropriate type.
22892 VecIn2 = VecIn2.getNode() ? VecIn2 : DAG.getUNDEF(VT: InVT1);
22893 assert(InVT1 == VecIn2.getValueType() && "Unexpected second input type.");
22894
22895 SDValue Shuffle = DAG.getVectorShuffle(VT: InVT1, dl: DL, N1: VecIn1, N2: VecIn2, Mask);
22896 if (ShuffleNumElems > NumElems)
22897 Shuffle = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT, N1: Shuffle, N2: ZeroIdx);
22898
22899 return Shuffle;
22900}
22901
22902static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
22903 assert(BV->getOpcode() == ISD::BUILD_VECTOR && "Expected build vector");
22904
22905 // First, determine where the build vector is not undef.
22906 // TODO: We could extend this to handle zero elements as well as undefs.
22907 int NumBVOps = BV->getNumOperands();
22908 int ZextElt = -1;
22909 for (int i = 0; i != NumBVOps; ++i) {
22910 SDValue Op = BV->getOperand(Num: i);
22911 if (Op.isUndef())
22912 continue;
22913 if (ZextElt == -1)
22914 ZextElt = i;
22915 else
22916 return SDValue();
22917 }
22918 // Bail out if there's no non-undef element.
22919 if (ZextElt == -1)
22920 return SDValue();
22921
22922 // The build vector contains some number of undef elements and exactly
22923 // one other element. That other element must be a zero-extended scalar
22924 // extracted from a vector at a constant index to turn this into a shuffle.
22925 // Also, require that the build vector does not implicitly truncate/extend
22926 // its elements.
22927 // TODO: This could be enhanced to allow ANY_EXTEND as well as ZERO_EXTEND.
22928 EVT VT = BV->getValueType(ResNo: 0);
22929 SDValue Zext = BV->getOperand(Num: ZextElt);
22930 if (Zext.getOpcode() != ISD::ZERO_EXTEND || !Zext.hasOneUse() ||
22931 Zext.getOperand(i: 0).getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
22932 !isa<ConstantSDNode>(Val: Zext.getOperand(i: 0).getOperand(i: 1)) ||
22933 Zext.getValueSizeInBits() != VT.getScalarSizeInBits())
22934 return SDValue();
22935
22936 // The zero-extend must be a multiple of the source size, and we must be
22937 // building a vector of the same size as the source of the extract element.
22938 SDValue Extract = Zext.getOperand(i: 0);
22939 unsigned DestSize = Zext.getValueSizeInBits();
22940 unsigned SrcSize = Extract.getValueSizeInBits();
22941 if (DestSize % SrcSize != 0 ||
22942 Extract.getOperand(i: 0).getValueSizeInBits() != VT.getSizeInBits())
22943 return SDValue();
22944
22945 // Create a shuffle mask that will combine the extracted element with zeros
22946 // and undefs.
22947 int ZextRatio = DestSize / SrcSize;
22948 int NumMaskElts = NumBVOps * ZextRatio;
22949 SmallVector<int, 32> ShufMask(NumMaskElts, -1);
22950 for (int i = 0; i != NumMaskElts; ++i) {
22951 if (i / ZextRatio == ZextElt) {
22952 // The low bits of the (potentially translated) extracted element map to
22953 // the source vector. The high bits map to zero. We will use a zero vector
22954 // as the 2nd source operand of the shuffle, so use the 1st element of
22955 // that vector (mask value is number-of-elements) for the high bits.
22956 int Low = DAG.getDataLayout().isBigEndian() ? (ZextRatio - 1) : 0;
22957 ShufMask[i] = (i % ZextRatio == Low) ? Extract.getConstantOperandVal(i: 1)
22958 : NumMaskElts;
22959 }
22960
22961 // Undef elements of the build vector remain undef because we initialize
22962 // the shuffle mask with -1.
22963 }
22964
22965 // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
22966 // bitcast (shuffle V, ZeroVec, VectorMask)
22967 SDLoc DL(BV);
22968 EVT VecVT = Extract.getOperand(i: 0).getValueType();
22969 SDValue ZeroVec = DAG.getConstant(Val: 0, DL, VT: VecVT);
22970 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22971 SDValue Shuf = TLI.buildLegalVectorShuffle(VT: VecVT, DL, N0: Extract.getOperand(i: 0),
22972 N1: ZeroVec, Mask: ShufMask, DAG);
22973 if (!Shuf)
22974 return SDValue();
22975 return DAG.getBitcast(VT, V: Shuf);
22976}
22977
22978// FIXME: promote to STLExtras.
22979template <typename R, typename T>
22980static auto getFirstIndexOf(R &&Range, const T &Val) {
22981 auto I = find(Range, Val);
22982 if (I == Range.end())
22983 return static_cast<decltype(std::distance(Range.begin(), I))>(-1);
22984 return std::distance(Range.begin(), I);
22985}
22986
22987// Check to see if this is a BUILD_VECTOR of a bunch of EXTRACT_VECTOR_ELT
22988// operations. If the types of the vectors we're extracting from allow it,
22989// turn this into a vector_shuffle node.
22990SDValue DAGCombiner::reduceBuildVecToShuffle(SDNode *N) {
22991 SDLoc DL(N);
22992 EVT VT = N->getValueType(ResNo: 0);
22993
22994 // Only type-legal BUILD_VECTOR nodes are converted to shuffle nodes.
22995 if (!isTypeLegal(VT))
22996 return SDValue();
22997
22998 if (SDValue V = reduceBuildVecToShuffleWithZero(BV: N, DAG))
22999 return V;
23000
23001 // May only combine to shuffle after legalize if shuffle is legal.
23002 if (LegalOperations && !TLI.isOperationLegal(Op: ISD::VECTOR_SHUFFLE, VT))
23003 return SDValue();
23004
23005 bool UsesZeroVector = false;
23006 unsigned NumElems = N->getNumOperands();
23007
23008 // Record, for each element of the newly built vector, which input vector
23009 // that element comes from. -1 stands for undef, 0 for the zero vector,
23010 // and positive values for the input vectors.
23011 // VectorMask maps each element to its vector number, and VecIn maps vector
23012 // numbers to their initial SDValues.
23013
23014 SmallVector<int, 8> VectorMask(NumElems, -1);
23015 SmallVector<SDValue, 8> VecIn;
23016 VecIn.push_back(Elt: SDValue());
23017
23018 for (unsigned i = 0; i != NumElems; ++i) {
23019 SDValue Op = N->getOperand(Num: i);
23020
23021 if (Op.isUndef())
23022 continue;
23023
23024 // See if we can use a blend with a zero vector.
23025 // TODO: Should we generalize this to a blend with an arbitrary constant
23026 // vector?
23027 if (isNullConstant(V: Op) || isNullFPConstant(V: Op)) {
23028 UsesZeroVector = true;
23029 VectorMask[i] = 0;
23030 continue;
23031 }
23032
23033 // Not an undef or zero. If the input is something other than an
23034 // EXTRACT_VECTOR_ELT with an in-range constant index, bail out.
23035 if (Op.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
23036 !isa<ConstantSDNode>(Val: Op.getOperand(i: 1)))
23037 return SDValue();
23038 SDValue ExtractedFromVec = Op.getOperand(i: 0);
23039
23040 if (ExtractedFromVec.getValueType().isScalableVector())
23041 return SDValue();
23042
23043 const APInt &ExtractIdx = Op.getConstantOperandAPInt(i: 1);
23044 if (ExtractIdx.uge(RHS: ExtractedFromVec.getValueType().getVectorNumElements()))
23045 return SDValue();
23046
23047 // All inputs must have the same element type as the output.
23048 if (VT.getVectorElementType() !=
23049 ExtractedFromVec.getValueType().getVectorElementType())
23050 return SDValue();
23051
23052 // Have we seen this input vector before?
23053 // The vectors are expected to be tiny (usually 1 or 2 elements), so using
23054 // a map back from SDValues to numbers isn't worth it.
23055 int Idx = getFirstIndexOf(Range&: VecIn, Val: ExtractedFromVec);
23056 if (Idx == -1) { // A new source vector?
23057 Idx = VecIn.size();
23058 VecIn.push_back(Elt: ExtractedFromVec);
23059 }
23060
23061 VectorMask[i] = Idx;
23062 }
23063
23064 // If we didn't find at least one input vector, bail out.
23065 if (VecIn.size() < 2)
23066 return SDValue();
23067
23068 // If all the Operands of BUILD_VECTOR extract from same
23069 // vector, then split the vector efficiently based on the maximum
23070 // vector access index and adjust the VectorMask and
23071 // VecIn accordingly.
23072 bool DidSplitVec = false;
23073 if (VecIn.size() == 2) {
23074 unsigned MaxIndex = 0;
23075 unsigned NearestPow2 = 0;
23076 SDValue Vec = VecIn.back();
23077 EVT InVT = Vec.getValueType();
23078 SmallVector<unsigned, 8> IndexVec(NumElems, 0);
23079
23080 for (unsigned i = 0; i < NumElems; i++) {
23081 if (VectorMask[i] <= 0)
23082 continue;
23083 unsigned Index = N->getOperand(Num: i).getConstantOperandVal(i: 1);
23084 IndexVec[i] = Index;
23085 MaxIndex = std::max(a: MaxIndex, b: Index);
23086 }
23087
23088 NearestPow2 = PowerOf2Ceil(A: MaxIndex);
23089 if (InVT.isSimple() && NearestPow2 > 2 && MaxIndex < NearestPow2 &&
23090 NumElems * 2 < NearestPow2) {
23091 unsigned SplitSize = NearestPow2 / 2;
23092 EVT SplitVT = EVT::getVectorVT(Context&: *DAG.getContext(),
23093 VT: InVT.getVectorElementType(), NumElements: SplitSize);
23094 if (TLI.isTypeLegal(VT: SplitVT) &&
23095 SplitSize + SplitVT.getVectorNumElements() <=
23096 InVT.getVectorNumElements()) {
23097 SDValue VecIn2 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: SplitVT, N1: Vec,
23098 N2: DAG.getVectorIdxConstant(Val: SplitSize, DL));
23099 SDValue VecIn1 = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: SplitVT, N1: Vec,
23100 N2: DAG.getVectorIdxConstant(Val: 0, DL));
23101 VecIn.pop_back();
23102 VecIn.push_back(Elt: VecIn1);
23103 VecIn.push_back(Elt: VecIn2);
23104 DidSplitVec = true;
23105
23106 for (unsigned i = 0; i < NumElems; i++) {
23107 if (VectorMask[i] <= 0)
23108 continue;
23109 VectorMask[i] = (IndexVec[i] < SplitSize) ? 1 : 2;
23110 }
23111 }
23112 }
23113 }
23114
23115 // Sort input vectors by decreasing vector element count,
23116 // while preserving the relative order of equally-sized vectors.
23117 // Note that we keep the first "implicit zero vector as-is.
23118 SmallVector<SDValue, 8> SortedVecIn(VecIn);
23119 llvm::stable_sort(Range: MutableArrayRef<SDValue>(SortedVecIn).drop_front(),
23120 C: [](const SDValue &a, const SDValue &b) {
23121 return a.getValueType().getVectorNumElements() >
23122 b.getValueType().getVectorNumElements();
23123 });
23124
23125 // We now also need to rebuild the VectorMask, because it referenced element
23126 // order in VecIn, and we just sorted them.
23127 for (int &SourceVectorIndex : VectorMask) {
23128 if (SourceVectorIndex <= 0)
23129 continue;
23130 unsigned Idx = getFirstIndexOf(Range&: SortedVecIn, Val: VecIn[SourceVectorIndex]);
23131 assert(Idx > 0 && Idx < SortedVecIn.size() &&
23132 VecIn[SourceVectorIndex] == SortedVecIn[Idx] && "Remapping failure");
23133 SourceVectorIndex = Idx;
23134 }
23135
23136 VecIn = std::move(SortedVecIn);
23137
23138 // TODO: Should this fire if some of the input vectors has illegal type (like
23139 // it does now), or should we let legalization run its course first?
23140
23141 // Shuffle phase:
23142 // Take pairs of vectors, and shuffle them so that the result has elements
23143 // from these vectors in the correct places.
23144 // For example, given:
23145 // t10: i32 = extract_vector_elt t1, Constant:i64<0>
23146 // t11: i32 = extract_vector_elt t2, Constant:i64<0>
23147 // t12: i32 = extract_vector_elt t3, Constant:i64<0>
23148 // t13: i32 = extract_vector_elt t1, Constant:i64<1>
23149 // t14: v4i32 = BUILD_VECTOR t10, t11, t12, t13
23150 // We will generate:
23151 // t20: v4i32 = vector_shuffle<0,4,u,1> t1, t2
23152 // t21: v4i32 = vector_shuffle<u,u,0,u> t3, undef
23153 SmallVector<SDValue, 4> Shuffles;
23154 for (unsigned In = 0, Len = (VecIn.size() / 2); In < Len; ++In) {
23155 unsigned LeftIdx = 2 * In + 1;
23156 SDValue VecLeft = VecIn[LeftIdx];
23157 SDValue VecRight =
23158 (LeftIdx + 1) < VecIn.size() ? VecIn[LeftIdx + 1] : SDValue();
23159
23160 if (SDValue Shuffle = createBuildVecShuffle(DL, N, VectorMask, VecIn1: VecLeft,
23161 VecIn2: VecRight, LeftIdx, DidSplitVec))
23162 Shuffles.push_back(Elt: Shuffle);
23163 else
23164 return SDValue();
23165 }
23166
23167 // If we need the zero vector as an "ingredient" in the blend tree, add it
23168 // to the list of shuffles.
23169 if (UsesZeroVector)
23170 Shuffles.push_back(Elt: VT.isInteger() ? DAG.getConstant(Val: 0, DL, VT)
23171 : DAG.getConstantFP(Val: 0.0, DL, VT));
23172
23173 // If we only have one shuffle, we're done.
23174 if (Shuffles.size() == 1)
23175 return Shuffles[0];
23176
23177 // Update the vector mask to point to the post-shuffle vectors.
23178 for (int &Vec : VectorMask)
23179 if (Vec == 0)
23180 Vec = Shuffles.size() - 1;
23181 else
23182 Vec = (Vec - 1) / 2;
23183
23184 // More than one shuffle. Generate a binary tree of blends, e.g. if from
23185 // the previous step we got the set of shuffles t10, t11, t12, t13, we will
23186 // generate:
23187 // t10: v8i32 = vector_shuffle<0,8,u,u,u,u,u,u> t1, t2
23188 // t11: v8i32 = vector_shuffle<u,u,0,8,u,u,u,u> t3, t4
23189 // t12: v8i32 = vector_shuffle<u,u,u,u,0,8,u,u> t5, t6
23190 // t13: v8i32 = vector_shuffle<u,u,u,u,u,u,0,8> t7, t8
23191 // t20: v8i32 = vector_shuffle<0,1,10,11,u,u,u,u> t10, t11
23192 // t21: v8i32 = vector_shuffle<u,u,u,u,4,5,14,15> t12, t13
23193 // t30: v8i32 = vector_shuffle<0,1,2,3,12,13,14,15> t20, t21
23194
23195 // Make sure the initial size of the shuffle list is even.
23196 if (Shuffles.size() % 2)
23197 Shuffles.push_back(Elt: DAG.getUNDEF(VT));
23198
23199 for (unsigned CurSize = Shuffles.size(); CurSize > 1; CurSize /= 2) {
23200 if (CurSize % 2) {
23201 Shuffles[CurSize] = DAG.getUNDEF(VT);
23202 CurSize++;
23203 }
23204 for (unsigned In = 0, Len = CurSize / 2; In < Len; ++In) {
23205 int Left = 2 * In;
23206 int Right = 2 * In + 1;
23207 SmallVector<int, 8> Mask(NumElems, -1);
23208 SDValue L = Shuffles[Left];
23209 ArrayRef<int> LMask;
23210 bool IsLeftShuffle = L.getOpcode() == ISD::VECTOR_SHUFFLE &&
23211 L.use_empty() && L.getOperand(i: 1).isUndef() &&
23212 L.getOperand(i: 0).getValueType() == L.getValueType();
23213 if (IsLeftShuffle) {
23214 LMask = cast<ShuffleVectorSDNode>(Val: L.getNode())->getMask();
23215 L = L.getOperand(i: 0);
23216 }
23217 SDValue R = Shuffles[Right];
23218 ArrayRef<int> RMask;
23219 bool IsRightShuffle = R.getOpcode() == ISD::VECTOR_SHUFFLE &&
23220 R.use_empty() && R.getOperand(i: 1).isUndef() &&
23221 R.getOperand(i: 0).getValueType() == R.getValueType();
23222 if (IsRightShuffle) {
23223 RMask = cast<ShuffleVectorSDNode>(Val: R.getNode())->getMask();
23224 R = R.getOperand(i: 0);
23225 }
23226 for (unsigned I = 0; I != NumElems; ++I) {
23227 if (VectorMask[I] == Left) {
23228 Mask[I] = I;
23229 if (IsLeftShuffle)
23230 Mask[I] = LMask[I];
23231 VectorMask[I] = In;
23232 } else if (VectorMask[I] == Right) {
23233 Mask[I] = I + NumElems;
23234 if (IsRightShuffle)
23235 Mask[I] = RMask[I] + NumElems;
23236 VectorMask[I] = In;
23237 }
23238 }
23239
23240 Shuffles[In] = DAG.getVectorShuffle(VT, dl: DL, N1: L, N2: R, Mask);
23241 }
23242 }
23243 return Shuffles[0];
23244}
23245
23246// Try to turn a build vector of zero extends of extract vector elts into a
23247// a vector zero extend and possibly an extract subvector.
23248// TODO: Support sign extend?
23249// TODO: Allow undef elements?
23250SDValue DAGCombiner::convertBuildVecZextToZext(SDNode *N) {
23251 if (LegalOperations)
23252 return SDValue();
23253
23254 EVT VT = N->getValueType(ResNo: 0);
23255
23256 bool FoundZeroExtend = false;
23257 SDValue Op0 = N->getOperand(Num: 0);
23258 auto checkElem = [&](SDValue Op) -> int64_t {
23259 unsigned Opc = Op.getOpcode();
23260 FoundZeroExtend |= (Opc == ISD::ZERO_EXTEND);
23261 if ((Opc == ISD::ZERO_EXTEND || Opc == ISD::ANY_EXTEND) &&
23262 Op.getOperand(i: 0).getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
23263 Op0.getOperand(i: 0).getOperand(i: 0) == Op.getOperand(i: 0).getOperand(i: 0))
23264 if (auto *C = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 0).getOperand(i: 1)))
23265 return C->getZExtValue();
23266 return -1;
23267 };
23268
23269 // Make sure the first element matches
23270 // (zext (extract_vector_elt X, C))
23271 // Offset must be a constant multiple of the
23272 // known-minimum vector length of the result type.
23273 int64_t Offset = checkElem(Op0);
23274 if (Offset < 0 || (Offset % VT.getVectorNumElements()) != 0)
23275 return SDValue();
23276
23277 unsigned NumElems = N->getNumOperands();
23278 SDValue In = Op0.getOperand(i: 0).getOperand(i: 0);
23279 EVT InSVT = In.getValueType().getScalarType();
23280 EVT InVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: InSVT, NumElements: NumElems);
23281
23282 // Don't create an illegal input type after type legalization.
23283 if (LegalTypes && !TLI.isTypeLegal(VT: InVT))
23284 return SDValue();
23285
23286 // Ensure all the elements come from the same vector and are adjacent.
23287 for (unsigned i = 1; i != NumElems; ++i) {
23288 if ((Offset + i) != checkElem(N->getOperand(Num: i)))
23289 return SDValue();
23290 }
23291
23292 SDLoc DL(N);
23293 In = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: InVT, N1: In,
23294 N2: Op0.getOperand(i: 0).getOperand(i: 1));
23295 return DAG.getNode(Opcode: FoundZeroExtend ? ISD::ZERO_EXTEND : ISD::ANY_EXTEND, DL,
23296 VT, Operand: In);
23297}
23298
23299// If this is a very simple BUILD_VECTOR with first element being a ZERO_EXTEND,
23300// and all other elements being constant zero's, granularize the BUILD_VECTOR's
23301// element width, absorbing the ZERO_EXTEND, turning it into a constant zero op.
23302// This patten can appear during legalization.
23303//
23304// NOTE: This can be generalized to allow more than a single
23305// non-constant-zero op, UNDEF's, and to be KnownBits-based,
23306SDValue DAGCombiner::convertBuildVecZextToBuildVecWithZeros(SDNode *N) {
23307 // Don't run this after legalization. Targets may have other preferences.
23308 if (Level >= AfterLegalizeDAG)
23309 return SDValue();
23310
23311 // FIXME: support big-endian.
23312 if (DAG.getDataLayout().isBigEndian())
23313 return SDValue();
23314
23315 EVT VT = N->getValueType(ResNo: 0);
23316 EVT OpVT = N->getOperand(Num: 0).getValueType();
23317 assert(!VT.isScalableVector() && "Encountered scalable BUILD_VECTOR?");
23318
23319 EVT OpIntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OpVT.getSizeInBits());
23320
23321 if (!TLI.isTypeLegal(VT: OpIntVT) ||
23322 (LegalOperations && !TLI.isOperationLegalOrCustom(Op: ISD::BITCAST, VT: OpIntVT)))
23323 return SDValue();
23324
23325 unsigned EltBitwidth = VT.getScalarSizeInBits();
23326 // NOTE: the actual width of operands may be wider than that!
23327
23328 // Analyze all operands of this BUILD_VECTOR. What is the largest number of
23329 // active bits they all have? We'll want to truncate them all to that width.
23330 unsigned ActiveBits = 0;
23331 APInt KnownZeroOps(VT.getVectorNumElements(), 0);
23332 for (auto I : enumerate(First: N->ops())) {
23333 SDValue Op = I.value();
23334 // FIXME: support UNDEF elements?
23335 if (auto *Cst = dyn_cast<ConstantSDNode>(Val&: Op)) {
23336 unsigned OpActiveBits =
23337 Cst->getAPIntValue().trunc(width: EltBitwidth).getActiveBits();
23338 if (OpActiveBits == 0) {
23339 KnownZeroOps.setBit(I.index());
23340 continue;
23341 }
23342 // Profitability check: don't allow non-zero constant operands.
23343 return SDValue();
23344 }
23345 // Profitability check: there must only be a single non-zero operand,
23346 // and it must be the first operand of the BUILD_VECTOR.
23347 if (I.index() != 0)
23348 return SDValue();
23349 // The operand must be a zero-extension itself.
23350 // FIXME: this could be generalized to known leading zeros check.
23351 if (Op.getOpcode() != ISD::ZERO_EXTEND)
23352 return SDValue();
23353 unsigned CurrActiveBits =
23354 Op.getOperand(i: 0).getValueSizeInBits().getFixedValue();
23355 assert(!ActiveBits && "Already encountered non-constant-zero operand?");
23356 ActiveBits = CurrActiveBits;
23357 // We want to at least halve the element size.
23358 if (2 * ActiveBits > EltBitwidth)
23359 return SDValue();
23360 }
23361
23362 // This BUILD_VECTOR must have at least one non-constant-zero operand.
23363 if (ActiveBits == 0)
23364 return SDValue();
23365
23366 // We have EltBitwidth bits, the *minimal* chunk size is ActiveBits,
23367 // into how many chunks can we split our element width?
23368 EVT NewScalarIntVT, NewIntVT;
23369 std::optional<unsigned> Factor;
23370 // We can split the element into at least two chunks, but not into more
23371 // than |_ EltBitwidth / ActiveBits _| chunks. Find a largest split factor
23372 // for which the element width is a multiple of it,
23373 // and the resulting types/operations on that chunk width are legal.
23374 assert(2 * ActiveBits <= EltBitwidth &&
23375 "We know that half or less bits of the element are active.");
23376 for (unsigned Scale = EltBitwidth / ActiveBits; Scale >= 2; --Scale) {
23377 if (EltBitwidth % Scale != 0)
23378 continue;
23379 unsigned ChunkBitwidth = EltBitwidth / Scale;
23380 assert(ChunkBitwidth >= ActiveBits && "As per starting point.");
23381 NewScalarIntVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: ChunkBitwidth);
23382 NewIntVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: NewScalarIntVT,
23383 NumElements: Scale * N->getNumOperands());
23384 if (!TLI.isTypeLegal(VT: NewScalarIntVT) || !TLI.isTypeLegal(VT: NewIntVT) ||
23385 (LegalOperations &&
23386 !(TLI.isOperationLegalOrCustom(Op: ISD::TRUNCATE, VT: NewScalarIntVT) &&
23387 TLI.isOperationLegalOrCustom(Op: ISD::BUILD_VECTOR, VT: NewIntVT))))
23388 continue;
23389 Factor = Scale;
23390 break;
23391 }
23392 if (!Factor)
23393 return SDValue();
23394
23395 SDLoc DL(N);
23396 SDValue ZeroOp = DAG.getConstant(Val: 0, DL, VT: NewScalarIntVT);
23397
23398 // Recreate the BUILD_VECTOR, with elements now being Factor times smaller.
23399 SmallVector<SDValue, 16> NewOps;
23400 NewOps.reserve(N: NewIntVT.getVectorNumElements());
23401 for (auto I : enumerate(First: N->ops())) {
23402 SDValue Op = I.value();
23403 assert(!Op.isUndef() && "FIXME: after allowing UNDEF's, handle them here.");
23404 unsigned SrcOpIdx = I.index();
23405 if (KnownZeroOps[SrcOpIdx]) {
23406 NewOps.append(NumInputs: *Factor, Elt: ZeroOp);
23407 continue;
23408 }
23409 Op = DAG.getBitcast(VT: OpIntVT, V: Op);
23410 Op = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: NewScalarIntVT, Operand: Op);
23411 NewOps.emplace_back(Args&: Op);
23412 NewOps.append(NumInputs: *Factor - 1, Elt: ZeroOp);
23413 }
23414 assert(NewOps.size() == NewIntVT.getVectorNumElements());
23415 SDValue NewBV = DAG.getBuildVector(VT: NewIntVT, DL, Ops: NewOps);
23416 NewBV = DAG.getBitcast(VT, V: NewBV);
23417 return NewBV;
23418}
23419
23420SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) {
23421 EVT VT = N->getValueType(ResNo: 0);
23422
23423 // A vector built entirely of undefs is undef.
23424 if (ISD::allOperandsUndef(N))
23425 return DAG.getUNDEF(VT);
23426
23427 // If this is a splat of a bitcast from another vector, change to a
23428 // concat_vector.
23429 // For example:
23430 // (build_vector (i64 (bitcast (v2i32 X))), (i64 (bitcast (v2i32 X)))) ->
23431 // (v2i64 (bitcast (concat_vectors (v2i32 X), (v2i32 X))))
23432 //
23433 // If X is a build_vector itself, the concat can become a larger build_vector.
23434 // TODO: Maybe this is useful for non-splat too?
23435 if (!LegalOperations) {
23436 SDValue Splat = cast<BuildVectorSDNode>(Val: N)->getSplatValue();
23437 // Only change build_vector to a concat_vector if the splat value type is
23438 // same as the vector element type.
23439 if (Splat && Splat.getValueType() == VT.getVectorElementType()) {
23440 Splat = peekThroughBitcasts(V: Splat);
23441 EVT SrcVT = Splat.getValueType();
23442 if (SrcVT.isVector()) {
23443 unsigned NumElts = N->getNumOperands() * SrcVT.getVectorNumElements();
23444 EVT NewVT = EVT::getVectorVT(Context&: *DAG.getContext(),
23445 VT: SrcVT.getVectorElementType(), NumElements: NumElts);
23446 if (!LegalTypes || TLI.isTypeLegal(VT: NewVT)) {
23447 SmallVector<SDValue, 8> Ops(N->getNumOperands(), Splat);
23448 SDValue Concat =
23449 DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT: NewVT, Ops);
23450 return DAG.getBitcast(VT, V: Concat);
23451 }
23452 }
23453 }
23454 }
23455
23456 // Check if we can express BUILD VECTOR via subvector extract.
23457 if (!LegalTypes && (N->getNumOperands() > 1)) {
23458 SDValue Op0 = N->getOperand(Num: 0);
23459 auto checkElem = [&](SDValue Op) -> uint64_t {
23460 if ((Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT) &&
23461 (Op0.getOperand(i: 0) == Op.getOperand(i: 0)))
23462 if (auto CNode = dyn_cast<ConstantSDNode>(Val: Op.getOperand(i: 1)))
23463 return CNode->getZExtValue();
23464 return -1;
23465 };
23466
23467 int Offset = checkElem(Op0);
23468 for (unsigned i = 0; i < N->getNumOperands(); ++i) {
23469 if (Offset + i != checkElem(N->getOperand(Num: i))) {
23470 Offset = -1;
23471 break;
23472 }
23473 }
23474
23475 if ((Offset == 0) &&
23476 (Op0.getOperand(i: 0).getValueType() == N->getValueType(ResNo: 0)))
23477 return Op0.getOperand(i: 0);
23478 if ((Offset != -1) &&
23479 ((Offset % N->getValueType(ResNo: 0).getVectorNumElements()) ==
23480 0)) // IDX must be multiple of output size.
23481 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT: N->getValueType(ResNo: 0),
23482 N1: Op0.getOperand(i: 0), N2: Op0.getOperand(i: 1));
23483 }
23484
23485 if (SDValue V = convertBuildVecZextToZext(N))
23486 return V;
23487
23488 if (SDValue V = convertBuildVecZextToBuildVecWithZeros(N))
23489 return V;
23490
23491 if (SDValue V = reduceBuildVecExtToExtBuildVec(N))
23492 return V;
23493
23494 if (SDValue V = reduceBuildVecTruncToBitCast(N))
23495 return V;
23496
23497 if (SDValue V = reduceBuildVecToShuffle(N))
23498 return V;
23499
23500 // A splat of a single element is a SPLAT_VECTOR if supported on the target.
23501 // Do this late as some of the above may replace the splat.
23502 if (TLI.getOperationAction(Op: ISD::SPLAT_VECTOR, VT) != TargetLowering::Expand)
23503 if (SDValue V = cast<BuildVectorSDNode>(Val: N)->getSplatValue()) {
23504 assert(!V.isUndef() && "Splat of undef should have been handled earlier");
23505 return DAG.getNode(Opcode: ISD::SPLAT_VECTOR, DL: SDLoc(N), VT, Operand: V);
23506 }
23507
23508 return SDValue();
23509}
23510
23511static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
23512 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23513 EVT OpVT = N->getOperand(Num: 0).getValueType();
23514
23515 // If the operands are legal vectors, leave them alone.
23516 if (TLI.isTypeLegal(VT: OpVT) || OpVT.isScalableVector())
23517 return SDValue();
23518
23519 SDLoc DL(N);
23520 EVT VT = N->getValueType(ResNo: 0);
23521 SmallVector<SDValue, 8> Ops;
23522 EVT SVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: OpVT.getSizeInBits());
23523
23524 // Keep track of what we encounter.
23525 bool AnyInteger = false;
23526 bool AnyFP = false;
23527 for (const SDValue &Op : N->ops()) {
23528 if (ISD::BITCAST == Op.getOpcode() &&
23529 !Op.getOperand(i: 0).getValueType().isVector())
23530 Ops.push_back(Elt: Op.getOperand(i: 0));
23531 else if (ISD::UNDEF == Op.getOpcode())
23532 Ops.push_back(Elt: DAG.getNode(Opcode: ISD::UNDEF, DL, VT: SVT));
23533 else
23534 return SDValue();
23535
23536 // Note whether we encounter an integer or floating point scalar.
23537 // If it's neither, bail out, it could be something weird like x86mmx.
23538 EVT LastOpVT = Ops.back().getValueType();
23539 if (LastOpVT.isFloatingPoint())
23540 AnyFP = true;
23541 else if (LastOpVT.isInteger())
23542 AnyInteger = true;
23543 else
23544 return SDValue();
23545 }
23546
23547 // If any of the operands is a floating point scalar bitcast to a vector,
23548 // use floating point types throughout, and bitcast everything.
23549 // Replace UNDEFs by another scalar UNDEF node, of the final desired type.
23550 if (AnyFP) {
23551 SVT = EVT::getFloatingPointVT(BitWidth: OpVT.getSizeInBits());
23552 if (AnyInteger) {
23553 for (SDValue &Op : Ops) {
23554 if (Op.getValueType() == SVT)
23555 continue;
23556 if (Op.isUndef())
23557 Op = DAG.getNode(Opcode: ISD::UNDEF, DL, VT: SVT);
23558 else
23559 Op = DAG.getBitcast(VT: SVT, V: Op);
23560 }
23561 }
23562 }
23563
23564 EVT VecVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SVT,
23565 NumElements: VT.getSizeInBits() / SVT.getSizeInBits());
23566 return DAG.getBitcast(VT, V: DAG.getBuildVector(VT: VecVT, DL, Ops));
23567}
23568
23569// Attempt to merge nested concat_vectors/undefs.
23570// Fold concat_vectors(concat_vectors(x,y,z,w),u,u,concat_vectors(a,b,c,d))
23571// --> concat_vectors(x,y,z,w,u,u,u,u,u,u,u,u,a,b,c,d)
23572static SDValue combineConcatVectorOfConcatVectors(SDNode *N,
23573 SelectionDAG &DAG) {
23574 EVT VT = N->getValueType(ResNo: 0);
23575
23576 // Ensure we're concatenating UNDEF and CONCAT_VECTORS nodes of similar types.
23577 EVT SubVT;
23578 SDValue FirstConcat;
23579 for (const SDValue &Op : N->ops()) {
23580 if (Op.isUndef())
23581 continue;
23582 if (Op.getOpcode() != ISD::CONCAT_VECTORS)
23583 return SDValue();
23584 if (!FirstConcat) {
23585 SubVT = Op.getOperand(i: 0).getValueType();
23586 if (!DAG.getTargetLoweringInfo().isTypeLegal(VT: SubVT))
23587 return SDValue();
23588 FirstConcat = Op;
23589 continue;
23590 }
23591 if (SubVT != Op.getOperand(i: 0).getValueType())
23592 return SDValue();
23593 }
23594 assert(FirstConcat && "Concat of all-undefs found");
23595
23596 SmallVector<SDValue> ConcatOps;
23597 for (const SDValue &Op : N->ops()) {
23598 if (Op.isUndef()) {
23599 ConcatOps.append(NumInputs: FirstConcat->getNumOperands(), Elt: DAG.getUNDEF(VT: SubVT));
23600 continue;
23601 }
23602 ConcatOps.append(in_start: Op->op_begin(), in_end: Op->op_end());
23603 }
23604 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops: ConcatOps);
23605}
23606
23607// Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR
23608// operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at
23609// most two distinct vectors the same size as the result, attempt to turn this
23610// into a legal shuffle.
23611static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
23612 EVT VT = N->getValueType(ResNo: 0);
23613 EVT OpVT = N->getOperand(Num: 0).getValueType();
23614
23615 // We currently can't generate an appropriate shuffle for a scalable vector.
23616 if (VT.isScalableVector())
23617 return SDValue();
23618
23619 int NumElts = VT.getVectorNumElements();
23620 int NumOpElts = OpVT.getVectorNumElements();
23621
23622 SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT);
23623 SmallVector<int, 8> Mask;
23624
23625 for (SDValue Op : N->ops()) {
23626 Op = peekThroughBitcasts(V: Op);
23627
23628 // UNDEF nodes convert to UNDEF shuffle mask values.
23629 if (Op.isUndef()) {
23630 Mask.append(NumInputs: (unsigned)NumOpElts, Elt: -1);
23631 continue;
23632 }
23633
23634 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
23635 return SDValue();
23636
23637 // What vector are we extracting the subvector from and at what index?
23638 SDValue ExtVec = Op.getOperand(i: 0);
23639 int ExtIdx = Op.getConstantOperandVal(i: 1);
23640
23641 // We want the EVT of the original extraction to correctly scale the
23642 // extraction index.
23643 EVT ExtVT = ExtVec.getValueType();
23644 ExtVec = peekThroughBitcasts(V: ExtVec);
23645
23646 // UNDEF nodes convert to UNDEF shuffle mask values.
23647 if (ExtVec.isUndef()) {
23648 Mask.append(NumInputs: (unsigned)NumOpElts, Elt: -1);
23649 continue;
23650 }
23651
23652 // Ensure that we are extracting a subvector from a vector the same
23653 // size as the result.
23654 if (ExtVT.getSizeInBits() != VT.getSizeInBits())
23655 return SDValue();
23656
23657 // Scale the subvector index to account for any bitcast.
23658 int NumExtElts = ExtVT.getVectorNumElements();
23659 if (0 == (NumExtElts % NumElts))
23660 ExtIdx /= (NumExtElts / NumElts);
23661 else if (0 == (NumElts % NumExtElts))
23662 ExtIdx *= (NumElts / NumExtElts);
23663 else
23664 return SDValue();
23665
23666 // At most we can reference 2 inputs in the final shuffle.
23667 if (SV0.isUndef() || SV0 == ExtVec) {
23668 SV0 = ExtVec;
23669 for (int i = 0; i != NumOpElts; ++i)
23670 Mask.push_back(Elt: i + ExtIdx);
23671 } else if (SV1.isUndef() || SV1 == ExtVec) {
23672 SV1 = ExtVec;
23673 for (int i = 0; i != NumOpElts; ++i)
23674 Mask.push_back(Elt: i + ExtIdx + NumElts);
23675 } else {
23676 return SDValue();
23677 }
23678 }
23679
23680 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23681 return TLI.buildLegalVectorShuffle(VT, DL: SDLoc(N), N0: DAG.getBitcast(VT, V: SV0),
23682 N1: DAG.getBitcast(VT, V: SV1), Mask, DAG);
23683}
23684
23685static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) {
23686 unsigned CastOpcode = N->getOperand(Num: 0).getOpcode();
23687 switch (CastOpcode) {
23688 case ISD::SINT_TO_FP:
23689 case ISD::UINT_TO_FP:
23690 case ISD::FP_TO_SINT:
23691 case ISD::FP_TO_UINT:
23692 // TODO: Allow more opcodes?
23693 // case ISD::BITCAST:
23694 // case ISD::TRUNCATE:
23695 // case ISD::ZERO_EXTEND:
23696 // case ISD::SIGN_EXTEND:
23697 // case ISD::FP_EXTEND:
23698 break;
23699 default:
23700 return SDValue();
23701 }
23702
23703 EVT SrcVT = N->getOperand(Num: 0).getOperand(i: 0).getValueType();
23704 if (!SrcVT.isVector())
23705 return SDValue();
23706
23707 // All operands of the concat must be the same kind of cast from the same
23708 // source type.
23709 SmallVector<SDValue, 4> SrcOps;
23710 for (SDValue Op : N->ops()) {
23711 if (Op.getOpcode() != CastOpcode || !Op.hasOneUse() ||
23712 Op.getOperand(i: 0).getValueType() != SrcVT)
23713 return SDValue();
23714 SrcOps.push_back(Elt: Op.getOperand(i: 0));
23715 }
23716
23717 // The wider cast must be supported by the target. This is unusual because
23718 // the operation support type parameter depends on the opcode. In addition,
23719 // check the other type in the cast to make sure this is really legal.
23720 EVT VT = N->getValueType(ResNo: 0);
23721 EVT SrcEltVT = SrcVT.getVectorElementType();
23722 ElementCount NumElts = SrcVT.getVectorElementCount() * N->getNumOperands();
23723 EVT ConcatSrcVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SrcEltVT, EC: NumElts);
23724 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
23725 switch (CastOpcode) {
23726 case ISD::SINT_TO_FP:
23727 case ISD::UINT_TO_FP:
23728 if (!TLI.isOperationLegalOrCustom(Op: CastOpcode, VT: ConcatSrcVT) ||
23729 !TLI.isTypeLegal(VT))
23730 return SDValue();
23731 break;
23732 case ISD::FP_TO_SINT:
23733 case ISD::FP_TO_UINT:
23734 if (!TLI.isOperationLegalOrCustom(Op: CastOpcode, VT) ||
23735 !TLI.isTypeLegal(VT: ConcatSrcVT))
23736 return SDValue();
23737 break;
23738 default:
23739 llvm_unreachable("Unexpected cast opcode");
23740 }
23741
23742 // concat (cast X), (cast Y)... -> cast (concat X, Y...)
23743 SDLoc DL(N);
23744 SDValue NewConcat = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT: ConcatSrcVT, Ops: SrcOps);
23745 return DAG.getNode(Opcode: CastOpcode, DL, VT, Operand: NewConcat);
23746}
23747
23748// See if this is a simple CONCAT_VECTORS with no UNDEF operands, and if one of
23749// the operands is a SHUFFLE_VECTOR, and all other operands are also operands
23750// to that SHUFFLE_VECTOR, create wider SHUFFLE_VECTOR.
23751static SDValue combineConcatVectorOfShuffleAndItsOperands(
23752 SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
23753 bool LegalOperations) {
23754 EVT VT = N->getValueType(ResNo: 0);
23755 EVT OpVT = N->getOperand(Num: 0).getValueType();
23756 if (VT.isScalableVector())
23757 return SDValue();
23758
23759 // For now, only allow simple 2-operand concatenations.
23760 if (N->getNumOperands() != 2)
23761 return SDValue();
23762
23763 // Don't create illegal types/shuffles when not allowed to.
23764 if ((LegalTypes && !TLI.isTypeLegal(VT)) ||
23765 (LegalOperations &&
23766 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT)))
23767 return SDValue();
23768
23769 // Analyze all of the operands of the CONCAT_VECTORS. Out of all of them,
23770 // we want to find one that is: (1) a SHUFFLE_VECTOR (2) only used by us,
23771 // and (3) all operands of CONCAT_VECTORS must be either that SHUFFLE_VECTOR,
23772 // or one of the operands of that SHUFFLE_VECTOR (but not UNDEF!).
23773 // (4) and for now, the SHUFFLE_VECTOR must be unary.
23774 ShuffleVectorSDNode *SVN = nullptr;
23775 for (SDValue Op : N->ops()) {
23776 if (auto *CurSVN = dyn_cast<ShuffleVectorSDNode>(Val&: Op);
23777 CurSVN && CurSVN->getOperand(Num: 1).isUndef() && N->isOnlyUserOf(N: CurSVN) &&
23778 all_of(Range: N->ops(), P: [CurSVN](SDValue Op) {
23779 // FIXME: can we allow UNDEF operands?
23780 return !Op.isUndef() &&
23781 (Op.getNode() == CurSVN || is_contained(Range: CurSVN->ops(), Element: Op));
23782 })) {
23783 SVN = CurSVN;
23784 break;
23785 }
23786 }
23787 if (!SVN)
23788 return SDValue();
23789
23790 // We are going to pad the shuffle operands, so any indice, that was picking
23791 // from the second operand, must be adjusted.
23792 SmallVector<int, 16> AdjustedMask;
23793 AdjustedMask.reserve(N: SVN->getMask().size());
23794 assert(SVN->getOperand(1).isUndef() && "Expected unary shuffle!");
23795 append_range(C&: AdjustedMask, R: SVN->getMask());
23796
23797 // Identity masks for the operands of the (padded) shuffle.
23798 SmallVector<int, 32> IdentityMask(2 * OpVT.getVectorNumElements());
23799 MutableArrayRef<int> FirstShufOpIdentityMask =
23800 MutableArrayRef<int>(IdentityMask)
23801 .take_front(N: OpVT.getVectorNumElements());
23802 MutableArrayRef<int> SecondShufOpIdentityMask =
23803 MutableArrayRef<int>(IdentityMask).take_back(N: OpVT.getVectorNumElements());
23804 std::iota(first: FirstShufOpIdentityMask.begin(), last: FirstShufOpIdentityMask.end(), value: 0);
23805 std::iota(first: SecondShufOpIdentityMask.begin(), last: SecondShufOpIdentityMask.end(),
23806 value: VT.getVectorNumElements());
23807
23808 // New combined shuffle mask.
23809 SmallVector<int, 32> Mask;
23810 Mask.reserve(N: VT.getVectorNumElements());
23811 for (SDValue Op : N->ops()) {
23812 assert(!Op.isUndef() && "Not expecting to concatenate UNDEF.");
23813 if (Op.getNode() == SVN) {
23814 append_range(C&: Mask, R&: AdjustedMask);
23815 continue;
23816 }
23817 if (Op == SVN->getOperand(Num: 0)) {
23818 append_range(C&: Mask, R&: FirstShufOpIdentityMask);
23819 continue;
23820 }
23821 if (Op == SVN->getOperand(Num: 1)) {
23822 append_range(C&: Mask, R&: SecondShufOpIdentityMask);
23823 continue;
23824 }
23825 llvm_unreachable("Unexpected operand!");
23826 }
23827
23828 // Don't create illegal shuffle masks.
23829 if (!TLI.isShuffleMaskLegal(Mask, VT))
23830 return SDValue();
23831
23832 // Pad the shuffle operands with UNDEF.
23833 SDLoc dl(N);
23834 std::array<SDValue, 2> ShufOps;
23835 for (auto I : zip(t: SVN->ops(), u&: ShufOps)) {
23836 SDValue ShufOp = std::get<0>(t&: I);
23837 SDValue &NewShufOp = std::get<1>(t&: I);
23838 if (ShufOp.isUndef())
23839 NewShufOp = DAG.getUNDEF(VT);
23840 else {
23841 SmallVector<SDValue, 2> ShufOpParts(N->getNumOperands(),
23842 DAG.getUNDEF(VT: OpVT));
23843 ShufOpParts[0] = ShufOp;
23844 NewShufOp = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: dl, VT, Ops: ShufOpParts);
23845 }
23846 }
23847 // Finally, create the new wide shuffle.
23848 return DAG.getVectorShuffle(VT, dl, N1: ShufOps[0], N2: ShufOps[1], Mask);
23849}
23850
23851SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
23852 // If we only have one input vector, we don't need to do any concatenation.
23853 if (N->getNumOperands() == 1)
23854 return N->getOperand(Num: 0);
23855
23856 // Check if all of the operands are undefs.
23857 EVT VT = N->getValueType(ResNo: 0);
23858 if (ISD::allOperandsUndef(N))
23859 return DAG.getUNDEF(VT);
23860
23861 // Optimize concat_vectors where all but the first of the vectors are undef.
23862 if (all_of(Range: drop_begin(RangeOrContainer: N->ops()),
23863 P: [](const SDValue &Op) { return Op.isUndef(); })) {
23864 SDValue In = N->getOperand(Num: 0);
23865 assert(In.getValueType().isVector() && "Must concat vectors");
23866
23867 // If the input is a concat_vectors, just make a larger concat by padding
23868 // with smaller undefs.
23869 //
23870 // Legalizing in AArch64TargetLowering::LowerCONCAT_VECTORS() and combining
23871 // here could cause an infinite loop. That legalizing happens when LegalDAG
23872 // is true and input of AArch64TargetLowering::LowerCONCAT_VECTORS() is
23873 // scalable.
23874 if (In.getOpcode() == ISD::CONCAT_VECTORS && In.hasOneUse() &&
23875 !(LegalDAG && In.getValueType().isScalableVector())) {
23876 unsigned NumOps = N->getNumOperands() * In.getNumOperands();
23877 SmallVector<SDValue, 4> Ops(In->op_begin(), In->op_end());
23878 Ops.resize(N: NumOps, NV: DAG.getUNDEF(VT: Ops[0].getValueType()));
23879 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
23880 }
23881
23882 SDValue Scalar = peekThroughOneUseBitcasts(V: In);
23883
23884 // concat_vectors(scalar_to_vector(scalar), undef) ->
23885 // scalar_to_vector(scalar)
23886 if (!LegalOperations && Scalar.getOpcode() == ISD::SCALAR_TO_VECTOR &&
23887 Scalar.hasOneUse()) {
23888 EVT SVT = Scalar.getValueType().getVectorElementType();
23889 if (SVT == Scalar.getOperand(i: 0).getValueType())
23890 Scalar = Scalar.getOperand(i: 0);
23891 }
23892
23893 // concat_vectors(scalar, undef) -> scalar_to_vector(scalar)
23894 if (!Scalar.getValueType().isVector() && In.hasOneUse()) {
23895 // If the bitcast type isn't legal, it might be a trunc of a legal type;
23896 // look through the trunc so we can still do the transform:
23897 // concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar)
23898 if (Scalar->getOpcode() == ISD::TRUNCATE &&
23899 !TLI.isTypeLegal(VT: Scalar.getValueType()) &&
23900 TLI.isTypeLegal(VT: Scalar->getOperand(Num: 0).getValueType()))
23901 Scalar = Scalar->getOperand(Num: 0);
23902
23903 EVT SclTy = Scalar.getValueType();
23904
23905 if (!SclTy.isFloatingPoint() && !SclTy.isInteger())
23906 return SDValue();
23907
23908 // Bail out if the vector size is not a multiple of the scalar size.
23909 if (VT.getSizeInBits() % SclTy.getSizeInBits())
23910 return SDValue();
23911
23912 unsigned VNTNumElms = VT.getSizeInBits() / SclTy.getSizeInBits();
23913 if (VNTNumElms < 2)
23914 return SDValue();
23915
23916 EVT NVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: SclTy, NumElements: VNTNumElms);
23917 if (!TLI.isTypeLegal(VT: NVT) || !TLI.isTypeLegal(VT: Scalar.getValueType()))
23918 return SDValue();
23919
23920 SDValue Res = DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL: SDLoc(N), VT: NVT, Operand: Scalar);
23921 return DAG.getBitcast(VT, V: Res);
23922 }
23923 }
23924
23925 // Fold any combination of BUILD_VECTOR or UNDEF nodes into one BUILD_VECTOR.
23926 // We have already tested above for an UNDEF only concatenation.
23927 // fold (concat_vectors (BUILD_VECTOR A, B, ...), (BUILD_VECTOR C, D, ...))
23928 // -> (BUILD_VECTOR A, B, ..., C, D, ...)
23929 auto IsBuildVectorOrUndef = [](const SDValue &Op) {
23930 return ISD::UNDEF == Op.getOpcode() || ISD::BUILD_VECTOR == Op.getOpcode();
23931 };
23932 if (llvm::all_of(Range: N->ops(), P: IsBuildVectorOrUndef)) {
23933 SmallVector<SDValue, 8> Opnds;
23934 EVT SVT = VT.getScalarType();
23935
23936 EVT MinVT = SVT;
23937 if (!SVT.isFloatingPoint()) {
23938 // If BUILD_VECTOR are from built from integer, they may have different
23939 // operand types. Get the smallest type and truncate all operands to it.
23940 bool FoundMinVT = false;
23941 for (const SDValue &Op : N->ops())
23942 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
23943 EVT OpSVT = Op.getOperand(i: 0).getValueType();
23944 MinVT = (!FoundMinVT || OpSVT.bitsLE(VT: MinVT)) ? OpSVT : MinVT;
23945 FoundMinVT = true;
23946 }
23947 assert(FoundMinVT && "Concat vector type mismatch");
23948 }
23949
23950 for (const SDValue &Op : N->ops()) {
23951 EVT OpVT = Op.getValueType();
23952 unsigned NumElts = OpVT.getVectorNumElements();
23953
23954 if (ISD::UNDEF == Op.getOpcode())
23955 Opnds.append(NumInputs: NumElts, Elt: DAG.getUNDEF(VT: MinVT));
23956
23957 if (ISD::BUILD_VECTOR == Op.getOpcode()) {
23958 if (SVT.isFloatingPoint()) {
23959 assert(SVT == OpVT.getScalarType() && "Concat vector type mismatch");
23960 Opnds.append(in_start: Op->op_begin(), in_end: Op->op_begin() + NumElts);
23961 } else {
23962 for (unsigned i = 0; i != NumElts; ++i)
23963 Opnds.push_back(
23964 Elt: DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(N), VT: MinVT, Operand: Op.getOperand(i)));
23965 }
23966 }
23967 }
23968
23969 assert(VT.getVectorNumElements() == Opnds.size() &&
23970 "Concat vector type mismatch");
23971 return DAG.getBuildVector(VT, DL: SDLoc(N), Ops: Opnds);
23972 }
23973
23974 // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR.
23975 // FIXME: Add support for concat_vectors(bitcast(vec0),bitcast(vec1),...).
23976 if (SDValue V = combineConcatVectorOfScalars(N, DAG))
23977 return V;
23978
23979 if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) {
23980 // Fold CONCAT_VECTORS of CONCAT_VECTORS (or undef) to VECTOR_SHUFFLE.
23981 if (SDValue V = combineConcatVectorOfConcatVectors(N, DAG))
23982 return V;
23983
23984 // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE.
23985 if (SDValue V = combineConcatVectorOfExtracts(N, DAG))
23986 return V;
23987 }
23988
23989 if (SDValue V = combineConcatVectorOfCasts(N, DAG))
23990 return V;
23991
23992 if (SDValue V = combineConcatVectorOfShuffleAndItsOperands(
23993 N, DAG, TLI, LegalTypes, LegalOperations))
23994 return V;
23995
23996 // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
23997 // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
23998 // operands and look for a CONCAT operations that place the incoming vectors
23999 // at the exact same location.
24000 //
24001 // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
24002 SDValue SingleSource = SDValue();
24003 unsigned PartNumElem =
24004 N->getOperand(Num: 0).getValueType().getVectorMinNumElements();
24005
24006 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
24007 SDValue Op = N->getOperand(Num: i);
24008
24009 if (Op.isUndef())
24010 continue;
24011
24012 // Check if this is the identity extract:
24013 if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR)
24014 return SDValue();
24015
24016 // Find the single incoming vector for the extract_subvector.
24017 if (SingleSource.getNode()) {
24018 if (Op.getOperand(i: 0) != SingleSource)
24019 return SDValue();
24020 } else {
24021 SingleSource = Op.getOperand(i: 0);
24022
24023 // Check the source type is the same as the type of the result.
24024 // If not, this concat may extend the vector, so we can not
24025 // optimize it away.
24026 if (SingleSource.getValueType() != N->getValueType(ResNo: 0))
24027 return SDValue();
24028 }
24029
24030 // Check that we are reading from the identity index.
24031 unsigned IdentityIndex = i * PartNumElem;
24032 if (Op.getConstantOperandAPInt(i: 1) != IdentityIndex)
24033 return SDValue();
24034 }
24035
24036 if (SingleSource.getNode())
24037 return SingleSource;
24038
24039 return SDValue();
24040}
24041
24042// Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
24043// if the subvector can be sourced for free.
24044static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
24045 if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
24046 V.getOperand(i: 1).getValueType() == SubVT && V.getOperand(i: 2) == Index) {
24047 return V.getOperand(i: 1);
24048 }
24049 auto *IndexC = dyn_cast<ConstantSDNode>(Val&: Index);
24050 if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
24051 V.getOperand(i: 0).getValueType() == SubVT &&
24052 (IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) {
24053 uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements();
24054 return V.getOperand(i: SubIdx);
24055 }
24056 return SDValue();
24057}
24058
24059static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
24060 SelectionDAG &DAG,
24061 bool LegalOperations) {
24062 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24063 SDValue BinOp = Extract->getOperand(Num: 0);
24064 unsigned BinOpcode = BinOp.getOpcode();
24065 if (!TLI.isBinOp(Opcode: BinOpcode) || BinOp->getNumValues() != 1)
24066 return SDValue();
24067
24068 EVT VecVT = BinOp.getValueType();
24069 SDValue Bop0 = BinOp.getOperand(i: 0), Bop1 = BinOp.getOperand(i: 1);
24070 if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
24071 return SDValue();
24072
24073 SDValue Index = Extract->getOperand(Num: 1);
24074 EVT SubVT = Extract->getValueType(ResNo: 0);
24075 if (!TLI.isOperationLegalOrCustom(Op: BinOpcode, VT: SubVT, LegalOnly: LegalOperations))
24076 return SDValue();
24077
24078 SDValue Sub0 = getSubVectorSrc(V: Bop0, Index, SubVT);
24079 SDValue Sub1 = getSubVectorSrc(V: Bop1, Index, SubVT);
24080
24081 // TODO: We could handle the case where only 1 operand is being inserted by
24082 // creating an extract of the other operand, but that requires checking
24083 // number of uses and/or costs.
24084 if (!Sub0 || !Sub1)
24085 return SDValue();
24086
24087 // We are inserting both operands of the wide binop only to extract back
24088 // to the narrow vector size. Eliminate all of the insert/extract:
24089 // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
24090 return DAG.getNode(Opcode: BinOpcode, DL: SDLoc(Extract), VT: SubVT, N1: Sub0, N2: Sub1,
24091 Flags: BinOp->getFlags());
24092}
24093
24094/// If we are extracting a subvector produced by a wide binary operator try
24095/// to use a narrow binary operator and/or avoid concatenation and extraction.
24096static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
24097 bool LegalOperations) {
24098 // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
24099 // some of these bailouts with other transforms.
24100
24101 if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations))
24102 return V;
24103
24104 // The extract index must be a constant, so we can map it to a concat operand.
24105 auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Val: Extract->getOperand(Num: 1));
24106 if (!ExtractIndexC)
24107 return SDValue();
24108
24109 // We are looking for an optionally bitcasted wide vector binary operator
24110 // feeding an extract subvector.
24111 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24112 SDValue BinOp = peekThroughBitcasts(V: Extract->getOperand(Num: 0));
24113 unsigned BOpcode = BinOp.getOpcode();
24114 if (!TLI.isBinOp(Opcode: BOpcode) || BinOp->getNumValues() != 1)
24115 return SDValue();
24116
24117 // Exclude the fake form of fneg (fsub -0.0, x) because that is likely to be
24118 // reduced to the unary fneg when it is visited, and we probably want to deal
24119 // with fneg in a target-specific way.
24120 if (BOpcode == ISD::FSUB) {
24121 auto *C = isConstOrConstSplatFP(N: BinOp.getOperand(i: 0), /*AllowUndefs*/ true);
24122 if (C && C->getValueAPF().isNegZero())
24123 return SDValue();
24124 }
24125
24126 // The binop must be a vector type, so we can extract some fraction of it.
24127 EVT WideBVT = BinOp.getValueType();
24128 // The optimisations below currently assume we are dealing with fixed length
24129 // vectors. It is possible to add support for scalable vectors, but at the
24130 // moment we've done no analysis to prove whether they are profitable or not.
24131 if (!WideBVT.isFixedLengthVector())
24132 return SDValue();
24133
24134 EVT VT = Extract->getValueType(ResNo: 0);
24135 unsigned ExtractIndex = ExtractIndexC->getZExtValue();
24136 assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
24137 "Extract index is not a multiple of the vector length.");
24138
24139 // Bail out if this is not a proper multiple width extraction.
24140 unsigned WideWidth = WideBVT.getSizeInBits();
24141 unsigned NarrowWidth = VT.getSizeInBits();
24142 if (WideWidth % NarrowWidth != 0)
24143 return SDValue();
24144
24145 // Bail out if we are extracting a fraction of a single operation. This can
24146 // occur because we potentially looked through a bitcast of the binop.
24147 unsigned NarrowingRatio = WideWidth / NarrowWidth;
24148 unsigned WideNumElts = WideBVT.getVectorNumElements();
24149 if (WideNumElts % NarrowingRatio != 0)
24150 return SDValue();
24151
24152 // Bail out if the target does not support a narrower version of the binop.
24153 EVT NarrowBVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: WideBVT.getScalarType(),
24154 NumElements: WideNumElts / NarrowingRatio);
24155 if (!TLI.isOperationLegalOrCustomOrPromote(Op: BOpcode, VT: NarrowBVT,
24156 LegalOnly: LegalOperations))
24157 return SDValue();
24158
24159 // If extraction is cheap, we don't need to look at the binop operands
24160 // for concat ops. The narrow binop alone makes this transform profitable.
24161 // We can't just reuse the original extract index operand because we may have
24162 // bitcasted.
24163 unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
24164 unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
24165 if (TLI.isExtractSubvectorCheap(ResVT: NarrowBVT, SrcVT: WideBVT, Index: ExtBOIdx) &&
24166 BinOp.hasOneUse() && Extract->getOperand(Num: 0)->hasOneUse()) {
24167 // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
24168 SDLoc DL(Extract);
24169 SDValue NewExtIndex = DAG.getVectorIdxConstant(Val: ExtBOIdx, DL);
24170 SDValue X = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
24171 N1: BinOp.getOperand(i: 0), N2: NewExtIndex);
24172 SDValue Y = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
24173 N1: BinOp.getOperand(i: 1), N2: NewExtIndex);
24174 SDValue NarrowBinOp =
24175 DAG.getNode(Opcode: BOpcode, DL, VT: NarrowBVT, N1: X, N2: Y, Flags: BinOp->getFlags());
24176 return DAG.getBitcast(VT, V: NarrowBinOp);
24177 }
24178
24179 // Only handle the case where we are doubling and then halving. A larger ratio
24180 // may require more than two narrow binops to replace the wide binop.
24181 if (NarrowingRatio != 2)
24182 return SDValue();
24183
24184 // TODO: The motivating case for this transform is an x86 AVX1 target. That
24185 // target has temptingly almost legal versions of bitwise logic ops in 256-bit
24186 // flavors, but no other 256-bit integer support. This could be extended to
24187 // handle any binop, but that may require fixing/adding other folds to avoid
24188 // codegen regressions.
24189 if (BOpcode != ISD::AND && BOpcode != ISD::OR && BOpcode != ISD::XOR)
24190 return SDValue();
24191
24192 // We need at least one concatenation operation of a binop operand to make
24193 // this transform worthwhile. The concat must double the input vector sizes.
24194 auto GetSubVector = [ConcatOpNum](SDValue V) -> SDValue {
24195 if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
24196 return V.getOperand(i: ConcatOpNum);
24197 return SDValue();
24198 };
24199 SDValue SubVecL = GetSubVector(peekThroughBitcasts(V: BinOp.getOperand(i: 0)));
24200 SDValue SubVecR = GetSubVector(peekThroughBitcasts(V: BinOp.getOperand(i: 1)));
24201
24202 if (SubVecL || SubVecR) {
24203 // If a binop operand was not the result of a concat, we must extract a
24204 // half-sized operand for our new narrow binop:
24205 // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
24206 // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
24207 // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
24208 SDLoc DL(Extract);
24209 SDValue IndexC = DAG.getVectorIdxConstant(Val: ExtBOIdx, DL);
24210 SDValue X = SubVecL ? DAG.getBitcast(VT: NarrowBVT, V: SubVecL)
24211 : DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
24212 N1: BinOp.getOperand(i: 0), N2: IndexC);
24213
24214 SDValue Y = SubVecR ? DAG.getBitcast(VT: NarrowBVT, V: SubVecR)
24215 : DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowBVT,
24216 N1: BinOp.getOperand(i: 1), N2: IndexC);
24217
24218 SDValue NarrowBinOp = DAG.getNode(Opcode: BOpcode, DL, VT: NarrowBVT, N1: X, N2: Y);
24219 return DAG.getBitcast(VT, V: NarrowBinOp);
24220 }
24221
24222 return SDValue();
24223}
24224
24225/// If we are extracting a subvector from a wide vector load, convert to a
24226/// narrow load to eliminate the extraction:
24227/// (extract_subvector (load wide vector)) --> (load narrow vector)
24228static SDValue narrowExtractedVectorLoad(SDNode *Extract, SelectionDAG &DAG) {
24229 // TODO: Add support for big-endian. The offset calculation must be adjusted.
24230 if (DAG.getDataLayout().isBigEndian())
24231 return SDValue();
24232
24233 auto *Ld = dyn_cast<LoadSDNode>(Val: Extract->getOperand(Num: 0));
24234 if (!Ld || Ld->getExtensionType() || !Ld->isSimple())
24235 return SDValue();
24236
24237 // Allow targets to opt-out.
24238 EVT VT = Extract->getValueType(ResNo: 0);
24239
24240 // We can only create byte sized loads.
24241 if (!VT.isByteSized())
24242 return SDValue();
24243
24244 unsigned Index = Extract->getConstantOperandVal(Num: 1);
24245 unsigned NumElts = VT.getVectorMinNumElements();
24246 // A fixed length vector being extracted from a scalable vector
24247 // may not be any *smaller* than the scalable one.
24248 if (Index == 0 && NumElts >= Ld->getValueType(ResNo: 0).getVectorMinNumElements())
24249 return SDValue();
24250
24251 // The definition of EXTRACT_SUBVECTOR states that the index must be a
24252 // multiple of the minimum number of elements in the result type.
24253 assert(Index % NumElts == 0 && "The extract subvector index is not a "
24254 "multiple of the result's element count");
24255
24256 // It's fine to use TypeSize here as we know the offset will not be negative.
24257 TypeSize Offset = VT.getStoreSize() * (Index / NumElts);
24258
24259 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24260 if (!TLI.shouldReduceLoadWidth(Load: Ld, ExtTy: Ld->getExtensionType(), NewVT: VT))
24261 return SDValue();
24262
24263 // The narrow load will be offset from the base address of the old load if
24264 // we are extracting from something besides index 0 (little-endian).
24265 SDLoc DL(Extract);
24266
24267 // TODO: Use "BaseIndexOffset" to make this more effective.
24268 SDValue NewAddr = DAG.getMemBasePlusOffset(Base: Ld->getBasePtr(), Offset, DL);
24269
24270 LocationSize StoreSize = LocationSize::precise(Value: VT.getStoreSize());
24271 MachineFunction &MF = DAG.getMachineFunction();
24272 MachineMemOperand *MMO;
24273 if (Offset.isScalable()) {
24274 MachinePointerInfo MPI =
24275 MachinePointerInfo(Ld->getPointerInfo().getAddrSpace());
24276 MMO = MF.getMachineMemOperand(MMO: Ld->getMemOperand(), PtrInfo: MPI, Size: StoreSize);
24277 } else
24278 MMO = MF.getMachineMemOperand(MMO: Ld->getMemOperand(), Offset: Offset.getFixedValue(),
24279 Size: StoreSize);
24280
24281 SDValue NewLd = DAG.getLoad(VT, dl: DL, Chain: Ld->getChain(), Ptr: NewAddr, MMO);
24282 DAG.makeEquivalentMemoryOrdering(OldLoad: Ld, NewMemOp: NewLd);
24283 return NewLd;
24284}
24285
24286/// Given EXTRACT_SUBVECTOR(VECTOR_SHUFFLE(Op0, Op1, Mask)),
24287/// try to produce VECTOR_SHUFFLE(EXTRACT_SUBVECTOR(Op?, ?),
24288/// EXTRACT_SUBVECTOR(Op?, ?),
24289/// Mask'))
24290/// iff it is legal and profitable to do so. Notably, the trimmed mask
24291/// (containing only the elements that are extracted)
24292/// must reference at most two subvectors.
24293static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
24294 SelectionDAG &DAG,
24295 const TargetLowering &TLI,
24296 bool LegalOperations) {
24297 assert(N->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
24298 "Must only be called on EXTRACT_SUBVECTOR's");
24299
24300 SDValue N0 = N->getOperand(Num: 0);
24301
24302 // Only deal with non-scalable vectors.
24303 EVT NarrowVT = N->getValueType(ResNo: 0);
24304 EVT WideVT = N0.getValueType();
24305 if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
24306 return SDValue();
24307
24308 // The operand must be a shufflevector.
24309 auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(Val&: N0);
24310 if (!WideShuffleVector)
24311 return SDValue();
24312
24313 // The old shuffleneeds to go away.
24314 if (!WideShuffleVector->hasOneUse())
24315 return SDValue();
24316
24317 // And the narrow shufflevector that we'll form must be legal.
24318 if (LegalOperations &&
24319 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT: NarrowVT))
24320 return SDValue();
24321
24322 uint64_t FirstExtractedEltIdx = N->getConstantOperandVal(Num: 1);
24323 int NumEltsExtracted = NarrowVT.getVectorNumElements();
24324 assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 &&
24325 "Extract index is not a multiple of the output vector length.");
24326
24327 int WideNumElts = WideVT.getVectorNumElements();
24328
24329 SmallVector<int, 16> NewMask;
24330 NewMask.reserve(N: NumEltsExtracted);
24331 SmallSetVector<std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>, 2>
24332 DemandedSubvectors;
24333
24334 // Try to decode the wide mask into narrow mask from at most two subvectors.
24335 for (int M : WideShuffleVector->getMask().slice(N: FirstExtractedEltIdx,
24336 M: NumEltsExtracted)) {
24337 assert((M >= -1) && (M < (2 * WideNumElts)) &&
24338 "Out-of-bounds shuffle mask?");
24339
24340 if (M < 0) {
24341 // Does not depend on operands, does not require adjustment.
24342 NewMask.emplace_back(Args&: M);
24343 continue;
24344 }
24345
24346 // From which operand of the shuffle does this shuffle mask element pick?
24347 int WideShufOpIdx = M / WideNumElts;
24348 // Which element of that operand is picked?
24349 int OpEltIdx = M % WideNumElts;
24350
24351 assert((OpEltIdx + WideShufOpIdx * WideNumElts) == M &&
24352 "Shuffle mask vector decomposition failure.");
24353
24354 // And which NumEltsExtracted-sized subvector of that operand is that?
24355 int OpSubvecIdx = OpEltIdx / NumEltsExtracted;
24356 // And which element within that subvector of that operand is that?
24357 int OpEltIdxInSubvec = OpEltIdx % NumEltsExtracted;
24358
24359 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted) == OpEltIdx &&
24360 "Shuffle mask subvector decomposition failure.");
24361
24362 assert((OpEltIdxInSubvec + OpSubvecIdx * NumEltsExtracted +
24363 WideShufOpIdx * WideNumElts) == M &&
24364 "Shuffle mask full decomposition failure.");
24365
24366 SDValue Op = WideShuffleVector->getOperand(Num: WideShufOpIdx);
24367
24368 if (Op.isUndef()) {
24369 // Picking from an undef operand. Let's adjust mask instead.
24370 NewMask.emplace_back(Args: -1);
24371 continue;
24372 }
24373
24374 const std::pair<SDValue, int> DemandedSubvector =
24375 std::make_pair(x&: Op, y&: OpSubvecIdx);
24376
24377 if (DemandedSubvectors.insert(X: DemandedSubvector)) {
24378 if (DemandedSubvectors.size() > 2)
24379 return SDValue(); // We can't handle more than two subvectors.
24380 // How many elements into the WideVT does this subvector start?
24381 int Index = NumEltsExtracted * OpSubvecIdx;
24382 // Bail out if the extraction isn't going to be cheap.
24383 if (!TLI.isExtractSubvectorCheap(ResVT: NarrowVT, SrcVT: WideVT, Index))
24384 return SDValue();
24385 }
24386
24387 // Ok, but from which operand of the new shuffle will this element pick?
24388 int NewOpIdx =
24389 getFirstIndexOf(Range: DemandedSubvectors.getArrayRef(), Val: DemandedSubvector);
24390 assert((NewOpIdx == 0 || NewOpIdx == 1) && "Unexpected operand index.");
24391
24392 int AdjM = OpEltIdxInSubvec + NewOpIdx * NumEltsExtracted;
24393 NewMask.emplace_back(Args&: AdjM);
24394 }
24395 assert(NewMask.size() == (unsigned)NumEltsExtracted && "Produced bad mask.");
24396 assert(DemandedSubvectors.size() <= 2 &&
24397 "Should have ended up demanding at most two subvectors.");
24398
24399 // Did we discover that the shuffle does not actually depend on operands?
24400 if (DemandedSubvectors.empty())
24401 return DAG.getUNDEF(VT: NarrowVT);
24402
24403 // Profitability check: only deal with extractions from the first subvector
24404 // unless the mask becomes an identity mask.
24405 if (!ShuffleVectorInst::isIdentityMask(Mask: NewMask, NumSrcElts: NewMask.size()) ||
24406 any_of(Range&: NewMask, P: [](int M) { return M < 0; }))
24407 for (auto &DemandedSubvector : DemandedSubvectors)
24408 if (DemandedSubvector.second != 0)
24409 return SDValue();
24410
24411 // We still perform the exact same EXTRACT_SUBVECTOR, just on different
24412 // operand[s]/index[es], so there is no point in checking for it's legality.
24413
24414 // Do not turn a legal shuffle into an illegal one.
24415 if (TLI.isShuffleMaskLegal(WideShuffleVector->getMask(), WideVT) &&
24416 !TLI.isShuffleMaskLegal(NewMask, NarrowVT))
24417 return SDValue();
24418
24419 SDLoc DL(N);
24420
24421 SmallVector<SDValue, 2> NewOps;
24422 for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
24423 &DemandedSubvector : DemandedSubvectors) {
24424 // How many elements into the WideVT does this subvector start?
24425 int Index = NumEltsExtracted * DemandedSubvector.second;
24426 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index, DL);
24427 NewOps.emplace_back(Args: DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NarrowVT,
24428 N1: DemandedSubvector.first, N2: IndexC));
24429 }
24430 assert((NewOps.size() == 1 || NewOps.size() == 2) &&
24431 "Should end up with either one or two ops");
24432
24433 // If we ended up with only one operand, pad with an undef.
24434 if (NewOps.size() == 1)
24435 NewOps.emplace_back(Args: DAG.getUNDEF(VT: NarrowVT));
24436
24437 return DAG.getVectorShuffle(VT: NarrowVT, dl: DL, N1: NewOps[0], N2: NewOps[1], Mask: NewMask);
24438}
24439
24440SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
24441 EVT NVT = N->getValueType(ResNo: 0);
24442 SDValue V = N->getOperand(Num: 0);
24443 uint64_t ExtIdx = N->getConstantOperandVal(Num: 1);
24444 SDLoc DL(N);
24445
24446 // Extract from UNDEF is UNDEF.
24447 if (V.isUndef())
24448 return DAG.getUNDEF(VT: NVT);
24449
24450 if (TLI.isOperationLegalOrCustomOrPromote(Op: ISD::LOAD, VT: NVT))
24451 if (SDValue NarrowLoad = narrowExtractedVectorLoad(Extract: N, DAG))
24452 return NarrowLoad;
24453
24454 // Combine an extract of an extract into a single extract_subvector.
24455 // ext (ext X, C), 0 --> ext X, C
24456 if (ExtIdx == 0 && V.getOpcode() == ISD::EXTRACT_SUBVECTOR && V.hasOneUse()) {
24457 if (TLI.isExtractSubvectorCheap(ResVT: NVT, SrcVT: V.getOperand(i: 0).getValueType(),
24458 Index: V.getConstantOperandVal(i: 1)) &&
24459 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NVT)) {
24460 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT, N1: V.getOperand(i: 0),
24461 N2: V.getOperand(i: 1));
24462 }
24463 }
24464
24465 // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V)
24466 if (V.getOpcode() == ISD::SPLAT_VECTOR)
24467 if (DAG.isConstantValueOfAnyType(N: V.getOperand(i: 0)) || V.hasOneUse())
24468 if (!LegalOperations || TLI.isOperationLegal(Op: ISD::SPLAT_VECTOR, VT: NVT))
24469 return DAG.getSplatVector(VT: NVT, DL, Op: V.getOperand(i: 0));
24470
24471 // extract_subvector(insert_subvector(x,y,c1),c2)
24472 // --> extract_subvector(y,c2-c1)
24473 // iff we're just extracting from the inserted subvector.
24474 if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
24475 SDValue InsSub = V.getOperand(i: 1);
24476 EVT InsSubVT = InsSub.getValueType();
24477 unsigned NumInsElts = InsSubVT.getVectorMinNumElements();
24478 unsigned InsIdx = V.getConstantOperandVal(i: 2);
24479 unsigned NumSubElts = NVT.getVectorMinNumElements();
24480 if (InsIdx <= ExtIdx && (ExtIdx + NumSubElts) <= (InsIdx + NumInsElts) &&
24481 TLI.isExtractSubvectorCheap(ResVT: NVT, SrcVT: InsSubVT, Index: ExtIdx - InsIdx) &&
24482 InsSubVT.isFixedLengthVector() && NVT.isFixedLengthVector() &&
24483 V.getValueType().isFixedLengthVector())
24484 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT, N1: InsSub,
24485 N2: DAG.getVectorIdxConstant(Val: ExtIdx - InsIdx, DL));
24486 }
24487
24488 // Try to move vector bitcast after extract_subv by scaling extraction index:
24489 // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index')
24490 if (V.getOpcode() == ISD::BITCAST &&
24491 V.getOperand(i: 0).getValueType().isVector() &&
24492 (!LegalOperations || TLI.isOperationLegal(Op: ISD::BITCAST, VT: NVT))) {
24493 SDValue SrcOp = V.getOperand(i: 0);
24494 EVT SrcVT = SrcOp.getValueType();
24495 unsigned SrcNumElts = SrcVT.getVectorMinNumElements();
24496 unsigned DestNumElts = V.getValueType().getVectorMinNumElements();
24497 if ((SrcNumElts % DestNumElts) == 0) {
24498 unsigned SrcDestRatio = SrcNumElts / DestNumElts;
24499 ElementCount NewExtEC = NVT.getVectorElementCount() * SrcDestRatio;
24500 EVT NewExtVT =
24501 EVT::getVectorVT(Context&: *DAG.getContext(), VT: SrcVT.getScalarType(), EC: NewExtEC);
24502 if (TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NewExtVT)) {
24503 SDValue NewIndex = DAG.getVectorIdxConstant(Val: ExtIdx * SrcDestRatio, DL);
24504 SDValue NewExtract = DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NewExtVT,
24505 N1: V.getOperand(i: 0), N2: NewIndex);
24506 return DAG.getBitcast(VT: NVT, V: NewExtract);
24507 }
24508 }
24509 if ((DestNumElts % SrcNumElts) == 0) {
24510 unsigned DestSrcRatio = DestNumElts / SrcNumElts;
24511 if (NVT.getVectorElementCount().isKnownMultipleOf(RHS: DestSrcRatio)) {
24512 ElementCount NewExtEC =
24513 NVT.getVectorElementCount().divideCoefficientBy(RHS: DestSrcRatio);
24514 EVT ScalarVT = SrcVT.getScalarType();
24515 if ((ExtIdx % DestSrcRatio) == 0) {
24516 unsigned IndexValScaled = ExtIdx / DestSrcRatio;
24517 EVT NewExtVT =
24518 EVT::getVectorVT(Context&: *DAG.getContext(), VT: ScalarVT, EC: NewExtEC);
24519 if (TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_SUBVECTOR, VT: NewExtVT)) {
24520 SDValue NewIndex = DAG.getVectorIdxConstant(Val: IndexValScaled, DL);
24521 SDValue NewExtract =
24522 DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NewExtVT,
24523 N1: V.getOperand(i: 0), N2: NewIndex);
24524 return DAG.getBitcast(VT: NVT, V: NewExtract);
24525 }
24526 if (NewExtEC.isScalar() &&
24527 TLI.isOperationLegalOrCustom(Op: ISD::EXTRACT_VECTOR_ELT, VT: ScalarVT)) {
24528 SDValue NewIndex = DAG.getVectorIdxConstant(Val: IndexValScaled, DL);
24529 SDValue NewExtract =
24530 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: ScalarVT,
24531 N1: V.getOperand(i: 0), N2: NewIndex);
24532 return DAG.getBitcast(VT: NVT, V: NewExtract);
24533 }
24534 }
24535 }
24536 }
24537 }
24538
24539 if (V.getOpcode() == ISD::CONCAT_VECTORS) {
24540 unsigned ExtNumElts = NVT.getVectorMinNumElements();
24541 EVT ConcatSrcVT = V.getOperand(i: 0).getValueType();
24542 assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
24543 "Concat and extract subvector do not change element type");
24544 assert((ExtIdx % ExtNumElts) == 0 &&
24545 "Extract index is not a multiple of the input vector length.");
24546
24547 unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
24548 unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
24549
24550 // If the concatenated source types match this extract, it's a direct
24551 // simplification:
24552 // extract_subvec (concat V1, V2, ...), i --> Vi
24553 if (NVT.getVectorElementCount() == ConcatSrcVT.getVectorElementCount())
24554 return V.getOperand(i: ConcatOpIdx);
24555
24556 // If the concatenated source vectors are a multiple length of this extract,
24557 // then extract a fraction of one of those source vectors directly from a
24558 // concat operand. Example:
24559 // v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y), 14 -->
24560 // v2i8 extract_subvec v8i8 Y, 6
24561 if (NVT.isFixedLengthVector() && ConcatSrcVT.isFixedLengthVector() &&
24562 ConcatSrcNumElts % ExtNumElts == 0) {
24563 unsigned NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
24564 assert(NewExtIdx + ExtNumElts <= ConcatSrcNumElts &&
24565 "Trying to extract from >1 concat operand?");
24566 assert(NewExtIdx % ExtNumElts == 0 &&
24567 "Extract index is not a multiple of the input vector length.");
24568 SDValue NewIndexC = DAG.getVectorIdxConstant(Val: NewExtIdx, DL);
24569 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT,
24570 N1: V.getOperand(i: ConcatOpIdx), N2: NewIndexC);
24571 }
24572 }
24573
24574 if (SDValue V =
24575 foldExtractSubvectorFromShuffleVector(N, DAG, TLI, LegalOperations))
24576 return V;
24577
24578 V = peekThroughBitcasts(V);
24579
24580 // If the input is a build vector. Try to make a smaller build vector.
24581 if (V.getOpcode() == ISD::BUILD_VECTOR) {
24582 EVT InVT = V.getValueType();
24583 unsigned ExtractSize = NVT.getSizeInBits();
24584 unsigned EltSize = InVT.getScalarSizeInBits();
24585 // Only do this if we won't split any elements.
24586 if (ExtractSize % EltSize == 0) {
24587 unsigned NumElems = ExtractSize / EltSize;
24588 EVT EltVT = InVT.getVectorElementType();
24589 EVT ExtractVT =
24590 NumElems == 1 ? EltVT
24591 : EVT::getVectorVT(Context&: *DAG.getContext(), VT: EltVT, NumElements: NumElems);
24592 if ((Level < AfterLegalizeDAG ||
24593 (NumElems == 1 ||
24594 TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT: ExtractVT))) &&
24595 (!LegalTypes || TLI.isTypeLegal(VT: ExtractVT))) {
24596 unsigned IdxVal = (ExtIdx * NVT.getScalarSizeInBits()) / EltSize;
24597
24598 if (NumElems == 1) {
24599 SDValue Src = V->getOperand(Num: IdxVal);
24600 if (EltVT != Src.getValueType())
24601 Src = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: EltVT, Operand: Src);
24602 return DAG.getBitcast(VT: NVT, V: Src);
24603 }
24604
24605 // Extract the pieces from the original build_vector.
24606 SDValue BuildVec =
24607 DAG.getBuildVector(VT: ExtractVT, DL, Ops: V->ops().slice(N: IdxVal, M: NumElems));
24608 return DAG.getBitcast(VT: NVT, V: BuildVec);
24609 }
24610 }
24611 }
24612
24613 if (V.getOpcode() == ISD::INSERT_SUBVECTOR) {
24614 // Handle only simple case where vector being inserted and vector
24615 // being extracted are of same size.
24616 EVT SmallVT = V.getOperand(i: 1).getValueType();
24617 if (!NVT.bitsEq(VT: SmallVT))
24618 return SDValue();
24619
24620 // Combine:
24621 // (extract_subvec (insert_subvec V1, V2, InsIdx), ExtIdx)
24622 // Into:
24623 // indices are equal or bit offsets are equal => V1
24624 // otherwise => (extract_subvec V1, ExtIdx)
24625 uint64_t InsIdx = V.getConstantOperandVal(i: 2);
24626 if (InsIdx * SmallVT.getScalarSizeInBits() ==
24627 ExtIdx * NVT.getScalarSizeInBits()) {
24628 if (LegalOperations && !TLI.isOperationLegal(Op: ISD::BITCAST, VT: NVT))
24629 return SDValue();
24630
24631 return DAG.getBitcast(VT: NVT, V: V.getOperand(i: 1));
24632 }
24633 return DAG.getNode(
24634 Opcode: ISD::EXTRACT_SUBVECTOR, DL, VT: NVT,
24635 N1: DAG.getBitcast(VT: N->getOperand(Num: 0).getValueType(), V: V.getOperand(i: 0)),
24636 N2: N->getOperand(Num: 1));
24637 }
24638
24639 if (SDValue NarrowBOp = narrowExtractedVectorBinOp(Extract: N, DAG, LegalOperations))
24640 return NarrowBOp;
24641
24642 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
24643 return SDValue(N, 0);
24644
24645 return SDValue();
24646}
24647
24648/// Try to convert a wide shuffle of concatenated vectors into 2 narrow shuffles
24649/// followed by concatenation. Narrow vector ops may have better performance
24650/// than wide ops, and this can unlock further narrowing of other vector ops.
24651/// Targets can invert this transform later if it is not profitable.
24652static SDValue foldShuffleOfConcatUndefs(ShuffleVectorSDNode *Shuf,
24653 SelectionDAG &DAG) {
24654 SDValue N0 = Shuf->getOperand(Num: 0), N1 = Shuf->getOperand(Num: 1);
24655 if (N0.getOpcode() != ISD::CONCAT_VECTORS || N0.getNumOperands() != 2 ||
24656 N1.getOpcode() != ISD::CONCAT_VECTORS || N1.getNumOperands() != 2 ||
24657 !N0.getOperand(i: 1).isUndef() || !N1.getOperand(i: 1).isUndef())
24658 return SDValue();
24659
24660 // Split the wide shuffle mask into halves. Any mask element that is accessing
24661 // operand 1 is offset down to account for narrowing of the vectors.
24662 ArrayRef<int> Mask = Shuf->getMask();
24663 EVT VT = Shuf->getValueType(ResNo: 0);
24664 unsigned NumElts = VT.getVectorNumElements();
24665 unsigned HalfNumElts = NumElts / 2;
24666 SmallVector<int, 16> Mask0(HalfNumElts, -1);
24667 SmallVector<int, 16> Mask1(HalfNumElts, -1);
24668 for (unsigned i = 0; i != NumElts; ++i) {
24669 if (Mask[i] == -1)
24670 continue;
24671 // If we reference the upper (undef) subvector then the element is undef.
24672 if ((Mask[i] % NumElts) >= HalfNumElts)
24673 continue;
24674 int M = Mask[i] < (int)NumElts ? Mask[i] : Mask[i] - (int)HalfNumElts;
24675 if (i < HalfNumElts)
24676 Mask0[i] = M;
24677 else
24678 Mask1[i - HalfNumElts] = M;
24679 }
24680
24681 // Ask the target if this is a valid transform.
24682 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24683 EVT HalfVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: VT.getScalarType(),
24684 NumElements: HalfNumElts);
24685 if (!TLI.isShuffleMaskLegal(Mask0, HalfVT) ||
24686 !TLI.isShuffleMaskLegal(Mask1, HalfVT))
24687 return SDValue();
24688
24689 // shuffle (concat X, undef), (concat Y, undef), Mask -->
24690 // concat (shuffle X, Y, Mask0), (shuffle X, Y, Mask1)
24691 SDValue X = N0.getOperand(i: 0), Y = N1.getOperand(i: 0);
24692 SDLoc DL(Shuf);
24693 SDValue Shuf0 = DAG.getVectorShuffle(VT: HalfVT, dl: DL, N1: X, N2: Y, Mask: Mask0);
24694 SDValue Shuf1 = DAG.getVectorShuffle(VT: HalfVT, dl: DL, N1: X, N2: Y, Mask: Mask1);
24695 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, N1: Shuf0, N2: Shuf1);
24696}
24697
24698// Tries to turn a shuffle of two CONCAT_VECTORS into a single concat,
24699// or turn a shuffle of a single concat into simpler shuffle then concat.
24700static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) {
24701 EVT VT = N->getValueType(ResNo: 0);
24702 unsigned NumElts = VT.getVectorNumElements();
24703
24704 SDValue N0 = N->getOperand(Num: 0);
24705 SDValue N1 = N->getOperand(Num: 1);
24706 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: N);
24707 ArrayRef<int> Mask = SVN->getMask();
24708
24709 SmallVector<SDValue, 4> Ops;
24710 EVT ConcatVT = N0.getOperand(i: 0).getValueType();
24711 unsigned NumElemsPerConcat = ConcatVT.getVectorNumElements();
24712 unsigned NumConcats = NumElts / NumElemsPerConcat;
24713
24714 auto IsUndefMaskElt = [](int i) { return i == -1; };
24715
24716 // Special case: shuffle(concat(A,B)) can be more efficiently represented
24717 // as concat(shuffle(A,B),UNDEF) if the shuffle doesn't set any of the high
24718 // half vector elements.
24719 if (NumElemsPerConcat * 2 == NumElts && N1.isUndef() &&
24720 llvm::all_of(Range: Mask.slice(N: NumElemsPerConcat, M: NumElemsPerConcat),
24721 P: IsUndefMaskElt)) {
24722 N0 = DAG.getVectorShuffle(VT: ConcatVT, dl: SDLoc(N), N1: N0.getOperand(i: 0),
24723 N2: N0.getOperand(i: 1),
24724 Mask: Mask.slice(N: 0, M: NumElemsPerConcat));
24725 N1 = DAG.getUNDEF(VT: ConcatVT);
24726 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, N1: N0, N2: N1);
24727 }
24728
24729 // Look at every vector that's inserted. We're looking for exact
24730 // subvector-sized copies from a concatenated vector
24731 for (unsigned I = 0; I != NumConcats; ++I) {
24732 unsigned Begin = I * NumElemsPerConcat;
24733 ArrayRef<int> SubMask = Mask.slice(N: Begin, M: NumElemsPerConcat);
24734
24735 // Make sure we're dealing with a copy.
24736 if (llvm::all_of(Range&: SubMask, P: IsUndefMaskElt)) {
24737 Ops.push_back(Elt: DAG.getUNDEF(VT: ConcatVT));
24738 continue;
24739 }
24740
24741 int OpIdx = -1;
24742 for (int i = 0; i != (int)NumElemsPerConcat; ++i) {
24743 if (IsUndefMaskElt(SubMask[i]))
24744 continue;
24745 if ((SubMask[i] % (int)NumElemsPerConcat) != i)
24746 return SDValue();
24747 int EltOpIdx = SubMask[i] / NumElemsPerConcat;
24748 if (0 <= OpIdx && EltOpIdx != OpIdx)
24749 return SDValue();
24750 OpIdx = EltOpIdx;
24751 }
24752 assert(0 <= OpIdx && "Unknown concat_vectors op");
24753
24754 if (OpIdx < (int)N0.getNumOperands())
24755 Ops.push_back(Elt: N0.getOperand(i: OpIdx));
24756 else
24757 Ops.push_back(Elt: N1.getOperand(i: OpIdx - N0.getNumOperands()));
24758 }
24759
24760 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
24761}
24762
24763// Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
24764// BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
24765//
24766// SHUFFLE(BUILD_VECTOR(), BUILD_VECTOR()) -> BUILD_VECTOR() is always
24767// a simplification in some sense, but it isn't appropriate in general: some
24768// BUILD_VECTORs are substantially cheaper than others. The general case
24769// of a BUILD_VECTOR requires inserting each element individually (or
24770// performing the equivalent in a temporary stack variable). A BUILD_VECTOR of
24771// all constants is a single constant pool load. A BUILD_VECTOR where each
24772// element is identical is a splat. A BUILD_VECTOR where most of the operands
24773// are undef lowers to a small number of element insertions.
24774//
24775// To deal with this, we currently use a bunch of mostly arbitrary heuristics.
24776// We don't fold shuffles where one side is a non-zero constant, and we don't
24777// fold shuffles if the resulting (non-splat) BUILD_VECTOR would have duplicate
24778// non-constant operands. This seems to work out reasonably well in practice.
24779static SDValue combineShuffleOfScalars(ShuffleVectorSDNode *SVN,
24780 SelectionDAG &DAG,
24781 const TargetLowering &TLI) {
24782 EVT VT = SVN->getValueType(ResNo: 0);
24783 unsigned NumElts = VT.getVectorNumElements();
24784 SDValue N0 = SVN->getOperand(Num: 0);
24785 SDValue N1 = SVN->getOperand(Num: 1);
24786
24787 if (!N0->hasOneUse())
24788 return SDValue();
24789
24790 // If only one of N1,N2 is constant, bail out if it is not ALL_ZEROS as
24791 // discussed above.
24792 if (!N1.isUndef()) {
24793 if (!N1->hasOneUse())
24794 return SDValue();
24795
24796 bool N0AnyConst = isAnyConstantBuildVector(V: N0);
24797 bool N1AnyConst = isAnyConstantBuildVector(V: N1);
24798 if (N0AnyConst && !N1AnyConst && !ISD::isBuildVectorAllZeros(N: N0.getNode()))
24799 return SDValue();
24800 if (!N0AnyConst && N1AnyConst && !ISD::isBuildVectorAllZeros(N: N1.getNode()))
24801 return SDValue();
24802 }
24803
24804 // If both inputs are splats of the same value then we can safely merge this
24805 // to a single BUILD_VECTOR with undef elements based on the shuffle mask.
24806 bool IsSplat = false;
24807 auto *BV0 = dyn_cast<BuildVectorSDNode>(Val&: N0);
24808 auto *BV1 = dyn_cast<BuildVectorSDNode>(Val&: N1);
24809 if (BV0 && BV1)
24810 if (SDValue Splat0 = BV0->getSplatValue())
24811 IsSplat = (Splat0 == BV1->getSplatValue());
24812
24813 SmallVector<SDValue, 8> Ops;
24814 SmallSet<SDValue, 16> DuplicateOps;
24815 for (int M : SVN->getMask()) {
24816 SDValue Op = DAG.getUNDEF(VT: VT.getScalarType());
24817 if (M >= 0) {
24818 int Idx = M < (int)NumElts ? M : M - NumElts;
24819 SDValue &S = (M < (int)NumElts ? N0 : N1);
24820 if (S.getOpcode() == ISD::BUILD_VECTOR) {
24821 Op = S.getOperand(i: Idx);
24822 } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR) {
24823 SDValue Op0 = S.getOperand(i: 0);
24824 Op = Idx == 0 ? Op0 : DAG.getUNDEF(VT: Op0.getValueType());
24825 } else {
24826 // Operand can't be combined - bail out.
24827 return SDValue();
24828 }
24829 }
24830
24831 // Don't duplicate a non-constant BUILD_VECTOR operand unless we're
24832 // generating a splat; semantically, this is fine, but it's likely to
24833 // generate low-quality code if the target can't reconstruct an appropriate
24834 // shuffle.
24835 if (!Op.isUndef() && !isIntOrFPConstant(V: Op))
24836 if (!IsSplat && !DuplicateOps.insert(V: Op).second)
24837 return SDValue();
24838
24839 Ops.push_back(Elt: Op);
24840 }
24841
24842 // BUILD_VECTOR requires all inputs to be of the same type, find the
24843 // maximum type and extend them all.
24844 EVT SVT = VT.getScalarType();
24845 if (SVT.isInteger())
24846 for (SDValue &Op : Ops)
24847 SVT = (SVT.bitsLT(VT: Op.getValueType()) ? Op.getValueType() : SVT);
24848 if (SVT != VT.getScalarType())
24849 for (SDValue &Op : Ops)
24850 Op = Op.isUndef() ? DAG.getUNDEF(VT: SVT)
24851 : (TLI.isZExtFree(FromTy: Op.getValueType(), ToTy: SVT)
24852 ? DAG.getZExtOrTrunc(Op, DL: SDLoc(SVN), VT: SVT)
24853 : DAG.getSExtOrTrunc(Op, DL: SDLoc(SVN), VT: SVT));
24854 return DAG.getBuildVector(VT, DL: SDLoc(SVN), Ops);
24855}
24856
24857// Match shuffles that can be converted to *_vector_extend_in_reg.
24858// This is often generated during legalization.
24859// e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src)),
24860// and returns the EVT to which the extension should be performed.
24861// NOTE: this assumes that the src is the first operand of the shuffle.
24862static std::optional<EVT> canCombineShuffleToExtendVectorInreg(
24863 unsigned Opcode, EVT VT, std::function<bool(unsigned)> Match,
24864 SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes,
24865 bool LegalOperations) {
24866 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
24867
24868 // TODO Add support for big-endian when we have a test case.
24869 if (!VT.isInteger() || IsBigEndian)
24870 return std::nullopt;
24871
24872 unsigned NumElts = VT.getVectorNumElements();
24873 unsigned EltSizeInBits = VT.getScalarSizeInBits();
24874
24875 // Attempt to match a '*_extend_vector_inreg' shuffle, we just search for
24876 // power-of-2 extensions as they are the most likely.
24877 // FIXME: should try Scale == NumElts case too,
24878 for (unsigned Scale = 2; Scale < NumElts; Scale *= 2) {
24879 // The vector width must be a multiple of Scale.
24880 if (NumElts % Scale != 0)
24881 continue;
24882
24883 EVT OutSVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: EltSizeInBits * Scale);
24884 EVT OutVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: OutSVT, NumElements: NumElts / Scale);
24885
24886 if ((LegalTypes && !TLI.isTypeLegal(VT: OutVT)) ||
24887 (LegalOperations && !TLI.isOperationLegalOrCustom(Op: Opcode, VT: OutVT)))
24888 continue;
24889
24890 if (Match(Scale))
24891 return OutVT;
24892 }
24893
24894 return std::nullopt;
24895}
24896
24897// Match shuffles that can be converted to any_vector_extend_in_reg.
24898// This is often generated during legalization.
24899// e.g. v4i32 <0,u,1,u> -> (v2i64 any_vector_extend_in_reg(v4i32 src))
24900static SDValue combineShuffleToAnyExtendVectorInreg(ShuffleVectorSDNode *SVN,
24901 SelectionDAG &DAG,
24902 const TargetLowering &TLI,
24903 bool LegalOperations) {
24904 EVT VT = SVN->getValueType(ResNo: 0);
24905 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
24906
24907 // TODO Add support for big-endian when we have a test case.
24908 if (!VT.isInteger() || IsBigEndian)
24909 return SDValue();
24910
24911 // shuffle<0,-1,1,-1> == (v2i64 anyextend_vector_inreg(v4i32))
24912 auto isAnyExtend = [NumElts = VT.getVectorNumElements(),
24913 Mask = SVN->getMask()](unsigned Scale) {
24914 for (unsigned i = 0; i != NumElts; ++i) {
24915 if (Mask[i] < 0)
24916 continue;
24917 if ((i % Scale) == 0 && Mask[i] == (int)(i / Scale))
24918 continue;
24919 return false;
24920 }
24921 return true;
24922 };
24923
24924 unsigned Opcode = ISD::ANY_EXTEND_VECTOR_INREG;
24925 SDValue N0 = SVN->getOperand(Num: 0);
24926 // Never create an illegal type. Only create unsupported operations if we
24927 // are pre-legalization.
24928 std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
24929 Opcode, VT, Match: isAnyExtend, DAG, TLI, /*LegalTypes=*/true, LegalOperations);
24930 if (!OutVT)
24931 return SDValue();
24932 return DAG.getBitcast(VT, V: DAG.getNode(Opcode, DL: SDLoc(SVN), VT: *OutVT, Operand: N0));
24933}
24934
24935// Match shuffles that can be converted to zero_extend_vector_inreg.
24936// This is often generated during legalization.
24937// e.g. v4i32 <0,z,1,u> -> (v2i64 zero_extend_vector_inreg(v4i32 src))
24938static SDValue combineShuffleToZeroExtendVectorInReg(ShuffleVectorSDNode *SVN,
24939 SelectionDAG &DAG,
24940 const TargetLowering &TLI,
24941 bool LegalOperations) {
24942 bool LegalTypes = true;
24943 EVT VT = SVN->getValueType(ResNo: 0);
24944 assert(!VT.isScalableVector() && "Encountered scalable shuffle?");
24945 unsigned NumElts = VT.getVectorNumElements();
24946 unsigned EltSizeInBits = VT.getScalarSizeInBits();
24947
24948 // TODO: add support for big-endian when we have a test case.
24949 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
24950 if (!VT.isInteger() || IsBigEndian)
24951 return SDValue();
24952
24953 SmallVector<int, 16> Mask(SVN->getMask().begin(), SVN->getMask().end());
24954 auto ForEachDecomposedIndice = [NumElts, &Mask](auto Fn) {
24955 for (int &Indice : Mask) {
24956 if (Indice < 0)
24957 continue;
24958 int OpIdx = (unsigned)Indice < NumElts ? 0 : 1;
24959 int OpEltIdx = (unsigned)Indice < NumElts ? Indice : Indice - NumElts;
24960 Fn(Indice, OpIdx, OpEltIdx);
24961 }
24962 };
24963
24964 // Which elements of which operand does this shuffle demand?
24965 std::array<APInt, 2> OpsDemandedElts;
24966 for (APInt &OpDemandedElts : OpsDemandedElts)
24967 OpDemandedElts = APInt::getZero(numBits: NumElts);
24968 ForEachDecomposedIndice(
24969 [&OpsDemandedElts](int &Indice, int OpIdx, int OpEltIdx) {
24970 OpsDemandedElts[OpIdx].setBit(OpEltIdx);
24971 });
24972
24973 // Element-wise(!), which of these demanded elements are know to be zero?
24974 std::array<APInt, 2> OpsKnownZeroElts;
24975 for (auto I : zip(t: SVN->ops(), u&: OpsDemandedElts, args&: OpsKnownZeroElts))
24976 std::get<2>(t&: I) =
24977 DAG.computeVectorKnownZeroElements(Op: std::get<0>(t&: I), DemandedElts: std::get<1>(t&: I));
24978
24979 // Manifest zeroable element knowledge in the shuffle mask.
24980 // NOTE: we don't have 'zeroable' sentinel value in generic DAG,
24981 // this is a local invention, but it won't leak into DAG.
24982 // FIXME: should we not manifest them, but just check when matching?
24983 bool HadZeroableElts = false;
24984 ForEachDecomposedIndice([&OpsKnownZeroElts, &HadZeroableElts](
24985 int &Indice, int OpIdx, int OpEltIdx) {
24986 if (OpsKnownZeroElts[OpIdx][OpEltIdx]) {
24987 Indice = -2; // Zeroable element.
24988 HadZeroableElts = true;
24989 }
24990 });
24991
24992 // Don't proceed unless we've refined at least one zeroable mask indice.
24993 // If we didn't, then we are still trying to match the same shuffle mask
24994 // we previously tried to match as ISD::ANY_EXTEND_VECTOR_INREG,
24995 // and evidently failed. Proceeding will lead to endless combine loops.
24996 if (!HadZeroableElts)
24997 return SDValue();
24998
24999 // The shuffle may be more fine-grained than we want. Widen elements first.
25000 // FIXME: should we do this before manifesting zeroable shuffle mask indices?
25001 SmallVector<int, 16> ScaledMask;
25002 getShuffleMaskWithWidestElts(Mask, ScaledMask);
25003 assert(Mask.size() >= ScaledMask.size() &&
25004 Mask.size() % ScaledMask.size() == 0 && "Unexpected mask widening.");
25005 int Prescale = Mask.size() / ScaledMask.size();
25006
25007 NumElts = ScaledMask.size();
25008 EltSizeInBits *= Prescale;
25009
25010 EVT PrescaledVT = EVT::getVectorVT(
25011 Context&: *DAG.getContext(), VT: EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: EltSizeInBits),
25012 NumElements: NumElts);
25013
25014 if (LegalTypes && !TLI.isTypeLegal(VT: PrescaledVT) && TLI.isTypeLegal(VT))
25015 return SDValue();
25016
25017 // For example,
25018 // shuffle<0,z,1,-1> == (v2i64 zero_extend_vector_inreg(v4i32))
25019 // But not shuffle<z,z,1,-1> and not shuffle<0,z,z,-1> ! (for same types)
25020 auto isZeroExtend = [NumElts, &ScaledMask](unsigned Scale) {
25021 assert(Scale >= 2 && Scale <= NumElts && NumElts % Scale == 0 &&
25022 "Unexpected mask scaling factor.");
25023 ArrayRef<int> Mask = ScaledMask;
25024 for (unsigned SrcElt = 0, NumSrcElts = NumElts / Scale;
25025 SrcElt != NumSrcElts; ++SrcElt) {
25026 // Analyze the shuffle mask in Scale-sized chunks.
25027 ArrayRef<int> MaskChunk = Mask.take_front(N: Scale);
25028 assert(MaskChunk.size() == Scale && "Unexpected mask size.");
25029 Mask = Mask.drop_front(N: MaskChunk.size());
25030 // The first indice in this chunk must be SrcElt, but not zero!
25031 // FIXME: undef should be fine, but that results in more-defined result.
25032 if (int FirstIndice = MaskChunk[0]; (unsigned)FirstIndice != SrcElt)
25033 return false;
25034 // The rest of the indices in this chunk must be zeros.
25035 // FIXME: undef should be fine, but that results in more-defined result.
25036 if (!all_of(Range: MaskChunk.drop_front(N: 1),
25037 P: [](int Indice) { return Indice == -2; }))
25038 return false;
25039 }
25040 assert(Mask.empty() && "Did not process the whole mask?");
25041 return true;
25042 };
25043
25044 unsigned Opcode = ISD::ZERO_EXTEND_VECTOR_INREG;
25045 for (bool Commuted : {false, true}) {
25046 SDValue Op = SVN->getOperand(Num: !Commuted ? 0 : 1);
25047 if (Commuted)
25048 ShuffleVectorSDNode::commuteMask(Mask: ScaledMask);
25049 std::optional<EVT> OutVT = canCombineShuffleToExtendVectorInreg(
25050 Opcode, VT: PrescaledVT, Match: isZeroExtend, DAG, TLI, LegalTypes,
25051 LegalOperations);
25052 if (OutVT)
25053 return DAG.getBitcast(VT, V: DAG.getNode(Opcode, DL: SDLoc(SVN), VT: *OutVT,
25054 Operand: DAG.getBitcast(VT: PrescaledVT, V: Op)));
25055 }
25056 return SDValue();
25057}
25058
25059// Detect 'truncate_vector_inreg' style shuffles that pack the lower parts of
25060// each source element of a large type into the lowest elements of a smaller
25061// destination type. This is often generated during legalization.
25062// If the source node itself was a '*_extend_vector_inreg' node then we should
25063// then be able to remove it.
25064static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN,
25065 SelectionDAG &DAG) {
25066 EVT VT = SVN->getValueType(ResNo: 0);
25067 bool IsBigEndian = DAG.getDataLayout().isBigEndian();
25068
25069 // TODO Add support for big-endian when we have a test case.
25070 if (!VT.isInteger() || IsBigEndian)
25071 return SDValue();
25072
25073 SDValue N0 = peekThroughBitcasts(V: SVN->getOperand(Num: 0));
25074
25075 unsigned Opcode = N0.getOpcode();
25076 if (!ISD::isExtVecInRegOpcode(Opcode))
25077 return SDValue();
25078
25079 SDValue N00 = N0.getOperand(i: 0);
25080 ArrayRef<int> Mask = SVN->getMask();
25081 unsigned NumElts = VT.getVectorNumElements();
25082 unsigned EltSizeInBits = VT.getScalarSizeInBits();
25083 unsigned ExtSrcSizeInBits = N00.getScalarValueSizeInBits();
25084 unsigned ExtDstSizeInBits = N0.getScalarValueSizeInBits();
25085
25086 if (ExtDstSizeInBits % ExtSrcSizeInBits != 0)
25087 return SDValue();
25088 unsigned ExtScale = ExtDstSizeInBits / ExtSrcSizeInBits;
25089
25090 // (v4i32 truncate_vector_inreg(v2i64)) == shuffle<0,2-1,-1>
25091 // (v8i16 truncate_vector_inreg(v4i32)) == shuffle<0,2,4,6,-1,-1,-1,-1>
25092 // (v8i16 truncate_vector_inreg(v2i64)) == shuffle<0,4,-1,-1,-1,-1,-1,-1>
25093 auto isTruncate = [&Mask, &NumElts](unsigned Scale) {
25094 for (unsigned i = 0; i != NumElts; ++i) {
25095 if (Mask[i] < 0)
25096 continue;
25097 if ((i * Scale) < NumElts && Mask[i] == (int)(i * Scale))
25098 continue;
25099 return false;
25100 }
25101 return true;
25102 };
25103
25104 // At the moment we just handle the case where we've truncated back to the
25105 // same size as before the extension.
25106 // TODO: handle more extension/truncation cases as cases arise.
25107 if (EltSizeInBits != ExtSrcSizeInBits)
25108 return SDValue();
25109
25110 // We can remove *extend_vector_inreg only if the truncation happens at
25111 // the same scale as the extension.
25112 if (isTruncate(ExtScale))
25113 return DAG.getBitcast(VT, V: N00);
25114
25115 return SDValue();
25116}
25117
25118// Combine shuffles of splat-shuffles of the form:
25119// shuffle (shuffle V, undef, splat-mask), undef, M
25120// If splat-mask contains undef elements, we need to be careful about
25121// introducing undef's in the folded mask which are not the result of composing
25122// the masks of the shuffles.
25123static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
25124 SelectionDAG &DAG) {
25125 EVT VT = Shuf->getValueType(ResNo: 0);
25126 unsigned NumElts = VT.getVectorNumElements();
25127
25128 if (!Shuf->getOperand(Num: 1).isUndef())
25129 return SDValue();
25130
25131 // See if this unary non-splat shuffle actually *is* a splat shuffle,
25132 // in disguise, with all demanded elements being identical.
25133 // FIXME: this can be done per-operand.
25134 if (!Shuf->isSplat()) {
25135 APInt DemandedElts(NumElts, 0);
25136 for (int Idx : Shuf->getMask()) {
25137 if (Idx < 0)
25138 continue; // Ignore sentinel indices.
25139 assert((unsigned)Idx < NumElts && "Out-of-bounds shuffle indice?");
25140 DemandedElts.setBit(Idx);
25141 }
25142 assert(DemandedElts.popcount() > 1 && "Is a splat shuffle already?");
25143 APInt UndefElts;
25144 if (DAG.isSplatValue(V: Shuf->getOperand(Num: 0), DemandedElts, UndefElts)) {
25145 // Even if all demanded elements are splat, some of them could be undef.
25146 // Which lowest demanded element is *not* known-undef?
25147 std::optional<unsigned> MinNonUndefIdx;
25148 for (int Idx : Shuf->getMask()) {
25149 if (Idx < 0 || UndefElts[Idx])
25150 continue; // Ignore sentinel indices, and undef elements.
25151 MinNonUndefIdx = std::min<unsigned>(a: Idx, b: MinNonUndefIdx.value_or(u: ~0U));
25152 }
25153 if (!MinNonUndefIdx)
25154 return DAG.getUNDEF(VT); // All undef - result is undef.
25155 assert(*MinNonUndefIdx < NumElts && "Expected valid element index.");
25156 SmallVector<int, 8> SplatMask(Shuf->getMask().begin(),
25157 Shuf->getMask().end());
25158 for (int &Idx : SplatMask) {
25159 if (Idx < 0)
25160 continue; // Passthrough sentinel indices.
25161 // Otherwise, just pick the lowest demanded non-undef element.
25162 // Or sentinel undef, if we know we'd pick a known-undef element.
25163 Idx = UndefElts[Idx] ? -1 : *MinNonUndefIdx;
25164 }
25165 assert(SplatMask != Shuf->getMask() && "Expected mask to change!");
25166 return DAG.getVectorShuffle(VT, dl: SDLoc(Shuf), N1: Shuf->getOperand(Num: 0),
25167 N2: Shuf->getOperand(Num: 1), Mask: SplatMask);
25168 }
25169 }
25170
25171 // If the inner operand is a known splat with no undefs, just return that directly.
25172 // TODO: Create DemandedElts mask from Shuf's mask.
25173 // TODO: Allow undef elements and merge with the shuffle code below.
25174 if (DAG.isSplatValue(V: Shuf->getOperand(Num: 0), /*AllowUndefs*/ false))
25175 return Shuf->getOperand(Num: 0);
25176
25177 auto *Splat = dyn_cast<ShuffleVectorSDNode>(Val: Shuf->getOperand(Num: 0));
25178 if (!Splat || !Splat->isSplat())
25179 return SDValue();
25180
25181 ArrayRef<int> ShufMask = Shuf->getMask();
25182 ArrayRef<int> SplatMask = Splat->getMask();
25183 assert(ShufMask.size() == SplatMask.size() && "Mask length mismatch");
25184
25185 // Prefer simplifying to the splat-shuffle, if possible. This is legal if
25186 // every undef mask element in the splat-shuffle has a corresponding undef
25187 // element in the user-shuffle's mask or if the composition of mask elements
25188 // would result in undef.
25189 // Examples for (shuffle (shuffle v, undef, SplatMask), undef, UserMask):
25190 // * UserMask=[0,2,u,u], SplatMask=[2,u,2,u] -> [2,2,u,u]
25191 // In this case it is not legal to simplify to the splat-shuffle because we
25192 // may be exposing the users of the shuffle an undef element at index 1
25193 // which was not there before the combine.
25194 // * UserMask=[0,u,2,u], SplatMask=[2,u,2,u] -> [2,u,2,u]
25195 // In this case the composition of masks yields SplatMask, so it's ok to
25196 // simplify to the splat-shuffle.
25197 // * UserMask=[3,u,2,u], SplatMask=[2,u,2,u] -> [u,u,2,u]
25198 // In this case the composed mask includes all undef elements of SplatMask
25199 // and in addition sets element zero to undef. It is safe to simplify to
25200 // the splat-shuffle.
25201 auto CanSimplifyToExistingSplat = [](ArrayRef<int> UserMask,
25202 ArrayRef<int> SplatMask) {
25203 for (unsigned i = 0, e = UserMask.size(); i != e; ++i)
25204 if (UserMask[i] != -1 && SplatMask[i] == -1 &&
25205 SplatMask[UserMask[i]] != -1)
25206 return false;
25207 return true;
25208 };
25209 if (CanSimplifyToExistingSplat(ShufMask, SplatMask))
25210 return Shuf->getOperand(Num: 0);
25211
25212 // Create a new shuffle with a mask that is composed of the two shuffles'
25213 // masks.
25214 SmallVector<int, 32> NewMask;
25215 for (int Idx : ShufMask)
25216 NewMask.push_back(Elt: Idx == -1 ? -1 : SplatMask[Idx]);
25217
25218 return DAG.getVectorShuffle(VT: Splat->getValueType(ResNo: 0), dl: SDLoc(Splat),
25219 N1: Splat->getOperand(Num: 0), N2: Splat->getOperand(Num: 1),
25220 Mask: NewMask);
25221}
25222
25223// Combine shuffles of bitcasts into a shuffle of the bitcast type, providing
25224// the mask can be treated as a larger type.
25225static SDValue combineShuffleOfBitcast(ShuffleVectorSDNode *SVN,
25226 SelectionDAG &DAG,
25227 const TargetLowering &TLI,
25228 bool LegalOperations) {
25229 SDValue Op0 = SVN->getOperand(Num: 0);
25230 SDValue Op1 = SVN->getOperand(Num: 1);
25231 EVT VT = SVN->getValueType(ResNo: 0);
25232 if (Op0.getOpcode() != ISD::BITCAST)
25233 return SDValue();
25234 EVT InVT = Op0.getOperand(i: 0).getValueType();
25235 if (!InVT.isVector() ||
25236 (!Op1.isUndef() && (Op1.getOpcode() != ISD::BITCAST ||
25237 Op1.getOperand(i: 0).getValueType() != InVT)))
25238 return SDValue();
25239 if (isAnyConstantBuildVector(V: Op0.getOperand(i: 0)) &&
25240 (Op1.isUndef() || isAnyConstantBuildVector(V: Op1.getOperand(i: 0))))
25241 return SDValue();
25242
25243 int VTLanes = VT.getVectorNumElements();
25244 int InLanes = InVT.getVectorNumElements();
25245 if (VTLanes <= InLanes || VTLanes % InLanes != 0 ||
25246 (LegalOperations &&
25247 !TLI.isOperationLegalOrCustom(Op: ISD::VECTOR_SHUFFLE, VT: InVT)))
25248 return SDValue();
25249 int Factor = VTLanes / InLanes;
25250
25251 // Check that each group of lanes in the mask are either undef or make a valid
25252 // mask for the wider lane type.
25253 ArrayRef<int> Mask = SVN->getMask();
25254 SmallVector<int> NewMask;
25255 if (!widenShuffleMaskElts(Scale: Factor, Mask, ScaledMask&: NewMask))
25256 return SDValue();
25257
25258 if (!TLI.isShuffleMaskLegal(NewMask, InVT))
25259 return SDValue();
25260
25261 // Create the new shuffle with the new mask and bitcast it back to the
25262 // original type.
25263 SDLoc DL(SVN);
25264 Op0 = Op0.getOperand(i: 0);
25265 Op1 = Op1.isUndef() ? DAG.getUNDEF(VT: InVT) : Op1.getOperand(i: 0);
25266 SDValue NewShuf = DAG.getVectorShuffle(VT: InVT, dl: DL, N1: Op0, N2: Op1, Mask: NewMask);
25267 return DAG.getBitcast(VT, V: NewShuf);
25268}
25269
25270/// Combine shuffle of shuffle of the form:
25271/// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
25272static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
25273 SelectionDAG &DAG) {
25274 if (!OuterShuf->getOperand(Num: 1).isUndef())
25275 return SDValue();
25276 auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(Val: OuterShuf->getOperand(Num: 0));
25277 if (!InnerShuf || !InnerShuf->getOperand(Num: 1).isUndef())
25278 return SDValue();
25279
25280 ArrayRef<int> OuterMask = OuterShuf->getMask();
25281 ArrayRef<int> InnerMask = InnerShuf->getMask();
25282 unsigned NumElts = OuterMask.size();
25283 assert(NumElts == InnerMask.size() && "Mask length mismatch");
25284 SmallVector<int, 32> CombinedMask(NumElts, -1);
25285 int SplatIndex = -1;
25286 for (unsigned i = 0; i != NumElts; ++i) {
25287 // Undef lanes remain undef.
25288 int OuterMaskElt = OuterMask[i];
25289 if (OuterMaskElt == -1)
25290 continue;
25291
25292 // Peek through the shuffle masks to get the underlying source element.
25293 int InnerMaskElt = InnerMask[OuterMaskElt];
25294 if (InnerMaskElt == -1)
25295 continue;
25296
25297 // Initialize the splatted element.
25298 if (SplatIndex == -1)
25299 SplatIndex = InnerMaskElt;
25300
25301 // Non-matching index - this is not a splat.
25302 if (SplatIndex != InnerMaskElt)
25303 return SDValue();
25304
25305 CombinedMask[i] = InnerMaskElt;
25306 }
25307 assert((all_of(CombinedMask, [](int M) { return M == -1; }) ||
25308 getSplatIndex(CombinedMask) != -1) &&
25309 "Expected a splat mask");
25310
25311 // TODO: The transform may be a win even if the mask is not legal.
25312 EVT VT = OuterShuf->getValueType(ResNo: 0);
25313 assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
25314 if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
25315 return SDValue();
25316
25317 return DAG.getVectorShuffle(VT, dl: SDLoc(OuterShuf), N1: InnerShuf->getOperand(Num: 0),
25318 N2: InnerShuf->getOperand(Num: 1), Mask: CombinedMask);
25319}
25320
25321/// If the shuffle mask is taking exactly one element from the first vector
25322/// operand and passing through all other elements from the second vector
25323/// operand, return the index of the mask element that is choosing an element
25324/// from the first operand. Otherwise, return -1.
25325static int getShuffleMaskIndexOfOneElementFromOp0IntoOp1(ArrayRef<int> Mask) {
25326 int MaskSize = Mask.size();
25327 int EltFromOp0 = -1;
25328 // TODO: This does not match if there are undef elements in the shuffle mask.
25329 // Should we ignore undefs in the shuffle mask instead? The trade-off is
25330 // removing an instruction (a shuffle), but losing the knowledge that some
25331 // vector lanes are not needed.
25332 for (int i = 0; i != MaskSize; ++i) {
25333 if (Mask[i] >= 0 && Mask[i] < MaskSize) {
25334 // We're looking for a shuffle of exactly one element from operand 0.
25335 if (EltFromOp0 != -1)
25336 return -1;
25337 EltFromOp0 = i;
25338 } else if (Mask[i] != i + MaskSize) {
25339 // Nothing from operand 1 can change lanes.
25340 return -1;
25341 }
25342 }
25343 return EltFromOp0;
25344}
25345
25346/// If a shuffle inserts exactly one element from a source vector operand into
25347/// another vector operand and we can access the specified element as a scalar,
25348/// then we can eliminate the shuffle.
25349static SDValue replaceShuffleOfInsert(ShuffleVectorSDNode *Shuf,
25350 SelectionDAG &DAG) {
25351 // First, check if we are taking one element of a vector and shuffling that
25352 // element into another vector.
25353 ArrayRef<int> Mask = Shuf->getMask();
25354 SmallVector<int, 16> CommutedMask(Mask);
25355 SDValue Op0 = Shuf->getOperand(Num: 0);
25356 SDValue Op1 = Shuf->getOperand(Num: 1);
25357 int ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask);
25358 if (ShufOp0Index == -1) {
25359 // Commute mask and check again.
25360 ShuffleVectorSDNode::commuteMask(Mask: CommutedMask);
25361 ShufOp0Index = getShuffleMaskIndexOfOneElementFromOp0IntoOp1(Mask: CommutedMask);
25362 if (ShufOp0Index == -1)
25363 return SDValue();
25364 // Commute operands to match the commuted shuffle mask.
25365 std::swap(a&: Op0, b&: Op1);
25366 Mask = CommutedMask;
25367 }
25368
25369 // The shuffle inserts exactly one element from operand 0 into operand 1.
25370 // Now see if we can access that element as a scalar via a real insert element
25371 // instruction.
25372 // TODO: We can try harder to locate the element as a scalar. Examples: it
25373 // could be an operand of SCALAR_TO_VECTOR, BUILD_VECTOR, or a constant.
25374 assert(Mask[ShufOp0Index] >= 0 && Mask[ShufOp0Index] < (int)Mask.size() &&
25375 "Shuffle mask value must be from operand 0");
25376 if (Op0.getOpcode() != ISD::INSERT_VECTOR_ELT)
25377 return SDValue();
25378
25379 auto *InsIndexC = dyn_cast<ConstantSDNode>(Val: Op0.getOperand(i: 2));
25380 if (!InsIndexC || InsIndexC->getSExtValue() != Mask[ShufOp0Index])
25381 return SDValue();
25382
25383 // There's an existing insertelement with constant insertion index, so we
25384 // don't need to check the legality/profitability of a replacement operation
25385 // that differs at most in the constant value. The target should be able to
25386 // lower any of those in a similar way. If not, legalization will expand this
25387 // to a scalar-to-vector plus shuffle.
25388 //
25389 // Note that the shuffle may move the scalar from the position that the insert
25390 // element used. Therefore, our new insert element occurs at the shuffle's
25391 // mask index value, not the insert's index value.
25392 // shuffle (insertelt v1, x, C), v2, mask --> insertelt v2, x, C'
25393 SDValue NewInsIndex = DAG.getVectorIdxConstant(Val: ShufOp0Index, DL: SDLoc(Shuf));
25394 return DAG.getNode(Opcode: ISD::INSERT_VECTOR_ELT, DL: SDLoc(Shuf), VT: Op0.getValueType(),
25395 N1: Op1, N2: Op0.getOperand(i: 1), N3: NewInsIndex);
25396}
25397
25398/// If we have a unary shuffle of a shuffle, see if it can be folded away
25399/// completely. This has the potential to lose undef knowledge because the first
25400/// shuffle may not have an undef mask element where the second one does. So
25401/// only call this after doing simplifications based on demanded elements.
25402static SDValue simplifyShuffleOfShuffle(ShuffleVectorSDNode *Shuf) {
25403 // shuf (shuf0 X, Y, Mask0), undef, Mask
25404 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Val: Shuf->getOperand(Num: 0));
25405 if (!Shuf0 || !Shuf->getOperand(Num: 1).isUndef())
25406 return SDValue();
25407
25408 ArrayRef<int> Mask = Shuf->getMask();
25409 ArrayRef<int> Mask0 = Shuf0->getMask();
25410 for (int i = 0, e = (int)Mask.size(); i != e; ++i) {
25411 // Ignore undef elements.
25412 if (Mask[i] == -1)
25413 continue;
25414 assert(Mask[i] >= 0 && Mask[i] < e && "Unexpected shuffle mask value");
25415
25416 // Is the element of the shuffle operand chosen by this shuffle the same as
25417 // the element chosen by the shuffle operand itself?
25418 if (Mask0[Mask[i]] != Mask0[i])
25419 return SDValue();
25420 }
25421 // Every element of this shuffle is identical to the result of the previous
25422 // shuffle, so we can replace this value.
25423 return Shuf->getOperand(Num: 0);
25424}
25425
25426SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
25427 EVT VT = N->getValueType(ResNo: 0);
25428 unsigned NumElts = VT.getVectorNumElements();
25429
25430 SDValue N0 = N->getOperand(Num: 0);
25431 SDValue N1 = N->getOperand(Num: 1);
25432
25433 assert(N0.getValueType() == VT && "Vector shuffle must be normalized in DAG");
25434
25435 // Canonicalize shuffle undef, undef -> undef
25436 if (N0.isUndef() && N1.isUndef())
25437 return DAG.getUNDEF(VT);
25438
25439 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Val: N);
25440
25441 // Canonicalize shuffle v, v -> v, undef
25442 if (N0 == N1)
25443 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: N0, N2: DAG.getUNDEF(VT),
25444 Mask: createUnaryMask(Mask: SVN->getMask(), NumElts));
25445
25446 // Canonicalize shuffle undef, v -> v, undef. Commute the shuffle mask.
25447 if (N0.isUndef())
25448 return DAG.getCommutedVectorShuffle(SV: *SVN);
25449
25450 // Remove references to rhs if it is undef
25451 if (N1.isUndef()) {
25452 bool Changed = false;
25453 SmallVector<int, 8> NewMask;
25454 for (unsigned i = 0; i != NumElts; ++i) {
25455 int Idx = SVN->getMaskElt(Idx: i);
25456 if (Idx >= (int)NumElts) {
25457 Idx = -1;
25458 Changed = true;
25459 }
25460 NewMask.push_back(Elt: Idx);
25461 }
25462 if (Changed)
25463 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: N0, N2: N1, Mask: NewMask);
25464 }
25465
25466 if (SDValue InsElt = replaceShuffleOfInsert(Shuf: SVN, DAG))
25467 return InsElt;
25468
25469 // A shuffle of a single vector that is a splatted value can always be folded.
25470 if (SDValue V = combineShuffleOfSplatVal(Shuf: SVN, DAG))
25471 return V;
25472
25473 if (SDValue V = formSplatFromShuffles(OuterShuf: SVN, DAG))
25474 return V;
25475
25476 // If it is a splat, check if the argument vector is another splat or a
25477 // build_vector.
25478 if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {
25479 int SplatIndex = SVN->getSplatIndex();
25480 if (N0.hasOneUse() && TLI.isExtractVecEltCheap(VT, Index: SplatIndex) &&
25481 TLI.isBinOp(Opcode: N0.getOpcode()) && N0->getNumValues() == 1) {
25482 // splat (vector_bo L, R), Index -->
25483 // splat (scalar_bo (extelt L, Index), (extelt R, Index))
25484 SDValue L = N0.getOperand(i: 0), R = N0.getOperand(i: 1);
25485 SDLoc DL(N);
25486 EVT EltVT = VT.getScalarType();
25487 SDValue Index = DAG.getVectorIdxConstant(Val: SplatIndex, DL);
25488 SDValue ExtL = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: L, N2: Index);
25489 SDValue ExtR = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: R, N2: Index);
25490 SDValue NewBO =
25491 DAG.getNode(Opcode: N0.getOpcode(), DL, VT: EltVT, N1: ExtL, N2: ExtR, Flags: N0->getFlags());
25492 SDValue Insert = DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL, VT, Operand: NewBO);
25493 SmallVector<int, 16> ZeroMask(VT.getVectorNumElements(), 0);
25494 return DAG.getVectorShuffle(VT, dl: DL, N1: Insert, N2: DAG.getUNDEF(VT), Mask: ZeroMask);
25495 }
25496
25497 // splat(scalar_to_vector(x), 0) -> build_vector(x,...,x)
25498 // splat(insert_vector_elt(v, x, c), c) -> build_vector(x,...,x)
25499 if ((!LegalOperations || TLI.isOperationLegal(Op: ISD::BUILD_VECTOR, VT)) &&
25500 N0.hasOneUse()) {
25501 if (N0.getOpcode() == ISD::SCALAR_TO_VECTOR && SplatIndex == 0)
25502 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op: N0.getOperand(i: 0));
25503
25504 if (N0.getOpcode() == ISD::INSERT_VECTOR_ELT)
25505 if (auto *Idx = dyn_cast<ConstantSDNode>(Val: N0.getOperand(i: 2)))
25506 if (Idx->getAPIntValue() == SplatIndex)
25507 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op: N0.getOperand(i: 1));
25508
25509 // Look through a bitcast if LE and splatting lane 0, through to a
25510 // scalar_to_vector or a build_vector.
25511 if (N0.getOpcode() == ISD::BITCAST && N0.getOperand(i: 0).hasOneUse() &&
25512 SplatIndex == 0 && DAG.getDataLayout().isLittleEndian() &&
25513 (N0.getOperand(i: 0).getOpcode() == ISD::SCALAR_TO_VECTOR ||
25514 N0.getOperand(i: 0).getOpcode() == ISD::BUILD_VECTOR)) {
25515 EVT N00VT = N0.getOperand(i: 0).getValueType();
25516 if (VT.getScalarSizeInBits() <= N00VT.getScalarSizeInBits() &&
25517 VT.isInteger() && N00VT.isInteger()) {
25518 EVT InVT =
25519 TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: VT.getScalarType());
25520 SDValue Op = DAG.getZExtOrTrunc(Op: N0.getOperand(i: 0).getOperand(i: 0),
25521 DL: SDLoc(N), VT: InVT);
25522 return DAG.getSplatBuildVector(VT, DL: SDLoc(N), Op);
25523 }
25524 }
25525 }
25526
25527 // If this is a bit convert that changes the element type of the vector but
25528 // not the number of vector elements, look through it. Be careful not to
25529 // look though conversions that change things like v4f32 to v2f64.
25530 SDNode *V = N0.getNode();
25531 if (V->getOpcode() == ISD::BITCAST) {
25532 SDValue ConvInput = V->getOperand(Num: 0);
25533 if (ConvInput.getValueType().isVector() &&
25534 ConvInput.getValueType().getVectorNumElements() == NumElts)
25535 V = ConvInput.getNode();
25536 }
25537
25538 if (V->getOpcode() == ISD::BUILD_VECTOR) {
25539 assert(V->getNumOperands() == NumElts &&
25540 "BUILD_VECTOR has wrong number of operands");
25541 SDValue Base;
25542 bool AllSame = true;
25543 for (unsigned i = 0; i != NumElts; ++i) {
25544 if (!V->getOperand(Num: i).isUndef()) {
25545 Base = V->getOperand(Num: i);
25546 break;
25547 }
25548 }
25549 // Splat of <u, u, u, u>, return <u, u, u, u>
25550 if (!Base.getNode())
25551 return N0;
25552 for (unsigned i = 0; i != NumElts; ++i) {
25553 if (V->getOperand(Num: i) != Base) {
25554 AllSame = false;
25555 break;
25556 }
25557 }
25558 // Splat of <x, x, x, x>, return <x, x, x, x>
25559 if (AllSame)
25560 return N0;
25561
25562 // Canonicalize any other splat as a build_vector.
25563 SDValue Splatted = V->getOperand(Num: SplatIndex);
25564 SmallVector<SDValue, 8> Ops(NumElts, Splatted);
25565 SDValue NewBV = DAG.getBuildVector(VT: V->getValueType(ResNo: 0), DL: SDLoc(N), Ops);
25566
25567 // We may have jumped through bitcasts, so the type of the
25568 // BUILD_VECTOR may not match the type of the shuffle.
25569 if (V->getValueType(ResNo: 0) != VT)
25570 NewBV = DAG.getBitcast(VT, V: NewBV);
25571 return NewBV;
25572 }
25573 }
25574
25575 // Simplify source operands based on shuffle mask.
25576 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
25577 return SDValue(N, 0);
25578
25579 // This is intentionally placed after demanded elements simplification because
25580 // it could eliminate knowledge of undef elements created by this shuffle.
25581 if (SDValue ShufOp = simplifyShuffleOfShuffle(Shuf: SVN))
25582 return ShufOp;
25583
25584 // Match shuffles that can be converted to any_vector_extend_in_reg.
25585 if (SDValue V =
25586 combineShuffleToAnyExtendVectorInreg(SVN, DAG, TLI, LegalOperations))
25587 return V;
25588
25589 // Combine "truncate_vector_in_reg" style shuffles.
25590 if (SDValue V = combineTruncationShuffle(SVN, DAG))
25591 return V;
25592
25593 if (N0.getOpcode() == ISD::CONCAT_VECTORS &&
25594 Level < AfterLegalizeVectorOps &&
25595 (N1.isUndef() ||
25596 (N1.getOpcode() == ISD::CONCAT_VECTORS &&
25597 N0.getOperand(i: 0).getValueType() == N1.getOperand(i: 0).getValueType()))) {
25598 if (SDValue V = partitionShuffleOfConcats(N, DAG))
25599 return V;
25600 }
25601
25602 // A shuffle of a concat of the same narrow vector can be reduced to use
25603 // only low-half elements of a concat with undef:
25604 // shuf (concat X, X), undef, Mask --> shuf (concat X, undef), undef, Mask'
25605 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N1.isUndef() &&
25606 N0.getNumOperands() == 2 &&
25607 N0.getOperand(i: 0) == N0.getOperand(i: 1)) {
25608 int HalfNumElts = (int)NumElts / 2;
25609 SmallVector<int, 8> NewMask;
25610 for (unsigned i = 0; i != NumElts; ++i) {
25611 int Idx = SVN->getMaskElt(Idx: i);
25612 if (Idx >= HalfNumElts) {
25613 assert(Idx < (int)NumElts && "Shuffle mask chooses undef op");
25614 Idx -= HalfNumElts;
25615 }
25616 NewMask.push_back(Elt: Idx);
25617 }
25618 if (TLI.isShuffleMaskLegal(NewMask, VT)) {
25619 SDValue UndefVec = DAG.getUNDEF(VT: N0.getOperand(i: 0).getValueType());
25620 SDValue NewCat = DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT,
25621 N1: N0.getOperand(i: 0), N2: UndefVec);
25622 return DAG.getVectorShuffle(VT, dl: SDLoc(N), N1: NewCat, N2: N1, Mask: NewMask);
25623 }
25624 }
25625
25626 // See if we can replace a shuffle with an insert_subvector.
25627 // e.g. v2i32 into v8i32:
25628 // shuffle(lhs,concat(rhs0,rhs1,rhs2,rhs3),0,1,2,3,10,11,6,7).
25629 // --> insert_subvector(lhs,rhs1,4).
25630 if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT) &&
25631 TLI.isOperationLegalOrCustom(Op: ISD::INSERT_SUBVECTOR, VT)) {
25632 auto ShuffleToInsert = [&](SDValue LHS, SDValue RHS, ArrayRef<int> Mask) {
25633 // Ensure RHS subvectors are legal.
25634 assert(RHS.getOpcode() == ISD::CONCAT_VECTORS && "Can't find subvectors");
25635 EVT SubVT = RHS.getOperand(i: 0).getValueType();
25636 int NumSubVecs = RHS.getNumOperands();
25637 int NumSubElts = SubVT.getVectorNumElements();
25638 assert((NumElts % NumSubElts) == 0 && "Subvector mismatch");
25639 if (!TLI.isTypeLegal(VT: SubVT))
25640 return SDValue();
25641
25642 // Don't bother if we have an unary shuffle (matches undef + LHS elts).
25643 if (all_of(Range&: Mask, P: [NumElts](int M) { return M < (int)NumElts; }))
25644 return SDValue();
25645
25646 // Search [NumSubElts] spans for RHS sequence.
25647 // TODO: Can we avoid nested loops to increase performance?
25648 SmallVector<int> InsertionMask(NumElts);
25649 for (int SubVec = 0; SubVec != NumSubVecs; ++SubVec) {
25650 for (int SubIdx = 0; SubIdx != (int)NumElts; SubIdx += NumSubElts) {
25651 // Reset mask to identity.
25652 std::iota(first: InsertionMask.begin(), last: InsertionMask.end(), value: 0);
25653
25654 // Add subvector insertion.
25655 std::iota(first: InsertionMask.begin() + SubIdx,
25656 last: InsertionMask.begin() + SubIdx + NumSubElts,
25657 value: NumElts + (SubVec * NumSubElts));
25658
25659 // See if the shuffle mask matches the reference insertion mask.
25660 bool MatchingShuffle = true;
25661 for (int i = 0; i != (int)NumElts; ++i) {
25662 int ExpectIdx = InsertionMask[i];
25663 int ActualIdx = Mask[i];
25664 if (0 <= ActualIdx && ExpectIdx != ActualIdx) {
25665 MatchingShuffle = false;
25666 break;
25667 }
25668 }
25669
25670 if (MatchingShuffle)
25671 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT, N1: LHS,
25672 N2: RHS.getOperand(i: SubVec),
25673 N3: DAG.getVectorIdxConstant(Val: SubIdx, DL: SDLoc(N)));
25674 }
25675 }
25676 return SDValue();
25677 };
25678 ArrayRef<int> Mask = SVN->getMask();
25679 if (N1.getOpcode() == ISD::CONCAT_VECTORS)
25680 if (SDValue InsertN1 = ShuffleToInsert(N0, N1, Mask))
25681 return InsertN1;
25682 if (N0.getOpcode() == ISD::CONCAT_VECTORS) {
25683 SmallVector<int> CommuteMask(Mask);
25684 ShuffleVectorSDNode::commuteMask(Mask: CommuteMask);
25685 if (SDValue InsertN0 = ShuffleToInsert(N1, N0, CommuteMask))
25686 return InsertN0;
25687 }
25688 }
25689
25690 // If we're not performing a select/blend shuffle, see if we can convert the
25691 // shuffle into a AND node, with all the out-of-lane elements are known zero.
25692 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
25693 bool IsInLaneMask = true;
25694 ArrayRef<int> Mask = SVN->getMask();
25695 SmallVector<int, 16> ClearMask(NumElts, -1);
25696 APInt DemandedLHS = APInt::getZero(numBits: NumElts);
25697 APInt DemandedRHS = APInt::getZero(numBits: NumElts);
25698 for (int I = 0; I != (int)NumElts; ++I) {
25699 int M = Mask[I];
25700 if (M < 0)
25701 continue;
25702 ClearMask[I] = M == I ? I : (I + NumElts);
25703 IsInLaneMask &= (M == I) || (M == (int)(I + NumElts));
25704 if (M != I) {
25705 APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS;
25706 Demanded.setBit(M % NumElts);
25707 }
25708 }
25709 // TODO: Should we try to mask with N1 as well?
25710 if (!IsInLaneMask && (!DemandedLHS.isZero() || !DemandedRHS.isZero()) &&
25711 (DemandedLHS.isZero() || DAG.MaskedVectorIsZero(Op: N0, DemandedElts: DemandedLHS)) &&
25712 (DemandedRHS.isZero() || DAG.MaskedVectorIsZero(Op: N1, DemandedElts: DemandedRHS))) {
25713 SDLoc DL(N);
25714 EVT IntVT = VT.changeVectorElementTypeToInteger();
25715 EVT IntSVT = VT.getVectorElementType().changeTypeToInteger();
25716 // Transform the type to a legal type so that the buildvector constant
25717 // elements are not illegal. Make sure that the result is larger than the
25718 // original type, incase the value is split into two (eg i64->i32).
25719 if (!TLI.isTypeLegal(VT: IntSVT) && LegalTypes)
25720 IntSVT = TLI.getTypeToTransformTo(Context&: *DAG.getContext(), VT: IntSVT);
25721 if (IntSVT.getSizeInBits() >= IntVT.getScalarSizeInBits()) {
25722 SDValue ZeroElt = DAG.getConstant(Val: 0, DL, VT: IntSVT);
25723 SDValue AllOnesElt = DAG.getAllOnesConstant(DL, VT: IntSVT);
25724 SmallVector<SDValue, 16> AndMask(NumElts, DAG.getUNDEF(VT: IntSVT));
25725 for (int I = 0; I != (int)NumElts; ++I)
25726 if (0 <= Mask[I])
25727 AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt;
25728
25729 // See if a clear mask is legal instead of going via
25730 // XformToShuffleWithZero which loses UNDEF mask elements.
25731 if (TLI.isVectorClearMaskLegal(ClearMask, IntVT))
25732 return DAG.getBitcast(
25733 VT, V: DAG.getVectorShuffle(VT: IntVT, dl: DL, N1: DAG.getBitcast(VT: IntVT, V: N0),
25734 N2: DAG.getConstant(Val: 0, DL, VT: IntVT), Mask: ClearMask));
25735
25736 if (TLI.isOperationLegalOrCustom(Op: ISD::AND, VT: IntVT))
25737 return DAG.getBitcast(
25738 VT, V: DAG.getNode(Opcode: ISD::AND, DL, VT: IntVT, N1: DAG.getBitcast(VT: IntVT, V: N0),
25739 N2: DAG.getBuildVector(VT: IntVT, DL, Ops: AndMask)));
25740 }
25741 }
25742 }
25743
25744 // Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
25745 // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
25746 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
25747 if (SDValue Res = combineShuffleOfScalars(SVN, DAG, TLI))
25748 return Res;
25749
25750 // If this shuffle only has a single input that is a bitcasted shuffle,
25751 // attempt to merge the 2 shuffles and suitably bitcast the inputs/output
25752 // back to their original types.
25753 if (N0.getOpcode() == ISD::BITCAST && N0.hasOneUse() &&
25754 N1.isUndef() && Level < AfterLegalizeVectorOps &&
25755 TLI.isTypeLegal(VT)) {
25756
25757 SDValue BC0 = peekThroughOneUseBitcasts(V: N0);
25758 if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) {
25759 EVT SVT = VT.getScalarType();
25760 EVT InnerVT = BC0->getValueType(ResNo: 0);
25761 EVT InnerSVT = InnerVT.getScalarType();
25762
25763 // Determine which shuffle works with the smaller scalar type.
25764 EVT ScaleVT = SVT.bitsLT(VT: InnerSVT) ? VT : InnerVT;
25765 EVT ScaleSVT = ScaleVT.getScalarType();
25766
25767 if (TLI.isTypeLegal(VT: ScaleVT) &&
25768 0 == (InnerSVT.getSizeInBits() % ScaleSVT.getSizeInBits()) &&
25769 0 == (SVT.getSizeInBits() % ScaleSVT.getSizeInBits())) {
25770 int InnerScale = InnerSVT.getSizeInBits() / ScaleSVT.getSizeInBits();
25771 int OuterScale = SVT.getSizeInBits() / ScaleSVT.getSizeInBits();
25772
25773 // Scale the shuffle masks to the smaller scalar type.
25774 ShuffleVectorSDNode *InnerSVN = cast<ShuffleVectorSDNode>(Val&: BC0);
25775 SmallVector<int, 8> InnerMask;
25776 SmallVector<int, 8> OuterMask;
25777 narrowShuffleMaskElts(Scale: InnerScale, Mask: InnerSVN->getMask(), ScaledMask&: InnerMask);
25778 narrowShuffleMaskElts(Scale: OuterScale, Mask: SVN->getMask(), ScaledMask&: OuterMask);
25779
25780 // Merge the shuffle masks.
25781 SmallVector<int, 8> NewMask;
25782 for (int M : OuterMask)
25783 NewMask.push_back(Elt: M < 0 ? -1 : InnerMask[M]);
25784
25785 // Test for shuffle mask legality over both commutations.
25786 SDValue SV0 = BC0->getOperand(Num: 0);
25787 SDValue SV1 = BC0->getOperand(Num: 1);
25788 bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
25789 if (!LegalMask) {
25790 std::swap(a&: SV0, b&: SV1);
25791 ShuffleVectorSDNode::commuteMask(Mask: NewMask);
25792 LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT);
25793 }
25794
25795 if (LegalMask) {
25796 SV0 = DAG.getBitcast(VT: ScaleVT, V: SV0);
25797 SV1 = DAG.getBitcast(VT: ScaleVT, V: SV1);
25798 return DAG.getBitcast(
25799 VT, V: DAG.getVectorShuffle(VT: ScaleVT, dl: SDLoc(N), N1: SV0, N2: SV1, Mask: NewMask));
25800 }
25801 }
25802 }
25803 }
25804
25805 // Match shuffles of bitcasts, so long as the mask can be treated as the
25806 // larger type.
25807 if (SDValue V = combineShuffleOfBitcast(SVN, DAG, TLI, LegalOperations))
25808 return V;
25809
25810 // Compute the combined shuffle mask for a shuffle with SV0 as the first
25811 // operand, and SV1 as the second operand.
25812 // i.e. Merge SVN(OtherSVN, N1) -> shuffle(SV0, SV1, Mask) iff Commute = false
25813 // Merge SVN(N1, OtherSVN) -> shuffle(SV0, SV1, Mask') iff Commute = true
25814 auto MergeInnerShuffle =
25815 [NumElts, &VT](bool Commute, ShuffleVectorSDNode *SVN,
25816 ShuffleVectorSDNode *OtherSVN, SDValue N1,
25817 const TargetLowering &TLI, SDValue &SV0, SDValue &SV1,
25818 SmallVectorImpl<int> &Mask) -> bool {
25819 // Don't try to fold splats; they're likely to simplify somehow, or they
25820 // might be free.
25821 if (OtherSVN->isSplat())
25822 return false;
25823
25824 SV0 = SV1 = SDValue();
25825 Mask.clear();
25826
25827 for (unsigned i = 0; i != NumElts; ++i) {
25828 int Idx = SVN->getMaskElt(Idx: i);
25829 if (Idx < 0) {
25830 // Propagate Undef.
25831 Mask.push_back(Elt: Idx);
25832 continue;
25833 }
25834
25835 if (Commute)
25836 Idx = (Idx < (int)NumElts) ? (Idx + NumElts) : (Idx - NumElts);
25837
25838 SDValue CurrentVec;
25839 if (Idx < (int)NumElts) {
25840 // This shuffle index refers to the inner shuffle N0. Lookup the inner
25841 // shuffle mask to identify which vector is actually referenced.
25842 Idx = OtherSVN->getMaskElt(Idx);
25843 if (Idx < 0) {
25844 // Propagate Undef.
25845 Mask.push_back(Elt: Idx);
25846 continue;
25847 }
25848 CurrentVec = (Idx < (int)NumElts) ? OtherSVN->getOperand(Num: 0)
25849 : OtherSVN->getOperand(Num: 1);
25850 } else {
25851 // This shuffle index references an element within N1.
25852 CurrentVec = N1;
25853 }
25854
25855 // Simple case where 'CurrentVec' is UNDEF.
25856 if (CurrentVec.isUndef()) {
25857 Mask.push_back(Elt: -1);
25858 continue;
25859 }
25860
25861 // Canonicalize the shuffle index. We don't know yet if CurrentVec
25862 // will be the first or second operand of the combined shuffle.
25863 Idx = Idx % NumElts;
25864 if (!SV0.getNode() || SV0 == CurrentVec) {
25865 // Ok. CurrentVec is the left hand side.
25866 // Update the mask accordingly.
25867 SV0 = CurrentVec;
25868 Mask.push_back(Elt: Idx);
25869 continue;
25870 }
25871 if (!SV1.getNode() || SV1 == CurrentVec) {
25872 // Ok. CurrentVec is the right hand side.
25873 // Update the mask accordingly.
25874 SV1 = CurrentVec;
25875 Mask.push_back(Elt: Idx + NumElts);
25876 continue;
25877 }
25878
25879 // Last chance - see if the vector is another shuffle and if it
25880 // uses one of the existing candidate shuffle ops.
25881 if (auto *CurrentSVN = dyn_cast<ShuffleVectorSDNode>(Val&: CurrentVec)) {
25882 int InnerIdx = CurrentSVN->getMaskElt(Idx);
25883 if (InnerIdx < 0) {
25884 Mask.push_back(Elt: -1);
25885 continue;
25886 }
25887 SDValue InnerVec = (InnerIdx < (int)NumElts)
25888 ? CurrentSVN->getOperand(Num: 0)
25889 : CurrentSVN->getOperand(Num: 1);
25890 if (InnerVec.isUndef()) {
25891 Mask.push_back(Elt: -1);
25892 continue;
25893 }
25894 InnerIdx %= NumElts;
25895 if (InnerVec == SV0) {
25896 Mask.push_back(Elt: InnerIdx);
25897 continue;
25898 }
25899 if (InnerVec == SV1) {
25900 Mask.push_back(Elt: InnerIdx + NumElts);
25901 continue;
25902 }
25903 }
25904
25905 // Bail out if we cannot convert the shuffle pair into a single shuffle.
25906 return false;
25907 }
25908
25909 if (llvm::all_of(Range&: Mask, P: [](int M) { return M < 0; }))
25910 return true;
25911
25912 // Avoid introducing shuffles with illegal mask.
25913 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
25914 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
25915 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
25916 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
25917 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
25918 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
25919 if (TLI.isShuffleMaskLegal(Mask, VT))
25920 return true;
25921
25922 std::swap(a&: SV0, b&: SV1);
25923 ShuffleVectorSDNode::commuteMask(Mask);
25924 return TLI.isShuffleMaskLegal(Mask, VT);
25925 };
25926
25927 if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
25928 // Canonicalize shuffles according to rules:
25929 // shuffle(A, shuffle(A, B)) -> shuffle(shuffle(A,B), A)
25930 // shuffle(B, shuffle(A, B)) -> shuffle(shuffle(A,B), B)
25931 // shuffle(B, shuffle(A, Undef)) -> shuffle(shuffle(A, Undef), B)
25932 if (N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
25933 N0.getOpcode() != ISD::VECTOR_SHUFFLE) {
25934 // The incoming shuffle must be of the same type as the result of the
25935 // current shuffle.
25936 assert(N1->getOperand(0).getValueType() == VT &&
25937 "Shuffle types don't match");
25938
25939 SDValue SV0 = N1->getOperand(Num: 0);
25940 SDValue SV1 = N1->getOperand(Num: 1);
25941 bool HasSameOp0 = N0 == SV0;
25942 bool IsSV1Undef = SV1.isUndef();
25943 if (HasSameOp0 || IsSV1Undef || N0 == SV1)
25944 // Commute the operands of this shuffle so merging below will trigger.
25945 return DAG.getCommutedVectorShuffle(SV: *SVN);
25946 }
25947
25948 // Canonicalize splat shuffles to the RHS to improve merging below.
25949 // shuffle(splat(A,u), shuffle(C,D)) -> shuffle'(shuffle(C,D), splat(A,u))
25950 if (N0.getOpcode() == ISD::VECTOR_SHUFFLE &&
25951 N1.getOpcode() == ISD::VECTOR_SHUFFLE &&
25952 cast<ShuffleVectorSDNode>(Val&: N0)->isSplat() &&
25953 !cast<ShuffleVectorSDNode>(Val&: N1)->isSplat()) {
25954 return DAG.getCommutedVectorShuffle(SV: *SVN);
25955 }
25956
25957 // Try to fold according to rules:
25958 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
25959 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
25960 // shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
25961 // Don't try to fold shuffles with illegal type.
25962 // Only fold if this shuffle is the only user of the other shuffle.
25963 // Try matching shuffle(C,shuffle(A,B)) commutted patterns as well.
25964 for (int i = 0; i != 2; ++i) {
25965 if (N->getOperand(Num: i).getOpcode() == ISD::VECTOR_SHUFFLE &&
25966 N->isOnlyUserOf(N: N->getOperand(Num: i).getNode())) {
25967 // The incoming shuffle must be of the same type as the result of the
25968 // current shuffle.
25969 auto *OtherSV = cast<ShuffleVectorSDNode>(Val: N->getOperand(Num: i));
25970 assert(OtherSV->getOperand(0).getValueType() == VT &&
25971 "Shuffle types don't match");
25972
25973 SDValue SV0, SV1;
25974 SmallVector<int, 4> Mask;
25975 if (MergeInnerShuffle(i != 0, SVN, OtherSV, N->getOperand(Num: 1 - i), TLI,
25976 SV0, SV1, Mask)) {
25977 // Check if all indices in Mask are Undef. In case, propagate Undef.
25978 if (llvm::all_of(Range&: Mask, P: [](int M) { return M < 0; }))
25979 return DAG.getUNDEF(VT);
25980
25981 return DAG.getVectorShuffle(VT, dl: SDLoc(N),
25982 N1: SV0 ? SV0 : DAG.getUNDEF(VT),
25983 N2: SV1 ? SV1 : DAG.getUNDEF(VT), Mask);
25984 }
25985 }
25986 }
25987
25988 // Merge shuffles through binops if we are able to merge it with at least
25989 // one other shuffles.
25990 // shuffle(bop(shuffle(x,y),shuffle(z,w)),undef)
25991 // shuffle(bop(shuffle(x,y),shuffle(z,w)),bop(shuffle(a,b),shuffle(c,d)))
25992 unsigned SrcOpcode = N0.getOpcode();
25993 if (TLI.isBinOp(Opcode: SrcOpcode) && N->isOnlyUserOf(N: N0.getNode()) &&
25994 (N1.isUndef() ||
25995 (SrcOpcode == N1.getOpcode() && N->isOnlyUserOf(N: N1.getNode())))) {
25996 // Get binop source ops, or just pass on the undef.
25997 SDValue Op00 = N0.getOperand(i: 0);
25998 SDValue Op01 = N0.getOperand(i: 1);
25999 SDValue Op10 = N1.isUndef() ? N1 : N1.getOperand(i: 0);
26000 SDValue Op11 = N1.isUndef() ? N1 : N1.getOperand(i: 1);
26001 // TODO: We might be able to relax the VT check but we don't currently
26002 // have any isBinOp() that has different result/ops VTs so play safe until
26003 // we have test coverage.
26004 if (Op00.getValueType() == VT && Op10.getValueType() == VT &&
26005 Op01.getValueType() == VT && Op11.getValueType() == VT &&
26006 (Op00.getOpcode() == ISD::VECTOR_SHUFFLE ||
26007 Op10.getOpcode() == ISD::VECTOR_SHUFFLE ||
26008 Op01.getOpcode() == ISD::VECTOR_SHUFFLE ||
26009 Op11.getOpcode() == ISD::VECTOR_SHUFFLE)) {
26010 auto CanMergeInnerShuffle = [&](SDValue &SV0, SDValue &SV1,
26011 SmallVectorImpl<int> &Mask, bool LeftOp,
26012 bool Commute) {
26013 SDValue InnerN = Commute ? N1 : N0;
26014 SDValue Op0 = LeftOp ? Op00 : Op01;
26015 SDValue Op1 = LeftOp ? Op10 : Op11;
26016 if (Commute)
26017 std::swap(a&: Op0, b&: Op1);
26018 // Only accept the merged shuffle if we don't introduce undef elements,
26019 // or the inner shuffle already contained undef elements.
26020 auto *SVN0 = dyn_cast<ShuffleVectorSDNode>(Val&: Op0);
26021 return SVN0 && InnerN->isOnlyUserOf(N: SVN0) &&
26022 MergeInnerShuffle(Commute, SVN, SVN0, Op1, TLI, SV0, SV1,
26023 Mask) &&
26024 (llvm::any_of(Range: SVN0->getMask(), P: [](int M) { return M < 0; }) ||
26025 llvm::none_of(Range&: Mask, P: [](int M) { return M < 0; }));
26026 };
26027
26028 // Ensure we don't increase the number of shuffles - we must merge a
26029 // shuffle from at least one of the LHS and RHS ops.
26030 bool MergedLeft = false;
26031 SDValue LeftSV0, LeftSV1;
26032 SmallVector<int, 4> LeftMask;
26033 if (CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, false) ||
26034 CanMergeInnerShuffle(LeftSV0, LeftSV1, LeftMask, true, true)) {
26035 MergedLeft = true;
26036 } else {
26037 LeftMask.assign(in_start: SVN->getMask().begin(), in_end: SVN->getMask().end());
26038 LeftSV0 = Op00, LeftSV1 = Op10;
26039 }
26040
26041 bool MergedRight = false;
26042 SDValue RightSV0, RightSV1;
26043 SmallVector<int, 4> RightMask;
26044 if (CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, false) ||
26045 CanMergeInnerShuffle(RightSV0, RightSV1, RightMask, false, true)) {
26046 MergedRight = true;
26047 } else {
26048 RightMask.assign(in_start: SVN->getMask().begin(), in_end: SVN->getMask().end());
26049 RightSV0 = Op01, RightSV1 = Op11;
26050 }
26051
26052 if (MergedLeft || MergedRight) {
26053 SDLoc DL(N);
26054 SDValue LHS = DAG.getVectorShuffle(
26055 VT, dl: DL, N1: LeftSV0 ? LeftSV0 : DAG.getUNDEF(VT),
26056 N2: LeftSV1 ? LeftSV1 : DAG.getUNDEF(VT), Mask: LeftMask);
26057 SDValue RHS = DAG.getVectorShuffle(
26058 VT, dl: DL, N1: RightSV0 ? RightSV0 : DAG.getUNDEF(VT),
26059 N2: RightSV1 ? RightSV1 : DAG.getUNDEF(VT), Mask: RightMask);
26060 return DAG.getNode(Opcode: SrcOpcode, DL, VT, N1: LHS, N2: RHS);
26061 }
26062 }
26063 }
26064 }
26065
26066 if (SDValue V = foldShuffleOfConcatUndefs(Shuf: SVN, DAG))
26067 return V;
26068
26069 // Match shuffles that can be converted to ISD::ZERO_EXTEND_VECTOR_INREG.
26070 // Perform this really late, because it could eliminate knowledge
26071 // of undef elements created by this shuffle.
26072 if (Level < AfterLegalizeTypes)
26073 if (SDValue V = combineShuffleToZeroExtendVectorInReg(SVN, DAG, TLI,
26074 LegalOperations))
26075 return V;
26076
26077 return SDValue();
26078}
26079
26080SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
26081 EVT VT = N->getValueType(ResNo: 0);
26082 if (!VT.isFixedLengthVector())
26083 return SDValue();
26084
26085 // Try to convert a scalar binop with an extracted vector element to a vector
26086 // binop. This is intended to reduce potentially expensive register moves.
26087 // TODO: Check if both operands are extracted.
26088 // TODO: How to prefer scalar/vector ops with multiple uses of the extact?
26089 // TODO: Generalize this, so it can be called from visitINSERT_VECTOR_ELT().
26090 SDValue Scalar = N->getOperand(Num: 0);
26091 unsigned Opcode = Scalar.getOpcode();
26092 EVT VecEltVT = VT.getScalarType();
26093 if (Scalar.hasOneUse() && Scalar->getNumValues() == 1 &&
26094 TLI.isBinOp(Opcode) && Scalar.getValueType() == VecEltVT &&
26095 Scalar.getOperand(i: 0).getValueType() == VecEltVT &&
26096 Scalar.getOperand(i: 1).getValueType() == VecEltVT &&
26097 Scalar->isOnlyUserOf(N: Scalar.getOperand(i: 0).getNode()) &&
26098 Scalar->isOnlyUserOf(N: Scalar.getOperand(i: 1).getNode()) &&
26099 DAG.isSafeToSpeculativelyExecute(Opcode) && hasOperation(Opcode, VT)) {
26100 // Match an extract element and get a shuffle mask equivalent.
26101 SmallVector<int, 8> ShufMask(VT.getVectorNumElements(), -1);
26102
26103 for (int i : {0, 1}) {
26104 // s2v (bo (extelt V, Idx), C) --> shuffle (bo V, C'), {Idx, -1, -1...}
26105 // s2v (bo C, (extelt V, Idx)) --> shuffle (bo C', V), {Idx, -1, -1...}
26106 SDValue EE = Scalar.getOperand(i);
26107 auto *C = dyn_cast<ConstantSDNode>(Val: Scalar.getOperand(i: i ? 0 : 1));
26108 if (C && EE.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
26109 EE.getOperand(i: 0).getValueType() == VT &&
26110 isa<ConstantSDNode>(Val: EE.getOperand(i: 1))) {
26111 // Mask = {ExtractIndex, undef, undef....}
26112 ShufMask[0] = EE.getConstantOperandVal(i: 1);
26113 // Make sure the shuffle is legal if we are crossing lanes.
26114 if (TLI.isShuffleMaskLegal(ShufMask, VT)) {
26115 SDLoc DL(N);
26116 SDValue V[] = {EE.getOperand(i: 0),
26117 DAG.getConstant(Val: C->getAPIntValue(), DL, VT)};
26118 SDValue VecBO = DAG.getNode(Opcode, DL, VT, N1: V[i], N2: V[1 - i]);
26119 return DAG.getVectorShuffle(VT, dl: DL, N1: VecBO, N2: DAG.getUNDEF(VT),
26120 Mask: ShufMask);
26121 }
26122 }
26123 }
26124 }
26125
26126 // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern
26127 // with a VECTOR_SHUFFLE and possible truncate.
26128 if (Opcode != ISD::EXTRACT_VECTOR_ELT ||
26129 !Scalar.getOperand(i: 0).getValueType().isFixedLengthVector())
26130 return SDValue();
26131
26132 // If we have an implicit truncate, truncate here if it is legal.
26133 if (VecEltVT != Scalar.getValueType() &&
26134 Scalar.getValueType().isScalarInteger() && isTypeLegal(VT: VecEltVT)) {
26135 SDValue Val = DAG.getNode(Opcode: ISD::TRUNCATE, DL: SDLoc(Scalar), VT: VecEltVT, Operand: Scalar);
26136 return DAG.getNode(Opcode: ISD::SCALAR_TO_VECTOR, DL: SDLoc(N), VT, Operand: Val);
26137 }
26138
26139 auto *ExtIndexC = dyn_cast<ConstantSDNode>(Val: Scalar.getOperand(i: 1));
26140 if (!ExtIndexC)
26141 return SDValue();
26142
26143 SDValue SrcVec = Scalar.getOperand(i: 0);
26144 EVT SrcVT = SrcVec.getValueType();
26145 unsigned SrcNumElts = SrcVT.getVectorNumElements();
26146 unsigned VTNumElts = VT.getVectorNumElements();
26147 if (VecEltVT == SrcVT.getScalarType() && VTNumElts <= SrcNumElts) {
26148 // Create a shuffle equivalent for scalar-to-vector: {ExtIndex, -1, -1, ...}
26149 SmallVector<int, 8> Mask(SrcNumElts, -1);
26150 Mask[0] = ExtIndexC->getZExtValue();
26151 SDValue LegalShuffle = TLI.buildLegalVectorShuffle(
26152 VT: SrcVT, DL: SDLoc(N), N0: SrcVec, N1: DAG.getUNDEF(VT: SrcVT), Mask, DAG);
26153 if (!LegalShuffle)
26154 return SDValue();
26155
26156 // If the initial vector is the same size, the shuffle is the result.
26157 if (VT == SrcVT)
26158 return LegalShuffle;
26159
26160 // If not, shorten the shuffled vector.
26161 if (VTNumElts != SrcNumElts) {
26162 SDValue ZeroIdx = DAG.getVectorIdxConstant(Val: 0, DL: SDLoc(N));
26163 EVT SubVT = EVT::getVectorVT(Context&: *DAG.getContext(),
26164 VT: SrcVT.getVectorElementType(), NumElements: VTNumElts);
26165 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N), VT: SubVT, N1: LegalShuffle,
26166 N2: ZeroIdx);
26167 }
26168 }
26169
26170 return SDValue();
26171}
26172
26173SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
26174 EVT VT = N->getValueType(ResNo: 0);
26175 SDValue N0 = N->getOperand(Num: 0);
26176 SDValue N1 = N->getOperand(Num: 1);
26177 SDValue N2 = N->getOperand(Num: 2);
26178 uint64_t InsIdx = N->getConstantOperandVal(Num: 2);
26179
26180 // If inserting an UNDEF, just return the original vector.
26181 if (N1.isUndef())
26182 return N0;
26183
26184 // If this is an insert of an extracted vector into an undef vector, we can
26185 // just use the input to the extract if the types match, and can simplify
26186 // in some cases even if they don't.
26187 if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
26188 N1.getOperand(i: 1) == N2) {
26189 EVT SrcVT = N1.getOperand(i: 0).getValueType();
26190 if (SrcVT == VT)
26191 return N1.getOperand(i: 0);
26192 // TODO: To remove the zero check, need to adjust the offset to
26193 // a multiple of the new src type.
26194 if (isNullConstant(V: N2) &&
26195 VT.isScalableVector() == SrcVT.isScalableVector()) {
26196 if (VT.getVectorMinNumElements() >= SrcVT.getVectorMinNumElements())
26197 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N),
26198 VT, N1: N0, N2: N1.getOperand(i: 0), N3: N2);
26199 else
26200 return DAG.getNode(Opcode: ISD::EXTRACT_SUBVECTOR, DL: SDLoc(N),
26201 VT, N1: N1.getOperand(i: 0), N2);
26202 }
26203 }
26204
26205 // Handle case where we've ended up inserting back into the source vector
26206 // we extracted the subvector from.
26207 // insert_subvector(N0, extract_subvector(N0, N2), N2) --> N0
26208 if (N1.getOpcode() == ISD::EXTRACT_SUBVECTOR && N1.getOperand(i: 0) == N0 &&
26209 N1.getOperand(i: 1) == N2)
26210 return N0;
26211
26212 // Simplify scalar inserts into an undef vector:
26213 // insert_subvector undef, (splat X), N2 -> splat X
26214 if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR)
26215 if (DAG.isConstantValueOfAnyType(N: N1.getOperand(i: 0)) || N1.hasOneUse())
26216 return DAG.getNode(Opcode: ISD::SPLAT_VECTOR, DL: SDLoc(N), VT, Operand: N1.getOperand(i: 0));
26217
26218 // If we are inserting a bitcast value into an undef, with the same
26219 // number of elements, just use the bitcast input of the extract.
26220 // i.e. INSERT_SUBVECTOR UNDEF (BITCAST N1) N2 ->
26221 // BITCAST (INSERT_SUBVECTOR UNDEF N1 N2)
26222 if (N0.isUndef() && N1.getOpcode() == ISD::BITCAST &&
26223 N1.getOperand(i: 0).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
26224 N1.getOperand(i: 0).getOperand(i: 1) == N2 &&
26225 N1.getOperand(i: 0).getOperand(i: 0).getValueType().getVectorElementCount() ==
26226 VT.getVectorElementCount() &&
26227 N1.getOperand(i: 0).getOperand(i: 0).getValueType().getSizeInBits() ==
26228 VT.getSizeInBits()) {
26229 return DAG.getBitcast(VT, V: N1.getOperand(i: 0).getOperand(i: 0));
26230 }
26231
26232 // If both N1 and N2 are bitcast values on which insert_subvector
26233 // would makes sense, pull the bitcast through.
26234 // i.e. INSERT_SUBVECTOR (BITCAST N0) (BITCAST N1) N2 ->
26235 // BITCAST (INSERT_SUBVECTOR N0 N1 N2)
26236 if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
26237 SDValue CN0 = N0.getOperand(i: 0);
26238 SDValue CN1 = N1.getOperand(i: 0);
26239 EVT CN0VT = CN0.getValueType();
26240 EVT CN1VT = CN1.getValueType();
26241 if (CN0VT.isVector() && CN1VT.isVector() &&
26242 CN0VT.getVectorElementType() == CN1VT.getVectorElementType() &&
26243 CN0VT.getVectorElementCount() == VT.getVectorElementCount()) {
26244 SDValue NewINSERT = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N),
26245 VT: CN0.getValueType(), N1: CN0, N2: CN1, N3: N2);
26246 return DAG.getBitcast(VT, V: NewINSERT);
26247 }
26248 }
26249
26250 // Combine INSERT_SUBVECTORs where we are inserting to the same index.
26251 // INSERT_SUBVECTOR( INSERT_SUBVECTOR( Vec, SubOld, Idx ), SubNew, Idx )
26252 // --> INSERT_SUBVECTOR( Vec, SubNew, Idx )
26253 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
26254 N0.getOperand(i: 1).getValueType() == N1.getValueType() &&
26255 N0.getOperand(i: 2) == N2)
26256 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT, N1: N0.getOperand(i: 0),
26257 N2: N1, N3: N2);
26258
26259 // Eliminate an intermediate insert into an undef vector:
26260 // insert_subvector undef, (insert_subvector undef, X, 0), 0 -->
26261 // insert_subvector undef, X, 0
26262 if (N0.isUndef() && N1.getOpcode() == ISD::INSERT_SUBVECTOR &&
26263 N1.getOperand(i: 0).isUndef() && isNullConstant(V: N1.getOperand(i: 2)) &&
26264 isNullConstant(V: N2))
26265 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT, N1: N0,
26266 N2: N1.getOperand(i: 1), N3: N2);
26267
26268 // Push subvector bitcasts to the output, adjusting the index as we go.
26269 // insert_subvector(bitcast(v), bitcast(s), c1)
26270 // -> bitcast(insert_subvector(v, s, c2))
26271 if ((N0.isUndef() || N0.getOpcode() == ISD::BITCAST) &&
26272 N1.getOpcode() == ISD::BITCAST) {
26273 SDValue N0Src = peekThroughBitcasts(V: N0);
26274 SDValue N1Src = peekThroughBitcasts(V: N1);
26275 EVT N0SrcSVT = N0Src.getValueType().getScalarType();
26276 EVT N1SrcSVT = N1Src.getValueType().getScalarType();
26277 if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) &&
26278 N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) {
26279 EVT NewVT;
26280 SDLoc DL(N);
26281 SDValue NewIdx;
26282 LLVMContext &Ctx = *DAG.getContext();
26283 ElementCount NumElts = VT.getVectorElementCount();
26284 unsigned EltSizeInBits = VT.getScalarSizeInBits();
26285 if ((EltSizeInBits % N1SrcSVT.getSizeInBits()) == 0) {
26286 unsigned Scale = EltSizeInBits / N1SrcSVT.getSizeInBits();
26287 NewVT = EVT::getVectorVT(Context&: Ctx, VT: N1SrcSVT, EC: NumElts * Scale);
26288 NewIdx = DAG.getVectorIdxConstant(Val: InsIdx * Scale, DL);
26289 } else if ((N1SrcSVT.getSizeInBits() % EltSizeInBits) == 0) {
26290 unsigned Scale = N1SrcSVT.getSizeInBits() / EltSizeInBits;
26291 if (NumElts.isKnownMultipleOf(RHS: Scale) && (InsIdx % Scale) == 0) {
26292 NewVT = EVT::getVectorVT(Context&: Ctx, VT: N1SrcSVT,
26293 EC: NumElts.divideCoefficientBy(RHS: Scale));
26294 NewIdx = DAG.getVectorIdxConstant(Val: InsIdx / Scale, DL);
26295 }
26296 }
26297 if (NewIdx && hasOperation(Opcode: ISD::INSERT_SUBVECTOR, VT: NewVT)) {
26298 SDValue Res = DAG.getBitcast(VT: NewVT, V: N0Src);
26299 Res = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT: NewVT, N1: Res, N2: N1Src, N3: NewIdx);
26300 return DAG.getBitcast(VT, V: Res);
26301 }
26302 }
26303 }
26304
26305 // Canonicalize insert_subvector dag nodes.
26306 // Example:
26307 // (insert_subvector (insert_subvector A, Idx0), Idx1)
26308 // -> (insert_subvector (insert_subvector A, Idx1), Idx0)
26309 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR && N0.hasOneUse() &&
26310 N1.getValueType() == N0.getOperand(i: 1).getValueType()) {
26311 unsigned OtherIdx = N0.getConstantOperandVal(i: 2);
26312 if (InsIdx < OtherIdx) {
26313 // Swap nodes.
26314 SDValue NewOp = DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N), VT,
26315 N1: N0.getOperand(i: 0), N2: N1, N3: N2);
26316 AddToWorklist(N: NewOp.getNode());
26317 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL: SDLoc(N0.getNode()),
26318 VT, N1: NewOp, N2: N0.getOperand(i: 1), N3: N0.getOperand(i: 2));
26319 }
26320 }
26321
26322 // If the input vector is a concatenation, and the insert replaces
26323 // one of the pieces, we can optimize into a single concat_vectors.
26324 if (N0.getOpcode() == ISD::CONCAT_VECTORS && N0.hasOneUse() &&
26325 N0.getOperand(i: 0).getValueType() == N1.getValueType() &&
26326 N0.getOperand(i: 0).getValueType().isScalableVector() ==
26327 N1.getValueType().isScalableVector()) {
26328 unsigned Factor = N1.getValueType().getVectorMinNumElements();
26329 SmallVector<SDValue, 8> Ops(N0->op_begin(), N0->op_end());
26330 Ops[InsIdx / Factor] = N1;
26331 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL: SDLoc(N), VT, Ops);
26332 }
26333
26334 // Simplify source operands based on insertion.
26335 if (SimplifyDemandedVectorElts(Op: SDValue(N, 0)))
26336 return SDValue(N, 0);
26337
26338 return SDValue();
26339}
26340
26341SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
26342 SDValue N0 = N->getOperand(Num: 0);
26343
26344 // fold (fp_to_fp16 (fp16_to_fp op)) -> op
26345 if (N0->getOpcode() == ISD::FP16_TO_FP)
26346 return N0->getOperand(Num: 0);
26347
26348 return SDValue();
26349}
26350
26351SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
26352 auto Op = N->getOpcode();
26353 assert((Op == ISD::FP16_TO_FP || Op == ISD::BF16_TO_FP) &&
26354 "opcode should be FP16_TO_FP or BF16_TO_FP.");
26355 SDValue N0 = N->getOperand(Num: 0);
26356
26357 // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op) or
26358 // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
26359 if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
26360 ConstantSDNode *AndConst = getAsNonOpaqueConstant(N: N0.getOperand(i: 1));
26361 if (AndConst && AndConst->getAPIntValue() == 0xffff) {
26362 return DAG.getNode(Opcode: Op, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: N0.getOperand(i: 0));
26363 }
26364 }
26365
26366 return SDValue();
26367}
26368
26369SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
26370 SDValue N0 = N->getOperand(Num: 0);
26371
26372 // fold (fp_to_bf16 (bf16_to_fp op)) -> op
26373 if (N0->getOpcode() == ISD::BF16_TO_FP)
26374 return N0->getOperand(Num: 0);
26375
26376 return SDValue();
26377}
26378
26379SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
26380 // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
26381 return visitFP16_TO_FP(N);
26382}
26383
26384SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
26385 SDValue N0 = N->getOperand(Num: 0);
26386 EVT VT = N0.getValueType();
26387 unsigned Opcode = N->getOpcode();
26388
26389 // VECREDUCE over 1-element vector is just an extract.
26390 if (VT.getVectorElementCount().isScalar()) {
26391 SDLoc dl(N);
26392 SDValue Res =
26393 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL: dl, VT: VT.getVectorElementType(), N1: N0,
26394 N2: DAG.getVectorIdxConstant(Val: 0, DL: dl));
26395 if (Res.getValueType() != N->getValueType(ResNo: 0))
26396 Res = DAG.getNode(Opcode: ISD::ANY_EXTEND, DL: dl, VT: N->getValueType(ResNo: 0), Operand: Res);
26397 return Res;
26398 }
26399
26400 // On an boolean vector an and/or reduction is the same as a umin/umax
26401 // reduction. Convert them if the latter is legal while the former isn't.
26402 if (Opcode == ISD::VECREDUCE_AND || Opcode == ISD::VECREDUCE_OR) {
26403 unsigned NewOpcode = Opcode == ISD::VECREDUCE_AND
26404 ? ISD::VECREDUCE_UMIN : ISD::VECREDUCE_UMAX;
26405 if (!TLI.isOperationLegalOrCustom(Op: Opcode, VT) &&
26406 TLI.isOperationLegalOrCustom(Op: NewOpcode, VT) &&
26407 DAG.ComputeNumSignBits(Op: N0) == VT.getScalarSizeInBits())
26408 return DAG.getNode(Opcode: NewOpcode, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: N0);
26409 }
26410
26411 // vecreduce_or(insert_subvector(zero or undef, val)) -> vecreduce_or(val)
26412 // vecreduce_and(insert_subvector(ones or undef, val)) -> vecreduce_and(val)
26413 if (N0.getOpcode() == ISD::INSERT_SUBVECTOR &&
26414 TLI.isTypeLegal(VT: N0.getOperand(i: 1).getValueType())) {
26415 SDValue Vec = N0.getOperand(i: 0);
26416 SDValue Subvec = N0.getOperand(i: 1);
26417 if ((Opcode == ISD::VECREDUCE_OR &&
26418 (N0.getOperand(i: 0).isUndef() || isNullOrNullSplat(V: Vec))) ||
26419 (Opcode == ISD::VECREDUCE_AND &&
26420 (N0.getOperand(i: 0).isUndef() || isAllOnesOrAllOnesSplat(V: Vec))))
26421 return DAG.getNode(Opcode, DL: SDLoc(N), VT: N->getValueType(ResNo: 0), Operand: Subvec);
26422 }
26423
26424 return SDValue();
26425}
26426
26427SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
26428 SelectionDAG::FlagInserter FlagsInserter(DAG, N);
26429
26430 // FSUB -> FMA combines:
26431 if (SDValue Fused = visitFSUBForFMACombine<VPMatchContext>(N)) {
26432 AddToWorklist(N: Fused.getNode());
26433 return Fused;
26434 }
26435 return SDValue();
26436}
26437
26438SDValue DAGCombiner::visitVPOp(SDNode *N) {
26439
26440 if (N->getOpcode() == ISD::VP_GATHER)
26441 if (SDValue SD = visitVPGATHER(N))
26442 return SD;
26443
26444 if (N->getOpcode() == ISD::VP_SCATTER)
26445 if (SDValue SD = visitVPSCATTER(N))
26446 return SD;
26447
26448 if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_LOAD)
26449 if (SDValue SD = visitVP_STRIDED_LOAD(N))
26450 return SD;
26451
26452 if (N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_STORE)
26453 if (SDValue SD = visitVP_STRIDED_STORE(N))
26454 return SD;
26455
26456 // VP operations in which all vector elements are disabled - either by
26457 // determining that the mask is all false or that the EVL is 0 - can be
26458 // eliminated.
26459 bool AreAllEltsDisabled = false;
26460 if (auto EVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode: N->getOpcode()))
26461 AreAllEltsDisabled |= isNullConstant(V: N->getOperand(Num: *EVLIdx));
26462 if (auto MaskIdx = ISD::getVPMaskIdx(Opcode: N->getOpcode()))
26463 AreAllEltsDisabled |=
26464 ISD::isConstantSplatVectorAllZeros(N: N->getOperand(Num: *MaskIdx).getNode());
26465
26466 // This is the only generic VP combine we support for now.
26467 if (!AreAllEltsDisabled) {
26468 switch (N->getOpcode()) {
26469 case ISD::VP_FADD:
26470 return visitVP_FADD(N);
26471 case ISD::VP_FSUB:
26472 return visitVP_FSUB(N);
26473 case ISD::VP_FMA:
26474 return visitFMA<VPMatchContext>(N);
26475 case ISD::VP_SELECT:
26476 return visitVP_SELECT(N);
26477 }
26478 return SDValue();
26479 }
26480
26481 // Binary operations can be replaced by UNDEF.
26482 if (ISD::isVPBinaryOp(Opcode: N->getOpcode()))
26483 return DAG.getUNDEF(VT: N->getValueType(ResNo: 0));
26484
26485 // VP Memory operations can be replaced by either the chain (stores) or the
26486 // chain + undef (loads).
26487 if (const auto *MemSD = dyn_cast<MemSDNode>(Val: N)) {
26488 if (MemSD->writeMem())
26489 return MemSD->getChain();
26490 return CombineTo(N, Res0: DAG.getUNDEF(VT: N->getValueType(ResNo: 0)), Res1: MemSD->getChain());
26491 }
26492
26493 // Reduction operations return the start operand when no elements are active.
26494 if (ISD::isVPReduction(Opcode: N->getOpcode()))
26495 return N->getOperand(Num: 0);
26496
26497 return SDValue();
26498}
26499
26500SDValue DAGCombiner::visitGET_FPENV_MEM(SDNode *N) {
26501 SDValue Chain = N->getOperand(Num: 0);
26502 SDValue Ptr = N->getOperand(Num: 1);
26503 EVT MemVT = cast<FPStateAccessSDNode>(Val: N)->getMemoryVT();
26504
26505 // Check if the memory, where FP state is written to, is used only in a single
26506 // load operation.
26507 LoadSDNode *LdNode = nullptr;
26508 for (auto *U : Ptr->uses()) {
26509 if (U == N)
26510 continue;
26511 if (auto *Ld = dyn_cast<LoadSDNode>(Val: U)) {
26512 if (LdNode && LdNode != Ld)
26513 return SDValue();
26514 LdNode = Ld;
26515 continue;
26516 }
26517 return SDValue();
26518 }
26519 if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
26520 !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
26521 !LdNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(N, 0)))
26522 return SDValue();
26523
26524 // Check if the loaded value is used only in a store operation.
26525 StoreSDNode *StNode = nullptr;
26526 for (auto I = LdNode->use_begin(), E = LdNode->use_end(); I != E; ++I) {
26527 SDUse &U = I.getUse();
26528 if (U.getResNo() == 0) {
26529 if (auto *St = dyn_cast<StoreSDNode>(Val: U.getUser())) {
26530 if (StNode)
26531 return SDValue();
26532 StNode = St;
26533 } else {
26534 return SDValue();
26535 }
26536 }
26537 }
26538 if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
26539 !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
26540 !StNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(LdNode, 1)))
26541 return SDValue();
26542
26543 // Create new node GET_FPENV_MEM, which uses the store address to write FP
26544 // environment.
26545 SDValue Res = DAG.getGetFPEnv(Chain, dl: SDLoc(N), Ptr: StNode->getBasePtr(), MemVT,
26546 MMO: StNode->getMemOperand());
26547 CombineTo(N: StNode, Res, AddTo: false);
26548 return Res;
26549}
26550
26551SDValue DAGCombiner::visitSET_FPENV_MEM(SDNode *N) {
26552 SDValue Chain = N->getOperand(Num: 0);
26553 SDValue Ptr = N->getOperand(Num: 1);
26554 EVT MemVT = cast<FPStateAccessSDNode>(Val: N)->getMemoryVT();
26555
26556 // Check if the address of FP state is used also in a store operation only.
26557 StoreSDNode *StNode = nullptr;
26558 for (auto *U : Ptr->uses()) {
26559 if (U == N)
26560 continue;
26561 if (auto *St = dyn_cast<StoreSDNode>(Val: U)) {
26562 if (StNode && StNode != St)
26563 return SDValue();
26564 StNode = St;
26565 continue;
26566 }
26567 return SDValue();
26568 }
26569 if (!StNode || !StNode->isSimple() || StNode->isIndexed() ||
26570 !StNode->getOffset().isUndef() || StNode->getMemoryVT() != MemVT ||
26571 !Chain.reachesChainWithoutSideEffects(Dest: SDValue(StNode, 0)))
26572 return SDValue();
26573
26574 // Check if the stored value is loaded from some location and the loaded
26575 // value is used only in the store operation.
26576 SDValue StValue = StNode->getValue();
26577 auto *LdNode = dyn_cast<LoadSDNode>(Val&: StValue);
26578 if (!LdNode || !LdNode->isSimple() || LdNode->isIndexed() ||
26579 !LdNode->getOffset().isUndef() || LdNode->getMemoryVT() != MemVT ||
26580 !StNode->getChain().reachesChainWithoutSideEffects(Dest: SDValue(LdNode, 1)))
26581 return SDValue();
26582
26583 // Create new node SET_FPENV_MEM, which uses the load address to read FP
26584 // environment.
26585 SDValue Res =
26586 DAG.getSetFPEnv(Chain: LdNode->getChain(), dl: SDLoc(N), Ptr: LdNode->getBasePtr(), MemVT,
26587 MMO: LdNode->getMemOperand());
26588 return Res;
26589}
26590
26591/// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
26592/// with the destination vector and a zero vector.
26593/// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>
26594/// vector_shuffle V, Zero, <0, 4, 2, 4>
26595SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) {
26596 assert(N->getOpcode() == ISD::AND && "Unexpected opcode!");
26597
26598 EVT VT = N->getValueType(ResNo: 0);
26599 SDValue LHS = N->getOperand(Num: 0);
26600 SDValue RHS = peekThroughBitcasts(V: N->getOperand(Num: 1));
26601 SDLoc DL(N);
26602
26603 // Make sure we're not running after operation legalization where it
26604 // may have custom lowered the vector shuffles.
26605 if (LegalOperations)
26606 return SDValue();
26607
26608 if (RHS.getOpcode() != ISD::BUILD_VECTOR)
26609 return SDValue();
26610
26611 EVT RVT = RHS.getValueType();
26612 unsigned NumElts = RHS.getNumOperands();
26613
26614 // Attempt to create a valid clear mask, splitting the mask into
26615 // sub elements and checking to see if each is
26616 // all zeros or all ones - suitable for shuffle masking.
26617 auto BuildClearMask = [&](int Split) {
26618 int NumSubElts = NumElts * Split;
26619 int NumSubBits = RVT.getScalarSizeInBits() / Split;
26620
26621 SmallVector<int, 8> Indices;
26622 for (int i = 0; i != NumSubElts; ++i) {
26623 int EltIdx = i / Split;
26624 int SubIdx = i % Split;
26625 SDValue Elt = RHS.getOperand(i: EltIdx);
26626 // X & undef --> 0 (not undef). So this lane must be converted to choose
26627 // from the zero constant vector (same as if the element had all 0-bits).
26628 if (Elt.isUndef()) {
26629 Indices.push_back(Elt: i + NumSubElts);
26630 continue;
26631 }
26632
26633 APInt Bits;
26634 if (auto *Cst = dyn_cast<ConstantSDNode>(Val&: Elt))
26635 Bits = Cst->getAPIntValue();
26636 else if (auto *CstFP = dyn_cast<ConstantFPSDNode>(Val&: Elt))
26637 Bits = CstFP->getValueAPF().bitcastToAPInt();
26638 else
26639 return SDValue();
26640
26641 // Extract the sub element from the constant bit mask.
26642 if (DAG.getDataLayout().isBigEndian())
26643 Bits = Bits.extractBits(numBits: NumSubBits, bitPosition: (Split - SubIdx - 1) * NumSubBits);
26644 else
26645 Bits = Bits.extractBits(numBits: NumSubBits, bitPosition: SubIdx * NumSubBits);
26646
26647 if (Bits.isAllOnes())
26648 Indices.push_back(Elt: i);
26649 else if (Bits == 0)
26650 Indices.push_back(Elt: i + NumSubElts);
26651 else
26652 return SDValue();
26653 }
26654
26655 // Let's see if the target supports this vector_shuffle.
26656 EVT ClearSVT = EVT::getIntegerVT(Context&: *DAG.getContext(), BitWidth: NumSubBits);
26657 EVT ClearVT = EVT::getVectorVT(Context&: *DAG.getContext(), VT: ClearSVT, NumElements: NumSubElts);
26658 if (!TLI.isVectorClearMaskLegal(Indices, ClearVT))
26659 return SDValue();
26660
26661 SDValue Zero = DAG.getConstant(Val: 0, DL, VT: ClearVT);
26662 return DAG.getBitcast(VT, V: DAG.getVectorShuffle(VT: ClearVT, dl: DL,
26663 N1: DAG.getBitcast(VT: ClearVT, V: LHS),
26664 N2: Zero, Mask: Indices));
26665 };
26666
26667 // Determine maximum split level (byte level masking).
26668 int MaxSplit = 1;
26669 if (RVT.getScalarSizeInBits() % 8 == 0)
26670 MaxSplit = RVT.getScalarSizeInBits() / 8;
26671
26672 for (int Split = 1; Split <= MaxSplit; ++Split)
26673 if (RVT.getScalarSizeInBits() % Split == 0)
26674 if (SDValue S = BuildClearMask(Split))
26675 return S;
26676
26677 return SDValue();
26678}
26679
26680/// If a vector binop is performed on splat values, it may be profitable to
26681/// extract, scalarize, and insert/splat.
26682static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
26683 const SDLoc &DL) {
26684 SDValue N0 = N->getOperand(Num: 0);
26685 SDValue N1 = N->getOperand(Num: 1);
26686 unsigned Opcode = N->getOpcode();
26687 EVT VT = N->getValueType(ResNo: 0);
26688 EVT EltVT = VT.getVectorElementType();
26689 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26690
26691 // TODO: Remove/replace the extract cost check? If the elements are available
26692 // as scalars, then there may be no extract cost. Should we ask if
26693 // inserting a scalar back into a vector is cheap instead?
26694 int Index0, Index1;
26695 SDValue Src0 = DAG.getSplatSourceVector(V: N0, SplatIndex&: Index0);
26696 SDValue Src1 = DAG.getSplatSourceVector(V: N1, SplatIndex&: Index1);
26697 // Extract element from splat_vector should be free.
26698 // TODO: use DAG.isSplatValue instead?
26699 bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR &&
26700 N1.getOpcode() == ISD::SPLAT_VECTOR;
26701 if (!Src0 || !Src1 || Index0 != Index1 ||
26702 Src0.getValueType().getVectorElementType() != EltVT ||
26703 Src1.getValueType().getVectorElementType() != EltVT ||
26704 !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index: Index0)) ||
26705 !TLI.isOperationLegalOrCustom(Op: Opcode, VT: EltVT))
26706 return SDValue();
26707
26708 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index0, DL);
26709 SDValue X = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Src0, N2: IndexC);
26710 SDValue Y = DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: EltVT, N1: Src1, N2: IndexC);
26711 SDValue ScalarBO = DAG.getNode(Opcode, DL, VT: EltVT, N1: X, N2: Y, Flags: N->getFlags());
26712
26713 // If all lanes but 1 are undefined, no need to splat the scalar result.
26714 // TODO: Keep track of undefs and use that info in the general case.
26715 if (N0.getOpcode() == ISD::BUILD_VECTOR && N0.getOpcode() == N1.getOpcode() &&
26716 count_if(Range: N0->ops(), P: [](SDValue V) { return !V.isUndef(); }) == 1 &&
26717 count_if(Range: N1->ops(), P: [](SDValue V) { return !V.isUndef(); }) == 1) {
26718 // bo (build_vec ..undef, X, undef...), (build_vec ..undef, Y, undef...) -->
26719 // build_vec ..undef, (bo X, Y), undef...
26720 SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), DAG.getUNDEF(VT: EltVT));
26721 Ops[Index0] = ScalarBO;
26722 return DAG.getBuildVector(VT, DL, Ops);
26723 }
26724
26725 // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
26726 return DAG.getSplat(VT, DL, Op: ScalarBO);
26727}
26728
26729/// Visit a vector cast operation, like FP_EXTEND.
26730SDValue DAGCombiner::SimplifyVCastOp(SDNode *N, const SDLoc &DL) {
26731 EVT VT = N->getValueType(ResNo: 0);
26732 assert(VT.isVector() && "SimplifyVCastOp only works on vectors!");
26733 EVT EltVT = VT.getVectorElementType();
26734 unsigned Opcode = N->getOpcode();
26735
26736 SDValue N0 = N->getOperand(Num: 0);
26737 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
26738
26739 // TODO: promote operation might be also good here?
26740 int Index0;
26741 SDValue Src0 = DAG.getSplatSourceVector(V: N0, SplatIndex&: Index0);
26742 if (Src0 &&
26743 (N0.getOpcode() == ISD::SPLAT_VECTOR ||
26744 TLI.isExtractVecEltCheap(VT, Index: Index0)) &&
26745 TLI.isOperationLegalOrCustom(Op: Opcode, VT: EltVT) &&
26746 TLI.preferScalarizeSplat(N)) {
26747 EVT SrcVT = N0.getValueType();
26748 EVT SrcEltVT = SrcVT.getVectorElementType();
26749 SDValue IndexC = DAG.getVectorIdxConstant(Val: Index0, DL);
26750 SDValue Elt =
26751 DAG.getNode(Opcode: ISD::EXTRACT_VECTOR_ELT, DL, VT: SrcEltVT, N1: Src0, N2: IndexC);
26752 SDValue ScalarBO = DAG.getNode(Opcode, DL, VT: EltVT, Operand: Elt, Flags: N->getFlags());
26753 if (VT.isScalableVector())
26754 return DAG.getSplatVector(VT, DL, Op: ScalarBO);
26755 SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
26756 return DAG.getBuildVector(VT, DL, Ops);
26757 }
26758
26759 return SDValue();
26760}
26761
26762/// Visit a binary vector operation, like ADD.
26763SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
26764 EVT VT = N->getValueType(ResNo: 0);
26765 assert(VT.isVector() && "SimplifyVBinOp only works on vectors!");
26766
26767 SDValue LHS = N->getOperand(Num: 0);
26768 SDValue RHS = N->getOperand(Num: 1);
26769 unsigned Opcode = N->getOpcode();
26770 SDNodeFlags Flags = N->getFlags();
26771
26772 // Move unary shuffles with identical masks after a vector binop:
26773 // VBinOp (shuffle A, Undef, Mask), (shuffle B, Undef, Mask))
26774 // --> shuffle (VBinOp A, B), Undef, Mask
26775 // This does not require type legality checks because we are creating the
26776 // same types of operations that are in the original sequence. We do have to
26777 // restrict ops like integer div that have immediate UB (eg, div-by-zero)
26778 // though. This code is adapted from the identical transform in instcombine.
26779 if (DAG.isSafeToSpeculativelyExecute(Opcode)) {
26780 auto *Shuf0 = dyn_cast<ShuffleVectorSDNode>(Val&: LHS);
26781 auto *Shuf1 = dyn_cast<ShuffleVectorSDNode>(Val&: RHS);
26782 if (Shuf0 && Shuf1 && Shuf0->getMask().equals(RHS: Shuf1->getMask()) &&
26783 LHS.getOperand(i: 1).isUndef() && RHS.getOperand(i: 1).isUndef() &&
26784 (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) {
26785 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: LHS.getOperand(i: 0),
26786 N2: RHS.getOperand(i: 0), Flags);
26787 SDValue UndefV = LHS.getOperand(i: 1);
26788 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: UndefV, Mask: Shuf0->getMask());
26789 }
26790
26791 // Try to sink a splat shuffle after a binop with a uniform constant.
26792 // This is limited to cases where neither the shuffle nor the constant have
26793 // undefined elements because that could be poison-unsafe or inhibit
26794 // demanded elements analysis. It is further limited to not change a splat
26795 // of an inserted scalar because that may be optimized better by
26796 // load-folding or other target-specific behaviors.
26797 if (isConstOrConstSplat(N: RHS) && Shuf0 && all_equal(Range: Shuf0->getMask()) &&
26798 Shuf0->hasOneUse() && Shuf0->getOperand(Num: 1).isUndef() &&
26799 Shuf0->getOperand(Num: 0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
26800 // binop (splat X), (splat C) --> splat (binop X, C)
26801 SDValue X = Shuf0->getOperand(Num: 0);
26802 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: X, N2: RHS, Flags);
26803 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: DAG.getUNDEF(VT),
26804 Mask: Shuf0->getMask());
26805 }
26806 if (isConstOrConstSplat(N: LHS) && Shuf1 && all_equal(Range: Shuf1->getMask()) &&
26807 Shuf1->hasOneUse() && Shuf1->getOperand(Num: 1).isUndef() &&
26808 Shuf1->getOperand(Num: 0).getOpcode() != ISD::INSERT_VECTOR_ELT) {
26809 // binop (splat C), (splat X) --> splat (binop C, X)
26810 SDValue X = Shuf1->getOperand(Num: 0);
26811 SDValue NewBinOp = DAG.getNode(Opcode, DL, VT, N1: LHS, N2: X, Flags);
26812 return DAG.getVectorShuffle(VT, dl: DL, N1: NewBinOp, N2: DAG.getUNDEF(VT),
26813 Mask: Shuf1->getMask());
26814 }
26815 }
26816
26817 // The following pattern is likely to emerge with vector reduction ops. Moving
26818 // the binary operation ahead of insertion may allow using a narrower vector
26819 // instruction that has better performance than the wide version of the op:
26820 // VBinOp (ins undef, X, Z), (ins undef, Y, Z) --> ins VecC, (VBinOp X, Y), Z
26821 if (LHS.getOpcode() == ISD::INSERT_SUBVECTOR && LHS.getOperand(i: 0).isUndef() &&
26822 RHS.getOpcode() == ISD::INSERT_SUBVECTOR && RHS.getOperand(i: 0).isUndef() &&
26823 LHS.getOperand(i: 2) == RHS.getOperand(i: 2) &&
26824 (LHS.hasOneUse() || RHS.hasOneUse())) {
26825 SDValue X = LHS.getOperand(i: 1);
26826 SDValue Y = RHS.getOperand(i: 1);
26827 SDValue Z = LHS.getOperand(i: 2);
26828 EVT NarrowVT = X.getValueType();
26829 if (NarrowVT == Y.getValueType() &&
26830 TLI.isOperationLegalOrCustomOrPromote(Op: Opcode, VT: NarrowVT,
26831 LegalOnly: LegalOperations)) {
26832 // (binop undef, undef) may not return undef, so compute that result.
26833 SDValue VecC =
26834 DAG.getNode(Opcode, DL, VT, N1: DAG.getUNDEF(VT), N2: DAG.getUNDEF(VT));
26835 SDValue NarrowBO = DAG.getNode(Opcode, DL, VT: NarrowVT, N1: X, N2: Y);
26836 return DAG.getNode(Opcode: ISD::INSERT_SUBVECTOR, DL, VT, N1: VecC, N2: NarrowBO, N3: Z);
26837 }
26838 }
26839
26840 // Make sure all but the first op are undef or constant.
26841 auto ConcatWithConstantOrUndef = [](SDValue Concat) {
26842 return Concat.getOpcode() == ISD::CONCAT_VECTORS &&
26843 all_of(Range: drop_begin(RangeOrContainer: Concat->ops()), P: [](const SDValue &Op) {
26844 return Op.isUndef() ||
26845 ISD::isBuildVectorOfConstantSDNodes(N: Op.getNode());
26846 });
26847 };
26848
26849 // The following pattern is likely to emerge with vector reduction ops. Moving
26850 // the binary operation ahead of the concat may allow using a narrower vector
26851 // instruction that has better performance than the wide version of the op:
26852 // VBinOp (concat X, undef/constant), (concat Y, undef/constant) -->
26853 // concat (VBinOp X, Y), VecC
26854 if (ConcatWithConstantOrUndef(LHS) && ConcatWithConstantOrUndef(RHS) &&
26855 (LHS.hasOneUse() || RHS.hasOneUse())) {
26856 EVT NarrowVT = LHS.getOperand(i: 0).getValueType();
26857 if (NarrowVT == RHS.getOperand(i: 0).getValueType() &&
26858 TLI.isOperationLegalOrCustomOrPromote(Op: Opcode, VT: NarrowVT)) {
26859 unsigned NumOperands = LHS.getNumOperands();
26860 SmallVector<SDValue, 4> ConcatOps;
26861 for (unsigned i = 0; i != NumOperands; ++i) {
26862 // This constant fold for operands 1 and up.
26863 ConcatOps.push_back(Elt: DAG.getNode(Opcode, DL, VT: NarrowVT, N1: LHS.getOperand(i),
26864 N2: RHS.getOperand(i)));
26865 }
26866
26867 return DAG.getNode(Opcode: ISD::CONCAT_VECTORS, DL, VT, Ops: ConcatOps);
26868 }
26869 }
26870
26871 if (SDValue V = scalarizeBinOpOfSplats(N, DAG, DL))
26872 return V;
26873
26874 return SDValue();
26875}
26876
26877SDValue DAGCombiner::SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1,
26878 SDValue N2) {
26879 assert(N0.getOpcode() == ISD::SETCC &&
26880 "First argument must be a SetCC node!");
26881
26882 SDValue SCC = SimplifySelectCC(DL, N0: N0.getOperand(i: 0), N1: N0.getOperand(i: 1), N2: N1, N3: N2,
26883 CC: cast<CondCodeSDNode>(Val: N0.getOperand(i: 2))->get());
26884
26885 // If we got a simplified select_cc node back from SimplifySelectCC, then
26886 // break it down into a new SETCC node, and a new SELECT node, and then return
26887 // the SELECT node, since we were called with a SELECT node.
26888 if (SCC.getNode()) {
26889 // Check to see if we got a select_cc back (to turn into setcc/select).
26890 // Otherwise, just return whatever node we got back, like fabs.
26891 if (SCC.getOpcode() == ISD::SELECT_CC) {
26892 const SDNodeFlags Flags = N0->getFlags();
26893 SDValue SETCC = DAG.getNode(Opcode: ISD::SETCC, DL: SDLoc(N0),
26894 VT: N0.getValueType(),
26895 N1: SCC.getOperand(i: 0), N2: SCC.getOperand(i: 1),
26896 N3: SCC.getOperand(i: 4), Flags);
26897 AddToWorklist(N: SETCC.getNode());
26898 SDValue SelectNode = DAG.getSelect(DL: SDLoc(SCC), VT: SCC.getValueType(), Cond: SETCC,
26899 LHS: SCC.getOperand(i: 2), RHS: SCC.getOperand(i: 3));
26900 SelectNode->setFlags(Flags);
26901 return SelectNode;
26902 }
26903
26904 return SCC;
26905 }
26906 return SDValue();
26907}
26908
26909/// Given a SELECT or a SELECT_CC node, where LHS and RHS are the two values
26910/// being selected between, see if we can simplify the select. Callers of this
26911/// should assume that TheSelect is deleted if this returns true. As such, they
26912/// should return the appropriate thing (e.g. the node) back to the top-level of
26913/// the DAG combiner loop to avoid it being looked at.
26914bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS,
26915 SDValue RHS) {
26916 // fold (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
26917 // The select + setcc is redundant, because fsqrt returns NaN for X < 0.
26918 if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(N: LHS)) {
26919 if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) {
26920 // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?))
26921 SDValue Sqrt = RHS;
26922 ISD::CondCode CC;
26923 SDValue CmpLHS;
26924 const ConstantFPSDNode *Zero = nullptr;
26925
26926 if (TheSelect->getOpcode() == ISD::SELECT_CC) {
26927 CC = cast<CondCodeSDNode>(Val: TheSelect->getOperand(Num: 4))->get();
26928 CmpLHS = TheSelect->getOperand(Num: 0);
26929 Zero = isConstOrConstSplatFP(N: TheSelect->getOperand(Num: 1));
26930 } else {
26931 // SELECT or VSELECT
26932 SDValue Cmp = TheSelect->getOperand(Num: 0);
26933 if (Cmp.getOpcode() == ISD::SETCC) {
26934 CC = cast<CondCodeSDNode>(Val: Cmp.getOperand(i: 2))->get();
26935 CmpLHS = Cmp.getOperand(i: 0);
26936 Zero = isConstOrConstSplatFP(N: Cmp.getOperand(i: 1));
26937 }
26938 }
26939 if (Zero && Zero->isZero() &&
26940 Sqrt.getOperand(i: 0) == CmpLHS && (CC == ISD::SETOLT ||
26941 CC == ISD::SETULT || CC == ISD::SETLT)) {
26942 // We have: (select (setcc x, [+-]0.0, *lt), NaN, (fsqrt x))
26943 CombineTo(N: TheSelect, Res: Sqrt);
26944 return true;
26945 }
26946 }
26947 }
26948 // Cannot simplify select with vector condition
26949 if (TheSelect->getOperand(Num: 0).getValueType().isVector()) return false;
26950
26951 // If this is a select from two identical things, try to pull the operation
26952 // through the select.
26953 if (LHS.getOpcode() != RHS.getOpcode() ||
26954 !LHS.hasOneUse() || !RHS.hasOneUse())
26955 return false;
26956
26957 // If this is a load and the token chain is identical, replace the select
26958 // of two loads with a load through a select of the address to load from.
26959 // This triggers in things like "select bool X, 10.0, 123.0" after the FP
26960 // constants have been dropped into the constant pool.
26961 if (LHS.getOpcode() == ISD::LOAD) {
26962 LoadSDNode *LLD = cast<LoadSDNode>(Val&: LHS);
26963 LoadSDNode *RLD = cast<LoadSDNode>(Val&: RHS);
26964
26965 // Token chains must be identical.
26966 if (LHS.getOperand(i: 0) != RHS.getOperand(i: 0) ||
26967 // Do not let this transformation reduce the number of volatile loads.
26968 // Be conservative for atomics for the moment
26969 // TODO: This does appear to be legal for unordered atomics (see D66309)
26970 !LLD->isSimple() || !RLD->isSimple() ||
26971 // FIXME: If either is a pre/post inc/dec load,
26972 // we'd need to split out the address adjustment.
26973 LLD->isIndexed() || RLD->isIndexed() ||
26974 // If this is an EXTLOAD, the VT's must match.
26975 LLD->getMemoryVT() != RLD->getMemoryVT() ||
26976 // If this is an EXTLOAD, the kind of extension must match.
26977 (LLD->getExtensionType() != RLD->getExtensionType() &&
26978 // The only exception is if one of the extensions is anyext.
26979 LLD->getExtensionType() != ISD::EXTLOAD &&
26980 RLD->getExtensionType() != ISD::EXTLOAD) ||
26981 // FIXME: this discards src value information. This is
26982 // over-conservative. It would be beneficial to be able to remember
26983 // both potential memory locations. Since we are discarding
26984 // src value info, don't do the transformation if the memory
26985 // locations are not in the default address space.
26986 LLD->getPointerInfo().getAddrSpace() != 0 ||
26987 RLD->getPointerInfo().getAddrSpace() != 0 ||
26988 // We can't produce a CMOV of a TargetFrameIndex since we won't
26989 // generate the address generation required.
26990 LLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
26991 RLD->getBasePtr().getOpcode() == ISD::TargetFrameIndex ||
26992 !TLI.isOperationLegalOrCustom(Op: TheSelect->getOpcode(),
26993 VT: LLD->getBasePtr().getValueType()))
26994 return false;
26995
26996 // The loads must not depend on one another.
26997 if (LLD->isPredecessorOf(N: RLD) || RLD->isPredecessorOf(N: LLD))
26998 return false;
26999
27000 // Check that the select condition doesn't reach either load. If so,
27001 // folding this will induce a cycle into the DAG. If not, this is safe to
27002 // xform, so create a select of the addresses.
27003
27004 SmallPtrSet<const SDNode *, 32> Visited;
27005 SmallVector<const SDNode *, 16> Worklist;
27006
27007 // Always fail if LLD and RLD are not independent. TheSelect is a
27008 // predecessor to all Nodes in question so we need not search past it.
27009
27010 Visited.insert(Ptr: TheSelect);
27011 Worklist.push_back(Elt: LLD);
27012 Worklist.push_back(Elt: RLD);
27013
27014 if (SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist) ||
27015 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist))
27016 return false;
27017
27018 SDValue Addr;
27019 if (TheSelect->getOpcode() == ISD::SELECT) {
27020 // We cannot do this optimization if any pair of {RLD, LLD} is a
27021 // predecessor to {RLD, LLD, CondNode}. As we've already compared the
27022 // Loads, we only need to check if CondNode is a successor to one of the
27023 // loads. We can further avoid this if there's no use of their chain
27024 // value.
27025 SDNode *CondNode = TheSelect->getOperand(Num: 0).getNode();
27026 Worklist.push_back(Elt: CondNode);
27027
27028 if ((LLD->hasAnyUseOfValue(Value: 1) &&
27029 SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist)) ||
27030 (RLD->hasAnyUseOfValue(Value: 1) &&
27031 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist)))
27032 return false;
27033
27034 Addr = DAG.getSelect(DL: SDLoc(TheSelect),
27035 VT: LLD->getBasePtr().getValueType(),
27036 Cond: TheSelect->getOperand(Num: 0), LHS: LLD->getBasePtr(),
27037 RHS: RLD->getBasePtr());
27038 } else { // Otherwise SELECT_CC
27039 // We cannot do this optimization if any pair of {RLD, LLD} is a
27040 // predecessor to {RLD, LLD, CondLHS, CondRHS}. As we've already compared
27041 // the Loads, we only need to check if CondLHS/CondRHS is a successor to
27042 // one of the loads. We can further avoid this if there's no use of their
27043 // chain value.
27044
27045 SDNode *CondLHS = TheSelect->getOperand(Num: 0).getNode();
27046 SDNode *CondRHS = TheSelect->getOperand(Num: 1).getNode();
27047 Worklist.push_back(Elt: CondLHS);
27048 Worklist.push_back(Elt: CondRHS);
27049
27050 if ((LLD->hasAnyUseOfValue(Value: 1) &&
27051 SDNode::hasPredecessorHelper(N: LLD, Visited, Worklist)) ||
27052 (RLD->hasAnyUseOfValue(Value: 1) &&
27053 SDNode::hasPredecessorHelper(N: RLD, Visited, Worklist)))
27054 return false;
27055
27056 Addr = DAG.getNode(Opcode: ISD::SELECT_CC, DL: SDLoc(TheSelect),
27057 VT: LLD->getBasePtr().getValueType(),
27058 N1: TheSelect->getOperand(Num: 0),
27059 N2: TheSelect->getOperand(Num: 1),
27060 N3: LLD->getBasePtr(), N4: RLD->getBasePtr(),
27061 N5: TheSelect->getOperand(Num: 4));
27062 }
27063
27064 SDValue Load;
27065 // It is safe to replace the two loads if they have different alignments,
27066 // but the new load must be the minimum (most restrictive) alignment of the
27067 // inputs.
27068 Align Alignment = std::min(a: LLD->getAlign(), b: RLD->getAlign());
27069 MachineMemOperand::Flags MMOFlags = LLD->getMemOperand()->getFlags();
27070 if (!RLD->isInvariant())
27071 MMOFlags &= ~MachineMemOperand::MOInvariant;
27072 if (!RLD->isDereferenceable())
27073 MMOFlags &= ~MachineMemOperand::MODereferenceable;
27074 if (LLD->getExtensionType() == ISD::NON_EXTLOAD) {
27075 // FIXME: Discards pointer and AA info.
27076 Load = DAG.getLoad(VT: TheSelect->getValueType(ResNo: 0), dl: SDLoc(TheSelect),
27077 Chain: LLD->getChain(), Ptr: Addr, PtrInfo: MachinePointerInfo(), Alignment,
27078 MMOFlags);
27079 } else {
27080 // FIXME: Discards pointer and AA info.
27081 Load = DAG.getExtLoad(
27082 ExtType: LLD->getExtensionType() == ISD::EXTLOAD ? RLD->getExtensionType()
27083 : LLD->getExtensionType(),
27084 dl: SDLoc(TheSelect), VT: TheSelect->getValueType(ResNo: 0), Chain: LLD->getChain(), Ptr: Addr,
27085 PtrInfo: MachinePointerInfo(), MemVT: LLD->getMemoryVT(), Alignment, MMOFlags);
27086 }
27087
27088 // Users of the select now use the result of the load.
27089 CombineTo(N: TheSelect, Res: Load);
27090
27091 // Users of the old loads now use the new load's chain. We know the
27092 // old-load value is dead now.
27093 CombineTo(N: LHS.getNode(), Res0: Load.getValue(R: 0), Res1: Load.getValue(R: 1));
27094 CombineTo(N: RHS.getNode(), Res0: Load.getValue(R: 0), Res1: Load.getValue(R: 1));
27095 return true;
27096 }
27097
27098 return false;
27099}
27100
27101/// Try to fold an expression of the form (N0 cond N1) ? N2 : N3 to a shift and
27102/// bitwise 'and'.
27103SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
27104 SDValue N1, SDValue N2, SDValue N3,
27105 ISD::CondCode CC) {
27106 // If this is a select where the false operand is zero and the compare is a
27107 // check of the sign bit, see if we can perform the "gzip trick":
27108 // select_cc setlt X, 0, A, 0 -> and (sra X, size(X)-1), A
27109 // select_cc setgt X, 0, A, 0 -> and (not (sra X, size(X)-1)), A
27110 EVT XType = N0.getValueType();
27111 EVT AType = N2.getValueType();
27112 if (!isNullConstant(V: N3) || !XType.bitsGE(VT: AType))
27113 return SDValue();
27114
27115 // If the comparison is testing for a positive value, we have to invert
27116 // the sign bit mask, so only do that transform if the target has a bitwise
27117 // 'and not' instruction (the invert is free).
27118 if (CC == ISD::SETGT && TLI.hasAndNot(X: N2)) {
27119 // (X > -1) ? A : 0
27120 // (X > 0) ? X : 0 <-- This is canonical signed max.
27121 if (!(isAllOnesConstant(V: N1) || (isNullConstant(V: N1) && N0 == N2)))
27122 return SDValue();
27123 } else if (CC == ISD::SETLT) {
27124 // (X < 0) ? A : 0
27125 // (X < 1) ? X : 0 <-- This is un-canonicalized signed min.
27126 if (!(isNullConstant(V: N1) || (isOneConstant(V: N1) && N0 == N2)))
27127 return SDValue();
27128 } else {
27129 return SDValue();
27130 }
27131
27132 // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
27133 // constant.
27134 EVT ShiftAmtTy = getShiftAmountTy(LHSTy: N0.getValueType());
27135 auto *N2C = dyn_cast<ConstantSDNode>(Val: N2.getNode());
27136 if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
27137 unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
27138 if (!TLI.shouldAvoidTransformToShift(VT: XType, Amount: ShCt)) {
27139 SDValue ShiftAmt = DAG.getConstant(Val: ShCt, DL, VT: ShiftAmtTy);
27140 SDValue Shift = DAG.getNode(Opcode: ISD::SRL, DL, VT: XType, N1: N0, N2: ShiftAmt);
27141 AddToWorklist(N: Shift.getNode());
27142
27143 if (XType.bitsGT(VT: AType)) {
27144 Shift = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: AType, Operand: Shift);
27145 AddToWorklist(N: Shift.getNode());
27146 }
27147
27148 if (CC == ISD::SETGT)
27149 Shift = DAG.getNOT(DL, Val: Shift, VT: AType);
27150
27151 return DAG.getNode(Opcode: ISD::AND, DL, VT: AType, N1: Shift, N2);
27152 }
27153 }
27154
27155 unsigned ShCt = XType.getSizeInBits() - 1;
27156 if (TLI.shouldAvoidTransformToShift(VT: XType, Amount: ShCt))
27157 return SDValue();
27158
27159 SDValue ShiftAmt = DAG.getConstant(Val: ShCt, DL, VT: ShiftAmtTy);
27160 SDValue Shift = DAG.getNode(Opcode: ISD::SRA, DL, VT: XType, N1: N0, N2: ShiftAmt);
27161 AddToWorklist(N: Shift.getNode());
27162
27163 if (XType.bitsGT(VT: AType)) {
27164 Shift = DAG.getNode(Opcode: ISD::TRUNCATE, DL, VT: AType, Operand: Shift);
27165 AddToWorklist(N: Shift.getNode());
27166 }
27167
27168 if (CC == ISD::SETGT)
27169 Shift = DAG.getNOT(DL, Val: Shift, VT: AType);
27170
27171 return DAG.getNode(Opcode: ISD::AND, DL, VT: AType, N1: Shift, N2);
27172}
27173
27174// Fold select(cc, binop(), binop()) -> binop(select(), select()) etc.
27175SDValue DAGCombiner::foldSelectOfBinops(SDNode *N) {
27176 SDValue N0 = N->getOperand(Num: 0);
27177 SDValue N1 = N->getOperand(Num: 1);
27178 SDValue N2 = N->getOperand(Num: 2);
27179 SDLoc DL(N);
27180
27181 unsigned BinOpc = N1.getOpcode();
27182 if (!TLI.isBinOp(Opcode: BinOpc) || (N2.getOpcode() != BinOpc) ||
27183 (N1.getResNo() != N2.getResNo()))
27184 return SDValue();
27185
27186 // The use checks are intentionally on SDNode because we may be dealing
27187 // with opcodes that produce more than one SDValue.
27188 // TODO: Do we really need to check N0 (the condition operand of the select)?
27189 // But removing that clause could cause an infinite loop...
27190 if (!N0->hasOneUse() || !N1->hasOneUse() || !N2->hasOneUse())
27191 return SDValue();
27192
27193 // Binops may include opcodes that return multiple values, so all values
27194 // must be created/propagated from the newly created binops below.
27195 SDVTList OpVTs = N1->getVTList();
27196
27197 // Fold select(cond, binop(x, y), binop(z, y))
27198 // --> binop(select(cond, x, z), y)
27199 if (N1.getOperand(i: 1) == N2.getOperand(i: 1)) {
27200 SDValue N10 = N1.getOperand(i: 0);
27201 SDValue N20 = N2.getOperand(i: 0);
27202 SDValue NewSel = DAG.getSelect(DL, VT: N10.getValueType(), Cond: N0, LHS: N10, RHS: N20);
27203 SDValue NewBinOp = DAG.getNode(Opcode: BinOpc, DL, VTList: OpVTs, N1: NewSel, N2: N1.getOperand(i: 1));
27204 NewBinOp->setFlags(N1->getFlags());
27205 NewBinOp->intersectFlagsWith(Flags: N2->getFlags());
27206 return SDValue(NewBinOp.getNode(), N1.getResNo());
27207 }
27208
27209 // Fold select(cond, binop(x, y), binop(x, z))
27210 // --> binop(x, select(cond, y, z))
27211 if (N1.getOperand(i: 0) == N2.getOperand(i: 0)) {
27212 SDValue N11 = N1.getOperand(i: 1);
27213 SDValue N21 = N2.getOperand(i: 1);
27214 // Second op VT might be different (e.g. shift amount type)
27215 if (N11.getValueType() == N21.getValueType()) {
27216 SDValue NewSel = DAG.getSelect(DL, VT: N11.getValueType(), Cond: N0, LHS: N11, RHS: N21);
27217 SDValue NewBinOp =
27218 DAG.getNode(Opcode: BinOpc, DL, VTList: OpVTs, N1: N1.getOperand(i: 0), N2: NewSel);
27219 NewBinOp->setFlags(N1->getFlags());
27220 NewBinOp->intersectFlagsWith(Flags: N2->getFlags());
27221 return SDValue(NewBinOp.getNode(), N1.getResNo());
27222 }
27223 }
27224
27225 // TODO: Handle isCommutativeBinOp patterns as well?
27226 return SDValue();
27227}
27228
27229// Transform (fneg/fabs (bitconvert x)) to avoid loading constant pool values.
27230SDValue DAGCombiner::foldSignChangeInBitcast(SDNode *N) {
27231 SDValue N0 = N->getOperand(Num: 0);
27232 EVT VT = N->getValueType(ResNo: 0);
27233 bool IsFabs = N->getOpcode() == ISD::FABS;
27234 bool IsFree = IsFabs ? TLI.isFAbsFree(VT) : TLI.isFNegFree(VT);
27235
27236 if (IsFree || N0.getOpcode() != ISD::BITCAST || !N0.hasOneUse())
27237 return SDValue();
27238
27239 SDValue Int = N0.getOperand(i: 0);
27240 EVT IntVT = Int.getValueType();
27241
27242 // The operand to cast should be integer.
27243 if (!IntVT.isInteger() || IntVT.isVector())
27244 return SDValue();
27245
27246 // (fneg (bitconvert x)) -> (bitconvert (xor x sign))
27247 // (fabs (bitconvert x)) -> (bitconvert (and x ~sign))
27248 APInt SignMask;
27249 if (N0.getValueType().isVector()) {
27250 // For vector, create a sign mask (0x80...) or its inverse (for fabs,
27251 // 0x7f...) per element and splat it.
27252 SignMask = APInt::getSignMask(BitWidth: N0.getScalarValueSizeInBits());
27253 if (IsFabs)
27254 SignMask = ~SignMask;
27255 SignMask = APInt::getSplat(NewLen: IntVT.getSizeInBits(), V: SignMask);
27256 } else {
27257 // For scalar, just use the sign mask (0x80... or the inverse, 0x7f...)
27258 SignMask = APInt::getSignMask(BitWidth: IntVT.getSizeInBits());
27259 if (IsFabs)
27260 SignMask = ~SignMask;
27261 }
27262 SDLoc DL(N0);
27263 Int = DAG.getNode(Opcode: IsFabs ? ISD::AND : ISD::XOR, DL, VT: IntVT, N1: Int,
27264 N2: DAG.getConstant(Val: SignMask, DL, VT: IntVT));
27265 AddToWorklist(N: Int.getNode());
27266 return DAG.getBitcast(VT, V: Int);
27267}
27268
27269/// Turn "(a cond b) ? 1.0f : 2.0f" into "load (tmp + ((a cond b) ? 0 : 4)"
27270/// where "tmp" is a constant pool entry containing an array with 1.0 and 2.0
27271/// in it. This may be a win when the constant is not otherwise available
27272/// because it replaces two constant pool loads with one.
27273SDValue DAGCombiner::convertSelectOfFPConstantsToLoadOffset(
27274 const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2, SDValue N3,
27275 ISD::CondCode CC) {
27276 if (!TLI.reduceSelectOfFPConstantLoads(CmpOpVT: N0.getValueType()))
27277 return SDValue();
27278
27279 // If we are before legalize types, we want the other legalization to happen
27280 // first (for example, to avoid messing with soft float).
27281 auto *TV = dyn_cast<ConstantFPSDNode>(Val&: N2);
27282 auto *FV = dyn_cast<ConstantFPSDNode>(Val&: N3);
27283 EVT VT = N2.getValueType();
27284 if (!TV || !FV || !TLI.isTypeLegal(VT))
27285 return SDValue();
27286
27287 // If a constant can be materialized without loads, this does not make sense.
27288 if (TLI.getOperationAction(Op: ISD::ConstantFP, VT) == TargetLowering::Legal ||
27289 TLI.isFPImmLegal(TV->getValueAPF(), TV->getValueType(ResNo: 0), ForCodeSize) ||
27290 TLI.isFPImmLegal(FV->getValueAPF(), FV->getValueType(ResNo: 0), ForCodeSize))
27291 return SDValue();
27292
27293 // If both constants have multiple uses, then we won't need to do an extra
27294 // load. The values are likely around in registers for other users.
27295 if (!TV->hasOneUse() && !FV->hasOneUse())
27296 return SDValue();
27297
27298 Constant *Elts[] = { const_cast<ConstantFP*>(FV->getConstantFPValue()),
27299 const_cast<ConstantFP*>(TV->getConstantFPValue()) };
27300 Type *FPTy = Elts[0]->getType();
27301 const DataLayout &TD = DAG.getDataLayout();
27302
27303 // Create a ConstantArray of the two constants.
27304 Constant *CA = ConstantArray::get(T: ArrayType::get(ElementType: FPTy, NumElements: 2), V: Elts);
27305 SDValue CPIdx = DAG.getConstantPool(C: CA, VT: TLI.getPointerTy(DL: DAG.getDataLayout()),
27306 Align: TD.getPrefTypeAlign(Ty: FPTy));
27307 Align Alignment = cast<ConstantPoolSDNode>(Val&: CPIdx)->getAlign();
27308
27309 // Get offsets to the 0 and 1 elements of the array, so we can select between
27310 // them.
27311 SDValue Zero = DAG.getIntPtrConstant(Val: 0, DL);
27312 unsigned EltSize = (unsigned)TD.getTypeAllocSize(Ty: Elts[0]->getType());
27313 SDValue One = DAG.getIntPtrConstant(Val: EltSize, DL: SDLoc(FV));
27314 SDValue Cond =
27315 DAG.getSetCC(DL, VT: getSetCCResultType(VT: N0.getValueType()), LHS: N0, RHS: N1, Cond: CC);
27316 AddToWorklist(N: Cond.getNode());
27317 SDValue CstOffset = DAG.getSelect(DL, VT: Zero.getValueType(), Cond, LHS: One, RHS: Zero);
27318 AddToWorklist(N: CstOffset.getNode());
27319 CPIdx = DAG.getNode(Opcode: ISD::ADD, DL, VT: CPIdx.getValueType(), N1: CPIdx, N2: CstOffset);
27320 AddToWorklist(N: CPIdx.getNode());
27321 return DAG.getLoad(VT: TV->getValueType(ResNo: 0), dl: DL, Chain: DAG.getEntryNode(), Ptr: CPIdx,
27322 PtrInfo: MachinePointerInfo::getConstantPool(
27323 MF&: DAG.getMachineFunction()), Alignment);
27324}
27325
27326/// Simplify an expression of the form (N0 cond N1) ? N2 : N3
27327/// where 'cond' is the comparison specified by CC.
27328SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
27329 SDValue N2, SDValue N3, ISD::CondCode CC,
27330 bool NotExtCompare) {
27331 // (x ? y : y) -> y.
27332 if (N2 == N3) return N2;
27333
27334 EVT CmpOpVT = N0.getValueType();
27335 EVT CmpResVT = getSetCCResultType(VT: CmpOpVT);
27336 EVT VT = N2.getValueType();
27337 auto *N1C = dyn_cast<ConstantSDNode>(Val: N1.getNode());
27338 auto *N2C = dyn_cast<ConstantSDNode>(Val: N2.getNode());
27339 auto *N3C = dyn_cast<ConstantSDNode>(Val: N3.getNode());
27340
27341 // Determine if the condition we're dealing with is constant.
27342 if (SDValue SCC = DAG.FoldSetCC(VT: CmpResVT, N1: N0, N2: N1, Cond: CC, dl: DL)) {
27343 AddToWorklist(N: SCC.getNode());
27344 if (auto *SCCC = dyn_cast<ConstantSDNode>(Val&: SCC)) {
27345 // fold select_cc true, x, y -> x
27346 // fold select_cc false, x, y -> y
27347 return !(SCCC->isZero()) ? N2 : N3;
27348 }
27349 }
27350
27351 if (SDValue V =
27352 convertSelectOfFPConstantsToLoadOffset(DL, N0, N1, N2, N3, CC))
27353 return V;
27354
27355 if (SDValue V = foldSelectCCToShiftAnd(DL, N0, N1, N2, N3, CC))
27356 return V;
27357
27358 // fold (select_cc seteq (and x, y), 0, 0, A) -> (and (sra (shl x)) A)
27359 // where y is has a single bit set.
27360 // A plaintext description would be, we can turn the SELECT_CC into an AND
27361 // when the condition can be materialized as an all-ones register. Any
27362 // single bit-test can be materialized as an all-ones register with
27363 // shift-left and shift-right-arith.
27364 if (CC == ISD::SETEQ && N0->getOpcode() == ISD::AND &&
27365 N0->getValueType(ResNo: 0) == VT && isNullConstant(V: N1) && isNullConstant(V: N2)) {
27366 SDValue AndLHS = N0->getOperand(Num: 0);
27367 auto *ConstAndRHS = dyn_cast<ConstantSDNode>(Val: N0->getOperand(Num: 1));
27368 if (ConstAndRHS && ConstAndRHS->getAPIntValue().popcount() == 1) {
27369 // Shift the tested bit over the sign bit.
27370 const APInt &AndMask = ConstAndRHS->getAPIntValue();
27371 if (TLI.shouldFoldSelectWithSingleBitTest(VT, AndMask)) {
27372 unsigned ShCt = AndMask.getBitWidth() - 1;
27373 SDValue ShlAmt =
27374 DAG.getConstant(Val: AndMask.countl_zero(), DL: SDLoc(AndLHS),
27375 VT: getShiftAmountTy(LHSTy: AndLHS.getValueType()));
27376 SDValue Shl = DAG.getNode(Opcode: ISD::SHL, DL: SDLoc(N0), VT, N1: AndLHS, N2: ShlAmt);
27377
27378 // Now arithmetic right shift it all the way over, so the result is
27379 // either all-ones, or zero.
27380 SDValue ShrAmt =
27381 DAG.getConstant(Val: ShCt, DL: SDLoc(Shl),
27382 VT: getShiftAmountTy(LHSTy: Shl.getValueType()));
27383 SDValue Shr = DAG.getNode(Opcode: ISD::SRA, DL: SDLoc(N0), VT, N1: Shl, N2: ShrAmt);
27384
27385 return DAG.getNode(Opcode: ISD::AND, DL, VT, N1: Shr, N2: N3);
27386 }
27387 }
27388 }
27389
27390 // fold select C, 16, 0 -> shl C, 4
27391 bool Fold = N2C && isNullConstant(V: N3) && N2C->getAPIntValue().isPowerOf2();
27392 bool Swap = N3C && isNullConstant(V: N2) && N3C->getAPIntValue().isPowerOf2();
27393
27394 if ((Fold || Swap) &&
27395 TLI.getBooleanContents(Type: CmpOpVT) ==
27396 TargetLowering::ZeroOrOneBooleanContent &&
27397 (!LegalOperations || TLI.isOperationLegal(Op: ISD::SETCC, VT: CmpOpVT))) {
27398
27399 if (Swap) {
27400 CC = ISD::getSetCCInverse(Operation: CC, Type: CmpOpVT);
27401 std::swap(a&: N2C, b&: N3C);
27402 }
27403
27404 // If the caller doesn't want us to simplify this into a zext of a compare,
27405 // don't do it.
27406 if (NotExtCompare && N2C->isOne())
27407 return SDValue();
27408
27409 SDValue Temp, SCC;
27410 // zext (setcc n0, n1)
27411 if (LegalTypes) {
27412 SCC = DAG.getSetCC(DL, VT: CmpResVT, LHS: N0, RHS: N1, Cond: CC);
27413 Temp = DAG.getZExtOrTrunc(Op: SCC, DL: SDLoc(N2), VT);
27414 } else {
27415 SCC = DAG.getSetCC(SDLoc(N0), MVT::i1, N0, N1, CC);
27416 Temp = DAG.getNode(Opcode: ISD::ZERO_EXTEND, DL: SDLoc(N2), VT, Operand: SCC);
27417 }
27418
27419 AddToWorklist(N: SCC.getNode());
27420 AddToWorklist(N: Temp.getNode());
27421
27422 if (N2C->isOne())
27423 return Temp;
27424
27425 unsigned ShCt = N2C->getAPIntValue().logBase2();
27426 if (TLI.shouldAvoidTransformToShift(VT, Amount: ShCt))
27427 return SDValue();
27428
27429 // shl setcc result by log2 n2c
27430 return DAG.getNode(Opcode: ISD::SHL, DL, VT: N2.getValueType(), N1: Temp,
27431 N2: DAG.getConstant(Val: ShCt, DL: SDLoc(Temp),
27432 VT: getShiftAmountTy(LHSTy: Temp.getValueType())));
27433 }
27434
27435 // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
27436 // select_cc seteq X, 0, sizeof(X), ctlz_zero_undef(X) -> ctlz(X)
27437 // select_cc seteq X, 0, sizeof(X), cttz(X) -> cttz(X)
27438 // select_cc seteq X, 0, sizeof(X), cttz_zero_undef(X) -> cttz(X)
27439 // select_cc setne X, 0, ctlz(X), sizeof(X) -> ctlz(X)
27440 // select_cc setne X, 0, ctlz_zero_undef(X), sizeof(X) -> ctlz(X)
27441 // select_cc setne X, 0, cttz(X), sizeof(X) -> cttz(X)
27442 // select_cc setne X, 0, cttz_zero_undef(X), sizeof(X) -> cttz(X)
27443 if (N1C && N1C->isZero() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
27444 SDValue ValueOnZero = N2;
27445 SDValue Count = N3;
27446 // If the condition is NE instead of E, swap the operands.
27447 if (CC == ISD::SETNE)
27448 std::swap(a&: ValueOnZero, b&: Count);
27449 // Check if the value on zero is a constant equal to the bits in the type.
27450 if (auto *ValueOnZeroC = dyn_cast<ConstantSDNode>(Val&: ValueOnZero)) {
27451 if (ValueOnZeroC->getAPIntValue() == VT.getSizeInBits()) {
27452 // If the other operand is cttz/cttz_zero_undef of N0, and cttz is
27453 // legal, combine to just cttz.
27454 if ((Count.getOpcode() == ISD::CTTZ ||
27455 Count.getOpcode() == ISD::CTTZ_ZERO_UNDEF) &&
27456 N0 == Count.getOperand(i: 0) &&
27457 (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTTZ, VT)))
27458 return DAG.getNode(Opcode: ISD::CTTZ, DL, VT, Operand: N0);
27459 // If the other operand is ctlz/ctlz_zero_undef of N0, and ctlz is
27460 // legal, combine to just ctlz.
27461 if ((Count.getOpcode() == ISD::CTLZ ||
27462 Count.getOpcode() == ISD::CTLZ_ZERO_UNDEF) &&
27463 N0 == Count.getOperand(i: 0) &&
27464 (!LegalOperations || TLI.isOperationLegal(Op: ISD::CTLZ, VT)))
27465 return DAG.getNode(Opcode: ISD::CTLZ, DL, VT, Operand: N0);
27466 }
27467 }
27468 }
27469
27470 // Fold select_cc setgt X, -1, C, ~C -> xor (ashr X, BW-1), C
27471 // Fold select_cc setlt X, 0, C, ~C -> xor (ashr X, BW-1), ~C
27472 if (!NotExtCompare && N1C && N2C && N3C &&
27473 N2C->getAPIntValue() == ~N3C->getAPIntValue() &&
27474 ((N1C->isAllOnes() && CC == ISD::SETGT) ||
27475 (N1C->isZero() && CC == ISD::SETLT)) &&
27476 !TLI.shouldAvoidTransformToShift(VT, Amount: CmpOpVT.getScalarSizeInBits() - 1)) {
27477 SDValue ASR = DAG.getNode(
27478 Opcode: ISD::SRA, DL, VT: CmpOpVT, N1: N0,
27479 N2: DAG.getConstant(Val: CmpOpVT.getScalarSizeInBits() - 1, DL, VT: CmpOpVT));
27480 return DAG.getNode(Opcode: ISD::XOR, DL, VT, N1: DAG.getSExtOrTrunc(Op: ASR, DL, VT),
27481 N2: DAG.getSExtOrTrunc(Op: CC == ISD::SETLT ? N3 : N2, DL, VT));
27482 }
27483
27484 if (SDValue S = PerformMinMaxFpToSatCombine(N0, N1, N2, N3, CC, DAG))
27485 return S;
27486 if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N2, N3, CC, DAG))
27487 return S;
27488
27489 return SDValue();
27490}
27491
27492/// This is a stub for TargetLowering::SimplifySetCC.
27493SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
27494 ISD::CondCode Cond, const SDLoc &DL,
27495 bool foldBooleans) {
27496 TargetLowering::DAGCombinerInfo
27497 DagCombineInfo(DAG, Level, false, this);
27498 return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DCI&: DagCombineInfo, dl: DL);
27499}
27500
27501/// Given an ISD::SDIV node expressing a divide by constant, return
27502/// a DAG expression to select that will generate the same value by multiplying
27503/// by a magic number.
27504/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
27505SDValue DAGCombiner::BuildSDIV(SDNode *N) {
27506 // when optimising for minimum size, we don't want to expand a div to a mul
27507 // and a shift.
27508 if (DAG.getMachineFunction().getFunction().hasMinSize())
27509 return SDValue();
27510
27511 SmallVector<SDNode *, 8> Built;
27512 if (SDValue S = TLI.BuildSDIV(N, DAG, IsAfterLegalization: LegalOperations, Created&: Built)) {
27513 for (SDNode *N : Built)
27514 AddToWorklist(N);
27515 return S;
27516 }
27517
27518 return SDValue();
27519}
27520
27521/// Given an ISD::SDIV node expressing a divide by constant power of 2, return a
27522/// DAG expression that will generate the same value by right shifting.
27523SDValue DAGCombiner::BuildSDIVPow2(SDNode *N) {
27524 ConstantSDNode *C = isConstOrConstSplat(N: N->getOperand(Num: 1));
27525 if (!C)
27526 return SDValue();
27527
27528 // Avoid division by zero.
27529 if (C->isZero())
27530 return SDValue();
27531
27532 SmallVector<SDNode *, 8> Built;
27533 if (SDValue S = TLI.BuildSDIVPow2(N, Divisor: C->getAPIntValue(), DAG, Created&: Built)) {
27534 for (SDNode *N : Built)
27535 AddToWorklist(N);
27536 return S;
27537 }
27538
27539 return SDValue();
27540}
27541
27542/// Given an ISD::UDIV node expressing a divide by constant, return a DAG
27543/// expression that will generate the same value by multiplying by a magic
27544/// number.
27545/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
27546SDValue DAGCombiner::BuildUDIV(SDNode *N) {
27547 // when optimising for minimum size, we don't want to expand a div to a mul
27548 // and a shift.
27549 if (DAG.getMachineFunction().getFunction().hasMinSize())
27550 return SDValue();
27551
27552 SmallVector<SDNode *, 8> Built;
27553 if (SDValue S = TLI.BuildUDIV(N, DAG, IsAfterLegalization: LegalOperations, Created&: Built)) {
27554 for (SDNode *N : Built)
27555 AddToWorklist(N);
27556 return S;
27557 }
27558
27559 return SDValue();
27560}
27561
27562/// Given an ISD::SREM node expressing a remainder by constant power of 2,
27563/// return a DAG expression that will generate the same value.
27564SDValue DAGCombiner::BuildSREMPow2(SDNode *N) {
27565 ConstantSDNode *C = isConstOrConstSplat(N: N->getOperand(Num: 1));
27566 if (!C)
27567 return SDValue();
27568
27569 // Avoid division by zero.
27570 if (C->isZero())
27571 return SDValue();
27572
27573 SmallVector<SDNode *, 8> Built;
27574 if (SDValue S = TLI.BuildSREMPow2(N, Divisor: C->getAPIntValue(), DAG, Created&: Built)) {
27575 for (SDNode *N : Built)
27576 AddToWorklist(N);
27577 return S;
27578 }
27579
27580 return SDValue();
27581}
27582
27583// This is basically just a port of takeLog2 from InstCombineMulDivRem.cpp
27584//
27585// Returns the node that represents `Log2(Op)`. This may create a new node. If
27586// we are unable to compute `Log2(Op)` its return `SDValue()`.
27587//
27588// All nodes will be created at `DL` and the output will be of type `VT`.
27589//
27590// This will only return `Log2(Op)` if we can prove `Op` is non-zero. Set
27591// `AssumeNonZero` if this function should simply assume (not require proving
27592// `Op` is non-zero).
27593static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
27594 SDValue Op, unsigned Depth,
27595 bool AssumeNonZero) {
27596 assert(VT.isInteger() && "Only integer types are supported!");
27597
27598 auto PeekThroughCastsAndTrunc = [](SDValue V) {
27599 while (true) {
27600 switch (V.getOpcode()) {
27601 case ISD::TRUNCATE:
27602 case ISD::ZERO_EXTEND:
27603 V = V.getOperand(i: 0);
27604 break;
27605 default:
27606 return V;
27607 }
27608 }
27609 };
27610
27611 if (VT.isScalableVector())
27612 return SDValue();
27613
27614 Op = PeekThroughCastsAndTrunc(Op);
27615
27616 // Helper for determining whether a value is a power-2 constant scalar or a
27617 // vector of such elements.
27618 SmallVector<APInt> Pow2Constants;
27619 auto IsPowerOfTwo = [&Pow2Constants](ConstantSDNode *C) {
27620 if (C->isZero() || C->isOpaque())
27621 return false;
27622 // TODO: We may also be able to support negative powers of 2 here.
27623 if (C->getAPIntValue().isPowerOf2()) {
27624 Pow2Constants.emplace_back(Args: C->getAPIntValue());
27625 return true;
27626 }
27627 return false;
27628 };
27629
27630 if (ISD::matchUnaryPredicate(Op, Match: IsPowerOfTwo)) {
27631 if (!VT.isVector())
27632 return DAG.getConstant(Val: Pow2Constants.back().logBase2(), DL, VT);
27633 // We need to create a build vector
27634 SmallVector<SDValue> Log2Ops;
27635 for (const APInt &Pow2 : Pow2Constants)
27636 Log2Ops.emplace_back(
27637 Args: DAG.getConstant(Val: Pow2.logBase2(), DL, VT: VT.getScalarType()));
27638 return DAG.getBuildVector(VT, DL, Ops: Log2Ops);
27639 }
27640
27641 if (Depth >= DAG.MaxRecursionDepth)
27642 return SDValue();
27643
27644 auto CastToVT = [&](EVT NewVT, SDValue ToCast) {
27645 ToCast = PeekThroughCastsAndTrunc(ToCast);
27646 EVT CurVT = ToCast.getValueType();
27647 if (NewVT == CurVT)
27648 return ToCast;
27649
27650 if (NewVT.getSizeInBits() == CurVT.getSizeInBits())
27651 return DAG.getBitcast(VT: NewVT, V: ToCast);
27652
27653 return DAG.getZExtOrTrunc(Op: ToCast, DL, VT: NewVT);
27654 };
27655
27656 // log2(X << Y) -> log2(X) + Y
27657 if (Op.getOpcode() == ISD::SHL) {
27658 // 1 << Y and X nuw/nsw << Y are all non-zero.
27659 if (AssumeNonZero || Op->getFlags().hasNoUnsignedWrap() ||
27660 Op->getFlags().hasNoSignedWrap() || isOneConstant(V: Op.getOperand(i: 0)))
27661 if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 0),
27662 Depth: Depth + 1, AssumeNonZero))
27663 return DAG.getNode(Opcode: ISD::ADD, DL, VT, N1: LogX,
27664 N2: CastToVT(VT, Op.getOperand(i: 1)));
27665 }
27666
27667 // c ? X : Y -> c ? Log2(X) : Log2(Y)
27668 if ((Op.getOpcode() == ISD::SELECT || Op.getOpcode() == ISD::VSELECT) &&
27669 Op.hasOneUse()) {
27670 if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 1),
27671 Depth: Depth + 1, AssumeNonZero))
27672 if (SDValue LogY = takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 2),
27673 Depth: Depth + 1, AssumeNonZero))
27674 return DAG.getSelect(DL, VT, Cond: Op.getOperand(i: 0), LHS: LogX, RHS: LogY);
27675 }
27676
27677 // log2(umin(X, Y)) -> umin(log2(X), log2(Y))
27678 // log2(umax(X, Y)) -> umax(log2(X), log2(Y))
27679 if ((Op.getOpcode() == ISD::UMIN || Op.getOpcode() == ISD::UMAX) &&
27680 Op.hasOneUse()) {
27681 // Use AssumeNonZero as false here. Otherwise we can hit case where
27682 // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
27683 if (SDValue LogX =
27684 takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 0), Depth: Depth + 1,
27685 /*AssumeNonZero*/ false))
27686 if (SDValue LogY =
27687 takeInexpensiveLog2(DAG, DL, VT, Op: Op.getOperand(i: 1), Depth: Depth + 1,
27688 /*AssumeNonZero*/ false))
27689 return DAG.getNode(Opcode: Op.getOpcode(), DL, VT, N1: LogX, N2: LogY);
27690 }
27691
27692 return SDValue();
27693}
27694
27695/// Determines the LogBase2 value for a non-null input value using the
27696/// transform: LogBase2(V) = (EltBits - 1) - ctlz(V).
27697SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL,
27698 bool KnownNonZero, bool InexpensiveOnly,
27699 std::optional<EVT> OutVT) {
27700 EVT VT = OutVT ? *OutVT : V.getValueType();
27701 SDValue InexpensiveLogBase2 =
27702 takeInexpensiveLog2(DAG, DL, VT, Op: V, /*Depth*/ 0, AssumeNonZero: KnownNonZero);
27703 if (InexpensiveLogBase2 || InexpensiveOnly || !DAG.isKnownToBeAPowerOfTwo(Val: V))
27704 return InexpensiveLogBase2;
27705
27706 SDValue Ctlz = DAG.getNode(Opcode: ISD::CTLZ, DL, VT, Operand: V);
27707 SDValue Base = DAG.getConstant(Val: VT.getScalarSizeInBits() - 1, DL, VT);
27708 SDValue LogBase2 = DAG.getNode(Opcode: ISD::SUB, DL, VT, N1: Base, N2: Ctlz);
27709 return LogBase2;
27710}
27711
27712/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
27713/// For the reciprocal, we need to find the zero of the function:
27714/// F(X) = 1/X - A [which has a zero at X = 1/A]
27715/// =>
27716/// X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
27717/// does not require additional intermediate precision]
27718/// For the last iteration, put numerator N into it to gain more precision:
27719/// Result = N X_i + X_i (N - N A X_i)
27720SDValue DAGCombiner::BuildDivEstimate(SDValue N, SDValue Op,
27721 SDNodeFlags Flags) {
27722 if (LegalDAG)
27723 return SDValue();
27724
27725 // TODO: Handle extended types?
27726 EVT VT = Op.getValueType();
27727 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
27728 VT.getScalarType() != MVT::f64)
27729 return SDValue();
27730
27731 // If estimates are explicitly disabled for this function, we're done.
27732 MachineFunction &MF = DAG.getMachineFunction();
27733 int Enabled = TLI.getRecipEstimateDivEnabled(VT, MF);
27734 if (Enabled == TLI.ReciprocalEstimate::Disabled)
27735 return SDValue();
27736
27737 // Estimates may be explicitly enabled for this type with a custom number of
27738 // refinement steps.
27739 int Iterations = TLI.getDivRefinementSteps(VT, MF);
27740 if (SDValue Est = TLI.getRecipEstimate(Operand: Op, DAG, Enabled, RefinementSteps&: Iterations)) {
27741 AddToWorklist(N: Est.getNode());
27742
27743 SDLoc DL(Op);
27744 if (Iterations) {
27745 SDValue FPOne = DAG.getConstantFP(Val: 1.0, DL, VT);
27746
27747 // Newton iterations: Est = Est + Est (N - Arg * Est)
27748 // If this is the last iteration, also multiply by the numerator.
27749 for (int i = 0; i < Iterations; ++i) {
27750 SDValue MulEst = Est;
27751
27752 if (i == Iterations - 1) {
27753 MulEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: N, N2: Est, Flags);
27754 AddToWorklist(N: MulEst.getNode());
27755 }
27756
27757 SDValue NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Op, N2: MulEst, Flags);
27758 AddToWorklist(N: NewEst.getNode());
27759
27760 NewEst = DAG.getNode(Opcode: ISD::FSUB, DL, VT,
27761 N1: (i == Iterations - 1 ? N : FPOne), N2: NewEst, Flags);
27762 AddToWorklist(N: NewEst.getNode());
27763
27764 NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: NewEst, Flags);
27765 AddToWorklist(N: NewEst.getNode());
27766
27767 Est = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: MulEst, N2: NewEst, Flags);
27768 AddToWorklist(N: Est.getNode());
27769 }
27770 } else {
27771 // If no iterations are available, multiply with N.
27772 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: N, Flags);
27773 AddToWorklist(N: Est.getNode());
27774 }
27775
27776 return Est;
27777 }
27778
27779 return SDValue();
27780}
27781
27782/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
27783/// For the reciprocal sqrt, we need to find the zero of the function:
27784/// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
27785/// =>
27786/// X_{i+1} = X_i (1.5 - A X_i^2 / 2)
27787/// As a result, we precompute A/2 prior to the iteration loop.
27788SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
27789 unsigned Iterations,
27790 SDNodeFlags Flags, bool Reciprocal) {
27791 EVT VT = Arg.getValueType();
27792 SDLoc DL(Arg);
27793 SDValue ThreeHalves = DAG.getConstantFP(Val: 1.5, DL, VT);
27794
27795 // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
27796 // this entire sequence requires only one FP constant.
27797 SDValue HalfArg = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: ThreeHalves, N2: Arg, Flags);
27798 HalfArg = DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: HalfArg, N2: Arg, Flags);
27799
27800 // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
27801 for (unsigned i = 0; i < Iterations; ++i) {
27802 SDValue NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: Est, Flags);
27803 NewEst = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: HalfArg, N2: NewEst, Flags);
27804 NewEst = DAG.getNode(Opcode: ISD::FSUB, DL, VT, N1: ThreeHalves, N2: NewEst, Flags);
27805 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: NewEst, Flags);
27806 }
27807
27808 // If non-reciprocal square root is requested, multiply the result by Arg.
27809 if (!Reciprocal)
27810 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: Arg, Flags);
27811
27812 return Est;
27813}
27814
27815/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
27816/// For the reciprocal sqrt, we need to find the zero of the function:
27817/// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
27818/// =>
27819/// X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
27820SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
27821 unsigned Iterations,
27822 SDNodeFlags Flags, bool Reciprocal) {
27823 EVT VT = Arg.getValueType();
27824 SDLoc DL(Arg);
27825 SDValue MinusThree = DAG.getConstantFP(Val: -3.0, DL, VT);
27826 SDValue MinusHalf = DAG.getConstantFP(Val: -0.5, DL, VT);
27827
27828 // This routine must enter the loop below to work correctly
27829 // when (Reciprocal == false).
27830 assert(Iterations > 0);
27831
27832 // Newton iterations for reciprocal square root:
27833 // E = (E * -0.5) * ((A * E) * E + -3.0)
27834 for (unsigned i = 0; i < Iterations; ++i) {
27835 SDValue AE = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Arg, N2: Est, Flags);
27836 SDValue AEE = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AE, N2: Est, Flags);
27837 SDValue RHS = DAG.getNode(Opcode: ISD::FADD, DL, VT, N1: AEE, N2: MinusThree, Flags);
27838
27839 // When calculating a square root at the last iteration build:
27840 // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
27841 // (notice a common subexpression)
27842 SDValue LHS;
27843 if (Reciprocal || (i + 1) < Iterations) {
27844 // RSQRT: LHS = (E * -0.5)
27845 LHS = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: Est, N2: MinusHalf, Flags);
27846 } else {
27847 // SQRT: LHS = (A * E) * -0.5
27848 LHS = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: AE, N2: MinusHalf, Flags);
27849 }
27850
27851 Est = DAG.getNode(Opcode: ISD::FMUL, DL, VT, N1: LHS, N2: RHS, Flags);
27852 }
27853
27854 return Est;
27855}
27856
27857/// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
27858/// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
27859/// Op can be zero.
27860SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
27861 bool Reciprocal) {
27862 if (LegalDAG)
27863 return SDValue();
27864
27865 // TODO: Handle extended types?
27866 EVT VT = Op.getValueType();
27867 if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 &&
27868 VT.getScalarType() != MVT::f64)
27869 return SDValue();
27870
27871 // If estimates are explicitly disabled for this function, we're done.
27872 MachineFunction &MF = DAG.getMachineFunction();
27873 int Enabled = TLI.getRecipEstimateSqrtEnabled(VT, MF);
27874 if (Enabled == TLI.ReciprocalEstimate::Disabled)
27875 return SDValue();
27876
27877 // Estimates may be explicitly enabled for this type with a custom number of
27878 // refinement steps.
27879 int Iterations = TLI.getSqrtRefinementSteps(VT, MF);
27880
27881 bool UseOneConstNR = false;
27882 if (SDValue Est =
27883 TLI.getSqrtEstimate(Operand: Op, DAG, Enabled, RefinementSteps&: Iterations, UseOneConstNR,
27884 Reciprocal)) {
27885 AddToWorklist(N: Est.getNode());
27886
27887 if (Iterations > 0)
27888 Est = UseOneConstNR
27889 ? buildSqrtNROneConst(Arg: Op, Est, Iterations, Flags, Reciprocal)
27890 : buildSqrtNRTwoConst(Arg: Op, Est, Iterations, Flags, Reciprocal);
27891 if (!Reciprocal) {
27892 SDLoc DL(Op);
27893 // Try the target specific test first.
27894 SDValue Test = TLI.getSqrtInputTest(Operand: Op, DAG, Mode: DAG.getDenormalMode(VT));
27895
27896 // The estimate is now completely wrong if the input was exactly 0.0 or
27897 // possibly a denormal. Force the answer to 0.0 or value provided by
27898 // target for those cases.
27899 Est = DAG.getNode(
27900 Opcode: Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
27901 N1: Test, N2: TLI.getSqrtResultForDenormInput(Operand: Op, DAG), N3: Est);
27902 }
27903 return Est;
27904 }
27905
27906 return SDValue();
27907}
27908
27909SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags) {
27910 return buildSqrtEstimateImpl(Op, Flags, Reciprocal: true);
27911}
27912
27913SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags Flags) {
27914 return buildSqrtEstimateImpl(Op, Flags, Reciprocal: false);
27915}
27916
27917/// Return true if there is any possibility that the two addresses overlap.
27918bool DAGCombiner::mayAlias(SDNode *Op0, SDNode *Op1) const {
27919
27920 struct MemUseCharacteristics {
27921 bool IsVolatile;
27922 bool IsAtomic;
27923 SDValue BasePtr;
27924 int64_t Offset;
27925 LocationSize NumBytes;
27926 MachineMemOperand *MMO;
27927 };
27928
27929 auto getCharacteristics = [](SDNode *N) -> MemUseCharacteristics {
27930 if (const auto *LSN = dyn_cast<LSBaseSDNode>(Val: N)) {
27931 int64_t Offset = 0;
27932 if (auto *C = dyn_cast<ConstantSDNode>(Val: LSN->getOffset()))
27933 Offset = (LSN->getAddressingMode() == ISD::PRE_INC) ? C->getSExtValue()
27934 : (LSN->getAddressingMode() == ISD::PRE_DEC)
27935 ? -1 * C->getSExtValue()
27936 : 0;
27937 TypeSize Size = LSN->getMemoryVT().getStoreSize();
27938 return {.IsVolatile: LSN->isVolatile(), .IsAtomic: LSN->isAtomic(),
27939 .BasePtr: LSN->getBasePtr(), .Offset: Offset /*base offset*/,
27940 .NumBytes: LocationSize::precise(Value: Size), .MMO: LSN->getMemOperand()};
27941 }
27942 if (const auto *LN = cast<LifetimeSDNode>(Val: N))
27943 return {.IsVolatile: false /*isVolatile*/,
27944 /*isAtomic*/ .IsAtomic: false,
27945 .BasePtr: LN->getOperand(Num: 1),
27946 .Offset: (LN->hasOffset()) ? LN->getOffset() : 0,
27947 .NumBytes: (LN->hasOffset()) ? LocationSize::precise(Value: LN->getSize())
27948 : LocationSize::beforeOrAfterPointer(),
27949 .MMO: (MachineMemOperand *)nullptr};
27950 // Default.
27951 return {.IsVolatile: false /*isvolatile*/,
27952 /*isAtomic*/ .IsAtomic: false,
27953 .BasePtr: SDValue(),
27954 .Offset: (int64_t)0 /*offset*/,
27955 .NumBytes: LocationSize::beforeOrAfterPointer() /*size*/,
27956 .MMO: (MachineMemOperand *)nullptr};
27957 };
27958
27959 MemUseCharacteristics MUC0 = getCharacteristics(Op0),
27960 MUC1 = getCharacteristics(Op1);
27961
27962 // If they are to the same address, then they must be aliases.
27963 if (MUC0.BasePtr.getNode() && MUC0.BasePtr == MUC1.BasePtr &&
27964 MUC0.Offset == MUC1.Offset)
27965 return true;
27966
27967 // If they are both volatile then they cannot be reordered.
27968 if (MUC0.IsVolatile && MUC1.IsVolatile)
27969 return true;
27970
27971 // Be conservative about atomics for the moment
27972 // TODO: This is way overconservative for unordered atomics (see D66309)
27973 if (MUC0.IsAtomic && MUC1.IsAtomic)
27974 return true;
27975
27976 if (MUC0.MMO && MUC1.MMO) {
27977 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
27978 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
27979 return false;
27980 }
27981
27982 // If NumBytes is scalable and offset is not 0, conservatively return may
27983 // alias
27984 if ((MUC0.NumBytes.hasValue() && MUC0.NumBytes.isScalable() &&
27985 MUC0.Offset != 0) ||
27986 (MUC1.NumBytes.hasValue() && MUC1.NumBytes.isScalable() &&
27987 MUC1.Offset != 0))
27988 return true;
27989 // Try to prove that there is aliasing, or that there is no aliasing. Either
27990 // way, we can return now. If nothing can be proved, proceed with more tests.
27991 bool IsAlias;
27992 if (BaseIndexOffset::computeAliasing(Op0, NumBytes0: MUC0.NumBytes, Op1, NumBytes1: MUC1.NumBytes,
27993 DAG, IsAlias))
27994 return IsAlias;
27995
27996 // The following all rely on MMO0 and MMO1 being valid. Fail conservatively if
27997 // either are not known.
27998 if (!MUC0.MMO || !MUC1.MMO)
27999 return true;
28000
28001 // If one operation reads from invariant memory, and the other may store, they
28002 // cannot alias. These should really be checking the equivalent of mayWrite,
28003 // but it only matters for memory nodes other than load /store.
28004 if ((MUC0.MMO->isInvariant() && MUC1.MMO->isStore()) ||
28005 (MUC1.MMO->isInvariant() && MUC0.MMO->isStore()))
28006 return false;
28007
28008 // If we know required SrcValue1 and SrcValue2 have relatively large
28009 // alignment compared to the size and offset of the access, we may be able
28010 // to prove they do not alias. This check is conservative for now to catch
28011 // cases created by splitting vector types, it only works when the offsets are
28012 // multiples of the size of the data.
28013 int64_t SrcValOffset0 = MUC0.MMO->getOffset();
28014 int64_t SrcValOffset1 = MUC1.MMO->getOffset();
28015 Align OrigAlignment0 = MUC0.MMO->getBaseAlign();
28016 Align OrigAlignment1 = MUC1.MMO->getBaseAlign();
28017 LocationSize Size0 = MUC0.NumBytes;
28018 LocationSize Size1 = MUC1.NumBytes;
28019
28020 if (OrigAlignment0 == OrigAlignment1 && SrcValOffset0 != SrcValOffset1 &&
28021 Size0.hasValue() && Size1.hasValue() && !Size0.isScalable() &&
28022 !Size1.isScalable() && Size0 == Size1 &&
28023 OrigAlignment0 > Size0.getValue().getKnownMinValue() &&
28024 SrcValOffset0 % Size0.getValue().getKnownMinValue() == 0 &&
28025 SrcValOffset1 % Size1.getValue().getKnownMinValue() == 0) {
28026 int64_t OffAlign0 = SrcValOffset0 % OrigAlignment0.value();
28027 int64_t OffAlign1 = SrcValOffset1 % OrigAlignment1.value();
28028
28029 // There is no overlap between these relatively aligned accesses of
28030 // similar size. Return no alias.
28031 if ((OffAlign0 + static_cast<int64_t>(
28032 Size0.getValue().getKnownMinValue())) <= OffAlign1 ||
28033 (OffAlign1 + static_cast<int64_t>(
28034 Size1.getValue().getKnownMinValue())) <= OffAlign0)
28035 return false;
28036 }
28037
28038 bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0
28039 ? CombinerGlobalAA
28040 : DAG.getSubtarget().useAA();
28041#ifndef NDEBUG
28042 if (CombinerAAOnlyFunc.getNumOccurrences() &&
28043 CombinerAAOnlyFunc != DAG.getMachineFunction().getName())
28044 UseAA = false;
28045#endif
28046
28047 if (UseAA && AA && MUC0.MMO->getValue() && MUC1.MMO->getValue() &&
28048 Size0.hasValue() && Size1.hasValue()) {
28049 // Use alias analysis information.
28050 int64_t MinOffset = std::min(a: SrcValOffset0, b: SrcValOffset1);
28051 int64_t Overlap0 =
28052 Size0.getValue().getKnownMinValue() + SrcValOffset0 - MinOffset;
28053 int64_t Overlap1 =
28054 Size1.getValue().getKnownMinValue() + SrcValOffset1 - MinOffset;
28055 LocationSize Loc0 =
28056 Size0.isScalable() ? Size0 : LocationSize::precise(Value: Overlap0);
28057 LocationSize Loc1 =
28058 Size1.isScalable() ? Size1 : LocationSize::precise(Value: Overlap1);
28059 if (AA->isNoAlias(
28060 LocA: MemoryLocation(MUC0.MMO->getValue(), Loc0,
28061 UseTBAA ? MUC0.MMO->getAAInfo() : AAMDNodes()),
28062 LocB: MemoryLocation(MUC1.MMO->getValue(), Loc1,
28063 UseTBAA ? MUC1.MMO->getAAInfo() : AAMDNodes())))
28064 return false;
28065 }
28066
28067 // Otherwise we have to assume they alias.
28068 return true;
28069}
28070
28071/// Walk up chain skipping non-aliasing memory nodes,
28072/// looking for aliasing nodes and adding them to the Aliases vector.
28073void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain,
28074 SmallVectorImpl<SDValue> &Aliases) {
28075 SmallVector<SDValue, 8> Chains; // List of chains to visit.
28076 SmallPtrSet<SDNode *, 16> Visited; // Visited node set.
28077
28078 // Get alias information for node.
28079 // TODO: relax aliasing for unordered atomics (see D66309)
28080 const bool IsLoad = isa<LoadSDNode>(Val: N) && cast<LoadSDNode>(Val: N)->isSimple();
28081
28082 // Starting off.
28083 Chains.push_back(Elt: OriginalChain);
28084 unsigned Depth = 0;
28085
28086 // Attempt to improve chain by a single step
28087 auto ImproveChain = [&](SDValue &C) -> bool {
28088 switch (C.getOpcode()) {
28089 case ISD::EntryToken:
28090 // No need to mark EntryToken.
28091 C = SDValue();
28092 return true;
28093 case ISD::LOAD:
28094 case ISD::STORE: {
28095 // Get alias information for C.
28096 // TODO: Relax aliasing for unordered atomics (see D66309)
28097 bool IsOpLoad = isa<LoadSDNode>(Val: C.getNode()) &&
28098 cast<LSBaseSDNode>(Val: C.getNode())->isSimple();
28099 if ((IsLoad && IsOpLoad) || !mayAlias(Op0: N, Op1: C.getNode())) {
28100 // Look further up the chain.
28101 C = C.getOperand(i: 0);
28102 return true;
28103 }
28104 // Alias, so stop here.
28105 return false;
28106 }
28107
28108 case ISD::CopyFromReg:
28109 // Always forward past CopyFromReg.
28110 C = C.getOperand(i: 0);
28111 return true;
28112
28113 case ISD::LIFETIME_START:
28114 case ISD::LIFETIME_END: {
28115 // We can forward past any lifetime start/end that can be proven not to
28116 // alias the memory access.
28117 if (!mayAlias(Op0: N, Op1: C.getNode())) {
28118 // Look further up the chain.
28119 C = C.getOperand(i: 0);
28120 return true;
28121 }
28122 return false;
28123 }
28124 default:
28125 return false;
28126 }
28127 };
28128
28129 // Look at each chain and determine if it is an alias. If so, add it to the
28130 // aliases list. If not, then continue up the chain looking for the next
28131 // candidate.
28132 while (!Chains.empty()) {
28133 SDValue Chain = Chains.pop_back_val();
28134
28135 // Don't bother if we've seen Chain before.
28136 if (!Visited.insert(Ptr: Chain.getNode()).second)
28137 continue;
28138
28139 // For TokenFactor nodes, look at each operand and only continue up the
28140 // chain until we reach the depth limit.
28141 //
28142 // FIXME: The depth check could be made to return the last non-aliasing
28143 // chain we found before we hit a tokenfactor rather than the original
28144 // chain.
28145 if (Depth > TLI.getGatherAllAliasesMaxDepth()) {
28146 Aliases.clear();
28147 Aliases.push_back(Elt: OriginalChain);
28148 return;
28149 }
28150
28151 if (Chain.getOpcode() == ISD::TokenFactor) {
28152 // We have to check each of the operands of the token factor for "small"
28153 // token factors, so we queue them up. Adding the operands to the queue
28154 // (stack) in reverse order maintains the original order and increases the
28155 // likelihood that getNode will find a matching token factor (CSE.)
28156 if (Chain.getNumOperands() > 16) {
28157 Aliases.push_back(Elt: Chain);
28158 continue;
28159 }
28160 for (unsigned n = Chain.getNumOperands(); n;)
28161 Chains.push_back(Elt: Chain.getOperand(i: --n));
28162 ++Depth;
28163 continue;
28164 }
28165 // Everything else
28166 if (ImproveChain(Chain)) {
28167 // Updated Chain Found, Consider new chain if one exists.
28168 if (Chain.getNode())
28169 Chains.push_back(Elt: Chain);
28170 ++Depth;
28171 continue;
28172 }
28173 // No Improved Chain Possible, treat as Alias.
28174 Aliases.push_back(Elt: Chain);
28175 }
28176}
28177
28178/// Walk up chain skipping non-aliasing memory nodes, looking for a better chain
28179/// (aliasing node.)
28180SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) {
28181 if (OptLevel == CodeGenOptLevel::None)
28182 return OldChain;
28183
28184 // Ops for replacing token factor.
28185 SmallVector<SDValue, 8> Aliases;
28186
28187 // Accumulate all the aliases to this node.
28188 GatherAllAliases(N, OriginalChain: OldChain, Aliases);
28189
28190 // If no operands then chain to entry token.
28191 if (Aliases.empty())
28192 return DAG.getEntryNode();
28193
28194 // If a single operand then chain to it. We don't need to revisit it.
28195 if (Aliases.size() == 1)
28196 return Aliases[0];
28197
28198 // Construct a custom tailored token factor.
28199 return DAG.getTokenFactor(DL: SDLoc(N), Vals&: Aliases);
28200}
28201
28202// This function tries to collect a bunch of potentially interesting
28203// nodes to improve the chains of, all at once. This might seem
28204// redundant, as this function gets called when visiting every store
28205// node, so why not let the work be done on each store as it's visited?
28206//
28207// I believe this is mainly important because mergeConsecutiveStores
28208// is unable to deal with merging stores of different sizes, so unless
28209// we improve the chains of all the potential candidates up-front
28210// before running mergeConsecutiveStores, it might only see some of
28211// the nodes that will eventually be candidates, and then not be able
28212// to go from a partially-merged state to the desired final
28213// fully-merged state.
28214
28215bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
28216 SmallVector<StoreSDNode *, 8> ChainedStores;
28217 StoreSDNode *STChain = St;
28218 // Intervals records which offsets from BaseIndex have been covered. In
28219 // the common case, every store writes to the immediately previous address
28220 // space and thus merged with the previous interval at insertion time.
28221
28222 using IMap = llvm::IntervalMap<int64_t, std::monostate, 8,
28223 IntervalMapHalfOpenInfo<int64_t>>;
28224 IMap::Allocator A;
28225 IMap Intervals(A);
28226
28227 // This holds the base pointer, index, and the offset in bytes from the base
28228 // pointer.
28229 const BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
28230
28231 // We must have a base and an offset.
28232 if (!BasePtr.getBase().getNode())
28233 return false;
28234
28235 // Do not handle stores to undef base pointers.
28236 if (BasePtr.getBase().isUndef())
28237 return false;
28238
28239 // Do not handle stores to opaque types
28240 if (St->getMemoryVT().isZeroSized())
28241 return false;
28242
28243 // BaseIndexOffset assumes that offsets are fixed-size, which
28244 // is not valid for scalable vectors where the offsets are
28245 // scaled by `vscale`, so bail out early.
28246 if (St->getMemoryVT().isScalableVT())
28247 return false;
28248
28249 // Add ST's interval.
28250 Intervals.insert(a: 0, b: (St->getMemoryVT().getSizeInBits() + 7) / 8,
28251 y: std::monostate{});
28252
28253 while (StoreSDNode *Chain = dyn_cast<StoreSDNode>(Val: STChain->getChain())) {
28254 if (Chain->getMemoryVT().isScalableVector())
28255 return false;
28256
28257 // If the chain has more than one use, then we can't reorder the mem ops.
28258 if (!SDValue(Chain, 0)->hasOneUse())
28259 break;
28260 // TODO: Relax for unordered atomics (see D66309)
28261 if (!Chain->isSimple() || Chain->isIndexed())
28262 break;
28263
28264 // Find the base pointer and offset for this memory node.
28265 const BaseIndexOffset Ptr = BaseIndexOffset::match(N: Chain, DAG);
28266 // Check that the base pointer is the same as the original one.
28267 int64_t Offset;
28268 if (!BasePtr.equalBaseIndex(Other: Ptr, DAG, Off&: Offset))
28269 break;
28270 int64_t Length = (Chain->getMemoryVT().getSizeInBits() + 7) / 8;
28271 // Make sure we don't overlap with other intervals by checking the ones to
28272 // the left or right before inserting.
28273 auto I = Intervals.find(x: Offset);
28274 // If there's a next interval, we should end before it.
28275 if (I != Intervals.end() && I.start() < (Offset + Length))
28276 break;
28277 // If there's a previous interval, we should start after it.
28278 if (I != Intervals.begin() && (--I).stop() <= Offset)
28279 break;
28280 Intervals.insert(a: Offset, b: Offset + Length, y: std::monostate{});
28281
28282 ChainedStores.push_back(Elt: Chain);
28283 STChain = Chain;
28284 }
28285
28286 // If we didn't find a chained store, exit.
28287 if (ChainedStores.empty())
28288 return false;
28289
28290 // Improve all chained stores (St and ChainedStores members) starting from
28291 // where the store chain ended and return single TokenFactor.
28292 SDValue NewChain = STChain->getChain();
28293 SmallVector<SDValue, 8> TFOps;
28294 for (unsigned I = ChainedStores.size(); I;) {
28295 StoreSDNode *S = ChainedStores[--I];
28296 SDValue BetterChain = FindBetterChain(N: S, OldChain: NewChain);
28297 S = cast<StoreSDNode>(Val: DAG.UpdateNodeOperands(
28298 N: S, Op1: BetterChain, Op2: S->getOperand(Num: 1), Op3: S->getOperand(Num: 2), Op4: S->getOperand(Num: 3)));
28299 TFOps.push_back(Elt: SDValue(S, 0));
28300 ChainedStores[I] = S;
28301 }
28302
28303 // Improve St's chain. Use a new node to avoid creating a loop from CombineTo.
28304 SDValue BetterChain = FindBetterChain(N: St, OldChain: NewChain);
28305 SDValue NewST;
28306 if (St->isTruncatingStore())
28307 NewST = DAG.getTruncStore(Chain: BetterChain, dl: SDLoc(St), Val: St->getValue(),
28308 Ptr: St->getBasePtr(), SVT: St->getMemoryVT(),
28309 MMO: St->getMemOperand());
28310 else
28311 NewST = DAG.getStore(Chain: BetterChain, dl: SDLoc(St), Val: St->getValue(),
28312 Ptr: St->getBasePtr(), MMO: St->getMemOperand());
28313
28314 TFOps.push_back(Elt: NewST);
28315
28316 // If we improved every element of TFOps, then we've lost the dependence on
28317 // NewChain to successors of St and we need to add it back to TFOps. Do so at
28318 // the beginning to keep relative order consistent with FindBetterChains.
28319 auto hasImprovedChain = [&](SDValue ST) -> bool {
28320 return ST->getOperand(Num: 0) != NewChain;
28321 };
28322 bool AddNewChain = llvm::all_of(Range&: TFOps, P: hasImprovedChain);
28323 if (AddNewChain)
28324 TFOps.insert(I: TFOps.begin(), Elt: NewChain);
28325
28326 SDValue TF = DAG.getTokenFactor(DL: SDLoc(STChain), Vals&: TFOps);
28327 CombineTo(N: St, Res: TF);
28328
28329 // Add TF and its operands to the worklist.
28330 AddToWorklist(N: TF.getNode());
28331 for (const SDValue &Op : TF->ops())
28332 AddToWorklist(N: Op.getNode());
28333 AddToWorklist(N: STChain);
28334 return true;
28335}
28336
28337bool DAGCombiner::findBetterNeighborChains(StoreSDNode *St) {
28338 if (OptLevel == CodeGenOptLevel::None)
28339 return false;
28340
28341 const BaseIndexOffset BasePtr = BaseIndexOffset::match(N: St, DAG);
28342
28343 // We must have a base and an offset.
28344 if (!BasePtr.getBase().getNode())
28345 return false;
28346
28347 // Do not handle stores to undef base pointers.
28348 if (BasePtr.getBase().isUndef())
28349 return false;
28350
28351 // Directly improve a chain of disjoint stores starting at St.
28352 if (parallelizeChainedStores(St))
28353 return true;
28354
28355 // Improve St's Chain..
28356 SDValue BetterChain = FindBetterChain(N: St, OldChain: St->getChain());
28357 if (St->getChain() != BetterChain) {
28358 replaceStoreChain(ST: St, BetterChain);
28359 return true;
28360 }
28361 return false;
28362}
28363
28364/// This is the entry point for the file.
28365void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis *AA,
28366 CodeGenOptLevel OptLevel) {
28367 /// This is the main entry point to this class.
28368 DAGCombiner(*this, AA, OptLevel).Run(AtLevel: Level);
28369}
28370

source code of llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp