1//===- ExpandReductions.cpp - Expand reduction intrinsics -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements IR expansion for reduction intrinsics, allowing targets
10// to enable the intrinsics until just before codegen.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/CodeGen/ExpandReductions.h"
15#include "llvm/Analysis/TargetTransformInfo.h"
16#include "llvm/CodeGen/Passes.h"
17#include "llvm/IR/IRBuilder.h"
18#include "llvm/IR/InstIterator.h"
19#include "llvm/IR/IntrinsicInst.h"
20#include "llvm/IR/Intrinsics.h"
21#include "llvm/InitializePasses.h"
22#include "llvm/Pass.h"
23#include "llvm/Transforms/Utils/LoopUtils.h"
24
25using namespace llvm;
26
27namespace {
28
29bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
30 bool Changed = false;
31 SmallVector<IntrinsicInst *, 4> Worklist;
32 for (auto &I : instructions(F)) {
33 if (auto *II = dyn_cast<IntrinsicInst>(Val: &I)) {
34 switch (II->getIntrinsicID()) {
35 default: break;
36 case Intrinsic::vector_reduce_fadd:
37 case Intrinsic::vector_reduce_fmul:
38 case Intrinsic::vector_reduce_add:
39 case Intrinsic::vector_reduce_mul:
40 case Intrinsic::vector_reduce_and:
41 case Intrinsic::vector_reduce_or:
42 case Intrinsic::vector_reduce_xor:
43 case Intrinsic::vector_reduce_smax:
44 case Intrinsic::vector_reduce_smin:
45 case Intrinsic::vector_reduce_umax:
46 case Intrinsic::vector_reduce_umin:
47 case Intrinsic::vector_reduce_fmax:
48 case Intrinsic::vector_reduce_fmin:
49 if (TTI->shouldExpandReduction(II))
50 Worklist.push_back(Elt: II);
51
52 break;
53 }
54 }
55 }
56
57 for (auto *II : Worklist) {
58 FastMathFlags FMF =
59 isa<FPMathOperator>(Val: II) ? II->getFastMathFlags() : FastMathFlags{};
60 Intrinsic::ID ID = II->getIntrinsicID();
61 RecurKind RK = getMinMaxReductionRecurKind(RdxID: ID);
62
63 Value *Rdx = nullptr;
64 IRBuilder<> Builder(II);
65 IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
66 Builder.setFastMathFlags(FMF);
67 switch (ID) {
68 default: llvm_unreachable("Unexpected intrinsic!");
69 case Intrinsic::vector_reduce_fadd:
70 case Intrinsic::vector_reduce_fmul: {
71 // FMFs must be attached to the call, otherwise it's an ordered reduction
72 // and it can't be handled by generating a shuffle sequence.
73 Value *Acc = II->getArgOperand(i: 0);
74 Value *Vec = II->getArgOperand(i: 1);
75 unsigned RdxOpcode = getArithmeticReductionInstruction(RdxID: ID);
76 if (!FMF.allowReassoc())
77 Rdx = getOrderedReduction(Builder, Acc, Src: Vec, Op: RdxOpcode, MinMaxKind: RK);
78 else {
79 if (!isPowerOf2_32(
80 Value: cast<FixedVectorType>(Val: Vec->getType())->getNumElements()))
81 continue;
82 Rdx = getShuffleReduction(Builder, Src: Vec, Op: RdxOpcode, MinMaxKind: RK);
83 Rdx = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)RdxOpcode, LHS: Acc, RHS: Rdx,
84 Name: "bin.rdx");
85 }
86 break;
87 }
88 case Intrinsic::vector_reduce_and:
89 case Intrinsic::vector_reduce_or: {
90 // Canonicalize logical or/and reductions:
91 // Or reduction for i1 is represented as:
92 // %val = bitcast <ReduxWidth x i1> to iReduxWidth
93 // %res = cmp ne iReduxWidth %val, 0
94 // And reduction for i1 is represented as:
95 // %val = bitcast <ReduxWidth x i1> to iReduxWidth
96 // %res = cmp eq iReduxWidth %val, 11111
97 Value *Vec = II->getArgOperand(i: 0);
98 auto *FTy = cast<FixedVectorType>(Val: Vec->getType());
99 unsigned NumElts = FTy->getNumElements();
100 if (!isPowerOf2_32(Value: NumElts))
101 continue;
102
103 if (FTy->getElementType() == Builder.getInt1Ty()) {
104 Rdx = Builder.CreateBitCast(V: Vec, DestTy: Builder.getIntNTy(N: NumElts));
105 if (ID == Intrinsic::vector_reduce_and) {
106 Rdx = Builder.CreateICmpEQ(
107 LHS: Rdx, RHS: ConstantInt::getAllOnesValue(Ty: Rdx->getType()));
108 } else {
109 assert(ID == Intrinsic::vector_reduce_or && "Expected or reduction.");
110 Rdx = Builder.CreateIsNotNull(Arg: Rdx);
111 }
112 break;
113 }
114 unsigned RdxOpcode = getArithmeticReductionInstruction(RdxID: ID);
115 Rdx = getShuffleReduction(Builder, Src: Vec, Op: RdxOpcode, MinMaxKind: RK);
116 break;
117 }
118 case Intrinsic::vector_reduce_add:
119 case Intrinsic::vector_reduce_mul:
120 case Intrinsic::vector_reduce_xor:
121 case Intrinsic::vector_reduce_smax:
122 case Intrinsic::vector_reduce_smin:
123 case Intrinsic::vector_reduce_umax:
124 case Intrinsic::vector_reduce_umin: {
125 Value *Vec = II->getArgOperand(i: 0);
126 if (!isPowerOf2_32(
127 Value: cast<FixedVectorType>(Val: Vec->getType())->getNumElements()))
128 continue;
129 unsigned RdxOpcode = getArithmeticReductionInstruction(RdxID: ID);
130 Rdx = getShuffleReduction(Builder, Src: Vec, Op: RdxOpcode, MinMaxKind: RK);
131 break;
132 }
133 case Intrinsic::vector_reduce_fmax:
134 case Intrinsic::vector_reduce_fmin: {
135 // We require "nnan" to use a shuffle reduction; "nsz" is implied by the
136 // semantics of the reduction.
137 Value *Vec = II->getArgOperand(i: 0);
138 if (!isPowerOf2_32(
139 Value: cast<FixedVectorType>(Val: Vec->getType())->getNumElements()) ||
140 !FMF.noNaNs())
141 continue;
142 unsigned RdxOpcode = getArithmeticReductionInstruction(RdxID: ID);
143 Rdx = getShuffleReduction(Builder, Src: Vec, Op: RdxOpcode, MinMaxKind: RK);
144 break;
145 }
146 }
147 II->replaceAllUsesWith(V: Rdx);
148 II->eraseFromParent();
149 Changed = true;
150 }
151 return Changed;
152}
153
154class ExpandReductions : public FunctionPass {
155public:
156 static char ID;
157 ExpandReductions() : FunctionPass(ID) {
158 initializeExpandReductionsPass(*PassRegistry::getPassRegistry());
159 }
160
161 bool runOnFunction(Function &F) override {
162 const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
163 return expandReductions(F, TTI);
164 }
165
166 void getAnalysisUsage(AnalysisUsage &AU) const override {
167 AU.addRequired<TargetTransformInfoWrapperPass>();
168 AU.setPreservesCFG();
169 }
170};
171}
172
173char ExpandReductions::ID;
174INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
175 "Expand reduction intrinsics", false, false)
176INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
177INITIALIZE_PASS_END(ExpandReductions, "expand-reductions",
178 "Expand reduction intrinsics", false, false)
179
180FunctionPass *llvm::createExpandReductionsPass() {
181 return new ExpandReductions();
182}
183
184PreservedAnalyses ExpandReductionsPass::run(Function &F,
185 FunctionAnalysisManager &AM) {
186 const auto &TTI = AM.getResult<TargetIRAnalysis>(IR&: F);
187 if (!expandReductions(F, TTI: &TTI))
188 return PreservedAnalyses::all();
189 PreservedAnalyses PA;
190 PA.preserveSet<CFGAnalyses>();
191 return PA;
192}
193

source code of llvm/lib/CodeGen/ExpandReductions.cpp