1//===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements expansion of various math operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Math/IR/Math.h"
15#include "mlir/Dialect/Math/Transforms/Passes.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/Vector/IR/VectorOps.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/ImplicitLocOpBuilder.h"
20#include "mlir/IR/TypeUtilities.h"
21#include "mlir/Transforms/DialectConversion.h"
22
23using namespace mlir;
24
25/// Create a float constant.
26static Value createFloatConst(Location loc, Type type, APFloat value,
27 OpBuilder &b) {
28 bool losesInfo = false;
29 auto eltType = getElementTypeOrSelf(type);
30 // Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
31 value.convert(ToSemantics: cast<FloatType>(Val&: eltType).getFloatSemantics(),
32 RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo);
33 auto attr = b.getFloatAttr(eltType, value);
34 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
35 return b.create<arith::ConstantOp>(loc,
36 DenseElementsAttr::get(shapedTy, attr));
37 }
38
39 return b.create<arith::ConstantOp>(loc, attr);
40}
41
42static Value createFloatConst(Location loc, Type type, double value,
43 OpBuilder &b) {
44 return createFloatConst(loc, type, value: APFloat(value), b);
45}
46
47/// Create an integer constant.
48static Value createIntConst(Location loc, Type type, int64_t value,
49 OpBuilder &b) {
50 auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
51 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
52 return b.create<arith::ConstantOp>(loc,
53 DenseElementsAttr::get(shapedTy, attr));
54 }
55
56 return b.create<arith::ConstantOp>(loc, attr);
57}
58
59static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
60 Type opType = operand.getType();
61 Type i64Ty = b.getI64Type();
62 if (auto shapedTy = dyn_cast<ShapedType>(opType))
63 i64Ty = shapedTy.clone(i64Ty);
64 Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
65 Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
66 // The truncation does not preserve the sign when the truncated
67 // value is -0. So here the sign is copied again.
68 return b.create<math::CopySignOp>(fpFixedConvert, operand);
69}
70
71// sinhf(float x) -> (exp(x) - exp(-x)) / 2
72static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
73 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
74 Value operand = op.getOperand();
75 Type opType = operand.getType();
76 Value exp = b.create<math::ExpOp>(operand);
77
78 Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
79 Value nexp = b.create<arith::DivFOp>(one, exp);
80 Value sub = b.create<arith::SubFOp>(exp, nexp);
81 Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
82 Value div = b.create<arith::DivFOp>(sub, two);
83 rewriter.replaceOp(op, div);
84 return success();
85}
86
87// coshf(float x) -> (exp(x) + exp(-x)) / 2
88static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
89 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
90 Value operand = op.getOperand();
91 Type opType = operand.getType();
92 Value exp = b.create<math::ExpOp>(operand);
93
94 Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
95 Value nexp = b.create<arith::DivFOp>(one, exp);
96 Value add = b.create<arith::AddFOp>(exp, nexp);
97 Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
98 Value div = b.create<arith::DivFOp>(add, two);
99 rewriter.replaceOp(op, div);
100 return success();
101}
102
103/// Expands tanh op into
104/// 1-exp^{-2x} / 1+exp^{-2x}
105/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
106/// We compute a "signs" value which is -1 if input is negative and +1 if input
107/// is positive. Then multiply the input by this value, guaranteeing that the
108/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
109/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
110/// result by `sign(x)` to retain sign of the real result.
111static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
112 auto floatType = op.getOperand().getType();
113 Location loc = op.getLoc();
114 Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
115 Value one = createFloatConst(loc, floatType, 1.0, rewriter);
116 Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
117
118 // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
119 Value isNegative = rewriter.create<arith::CmpFOp>(
120 loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
121 Value isNegativeFloat =
122 rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
123 Value isNegativeTimesNegTwo =
124 rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
125 Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
126
127 // Normalize input to positive value: y = sign(x) * x
128 Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
129
130 // Decompose on normalized input
131 Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
132 Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
133 Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
134 Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
135 Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
136
137 // Multiply result by sign(x) to retain signs from negative inputs
138 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
139
140 return success();
141}
142
143// Converts math.tan to math.sin, math.cos, and arith.divf.
144static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
145 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
146 Value operand = op.getOperand();
147 Type type = operand.getType();
148 Value sin = b.create<math::SinOp>(type, operand);
149 Value cos = b.create<math::CosOp>(type, operand);
150 Value div = b.create<arith::DivFOp>(type, sin, cos);
151 rewriter.replaceOp(op, div);
152 return success();
153}
154
155static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
156 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
157 Value operandA = op.getOperand(0);
158 Value operandB = op.getOperand(1);
159 Value operandC = op.getOperand(2);
160 Type type = op.getType();
161 Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
162 Value add = b.create<arith::AddFOp>(type, mult, operandC);
163 rewriter.replaceOp(op, add);
164 return success();
165}
166
167// Converts a floorf() function to the following:
168// floorf(float x) ->
169// y = (float)(int) x
170// if (x < 0) then incr = -1 else incr = 0
171// y = y + incr <= replace this op with the floorf op.
172static LogicalResult convertFloorOp(math::FloorOp op,
173 PatternRewriter &rewriter) {
174 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
175 Value operand = op.getOperand();
176 Type opType = operand.getType();
177 Value fpFixedConvert = createTruncatedFPValue(operand, b);
178
179 // Creating constants for later use.
180 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
181 Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
182
183 Value negCheck =
184 b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
185 Value incrValue =
186 b.create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero);
187 Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
188 rewriter.replaceOp(op, ret);
189 return success();
190}
191
192// Converts a ceilf() function to the following:
193// ceilf(float x) ->
194// y = (float)(int) x
195// if (x > y) then incr = 1 else incr = 0
196// y = y + incr <= replace this op with the ceilf op.
197static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
198 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
199 Value operand = op.getOperand();
200 Type opType = operand.getType();
201 Value fpFixedConvert = createTruncatedFPValue(operand, b);
202
203 // Creating constants for later use.
204 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
205 Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
206
207 Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
208 fpFixedConvert);
209 Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
210
211 Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
212 rewriter.replaceOp(op, ret);
213 return success();
214}
215
216// Convert `math.fpowi` to a series of `arith.mulf` operations.
217// If the power is negative, we divide one by the result.
218// If both the base and power are zero, the result is 1.
219// In the case of non constant power, we convert the operation to `math.powf`.
220static LogicalResult convertFPowIOp(math::FPowIOp op,
221 PatternRewriter &rewriter) {
222 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
223 Value base = op.getOperand(0);
224 Value power = op.getOperand(1);
225 Type baseType = base.getType();
226
227 auto convertFPowItoPowf = [&]() -> LogicalResult {
228 Value castPowerToFp =
229 rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
230 Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
231 castPowerToFp);
232 rewriter.replaceOp(op, res);
233 return success();
234 };
235
236 Attribute cstAttr;
237 if (!matchPattern(value: power, pattern: m_Constant(bind_value: &cstAttr)))
238 return convertFPowItoPowf();
239
240 APInt value;
241 if (!matchPattern(cstAttr, m_ConstantInt(&value)))
242 return convertFPowItoPowf();
243
244 int64_t powerInt = value.getSExtValue();
245 bool isNegative = powerInt < 0;
246 int64_t absPower = std::abs(i: powerInt);
247 Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
248 Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
249
250 while (absPower > 0) {
251 if (absPower & 1)
252 res = b.create<arith::MulFOp>(baseType, base, res);
253 absPower >>= 1;
254 base = b.create<arith::MulFOp>(baseType, base, base);
255 }
256
257 // Make sure not to introduce UB in case of negative power.
258 if (isNegative) {
259 auto &sem = dyn_cast<mlir::FloatType>(Val: getElementTypeOrSelf(type: baseType))
260 .getFloatSemantics();
261 Value zero =
262 createFloatConst(op->getLoc(), baseType,
263 APFloat::getZero(Sem: sem, /*Negative=*/false), rewriter);
264 Value negZero =
265 createFloatConst(op->getLoc(), baseType,
266 APFloat::getZero(Sem: sem, /*Negative=*/true), rewriter);
267 Value posInfinity =
268 createFloatConst(op->getLoc(), baseType,
269 APFloat::getInf(Sem: sem, /*Negative=*/false), rewriter);
270 Value negInfinity =
271 createFloatConst(op->getLoc(), baseType,
272 APFloat::getInf(Sem: sem, /*Negative=*/true), rewriter);
273 Value zeroEqCheck =
274 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
275 Value negZeroEqCheck =
276 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
277 res = b.create<arith::DivFOp>(baseType, one, res);
278 res =
279 b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
280 res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
281 res);
282 }
283
284 rewriter.replaceOp(op, res);
285 return success();
286}
287
288// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
289static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
290 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
291 Value operandA = op.getOperand(0);
292 Value operandB = op.getOperand(1);
293 Type opType = operandA.getType();
294 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
295 Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
296 Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
297 Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
298 Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
299
300 Value logA = b.create<math::LogOp>(opType, opASquared);
301 Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
302 Value expResult = b.create<math::ExpOp>(opType, mult);
303 Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
304 Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
305 Value negCheck =
306 b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
307 Value oddPower =
308 b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
309 Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
310
311 Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
312 expResult);
313 rewriter.replaceOp(op, res);
314 return success();
315}
316
317// exp2f(float x) -> exp(x * ln(2))
318// Proof: Let's say 2^x = y
319// ln(2^x) = ln(y)
320// x * ln(2) = ln(y) => e ^(x*ln(2)) = y
321static LogicalResult convertExp2fOp(math::Exp2Op op,
322 PatternRewriter &rewriter) {
323 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
324 Value operand = op.getOperand();
325 Type opType = operand.getType();
326 Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
327 Value mult = b.create<arith::MulFOp>(opType, operand, ln2);
328 Value exp = b.create<math::ExpOp>(op->getLoc(), mult);
329 rewriter.replaceOp(op, exp);
330 return success();
331}
332
333static LogicalResult convertRoundOp(math::RoundOp op,
334 PatternRewriter &rewriter) {
335 Location loc = op.getLoc();
336 ImplicitLocOpBuilder b(loc, rewriter);
337 Value operand = op.getOperand();
338 Type opType = operand.getType();
339 Type opEType = getElementTypeOrSelf(type: opType);
340
341 if (!opEType.isF32()) {
342 return rewriter.notifyMatchFailure(op, "not a round of f32.");
343 }
344
345 Type i32Ty = b.getI32Type();
346 if (auto shapedTy = dyn_cast<ShapedType>(opType))
347 i32Ty = shapedTy.clone(i32Ty);
348
349 Value half = createFloatConst(loc, type: opType, value: 0.5, b);
350 Value c23 = createIntConst(loc, type: i32Ty, value: 23, b);
351 Value c127 = createIntConst(loc, type: i32Ty, value: 127, b);
352 Value expMask = createIntConst(loc, type: i32Ty, value: (1 << 8) - 1, b);
353
354 Value incrValue = b.create<math::CopySignOp>(half, operand);
355 Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
356 Value fpFixedConvert = createTruncatedFPValue(operand: add, b);
357
358 // There are three cases where adding 0.5 to the value and truncating by
359 // converting to an i64 does not result in the correct behavior:
360 //
361 // 1. Special values: +-inf and +-nan
362 // Casting these special values to i64 has undefined behavior. To identify
363 // these values, we use the fact that these values are the only float
364 // values with the maximum possible biased exponent.
365 //
366 // 2. Large values: 2^23 <= |x| <= INT_64_MAX
367 // Adding 0.5 to a float larger than or equal to 2^23 results in precision
368 // errors that sometimes round the value up and sometimes round the value
369 // down. For example:
370 // 8388608.0 + 0.5 = 8388608.0
371 // 8388609.0 + 0.5 = 8388610.0
372 //
373 // 3. Very large values: |x| > INT_64_MAX
374 // Casting to i64 a value greater than the max i64 value will overflow the
375 // i64 leading to wrong outputs.
376 //
377 // All three cases satisfy the property `biasedExp >= 23`.
378 Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
379 Value operandExp = b.create<arith::AndIOp>(
380 b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
381 Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
382 Value isSpecialValOrLargeVal =
383 b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
384
385 Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
386 fpFixedConvert);
387 rewriter.replaceOp(op, result);
388 return success();
389}
390
391// Converts math.ctlz to scf and arith operations. This is done
392// by performing a binary search on the bits.
393static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
394 PatternRewriter &rewriter) {
395 auto operand = op.getOperand();
396 auto operandTy = operand.getType();
397 auto eTy = getElementTypeOrSelf(operandTy);
398 Location loc = op.getLoc();
399
400 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
401 if (bitwidth > 64)
402 return failure();
403
404 uint64_t allbits = -1;
405 if (bitwidth < 64) {
406 allbits = allbits >> (64 - bitwidth);
407 }
408
409 Value x = operand;
410 Value count = createIntConst(loc, operandTy, 0, rewriter);
411 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
412 auto half = bw / 2;
413 auto bits = createIntConst(loc, operandTy, half, rewriter);
414 auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
415
416 Value pred =
417 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
418 Value add = rewriter.create<arith::AddIOp>(loc, count, bits);
419 Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits);
420
421 x = rewriter.create<arith::SelectOp>(loc, pred, shift, x);
422 count = rewriter.create<arith::SelectOp>(loc, pred, add, count);
423 }
424
425 Value zero = createIntConst(loc, operandTy, 0, rewriter);
426 Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
427 operand, zero);
428
429 Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
430 Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count);
431 rewriter.replaceOp(op, sel);
432 return success();
433}
434
435// Convert `math.roundeven` into `math.round` + arith ops
436static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
437 PatternRewriter &rewriter) {
438 Location loc = op.getLoc();
439 ImplicitLocOpBuilder b(loc, rewriter);
440 auto operand = op.getOperand();
441 Type operandTy = operand.getType();
442 Type resultTy = op.getType();
443 Type operandETy = getElementTypeOrSelf(type: operandTy);
444 Type resultETy = getElementTypeOrSelf(type: resultTy);
445
446 if (!isa<FloatType>(Val: operandETy) || !isa<FloatType>(Val: resultETy)) {
447 return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
448 }
449
450 Type fTy = operandTy;
451 Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
452 if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
453 iTy = shapedTy.clone(iTy);
454 }
455
456 unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
457 // The width returned by getFPMantissaWidth includes the integer bit.
458 unsigned mantissaWidth =
459 llvm::cast<FloatType>(Val&: operandETy).getFPMantissaWidth() - 1;
460 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
461
462 // The names of the variables correspond to f32.
463 // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
464 // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
465 // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
466 Value c1Float = createFloatConst(loc, type: fTy, value: 1.0, b);
467 Value c0 = createIntConst(loc, type: iTy, value: 0, b);
468 Value c1 = createIntConst(loc, type: iTy, value: 1, b);
469 Value cNeg1 = createIntConst(loc, type: iTy, value: -1, b);
470 Value c23 = createIntConst(loc, type: iTy, value: mantissaWidth, b);
471 Value c31 = createIntConst(loc, type: iTy, value: bitWidth - 1, b);
472 Value c127 = createIntConst(loc, type: iTy, value: (1ull << (exponentWidth - 1)) - 1, b);
473 Value c2To22 = createIntConst(loc, type: iTy, value: 1ull << (mantissaWidth - 1), b);
474 Value c23Mask = createIntConst(loc, type: iTy, value: (1ull << mantissaWidth) - 1, b);
475 Value expMask = createIntConst(loc, type: iTy, value: (1ull << exponentWidth) - 1, b);
476
477 Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand);
478 Value round = b.create<math::RoundOp>(operand);
479 Value roundBitcast = b.create<arith::BitcastOp>(iTy, round);
480
481 // Get biased exponents for operand and round(operand)
482 Value operandExp = b.create<arith::AndIOp>(
483 b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
484 Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
485 Value roundExp = b.create<arith::AndIOp>(
486 b.create<arith::ShRUIOp>(roundBitcast, c23), expMask);
487 Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
488
489 auto safeShiftRight = [&](Value x, Value shift) -> Value {
490 // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
491 Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
492 clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
493 return b.create<arith::ShRUIOp>(x, clampedShift);
494 };
495
496 auto maskMantissa = [&](Value mantissa,
497 Value mantissaMaskRightShift) -> Value {
498 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
499 return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask);
500 };
501
502 // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
503 // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
504 // with `biasedExp > 23` (numbers where there is not enough precision to store
505 // decimals) are always even, and they satisfy the even condition trivially
506 // since the mantissa without all its bits is zero. The even condition
507 // is also true for +-0, since they have `biasedExp = -127` and the entire
508 // mantissa is zero. The case of +-1 has to be handled separately. Here
509 // we identify these values by noting that +-1 are the only whole numbers with
510 // `biasedExp == 0`.
511 //
512 // The special values +-inf and +-nan also satisfy the same property that
513 // whole non-unit even numbers satisfy. In particular, the special values have
514 // `biasedExp > 23`, so they get treated as large numbers with no room for
515 // decimals, which are always even.
516 Value roundBiasedExpEq0 =
517 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
518 Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1);
519 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
520 Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>(
521 arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
522 roundIsNotEvenOrSpecialVal =
523 b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
524
525 // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
526 // integers if the bit at index `biasedExp` starting from the left in the
527 // mantissa is 1 and all the bits to the right are zero. Values with
528 // `biasedExp >= 23` don't have decimals, so they are never halfway. The
529 // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
530 // so these are handled separately. In particular, if `biasedExp == -1`, the
531 // value is halfway if the entire mantissa is zero.
532 Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>(
533 arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
534 Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>(
535 operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
536 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
537 Value operandIsHalfway =
538 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
539 expectedOperandMaskedMantissa);
540 // Ensure `biasedExp` is in the valid range for half values.
541 Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>(
542 arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
543 Value operandBiasedExpLt23 =
544 b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
545 operandIsHalfway =
546 b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
547 operandIsHalfway =
548 b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
549
550 // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
551 // case where `round` rounded in the opposite direction of `roundeven`.
552 Value sign = b.create<math::CopySignOp>(c1Float, operand);
553 Value roundShifted = b.create<arith::SubFOp>(round, sign);
554 // If the rounded value is even or a special value, we default to the behavior
555 // of `math.round`.
556 Value needsShift =
557 b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
558 Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round);
559 // The `x - sign` adjustment does not preserve the sign when we are adjusting
560 // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
561 // rounded to -0.0.
562 result = b.create<math::CopySignOp>(result, operand);
563 rewriter.replaceOp(op, result);
564 return success();
565}
566
567void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
568 patterns.add(convertCtlzOp);
569}
570
571void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
572 patterns.add(convertSinhOp);
573}
574
575void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
576 patterns.add(convertCoshOp);
577}
578
579void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
580 patterns.add(convertTanOp);
581}
582
583void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
584 patterns.add(convertTanhOp);
585}
586
587void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
588 patterns.add(convertFmaFOp);
589}
590
591void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
592 patterns.add(convertCeilOp);
593}
594
595void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
596 patterns.add(convertExp2fOp);
597}
598
599void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
600 patterns.add(convertPowfOp);
601}
602
603void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
604 patterns.add(convertFPowIOp);
605}
606
607void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
608 patterns.add(convertRoundOp);
609}
610
611void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
612 patterns.add(convertFloorOp);
613}
614
615void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
616 patterns.add(convertRoundEvenOp);
617}
618

source code of mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp