1//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cassert>
10#include <cstdint>
11#include <functional>
12#include <utility>
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/CommonFolders.h"
16#include "mlir/Dialect/UB/IR/UBOps.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinAttributeInterfaces.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/OpImplementation.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/IR/TypeUtilities.h"
24#include "mlir/Support/LogicalResult.h"
25
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34using namespace mlir;
35using namespace mlir::arith;
36
37//===----------------------------------------------------------------------===//
38// Pattern helpers
39//===----------------------------------------------------------------------===//
40
41static IntegerAttr
42applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
43 Attribute rhs,
44 function_ref<APInt(const APInt &, const APInt &)> binFn) {
45 APInt lhsVal = llvm::cast<IntegerAttr>(Val&: lhs).getValue();
46 APInt rhsVal = llvm::cast<IntegerAttr>(Val&: rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
48 return IntegerAttr::get(type: res.getType(), value);
49}
50
51static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
52 Attribute lhs, Attribute rhs) {
53 return applyToIntegerAttrs(builder, res, lhs, rhs, binFn: std::plus<APInt>());
54}
55
56static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
57 Attribute lhs, Attribute rhs) {
58 return applyToIntegerAttrs(builder, res, lhs, rhs, binFn: std::minus<APInt>());
59}
60
61static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
62 Attribute lhs, Attribute rhs) {
63 return applyToIntegerAttrs(builder, res, lhs, rhs, binFn: std::multiplies<APInt>());
64}
65
66// Merge overflow flags from 2 ops, selecting the most conservative combination.
67static IntegerOverflowFlagsAttr
68mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
69 IntegerOverflowFlagsAttr val2) {
70 return IntegerOverflowFlagsAttr::get(context: val1.getContext(),
71 value: val1.getValue() & val2.getValue());
72}
73
74/// Invert an integer comparison predicate.
75arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
76 switch (pred) {
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
97 }
98 llvm_unreachable("unknown cmpi predicate kind");
99}
100
101/// Equivalent to
102/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
103///
104/// Not possible to implement as chain of calls as this would introduce a
105/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
106/// on the LLVM dialect and on translation to LLVM.
107static llvm::RoundingMode
108convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
120 }
121 llvm_unreachable("Unhandled rounding mode");
122}
123
124static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
125 return arith::CmpIPredicateAttr::get(context: pred.getContext(),
126 val: invertPredicate(pred: pred.getValue()));
127}
128
129static int64_t getScalarOrElementWidth(Type type) {
130 Type elemTy = getElementTypeOrSelf(type);
131 if (elemTy.isIntOrFloat())
132 return elemTy.getIntOrFloatBitWidth();
133
134 return -1;
135}
136
137static int64_t getScalarOrElementWidth(Value value) {
138 return getScalarOrElementWidth(type: value.getType());
139}
140
141static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
142 APInt value;
143 if (matchPattern(attr, pattern: m_ConstantInt(bind_value: &value)))
144 return value;
145
146 return failure();
147}
148
149static Attribute getBoolAttribute(Type type, bool value) {
150 auto boolAttr = BoolAttr::get(context: type.getContext(), value);
151 ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(Val&: type);
152 if (!shapedType)
153 return boolAttr;
154 return DenseElementsAttr::get(type: shapedType, values: boolAttr);
155}
156
157//===----------------------------------------------------------------------===//
158// TableGen'd canonicalization patterns
159//===----------------------------------------------------------------------===//
160
161namespace {
162#include "ArithCanonicalization.inc"
163} // namespace
164
165//===----------------------------------------------------------------------===//
166// Common helpers
167//===----------------------------------------------------------------------===//
168
169/// Return the type of the same shape (scalar, vector or tensor) containing i1.
170static Type getI1SameShape(Type type) {
171 auto i1Type = IntegerType::get(context: type.getContext(), width: 1);
172 if (auto shapedType = llvm::dyn_cast<ShapedType>(Val&: type))
173 return shapedType.cloneWith(shape: std::nullopt, elementType: i1Type);
174 if (llvm::isa<UnrankedTensorType>(Val: type))
175 return UnrankedTensorType::get(elementType: i1Type);
176 return i1Type;
177}
178
179//===----------------------------------------------------------------------===//
180// ConstantOp
181//===----------------------------------------------------------------------===//
182
183void arith::ConstantOp::getAsmResultNames(
184 function_ref<void(Value, StringRef)> setNameFn) {
185 auto type = getType();
186 if (auto intCst = llvm::dyn_cast<IntegerAttr>(Val: getValue())) {
187 auto intType = llvm::dyn_cast<IntegerType>(Val&: type);
188
189 // Sugar i1 constants with 'true' and 'false'.
190 if (intType && intType.getWidth() == 1)
191 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
192
193 // Otherwise, build a complex name with the value and type.
194 SmallString<32> specialNameBuffer;
195 llvm::raw_svector_ostream specialName(specialNameBuffer);
196 specialName << 'c' << intCst.getValue();
197 if (intType)
198 specialName << '_' << type;
199 setNameFn(getResult(), specialName.str());
200 } else {
201 setNameFn(getResult(), "cst");
202 }
203}
204
205/// TODO: disallow arith.constant to return anything other than signless integer
206/// or float like.
207LogicalResult arith::ConstantOp::verify() {
208 auto type = getType();
209 // Integer values must be signless.
210 if (llvm::isa<IntegerType>(Val: type) &&
211 !llvm::cast<IntegerType>(Val&: type).isSignless())
212 return emitOpError(message: "integer return type must be signless");
213 // Any float or elements attribute are acceptable.
214 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(Val: getValue())) {
215 return emitOpError(
216 message: "value must be an integer, float, or elements attribute");
217 }
218
219 // Note, we could relax this for vectors with 1 scalable dim, e.g.:
220 // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
221 // However, this would most likely require updating the lowerings to LLVM.
222 if (isa<ScalableVectorType>(Val: type) && !isa<SplatElementsAttr>(Val: getValue()))
223 return emitOpError(
224 message: "intializing scalable vectors with elements attribute is not supported"
225 " unless it's a vector splat");
226 return success();
227}
228
229bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
230 // The value's type must be the same as the provided type.
231 auto typedAttr = llvm::dyn_cast<TypedAttr>(Val&: value);
232 if (!typedAttr || typedAttr.getType() != type)
233 return false;
234 // Integer values must be signless.
235 if (llvm::isa<IntegerType>(Val: type) &&
236 !llvm::cast<IntegerType>(Val&: type).isSignless())
237 return false;
238 // Integer, float, and element attributes are buildable.
239 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(Val: value);
240}
241
242ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
243 Type type, Location loc) {
244 if (isBuildableWith(value, type))
245 return builder.create<arith::ConstantOp>(location: loc, args: cast<TypedAttr>(Val&: value));
246 return nullptr;
247}
248
249OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
250
251void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
252 int64_t value, unsigned width) {
253 auto type = builder.getIntegerType(width);
254 arith::ConstantOp::build(odsBuilder&: builder, odsState&: result, result: type,
255 value: builder.getIntegerAttr(type, value));
256}
257
258void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
259 Type type, int64_t value) {
260 arith::ConstantOp::build(odsBuilder&: builder, odsState&: result, result: type,
261 value: builder.getIntegerAttr(type, value));
262}
263
264void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
265 Type type, const APInt &value) {
266 arith::ConstantOp::build(odsBuilder&: builder, odsState&: result, result: type,
267 value: builder.getIntegerAttr(type, value));
268}
269
270bool arith::ConstantIntOp::classof(Operation *op) {
271 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(Val: op))
272 return constOp.getType().isSignlessInteger();
273 return false;
274}
275
276void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
277 FloatType type, const APFloat &value) {
278 arith::ConstantOp::build(odsBuilder&: builder, odsState&: result, result: type,
279 value: builder.getFloatAttr(type, value));
280}
281
282bool arith::ConstantFloatOp::classof(Operation *op) {
283 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(Val: op))
284 return llvm::isa<FloatType>(Val: constOp.getType());
285 return false;
286}
287
288void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
289 int64_t value) {
290 arith::ConstantOp::build(odsBuilder&: builder, odsState&: result, result: builder.getIndexType(),
291 value: builder.getIndexAttr(value));
292}
293
294bool arith::ConstantIndexOp::classof(Operation *op) {
295 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(Val: op))
296 return constOp.getType().isIndex();
297 return false;
298}
299
300Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc,
301 Type type) {
302 // TODO: Incorporate this check to `FloatAttr::get*`.
303 assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
304 "type doesn't have a zero representation");
305 TypedAttr zeroAttr = builder.getZeroAttr(type);
306 assert(zeroAttr && "unsupported type for zero attribute");
307 return builder.create<arith::ConstantOp>(location: loc, args&: zeroAttr);
308}
309
310//===----------------------------------------------------------------------===//
311// AddIOp
312//===----------------------------------------------------------------------===//
313
314OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
315 // addi(x, 0) -> x
316 if (matchPattern(attr: adaptor.getRhs(), pattern: m_Zero()))
317 return getLhs();
318
319 // addi(subi(a, b), b) -> a
320 if (auto sub = getLhs().getDefiningOp<SubIOp>())
321 if (getRhs() == sub.getRhs())
322 return sub.getLhs();
323
324 // addi(b, subi(a, b)) -> a
325 if (auto sub = getRhs().getDefiningOp<SubIOp>())
326 if (getLhs() == sub.getRhs())
327 return sub.getLhs();
328
329 return constFoldBinaryOp<IntegerAttr>(
330 operands: adaptor.getOperands(),
331 calculate: [](APInt a, const APInt &b) { return std::move(a) + b; });
332}
333
334void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
335 MLIRContext *context) {
336 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
337 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(arg&: context);
338}
339
340//===----------------------------------------------------------------------===//
341// AddUIExtendedOp
342//===----------------------------------------------------------------------===//
343
344std::optional<SmallVector<int64_t, 4>>
345arith::AddUIExtendedOp::getShapeForUnroll() {
346 if (auto vt = llvm::dyn_cast<VectorType>(Val: getType(i: 0)))
347 return llvm::to_vector<4>(Range: vt.getShape());
348 return std::nullopt;
349}
350
351// Returns the overflow bit, assuming that `sum` is the result of unsigned
352// addition of `operand` and another number.
353static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
354 return sum.ult(RHS: operand) ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1);
355}
356
357LogicalResult
358arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
359 SmallVectorImpl<OpFoldResult> &results) {
360 Type overflowTy = getOverflow().getType();
361 // addui_extended(x, 0) -> x, false
362 if (matchPattern(value: getRhs(), pattern: m_Zero())) {
363 Builder builder(getContext());
364 auto falseValue = builder.getZeroAttr(type: overflowTy);
365
366 results.push_back(Elt: getLhs());
367 results.push_back(Elt: falseValue);
368 return success();
369 }
370
371 // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
372 // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
373 // operands. If that succeeds, calculate the overflow bit based on the sum
374 // and the first (constant) operand, `lhs`.
375 if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
376 operands: adaptor.getOperands(),
377 calculate: [](APInt a, const APInt &b) { return std::move(a) + b; })) {
378 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
379 operands: ArrayRef({sumAttr, adaptor.getLhs()}),
380 resultType: getI1SameShape(type: llvm::cast<TypedAttr>(Val&: sumAttr).getType()),
381 calculate&: calculateUnsignedOverflow);
382 if (!overflowAttr)
383 return failure();
384
385 results.push_back(Elt: sumAttr);
386 results.push_back(Elt: overflowAttr);
387 return success();
388 }
389
390 return failure();
391}
392
393void arith::AddUIExtendedOp::getCanonicalizationPatterns(
394 RewritePatternSet &patterns, MLIRContext *context) {
395 patterns.add<AddUIExtendedToAddI>(arg&: context);
396}
397
398//===----------------------------------------------------------------------===//
399// SubIOp
400//===----------------------------------------------------------------------===//
401
402OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
403 // subi(x,x) -> 0
404 if (getOperand(i: 0) == getOperand(i: 1)) {
405 auto shapedType = dyn_cast<ShapedType>(Val: getType());
406 // We can't generate a constant with a dynamic shaped tensor.
407 if (!shapedType || shapedType.hasStaticShape())
408 return Builder(getContext()).getZeroAttr(type: getType());
409 }
410 // subi(x,0) -> x
411 if (matchPattern(attr: adaptor.getRhs(), pattern: m_Zero()))
412 return getLhs();
413
414 if (auto add = getLhs().getDefiningOp<AddIOp>()) {
415 // subi(addi(a, b), b) -> a
416 if (getRhs() == add.getRhs())
417 return add.getLhs();
418 // subi(addi(a, b), a) -> b
419 if (getRhs() == add.getLhs())
420 return add.getRhs();
421 }
422
423 return constFoldBinaryOp<IntegerAttr>(
424 operands: adaptor.getOperands(),
425 calculate: [](APInt a, const APInt &b) { return std::move(a) - b; });
426}
427
428void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
429 MLIRContext *context) {
430 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
431 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
432 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(arg&: context);
433}
434
435//===----------------------------------------------------------------------===//
436// MulIOp
437//===----------------------------------------------------------------------===//
438
439OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
440 // muli(x, 0) -> 0
441 if (matchPattern(attr: adaptor.getRhs(), pattern: m_Zero()))
442 return getRhs();
443 // muli(x, 1) -> x
444 if (matchPattern(attr: adaptor.getRhs(), pattern: m_One()))
445 return getLhs();
446 // TODO: Handle the overflow case.
447
448 // default folder
449 return constFoldBinaryOp<IntegerAttr>(
450 operands: adaptor.getOperands(),
451 calculate: [](const APInt &a, const APInt &b) { return a * b; });
452}
453
454void arith::MulIOp::getAsmResultNames(
455 function_ref<void(Value, StringRef)> setNameFn) {
456 if (!isa<IndexType>(Val: getType()))
457 return;
458
459 // Match vector.vscale by name to avoid depending on the vector dialect (which
460 // is a circular dependency).
461 auto isVscale = [](Operation *op) {
462 return op && op->getName().getStringRef() == "vector.vscale";
463 };
464
465 IntegerAttr baseValue;
466 auto isVscaleExpr = [&](Value a, Value b) {
467 return matchPattern(value: a, pattern: m_Constant(bind_value: &baseValue)) &&
468 isVscale(b.getDefiningOp());
469 };
470
471 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
472 return;
473
474 // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
475 SmallString<32> specialNameBuffer;
476 llvm::raw_svector_ostream specialName(specialNameBuffer);
477 specialName << 'c' << baseValue.getInt() << "_vscale";
478 setNameFn(getResult(), specialName.str());
479}
480
481void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
482 MLIRContext *context) {
483 patterns.add<MulIMulIConstant>(arg&: context);
484}
485
486//===----------------------------------------------------------------------===//
487// MulSIExtendedOp
488//===----------------------------------------------------------------------===//
489
490std::optional<SmallVector<int64_t, 4>>
491arith::MulSIExtendedOp::getShapeForUnroll() {
492 if (auto vt = llvm::dyn_cast<VectorType>(Val: getType(i: 0)))
493 return llvm::to_vector<4>(Range: vt.getShape());
494 return std::nullopt;
495}
496
497LogicalResult
498arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
499 SmallVectorImpl<OpFoldResult> &results) {
500 // mulsi_extended(x, 0) -> 0, 0
501 if (matchPattern(attr: adaptor.getRhs(), pattern: m_Zero())) {
502 Attribute zero = adaptor.getRhs();
503 results.push_back(Elt: zero);
504 results.push_back(Elt: zero);
505 return success();
506 }
507
508 // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
509 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
510 operands: adaptor.getOperands(),
511 calculate: [](const APInt &a, const APInt &b) { return a * b; })) {
512 // Invoke the constant fold helper again to calculate the 'high' result.
513 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
514 operands: adaptor.getOperands(), calculate: [](const APInt &a, const APInt &b) {
515 return llvm::APIntOps::mulhs(C1: a, C2: b);
516 });
517 assert(highAttr && "Unexpected constant-folding failure");
518
519 results.push_back(Elt: lowAttr);
520 results.push_back(Elt: highAttr);
521 return success();
522 }
523
524 return failure();
525}
526
527void arith::MulSIExtendedOp::getCanonicalizationPatterns(
528 RewritePatternSet &patterns, MLIRContext *context) {
529 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(arg&: context);
530}
531
532//===----------------------------------------------------------------------===//
533// MulUIExtendedOp
534//===----------------------------------------------------------------------===//
535
536std::optional<SmallVector<int64_t, 4>>
537arith::MulUIExtendedOp::getShapeForUnroll() {
538 if (auto vt = llvm::dyn_cast<VectorType>(Val: getType(i: 0)))
539 return llvm::to_vector<4>(Range: vt.getShape());
540 return std::nullopt;
541}
542
543LogicalResult
544arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
545 SmallVectorImpl<OpFoldResult> &results) {
546 // mului_extended(x, 0) -> 0, 0
547 if (matchPattern(attr: adaptor.getRhs(), pattern: m_Zero())) {
548 Attribute zero = adaptor.getRhs();
549 results.push_back(Elt: zero);
550 results.push_back(Elt: zero);
551 return success();
552 }
553
554 // mului_extended(x, 1) -> x, 0
555 if (matchPattern(attr: adaptor.getRhs(), pattern: m_One())) {
556 Builder builder(getContext());
557 Attribute zero = builder.getZeroAttr(type: getLhs().getType());
558 results.push_back(Elt: getLhs());
559 results.push_back(Elt: zero);
560 return success();
561 }
562
563 // mului_extended(cst_a, cst_b) -> cst_low, cst_high
564 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
565 operands: adaptor.getOperands(),
566 calculate: [](const APInt &a, const APInt &b) { return a * b; })) {
567 // Invoke the constant fold helper again to calculate the 'high' result.
568 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
569 operands: adaptor.getOperands(), calculate: [](const APInt &a, const APInt &b) {
570 return llvm::APIntOps::mulhu(C1: a, C2: b);
571 });
572 assert(highAttr && "Unexpected constant-folding failure");
573
574 results.push_back(Elt: lowAttr);
575 results.push_back(Elt: highAttr);
576 return success();
577 }
578
579 return failure();
580}
581
582void arith::MulUIExtendedOp::getCanonicalizationPatterns(
583 RewritePatternSet &patterns, MLIRContext *context) {
584 patterns.add<MulUIExtendedToMulI>(arg&: context);
585}
586
587//===----------------------------------------------------------------------===//
588// DivUIOp
589//===----------------------------------------------------------------------===//
590
591/// Fold `(a * b) / b -> a`
592static Value foldDivMul(Value lhs, Value rhs,
593 arith::IntegerOverflowFlags ovfFlags) {
594 auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
595 if (!mul || !bitEnumContainsAll(bits: mul.getOverflowFlags(), bit: ovfFlags))
596 return {};
597
598 if (mul.getLhs() == rhs)
599 return mul.getRhs();
600
601 if (mul.getRhs() == rhs)
602 return mul.getLhs();
603
604 return {};
605}
606
607OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
608 // divui (x, 1) -> x.
609 if (matchPattern(attr: adaptor.getRhs(), pattern: m_One()))
610 return getLhs();
611
612 // (a * b) / b -> a
613 if (Value val = foldDivMul(lhs: getLhs(), rhs: getRhs(), ovfFlags: IntegerOverflowFlags::nuw))
614 return val;
615
616 // Don't fold if it would require a division by zero.
617 bool div0 = false;
618 auto result = constFoldBinaryOp<IntegerAttr>(operands: adaptor.getOperands(),
619 calculate: [&](APInt a, const APInt &b) {
620 if (div0 || !b) {
621 div0 = true;
622 return a;
623 }
624 return a.udiv(RHS: b);
625 });
626
627 return div0 ? Attribute() : result;
628}
629
630/// Returns whether an unsigned division by `divisor` is speculatable.
631static Speculation::Speculatability getDivUISpeculatability(Value divisor) {
632 // X / 0 => UB
633 if (matchPattern(value: divisor, pattern: m_IntRangeWithoutZeroU()))
634 return Speculation::Speculatable;
635
636 return Speculation::NotSpeculatable;
637}
638
639Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
640 return getDivUISpeculatability(divisor: getRhs());
641}
642
643//===----------------------------------------------------------------------===//
644// DivSIOp
645//===----------------------------------------------------------------------===//
646
647OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
648 // divsi (x, 1) -> x.
649 if (matchPattern(attr: adaptor.getRhs(), pattern: m_One()))
650 return getLhs();
651
652 // (a * b) / b -> a
653 if (Value val = foldDivMul(lhs: getLhs(), rhs: getRhs(), ovfFlags: IntegerOverflowFlags::nsw))
654 return val;
655
656 // Don't fold if it would overflow or if it requires a division by zero.
657 bool overflowOrDiv0 = false;
658 auto result = constFoldBinaryOp<IntegerAttr>(
659 operands: adaptor.getOperands(), calculate: [&](APInt a, const APInt &b) {
660 if (overflowOrDiv0 || !b) {
661 overflowOrDiv0 = true;
662 return a;
663 }
664 return a.sdiv_ov(RHS: b, Overflow&: overflowOrDiv0);
665 });
666
667 return overflowOrDiv0 ? Attribute() : result;
668}
669
670/// Returns whether a signed division by `divisor` is speculatable. This
671/// function conservatively assumes that all signed division by -1 are not
672/// speculatable.
673static Speculation::Speculatability getDivSISpeculatability(Value divisor) {
674 // X / 0 => UB
675 // INT_MIN / -1 => UB
676 if (matchPattern(value: divisor, pattern: m_IntRangeWithoutZeroS()) &&
677 matchPattern(value: divisor, pattern: m_IntRangeWithoutNegOneS()))
678 return Speculation::Speculatable;
679
680 return Speculation::NotSpeculatable;
681}
682
683Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
684 return getDivSISpeculatability(divisor: getRhs());
685}
686
687//===----------------------------------------------------------------------===//
688// Ceil and floor division folding helpers
689//===----------------------------------------------------------------------===//
690
691static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
692 bool &overflow) {
693 // Returns (a-1)/b + 1
694 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
695 APInt val = a.ssub_ov(RHS: one, Overflow&: overflow).sdiv_ov(RHS: b, Overflow&: overflow);
696 return val.sadd_ov(RHS: one, Overflow&: overflow);
697}
698
699//===----------------------------------------------------------------------===//
700// CeilDivUIOp
701//===----------------------------------------------------------------------===//
702
703OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
704 // ceildivui (x, 1) -> x.
705 if (matchPattern(attr: adaptor.getRhs(), pattern: m_One()))
706 return getLhs();
707
708 bool overflowOrDiv0 = false;
709 auto result = constFoldBinaryOp<IntegerAttr>(
710 operands: adaptor.getOperands(), calculate: [&](APInt a, const APInt &b) {
711 if (overflowOrDiv0 || !b) {
712 overflowOrDiv0 = true;
713 return a;
714 }
715 APInt quotient = a.udiv(RHS: b);
716 if (!a.urem(RHS: b))
717 return quotient;
718 APInt one(a.getBitWidth(), 1, true);
719 return quotient.uadd_ov(RHS: one, Overflow&: overflowOrDiv0);
720 });
721
722 return overflowOrDiv0 ? Attribute() : result;
723}
724
725Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
726 return getDivUISpeculatability(divisor: getRhs());
727}
728
729//===----------------------------------------------------------------------===//
730// CeilDivSIOp
731//===----------------------------------------------------------------------===//
732
733OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
734 // ceildivsi (x, 1) -> x.
735 if (matchPattern(attr: adaptor.getRhs(), pattern: m_One()))
736 return getLhs();
737
738 // Don't fold if it would overflow or if it requires a division by zero.
739 // TODO: This hook won't fold operations where a = MININT, because
740 // negating MININT overflows. This can be improved.
741 bool overflowOrDiv0 = false;
742 auto result = constFoldBinaryOp<IntegerAttr>(
743 operands: adaptor.getOperands(), calculate: [&](APInt a, const APInt &b) {
744 if (overflowOrDiv0 || !b) {
745 overflowOrDiv0 = true;
746 return a;
747 }
748 if (!a)
749 return a;
750 // After this point we know that neither a or b are zero.
751 unsigned bits = a.getBitWidth();
752 APInt zero = APInt::getZero(numBits: bits);
753 bool aGtZero = a.sgt(RHS: zero);
754 bool bGtZero = b.sgt(RHS: zero);
755 if (aGtZero && bGtZero) {
756 // Both positive, return ceil(a, b).
757 return signedCeilNonnegInputs(a, b, overflow&: overflowOrDiv0);
758 }
759
760 // No folding happens if any of the intermediate arithmetic operations
761 // overflows.
762 bool overflowNegA = false;
763 bool overflowNegB = false;
764 bool overflowDiv = false;
765 bool overflowNegRes = false;
766 if (!aGtZero && !bGtZero) {
767 // Both negative, return ceil(-a, -b).
768 APInt posA = zero.ssub_ov(RHS: a, Overflow&: overflowNegA);
769 APInt posB = zero.ssub_ov(RHS: b, Overflow&: overflowNegB);
770 APInt res = signedCeilNonnegInputs(a: posA, b: posB, overflow&: overflowDiv);
771 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
772 return res;
773 }
774 if (!aGtZero && bGtZero) {
775 // A is negative, b is positive, return - ( -a / b).
776 APInt posA = zero.ssub_ov(RHS: a, Overflow&: overflowNegA);
777 APInt div = posA.sdiv_ov(RHS: b, Overflow&: overflowDiv);
778 APInt res = zero.ssub_ov(RHS: div, Overflow&: overflowNegRes);
779 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
780 return res;
781 }
782 // A is positive, b is negative, return - (a / -b).
783 APInt posB = zero.ssub_ov(RHS: b, Overflow&: overflowNegB);
784 APInt div = a.sdiv_ov(RHS: posB, Overflow&: overflowDiv);
785 APInt res = zero.ssub_ov(RHS: div, Overflow&: overflowNegRes);
786
787 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
788 return res;
789 });
790
791 return overflowOrDiv0 ? Attribute() : result;
792}
793
794Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
795 return getDivSISpeculatability(divisor: getRhs());
796}
797
798//===----------------------------------------------------------------------===//
799// FloorDivSIOp
800//===----------------------------------------------------------------------===//
801
802OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
803 // floordivsi (x, 1) -> x.
804 if (matchPattern(attr: adaptor.getRhs(), pattern: m_One()))
805 return getLhs();
806
807 // Don't fold if it would overflow or if it requires a division by zero.
808 bool overflowOrDiv = false;
809 auto result = constFoldBinaryOp<IntegerAttr>(
810 operands: adaptor.getOperands(), calculate: [&](APInt a, const APInt &b) {
811 if (b.isZero()) {
812 overflowOrDiv = true;
813 return a;
814 }
815 return a.sfloordiv_ov(RHS: b, Overflow&: overflowOrDiv);
816 });
817
818 return overflowOrDiv ? Attribute() : result;
819}
820
821//===----------------------------------------------------------------------===//
822// RemUIOp
823//===----------------------------------------------------------------------===//
824
825OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
826 // remui (x, 1) -> 0.
827 if (matchPattern(attr: adaptor.getRhs(), pattern: m_One()))
828 return Builder(getContext()).getZeroAttr(type: getType());
829
830 // Don't fold if it would require a division by zero.
831 bool div0 = false;
832 auto result = constFoldBinaryOp<IntegerAttr>(operands: adaptor.getOperands(),
833 calculate: [&](APInt a, const APInt &b) {
834 if (div0 || b.isZero()) {
835 div0 = true;
836 return a;
837 }
838 return a.urem(RHS: b);
839 });
840
841 return div0 ? Attribute() : result;
842}
843
844//===----------------------------------------------------------------------===//
845// RemSIOp
846//===----------------------------------------------------------------------===//
847
848OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
849 // remsi (x, 1) -> 0.
850 if (matchPattern(attr: adaptor.getRhs(), pattern: m_One()))
851 return Builder(getContext()).getZeroAttr(type: getType());
852
853 // Don't fold if it would require a division by zero.
854 bool div0 = false;
855 auto result = constFoldBinaryOp<IntegerAttr>(operands: adaptor.getOperands(),
856 calculate: [&](APInt a, const APInt &b) {
857 if (div0 || b.isZero()) {
858 div0 = true;
859 return a;
860 }
861 return a.srem(RHS: b);
862 });
863
864 return div0 ? Attribute() : result;
865}
866
867//===----------------------------------------------------------------------===//
868// AndIOp
869//===----------------------------------------------------------------------===//
870
871/// Fold `and(a, and(a, b))` to `and(a, b)`
872static Value foldAndIofAndI(arith::AndIOp op) {
873 for (bool reversePrev : {false, true}) {
874 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
875 .getDefiningOp<arith::AndIOp>();
876 if (!prev)
877 continue;
878
879 Value other = (reversePrev ? op.getLhs() : op.getRhs());
880 if (other != prev.getLhs() && other != prev.getRhs())
881 continue;
882
883 return prev.getResult();
884 }
885 return {};
886}
887
888OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
889 /// and(x, 0) -> 0
890 if (matchPattern(attr: adaptor.getRhs(), pattern: m_Zero()))
891 return getRhs();
892 /// and(x, allOnes) -> x
893 APInt intValue;
894 if (matchPattern(attr: adaptor.getRhs(), pattern: m_ConstantInt(bind_value: &intValue)) &&
895 intValue.isAllOnes())
896 return getLhs();
897 /// and(x, not(x)) -> 0
898 if (matchPattern(value: getRhs(), pattern: m_Op<XOrIOp>(matchers: matchers::m_Val(v: getLhs()),
899 matchers: m_ConstantInt(bind_value: &intValue))) &&
900 intValue.isAllOnes())
901 return Builder(getContext()).getZeroAttr(type: getType());
902 /// and(not(x), x) -> 0
903 if (matchPattern(value: getLhs(), pattern: m_Op<XOrIOp>(matchers: matchers::m_Val(v: getRhs()),
904 matchers: m_ConstantInt(bind_value: &intValue))) &&
905 intValue.isAllOnes())
906 return Builder(getContext()).getZeroAttr(type: getType());
907
908 /// and(a, and(a, b)) -> and(a, b)
909 if (Value result = foldAndIofAndI(op: *this))
910 return result;
911
912 return constFoldBinaryOp<IntegerAttr>(
913 operands: adaptor.getOperands(),
914 calculate: [](APInt a, const APInt &b) { return std::move(a) & b; });
915}
916
917//===----------------------------------------------------------------------===//
918// OrIOp
919//===----------------------------------------------------------------------===//
920
921OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
922 if (APInt rhsVal; matchPattern(attr: adaptor.getRhs(), pattern: m_ConstantInt(bind_value: &rhsVal))) {
923 /// or(x, 0) -> x
924 if (rhsVal.isZero())
925 return getLhs();
926 /// or(x, <all ones>) -> <all ones>
927 if (rhsVal.isAllOnes())
928 return adaptor.getRhs();
929 }
930
931 APInt intValue;
932 /// or(x, xor(x, 1)) -> 1
933 if (matchPattern(value: getRhs(), pattern: m_Op<XOrIOp>(matchers: matchers::m_Val(v: getLhs()),
934 matchers: m_ConstantInt(bind_value: &intValue))) &&
935 intValue.isAllOnes())
936 return getRhs().getDefiningOp<XOrIOp>().getRhs();
937 /// or(xor(x, 1), x) -> 1
938 if (matchPattern(value: getLhs(), pattern: m_Op<XOrIOp>(matchers: matchers::m_Val(v: getRhs()),
939 matchers: m_ConstantInt(bind_value: &intValue))) &&
940 intValue.isAllOnes())
941 return getLhs().getDefiningOp<XOrIOp>().getRhs();
942
943 return constFoldBinaryOp<IntegerAttr>(
944 operands: adaptor.getOperands(),
945 calculate: [](APInt a, const APInt &b) { return std::move(a) | b; });
946}
947
948//===----------------------------------------------------------------------===//
949// XOrIOp
950//===----------------------------------------------------------------------===//
951
952OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
953 /// xor(x, 0) -> x
954 if (matchPattern(attr: adaptor.getRhs(), pattern: m_Zero()))
955 return getLhs();
956 /// xor(x, x) -> 0
957 if (getLhs() == getRhs())
958 return Builder(getContext()).getZeroAttr(type: getType());
959 /// xor(xor(x, a), a) -> x
960 /// xor(xor(a, x), a) -> x
961 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
962 if (prev.getRhs() == getRhs())
963 return prev.getLhs();
964 if (prev.getLhs() == getRhs())
965 return prev.getRhs();
966 }
967 /// xor(a, xor(x, a)) -> x
968 /// xor(a, xor(a, x)) -> x
969 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
970 if (prev.getRhs() == getLhs())
971 return prev.getLhs();
972 if (prev.getLhs() == getLhs())
973 return prev.getRhs();
974 }
975
976 return constFoldBinaryOp<IntegerAttr>(
977 operands: adaptor.getOperands(),
978 calculate: [](APInt a, const APInt &b) { return std::move(a) ^ b; });
979}
980
981void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
982 MLIRContext *context) {
983 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(arg&: context);
984}
985
986//===----------------------------------------------------------------------===//
987// NegFOp
988//===----------------------------------------------------------------------===//
989
990OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
991 /// negf(negf(x)) -> x
992 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
993 return op.getOperand();
994 return constFoldUnaryOp<FloatAttr>(operands: adaptor.getOperands(),
995 calculate: [](const APFloat &a) { return -a; });
996}
997
998//===----------------------------------------------------------------------===//
999// AddFOp
1000//===----------------------------------------------------------------------===//
1001
1002OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1003 // addf(x, -0) -> x
1004 if (matchPattern(attr: adaptor.getRhs(), pattern: m_NegZeroFloat()))
1005 return getLhs();
1006
1007 return constFoldBinaryOp<FloatAttr>(
1008 operands: adaptor.getOperands(),
1009 calculate: [](const APFloat &a, const APFloat &b) { return a + b; });
1010}
1011
1012//===----------------------------------------------------------------------===//
1013// SubFOp
1014//===----------------------------------------------------------------------===//
1015
1016OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1017 // subf(x, +0) -> x
1018 if (matchPattern(attr: adaptor.getRhs(), pattern: m_PosZeroFloat()))
1019 return getLhs();
1020
1021 return constFoldBinaryOp<FloatAttr>(
1022 operands: adaptor.getOperands(),
1023 calculate: [](const APFloat &a, const APFloat &b) { return a - b; });
1024}
1025
1026//===----------------------------------------------------------------------===//
1027// MaximumFOp
1028//===----------------------------------------------------------------------===//
1029
1030OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1031 // maximumf(x,x) -> x
1032 if (getLhs() == getRhs())
1033 return getRhs();
1034
1035 // maximumf(x, -inf) -> x
1036 if (matchPattern(attr: adaptor.getRhs(), pattern: m_NegInfFloat()))
1037 return getLhs();
1038
1039 return constFoldBinaryOp<FloatAttr>(
1040 operands: adaptor.getOperands(),
1041 calculate: [](const APFloat &a, const APFloat &b) { return llvm::maximum(A: a, B: b); });
1042}
1043
1044//===----------------------------------------------------------------------===//
1045// MaxNumFOp
1046//===----------------------------------------------------------------------===//
1047
1048OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1049 // maxnumf(x,x) -> x
1050 if (getLhs() == getRhs())
1051 return getRhs();
1052
1053 // maxnumf(x, NaN) -> x
1054 if (matchPattern(attr: adaptor.getRhs(), pattern: m_NaNFloat()))
1055 return getLhs();
1056
1057 return constFoldBinaryOp<FloatAttr>(operands: adaptor.getOperands(), calculate&: llvm::maxnum);
1058}
1059
1060//===----------------------------------------------------------------------===//
1061// MaxSIOp
1062//===----------------------------------------------------------------------===//
1063
1064OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1065 // maxsi(x,x) -> x
1066 if (getLhs() == getRhs())
1067 return getRhs();
1068
1069 if (APInt intValue;
1070 matchPattern(attr: adaptor.getRhs(), pattern: m_ConstantInt(bind_value: &intValue))) {
1071 // maxsi(x,MAX_INT) -> MAX_INT
1072 if (intValue.isMaxSignedValue())
1073 return getRhs();
1074 // maxsi(x, MIN_INT) -> x
1075 if (intValue.isMinSignedValue())
1076 return getLhs();
1077 }
1078
1079 return constFoldBinaryOp<IntegerAttr>(operands: adaptor.getOperands(),
1080 calculate: [](const APInt &a, const APInt &b) {
1081 return llvm::APIntOps::smax(A: a, B: b);
1082 });
1083}
1084
1085//===----------------------------------------------------------------------===//
1086// MaxUIOp
1087//===----------------------------------------------------------------------===//
1088
1089OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1090 // maxui(x,x) -> x
1091 if (getLhs() == getRhs())
1092 return getRhs();
1093
1094 if (APInt intValue;
1095 matchPattern(attr: adaptor.getRhs(), pattern: m_ConstantInt(bind_value: &intValue))) {
1096 // maxui(x,MAX_INT) -> MAX_INT
1097 if (intValue.isMaxValue())
1098 return getRhs();
1099 // maxui(x, MIN_INT) -> x
1100 if (intValue.isMinValue())
1101 return getLhs();
1102 }
1103
1104 return constFoldBinaryOp<IntegerAttr>(operands: adaptor.getOperands(),
1105 calculate: [](const APInt &a, const APInt &b) {
1106 return llvm::APIntOps::umax(A: a, B: b);
1107 });
1108}
1109
1110//===----------------------------------------------------------------------===//
1111// MinimumFOp
1112//===----------------------------------------------------------------------===//
1113
1114OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1115 // minimumf(x,x) -> x
1116 if (getLhs() == getRhs())
1117 return getRhs();
1118
1119 // minimumf(x, +inf) -> x
1120 if (matchPattern(attr: adaptor.getRhs(), pattern: m_PosInfFloat()))
1121 return getLhs();
1122
1123 return constFoldBinaryOp<FloatAttr>(
1124 operands: adaptor.getOperands(),
1125 calculate: [](const APFloat &a, const APFloat &b) { return llvm::minimum(A: a, B: b); });
1126}
1127
1128//===----------------------------------------------------------------------===//
1129// MinNumFOp
1130//===----------------------------------------------------------------------===//
1131
1132OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1133 // minnumf(x,x) -> x
1134 if (getLhs() == getRhs())
1135 return getRhs();
1136
1137 // minnumf(x, NaN) -> x
1138 if (matchPattern(attr: adaptor.getRhs(), pattern: m_NaNFloat()))
1139 return getLhs();
1140
1141 return constFoldBinaryOp<FloatAttr>(
1142 operands: adaptor.getOperands(),
1143 calculate: [](const APFloat &a, const APFloat &b) { return llvm::minnum(A: a, B: b); });
1144}
1145
1146//===----------------------------------------------------------------------===//
1147// MinSIOp
1148//===----------------------------------------------------------------------===//
1149
1150OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1151 // minsi(x,x) -> x
1152 if (getLhs() == getRhs())
1153 return getRhs();
1154
1155 if (APInt intValue;
1156 matchPattern(attr: adaptor.getRhs(), pattern: m_ConstantInt(bind_value: &intValue))) {
1157 // minsi(x,MIN_INT) -> MIN_INT
1158 if (intValue.isMinSignedValue())
1159 return getRhs();
1160 // minsi(x, MAX_INT) -> x
1161 if (intValue.isMaxSignedValue())
1162 return getLhs();
1163 }
1164
1165 return constFoldBinaryOp<IntegerAttr>(operands: adaptor.getOperands(),
1166 calculate: [](const APInt &a, const APInt &b) {
1167 return llvm::APIntOps::smin(A: a, B: b);
1168 });
1169}
1170
1171//===----------------------------------------------------------------------===//
1172// MinUIOp
1173//===----------------------------------------------------------------------===//
1174
1175OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1176 // minui(x,x) -> x
1177 if (getLhs() == getRhs())
1178 return getRhs();
1179
1180 if (APInt intValue;
1181 matchPattern(attr: adaptor.getRhs(), pattern: m_ConstantInt(bind_value: &intValue))) {
1182 // minui(x,MIN_INT) -> MIN_INT
1183 if (intValue.isMinValue())
1184 return getRhs();
1185 // minui(x, MAX_INT) -> x
1186 if (intValue.isMaxValue())
1187 return getLhs();
1188 }
1189
1190 return constFoldBinaryOp<IntegerAttr>(operands: adaptor.getOperands(),
1191 calculate: [](const APInt &a, const APInt &b) {
1192 return llvm::APIntOps::umin(A: a, B: b);
1193 });
1194}
1195
1196//===----------------------------------------------------------------------===//
1197// MulFOp
1198//===----------------------------------------------------------------------===//
1199
1200OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1201 // mulf(x, 1) -> x
1202 if (matchPattern(attr: adaptor.getRhs(), pattern: m_OneFloat()))
1203 return getLhs();
1204
1205 return constFoldBinaryOp<FloatAttr>(
1206 operands: adaptor.getOperands(),
1207 calculate: [](const APFloat &a, const APFloat &b) { return a * b; });
1208}
1209
1210void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1211 MLIRContext *context) {
1212 patterns.add<MulFOfNegF>(arg&: context);
1213}
1214
1215//===----------------------------------------------------------------------===//
1216// DivFOp
1217//===----------------------------------------------------------------------===//
1218
1219OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1220 // divf(x, 1) -> x
1221 if (matchPattern(attr: adaptor.getRhs(), pattern: m_OneFloat()))
1222 return getLhs();
1223
1224 return constFoldBinaryOp<FloatAttr>(
1225 operands: adaptor.getOperands(),
1226 calculate: [](const APFloat &a, const APFloat &b) { return a / b; });
1227}
1228
1229void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1230 MLIRContext *context) {
1231 patterns.add<DivFOfNegF>(arg&: context);
1232}
1233
1234//===----------------------------------------------------------------------===//
1235// RemFOp
1236//===----------------------------------------------------------------------===//
1237
1238OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1239 return constFoldBinaryOp<FloatAttr>(operands: adaptor.getOperands(),
1240 calculate: [](const APFloat &a, const APFloat &b) {
1241 APFloat result(a);
1242 // APFloat::mod() offers the remainder
1243 // behavior we want, i.e. the result has
1244 // the sign of LHS operand.
1245 (void)result.mod(RHS: b);
1246 return result;
1247 });
1248}
1249
1250//===----------------------------------------------------------------------===//
1251// Utility functions for verifying cast ops
1252//===----------------------------------------------------------------------===//
1253
1254template <typename... Types>
1255using type_list = std::tuple<Types...> *;
1256
1257/// Returns a non-null type only if the provided type is one of the allowed
1258/// types or one of the allowed shaped types of the allowed types. Returns the
1259/// element type if a valid shaped type is provided.
1260template <typename... ShapedTypes, typename... ElementTypes>
1261static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
1262 type_list<ElementTypes...>) {
1263 if (llvm::isa<ShapedType>(Val: type) && !llvm::isa<ShapedTypes...>(type))
1264 return {};
1265
1266 auto underlyingType = getElementTypeOrSelf(type);
1267 if (!llvm::isa<ElementTypes...>(underlyingType))
1268 return {};
1269
1270 return underlyingType;
1271}
1272
1273/// Get allowed underlying types for vectors and tensors.
1274template <typename... ElementTypes>
1275static Type getTypeIfLike(Type type) {
1276 return getUnderlyingType(type, type_list<VectorType, TensorType>(),
1277 type_list<ElementTypes...>());
1278}
1279
1280/// Get allowed underlying types for vectors, tensors, and memrefs.
1281template <typename... ElementTypes>
1282static Type getTypeIfLikeOrMemRef(Type type) {
1283 return getUnderlyingType(type,
1284 type_list<VectorType, TensorType, MemRefType>(),
1285 type_list<ElementTypes...>());
1286}
1287
1288/// Return false if both types are ranked tensor with mismatching encoding.
1289static bool hasSameEncoding(Type typeA, Type typeB) {
1290 auto rankedTensorA = dyn_cast<RankedTensorType>(Val&: typeA);
1291 auto rankedTensorB = dyn_cast<RankedTensorType>(Val&: typeB);
1292 if (!rankedTensorA || !rankedTensorB)
1293 return true;
1294 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1295}
1296
1297static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
1298 if (inputs.size() != 1 || outputs.size() != 1)
1299 return false;
1300 if (!hasSameEncoding(typeA: inputs.front(), typeB: outputs.front()))
1301 return false;
1302 return succeeded(Result: verifyCompatibleShapes(types1: inputs.front(), types2: outputs.front()));
1303}
1304
1305//===----------------------------------------------------------------------===//
1306// Verifiers for integer and floating point extension/truncation ops
1307//===----------------------------------------------------------------------===//
1308
1309// Extend ops can only extend to a wider type.
1310template <typename ValType, typename Op>
1311static LogicalResult verifyExtOp(Op op) {
1312 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1313 Type dstType = getElementTypeOrSelf(op.getType());
1314
1315 if (llvm::cast<ValType>(srcType).getWidth() >=
1316 llvm::cast<ValType>(dstType).getWidth())
1317 return op.emitError("result type ")
1318 << dstType << " must be wider than operand type " << srcType;
1319
1320 return success();
1321}
1322
1323// Truncate ops can only truncate to a shorter type.
1324template <typename ValType, typename Op>
1325static LogicalResult verifyTruncateOp(Op op) {
1326 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1327 Type dstType = getElementTypeOrSelf(op.getType());
1328
1329 if (llvm::cast<ValType>(srcType).getWidth() <=
1330 llvm::cast<ValType>(dstType).getWidth())
1331 return op.emitError("result type ")
1332 << dstType << " must be shorter than operand type " << srcType;
1333
1334 return success();
1335}
1336
1337/// Validate a cast that changes the width of a type.
1338template <template <typename> class WidthComparator, typename... ElementTypes>
1339static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1340 if (!areValidCastInputsAndOutputs(inputs, outputs))
1341 return false;
1342
1343 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1344 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1345 if (!srcType || !dstType)
1346 return false;
1347
1348 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1349 srcType.getIntOrFloatBitWidth());
1350}
1351
1352/// Attempts to convert `sourceValue` to an APFloat value with
1353/// `targetSemantics` and `roundingMode`, without any information loss.
1354static FailureOr<APFloat> convertFloatValue(
1355 APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1356 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1357 bool losesInfo = false;
1358 auto status = sourceValue.convert(ToSemantics: targetSemantics, RM: roundingMode, losesInfo: &losesInfo);
1359 if (losesInfo || status != APFloat::opOK)
1360 return failure();
1361
1362 return sourceValue;
1363}
1364
1365//===----------------------------------------------------------------------===//
1366// ExtUIOp
1367//===----------------------------------------------------------------------===//
1368
1369OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1370 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1371 getInMutable().assign(value: lhs.getIn());
1372 return getResult();
1373 }
1374
1375 Type resType = getElementTypeOrSelf(type: getType());
1376 unsigned bitWidth = llvm::cast<IntegerType>(Val&: resType).getWidth();
1377 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1378 operands: adaptor.getOperands(), resType: getType(),
1379 calculate: [bitWidth](const APInt &a, bool &castStatus) {
1380 return a.zext(width: bitWidth);
1381 });
1382}
1383
1384bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1385 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1386}
1387
1388LogicalResult arith::ExtUIOp::verify() {
1389 return verifyExtOp<IntegerType>(op: *this);
1390}
1391
1392//===----------------------------------------------------------------------===//
1393// ExtSIOp
1394//===----------------------------------------------------------------------===//
1395
1396OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1397 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1398 getInMutable().assign(value: lhs.getIn());
1399 return getResult();
1400 }
1401
1402 Type resType = getElementTypeOrSelf(type: getType());
1403 unsigned bitWidth = llvm::cast<IntegerType>(Val&: resType).getWidth();
1404 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1405 operands: adaptor.getOperands(), resType: getType(),
1406 calculate: [bitWidth](const APInt &a, bool &castStatus) {
1407 return a.sext(width: bitWidth);
1408 });
1409}
1410
1411bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1412 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
1413}
1414
1415void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1416 MLIRContext *context) {
1417 patterns.add<ExtSIOfExtUI>(arg&: context);
1418}
1419
1420LogicalResult arith::ExtSIOp::verify() {
1421 return verifyExtOp<IntegerType>(op: *this);
1422}
1423
1424//===----------------------------------------------------------------------===//
1425// ExtFOp
1426//===----------------------------------------------------------------------===//
1427
1428/// Fold extension of float constants when there is no information loss due the
1429/// difference in fp semantics.
1430OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1431 if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1432 if (truncFOp.getOperand().getType() == getType()) {
1433 arith::FastMathFlags truncFMF =
1434 truncFOp.getFastmath().value_or(u: arith::FastMathFlags::none);
1435 bool isTruncContract =
1436 bitEnumContainsAll(bits: truncFMF, bit: arith::FastMathFlags::contract);
1437 arith::FastMathFlags extFMF =
1438 getFastmath().value_or(u: arith::FastMathFlags::none);
1439 bool isExtContract =
1440 bitEnumContainsAll(bits: extFMF, bit: arith::FastMathFlags::contract);
1441 if (isTruncContract && isExtContract) {
1442 return truncFOp.getOperand();
1443 }
1444 }
1445 }
1446
1447 auto resElemType = cast<FloatType>(Val: getElementTypeOrSelf(type: getType()));
1448 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1449 return constFoldCastOp<FloatAttr, FloatAttr>(
1450 operands: adaptor.getOperands(), resType: getType(),
1451 calculate: [&targetSemantics](const APFloat &a, bool &castStatus) {
1452 FailureOr<APFloat> result = convertFloatValue(sourceValue: a, targetSemantics);
1453 if (failed(Result: result)) {
1454 castStatus = false;
1455 return a;
1456 }
1457 return *result;
1458 });
1459}
1460
1461bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1462 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1463}
1464
1465LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(op: *this); }
1466
1467//===----------------------------------------------------------------------===//
1468// ScalingExtFOp
1469//===----------------------------------------------------------------------===//
1470
1471bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1472 TypeRange outputs) {
1473 return checkWidthChangeCast<std::greater, FloatType>(inputs: inputs.front(), outputs);
1474}
1475
1476LogicalResult arith::ScalingExtFOp::verify() {
1477 return verifyExtOp<FloatType>(op: *this);
1478}
1479
1480//===----------------------------------------------------------------------===//
1481// TruncIOp
1482//===----------------------------------------------------------------------===//
1483
1484OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1485 if (matchPattern(value: getOperand(), pattern: m_Op<arith::ExtUIOp>()) ||
1486 matchPattern(value: getOperand(), pattern: m_Op<arith::ExtSIOp>())) {
1487 Value src = getOperand().getDefiningOp()->getOperand(idx: 0);
1488 Type srcType = getElementTypeOrSelf(type: src.getType());
1489 Type dstType = getElementTypeOrSelf(type: getType());
1490 // trunci(zexti(a)) -> trunci(a)
1491 // trunci(sexti(a)) -> trunci(a)
1492 if (llvm::cast<IntegerType>(Val&: srcType).getWidth() >
1493 llvm::cast<IntegerType>(Val&: dstType).getWidth()) {
1494 setOperand(src);
1495 return getResult();
1496 }
1497
1498 // trunci(zexti(a)) -> a
1499 // trunci(sexti(a)) -> a
1500 if (srcType == dstType)
1501 return src;
1502 }
1503
1504 // trunci(trunci(a)) -> trunci(a))
1505 if (matchPattern(value: getOperand(), pattern: m_Op<arith::TruncIOp>())) {
1506 setOperand(getOperand().getDefiningOp()->getOperand(idx: 0));
1507 return getResult();
1508 }
1509
1510 Type resType = getElementTypeOrSelf(type: getType());
1511 unsigned bitWidth = llvm::cast<IntegerType>(Val&: resType).getWidth();
1512 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1513 operands: adaptor.getOperands(), resType: getType(),
1514 calculate: [bitWidth](const APInt &a, bool &castStatus) {
1515 return a.trunc(width: bitWidth);
1516 });
1517}
1518
1519bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1520 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1521}
1522
1523void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1524 MLIRContext *context) {
1525 patterns
1526 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1527 arg&: context);
1528}
1529
1530LogicalResult arith::TruncIOp::verify() {
1531 return verifyTruncateOp<IntegerType>(op: *this);
1532}
1533
1534//===----------------------------------------------------------------------===//
1535// TruncFOp
1536//===----------------------------------------------------------------------===//
1537
1538/// Perform safe const propagation for truncf, i.e., only propagate if FP value
1539/// can be represented without precision loss.
1540OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1541 auto resElemType = cast<FloatType>(Val: getElementTypeOrSelf(type: getType()));
1542 if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1543 Value src = extOp.getIn();
1544 auto srcType = cast<FloatType>(Val: getElementTypeOrSelf(type: src.getType()));
1545 auto intermediateType =
1546 cast<FloatType>(Val: getElementTypeOrSelf(type: extOp.getType()));
1547 // Check if the srcType is representable in the intermediateType.
1548 if (llvm::APFloatBase::isRepresentableBy(
1549 A: srcType.getFloatSemantics(),
1550 B: intermediateType.getFloatSemantics())) {
1551 // truncf(extf(a)) -> truncf(a)
1552 if (srcType.getWidth() > resElemType.getWidth()) {
1553 setOperand(src);
1554 return getResult();
1555 }
1556
1557 // truncf(extf(a)) -> a
1558 if (srcType == resElemType)
1559 return src;
1560 }
1561 }
1562
1563 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1564 return constFoldCastOp<FloatAttr, FloatAttr>(
1565 operands: adaptor.getOperands(), resType: getType(),
1566 calculate: [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1567 RoundingMode roundingMode =
1568 getRoundingmode().value_or(u: RoundingMode::to_nearest_even);
1569 llvm::RoundingMode llvmRoundingMode =
1570 convertArithRoundingModeToLLVMIR(roundingMode);
1571 FailureOr<APFloat> result =
1572 convertFloatValue(sourceValue: a, targetSemantics, roundingMode: llvmRoundingMode);
1573 if (failed(Result: result)) {
1574 castStatus = false;
1575 return a;
1576 }
1577 return *result;
1578 });
1579}
1580
1581void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1582 MLIRContext *context) {
1583 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(arg&: context);
1584}
1585
1586bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1587 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1588}
1589
1590LogicalResult arith::TruncFOp::verify() {
1591 return verifyTruncateOp<FloatType>(op: *this);
1592}
1593
1594//===----------------------------------------------------------------------===//
1595// ScalingTruncFOp
1596//===----------------------------------------------------------------------===//
1597
1598bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1599 TypeRange outputs) {
1600 return checkWidthChangeCast<std::less, FloatType>(inputs: inputs.front(), outputs);
1601}
1602
1603LogicalResult arith::ScalingTruncFOp::verify() {
1604 return verifyTruncateOp<FloatType>(op: *this);
1605}
1606
1607//===----------------------------------------------------------------------===//
1608// AndIOp
1609//===----------------------------------------------------------------------===//
1610
1611void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1612 MLIRContext *context) {
1613 patterns.add<AndOfExtUI, AndOfExtSI>(arg&: context);
1614}
1615
1616//===----------------------------------------------------------------------===//
1617// OrIOp
1618//===----------------------------------------------------------------------===//
1619
1620void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1621 MLIRContext *context) {
1622 patterns.add<OrOfExtUI, OrOfExtSI>(arg&: context);
1623}
1624
1625//===----------------------------------------------------------------------===//
1626// Verifiers for casts between integers and floats.
1627//===----------------------------------------------------------------------===//
1628
1629template <typename From, typename To>
1630static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1631 if (!areValidCastInputsAndOutputs(inputs, outputs))
1632 return false;
1633
1634 auto srcType = getTypeIfLike<From>(inputs.front());
1635 auto dstType = getTypeIfLike<To>(outputs.back());
1636
1637 return srcType && dstType;
1638}
1639
1640//===----------------------------------------------------------------------===//
1641// UIToFPOp
1642//===----------------------------------------------------------------------===//
1643
1644bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1645 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1646}
1647
1648OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1649 Type resEleType = getElementTypeOrSelf(type: getType());
1650 return constFoldCastOp<IntegerAttr, FloatAttr>(
1651 operands: adaptor.getOperands(), resType: getType(),
1652 calculate: [&resEleType](const APInt &a, bool &castStatus) {
1653 FloatType floatTy = llvm::cast<FloatType>(Val&: resEleType);
1654 APFloat apf(floatTy.getFloatSemantics(),
1655 APInt::getZero(numBits: floatTy.getWidth()));
1656 apf.convertFromAPInt(Input: a, /*IsSigned=*/false,
1657 RM: APFloat::rmNearestTiesToEven);
1658 return apf;
1659 });
1660}
1661
1662//===----------------------------------------------------------------------===//
1663// SIToFPOp
1664//===----------------------------------------------------------------------===//
1665
1666bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1667 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1668}
1669
1670OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1671 Type resEleType = getElementTypeOrSelf(type: getType());
1672 return constFoldCastOp<IntegerAttr, FloatAttr>(
1673 operands: adaptor.getOperands(), resType: getType(),
1674 calculate: [&resEleType](const APInt &a, bool &castStatus) {
1675 FloatType floatTy = llvm::cast<FloatType>(Val&: resEleType);
1676 APFloat apf(floatTy.getFloatSemantics(),
1677 APInt::getZero(numBits: floatTy.getWidth()));
1678 apf.convertFromAPInt(Input: a, /*IsSigned=*/true,
1679 RM: APFloat::rmNearestTiesToEven);
1680 return apf;
1681 });
1682}
1683
1684//===----------------------------------------------------------------------===//
1685// FPToUIOp
1686//===----------------------------------------------------------------------===//
1687
1688bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1689 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1690}
1691
1692OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1693 Type resType = getElementTypeOrSelf(type: getType());
1694 unsigned bitWidth = llvm::cast<IntegerType>(Val&: resType).getWidth();
1695 return constFoldCastOp<FloatAttr, IntegerAttr>(
1696 operands: adaptor.getOperands(), resType: getType(),
1697 calculate: [&bitWidth](const APFloat &a, bool &castStatus) {
1698 bool ignored;
1699 APSInt api(bitWidth, /*isUnsigned=*/true);
1700 castStatus = APFloat::opInvalidOp !=
1701 a.convertToInteger(Result&: api, RM: APFloat::rmTowardZero, IsExact: &ignored);
1702 return api;
1703 });
1704}
1705
1706//===----------------------------------------------------------------------===//
1707// FPToSIOp
1708//===----------------------------------------------------------------------===//
1709
1710bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1711 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1712}
1713
1714OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1715 Type resType = getElementTypeOrSelf(type: getType());
1716 unsigned bitWidth = llvm::cast<IntegerType>(Val&: resType).getWidth();
1717 return constFoldCastOp<FloatAttr, IntegerAttr>(
1718 operands: adaptor.getOperands(), resType: getType(),
1719 calculate: [&bitWidth](const APFloat &a, bool &castStatus) {
1720 bool ignored;
1721 APSInt api(bitWidth, /*isUnsigned=*/false);
1722 castStatus = APFloat::opInvalidOp !=
1723 a.convertToInteger(Result&: api, RM: APFloat::rmTowardZero, IsExact: &ignored);
1724 return api;
1725 });
1726}
1727
1728//===----------------------------------------------------------------------===//
1729// IndexCastOp
1730//===----------------------------------------------------------------------===//
1731
1732static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1733 if (!areValidCastInputsAndOutputs(inputs, outputs))
1734 return false;
1735
1736 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(type: inputs.front());
1737 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(type: outputs.front());
1738 if (!srcType || !dstType)
1739 return false;
1740
1741 return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1742 (srcType.isSignlessInteger() && dstType.isIndex());
1743}
1744
1745bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1746 TypeRange outputs) {
1747 return areIndexCastCompatible(inputs, outputs);
1748}
1749
1750OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1751 // index_cast(constant) -> constant
1752 unsigned resultBitwidth = 64; // Default for index integer attributes.
1753 if (auto intTy = dyn_cast<IntegerType>(Val: getElementTypeOrSelf(type: getType())))
1754 resultBitwidth = intTy.getWidth();
1755
1756 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1757 operands: adaptor.getOperands(), resType: getType(),
1758 calculate: [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1759 return a.sextOrTrunc(width: resultBitwidth);
1760 });
1761}
1762
1763void arith::IndexCastOp::getCanonicalizationPatterns(
1764 RewritePatternSet &patterns, MLIRContext *context) {
1765 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(arg&: context);
1766}
1767
1768//===----------------------------------------------------------------------===//
1769// IndexCastUIOp
1770//===----------------------------------------------------------------------===//
1771
1772bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1773 TypeRange outputs) {
1774 return areIndexCastCompatible(inputs, outputs);
1775}
1776
1777OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1778 // index_castui(constant) -> constant
1779 unsigned resultBitwidth = 64; // Default for index integer attributes.
1780 if (auto intTy = dyn_cast<IntegerType>(Val: getElementTypeOrSelf(type: getType())))
1781 resultBitwidth = intTy.getWidth();
1782
1783 return constFoldCastOp<IntegerAttr, IntegerAttr>(
1784 operands: adaptor.getOperands(), resType: getType(),
1785 calculate: [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1786 return a.zextOrTrunc(width: resultBitwidth);
1787 });
1788}
1789
1790void arith::IndexCastUIOp::getCanonicalizationPatterns(
1791 RewritePatternSet &patterns, MLIRContext *context) {
1792 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(arg&: context);
1793}
1794
1795//===----------------------------------------------------------------------===//
1796// BitcastOp
1797//===----------------------------------------------------------------------===//
1798
1799bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1800 if (!areValidCastInputsAndOutputs(inputs, outputs))
1801 return false;
1802
1803 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(type: inputs.front());
1804 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(type: outputs.front());
1805 if (!srcType || !dstType)
1806 return false;
1807
1808 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1809}
1810
1811OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1812 auto resType = getType();
1813 auto operand = adaptor.getIn();
1814 if (!operand)
1815 return {};
1816
1817 /// Bitcast dense elements.
1818 if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(Val&: operand))
1819 return denseAttr.bitcast(newElType: llvm::cast<ShapedType>(Val&: resType).getElementType());
1820 /// Other shaped types unhandled.
1821 if (llvm::isa<ShapedType>(Val: resType))
1822 return {};
1823
1824 /// Bitcast poison.
1825 if (llvm::isa<ub::PoisonAttr>(Val: operand))
1826 return ub::PoisonAttr::get(ctx: getContext());
1827
1828 /// Bitcast integer or float to integer or float.
1829 APInt bits = llvm::isa<FloatAttr>(Val: operand)
1830 ? llvm::cast<FloatAttr>(Val&: operand).getValue().bitcastToAPInt()
1831 : llvm::cast<IntegerAttr>(Val&: operand).getValue();
1832 assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1833 "trying to fold on broken IR: operands have incompatible types");
1834
1835 if (auto resFloatType = llvm::dyn_cast<FloatType>(Val&: resType))
1836 return FloatAttr::get(type: resType,
1837 value: APFloat(resFloatType.getFloatSemantics(), bits));
1838 return IntegerAttr::get(type: resType, value: bits);
1839}
1840
1841void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1842 MLIRContext *context) {
1843 patterns.add<BitcastOfBitcast>(arg&: context);
1844}
1845
1846//===----------------------------------------------------------------------===//
1847// CmpIOp
1848//===----------------------------------------------------------------------===//
1849
1850/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1851/// comparison predicates.
1852bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1853 const APInt &lhs, const APInt &rhs) {
1854 switch (predicate) {
1855 case arith::CmpIPredicate::eq:
1856 return lhs.eq(RHS: rhs);
1857 case arith::CmpIPredicate::ne:
1858 return lhs.ne(RHS: rhs);
1859 case arith::CmpIPredicate::slt:
1860 return lhs.slt(RHS: rhs);
1861 case arith::CmpIPredicate::sle:
1862 return lhs.sle(RHS: rhs);
1863 case arith::CmpIPredicate::sgt:
1864 return lhs.sgt(RHS: rhs);
1865 case arith::CmpIPredicate::sge:
1866 return lhs.sge(RHS: rhs);
1867 case arith::CmpIPredicate::ult:
1868 return lhs.ult(RHS: rhs);
1869 case arith::CmpIPredicate::ule:
1870 return lhs.ule(RHS: rhs);
1871 case arith::CmpIPredicate::ugt:
1872 return lhs.ugt(RHS: rhs);
1873 case arith::CmpIPredicate::uge:
1874 return lhs.uge(RHS: rhs);
1875 }
1876 llvm_unreachable("unknown cmpi predicate kind");
1877}
1878
1879/// Returns true if the predicate is true for two equal operands.
1880static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1881 switch (predicate) {
1882 case arith::CmpIPredicate::eq:
1883 case arith::CmpIPredicate::sle:
1884 case arith::CmpIPredicate::sge:
1885 case arith::CmpIPredicate::ule:
1886 case arith::CmpIPredicate::uge:
1887 return true;
1888 case arith::CmpIPredicate::ne:
1889 case arith::CmpIPredicate::slt:
1890 case arith::CmpIPredicate::sgt:
1891 case arith::CmpIPredicate::ult:
1892 case arith::CmpIPredicate::ugt:
1893 return false;
1894 }
1895 llvm_unreachable("unknown cmpi predicate kind");
1896}
1897
1898static std::optional<int64_t> getIntegerWidth(Type t) {
1899 if (auto intType = llvm::dyn_cast<IntegerType>(Val&: t)) {
1900 return intType.getWidth();
1901 }
1902 if (auto vectorIntType = llvm::dyn_cast<VectorType>(Val&: t)) {
1903 return llvm::cast<IntegerType>(Val: vectorIntType.getElementType()).getWidth();
1904 }
1905 return std::nullopt;
1906}
1907
1908OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1909 // cmpi(pred, x, x)
1910 if (getLhs() == getRhs()) {
1911 auto val = applyCmpPredicateToEqualOperands(predicate: getPredicate());
1912 return getBoolAttribute(type: getType(), value: val);
1913 }
1914
1915 if (matchPattern(attr: adaptor.getRhs(), pattern: m_Zero())) {
1916 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1917 // extsi(%x : i1 -> iN) != 0 -> %x
1918 std::optional<int64_t> integerWidth =
1919 getIntegerWidth(t: extOp.getOperand().getType());
1920 if (integerWidth && integerWidth.value() == 1 &&
1921 getPredicate() == arith::CmpIPredicate::ne)
1922 return extOp.getOperand();
1923 }
1924 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1925 // extui(%x : i1 -> iN) != 0 -> %x
1926 std::optional<int64_t> integerWidth =
1927 getIntegerWidth(t: extOp.getOperand().getType());
1928 if (integerWidth && integerWidth.value() == 1 &&
1929 getPredicate() == arith::CmpIPredicate::ne)
1930 return extOp.getOperand();
1931 }
1932
1933 // arith.cmpi ne, %val, %zero : i1 -> %val
1934 if (getElementTypeOrSelf(type: getLhs().getType()).isInteger(width: 1) &&
1935 getPredicate() == arith::CmpIPredicate::ne)
1936 return getLhs();
1937 }
1938
1939 if (matchPattern(attr: adaptor.getRhs(), pattern: m_One())) {
1940 // arith.cmpi eq, %val, %one : i1 -> %val
1941 if (getElementTypeOrSelf(type: getLhs().getType()).isInteger(width: 1) &&
1942 getPredicate() == arith::CmpIPredicate::eq)
1943 return getLhs();
1944 }
1945
1946 // Move constant to the right side.
1947 if (adaptor.getLhs() && !adaptor.getRhs()) {
1948 // Do not use invertPredicate, as it will change eq to ne and vice versa.
1949 using Pred = CmpIPredicate;
1950 const std::pair<Pred, Pred> invPreds[] = {
1951 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1952 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1953 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1954 {Pred::ne, Pred::ne},
1955 };
1956 Pred origPred = getPredicate();
1957 for (auto pred : invPreds) {
1958 if (origPred == pred.first) {
1959 setPredicate(pred.second);
1960 Value lhs = getLhs();
1961 Value rhs = getRhs();
1962 getLhsMutable().assign(value: rhs);
1963 getRhsMutable().assign(value: lhs);
1964 return getResult();
1965 }
1966 }
1967 llvm_unreachable("unknown cmpi predicate kind");
1968 }
1969
1970 // We are moving constants to the right side; So if lhs is constant rhs is
1971 // guaranteed to be a constant.
1972 if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(Val: adaptor.getLhs())) {
1973 return constFoldBinaryOp<IntegerAttr>(
1974 operands: adaptor.getOperands(), resultType: getI1SameShape(type: lhs.getType()),
1975 calculate: [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
1976 return APInt(1,
1977 static_cast<int64_t>(applyCmpPredicate(predicate: pred, lhs, rhs)));
1978 });
1979 }
1980
1981 return {};
1982}
1983
1984void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1985 MLIRContext *context) {
1986 patterns.insert<CmpIExtSI, CmpIExtUI>(arg&: context);
1987}
1988
1989//===----------------------------------------------------------------------===//
1990// CmpFOp
1991//===----------------------------------------------------------------------===//
1992
1993/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1994/// comparison predicates.
1995bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1996 const APFloat &lhs, const APFloat &rhs) {
1997 auto cmpResult = lhs.compare(RHS: rhs);
1998 switch (predicate) {
1999 case arith::CmpFPredicate::AlwaysFalse:
2000 return false;
2001 case arith::CmpFPredicate::OEQ:
2002 return cmpResult == APFloat::cmpEqual;
2003 case arith::CmpFPredicate::OGT:
2004 return cmpResult == APFloat::cmpGreaterThan;
2005 case arith::CmpFPredicate::OGE:
2006 return cmpResult == APFloat::cmpGreaterThan ||
2007 cmpResult == APFloat::cmpEqual;
2008 case arith::CmpFPredicate::OLT:
2009 return cmpResult == APFloat::cmpLessThan;
2010 case arith::CmpFPredicate::OLE:
2011 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2012 case arith::CmpFPredicate::ONE:
2013 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2014 case arith::CmpFPredicate::ORD:
2015 return cmpResult != APFloat::cmpUnordered;
2016 case arith::CmpFPredicate::UEQ:
2017 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2018 case arith::CmpFPredicate::UGT:
2019 return cmpResult == APFloat::cmpUnordered ||
2020 cmpResult == APFloat::cmpGreaterThan;
2021 case arith::CmpFPredicate::UGE:
2022 return cmpResult == APFloat::cmpUnordered ||
2023 cmpResult == APFloat::cmpGreaterThan ||
2024 cmpResult == APFloat::cmpEqual;
2025 case arith::CmpFPredicate::ULT:
2026 return cmpResult == APFloat::cmpUnordered ||
2027 cmpResult == APFloat::cmpLessThan;
2028 case arith::CmpFPredicate::ULE:
2029 return cmpResult == APFloat::cmpUnordered ||
2030 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2031 case arith::CmpFPredicate::UNE:
2032 return cmpResult != APFloat::cmpEqual;
2033 case arith::CmpFPredicate::UNO:
2034 return cmpResult == APFloat::cmpUnordered;
2035 case arith::CmpFPredicate::AlwaysTrue:
2036 return true;
2037 }
2038 llvm_unreachable("unknown cmpf predicate kind");
2039}
2040
2041OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2042 auto lhs = llvm::dyn_cast_if_present<FloatAttr>(Val: adaptor.getLhs());
2043 auto rhs = llvm::dyn_cast_if_present<FloatAttr>(Val: adaptor.getRhs());
2044
2045 // If one operand is NaN, making them both NaN does not change the result.
2046 if (lhs && lhs.getValue().isNaN())
2047 rhs = lhs;
2048 if (rhs && rhs.getValue().isNaN())
2049 lhs = rhs;
2050
2051 if (!lhs || !rhs)
2052 return {};
2053
2054 auto val = applyCmpPredicate(predicate: getPredicate(), lhs: lhs.getValue(), rhs: rhs.getValue());
2055 return BoolAttr::get(context: getContext(), value: val);
2056}
2057
2058class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2059public:
2060 using OpRewritePattern<CmpFOp>::OpRewritePattern;
2061
2062 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2063 bool isUnsigned) {
2064 using namespace arith;
2065 switch (pred) {
2066 case CmpFPredicate::UEQ:
2067 case CmpFPredicate::OEQ:
2068 return CmpIPredicate::eq;
2069 case CmpFPredicate::UGT:
2070 case CmpFPredicate::OGT:
2071 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2072 case CmpFPredicate::UGE:
2073 case CmpFPredicate::OGE:
2074 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2075 case CmpFPredicate::ULT:
2076 case CmpFPredicate::OLT:
2077 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2078 case CmpFPredicate::ULE:
2079 case CmpFPredicate::OLE:
2080 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2081 case CmpFPredicate::UNE:
2082 case CmpFPredicate::ONE:
2083 return CmpIPredicate::ne;
2084 default:
2085 llvm_unreachable("Unexpected predicate!");
2086 }
2087 }
2088
2089 LogicalResult matchAndRewrite(CmpFOp op,
2090 PatternRewriter &rewriter) const override {
2091 FloatAttr flt;
2092 if (!matchPattern(value: op.getRhs(), pattern: m_Constant(bind_value: &flt)))
2093 return failure();
2094
2095 const APFloat &rhs = flt.getValue();
2096
2097 // Don't attempt to fold a nan.
2098 if (rhs.isNaN())
2099 return failure();
2100
2101 // Get the width of the mantissa. We don't want to hack on conversions that
2102 // might lose information from the integer, e.g. "i64 -> float"
2103 FloatType floatTy = llvm::cast<FloatType>(Val: op.getRhs().getType());
2104 int mantissaWidth = floatTy.getFPMantissaWidth();
2105 if (mantissaWidth <= 0)
2106 return failure();
2107
2108 bool isUnsigned;
2109 Value intVal;
2110
2111 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2112 isUnsigned = false;
2113 intVal = si.getIn();
2114 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2115 isUnsigned = true;
2116 intVal = ui.getIn();
2117 } else {
2118 return failure();
2119 }
2120
2121 // Check to see that the input is converted from an integer type that is
2122 // small enough that preserves all bits.
2123 auto intTy = llvm::cast<IntegerType>(Val: intVal.getType());
2124 auto intWidth = intTy.getWidth();
2125
2126 // Number of bits representing values, as opposed to the sign
2127 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2128
2129 // Following test does NOT adjust intWidth downwards for signed inputs,
2130 // because the most negative value still requires all the mantissa bits
2131 // to distinguish it from one less than that value.
2132 if ((int)intWidth > mantissaWidth) {
2133 // Conversion would lose accuracy. Check if loss can impact comparison.
2134 int exponent = ilogb(Arg: rhs);
2135 if (exponent == APFloat::IEK_Inf) {
2136 int maxExponent = ilogb(Arg: APFloat::getLargest(Sem: rhs.getSemantics()));
2137 if (maxExponent < (int)valueBits) {
2138 // Conversion could create infinity.
2139 return failure();
2140 }
2141 } else {
2142 // Note that if rhs is zero or NaN, then Exp is negative
2143 // and first condition is trivially false.
2144 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2145 // Conversion could affect comparison.
2146 return failure();
2147 }
2148 }
2149 }
2150
2151 // Convert to equivalent cmpi predicate
2152 CmpIPredicate pred;
2153 switch (op.getPredicate()) {
2154 case CmpFPredicate::ORD:
2155 // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2156 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: true,
2157 /*width=*/args: 1);
2158 return success();
2159 case CmpFPredicate::UNO:
2160 // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2161 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: false,
2162 /*width=*/args: 1);
2163 return success();
2164 default:
2165 pred = convertToIntegerPredicate(pred: op.getPredicate(), isUnsigned);
2166 break;
2167 }
2168
2169 if (!isUnsigned) {
2170 // If the rhs value is > SignedMax, fold the comparison. This handles
2171 // +INF and large values.
2172 APFloat signedMax(rhs.getSemantics());
2173 signedMax.convertFromAPInt(Input: APInt::getSignedMaxValue(numBits: intWidth), IsSigned: true,
2174 RM: APFloat::rmNearestTiesToEven);
2175 if (signedMax < rhs) { // smax < 13123.0
2176 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2177 pred == CmpIPredicate::sle)
2178 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: true,
2179 /*width=*/args: 1);
2180 else
2181 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: false,
2182 /*width=*/args: 1);
2183 return success();
2184 }
2185 } else {
2186 // If the rhs value is > UnsignedMax, fold the comparison. This handles
2187 // +INF and large values.
2188 APFloat unsignedMax(rhs.getSemantics());
2189 unsignedMax.convertFromAPInt(Input: APInt::getMaxValue(numBits: intWidth), IsSigned: false,
2190 RM: APFloat::rmNearestTiesToEven);
2191 if (unsignedMax < rhs) { // umax < 13123.0
2192 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2193 pred == CmpIPredicate::ule)
2194 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: true,
2195 /*width=*/args: 1);
2196 else
2197 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: false,
2198 /*width=*/args: 1);
2199 return success();
2200 }
2201 }
2202
2203 if (!isUnsigned) {
2204 // See if the rhs value is < SignedMin.
2205 APFloat signedMin(rhs.getSemantics());
2206 signedMin.convertFromAPInt(Input: APInt::getSignedMinValue(numBits: intWidth), IsSigned: true,
2207 RM: APFloat::rmNearestTiesToEven);
2208 if (signedMin > rhs) { // smin > 12312.0
2209 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2210 pred == CmpIPredicate::sge)
2211 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: true,
2212 /*width=*/args: 1);
2213 else
2214 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: false,
2215 /*width=*/args: 1);
2216 return success();
2217 }
2218 } else {
2219 // See if the rhs value is < UnsignedMin.
2220 APFloat unsignedMin(rhs.getSemantics());
2221 unsignedMin.convertFromAPInt(Input: APInt::getMinValue(numBits: intWidth), IsSigned: false,
2222 RM: APFloat::rmNearestTiesToEven);
2223 if (unsignedMin > rhs) { // umin > 12312.0
2224 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2225 pred == CmpIPredicate::uge)
2226 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: true,
2227 /*width=*/args: 1);
2228 else
2229 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: false,
2230 /*width=*/args: 1);
2231 return success();
2232 }
2233 }
2234
2235 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2236 // [0, UMAX], but it may still be fractional. See if it is fractional by
2237 // casting the FP value to the integer value and back, checking for
2238 // equality. Don't do this for zero, because -0.0 is not fractional.
2239 bool ignored;
2240 APSInt rhsInt(intWidth, isUnsigned);
2241 if (APFloat::opInvalidOp ==
2242 rhs.convertToInteger(Result&: rhsInt, RM: APFloat::rmTowardZero, IsExact: &ignored)) {
2243 // Undefined behavior invoked - the destination type can't represent
2244 // the input constant.
2245 return failure();
2246 }
2247
2248 if (!rhs.isZero()) {
2249 APFloat apf(floatTy.getFloatSemantics(),
2250 APInt::getZero(numBits: floatTy.getWidth()));
2251 apf.convertFromAPInt(Input: rhsInt, IsSigned: !isUnsigned, RM: APFloat::rmNearestTiesToEven);
2252
2253 bool equal = apf == rhs;
2254 if (!equal) {
2255 // If we had a comparison against a fractional value, we have to adjust
2256 // the compare predicate and sometimes the value. rhsInt is rounded
2257 // towards zero at this point.
2258 switch (pred) {
2259 case CmpIPredicate::ne: // (float)int != 4.4 --> true
2260 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: true,
2261 /*width=*/args: 1);
2262 return success();
2263 case CmpIPredicate::eq: // (float)int == 4.4 --> false
2264 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: false,
2265 /*width=*/args: 1);
2266 return success();
2267 case CmpIPredicate::ule:
2268 // (float)int <= 4.4 --> int <= 4
2269 // (float)int <= -4.4 --> false
2270 if (rhs.isNegative()) {
2271 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: false,
2272 /*width=*/args: 1);
2273 return success();
2274 }
2275 break;
2276 case CmpIPredicate::sle:
2277 // (float)int <= 4.4 --> int <= 4
2278 // (float)int <= -4.4 --> int < -4
2279 if (rhs.isNegative())
2280 pred = CmpIPredicate::slt;
2281 break;
2282 case CmpIPredicate::ult:
2283 // (float)int < -4.4 --> false
2284 // (float)int < 4.4 --> int <= 4
2285 if (rhs.isNegative()) {
2286 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: false,
2287 /*width=*/args: 1);
2288 return success();
2289 }
2290 pred = CmpIPredicate::ule;
2291 break;
2292 case CmpIPredicate::slt:
2293 // (float)int < -4.4 --> int < -4
2294 // (float)int < 4.4 --> int <= 4
2295 if (!rhs.isNegative())
2296 pred = CmpIPredicate::sle;
2297 break;
2298 case CmpIPredicate::ugt:
2299 // (float)int > 4.4 --> int > 4
2300 // (float)int > -4.4 --> true
2301 if (rhs.isNegative()) {
2302 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: true,
2303 /*width=*/args: 1);
2304 return success();
2305 }
2306 break;
2307 case CmpIPredicate::sgt:
2308 // (float)int > 4.4 --> int > 4
2309 // (float)int > -4.4 --> int >= -4
2310 if (rhs.isNegative())
2311 pred = CmpIPredicate::sge;
2312 break;
2313 case CmpIPredicate::uge:
2314 // (float)int >= -4.4 --> true
2315 // (float)int >= 4.4 --> int > 4
2316 if (rhs.isNegative()) {
2317 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/args: true,
2318 /*width=*/args: 1);
2319 return success();
2320 }
2321 pred = CmpIPredicate::ugt;
2322 break;
2323 case CmpIPredicate::sge:
2324 // (float)int >= -4.4 --> int >= -4
2325 // (float)int >= 4.4 --> int > 4
2326 if (!rhs.isNegative())
2327 pred = CmpIPredicate::sgt;
2328 break;
2329 }
2330 }
2331 }
2332
2333 // Lower this FP comparison into an appropriate integer version of the
2334 // comparison.
2335 rewriter.replaceOpWithNewOp<CmpIOp>(
2336 op, args&: pred, args&: intVal,
2337 args: rewriter.create<ConstantOp>(
2338 location: op.getLoc(), args: intVal.getType(),
2339 args: rewriter.getIntegerAttr(type: intVal.getType(), value: rhsInt)));
2340 return success();
2341 }
2342};
2343
2344void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2345 MLIRContext *context) {
2346 patterns.insert<CmpFIntToFPConst>(arg&: context);
2347}
2348
2349//===----------------------------------------------------------------------===//
2350// SelectOp
2351//===----------------------------------------------------------------------===//
2352
2353// select %arg, %c1, %c0 => extui %arg
2354struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2355 using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
2356
2357 LogicalResult matchAndRewrite(arith::SelectOp op,
2358 PatternRewriter &rewriter) const override {
2359 // Cannot extui i1 to i1, or i1 to f32
2360 if (!llvm::isa<IntegerType>(Val: op.getType()) || op.getType().isInteger(width: 1))
2361 return failure();
2362
2363 // select %x, c1, %c0 => extui %arg
2364 if (matchPattern(value: op.getTrueValue(), pattern: m_One()) &&
2365 matchPattern(value: op.getFalseValue(), pattern: m_Zero())) {
2366 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, args: op.getType(),
2367 args: op.getCondition());
2368 return success();
2369 }
2370
2371 // select %x, c0, %c1 => extui (xor %arg, true)
2372 if (matchPattern(value: op.getTrueValue(), pattern: m_Zero()) &&
2373 matchPattern(value: op.getFalseValue(), pattern: m_One())) {
2374 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2375 op, args: op.getType(),
2376 args: rewriter.create<arith::XOrIOp>(
2377 location: op.getLoc(), args: op.getCondition(),
2378 args: rewriter.create<arith::ConstantIntOp>(
2379 location: op.getLoc(), args: op.getCondition().getType(), args: 1)));
2380 return success();
2381 }
2382
2383 return failure();
2384 }
2385};
2386
2387void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2388 MLIRContext *context) {
2389 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2390 SelectI1ToNot, SelectToExtUI>(arg&: context);
2391}
2392
2393OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2394 Value trueVal = getTrueValue();
2395 Value falseVal = getFalseValue();
2396 if (trueVal == falseVal)
2397 return trueVal;
2398
2399 Value condition = getCondition();
2400
2401 // select true, %0, %1 => %0
2402 if (matchPattern(attr: adaptor.getCondition(), pattern: m_One()))
2403 return trueVal;
2404
2405 // select false, %0, %1 => %1
2406 if (matchPattern(attr: adaptor.getCondition(), pattern: m_Zero()))
2407 return falseVal;
2408
2409 // If either operand is fully poisoned, return the other.
2410 if (isa_and_nonnull<ub::PoisonAttr>(Val: adaptor.getTrueValue()))
2411 return falseVal;
2412
2413 if (isa_and_nonnull<ub::PoisonAttr>(Val: adaptor.getFalseValue()))
2414 return trueVal;
2415
2416 // select %x, true, false => %x
2417 if (getType().isSignlessInteger(width: 1) &&
2418 matchPattern(attr: adaptor.getTrueValue(), pattern: m_One()) &&
2419 matchPattern(attr: adaptor.getFalseValue(), pattern: m_Zero()))
2420 return condition;
2421
2422 if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(Val: condition.getDefiningOp())) {
2423 auto pred = cmp.getPredicate();
2424 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2425 auto cmpLhs = cmp.getLhs();
2426 auto cmpRhs = cmp.getRhs();
2427
2428 // %0 = arith.cmpi eq, %arg0, %arg1
2429 // %1 = arith.select %0, %arg0, %arg1 => %arg1
2430
2431 // %0 = arith.cmpi ne, %arg0, %arg1
2432 // %1 = arith.select %0, %arg0, %arg1 => %arg0
2433
2434 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2435 (cmpRhs == trueVal && cmpLhs == falseVal))
2436 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2437 }
2438 }
2439
2440 // Constant-fold constant operands over non-splat constant condition.
2441 // select %cst_vec, %cst0, %cst1 => %cst2
2442 if (auto cond =
2443 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getCondition())) {
2444 if (auto lhs =
2445 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getTrueValue())) {
2446 if (auto rhs =
2447 llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getFalseValue())) {
2448 SmallVector<Attribute> results;
2449 results.reserve(N: static_cast<size_t>(cond.getNumElements()));
2450 auto condVals = llvm::make_range(x: cond.value_begin<BoolAttr>(),
2451 y: cond.value_end<BoolAttr>());
2452 auto lhsVals = llvm::make_range(x: lhs.value_begin<Attribute>(),
2453 y: lhs.value_end<Attribute>());
2454 auto rhsVals = llvm::make_range(x: rhs.value_begin<Attribute>(),
2455 y: rhs.value_end<Attribute>());
2456
2457 for (auto [condVal, lhsVal, rhsVal] :
2458 llvm::zip_equal(t&: condVals, u&: lhsVals, args&: rhsVals))
2459 results.push_back(Elt: condVal.getValue() ? lhsVal : rhsVal);
2460
2461 return DenseElementsAttr::get(type: lhs.getType(), values: results);
2462 }
2463 }
2464 }
2465
2466 return nullptr;
2467}
2468
2469ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2470 Type conditionType, resultType;
2471 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2472 if (parser.parseOperandList(result&: operands, /*requiredOperandCount=*/3) ||
2473 parser.parseOptionalAttrDict(result&: result.attributes) ||
2474 parser.parseColonType(result&: resultType))
2475 return failure();
2476
2477 // Check for the explicit condition type if this is a masked tensor or vector.
2478 if (succeeded(Result: parser.parseOptionalComma())) {
2479 conditionType = resultType;
2480 if (parser.parseType(result&: resultType))
2481 return failure();
2482 } else {
2483 conditionType = parser.getBuilder().getI1Type();
2484 }
2485
2486 result.addTypes(newTypes: resultType);
2487 return parser.resolveOperands(operands,
2488 types: {conditionType, resultType, resultType},
2489 loc: parser.getNameLoc(), result&: result.operands);
2490}
2491
2492void arith::SelectOp::print(OpAsmPrinter &p) {
2493 p << " " << getOperands();
2494 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
2495 p << " : ";
2496 if (ShapedType condType =
2497 llvm::dyn_cast<ShapedType>(Val: getCondition().getType()))
2498 p << condType << ", ";
2499 p << getType();
2500}
2501
2502LogicalResult arith::SelectOp::verify() {
2503 Type conditionType = getCondition().getType();
2504 if (conditionType.isSignlessInteger(width: 1))
2505 return success();
2506
2507 // If the result type is a vector or tensor, the type can be a mask with the
2508 // same elements.
2509 Type resultType = getType();
2510 if (!llvm::isa<TensorType, VectorType>(Val: resultType))
2511 return emitOpError() << "expected condition to be a signless i1, but got "
2512 << conditionType;
2513 Type shapedConditionType = getI1SameShape(type: resultType);
2514 if (conditionType != shapedConditionType) {
2515 return emitOpError() << "expected condition type to have the same shape "
2516 "as the result type, expected "
2517 << shapedConditionType << ", but got "
2518 << conditionType;
2519 }
2520 return success();
2521}
2522//===----------------------------------------------------------------------===//
2523// ShLIOp
2524//===----------------------------------------------------------------------===//
2525
2526OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2527 // shli(x, 0) -> x
2528 if (matchPattern(attr: adaptor.getRhs(), pattern: m_Zero()))
2529 return getLhs();
2530 // Don't fold if shifting more or equal than the bit width.
2531 bool bounded = false;
2532 auto result = constFoldBinaryOp<IntegerAttr>(
2533 operands: adaptor.getOperands(), calculate: [&](const APInt &a, const APInt &b) {
2534 bounded = b.ult(RHS: b.getBitWidth());
2535 return a.shl(ShiftAmt: b);
2536 });
2537 return bounded ? result : Attribute();
2538}
2539
2540//===----------------------------------------------------------------------===//
2541// ShRUIOp
2542//===----------------------------------------------------------------------===//
2543
2544OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2545 // shrui(x, 0) -> x
2546 if (matchPattern(attr: adaptor.getRhs(), pattern: m_Zero()))
2547 return getLhs();
2548 // Don't fold if shifting more or equal than the bit width.
2549 bool bounded = false;
2550 auto result = constFoldBinaryOp<IntegerAttr>(
2551 operands: adaptor.getOperands(), calculate: [&](const APInt &a, const APInt &b) {
2552 bounded = b.ult(RHS: b.getBitWidth());
2553 return a.lshr(ShiftAmt: b);
2554 });
2555 return bounded ? result : Attribute();
2556}
2557
2558//===----------------------------------------------------------------------===//
2559// ShRSIOp
2560//===----------------------------------------------------------------------===//
2561
2562OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2563 // shrsi(x, 0) -> x
2564 if (matchPattern(attr: adaptor.getRhs(), pattern: m_Zero()))
2565 return getLhs();
2566 // Don't fold if shifting more or equal than the bit width.
2567 bool bounded = false;
2568 auto result = constFoldBinaryOp<IntegerAttr>(
2569 operands: adaptor.getOperands(), calculate: [&](const APInt &a, const APInt &b) {
2570 bounded = b.ult(RHS: b.getBitWidth());
2571 return a.ashr(ShiftAmt: b);
2572 });
2573 return bounded ? result : Attribute();
2574}
2575
2576//===----------------------------------------------------------------------===//
2577// Atomic Enum
2578//===----------------------------------------------------------------------===//
2579
2580/// Returns the identity value attribute associated with an AtomicRMWKind op.
2581TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2582 OpBuilder &builder, Location loc,
2583 bool useOnlyFiniteValue) {
2584 switch (kind) {
2585 case AtomicRMWKind::maximumf: {
2586 const llvm::fltSemantics &semantic =
2587 llvm::cast<FloatType>(Val&: resultType).getFloatSemantics();
2588 APFloat identity = useOnlyFiniteValue
2589 ? APFloat::getLargest(Sem: semantic, /*Negative=*/true)
2590 : APFloat::getInf(Sem: semantic, /*Negative=*/true);
2591 return builder.getFloatAttr(type: resultType, value: identity);
2592 }
2593 case AtomicRMWKind::maxnumf: {
2594 const llvm::fltSemantics &semantic =
2595 llvm::cast<FloatType>(Val&: resultType).getFloatSemantics();
2596 APFloat identity = APFloat::getNaN(Sem: semantic, /*Negative=*/true);
2597 return builder.getFloatAttr(type: resultType, value: identity);
2598 }
2599 case AtomicRMWKind::addf:
2600 case AtomicRMWKind::addi:
2601 case AtomicRMWKind::maxu:
2602 case AtomicRMWKind::ori:
2603 return builder.getZeroAttr(type: resultType);
2604 case AtomicRMWKind::andi:
2605 return builder.getIntegerAttr(
2606 type: resultType,
2607 value: APInt::getAllOnes(numBits: llvm::cast<IntegerType>(Val&: resultType).getWidth()));
2608 case AtomicRMWKind::maxs:
2609 return builder.getIntegerAttr(
2610 type: resultType, value: APInt::getSignedMinValue(
2611 numBits: llvm::cast<IntegerType>(Val&: resultType).getWidth()));
2612 case AtomicRMWKind::minimumf: {
2613 const llvm::fltSemantics &semantic =
2614 llvm::cast<FloatType>(Val&: resultType).getFloatSemantics();
2615 APFloat identity = useOnlyFiniteValue
2616 ? APFloat::getLargest(Sem: semantic, /*Negative=*/false)
2617 : APFloat::getInf(Sem: semantic, /*Negative=*/false);
2618
2619 return builder.getFloatAttr(type: resultType, value: identity);
2620 }
2621 case AtomicRMWKind::minnumf: {
2622 const llvm::fltSemantics &semantic =
2623 llvm::cast<FloatType>(Val&: resultType).getFloatSemantics();
2624 APFloat identity = APFloat::getNaN(Sem: semantic, /*Negative=*/false);
2625 return builder.getFloatAttr(type: resultType, value: identity);
2626 }
2627 case AtomicRMWKind::mins:
2628 return builder.getIntegerAttr(
2629 type: resultType, value: APInt::getSignedMaxValue(
2630 numBits: llvm::cast<IntegerType>(Val&: resultType).getWidth()));
2631 case AtomicRMWKind::minu:
2632 return builder.getIntegerAttr(
2633 type: resultType,
2634 value: APInt::getMaxValue(numBits: llvm::cast<IntegerType>(Val&: resultType).getWidth()));
2635 case AtomicRMWKind::muli:
2636 return builder.getIntegerAttr(type: resultType, value: 1);
2637 case AtomicRMWKind::mulf:
2638 return builder.getFloatAttr(type: resultType, value: 1);
2639 // TODO: Add remaining reduction operations.
2640 default:
2641 (void)emitOptionalError(loc, args: "Reduction operation type not supported");
2642 break;
2643 }
2644 return nullptr;
2645}
2646
2647/// Return the identity numeric value associated to the give op.
2648std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2649 std::optional<AtomicRMWKind> maybeKind =
2650 llvm::TypeSwitch<Operation *, std::optional<AtomicRMWKind>>(op)
2651 // Floating-point operations.
2652 .Case(caseFn: [](arith::AddFOp op) { return AtomicRMWKind::addf; })
2653 .Case(caseFn: [](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2654 .Case(caseFn: [](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2655 .Case(caseFn: [](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2656 .Case(caseFn: [](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2657 .Case(caseFn: [](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2658 // Integer operations.
2659 .Case(caseFn: [](arith::AddIOp op) { return AtomicRMWKind::addi; })
2660 .Case(caseFn: [](arith::OrIOp op) { return AtomicRMWKind::ori; })
2661 .Case(caseFn: [](arith::XOrIOp op) { return AtomicRMWKind::ori; })
2662 .Case(caseFn: [](arith::AndIOp op) { return AtomicRMWKind::andi; })
2663 .Case(caseFn: [](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2664 .Case(caseFn: [](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2665 .Case(caseFn: [](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2666 .Case(caseFn: [](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2667 .Case(caseFn: [](arith::MulIOp op) { return AtomicRMWKind::muli; })
2668 .Default(defaultFn: [](Operation *op) { return std::nullopt; });
2669 if (!maybeKind) {
2670 return std::nullopt;
2671 }
2672
2673 bool useOnlyFiniteValue = false;
2674 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(Val: op);
2675 if (fmfOpInterface) {
2676 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2677 useOnlyFiniteValue =
2678 bitEnumContainsAny(bits: fmfAttr.getValue(), bit: arith::FastMathFlags::ninf);
2679 }
2680
2681 // Builder only used as helper for attribute creation.
2682 OpBuilder b(op->getContext());
2683 Type resultType = op->getResult(idx: 0).getType();
2684
2685 return getIdentityValueAttr(kind: *maybeKind, resultType, builder&: b, loc: op->getLoc(),
2686 useOnlyFiniteValue);
2687}
2688
2689/// Returns the identity value associated with an AtomicRMWKind op.
2690Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2691 OpBuilder &builder, Location loc,
2692 bool useOnlyFiniteValue) {
2693 auto attr =
2694 getIdentityValueAttr(kind: op, resultType, builder, loc, useOnlyFiniteValue);
2695 return builder.create<arith::ConstantOp>(location: loc, args&: attr);
2696}
2697
2698/// Return the value obtained by applying the reduction operation kind
2699/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2700Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2701 Location loc, Value lhs, Value rhs) {
2702 switch (op) {
2703 case AtomicRMWKind::addf:
2704 return builder.create<arith::AddFOp>(location: loc, args&: lhs, args&: rhs);
2705 case AtomicRMWKind::addi:
2706 return builder.create<arith::AddIOp>(location: loc, args&: lhs, args&: rhs);
2707 case AtomicRMWKind::mulf:
2708 return builder.create<arith::MulFOp>(location: loc, args&: lhs, args&: rhs);
2709 case AtomicRMWKind::muli:
2710 return builder.create<arith::MulIOp>(location: loc, args&: lhs, args&: rhs);
2711 case AtomicRMWKind::maximumf:
2712 return builder.create<arith::MaximumFOp>(location: loc, args&: lhs, args&: rhs);
2713 case AtomicRMWKind::minimumf:
2714 return builder.create<arith::MinimumFOp>(location: loc, args&: lhs, args&: rhs);
2715 case AtomicRMWKind::maxnumf:
2716 return builder.create<arith::MaxNumFOp>(location: loc, args&: lhs, args&: rhs);
2717 case AtomicRMWKind::minnumf:
2718 return builder.create<arith::MinNumFOp>(location: loc, args&: lhs, args&: rhs);
2719 case AtomicRMWKind::maxs:
2720 return builder.create<arith::MaxSIOp>(location: loc, args&: lhs, args&: rhs);
2721 case AtomicRMWKind::mins:
2722 return builder.create<arith::MinSIOp>(location: loc, args&: lhs, args&: rhs);
2723 case AtomicRMWKind::maxu:
2724 return builder.create<arith::MaxUIOp>(location: loc, args&: lhs, args&: rhs);
2725 case AtomicRMWKind::minu:
2726 return builder.create<arith::MinUIOp>(location: loc, args&: lhs, args&: rhs);
2727 case AtomicRMWKind::ori:
2728 return builder.create<arith::OrIOp>(location: loc, args&: lhs, args&: rhs);
2729 case AtomicRMWKind::andi:
2730 return builder.create<arith::AndIOp>(location: loc, args&: lhs, args&: rhs);
2731 // TODO: Add remaining reduction operations.
2732 default:
2733 (void)emitOptionalError(loc, args: "Reduction operation type not supported");
2734 break;
2735 }
2736 return nullptr;
2737}
2738
2739//===----------------------------------------------------------------------===//
2740// TableGen'd op method definitions
2741//===----------------------------------------------------------------------===//
2742
2743#define GET_OP_CLASSES
2744#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2745
2746//===----------------------------------------------------------------------===//
2747// TableGen'd enum attribute definitions
2748//===----------------------------------------------------------------------===//
2749
2750#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
2751

source code of mlir/lib/Dialect/Arith/IR/ArithOps.cpp