1//===- SeedCollector.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h"
10#include "llvm/Analysis/LoopAccessAnalysis.h"
11#include "llvm/Analysis/ValueTracking.h"
12#include "llvm/IR/Type.h"
13#include "llvm/SandboxIR/Instruction.h"
14#include "llvm/SandboxIR/Utils.h"
15#include "llvm/Support/Compiler.h"
16#include "llvm/Support/Debug.h"
17
18using namespace llvm;
19namespace llvm::sandboxir {
20
21static cl::opt<unsigned> SeedBundleSizeLimit(
22 "sbvec-seed-bundle-size-limit", cl::init(Val: 32), cl::Hidden,
23 cl::desc("Limit the size of the seed bundle to cap compilation time."));
24#define LoadSeedsDef "loads"
25#define StoreSeedsDef "stores"
26static cl::opt<std::string> CollectSeeds(
27 "sbvec-collect-seeds", cl::init(LoadSeedsDef "," StoreSeedsDef), cl::Hidden,
28 cl::desc("Collect these seeds. Use empty for none or a comma-separated "
29 "list of '" LoadSeedsDef "' and '" StoreSeedsDef "'."));
30static cl::opt<unsigned> SeedGroupsLimit(
31 "sbvec-seed-groups-limit", cl::init(Val: 256), cl::Hidden,
32 cl::desc("Limit the number of collected seeds groups in a BB to "
33 "cap compilation time."));
34
35ArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
36 unsigned MaxVecRegBits,
37 bool ForcePowerOf2) {
38 // Use uint32_t here for compatibility with IsPowerOf2_32
39
40 // BitCount tracks the size of the working slice. From that we can tell
41 // when the working slice's size is a power-of-two and when it exceeds
42 // the legal size in MaxVecBits.
43 uint32_t BitCount = 0;
44 uint32_t NumElements = 0;
45 // Tracks the most recent slice where NumElements gave a power-of-2 BitCount
46 uint32_t NumElementsPowerOfTwo = 0;
47 uint32_t BitCountPowerOfTwo = 0;
48 // Can't start a slice with a used instruction.
49 assert(!isUsed(StartIdx) && "Expected unused at StartIdx");
50 for (Instruction *S : drop_begin(RangeOrContainer&: Seeds, N: StartIdx)) {
51 // Stop if this instruction is used. This needs to be done before
52 // getNumBits() because a "used" instruction may have been erased.
53 if (isUsed(Element: StartIdx + NumElements))
54 break;
55 uint32_t InstBits = Utils::getNumBits(I: S);
56 // Stop if adding it puts the slice over the limit.
57 if (BitCount + InstBits > MaxVecRegBits)
58 break;
59 NumElements++;
60 BitCount += InstBits;
61 if (ForcePowerOf2 && isPowerOf2_32(Value: BitCount)) {
62 NumElementsPowerOfTwo = NumElements;
63 BitCountPowerOfTwo = BitCount;
64 }
65 }
66 if (ForcePowerOf2) {
67 NumElements = NumElementsPowerOfTwo;
68 BitCount = BitCountPowerOfTwo;
69 }
70
71 // Return any non-empty slice
72 if (NumElements > 1) {
73 assert((!ForcePowerOf2 || isPowerOf2_32(BitCount)) &&
74 "Must be a power of two");
75 return ArrayRef<Instruction *>(&Seeds[StartIdx], NumElements);
76 }
77 return {};
78}
79
80template <typename LoadOrStoreT>
81SeedContainer::KeyT SeedContainer::getKey(LoadOrStoreT *LSI) const {
82 assert((isa<LoadInst>(LSI) || isa<StoreInst>(LSI)) &&
83 "Expected Load or Store!");
84 Value *Ptr = Utils::getMemInstructionBase(LSI);
85 Instruction::Opcode Op = LSI->getOpcode();
86 Type *Ty = Utils::getExpectedType(V: LSI);
87 if (auto *VTy = dyn_cast<VectorType>(Val: Ty))
88 Ty = VTy->getElementType();
89 return {Ptr, Ty, Op};
90}
91
92// Explicit instantiations
93template SeedContainer::KeyT
94SeedContainer::getKey<LoadInst>(LoadInst *LSI) const;
95template SeedContainer::KeyT
96SeedContainer::getKey<StoreInst>(StoreInst *LSI) const;
97
98bool SeedContainer::erase(Instruction *I) {
99 assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && "Expected Load or Store!");
100 auto It = SeedLookupMap.find(Val: I);
101 if (It == SeedLookupMap.end())
102 return false;
103 SeedBundle *Bndl = It->second;
104 Bndl->setUsed(I);
105 return true;
106}
107
108template <typename LoadOrStoreT> void SeedContainer::insert(LoadOrStoreT *LSI) {
109 // Find the bundle containing seeds for this symbol and type-of-access.
110 auto &BundleVec = Bundles[getKey(LSI)];
111 // Fill this vector of bundles front to back so that only the last bundle in
112 // the vector may have available space. This avoids iteration to find one with
113 // space.
114 if (BundleVec.empty() || BundleVec.back()->size() == SeedBundleSizeLimit)
115 BundleVec.emplace_back(std::make_unique<MemSeedBundle<LoadOrStoreT>>(LSI));
116 else
117 BundleVec.back()->insert(LSI, SE);
118
119 SeedLookupMap[LSI] = BundleVec.back().get();
120}
121
122// Explicit instantiations
123template LLVM_EXPORT_TEMPLATE void SeedContainer::insert<LoadInst>(LoadInst *);
124template LLVM_EXPORT_TEMPLATE void
125SeedContainer::insert<StoreInst>(StoreInst *);
126
127#ifndef NDEBUG
128void SeedContainer::print(raw_ostream &OS) const {
129 for (const auto &Pair : Bundles) {
130 auto [I, Ty, Opc] = Pair.first;
131 const auto &SeedsVec = Pair.second;
132 std::string RefType = dyn_cast<LoadInst>(Val: I) ? "Load"
133 : dyn_cast<StoreInst>(Val: I) ? "Store"
134 : "Other";
135 OS << "[Inst=" << *I << " Ty=" << Ty << " " << RefType << "]\n";
136 for (const auto &SeedPtr : SeedsVec) {
137 SeedPtr->dump(OS);
138 OS << "\n";
139 }
140 }
141 OS << "\n";
142}
143
144LLVM_DUMP_METHOD void SeedContainer::dump() const { print(OS&: dbgs()); }
145#endif // NDEBUG
146
147template <typename LoadOrStoreT> static bool isValidMemSeed(LoadOrStoreT *LSI) {
148 if (!LSI->isSimple())
149 return false;
150 auto *Ty = Utils::getExpectedType(V: LSI);
151 // Omit types that are architecturally unvectorizable
152 if (Ty->isX86_FP80Ty() || Ty->isPPC_FP128Ty())
153 return false;
154 // Omit vector types without compile-time-known lane counts
155 if (isa<ScalableVectorType>(Ty))
156 return false;
157 if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
158 return VectorType::isValidElementType(ElemTy: VTy->getElementType());
159 return VectorType::isValidElementType(ElemTy: Ty);
160}
161
162template bool isValidMemSeed<LoadInst>(LoadInst *LSI);
163template bool isValidMemSeed<StoreInst>(StoreInst *LSI);
164
165SeedCollector::SeedCollector(BasicBlock *BB, ScalarEvolution &SE)
166 : StoreSeeds(SE), LoadSeeds(SE), Ctx(BB->getContext()) {
167
168 bool CollectStores = CollectSeeds.find(StoreSeedsDef) != std::string::npos;
169 bool CollectLoads = CollectSeeds.find(LoadSeedsDef) != std::string::npos;
170 if (!CollectStores && !CollectLoads)
171 return;
172
173 EraseCallbackID = Ctx.registerEraseInstrCallback(CB: [this](Instruction *I) {
174 if (auto SI = dyn_cast<StoreInst>(Val: I))
175 StoreSeeds.erase(I: SI);
176 else if (auto LI = dyn_cast<LoadInst>(Val: I))
177 LoadSeeds.erase(I: LI);
178 });
179
180 // Actually collect the seeds.
181 for (auto &I : *BB) {
182 if (StoreInst *SI = dyn_cast<StoreInst>(Val: &I))
183 if (CollectStores && isValidMemSeed(LSI: SI))
184 StoreSeeds.insert(LSI: SI);
185 if (LoadInst *LI = dyn_cast<LoadInst>(Val: &I))
186 if (CollectLoads && isValidMemSeed(LSI: LI))
187 LoadSeeds.insert(LSI: LI);
188 // Cap compilation time.
189 if (totalNumSeedGroups() > SeedGroupsLimit)
190 break;
191 }
192}
193
194SeedCollector::~SeedCollector() {
195 Ctx.unregisterEraseInstrCallback(ID: EraseCallbackID);
196}
197
198#ifndef NDEBUG
199void SeedCollector::print(raw_ostream &OS) const {
200 OS << "=== StoreSeeds ===\n";
201 StoreSeeds.print(OS);
202 OS << "=== LoadSeeds ===\n";
203 LoadSeeds.print(OS);
204}
205
206void SeedCollector::dump() const { print(OS&: dbgs()); }
207#endif
208
209} // namespace llvm::sandboxir
210

source code of llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp