| 1 | //===- IndexOps.cpp - Index operation definitions --------------------------==// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
| 10 | #include "mlir/Dialect/Index/IR/IndexAttrs.h" |
| 11 | #include "mlir/Dialect/Index/IR/IndexDialect.h" |
| 12 | #include "mlir/IR/Builders.h" |
| 13 | #include "mlir/IR/Matchers.h" |
| 14 | #include "mlir/IR/OpImplementation.h" |
| 15 | #include "mlir/IR/PatternMatch.h" |
| 16 | #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" |
| 17 | #include "llvm/ADT/SmallString.h" |
| 18 | #include "llvm/ADT/TypeSwitch.h" |
| 19 | |
| 20 | using namespace mlir; |
| 21 | using namespace mlir::index; |
| 22 | |
| 23 | //===----------------------------------------------------------------------===// |
| 24 | // IndexDialect |
| 25 | //===----------------------------------------------------------------------===// |
| 26 | |
| 27 | void IndexDialect::registerOperations() { |
| 28 | addOperations< |
| 29 | #define GET_OP_LIST |
| 30 | #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc" |
| 31 | >(); |
| 32 | } |
| 33 | |
| 34 | Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value, |
| 35 | Type type, Location loc) { |
| 36 | // Materialize bool constants as `i1`. |
| 37 | if (auto boolValue = dyn_cast<BoolAttr>(value)) { |
| 38 | if (!type.isSignlessInteger(1)) |
| 39 | return nullptr; |
| 40 | return b.create<BoolConstantOp>(loc, type, boolValue); |
| 41 | } |
| 42 | |
| 43 | // Materialize integer attributes as `index`. |
| 44 | if (auto indexValue = dyn_cast<IntegerAttr>(value)) { |
| 45 | if (!llvm::isa<IndexType>(indexValue.getType()) || |
| 46 | !llvm::isa<IndexType>(type)) |
| 47 | return nullptr; |
| 48 | assert(indexValue.getValue().getBitWidth() == |
| 49 | IndexType::kInternalStorageBitWidth); |
| 50 | return b.create<ConstantOp>(loc, indexValue); |
| 51 | } |
| 52 | |
| 53 | return nullptr; |
| 54 | } |
| 55 | |
| 56 | //===----------------------------------------------------------------------===// |
| 57 | // Fold Utilities |
| 58 | //===----------------------------------------------------------------------===// |
| 59 | |
| 60 | /// Fold an index operation irrespective of the target bitwidth. The |
| 61 | /// operation must satisfy the property: |
| 62 | /// |
| 63 | /// ``` |
| 64 | /// trunc(f(a, b)) = f(trunc(a), trunc(b)) |
| 65 | /// ``` |
| 66 | /// |
| 67 | /// For all values of `a` and `b`. The function accepts a lambda that computes |
| 68 | /// the integer result, which in turn must satisfy the above property. |
| 69 | static OpFoldResult foldBinaryOpUnchecked( |
| 70 | ArrayRef<Attribute> operands, |
| 71 | function_ref<std::optional<APInt>(const APInt &, const APInt &)> |
| 72 | calculate) { |
| 73 | assert(operands.size() == 2 && "binary operation expected 2 operands" ); |
| 74 | auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]); |
| 75 | auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]); |
| 76 | if (!lhs || !rhs) |
| 77 | return {}; |
| 78 | |
| 79 | std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue()); |
| 80 | if (!result) |
| 81 | return {}; |
| 82 | assert(result->trunc(32) == |
| 83 | calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32))); |
| 84 | return IntegerAttr::get(IndexType::get(lhs.getContext()), *result); |
| 85 | } |
| 86 | |
| 87 | /// Fold an index operation only if the truncated 64-bit result matches the |
| 88 | /// 32-bit result for operations that don't satisfy the above property. These |
| 89 | /// are operations where the upper bits of the operands can affect the lower |
| 90 | /// bits of the results. |
| 91 | /// |
| 92 | /// The function accepts a lambda that computes the integer result in both |
| 93 | /// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is |
| 94 | /// not folded. |
| 95 | static OpFoldResult foldBinaryOpChecked( |
| 96 | ArrayRef<Attribute> operands, |
| 97 | function_ref<std::optional<APInt>(const APInt &, const APInt &lhs)> |
| 98 | calculate) { |
| 99 | assert(operands.size() == 2 && "binary operation expected 2 operands" ); |
| 100 | auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]); |
| 101 | auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]); |
| 102 | // Only fold index operands. |
| 103 | if (!lhs || !rhs) |
| 104 | return {}; |
| 105 | |
| 106 | // Compute the 64-bit result and the 32-bit result. |
| 107 | std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue()); |
| 108 | if (!result64) |
| 109 | return {}; |
| 110 | std::optional<APInt> result32 = |
| 111 | calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)); |
| 112 | if (!result32) |
| 113 | return {}; |
| 114 | // Compare the truncated 64-bit result to the 32-bit result. |
| 115 | if (result64->trunc(width: 32) != *result32) |
| 116 | return {}; |
| 117 | // The operation can be folded for these particular operands. |
| 118 | return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64); |
| 119 | } |
| 120 | |
| 121 | /// Helper for associative and commutative binary ops that can be transformed: |
| 122 | /// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)` |
| 123 | /// where c1 and c2 are constants. It is expected that `tmp` will be folded. |
| 124 | template <typename BinaryOp> |
| 125 | LogicalResult |
| 126 | canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, |
| 127 | PatternRewriter &rewriter) { |
| 128 | if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant())) |
| 129 | return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant" ); |
| 130 | |
| 131 | auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>(); |
| 132 | if (!lhsOp) |
| 133 | return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp" ); |
| 134 | |
| 135 | if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant())) |
| 136 | return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant" ); |
| 137 | |
| 138 | Value c = rewriter.createOrFold<BinaryOp>(op->getLoc(), op.getRhs(), |
| 139 | lhsOp.getRhs()); |
| 140 | if (c.getDefiningOp<BinaryOp>()) |
| 141 | return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded" ); |
| 142 | |
| 143 | rewriter.replaceOpWithNewOp<BinaryOp>(op, lhsOp.getLhs(), c); |
| 144 | return success(); |
| 145 | } |
| 146 | |
| 147 | //===----------------------------------------------------------------------===// |
| 148 | // AddOp |
| 149 | //===----------------------------------------------------------------------===// |
| 150 | |
| 151 | OpFoldResult AddOp::fold(FoldAdaptor adaptor) { |
| 152 | if (OpFoldResult result = foldBinaryOpUnchecked( |
| 153 | adaptor.getOperands(), |
| 154 | [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; })) |
| 155 | return result; |
| 156 | |
| 157 | if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) { |
| 158 | // Fold `add(x, 0) -> x`. |
| 159 | if (rhs.getValue().isZero()) |
| 160 | return getLhs(); |
| 161 | } |
| 162 | |
| 163 | return {}; |
| 164 | } |
| 165 | |
| 166 | LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) { |
| 167 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 168 | } |
| 169 | |
| 170 | //===----------------------------------------------------------------------===// |
| 171 | // SubOp |
| 172 | //===----------------------------------------------------------------------===// |
| 173 | |
| 174 | OpFoldResult SubOp::fold(FoldAdaptor adaptor) { |
| 175 | if (OpFoldResult result = foldBinaryOpUnchecked( |
| 176 | adaptor.getOperands(), |
| 177 | [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; })) |
| 178 | return result; |
| 179 | |
| 180 | if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) { |
| 181 | // Fold `sub(x, 0) -> x`. |
| 182 | if (rhs.getValue().isZero()) |
| 183 | return getLhs(); |
| 184 | } |
| 185 | |
| 186 | return {}; |
| 187 | } |
| 188 | |
| 189 | //===----------------------------------------------------------------------===// |
| 190 | // MulOp |
| 191 | //===----------------------------------------------------------------------===// |
| 192 | |
| 193 | OpFoldResult MulOp::fold(FoldAdaptor adaptor) { |
| 194 | if (OpFoldResult result = foldBinaryOpUnchecked( |
| 195 | adaptor.getOperands(), |
| 196 | [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; })) |
| 197 | return result; |
| 198 | |
| 199 | if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) { |
| 200 | // Fold `mul(x, 1) -> x`. |
| 201 | if (rhs.getValue().isOne()) |
| 202 | return getLhs(); |
| 203 | // Fold `mul(x, 0) -> 0`. |
| 204 | if (rhs.getValue().isZero()) |
| 205 | return rhs; |
| 206 | } |
| 207 | |
| 208 | return {}; |
| 209 | } |
| 210 | |
| 211 | LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) { |
| 212 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 213 | } |
| 214 | |
| 215 | //===----------------------------------------------------------------------===// |
| 216 | // DivSOp |
| 217 | //===----------------------------------------------------------------------===// |
| 218 | |
| 219 | OpFoldResult DivSOp::fold(FoldAdaptor adaptor) { |
| 220 | return foldBinaryOpChecked( |
| 221 | adaptor.getOperands(), |
| 222 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
| 223 | // Don't fold division by zero. |
| 224 | if (rhs.isZero()) |
| 225 | return std::nullopt; |
| 226 | return lhs.sdiv(rhs); |
| 227 | }); |
| 228 | } |
| 229 | |
| 230 | //===----------------------------------------------------------------------===// |
| 231 | // DivUOp |
| 232 | //===----------------------------------------------------------------------===// |
| 233 | |
| 234 | OpFoldResult DivUOp::fold(FoldAdaptor adaptor) { |
| 235 | return foldBinaryOpChecked( |
| 236 | adaptor.getOperands(), |
| 237 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
| 238 | // Don't fold division by zero. |
| 239 | if (rhs.isZero()) |
| 240 | return std::nullopt; |
| 241 | return lhs.udiv(rhs); |
| 242 | }); |
| 243 | } |
| 244 | |
| 245 | //===----------------------------------------------------------------------===// |
| 246 | // CeilDivSOp |
| 247 | //===----------------------------------------------------------------------===// |
| 248 | |
| 249 | /// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then |
| 250 | /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. |
| 251 | static std::optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) { |
| 252 | // Don't fold division by zero. |
| 253 | if (m.isZero()) |
| 254 | return std::nullopt; |
| 255 | // Short-circuit the zero case. |
| 256 | if (n.isZero()) |
| 257 | return n; |
| 258 | |
| 259 | bool mGtZ = m.sgt(RHS: 0); |
| 260 | if (n.sgt(RHS: 0) != mGtZ) { |
| 261 | // If the operands have different signs, compute the negative result. Signed |
| 262 | // division overflow is not possible, since if `m == -1`, `n` can be at most |
| 263 | // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement. |
| 264 | return -(-n).sdiv(RHS: m); |
| 265 | } |
| 266 | // Otherwise, compute the positive result. Signed division overflow is not |
| 267 | // possible since if `m == -1`, `x` will be `1`. |
| 268 | int64_t x = mGtZ ? -1 : 1; |
| 269 | return (n + x).sdiv(RHS: m) + 1; |
| 270 | } |
| 271 | |
| 272 | OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) { |
| 273 | return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS); |
| 274 | } |
| 275 | |
| 276 | //===----------------------------------------------------------------------===// |
| 277 | // CeilDivUOp |
| 278 | //===----------------------------------------------------------------------===// |
| 279 | |
| 280 | OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) { |
| 281 | // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`. |
| 282 | return foldBinaryOpChecked( |
| 283 | adaptor.getOperands(), |
| 284 | [](const APInt &n, const APInt &m) -> std::optional<APInt> { |
| 285 | // Don't fold division by zero. |
| 286 | if (m.isZero()) |
| 287 | return std::nullopt; |
| 288 | // Short-circuit the zero case. |
| 289 | if (n.isZero()) |
| 290 | return n; |
| 291 | |
| 292 | return (n - 1).udiv(m) + 1; |
| 293 | }); |
| 294 | } |
| 295 | |
| 296 | //===----------------------------------------------------------------------===// |
| 297 | // FloorDivSOp |
| 298 | //===----------------------------------------------------------------------===// |
| 299 | |
| 300 | /// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then |
| 301 | /// `n*m < 0 ? -1 - (x-n)/m : n/m`. |
| 302 | static std::optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) { |
| 303 | // Don't fold division by zero. |
| 304 | if (m.isZero()) |
| 305 | return std::nullopt; |
| 306 | // Short-circuit the zero case. |
| 307 | if (n.isZero()) |
| 308 | return n; |
| 309 | |
| 310 | bool mLtZ = m.slt(RHS: 0); |
| 311 | if (n.slt(RHS: 0) == mLtZ) { |
| 312 | // If the operands have the same sign, compute the positive result. |
| 313 | return n.sdiv(RHS: m); |
| 314 | } |
| 315 | // If the operands have different signs, compute the negative result. Signed |
| 316 | // division overflow is not possible since if `m == -1`, `x` will be 1 and |
| 317 | // `n` can be at most `INT_MAX`. |
| 318 | int64_t x = mLtZ ? 1 : -1; |
| 319 | return -1 - (x - n).sdiv(RHS: m); |
| 320 | } |
| 321 | |
| 322 | OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) { |
| 323 | return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS); |
| 324 | } |
| 325 | |
| 326 | //===----------------------------------------------------------------------===// |
| 327 | // RemSOp |
| 328 | //===----------------------------------------------------------------------===// |
| 329 | |
| 330 | OpFoldResult RemSOp::fold(FoldAdaptor adaptor) { |
| 331 | return foldBinaryOpChecked( |
| 332 | adaptor.getOperands(), |
| 333 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
| 334 | // Don't fold division by zero. |
| 335 | if (rhs.isZero()) |
| 336 | return std::nullopt; |
| 337 | return lhs.srem(rhs); |
| 338 | }); |
| 339 | } |
| 340 | |
| 341 | //===----------------------------------------------------------------------===// |
| 342 | // RemUOp |
| 343 | //===----------------------------------------------------------------------===// |
| 344 | |
| 345 | OpFoldResult RemUOp::fold(FoldAdaptor adaptor) { |
| 346 | return foldBinaryOpChecked( |
| 347 | adaptor.getOperands(), |
| 348 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
| 349 | // Don't fold division by zero. |
| 350 | if (rhs.isZero()) |
| 351 | return std::nullopt; |
| 352 | return lhs.urem(rhs); |
| 353 | }); |
| 354 | } |
| 355 | |
| 356 | //===----------------------------------------------------------------------===// |
| 357 | // MaxSOp |
| 358 | //===----------------------------------------------------------------------===// |
| 359 | |
| 360 | OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) { |
| 361 | return foldBinaryOpChecked(adaptor.getOperands(), |
| 362 | [](const APInt &lhs, const APInt &rhs) { |
| 363 | return lhs.sgt(rhs) ? lhs : rhs; |
| 364 | }); |
| 365 | } |
| 366 | |
| 367 | LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) { |
| 368 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 369 | } |
| 370 | |
| 371 | //===----------------------------------------------------------------------===// |
| 372 | // MaxUOp |
| 373 | //===----------------------------------------------------------------------===// |
| 374 | |
| 375 | OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) { |
| 376 | return foldBinaryOpChecked(adaptor.getOperands(), |
| 377 | [](const APInt &lhs, const APInt &rhs) { |
| 378 | return lhs.ugt(rhs) ? lhs : rhs; |
| 379 | }); |
| 380 | } |
| 381 | |
| 382 | LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) { |
| 383 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 384 | } |
| 385 | |
| 386 | //===----------------------------------------------------------------------===// |
| 387 | // MinSOp |
| 388 | //===----------------------------------------------------------------------===// |
| 389 | |
| 390 | OpFoldResult MinSOp::fold(FoldAdaptor adaptor) { |
| 391 | return foldBinaryOpChecked(adaptor.getOperands(), |
| 392 | [](const APInt &lhs, const APInt &rhs) { |
| 393 | return lhs.slt(rhs) ? lhs : rhs; |
| 394 | }); |
| 395 | } |
| 396 | |
| 397 | LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) { |
| 398 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 399 | } |
| 400 | |
| 401 | //===----------------------------------------------------------------------===// |
| 402 | // MinUOp |
| 403 | //===----------------------------------------------------------------------===// |
| 404 | |
| 405 | OpFoldResult MinUOp::fold(FoldAdaptor adaptor) { |
| 406 | return foldBinaryOpChecked(adaptor.getOperands(), |
| 407 | [](const APInt &lhs, const APInt &rhs) { |
| 408 | return lhs.ult(rhs) ? lhs : rhs; |
| 409 | }); |
| 410 | } |
| 411 | |
| 412 | LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) { |
| 413 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 414 | } |
| 415 | |
| 416 | //===----------------------------------------------------------------------===// |
| 417 | // ShlOp |
| 418 | //===----------------------------------------------------------------------===// |
| 419 | |
| 420 | OpFoldResult ShlOp::fold(FoldAdaptor adaptor) { |
| 421 | return foldBinaryOpUnchecked( |
| 422 | adaptor.getOperands(), |
| 423 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
| 424 | // We cannot fold if the RHS is greater than or equal to 32 because |
| 425 | // this would be UB in 32-bit systems but not on 64-bit systems. RHS is |
| 426 | // already treated as unsigned. |
| 427 | if (rhs.uge(32)) |
| 428 | return {}; |
| 429 | return lhs << rhs; |
| 430 | }); |
| 431 | } |
| 432 | |
| 433 | //===----------------------------------------------------------------------===// |
| 434 | // ShrSOp |
| 435 | //===----------------------------------------------------------------------===// |
| 436 | |
| 437 | OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) { |
| 438 | return foldBinaryOpChecked( |
| 439 | adaptor.getOperands(), |
| 440 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
| 441 | // Don't fold if RHS is greater than or equal to 32. |
| 442 | if (rhs.uge(32)) |
| 443 | return {}; |
| 444 | return lhs.ashr(rhs); |
| 445 | }); |
| 446 | } |
| 447 | |
| 448 | //===----------------------------------------------------------------------===// |
| 449 | // ShrUOp |
| 450 | //===----------------------------------------------------------------------===// |
| 451 | |
| 452 | OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) { |
| 453 | return foldBinaryOpChecked( |
| 454 | adaptor.getOperands(), |
| 455 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
| 456 | // Don't fold if RHS is greater than or equal to 32. |
| 457 | if (rhs.uge(32)) |
| 458 | return {}; |
| 459 | return lhs.lshr(rhs); |
| 460 | }); |
| 461 | } |
| 462 | |
| 463 | //===----------------------------------------------------------------------===// |
| 464 | // AndOp |
| 465 | //===----------------------------------------------------------------------===// |
| 466 | |
| 467 | OpFoldResult AndOp::fold(FoldAdaptor adaptor) { |
| 468 | return foldBinaryOpUnchecked( |
| 469 | adaptor.getOperands(), |
| 470 | [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; }); |
| 471 | } |
| 472 | |
| 473 | LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) { |
| 474 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 475 | } |
| 476 | |
| 477 | //===----------------------------------------------------------------------===// |
| 478 | // OrOp |
| 479 | //===----------------------------------------------------------------------===// |
| 480 | |
| 481 | OpFoldResult OrOp::fold(FoldAdaptor adaptor) { |
| 482 | return foldBinaryOpUnchecked( |
| 483 | adaptor.getOperands(), |
| 484 | [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; }); |
| 485 | } |
| 486 | |
| 487 | LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) { |
| 488 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 489 | } |
| 490 | |
| 491 | //===----------------------------------------------------------------------===// |
| 492 | // XOrOp |
| 493 | //===----------------------------------------------------------------------===// |
| 494 | |
| 495 | OpFoldResult XOrOp::fold(FoldAdaptor adaptor) { |
| 496 | return foldBinaryOpUnchecked( |
| 497 | adaptor.getOperands(), |
| 498 | [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; }); |
| 499 | } |
| 500 | |
| 501 | LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) { |
| 502 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
| 503 | } |
| 504 | |
| 505 | //===----------------------------------------------------------------------===// |
| 506 | // CastSOp |
| 507 | //===----------------------------------------------------------------------===// |
| 508 | |
| 509 | static OpFoldResult |
| 510 | foldCastOp(Attribute input, Type type, |
| 511 | function_ref<APInt(const APInt &, unsigned)> extFn, |
| 512 | function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) { |
| 513 | auto attr = dyn_cast_if_present<IntegerAttr>(input); |
| 514 | if (!attr) |
| 515 | return {}; |
| 516 | const APInt &value = attr.getValue(); |
| 517 | |
| 518 | if (isa<IndexType>(Val: type)) { |
| 519 | // When casting to an index type, perform the cast assuming a 64-bit target. |
| 520 | // The result can be truncated to 32 bits as needed and always be correct. |
| 521 | // This is because `cast32(cast64(value)) == cast32(value)`. |
| 522 | APInt result = extOrTruncFn(value, 64); |
| 523 | return IntegerAttr::get(type, result); |
| 524 | } |
| 525 | |
| 526 | // When casting from an index type, we must ensure the results respect |
| 527 | // `cast_t(value) == cast_t(trunc32(value))`. |
| 528 | auto intType = cast<IntegerType>(type); |
| 529 | unsigned width = intType.getWidth(); |
| 530 | |
| 531 | // If the result type is at most 32 bits, then the cast can always be folded |
| 532 | // because it is always a truncation. |
| 533 | if (width <= 32) { |
| 534 | APInt result = value.trunc(width); |
| 535 | return IntegerAttr::get(type, result); |
| 536 | } |
| 537 | |
| 538 | // If the result type is at least 64 bits, then the cast is always a |
| 539 | // extension. The results will differ if `trunc32(value) != value)`. |
| 540 | if (width >= 64) { |
| 541 | if (extFn(value.trunc(width: 32), 64) != value) |
| 542 | return {}; |
| 543 | APInt result = extFn(value, width); |
| 544 | return IntegerAttr::get(type, result); |
| 545 | } |
| 546 | |
| 547 | // Otherwise, we just have to check the property directly. |
| 548 | APInt result = value.trunc(width); |
| 549 | if (result != extFn(value.trunc(width: 32), width)) |
| 550 | return {}; |
| 551 | return IntegerAttr::get(type, result); |
| 552 | } |
| 553 | |
| 554 | bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { |
| 555 | return llvm::isa<IndexType>(lhsTypes.front()) != |
| 556 | llvm::isa<IndexType>(rhsTypes.front()); |
| 557 | } |
| 558 | |
| 559 | OpFoldResult CastSOp::fold(FoldAdaptor adaptor) { |
| 560 | return foldCastOp( |
| 561 | adaptor.getInput(), getType(), |
| 562 | [](const APInt &x, unsigned width) { return x.sext(width); }, |
| 563 | [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); }); |
| 564 | } |
| 565 | |
| 566 | //===----------------------------------------------------------------------===// |
| 567 | // CastUOp |
| 568 | //===----------------------------------------------------------------------===// |
| 569 | |
| 570 | bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { |
| 571 | return llvm::isa<IndexType>(lhsTypes.front()) != |
| 572 | llvm::isa<IndexType>(rhsTypes.front()); |
| 573 | } |
| 574 | |
| 575 | OpFoldResult CastUOp::fold(FoldAdaptor adaptor) { |
| 576 | return foldCastOp( |
| 577 | adaptor.getInput(), getType(), |
| 578 | [](const APInt &x, unsigned width) { return x.zext(width); }, |
| 579 | [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); }); |
| 580 | } |
| 581 | |
| 582 | //===----------------------------------------------------------------------===// |
| 583 | // CmpOp |
| 584 | //===----------------------------------------------------------------------===// |
| 585 | |
| 586 | /// Compare two integers according to the comparison predicate. |
| 587 | bool compareIndices(const APInt &lhs, const APInt &rhs, |
| 588 | IndexCmpPredicate pred) { |
| 589 | switch (pred) { |
| 590 | case IndexCmpPredicate::EQ: |
| 591 | return lhs.eq(RHS: rhs); |
| 592 | case IndexCmpPredicate::NE: |
| 593 | return lhs.ne(RHS: rhs); |
| 594 | case IndexCmpPredicate::SGE: |
| 595 | return lhs.sge(RHS: rhs); |
| 596 | case IndexCmpPredicate::SGT: |
| 597 | return lhs.sgt(RHS: rhs); |
| 598 | case IndexCmpPredicate::SLE: |
| 599 | return lhs.sle(RHS: rhs); |
| 600 | case IndexCmpPredicate::SLT: |
| 601 | return lhs.slt(RHS: rhs); |
| 602 | case IndexCmpPredicate::UGE: |
| 603 | return lhs.uge(RHS: rhs); |
| 604 | case IndexCmpPredicate::UGT: |
| 605 | return lhs.ugt(RHS: rhs); |
| 606 | case IndexCmpPredicate::ULE: |
| 607 | return lhs.ule(RHS: rhs); |
| 608 | case IndexCmpPredicate::ULT: |
| 609 | return lhs.ult(RHS: rhs); |
| 610 | } |
| 611 | llvm_unreachable("unhandled IndexCmpPredicate predicate" ); |
| 612 | } |
| 613 | |
| 614 | /// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the |
| 615 | /// values of `cstA` and `cstB`, the max or min operation, and the comparison |
| 616 | /// predicate. Check whether the value folds in both 32-bit and 64-bit |
| 617 | /// arithmetic and to the same value. |
| 618 | static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp, |
| 619 | const APInt &cstA, |
| 620 | const APInt &cstB, unsigned width, |
| 621 | IndexCmpPredicate pred) { |
| 622 | ConstantIntRanges lhsRange = TypeSwitch<Operation *, ConstantIntRanges>(lhsOp) |
| 623 | .Case(caseFn: [&](MinSOp op) { |
| 624 | return ConstantIntRanges::fromSigned( |
| 625 | smin: APInt::getSignedMinValue(numBits: width), smax: cstA); |
| 626 | }) |
| 627 | .Case(caseFn: [&](MinUOp op) { |
| 628 | return ConstantIntRanges::fromUnsigned( |
| 629 | umin: APInt::getMinValue(numBits: width), umax: cstA); |
| 630 | }) |
| 631 | .Case(caseFn: [&](MaxSOp op) { |
| 632 | return ConstantIntRanges::fromSigned( |
| 633 | smin: cstA, smax: APInt::getSignedMaxValue(numBits: width)); |
| 634 | }) |
| 635 | .Case(caseFn: [&](MaxUOp op) { |
| 636 | return ConstantIntRanges::fromUnsigned( |
| 637 | umin: cstA, umax: APInt::getMaxValue(numBits: width)); |
| 638 | }); |
| 639 | return intrange::evaluatePred(pred: static_cast<intrange::CmpPredicate>(pred), |
| 640 | lhs: lhsRange, rhs: ConstantIntRanges::constant(value: cstB)); |
| 641 | } |
| 642 | |
| 643 | /// Return the result of `cmp(pred, x, x)` |
| 644 | static bool compareSameArgs(IndexCmpPredicate pred) { |
| 645 | switch (pred) { |
| 646 | case IndexCmpPredicate::EQ: |
| 647 | case IndexCmpPredicate::SGE: |
| 648 | case IndexCmpPredicate::SLE: |
| 649 | case IndexCmpPredicate::UGE: |
| 650 | case IndexCmpPredicate::ULE: |
| 651 | return true; |
| 652 | case IndexCmpPredicate::NE: |
| 653 | case IndexCmpPredicate::SGT: |
| 654 | case IndexCmpPredicate::SLT: |
| 655 | case IndexCmpPredicate::UGT: |
| 656 | case IndexCmpPredicate::ULT: |
| 657 | return false; |
| 658 | } |
| 659 | llvm_unreachable("unknown predicate in compareSameArgs" ); |
| 660 | } |
| 661 | |
| 662 | OpFoldResult CmpOp::fold(FoldAdaptor adaptor) { |
| 663 | // Attempt to fold if both inputs are constant. |
| 664 | auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs()); |
| 665 | auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()); |
| 666 | if (lhs && rhs) { |
| 667 | // Perform the comparison in 64-bit and 32-bit. |
| 668 | bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred()); |
| 669 | bool result32 = compareIndices(lhs.getValue().trunc(32), |
| 670 | rhs.getValue().trunc(32), getPred()); |
| 671 | if (result64 == result32) |
| 672 | return BoolAttr::get(getContext(), result64); |
| 673 | } |
| 674 | |
| 675 | // Fold `cmp(max/min(x, cstA), cstB)`. |
| 676 | Operation *lhsOp = getLhs().getDefiningOp(); |
| 677 | IntegerAttr cstA; |
| 678 | if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) && |
| 679 | matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) { |
| 680 | std::optional<bool> result64 = foldCmpOfMaxOrMin( |
| 681 | lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred()); |
| 682 | std::optional<bool> result32 = |
| 683 | foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32), |
| 684 | rhs.getValue().trunc(32), 32, getPred()); |
| 685 | // Fold if the 32-bit and 64-bit results are the same. |
| 686 | if (result64 && result32 && *result64 == *result32) |
| 687 | return BoolAttr::get(getContext(), *result64); |
| 688 | } |
| 689 | |
| 690 | // Fold `cmp(x, x)` |
| 691 | if (getLhs() == getRhs()) |
| 692 | return BoolAttr::get(getContext(), compareSameArgs(getPred())); |
| 693 | |
| 694 | return {}; |
| 695 | } |
| 696 | |
| 697 | /// Canonicalize |
| 698 | /// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`. |
| 699 | /// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`. |
| 700 | LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) { |
| 701 | IntegerAttr cmpRhs; |
| 702 | IntegerAttr cmpLhs; |
| 703 | |
| 704 | bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) && |
| 705 | cmpRhs.getValue().isZero(); |
| 706 | bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) && |
| 707 | cmpLhs.getValue().isZero(); |
| 708 | if (!rhsIsZero && !lhsIsZero) |
| 709 | return rewriter.notifyMatchFailure(op.getLoc(), |
| 710 | "cmp is not comparing something with 0" ); |
| 711 | SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>() |
| 712 | : op.getRhs().getDefiningOp<index::SubOp>(); |
| 713 | if (!subOp) |
| 714 | return rewriter.notifyMatchFailure( |
| 715 | op.getLoc(), "non-zero operand is not a result of subtraction" ); |
| 716 | |
| 717 | index::CmpOp newCmp; |
| 718 | if (rhsIsZero) |
| 719 | newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(), |
| 720 | subOp.getLhs(), subOp.getRhs()); |
| 721 | else |
| 722 | newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(), |
| 723 | subOp.getRhs(), subOp.getLhs()); |
| 724 | rewriter.replaceOp(op, newCmp); |
| 725 | return success(); |
| 726 | } |
| 727 | |
| 728 | //===----------------------------------------------------------------------===// |
| 729 | // ConstantOp |
| 730 | //===----------------------------------------------------------------------===// |
| 731 | |
| 732 | void ConstantOp::getAsmResultNames( |
| 733 | function_ref<void(Value, StringRef)> setNameFn) { |
| 734 | SmallString<32> specialNameBuffer; |
| 735 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
| 736 | specialName << "idx" << getValueAttr().getValue(); |
| 737 | setNameFn(getResult(), specialName.str()); |
| 738 | } |
| 739 | |
| 740 | OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } |
| 741 | |
| 742 | void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) { |
| 743 | build(b, state, b.getIndexType(), b.getIndexAttr(value)); |
| 744 | } |
| 745 | |
| 746 | //===----------------------------------------------------------------------===// |
| 747 | // BoolConstantOp |
| 748 | //===----------------------------------------------------------------------===// |
| 749 | |
| 750 | OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) { |
| 751 | return getValueAttr(); |
| 752 | } |
| 753 | |
| 754 | void BoolConstantOp::getAsmResultNames( |
| 755 | function_ref<void(Value, StringRef)> setNameFn) { |
| 756 | setNameFn(getResult(), getValue() ? "true" : "false" ); |
| 757 | } |
| 758 | |
| 759 | //===----------------------------------------------------------------------===// |
| 760 | // ODS-Generated Definitions |
| 761 | //===----------------------------------------------------------------------===// |
| 762 | |
| 763 | #define GET_OP_CLASSES |
| 764 | #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc" |
| 765 | |