1//===- IndexOps.cpp - Index operation definitions --------------------------==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Index/IR/IndexOps.h"
10#include "mlir/Dialect/Index/IR/IndexAttrs.h"
11#include "mlir/Dialect/Index/IR/IndexDialect.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
16#include "llvm/ADT/SmallString.h"
17#include "llvm/ADT/TypeSwitch.h"
18
19using namespace mlir;
20using namespace mlir::index;
21
22//===----------------------------------------------------------------------===//
23// IndexDialect
24//===----------------------------------------------------------------------===//
25
26void IndexDialect::registerOperations() {
27 addOperations<
28#define GET_OP_LIST
29#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
30 >();
31}
32
33Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
34 Type type, Location loc) {
35 // Materialize bool constants as `i1`.
36 if (auto boolValue = dyn_cast<BoolAttr>(Val&: value)) {
37 if (!type.isSignlessInteger(width: 1))
38 return nullptr;
39 return b.create<BoolConstantOp>(location: loc, args&: type, args&: boolValue);
40 }
41
42 // Materialize integer attributes as `index`.
43 if (auto indexValue = dyn_cast<IntegerAttr>(Val&: value)) {
44 if (!llvm::isa<IndexType>(Val: indexValue.getType()) ||
45 !llvm::isa<IndexType>(Val: type))
46 return nullptr;
47 assert(indexValue.getValue().getBitWidth() ==
48 IndexType::kInternalStorageBitWidth);
49 return b.create<ConstantOp>(location: loc, args&: indexValue);
50 }
51
52 return nullptr;
53}
54
55//===----------------------------------------------------------------------===//
56// Fold Utilities
57//===----------------------------------------------------------------------===//
58
59/// Fold an index operation irrespective of the target bitwidth. The
60/// operation must satisfy the property:
61///
62/// ```
63/// trunc(f(a, b)) = f(trunc(a), trunc(b))
64/// ```
65///
66/// For all values of `a` and `b`. The function accepts a lambda that computes
67/// the integer result, which in turn must satisfy the above property.
68static OpFoldResult foldBinaryOpUnchecked(
69 ArrayRef<Attribute> operands,
70 function_ref<std::optional<APInt>(const APInt &, const APInt &)>
71 calculate) {
72 assert(operands.size() == 2 && "binary operation expected 2 operands");
73 auto lhs = dyn_cast_if_present<IntegerAttr>(Val: operands[0]);
74 auto rhs = dyn_cast_if_present<IntegerAttr>(Val: operands[1]);
75 if (!lhs || !rhs)
76 return {};
77
78 std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
79 if (!result)
80 return {};
81 assert(result->trunc(32) ==
82 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
83 return IntegerAttr::get(type: IndexType::get(context: lhs.getContext()), value: *result);
84}
85
86/// Fold an index operation only if the truncated 64-bit result matches the
87/// 32-bit result for operations that don't satisfy the above property. These
88/// are operations where the upper bits of the operands can affect the lower
89/// bits of the results.
90///
91/// The function accepts a lambda that computes the integer result in both
92/// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is
93/// not folded.
94static OpFoldResult foldBinaryOpChecked(
95 ArrayRef<Attribute> operands,
96 function_ref<std::optional<APInt>(const APInt &, const APInt &lhs)>
97 calculate) {
98 assert(operands.size() == 2 && "binary operation expected 2 operands");
99 auto lhs = dyn_cast_if_present<IntegerAttr>(Val: operands[0]);
100 auto rhs = dyn_cast_if_present<IntegerAttr>(Val: operands[1]);
101 // Only fold index operands.
102 if (!lhs || !rhs)
103 return {};
104
105 // Compute the 64-bit result and the 32-bit result.
106 std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
107 if (!result64)
108 return {};
109 std::optional<APInt> result32 =
110 calculate(lhs.getValue().trunc(width: 32), rhs.getValue().trunc(width: 32));
111 if (!result32)
112 return {};
113 // Compare the truncated 64-bit result to the 32-bit result.
114 if (result64->trunc(width: 32) != *result32)
115 return {};
116 // The operation can be folded for these particular operands.
117 return IntegerAttr::get(type: IndexType::get(context: lhs.getContext()), value: *result64);
118}
119
120/// Helper for associative and commutative binary ops that can be transformed:
121/// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)`
122/// where c1 and c2 are constants. It is expected that `tmp` will be folded.
123template <typename BinaryOp>
124LogicalResult
125canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op,
126 PatternRewriter &rewriter) {
127 if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant()))
128 return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
129
130 auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>();
131 if (!lhsOp)
132 return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp");
133
134 if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant()))
135 return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant");
136
137 Value c = rewriter.createOrFold<BinaryOp>(op->getLoc(), op.getRhs(),
138 lhsOp.getRhs());
139 if (c.getDefiningOp<BinaryOp>())
140 return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded");
141
142 rewriter.replaceOpWithNewOp<BinaryOp>(op, lhsOp.getLhs(), c);
143 return success();
144}
145
146//===----------------------------------------------------------------------===//
147// AddOp
148//===----------------------------------------------------------------------===//
149
150OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
151 if (OpFoldResult result = foldBinaryOpUnchecked(
152 operands: adaptor.getOperands(),
153 calculate: [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; }))
154 return result;
155
156 if (auto rhs = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getRhs())) {
157 // Fold `add(x, 0) -> x`.
158 if (rhs.getValue().isZero())
159 return getLhs();
160 }
161
162 return {};
163}
164
165LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
166 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
167}
168
169//===----------------------------------------------------------------------===//
170// SubOp
171//===----------------------------------------------------------------------===//
172
173OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
174 if (OpFoldResult result = foldBinaryOpUnchecked(
175 operands: adaptor.getOperands(),
176 calculate: [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; }))
177 return result;
178
179 if (auto rhs = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getRhs())) {
180 // Fold `sub(x, 0) -> x`.
181 if (rhs.getValue().isZero())
182 return getLhs();
183 }
184
185 return {};
186}
187
188//===----------------------------------------------------------------------===//
189// MulOp
190//===----------------------------------------------------------------------===//
191
192OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
193 if (OpFoldResult result = foldBinaryOpUnchecked(
194 operands: adaptor.getOperands(),
195 calculate: [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }))
196 return result;
197
198 if (auto rhs = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getRhs())) {
199 // Fold `mul(x, 1) -> x`.
200 if (rhs.getValue().isOne())
201 return getLhs();
202 // Fold `mul(x, 0) -> 0`.
203 if (rhs.getValue().isZero())
204 return rhs;
205 }
206
207 return {};
208}
209
210LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
211 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
212}
213
214//===----------------------------------------------------------------------===//
215// DivSOp
216//===----------------------------------------------------------------------===//
217
218OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
219 return foldBinaryOpChecked(
220 operands: adaptor.getOperands(),
221 calculate: [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
222 // Don't fold division by zero.
223 if (rhs.isZero())
224 return std::nullopt;
225 return lhs.sdiv(RHS: rhs);
226 });
227}
228
229//===----------------------------------------------------------------------===//
230// DivUOp
231//===----------------------------------------------------------------------===//
232
233OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
234 return foldBinaryOpChecked(
235 operands: adaptor.getOperands(),
236 calculate: [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
237 // Don't fold division by zero.
238 if (rhs.isZero())
239 return std::nullopt;
240 return lhs.udiv(RHS: rhs);
241 });
242}
243
244//===----------------------------------------------------------------------===//
245// CeilDivSOp
246//===----------------------------------------------------------------------===//
247
248/// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
249/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
250static std::optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
251 // Don't fold division by zero.
252 if (m.isZero())
253 return std::nullopt;
254 // Short-circuit the zero case.
255 if (n.isZero())
256 return n;
257
258 bool mGtZ = m.sgt(RHS: 0);
259 if (n.sgt(RHS: 0) != mGtZ) {
260 // If the operands have different signs, compute the negative result. Signed
261 // division overflow is not possible, since if `m == -1`, `n` can be at most
262 // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
263 return -(-n).sdiv(RHS: m);
264 }
265 // Otherwise, compute the positive result. Signed division overflow is not
266 // possible since if `m == -1`, `x` will be `1`.
267 int64_t x = mGtZ ? -1 : 1;
268 return (n + x).sdiv(RHS: m) + 1;
269}
270
271OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {
272 return foldBinaryOpChecked(operands: adaptor.getOperands(), calculate: calculateCeilDivS);
273}
274
275//===----------------------------------------------------------------------===//
276// CeilDivUOp
277//===----------------------------------------------------------------------===//
278
279OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {
280 // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
281 return foldBinaryOpChecked(
282 operands: adaptor.getOperands(),
283 calculate: [](const APInt &n, const APInt &m) -> std::optional<APInt> {
284 // Don't fold division by zero.
285 if (m.isZero())
286 return std::nullopt;
287 // Short-circuit the zero case.
288 if (n.isZero())
289 return n;
290
291 return (n - 1).udiv(RHS: m) + 1;
292 });
293}
294
295//===----------------------------------------------------------------------===//
296// FloorDivSOp
297//===----------------------------------------------------------------------===//
298
299/// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
300/// `n*m < 0 ? -1 - (x-n)/m : n/m`.
301static std::optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
302 // Don't fold division by zero.
303 if (m.isZero())
304 return std::nullopt;
305 // Short-circuit the zero case.
306 if (n.isZero())
307 return n;
308
309 bool mLtZ = m.slt(RHS: 0);
310 if (n.slt(RHS: 0) == mLtZ) {
311 // If the operands have the same sign, compute the positive result.
312 return n.sdiv(RHS: m);
313 }
314 // If the operands have different signs, compute the negative result. Signed
315 // division overflow is not possible since if `m == -1`, `x` will be 1 and
316 // `n` can be at most `INT_MAX`.
317 int64_t x = mLtZ ? 1 : -1;
318 return -1 - (x - n).sdiv(RHS: m);
319}
320
321OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
322 return foldBinaryOpChecked(operands: adaptor.getOperands(), calculate: calculateFloorDivS);
323}
324
325//===----------------------------------------------------------------------===//
326// RemSOp
327//===----------------------------------------------------------------------===//
328
329OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
330 return foldBinaryOpChecked(
331 operands: adaptor.getOperands(),
332 calculate: [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
333 // Don't fold division by zero.
334 if (rhs.isZero())
335 return std::nullopt;
336 return lhs.srem(RHS: rhs);
337 });
338}
339
340//===----------------------------------------------------------------------===//
341// RemUOp
342//===----------------------------------------------------------------------===//
343
344OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
345 return foldBinaryOpChecked(
346 operands: adaptor.getOperands(),
347 calculate: [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
348 // Don't fold division by zero.
349 if (rhs.isZero())
350 return std::nullopt;
351 return lhs.urem(RHS: rhs);
352 });
353}
354
355//===----------------------------------------------------------------------===//
356// MaxSOp
357//===----------------------------------------------------------------------===//
358
359OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
360 return foldBinaryOpChecked(operands: adaptor.getOperands(),
361 calculate: [](const APInt &lhs, const APInt &rhs) {
362 return lhs.sgt(RHS: rhs) ? lhs : rhs;
363 });
364}
365
366LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) {
367 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
368}
369
370//===----------------------------------------------------------------------===//
371// MaxUOp
372//===----------------------------------------------------------------------===//
373
374OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
375 return foldBinaryOpChecked(operands: adaptor.getOperands(),
376 calculate: [](const APInt &lhs, const APInt &rhs) {
377 return lhs.ugt(RHS: rhs) ? lhs : rhs;
378 });
379}
380
381LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) {
382 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
383}
384
385//===----------------------------------------------------------------------===//
386// MinSOp
387//===----------------------------------------------------------------------===//
388
389OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
390 return foldBinaryOpChecked(operands: adaptor.getOperands(),
391 calculate: [](const APInt &lhs, const APInt &rhs) {
392 return lhs.slt(RHS: rhs) ? lhs : rhs;
393 });
394}
395
396LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) {
397 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
398}
399
400//===----------------------------------------------------------------------===//
401// MinUOp
402//===----------------------------------------------------------------------===//
403
404OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
405 return foldBinaryOpChecked(operands: adaptor.getOperands(),
406 calculate: [](const APInt &lhs, const APInt &rhs) {
407 return lhs.ult(RHS: rhs) ? lhs : rhs;
408 });
409}
410
411LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) {
412 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
413}
414
415//===----------------------------------------------------------------------===//
416// ShlOp
417//===----------------------------------------------------------------------===//
418
419OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
420 return foldBinaryOpUnchecked(
421 operands: adaptor.getOperands(),
422 calculate: [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
423 // We cannot fold if the RHS is greater than or equal to 32 because
424 // this would be UB in 32-bit systems but not on 64-bit systems. RHS is
425 // already treated as unsigned.
426 if (rhs.uge(RHS: 32))
427 return {};
428 return lhs << rhs;
429 });
430}
431
432//===----------------------------------------------------------------------===//
433// ShrSOp
434//===----------------------------------------------------------------------===//
435
436OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
437 return foldBinaryOpChecked(
438 operands: adaptor.getOperands(),
439 calculate: [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
440 // Don't fold if RHS is greater than or equal to 32.
441 if (rhs.uge(RHS: 32))
442 return {};
443 return lhs.ashr(ShiftAmt: rhs);
444 });
445}
446
447//===----------------------------------------------------------------------===//
448// ShrUOp
449//===----------------------------------------------------------------------===//
450
451OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
452 return foldBinaryOpChecked(
453 operands: adaptor.getOperands(),
454 calculate: [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
455 // Don't fold if RHS is greater than or equal to 32.
456 if (rhs.uge(RHS: 32))
457 return {};
458 return lhs.lshr(ShiftAmt: rhs);
459 });
460}
461
462//===----------------------------------------------------------------------===//
463// AndOp
464//===----------------------------------------------------------------------===//
465
466OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
467 return foldBinaryOpUnchecked(
468 operands: adaptor.getOperands(),
469 calculate: [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
470}
471
472LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
473 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
474}
475
476//===----------------------------------------------------------------------===//
477// OrOp
478//===----------------------------------------------------------------------===//
479
480OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
481 return foldBinaryOpUnchecked(
482 operands: adaptor.getOperands(),
483 calculate: [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
484}
485
486LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
487 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
488}
489
490//===----------------------------------------------------------------------===//
491// XOrOp
492//===----------------------------------------------------------------------===//
493
494OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
495 return foldBinaryOpUnchecked(
496 operands: adaptor.getOperands(),
497 calculate: [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
498}
499
500LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) {
501 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
502}
503
504//===----------------------------------------------------------------------===//
505// CastSOp
506//===----------------------------------------------------------------------===//
507
508static OpFoldResult
509foldCastOp(Attribute input, Type type,
510 function_ref<APInt(const APInt &, unsigned)> extFn,
511 function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
512 auto attr = dyn_cast_if_present<IntegerAttr>(Val&: input);
513 if (!attr)
514 return {};
515 const APInt &value = attr.getValue();
516
517 if (isa<IndexType>(Val: type)) {
518 // When casting to an index type, perform the cast assuming a 64-bit target.
519 // The result can be truncated to 32 bits as needed and always be correct.
520 // This is because `cast32(cast64(value)) == cast32(value)`.
521 APInt result = extOrTruncFn(value, 64);
522 return IntegerAttr::get(type, value: result);
523 }
524
525 // When casting from an index type, we must ensure the results respect
526 // `cast_t(value) == cast_t(trunc32(value))`.
527 auto intType = cast<IntegerType>(Val&: type);
528 unsigned width = intType.getWidth();
529
530 // If the result type is at most 32 bits, then the cast can always be folded
531 // because it is always a truncation.
532 if (width <= 32) {
533 APInt result = value.trunc(width);
534 return IntegerAttr::get(type, value: result);
535 }
536
537 // If the result type is at least 64 bits, then the cast is always a
538 // extension. The results will differ if `trunc32(value) != value)`.
539 if (width >= 64) {
540 if (extFn(value.trunc(width: 32), 64) != value)
541 return {};
542 APInt result = extFn(value, width);
543 return IntegerAttr::get(type, value: result);
544 }
545
546 // Otherwise, we just have to check the property directly.
547 APInt result = value.trunc(width);
548 if (result != extFn(value.trunc(width: 32), width))
549 return {};
550 return IntegerAttr::get(type, value: result);
551}
552
553bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
554 return llvm::isa<IndexType>(Val: lhsTypes.front()) !=
555 llvm::isa<IndexType>(Val: rhsTypes.front());
556}
557
558OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
559 return foldCastOp(
560 input: adaptor.getInput(), type: getType(),
561 extFn: [](const APInt &x, unsigned width) { return x.sext(width); },
562 extOrTruncFn: [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
563}
564
565//===----------------------------------------------------------------------===//
566// CastUOp
567//===----------------------------------------------------------------------===//
568
569bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
570 return llvm::isa<IndexType>(Val: lhsTypes.front()) !=
571 llvm::isa<IndexType>(Val: rhsTypes.front());
572}
573
574OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
575 return foldCastOp(
576 input: adaptor.getInput(), type: getType(),
577 extFn: [](const APInt &x, unsigned width) { return x.zext(width); },
578 extOrTruncFn: [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
579}
580
581//===----------------------------------------------------------------------===//
582// CmpOp
583//===----------------------------------------------------------------------===//
584
585/// Compare two integers according to the comparison predicate.
586bool compareIndices(const APInt &lhs, const APInt &rhs,
587 IndexCmpPredicate pred) {
588 switch (pred) {
589 case IndexCmpPredicate::EQ:
590 return lhs.eq(RHS: rhs);
591 case IndexCmpPredicate::NE:
592 return lhs.ne(RHS: rhs);
593 case IndexCmpPredicate::SGE:
594 return lhs.sge(RHS: rhs);
595 case IndexCmpPredicate::SGT:
596 return lhs.sgt(RHS: rhs);
597 case IndexCmpPredicate::SLE:
598 return lhs.sle(RHS: rhs);
599 case IndexCmpPredicate::SLT:
600 return lhs.slt(RHS: rhs);
601 case IndexCmpPredicate::UGE:
602 return lhs.uge(RHS: rhs);
603 case IndexCmpPredicate::UGT:
604 return lhs.ugt(RHS: rhs);
605 case IndexCmpPredicate::ULE:
606 return lhs.ule(RHS: rhs);
607 case IndexCmpPredicate::ULT:
608 return lhs.ult(RHS: rhs);
609 }
610 llvm_unreachable("unhandled IndexCmpPredicate predicate");
611}
612
613/// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the
614/// values of `cstA` and `cstB`, the max or min operation, and the comparison
615/// predicate. Check whether the value folds in both 32-bit and 64-bit
616/// arithmetic and to the same value.
617static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp,
618 const APInt &cstA,
619 const APInt &cstB, unsigned width,
620 IndexCmpPredicate pred) {
621 ConstantIntRanges lhsRange = TypeSwitch<Operation *, ConstantIntRanges>(lhsOp)
622 .Case(caseFn: [&](MinSOp op) {
623 return ConstantIntRanges::fromSigned(
624 smin: APInt::getSignedMinValue(numBits: width), smax: cstA);
625 })
626 .Case(caseFn: [&](MinUOp op) {
627 return ConstantIntRanges::fromUnsigned(
628 umin: APInt::getMinValue(numBits: width), umax: cstA);
629 })
630 .Case(caseFn: [&](MaxSOp op) {
631 return ConstantIntRanges::fromSigned(
632 smin: cstA, smax: APInt::getSignedMaxValue(numBits: width));
633 })
634 .Case(caseFn: [&](MaxUOp op) {
635 return ConstantIntRanges::fromUnsigned(
636 umin: cstA, umax: APInt::getMaxValue(numBits: width));
637 });
638 return intrange::evaluatePred(pred: static_cast<intrange::CmpPredicate>(pred),
639 lhs: lhsRange, rhs: ConstantIntRanges::constant(value: cstB));
640}
641
642/// Return the result of `cmp(pred, x, x)`
643static bool compareSameArgs(IndexCmpPredicate pred) {
644 switch (pred) {
645 case IndexCmpPredicate::EQ:
646 case IndexCmpPredicate::SGE:
647 case IndexCmpPredicate::SLE:
648 case IndexCmpPredicate::UGE:
649 case IndexCmpPredicate::ULE:
650 return true;
651 case IndexCmpPredicate::NE:
652 case IndexCmpPredicate::SGT:
653 case IndexCmpPredicate::SLT:
654 case IndexCmpPredicate::UGT:
655 case IndexCmpPredicate::ULT:
656 return false;
657 }
658 llvm_unreachable("unknown predicate in compareSameArgs");
659}
660
661OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
662 // Attempt to fold if both inputs are constant.
663 auto lhs = dyn_cast_if_present<IntegerAttr>(Val: adaptor.getLhs());
664 auto rhs = dyn_cast_if_present<IntegerAttr>(Val: adaptor.getRhs());
665 if (lhs && rhs) {
666 // Perform the comparison in 64-bit and 32-bit.
667 bool result64 = compareIndices(lhs: lhs.getValue(), rhs: rhs.getValue(), pred: getPred());
668 bool result32 = compareIndices(lhs: lhs.getValue().trunc(width: 32),
669 rhs: rhs.getValue().trunc(width: 32), pred: getPred());
670 if (result64 == result32)
671 return BoolAttr::get(context: getContext(), value: result64);
672 }
673
674 // Fold `cmp(max/min(x, cstA), cstB)`.
675 Operation *lhsOp = getLhs().getDefiningOp();
676 IntegerAttr cstA;
677 if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(Val: lhsOp) &&
678 matchPattern(value: lhsOp->getOperand(idx: 1), pattern: m_Constant(bind_value: &cstA)) && rhs) {
679 std::optional<bool> result64 = foldCmpOfMaxOrMin(
680 lhsOp, cstA: cstA.getValue(), cstB: rhs.getValue(), width: 64, pred: getPred());
681 std::optional<bool> result32 =
682 foldCmpOfMaxOrMin(lhsOp, cstA: cstA.getValue().trunc(width: 32),
683 cstB: rhs.getValue().trunc(width: 32), width: 32, pred: getPred());
684 // Fold if the 32-bit and 64-bit results are the same.
685 if (result64 && result32 && *result64 == *result32)
686 return BoolAttr::get(context: getContext(), value: *result64);
687 }
688
689 // Fold `cmp(x, x)`
690 if (getLhs() == getRhs())
691 return BoolAttr::get(context: getContext(), value: compareSameArgs(pred: getPred()));
692
693 return {};
694}
695
696/// Canonicalize
697/// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`.
698/// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`.
699LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
700 IntegerAttr cmpRhs;
701 IntegerAttr cmpLhs;
702
703 bool rhsIsZero = matchPattern(value: op.getRhs(), pattern: m_Constant(bind_value: &cmpRhs)) &&
704 cmpRhs.getValue().isZero();
705 bool lhsIsZero = matchPattern(value: op.getLhs(), pattern: m_Constant(bind_value: &cmpLhs)) &&
706 cmpLhs.getValue().isZero();
707 if (!rhsIsZero && !lhsIsZero)
708 return rewriter.notifyMatchFailure(arg: op.getLoc(),
709 msg: "cmp is not comparing something with 0");
710 SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
711 : op.getRhs().getDefiningOp<index::SubOp>();
712 if (!subOp)
713 return rewriter.notifyMatchFailure(
714 arg: op.getLoc(), msg: "non-zero operand is not a result of subtraction");
715
716 index::CmpOp newCmp;
717 if (rhsIsZero)
718 newCmp = rewriter.create<index::CmpOp>(location: op.getLoc(), args: op.getPred(),
719 args: subOp.getLhs(), args: subOp.getRhs());
720 else
721 newCmp = rewriter.create<index::CmpOp>(location: op.getLoc(), args: op.getPred(),
722 args: subOp.getRhs(), args: subOp.getLhs());
723 rewriter.replaceOp(op, newOp: newCmp);
724 return success();
725}
726
727//===----------------------------------------------------------------------===//
728// ConstantOp
729//===----------------------------------------------------------------------===//
730
731void ConstantOp::getAsmResultNames(
732 function_ref<void(Value, StringRef)> setNameFn) {
733 SmallString<32> specialNameBuffer;
734 llvm::raw_svector_ostream specialName(specialNameBuffer);
735 specialName << "idx" << getValueAttr().getValue();
736 setNameFn(getResult(), specialName.str());
737}
738
739OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
740
741void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
742 build(odsBuilder&: b, odsState&: state, result: b.getIndexType(), value: b.getIndexAttr(value));
743}
744
745//===----------------------------------------------------------------------===//
746// BoolConstantOp
747//===----------------------------------------------------------------------===//
748
749OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
750 return getValueAttr();
751}
752
753void BoolConstantOp::getAsmResultNames(
754 function_ref<void(Value, StringRef)> setNameFn) {
755 setNameFn(getResult(), getValue() ? "true" : "false");
756}
757
758//===----------------------------------------------------------------------===//
759// ODS-Generated Definitions
760//===----------------------------------------------------------------------===//
761
762#define GET_OP_CLASSES
763#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
764

source code of mlir/lib/Dialect/Index/IR/IndexOps.cpp