1 | //===- ExpandReductions.cpp - Expand reduction intrinsics -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass implements IR expansion for reduction intrinsics, allowing targets |
10 | // to enable the intrinsics until just before codegen. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/CodeGen/ExpandReductions.h" |
15 | #include "llvm/Analysis/TargetTransformInfo.h" |
16 | #include "llvm/CodeGen/Passes.h" |
17 | #include "llvm/IR/IRBuilder.h" |
18 | #include "llvm/IR/InstIterator.h" |
19 | #include "llvm/IR/IntrinsicInst.h" |
20 | #include "llvm/IR/Intrinsics.h" |
21 | #include "llvm/InitializePasses.h" |
22 | #include "llvm/Pass.h" |
23 | #include "llvm/Transforms/Utils/LoopUtils.h" |
24 | |
25 | using namespace llvm; |
26 | |
27 | namespace { |
28 | |
29 | bool expandReductions(Function &F, const TargetTransformInfo *TTI) { |
30 | bool Changed = false; |
31 | SmallVector<IntrinsicInst *, 4> Worklist; |
32 | for (auto &I : instructions(F)) { |
33 | if (auto *II = dyn_cast<IntrinsicInst>(Val: &I)) { |
34 | switch (II->getIntrinsicID()) { |
35 | default: break; |
36 | case Intrinsic::vector_reduce_fadd: |
37 | case Intrinsic::vector_reduce_fmul: |
38 | case Intrinsic::vector_reduce_add: |
39 | case Intrinsic::vector_reduce_mul: |
40 | case Intrinsic::vector_reduce_and: |
41 | case Intrinsic::vector_reduce_or: |
42 | case Intrinsic::vector_reduce_xor: |
43 | case Intrinsic::vector_reduce_smax: |
44 | case Intrinsic::vector_reduce_smin: |
45 | case Intrinsic::vector_reduce_umax: |
46 | case Intrinsic::vector_reduce_umin: |
47 | case Intrinsic::vector_reduce_fmax: |
48 | case Intrinsic::vector_reduce_fmin: |
49 | if (TTI->shouldExpandReduction(II)) |
50 | Worklist.push_back(Elt: II); |
51 | |
52 | break; |
53 | } |
54 | } |
55 | } |
56 | |
57 | for (auto *II : Worklist) { |
58 | FastMathFlags FMF = |
59 | isa<FPMathOperator>(Val: II) ? II->getFastMathFlags() : FastMathFlags{}; |
60 | Intrinsic::ID ID = II->getIntrinsicID(); |
61 | RecurKind RK = getMinMaxReductionRecurKind(RdxID: ID); |
62 | |
63 | Value *Rdx = nullptr; |
64 | IRBuilder<> Builder(II); |
65 | IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); |
66 | Builder.setFastMathFlags(FMF); |
67 | switch (ID) { |
68 | default: llvm_unreachable("Unexpected intrinsic!" ); |
69 | case Intrinsic::vector_reduce_fadd: |
70 | case Intrinsic::vector_reduce_fmul: { |
71 | // FMFs must be attached to the call, otherwise it's an ordered reduction |
72 | // and it can't be handled by generating a shuffle sequence. |
73 | Value *Acc = II->getArgOperand(i: 0); |
74 | Value *Vec = II->getArgOperand(i: 1); |
75 | unsigned RdxOpcode = getArithmeticReductionInstruction(RdxID: ID); |
76 | if (!FMF.allowReassoc()) |
77 | Rdx = getOrderedReduction(Builder, Acc, Src: Vec, Op: RdxOpcode, MinMaxKind: RK); |
78 | else { |
79 | if (!isPowerOf2_32( |
80 | Value: cast<FixedVectorType>(Val: Vec->getType())->getNumElements())) |
81 | continue; |
82 | Rdx = getShuffleReduction(Builder, Src: Vec, Op: RdxOpcode, MinMaxKind: RK); |
83 | Rdx = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)RdxOpcode, LHS: Acc, RHS: Rdx, |
84 | Name: "bin.rdx" ); |
85 | } |
86 | break; |
87 | } |
88 | case Intrinsic::vector_reduce_and: |
89 | case Intrinsic::vector_reduce_or: { |
90 | // Canonicalize logical or/and reductions: |
91 | // Or reduction for i1 is represented as: |
92 | // %val = bitcast <ReduxWidth x i1> to iReduxWidth |
93 | // %res = cmp ne iReduxWidth %val, 0 |
94 | // And reduction for i1 is represented as: |
95 | // %val = bitcast <ReduxWidth x i1> to iReduxWidth |
96 | // %res = cmp eq iReduxWidth %val, 11111 |
97 | Value *Vec = II->getArgOperand(i: 0); |
98 | auto *FTy = cast<FixedVectorType>(Val: Vec->getType()); |
99 | unsigned NumElts = FTy->getNumElements(); |
100 | if (!isPowerOf2_32(Value: NumElts)) |
101 | continue; |
102 | |
103 | if (FTy->getElementType() == Builder.getInt1Ty()) { |
104 | Rdx = Builder.CreateBitCast(V: Vec, DestTy: Builder.getIntNTy(N: NumElts)); |
105 | if (ID == Intrinsic::vector_reduce_and) { |
106 | Rdx = Builder.CreateICmpEQ( |
107 | LHS: Rdx, RHS: ConstantInt::getAllOnesValue(Ty: Rdx->getType())); |
108 | } else { |
109 | assert(ID == Intrinsic::vector_reduce_or && "Expected or reduction." ); |
110 | Rdx = Builder.CreateIsNotNull(Arg: Rdx); |
111 | } |
112 | break; |
113 | } |
114 | unsigned RdxOpcode = getArithmeticReductionInstruction(RdxID: ID); |
115 | Rdx = getShuffleReduction(Builder, Src: Vec, Op: RdxOpcode, MinMaxKind: RK); |
116 | break; |
117 | } |
118 | case Intrinsic::vector_reduce_add: |
119 | case Intrinsic::vector_reduce_mul: |
120 | case Intrinsic::vector_reduce_xor: |
121 | case Intrinsic::vector_reduce_smax: |
122 | case Intrinsic::vector_reduce_smin: |
123 | case Intrinsic::vector_reduce_umax: |
124 | case Intrinsic::vector_reduce_umin: { |
125 | Value *Vec = II->getArgOperand(i: 0); |
126 | if (!isPowerOf2_32( |
127 | Value: cast<FixedVectorType>(Val: Vec->getType())->getNumElements())) |
128 | continue; |
129 | unsigned RdxOpcode = getArithmeticReductionInstruction(RdxID: ID); |
130 | Rdx = getShuffleReduction(Builder, Src: Vec, Op: RdxOpcode, MinMaxKind: RK); |
131 | break; |
132 | } |
133 | case Intrinsic::vector_reduce_fmax: |
134 | case Intrinsic::vector_reduce_fmin: { |
135 | // We require "nnan" to use a shuffle reduction; "nsz" is implied by the |
136 | // semantics of the reduction. |
137 | Value *Vec = II->getArgOperand(i: 0); |
138 | if (!isPowerOf2_32( |
139 | Value: cast<FixedVectorType>(Val: Vec->getType())->getNumElements()) || |
140 | !FMF.noNaNs()) |
141 | continue; |
142 | unsigned RdxOpcode = getArithmeticReductionInstruction(RdxID: ID); |
143 | Rdx = getShuffleReduction(Builder, Src: Vec, Op: RdxOpcode, MinMaxKind: RK); |
144 | break; |
145 | } |
146 | } |
147 | II->replaceAllUsesWith(V: Rdx); |
148 | II->eraseFromParent(); |
149 | Changed = true; |
150 | } |
151 | return Changed; |
152 | } |
153 | |
154 | class ExpandReductions : public FunctionPass { |
155 | public: |
156 | static char ID; |
157 | ExpandReductions() : FunctionPass(ID) { |
158 | initializeExpandReductionsPass(*PassRegistry::getPassRegistry()); |
159 | } |
160 | |
161 | bool runOnFunction(Function &F) override { |
162 | const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); |
163 | return expandReductions(F, TTI); |
164 | } |
165 | |
166 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
167 | AU.addRequired<TargetTransformInfoWrapperPass>(); |
168 | AU.setPreservesCFG(); |
169 | } |
170 | }; |
171 | } |
172 | |
173 | char ExpandReductions::ID; |
174 | INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions" , |
175 | "Expand reduction intrinsics" , false, false) |
176 | INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) |
177 | INITIALIZE_PASS_END(ExpandReductions, "expand-reductions" , |
178 | "Expand reduction intrinsics" , false, false) |
179 | |
180 | FunctionPass *llvm::createExpandReductionsPass() { |
181 | return new ExpandReductions(); |
182 | } |
183 | |
184 | PreservedAnalyses ExpandReductionsPass::run(Function &F, |
185 | FunctionAnalysisManager &AM) { |
186 | const auto &TTI = AM.getResult<TargetIRAnalysis>(IR&: F); |
187 | if (!expandReductions(F, TTI: &TTI)) |
188 | return PreservedAnalyses::all(); |
189 | PreservedAnalyses PA; |
190 | PA.preserveSet<CFGAnalyses>(); |
191 | return PA; |
192 | } |
193 | |