1 | //===- IndexOps.cpp - Index operation definitions --------------------------==// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
10 | #include "mlir/Dialect/Index/IR/IndexAttrs.h" |
11 | #include "mlir/Dialect/Index/IR/IndexDialect.h" |
12 | #include "mlir/IR/Builders.h" |
13 | #include "mlir/IR/Matchers.h" |
14 | #include "mlir/IR/OpImplementation.h" |
15 | #include "mlir/IR/PatternMatch.h" |
16 | #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" |
17 | #include "llvm/ADT/SmallString.h" |
18 | #include "llvm/ADT/TypeSwitch.h" |
19 | |
20 | using namespace mlir; |
21 | using namespace mlir::index; |
22 | |
23 | //===----------------------------------------------------------------------===// |
24 | // IndexDialect |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | void IndexDialect::registerOperations() { |
28 | addOperations< |
29 | #define GET_OP_LIST |
30 | #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc" |
31 | >(); |
32 | } |
33 | |
34 | Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value, |
35 | Type type, Location loc) { |
36 | // Materialize bool constants as `i1`. |
37 | if (auto boolValue = dyn_cast<BoolAttr>(value)) { |
38 | if (!type.isSignlessInteger(1)) |
39 | return nullptr; |
40 | return b.create<BoolConstantOp>(loc, type, boolValue); |
41 | } |
42 | |
43 | // Materialize integer attributes as `index`. |
44 | if (auto indexValue = dyn_cast<IntegerAttr>(value)) { |
45 | if (!llvm::isa<IndexType>(indexValue.getType()) || |
46 | !llvm::isa<IndexType>(type)) |
47 | return nullptr; |
48 | assert(indexValue.getValue().getBitWidth() == |
49 | IndexType::kInternalStorageBitWidth); |
50 | return b.create<ConstantOp>(loc, indexValue); |
51 | } |
52 | |
53 | return nullptr; |
54 | } |
55 | |
56 | //===----------------------------------------------------------------------===// |
57 | // Fold Utilities |
58 | //===----------------------------------------------------------------------===// |
59 | |
60 | /// Fold an index operation irrespective of the target bitwidth. The |
61 | /// operation must satisfy the property: |
62 | /// |
63 | /// ``` |
64 | /// trunc(f(a, b)) = f(trunc(a), trunc(b)) |
65 | /// ``` |
66 | /// |
67 | /// For all values of `a` and `b`. The function accepts a lambda that computes |
68 | /// the integer result, which in turn must satisfy the above property. |
69 | static OpFoldResult foldBinaryOpUnchecked( |
70 | ArrayRef<Attribute> operands, |
71 | function_ref<std::optional<APInt>(const APInt &, const APInt &)> |
72 | calculate) { |
73 | assert(operands.size() == 2 && "binary operation expected 2 operands" ); |
74 | auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]); |
75 | auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]); |
76 | if (!lhs || !rhs) |
77 | return {}; |
78 | |
79 | std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue()); |
80 | if (!result) |
81 | return {}; |
82 | assert(result->trunc(32) == |
83 | calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32))); |
84 | return IntegerAttr::get(IndexType::get(lhs.getContext()), *result); |
85 | } |
86 | |
87 | /// Fold an index operation only if the truncated 64-bit result matches the |
88 | /// 32-bit result for operations that don't satisfy the above property. These |
89 | /// are operations where the upper bits of the operands can affect the lower |
90 | /// bits of the results. |
91 | /// |
92 | /// The function accepts a lambda that computes the integer result in both |
93 | /// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is |
94 | /// not folded. |
95 | static OpFoldResult foldBinaryOpChecked( |
96 | ArrayRef<Attribute> operands, |
97 | function_ref<std::optional<APInt>(const APInt &, const APInt &lhs)> |
98 | calculate) { |
99 | assert(operands.size() == 2 && "binary operation expected 2 operands" ); |
100 | auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]); |
101 | auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]); |
102 | // Only fold index operands. |
103 | if (!lhs || !rhs) |
104 | return {}; |
105 | |
106 | // Compute the 64-bit result and the 32-bit result. |
107 | std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue()); |
108 | if (!result64) |
109 | return {}; |
110 | std::optional<APInt> result32 = |
111 | calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)); |
112 | if (!result32) |
113 | return {}; |
114 | // Compare the truncated 64-bit result to the 32-bit result. |
115 | if (result64->trunc(width: 32) != *result32) |
116 | return {}; |
117 | // The operation can be folded for these particular operands. |
118 | return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64); |
119 | } |
120 | |
121 | //===----------------------------------------------------------------------===// |
122 | // AddOp |
123 | //===----------------------------------------------------------------------===// |
124 | |
125 | OpFoldResult AddOp::fold(FoldAdaptor adaptor) { |
126 | if (OpFoldResult result = foldBinaryOpUnchecked( |
127 | adaptor.getOperands(), |
128 | [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; })) |
129 | return result; |
130 | |
131 | if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) { |
132 | // Fold `add(x, 0) -> x`. |
133 | if (rhs.getValue().isZero()) |
134 | return getLhs(); |
135 | } |
136 | |
137 | return {}; |
138 | } |
139 | |
140 | //===----------------------------------------------------------------------===// |
141 | // SubOp |
142 | //===----------------------------------------------------------------------===// |
143 | |
144 | OpFoldResult SubOp::fold(FoldAdaptor adaptor) { |
145 | if (OpFoldResult result = foldBinaryOpUnchecked( |
146 | adaptor.getOperands(), |
147 | [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; })) |
148 | return result; |
149 | |
150 | if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) { |
151 | // Fold `sub(x, 0) -> x`. |
152 | if (rhs.getValue().isZero()) |
153 | return getLhs(); |
154 | } |
155 | |
156 | return {}; |
157 | } |
158 | |
159 | //===----------------------------------------------------------------------===// |
160 | // MulOp |
161 | //===----------------------------------------------------------------------===// |
162 | |
163 | OpFoldResult MulOp::fold(FoldAdaptor adaptor) { |
164 | if (OpFoldResult result = foldBinaryOpUnchecked( |
165 | adaptor.getOperands(), |
166 | [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; })) |
167 | return result; |
168 | |
169 | if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) { |
170 | // Fold `mul(x, 1) -> x`. |
171 | if (rhs.getValue().isOne()) |
172 | return getLhs(); |
173 | // Fold `mul(x, 0) -> 0`. |
174 | if (rhs.getValue().isZero()) |
175 | return rhs; |
176 | } |
177 | |
178 | return {}; |
179 | } |
180 | |
181 | //===----------------------------------------------------------------------===// |
182 | // DivSOp |
183 | //===----------------------------------------------------------------------===// |
184 | |
185 | OpFoldResult DivSOp::fold(FoldAdaptor adaptor) { |
186 | return foldBinaryOpChecked( |
187 | adaptor.getOperands(), |
188 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
189 | // Don't fold division by zero. |
190 | if (rhs.isZero()) |
191 | return std::nullopt; |
192 | return lhs.sdiv(rhs); |
193 | }); |
194 | } |
195 | |
196 | //===----------------------------------------------------------------------===// |
197 | // DivUOp |
198 | //===----------------------------------------------------------------------===// |
199 | |
200 | OpFoldResult DivUOp::fold(FoldAdaptor adaptor) { |
201 | return foldBinaryOpChecked( |
202 | adaptor.getOperands(), |
203 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
204 | // Don't fold division by zero. |
205 | if (rhs.isZero()) |
206 | return std::nullopt; |
207 | return lhs.udiv(rhs); |
208 | }); |
209 | } |
210 | |
211 | //===----------------------------------------------------------------------===// |
212 | // CeilDivSOp |
213 | //===----------------------------------------------------------------------===// |
214 | |
215 | /// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then |
216 | /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. |
217 | static std::optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) { |
218 | // Don't fold division by zero. |
219 | if (m.isZero()) |
220 | return std::nullopt; |
221 | // Short-circuit the zero case. |
222 | if (n.isZero()) |
223 | return n; |
224 | |
225 | bool mGtZ = m.sgt(RHS: 0); |
226 | if (n.sgt(RHS: 0) != mGtZ) { |
227 | // If the operands have different signs, compute the negative result. Signed |
228 | // division overflow is not possible, since if `m == -1`, `n` can be at most |
229 | // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement. |
230 | return -(-n).sdiv(RHS: m); |
231 | } |
232 | // Otherwise, compute the positive result. Signed division overflow is not |
233 | // possible since if `m == -1`, `x` will be `1`. |
234 | int64_t x = mGtZ ? -1 : 1; |
235 | return (n + x).sdiv(RHS: m) + 1; |
236 | } |
237 | |
238 | OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) { |
239 | return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS); |
240 | } |
241 | |
242 | //===----------------------------------------------------------------------===// |
243 | // CeilDivUOp |
244 | //===----------------------------------------------------------------------===// |
245 | |
246 | OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) { |
247 | // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`. |
248 | return foldBinaryOpChecked( |
249 | adaptor.getOperands(), |
250 | [](const APInt &n, const APInt &m) -> std::optional<APInt> { |
251 | // Don't fold division by zero. |
252 | if (m.isZero()) |
253 | return std::nullopt; |
254 | // Short-circuit the zero case. |
255 | if (n.isZero()) |
256 | return n; |
257 | |
258 | return (n - 1).udiv(m) + 1; |
259 | }); |
260 | } |
261 | |
262 | //===----------------------------------------------------------------------===// |
263 | // FloorDivSOp |
264 | //===----------------------------------------------------------------------===// |
265 | |
266 | /// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then |
267 | /// `n*m < 0 ? -1 - (x-n)/m : n/m`. |
268 | static std::optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) { |
269 | // Don't fold division by zero. |
270 | if (m.isZero()) |
271 | return std::nullopt; |
272 | // Short-circuit the zero case. |
273 | if (n.isZero()) |
274 | return n; |
275 | |
276 | bool mLtZ = m.slt(RHS: 0); |
277 | if (n.slt(RHS: 0) == mLtZ) { |
278 | // If the operands have the same sign, compute the positive result. |
279 | return n.sdiv(RHS: m); |
280 | } |
281 | // If the operands have different signs, compute the negative result. Signed |
282 | // division overflow is not possible since if `m == -1`, `x` will be 1 and |
283 | // `n` can be at most `INT_MAX`. |
284 | int64_t x = mLtZ ? 1 : -1; |
285 | return -1 - (x - n).sdiv(RHS: m); |
286 | } |
287 | |
288 | OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) { |
289 | return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS); |
290 | } |
291 | |
292 | //===----------------------------------------------------------------------===// |
293 | // RemSOp |
294 | //===----------------------------------------------------------------------===// |
295 | |
296 | OpFoldResult RemSOp::fold(FoldAdaptor adaptor) { |
297 | return foldBinaryOpChecked( |
298 | adaptor.getOperands(), |
299 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
300 | // Don't fold division by zero. |
301 | if (rhs.isZero()) |
302 | return std::nullopt; |
303 | return lhs.srem(rhs); |
304 | }); |
305 | } |
306 | |
307 | //===----------------------------------------------------------------------===// |
308 | // RemUOp |
309 | //===----------------------------------------------------------------------===// |
310 | |
311 | OpFoldResult RemUOp::fold(FoldAdaptor adaptor) { |
312 | return foldBinaryOpChecked( |
313 | adaptor.getOperands(), |
314 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
315 | // Don't fold division by zero. |
316 | if (rhs.isZero()) |
317 | return std::nullopt; |
318 | return lhs.urem(rhs); |
319 | }); |
320 | } |
321 | |
322 | //===----------------------------------------------------------------------===// |
323 | // MaxSOp |
324 | //===----------------------------------------------------------------------===// |
325 | |
326 | OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) { |
327 | return foldBinaryOpChecked(adaptor.getOperands(), |
328 | [](const APInt &lhs, const APInt &rhs) { |
329 | return lhs.sgt(rhs) ? lhs : rhs; |
330 | }); |
331 | } |
332 | |
333 | //===----------------------------------------------------------------------===// |
334 | // MaxUOp |
335 | //===----------------------------------------------------------------------===// |
336 | |
337 | OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) { |
338 | return foldBinaryOpChecked(adaptor.getOperands(), |
339 | [](const APInt &lhs, const APInt &rhs) { |
340 | return lhs.ugt(rhs) ? lhs : rhs; |
341 | }); |
342 | } |
343 | |
344 | //===----------------------------------------------------------------------===// |
345 | // MinSOp |
346 | //===----------------------------------------------------------------------===// |
347 | |
348 | OpFoldResult MinSOp::fold(FoldAdaptor adaptor) { |
349 | return foldBinaryOpChecked(adaptor.getOperands(), |
350 | [](const APInt &lhs, const APInt &rhs) { |
351 | return lhs.slt(rhs) ? lhs : rhs; |
352 | }); |
353 | } |
354 | |
355 | //===----------------------------------------------------------------------===// |
356 | // MinUOp |
357 | //===----------------------------------------------------------------------===// |
358 | |
359 | OpFoldResult MinUOp::fold(FoldAdaptor adaptor) { |
360 | return foldBinaryOpChecked(adaptor.getOperands(), |
361 | [](const APInt &lhs, const APInt &rhs) { |
362 | return lhs.ult(rhs) ? lhs : rhs; |
363 | }); |
364 | } |
365 | |
366 | //===----------------------------------------------------------------------===// |
367 | // ShlOp |
368 | //===----------------------------------------------------------------------===// |
369 | |
370 | OpFoldResult ShlOp::fold(FoldAdaptor adaptor) { |
371 | return foldBinaryOpUnchecked( |
372 | adaptor.getOperands(), |
373 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
374 | // We cannot fold if the RHS is greater than or equal to 32 because |
375 | // this would be UB in 32-bit systems but not on 64-bit systems. RHS is |
376 | // already treated as unsigned. |
377 | if (rhs.uge(32)) |
378 | return {}; |
379 | return lhs << rhs; |
380 | }); |
381 | } |
382 | |
383 | //===----------------------------------------------------------------------===// |
384 | // ShrSOp |
385 | //===----------------------------------------------------------------------===// |
386 | |
387 | OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) { |
388 | return foldBinaryOpChecked( |
389 | adaptor.getOperands(), |
390 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
391 | // Don't fold if RHS is greater than or equal to 32. |
392 | if (rhs.uge(32)) |
393 | return {}; |
394 | return lhs.ashr(rhs); |
395 | }); |
396 | } |
397 | |
398 | //===----------------------------------------------------------------------===// |
399 | // ShrUOp |
400 | //===----------------------------------------------------------------------===// |
401 | |
402 | OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) { |
403 | return foldBinaryOpChecked( |
404 | adaptor.getOperands(), |
405 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
406 | // Don't fold if RHS is greater than or equal to 32. |
407 | if (rhs.uge(32)) |
408 | return {}; |
409 | return lhs.lshr(rhs); |
410 | }); |
411 | } |
412 | |
413 | //===----------------------------------------------------------------------===// |
414 | // AndOp |
415 | //===----------------------------------------------------------------------===// |
416 | |
417 | OpFoldResult AndOp::fold(FoldAdaptor adaptor) { |
418 | return foldBinaryOpUnchecked( |
419 | adaptor.getOperands(), |
420 | [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; }); |
421 | } |
422 | |
423 | //===----------------------------------------------------------------------===// |
424 | // OrOp |
425 | //===----------------------------------------------------------------------===// |
426 | |
427 | OpFoldResult OrOp::fold(FoldAdaptor adaptor) { |
428 | return foldBinaryOpUnchecked( |
429 | adaptor.getOperands(), |
430 | [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; }); |
431 | } |
432 | |
433 | //===----------------------------------------------------------------------===// |
434 | // XOrOp |
435 | //===----------------------------------------------------------------------===// |
436 | |
437 | OpFoldResult XOrOp::fold(FoldAdaptor adaptor) { |
438 | return foldBinaryOpUnchecked( |
439 | adaptor.getOperands(), |
440 | [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; }); |
441 | } |
442 | |
443 | //===----------------------------------------------------------------------===// |
444 | // CastSOp |
445 | //===----------------------------------------------------------------------===// |
446 | |
447 | static OpFoldResult |
448 | foldCastOp(Attribute input, Type type, |
449 | function_ref<APInt(const APInt &, unsigned)> extFn, |
450 | function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) { |
451 | auto attr = dyn_cast_if_present<IntegerAttr>(input); |
452 | if (!attr) |
453 | return {}; |
454 | const APInt &value = attr.getValue(); |
455 | |
456 | if (isa<IndexType>(Val: type)) { |
457 | // When casting to an index type, perform the cast assuming a 64-bit target. |
458 | // The result can be truncated to 32 bits as needed and always be correct. |
459 | // This is because `cast32(cast64(value)) == cast32(value)`. |
460 | APInt result = extOrTruncFn(value, 64); |
461 | return IntegerAttr::get(type, result); |
462 | } |
463 | |
464 | // When casting from an index type, we must ensure the results respect |
465 | // `cast_t(value) == cast_t(trunc32(value))`. |
466 | auto intType = cast<IntegerType>(type); |
467 | unsigned width = intType.getWidth(); |
468 | |
469 | // If the result type is at most 32 bits, then the cast can always be folded |
470 | // because it is always a truncation. |
471 | if (width <= 32) { |
472 | APInt result = value.trunc(width); |
473 | return IntegerAttr::get(type, result); |
474 | } |
475 | |
476 | // If the result type is at least 64 bits, then the cast is always a |
477 | // extension. The results will differ if `trunc32(value) != value)`. |
478 | if (width >= 64) { |
479 | if (extFn(value.trunc(width: 32), 64) != value) |
480 | return {}; |
481 | APInt result = extFn(value, width); |
482 | return IntegerAttr::get(type, result); |
483 | } |
484 | |
485 | // Otherwise, we just have to check the property directly. |
486 | APInt result = value.trunc(width); |
487 | if (result != extFn(value.trunc(width: 32), width)) |
488 | return {}; |
489 | return IntegerAttr::get(type, result); |
490 | } |
491 | |
492 | bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { |
493 | return llvm::isa<IndexType>(lhsTypes.front()) != |
494 | llvm::isa<IndexType>(rhsTypes.front()); |
495 | } |
496 | |
497 | OpFoldResult CastSOp::fold(FoldAdaptor adaptor) { |
498 | return foldCastOp( |
499 | adaptor.getInput(), getType(), |
500 | [](const APInt &x, unsigned width) { return x.sext(width); }, |
501 | [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); }); |
502 | } |
503 | |
504 | //===----------------------------------------------------------------------===// |
505 | // CastUOp |
506 | //===----------------------------------------------------------------------===// |
507 | |
508 | bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { |
509 | return llvm::isa<IndexType>(lhsTypes.front()) != |
510 | llvm::isa<IndexType>(rhsTypes.front()); |
511 | } |
512 | |
513 | OpFoldResult CastUOp::fold(FoldAdaptor adaptor) { |
514 | return foldCastOp( |
515 | adaptor.getInput(), getType(), |
516 | [](const APInt &x, unsigned width) { return x.zext(width); }, |
517 | [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); }); |
518 | } |
519 | |
520 | //===----------------------------------------------------------------------===// |
521 | // CmpOp |
522 | //===----------------------------------------------------------------------===// |
523 | |
524 | /// Compare two integers according to the comparison predicate. |
525 | bool compareIndices(const APInt &lhs, const APInt &rhs, |
526 | IndexCmpPredicate pred) { |
527 | switch (pred) { |
528 | case IndexCmpPredicate::EQ: |
529 | return lhs.eq(RHS: rhs); |
530 | case IndexCmpPredicate::NE: |
531 | return lhs.ne(RHS: rhs); |
532 | case IndexCmpPredicate::SGE: |
533 | return lhs.sge(RHS: rhs); |
534 | case IndexCmpPredicate::SGT: |
535 | return lhs.sgt(RHS: rhs); |
536 | case IndexCmpPredicate::SLE: |
537 | return lhs.sle(RHS: rhs); |
538 | case IndexCmpPredicate::SLT: |
539 | return lhs.slt(RHS: rhs); |
540 | case IndexCmpPredicate::UGE: |
541 | return lhs.uge(RHS: rhs); |
542 | case IndexCmpPredicate::UGT: |
543 | return lhs.ugt(RHS: rhs); |
544 | case IndexCmpPredicate::ULE: |
545 | return lhs.ule(RHS: rhs); |
546 | case IndexCmpPredicate::ULT: |
547 | return lhs.ult(RHS: rhs); |
548 | } |
549 | llvm_unreachable("unhandled IndexCmpPredicate predicate" ); |
550 | } |
551 | |
552 | /// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the |
553 | /// values of `cstA` and `cstB`, the max or min operation, and the comparison |
554 | /// predicate. Check whether the value folds in both 32-bit and 64-bit |
555 | /// arithmetic and to the same value. |
556 | static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp, |
557 | const APInt &cstA, |
558 | const APInt &cstB, unsigned width, |
559 | IndexCmpPredicate pred) { |
560 | ConstantIntRanges lhsRange = TypeSwitch<Operation *, ConstantIntRanges>(lhsOp) |
561 | .Case(caseFn: [&](MinSOp op) { |
562 | return ConstantIntRanges::fromSigned( |
563 | smin: APInt::getSignedMinValue(numBits: width), smax: cstA); |
564 | }) |
565 | .Case(caseFn: [&](MinUOp op) { |
566 | return ConstantIntRanges::fromUnsigned( |
567 | umin: APInt::getMinValue(numBits: width), umax: cstA); |
568 | }) |
569 | .Case(caseFn: [&](MaxSOp op) { |
570 | return ConstantIntRanges::fromSigned( |
571 | smin: cstA, smax: APInt::getSignedMaxValue(numBits: width)); |
572 | }) |
573 | .Case(caseFn: [&](MaxUOp op) { |
574 | return ConstantIntRanges::fromUnsigned( |
575 | umin: cstA, umax: APInt::getMaxValue(numBits: width)); |
576 | }); |
577 | return intrange::evaluatePred(pred: static_cast<intrange::CmpPredicate>(pred), |
578 | lhs: lhsRange, rhs: ConstantIntRanges::constant(value: cstB)); |
579 | } |
580 | |
581 | /// Return the result of `cmp(pred, x, x)` |
582 | static bool compareSameArgs(IndexCmpPredicate pred) { |
583 | switch (pred) { |
584 | case IndexCmpPredicate::EQ: |
585 | case IndexCmpPredicate::SGE: |
586 | case IndexCmpPredicate::SLE: |
587 | case IndexCmpPredicate::UGE: |
588 | case IndexCmpPredicate::ULE: |
589 | return true; |
590 | case IndexCmpPredicate::NE: |
591 | case IndexCmpPredicate::SGT: |
592 | case IndexCmpPredicate::SLT: |
593 | case IndexCmpPredicate::UGT: |
594 | case IndexCmpPredicate::ULT: |
595 | return false; |
596 | } |
597 | } |
598 | |
599 | OpFoldResult CmpOp::fold(FoldAdaptor adaptor) { |
600 | // Attempt to fold if both inputs are constant. |
601 | auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs()); |
602 | auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()); |
603 | if (lhs && rhs) { |
604 | // Perform the comparison in 64-bit and 32-bit. |
605 | bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred()); |
606 | bool result32 = compareIndices(lhs.getValue().trunc(32), |
607 | rhs.getValue().trunc(32), getPred()); |
608 | if (result64 == result32) |
609 | return BoolAttr::get(getContext(), result64); |
610 | } |
611 | |
612 | // Fold `cmp(max/min(x, cstA), cstB)`. |
613 | Operation *lhsOp = getLhs().getDefiningOp(); |
614 | IntegerAttr cstA; |
615 | if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) && |
616 | matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) { |
617 | std::optional<bool> result64 = foldCmpOfMaxOrMin( |
618 | lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred()); |
619 | std::optional<bool> result32 = |
620 | foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32), |
621 | rhs.getValue().trunc(32), 32, getPred()); |
622 | // Fold if the 32-bit and 64-bit results are the same. |
623 | if (result64 && result32 && *result64 == *result32) |
624 | return BoolAttr::get(getContext(), *result64); |
625 | } |
626 | |
627 | // Fold `cmp(x, x)` |
628 | if (getLhs() == getRhs()) |
629 | return BoolAttr::get(getContext(), compareSameArgs(getPred())); |
630 | |
631 | return {}; |
632 | } |
633 | |
634 | /// Canonicalize |
635 | /// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`. |
636 | /// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`. |
637 | LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) { |
638 | IntegerAttr cmpRhs; |
639 | IntegerAttr cmpLhs; |
640 | |
641 | bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) && |
642 | cmpRhs.getValue().isZero(); |
643 | bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) && |
644 | cmpLhs.getValue().isZero(); |
645 | if (!rhsIsZero && !lhsIsZero) |
646 | return rewriter.notifyMatchFailure(op.getLoc(), |
647 | "cmp is not comparing something with 0" ); |
648 | SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>() |
649 | : op.getRhs().getDefiningOp<index::SubOp>(); |
650 | if (!subOp) |
651 | return rewriter.notifyMatchFailure( |
652 | op.getLoc(), "non-zero operand is not a result of subtraction" ); |
653 | |
654 | index::CmpOp newCmp; |
655 | if (rhsIsZero) |
656 | newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(), |
657 | subOp.getLhs(), subOp.getRhs()); |
658 | else |
659 | newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(), |
660 | subOp.getRhs(), subOp.getLhs()); |
661 | rewriter.replaceOp(op, newCmp); |
662 | return success(); |
663 | } |
664 | |
665 | //===----------------------------------------------------------------------===// |
666 | // ConstantOp |
667 | //===----------------------------------------------------------------------===// |
668 | |
669 | void ConstantOp::getAsmResultNames( |
670 | function_ref<void(Value, StringRef)> setNameFn) { |
671 | SmallString<32> specialNameBuffer; |
672 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
673 | specialName << "idx" << getValueAttr().getValue(); |
674 | setNameFn(getResult(), specialName.str()); |
675 | } |
676 | |
677 | OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } |
678 | |
679 | void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) { |
680 | build(b, state, b.getIndexType(), b.getIndexAttr(value)); |
681 | } |
682 | |
683 | //===----------------------------------------------------------------------===// |
684 | // BoolConstantOp |
685 | //===----------------------------------------------------------------------===// |
686 | |
687 | OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) { |
688 | return getValueAttr(); |
689 | } |
690 | |
691 | void BoolConstantOp::getAsmResultNames( |
692 | function_ref<void(Value, StringRef)> setNameFn) { |
693 | setNameFn(getResult(), getValue() ? "true" : "false" ); |
694 | } |
695 | |
696 | //===----------------------------------------------------------------------===// |
697 | // ODS-Generated Definitions |
698 | //===----------------------------------------------------------------------===// |
699 | |
700 | #define GET_OP_CLASSES |
701 | #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc" |
702 | |