1//===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements expansion of various math operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Math/IR/Math.h"
15#include "mlir/Dialect/Math/Transforms/Passes.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/Vector/IR/VectorOps.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/ImplicitLocOpBuilder.h"
20#include "mlir/IR/TypeUtilities.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "llvm/ADT/APFloat.h"
23
24using namespace mlir;
25
26/// Create a float constant.
27static Value createFloatConst(Location loc, Type type, APFloat value,
28 OpBuilder &b) {
29 bool losesInfo = false;
30 auto eltType = getElementTypeOrSelf(type);
31 // Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
32 value.convert(ToSemantics: cast<FloatType>(eltType).getFloatSemantics(),
33 RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo);
34 auto attr = b.getFloatAttr(eltType, value);
35 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
36 return b.create<arith::ConstantOp>(loc,
37 DenseElementsAttr::get(shapedTy, attr));
38 }
39
40 return b.create<arith::ConstantOp>(loc, attr);
41}
42
43static Value createFloatConst(Location loc, Type type, double value,
44 OpBuilder &b) {
45 return createFloatConst(loc, type, value: APFloat(value), b);
46}
47
48/// Create an integer constant.
49static Value createIntConst(Location loc, Type type, int64_t value,
50 OpBuilder &b) {
51 auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
52 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
53 return b.create<arith::ConstantOp>(loc,
54 DenseElementsAttr::get(shapedTy, attr));
55 }
56
57 return b.create<arith::ConstantOp>(loc, attr);
58}
59
60static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
61 Type opType = operand.getType();
62 Type i64Ty = b.getI64Type();
63 if (auto shapedTy = dyn_cast<ShapedType>(opType))
64 i64Ty = shapedTy.clone(i64Ty);
65 Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
66 Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
67 // The truncation does not preserve the sign when the truncated
68 // value is -0. So here the sign is copied again.
69 return b.create<math::CopySignOp>(fpFixedConvert, operand);
70}
71
72// sinhf(float x) -> (exp(x) - exp(-x)) / 2
73static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
74 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
75 Value operand = op.getOperand();
76 Type opType = operand.getType();
77
78 Value exp = b.create<math::ExpOp>(operand);
79 Value neg = b.create<arith::NegFOp>(operand);
80 Value nexp = b.create<math::ExpOp>(neg);
81 Value sub = b.create<arith::SubFOp>(exp, nexp);
82 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
83 Value res = b.create<arith::MulFOp>(sub, half);
84 rewriter.replaceOp(op, res);
85 return success();
86}
87
88// coshf(float x) -> (exp(x) + exp(-x)) / 2
89static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
90 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
91 Value operand = op.getOperand();
92 Type opType = operand.getType();
93
94 Value exp = b.create<math::ExpOp>(operand);
95 Value neg = b.create<arith::NegFOp>(operand);
96 Value nexp = b.create<math::ExpOp>(neg);
97 Value add = b.create<arith::AddFOp>(exp, nexp);
98 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
99 Value res = b.create<arith::MulFOp>(add, half);
100 rewriter.replaceOp(op, res);
101 return success();
102}
103
104/// Expands tanh op into
105/// 1-exp^{-2x} / 1+exp^{-2x}
106/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
107/// We compute a "signs" value which is -1 if input is negative and +1 if input
108/// is positive. Then multiply the input by this value, guaranteeing that the
109/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
110/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
111/// result by `sign(x)` to retain sign of the real result.
112static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
113 auto floatType = op.getOperand().getType();
114 Location loc = op.getLoc();
115 Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
116 Value one = createFloatConst(loc, floatType, 1.0, rewriter);
117 Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
118
119 // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
120 Value isNegative = rewriter.create<arith::CmpFOp>(
121 loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
122 Value isNegativeFloat =
123 rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
124 Value isNegativeTimesNegTwo =
125 rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
126 Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
127
128 // Normalize input to positive value: y = sign(x) * x
129 Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
130
131 // Decompose on normalized input
132 Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
133 Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
134 Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
135 Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
136 Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
137
138 // Multiply result by sign(x) to retain signs from negative inputs
139 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
140
141 return success();
142}
143
144// Converts math.tan to math.sin, math.cos, and arith.divf.
145static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
146 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
147 Value operand = op.getOperand();
148 Type type = operand.getType();
149 Value sin = b.create<math::SinOp>(type, operand);
150 Value cos = b.create<math::CosOp>(type, operand);
151 Value div = b.create<arith::DivFOp>(type, sin, cos);
152 rewriter.replaceOp(op, div);
153 return success();
154}
155
156// asinh(float x) -> log(x + sqrt(x**2 + 1))
157static LogicalResult convertAsinhOp(math::AsinhOp op,
158 PatternRewriter &rewriter) {
159 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
160 Value operand = op.getOperand();
161 Type opType = operand.getType();
162
163 Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
164 Value fma = b.create<math::FmaOp>(operand, operand, one);
165 Value sqrt = b.create<math::SqrtOp>(fma);
166 Value add = b.create<arith::AddFOp>(operand, sqrt);
167 Value res = b.create<math::LogOp>(add);
168 rewriter.replaceOp(op, res);
169 return success();
170}
171
172// acosh(float x) -> log(x + sqrt(x**2 - 1))
173static LogicalResult convertAcoshOp(math::AcoshOp op,
174 PatternRewriter &rewriter) {
175 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
176 Value operand = op.getOperand();
177 Type opType = operand.getType();
178
179 Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
180 Value fma = b.create<math::FmaOp>(operand, operand, negOne);
181 Value sqrt = b.create<math::SqrtOp>(fma);
182 Value add = b.create<arith::AddFOp>(operand, sqrt);
183 Value res = b.create<math::LogOp>(add);
184 rewriter.replaceOp(op, res);
185 return success();
186}
187
188// atanh(float x) -> log((1 + x) / (1 - x)) / 2
189static LogicalResult convertAtanhOp(math::AtanhOp op,
190 PatternRewriter &rewriter) {
191 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
192 Value operand = op.getOperand();
193 Type opType = operand.getType();
194
195 Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
196 Value add = b.create<arith::AddFOp>(operand, one);
197 Value neg = b.create<arith::NegFOp>(operand);
198 Value sub = b.create<arith::AddFOp>(neg, one);
199 Value div = b.create<arith::DivFOp>(add, sub);
200 Value log = b.create<math::LogOp>(div);
201 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
202 Value res = b.create<arith::MulFOp>(log, half);
203 rewriter.replaceOp(op, res);
204 return success();
205}
206
207static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
208 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
209 Value operandA = op.getOperand(0);
210 Value operandB = op.getOperand(1);
211 Value operandC = op.getOperand(2);
212 Type type = op.getType();
213 Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
214 Value add = b.create<arith::AddFOp>(type, mult, operandC);
215 rewriter.replaceOp(op, add);
216 return success();
217}
218
219// Converts a ceilf() function to the following:
220// ceilf(float x) ->
221// y = (float)(int) x
222// if (x > y) then incr = 1 else incr = 0
223// y = y + incr <= replace this op with the ceilf op.
224static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
225 // Creating constants assumes the static shaped type.
226 auto shapedType = dyn_cast<ShapedType>(op.getType());
227 if (shapedType && !shapedType.hasStaticShape())
228 return failure();
229
230 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
231 Value operand = op.getOperand();
232 Type opType = operand.getType();
233 Value fpFixedConvert = createTruncatedFPValue(operand, b);
234
235 // Creating constants for later use.
236 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
237 Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
238
239 Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
240 fpFixedConvert);
241 Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
242
243 Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
244 rewriter.replaceOp(op, ret);
245 return success();
246}
247
248// Convert `math.fpowi` to a series of `arith.mulf` operations.
249// If the power is negative, we divide one by the result.
250// If both the base and power are zero, the result is 1.
251// In the case of non constant power, we convert the operation to `math.powf`.
252static LogicalResult convertFPowIOp(math::FPowIOp op,
253 PatternRewriter &rewriter) {
254 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
255 Value base = op.getOperand(0);
256 Value power = op.getOperand(1);
257 Type baseType = base.getType();
258
259 auto convertFPowItoPowf = [&]() -> LogicalResult {
260 Value castPowerToFp =
261 rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
262 Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
263 castPowerToFp);
264 rewriter.replaceOp(op, res);
265 return success();
266 };
267
268 Attribute cstAttr;
269 if (!matchPattern(value: power, pattern: m_Constant(bind_value: &cstAttr)))
270 return convertFPowItoPowf();
271
272 APInt value;
273 if (!matchPattern(cstAttr, m_ConstantInt(&value)))
274 return convertFPowItoPowf();
275
276 int64_t powerInt = value.getSExtValue();
277 bool isNegative = powerInt < 0;
278 int64_t absPower = std::abs(i: powerInt);
279 Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
280 Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
281
282 while (absPower > 0) {
283 if (absPower & 1)
284 res = b.create<arith::MulFOp>(baseType, base, res);
285 absPower >>= 1;
286 base = b.create<arith::MulFOp>(baseType, base, base);
287 }
288
289 // Make sure not to introduce UB in case of negative power.
290 if (isNegative) {
291 auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(type: baseType))
292 .getFloatSemantics();
293 Value zero =
294 createFloatConst(op->getLoc(), baseType,
295 APFloat::getZero(Sem: sem, /*Negative=*/false), rewriter);
296 Value negZero =
297 createFloatConst(op->getLoc(), baseType,
298 APFloat::getZero(Sem: sem, /*Negative=*/true), rewriter);
299 Value posInfinity =
300 createFloatConst(op->getLoc(), baseType,
301 APFloat::getInf(Sem: sem, /*Negative=*/false), rewriter);
302 Value negInfinity =
303 createFloatConst(op->getLoc(), baseType,
304 APFloat::getInf(Sem: sem, /*Negative=*/true), rewriter);
305 Value zeroEqCheck =
306 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
307 Value negZeroEqCheck =
308 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
309 res = b.create<arith::DivFOp>(baseType, one, res);
310 res =
311 b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
312 res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
313 res);
314 }
315
316 rewriter.replaceOp(op, res);
317 return success();
318}
319
320// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
321// Some special cases where b is constant are handled separately:
322// when b == 0, or |b| == 0.5, 1.0, or 2.0.
323static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
324 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
325 Value operandA = op.getOperand(0);
326 Value operandB = op.getOperand(1);
327 auto typeA = operandA.getType();
328 auto typeB = operandB.getType();
329
330 auto &sem =
331 cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
332 APFloat valueB(sem);
333 auto mulf = [&](Value x, Value y) -> Value {
334 return b.create<arith::MulFOp>(x, y);
335 };
336 if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
337 if (valueB.isZero()) {
338 // a^0 -> 1
339 Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
340 rewriter.replaceOp(op, one);
341 return success();
342 }
343 if (valueB.isExactlyValue(V: 1.0)) {
344 // a^1 -> a
345 rewriter.replaceOp(op, operandA);
346 return success();
347 }
348 if (valueB.isExactlyValue(V: -1.0)) {
349 // a^(-1) -> 1 / a
350 Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
351 Value div = b.create<arith::DivFOp>(one, operandA);
352 rewriter.replaceOp(op, div);
353 return success();
354 }
355 if (valueB.isExactlyValue(V: 0.5)) {
356 // a^(1/2) -> sqrt(a)
357 Value sqrt = b.create<math::SqrtOp>(operandA);
358 rewriter.replaceOp(op, sqrt);
359 return success();
360 }
361 if (valueB.isExactlyValue(V: -0.5)) {
362 // a^(-1/2) -> 1 / sqrt(a)
363 Value rsqrt = b.create<math::RsqrtOp>(operandA);
364 rewriter.replaceOp(op, rsqrt);
365 return success();
366 }
367 if (valueB.isExactlyValue(V: 2.0)) {
368 // a^2 -> a * a
369 rewriter.replaceOp(op, mulf(operandA, operandA));
370 return success();
371 }
372 if (valueB.isExactlyValue(V: -2.0)) {
373 // a^(-2) -> 1 / (a * a)
374 Value one =
375 createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
376 Value div = b.create<arith::DivFOp>(one, mulf(operandA, operandA));
377 rewriter.replaceOp(op, div);
378 return success();
379 }
380 if (valueB.isExactlyValue(V: 3.0)) {
381 rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
382 return success();
383 }
384 }
385
386 Value logA = b.create<math::LogOp>(operandA);
387 Value mult = b.create<arith::MulFOp>(operandB, logA);
388 Value expResult = b.create<math::ExpOp>(mult);
389 rewriter.replaceOp(op, expResult);
390 return success();
391}
392
393// exp2f(float x) -> exp(x * ln(2))
394// Proof: Let's say 2^x = y
395// ln(2^x) = ln(y)
396// x * ln(2) = ln(y) => e ^(x*ln(2)) = y
397static LogicalResult convertExp2fOp(math::Exp2Op op,
398 PatternRewriter &rewriter) {
399 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
400 Value operand = op.getOperand();
401 Type opType = operand.getType();
402 Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
403 Value mult = b.create<arith::MulFOp>(opType, operand, ln2);
404 Value exp = b.create<math::ExpOp>(op->getLoc(), mult);
405 rewriter.replaceOp(op, exp);
406 return success();
407}
408
409static LogicalResult convertRoundOp(math::RoundOp op,
410 PatternRewriter &rewriter) {
411 Location loc = op.getLoc();
412 ImplicitLocOpBuilder b(loc, rewriter);
413 Value operand = op.getOperand();
414 Type opType = operand.getType();
415 Type opEType = getElementTypeOrSelf(type: opType);
416
417 if (!opEType.isF32()) {
418 return rewriter.notifyMatchFailure(op, "not a round of f32.");
419 }
420
421 Type i32Ty = b.getI32Type();
422 if (auto shapedTy = dyn_cast<ShapedType>(opType))
423 i32Ty = shapedTy.clone(i32Ty);
424
425 Value half = createFloatConst(loc, type: opType, value: 0.5, b);
426 Value c23 = createIntConst(loc, type: i32Ty, value: 23, b);
427 Value c127 = createIntConst(loc, type: i32Ty, value: 127, b);
428 Value expMask = createIntConst(loc, type: i32Ty, value: (1 << 8) - 1, b);
429
430 Value incrValue = b.create<math::CopySignOp>(half, operand);
431 Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
432 Value fpFixedConvert = createTruncatedFPValue(operand: add, b);
433
434 // There are three cases where adding 0.5 to the value and truncating by
435 // converting to an i64 does not result in the correct behavior:
436 //
437 // 1. Special values: +-inf and +-nan
438 // Casting these special values to i64 has undefined behavior. To identify
439 // these values, we use the fact that these values are the only float
440 // values with the maximum possible biased exponent.
441 //
442 // 2. Large values: 2^23 <= |x| <= INT_64_MAX
443 // Adding 0.5 to a float larger than or equal to 2^23 results in precision
444 // errors that sometimes round the value up and sometimes round the value
445 // down. For example:
446 // 8388608.0 + 0.5 = 8388608.0
447 // 8388609.0 + 0.5 = 8388610.0
448 //
449 // 3. Very large values: |x| > INT_64_MAX
450 // Casting to i64 a value greater than the max i64 value will overflow the
451 // i64 leading to wrong outputs.
452 //
453 // All three cases satisfy the property `biasedExp >= 23`.
454 Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
455 Value operandExp = b.create<arith::AndIOp>(
456 b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
457 Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
458 Value isSpecialValOrLargeVal =
459 b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
460
461 Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
462 fpFixedConvert);
463 rewriter.replaceOp(op, result);
464 return success();
465}
466
467// Converts math.ctlz to scf and arith operations. This is done
468// by performing a binary search on the bits.
469static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
470 PatternRewriter &rewriter) {
471 auto operand = op.getOperand();
472 auto operandTy = operand.getType();
473 auto eTy = getElementTypeOrSelf(operandTy);
474 Location loc = op.getLoc();
475
476 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
477 if (bitwidth > 64)
478 return failure();
479
480 uint64_t allbits = -1;
481 if (bitwidth < 64) {
482 allbits = allbits >> (64 - bitwidth);
483 }
484
485 Value x = operand;
486 Value count = createIntConst(loc, operandTy, 0, rewriter);
487 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
488 auto half = bw / 2;
489 auto bits = createIntConst(loc, operandTy, half, rewriter);
490 auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
491
492 Value pred =
493 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
494 Value add = rewriter.create<arith::AddIOp>(loc, count, bits);
495 Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits);
496
497 x = rewriter.create<arith::SelectOp>(loc, pred, shift, x);
498 count = rewriter.create<arith::SelectOp>(loc, pred, add, count);
499 }
500
501 Value zero = createIntConst(loc, operandTy, 0, rewriter);
502 Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
503 operand, zero);
504
505 Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
506 Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count);
507 rewriter.replaceOp(op, sel);
508 return success();
509}
510
511// Convert `math.roundeven` into `math.round` + arith ops
512static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
513 PatternRewriter &rewriter) {
514 Location loc = op.getLoc();
515 ImplicitLocOpBuilder b(loc, rewriter);
516 auto operand = op.getOperand();
517 Type operandTy = operand.getType();
518 Type resultTy = op.getType();
519 Type operandETy = getElementTypeOrSelf(type: operandTy);
520 Type resultETy = getElementTypeOrSelf(type: resultTy);
521
522 if (!isa<FloatType>(Val: operandETy) || !isa<FloatType>(Val: resultETy)) {
523 return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
524 }
525
526 Type fTy = operandTy;
527 Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
528 if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
529 iTy = shapedTy.clone(iTy);
530 }
531
532 unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
533 // The width returned by getFPMantissaWidth includes the integer bit.
534 unsigned mantissaWidth =
535 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
536 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
537
538 // The names of the variables correspond to f32.
539 // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
540 // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
541 // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
542 Value c1Float = createFloatConst(loc, type: fTy, value: 1.0, b);
543 Value c0 = createIntConst(loc, type: iTy, value: 0, b);
544 Value c1 = createIntConst(loc, type: iTy, value: 1, b);
545 Value cNeg1 = createIntConst(loc, type: iTy, value: -1, b);
546 Value c23 = createIntConst(loc, type: iTy, value: mantissaWidth, b);
547 Value c31 = createIntConst(loc, type: iTy, value: bitWidth - 1, b);
548 Value c127 = createIntConst(loc, type: iTy, value: (1ull << (exponentWidth - 1)) - 1, b);
549 Value c2To22 = createIntConst(loc, type: iTy, value: 1ull << (mantissaWidth - 1), b);
550 Value c23Mask = createIntConst(loc, type: iTy, value: (1ull << mantissaWidth) - 1, b);
551 Value expMask = createIntConst(loc, type: iTy, value: (1ull << exponentWidth) - 1, b);
552
553 Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand);
554 Value round = b.create<math::RoundOp>(operand);
555 Value roundBitcast = b.create<arith::BitcastOp>(iTy, round);
556
557 // Get biased exponents for operand and round(operand)
558 Value operandExp = b.create<arith::AndIOp>(
559 b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
560 Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
561 Value roundExp = b.create<arith::AndIOp>(
562 b.create<arith::ShRUIOp>(roundBitcast, c23), expMask);
563 Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
564
565 auto safeShiftRight = [&](Value x, Value shift) -> Value {
566 // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
567 Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
568 clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
569 return b.create<arith::ShRUIOp>(x, clampedShift);
570 };
571
572 auto maskMantissa = [&](Value mantissa,
573 Value mantissaMaskRightShift) -> Value {
574 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
575 return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask);
576 };
577
578 // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
579 // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
580 // with `biasedExp > 23` (numbers where there is not enough precision to store
581 // decimals) are always even, and they satisfy the even condition trivially
582 // since the mantissa without all its bits is zero. The even condition
583 // is also true for +-0, since they have `biasedExp = -127` and the entire
584 // mantissa is zero. The case of +-1 has to be handled separately. Here
585 // we identify these values by noting that +-1 are the only whole numbers with
586 // `biasedExp == 0`.
587 //
588 // The special values +-inf and +-nan also satisfy the same property that
589 // whole non-unit even numbers satisfy. In particular, the special values have
590 // `biasedExp > 23`, so they get treated as large numbers with no room for
591 // decimals, which are always even.
592 Value roundBiasedExpEq0 =
593 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
594 Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1);
595 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
596 Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>(
597 arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
598 roundIsNotEvenOrSpecialVal =
599 b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
600
601 // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
602 // integers if the bit at index `biasedExp` starting from the left in the
603 // mantissa is 1 and all the bits to the right are zero. Values with
604 // `biasedExp >= 23` don't have decimals, so they are never halfway. The
605 // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
606 // so these are handled separately. In particular, if `biasedExp == -1`, the
607 // value is halfway if the entire mantissa is zero.
608 Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>(
609 arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
610 Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>(
611 operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
612 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
613 Value operandIsHalfway =
614 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
615 expectedOperandMaskedMantissa);
616 // Ensure `biasedExp` is in the valid range for half values.
617 Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>(
618 arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
619 Value operandBiasedExpLt23 =
620 b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
621 operandIsHalfway =
622 b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
623 operandIsHalfway =
624 b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
625
626 // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
627 // case where `round` rounded in the opposite direction of `roundeven`.
628 Value sign = b.create<math::CopySignOp>(c1Float, operand);
629 Value roundShifted = b.create<arith::SubFOp>(round, sign);
630 // If the rounded value is even or a special value, we default to the behavior
631 // of `math.round`.
632 Value needsShift =
633 b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
634 Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round);
635 // The `x - sign` adjustment does not preserve the sign when we are adjusting
636 // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
637 // rounded to -0.0.
638 result = b.create<math::CopySignOp>(result, operand);
639 rewriter.replaceOp(op, result);
640 return success();
641}
642
643// Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
644static LogicalResult convertRsqrtOp(math::RsqrtOp op,
645 PatternRewriter &rewriter) {
646
647 auto operand = op.getOperand();
648 auto operandTy = operand.getType();
649 // Operand type must be shatic shaped type to create const float.
650 auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
651 if (shapedOperandType && !shapedOperandType.hasStaticShape())
652 return failure();
653
654 auto eTy = getElementTypeOrSelf(operandTy);
655 if (!isa<FloatType>(eTy))
656 return failure();
657
658 Location loc = op->getLoc();
659 auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
660 auto sqrtOp = rewriter.create<math::SqrtOp>(loc, operand);
661 rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp);
662 return success();
663}
664
665void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
666 patterns.add(convertCtlzOp);
667}
668
669void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
670 patterns.add(convertSinhOp);
671}
672
673void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
674 patterns.add(convertCoshOp);
675}
676
677void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
678 patterns.add(convertTanOp);
679}
680
681void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
682 patterns.add(convertTanhOp);
683}
684
685void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
686 patterns.add(convertAsinhOp);
687}
688
689void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
690 patterns.add(convertAcoshOp);
691}
692
693void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
694 patterns.add(convertAtanhOp);
695}
696
697void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
698 patterns.add(convertFmaFOp);
699}
700
701void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
702 patterns.add(convertCeilOp);
703}
704
705void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
706 patterns.add(convertExp2fOp);
707}
708
709void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
710 patterns.add(convertPowfOp);
711}
712
713void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
714 patterns.add(convertFPowIOp);
715}
716
717void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
718 patterns.add(convertRoundOp);
719}
720
721void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
722 patterns.add(convertRoundEvenOp);
723}
724
725void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
726 patterns.add(convertRsqrtOp);
727}
728

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp