1//===- PolynomialApproximation.cpp - Approximate math operations ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements expansion of math operations to fast approximations
10// that do not rely on any of the library functions.
11//
12//===----------------------------------------------------------------------===//
13
14#include <climits>
15#include <cmath>
16#include <cstddef>
17
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Math/IR/Math.h"
20#include "mlir/Dialect/Math/Transforms/Approximation.h"
21#include "mlir/Dialect/Math/Transforms/Passes.h"
22#include "mlir/Dialect/Utils/IndexingUtils.h"
23#include "mlir/Dialect/Vector/IR/VectorOps.h"
24#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
25#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
26#include "mlir/IR/Builders.h"
27#include "mlir/IR/BuiltinTypes.h"
28#include "mlir/IR/ImplicitLocOpBuilder.h"
29#include "mlir/IR/OpDefinition.h"
30#include "mlir/IR/PatternMatch.h"
31#include "mlir/IR/TypeUtilities.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "llvm/ADT/ArrayRef.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/Support/MathExtras.h"
36
37using namespace mlir;
38using namespace mlir::math;
39using namespace mlir::vector;
40
41// Helper to encapsulate a vector's shape (including scalable dims).
42struct VectorShape {
43 ArrayRef<int64_t> sizes;
44 ArrayRef<bool> scalableFlags;
45};
46
47// Returns vector shape if the type is a vector, otherwise return nullopt.
48static std::optional<VectorShape> vectorShape(Type type) {
49 if (auto vectorType = dyn_cast<VectorType>(Val&: type)) {
50 return VectorShape{.sizes: vectorType.getShape(), .scalableFlags: vectorType.getScalableDims()};
51 }
52 return std::nullopt;
53}
54
55static std::optional<VectorShape> vectorShape(Value value) {
56 return vectorShape(type: value.getType());
57}
58
59//----------------------------------------------------------------------------//
60// Broadcast scalar types and values into vector types and values.
61//----------------------------------------------------------------------------//
62
63// Broadcasts scalar type into vector type (iff shape is non-scalar).
64static Type broadcast(Type type, std::optional<VectorShape> shape) {
65 assert(!isa<VectorType>(type) && "must be scalar type");
66 return shape ? VectorType::get(shape: shape->sizes, elementType: type, scalableDims: shape->scalableFlags)
67 : type;
68}
69
70// Broadcasts scalar value into vector (iff shape is non-scalar).
71static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
72 std::optional<VectorShape> shape) {
73 assert(!isa<VectorType>(value.getType()) && "must be scalar value");
74 auto type = broadcast(type: value.getType(), shape);
75 return shape ? builder.create<BroadcastOp>(args&: type, args&: value) : value;
76}
77
78//----------------------------------------------------------------------------//
79// Helper function to handle n-D vectors with 1-D operations.
80//----------------------------------------------------------------------------//
81
82// Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors
83// and calls the compute function with 1-D vector operands. Stitches back all
84// results into the original n-D vector result.
85//
86// Examples: vectorWidth = 8
87// - vector<4x8xf32> unrolled 4 times
88// - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times
89// - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times
90//
91// Some math approximations rely on ISA-specific operations that only accept
92// fixed size 1-D vectors (e.g. AVX expects vectors of width 8).
93//
94// It is the caller's responsibility to verify that the inner dimension is
95// divisible by the vectorWidth, and that all operands have the same vector
96// shape.
97static Value
98handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
99 ValueRange operands, int64_t vectorWidth,
100 llvm::function_ref<Value(ValueRange)> compute) {
101 assert(!operands.empty() && "operands must be not empty");
102 assert(vectorWidth > 0 && "vector width must be larger than 0");
103
104 VectorType inputType = cast<VectorType>(Val: operands[0].getType());
105 ArrayRef<int64_t> inputShape = inputType.getShape();
106
107 // If input shape matches target vector width, we can just call the
108 // user-provided compute function with the operands.
109 if (inputShape == llvm::ArrayRef(vectorWidth))
110 return compute(operands);
111
112 // Check if the inner dimension has to be expanded, or we can directly iterate
113 // over the outer dimensions of the vector.
114 int64_t innerDim = inputShape.back();
115 int64_t expansionDim = innerDim / vectorWidth;
116 assert((innerDim % vectorWidth == 0) && "invalid inner dimension size");
117
118 // Maybe expand operands to the higher rank vector shape that we'll use to
119 // iterate over and extract one dimensional vectors.
120 SmallVector<int64_t> expandedShape(inputShape);
121 SmallVector<Value> expandedOperands(operands);
122
123 if (expansionDim > 1) {
124 // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth].
125 expandedShape.insert(I: expandedShape.end() - 1, Elt: expansionDim);
126 expandedShape.back() = vectorWidth;
127
128 for (unsigned i = 0; i < operands.size(); ++i) {
129 auto operand = operands[i];
130 auto eltType = cast<VectorType>(Val: operand.getType()).getElementType();
131 auto expandedType = VectorType::get(shape: expandedShape, elementType: eltType);
132 expandedOperands[i] =
133 builder.create<vector::ShapeCastOp>(args&: expandedType, args&: operand);
134 }
135 }
136
137 // Iterate over all outer dimensions of the compute shape vector type.
138 auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
139 int64_t maxIndex = computeMaxLinearIndex(basis: iterationDims);
140 auto strides = computeStrides(sizes: iterationDims);
141
142 // Compute results for each one dimensional vector.
143 SmallVector<Value> results(maxIndex);
144
145 for (int64_t i = 0; i < maxIndex; ++i) {
146 auto offsets = delinearize(linearIndex: i, strides);
147
148 SmallVector<Value> extracted(expandedOperands.size());
149 for (const auto &tuple : llvm::enumerate(First&: expandedOperands))
150 extracted[tuple.index()] =
151 builder.create<vector::ExtractOp>(args&: tuple.value(), args&: offsets);
152
153 results[i] = compute(extracted);
154 }
155
156 // Stitch results together into one large vector.
157 Type resultEltType = cast<VectorType>(Val: results[0].getType()).getElementType();
158 Type resultExpandedType = VectorType::get(shape: expandedShape, elementType: resultEltType);
159 Value result = builder.create<arith::ConstantOp>(
160 args&: resultExpandedType, args: builder.getZeroAttr(type: resultExpandedType));
161
162 for (int64_t i = 0; i < maxIndex; ++i)
163 result = builder.create<vector::InsertOp>(args&: results[i], args&: result,
164 args: delinearize(linearIndex: i, strides));
165
166 // Reshape back to the original vector shape.
167 return builder.create<vector::ShapeCastOp>(
168 args: VectorType::get(shape: inputShape, elementType: resultEltType), args&: result);
169}
170
171//----------------------------------------------------------------------------//
172// Helper functions to create constants.
173//----------------------------------------------------------------------------//
174
175static Value boolCst(ImplicitLocOpBuilder &builder, bool value) {
176 return builder.create<arith::ConstantOp>(args: builder.getBoolAttr(value));
177}
178
179static Value floatCst(ImplicitLocOpBuilder &builder, float value,
180 Type elementType) {
181 assert((elementType.isF16() || elementType.isF32()) &&
182 "x must be f16 or f32 type.");
183 return builder.create<arith::ConstantOp>(
184 args: builder.getFloatAttr(type: elementType, value));
185}
186
187static Value f32Cst(ImplicitLocOpBuilder &builder, double value) {
188 return builder.create<arith::ConstantOp>(args: builder.getF32FloatAttr(value));
189}
190
191static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
192 return builder.create<arith::ConstantOp>(args: builder.getI32IntegerAttr(value));
193}
194
195static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
196 Value i32Value = i32Cst(builder, value: static_cast<int32_t>(bits));
197 return builder.create<arith::BitcastOp>(args: builder.getF32Type(), args&: i32Value);
198}
199
200//----------------------------------------------------------------------------//
201// Helper functions to build math functions approximations.
202//----------------------------------------------------------------------------//
203
204// Return the minimum of the two values or NaN if value is NaN
205static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) {
206 return builder.create<arith::SelectOp>(
207 args: builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::ULT, args&: value, args&: bound),
208 args&: value, args&: bound);
209}
210
211// Return the maximum of the two values or NaN if value is NaN
212static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) {
213 return builder.create<arith::SelectOp>(
214 args: builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::UGT, args&: value, args&: bound),
215 args&: value, args&: bound);
216}
217
218// Return the clamped value or NaN if value is NaN
219static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
220 Value upperBound) {
221 return max(builder, value: min(builder, value, bound: upperBound), bound: lowerBound);
222}
223
224// Decomposes given floating point value `arg` into a normalized fraction and
225// an integral power of two (see std::frexp). Returned values have float type.
226static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
227 bool isPositive = false) {
228 assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
229 std::optional<VectorShape> shape = vectorShape(value: arg);
230
231 auto bcast = [&](Value value) -> Value {
232 return broadcast(builder, value, shape);
233 };
234
235 auto i32 = builder.getIntegerType(width: 32);
236 auto i32Vec = broadcast(type: i32, shape);
237 auto f32Vec = broadcast(type: builder.getF32Type(), shape);
238
239 Value cst126f = f32Cst(builder, value: 126.0f);
240 Value cstHalf = f32Cst(builder, value: 0.5f);
241 Value cstInvMantMask = f32FromBits(builder, bits: ~0x7f800000u);
242
243 // Bitcast to i32 for bitwise operations.
244 Value i32Half = builder.create<arith::BitcastOp>(args&: i32, args&: cstHalf);
245 Value i32InvMantMask = builder.create<arith::BitcastOp>(args&: i32, args&: cstInvMantMask);
246 Value i32Arg = builder.create<arith::BitcastOp>(args&: i32Vec, args&: arg);
247
248 // Compute normalized fraction.
249 Value tmp0 = builder.create<arith::AndIOp>(args&: i32Arg, args: bcast(i32InvMantMask));
250 Value tmp1 = builder.create<arith::OrIOp>(args&: tmp0, args: bcast(i32Half));
251 Value normalizedFraction = builder.create<arith::BitcastOp>(args&: f32Vec, args&: tmp1);
252
253 // Compute exponent.
254 Value arg0 = isPositive ? arg : builder.create<math::AbsFOp>(args&: arg);
255 Value biasedExponentBits = builder.create<arith::ShRUIOp>(
256 args: builder.create<arith::BitcastOp>(args&: i32Vec, args&: arg0),
257 args: bcast(i32Cst(builder, value: 23)));
258 Value biasedExponent =
259 builder.create<arith::SIToFPOp>(args&: f32Vec, args&: biasedExponentBits);
260 Value exponent =
261 builder.create<arith::SubFOp>(args&: biasedExponent, args: bcast(cst126f));
262
263 return {normalizedFraction, exponent};
264}
265
266// Computes exp2 for an i32 argument.
267static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
268 assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
269 std::optional<VectorShape> shape = vectorShape(value: arg);
270
271 auto bcast = [&](Value value) -> Value {
272 return broadcast(builder, value, shape);
273 };
274
275 auto f32Vec = broadcast(type: builder.getF32Type(), shape);
276 // The exponent of f32 located at 23-bit.
277 auto exponetBitLocation = bcast(i32Cst(builder, value: 23));
278 // Set the exponent bias to zero.
279 auto bias = bcast(i32Cst(builder, value: 127));
280
281 Value biasedArg = builder.create<arith::AddIOp>(args&: arg, args&: bias);
282 Value exp2ValueInt =
283 builder.create<arith::ShLIOp>(args&: biasedArg, args&: exponetBitLocation);
284 Value exp2ValueF32 = builder.create<arith::BitcastOp>(args&: f32Vec, args&: exp2ValueInt);
285
286 return exp2ValueF32;
287}
288
289namespace {
290Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
291 llvm::ArrayRef<Value> coeffs, Value x) {
292 Type elementType = getElementTypeOrSelf(val: x);
293 assert((elementType.isF32() || elementType.isF16()) &&
294 "x must be f32 or f16 type");
295 std::optional<VectorShape> shape = vectorShape(value: x);
296
297 if (coeffs.empty())
298 return broadcast(builder, value: floatCst(builder, value: 0.0f, elementType), shape);
299
300 if (coeffs.size() == 1)
301 return coeffs[0];
302
303 Value res = builder.create<math::FmaOp>(args&: x, args: coeffs[coeffs.size() - 1],
304 args: coeffs[coeffs.size() - 2]);
305 for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
306 res = builder.create<math::FmaOp>(args&: x, args&: res, args: coeffs[i]);
307 }
308 return res;
309}
310} // namespace
311
312//----------------------------------------------------------------------------//
313// Helper function/pattern to insert casts for reusing F32 bit expansion.
314//----------------------------------------------------------------------------//
315
316template <typename T>
317LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
318 // Conservatively only allow where the operand and result types are exactly 1.
319 Type origType = op->getResultTypes().front();
320 for (Type t : llvm::drop_begin(RangeOrContainer: op->getResultTypes()))
321 if (origType != t)
322 return rewriter.notifyMatchFailure(arg&: op, msg: "required all types to match");
323 for (Type t : op->getOperandTypes())
324 if (origType != t)
325 return rewriter.notifyMatchFailure(arg&: op, msg: "required all types to match");
326
327 // Skip if already F32 or larger than 32 bits.
328 if (getElementTypeOrSelf(type: origType).isF32() ||
329 getElementTypeOrSelf(type: origType).getIntOrFloatBitWidth() > 32)
330 return failure();
331
332 // Create F32 equivalent type.
333 Type newType;
334 if (auto shaped = dyn_cast<ShapedType>(Val&: origType)) {
335 newType = shaped.clone(elementType: rewriter.getF32Type());
336 } else if (isa<FloatType>(Val: origType)) {
337 newType = rewriter.getF32Type();
338 } else {
339 return rewriter.notifyMatchFailure(arg&: op,
340 msg: "unable to find F32 equivalent type");
341 }
342
343 Location loc = op->getLoc();
344 SmallVector<Value> operands;
345 for (auto operand : op->getOperands())
346 operands.push_back(Elt: rewriter.create<arith::ExtFOp>(location: loc, args&: newType, args&: operand));
347 auto result =
348 rewriter.create<T>(loc, TypeRange{newType}, operands, op->getAttrs());
349 rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
350 return success();
351}
352
353namespace {
354// Pattern to cast to F32 to reuse F32 expansion as fallback for single-result
355// op.
356// TODO: Consider revising to avoid adding multiple casts for a subgraph that is
357// all in lower precision. Currently this is only fallback support and performs
358// simplistic casting.
359template <typename T>
360struct ReuseF32Expansion : public OpRewritePattern<T> {
361public:
362 using OpRewritePattern<T>::OpRewritePattern;
363 LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final {
364 static_assert(
365 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
366 "requires same operands and result types");
367 return insertCasts<T>(op, rewriter);
368 }
369};
370} // namespace
371
372//----------------------------------------------------------------------------//
373// AtanOp approximation.
374//----------------------------------------------------------------------------//
375
376namespace {
377struct AtanApproximation : public OpRewritePattern<math::AtanOp> {
378public:
379 using OpRewritePattern::OpRewritePattern;
380
381 LogicalResult matchAndRewrite(math::AtanOp op,
382 PatternRewriter &rewriter) const final;
383};
384} // namespace
385
386LogicalResult
387AtanApproximation::matchAndRewrite(math::AtanOp op,
388 PatternRewriter &rewriter) const {
389 auto operand = op.getOperand();
390 if (!getElementTypeOrSelf(val: operand).isF32())
391 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported operand type");
392
393 std::optional<VectorShape> shape = vectorShape(value: op.getOperand());
394
395 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
396 Value abs = builder.create<math::AbsFOp>(args&: operand);
397
398 auto one = broadcast(builder, value: f32Cst(builder, value: 1.0), shape);
399
400 // When 0.66 < x <= 2.41 we do (x-1) / (x+1):
401 auto twoThirds = broadcast(builder, value: f32Cst(builder, value: 0.66), shape);
402 Value cmp2 =
403 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGT, args&: abs, args&: twoThirds);
404 Value addone = builder.create<arith::AddFOp>(args&: abs, args&: one);
405 Value subone = builder.create<arith::SubFOp>(args&: abs, args&: one);
406 Value xnum = builder.create<arith::SelectOp>(args&: cmp2, args&: subone, args&: abs);
407 Value xden = builder.create<arith::SelectOp>(args&: cmp2, args&: addone, args&: one);
408
409 auto bcast = [&](Value value) -> Value {
410 return broadcast(builder, value, shape);
411 };
412
413 // Break into the <= 0.66 or > 2.41 we do x or 1/x:
414 auto tan3pio8 = bcast(f32Cst(builder, value: 2.41421356237309504880));
415 Value cmp1 =
416 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGT, args&: abs, args&: tan3pio8);
417 xnum = builder.create<arith::SelectOp>(args&: cmp1, args&: one, args&: xnum);
418 xden = builder.create<arith::SelectOp>(args&: cmp1, args&: abs, args&: xden);
419
420 Value x = builder.create<arith::DivFOp>(args&: xnum, args&: xden);
421 Value xx = builder.create<arith::MulFOp>(args&: x, args&: x);
422
423 // Perform the Taylor series approximation for atan over the range
424 // [0.0, 0.66].
425 auto p0 = bcast(f32Cst(builder, value: -8.750608600031904122785e-01));
426 auto p1 = bcast(f32Cst(builder, value: -1.615753718733365076637e+01));
427 auto p2 = bcast(f32Cst(builder, value: -7.500855792314704667340e+01));
428 auto p3 = bcast(f32Cst(builder, value: -1.228866684490136173410e+02));
429 auto p4 = bcast(f32Cst(builder, value: -6.485021904942025371773e+01));
430 auto q0 = bcast(f32Cst(builder, value: +2.485846490142306297962e+01));
431 auto q1 = bcast(f32Cst(builder, value: +1.650270098316988542046e+02));
432 auto q2 = bcast(f32Cst(builder, value: +4.328810604912902668951e+02));
433 auto q3 = bcast(f32Cst(builder, value: +4.853903996359136964868e+02));
434 auto q4 = bcast(f32Cst(builder, value: +1.945506571482613964425e+02));
435
436 // Apply the polynomial approximation for the numerator:
437 Value n = p0;
438 n = builder.create<math::FmaOp>(args&: xx, args&: n, args&: p1);
439 n = builder.create<math::FmaOp>(args&: xx, args&: n, args&: p2);
440 n = builder.create<math::FmaOp>(args&: xx, args&: n, args&: p3);
441 n = builder.create<math::FmaOp>(args&: xx, args&: n, args&: p4);
442 n = builder.create<arith::MulFOp>(args&: n, args&: xx);
443
444 // Apply the polynomial approximation for the denominator:
445 Value d = q0;
446 d = builder.create<math::FmaOp>(args&: xx, args&: d, args&: q1);
447 d = builder.create<math::FmaOp>(args&: xx, args&: d, args&: q2);
448 d = builder.create<math::FmaOp>(args&: xx, args&: d, args&: q3);
449 d = builder.create<math::FmaOp>(args&: xx, args&: d, args&: q4);
450
451 // Compute approximation of theta:
452 Value ans0 = builder.create<arith::DivFOp>(args&: n, args&: d);
453 ans0 = builder.create<math::FmaOp>(args&: ans0, args&: x, args&: x);
454
455 // Correct for the input mapping's angles:
456 Value mpi4 = bcast(f32Cst(builder, value: llvm::numbers::pi / 4));
457 Value ans2 = builder.create<arith::AddFOp>(args&: mpi4, args&: ans0);
458 Value ans = builder.create<arith::SelectOp>(args&: cmp2, args&: ans2, args&: ans0);
459
460 Value mpi2 = bcast(f32Cst(builder, value: llvm::numbers::pi / 2));
461 Value ans1 = builder.create<arith::SubFOp>(args&: mpi2, args&: ans0);
462 ans = builder.create<arith::SelectOp>(args&: cmp1, args&: ans1, args&: ans);
463
464 // Correct for signing of the input.
465 rewriter.replaceOpWithNewOp<math::CopySignOp>(op, args&: ans, args&: operand);
466 return success();
467}
468
469//----------------------------------------------------------------------------//
470// AtanOp approximation.
471//----------------------------------------------------------------------------//
472
473namespace {
474struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> {
475public:
476 using OpRewritePattern::OpRewritePattern;
477
478 LogicalResult matchAndRewrite(math::Atan2Op op,
479 PatternRewriter &rewriter) const final;
480};
481} // namespace
482
483LogicalResult
484Atan2Approximation::matchAndRewrite(math::Atan2Op op,
485 PatternRewriter &rewriter) const {
486 auto y = op.getOperand(i: 0);
487 auto x = op.getOperand(i: 1);
488 if (!getElementTypeOrSelf(val: x).isF32())
489 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported operand type");
490
491 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
492 std::optional<VectorShape> shape = vectorShape(value: op.getResult());
493
494 // Compute atan in the valid range.
495 auto div = builder.create<arith::DivFOp>(args&: y, args&: x);
496 auto atan = builder.create<math::AtanOp>(args&: div);
497
498 // Determine what the atan would be for a 180 degree rotation.
499 auto zero = broadcast(builder, value: f32Cst(builder, value: 0.0f), shape);
500 auto pi = broadcast(builder, value: f32Cst(builder, value: 3.14159265359f), shape);
501 auto addPi = builder.create<arith::AddFOp>(args&: atan, args&: pi);
502 auto subPi = builder.create<arith::SubFOp>(args&: atan, args&: pi);
503 auto atanGt =
504 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGT, args&: atan, args&: zero);
505 auto flippedAtan = builder.create<arith::SelectOp>(args&: atanGt, args&: subPi, args&: addPi);
506
507 // Determine whether to directly use atan or use the 180 degree flip
508 auto xGt = builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGT, args&: x, args&: zero);
509 Value result = builder.create<arith::SelectOp>(args&: xGt, args&: atan, args&: flippedAtan);
510
511 // Handle x = 0, y > 0
512 Value xZero =
513 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: x, args&: zero);
514 Value yGt = builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGT, args&: y, args&: zero);
515 Value isHalfPi = builder.create<arith::AndIOp>(args&: xZero, args&: yGt);
516 auto halfPi = broadcast(builder, value: f32Cst(builder, value: 1.57079632679f), shape);
517 result = builder.create<arith::SelectOp>(args&: isHalfPi, args&: halfPi, args&: result);
518
519 // Handle x = 0, y < 0
520 Value yLt = builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OLT, args&: y, args&: zero);
521 Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(args&: xZero, args&: yLt);
522 auto negativeHalfPiPi =
523 broadcast(builder, value: f32Cst(builder, value: -1.57079632679f), shape);
524 result = builder.create<arith::SelectOp>(args&: isNegativeHalfPiPi, args&: negativeHalfPiPi,
525 args&: result);
526
527 // Handle x = 0, y = 0;
528 Value yZero =
529 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: y, args&: zero);
530 Value isNan = builder.create<arith::AndIOp>(args&: xZero, args&: yZero);
531 Value cstNan = broadcast(builder, value: f32FromBits(builder, bits: 0x7fc00000), shape);
532 result = builder.create<arith::SelectOp>(args&: isNan, args&: cstNan, args&: result);
533
534 rewriter.replaceOp(op, newValues: result);
535 return success();
536}
537
538//----------------------------------------------------------------------------//
539// TanhOp approximation.
540//----------------------------------------------------------------------------//
541
542namespace {
543struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
544public:
545 using OpRewritePattern::OpRewritePattern;
546
547 LogicalResult matchAndRewrite(math::TanhOp op,
548 PatternRewriter &rewriter) const final;
549};
550} // namespace
551
552LogicalResult
553TanhApproximation::matchAndRewrite(math::TanhOp op,
554 PatternRewriter &rewriter) const {
555 if (!getElementTypeOrSelf(val: op.getOperand()).isF32())
556 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported operand type");
557
558 std::optional<VectorShape> shape = vectorShape(value: op.getOperand());
559
560 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
561 auto bcast = [&](Value value) -> Value {
562 return broadcast(builder, value, shape);
563 };
564
565 // Clamp operand into [plusClamp, minusClamp] range.
566 Value minusClamp = bcast(f32Cst(builder, value: -7.99881172180175781f));
567 Value plusClamp = bcast(f32Cst(builder, value: 7.99881172180175781f));
568 Value x = clamp(builder, value: op.getOperand(), lowerBound: minusClamp, upperBound: plusClamp);
569
570 // Mask for tiny values that are approximated with `operand`.
571 Value tiny = bcast(f32Cst(builder, value: 0.0004f));
572 Value tinyMask = builder.create<arith::CmpFOp>(
573 args: arith::CmpFPredicate::OLT, args: builder.create<math::AbsFOp>(args: op.getOperand()),
574 args&: tiny);
575
576 // The monomial coefficients of the numerator polynomial (odd).
577 Value alpha1 = bcast(f32Cst(builder, value: 4.89352455891786e-03f));
578 Value alpha3 = bcast(f32Cst(builder, value: 6.37261928875436e-04f));
579 Value alpha5 = bcast(f32Cst(builder, value: 1.48572235717979e-05f));
580 Value alpha7 = bcast(f32Cst(builder, value: 5.12229709037114e-08f));
581 Value alpha9 = bcast(f32Cst(builder, value: -8.60467152213735e-11f));
582 Value alpha11 = bcast(f32Cst(builder, value: 2.00018790482477e-13f));
583 Value alpha13 = bcast(f32Cst(builder, value: -2.76076847742355e-16f));
584
585 // The monomial coefficients of the denominator polynomial (even).
586 Value beta0 = bcast(f32Cst(builder, value: 4.89352518554385e-03f));
587 Value beta2 = bcast(f32Cst(builder, value: 2.26843463243900e-03f));
588 Value beta4 = bcast(f32Cst(builder, value: 1.18534705686654e-04f));
589 Value beta6 = bcast(f32Cst(builder, value: 1.19825839466702e-06f));
590
591 // Since the polynomials are odd/even, we need x^2.
592 Value x2 = builder.create<arith::MulFOp>(args&: x, args&: x);
593
594 // Evaluate the numerator polynomial p.
595 Value p = builder.create<math::FmaOp>(args&: x2, args&: alpha13, args&: alpha11);
596 p = builder.create<math::FmaOp>(args&: x2, args&: p, args&: alpha9);
597 p = builder.create<math::FmaOp>(args&: x2, args&: p, args&: alpha7);
598 p = builder.create<math::FmaOp>(args&: x2, args&: p, args&: alpha5);
599 p = builder.create<math::FmaOp>(args&: x2, args&: p, args&: alpha3);
600 p = builder.create<math::FmaOp>(args&: x2, args&: p, args&: alpha1);
601 p = builder.create<arith::MulFOp>(args&: x, args&: p);
602
603 // Evaluate the denominator polynomial q.
604 Value q = builder.create<math::FmaOp>(args&: x2, args&: beta6, args&: beta4);
605 q = builder.create<math::FmaOp>(args&: x2, args&: q, args&: beta2);
606 q = builder.create<math::FmaOp>(args&: x2, args&: q, args&: beta0);
607
608 // Divide the numerator by the denominator.
609 Value res = builder.create<arith::SelectOp>(
610 args&: tinyMask, args&: x, args: builder.create<arith::DivFOp>(args&: p, args&: q));
611
612 rewriter.replaceOp(op, newValues: res);
613
614 return success();
615}
616
617#define LN2_VALUE \
618 0.693147180559945309417232121458176568075500134360255254120680009493393621L
619#define LOG2E_VALUE \
620 1.442695040888963407359924681001892137426645954152985934135449406931109219L
621
622//----------------------------------------------------------------------------//
623// LogOp and Log2Op approximation.
624//----------------------------------------------------------------------------//
625
626namespace {
627template <typename Op>
628struct LogApproximationBase : public OpRewritePattern<Op> {
629 using OpRewritePattern<Op>::OpRewritePattern;
630
631 /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise.
632 LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter,
633 bool base2) const;
634};
635} // namespace
636
637// This approximation comes from Julien Pommier's SSE math library.
638// Link: http://gruntthepeon.free.fr/ssemath
639template <typename Op>
640LogicalResult
641LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
642 bool base2) const {
643 if (!getElementTypeOrSelf(op.getOperand()).isF32())
644 return rewriter.notifyMatchFailure(op, "unsupported operand type");
645
646 std::optional<VectorShape> shape = vectorShape(op.getOperand());
647
648 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
649 auto bcast = [&](Value value) -> Value {
650 return broadcast(builder, value, shape);
651 };
652
653 Value cstZero = bcast(f32Cst(builder, value: 0.0f));
654 Value cstOne = bcast(f32Cst(builder, value: 1.0f));
655 Value cstNegHalf = bcast(f32Cst(builder, value: -0.5f));
656
657 // The smallest non denormalized float number.
658 Value cstMinNormPos = bcast(f32FromBits(builder, bits: 0x00800000u));
659 Value cstMinusInf = bcast(f32FromBits(builder, bits: 0xff800000u));
660 Value cstPosInf = bcast(f32FromBits(builder, bits: 0x7f800000u));
661 Value cstNan = bcast(f32FromBits(builder, bits: 0x7fc00000));
662
663 // Polynomial coefficients.
664 Value cstCephesSQRTHF = bcast(f32Cst(builder, value: 0.707106781186547524f));
665 Value cstCephesLogP0 = bcast(f32Cst(builder, value: 7.0376836292E-2f));
666 Value cstCephesLogP1 = bcast(f32Cst(builder, value: -1.1514610310E-1f));
667 Value cstCephesLogP2 = bcast(f32Cst(builder, value: 1.1676998740E-1f));
668 Value cstCephesLogP3 = bcast(f32Cst(builder, value: -1.2420140846E-1f));
669 Value cstCephesLogP4 = bcast(f32Cst(builder, value: +1.4249322787E-1f));
670 Value cstCephesLogP5 = bcast(f32Cst(builder, value: -1.6668057665E-1f));
671 Value cstCephesLogP6 = bcast(f32Cst(builder, value: +2.0000714765E-1f));
672 Value cstCephesLogP7 = bcast(f32Cst(builder, value: -2.4999993993E-1f));
673 Value cstCephesLogP8 = bcast(f32Cst(builder, value: +3.3333331174E-1f));
674
675 Value x = op.getOperand();
676
677 // Truncate input values to the minimum positive normal.
678 x = max(builder, value: x, bound: cstMinNormPos);
679
680 // Extract significant in the range [0.5,1) and exponent.
681 std::pair<Value, Value> pair = frexp(builder, arg: x, /*isPositive=*/true);
682 x = pair.first;
683 Value e = pair.second;
684
685 // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift
686 // by -1.0. The values are then centered around 0, which improves the
687 // stability of the polynomial evaluation:
688 //
689 // if( x < SQRTHF ) {
690 // e -= 1;
691 // x = x + x - 1.0;
692 // } else { x = x - 1.0; }
693 Value mask = builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OLT, args&: x,
694 args&: cstCephesSQRTHF);
695 Value tmp = builder.create<arith::SelectOp>(args&: mask, args&: x, args&: cstZero);
696
697 x = builder.create<arith::SubFOp>(args&: x, args&: cstOne);
698 e = builder.create<arith::SubFOp>(
699 args&: e, args: builder.create<arith::SelectOp>(args&: mask, args&: cstOne, args&: cstZero));
700 x = builder.create<arith::AddFOp>(args&: x, args&: tmp);
701
702 Value x2 = builder.create<arith::MulFOp>(args&: x, args&: x);
703 Value x3 = builder.create<arith::MulFOp>(args&: x2, args&: x);
704
705 // Evaluate the polynomial approximant of degree 8 in three parts.
706 Value y0, y1, y2;
707 y0 = builder.create<math::FmaOp>(args&: cstCephesLogP0, args&: x, args&: cstCephesLogP1);
708 y1 = builder.create<math::FmaOp>(args&: cstCephesLogP3, args&: x, args&: cstCephesLogP4);
709 y2 = builder.create<math::FmaOp>(args&: cstCephesLogP6, args&: x, args&: cstCephesLogP7);
710 y0 = builder.create<math::FmaOp>(args&: y0, args&: x, args&: cstCephesLogP2);
711 y1 = builder.create<math::FmaOp>(args&: y1, args&: x, args&: cstCephesLogP5);
712 y2 = builder.create<math::FmaOp>(args&: y2, args&: x, args&: cstCephesLogP8);
713 y0 = builder.create<math::FmaOp>(args&: y0, args&: x3, args&: y1);
714 y0 = builder.create<math::FmaOp>(args&: y0, args&: x3, args&: y2);
715 y0 = builder.create<arith::MulFOp>(args&: y0, args&: x3);
716
717 y0 = builder.create<math::FmaOp>(args&: cstNegHalf, args&: x2, args&: y0);
718 x = builder.create<arith::AddFOp>(args&: x, args&: y0);
719
720 if (base2) {
721 Value cstLog2e = bcast(f32Cst(builder, value: static_cast<float>(LOG2E_VALUE)));
722 x = builder.create<math::FmaOp>(args&: x, args&: cstLog2e, args&: e);
723 } else {
724 Value cstLn2 = bcast(f32Cst(builder, value: static_cast<float>(LN2_VALUE)));
725 x = builder.create<math::FmaOp>(args&: e, args&: cstLn2, args&: x);
726 }
727
728 Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
729 op.getOperand(), cstZero);
730 Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
731 op.getOperand(), cstZero);
732 Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
733 op.getOperand(), cstPosInf);
734
735 // Filter out invalid values:
736 // • x == 0 -> -INF
737 // • x < 0 -> NAN
738 // • x == +INF -> +INF
739 Value aproximation = builder.create<arith::SelectOp>(
740 args&: zeroMask, args&: cstMinusInf,
741 args: builder.create<arith::SelectOp>(
742 args&: invalidMask, args&: cstNan,
743 args: builder.create<arith::SelectOp>(args&: posInfMask, args&: cstPosInf, args&: x)));
744
745 rewriter.replaceOp(op, aproximation);
746
747 return success();
748}
749
750namespace {
751struct LogApproximation : public LogApproximationBase<math::LogOp> {
752 using LogApproximationBase::LogApproximationBase;
753
754 LogicalResult matchAndRewrite(math::LogOp op,
755 PatternRewriter &rewriter) const final {
756 return logMatchAndRewrite(op, rewriter, /*base2=*/false);
757 }
758};
759} // namespace
760
761namespace {
762struct Log2Approximation : public LogApproximationBase<math::Log2Op> {
763 using LogApproximationBase::LogApproximationBase;
764
765 LogicalResult matchAndRewrite(math::Log2Op op,
766 PatternRewriter &rewriter) const final {
767 return logMatchAndRewrite(op, rewriter, /*base2=*/true);
768 }
769};
770} // namespace
771
772//----------------------------------------------------------------------------//
773// Log1p approximation.
774//----------------------------------------------------------------------------//
775
776namespace {
777struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
778public:
779 using OpRewritePattern::OpRewritePattern;
780
781 LogicalResult matchAndRewrite(math::Log1pOp op,
782 PatternRewriter &rewriter) const final;
783};
784} // namespace
785
786// Approximate log(1+x).
787LogicalResult
788Log1pApproximation::matchAndRewrite(math::Log1pOp op,
789 PatternRewriter &rewriter) const {
790 if (!getElementTypeOrSelf(val: op.getOperand()).isF32())
791 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported operand type");
792
793 std::optional<VectorShape> shape = vectorShape(value: op.getOperand());
794
795 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
796 auto bcast = [&](Value value) -> Value {
797 return broadcast(builder, value, shape);
798 };
799
800 // Approximate log(1+x) using the following, due to W. Kahan:
801 // u = x + 1.0;
802 // if (u == 1.0 || u == inf) return x;
803 // return x * log(u) / (u - 1.0);
804 // ^^^^^^^^^^^^^^^^^^^^^^
805 // "logLarge" below.
806 Value cstOne = bcast(f32Cst(builder, value: 1.0f));
807 Value x = op.getOperand();
808 Value u = builder.create<arith::AddFOp>(args&: x, args&: cstOne);
809 Value uSmall =
810 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: u, args&: cstOne);
811 Value logU = builder.create<math::LogOp>(args&: u);
812 Value uInf =
813 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: u, args&: logU);
814 Value logLarge = builder.create<arith::MulFOp>(
815 args&: x, args: builder.create<arith::DivFOp>(
816 args&: logU, args: builder.create<arith::SubFOp>(args&: u, args&: cstOne)));
817 Value approximation = builder.create<arith::SelectOp>(
818 args: builder.create<arith::OrIOp>(args&: uSmall, args&: uInf), args&: x, args&: logLarge);
819 rewriter.replaceOp(op, newValues: approximation);
820 return success();
821}
822
823//----------------------------------------------------------------------------//
824// Asin approximation.
825//----------------------------------------------------------------------------//
826
827// Approximates asin(x).
828// This approximation is based on the following stackoverflow post:
829// https://stackoverflow.com/a/42683455
830namespace {
831struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> {
832public:
833 using OpRewritePattern::OpRewritePattern;
834
835 LogicalResult matchAndRewrite(math::AsinOp op,
836 PatternRewriter &rewriter) const final;
837};
838} // namespace
839LogicalResult
840AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
841 PatternRewriter &rewriter) const {
842 Value operand = op.getOperand();
843 Type elementType = getElementTypeOrSelf(val: operand);
844
845 if (!(elementType.isF32() || elementType.isF16()))
846 return rewriter.notifyMatchFailure(arg&: op,
847 msg: "only f32 and f16 type is supported.");
848 std::optional<VectorShape> shape = vectorShape(value: operand);
849
850 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
851 auto bcast = [&](Value value) -> Value {
852 return broadcast(builder, value, shape);
853 };
854
855 auto fma = [&](Value a, Value b, Value c) -> Value {
856 return builder.create<math::FmaOp>(args&: a, args&: b, args&: c);
857 };
858
859 auto mul = [&](Value a, Value b) -> Value {
860 return builder.create<arith::MulFOp>(args&: a, args&: b);
861 };
862
863 auto sub = [&](Value a, Value b) -> Value {
864 return builder.create<arith::SubFOp>(args&: a, args&: b);
865 };
866
867 auto abs = [&](Value a) -> Value { return builder.create<math::AbsFOp>(args&: a); };
868
869 auto sqrt = [&](Value a) -> Value { return builder.create<math::SqrtOp>(args&: a); };
870
871 auto scopy = [&](Value a, Value b) -> Value {
872 return builder.create<math::CopySignOp>(args&: a, args&: b);
873 };
874
875 auto sel = [&](Value a, Value b, Value c) -> Value {
876 return builder.create<arith::SelectOp>(args&: a, args&: b, args&: c);
877 };
878
879 Value abso = abs(operand);
880 Value aa = mul(operand, operand);
881 Value opp = sqrt(sub(bcast(floatCst(builder, value: 1.0, elementType)), aa));
882
883 Value gt =
884 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGT, args&: aa,
885 args: bcast(floatCst(builder, value: 0.5, elementType)));
886
887 Value x = sel(gt, opp, abso);
888
889 // Asin(x) approximation for x = [-9/16, 9/16]:
890 Value s = mul(x, x);
891 Value q = mul(s, s);
892 Value r = bcast(floatCst(builder, value: 5.5579749017470502e-2, elementType));
893 Value t = bcast(floatCst(builder, value: -6.2027913464120114e-2, elementType));
894
895 r = fma(r, q, bcast(floatCst(builder, value: 5.4224464349245036e-2, elementType)));
896 t = fma(t, q, bcast(floatCst(builder, value: -1.1326992890324464e-2, elementType)));
897 r = fma(r, q, bcast(floatCst(builder, value: 1.5268872539397656e-2, elementType)));
898 t = fma(t, q, bcast(floatCst(builder, value: 1.0493798473372081e-2, elementType)));
899 r = fma(r, q, bcast(floatCst(builder, value: 1.4106045900607047e-2, elementType)));
900 t = fma(t, q, bcast(floatCst(builder, value: 1.7339776384962050e-2, elementType)));
901 r = fma(r, q, bcast(floatCst(builder, value: 2.2372961589651054e-2, elementType)));
902 t = fma(t, q, bcast(floatCst(builder, value: 3.0381912707941005e-2, elementType)));
903 r = fma(r, q, bcast(floatCst(builder, value: 4.4642857881094775e-2, elementType)));
904 t = fma(t, q, bcast(floatCst(builder, value: 7.4999999991367292e-2, elementType)));
905 r = fma(r, s, t);
906 r = fma(r, s, bcast(floatCst(builder, value: 1.6666666666670193e-1, elementType)));
907 t = mul(x, s);
908 r = fma(r, t, x);
909
910 Value rsub = sub(bcast(floatCst(builder, value: 1.57079632679, elementType)), r);
911 r = sel(gt, rsub, r);
912 r = scopy(r, operand);
913
914 rewriter.replaceOp(op, newValues: r);
915 return success();
916}
917
918//----------------------------------------------------------------------------//
919// Acos approximation.
920//----------------------------------------------------------------------------//
921
922// Approximates acos(x).
923// This approximation is based on the following stackoverflow post:
924// https://stackoverflow.com/a/42683455
925namespace {
926struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> {
927public:
928 using OpRewritePattern::OpRewritePattern;
929
930 LogicalResult matchAndRewrite(math::AcosOp op,
931 PatternRewriter &rewriter) const final;
932};
933} // namespace
934LogicalResult
935AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
936 PatternRewriter &rewriter) const {
937 Value operand = op.getOperand();
938 Type elementType = getElementTypeOrSelf(val: operand);
939
940 if (!(elementType.isF32() || elementType.isF16()))
941 return rewriter.notifyMatchFailure(arg&: op,
942 msg: "only f32 and f16 type is supported.");
943 std::optional<VectorShape> shape = vectorShape(value: operand);
944
945 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
946 auto bcast = [&](Value value) -> Value {
947 return broadcast(builder, value, shape);
948 };
949
950 auto fma = [&](Value a, Value b, Value c) -> Value {
951 return builder.create<math::FmaOp>(args&: a, args&: b, args&: c);
952 };
953
954 auto mul = [&](Value a, Value b) -> Value {
955 return builder.create<arith::MulFOp>(args&: a, args&: b);
956 };
957
958 Value negOperand = builder.create<arith::NegFOp>(args&: operand);
959 Value zero = bcast(floatCst(builder, value: 0.0, elementType));
960 Value half = bcast(floatCst(builder, value: 0.5, elementType));
961 Value negOne = bcast(floatCst(builder, value: -1.0, elementType));
962 Value selR =
963 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGT, args&: operand, args&: zero);
964 Value r = builder.create<arith::SelectOp>(args&: selR, args&: negOperand, args&: operand);
965 Value chkConst = bcast(floatCst(builder, value: -0.5625, elementType));
966 Value firstPred =
967 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGT, args&: r, args&: chkConst);
968
969 Value trueVal =
970 fma(bcast(floatCst(builder, value: 9.3282184640716537e-1, elementType)),
971 bcast(floatCst(builder, value: 1.6839188885261840e+0, elementType)),
972 builder.create<math::AsinOp>(args&: r));
973
974 Value falseVal = builder.create<math::SqrtOp>(args: fma(half, r, half));
975 falseVal = builder.create<math::AsinOp>(args&: falseVal);
976 falseVal = mul(bcast(floatCst(builder, value: 2.0, elementType)), falseVal);
977
978 r = builder.create<arith::SelectOp>(args&: firstPred, args&: trueVal, args&: falseVal);
979
980 // Check whether the operand lies in between [-1.0, 0.0).
981 Value greaterThanNegOne =
982 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGE, args&: operand, args&: negOne);
983
984 Value lessThanZero =
985 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OLT, args&: operand, args&: zero);
986
987 Value betweenNegOneZero =
988 builder.create<arith::AndIOp>(args&: greaterThanNegOne, args&: lessThanZero);
989
990 trueVal = fma(bcast(floatCst(builder, value: 1.8656436928143307e+0, elementType)),
991 bcast(floatCst(builder, value: 1.6839188885261840e+0, elementType)),
992 builder.create<arith::NegFOp>(args&: r));
993
994 Value finalVal =
995 builder.create<arith::SelectOp>(args&: betweenNegOneZero, args&: trueVal, args&: r);
996
997 rewriter.replaceOp(op, newValues: finalVal);
998 return success();
999}
1000
1001//----------------------------------------------------------------------------//
1002// Erf approximation.
1003//----------------------------------------------------------------------------//
1004
1005// Approximates erf(x) with
1006// a - P(x)/Q(x)
1007// where P and Q are polynomials of degree 4.
1008// Different coefficients are chosen based on the value of x.
1009// The approximation error is ~2.5e-07.
1010// Boost's minimax tool that utilizes the Remez method was used to find the
1011// coefficients.
1012LogicalResult
1013ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
1014 PatternRewriter &rewriter) const {
1015 Value operand = op.getOperand();
1016 Type elementType = getElementTypeOrSelf(val: operand);
1017
1018 if (!(elementType.isF32() || elementType.isF16()))
1019 return rewriter.notifyMatchFailure(arg&: op,
1020 msg: "only f32 and f16 type is supported.");
1021 std::optional<VectorShape> shape = vectorShape(value: operand);
1022
1023 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1024 auto bcast = [&](Value value) -> Value {
1025 return broadcast(builder, value, shape);
1026 };
1027
1028 const int intervalsCount = 3;
1029 const int polyDegree = 4;
1030
1031 Value zero = bcast(floatCst(builder, value: 0, elementType));
1032 Value one = bcast(floatCst(builder, value: 1, elementType));
1033 Value pp[intervalsCount][polyDegree + 1];
1034 pp[0][0] = bcast(floatCst(builder, value: +0.00000000000000000e+00f, elementType));
1035 pp[0][1] = bcast(floatCst(builder, value: +1.12837916222975858e+00f, elementType));
1036 pp[0][2] = bcast(floatCst(builder, value: -5.23018562988006470e-01f, elementType));
1037 pp[0][3] = bcast(floatCst(builder, value: +2.09741709609267072e-01f, elementType));
1038 pp[0][4] = bcast(floatCst(builder, value: +2.58146801602987875e-02f, elementType));
1039 pp[1][0] = bcast(floatCst(builder, value: +0.00000000000000000e+00f, elementType));
1040 pp[1][1] = bcast(floatCst(builder, value: +1.12750687816789140e+00f, elementType));
1041 pp[1][2] = bcast(floatCst(builder, value: -3.64721408487825775e-01f, elementType));
1042 pp[1][3] = bcast(floatCst(builder, value: +1.18407396425136952e-01f, elementType));
1043 pp[1][4] = bcast(floatCst(builder, value: +3.70645533056476558e-02f, elementType));
1044 pp[2][0] = bcast(floatCst(builder, value: -3.30093071049483172e-03f, elementType));
1045 pp[2][1] = bcast(floatCst(builder, value: +3.51961938357697011e-03f, elementType));
1046 pp[2][2] = bcast(floatCst(builder, value: -1.41373622814988039e-03f, elementType));
1047 pp[2][3] = bcast(floatCst(builder, value: +2.53447094961941348e-04f, elementType));
1048 pp[2][4] = bcast(floatCst(builder, value: -1.71048029455037401e-05f, elementType));
1049
1050 Value qq[intervalsCount][polyDegree + 1];
1051 qq[0][0] = bcast(floatCst(builder, value: +1.000000000000000000e+00f, elementType));
1052 qq[0][1] = bcast(floatCst(builder, value: -4.635138185962547255e-01f, elementType));
1053 qq[0][2] = bcast(floatCst(builder, value: +5.192301327279782447e-01f, elementType));
1054 qq[0][3] = bcast(floatCst(builder, value: -1.318089722204810087e-01f, elementType));
1055 qq[0][4] = bcast(floatCst(builder, value: +7.397964654672315005e-02f, elementType));
1056 qq[1][0] = bcast(floatCst(builder, value: +1.00000000000000000e+00f, elementType));
1057 qq[1][1] = bcast(floatCst(builder, value: -3.27607011824493086e-01f, elementType));
1058 qq[1][2] = bcast(floatCst(builder, value: +4.48369090658821977e-01f, elementType));
1059 qq[1][3] = bcast(floatCst(builder, value: -8.83462621207857930e-02f, elementType));
1060 qq[1][4] = bcast(floatCst(builder, value: +5.72442770283176093e-02f, elementType));
1061 qq[2][0] = bcast(floatCst(builder, value: +1.00000000000000000e+00f, elementType));
1062 qq[2][1] = bcast(floatCst(builder, value: -2.06069165953913769e+00f, elementType));
1063 qq[2][2] = bcast(floatCst(builder, value: +1.62705939945477759e+00f, elementType));
1064 qq[2][3] = bcast(floatCst(builder, value: -5.83389859211130017e-01f, elementType));
1065 qq[2][4] = bcast(floatCst(builder, value: +8.21908939856640930e-02f, elementType));
1066
1067 Value offsets[intervalsCount];
1068 offsets[0] = bcast(floatCst(builder, value: 0.0f, elementType));
1069 offsets[1] = bcast(floatCst(builder, value: 0.0f, elementType));
1070 offsets[2] = bcast(floatCst(builder, value: 1.0f, elementType));
1071
1072 Value bounds[intervalsCount];
1073 bounds[0] = bcast(floatCst(builder, value: 0.8f, elementType));
1074 bounds[1] = bcast(floatCst(builder, value: 2.0f, elementType));
1075 bounds[2] = bcast(floatCst(builder, value: 3.75f, elementType));
1076
1077 Value isNegativeArg =
1078 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OLT, args&: operand, args&: zero);
1079 Value negArg = builder.create<arith::NegFOp>(args&: operand);
1080 Value x = builder.create<arith::SelectOp>(args&: isNegativeArg, args&: negArg, args&: operand);
1081
1082 Value offset = offsets[0];
1083 Value p[polyDegree + 1];
1084 Value q[polyDegree + 1];
1085 for (int i = 0; i <= polyDegree; ++i) {
1086 p[i] = pp[0][i];
1087 q[i] = qq[0][i];
1088 }
1089
1090 // TODO: maybe use vector stacking to reduce the number of selects.
1091 Value isLessThanBound[intervalsCount];
1092 for (int j = 0; j < intervalsCount - 1; ++j) {
1093 isLessThanBound[j] =
1094 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OLT, args&: x, args&: bounds[j]);
1095 for (int i = 0; i <= polyDegree; ++i) {
1096 p[i] = builder.create<arith::SelectOp>(args&: isLessThanBound[j], args&: p[i],
1097 args&: pp[j + 1][i]);
1098 q[i] = builder.create<arith::SelectOp>(args&: isLessThanBound[j], args&: q[i],
1099 args&: qq[j + 1][i]);
1100 }
1101 offset = builder.create<arith::SelectOp>(args&: isLessThanBound[j], args&: offset,
1102 args&: offsets[j + 1]);
1103 }
1104 isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>(
1105 args: arith::CmpFPredicate::ULT, args&: x, args&: bounds[intervalsCount - 1]);
1106
1107 Value pPoly = makePolynomialCalculation(builder, coeffs: p, x);
1108 Value qPoly = makePolynomialCalculation(builder, coeffs: q, x);
1109 Value rationalPoly = builder.create<arith::DivFOp>(args&: pPoly, args&: qPoly);
1110 Value formula = builder.create<arith::AddFOp>(args&: offset, args&: rationalPoly);
1111 formula = builder.create<arith::SelectOp>(args&: isLessThanBound[intervalsCount - 1],
1112 args&: formula, args&: one);
1113
1114 // erf is odd function: erf(x) = -erf(-x).
1115 Value negFormula = builder.create<arith::NegFOp>(args&: formula);
1116 Value res =
1117 builder.create<arith::SelectOp>(args&: isNegativeArg, args&: negFormula, args&: formula);
1118
1119 rewriter.replaceOp(op, newValues: res);
1120
1121 return success();
1122}
1123
1124// Approximates erfc(x) with p((x - 2) / (x + 2)), where p is a 9 degree
1125// polynomial.This approximation is based on the following stackoverflow post:
1126// https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf
1127// The stackoverflow post is in turn based on:
1128// M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of
1129// (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36,
1130// No. 153, January 1981, pp. 249-253.
1131//
1132// Maximum error: 2.65 ulps
1133LogicalResult
1134ErfcPolynomialApproximation::matchAndRewrite(math::ErfcOp op,
1135 PatternRewriter &rewriter) const {
1136 Value x = op.getOperand();
1137 Type et = getElementTypeOrSelf(val: x);
1138
1139 if (!et.isF32())
1140 return rewriter.notifyMatchFailure(arg&: op, msg: "only f32 type is supported.");
1141 std::optional<VectorShape> shape = vectorShape(value: x);
1142
1143 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1144 auto bcast = [&](Value value) -> Value {
1145 return broadcast(builder, value, shape);
1146 };
1147
1148 Value trueValue = bcast(boolCst(builder, value: true));
1149 Value zero = bcast(floatCst(builder, value: 0.0f, elementType: et));
1150 Value one = bcast(floatCst(builder, value: 1.0f, elementType: et));
1151 Value onehalf = bcast(floatCst(builder, value: 0.5f, elementType: et));
1152 Value neg4 = bcast(floatCst(builder, value: -4.0f, elementType: et));
1153 Value neg2 = bcast(floatCst(builder, value: -2.0f, elementType: et));
1154 Value pos2 = bcast(floatCst(builder, value: 2.0f, elementType: et));
1155 Value posInf = bcast(floatCst(builder, INFINITY, elementType: et));
1156 Value clampVal = bcast(floatCst(builder, value: 10.0546875f, elementType: et));
1157
1158 Value a = builder.create<math::AbsFOp>(args&: x);
1159 Value p = builder.create<arith::AddFOp>(args&: a, args&: pos2);
1160 Value r = builder.create<arith::DivFOp>(args&: one, args&: p);
1161 Value q = builder.create<math::FmaOp>(args&: neg4, args&: r, args&: one);
1162 Value t = builder.create<math::FmaOp>(args: builder.create<arith::AddFOp>(args&: q, args&: one),
1163 args&: neg2, args&: a);
1164 Value e = builder.create<math::FmaOp>(args: builder.create<arith::NegFOp>(args&: a), args&: q, args&: t);
1165 q = builder.create<math::FmaOp>(args&: r, args&: e, args&: q);
1166
1167 p = bcast(floatCst(builder, value: -0x1.a4a000p-12f, elementType: et)); // -4.01139259e-4
1168 Value c1 = bcast(floatCst(builder, value: -0x1.42a260p-10f, elementType: et)); // -1.23075210e-3
1169 p = builder.create<math::FmaOp>(args&: p, args&: q, args&: c1);
1170 Value c2 = bcast(floatCst(builder, value: 0x1.585714p-10f, elementType: et)); // 1.31355342e-3
1171 p = builder.create<math::FmaOp>(args&: p, args&: q, args&: c2);
1172 Value c3 = bcast(floatCst(builder, value: 0x1.1adcc4p-07f, elementType: et)); // 8.63227434e-3
1173 p = builder.create<math::FmaOp>(args&: p, args&: q, args&: c3);
1174 Value c4 = bcast(floatCst(builder, value: -0x1.081b82p-07f, elementType: et)); // -8.05991981e-3
1175 p = builder.create<math::FmaOp>(args&: p, args&: q, args&: c4);
1176 Value c5 = bcast(floatCst(builder, value: -0x1.bc0b6ap-05f, elementType: et)); // -5.42046614e-2
1177 p = builder.create<math::FmaOp>(args&: p, args&: q, args&: c5);
1178 Value c6 = bcast(floatCst(builder, value: 0x1.4ffc46p-03f, elementType: et)); // 1.64055392e-1
1179 p = builder.create<math::FmaOp>(args&: p, args&: q, args&: c6);
1180 Value c7 = bcast(floatCst(builder, value: -0x1.540840p-03f, elementType: et)); // -1.66031361e-1
1181 p = builder.create<math::FmaOp>(args&: p, args&: q, args&: c7);
1182 Value c8 = bcast(floatCst(builder, value: -0x1.7bf616p-04f, elementType: et)); // -9.27639827e-2
1183 p = builder.create<math::FmaOp>(args&: p, args&: q, args&: c8);
1184 Value c9 = bcast(floatCst(builder, value: 0x1.1ba03ap-02f, elementType: et)); // 2.76978403e-1
1185 p = builder.create<math::FmaOp>(args&: p, args&: q, args&: c9);
1186
1187 Value d = builder.create<math::FmaOp>(args&: pos2, args&: a, args&: one);
1188 r = builder.create<arith::DivFOp>(args&: one, args&: d);
1189 q = builder.create<math::FmaOp>(args&: p, args&: r, args&: r);
1190 Value negfa = builder.create<arith::NegFOp>(args&: a);
1191 Value fmaqah = builder.create<math::FmaOp>(args&: q, args&: negfa, args&: onehalf);
1192 Value psubq = builder.create<arith::SubFOp>(args&: p, args&: q);
1193 e = builder.create<math::FmaOp>(args&: fmaqah, args&: pos2, args&: psubq);
1194 r = builder.create<math::FmaOp>(args&: e, args&: r, args&: q);
1195
1196 Value s = builder.create<arith::MulFOp>(args&: a, args&: a);
1197 e = builder.create<math::ExpOp>(args: builder.create<arith::NegFOp>(args&: s));
1198
1199 t = builder.create<math::FmaOp>(args: builder.create<arith::NegFOp>(args&: a), args&: a, args&: s);
1200 r = builder.create<math::FmaOp>(
1201 args&: r, args&: e,
1202 args: builder.create<arith::MulFOp>(args: builder.create<arith::MulFOp>(args&: r, args&: e), args&: t));
1203
1204 Value isNotLessThanInf = builder.create<arith::XOrIOp>(
1205 args: builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OLT, args&: a, args&: posInf),
1206 args&: trueValue);
1207 r = builder.create<arith::SelectOp>(args&: isNotLessThanInf,
1208 args: builder.create<arith::AddFOp>(args&: x, args&: x), args&: r);
1209 Value isGreaterThanClamp =
1210 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGT, args&: a, args&: clampVal);
1211 r = builder.create<arith::SelectOp>(args&: isGreaterThanClamp, args&: zero, args&: r);
1212
1213 Value isNegative =
1214 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OLT, args&: x, args&: zero);
1215 r = builder.create<arith::SelectOp>(
1216 args&: isNegative, args: builder.create<arith::SubFOp>(args&: pos2, args&: r), args&: r);
1217
1218 rewriter.replaceOp(op, newValues: r);
1219 return success();
1220}
1221//----------------------------------------------------------------------------//
1222// Exp approximation.
1223//----------------------------------------------------------------------------//
1224
1225namespace {
1226
1227Value clampWithNormals(ImplicitLocOpBuilder &builder,
1228 const std::optional<VectorShape> shape, Value value,
1229 float lowerBound, float upperBound) {
1230 assert(!std::isnan(lowerBound));
1231 assert(!std::isnan(upperBound));
1232
1233 auto bcast = [&](Value value) -> Value {
1234 return broadcast(builder, value, shape);
1235 };
1236
1237 auto selectCmp = [&builder](auto pred, Value value, Value bound) {
1238 return builder.create<arith::SelectOp>(
1239 builder.create<arith::CmpFOp>(pred, value, bound), value, bound);
1240 };
1241
1242 // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs.
1243 // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with
1244 // arith::{Max,Min}FOp.
1245 value = selectCmp(arith::CmpFPredicate::UGE, value,
1246 bcast(f32Cst(builder, value: lowerBound)));
1247 value = selectCmp(arith::CmpFPredicate::ULE, value,
1248 bcast(f32Cst(builder, value: upperBound)));
1249 return value;
1250}
1251
1252struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
1253public:
1254 using OpRewritePattern::OpRewritePattern;
1255
1256 LogicalResult matchAndRewrite(math::ExpOp op,
1257 PatternRewriter &rewriter) const final;
1258};
1259
1260LogicalResult
1261ExpApproximation::matchAndRewrite(math::ExpOp op,
1262 PatternRewriter &rewriter) const {
1263 auto shape = vectorShape(type: op.getOperand().getType());
1264 auto elementTy = getElementTypeOrSelf(type: op.getType());
1265 if (!elementTy.isF32())
1266 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported operand type");
1267
1268 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1269
1270 auto add = [&](Value a, Value b) -> Value {
1271 return builder.create<arith::AddFOp>(args&: a, args&: b);
1272 };
1273 auto bcast = [&](Value value) -> Value {
1274 return broadcast(builder, value, shape);
1275 };
1276 auto floor = [&](Value a) { return builder.create<math::FloorOp>(args&: a); };
1277 auto fmla = [&](Value a, Value b, Value c) {
1278 return builder.create<math::FmaOp>(args&: a, args&: b, args&: c);
1279 };
1280 auto mul = [&](Value a, Value b) -> Value {
1281 return builder.create<arith::MulFOp>(args&: a, args&: b);
1282 };
1283
1284 // Polynomial approximation from Cephes.
1285 //
1286 // To compute e^x, we re-express it as
1287 //
1288 // e^x = e^(a + b)
1289 // = e^(a + n log(2))
1290 // = e^a * 2^n.
1291 //
1292 // We choose n = round(x / log(2)), restricting the value of `a` to
1293 // (-log(2)/2, log(2)/2). We then use a polynomial to compute e^a. The
1294 // relative error between our approximation and the true value of e^a is less
1295 // than 2^-22.5 for all values of `a` within this range.
1296
1297 // Restrict input to a small range, including some values that evaluate to
1298 // +/- inf. Note that for our lower bound, we choose log(2^-126) instead of
1299 // log(F32_EPSILON). We do so because this routine always flushes denormal
1300 // floating points to 0. Therefore, we only need to worry about exponentiating
1301 // up to the smallest representable non-denormal floating point, which is
1302 // 2^-126.
1303
1304 // Constants.
1305 Value cstHalf = bcast(f32Cst(builder, value: 0.5f));
1306 Value cstOne = bcast(f32Cst(builder, value: 1.0f));
1307
1308 // 1/log(2)
1309 Value cstLog2ef = bcast(f32Cst(builder, value: 1.44269504088896341f));
1310
1311 Value cstExpC1 = bcast(f32Cst(builder, value: -0.693359375f));
1312 Value cstExpC2 = bcast(f32Cst(builder, value: 2.12194440e-4f));
1313 Value cstExpP0 = bcast(f32Cst(builder, value: 1.9875691500E-4f));
1314 Value cstExpP1 = bcast(f32Cst(builder, value: 1.3981999507E-3f));
1315 Value cstExpP2 = bcast(f32Cst(builder, value: 8.3334519073E-3f));
1316 Value cstExpP3 = bcast(f32Cst(builder, value: 4.1665795894E-2f));
1317 Value cstExpP4 = bcast(f32Cst(builder, value: 1.6666665459E-1f));
1318 Value cstExpP5 = bcast(f32Cst(builder, value: 5.0000001201E-1f));
1319
1320 // Our computations below aren't particularly sensitive to the exact choices
1321 // here, so we choose values a bit larger/smaller than
1322 //
1323 // log(F32_MAX) = 88.723...
1324 // log(2^-126) = -87.337...
1325 Value x = op.getOperand();
1326 x = clampWithNormals(builder, shape, value: x, lowerBound: -87.8f, upperBound: 88.8f);
1327 Value n = floor(fmla(x, cstLog2ef, cstHalf));
1328
1329 // When we eventually do the multiplication in e^a * 2^n, we need to handle
1330 // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1
1331 // (so e^a * 2^n != inf). There's a similar problem for n < -126, the
1332 // smallest fp32 exponent.
1333 //
1334 // A straightforward solution would be to detect n out of range and split it
1335 // up, doing
1336 //
1337 // e^a * 2^n = e^a * 2^(n1 + n2)
1338 // = (2^n1 * e^a) * 2^n2.
1339 //
1340 // But it turns out this approach is quite slow, probably because it
1341 // manipulates subnormal values.
1342 //
1343 // The approach we use instead is to clamp n to [-127, 127]. Let n' be the
1344 // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow
1345 // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though
1346 // this value of `a` is outside our previously specified range, e^a will still
1347 // only have a relative error of approximately 2^-16 at worse. In practice
1348 // this seems to work well enough; it passes our exhaustive tests, breaking
1349 // only one result, and by one ulp (we return exp(88.7228394) = max-float but
1350 // we should return inf).
1351 //
1352 // In the case where n' = -127, the original input value of x is so small that
1353 // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest
1354 // normal floating point, and since we flush denormals, we simply return 0. We
1355 // do this in a branchless way by observing that our code for constructing 2^n
1356 // produces 0 if n = -127.
1357 //
1358 // The proof that n' = -127 implies e^x < 2^-126 is as follows:
1359 //
1360 // n' = -127 implies n <= -127
1361 // implies round(x / log(2)) <= -127
1362 // implies x/log(2) < -126.5
1363 // implies x < -126.5 * log(2)
1364 // implies e^x < e^(-126.5 * log(2))
1365 // implies e^x < 2^-126.5 < 2^-126
1366 //
1367 // This proves that n' = -127 implies e^x < 2^-126.
1368 n = clampWithNormals(builder, shape, value: n, lowerBound: -127.0f, upperBound: 127.0f);
1369
1370 // Computes x = x - n' * log(2), the value for `a`
1371 x = fmla(cstExpC1, n, x);
1372 x = fmla(cstExpC2, n, x);
1373
1374 // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5).
1375 Value z = fmla(x, cstExpP0, cstExpP1);
1376 z = fmla(z, x, cstExpP2);
1377 z = fmla(z, x, cstExpP3);
1378 z = fmla(z, x, cstExpP4);
1379 z = fmla(z, x, cstExpP5);
1380 z = fmla(z, mul(x, x), x);
1381 z = add(cstOne, z);
1382
1383 // Convert n' to an i32. This is safe because we clamped it above.
1384 auto i32Vec = broadcast(type: builder.getI32Type(), shape);
1385 Value nI32 = builder.create<arith::FPToSIOp>(args&: i32Vec, args&: n);
1386
1387 // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127.
1388 Value pow2 = exp2I32(builder, arg: nI32);
1389
1390 // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127.
1391 Value ret = mul(z, pow2);
1392
1393 rewriter.replaceOp(op, newValues: ret);
1394 return mlir::success();
1395}
1396
1397} // namespace
1398
1399//----------------------------------------------------------------------------//
1400// ExpM1 approximation.
1401//----------------------------------------------------------------------------//
1402
1403namespace {
1404
1405struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
1406public:
1407 using OpRewritePattern::OpRewritePattern;
1408
1409 LogicalResult matchAndRewrite(math::ExpM1Op op,
1410 PatternRewriter &rewriter) const final;
1411};
1412} // namespace
1413
1414LogicalResult
1415ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1416 PatternRewriter &rewriter) const {
1417 if (!getElementTypeOrSelf(val: op.getOperand()).isF32())
1418 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported operand type");
1419
1420 std::optional<VectorShape> shape = vectorShape(value: op.getOperand());
1421
1422 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1423 auto bcast = [&](Value value) -> Value {
1424 return broadcast(builder, value, shape);
1425 };
1426
1427 // expm1(x) = exp(x) - 1 = u - 1.
1428 // We have to handle it carefully when x is near 0, i.e. u ~= 1,
1429 // and when the input is ~= -inf, i.e. u - 1 ~= -1.
1430 Value cstOne = bcast(f32Cst(builder, value: 1.0f));
1431 Value cstNegOne = bcast(f32Cst(builder, value: -1.0f));
1432 Value x = op.getOperand();
1433 Value u = builder.create<math::ExpOp>(args&: x);
1434 Value uEqOneOrNaN =
1435 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::UEQ, args&: u, args&: cstOne);
1436 Value uMinusOne = builder.create<arith::SubFOp>(args&: u, args&: cstOne);
1437 Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
1438 args: arith::CmpFPredicate::OEQ, args&: uMinusOne, args&: cstNegOne);
1439 // logU = log(u) ~= x
1440 Value logU = builder.create<math::LogOp>(args&: u);
1441
1442 // Detect exp(x) = +inf; written this way to avoid having to form +inf.
1443 Value isInf =
1444 builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: logU, args&: u);
1445
1446 // (u - 1) * (x / ~x)
1447 Value expm1 = builder.create<arith::MulFOp>(
1448 args&: uMinusOne, args: builder.create<arith::DivFOp>(args&: x, args&: logU));
1449 expm1 = builder.create<arith::SelectOp>(args&: isInf, args&: u, args&: expm1);
1450 Value approximation = builder.create<arith::SelectOp>(
1451 args&: uEqOneOrNaN, args&: x,
1452 args: builder.create<arith::SelectOp>(args&: uMinusOneEqNegOne, args&: cstNegOne, args&: expm1));
1453 rewriter.replaceOp(op, newValues: approximation);
1454 return success();
1455}
1456
1457//----------------------------------------------------------------------------//
1458// Sin and Cos approximation.
1459//----------------------------------------------------------------------------//
1460
1461namespace {
1462
1463template <bool isSine, typename OpTy>
1464struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
1465public:
1466 using OpRewritePattern<OpTy>::OpRewritePattern;
1467
1468 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
1469};
1470} // namespace
1471
1472#define TWO_OVER_PI \
1473 0.6366197723675813430755350534900574481378385829618257949906693762L
1474#define PI_OVER_2 \
1475 1.5707963267948966192313216916397514420985846996875529104874722961L
1476
1477// Approximates sin(x) or cos(x) by finding the best approximation polynomial in
1478// the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
1479// reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
1480template <bool isSine, typename OpTy>
1481LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1482 OpTy op, PatternRewriter &rewriter) const {
1483 static_assert(
1484 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1485 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1486
1487 if (!getElementTypeOrSelf(op.getOperand()).isF32())
1488 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1489
1490 std::optional<VectorShape> shape = vectorShape(op.getOperand());
1491
1492 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1493 auto bcast = [&](Value value) -> Value {
1494 return broadcast(builder, value, shape);
1495 };
1496 auto mul = [&](Value a, Value b) -> Value {
1497 return builder.create<arith::MulFOp>(args&: a, args&: b);
1498 };
1499 auto sub = [&](Value a, Value b) -> Value {
1500 return builder.create<arith::SubFOp>(args&: a, args&: b);
1501 };
1502 auto floor = [&](Value a) { return builder.create<math::FloorOp>(args&: a); };
1503
1504 auto i32Vec = broadcast(type: builder.getI32Type(), shape);
1505 auto fPToSingedInteger = [&](Value a) -> Value {
1506 return builder.create<arith::FPToSIOp>(args&: i32Vec, args&: a);
1507 };
1508
1509 auto modulo4 = [&](Value a) -> Value {
1510 return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, value: 3)));
1511 };
1512
1513 auto isEqualTo = [&](Value a, Value b) -> Value {
1514 return builder.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: a, args&: b);
1515 };
1516
1517 auto isGreaterThan = [&](Value a, Value b) -> Value {
1518 return builder.create<arith::CmpIOp>(args: arith::CmpIPredicate::sgt, args&: a, args&: b);
1519 };
1520
1521 auto select = [&](Value cond, Value t, Value f) -> Value {
1522 return builder.create<arith::SelectOp>(args&: cond, args&: t, args&: f);
1523 };
1524
1525 auto fmla = [&](Value a, Value b, Value c) {
1526 return builder.create<math::FmaOp>(args&: a, args&: b, args&: c);
1527 };
1528
1529 auto bitwiseOr = [&](Value a, Value b) {
1530 return builder.create<arith::OrIOp>(args&: a, args&: b);
1531 };
1532
1533 Value twoOverPi = bcast(f32Cst(builder, value: (float)TWO_OVER_PI));
1534 Value piOverTwo = bcast(f32Cst(builder, value: (float)PI_OVER_2));
1535
1536 Value x = op.getOperand();
1537
1538 Value k = floor(mul(x, twoOverPi));
1539
1540 Value y = sub(x, mul(k, piOverTwo));
1541
1542 Value cstOne = bcast(f32Cst(builder, value: 1.0));
1543 Value cstNegativeOne = bcast(f32Cst(builder, value: -1.0));
1544
1545 Value cstSC2 = bcast(f32Cst(builder, value: -0.16666667163372039794921875f));
1546 Value cstSC4 = bcast(f32Cst(builder, value: 8.333347737789154052734375e-3f));
1547 Value cstSC6 = bcast(f32Cst(builder, value: -1.9842604524455964565277099609375e-4f));
1548 Value cstSC8 =
1549 bcast(f32Cst(builder, value: 2.760012648650445044040679931640625e-6f));
1550 Value cstSC10 =
1551 bcast(f32Cst(builder, value: -2.50293279435709337121807038784027099609375e-8f));
1552
1553 Value cstCC2 = bcast(f32Cst(builder, value: -0.5f));
1554 Value cstCC4 = bcast(f32Cst(builder, value: 4.166664183139801025390625e-2f));
1555 Value cstCC6 = bcast(f32Cst(builder, value: -1.388833043165504932403564453125e-3f));
1556 Value cstCC8 = bcast(f32Cst(builder, value: 2.47562347794882953166961669921875e-5f));
1557 Value cstCC10 =
1558 bcast(f32Cst(builder, value: -2.59630184018533327616751194000244140625e-7f));
1559
1560 Value kMod4 = modulo4(fPToSingedInteger(k));
1561
1562 Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 0)));
1563 Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 1)));
1564 Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 2)));
1565 Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 3)));
1566
1567 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1568 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, value: 1)))
1569 : bitwiseOr(kR1, kR2);
1570
1571 Value y2 = mul(y, y);
1572
1573 Value base = select(sinuseCos, cstOne, y);
1574 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1575 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1576 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1577 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1578 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1579
1580 Value v1 = fmla(y2, cstC10, cstC8);
1581 Value v2 = fmla(y2, v1, cstC6);
1582 Value v3 = fmla(y2, v2, cstC4);
1583 Value v4 = fmla(y2, v3, cstC2);
1584 Value v5 = fmla(y2, v4, cstOne);
1585 Value v6 = mul(base, v5);
1586
1587 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1588
1589 rewriter.replaceOp(op, approximation);
1590
1591 return success();
1592}
1593
1594//----------------------------------------------------------------------------//
1595// Cbrt approximation.
1596//----------------------------------------------------------------------------//
1597
1598namespace {
1599struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> {
1600 using OpRewritePattern::OpRewritePattern;
1601
1602 LogicalResult matchAndRewrite(math::CbrtOp op,
1603 PatternRewriter &rewriter) const final;
1604};
1605} // namespace
1606
1607// Estimation of cube-root using an algorithm defined in
1608// Hacker's Delight 2nd Edition.
1609LogicalResult
1610CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1611 PatternRewriter &rewriter) const {
1612 auto operand = op.getOperand();
1613 if (!getElementTypeOrSelf(val: operand).isF32())
1614 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported operand type");
1615
1616 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1617 std::optional<VectorShape> shape = vectorShape(value: operand);
1618
1619 Type floatTy = getElementTypeOrSelf(type: operand.getType());
1620 Type intTy = b.getIntegerType(width: floatTy.getIntOrFloatBitWidth());
1621
1622 // Convert to vector types if necessary.
1623 floatTy = broadcast(type: floatTy, shape);
1624 intTy = broadcast(type: intTy, shape);
1625
1626 auto bconst = [&](TypedAttr attr) -> Value {
1627 Value value = b.create<arith::ConstantOp>(args&: attr);
1628 return broadcast(builder&: b, value, shape);
1629 };
1630
1631 // Declare the initial values:
1632 Value intTwo = bconst(b.getI32IntegerAttr(value: 2));
1633 Value intFour = bconst(b.getI32IntegerAttr(value: 4));
1634 Value intEight = bconst(b.getI32IntegerAttr(value: 8));
1635 Value intMagic = bconst(b.getI32IntegerAttr(value: 0x2a5137a0));
1636 Value fpThird = bconst(b.getF32FloatAttr(value: 0.33333333f));
1637 Value fpTwo = bconst(b.getF32FloatAttr(value: 2.0f));
1638 Value fpZero = bconst(b.getF32FloatAttr(value: 0.0f));
1639
1640 // Compute an approximation of one third:
1641 // union {int ix; float x;};
1642 // x = x0;
1643 // ix = ix/4 + ix/16;
1644 Value absValue = b.create<math::AbsFOp>(args&: operand);
1645 Value intValue = b.create<arith::BitcastOp>(args&: intTy, args&: absValue);
1646 Value divideBy4 = b.create<arith::ShRSIOp>(args&: intValue, args&: intTwo);
1647 Value divideBy16 = b.create<arith::ShRSIOp>(args&: intValue, args&: intFour);
1648 intValue = b.create<arith::AddIOp>(args&: divideBy4, args&: divideBy16);
1649
1650 // ix = ix + ix/16;
1651 divideBy16 = b.create<arith::ShRSIOp>(args&: intValue, args&: intFour);
1652 intValue = b.create<arith::AddIOp>(args&: intValue, args&: divideBy16);
1653
1654 // ix = ix + ix/256;
1655 Value divideBy256 = b.create<arith::ShRSIOp>(args&: intValue, args&: intEight);
1656 intValue = b.create<arith::AddIOp>(args&: intValue, args&: divideBy256);
1657
1658 // ix = 0x2a5137a0 + ix;
1659 intValue = b.create<arith::AddIOp>(args&: intValue, args&: intMagic);
1660
1661 // Perform one newtons step:
1662 // x = 0.33333333f*(2.0f*x + x0/(x*x));
1663 Value floatValue = b.create<arith::BitcastOp>(args&: floatTy, args&: intValue);
1664 Value squared = b.create<arith::MulFOp>(args&: floatValue, args&: floatValue);
1665 Value mulTwo = b.create<arith::MulFOp>(args&: floatValue, args&: fpTwo);
1666 Value divSquared = b.create<arith::DivFOp>(args&: absValue, args&: squared);
1667 floatValue = b.create<arith::AddFOp>(args&: mulTwo, args&: divSquared);
1668 floatValue = b.create<arith::MulFOp>(args&: floatValue, args&: fpThird);
1669
1670 // x = 0.33333333f*(2.0f*x + x0/(x*x));
1671 squared = b.create<arith::MulFOp>(args&: floatValue, args&: floatValue);
1672 mulTwo = b.create<arith::MulFOp>(args&: floatValue, args&: fpTwo);
1673 divSquared = b.create<arith::DivFOp>(args&: absValue, args&: squared);
1674 floatValue = b.create<arith::AddFOp>(args&: mulTwo, args&: divSquared);
1675 floatValue = b.create<arith::MulFOp>(args&: floatValue, args&: fpThird);
1676
1677 // Check for zero and restore sign.
1678 Value isZero =
1679 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: absValue, args&: fpZero);
1680 floatValue = b.create<arith::SelectOp>(args&: isZero, args&: fpZero, args&: floatValue);
1681 floatValue = b.create<math::CopySignOp>(args&: floatValue, args&: operand);
1682
1683 rewriter.replaceOp(op, newValues: floatValue);
1684 return success();
1685}
1686
1687//----------------------------------------------------------------------------//
1688// Rsqrt approximation.
1689//----------------------------------------------------------------------------//
1690
1691namespace {
1692struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
1693 using OpRewritePattern::OpRewritePattern;
1694
1695 LogicalResult matchAndRewrite(math::RsqrtOp op,
1696 PatternRewriter &rewriter) const final;
1697};
1698} // namespace
1699
1700LogicalResult
1701RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1702 PatternRewriter &rewriter) const {
1703 if (!getElementTypeOrSelf(val: op.getOperand()).isF32())
1704 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported operand type");
1705
1706 std::optional<VectorShape> shape = vectorShape(value: op.getOperand());
1707
1708 // Only support already-vectorized rsqrt's.
1709 if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0)
1710 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported operand type");
1711
1712 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1713 auto bcast = [&](Value value) -> Value {
1714 return broadcast(builder, value, shape);
1715 };
1716
1717 Value cstPosInf = bcast(f32FromBits(builder, bits: 0x7f800000u));
1718 Value cstOnePointFive = bcast(f32Cst(builder, value: 1.5f));
1719 Value cstNegHalf = bcast(f32Cst(builder, value: -0.5f));
1720 Value cstMinNormPos = bcast(f32FromBits(builder, bits: 0x00800000u));
1721
1722 Value negHalf = builder.create<arith::MulFOp>(args: op.getOperand(), args&: cstNegHalf);
1723
1724 // Select only the inverse sqrt of positive normals (denormals are
1725 // flushed to zero).
1726 Value ltMinMask = builder.create<arith::CmpFOp>(
1727 args: arith::CmpFPredicate::OLT, args: op.getOperand(), args&: cstMinNormPos);
1728 Value infMask = builder.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ,
1729 args: op.getOperand(), args&: cstPosInf);
1730 Value notNormalFiniteMask = builder.create<arith::OrIOp>(args&: ltMinMask, args&: infMask);
1731
1732 // Compute an approximate result.
1733 Value yApprox = handleMultidimensionalVectors(
1734 builder, operands: op->getOperands(), vectorWidth: 8, compute: [&builder](ValueRange operands) -> Value {
1735 return builder.create<x86vector::RsqrtOp>(args&: operands);
1736 });
1737
1738 // Do a single step of Newton-Raphson iteration to improve the approximation.
1739 // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
1740 // It is essential to evaluate the inner term like this because forming
1741 // y_n^2 may over- or underflow.
1742 Value inner = builder.create<arith::MulFOp>(args&: negHalf, args&: yApprox);
1743 Value fma = builder.create<math::FmaOp>(args&: yApprox, args&: inner, args&: cstOnePointFive);
1744 Value yNewton = builder.create<arith::MulFOp>(args&: yApprox, args&: fma);
1745
1746 // Select the result of the Newton-Raphson step for positive normal arguments.
1747 // For other arguments, choose the output of the intrinsic. This will
1748 // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
1749 // x is zero or a positive denormalized float (equivalent to flushing positive
1750 // denormalized inputs to zero).
1751 Value res =
1752 builder.create<arith::SelectOp>(args&: notNormalFiniteMask, args&: yApprox, args&: yNewton);
1753 rewriter.replaceOp(op, newValues: res);
1754
1755 return success();
1756}
1757
1758//----------------------------------------------------------------------------//
1759
1760void mlir::populatePolynomialApproximateTanhPattern(
1761 RewritePatternSet &patterns) {
1762 patterns.add<TanhApproximation>(arg: patterns.getContext());
1763}
1764
1765void mlir::populatePolynomialApproximateErfPattern(
1766 RewritePatternSet &patterns) {
1767 patterns.add<ErfPolynomialApproximation>(arg: patterns.getContext());
1768}
1769
1770void mlir::populatePolynomialApproximateErfcPattern(
1771 RewritePatternSet &patterns) {
1772 patterns.add<ErfcPolynomialApproximation>(arg: patterns.getContext());
1773}
1774
1775template <typename OpType>
1776static void
1777populateMathF32ExpansionPattern(RewritePatternSet &patterns,
1778 llvm::function_ref<bool(StringRef)> predicate,
1779 PatternBenefit benefit) {
1780 if (predicate(OpType::getOperationName())) {
1781 patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext(), benefit);
1782 }
1783}
1784
1785void mlir::populateMathF32ExpansionPatterns(
1786 RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1787 PatternBenefit benefit) {
1788 populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate, benefit);
1789 populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate, benefit);
1790 populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate, benefit);
1791 populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate, benefit);
1792 populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate, benefit);
1793 populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate, benefit);
1794 populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate, benefit);
1795 populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate, benefit);
1796 populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate, benefit);
1797 populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate, benefit);
1798 populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate, benefit);
1799 populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate, benefit);
1800 populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate, benefit);
1801 populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate, benefit);
1802 populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate, benefit);
1803 populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate, benefit);
1804 populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate, benefit);
1805 populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate, benefit);
1806 populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate, benefit);
1807 populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate, benefit);
1808 populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate, benefit);
1809 populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate, benefit);
1810 populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate, benefit);
1811 populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate, benefit);
1812 populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate, benefit);
1813 populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate, benefit);
1814}
1815
1816template <typename OpType, typename PatternType>
1817static void populateMathPolynomialApproximationPattern(
1818 RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1819 PatternBenefit benefit) {
1820 if (predicate(OpType::getOperationName())) {
1821 patterns.add<PatternType>(patterns.getContext(), benefit);
1822 }
1823}
1824
1825void mlir::populateMathPolynomialApproximationPatterns(
1826 RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1827 PatternBenefit benefit) {
1828 populateMathPolynomialApproximationPattern<AcosOp,
1829 AcosPolynomialApproximation>(
1830 patterns, predicate, benefit);
1831 populateMathPolynomialApproximationPattern<AsinOp,
1832 AsinPolynomialApproximation>(
1833 patterns, predicate, benefit);
1834 populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
1835 patterns, predicate, benefit);
1836 populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
1837 patterns, predicate, benefit);
1838 populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
1839 patterns, predicate, benefit);
1840 populateMathPolynomialApproximationPattern<
1841 CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate,
1842 benefit);
1843 populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
1844 patterns, predicate, benefit);
1845 populateMathPolynomialApproximationPattern<ErfcOp,
1846 ErfcPolynomialApproximation>(
1847 patterns, predicate, benefit);
1848 populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
1849 patterns, predicate, benefit);
1850 populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
1851 patterns, predicate, benefit);
1852 populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
1853 patterns, predicate, benefit);
1854 populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
1855 patterns, predicate, benefit);
1856 populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
1857 patterns, predicate, benefit);
1858 populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
1859 patterns, predicate, benefit);
1860 populateMathPolynomialApproximationPattern<
1861 SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate,
1862 benefit);
1863 populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
1864 patterns, predicate, benefit);
1865}
1866
1867void mlir::populateMathPolynomialApproximationPatterns(
1868 RewritePatternSet &patterns,
1869 const MathPolynomialApproximationOptions &options) {
1870 mlir::populateMathF32ExpansionPatterns(patterns, predicate: [](StringRef name) -> bool {
1871 return llvm::is_contained(
1872 Set: {math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(),
1873 math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1874 math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(),
1875 math::ErfOp::getOperationName(), math::ErfcOp::getOperationName(),
1876 math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
1877 math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1878 math::CosOp::getOperationName()},
1879 Element: name);
1880 });
1881
1882 populateMathPolynomialApproximationPatterns(
1883 patterns, predicate: [](StringRef name) -> bool {
1884 return llvm::is_contained(
1885 Set: {math::AtanOp::getOperationName(),
1886 math::Atan2Op::getOperationName(),
1887 math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1888 math::Log2Op::getOperationName(),
1889 math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(),
1890 math::ErfcOp::getOperationName(), math::AsinOp::getOperationName(),
1891 math::AcosOp::getOperationName(), math::ExpOp::getOperationName(),
1892 math::ExpM1Op::getOperationName(),
1893 math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1894 math::CosOp::getOperationName()},
1895 Element: name);
1896 });
1897
1898 if (options.enableAvx2) {
1899 auto predicateRsqrt = [](StringRef name) {
1900 return name == math::RsqrtOp::getOperationName();
1901 };
1902 mlir::populateMathF32ExpansionPatterns(patterns, predicate: predicateRsqrt);
1903 mlir::populateMathPolynomialApproximationPatterns(patterns, predicate: predicateRsqrt);
1904 }
1905}
1906

source code of mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp