1//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cassert>
10#include <cstdint>
11#include <functional>
12#include <utility>
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/CommonFolders.h"
16#include "mlir/Dialect/UB/IR/UBOps.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinAttributeInterfaces.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/OpImplementation.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/IR/TypeUtilities.h"
24#include "mlir/Support/LogicalResult.h"
25
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallString.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34
35using namespace mlir;
36using namespace mlir::arith;
37
38//===----------------------------------------------------------------------===//
39// Pattern helpers
40//===----------------------------------------------------------------------===//
41
42static IntegerAttr
43applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
44 Attribute rhs,
45 function_ref<APInt(const APInt &, const APInt &)> binFn) {
46 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48 APInt value = binFn(lhsVal, rhsVal);
49 return IntegerAttr::get(res.getType(), value);
50}
51
52static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53 Attribute lhs, Attribute rhs) {
54 return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
55}
56
57static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58 Attribute lhs, Attribute rhs) {
59 return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
60}
61
62static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63 Attribute lhs, Attribute rhs) {
64 return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
65}
66
67// Merge overflow flags from 2 ops, selecting the most conservative combination.
68static IntegerOverflowFlagsAttr
69mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
70 IntegerOverflowFlagsAttr val2) {
71 return IntegerOverflowFlagsAttr::get(val1.getContext(),
72 val1.getValue() & val2.getValue());
73}
74
75/// Invert an integer comparison predicate.
76arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
77 switch (pred) {
78 case arith::CmpIPredicate::eq:
79 return arith::CmpIPredicate::ne;
80 case arith::CmpIPredicate::ne:
81 return arith::CmpIPredicate::eq;
82 case arith::CmpIPredicate::slt:
83 return arith::CmpIPredicate::sge;
84 case arith::CmpIPredicate::sle:
85 return arith::CmpIPredicate::sgt;
86 case arith::CmpIPredicate::sgt:
87 return arith::CmpIPredicate::sle;
88 case arith::CmpIPredicate::sge:
89 return arith::CmpIPredicate::slt;
90 case arith::CmpIPredicate::ult:
91 return arith::CmpIPredicate::uge;
92 case arith::CmpIPredicate::ule:
93 return arith::CmpIPredicate::ugt;
94 case arith::CmpIPredicate::ugt:
95 return arith::CmpIPredicate::ule;
96 case arith::CmpIPredicate::uge:
97 return arith::CmpIPredicate::ult;
98 }
99 llvm_unreachable("unknown cmpi predicate kind");
100}
101
102/// Equivalent to
103/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
104///
105/// Not possible to implement as chain of calls as this would introduce a
106/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
107/// on the LLVM dialect and on translation to LLVM.
108static llvm::RoundingMode
109convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
110 switch (roundingMode) {
111 case RoundingMode::downward:
112 return llvm::RoundingMode::TowardNegative;
113 case RoundingMode::to_nearest_away:
114 return llvm::RoundingMode::NearestTiesToAway;
115 case RoundingMode::to_nearest_even:
116 return llvm::RoundingMode::NearestTiesToEven;
117 case RoundingMode::toward_zero:
118 return llvm::RoundingMode::TowardZero;
119 case RoundingMode::upward:
120 return llvm::RoundingMode::TowardPositive;
121 }
122 llvm_unreachable("Unhandled rounding mode");
123}
124
125static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
126 return arith::CmpIPredicateAttr::get(pred.getContext(),
127 invertPredicate(pred.getValue()));
128}
129
130static int64_t getScalarOrElementWidth(Type type) {
131 Type elemTy = getElementTypeOrSelf(type);
132 if (elemTy.isIntOrFloat())
133 return elemTy.getIntOrFloatBitWidth();
134
135 return -1;
136}
137
138static int64_t getScalarOrElementWidth(Value value) {
139 return getScalarOrElementWidth(type: value.getType());
140}
141
142static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
143 APInt value;
144 if (matchPattern(attr, m_ConstantInt(&value)))
145 return value;
146
147 return failure();
148}
149
150static Attribute getBoolAttribute(Type type, bool value) {
151 auto boolAttr = BoolAttr::get(context: type.getContext(), value);
152 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
153 if (!shapedType)
154 return boolAttr;
155 return DenseElementsAttr::get(shapedType, boolAttr);
156}
157
158//===----------------------------------------------------------------------===//
159// TableGen'd canonicalization patterns
160//===----------------------------------------------------------------------===//
161
162namespace {
163#include "ArithCanonicalization.inc"
164} // namespace
165
166//===----------------------------------------------------------------------===//
167// Common helpers
168//===----------------------------------------------------------------------===//
169
170/// Return the type of the same shape (scalar, vector or tensor) containing i1.
171static Type getI1SameShape(Type type) {
172 auto i1Type = IntegerType::get(type.getContext(), 1);
173 if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
174 return shapedType.cloneWith(std::nullopt, i1Type);
175 if (llvm::isa<UnrankedTensorType>(type))
176 return UnrankedTensorType::get(i1Type);
177 return i1Type;
178}
179
180//===----------------------------------------------------------------------===//
181// ConstantOp
182//===----------------------------------------------------------------------===//
183
184void arith::ConstantOp::getAsmResultNames(
185 function_ref<void(Value, StringRef)> setNameFn) {
186 auto type = getType();
187 if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
188 auto intType = llvm::dyn_cast<IntegerType>(type);
189
190 // Sugar i1 constants with 'true' and 'false'.
191 if (intType && intType.getWidth() == 1)
192 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
193
194 // Otherwise, build a complex name with the value and type.
195 SmallString<32> specialNameBuffer;
196 llvm::raw_svector_ostream specialName(specialNameBuffer);
197 specialName << 'c' << intCst.getValue();
198 if (intType)
199 specialName << '_' << type;
200 setNameFn(getResult(), specialName.str());
201 } else {
202 setNameFn(getResult(), "cst");
203 }
204}
205
206/// TODO: disallow arith.constant to return anything other than signless integer
207/// or float like.
208LogicalResult arith::ConstantOp::verify() {
209 auto type = getType();
210 // Integer values must be signless.
211 if (llvm::isa<IntegerType>(type) &&
212 !llvm::cast<IntegerType>(type).isSignless())
213 return emitOpError("integer return type must be signless");
214 // Any float or elements attribute are acceptable.
215 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
216 return emitOpError(
217 "value must be an integer, float, or elements attribute");
218 }
219
220 // Note, we could relax this for vectors with 1 scalable dim, e.g.:
221 // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
222 // However, this would most likely require updating the lowerings to LLVM.
223 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
224 return emitOpError(
225 "intializing scalable vectors with elements attribute is not supported"
226 " unless it's a vector splat");
227 return success();
228}
229
230bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
231 // The value's type must be the same as the provided type.
232 auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
233 if (!typedAttr || typedAttr.getType() != type)
234 return false;
235 // Integer values must be signless.
236 if (llvm::isa<IntegerType>(type) &&
237 !llvm::cast<IntegerType>(type).isSignless())
238 return false;
239 // Integer, float, and element attributes are buildable.
240 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
241}
242
243ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
244 Type type, Location loc) {
245 if (isBuildableWith(value, type))
246 return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
247 return nullptr;
248}
249
250OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
251
252void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
253 int64_t value, unsigned width) {
254 auto type = builder.getIntegerType(width);
255 arith::ConstantOp::build(builder, result, type,
256 builder.getIntegerAttr(type, value));
257}
258
259void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
260 int64_t value, Type type) {
261 assert(type.isSignlessInteger() &&
262 "ConstantIntOp can only have signless integer type values");
263 arith::ConstantOp::build(builder, result, type,
264 builder.getIntegerAttr(type, value));
265}
266
267bool arith::ConstantIntOp::classof(Operation *op) {
268 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
269 return constOp.getType().isSignlessInteger();
270 return false;
271}
272
273void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
274 const APFloat &value, FloatType type) {
275 arith::ConstantOp::build(builder, result, type,
276 builder.getFloatAttr(type, value));
277}
278
279bool arith::ConstantFloatOp::classof(Operation *op) {
280 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
281 return llvm::isa<FloatType>(constOp.getType());
282 return false;
283}
284
285void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
286 int64_t value) {
287 arith::ConstantOp::build(builder, result, builder.getIndexType(),
288 builder.getIndexAttr(value));
289}
290
291bool arith::ConstantIndexOp::classof(Operation *op) {
292 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
293 return constOp.getType().isIndex();
294 return false;
295}
296
297//===----------------------------------------------------------------------===//
298// AddIOp
299//===----------------------------------------------------------------------===//
300
301OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
302 // addi(x, 0) -> x
303 if (matchPattern(adaptor.getRhs(), m_Zero()))
304 return getLhs();
305
306 // addi(subi(a, b), b) -> a
307 if (auto sub = getLhs().getDefiningOp<SubIOp>())
308 if (getRhs() == sub.getRhs())
309 return sub.getLhs();
310
311 // addi(b, subi(a, b)) -> a
312 if (auto sub = getRhs().getDefiningOp<SubIOp>())
313 if (getLhs() == sub.getRhs())
314 return sub.getLhs();
315
316 return constFoldBinaryOp<IntegerAttr>(
317 adaptor.getOperands(),
318 [](APInt a, const APInt &b) { return std::move(a) + b; });
319}
320
321void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
322 MLIRContext *context) {
323 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
324 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
325}
326
327//===----------------------------------------------------------------------===//
328// AddUIExtendedOp
329//===----------------------------------------------------------------------===//
330
331std::optional<SmallVector<int64_t, 4>>
332arith::AddUIExtendedOp::getShapeForUnroll() {
333 if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
334 return llvm::to_vector<4>(vt.getShape());
335 return std::nullopt;
336}
337
338// Returns the overflow bit, assuming that `sum` is the result of unsigned
339// addition of `operand` and another number.
340static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
341 return sum.ult(RHS: operand) ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1);
342}
343
344LogicalResult
345arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
346 SmallVectorImpl<OpFoldResult> &results) {
347 Type overflowTy = getOverflow().getType();
348 // addui_extended(x, 0) -> x, false
349 if (matchPattern(getRhs(), m_Zero())) {
350 Builder builder(getContext());
351 auto falseValue = builder.getZeroAttr(overflowTy);
352
353 results.push_back(getLhs());
354 results.push_back(falseValue);
355 return success();
356 }
357
358 // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
359 // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
360 // operands. If that succeeds, calculate the overflow bit based on the sum
361 // and the first (constant) operand, `lhs`.
362 if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
363 adaptor.getOperands(),
364 [](APInt a, const APInt &b) { return std::move(a) + b; })) {
365 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
366 ArrayRef({sumAttr, adaptor.getLhs()}),
367 getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
368 calculateUnsignedOverflow);
369 if (!overflowAttr)
370 return failure();
371
372 results.push_back(sumAttr);
373 results.push_back(overflowAttr);
374 return success();
375 }
376
377 return failure();
378}
379
380void arith::AddUIExtendedOp::getCanonicalizationPatterns(
381 RewritePatternSet &patterns, MLIRContext *context) {
382 patterns.add<AddUIExtendedToAddI>(context);
383}
384
385//===----------------------------------------------------------------------===//
386// SubIOp
387//===----------------------------------------------------------------------===//
388
389OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
390 // subi(x,x) -> 0
391 if (getOperand(0) == getOperand(1)) {
392 auto shapedType = dyn_cast<ShapedType>(getType());
393 // We can't generate a constant with a dynamic shaped tensor.
394 if (!shapedType || shapedType.hasStaticShape())
395 return Builder(getContext()).getZeroAttr(getType());
396 }
397 // subi(x,0) -> x
398 if (matchPattern(adaptor.getRhs(), m_Zero()))
399 return getLhs();
400
401 if (auto add = getLhs().getDefiningOp<AddIOp>()) {
402 // subi(addi(a, b), b) -> a
403 if (getRhs() == add.getRhs())
404 return add.getLhs();
405 // subi(addi(a, b), a) -> b
406 if (getRhs() == add.getLhs())
407 return add.getRhs();
408 }
409
410 return constFoldBinaryOp<IntegerAttr>(
411 adaptor.getOperands(),
412 [](APInt a, const APInt &b) { return std::move(a) - b; });
413}
414
415void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
416 MLIRContext *context) {
417 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
418 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
419 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
420}
421
422//===----------------------------------------------------------------------===//
423// MulIOp
424//===----------------------------------------------------------------------===//
425
426OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
427 // muli(x, 0) -> 0
428 if (matchPattern(adaptor.getRhs(), m_Zero()))
429 return getRhs();
430 // muli(x, 1) -> x
431 if (matchPattern(adaptor.getRhs(), m_One()))
432 return getLhs();
433 // TODO: Handle the overflow case.
434
435 // default folder
436 return constFoldBinaryOp<IntegerAttr>(
437 adaptor.getOperands(),
438 [](const APInt &a, const APInt &b) { return a * b; });
439}
440
441void arith::MulIOp::getAsmResultNames(
442 function_ref<void(Value, StringRef)> setNameFn) {
443 if (!isa<IndexType>(getType()))
444 return;
445
446 // Match vector.vscale by name to avoid depending on the vector dialect (which
447 // is a circular dependency).
448 auto isVscale = [](Operation *op) {
449 return op && op->getName().getStringRef() == "vector.vscale";
450 };
451
452 IntegerAttr baseValue;
453 auto isVscaleExpr = [&](Value a, Value b) {
454 return matchPattern(a, m_Constant(&baseValue)) &&
455 isVscale(b.getDefiningOp());
456 };
457
458 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
459 return;
460
461 // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
462 SmallString<32> specialNameBuffer;
463 llvm::raw_svector_ostream specialName(specialNameBuffer);
464 specialName << 'c' << baseValue.getInt() << "_vscale";
465 setNameFn(getResult(), specialName.str());
466}
467
468void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
469 MLIRContext *context) {
470 patterns.add<MulIMulIConstant>(context);
471}
472
473//===----------------------------------------------------------------------===//
474// MulSIExtendedOp
475//===----------------------------------------------------------------------===//
476
477std::optional<SmallVector<int64_t, 4>>
478arith::MulSIExtendedOp::getShapeForUnroll() {
479 if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
480 return llvm::to_vector<4>(vt.getShape());
481 return std::nullopt;
482}
483
484LogicalResult
485arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
486 SmallVectorImpl<OpFoldResult> &results) {
487 // mulsi_extended(x, 0) -> 0, 0
488 if (matchPattern(adaptor.getRhs(), m_Zero())) {
489 Attribute zero = adaptor.getRhs();
490 results.push_back(zero);
491 results.push_back(zero);
492 return success();
493 }
494
495 // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
496 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
497 adaptor.getOperands(),
498 [](const APInt &a, const APInt &b) { return a * b; })) {
499 // Invoke the constant fold helper again to calculate the 'high' result.
500 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
501 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
502 return llvm::APIntOps::mulhs(a, b);
503 });
504 assert(highAttr && "Unexpected constant-folding failure");
505
506 results.push_back(lowAttr);
507 results.push_back(highAttr);
508 return success();
509 }
510
511 return failure();
512}
513
514void arith::MulSIExtendedOp::getCanonicalizationPatterns(
515 RewritePatternSet &patterns, MLIRContext *context) {
516 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
517}
518
519//===----------------------------------------------------------------------===//
520// MulUIExtendedOp
521//===----------------------------------------------------------------------===//
522
523std::optional<SmallVector<int64_t, 4>>
524arith::MulUIExtendedOp::getShapeForUnroll() {
525 if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
526 return llvm::to_vector<4>(vt.getShape());
527 return std::nullopt;
528}
529
530LogicalResult
531arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
532 SmallVectorImpl<OpFoldResult> &results) {
533 // mului_extended(x, 0) -> 0, 0
534 if (matchPattern(adaptor.getRhs(), m_Zero())) {
535 Attribute zero = adaptor.getRhs();
536 results.push_back(zero);
537 results.push_back(zero);
538 return success();
539 }
540
541 // mului_extended(x, 1) -> x, 0
542 if (matchPattern(adaptor.getRhs(), m_One())) {
543 Builder builder(getContext());
544 Attribute zero = builder.getZeroAttr(getLhs().getType());
545 results.push_back(getLhs());
546 results.push_back(zero);
547 return success();
548 }
549
550 // mului_extended(cst_a, cst_b) -> cst_low, cst_high
551 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
552 adaptor.getOperands(),
553 [](const APInt &a, const APInt &b) { return a * b; })) {
554 // Invoke the constant fold helper again to calculate the 'high' result.
555 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
556 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
557 return llvm::APIntOps::mulhu(a, b);
558 });
559 assert(highAttr && "Unexpected constant-folding failure");
560
561 results.push_back(lowAttr);
562 results.push_back(highAttr);
563 return success();
564 }
565
566 return failure();
567}
568
569void arith::MulUIExtendedOp::getCanonicalizationPatterns(
570 RewritePatternSet &patterns, MLIRContext *context) {
571 patterns.add<MulUIExtendedToMulI>(context);
572}
573
574//===----------------------------------------------------------------------===//
575// DivUIOp
576//===----------------------------------------------------------------------===//
577
578/// Fold `(a * b) / b -> a`
579static Value foldDivMul(Value lhs, Value rhs,
580 arith::IntegerOverflowFlags ovfFlags) {
581 auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
582 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
583 return {};
584
585 if (mul.getLhs() == rhs)
586 return mul.getRhs();
587
588 if (mul.getRhs() == rhs)
589 return mul.getLhs();
590
591 return {};
592}
593
594OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
595 // divui (x, 1) -> x.
596 if (matchPattern(adaptor.getRhs(), m_One()))
597 return getLhs();
598
599 // (a * b) / b -> a
600 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
601 return val;
602
603 // Don't fold if it would require a division by zero.
604 bool div0 = false;
605 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
606 [&](APInt a, const APInt &b) {
607 if (div0 || !b) {
608 div0 = true;
609 return a;
610 }
611 return a.udiv(b);
612 });
613
614 return div0 ? Attribute() : result;
615}
616
617/// Returns whether an unsigned division by `divisor` is speculatable.
618static Speculation::Speculatability getDivUISpeculatability(Value divisor) {
619 // X / 0 => UB
620 if (matchPattern(value: divisor, pattern: m_IntRangeWithoutZeroU()))
621 return Speculation::Speculatable;
622
623 return Speculation::NotSpeculatable;
624}
625
626Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
627 return getDivUISpeculatability(getRhs());
628}
629
630//===----------------------------------------------------------------------===//
631// DivSIOp
632//===----------------------------------------------------------------------===//
633
634OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
635 // divsi (x, 1) -> x.
636 if (matchPattern(adaptor.getRhs(), m_One()))
637 return getLhs();
638
639 // (a * b) / b -> a
640 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
641 return val;
642
643 // Don't fold if it would overflow or if it requires a division by zero.
644 bool overflowOrDiv0 = false;
645 auto result = constFoldBinaryOp<IntegerAttr>(
646 adaptor.getOperands(), [&](APInt a, const APInt &b) {
647 if (overflowOrDiv0 || !b) {
648 overflowOrDiv0 = true;
649 return a;
650 }
651 return a.sdiv_ov(b, overflowOrDiv0);
652 });
653
654 return overflowOrDiv0 ? Attribute() : result;
655}
656
657/// Returns whether a signed division by `divisor` is speculatable. This
658/// function conservatively assumes that all signed division by -1 are not
659/// speculatable.
660static Speculation::Speculatability getDivSISpeculatability(Value divisor) {
661 // X / 0 => UB
662 // INT_MIN / -1 => UB
663 if (matchPattern(value: divisor, pattern: m_IntRangeWithoutZeroS()) &&
664 matchPattern(value: divisor, pattern: m_IntRangeWithoutNegOneS()))
665 return Speculation::Speculatable;
666
667 return Speculation::NotSpeculatable;
668}
669
670Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
671 return getDivSISpeculatability(getRhs());
672}
673
674//===----------------------------------------------------------------------===//
675// Ceil and floor division folding helpers
676//===----------------------------------------------------------------------===//
677
678static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
679 bool &overflow) {
680 // Returns (a-1)/b + 1
681 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
682 APInt val = a.ssub_ov(RHS: one, Overflow&: overflow).sdiv_ov(RHS: b, Overflow&: overflow);
683 return val.sadd_ov(RHS: one, Overflow&: overflow);
684}
685
686//===----------------------------------------------------------------------===//
687// CeilDivUIOp
688//===----------------------------------------------------------------------===//
689
690OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
691 // ceildivui (x, 1) -> x.
692 if (matchPattern(adaptor.getRhs(), m_One()))
693 return getLhs();
694
695 bool overflowOrDiv0 = false;
696 auto result = constFoldBinaryOp<IntegerAttr>(
697 adaptor.getOperands(), [&](APInt a, const APInt &b) {
698 if (overflowOrDiv0 || !b) {
699 overflowOrDiv0 = true;
700 return a;
701 }
702 APInt quotient = a.udiv(b);
703 if (!a.urem(b))
704 return quotient;
705 APInt one(a.getBitWidth(), 1, true);
706 return quotient.uadd_ov(one, overflowOrDiv0);
707 });
708
709 return overflowOrDiv0 ? Attribute() : result;
710}
711
712Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
713 return getDivUISpeculatability(getRhs());
714}
715
716//===----------------------------------------------------------------------===//
717// CeilDivSIOp
718//===----------------------------------------------------------------------===//
719
720OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
721 // ceildivsi (x, 1) -> x.
722 if (matchPattern(adaptor.getRhs(), m_One()))
723 return getLhs();
724
725 // Don't fold if it would overflow or if it requires a division by zero.
726 // TODO: This hook won't fold operations where a = MININT, because
727 // negating MININT overflows. This can be improved.
728 bool overflowOrDiv0 = false;
729 auto result = constFoldBinaryOp<IntegerAttr>(
730 adaptor.getOperands(), [&](APInt a, const APInt &b) {
731 if (overflowOrDiv0 || !b) {
732 overflowOrDiv0 = true;
733 return a;
734 }
735 if (!a)
736 return a;
737 // After this point we know that neither a or b are zero.
738 unsigned bits = a.getBitWidth();
739 APInt zero = APInt::getZero(bits);
740 bool aGtZero = a.sgt(zero);
741 bool bGtZero = b.sgt(zero);
742 if (aGtZero && bGtZero) {
743 // Both positive, return ceil(a, b).
744 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
745 }
746
747 // No folding happens if any of the intermediate arithmetic operations
748 // overflows.
749 bool overflowNegA = false;
750 bool overflowNegB = false;
751 bool overflowDiv = false;
752 bool overflowNegRes = false;
753 if (!aGtZero && !bGtZero) {
754 // Both negative, return ceil(-a, -b).
755 APInt posA = zero.ssub_ov(a, overflowNegA);
756 APInt posB = zero.ssub_ov(b, overflowNegB);
757 APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
758 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
759 return res;
760 }
761 if (!aGtZero && bGtZero) {
762 // A is negative, b is positive, return - ( -a / b).
763 APInt posA = zero.ssub_ov(a, overflowNegA);
764 APInt div = posA.sdiv_ov(b, overflowDiv);
765 APInt res = zero.ssub_ov(div, overflowNegRes);
766 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
767 return res;
768 }
769 // A is positive, b is negative, return - (a / -b).
770 APInt posB = zero.ssub_ov(b, overflowNegB);
771 APInt div = a.sdiv_ov(posB, overflowDiv);
772 APInt res = zero.ssub_ov(div, overflowNegRes);
773
774 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
775 return res;
776 });
777
778 return overflowOrDiv0 ? Attribute() : result;
779}
780
781Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
782 return getDivSISpeculatability(getRhs());
783}
784
785//===----------------------------------------------------------------------===//
786// FloorDivSIOp
787//===----------------------------------------------------------------------===//
788
789OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
790 // floordivsi (x, 1) -> x.
791 if (matchPattern(adaptor.getRhs(), m_One()))
792 return getLhs();
793
794 // Don't fold if it would overflow or if it requires a division by zero.
795 bool overflowOrDiv = false;
796 auto result = constFoldBinaryOp<IntegerAttr>(
797 adaptor.getOperands(), [&](APInt a, const APInt &b) {
798 if (b.isZero()) {
799 overflowOrDiv = true;
800 return a;
801 }
802 return a.sfloordiv_ov(b, overflowOrDiv);
803 });
804
805 return overflowOrDiv ? Attribute() : result;
806}
807
808//===----------------------------------------------------------------------===//
809// RemUIOp
810//===----------------------------------------------------------------------===//
811
812OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
813 // remui (x, 1) -> 0.
814 if (matchPattern(adaptor.getRhs(), m_One()))
815 return Builder(getContext()).getZeroAttr(getType());
816
817 // Don't fold if it would require a division by zero.
818 bool div0 = false;
819 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
820 [&](APInt a, const APInt &b) {
821 if (div0 || b.isZero()) {
822 div0 = true;
823 return a;
824 }
825 return a.urem(b);
826 });
827
828 return div0 ? Attribute() : result;
829}
830
831//===----------------------------------------------------------------------===//
832// RemSIOp
833//===----------------------------------------------------------------------===//
834
835OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
836 // remsi (x, 1) -> 0.
837 if (matchPattern(adaptor.getRhs(), m_One()))
838 return Builder(getContext()).getZeroAttr(getType());
839
840 // Don't fold if it would require a division by zero.
841 bool div0 = false;
842 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
843 [&](APInt a, const APInt &b) {
844 if (div0 || b.isZero()) {
845 div0 = true;
846 return a;
847 }
848 return a.srem(b);
849 });
850
851 return div0 ? Attribute() : result;
852}
853
854//===----------------------------------------------------------------------===//
855// AndIOp
856//===----------------------------------------------------------------------===//
857
858/// Fold `and(a, and(a, b))` to `and(a, b)`
859static Value foldAndIofAndI(arith::AndIOp op) {
860 for (bool reversePrev : {false, true}) {
861 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
862 .getDefiningOp<arith::AndIOp>();
863 if (!prev)
864 continue;
865
866 Value other = (reversePrev ? op.getLhs() : op.getRhs());
867 if (other != prev.getLhs() && other != prev.getRhs())
868 continue;
869
870 return prev.getResult();
871 }
872 return {};
873}
874
875OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
876 /// and(x, 0) -> 0
877 if (matchPattern(adaptor.getRhs(), m_Zero()))
878 return getRhs();
879 /// and(x, allOnes) -> x
880 APInt intValue;
881 if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
882 intValue.isAllOnes())
883 return getLhs();
884 /// and(x, not(x)) -> 0
885 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
886 m_ConstantInt(&intValue))) &&
887 intValue.isAllOnes())
888 return Builder(getContext()).getZeroAttr(getType());
889 /// and(not(x), x) -> 0
890 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
891 m_ConstantInt(&intValue))) &&
892 intValue.isAllOnes())
893 return Builder(getContext()).getZeroAttr(getType());
894
895 /// and(a, and(a, b)) -> and(a, b)
896 if (Value result = foldAndIofAndI(*this))
897 return result;
898
899 return constFoldBinaryOp<IntegerAttr>(
900 adaptor.getOperands(),
901 [](APInt a, const APInt &b) { return std::move(a) & b; });
902}
903
904//===----------------------------------------------------------------------===//
905// OrIOp
906//===----------------------------------------------------------------------===//
907
908OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
909 if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
910 /// or(x, 0) -> x
911 if (rhsVal.isZero())
912 return getLhs();
913 /// or(x, <all ones>) -> <all ones>
914 if (rhsVal.isAllOnes())
915 return adaptor.getRhs();
916 }
917
918 APInt intValue;
919 /// or(x, xor(x, 1)) -> 1
920 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
921 m_ConstantInt(&intValue))) &&
922 intValue.isAllOnes())
923 return getRhs().getDefiningOp<XOrIOp>().getRhs();
924 /// or(xor(x, 1), x) -> 1
925 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
926 m_ConstantInt(&intValue))) &&
927 intValue.isAllOnes())
928 return getLhs().getDefiningOp<XOrIOp>().getRhs();
929
930 return constFoldBinaryOp<IntegerAttr>(
931 adaptor.getOperands(),
932 [](APInt a, const APInt &b) { return std::move(a) | b; });
933}
934
935//===----------------------------------------------------------------------===//
936// XOrIOp
937//===----------------------------------------------------------------------===//
938
939OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
940 /// xor(x, 0) -> x
941 if (matchPattern(adaptor.getRhs(), m_Zero()))
942 return getLhs();
943 /// xor(x, x) -> 0
944 if (getLhs() == getRhs())
945 return Builder(getContext()).getZeroAttr(getType());
946 /// xor(xor(x, a), a) -> x
947 /// xor(xor(a, x), a) -> x
948 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
949 if (prev.getRhs() == getRhs())
950 return prev.getLhs();
951 if (prev.getLhs() == getRhs())
952 return prev.getRhs();
953 }
954 /// xor(a, xor(x, a)) -> x
955 /// xor(a, xor(a, x)) -> x
956 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
957 if (prev.getRhs() == getLhs())
958 return prev.getLhs();
959 if (prev.getLhs() == getLhs())
960 return prev.getRhs();
961 }
962
963 return constFoldBinaryOp<IntegerAttr>(
964 adaptor.getOperands(),
965 [](APInt a, const APInt &b) { return std::move(a) ^ b; });
966}
967
968void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
969 MLIRContext *context) {
970 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
971}
972
973//===----------------------------------------------------------------------===//
974// NegFOp
975//===----------------------------------------------------------------------===//
976
977OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
978 /// negf(negf(x)) -> x
979 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
980 return op.getOperand();
981 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
982 [](const APFloat &a) { return -a; });
983}
984
985//===----------------------------------------------------------------------===//
986// AddFOp
987//===----------------------------------------------------------------------===//
988
989OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
990 // addf(x, -0) -> x
991 if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
992 return getLhs();
993
994 return constFoldBinaryOp<FloatAttr>(
995 adaptor.getOperands(),
996 [](const APFloat &a, const APFloat &b) { return a + b; });
997}
998
999//===----------------------------------------------------------------------===//
1000// SubFOp
1001//===----------------------------------------------------------------------===//
1002
1003OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1004 // subf(x, +0) -> x
1005 if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1006 return getLhs();
1007
1008 return constFoldBinaryOp<FloatAttr>(
1009 adaptor.getOperands(),
1010 [](const APFloat &a, const APFloat &b) { return a - b; });
1011}
1012
1013//===----------------------------------------------------------------------===//
1014// MaximumFOp
1015//===----------------------------------------------------------------------===//
1016
1017OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1018 // maximumf(x,x) -> x
1019 if (getLhs() == getRhs())
1020 return getRhs();
1021
1022 // maximumf(x, -inf) -> x
1023 if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1024 return getLhs();
1025
1026 return constFoldBinaryOp<FloatAttr>(
1027 adaptor.getOperands(),
1028 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1029}
1030
1031//===----------------------------------------------------------------------===//
1032// MaxNumFOp
1033//===----------------------------------------------------------------------===//
1034
1035OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1036 // maxnumf(x,x) -> x
1037 if (getLhs() == getRhs())
1038 return getRhs();
1039
1040 // maxnumf(x, NaN) -> x
1041 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1042 return getLhs();
1043
1044 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1045}
1046
1047//===----------------------------------------------------------------------===//
1048// MaxSIOp
1049//===----------------------------------------------------------------------===//
1050
1051OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1052 // maxsi(x,x) -> x
1053 if (getLhs() == getRhs())
1054 return getRhs();
1055
1056 if (APInt intValue;
1057 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1058 // maxsi(x,MAX_INT) -> MAX_INT
1059 if (intValue.isMaxSignedValue())
1060 return getRhs();
1061 // maxsi(x, MIN_INT) -> x
1062 if (intValue.isMinSignedValue())
1063 return getLhs();
1064 }
1065
1066 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1067 [](const APInt &a, const APInt &b) {
1068 return llvm::APIntOps::smax(a, b);
1069 });
1070}
1071
1072//===----------------------------------------------------------------------===//
1073// MaxUIOp
1074//===----------------------------------------------------------------------===//
1075
1076OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1077 // maxui(x,x) -> x
1078 if (getLhs() == getRhs())
1079 return getRhs();
1080
1081 if (APInt intValue;
1082 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1083 // maxui(x,MAX_INT) -> MAX_INT
1084 if (intValue.isMaxValue())
1085 return getRhs();
1086 // maxui(x, MIN_INT) -> x
1087 if (intValue.isMinValue())
1088 return getLhs();
1089 }
1090
1091 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1092 [](const APInt &a, const APInt &b) {
1093 return llvm::APIntOps::umax(a, b);
1094 });
1095}
1096
1097//===----------------------------------------------------------------------===//
1098// MinimumFOp
1099//===----------------------------------------------------------------------===//
1100
1101OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1102 // minimumf(x,x) -> x
1103 if (getLhs() == getRhs())
1104 return getRhs();
1105
1106 // minimumf(x, +inf) -> x
1107 if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1108 return getLhs();
1109
1110 return constFoldBinaryOp<FloatAttr>(
1111 adaptor.getOperands(),
1112 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1113}
1114
1115//===----------------------------------------------------------------------===//
1116// MinNumFOp
1117//===----------------------------------------------------------------------===//
1118
1119OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1120 // minnumf(x,x) -> x
1121 if (getLhs() == getRhs())
1122 return getRhs();
1123
1124 // minnumf(x, NaN) -> x
1125 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1126 return getLhs();
1127
1128 return constFoldBinaryOp<FloatAttr>(
1129 adaptor.getOperands(),
1130 [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1131}
1132
1133//===----------------------------------------------------------------------===//
1134// MinSIOp
1135//===----------------------------------------------------------------------===//
1136
1137OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1138 // minsi(x,x) -> x
1139 if (getLhs() == getRhs())
1140 return getRhs();
1141
1142 if (APInt intValue;
1143 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1144 // minsi(x,MIN_INT) -> MIN_INT
1145 if (intValue.isMinSignedValue())
1146 return getRhs();
1147 // minsi(x, MAX_INT) -> x
1148 if (intValue.isMaxSignedValue())
1149 return getLhs();
1150 }
1151
1152 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1153 [](const APInt &a, const APInt &b) {
1154 return llvm::APIntOps::smin(a, b);
1155 });
1156}
1157
1158//===----------------------------------------------------------------------===//
1159// MinUIOp
1160//===----------------------------------------------------------------------===//
1161
1162OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1163 // minui(x,x) -> x
1164 if (getLhs() == getRhs())
1165 return getRhs();
1166
1167 if (APInt intValue;
1168 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1169 // minui(x,MIN_INT) -> MIN_INT
1170 if (intValue.isMinValue())
1171 return getRhs();
1172 // minui(x, MAX_INT) -> x
1173 if (intValue.isMaxValue())
1174 return getLhs();
1175 }
1176
1177 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1178 [](const APInt &a, const APInt &b) {
1179 return llvm::APIntOps::umin(a, b);
1180 });
1181}
1182
1183//===----------------------------------------------------------------------===//
1184// MulFOp
1185//===----------------------------------------------------------------------===//
1186
1187OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1188 // mulf(x, 1) -> x
1189 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1190 return getLhs();
1191
1192 return constFoldBinaryOp<FloatAttr>(
1193 adaptor.getOperands(),
1194 [](const APFloat &a, const APFloat &b) { return a * b; });
1195}
1196
1197void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1198 MLIRContext *context) {
1199 patterns.add<MulFOfNegF>(context);
1200}
1201
1202//===----------------------------------------------------------------------===//
1203// DivFOp
1204//===----------------------------------------------------------------------===//
1205
1206OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1207 // divf(x, 1) -> x
1208 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1209 return getLhs();
1210
1211 return constFoldBinaryOp<FloatAttr>(
1212 adaptor.getOperands(),
1213 [](const APFloat &a, const APFloat &b) { return a / b; });
1214}
1215
1216void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1217 MLIRContext *context) {
1218 patterns.add<DivFOfNegF>(context);
1219}
1220
1221//===----------------------------------------------------------------------===//
1222// RemFOp
1223//===----------------------------------------------------------------------===//
1224
1225OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1226 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1227 [](const APFloat &a, const APFloat &b) {
1228 APFloat result(a);
1229 // APFloat::mod() offers the remainder
1230 // behavior we want, i.e. the result has
1231 // the sign of LHS operand.
1232 (void)result.mod(b);
1233 return result;
1234 });
1235}
1236
1237//===----------------------------------------------------------------------===//
1238// Utility functions for verifying cast ops
1239//===----------------------------------------------------------------------===//
1240
1241template <typename... Types>
1242using type_list = std::tuple<Types...> *;
1243
1244/// Returns a non-null type only if the provided type is one of the allowed
1245/// types or one of the allowed shaped types of the allowed types. Returns the
1246/// element type if a valid shaped type is provided.
1247template <typename... ShapedTypes, typename... ElementTypes>
1248static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
1249 type_list<ElementTypes...>) {
1250 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1251 return {};
1252
1253 auto underlyingType = getElementTypeOrSelf(type);
1254 if (!llvm::isa<ElementTypes...>(underlyingType))
1255 return {};
1256
1257 return underlyingType;
1258}
1259
1260/// Get allowed underlying types for vectors and tensors.
1261template <typename... ElementTypes>
1262static Type getTypeIfLike(Type type) {
1263 return getUnderlyingType(type, type_list<VectorType, TensorType>(),
1264 type_list<ElementTypes...>());
1265}
1266
1267/// Get allowed underlying types for vectors, tensors, and memrefs.
1268template <typename... ElementTypes>
1269static Type getTypeIfLikeOrMemRef(Type type) {
1270 return getUnderlyingType(type,
1271 type_list<VectorType, TensorType, MemRefType>(),
1272 type_list<ElementTypes...>());
1273}
1274
1275/// Return false if both types are ranked tensor with mismatching encoding.
1276static bool hasSameEncoding(Type typeA, Type typeB) {
1277 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1278 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1279 if (!rankedTensorA || !rankedTensorB)
1280 return true;
1281 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1282}
1283
1284static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1285 if (inputs.size() != 1 || outputs.size() != 1)
1286 return false;
1287 if (!hasSameEncoding(typeA: inputs.front(), typeB: outputs.front()))
1288 return false;
1289 return succeeded(Result: verifyCompatibleShapes(types1: inputs.front(), types2: outputs.front()));
1290}
1291
1292//===----------------------------------------------------------------------===//
1293// Verifiers for integer and floating point extension/truncation ops
1294//===----------------------------------------------------------------------===//
1295
1296// Extend ops can only extend to a wider type.
1297template <typename ValType, typename Op>
1298static LogicalResult verifyExtOp(Op op) {
1299 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1300 Type dstType = getElementTypeOrSelf(op.getType());
1301
1302 if (llvm::cast<ValType>(srcType).getWidth() >=
1303 llvm::cast<ValType>(dstType).getWidth())
1304 return op.emitError("result type ")
1305 << dstType << " must be wider than operand type " << srcType;
1306
1307 return success();
1308}
1309
1310// Truncate ops can only truncate to a shorter type.
1311template <typename ValType, typename Op>
1312static LogicalResult verifyTruncateOp(Op op) {
1313 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1314 Type dstType = getElementTypeOrSelf(op.getType());
1315
1316 if (llvm::cast<ValType>(srcType).getWidth() <=
1317 llvm::cast<ValType>(dstType).getWidth())
1318 return op.emitError("result type ")
1319 << dstType << " must be shorter than operand type " << srcType;
1320
1321 return success();
1322}
1323
1324/// Validate a cast that changes the width of a type.
1325template <template <typename> class WidthComparator, typename... ElementTypes>
1326static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1327 if (!areValidCastInputsAndOutputs(inputs, outputs))
1328 return false;
1329
1330 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1331 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1332 if (!srcType || !dstType)
1333 return false;
1334
1335 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1336 srcType.getIntOrFloatBitWidth());
1337}
1338
1339/// Attempts to convert `sourceValue` to an APFloat value with
1340/// `targetSemantics` and `roundingMode`, without any information loss.
1341static FailureOr<APFloat> convertFloatValue(
1342 APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1343 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1344 bool losesInfo = false;
1345 auto status = sourceValue.convert(ToSemantics: targetSemantics, RM: roundingMode, losesInfo: &losesInfo);
1346 if (losesInfo || status != APFloat::opOK)
1347 return failure();
1348
1349 return sourceValue;
1350}
1351
1352//===----------------------------------------------------------------------===//
1353// ExtUIOp
1354//===----------------------------------------------------------------------===//
1355
1356OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1357 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1358 getInMutable().assign(lhs.getIn());
1359 return getResult();
1360 }
1361
1362 Type resType = getElementTypeOrSelf(getType());
1363 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1364 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1365 adaptor.getOperands(), getType(),
1366 [bitWidth](const APInt &a, bool &castStatus) {
1367 return a.zext(bitWidth);
1368 });
1369}
1370
1371bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1372 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1373}
1374
1375LogicalResult arith::ExtUIOp::verify() {
1376 return verifyExtOp<IntegerType>(*this);
1377}
1378
1379//===----------------------------------------------------------------------===//
1380// ExtSIOp
1381//===----------------------------------------------------------------------===//
1382
1383OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1384 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1385 getInMutable().assign(lhs.getIn());
1386 return getResult();
1387 }
1388
1389 Type resType = getElementTypeOrSelf(getType());
1390 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1391 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1392 adaptor.getOperands(), getType(),
1393 [bitWidth](const APInt &a, bool &castStatus) {
1394 return a.sext(bitWidth);
1395 });
1396}
1397
1398bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1399 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1400}
1401
1402void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1403 MLIRContext *context) {
1404 patterns.add<ExtSIOfExtUI>(context);
1405}
1406
1407LogicalResult arith::ExtSIOp::verify() {
1408 return verifyExtOp<IntegerType>(*this);
1409}
1410
1411//===----------------------------------------------------------------------===//
1412// ExtFOp
1413//===----------------------------------------------------------------------===//
1414
1415/// Fold extension of float constants when there is no information loss due the
1416/// difference in fp semantics.
1417OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1418 if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1419 if (truncFOp.getOperand().getType() == getType()) {
1420 arith::FastMathFlags truncFMF =
1421 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1422 bool isTruncContract =
1423 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1424 arith::FastMathFlags extFMF =
1425 getFastmath().value_or(arith::FastMathFlags::none);
1426 bool isExtContract =
1427 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1428 if (isTruncContract && isExtContract) {
1429 return truncFOp.getOperand();
1430 }
1431 }
1432 }
1433
1434 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1435 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1436 return constFoldCastOp<FloatAttr, FloatAttr>(
1437 adaptor.getOperands(), getType(),
1438 [&targetSemantics](const APFloat &a, bool &castStatus) {
1439 FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1440 if (failed(result)) {
1441 castStatus = false;
1442 return a;
1443 }
1444 return *result;
1445 });
1446}
1447
1448bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1449 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1450}
1451
1452LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1453
1454//===----------------------------------------------------------------------===//
1455// ScalingExtFOp
1456//===----------------------------------------------------------------------===//
1457
1458bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1459 TypeRange outputs) {
1460 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1461}
1462
1463LogicalResult arith::ScalingExtFOp::verify() {
1464 return verifyExtOp<FloatType>(*this);
1465}
1466
1467//===----------------------------------------------------------------------===//
1468// TruncIOp
1469//===----------------------------------------------------------------------===//
1470
1471OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1472 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1473 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1474 Value src = getOperand().getDefiningOp()->getOperand(0);
1475 Type srcType = getElementTypeOrSelf(src.getType());
1476 Type dstType = getElementTypeOrSelf(getType());
1477 // trunci(zexti(a)) -> trunci(a)
1478 // trunci(sexti(a)) -> trunci(a)
1479 if (llvm::cast<IntegerType>(srcType).getWidth() >
1480 llvm::cast<IntegerType>(dstType).getWidth()) {
1481 setOperand(src);
1482 return getResult();
1483 }
1484
1485 // trunci(zexti(a)) -> a
1486 // trunci(sexti(a)) -> a
1487 if (srcType == dstType)
1488 return src;
1489 }
1490
1491 // trunci(trunci(a)) -> trunci(a))
1492 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1493 setOperand(getOperand().getDefiningOp()->getOperand(0));
1494 return getResult();
1495 }
1496
1497 Type resType = getElementTypeOrSelf(getType());
1498 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1499 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1500 adaptor.getOperands(), getType(),
1501 [bitWidth](const APInt &a, bool &castStatus) {
1502 return a.trunc(bitWidth);
1503 });
1504}
1505
1506bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1507 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1508}
1509
1510void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1511 MLIRContext *context) {
1512 patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1513 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1514 context);
1515}
1516
1517LogicalResult arith::TruncIOp::verify() {
1518 return verifyTruncateOp<IntegerType>(*this);
1519}
1520
1521//===----------------------------------------------------------------------===//
1522// TruncFOp
1523//===----------------------------------------------------------------------===//
1524
1525/// Perform safe const propagation for truncf, i.e., only propagate if FP value
1526/// can be represented without precision loss.
1527OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1528 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1529 if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1530 Value src = extOp.getIn();
1531 auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1532 auto intermediateType =
1533 cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1534 // Check if the srcType is representable in the intermediateType.
1535 if (llvm::APFloatBase::isRepresentableBy(
1536 srcType.getFloatSemantics(),
1537 intermediateType.getFloatSemantics())) {
1538 // truncf(extf(a)) -> truncf(a)
1539 if (srcType.getWidth() > resElemType.getWidth()) {
1540 setOperand(src);
1541 return getResult();
1542 }
1543
1544 // truncf(extf(a)) -> a
1545 if (srcType == resElemType)
1546 return src;
1547 }
1548 }
1549
1550 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1551 return constFoldCastOp<FloatAttr, FloatAttr>(
1552 adaptor.getOperands(), getType(),
1553 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1554 RoundingMode roundingMode =
1555 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1556 llvm::RoundingMode llvmRoundingMode =
1557 convertArithRoundingModeToLLVMIR(roundingMode);
1558 FailureOr<APFloat> result =
1559 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1560 if (failed(result)) {
1561 castStatus = false;
1562 return a;
1563 }
1564 return *result;
1565 });
1566}
1567
1568void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1569 MLIRContext *context) {
1570 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1571}
1572
1573bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1574 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1575}
1576
1577LogicalResult arith::TruncFOp::verify() {
1578 return verifyTruncateOp<FloatType>(*this);
1579}
1580
1581//===----------------------------------------------------------------------===//
1582// ScalingTruncFOp
1583//===----------------------------------------------------------------------===//
1584
1585bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1586 TypeRange outputs) {
1587 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1588}
1589
1590LogicalResult arith::ScalingTruncFOp::verify() {
1591 return verifyTruncateOp<FloatType>(*this);
1592}
1593
1594//===----------------------------------------------------------------------===//
1595// AndIOp
1596//===----------------------------------------------------------------------===//
1597
1598void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1599 MLIRContext *context) {
1600 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1601}
1602
1603//===----------------------------------------------------------------------===//
1604// OrIOp
1605//===----------------------------------------------------------------------===//
1606
1607void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1608 MLIRContext *context) {
1609 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1610}
1611
1612//===----------------------------------------------------------------------===//
1613// Verifiers for casts between integers and floats.
1614//===----------------------------------------------------------------------===//
1615
1616template <typename From, typename To>
1617static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1618 if (!areValidCastInputsAndOutputs(inputs, outputs))
1619 return false;
1620
1621 auto srcType = getTypeIfLike<From>(inputs.front());
1622 auto dstType = getTypeIfLike<To>(outputs.back());
1623
1624 return srcType && dstType;
1625}
1626
1627//===----------------------------------------------------------------------===//
1628// UIToFPOp
1629//===----------------------------------------------------------------------===//
1630
1631bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1632 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1633}
1634
1635OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1636 Type resEleType = getElementTypeOrSelf(getType());
1637 return constFoldCastOp<IntegerAttr, FloatAttr>(
1638 adaptor.getOperands(), getType(),
1639 [&resEleType](const APInt &a, bool &castStatus) {
1640 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1641 APFloat apf(floatTy.getFloatSemantics(),
1642 APInt::getZero(floatTy.getWidth()));
1643 apf.convertFromAPInt(a, /*IsSigned=*/false,
1644 APFloat::rmNearestTiesToEven);
1645 return apf;
1646 });
1647}
1648
1649//===----------------------------------------------------------------------===//
1650// SIToFPOp
1651//===----------------------------------------------------------------------===//
1652
1653bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1654 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1655}
1656
1657OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1658 Type resEleType = getElementTypeOrSelf(getType());
1659 return constFoldCastOp<IntegerAttr, FloatAttr>(
1660 adaptor.getOperands(), getType(),
1661 [&resEleType](const APInt &a, bool &castStatus) {
1662 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1663 APFloat apf(floatTy.getFloatSemantics(),
1664 APInt::getZero(floatTy.getWidth()));
1665 apf.convertFromAPInt(a, /*IsSigned=*/true,
1666 APFloat::rmNearestTiesToEven);
1667 return apf;
1668 });
1669}
1670
1671//===----------------------------------------------------------------------===//
1672// FPToUIOp
1673//===----------------------------------------------------------------------===//
1674
1675bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1676 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1677}
1678
1679OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1680 Type resType = getElementTypeOrSelf(getType());
1681 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1682 return constFoldCastOp<FloatAttr, IntegerAttr>(
1683 adaptor.getOperands(), getType(),
1684 [&bitWidth](const APFloat &a, bool &castStatus) {
1685 bool ignored;
1686 APSInt api(bitWidth, /*isUnsigned=*/true);
1687 castStatus = APFloat::opInvalidOp !=
1688 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1689 return api;
1690 });
1691}
1692
1693//===----------------------------------------------------------------------===//
1694// FPToSIOp
1695//===----------------------------------------------------------------------===//
1696
1697bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1698 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1699}
1700
1701OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1702 Type resType = getElementTypeOrSelf(getType());
1703 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1704 return constFoldCastOp<FloatAttr, IntegerAttr>(
1705 adaptor.getOperands(), getType(),
1706 [&bitWidth](const APFloat &a, bool &castStatus) {
1707 bool ignored;
1708 APSInt api(bitWidth, /*isUnsigned=*/false);
1709 castStatus = APFloat::opInvalidOp !=
1710 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1711 return api;
1712 });
1713}
1714
1715//===----------------------------------------------------------------------===//
1716// IndexCastOp
1717//===----------------------------------------------------------------------===//
1718
1719static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1720 if (!areValidCastInputsAndOutputs(inputs, outputs))
1721 return false;
1722
1723 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(type: inputs.front());
1724 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(type: outputs.front());
1725 if (!srcType || !dstType)
1726 return false;
1727
1728 return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1729 (srcType.isSignlessInteger() && dstType.isIndex());
1730}
1731
1732bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1733 TypeRange outputs) {
1734 return areIndexCastCompatible(inputs, outputs);
1735}
1736
1737OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1738 // index_cast(constant) -> constant
1739 unsigned resultBitwidth = 64; // Default for index integer attributes.
1740 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1741 resultBitwidth = intTy.getWidth();
1742
1743 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1744 adaptor.getOperands(), getType(),
1745 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1746 return a.sextOrTrunc(resultBitwidth);
1747 });
1748}
1749
1750void arith::IndexCastOp::getCanonicalizationPatterns(
1751 RewritePatternSet &patterns, MLIRContext *context) {
1752 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1753}
1754
1755//===----------------------------------------------------------------------===//
1756// IndexCastUIOp
1757//===----------------------------------------------------------------------===//
1758
1759bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1760 TypeRange outputs) {
1761 return areIndexCastCompatible(inputs, outputs);
1762}
1763
1764OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1765 // index_castui(constant) -> constant
1766 unsigned resultBitwidth = 64; // Default for index integer attributes.
1767 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1768 resultBitwidth = intTy.getWidth();
1769
1770 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1771 adaptor.getOperands(), getType(),
1772 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1773 return a.zextOrTrunc(resultBitwidth);
1774 });
1775}
1776
1777void arith::IndexCastUIOp::getCanonicalizationPatterns(
1778 RewritePatternSet &patterns, MLIRContext *context) {
1779 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1780}
1781
1782//===----------------------------------------------------------------------===//
1783// BitcastOp
1784//===----------------------------------------------------------------------===//
1785
1786bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1787 if (!areValidCastInputsAndOutputs(inputs, outputs))
1788 return false;
1789
1790 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1791 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1792 if (!srcType || !dstType)
1793 return false;
1794
1795 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1796}
1797
1798OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1799 auto resType = getType();
1800 auto operand = adaptor.getIn();
1801 if (!operand)
1802 return {};
1803
1804 /// Bitcast dense elements.
1805 if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1806 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1807 /// Other shaped types unhandled.
1808 if (llvm::isa<ShapedType>(resType))
1809 return {};
1810
1811 /// Bitcast poison.
1812 if (llvm::isa<ub::PoisonAttr>(operand))
1813 return ub::PoisonAttr::get(getContext());
1814
1815 /// Bitcast integer or float to integer or float.
1816 APInt bits = llvm::isa<FloatAttr>(operand)
1817 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1818 : llvm::cast<IntegerAttr>(operand).getValue();
1819 assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1820 "trying to fold on broken IR: operands have incompatible types");
1821
1822 if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1823 return FloatAttr::get(resType,
1824 APFloat(resFloatType.getFloatSemantics(), bits));
1825 return IntegerAttr::get(resType, bits);
1826}
1827
1828void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1829 MLIRContext *context) {
1830 patterns.add<BitcastOfBitcast>(context);
1831}
1832
1833//===----------------------------------------------------------------------===//
1834// CmpIOp
1835//===----------------------------------------------------------------------===//
1836
1837/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1838/// comparison predicates.
1839bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1840 const APInt &lhs, const APInt &rhs) {
1841 switch (predicate) {
1842 case arith::CmpIPredicate::eq:
1843 return lhs.eq(RHS: rhs);
1844 case arith::CmpIPredicate::ne:
1845 return lhs.ne(RHS: rhs);
1846 case arith::CmpIPredicate::slt:
1847 return lhs.slt(RHS: rhs);
1848 case arith::CmpIPredicate::sle:
1849 return lhs.sle(RHS: rhs);
1850 case arith::CmpIPredicate::sgt:
1851 return lhs.sgt(RHS: rhs);
1852 case arith::CmpIPredicate::sge:
1853 return lhs.sge(RHS: rhs);
1854 case arith::CmpIPredicate::ult:
1855 return lhs.ult(RHS: rhs);
1856 case arith::CmpIPredicate::ule:
1857 return lhs.ule(RHS: rhs);
1858 case arith::CmpIPredicate::ugt:
1859 return lhs.ugt(RHS: rhs);
1860 case arith::CmpIPredicate::uge:
1861 return lhs.uge(RHS: rhs);
1862 }
1863 llvm_unreachable("unknown cmpi predicate kind");
1864}
1865
1866/// Returns true if the predicate is true for two equal operands.
1867static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1868 switch (predicate) {
1869 case arith::CmpIPredicate::eq:
1870 case arith::CmpIPredicate::sle:
1871 case arith::CmpIPredicate::sge:
1872 case arith::CmpIPredicate::ule:
1873 case arith::CmpIPredicate::uge:
1874 return true;
1875 case arith::CmpIPredicate::ne:
1876 case arith::CmpIPredicate::slt:
1877 case arith::CmpIPredicate::sgt:
1878 case arith::CmpIPredicate::ult:
1879 case arith::CmpIPredicate::ugt:
1880 return false;
1881 }
1882 llvm_unreachable("unknown cmpi predicate kind");
1883}
1884
1885static std::optional<int64_t> getIntegerWidth(Type t) {
1886 if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1887 return intType.getWidth();
1888 }
1889 if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1890 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1891 }
1892 return std::nullopt;
1893}
1894
1895OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1896 // cmpi(pred, x, x)
1897 if (getLhs() == getRhs()) {
1898 auto val = applyCmpPredicateToEqualOperands(getPredicate());
1899 return getBoolAttribute(getType(), val);
1900 }
1901
1902 if (matchPattern(adaptor.getRhs(), m_Zero())) {
1903 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1904 // extsi(%x : i1 -> iN) != 0 -> %x
1905 std::optional<int64_t> integerWidth =
1906 getIntegerWidth(extOp.getOperand().getType());
1907 if (integerWidth && integerWidth.value() == 1 &&
1908 getPredicate() == arith::CmpIPredicate::ne)
1909 return extOp.getOperand();
1910 }
1911 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1912 // extui(%x : i1 -> iN) != 0 -> %x
1913 std::optional<int64_t> integerWidth =
1914 getIntegerWidth(extOp.getOperand().getType());
1915 if (integerWidth && integerWidth.value() == 1 &&
1916 getPredicate() == arith::CmpIPredicate::ne)
1917 return extOp.getOperand();
1918 }
1919
1920 // arith.cmpi ne, %val, %zero : i1 -> %val
1921 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1922 getPredicate() == arith::CmpIPredicate::ne)
1923 return getLhs();
1924 }
1925
1926 if (matchPattern(adaptor.getRhs(), m_One())) {
1927 // arith.cmpi eq, %val, %one : i1 -> %val
1928 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
1929 getPredicate() == arith::CmpIPredicate::eq)
1930 return getLhs();
1931 }
1932
1933 // Move constant to the right side.
1934 if (adaptor.getLhs() && !adaptor.getRhs()) {
1935 // Do not use invertPredicate, as it will change eq to ne and vice versa.
1936 using Pred = CmpIPredicate;
1937 const std::pair<Pred, Pred> invPreds[] = {
1938 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1939 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1940 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1941 {Pred::ne, Pred::ne},
1942 };
1943 Pred origPred = getPredicate();
1944 for (auto pred : invPreds) {
1945 if (origPred == pred.first) {
1946 setPredicate(pred.second);
1947 Value lhs = getLhs();
1948 Value rhs = getRhs();
1949 getLhsMutable().assign(rhs);
1950 getRhsMutable().assign(lhs);
1951 return getResult();
1952 }
1953 }
1954 llvm_unreachable("unknown cmpi predicate kind");
1955 }
1956
1957 // We are moving constants to the right side; So if lhs is constant rhs is
1958 // guaranteed to be a constant.
1959 if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1960 return constFoldBinaryOp<IntegerAttr>(
1961 adaptor.getOperands(), getI1SameShape(lhs.getType()),
1962 [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1963 return APInt(1,
1964 static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1965 });
1966 }
1967
1968 return {};
1969}
1970
1971void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1972 MLIRContext *context) {
1973 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1974}
1975
1976//===----------------------------------------------------------------------===//
1977// CmpFOp
1978//===----------------------------------------------------------------------===//
1979
1980/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1981/// comparison predicates.
1982bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1983 const APFloat &lhs, const APFloat &rhs) {
1984 auto cmpResult = lhs.compare(RHS: rhs);
1985 switch (predicate) {
1986 case arith::CmpFPredicate::AlwaysFalse:
1987 return false;
1988 case arith::CmpFPredicate::OEQ:
1989 return cmpResult == APFloat::cmpEqual;
1990 case arith::CmpFPredicate::OGT:
1991 return cmpResult == APFloat::cmpGreaterThan;
1992 case arith::CmpFPredicate::OGE:
1993 return cmpResult == APFloat::cmpGreaterThan ||
1994 cmpResult == APFloat::cmpEqual;
1995 case arith::CmpFPredicate::OLT:
1996 return cmpResult == APFloat::cmpLessThan;
1997 case arith::CmpFPredicate::OLE:
1998 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1999 case arith::CmpFPredicate::ONE:
2000 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2001 case arith::CmpFPredicate::ORD:
2002 return cmpResult != APFloat::cmpUnordered;
2003 case arith::CmpFPredicate::UEQ:
2004 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2005 case arith::CmpFPredicate::UGT:
2006 return cmpResult == APFloat::cmpUnordered ||
2007 cmpResult == APFloat::cmpGreaterThan;
2008 case arith::CmpFPredicate::UGE:
2009 return cmpResult == APFloat::cmpUnordered ||
2010 cmpResult == APFloat::cmpGreaterThan ||
2011 cmpResult == APFloat::cmpEqual;
2012 case arith::CmpFPredicate::ULT:
2013 return cmpResult == APFloat::cmpUnordered ||
2014 cmpResult == APFloat::cmpLessThan;
2015 case arith::CmpFPredicate::ULE:
2016 return cmpResult == APFloat::cmpUnordered ||
2017 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2018 case arith::CmpFPredicate::UNE:
2019 return cmpResult != APFloat::cmpEqual;
2020 case arith::CmpFPredicate::UNO:
2021 return cmpResult == APFloat::cmpUnordered;
2022 case arith::CmpFPredicate::AlwaysTrue:
2023 return true;
2024 }
2025 llvm_unreachable("unknown cmpf predicate kind");
2026}
2027
2028OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2029 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2030 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2031
2032 // If one operand is NaN, making them both NaN does not change the result.
2033 if (lhs && lhs.getValue().isNaN())
2034 rhs = lhs;
2035 if (rhs && rhs.getValue().isNaN())
2036 lhs = rhs;
2037
2038 if (!lhs || !rhs)
2039 return {};
2040
2041 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2042 return BoolAttr::get(getContext(), val);
2043}
2044
2045class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2046public:
2047 using OpRewritePattern<CmpFOp>::OpRewritePattern;
2048
2049 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2050 bool isUnsigned) {
2051 using namespace arith;
2052 switch (pred) {
2053 case CmpFPredicate::UEQ:
2054 case CmpFPredicate::OEQ:
2055 return CmpIPredicate::eq;
2056 case CmpFPredicate::UGT:
2057 case CmpFPredicate::OGT:
2058 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2059 case CmpFPredicate::UGE:
2060 case CmpFPredicate::OGE:
2061 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2062 case CmpFPredicate::ULT:
2063 case CmpFPredicate::OLT:
2064 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2065 case CmpFPredicate::ULE:
2066 case CmpFPredicate::OLE:
2067 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2068 case CmpFPredicate::UNE:
2069 case CmpFPredicate::ONE:
2070 return CmpIPredicate::ne;
2071 default:
2072 llvm_unreachable("Unexpected predicate!");
2073 }
2074 }
2075
2076 LogicalResult matchAndRewrite(CmpFOp op,
2077 PatternRewriter &rewriter) const override {
2078 FloatAttr flt;
2079 if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2080 return failure();
2081
2082 const APFloat &rhs = flt.getValue();
2083
2084 // Don't attempt to fold a nan.
2085 if (rhs.isNaN())
2086 return failure();
2087
2088 // Get the width of the mantissa. We don't want to hack on conversions that
2089 // might lose information from the integer, e.g. "i64 -> float"
2090 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2091 int mantissaWidth = floatTy.getFPMantissaWidth();
2092 if (mantissaWidth <= 0)
2093 return failure();
2094
2095 bool isUnsigned;
2096 Value intVal;
2097
2098 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2099 isUnsigned = false;
2100 intVal = si.getIn();
2101 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2102 isUnsigned = true;
2103 intVal = ui.getIn();
2104 } else {
2105 return failure();
2106 }
2107
2108 // Check to see that the input is converted from an integer type that is
2109 // small enough that preserves all bits.
2110 auto intTy = llvm::cast<IntegerType>(intVal.getType());
2111 auto intWidth = intTy.getWidth();
2112
2113 // Number of bits representing values, as opposed to the sign
2114 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2115
2116 // Following test does NOT adjust intWidth downwards for signed inputs,
2117 // because the most negative value still requires all the mantissa bits
2118 // to distinguish it from one less than that value.
2119 if ((int)intWidth > mantissaWidth) {
2120 // Conversion would lose accuracy. Check if loss can impact comparison.
2121 int exponent = ilogb(Arg: rhs);
2122 if (exponent == APFloat::IEK_Inf) {
2123 int maxExponent = ilogb(Arg: APFloat::getLargest(Sem: rhs.getSemantics()));
2124 if (maxExponent < (int)valueBits) {
2125 // Conversion could create infinity.
2126 return failure();
2127 }
2128 } else {
2129 // Note that if rhs is zero or NaN, then Exp is negative
2130 // and first condition is trivially false.
2131 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2132 // Conversion could affect comparison.
2133 return failure();
2134 }
2135 }
2136 }
2137
2138 // Convert to equivalent cmpi predicate
2139 CmpIPredicate pred;
2140 switch (op.getPredicate()) {
2141 case CmpFPredicate::ORD:
2142 // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2143 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2144 /*width=*/1);
2145 return success();
2146 case CmpFPredicate::UNO:
2147 // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2148 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2149 /*width=*/1);
2150 return success();
2151 default:
2152 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2153 break;
2154 }
2155
2156 if (!isUnsigned) {
2157 // If the rhs value is > SignedMax, fold the comparison. This handles
2158 // +INF and large values.
2159 APFloat signedMax(rhs.getSemantics());
2160 signedMax.convertFromAPInt(Input: APInt::getSignedMaxValue(numBits: intWidth), IsSigned: true,
2161 RM: APFloat::rmNearestTiesToEven);
2162 if (signedMax < rhs) { // smax < 13123.0
2163 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2164 pred == CmpIPredicate::sle)
2165 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2166 /*width=*/1);
2167 else
2168 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2169 /*width=*/1);
2170 return success();
2171 }
2172 } else {
2173 // If the rhs value is > UnsignedMax, fold the comparison. This handles
2174 // +INF and large values.
2175 APFloat unsignedMax(rhs.getSemantics());
2176 unsignedMax.convertFromAPInt(Input: APInt::getMaxValue(numBits: intWidth), IsSigned: false,
2177 RM: APFloat::rmNearestTiesToEven);
2178 if (unsignedMax < rhs) { // umax < 13123.0
2179 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2180 pred == CmpIPredicate::ule)
2181 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2182 /*width=*/1);
2183 else
2184 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2185 /*width=*/1);
2186 return success();
2187 }
2188 }
2189
2190 if (!isUnsigned) {
2191 // See if the rhs value is < SignedMin.
2192 APFloat signedMin(rhs.getSemantics());
2193 signedMin.convertFromAPInt(Input: APInt::getSignedMinValue(numBits: intWidth), IsSigned: true,
2194 RM: APFloat::rmNearestTiesToEven);
2195 if (signedMin > rhs) { // smin > 12312.0
2196 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2197 pred == CmpIPredicate::sge)
2198 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2199 /*width=*/1);
2200 else
2201 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2202 /*width=*/1);
2203 return success();
2204 }
2205 } else {
2206 // See if the rhs value is < UnsignedMin.
2207 APFloat unsignedMin(rhs.getSemantics());
2208 unsignedMin.convertFromAPInt(Input: APInt::getMinValue(numBits: intWidth), IsSigned: false,
2209 RM: APFloat::rmNearestTiesToEven);
2210 if (unsignedMin > rhs) { // umin > 12312.0
2211 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2212 pred == CmpIPredicate::uge)
2213 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2214 /*width=*/1);
2215 else
2216 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2217 /*width=*/1);
2218 return success();
2219 }
2220 }
2221
2222 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2223 // [0, UMAX], but it may still be fractional. See if it is fractional by
2224 // casting the FP value to the integer value and back, checking for
2225 // equality. Don't do this for zero, because -0.0 is not fractional.
2226 bool ignored;
2227 APSInt rhsInt(intWidth, isUnsigned);
2228 if (APFloat::opInvalidOp ==
2229 rhs.convertToInteger(Result&: rhsInt, RM: APFloat::rmTowardZero, IsExact: &ignored)) {
2230 // Undefined behavior invoked - the destination type can't represent
2231 // the input constant.
2232 return failure();
2233 }
2234
2235 if (!rhs.isZero()) {
2236 APFloat apf(floatTy.getFloatSemantics(),
2237 APInt::getZero(numBits: floatTy.getWidth()));
2238 apf.convertFromAPInt(Input: rhsInt, IsSigned: !isUnsigned, RM: APFloat::rmNearestTiesToEven);
2239
2240 bool equal = apf == rhs;
2241 if (!equal) {
2242 // If we had a comparison against a fractional value, we have to adjust
2243 // the compare predicate and sometimes the value. rhsInt is rounded
2244 // towards zero at this point.
2245 switch (pred) {
2246 case CmpIPredicate::ne: // (float)int != 4.4 --> true
2247 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2248 /*width=*/1);
2249 return success();
2250 case CmpIPredicate::eq: // (float)int == 4.4 --> false
2251 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2252 /*width=*/1);
2253 return success();
2254 case CmpIPredicate::ule:
2255 // (float)int <= 4.4 --> int <= 4
2256 // (float)int <= -4.4 --> false
2257 if (rhs.isNegative()) {
2258 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2259 /*width=*/1);
2260 return success();
2261 }
2262 break;
2263 case CmpIPredicate::sle:
2264 // (float)int <= 4.4 --> int <= 4
2265 // (float)int <= -4.4 --> int < -4
2266 if (rhs.isNegative())
2267 pred = CmpIPredicate::slt;
2268 break;
2269 case CmpIPredicate::ult:
2270 // (float)int < -4.4 --> false
2271 // (float)int < 4.4 --> int <= 4
2272 if (rhs.isNegative()) {
2273 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2274 /*width=*/1);
2275 return success();
2276 }
2277 pred = CmpIPredicate::ule;
2278 break;
2279 case CmpIPredicate::slt:
2280 // (float)int < -4.4 --> int < -4
2281 // (float)int < 4.4 --> int <= 4
2282 if (!rhs.isNegative())
2283 pred = CmpIPredicate::sle;
2284 break;
2285 case CmpIPredicate::ugt:
2286 // (float)int > 4.4 --> int > 4
2287 // (float)int > -4.4 --> true
2288 if (rhs.isNegative()) {
2289 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2290 /*width=*/1);
2291 return success();
2292 }
2293 break;
2294 case CmpIPredicate::sgt:
2295 // (float)int > 4.4 --> int > 4
2296 // (float)int > -4.4 --> int >= -4
2297 if (rhs.isNegative())
2298 pred = CmpIPredicate::sge;
2299 break;
2300 case CmpIPredicate::uge:
2301 // (float)int >= -4.4 --> true
2302 // (float)int >= 4.4 --> int > 4
2303 if (rhs.isNegative()) {
2304 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2305 /*width=*/1);
2306 return success();
2307 }
2308 pred = CmpIPredicate::ugt;
2309 break;
2310 case CmpIPredicate::sge:
2311 // (float)int >= -4.4 --> int >= -4
2312 // (float)int >= 4.4 --> int > 4
2313 if (!rhs.isNegative())
2314 pred = CmpIPredicate::sgt;
2315 break;
2316 }
2317 }
2318 }
2319
2320 // Lower this FP comparison into an appropriate integer version of the
2321 // comparison.
2322 rewriter.replaceOpWithNewOp<CmpIOp>(
2323 op, pred, intVal,
2324 rewriter.create<ConstantOp>(
2325 op.getLoc(), intVal.getType(),
2326 rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2327 return success();
2328 }
2329};
2330
2331void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2332 MLIRContext *context) {
2333 patterns.insert<CmpFIntToFPConst>(context);
2334}
2335
2336//===----------------------------------------------------------------------===//
2337// SelectOp
2338//===----------------------------------------------------------------------===//
2339
2340// select %arg, %c1, %c0 => extui %arg
2341struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2342 using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
2343
2344 LogicalResult matchAndRewrite(arith::SelectOp op,
2345 PatternRewriter &rewriter) const override {
2346 // Cannot extui i1 to i1, or i1 to f32
2347 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2348 return failure();
2349
2350 // select %x, c1, %c0 => extui %arg
2351 if (matchPattern(op.getTrueValue(), m_One()) &&
2352 matchPattern(op.getFalseValue(), m_Zero())) {
2353 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2354 op.getCondition());
2355 return success();
2356 }
2357
2358 // select %x, c0, %c1 => extui (xor %arg, true)
2359 if (matchPattern(op.getTrueValue(), m_Zero()) &&
2360 matchPattern(op.getFalseValue(), m_One())) {
2361 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2362 op, op.getType(),
2363 rewriter.create<arith::XOrIOp>(
2364 op.getLoc(), op.getCondition(),
2365 rewriter.create<arith::ConstantIntOp>(
2366 op.getLoc(), 1, op.getCondition().getType())));
2367 return success();
2368 }
2369
2370 return failure();
2371 }
2372};
2373
2374void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2375 MLIRContext *context) {
2376 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2377 SelectI1ToNot, SelectToExtUI>(context);
2378}
2379
2380OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2381 Value trueVal = getTrueValue();
2382 Value falseVal = getFalseValue();
2383 if (trueVal == falseVal)
2384 return trueVal;
2385
2386 Value condition = getCondition();
2387
2388 // select true, %0, %1 => %0
2389 if (matchPattern(adaptor.getCondition(), m_One()))
2390 return trueVal;
2391
2392 // select false, %0, %1 => %1
2393 if (matchPattern(adaptor.getCondition(), m_Zero()))
2394 return falseVal;
2395
2396 // If either operand is fully poisoned, return the other.
2397 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2398 return falseVal;
2399
2400 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2401 return trueVal;
2402
2403 // select %x, true, false => %x
2404 if (getType().isSignlessInteger(1) &&
2405 matchPattern(adaptor.getTrueValue(), m_One()) &&
2406 matchPattern(adaptor.getFalseValue(), m_Zero()))
2407 return condition;
2408
2409 if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2410 auto pred = cmp.getPredicate();
2411 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2412 auto cmpLhs = cmp.getLhs();
2413 auto cmpRhs = cmp.getRhs();
2414
2415 // %0 = arith.cmpi eq, %arg0, %arg1
2416 // %1 = arith.select %0, %arg0, %arg1 => %arg1
2417
2418 // %0 = arith.cmpi ne, %arg0, %arg1
2419 // %1 = arith.select %0, %arg0, %arg1 => %arg0
2420
2421 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2422 (cmpRhs == trueVal && cmpLhs == falseVal))
2423 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2424 }
2425 }
2426
2427 // Constant-fold constant operands over non-splat constant condition.
2428 // select %cst_vec, %cst0, %cst1 => %cst2
2429 if (auto cond =
2430 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2431 if (auto lhs =
2432 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2433 if (auto rhs =
2434 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2435 SmallVector<Attribute> results;
2436 results.reserve(static_cast<size_t>(cond.getNumElements()));
2437 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2438 cond.value_end<BoolAttr>());
2439 auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2440 lhs.value_end<Attribute>());
2441 auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2442 rhs.value_end<Attribute>());
2443
2444 for (auto [condVal, lhsVal, rhsVal] :
2445 llvm::zip_equal(condVals, lhsVals, rhsVals))
2446 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2447
2448 return DenseElementsAttr::get(lhs.getType(), results);
2449 }
2450 }
2451 }
2452
2453 return nullptr;
2454}
2455
2456ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2457 Type conditionType, resultType;
2458 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2459 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2460 parser.parseOptionalAttrDict(result.attributes) ||
2461 parser.parseColonType(resultType))
2462 return failure();
2463
2464 // Check for the explicit condition type if this is a masked tensor or vector.
2465 if (succeeded(parser.parseOptionalComma())) {
2466 conditionType = resultType;
2467 if (parser.parseType(resultType))
2468 return failure();
2469 } else {
2470 conditionType = parser.getBuilder().getI1Type();
2471 }
2472
2473 result.addTypes(resultType);
2474 return parser.resolveOperands(operands,
2475 {conditionType, resultType, resultType},
2476 parser.getNameLoc(), result.operands);
2477}
2478
2479void arith::SelectOp::print(OpAsmPrinter &p) {
2480 p << " " << getOperands();
2481 p.printOptionalAttrDict((*this)->getAttrs());
2482 p << " : ";
2483 if (ShapedType condType =
2484 llvm::dyn_cast<ShapedType>(getCondition().getType()))
2485 p << condType << ", ";
2486 p << getType();
2487}
2488
2489LogicalResult arith::SelectOp::verify() {
2490 Type conditionType = getCondition().getType();
2491 if (conditionType.isSignlessInteger(1))
2492 return success();
2493
2494 // If the result type is a vector or tensor, the type can be a mask with the
2495 // same elements.
2496 Type resultType = getType();
2497 if (!llvm::isa<TensorType, VectorType>(resultType))
2498 return emitOpError() << "expected condition to be a signless i1, but got "
2499 << conditionType;
2500 Type shapedConditionType = getI1SameShape(resultType);
2501 if (conditionType != shapedConditionType) {
2502 return emitOpError() << "expected condition type to have the same shape "
2503 "as the result type, expected "
2504 << shapedConditionType << ", but got "
2505 << conditionType;
2506 }
2507 return success();
2508}
2509//===----------------------------------------------------------------------===//
2510// ShLIOp
2511//===----------------------------------------------------------------------===//
2512
2513OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2514 // shli(x, 0) -> x
2515 if (matchPattern(adaptor.getRhs(), m_Zero()))
2516 return getLhs();
2517 // Don't fold if shifting more or equal than the bit width.
2518 bool bounded = false;
2519 auto result = constFoldBinaryOp<IntegerAttr>(
2520 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2521 bounded = b.ult(b.getBitWidth());
2522 return a.shl(b);
2523 });
2524 return bounded ? result : Attribute();
2525}
2526
2527//===----------------------------------------------------------------------===//
2528// ShRUIOp
2529//===----------------------------------------------------------------------===//
2530
2531OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2532 // shrui(x, 0) -> x
2533 if (matchPattern(adaptor.getRhs(), m_Zero()))
2534 return getLhs();
2535 // Don't fold if shifting more or equal than the bit width.
2536 bool bounded = false;
2537 auto result = constFoldBinaryOp<IntegerAttr>(
2538 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2539 bounded = b.ult(b.getBitWidth());
2540 return a.lshr(b);
2541 });
2542 return bounded ? result : Attribute();
2543}
2544
2545//===----------------------------------------------------------------------===//
2546// ShRSIOp
2547//===----------------------------------------------------------------------===//
2548
2549OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2550 // shrsi(x, 0) -> x
2551 if (matchPattern(adaptor.getRhs(), m_Zero()))
2552 return getLhs();
2553 // Don't fold if shifting more or equal than the bit width.
2554 bool bounded = false;
2555 auto result = constFoldBinaryOp<IntegerAttr>(
2556 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2557 bounded = b.ult(b.getBitWidth());
2558 return a.ashr(b);
2559 });
2560 return bounded ? result : Attribute();
2561}
2562
2563//===----------------------------------------------------------------------===//
2564// Atomic Enum
2565//===----------------------------------------------------------------------===//
2566
2567/// Returns the identity value attribute associated with an AtomicRMWKind op.
2568TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2569 OpBuilder &builder, Location loc,
2570 bool useOnlyFiniteValue) {
2571 switch (kind) {
2572 case AtomicRMWKind::maximumf: {
2573 const llvm::fltSemantics &semantic =
2574 llvm::cast<FloatType>(resultType).getFloatSemantics();
2575 APFloat identity = useOnlyFiniteValue
2576 ? APFloat::getLargest(Sem: semantic, /*Negative=*/true)
2577 : APFloat::getInf(Sem: semantic, /*Negative=*/true);
2578 return builder.getFloatAttr(resultType, identity);
2579 }
2580 case AtomicRMWKind::maxnumf: {
2581 const llvm::fltSemantics &semantic =
2582 llvm::cast<FloatType>(resultType).getFloatSemantics();
2583 APFloat identity = APFloat::getNaN(Sem: semantic, /*Negative=*/true);
2584 return builder.getFloatAttr(resultType, identity);
2585 }
2586 case AtomicRMWKind::addf:
2587 case AtomicRMWKind::addi:
2588 case AtomicRMWKind::maxu:
2589 case AtomicRMWKind::ori:
2590 return builder.getZeroAttr(resultType);
2591 case AtomicRMWKind::andi:
2592 return builder.getIntegerAttr(
2593 resultType,
2594 APInt::getAllOnes(numBits: llvm::cast<IntegerType>(resultType).getWidth()));
2595 case AtomicRMWKind::maxs:
2596 return builder.getIntegerAttr(
2597 resultType, APInt::getSignedMinValue(
2598 numBits: llvm::cast<IntegerType>(resultType).getWidth()));
2599 case AtomicRMWKind::minimumf: {
2600 const llvm::fltSemantics &semantic =
2601 llvm::cast<FloatType>(resultType).getFloatSemantics();
2602 APFloat identity = useOnlyFiniteValue
2603 ? APFloat::getLargest(Sem: semantic, /*Negative=*/false)
2604 : APFloat::getInf(Sem: semantic, /*Negative=*/false);
2605
2606 return builder.getFloatAttr(resultType, identity);
2607 }
2608 case AtomicRMWKind::minnumf: {
2609 const llvm::fltSemantics &semantic =
2610 llvm::cast<FloatType>(resultType).getFloatSemantics();
2611 APFloat identity = APFloat::getNaN(Sem: semantic, /*Negative=*/false);
2612 return builder.getFloatAttr(resultType, identity);
2613 }
2614 case AtomicRMWKind::mins:
2615 return builder.getIntegerAttr(
2616 resultType, APInt::getSignedMaxValue(
2617 numBits: llvm::cast<IntegerType>(resultType).getWidth()));
2618 case AtomicRMWKind::minu:
2619 return builder.getIntegerAttr(
2620 resultType,
2621 APInt::getMaxValue(numBits: llvm::cast<IntegerType>(resultType).getWidth()));
2622 case AtomicRMWKind::muli:
2623 return builder.getIntegerAttr(resultType, 1);
2624 case AtomicRMWKind::mulf:
2625 return builder.getFloatAttr(resultType, 1);
2626 // TODO: Add remaining reduction operations.
2627 default:
2628 (void)emitOptionalError(loc, args: "Reduction operation type not supported");
2629 break;
2630 }
2631 return nullptr;
2632}
2633
2634/// Return the identity numeric value associated to the give op.
2635std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2636 std::optional<AtomicRMWKind> maybeKind =
2637 llvm::TypeSwitch<Operation *, std::optional<AtomicRMWKind>>(op)
2638 // Floating-point operations.
2639 .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2640 .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2641 .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2642 .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2643 .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2644 .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2645 // Integer operations.
2646 .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2647 .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2648 .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2649 .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2650 .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2651 .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2652 .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2653 .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2654 .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2655 .Default([](Operation *op) { return std::nullopt; });
2656 if (!maybeKind) {
2657 return std::nullopt;
2658 }
2659
2660 bool useOnlyFiniteValue = false;
2661 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2662 if (fmfOpInterface) {
2663 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2664 useOnlyFiniteValue =
2665 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2666 }
2667
2668 // Builder only used as helper for attribute creation.
2669 OpBuilder b(op->getContext());
2670 Type resultType = op->getResult(idx: 0).getType();
2671
2672 return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2673 useOnlyFiniteValue);
2674}
2675
2676/// Returns the identity value associated with an AtomicRMWKind op.
2677Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2678 OpBuilder &builder, Location loc,
2679 bool useOnlyFiniteValue) {
2680 auto attr =
2681 getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2682 return builder.create<arith::ConstantOp>(loc, attr);
2683}
2684
2685/// Return the value obtained by applying the reduction operation kind
2686/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2687Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2688 Location loc, Value lhs, Value rhs) {
2689 switch (op) {
2690 case AtomicRMWKind::addf:
2691 return builder.create<arith::AddFOp>(loc, lhs, rhs);
2692 case AtomicRMWKind::addi:
2693 return builder.create<arith::AddIOp>(loc, lhs, rhs);
2694 case AtomicRMWKind::mulf:
2695 return builder.create<arith::MulFOp>(loc, lhs, rhs);
2696 case AtomicRMWKind::muli:
2697 return builder.create<arith::MulIOp>(loc, lhs, rhs);
2698 case AtomicRMWKind::maximumf:
2699 return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2700 case AtomicRMWKind::minimumf:
2701 return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2702 case AtomicRMWKind::maxnumf:
2703 return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2704 case AtomicRMWKind::minnumf:
2705 return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
2706 case AtomicRMWKind::maxs:
2707 return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2708 case AtomicRMWKind::mins:
2709 return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2710 case AtomicRMWKind::maxu:
2711 return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2712 case AtomicRMWKind::minu:
2713 return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2714 case AtomicRMWKind::ori:
2715 return builder.create<arith::OrIOp>(loc, lhs, rhs);
2716 case AtomicRMWKind::andi:
2717 return builder.create<arith::AndIOp>(loc, lhs, rhs);
2718 // TODO: Add remaining reduction operations.
2719 default:
2720 (void)emitOptionalError(loc, args: "Reduction operation type not supported");
2721 break;
2722 }
2723 return nullptr;
2724}
2725
2726//===----------------------------------------------------------------------===//
2727// TableGen'd op method definitions
2728//===----------------------------------------------------------------------===//
2729
2730#define GET_OP_CLASSES
2731#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2732
2733//===----------------------------------------------------------------------===//
2734// TableGen'd enum attribute definitions
2735//===----------------------------------------------------------------------===//
2736
2737#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
2738

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Arith/IR/ArithOps.cpp