| 1 | //===- ExpandPatterns.cpp - Code to expand various math operations. -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements expansion of various math operations. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 14 | #include "mlir/Dialect/Math/IR/Math.h" |
| 15 | #include "mlir/Dialect/Math/Transforms/Passes.h" |
| 16 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 17 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 18 | #include "mlir/IR/Builders.h" |
| 19 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
| 20 | #include "mlir/IR/TypeUtilities.h" |
| 21 | #include "mlir/Transforms/DialectConversion.h" |
| 22 | |
| 23 | using namespace mlir; |
| 24 | |
| 25 | /// Create a float constant. |
| 26 | static Value createFloatConst(Location loc, Type type, APFloat value, |
| 27 | OpBuilder &b) { |
| 28 | bool losesInfo = false; |
| 29 | auto eltType = getElementTypeOrSelf(type); |
| 30 | // Convert double to the given `FloatType` with round-to-nearest-ties-to-even. |
| 31 | value.convert(ToSemantics: cast<FloatType>(Val&: eltType).getFloatSemantics(), |
| 32 | RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo); |
| 33 | auto attr = b.getFloatAttr(type: eltType, value); |
| 34 | if (auto shapedTy = dyn_cast<ShapedType>(Val&: type)) { |
| 35 | return b.create<arith::ConstantOp>(location: loc, |
| 36 | args: DenseElementsAttr::get(type: shapedTy, values: attr)); |
| 37 | } |
| 38 | |
| 39 | return b.create<arith::ConstantOp>(location: loc, args&: attr); |
| 40 | } |
| 41 | |
| 42 | static Value createFloatConst(Location loc, Type type, double value, |
| 43 | OpBuilder &b) { |
| 44 | return createFloatConst(loc, type, value: APFloat(value), b); |
| 45 | } |
| 46 | |
| 47 | /// Create an integer constant. |
| 48 | static Value createIntConst(Location loc, Type type, int64_t value, |
| 49 | OpBuilder &b) { |
| 50 | auto attr = b.getIntegerAttr(type: getElementTypeOrSelf(type), value); |
| 51 | if (auto shapedTy = dyn_cast<ShapedType>(Val&: type)) { |
| 52 | return b.create<arith::ConstantOp>(location: loc, |
| 53 | args: DenseElementsAttr::get(type: shapedTy, values: attr)); |
| 54 | } |
| 55 | |
| 56 | return b.create<arith::ConstantOp>(location: loc, args&: attr); |
| 57 | } |
| 58 | |
| 59 | static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) { |
| 60 | Type opType = operand.getType(); |
| 61 | Type i64Ty = b.getI64Type(); |
| 62 | if (auto shapedTy = dyn_cast<ShapedType>(Val&: opType)) |
| 63 | i64Ty = shapedTy.clone(elementType: i64Ty); |
| 64 | Value fixedConvert = b.create<arith::FPToSIOp>(args&: i64Ty, args&: operand); |
| 65 | Value fpFixedConvert = b.create<arith::SIToFPOp>(args&: opType, args&: fixedConvert); |
| 66 | // The truncation does not preserve the sign when the truncated |
| 67 | // value is -0. So here the sign is copied again. |
| 68 | return b.create<math::CopySignOp>(args&: fpFixedConvert, args&: operand); |
| 69 | } |
| 70 | |
| 71 | // sinhf(float x) -> (exp(x) - exp(-x)) / 2 |
| 72 | static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) { |
| 73 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 74 | Value operand = op.getOperand(); |
| 75 | Type opType = operand.getType(); |
| 76 | |
| 77 | Value exp = b.create<math::ExpOp>(args&: operand); |
| 78 | Value neg = b.create<arith::NegFOp>(args&: operand); |
| 79 | Value nexp = b.create<math::ExpOp>(args&: neg); |
| 80 | Value sub = b.create<arith::SubFOp>(args&: exp, args&: nexp); |
| 81 | Value half = createFloatConst(loc: op->getLoc(), type: opType, value: 0.5, b&: rewriter); |
| 82 | Value res = b.create<arith::MulFOp>(args&: sub, args&: half); |
| 83 | rewriter.replaceOp(op, newValues: res); |
| 84 | return success(); |
| 85 | } |
| 86 | |
| 87 | // coshf(float x) -> (exp(x) + exp(-x)) / 2 |
| 88 | static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) { |
| 89 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 90 | Value operand = op.getOperand(); |
| 91 | Type opType = operand.getType(); |
| 92 | |
| 93 | Value exp = b.create<math::ExpOp>(args&: operand); |
| 94 | Value neg = b.create<arith::NegFOp>(args&: operand); |
| 95 | Value nexp = b.create<math::ExpOp>(args&: neg); |
| 96 | Value add = b.create<arith::AddFOp>(args&: exp, args&: nexp); |
| 97 | Value half = createFloatConst(loc: op->getLoc(), type: opType, value: 0.5, b&: rewriter); |
| 98 | Value res = b.create<arith::MulFOp>(args&: add, args&: half); |
| 99 | rewriter.replaceOp(op, newValues: res); |
| 100 | return success(); |
| 101 | } |
| 102 | |
| 103 | /// Expands tanh op into |
| 104 | /// 1-exp^{-2x} / 1+exp^{-2x} |
| 105 | /// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`. |
| 106 | /// We compute a "signs" value which is -1 if input is negative and +1 if input |
| 107 | /// is positive. Then multiply the input by this value, guaranteeing that the |
| 108 | /// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0, |
| 109 | /// 1]. Expand the computation on the input `x * sign(x)`, then multiply the |
| 110 | /// result by `sign(x)` to retain sign of the real result. |
| 111 | static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { |
| 112 | auto floatType = op.getOperand().getType(); |
| 113 | Location loc = op.getLoc(); |
| 114 | Value zero = createFloatConst(loc, type: floatType, value: 0.0, b&: rewriter); |
| 115 | Value one = createFloatConst(loc, type: floatType, value: 1.0, b&: rewriter); |
| 116 | Value negTwo = createFloatConst(loc, type: floatType, value: -2.0, b&: rewriter); |
| 117 | |
| 118 | // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1 |
| 119 | Value isNegative = rewriter.create<arith::CmpFOp>( |
| 120 | location: loc, args: arith::CmpFPredicate::OLT, args: op.getOperand(), args&: zero); |
| 121 | Value isNegativeFloat = |
| 122 | rewriter.create<arith::UIToFPOp>(location: loc, args&: floatType, args&: isNegative); |
| 123 | Value isNegativeTimesNegTwo = |
| 124 | rewriter.create<arith::MulFOp>(location: loc, args&: isNegativeFloat, args&: negTwo); |
| 125 | Value sign = rewriter.create<arith::AddFOp>(location: loc, args&: isNegativeTimesNegTwo, args&: one); |
| 126 | |
| 127 | // Normalize input to positive value: y = sign(x) * x |
| 128 | Value positiveX = rewriter.create<arith::MulFOp>(location: loc, args&: sign, args: op.getOperand()); |
| 129 | |
| 130 | // Decompose on normalized input |
| 131 | Value negDoubledX = rewriter.create<arith::MulFOp>(location: loc, args&: negTwo, args&: positiveX); |
| 132 | Value exp2x = rewriter.create<math::ExpOp>(location: loc, args&: negDoubledX); |
| 133 | Value dividend = rewriter.create<arith::SubFOp>(location: loc, args&: one, args&: exp2x); |
| 134 | Value divisor = rewriter.create<arith::AddFOp>(location: loc, args&: one, args&: exp2x); |
| 135 | Value positiveRes = rewriter.create<arith::DivFOp>(location: loc, args&: dividend, args&: divisor); |
| 136 | |
| 137 | // Multiply result by sign(x) to retain signs from negative inputs |
| 138 | rewriter.replaceOpWithNewOp<arith::MulFOp>(op, args&: sign, args&: positiveRes); |
| 139 | |
| 140 | return success(); |
| 141 | } |
| 142 | |
| 143 | // Converts math.tan to math.sin, math.cos, and arith.divf. |
| 144 | static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) { |
| 145 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 146 | Value operand = op.getOperand(); |
| 147 | Type type = operand.getType(); |
| 148 | Value sin = b.create<math::SinOp>(args&: type, args&: operand); |
| 149 | Value cos = b.create<math::CosOp>(args&: type, args&: operand); |
| 150 | Value div = b.create<arith::DivFOp>(args&: type, args&: sin, args&: cos); |
| 151 | rewriter.replaceOp(op, newValues: div); |
| 152 | return success(); |
| 153 | } |
| 154 | |
| 155 | // asinh(float x) -> log(x + sqrt(x**2 + 1)) |
| 156 | static LogicalResult convertAsinhOp(math::AsinhOp op, |
| 157 | PatternRewriter &rewriter) { |
| 158 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 159 | Value operand = op.getOperand(); |
| 160 | Type opType = operand.getType(); |
| 161 | |
| 162 | Value one = createFloatConst(loc: op->getLoc(), type: opType, value: 1.0, b&: rewriter); |
| 163 | Value fma = b.create<math::FmaOp>(args&: operand, args&: operand, args&: one); |
| 164 | Value sqrt = b.create<math::SqrtOp>(args&: fma); |
| 165 | Value add = b.create<arith::AddFOp>(args&: operand, args&: sqrt); |
| 166 | Value res = b.create<math::LogOp>(args&: add); |
| 167 | rewriter.replaceOp(op, newValues: res); |
| 168 | return success(); |
| 169 | } |
| 170 | |
| 171 | // acosh(float x) -> log(x + sqrt(x**2 - 1)) |
| 172 | static LogicalResult convertAcoshOp(math::AcoshOp op, |
| 173 | PatternRewriter &rewriter) { |
| 174 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 175 | Value operand = op.getOperand(); |
| 176 | Type opType = operand.getType(); |
| 177 | |
| 178 | Value negOne = createFloatConst(loc: op->getLoc(), type: opType, value: -1.0, b&: rewriter); |
| 179 | Value fma = b.create<math::FmaOp>(args&: operand, args&: operand, args&: negOne); |
| 180 | Value sqrt = b.create<math::SqrtOp>(args&: fma); |
| 181 | Value add = b.create<arith::AddFOp>(args&: operand, args&: sqrt); |
| 182 | Value res = b.create<math::LogOp>(args&: add); |
| 183 | rewriter.replaceOp(op, newValues: res); |
| 184 | return success(); |
| 185 | } |
| 186 | |
| 187 | // atanh(float x) -> log((1 + x) / (1 - x)) / 2 |
| 188 | static LogicalResult convertAtanhOp(math::AtanhOp op, |
| 189 | PatternRewriter &rewriter) { |
| 190 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 191 | Value operand = op.getOperand(); |
| 192 | Type opType = operand.getType(); |
| 193 | |
| 194 | Value one = createFloatConst(loc: op->getLoc(), type: opType, value: 1.0, b&: rewriter); |
| 195 | Value add = b.create<arith::AddFOp>(args&: operand, args&: one); |
| 196 | Value neg = b.create<arith::NegFOp>(args&: operand); |
| 197 | Value sub = b.create<arith::AddFOp>(args&: neg, args&: one); |
| 198 | Value div = b.create<arith::DivFOp>(args&: add, args&: sub); |
| 199 | Value log = b.create<math::LogOp>(args&: div); |
| 200 | Value half = createFloatConst(loc: op->getLoc(), type: opType, value: 0.5, b&: rewriter); |
| 201 | Value res = b.create<arith::MulFOp>(args&: log, args&: half); |
| 202 | rewriter.replaceOp(op, newValues: res); |
| 203 | return success(); |
| 204 | } |
| 205 | |
| 206 | static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { |
| 207 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 208 | Value operandA = op.getOperand(i: 0); |
| 209 | Value operandB = op.getOperand(i: 1); |
| 210 | Value operandC = op.getOperand(i: 2); |
| 211 | Type type = op.getType(); |
| 212 | Value mult = b.create<arith::MulFOp>(args&: type, args&: operandA, args&: operandB); |
| 213 | Value add = b.create<arith::AddFOp>(args&: type, args&: mult, args&: operandC); |
| 214 | rewriter.replaceOp(op, newValues: add); |
| 215 | return success(); |
| 216 | } |
| 217 | |
| 218 | // Converts a ceilf() function to the following: |
| 219 | // ceilf(float x) -> |
| 220 | // y = (float)(int) x |
| 221 | // if (x > y) then incr = 1 else incr = 0 |
| 222 | // y = y + incr <= replace this op with the ceilf op. |
| 223 | static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { |
| 224 | // Creating constants assumes the static shaped type. |
| 225 | auto shapedType = dyn_cast<ShapedType>(Val: op.getType()); |
| 226 | if (shapedType && !shapedType.hasStaticShape()) |
| 227 | return failure(); |
| 228 | |
| 229 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 230 | Value operand = op.getOperand(); |
| 231 | Type opType = operand.getType(); |
| 232 | Value fpFixedConvert = createTruncatedFPValue(operand, b); |
| 233 | |
| 234 | // Creating constants for later use. |
| 235 | Value zero = createFloatConst(loc: op->getLoc(), type: opType, value: 0.00, b&: rewriter); |
| 236 | Value one = createFloatConst(loc: op->getLoc(), type: opType, value: 1.00, b&: rewriter); |
| 237 | |
| 238 | Value gtCheck = b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OGT, args&: operand, |
| 239 | args&: fpFixedConvert); |
| 240 | Value incrValue = b.create<arith::SelectOp>(location: op->getLoc(), args&: gtCheck, args&: one, args&: zero); |
| 241 | |
| 242 | Value ret = b.create<arith::AddFOp>(args&: opType, args&: fpFixedConvert, args&: incrValue); |
| 243 | rewriter.replaceOp(op, newValues: ret); |
| 244 | return success(); |
| 245 | } |
| 246 | |
| 247 | // Convert `math.fpowi` to a series of `arith.mulf` operations. |
| 248 | // If the power is negative, we divide one by the result. |
| 249 | // If both the base and power are zero, the result is 1. |
| 250 | // In the case of non constant power, we convert the operation to `math.powf`. |
| 251 | static LogicalResult convertFPowIOp(math::FPowIOp op, |
| 252 | PatternRewriter &rewriter) { |
| 253 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 254 | Value base = op.getOperand(i: 0); |
| 255 | Value power = op.getOperand(i: 1); |
| 256 | Type baseType = base.getType(); |
| 257 | |
| 258 | auto convertFPowItoPowf = [&]() -> LogicalResult { |
| 259 | Value castPowerToFp = |
| 260 | rewriter.create<arith::SIToFPOp>(location: op.getLoc(), args&: baseType, args&: power); |
| 261 | Value res = rewriter.create<math::PowFOp>(location: op.getLoc(), args&: baseType, args&: base, |
| 262 | args&: castPowerToFp); |
| 263 | rewriter.replaceOp(op, newValues: res); |
| 264 | return success(); |
| 265 | }; |
| 266 | |
| 267 | Attribute cstAttr; |
| 268 | if (!matchPattern(value: power, pattern: m_Constant(bind_value: &cstAttr))) |
| 269 | return convertFPowItoPowf(); |
| 270 | |
| 271 | APInt value; |
| 272 | if (!matchPattern(attr: cstAttr, pattern: m_ConstantInt(bind_value: &value))) |
| 273 | return convertFPowItoPowf(); |
| 274 | |
| 275 | int64_t powerInt = value.getSExtValue(); |
| 276 | bool isNegative = powerInt < 0; |
| 277 | int64_t absPower = std::abs(i: powerInt); |
| 278 | Value one = createFloatConst(loc: op->getLoc(), type: baseType, value: 1.00, b&: rewriter); |
| 279 | Value res = createFloatConst(loc: op->getLoc(), type: baseType, value: 1.00, b&: rewriter); |
| 280 | |
| 281 | while (absPower > 0) { |
| 282 | if (absPower & 1) |
| 283 | res = b.create<arith::MulFOp>(args&: baseType, args&: base, args&: res); |
| 284 | absPower >>= 1; |
| 285 | base = b.create<arith::MulFOp>(args&: baseType, args&: base, args&: base); |
| 286 | } |
| 287 | |
| 288 | // Make sure not to introduce UB in case of negative power. |
| 289 | if (isNegative) { |
| 290 | auto &sem = dyn_cast<mlir::FloatType>(Val: getElementTypeOrSelf(type: baseType)) |
| 291 | .getFloatSemantics(); |
| 292 | Value zero = |
| 293 | createFloatConst(loc: op->getLoc(), type: baseType, |
| 294 | value: APFloat::getZero(Sem: sem, /*Negative=*/false), b&: rewriter); |
| 295 | Value negZero = |
| 296 | createFloatConst(loc: op->getLoc(), type: baseType, |
| 297 | value: APFloat::getZero(Sem: sem, /*Negative=*/true), b&: rewriter); |
| 298 | Value posInfinity = |
| 299 | createFloatConst(loc: op->getLoc(), type: baseType, |
| 300 | value: APFloat::getInf(Sem: sem, /*Negative=*/false), b&: rewriter); |
| 301 | Value negInfinity = |
| 302 | createFloatConst(loc: op->getLoc(), type: baseType, |
| 303 | value: APFloat::getInf(Sem: sem, /*Negative=*/true), b&: rewriter); |
| 304 | Value zeroEqCheck = |
| 305 | b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: res, args&: zero); |
| 306 | Value negZeroEqCheck = |
| 307 | b.create<arith::CmpFOp>(args: arith::CmpFPredicate::OEQ, args&: res, args&: negZero); |
| 308 | res = b.create<arith::DivFOp>(args&: baseType, args&: one, args&: res); |
| 309 | res = |
| 310 | b.create<arith::SelectOp>(location: op->getLoc(), args&: zeroEqCheck, args&: posInfinity, args&: res); |
| 311 | res = b.create<arith::SelectOp>(location: op->getLoc(), args&: negZeroEqCheck, args&: negInfinity, |
| 312 | args&: res); |
| 313 | } |
| 314 | |
| 315 | rewriter.replaceOp(op, newValues: res); |
| 316 | return success(); |
| 317 | } |
| 318 | |
| 319 | // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) |
| 320 | // Some special cases where b is constant are handled separately: |
| 321 | // when b == 0, or |b| == 0.5, 1.0, or 2.0. |
| 322 | static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { |
| 323 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 324 | Value operandA = op.getOperand(i: 0); |
| 325 | Value operandB = op.getOperand(i: 1); |
| 326 | auto typeA = operandA.getType(); |
| 327 | auto typeB = operandB.getType(); |
| 328 | |
| 329 | auto &sem = |
| 330 | cast<mlir::FloatType>(Val: getElementTypeOrSelf(type: typeB)).getFloatSemantics(); |
| 331 | APFloat valueB(sem); |
| 332 | auto mulf = [&](Value x, Value y) -> Value { |
| 333 | return b.create<arith::MulFOp>(args&: x, args&: y); |
| 334 | }; |
| 335 | if (matchPattern(value: operandB, pattern: m_ConstantFloat(bind_value: &valueB))) { |
| 336 | if (valueB.isZero()) { |
| 337 | // a^0 -> 1 |
| 338 | Value one = createFloatConst(loc: op->getLoc(), type: typeA, value: 1.0, b&: rewriter); |
| 339 | rewriter.replaceOp(op, newValues: one); |
| 340 | return success(); |
| 341 | } |
| 342 | if (valueB.isExactlyValue(V: 1.0)) { |
| 343 | // a^1 -> a |
| 344 | rewriter.replaceOp(op, newValues: operandA); |
| 345 | return success(); |
| 346 | } |
| 347 | if (valueB.isExactlyValue(V: -1.0)) { |
| 348 | // a^(-1) -> 1 / a |
| 349 | Value one = createFloatConst(loc: op->getLoc(), type: typeA, value: 1.0, b&: rewriter); |
| 350 | Value div = b.create<arith::DivFOp>(args&: one, args&: operandA); |
| 351 | rewriter.replaceOp(op, newValues: div); |
| 352 | return success(); |
| 353 | } |
| 354 | if (valueB.isExactlyValue(V: 0.5)) { |
| 355 | // a^(1/2) -> sqrt(a) |
| 356 | Value sqrt = b.create<math::SqrtOp>(args&: operandA); |
| 357 | rewriter.replaceOp(op, newValues: sqrt); |
| 358 | return success(); |
| 359 | } |
| 360 | if (valueB.isExactlyValue(V: -0.5)) { |
| 361 | // a^(-1/2) -> 1 / sqrt(a) |
| 362 | Value rsqrt = b.create<math::RsqrtOp>(args&: operandA); |
| 363 | rewriter.replaceOp(op, newValues: rsqrt); |
| 364 | return success(); |
| 365 | } |
| 366 | if (valueB.isExactlyValue(V: 2.0)) { |
| 367 | // a^2 -> a * a |
| 368 | rewriter.replaceOp(op, newValues: mulf(operandA, operandA)); |
| 369 | return success(); |
| 370 | } |
| 371 | if (valueB.isExactlyValue(V: -2.0)) { |
| 372 | // a^(-2) -> 1 / (a * a) |
| 373 | Value one = |
| 374 | createFloatConst(loc: op->getLoc(), type: operandA.getType(), value: 1.0, b&: rewriter); |
| 375 | Value div = b.create<arith::DivFOp>(args&: one, args: mulf(operandA, operandA)); |
| 376 | rewriter.replaceOp(op, newValues: div); |
| 377 | return success(); |
| 378 | } |
| 379 | if (valueB.isExactlyValue(V: 3.0)) { |
| 380 | rewriter.replaceOp(op, newValues: mulf(mulf(operandA, operandA), operandA)); |
| 381 | return success(); |
| 382 | } |
| 383 | } |
| 384 | |
| 385 | Value logA = b.create<math::LogOp>(args&: operandA); |
| 386 | Value mult = b.create<arith::MulFOp>(args&: operandB, args&: logA); |
| 387 | Value expResult = b.create<math::ExpOp>(args&: mult); |
| 388 | rewriter.replaceOp(op, newValues: expResult); |
| 389 | return success(); |
| 390 | } |
| 391 | |
| 392 | // exp2f(float x) -> exp(x * ln(2)) |
| 393 | // Proof: Let's say 2^x = y |
| 394 | // ln(2^x) = ln(y) |
| 395 | // x * ln(2) = ln(y) => e ^(x*ln(2)) = y |
| 396 | static LogicalResult convertExp2fOp(math::Exp2Op op, |
| 397 | PatternRewriter &rewriter) { |
| 398 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 399 | Value operand = op.getOperand(); |
| 400 | Type opType = operand.getType(); |
| 401 | Value ln2 = createFloatConst(loc: op->getLoc(), type: opType, value: llvm::numbers::ln2, b); |
| 402 | Value mult = b.create<arith::MulFOp>(args&: opType, args&: operand, args&: ln2); |
| 403 | Value exp = b.create<math::ExpOp>(location: op->getLoc(), args&: mult); |
| 404 | rewriter.replaceOp(op, newValues: exp); |
| 405 | return success(); |
| 406 | } |
| 407 | |
| 408 | static LogicalResult convertRoundOp(math::RoundOp op, |
| 409 | PatternRewriter &rewriter) { |
| 410 | Location loc = op.getLoc(); |
| 411 | ImplicitLocOpBuilder b(loc, rewriter); |
| 412 | Value operand = op.getOperand(); |
| 413 | Type opType = operand.getType(); |
| 414 | Type opEType = getElementTypeOrSelf(type: opType); |
| 415 | |
| 416 | if (!opEType.isF32()) { |
| 417 | return rewriter.notifyMatchFailure(arg&: op, msg: "not a round of f32." ); |
| 418 | } |
| 419 | |
| 420 | Type i32Ty = b.getI32Type(); |
| 421 | if (auto shapedTy = dyn_cast<ShapedType>(Val&: opType)) |
| 422 | i32Ty = shapedTy.clone(elementType: i32Ty); |
| 423 | |
| 424 | Value half = createFloatConst(loc, type: opType, value: 0.5, b); |
| 425 | Value c23 = createIntConst(loc, type: i32Ty, value: 23, b); |
| 426 | Value c127 = createIntConst(loc, type: i32Ty, value: 127, b); |
| 427 | Value expMask = createIntConst(loc, type: i32Ty, value: (1 << 8) - 1, b); |
| 428 | |
| 429 | Value incrValue = b.create<math::CopySignOp>(args&: half, args&: operand); |
| 430 | Value add = b.create<arith::AddFOp>(args&: opType, args&: operand, args&: incrValue); |
| 431 | Value fpFixedConvert = createTruncatedFPValue(operand: add, b); |
| 432 | |
| 433 | // There are three cases where adding 0.5 to the value and truncating by |
| 434 | // converting to an i64 does not result in the correct behavior: |
| 435 | // |
| 436 | // 1. Special values: +-inf and +-nan |
| 437 | // Casting these special values to i64 has undefined behavior. To identify |
| 438 | // these values, we use the fact that these values are the only float |
| 439 | // values with the maximum possible biased exponent. |
| 440 | // |
| 441 | // 2. Large values: 2^23 <= |x| <= INT_64_MAX |
| 442 | // Adding 0.5 to a float larger than or equal to 2^23 results in precision |
| 443 | // errors that sometimes round the value up and sometimes round the value |
| 444 | // down. For example: |
| 445 | // 8388608.0 + 0.5 = 8388608.0 |
| 446 | // 8388609.0 + 0.5 = 8388610.0 |
| 447 | // |
| 448 | // 3. Very large values: |x| > INT_64_MAX |
| 449 | // Casting to i64 a value greater than the max i64 value will overflow the |
| 450 | // i64 leading to wrong outputs. |
| 451 | // |
| 452 | // All three cases satisfy the property `biasedExp >= 23`. |
| 453 | Value operandBitcast = b.create<arith::BitcastOp>(args&: i32Ty, args&: operand); |
| 454 | Value operandExp = b.create<arith::AndIOp>( |
| 455 | args: b.create<arith::ShRUIOp>(args&: operandBitcast, args&: c23), args&: expMask); |
| 456 | Value operandBiasedExp = b.create<arith::SubIOp>(args&: operandExp, args&: c127); |
| 457 | Value isSpecialValOrLargeVal = |
| 458 | b.create<arith::CmpIOp>(args: arith::CmpIPredicate::sge, args&: operandBiasedExp, args&: c23); |
| 459 | |
| 460 | Value result = b.create<arith::SelectOp>(args&: isSpecialValOrLargeVal, args&: operand, |
| 461 | args&: fpFixedConvert); |
| 462 | rewriter.replaceOp(op, newValues: result); |
| 463 | return success(); |
| 464 | } |
| 465 | |
| 466 | // Converts math.ctlz to scf and arith operations. This is done |
| 467 | // by performing a binary search on the bits. |
| 468 | static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, |
| 469 | PatternRewriter &rewriter) { |
| 470 | auto operand = op.getOperand(); |
| 471 | auto operandTy = operand.getType(); |
| 472 | auto eTy = getElementTypeOrSelf(type: operandTy); |
| 473 | Location loc = op.getLoc(); |
| 474 | |
| 475 | int32_t bitwidth = eTy.getIntOrFloatBitWidth(); |
| 476 | if (bitwidth > 64) |
| 477 | return failure(); |
| 478 | |
| 479 | uint64_t allbits = -1; |
| 480 | if (bitwidth < 64) { |
| 481 | allbits = allbits >> (64 - bitwidth); |
| 482 | } |
| 483 | |
| 484 | Value x = operand; |
| 485 | Value count = createIntConst(loc, type: operandTy, value: 0, b&: rewriter); |
| 486 | for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) { |
| 487 | auto half = bw / 2; |
| 488 | auto bits = createIntConst(loc, type: operandTy, value: half, b&: rewriter); |
| 489 | auto mask = createIntConst(loc, type: operandTy, value: allbits >> half, b&: rewriter); |
| 490 | |
| 491 | Value pred = |
| 492 | rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::ule, args&: x, args&: mask); |
| 493 | Value add = rewriter.create<arith::AddIOp>(location: loc, args&: count, args&: bits); |
| 494 | Value shift = rewriter.create<arith::ShLIOp>(location: loc, args&: x, args&: bits); |
| 495 | |
| 496 | x = rewriter.create<arith::SelectOp>(location: loc, args&: pred, args&: shift, args&: x); |
| 497 | count = rewriter.create<arith::SelectOp>(location: loc, args&: pred, args&: add, args&: count); |
| 498 | } |
| 499 | |
| 500 | Value zero = createIntConst(loc, type: operandTy, value: 0, b&: rewriter); |
| 501 | Value pred = rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::eq, |
| 502 | args&: operand, args&: zero); |
| 503 | |
| 504 | Value bwval = createIntConst(loc, type: operandTy, value: bitwidth, b&: rewriter); |
| 505 | Value sel = rewriter.create<arith::SelectOp>(location: loc, args&: pred, args&: bwval, args&: count); |
| 506 | rewriter.replaceOp(op, newValues: sel); |
| 507 | return success(); |
| 508 | } |
| 509 | |
| 510 | // Convert `math.roundeven` into `math.round` + arith ops |
| 511 | static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, |
| 512 | PatternRewriter &rewriter) { |
| 513 | Location loc = op.getLoc(); |
| 514 | ImplicitLocOpBuilder b(loc, rewriter); |
| 515 | auto operand = op.getOperand(); |
| 516 | Type operandTy = operand.getType(); |
| 517 | Type resultTy = op.getType(); |
| 518 | Type operandETy = getElementTypeOrSelf(type: operandTy); |
| 519 | Type resultETy = getElementTypeOrSelf(type: resultTy); |
| 520 | |
| 521 | if (!isa<FloatType>(Val: operandETy) || !isa<FloatType>(Val: resultETy)) { |
| 522 | return rewriter.notifyMatchFailure(arg&: op, msg: "not a roundeven of f16 or f32." ); |
| 523 | } |
| 524 | |
| 525 | Type fTy = operandTy; |
| 526 | Type iTy = rewriter.getIntegerType(width: operandETy.getIntOrFloatBitWidth()); |
| 527 | if (auto shapedTy = dyn_cast<ShapedType>(Val&: fTy)) { |
| 528 | iTy = shapedTy.clone(elementType: iTy); |
| 529 | } |
| 530 | |
| 531 | unsigned bitWidth = operandETy.getIntOrFloatBitWidth(); |
| 532 | // The width returned by getFPMantissaWidth includes the integer bit. |
| 533 | unsigned mantissaWidth = |
| 534 | llvm::cast<FloatType>(Val&: operandETy).getFPMantissaWidth() - 1; |
| 535 | unsigned exponentWidth = bitWidth - mantissaWidth - 1; |
| 536 | |
| 537 | // The names of the variables correspond to f32. |
| 538 | // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa. |
| 539 | // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa. |
| 540 | // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa. |
| 541 | Value c1Float = createFloatConst(loc, type: fTy, value: 1.0, b); |
| 542 | Value c0 = createIntConst(loc, type: iTy, value: 0, b); |
| 543 | Value c1 = createIntConst(loc, type: iTy, value: 1, b); |
| 544 | Value cNeg1 = createIntConst(loc, type: iTy, value: -1, b); |
| 545 | Value c23 = createIntConst(loc, type: iTy, value: mantissaWidth, b); |
| 546 | Value c31 = createIntConst(loc, type: iTy, value: bitWidth - 1, b); |
| 547 | Value c127 = createIntConst(loc, type: iTy, value: (1ull << (exponentWidth - 1)) - 1, b); |
| 548 | Value c2To22 = createIntConst(loc, type: iTy, value: 1ull << (mantissaWidth - 1), b); |
| 549 | Value c23Mask = createIntConst(loc, type: iTy, value: (1ull << mantissaWidth) - 1, b); |
| 550 | Value expMask = createIntConst(loc, type: iTy, value: (1ull << exponentWidth) - 1, b); |
| 551 | |
| 552 | Value operandBitcast = b.create<arith::BitcastOp>(args&: iTy, args&: operand); |
| 553 | Value round = b.create<math::RoundOp>(args&: operand); |
| 554 | Value roundBitcast = b.create<arith::BitcastOp>(args&: iTy, args&: round); |
| 555 | |
| 556 | // Get biased exponents for operand and round(operand) |
| 557 | Value operandExp = b.create<arith::AndIOp>( |
| 558 | args: b.create<arith::ShRUIOp>(args&: operandBitcast, args&: c23), args&: expMask); |
| 559 | Value operandBiasedExp = b.create<arith::SubIOp>(args&: operandExp, args&: c127); |
| 560 | Value roundExp = b.create<arith::AndIOp>( |
| 561 | args: b.create<arith::ShRUIOp>(args&: roundBitcast, args&: c23), args&: expMask); |
| 562 | Value roundBiasedExp = b.create<arith::SubIOp>(args&: roundExp, args&: c127); |
| 563 | |
| 564 | auto safeShiftRight = [&](Value x, Value shift) -> Value { |
| 565 | // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior |
| 566 | Value clampedShift = b.create<arith::MaxSIOp>(args&: shift, args&: c0); |
| 567 | clampedShift = b.create<arith::MinSIOp>(args&: clampedShift, args&: c31); |
| 568 | return b.create<arith::ShRUIOp>(args&: x, args&: clampedShift); |
| 569 | }; |
| 570 | |
| 571 | auto maskMantissa = [&](Value mantissa, |
| 572 | Value mantissaMaskRightShift) -> Value { |
| 573 | Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift); |
| 574 | return b.create<arith::AndIOp>(args&: mantissa, args&: shiftedMantissaMask); |
| 575 | }; |
| 576 | |
| 577 | // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring |
| 578 | // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers |
| 579 | // with `biasedExp > 23` (numbers where there is not enough precision to store |
| 580 | // decimals) are always even, and they satisfy the even condition trivially |
| 581 | // since the mantissa without all its bits is zero. The even condition |
| 582 | // is also true for +-0, since they have `biasedExp = -127` and the entire |
| 583 | // mantissa is zero. The case of +-1 has to be handled separately. Here |
| 584 | // we identify these values by noting that +-1 are the only whole numbers with |
| 585 | // `biasedExp == 0`. |
| 586 | // |
| 587 | // The special values +-inf and +-nan also satisfy the same property that |
| 588 | // whole non-unit even numbers satisfy. In particular, the special values have |
| 589 | // `biasedExp > 23`, so they get treated as large numbers with no room for |
| 590 | // decimals, which are always even. |
| 591 | Value roundBiasedExpEq0 = |
| 592 | b.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: roundBiasedExp, args&: c0); |
| 593 | Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(args&: roundBiasedExp, args&: c1); |
| 594 | Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1); |
| 595 | Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>( |
| 596 | args: arith::CmpIPredicate::ne, args&: roundMaskedMantissa, args&: c0); |
| 597 | roundIsNotEvenOrSpecialVal = |
| 598 | b.create<arith::OrIOp>(args&: roundIsNotEvenOrSpecialVal, args&: roundBiasedExpEq0); |
| 599 | |
| 600 | // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive |
| 601 | // integers if the bit at index `biasedExp` starting from the left in the |
| 602 | // mantissa is 1 and all the bits to the right are zero. Values with |
| 603 | // `biasedExp >= 23` don't have decimals, so they are never halfway. The |
| 604 | // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`, |
| 605 | // so these are handled separately. In particular, if `biasedExp == -1`, the |
| 606 | // value is halfway if the entire mantissa is zero. |
| 607 | Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>( |
| 608 | args: arith::CmpIPredicate::eq, args&: operandBiasedExp, args&: cNeg1); |
| 609 | Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>( |
| 610 | args&: operandBiasedExpEqNeg1, args&: c0, args: safeShiftRight(c2To22, operandBiasedExp)); |
| 611 | Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp); |
| 612 | Value operandIsHalfway = |
| 613 | b.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: operandMaskedMantissa, |
| 614 | args&: expectedOperandMaskedMantissa); |
| 615 | // Ensure `biasedExp` is in the valid range for half values. |
| 616 | Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>( |
| 617 | args: arith::CmpIPredicate::sge, args&: operandBiasedExp, args&: cNeg1); |
| 618 | Value operandBiasedExpLt23 = |
| 619 | b.create<arith::CmpIOp>(args: arith::CmpIPredicate::slt, args&: operandBiasedExp, args&: c23); |
| 620 | operandIsHalfway = |
| 621 | b.create<arith::AndIOp>(args&: operandIsHalfway, args&: operandBiasedExpLt23); |
| 622 | operandIsHalfway = |
| 623 | b.create<arith::AndIOp>(args&: operandIsHalfway, args&: operandBiasedExpGeNeg1); |
| 624 | |
| 625 | // Adjust rounded operand with `round(operand) - sign(operand)` to correct the |
| 626 | // case where `round` rounded in the opposite direction of `roundeven`. |
| 627 | Value sign = b.create<math::CopySignOp>(args&: c1Float, args&: operand); |
| 628 | Value roundShifted = b.create<arith::SubFOp>(args&: round, args&: sign); |
| 629 | // If the rounded value is even or a special value, we default to the behavior |
| 630 | // of `math.round`. |
| 631 | Value needsShift = |
| 632 | b.create<arith::AndIOp>(args&: roundIsNotEvenOrSpecialVal, args&: operandIsHalfway); |
| 633 | Value result = b.create<arith::SelectOp>(args&: needsShift, args&: roundShifted, args&: round); |
| 634 | // The `x - sign` adjustment does not preserve the sign when we are adjusting |
| 635 | // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is |
| 636 | // rounded to -0.0. |
| 637 | result = b.create<math::CopySignOp>(args&: result, args&: operand); |
| 638 | rewriter.replaceOp(op, newValues: result); |
| 639 | return success(); |
| 640 | } |
| 641 | |
| 642 | // Convert `math.rsqrt` into `arith.divf` + `math.sqrt` |
| 643 | static LogicalResult convertRsqrtOp(math::RsqrtOp op, |
| 644 | PatternRewriter &rewriter) { |
| 645 | |
| 646 | auto operand = op.getOperand(); |
| 647 | auto operandTy = operand.getType(); |
| 648 | // Operand type must be shatic shaped type to create const float. |
| 649 | auto shapedOperandType = dyn_cast<ShapedType>(Val&: operandTy); |
| 650 | if (shapedOperandType && !shapedOperandType.hasStaticShape()) |
| 651 | return failure(); |
| 652 | |
| 653 | auto eTy = getElementTypeOrSelf(type: operandTy); |
| 654 | if (!isa<FloatType>(Val: eTy)) |
| 655 | return failure(); |
| 656 | |
| 657 | Location loc = op->getLoc(); |
| 658 | auto constOneFloat = createFloatConst(loc, type: operandTy, value: 1.0, b&: rewriter); |
| 659 | auto sqrtOp = rewriter.create<math::SqrtOp>(location: loc, args&: operand); |
| 660 | rewriter.replaceOpWithNewOp<arith::DivFOp>(op, args&: constOneFloat, args&: sqrtOp); |
| 661 | return success(); |
| 662 | } |
| 663 | |
| 664 | void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { |
| 665 | patterns.add(implFn: convertCtlzOp); |
| 666 | } |
| 667 | |
| 668 | void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) { |
| 669 | patterns.add(implFn: convertSinhOp); |
| 670 | } |
| 671 | |
| 672 | void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) { |
| 673 | patterns.add(implFn: convertCoshOp); |
| 674 | } |
| 675 | |
| 676 | void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { |
| 677 | patterns.add(implFn: convertTanOp); |
| 678 | } |
| 679 | |
| 680 | void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { |
| 681 | patterns.add(implFn: convertTanhOp); |
| 682 | } |
| 683 | |
| 684 | void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) { |
| 685 | patterns.add(implFn: convertAsinhOp); |
| 686 | } |
| 687 | |
| 688 | void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) { |
| 689 | patterns.add(implFn: convertAcoshOp); |
| 690 | } |
| 691 | |
| 692 | void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) { |
| 693 | patterns.add(implFn: convertAtanhOp); |
| 694 | } |
| 695 | |
| 696 | void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { |
| 697 | patterns.add(implFn: convertFmaFOp); |
| 698 | } |
| 699 | |
| 700 | void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { |
| 701 | patterns.add(implFn: convertCeilOp); |
| 702 | } |
| 703 | |
| 704 | void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { |
| 705 | patterns.add(implFn: convertExp2fOp); |
| 706 | } |
| 707 | |
| 708 | void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { |
| 709 | patterns.add(implFn: convertPowfOp); |
| 710 | } |
| 711 | |
| 712 | void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) { |
| 713 | patterns.add(implFn: convertFPowIOp); |
| 714 | } |
| 715 | |
| 716 | void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { |
| 717 | patterns.add(implFn: convertRoundOp); |
| 718 | } |
| 719 | |
| 720 | void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) { |
| 721 | patterns.add(implFn: convertRoundEvenOp); |
| 722 | } |
| 723 | |
| 724 | void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) { |
| 725 | patterns.add(implFn: convertRsqrtOp); |
| 726 | } |
| 727 | |