1 | //===- PolynomialApproximation.cpp - Approximate math operations ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements expansion of math operations to fast approximations |
10 | // that do not rely on any of the library functions. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include <climits> |
15 | #include <cmath> |
16 | #include <cstddef> |
17 | |
18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
19 | #include "mlir/Dialect/Math/IR/Math.h" |
20 | #include "mlir/Dialect/Math/Transforms/Approximation.h" |
21 | #include "mlir/Dialect/Math/Transforms/Passes.h" |
22 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
23 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
24 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
25 | #include "mlir/Dialect/X86Vector/X86VectorDialect.h" |
26 | #include "mlir/IR/Builders.h" |
27 | #include "mlir/IR/BuiltinTypes.h" |
28 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
29 | #include "mlir/IR/OpDefinition.h" |
30 | #include "mlir/IR/PatternMatch.h" |
31 | #include "mlir/IR/TypeUtilities.h" |
32 | #include "mlir/Transforms/DialectConversion.h" |
33 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
34 | #include "llvm/ADT/ArrayRef.h" |
35 | #include "llvm/ADT/STLExtras.h" |
36 | #include "llvm/Support/MathExtras.h" |
37 | |
38 | using namespace mlir; |
39 | using namespace mlir::math; |
40 | using namespace mlir::vector; |
41 | |
42 | // Helper to encapsulate a vector's shape (including scalable dims). |
43 | struct VectorShape { |
44 | ArrayRef<int64_t> sizes; |
45 | ArrayRef<bool> scalableFlags; |
46 | |
47 | bool empty() const { return sizes.empty(); } |
48 | }; |
49 | |
50 | // Returns vector shape if the type is a vector. Returns an empty shape if it is |
51 | // not a vector. |
52 | static VectorShape vectorShape(Type type) { |
53 | auto vectorType = dyn_cast<VectorType>(type); |
54 | return vectorType |
55 | ? VectorShape{vectorType.getShape(), vectorType.getScalableDims()} |
56 | : VectorShape{}; |
57 | } |
58 | |
59 | static VectorShape vectorShape(Value value) { |
60 | return vectorShape(type: value.getType()); |
61 | } |
62 | |
63 | //----------------------------------------------------------------------------// |
64 | // Broadcast scalar types and values into vector types and values. |
65 | //----------------------------------------------------------------------------// |
66 | |
67 | // Broadcasts scalar type into vector type (iff shape is non-scalar). |
68 | static Type broadcast(Type type, VectorShape shape) { |
69 | assert(!isa<VectorType>(type) && "must be scalar type" ); |
70 | return !shape.empty() |
71 | ? VectorType::get(shape.sizes, type, shape.scalableFlags) |
72 | : type; |
73 | } |
74 | |
75 | // Broadcasts scalar value into vector (iff shape is non-scalar). |
76 | static Value broadcast(ImplicitLocOpBuilder &builder, Value value, |
77 | VectorShape shape) { |
78 | assert(!isa<VectorType>(value.getType()) && "must be scalar value" ); |
79 | auto type = broadcast(type: value.getType(), shape); |
80 | return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value; |
81 | } |
82 | |
83 | //----------------------------------------------------------------------------// |
84 | // Helper function to handle n-D vectors with 1-D operations. |
85 | //----------------------------------------------------------------------------// |
86 | |
87 | // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors |
88 | // and calls the compute function with 1-D vector operands. Stitches back all |
89 | // results into the original n-D vector result. |
90 | // |
91 | // Examples: vectorWidth = 8 |
92 | // - vector<4x8xf32> unrolled 4 times |
93 | // - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times |
94 | // - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times |
95 | // |
96 | // Some math approximations rely on ISA-specific operations that only accept |
97 | // fixed size 1-D vectors (e.g. AVX expects vectors of width 8). |
98 | // |
99 | // It is the caller's responsibility to verify that the inner dimension is |
100 | // divisible by the vectorWidth, and that all operands have the same vector |
101 | // shape. |
102 | static Value |
103 | handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, |
104 | ValueRange operands, int64_t vectorWidth, |
105 | llvm::function_ref<Value(ValueRange)> compute) { |
106 | assert(!operands.empty() && "operands must be not empty" ); |
107 | assert(vectorWidth > 0 && "vector width must be larger than 0" ); |
108 | |
109 | VectorType inputType = cast<VectorType>(operands[0].getType()); |
110 | ArrayRef<int64_t> inputShape = inputType.getShape(); |
111 | |
112 | // If input shape matches target vector width, we can just call the |
113 | // user-provided compute function with the operands. |
114 | if (inputShape == llvm::ArrayRef(vectorWidth)) |
115 | return compute(operands); |
116 | |
117 | // Check if the inner dimension has to be expanded, or we can directly iterate |
118 | // over the outer dimensions of the vector. |
119 | int64_t innerDim = inputShape.back(); |
120 | int64_t expansionDim = innerDim / vectorWidth; |
121 | assert((innerDim % vectorWidth == 0) && "invalid inner dimension size" ); |
122 | |
123 | // Maybe expand operands to the higher rank vector shape that we'll use to |
124 | // iterate over and extract one dimensional vectors. |
125 | SmallVector<int64_t> expandedShape(inputShape.begin(), inputShape.end()); |
126 | SmallVector<Value> expandedOperands(operands); |
127 | |
128 | if (expansionDim > 1) { |
129 | // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth]. |
130 | expandedShape.insert(I: expandedShape.end() - 1, Elt: expansionDim); |
131 | expandedShape.back() = vectorWidth; |
132 | |
133 | for (unsigned i = 0; i < operands.size(); ++i) { |
134 | auto operand = operands[i]; |
135 | auto eltType = cast<VectorType>(operand.getType()).getElementType(); |
136 | auto expandedType = VectorType::get(expandedShape, eltType); |
137 | expandedOperands[i] = |
138 | builder.create<vector::ShapeCastOp>(expandedType, operand); |
139 | } |
140 | } |
141 | |
142 | // Iterate over all outer dimensions of the compute shape vector type. |
143 | auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back(); |
144 | int64_t maxIndex = computeMaxLinearIndex(iterationDims); |
145 | auto strides = computeStrides(iterationDims); |
146 | |
147 | // Compute results for each one dimensional vector. |
148 | SmallVector<Value> results(maxIndex); |
149 | |
150 | for (int64_t i = 0; i < maxIndex; ++i) { |
151 | auto offsets = delinearize(i, strides); |
152 | |
153 | SmallVector<Value> (expandedOperands.size()); |
154 | for (const auto &tuple : llvm::enumerate(expandedOperands)) |
155 | extracted[tuple.index()] = |
156 | builder.create<vector::ExtractOp>(tuple.value(), offsets); |
157 | |
158 | results[i] = compute(extracted); |
159 | } |
160 | |
161 | // Stitch results together into one large vector. |
162 | Type resultEltType = cast<VectorType>(results[0].getType()).getElementType(); |
163 | Type resultExpandedType = VectorType::get(expandedShape, resultEltType); |
164 | Value result = builder.create<arith::ConstantOp>( |
165 | resultExpandedType, builder.getZeroAttr(resultExpandedType)); |
166 | |
167 | for (int64_t i = 0; i < maxIndex; ++i) |
168 | result = builder.create<vector::InsertOp>(results[i], result, |
169 | delinearize(i, strides)); |
170 | |
171 | // Reshape back to the original vector shape. |
172 | return builder.create<vector::ShapeCastOp>( |
173 | VectorType::get(inputShape, resultEltType), result); |
174 | } |
175 | |
176 | //----------------------------------------------------------------------------// |
177 | // Helper functions to create constants. |
178 | //----------------------------------------------------------------------------// |
179 | |
180 | static Value floatCst(ImplicitLocOpBuilder &builder, float value, |
181 | Type elementType) { |
182 | assert((elementType.isF16() || elementType.isF32()) && |
183 | "x must be f16 or f32 type." ); |
184 | return builder.create<arith::ConstantOp>( |
185 | builder.getFloatAttr(elementType, value)); |
186 | } |
187 | |
188 | static Value f32Cst(ImplicitLocOpBuilder &builder, double value) { |
189 | return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value)); |
190 | } |
191 | |
192 | static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { |
193 | return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value)); |
194 | } |
195 | |
196 | static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { |
197 | Value i32Value = i32Cst(builder, value: static_cast<int32_t>(bits)); |
198 | return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value); |
199 | } |
200 | |
201 | //----------------------------------------------------------------------------// |
202 | // Helper functions to build math functions approximations. |
203 | //----------------------------------------------------------------------------// |
204 | |
205 | // Return the minimum of the two values or NaN if value is NaN |
206 | static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) { |
207 | return builder.create<arith::SelectOp>( |
208 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound), |
209 | value, bound); |
210 | } |
211 | |
212 | // Return the maximum of the two values or NaN if value is NaN |
213 | static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) { |
214 | return builder.create<arith::SelectOp>( |
215 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound), |
216 | value, bound); |
217 | } |
218 | |
219 | // Return the clamped value or NaN if value is NaN |
220 | static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, |
221 | Value upperBound) { |
222 | return max(builder, value: min(builder, value, bound: upperBound), bound: lowerBound); |
223 | } |
224 | |
225 | // Decomposes given floating point value `arg` into a normalized fraction and |
226 | // an integral power of two (see std::frexp). Returned values have float type. |
227 | static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg, |
228 | bool isPositive = false) { |
229 | assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type" ); |
230 | VectorShape shape = vectorShape(value: arg); |
231 | |
232 | auto bcast = [&](Value value) -> Value { |
233 | return broadcast(builder, value, shape); |
234 | }; |
235 | |
236 | auto i32 = builder.getIntegerType(32); |
237 | auto i32Vec = broadcast(i32, shape); |
238 | auto f32Vec = broadcast(type: builder.getF32Type(), shape); |
239 | |
240 | Value cst126f = f32Cst(builder, value: 126.0f); |
241 | Value cstHalf = f32Cst(builder, value: 0.5f); |
242 | Value cstInvMantMask = f32FromBits(builder, bits: ~0x7f800000u); |
243 | |
244 | // Bitcast to i32 for bitwise operations. |
245 | Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf); |
246 | Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask); |
247 | Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg); |
248 | |
249 | // Compute normalized fraction. |
250 | Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask)); |
251 | Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half)); |
252 | Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1); |
253 | |
254 | // Compute exponent. |
255 | Value arg0 = isPositive ? arg : builder.create<math::AbsFOp>(arg); |
256 | Value biasedExponentBits = builder.create<arith::ShRUIOp>( |
257 | builder.create<arith::BitcastOp>(i32Vec, arg0), |
258 | bcast(i32Cst(builder, 23))); |
259 | Value biasedExponent = |
260 | builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits); |
261 | Value exponent = |
262 | builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f)); |
263 | |
264 | return {normalizedFraction, exponent}; |
265 | } |
266 | |
267 | // Computes exp2 for an i32 argument. |
268 | static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { |
269 | assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type" ); |
270 | VectorShape shape = vectorShape(value: arg); |
271 | |
272 | auto bcast = [&](Value value) -> Value { |
273 | return broadcast(builder, value, shape); |
274 | }; |
275 | |
276 | auto f32Vec = broadcast(type: builder.getF32Type(), shape); |
277 | // The exponent of f32 located at 23-bit. |
278 | auto exponetBitLocation = bcast(i32Cst(builder, value: 23)); |
279 | // Set the exponent bias to zero. |
280 | auto bias = bcast(i32Cst(builder, value: 127)); |
281 | |
282 | Value biasedArg = builder.create<arith::AddIOp>(arg, bias); |
283 | Value exp2ValueInt = |
284 | builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation); |
285 | Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt); |
286 | |
287 | return exp2ValueF32; |
288 | } |
289 | |
290 | namespace { |
291 | Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, |
292 | llvm::ArrayRef<Value> coeffs, Value x) { |
293 | Type elementType = getElementTypeOrSelf(val: x); |
294 | assert((elementType.isF32() || elementType.isF16()) && |
295 | "x must be f32 or f16 type" ); |
296 | VectorShape shape = vectorShape(value: x); |
297 | |
298 | if (coeffs.empty()) |
299 | return broadcast(builder, value: floatCst(builder, value: 0.0f, elementType), shape); |
300 | |
301 | if (coeffs.size() == 1) |
302 | return coeffs[0]; |
303 | |
304 | Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1], |
305 | coeffs[coeffs.size() - 2]); |
306 | for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) { |
307 | res = builder.create<math::FmaOp>(x, res, coeffs[i]); |
308 | } |
309 | return res; |
310 | } |
311 | } // namespace |
312 | |
313 | //----------------------------------------------------------------------------// |
314 | // Helper function/pattern to insert casts for reusing F32 bit expansion. |
315 | //----------------------------------------------------------------------------// |
316 | |
317 | template <typename T> |
318 | LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) { |
319 | // Conservatively only allow where the operand and result types are exactly 1. |
320 | Type origType = op->getResultTypes().front(); |
321 | for (Type t : llvm::drop_begin(RangeOrContainer: op->getResultTypes())) |
322 | if (origType != t) |
323 | return rewriter.notifyMatchFailure(arg&: op, msg: "required all types to match" ); |
324 | for (Type t : op->getOperandTypes()) |
325 | if (origType != t) |
326 | return rewriter.notifyMatchFailure(arg&: op, msg: "required all types to match" ); |
327 | |
328 | // Skip if already F32 or larger than 32 bits. |
329 | if (getElementTypeOrSelf(type: origType).isF32() || |
330 | getElementTypeOrSelf(type: origType).getIntOrFloatBitWidth() > 32) |
331 | return failure(); |
332 | |
333 | // Create F32 equivalent type. |
334 | Type newType; |
335 | if (auto shaped = dyn_cast<ShapedType>(origType)) { |
336 | newType = shaped.clone(rewriter.getF32Type()); |
337 | } else if (isa<FloatType>(Val: origType)) { |
338 | newType = rewriter.getF32Type(); |
339 | } else { |
340 | return rewriter.notifyMatchFailure(arg&: op, |
341 | msg: "unable to find F32 equivalent type" ); |
342 | } |
343 | |
344 | Location loc = op->getLoc(); |
345 | SmallVector<Value> operands; |
346 | for (auto operand : op->getOperands()) |
347 | operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand)); |
348 | auto result = |
349 | rewriter.create<T>(loc, TypeRange{newType}, operands, op->getAttrs()); |
350 | rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result); |
351 | return success(); |
352 | } |
353 | |
354 | namespace { |
355 | // Pattern to cast to F32 to reuse F32 expansion as fallback for single-result |
356 | // op. |
357 | // TODO: Consider revising to avoid adding multiple casts for a subgraph that is |
358 | // all in lower precision. Currently this is only fallback support and performs |
359 | // simplistic casting. |
360 | template <typename T> |
361 | struct ReuseF32Expansion : public OpRewritePattern<T> { |
362 | public: |
363 | using OpRewritePattern<T>::OpRewritePattern; |
364 | LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final { |
365 | static_assert( |
366 | T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(), |
367 | "requires same operands and result types" ); |
368 | return insertCasts<T>(op, rewriter); |
369 | } |
370 | }; |
371 | } // namespace |
372 | |
373 | //----------------------------------------------------------------------------// |
374 | // AtanOp approximation. |
375 | //----------------------------------------------------------------------------// |
376 | |
377 | namespace { |
378 | struct AtanApproximation : public OpRewritePattern<math::AtanOp> { |
379 | public: |
380 | using OpRewritePattern::OpRewritePattern; |
381 | |
382 | LogicalResult matchAndRewrite(math::AtanOp op, |
383 | PatternRewriter &rewriter) const final; |
384 | }; |
385 | } // namespace |
386 | |
387 | LogicalResult |
388 | AtanApproximation::matchAndRewrite(math::AtanOp op, |
389 | PatternRewriter &rewriter) const { |
390 | auto operand = op.getOperand(); |
391 | if (!getElementTypeOrSelf(operand).isF32()) |
392 | return rewriter.notifyMatchFailure(op, "unsupported operand type" ); |
393 | |
394 | VectorShape shape = vectorShape(op.getOperand()); |
395 | |
396 | ImplicitLocOpBuilder builder(op->getLoc(), rewriter); |
397 | Value abs = builder.create<math::AbsFOp>(operand); |
398 | |
399 | auto one = broadcast(builder, value: f32Cst(builder, value: 1.0), shape); |
400 | |
401 | // When 0.66 < x <= 2.41 we do (x-1) / (x+1): |
402 | auto twoThirds = broadcast(builder, value: f32Cst(builder, value: 0.66), shape); |
403 | Value cmp2 = |
404 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, twoThirds); |
405 | Value addone = builder.create<arith::AddFOp>(abs, one); |
406 | Value subone = builder.create<arith::SubFOp>(abs, one); |
407 | Value xnum = builder.create<arith::SelectOp>(cmp2, subone, abs); |
408 | Value xden = builder.create<arith::SelectOp>(cmp2, addone, one); |
409 | |
410 | auto bcast = [&](Value value) -> Value { |
411 | return broadcast(builder, value, shape); |
412 | }; |
413 | |
414 | // Break into the <= 0.66 or > 2.41 we do x or 1/x: |
415 | auto tan3pio8 = bcast(f32Cst(builder, value: 2.41421356237309504880)); |
416 | Value cmp1 = |
417 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, tan3pio8); |
418 | xnum = builder.create<arith::SelectOp>(cmp1, one, xnum); |
419 | xden = builder.create<arith::SelectOp>(cmp1, abs, xden); |
420 | |
421 | Value x = builder.create<arith::DivFOp>(xnum, xden); |
422 | Value xx = builder.create<arith::MulFOp>(x, x); |
423 | |
424 | // Perform the Taylor series approximation for atan over the range |
425 | // [0.0, 0.66]. |
426 | auto p0 = bcast(f32Cst(builder, value: -8.750608600031904122785e-01)); |
427 | auto p1 = bcast(f32Cst(builder, value: -1.615753718733365076637e+01)); |
428 | auto p2 = bcast(f32Cst(builder, value: -7.500855792314704667340e+01)); |
429 | auto p3 = bcast(f32Cst(builder, value: -1.228866684490136173410e+02)); |
430 | auto p4 = bcast(f32Cst(builder, value: -6.485021904942025371773e+01)); |
431 | auto q0 = bcast(f32Cst(builder, value: +2.485846490142306297962e+01)); |
432 | auto q1 = bcast(f32Cst(builder, value: +1.650270098316988542046e+02)); |
433 | auto q2 = bcast(f32Cst(builder, value: +4.328810604912902668951e+02)); |
434 | auto q3 = bcast(f32Cst(builder, value: +4.853903996359136964868e+02)); |
435 | auto q4 = bcast(f32Cst(builder, value: +1.945506571482613964425e+02)); |
436 | |
437 | // Apply the polynomial approximation for the numerator: |
438 | Value n = p0; |
439 | n = builder.create<math::FmaOp>(xx, n, p1); |
440 | n = builder.create<math::FmaOp>(xx, n, p2); |
441 | n = builder.create<math::FmaOp>(xx, n, p3); |
442 | n = builder.create<math::FmaOp>(xx, n, p4); |
443 | n = builder.create<arith::MulFOp>(n, xx); |
444 | |
445 | // Apply the polynomial approximation for the denominator: |
446 | Value d = q0; |
447 | d = builder.create<math::FmaOp>(xx, d, q1); |
448 | d = builder.create<math::FmaOp>(xx, d, q2); |
449 | d = builder.create<math::FmaOp>(xx, d, q3); |
450 | d = builder.create<math::FmaOp>(xx, d, q4); |
451 | |
452 | // Compute approximation of theta: |
453 | Value ans0 = builder.create<arith::DivFOp>(n, d); |
454 | ans0 = builder.create<math::FmaOp>(ans0, x, x); |
455 | |
456 | // Correct for the input mapping's angles: |
457 | Value mpi4 = bcast(f32Cst(builder, value: llvm::numbers::pi / 4)); |
458 | Value ans2 = builder.create<arith::AddFOp>(mpi4, ans0); |
459 | Value ans = builder.create<arith::SelectOp>(cmp2, ans2, ans0); |
460 | |
461 | Value mpi2 = bcast(f32Cst(builder, value: llvm::numbers::pi / 2)); |
462 | Value ans1 = builder.create<arith::SubFOp>(mpi2, ans0); |
463 | ans = builder.create<arith::SelectOp>(cmp1, ans1, ans); |
464 | |
465 | // Correct for signing of the input. |
466 | rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand); |
467 | return success(); |
468 | } |
469 | |
470 | //----------------------------------------------------------------------------// |
471 | // AtanOp approximation. |
472 | //----------------------------------------------------------------------------// |
473 | |
474 | namespace { |
475 | struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> { |
476 | public: |
477 | using OpRewritePattern::OpRewritePattern; |
478 | |
479 | LogicalResult matchAndRewrite(math::Atan2Op op, |
480 | PatternRewriter &rewriter) const final; |
481 | }; |
482 | } // namespace |
483 | |
484 | LogicalResult |
485 | Atan2Approximation::matchAndRewrite(math::Atan2Op op, |
486 | PatternRewriter &rewriter) const { |
487 | auto y = op.getOperand(0); |
488 | auto x = op.getOperand(1); |
489 | if (!getElementTypeOrSelf(x).isF32()) |
490 | return rewriter.notifyMatchFailure(op, "unsupported operand type" ); |
491 | |
492 | ImplicitLocOpBuilder builder(op->getLoc(), rewriter); |
493 | VectorShape shape = vectorShape(op.getResult()); |
494 | |
495 | // Compute atan in the valid range. |
496 | auto div = builder.create<arith::DivFOp>(y, x); |
497 | auto atan = builder.create<math::AtanOp>(div); |
498 | |
499 | // Determine what the atan would be for a 180 degree rotation. |
500 | auto zero = broadcast(builder, value: f32Cst(builder, value: 0.0f), shape); |
501 | auto pi = broadcast(builder, value: f32Cst(builder, value: 3.14159265359f), shape); |
502 | auto addPi = builder.create<arith::AddFOp>(atan, pi); |
503 | auto subPi = builder.create<arith::SubFOp>(atan, pi); |
504 | auto atanGt = |
505 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero); |
506 | auto flippedAtan = builder.create<arith::SelectOp>(atanGt, subPi, addPi); |
507 | |
508 | // Determine whether to directly use atan or use the 180 degree flip |
509 | auto xGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero); |
510 | Value result = builder.create<arith::SelectOp>(xGt, atan, flippedAtan); |
511 | |
512 | // Handle x = 0, y > 0 |
513 | Value xZero = |
514 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero); |
515 | Value yGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero); |
516 | Value isHalfPi = builder.create<arith::AndIOp>(xZero, yGt); |
517 | auto halfPi = broadcast(builder, value: f32Cst(builder, value: 1.57079632679f), shape); |
518 | result = builder.create<arith::SelectOp>(isHalfPi, halfPi, result); |
519 | |
520 | // Handle x = 0, y < 0 |
521 | Value yLt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero); |
522 | Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(xZero, yLt); |
523 | auto negativeHalfPiPi = |
524 | broadcast(builder, value: f32Cst(builder, value: -1.57079632679f), shape); |
525 | result = builder.create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi, |
526 | result); |
527 | |
528 | // Handle x = 0, y = 0; |
529 | Value yZero = |
530 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero); |
531 | Value isNan = builder.create<arith::AndIOp>(xZero, yZero); |
532 | Value cstNan = broadcast(builder, value: f32FromBits(builder, bits: 0x7fc00000), shape); |
533 | result = builder.create<arith::SelectOp>(isNan, cstNan, result); |
534 | |
535 | rewriter.replaceOp(op, result); |
536 | return success(); |
537 | } |
538 | |
539 | //----------------------------------------------------------------------------// |
540 | // TanhOp approximation. |
541 | //----------------------------------------------------------------------------// |
542 | |
543 | namespace { |
544 | struct TanhApproximation : public OpRewritePattern<math::TanhOp> { |
545 | public: |
546 | using OpRewritePattern::OpRewritePattern; |
547 | |
548 | LogicalResult matchAndRewrite(math::TanhOp op, |
549 | PatternRewriter &rewriter) const final; |
550 | }; |
551 | } // namespace |
552 | |
553 | LogicalResult |
554 | TanhApproximation::matchAndRewrite(math::TanhOp op, |
555 | PatternRewriter &rewriter) const { |
556 | if (!getElementTypeOrSelf(op.getOperand()).isF32()) |
557 | return rewriter.notifyMatchFailure(op, "unsupported operand type" ); |
558 | |
559 | VectorShape shape = vectorShape(op.getOperand()); |
560 | |
561 | ImplicitLocOpBuilder builder(op->getLoc(), rewriter); |
562 | auto bcast = [&](Value value) -> Value { |
563 | return broadcast(builder, value, shape); |
564 | }; |
565 | |
566 | // Clamp operand into [plusClamp, minusClamp] range. |
567 | Value minusClamp = bcast(f32Cst(builder, value: -7.99881172180175781f)); |
568 | Value plusClamp = bcast(f32Cst(builder, value: 7.99881172180175781f)); |
569 | Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp); |
570 | |
571 | // Mask for tiny values that are approximated with `operand`. |
572 | Value tiny = bcast(f32Cst(builder, value: 0.0004f)); |
573 | Value tinyMask = builder.create<arith::CmpFOp>( |
574 | arith::CmpFPredicate::OLT, builder.create<math::AbsFOp>(op.getOperand()), |
575 | tiny); |
576 | |
577 | // The monomial coefficients of the numerator polynomial (odd). |
578 | Value alpha1 = bcast(f32Cst(builder, value: 4.89352455891786e-03f)); |
579 | Value alpha3 = bcast(f32Cst(builder, value: 6.37261928875436e-04f)); |
580 | Value alpha5 = bcast(f32Cst(builder, value: 1.48572235717979e-05f)); |
581 | Value alpha7 = bcast(f32Cst(builder, value: 5.12229709037114e-08f)); |
582 | Value alpha9 = bcast(f32Cst(builder, value: -8.60467152213735e-11f)); |
583 | Value alpha11 = bcast(f32Cst(builder, value: 2.00018790482477e-13f)); |
584 | Value alpha13 = bcast(f32Cst(builder, value: -2.76076847742355e-16f)); |
585 | |
586 | // The monomial coefficients of the denominator polynomial (even). |
587 | Value beta0 = bcast(f32Cst(builder, value: 4.89352518554385e-03f)); |
588 | Value beta2 = bcast(f32Cst(builder, value: 2.26843463243900e-03f)); |
589 | Value beta4 = bcast(f32Cst(builder, value: 1.18534705686654e-04f)); |
590 | Value beta6 = bcast(f32Cst(builder, value: 1.19825839466702e-06f)); |
591 | |
592 | // Since the polynomials are odd/even, we need x^2. |
593 | Value x2 = builder.create<arith::MulFOp>(x, x); |
594 | |
595 | // Evaluate the numerator polynomial p. |
596 | Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11); |
597 | p = builder.create<math::FmaOp>(x2, p, alpha9); |
598 | p = builder.create<math::FmaOp>(x2, p, alpha7); |
599 | p = builder.create<math::FmaOp>(x2, p, alpha5); |
600 | p = builder.create<math::FmaOp>(x2, p, alpha3); |
601 | p = builder.create<math::FmaOp>(x2, p, alpha1); |
602 | p = builder.create<arith::MulFOp>(x, p); |
603 | |
604 | // Evaluate the denominator polynomial q. |
605 | Value q = builder.create<math::FmaOp>(x2, beta6, beta4); |
606 | q = builder.create<math::FmaOp>(x2, q, beta2); |
607 | q = builder.create<math::FmaOp>(x2, q, beta0); |
608 | |
609 | // Divide the numerator by the denominator. |
610 | Value res = builder.create<arith::SelectOp>( |
611 | tinyMask, x, builder.create<arith::DivFOp>(p, q)); |
612 | |
613 | rewriter.replaceOp(op, res); |
614 | |
615 | return success(); |
616 | } |
617 | |
618 | #define LN2_VALUE \ |
619 | 0.693147180559945309417232121458176568075500134360255254120680009493393621L |
620 | #define LOG2E_VALUE \ |
621 | 1.442695040888963407359924681001892137426645954152985934135449406931109219L |
622 | |
623 | //----------------------------------------------------------------------------// |
624 | // LogOp and Log2Op approximation. |
625 | //----------------------------------------------------------------------------// |
626 | |
627 | namespace { |
628 | template <typename Op> |
629 | struct LogApproximationBase : public OpRewritePattern<Op> { |
630 | using OpRewritePattern<Op>::OpRewritePattern; |
631 | |
632 | /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise. |
633 | LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter, |
634 | bool base2) const; |
635 | }; |
636 | } // namespace |
637 | |
638 | // This approximation comes from Julien Pommier's SSE math library. |
639 | // Link: http://gruntthepeon.free.fr/ssemath |
640 | template <typename Op> |
641 | LogicalResult |
642 | LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter, |
643 | bool base2) const { |
644 | if (!getElementTypeOrSelf(op.getOperand()).isF32()) |
645 | return rewriter.notifyMatchFailure(op, "unsupported operand type" ); |
646 | |
647 | VectorShape shape = vectorShape(op.getOperand()); |
648 | |
649 | ImplicitLocOpBuilder builder(op->getLoc(), rewriter); |
650 | auto bcast = [&](Value value) -> Value { |
651 | return broadcast(builder, value, shape); |
652 | }; |
653 | |
654 | Value cstZero = bcast(f32Cst(builder, value: 0.0f)); |
655 | Value cstOne = bcast(f32Cst(builder, value: 1.0f)); |
656 | Value cstNegHalf = bcast(f32Cst(builder, value: -0.5f)); |
657 | |
658 | // The smallest non denormalized float number. |
659 | Value cstMinNormPos = bcast(f32FromBits(builder, bits: 0x00800000u)); |
660 | Value cstMinusInf = bcast(f32FromBits(builder, bits: 0xff800000u)); |
661 | Value cstPosInf = bcast(f32FromBits(builder, bits: 0x7f800000u)); |
662 | Value cstNan = bcast(f32FromBits(builder, bits: 0x7fc00000)); |
663 | |
664 | // Polynomial coefficients. |
665 | Value cstCephesSQRTHF = bcast(f32Cst(builder, value: 0.707106781186547524f)); |
666 | Value cstCephesLogP0 = bcast(f32Cst(builder, value: 7.0376836292E-2f)); |
667 | Value cstCephesLogP1 = bcast(f32Cst(builder, value: -1.1514610310E-1f)); |
668 | Value cstCephesLogP2 = bcast(f32Cst(builder, value: 1.1676998740E-1f)); |
669 | Value cstCephesLogP3 = bcast(f32Cst(builder, value: -1.2420140846E-1f)); |
670 | Value cstCephesLogP4 = bcast(f32Cst(builder, value: +1.4249322787E-1f)); |
671 | Value cstCephesLogP5 = bcast(f32Cst(builder, value: -1.6668057665E-1f)); |
672 | Value cstCephesLogP6 = bcast(f32Cst(builder, value: +2.0000714765E-1f)); |
673 | Value cstCephesLogP7 = bcast(f32Cst(builder, value: -2.4999993993E-1f)); |
674 | Value cstCephesLogP8 = bcast(f32Cst(builder, value: +3.3333331174E-1f)); |
675 | |
676 | Value x = op.getOperand(); |
677 | |
678 | // Truncate input values to the minimum positive normal. |
679 | x = max(builder, value: x, bound: cstMinNormPos); |
680 | |
681 | // Extract significant in the range [0.5,1) and exponent. |
682 | std::pair<Value, Value> pair = frexp(builder, arg: x, /*isPositive=*/true); |
683 | x = pair.first; |
684 | Value e = pair.second; |
685 | |
686 | // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift |
687 | // by -1.0. The values are then centered around 0, which improves the |
688 | // stability of the polynomial evaluation: |
689 | // |
690 | // if( x < SQRTHF ) { |
691 | // e -= 1; |
692 | // x = x + x - 1.0; |
693 | // } else { x = x - 1.0; } |
694 | Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, |
695 | cstCephesSQRTHF); |
696 | Value tmp = builder.create<arith::SelectOp>(mask, x, cstZero); |
697 | |
698 | x = builder.create<arith::SubFOp>(x, cstOne); |
699 | e = builder.create<arith::SubFOp>( |
700 | e, builder.create<arith::SelectOp>(mask, cstOne, cstZero)); |
701 | x = builder.create<arith::AddFOp>(x, tmp); |
702 | |
703 | Value x2 = builder.create<arith::MulFOp>(x, x); |
704 | Value x3 = builder.create<arith::MulFOp>(x2, x); |
705 | |
706 | // Evaluate the polynomial approximant of degree 8 in three parts. |
707 | Value y0, y1, y2; |
708 | y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1); |
709 | y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4); |
710 | y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7); |
711 | y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2); |
712 | y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5); |
713 | y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8); |
714 | y0 = builder.create<math::FmaOp>(y0, x3, y1); |
715 | y0 = builder.create<math::FmaOp>(y0, x3, y2); |
716 | y0 = builder.create<arith::MulFOp>(y0, x3); |
717 | |
718 | y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0); |
719 | x = builder.create<arith::AddFOp>(x, y0); |
720 | |
721 | if (base2) { |
722 | Value cstLog2e = bcast(f32Cst(builder, value: static_cast<float>(LOG2E_VALUE))); |
723 | x = builder.create<math::FmaOp>(x, cstLog2e, e); |
724 | } else { |
725 | Value cstLn2 = bcast(f32Cst(builder, value: static_cast<float>(LN2_VALUE))); |
726 | x = builder.create<math::FmaOp>(e, cstLn2, x); |
727 | } |
728 | |
729 | Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, |
730 | op.getOperand(), cstZero); |
731 | Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, |
732 | op.getOperand(), cstZero); |
733 | Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, |
734 | op.getOperand(), cstPosInf); |
735 | |
736 | // Filter out invalid values: |
737 | // • x == 0 -> -INF |
738 | // • x < 0 -> NAN |
739 | // • x == +INF -> +INF |
740 | Value aproximation = builder.create<arith::SelectOp>( |
741 | zeroMask, cstMinusInf, |
742 | builder.create<arith::SelectOp>( |
743 | invalidMask, cstNan, |
744 | builder.create<arith::SelectOp>(posInfMask, cstPosInf, x))); |
745 | |
746 | rewriter.replaceOp(op, aproximation); |
747 | |
748 | return success(); |
749 | } |
750 | |
751 | namespace { |
752 | struct LogApproximation : public LogApproximationBase<math::LogOp> { |
753 | using LogApproximationBase::LogApproximationBase; |
754 | |
755 | LogicalResult matchAndRewrite(math::LogOp op, |
756 | PatternRewriter &rewriter) const final { |
757 | return logMatchAndRewrite(op, rewriter, /*base2=*/false); |
758 | } |
759 | }; |
760 | } // namespace |
761 | |
762 | namespace { |
763 | struct Log2Approximation : public LogApproximationBase<math::Log2Op> { |
764 | using LogApproximationBase::LogApproximationBase; |
765 | |
766 | LogicalResult matchAndRewrite(math::Log2Op op, |
767 | PatternRewriter &rewriter) const final { |
768 | return logMatchAndRewrite(op, rewriter, /*base2=*/true); |
769 | } |
770 | }; |
771 | } // namespace |
772 | |
773 | //----------------------------------------------------------------------------// |
774 | // Log1p approximation. |
775 | //----------------------------------------------------------------------------// |
776 | |
777 | namespace { |
778 | struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> { |
779 | public: |
780 | using OpRewritePattern::OpRewritePattern; |
781 | |
782 | LogicalResult matchAndRewrite(math::Log1pOp op, |
783 | PatternRewriter &rewriter) const final; |
784 | }; |
785 | } // namespace |
786 | |
787 | // Approximate log(1+x). |
788 | LogicalResult |
789 | Log1pApproximation::matchAndRewrite(math::Log1pOp op, |
790 | PatternRewriter &rewriter) const { |
791 | if (!getElementTypeOrSelf(op.getOperand()).isF32()) |
792 | return rewriter.notifyMatchFailure(op, "unsupported operand type" ); |
793 | |
794 | VectorShape shape = vectorShape(op.getOperand()); |
795 | |
796 | ImplicitLocOpBuilder builder(op->getLoc(), rewriter); |
797 | auto bcast = [&](Value value) -> Value { |
798 | return broadcast(builder, value, shape); |
799 | }; |
800 | |
801 | // Approximate log(1+x) using the following, due to W. Kahan: |
802 | // u = x + 1.0; |
803 | // if (u == 1.0 || u == inf) return x; |
804 | // return x * log(u) / (u - 1.0); |
805 | // ^^^^^^^^^^^^^^^^^^^^^^ |
806 | // "logLarge" below. |
807 | Value cstOne = bcast(f32Cst(builder, value: 1.0f)); |
808 | Value x = op.getOperand(); |
809 | Value u = builder.create<arith::AddFOp>(x, cstOne); |
810 | Value uSmall = |
811 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne); |
812 | Value logU = builder.create<math::LogOp>(u); |
813 | Value uInf = |
814 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU); |
815 | Value logLarge = builder.create<arith::MulFOp>( |
816 | x, builder.create<arith::DivFOp>( |
817 | logU, builder.create<arith::SubFOp>(u, cstOne))); |
818 | Value approximation = builder.create<arith::SelectOp>( |
819 | builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge); |
820 | rewriter.replaceOp(op, approximation); |
821 | return success(); |
822 | } |
823 | |
824 | //----------------------------------------------------------------------------// |
825 | // Erf approximation. |
826 | //----------------------------------------------------------------------------// |
827 | |
828 | // Approximates erf(x) with |
829 | // a - P(x)/Q(x) |
830 | // where P and Q are polynomials of degree 4. |
831 | // Different coefficients are chosen based on the value of x. |
832 | // The approximation error is ~2.5e-07. |
833 | // Boost's minimax tool that utilizes the Remez method was used to find the |
834 | // coefficients. |
835 | LogicalResult |
836 | ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, |
837 | PatternRewriter &rewriter) const { |
838 | Value operand = op.getOperand(); |
839 | Type elementType = getElementTypeOrSelf(val: operand); |
840 | |
841 | if (!(elementType.isF32() || elementType.isF16())) |
842 | return rewriter.notifyMatchFailure(op, |
843 | "only f32 and f16 type is supported." ); |
844 | VectorShape shape = vectorShape(value: operand); |
845 | |
846 | ImplicitLocOpBuilder builder(op->getLoc(), rewriter); |
847 | auto bcast = [&](Value value) -> Value { |
848 | return broadcast(builder, value, shape); |
849 | }; |
850 | |
851 | const int intervalsCount = 3; |
852 | const int polyDegree = 4; |
853 | |
854 | Value zero = bcast(floatCst(builder, value: 0, elementType)); |
855 | Value one = bcast(floatCst(builder, value: 1, elementType)); |
856 | Value pp[intervalsCount][polyDegree + 1]; |
857 | pp[0][0] = bcast(floatCst(builder, value: +0.00000000000000000e+00f, elementType)); |
858 | pp[0][1] = bcast(floatCst(builder, value: +1.12837916222975858e+00f, elementType)); |
859 | pp[0][2] = bcast(floatCst(builder, value: -5.23018562988006470e-01f, elementType)); |
860 | pp[0][3] = bcast(floatCst(builder, value: +2.09741709609267072e-01f, elementType)); |
861 | pp[0][4] = bcast(floatCst(builder, value: +2.58146801602987875e-02f, elementType)); |
862 | pp[1][0] = bcast(floatCst(builder, value: +0.00000000000000000e+00f, elementType)); |
863 | pp[1][1] = bcast(floatCst(builder, value: +1.12750687816789140e+00f, elementType)); |
864 | pp[1][2] = bcast(floatCst(builder, value: -3.64721408487825775e-01f, elementType)); |
865 | pp[1][3] = bcast(floatCst(builder, value: +1.18407396425136952e-01f, elementType)); |
866 | pp[1][4] = bcast(floatCst(builder, value: +3.70645533056476558e-02f, elementType)); |
867 | pp[2][0] = bcast(floatCst(builder, value: -3.30093071049483172e-03f, elementType)); |
868 | pp[2][1] = bcast(floatCst(builder, value: +3.51961938357697011e-03f, elementType)); |
869 | pp[2][2] = bcast(floatCst(builder, value: -1.41373622814988039e-03f, elementType)); |
870 | pp[2][3] = bcast(floatCst(builder, value: +2.53447094961941348e-04f, elementType)); |
871 | pp[2][4] = bcast(floatCst(builder, value: -1.71048029455037401e-05f, elementType)); |
872 | |
873 | Value qq[intervalsCount][polyDegree + 1]; |
874 | qq[0][0] = bcast(floatCst(builder, value: +1.000000000000000000e+00f, elementType)); |
875 | qq[0][1] = bcast(floatCst(builder, value: -4.635138185962547255e-01f, elementType)); |
876 | qq[0][2] = bcast(floatCst(builder, value: +5.192301327279782447e-01f, elementType)); |
877 | qq[0][3] = bcast(floatCst(builder, value: -1.318089722204810087e-01f, elementType)); |
878 | qq[0][4] = bcast(floatCst(builder, value: +7.397964654672315005e-02f, elementType)); |
879 | qq[1][0] = bcast(floatCst(builder, value: +1.00000000000000000e+00f, elementType)); |
880 | qq[1][1] = bcast(floatCst(builder, value: -3.27607011824493086e-01f, elementType)); |
881 | qq[1][2] = bcast(floatCst(builder, value: +4.48369090658821977e-01f, elementType)); |
882 | qq[1][3] = bcast(floatCst(builder, value: -8.83462621207857930e-02f, elementType)); |
883 | qq[1][4] = bcast(floatCst(builder, value: +5.72442770283176093e-02f, elementType)); |
884 | qq[2][0] = bcast(floatCst(builder, value: +1.00000000000000000e+00f, elementType)); |
885 | qq[2][1] = bcast(floatCst(builder, value: -2.06069165953913769e+00f, elementType)); |
886 | qq[2][2] = bcast(floatCst(builder, value: +1.62705939945477759e+00f, elementType)); |
887 | qq[2][3] = bcast(floatCst(builder, value: -5.83389859211130017e-01f, elementType)); |
888 | qq[2][4] = bcast(floatCst(builder, value: +8.21908939856640930e-02f, elementType)); |
889 | |
890 | Value offsets[intervalsCount]; |
891 | offsets[0] = bcast(floatCst(builder, value: 0.0f, elementType)); |
892 | offsets[1] = bcast(floatCst(builder, value: 0.0f, elementType)); |
893 | offsets[2] = bcast(floatCst(builder, value: 1.0f, elementType)); |
894 | |
895 | Value bounds[intervalsCount]; |
896 | bounds[0] = bcast(floatCst(builder, value: 0.8f, elementType)); |
897 | bounds[1] = bcast(floatCst(builder, value: 2.0f, elementType)); |
898 | bounds[2] = bcast(floatCst(builder, value: 3.75f, elementType)); |
899 | |
900 | Value isNegativeArg = |
901 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero); |
902 | Value negArg = builder.create<arith::NegFOp>(operand); |
903 | Value x = builder.create<arith::SelectOp>(isNegativeArg, negArg, operand); |
904 | |
905 | Value offset = offsets[0]; |
906 | Value p[polyDegree + 1]; |
907 | Value q[polyDegree + 1]; |
908 | for (int i = 0; i <= polyDegree; ++i) { |
909 | p[i] = pp[0][i]; |
910 | q[i] = qq[0][i]; |
911 | } |
912 | |
913 | // TODO: maybe use vector stacking to reduce the number of selects. |
914 | Value isLessThanBound[intervalsCount]; |
915 | for (int j = 0; j < intervalsCount - 1; ++j) { |
916 | isLessThanBound[j] = |
917 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]); |
918 | for (int i = 0; i <= polyDegree; ++i) { |
919 | p[i] = builder.create<arith::SelectOp>(isLessThanBound[j], p[i], |
920 | pp[j + 1][i]); |
921 | q[i] = builder.create<arith::SelectOp>(isLessThanBound[j], q[i], |
922 | qq[j + 1][i]); |
923 | } |
924 | offset = builder.create<arith::SelectOp>(isLessThanBound[j], offset, |
925 | offsets[j + 1]); |
926 | } |
927 | isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>( |
928 | arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]); |
929 | |
930 | Value pPoly = makePolynomialCalculation(builder, coeffs: p, x); |
931 | Value qPoly = makePolynomialCalculation(builder, coeffs: q, x); |
932 | Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly); |
933 | Value formula = builder.create<arith::AddFOp>(offset, rationalPoly); |
934 | formula = builder.create<arith::SelectOp>(isLessThanBound[intervalsCount - 1], |
935 | formula, one); |
936 | |
937 | // erf is odd function: erf(x) = -erf(-x). |
938 | Value negFormula = builder.create<arith::NegFOp>(formula); |
939 | Value res = |
940 | builder.create<arith::SelectOp>(isNegativeArg, negFormula, formula); |
941 | |
942 | rewriter.replaceOp(op, res); |
943 | |
944 | return success(); |
945 | } |
946 | |
947 | //----------------------------------------------------------------------------// |
948 | // Exp approximation. |
949 | //----------------------------------------------------------------------------// |
950 | |
951 | namespace { |
952 | |
953 | Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape, |
954 | Value value, float lowerBound, float upperBound) { |
955 | assert(!std::isnan(lowerBound)); |
956 | assert(!std::isnan(upperBound)); |
957 | |
958 | auto bcast = [&](Value value) -> Value { |
959 | return broadcast(builder, value, shape); |
960 | }; |
961 | |
962 | auto selectCmp = [&builder](auto pred, Value value, Value bound) { |
963 | return builder.create<arith::SelectOp>( |
964 | builder.create<arith::CmpFOp>(pred, value, bound), value, bound); |
965 | }; |
966 | |
967 | // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs. |
968 | // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with |
969 | // arith::{Max,Min}FOp. |
970 | value = selectCmp(arith::CmpFPredicate::UGE, value, |
971 | bcast(f32Cst(builder, lowerBound))); |
972 | value = selectCmp(arith::CmpFPredicate::ULE, value, |
973 | bcast(f32Cst(builder, upperBound))); |
974 | return value; |
975 | } |
976 | |
977 | struct ExpApproximation : public OpRewritePattern<math::ExpOp> { |
978 | public: |
979 | using OpRewritePattern::OpRewritePattern; |
980 | |
981 | LogicalResult matchAndRewrite(math::ExpOp op, |
982 | PatternRewriter &rewriter) const final; |
983 | }; |
984 | |
985 | LogicalResult |
986 | ExpApproximation::matchAndRewrite(math::ExpOp op, |
987 | PatternRewriter &rewriter) const { |
988 | auto shape = vectorShape(op.getOperand().getType()); |
989 | auto elementTy = getElementTypeOrSelf(op.getType()); |
990 | if (!elementTy.isF32()) |
991 | return rewriter.notifyMatchFailure(op, "unsupported operand type" ); |
992 | |
993 | ImplicitLocOpBuilder builder(op->getLoc(), rewriter); |
994 | |
995 | auto add = [&](Value a, Value b) -> Value { |
996 | return builder.create<arith::AddFOp>(a, b); |
997 | }; |
998 | auto bcast = [&](Value value) -> Value { |
999 | return broadcast(builder, value, shape); |
1000 | }; |
1001 | auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; |
1002 | auto fmla = [&](Value a, Value b, Value c) { |
1003 | return builder.create<math::FmaOp>(a, b, c); |
1004 | }; |
1005 | auto mul = [&](Value a, Value b) -> Value { |
1006 | return builder.create<arith::MulFOp>(a, b); |
1007 | }; |
1008 | |
1009 | // Polynomial approximation from Cephes. |
1010 | // |
1011 | // To compute e^x, we re-express it as |
1012 | // |
1013 | // e^x = e^(a + b) |
1014 | // = e^(a + n log(2)) |
1015 | // = e^a * 2^n. |
1016 | // |
1017 | // We choose n = round(x / log(2)), restricting the value of `a` to |
1018 | // (-log(2)/2, log(2)/2). We then use a polynomial to compute e^a. The |
1019 | // relative error between our approximation and the true value of e^a is less |
1020 | // than 2^-22.5 for all values of `a` within this range. |
1021 | |
1022 | // Restrict input to a small range, including some values that evaluate to |
1023 | // +/- inf. Note that for our lower bound, we choose log(2^-126) instead of |
1024 | // log(F32_EPSILON). We do so because this routine always flushes denormal |
1025 | // floating points to 0. Therefore, we only need to worry about exponentiating |
1026 | // up to the smallest representable non-denormal floating point, which is |
1027 | // 2^-126. |
1028 | |
1029 | // Constants. |
1030 | Value cstHalf = bcast(f32Cst(builder, value: 0.5f)); |
1031 | Value cstOne = bcast(f32Cst(builder, value: 1.0f)); |
1032 | |
1033 | // 1/log(2) |
1034 | Value cstLog2ef = bcast(f32Cst(builder, value: 1.44269504088896341f)); |
1035 | |
1036 | Value cstExpC1 = bcast(f32Cst(builder, value: -0.693359375f)); |
1037 | Value cstExpC2 = bcast(f32Cst(builder, value: 2.12194440e-4f)); |
1038 | Value cstExpP0 = bcast(f32Cst(builder, value: 1.9875691500E-4f)); |
1039 | Value cstExpP1 = bcast(f32Cst(builder, value: 1.3981999507E-3f)); |
1040 | Value cstExpP2 = bcast(f32Cst(builder, value: 8.3334519073E-3f)); |
1041 | Value cstExpP3 = bcast(f32Cst(builder, value: 4.1665795894E-2f)); |
1042 | Value cstExpP4 = bcast(f32Cst(builder, value: 1.6666665459E-1f)); |
1043 | Value cstExpP5 = bcast(f32Cst(builder, value: 5.0000001201E-1f)); |
1044 | |
1045 | // Our computations below aren't particularly sensitive to the exact choices |
1046 | // here, so we choose values a bit larger/smaller than |
1047 | // |
1048 | // log(F32_MAX) = 88.723... |
1049 | // log(2^-126) = -87.337... |
1050 | Value x = op.getOperand(); |
1051 | x = clampWithNormals(builder, shape, x, -87.8f, 88.8f); |
1052 | Value n = floor(fmla(x, cstLog2ef, cstHalf)); |
1053 | |
1054 | // When we eventually do the multiplication in e^a * 2^n, we need to handle |
1055 | // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1 |
1056 | // (so e^a * 2^n != inf). There's a similar problem for n < -126, the |
1057 | // smallest fp32 exponent. |
1058 | // |
1059 | // A straightforward solution would be to detect n out of range and split it |
1060 | // up, doing |
1061 | // |
1062 | // e^a * 2^n = e^a * 2^(n1 + n2) |
1063 | // = (2^n1 * e^a) * 2^n2. |
1064 | // |
1065 | // But it turns out this approach is quite slow, probably because it |
1066 | // manipulates subnormal values. |
1067 | // |
1068 | // The approach we use instead is to clamp n to [-127, 127]. Let n' be the |
1069 | // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow |
1070 | // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though |
1071 | // this value of `a` is outside our previously specified range, e^a will still |
1072 | // only have a relative error of approximately 2^-16 at worse. In practice |
1073 | // this seems to work well enough; it passes our exhaustive tests, breaking |
1074 | // only one result, and by one ulp (we return exp(88.7228394) = max-float but |
1075 | // we should return inf). |
1076 | // |
1077 | // In the case where n' = -127, the original input value of x is so small that |
1078 | // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest |
1079 | // normal floating point, and since we flush denormals, we simply return 0. We |
1080 | // do this in a branchless way by observing that our code for constructing 2^n |
1081 | // produces 0 if n = -127. |
1082 | // |
1083 | // The proof that n' = -127 implies e^x < 2^-126 is as follows: |
1084 | // |
1085 | // n' = -127 implies n <= -127 |
1086 | // implies round(x / log(2)) <= -127 |
1087 | // implies x/log(2) < -126.5 |
1088 | // implies x < -126.5 * log(2) |
1089 | // implies e^x < e^(-126.5 * log(2)) |
1090 | // implies e^x < 2^-126.5 < 2^-126 |
1091 | // |
1092 | // This proves that n' = -127 implies e^x < 2^-126. |
1093 | n = clampWithNormals(builder, shape, n, -127.0f, 127.0f); |
1094 | |
1095 | // Computes x = x - n' * log(2), the value for `a` |
1096 | x = fmla(cstExpC1, n, x); |
1097 | x = fmla(cstExpC2, n, x); |
1098 | |
1099 | // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5). |
1100 | Value z = fmla(x, cstExpP0, cstExpP1); |
1101 | z = fmla(z, x, cstExpP2); |
1102 | z = fmla(z, x, cstExpP3); |
1103 | z = fmla(z, x, cstExpP4); |
1104 | z = fmla(z, x, cstExpP5); |
1105 | z = fmla(z, mul(x, x), x); |
1106 | z = add(cstOne, z); |
1107 | |
1108 | // Convert n' to an i32. This is safe because we clamped it above. |
1109 | auto i32Vec = broadcast(builder.getI32Type(), shape); |
1110 | Value nI32 = builder.create<arith::FPToSIOp>(i32Vec, n); |
1111 | |
1112 | // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127. |
1113 | Value pow2 = exp2I32(builder, arg: nI32); |
1114 | |
1115 | // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127. |
1116 | Value ret = mul(z, pow2); |
1117 | |
1118 | rewriter.replaceOp(op, ret); |
1119 | return mlir::success(); |
1120 | } |
1121 | |
1122 | } // namespace |
1123 | |
1124 | //----------------------------------------------------------------------------// |
1125 | // ExpM1 approximation. |
1126 | //----------------------------------------------------------------------------// |
1127 | |
1128 | namespace { |
1129 | |
1130 | struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> { |
1131 | public: |
1132 | using OpRewritePattern::OpRewritePattern; |
1133 | |
1134 | LogicalResult matchAndRewrite(math::ExpM1Op op, |
1135 | PatternRewriter &rewriter) const final; |
1136 | }; |
1137 | } // namespace |
1138 | |
1139 | LogicalResult |
1140 | ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, |
1141 | PatternRewriter &rewriter) const { |
1142 | if (!getElementTypeOrSelf(op.getOperand()).isF32()) |
1143 | return rewriter.notifyMatchFailure(op, "unsupported operand type" ); |
1144 | |
1145 | VectorShape shape = vectorShape(op.getOperand()); |
1146 | |
1147 | ImplicitLocOpBuilder builder(op->getLoc(), rewriter); |
1148 | auto bcast = [&](Value value) -> Value { |
1149 | return broadcast(builder, value, shape); |
1150 | }; |
1151 | |
1152 | // expm1(x) = exp(x) - 1 = u - 1. |
1153 | // We have to handle it carefully when x is near 0, i.e. u ~= 1, |
1154 | // and when the input is ~= -inf, i.e. u - 1 ~= -1. |
1155 | Value cstOne = bcast(f32Cst(builder, value: 1.0f)); |
1156 | Value cstNegOne = bcast(f32Cst(builder, value: -1.0f)); |
1157 | Value x = op.getOperand(); |
1158 | Value u = builder.create<math::ExpOp>(x); |
1159 | Value uEqOneOrNaN = |
1160 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne); |
1161 | Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne); |
1162 | Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>( |
1163 | arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); |
1164 | // logU = log(u) ~= x |
1165 | Value logU = builder.create<math::LogOp>(u); |
1166 | |
1167 | // Detect exp(x) = +inf; written this way to avoid having to form +inf. |
1168 | Value isInf = |
1169 | builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u); |
1170 | |
1171 | // (u - 1) * (x / ~x) |
1172 | Value expm1 = builder.create<arith::MulFOp>( |
1173 | uMinusOne, builder.create<arith::DivFOp>(x, logU)); |
1174 | expm1 = builder.create<arith::SelectOp>(isInf, u, expm1); |
1175 | Value approximation = builder.create<arith::SelectOp>( |
1176 | uEqOneOrNaN, x, |
1177 | builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1)); |
1178 | rewriter.replaceOp(op, approximation); |
1179 | return success(); |
1180 | } |
1181 | |
1182 | //----------------------------------------------------------------------------// |
1183 | // Sin and Cos approximation. |
1184 | //----------------------------------------------------------------------------// |
1185 | |
1186 | namespace { |
1187 | |
1188 | template <bool isSine, typename OpTy> |
1189 | struct SinAndCosApproximation : public OpRewritePattern<OpTy> { |
1190 | public: |
1191 | using OpRewritePattern<OpTy>::OpRewritePattern; |
1192 | |
1193 | LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final; |
1194 | }; |
1195 | } // namespace |
1196 | |
1197 | #define TWO_OVER_PI \ |
1198 | 0.6366197723675813430755350534900574481378385829618257949906693762L |
1199 | #define PI_OVER_2 \ |
1200 | 1.5707963267948966192313216916397514420985846996875529104874722961L |
1201 | |
1202 | // Approximates sin(x) or cos(x) by finding the best approximation polynomial in |
1203 | // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the |
1204 | // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y). |
1205 | template <bool isSine, typename OpTy> |
1206 | LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite( |
1207 | OpTy op, PatternRewriter &rewriter) const { |
1208 | static_assert( |
1209 | llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value, |
1210 | "SinAndCosApproximation pattern expects math::SinOp or math::CosOp" ); |
1211 | |
1212 | if (!getElementTypeOrSelf(op.getOperand()).isF32()) |
1213 | return rewriter.notifyMatchFailure(op, "unsupported operand type" ); |
1214 | |
1215 | VectorShape shape = vectorShape(op.getOperand()); |
1216 | |
1217 | ImplicitLocOpBuilder builder(op->getLoc(), rewriter); |
1218 | auto bcast = [&](Value value) -> Value { |
1219 | return broadcast(builder, value, shape); |
1220 | }; |
1221 | auto mul = [&](Value a, Value b) -> Value { |
1222 | return builder.create<arith::MulFOp>(a, b); |
1223 | }; |
1224 | auto sub = [&](Value a, Value b) -> Value { |
1225 | return builder.create<arith::SubFOp>(a, b); |
1226 | }; |
1227 | auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); }; |
1228 | |
1229 | auto i32Vec = broadcast(builder.getI32Type(), shape); |
1230 | auto fPToSingedInteger = [&](Value a) -> Value { |
1231 | return builder.create<arith::FPToSIOp>(i32Vec, a); |
1232 | }; |
1233 | |
1234 | auto modulo4 = [&](Value a) -> Value { |
1235 | return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3))); |
1236 | }; |
1237 | |
1238 | auto isEqualTo = [&](Value a, Value b) -> Value { |
1239 | return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b); |
1240 | }; |
1241 | |
1242 | auto isGreaterThan = [&](Value a, Value b) -> Value { |
1243 | return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b); |
1244 | }; |
1245 | |
1246 | auto select = [&](Value cond, Value t, Value f) -> Value { |
1247 | return builder.create<arith::SelectOp>(cond, t, f); |
1248 | }; |
1249 | |
1250 | auto fmla = [&](Value a, Value b, Value c) { |
1251 | return builder.create<math::FmaOp>(a, b, c); |
1252 | }; |
1253 | |
1254 | auto bitwiseOr = [&](Value a, Value b) { |
1255 | return builder.create<arith::OrIOp>(a, b); |
1256 | }; |
1257 | |
1258 | Value twoOverPi = bcast(f32Cst(builder, value: (float)TWO_OVER_PI)); |
1259 | Value piOverTwo = bcast(f32Cst(builder, value: (float)PI_OVER_2)); |
1260 | |
1261 | Value x = op.getOperand(); |
1262 | |
1263 | Value k = floor(mul(x, twoOverPi)); |
1264 | |
1265 | Value y = sub(x, mul(k, piOverTwo)); |
1266 | |
1267 | Value cstOne = bcast(f32Cst(builder, value: 1.0)); |
1268 | Value cstNegativeOne = bcast(f32Cst(builder, value: -1.0)); |
1269 | |
1270 | Value cstSC2 = bcast(f32Cst(builder, value: -0.16666667163372039794921875f)); |
1271 | Value cstSC4 = bcast(f32Cst(builder, value: 8.333347737789154052734375e-3f)); |
1272 | Value cstSC6 = bcast(f32Cst(builder, value: -1.9842604524455964565277099609375e-4f)); |
1273 | Value cstSC8 = |
1274 | bcast(f32Cst(builder, value: 2.760012648650445044040679931640625e-6f)); |
1275 | Value cstSC10 = |
1276 | bcast(f32Cst(builder, value: -2.50293279435709337121807038784027099609375e-8f)); |
1277 | |
1278 | Value cstCC2 = bcast(f32Cst(builder, value: -0.5f)); |
1279 | Value cstCC4 = bcast(f32Cst(builder, value: 4.166664183139801025390625e-2f)); |
1280 | Value cstCC6 = bcast(f32Cst(builder, value: -1.388833043165504932403564453125e-3f)); |
1281 | Value cstCC8 = bcast(f32Cst(builder, value: 2.47562347794882953166961669921875e-5f)); |
1282 | Value cstCC10 = |
1283 | bcast(f32Cst(builder, value: -2.59630184018533327616751194000244140625e-7f)); |
1284 | |
1285 | Value kMod4 = modulo4(fPToSingedInteger(k)); |
1286 | |
1287 | Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 0))); |
1288 | Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 1))); |
1289 | Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 2))); |
1290 | Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, value: 3))); |
1291 | |
1292 | Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2); |
1293 | Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, value: 1))) |
1294 | : bitwiseOr(kR1, kR2); |
1295 | |
1296 | Value y2 = mul(y, y); |
1297 | |
1298 | Value base = select(sinuseCos, cstOne, y); |
1299 | Value cstC2 = select(sinuseCos, cstCC2, cstSC2); |
1300 | Value cstC4 = select(sinuseCos, cstCC4, cstSC4); |
1301 | Value cstC6 = select(sinuseCos, cstCC6, cstSC6); |
1302 | Value cstC8 = select(sinuseCos, cstCC8, cstSC8); |
1303 | Value cstC10 = select(sinuseCos, cstCC10, cstSC10); |
1304 | |
1305 | Value v1 = fmla(y2, cstC10, cstC8); |
1306 | Value v2 = fmla(y2, v1, cstC6); |
1307 | Value v3 = fmla(y2, v2, cstC4); |
1308 | Value v4 = fmla(y2, v3, cstC2); |
1309 | Value v5 = fmla(y2, v4, cstOne); |
1310 | Value v6 = mul(base, v5); |
1311 | |
1312 | Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6); |
1313 | |
1314 | rewriter.replaceOp(op, approximation); |
1315 | |
1316 | return success(); |
1317 | } |
1318 | |
1319 | //----------------------------------------------------------------------------// |
1320 | // Cbrt approximation. |
1321 | //----------------------------------------------------------------------------// |
1322 | |
1323 | namespace { |
1324 | struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> { |
1325 | using OpRewritePattern::OpRewritePattern; |
1326 | |
1327 | LogicalResult matchAndRewrite(math::CbrtOp op, |
1328 | PatternRewriter &rewriter) const final; |
1329 | }; |
1330 | } // namespace |
1331 | |
1332 | // Estimation of cube-root using an algorithm defined in |
1333 | // Hacker's Delight 2nd Edition. |
1334 | LogicalResult |
1335 | CbrtApproximation::matchAndRewrite(math::CbrtOp op, |
1336 | PatternRewriter &rewriter) const { |
1337 | auto operand = op.getOperand(); |
1338 | if (!getElementTypeOrSelf(operand).isF32()) |
1339 | return rewriter.notifyMatchFailure(op, "unsupported operand type" ); |
1340 | |
1341 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1342 | VectorShape shape = vectorShape(operand); |
1343 | |
1344 | Type floatTy = getElementTypeOrSelf(operand.getType()); |
1345 | Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth()); |
1346 | |
1347 | // Convert to vector types if necessary. |
1348 | floatTy = broadcast(type: floatTy, shape); |
1349 | intTy = broadcast(type: intTy, shape); |
1350 | |
1351 | auto bconst = [&](TypedAttr attr) -> Value { |
1352 | Value value = b.create<arith::ConstantOp>(attr); |
1353 | return broadcast(builder&: b, value, shape); |
1354 | }; |
1355 | |
1356 | // Declare the initial values: |
1357 | Value intTwo = bconst(b.getI32IntegerAttr(2)); |
1358 | Value intFour = bconst(b.getI32IntegerAttr(4)); |
1359 | Value intEight = bconst(b.getI32IntegerAttr(8)); |
1360 | Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0)); |
1361 | Value fpThird = bconst(b.getF32FloatAttr(0.33333333f)); |
1362 | Value fpTwo = bconst(b.getF32FloatAttr(2.0f)); |
1363 | Value fpZero = bconst(b.getF32FloatAttr(0.0f)); |
1364 | |
1365 | // Compute an approximation of one third: |
1366 | // union {int ix; float x;}; |
1367 | // x = x0; |
1368 | // ix = ix/4 + ix/16; |
1369 | Value absValue = b.create<math::AbsFOp>(operand); |
1370 | Value intValue = b.create<arith::BitcastOp>(intTy, absValue); |
1371 | Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo); |
1372 | Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour); |
1373 | intValue = b.create<arith::AddIOp>(divideBy4, divideBy16); |
1374 | |
1375 | // ix = ix + ix/16; |
1376 | divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour); |
1377 | intValue = b.create<arith::AddIOp>(intValue, divideBy16); |
1378 | |
1379 | // ix = ix + ix/256; |
1380 | Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight); |
1381 | intValue = b.create<arith::AddIOp>(intValue, divideBy256); |
1382 | |
1383 | // ix = 0x2a5137a0 + ix; |
1384 | intValue = b.create<arith::AddIOp>(intValue, intMagic); |
1385 | |
1386 | // Perform one newtons step: |
1387 | // x = 0.33333333f*(2.0f*x + x0/(x*x)); |
1388 | Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue); |
1389 | Value squared = b.create<arith::MulFOp>(floatValue, floatValue); |
1390 | Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo); |
1391 | Value divSquared = b.create<arith::DivFOp>(absValue, squared); |
1392 | floatValue = b.create<arith::AddFOp>(mulTwo, divSquared); |
1393 | floatValue = b.create<arith::MulFOp>(floatValue, fpThird); |
1394 | |
1395 | // x = 0.33333333f*(2.0f*x + x0/(x*x)); |
1396 | squared = b.create<arith::MulFOp>(floatValue, floatValue); |
1397 | mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo); |
1398 | divSquared = b.create<arith::DivFOp>(absValue, squared); |
1399 | floatValue = b.create<arith::AddFOp>(mulTwo, divSquared); |
1400 | floatValue = b.create<arith::MulFOp>(floatValue, fpThird); |
1401 | |
1402 | // Check for zero and restore sign. |
1403 | Value isZero = |
1404 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero); |
1405 | floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue); |
1406 | floatValue = b.create<math::CopySignOp>(floatValue, operand); |
1407 | |
1408 | rewriter.replaceOp(op, floatValue); |
1409 | return success(); |
1410 | } |
1411 | |
1412 | //----------------------------------------------------------------------------// |
1413 | // Rsqrt approximation. |
1414 | //----------------------------------------------------------------------------// |
1415 | |
1416 | namespace { |
1417 | struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> { |
1418 | using OpRewritePattern::OpRewritePattern; |
1419 | |
1420 | LogicalResult matchAndRewrite(math::RsqrtOp op, |
1421 | PatternRewriter &rewriter) const final; |
1422 | }; |
1423 | } // namespace |
1424 | |
1425 | LogicalResult |
1426 | RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, |
1427 | PatternRewriter &rewriter) const { |
1428 | if (!getElementTypeOrSelf(op.getOperand()).isF32()) |
1429 | return rewriter.notifyMatchFailure(op, "unsupported operand type" ); |
1430 | |
1431 | VectorShape shape = vectorShape(op.getOperand()); |
1432 | |
1433 | // Only support already-vectorized rsqrt's. |
1434 | if (shape.empty() || shape.sizes.back() % 8 != 0) |
1435 | return rewriter.notifyMatchFailure(op, "unsupported operand type" ); |
1436 | |
1437 | ImplicitLocOpBuilder builder(op->getLoc(), rewriter); |
1438 | auto bcast = [&](Value value) -> Value { |
1439 | return broadcast(builder, value, shape); |
1440 | }; |
1441 | |
1442 | Value cstPosInf = bcast(f32FromBits(builder, bits: 0x7f800000u)); |
1443 | Value cstOnePointFive = bcast(f32Cst(builder, value: 1.5f)); |
1444 | Value cstNegHalf = bcast(f32Cst(builder, value: -0.5f)); |
1445 | Value cstMinNormPos = bcast(f32FromBits(builder, bits: 0x00800000u)); |
1446 | |
1447 | Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf); |
1448 | |
1449 | // Select only the inverse sqrt of positive normals (denormals are |
1450 | // flushed to zero). |
1451 | Value ltMinMask = builder.create<arith::CmpFOp>( |
1452 | arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos); |
1453 | Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, |
1454 | op.getOperand(), cstPosInf); |
1455 | Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask); |
1456 | |
1457 | // Compute an approximate result. |
1458 | Value yApprox = handleMultidimensionalVectors( |
1459 | builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value { |
1460 | return builder.create<x86vector::RsqrtOp>(operands); |
1461 | }); |
1462 | |
1463 | // Do a single step of Newton-Raphson iteration to improve the approximation. |
1464 | // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). |
1465 | // It is essential to evaluate the inner term like this because forming |
1466 | // y_n^2 may over- or underflow. |
1467 | Value inner = builder.create<arith::MulFOp>(negHalf, yApprox); |
1468 | Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive); |
1469 | Value yNewton = builder.create<arith::MulFOp>(yApprox, fma); |
1470 | |
1471 | // Select the result of the Newton-Raphson step for positive normal arguments. |
1472 | // For other arguments, choose the output of the intrinsic. This will |
1473 | // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if |
1474 | // x is zero or a positive denormalized float (equivalent to flushing positive |
1475 | // denormalized inputs to zero). |
1476 | Value res = |
1477 | builder.create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton); |
1478 | rewriter.replaceOp(op, res); |
1479 | |
1480 | return success(); |
1481 | } |
1482 | |
1483 | //----------------------------------------------------------------------------// |
1484 | |
1485 | void mlir::populatePolynomialApproximateTanhPattern( |
1486 | RewritePatternSet &patterns) { |
1487 | patterns.add<TanhApproximation>(arg: patterns.getContext()); |
1488 | } |
1489 | |
1490 | void mlir::populatePolynomialApproximateErfPattern( |
1491 | RewritePatternSet &patterns) { |
1492 | patterns.add<ErfPolynomialApproximation>(arg: patterns.getContext()); |
1493 | } |
1494 | |
1495 | void mlir::populateMathPolynomialApproximationPatterns( |
1496 | RewritePatternSet &patterns, |
1497 | const MathPolynomialApproximationOptions &options) { |
1498 | // Patterns for leveraging existing f32 lowerings on other data types. |
1499 | patterns |
1500 | .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>, |
1501 | ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>, |
1502 | ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>, |
1503 | ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>, |
1504 | ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>, |
1505 | ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>( |
1506 | patterns.getContext()); |
1507 | |
1508 | patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation, |
1509 | LogApproximation, Log2Approximation, Log1pApproximation, |
1510 | ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation, |
1511 | CbrtApproximation, SinAndCosApproximation<true, math::SinOp>, |
1512 | SinAndCosApproximation<false, math::CosOp>>( |
1513 | patterns.getContext()); |
1514 | if (options.enableAvx2) { |
1515 | patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>( |
1516 | patterns.getContext()); |
1517 | } |
1518 | } |
1519 | |