1//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cassert>
10#include <cstdint>
11#include <functional>
12#include <utility>
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/CommonFolders.h"
16#include "mlir/Dialect/UB/IR/UBOps.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinAttributeInterfaces.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/OpImplementation.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/IR/TypeUtilities.h"
24#include "mlir/Support/LogicalResult.h"
25
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallString.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34
35using namespace mlir;
36using namespace mlir::arith;
37
38//===----------------------------------------------------------------------===//
39// Pattern helpers
40//===----------------------------------------------------------------------===//
41
42static IntegerAttr
43applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
44 Attribute rhs,
45 function_ref<APInt(const APInt &, const APInt &)> binFn) {
46 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
47 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
48 APInt value = binFn(lhsVal, rhsVal);
49 return IntegerAttr::get(res.getType(), value);
50}
51
52static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53 Attribute lhs, Attribute rhs) {
54 return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
55}
56
57static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58 Attribute lhs, Attribute rhs) {
59 return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
60}
61
62static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63 Attribute lhs, Attribute rhs) {
64 return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
65}
66
67/// Invert an integer comparison predicate.
68arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
69 switch (pred) {
70 case arith::CmpIPredicate::eq:
71 return arith::CmpIPredicate::ne;
72 case arith::CmpIPredicate::ne:
73 return arith::CmpIPredicate::eq;
74 case arith::CmpIPredicate::slt:
75 return arith::CmpIPredicate::sge;
76 case arith::CmpIPredicate::sle:
77 return arith::CmpIPredicate::sgt;
78 case arith::CmpIPredicate::sgt:
79 return arith::CmpIPredicate::sle;
80 case arith::CmpIPredicate::sge:
81 return arith::CmpIPredicate::slt;
82 case arith::CmpIPredicate::ult:
83 return arith::CmpIPredicate::uge;
84 case arith::CmpIPredicate::ule:
85 return arith::CmpIPredicate::ugt;
86 case arith::CmpIPredicate::ugt:
87 return arith::CmpIPredicate::ule;
88 case arith::CmpIPredicate::uge:
89 return arith::CmpIPredicate::ult;
90 }
91 llvm_unreachable("unknown cmpi predicate kind");
92}
93
94/// Equivalent to
95/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
96///
97/// Not possible to implement as chain of calls as this would introduce a
98/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
99/// on the LLVM dialect and on translation to LLVM.
100static llvm::RoundingMode
101convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
102 switch (roundingMode) {
103 case RoundingMode::downward:
104 return llvm::RoundingMode::TowardNegative;
105 case RoundingMode::to_nearest_away:
106 return llvm::RoundingMode::NearestTiesToAway;
107 case RoundingMode::to_nearest_even:
108 return llvm::RoundingMode::NearestTiesToEven;
109 case RoundingMode::toward_zero:
110 return llvm::RoundingMode::TowardZero;
111 case RoundingMode::upward:
112 return llvm::RoundingMode::TowardPositive;
113 }
114 llvm_unreachable("Unhandled rounding mode");
115}
116
117static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
118 return arith::CmpIPredicateAttr::get(pred.getContext(),
119 invertPredicate(pred.getValue()));
120}
121
122static int64_t getScalarOrElementWidth(Type type) {
123 Type elemTy = getElementTypeOrSelf(type);
124 if (elemTy.isIntOrFloat())
125 return elemTy.getIntOrFloatBitWidth();
126
127 return -1;
128}
129
130static int64_t getScalarOrElementWidth(Value value) {
131 return getScalarOrElementWidth(type: value.getType());
132}
133
134static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
135 APInt value;
136 if (matchPattern(attr, m_ConstantInt(&value)))
137 return value;
138
139 return failure();
140}
141
142static Attribute getBoolAttribute(Type type, bool value) {
143 auto boolAttr = BoolAttr::get(context: type.getContext(), value);
144 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
145 if (!shapedType)
146 return boolAttr;
147 return DenseElementsAttr::get(shapedType, boolAttr);
148}
149
150//===----------------------------------------------------------------------===//
151// TableGen'd canonicalization patterns
152//===----------------------------------------------------------------------===//
153
154namespace {
155#include "ArithCanonicalization.inc"
156} // namespace
157
158//===----------------------------------------------------------------------===//
159// Common helpers
160//===----------------------------------------------------------------------===//
161
162/// Return the type of the same shape (scalar, vector or tensor) containing i1.
163static Type getI1SameShape(Type type) {
164 auto i1Type = IntegerType::get(type.getContext(), 1);
165 if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
166 return shapedType.cloneWith(std::nullopt, i1Type);
167 if (llvm::isa<UnrankedTensorType>(type))
168 return UnrankedTensorType::get(i1Type);
169 return i1Type;
170}
171
172//===----------------------------------------------------------------------===//
173// ConstantOp
174//===----------------------------------------------------------------------===//
175
176void arith::ConstantOp::getAsmResultNames(
177 function_ref<void(Value, StringRef)> setNameFn) {
178 auto type = getType();
179 if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
180 auto intType = llvm::dyn_cast<IntegerType>(type);
181
182 // Sugar i1 constants with 'true' and 'false'.
183 if (intType && intType.getWidth() == 1)
184 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
185
186 // Otherwise, build a complex name with the value and type.
187 SmallString<32> specialNameBuffer;
188 llvm::raw_svector_ostream specialName(specialNameBuffer);
189 specialName << 'c' << intCst.getValue();
190 if (intType)
191 specialName << '_' << type;
192 setNameFn(getResult(), specialName.str());
193 } else {
194 setNameFn(getResult(), "cst");
195 }
196}
197
198/// TODO: disallow arith.constant to return anything other than signless integer
199/// or float like.
200LogicalResult arith::ConstantOp::verify() {
201 auto type = getType();
202 // The value's type must match the return type.
203 if (getValue().getType() != type) {
204 return emitOpError() << "value type " << getValue().getType()
205 << " must match return type: " << type;
206 }
207 // Integer values must be signless.
208 if (llvm::isa<IntegerType>(type) &&
209 !llvm::cast<IntegerType>(type).isSignless())
210 return emitOpError("integer return type must be signless");
211 // Any float or elements attribute are acceptable.
212 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
213 return emitOpError(
214 "value must be an integer, float, or elements attribute");
215 }
216
217 // Note, we could relax this for vectors with 1 scalable dim, e.g.:
218 // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
219 // However, this would most likely require updating the lowerings to LLVM.
220 auto vecType = dyn_cast<VectorType>(type);
221 if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
222 return emitOpError(
223 "intializing scalable vectors with elements attribute is not supported"
224 " unless it's a vector splat");
225 return success();
226}
227
228bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
229 // The value's type must be the same as the provided type.
230 auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
231 if (!typedAttr || typedAttr.getType() != type)
232 return false;
233 // Integer values must be signless.
234 if (llvm::isa<IntegerType>(type) &&
235 !llvm::cast<IntegerType>(type).isSignless())
236 return false;
237 // Integer, float, and element attributes are buildable.
238 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
239}
240
241ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
242 Type type, Location loc) {
243 if (isBuildableWith(value, type))
244 return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
245 return nullptr;
246}
247
248OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
249
250void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
251 int64_t value, unsigned width) {
252 auto type = builder.getIntegerType(width);
253 arith::ConstantOp::build(builder, result, type,
254 builder.getIntegerAttr(type, value));
255}
256
257void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
258 int64_t value, Type type) {
259 assert(type.isSignlessInteger() &&
260 "ConstantIntOp can only have signless integer type values");
261 arith::ConstantOp::build(builder, result, type,
262 builder.getIntegerAttr(type, value));
263}
264
265bool arith::ConstantIntOp::classof(Operation *op) {
266 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
267 return constOp.getType().isSignlessInteger();
268 return false;
269}
270
271void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
272 const APFloat &value, FloatType type) {
273 arith::ConstantOp::build(builder, result, type,
274 builder.getFloatAttr(type, value));
275}
276
277bool arith::ConstantFloatOp::classof(Operation *op) {
278 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
279 return llvm::isa<FloatType>(constOp.getType());
280 return false;
281}
282
283void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
284 int64_t value) {
285 arith::ConstantOp::build(builder, result, builder.getIndexType(),
286 builder.getIndexAttr(value));
287}
288
289bool arith::ConstantIndexOp::classof(Operation *op) {
290 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
291 return constOp.getType().isIndex();
292 return false;
293}
294
295//===----------------------------------------------------------------------===//
296// AddIOp
297//===----------------------------------------------------------------------===//
298
299OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
300 // addi(x, 0) -> x
301 if (matchPattern(adaptor.getRhs(), m_Zero()))
302 return getLhs();
303
304 // addi(subi(a, b), b) -> a
305 if (auto sub = getLhs().getDefiningOp<SubIOp>())
306 if (getRhs() == sub.getRhs())
307 return sub.getLhs();
308
309 // addi(b, subi(a, b)) -> a
310 if (auto sub = getRhs().getDefiningOp<SubIOp>())
311 if (getLhs() == sub.getRhs())
312 return sub.getLhs();
313
314 return constFoldBinaryOp<IntegerAttr>(
315 adaptor.getOperands(),
316 [](APInt a, const APInt &b) { return std::move(a) + b; });
317}
318
319void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
320 MLIRContext *context) {
321 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
322 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
323}
324
325//===----------------------------------------------------------------------===//
326// AddUIExtendedOp
327//===----------------------------------------------------------------------===//
328
329std::optional<SmallVector<int64_t, 4>>
330arith::AddUIExtendedOp::getShapeForUnroll() {
331 if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
332 return llvm::to_vector<4>(vt.getShape());
333 return std::nullopt;
334}
335
336// Returns the overflow bit, assuming that `sum` is the result of unsigned
337// addition of `operand` and another number.
338static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
339 return sum.ult(RHS: operand) ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1);
340}
341
342LogicalResult
343arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
344 SmallVectorImpl<OpFoldResult> &results) {
345 Type overflowTy = getOverflow().getType();
346 // addui_extended(x, 0) -> x, false
347 if (matchPattern(getRhs(), m_Zero())) {
348 Builder builder(getContext());
349 auto falseValue = builder.getZeroAttr(overflowTy);
350
351 results.push_back(getLhs());
352 results.push_back(falseValue);
353 return success();
354 }
355
356 // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
357 // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
358 // operands. If that succeeds, calculate the overflow bit based on the sum
359 // and the first (constant) operand, `lhs`.
360 if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
361 adaptor.getOperands(),
362 [](APInt a, const APInt &b) { return std::move(a) + b; })) {
363 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
364 ArrayRef({sumAttr, adaptor.getLhs()}),
365 getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
366 calculateUnsignedOverflow);
367 if (!overflowAttr)
368 return failure();
369
370 results.push_back(sumAttr);
371 results.push_back(overflowAttr);
372 return success();
373 }
374
375 return failure();
376}
377
378void arith::AddUIExtendedOp::getCanonicalizationPatterns(
379 RewritePatternSet &patterns, MLIRContext *context) {
380 patterns.add<AddUIExtendedToAddI>(context);
381}
382
383//===----------------------------------------------------------------------===//
384// SubIOp
385//===----------------------------------------------------------------------===//
386
387OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
388 // subi(x,x) -> 0
389 if (getOperand(0) == getOperand(1))
390 return Builder(getContext()).getZeroAttr(getType());
391 // subi(x,0) -> x
392 if (matchPattern(adaptor.getRhs(), m_Zero()))
393 return getLhs();
394
395 if (auto add = getLhs().getDefiningOp<AddIOp>()) {
396 // subi(addi(a, b), b) -> a
397 if (getRhs() == add.getRhs())
398 return add.getLhs();
399 // subi(addi(a, b), a) -> b
400 if (getRhs() == add.getLhs())
401 return add.getRhs();
402 }
403
404 return constFoldBinaryOp<IntegerAttr>(
405 adaptor.getOperands(),
406 [](APInt a, const APInt &b) { return std::move(a) - b; });
407}
408
409void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
410 MLIRContext *context) {
411 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
412 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
413 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
414}
415
416//===----------------------------------------------------------------------===//
417// MulIOp
418//===----------------------------------------------------------------------===//
419
420OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
421 // muli(x, 0) -> 0
422 if (matchPattern(adaptor.getRhs(), m_Zero()))
423 return getRhs();
424 // muli(x, 1) -> x
425 if (matchPattern(adaptor.getRhs(), m_One()))
426 return getLhs();
427 // TODO: Handle the overflow case.
428
429 // default folder
430 return constFoldBinaryOp<IntegerAttr>(
431 adaptor.getOperands(),
432 [](const APInt &a, const APInt &b) { return a * b; });
433}
434
435void arith::MulIOp::getAsmResultNames(
436 function_ref<void(Value, StringRef)> setNameFn) {
437 if (!isa<IndexType>(getType()))
438 return;
439
440 // Match vector.vscale by name to avoid depending on the vector dialect (which
441 // is a circular dependency).
442 auto isVscale = [](Operation *op) {
443 return op && op->getName().getStringRef() == "vector.vscale";
444 };
445
446 IntegerAttr baseValue;
447 auto isVscaleExpr = [&](Value a, Value b) {
448 return matchPattern(a, m_Constant(&baseValue)) &&
449 isVscale(b.getDefiningOp());
450 };
451
452 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
453 return;
454
455 // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
456 SmallString<32> specialNameBuffer;
457 llvm::raw_svector_ostream specialName(specialNameBuffer);
458 specialName << 'c' << baseValue.getInt() << "_vscale";
459 setNameFn(getResult(), specialName.str());
460}
461
462void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
463 MLIRContext *context) {
464 patterns.add<MulIMulIConstant>(context);
465}
466
467//===----------------------------------------------------------------------===//
468// MulSIExtendedOp
469//===----------------------------------------------------------------------===//
470
471std::optional<SmallVector<int64_t, 4>>
472arith::MulSIExtendedOp::getShapeForUnroll() {
473 if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
474 return llvm::to_vector<4>(vt.getShape());
475 return std::nullopt;
476}
477
478LogicalResult
479arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
480 SmallVectorImpl<OpFoldResult> &results) {
481 // mulsi_extended(x, 0) -> 0, 0
482 if (matchPattern(adaptor.getRhs(), m_Zero())) {
483 Attribute zero = adaptor.getRhs();
484 results.push_back(zero);
485 results.push_back(zero);
486 return success();
487 }
488
489 // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
490 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
491 adaptor.getOperands(),
492 [](const APInt &a, const APInt &b) { return a * b; })) {
493 // Invoke the constant fold helper again to calculate the 'high' result.
494 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
495 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
496 return llvm::APIntOps::mulhs(a, b);
497 });
498 assert(highAttr && "Unexpected constant-folding failure");
499
500 results.push_back(lowAttr);
501 results.push_back(highAttr);
502 return success();
503 }
504
505 return failure();
506}
507
508void arith::MulSIExtendedOp::getCanonicalizationPatterns(
509 RewritePatternSet &patterns, MLIRContext *context) {
510 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
511}
512
513//===----------------------------------------------------------------------===//
514// MulUIExtendedOp
515//===----------------------------------------------------------------------===//
516
517std::optional<SmallVector<int64_t, 4>>
518arith::MulUIExtendedOp::getShapeForUnroll() {
519 if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
520 return llvm::to_vector<4>(vt.getShape());
521 return std::nullopt;
522}
523
524LogicalResult
525arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
526 SmallVectorImpl<OpFoldResult> &results) {
527 // mului_extended(x, 0) -> 0, 0
528 if (matchPattern(adaptor.getRhs(), m_Zero())) {
529 Attribute zero = adaptor.getRhs();
530 results.push_back(zero);
531 results.push_back(zero);
532 return success();
533 }
534
535 // mului_extended(x, 1) -> x, 0
536 if (matchPattern(adaptor.getRhs(), m_One())) {
537 Builder builder(getContext());
538 Attribute zero = builder.getZeroAttr(getLhs().getType());
539 results.push_back(getLhs());
540 results.push_back(zero);
541 return success();
542 }
543
544 // mului_extended(cst_a, cst_b) -> cst_low, cst_high
545 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
546 adaptor.getOperands(),
547 [](const APInt &a, const APInt &b) { return a * b; })) {
548 // Invoke the constant fold helper again to calculate the 'high' result.
549 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
550 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
551 return llvm::APIntOps::mulhu(a, b);
552 });
553 assert(highAttr && "Unexpected constant-folding failure");
554
555 results.push_back(lowAttr);
556 results.push_back(highAttr);
557 return success();
558 }
559
560 return failure();
561}
562
563void arith::MulUIExtendedOp::getCanonicalizationPatterns(
564 RewritePatternSet &patterns, MLIRContext *context) {
565 patterns.add<MulUIExtendedToMulI>(context);
566}
567
568//===----------------------------------------------------------------------===//
569// DivUIOp
570//===----------------------------------------------------------------------===//
571
572OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
573 // divui (x, 1) -> x.
574 if (matchPattern(adaptor.getRhs(), m_One()))
575 return getLhs();
576
577 // Don't fold if it would require a division by zero.
578 bool div0 = false;
579 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
580 [&](APInt a, const APInt &b) {
581 if (div0 || !b) {
582 div0 = true;
583 return a;
584 }
585 return a.udiv(b);
586 });
587
588 return div0 ? Attribute() : result;
589}
590
591Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
592 // X / 0 => UB
593 return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
594 : Speculation::NotSpeculatable;
595}
596
597//===----------------------------------------------------------------------===//
598// DivSIOp
599//===----------------------------------------------------------------------===//
600
601OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
602 // divsi (x, 1) -> x.
603 if (matchPattern(adaptor.getRhs(), m_One()))
604 return getLhs();
605
606 // Don't fold if it would overflow or if it requires a division by zero.
607 bool overflowOrDiv0 = false;
608 auto result = constFoldBinaryOp<IntegerAttr>(
609 adaptor.getOperands(), [&](APInt a, const APInt &b) {
610 if (overflowOrDiv0 || !b) {
611 overflowOrDiv0 = true;
612 return a;
613 }
614 return a.sdiv_ov(b, overflowOrDiv0);
615 });
616
617 return overflowOrDiv0 ? Attribute() : result;
618}
619
620Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
621 bool mayHaveUB = true;
622
623 APInt constRHS;
624 // X / 0 => UB
625 // INT_MIN / -1 => UB
626 if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
627 mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
628
629 return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable;
630}
631
632//===----------------------------------------------------------------------===//
633// Ceil and floor division folding helpers
634//===----------------------------------------------------------------------===//
635
636static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
637 bool &overflow) {
638 // Returns (a-1)/b + 1
639 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
640 APInt val = a.ssub_ov(RHS: one, Overflow&: overflow).sdiv_ov(RHS: b, Overflow&: overflow);
641 return val.sadd_ov(RHS: one, Overflow&: overflow);
642}
643
644//===----------------------------------------------------------------------===//
645// CeilDivUIOp
646//===----------------------------------------------------------------------===//
647
648OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
649 // ceildivui (x, 1) -> x.
650 if (matchPattern(adaptor.getRhs(), m_One()))
651 return getLhs();
652
653 bool overflowOrDiv0 = false;
654 auto result = constFoldBinaryOp<IntegerAttr>(
655 adaptor.getOperands(), [&](APInt a, const APInt &b) {
656 if (overflowOrDiv0 || !b) {
657 overflowOrDiv0 = true;
658 return a;
659 }
660 APInt quotient = a.udiv(b);
661 if (!a.urem(b))
662 return quotient;
663 APInt one(a.getBitWidth(), 1, true);
664 return quotient.uadd_ov(one, overflowOrDiv0);
665 });
666
667 return overflowOrDiv0 ? Attribute() : result;
668}
669
670Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
671 // X / 0 => UB
672 return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
673 : Speculation::NotSpeculatable;
674}
675
676//===----------------------------------------------------------------------===//
677// CeilDivSIOp
678//===----------------------------------------------------------------------===//
679
680OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
681 // ceildivsi (x, 1) -> x.
682 if (matchPattern(adaptor.getRhs(), m_One()))
683 return getLhs();
684
685 // Don't fold if it would overflow or if it requires a division by zero.
686 bool overflowOrDiv0 = false;
687 auto result = constFoldBinaryOp<IntegerAttr>(
688 adaptor.getOperands(), [&](APInt a, const APInt &b) {
689 if (overflowOrDiv0 || !b) {
690 overflowOrDiv0 = true;
691 return a;
692 }
693 if (!a)
694 return a;
695 // After this point we know that neither a or b are zero.
696 unsigned bits = a.getBitWidth();
697 APInt zero = APInt::getZero(bits);
698 bool aGtZero = a.sgt(zero);
699 bool bGtZero = b.sgt(zero);
700 if (aGtZero && bGtZero) {
701 // Both positive, return ceil(a, b).
702 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
703 }
704 if (!aGtZero && !bGtZero) {
705 // Both negative, return ceil(-a, -b).
706 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
707 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
708 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
709 }
710 if (!aGtZero && bGtZero) {
711 // A is negative, b is positive, return - ( -a / b).
712 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
713 APInt div = posA.sdiv_ov(b, overflowOrDiv0);
714 return zero.ssub_ov(div, overflowOrDiv0);
715 }
716 // A is positive, b is negative, return - (a / -b).
717 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
718 APInt div = a.sdiv_ov(posB, overflowOrDiv0);
719 return zero.ssub_ov(div, overflowOrDiv0);
720 });
721
722 return overflowOrDiv0 ? Attribute() : result;
723}
724
725Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
726 bool mayHaveUB = true;
727
728 APInt constRHS;
729 // X / 0 => UB
730 // INT_MIN / -1 => UB
731 if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
732 mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
733
734 return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable;
735}
736
737//===----------------------------------------------------------------------===//
738// FloorDivSIOp
739//===----------------------------------------------------------------------===//
740
741OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
742 // floordivsi (x, 1) -> x.
743 if (matchPattern(adaptor.getRhs(), m_One()))
744 return getLhs();
745
746 // Don't fold if it would overflow or if it requires a division by zero.
747 bool overflowOrDiv = false;
748 auto result = constFoldBinaryOp<IntegerAttr>(
749 adaptor.getOperands(), [&](APInt a, const APInt &b) {
750 if (b.isZero()) {
751 overflowOrDiv = true;
752 return a;
753 }
754 return a.sfloordiv_ov(b, overflowOrDiv);
755 });
756
757 return overflowOrDiv ? Attribute() : result;
758}
759
760//===----------------------------------------------------------------------===//
761// RemUIOp
762//===----------------------------------------------------------------------===//
763
764OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
765 // remui (x, 1) -> 0.
766 if (matchPattern(adaptor.getRhs(), m_One()))
767 return Builder(getContext()).getZeroAttr(getType());
768
769 // Don't fold if it would require a division by zero.
770 bool div0 = false;
771 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
772 [&](APInt a, const APInt &b) {
773 if (div0 || b.isZero()) {
774 div0 = true;
775 return a;
776 }
777 return a.urem(b);
778 });
779
780 return div0 ? Attribute() : result;
781}
782
783//===----------------------------------------------------------------------===//
784// RemSIOp
785//===----------------------------------------------------------------------===//
786
787OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
788 // remsi (x, 1) -> 0.
789 if (matchPattern(adaptor.getRhs(), m_One()))
790 return Builder(getContext()).getZeroAttr(getType());
791
792 // Don't fold if it would require a division by zero.
793 bool div0 = false;
794 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
795 [&](APInt a, const APInt &b) {
796 if (div0 || b.isZero()) {
797 div0 = true;
798 return a;
799 }
800 return a.srem(b);
801 });
802
803 return div0 ? Attribute() : result;
804}
805
806//===----------------------------------------------------------------------===//
807// AndIOp
808//===----------------------------------------------------------------------===//
809
810/// Fold `and(a, and(a, b))` to `and(a, b)`
811static Value foldAndIofAndI(arith::AndIOp op) {
812 for (bool reversePrev : {false, true}) {
813 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
814 .getDefiningOp<arith::AndIOp>();
815 if (!prev)
816 continue;
817
818 Value other = (reversePrev ? op.getLhs() : op.getRhs());
819 if (other != prev.getLhs() && other != prev.getRhs())
820 continue;
821
822 return prev.getResult();
823 }
824 return {};
825}
826
827OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
828 /// and(x, 0) -> 0
829 if (matchPattern(adaptor.getRhs(), m_Zero()))
830 return getRhs();
831 /// and(x, allOnes) -> x
832 APInt intValue;
833 if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
834 intValue.isAllOnes())
835 return getLhs();
836 /// and(x, not(x)) -> 0
837 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
838 m_ConstantInt(&intValue))) &&
839 intValue.isAllOnes())
840 return Builder(getContext()).getZeroAttr(getType());
841 /// and(not(x), x) -> 0
842 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
843 m_ConstantInt(&intValue))) &&
844 intValue.isAllOnes())
845 return Builder(getContext()).getZeroAttr(getType());
846
847 /// and(a, and(a, b)) -> and(a, b)
848 if (Value result = foldAndIofAndI(*this))
849 return result;
850
851 return constFoldBinaryOp<IntegerAttr>(
852 adaptor.getOperands(),
853 [](APInt a, const APInt &b) { return std::move(a) & b; });
854}
855
856//===----------------------------------------------------------------------===//
857// OrIOp
858//===----------------------------------------------------------------------===//
859
860OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
861 if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
862 /// or(x, 0) -> x
863 if (rhsVal.isZero())
864 return getLhs();
865 /// or(x, <all ones>) -> <all ones>
866 if (rhsVal.isAllOnes())
867 return adaptor.getRhs();
868 }
869
870 APInt intValue;
871 /// or(x, xor(x, 1)) -> 1
872 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
873 m_ConstantInt(&intValue))) &&
874 intValue.isAllOnes())
875 return getRhs().getDefiningOp<XOrIOp>().getRhs();
876 /// or(xor(x, 1), x) -> 1
877 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
878 m_ConstantInt(&intValue))) &&
879 intValue.isAllOnes())
880 return getLhs().getDefiningOp<XOrIOp>().getRhs();
881
882 return constFoldBinaryOp<IntegerAttr>(
883 adaptor.getOperands(),
884 [](APInt a, const APInt &b) { return std::move(a) | b; });
885}
886
887//===----------------------------------------------------------------------===//
888// XOrIOp
889//===----------------------------------------------------------------------===//
890
891OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
892 /// xor(x, 0) -> x
893 if (matchPattern(adaptor.getRhs(), m_Zero()))
894 return getLhs();
895 /// xor(x, x) -> 0
896 if (getLhs() == getRhs())
897 return Builder(getContext()).getZeroAttr(getType());
898 /// xor(xor(x, a), a) -> x
899 /// xor(xor(a, x), a) -> x
900 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
901 if (prev.getRhs() == getRhs())
902 return prev.getLhs();
903 if (prev.getLhs() == getRhs())
904 return prev.getRhs();
905 }
906 /// xor(a, xor(x, a)) -> x
907 /// xor(a, xor(a, x)) -> x
908 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
909 if (prev.getRhs() == getLhs())
910 return prev.getLhs();
911 if (prev.getLhs() == getLhs())
912 return prev.getRhs();
913 }
914
915 return constFoldBinaryOp<IntegerAttr>(
916 adaptor.getOperands(),
917 [](APInt a, const APInt &b) { return std::move(a) ^ b; });
918}
919
920void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
921 MLIRContext *context) {
922 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
923}
924
925//===----------------------------------------------------------------------===//
926// NegFOp
927//===----------------------------------------------------------------------===//
928
929OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
930 /// negf(negf(x)) -> x
931 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
932 return op.getOperand();
933 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
934 [](const APFloat &a) { return -a; });
935}
936
937//===----------------------------------------------------------------------===//
938// AddFOp
939//===----------------------------------------------------------------------===//
940
941OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
942 // addf(x, -0) -> x
943 if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
944 return getLhs();
945
946 return constFoldBinaryOp<FloatAttr>(
947 adaptor.getOperands(),
948 [](const APFloat &a, const APFloat &b) { return a + b; });
949}
950
951//===----------------------------------------------------------------------===//
952// SubFOp
953//===----------------------------------------------------------------------===//
954
955OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
956 // subf(x, +0) -> x
957 if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
958 return getLhs();
959
960 return constFoldBinaryOp<FloatAttr>(
961 adaptor.getOperands(),
962 [](const APFloat &a, const APFloat &b) { return a - b; });
963}
964
965//===----------------------------------------------------------------------===//
966// MaximumFOp
967//===----------------------------------------------------------------------===//
968
969OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
970 // maximumf(x,x) -> x
971 if (getLhs() == getRhs())
972 return getRhs();
973
974 // maximumf(x, -inf) -> x
975 if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
976 return getLhs();
977
978 return constFoldBinaryOp<FloatAttr>(
979 adaptor.getOperands(),
980 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
981}
982
983//===----------------------------------------------------------------------===//
984// MaxNumFOp
985//===----------------------------------------------------------------------===//
986
987OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
988 // maxnumf(x,x) -> x
989 if (getLhs() == getRhs())
990 return getRhs();
991
992 // maxnumf(x, -inf) -> x
993 if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
994 return getLhs();
995
996 return constFoldBinaryOp<FloatAttr>(
997 adaptor.getOperands(),
998 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
999}
1000
1001//===----------------------------------------------------------------------===//
1002// MaxSIOp
1003//===----------------------------------------------------------------------===//
1004
1005OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1006 // maxsi(x,x) -> x
1007 if (getLhs() == getRhs())
1008 return getRhs();
1009
1010 if (APInt intValue;
1011 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1012 // maxsi(x,MAX_INT) -> MAX_INT
1013 if (intValue.isMaxSignedValue())
1014 return getRhs();
1015 // maxsi(x, MIN_INT) -> x
1016 if (intValue.isMinSignedValue())
1017 return getLhs();
1018 }
1019
1020 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1021 [](const APInt &a, const APInt &b) {
1022 return llvm::APIntOps::smax(a, b);
1023 });
1024}
1025
1026//===----------------------------------------------------------------------===//
1027// MaxUIOp
1028//===----------------------------------------------------------------------===//
1029
1030OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1031 // maxui(x,x) -> x
1032 if (getLhs() == getRhs())
1033 return getRhs();
1034
1035 if (APInt intValue;
1036 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1037 // maxui(x,MAX_INT) -> MAX_INT
1038 if (intValue.isMaxValue())
1039 return getRhs();
1040 // maxui(x, MIN_INT) -> x
1041 if (intValue.isMinValue())
1042 return getLhs();
1043 }
1044
1045 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1046 [](const APInt &a, const APInt &b) {
1047 return llvm::APIntOps::umax(a, b);
1048 });
1049}
1050
1051//===----------------------------------------------------------------------===//
1052// MinimumFOp
1053//===----------------------------------------------------------------------===//
1054
1055OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1056 // minimumf(x,x) -> x
1057 if (getLhs() == getRhs())
1058 return getRhs();
1059
1060 // minimumf(x, +inf) -> x
1061 if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1062 return getLhs();
1063
1064 return constFoldBinaryOp<FloatAttr>(
1065 adaptor.getOperands(),
1066 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1067}
1068
1069//===----------------------------------------------------------------------===//
1070// MinNumFOp
1071//===----------------------------------------------------------------------===//
1072
1073OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1074 // minnumf(x,x) -> x
1075 if (getLhs() == getRhs())
1076 return getRhs();
1077
1078 // minnumf(x, +inf) -> x
1079 if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1080 return getLhs();
1081
1082 return constFoldBinaryOp<FloatAttr>(
1083 adaptor.getOperands(),
1084 [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1085}
1086
1087//===----------------------------------------------------------------------===//
1088// MinSIOp
1089//===----------------------------------------------------------------------===//
1090
1091OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1092 // minsi(x,x) -> x
1093 if (getLhs() == getRhs())
1094 return getRhs();
1095
1096 if (APInt intValue;
1097 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1098 // minsi(x,MIN_INT) -> MIN_INT
1099 if (intValue.isMinSignedValue())
1100 return getRhs();
1101 // minsi(x, MAX_INT) -> x
1102 if (intValue.isMaxSignedValue())
1103 return getLhs();
1104 }
1105
1106 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1107 [](const APInt &a, const APInt &b) {
1108 return llvm::APIntOps::smin(a, b);
1109 });
1110}
1111
1112//===----------------------------------------------------------------------===//
1113// MinUIOp
1114//===----------------------------------------------------------------------===//
1115
1116OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1117 // minui(x,x) -> x
1118 if (getLhs() == getRhs())
1119 return getRhs();
1120
1121 if (APInt intValue;
1122 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1123 // minui(x,MIN_INT) -> MIN_INT
1124 if (intValue.isMinValue())
1125 return getRhs();
1126 // minui(x, MAX_INT) -> x
1127 if (intValue.isMaxValue())
1128 return getLhs();
1129 }
1130
1131 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1132 [](const APInt &a, const APInt &b) {
1133 return llvm::APIntOps::umin(a, b);
1134 });
1135}
1136
1137//===----------------------------------------------------------------------===//
1138// MulFOp
1139//===----------------------------------------------------------------------===//
1140
1141OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1142 // mulf(x, 1) -> x
1143 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1144 return getLhs();
1145
1146 return constFoldBinaryOp<FloatAttr>(
1147 adaptor.getOperands(),
1148 [](const APFloat &a, const APFloat &b) { return a * b; });
1149}
1150
1151void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1152 MLIRContext *context) {
1153 patterns.add<MulFOfNegF>(context);
1154}
1155
1156//===----------------------------------------------------------------------===//
1157// DivFOp
1158//===----------------------------------------------------------------------===//
1159
1160OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1161 // divf(x, 1) -> x
1162 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1163 return getLhs();
1164
1165 return constFoldBinaryOp<FloatAttr>(
1166 adaptor.getOperands(),
1167 [](const APFloat &a, const APFloat &b) { return a / b; });
1168}
1169
1170void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1171 MLIRContext *context) {
1172 patterns.add<DivFOfNegF>(context);
1173}
1174
1175//===----------------------------------------------------------------------===//
1176// RemFOp
1177//===----------------------------------------------------------------------===//
1178
1179OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1180 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1181 [](const APFloat &a, const APFloat &b) {
1182 APFloat result(a);
1183 (void)result.remainder(b);
1184 return result;
1185 });
1186}
1187
1188//===----------------------------------------------------------------------===//
1189// Utility functions for verifying cast ops
1190//===----------------------------------------------------------------------===//
1191
1192template <typename... Types>
1193using type_list = std::tuple<Types...> *;
1194
1195/// Returns a non-null type only if the provided type is one of the allowed
1196/// types or one of the allowed shaped types of the allowed types. Returns the
1197/// element type if a valid shaped type is provided.
1198template <typename... ShapedTypes, typename... ElementTypes>
1199static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
1200 type_list<ElementTypes...>) {
1201 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1202 return {};
1203
1204 auto underlyingType = getElementTypeOrSelf(type);
1205 if (!llvm::isa<ElementTypes...>(underlyingType))
1206 return {};
1207
1208 return underlyingType;
1209}
1210
1211/// Get allowed underlying types for vectors and tensors.
1212template <typename... ElementTypes>
1213static Type getTypeIfLike(Type type) {
1214 return getUnderlyingType(type, type_list<VectorType, TensorType>(),
1215 type_list<ElementTypes...>());
1216}
1217
1218/// Get allowed underlying types for vectors, tensors, and memrefs.
1219template <typename... ElementTypes>
1220static Type getTypeIfLikeOrMemRef(Type type) {
1221 return getUnderlyingType(type,
1222 type_list<VectorType, TensorType, MemRefType>(),
1223 type_list<ElementTypes...>());
1224}
1225
1226/// Return false if both types are ranked tensor with mismatching encoding.
1227static bool hasSameEncoding(Type typeA, Type typeB) {
1228 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1229 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1230 if (!rankedTensorA || !rankedTensorB)
1231 return true;
1232 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1233}
1234
1235static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1236 if (inputs.size() != 1 || outputs.size() != 1)
1237 return false;
1238 if (!hasSameEncoding(typeA: inputs.front(), typeB: outputs.front()))
1239 return false;
1240 return succeeded(result: verifyCompatibleShapes(types1: inputs.front(), types2: outputs.front()));
1241}
1242
1243//===----------------------------------------------------------------------===//
1244// Verifiers for integer and floating point extension/truncation ops
1245//===----------------------------------------------------------------------===//
1246
1247// Extend ops can only extend to a wider type.
1248template <typename ValType, typename Op>
1249static LogicalResult verifyExtOp(Op op) {
1250 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1251 Type dstType = getElementTypeOrSelf(op.getType());
1252
1253 if (llvm::cast<ValType>(srcType).getWidth() >=
1254 llvm::cast<ValType>(dstType).getWidth())
1255 return op.emitError("result type ")
1256 << dstType << " must be wider than operand type " << srcType;
1257
1258 return success();
1259}
1260
1261// Truncate ops can only truncate to a shorter type.
1262template <typename ValType, typename Op>
1263static LogicalResult verifyTruncateOp(Op op) {
1264 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1265 Type dstType = getElementTypeOrSelf(op.getType());
1266
1267 if (llvm::cast<ValType>(srcType).getWidth() <=
1268 llvm::cast<ValType>(dstType).getWidth())
1269 return op.emitError("result type ")
1270 << dstType << " must be shorter than operand type " << srcType;
1271
1272 return success();
1273}
1274
1275/// Validate a cast that changes the width of a type.
1276template <template <typename> class WidthComparator, typename... ElementTypes>
1277static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1278 if (!areValidCastInputsAndOutputs(inputs, outputs))
1279 return false;
1280
1281 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1282 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1283 if (!srcType || !dstType)
1284 return false;
1285
1286 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1287 srcType.getIntOrFloatBitWidth());
1288}
1289
1290/// Attempts to convert `sourceValue` to an APFloat value with
1291/// `targetSemantics` and `roundingMode`, without any information loss.
1292static FailureOr<APFloat> convertFloatValue(
1293 APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1294 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1295 bool losesInfo = false;
1296 auto status = sourceValue.convert(ToSemantics: targetSemantics, RM: roundingMode, losesInfo: &losesInfo);
1297 if (losesInfo || status != APFloat::opOK)
1298 return failure();
1299
1300 return sourceValue;
1301}
1302
1303//===----------------------------------------------------------------------===//
1304// ExtUIOp
1305//===----------------------------------------------------------------------===//
1306
1307OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1308 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1309 getInMutable().assign(lhs.getIn());
1310 return getResult();
1311 }
1312
1313 Type resType = getElementTypeOrSelf(getType());
1314 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1315 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1316 adaptor.getOperands(), getType(),
1317 [bitWidth](const APInt &a, bool &castStatus) {
1318 return a.zext(bitWidth);
1319 });
1320}
1321
1322bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1323 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1324}
1325
1326LogicalResult arith::ExtUIOp::verify() {
1327 return verifyExtOp<IntegerType>(*this);
1328}
1329
1330//===----------------------------------------------------------------------===//
1331// ExtSIOp
1332//===----------------------------------------------------------------------===//
1333
1334OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1335 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1336 getInMutable().assign(lhs.getIn());
1337 return getResult();
1338 }
1339
1340 Type resType = getElementTypeOrSelf(getType());
1341 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1342 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1343 adaptor.getOperands(), getType(),
1344 [bitWidth](const APInt &a, bool &castStatus) {
1345 return a.sext(bitWidth);
1346 });
1347}
1348
1349bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1350 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1351}
1352
1353void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1354 MLIRContext *context) {
1355 patterns.add<ExtSIOfExtUI>(context);
1356}
1357
1358LogicalResult arith::ExtSIOp::verify() {
1359 return verifyExtOp<IntegerType>(*this);
1360}
1361
1362//===----------------------------------------------------------------------===//
1363// ExtFOp
1364//===----------------------------------------------------------------------===//
1365
1366/// Fold extension of float constants when there is no information loss due the
1367/// difference in fp semantics.
1368OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1369 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1370 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1371 return constFoldCastOp<FloatAttr, FloatAttr>(
1372 adaptor.getOperands(), getType(),
1373 [&targetSemantics](const APFloat &a, bool &castStatus) {
1374 FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1375 if (failed(result)) {
1376 castStatus = false;
1377 return a;
1378 }
1379 return *result;
1380 });
1381}
1382
1383bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1384 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1385}
1386
1387LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1388
1389//===----------------------------------------------------------------------===//
1390// TruncIOp
1391//===----------------------------------------------------------------------===//
1392
1393OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1394 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1395 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1396 Value src = getOperand().getDefiningOp()->getOperand(0);
1397 Type srcType = getElementTypeOrSelf(src.getType());
1398 Type dstType = getElementTypeOrSelf(getType());
1399 // trunci(zexti(a)) -> trunci(a)
1400 // trunci(sexti(a)) -> trunci(a)
1401 if (llvm::cast<IntegerType>(srcType).getWidth() >
1402 llvm::cast<IntegerType>(dstType).getWidth()) {
1403 setOperand(src);
1404 return getResult();
1405 }
1406
1407 // trunci(zexti(a)) -> a
1408 // trunci(sexti(a)) -> a
1409 if (srcType == dstType)
1410 return src;
1411 }
1412
1413 // trunci(trunci(a)) -> trunci(a))
1414 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1415 setOperand(getOperand().getDefiningOp()->getOperand(0));
1416 return getResult();
1417 }
1418
1419 Type resType = getElementTypeOrSelf(getType());
1420 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1421 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1422 adaptor.getOperands(), getType(),
1423 [bitWidth](const APInt &a, bool &castStatus) {
1424 return a.trunc(bitWidth);
1425 });
1426}
1427
1428bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1429 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1430}
1431
1432void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1433 MLIRContext *context) {
1434 patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1435 TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1436 context);
1437}
1438
1439LogicalResult arith::TruncIOp::verify() {
1440 return verifyTruncateOp<IntegerType>(*this);
1441}
1442
1443//===----------------------------------------------------------------------===//
1444// TruncFOp
1445//===----------------------------------------------------------------------===//
1446
1447/// Perform safe const propagation for truncf, i.e., only propagate if FP value
1448/// can be represented without precision loss.
1449OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1450 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1451 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1452 return constFoldCastOp<FloatAttr, FloatAttr>(
1453 adaptor.getOperands(), getType(),
1454 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1455 RoundingMode roundingMode =
1456 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1457 llvm::RoundingMode llvmRoundingMode =
1458 convertArithRoundingModeToLLVMIR(roundingMode);
1459 FailureOr<APFloat> result =
1460 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1461 if (failed(result)) {
1462 castStatus = false;
1463 return a;
1464 }
1465 return *result;
1466 });
1467}
1468
1469bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1470 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1471}
1472
1473LogicalResult arith::TruncFOp::verify() {
1474 return verifyTruncateOp<FloatType>(*this);
1475}
1476
1477//===----------------------------------------------------------------------===//
1478// AndIOp
1479//===----------------------------------------------------------------------===//
1480
1481void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1482 MLIRContext *context) {
1483 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1484}
1485
1486//===----------------------------------------------------------------------===//
1487// OrIOp
1488//===----------------------------------------------------------------------===//
1489
1490void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1491 MLIRContext *context) {
1492 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1493}
1494
1495//===----------------------------------------------------------------------===//
1496// Verifiers for casts between integers and floats.
1497//===----------------------------------------------------------------------===//
1498
1499template <typename From, typename To>
1500static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1501 if (!areValidCastInputsAndOutputs(inputs, outputs))
1502 return false;
1503
1504 auto srcType = getTypeIfLike<From>(inputs.front());
1505 auto dstType = getTypeIfLike<To>(outputs.back());
1506
1507 return srcType && dstType;
1508}
1509
1510//===----------------------------------------------------------------------===//
1511// UIToFPOp
1512//===----------------------------------------------------------------------===//
1513
1514bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1515 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1516}
1517
1518OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1519 Type resEleType = getElementTypeOrSelf(getType());
1520 return constFoldCastOp<IntegerAttr, FloatAttr>(
1521 adaptor.getOperands(), getType(),
1522 [&resEleType](const APInt &a, bool &castStatus) {
1523 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1524 APFloat apf(floatTy.getFloatSemantics(),
1525 APInt::getZero(floatTy.getWidth()));
1526 apf.convertFromAPInt(a, /*IsSigned=*/false,
1527 APFloat::rmNearestTiesToEven);
1528 return apf;
1529 });
1530}
1531
1532//===----------------------------------------------------------------------===//
1533// SIToFPOp
1534//===----------------------------------------------------------------------===//
1535
1536bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1537 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1538}
1539
1540OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1541 Type resEleType = getElementTypeOrSelf(getType());
1542 return constFoldCastOp<IntegerAttr, FloatAttr>(
1543 adaptor.getOperands(), getType(),
1544 [&resEleType](const APInt &a, bool &castStatus) {
1545 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1546 APFloat apf(floatTy.getFloatSemantics(),
1547 APInt::getZero(floatTy.getWidth()));
1548 apf.convertFromAPInt(a, /*IsSigned=*/true,
1549 APFloat::rmNearestTiesToEven);
1550 return apf;
1551 });
1552}
1553
1554//===----------------------------------------------------------------------===//
1555// FPToUIOp
1556//===----------------------------------------------------------------------===//
1557
1558bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1559 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1560}
1561
1562OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1563 Type resType = getElementTypeOrSelf(getType());
1564 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1565 return constFoldCastOp<FloatAttr, IntegerAttr>(
1566 adaptor.getOperands(), getType(),
1567 [&bitWidth](const APFloat &a, bool &castStatus) {
1568 bool ignored;
1569 APSInt api(bitWidth, /*isUnsigned=*/true);
1570 castStatus = APFloat::opInvalidOp !=
1571 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1572 return api;
1573 });
1574}
1575
1576//===----------------------------------------------------------------------===//
1577// FPToSIOp
1578//===----------------------------------------------------------------------===//
1579
1580bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1581 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1582}
1583
1584OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1585 Type resType = getElementTypeOrSelf(getType());
1586 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1587 return constFoldCastOp<FloatAttr, IntegerAttr>(
1588 adaptor.getOperands(), getType(),
1589 [&bitWidth](const APFloat &a, bool &castStatus) {
1590 bool ignored;
1591 APSInt api(bitWidth, /*isUnsigned=*/false);
1592 castStatus = APFloat::opInvalidOp !=
1593 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1594 return api;
1595 });
1596}
1597
1598//===----------------------------------------------------------------------===//
1599// IndexCastOp
1600//===----------------------------------------------------------------------===//
1601
1602static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1603 if (!areValidCastInputsAndOutputs(inputs, outputs))
1604 return false;
1605
1606 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(type: inputs.front());
1607 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(type: outputs.front());
1608 if (!srcType || !dstType)
1609 return false;
1610
1611 return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1612 (srcType.isSignlessInteger() && dstType.isIndex());
1613}
1614
1615bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1616 TypeRange outputs) {
1617 return areIndexCastCompatible(inputs, outputs);
1618}
1619
1620OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1621 // index_cast(constant) -> constant
1622 unsigned resultBitwidth = 64; // Default for index integer attributes.
1623 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1624 resultBitwidth = intTy.getWidth();
1625
1626 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1627 adaptor.getOperands(), getType(),
1628 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1629 return a.sextOrTrunc(resultBitwidth);
1630 });
1631}
1632
1633void arith::IndexCastOp::getCanonicalizationPatterns(
1634 RewritePatternSet &patterns, MLIRContext *context) {
1635 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1636}
1637
1638//===----------------------------------------------------------------------===//
1639// IndexCastUIOp
1640//===----------------------------------------------------------------------===//
1641
1642bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1643 TypeRange outputs) {
1644 return areIndexCastCompatible(inputs, outputs);
1645}
1646
1647OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1648 // index_castui(constant) -> constant
1649 unsigned resultBitwidth = 64; // Default for index integer attributes.
1650 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1651 resultBitwidth = intTy.getWidth();
1652
1653 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1654 adaptor.getOperands(), getType(),
1655 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1656 return a.zextOrTrunc(resultBitwidth);
1657 });
1658}
1659
1660void arith::IndexCastUIOp::getCanonicalizationPatterns(
1661 RewritePatternSet &patterns, MLIRContext *context) {
1662 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1663}
1664
1665//===----------------------------------------------------------------------===//
1666// BitcastOp
1667//===----------------------------------------------------------------------===//
1668
1669bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1670 if (!areValidCastInputsAndOutputs(inputs, outputs))
1671 return false;
1672
1673 auto srcType =
1674 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1675 auto dstType =
1676 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1677 if (!srcType || !dstType)
1678 return false;
1679
1680 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1681}
1682
1683OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1684 auto resType = getType();
1685 auto operand = adaptor.getIn();
1686 if (!operand)
1687 return {};
1688
1689 /// Bitcast dense elements.
1690 if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
1691 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1692 /// Other shaped types unhandled.
1693 if (llvm::isa<ShapedType>(resType))
1694 return {};
1695
1696 /// Bitcast integer or float to integer or float.
1697 APInt bits = llvm::isa<FloatAttr>(operand)
1698 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1699 : llvm::cast<IntegerAttr>(operand).getValue();
1700
1701 if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
1702 return FloatAttr::get(resType,
1703 APFloat(resFloatType.getFloatSemantics(), bits));
1704 return IntegerAttr::get(resType, bits);
1705}
1706
1707void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1708 MLIRContext *context) {
1709 patterns.add<BitcastOfBitcast>(context);
1710}
1711
1712//===----------------------------------------------------------------------===//
1713// CmpIOp
1714//===----------------------------------------------------------------------===//
1715
1716/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1717/// comparison predicates.
1718bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1719 const APInt &lhs, const APInt &rhs) {
1720 switch (predicate) {
1721 case arith::CmpIPredicate::eq:
1722 return lhs.eq(RHS: rhs);
1723 case arith::CmpIPredicate::ne:
1724 return lhs.ne(RHS: rhs);
1725 case arith::CmpIPredicate::slt:
1726 return lhs.slt(RHS: rhs);
1727 case arith::CmpIPredicate::sle:
1728 return lhs.sle(RHS: rhs);
1729 case arith::CmpIPredicate::sgt:
1730 return lhs.sgt(RHS: rhs);
1731 case arith::CmpIPredicate::sge:
1732 return lhs.sge(RHS: rhs);
1733 case arith::CmpIPredicate::ult:
1734 return lhs.ult(RHS: rhs);
1735 case arith::CmpIPredicate::ule:
1736 return lhs.ule(RHS: rhs);
1737 case arith::CmpIPredicate::ugt:
1738 return lhs.ugt(RHS: rhs);
1739 case arith::CmpIPredicate::uge:
1740 return lhs.uge(RHS: rhs);
1741 }
1742 llvm_unreachable("unknown cmpi predicate kind");
1743}
1744
1745/// Returns true if the predicate is true for two equal operands.
1746static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1747 switch (predicate) {
1748 case arith::CmpIPredicate::eq:
1749 case arith::CmpIPredicate::sle:
1750 case arith::CmpIPredicate::sge:
1751 case arith::CmpIPredicate::ule:
1752 case arith::CmpIPredicate::uge:
1753 return true;
1754 case arith::CmpIPredicate::ne:
1755 case arith::CmpIPredicate::slt:
1756 case arith::CmpIPredicate::sgt:
1757 case arith::CmpIPredicate::ult:
1758 case arith::CmpIPredicate::ugt:
1759 return false;
1760 }
1761 llvm_unreachable("unknown cmpi predicate kind");
1762}
1763
1764static std::optional<int64_t> getIntegerWidth(Type t) {
1765 if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
1766 return intType.getWidth();
1767 }
1768 if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
1769 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1770 }
1771 return std::nullopt;
1772}
1773
1774OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1775 // cmpi(pred, x, x)
1776 if (getLhs() == getRhs()) {
1777 auto val = applyCmpPredicateToEqualOperands(getPredicate());
1778 return getBoolAttribute(getType(), val);
1779 }
1780
1781 if (matchPattern(adaptor.getRhs(), m_Zero())) {
1782 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1783 // extsi(%x : i1 -> iN) != 0 -> %x
1784 std::optional<int64_t> integerWidth =
1785 getIntegerWidth(extOp.getOperand().getType());
1786 if (integerWidth && integerWidth.value() == 1 &&
1787 getPredicate() == arith::CmpIPredicate::ne)
1788 return extOp.getOperand();
1789 }
1790 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1791 // extui(%x : i1 -> iN) != 0 -> %x
1792 std::optional<int64_t> integerWidth =
1793 getIntegerWidth(extOp.getOperand().getType());
1794 if (integerWidth && integerWidth.value() == 1 &&
1795 getPredicate() == arith::CmpIPredicate::ne)
1796 return extOp.getOperand();
1797 }
1798 }
1799
1800 // Move constant to the right side.
1801 if (adaptor.getLhs() && !adaptor.getRhs()) {
1802 // Do not use invertPredicate, as it will change eq to ne and vice versa.
1803 using Pred = CmpIPredicate;
1804 const std::pair<Pred, Pred> invPreds[] = {
1805 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1806 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1807 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1808 {Pred::ne, Pred::ne},
1809 };
1810 Pred origPred = getPredicate();
1811 for (auto pred : invPreds) {
1812 if (origPred == pred.first) {
1813 setPredicate(pred.second);
1814 Value lhs = getLhs();
1815 Value rhs = getRhs();
1816 getLhsMutable().assign(rhs);
1817 getRhsMutable().assign(lhs);
1818 return getResult();
1819 }
1820 }
1821 llvm_unreachable("unknown cmpi predicate kind");
1822 }
1823
1824 // We are moving constants to the right side; So if lhs is constant rhs is
1825 // guaranteed to be a constant.
1826 if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
1827 return constFoldBinaryOp<IntegerAttr>(
1828 adaptor.getOperands(), getI1SameShape(lhs.getType()),
1829 [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1830 return APInt(1,
1831 static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
1832 });
1833 }
1834
1835 return {};
1836}
1837
1838void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1839 MLIRContext *context) {
1840 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1841}
1842
1843//===----------------------------------------------------------------------===//
1844// CmpFOp
1845//===----------------------------------------------------------------------===//
1846
1847/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1848/// comparison predicates.
1849bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1850 const APFloat &lhs, const APFloat &rhs) {
1851 auto cmpResult = lhs.compare(RHS: rhs);
1852 switch (predicate) {
1853 case arith::CmpFPredicate::AlwaysFalse:
1854 return false;
1855 case arith::CmpFPredicate::OEQ:
1856 return cmpResult == APFloat::cmpEqual;
1857 case arith::CmpFPredicate::OGT:
1858 return cmpResult == APFloat::cmpGreaterThan;
1859 case arith::CmpFPredicate::OGE:
1860 return cmpResult == APFloat::cmpGreaterThan ||
1861 cmpResult == APFloat::cmpEqual;
1862 case arith::CmpFPredicate::OLT:
1863 return cmpResult == APFloat::cmpLessThan;
1864 case arith::CmpFPredicate::OLE:
1865 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1866 case arith::CmpFPredicate::ONE:
1867 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1868 case arith::CmpFPredicate::ORD:
1869 return cmpResult != APFloat::cmpUnordered;
1870 case arith::CmpFPredicate::UEQ:
1871 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1872 case arith::CmpFPredicate::UGT:
1873 return cmpResult == APFloat::cmpUnordered ||
1874 cmpResult == APFloat::cmpGreaterThan;
1875 case arith::CmpFPredicate::UGE:
1876 return cmpResult == APFloat::cmpUnordered ||
1877 cmpResult == APFloat::cmpGreaterThan ||
1878 cmpResult == APFloat::cmpEqual;
1879 case arith::CmpFPredicate::ULT:
1880 return cmpResult == APFloat::cmpUnordered ||
1881 cmpResult == APFloat::cmpLessThan;
1882 case arith::CmpFPredicate::ULE:
1883 return cmpResult == APFloat::cmpUnordered ||
1884 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1885 case arith::CmpFPredicate::UNE:
1886 return cmpResult != APFloat::cmpEqual;
1887 case arith::CmpFPredicate::UNO:
1888 return cmpResult == APFloat::cmpUnordered;
1889 case arith::CmpFPredicate::AlwaysTrue:
1890 return true;
1891 }
1892 llvm_unreachable("unknown cmpf predicate kind");
1893}
1894
1895OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
1896 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
1897 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
1898
1899 // If one operand is NaN, making them both NaN does not change the result.
1900 if (lhs && lhs.getValue().isNaN())
1901 rhs = lhs;
1902 if (rhs && rhs.getValue().isNaN())
1903 lhs = rhs;
1904
1905 if (!lhs || !rhs)
1906 return {};
1907
1908 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1909 return BoolAttr::get(getContext(), val);
1910}
1911
1912class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1913public:
1914 using OpRewritePattern<CmpFOp>::OpRewritePattern;
1915
1916 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1917 bool isUnsigned) {
1918 using namespace arith;
1919 switch (pred) {
1920 case CmpFPredicate::UEQ:
1921 case CmpFPredicate::OEQ:
1922 return CmpIPredicate::eq;
1923 case CmpFPredicate::UGT:
1924 case CmpFPredicate::OGT:
1925 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1926 case CmpFPredicate::UGE:
1927 case CmpFPredicate::OGE:
1928 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1929 case CmpFPredicate::ULT:
1930 case CmpFPredicate::OLT:
1931 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1932 case CmpFPredicate::ULE:
1933 case CmpFPredicate::OLE:
1934 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1935 case CmpFPredicate::UNE:
1936 case CmpFPredicate::ONE:
1937 return CmpIPredicate::ne;
1938 default:
1939 llvm_unreachable("Unexpected predicate!");
1940 }
1941 }
1942
1943 LogicalResult matchAndRewrite(CmpFOp op,
1944 PatternRewriter &rewriter) const override {
1945 FloatAttr flt;
1946 if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1947 return failure();
1948
1949 const APFloat &rhs = flt.getValue();
1950
1951 // Don't attempt to fold a nan.
1952 if (rhs.isNaN())
1953 return failure();
1954
1955 // Get the width of the mantissa. We don't want to hack on conversions that
1956 // might lose information from the integer, e.g. "i64 -> float"
1957 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
1958 int mantissaWidth = floatTy.getFPMantissaWidth();
1959 if (mantissaWidth <= 0)
1960 return failure();
1961
1962 bool isUnsigned;
1963 Value intVal;
1964
1965 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1966 isUnsigned = false;
1967 intVal = si.getIn();
1968 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1969 isUnsigned = true;
1970 intVal = ui.getIn();
1971 } else {
1972 return failure();
1973 }
1974
1975 // Check to see that the input is converted from an integer type that is
1976 // small enough that preserves all bits.
1977 auto intTy = llvm::cast<IntegerType>(intVal.getType());
1978 auto intWidth = intTy.getWidth();
1979
1980 // Number of bits representing values, as opposed to the sign
1981 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1982
1983 // Following test does NOT adjust intWidth downwards for signed inputs,
1984 // because the most negative value still requires all the mantissa bits
1985 // to distinguish it from one less than that value.
1986 if ((int)intWidth > mantissaWidth) {
1987 // Conversion would lose accuracy. Check if loss can impact comparison.
1988 int exponent = ilogb(Arg: rhs);
1989 if (exponent == APFloat::IEK_Inf) {
1990 int maxExponent = ilogb(Arg: APFloat::getLargest(Sem: rhs.getSemantics()));
1991 if (maxExponent < (int)valueBits) {
1992 // Conversion could create infinity.
1993 return failure();
1994 }
1995 } else {
1996 // Note that if rhs is zero or NaN, then Exp is negative
1997 // and first condition is trivially false.
1998 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1999 // Conversion could affect comparison.
2000 return failure();
2001 }
2002 }
2003 }
2004
2005 // Convert to equivalent cmpi predicate
2006 CmpIPredicate pred;
2007 switch (op.getPredicate()) {
2008 case CmpFPredicate::ORD:
2009 // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2010 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2011 /*width=*/1);
2012 return success();
2013 case CmpFPredicate::UNO:
2014 // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2015 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2016 /*width=*/1);
2017 return success();
2018 default:
2019 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2020 break;
2021 }
2022
2023 if (!isUnsigned) {
2024 // If the rhs value is > SignedMax, fold the comparison. This handles
2025 // +INF and large values.
2026 APFloat signedMax(rhs.getSemantics());
2027 signedMax.convertFromAPInt(Input: APInt::getSignedMaxValue(numBits: intWidth), IsSigned: true,
2028 RM: APFloat::rmNearestTiesToEven);
2029 if (signedMax < rhs) { // smax < 13123.0
2030 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2031 pred == CmpIPredicate::sle)
2032 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2033 /*width=*/1);
2034 else
2035 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2036 /*width=*/1);
2037 return success();
2038 }
2039 } else {
2040 // If the rhs value is > UnsignedMax, fold the comparison. This handles
2041 // +INF and large values.
2042 APFloat unsignedMax(rhs.getSemantics());
2043 unsignedMax.convertFromAPInt(Input: APInt::getMaxValue(numBits: intWidth), IsSigned: false,
2044 RM: APFloat::rmNearestTiesToEven);
2045 if (unsignedMax < rhs) { // umax < 13123.0
2046 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2047 pred == CmpIPredicate::ule)
2048 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2049 /*width=*/1);
2050 else
2051 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2052 /*width=*/1);
2053 return success();
2054 }
2055 }
2056
2057 if (!isUnsigned) {
2058 // See if the rhs value is < SignedMin.
2059 APFloat signedMin(rhs.getSemantics());
2060 signedMin.convertFromAPInt(Input: APInt::getSignedMinValue(numBits: intWidth), IsSigned: true,
2061 RM: APFloat::rmNearestTiesToEven);
2062 if (signedMin > rhs) { // smin > 12312.0
2063 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2064 pred == CmpIPredicate::sge)
2065 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2066 /*width=*/1);
2067 else
2068 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2069 /*width=*/1);
2070 return success();
2071 }
2072 } else {
2073 // See if the rhs value is < UnsignedMin.
2074 APFloat unsignedMin(rhs.getSemantics());
2075 unsignedMin.convertFromAPInt(Input: APInt::getMinValue(numBits: intWidth), IsSigned: false,
2076 RM: APFloat::rmNearestTiesToEven);
2077 if (unsignedMin > rhs) { // umin > 12312.0
2078 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2079 pred == CmpIPredicate::uge)
2080 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2081 /*width=*/1);
2082 else
2083 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2084 /*width=*/1);
2085 return success();
2086 }
2087 }
2088
2089 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2090 // [0, UMAX], but it may still be fractional. See if it is fractional by
2091 // casting the FP value to the integer value and back, checking for
2092 // equality. Don't do this for zero, because -0.0 is not fractional.
2093 bool ignored;
2094 APSInt rhsInt(intWidth, isUnsigned);
2095 if (APFloat::opInvalidOp ==
2096 rhs.convertToInteger(Result&: rhsInt, RM: APFloat::rmTowardZero, IsExact: &ignored)) {
2097 // Undefined behavior invoked - the destination type can't represent
2098 // the input constant.
2099 return failure();
2100 }
2101
2102 if (!rhs.isZero()) {
2103 APFloat apf(floatTy.getFloatSemantics(),
2104 APInt::getZero(numBits: floatTy.getWidth()));
2105 apf.convertFromAPInt(Input: rhsInt, IsSigned: !isUnsigned, RM: APFloat::rmNearestTiesToEven);
2106
2107 bool equal = apf == rhs;
2108 if (!equal) {
2109 // If we had a comparison against a fractional value, we have to adjust
2110 // the compare predicate and sometimes the value. rhsInt is rounded
2111 // towards zero at this point.
2112 switch (pred) {
2113 case CmpIPredicate::ne: // (float)int != 4.4 --> true
2114 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2115 /*width=*/1);
2116 return success();
2117 case CmpIPredicate::eq: // (float)int == 4.4 --> false
2118 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2119 /*width=*/1);
2120 return success();
2121 case CmpIPredicate::ule:
2122 // (float)int <= 4.4 --> int <= 4
2123 // (float)int <= -4.4 --> false
2124 if (rhs.isNegative()) {
2125 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2126 /*width=*/1);
2127 return success();
2128 }
2129 break;
2130 case CmpIPredicate::sle:
2131 // (float)int <= 4.4 --> int <= 4
2132 // (float)int <= -4.4 --> int < -4
2133 if (rhs.isNegative())
2134 pred = CmpIPredicate::slt;
2135 break;
2136 case CmpIPredicate::ult:
2137 // (float)int < -4.4 --> false
2138 // (float)int < 4.4 --> int <= 4
2139 if (rhs.isNegative()) {
2140 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2141 /*width=*/1);
2142 return success();
2143 }
2144 pred = CmpIPredicate::ule;
2145 break;
2146 case CmpIPredicate::slt:
2147 // (float)int < -4.4 --> int < -4
2148 // (float)int < 4.4 --> int <= 4
2149 if (!rhs.isNegative())
2150 pred = CmpIPredicate::sle;
2151 break;
2152 case CmpIPredicate::ugt:
2153 // (float)int > 4.4 --> int > 4
2154 // (float)int > -4.4 --> true
2155 if (rhs.isNegative()) {
2156 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2157 /*width=*/1);
2158 return success();
2159 }
2160 break;
2161 case CmpIPredicate::sgt:
2162 // (float)int > 4.4 --> int > 4
2163 // (float)int > -4.4 --> int >= -4
2164 if (rhs.isNegative())
2165 pred = CmpIPredicate::sge;
2166 break;
2167 case CmpIPredicate::uge:
2168 // (float)int >= -4.4 --> true
2169 // (float)int >= 4.4 --> int > 4
2170 if (rhs.isNegative()) {
2171 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2172 /*width=*/1);
2173 return success();
2174 }
2175 pred = CmpIPredicate::ugt;
2176 break;
2177 case CmpIPredicate::sge:
2178 // (float)int >= -4.4 --> int >= -4
2179 // (float)int >= 4.4 --> int > 4
2180 if (!rhs.isNegative())
2181 pred = CmpIPredicate::sgt;
2182 break;
2183 }
2184 }
2185 }
2186
2187 // Lower this FP comparison into an appropriate integer version of the
2188 // comparison.
2189 rewriter.replaceOpWithNewOp<CmpIOp>(
2190 op, pred, intVal,
2191 rewriter.create<ConstantOp>(
2192 op.getLoc(), intVal.getType(),
2193 rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2194 return success();
2195 }
2196};
2197
2198void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2199 MLIRContext *context) {
2200 patterns.insert<CmpFIntToFPConst>(context);
2201}
2202
2203//===----------------------------------------------------------------------===//
2204// SelectOp
2205//===----------------------------------------------------------------------===//
2206
2207// select %arg, %c1, %c0 => extui %arg
2208struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2209 using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
2210
2211 LogicalResult matchAndRewrite(arith::SelectOp op,
2212 PatternRewriter &rewriter) const override {
2213 // Cannot extui i1 to i1, or i1 to f32
2214 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2215 return failure();
2216
2217 // select %x, c1, %c0 => extui %arg
2218 if (matchPattern(op.getTrueValue(), m_One()) &&
2219 matchPattern(op.getFalseValue(), m_Zero())) {
2220 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2221 op.getCondition());
2222 return success();
2223 }
2224
2225 // select %x, c0, %c1 => extui (xor %arg, true)
2226 if (matchPattern(op.getTrueValue(), m_Zero()) &&
2227 matchPattern(op.getFalseValue(), m_One())) {
2228 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2229 op, op.getType(),
2230 rewriter.create<arith::XOrIOp>(
2231 op.getLoc(), op.getCondition(),
2232 rewriter.create<arith::ConstantIntOp>(
2233 op.getLoc(), 1, op.getCondition().getType())));
2234 return success();
2235 }
2236
2237 return failure();
2238 }
2239};
2240
2241void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2242 MLIRContext *context) {
2243 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2244 SelectI1ToNot, SelectToExtUI>(context);
2245}
2246
2247OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2248 Value trueVal = getTrueValue();
2249 Value falseVal = getFalseValue();
2250 if (trueVal == falseVal)
2251 return trueVal;
2252
2253 Value condition = getCondition();
2254
2255 // select true, %0, %1 => %0
2256 if (matchPattern(adaptor.getCondition(), m_One()))
2257 return trueVal;
2258
2259 // select false, %0, %1 => %1
2260 if (matchPattern(adaptor.getCondition(), m_Zero()))
2261 return falseVal;
2262
2263 // If either operand is fully poisoned, return the other.
2264 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2265 return falseVal;
2266
2267 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2268 return trueVal;
2269
2270 // select %x, true, false => %x
2271 if (getType().isInteger(1) && matchPattern(adaptor.getTrueValue(), m_One()) &&
2272 matchPattern(adaptor.getFalseValue(), m_Zero()))
2273 return condition;
2274
2275 if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
2276 auto pred = cmp.getPredicate();
2277 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2278 auto cmpLhs = cmp.getLhs();
2279 auto cmpRhs = cmp.getRhs();
2280
2281 // %0 = arith.cmpi eq, %arg0, %arg1
2282 // %1 = arith.select %0, %arg0, %arg1 => %arg1
2283
2284 // %0 = arith.cmpi ne, %arg0, %arg1
2285 // %1 = arith.select %0, %arg0, %arg1 => %arg0
2286
2287 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2288 (cmpRhs == trueVal && cmpLhs == falseVal))
2289 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2290 }
2291 }
2292
2293 // Constant-fold constant operands over non-splat constant condition.
2294 // select %cst_vec, %cst0, %cst1 => %cst2
2295 if (auto cond =
2296 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2297 if (auto lhs =
2298 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2299 if (auto rhs =
2300 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2301 SmallVector<Attribute> results;
2302 results.reserve(static_cast<size_t>(cond.getNumElements()));
2303 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2304 cond.value_end<BoolAttr>());
2305 auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2306 lhs.value_end<Attribute>());
2307 auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2308 rhs.value_end<Attribute>());
2309
2310 for (auto [condVal, lhsVal, rhsVal] :
2311 llvm::zip_equal(condVals, lhsVals, rhsVals))
2312 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2313
2314 return DenseElementsAttr::get(lhs.getType(), results);
2315 }
2316 }
2317 }
2318
2319 return nullptr;
2320}
2321
2322ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2323 Type conditionType, resultType;
2324 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2325 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2326 parser.parseOptionalAttrDict(result.attributes) ||
2327 parser.parseColonType(resultType))
2328 return failure();
2329
2330 // Check for the explicit condition type if this is a masked tensor or vector.
2331 if (succeeded(parser.parseOptionalComma())) {
2332 conditionType = resultType;
2333 if (parser.parseType(resultType))
2334 return failure();
2335 } else {
2336 conditionType = parser.getBuilder().getI1Type();
2337 }
2338
2339 result.addTypes(resultType);
2340 return parser.resolveOperands(operands,
2341 {conditionType, resultType, resultType},
2342 parser.getNameLoc(), result.operands);
2343}
2344
2345void arith::SelectOp::print(OpAsmPrinter &p) {
2346 p << " " << getOperands();
2347 p.printOptionalAttrDict((*this)->getAttrs());
2348 p << " : ";
2349 if (ShapedType condType =
2350 llvm::dyn_cast<ShapedType>(getCondition().getType()))
2351 p << condType << ", ";
2352 p << getType();
2353}
2354
2355LogicalResult arith::SelectOp::verify() {
2356 Type conditionType = getCondition().getType();
2357 if (conditionType.isSignlessInteger(1))
2358 return success();
2359
2360 // If the result type is a vector or tensor, the type can be a mask with the
2361 // same elements.
2362 Type resultType = getType();
2363 if (!llvm::isa<TensorType, VectorType>(resultType))
2364 return emitOpError() << "expected condition to be a signless i1, but got "
2365 << conditionType;
2366 Type shapedConditionType = getI1SameShape(resultType);
2367 if (conditionType != shapedConditionType) {
2368 return emitOpError() << "expected condition type to have the same shape "
2369 "as the result type, expected "
2370 << shapedConditionType << ", but got "
2371 << conditionType;
2372 }
2373 return success();
2374}
2375//===----------------------------------------------------------------------===//
2376// ShLIOp
2377//===----------------------------------------------------------------------===//
2378
2379OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2380 // shli(x, 0) -> x
2381 if (matchPattern(adaptor.getRhs(), m_Zero()))
2382 return getLhs();
2383 // Don't fold if shifting more or equal than the bit width.
2384 bool bounded = false;
2385 auto result = constFoldBinaryOp<IntegerAttr>(
2386 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2387 bounded = b.ult(b.getBitWidth());
2388 return a.shl(b);
2389 });
2390 return bounded ? result : Attribute();
2391}
2392
2393//===----------------------------------------------------------------------===//
2394// ShRUIOp
2395//===----------------------------------------------------------------------===//
2396
2397OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2398 // shrui(x, 0) -> x
2399 if (matchPattern(adaptor.getRhs(), m_Zero()))
2400 return getLhs();
2401 // Don't fold if shifting more or equal than the bit width.
2402 bool bounded = false;
2403 auto result = constFoldBinaryOp<IntegerAttr>(
2404 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2405 bounded = b.ult(b.getBitWidth());
2406 return a.lshr(b);
2407 });
2408 return bounded ? result : Attribute();
2409}
2410
2411//===----------------------------------------------------------------------===//
2412// ShRSIOp
2413//===----------------------------------------------------------------------===//
2414
2415OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2416 // shrsi(x, 0) -> x
2417 if (matchPattern(adaptor.getRhs(), m_Zero()))
2418 return getLhs();
2419 // Don't fold if shifting more or equal than the bit width.
2420 bool bounded = false;
2421 auto result = constFoldBinaryOp<IntegerAttr>(
2422 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2423 bounded = b.ult(b.getBitWidth());
2424 return a.ashr(b);
2425 });
2426 return bounded ? result : Attribute();
2427}
2428
2429//===----------------------------------------------------------------------===//
2430// Atomic Enum
2431//===----------------------------------------------------------------------===//
2432
2433/// Returns the identity value attribute associated with an AtomicRMWKind op.
2434TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2435 OpBuilder &builder, Location loc,
2436 bool useOnlyFiniteValue) {
2437 switch (kind) {
2438 case AtomicRMWKind::maximumf: {
2439 const llvm::fltSemantics &semantic =
2440 llvm::cast<FloatType>(Val&: resultType).getFloatSemantics();
2441 APFloat identity = useOnlyFiniteValue
2442 ? APFloat::getLargest(Sem: semantic, /*Negative=*/true)
2443 : APFloat::getInf(Sem: semantic, /*Negative=*/true);
2444 return builder.getFloatAttr(resultType, identity);
2445 }
2446 case AtomicRMWKind::addf:
2447 case AtomicRMWKind::addi:
2448 case AtomicRMWKind::maxu:
2449 case AtomicRMWKind::ori:
2450 return builder.getZeroAttr(resultType);
2451 case AtomicRMWKind::andi:
2452 return builder.getIntegerAttr(
2453 resultType,
2454 APInt::getAllOnes(numBits: llvm::cast<IntegerType>(resultType).getWidth()));
2455 case AtomicRMWKind::maxs:
2456 return builder.getIntegerAttr(
2457 resultType, APInt::getSignedMinValue(
2458 numBits: llvm::cast<IntegerType>(resultType).getWidth()));
2459 case AtomicRMWKind::minimumf: {
2460 const llvm::fltSemantics &semantic =
2461 llvm::cast<FloatType>(Val&: resultType).getFloatSemantics();
2462 APFloat identity = useOnlyFiniteValue
2463 ? APFloat::getLargest(Sem: semantic, /*Negative=*/false)
2464 : APFloat::getInf(Sem: semantic, /*Negative=*/false);
2465
2466 return builder.getFloatAttr(resultType, identity);
2467 }
2468 case AtomicRMWKind::mins:
2469 return builder.getIntegerAttr(
2470 resultType, APInt::getSignedMaxValue(
2471 numBits: llvm::cast<IntegerType>(resultType).getWidth()));
2472 case AtomicRMWKind::minu:
2473 return builder.getIntegerAttr(
2474 resultType,
2475 APInt::getMaxValue(numBits: llvm::cast<IntegerType>(resultType).getWidth()));
2476 case AtomicRMWKind::muli:
2477 return builder.getIntegerAttr(resultType, 1);
2478 case AtomicRMWKind::mulf:
2479 return builder.getFloatAttr(resultType, 1);
2480 // TODO: Add remaining reduction operations.
2481 default:
2482 (void)emitOptionalError(loc, args: "Reduction operation type not supported");
2483 break;
2484 }
2485 return nullptr;
2486}
2487
2488/// Return the identity numeric value associated to the give op.
2489std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2490 std::optional<AtomicRMWKind> maybeKind =
2491 llvm::TypeSwitch<Operation *, std::optional<AtomicRMWKind>>(op)
2492 // Floating-point operations.
2493 .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2494 .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2495 .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2496 .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2497 // Integer operations.
2498 .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2499 .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2500 .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2501 .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2502 .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2503 .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2504 .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2505 .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2506 .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2507 .Default([](Operation *op) { return std::nullopt; });
2508 if (!maybeKind) {
2509 op->emitError() << "Unknown neutral element for: " << *op;
2510 return std::nullopt;
2511 }
2512
2513 bool useOnlyFiniteValue = false;
2514 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2515 if (fmfOpInterface) {
2516 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2517 useOnlyFiniteValue =
2518 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2519 }
2520
2521 // Builder only used as helper for attribute creation.
2522 OpBuilder b(op->getContext());
2523 Type resultType = op->getResult(idx: 0).getType();
2524
2525 return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2526 useOnlyFiniteValue);
2527}
2528
2529/// Returns the identity value associated with an AtomicRMWKind op.
2530Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2531 OpBuilder &builder, Location loc,
2532 bool useOnlyFiniteValue) {
2533 auto attr =
2534 getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2535 return builder.create<arith::ConstantOp>(loc, attr);
2536}
2537
2538/// Return the value obtained by applying the reduction operation kind
2539/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2540Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2541 Location loc, Value lhs, Value rhs) {
2542 switch (op) {
2543 case AtomicRMWKind::addf:
2544 return builder.create<arith::AddFOp>(loc, lhs, rhs);
2545 case AtomicRMWKind::addi:
2546 return builder.create<arith::AddIOp>(loc, lhs, rhs);
2547 case AtomicRMWKind::mulf:
2548 return builder.create<arith::MulFOp>(loc, lhs, rhs);
2549 case AtomicRMWKind::muli:
2550 return builder.create<arith::MulIOp>(loc, lhs, rhs);
2551 case AtomicRMWKind::maximumf:
2552 return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
2553 case AtomicRMWKind::minimumf:
2554 return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2555 case AtomicRMWKind::maxnumf:
2556 return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2557 case AtomicRMWKind::minnumf:
2558 return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
2559 case AtomicRMWKind::maxs:
2560 return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2561 case AtomicRMWKind::mins:
2562 return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2563 case AtomicRMWKind::maxu:
2564 return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2565 case AtomicRMWKind::minu:
2566 return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2567 case AtomicRMWKind::ori:
2568 return builder.create<arith::OrIOp>(loc, lhs, rhs);
2569 case AtomicRMWKind::andi:
2570 return builder.create<arith::AndIOp>(loc, lhs, rhs);
2571 // TODO: Add remaining reduction operations.
2572 default:
2573 (void)emitOptionalError(loc, args: "Reduction operation type not supported");
2574 break;
2575 }
2576 return nullptr;
2577}
2578
2579//===----------------------------------------------------------------------===//
2580// TableGen'd op method definitions
2581//===----------------------------------------------------------------------===//
2582
2583#define GET_OP_CLASSES
2584#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2585
2586//===----------------------------------------------------------------------===//
2587// TableGen'd enum attribute definitions
2588//===----------------------------------------------------------------------===//
2589
2590#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
2591

source code of mlir/lib/Dialect/Arith/IR/ArithOps.cpp