1 | //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements rewrites based on the basic rules of algebra |
10 | // (Commutativity, associativity, etc...) and strength reductions for math |
11 | // operations. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Math/IR/Math.h" |
17 | #include "mlir/Dialect/Math/Transforms/Passes.h" |
18 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
19 | #include "mlir/IR/Builders.h" |
20 | #include "mlir/IR/Matchers.h" |
21 | #include "mlir/IR/TypeUtilities.h" |
22 | #include <climits> |
23 | |
24 | using namespace mlir; |
25 | |
26 | //----------------------------------------------------------------------------// |
27 | // PowFOp strength reduction. |
28 | //----------------------------------------------------------------------------// |
29 | |
30 | namespace { |
31 | struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> { |
32 | public: |
33 | using OpRewritePattern::OpRewritePattern; |
34 | |
35 | LogicalResult matchAndRewrite(math::PowFOp op, |
36 | PatternRewriter &rewriter) const final; |
37 | }; |
38 | } // namespace |
39 | |
40 | LogicalResult |
41 | PowFStrengthReduction::matchAndRewrite(math::PowFOp op, |
42 | PatternRewriter &rewriter) const { |
43 | Location loc = op.getLoc(); |
44 | Value x = op.getLhs(); |
45 | |
46 | FloatAttr scalarExponent; |
47 | DenseFPElementsAttr vectorExponent; |
48 | |
49 | bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent)); |
50 | bool isVector = matchPattern(op.getRhs(), m_Constant(bind_value: &vectorExponent)); |
51 | |
52 | // Returns true if exponent is a constant equal to `value`. |
53 | auto isExponentValue = [&](double value) -> bool { |
54 | if (isScalar) |
55 | return scalarExponent.getValue().isExactlyValue(value); |
56 | |
57 | if (isVector && vectorExponent.isSplat()) |
58 | return vectorExponent.getSplatValue<FloatAttr>() |
59 | .getValue() |
60 | .isExactlyValue(value); |
61 | |
62 | return false; |
63 | }; |
64 | |
65 | // Maybe broadcasts scalar value into vector type compatible with `op`. |
66 | auto bcast = [&](Value value) -> Value { |
67 | if (auto vec = dyn_cast<VectorType>(op.getType())) |
68 | return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value); |
69 | return value; |
70 | }; |
71 | |
72 | // Replace `pow(x, 1.0)` with `x`. |
73 | if (isExponentValue(1.0)) { |
74 | rewriter.replaceOp(op, x); |
75 | return success(); |
76 | } |
77 | |
78 | // Replace `pow(x, 2.0)` with `x * x`. |
79 | if (isExponentValue(2.0)) { |
80 | rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x})); |
81 | return success(); |
82 | } |
83 | |
84 | // Replace `pow(x, 3.0)` with `x * x * x`. |
85 | if (isExponentValue(3.0)) { |
86 | Value square = |
87 | rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x})); |
88 | rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square})); |
89 | return success(); |
90 | } |
91 | |
92 | // Replace `pow(x, -1.0)` with `1.0 / x`. |
93 | if (isExponentValue(-1.0)) { |
94 | Value one = rewriter.create<arith::ConstantOp>( |
95 | loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0)); |
96 | rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x})); |
97 | return success(); |
98 | } |
99 | |
100 | // Replace `pow(x, 0.5)` with `sqrt(x)`. |
101 | if (isExponentValue(0.5)) { |
102 | rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x); |
103 | return success(); |
104 | } |
105 | |
106 | // Replace `pow(x, -0.5)` with `rsqrt(x)`. |
107 | if (isExponentValue(-0.5)) { |
108 | rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x); |
109 | return success(); |
110 | } |
111 | |
112 | // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`. |
113 | if (isExponentValue(0.75)) { |
114 | Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x); |
115 | Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf); |
116 | rewriter.replaceOpWithNewOp<arith::MulFOp>(op, |
117 | ValueRange{powHalf, powQuarter}); |
118 | return success(); |
119 | } |
120 | |
121 | return failure(); |
122 | } |
123 | |
124 | //----------------------------------------------------------------------------// |
125 | // FPowIOp/IPowIOp strength reduction. |
126 | //----------------------------------------------------------------------------// |
127 | |
128 | namespace { |
129 | template <typename PowIOpTy, typename DivOpTy, typename MulOpTy> |
130 | struct PowIStrengthReduction : public OpRewritePattern<PowIOpTy> { |
131 | |
132 | unsigned exponentThreshold; |
133 | |
134 | public: |
135 | PowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3, |
136 | PatternBenefit benefit = 1, |
137 | ArrayRef<StringRef> generatedNames = {}) |
138 | : OpRewritePattern<PowIOpTy>(context, benefit, generatedNames), |
139 | exponentThreshold(exponentThreshold) {} |
140 | |
141 | LogicalResult matchAndRewrite(PowIOpTy op, |
142 | PatternRewriter &rewriter) const final; |
143 | }; |
144 | } // namespace |
145 | |
146 | template <typename PowIOpTy, typename DivOpTy, typename MulOpTy> |
147 | LogicalResult |
148 | PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite( |
149 | PowIOpTy op, PatternRewriter &rewriter) const { |
150 | Location loc = op.getLoc(); |
151 | Value base = op.getLhs(); |
152 | |
153 | IntegerAttr scalarExponent; |
154 | DenseIntElementsAttr vectorExponent; |
155 | |
156 | bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent)); |
157 | bool isVector = matchPattern(op.getRhs(), m_Constant(bind_value: &vectorExponent)); |
158 | |
159 | // Simplify cases with known exponent value. |
160 | int64_t exponentValue = 0; |
161 | if (isScalar) |
162 | exponentValue = scalarExponent.getInt(); |
163 | else if (isVector && vectorExponent.isSplat()) |
164 | exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt(); |
165 | else |
166 | return failure(); |
167 | |
168 | // Maybe broadcasts scalar value into vector type compatible with `op`. |
169 | auto bcast = [&loc, &op, &rewriter](Value value) -> Value { |
170 | if (auto vec = dyn_cast<VectorType>(op.getType())) |
171 | return rewriter.create<vector::BroadcastOp>(loc, vec, value); |
172 | return value; |
173 | }; |
174 | |
175 | Value one; |
176 | Type opType = getElementTypeOrSelf(op.getType()); |
177 | if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) |
178 | one = rewriter.create<arith::ConstantOp>( |
179 | loc, rewriter.getFloatAttr(opType, 1.0)); |
180 | else |
181 | one = rewriter.create<arith::ConstantOp>( |
182 | loc, rewriter.getIntegerAttr(opType, 1)); |
183 | |
184 | // Replace `[fi]powi(x, 0)` with `1`. |
185 | if (exponentValue == 0) { |
186 | rewriter.replaceOp(op, bcast(one)); |
187 | return success(); |
188 | } |
189 | |
190 | bool exponentIsNegative = false; |
191 | if (exponentValue < 0) { |
192 | exponentIsNegative = true; |
193 | exponentValue *= -1; |
194 | } |
195 | |
196 | // Bail out if `abs(exponent)` exceeds the threshold. |
197 | if (exponentValue > exponentThreshold) |
198 | return failure(); |
199 | |
200 | // Inverse the base for negative exponent, i.e. for |
201 | // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`. |
202 | if (exponentIsNegative) |
203 | base = rewriter.create<DivOpTy>(loc, bcast(one), base); |
204 | |
205 | Value result = base; |
206 | // Transform to naive sequence of multiplications: |
207 | // * For positive exponent case replace: |
208 | // `[fi]powi(x, positive_exponent)` |
209 | // with: |
210 | // x * x * x * ... |
211 | // * For negative exponent case replace: |
212 | // `[fi]powi(x, negative_exponent)` |
213 | // with: |
214 | // (1 / x) * (1 / x) * (1 / x) * ... |
215 | for (unsigned i = 1; i < exponentValue; ++i) |
216 | result = rewriter.create<MulOpTy>(loc, result, base); |
217 | |
218 | rewriter.replaceOp(op, result); |
219 | return success(); |
220 | } |
221 | |
222 | //----------------------------------------------------------------------------// |
223 | |
224 | void mlir::populateMathAlgebraicSimplificationPatterns( |
225 | RewritePatternSet &patterns) { |
226 | patterns |
227 | .add<PowFStrengthReduction, |
228 | PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>, |
229 | PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>( |
230 | patterns.getContext()); |
231 | } |
232 | |