| 1 | //===- IndexOps.cpp - Index operation definitions --------------------------==// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
| 10 | #include "mlir/Dialect/Index/IR/IndexAttrs.h" |
| 11 | #include "mlir/Dialect/Index/IR/IndexDialect.h" |
| 12 | #include "mlir/IR/Builders.h" |
| 13 | #include "mlir/IR/Matchers.h" |
| 14 | #include "mlir/IR/PatternMatch.h" |
| 15 | #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" |
| 16 | #include "llvm/ADT/SmallString.h" |
| 17 | #include "llvm/ADT/TypeSwitch.h" |
| 18 | |
| 19 | using namespace mlir; |
| 20 | using namespace mlir::index; |
| 21 | |
| 22 | //===----------------------------------------------------------------------===// |
| 23 | // IndexDialect |
| 24 | //===----------------------------------------------------------------------===// |
| 25 | |
| 26 | void IndexDialect::registerOperations() { |
| 27 | addOperations< |
| 28 | #define GET_OP_LIST |
| 29 | #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc" |
| 30 | >(); |
| 31 | } |
| 32 | |
| 33 | Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value, |
| 34 | Type type, Location loc) { |
| 35 | // Materialize bool constants as `i1`. |
| 36 | if (auto boolValue = dyn_cast<BoolAttr>(Val&: value)) { |
| 37 | if (!type.isSignlessInteger(width: 1)) |
| 38 | return nullptr; |
| 39 | return b.create<BoolConstantOp>(location: loc, args&: type, args&: boolValue); |
| 40 | } |
| 41 | |
| 42 | // Materialize integer attributes as `index`. |
| 43 | if (auto indexValue = dyn_cast<IntegerAttr>(Val&: value)) { |
| 44 | if (!llvm::isa<IndexType>(Val: indexValue.getType()) || |
| 45 | !llvm::isa<IndexType>(Val: type)) |
| 46 | return nullptr; |
| 47 | assert(indexValue.getValue().getBitWidth() == |
| 48 | IndexType::kInternalStorageBitWidth); |
| 49 | return b.create<ConstantOp>(location: loc, args&: indexValue); |
| 50 | } |
| 51 | |
| 52 | return nullptr; |
| 53 | } |
| 54 | |
| 55 | //===----------------------------------------------------------------------===// |
| 56 | // Fold Utilities |
| 57 | //===----------------------------------------------------------------------===// |
| 58 | |
| 59 | /// Fold an index operation irrespective of the target bitwidth. The |
| 60 | /// operation must satisfy the property: |
| 61 | /// |
| 62 | /// ``` |
| 63 | /// trunc(f(a, b)) = f(trunc(a), trunc(b)) |
| 64 | /// ``` |
| 65 | /// |
| 66 | /// For all values of `a` and `b`. The function accepts a lambda that computes |
| 67 | /// the integer result, which in turn must satisfy the above property. |
| 68 | static OpFoldResult foldBinaryOpUnchecked( |
| 69 | ArrayRef<Attribute> operands, |
| 70 | function_ref<std::optional<APInt>(const APInt &, const APInt &)> |
| 71 | calculate) { |
| 72 | assert(operands.size() == 2 && "binary operation expected 2 operands" ); |
| 73 | auto lhs = dyn_cast_if_present<IntegerAttr>(Val: operands[0]); |
| 74 | auto rhs = dyn_cast_if_present<IntegerAttr>(Val: operands[1]); |
| 75 | if (!lhs || !rhs) |
| 76 | return {}; |
| 77 | |
| 78 | std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue()); |
| 79 | if (!result) |
| 80 | return {}; |
| 81 | assert(result->trunc(32) == |
| 82 | calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32))); |
| 83 | return IntegerAttr::get(type: IndexType::get(context: lhs.getContext()), value: *result); |
| 84 | } |
| 85 | |
| 86 | /// Fold an index operation only if the truncated 64-bit result matches the |
| 87 | /// 32-bit result for operations that don't satisfy the above property. These |
| 88 | /// are operations where the upper bits of the operands can affect the lower |
| 89 | /// bits of the results. |
| 90 | /// |
| 91 | /// The function accepts a lambda that computes the integer result in both |
| 92 | /// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is |
| 93 | /// not folded. |
| 94 | static OpFoldResult foldBinaryOpChecked( |
| 95 | ArrayRef<Attribute> operands, |
| 96 | function_ref<std::optional<APInt>(const APInt &, const APInt &lhs)> |
| 97 | calculate) { |
| 98 | assert(operands.size() == 2 && "binary operation expected 2 operands" ); |
| 99 | auto lhs = dyn_cast_if_present<IntegerAttr>(Val: operands[0]); |
| 100 | auto rhs = dyn_cast_if_present<IntegerAttr>(Val: operands[1]); |
| 101 | // Only fold index operands. |
| 102 | if (!lhs || !rhs) |
| 103 | return {}; |
| 104 | |
| 105 | // Compute the 64-bit result and the 32-bit result. |
| 106 | std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue()); |
| 107 | if (!result64) |
| 108 | return {}; |
| 109 | std::optional<APInt> result32 = |
| 110 | calculate(lhs.getValue().trunc(width: 32), rhs.getValue().trunc(width: 32)); |
| 111 | if (!result32) |
| 112 | return {}; |
| 113 | // Compare the truncated 64-bit result to the 32-bit result. |
| 114 | if (result64->trunc(width: 32) != *result32) |
| 115 | return {}; |
| 116 | // The operation can be folded for these particular operands. |
| 117 | return IntegerAttr::get(type: IndexType::get(context: lhs.getContext()), value: *result64); |
| 118 | } |
| 119 | |
| 120 | /// Helper for associative and commutative binary ops that can be transformed: |
| 121 | /// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)` |
| 122 | /// where c1 and c2 are constants. It is expected that `tmp` will be folded. |
| 123 | template <typename BinaryOp> |
| 124 | LogicalResult |
| 125 | canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, |
| 126 | PatternRewriter &rewriter) { |
| 127 | if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant())) |
| 128 | return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant" ); |
| 129 | |
| 130 | auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>(); |
| 131 | if (!lhsOp) |
| 132 | return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp" ); |
| 133 | |
| 134 | if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant())) |
| 135 | return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant" ); |
| 136 | |
| 137 | Value c = rewriter.createOrFold<BinaryOp>(op->getLoc(), op.getRhs(), |
| 138 | lhsOp.getRhs()); |
| 139 | if (c.getDefiningOp<BinaryOp>()) |
| 140 | return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded" ); |
| 141 | |
| 142 | rewriter.replaceOpWithNewOp<BinaryOp>(op, lhsOp.getLhs(), c); |
| 143 | return success(); |
| 144 | } |
| 145 | |
| 146 | //===----------------------------------------------------------------------===// |
| 147 | // AddOp |
| 148 | //===----------------------------------------------------------------------===// |
| 149 | |
| 150 | OpFoldResult AddOp::fold(FoldAdaptor adaptor) { |
| 151 | if (OpFoldResult result = foldBinaryOpUnchecked( |
| 152 | operands: adaptor.getOperands(), |
| 153 | calculate: [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; })) |
| 154 | return result; |
| 155 | |
| 156 | if (auto rhs = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getRhs())) { |
| 157 | // Fold `add(x, 0) -> x`. |
| 158 | if (rhs.getValue().isZero()) |
| 159 | return getLhs(); |
| 160 | } |
| 161 | |
| 162 | return {}; |
| 163 | } |
| 164 | |
| 165 | LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) { |
| 166 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 167 | } |
| 168 | |
| 169 | //===----------------------------------------------------------------------===// |
| 170 | // SubOp |
| 171 | //===----------------------------------------------------------------------===// |
| 172 | |
| 173 | OpFoldResult SubOp::fold(FoldAdaptor adaptor) { |
| 174 | if (OpFoldResult result = foldBinaryOpUnchecked( |
| 175 | operands: adaptor.getOperands(), |
| 176 | calculate: [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; })) |
| 177 | return result; |
| 178 | |
| 179 | if (auto rhs = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getRhs())) { |
| 180 | // Fold `sub(x, 0) -> x`. |
| 181 | if (rhs.getValue().isZero()) |
| 182 | return getLhs(); |
| 183 | } |
| 184 | |
| 185 | return {}; |
| 186 | } |
| 187 | |
| 188 | //===----------------------------------------------------------------------===// |
| 189 | // MulOp |
| 190 | //===----------------------------------------------------------------------===// |
| 191 | |
| 192 | OpFoldResult MulOp::fold(FoldAdaptor adaptor) { |
| 193 | if (OpFoldResult result = foldBinaryOpUnchecked( |
| 194 | operands: adaptor.getOperands(), |
| 195 | calculate: [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; })) |
| 196 | return result; |
| 197 | |
| 198 | if (auto rhs = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getRhs())) { |
| 199 | // Fold `mul(x, 1) -> x`. |
| 200 | if (rhs.getValue().isOne()) |
| 201 | return getLhs(); |
| 202 | // Fold `mul(x, 0) -> 0`. |
| 203 | if (rhs.getValue().isZero()) |
| 204 | return rhs; |
| 205 | } |
| 206 | |
| 207 | return {}; |
| 208 | } |
| 209 | |
| 210 | LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) { |
| 211 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 212 | } |
| 213 | |
| 214 | //===----------------------------------------------------------------------===// |
| 215 | // DivSOp |
| 216 | //===----------------------------------------------------------------------===// |
| 217 | |
| 218 | OpFoldResult DivSOp::fold(FoldAdaptor adaptor) { |
| 219 | return foldBinaryOpChecked( |
| 220 | operands: adaptor.getOperands(), |
| 221 | calculate: [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
| 222 | // Don't fold division by zero. |
| 223 | if (rhs.isZero()) |
| 224 | return std::nullopt; |
| 225 | return lhs.sdiv(RHS: rhs); |
| 226 | }); |
| 227 | } |
| 228 | |
| 229 | //===----------------------------------------------------------------------===// |
| 230 | // DivUOp |
| 231 | //===----------------------------------------------------------------------===// |
| 232 | |
| 233 | OpFoldResult DivUOp::fold(FoldAdaptor adaptor) { |
| 234 | return foldBinaryOpChecked( |
| 235 | operands: adaptor.getOperands(), |
| 236 | calculate: [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
| 237 | // Don't fold division by zero. |
| 238 | if (rhs.isZero()) |
| 239 | return std::nullopt; |
| 240 | return lhs.udiv(RHS: rhs); |
| 241 | }); |
| 242 | } |
| 243 | |
| 244 | //===----------------------------------------------------------------------===// |
| 245 | // CeilDivSOp |
| 246 | //===----------------------------------------------------------------------===// |
| 247 | |
| 248 | /// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then |
| 249 | /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. |
| 250 | static std::optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) { |
| 251 | // Don't fold division by zero. |
| 252 | if (m.isZero()) |
| 253 | return std::nullopt; |
| 254 | // Short-circuit the zero case. |
| 255 | if (n.isZero()) |
| 256 | return n; |
| 257 | |
| 258 | bool mGtZ = m.sgt(RHS: 0); |
| 259 | if (n.sgt(RHS: 0) != mGtZ) { |
| 260 | // If the operands have different signs, compute the negative result. Signed |
| 261 | // division overflow is not possible, since if `m == -1`, `n` can be at most |
| 262 | // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement. |
| 263 | return -(-n).sdiv(RHS: m); |
| 264 | } |
| 265 | // Otherwise, compute the positive result. Signed division overflow is not |
| 266 | // possible since if `m == -1`, `x` will be `1`. |
| 267 | int64_t x = mGtZ ? -1 : 1; |
| 268 | return (n + x).sdiv(RHS: m) + 1; |
| 269 | } |
| 270 | |
| 271 | OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) { |
| 272 | return foldBinaryOpChecked(operands: adaptor.getOperands(), calculate: calculateCeilDivS); |
| 273 | } |
| 274 | |
| 275 | //===----------------------------------------------------------------------===// |
| 276 | // CeilDivUOp |
| 277 | //===----------------------------------------------------------------------===// |
| 278 | |
| 279 | OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) { |
| 280 | // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`. |
| 281 | return foldBinaryOpChecked( |
| 282 | operands: adaptor.getOperands(), |
| 283 | calculate: [](const APInt &n, const APInt &m) -> std::optional<APInt> { |
| 284 | // Don't fold division by zero. |
| 285 | if (m.isZero()) |
| 286 | return std::nullopt; |
| 287 | // Short-circuit the zero case. |
| 288 | if (n.isZero()) |
| 289 | return n; |
| 290 | |
| 291 | return (n - 1).udiv(RHS: m) + 1; |
| 292 | }); |
| 293 | } |
| 294 | |
| 295 | //===----------------------------------------------------------------------===// |
| 296 | // FloorDivSOp |
| 297 | //===----------------------------------------------------------------------===// |
| 298 | |
| 299 | /// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then |
| 300 | /// `n*m < 0 ? -1 - (x-n)/m : n/m`. |
| 301 | static std::optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) { |
| 302 | // Don't fold division by zero. |
| 303 | if (m.isZero()) |
| 304 | return std::nullopt; |
| 305 | // Short-circuit the zero case. |
| 306 | if (n.isZero()) |
| 307 | return n; |
| 308 | |
| 309 | bool mLtZ = m.slt(RHS: 0); |
| 310 | if (n.slt(RHS: 0) == mLtZ) { |
| 311 | // If the operands have the same sign, compute the positive result. |
| 312 | return n.sdiv(RHS: m); |
| 313 | } |
| 314 | // If the operands have different signs, compute the negative result. Signed |
| 315 | // division overflow is not possible since if `m == -1`, `x` will be 1 and |
| 316 | // `n` can be at most `INT_MAX`. |
| 317 | int64_t x = mLtZ ? 1 : -1; |
| 318 | return -1 - (x - n).sdiv(RHS: m); |
| 319 | } |
| 320 | |
| 321 | OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) { |
| 322 | return foldBinaryOpChecked(operands: adaptor.getOperands(), calculate: calculateFloorDivS); |
| 323 | } |
| 324 | |
| 325 | //===----------------------------------------------------------------------===// |
| 326 | // RemSOp |
| 327 | //===----------------------------------------------------------------------===// |
| 328 | |
| 329 | OpFoldResult RemSOp::fold(FoldAdaptor adaptor) { |
| 330 | return foldBinaryOpChecked( |
| 331 | operands: adaptor.getOperands(), |
| 332 | calculate: [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
| 333 | // Don't fold division by zero. |
| 334 | if (rhs.isZero()) |
| 335 | return std::nullopt; |
| 336 | return lhs.srem(RHS: rhs); |
| 337 | }); |
| 338 | } |
| 339 | |
| 340 | //===----------------------------------------------------------------------===// |
| 341 | // RemUOp |
| 342 | //===----------------------------------------------------------------------===// |
| 343 | |
| 344 | OpFoldResult RemUOp::fold(FoldAdaptor adaptor) { |
| 345 | return foldBinaryOpChecked( |
| 346 | operands: adaptor.getOperands(), |
| 347 | calculate: [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
| 348 | // Don't fold division by zero. |
| 349 | if (rhs.isZero()) |
| 350 | return std::nullopt; |
| 351 | return lhs.urem(RHS: rhs); |
| 352 | }); |
| 353 | } |
| 354 | |
| 355 | //===----------------------------------------------------------------------===// |
| 356 | // MaxSOp |
| 357 | //===----------------------------------------------------------------------===// |
| 358 | |
| 359 | OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) { |
| 360 | return foldBinaryOpChecked(operands: adaptor.getOperands(), |
| 361 | calculate: [](const APInt &lhs, const APInt &rhs) { |
| 362 | return lhs.sgt(RHS: rhs) ? lhs : rhs; |
| 363 | }); |
| 364 | } |
| 365 | |
| 366 | LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) { |
| 367 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 368 | } |
| 369 | |
| 370 | //===----------------------------------------------------------------------===// |
| 371 | // MaxUOp |
| 372 | //===----------------------------------------------------------------------===// |
| 373 | |
| 374 | OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) { |
| 375 | return foldBinaryOpChecked(operands: adaptor.getOperands(), |
| 376 | calculate: [](const APInt &lhs, const APInt &rhs) { |
| 377 | return lhs.ugt(RHS: rhs) ? lhs : rhs; |
| 378 | }); |
| 379 | } |
| 380 | |
| 381 | LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) { |
| 382 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 383 | } |
| 384 | |
| 385 | //===----------------------------------------------------------------------===// |
| 386 | // MinSOp |
| 387 | //===----------------------------------------------------------------------===// |
| 388 | |
| 389 | OpFoldResult MinSOp::fold(FoldAdaptor adaptor) { |
| 390 | return foldBinaryOpChecked(operands: adaptor.getOperands(), |
| 391 | calculate: [](const APInt &lhs, const APInt &rhs) { |
| 392 | return lhs.slt(RHS: rhs) ? lhs : rhs; |
| 393 | }); |
| 394 | } |
| 395 | |
| 396 | LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) { |
| 397 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 398 | } |
| 399 | |
| 400 | //===----------------------------------------------------------------------===// |
| 401 | // MinUOp |
| 402 | //===----------------------------------------------------------------------===// |
| 403 | |
| 404 | OpFoldResult MinUOp::fold(FoldAdaptor adaptor) { |
| 405 | return foldBinaryOpChecked(operands: adaptor.getOperands(), |
| 406 | calculate: [](const APInt &lhs, const APInt &rhs) { |
| 407 | return lhs.ult(RHS: rhs) ? lhs : rhs; |
| 408 | }); |
| 409 | } |
| 410 | |
| 411 | LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) { |
| 412 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 413 | } |
| 414 | |
| 415 | //===----------------------------------------------------------------------===// |
| 416 | // ShlOp |
| 417 | //===----------------------------------------------------------------------===// |
| 418 | |
| 419 | OpFoldResult ShlOp::fold(FoldAdaptor adaptor) { |
| 420 | return foldBinaryOpUnchecked( |
| 421 | operands: adaptor.getOperands(), |
| 422 | calculate: [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
| 423 | // We cannot fold if the RHS is greater than or equal to 32 because |
| 424 | // this would be UB in 32-bit systems but not on 64-bit systems. RHS is |
| 425 | // already treated as unsigned. |
| 426 | if (rhs.uge(RHS: 32)) |
| 427 | return {}; |
| 428 | return lhs << rhs; |
| 429 | }); |
| 430 | } |
| 431 | |
| 432 | //===----------------------------------------------------------------------===// |
| 433 | // ShrSOp |
| 434 | //===----------------------------------------------------------------------===// |
| 435 | |
| 436 | OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) { |
| 437 | return foldBinaryOpChecked( |
| 438 | operands: adaptor.getOperands(), |
| 439 | calculate: [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
| 440 | // Don't fold if RHS is greater than or equal to 32. |
| 441 | if (rhs.uge(RHS: 32)) |
| 442 | return {}; |
| 443 | return lhs.ashr(ShiftAmt: rhs); |
| 444 | }); |
| 445 | } |
| 446 | |
| 447 | //===----------------------------------------------------------------------===// |
| 448 | // ShrUOp |
| 449 | //===----------------------------------------------------------------------===// |
| 450 | |
| 451 | OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) { |
| 452 | return foldBinaryOpChecked( |
| 453 | operands: adaptor.getOperands(), |
| 454 | calculate: [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
| 455 | // Don't fold if RHS is greater than or equal to 32. |
| 456 | if (rhs.uge(RHS: 32)) |
| 457 | return {}; |
| 458 | return lhs.lshr(ShiftAmt: rhs); |
| 459 | }); |
| 460 | } |
| 461 | |
| 462 | //===----------------------------------------------------------------------===// |
| 463 | // AndOp |
| 464 | //===----------------------------------------------------------------------===// |
| 465 | |
| 466 | OpFoldResult AndOp::fold(FoldAdaptor adaptor) { |
| 467 | return foldBinaryOpUnchecked( |
| 468 | operands: adaptor.getOperands(), |
| 469 | calculate: [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; }); |
| 470 | } |
| 471 | |
| 472 | LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) { |
| 473 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 474 | } |
| 475 | |
| 476 | //===----------------------------------------------------------------------===// |
| 477 | // OrOp |
| 478 | //===----------------------------------------------------------------------===// |
| 479 | |
| 480 | OpFoldResult OrOp::fold(FoldAdaptor adaptor) { |
| 481 | return foldBinaryOpUnchecked( |
| 482 | operands: adaptor.getOperands(), |
| 483 | calculate: [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; }); |
| 484 | } |
| 485 | |
| 486 | LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) { |
| 487 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 488 | } |
| 489 | |
| 490 | //===----------------------------------------------------------------------===// |
| 491 | // XOrOp |
| 492 | //===----------------------------------------------------------------------===// |
| 493 | |
| 494 | OpFoldResult XOrOp::fold(FoldAdaptor adaptor) { |
| 495 | return foldBinaryOpUnchecked( |
| 496 | operands: adaptor.getOperands(), |
| 497 | calculate: [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; }); |
| 498 | } |
| 499 | |
| 500 | LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) { |
| 501 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 502 | } |
| 503 | |
| 504 | //===----------------------------------------------------------------------===// |
| 505 | // CastSOp |
| 506 | //===----------------------------------------------------------------------===// |
| 507 | |
| 508 | static OpFoldResult |
| 509 | foldCastOp(Attribute input, Type type, |
| 510 | function_ref<APInt(const APInt &, unsigned)> extFn, |
| 511 | function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) { |
| 512 | auto attr = dyn_cast_if_present<IntegerAttr>(Val&: input); |
| 513 | if (!attr) |
| 514 | return {}; |
| 515 | const APInt &value = attr.getValue(); |
| 516 | |
| 517 | if (isa<IndexType>(Val: type)) { |
| 518 | // When casting to an index type, perform the cast assuming a 64-bit target. |
| 519 | // The result can be truncated to 32 bits as needed and always be correct. |
| 520 | // This is because `cast32(cast64(value)) == cast32(value)`. |
| 521 | APInt result = extOrTruncFn(value, 64); |
| 522 | return IntegerAttr::get(type, value: result); |
| 523 | } |
| 524 | |
| 525 | // When casting from an index type, we must ensure the results respect |
| 526 | // `cast_t(value) == cast_t(trunc32(value))`. |
| 527 | auto intType = cast<IntegerType>(Val&: type); |
| 528 | unsigned width = intType.getWidth(); |
| 529 | |
| 530 | // If the result type is at most 32 bits, then the cast can always be folded |
| 531 | // because it is always a truncation. |
| 532 | if (width <= 32) { |
| 533 | APInt result = value.trunc(width); |
| 534 | return IntegerAttr::get(type, value: result); |
| 535 | } |
| 536 | |
| 537 | // If the result type is at least 64 bits, then the cast is always a |
| 538 | // extension. The results will differ if `trunc32(value) != value)`. |
| 539 | if (width >= 64) { |
| 540 | if (extFn(value.trunc(width: 32), 64) != value) |
| 541 | return {}; |
| 542 | APInt result = extFn(value, width); |
| 543 | return IntegerAttr::get(type, value: result); |
| 544 | } |
| 545 | |
| 546 | // Otherwise, we just have to check the property directly. |
| 547 | APInt result = value.trunc(width); |
| 548 | if (result != extFn(value.trunc(width: 32), width)) |
| 549 | return {}; |
| 550 | return IntegerAttr::get(type, value: result); |
| 551 | } |
| 552 | |
| 553 | bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { |
| 554 | return llvm::isa<IndexType>(Val: lhsTypes.front()) != |
| 555 | llvm::isa<IndexType>(Val: rhsTypes.front()); |
| 556 | } |
| 557 | |
| 558 | OpFoldResult CastSOp::fold(FoldAdaptor adaptor) { |
| 559 | return foldCastOp( |
| 560 | input: adaptor.getInput(), type: getType(), |
| 561 | extFn: [](const APInt &x, unsigned width) { return x.sext(width); }, |
| 562 | extOrTruncFn: [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); }); |
| 563 | } |
| 564 | |
| 565 | //===----------------------------------------------------------------------===// |
| 566 | // CastUOp |
| 567 | //===----------------------------------------------------------------------===// |
| 568 | |
| 569 | bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { |
| 570 | return llvm::isa<IndexType>(Val: lhsTypes.front()) != |
| 571 | llvm::isa<IndexType>(Val: rhsTypes.front()); |
| 572 | } |
| 573 | |
| 574 | OpFoldResult CastUOp::fold(FoldAdaptor adaptor) { |
| 575 | return foldCastOp( |
| 576 | input: adaptor.getInput(), type: getType(), |
| 577 | extFn: [](const APInt &x, unsigned width) { return x.zext(width); }, |
| 578 | extOrTruncFn: [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); }); |
| 579 | } |
| 580 | |
| 581 | //===----------------------------------------------------------------------===// |
| 582 | // CmpOp |
| 583 | //===----------------------------------------------------------------------===// |
| 584 | |
| 585 | /// Compare two integers according to the comparison predicate. |
| 586 | bool compareIndices(const APInt &lhs, const APInt &rhs, |
| 587 | IndexCmpPredicate pred) { |
| 588 | switch (pred) { |
| 589 | case IndexCmpPredicate::EQ: |
| 590 | return lhs.eq(RHS: rhs); |
| 591 | case IndexCmpPredicate::NE: |
| 592 | return lhs.ne(RHS: rhs); |
| 593 | case IndexCmpPredicate::SGE: |
| 594 | return lhs.sge(RHS: rhs); |
| 595 | case IndexCmpPredicate::SGT: |
| 596 | return lhs.sgt(RHS: rhs); |
| 597 | case IndexCmpPredicate::SLE: |
| 598 | return lhs.sle(RHS: rhs); |
| 599 | case IndexCmpPredicate::SLT: |
| 600 | return lhs.slt(RHS: rhs); |
| 601 | case IndexCmpPredicate::UGE: |
| 602 | return lhs.uge(RHS: rhs); |
| 603 | case IndexCmpPredicate::UGT: |
| 604 | return lhs.ugt(RHS: rhs); |
| 605 | case IndexCmpPredicate::ULE: |
| 606 | return lhs.ule(RHS: rhs); |
| 607 | case IndexCmpPredicate::ULT: |
| 608 | return lhs.ult(RHS: rhs); |
| 609 | } |
| 610 | llvm_unreachable("unhandled IndexCmpPredicate predicate" ); |
| 611 | } |
| 612 | |
| 613 | /// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the |
| 614 | /// values of `cstA` and `cstB`, the max or min operation, and the comparison |
| 615 | /// predicate. Check whether the value folds in both 32-bit and 64-bit |
| 616 | /// arithmetic and to the same value. |
| 617 | static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp, |
| 618 | const APInt &cstA, |
| 619 | const APInt &cstB, unsigned width, |
| 620 | IndexCmpPredicate pred) { |
| 621 | ConstantIntRanges lhsRange = TypeSwitch<Operation *, ConstantIntRanges>(lhsOp) |
| 622 | .Case(caseFn: [&](MinSOp op) { |
| 623 | return ConstantIntRanges::fromSigned( |
| 624 | smin: APInt::getSignedMinValue(numBits: width), smax: cstA); |
| 625 | }) |
| 626 | .Case(caseFn: [&](MinUOp op) { |
| 627 | return ConstantIntRanges::fromUnsigned( |
| 628 | umin: APInt::getMinValue(numBits: width), umax: cstA); |
| 629 | }) |
| 630 | .Case(caseFn: [&](MaxSOp op) { |
| 631 | return ConstantIntRanges::fromSigned( |
| 632 | smin: cstA, smax: APInt::getSignedMaxValue(numBits: width)); |
| 633 | }) |
| 634 | .Case(caseFn: [&](MaxUOp op) { |
| 635 | return ConstantIntRanges::fromUnsigned( |
| 636 | umin: cstA, umax: APInt::getMaxValue(numBits: width)); |
| 637 | }); |
| 638 | return intrange::evaluatePred(pred: static_cast<intrange::CmpPredicate>(pred), |
| 639 | lhs: lhsRange, rhs: ConstantIntRanges::constant(value: cstB)); |
| 640 | } |
| 641 | |
| 642 | /// Return the result of `cmp(pred, x, x)` |
| 643 | static bool compareSameArgs(IndexCmpPredicate pred) { |
| 644 | switch (pred) { |
| 645 | case IndexCmpPredicate::EQ: |
| 646 | case IndexCmpPredicate::SGE: |
| 647 | case IndexCmpPredicate::SLE: |
| 648 | case IndexCmpPredicate::UGE: |
| 649 | case IndexCmpPredicate::ULE: |
| 650 | return true; |
| 651 | case IndexCmpPredicate::NE: |
| 652 | case IndexCmpPredicate::SGT: |
| 653 | case IndexCmpPredicate::SLT: |
| 654 | case IndexCmpPredicate::UGT: |
| 655 | case IndexCmpPredicate::ULT: |
| 656 | return false; |
| 657 | } |
| 658 | llvm_unreachable("unknown predicate in compareSameArgs" ); |
| 659 | } |
| 660 | |
| 661 | OpFoldResult CmpOp::fold(FoldAdaptor adaptor) { |
| 662 | // Attempt to fold if both inputs are constant. |
| 663 | auto lhs = dyn_cast_if_present<IntegerAttr>(Val: adaptor.getLhs()); |
| 664 | auto rhs = dyn_cast_if_present<IntegerAttr>(Val: adaptor.getRhs()); |
| 665 | if (lhs && rhs) { |
| 666 | // Perform the comparison in 64-bit and 32-bit. |
| 667 | bool result64 = compareIndices(lhs: lhs.getValue(), rhs: rhs.getValue(), pred: getPred()); |
| 668 | bool result32 = compareIndices(lhs: lhs.getValue().trunc(width: 32), |
| 669 | rhs: rhs.getValue().trunc(width: 32), pred: getPred()); |
| 670 | if (result64 == result32) |
| 671 | return BoolAttr::get(context: getContext(), value: result64); |
| 672 | } |
| 673 | |
| 674 | // Fold `cmp(max/min(x, cstA), cstB)`. |
| 675 | Operation *lhsOp = getLhs().getDefiningOp(); |
| 676 | IntegerAttr cstA; |
| 677 | if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(Val: lhsOp) && |
| 678 | matchPattern(value: lhsOp->getOperand(idx: 1), pattern: m_Constant(bind_value: &cstA)) && rhs) { |
| 679 | std::optional<bool> result64 = foldCmpOfMaxOrMin( |
| 680 | lhsOp, cstA: cstA.getValue(), cstB: rhs.getValue(), width: 64, pred: getPred()); |
| 681 | std::optional<bool> result32 = |
| 682 | foldCmpOfMaxOrMin(lhsOp, cstA: cstA.getValue().trunc(width: 32), |
| 683 | cstB: rhs.getValue().trunc(width: 32), width: 32, pred: getPred()); |
| 684 | // Fold if the 32-bit and 64-bit results are the same. |
| 685 | if (result64 && result32 && *result64 == *result32) |
| 686 | return BoolAttr::get(context: getContext(), value: *result64); |
| 687 | } |
| 688 | |
| 689 | // Fold `cmp(x, x)` |
| 690 | if (getLhs() == getRhs()) |
| 691 | return BoolAttr::get(context: getContext(), value: compareSameArgs(pred: getPred())); |
| 692 | |
| 693 | return {}; |
| 694 | } |
| 695 | |
| 696 | /// Canonicalize |
| 697 | /// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`. |
| 698 | /// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`. |
| 699 | LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) { |
| 700 | IntegerAttr cmpRhs; |
| 701 | IntegerAttr cmpLhs; |
| 702 | |
| 703 | bool rhsIsZero = matchPattern(value: op.getRhs(), pattern: m_Constant(bind_value: &cmpRhs)) && |
| 704 | cmpRhs.getValue().isZero(); |
| 705 | bool lhsIsZero = matchPattern(value: op.getLhs(), pattern: m_Constant(bind_value: &cmpLhs)) && |
| 706 | cmpLhs.getValue().isZero(); |
| 707 | if (!rhsIsZero && !lhsIsZero) |
| 708 | return rewriter.notifyMatchFailure(arg: op.getLoc(), |
| 709 | msg: "cmp is not comparing something with 0" ); |
| 710 | SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>() |
| 711 | : op.getRhs().getDefiningOp<index::SubOp>(); |
| 712 | if (!subOp) |
| 713 | return rewriter.notifyMatchFailure( |
| 714 | arg: op.getLoc(), msg: "non-zero operand is not a result of subtraction" ); |
| 715 | |
| 716 | index::CmpOp newCmp; |
| 717 | if (rhsIsZero) |
| 718 | newCmp = rewriter.create<index::CmpOp>(location: op.getLoc(), args: op.getPred(), |
| 719 | args: subOp.getLhs(), args: subOp.getRhs()); |
| 720 | else |
| 721 | newCmp = rewriter.create<index::CmpOp>(location: op.getLoc(), args: op.getPred(), |
| 722 | args: subOp.getRhs(), args: subOp.getLhs()); |
| 723 | rewriter.replaceOp(op, newOp: newCmp); |
| 724 | return success(); |
| 725 | } |
| 726 | |
| 727 | //===----------------------------------------------------------------------===// |
| 728 | // ConstantOp |
| 729 | //===----------------------------------------------------------------------===// |
| 730 | |
| 731 | void ConstantOp::getAsmResultNames( |
| 732 | function_ref<void(Value, StringRef)> setNameFn) { |
| 733 | SmallString<32> specialNameBuffer; |
| 734 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
| 735 | specialName << "idx" << getValueAttr().getValue(); |
| 736 | setNameFn(getResult(), specialName.str()); |
| 737 | } |
| 738 | |
| 739 | OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } |
| 740 | |
| 741 | void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) { |
| 742 | build(odsBuilder&: b, odsState&: state, result: b.getIndexType(), value: b.getIndexAttr(value)); |
| 743 | } |
| 744 | |
| 745 | //===----------------------------------------------------------------------===// |
| 746 | // BoolConstantOp |
| 747 | //===----------------------------------------------------------------------===// |
| 748 | |
| 749 | OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) { |
| 750 | return getValueAttr(); |
| 751 | } |
| 752 | |
| 753 | void BoolConstantOp::getAsmResultNames( |
| 754 | function_ref<void(Value, StringRef)> setNameFn) { |
| 755 | setNameFn(getResult(), getValue() ? "true" : "false" ); |
| 756 | } |
| 757 | |
| 758 | //===----------------------------------------------------------------------===// |
| 759 | // ODS-Generated Definitions |
| 760 | //===----------------------------------------------------------------------===// |
| 761 | |
| 762 | #define GET_OP_CLASSES |
| 763 | #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc" |
| 764 | |