1 | //===- IndexOps.cpp - Index operation definitions --------------------------==// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
10 | #include "mlir/Dialect/Index/IR/IndexAttrs.h" |
11 | #include "mlir/Dialect/Index/IR/IndexDialect.h" |
12 | #include "mlir/IR/Builders.h" |
13 | #include "mlir/IR/Matchers.h" |
14 | #include "mlir/IR/OpImplementation.h" |
15 | #include "mlir/IR/PatternMatch.h" |
16 | #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" |
17 | #include "llvm/ADT/SmallString.h" |
18 | #include "llvm/ADT/TypeSwitch.h" |
19 | |
20 | using namespace mlir; |
21 | using namespace mlir::index; |
22 | |
23 | //===----------------------------------------------------------------------===// |
24 | // IndexDialect |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | void IndexDialect::registerOperations() { |
28 | addOperations< |
29 | #define GET_OP_LIST |
30 | #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc" |
31 | >(); |
32 | } |
33 | |
34 | Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value, |
35 | Type type, Location loc) { |
36 | // Materialize bool constants as `i1`. |
37 | if (auto boolValue = dyn_cast<BoolAttr>(value)) { |
38 | if (!type.isSignlessInteger(1)) |
39 | return nullptr; |
40 | return b.create<BoolConstantOp>(loc, type, boolValue); |
41 | } |
42 | |
43 | // Materialize integer attributes as `index`. |
44 | if (auto indexValue = dyn_cast<IntegerAttr>(value)) { |
45 | if (!llvm::isa<IndexType>(indexValue.getType()) || |
46 | !llvm::isa<IndexType>(type)) |
47 | return nullptr; |
48 | assert(indexValue.getValue().getBitWidth() == |
49 | IndexType::kInternalStorageBitWidth); |
50 | return b.create<ConstantOp>(loc, indexValue); |
51 | } |
52 | |
53 | return nullptr; |
54 | } |
55 | |
56 | //===----------------------------------------------------------------------===// |
57 | // Fold Utilities |
58 | //===----------------------------------------------------------------------===// |
59 | |
60 | /// Fold an index operation irrespective of the target bitwidth. The |
61 | /// operation must satisfy the property: |
62 | /// |
63 | /// ``` |
64 | /// trunc(f(a, b)) = f(trunc(a), trunc(b)) |
65 | /// ``` |
66 | /// |
67 | /// For all values of `a` and `b`. The function accepts a lambda that computes |
68 | /// the integer result, which in turn must satisfy the above property. |
69 | static OpFoldResult foldBinaryOpUnchecked( |
70 | ArrayRef<Attribute> operands, |
71 | function_ref<std::optional<APInt>(const APInt &, const APInt &)> |
72 | calculate) { |
73 | assert(operands.size() == 2 && "binary operation expected 2 operands" ); |
74 | auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]); |
75 | auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]); |
76 | if (!lhs || !rhs) |
77 | return {}; |
78 | |
79 | std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue()); |
80 | if (!result) |
81 | return {}; |
82 | assert(result->trunc(32) == |
83 | calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32))); |
84 | return IntegerAttr::get(IndexType::get(lhs.getContext()), *result); |
85 | } |
86 | |
87 | /// Fold an index operation only if the truncated 64-bit result matches the |
88 | /// 32-bit result for operations that don't satisfy the above property. These |
89 | /// are operations where the upper bits of the operands can affect the lower |
90 | /// bits of the results. |
91 | /// |
92 | /// The function accepts a lambda that computes the integer result in both |
93 | /// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is |
94 | /// not folded. |
95 | static OpFoldResult foldBinaryOpChecked( |
96 | ArrayRef<Attribute> operands, |
97 | function_ref<std::optional<APInt>(const APInt &, const APInt &lhs)> |
98 | calculate) { |
99 | assert(operands.size() == 2 && "binary operation expected 2 operands" ); |
100 | auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]); |
101 | auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]); |
102 | // Only fold index operands. |
103 | if (!lhs || !rhs) |
104 | return {}; |
105 | |
106 | // Compute the 64-bit result and the 32-bit result. |
107 | std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue()); |
108 | if (!result64) |
109 | return {}; |
110 | std::optional<APInt> result32 = |
111 | calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)); |
112 | if (!result32) |
113 | return {}; |
114 | // Compare the truncated 64-bit result to the 32-bit result. |
115 | if (result64->trunc(width: 32) != *result32) |
116 | return {}; |
117 | // The operation can be folded for these particular operands. |
118 | return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64); |
119 | } |
120 | |
121 | /// Helper for associative and commutative binary ops that can be transformed: |
122 | /// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)` |
123 | /// where c1 and c2 are constants. It is expected that `tmp` will be folded. |
124 | template <typename BinaryOp> |
125 | LogicalResult |
126 | canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, |
127 | PatternRewriter &rewriter) { |
128 | if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant())) |
129 | return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant" ); |
130 | |
131 | auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>(); |
132 | if (!lhsOp) |
133 | return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp" ); |
134 | |
135 | if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant())) |
136 | return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant" ); |
137 | |
138 | Value c = rewriter.createOrFold<BinaryOp>(op->getLoc(), op.getRhs(), |
139 | lhsOp.getRhs()); |
140 | if (c.getDefiningOp<BinaryOp>()) |
141 | return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded" ); |
142 | |
143 | rewriter.replaceOpWithNewOp<BinaryOp>(op, lhsOp.getLhs(), c); |
144 | return success(); |
145 | } |
146 | |
147 | //===----------------------------------------------------------------------===// |
148 | // AddOp |
149 | //===----------------------------------------------------------------------===// |
150 | |
151 | OpFoldResult AddOp::fold(FoldAdaptor adaptor) { |
152 | if (OpFoldResult result = foldBinaryOpUnchecked( |
153 | adaptor.getOperands(), |
154 | [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; })) |
155 | return result; |
156 | |
157 | if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) { |
158 | // Fold `add(x, 0) -> x`. |
159 | if (rhs.getValue().isZero()) |
160 | return getLhs(); |
161 | } |
162 | |
163 | return {}; |
164 | } |
165 | |
166 | LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) { |
167 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
168 | } |
169 | |
170 | //===----------------------------------------------------------------------===// |
171 | // SubOp |
172 | //===----------------------------------------------------------------------===// |
173 | |
174 | OpFoldResult SubOp::fold(FoldAdaptor adaptor) { |
175 | if (OpFoldResult result = foldBinaryOpUnchecked( |
176 | adaptor.getOperands(), |
177 | [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; })) |
178 | return result; |
179 | |
180 | if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) { |
181 | // Fold `sub(x, 0) -> x`. |
182 | if (rhs.getValue().isZero()) |
183 | return getLhs(); |
184 | } |
185 | |
186 | return {}; |
187 | } |
188 | |
189 | //===----------------------------------------------------------------------===// |
190 | // MulOp |
191 | //===----------------------------------------------------------------------===// |
192 | |
193 | OpFoldResult MulOp::fold(FoldAdaptor adaptor) { |
194 | if (OpFoldResult result = foldBinaryOpUnchecked( |
195 | adaptor.getOperands(), |
196 | [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; })) |
197 | return result; |
198 | |
199 | if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) { |
200 | // Fold `mul(x, 1) -> x`. |
201 | if (rhs.getValue().isOne()) |
202 | return getLhs(); |
203 | // Fold `mul(x, 0) -> 0`. |
204 | if (rhs.getValue().isZero()) |
205 | return rhs; |
206 | } |
207 | |
208 | return {}; |
209 | } |
210 | |
211 | LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) { |
212 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
213 | } |
214 | |
215 | //===----------------------------------------------------------------------===// |
216 | // DivSOp |
217 | //===----------------------------------------------------------------------===// |
218 | |
219 | OpFoldResult DivSOp::fold(FoldAdaptor adaptor) { |
220 | return foldBinaryOpChecked( |
221 | adaptor.getOperands(), |
222 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
223 | // Don't fold division by zero. |
224 | if (rhs.isZero()) |
225 | return std::nullopt; |
226 | return lhs.sdiv(rhs); |
227 | }); |
228 | } |
229 | |
230 | //===----------------------------------------------------------------------===// |
231 | // DivUOp |
232 | //===----------------------------------------------------------------------===// |
233 | |
234 | OpFoldResult DivUOp::fold(FoldAdaptor adaptor) { |
235 | return foldBinaryOpChecked( |
236 | adaptor.getOperands(), |
237 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
238 | // Don't fold division by zero. |
239 | if (rhs.isZero()) |
240 | return std::nullopt; |
241 | return lhs.udiv(rhs); |
242 | }); |
243 | } |
244 | |
245 | //===----------------------------------------------------------------------===// |
246 | // CeilDivSOp |
247 | //===----------------------------------------------------------------------===// |
248 | |
249 | /// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then |
250 | /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. |
251 | static std::optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) { |
252 | // Don't fold division by zero. |
253 | if (m.isZero()) |
254 | return std::nullopt; |
255 | // Short-circuit the zero case. |
256 | if (n.isZero()) |
257 | return n; |
258 | |
259 | bool mGtZ = m.sgt(RHS: 0); |
260 | if (n.sgt(RHS: 0) != mGtZ) { |
261 | // If the operands have different signs, compute the negative result. Signed |
262 | // division overflow is not possible, since if `m == -1`, `n` can be at most |
263 | // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement. |
264 | return -(-n).sdiv(RHS: m); |
265 | } |
266 | // Otherwise, compute the positive result. Signed division overflow is not |
267 | // possible since if `m == -1`, `x` will be `1`. |
268 | int64_t x = mGtZ ? -1 : 1; |
269 | return (n + x).sdiv(RHS: m) + 1; |
270 | } |
271 | |
272 | OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) { |
273 | return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS); |
274 | } |
275 | |
276 | //===----------------------------------------------------------------------===// |
277 | // CeilDivUOp |
278 | //===----------------------------------------------------------------------===// |
279 | |
280 | OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) { |
281 | // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`. |
282 | return foldBinaryOpChecked( |
283 | adaptor.getOperands(), |
284 | [](const APInt &n, const APInt &m) -> std::optional<APInt> { |
285 | // Don't fold division by zero. |
286 | if (m.isZero()) |
287 | return std::nullopt; |
288 | // Short-circuit the zero case. |
289 | if (n.isZero()) |
290 | return n; |
291 | |
292 | return (n - 1).udiv(m) + 1; |
293 | }); |
294 | } |
295 | |
296 | //===----------------------------------------------------------------------===// |
297 | // FloorDivSOp |
298 | //===----------------------------------------------------------------------===// |
299 | |
300 | /// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then |
301 | /// `n*m < 0 ? -1 - (x-n)/m : n/m`. |
302 | static std::optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) { |
303 | // Don't fold division by zero. |
304 | if (m.isZero()) |
305 | return std::nullopt; |
306 | // Short-circuit the zero case. |
307 | if (n.isZero()) |
308 | return n; |
309 | |
310 | bool mLtZ = m.slt(RHS: 0); |
311 | if (n.slt(RHS: 0) == mLtZ) { |
312 | // If the operands have the same sign, compute the positive result. |
313 | return n.sdiv(RHS: m); |
314 | } |
315 | // If the operands have different signs, compute the negative result. Signed |
316 | // division overflow is not possible since if `m == -1`, `x` will be 1 and |
317 | // `n` can be at most `INT_MAX`. |
318 | int64_t x = mLtZ ? 1 : -1; |
319 | return -1 - (x - n).sdiv(RHS: m); |
320 | } |
321 | |
322 | OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) { |
323 | return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS); |
324 | } |
325 | |
326 | //===----------------------------------------------------------------------===// |
327 | // RemSOp |
328 | //===----------------------------------------------------------------------===// |
329 | |
330 | OpFoldResult RemSOp::fold(FoldAdaptor adaptor) { |
331 | return foldBinaryOpChecked( |
332 | adaptor.getOperands(), |
333 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
334 | // Don't fold division by zero. |
335 | if (rhs.isZero()) |
336 | return std::nullopt; |
337 | return lhs.srem(rhs); |
338 | }); |
339 | } |
340 | |
341 | //===----------------------------------------------------------------------===// |
342 | // RemUOp |
343 | //===----------------------------------------------------------------------===// |
344 | |
345 | OpFoldResult RemUOp::fold(FoldAdaptor adaptor) { |
346 | return foldBinaryOpChecked( |
347 | adaptor.getOperands(), |
348 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
349 | // Don't fold division by zero. |
350 | if (rhs.isZero()) |
351 | return std::nullopt; |
352 | return lhs.urem(rhs); |
353 | }); |
354 | } |
355 | |
356 | //===----------------------------------------------------------------------===// |
357 | // MaxSOp |
358 | //===----------------------------------------------------------------------===// |
359 | |
360 | OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) { |
361 | return foldBinaryOpChecked(adaptor.getOperands(), |
362 | [](const APInt &lhs, const APInt &rhs) { |
363 | return lhs.sgt(rhs) ? lhs : rhs; |
364 | }); |
365 | } |
366 | |
367 | LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) { |
368 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
369 | } |
370 | |
371 | //===----------------------------------------------------------------------===// |
372 | // MaxUOp |
373 | //===----------------------------------------------------------------------===// |
374 | |
375 | OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) { |
376 | return foldBinaryOpChecked(adaptor.getOperands(), |
377 | [](const APInt &lhs, const APInt &rhs) { |
378 | return lhs.ugt(rhs) ? lhs : rhs; |
379 | }); |
380 | } |
381 | |
382 | LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) { |
383 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
384 | } |
385 | |
386 | //===----------------------------------------------------------------------===// |
387 | // MinSOp |
388 | //===----------------------------------------------------------------------===// |
389 | |
390 | OpFoldResult MinSOp::fold(FoldAdaptor adaptor) { |
391 | return foldBinaryOpChecked(adaptor.getOperands(), |
392 | [](const APInt &lhs, const APInt &rhs) { |
393 | return lhs.slt(rhs) ? lhs : rhs; |
394 | }); |
395 | } |
396 | |
397 | LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) { |
398 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
399 | } |
400 | |
401 | //===----------------------------------------------------------------------===// |
402 | // MinUOp |
403 | //===----------------------------------------------------------------------===// |
404 | |
405 | OpFoldResult MinUOp::fold(FoldAdaptor adaptor) { |
406 | return foldBinaryOpChecked(adaptor.getOperands(), |
407 | [](const APInt &lhs, const APInt &rhs) { |
408 | return lhs.ult(rhs) ? lhs : rhs; |
409 | }); |
410 | } |
411 | |
412 | LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) { |
413 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
414 | } |
415 | |
416 | //===----------------------------------------------------------------------===// |
417 | // ShlOp |
418 | //===----------------------------------------------------------------------===// |
419 | |
420 | OpFoldResult ShlOp::fold(FoldAdaptor adaptor) { |
421 | return foldBinaryOpUnchecked( |
422 | adaptor.getOperands(), |
423 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
424 | // We cannot fold if the RHS is greater than or equal to 32 because |
425 | // this would be UB in 32-bit systems but not on 64-bit systems. RHS is |
426 | // already treated as unsigned. |
427 | if (rhs.uge(32)) |
428 | return {}; |
429 | return lhs << rhs; |
430 | }); |
431 | } |
432 | |
433 | //===----------------------------------------------------------------------===// |
434 | // ShrSOp |
435 | //===----------------------------------------------------------------------===// |
436 | |
437 | OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) { |
438 | return foldBinaryOpChecked( |
439 | adaptor.getOperands(), |
440 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
441 | // Don't fold if RHS is greater than or equal to 32. |
442 | if (rhs.uge(32)) |
443 | return {}; |
444 | return lhs.ashr(rhs); |
445 | }); |
446 | } |
447 | |
448 | //===----------------------------------------------------------------------===// |
449 | // ShrUOp |
450 | //===----------------------------------------------------------------------===// |
451 | |
452 | OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) { |
453 | return foldBinaryOpChecked( |
454 | adaptor.getOperands(), |
455 | [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> { |
456 | // Don't fold if RHS is greater than or equal to 32. |
457 | if (rhs.uge(32)) |
458 | return {}; |
459 | return lhs.lshr(rhs); |
460 | }); |
461 | } |
462 | |
463 | //===----------------------------------------------------------------------===// |
464 | // AndOp |
465 | //===----------------------------------------------------------------------===// |
466 | |
467 | OpFoldResult AndOp::fold(FoldAdaptor adaptor) { |
468 | return foldBinaryOpUnchecked( |
469 | adaptor.getOperands(), |
470 | [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; }); |
471 | } |
472 | |
473 | LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) { |
474 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
475 | } |
476 | |
477 | //===----------------------------------------------------------------------===// |
478 | // OrOp |
479 | //===----------------------------------------------------------------------===// |
480 | |
481 | OpFoldResult OrOp::fold(FoldAdaptor adaptor) { |
482 | return foldBinaryOpUnchecked( |
483 | adaptor.getOperands(), |
484 | [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; }); |
485 | } |
486 | |
487 | LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) { |
488 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
489 | } |
490 | |
491 | //===----------------------------------------------------------------------===// |
492 | // XOrOp |
493 | //===----------------------------------------------------------------------===// |
494 | |
495 | OpFoldResult XOrOp::fold(FoldAdaptor adaptor) { |
496 | return foldBinaryOpUnchecked( |
497 | adaptor.getOperands(), |
498 | [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; }); |
499 | } |
500 | |
501 | LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) { |
502 | return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter); |
503 | } |
504 | |
505 | //===----------------------------------------------------------------------===// |
506 | // CastSOp |
507 | //===----------------------------------------------------------------------===// |
508 | |
509 | static OpFoldResult |
510 | foldCastOp(Attribute input, Type type, |
511 | function_ref<APInt(const APInt &, unsigned)> extFn, |
512 | function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) { |
513 | auto attr = dyn_cast_if_present<IntegerAttr>(input); |
514 | if (!attr) |
515 | return {}; |
516 | const APInt &value = attr.getValue(); |
517 | |
518 | if (isa<IndexType>(Val: type)) { |
519 | // When casting to an index type, perform the cast assuming a 64-bit target. |
520 | // The result can be truncated to 32 bits as needed and always be correct. |
521 | // This is because `cast32(cast64(value)) == cast32(value)`. |
522 | APInt result = extOrTruncFn(value, 64); |
523 | return IntegerAttr::get(type, result); |
524 | } |
525 | |
526 | // When casting from an index type, we must ensure the results respect |
527 | // `cast_t(value) == cast_t(trunc32(value))`. |
528 | auto intType = cast<IntegerType>(type); |
529 | unsigned width = intType.getWidth(); |
530 | |
531 | // If the result type is at most 32 bits, then the cast can always be folded |
532 | // because it is always a truncation. |
533 | if (width <= 32) { |
534 | APInt result = value.trunc(width); |
535 | return IntegerAttr::get(type, result); |
536 | } |
537 | |
538 | // If the result type is at least 64 bits, then the cast is always a |
539 | // extension. The results will differ if `trunc32(value) != value)`. |
540 | if (width >= 64) { |
541 | if (extFn(value.trunc(width: 32), 64) != value) |
542 | return {}; |
543 | APInt result = extFn(value, width); |
544 | return IntegerAttr::get(type, result); |
545 | } |
546 | |
547 | // Otherwise, we just have to check the property directly. |
548 | APInt result = value.trunc(width); |
549 | if (result != extFn(value.trunc(width: 32), width)) |
550 | return {}; |
551 | return IntegerAttr::get(type, result); |
552 | } |
553 | |
554 | bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { |
555 | return llvm::isa<IndexType>(lhsTypes.front()) != |
556 | llvm::isa<IndexType>(rhsTypes.front()); |
557 | } |
558 | |
559 | OpFoldResult CastSOp::fold(FoldAdaptor adaptor) { |
560 | return foldCastOp( |
561 | adaptor.getInput(), getType(), |
562 | [](const APInt &x, unsigned width) { return x.sext(width); }, |
563 | [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); }); |
564 | } |
565 | |
566 | //===----------------------------------------------------------------------===// |
567 | // CastUOp |
568 | //===----------------------------------------------------------------------===// |
569 | |
570 | bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { |
571 | return llvm::isa<IndexType>(lhsTypes.front()) != |
572 | llvm::isa<IndexType>(rhsTypes.front()); |
573 | } |
574 | |
575 | OpFoldResult CastUOp::fold(FoldAdaptor adaptor) { |
576 | return foldCastOp( |
577 | adaptor.getInput(), getType(), |
578 | [](const APInt &x, unsigned width) { return x.zext(width); }, |
579 | [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); }); |
580 | } |
581 | |
582 | //===----------------------------------------------------------------------===// |
583 | // CmpOp |
584 | //===----------------------------------------------------------------------===// |
585 | |
586 | /// Compare two integers according to the comparison predicate. |
587 | bool compareIndices(const APInt &lhs, const APInt &rhs, |
588 | IndexCmpPredicate pred) { |
589 | switch (pred) { |
590 | case IndexCmpPredicate::EQ: |
591 | return lhs.eq(RHS: rhs); |
592 | case IndexCmpPredicate::NE: |
593 | return lhs.ne(RHS: rhs); |
594 | case IndexCmpPredicate::SGE: |
595 | return lhs.sge(RHS: rhs); |
596 | case IndexCmpPredicate::SGT: |
597 | return lhs.sgt(RHS: rhs); |
598 | case IndexCmpPredicate::SLE: |
599 | return lhs.sle(RHS: rhs); |
600 | case IndexCmpPredicate::SLT: |
601 | return lhs.slt(RHS: rhs); |
602 | case IndexCmpPredicate::UGE: |
603 | return lhs.uge(RHS: rhs); |
604 | case IndexCmpPredicate::UGT: |
605 | return lhs.ugt(RHS: rhs); |
606 | case IndexCmpPredicate::ULE: |
607 | return lhs.ule(RHS: rhs); |
608 | case IndexCmpPredicate::ULT: |
609 | return lhs.ult(RHS: rhs); |
610 | } |
611 | llvm_unreachable("unhandled IndexCmpPredicate predicate" ); |
612 | } |
613 | |
614 | /// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the |
615 | /// values of `cstA` and `cstB`, the max or min operation, and the comparison |
616 | /// predicate. Check whether the value folds in both 32-bit and 64-bit |
617 | /// arithmetic and to the same value. |
618 | static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp, |
619 | const APInt &cstA, |
620 | const APInt &cstB, unsigned width, |
621 | IndexCmpPredicate pred) { |
622 | ConstantIntRanges lhsRange = TypeSwitch<Operation *, ConstantIntRanges>(lhsOp) |
623 | .Case(caseFn: [&](MinSOp op) { |
624 | return ConstantIntRanges::fromSigned( |
625 | smin: APInt::getSignedMinValue(numBits: width), smax: cstA); |
626 | }) |
627 | .Case(caseFn: [&](MinUOp op) { |
628 | return ConstantIntRanges::fromUnsigned( |
629 | umin: APInt::getMinValue(numBits: width), umax: cstA); |
630 | }) |
631 | .Case(caseFn: [&](MaxSOp op) { |
632 | return ConstantIntRanges::fromSigned( |
633 | smin: cstA, smax: APInt::getSignedMaxValue(numBits: width)); |
634 | }) |
635 | .Case(caseFn: [&](MaxUOp op) { |
636 | return ConstantIntRanges::fromUnsigned( |
637 | umin: cstA, umax: APInt::getMaxValue(numBits: width)); |
638 | }); |
639 | return intrange::evaluatePred(pred: static_cast<intrange::CmpPredicate>(pred), |
640 | lhs: lhsRange, rhs: ConstantIntRanges::constant(value: cstB)); |
641 | } |
642 | |
643 | /// Return the result of `cmp(pred, x, x)` |
644 | static bool compareSameArgs(IndexCmpPredicate pred) { |
645 | switch (pred) { |
646 | case IndexCmpPredicate::EQ: |
647 | case IndexCmpPredicate::SGE: |
648 | case IndexCmpPredicate::SLE: |
649 | case IndexCmpPredicate::UGE: |
650 | case IndexCmpPredicate::ULE: |
651 | return true; |
652 | case IndexCmpPredicate::NE: |
653 | case IndexCmpPredicate::SGT: |
654 | case IndexCmpPredicate::SLT: |
655 | case IndexCmpPredicate::UGT: |
656 | case IndexCmpPredicate::ULT: |
657 | return false; |
658 | } |
659 | llvm_unreachable("unknown predicate in compareSameArgs" ); |
660 | } |
661 | |
662 | OpFoldResult CmpOp::fold(FoldAdaptor adaptor) { |
663 | // Attempt to fold if both inputs are constant. |
664 | auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs()); |
665 | auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()); |
666 | if (lhs && rhs) { |
667 | // Perform the comparison in 64-bit and 32-bit. |
668 | bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred()); |
669 | bool result32 = compareIndices(lhs.getValue().trunc(32), |
670 | rhs.getValue().trunc(32), getPred()); |
671 | if (result64 == result32) |
672 | return BoolAttr::get(getContext(), result64); |
673 | } |
674 | |
675 | // Fold `cmp(max/min(x, cstA), cstB)`. |
676 | Operation *lhsOp = getLhs().getDefiningOp(); |
677 | IntegerAttr cstA; |
678 | if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) && |
679 | matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) { |
680 | std::optional<bool> result64 = foldCmpOfMaxOrMin( |
681 | lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred()); |
682 | std::optional<bool> result32 = |
683 | foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32), |
684 | rhs.getValue().trunc(32), 32, getPred()); |
685 | // Fold if the 32-bit and 64-bit results are the same. |
686 | if (result64 && result32 && *result64 == *result32) |
687 | return BoolAttr::get(getContext(), *result64); |
688 | } |
689 | |
690 | // Fold `cmp(x, x)` |
691 | if (getLhs() == getRhs()) |
692 | return BoolAttr::get(getContext(), compareSameArgs(getPred())); |
693 | |
694 | return {}; |
695 | } |
696 | |
697 | /// Canonicalize |
698 | /// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`. |
699 | /// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`. |
700 | LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) { |
701 | IntegerAttr cmpRhs; |
702 | IntegerAttr cmpLhs; |
703 | |
704 | bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) && |
705 | cmpRhs.getValue().isZero(); |
706 | bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) && |
707 | cmpLhs.getValue().isZero(); |
708 | if (!rhsIsZero && !lhsIsZero) |
709 | return rewriter.notifyMatchFailure(op.getLoc(), |
710 | "cmp is not comparing something with 0" ); |
711 | SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>() |
712 | : op.getRhs().getDefiningOp<index::SubOp>(); |
713 | if (!subOp) |
714 | return rewriter.notifyMatchFailure( |
715 | op.getLoc(), "non-zero operand is not a result of subtraction" ); |
716 | |
717 | index::CmpOp newCmp; |
718 | if (rhsIsZero) |
719 | newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(), |
720 | subOp.getLhs(), subOp.getRhs()); |
721 | else |
722 | newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(), |
723 | subOp.getRhs(), subOp.getLhs()); |
724 | rewriter.replaceOp(op, newCmp); |
725 | return success(); |
726 | } |
727 | |
728 | //===----------------------------------------------------------------------===// |
729 | // ConstantOp |
730 | //===----------------------------------------------------------------------===// |
731 | |
732 | void ConstantOp::getAsmResultNames( |
733 | function_ref<void(Value, StringRef)> setNameFn) { |
734 | SmallString<32> specialNameBuffer; |
735 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
736 | specialName << "idx" << getValueAttr().getValue(); |
737 | setNameFn(getResult(), specialName.str()); |
738 | } |
739 | |
740 | OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } |
741 | |
742 | void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) { |
743 | build(b, state, b.getIndexType(), b.getIndexAttr(value)); |
744 | } |
745 | |
746 | //===----------------------------------------------------------------------===// |
747 | // BoolConstantOp |
748 | //===----------------------------------------------------------------------===// |
749 | |
750 | OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) { |
751 | return getValueAttr(); |
752 | } |
753 | |
754 | void BoolConstantOp::getAsmResultNames( |
755 | function_ref<void(Value, StringRef)> setNameFn) { |
756 | setNameFn(getResult(), getValue() ? "true" : "false" ); |
757 | } |
758 | |
759 | //===----------------------------------------------------------------------===// |
760 | // ODS-Generated Definitions |
761 | //===----------------------------------------------------------------------===// |
762 | |
763 | #define GET_OP_CLASSES |
764 | #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc" |
765 | |