1 | //===----- SVEIntrinsicOpts - SVE ACLE Intrinsics Opts --------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Performs general IR level optimizations on SVE intrinsics. |
10 | // |
11 | // This pass performs the following optimizations: |
12 | // |
13 | // - removes unnecessary ptrue intrinsics (llvm.aarch64.sve.ptrue), e.g: |
14 | // %1 = @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) |
15 | // %2 = @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) |
16 | // ; (%1 can be replaced with a reinterpret of %2) |
17 | // |
18 | // - optimizes ptest intrinsics where the operands are being needlessly |
19 | // converted to and from svbool_t. |
20 | // |
21 | //===----------------------------------------------------------------------===// |
22 | |
23 | #include "AArch64.h" |
24 | #include "Utils/AArch64BaseInfo.h" |
25 | #include "llvm/ADT/PostOrderIterator.h" |
26 | #include "llvm/ADT/SetVector.h" |
27 | #include "llvm/IR/Constants.h" |
28 | #include "llvm/IR/Dominators.h" |
29 | #include "llvm/IR/IRBuilder.h" |
30 | #include "llvm/IR/Instructions.h" |
31 | #include "llvm/IR/IntrinsicInst.h" |
32 | #include "llvm/IR/IntrinsicsAArch64.h" |
33 | #include "llvm/IR/LLVMContext.h" |
34 | #include "llvm/IR/PatternMatch.h" |
35 | #include "llvm/InitializePasses.h" |
36 | #include "llvm/Support/Debug.h" |
37 | #include <optional> |
38 | |
39 | using namespace llvm; |
40 | using namespace llvm::PatternMatch; |
41 | |
42 | #define DEBUG_TYPE "aarch64-sve-intrinsic-opts" |
43 | |
44 | namespace { |
45 | struct SVEIntrinsicOpts : public ModulePass { |
46 | static char ID; // Pass identification, replacement for typeid |
47 | SVEIntrinsicOpts() : ModulePass(ID) { |
48 | initializeSVEIntrinsicOptsPass(*PassRegistry::getPassRegistry()); |
49 | } |
50 | |
51 | bool runOnModule(Module &M) override; |
52 | void getAnalysisUsage(AnalysisUsage &AU) const override; |
53 | |
54 | private: |
55 | bool coalescePTrueIntrinsicCalls(BasicBlock &BB, |
56 | SmallSetVector<IntrinsicInst *, 4> &PTrues); |
57 | bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions); |
58 | bool optimizePredicateStore(Instruction *I); |
59 | bool optimizePredicateLoad(Instruction *I); |
60 | |
61 | bool optimizeInstructions(SmallSetVector<Function *, 4> &Functions); |
62 | |
63 | /// Operates at the function-scope. I.e., optimizations are applied local to |
64 | /// the functions themselves. |
65 | bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions); |
66 | }; |
67 | } // end anonymous namespace |
68 | |
69 | void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const { |
70 | AU.addRequired<DominatorTreeWrapperPass>(); |
71 | AU.setPreservesCFG(); |
72 | } |
73 | |
74 | char SVEIntrinsicOpts::ID = 0; |
75 | static const char *name = "SVE intrinsics optimizations" ; |
76 | INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false) |
77 | INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass); |
78 | INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false) |
79 | |
80 | ModulePass *llvm::createSVEIntrinsicOptsPass() { |
81 | return new SVEIntrinsicOpts(); |
82 | } |
83 | |
84 | /// Checks if a ptrue intrinsic call is promoted. The act of promoting a |
85 | /// ptrue will introduce zeroing. For example: |
86 | /// |
87 | /// %1 = <vscale x 4 x i1> call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) |
88 | /// %2 = <vscale x 16 x i1> call @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %1) |
89 | /// %3 = <vscale x 8 x i1> call @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2) |
90 | /// |
91 | /// %1 is promoted, because it is converted: |
92 | /// |
93 | /// <vscale x 4 x i1> => <vscale x 16 x i1> => <vscale x 8 x i1> |
94 | /// |
95 | /// via a sequence of the SVE reinterpret intrinsics convert.{to,from}.svbool. |
96 | static bool isPTruePromoted(IntrinsicInst *PTrue) { |
97 | // Find all users of this intrinsic that are calls to convert-to-svbool |
98 | // reinterpret intrinsics. |
99 | SmallVector<IntrinsicInst *, 4> ConvertToUses; |
100 | for (User *User : PTrue->users()) { |
101 | if (match(User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) { |
102 | ConvertToUses.push_back(Elt: cast<IntrinsicInst>(Val: User)); |
103 | } |
104 | } |
105 | |
106 | // If no such calls were found, this is ptrue is not promoted. |
107 | if (ConvertToUses.empty()) |
108 | return false; |
109 | |
110 | // Otherwise, try to find users of the convert-to-svbool intrinsics that are |
111 | // calls to the convert-from-svbool intrinsic, and would result in some lanes |
112 | // being zeroed. |
113 | const auto *PTrueVTy = cast<ScalableVectorType>(Val: PTrue->getType()); |
114 | for (IntrinsicInst *ConvertToUse : ConvertToUses) { |
115 | for (User *User : ConvertToUse->users()) { |
116 | auto *IntrUser = dyn_cast<IntrinsicInst>(Val: User); |
117 | if (IntrUser && IntrUser->getIntrinsicID() == |
118 | Intrinsic::aarch64_sve_convert_from_svbool) { |
119 | const auto *IntrUserVTy = cast<ScalableVectorType>(Val: IntrUser->getType()); |
120 | |
121 | // Would some lanes become zeroed by the conversion? |
122 | if (IntrUserVTy->getElementCount().getKnownMinValue() > |
123 | PTrueVTy->getElementCount().getKnownMinValue()) |
124 | // This is a promoted ptrue. |
125 | return true; |
126 | } |
127 | } |
128 | } |
129 | |
130 | // If no matching calls were found, this is not a promoted ptrue. |
131 | return false; |
132 | } |
133 | |
134 | /// Attempts to coalesce ptrues in a basic block. |
135 | bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls( |
136 | BasicBlock &BB, SmallSetVector<IntrinsicInst *, 4> &PTrues) { |
137 | if (PTrues.size() <= 1) |
138 | return false; |
139 | |
140 | // Find the ptrue with the most lanes. |
141 | auto *MostEncompassingPTrue = |
142 | *llvm::max_element(Range&: PTrues, C: [](auto *PTrue1, auto *PTrue2) { |
143 | auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType()); |
144 | auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType()); |
145 | return PTrue1VTy->getElementCount().getKnownMinValue() < |
146 | PTrue2VTy->getElementCount().getKnownMinValue(); |
147 | }); |
148 | |
149 | // Remove the most encompassing ptrue, as well as any promoted ptrues, leaving |
150 | // behind only the ptrues to be coalesced. |
151 | PTrues.remove(X: MostEncompassingPTrue); |
152 | PTrues.remove_if(P: isPTruePromoted); |
153 | |
154 | // Hoist MostEncompassingPTrue to the start of the basic block. It is always |
155 | // safe to do this, since ptrue intrinsic calls are guaranteed to have no |
156 | // predecessors. |
157 | MostEncompassingPTrue->moveBefore(BB, I: BB.getFirstInsertionPt()); |
158 | |
159 | LLVMContext &Ctx = BB.getContext(); |
160 | IRBuilder<> Builder(Ctx); |
161 | Builder.SetInsertPoint(TheBB: &BB, IP: ++MostEncompassingPTrue->getIterator()); |
162 | |
163 | auto *MostEncompassingPTrueVTy = |
164 | cast<VectorType>(Val: MostEncompassingPTrue->getType()); |
165 | auto *ConvertToSVBool = Builder.CreateIntrinsic( |
166 | Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy}, |
167 | {MostEncompassingPTrue}); |
168 | |
169 | bool ConvertFromCreated = false; |
170 | for (auto *PTrue : PTrues) { |
171 | auto *PTrueVTy = cast<VectorType>(Val: PTrue->getType()); |
172 | |
173 | // Only create the converts if the types are not already the same, otherwise |
174 | // just use the most encompassing ptrue. |
175 | if (MostEncompassingPTrueVTy != PTrueVTy) { |
176 | ConvertFromCreated = true; |
177 | |
178 | Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator()); |
179 | auto *ConvertFromSVBool = |
180 | Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool, |
181 | {PTrueVTy}, {ConvertToSVBool}); |
182 | PTrue->replaceAllUsesWith(V: ConvertFromSVBool); |
183 | } else |
184 | PTrue->replaceAllUsesWith(V: MostEncompassingPTrue); |
185 | |
186 | PTrue->eraseFromParent(); |
187 | } |
188 | |
189 | // We never used the ConvertTo so remove it |
190 | if (!ConvertFromCreated) |
191 | ConvertToSVBool->eraseFromParent(); |
192 | |
193 | return true; |
194 | } |
195 | |
196 | /// The goal of this function is to remove redundant calls to the SVE ptrue |
197 | /// intrinsic in each basic block within the given functions. |
198 | /// |
199 | /// SVE ptrues have two representations in LLVM IR: |
200 | /// - a logical representation -- an arbitrary-width scalable vector of i1s, |
201 | /// i.e. <vscale x N x i1>. |
202 | /// - a physical representation (svbool, <vscale x 16 x i1>) -- a 16-element |
203 | /// scalable vector of i1s, i.e. <vscale x 16 x i1>. |
204 | /// |
205 | /// The SVE ptrue intrinsic is used to create a logical representation of an SVE |
206 | /// predicate. Suppose that we have two SVE ptrue intrinsic calls: P1 and P2. If |
207 | /// P1 creates a logical SVE predicate that is at least as wide as the logical |
208 | /// SVE predicate created by P2, then all of the bits that are true in the |
209 | /// physical representation of P2 are necessarily also true in the physical |
210 | /// representation of P1. P1 'encompasses' P2, therefore, the intrinsic call to |
211 | /// P2 is redundant and can be replaced by an SVE reinterpret of P1 via |
212 | /// convert.{to,from}.svbool. |
213 | /// |
214 | /// Currently, this pass only coalesces calls to SVE ptrue intrinsics |
215 | /// if they match the following conditions: |
216 | /// |
217 | /// - the call to the intrinsic uses either the SV_ALL or SV_POW2 patterns. |
218 | /// SV_ALL indicates that all bits of the predicate vector are to be set to |
219 | /// true. SV_POW2 indicates that all bits of the predicate vector up to the |
220 | /// largest power-of-two are to be set to true. |
221 | /// - the result of the call to the intrinsic is not promoted to a wider |
222 | /// predicate. In this case, keeping the extra ptrue leads to better codegen |
223 | /// -- coalescing here would create an irreducible chain of SVE reinterprets |
224 | /// via convert.{to,from}.svbool. |
225 | /// |
226 | /// EXAMPLE: |
227 | /// |
228 | /// %1 = <vscale x 8 x i1> ptrue(i32 SV_ALL) |
229 | /// ; Logical: <1, 1, 1, 1, 1, 1, 1, 1> |
230 | /// ; Physical: <1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0> |
231 | /// ... |
232 | /// |
233 | /// %2 = <vscale x 4 x i1> ptrue(i32 SV_ALL) |
234 | /// ; Logical: <1, 1, 1, 1> |
235 | /// ; Physical: <1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0> |
236 | /// ... |
237 | /// |
238 | /// Here, %2 can be replaced by an SVE reinterpret of %1, giving, for instance: |
239 | /// |
240 | /// %1 = <vscale x 8 x i1> ptrue(i32 i31) |
241 | /// %2 = <vscale x 16 x i1> convert.to.svbool(<vscale x 8 x i1> %1) |
242 | /// %3 = <vscale x 4 x i1> convert.from.svbool(<vscale x 16 x i1> %2) |
243 | /// |
244 | bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls( |
245 | SmallSetVector<Function *, 4> &Functions) { |
246 | bool Changed = false; |
247 | |
248 | for (auto *F : Functions) { |
249 | for (auto &BB : *F) { |
250 | SmallSetVector<IntrinsicInst *, 4> SVAllPTrues; |
251 | SmallSetVector<IntrinsicInst *, 4> SVPow2PTrues; |
252 | |
253 | // For each basic block, collect the used ptrues and try to coalesce them. |
254 | for (Instruction &I : BB) { |
255 | if (I.use_empty()) |
256 | continue; |
257 | |
258 | auto *IntrI = dyn_cast<IntrinsicInst>(Val: &I); |
259 | if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) |
260 | continue; |
261 | |
262 | const auto PTruePattern = |
263 | cast<ConstantInt>(Val: IntrI->getOperand(i_nocapture: 0))->getZExtValue(); |
264 | |
265 | if (PTruePattern == AArch64SVEPredPattern::all) |
266 | SVAllPTrues.insert(X: IntrI); |
267 | if (PTruePattern == AArch64SVEPredPattern::pow2) |
268 | SVPow2PTrues.insert(X: IntrI); |
269 | } |
270 | |
271 | Changed |= coalescePTrueIntrinsicCalls(BB, PTrues&: SVAllPTrues); |
272 | Changed |= coalescePTrueIntrinsicCalls(BB, PTrues&: SVPow2PTrues); |
273 | } |
274 | } |
275 | |
276 | return Changed; |
277 | } |
278 | |
279 | // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce |
280 | // scalable stores as late as possible |
281 | bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) { |
282 | auto *F = I->getFunction(); |
283 | auto Attr = F->getFnAttribute(Attribute::VScaleRange); |
284 | if (!Attr.isValid()) |
285 | return false; |
286 | |
287 | unsigned MinVScale = Attr.getVScaleRangeMin(); |
288 | std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax(); |
289 | // The transform needs to know the exact runtime length of scalable vectors |
290 | if (!MaxVScale || MinVScale != MaxVScale) |
291 | return false; |
292 | |
293 | auto *PredType = |
294 | ScalableVectorType::get(ElementType: Type::getInt1Ty(C&: I->getContext()), MinNumElts: 16); |
295 | auto *FixedPredType = |
296 | FixedVectorType::get(ElementType: Type::getInt8Ty(C&: I->getContext()), NumElts: MinVScale * 2); |
297 | |
298 | // If we have a store.. |
299 | auto *Store = dyn_cast<StoreInst>(Val: I); |
300 | if (!Store || !Store->isSimple()) |
301 | return false; |
302 | |
303 | // ..that is storing a predicate vector sized worth of bits.. |
304 | if (Store->getOperand(i_nocapture: 0)->getType() != FixedPredType) |
305 | return false; |
306 | |
307 | // ..where the value stored comes from a vector extract.. |
308 | auto *IntrI = dyn_cast<IntrinsicInst>(Val: Store->getOperand(i_nocapture: 0)); |
309 | if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract) |
310 | return false; |
311 | |
312 | // ..that is extracting from index 0.. |
313 | if (!cast<ConstantInt>(Val: IntrI->getOperand(i_nocapture: 1))->isZero()) |
314 | return false; |
315 | |
316 | // ..where the value being extract from comes from a bitcast |
317 | auto *BitCast = dyn_cast<BitCastInst>(Val: IntrI->getOperand(i_nocapture: 0)); |
318 | if (!BitCast) |
319 | return false; |
320 | |
321 | // ..and the bitcast is casting from predicate type |
322 | if (BitCast->getOperand(i_nocapture: 0)->getType() != PredType) |
323 | return false; |
324 | |
325 | IRBuilder<> Builder(I->getContext()); |
326 | Builder.SetInsertPoint(I); |
327 | |
328 | Builder.CreateStore(Val: BitCast->getOperand(i_nocapture: 0), Ptr: Store->getPointerOperand()); |
329 | |
330 | Store->eraseFromParent(); |
331 | if (IntrI->getNumUses() == 0) |
332 | IntrI->eraseFromParent(); |
333 | if (BitCast->getNumUses() == 0) |
334 | BitCast->eraseFromParent(); |
335 | |
336 | return true; |
337 | } |
338 | |
339 | // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce |
340 | // scalable loads as late as possible |
341 | bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) { |
342 | auto *F = I->getFunction(); |
343 | auto Attr = F->getFnAttribute(Attribute::VScaleRange); |
344 | if (!Attr.isValid()) |
345 | return false; |
346 | |
347 | unsigned MinVScale = Attr.getVScaleRangeMin(); |
348 | std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax(); |
349 | // The transform needs to know the exact runtime length of scalable vectors |
350 | if (!MaxVScale || MinVScale != MaxVScale) |
351 | return false; |
352 | |
353 | auto *PredType = |
354 | ScalableVectorType::get(ElementType: Type::getInt1Ty(C&: I->getContext()), MinNumElts: 16); |
355 | auto *FixedPredType = |
356 | FixedVectorType::get(ElementType: Type::getInt8Ty(C&: I->getContext()), NumElts: MinVScale * 2); |
357 | |
358 | // If we have a bitcast.. |
359 | auto *BitCast = dyn_cast<BitCastInst>(Val: I); |
360 | if (!BitCast || BitCast->getType() != PredType) |
361 | return false; |
362 | |
363 | // ..whose operand is a vector_insert.. |
364 | auto *IntrI = dyn_cast<IntrinsicInst>(Val: BitCast->getOperand(i_nocapture: 0)); |
365 | if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert) |
366 | return false; |
367 | |
368 | // ..that is inserting into index zero of an undef vector.. |
369 | if (!isa<UndefValue>(Val: IntrI->getOperand(i_nocapture: 0)) || |
370 | !cast<ConstantInt>(Val: IntrI->getOperand(i_nocapture: 2))->isZero()) |
371 | return false; |
372 | |
373 | // ..where the value inserted comes from a load.. |
374 | auto *Load = dyn_cast<LoadInst>(Val: IntrI->getOperand(i_nocapture: 1)); |
375 | if (!Load || !Load->isSimple()) |
376 | return false; |
377 | |
378 | // ..that is loading a predicate vector sized worth of bits.. |
379 | if (Load->getType() != FixedPredType) |
380 | return false; |
381 | |
382 | IRBuilder<> Builder(I->getContext()); |
383 | Builder.SetInsertPoint(Load); |
384 | |
385 | auto *LoadPred = Builder.CreateLoad(Ty: PredType, Ptr: Load->getPointerOperand()); |
386 | |
387 | BitCast->replaceAllUsesWith(V: LoadPred); |
388 | BitCast->eraseFromParent(); |
389 | if (IntrI->getNumUses() == 0) |
390 | IntrI->eraseFromParent(); |
391 | if (Load->getNumUses() == 0) |
392 | Load->eraseFromParent(); |
393 | |
394 | return true; |
395 | } |
396 | |
397 | bool SVEIntrinsicOpts::optimizeInstructions( |
398 | SmallSetVector<Function *, 4> &Functions) { |
399 | bool Changed = false; |
400 | |
401 | for (auto *F : Functions) { |
402 | DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(F&: *F).getDomTree(); |
403 | |
404 | // Traverse the DT with an rpo walk so we see defs before uses, allowing |
405 | // simplification to be done incrementally. |
406 | BasicBlock *Root = DT->getRoot(); |
407 | ReversePostOrderTraversal<BasicBlock *> RPOT(Root); |
408 | for (auto *BB : RPOT) { |
409 | for (Instruction &I : make_early_inc_range(Range&: *BB)) { |
410 | switch (I.getOpcode()) { |
411 | case Instruction::Store: |
412 | Changed |= optimizePredicateStore(I: &I); |
413 | break; |
414 | case Instruction::BitCast: |
415 | Changed |= optimizePredicateLoad(I: &I); |
416 | break; |
417 | } |
418 | } |
419 | } |
420 | } |
421 | |
422 | return Changed; |
423 | } |
424 | |
425 | bool SVEIntrinsicOpts::optimizeFunctions( |
426 | SmallSetVector<Function *, 4> &Functions) { |
427 | bool Changed = false; |
428 | |
429 | Changed |= optimizePTrueIntrinsicCalls(Functions); |
430 | Changed |= optimizeInstructions(Functions); |
431 | |
432 | return Changed; |
433 | } |
434 | |
435 | bool SVEIntrinsicOpts::runOnModule(Module &M) { |
436 | bool Changed = false; |
437 | SmallSetVector<Function *, 4> Functions; |
438 | |
439 | // Check for SVE intrinsic declarations first so that we only iterate over |
440 | // relevant functions. Where an appropriate declaration is found, store the |
441 | // function(s) where it is used so we can target these only. |
442 | for (auto &F : M.getFunctionList()) { |
443 | if (!F.isDeclaration()) |
444 | continue; |
445 | |
446 | switch (F.getIntrinsicID()) { |
447 | case Intrinsic::vector_extract: |
448 | case Intrinsic::vector_insert: |
449 | case Intrinsic::aarch64_sve_ptrue: |
450 | for (User *U : F.users()) |
451 | Functions.insert(X: cast<Instruction>(Val: U)->getFunction()); |
452 | break; |
453 | default: |
454 | break; |
455 | } |
456 | } |
457 | |
458 | if (!Functions.empty()) |
459 | Changed |= optimizeFunctions(Functions); |
460 | |
461 | return Changed; |
462 | } |
463 | |