1//===- PolynomialApproximation.cpp - Approximate math operations ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements expansion of math operations to fast approximations
10// that do not rely on any of the library functions.
11//
12//===----------------------------------------------------------------------===//
13
14#include <climits>
15#include <cmath>
16#include <cstddef>
17
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Math/IR/Math.h"
20#include "mlir/Dialect/Math/Transforms/Approximation.h"
21#include "mlir/Dialect/Math/Transforms/Passes.h"
22#include "mlir/Dialect/Utils/IndexingUtils.h"
23#include "mlir/Dialect/Vector/IR/VectorOps.h"
24#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
25#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
26#include "mlir/IR/Builders.h"
27#include "mlir/IR/BuiltinTypes.h"
28#include "mlir/IR/ImplicitLocOpBuilder.h"
29#include "mlir/IR/OpDefinition.h"
30#include "mlir/IR/PatternMatch.h"
31#include "mlir/IR/TypeUtilities.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34#include "llvm/ADT/ArrayRef.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/Support/MathExtras.h"
37
38using namespace mlir;
39using namespace mlir::math;
40using namespace mlir::vector;
41
42// Helper to encapsulate a vector's shape (including scalable dims).
43struct VectorShape {
44 ArrayRef<int64_t> sizes;
45 ArrayRef<bool> scalableFlags;
46
47 bool empty() const { return sizes.empty(); }
48};
49
50// Returns vector shape if the type is a vector. Returns an empty shape if it is
51// not a vector.
52static VectorShape vectorShape(Type type) {
53 auto vectorType = dyn_cast<VectorType>(type);
54 return vectorType
55 ? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
56 : VectorShape{};
57}
58
59static VectorShape vectorShape(Value value) {
60 return vectorShape(type: value.getType());
61}
62
63//----------------------------------------------------------------------------//
64// Broadcast scalar types and values into vector types and values.
65//----------------------------------------------------------------------------//
66
67// Broadcasts scalar type into vector type (iff shape is non-scalar).
68static Type broadcast(Type type, VectorShape shape) {
69 assert(!isa<VectorType>(type) && "must be scalar type");
70 return !shape.empty()
71 ? VectorType::get(shape.sizes, type, shape.scalableFlags)
72 : type;
73}
74
75// Broadcasts scalar value into vector (iff shape is non-scalar).
76static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
77 VectorShape shape) {
78 assert(!isa<VectorType>(value.getType()) && "must be scalar value");
79 auto type = broadcast(type: value.getType(), shape);
80 return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
81}
82
83//----------------------------------------------------------------------------//
84// Helper function to handle n-D vectors with 1-D operations.
85//----------------------------------------------------------------------------//
86
87// Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors
88// and calls the compute function with 1-D vector operands. Stitches back all
89// results into the original n-D vector result.
90//
91// Examples: vectorWidth = 8
92// - vector<4x8xf32> unrolled 4 times
93// - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times
94// - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times
95//
96// Some math approximations rely on ISA-specific operations that only accept
97// fixed size 1-D vectors (e.g. AVX expects vectors of width 8).
98//
99// It is the caller's responsibility to verify that the inner dimension is
100// divisible by the vectorWidth, and that all operands have the same vector
101// shape.
102static Value
103handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
104 ValueRange operands, int64_t vectorWidth,
105 llvm::function_ref<Value(ValueRange)> compute) {
106 assert(!operands.empty() && "operands must be not empty");
107 assert(vectorWidth > 0 && "vector width must be larger than 0");
108
109 VectorType inputType = cast<VectorType>(operands[0].getType());
110 ArrayRef<int64_t> inputShape = inputType.getShape();
111
112 // If input shape matches target vector width, we can just call the
113 // user-provided compute function with the operands.
114 if (inputShape == llvm::ArrayRef(vectorWidth))
115 return compute(operands);
116
117 // Check if the inner dimension has to be expanded, or we can directly iterate
118 // over the outer dimensions of the vector.
119 int64_t innerDim = inputShape.back();
120 int64_t expansionDim = innerDim / vectorWidth;
121 assert((innerDim % vectorWidth == 0) && "invalid inner dimension size");
122
123 // Maybe expand operands to the higher rank vector shape that we'll use to
124 // iterate over and extract one dimensional vectors.
125 SmallVector<int64_t> expandedShape(inputShape.begin(), inputShape.end());
126 SmallVector<Value> expandedOperands(operands);
127
128 if (expansionDim > 1) {
129 // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth].
130 expandedShape.insert(I: expandedShape.end() - 1, Elt: expansionDim);
131 expandedShape.back() = vectorWidth;
132
133 for (unsigned i = 0; i < operands.size(); ++i) {
134 auto operand = operands[i];
135 auto eltType = cast<VectorType>(operand.getType()).getElementType();
136 auto expandedType = VectorType::get(expandedShape, eltType);
137 expandedOperands[i] =
138 builder.create<vector::ShapeCastOp>(expandedType, operand);
139 }
140 }
141
142 // Iterate over all outer dimensions of the compute shape vector type.
143 auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
144 int64_t maxIndex = computeMaxLinearIndex(iterationDims);
145 auto strides = computeStrides(iterationDims);
146
147 // Compute results for each one dimensional vector.
148 SmallVector<Value> results(maxIndex);
149
150 for (int64_t i = 0; i < maxIndex; ++i) {
151 auto offsets = delinearize(i, strides);
152
153 SmallVector<Value> extracted(expandedOperands.size());
154 for (const auto &tuple : llvm::enumerate(expandedOperands))
155 extracted[tuple.index()] =
156 builder.create<vector::ExtractOp>(tuple.value(), offsets);
157
158 results[i] = compute(extracted);
159 }
160
161 // Stitch results together into one large vector.
162 Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
163 Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
164 Value result = builder.create<arith::ConstantOp>(
165 resultExpandedType, builder.getZeroAttr(resultExpandedType));
166
167 for (int64_t i = 0; i < maxIndex; ++i)
168 result = builder.create<vector::InsertOp>(results[i], result,
169 delinearize(i, strides));
170
171 // Reshape back to the original vector shape.
172 return builder.create<vector::ShapeCastOp>(
173 VectorType::get(inputShape, resultEltType), result);
174}
175
176//----------------------------------------------------------------------------//
177// Helper functions to create constants.
178//----------------------------------------------------------------------------//
179
180static Value floatCst(ImplicitLocOpBuilder &builder, float value,
181 Type elementType) {
182 assert((elementType.isF16() || elementType.isF32()) &&
183 "x must be f16 or f32 type.");
184 return builder.create<arith::ConstantOp>(
185 builder.getFloatAttr(elementType, value));
186}
187
188static Value f32Cst(ImplicitLocOpBuilder &builder, double value) {
189 return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
190}
191
192static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
193 return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value));
194}
195
196static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
197 Value i32Value = i32Cst(builder, value: static_cast<int32_t>(bits));
198 return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value);
199}
200
201//----------------------------------------------------------------------------//
202// Helper functions to build math functions approximations.
203//----------------------------------------------------------------------------//
204
205// Return the minimum of the two values or NaN if value is NaN
206static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) {
207 return builder.create<arith::SelectOp>(
208 builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
209 value, bound);
210}
211
212// Return the maximum of the two values or NaN if value is NaN
213static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) {
214 return builder.create<arith::SelectOp>(
215 builder.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
216 value, bound);
217}
218
219// Return the clamped value or NaN if value is NaN
220static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
221 Value upperBound) {
222 return max(builder, value: min(builder, value, bound: upperBound), bound: lowerBound);
223}
224
225// Decomposes given floating point value `arg` into a normalized fraction and
226// an integral power of two (see std::frexp). Returned values have float type.
227static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
228 bool isPositive = false) {
229 assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
230 VectorShape shape = vectorShape(value: arg);
231
232 auto bcast = [&](Value value) -> Value {
233 return broadcast(builder, value, shape);
234 };
235
236 auto i32 = builder.getIntegerType(32);
237 auto i32Vec = broadcast(i32, shape);
238 auto f32Vec = broadcast(type: builder.getF32Type(), shape);
239
240 Value cst126f = f32Cst(builder, value: 126.0f);
241 Value cstHalf = f32Cst(builder, value: 0.5f);
242 Value cstInvMantMask = f32FromBits(builder, bits: ~0x7f800000u);
243
244 // Bitcast to i32 for bitwise operations.
245 Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf);
246 Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask);
247 Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg);
248
249 // Compute normalized fraction.
250 Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
251 Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half));
252 Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1);
253
254 // Compute exponent.
255 Value arg0 = isPositive ? arg : builder.create<math::AbsFOp>(arg);
256 Value biasedExponentBits = builder.create<arith::ShRUIOp>(
257 builder.create<arith::BitcastOp>(i32Vec, arg0),
258 bcast(i32Cst(builder, 23)));
259 Value biasedExponent =
260 builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
261 Value exponent =
262 builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f));
263
264 return {normalizedFraction, exponent};
265}
266
267// Computes exp2 for an i32 argument.
268static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
269 assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
270 VectorShape shape = vectorShape(value: arg);
271
272 auto bcast = [&](Value value) -> Value {
273 return broadcast(builder, value, shape);
274 };
275
276 auto f32Vec = broadcast(type: builder.getF32Type(), shape);
277 // The exponent of f32 located at 23-bit.
278 auto exponetBitLocation = bcast(i32Cst(builder, value: 23));
279 // Set the exponent bias to zero.
280 auto bias = bcast(i32Cst(builder, value: 127));
281
282 Value biasedArg = builder.create<arith::AddIOp>(arg, bias);
283 Value exp2ValueInt =
284 builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation);
285 Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt);
286
287 return exp2ValueF32;
288}
289
290namespace {
291Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
292 llvm::ArrayRef<Value> coeffs, Value x) {
293 Type elementType = getElementTypeOrSelf(val: x);
294 assert((elementType.isF32() || elementType.isF16()) &&
295 "x must be f32 or f16 type");
296 VectorShape shape = vectorShape(value: x);
297
298 if (coeffs.empty())
299 return broadcast(builder, value: floatCst(builder, value: 0.0f, elementType), shape);
300
301 if (coeffs.size() == 1)
302 return coeffs[0];
303
304 Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
305 coeffs[coeffs.size() - 2]);
306 for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
307 res = builder.create<math::FmaOp>(x, res, coeffs[i]);
308 }
309 return res;
310}
311} // namespace
312
313//----------------------------------------------------------------------------//
314// Helper function/pattern to insert casts for reusing F32 bit expansion.
315//----------------------------------------------------------------------------//
316
317template <typename T>
318LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
319 // Conservatively only allow where the operand and result types are exactly 1.
320 Type origType = op->getResultTypes().front();
321 for (Type t : llvm::drop_begin(RangeOrContainer: op->getResultTypes()))
322 if (origType != t)
323 return rewriter.notifyMatchFailure(arg&: op, msg: "required all types to match");
324 for (Type t : op->getOperandTypes())
325 if (origType != t)
326 return rewriter.notifyMatchFailure(arg&: op, msg: "required all types to match");
327
328 // Skip if already F32 or larger than 32 bits.
329 if (getElementTypeOrSelf(type: origType).isF32() ||
330 getElementTypeOrSelf(type: origType).getIntOrFloatBitWidth() > 32)
331 return failure();
332
333 // Create F32 equivalent type.
334 Type newType;
335 if (auto shaped = dyn_cast<ShapedType>(origType)) {
336 newType = shaped.clone(rewriter.getF32Type());
337 } else if (isa<FloatType>(Val: origType)) {
338 newType = rewriter.getF32Type();
339 } else {
340 return rewriter.notifyMatchFailure(arg&: op,
341 msg: "unable to find F32 equivalent type");
342 }
343
344 Location loc = op->getLoc();
345 SmallVector<Value> operands;
346 for (auto operand : op->getOperands())
347 operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand));
348 auto result =
349 rewriter.create<T>(loc, TypeRange{newType}, operands, op->getAttrs());
350 rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
351 return success();
352}
353
354namespace {
355// Pattern to cast to F32 to reuse F32 expansion as fallback for single-result
356// op.
357// TODO: Consider revising to avoid adding multiple casts for a subgraph that is
358// all in lower precision. Currently this is only fallback support and performs
359// simplistic casting.
360template <typename T>
361struct ReuseF32Expansion : public OpRewritePattern<T> {
362public:
363 using OpRewritePattern<T>::OpRewritePattern;
364 LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final {
365 static_assert(
366 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
367 "requires same operands and result types");
368 return insertCasts<T>(op, rewriter);
369 }
370};
371} // namespace
372
373//----------------------------------------------------------------------------//
374// AtanOp approximation.
375//----------------------------------------------------------------------------//
376
377namespace {
378struct AtanApproximation : public OpRewritePattern<math::AtanOp> {
379public:
380 using OpRewritePattern::OpRewritePattern;
381
382 LogicalResult matchAndRewrite(math::AtanOp op,
383 PatternRewriter &rewriter) const final;
384};
385} // namespace
386
387LogicalResult
388AtanApproximation::matchAndRewrite(math::AtanOp op,
389 PatternRewriter &rewriter) const {
390 auto operand = op.getOperand();
391 if (!getElementTypeOrSelf(operand).isF32())
392 return rewriter.notifyMatchFailure(op, "unsupported operand type");
393
394 VectorShape shape = vectorShape(op.getOperand());
395
396 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
397 Value abs = builder.create<math::AbsFOp>(operand);
398
399 auto one = broadcast(builder, value: f32Cst(builder, value: 1.0), shape);
400
401 // When 0.66 < x <= 2.41 we do (x-1) / (x+1):
402 auto twoThirds = broadcast(builder, value: f32Cst(builder, value: 0.66), shape);
403 Value cmp2 =
404 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, twoThirds);
405 Value addone = builder.create<arith::AddFOp>(abs, one);
406 Value subone = builder.create<arith::SubFOp>(abs, one);
407 Value xnum = builder.create<arith::SelectOp>(cmp2, subone, abs);
408 Value xden = builder.create<arith::SelectOp>(cmp2, addone, one);
409
410 auto bcast = [&](Value value) -> Value {
411 return broadcast(builder, value, shape);
412 };
413
414 // Break into the <= 0.66 or > 2.41 we do x or 1/x:
415 auto tan3pio8 = bcast(f32Cst(builder, value: 2.41421356237309504880));
416 Value cmp1 =
417 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, tan3pio8);
418 xnum = builder.create<arith::SelectOp>(cmp1, one, xnum);
419 xden = builder.create<arith::SelectOp>(cmp1, abs, xden);
420
421 Value x = builder.create<arith::DivFOp>(xnum, xden);
422 Value xx = builder.create<arith::MulFOp>(x, x);
423
424 // Perform the Taylor series approximation for atan over the range
425 // [0.0, 0.66].
426 auto p0 = bcast(f32Cst(builder, value: -8.750608600031904122785e-01));
427 auto p1 = bcast(f32Cst(builder, value: -1.615753718733365076637e+01));
428 auto p2 = bcast(f32Cst(builder, value: -7.500855792314704667340e+01));
429 auto p3 = bcast(f32Cst(builder, value: -1.228866684490136173410e+02));
430 auto p4 = bcast(f32Cst(builder, value: -6.485021904942025371773e+01));
431 auto q0 = bcast(f32Cst(builder, value: +2.485846490142306297962e+01));
432 auto q1 = bcast(f32Cst(builder, value: +1.650270098316988542046e+02));
433 auto q2 = bcast(f32Cst(builder, value: +4.328810604912902668951e+02));
434 auto q3 = bcast(f32Cst(builder, value: +4.853903996359136964868e+02));
435 auto q4 = bcast(f32Cst(builder, value: +1.945506571482613964425e+02));
436
437 // Apply the polynomial approximation for the numerator:
438 Value n = p0;
439 n = builder.create<math::FmaOp>(xx, n, p1);
440 n = builder.create<math::FmaOp>(xx, n, p2);
441 n = builder.create<math::FmaOp>(xx, n, p3);
442 n = builder.create<math::FmaOp>(xx, n, p4);
443 n = builder.create<arith::MulFOp>(n, xx);
444
445 // Apply the polynomial approximation for the denominator:
446 Value d = q0;
447 d = builder.create<math::FmaOp>(xx, d, q1);
448 d = builder.create<math::FmaOp>(xx, d, q2);
449 d = builder.create<math::FmaOp>(xx, d, q3);
450 d = builder.create<math::FmaOp>(xx, d, q4);
451
452 // Compute approximation of theta:
453 Value ans0 = builder.create<arith::DivFOp>(n, d);
454 ans0 = builder.create<math::FmaOp>(ans0, x, x);
455
456 // Correct for the input mapping's angles:
457 Value mpi4 = bcast(f32Cst(builder, value: llvm::numbers::pi / 4));
458 Value ans2 = builder.create<arith::AddFOp>(mpi4, ans0);
459 Value ans = builder.create<arith::SelectOp>(cmp2, ans2, ans0);
460
461 Value mpi2 = bcast(f32Cst(builder, value: llvm::numbers::pi / 2));
462 Value ans1 = builder.create<arith::SubFOp>(mpi2, ans0);
463 ans = builder.create<arith::SelectOp>(cmp1, ans1, ans);
464
465 // Correct for signing of the input.
466 rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand);
467 return success();
468}
469
470//----------------------------------------------------------------------------//
471// AtanOp approximation.
472//----------------------------------------------------------------------------//
473
474namespace {
475struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> {
476public:
477 using OpRewritePattern::OpRewritePattern;
478
479 LogicalResult matchAndRewrite(math::Atan2Op op,
480 PatternRewriter &rewriter) const final;
481};
482} // namespace
483
484LogicalResult
485Atan2Approximation::matchAndRewrite(math::Atan2Op op,
486 PatternRewriter &rewriter) const {
487 auto y = op.getOperand(0);
488 auto x = op.getOperand(1);
489 if (!getElementTypeOrSelf(x).isF32())
490 return rewriter.notifyMatchFailure(op, "unsupported operand type");
491
492 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
493 VectorShape shape = vectorShape(op.getResult());
494
495 // Compute atan in the valid range.
496 auto div = builder.create<arith::DivFOp>(y, x);
497 auto atan = builder.create<math::AtanOp>(div);
498
499 // Determine what the atan would be for a 180 degree rotation.
500 auto zero = broadcast(builder, value: f32Cst(builder, value: 0.0f), shape);
501 auto pi = broadcast(builder, value: f32Cst(builder, value: 3.14159265359f), shape);
502 auto addPi = builder.create<arith::AddFOp>(atan, pi);
503 auto subPi = builder.create<arith::SubFOp>(atan, pi);
504 auto atanGt =
505 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
506 auto flippedAtan = builder.create<arith::SelectOp>(atanGt, subPi, addPi);
507
508 // Determine whether to directly use atan or use the 180 degree flip
509 auto xGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
510 Value result = builder.create<arith::SelectOp>(xGt, atan, flippedAtan);
511
512 // Handle x = 0, y > 0
513 Value xZero =
514 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
515 Value yGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
516 Value isHalfPi = builder.create<arith::AndIOp>(xZero, yGt);
517 auto halfPi = broadcast(builder, value: f32Cst(builder, value: 1.57079632679f), shape);
518 result = builder.create<arith::SelectOp>(isHalfPi, halfPi, result);
519
520 // Handle x = 0, y < 0
521 Value yLt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
522 Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(xZero, yLt);
523 auto negativeHalfPiPi =
524 broadcast(builder, value: f32Cst(builder, value: -1.57079632679f), shape);
525 result = builder.create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
526 result);
527
528 // Handle x = 0, y = 0;
529 Value yZero =
530 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
531 Value isNan = builder.create<arith::AndIOp>(xZero, yZero);
532 Value cstNan = broadcast(builder, value: f32FromBits(builder, bits: 0x7fc00000), shape);
533 result = builder.create<arith::SelectOp>(isNan, cstNan, result);
534
535 rewriter.replaceOp(op, result);
536 return success();
537}
538
539//----------------------------------------------------------------------------//
540// TanhOp approximation.
541//----------------------------------------------------------------------------//
542
543namespace {
544struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
545public:
546 using OpRewritePattern::OpRewritePattern;
547
548 LogicalResult matchAndRewrite(math::TanhOp op,
549 PatternRewriter &rewriter) const final;
550};
551} // namespace
552
553LogicalResult
554TanhApproximation::matchAndRewrite(math::TanhOp op,
555 PatternRewriter &rewriter) const {
556 if (!getElementTypeOrSelf(op.getOperand()).isF32())
557 return rewriter.notifyMatchFailure(op, "unsupported operand type");
558
559 VectorShape shape = vectorShape(op.getOperand());
560
561 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
562 auto bcast = [&](Value value) -> Value {
563 return broadcast(builder, value, shape);
564 };
565
566 // Clamp operand into [plusClamp, minusClamp] range.
567 Value minusClamp = bcast(f32Cst(builder, value: -7.99881172180175781f));
568 Value plusClamp = bcast(f32Cst(builder, value: 7.99881172180175781f));
569 Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp);
570
571 // Mask for tiny values that are approximated with `operand`.
572 Value tiny = bcast(f32Cst(builder, value: 0.0004f));
573 Value tinyMask = builder.create<arith::CmpFOp>(
574 arith::CmpFPredicate::OLT, builder.create<math::AbsFOp>(op.getOperand()),
575 tiny);
576
577 // The monomial coefficients of the numerator polynomial (odd).
578 Value alpha1 = bcast(f32Cst(builder, value: 4.89352455891786e-03f));
579 Value alpha3 = bcast(f32Cst(builder, value: 6.37261928875436e-04f));
580 Value alpha5 = bcast(f32Cst(builder, value: 1.48572235717979e-05f));
581 Value alpha7 = bcast(f32Cst(builder, value: 5.12229709037114e-08f));
582 Value alpha9 = bcast(f32Cst(builder, value: -8.60467152213735e-11f));
583 Value alpha11 = bcast(f32Cst(builder, value: 2.00018790482477e-13f));
584 Value alpha13 = bcast(f32Cst(builder, value: -2.76076847742355e-16f));
585
586 // The monomial coefficients of the denominator polynomial (even).
587 Value beta0 = bcast(f32Cst(builder, value: 4.89352518554385e-03f));
588 Value beta2 = bcast(f32Cst(builder, value: 2.26843463243900e-03f));
589 Value beta4 = bcast(f32Cst(builder, value: 1.18534705686654e-04f));
590 Value beta6 = bcast(f32Cst(builder, value: 1.19825839466702e-06f));
591
592 // Since the polynomials are odd/even, we need x^2.
593 Value x2 = builder.create<arith::MulFOp>(x, x);
594
595 // Evaluate the numerator polynomial p.
596 Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11);
597 p = builder.create<math::FmaOp>(x2, p, alpha9);
598 p = builder.create<math::FmaOp>(x2, p, alpha7);
599 p = builder.create<math::FmaOp>(x2, p, alpha5);
600 p = builder.create<math::FmaOp>(x2, p, alpha3);
601 p = builder.create<math::FmaOp>(x2, p, alpha1);
602 p = builder.create<arith::MulFOp>(x, p);
603
604 // Evaluate the denominator polynomial q.
605 Value q = builder.create<math::FmaOp>(x2, beta6, beta4);
606 q = builder.create<math::FmaOp>(x2, q, beta2);
607 q = builder.create<math::FmaOp>(x2, q, beta0);
608
609 // Divide the numerator by the denominator.
610 Value res = builder.create<arith::SelectOp>(
611 tinyMask, x, builder.create<arith::DivFOp>(p, q));
612
613 rewriter.replaceOp(op, res);
614
615 return success();
616}
617
618#define LN2_VALUE \
619 0.693147180559945309417232121458176568075500134360255254120680009493393621L
620#define LOG2E_VALUE \
621 1.442695040888963407359924681001892137426645954152985934135449406931109219L
622
623//----------------------------------------------------------------------------//
624// LogOp and Log2Op approximation.
625//----------------------------------------------------------------------------//
626
627namespace {
628template <typename Op>
629struct LogApproximationBase : public OpRewritePattern<Op> {
630 using OpRewritePattern<Op>::OpRewritePattern;
631
632 /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise.
633 LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter,
634 bool base2) const;
635};
636} // namespace
637
638// This approximation comes from Julien Pommier's SSE math library.
639// Link: http://gruntthepeon.free.fr/ssemath
640template <typename Op>
641LogicalResult
642LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
643 bool base2) const {
644 if (!getElementTypeOrSelf(op.getOperand()).isF32())
645 return rewriter.notifyMatchFailure(op, "unsupported operand type");
646
647 VectorShape shape = vectorShape(op.getOperand());
648
649 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
650 auto bcast = [&](Value value) -> Value {
651 return broadcast(builder, value, shape);
652 };
653
654 Value cstZero = bcast(f32Cst(builder, value: 0.0f));
655 Value cstOne = bcast(f32Cst(builder, value: 1.0f));
656 Value cstNegHalf = bcast(f32Cst(builder, value: -0.5f));
657
658 // The smallest non denormalized float number.
659 Value cstMinNormPos = bcast(f32FromBits(builder, bits: 0x00800000u));
660 Value cstMinusInf = bcast(f32FromBits(builder, bits: 0xff800000u));
661 Value cstPosInf = bcast(f32FromBits(builder, bits: 0x7f800000u));
662 Value cstNan = bcast(f32FromBits(builder, bits: 0x7fc00000));
663
664 // Polynomial coefficients.
665 Value cstCephesSQRTHF = bcast(f32Cst(builder, value: 0.707106781186547524f));
666 Value cstCephesLogP0 = bcast(f32Cst(builder, value: 7.0376836292E-2f));
667 Value cstCephesLogP1 = bcast(f32Cst(builder, value: -1.1514610310E-1f));
668 Value cstCephesLogP2 = bcast(f32Cst(builder, value: 1.1676998740E-1f));
669 Value cstCephesLogP3 = bcast(f32Cst(builder, value: -1.2420140846E-1f));
670 Value cstCephesLogP4 = bcast(f32Cst(builder, value: +1.4249322787E-1f));
671 Value cstCephesLogP5 = bcast(f32Cst(builder, value: -1.6668057665E-1f));
672 Value cstCephesLogP6 = bcast(f32Cst(builder, value: +2.0000714765E-1f));
673 Value cstCephesLogP7 = bcast(f32Cst(builder, value: -2.4999993993E-1f));
674 Value cstCephesLogP8 = bcast(f32Cst(builder, value: +3.3333331174E-1f));
675
676 Value x = op.getOperand();
677
678 // Truncate input values to the minimum positive normal.
679 x = max(builder, value: x, bound: cstMinNormPos);
680
681 // Extract significant in the range [0.5,1) and exponent.
682 std::pair<Value, Value> pair = frexp(builder, arg: x, /*isPositive=*/true);
683 x = pair.first;
684 Value e = pair.second;
685
686 // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift
687 // by -1.0. The values are then centered around 0, which improves the
688 // stability of the polynomial evaluation:
689 //
690 // if( x < SQRTHF ) {
691 // e -= 1;
692 // x = x + x - 1.0;
693 // } else { x = x - 1.0; }
694 Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
695 cstCephesSQRTHF);
696 Value tmp = builder.create<arith::SelectOp>(mask, x, cstZero);
697
698 x = builder.create<arith::SubFOp>(x, cstOne);
699 e = builder.create<arith::SubFOp>(
700 e, builder.create<arith::SelectOp>(mask, cstOne, cstZero));
701 x = builder.create<arith::AddFOp>(x, tmp);
702
703 Value x2 = builder.create<arith::MulFOp>(x, x);
704 Value x3 = builder.create<arith::MulFOp>(x2, x);
705
706 // Evaluate the polynomial approximant of degree 8 in three parts.
707 Value y0, y1, y2;
708 y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
709 y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
710 y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
711 y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2);
712 y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5);
713 y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8);
714 y0 = builder.create<math::FmaOp>(y0, x3, y1);
715 y0 = builder.create<math::FmaOp>(y0, x3, y2);
716 y0 = builder.create<arith::MulFOp>(y0, x3);
717
718 y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0);
719 x = builder.create<arith::AddFOp>(x, y0);
720
721 if (base2) {
722 Value cstLog2e = bcast(f32Cst(builder, value: static_cast<float>(LOG2E_VALUE)));
723 x = builder.create<math::FmaOp>(x, cstLog2e, e);
724 } else {
725 Value cstLn2 = bcast(f32Cst(builder, value: static_cast<float>(LN2_VALUE)));
726 x = builder.create<math::FmaOp>(e, cstLn2, x);
727 }
728
729 Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
730 op.getOperand(), cstZero);
731 Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
732 op.getOperand(), cstZero);
733 Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
734 op.getOperand(), cstPosInf);
735
736 // Filter out invalid values:
737 // • x == 0 -> -INF
738 // • x < 0 -> NAN
739 // • x == +INF -> +INF
740 Value aproximation = builder.create<arith::SelectOp>(
741 zeroMask, cstMinusInf,
742 builder.create<arith::SelectOp>(
743 invalidMask, cstNan,
744 builder.create<arith::SelectOp>(posInfMask, cstPosInf, x)));
745
746 rewriter.replaceOp(op, aproximation);
747
748 return success();
749}
750
751namespace {
752struct LogApproximation : public LogApproximationBase<math::LogOp> {
753 using LogApproximationBase::LogApproximationBase;
754
755 LogicalResult matchAndRewrite(math::LogOp op,
756 PatternRewriter &rewriter) const final {
757 return logMatchAndRewrite(op, rewriter, /*base2=*/false);
758 }
759};
760} // namespace
761
762namespace {
763struct Log2Approximation : public LogApproximationBase<math::Log2Op> {
764 using LogApproximationBase::LogApproximationBase;
765
766 LogicalResult matchAndRewrite(math::Log2Op op,
767 PatternRewriter &rewriter) const final {
768 return logMatchAndRewrite(op, rewriter, /*base2=*/true);
769 }
770};
771} // namespace
772
773//----------------------------------------------------------------------------//
774// Log1p approximation.
775//----------------------------------------------------------------------------//
776
777namespace {
778struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
779public:
780 using OpRewritePattern::OpRewritePattern;
781
782 LogicalResult matchAndRewrite(math::Log1pOp op,
783 PatternRewriter &rewriter) const final;
784};
785} // namespace
786
787// Approximate log(1+x).
788LogicalResult
789Log1pApproximation::matchAndRewrite(math::Log1pOp op,
790 PatternRewriter &rewriter) const {
791 if (!getElementTypeOrSelf(op.getOperand()).isF32())
792 return rewriter.notifyMatchFailure(op, "unsupported operand type");
793
794 VectorShape shape = vectorShape(op.getOperand());
795
796 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
797 auto bcast = [&](Value value) -> Value {
798 return broadcast(builder, value, shape);
799 };
800
801 // Approximate log(1+x) using the following, due to W. Kahan:
802 // u = x + 1.0;
803 // if (u == 1.0 || u == inf) return x;
804 // return x * log(u) / (u - 1.0);
805 // ^^^^^^^^^^^^^^^^^^^^^^
806 // "logLarge" below.
807 Value cstOne = bcast(f32Cst(builder, value: 1.0f));
808 Value x = op.getOperand();
809 Value u = builder.create<arith::AddFOp>(x, cstOne);
810 Value uSmall =
811 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
812 Value logU = builder.create<math::LogOp>(u);
813 Value uInf =
814 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
815 Value logLarge = builder.create<arith::MulFOp>(
816 x, builder.create<arith::DivFOp>(
817 logU, builder.create<arith::SubFOp>(u, cstOne)));
818 Value approximation = builder.create<arith::SelectOp>(
819 builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge);
820 rewriter.replaceOp(op, approximation);
821 return success();
822}
823
824//----------------------------------------------------------------------------//
825// Erf approximation.
826//----------------------------------------------------------------------------//
827
828// Approximates erf(x) with
829// a - P(x)/Q(x)
830// where P and Q are polynomials of degree 4.
831// Different coefficients are chosen based on the value of x.
832// The approximation error is ~2.5e-07.
833// Boost's minimax tool that utilizes the Remez method was used to find the
834// coefficients.
835LogicalResult
836ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
837 PatternRewriter &rewriter) const {
838 Value operand = op.getOperand();
839 Type elementType = getElementTypeOrSelf(val: operand);
840
841 if (!(elementType.isF32() || elementType.isF16()))
842 return rewriter.notifyMatchFailure(op,
843 "only f32 and f16 type is supported.");
844 VectorShape shape = vectorShape(value: operand);
845
846 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
847 auto bcast = [&](Value value) -> Value {
848 return broadcast(builder, value, shape);
849 };
850
851 const int intervalsCount = 3;
852 const int polyDegree = 4;
853
854 Value zero = bcast(floatCst(builder, value: 0, elementType));
855 Value one = bcast(floatCst(builder, value: 1, elementType));
856 Value pp[intervalsCount][polyDegree + 1];
857 pp[0][0] = bcast(floatCst(builder, value: +0.00000000000000000e+00f, elementType));
858 pp[0][1] = bcast(floatCst(builder, value: +1.12837916222975858e+00f, elementType));
859 pp[0][2] = bcast(floatCst(builder, value: -5.23018562988006470e-01f, elementType));
860 pp[0][3] = bcast(floatCst(builder, value: +2.09741709609267072e-01f, elementType));
861 pp[0][4] = bcast(floatCst(builder, value: +2.58146801602987875e-02f, elementType));
862 pp[1][0] = bcast(floatCst(builder, value: +0.00000000000000000e+00f, elementType));
863 pp[1][1] = bcast(floatCst(builder, value: +1.12750687816789140e+00f, elementType));
864 pp[1][2] = bcast(floatCst(builder, value: -3.64721408487825775e-01f, elementType));
865 pp[1][3] = bcast(floatCst(builder, value: +1.18407396425136952e-01f, elementType));
866 pp[1][4] = bcast(floatCst(builder, value: +3.70645533056476558e-02f, elementType));
867 pp[2][0] = bcast(floatCst(builder, value: -3.30093071049483172e-03f, elementType));
868 pp[2][1] = bcast(floatCst(builder, value: +3.51961938357697011e-03f, elementType));
869 pp[2][2] = bcast(floatCst(builder, value: -1.41373622814988039e-03f, elementType));
870 pp[2][3] = bcast(floatCst(builder, value: +2.53447094961941348e-04f, elementType));
871 pp[2][4] = bcast(floatCst(builder, value: -1.71048029455037401e-05f, elementType));
872
873 Value qq[intervalsCount][polyDegree + 1];
874 qq[0][0] = bcast(floatCst(builder, value: +1.000000000000000000e+00f, elementType));
875 qq[0][1] = bcast(floatCst(builder, value: -4.635138185962547255e-01f, elementType));
876 qq[0][2] = bcast(floatCst(builder, value: +5.192301327279782447e-01f, elementType));
877 qq[0][3] = bcast(floatCst(builder, value: -1.318089722204810087e-01f, elementType));
878 qq[0][4] = bcast(floatCst(builder, value: +7.397964654672315005e-02f, elementType));
879 qq[1][0] = bcast(floatCst(builder, value: +1.00000000000000000e+00f, elementType));
880 qq[1][1] = bcast(floatCst(builder, value: -3.27607011824493086e-01f, elementType));
881 qq[1][2] = bcast(floatCst(builder, value: +4.48369090658821977e-01f, elementType));
882 qq[1][3] = bcast(floatCst(builder, value: -8.83462621207857930e-02f, elementType));
883 qq[1][4] = bcast(floatCst(builder, value: +5.72442770283176093e-02f, elementType));
884 qq[2][0] = bcast(floatCst(builder, value: +1.00000000000000000e+00f, elementType));
885 qq[2][1] = bcast(floatCst(builder, value: -2.06069165953913769e+00f, elementType));
886 qq[2][2] = bcast(floatCst(builder, value: +1.62705939945477759e+00f, elementType));
887 qq[2][3] = bcast(floatCst(builder, value: -5.83389859211130017e-01f, elementType));
888 qq[2][4] = bcast(floatCst(builder, value: +8.21908939856640930e-02f, elementType));
889
890 Value offsets[intervalsCount];
891 offsets[0] = bcast(floatCst(builder, value: 0.0f, elementType));
892 offsets[1] = bcast(floatCst(builder, value: 0.0f, elementType));
893 offsets[2] = bcast(floatCst(builder, value: 1.0f, elementType));
894
895 Value bounds[intervalsCount];
896 bounds[0] = bcast(floatCst(builder, value: 0.8f, elementType));
897 bounds[1] = bcast(floatCst(builder, value: 2.0f, elementType));
898 bounds[2] = bcast(floatCst(builder, value: 3.75f, elementType));
899
900 Value isNegativeArg =
901 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
902 Value negArg = builder.create<arith::NegFOp>(operand);
903 Value x = builder.create<arith::SelectOp>(isNegativeArg, negArg, operand);
904
905 Value offset = offsets[0];
906 Value p[polyDegree + 1];
907 Value q[polyDegree + 1];
908 for (int i = 0; i <= polyDegree; ++i) {
909 p[i] = pp[0][i];
910 q[i] = qq[0][i];
911 }
912
913 // TODO: maybe use vector stacking to reduce the number of selects.
914 Value isLessThanBound[intervalsCount];
915 for (int j = 0; j < intervalsCount - 1; ++j) {
916 isLessThanBound[j] =
917 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]);
918 for (int i = 0; i <= polyDegree; ++i) {
919 p[i] = builder.create<arith::SelectOp>(isLessThanBound[j], p[i],
920 pp[j + 1][i]);
921 q[i] = builder.create<arith::SelectOp>(isLessThanBound[j], q[i],
922 qq[j + 1][i]);
923 }
924 offset = builder.create<arith::SelectOp>(isLessThanBound[j], offset,
925 offsets[j + 1]);
926 }
927 isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>(
928 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
929
930 Value pPoly = makePolynomialCalculation(builder, coeffs: p, x);
931 Value qPoly = makePolynomialCalculation(builder, coeffs: q, x);
932 Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly);
933 Value formula = builder.create<arith::AddFOp>(offset, rationalPoly);
934 formula = builder.create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
935 formula, one);
936
937 // erf is odd function: erf(x) = -erf(-x).
938 Value negFormula = builder.create<arith::NegFOp>(formula);
939 Value res =
940 builder.create<arith::SelectOp>(isNegativeArg, negFormula, formula);
941
942 rewriter.replaceOp(op, res);
943
944 return success();
945}
946
947//----------------------------------------------------------------------------//
948// Exp approximation.
949//----------------------------------------------------------------------------//
950
951namespace {
952
953Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape,
954 Value value, float lowerBound, float upperBound) {
955 assert(!std::isnan(lowerBound));
956 assert(!std::isnan(upperBound));
957
958 auto bcast = [&](Value value) -> Value {
959 return broadcast(builder, value, shape);
960 };
961
962 auto selectCmp = [&builder](auto pred, Value value, Value bound) {
963 return builder.create<arith::SelectOp>(
964 builder.create<arith::CmpFOp>(pred, value, bound), value, bound);
965 };
966
967 // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs.
968 // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with
969 // arith::{Max,Min}FOp.
970 value = selectCmp(arith::CmpFPredicate::UGE, value,
971 bcast(f32Cst(builder, lowerBound)));
972 value = selectCmp(arith::CmpFPredicate::ULE, value,
973 bcast(f32Cst(builder, upperBound)));
974 return value;
975}
976
977struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
978public:
979 using OpRewritePattern::OpRewritePattern;
980
981 LogicalResult matchAndRewrite(math::ExpOp op,
982 PatternRewriter &rewriter) const final;
983};
984
985LogicalResult
986ExpApproximation::matchAndRewrite(math::ExpOp op,
987 PatternRewriter &rewriter) const {
988 auto shape = vectorShape(op.getOperand().getType());
989 auto elementTy = getElementTypeOrSelf(op.getType());
990 if (!elementTy.isF32())
991 return rewriter.notifyMatchFailure(op, "unsupported operand type");
992
993 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
994
995 auto add = [&](Value a, Value b) -> Value {
996 return builder.create<arith::AddFOp>(a, b);
997 };
998 auto bcast = [&](Value value) -> Value {
999 return broadcast(builder, value, shape);
1000 };
1001 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
1002 auto fmla = [&](Value a, Value b, Value c) {
1003 return builder.create<math::FmaOp>(a, b, c);
1004 };
1005 auto mul = [&](Value a, Value b) -> Value {
1006 return builder.create<arith::MulFOp>(a, b);
1007 };
1008
1009 // Polynomial approximation from Cephes.
1010 //
1011 // To compute e^x, we re-express it as
1012 //
1013 // e^x = e^(a + b)
1014 // = e^(a + n log(2))
1015 // = e^a * 2^n.
1016 //
1017 // We choose n = round(x / log(2)), restricting the value of `a` to
1018 // (-log(2)/2, log(2)/2). We then use a polynomial to compute e^a. The
1019 // relative error between our approximation and the true value of e^a is less
1020 // than 2^-22.5 for all values of `a` within this range.
1021
1022 // Restrict input to a small range, including some values that evaluate to
1023 // +/- inf. Note that for our lower bound, we choose log(2^-126) instead of
1024 // log(F32_EPSILON). We do so because this routine always flushes denormal
1025 // floating points to 0. Therefore, we only need to worry about exponentiating
1026 // up to the smallest representable non-denormal floating point, which is
1027 // 2^-126.
1028
1029 // Constants.
1030 Value cstHalf = bcast(f32Cst(builder, value: 0.5f));
1031 Value cstOne = bcast(f32Cst(builder, value: 1.0f));
1032
1033 // 1/log(2)
1034 Value cstLog2ef = bcast(f32Cst(builder, value: 1.44269504088896341f));
1035
1036 Value cstExpC1 = bcast(f32Cst(builder, value: -0.693359375f));
1037 Value cstExpC2 = bcast(f32Cst(builder, value: 2.12194440e-4f));
1038 Value cstExpP0 = bcast(f32Cst(builder, value: 1.9875691500E-4f));
1039 Value cstExpP1 = bcast(f32Cst(builder, value: 1.3981999507E-3f));
1040 Value cstExpP2 = bcast(f32Cst(builder, value: 8.3334519073E-3f));
1041 Value cstExpP3 = bcast(f32Cst(builder, value: 4.1665795894E-2f));
1042 Value cstExpP4 = bcast(f32Cst(builder, value: 1.6666665459E-1f));
1043 Value cstExpP5 = bcast(f32Cst(builder, value: 5.0000001201E-1f));
1044
1045 // Our computations below aren't particularly sensitive to the exact choices
1046 // here, so we choose values a bit larger/smaller than
1047 //
1048 // log(F32_MAX) = 88.723...
1049 // log(2^-126) = -87.337...
1050 Value x = op.getOperand();
1051 x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1052 Value n = floor(fmla(x, cstLog2ef, cstHalf));
1053
1054 // When we eventually do the multiplication in e^a * 2^n, we need to handle
1055 // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1
1056 // (so e^a * 2^n != inf). There's a similar problem for n < -126, the
1057 // smallest fp32 exponent.
1058 //
1059 // A straightforward solution would be to detect n out of range and split it
1060 // up, doing
1061 //
1062 // e^a * 2^n = e^a * 2^(n1 + n2)
1063 // = (2^n1 * e^a) * 2^n2.
1064 //
1065 // But it turns out this approach is quite slow, probably because it
1066 // manipulates subnormal values.
1067 //
1068 // The approach we use instead is to clamp n to [-127, 127]. Let n' be the
1069 // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow
1070 // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though
1071 // this value of `a` is outside our previously specified range, e^a will still
1072 // only have a relative error of approximately 2^-16 at worse. In practice
1073 // this seems to work well enough; it passes our exhaustive tests, breaking
1074 // only one result, and by one ulp (we return exp(88.7228394) = max-float but
1075 // we should return inf).
1076 //
1077 // In the case where n' = -127, the original input value of x is so small that
1078 // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest
1079 // normal floating point, and since we flush denormals, we simply return 0. We
1080 // do this in a branchless way by observing that our code for constructing 2^n
1081 // produces 0 if n = -127.
1082 //
1083 // The proof that n' = -127 implies e^x < 2^-126 is as follows:
1084 //
1085 // n' = -127 implies n <= -127
1086 // implies round(x / log(2)) <= -127
1087 // implies x/log(2) < -126.5
1088 // implies x < -126.5 * log(2)
1089 // implies e^x < e^(-126.5 * log(2))
1090 // implies e^x < 2^-126.5 < 2^-126
1091 //
1092 // This proves that n' = -127 implies e^x < 2^-126.
1093 n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1094
1095 // Computes x = x - n' * log(2), the value for `a`
1096 x = fmla(cstExpC1, n, x);
1097 x = fmla(cstExpC2, n, x);
1098
1099 // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5).
1100 Value z = fmla(x, cstExpP0, cstExpP1);
1101 z = fmla(z, x, cstExpP2);
1102 z = fmla(z, x, cstExpP3);
1103 z = fmla(z, x, cstExpP4);
1104 z = fmla(z, x, cstExpP5);
1105 z = fmla(z, mul(x, x), x);
1106 z = add(cstOne, z);
1107
1108 // Convert n' to an i32. This is safe because we clamped it above.
1109 auto i32Vec = broadcast(builder.getI32Type(), shape);
1110 Value nI32 = builder.create<arith::FPToSIOp>(i32Vec, n);
1111
1112 // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127.
1113 Value pow2 = exp2I32(builder, arg: nI32);
1114
1115 // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127.
1116 Value ret = mul(z, pow2);
1117
1118 rewriter.replaceOp(op, ret);
1119 return mlir::success();
1120}
1121
1122} // namespace
1123
1124//----------------------------------------------------------------------------//
1125// ExpM1 approximation.
1126//----------------------------------------------------------------------------//
1127
1128namespace {
1129
1130struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
1131public:
1132 using OpRewritePattern::OpRewritePattern;
1133
1134 LogicalResult matchAndRewrite(math::ExpM1Op op,
1135 PatternRewriter &rewriter) const final;
1136};
1137} // namespace
1138
1139LogicalResult
1140ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1141 PatternRewriter &rewriter) const {
1142 if (!getElementTypeOrSelf(op.getOperand()).isF32())
1143 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1144
1145 VectorShape shape = vectorShape(op.getOperand());
1146
1147 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1148 auto bcast = [&](Value value) -> Value {
1149 return broadcast(builder, value, shape);
1150 };
1151
1152 // expm1(x) = exp(x) - 1 = u - 1.
1153 // We have to handle it carefully when x is near 0, i.e. u ~= 1,
1154 // and when the input is ~= -inf, i.e. u - 1 ~= -1.
1155 Value cstOne = bcast(f32Cst(builder, value: 1.0f));
1156 Value cstNegOne = bcast(f32Cst(builder, value: -1.0f));
1157 Value x = op.getOperand();
1158 Value u = builder.create<math::ExpOp>(x);
1159 Value uEqOneOrNaN =
1160 builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1161 Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
1162 Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
1163 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1164 // logU = log(u) ~= x
1165 Value logU = builder.create<math::LogOp>(u);
1166
1167 // Detect exp(x) = +inf; written this way to avoid having to form +inf.
1168 Value isInf =
1169 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1170
1171 // (u - 1) * (x / ~x)
1172 Value expm1 = builder.create<arith::MulFOp>(
1173 uMinusOne, builder.create<arith::DivFOp>(x, logU));
1174 expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
1175 Value approximation = builder.create<arith::SelectOp>(
1176 uEqOneOrNaN, x,
1177 builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1178 rewriter.replaceOp(op, approximation);
1179 return success();
1180}
1181
1182//----------------------------------------------------------------------------//
1183// Sin and Cos approximation.
1184//----------------------------------------------------------------------------//
1185
1186namespace {
1187
1188template <bool isSine, typename OpTy>
1189struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
1190public:
1191 using OpRewritePattern<OpTy>::OpRewritePattern;
1192
1193 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
1194};
1195} // namespace
1196
1197#define TWO_OVER_PI \
1198 0.6366197723675813430755350534900574481378385829618257949906693762L
1199#define PI_OVER_2 \
1200 1.5707963267948966192313216916397514420985846996875529104874722961L
1201
1202// Approximates sin(x) or cos(x) by finding the best approximation polynomial in
1203// the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
1204// reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
1205template <bool isSine, typename OpTy>
1206LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1207 OpTy op, PatternRewriter &rewriter) const {
1208 static_assert(
1209 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1210 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1211
1212 if (!getElementTypeOrSelf(op.getOperand()).isF32())
1213 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1214
1215 VectorShape shape = vectorShape(op.getOperand());
1216
1217 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1218 auto bcast = [&](Value value) -> Value {
1219 return broadcast(builder, value, shape);
1220 };
1221 auto mul = [&](Value a, Value b) -> Value {
1222 return builder.create<arith::MulFOp>(a, b);
1223 };
1224 auto sub = [&](Value a, Value b) -> Value {
1225 return builder.create<arith::SubFOp>(a, b);
1226 };
1227 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
1228
1229 auto i32Vec = broadcast(builder.getI32Type(), shape);
1230 auto fPToSingedInteger = [&](Value a) -> Value {
1231 return builder.create<arith::FPToSIOp>(i32Vec, a);
1232 };
1233
1234 auto modulo4 = [&](Value a) -> Value {
1235 return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3)));
1236 };
1237
1238 auto isEqualTo = [&](Value a, Value b) -> Value {
1239 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1240 };
1241
1242 auto isGreaterThan = [&](Value a, Value b) -> Value {
1243 return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1244 };
1245
1246 auto select = [&](Value cond, Value t, Value f) -> Value {
1247 return builder.create<arith::SelectOp>(cond, t, f);
1248 };
1249
1250 auto fmla = [&](Value a, Value b, Value c) {
1251 return builder.create<math::FmaOp>(a, b, c);
1252 };
1253
1254 auto bitwiseOr = [&](Value a, Value b) {
1255 return builder.create<arith::OrIOp>(a, b);
1256 };
1257
1258 Value twoOverPi = bcast(f32Cst(builder, value: (float)TWO_OVER_PI));
1259 Value piOverTwo = bcast(f32Cst(builder, value: (float)PI_OVER_2));
1260
1261 Value x = op.getOperand();
1262
1263 Value k = floor(mul(x, twoOverPi));
1264
1265 Value y = sub(x, mul(k, piOverTwo));
1266
1267 Value cstOne = bcast(f32Cst(builder, value: 1.0));
1268 Value cstNegativeOne = bcast(f32Cst(builder, value: -1.0));
1269
1270 Value cstSC2 = bcast(f32Cst(builder, value: -0.16666667163372039794921875f));
1271 Value cstSC4 = bcast(f32Cst(builder, value: 8.333347737789154052734375e-3f));
1272 Value cstSC6 = bcast(f32Cst(builder, value: -1.9842604524455964565277099609375e-4f));
1273 Value cstSC8 =
1274 bcast(f32Cst(builder, value: 2.760012648650445044040679931640625e-6f));
1275 Value cstSC10 =
1276 bcast(f32Cst(builder, value: -2.50293279435709337121807038784027099609375e-8f));
1277
1278 Value cstCC2 = bcast(f32Cst(builder, value: -0.5f));
1279 Value cstCC4 = bcast(f32Cst(builder, value: 4.166664183139801025390625e-2f));
1280 Value cstCC6 = bcast(f32Cst(builder, value: -1.388833043165504932403564453125e-3f));
1281 Value cstCC8 = bcast(f32Cst(builder, value: 2.47562347794882953166961669921875e-5f));
1282 Value cstCC10 =
1283 bcast(f32Cst(builder, value: -2.59630184018533327616751194000244140625e-7f));
1284
1285 Value kMod4 = modulo4(fPToSingedInteger(k));
1286
1287 Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 0)));
1288 Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 1)));
1289 Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 2)));
1290 Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 3)));
1291
1292 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1293 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, value: 1)))
1294 : bitwiseOr(kR1, kR2);
1295
1296 Value y2 = mul(y, y);
1297
1298 Value base = select(sinuseCos, cstOne, y);
1299 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1300 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1301 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1302 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1303 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1304
1305 Value v1 = fmla(y2, cstC10, cstC8);
1306 Value v2 = fmla(y2, v1, cstC6);
1307 Value v3 = fmla(y2, v2, cstC4);
1308 Value v4 = fmla(y2, v3, cstC2);
1309 Value v5 = fmla(y2, v4, cstOne);
1310 Value v6 = mul(base, v5);
1311
1312 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1313
1314 rewriter.replaceOp(op, approximation);
1315
1316 return success();
1317}
1318
1319//----------------------------------------------------------------------------//
1320// Cbrt approximation.
1321//----------------------------------------------------------------------------//
1322
1323namespace {
1324struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> {
1325 using OpRewritePattern::OpRewritePattern;
1326
1327 LogicalResult matchAndRewrite(math::CbrtOp op,
1328 PatternRewriter &rewriter) const final;
1329};
1330} // namespace
1331
1332// Estimation of cube-root using an algorithm defined in
1333// Hacker's Delight 2nd Edition.
1334LogicalResult
1335CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1336 PatternRewriter &rewriter) const {
1337 auto operand = op.getOperand();
1338 if (!getElementTypeOrSelf(operand).isF32())
1339 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1340
1341 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1342 VectorShape shape = vectorShape(operand);
1343
1344 Type floatTy = getElementTypeOrSelf(operand.getType());
1345 Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
1346
1347 // Convert to vector types if necessary.
1348 floatTy = broadcast(type: floatTy, shape);
1349 intTy = broadcast(type: intTy, shape);
1350
1351 auto bconst = [&](TypedAttr attr) -> Value {
1352 Value value = b.create<arith::ConstantOp>(attr);
1353 return broadcast(builder&: b, value, shape);
1354 };
1355
1356 // Declare the initial values:
1357 Value intTwo = bconst(b.getI32IntegerAttr(2));
1358 Value intFour = bconst(b.getI32IntegerAttr(4));
1359 Value intEight = bconst(b.getI32IntegerAttr(8));
1360 Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1361 Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1362 Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1363 Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1364
1365 // Compute an approximation of one third:
1366 // union {int ix; float x;};
1367 // x = x0;
1368 // ix = ix/4 + ix/16;
1369 Value absValue = b.create<math::AbsFOp>(operand);
1370 Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
1371 Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
1372 Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1373 intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
1374
1375 // ix = ix + ix/16;
1376 divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1377 intValue = b.create<arith::AddIOp>(intValue, divideBy16);
1378
1379 // ix = ix + ix/256;
1380 Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
1381 intValue = b.create<arith::AddIOp>(intValue, divideBy256);
1382
1383 // ix = 0x2a5137a0 + ix;
1384 intValue = b.create<arith::AddIOp>(intValue, intMagic);
1385
1386 // Perform one newtons step:
1387 // x = 0.33333333f*(2.0f*x + x0/(x*x));
1388 Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
1389 Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
1390 Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1391 Value divSquared = b.create<arith::DivFOp>(absValue, squared);
1392 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1393 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1394
1395 // x = 0.33333333f*(2.0f*x + x0/(x*x));
1396 squared = b.create<arith::MulFOp>(floatValue, floatValue);
1397 mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1398 divSquared = b.create<arith::DivFOp>(absValue, squared);
1399 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1400 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1401
1402 // Check for zero and restore sign.
1403 Value isZero =
1404 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
1405 floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
1406 floatValue = b.create<math::CopySignOp>(floatValue, operand);
1407
1408 rewriter.replaceOp(op, floatValue);
1409 return success();
1410}
1411
1412//----------------------------------------------------------------------------//
1413// Rsqrt approximation.
1414//----------------------------------------------------------------------------//
1415
1416namespace {
1417struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
1418 using OpRewritePattern::OpRewritePattern;
1419
1420 LogicalResult matchAndRewrite(math::RsqrtOp op,
1421 PatternRewriter &rewriter) const final;
1422};
1423} // namespace
1424
1425LogicalResult
1426RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1427 PatternRewriter &rewriter) const {
1428 if (!getElementTypeOrSelf(op.getOperand()).isF32())
1429 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1430
1431 VectorShape shape = vectorShape(op.getOperand());
1432
1433 // Only support already-vectorized rsqrt's.
1434 if (shape.empty() || shape.sizes.back() % 8 != 0)
1435 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1436
1437 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1438 auto bcast = [&](Value value) -> Value {
1439 return broadcast(builder, value, shape);
1440 };
1441
1442 Value cstPosInf = bcast(f32FromBits(builder, bits: 0x7f800000u));
1443 Value cstOnePointFive = bcast(f32Cst(builder, value: 1.5f));
1444 Value cstNegHalf = bcast(f32Cst(builder, value: -0.5f));
1445 Value cstMinNormPos = bcast(f32FromBits(builder, bits: 0x00800000u));
1446
1447 Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf);
1448
1449 // Select only the inverse sqrt of positive normals (denormals are
1450 // flushed to zero).
1451 Value ltMinMask = builder.create<arith::CmpFOp>(
1452 arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos);
1453 Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1454 op.getOperand(), cstPosInf);
1455 Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask);
1456
1457 // Compute an approximate result.
1458 Value yApprox = handleMultidimensionalVectors(
1459 builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
1460 return builder.create<x86vector::RsqrtOp>(operands);
1461 });
1462
1463 // Do a single step of Newton-Raphson iteration to improve the approximation.
1464 // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
1465 // It is essential to evaluate the inner term like this because forming
1466 // y_n^2 may over- or underflow.
1467 Value inner = builder.create<arith::MulFOp>(negHalf, yApprox);
1468 Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1469 Value yNewton = builder.create<arith::MulFOp>(yApprox, fma);
1470
1471 // Select the result of the Newton-Raphson step for positive normal arguments.
1472 // For other arguments, choose the output of the intrinsic. This will
1473 // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
1474 // x is zero or a positive denormalized float (equivalent to flushing positive
1475 // denormalized inputs to zero).
1476 Value res =
1477 builder.create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1478 rewriter.replaceOp(op, res);
1479
1480 return success();
1481}
1482
1483//----------------------------------------------------------------------------//
1484
1485void mlir::populatePolynomialApproximateTanhPattern(
1486 RewritePatternSet &patterns) {
1487 patterns.add<TanhApproximation>(arg: patterns.getContext());
1488}
1489
1490void mlir::populatePolynomialApproximateErfPattern(
1491 RewritePatternSet &patterns) {
1492 patterns.add<ErfPolynomialApproximation>(arg: patterns.getContext());
1493}
1494
1495void mlir::populateMathPolynomialApproximationPatterns(
1496 RewritePatternSet &patterns,
1497 const MathPolynomialApproximationOptions &options) {
1498 // Patterns for leveraging existing f32 lowerings on other data types.
1499 patterns
1500 .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1501 ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1502 ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1503 ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1504 ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1505 ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1506 patterns.getContext());
1507
1508 patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
1509 LogApproximation, Log2Approximation, Log1pApproximation,
1510 ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1511 CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1512 SinAndCosApproximation<false, math::CosOp>>(
1513 patterns.getContext());
1514 if (options.enableAvx2) {
1515 patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
1516 patterns.getContext());
1517 }
1518}
1519

source code of mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp