1//===- IndexOps.cpp - Index operation definitions --------------------------==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Index/IR/IndexOps.h"
10#include "mlir/Dialect/Index/IR/IndexAttrs.h"
11#include "mlir/Dialect/Index/IR/IndexDialect.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/OpImplementation.h"
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
17#include "llvm/ADT/SmallString.h"
18#include "llvm/ADT/TypeSwitch.h"
19
20using namespace mlir;
21using namespace mlir::index;
22
23//===----------------------------------------------------------------------===//
24// IndexDialect
25//===----------------------------------------------------------------------===//
26
27void IndexDialect::registerOperations() {
28 addOperations<
29#define GET_OP_LIST
30#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
31 >();
32}
33
34Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
35 Type type, Location loc) {
36 // Materialize bool constants as `i1`.
37 if (auto boolValue = dyn_cast<BoolAttr>(value)) {
38 if (!type.isSignlessInteger(1))
39 return nullptr;
40 return b.create<BoolConstantOp>(loc, type, boolValue);
41 }
42
43 // Materialize integer attributes as `index`.
44 if (auto indexValue = dyn_cast<IntegerAttr>(value)) {
45 if (!llvm::isa<IndexType>(indexValue.getType()) ||
46 !llvm::isa<IndexType>(type))
47 return nullptr;
48 assert(indexValue.getValue().getBitWidth() ==
49 IndexType::kInternalStorageBitWidth);
50 return b.create<ConstantOp>(loc, indexValue);
51 }
52
53 return nullptr;
54}
55
56//===----------------------------------------------------------------------===//
57// Fold Utilities
58//===----------------------------------------------------------------------===//
59
60/// Fold an index operation irrespective of the target bitwidth. The
61/// operation must satisfy the property:
62///
63/// ```
64/// trunc(f(a, b)) = f(trunc(a), trunc(b))
65/// ```
66///
67/// For all values of `a` and `b`. The function accepts a lambda that computes
68/// the integer result, which in turn must satisfy the above property.
69static OpFoldResult foldBinaryOpUnchecked(
70 ArrayRef<Attribute> operands,
71 function_ref<std::optional<APInt>(const APInt &, const APInt &)>
72 calculate) {
73 assert(operands.size() == 2 && "binary operation expected 2 operands");
74 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
75 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
76 if (!lhs || !rhs)
77 return {};
78
79 std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
80 if (!result)
81 return {};
82 assert(result->trunc(32) ==
83 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
84 return IntegerAttr::get(IndexType::get(lhs.getContext()), *result);
85}
86
87/// Fold an index operation only if the truncated 64-bit result matches the
88/// 32-bit result for operations that don't satisfy the above property. These
89/// are operations where the upper bits of the operands can affect the lower
90/// bits of the results.
91///
92/// The function accepts a lambda that computes the integer result in both
93/// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is
94/// not folded.
95static OpFoldResult foldBinaryOpChecked(
96 ArrayRef<Attribute> operands,
97 function_ref<std::optional<APInt>(const APInt &, const APInt &lhs)>
98 calculate) {
99 assert(operands.size() == 2 && "binary operation expected 2 operands");
100 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
101 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
102 // Only fold index operands.
103 if (!lhs || !rhs)
104 return {};
105
106 // Compute the 64-bit result and the 32-bit result.
107 std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
108 if (!result64)
109 return {};
110 std::optional<APInt> result32 =
111 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
112 if (!result32)
113 return {};
114 // Compare the truncated 64-bit result to the 32-bit result.
115 if (result64->trunc(width: 32) != *result32)
116 return {};
117 // The operation can be folded for these particular operands.
118 return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64);
119}
120
121/// Helper for associative and commutative binary ops that can be transformed:
122/// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)`
123/// where c1 and c2 are constants. It is expected that `tmp` will be folded.
124template <typename BinaryOp>
125LogicalResult
126canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op,
127 PatternRewriter &rewriter) {
128 if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant()))
129 return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
130
131 auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>();
132 if (!lhsOp)
133 return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp");
134
135 if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant()))
136 return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant");
137
138 Value c = rewriter.createOrFold<BinaryOp>(op->getLoc(), op.getRhs(),
139 lhsOp.getRhs());
140 if (c.getDefiningOp<BinaryOp>())
141 return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded");
142
143 rewriter.replaceOpWithNewOp<BinaryOp>(op, lhsOp.getLhs(), c);
144 return success();
145}
146
147//===----------------------------------------------------------------------===//
148// AddOp
149//===----------------------------------------------------------------------===//
150
151OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
152 if (OpFoldResult result = foldBinaryOpUnchecked(
153 adaptor.getOperands(),
154 [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; }))
155 return result;
156
157 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
158 // Fold `add(x, 0) -> x`.
159 if (rhs.getValue().isZero())
160 return getLhs();
161 }
162
163 return {};
164}
165
166LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
167 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
168}
169
170//===----------------------------------------------------------------------===//
171// SubOp
172//===----------------------------------------------------------------------===//
173
174OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
175 if (OpFoldResult result = foldBinaryOpUnchecked(
176 adaptor.getOperands(),
177 [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; }))
178 return result;
179
180 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
181 // Fold `sub(x, 0) -> x`.
182 if (rhs.getValue().isZero())
183 return getLhs();
184 }
185
186 return {};
187}
188
189//===----------------------------------------------------------------------===//
190// MulOp
191//===----------------------------------------------------------------------===//
192
193OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
194 if (OpFoldResult result = foldBinaryOpUnchecked(
195 adaptor.getOperands(),
196 [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }))
197 return result;
198
199 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
200 // Fold `mul(x, 1) -> x`.
201 if (rhs.getValue().isOne())
202 return getLhs();
203 // Fold `mul(x, 0) -> 0`.
204 if (rhs.getValue().isZero())
205 return rhs;
206 }
207
208 return {};
209}
210
211LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
212 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
213}
214
215//===----------------------------------------------------------------------===//
216// DivSOp
217//===----------------------------------------------------------------------===//
218
219OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
220 return foldBinaryOpChecked(
221 adaptor.getOperands(),
222 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
223 // Don't fold division by zero.
224 if (rhs.isZero())
225 return std::nullopt;
226 return lhs.sdiv(rhs);
227 });
228}
229
230//===----------------------------------------------------------------------===//
231// DivUOp
232//===----------------------------------------------------------------------===//
233
234OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
235 return foldBinaryOpChecked(
236 adaptor.getOperands(),
237 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
238 // Don't fold division by zero.
239 if (rhs.isZero())
240 return std::nullopt;
241 return lhs.udiv(rhs);
242 });
243}
244
245//===----------------------------------------------------------------------===//
246// CeilDivSOp
247//===----------------------------------------------------------------------===//
248
249/// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
250/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
251static std::optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
252 // Don't fold division by zero.
253 if (m.isZero())
254 return std::nullopt;
255 // Short-circuit the zero case.
256 if (n.isZero())
257 return n;
258
259 bool mGtZ = m.sgt(RHS: 0);
260 if (n.sgt(RHS: 0) != mGtZ) {
261 // If the operands have different signs, compute the negative result. Signed
262 // division overflow is not possible, since if `m == -1`, `n` can be at most
263 // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
264 return -(-n).sdiv(RHS: m);
265 }
266 // Otherwise, compute the positive result. Signed division overflow is not
267 // possible since if `m == -1`, `x` will be `1`.
268 int64_t x = mGtZ ? -1 : 1;
269 return (n + x).sdiv(RHS: m) + 1;
270}
271
272OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {
273 return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS);
274}
275
276//===----------------------------------------------------------------------===//
277// CeilDivUOp
278//===----------------------------------------------------------------------===//
279
280OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {
281 // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
282 return foldBinaryOpChecked(
283 adaptor.getOperands(),
284 [](const APInt &n, const APInt &m) -> std::optional<APInt> {
285 // Don't fold division by zero.
286 if (m.isZero())
287 return std::nullopt;
288 // Short-circuit the zero case.
289 if (n.isZero())
290 return n;
291
292 return (n - 1).udiv(m) + 1;
293 });
294}
295
296//===----------------------------------------------------------------------===//
297// FloorDivSOp
298//===----------------------------------------------------------------------===//
299
300/// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
301/// `n*m < 0 ? -1 - (x-n)/m : n/m`.
302static std::optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
303 // Don't fold division by zero.
304 if (m.isZero())
305 return std::nullopt;
306 // Short-circuit the zero case.
307 if (n.isZero())
308 return n;
309
310 bool mLtZ = m.slt(RHS: 0);
311 if (n.slt(RHS: 0) == mLtZ) {
312 // If the operands have the same sign, compute the positive result.
313 return n.sdiv(RHS: m);
314 }
315 // If the operands have different signs, compute the negative result. Signed
316 // division overflow is not possible since if `m == -1`, `x` will be 1 and
317 // `n` can be at most `INT_MAX`.
318 int64_t x = mLtZ ? 1 : -1;
319 return -1 - (x - n).sdiv(RHS: m);
320}
321
322OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
323 return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS);
324}
325
326//===----------------------------------------------------------------------===//
327// RemSOp
328//===----------------------------------------------------------------------===//
329
330OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
331 return foldBinaryOpChecked(
332 adaptor.getOperands(),
333 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
334 // Don't fold division by zero.
335 if (rhs.isZero())
336 return std::nullopt;
337 return lhs.srem(rhs);
338 });
339}
340
341//===----------------------------------------------------------------------===//
342// RemUOp
343//===----------------------------------------------------------------------===//
344
345OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
346 return foldBinaryOpChecked(
347 adaptor.getOperands(),
348 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
349 // Don't fold division by zero.
350 if (rhs.isZero())
351 return std::nullopt;
352 return lhs.urem(rhs);
353 });
354}
355
356//===----------------------------------------------------------------------===//
357// MaxSOp
358//===----------------------------------------------------------------------===//
359
360OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
361 return foldBinaryOpChecked(adaptor.getOperands(),
362 [](const APInt &lhs, const APInt &rhs) {
363 return lhs.sgt(rhs) ? lhs : rhs;
364 });
365}
366
367LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) {
368 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
369}
370
371//===----------------------------------------------------------------------===//
372// MaxUOp
373//===----------------------------------------------------------------------===//
374
375OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
376 return foldBinaryOpChecked(adaptor.getOperands(),
377 [](const APInt &lhs, const APInt &rhs) {
378 return lhs.ugt(rhs) ? lhs : rhs;
379 });
380}
381
382LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) {
383 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
384}
385
386//===----------------------------------------------------------------------===//
387// MinSOp
388//===----------------------------------------------------------------------===//
389
390OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
391 return foldBinaryOpChecked(adaptor.getOperands(),
392 [](const APInt &lhs, const APInt &rhs) {
393 return lhs.slt(rhs) ? lhs : rhs;
394 });
395}
396
397LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) {
398 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
399}
400
401//===----------------------------------------------------------------------===//
402// MinUOp
403//===----------------------------------------------------------------------===//
404
405OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
406 return foldBinaryOpChecked(adaptor.getOperands(),
407 [](const APInt &lhs, const APInt &rhs) {
408 return lhs.ult(rhs) ? lhs : rhs;
409 });
410}
411
412LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) {
413 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
414}
415
416//===----------------------------------------------------------------------===//
417// ShlOp
418//===----------------------------------------------------------------------===//
419
420OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
421 return foldBinaryOpUnchecked(
422 adaptor.getOperands(),
423 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
424 // We cannot fold if the RHS is greater than or equal to 32 because
425 // this would be UB in 32-bit systems but not on 64-bit systems. RHS is
426 // already treated as unsigned.
427 if (rhs.uge(32))
428 return {};
429 return lhs << rhs;
430 });
431}
432
433//===----------------------------------------------------------------------===//
434// ShrSOp
435//===----------------------------------------------------------------------===//
436
437OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
438 return foldBinaryOpChecked(
439 adaptor.getOperands(),
440 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
441 // Don't fold if RHS is greater than or equal to 32.
442 if (rhs.uge(32))
443 return {};
444 return lhs.ashr(rhs);
445 });
446}
447
448//===----------------------------------------------------------------------===//
449// ShrUOp
450//===----------------------------------------------------------------------===//
451
452OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
453 return foldBinaryOpChecked(
454 adaptor.getOperands(),
455 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
456 // Don't fold if RHS is greater than or equal to 32.
457 if (rhs.uge(32))
458 return {};
459 return lhs.lshr(rhs);
460 });
461}
462
463//===----------------------------------------------------------------------===//
464// AndOp
465//===----------------------------------------------------------------------===//
466
467OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
468 return foldBinaryOpUnchecked(
469 adaptor.getOperands(),
470 [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
471}
472
473LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
474 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
475}
476
477//===----------------------------------------------------------------------===//
478// OrOp
479//===----------------------------------------------------------------------===//
480
481OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
482 return foldBinaryOpUnchecked(
483 adaptor.getOperands(),
484 [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
485}
486
487LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
488 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
489}
490
491//===----------------------------------------------------------------------===//
492// XOrOp
493//===----------------------------------------------------------------------===//
494
495OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
496 return foldBinaryOpUnchecked(
497 adaptor.getOperands(),
498 [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
499}
500
501LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) {
502 return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
503}
504
505//===----------------------------------------------------------------------===//
506// CastSOp
507//===----------------------------------------------------------------------===//
508
509static OpFoldResult
510foldCastOp(Attribute input, Type type,
511 function_ref<APInt(const APInt &, unsigned)> extFn,
512 function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
513 auto attr = dyn_cast_if_present<IntegerAttr>(input);
514 if (!attr)
515 return {};
516 const APInt &value = attr.getValue();
517
518 if (isa<IndexType>(Val: type)) {
519 // When casting to an index type, perform the cast assuming a 64-bit target.
520 // The result can be truncated to 32 bits as needed and always be correct.
521 // This is because `cast32(cast64(value)) == cast32(value)`.
522 APInt result = extOrTruncFn(value, 64);
523 return IntegerAttr::get(type, result);
524 }
525
526 // When casting from an index type, we must ensure the results respect
527 // `cast_t(value) == cast_t(trunc32(value))`.
528 auto intType = cast<IntegerType>(type);
529 unsigned width = intType.getWidth();
530
531 // If the result type is at most 32 bits, then the cast can always be folded
532 // because it is always a truncation.
533 if (width <= 32) {
534 APInt result = value.trunc(width);
535 return IntegerAttr::get(type, result);
536 }
537
538 // If the result type is at least 64 bits, then the cast is always a
539 // extension. The results will differ if `trunc32(value) != value)`.
540 if (width >= 64) {
541 if (extFn(value.trunc(width: 32), 64) != value)
542 return {};
543 APInt result = extFn(value, width);
544 return IntegerAttr::get(type, result);
545 }
546
547 // Otherwise, we just have to check the property directly.
548 APInt result = value.trunc(width);
549 if (result != extFn(value.trunc(width: 32), width))
550 return {};
551 return IntegerAttr::get(type, result);
552}
553
554bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
555 return llvm::isa<IndexType>(lhsTypes.front()) !=
556 llvm::isa<IndexType>(rhsTypes.front());
557}
558
559OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
560 return foldCastOp(
561 adaptor.getInput(), getType(),
562 [](const APInt &x, unsigned width) { return x.sext(width); },
563 [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
564}
565
566//===----------------------------------------------------------------------===//
567// CastUOp
568//===----------------------------------------------------------------------===//
569
570bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
571 return llvm::isa<IndexType>(lhsTypes.front()) !=
572 llvm::isa<IndexType>(rhsTypes.front());
573}
574
575OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
576 return foldCastOp(
577 adaptor.getInput(), getType(),
578 [](const APInt &x, unsigned width) { return x.zext(width); },
579 [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
580}
581
582//===----------------------------------------------------------------------===//
583// CmpOp
584//===----------------------------------------------------------------------===//
585
586/// Compare two integers according to the comparison predicate.
587bool compareIndices(const APInt &lhs, const APInt &rhs,
588 IndexCmpPredicate pred) {
589 switch (pred) {
590 case IndexCmpPredicate::EQ:
591 return lhs.eq(RHS: rhs);
592 case IndexCmpPredicate::NE:
593 return lhs.ne(RHS: rhs);
594 case IndexCmpPredicate::SGE:
595 return lhs.sge(RHS: rhs);
596 case IndexCmpPredicate::SGT:
597 return lhs.sgt(RHS: rhs);
598 case IndexCmpPredicate::SLE:
599 return lhs.sle(RHS: rhs);
600 case IndexCmpPredicate::SLT:
601 return lhs.slt(RHS: rhs);
602 case IndexCmpPredicate::UGE:
603 return lhs.uge(RHS: rhs);
604 case IndexCmpPredicate::UGT:
605 return lhs.ugt(RHS: rhs);
606 case IndexCmpPredicate::ULE:
607 return lhs.ule(RHS: rhs);
608 case IndexCmpPredicate::ULT:
609 return lhs.ult(RHS: rhs);
610 }
611 llvm_unreachable("unhandled IndexCmpPredicate predicate");
612}
613
614/// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the
615/// values of `cstA` and `cstB`, the max or min operation, and the comparison
616/// predicate. Check whether the value folds in both 32-bit and 64-bit
617/// arithmetic and to the same value.
618static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp,
619 const APInt &cstA,
620 const APInt &cstB, unsigned width,
621 IndexCmpPredicate pred) {
622 ConstantIntRanges lhsRange = TypeSwitch<Operation *, ConstantIntRanges>(lhsOp)
623 .Case(caseFn: [&](MinSOp op) {
624 return ConstantIntRanges::fromSigned(
625 smin: APInt::getSignedMinValue(numBits: width), smax: cstA);
626 })
627 .Case(caseFn: [&](MinUOp op) {
628 return ConstantIntRanges::fromUnsigned(
629 umin: APInt::getMinValue(numBits: width), umax: cstA);
630 })
631 .Case(caseFn: [&](MaxSOp op) {
632 return ConstantIntRanges::fromSigned(
633 smin: cstA, smax: APInt::getSignedMaxValue(numBits: width));
634 })
635 .Case(caseFn: [&](MaxUOp op) {
636 return ConstantIntRanges::fromUnsigned(
637 umin: cstA, umax: APInt::getMaxValue(numBits: width));
638 });
639 return intrange::evaluatePred(pred: static_cast<intrange::CmpPredicate>(pred),
640 lhs: lhsRange, rhs: ConstantIntRanges::constant(value: cstB));
641}
642
643/// Return the result of `cmp(pred, x, x)`
644static bool compareSameArgs(IndexCmpPredicate pred) {
645 switch (pred) {
646 case IndexCmpPredicate::EQ:
647 case IndexCmpPredicate::SGE:
648 case IndexCmpPredicate::SLE:
649 case IndexCmpPredicate::UGE:
650 case IndexCmpPredicate::ULE:
651 return true;
652 case IndexCmpPredicate::NE:
653 case IndexCmpPredicate::SGT:
654 case IndexCmpPredicate::SLT:
655 case IndexCmpPredicate::UGT:
656 case IndexCmpPredicate::ULT:
657 return false;
658 }
659 llvm_unreachable("unknown predicate in compareSameArgs");
660}
661
662OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
663 // Attempt to fold if both inputs are constant.
664 auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
665 auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
666 if (lhs && rhs) {
667 // Perform the comparison in 64-bit and 32-bit.
668 bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
669 bool result32 = compareIndices(lhs.getValue().trunc(32),
670 rhs.getValue().trunc(32), getPred());
671 if (result64 == result32)
672 return BoolAttr::get(getContext(), result64);
673 }
674
675 // Fold `cmp(max/min(x, cstA), cstB)`.
676 Operation *lhsOp = getLhs().getDefiningOp();
677 IntegerAttr cstA;
678 if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
679 matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) {
680 std::optional<bool> result64 = foldCmpOfMaxOrMin(
681 lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
682 std::optional<bool> result32 =
683 foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32),
684 rhs.getValue().trunc(32), 32, getPred());
685 // Fold if the 32-bit and 64-bit results are the same.
686 if (result64 && result32 && *result64 == *result32)
687 return BoolAttr::get(getContext(), *result64);
688 }
689
690 // Fold `cmp(x, x)`
691 if (getLhs() == getRhs())
692 return BoolAttr::get(getContext(), compareSameArgs(getPred()));
693
694 return {};
695}
696
697/// Canonicalize
698/// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`.
699/// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`.
700LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
701 IntegerAttr cmpRhs;
702 IntegerAttr cmpLhs;
703
704 bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) &&
705 cmpRhs.getValue().isZero();
706 bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) &&
707 cmpLhs.getValue().isZero();
708 if (!rhsIsZero && !lhsIsZero)
709 return rewriter.notifyMatchFailure(op.getLoc(),
710 "cmp is not comparing something with 0");
711 SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
712 : op.getRhs().getDefiningOp<index::SubOp>();
713 if (!subOp)
714 return rewriter.notifyMatchFailure(
715 op.getLoc(), "non-zero operand is not a result of subtraction");
716
717 index::CmpOp newCmp;
718 if (rhsIsZero)
719 newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
720 subOp.getLhs(), subOp.getRhs());
721 else
722 newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
723 subOp.getRhs(), subOp.getLhs());
724 rewriter.replaceOp(op, newCmp);
725 return success();
726}
727
728//===----------------------------------------------------------------------===//
729// ConstantOp
730//===----------------------------------------------------------------------===//
731
732void ConstantOp::getAsmResultNames(
733 function_ref<void(Value, StringRef)> setNameFn) {
734 SmallString<32> specialNameBuffer;
735 llvm::raw_svector_ostream specialName(specialNameBuffer);
736 specialName << "idx" << getValueAttr().getValue();
737 setNameFn(getResult(), specialName.str());
738}
739
740OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
741
742void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
743 build(b, state, b.getIndexType(), b.getIndexAttr(value));
744}
745
746//===----------------------------------------------------------------------===//
747// BoolConstantOp
748//===----------------------------------------------------------------------===//
749
750OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
751 return getValueAttr();
752}
753
754void BoolConstantOp::getAsmResultNames(
755 function_ref<void(Value, StringRef)> setNameFn) {
756 setNameFn(getResult(), getValue() ? "true" : "false");
757}
758
759//===----------------------------------------------------------------------===//
760// ODS-Generated Definitions
761//===----------------------------------------------------------------------===//
762
763#define GET_OP_CLASSES
764#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
765

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Index/IR/IndexOps.cpp