1//===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements expansion of various math operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Math/IR/Math.h"
15#include "mlir/Dialect/Math/Transforms/Passes.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/Vector/IR/VectorOps.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/ImplicitLocOpBuilder.h"
20#include "mlir/IR/TypeUtilities.h"
21#include "mlir/Transforms/DialectConversion.h"
22
23using namespace mlir;
24
25/// Create a float constant.
26static Value createFloatConst(Location loc, Type type, APFloat value,
27 OpBuilder &b) {
28 bool losesInfo = false;
29 auto eltType = getElementTypeOrSelf(type);
30 // Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
31 value.convert(ToSemantics: cast<FloatType>(Val&: eltType).getFloatSemantics(),
32 RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo);
33 auto attr = b.getFloatAttr(type: eltType, value);
34 if (auto shapedTy = dyn_cast<ShapedType>(Val&: type)) {
35 return b.create<arith::ConstantOp>(location: loc,
36 args: DenseElementsAttr::get(type: shapedTy, values: attr));
37 }
38
39 return b.create<arith::ConstantOp>(location: loc, args&: attr);
40}
41
42static Value createFloatConst(Location loc, Type type, double value,
43 OpBuilder &b) {
44 return createFloatConst(loc, type, value: APFloat(value), b);
45}
46
47/// Create an integer constant.
48static Value createIntConst(Location loc, Type type, int64_t value,
49 OpBuilder &b) {
50 auto attr = b.getIntegerAttr(type: getElementTypeOrSelf(type), value);
51 if (auto shapedTy = dyn_cast<ShapedType>(Val&: type)) {
52 return b.create<arith::ConstantOp>(location: loc,
53 args: DenseElementsAttr::get(type: shapedTy, values: attr));
54 }
55
56 return b.create<arith::ConstantOp>(location: loc, args&: attr);
57}
58
59static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
60 Type opType = operand.getType();
61 Type i64Ty = b.getI64Type();
62 if (auto shapedTy = dyn_cast<ShapedType>(Val&: opType))
63 i64Ty = shapedTy.clone(elementType: i64Ty);
64 Value fixedConvert = b.create<arith::FPToSIOp>(args&: i64Ty, args&: operand);
65 Value fpFixedConvert = b.create<arith::SIToFPOp>(args&: opType, args&: fixedConvert);
66 // The truncation does not preserve the sign when the truncated
67 // value is -0. So here the sign is copied again.
68 return b.create<math::CopySignOp>(args&: fpFixedConvert, args&: operand);
69}
70
71// sinhf(float x) -> (exp(x) - exp(-x)) / 2
72static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
73 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
74 Value operand = op.getOperand();
75 Type opType = operand.getType();
76
77 Value exp = b.create<math::ExpOp>(args&: operand);
78 Value neg = b.create<arith::NegFOp>(args&: operand);
79 Value nexp = b.create<math::ExpOp>(args&: neg);
80 Value sub = b.create<arith::SubFOp>(args&: exp, args&: nexp);
81 Value half = createFloatConst(loc: op->getLoc(), type: opType, value: 0.5, b&: rewriter);
82 Value res = b.create<arith::MulFOp>(args&: sub, args&: half);
83 rewriter.replaceOp(op, newValues: res);
84 return success();
85}
86
87// coshf(float x) -> (exp(x) + exp(-x)) / 2
88static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
89 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
90 Value operand = op.getOperand();
91 Type opType = operand.getType();
92
93 Value exp = b.create<math::ExpOp>(args&: operand);
94 Value neg = b.create<arith::NegFOp>(args&: operand);
95 Value nexp = b.create<math::ExpOp>(args&: neg);
96 Value add = b.create<arith::AddFOp>(args&: exp, args&: nexp);
97 Value half = createFloatConst(loc: op->getLoc(), type: opType, value: 0.5, b&: rewriter);
98 Value res = b.create<arith::MulFOp>(args&: add, args&: half);
99 rewriter.replaceOp(op, newValues: res);
100 return success();
101}
102
103/// Expands tanh op into
104/// 1-exp^{-2x} / 1+exp^{-2x}
105/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
106/// We compute a "signs" value which is -1 if input is negative and +1 if input
107/// is positive. Then multiply the input by this value, guaranteeing that the
108/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
109/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
110/// result by `sign(x)` to retain sign of the real result.
111static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
112 auto floatType = op.getOperand().getType();
113 Location loc = op.getLoc();
114 Value zero = createFloatConst(loc, type: floatType, value: 0.0, b&: rewriter);
115 Value one = createFloatConst(loc, type: floatType, value: 1.0, b&: rewriter);
116 Value negTwo = createFloatConst(loc, type: floatType, value: -2.0, b&: rewriter);
117
118 // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
119 Value isNegative = rewriter.create<arith::CmpFOp>(
120 location: loc, args: arith::CmpFPredicate::OLT, args: op.getOperand(), args&: zero);
121 Value isNegativeFloat =
122 rewriter.create<arith::UIToFPOp>(location: loc, args&: floatType, args&: isNegative);
123 Value isNegativeTimesNegTwo =
124 rewriter.create<arith::MulFOp>(location: loc, args&: isNegativeFloat, args&: negTwo);
125 Value sign = rewriter.create<arith::AddFOp>(location: loc, args&: isNegativeTimesNegTwo, args&: one);
126
127 // Normalize input to positive value: y = sign(x) * x
128 Value positiveX = rewriter.create<arith::MulFOp>(location: loc, args&: sign, args: op.getOperand());
129
130 // Decompose on normalized input
131 Value negDoubledX = rewriter.create<arith::MulFOp>(location: loc, args&: negTwo, args&: positiveX);
132 Value exp2x = rewriter.create<math::ExpOp>(location: loc, args&: negDoubledX);
133 Value dividend = rewriter.create<arith::SubFOp>(location: loc, args&: one, args&: exp2x);
134 Value divisor = rewriter.create<arith::AddFOp>(location: loc, args&: one, args&: exp2x);
135 Value positiveRes = rewriter.create<arith::DivFOp>(location: loc, args&: dividend, args&: divisor);
136
137 // Multiply result by sign(x) to retain signs from negative inputs
138 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, args&: sign, args&: positiveRes);
139
140 return success();
141}
142
143// Converts math.tan to math.sin, math.cos, and arith.divf.
144static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
145 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
146 Value operand = op.getOperand();
147 Type type = operand.getType();
148 Value sin = b.create<math::SinOp>(args&: type, args&: operand);
149 Value cos = b.create<math::CosOp>(args&: type, args&: operand);
150 Value div = b.create<arith::DivFOp>(args&: type, args&: sin, args&: cos);
151 rewriter.replaceOp(op, newValues: div);
152 return success();
153}
154
155// asinh(float x) -> log(x + sqrt(x**2 + 1))
156static LogicalResult convertAsinhOp(math::AsinhOp op,
157 PatternRewriter &rewriter) {
158 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
159 Value operand = op.getOperand();
160 Type opType = operand.getType();
161
162 Value one = createFloatConst(loc: op->getLoc(), type: opType, value: 1.0, b&: rewriter);
163 Value fma = b.create<math::FmaOp>(args&: operand, args&: operand, args&: one);
164 Value sqrt = b.create<math::SqrtOp>(args&: fma);
165 Value add = b.create<arith::AddFOp>(args&: operand, args&: sqrt);
166 Value res = b.create<math::LogOp>(args&: add);
167 rewriter.replaceOp(op, newValues: res);
168 return success();
169}
170
171// acosh(float x) -> log(x + sqrt(x**2 - 1))
172static LogicalResult convertAcoshOp(math::AcoshOp op,
173 PatternRewriter &rewriter) {
174 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
175 Value operand = op.getOperand();
176 Type opType = operand.getType();
177
178 Value negOne = createFloatConst(loc: op->getLoc(), type: opType, value: -1.0, b&: rewriter);
179 Value fma = b.create<math::FmaOp>(args&: operand, args&: operand, args&: negOne);
180 Value sqrt = b.create<math::SqrtOp>(args&: fma);
181 Value add = b.create<arith::AddFOp>(args&: operand, args&: sqrt);
182 Value res = b.create<math::LogOp>(args&: add);
183 rewriter.replaceOp(op, newValues: res);
184 return success();
185}
186
187// atanh(float x) -> log((1 + x) / (1 - x)) / 2
188static LogicalResult convertAtanhOp(math::AtanhOp op,
189 PatternRewriter &rewriter) {
190 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
191 Value operand = op.getOperand();
192 Type opType = operand.getType();
193
194 Value one = createFloatConst(loc: op->getLoc(), type: opType, value: 1.0, b&: rewriter);
195 Value add = b.create<arith::AddFOp>(args&: operand, args&: one);
196 Value neg = b.create<arith::NegFOp>(args&: operand);
197 Value sub = b.create<arith::AddFOp>(args&: neg, args&: one);
198 Value div = b.create<arith::DivFOp>(args&: add, args&: sub);
199 Value log = b.create<math::LogOp>(args&: div);
200 Value half = createFloatConst(loc: op->getLoc(), type: opType, value: 0.5, b&: rewriter);
201 Value res = b.create<arith::MulFOp>(args&: log, args&: half);
202 rewriter.replaceOp(op, newValues: res);
203 return success();
204}
205
206static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
207 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
208 Value operandA = op.getOperand(i: 0);
209 Value operandB = op.getOperand(i: 1);
210 Value operandC = op.getOperand(i: 2);
211 Type type = op.getType();
212 Value mult = b.create<arith::MulFOp>(args&: type, args&: operandA, args&: operandB);
213 Value add = b.create<arith::AddFOp>(args&: type, args&: mult, args&: operandC);
214 rewriter.replaceOp(op, newValues: add);
215 return success();
216}
217
218// Converts a ceilf() function to the following:
219// ceilf(float x) ->
220// y = (float)(int) x
221// if (x > y) then incr = 1 else incr = 0
222// y = y + incr <= replace this op with the ceilf op.
223static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
224 // Creating constants assumes the static shaped type.
225 auto shapedType = dyn_cast<ShapedType>(Val: op.getType());
226 if (shapedType && !shapedType.hasStaticShape())
227 return failure();
228
229 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
230 Value operand = op.getOperand();
231 Type opType = operand.getType();
232 Value fpFixedConvert = createTruncatedFPValue(operand, b);
233
234 // Creating constants for later use.
235 Value zero = createFloatConst(loc: op->getLoc(), type: opType, value: 0.00, b&: rewriter);
236 Value one = createFloatConst(loc: op->getLoc(), type: opType, value: 1.00, b&: rewriter);
237
238 Value gtCheck = b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGT, args&: operand,
239 args&: fpFixedConvert);
240 Value incrValue = b.create<arith::SelectOp>(location: op->getLoc(), args&: gtCheck, args&: one, args&: zero);
241
242 Value ret = b.create<arith::AddFOp>(args&: opType, args&: fpFixedConvert, args&: incrValue);
243 rewriter.replaceOp(op, newValues: ret);
244 return success();
245}
246
247// Convert `math.fpowi` to a series of `arith.mulf` operations.
248// If the power is negative, we divide one by the result.
249// If both the base and power are zero, the result is 1.
250// In the case of non constant power, we convert the operation to `math.powf`.
251static LogicalResult convertFPowIOp(math::FPowIOp op,
252 PatternRewriter &rewriter) {
253 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
254 Value base = op.getOperand(i: 0);
255 Value power = op.getOperand(i: 1);
256 Type baseType = base.getType();
257
258 auto convertFPowItoPowf = [&]() -> LogicalResult {
259 Value castPowerToFp =
260 rewriter.create<arith::SIToFPOp>(location: op.getLoc(), args&: baseType, args&: power);
261 Value res = rewriter.create<math::PowFOp>(location: op.getLoc(), args&: baseType, args&: base,
262 args&: castPowerToFp);
263 rewriter.replaceOp(op, newValues: res);
264 return success();
265 };
266
267 Attribute cstAttr;
268 if (!matchPattern(value: power, pattern: m_Constant(bind_value: &cstAttr)))
269 return convertFPowItoPowf();
270
271 APInt value;
272 if (!matchPattern(attr: cstAttr, pattern: m_ConstantInt(bind_value: &value)))
273 return convertFPowItoPowf();
274
275 int64_t powerInt = value.getSExtValue();
276 bool isNegative = powerInt < 0;
277 int64_t absPower = std::abs(i: powerInt);
278 Value one = createFloatConst(loc: op->getLoc(), type: baseType, value: 1.00, b&: rewriter);
279 Value res = createFloatConst(loc: op->getLoc(), type: baseType, value: 1.00, b&: rewriter);
280
281 while (absPower > 0) {
282 if (absPower & 1)
283 res = b.create<arith::MulFOp>(args&: baseType, args&: base, args&: res);
284 absPower >>= 1;
285 base = b.create<arith::MulFOp>(args&: baseType, args&: base, args&: base);
286 }
287
288 // Make sure not to introduce UB in case of negative power.
289 if (isNegative) {
290 auto &sem = dyn_cast<mlir::FloatType>(Val: getElementTypeOrSelf(type: baseType))
291 .getFloatSemantics();
292 Value zero =
293 createFloatConst(loc: op->getLoc(), type: baseType,
294 value: APFloat::getZero(Sem: sem, /*Negative=*/false), b&: rewriter);
295 Value negZero =
296 createFloatConst(loc: op->getLoc(), type: baseType,
297 value: APFloat::getZero(Sem: sem, /*Negative=*/true), b&: rewriter);
298 Value posInfinity =
299 createFloatConst(loc: op->getLoc(), type: baseType,
300 value: APFloat::getInf(Sem: sem, /*Negative=*/false), b&: rewriter);
301 Value negInfinity =
302 createFloatConst(loc: op->getLoc(), type: baseType,
303 value: APFloat::getInf(Sem: sem, /*Negative=*/true), b&: rewriter);
304 Value zeroEqCheck =
305 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: res, args&: zero);
306 Value negZeroEqCheck =
307 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: res, args&: negZero);
308 res = b.create<arith::DivFOp>(args&: baseType, args&: one, args&: res);
309 res =
310 b.create<arith::SelectOp>(location: op->getLoc(), args&: zeroEqCheck, args&: posInfinity, args&: res);
311 res = b.create<arith::SelectOp>(location: op->getLoc(), args&: negZeroEqCheck, args&: negInfinity,
312 args&: res);
313 }
314
315 rewriter.replaceOp(op, newValues: res);
316 return success();
317}
318
319// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
320// Some special cases where b is constant are handled separately:
321// when b == 0, or |b| == 0.5, 1.0, or 2.0.
322static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
323 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
324 Value operandA = op.getOperand(i: 0);
325 Value operandB = op.getOperand(i: 1);
326 auto typeA = operandA.getType();
327 auto typeB = operandB.getType();
328
329 auto &sem =
330 cast<mlir::FloatType>(Val: getElementTypeOrSelf(type: typeB)).getFloatSemantics();
331 APFloat valueB(sem);
332 auto mulf = [&](Value x, Value y) -> Value {
333 return b.create<arith::MulFOp>(args&: x, args&: y);
334 };
335 if (matchPattern(value: operandB, pattern: m_ConstantFloat(bind_value: &valueB))) {
336 if (valueB.isZero()) {
337 // a^0 -> 1
338 Value one = createFloatConst(loc: op->getLoc(), type: typeA, value: 1.0, b&: rewriter);
339 rewriter.replaceOp(op, newValues: one);
340 return success();
341 }
342 if (valueB.isExactlyValue(V: 1.0)) {
343 // a^1 -> a
344 rewriter.replaceOp(op, newValues: operandA);
345 return success();
346 }
347 if (valueB.isExactlyValue(V: -1.0)) {
348 // a^(-1) -> 1 / a
349 Value one = createFloatConst(loc: op->getLoc(), type: typeA, value: 1.0, b&: rewriter);
350 Value div = b.create<arith::DivFOp>(args&: one, args&: operandA);
351 rewriter.replaceOp(op, newValues: div);
352 return success();
353 }
354 if (valueB.isExactlyValue(V: 0.5)) {
355 // a^(1/2) -> sqrt(a)
356 Value sqrt = b.create<math::SqrtOp>(args&: operandA);
357 rewriter.replaceOp(op, newValues: sqrt);
358 return success();
359 }
360 if (valueB.isExactlyValue(V: -0.5)) {
361 // a^(-1/2) -> 1 / sqrt(a)
362 Value rsqrt = b.create<math::RsqrtOp>(args&: operandA);
363 rewriter.replaceOp(op, newValues: rsqrt);
364 return success();
365 }
366 if (valueB.isExactlyValue(V: 2.0)) {
367 // a^2 -> a * a
368 rewriter.replaceOp(op, newValues: mulf(operandA, operandA));
369 return success();
370 }
371 if (valueB.isExactlyValue(V: -2.0)) {
372 // a^(-2) -> 1 / (a * a)
373 Value one =
374 createFloatConst(loc: op->getLoc(), type: operandA.getType(), value: 1.0, b&: rewriter);
375 Value div = b.create<arith::DivFOp>(args&: one, args: mulf(operandA, operandA));
376 rewriter.replaceOp(op, newValues: div);
377 return success();
378 }
379 if (valueB.isExactlyValue(V: 3.0)) {
380 rewriter.replaceOp(op, newValues: mulf(mulf(operandA, operandA), operandA));
381 return success();
382 }
383 }
384
385 Value logA = b.create<math::LogOp>(args&: operandA);
386 Value mult = b.create<arith::MulFOp>(args&: operandB, args&: logA);
387 Value expResult = b.create<math::ExpOp>(args&: mult);
388 rewriter.replaceOp(op, newValues: expResult);
389 return success();
390}
391
392// exp2f(float x) -> exp(x * ln(2))
393// Proof: Let's say 2^x = y
394// ln(2^x) = ln(y)
395// x * ln(2) = ln(y) => e ^(x*ln(2)) = y
396static LogicalResult convertExp2fOp(math::Exp2Op op,
397 PatternRewriter &rewriter) {
398 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
399 Value operand = op.getOperand();
400 Type opType = operand.getType();
401 Value ln2 = createFloatConst(loc: op->getLoc(), type: opType, value: llvm::numbers::ln2, b);
402 Value mult = b.create<arith::MulFOp>(args&: opType, args&: operand, args&: ln2);
403 Value exp = b.create<math::ExpOp>(location: op->getLoc(), args&: mult);
404 rewriter.replaceOp(op, newValues: exp);
405 return success();
406}
407
408static LogicalResult convertRoundOp(math::RoundOp op,
409 PatternRewriter &rewriter) {
410 Location loc = op.getLoc();
411 ImplicitLocOpBuilder b(loc, rewriter);
412 Value operand = op.getOperand();
413 Type opType = operand.getType();
414 Type opEType = getElementTypeOrSelf(type: opType);
415
416 if (!opEType.isF32()) {
417 return rewriter.notifyMatchFailure(arg&: op, msg: "not a round of f32.");
418 }
419
420 Type i32Ty = b.getI32Type();
421 if (auto shapedTy = dyn_cast<ShapedType>(Val&: opType))
422 i32Ty = shapedTy.clone(elementType: i32Ty);
423
424 Value half = createFloatConst(loc, type: opType, value: 0.5, b);
425 Value c23 = createIntConst(loc, type: i32Ty, value: 23, b);
426 Value c127 = createIntConst(loc, type: i32Ty, value: 127, b);
427 Value expMask = createIntConst(loc, type: i32Ty, value: (1 << 8) - 1, b);
428
429 Value incrValue = b.create<math::CopySignOp>(args&: half, args&: operand);
430 Value add = b.create<arith::AddFOp>(args&: opType, args&: operand, args&: incrValue);
431 Value fpFixedConvert = createTruncatedFPValue(operand: add, b);
432
433 // There are three cases where adding 0.5 to the value and truncating by
434 // converting to an i64 does not result in the correct behavior:
435 //
436 // 1. Special values: +-inf and +-nan
437 // Casting these special values to i64 has undefined behavior. To identify
438 // these values, we use the fact that these values are the only float
439 // values with the maximum possible biased exponent.
440 //
441 // 2. Large values: 2^23 <= |x| <= INT_64_MAX
442 // Adding 0.5 to a float larger than or equal to 2^23 results in precision
443 // errors that sometimes round the value up and sometimes round the value
444 // down. For example:
445 // 8388608.0 + 0.5 = 8388608.0
446 // 8388609.0 + 0.5 = 8388610.0
447 //
448 // 3. Very large values: |x| > INT_64_MAX
449 // Casting to i64 a value greater than the max i64 value will overflow the
450 // i64 leading to wrong outputs.
451 //
452 // All three cases satisfy the property `biasedExp >= 23`.
453 Value operandBitcast = b.create<arith::BitcastOp>(args&: i32Ty, args&: operand);
454 Value operandExp = b.create<arith::AndIOp>(
455 args: b.create<arith::ShRUIOp>(args&: operandBitcast, args&: c23), args&: expMask);
456 Value operandBiasedExp = b.create<arith::SubIOp>(args&: operandExp, args&: c127);
457 Value isSpecialValOrLargeVal =
458 b.create<arith::CmpIOp>(args: arith::CmpIPredicate::sge, args&: operandBiasedExp, args&: c23);
459
460 Value result = b.create<arith::SelectOp>(args&: isSpecialValOrLargeVal, args&: operand,
461 args&: fpFixedConvert);
462 rewriter.replaceOp(op, newValues: result);
463 return success();
464}
465
466// Converts math.ctlz to scf and arith operations. This is done
467// by performing a binary search on the bits.
468static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
469 PatternRewriter &rewriter) {
470 auto operand = op.getOperand();
471 auto operandTy = operand.getType();
472 auto eTy = getElementTypeOrSelf(type: operandTy);
473 Location loc = op.getLoc();
474
475 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
476 if (bitwidth > 64)
477 return failure();
478
479 uint64_t allbits = -1;
480 if (bitwidth < 64) {
481 allbits = allbits >> (64 - bitwidth);
482 }
483
484 Value x = operand;
485 Value count = createIntConst(loc, type: operandTy, value: 0, b&: rewriter);
486 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
487 auto half = bw / 2;
488 auto bits = createIntConst(loc, type: operandTy, value: half, b&: rewriter);
489 auto mask = createIntConst(loc, type: operandTy, value: allbits >> half, b&: rewriter);
490
491 Value pred =
492 rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::ule, args&: x, args&: mask);
493 Value add = rewriter.create<arith::AddIOp>(location: loc, args&: count, args&: bits);
494 Value shift = rewriter.create<arith::ShLIOp>(location: loc, args&: x, args&: bits);
495
496 x = rewriter.create<arith::SelectOp>(location: loc, args&: pred, args&: shift, args&: x);
497 count = rewriter.create<arith::SelectOp>(location: loc, args&: pred, args&: add, args&: count);
498 }
499
500 Value zero = createIntConst(loc, type: operandTy, value: 0, b&: rewriter);
501 Value pred = rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::eq,
502 args&: operand, args&: zero);
503
504 Value bwval = createIntConst(loc, type: operandTy, value: bitwidth, b&: rewriter);
505 Value sel = rewriter.create<arith::SelectOp>(location: loc, args&: pred, args&: bwval, args&: count);
506 rewriter.replaceOp(op, newValues: sel);
507 return success();
508}
509
510// Convert `math.roundeven` into `math.round` + arith ops
511static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
512 PatternRewriter &rewriter) {
513 Location loc = op.getLoc();
514 ImplicitLocOpBuilder b(loc, rewriter);
515 auto operand = op.getOperand();
516 Type operandTy = operand.getType();
517 Type resultTy = op.getType();
518 Type operandETy = getElementTypeOrSelf(type: operandTy);
519 Type resultETy = getElementTypeOrSelf(type: resultTy);
520
521 if (!isa<FloatType>(Val: operandETy) || !isa<FloatType>(Val: resultETy)) {
522 return rewriter.notifyMatchFailure(arg&: op, msg: "not a roundeven of f16 or f32.");
523 }
524
525 Type fTy = operandTy;
526 Type iTy = rewriter.getIntegerType(width: operandETy.getIntOrFloatBitWidth());
527 if (auto shapedTy = dyn_cast<ShapedType>(Val&: fTy)) {
528 iTy = shapedTy.clone(elementType: iTy);
529 }
530
531 unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
532 // The width returned by getFPMantissaWidth includes the integer bit.
533 unsigned mantissaWidth =
534 llvm::cast<FloatType>(Val&: operandETy).getFPMantissaWidth() - 1;
535 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
536
537 // The names of the variables correspond to f32.
538 // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
539 // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
540 // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
541 Value c1Float = createFloatConst(loc, type: fTy, value: 1.0, b);
542 Value c0 = createIntConst(loc, type: iTy, value: 0, b);
543 Value c1 = createIntConst(loc, type: iTy, value: 1, b);
544 Value cNeg1 = createIntConst(loc, type: iTy, value: -1, b);
545 Value c23 = createIntConst(loc, type: iTy, value: mantissaWidth, b);
546 Value c31 = createIntConst(loc, type: iTy, value: bitWidth - 1, b);
547 Value c127 = createIntConst(loc, type: iTy, value: (1ull << (exponentWidth - 1)) - 1, b);
548 Value c2To22 = createIntConst(loc, type: iTy, value: 1ull << (mantissaWidth - 1), b);
549 Value c23Mask = createIntConst(loc, type: iTy, value: (1ull << mantissaWidth) - 1, b);
550 Value expMask = createIntConst(loc, type: iTy, value: (1ull << exponentWidth) - 1, b);
551
552 Value operandBitcast = b.create<arith::BitcastOp>(args&: iTy, args&: operand);
553 Value round = b.create<math::RoundOp>(args&: operand);
554 Value roundBitcast = b.create<arith::BitcastOp>(args&: iTy, args&: round);
555
556 // Get biased exponents for operand and round(operand)
557 Value operandExp = b.create<arith::AndIOp>(
558 args: b.create<arith::ShRUIOp>(args&: operandBitcast, args&: c23), args&: expMask);
559 Value operandBiasedExp = b.create<arith::SubIOp>(args&: operandExp, args&: c127);
560 Value roundExp = b.create<arith::AndIOp>(
561 args: b.create<arith::ShRUIOp>(args&: roundBitcast, args&: c23), args&: expMask);
562 Value roundBiasedExp = b.create<arith::SubIOp>(args&: roundExp, args&: c127);
563
564 auto safeShiftRight = [&](Value x, Value shift) -> Value {
565 // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
566 Value clampedShift = b.create<arith::MaxSIOp>(args&: shift, args&: c0);
567 clampedShift = b.create<arith::MinSIOp>(args&: clampedShift, args&: c31);
568 return b.create<arith::ShRUIOp>(args&: x, args&: clampedShift);
569 };
570
571 auto maskMantissa = [&](Value mantissa,
572 Value mantissaMaskRightShift) -> Value {
573 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
574 return b.create<arith::AndIOp>(args&: mantissa, args&: shiftedMantissaMask);
575 };
576
577 // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
578 // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
579 // with `biasedExp > 23` (numbers where there is not enough precision to store
580 // decimals) are always even, and they satisfy the even condition trivially
581 // since the mantissa without all its bits is zero. The even condition
582 // is also true for +-0, since they have `biasedExp = -127` and the entire
583 // mantissa is zero. The case of +-1 has to be handled separately. Here
584 // we identify these values by noting that +-1 are the only whole numbers with
585 // `biasedExp == 0`.
586 //
587 // The special values +-inf and +-nan also satisfy the same property that
588 // whole non-unit even numbers satisfy. In particular, the special values have
589 // `biasedExp > 23`, so they get treated as large numbers with no room for
590 // decimals, which are always even.
591 Value roundBiasedExpEq0 =
592 b.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: roundBiasedExp, args&: c0);
593 Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(args&: roundBiasedExp, args&: c1);
594 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
595 Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>(
596 args: arith::CmpIPredicate::ne, args&: roundMaskedMantissa, args&: c0);
597 roundIsNotEvenOrSpecialVal =
598 b.create<arith::OrIOp>(args&: roundIsNotEvenOrSpecialVal, args&: roundBiasedExpEq0);
599
600 // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
601 // integers if the bit at index `biasedExp` starting from the left in the
602 // mantissa is 1 and all the bits to the right are zero. Values with
603 // `biasedExp >= 23` don't have decimals, so they are never halfway. The
604 // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
605 // so these are handled separately. In particular, if `biasedExp == -1`, the
606 // value is halfway if the entire mantissa is zero.
607 Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>(
608 args: arith::CmpIPredicate::eq, args&: operandBiasedExp, args&: cNeg1);
609 Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>(
610 args&: operandBiasedExpEqNeg1, args&: c0, args: safeShiftRight(c2To22, operandBiasedExp));
611 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
612 Value operandIsHalfway =
613 b.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: operandMaskedMantissa,
614 args&: expectedOperandMaskedMantissa);
615 // Ensure `biasedExp` is in the valid range for half values.
616 Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>(
617 args: arith::CmpIPredicate::sge, args&: operandBiasedExp, args&: cNeg1);
618 Value operandBiasedExpLt23 =
619 b.create<arith::CmpIOp>(args: arith::CmpIPredicate::slt, args&: operandBiasedExp, args&: c23);
620 operandIsHalfway =
621 b.create<arith::AndIOp>(args&: operandIsHalfway, args&: operandBiasedExpLt23);
622 operandIsHalfway =
623 b.create<arith::AndIOp>(args&: operandIsHalfway, args&: operandBiasedExpGeNeg1);
624
625 // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
626 // case where `round` rounded in the opposite direction of `roundeven`.
627 Value sign = b.create<math::CopySignOp>(args&: c1Float, args&: operand);
628 Value roundShifted = b.create<arith::SubFOp>(args&: round, args&: sign);
629 // If the rounded value is even or a special value, we default to the behavior
630 // of `math.round`.
631 Value needsShift =
632 b.create<arith::AndIOp>(args&: roundIsNotEvenOrSpecialVal, args&: operandIsHalfway);
633 Value result = b.create<arith::SelectOp>(args&: needsShift, args&: roundShifted, args&: round);
634 // The `x - sign` adjustment does not preserve the sign when we are adjusting
635 // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
636 // rounded to -0.0.
637 result = b.create<math::CopySignOp>(args&: result, args&: operand);
638 rewriter.replaceOp(op, newValues: result);
639 return success();
640}
641
642// Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
643static LogicalResult convertRsqrtOp(math::RsqrtOp op,
644 PatternRewriter &rewriter) {
645
646 auto operand = op.getOperand();
647 auto operandTy = operand.getType();
648 // Operand type must be shatic shaped type to create const float.
649 auto shapedOperandType = dyn_cast<ShapedType>(Val&: operandTy);
650 if (shapedOperandType && !shapedOperandType.hasStaticShape())
651 return failure();
652
653 auto eTy = getElementTypeOrSelf(type: operandTy);
654 if (!isa<FloatType>(Val: eTy))
655 return failure();
656
657 Location loc = op->getLoc();
658 auto constOneFloat = createFloatConst(loc, type: operandTy, value: 1.0, b&: rewriter);
659 auto sqrtOp = rewriter.create<math::SqrtOp>(location: loc, args&: operand);
660 rewriter.replaceOpWithNewOp<arith::DivFOp>(op, args&: constOneFloat, args&: sqrtOp);
661 return success();
662}
663
664void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
665 patterns.add(implFn: convertCtlzOp);
666}
667
668void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
669 patterns.add(implFn: convertSinhOp);
670}
671
672void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
673 patterns.add(implFn: convertCoshOp);
674}
675
676void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
677 patterns.add(implFn: convertTanOp);
678}
679
680void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
681 patterns.add(implFn: convertTanhOp);
682}
683
684void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
685 patterns.add(implFn: convertAsinhOp);
686}
687
688void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
689 patterns.add(implFn: convertAcoshOp);
690}
691
692void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
693 patterns.add(implFn: convertAtanhOp);
694}
695
696void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
697 patterns.add(implFn: convertFmaFOp);
698}
699
700void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
701 patterns.add(implFn: convertCeilOp);
702}
703
704void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
705 patterns.add(implFn: convertExp2fOp);
706}
707
708void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
709 patterns.add(implFn: convertPowfOp);
710}
711
712void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
713 patterns.add(implFn: convertFPowIOp);
714}
715
716void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
717 patterns.add(implFn: convertRoundOp);
718}
719
720void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
721 patterns.add(implFn: convertRoundEvenOp);
722}
723
724void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
725 patterns.add(implFn: convertRsqrtOp);
726}
727

source code of mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp