1//===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewrites based on the basic rules of algebra
10// (Commutativity, associativity, etc...) and strength reductions for math
11// operations.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Math/IR/Math.h"
17#include "mlir/Dialect/Math/Transforms/Passes.h"
18#include "mlir/Dialect/Vector/IR/VectorOps.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/TypeUtilities.h"
22#include <climits>
23
24using namespace mlir;
25
26//----------------------------------------------------------------------------//
27// PowFOp strength reduction.
28//----------------------------------------------------------------------------//
29
30namespace {
31struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
32public:
33 using OpRewritePattern::OpRewritePattern;
34
35 LogicalResult matchAndRewrite(math::PowFOp op,
36 PatternRewriter &rewriter) const final;
37};
38} // namespace
39
40LogicalResult
41PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
42 PatternRewriter &rewriter) const {
43 Location loc = op.getLoc();
44 Value x = op.getLhs();
45
46 FloatAttr scalarExponent;
47 DenseFPElementsAttr vectorExponent;
48
49 bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
50 bool isVector = matchPattern(op.getRhs(), m_Constant(bind_value: &vectorExponent));
51
52 // Returns true if exponent is a constant equal to `value`.
53 auto isExponentValue = [&](double value) -> bool {
54 if (isScalar)
55 return scalarExponent.getValue().isExactlyValue(value);
56
57 if (isVector && vectorExponent.isSplat())
58 return vectorExponent.getSplatValue<FloatAttr>()
59 .getValue()
60 .isExactlyValue(value);
61
62 return false;
63 };
64
65 // Maybe broadcasts scalar value into vector type compatible with `op`.
66 auto bcast = [&](Value value) -> Value {
67 if (auto vec = dyn_cast<VectorType>(op.getType()))
68 return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
69 return value;
70 };
71
72 // Replace `pow(x, 1.0)` with `x`.
73 if (isExponentValue(1.0)) {
74 rewriter.replaceOp(op, x);
75 return success();
76 }
77
78 // Replace `pow(x, 2.0)` with `x * x`.
79 if (isExponentValue(2.0)) {
80 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
81 return success();
82 }
83
84 // Replace `pow(x, 3.0)` with `x * x * x`.
85 if (isExponentValue(3.0)) {
86 Value square =
87 rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
88 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
89 return success();
90 }
91
92 // Replace `pow(x, -1.0)` with `1.0 / x`.
93 if (isExponentValue(-1.0)) {
94 Value one = rewriter.create<arith::ConstantOp>(
95 loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
96 rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
97 return success();
98 }
99
100 // Replace `pow(x, 0.5)` with `sqrt(x)`.
101 if (isExponentValue(0.5)) {
102 rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
103 return success();
104 }
105
106 // Replace `pow(x, -0.5)` with `rsqrt(x)`.
107 if (isExponentValue(-0.5)) {
108 rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
109 return success();
110 }
111
112 // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
113 if (isExponentValue(0.75)) {
114 Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x);
115 Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf);
116 rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
117 ValueRange{powHalf, powQuarter});
118 return success();
119 }
120
121 return failure();
122}
123
124//----------------------------------------------------------------------------//
125// FPowIOp/IPowIOp strength reduction.
126//----------------------------------------------------------------------------//
127
128namespace {
129template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
130struct PowIStrengthReduction : public OpRewritePattern<PowIOpTy> {
131
132 unsigned exponentThreshold;
133
134public:
135 PowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
136 PatternBenefit benefit = 1,
137 ArrayRef<StringRef> generatedNames = {})
138 : OpRewritePattern<PowIOpTy>(context, benefit, generatedNames),
139 exponentThreshold(exponentThreshold) {}
140
141 LogicalResult matchAndRewrite(PowIOpTy op,
142 PatternRewriter &rewriter) const final;
143};
144} // namespace
145
146template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
147LogicalResult
148PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
149 PowIOpTy op, PatternRewriter &rewriter) const {
150 Location loc = op.getLoc();
151 Value base = op.getLhs();
152
153 IntegerAttr scalarExponent;
154 DenseIntElementsAttr vectorExponent;
155
156 bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
157 bool isVector = matchPattern(op.getRhs(), m_Constant(bind_value: &vectorExponent));
158
159 // Simplify cases with known exponent value.
160 int64_t exponentValue = 0;
161 if (isScalar)
162 exponentValue = scalarExponent.getInt();
163 else if (isVector && vectorExponent.isSplat())
164 exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
165 else
166 return failure();
167
168 // Maybe broadcasts scalar value into vector type compatible with `op`.
169 auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
170 if (auto vec = dyn_cast<VectorType>(op.getType()))
171 return rewriter.create<vector::BroadcastOp>(loc, vec, value);
172 return value;
173 };
174
175 Value one;
176 Type opType = getElementTypeOrSelf(op.getType());
177 if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
178 one = rewriter.create<arith::ConstantOp>(
179 loc, rewriter.getFloatAttr(opType, 1.0));
180 else
181 one = rewriter.create<arith::ConstantOp>(
182 loc, rewriter.getIntegerAttr(opType, 1));
183
184 // Replace `[fi]powi(x, 0)` with `1`.
185 if (exponentValue == 0) {
186 rewriter.replaceOp(op, bcast(one));
187 return success();
188 }
189
190 bool exponentIsNegative = false;
191 if (exponentValue < 0) {
192 exponentIsNegative = true;
193 exponentValue *= -1;
194 }
195
196 // Bail out if `abs(exponent)` exceeds the threshold.
197 if (exponentValue > exponentThreshold)
198 return failure();
199
200 // Inverse the base for negative exponent, i.e. for
201 // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
202 if (exponentIsNegative)
203 base = rewriter.create<DivOpTy>(loc, bcast(one), base);
204
205 Value result = base;
206 // Transform to naive sequence of multiplications:
207 // * For positive exponent case replace:
208 // `[fi]powi(x, positive_exponent)`
209 // with:
210 // x * x * x * ...
211 // * For negative exponent case replace:
212 // `[fi]powi(x, negative_exponent)`
213 // with:
214 // (1 / x) * (1 / x) * (1 / x) * ...
215 for (unsigned i = 1; i < exponentValue; ++i)
216 result = rewriter.create<MulOpTy>(loc, result, base);
217
218 rewriter.replaceOp(op, result);
219 return success();
220}
221
222//----------------------------------------------------------------------------//
223
224void mlir::populateMathAlgebraicSimplificationPatterns(
225 RewritePatternSet &patterns) {
226 patterns
227 .add<PowFStrengthReduction,
228 PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
229 PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
230 patterns.getContext());
231}
232

source code of mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp