1//===- PolynomialApproximation.cpp - Approximate math operations ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements expansion of math operations to fast approximations
10// that do not rely on any of the library functions.
11//
12//===----------------------------------------------------------------------===//
13
14#include <climits>
15#include <cmath>
16#include <cstddef>
17
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Math/IR/Math.h"
20#include "mlir/Dialect/Math/Transforms/Approximation.h"
21#include "mlir/Dialect/Math/Transforms/Passes.h"
22#include "mlir/Dialect/Utils/IndexingUtils.h"
23#include "mlir/Dialect/Vector/IR/VectorOps.h"
24#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
25#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
26#include "mlir/IR/Builders.h"
27#include "mlir/IR/BuiltinTypes.h"
28#include "mlir/IR/ImplicitLocOpBuilder.h"
29#include "mlir/IR/OpDefinition.h"
30#include "mlir/IR/PatternMatch.h"
31#include "mlir/IR/TypeUtilities.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34#include "llvm/ADT/ArrayRef.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/Support/MathExtras.h"
37
38using namespace mlir;
39using namespace mlir::math;
40using namespace mlir::vector;
41
42// Helper to encapsulate a vector's shape (including scalable dims).
43struct VectorShape {
44 ArrayRef<int64_t> sizes;
45 ArrayRef<bool> scalableFlags;
46};
47
48// Returns vector shape if the type is a vector, otherwise return nullopt.
49static std::optional<VectorShape> vectorShape(Type type) {
50 if (auto vectorType = dyn_cast<VectorType>(type)) {
51 return VectorShape{vectorType.getShape(), vectorType.getScalableDims()};
52 }
53 return std::nullopt;
54}
55
56static std::optional<VectorShape> vectorShape(Value value) {
57 return vectorShape(type: value.getType());
58}
59
60//----------------------------------------------------------------------------//
61// Broadcast scalar types and values into vector types and values.
62//----------------------------------------------------------------------------//
63
64// Broadcasts scalar type into vector type (iff shape is non-scalar).
65static Type broadcast(Type type, std::optional<VectorShape> shape) {
66 assert(!isa<VectorType>(type) && "must be scalar type");
67 return shape ? VectorType::get(shape->sizes, type, shape->scalableFlags)
68 : type;
69}
70
71// Broadcasts scalar value into vector (iff shape is non-scalar).
72static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
73 std::optional<VectorShape> shape) {
74 assert(!isa<VectorType>(value.getType()) && "must be scalar value");
75 auto type = broadcast(type: value.getType(), shape);
76 return shape ? builder.create<BroadcastOp>(type, value) : value;
77}
78
79//----------------------------------------------------------------------------//
80// Helper function to handle n-D vectors with 1-D operations.
81//----------------------------------------------------------------------------//
82
83// Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors
84// and calls the compute function with 1-D vector operands. Stitches back all
85// results into the original n-D vector result.
86//
87// Examples: vectorWidth = 8
88// - vector<4x8xf32> unrolled 4 times
89// - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times
90// - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times
91//
92// Some math approximations rely on ISA-specific operations that only accept
93// fixed size 1-D vectors (e.g. AVX expects vectors of width 8).
94//
95// It is the caller's responsibility to verify that the inner dimension is
96// divisible by the vectorWidth, and that all operands have the same vector
97// shape.
98static Value
99handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
100 ValueRange operands, int64_t vectorWidth,
101 llvm::function_ref<Value(ValueRange)> compute) {
102 assert(!operands.empty() && "operands must be not empty");
103 assert(vectorWidth > 0 && "vector width must be larger than 0");
104
105 VectorType inputType = cast<VectorType>(operands[0].getType());
106 ArrayRef<int64_t> inputShape = inputType.getShape();
107
108 // If input shape matches target vector width, we can just call the
109 // user-provided compute function with the operands.
110 if (inputShape == llvm::ArrayRef(vectorWidth))
111 return compute(operands);
112
113 // Check if the inner dimension has to be expanded, or we can directly iterate
114 // over the outer dimensions of the vector.
115 int64_t innerDim = inputShape.back();
116 int64_t expansionDim = innerDim / vectorWidth;
117 assert((innerDim % vectorWidth == 0) && "invalid inner dimension size");
118
119 // Maybe expand operands to the higher rank vector shape that we'll use to
120 // iterate over and extract one dimensional vectors.
121 SmallVector<int64_t> expandedShape(inputShape);
122 SmallVector<Value> expandedOperands(operands);
123
124 if (expansionDim > 1) {
125 // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth].
126 expandedShape.insert(I: expandedShape.end() - 1, Elt: expansionDim);
127 expandedShape.back() = vectorWidth;
128
129 for (unsigned i = 0; i < operands.size(); ++i) {
130 auto operand = operands[i];
131 auto eltType = cast<VectorType>(operand.getType()).getElementType();
132 auto expandedType = VectorType::get(expandedShape, eltType);
133 expandedOperands[i] =
134 builder.create<vector::ShapeCastOp>(expandedType, operand);
135 }
136 }
137
138 // Iterate over all outer dimensions of the compute shape vector type.
139 auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
140 int64_t maxIndex = computeMaxLinearIndex(iterationDims);
141 auto strides = computeStrides(iterationDims);
142
143 // Compute results for each one dimensional vector.
144 SmallVector<Value> results(maxIndex);
145
146 for (int64_t i = 0; i < maxIndex; ++i) {
147 auto offsets = delinearize(i, strides);
148
149 SmallVector<Value> extracted(expandedOperands.size());
150 for (const auto &tuple : llvm::enumerate(expandedOperands))
151 extracted[tuple.index()] =
152 builder.create<vector::ExtractOp>(tuple.value(), offsets);
153
154 results[i] = compute(extracted);
155 }
156
157 // Stitch results together into one large vector.
158 Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
159 Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
160 Value result = builder.create<arith::ConstantOp>(
161 resultExpandedType, builder.getZeroAttr(resultExpandedType));
162
163 for (int64_t i = 0; i < maxIndex; ++i)
164 result = builder.create<vector::InsertOp>(results[i], result,
165 delinearize(i, strides));
166
167 // Reshape back to the original vector shape.
168 return builder.create<vector::ShapeCastOp>(
169 VectorType::get(inputShape, resultEltType), result);
170}
171
172//----------------------------------------------------------------------------//
173// Helper functions to create constants.
174//----------------------------------------------------------------------------//
175
176static Value boolCst(ImplicitLocOpBuilder &builder, bool value) {
177 return builder.create<arith::ConstantOp>(builder.getBoolAttr(value));
178}
179
180static Value floatCst(ImplicitLocOpBuilder &builder, float value,
181 Type elementType) {
182 assert((elementType.isF16() || elementType.isF32()) &&
183 "x must be f16 or f32 type.");
184 return builder.create<arith::ConstantOp>(
185 builder.getFloatAttr(elementType, value));
186}
187
188static Value f32Cst(ImplicitLocOpBuilder &builder, double value) {
189 return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
190}
191
192static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
193 return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value));
194}
195
196static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
197 Value i32Value = i32Cst(builder, value: static_cast<int32_t>(bits));
198 return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value);
199}
200
201//----------------------------------------------------------------------------//
202// Helper functions to build math functions approximations.
203//----------------------------------------------------------------------------//
204
205// Return the minimum of the two values or NaN if value is NaN
206static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) {
207 return builder.create<arith::SelectOp>(
208 builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
209 value, bound);
210}
211
212// Return the maximum of the two values or NaN if value is NaN
213static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) {
214 return builder.create<arith::SelectOp>(
215 builder.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
216 value, bound);
217}
218
219// Return the clamped value or NaN if value is NaN
220static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
221 Value upperBound) {
222 return max(builder, value: min(builder, value, bound: upperBound), bound: lowerBound);
223}
224
225// Decomposes given floating point value `arg` into a normalized fraction and
226// an integral power of two (see std::frexp). Returned values have float type.
227static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
228 bool isPositive = false) {
229 assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
230 std::optional<VectorShape> shape = vectorShape(value: arg);
231
232 auto bcast = [&](Value value) -> Value {
233 return broadcast(builder, value, shape);
234 };
235
236 auto i32 = builder.getIntegerType(32);
237 auto i32Vec = broadcast(i32, shape);
238 auto f32Vec = broadcast(builder.getF32Type(), shape);
239
240 Value cst126f = f32Cst(builder, value: 126.0f);
241 Value cstHalf = f32Cst(builder, value: 0.5f);
242 Value cstInvMantMask = f32FromBits(builder, bits: ~0x7f800000u);
243
244 // Bitcast to i32 for bitwise operations.
245 Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf);
246 Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask);
247 Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg);
248
249 // Compute normalized fraction.
250 Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
251 Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half));
252 Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1);
253
254 // Compute exponent.
255 Value arg0 = isPositive ? arg : builder.create<math::AbsFOp>(arg);
256 Value biasedExponentBits = builder.create<arith::ShRUIOp>(
257 builder.create<arith::BitcastOp>(i32Vec, arg0),
258 bcast(i32Cst(builder, 23)));
259 Value biasedExponent =
260 builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
261 Value exponent =
262 builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f));
263
264 return {normalizedFraction, exponent};
265}
266
267// Computes exp2 for an i32 argument.
268static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
269 assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
270 std::optional<VectorShape> shape = vectorShape(value: arg);
271
272 auto bcast = [&](Value value) -> Value {
273 return broadcast(builder, value, shape);
274 };
275
276 auto f32Vec = broadcast(builder.getF32Type(), shape);
277 // The exponent of f32 located at 23-bit.
278 auto exponetBitLocation = bcast(i32Cst(builder, value: 23));
279 // Set the exponent bias to zero.
280 auto bias = bcast(i32Cst(builder, value: 127));
281
282 Value biasedArg = builder.create<arith::AddIOp>(arg, bias);
283 Value exp2ValueInt =
284 builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation);
285 Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt);
286
287 return exp2ValueF32;
288}
289
290namespace {
291Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
292 llvm::ArrayRef<Value> coeffs, Value x) {
293 Type elementType = getElementTypeOrSelf(val: x);
294 assert((elementType.isF32() || elementType.isF16()) &&
295 "x must be f32 or f16 type");
296 std::optional<VectorShape> shape = vectorShape(value: x);
297
298 if (coeffs.empty())
299 return broadcast(builder, value: floatCst(builder, value: 0.0f, elementType), shape);
300
301 if (coeffs.size() == 1)
302 return coeffs[0];
303
304 Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
305 coeffs[coeffs.size() - 2]);
306 for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
307 res = builder.create<math::FmaOp>(x, res, coeffs[i]);
308 }
309 return res;
310}
311} // namespace
312
313//----------------------------------------------------------------------------//
314// Helper function/pattern to insert casts for reusing F32 bit expansion.
315//----------------------------------------------------------------------------//
316
317template <typename T>
318LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
319 // Conservatively only allow where the operand and result types are exactly 1.
320 Type origType = op->getResultTypes().front();
321 for (Type t : llvm::drop_begin(RangeOrContainer: op->getResultTypes()))
322 if (origType != t)
323 return rewriter.notifyMatchFailure(arg&: op, msg: "required all types to match");
324 for (Type t : op->getOperandTypes())
325 if (origType != t)
326 return rewriter.notifyMatchFailure(arg&: op, msg: "required all types to match");
327
328 // Skip if already F32 or larger than 32 bits.
329 if (getElementTypeOrSelf(type: origType).isF32() ||
330 getElementTypeOrSelf(type: origType).getIntOrFloatBitWidth() > 32)
331 return failure();
332
333 // Create F32 equivalent type.
334 Type newType;
335 if (auto shaped = dyn_cast<ShapedType>(origType)) {
336 newType = shaped.clone(rewriter.getF32Type());
337 } else if (isa<FloatType>(Val: origType)) {
338 newType = rewriter.getF32Type();
339 } else {
340 return rewriter.notifyMatchFailure(arg&: op,
341 msg: "unable to find F32 equivalent type");
342 }
343
344 Location loc = op->getLoc();
345 SmallVector<Value> operands;
346 for (auto operand : op->getOperands())
347 operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand));
348 auto result =
349 rewriter.create<T>(loc, TypeRange{newType}, operands, op->getAttrs());
350 rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
351 return success();
352}
353
354namespace {
355// Pattern to cast to F32 to reuse F32 expansion as fallback for single-result
356// op.
357// TODO: Consider revising to avoid adding multiple casts for a subgraph that is
358// all in lower precision. Currently this is only fallback support and performs
359// simplistic casting.
360template <typename T>
361struct ReuseF32Expansion : public OpRewritePattern<T> {
362public:
363 using OpRewritePattern<T>::OpRewritePattern;
364 LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final {
365 static_assert(
366 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
367 "requires same operands and result types");
368 return insertCasts<T>(op, rewriter);
369 }
370};
371} // namespace
372
373//----------------------------------------------------------------------------//
374// AtanOp approximation.
375//----------------------------------------------------------------------------//
376
377namespace {
378struct AtanApproximation : public OpRewritePattern<math::AtanOp> {
379public:
380 using OpRewritePattern::OpRewritePattern;
381
382 LogicalResult matchAndRewrite(math::AtanOp op,
383 PatternRewriter &rewriter) const final;
384};
385} // namespace
386
387LogicalResult
388AtanApproximation::matchAndRewrite(math::AtanOp op,
389 PatternRewriter &rewriter) const {
390 auto operand = op.getOperand();
391 if (!getElementTypeOrSelf(operand).isF32())
392 return rewriter.notifyMatchFailure(op, "unsupported operand type");
393
394 std::optional<VectorShape> shape = vectorShape(op.getOperand());
395
396 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
397 Value abs = builder.create<math::AbsFOp>(operand);
398
399 auto one = broadcast(builder, value: f32Cst(builder, value: 1.0), shape);
400
401 // When 0.66 < x <= 2.41 we do (x-1) / (x+1):
402 auto twoThirds = broadcast(builder, value: f32Cst(builder, value: 0.66), shape);
403 Value cmp2 =
404 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, twoThirds);
405 Value addone = builder.create<arith::AddFOp>(abs, one);
406 Value subone = builder.create<arith::SubFOp>(abs, one);
407 Value xnum = builder.create<arith::SelectOp>(cmp2, subone, abs);
408 Value xden = builder.create<arith::SelectOp>(cmp2, addone, one);
409
410 auto bcast = [&](Value value) -> Value {
411 return broadcast(builder, value, shape);
412 };
413
414 // Break into the <= 0.66 or > 2.41 we do x or 1/x:
415 auto tan3pio8 = bcast(f32Cst(builder, value: 2.41421356237309504880));
416 Value cmp1 =
417 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, tan3pio8);
418 xnum = builder.create<arith::SelectOp>(cmp1, one, xnum);
419 xden = builder.create<arith::SelectOp>(cmp1, abs, xden);
420
421 Value x = builder.create<arith::DivFOp>(xnum, xden);
422 Value xx = builder.create<arith::MulFOp>(x, x);
423
424 // Perform the Taylor series approximation for atan over the range
425 // [0.0, 0.66].
426 auto p0 = bcast(f32Cst(builder, value: -8.750608600031904122785e-01));
427 auto p1 = bcast(f32Cst(builder, value: -1.615753718733365076637e+01));
428 auto p2 = bcast(f32Cst(builder, value: -7.500855792314704667340e+01));
429 auto p3 = bcast(f32Cst(builder, value: -1.228866684490136173410e+02));
430 auto p4 = bcast(f32Cst(builder, value: -6.485021904942025371773e+01));
431 auto q0 = bcast(f32Cst(builder, value: +2.485846490142306297962e+01));
432 auto q1 = bcast(f32Cst(builder, value: +1.650270098316988542046e+02));
433 auto q2 = bcast(f32Cst(builder, value: +4.328810604912902668951e+02));
434 auto q3 = bcast(f32Cst(builder, value: +4.853903996359136964868e+02));
435 auto q4 = bcast(f32Cst(builder, value: +1.945506571482613964425e+02));
436
437 // Apply the polynomial approximation for the numerator:
438 Value n = p0;
439 n = builder.create<math::FmaOp>(xx, n, p1);
440 n = builder.create<math::FmaOp>(xx, n, p2);
441 n = builder.create<math::FmaOp>(xx, n, p3);
442 n = builder.create<math::FmaOp>(xx, n, p4);
443 n = builder.create<arith::MulFOp>(n, xx);
444
445 // Apply the polynomial approximation for the denominator:
446 Value d = q0;
447 d = builder.create<math::FmaOp>(xx, d, q1);
448 d = builder.create<math::FmaOp>(xx, d, q2);
449 d = builder.create<math::FmaOp>(xx, d, q3);
450 d = builder.create<math::FmaOp>(xx, d, q4);
451
452 // Compute approximation of theta:
453 Value ans0 = builder.create<arith::DivFOp>(n, d);
454 ans0 = builder.create<math::FmaOp>(ans0, x, x);
455
456 // Correct for the input mapping's angles:
457 Value mpi4 = bcast(f32Cst(builder, value: llvm::numbers::pi / 4));
458 Value ans2 = builder.create<arith::AddFOp>(mpi4, ans0);
459 Value ans = builder.create<arith::SelectOp>(cmp2, ans2, ans0);
460
461 Value mpi2 = bcast(f32Cst(builder, value: llvm::numbers::pi / 2));
462 Value ans1 = builder.create<arith::SubFOp>(mpi2, ans0);
463 ans = builder.create<arith::SelectOp>(cmp1, ans1, ans);
464
465 // Correct for signing of the input.
466 rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand);
467 return success();
468}
469
470//----------------------------------------------------------------------------//
471// AtanOp approximation.
472//----------------------------------------------------------------------------//
473
474namespace {
475struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> {
476public:
477 using OpRewritePattern::OpRewritePattern;
478
479 LogicalResult matchAndRewrite(math::Atan2Op op,
480 PatternRewriter &rewriter) const final;
481};
482} // namespace
483
484LogicalResult
485Atan2Approximation::matchAndRewrite(math::Atan2Op op,
486 PatternRewriter &rewriter) const {
487 auto y = op.getOperand(0);
488 auto x = op.getOperand(1);
489 if (!getElementTypeOrSelf(x).isF32())
490 return rewriter.notifyMatchFailure(op, "unsupported operand type");
491
492 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
493 std::optional<VectorShape> shape = vectorShape(op.getResult());
494
495 // Compute atan in the valid range.
496 auto div = builder.create<arith::DivFOp>(y, x);
497 auto atan = builder.create<math::AtanOp>(div);
498
499 // Determine what the atan would be for a 180 degree rotation.
500 auto zero = broadcast(builder, value: f32Cst(builder, value: 0.0f), shape);
501 auto pi = broadcast(builder, value: f32Cst(builder, value: 3.14159265359f), shape);
502 auto addPi = builder.create<arith::AddFOp>(atan, pi);
503 auto subPi = builder.create<arith::SubFOp>(atan, pi);
504 auto atanGt =
505 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
506 auto flippedAtan = builder.create<arith::SelectOp>(atanGt, subPi, addPi);
507
508 // Determine whether to directly use atan or use the 180 degree flip
509 auto xGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
510 Value result = builder.create<arith::SelectOp>(xGt, atan, flippedAtan);
511
512 // Handle x = 0, y > 0
513 Value xZero =
514 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
515 Value yGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
516 Value isHalfPi = builder.create<arith::AndIOp>(xZero, yGt);
517 auto halfPi = broadcast(builder, value: f32Cst(builder, value: 1.57079632679f), shape);
518 result = builder.create<arith::SelectOp>(isHalfPi, halfPi, result);
519
520 // Handle x = 0, y < 0
521 Value yLt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
522 Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(xZero, yLt);
523 auto negativeHalfPiPi =
524 broadcast(builder, value: f32Cst(builder, value: -1.57079632679f), shape);
525 result = builder.create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
526 result);
527
528 // Handle x = 0, y = 0;
529 Value yZero =
530 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
531 Value isNan = builder.create<arith::AndIOp>(xZero, yZero);
532 Value cstNan = broadcast(builder, value: f32FromBits(builder, bits: 0x7fc00000), shape);
533 result = builder.create<arith::SelectOp>(isNan, cstNan, result);
534
535 rewriter.replaceOp(op, result);
536 return success();
537}
538
539//----------------------------------------------------------------------------//
540// TanhOp approximation.
541//----------------------------------------------------------------------------//
542
543namespace {
544struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
545public:
546 using OpRewritePattern::OpRewritePattern;
547
548 LogicalResult matchAndRewrite(math::TanhOp op,
549 PatternRewriter &rewriter) const final;
550};
551} // namespace
552
553LogicalResult
554TanhApproximation::matchAndRewrite(math::TanhOp op,
555 PatternRewriter &rewriter) const {
556 if (!getElementTypeOrSelf(op.getOperand()).isF32())
557 return rewriter.notifyMatchFailure(op, "unsupported operand type");
558
559 std::optional<VectorShape> shape = vectorShape(op.getOperand());
560
561 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
562 auto bcast = [&](Value value) -> Value {
563 return broadcast(builder, value, shape);
564 };
565
566 // Clamp operand into [plusClamp, minusClamp] range.
567 Value minusClamp = bcast(f32Cst(builder, value: -7.99881172180175781f));
568 Value plusClamp = bcast(f32Cst(builder, value: 7.99881172180175781f));
569 Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp);
570
571 // Mask for tiny values that are approximated with `operand`.
572 Value tiny = bcast(f32Cst(builder, value: 0.0004f));
573 Value tinyMask = builder.create<arith::CmpFOp>(
574 arith::CmpFPredicate::OLT, builder.create<math::AbsFOp>(op.getOperand()),
575 tiny);
576
577 // The monomial coefficients of the numerator polynomial (odd).
578 Value alpha1 = bcast(f32Cst(builder, value: 4.89352455891786e-03f));
579 Value alpha3 = bcast(f32Cst(builder, value: 6.37261928875436e-04f));
580 Value alpha5 = bcast(f32Cst(builder, value: 1.48572235717979e-05f));
581 Value alpha7 = bcast(f32Cst(builder, value: 5.12229709037114e-08f));
582 Value alpha9 = bcast(f32Cst(builder, value: -8.60467152213735e-11f));
583 Value alpha11 = bcast(f32Cst(builder, value: 2.00018790482477e-13f));
584 Value alpha13 = bcast(f32Cst(builder, value: -2.76076847742355e-16f));
585
586 // The monomial coefficients of the denominator polynomial (even).
587 Value beta0 = bcast(f32Cst(builder, value: 4.89352518554385e-03f));
588 Value beta2 = bcast(f32Cst(builder, value: 2.26843463243900e-03f));
589 Value beta4 = bcast(f32Cst(builder, value: 1.18534705686654e-04f));
590 Value beta6 = bcast(f32Cst(builder, value: 1.19825839466702e-06f));
591
592 // Since the polynomials are odd/even, we need x^2.
593 Value x2 = builder.create<arith::MulFOp>(x, x);
594
595 // Evaluate the numerator polynomial p.
596 Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11);
597 p = builder.create<math::FmaOp>(x2, p, alpha9);
598 p = builder.create<math::FmaOp>(x2, p, alpha7);
599 p = builder.create<math::FmaOp>(x2, p, alpha5);
600 p = builder.create<math::FmaOp>(x2, p, alpha3);
601 p = builder.create<math::FmaOp>(x2, p, alpha1);
602 p = builder.create<arith::MulFOp>(x, p);
603
604 // Evaluate the denominator polynomial q.
605 Value q = builder.create<math::FmaOp>(x2, beta6, beta4);
606 q = builder.create<math::FmaOp>(x2, q, beta2);
607 q = builder.create<math::FmaOp>(x2, q, beta0);
608
609 // Divide the numerator by the denominator.
610 Value res = builder.create<arith::SelectOp>(
611 tinyMask, x, builder.create<arith::DivFOp>(p, q));
612
613 rewriter.replaceOp(op, res);
614
615 return success();
616}
617
618#define LN2_VALUE \
619 0.693147180559945309417232121458176568075500134360255254120680009493393621L
620#define LOG2E_VALUE \
621 1.442695040888963407359924681001892137426645954152985934135449406931109219L
622
623//----------------------------------------------------------------------------//
624// LogOp and Log2Op approximation.
625//----------------------------------------------------------------------------//
626
627namespace {
628template <typename Op>
629struct LogApproximationBase : public OpRewritePattern<Op> {
630 using OpRewritePattern<Op>::OpRewritePattern;
631
632 /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise.
633 LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter,
634 bool base2) const;
635};
636} // namespace
637
638// This approximation comes from Julien Pommier's SSE math library.
639// Link: http://gruntthepeon.free.fr/ssemath
640template <typename Op>
641LogicalResult
642LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
643 bool base2) const {
644 if (!getElementTypeOrSelf(op.getOperand()).isF32())
645 return rewriter.notifyMatchFailure(op, "unsupported operand type");
646
647 std::optional<VectorShape> shape = vectorShape(op.getOperand());
648
649 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
650 auto bcast = [&](Value value) -> Value {
651 return broadcast(builder, value, shape);
652 };
653
654 Value cstZero = bcast(f32Cst(builder, value: 0.0f));
655 Value cstOne = bcast(f32Cst(builder, value: 1.0f));
656 Value cstNegHalf = bcast(f32Cst(builder, value: -0.5f));
657
658 // The smallest non denormalized float number.
659 Value cstMinNormPos = bcast(f32FromBits(builder, bits: 0x00800000u));
660 Value cstMinusInf = bcast(f32FromBits(builder, bits: 0xff800000u));
661 Value cstPosInf = bcast(f32FromBits(builder, bits: 0x7f800000u));
662 Value cstNan = bcast(f32FromBits(builder, bits: 0x7fc00000));
663
664 // Polynomial coefficients.
665 Value cstCephesSQRTHF = bcast(f32Cst(builder, value: 0.707106781186547524f));
666 Value cstCephesLogP0 = bcast(f32Cst(builder, value: 7.0376836292E-2f));
667 Value cstCephesLogP1 = bcast(f32Cst(builder, value: -1.1514610310E-1f));
668 Value cstCephesLogP2 = bcast(f32Cst(builder, value: 1.1676998740E-1f));
669 Value cstCephesLogP3 = bcast(f32Cst(builder, value: -1.2420140846E-1f));
670 Value cstCephesLogP4 = bcast(f32Cst(builder, value: +1.4249322787E-1f));
671 Value cstCephesLogP5 = bcast(f32Cst(builder, value: -1.6668057665E-1f));
672 Value cstCephesLogP6 = bcast(f32Cst(builder, value: +2.0000714765E-1f));
673 Value cstCephesLogP7 = bcast(f32Cst(builder, value: -2.4999993993E-1f));
674 Value cstCephesLogP8 = bcast(f32Cst(builder, value: +3.3333331174E-1f));
675
676 Value x = op.getOperand();
677
678 // Truncate input values to the minimum positive normal.
679 x = max(builder, value: x, bound: cstMinNormPos);
680
681 // Extract significant in the range [0.5,1) and exponent.
682 std::pair<Value, Value> pair = frexp(builder, arg: x, /*isPositive=*/true);
683 x = pair.first;
684 Value e = pair.second;
685
686 // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift
687 // by -1.0. The values are then centered around 0, which improves the
688 // stability of the polynomial evaluation:
689 //
690 // if( x < SQRTHF ) {
691 // e -= 1;
692 // x = x + x - 1.0;
693 // } else { x = x - 1.0; }
694 Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
695 cstCephesSQRTHF);
696 Value tmp = builder.create<arith::SelectOp>(mask, x, cstZero);
697
698 x = builder.create<arith::SubFOp>(x, cstOne);
699 e = builder.create<arith::SubFOp>(
700 e, builder.create<arith::SelectOp>(mask, cstOne, cstZero));
701 x = builder.create<arith::AddFOp>(x, tmp);
702
703 Value x2 = builder.create<arith::MulFOp>(x, x);
704 Value x3 = builder.create<arith::MulFOp>(x2, x);
705
706 // Evaluate the polynomial approximant of degree 8 in three parts.
707 Value y0, y1, y2;
708 y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
709 y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
710 y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
711 y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2);
712 y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5);
713 y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8);
714 y0 = builder.create<math::FmaOp>(y0, x3, y1);
715 y0 = builder.create<math::FmaOp>(y0, x3, y2);
716 y0 = builder.create<arith::MulFOp>(y0, x3);
717
718 y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0);
719 x = builder.create<arith::AddFOp>(x, y0);
720
721 if (base2) {
722 Value cstLog2e = bcast(f32Cst(builder, value: static_cast<float>(LOG2E_VALUE)));
723 x = builder.create<math::FmaOp>(x, cstLog2e, e);
724 } else {
725 Value cstLn2 = bcast(f32Cst(builder, value: static_cast<float>(LN2_VALUE)));
726 x = builder.create<math::FmaOp>(e, cstLn2, x);
727 }
728
729 Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
730 op.getOperand(), cstZero);
731 Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
732 op.getOperand(), cstZero);
733 Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
734 op.getOperand(), cstPosInf);
735
736 // Filter out invalid values:
737 // • x == 0 -> -INF
738 // • x < 0 -> NAN
739 // • x == +INF -> +INF
740 Value aproximation = builder.create<arith::SelectOp>(
741 zeroMask, cstMinusInf,
742 builder.create<arith::SelectOp>(
743 invalidMask, cstNan,
744 builder.create<arith::SelectOp>(posInfMask, cstPosInf, x)));
745
746 rewriter.replaceOp(op, aproximation);
747
748 return success();
749}
750
751namespace {
752struct LogApproximation : public LogApproximationBase<math::LogOp> {
753 using LogApproximationBase::LogApproximationBase;
754
755 LogicalResult matchAndRewrite(math::LogOp op,
756 PatternRewriter &rewriter) const final {
757 return logMatchAndRewrite(op, rewriter, /*base2=*/false);
758 }
759};
760} // namespace
761
762namespace {
763struct Log2Approximation : public LogApproximationBase<math::Log2Op> {
764 using LogApproximationBase::LogApproximationBase;
765
766 LogicalResult matchAndRewrite(math::Log2Op op,
767 PatternRewriter &rewriter) const final {
768 return logMatchAndRewrite(op, rewriter, /*base2=*/true);
769 }
770};
771} // namespace
772
773//----------------------------------------------------------------------------//
774// Log1p approximation.
775//----------------------------------------------------------------------------//
776
777namespace {
778struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
779public:
780 using OpRewritePattern::OpRewritePattern;
781
782 LogicalResult matchAndRewrite(math::Log1pOp op,
783 PatternRewriter &rewriter) const final;
784};
785} // namespace
786
787// Approximate log(1+x).
788LogicalResult
789Log1pApproximation::matchAndRewrite(math::Log1pOp op,
790 PatternRewriter &rewriter) const {
791 if (!getElementTypeOrSelf(op.getOperand()).isF32())
792 return rewriter.notifyMatchFailure(op, "unsupported operand type");
793
794 std::optional<VectorShape> shape = vectorShape(op.getOperand());
795
796 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
797 auto bcast = [&](Value value) -> Value {
798 return broadcast(builder, value, shape);
799 };
800
801 // Approximate log(1+x) using the following, due to W. Kahan:
802 // u = x + 1.0;
803 // if (u == 1.0 || u == inf) return x;
804 // return x * log(u) / (u - 1.0);
805 // ^^^^^^^^^^^^^^^^^^^^^^
806 // "logLarge" below.
807 Value cstOne = bcast(f32Cst(builder, value: 1.0f));
808 Value x = op.getOperand();
809 Value u = builder.create<arith::AddFOp>(x, cstOne);
810 Value uSmall =
811 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
812 Value logU = builder.create<math::LogOp>(u);
813 Value uInf =
814 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
815 Value logLarge = builder.create<arith::MulFOp>(
816 x, builder.create<arith::DivFOp>(
817 logU, builder.create<arith::SubFOp>(u, cstOne)));
818 Value approximation = builder.create<arith::SelectOp>(
819 builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge);
820 rewriter.replaceOp(op, approximation);
821 return success();
822}
823
824//----------------------------------------------------------------------------//
825// Asin approximation.
826//----------------------------------------------------------------------------//
827
828// Approximates asin(x).
829// This approximation is based on the following stackoverflow post:
830// https://stackoverflow.com/a/42683455
831namespace {
832struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> {
833public:
834 using OpRewritePattern::OpRewritePattern;
835
836 LogicalResult matchAndRewrite(math::AsinOp op,
837 PatternRewriter &rewriter) const final;
838};
839} // namespace
840LogicalResult
841AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
842 PatternRewriter &rewriter) const {
843 Value operand = op.getOperand();
844 Type elementType = getElementTypeOrSelf(val: operand);
845
846 if (!(elementType.isF32() || elementType.isF16()))
847 return rewriter.notifyMatchFailure(op,
848 "only f32 and f16 type is supported.");
849 std::optional<VectorShape> shape = vectorShape(value: operand);
850
851 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
852 auto bcast = [&](Value value) -> Value {
853 return broadcast(builder, value, shape);
854 };
855
856 auto fma = [&](Value a, Value b, Value c) -> Value {
857 return builder.create<math::FmaOp>(a, b, c);
858 };
859
860 auto mul = [&](Value a, Value b) -> Value {
861 return builder.create<arith::MulFOp>(a, b);
862 };
863
864 auto sub = [&](Value a, Value b) -> Value {
865 return builder.create<arith::SubFOp>(a, b);
866 };
867
868 auto abs = [&](Value a) -> Value { return builder.create<math::AbsFOp>(a); };
869
870 auto sqrt = [&](Value a) -> Value { return builder.create<math::SqrtOp>(a); };
871
872 auto scopy = [&](Value a, Value b) -> Value {
873 return builder.create<math::CopySignOp>(a, b);
874 };
875
876 auto sel = [&](Value a, Value b, Value c) -> Value {
877 return builder.create<arith::SelectOp>(a, b, c);
878 };
879
880 Value abso = abs(operand);
881 Value aa = mul(operand, operand);
882 Value opp = sqrt(sub(bcast(floatCst(builder, value: 1.0, elementType)), aa));
883
884 Value gt =
885 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, aa,
886 bcast(floatCst(builder, 0.5, elementType)));
887
888 Value x = sel(gt, opp, abso);
889
890 // Asin(x) approximation for x = [-9/16, 9/16]:
891 Value s = mul(x, x);
892 Value q = mul(s, s);
893 Value r = bcast(floatCst(builder, value: 5.5579749017470502e-2, elementType));
894 Value t = bcast(floatCst(builder, value: -6.2027913464120114e-2, elementType));
895
896 r = fma(r, q, bcast(floatCst(builder, value: 5.4224464349245036e-2, elementType)));
897 t = fma(t, q, bcast(floatCst(builder, value: -1.1326992890324464e-2, elementType)));
898 r = fma(r, q, bcast(floatCst(builder, value: 1.5268872539397656e-2, elementType)));
899 t = fma(t, q, bcast(floatCst(builder, value: 1.0493798473372081e-2, elementType)));
900 r = fma(r, q, bcast(floatCst(builder, value: 1.4106045900607047e-2, elementType)));
901 t = fma(t, q, bcast(floatCst(builder, value: 1.7339776384962050e-2, elementType)));
902 r = fma(r, q, bcast(floatCst(builder, value: 2.2372961589651054e-2, elementType)));
903 t = fma(t, q, bcast(floatCst(builder, value: 3.0381912707941005e-2, elementType)));
904 r = fma(r, q, bcast(floatCst(builder, value: 4.4642857881094775e-2, elementType)));
905 t = fma(t, q, bcast(floatCst(builder, value: 7.4999999991367292e-2, elementType)));
906 r = fma(r, s, t);
907 r = fma(r, s, bcast(floatCst(builder, value: 1.6666666666670193e-1, elementType)));
908 t = mul(x, s);
909 r = fma(r, t, x);
910
911 Value rsub = sub(bcast(floatCst(builder, value: 1.57079632679, elementType)), r);
912 r = sel(gt, rsub, r);
913 r = scopy(r, operand);
914
915 rewriter.replaceOp(op, r);
916 return success();
917}
918
919//----------------------------------------------------------------------------//
920// Acos approximation.
921//----------------------------------------------------------------------------//
922
923// Approximates acos(x).
924// This approximation is based on the following stackoverflow post:
925// https://stackoverflow.com/a/42683455
926namespace {
927struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> {
928public:
929 using OpRewritePattern::OpRewritePattern;
930
931 LogicalResult matchAndRewrite(math::AcosOp op,
932 PatternRewriter &rewriter) const final;
933};
934} // namespace
935LogicalResult
936AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
937 PatternRewriter &rewriter) const {
938 Value operand = op.getOperand();
939 Type elementType = getElementTypeOrSelf(val: operand);
940
941 if (!(elementType.isF32() || elementType.isF16()))
942 return rewriter.notifyMatchFailure(op,
943 "only f32 and f16 type is supported.");
944 std::optional<VectorShape> shape = vectorShape(value: operand);
945
946 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
947 auto bcast = [&](Value value) -> Value {
948 return broadcast(builder, value, shape);
949 };
950
951 auto fma = [&](Value a, Value b, Value c) -> Value {
952 return builder.create<math::FmaOp>(a, b, c);
953 };
954
955 auto mul = [&](Value a, Value b) -> Value {
956 return builder.create<arith::MulFOp>(a, b);
957 };
958
959 Value negOperand = builder.create<arith::NegFOp>(operand);
960 Value zero = bcast(floatCst(builder, value: 0.0, elementType));
961 Value half = bcast(floatCst(builder, value: 0.5, elementType));
962 Value negOne = bcast(floatCst(builder, value: -1.0, elementType));
963 Value selR =
964 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
965 Value r = builder.create<arith::SelectOp>(selR, negOperand, operand);
966 Value chkConst = bcast(floatCst(builder, value: -0.5625, elementType));
967 Value firstPred =
968 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
969
970 Value trueVal =
971 fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)),
972 bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
973 builder.create<math::AsinOp>(r));
974
975 Value falseVal = builder.create<math::SqrtOp>(fma(half, r, half));
976 falseVal = builder.create<math::AsinOp>(falseVal);
977 falseVal = mul(bcast(floatCst(builder, value: 2.0, elementType)), falseVal);
978
979 r = builder.create<arith::SelectOp>(firstPred, trueVal, falseVal);
980
981 // Check whether the operand lies in between [-1.0, 0.0).
982 Value greaterThanNegOne =
983 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
984
985 Value lessThanZero =
986 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
987
988 Value betweenNegOneZero =
989 builder.create<arith::AndIOp>(greaterThanNegOne, lessThanZero);
990
991 trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)),
992 bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
993 builder.create<arith::NegFOp>(r));
994
995 Value finalVal =
996 builder.create<arith::SelectOp>(betweenNegOneZero, trueVal, r);
997
998 rewriter.replaceOp(op, finalVal);
999 return success();
1000}
1001
1002//----------------------------------------------------------------------------//
1003// Erf approximation.
1004//----------------------------------------------------------------------------//
1005
1006// Approximates erf(x) with
1007// a - P(x)/Q(x)
1008// where P and Q are polynomials of degree 4.
1009// Different coefficients are chosen based on the value of x.
1010// The approximation error is ~2.5e-07.
1011// Boost's minimax tool that utilizes the Remez method was used to find the
1012// coefficients.
1013LogicalResult
1014ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
1015 PatternRewriter &rewriter) const {
1016 Value operand = op.getOperand();
1017 Type elementType = getElementTypeOrSelf(val: operand);
1018
1019 if (!(elementType.isF32() || elementType.isF16()))
1020 return rewriter.notifyMatchFailure(op,
1021 "only f32 and f16 type is supported.");
1022 std::optional<VectorShape> shape = vectorShape(value: operand);
1023
1024 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1025 auto bcast = [&](Value value) -> Value {
1026 return broadcast(builder, value, shape);
1027 };
1028
1029 const int intervalsCount = 3;
1030 const int polyDegree = 4;
1031
1032 Value zero = bcast(floatCst(builder, value: 0, elementType));
1033 Value one = bcast(floatCst(builder, value: 1, elementType));
1034 Value pp[intervalsCount][polyDegree + 1];
1035 pp[0][0] = bcast(floatCst(builder, value: +0.00000000000000000e+00f, elementType));
1036 pp[0][1] = bcast(floatCst(builder, value: +1.12837916222975858e+00f, elementType));
1037 pp[0][2] = bcast(floatCst(builder, value: -5.23018562988006470e-01f, elementType));
1038 pp[0][3] = bcast(floatCst(builder, value: +2.09741709609267072e-01f, elementType));
1039 pp[0][4] = bcast(floatCst(builder, value: +2.58146801602987875e-02f, elementType));
1040 pp[1][0] = bcast(floatCst(builder, value: +0.00000000000000000e+00f, elementType));
1041 pp[1][1] = bcast(floatCst(builder, value: +1.12750687816789140e+00f, elementType));
1042 pp[1][2] = bcast(floatCst(builder, value: -3.64721408487825775e-01f, elementType));
1043 pp[1][3] = bcast(floatCst(builder, value: +1.18407396425136952e-01f, elementType));
1044 pp[1][4] = bcast(floatCst(builder, value: +3.70645533056476558e-02f, elementType));
1045 pp[2][0] = bcast(floatCst(builder, value: -3.30093071049483172e-03f, elementType));
1046 pp[2][1] = bcast(floatCst(builder, value: +3.51961938357697011e-03f, elementType));
1047 pp[2][2] = bcast(floatCst(builder, value: -1.41373622814988039e-03f, elementType));
1048 pp[2][3] = bcast(floatCst(builder, value: +2.53447094961941348e-04f, elementType));
1049 pp[2][4] = bcast(floatCst(builder, value: -1.71048029455037401e-05f, elementType));
1050
1051 Value qq[intervalsCount][polyDegree + 1];
1052 qq[0][0] = bcast(floatCst(builder, value: +1.000000000000000000e+00f, elementType));
1053 qq[0][1] = bcast(floatCst(builder, value: -4.635138185962547255e-01f, elementType));
1054 qq[0][2] = bcast(floatCst(builder, value: +5.192301327279782447e-01f, elementType));
1055 qq[0][3] = bcast(floatCst(builder, value: -1.318089722204810087e-01f, elementType));
1056 qq[0][4] = bcast(floatCst(builder, value: +7.397964654672315005e-02f, elementType));
1057 qq[1][0] = bcast(floatCst(builder, value: +1.00000000000000000e+00f, elementType));
1058 qq[1][1] = bcast(floatCst(builder, value: -3.27607011824493086e-01f, elementType));
1059 qq[1][2] = bcast(floatCst(builder, value: +4.48369090658821977e-01f, elementType));
1060 qq[1][3] = bcast(floatCst(builder, value: -8.83462621207857930e-02f, elementType));
1061 qq[1][4] = bcast(floatCst(builder, value: +5.72442770283176093e-02f, elementType));
1062 qq[2][0] = bcast(floatCst(builder, value: +1.00000000000000000e+00f, elementType));
1063 qq[2][1] = bcast(floatCst(builder, value: -2.06069165953913769e+00f, elementType));
1064 qq[2][2] = bcast(floatCst(builder, value: +1.62705939945477759e+00f, elementType));
1065 qq[2][3] = bcast(floatCst(builder, value: -5.83389859211130017e-01f, elementType));
1066 qq[2][4] = bcast(floatCst(builder, value: +8.21908939856640930e-02f, elementType));
1067
1068 Value offsets[intervalsCount];
1069 offsets[0] = bcast(floatCst(builder, value: 0.0f, elementType));
1070 offsets[1] = bcast(floatCst(builder, value: 0.0f, elementType));
1071 offsets[2] = bcast(floatCst(builder, value: 1.0f, elementType));
1072
1073 Value bounds[intervalsCount];
1074 bounds[0] = bcast(floatCst(builder, value: 0.8f, elementType));
1075 bounds[1] = bcast(floatCst(builder, value: 2.0f, elementType));
1076 bounds[2] = bcast(floatCst(builder, value: 3.75f, elementType));
1077
1078 Value isNegativeArg =
1079 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
1080 Value negArg = builder.create<arith::NegFOp>(operand);
1081 Value x = builder.create<arith::SelectOp>(isNegativeArg, negArg, operand);
1082
1083 Value offset = offsets[0];
1084 Value p[polyDegree + 1];
1085 Value q[polyDegree + 1];
1086 for (int i = 0; i <= polyDegree; ++i) {
1087 p[i] = pp[0][i];
1088 q[i] = qq[0][i];
1089 }
1090
1091 // TODO: maybe use vector stacking to reduce the number of selects.
1092 Value isLessThanBound[intervalsCount];
1093 for (int j = 0; j < intervalsCount - 1; ++j) {
1094 isLessThanBound[j] =
1095 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]);
1096 for (int i = 0; i <= polyDegree; ++i) {
1097 p[i] = builder.create<arith::SelectOp>(isLessThanBound[j], p[i],
1098 pp[j + 1][i]);
1099 q[i] = builder.create<arith::SelectOp>(isLessThanBound[j], q[i],
1100 qq[j + 1][i]);
1101 }
1102 offset = builder.create<arith::SelectOp>(isLessThanBound[j], offset,
1103 offsets[j + 1]);
1104 }
1105 isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>(
1106 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
1107
1108 Value pPoly = makePolynomialCalculation(builder, coeffs: p, x);
1109 Value qPoly = makePolynomialCalculation(builder, coeffs: q, x);
1110 Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly);
1111 Value formula = builder.create<arith::AddFOp>(offset, rationalPoly);
1112 formula = builder.create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
1113 formula, one);
1114
1115 // erf is odd function: erf(x) = -erf(-x).
1116 Value negFormula = builder.create<arith::NegFOp>(formula);
1117 Value res =
1118 builder.create<arith::SelectOp>(isNegativeArg, negFormula, formula);
1119
1120 rewriter.replaceOp(op, res);
1121
1122 return success();
1123}
1124
1125// Approximates erfc(x) with p((x - 2) / (x + 2)), where p is a 9 degree
1126// polynomial.This approximation is based on the following stackoverflow post:
1127// https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf
1128// The stackoverflow post is in turn based on:
1129// M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of
1130// (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36,
1131// No. 153, January 1981, pp. 249-253.
1132//
1133// Maximum error: 2.65 ulps
1134LogicalResult
1135ErfcPolynomialApproximation::matchAndRewrite(math::ErfcOp op,
1136 PatternRewriter &rewriter) const {
1137 Value x = op.getOperand();
1138 Type et = getElementTypeOrSelf(val: x);
1139
1140 if (!et.isF32())
1141 return rewriter.notifyMatchFailure(op, "only f32 type is supported.");
1142 std::optional<VectorShape> shape = vectorShape(value: x);
1143
1144 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1145 auto bcast = [&](Value value) -> Value {
1146 return broadcast(builder, value, shape);
1147 };
1148
1149 Value trueValue = bcast(boolCst(builder, value: true));
1150 Value zero = bcast(floatCst(builder, value: 0.0f, elementType: et));
1151 Value one = bcast(floatCst(builder, value: 1.0f, elementType: et));
1152 Value onehalf = bcast(floatCst(builder, value: 0.5f, elementType: et));
1153 Value neg4 = bcast(floatCst(builder, value: -4.0f, elementType: et));
1154 Value neg2 = bcast(floatCst(builder, value: -2.0f, elementType: et));
1155 Value pos2 = bcast(floatCst(builder, value: 2.0f, elementType: et));
1156 Value posInf = bcast(floatCst(builder, INFINITY, elementType: et));
1157 Value clampVal = bcast(floatCst(builder, value: 10.0546875f, elementType: et));
1158
1159 Value a = builder.create<math::AbsFOp>(x);
1160 Value p = builder.create<arith::AddFOp>(a, pos2);
1161 Value r = builder.create<arith::DivFOp>(one, p);
1162 Value q = builder.create<math::FmaOp>(neg4, r, one);
1163 Value t = builder.create<math::FmaOp>(builder.create<arith::AddFOp>(q, one),
1164 neg2, a);
1165 Value e = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), q, t);
1166 q = builder.create<math::FmaOp>(r, e, q);
1167
1168 p = bcast(floatCst(builder, value: -0x1.a4a000p-12f, elementType: et)); // -4.01139259e-4
1169 Value c1 = bcast(floatCst(builder, value: -0x1.42a260p-10f, elementType: et)); // -1.23075210e-3
1170 p = builder.create<math::FmaOp>(p, q, c1);
1171 Value c2 = bcast(floatCst(builder, value: 0x1.585714p-10f, elementType: et)); // 1.31355342e-3
1172 p = builder.create<math::FmaOp>(p, q, c2);
1173 Value c3 = bcast(floatCst(builder, value: 0x1.1adcc4p-07f, elementType: et)); // 8.63227434e-3
1174 p = builder.create<math::FmaOp>(p, q, c3);
1175 Value c4 = bcast(floatCst(builder, value: -0x1.081b82p-07f, elementType: et)); // -8.05991981e-3
1176 p = builder.create<math::FmaOp>(p, q, c4);
1177 Value c5 = bcast(floatCst(builder, value: -0x1.bc0b6ap-05f, elementType: et)); // -5.42046614e-2
1178 p = builder.create<math::FmaOp>(p, q, c5);
1179 Value c6 = bcast(floatCst(builder, value: 0x1.4ffc46p-03f, elementType: et)); // 1.64055392e-1
1180 p = builder.create<math::FmaOp>(p, q, c6);
1181 Value c7 = bcast(floatCst(builder, value: -0x1.540840p-03f, elementType: et)); // -1.66031361e-1
1182 p = builder.create<math::FmaOp>(p, q, c7);
1183 Value c8 = bcast(floatCst(builder, value: -0x1.7bf616p-04f, elementType: et)); // -9.27639827e-2
1184 p = builder.create<math::FmaOp>(p, q, c8);
1185 Value c9 = bcast(floatCst(builder, value: 0x1.1ba03ap-02f, elementType: et)); // 2.76978403e-1
1186 p = builder.create<math::FmaOp>(p, q, c9);
1187
1188 Value d = builder.create<math::FmaOp>(pos2, a, one);
1189 r = builder.create<arith::DivFOp>(one, d);
1190 q = builder.create<math::FmaOp>(p, r, r);
1191 Value negfa = builder.create<arith::NegFOp>(a);
1192 Value fmaqah = builder.create<math::FmaOp>(q, negfa, onehalf);
1193 Value psubq = builder.create<arith::SubFOp>(p, q);
1194 e = builder.create<math::FmaOp>(fmaqah, pos2, psubq);
1195 r = builder.create<math::FmaOp>(e, r, q);
1196
1197 Value s = builder.create<arith::MulFOp>(a, a);
1198 e = builder.create<math::ExpOp>(builder.create<arith::NegFOp>(s));
1199
1200 t = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), a, s);
1201 r = builder.create<math::FmaOp>(
1202 r, e,
1203 builder.create<arith::MulFOp>(builder.create<arith::MulFOp>(r, e), t));
1204
1205 Value isNotLessThanInf = builder.create<arith::XOrIOp>(
1206 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, posInf),
1207 trueValue);
1208 r = builder.create<arith::SelectOp>(isNotLessThanInf,
1209 builder.create<arith::AddFOp>(x, x), r);
1210 Value isGreaterThanClamp =
1211 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, clampVal);
1212 r = builder.create<arith::SelectOp>(isGreaterThanClamp, zero, r);
1213
1214 Value isNegative =
1215 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
1216 r = builder.create<arith::SelectOp>(
1217 isNegative, builder.create<arith::SubFOp>(pos2, r), r);
1218
1219 rewriter.replaceOp(op, r);
1220 return success();
1221}
1222//----------------------------------------------------------------------------//
1223// Exp approximation.
1224//----------------------------------------------------------------------------//
1225
1226namespace {
1227
1228Value clampWithNormals(ImplicitLocOpBuilder &builder,
1229 const std::optional<VectorShape> shape, Value value,
1230 float lowerBound, float upperBound) {
1231 assert(!std::isnan(lowerBound));
1232 assert(!std::isnan(upperBound));
1233
1234 auto bcast = [&](Value value) -> Value {
1235 return broadcast(builder, value, shape);
1236 };
1237
1238 auto selectCmp = [&builder](auto pred, Value value, Value bound) {
1239 return builder.create<arith::SelectOp>(
1240 builder.create<arith::CmpFOp>(pred, value, bound), value, bound);
1241 };
1242
1243 // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs.
1244 // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with
1245 // arith::{Max,Min}FOp.
1246 value = selectCmp(arith::CmpFPredicate::UGE, value,
1247 bcast(f32Cst(builder, lowerBound)));
1248 value = selectCmp(arith::CmpFPredicate::ULE, value,
1249 bcast(f32Cst(builder, upperBound)));
1250 return value;
1251}
1252
1253struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
1254public:
1255 using OpRewritePattern::OpRewritePattern;
1256
1257 LogicalResult matchAndRewrite(math::ExpOp op,
1258 PatternRewriter &rewriter) const final;
1259};
1260
1261LogicalResult
1262ExpApproximation::matchAndRewrite(math::ExpOp op,
1263 PatternRewriter &rewriter) const {
1264 auto shape = vectorShape(op.getOperand().getType());
1265 auto elementTy = getElementTypeOrSelf(op.getType());
1266 if (!elementTy.isF32())
1267 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1268
1269 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1270
1271 auto add = [&](Value a, Value b) -> Value {
1272 return builder.create<arith::AddFOp>(a, b);
1273 };
1274 auto bcast = [&](Value value) -> Value {
1275 return broadcast(builder, value, shape);
1276 };
1277 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
1278 auto fmla = [&](Value a, Value b, Value c) {
1279 return builder.create<math::FmaOp>(a, b, c);
1280 };
1281 auto mul = [&](Value a, Value b) -> Value {
1282 return builder.create<arith::MulFOp>(a, b);
1283 };
1284
1285 // Polynomial approximation from Cephes.
1286 //
1287 // To compute e^x, we re-express it as
1288 //
1289 // e^x = e^(a + b)
1290 // = e^(a + n log(2))
1291 // = e^a * 2^n.
1292 //
1293 // We choose n = round(x / log(2)), restricting the value of `a` to
1294 // (-log(2)/2, log(2)/2). We then use a polynomial to compute e^a. The
1295 // relative error between our approximation and the true value of e^a is less
1296 // than 2^-22.5 for all values of `a` within this range.
1297
1298 // Restrict input to a small range, including some values that evaluate to
1299 // +/- inf. Note that for our lower bound, we choose log(2^-126) instead of
1300 // log(F32_EPSILON). We do so because this routine always flushes denormal
1301 // floating points to 0. Therefore, we only need to worry about exponentiating
1302 // up to the smallest representable non-denormal floating point, which is
1303 // 2^-126.
1304
1305 // Constants.
1306 Value cstHalf = bcast(f32Cst(builder, value: 0.5f));
1307 Value cstOne = bcast(f32Cst(builder, value: 1.0f));
1308
1309 // 1/log(2)
1310 Value cstLog2ef = bcast(f32Cst(builder, value: 1.44269504088896341f));
1311
1312 Value cstExpC1 = bcast(f32Cst(builder, value: -0.693359375f));
1313 Value cstExpC2 = bcast(f32Cst(builder, value: 2.12194440e-4f));
1314 Value cstExpP0 = bcast(f32Cst(builder, value: 1.9875691500E-4f));
1315 Value cstExpP1 = bcast(f32Cst(builder, value: 1.3981999507E-3f));
1316 Value cstExpP2 = bcast(f32Cst(builder, value: 8.3334519073E-3f));
1317 Value cstExpP3 = bcast(f32Cst(builder, value: 4.1665795894E-2f));
1318 Value cstExpP4 = bcast(f32Cst(builder, value: 1.6666665459E-1f));
1319 Value cstExpP5 = bcast(f32Cst(builder, value: 5.0000001201E-1f));
1320
1321 // Our computations below aren't particularly sensitive to the exact choices
1322 // here, so we choose values a bit larger/smaller than
1323 //
1324 // log(F32_MAX) = 88.723...
1325 // log(2^-126) = -87.337...
1326 Value x = op.getOperand();
1327 x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1328 Value n = floor(fmla(x, cstLog2ef, cstHalf));
1329
1330 // When we eventually do the multiplication in e^a * 2^n, we need to handle
1331 // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1
1332 // (so e^a * 2^n != inf). There's a similar problem for n < -126, the
1333 // smallest fp32 exponent.
1334 //
1335 // A straightforward solution would be to detect n out of range and split it
1336 // up, doing
1337 //
1338 // e^a * 2^n = e^a * 2^(n1 + n2)
1339 // = (2^n1 * e^a) * 2^n2.
1340 //
1341 // But it turns out this approach is quite slow, probably because it
1342 // manipulates subnormal values.
1343 //
1344 // The approach we use instead is to clamp n to [-127, 127]. Let n' be the
1345 // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow
1346 // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though
1347 // this value of `a` is outside our previously specified range, e^a will still
1348 // only have a relative error of approximately 2^-16 at worse. In practice
1349 // this seems to work well enough; it passes our exhaustive tests, breaking
1350 // only one result, and by one ulp (we return exp(88.7228394) = max-float but
1351 // we should return inf).
1352 //
1353 // In the case where n' = -127, the original input value of x is so small that
1354 // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest
1355 // normal floating point, and since we flush denormals, we simply return 0. We
1356 // do this in a branchless way by observing that our code for constructing 2^n
1357 // produces 0 if n = -127.
1358 //
1359 // The proof that n' = -127 implies e^x < 2^-126 is as follows:
1360 //
1361 // n' = -127 implies n <= -127
1362 // implies round(x / log(2)) <= -127
1363 // implies x/log(2) < -126.5
1364 // implies x < -126.5 * log(2)
1365 // implies e^x < e^(-126.5 * log(2))
1366 // implies e^x < 2^-126.5 < 2^-126
1367 //
1368 // This proves that n' = -127 implies e^x < 2^-126.
1369 n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1370
1371 // Computes x = x - n' * log(2), the value for `a`
1372 x = fmla(cstExpC1, n, x);
1373 x = fmla(cstExpC2, n, x);
1374
1375 // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5).
1376 Value z = fmla(x, cstExpP0, cstExpP1);
1377 z = fmla(z, x, cstExpP2);
1378 z = fmla(z, x, cstExpP3);
1379 z = fmla(z, x, cstExpP4);
1380 z = fmla(z, x, cstExpP5);
1381 z = fmla(z, mul(x, x), x);
1382 z = add(cstOne, z);
1383
1384 // Convert n' to an i32. This is safe because we clamped it above.
1385 auto i32Vec = broadcast(builder.getI32Type(), shape);
1386 Value nI32 = builder.create<arith::FPToSIOp>(i32Vec, n);
1387
1388 // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127.
1389 Value pow2 = exp2I32(builder, arg: nI32);
1390
1391 // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127.
1392 Value ret = mul(z, pow2);
1393
1394 rewriter.replaceOp(op, ret);
1395 return mlir::success();
1396}
1397
1398} // namespace
1399
1400//----------------------------------------------------------------------------//
1401// ExpM1 approximation.
1402//----------------------------------------------------------------------------//
1403
1404namespace {
1405
1406struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
1407public:
1408 using OpRewritePattern::OpRewritePattern;
1409
1410 LogicalResult matchAndRewrite(math::ExpM1Op op,
1411 PatternRewriter &rewriter) const final;
1412};
1413} // namespace
1414
1415LogicalResult
1416ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1417 PatternRewriter &rewriter) const {
1418 if (!getElementTypeOrSelf(op.getOperand()).isF32())
1419 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1420
1421 std::optional<VectorShape> shape = vectorShape(op.getOperand());
1422
1423 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1424 auto bcast = [&](Value value) -> Value {
1425 return broadcast(builder, value, shape);
1426 };
1427
1428 // expm1(x) = exp(x) - 1 = u - 1.
1429 // We have to handle it carefully when x is near 0, i.e. u ~= 1,
1430 // and when the input is ~= -inf, i.e. u - 1 ~= -1.
1431 Value cstOne = bcast(f32Cst(builder, value: 1.0f));
1432 Value cstNegOne = bcast(f32Cst(builder, value: -1.0f));
1433 Value x = op.getOperand();
1434 Value u = builder.create<math::ExpOp>(x);
1435 Value uEqOneOrNaN =
1436 builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1437 Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
1438 Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
1439 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1440 // logU = log(u) ~= x
1441 Value logU = builder.create<math::LogOp>(u);
1442
1443 // Detect exp(x) = +inf; written this way to avoid having to form +inf.
1444 Value isInf =
1445 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1446
1447 // (u - 1) * (x / ~x)
1448 Value expm1 = builder.create<arith::MulFOp>(
1449 uMinusOne, builder.create<arith::DivFOp>(x, logU));
1450 expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
1451 Value approximation = builder.create<arith::SelectOp>(
1452 uEqOneOrNaN, x,
1453 builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1454 rewriter.replaceOp(op, approximation);
1455 return success();
1456}
1457
1458//----------------------------------------------------------------------------//
1459// Sin and Cos approximation.
1460//----------------------------------------------------------------------------//
1461
1462namespace {
1463
1464template <bool isSine, typename OpTy>
1465struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
1466public:
1467 using OpRewritePattern<OpTy>::OpRewritePattern;
1468
1469 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
1470};
1471} // namespace
1472
1473#define TWO_OVER_PI \
1474 0.6366197723675813430755350534900574481378385829618257949906693762L
1475#define PI_OVER_2 \
1476 1.5707963267948966192313216916397514420985846996875529104874722961L
1477
1478// Approximates sin(x) or cos(x) by finding the best approximation polynomial in
1479// the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
1480// reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
1481template <bool isSine, typename OpTy>
1482LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1483 OpTy op, PatternRewriter &rewriter) const {
1484 static_assert(
1485 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1486 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1487
1488 if (!getElementTypeOrSelf(op.getOperand()).isF32())
1489 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1490
1491 std::optional<VectorShape> shape = vectorShape(op.getOperand());
1492
1493 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1494 auto bcast = [&](Value value) -> Value {
1495 return broadcast(builder, value, shape);
1496 };
1497 auto mul = [&](Value a, Value b) -> Value {
1498 return builder.create<arith::MulFOp>(a, b);
1499 };
1500 auto sub = [&](Value a, Value b) -> Value {
1501 return builder.create<arith::SubFOp>(a, b);
1502 };
1503 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
1504
1505 auto i32Vec = broadcast(builder.getI32Type(), shape);
1506 auto fPToSingedInteger = [&](Value a) -> Value {
1507 return builder.create<arith::FPToSIOp>(i32Vec, a);
1508 };
1509
1510 auto modulo4 = [&](Value a) -> Value {
1511 return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3)));
1512 };
1513
1514 auto isEqualTo = [&](Value a, Value b) -> Value {
1515 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1516 };
1517
1518 auto isGreaterThan = [&](Value a, Value b) -> Value {
1519 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1520 };
1521
1522 auto select = [&](Value cond, Value t, Value f) -> Value {
1523 return builder.create<arith::SelectOp>(cond, t, f);
1524 };
1525
1526 auto fmla = [&](Value a, Value b, Value c) {
1527 return builder.create<math::FmaOp>(a, b, c);
1528 };
1529
1530 auto bitwiseOr = [&](Value a, Value b) {
1531 return builder.create<arith::OrIOp>(a, b);
1532 };
1533
1534 Value twoOverPi = bcast(f32Cst(builder, value: (float)TWO_OVER_PI));
1535 Value piOverTwo = bcast(f32Cst(builder, value: (float)PI_OVER_2));
1536
1537 Value x = op.getOperand();
1538
1539 Value k = floor(mul(x, twoOverPi));
1540
1541 Value y = sub(x, mul(k, piOverTwo));
1542
1543 Value cstOne = bcast(f32Cst(builder, value: 1.0));
1544 Value cstNegativeOne = bcast(f32Cst(builder, value: -1.0));
1545
1546 Value cstSC2 = bcast(f32Cst(builder, value: -0.16666667163372039794921875f));
1547 Value cstSC4 = bcast(f32Cst(builder, value: 8.333347737789154052734375e-3f));
1548 Value cstSC6 = bcast(f32Cst(builder, value: -1.9842604524455964565277099609375e-4f));
1549 Value cstSC8 =
1550 bcast(f32Cst(builder, value: 2.760012648650445044040679931640625e-6f));
1551 Value cstSC10 =
1552 bcast(f32Cst(builder, value: -2.50293279435709337121807038784027099609375e-8f));
1553
1554 Value cstCC2 = bcast(f32Cst(builder, value: -0.5f));
1555 Value cstCC4 = bcast(f32Cst(builder, value: 4.166664183139801025390625e-2f));
1556 Value cstCC6 = bcast(f32Cst(builder, value: -1.388833043165504932403564453125e-3f));
1557 Value cstCC8 = bcast(f32Cst(builder, value: 2.47562347794882953166961669921875e-5f));
1558 Value cstCC10 =
1559 bcast(f32Cst(builder, value: -2.59630184018533327616751194000244140625e-7f));
1560
1561 Value kMod4 = modulo4(fPToSingedInteger(k));
1562
1563 Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 0)));
1564 Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 1)));
1565 Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 2)));
1566 Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 3)));
1567
1568 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1569 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, value: 1)))
1570 : bitwiseOr(kR1, kR2);
1571
1572 Value y2 = mul(y, y);
1573
1574 Value base = select(sinuseCos, cstOne, y);
1575 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1576 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1577 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1578 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1579 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1580
1581 Value v1 = fmla(y2, cstC10, cstC8);
1582 Value v2 = fmla(y2, v1, cstC6);
1583 Value v3 = fmla(y2, v2, cstC4);
1584 Value v4 = fmla(y2, v3, cstC2);
1585 Value v5 = fmla(y2, v4, cstOne);
1586 Value v6 = mul(base, v5);
1587
1588 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1589
1590 rewriter.replaceOp(op, approximation);
1591
1592 return success();
1593}
1594
1595//----------------------------------------------------------------------------//
1596// Cbrt approximation.
1597//----------------------------------------------------------------------------//
1598
1599namespace {
1600struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> {
1601 using OpRewritePattern::OpRewritePattern;
1602
1603 LogicalResult matchAndRewrite(math::CbrtOp op,
1604 PatternRewriter &rewriter) const final;
1605};
1606} // namespace
1607
1608// Estimation of cube-root using an algorithm defined in
1609// Hacker's Delight 2nd Edition.
1610LogicalResult
1611CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1612 PatternRewriter &rewriter) const {
1613 auto operand = op.getOperand();
1614 if (!getElementTypeOrSelf(operand).isF32())
1615 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1616
1617 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1618 std::optional<VectorShape> shape = vectorShape(operand);
1619
1620 Type floatTy = getElementTypeOrSelf(operand.getType());
1621 Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
1622
1623 // Convert to vector types if necessary.
1624 floatTy = broadcast(type: floatTy, shape);
1625 intTy = broadcast(type: intTy, shape);
1626
1627 auto bconst = [&](TypedAttr attr) -> Value {
1628 Value value = b.create<arith::ConstantOp>(attr);
1629 return broadcast(builder&: b, value, shape);
1630 };
1631
1632 // Declare the initial values:
1633 Value intTwo = bconst(b.getI32IntegerAttr(2));
1634 Value intFour = bconst(b.getI32IntegerAttr(4));
1635 Value intEight = bconst(b.getI32IntegerAttr(8));
1636 Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1637 Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1638 Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1639 Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1640
1641 // Compute an approximation of one third:
1642 // union {int ix; float x;};
1643 // x = x0;
1644 // ix = ix/4 + ix/16;
1645 Value absValue = b.create<math::AbsFOp>(operand);
1646 Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
1647 Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
1648 Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1649 intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
1650
1651 // ix = ix + ix/16;
1652 divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1653 intValue = b.create<arith::AddIOp>(intValue, divideBy16);
1654
1655 // ix = ix + ix/256;
1656 Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
1657 intValue = b.create<arith::AddIOp>(intValue, divideBy256);
1658
1659 // ix = 0x2a5137a0 + ix;
1660 intValue = b.create<arith::AddIOp>(intValue, intMagic);
1661
1662 // Perform one newtons step:
1663 // x = 0.33333333f*(2.0f*x + x0/(x*x));
1664 Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
1665 Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
1666 Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1667 Value divSquared = b.create<arith::DivFOp>(absValue, squared);
1668 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1669 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1670
1671 // x = 0.33333333f*(2.0f*x + x0/(x*x));
1672 squared = b.create<arith::MulFOp>(floatValue, floatValue);
1673 mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1674 divSquared = b.create<arith::DivFOp>(absValue, squared);
1675 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1676 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1677
1678 // Check for zero and restore sign.
1679 Value isZero =
1680 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
1681 floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
1682 floatValue = b.create<math::CopySignOp>(floatValue, operand);
1683
1684 rewriter.replaceOp(op, floatValue);
1685 return success();
1686}
1687
1688//----------------------------------------------------------------------------//
1689// Rsqrt approximation.
1690//----------------------------------------------------------------------------//
1691
1692namespace {
1693struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
1694 using OpRewritePattern::OpRewritePattern;
1695
1696 LogicalResult matchAndRewrite(math::RsqrtOp op,
1697 PatternRewriter &rewriter) const final;
1698};
1699} // namespace
1700
1701LogicalResult
1702RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1703 PatternRewriter &rewriter) const {
1704 if (!getElementTypeOrSelf(op.getOperand()).isF32())
1705 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1706
1707 std::optional<VectorShape> shape = vectorShape(op.getOperand());
1708
1709 // Only support already-vectorized rsqrt's.
1710 if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0)
1711 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1712
1713 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1714 auto bcast = [&](Value value) -> Value {
1715 return broadcast(builder, value, shape);
1716 };
1717
1718 Value cstPosInf = bcast(f32FromBits(builder, bits: 0x7f800000u));
1719 Value cstOnePointFive = bcast(f32Cst(builder, value: 1.5f));
1720 Value cstNegHalf = bcast(f32Cst(builder, value: -0.5f));
1721 Value cstMinNormPos = bcast(f32FromBits(builder, bits: 0x00800000u));
1722
1723 Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf);
1724
1725 // Select only the inverse sqrt of positive normals (denormals are
1726 // flushed to zero).
1727 Value ltMinMask = builder.create<arith::CmpFOp>(
1728 arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos);
1729 Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1730 op.getOperand(), cstPosInf);
1731 Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask);
1732
1733 // Compute an approximate result.
1734 Value yApprox = handleMultidimensionalVectors(
1735 builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
1736 return builder.create<x86vector::RsqrtOp>(operands);
1737 });
1738
1739 // Do a single step of Newton-Raphson iteration to improve the approximation.
1740 // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
1741 // It is essential to evaluate the inner term like this because forming
1742 // y_n^2 may over- or underflow.
1743 Value inner = builder.create<arith::MulFOp>(negHalf, yApprox);
1744 Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1745 Value yNewton = builder.create<arith::MulFOp>(yApprox, fma);
1746
1747 // Select the result of the Newton-Raphson step for positive normal arguments.
1748 // For other arguments, choose the output of the intrinsic. This will
1749 // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
1750 // x is zero or a positive denormalized float (equivalent to flushing positive
1751 // denormalized inputs to zero).
1752 Value res =
1753 builder.create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1754 rewriter.replaceOp(op, res);
1755
1756 return success();
1757}
1758
1759//----------------------------------------------------------------------------//
1760
1761void mlir::populatePolynomialApproximateTanhPattern(
1762 RewritePatternSet &patterns) {
1763 patterns.add<TanhApproximation>(arg: patterns.getContext());
1764}
1765
1766void mlir::populatePolynomialApproximateErfPattern(
1767 RewritePatternSet &patterns) {
1768 patterns.add<ErfPolynomialApproximation>(arg: patterns.getContext());
1769}
1770
1771void mlir::populatePolynomialApproximateErfcPattern(
1772 RewritePatternSet &patterns) {
1773 patterns.add<ErfcPolynomialApproximation>(arg: patterns.getContext());
1774}
1775
1776template <typename OpType>
1777static void
1778populateMathF32ExpansionPattern(RewritePatternSet &patterns,
1779 llvm::function_ref<bool(StringRef)> predicate,
1780 PatternBenefit benefit) {
1781 if (predicate(OpType::getOperationName())) {
1782 patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext(), benefit);
1783 }
1784}
1785
1786void mlir::populateMathF32ExpansionPatterns(
1787 RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1788 PatternBenefit benefit) {
1789 populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate, benefit);
1790 populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate, benefit);
1791 populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate, benefit);
1792 populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate, benefit);
1793 populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate, benefit);
1794 populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate, benefit);
1795 populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate, benefit);
1796 populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate, benefit);
1797 populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate, benefit);
1798 populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate, benefit);
1799 populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate, benefit);
1800 populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate, benefit);
1801 populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate, benefit);
1802 populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate, benefit);
1803 populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate, benefit);
1804 populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate, benefit);
1805 populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate, benefit);
1806 populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate, benefit);
1807 populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate, benefit);
1808 populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate, benefit);
1809 populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate, benefit);
1810 populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate, benefit);
1811 populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate, benefit);
1812 populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate, benefit);
1813 populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate, benefit);
1814 populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate, benefit);
1815}
1816
1817template <typename OpType, typename PatternType>
1818static void populateMathPolynomialApproximationPattern(
1819 RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1820 PatternBenefit benefit) {
1821 if (predicate(OpType::getOperationName())) {
1822 patterns.add<PatternType>(patterns.getContext(), benefit);
1823 }
1824}
1825
1826void mlir::populateMathPolynomialApproximationPatterns(
1827 RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1828 PatternBenefit benefit) {
1829 populateMathPolynomialApproximationPattern<AcosOp,
1830 AcosPolynomialApproximation>(
1831 patterns, predicate, benefit);
1832 populateMathPolynomialApproximationPattern<AsinOp,
1833 AsinPolynomialApproximation>(
1834 patterns, predicate, benefit);
1835 populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
1836 patterns, predicate, benefit);
1837 populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
1838 patterns, predicate, benefit);
1839 populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
1840 patterns, predicate, benefit);
1841 populateMathPolynomialApproximationPattern<
1842 CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate,
1843 benefit);
1844 populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
1845 patterns, predicate, benefit);
1846 populateMathPolynomialApproximationPattern<ErfcOp,
1847 ErfcPolynomialApproximation>(
1848 patterns, predicate, benefit);
1849 populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
1850 patterns, predicate, benefit);
1851 populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
1852 patterns, predicate, benefit);
1853 populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
1854 patterns, predicate, benefit);
1855 populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
1856 patterns, predicate, benefit);
1857 populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
1858 patterns, predicate, benefit);
1859 populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
1860 patterns, predicate, benefit);
1861 populateMathPolynomialApproximationPattern<
1862 SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate,
1863 benefit);
1864 populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
1865 patterns, predicate, benefit);
1866}
1867
1868void mlir::populateMathPolynomialApproximationPatterns(
1869 RewritePatternSet &patterns,
1870 const MathPolynomialApproximationOptions &options) {
1871 mlir::populateMathF32ExpansionPatterns(patterns, predicate: [](StringRef name) -> bool {
1872 return llvm::is_contained(
1873 {math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(),
1874 math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1875 math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(),
1876 math::ErfOp::getOperationName(), math::ErfcOp::getOperationName(),
1877 math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
1878 math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1879 math::CosOp::getOperationName()},
1880 name);
1881 });
1882
1883 populateMathPolynomialApproximationPatterns(
1884 patterns, predicate: [](StringRef name) -> bool {
1885 return llvm::is_contained(
1886 {math::AtanOp::getOperationName(),
1887 math::Atan2Op::getOperationName(),
1888 math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1889 math::Log2Op::getOperationName(),
1890 math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(),
1891 math::ErfcOp::getOperationName(), math::AsinOp::getOperationName(),
1892 math::AcosOp::getOperationName(), math::ExpOp::getOperationName(),
1893 math::ExpM1Op::getOperationName(),
1894 math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1895 math::CosOp::getOperationName()},
1896 name);
1897 });
1898
1899 if (options.enableAvx2) {
1900 auto predicateRsqrt = [](StringRef name) {
1901 return name == math::RsqrtOp::getOperationName();
1902 };
1903 mlir::populateMathF32ExpansionPatterns(patterns, predicateRsqrt);
1904 mlir::populateMathPolynomialApproximationPatterns(patterns, predicateRsqrt);
1905 }
1906}
1907

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp