1//===- IndexOps.cpp - Index operation definitions --------------------------==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Index/IR/IndexOps.h"
10#include "mlir/Dialect/Index/IR/IndexAttrs.h"
11#include "mlir/Dialect/Index/IR/IndexDialect.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/OpImplementation.h"
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
17#include "llvm/ADT/SmallString.h"
18#include "llvm/ADT/TypeSwitch.h"
19
20using namespace mlir;
21using namespace mlir::index;
22
23//===----------------------------------------------------------------------===//
24// IndexDialect
25//===----------------------------------------------------------------------===//
26
27void IndexDialect::registerOperations() {
28 addOperations<
29#define GET_OP_LIST
30#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
31 >();
32}
33
34Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
35 Type type, Location loc) {
36 // Materialize bool constants as `i1`.
37 if (auto boolValue = dyn_cast<BoolAttr>(value)) {
38 if (!type.isSignlessInteger(1))
39 return nullptr;
40 return b.create<BoolConstantOp>(loc, type, boolValue);
41 }
42
43 // Materialize integer attributes as `index`.
44 if (auto indexValue = dyn_cast<IntegerAttr>(value)) {
45 if (!llvm::isa<IndexType>(indexValue.getType()) ||
46 !llvm::isa<IndexType>(type))
47 return nullptr;
48 assert(indexValue.getValue().getBitWidth() ==
49 IndexType::kInternalStorageBitWidth);
50 return b.create<ConstantOp>(loc, indexValue);
51 }
52
53 return nullptr;
54}
55
56//===----------------------------------------------------------------------===//
57// Fold Utilities
58//===----------------------------------------------------------------------===//
59
60/// Fold an index operation irrespective of the target bitwidth. The
61/// operation must satisfy the property:
62///
63/// ```
64/// trunc(f(a, b)) = f(trunc(a), trunc(b))
65/// ```
66///
67/// For all values of `a` and `b`. The function accepts a lambda that computes
68/// the integer result, which in turn must satisfy the above property.
69static OpFoldResult foldBinaryOpUnchecked(
70 ArrayRef<Attribute> operands,
71 function_ref<std::optional<APInt>(const APInt &, const APInt &)>
72 calculate) {
73 assert(operands.size() == 2 && "binary operation expected 2 operands");
74 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
75 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
76 if (!lhs || !rhs)
77 return {};
78
79 std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
80 if (!result)
81 return {};
82 assert(result->trunc(32) ==
83 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
84 return IntegerAttr::get(IndexType::get(lhs.getContext()), *result);
85}
86
87/// Fold an index operation only if the truncated 64-bit result matches the
88/// 32-bit result for operations that don't satisfy the above property. These
89/// are operations where the upper bits of the operands can affect the lower
90/// bits of the results.
91///
92/// The function accepts a lambda that computes the integer result in both
93/// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is
94/// not folded.
95static OpFoldResult foldBinaryOpChecked(
96 ArrayRef<Attribute> operands,
97 function_ref<std::optional<APInt>(const APInt &, const APInt &lhs)>
98 calculate) {
99 assert(operands.size() == 2 && "binary operation expected 2 operands");
100 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
101 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
102 // Only fold index operands.
103 if (!lhs || !rhs)
104 return {};
105
106 // Compute the 64-bit result and the 32-bit result.
107 std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
108 if (!result64)
109 return {};
110 std::optional<APInt> result32 =
111 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
112 if (!result32)
113 return {};
114 // Compare the truncated 64-bit result to the 32-bit result.
115 if (result64->trunc(width: 32) != *result32)
116 return {};
117 // The operation can be folded for these particular operands.
118 return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64);
119}
120
121//===----------------------------------------------------------------------===//
122// AddOp
123//===----------------------------------------------------------------------===//
124
125OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
126 if (OpFoldResult result = foldBinaryOpUnchecked(
127 adaptor.getOperands(),
128 [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; }))
129 return result;
130
131 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
132 // Fold `add(x, 0) -> x`.
133 if (rhs.getValue().isZero())
134 return getLhs();
135 }
136
137 return {};
138}
139
140//===----------------------------------------------------------------------===//
141// SubOp
142//===----------------------------------------------------------------------===//
143
144OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
145 if (OpFoldResult result = foldBinaryOpUnchecked(
146 adaptor.getOperands(),
147 [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; }))
148 return result;
149
150 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
151 // Fold `sub(x, 0) -> x`.
152 if (rhs.getValue().isZero())
153 return getLhs();
154 }
155
156 return {};
157}
158
159//===----------------------------------------------------------------------===//
160// MulOp
161//===----------------------------------------------------------------------===//
162
163OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
164 if (OpFoldResult result = foldBinaryOpUnchecked(
165 adaptor.getOperands(),
166 [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }))
167 return result;
168
169 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
170 // Fold `mul(x, 1) -> x`.
171 if (rhs.getValue().isOne())
172 return getLhs();
173 // Fold `mul(x, 0) -> 0`.
174 if (rhs.getValue().isZero())
175 return rhs;
176 }
177
178 return {};
179}
180
181//===----------------------------------------------------------------------===//
182// DivSOp
183//===----------------------------------------------------------------------===//
184
185OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
186 return foldBinaryOpChecked(
187 adaptor.getOperands(),
188 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
189 // Don't fold division by zero.
190 if (rhs.isZero())
191 return std::nullopt;
192 return lhs.sdiv(rhs);
193 });
194}
195
196//===----------------------------------------------------------------------===//
197// DivUOp
198//===----------------------------------------------------------------------===//
199
200OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
201 return foldBinaryOpChecked(
202 adaptor.getOperands(),
203 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
204 // Don't fold division by zero.
205 if (rhs.isZero())
206 return std::nullopt;
207 return lhs.udiv(rhs);
208 });
209}
210
211//===----------------------------------------------------------------------===//
212// CeilDivSOp
213//===----------------------------------------------------------------------===//
214
215/// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
216/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
217static std::optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
218 // Don't fold division by zero.
219 if (m.isZero())
220 return std::nullopt;
221 // Short-circuit the zero case.
222 if (n.isZero())
223 return n;
224
225 bool mGtZ = m.sgt(RHS: 0);
226 if (n.sgt(RHS: 0) != mGtZ) {
227 // If the operands have different signs, compute the negative result. Signed
228 // division overflow is not possible, since if `m == -1`, `n` can be at most
229 // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
230 return -(-n).sdiv(RHS: m);
231 }
232 // Otherwise, compute the positive result. Signed division overflow is not
233 // possible since if `m == -1`, `x` will be `1`.
234 int64_t x = mGtZ ? -1 : 1;
235 return (n + x).sdiv(RHS: m) + 1;
236}
237
238OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {
239 return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS);
240}
241
242//===----------------------------------------------------------------------===//
243// CeilDivUOp
244//===----------------------------------------------------------------------===//
245
246OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {
247 // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
248 return foldBinaryOpChecked(
249 adaptor.getOperands(),
250 [](const APInt &n, const APInt &m) -> std::optional<APInt> {
251 // Don't fold division by zero.
252 if (m.isZero())
253 return std::nullopt;
254 // Short-circuit the zero case.
255 if (n.isZero())
256 return n;
257
258 return (n - 1).udiv(m) + 1;
259 });
260}
261
262//===----------------------------------------------------------------------===//
263// FloorDivSOp
264//===----------------------------------------------------------------------===//
265
266/// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
267/// `n*m < 0 ? -1 - (x-n)/m : n/m`.
268static std::optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
269 // Don't fold division by zero.
270 if (m.isZero())
271 return std::nullopt;
272 // Short-circuit the zero case.
273 if (n.isZero())
274 return n;
275
276 bool mLtZ = m.slt(RHS: 0);
277 if (n.slt(RHS: 0) == mLtZ) {
278 // If the operands have the same sign, compute the positive result.
279 return n.sdiv(RHS: m);
280 }
281 // If the operands have different signs, compute the negative result. Signed
282 // division overflow is not possible since if `m == -1`, `x` will be 1 and
283 // `n` can be at most `INT_MAX`.
284 int64_t x = mLtZ ? 1 : -1;
285 return -1 - (x - n).sdiv(RHS: m);
286}
287
288OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
289 return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS);
290}
291
292//===----------------------------------------------------------------------===//
293// RemSOp
294//===----------------------------------------------------------------------===//
295
296OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
297 return foldBinaryOpChecked(
298 adaptor.getOperands(),
299 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
300 // Don't fold division by zero.
301 if (rhs.isZero())
302 return std::nullopt;
303 return lhs.srem(rhs);
304 });
305}
306
307//===----------------------------------------------------------------------===//
308// RemUOp
309//===----------------------------------------------------------------------===//
310
311OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
312 return foldBinaryOpChecked(
313 adaptor.getOperands(),
314 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
315 // Don't fold division by zero.
316 if (rhs.isZero())
317 return std::nullopt;
318 return lhs.urem(rhs);
319 });
320}
321
322//===----------------------------------------------------------------------===//
323// MaxSOp
324//===----------------------------------------------------------------------===//
325
326OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
327 return foldBinaryOpChecked(adaptor.getOperands(),
328 [](const APInt &lhs, const APInt &rhs) {
329 return lhs.sgt(rhs) ? lhs : rhs;
330 });
331}
332
333//===----------------------------------------------------------------------===//
334// MaxUOp
335//===----------------------------------------------------------------------===//
336
337OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
338 return foldBinaryOpChecked(adaptor.getOperands(),
339 [](const APInt &lhs, const APInt &rhs) {
340 return lhs.ugt(rhs) ? lhs : rhs;
341 });
342}
343
344//===----------------------------------------------------------------------===//
345// MinSOp
346//===----------------------------------------------------------------------===//
347
348OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
349 return foldBinaryOpChecked(adaptor.getOperands(),
350 [](const APInt &lhs, const APInt &rhs) {
351 return lhs.slt(rhs) ? lhs : rhs;
352 });
353}
354
355//===----------------------------------------------------------------------===//
356// MinUOp
357//===----------------------------------------------------------------------===//
358
359OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
360 return foldBinaryOpChecked(adaptor.getOperands(),
361 [](const APInt &lhs, const APInt &rhs) {
362 return lhs.ult(rhs) ? lhs : rhs;
363 });
364}
365
366//===----------------------------------------------------------------------===//
367// ShlOp
368//===----------------------------------------------------------------------===//
369
370OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
371 return foldBinaryOpUnchecked(
372 adaptor.getOperands(),
373 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
374 // We cannot fold if the RHS is greater than or equal to 32 because
375 // this would be UB in 32-bit systems but not on 64-bit systems. RHS is
376 // already treated as unsigned.
377 if (rhs.uge(32))
378 return {};
379 return lhs << rhs;
380 });
381}
382
383//===----------------------------------------------------------------------===//
384// ShrSOp
385//===----------------------------------------------------------------------===//
386
387OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
388 return foldBinaryOpChecked(
389 adaptor.getOperands(),
390 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
391 // Don't fold if RHS is greater than or equal to 32.
392 if (rhs.uge(32))
393 return {};
394 return lhs.ashr(rhs);
395 });
396}
397
398//===----------------------------------------------------------------------===//
399// ShrUOp
400//===----------------------------------------------------------------------===//
401
402OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
403 return foldBinaryOpChecked(
404 adaptor.getOperands(),
405 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
406 // Don't fold if RHS is greater than or equal to 32.
407 if (rhs.uge(32))
408 return {};
409 return lhs.lshr(rhs);
410 });
411}
412
413//===----------------------------------------------------------------------===//
414// AndOp
415//===----------------------------------------------------------------------===//
416
417OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
418 return foldBinaryOpUnchecked(
419 adaptor.getOperands(),
420 [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
421}
422
423//===----------------------------------------------------------------------===//
424// OrOp
425//===----------------------------------------------------------------------===//
426
427OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
428 return foldBinaryOpUnchecked(
429 adaptor.getOperands(),
430 [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
431}
432
433//===----------------------------------------------------------------------===//
434// XOrOp
435//===----------------------------------------------------------------------===//
436
437OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
438 return foldBinaryOpUnchecked(
439 adaptor.getOperands(),
440 [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
441}
442
443//===----------------------------------------------------------------------===//
444// CastSOp
445//===----------------------------------------------------------------------===//
446
447static OpFoldResult
448foldCastOp(Attribute input, Type type,
449 function_ref<APInt(const APInt &, unsigned)> extFn,
450 function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
451 auto attr = dyn_cast_if_present<IntegerAttr>(input);
452 if (!attr)
453 return {};
454 const APInt &value = attr.getValue();
455
456 if (isa<IndexType>(Val: type)) {
457 // When casting to an index type, perform the cast assuming a 64-bit target.
458 // The result can be truncated to 32 bits as needed and always be correct.
459 // This is because `cast32(cast64(value)) == cast32(value)`.
460 APInt result = extOrTruncFn(value, 64);
461 return IntegerAttr::get(type, result);
462 }
463
464 // When casting from an index type, we must ensure the results respect
465 // `cast_t(value) == cast_t(trunc32(value))`.
466 auto intType = cast<IntegerType>(type);
467 unsigned width = intType.getWidth();
468
469 // If the result type is at most 32 bits, then the cast can always be folded
470 // because it is always a truncation.
471 if (width <= 32) {
472 APInt result = value.trunc(width);
473 return IntegerAttr::get(type, result);
474 }
475
476 // If the result type is at least 64 bits, then the cast is always a
477 // extension. The results will differ if `trunc32(value) != value)`.
478 if (width >= 64) {
479 if (extFn(value.trunc(width: 32), 64) != value)
480 return {};
481 APInt result = extFn(value, width);
482 return IntegerAttr::get(type, result);
483 }
484
485 // Otherwise, we just have to check the property directly.
486 APInt result = value.trunc(width);
487 if (result != extFn(value.trunc(width: 32), width))
488 return {};
489 return IntegerAttr::get(type, result);
490}
491
492bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
493 return llvm::isa<IndexType>(lhsTypes.front()) !=
494 llvm::isa<IndexType>(rhsTypes.front());
495}
496
497OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
498 return foldCastOp(
499 adaptor.getInput(), getType(),
500 [](const APInt &x, unsigned width) { return x.sext(width); },
501 [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
502}
503
504//===----------------------------------------------------------------------===//
505// CastUOp
506//===----------------------------------------------------------------------===//
507
508bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
509 return llvm::isa<IndexType>(lhsTypes.front()) !=
510 llvm::isa<IndexType>(rhsTypes.front());
511}
512
513OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
514 return foldCastOp(
515 adaptor.getInput(), getType(),
516 [](const APInt &x, unsigned width) { return x.zext(width); },
517 [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
518}
519
520//===----------------------------------------------------------------------===//
521// CmpOp
522//===----------------------------------------------------------------------===//
523
524/// Compare two integers according to the comparison predicate.
525bool compareIndices(const APInt &lhs, const APInt &rhs,
526 IndexCmpPredicate pred) {
527 switch (pred) {
528 case IndexCmpPredicate::EQ:
529 return lhs.eq(RHS: rhs);
530 case IndexCmpPredicate::NE:
531 return lhs.ne(RHS: rhs);
532 case IndexCmpPredicate::SGE:
533 return lhs.sge(RHS: rhs);
534 case IndexCmpPredicate::SGT:
535 return lhs.sgt(RHS: rhs);
536 case IndexCmpPredicate::SLE:
537 return lhs.sle(RHS: rhs);
538 case IndexCmpPredicate::SLT:
539 return lhs.slt(RHS: rhs);
540 case IndexCmpPredicate::UGE:
541 return lhs.uge(RHS: rhs);
542 case IndexCmpPredicate::UGT:
543 return lhs.ugt(RHS: rhs);
544 case IndexCmpPredicate::ULE:
545 return lhs.ule(RHS: rhs);
546 case IndexCmpPredicate::ULT:
547 return lhs.ult(RHS: rhs);
548 }
549 llvm_unreachable("unhandled IndexCmpPredicate predicate");
550}
551
552/// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the
553/// values of `cstA` and `cstB`, the max or min operation, and the comparison
554/// predicate. Check whether the value folds in both 32-bit and 64-bit
555/// arithmetic and to the same value.
556static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp,
557 const APInt &cstA,
558 const APInt &cstB, unsigned width,
559 IndexCmpPredicate pred) {
560 ConstantIntRanges lhsRange = TypeSwitch<Operation *, ConstantIntRanges>(lhsOp)
561 .Case(caseFn: [&](MinSOp op) {
562 return ConstantIntRanges::fromSigned(
563 smin: APInt::getSignedMinValue(numBits: width), smax: cstA);
564 })
565 .Case(caseFn: [&](MinUOp op) {
566 return ConstantIntRanges::fromUnsigned(
567 umin: APInt::getMinValue(numBits: width), umax: cstA);
568 })
569 .Case(caseFn: [&](MaxSOp op) {
570 return ConstantIntRanges::fromSigned(
571 smin: cstA, smax: APInt::getSignedMaxValue(numBits: width));
572 })
573 .Case(caseFn: [&](MaxUOp op) {
574 return ConstantIntRanges::fromUnsigned(
575 umin: cstA, umax: APInt::getMaxValue(numBits: width));
576 });
577 return intrange::evaluatePred(pred: static_cast<intrange::CmpPredicate>(pred),
578 lhs: lhsRange, rhs: ConstantIntRanges::constant(value: cstB));
579}
580
581/// Return the result of `cmp(pred, x, x)`
582static bool compareSameArgs(IndexCmpPredicate pred) {
583 switch (pred) {
584 case IndexCmpPredicate::EQ:
585 case IndexCmpPredicate::SGE:
586 case IndexCmpPredicate::SLE:
587 case IndexCmpPredicate::UGE:
588 case IndexCmpPredicate::ULE:
589 return true;
590 case IndexCmpPredicate::NE:
591 case IndexCmpPredicate::SGT:
592 case IndexCmpPredicate::SLT:
593 case IndexCmpPredicate::UGT:
594 case IndexCmpPredicate::ULT:
595 return false;
596 }
597}
598
599OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
600 // Attempt to fold if both inputs are constant.
601 auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
602 auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
603 if (lhs && rhs) {
604 // Perform the comparison in 64-bit and 32-bit.
605 bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
606 bool result32 = compareIndices(lhs.getValue().trunc(32),
607 rhs.getValue().trunc(32), getPred());
608 if (result64 == result32)
609 return BoolAttr::get(getContext(), result64);
610 }
611
612 // Fold `cmp(max/min(x, cstA), cstB)`.
613 Operation *lhsOp = getLhs().getDefiningOp();
614 IntegerAttr cstA;
615 if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
616 matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) {
617 std::optional<bool> result64 = foldCmpOfMaxOrMin(
618 lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
619 std::optional<bool> result32 =
620 foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32),
621 rhs.getValue().trunc(32), 32, getPred());
622 // Fold if the 32-bit and 64-bit results are the same.
623 if (result64 && result32 && *result64 == *result32)
624 return BoolAttr::get(getContext(), *result64);
625 }
626
627 // Fold `cmp(x, x)`
628 if (getLhs() == getRhs())
629 return BoolAttr::get(getContext(), compareSameArgs(getPred()));
630
631 return {};
632}
633
634/// Canonicalize
635/// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`.
636/// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`.
637LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
638 IntegerAttr cmpRhs;
639 IntegerAttr cmpLhs;
640
641 bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) &&
642 cmpRhs.getValue().isZero();
643 bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) &&
644 cmpLhs.getValue().isZero();
645 if (!rhsIsZero && !lhsIsZero)
646 return rewriter.notifyMatchFailure(op.getLoc(),
647 "cmp is not comparing something with 0");
648 SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
649 : op.getRhs().getDefiningOp<index::SubOp>();
650 if (!subOp)
651 return rewriter.notifyMatchFailure(
652 op.getLoc(), "non-zero operand is not a result of subtraction");
653
654 index::CmpOp newCmp;
655 if (rhsIsZero)
656 newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
657 subOp.getLhs(), subOp.getRhs());
658 else
659 newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
660 subOp.getRhs(), subOp.getLhs());
661 rewriter.replaceOp(op, newCmp);
662 return success();
663}
664
665//===----------------------------------------------------------------------===//
666// ConstantOp
667//===----------------------------------------------------------------------===//
668
669void ConstantOp::getAsmResultNames(
670 function_ref<void(Value, StringRef)> setNameFn) {
671 SmallString<32> specialNameBuffer;
672 llvm::raw_svector_ostream specialName(specialNameBuffer);
673 specialName << "idx" << getValueAttr().getValue();
674 setNameFn(getResult(), specialName.str());
675}
676
677OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
678
679void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
680 build(b, state, b.getIndexType(), b.getIndexAttr(value));
681}
682
683//===----------------------------------------------------------------------===//
684// BoolConstantOp
685//===----------------------------------------------------------------------===//
686
687OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
688 return getValueAttr();
689}
690
691void BoolConstantOp::getAsmResultNames(
692 function_ref<void(Value, StringRef)> setNameFn) {
693 setNameFn(getResult(), getValue() ? "true" : "false");
694}
695
696//===----------------------------------------------------------------------===//
697// ODS-Generated Definitions
698//===----------------------------------------------------------------------===//
699
700#define GET_OP_CLASSES
701#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
702

source code of mlir/lib/Dialect/Index/IR/IndexOps.cpp