1 | //===- ExpandPatterns.cpp - Code to expand various math operations. -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements expansion of various math operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/Math/IR/Math.h" |
15 | #include "mlir/Dialect/Math/Transforms/Passes.h" |
16 | #include "mlir/Dialect/SCF/IR/SCF.h" |
17 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
20 | #include "mlir/IR/TypeUtilities.h" |
21 | #include "mlir/Transforms/DialectConversion.h" |
22 | |
23 | using namespace mlir; |
24 | |
25 | /// Create a float constant. |
26 | static Value createFloatConst(Location loc, Type type, APFloat value, |
27 | OpBuilder &b) { |
28 | bool losesInfo = false; |
29 | auto eltType = getElementTypeOrSelf(type); |
30 | // Convert double to the given `FloatType` with round-to-nearest-ties-to-even. |
31 | value.convert(ToSemantics: cast<FloatType>(Val&: eltType).getFloatSemantics(), |
32 | RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo); |
33 | auto attr = b.getFloatAttr(eltType, value); |
34 | if (auto shapedTy = dyn_cast<ShapedType>(type)) { |
35 | return b.create<arith::ConstantOp>(loc, |
36 | DenseElementsAttr::get(shapedTy, attr)); |
37 | } |
38 | |
39 | return b.create<arith::ConstantOp>(loc, attr); |
40 | } |
41 | |
42 | static Value createFloatConst(Location loc, Type type, double value, |
43 | OpBuilder &b) { |
44 | return createFloatConst(loc, type, value: APFloat(value), b); |
45 | } |
46 | |
47 | /// Create an integer constant. |
48 | static Value createIntConst(Location loc, Type type, int64_t value, |
49 | OpBuilder &b) { |
50 | auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value); |
51 | if (auto shapedTy = dyn_cast<ShapedType>(type)) { |
52 | return b.create<arith::ConstantOp>(loc, |
53 | DenseElementsAttr::get(shapedTy, attr)); |
54 | } |
55 | |
56 | return b.create<arith::ConstantOp>(loc, attr); |
57 | } |
58 | |
59 | static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) { |
60 | Type opType = operand.getType(); |
61 | Type i64Ty = b.getI64Type(); |
62 | if (auto shapedTy = dyn_cast<ShapedType>(opType)) |
63 | i64Ty = shapedTy.clone(i64Ty); |
64 | Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand); |
65 | Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert); |
66 | // The truncation does not preserve the sign when the truncated |
67 | // value is -0. So here the sign is copied again. |
68 | return b.create<math::CopySignOp>(fpFixedConvert, operand); |
69 | } |
70 | |
71 | // sinhf(float x) -> (exp(x) - exp(-x)) / 2 |
72 | static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) { |
73 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
74 | Value operand = op.getOperand(); |
75 | Type opType = operand.getType(); |
76 | Value exp = b.create<math::ExpOp>(operand); |
77 | |
78 | Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); |
79 | Value nexp = b.create<arith::DivFOp>(one, exp); |
80 | Value sub = b.create<arith::SubFOp>(exp, nexp); |
81 | Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter); |
82 | Value div = b.create<arith::DivFOp>(sub, two); |
83 | rewriter.replaceOp(op, div); |
84 | return success(); |
85 | } |
86 | |
87 | // coshf(float x) -> (exp(x) + exp(-x)) / 2 |
88 | static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) { |
89 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
90 | Value operand = op.getOperand(); |
91 | Type opType = operand.getType(); |
92 | Value exp = b.create<math::ExpOp>(operand); |
93 | |
94 | Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter); |
95 | Value nexp = b.create<arith::DivFOp>(one, exp); |
96 | Value add = b.create<arith::AddFOp>(exp, nexp); |
97 | Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter); |
98 | Value div = b.create<arith::DivFOp>(add, two); |
99 | rewriter.replaceOp(op, div); |
100 | return success(); |
101 | } |
102 | |
103 | /// Expands tanh op into |
104 | /// 1-exp^{-2x} / 1+exp^{-2x} |
105 | /// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`. |
106 | /// We compute a "signs" value which is -1 if input is negative and +1 if input |
107 | /// is positive. Then multiply the input by this value, guaranteeing that the |
108 | /// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0, |
109 | /// 1]. Expand the computation on the input `x * sign(x)`, then multiply the |
110 | /// result by `sign(x)` to retain sign of the real result. |
111 | static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { |
112 | auto floatType = op.getOperand().getType(); |
113 | Location loc = op.getLoc(); |
114 | Value zero = createFloatConst(loc, floatType, 0.0, rewriter); |
115 | Value one = createFloatConst(loc, floatType, 1.0, rewriter); |
116 | Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter); |
117 | |
118 | // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1 |
119 | Value isNegative = rewriter.create<arith::CmpFOp>( |
120 | loc, arith::CmpFPredicate::OLT, op.getOperand(), zero); |
121 | Value isNegativeFloat = |
122 | rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative); |
123 | Value isNegativeTimesNegTwo = |
124 | rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo); |
125 | Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one); |
126 | |
127 | // Normalize input to positive value: y = sign(x) * x |
128 | Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand()); |
129 | |
130 | // Decompose on normalized input |
131 | Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX); |
132 | Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX); |
133 | Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x); |
134 | Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x); |
135 | Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor); |
136 | |
137 | // Multiply result by sign(x) to retain signs from negative inputs |
138 | rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes); |
139 | |
140 | return success(); |
141 | } |
142 | |
143 | // Converts math.tan to math.sin, math.cos, and arith.divf. |
144 | static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) { |
145 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
146 | Value operand = op.getOperand(); |
147 | Type type = operand.getType(); |
148 | Value sin = b.create<math::SinOp>(type, operand); |
149 | Value cos = b.create<math::CosOp>(type, operand); |
150 | Value div = b.create<arith::DivFOp>(type, sin, cos); |
151 | rewriter.replaceOp(op, div); |
152 | return success(); |
153 | } |
154 | |
155 | static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { |
156 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
157 | Value operandA = op.getOperand(0); |
158 | Value operandB = op.getOperand(1); |
159 | Value operandC = op.getOperand(2); |
160 | Type type = op.getType(); |
161 | Value mult = b.create<arith::MulFOp>(type, operandA, operandB); |
162 | Value add = b.create<arith::AddFOp>(type, mult, operandC); |
163 | rewriter.replaceOp(op, add); |
164 | return success(); |
165 | } |
166 | |
167 | // Converts a floorf() function to the following: |
168 | // floorf(float x) -> |
169 | // y = (float)(int) x |
170 | // if (x < 0) then incr = -1 else incr = 0 |
171 | // y = y + incr <= replace this op with the floorf op. |
172 | static LogicalResult convertFloorOp(math::FloorOp op, |
173 | PatternRewriter &rewriter) { |
174 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
175 | Value operand = op.getOperand(); |
176 | Type opType = operand.getType(); |
177 | Value fpFixedConvert = createTruncatedFPValue(operand, b); |
178 | |
179 | // Creating constants for later use. |
180 | Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); |
181 | Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter); |
182 | |
183 | Value negCheck = |
184 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero); |
185 | Value incrValue = |
186 | b.create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero); |
187 | Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue); |
188 | rewriter.replaceOp(op, ret); |
189 | return success(); |
190 | } |
191 | |
192 | // Converts a ceilf() function to the following: |
193 | // ceilf(float x) -> |
194 | // y = (float)(int) x |
195 | // if (x > y) then incr = 1 else incr = 0 |
196 | // y = y + incr <= replace this op with the ceilf op. |
197 | static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { |
198 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
199 | Value operand = op.getOperand(); |
200 | Type opType = operand.getType(); |
201 | Value fpFixedConvert = createTruncatedFPValue(operand, b); |
202 | |
203 | // Creating constants for later use. |
204 | Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); |
205 | Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); |
206 | |
207 | Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, |
208 | fpFixedConvert); |
209 | Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero); |
210 | |
211 | Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue); |
212 | rewriter.replaceOp(op, ret); |
213 | return success(); |
214 | } |
215 | |
216 | // Convert `math.fpowi` to a series of `arith.mulf` operations. |
217 | // If the power is negative, we divide one by the result. |
218 | // If both the base and power are zero, the result is 1. |
219 | // In the case of non constant power, we convert the operation to `math.powf`. |
220 | static LogicalResult convertFPowIOp(math::FPowIOp op, |
221 | PatternRewriter &rewriter) { |
222 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
223 | Value base = op.getOperand(0); |
224 | Value power = op.getOperand(1); |
225 | Type baseType = base.getType(); |
226 | |
227 | auto convertFPowItoPowf = [&]() -> LogicalResult { |
228 | Value castPowerToFp = |
229 | rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power); |
230 | Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base, |
231 | castPowerToFp); |
232 | rewriter.replaceOp(op, res); |
233 | return success(); |
234 | }; |
235 | |
236 | Attribute cstAttr; |
237 | if (!matchPattern(value: power, pattern: m_Constant(bind_value: &cstAttr))) |
238 | return convertFPowItoPowf(); |
239 | |
240 | APInt value; |
241 | if (!matchPattern(cstAttr, m_ConstantInt(&value))) |
242 | return convertFPowItoPowf(); |
243 | |
244 | int64_t powerInt = value.getSExtValue(); |
245 | bool isNegative = powerInt < 0; |
246 | int64_t absPower = std::abs(i: powerInt); |
247 | Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter); |
248 | Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter); |
249 | |
250 | while (absPower > 0) { |
251 | if (absPower & 1) |
252 | res = b.create<arith::MulFOp>(baseType, base, res); |
253 | absPower >>= 1; |
254 | base = b.create<arith::MulFOp>(baseType, base, base); |
255 | } |
256 | |
257 | // Make sure not to introduce UB in case of negative power. |
258 | if (isNegative) { |
259 | auto &sem = dyn_cast<mlir::FloatType>(Val: getElementTypeOrSelf(type: baseType)) |
260 | .getFloatSemantics(); |
261 | Value zero = |
262 | createFloatConst(op->getLoc(), baseType, |
263 | APFloat::getZero(Sem: sem, /*Negative=*/false), rewriter); |
264 | Value negZero = |
265 | createFloatConst(op->getLoc(), baseType, |
266 | APFloat::getZero(Sem: sem, /*Negative=*/true), rewriter); |
267 | Value posInfinity = |
268 | createFloatConst(op->getLoc(), baseType, |
269 | APFloat::getInf(Sem: sem, /*Negative=*/false), rewriter); |
270 | Value negInfinity = |
271 | createFloatConst(op->getLoc(), baseType, |
272 | APFloat::getInf(Sem: sem, /*Negative=*/true), rewriter); |
273 | Value zeroEqCheck = |
274 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero); |
275 | Value negZeroEqCheck = |
276 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero); |
277 | res = b.create<arith::DivFOp>(baseType, one, res); |
278 | res = |
279 | b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res); |
280 | res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity, |
281 | res); |
282 | } |
283 | |
284 | rewriter.replaceOp(op, res); |
285 | return success(); |
286 | } |
287 | |
288 | // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) |
289 | static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { |
290 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
291 | Value operandA = op.getOperand(0); |
292 | Value operandB = op.getOperand(1); |
293 | Type opType = operandA.getType(); |
294 | Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); |
295 | Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter); |
296 | Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter); |
297 | Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA); |
298 | Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two); |
299 | |
300 | Value logA = b.create<math::LogOp>(opType, opASquared); |
301 | Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA); |
302 | Value expResult = b.create<math::ExpOp>(opType, mult); |
303 | Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne); |
304 | Value remainder = b.create<arith::RemFOp>(opType, operandB, two); |
305 | Value negCheck = |
306 | b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero); |
307 | Value oddPower = |
308 | b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero); |
309 | Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck); |
310 | |
311 | Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult, |
312 | expResult); |
313 | rewriter.replaceOp(op, res); |
314 | return success(); |
315 | } |
316 | |
317 | // exp2f(float x) -> exp(x * ln(2)) |
318 | // Proof: Let's say 2^x = y |
319 | // ln(2^x) = ln(y) |
320 | // x * ln(2) = ln(y) => e ^(x*ln(2)) = y |
321 | static LogicalResult convertExp2fOp(math::Exp2Op op, |
322 | PatternRewriter &rewriter) { |
323 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
324 | Value operand = op.getOperand(); |
325 | Type opType = operand.getType(); |
326 | Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b); |
327 | Value mult = b.create<arith::MulFOp>(opType, operand, ln2); |
328 | Value exp = b.create<math::ExpOp>(op->getLoc(), mult); |
329 | rewriter.replaceOp(op, exp); |
330 | return success(); |
331 | } |
332 | |
333 | static LogicalResult convertRoundOp(math::RoundOp op, |
334 | PatternRewriter &rewriter) { |
335 | Location loc = op.getLoc(); |
336 | ImplicitLocOpBuilder b(loc, rewriter); |
337 | Value operand = op.getOperand(); |
338 | Type opType = operand.getType(); |
339 | Type opEType = getElementTypeOrSelf(type: opType); |
340 | |
341 | if (!opEType.isF32()) { |
342 | return rewriter.notifyMatchFailure(op, "not a round of f32." ); |
343 | } |
344 | |
345 | Type i32Ty = b.getI32Type(); |
346 | if (auto shapedTy = dyn_cast<ShapedType>(opType)) |
347 | i32Ty = shapedTy.clone(i32Ty); |
348 | |
349 | Value half = createFloatConst(loc, type: opType, value: 0.5, b); |
350 | Value c23 = createIntConst(loc, type: i32Ty, value: 23, b); |
351 | Value c127 = createIntConst(loc, type: i32Ty, value: 127, b); |
352 | Value expMask = createIntConst(loc, type: i32Ty, value: (1 << 8) - 1, b); |
353 | |
354 | Value incrValue = b.create<math::CopySignOp>(half, operand); |
355 | Value add = b.create<arith::AddFOp>(opType, operand, incrValue); |
356 | Value fpFixedConvert = createTruncatedFPValue(operand: add, b); |
357 | |
358 | // There are three cases where adding 0.5 to the value and truncating by |
359 | // converting to an i64 does not result in the correct behavior: |
360 | // |
361 | // 1. Special values: +-inf and +-nan |
362 | // Casting these special values to i64 has undefined behavior. To identify |
363 | // these values, we use the fact that these values are the only float |
364 | // values with the maximum possible biased exponent. |
365 | // |
366 | // 2. Large values: 2^23 <= |x| <= INT_64_MAX |
367 | // Adding 0.5 to a float larger than or equal to 2^23 results in precision |
368 | // errors that sometimes round the value up and sometimes round the value |
369 | // down. For example: |
370 | // 8388608.0 + 0.5 = 8388608.0 |
371 | // 8388609.0 + 0.5 = 8388610.0 |
372 | // |
373 | // 3. Very large values: |x| > INT_64_MAX |
374 | // Casting to i64 a value greater than the max i64 value will overflow the |
375 | // i64 leading to wrong outputs. |
376 | // |
377 | // All three cases satisfy the property `biasedExp >= 23`. |
378 | Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand); |
379 | Value operandExp = b.create<arith::AndIOp>( |
380 | b.create<arith::ShRUIOp>(operandBitcast, c23), expMask); |
381 | Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127); |
382 | Value isSpecialValOrLargeVal = |
383 | b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23); |
384 | |
385 | Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand, |
386 | fpFixedConvert); |
387 | rewriter.replaceOp(op, result); |
388 | return success(); |
389 | } |
390 | |
391 | // Converts math.ctlz to scf and arith operations. This is done |
392 | // by performing a binary search on the bits. |
393 | static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, |
394 | PatternRewriter &rewriter) { |
395 | auto operand = op.getOperand(); |
396 | auto operandTy = operand.getType(); |
397 | auto eTy = getElementTypeOrSelf(operandTy); |
398 | Location loc = op.getLoc(); |
399 | |
400 | int32_t bitwidth = eTy.getIntOrFloatBitWidth(); |
401 | if (bitwidth > 64) |
402 | return failure(); |
403 | |
404 | uint64_t allbits = -1; |
405 | if (bitwidth < 64) { |
406 | allbits = allbits >> (64 - bitwidth); |
407 | } |
408 | |
409 | Value x = operand; |
410 | Value count = createIntConst(loc, operandTy, 0, rewriter); |
411 | for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) { |
412 | auto half = bw / 2; |
413 | auto bits = createIntConst(loc, operandTy, half, rewriter); |
414 | auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter); |
415 | |
416 | Value pred = |
417 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask); |
418 | Value add = rewriter.create<arith::AddIOp>(loc, count, bits); |
419 | Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits); |
420 | |
421 | x = rewriter.create<arith::SelectOp>(loc, pred, shift, x); |
422 | count = rewriter.create<arith::SelectOp>(loc, pred, add, count); |
423 | } |
424 | |
425 | Value zero = createIntConst(loc, operandTy, 0, rewriter); |
426 | Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
427 | operand, zero); |
428 | |
429 | Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter); |
430 | Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count); |
431 | rewriter.replaceOp(op, sel); |
432 | return success(); |
433 | } |
434 | |
435 | // Convert `math.roundeven` into `math.round` + arith ops |
436 | static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, |
437 | PatternRewriter &rewriter) { |
438 | Location loc = op.getLoc(); |
439 | ImplicitLocOpBuilder b(loc, rewriter); |
440 | auto operand = op.getOperand(); |
441 | Type operandTy = operand.getType(); |
442 | Type resultTy = op.getType(); |
443 | Type operandETy = getElementTypeOrSelf(type: operandTy); |
444 | Type resultETy = getElementTypeOrSelf(type: resultTy); |
445 | |
446 | if (!isa<FloatType>(Val: operandETy) || !isa<FloatType>(Val: resultETy)) { |
447 | return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32." ); |
448 | } |
449 | |
450 | Type fTy = operandTy; |
451 | Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth()); |
452 | if (auto shapedTy = dyn_cast<ShapedType>(fTy)) { |
453 | iTy = shapedTy.clone(iTy); |
454 | } |
455 | |
456 | unsigned bitWidth = operandETy.getIntOrFloatBitWidth(); |
457 | // The width returned by getFPMantissaWidth includes the integer bit. |
458 | unsigned mantissaWidth = |
459 | llvm::cast<FloatType>(Val&: operandETy).getFPMantissaWidth() - 1; |
460 | unsigned exponentWidth = bitWidth - mantissaWidth - 1; |
461 | |
462 | // The names of the variables correspond to f32. |
463 | // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa. |
464 | // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa. |
465 | // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa. |
466 | Value c1Float = createFloatConst(loc, type: fTy, value: 1.0, b); |
467 | Value c0 = createIntConst(loc, type: iTy, value: 0, b); |
468 | Value c1 = createIntConst(loc, type: iTy, value: 1, b); |
469 | Value cNeg1 = createIntConst(loc, type: iTy, value: -1, b); |
470 | Value c23 = createIntConst(loc, type: iTy, value: mantissaWidth, b); |
471 | Value c31 = createIntConst(loc, type: iTy, value: bitWidth - 1, b); |
472 | Value c127 = createIntConst(loc, type: iTy, value: (1ull << (exponentWidth - 1)) - 1, b); |
473 | Value c2To22 = createIntConst(loc, type: iTy, value: 1ull << (mantissaWidth - 1), b); |
474 | Value c23Mask = createIntConst(loc, type: iTy, value: (1ull << mantissaWidth) - 1, b); |
475 | Value expMask = createIntConst(loc, type: iTy, value: (1ull << exponentWidth) - 1, b); |
476 | |
477 | Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand); |
478 | Value round = b.create<math::RoundOp>(operand); |
479 | Value roundBitcast = b.create<arith::BitcastOp>(iTy, round); |
480 | |
481 | // Get biased exponents for operand and round(operand) |
482 | Value operandExp = b.create<arith::AndIOp>( |
483 | b.create<arith::ShRUIOp>(operandBitcast, c23), expMask); |
484 | Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127); |
485 | Value roundExp = b.create<arith::AndIOp>( |
486 | b.create<arith::ShRUIOp>(roundBitcast, c23), expMask); |
487 | Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127); |
488 | |
489 | auto safeShiftRight = [&](Value x, Value shift) -> Value { |
490 | // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior |
491 | Value clampedShift = b.create<arith::MaxSIOp>(shift, c0); |
492 | clampedShift = b.create<arith::MinSIOp>(clampedShift, c31); |
493 | return b.create<arith::ShRUIOp>(x, clampedShift); |
494 | }; |
495 | |
496 | auto maskMantissa = [&](Value mantissa, |
497 | Value mantissaMaskRightShift) -> Value { |
498 | Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift); |
499 | return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask); |
500 | }; |
501 | |
502 | // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring |
503 | // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers |
504 | // with `biasedExp > 23` (numbers where there is not enough precision to store |
505 | // decimals) are always even, and they satisfy the even condition trivially |
506 | // since the mantissa without all its bits is zero. The even condition |
507 | // is also true for +-0, since they have `biasedExp = -127` and the entire |
508 | // mantissa is zero. The case of +-1 has to be handled separately. Here |
509 | // we identify these values by noting that +-1 are the only whole numbers with |
510 | // `biasedExp == 0`. |
511 | // |
512 | // The special values +-inf and +-nan also satisfy the same property that |
513 | // whole non-unit even numbers satisfy. In particular, the special values have |
514 | // `biasedExp > 23`, so they get treated as large numbers with no room for |
515 | // decimals, which are always even. |
516 | Value roundBiasedExpEq0 = |
517 | b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0); |
518 | Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1); |
519 | Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1); |
520 | Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>( |
521 | arith::CmpIPredicate::ne, roundMaskedMantissa, c0); |
522 | roundIsNotEvenOrSpecialVal = |
523 | b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0); |
524 | |
525 | // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive |
526 | // integers if the bit at index `biasedExp` starting from the left in the |
527 | // mantissa is 1 and all the bits to the right are zero. Values with |
528 | // `biasedExp >= 23` don't have decimals, so they are never halfway. The |
529 | // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`, |
530 | // so these are handled separately. In particular, if `biasedExp == -1`, the |
531 | // value is halfway if the entire mantissa is zero. |
532 | Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>( |
533 | arith::CmpIPredicate::eq, operandBiasedExp, cNeg1); |
534 | Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>( |
535 | operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp)); |
536 | Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp); |
537 | Value operandIsHalfway = |
538 | b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa, |
539 | expectedOperandMaskedMantissa); |
540 | // Ensure `biasedExp` is in the valid range for half values. |
541 | Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>( |
542 | arith::CmpIPredicate::sge, operandBiasedExp, cNeg1); |
543 | Value operandBiasedExpLt23 = |
544 | b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23); |
545 | operandIsHalfway = |
546 | b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23); |
547 | operandIsHalfway = |
548 | b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1); |
549 | |
550 | // Adjust rounded operand with `round(operand) - sign(operand)` to correct the |
551 | // case where `round` rounded in the opposite direction of `roundeven`. |
552 | Value sign = b.create<math::CopySignOp>(c1Float, operand); |
553 | Value roundShifted = b.create<arith::SubFOp>(round, sign); |
554 | // If the rounded value is even or a special value, we default to the behavior |
555 | // of `math.round`. |
556 | Value needsShift = |
557 | b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway); |
558 | Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round); |
559 | // The `x - sign` adjustment does not preserve the sign when we are adjusting |
560 | // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is |
561 | // rounded to -0.0. |
562 | result = b.create<math::CopySignOp>(result, operand); |
563 | rewriter.replaceOp(op, result); |
564 | return success(); |
565 | } |
566 | |
567 | void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { |
568 | patterns.add(convertCtlzOp); |
569 | } |
570 | |
571 | void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) { |
572 | patterns.add(convertSinhOp); |
573 | } |
574 | |
575 | void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) { |
576 | patterns.add(convertCoshOp); |
577 | } |
578 | |
579 | void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { |
580 | patterns.add(convertTanOp); |
581 | } |
582 | |
583 | void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { |
584 | patterns.add(convertTanhOp); |
585 | } |
586 | |
587 | void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { |
588 | patterns.add(convertFmaFOp); |
589 | } |
590 | |
591 | void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { |
592 | patterns.add(convertCeilOp); |
593 | } |
594 | |
595 | void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { |
596 | patterns.add(convertExp2fOp); |
597 | } |
598 | |
599 | void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { |
600 | patterns.add(convertPowfOp); |
601 | } |
602 | |
603 | void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) { |
604 | patterns.add(convertFPowIOp); |
605 | } |
606 | |
607 | void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { |
608 | patterns.add(convertRoundOp); |
609 | } |
610 | |
611 | void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) { |
612 | patterns.add(convertFloorOp); |
613 | } |
614 | |
615 | void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) { |
616 | patterns.add(convertRoundEvenOp); |
617 | } |
618 | |