1//===- LoopIdiomRecognize.cpp - Loop idiom recognition --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements an idiom recognizer that transforms simple loops into a
10// non-loop form. In cases that this kicks in, it can be a significant
11// performance win.
12//
13// If compiling for code size we avoid idiom recognition if the resulting
14// code could be larger than the code for the original loop. One way this could
15// happen is if the loop is not removable after idiom recognition due to the
16// presence of non-idiom instructions. The initial implementation of the
17// heuristics applies to idioms in multi-block loops.
18//
19//===----------------------------------------------------------------------===//
20//
21// TODO List:
22//
23// Future loop memory idioms to recognize:
24// memcmp, strlen, etc.
25// Future floating point idioms to recognize in -ffast-math mode:
26// fpowi
27//
28// This could recognize common matrix multiplies and dot product idioms and
29// replace them with calls to BLAS (if linked in??).
30//
31//===----------------------------------------------------------------------===//
32
33#include "llvm/Transforms/Scalar/LoopIdiomRecognize.h"
34#include "llvm/ADT/APInt.h"
35#include "llvm/ADT/ArrayRef.h"
36#include "llvm/ADT/DenseMap.h"
37#include "llvm/ADT/MapVector.h"
38#include "llvm/ADT/SetVector.h"
39#include "llvm/ADT/SmallPtrSet.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/Statistic.h"
42#include "llvm/ADT/StringRef.h"
43#include "llvm/Analysis/AliasAnalysis.h"
44#include "llvm/Analysis/CmpInstAnalysis.h"
45#include "llvm/Analysis/LoopAccessAnalysis.h"
46#include "llvm/Analysis/LoopInfo.h"
47#include "llvm/Analysis/LoopPass.h"
48#include "llvm/Analysis/MemoryLocation.h"
49#include "llvm/Analysis/MemorySSA.h"
50#include "llvm/Analysis/MemorySSAUpdater.h"
51#include "llvm/Analysis/MustExecute.h"
52#include "llvm/Analysis/OptimizationRemarkEmitter.h"
53#include "llvm/Analysis/ScalarEvolution.h"
54#include "llvm/Analysis/ScalarEvolutionExpressions.h"
55#include "llvm/Analysis/TargetLibraryInfo.h"
56#include "llvm/Analysis/TargetTransformInfo.h"
57#include "llvm/Analysis/ValueTracking.h"
58#include "llvm/IR/BasicBlock.h"
59#include "llvm/IR/Constant.h"
60#include "llvm/IR/Constants.h"
61#include "llvm/IR/DataLayout.h"
62#include "llvm/IR/DebugLoc.h"
63#include "llvm/IR/DerivedTypes.h"
64#include "llvm/IR/Dominators.h"
65#include "llvm/IR/GlobalValue.h"
66#include "llvm/IR/GlobalVariable.h"
67#include "llvm/IR/IRBuilder.h"
68#include "llvm/IR/InstrTypes.h"
69#include "llvm/IR/Instruction.h"
70#include "llvm/IR/Instructions.h"
71#include "llvm/IR/IntrinsicInst.h"
72#include "llvm/IR/Intrinsics.h"
73#include "llvm/IR/LLVMContext.h"
74#include "llvm/IR/Module.h"
75#include "llvm/IR/PassManager.h"
76#include "llvm/IR/PatternMatch.h"
77#include "llvm/IR/Type.h"
78#include "llvm/IR/User.h"
79#include "llvm/IR/Value.h"
80#include "llvm/IR/ValueHandle.h"
81#include "llvm/Support/Casting.h"
82#include "llvm/Support/CommandLine.h"
83#include "llvm/Support/Debug.h"
84#include "llvm/Support/InstructionCost.h"
85#include "llvm/Support/raw_ostream.h"
86#include "llvm/Transforms/Utils/BuildLibCalls.h"
87#include "llvm/Transforms/Utils/Local.h"
88#include "llvm/Transforms/Utils/LoopUtils.h"
89#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
90#include <algorithm>
91#include <cassert>
92#include <cstdint>
93#include <utility>
94#include <vector>
95
96using namespace llvm;
97
98#define DEBUG_TYPE "loop-idiom"
99
100STATISTIC(NumMemSet, "Number of memset's formed from loop stores");
101STATISTIC(NumMemCpy, "Number of memcpy's formed from loop load+stores");
102STATISTIC(NumMemMove, "Number of memmove's formed from loop load+stores");
103STATISTIC(
104 NumShiftUntilBitTest,
105 "Number of uncountable loops recognized as 'shift until bitttest' idiom");
106STATISTIC(NumShiftUntilZero,
107 "Number of uncountable loops recognized as 'shift until zero' idiom");
108
109bool DisableLIRP::All;
110static cl::opt<bool, true>
111 DisableLIRPAll("disable-" DEBUG_TYPE "-all",
112 cl::desc("Options to disable Loop Idiom Recognize Pass."),
113 cl::location(L&: DisableLIRP::All), cl::init(Val: false),
114 cl::ReallyHidden);
115
116bool DisableLIRP::Memset;
117static cl::opt<bool, true>
118 DisableLIRPMemset("disable-" DEBUG_TYPE "-memset",
119 cl::desc("Proceed with loop idiom recognize pass, but do "
120 "not convert loop(s) to memset."),
121 cl::location(L&: DisableLIRP::Memset), cl::init(Val: false),
122 cl::ReallyHidden);
123
124bool DisableLIRP::Memcpy;
125static cl::opt<bool, true>
126 DisableLIRPMemcpy("disable-" DEBUG_TYPE "-memcpy",
127 cl::desc("Proceed with loop idiom recognize pass, but do "
128 "not convert loop(s) to memcpy."),
129 cl::location(L&: DisableLIRP::Memcpy), cl::init(Val: false),
130 cl::ReallyHidden);
131
132static cl::opt<bool> UseLIRCodeSizeHeurs(
133 "use-lir-code-size-heurs",
134 cl::desc("Use loop idiom recognition code size heuristics when compiling"
135 "with -Os/-Oz"),
136 cl::init(Val: true), cl::Hidden);
137
138namespace {
139
140class LoopIdiomRecognize {
141 Loop *CurLoop = nullptr;
142 AliasAnalysis *AA;
143 DominatorTree *DT;
144 LoopInfo *LI;
145 ScalarEvolution *SE;
146 TargetLibraryInfo *TLI;
147 const TargetTransformInfo *TTI;
148 const DataLayout *DL;
149 OptimizationRemarkEmitter &ORE;
150 bool ApplyCodeSizeHeuristics;
151 std::unique_ptr<MemorySSAUpdater> MSSAU;
152
153public:
154 explicit LoopIdiomRecognize(AliasAnalysis *AA, DominatorTree *DT,
155 LoopInfo *LI, ScalarEvolution *SE,
156 TargetLibraryInfo *TLI,
157 const TargetTransformInfo *TTI, MemorySSA *MSSA,
158 const DataLayout *DL,
159 OptimizationRemarkEmitter &ORE)
160 : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), ORE(ORE) {
161 if (MSSA)
162 MSSAU = std::make_unique<MemorySSAUpdater>(args&: MSSA);
163 }
164
165 bool runOnLoop(Loop *L);
166
167private:
168 using StoreList = SmallVector<StoreInst *, 8>;
169 using StoreListMap = MapVector<Value *, StoreList>;
170
171 StoreListMap StoreRefsForMemset;
172 StoreListMap StoreRefsForMemsetPattern;
173 StoreList StoreRefsForMemcpy;
174 bool HasMemset;
175 bool HasMemsetPattern;
176 bool HasMemcpy;
177
178 /// Return code for isLegalStore()
179 enum LegalStoreKind {
180 None = 0,
181 Memset,
182 MemsetPattern,
183 Memcpy,
184 UnorderedAtomicMemcpy,
185 DontUse // Dummy retval never to be used. Allows catching errors in retval
186 // handling.
187 };
188
189 /// \name Countable Loop Idiom Handling
190 /// @{
191
192 bool runOnCountableLoop();
193 bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
194 SmallVectorImpl<BasicBlock *> &ExitBlocks);
195
196 void collectStores(BasicBlock *BB);
197 LegalStoreKind isLegalStore(StoreInst *SI);
198 enum class ForMemset { No, Yes };
199 bool processLoopStores(SmallVectorImpl<StoreInst *> &SL, const SCEV *BECount,
200 ForMemset For);
201
202 template <typename MemInst>
203 bool processLoopMemIntrinsic(
204 BasicBlock *BB,
205 bool (LoopIdiomRecognize::*Processor)(MemInst *, const SCEV *),
206 const SCEV *BECount);
207 bool processLoopMemCpy(MemCpyInst *MCI, const SCEV *BECount);
208 bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount);
209
210 bool processLoopStridedStore(Value *DestPtr, const SCEV *StoreSizeSCEV,
211 MaybeAlign StoreAlignment, Value *StoredVal,
212 Instruction *TheStore,
213 SmallPtrSetImpl<Instruction *> &Stores,
214 const SCEVAddRecExpr *Ev, const SCEV *BECount,
215 bool IsNegStride, bool IsLoopMemset = false);
216 bool processLoopStoreOfLoopLoad(StoreInst *SI, const SCEV *BECount);
217 bool processLoopStoreOfLoopLoad(Value *DestPtr, Value *SourcePtr,
218 const SCEV *StoreSize, MaybeAlign StoreAlign,
219 MaybeAlign LoadAlign, Instruction *TheStore,
220 Instruction *TheLoad,
221 const SCEVAddRecExpr *StoreEv,
222 const SCEVAddRecExpr *LoadEv,
223 const SCEV *BECount);
224 bool avoidLIRForMultiBlockLoop(bool IsMemset = false,
225 bool IsLoopMemset = false);
226
227 /// @}
228 /// \name Noncountable Loop Idiom Handling
229 /// @{
230
231 bool runOnNoncountableLoop();
232
233 bool recognizePopcount();
234 void transformLoopToPopcount(BasicBlock *PreCondBB, Instruction *CntInst,
235 PHINode *CntPhi, Value *Var);
236 bool recognizeAndInsertFFS(); /// Find First Set: ctlz or cttz
237 void transformLoopToCountable(Intrinsic::ID IntrinID, BasicBlock *PreCondBB,
238 Instruction *CntInst, PHINode *CntPhi,
239 Value *Var, Instruction *DefX,
240 const DebugLoc &DL, bool ZeroCheck,
241 bool IsCntPhiUsedOutsideLoop);
242
243 bool recognizeShiftUntilBitTest();
244 bool recognizeShiftUntilZero();
245
246 /// @}
247};
248} // end anonymous namespace
249
250PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM,
251 LoopStandardAnalysisResults &AR,
252 LPMUpdater &) {
253 if (DisableLIRP::All)
254 return PreservedAnalyses::all();
255
256 const auto *DL = &L.getHeader()->getModule()->getDataLayout();
257
258 // For the new PM, we also can't use OptimizationRemarkEmitter as an analysis
259 // pass. Function analyses need to be preserved across loop transformations
260 // but ORE cannot be preserved (see comment before the pass definition).
261 OptimizationRemarkEmitter ORE(L.getHeader()->getParent());
262
263 LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI,
264 AR.MSSA, DL, ORE);
265 if (!LIR.runOnLoop(L: &L))
266 return PreservedAnalyses::all();
267
268 auto PA = getLoopPassPreservedAnalyses();
269 if (AR.MSSA)
270 PA.preserve<MemorySSAAnalysis>();
271 return PA;
272}
273
274static void deleteDeadInstruction(Instruction *I) {
275 I->replaceAllUsesWith(V: PoisonValue::get(T: I->getType()));
276 I->eraseFromParent();
277}
278
279//===----------------------------------------------------------------------===//
280//
281// Implementation of LoopIdiomRecognize
282//
283//===----------------------------------------------------------------------===//
284
285bool LoopIdiomRecognize::runOnLoop(Loop *L) {
286 CurLoop = L;
287 // If the loop could not be converted to canonical form, it must have an
288 // indirectbr in it, just give up.
289 if (!L->getLoopPreheader())
290 return false;
291
292 // Disable loop idiom recognition if the function's name is a common idiom.
293 StringRef Name = L->getHeader()->getParent()->getName();
294 if (Name == "memset" || Name == "memcpy")
295 return false;
296
297 // Determine if code size heuristics need to be applied.
298 ApplyCodeSizeHeuristics =
299 L->getHeader()->getParent()->hasOptSize() && UseLIRCodeSizeHeurs;
300
301 HasMemset = TLI->has(F: LibFunc_memset);
302 HasMemsetPattern = TLI->has(F: LibFunc_memset_pattern16);
303 HasMemcpy = TLI->has(F: LibFunc_memcpy);
304
305 if (HasMemset || HasMemsetPattern || HasMemcpy)
306 if (SE->hasLoopInvariantBackedgeTakenCount(L))
307 return runOnCountableLoop();
308
309 return runOnNoncountableLoop();
310}
311
312bool LoopIdiomRecognize::runOnCountableLoop() {
313 const SCEV *BECount = SE->getBackedgeTakenCount(L: CurLoop);
314 assert(!isa<SCEVCouldNotCompute>(BECount) &&
315 "runOnCountableLoop() called on a loop without a predictable"
316 "backedge-taken count");
317
318 // If this loop executes exactly one time, then it should be peeled, not
319 // optimized by this pass.
320 if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(Val: BECount))
321 if (BECst->getAPInt() == 0)
322 return false;
323
324 SmallVector<BasicBlock *, 8> ExitBlocks;
325 CurLoop->getUniqueExitBlocks(ExitBlocks);
326
327 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F["
328 << CurLoop->getHeader()->getParent()->getName()
329 << "] Countable Loop %" << CurLoop->getHeader()->getName()
330 << "\n");
331
332 // The following transforms hoist stores/memsets into the loop pre-header.
333 // Give up if the loop has instructions that may throw.
334 SimpleLoopSafetyInfo SafetyInfo;
335 SafetyInfo.computeLoopSafetyInfo(CurLoop);
336 if (SafetyInfo.anyBlockMayThrow())
337 return false;
338
339 bool MadeChange = false;
340
341 // Scan all the blocks in the loop that are not in subloops.
342 for (auto *BB : CurLoop->getBlocks()) {
343 // Ignore blocks in subloops.
344 if (LI->getLoopFor(BB) != CurLoop)
345 continue;
346
347 MadeChange |= runOnLoopBlock(BB, BECount, ExitBlocks);
348 }
349 return MadeChange;
350}
351
352static APInt getStoreStride(const SCEVAddRecExpr *StoreEv) {
353 const SCEVConstant *ConstStride = cast<SCEVConstant>(Val: StoreEv->getOperand(i: 1));
354 return ConstStride->getAPInt();
355}
356
357/// getMemSetPatternValue - If a strided store of the specified value is safe to
358/// turn into a memset_pattern16, return a ConstantArray of 16 bytes that should
359/// be passed in. Otherwise, return null.
360///
361/// Note that we don't ever attempt to use memset_pattern8 or 4, because these
362/// just replicate their input array and then pass on to memset_pattern16.
363static Constant *getMemSetPatternValue(Value *V, const DataLayout *DL) {
364 // FIXME: This could check for UndefValue because it can be merged into any
365 // other valid pattern.
366
367 // If the value isn't a constant, we can't promote it to being in a constant
368 // array. We could theoretically do a store to an alloca or something, but
369 // that doesn't seem worthwhile.
370 Constant *C = dyn_cast<Constant>(Val: V);
371 if (!C || isa<ConstantExpr>(Val: C))
372 return nullptr;
373
374 // Only handle simple values that are a power of two bytes in size.
375 uint64_t Size = DL->getTypeSizeInBits(Ty: V->getType());
376 if (Size == 0 || (Size & 7) || (Size & (Size - 1)))
377 return nullptr;
378
379 // Don't care enough about darwin/ppc to implement this.
380 if (DL->isBigEndian())
381 return nullptr;
382
383 // Convert to size in bytes.
384 Size /= 8;
385
386 // TODO: If CI is larger than 16-bytes, we can try slicing it in half to see
387 // if the top and bottom are the same (e.g. for vectors and large integers).
388 if (Size > 16)
389 return nullptr;
390
391 // If the constant is exactly 16 bytes, just use it.
392 if (Size == 16)
393 return C;
394
395 // Otherwise, we'll use an array of the constants.
396 unsigned ArraySize = 16 / Size;
397 ArrayType *AT = ArrayType::get(ElementType: V->getType(), NumElements: ArraySize);
398 return ConstantArray::get(T: AT, V: std::vector<Constant *>(ArraySize, C));
399}
400
401LoopIdiomRecognize::LegalStoreKind
402LoopIdiomRecognize::isLegalStore(StoreInst *SI) {
403 // Don't touch volatile stores.
404 if (SI->isVolatile())
405 return LegalStoreKind::None;
406 // We only want simple or unordered-atomic stores.
407 if (!SI->isUnordered())
408 return LegalStoreKind::None;
409
410 // Avoid merging nontemporal stores.
411 if (SI->getMetadata(KindID: LLVMContext::MD_nontemporal))
412 return LegalStoreKind::None;
413
414 Value *StoredVal = SI->getValueOperand();
415 Value *StorePtr = SI->getPointerOperand();
416
417 // Don't convert stores of non-integral pointer types to memsets (which stores
418 // integers).
419 if (DL->isNonIntegralPointerType(Ty: StoredVal->getType()->getScalarType()))
420 return LegalStoreKind::None;
421
422 // Reject stores that are so large that they overflow an unsigned.
423 // When storing out scalable vectors we bail out for now, since the code
424 // below currently only works for constant strides.
425 TypeSize SizeInBits = DL->getTypeSizeInBits(Ty: StoredVal->getType());
426 if (SizeInBits.isScalable() || (SizeInBits.getFixedValue() & 7) ||
427 (SizeInBits.getFixedValue() >> 32) != 0)
428 return LegalStoreKind::None;
429
430 // See if the pointer expression is an AddRec like {base,+,1} on the current
431 // loop, which indicates a strided store. If we have something else, it's a
432 // random store we can't handle.
433 const SCEVAddRecExpr *StoreEv =
434 dyn_cast<SCEVAddRecExpr>(Val: SE->getSCEV(V: StorePtr));
435 if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine())
436 return LegalStoreKind::None;
437
438 // Check to see if we have a constant stride.
439 if (!isa<SCEVConstant>(Val: StoreEv->getOperand(i: 1)))
440 return LegalStoreKind::None;
441
442 // See if the store can be turned into a memset.
443
444 // If the stored value is a byte-wise value (like i32 -1), then it may be
445 // turned into a memset of i8 -1, assuming that all the consecutive bytes
446 // are stored. A store of i32 0x01020304 can never be turned into a memset,
447 // but it can be turned into memset_pattern if the target supports it.
448 Value *SplatValue = isBytewiseValue(V: StoredVal, DL: *DL);
449
450 // Note: memset and memset_pattern on unordered-atomic is yet not supported
451 bool UnorderedAtomic = SI->isUnordered() && !SI->isSimple();
452
453 // If we're allowed to form a memset, and the stored value would be
454 // acceptable for memset, use it.
455 if (!UnorderedAtomic && HasMemset && SplatValue && !DisableLIRP::Memset &&
456 // Verify that the stored value is loop invariant. If not, we can't
457 // promote the memset.
458 CurLoop->isLoopInvariant(V: SplatValue)) {
459 // It looks like we can use SplatValue.
460 return LegalStoreKind::Memset;
461 }
462 if (!UnorderedAtomic && HasMemsetPattern && !DisableLIRP::Memset &&
463 // Don't create memset_pattern16s with address spaces.
464 StorePtr->getType()->getPointerAddressSpace() == 0 &&
465 getMemSetPatternValue(V: StoredVal, DL)) {
466 // It looks like we can use PatternValue!
467 return LegalStoreKind::MemsetPattern;
468 }
469
470 // Otherwise, see if the store can be turned into a memcpy.
471 if (HasMemcpy && !DisableLIRP::Memcpy) {
472 // Check to see if the stride matches the size of the store. If so, then we
473 // know that every byte is touched in the loop.
474 APInt Stride = getStoreStride(StoreEv);
475 unsigned StoreSize = DL->getTypeStoreSize(Ty: SI->getValueOperand()->getType());
476 if (StoreSize != Stride && StoreSize != -Stride)
477 return LegalStoreKind::None;
478
479 // The store must be feeding a non-volatile load.
480 LoadInst *LI = dyn_cast<LoadInst>(Val: SI->getValueOperand());
481
482 // Only allow non-volatile loads
483 if (!LI || LI->isVolatile())
484 return LegalStoreKind::None;
485 // Only allow simple or unordered-atomic loads
486 if (!LI->isUnordered())
487 return LegalStoreKind::None;
488
489 // See if the pointer expression is an AddRec like {base,+,1} on the current
490 // loop, which indicates a strided load. If we have something else, it's a
491 // random load we can't handle.
492 const SCEVAddRecExpr *LoadEv =
493 dyn_cast<SCEVAddRecExpr>(Val: SE->getSCEV(V: LI->getPointerOperand()));
494 if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
495 return LegalStoreKind::None;
496
497 // The store and load must share the same stride.
498 if (StoreEv->getOperand(i: 1) != LoadEv->getOperand(i: 1))
499 return LegalStoreKind::None;
500
501 // Success. This store can be converted into a memcpy.
502 UnorderedAtomic = UnorderedAtomic || LI->isAtomic();
503 return UnorderedAtomic ? LegalStoreKind::UnorderedAtomicMemcpy
504 : LegalStoreKind::Memcpy;
505 }
506 // This store can't be transformed into a memset/memcpy.
507 return LegalStoreKind::None;
508}
509
510void LoopIdiomRecognize::collectStores(BasicBlock *BB) {
511 StoreRefsForMemset.clear();
512 StoreRefsForMemsetPattern.clear();
513 StoreRefsForMemcpy.clear();
514 for (Instruction &I : *BB) {
515 StoreInst *SI = dyn_cast<StoreInst>(Val: &I);
516 if (!SI)
517 continue;
518
519 // Make sure this is a strided store with a constant stride.
520 switch (isLegalStore(SI)) {
521 case LegalStoreKind::None:
522 // Nothing to do
523 break;
524 case LegalStoreKind::Memset: {
525 // Find the base pointer.
526 Value *Ptr = getUnderlyingObject(V: SI->getPointerOperand());
527 StoreRefsForMemset[Ptr].push_back(Elt: SI);
528 } break;
529 case LegalStoreKind::MemsetPattern: {
530 // Find the base pointer.
531 Value *Ptr = getUnderlyingObject(V: SI->getPointerOperand());
532 StoreRefsForMemsetPattern[Ptr].push_back(Elt: SI);
533 } break;
534 case LegalStoreKind::Memcpy:
535 case LegalStoreKind::UnorderedAtomicMemcpy:
536 StoreRefsForMemcpy.push_back(Elt: SI);
537 break;
538 default:
539 assert(false && "unhandled return value");
540 break;
541 }
542 }
543}
544
545/// runOnLoopBlock - Process the specified block, which lives in a counted loop
546/// with the specified backedge count. This block is known to be in the current
547/// loop and not in any subloops.
548bool LoopIdiomRecognize::runOnLoopBlock(
549 BasicBlock *BB, const SCEV *BECount,
550 SmallVectorImpl<BasicBlock *> &ExitBlocks) {
551 // We can only promote stores in this block if they are unconditionally
552 // executed in the loop. For a block to be unconditionally executed, it has
553 // to dominate all the exit blocks of the loop. Verify this now.
554 for (BasicBlock *ExitBlock : ExitBlocks)
555 if (!DT->dominates(A: BB, B: ExitBlock))
556 return false;
557
558 bool MadeChange = false;
559 // Look for store instructions, which may be optimized to memset/memcpy.
560 collectStores(BB);
561
562 // Look for a single store or sets of stores with a common base, which can be
563 // optimized into a memset (memset_pattern). The latter most commonly happens
564 // with structs and handunrolled loops.
565 for (auto &SL : StoreRefsForMemset)
566 MadeChange |= processLoopStores(SL&: SL.second, BECount, For: ForMemset::Yes);
567
568 for (auto &SL : StoreRefsForMemsetPattern)
569 MadeChange |= processLoopStores(SL&: SL.second, BECount, For: ForMemset::No);
570
571 // Optimize the store into a memcpy, if it feeds an similarly strided load.
572 for (auto &SI : StoreRefsForMemcpy)
573 MadeChange |= processLoopStoreOfLoopLoad(SI, BECount);
574
575 MadeChange |= processLoopMemIntrinsic<MemCpyInst>(
576 BB, Processor: &LoopIdiomRecognize::processLoopMemCpy, BECount);
577 MadeChange |= processLoopMemIntrinsic<MemSetInst>(
578 BB, Processor: &LoopIdiomRecognize::processLoopMemSet, BECount);
579
580 return MadeChange;
581}
582
583/// See if this store(s) can be promoted to a memset.
584bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL,
585 const SCEV *BECount, ForMemset For) {
586 // Try to find consecutive stores that can be transformed into memsets.
587 SetVector<StoreInst *> Heads, Tails;
588 SmallDenseMap<StoreInst *, StoreInst *> ConsecutiveChain;
589
590 // Do a quadratic search on all of the given stores and find
591 // all of the pairs of stores that follow each other.
592 SmallVector<unsigned, 16> IndexQueue;
593 for (unsigned i = 0, e = SL.size(); i < e; ++i) {
594 assert(SL[i]->isSimple() && "Expected only non-volatile stores.");
595
596 Value *FirstStoredVal = SL[i]->getValueOperand();
597 Value *FirstStorePtr = SL[i]->getPointerOperand();
598 const SCEVAddRecExpr *FirstStoreEv =
599 cast<SCEVAddRecExpr>(Val: SE->getSCEV(V: FirstStorePtr));
600 APInt FirstStride = getStoreStride(StoreEv: FirstStoreEv);
601 unsigned FirstStoreSize = DL->getTypeStoreSize(Ty: SL[i]->getValueOperand()->getType());
602
603 // See if we can optimize just this store in isolation.
604 if (FirstStride == FirstStoreSize || -FirstStride == FirstStoreSize) {
605 Heads.insert(X: SL[i]);
606 continue;
607 }
608
609 Value *FirstSplatValue = nullptr;
610 Constant *FirstPatternValue = nullptr;
611
612 if (For == ForMemset::Yes)
613 FirstSplatValue = isBytewiseValue(V: FirstStoredVal, DL: *DL);
614 else
615 FirstPatternValue = getMemSetPatternValue(V: FirstStoredVal, DL);
616
617 assert((FirstSplatValue || FirstPatternValue) &&
618 "Expected either splat value or pattern value.");
619
620 IndexQueue.clear();
621 // If a store has multiple consecutive store candidates, search Stores
622 // array according to the sequence: from i+1 to e, then from i-1 to 0.
623 // This is because usually pairing with immediate succeeding or preceding
624 // candidate create the best chance to find memset opportunity.
625 unsigned j = 0;
626 for (j = i + 1; j < e; ++j)
627 IndexQueue.push_back(Elt: j);
628 for (j = i; j > 0; --j)
629 IndexQueue.push_back(Elt: j - 1);
630
631 for (auto &k : IndexQueue) {
632 assert(SL[k]->isSimple() && "Expected only non-volatile stores.");
633 Value *SecondStorePtr = SL[k]->getPointerOperand();
634 const SCEVAddRecExpr *SecondStoreEv =
635 cast<SCEVAddRecExpr>(Val: SE->getSCEV(V: SecondStorePtr));
636 APInt SecondStride = getStoreStride(StoreEv: SecondStoreEv);
637
638 if (FirstStride != SecondStride)
639 continue;
640
641 Value *SecondStoredVal = SL[k]->getValueOperand();
642 Value *SecondSplatValue = nullptr;
643 Constant *SecondPatternValue = nullptr;
644
645 if (For == ForMemset::Yes)
646 SecondSplatValue = isBytewiseValue(V: SecondStoredVal, DL: *DL);
647 else
648 SecondPatternValue = getMemSetPatternValue(V: SecondStoredVal, DL);
649
650 assert((SecondSplatValue || SecondPatternValue) &&
651 "Expected either splat value or pattern value.");
652
653 if (isConsecutiveAccess(A: SL[i], B: SL[k], DL: *DL, SE&: *SE, CheckType: false)) {
654 if (For == ForMemset::Yes) {
655 if (isa<UndefValue>(Val: FirstSplatValue))
656 FirstSplatValue = SecondSplatValue;
657 if (FirstSplatValue != SecondSplatValue)
658 continue;
659 } else {
660 if (isa<UndefValue>(Val: FirstPatternValue))
661 FirstPatternValue = SecondPatternValue;
662 if (FirstPatternValue != SecondPatternValue)
663 continue;
664 }
665 Tails.insert(X: SL[k]);
666 Heads.insert(X: SL[i]);
667 ConsecutiveChain[SL[i]] = SL[k];
668 break;
669 }
670 }
671 }
672
673 // We may run into multiple chains that merge into a single chain. We mark the
674 // stores that we transformed so that we don't visit the same store twice.
675 SmallPtrSet<Value *, 16> TransformedStores;
676 bool Changed = false;
677
678 // For stores that start but don't end a link in the chain:
679 for (StoreInst *I : Heads) {
680 if (Tails.count(key: I))
681 continue;
682
683 // We found a store instr that starts a chain. Now follow the chain and try
684 // to transform it.
685 SmallPtrSet<Instruction *, 8> AdjacentStores;
686 StoreInst *HeadStore = I;
687 unsigned StoreSize = 0;
688
689 // Collect the chain into a list.
690 while (Tails.count(key: I) || Heads.count(key: I)) {
691 if (TransformedStores.count(Ptr: I))
692 break;
693 AdjacentStores.insert(Ptr: I);
694
695 StoreSize += DL->getTypeStoreSize(Ty: I->getValueOperand()->getType());
696 // Move to the next value in the chain.
697 I = ConsecutiveChain[I];
698 }
699
700 Value *StoredVal = HeadStore->getValueOperand();
701 Value *StorePtr = HeadStore->getPointerOperand();
702 const SCEVAddRecExpr *StoreEv = cast<SCEVAddRecExpr>(Val: SE->getSCEV(V: StorePtr));
703 APInt Stride = getStoreStride(StoreEv);
704
705 // Check to see if the stride matches the size of the stores. If so, then
706 // we know that every byte is touched in the loop.
707 if (StoreSize != Stride && StoreSize != -Stride)
708 continue;
709
710 bool IsNegStride = StoreSize == -Stride;
711
712 Type *IntIdxTy = DL->getIndexType(PtrTy: StorePtr->getType());
713 const SCEV *StoreSizeSCEV = SE->getConstant(Ty: IntIdxTy, V: StoreSize);
714 if (processLoopStridedStore(DestPtr: StorePtr, StoreSizeSCEV,
715 StoreAlignment: MaybeAlign(HeadStore->getAlign()), StoredVal,
716 TheStore: HeadStore, Stores&: AdjacentStores, Ev: StoreEv, BECount,
717 IsNegStride)) {
718 TransformedStores.insert(I: AdjacentStores.begin(), E: AdjacentStores.end());
719 Changed = true;
720 }
721 }
722
723 return Changed;
724}
725
726/// processLoopMemIntrinsic - Template function for calling different processor
727/// functions based on mem intrinsic type.
728template <typename MemInst>
729bool LoopIdiomRecognize::processLoopMemIntrinsic(
730 BasicBlock *BB,
731 bool (LoopIdiomRecognize::*Processor)(MemInst *, const SCEV *),
732 const SCEV *BECount) {
733 bool MadeChange = false;
734 for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) {
735 Instruction *Inst = &*I++;
736 // Look for memory instructions, which may be optimized to a larger one.
737 if (MemInst *MI = dyn_cast<MemInst>(Inst)) {
738 WeakTrackingVH InstPtr(&*I);
739 if (!(this->*Processor)(MI, BECount))
740 continue;
741 MadeChange = true;
742
743 // If processing the instruction invalidated our iterator, start over from
744 // the top of the block.
745 if (!InstPtr)
746 I = BB->begin();
747 }
748 }
749 return MadeChange;
750}
751
752/// processLoopMemCpy - See if this memcpy can be promoted to a large memcpy
753bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
754 const SCEV *BECount) {
755 // We can only handle non-volatile memcpys with a constant size.
756 if (MCI->isVolatile() || !isa<ConstantInt>(Val: MCI->getLength()))
757 return false;
758
759 // If we're not allowed to hack on memcpy, we fail.
760 if ((!HasMemcpy && !isa<MemCpyInlineInst>(Val: MCI)) || DisableLIRP::Memcpy)
761 return false;
762
763 Value *Dest = MCI->getDest();
764 Value *Source = MCI->getSource();
765 if (!Dest || !Source)
766 return false;
767
768 // See if the load and store pointer expressions are AddRec like {base,+,1} on
769 // the current loop, which indicates a strided load and store. If we have
770 // something else, it's a random load or store we can't handle.
771 const SCEVAddRecExpr *StoreEv = dyn_cast<SCEVAddRecExpr>(Val: SE->getSCEV(V: Dest));
772 if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine())
773 return false;
774 const SCEVAddRecExpr *LoadEv = dyn_cast<SCEVAddRecExpr>(Val: SE->getSCEV(V: Source));
775 if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
776 return false;
777
778 // Reject memcpys that are so large that they overflow an unsigned.
779 uint64_t SizeInBytes = cast<ConstantInt>(Val: MCI->getLength())->getZExtValue();
780 if ((SizeInBytes >> 32) != 0)
781 return false;
782
783 // Check if the stride matches the size of the memcpy. If so, then we know
784 // that every byte is touched in the loop.
785 const SCEVConstant *ConstStoreStride =
786 dyn_cast<SCEVConstant>(Val: StoreEv->getOperand(i: 1));
787 const SCEVConstant *ConstLoadStride =
788 dyn_cast<SCEVConstant>(Val: LoadEv->getOperand(i: 1));
789 if (!ConstStoreStride || !ConstLoadStride)
790 return false;
791
792 APInt StoreStrideValue = ConstStoreStride->getAPInt();
793 APInt LoadStrideValue = ConstLoadStride->getAPInt();
794 // Huge stride value - give up
795 if (StoreStrideValue.getBitWidth() > 64 || LoadStrideValue.getBitWidth() > 64)
796 return false;
797
798 if (SizeInBytes != StoreStrideValue && SizeInBytes != -StoreStrideValue) {
799 ORE.emit(RemarkBuilder: [&]() {
800 return OptimizationRemarkMissed(DEBUG_TYPE, "SizeStrideUnequal", MCI)
801 << ore::NV("Inst", "memcpy") << " in "
802 << ore::NV("Function", MCI->getFunction())
803 << " function will not be hoisted: "
804 << ore::NV("Reason", "memcpy size is not equal to stride");
805 });
806 return false;
807 }
808
809 int64_t StoreStrideInt = StoreStrideValue.getSExtValue();
810 int64_t LoadStrideInt = LoadStrideValue.getSExtValue();
811 // Check if the load stride matches the store stride.
812 if (StoreStrideInt != LoadStrideInt)
813 return false;
814
815 return processLoopStoreOfLoopLoad(
816 DestPtr: Dest, SourcePtr: Source, StoreSize: SE->getConstant(Ty: Dest->getType(), V: SizeInBytes),
817 StoreAlign: MCI->getDestAlign(), LoadAlign: MCI->getSourceAlign(), TheStore: MCI, TheLoad: MCI, StoreEv, LoadEv,
818 BECount);
819}
820
821/// processLoopMemSet - See if this memset can be promoted to a large memset.
822bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
823 const SCEV *BECount) {
824 // We can only handle non-volatile memsets.
825 if (MSI->isVolatile())
826 return false;
827
828 // If we're not allowed to hack on memset, we fail.
829 if (!HasMemset || DisableLIRP::Memset)
830 return false;
831
832 Value *Pointer = MSI->getDest();
833
834 // See if the pointer expression is an AddRec like {base,+,1} on the current
835 // loop, which indicates a strided store. If we have something else, it's a
836 // random store we can't handle.
837 const SCEVAddRecExpr *Ev = dyn_cast<SCEVAddRecExpr>(Val: SE->getSCEV(V: Pointer));
838 if (!Ev || Ev->getLoop() != CurLoop)
839 return false;
840 if (!Ev->isAffine()) {
841 LLVM_DEBUG(dbgs() << " Pointer is not affine, abort\n");
842 return false;
843 }
844
845 const SCEV *PointerStrideSCEV = Ev->getOperand(i: 1);
846 const SCEV *MemsetSizeSCEV = SE->getSCEV(V: MSI->getLength());
847 if (!PointerStrideSCEV || !MemsetSizeSCEV)
848 return false;
849
850 bool IsNegStride = false;
851 const bool IsConstantSize = isa<ConstantInt>(Val: MSI->getLength());
852
853 if (IsConstantSize) {
854 // Memset size is constant.
855 // Check if the pointer stride matches the memset size. If so, then
856 // we know that every byte is touched in the loop.
857 LLVM_DEBUG(dbgs() << " memset size is constant\n");
858 uint64_t SizeInBytes = cast<ConstantInt>(Val: MSI->getLength())->getZExtValue();
859 const SCEVConstant *ConstStride = dyn_cast<SCEVConstant>(Val: Ev->getOperand(i: 1));
860 if (!ConstStride)
861 return false;
862
863 APInt Stride = ConstStride->getAPInt();
864 if (SizeInBytes != Stride && SizeInBytes != -Stride)
865 return false;
866
867 IsNegStride = SizeInBytes == -Stride;
868 } else {
869 // Memset size is non-constant.
870 // Check if the pointer stride matches the memset size.
871 // To be conservative, the pass would not promote pointers that aren't in
872 // address space zero. Also, the pass only handles memset length and stride
873 // that are invariant for the top level loop.
874 LLVM_DEBUG(dbgs() << " memset size is non-constant\n");
875 if (Pointer->getType()->getPointerAddressSpace() != 0) {
876 LLVM_DEBUG(dbgs() << " pointer is not in address space zero, "
877 << "abort\n");
878 return false;
879 }
880 if (!SE->isLoopInvariant(S: MemsetSizeSCEV, L: CurLoop)) {
881 LLVM_DEBUG(dbgs() << " memset size is not a loop-invariant, "
882 << "abort\n");
883 return false;
884 }
885
886 // Compare positive direction PointerStrideSCEV with MemsetSizeSCEV
887 IsNegStride = PointerStrideSCEV->isNonConstantNegative();
888 const SCEV *PositiveStrideSCEV =
889 IsNegStride ? SE->getNegativeSCEV(V: PointerStrideSCEV)
890 : PointerStrideSCEV;
891 LLVM_DEBUG(dbgs() << " MemsetSizeSCEV: " << *MemsetSizeSCEV << "\n"
892 << " PositiveStrideSCEV: " << *PositiveStrideSCEV
893 << "\n");
894
895 if (PositiveStrideSCEV != MemsetSizeSCEV) {
896 // If an expression is covered by the loop guard, compare again and
897 // proceed with optimization if equal.
898 const SCEV *FoldedPositiveStride =
899 SE->applyLoopGuards(Expr: PositiveStrideSCEV, L: CurLoop);
900 const SCEV *FoldedMemsetSize =
901 SE->applyLoopGuards(Expr: MemsetSizeSCEV, L: CurLoop);
902
903 LLVM_DEBUG(dbgs() << " Try to fold SCEV based on loop guard\n"
904 << " FoldedMemsetSize: " << *FoldedMemsetSize << "\n"
905 << " FoldedPositiveStride: " << *FoldedPositiveStride
906 << "\n");
907
908 if (FoldedPositiveStride != FoldedMemsetSize) {
909 LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n");
910 return false;
911 }
912 }
913 }
914
915 // Verify that the memset value is loop invariant. If not, we can't promote
916 // the memset.
917 Value *SplatValue = MSI->getValue();
918 if (!SplatValue || !CurLoop->isLoopInvariant(V: SplatValue))
919 return false;
920
921 SmallPtrSet<Instruction *, 1> MSIs;
922 MSIs.insert(Ptr: MSI);
923 return processLoopStridedStore(DestPtr: Pointer, StoreSizeSCEV: SE->getSCEV(V: MSI->getLength()),
924 StoreAlignment: MSI->getDestAlign(), StoredVal: SplatValue, TheStore: MSI, Stores&: MSIs, Ev,
925 BECount, IsNegStride, /*IsLoopMemset=*/true);
926}
927
928/// mayLoopAccessLocation - Return true if the specified loop might access the
929/// specified pointer location, which is a loop-strided access. The 'Access'
930/// argument specifies what the verboten forms of access are (read or write).
931static bool
932mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
933 const SCEV *BECount, const SCEV *StoreSizeSCEV,
934 AliasAnalysis &AA,
935 SmallPtrSetImpl<Instruction *> &IgnoredInsts) {
936 // Get the location that may be stored across the loop. Since the access is
937 // strided positively through memory, we say that the modified location starts
938 // at the pointer and has infinite size.
939 LocationSize AccessSize = LocationSize::afterPointer();
940
941 // If the loop iterates a fixed number of times, we can refine the access size
942 // to be exactly the size of the memset, which is (BECount+1)*StoreSize
943 const SCEVConstant *BECst = dyn_cast<SCEVConstant>(Val: BECount);
944 const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(Val: StoreSizeSCEV);
945 if (BECst && ConstSize) {
946 std::optional<uint64_t> BEInt = BECst->getAPInt().tryZExtValue();
947 std::optional<uint64_t> SizeInt = ConstSize->getAPInt().tryZExtValue();
948 // FIXME: Should this check for overflow?
949 if (BEInt && SizeInt)
950 AccessSize = LocationSize::precise(Value: (*BEInt + 1) * *SizeInt);
951 }
952
953 // TODO: For this to be really effective, we have to dive into the pointer
954 // operand in the store. Store to &A[i] of 100 will always return may alias
955 // with store of &A[100], we need to StoreLoc to be "A" with size of 100,
956 // which will then no-alias a store to &A[100].
957 MemoryLocation StoreLoc(Ptr, AccessSize);
958
959 for (BasicBlock *B : L->blocks())
960 for (Instruction &I : *B)
961 if (!IgnoredInsts.contains(Ptr: &I) &&
962 isModOrRefSet(MRI: AA.getModRefInfo(I: &I, OptLoc: StoreLoc) & Access))
963 return true;
964 return false;
965}
966
967// If we have a negative stride, Start refers to the end of the memory location
968// we're trying to memset. Therefore, we need to recompute the base pointer,
969// which is just Start - BECount*Size.
970static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount,
971 Type *IntPtr, const SCEV *StoreSizeSCEV,
972 ScalarEvolution *SE) {
973 const SCEV *Index = SE->getTruncateOrZeroExtend(V: BECount, Ty: IntPtr);
974 if (!StoreSizeSCEV->isOne()) {
975 // index = back edge count * store size
976 Index = SE->getMulExpr(LHS: Index,
977 RHS: SE->getTruncateOrZeroExtend(V: StoreSizeSCEV, Ty: IntPtr),
978 Flags: SCEV::FlagNUW);
979 }
980 // base pointer = start - index * store size
981 return SE->getMinusSCEV(LHS: Start, RHS: Index);
982}
983
984/// Compute the number of bytes as a SCEV from the backedge taken count.
985///
986/// This also maps the SCEV into the provided type and tries to handle the
987/// computation in a way that will fold cleanly.
988static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
989 const SCEV *StoreSizeSCEV, Loop *CurLoop,
990 const DataLayout *DL, ScalarEvolution *SE) {
991 const SCEV *TripCountSCEV =
992 SE->getTripCountFromExitCount(ExitCount: BECount, EvalTy: IntPtr, L: CurLoop);
993 return SE->getMulExpr(LHS: TripCountSCEV,
994 RHS: SE->getTruncateOrZeroExtend(V: StoreSizeSCEV, Ty: IntPtr),
995 Flags: SCEV::FlagNUW);
996}
997
998/// processLoopStridedStore - We see a strided store of some value. If we can
999/// transform this into a memset or memset_pattern in the loop preheader, do so.
1000bool LoopIdiomRecognize::processLoopStridedStore(
1001 Value *DestPtr, const SCEV *StoreSizeSCEV, MaybeAlign StoreAlignment,
1002 Value *StoredVal, Instruction *TheStore,
1003 SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev,
1004 const SCEV *BECount, bool IsNegStride, bool IsLoopMemset) {
1005 Module *M = TheStore->getModule();
1006 Value *SplatValue = isBytewiseValue(V: StoredVal, DL: *DL);
1007 Constant *PatternValue = nullptr;
1008
1009 if (!SplatValue)
1010 PatternValue = getMemSetPatternValue(V: StoredVal, DL);
1011
1012 assert((SplatValue || PatternValue) &&
1013 "Expected either splat value or pattern value.");
1014
1015 // The trip count of the loop and the base pointer of the addrec SCEV is
1016 // guaranteed to be loop invariant, which means that it should dominate the
1017 // header. This allows us to insert code for it in the preheader.
1018 unsigned DestAS = DestPtr->getType()->getPointerAddressSpace();
1019 BasicBlock *Preheader = CurLoop->getLoopPreheader();
1020 IRBuilder<> Builder(Preheader->getTerminator());
1021 SCEVExpander Expander(*SE, *DL, "loop-idiom");
1022 SCEVExpanderCleaner ExpCleaner(Expander);
1023
1024 Type *DestInt8PtrTy = Builder.getPtrTy(AddrSpace: DestAS);
1025 Type *IntIdxTy = DL->getIndexType(PtrTy: DestPtr->getType());
1026
1027 bool Changed = false;
1028 const SCEV *Start = Ev->getStart();
1029 // Handle negative strided loops.
1030 if (IsNegStride)
1031 Start = getStartForNegStride(Start, BECount, IntPtr: IntIdxTy, StoreSizeSCEV, SE);
1032
1033 // TODO: ideally we should still be able to generate memset if SCEV expander
1034 // is taught to generate the dependencies at the latest point.
1035 if (!Expander.isSafeToExpand(S: Start))
1036 return Changed;
1037
1038 // Okay, we have a strided store "p[i]" of a splattable value. We can turn
1039 // this into a memset in the loop preheader now if we want. However, this
1040 // would be unsafe to do if there is anything else in the loop that may read
1041 // or write to the aliased location. Check for any overlap by generating the
1042 // base pointer and checking the region.
1043 Value *BasePtr =
1044 Expander.expandCodeFor(SH: Start, Ty: DestInt8PtrTy, I: Preheader->getTerminator());
1045
1046 // From here on out, conservatively report to the pass manager that we've
1047 // changed the IR, even if we later clean up these added instructions. There
1048 // may be structural differences e.g. in the order of use lists not accounted
1049 // for in just a textual dump of the IR. This is written as a variable, even
1050 // though statically all the places this dominates could be replaced with
1051 // 'true', with the hope that anyone trying to be clever / "more precise" with
1052 // the return value will read this comment, and leave them alone.
1053 Changed = true;
1054
1055 if (mayLoopAccessLocation(Ptr: BasePtr, Access: ModRefInfo::ModRef, L: CurLoop, BECount,
1056 StoreSizeSCEV, AA&: *AA, IgnoredInsts&: Stores))
1057 return Changed;
1058
1059 if (avoidLIRForMultiBlockLoop(/*IsMemset=*/true, IsLoopMemset))
1060 return Changed;
1061
1062 // Okay, everything looks good, insert the memset.
1063
1064 const SCEV *NumBytesS =
1065 getNumBytes(BECount, IntPtr: IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
1066
1067 // TODO: ideally we should still be able to generate memset if SCEV expander
1068 // is taught to generate the dependencies at the latest point.
1069 if (!Expander.isSafeToExpand(S: NumBytesS))
1070 return Changed;
1071
1072 Value *NumBytes =
1073 Expander.expandCodeFor(SH: NumBytesS, Ty: IntIdxTy, I: Preheader->getTerminator());
1074
1075 if (!SplatValue && !isLibFuncEmittable(M, TLI, TheLibFunc: LibFunc_memset_pattern16))
1076 return Changed;
1077
1078 AAMDNodes AATags = TheStore->getAAMetadata();
1079 for (Instruction *Store : Stores)
1080 AATags = AATags.merge(Other: Store->getAAMetadata());
1081 if (auto CI = dyn_cast<ConstantInt>(Val: NumBytes))
1082 AATags = AATags.extendTo(Len: CI->getZExtValue());
1083 else
1084 AATags = AATags.extendTo(Len: -1);
1085
1086 CallInst *NewCall;
1087 if (SplatValue) {
1088 NewCall = Builder.CreateMemSet(
1089 Ptr: BasePtr, Val: SplatValue, Size: NumBytes, Align: MaybeAlign(StoreAlignment),
1090 /*isVolatile=*/false, TBAATag: AATags.TBAA, ScopeTag: AATags.Scope, NoAliasTag: AATags.NoAlias);
1091 } else {
1092 assert (isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16));
1093 // Everything is emitted in default address space
1094 Type *Int8PtrTy = DestInt8PtrTy;
1095
1096 StringRef FuncName = "memset_pattern16";
1097 FunctionCallee MSP = getOrInsertLibFunc(M, TLI: *TLI, TheLibFunc: LibFunc_memset_pattern16,
1098 RetTy: Builder.getVoidTy(), Args: Int8PtrTy, Args: Int8PtrTy, Args: IntIdxTy);
1099 inferNonMandatoryLibFuncAttrs(M, Name: FuncName, TLI: *TLI);
1100
1101 // Otherwise we should form a memset_pattern16. PatternValue is known to be
1102 // an constant array of 16-bytes. Plop the value into a mergable global.
1103 GlobalVariable *GV = new GlobalVariable(*M, PatternValue->getType(), true,
1104 GlobalValue::PrivateLinkage,
1105 PatternValue, ".memset_pattern");
1106 GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); // Ok to merge these.
1107 GV->setAlignment(Align(16));
1108 Value *PatternPtr = GV;
1109 NewCall = Builder.CreateCall(Callee: MSP, Args: {BasePtr, PatternPtr, NumBytes});
1110
1111 // Set the TBAA info if present.
1112 if (AATags.TBAA)
1113 NewCall->setMetadata(KindID: LLVMContext::MD_tbaa, Node: AATags.TBAA);
1114
1115 if (AATags.Scope)
1116 NewCall->setMetadata(KindID: LLVMContext::MD_alias_scope, Node: AATags.Scope);
1117
1118 if (AATags.NoAlias)
1119 NewCall->setMetadata(KindID: LLVMContext::MD_noalias, Node: AATags.NoAlias);
1120 }
1121
1122 NewCall->setDebugLoc(TheStore->getDebugLoc());
1123
1124 if (MSSAU) {
1125 MemoryAccess *NewMemAcc = MSSAU->createMemoryAccessInBB(
1126 I: NewCall, Definition: nullptr, BB: NewCall->getParent(), Point: MemorySSA::BeforeTerminator);
1127 MSSAU->insertDef(Def: cast<MemoryDef>(Val: NewMemAcc), RenameUses: true);
1128 }
1129
1130 LLVM_DEBUG(dbgs() << " Formed memset: " << *NewCall << "\n"
1131 << " from store to: " << *Ev << " at: " << *TheStore
1132 << "\n");
1133
1134 ORE.emit(RemarkBuilder: [&]() {
1135 OptimizationRemark R(DEBUG_TYPE, "ProcessLoopStridedStore",
1136 NewCall->getDebugLoc(), Preheader);
1137 R << "Transformed loop-strided store in "
1138 << ore::NV("Function", TheStore->getFunction())
1139 << " function into a call to "
1140 << ore::NV("NewFunction", NewCall->getCalledFunction())
1141 << "() intrinsic";
1142 if (!Stores.empty())
1143 R << ore::setExtraArgs();
1144 for (auto *I : Stores) {
1145 R << ore::NV("FromBlock", I->getParent()->getName())
1146 << ore::NV("ToBlock", Preheader->getName());
1147 }
1148 return R;
1149 });
1150
1151 // Okay, the memset has been formed. Zap the original store and anything that
1152 // feeds into it.
1153 for (auto *I : Stores) {
1154 if (MSSAU)
1155 MSSAU->removeMemoryAccess(I, OptimizePhis: true);
1156 deleteDeadInstruction(I);
1157 }
1158 if (MSSAU && VerifyMemorySSA)
1159 MSSAU->getMemorySSA()->verifyMemorySSA();
1160 ++NumMemSet;
1161 ExpCleaner.markResultUsed();
1162 return true;
1163}
1164
1165/// If the stored value is a strided load in the same loop with the same stride
1166/// this may be transformable into a memcpy. This kicks in for stuff like
1167/// for (i) A[i] = B[i];
1168bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,
1169 const SCEV *BECount) {
1170 assert(SI->isUnordered() && "Expected only non-volatile non-ordered stores.");
1171
1172 Value *StorePtr = SI->getPointerOperand();
1173 const SCEVAddRecExpr *StoreEv = cast<SCEVAddRecExpr>(Val: SE->getSCEV(V: StorePtr));
1174 unsigned StoreSize = DL->getTypeStoreSize(Ty: SI->getValueOperand()->getType());
1175
1176 // The store must be feeding a non-volatile load.
1177 LoadInst *LI = cast<LoadInst>(Val: SI->getValueOperand());
1178 assert(LI->isUnordered() && "Expected only non-volatile non-ordered loads.");
1179
1180 // See if the pointer expression is an AddRec like {base,+,1} on the current
1181 // loop, which indicates a strided load. If we have something else, it's a
1182 // random load we can't handle.
1183 Value *LoadPtr = LI->getPointerOperand();
1184 const SCEVAddRecExpr *LoadEv = cast<SCEVAddRecExpr>(Val: SE->getSCEV(V: LoadPtr));
1185
1186 const SCEV *StoreSizeSCEV = SE->getConstant(Ty: StorePtr->getType(), V: StoreSize);
1187 return processLoopStoreOfLoopLoad(DestPtr: StorePtr, SourcePtr: LoadPtr, StoreSize: StoreSizeSCEV,
1188 StoreAlign: SI->getAlign(), LoadAlign: LI->getAlign(), TheStore: SI, TheLoad: LI,
1189 StoreEv, LoadEv, BECount);
1190}
1191
1192namespace {
1193class MemmoveVerifier {
1194public:
1195 explicit MemmoveVerifier(const Value &LoadBasePtr, const Value &StoreBasePtr,
1196 const DataLayout &DL)
1197 : DL(DL), BP1(llvm::GetPointerBaseWithConstantOffset(
1198 Ptr: LoadBasePtr.stripPointerCasts(), Offset&: LoadOff, DL)),
1199 BP2(llvm::GetPointerBaseWithConstantOffset(
1200 Ptr: StoreBasePtr.stripPointerCasts(), Offset&: StoreOff, DL)),
1201 IsSameObject(BP1 == BP2) {}
1202
1203 bool loadAndStoreMayFormMemmove(unsigned StoreSize, bool IsNegStride,
1204 const Instruction &TheLoad,
1205 bool IsMemCpy) const {
1206 if (IsMemCpy) {
1207 // Ensure that LoadBasePtr is after StoreBasePtr or before StoreBasePtr
1208 // for negative stride.
1209 if ((!IsNegStride && LoadOff <= StoreOff) ||
1210 (IsNegStride && LoadOff >= StoreOff))
1211 return false;
1212 } else {
1213 // Ensure that LoadBasePtr is after StoreBasePtr or before StoreBasePtr
1214 // for negative stride. LoadBasePtr shouldn't overlap with StoreBasePtr.
1215 int64_t LoadSize =
1216 DL.getTypeSizeInBits(Ty: TheLoad.getType()).getFixedValue() / 8;
1217 if (BP1 != BP2 || LoadSize != int64_t(StoreSize))
1218 return false;
1219 if ((!IsNegStride && LoadOff < StoreOff + int64_t(StoreSize)) ||
1220 (IsNegStride && LoadOff + LoadSize > StoreOff))
1221 return false;
1222 }
1223 return true;
1224 }
1225
1226private:
1227 const DataLayout &DL;
1228 int64_t LoadOff = 0;
1229 int64_t StoreOff = 0;
1230 const Value *BP1;
1231 const Value *BP2;
1232
1233public:
1234 const bool IsSameObject;
1235};
1236} // namespace
1237
1238bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
1239 Value *DestPtr, Value *SourcePtr, const SCEV *StoreSizeSCEV,
1240 MaybeAlign StoreAlign, MaybeAlign LoadAlign, Instruction *TheStore,
1241 Instruction *TheLoad, const SCEVAddRecExpr *StoreEv,
1242 const SCEVAddRecExpr *LoadEv, const SCEV *BECount) {
1243
1244 // FIXME: until llvm.memcpy.inline supports dynamic sizes, we need to
1245 // conservatively bail here, since otherwise we may have to transform
1246 // llvm.memcpy.inline into llvm.memcpy which is illegal.
1247 if (isa<MemCpyInlineInst>(Val: TheStore))
1248 return false;
1249
1250 // The trip count of the loop and the base pointer of the addrec SCEV is
1251 // guaranteed to be loop invariant, which means that it should dominate the
1252 // header. This allows us to insert code for it in the preheader.
1253 BasicBlock *Preheader = CurLoop->getLoopPreheader();
1254 IRBuilder<> Builder(Preheader->getTerminator());
1255 SCEVExpander Expander(*SE, *DL, "loop-idiom");
1256
1257 SCEVExpanderCleaner ExpCleaner(Expander);
1258
1259 bool Changed = false;
1260 const SCEV *StrStart = StoreEv->getStart();
1261 unsigned StrAS = DestPtr->getType()->getPointerAddressSpace();
1262 Type *IntIdxTy = Builder.getIntNTy(N: DL->getIndexSizeInBits(AS: StrAS));
1263
1264 APInt Stride = getStoreStride(StoreEv);
1265 const SCEVConstant *ConstStoreSize = dyn_cast<SCEVConstant>(Val: StoreSizeSCEV);
1266
1267 // TODO: Deal with non-constant size; Currently expect constant store size
1268 assert(ConstStoreSize && "store size is expected to be a constant");
1269
1270 int64_t StoreSize = ConstStoreSize->getValue()->getZExtValue();
1271 bool IsNegStride = StoreSize == -Stride;
1272
1273 // Handle negative strided loops.
1274 if (IsNegStride)
1275 StrStart =
1276 getStartForNegStride(Start: StrStart, BECount, IntPtr: IntIdxTy, StoreSizeSCEV, SE);
1277
1278 // Okay, we have a strided store "p[i]" of a loaded value. We can turn
1279 // this into a memcpy in the loop preheader now if we want. However, this
1280 // would be unsafe to do if there is anything else in the loop that may read
1281 // or write the memory region we're storing to. This includes the load that
1282 // feeds the stores. Check for an alias by generating the base address and
1283 // checking everything.
1284 Value *StoreBasePtr = Expander.expandCodeFor(
1285 SH: StrStart, Ty: Builder.getPtrTy(AddrSpace: StrAS), I: Preheader->getTerminator());
1286
1287 // From here on out, conservatively report to the pass manager that we've
1288 // changed the IR, even if we later clean up these added instructions. There
1289 // may be structural differences e.g. in the order of use lists not accounted
1290 // for in just a textual dump of the IR. This is written as a variable, even
1291 // though statically all the places this dominates could be replaced with
1292 // 'true', with the hope that anyone trying to be clever / "more precise" with
1293 // the return value will read this comment, and leave them alone.
1294 Changed = true;
1295
1296 SmallPtrSet<Instruction *, 2> IgnoredInsts;
1297 IgnoredInsts.insert(Ptr: TheStore);
1298
1299 bool IsMemCpy = isa<MemCpyInst>(Val: TheStore);
1300 const StringRef InstRemark = IsMemCpy ? "memcpy" : "load and store";
1301
1302 bool LoopAccessStore =
1303 mayLoopAccessLocation(Ptr: StoreBasePtr, Access: ModRefInfo::ModRef, L: CurLoop, BECount,
1304 StoreSizeSCEV, AA&: *AA, IgnoredInsts);
1305 if (LoopAccessStore) {
1306 // For memmove case it's not enough to guarantee that loop doesn't access
1307 // TheStore and TheLoad. Additionally we need to make sure that TheStore is
1308 // the only user of TheLoad.
1309 if (!TheLoad->hasOneUse())
1310 return Changed;
1311 IgnoredInsts.insert(Ptr: TheLoad);
1312 if (mayLoopAccessLocation(Ptr: StoreBasePtr, Access: ModRefInfo::ModRef, L: CurLoop,
1313 BECount, StoreSizeSCEV, AA&: *AA, IgnoredInsts)) {
1314 ORE.emit(RemarkBuilder: [&]() {
1315 return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessStore",
1316 TheStore)
1317 << ore::NV("Inst", InstRemark) << " in "
1318 << ore::NV("Function", TheStore->getFunction())
1319 << " function will not be hoisted: "
1320 << ore::NV("Reason", "The loop may access store location");
1321 });
1322 return Changed;
1323 }
1324 IgnoredInsts.erase(Ptr: TheLoad);
1325 }
1326
1327 const SCEV *LdStart = LoadEv->getStart();
1328 unsigned LdAS = SourcePtr->getType()->getPointerAddressSpace();
1329
1330 // Handle negative strided loops.
1331 if (IsNegStride)
1332 LdStart =
1333 getStartForNegStride(Start: LdStart, BECount, IntPtr: IntIdxTy, StoreSizeSCEV, SE);
1334
1335 // For a memcpy, we have to make sure that the input array is not being
1336 // mutated by the loop.
1337 Value *LoadBasePtr = Expander.expandCodeFor(SH: LdStart, Ty: Builder.getPtrTy(AddrSpace: LdAS),
1338 I: Preheader->getTerminator());
1339
1340 // If the store is a memcpy instruction, we must check if it will write to
1341 // the load memory locations. So remove it from the ignored stores.
1342 MemmoveVerifier Verifier(*LoadBasePtr, *StoreBasePtr, *DL);
1343 if (IsMemCpy && !Verifier.IsSameObject)
1344 IgnoredInsts.erase(Ptr: TheStore);
1345 if (mayLoopAccessLocation(Ptr: LoadBasePtr, Access: ModRefInfo::Mod, L: CurLoop, BECount,
1346 StoreSizeSCEV, AA&: *AA, IgnoredInsts)) {
1347 ORE.emit(RemarkBuilder: [&]() {
1348 return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessLoad", TheLoad)
1349 << ore::NV("Inst", InstRemark) << " in "
1350 << ore::NV("Function", TheStore->getFunction())
1351 << " function will not be hoisted: "
1352 << ore::NV("Reason", "The loop may access load location");
1353 });
1354 return Changed;
1355 }
1356
1357 bool UseMemMove = IsMemCpy ? Verifier.IsSameObject : LoopAccessStore;
1358 if (UseMemMove)
1359 if (!Verifier.loadAndStoreMayFormMemmove(StoreSize, IsNegStride, TheLoad: *TheLoad,
1360 IsMemCpy))
1361 return Changed;
1362
1363 if (avoidLIRForMultiBlockLoop())
1364 return Changed;
1365
1366 // Okay, everything is safe, we can transform this!
1367
1368 const SCEV *NumBytesS =
1369 getNumBytes(BECount, IntPtr: IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
1370
1371 Value *NumBytes =
1372 Expander.expandCodeFor(SH: NumBytesS, Ty: IntIdxTy, I: Preheader->getTerminator());
1373
1374 AAMDNodes AATags = TheLoad->getAAMetadata();
1375 AAMDNodes StoreAATags = TheStore->getAAMetadata();
1376 AATags = AATags.merge(Other: StoreAATags);
1377 if (auto CI = dyn_cast<ConstantInt>(Val: NumBytes))
1378 AATags = AATags.extendTo(Len: CI->getZExtValue());
1379 else
1380 AATags = AATags.extendTo(Len: -1);
1381
1382 CallInst *NewCall = nullptr;
1383 // Check whether to generate an unordered atomic memcpy:
1384 // If the load or store are atomic, then they must necessarily be unordered
1385 // by previous checks.
1386 if (!TheStore->isAtomic() && !TheLoad->isAtomic()) {
1387 if (UseMemMove)
1388 NewCall = Builder.CreateMemMove(
1389 Dst: StoreBasePtr, DstAlign: StoreAlign, Src: LoadBasePtr, SrcAlign: LoadAlign, Size: NumBytes,
1390 /*isVolatile=*/false, TBAATag: AATags.TBAA, ScopeTag: AATags.Scope, NoAliasTag: AATags.NoAlias);
1391 else
1392 NewCall =
1393 Builder.CreateMemCpy(Dst: StoreBasePtr, DstAlign: StoreAlign, Src: LoadBasePtr, SrcAlign: LoadAlign,
1394 Size: NumBytes, /*isVolatile=*/false, TBAATag: AATags.TBAA,
1395 TBAAStructTag: AATags.TBAAStruct, ScopeTag: AATags.Scope, NoAliasTag: AATags.NoAlias);
1396 } else {
1397 // For now don't support unordered atomic memmove.
1398 if (UseMemMove)
1399 return Changed;
1400 // We cannot allow unaligned ops for unordered load/store, so reject
1401 // anything where the alignment isn't at least the element size.
1402 assert((StoreAlign && LoadAlign) &&
1403 "Expect unordered load/store to have align.");
1404 if (*StoreAlign < StoreSize || *LoadAlign < StoreSize)
1405 return Changed;
1406
1407 // If the element.atomic memcpy is not lowered into explicit
1408 // loads/stores later, then it will be lowered into an element-size
1409 // specific lib call. If the lib call doesn't exist for our store size, then
1410 // we shouldn't generate the memcpy.
1411 if (StoreSize > TTI->getAtomicMemIntrinsicMaxElementSize())
1412 return Changed;
1413
1414 // Create the call.
1415 // Note that unordered atomic loads/stores are *required* by the spec to
1416 // have an alignment but non-atomic loads/stores may not.
1417 NewCall = Builder.CreateElementUnorderedAtomicMemCpy(
1418 Dst: StoreBasePtr, DstAlign: *StoreAlign, Src: LoadBasePtr, SrcAlign: *LoadAlign, Size: NumBytes, ElementSize: StoreSize,
1419 TBAATag: AATags.TBAA, TBAAStructTag: AATags.TBAAStruct, ScopeTag: AATags.Scope, NoAliasTag: AATags.NoAlias);
1420 }
1421 NewCall->setDebugLoc(TheStore->getDebugLoc());
1422
1423 if (MSSAU) {
1424 MemoryAccess *NewMemAcc = MSSAU->createMemoryAccessInBB(
1425 I: NewCall, Definition: nullptr, BB: NewCall->getParent(), Point: MemorySSA::BeforeTerminator);
1426 MSSAU->insertDef(Def: cast<MemoryDef>(Val: NewMemAcc), RenameUses: true);
1427 }
1428
1429 LLVM_DEBUG(dbgs() << " Formed new call: " << *NewCall << "\n"
1430 << " from load ptr=" << *LoadEv << " at: " << *TheLoad
1431 << "\n"
1432 << " from store ptr=" << *StoreEv << " at: " << *TheStore
1433 << "\n");
1434
1435 ORE.emit(RemarkBuilder: [&]() {
1436 return OptimizationRemark(DEBUG_TYPE, "ProcessLoopStoreOfLoopLoad",
1437 NewCall->getDebugLoc(), Preheader)
1438 << "Formed a call to "
1439 << ore::NV("NewFunction", NewCall->getCalledFunction())
1440 << "() intrinsic from " << ore::NV("Inst", InstRemark)
1441 << " instruction in " << ore::NV("Function", TheStore->getFunction())
1442 << " function"
1443 << ore::setExtraArgs()
1444 << ore::NV("FromBlock", TheStore->getParent()->getName())
1445 << ore::NV("ToBlock", Preheader->getName());
1446 });
1447
1448 // Okay, a new call to memcpy/memmove has been formed. Zap the original store
1449 // and anything that feeds into it.
1450 if (MSSAU)
1451 MSSAU->removeMemoryAccess(I: TheStore, OptimizePhis: true);
1452 deleteDeadInstruction(I: TheStore);
1453 if (MSSAU && VerifyMemorySSA)
1454 MSSAU->getMemorySSA()->verifyMemorySSA();
1455 if (UseMemMove)
1456 ++NumMemMove;
1457 else
1458 ++NumMemCpy;
1459 ExpCleaner.markResultUsed();
1460 return true;
1461}
1462
1463// When compiling for codesize we avoid idiom recognition for a multi-block loop
1464// unless it is a loop_memset idiom or a memset/memcpy idiom in a nested loop.
1465//
1466bool LoopIdiomRecognize::avoidLIRForMultiBlockLoop(bool IsMemset,
1467 bool IsLoopMemset) {
1468 if (ApplyCodeSizeHeuristics && CurLoop->getNumBlocks() > 1) {
1469 if (CurLoop->isOutermost() && (!IsMemset || !IsLoopMemset)) {
1470 LLVM_DEBUG(dbgs() << " " << CurLoop->getHeader()->getParent()->getName()
1471 << " : LIR " << (IsMemset ? "Memset" : "Memcpy")
1472 << " avoided: multi-block top-level loop\n");
1473 return true;
1474 }
1475 }
1476
1477 return false;
1478}
1479
1480bool LoopIdiomRecognize::runOnNoncountableLoop() {
1481 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F["
1482 << CurLoop->getHeader()->getParent()->getName()
1483 << "] Noncountable Loop %"
1484 << CurLoop->getHeader()->getName() << "\n");
1485
1486 return recognizePopcount() || recognizeAndInsertFFS() ||
1487 recognizeShiftUntilBitTest() || recognizeShiftUntilZero();
1488}
1489
1490/// Check if the given conditional branch is based on the comparison between
1491/// a variable and zero, and if the variable is non-zero or zero (JmpOnZero is
1492/// true), the control yields to the loop entry. If the branch matches the
1493/// behavior, the variable involved in the comparison is returned. This function
1494/// will be called to see if the precondition and postcondition of the loop are
1495/// in desirable form.
1496static Value *matchCondition(BranchInst *BI, BasicBlock *LoopEntry,
1497 bool JmpOnZero = false) {
1498 if (!BI || !BI->isConditional())
1499 return nullptr;
1500
1501 ICmpInst *Cond = dyn_cast<ICmpInst>(Val: BI->getCondition());
1502 if (!Cond)
1503 return nullptr;
1504
1505 ConstantInt *CmpZero = dyn_cast<ConstantInt>(Val: Cond->getOperand(i_nocapture: 1));
1506 if (!CmpZero || !CmpZero->isZero())
1507 return nullptr;
1508
1509 BasicBlock *TrueSucc = BI->getSuccessor(i: 0);
1510 BasicBlock *FalseSucc = BI->getSuccessor(i: 1);
1511 if (JmpOnZero)
1512 std::swap(a&: TrueSucc, b&: FalseSucc);
1513
1514 ICmpInst::Predicate Pred = Cond->getPredicate();
1515 if ((Pred == ICmpInst::ICMP_NE && TrueSucc == LoopEntry) ||
1516 (Pred == ICmpInst::ICMP_EQ && FalseSucc == LoopEntry))
1517 return Cond->getOperand(i_nocapture: 0);
1518
1519 return nullptr;
1520}
1521
1522// Check if the recurrence variable `VarX` is in the right form to create
1523// the idiom. Returns the value coerced to a PHINode if so.
1524static PHINode *getRecurrenceVar(Value *VarX, Instruction *DefX,
1525 BasicBlock *LoopEntry) {
1526 auto *PhiX = dyn_cast<PHINode>(Val: VarX);
1527 if (PhiX && PhiX->getParent() == LoopEntry &&
1528 (PhiX->getOperand(i_nocapture: 0) == DefX || PhiX->getOperand(i_nocapture: 1) == DefX))
1529 return PhiX;
1530 return nullptr;
1531}
1532
1533/// Return true iff the idiom is detected in the loop.
1534///
1535/// Additionally:
1536/// 1) \p CntInst is set to the instruction counting the population bit.
1537/// 2) \p CntPhi is set to the corresponding phi node.
1538/// 3) \p Var is set to the value whose population bits are being counted.
1539///
1540/// The core idiom we are trying to detect is:
1541/// \code
1542/// if (x0 != 0)
1543/// goto loop-exit // the precondition of the loop
1544/// cnt0 = init-val;
1545/// do {
1546/// x1 = phi (x0, x2);
1547/// cnt1 = phi(cnt0, cnt2);
1548///
1549/// cnt2 = cnt1 + 1;
1550/// ...
1551/// x2 = x1 & (x1 - 1);
1552/// ...
1553/// } while(x != 0);
1554///
1555/// loop-exit:
1556/// \endcode
1557static bool detectPopcountIdiom(Loop *CurLoop, BasicBlock *PreCondBB,
1558 Instruction *&CntInst, PHINode *&CntPhi,
1559 Value *&Var) {
1560 // step 1: Check to see if the look-back branch match this pattern:
1561 // "if (a!=0) goto loop-entry".
1562 BasicBlock *LoopEntry;
1563 Instruction *DefX2, *CountInst;
1564 Value *VarX1, *VarX0;
1565 PHINode *PhiX, *CountPhi;
1566
1567 DefX2 = CountInst = nullptr;
1568 VarX1 = VarX0 = nullptr;
1569 PhiX = CountPhi = nullptr;
1570 LoopEntry = *(CurLoop->block_begin());
1571
1572 // step 1: Check if the loop-back branch is in desirable form.
1573 {
1574 if (Value *T = matchCondition(
1575 BI: dyn_cast<BranchInst>(Val: LoopEntry->getTerminator()), LoopEntry))
1576 DefX2 = dyn_cast<Instruction>(Val: T);
1577 else
1578 return false;
1579 }
1580
1581 // step 2: detect instructions corresponding to "x2 = x1 & (x1 - 1)"
1582 {
1583 if (!DefX2 || DefX2->getOpcode() != Instruction::And)
1584 return false;
1585
1586 BinaryOperator *SubOneOp;
1587
1588 if ((SubOneOp = dyn_cast<BinaryOperator>(Val: DefX2->getOperand(i: 0))))
1589 VarX1 = DefX2->getOperand(i: 1);
1590 else {
1591 VarX1 = DefX2->getOperand(i: 0);
1592 SubOneOp = dyn_cast<BinaryOperator>(Val: DefX2->getOperand(i: 1));
1593 }
1594 if (!SubOneOp || SubOneOp->getOperand(i_nocapture: 0) != VarX1)
1595 return false;
1596
1597 ConstantInt *Dec = dyn_cast<ConstantInt>(Val: SubOneOp->getOperand(i_nocapture: 1));
1598 if (!Dec ||
1599 !((SubOneOp->getOpcode() == Instruction::Sub && Dec->isOne()) ||
1600 (SubOneOp->getOpcode() == Instruction::Add &&
1601 Dec->isMinusOne()))) {
1602 return false;
1603 }
1604 }
1605
1606 // step 3: Check the recurrence of variable X
1607 PhiX = getRecurrenceVar(VarX: VarX1, DefX: DefX2, LoopEntry);
1608 if (!PhiX)
1609 return false;
1610
1611 // step 4: Find the instruction which count the population: cnt2 = cnt1 + 1
1612 {
1613 CountInst = nullptr;
1614 for (Instruction &Inst : llvm::make_range(
1615 x: LoopEntry->getFirstNonPHI()->getIterator(), y: LoopEntry->end())) {
1616 if (Inst.getOpcode() != Instruction::Add)
1617 continue;
1618
1619 ConstantInt *Inc = dyn_cast<ConstantInt>(Val: Inst.getOperand(i: 1));
1620 if (!Inc || !Inc->isOne())
1621 continue;
1622
1623 PHINode *Phi = getRecurrenceVar(VarX: Inst.getOperand(i: 0), DefX: &Inst, LoopEntry);
1624 if (!Phi)
1625 continue;
1626
1627 // Check if the result of the instruction is live of the loop.
1628 bool LiveOutLoop = false;
1629 for (User *U : Inst.users()) {
1630 if ((cast<Instruction>(Val: U))->getParent() != LoopEntry) {
1631 LiveOutLoop = true;
1632 break;
1633 }
1634 }
1635
1636 if (LiveOutLoop) {
1637 CountInst = &Inst;
1638 CountPhi = Phi;
1639 break;
1640 }
1641 }
1642
1643 if (!CountInst)
1644 return false;
1645 }
1646
1647 // step 5: check if the precondition is in this form:
1648 // "if (x != 0) goto loop-head ; else goto somewhere-we-don't-care;"
1649 {
1650 auto *PreCondBr = dyn_cast<BranchInst>(Val: PreCondBB->getTerminator());
1651 Value *T = matchCondition(BI: PreCondBr, LoopEntry: CurLoop->getLoopPreheader());
1652 if (T != PhiX->getOperand(i_nocapture: 0) && T != PhiX->getOperand(i_nocapture: 1))
1653 return false;
1654
1655 CntInst = CountInst;
1656 CntPhi = CountPhi;
1657 Var = T;
1658 }
1659
1660 return true;
1661}
1662
1663/// Return true if the idiom is detected in the loop.
1664///
1665/// Additionally:
1666/// 1) \p CntInst is set to the instruction Counting Leading Zeros (CTLZ)
1667/// or nullptr if there is no such.
1668/// 2) \p CntPhi is set to the corresponding phi node
1669/// or nullptr if there is no such.
1670/// 3) \p Var is set to the value whose CTLZ could be used.
1671/// 4) \p DefX is set to the instruction calculating Loop exit condition.
1672///
1673/// The core idiom we are trying to detect is:
1674/// \code
1675/// if (x0 == 0)
1676/// goto loop-exit // the precondition of the loop
1677/// cnt0 = init-val;
1678/// do {
1679/// x = phi (x0, x.next); //PhiX
1680/// cnt = phi(cnt0, cnt.next);
1681///
1682/// cnt.next = cnt + 1;
1683/// ...
1684/// x.next = x >> 1; // DefX
1685/// ...
1686/// } while(x.next != 0);
1687///
1688/// loop-exit:
1689/// \endcode
1690static bool detectShiftUntilZeroIdiom(Loop *CurLoop, const DataLayout &DL,
1691 Intrinsic::ID &IntrinID, Value *&InitX,
1692 Instruction *&CntInst, PHINode *&CntPhi,
1693 Instruction *&DefX) {
1694 BasicBlock *LoopEntry;
1695 Value *VarX = nullptr;
1696
1697 DefX = nullptr;
1698 CntInst = nullptr;
1699 CntPhi = nullptr;
1700 LoopEntry = *(CurLoop->block_begin());
1701
1702 // step 1: Check if the loop-back branch is in desirable form.
1703 if (Value *T = matchCondition(
1704 BI: dyn_cast<BranchInst>(Val: LoopEntry->getTerminator()), LoopEntry))
1705 DefX = dyn_cast<Instruction>(Val: T);
1706 else
1707 return false;
1708
1709 // step 2: detect instructions corresponding to "x.next = x >> 1 or x << 1"
1710 if (!DefX || !DefX->isShift())
1711 return false;
1712 IntrinID = DefX->getOpcode() == Instruction::Shl ? Intrinsic::cttz :
1713 Intrinsic::ctlz;
1714 ConstantInt *Shft = dyn_cast<ConstantInt>(Val: DefX->getOperand(i: 1));
1715 if (!Shft || !Shft->isOne())
1716 return false;
1717 VarX = DefX->getOperand(i: 0);
1718
1719 // step 3: Check the recurrence of variable X
1720 PHINode *PhiX = getRecurrenceVar(VarX, DefX, LoopEntry);
1721 if (!PhiX)
1722 return false;
1723
1724 InitX = PhiX->getIncomingValueForBlock(BB: CurLoop->getLoopPreheader());
1725
1726 // Make sure the initial value can't be negative otherwise the ashr in the
1727 // loop might never reach zero which would make the loop infinite.
1728 if (DefX->getOpcode() == Instruction::AShr && !isKnownNonNegative(V: InitX, SQ: DL))
1729 return false;
1730
1731 // step 4: Find the instruction which count the CTLZ: cnt.next = cnt + 1
1732 // or cnt.next = cnt + -1.
1733 // TODO: We can skip the step. If loop trip count is known (CTLZ),
1734 // then all uses of "cnt.next" could be optimized to the trip count
1735 // plus "cnt0". Currently it is not optimized.
1736 // This step could be used to detect POPCNT instruction:
1737 // cnt.next = cnt + (x.next & 1)
1738 for (Instruction &Inst : llvm::make_range(
1739 x: LoopEntry->getFirstNonPHI()->getIterator(), y: LoopEntry->end())) {
1740 if (Inst.getOpcode() != Instruction::Add)
1741 continue;
1742
1743 ConstantInt *Inc = dyn_cast<ConstantInt>(Val: Inst.getOperand(i: 1));
1744 if (!Inc || (!Inc->isOne() && !Inc->isMinusOne()))
1745 continue;
1746
1747 PHINode *Phi = getRecurrenceVar(VarX: Inst.getOperand(i: 0), DefX: &Inst, LoopEntry);
1748 if (!Phi)
1749 continue;
1750
1751 CntInst = &Inst;
1752 CntPhi = Phi;
1753 break;
1754 }
1755 if (!CntInst)
1756 return false;
1757
1758 return true;
1759}
1760
1761/// Recognize CTLZ or CTTZ idiom in a non-countable loop and convert the loop
1762/// to countable (with CTLZ / CTTZ trip count). If CTLZ / CTTZ inserted as a new
1763/// trip count returns true; otherwise, returns false.
1764bool LoopIdiomRecognize::recognizeAndInsertFFS() {
1765 // Give up if the loop has multiple blocks or multiple backedges.
1766 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1)
1767 return false;
1768
1769 Intrinsic::ID IntrinID;
1770 Value *InitX;
1771 Instruction *DefX = nullptr;
1772 PHINode *CntPhi = nullptr;
1773 Instruction *CntInst = nullptr;
1774 // Help decide if transformation is profitable. For ShiftUntilZero idiom,
1775 // this is always 6.
1776 size_t IdiomCanonicalSize = 6;
1777
1778 if (!detectShiftUntilZeroIdiom(CurLoop, DL: *DL, IntrinID, InitX,
1779 CntInst, CntPhi, DefX))
1780 return false;
1781
1782 bool IsCntPhiUsedOutsideLoop = false;
1783 for (User *U : CntPhi->users())
1784 if (!CurLoop->contains(Inst: cast<Instruction>(Val: U))) {
1785 IsCntPhiUsedOutsideLoop = true;
1786 break;
1787 }
1788 bool IsCntInstUsedOutsideLoop = false;
1789 for (User *U : CntInst->users())
1790 if (!CurLoop->contains(Inst: cast<Instruction>(Val: U))) {
1791 IsCntInstUsedOutsideLoop = true;
1792 break;
1793 }
1794 // If both CntInst and CntPhi are used outside the loop the profitability
1795 // is questionable.
1796 if (IsCntInstUsedOutsideLoop && IsCntPhiUsedOutsideLoop)
1797 return false;
1798
1799 // For some CPUs result of CTLZ(X) intrinsic is undefined
1800 // when X is 0. If we can not guarantee X != 0, we need to check this
1801 // when expand.
1802 bool ZeroCheck = false;
1803 // It is safe to assume Preheader exist as it was checked in
1804 // parent function RunOnLoop.
1805 BasicBlock *PH = CurLoop->getLoopPreheader();
1806
1807 // If we are using the count instruction outside the loop, make sure we
1808 // have a zero check as a precondition. Without the check the loop would run
1809 // one iteration for before any check of the input value. This means 0 and 1
1810 // would have identical behavior in the original loop and thus
1811 if (!IsCntPhiUsedOutsideLoop) {
1812 auto *PreCondBB = PH->getSinglePredecessor();
1813 if (!PreCondBB)
1814 return false;
1815 auto *PreCondBI = dyn_cast<BranchInst>(Val: PreCondBB->getTerminator());
1816 if (!PreCondBI)
1817 return false;
1818 if (matchCondition(BI: PreCondBI, LoopEntry: PH) != InitX)
1819 return false;
1820 ZeroCheck = true;
1821 }
1822
1823 // Check if CTLZ / CTTZ intrinsic is profitable. Assume it is always
1824 // profitable if we delete the loop.
1825
1826 // the loop has only 6 instructions:
1827 // %n.addr.0 = phi [ %n, %entry ], [ %shr, %while.cond ]
1828 // %i.0 = phi [ %i0, %entry ], [ %inc, %while.cond ]
1829 // %shr = ashr %n.addr.0, 1
1830 // %tobool = icmp eq %shr, 0
1831 // %inc = add nsw %i.0, 1
1832 // br i1 %tobool
1833
1834 const Value *Args[] = {InitX,
1835 ConstantInt::getBool(Context&: InitX->getContext(), V: ZeroCheck)};
1836
1837 // @llvm.dbg doesn't count as they have no semantic effect.
1838 auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug();
1839 uint32_t HeaderSize =
1840 std::distance(first: InstWithoutDebugIt.begin(), last: InstWithoutDebugIt.end());
1841
1842 IntrinsicCostAttributes Attrs(IntrinID, InitX->getType(), Args);
1843 InstructionCost Cost =
1844 TTI->getIntrinsicInstrCost(ICA: Attrs, CostKind: TargetTransformInfo::TCK_SizeAndLatency);
1845 if (HeaderSize != IdiomCanonicalSize &&
1846 Cost > TargetTransformInfo::TCC_Basic)
1847 return false;
1848
1849 transformLoopToCountable(IntrinID, PreCondBB: PH, CntInst, CntPhi, Var: InitX, DefX,
1850 DL: DefX->getDebugLoc(), ZeroCheck,
1851 IsCntPhiUsedOutsideLoop);
1852 return true;
1853}
1854
1855/// Recognizes a population count idiom in a non-countable loop.
1856///
1857/// If detected, transforms the relevant code to issue the popcount intrinsic
1858/// function call, and returns true; otherwise, returns false.
1859bool LoopIdiomRecognize::recognizePopcount() {
1860 if (TTI->getPopcntSupport(IntTyWidthInBit: 32) != TargetTransformInfo::PSK_FastHardware)
1861 return false;
1862
1863 // Counting population are usually conducted by few arithmetic instructions.
1864 // Such instructions can be easily "absorbed" by vacant slots in a
1865 // non-compact loop. Therefore, recognizing popcount idiom only makes sense
1866 // in a compact loop.
1867
1868 // Give up if the loop has multiple blocks or multiple backedges.
1869 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1)
1870 return false;
1871
1872 BasicBlock *LoopBody = *(CurLoop->block_begin());
1873 if (LoopBody->size() >= 20) {
1874 // The loop is too big, bail out.
1875 return false;
1876 }
1877
1878 // It should have a preheader containing nothing but an unconditional branch.
1879 BasicBlock *PH = CurLoop->getLoopPreheader();
1880 if (!PH || &PH->front() != PH->getTerminator())
1881 return false;
1882 auto *EntryBI = dyn_cast<BranchInst>(Val: PH->getTerminator());
1883 if (!EntryBI || EntryBI->isConditional())
1884 return false;
1885
1886 // It should have a precondition block where the generated popcount intrinsic
1887 // function can be inserted.
1888 auto *PreCondBB = PH->getSinglePredecessor();
1889 if (!PreCondBB)
1890 return false;
1891 auto *PreCondBI = dyn_cast<BranchInst>(Val: PreCondBB->getTerminator());
1892 if (!PreCondBI || PreCondBI->isUnconditional())
1893 return false;
1894
1895 Instruction *CntInst;
1896 PHINode *CntPhi;
1897 Value *Val;
1898 if (!detectPopcountIdiom(CurLoop, PreCondBB, CntInst, CntPhi, Var&: Val))
1899 return false;
1900
1901 transformLoopToPopcount(PreCondBB, CntInst, CntPhi, Var: Val);
1902 return true;
1903}
1904
1905static CallInst *createPopcntIntrinsic(IRBuilder<> &IRBuilder, Value *Val,
1906 const DebugLoc &DL) {
1907 Value *Ops[] = {Val};
1908 Type *Tys[] = {Val->getType()};
1909
1910 Module *M = IRBuilder.GetInsertBlock()->getParent()->getParent();
1911 Function *Func = Intrinsic::getDeclaration(M, Intrinsic::id: ctpop, Tys);
1912 CallInst *CI = IRBuilder.CreateCall(Callee: Func, Args: Ops);
1913 CI->setDebugLoc(DL);
1914
1915 return CI;
1916}
1917
1918static CallInst *createFFSIntrinsic(IRBuilder<> &IRBuilder, Value *Val,
1919 const DebugLoc &DL, bool ZeroCheck,
1920 Intrinsic::ID IID) {
1921 Value *Ops[] = {Val, IRBuilder.getInt1(V: ZeroCheck)};
1922 Type *Tys[] = {Val->getType()};
1923
1924 Module *M = IRBuilder.GetInsertBlock()->getParent()->getParent();
1925 Function *Func = Intrinsic::getDeclaration(M, id: IID, Tys);
1926 CallInst *CI = IRBuilder.CreateCall(Callee: Func, Args: Ops);
1927 CI->setDebugLoc(DL);
1928
1929 return CI;
1930}
1931
1932/// Transform the following loop (Using CTLZ, CTTZ is similar):
1933/// loop:
1934/// CntPhi = PHI [Cnt0, CntInst]
1935/// PhiX = PHI [InitX, DefX]
1936/// CntInst = CntPhi + 1
1937/// DefX = PhiX >> 1
1938/// LOOP_BODY
1939/// Br: loop if (DefX != 0)
1940/// Use(CntPhi) or Use(CntInst)
1941///
1942/// Into:
1943/// If CntPhi used outside the loop:
1944/// CountPrev = BitWidth(InitX) - CTLZ(InitX >> 1)
1945/// Count = CountPrev + 1
1946/// else
1947/// Count = BitWidth(InitX) - CTLZ(InitX)
1948/// loop:
1949/// CntPhi = PHI [Cnt0, CntInst]
1950/// PhiX = PHI [InitX, DefX]
1951/// PhiCount = PHI [Count, Dec]
1952/// CntInst = CntPhi + 1
1953/// DefX = PhiX >> 1
1954/// Dec = PhiCount - 1
1955/// LOOP_BODY
1956/// Br: loop if (Dec != 0)
1957/// Use(CountPrev + Cnt0) // Use(CntPhi)
1958/// or
1959/// Use(Count + Cnt0) // Use(CntInst)
1960///
1961/// If LOOP_BODY is empty the loop will be deleted.
1962/// If CntInst and DefX are not used in LOOP_BODY they will be removed.
1963void LoopIdiomRecognize::transformLoopToCountable(
1964 Intrinsic::ID IntrinID, BasicBlock *Preheader, Instruction *CntInst,
1965 PHINode *CntPhi, Value *InitX, Instruction *DefX, const DebugLoc &DL,
1966 bool ZeroCheck, bool IsCntPhiUsedOutsideLoop) {
1967 BranchInst *PreheaderBr = cast<BranchInst>(Val: Preheader->getTerminator());
1968
1969 // Step 1: Insert the CTLZ/CTTZ instruction at the end of the preheader block
1970 IRBuilder<> Builder(PreheaderBr);
1971 Builder.SetCurrentDebugLocation(DL);
1972
1973 // If there are no uses of CntPhi crate:
1974 // Count = BitWidth - CTLZ(InitX);
1975 // NewCount = Count;
1976 // If there are uses of CntPhi create:
1977 // NewCount = BitWidth - CTLZ(InitX >> 1);
1978 // Count = NewCount + 1;
1979 Value *InitXNext;
1980 if (IsCntPhiUsedOutsideLoop) {
1981 if (DefX->getOpcode() == Instruction::AShr)
1982 InitXNext = Builder.CreateAShr(LHS: InitX, RHS: 1);
1983 else if (DefX->getOpcode() == Instruction::LShr)
1984 InitXNext = Builder.CreateLShr(LHS: InitX, RHS: 1);
1985 else if (DefX->getOpcode() == Instruction::Shl) // cttz
1986 InitXNext = Builder.CreateShl(LHS: InitX, RHS: 1);
1987 else
1988 llvm_unreachable("Unexpected opcode!");
1989 } else
1990 InitXNext = InitX;
1991 Value *Count =
1992 createFFSIntrinsic(IRBuilder&: Builder, Val: InitXNext, DL, ZeroCheck, IID: IntrinID);
1993 Type *CountTy = Count->getType();
1994 Count = Builder.CreateSub(
1995 LHS: ConstantInt::get(Ty: CountTy, V: CountTy->getIntegerBitWidth()), RHS: Count);
1996 Value *NewCount = Count;
1997 if (IsCntPhiUsedOutsideLoop)
1998 Count = Builder.CreateAdd(LHS: Count, RHS: ConstantInt::get(Ty: CountTy, V: 1));
1999
2000 NewCount = Builder.CreateZExtOrTrunc(V: NewCount, DestTy: CntInst->getType());
2001
2002 Value *CntInitVal = CntPhi->getIncomingValueForBlock(BB: Preheader);
2003 if (cast<ConstantInt>(Val: CntInst->getOperand(i: 1))->isOne()) {
2004 // If the counter was being incremented in the loop, add NewCount to the
2005 // counter's initial value, but only if the initial value is not zero.
2006 ConstantInt *InitConst = dyn_cast<ConstantInt>(Val: CntInitVal);
2007 if (!InitConst || !InitConst->isZero())
2008 NewCount = Builder.CreateAdd(LHS: NewCount, RHS: CntInitVal);
2009 } else {
2010 // If the count was being decremented in the loop, subtract NewCount from
2011 // the counter's initial value.
2012 NewCount = Builder.CreateSub(LHS: CntInitVal, RHS: NewCount);
2013 }
2014
2015 // Step 2: Insert new IV and loop condition:
2016 // loop:
2017 // ...
2018 // PhiCount = PHI [Count, Dec]
2019 // ...
2020 // Dec = PhiCount - 1
2021 // ...
2022 // Br: loop if (Dec != 0)
2023 BasicBlock *Body = *(CurLoop->block_begin());
2024 auto *LbBr = cast<BranchInst>(Val: Body->getTerminator());
2025 ICmpInst *LbCond = cast<ICmpInst>(Val: LbBr->getCondition());
2026
2027 PHINode *TcPhi = PHINode::Create(Ty: CountTy, NumReservedValues: 2, NameStr: "tcphi");
2028 TcPhi->insertBefore(InsertPos: Body->begin());
2029
2030 Builder.SetInsertPoint(LbCond);
2031 Instruction *TcDec = cast<Instruction>(Val: Builder.CreateSub(
2032 LHS: TcPhi, RHS: ConstantInt::get(Ty: CountTy, V: 1), Name: "tcdec", HasNUW: false, HasNSW: true));
2033
2034 TcPhi->addIncoming(V: Count, BB: Preheader);
2035 TcPhi->addIncoming(V: TcDec, BB: Body);
2036
2037 CmpInst::Predicate Pred =
2038 (LbBr->getSuccessor(i: 0) == Body) ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ;
2039 LbCond->setPredicate(Pred);
2040 LbCond->setOperand(i_nocapture: 0, Val_nocapture: TcDec);
2041 LbCond->setOperand(i_nocapture: 1, Val_nocapture: ConstantInt::get(Ty: CountTy, V: 0));
2042
2043 // Step 3: All the references to the original counter outside
2044 // the loop are replaced with the NewCount
2045 if (IsCntPhiUsedOutsideLoop)
2046 CntPhi->replaceUsesOutsideBlock(V: NewCount, BB: Body);
2047 else
2048 CntInst->replaceUsesOutsideBlock(V: NewCount, BB: Body);
2049
2050 // step 4: Forget the "non-computable" trip-count SCEV associated with the
2051 // loop. The loop would otherwise not be deleted even if it becomes empty.
2052 SE->forgetLoop(L: CurLoop);
2053}
2054
2055void LoopIdiomRecognize::transformLoopToPopcount(BasicBlock *PreCondBB,
2056 Instruction *CntInst,
2057 PHINode *CntPhi, Value *Var) {
2058 BasicBlock *PreHead = CurLoop->getLoopPreheader();
2059 auto *PreCondBr = cast<BranchInst>(Val: PreCondBB->getTerminator());
2060 const DebugLoc &DL = CntInst->getDebugLoc();
2061
2062 // Assuming before transformation, the loop is following:
2063 // if (x) // the precondition
2064 // do { cnt++; x &= x - 1; } while(x);
2065
2066 // Step 1: Insert the ctpop instruction at the end of the precondition block
2067 IRBuilder<> Builder(PreCondBr);
2068 Value *PopCnt, *PopCntZext, *NewCount, *TripCnt;
2069 {
2070 PopCnt = createPopcntIntrinsic(IRBuilder&: Builder, Val: Var, DL);
2071 NewCount = PopCntZext =
2072 Builder.CreateZExtOrTrunc(V: PopCnt, DestTy: cast<IntegerType>(Val: CntPhi->getType()));
2073
2074 if (NewCount != PopCnt)
2075 (cast<Instruction>(Val: NewCount))->setDebugLoc(DL);
2076
2077 // TripCnt is exactly the number of iterations the loop has
2078 TripCnt = NewCount;
2079
2080 // If the population counter's initial value is not zero, insert Add Inst.
2081 Value *CntInitVal = CntPhi->getIncomingValueForBlock(BB: PreHead);
2082 ConstantInt *InitConst = dyn_cast<ConstantInt>(Val: CntInitVal);
2083 if (!InitConst || !InitConst->isZero()) {
2084 NewCount = Builder.CreateAdd(LHS: NewCount, RHS: CntInitVal);
2085 (cast<Instruction>(Val: NewCount))->setDebugLoc(DL);
2086 }
2087 }
2088
2089 // Step 2: Replace the precondition from "if (x == 0) goto loop-exit" to
2090 // "if (NewCount == 0) loop-exit". Without this change, the intrinsic
2091 // function would be partial dead code, and downstream passes will drag
2092 // it back from the precondition block to the preheader.
2093 {
2094 ICmpInst *PreCond = cast<ICmpInst>(Val: PreCondBr->getCondition());
2095
2096 Value *Opnd0 = PopCntZext;
2097 Value *Opnd1 = ConstantInt::get(Ty: PopCntZext->getType(), V: 0);
2098 if (PreCond->getOperand(i_nocapture: 0) != Var)
2099 std::swap(a&: Opnd0, b&: Opnd1);
2100
2101 ICmpInst *NewPreCond = cast<ICmpInst>(
2102 Val: Builder.CreateICmp(P: PreCond->getPredicate(), LHS: Opnd0, RHS: Opnd1));
2103 PreCondBr->setCondition(NewPreCond);
2104
2105 RecursivelyDeleteTriviallyDeadInstructions(V: PreCond, TLI);
2106 }
2107
2108 // Step 3: Note that the population count is exactly the trip count of the
2109 // loop in question, which enable us to convert the loop from noncountable
2110 // loop into a countable one. The benefit is twofold:
2111 //
2112 // - If the loop only counts population, the entire loop becomes dead after
2113 // the transformation. It is a lot easier to prove a countable loop dead
2114 // than to prove a noncountable one. (In some C dialects, an infinite loop
2115 // isn't dead even if it computes nothing useful. In general, DCE needs
2116 // to prove a noncountable loop finite before safely delete it.)
2117 //
2118 // - If the loop also performs something else, it remains alive.
2119 // Since it is transformed to countable form, it can be aggressively
2120 // optimized by some optimizations which are in general not applicable
2121 // to a noncountable loop.
2122 //
2123 // After this step, this loop (conceptually) would look like following:
2124 // newcnt = __builtin_ctpop(x);
2125 // t = newcnt;
2126 // if (x)
2127 // do { cnt++; x &= x-1; t--) } while (t > 0);
2128 BasicBlock *Body = *(CurLoop->block_begin());
2129 {
2130 auto *LbBr = cast<BranchInst>(Val: Body->getTerminator());
2131 ICmpInst *LbCond = cast<ICmpInst>(Val: LbBr->getCondition());
2132 Type *Ty = TripCnt->getType();
2133
2134 PHINode *TcPhi = PHINode::Create(Ty, NumReservedValues: 2, NameStr: "tcphi");
2135 TcPhi->insertBefore(InsertPos: Body->begin());
2136
2137 Builder.SetInsertPoint(LbCond);
2138 Instruction *TcDec = cast<Instruction>(
2139 Val: Builder.CreateSub(LHS: TcPhi, RHS: ConstantInt::get(Ty, V: 1),
2140 Name: "tcdec", HasNUW: false, HasNSW: true));
2141
2142 TcPhi->addIncoming(V: TripCnt, BB: PreHead);
2143 TcPhi->addIncoming(V: TcDec, BB: Body);
2144
2145 CmpInst::Predicate Pred =
2146 (LbBr->getSuccessor(i: 0) == Body) ? CmpInst::ICMP_UGT : CmpInst::ICMP_SLE;
2147 LbCond->setPredicate(Pred);
2148 LbCond->setOperand(i_nocapture: 0, Val_nocapture: TcDec);
2149 LbCond->setOperand(i_nocapture: 1, Val_nocapture: ConstantInt::get(Ty, V: 0));
2150 }
2151
2152 // Step 4: All the references to the original population counter outside
2153 // the loop are replaced with the NewCount -- the value returned from
2154 // __builtin_ctpop().
2155 CntInst->replaceUsesOutsideBlock(V: NewCount, BB: Body);
2156
2157 // step 5: Forget the "non-computable" trip-count SCEV associated with the
2158 // loop. The loop would otherwise not be deleted even if it becomes empty.
2159 SE->forgetLoop(L: CurLoop);
2160}
2161
2162/// Match loop-invariant value.
2163template <typename SubPattern_t> struct match_LoopInvariant {
2164 SubPattern_t SubPattern;
2165 const Loop *L;
2166
2167 match_LoopInvariant(const SubPattern_t &SP, const Loop *L)
2168 : SubPattern(SP), L(L) {}
2169
2170 template <typename ITy> bool match(ITy *V) {
2171 return L->isLoopInvariant(V) && SubPattern.match(V);
2172 }
2173};
2174
2175/// Matches if the value is loop-invariant.
2176template <typename Ty>
2177inline match_LoopInvariant<Ty> m_LoopInvariant(const Ty &M, const Loop *L) {
2178 return match_LoopInvariant<Ty>(M, L);
2179}
2180
2181/// Return true if the idiom is detected in the loop.
2182///
2183/// The core idiom we are trying to detect is:
2184/// \code
2185/// entry:
2186/// <...>
2187/// %bitmask = shl i32 1, %bitpos
2188/// br label %loop
2189///
2190/// loop:
2191/// %x.curr = phi i32 [ %x, %entry ], [ %x.next, %loop ]
2192/// %x.curr.bitmasked = and i32 %x.curr, %bitmask
2193/// %x.curr.isbitunset = icmp eq i32 %x.curr.bitmasked, 0
2194/// %x.next = shl i32 %x.curr, 1
2195/// <...>
2196/// br i1 %x.curr.isbitunset, label %loop, label %end
2197///
2198/// end:
2199/// %x.curr.res = phi i32 [ %x.curr, %loop ] <...>
2200/// %x.next.res = phi i32 [ %x.next, %loop ] <...>
2201/// <...>
2202/// \endcode
2203static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
2204 Value *&BitMask, Value *&BitPos,
2205 Value *&CurrX, Instruction *&NextX) {
2206 LLVM_DEBUG(dbgs() << DEBUG_TYPE
2207 " Performing shift-until-bittest idiom detection.\n");
2208
2209 // Give up if the loop has multiple blocks or multiple backedges.
2210 if (CurLoop->getNumBlocks() != 1 || CurLoop->getNumBackEdges() != 1) {
2211 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad block/backedge count.\n");
2212 return false;
2213 }
2214
2215 BasicBlock *LoopHeaderBB = CurLoop->getHeader();
2216 BasicBlock *LoopPreheaderBB = CurLoop->getLoopPreheader();
2217 assert(LoopPreheaderBB && "There is always a loop preheader.");
2218
2219 using namespace PatternMatch;
2220
2221 // Step 1: Check if the loop backedge is in desirable form.
2222
2223 ICmpInst::Predicate Pred;
2224 Value *CmpLHS, *CmpRHS;
2225 BasicBlock *TrueBB, *FalseBB;
2226 if (!match(V: LoopHeaderBB->getTerminator(),
2227 P: m_Br(C: m_ICmp(Pred, L: m_Value(V&: CmpLHS), R: m_Value(V&: CmpRHS)),
2228 T: m_BasicBlock(V&: TrueBB), F: m_BasicBlock(V&: FalseBB)))) {
2229 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge structure.\n");
2230 return false;
2231 }
2232
2233 // Step 2: Check if the backedge's condition is in desirable form.
2234
2235 auto MatchVariableBitMask = [&]() {
2236 return ICmpInst::isEquality(P: Pred) && match(V: CmpRHS, P: m_Zero()) &&
2237 match(V: CmpLHS,
2238 P: m_c_And(L: m_Value(V&: CurrX),
2239 R: m_CombineAnd(
2240 L: m_Value(V&: BitMask),
2241 R: m_LoopInvariant(M: m_Shl(L: m_One(), R: m_Value(V&: BitPos)),
2242 L: CurLoop))));
2243 };
2244 auto MatchConstantBitMask = [&]() {
2245 return ICmpInst::isEquality(P: Pred) && match(V: CmpRHS, P: m_Zero()) &&
2246 match(V: CmpLHS, P: m_And(L: m_Value(V&: CurrX),
2247 R: m_CombineAnd(L: m_Value(V&: BitMask), R: m_Power2()))) &&
2248 (BitPos = ConstantExpr::getExactLogBase2(C: cast<Constant>(Val: BitMask)));
2249 };
2250 auto MatchDecomposableConstantBitMask = [&]() {
2251 APInt Mask;
2252 return llvm::decomposeBitTestICmp(LHS: CmpLHS, RHS: CmpRHS, Pred, X&: CurrX, Mask) &&
2253 ICmpInst::isEquality(P: Pred) && Mask.isPowerOf2() &&
2254 (BitMask = ConstantInt::get(Ty: CurrX->getType(), V: Mask)) &&
2255 (BitPos = ConstantInt::get(Ty: CurrX->getType(), V: Mask.logBase2()));
2256 };
2257
2258 if (!MatchVariableBitMask() && !MatchConstantBitMask() &&
2259 !MatchDecomposableConstantBitMask()) {
2260 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n");
2261 return false;
2262 }
2263
2264 // Step 3: Check if the recurrence is in desirable form.
2265 auto *CurrXPN = dyn_cast<PHINode>(Val: CurrX);
2266 if (!CurrXPN || CurrXPN->getParent() != LoopHeaderBB) {
2267 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Not an expected PHI node.\n");
2268 return false;
2269 }
2270
2271 BaseX = CurrXPN->getIncomingValueForBlock(BB: LoopPreheaderBB);
2272 NextX =
2273 dyn_cast<Instruction>(Val: CurrXPN->getIncomingValueForBlock(BB: LoopHeaderBB));
2274
2275 assert(CurLoop->isLoopInvariant(BaseX) &&
2276 "Expected BaseX to be avaliable in the preheader!");
2277
2278 if (!NextX || !match(V: NextX, P: m_Shl(L: m_Specific(V: CurrX), R: m_One()))) {
2279 // FIXME: support right-shift?
2280 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad recurrence.\n");
2281 return false;
2282 }
2283
2284 // Step 4: Check if the backedge's destinations are in desirable form.
2285
2286 assert(ICmpInst::isEquality(Pred) &&
2287 "Should only get equality predicates here.");
2288
2289 // cmp-br is commutative, so canonicalize to a single variant.
2290 if (Pred != ICmpInst::Predicate::ICMP_EQ) {
2291 Pred = ICmpInst::getInversePredicate(pred: Pred);
2292 std::swap(a&: TrueBB, b&: FalseBB);
2293 }
2294
2295 // We expect to exit loop when comparison yields false,
2296 // so when it yields true we should branch back to loop header.
2297 if (TrueBB != LoopHeaderBB) {
2298 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge flow.\n");
2299 return false;
2300 }
2301
2302 // Okay, idiom checks out.
2303 return true;
2304}
2305
2306/// Look for the following loop:
2307/// \code
2308/// entry:
2309/// <...>
2310/// %bitmask = shl i32 1, %bitpos
2311/// br label %loop
2312///
2313/// loop:
2314/// %x.curr = phi i32 [ %x, %entry ], [ %x.next, %loop ]
2315/// %x.curr.bitmasked = and i32 %x.curr, %bitmask
2316/// %x.curr.isbitunset = icmp eq i32 %x.curr.bitmasked, 0
2317/// %x.next = shl i32 %x.curr, 1
2318/// <...>
2319/// br i1 %x.curr.isbitunset, label %loop, label %end
2320///
2321/// end:
2322/// %x.curr.res = phi i32 [ %x.curr, %loop ] <...>
2323/// %x.next.res = phi i32 [ %x.next, %loop ] <...>
2324/// <...>
2325/// \endcode
2326///
2327/// And transform it into:
2328/// \code
2329/// entry:
2330/// %bitmask = shl i32 1, %bitpos
2331/// %lowbitmask = add i32 %bitmask, -1
2332/// %mask = or i32 %lowbitmask, %bitmask
2333/// %x.masked = and i32 %x, %mask
2334/// %x.masked.numleadingzeros = call i32 @llvm.ctlz.i32(i32 %x.masked,
2335/// i1 true)
2336/// %x.masked.numactivebits = sub i32 32, %x.masked.numleadingzeros
2337/// %x.masked.leadingonepos = add i32 %x.masked.numactivebits, -1
2338/// %backedgetakencount = sub i32 %bitpos, %x.masked.leadingonepos
2339/// %tripcount = add i32 %backedgetakencount, 1
2340/// %x.curr = shl i32 %x, %backedgetakencount
2341/// %x.next = shl i32 %x, %tripcount
2342/// br label %loop
2343///
2344/// loop:
2345/// %loop.iv = phi i32 [ 0, %entry ], [ %loop.iv.next, %loop ]
2346/// %loop.iv.next = add nuw i32 %loop.iv, 1
2347/// %loop.ivcheck = icmp eq i32 %loop.iv.next, %tripcount
2348/// <...>
2349/// br i1 %loop.ivcheck, label %end, label %loop
2350///
2351/// end:
2352/// %x.curr.res = phi i32 [ %x.curr, %loop ] <...>
2353/// %x.next.res = phi i32 [ %x.next, %loop ] <...>
2354/// <...>
2355/// \endcode
2356bool LoopIdiomRecognize::recognizeShiftUntilBitTest() {
2357 bool MadeChange = false;
2358
2359 Value *X, *BitMask, *BitPos, *XCurr;
2360 Instruction *XNext;
2361 if (!detectShiftUntilBitTestIdiom(CurLoop, BaseX&: X, BitMask, BitPos, CurrX&: XCurr,
2362 NextX&: XNext)) {
2363 LLVM_DEBUG(dbgs() << DEBUG_TYPE
2364 " shift-until-bittest idiom detection failed.\n");
2365 return MadeChange;
2366 }
2367 LLVM_DEBUG(dbgs() << DEBUG_TYPE " shift-until-bittest idiom detected!\n");
2368
2369 // Ok, it is the idiom we were looking for, we *could* transform this loop,
2370 // but is it profitable to transform?
2371
2372 BasicBlock *LoopHeaderBB = CurLoop->getHeader();
2373 BasicBlock *LoopPreheaderBB = CurLoop->getLoopPreheader();
2374 assert(LoopPreheaderBB && "There is always a loop preheader.");
2375
2376 BasicBlock *SuccessorBB = CurLoop->getExitBlock();
2377 assert(SuccessorBB && "There is only a single successor.");
2378
2379 IRBuilder<> Builder(LoopPreheaderBB->getTerminator());
2380 Builder.SetCurrentDebugLocation(cast<Instruction>(Val: XCurr)->getDebugLoc());
2381
2382 Intrinsic::ID IntrID = Intrinsic::ctlz;
2383 Type *Ty = X->getType();
2384 unsigned Bitwidth = Ty->getScalarSizeInBits();
2385
2386 TargetTransformInfo::TargetCostKind CostKind =
2387 TargetTransformInfo::TCK_SizeAndLatency;
2388
2389 // The rewrite is considered to be unprofitable iff and only iff the
2390 // intrinsic/shift we'll use are not cheap. Note that we are okay with *just*
2391 // making the loop countable, even if nothing else changes.
2392 IntrinsicCostAttributes Attrs(
2393 IntrID, Ty, {PoisonValue::get(T: Ty), /*is_zero_poison=*/Builder.getTrue()});
2394 InstructionCost Cost = TTI->getIntrinsicInstrCost(ICA: Attrs, CostKind);
2395 if (Cost > TargetTransformInfo::TCC_Basic) {
2396 LLVM_DEBUG(dbgs() << DEBUG_TYPE
2397 " Intrinsic is too costly, not beneficial\n");
2398 return MadeChange;
2399 }
2400 if (TTI->getArithmeticInstrCost(Opcode: Instruction::Shl, Ty, CostKind) >
2401 TargetTransformInfo::TCC_Basic) {
2402 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Shift is too costly, not beneficial\n");
2403 return MadeChange;
2404 }
2405
2406 // Ok, transform appears worthwhile.
2407 MadeChange = true;
2408
2409 if (!isGuaranteedNotToBeUndefOrPoison(V: BitPos)) {
2410 // BitMask may be computed from BitPos, Freeze BitPos so we can increase
2411 // it's use count.
2412 Instruction *InsertPt = nullptr;
2413 if (auto *BitPosI = dyn_cast<Instruction>(Val: BitPos))
2414 InsertPt = &**BitPosI->getInsertionPointAfterDef();
2415 else
2416 InsertPt = &*DT->getRoot()->getFirstNonPHIOrDbgOrAlloca();
2417 if (!InsertPt)
2418 return false;
2419 FreezeInst *BitPosFrozen =
2420 new FreezeInst(BitPos, BitPos->getName() + ".fr", InsertPt);
2421 BitPos->replaceUsesWithIf(New: BitPosFrozen, ShouldReplace: [BitPosFrozen](Use &U) {
2422 return U.getUser() != BitPosFrozen;
2423 });
2424 BitPos = BitPosFrozen;
2425 }
2426
2427 // Step 1: Compute the loop trip count.
2428
2429 Value *LowBitMask = Builder.CreateAdd(LHS: BitMask, RHS: Constant::getAllOnesValue(Ty),
2430 Name: BitPos->getName() + ".lowbitmask");
2431 Value *Mask =
2432 Builder.CreateOr(LHS: LowBitMask, RHS: BitMask, Name: BitPos->getName() + ".mask");
2433 Value *XMasked = Builder.CreateAnd(LHS: X, RHS: Mask, Name: X->getName() + ".masked");
2434 CallInst *XMaskedNumLeadingZeros = Builder.CreateIntrinsic(
2435 ID: IntrID, Types: Ty, Args: {XMasked, /*is_zero_poison=*/Builder.getTrue()},
2436 /*FMFSource=*/nullptr, Name: XMasked->getName() + ".numleadingzeros");
2437 Value *XMaskedNumActiveBits = Builder.CreateSub(
2438 LHS: ConstantInt::get(Ty, V: Ty->getScalarSizeInBits()), RHS: XMaskedNumLeadingZeros,
2439 Name: XMasked->getName() + ".numactivebits", /*HasNUW=*/true,
2440 /*HasNSW=*/Bitwidth != 2);
2441 Value *XMaskedLeadingOnePos =
2442 Builder.CreateAdd(LHS: XMaskedNumActiveBits, RHS: Constant::getAllOnesValue(Ty),
2443 Name: XMasked->getName() + ".leadingonepos", /*HasNUW=*/false,
2444 /*HasNSW=*/Bitwidth > 2);
2445
2446 Value *LoopBackedgeTakenCount = Builder.CreateSub(
2447 LHS: BitPos, RHS: XMaskedLeadingOnePos, Name: CurLoop->getName() + ".backedgetakencount",
2448 /*HasNUW=*/true, /*HasNSW=*/true);
2449 // We know loop's backedge-taken count, but what's loop's trip count?
2450 // Note that while NUW is always safe, while NSW is only for bitwidths != 2.
2451 Value *LoopTripCount =
2452 Builder.CreateAdd(LHS: LoopBackedgeTakenCount, RHS: ConstantInt::get(Ty, V: 1),
2453 Name: CurLoop->getName() + ".tripcount", /*HasNUW=*/true,
2454 /*HasNSW=*/Bitwidth != 2);
2455
2456 // Step 2: Compute the recurrence's final value without a loop.
2457
2458 // NewX is always safe to compute, because `LoopBackedgeTakenCount`
2459 // will always be smaller than `bitwidth(X)`, i.e. we never get poison.
2460 Value *NewX = Builder.CreateShl(LHS: X, RHS: LoopBackedgeTakenCount);
2461 NewX->takeName(V: XCurr);
2462 if (auto *I = dyn_cast<Instruction>(NewX))
2463 I->copyIRFlags(XNext, /*IncludeWrapFlags=*/true);
2464
2465 Value *NewXNext;
2466 // Rewriting XNext is more complicated, however, because `X << LoopTripCount`
2467 // will be poison iff `LoopTripCount == bitwidth(X)` (which will happen
2468 // iff `BitPos` is `bitwidth(x) - 1` and `X` is `1`). So unless we know
2469 // that isn't the case, we'll need to emit an alternative, safe IR.
2470 if (XNext->hasNoSignedWrap() || XNext->hasNoUnsignedWrap() ||
2471 PatternMatch::match(
2472 V: BitPos, P: PatternMatch::m_SpecificInt_ICMP(
2473 Predicate: ICmpInst::ICMP_NE, Threshold: APInt(Ty->getScalarSizeInBits(),
2474 Ty->getScalarSizeInBits() - 1))))
2475 NewXNext = Builder.CreateShl(LHS: X, RHS: LoopTripCount);
2476 else {
2477 // Otherwise, just additionally shift by one. It's the smallest solution,
2478 // alternatively, we could check that NewX is INT_MIN (or BitPos is )
2479 // and select 0 instead.
2480 NewXNext = Builder.CreateShl(LHS: NewX, RHS: ConstantInt::get(Ty, V: 1));
2481 }
2482
2483 NewXNext->takeName(V: XNext);
2484 if (auto *I = dyn_cast<Instruction>(Val: NewXNext))
2485 I->copyIRFlags(V: XNext, /*IncludeWrapFlags=*/true);
2486
2487 // Step 3: Adjust the successor basic block to recieve the computed
2488 // recurrence's final value instead of the recurrence itself.
2489
2490 XCurr->replaceUsesOutsideBlock(V: NewX, BB: LoopHeaderBB);
2491 XNext->replaceUsesOutsideBlock(V: NewXNext, BB: LoopHeaderBB);
2492
2493 // Step 4: Rewrite the loop into a countable form, with canonical IV.
2494
2495 // The new canonical induction variable.
2496 Builder.SetInsertPoint(TheBB: LoopHeaderBB, IP: LoopHeaderBB->begin());
2497 auto *IV = Builder.CreatePHI(Ty, NumReservedValues: 2, Name: CurLoop->getName() + ".iv");
2498
2499 // The induction itself.
2500 // Note that while NUW is always safe, while NSW is only for bitwidths != 2.
2501 Builder.SetInsertPoint(LoopHeaderBB->getTerminator());
2502 auto *IVNext =
2503 Builder.CreateAdd(LHS: IV, RHS: ConstantInt::get(Ty, V: 1), Name: IV->getName() + ".next",
2504 /*HasNUW=*/true, /*HasNSW=*/Bitwidth != 2);
2505
2506 // The loop trip count check.
2507 auto *IVCheck = Builder.CreateICmpEQ(LHS: IVNext, RHS: LoopTripCount,
2508 Name: CurLoop->getName() + ".ivcheck");
2509 Builder.CreateCondBr(IVCheck, SuccessorBB, LoopHeaderBB);
2510 LoopHeaderBB->getTerminator()->eraseFromParent();
2511
2512 // Populate the IV PHI.
2513 IV->addIncoming(V: ConstantInt::get(Ty, V: 0), BB: LoopPreheaderBB);
2514 IV->addIncoming(V: IVNext, BB: LoopHeaderBB);
2515
2516 // Step 5: Forget the "non-computable" trip-count SCEV associated with the
2517 // loop. The loop would otherwise not be deleted even if it becomes empty.
2518
2519 SE->forgetLoop(L: CurLoop);
2520
2521 // Other passes will take care of actually deleting the loop if possible.
2522
2523 LLVM_DEBUG(dbgs() << DEBUG_TYPE " shift-until-bittest idiom optimized!\n");
2524
2525 ++NumShiftUntilBitTest;
2526 return MadeChange;
2527}
2528
2529/// Return true if the idiom is detected in the loop.
2530///
2531/// The core idiom we are trying to detect is:
2532/// \code
2533/// entry:
2534/// <...>
2535/// %start = <...>
2536/// %extraoffset = <...>
2537/// <...>
2538/// br label %for.cond
2539///
2540/// loop:
2541/// %iv = phi i8 [ %start, %entry ], [ %iv.next, %for.cond ]
2542/// %nbits = add nsw i8 %iv, %extraoffset
2543/// %val.shifted = {{l,a}shr,shl} i8 %val, %nbits
2544/// %val.shifted.iszero = icmp eq i8 %val.shifted, 0
2545/// %iv.next = add i8 %iv, 1
2546/// <...>
2547/// br i1 %val.shifted.iszero, label %end, label %loop
2548///
2549/// end:
2550/// %iv.res = phi i8 [ %iv, %loop ] <...>
2551/// %nbits.res = phi i8 [ %nbits, %loop ] <...>
2552/// %val.shifted.res = phi i8 [ %val.shifted, %loop ] <...>
2553/// %val.shifted.iszero.res = phi i1 [ %val.shifted.iszero, %loop ] <...>
2554/// %iv.next.res = phi i8 [ %iv.next, %loop ] <...>
2555/// <...>
2556/// \endcode
2557static bool detectShiftUntilZeroIdiom(Loop *CurLoop, ScalarEvolution *SE,
2558 Instruction *&ValShiftedIsZero,
2559 Intrinsic::ID &IntrinID, Instruction *&IV,
2560 Value *&Start, Value *&Val,
2561 const SCEV *&ExtraOffsetExpr,
2562 bool &InvertedCond) {
2563 LLVM_DEBUG(dbgs() << DEBUG_TYPE
2564 " Performing shift-until-zero idiom detection.\n");
2565
2566 // Give up if the loop has multiple blocks or multiple backedges.
2567 if (CurLoop->getNumBlocks() != 1 || CurLoop->getNumBackEdges() != 1) {
2568 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad block/backedge count.\n");
2569 return false;
2570 }
2571
2572 Instruction *ValShifted, *NBits, *IVNext;
2573 Value *ExtraOffset;
2574
2575 BasicBlock *LoopHeaderBB = CurLoop->getHeader();
2576 BasicBlock *LoopPreheaderBB = CurLoop->getLoopPreheader();
2577 assert(LoopPreheaderBB && "There is always a loop preheader.");
2578
2579 using namespace PatternMatch;
2580
2581 // Step 1: Check if the loop backedge, condition is in desirable form.
2582
2583 ICmpInst::Predicate Pred;
2584 BasicBlock *TrueBB, *FalseBB;
2585 if (!match(V: LoopHeaderBB->getTerminator(),
2586 P: m_Br(C: m_Instruction(I&: ValShiftedIsZero), T: m_BasicBlock(V&: TrueBB),
2587 F: m_BasicBlock(V&: FalseBB))) ||
2588 !match(V: ValShiftedIsZero,
2589 P: m_ICmp(Pred, L: m_Instruction(I&: ValShifted), R: m_Zero())) ||
2590 !ICmpInst::isEquality(P: Pred)) {
2591 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge structure.\n");
2592 return false;
2593 }
2594
2595 // Step 2: Check if the comparison's operand is in desirable form.
2596 // FIXME: Val could be a one-input PHI node, which we should look past.
2597 if (!match(V: ValShifted, P: m_Shift(L: m_LoopInvariant(M: m_Value(V&: Val), L: CurLoop),
2598 R: m_Instruction(I&: NBits)))) {
2599 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad comparisons value computation.\n");
2600 return false;
2601 }
2602 IntrinID = ValShifted->getOpcode() == Instruction::Shl ? Intrinsic::cttz
2603 : Intrinsic::ctlz;
2604
2605 // Step 3: Check if the shift amount is in desirable form.
2606
2607 if (match(V: NBits, P: m_c_Add(L: m_Instruction(I&: IV),
2608 R: m_LoopInvariant(M: m_Value(V&: ExtraOffset), L: CurLoop))) &&
2609 (NBits->hasNoSignedWrap() || NBits->hasNoUnsignedWrap()))
2610 ExtraOffsetExpr = SE->getNegativeSCEV(V: SE->getSCEV(V: ExtraOffset));
2611 else if (match(V: NBits,
2612 P: m_Sub(L: m_Instruction(I&: IV),
2613 R: m_LoopInvariant(M: m_Value(V&: ExtraOffset), L: CurLoop))) &&
2614 NBits->hasNoSignedWrap())
2615 ExtraOffsetExpr = SE->getSCEV(V: ExtraOffset);
2616 else {
2617 IV = NBits;
2618 ExtraOffsetExpr = SE->getZero(Ty: NBits->getType());
2619 }
2620
2621 // Step 4: Check if the recurrence is in desirable form.
2622 auto *IVPN = dyn_cast<PHINode>(Val: IV);
2623 if (!IVPN || IVPN->getParent() != LoopHeaderBB) {
2624 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Not an expected PHI node.\n");
2625 return false;
2626 }
2627
2628 Start = IVPN->getIncomingValueForBlock(BB: LoopPreheaderBB);
2629 IVNext = dyn_cast<Instruction>(Val: IVPN->getIncomingValueForBlock(BB: LoopHeaderBB));
2630
2631 if (!IVNext || !match(V: IVNext, P: m_Add(L: m_Specific(V: IVPN), R: m_One()))) {
2632 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad recurrence.\n");
2633 return false;
2634 }
2635
2636 // Step 4: Check if the backedge's destinations are in desirable form.
2637
2638 assert(ICmpInst::isEquality(Pred) &&
2639 "Should only get equality predicates here.");
2640
2641 // cmp-br is commutative, so canonicalize to a single variant.
2642 InvertedCond = Pred != ICmpInst::Predicate::ICMP_EQ;
2643 if (InvertedCond) {
2644 Pred = ICmpInst::getInversePredicate(pred: Pred);
2645 std::swap(a&: TrueBB, b&: FalseBB);
2646 }
2647
2648 // We expect to exit loop when comparison yields true,
2649 // so when it yields false we should branch back to loop header.
2650 if (FalseBB != LoopHeaderBB) {
2651 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge flow.\n");
2652 return false;
2653 }
2654
2655 // The new, countable, loop will certainly only run a known number of
2656 // iterations, It won't be infinite. But the old loop might be infinite
2657 // under certain conditions. For logical shifts, the value will become zero
2658 // after at most bitwidth(%Val) loop iterations. However, for arithmetic
2659 // right-shift, iff the sign bit was set, the value will never become zero,
2660 // and the loop may never finish.
2661 if (ValShifted->getOpcode() == Instruction::AShr &&
2662 !isMustProgress(L: CurLoop) && !SE->isKnownNonNegative(S: SE->getSCEV(V: Val))) {
2663 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Can not prove the loop is finite.\n");
2664 return false;
2665 }
2666
2667 // Okay, idiom checks out.
2668 return true;
2669}
2670
2671/// Look for the following loop:
2672/// \code
2673/// entry:
2674/// <...>
2675/// %start = <...>
2676/// %extraoffset = <...>
2677/// <...>
2678/// br label %for.cond
2679///
2680/// loop:
2681/// %iv = phi i8 [ %start, %entry ], [ %iv.next, %for.cond ]
2682/// %nbits = add nsw i8 %iv, %extraoffset
2683/// %val.shifted = {{l,a}shr,shl} i8 %val, %nbits
2684/// %val.shifted.iszero = icmp eq i8 %val.shifted, 0
2685/// %iv.next = add i8 %iv, 1
2686/// <...>
2687/// br i1 %val.shifted.iszero, label %end, label %loop
2688///
2689/// end:
2690/// %iv.res = phi i8 [ %iv, %loop ] <...>
2691/// %nbits.res = phi i8 [ %nbits, %loop ] <...>
2692/// %val.shifted.res = phi i8 [ %val.shifted, %loop ] <...>
2693/// %val.shifted.iszero.res = phi i1 [ %val.shifted.iszero, %loop ] <...>
2694/// %iv.next.res = phi i8 [ %iv.next, %loop ] <...>
2695/// <...>
2696/// \endcode
2697///
2698/// And transform it into:
2699/// \code
2700/// entry:
2701/// <...>
2702/// %start = <...>
2703/// %extraoffset = <...>
2704/// <...>
2705/// %val.numleadingzeros = call i8 @llvm.ct{l,t}z.i8(i8 %val, i1 0)
2706/// %val.numactivebits = sub i8 8, %val.numleadingzeros
2707/// %extraoffset.neg = sub i8 0, %extraoffset
2708/// %tmp = add i8 %val.numactivebits, %extraoffset.neg
2709/// %iv.final = call i8 @llvm.smax.i8(i8 %tmp, i8 %start)
2710/// %loop.tripcount = sub i8 %iv.final, %start
2711/// br label %loop
2712///
2713/// loop:
2714/// %loop.iv = phi i8 [ 0, %entry ], [ %loop.iv.next, %loop ]
2715/// %loop.iv.next = add i8 %loop.iv, 1
2716/// %loop.ivcheck = icmp eq i8 %loop.iv.next, %loop.tripcount
2717/// %iv = add i8 %loop.iv, %start
2718/// <...>
2719/// br i1 %loop.ivcheck, label %end, label %loop
2720///
2721/// end:
2722/// %iv.res = phi i8 [ %iv.final, %loop ] <...>
2723/// <...>
2724/// \endcode
2725bool LoopIdiomRecognize::recognizeShiftUntilZero() {
2726 bool MadeChange = false;
2727
2728 Instruction *ValShiftedIsZero;
2729 Intrinsic::ID IntrID;
2730 Instruction *IV;
2731 Value *Start, *Val;
2732 const SCEV *ExtraOffsetExpr;
2733 bool InvertedCond;
2734 if (!detectShiftUntilZeroIdiom(CurLoop, SE, ValShiftedIsZero, IntrinID&: IntrID, IV,
2735 Start, Val, ExtraOffsetExpr, InvertedCond)) {
2736 LLVM_DEBUG(dbgs() << DEBUG_TYPE
2737 " shift-until-zero idiom detection failed.\n");
2738 return MadeChange;
2739 }
2740 LLVM_DEBUG(dbgs() << DEBUG_TYPE " shift-until-zero idiom detected!\n");
2741
2742 // Ok, it is the idiom we were looking for, we *could* transform this loop,
2743 // but is it profitable to transform?
2744
2745 BasicBlock *LoopHeaderBB = CurLoop->getHeader();
2746 BasicBlock *LoopPreheaderBB = CurLoop->getLoopPreheader();
2747 assert(LoopPreheaderBB && "There is always a loop preheader.");
2748
2749 BasicBlock *SuccessorBB = CurLoop->getExitBlock();
2750 assert(SuccessorBB && "There is only a single successor.");
2751
2752 IRBuilder<> Builder(LoopPreheaderBB->getTerminator());
2753 Builder.SetCurrentDebugLocation(IV->getDebugLoc());
2754
2755 Type *Ty = Val->getType();
2756 unsigned Bitwidth = Ty->getScalarSizeInBits();
2757
2758 TargetTransformInfo::TargetCostKind CostKind =
2759 TargetTransformInfo::TCK_SizeAndLatency;
2760
2761 // The rewrite is considered to be unprofitable iff and only iff the
2762 // intrinsic we'll use are not cheap. Note that we are okay with *just*
2763 // making the loop countable, even if nothing else changes.
2764 IntrinsicCostAttributes Attrs(
2765 IntrID, Ty, {PoisonValue::get(T: Ty), /*is_zero_poison=*/Builder.getFalse()});
2766 InstructionCost Cost = TTI->getIntrinsicInstrCost(ICA: Attrs, CostKind);
2767 if (Cost > TargetTransformInfo::TCC_Basic) {
2768 LLVM_DEBUG(dbgs() << DEBUG_TYPE
2769 " Intrinsic is too costly, not beneficial\n");
2770 return MadeChange;
2771 }
2772
2773 // Ok, transform appears worthwhile.
2774 MadeChange = true;
2775
2776 bool OffsetIsZero = false;
2777 if (auto *ExtraOffsetExprC = dyn_cast<SCEVConstant>(Val: ExtraOffsetExpr))
2778 OffsetIsZero = ExtraOffsetExprC->isZero();
2779
2780 // Step 1: Compute the loop's final IV value / trip count.
2781
2782 CallInst *ValNumLeadingZeros = Builder.CreateIntrinsic(
2783 ID: IntrID, Types: Ty, Args: {Val, /*is_zero_poison=*/Builder.getFalse()},
2784 /*FMFSource=*/nullptr, Name: Val->getName() + ".numleadingzeros");
2785 Value *ValNumActiveBits = Builder.CreateSub(
2786 LHS: ConstantInt::get(Ty, V: Ty->getScalarSizeInBits()), RHS: ValNumLeadingZeros,
2787 Name: Val->getName() + ".numactivebits", /*HasNUW=*/true,
2788 /*HasNSW=*/Bitwidth != 2);
2789
2790 SCEVExpander Expander(*SE, *DL, "loop-idiom");
2791 Expander.setInsertPoint(&*Builder.GetInsertPoint());
2792 Value *ExtraOffset = Expander.expandCodeFor(SH: ExtraOffsetExpr);
2793
2794 Value *ValNumActiveBitsOffset = Builder.CreateAdd(
2795 LHS: ValNumActiveBits, RHS: ExtraOffset, Name: ValNumActiveBits->getName() + ".offset",
2796 /*HasNUW=*/OffsetIsZero, /*HasNSW=*/true);
2797 Value *IVFinal = Builder.CreateIntrinsic(Intrinsic::smax, {Ty},
2798 {ValNumActiveBitsOffset, Start},
2799 /*FMFSource=*/nullptr, "iv.final");
2800
2801 auto *LoopBackedgeTakenCount = cast<Instruction>(Val: Builder.CreateSub(
2802 LHS: IVFinal, RHS: Start, Name: CurLoop->getName() + ".backedgetakencount",
2803 /*HasNUW=*/OffsetIsZero, /*HasNSW=*/true));
2804 // FIXME: or when the offset was `add nuw`
2805
2806 // We know loop's backedge-taken count, but what's loop's trip count?
2807 Value *LoopTripCount =
2808 Builder.CreateAdd(LHS: LoopBackedgeTakenCount, RHS: ConstantInt::get(Ty, V: 1),
2809 Name: CurLoop->getName() + ".tripcount", /*HasNUW=*/true,
2810 /*HasNSW=*/Bitwidth != 2);
2811
2812 // Step 2: Adjust the successor basic block to recieve the original
2813 // induction variable's final value instead of the orig. IV itself.
2814
2815 IV->replaceUsesOutsideBlock(V: IVFinal, BB: LoopHeaderBB);
2816
2817 // Step 3: Rewrite the loop into a countable form, with canonical IV.
2818
2819 // The new canonical induction variable.
2820 Builder.SetInsertPoint(TheBB: LoopHeaderBB, IP: LoopHeaderBB->begin());
2821 auto *CIV = Builder.CreatePHI(Ty, NumReservedValues: 2, Name: CurLoop->getName() + ".iv");
2822
2823 // The induction itself.
2824 Builder.SetInsertPoint(TheBB: LoopHeaderBB, IP: LoopHeaderBB->getFirstNonPHIIt());
2825 auto *CIVNext =
2826 Builder.CreateAdd(LHS: CIV, RHS: ConstantInt::get(Ty, V: 1), Name: CIV->getName() + ".next",
2827 /*HasNUW=*/true, /*HasNSW=*/Bitwidth != 2);
2828
2829 // The loop trip count check.
2830 auto *CIVCheck = Builder.CreateICmpEQ(LHS: CIVNext, RHS: LoopTripCount,
2831 Name: CurLoop->getName() + ".ivcheck");
2832 auto *NewIVCheck = CIVCheck;
2833 if (InvertedCond) {
2834 NewIVCheck = Builder.CreateNot(V: CIVCheck);
2835 NewIVCheck->takeName(ValShiftedIsZero);
2836 }
2837
2838 // The original IV, but rebased to be an offset to the CIV.
2839 auto *IVDePHId = Builder.CreateAdd(LHS: CIV, RHS: Start, Name: "", /*HasNUW=*/false,
2840 /*HasNSW=*/true); // FIXME: what about NUW?
2841 IVDePHId->takeName(V: IV);
2842
2843 // The loop terminator.
2844 Builder.SetInsertPoint(LoopHeaderBB->getTerminator());
2845 Builder.CreateCondBr(CIVCheck, SuccessorBB, LoopHeaderBB);
2846 LoopHeaderBB->getTerminator()->eraseFromParent();
2847
2848 // Populate the IV PHI.
2849 CIV->addIncoming(V: ConstantInt::get(Ty, V: 0), BB: LoopPreheaderBB);
2850 CIV->addIncoming(V: CIVNext, BB: LoopHeaderBB);
2851
2852 // Step 4: Forget the "non-computable" trip-count SCEV associated with the
2853 // loop. The loop would otherwise not be deleted even if it becomes empty.
2854
2855 SE->forgetLoop(L: CurLoop);
2856
2857 // Step 5: Try to cleanup the loop's body somewhat.
2858 IV->replaceAllUsesWith(V: IVDePHId);
2859 IV->eraseFromParent();
2860
2861 ValShiftedIsZero->replaceAllUsesWith(V: NewIVCheck);
2862 ValShiftedIsZero->eraseFromParent();
2863
2864 // Other passes will take care of actually deleting the loop if possible.
2865
2866 LLVM_DEBUG(dbgs() << DEBUG_TYPE " shift-until-zero idiom optimized!\n");
2867
2868 ++NumShiftUntilZero;
2869 return MadeChange;
2870}
2871

source code of llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp