1 | //===- ARMParallelDSP.cpp - Parallel DSP Pass -----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | /// \file |
10 | /// Armv6 introduced instructions to perform 32-bit SIMD operations. The |
11 | /// purpose of this pass is do some IR pattern matching to create ACLE |
12 | /// DSP intrinsics, which map on these 32-bit SIMD operations. |
13 | /// This pass runs only when unaligned accesses is supported/enabled. |
14 | // |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #include "ARM.h" |
18 | #include "ARMSubtarget.h" |
19 | #include "llvm/ADT/SmallPtrSet.h" |
20 | #include "llvm/ADT/Statistic.h" |
21 | #include "llvm/Analysis/AliasAnalysis.h" |
22 | #include "llvm/Analysis/AssumptionCache.h" |
23 | #include "llvm/Analysis/GlobalsModRef.h" |
24 | #include "llvm/Analysis/LoopAccessAnalysis.h" |
25 | #include "llvm/Analysis/TargetLibraryInfo.h" |
26 | #include "llvm/CodeGen/TargetPassConfig.h" |
27 | #include "llvm/IR/Instructions.h" |
28 | #include "llvm/IR/IntrinsicsARM.h" |
29 | #include "llvm/IR/IRBuilder.h" |
30 | #include "llvm/IR/NoFolder.h" |
31 | #include "llvm/IR/PatternMatch.h" |
32 | #include "llvm/Pass.h" |
33 | #include "llvm/PassRegistry.h" |
34 | #include "llvm/Support/Debug.h" |
35 | #include "llvm/Transforms/Scalar.h" |
36 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
37 | |
38 | using namespace llvm; |
39 | using namespace PatternMatch; |
40 | |
41 | #define DEBUG_TYPE "arm-parallel-dsp" |
42 | |
43 | STATISTIC(NumSMLAD , "Number of smlad instructions generated" ); |
44 | |
45 | static cl::opt<bool> |
46 | DisableParallelDSP("disable-arm-parallel-dsp" , cl::Hidden, cl::init(Val: false), |
47 | cl::desc("Disable the ARM Parallel DSP pass" )); |
48 | |
49 | static cl::opt<unsigned> |
50 | NumLoadLimit("arm-parallel-dsp-load-limit" , cl::Hidden, cl::init(Val: 16), |
51 | cl::desc("Limit the number of loads analysed" )); |
52 | |
53 | namespace { |
54 | struct MulCandidate; |
55 | class Reduction; |
56 | |
57 | using MulCandList = SmallVector<std::unique_ptr<MulCandidate>, 8>; |
58 | using MemInstList = SmallVectorImpl<LoadInst*>; |
59 | using MulPairList = SmallVector<std::pair<MulCandidate*, MulCandidate*>, 8>; |
60 | |
61 | // 'MulCandidate' holds the multiplication instructions that are candidates |
62 | // for parallel execution. |
63 | struct MulCandidate { |
64 | Instruction *Root; |
65 | Value* LHS; |
66 | Value* RHS; |
67 | bool Exchange = false; |
68 | bool Paired = false; |
69 | SmallVector<LoadInst*, 2> VecLd; // Container for loads to widen. |
70 | |
71 | MulCandidate(Instruction *I, Value *lhs, Value *rhs) : |
72 | Root(I), LHS(lhs), RHS(rhs) { } |
73 | |
74 | bool HasTwoLoadInputs() const { |
75 | return isa<LoadInst>(Val: LHS) && isa<LoadInst>(Val: RHS); |
76 | } |
77 | |
78 | LoadInst *getBaseLoad() const { |
79 | return VecLd.front(); |
80 | } |
81 | }; |
82 | |
83 | /// Represent a sequence of multiply-accumulate operations with the aim to |
84 | /// perform the multiplications in parallel. |
85 | class Reduction { |
86 | Instruction *Root = nullptr; |
87 | Value *Acc = nullptr; |
88 | MulCandList Muls; |
89 | MulPairList MulPairs; |
90 | SetVector<Instruction*> Adds; |
91 | |
92 | public: |
93 | Reduction() = delete; |
94 | |
95 | Reduction (Instruction *Add) : Root(Add) { } |
96 | |
97 | /// Record an Add instruction that is a part of the this reduction. |
98 | void InsertAdd(Instruction *I) { Adds.insert(X: I); } |
99 | |
100 | /// Create MulCandidates, each rooted at a Mul instruction, that is a part |
101 | /// of this reduction. |
102 | void InsertMuls() { |
103 | auto GetMulOperand = [](Value *V) -> Instruction* { |
104 | if (auto *SExt = dyn_cast<SExtInst>(Val: V)) { |
105 | if (auto *I = dyn_cast<Instruction>(Val: SExt->getOperand(i_nocapture: 0))) |
106 | if (I->getOpcode() == Instruction::Mul) |
107 | return I; |
108 | } else if (auto *I = dyn_cast<Instruction>(Val: V)) { |
109 | if (I->getOpcode() == Instruction::Mul) |
110 | return I; |
111 | } |
112 | return nullptr; |
113 | }; |
114 | |
115 | auto InsertMul = [this](Instruction *I) { |
116 | Value *LHS = cast<Instruction>(Val: I->getOperand(i: 0))->getOperand(i: 0); |
117 | Value *RHS = cast<Instruction>(Val: I->getOperand(i: 1))->getOperand(i: 0); |
118 | Muls.push_back(Elt: std::make_unique<MulCandidate>(args&: I, args&: LHS, args&: RHS)); |
119 | }; |
120 | |
121 | for (auto *Add : Adds) { |
122 | if (Add == Acc) |
123 | continue; |
124 | if (auto *Mul = GetMulOperand(Add->getOperand(i: 0))) |
125 | InsertMul(Mul); |
126 | if (auto *Mul = GetMulOperand(Add->getOperand(i: 1))) |
127 | InsertMul(Mul); |
128 | } |
129 | } |
130 | |
131 | /// Add the incoming accumulator value, returns true if a value had not |
132 | /// already been added. Returning false signals to the user that this |
133 | /// reduction already has a value to initialise the accumulator. |
134 | bool InsertAcc(Value *V) { |
135 | if (Acc) |
136 | return false; |
137 | Acc = V; |
138 | return true; |
139 | } |
140 | |
141 | /// Set two MulCandidates, rooted at muls, that can be executed as a single |
142 | /// parallel operation. |
143 | void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1, |
144 | bool Exchange = false) { |
145 | LLVM_DEBUG(dbgs() << "Pairing:\n" |
146 | << *Mul0->Root << "\n" |
147 | << *Mul1->Root << "\n" ); |
148 | Mul0->Paired = true; |
149 | Mul1->Paired = true; |
150 | if (Exchange) |
151 | Mul1->Exchange = true; |
152 | MulPairs.push_back(Elt: std::make_pair(x&: Mul0, y&: Mul1)); |
153 | } |
154 | |
155 | /// Return the add instruction which is the root of the reduction. |
156 | Instruction *getRoot() { return Root; } |
157 | |
158 | bool is64Bit() const { return Root->getType()->isIntegerTy(Bitwidth: 64); } |
159 | |
160 | Type *getType() const { return Root->getType(); } |
161 | |
162 | /// Return the incoming value to be accumulated. This maybe null. |
163 | Value *getAccumulator() { return Acc; } |
164 | |
165 | /// Return the set of adds that comprise the reduction. |
166 | SetVector<Instruction*> &getAdds() { return Adds; } |
167 | |
168 | /// Return the MulCandidate, rooted at mul instruction, that comprise the |
169 | /// the reduction. |
170 | MulCandList &getMuls() { return Muls; } |
171 | |
172 | /// Return the MulCandidate, rooted at mul instructions, that have been |
173 | /// paired for parallel execution. |
174 | MulPairList &getMulPairs() { return MulPairs; } |
175 | |
176 | /// To finalise, replace the uses of the root with the intrinsic call. |
177 | void UpdateRoot(Instruction *SMLAD) { |
178 | Root->replaceAllUsesWith(V: SMLAD); |
179 | } |
180 | |
181 | void dump() { |
182 | LLVM_DEBUG(dbgs() << "Reduction:\n" ; |
183 | for (auto *Add : Adds) |
184 | LLVM_DEBUG(dbgs() << *Add << "\n" ); |
185 | for (auto &Mul : Muls) |
186 | LLVM_DEBUG(dbgs() << *Mul->Root << "\n" |
187 | << " " << *Mul->LHS << "\n" |
188 | << " " << *Mul->RHS << "\n" ); |
189 | LLVM_DEBUG(if (Acc) dbgs() << "Acc in: " << *Acc << "\n" ) |
190 | ); |
191 | } |
192 | }; |
193 | |
194 | class WidenedLoad { |
195 | LoadInst *NewLd = nullptr; |
196 | SmallVector<LoadInst*, 4> Loads; |
197 | |
198 | public: |
199 | WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide) |
200 | : NewLd(Wide) { |
201 | append_range(C&: Loads, R&: Lds); |
202 | } |
203 | LoadInst *getLoad() { |
204 | return NewLd; |
205 | } |
206 | }; |
207 | |
208 | class ARMParallelDSP : public FunctionPass { |
209 | ScalarEvolution *SE; |
210 | AliasAnalysis *AA; |
211 | TargetLibraryInfo *TLI; |
212 | DominatorTree *DT; |
213 | const DataLayout *DL; |
214 | Module *M; |
215 | std::map<LoadInst*, LoadInst*> LoadPairs; |
216 | SmallPtrSet<LoadInst*, 4> OffsetLoads; |
217 | std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads; |
218 | |
219 | template<unsigned> |
220 | bool IsNarrowSequence(Value *V); |
221 | bool Search(Value *V, BasicBlock *BB, Reduction &R); |
222 | bool RecordMemoryOps(BasicBlock *BB); |
223 | void InsertParallelMACs(Reduction &Reduction); |
224 | bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem); |
225 | LoadInst* CreateWideLoad(MemInstList &Loads, IntegerType *LoadTy); |
226 | bool CreateParallelPairs(Reduction &R); |
227 | |
228 | /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate |
229 | /// Dual performs two signed 16x16-bit multiplications. It adds the |
230 | /// products to a 32-bit accumulate operand. Optionally, the instruction can |
231 | /// exchange the halfwords of the second operand before performing the |
232 | /// arithmetic. |
233 | bool MatchSMLAD(Function &F); |
234 | |
235 | public: |
236 | static char ID; |
237 | |
238 | ARMParallelDSP() : FunctionPass(ID) { } |
239 | |
240 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
241 | FunctionPass::getAnalysisUsage(AU); |
242 | AU.addRequired<AssumptionCacheTracker>(); |
243 | AU.addRequired<ScalarEvolutionWrapperPass>(); |
244 | AU.addRequired<AAResultsWrapperPass>(); |
245 | AU.addRequired<TargetLibraryInfoWrapperPass>(); |
246 | AU.addRequired<DominatorTreeWrapperPass>(); |
247 | AU.addRequired<TargetPassConfig>(); |
248 | AU.addPreserved<ScalarEvolutionWrapperPass>(); |
249 | AU.addPreserved<GlobalsAAWrapperPass>(); |
250 | AU.setPreservesCFG(); |
251 | } |
252 | |
253 | bool runOnFunction(Function &F) override { |
254 | if (DisableParallelDSP) |
255 | return false; |
256 | if (skipFunction(F)) |
257 | return false; |
258 | |
259 | SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); |
260 | AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); |
261 | TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); |
262 | DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); |
263 | auto &TPC = getAnalysis<TargetPassConfig>(); |
264 | |
265 | M = F.getParent(); |
266 | DL = &M->getDataLayout(); |
267 | |
268 | auto &TM = TPC.getTM<TargetMachine>(); |
269 | auto *ST = &TM.getSubtarget<ARMSubtarget>(F); |
270 | |
271 | if (!ST->allowsUnalignedMem()) { |
272 | LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not " |
273 | "running pass ARMParallelDSP\n" ); |
274 | return false; |
275 | } |
276 | |
277 | if (!ST->hasDSP()) { |
278 | LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass " |
279 | "ARMParallelDSP\n" ); |
280 | return false; |
281 | } |
282 | |
283 | if (!ST->isLittle()) { |
284 | LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass " |
285 | << "ARMParallelDSP\n" ); |
286 | return false; |
287 | } |
288 | |
289 | LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n" ); |
290 | LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n" ); |
291 | |
292 | bool Changes = MatchSMLAD(F); |
293 | return Changes; |
294 | } |
295 | }; |
296 | } |
297 | |
298 | bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, |
299 | MemInstList &VecMem) { |
300 | if (!Ld0 || !Ld1) |
301 | return false; |
302 | |
303 | if (!LoadPairs.count(x: Ld0) || LoadPairs[Ld0] != Ld1) |
304 | return false; |
305 | |
306 | LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n" ; |
307 | dbgs() << "Ld0:" ; Ld0->dump(); |
308 | dbgs() << "Ld1:" ; Ld1->dump(); |
309 | ); |
310 | |
311 | VecMem.clear(); |
312 | VecMem.push_back(Elt: Ld0); |
313 | VecMem.push_back(Elt: Ld1); |
314 | return true; |
315 | } |
316 | |
317 | // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP |
318 | // instructions, which is set to 16. So here we should collect all i8 and i16 |
319 | // narrow operations. |
320 | // TODO: we currently only collect i16, and will support i8 later, so that's |
321 | // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth. |
322 | template<unsigned MaxBitWidth> |
323 | bool ARMParallelDSP::IsNarrowSequence(Value *V) { |
324 | if (auto *SExt = dyn_cast<SExtInst>(Val: V)) { |
325 | if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth) |
326 | return false; |
327 | |
328 | if (auto *Ld = dyn_cast<LoadInst>(Val: SExt->getOperand(i_nocapture: 0))) { |
329 | // Check that this load could be paired. |
330 | return LoadPairs.count(x: Ld) || OffsetLoads.count(Ptr: Ld); |
331 | } |
332 | } |
333 | return false; |
334 | } |
335 | |
336 | /// Iterate through the block and record base, offset pairs of loads which can |
337 | /// be widened into a single load. |
338 | bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) { |
339 | SmallVector<LoadInst*, 8> Loads; |
340 | SmallVector<Instruction*, 8> Writes; |
341 | LoadPairs.clear(); |
342 | WideLoads.clear(); |
343 | |
344 | // Collect loads and instruction that may write to memory. For now we only |
345 | // record loads which are simple, sign-extended and have a single user. |
346 | // TODO: Allow zero-extended loads. |
347 | for (auto &I : *BB) { |
348 | if (I.mayWriteToMemory()) |
349 | Writes.push_back(Elt: &I); |
350 | auto *Ld = dyn_cast<LoadInst>(Val: &I); |
351 | if (!Ld || !Ld->isSimple() || |
352 | !Ld->hasOneUse() || !isa<SExtInst>(Val: Ld->user_back())) |
353 | continue; |
354 | Loads.push_back(Elt: Ld); |
355 | } |
356 | |
357 | if (Loads.empty() || Loads.size() > NumLoadLimit) |
358 | return false; |
359 | |
360 | using InstSet = std::set<Instruction*>; |
361 | using DepMap = std::map<Instruction*, InstSet>; |
362 | DepMap RAWDeps; |
363 | |
364 | // Record any writes that may alias a load. |
365 | const auto Size = LocationSize::beforeOrAfterPointer(); |
366 | for (auto *Write : Writes) { |
367 | for (auto *Read : Loads) { |
368 | MemoryLocation ReadLoc = |
369 | MemoryLocation(Read->getPointerOperand(), Size); |
370 | |
371 | if (!isModOrRefSet(MRI: AA->getModRefInfo(I: Write, OptLoc: ReadLoc))) |
372 | continue; |
373 | if (Write->comesBefore(Other: Read)) |
374 | RAWDeps[Read].insert(x: Write); |
375 | } |
376 | } |
377 | |
378 | // Check whether there's not a write between the two loads which would |
379 | // prevent them from being safely merged. |
380 | auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) { |
381 | bool BaseFirst = Base->comesBefore(Other: Offset); |
382 | LoadInst *Dominator = BaseFirst ? Base : Offset; |
383 | LoadInst *Dominated = BaseFirst ? Offset : Base; |
384 | |
385 | if (RAWDeps.count(x: Dominated)) { |
386 | InstSet &WritesBefore = RAWDeps[Dominated]; |
387 | |
388 | for (auto *Before : WritesBefore) { |
389 | // We can't move the second load backward, past a write, to merge |
390 | // with the first load. |
391 | if (Dominator->comesBefore(Other: Before)) |
392 | return false; |
393 | } |
394 | } |
395 | return true; |
396 | }; |
397 | |
398 | // Record base, offset load pairs. |
399 | for (auto *Base : Loads) { |
400 | for (auto *Offset : Loads) { |
401 | if (Base == Offset || OffsetLoads.count(Ptr: Offset)) |
402 | continue; |
403 | |
404 | if (isConsecutiveAccess(A: Base, B: Offset, DL: *DL, SE&: *SE) && |
405 | SafeToPair(Base, Offset)) { |
406 | LoadPairs[Base] = Offset; |
407 | OffsetLoads.insert(Ptr: Offset); |
408 | break; |
409 | } |
410 | } |
411 | } |
412 | |
413 | LLVM_DEBUG(if (!LoadPairs.empty()) { |
414 | dbgs() << "Consecutive load pairs:\n" ; |
415 | for (auto &MapIt : LoadPairs) { |
416 | LLVM_DEBUG(dbgs() << *MapIt.first << ", " |
417 | << *MapIt.second << "\n" ); |
418 | } |
419 | }); |
420 | return LoadPairs.size() > 1; |
421 | } |
422 | |
423 | // Search recursively back through the operands to find a tree of values that |
424 | // form a multiply-accumulate chain. The search records the Add and Mul |
425 | // instructions that form the reduction and allows us to find a single value |
426 | // to be used as the initial input to the accumlator. |
427 | bool ARMParallelDSP::Search(Value *V, BasicBlock *BB, Reduction &R) { |
428 | // If we find a non-instruction, try to use it as the initial accumulator |
429 | // value. This may have already been found during the search in which case |
430 | // this function will return false, signaling a search fail. |
431 | auto *I = dyn_cast<Instruction>(Val: V); |
432 | if (!I) |
433 | return R.InsertAcc(V); |
434 | |
435 | if (I->getParent() != BB) |
436 | return false; |
437 | |
438 | switch (I->getOpcode()) { |
439 | default: |
440 | break; |
441 | case Instruction::PHI: |
442 | // Could be the accumulator value. |
443 | return R.InsertAcc(V); |
444 | case Instruction::Add: { |
445 | // Adds should be adding together two muls, or another add and a mul to |
446 | // be within the mac chain. One of the operands may also be the |
447 | // accumulator value at which point we should stop searching. |
448 | R.InsertAdd(I); |
449 | Value *LHS = I->getOperand(i: 0); |
450 | Value *RHS = I->getOperand(i: 1); |
451 | bool ValidLHS = Search(V: LHS, BB, R); |
452 | bool ValidRHS = Search(V: RHS, BB, R); |
453 | |
454 | if (ValidLHS && ValidRHS) |
455 | return true; |
456 | |
457 | // Ensure we don't add the root as the incoming accumulator. |
458 | if (R.getRoot() == I) |
459 | return false; |
460 | |
461 | return R.InsertAcc(V: I); |
462 | } |
463 | case Instruction::Mul: { |
464 | Value *MulOp0 = I->getOperand(i: 0); |
465 | Value *MulOp1 = I->getOperand(i: 1); |
466 | return IsNarrowSequence<16>(V: MulOp0) && IsNarrowSequence<16>(V: MulOp1); |
467 | } |
468 | case Instruction::SExt: |
469 | return Search(V: I->getOperand(i: 0), BB, R); |
470 | } |
471 | return false; |
472 | } |
473 | |
474 | // The pass needs to identify integer add/sub reductions of 16-bit vector |
475 | // multiplications. |
476 | // To use SMLAD: |
477 | // 1) we first need to find integer add then look for this pattern: |
478 | // |
479 | // acc0 = ... |
480 | // ld0 = load i16 |
481 | // sext0 = sext i16 %ld0 to i32 |
482 | // ld1 = load i16 |
483 | // sext1 = sext i16 %ld1 to i32 |
484 | // mul0 = mul %sext0, %sext1 |
485 | // ld2 = load i16 |
486 | // sext2 = sext i16 %ld2 to i32 |
487 | // ld3 = load i16 |
488 | // sext3 = sext i16 %ld3 to i32 |
489 | // mul1 = mul i32 %sext2, %sext3 |
490 | // add0 = add i32 %mul0, %acc0 |
491 | // acc1 = add i32 %add0, %mul1 |
492 | // |
493 | // Which can be selected to: |
494 | // |
495 | // ldr r0 |
496 | // ldr r1 |
497 | // smlad r2, r0, r1, r2 |
498 | // |
499 | // If constants are used instead of loads, these will need to be hoisted |
500 | // out and into a register. |
501 | // |
502 | // If loop invariants are used instead of loads, these need to be packed |
503 | // before the loop begins. |
504 | // |
505 | bool ARMParallelDSP::MatchSMLAD(Function &F) { |
506 | bool Changed = false; |
507 | |
508 | for (auto &BB : F) { |
509 | SmallPtrSet<Instruction*, 4> AllAdds; |
510 | if (!RecordMemoryOps(BB: &BB)) |
511 | continue; |
512 | |
513 | for (Instruction &I : reverse(C&: BB)) { |
514 | if (I.getOpcode() != Instruction::Add) |
515 | continue; |
516 | |
517 | if (AllAdds.count(Ptr: &I)) |
518 | continue; |
519 | |
520 | const auto *Ty = I.getType(); |
521 | if (!Ty->isIntegerTy(Bitwidth: 32) && !Ty->isIntegerTy(Bitwidth: 64)) |
522 | continue; |
523 | |
524 | Reduction R(&I); |
525 | if (!Search(V: &I, BB: &BB, R)) |
526 | continue; |
527 | |
528 | R.InsertMuls(); |
529 | LLVM_DEBUG(dbgs() << "After search, Reduction:\n" ; R.dump()); |
530 | |
531 | if (!CreateParallelPairs(R)) |
532 | continue; |
533 | |
534 | InsertParallelMACs(Reduction&: R); |
535 | Changed = true; |
536 | AllAdds.insert(I: R.getAdds().begin(), E: R.getAdds().end()); |
537 | LLVM_DEBUG(dbgs() << "BB after inserting parallel MACs:\n" << BB); |
538 | } |
539 | } |
540 | |
541 | return Changed; |
542 | } |
543 | |
544 | bool ARMParallelDSP::CreateParallelPairs(Reduction &R) { |
545 | |
546 | // Not enough mul operations to make a pair. |
547 | if (R.getMuls().size() < 2) |
548 | return false; |
549 | |
550 | // Check that the muls operate directly upon sign extended loads. |
551 | for (auto &MulCand : R.getMuls()) { |
552 | if (!MulCand->HasTwoLoadInputs()) |
553 | return false; |
554 | } |
555 | |
556 | auto CanPair = [&](Reduction &R, MulCandidate *PMul0, MulCandidate *PMul1) { |
557 | // The first elements of each vector should be loads with sexts. If we |
558 | // find that its two pairs of consecutive loads, then these can be |
559 | // transformed into two wider loads and the users can be replaced with |
560 | // DSP intrinsics. |
561 | auto Ld0 = static_cast<LoadInst*>(PMul0->LHS); |
562 | auto Ld1 = static_cast<LoadInst*>(PMul1->LHS); |
563 | auto Ld2 = static_cast<LoadInst*>(PMul0->RHS); |
564 | auto Ld3 = static_cast<LoadInst*>(PMul1->RHS); |
565 | |
566 | // Check that each mul is operating on two different loads. |
567 | if (Ld0 == Ld2 || Ld1 == Ld3) |
568 | return false; |
569 | |
570 | if (AreSequentialLoads(Ld0, Ld1, VecMem&: PMul0->VecLd)) { |
571 | if (AreSequentialLoads(Ld0: Ld2, Ld1: Ld3, VecMem&: PMul1->VecLd)) { |
572 | LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n" ); |
573 | R.AddMulPair(Mul0: PMul0, Mul1: PMul1); |
574 | return true; |
575 | } else if (AreSequentialLoads(Ld0: Ld3, Ld1: Ld2, VecMem&: PMul1->VecLd)) { |
576 | LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n" ); |
577 | LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n" ); |
578 | R.AddMulPair(Mul0: PMul0, Mul1: PMul1, Exchange: true); |
579 | return true; |
580 | } |
581 | } else if (AreSequentialLoads(Ld0: Ld1, Ld1: Ld0, VecMem&: PMul0->VecLd) && |
582 | AreSequentialLoads(Ld0: Ld2, Ld1: Ld3, VecMem&: PMul1->VecLd)) { |
583 | LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n" ); |
584 | LLVM_DEBUG(dbgs() << " exchanging Ld0 and Ld1\n" ); |
585 | LLVM_DEBUG(dbgs() << " and swapping muls\n" ); |
586 | // Only the second operand can be exchanged, so swap the muls. |
587 | R.AddMulPair(Mul0: PMul1, Mul1: PMul0, Exchange: true); |
588 | return true; |
589 | } |
590 | return false; |
591 | }; |
592 | |
593 | MulCandList &Muls = R.getMuls(); |
594 | const unsigned Elems = Muls.size(); |
595 | for (unsigned i = 0; i < Elems; ++i) { |
596 | MulCandidate *PMul0 = static_cast<MulCandidate*>(Muls[i].get()); |
597 | if (PMul0->Paired) |
598 | continue; |
599 | |
600 | for (unsigned j = 0; j < Elems; ++j) { |
601 | if (i == j) |
602 | continue; |
603 | |
604 | MulCandidate *PMul1 = static_cast<MulCandidate*>(Muls[j].get()); |
605 | if (PMul1->Paired) |
606 | continue; |
607 | |
608 | const Instruction *Mul0 = PMul0->Root; |
609 | const Instruction *Mul1 = PMul1->Root; |
610 | if (Mul0 == Mul1) |
611 | continue; |
612 | |
613 | assert(PMul0 != PMul1 && "expected different chains" ); |
614 | |
615 | if (CanPair(R, PMul0, PMul1)) |
616 | break; |
617 | } |
618 | } |
619 | return !R.getMulPairs().empty(); |
620 | } |
621 | |
622 | void ARMParallelDSP::InsertParallelMACs(Reduction &R) { |
623 | |
624 | auto CreateSMLAD = [&](LoadInst* WideLd0, LoadInst *WideLd1, |
625 | Value *Acc, bool Exchange, |
626 | Instruction *InsertAfter) { |
627 | // Replace the reduction chain with an intrinsic call |
628 | |
629 | Value* Args[] = { WideLd0, WideLd1, Acc }; |
630 | Function *SMLAD = nullptr; |
631 | if (Exchange) |
632 | SMLAD = Acc->getType()->isIntegerTy(Bitwidth: 32) ? |
633 | Intrinsic::getDeclaration(M, Intrinsic::id: arm_smladx) : |
634 | Intrinsic::getDeclaration(M, Intrinsic::id: arm_smlaldx); |
635 | else |
636 | SMLAD = Acc->getType()->isIntegerTy(Bitwidth: 32) ? |
637 | Intrinsic::getDeclaration(M, Intrinsic::id: arm_smlad) : |
638 | Intrinsic::getDeclaration(M, Intrinsic::id: arm_smlald); |
639 | |
640 | IRBuilder<NoFolder> Builder(InsertAfter->getParent(), |
641 | BasicBlock::iterator(InsertAfter)); |
642 | Instruction *Call = Builder.CreateCall(Callee: SMLAD, Args); |
643 | NumSMLAD++; |
644 | return Call; |
645 | }; |
646 | |
647 | // Return the instruction after the dominated instruction. |
648 | auto GetInsertPoint = [this](Value *A, Value *B) { |
649 | assert((isa<Instruction>(A) || isa<Instruction>(B)) && |
650 | "expected at least one instruction" ); |
651 | |
652 | Value *V = nullptr; |
653 | if (!isa<Instruction>(Val: A)) |
654 | V = B; |
655 | else if (!isa<Instruction>(Val: B)) |
656 | V = A; |
657 | else |
658 | V = DT->dominates(Def: cast<Instruction>(Val: A), User: cast<Instruction>(Val: B)) ? B : A; |
659 | |
660 | return &*++BasicBlock::iterator(cast<Instruction>(Val: V)); |
661 | }; |
662 | |
663 | Value *Acc = R.getAccumulator(); |
664 | |
665 | // For any muls that were discovered but not paired, accumulate their values |
666 | // as before. |
667 | IRBuilder<NoFolder> Builder(R.getRoot()->getParent()); |
668 | MulCandList &MulCands = R.getMuls(); |
669 | for (auto &MulCand : MulCands) { |
670 | if (MulCand->Paired) |
671 | continue; |
672 | |
673 | Instruction *Mul = cast<Instruction>(Val: MulCand->Root); |
674 | LLVM_DEBUG(dbgs() << "Accumulating unpaired mul: " << *Mul << "\n" ); |
675 | |
676 | if (R.getType() != Mul->getType()) { |
677 | assert(R.is64Bit() && "expected 64-bit result" ); |
678 | Builder.SetInsertPoint(&*++BasicBlock::iterator(Mul)); |
679 | Mul = cast<Instruction>(Val: Builder.CreateSExt(V: Mul, DestTy: R.getRoot()->getType())); |
680 | } |
681 | |
682 | if (!Acc) { |
683 | Acc = Mul; |
684 | continue; |
685 | } |
686 | |
687 | // If Acc is the original incoming value to the reduction, it could be a |
688 | // phi. But the phi will dominate Mul, meaning that Mul will be the |
689 | // insertion point. |
690 | Builder.SetInsertPoint(GetInsertPoint(Mul, Acc)); |
691 | Acc = Builder.CreateAdd(LHS: Mul, RHS: Acc); |
692 | } |
693 | |
694 | if (!Acc) { |
695 | Acc = R.is64Bit() ? |
696 | ConstantInt::get(Ty: IntegerType::get(C&: M->getContext(), NumBits: 64), V: 0) : |
697 | ConstantInt::get(Ty: IntegerType::get(C&: M->getContext(), NumBits: 32), V: 0); |
698 | } else if (Acc->getType() != R.getType()) { |
699 | Builder.SetInsertPoint(R.getRoot()); |
700 | Acc = Builder.CreateSExt(V: Acc, DestTy: R.getType()); |
701 | } |
702 | |
703 | // Roughly sort the mul pairs in their program order. |
704 | llvm::sort(C&: R.getMulPairs(), Comp: [](auto &PairA, auto &PairB) { |
705 | const Instruction *A = PairA.first->Root; |
706 | const Instruction *B = PairB.first->Root; |
707 | return A->comesBefore(Other: B); |
708 | }); |
709 | |
710 | IntegerType *Ty = IntegerType::get(C&: M->getContext(), NumBits: 32); |
711 | for (auto &Pair : R.getMulPairs()) { |
712 | MulCandidate *LHSMul = Pair.first; |
713 | MulCandidate *RHSMul = Pair.second; |
714 | LoadInst *BaseLHS = LHSMul->getBaseLoad(); |
715 | LoadInst *BaseRHS = RHSMul->getBaseLoad(); |
716 | LoadInst *WideLHS = WideLoads.count(x: BaseLHS) ? |
717 | WideLoads[BaseLHS]->getLoad() : CreateWideLoad(Loads&: LHSMul->VecLd, LoadTy: Ty); |
718 | LoadInst *WideRHS = WideLoads.count(x: BaseRHS) ? |
719 | WideLoads[BaseRHS]->getLoad() : CreateWideLoad(Loads&: RHSMul->VecLd, LoadTy: Ty); |
720 | |
721 | Instruction *InsertAfter = GetInsertPoint(WideLHS, WideRHS); |
722 | InsertAfter = GetInsertPoint(InsertAfter, Acc); |
723 | Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter); |
724 | } |
725 | R.UpdateRoot(SMLAD: cast<Instruction>(Val: Acc)); |
726 | } |
727 | |
728 | LoadInst* ARMParallelDSP::CreateWideLoad(MemInstList &Loads, |
729 | IntegerType *LoadTy) { |
730 | assert(Loads.size() == 2 && "currently only support widening two loads" ); |
731 | |
732 | LoadInst *Base = Loads[0]; |
733 | LoadInst *Offset = Loads[1]; |
734 | |
735 | Instruction *BaseSExt = dyn_cast<SExtInst>(Val: Base->user_back()); |
736 | Instruction *OffsetSExt = dyn_cast<SExtInst>(Val: Offset->user_back()); |
737 | |
738 | assert((BaseSExt && OffsetSExt) |
739 | && "Loads should have a single, extending, user" ); |
740 | |
741 | std::function<void(Value*, Value*)> MoveBefore = |
742 | [&](Value *A, Value *B) -> void { |
743 | if (!isa<Instruction>(Val: A) || !isa<Instruction>(Val: B)) |
744 | return; |
745 | |
746 | auto *Source = cast<Instruction>(Val: A); |
747 | auto *Sink = cast<Instruction>(Val: B); |
748 | |
749 | if (DT->dominates(Def: Source, User: Sink) || |
750 | Source->getParent() != Sink->getParent() || |
751 | isa<PHINode>(Val: Source) || isa<PHINode>(Val: Sink)) |
752 | return; |
753 | |
754 | Source->moveBefore(MovePos: Sink); |
755 | for (auto &Op : Source->operands()) |
756 | MoveBefore(Op, Source); |
757 | }; |
758 | |
759 | // Insert the load at the point of the original dominating load. |
760 | LoadInst *DomLoad = DT->dominates(Def: Base, User: Offset) ? Base : Offset; |
761 | IRBuilder<NoFolder> IRB(DomLoad->getParent(), |
762 | ++BasicBlock::iterator(DomLoad)); |
763 | |
764 | // Create the wide load, while making sure to maintain the original alignment |
765 | // as this prevents ldrd from being generated when it could be illegal due to |
766 | // memory alignment. |
767 | Value *VecPtr = Base->getPointerOperand(); |
768 | LoadInst *WideLoad = IRB.CreateAlignedLoad(Ty: LoadTy, Ptr: VecPtr, Align: Base->getAlign()); |
769 | |
770 | // Make sure everything is in the correct order in the basic block. |
771 | MoveBefore(Base->getPointerOperand(), VecPtr); |
772 | MoveBefore(VecPtr, WideLoad); |
773 | |
774 | // From the wide load, create two values that equal the original two loads. |
775 | // Loads[0] needs trunc while Loads[1] needs a lshr and trunc. |
776 | // TODO: Support big-endian as well. |
777 | Value *Bottom = IRB.CreateTrunc(V: WideLoad, DestTy: Base->getType()); |
778 | Value *NewBaseSExt = IRB.CreateSExt(V: Bottom, DestTy: BaseSExt->getType()); |
779 | BaseSExt->replaceAllUsesWith(V: NewBaseSExt); |
780 | |
781 | IntegerType *OffsetTy = cast<IntegerType>(Val: Offset->getType()); |
782 | Value *ShiftVal = ConstantInt::get(Ty: LoadTy, V: OffsetTy->getBitWidth()); |
783 | Value *Top = IRB.CreateLShr(LHS: WideLoad, RHS: ShiftVal); |
784 | Value *Trunc = IRB.CreateTrunc(V: Top, DestTy: OffsetTy); |
785 | Value *NewOffsetSExt = IRB.CreateSExt(V: Trunc, DestTy: OffsetSExt->getType()); |
786 | OffsetSExt->replaceAllUsesWith(V: NewOffsetSExt); |
787 | |
788 | LLVM_DEBUG(dbgs() << "From Base and Offset:\n" |
789 | << *Base << "\n" << *Offset << "\n" |
790 | << "Created Wide Load:\n" |
791 | << *WideLoad << "\n" |
792 | << *Bottom << "\n" |
793 | << *NewBaseSExt << "\n" |
794 | << *Top << "\n" |
795 | << *Trunc << "\n" |
796 | << *NewOffsetSExt << "\n" ); |
797 | WideLoads.emplace(args: std::make_pair(x&: Base, |
798 | y: std::make_unique<WidenedLoad>(args&: Loads, args&: WideLoad))); |
799 | return WideLoad; |
800 | } |
801 | |
802 | Pass *llvm::createARMParallelDSPPass() { |
803 | return new ARMParallelDSP(); |
804 | } |
805 | |
806 | char ARMParallelDSP::ID = 0; |
807 | |
808 | INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp" , |
809 | "Transform functions to use DSP intrinsics" , false, false) |
810 | INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp" , |
811 | "Transform functions to use DSP intrinsics" , false, false) |
812 | |