1 | //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <cassert> |
10 | #include <cstdint> |
11 | #include <functional> |
12 | #include <utility> |
13 | |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/CommonFolders.h" |
16 | #include "mlir/Dialect/UB/IR/UBOps.h" |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/BuiltinAttributeInterfaces.h" |
19 | #include "mlir/IR/BuiltinAttributes.h" |
20 | #include "mlir/IR/Matchers.h" |
21 | #include "mlir/IR/OpImplementation.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | #include "mlir/IR/TypeUtilities.h" |
24 | #include "mlir/Support/LogicalResult.h" |
25 | |
26 | #include "llvm/ADT/APFloat.h" |
27 | #include "llvm/ADT/APInt.h" |
28 | #include "llvm/ADT/APSInt.h" |
29 | #include "llvm/ADT/FloatingPointMode.h" |
30 | #include "llvm/ADT/STLExtras.h" |
31 | #include "llvm/ADT/SmallString.h" |
32 | #include "llvm/ADT/SmallVector.h" |
33 | #include "llvm/ADT/TypeSwitch.h" |
34 | |
35 | using namespace mlir; |
36 | using namespace mlir::arith; |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // Pattern helpers |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | static IntegerAttr |
43 | applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, |
44 | Attribute rhs, |
45 | function_ref<APInt(const APInt &, const APInt &)> binFn) { |
46 | APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue(); |
47 | APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue(); |
48 | APInt value = binFn(lhsVal, rhsVal); |
49 | return IntegerAttr::get(res.getType(), value); |
50 | } |
51 | |
52 | static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, |
53 | Attribute lhs, Attribute rhs) { |
54 | return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>()); |
55 | } |
56 | |
57 | static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, |
58 | Attribute lhs, Attribute rhs) { |
59 | return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>()); |
60 | } |
61 | |
62 | static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, |
63 | Attribute lhs, Attribute rhs) { |
64 | return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>()); |
65 | } |
66 | |
67 | /// Invert an integer comparison predicate. |
68 | arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { |
69 | switch (pred) { |
70 | case arith::CmpIPredicate::eq: |
71 | return arith::CmpIPredicate::ne; |
72 | case arith::CmpIPredicate::ne: |
73 | return arith::CmpIPredicate::eq; |
74 | case arith::CmpIPredicate::slt: |
75 | return arith::CmpIPredicate::sge; |
76 | case arith::CmpIPredicate::sle: |
77 | return arith::CmpIPredicate::sgt; |
78 | case arith::CmpIPredicate::sgt: |
79 | return arith::CmpIPredicate::sle; |
80 | case arith::CmpIPredicate::sge: |
81 | return arith::CmpIPredicate::slt; |
82 | case arith::CmpIPredicate::ult: |
83 | return arith::CmpIPredicate::uge; |
84 | case arith::CmpIPredicate::ule: |
85 | return arith::CmpIPredicate::ugt; |
86 | case arith::CmpIPredicate::ugt: |
87 | return arith::CmpIPredicate::ule; |
88 | case arith::CmpIPredicate::uge: |
89 | return arith::CmpIPredicate::ult; |
90 | } |
91 | llvm_unreachable("unknown cmpi predicate kind" ); |
92 | } |
93 | |
94 | /// Equivalent to |
95 | /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)). |
96 | /// |
97 | /// Not possible to implement as chain of calls as this would introduce a |
98 | /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend |
99 | /// on the LLVM dialect and on translation to LLVM. |
100 | static llvm::RoundingMode |
101 | convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) { |
102 | switch (roundingMode) { |
103 | case RoundingMode::downward: |
104 | return llvm::RoundingMode::TowardNegative; |
105 | case RoundingMode::to_nearest_away: |
106 | return llvm::RoundingMode::NearestTiesToAway; |
107 | case RoundingMode::to_nearest_even: |
108 | return llvm::RoundingMode::NearestTiesToEven; |
109 | case RoundingMode::toward_zero: |
110 | return llvm::RoundingMode::TowardZero; |
111 | case RoundingMode::upward: |
112 | return llvm::RoundingMode::TowardPositive; |
113 | } |
114 | llvm_unreachable("Unhandled rounding mode" ); |
115 | } |
116 | |
117 | static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { |
118 | return arith::CmpIPredicateAttr::get(pred.getContext(), |
119 | invertPredicate(pred.getValue())); |
120 | } |
121 | |
122 | static int64_t getScalarOrElementWidth(Type type) { |
123 | Type elemTy = getElementTypeOrSelf(type); |
124 | if (elemTy.isIntOrFloat()) |
125 | return elemTy.getIntOrFloatBitWidth(); |
126 | |
127 | return -1; |
128 | } |
129 | |
130 | static int64_t getScalarOrElementWidth(Value value) { |
131 | return getScalarOrElementWidth(type: value.getType()); |
132 | } |
133 | |
134 | static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) { |
135 | APInt value; |
136 | if (matchPattern(attr, m_ConstantInt(&value))) |
137 | return value; |
138 | |
139 | return failure(); |
140 | } |
141 | |
142 | static Attribute getBoolAttribute(Type type, bool value) { |
143 | auto boolAttr = BoolAttr::get(context: type.getContext(), value); |
144 | ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type); |
145 | if (!shapedType) |
146 | return boolAttr; |
147 | return DenseElementsAttr::get(shapedType, boolAttr); |
148 | } |
149 | |
150 | //===----------------------------------------------------------------------===// |
151 | // TableGen'd canonicalization patterns |
152 | //===----------------------------------------------------------------------===// |
153 | |
154 | namespace { |
155 | #include "ArithCanonicalization.inc" |
156 | } // namespace |
157 | |
158 | //===----------------------------------------------------------------------===// |
159 | // Common helpers |
160 | //===----------------------------------------------------------------------===// |
161 | |
162 | /// Return the type of the same shape (scalar, vector or tensor) containing i1. |
163 | static Type getI1SameShape(Type type) { |
164 | auto i1Type = IntegerType::get(type.getContext(), 1); |
165 | if (auto shapedType = llvm::dyn_cast<ShapedType>(type)) |
166 | return shapedType.cloneWith(std::nullopt, i1Type); |
167 | if (llvm::isa<UnrankedTensorType>(type)) |
168 | return UnrankedTensorType::get(i1Type); |
169 | return i1Type; |
170 | } |
171 | |
172 | //===----------------------------------------------------------------------===// |
173 | // ConstantOp |
174 | //===----------------------------------------------------------------------===// |
175 | |
176 | void arith::ConstantOp::getAsmResultNames( |
177 | function_ref<void(Value, StringRef)> setNameFn) { |
178 | auto type = getType(); |
179 | if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) { |
180 | auto intType = llvm::dyn_cast<IntegerType>(type); |
181 | |
182 | // Sugar i1 constants with 'true' and 'false'. |
183 | if (intType && intType.getWidth() == 1) |
184 | return setNameFn(getResult(), (intCst.getInt() ? "true" : "false" )); |
185 | |
186 | // Otherwise, build a complex name with the value and type. |
187 | SmallString<32> specialNameBuffer; |
188 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
189 | specialName << 'c' << intCst.getValue(); |
190 | if (intType) |
191 | specialName << '_' << type; |
192 | setNameFn(getResult(), specialName.str()); |
193 | } else { |
194 | setNameFn(getResult(), "cst" ); |
195 | } |
196 | } |
197 | |
198 | /// TODO: disallow arith.constant to return anything other than signless integer |
199 | /// or float like. |
200 | LogicalResult arith::ConstantOp::verify() { |
201 | auto type = getType(); |
202 | // The value's type must match the return type. |
203 | if (getValue().getType() != type) { |
204 | return emitOpError() << "value type " << getValue().getType() |
205 | << " must match return type: " << type; |
206 | } |
207 | // Integer values must be signless. |
208 | if (llvm::isa<IntegerType>(type) && |
209 | !llvm::cast<IntegerType>(type).isSignless()) |
210 | return emitOpError("integer return type must be signless" ); |
211 | // Any float or elements attribute are acceptable. |
212 | if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) { |
213 | return emitOpError( |
214 | "value must be an integer, float, or elements attribute" ); |
215 | } |
216 | |
217 | // Note, we could relax this for vectors with 1 scalable dim, e.g.: |
218 | // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32> |
219 | // However, this would most likely require updating the lowerings to LLVM. |
220 | auto vecType = dyn_cast<VectorType>(type); |
221 | if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue())) |
222 | return emitOpError( |
223 | "intializing scalable vectors with elements attribute is not supported" |
224 | " unless it's a vector splat" ); |
225 | return success(); |
226 | } |
227 | |
228 | bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { |
229 | // The value's type must be the same as the provided type. |
230 | auto typedAttr = llvm::dyn_cast<TypedAttr>(value); |
231 | if (!typedAttr || typedAttr.getType() != type) |
232 | return false; |
233 | // Integer values must be signless. |
234 | if (llvm::isa<IntegerType>(type) && |
235 | !llvm::cast<IntegerType>(type).isSignless()) |
236 | return false; |
237 | // Integer, float, and element attributes are buildable. |
238 | return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value); |
239 | } |
240 | |
241 | ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value, |
242 | Type type, Location loc) { |
243 | if (isBuildableWith(value, type)) |
244 | return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value)); |
245 | return nullptr; |
246 | } |
247 | |
248 | OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } |
249 | |
250 | void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, |
251 | int64_t value, unsigned width) { |
252 | auto type = builder.getIntegerType(width); |
253 | arith::ConstantOp::build(builder, result, type, |
254 | builder.getIntegerAttr(type, value)); |
255 | } |
256 | |
257 | void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, |
258 | int64_t value, Type type) { |
259 | assert(type.isSignlessInteger() && |
260 | "ConstantIntOp can only have signless integer type values" ); |
261 | arith::ConstantOp::build(builder, result, type, |
262 | builder.getIntegerAttr(type, value)); |
263 | } |
264 | |
265 | bool arith::ConstantIntOp::classof(Operation *op) { |
266 | if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) |
267 | return constOp.getType().isSignlessInteger(); |
268 | return false; |
269 | } |
270 | |
271 | void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, |
272 | const APFloat &value, FloatType type) { |
273 | arith::ConstantOp::build(builder, result, type, |
274 | builder.getFloatAttr(type, value)); |
275 | } |
276 | |
277 | bool arith::ConstantFloatOp::classof(Operation *op) { |
278 | if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) |
279 | return llvm::isa<FloatType>(constOp.getType()); |
280 | return false; |
281 | } |
282 | |
283 | void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, |
284 | int64_t value) { |
285 | arith::ConstantOp::build(builder, result, builder.getIndexType(), |
286 | builder.getIndexAttr(value)); |
287 | } |
288 | |
289 | bool arith::ConstantIndexOp::classof(Operation *op) { |
290 | if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) |
291 | return constOp.getType().isIndex(); |
292 | return false; |
293 | } |
294 | |
295 | //===----------------------------------------------------------------------===// |
296 | // AddIOp |
297 | //===----------------------------------------------------------------------===// |
298 | |
299 | OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) { |
300 | // addi(x, 0) -> x |
301 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
302 | return getLhs(); |
303 | |
304 | // addi(subi(a, b), b) -> a |
305 | if (auto sub = getLhs().getDefiningOp<SubIOp>()) |
306 | if (getRhs() == sub.getRhs()) |
307 | return sub.getLhs(); |
308 | |
309 | // addi(b, subi(a, b)) -> a |
310 | if (auto sub = getRhs().getDefiningOp<SubIOp>()) |
311 | if (getLhs() == sub.getRhs()) |
312 | return sub.getLhs(); |
313 | |
314 | return constFoldBinaryOp<IntegerAttr>( |
315 | adaptor.getOperands(), |
316 | [](APInt a, const APInt &b) { return std::move(a) + b; }); |
317 | } |
318 | |
319 | void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
320 | MLIRContext *context) { |
321 | patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS, |
322 | AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context); |
323 | } |
324 | |
325 | //===----------------------------------------------------------------------===// |
326 | // AddUIExtendedOp |
327 | //===----------------------------------------------------------------------===// |
328 | |
329 | std::optional<SmallVector<int64_t, 4>> |
330 | arith::AddUIExtendedOp::getShapeForUnroll() { |
331 | if (auto vt = llvm::dyn_cast<VectorType>(getType(0))) |
332 | return llvm::to_vector<4>(vt.getShape()); |
333 | return std::nullopt; |
334 | } |
335 | |
336 | // Returns the overflow bit, assuming that `sum` is the result of unsigned |
337 | // addition of `operand` and another number. |
338 | static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) { |
339 | return sum.ult(RHS: operand) ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1); |
340 | } |
341 | |
342 | LogicalResult |
343 | arith::AddUIExtendedOp::fold(FoldAdaptor adaptor, |
344 | SmallVectorImpl<OpFoldResult> &results) { |
345 | Type overflowTy = getOverflow().getType(); |
346 | // addui_extended(x, 0) -> x, false |
347 | if (matchPattern(getRhs(), m_Zero())) { |
348 | Builder builder(getContext()); |
349 | auto falseValue = builder.getZeroAttr(overflowTy); |
350 | |
351 | results.push_back(getLhs()); |
352 | results.push_back(falseValue); |
353 | return success(); |
354 | } |
355 | |
356 | // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry |
357 | // Let the `constFoldBinaryOp` utility attempt to fold the sum of both |
358 | // operands. If that succeeds, calculate the overflow bit based on the sum |
359 | // and the first (constant) operand, `lhs`. |
360 | if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>( |
361 | adaptor.getOperands(), |
362 | [](APInt a, const APInt &b) { return std::move(a) + b; })) { |
363 | Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>( |
364 | ArrayRef({sumAttr, adaptor.getLhs()}), |
365 | getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()), |
366 | calculateUnsignedOverflow); |
367 | if (!overflowAttr) |
368 | return failure(); |
369 | |
370 | results.push_back(sumAttr); |
371 | results.push_back(overflowAttr); |
372 | return success(); |
373 | } |
374 | |
375 | return failure(); |
376 | } |
377 | |
378 | void arith::AddUIExtendedOp::getCanonicalizationPatterns( |
379 | RewritePatternSet &patterns, MLIRContext *context) { |
380 | patterns.add<AddUIExtendedToAddI>(context); |
381 | } |
382 | |
383 | //===----------------------------------------------------------------------===// |
384 | // SubIOp |
385 | //===----------------------------------------------------------------------===// |
386 | |
387 | OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) { |
388 | // subi(x,x) -> 0 |
389 | if (getOperand(0) == getOperand(1)) |
390 | return Builder(getContext()).getZeroAttr(getType()); |
391 | // subi(x,0) -> x |
392 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
393 | return getLhs(); |
394 | |
395 | if (auto add = getLhs().getDefiningOp<AddIOp>()) { |
396 | // subi(addi(a, b), b) -> a |
397 | if (getRhs() == add.getRhs()) |
398 | return add.getLhs(); |
399 | // subi(addi(a, b), a) -> b |
400 | if (getRhs() == add.getLhs()) |
401 | return add.getRhs(); |
402 | } |
403 | |
404 | return constFoldBinaryOp<IntegerAttr>( |
405 | adaptor.getOperands(), |
406 | [](APInt a, const APInt &b) { return std::move(a) - b; }); |
407 | } |
408 | |
409 | void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
410 | MLIRContext *context) { |
411 | patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, |
412 | SubIRHSSubConstantLHS, SubILHSSubConstantRHS, |
413 | SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context); |
414 | } |
415 | |
416 | //===----------------------------------------------------------------------===// |
417 | // MulIOp |
418 | //===----------------------------------------------------------------------===// |
419 | |
420 | OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) { |
421 | // muli(x, 0) -> 0 |
422 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
423 | return getRhs(); |
424 | // muli(x, 1) -> x |
425 | if (matchPattern(adaptor.getRhs(), m_One())) |
426 | return getLhs(); |
427 | // TODO: Handle the overflow case. |
428 | |
429 | // default folder |
430 | return constFoldBinaryOp<IntegerAttr>( |
431 | adaptor.getOperands(), |
432 | [](const APInt &a, const APInt &b) { return a * b; }); |
433 | } |
434 | |
435 | void arith::MulIOp::getAsmResultNames( |
436 | function_ref<void(Value, StringRef)> setNameFn) { |
437 | if (!isa<IndexType>(getType())) |
438 | return; |
439 | |
440 | // Match vector.vscale by name to avoid depending on the vector dialect (which |
441 | // is a circular dependency). |
442 | auto isVscale = [](Operation *op) { |
443 | return op && op->getName().getStringRef() == "vector.vscale" ; |
444 | }; |
445 | |
446 | IntegerAttr baseValue; |
447 | auto isVscaleExpr = [&](Value a, Value b) { |
448 | return matchPattern(a, m_Constant(&baseValue)) && |
449 | isVscale(b.getDefiningOp()); |
450 | }; |
451 | |
452 | if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs())) |
453 | return; |
454 | |
455 | // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`. |
456 | SmallString<32> specialNameBuffer; |
457 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
458 | specialName << 'c' << baseValue.getInt() << "_vscale" ; |
459 | setNameFn(getResult(), specialName.str()); |
460 | } |
461 | |
462 | void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
463 | MLIRContext *context) { |
464 | patterns.add<MulIMulIConstant>(context); |
465 | } |
466 | |
467 | //===----------------------------------------------------------------------===// |
468 | // MulSIExtendedOp |
469 | //===----------------------------------------------------------------------===// |
470 | |
471 | std::optional<SmallVector<int64_t, 4>> |
472 | arith::MulSIExtendedOp::getShapeForUnroll() { |
473 | if (auto vt = llvm::dyn_cast<VectorType>(getType(0))) |
474 | return llvm::to_vector<4>(vt.getShape()); |
475 | return std::nullopt; |
476 | } |
477 | |
478 | LogicalResult |
479 | arith::MulSIExtendedOp::fold(FoldAdaptor adaptor, |
480 | SmallVectorImpl<OpFoldResult> &results) { |
481 | // mulsi_extended(x, 0) -> 0, 0 |
482 | if (matchPattern(adaptor.getRhs(), m_Zero())) { |
483 | Attribute zero = adaptor.getRhs(); |
484 | results.push_back(zero); |
485 | results.push_back(zero); |
486 | return success(); |
487 | } |
488 | |
489 | // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high |
490 | if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>( |
491 | adaptor.getOperands(), |
492 | [](const APInt &a, const APInt &b) { return a * b; })) { |
493 | // Invoke the constant fold helper again to calculate the 'high' result. |
494 | Attribute highAttr = constFoldBinaryOp<IntegerAttr>( |
495 | adaptor.getOperands(), [](const APInt &a, const APInt &b) { |
496 | return llvm::APIntOps::mulhs(a, b); |
497 | }); |
498 | assert(highAttr && "Unexpected constant-folding failure" ); |
499 | |
500 | results.push_back(lowAttr); |
501 | results.push_back(highAttr); |
502 | return success(); |
503 | } |
504 | |
505 | return failure(); |
506 | } |
507 | |
508 | void arith::MulSIExtendedOp::getCanonicalizationPatterns( |
509 | RewritePatternSet &patterns, MLIRContext *context) { |
510 | patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context); |
511 | } |
512 | |
513 | //===----------------------------------------------------------------------===// |
514 | // MulUIExtendedOp |
515 | //===----------------------------------------------------------------------===// |
516 | |
517 | std::optional<SmallVector<int64_t, 4>> |
518 | arith::MulUIExtendedOp::getShapeForUnroll() { |
519 | if (auto vt = llvm::dyn_cast<VectorType>(getType(0))) |
520 | return llvm::to_vector<4>(vt.getShape()); |
521 | return std::nullopt; |
522 | } |
523 | |
524 | LogicalResult |
525 | arith::MulUIExtendedOp::fold(FoldAdaptor adaptor, |
526 | SmallVectorImpl<OpFoldResult> &results) { |
527 | // mului_extended(x, 0) -> 0, 0 |
528 | if (matchPattern(adaptor.getRhs(), m_Zero())) { |
529 | Attribute zero = adaptor.getRhs(); |
530 | results.push_back(zero); |
531 | results.push_back(zero); |
532 | return success(); |
533 | } |
534 | |
535 | // mului_extended(x, 1) -> x, 0 |
536 | if (matchPattern(adaptor.getRhs(), m_One())) { |
537 | Builder builder(getContext()); |
538 | Attribute zero = builder.getZeroAttr(getLhs().getType()); |
539 | results.push_back(getLhs()); |
540 | results.push_back(zero); |
541 | return success(); |
542 | } |
543 | |
544 | // mului_extended(cst_a, cst_b) -> cst_low, cst_high |
545 | if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>( |
546 | adaptor.getOperands(), |
547 | [](const APInt &a, const APInt &b) { return a * b; })) { |
548 | // Invoke the constant fold helper again to calculate the 'high' result. |
549 | Attribute highAttr = constFoldBinaryOp<IntegerAttr>( |
550 | adaptor.getOperands(), [](const APInt &a, const APInt &b) { |
551 | return llvm::APIntOps::mulhu(a, b); |
552 | }); |
553 | assert(highAttr && "Unexpected constant-folding failure" ); |
554 | |
555 | results.push_back(lowAttr); |
556 | results.push_back(highAttr); |
557 | return success(); |
558 | } |
559 | |
560 | return failure(); |
561 | } |
562 | |
563 | void arith::MulUIExtendedOp::getCanonicalizationPatterns( |
564 | RewritePatternSet &patterns, MLIRContext *context) { |
565 | patterns.add<MulUIExtendedToMulI>(context); |
566 | } |
567 | |
568 | //===----------------------------------------------------------------------===// |
569 | // DivUIOp |
570 | //===----------------------------------------------------------------------===// |
571 | |
572 | OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) { |
573 | // divui (x, 1) -> x. |
574 | if (matchPattern(adaptor.getRhs(), m_One())) |
575 | return getLhs(); |
576 | |
577 | // Don't fold if it would require a division by zero. |
578 | bool div0 = false; |
579 | auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), |
580 | [&](APInt a, const APInt &b) { |
581 | if (div0 || !b) { |
582 | div0 = true; |
583 | return a; |
584 | } |
585 | return a.udiv(b); |
586 | }); |
587 | |
588 | return div0 ? Attribute() : result; |
589 | } |
590 | |
591 | Speculation::Speculatability arith::DivUIOp::getSpeculatability() { |
592 | // X / 0 => UB |
593 | return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable |
594 | : Speculation::NotSpeculatable; |
595 | } |
596 | |
597 | //===----------------------------------------------------------------------===// |
598 | // DivSIOp |
599 | //===----------------------------------------------------------------------===// |
600 | |
601 | OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) { |
602 | // divsi (x, 1) -> x. |
603 | if (matchPattern(adaptor.getRhs(), m_One())) |
604 | return getLhs(); |
605 | |
606 | // Don't fold if it would overflow or if it requires a division by zero. |
607 | bool overflowOrDiv0 = false; |
608 | auto result = constFoldBinaryOp<IntegerAttr>( |
609 | adaptor.getOperands(), [&](APInt a, const APInt &b) { |
610 | if (overflowOrDiv0 || !b) { |
611 | overflowOrDiv0 = true; |
612 | return a; |
613 | } |
614 | return a.sdiv_ov(b, overflowOrDiv0); |
615 | }); |
616 | |
617 | return overflowOrDiv0 ? Attribute() : result; |
618 | } |
619 | |
620 | Speculation::Speculatability arith::DivSIOp::getSpeculatability() { |
621 | bool mayHaveUB = true; |
622 | |
623 | APInt constRHS; |
624 | // X / 0 => UB |
625 | // INT_MIN / -1 => UB |
626 | if (matchPattern(getRhs(), m_ConstantInt(&constRHS))) |
627 | mayHaveUB = constRHS.isAllOnes() || constRHS.isZero(); |
628 | |
629 | return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable; |
630 | } |
631 | |
632 | //===----------------------------------------------------------------------===// |
633 | // Ceil and floor division folding helpers |
634 | //===----------------------------------------------------------------------===// |
635 | |
636 | static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, |
637 | bool &overflow) { |
638 | // Returns (a-1)/b + 1 |
639 | APInt one(a.getBitWidth(), 1, true); // Signed value 1. |
640 | APInt val = a.ssub_ov(RHS: one, Overflow&: overflow).sdiv_ov(RHS: b, Overflow&: overflow); |
641 | return val.sadd_ov(RHS: one, Overflow&: overflow); |
642 | } |
643 | |
644 | //===----------------------------------------------------------------------===// |
645 | // CeilDivUIOp |
646 | //===----------------------------------------------------------------------===// |
647 | |
648 | OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) { |
649 | // ceildivui (x, 1) -> x. |
650 | if (matchPattern(adaptor.getRhs(), m_One())) |
651 | return getLhs(); |
652 | |
653 | bool overflowOrDiv0 = false; |
654 | auto result = constFoldBinaryOp<IntegerAttr>( |
655 | adaptor.getOperands(), [&](APInt a, const APInt &b) { |
656 | if (overflowOrDiv0 || !b) { |
657 | overflowOrDiv0 = true; |
658 | return a; |
659 | } |
660 | APInt quotient = a.udiv(b); |
661 | if (!a.urem(b)) |
662 | return quotient; |
663 | APInt one(a.getBitWidth(), 1, true); |
664 | return quotient.uadd_ov(one, overflowOrDiv0); |
665 | }); |
666 | |
667 | return overflowOrDiv0 ? Attribute() : result; |
668 | } |
669 | |
670 | Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() { |
671 | // X / 0 => UB |
672 | return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable |
673 | : Speculation::NotSpeculatable; |
674 | } |
675 | |
676 | //===----------------------------------------------------------------------===// |
677 | // CeilDivSIOp |
678 | //===----------------------------------------------------------------------===// |
679 | |
680 | OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) { |
681 | // ceildivsi (x, 1) -> x. |
682 | if (matchPattern(adaptor.getRhs(), m_One())) |
683 | return getLhs(); |
684 | |
685 | // Don't fold if it would overflow or if it requires a division by zero. |
686 | bool overflowOrDiv0 = false; |
687 | auto result = constFoldBinaryOp<IntegerAttr>( |
688 | adaptor.getOperands(), [&](APInt a, const APInt &b) { |
689 | if (overflowOrDiv0 || !b) { |
690 | overflowOrDiv0 = true; |
691 | return a; |
692 | } |
693 | if (!a) |
694 | return a; |
695 | // After this point we know that neither a or b are zero. |
696 | unsigned bits = a.getBitWidth(); |
697 | APInt zero = APInt::getZero(bits); |
698 | bool aGtZero = a.sgt(zero); |
699 | bool bGtZero = b.sgt(zero); |
700 | if (aGtZero && bGtZero) { |
701 | // Both positive, return ceil(a, b). |
702 | return signedCeilNonnegInputs(a, b, overflowOrDiv0); |
703 | } |
704 | if (!aGtZero && !bGtZero) { |
705 | // Both negative, return ceil(-a, -b). |
706 | APInt posA = zero.ssub_ov(a, overflowOrDiv0); |
707 | APInt posB = zero.ssub_ov(b, overflowOrDiv0); |
708 | return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); |
709 | } |
710 | if (!aGtZero && bGtZero) { |
711 | // A is negative, b is positive, return - ( -a / b). |
712 | APInt posA = zero.ssub_ov(a, overflowOrDiv0); |
713 | APInt div = posA.sdiv_ov(b, overflowOrDiv0); |
714 | return zero.ssub_ov(div, overflowOrDiv0); |
715 | } |
716 | // A is positive, b is negative, return - (a / -b). |
717 | APInt posB = zero.ssub_ov(b, overflowOrDiv0); |
718 | APInt div = a.sdiv_ov(posB, overflowOrDiv0); |
719 | return zero.ssub_ov(div, overflowOrDiv0); |
720 | }); |
721 | |
722 | return overflowOrDiv0 ? Attribute() : result; |
723 | } |
724 | |
725 | Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() { |
726 | bool mayHaveUB = true; |
727 | |
728 | APInt constRHS; |
729 | // X / 0 => UB |
730 | // INT_MIN / -1 => UB |
731 | if (matchPattern(getRhs(), m_ConstantInt(&constRHS))) |
732 | mayHaveUB = constRHS.isAllOnes() || constRHS.isZero(); |
733 | |
734 | return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable; |
735 | } |
736 | |
737 | //===----------------------------------------------------------------------===// |
738 | // FloorDivSIOp |
739 | //===----------------------------------------------------------------------===// |
740 | |
741 | OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) { |
742 | // floordivsi (x, 1) -> x. |
743 | if (matchPattern(adaptor.getRhs(), m_One())) |
744 | return getLhs(); |
745 | |
746 | // Don't fold if it would overflow or if it requires a division by zero. |
747 | bool overflowOrDiv = false; |
748 | auto result = constFoldBinaryOp<IntegerAttr>( |
749 | adaptor.getOperands(), [&](APInt a, const APInt &b) { |
750 | if (b.isZero()) { |
751 | overflowOrDiv = true; |
752 | return a; |
753 | } |
754 | return a.sfloordiv_ov(b, overflowOrDiv); |
755 | }); |
756 | |
757 | return overflowOrDiv ? Attribute() : result; |
758 | } |
759 | |
760 | //===----------------------------------------------------------------------===// |
761 | // RemUIOp |
762 | //===----------------------------------------------------------------------===// |
763 | |
764 | OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) { |
765 | // remui (x, 1) -> 0. |
766 | if (matchPattern(adaptor.getRhs(), m_One())) |
767 | return Builder(getContext()).getZeroAttr(getType()); |
768 | |
769 | // Don't fold if it would require a division by zero. |
770 | bool div0 = false; |
771 | auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), |
772 | [&](APInt a, const APInt &b) { |
773 | if (div0 || b.isZero()) { |
774 | div0 = true; |
775 | return a; |
776 | } |
777 | return a.urem(b); |
778 | }); |
779 | |
780 | return div0 ? Attribute() : result; |
781 | } |
782 | |
783 | //===----------------------------------------------------------------------===// |
784 | // RemSIOp |
785 | //===----------------------------------------------------------------------===// |
786 | |
787 | OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) { |
788 | // remsi (x, 1) -> 0. |
789 | if (matchPattern(adaptor.getRhs(), m_One())) |
790 | return Builder(getContext()).getZeroAttr(getType()); |
791 | |
792 | // Don't fold if it would require a division by zero. |
793 | bool div0 = false; |
794 | auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), |
795 | [&](APInt a, const APInt &b) { |
796 | if (div0 || b.isZero()) { |
797 | div0 = true; |
798 | return a; |
799 | } |
800 | return a.srem(b); |
801 | }); |
802 | |
803 | return div0 ? Attribute() : result; |
804 | } |
805 | |
806 | //===----------------------------------------------------------------------===// |
807 | // AndIOp |
808 | //===----------------------------------------------------------------------===// |
809 | |
810 | /// Fold `and(a, and(a, b))` to `and(a, b)` |
811 | static Value foldAndIofAndI(arith::AndIOp op) { |
812 | for (bool reversePrev : {false, true}) { |
813 | auto prev = (reversePrev ? op.getRhs() : op.getLhs()) |
814 | .getDefiningOp<arith::AndIOp>(); |
815 | if (!prev) |
816 | continue; |
817 | |
818 | Value other = (reversePrev ? op.getLhs() : op.getRhs()); |
819 | if (other != prev.getLhs() && other != prev.getRhs()) |
820 | continue; |
821 | |
822 | return prev.getResult(); |
823 | } |
824 | return {}; |
825 | } |
826 | |
827 | OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) { |
828 | /// and(x, 0) -> 0 |
829 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
830 | return getRhs(); |
831 | /// and(x, allOnes) -> x |
832 | APInt intValue; |
833 | if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) && |
834 | intValue.isAllOnes()) |
835 | return getLhs(); |
836 | /// and(x, not(x)) -> 0 |
837 | if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()), |
838 | m_ConstantInt(&intValue))) && |
839 | intValue.isAllOnes()) |
840 | return Builder(getContext()).getZeroAttr(getType()); |
841 | /// and(not(x), x) -> 0 |
842 | if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()), |
843 | m_ConstantInt(&intValue))) && |
844 | intValue.isAllOnes()) |
845 | return Builder(getContext()).getZeroAttr(getType()); |
846 | |
847 | /// and(a, and(a, b)) -> and(a, b) |
848 | if (Value result = foldAndIofAndI(*this)) |
849 | return result; |
850 | |
851 | return constFoldBinaryOp<IntegerAttr>( |
852 | adaptor.getOperands(), |
853 | [](APInt a, const APInt &b) { return std::move(a) & b; }); |
854 | } |
855 | |
856 | //===----------------------------------------------------------------------===// |
857 | // OrIOp |
858 | //===----------------------------------------------------------------------===// |
859 | |
860 | OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) { |
861 | if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) { |
862 | /// or(x, 0) -> x |
863 | if (rhsVal.isZero()) |
864 | return getLhs(); |
865 | /// or(x, <all ones>) -> <all ones> |
866 | if (rhsVal.isAllOnes()) |
867 | return adaptor.getRhs(); |
868 | } |
869 | |
870 | APInt intValue; |
871 | /// or(x, xor(x, 1)) -> 1 |
872 | if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()), |
873 | m_ConstantInt(&intValue))) && |
874 | intValue.isAllOnes()) |
875 | return getRhs().getDefiningOp<XOrIOp>().getRhs(); |
876 | /// or(xor(x, 1), x) -> 1 |
877 | if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()), |
878 | m_ConstantInt(&intValue))) && |
879 | intValue.isAllOnes()) |
880 | return getLhs().getDefiningOp<XOrIOp>().getRhs(); |
881 | |
882 | return constFoldBinaryOp<IntegerAttr>( |
883 | adaptor.getOperands(), |
884 | [](APInt a, const APInt &b) { return std::move(a) | b; }); |
885 | } |
886 | |
887 | //===----------------------------------------------------------------------===// |
888 | // XOrIOp |
889 | //===----------------------------------------------------------------------===// |
890 | |
891 | OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) { |
892 | /// xor(x, 0) -> x |
893 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
894 | return getLhs(); |
895 | /// xor(x, x) -> 0 |
896 | if (getLhs() == getRhs()) |
897 | return Builder(getContext()).getZeroAttr(getType()); |
898 | /// xor(xor(x, a), a) -> x |
899 | /// xor(xor(a, x), a) -> x |
900 | if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) { |
901 | if (prev.getRhs() == getRhs()) |
902 | return prev.getLhs(); |
903 | if (prev.getLhs() == getRhs()) |
904 | return prev.getRhs(); |
905 | } |
906 | /// xor(a, xor(x, a)) -> x |
907 | /// xor(a, xor(a, x)) -> x |
908 | if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) { |
909 | if (prev.getRhs() == getLhs()) |
910 | return prev.getLhs(); |
911 | if (prev.getLhs() == getLhs()) |
912 | return prev.getRhs(); |
913 | } |
914 | |
915 | return constFoldBinaryOp<IntegerAttr>( |
916 | adaptor.getOperands(), |
917 | [](APInt a, const APInt &b) { return std::move(a) ^ b; }); |
918 | } |
919 | |
920 | void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
921 | MLIRContext *context) { |
922 | patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context); |
923 | } |
924 | |
925 | //===----------------------------------------------------------------------===// |
926 | // NegFOp |
927 | //===----------------------------------------------------------------------===// |
928 | |
929 | OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) { |
930 | /// negf(negf(x)) -> x |
931 | if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>()) |
932 | return op.getOperand(); |
933 | return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(), |
934 | [](const APFloat &a) { return -a; }); |
935 | } |
936 | |
937 | //===----------------------------------------------------------------------===// |
938 | // AddFOp |
939 | //===----------------------------------------------------------------------===// |
940 | |
941 | OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) { |
942 | // addf(x, -0) -> x |
943 | if (matchPattern(adaptor.getRhs(), m_NegZeroFloat())) |
944 | return getLhs(); |
945 | |
946 | return constFoldBinaryOp<FloatAttr>( |
947 | adaptor.getOperands(), |
948 | [](const APFloat &a, const APFloat &b) { return a + b; }); |
949 | } |
950 | |
951 | //===----------------------------------------------------------------------===// |
952 | // SubFOp |
953 | //===----------------------------------------------------------------------===// |
954 | |
955 | OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) { |
956 | // subf(x, +0) -> x |
957 | if (matchPattern(adaptor.getRhs(), m_PosZeroFloat())) |
958 | return getLhs(); |
959 | |
960 | return constFoldBinaryOp<FloatAttr>( |
961 | adaptor.getOperands(), |
962 | [](const APFloat &a, const APFloat &b) { return a - b; }); |
963 | } |
964 | |
965 | //===----------------------------------------------------------------------===// |
966 | // MaximumFOp |
967 | //===----------------------------------------------------------------------===// |
968 | |
969 | OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) { |
970 | // maximumf(x,x) -> x |
971 | if (getLhs() == getRhs()) |
972 | return getRhs(); |
973 | |
974 | // maximumf(x, -inf) -> x |
975 | if (matchPattern(adaptor.getRhs(), m_NegInfFloat())) |
976 | return getLhs(); |
977 | |
978 | return constFoldBinaryOp<FloatAttr>( |
979 | adaptor.getOperands(), |
980 | [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); |
981 | } |
982 | |
983 | //===----------------------------------------------------------------------===// |
984 | // MaxNumFOp |
985 | //===----------------------------------------------------------------------===// |
986 | |
987 | OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) { |
988 | // maxnumf(x,x) -> x |
989 | if (getLhs() == getRhs()) |
990 | return getRhs(); |
991 | |
992 | // maxnumf(x, -inf) -> x |
993 | if (matchPattern(adaptor.getRhs(), m_NegInfFloat())) |
994 | return getLhs(); |
995 | |
996 | return constFoldBinaryOp<FloatAttr>( |
997 | adaptor.getOperands(), |
998 | [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); |
999 | } |
1000 | |
1001 | //===----------------------------------------------------------------------===// |
1002 | // MaxSIOp |
1003 | //===----------------------------------------------------------------------===// |
1004 | |
1005 | OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) { |
1006 | // maxsi(x,x) -> x |
1007 | if (getLhs() == getRhs()) |
1008 | return getRhs(); |
1009 | |
1010 | if (APInt intValue; |
1011 | matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { |
1012 | // maxsi(x,MAX_INT) -> MAX_INT |
1013 | if (intValue.isMaxSignedValue()) |
1014 | return getRhs(); |
1015 | // maxsi(x, MIN_INT) -> x |
1016 | if (intValue.isMinSignedValue()) |
1017 | return getLhs(); |
1018 | } |
1019 | |
1020 | return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), |
1021 | [](const APInt &a, const APInt &b) { |
1022 | return llvm::APIntOps::smax(a, b); |
1023 | }); |
1024 | } |
1025 | |
1026 | //===----------------------------------------------------------------------===// |
1027 | // MaxUIOp |
1028 | //===----------------------------------------------------------------------===// |
1029 | |
1030 | OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) { |
1031 | // maxui(x,x) -> x |
1032 | if (getLhs() == getRhs()) |
1033 | return getRhs(); |
1034 | |
1035 | if (APInt intValue; |
1036 | matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { |
1037 | // maxui(x,MAX_INT) -> MAX_INT |
1038 | if (intValue.isMaxValue()) |
1039 | return getRhs(); |
1040 | // maxui(x, MIN_INT) -> x |
1041 | if (intValue.isMinValue()) |
1042 | return getLhs(); |
1043 | } |
1044 | |
1045 | return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), |
1046 | [](const APInt &a, const APInt &b) { |
1047 | return llvm::APIntOps::umax(a, b); |
1048 | }); |
1049 | } |
1050 | |
1051 | //===----------------------------------------------------------------------===// |
1052 | // MinimumFOp |
1053 | //===----------------------------------------------------------------------===// |
1054 | |
1055 | OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) { |
1056 | // minimumf(x,x) -> x |
1057 | if (getLhs() == getRhs()) |
1058 | return getRhs(); |
1059 | |
1060 | // minimumf(x, +inf) -> x |
1061 | if (matchPattern(adaptor.getRhs(), m_PosInfFloat())) |
1062 | return getLhs(); |
1063 | |
1064 | return constFoldBinaryOp<FloatAttr>( |
1065 | adaptor.getOperands(), |
1066 | [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); |
1067 | } |
1068 | |
1069 | //===----------------------------------------------------------------------===// |
1070 | // MinNumFOp |
1071 | //===----------------------------------------------------------------------===// |
1072 | |
1073 | OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) { |
1074 | // minnumf(x,x) -> x |
1075 | if (getLhs() == getRhs()) |
1076 | return getRhs(); |
1077 | |
1078 | // minnumf(x, +inf) -> x |
1079 | if (matchPattern(adaptor.getRhs(), m_PosInfFloat())) |
1080 | return getLhs(); |
1081 | |
1082 | return constFoldBinaryOp<FloatAttr>( |
1083 | adaptor.getOperands(), |
1084 | [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); }); |
1085 | } |
1086 | |
1087 | //===----------------------------------------------------------------------===// |
1088 | // MinSIOp |
1089 | //===----------------------------------------------------------------------===// |
1090 | |
1091 | OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) { |
1092 | // minsi(x,x) -> x |
1093 | if (getLhs() == getRhs()) |
1094 | return getRhs(); |
1095 | |
1096 | if (APInt intValue; |
1097 | matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { |
1098 | // minsi(x,MIN_INT) -> MIN_INT |
1099 | if (intValue.isMinSignedValue()) |
1100 | return getRhs(); |
1101 | // minsi(x, MAX_INT) -> x |
1102 | if (intValue.isMaxSignedValue()) |
1103 | return getLhs(); |
1104 | } |
1105 | |
1106 | return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), |
1107 | [](const APInt &a, const APInt &b) { |
1108 | return llvm::APIntOps::smin(a, b); |
1109 | }); |
1110 | } |
1111 | |
1112 | //===----------------------------------------------------------------------===// |
1113 | // MinUIOp |
1114 | //===----------------------------------------------------------------------===// |
1115 | |
1116 | OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) { |
1117 | // minui(x,x) -> x |
1118 | if (getLhs() == getRhs()) |
1119 | return getRhs(); |
1120 | |
1121 | if (APInt intValue; |
1122 | matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { |
1123 | // minui(x,MIN_INT) -> MIN_INT |
1124 | if (intValue.isMinValue()) |
1125 | return getRhs(); |
1126 | // minui(x, MAX_INT) -> x |
1127 | if (intValue.isMaxValue()) |
1128 | return getLhs(); |
1129 | } |
1130 | |
1131 | return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), |
1132 | [](const APInt &a, const APInt &b) { |
1133 | return llvm::APIntOps::umin(a, b); |
1134 | }); |
1135 | } |
1136 | |
1137 | //===----------------------------------------------------------------------===// |
1138 | // MulFOp |
1139 | //===----------------------------------------------------------------------===// |
1140 | |
1141 | OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) { |
1142 | // mulf(x, 1) -> x |
1143 | if (matchPattern(adaptor.getRhs(), m_OneFloat())) |
1144 | return getLhs(); |
1145 | |
1146 | return constFoldBinaryOp<FloatAttr>( |
1147 | adaptor.getOperands(), |
1148 | [](const APFloat &a, const APFloat &b) { return a * b; }); |
1149 | } |
1150 | |
1151 | void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1152 | MLIRContext *context) { |
1153 | patterns.add<MulFOfNegF>(context); |
1154 | } |
1155 | |
1156 | //===----------------------------------------------------------------------===// |
1157 | // DivFOp |
1158 | //===----------------------------------------------------------------------===// |
1159 | |
1160 | OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) { |
1161 | // divf(x, 1) -> x |
1162 | if (matchPattern(adaptor.getRhs(), m_OneFloat())) |
1163 | return getLhs(); |
1164 | |
1165 | return constFoldBinaryOp<FloatAttr>( |
1166 | adaptor.getOperands(), |
1167 | [](const APFloat &a, const APFloat &b) { return a / b; }); |
1168 | } |
1169 | |
1170 | void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1171 | MLIRContext *context) { |
1172 | patterns.add<DivFOfNegF>(context); |
1173 | } |
1174 | |
1175 | //===----------------------------------------------------------------------===// |
1176 | // RemFOp |
1177 | //===----------------------------------------------------------------------===// |
1178 | |
1179 | OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) { |
1180 | return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), |
1181 | [](const APFloat &a, const APFloat &b) { |
1182 | APFloat result(a); |
1183 | (void)result.remainder(b); |
1184 | return result; |
1185 | }); |
1186 | } |
1187 | |
1188 | //===----------------------------------------------------------------------===// |
1189 | // Utility functions for verifying cast ops |
1190 | //===----------------------------------------------------------------------===// |
1191 | |
1192 | template <typename... Types> |
1193 | using type_list = std::tuple<Types...> *; |
1194 | |
1195 | /// Returns a non-null type only if the provided type is one of the allowed |
1196 | /// types or one of the allowed shaped types of the allowed types. Returns the |
1197 | /// element type if a valid shaped type is provided. |
1198 | template <typename... ShapedTypes, typename... ElementTypes> |
1199 | static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, |
1200 | type_list<ElementTypes...>) { |
1201 | if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type)) |
1202 | return {}; |
1203 | |
1204 | auto underlyingType = getElementTypeOrSelf(type); |
1205 | if (!llvm::isa<ElementTypes...>(underlyingType)) |
1206 | return {}; |
1207 | |
1208 | return underlyingType; |
1209 | } |
1210 | |
1211 | /// Get allowed underlying types for vectors and tensors. |
1212 | template <typename... ElementTypes> |
1213 | static Type getTypeIfLike(Type type) { |
1214 | return getUnderlyingType(type, type_list<VectorType, TensorType>(), |
1215 | type_list<ElementTypes...>()); |
1216 | } |
1217 | |
1218 | /// Get allowed underlying types for vectors, tensors, and memrefs. |
1219 | template <typename... ElementTypes> |
1220 | static Type getTypeIfLikeOrMemRef(Type type) { |
1221 | return getUnderlyingType(type, |
1222 | type_list<VectorType, TensorType, MemRefType>(), |
1223 | type_list<ElementTypes...>()); |
1224 | } |
1225 | |
1226 | /// Return false if both types are ranked tensor with mismatching encoding. |
1227 | static bool hasSameEncoding(Type typeA, Type typeB) { |
1228 | auto rankedTensorA = dyn_cast<RankedTensorType>(typeA); |
1229 | auto rankedTensorB = dyn_cast<RankedTensorType>(typeB); |
1230 | if (!rankedTensorA || !rankedTensorB) |
1231 | return true; |
1232 | return rankedTensorA.getEncoding() == rankedTensorB.getEncoding(); |
1233 | } |
1234 | |
1235 | static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { |
1236 | if (inputs.size() != 1 || outputs.size() != 1) |
1237 | return false; |
1238 | if (!hasSameEncoding(typeA: inputs.front(), typeB: outputs.front())) |
1239 | return false; |
1240 | return succeeded(result: verifyCompatibleShapes(types1: inputs.front(), types2: outputs.front())); |
1241 | } |
1242 | |
1243 | //===----------------------------------------------------------------------===// |
1244 | // Verifiers for integer and floating point extension/truncation ops |
1245 | //===----------------------------------------------------------------------===// |
1246 | |
1247 | // Extend ops can only extend to a wider type. |
1248 | template <typename ValType, typename Op> |
1249 | static LogicalResult verifyExtOp(Op op) { |
1250 | Type srcType = getElementTypeOrSelf(op.getIn().getType()); |
1251 | Type dstType = getElementTypeOrSelf(op.getType()); |
1252 | |
1253 | if (llvm::cast<ValType>(srcType).getWidth() >= |
1254 | llvm::cast<ValType>(dstType).getWidth()) |
1255 | return op.emitError("result type " ) |
1256 | << dstType << " must be wider than operand type " << srcType; |
1257 | |
1258 | return success(); |
1259 | } |
1260 | |
1261 | // Truncate ops can only truncate to a shorter type. |
1262 | template <typename ValType, typename Op> |
1263 | static LogicalResult verifyTruncateOp(Op op) { |
1264 | Type srcType = getElementTypeOrSelf(op.getIn().getType()); |
1265 | Type dstType = getElementTypeOrSelf(op.getType()); |
1266 | |
1267 | if (llvm::cast<ValType>(srcType).getWidth() <= |
1268 | llvm::cast<ValType>(dstType).getWidth()) |
1269 | return op.emitError("result type " ) |
1270 | << dstType << " must be shorter than operand type " << srcType; |
1271 | |
1272 | return success(); |
1273 | } |
1274 | |
1275 | /// Validate a cast that changes the width of a type. |
1276 | template <template <typename> class WidthComparator, typename... ElementTypes> |
1277 | static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { |
1278 | if (!areValidCastInputsAndOutputs(inputs, outputs)) |
1279 | return false; |
1280 | |
1281 | auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); |
1282 | auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); |
1283 | if (!srcType || !dstType) |
1284 | return false; |
1285 | |
1286 | return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), |
1287 | srcType.getIntOrFloatBitWidth()); |
1288 | } |
1289 | |
1290 | /// Attempts to convert `sourceValue` to an APFloat value with |
1291 | /// `targetSemantics` and `roundingMode`, without any information loss. |
1292 | static FailureOr<APFloat> convertFloatValue( |
1293 | APFloat sourceValue, const llvm::fltSemantics &targetSemantics, |
1294 | llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) { |
1295 | bool losesInfo = false; |
1296 | auto status = sourceValue.convert(ToSemantics: targetSemantics, RM: roundingMode, losesInfo: &losesInfo); |
1297 | if (losesInfo || status != APFloat::opOK) |
1298 | return failure(); |
1299 | |
1300 | return sourceValue; |
1301 | } |
1302 | |
1303 | //===----------------------------------------------------------------------===// |
1304 | // ExtUIOp |
1305 | //===----------------------------------------------------------------------===// |
1306 | |
1307 | OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) { |
1308 | if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { |
1309 | getInMutable().assign(lhs.getIn()); |
1310 | return getResult(); |
1311 | } |
1312 | |
1313 | Type resType = getElementTypeOrSelf(getType()); |
1314 | unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); |
1315 | return constFoldCastOp<IntegerAttr, IntegerAttr>( |
1316 | adaptor.getOperands(), getType(), |
1317 | [bitWidth](const APInt &a, bool &castStatus) { |
1318 | return a.zext(bitWidth); |
1319 | }); |
1320 | } |
1321 | |
1322 | bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1323 | return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); |
1324 | } |
1325 | |
1326 | LogicalResult arith::ExtUIOp::verify() { |
1327 | return verifyExtOp<IntegerType>(*this); |
1328 | } |
1329 | |
1330 | //===----------------------------------------------------------------------===// |
1331 | // ExtSIOp |
1332 | //===----------------------------------------------------------------------===// |
1333 | |
1334 | OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) { |
1335 | if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { |
1336 | getInMutable().assign(lhs.getIn()); |
1337 | return getResult(); |
1338 | } |
1339 | |
1340 | Type resType = getElementTypeOrSelf(getType()); |
1341 | unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); |
1342 | return constFoldCastOp<IntegerAttr, IntegerAttr>( |
1343 | adaptor.getOperands(), getType(), |
1344 | [bitWidth](const APInt &a, bool &castStatus) { |
1345 | return a.sext(bitWidth); |
1346 | }); |
1347 | } |
1348 | |
1349 | bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1350 | return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); |
1351 | } |
1352 | |
1353 | void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1354 | MLIRContext *context) { |
1355 | patterns.add<ExtSIOfExtUI>(context); |
1356 | } |
1357 | |
1358 | LogicalResult arith::ExtSIOp::verify() { |
1359 | return verifyExtOp<IntegerType>(*this); |
1360 | } |
1361 | |
1362 | //===----------------------------------------------------------------------===// |
1363 | // ExtFOp |
1364 | //===----------------------------------------------------------------------===// |
1365 | |
1366 | /// Fold extension of float constants when there is no information loss due the |
1367 | /// difference in fp semantics. |
1368 | OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) { |
1369 | auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType())); |
1370 | const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics(); |
1371 | return constFoldCastOp<FloatAttr, FloatAttr>( |
1372 | adaptor.getOperands(), getType(), |
1373 | [&targetSemantics](const APFloat &a, bool &castStatus) { |
1374 | FailureOr<APFloat> result = convertFloatValue(a, targetSemantics); |
1375 | if (failed(result)) { |
1376 | castStatus = false; |
1377 | return a; |
1378 | } |
1379 | return *result; |
1380 | }); |
1381 | } |
1382 | |
1383 | bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1384 | return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); |
1385 | } |
1386 | |
1387 | LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); } |
1388 | |
1389 | //===----------------------------------------------------------------------===// |
1390 | // TruncIOp |
1391 | //===----------------------------------------------------------------------===// |
1392 | |
1393 | OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) { |
1394 | if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || |
1395 | matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) { |
1396 | Value src = getOperand().getDefiningOp()->getOperand(0); |
1397 | Type srcType = getElementTypeOrSelf(src.getType()); |
1398 | Type dstType = getElementTypeOrSelf(getType()); |
1399 | // trunci(zexti(a)) -> trunci(a) |
1400 | // trunci(sexti(a)) -> trunci(a) |
1401 | if (llvm::cast<IntegerType>(srcType).getWidth() > |
1402 | llvm::cast<IntegerType>(dstType).getWidth()) { |
1403 | setOperand(src); |
1404 | return getResult(); |
1405 | } |
1406 | |
1407 | // trunci(zexti(a)) -> a |
1408 | // trunci(sexti(a)) -> a |
1409 | if (srcType == dstType) |
1410 | return src; |
1411 | } |
1412 | |
1413 | // trunci(trunci(a)) -> trunci(a)) |
1414 | if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) { |
1415 | setOperand(getOperand().getDefiningOp()->getOperand(0)); |
1416 | return getResult(); |
1417 | } |
1418 | |
1419 | Type resType = getElementTypeOrSelf(getType()); |
1420 | unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); |
1421 | return constFoldCastOp<IntegerAttr, IntegerAttr>( |
1422 | adaptor.getOperands(), getType(), |
1423 | [bitWidth](const APInt &a, bool &castStatus) { |
1424 | return a.trunc(bitWidth); |
1425 | }); |
1426 | } |
1427 | |
1428 | bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1429 | return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); |
1430 | } |
1431 | |
1432 | void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1433 | MLIRContext *context) { |
1434 | patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI, |
1435 | TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>( |
1436 | context); |
1437 | } |
1438 | |
1439 | LogicalResult arith::TruncIOp::verify() { |
1440 | return verifyTruncateOp<IntegerType>(*this); |
1441 | } |
1442 | |
1443 | //===----------------------------------------------------------------------===// |
1444 | // TruncFOp |
1445 | //===----------------------------------------------------------------------===// |
1446 | |
1447 | /// Perform safe const propagation for truncf, i.e., only propagate if FP value |
1448 | /// can be represented without precision loss. |
1449 | OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) { |
1450 | auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType())); |
1451 | const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics(); |
1452 | return constFoldCastOp<FloatAttr, FloatAttr>( |
1453 | adaptor.getOperands(), getType(), |
1454 | [this, &targetSemantics](const APFloat &a, bool &castStatus) { |
1455 | RoundingMode roundingMode = |
1456 | getRoundingmode().value_or(RoundingMode::to_nearest_even); |
1457 | llvm::RoundingMode llvmRoundingMode = |
1458 | convertArithRoundingModeToLLVMIR(roundingMode); |
1459 | FailureOr<APFloat> result = |
1460 | convertFloatValue(a, targetSemantics, llvmRoundingMode); |
1461 | if (failed(result)) { |
1462 | castStatus = false; |
1463 | return a; |
1464 | } |
1465 | return *result; |
1466 | }); |
1467 | } |
1468 | |
1469 | bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1470 | return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); |
1471 | } |
1472 | |
1473 | LogicalResult arith::TruncFOp::verify() { |
1474 | return verifyTruncateOp<FloatType>(*this); |
1475 | } |
1476 | |
1477 | //===----------------------------------------------------------------------===// |
1478 | // AndIOp |
1479 | //===----------------------------------------------------------------------===// |
1480 | |
1481 | void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1482 | MLIRContext *context) { |
1483 | patterns.add<AndOfExtUI, AndOfExtSI>(context); |
1484 | } |
1485 | |
1486 | //===----------------------------------------------------------------------===// |
1487 | // OrIOp |
1488 | //===----------------------------------------------------------------------===// |
1489 | |
1490 | void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1491 | MLIRContext *context) { |
1492 | patterns.add<OrOfExtUI, OrOfExtSI>(context); |
1493 | } |
1494 | |
1495 | //===----------------------------------------------------------------------===// |
1496 | // Verifiers for casts between integers and floats. |
1497 | //===----------------------------------------------------------------------===// |
1498 | |
1499 | template <typename From, typename To> |
1500 | static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { |
1501 | if (!areValidCastInputsAndOutputs(inputs, outputs)) |
1502 | return false; |
1503 | |
1504 | auto srcType = getTypeIfLike<From>(inputs.front()); |
1505 | auto dstType = getTypeIfLike<To>(outputs.back()); |
1506 | |
1507 | return srcType && dstType; |
1508 | } |
1509 | |
1510 | //===----------------------------------------------------------------------===// |
1511 | // UIToFPOp |
1512 | //===----------------------------------------------------------------------===// |
1513 | |
1514 | bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1515 | return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); |
1516 | } |
1517 | |
1518 | OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) { |
1519 | Type resEleType = getElementTypeOrSelf(getType()); |
1520 | return constFoldCastOp<IntegerAttr, FloatAttr>( |
1521 | adaptor.getOperands(), getType(), |
1522 | [&resEleType](const APInt &a, bool &castStatus) { |
1523 | FloatType floatTy = llvm::cast<FloatType>(resEleType); |
1524 | APFloat apf(floatTy.getFloatSemantics(), |
1525 | APInt::getZero(floatTy.getWidth())); |
1526 | apf.convertFromAPInt(a, /*IsSigned=*/false, |
1527 | APFloat::rmNearestTiesToEven); |
1528 | return apf; |
1529 | }); |
1530 | } |
1531 | |
1532 | //===----------------------------------------------------------------------===// |
1533 | // SIToFPOp |
1534 | //===----------------------------------------------------------------------===// |
1535 | |
1536 | bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1537 | return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); |
1538 | } |
1539 | |
1540 | OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) { |
1541 | Type resEleType = getElementTypeOrSelf(getType()); |
1542 | return constFoldCastOp<IntegerAttr, FloatAttr>( |
1543 | adaptor.getOperands(), getType(), |
1544 | [&resEleType](const APInt &a, bool &castStatus) { |
1545 | FloatType floatTy = llvm::cast<FloatType>(resEleType); |
1546 | APFloat apf(floatTy.getFloatSemantics(), |
1547 | APInt::getZero(floatTy.getWidth())); |
1548 | apf.convertFromAPInt(a, /*IsSigned=*/true, |
1549 | APFloat::rmNearestTiesToEven); |
1550 | return apf; |
1551 | }); |
1552 | } |
1553 | |
1554 | //===----------------------------------------------------------------------===// |
1555 | // FPToUIOp |
1556 | //===----------------------------------------------------------------------===// |
1557 | |
1558 | bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1559 | return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); |
1560 | } |
1561 | |
1562 | OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) { |
1563 | Type resType = getElementTypeOrSelf(getType()); |
1564 | unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); |
1565 | return constFoldCastOp<FloatAttr, IntegerAttr>( |
1566 | adaptor.getOperands(), getType(), |
1567 | [&bitWidth](const APFloat &a, bool &castStatus) { |
1568 | bool ignored; |
1569 | APSInt api(bitWidth, /*isUnsigned=*/true); |
1570 | castStatus = APFloat::opInvalidOp != |
1571 | a.convertToInteger(api, APFloat::rmTowardZero, &ignored); |
1572 | return api; |
1573 | }); |
1574 | } |
1575 | |
1576 | //===----------------------------------------------------------------------===// |
1577 | // FPToSIOp |
1578 | //===----------------------------------------------------------------------===// |
1579 | |
1580 | bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1581 | return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); |
1582 | } |
1583 | |
1584 | OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) { |
1585 | Type resType = getElementTypeOrSelf(getType()); |
1586 | unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); |
1587 | return constFoldCastOp<FloatAttr, IntegerAttr>( |
1588 | adaptor.getOperands(), getType(), |
1589 | [&bitWidth](const APFloat &a, bool &castStatus) { |
1590 | bool ignored; |
1591 | APSInt api(bitWidth, /*isUnsigned=*/false); |
1592 | castStatus = APFloat::opInvalidOp != |
1593 | a.convertToInteger(api, APFloat::rmTowardZero, &ignored); |
1594 | return api; |
1595 | }); |
1596 | } |
1597 | |
1598 | //===----------------------------------------------------------------------===// |
1599 | // IndexCastOp |
1600 | //===----------------------------------------------------------------------===// |
1601 | |
1602 | static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) { |
1603 | if (!areValidCastInputsAndOutputs(inputs, outputs)) |
1604 | return false; |
1605 | |
1606 | auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(type: inputs.front()); |
1607 | auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(type: outputs.front()); |
1608 | if (!srcType || !dstType) |
1609 | return false; |
1610 | |
1611 | return (srcType.isIndex() && dstType.isSignlessInteger()) || |
1612 | (srcType.isSignlessInteger() && dstType.isIndex()); |
1613 | } |
1614 | |
1615 | bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, |
1616 | TypeRange outputs) { |
1617 | return areIndexCastCompatible(inputs, outputs); |
1618 | } |
1619 | |
1620 | OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) { |
1621 | // index_cast(constant) -> constant |
1622 | unsigned resultBitwidth = 64; // Default for index integer attributes. |
1623 | if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType()))) |
1624 | resultBitwidth = intTy.getWidth(); |
1625 | |
1626 | return constFoldCastOp<IntegerAttr, IntegerAttr>( |
1627 | adaptor.getOperands(), getType(), |
1628 | [resultBitwidth](const APInt &a, bool & /*castStatus*/) { |
1629 | return a.sextOrTrunc(resultBitwidth); |
1630 | }); |
1631 | } |
1632 | |
1633 | void arith::IndexCastOp::getCanonicalizationPatterns( |
1634 | RewritePatternSet &patterns, MLIRContext *context) { |
1635 | patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context); |
1636 | } |
1637 | |
1638 | //===----------------------------------------------------------------------===// |
1639 | // IndexCastUIOp |
1640 | //===----------------------------------------------------------------------===// |
1641 | |
1642 | bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs, |
1643 | TypeRange outputs) { |
1644 | return areIndexCastCompatible(inputs, outputs); |
1645 | } |
1646 | |
1647 | OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) { |
1648 | // index_castui(constant) -> constant |
1649 | unsigned resultBitwidth = 64; // Default for index integer attributes. |
1650 | if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType()))) |
1651 | resultBitwidth = intTy.getWidth(); |
1652 | |
1653 | return constFoldCastOp<IntegerAttr, IntegerAttr>( |
1654 | adaptor.getOperands(), getType(), |
1655 | [resultBitwidth](const APInt &a, bool & /*castStatus*/) { |
1656 | return a.zextOrTrunc(resultBitwidth); |
1657 | }); |
1658 | } |
1659 | |
1660 | void arith::IndexCastUIOp::getCanonicalizationPatterns( |
1661 | RewritePatternSet &patterns, MLIRContext *context) { |
1662 | patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context); |
1663 | } |
1664 | |
1665 | //===----------------------------------------------------------------------===// |
1666 | // BitcastOp |
1667 | //===----------------------------------------------------------------------===// |
1668 | |
1669 | bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1670 | if (!areValidCastInputsAndOutputs(inputs, outputs)) |
1671 | return false; |
1672 | |
1673 | auto srcType = |
1674 | getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front()); |
1675 | auto dstType = |
1676 | getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front()); |
1677 | if (!srcType || !dstType) |
1678 | return false; |
1679 | |
1680 | return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); |
1681 | } |
1682 | |
1683 | OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) { |
1684 | auto resType = getType(); |
1685 | auto operand = adaptor.getIn(); |
1686 | if (!operand) |
1687 | return {}; |
1688 | |
1689 | /// Bitcast dense elements. |
1690 | if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand)) |
1691 | return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType()); |
1692 | /// Other shaped types unhandled. |
1693 | if (llvm::isa<ShapedType>(resType)) |
1694 | return {}; |
1695 | |
1696 | /// Bitcast integer or float to integer or float. |
1697 | APInt bits = llvm::isa<FloatAttr>(operand) |
1698 | ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt() |
1699 | : llvm::cast<IntegerAttr>(operand).getValue(); |
1700 | |
1701 | if (auto resFloatType = llvm::dyn_cast<FloatType>(resType)) |
1702 | return FloatAttr::get(resType, |
1703 | APFloat(resFloatType.getFloatSemantics(), bits)); |
1704 | return IntegerAttr::get(resType, bits); |
1705 | } |
1706 | |
1707 | void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1708 | MLIRContext *context) { |
1709 | patterns.add<BitcastOfBitcast>(context); |
1710 | } |
1711 | |
1712 | //===----------------------------------------------------------------------===// |
1713 | // CmpIOp |
1714 | //===----------------------------------------------------------------------===// |
1715 | |
1716 | /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer |
1717 | /// comparison predicates. |
1718 | bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, |
1719 | const APInt &lhs, const APInt &rhs) { |
1720 | switch (predicate) { |
1721 | case arith::CmpIPredicate::eq: |
1722 | return lhs.eq(RHS: rhs); |
1723 | case arith::CmpIPredicate::ne: |
1724 | return lhs.ne(RHS: rhs); |
1725 | case arith::CmpIPredicate::slt: |
1726 | return lhs.slt(RHS: rhs); |
1727 | case arith::CmpIPredicate::sle: |
1728 | return lhs.sle(RHS: rhs); |
1729 | case arith::CmpIPredicate::sgt: |
1730 | return lhs.sgt(RHS: rhs); |
1731 | case arith::CmpIPredicate::sge: |
1732 | return lhs.sge(RHS: rhs); |
1733 | case arith::CmpIPredicate::ult: |
1734 | return lhs.ult(RHS: rhs); |
1735 | case arith::CmpIPredicate::ule: |
1736 | return lhs.ule(RHS: rhs); |
1737 | case arith::CmpIPredicate::ugt: |
1738 | return lhs.ugt(RHS: rhs); |
1739 | case arith::CmpIPredicate::uge: |
1740 | return lhs.uge(RHS: rhs); |
1741 | } |
1742 | llvm_unreachable("unknown cmpi predicate kind" ); |
1743 | } |
1744 | |
1745 | /// Returns true if the predicate is true for two equal operands. |
1746 | static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { |
1747 | switch (predicate) { |
1748 | case arith::CmpIPredicate::eq: |
1749 | case arith::CmpIPredicate::sle: |
1750 | case arith::CmpIPredicate::sge: |
1751 | case arith::CmpIPredicate::ule: |
1752 | case arith::CmpIPredicate::uge: |
1753 | return true; |
1754 | case arith::CmpIPredicate::ne: |
1755 | case arith::CmpIPredicate::slt: |
1756 | case arith::CmpIPredicate::sgt: |
1757 | case arith::CmpIPredicate::ult: |
1758 | case arith::CmpIPredicate::ugt: |
1759 | return false; |
1760 | } |
1761 | llvm_unreachable("unknown cmpi predicate kind" ); |
1762 | } |
1763 | |
1764 | static std::optional<int64_t> getIntegerWidth(Type t) { |
1765 | if (auto intType = llvm::dyn_cast<IntegerType>(t)) { |
1766 | return intType.getWidth(); |
1767 | } |
1768 | if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) { |
1769 | return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth(); |
1770 | } |
1771 | return std::nullopt; |
1772 | } |
1773 | |
1774 | OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) { |
1775 | // cmpi(pred, x, x) |
1776 | if (getLhs() == getRhs()) { |
1777 | auto val = applyCmpPredicateToEqualOperands(getPredicate()); |
1778 | return getBoolAttribute(getType(), val); |
1779 | } |
1780 | |
1781 | if (matchPattern(adaptor.getRhs(), m_Zero())) { |
1782 | if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { |
1783 | // extsi(%x : i1 -> iN) != 0 -> %x |
1784 | std::optional<int64_t> integerWidth = |
1785 | getIntegerWidth(extOp.getOperand().getType()); |
1786 | if (integerWidth && integerWidth.value() == 1 && |
1787 | getPredicate() == arith::CmpIPredicate::ne) |
1788 | return extOp.getOperand(); |
1789 | } |
1790 | if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { |
1791 | // extui(%x : i1 -> iN) != 0 -> %x |
1792 | std::optional<int64_t> integerWidth = |
1793 | getIntegerWidth(extOp.getOperand().getType()); |
1794 | if (integerWidth && integerWidth.value() == 1 && |
1795 | getPredicate() == arith::CmpIPredicate::ne) |
1796 | return extOp.getOperand(); |
1797 | } |
1798 | } |
1799 | |
1800 | // Move constant to the right side. |
1801 | if (adaptor.getLhs() && !adaptor.getRhs()) { |
1802 | // Do not use invertPredicate, as it will change eq to ne and vice versa. |
1803 | using Pred = CmpIPredicate; |
1804 | const std::pair<Pred, Pred> invPreds[] = { |
1805 | {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge}, |
1806 | {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult}, |
1807 | {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq}, |
1808 | {Pred::ne, Pred::ne}, |
1809 | }; |
1810 | Pred origPred = getPredicate(); |
1811 | for (auto pred : invPreds) { |
1812 | if (origPred == pred.first) { |
1813 | setPredicate(pred.second); |
1814 | Value lhs = getLhs(); |
1815 | Value rhs = getRhs(); |
1816 | getLhsMutable().assign(rhs); |
1817 | getRhsMutable().assign(lhs); |
1818 | return getResult(); |
1819 | } |
1820 | } |
1821 | llvm_unreachable("unknown cmpi predicate kind" ); |
1822 | } |
1823 | |
1824 | // We are moving constants to the right side; So if lhs is constant rhs is |
1825 | // guaranteed to be a constant. |
1826 | if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) { |
1827 | return constFoldBinaryOp<IntegerAttr>( |
1828 | adaptor.getOperands(), getI1SameShape(lhs.getType()), |
1829 | [pred = getPredicate()](const APInt &lhs, const APInt &rhs) { |
1830 | return APInt(1, |
1831 | static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs))); |
1832 | }); |
1833 | } |
1834 | |
1835 | return {}; |
1836 | } |
1837 | |
1838 | void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1839 | MLIRContext *context) { |
1840 | patterns.insert<CmpIExtSI, CmpIExtUI>(context); |
1841 | } |
1842 | |
1843 | //===----------------------------------------------------------------------===// |
1844 | // CmpFOp |
1845 | //===----------------------------------------------------------------------===// |
1846 | |
1847 | /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point |
1848 | /// comparison predicates. |
1849 | bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, |
1850 | const APFloat &lhs, const APFloat &rhs) { |
1851 | auto cmpResult = lhs.compare(RHS: rhs); |
1852 | switch (predicate) { |
1853 | case arith::CmpFPredicate::AlwaysFalse: |
1854 | return false; |
1855 | case arith::CmpFPredicate::OEQ: |
1856 | return cmpResult == APFloat::cmpEqual; |
1857 | case arith::CmpFPredicate::OGT: |
1858 | return cmpResult == APFloat::cmpGreaterThan; |
1859 | case arith::CmpFPredicate::OGE: |
1860 | return cmpResult == APFloat::cmpGreaterThan || |
1861 | cmpResult == APFloat::cmpEqual; |
1862 | case arith::CmpFPredicate::OLT: |
1863 | return cmpResult == APFloat::cmpLessThan; |
1864 | case arith::CmpFPredicate::OLE: |
1865 | return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; |
1866 | case arith::CmpFPredicate::ONE: |
1867 | return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; |
1868 | case arith::CmpFPredicate::ORD: |
1869 | return cmpResult != APFloat::cmpUnordered; |
1870 | case arith::CmpFPredicate::UEQ: |
1871 | return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; |
1872 | case arith::CmpFPredicate::UGT: |
1873 | return cmpResult == APFloat::cmpUnordered || |
1874 | cmpResult == APFloat::cmpGreaterThan; |
1875 | case arith::CmpFPredicate::UGE: |
1876 | return cmpResult == APFloat::cmpUnordered || |
1877 | cmpResult == APFloat::cmpGreaterThan || |
1878 | cmpResult == APFloat::cmpEqual; |
1879 | case arith::CmpFPredicate::ULT: |
1880 | return cmpResult == APFloat::cmpUnordered || |
1881 | cmpResult == APFloat::cmpLessThan; |
1882 | case arith::CmpFPredicate::ULE: |
1883 | return cmpResult == APFloat::cmpUnordered || |
1884 | cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; |
1885 | case arith::CmpFPredicate::UNE: |
1886 | return cmpResult != APFloat::cmpEqual; |
1887 | case arith::CmpFPredicate::UNO: |
1888 | return cmpResult == APFloat::cmpUnordered; |
1889 | case arith::CmpFPredicate::AlwaysTrue: |
1890 | return true; |
1891 | } |
1892 | llvm_unreachable("unknown cmpf predicate kind" ); |
1893 | } |
1894 | |
1895 | OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) { |
1896 | auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs()); |
1897 | auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs()); |
1898 | |
1899 | // If one operand is NaN, making them both NaN does not change the result. |
1900 | if (lhs && lhs.getValue().isNaN()) |
1901 | rhs = lhs; |
1902 | if (rhs && rhs.getValue().isNaN()) |
1903 | lhs = rhs; |
1904 | |
1905 | if (!lhs || !rhs) |
1906 | return {}; |
1907 | |
1908 | auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); |
1909 | return BoolAttr::get(getContext(), val); |
1910 | } |
1911 | |
1912 | class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> { |
1913 | public: |
1914 | using OpRewritePattern<CmpFOp>::OpRewritePattern; |
1915 | |
1916 | static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, |
1917 | bool isUnsigned) { |
1918 | using namespace arith; |
1919 | switch (pred) { |
1920 | case CmpFPredicate::UEQ: |
1921 | case CmpFPredicate::OEQ: |
1922 | return CmpIPredicate::eq; |
1923 | case CmpFPredicate::UGT: |
1924 | case CmpFPredicate::OGT: |
1925 | return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt; |
1926 | case CmpFPredicate::UGE: |
1927 | case CmpFPredicate::OGE: |
1928 | return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge; |
1929 | case CmpFPredicate::ULT: |
1930 | case CmpFPredicate::OLT: |
1931 | return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt; |
1932 | case CmpFPredicate::ULE: |
1933 | case CmpFPredicate::OLE: |
1934 | return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle; |
1935 | case CmpFPredicate::UNE: |
1936 | case CmpFPredicate::ONE: |
1937 | return CmpIPredicate::ne; |
1938 | default: |
1939 | llvm_unreachable("Unexpected predicate!" ); |
1940 | } |
1941 | } |
1942 | |
1943 | LogicalResult matchAndRewrite(CmpFOp op, |
1944 | PatternRewriter &rewriter) const override { |
1945 | FloatAttr flt; |
1946 | if (!matchPattern(op.getRhs(), m_Constant(&flt))) |
1947 | return failure(); |
1948 | |
1949 | const APFloat &rhs = flt.getValue(); |
1950 | |
1951 | // Don't attempt to fold a nan. |
1952 | if (rhs.isNaN()) |
1953 | return failure(); |
1954 | |
1955 | // Get the width of the mantissa. We don't want to hack on conversions that |
1956 | // might lose information from the integer, e.g. "i64 -> float" |
1957 | FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType()); |
1958 | int mantissaWidth = floatTy.getFPMantissaWidth(); |
1959 | if (mantissaWidth <= 0) |
1960 | return failure(); |
1961 | |
1962 | bool isUnsigned; |
1963 | Value intVal; |
1964 | |
1965 | if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) { |
1966 | isUnsigned = false; |
1967 | intVal = si.getIn(); |
1968 | } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) { |
1969 | isUnsigned = true; |
1970 | intVal = ui.getIn(); |
1971 | } else { |
1972 | return failure(); |
1973 | } |
1974 | |
1975 | // Check to see that the input is converted from an integer type that is |
1976 | // small enough that preserves all bits. |
1977 | auto intTy = llvm::cast<IntegerType>(intVal.getType()); |
1978 | auto intWidth = intTy.getWidth(); |
1979 | |
1980 | // Number of bits representing values, as opposed to the sign |
1981 | auto valueBits = isUnsigned ? intWidth : (intWidth - 1); |
1982 | |
1983 | // Following test does NOT adjust intWidth downwards for signed inputs, |
1984 | // because the most negative value still requires all the mantissa bits |
1985 | // to distinguish it from one less than that value. |
1986 | if ((int)intWidth > mantissaWidth) { |
1987 | // Conversion would lose accuracy. Check if loss can impact comparison. |
1988 | int exponent = ilogb(Arg: rhs); |
1989 | if (exponent == APFloat::IEK_Inf) { |
1990 | int maxExponent = ilogb(Arg: APFloat::getLargest(Sem: rhs.getSemantics())); |
1991 | if (maxExponent < (int)valueBits) { |
1992 | // Conversion could create infinity. |
1993 | return failure(); |
1994 | } |
1995 | } else { |
1996 | // Note that if rhs is zero or NaN, then Exp is negative |
1997 | // and first condition is trivially false. |
1998 | if (mantissaWidth <= exponent && exponent <= (int)valueBits) { |
1999 | // Conversion could affect comparison. |
2000 | return failure(); |
2001 | } |
2002 | } |
2003 | } |
2004 | |
2005 | // Convert to equivalent cmpi predicate |
2006 | CmpIPredicate pred; |
2007 | switch (op.getPredicate()) { |
2008 | case CmpFPredicate::ORD: |
2009 | // Int to fp conversion doesn't create a nan (ord checks neither is a nan) |
2010 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2011 | /*width=*/1); |
2012 | return success(); |
2013 | case CmpFPredicate::UNO: |
2014 | // Int to fp conversion doesn't create a nan (uno checks either is a nan) |
2015 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2016 | /*width=*/1); |
2017 | return success(); |
2018 | default: |
2019 | pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned); |
2020 | break; |
2021 | } |
2022 | |
2023 | if (!isUnsigned) { |
2024 | // If the rhs value is > SignedMax, fold the comparison. This handles |
2025 | // +INF and large values. |
2026 | APFloat signedMax(rhs.getSemantics()); |
2027 | signedMax.convertFromAPInt(Input: APInt::getSignedMaxValue(numBits: intWidth), IsSigned: true, |
2028 | RM: APFloat::rmNearestTiesToEven); |
2029 | if (signedMax < rhs) { // smax < 13123.0 |
2030 | if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt || |
2031 | pred == CmpIPredicate::sle) |
2032 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2033 | /*width=*/1); |
2034 | else |
2035 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2036 | /*width=*/1); |
2037 | return success(); |
2038 | } |
2039 | } else { |
2040 | // If the rhs value is > UnsignedMax, fold the comparison. This handles |
2041 | // +INF and large values. |
2042 | APFloat unsignedMax(rhs.getSemantics()); |
2043 | unsignedMax.convertFromAPInt(Input: APInt::getMaxValue(numBits: intWidth), IsSigned: false, |
2044 | RM: APFloat::rmNearestTiesToEven); |
2045 | if (unsignedMax < rhs) { // umax < 13123.0 |
2046 | if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult || |
2047 | pred == CmpIPredicate::ule) |
2048 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2049 | /*width=*/1); |
2050 | else |
2051 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2052 | /*width=*/1); |
2053 | return success(); |
2054 | } |
2055 | } |
2056 | |
2057 | if (!isUnsigned) { |
2058 | // See if the rhs value is < SignedMin. |
2059 | APFloat signedMin(rhs.getSemantics()); |
2060 | signedMin.convertFromAPInt(Input: APInt::getSignedMinValue(numBits: intWidth), IsSigned: true, |
2061 | RM: APFloat::rmNearestTiesToEven); |
2062 | if (signedMin > rhs) { // smin > 12312.0 |
2063 | if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt || |
2064 | pred == CmpIPredicate::sge) |
2065 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2066 | /*width=*/1); |
2067 | else |
2068 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2069 | /*width=*/1); |
2070 | return success(); |
2071 | } |
2072 | } else { |
2073 | // See if the rhs value is < UnsignedMin. |
2074 | APFloat unsignedMin(rhs.getSemantics()); |
2075 | unsignedMin.convertFromAPInt(Input: APInt::getMinValue(numBits: intWidth), IsSigned: false, |
2076 | RM: APFloat::rmNearestTiesToEven); |
2077 | if (unsignedMin > rhs) { // umin > 12312.0 |
2078 | if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt || |
2079 | pred == CmpIPredicate::uge) |
2080 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2081 | /*width=*/1); |
2082 | else |
2083 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2084 | /*width=*/1); |
2085 | return success(); |
2086 | } |
2087 | } |
2088 | |
2089 | // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or |
2090 | // [0, UMAX], but it may still be fractional. See if it is fractional by |
2091 | // casting the FP value to the integer value and back, checking for |
2092 | // equality. Don't do this for zero, because -0.0 is not fractional. |
2093 | bool ignored; |
2094 | APSInt rhsInt(intWidth, isUnsigned); |
2095 | if (APFloat::opInvalidOp == |
2096 | rhs.convertToInteger(Result&: rhsInt, RM: APFloat::rmTowardZero, IsExact: &ignored)) { |
2097 | // Undefined behavior invoked - the destination type can't represent |
2098 | // the input constant. |
2099 | return failure(); |
2100 | } |
2101 | |
2102 | if (!rhs.isZero()) { |
2103 | APFloat apf(floatTy.getFloatSemantics(), |
2104 | APInt::getZero(numBits: floatTy.getWidth())); |
2105 | apf.convertFromAPInt(Input: rhsInt, IsSigned: !isUnsigned, RM: APFloat::rmNearestTiesToEven); |
2106 | |
2107 | bool equal = apf == rhs; |
2108 | if (!equal) { |
2109 | // If we had a comparison against a fractional value, we have to adjust |
2110 | // the compare predicate and sometimes the value. rhsInt is rounded |
2111 | // towards zero at this point. |
2112 | switch (pred) { |
2113 | case CmpIPredicate::ne: // (float)int != 4.4 --> true |
2114 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2115 | /*width=*/1); |
2116 | return success(); |
2117 | case CmpIPredicate::eq: // (float)int == 4.4 --> false |
2118 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2119 | /*width=*/1); |
2120 | return success(); |
2121 | case CmpIPredicate::ule: |
2122 | // (float)int <= 4.4 --> int <= 4 |
2123 | // (float)int <= -4.4 --> false |
2124 | if (rhs.isNegative()) { |
2125 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2126 | /*width=*/1); |
2127 | return success(); |
2128 | } |
2129 | break; |
2130 | case CmpIPredicate::sle: |
2131 | // (float)int <= 4.4 --> int <= 4 |
2132 | // (float)int <= -4.4 --> int < -4 |
2133 | if (rhs.isNegative()) |
2134 | pred = CmpIPredicate::slt; |
2135 | break; |
2136 | case CmpIPredicate::ult: |
2137 | // (float)int < -4.4 --> false |
2138 | // (float)int < 4.4 --> int <= 4 |
2139 | if (rhs.isNegative()) { |
2140 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2141 | /*width=*/1); |
2142 | return success(); |
2143 | } |
2144 | pred = CmpIPredicate::ule; |
2145 | break; |
2146 | case CmpIPredicate::slt: |
2147 | // (float)int < -4.4 --> int < -4 |
2148 | // (float)int < 4.4 --> int <= 4 |
2149 | if (!rhs.isNegative()) |
2150 | pred = CmpIPredicate::sle; |
2151 | break; |
2152 | case CmpIPredicate::ugt: |
2153 | // (float)int > 4.4 --> int > 4 |
2154 | // (float)int > -4.4 --> true |
2155 | if (rhs.isNegative()) { |
2156 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2157 | /*width=*/1); |
2158 | return success(); |
2159 | } |
2160 | break; |
2161 | case CmpIPredicate::sgt: |
2162 | // (float)int > 4.4 --> int > 4 |
2163 | // (float)int > -4.4 --> int >= -4 |
2164 | if (rhs.isNegative()) |
2165 | pred = CmpIPredicate::sge; |
2166 | break; |
2167 | case CmpIPredicate::uge: |
2168 | // (float)int >= -4.4 --> true |
2169 | // (float)int >= 4.4 --> int > 4 |
2170 | if (rhs.isNegative()) { |
2171 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2172 | /*width=*/1); |
2173 | return success(); |
2174 | } |
2175 | pred = CmpIPredicate::ugt; |
2176 | break; |
2177 | case CmpIPredicate::sge: |
2178 | // (float)int >= -4.4 --> int >= -4 |
2179 | // (float)int >= 4.4 --> int > 4 |
2180 | if (!rhs.isNegative()) |
2181 | pred = CmpIPredicate::sgt; |
2182 | break; |
2183 | } |
2184 | } |
2185 | } |
2186 | |
2187 | // Lower this FP comparison into an appropriate integer version of the |
2188 | // comparison. |
2189 | rewriter.replaceOpWithNewOp<CmpIOp>( |
2190 | op, pred, intVal, |
2191 | rewriter.create<ConstantOp>( |
2192 | op.getLoc(), intVal.getType(), |
2193 | rewriter.getIntegerAttr(intVal.getType(), rhsInt))); |
2194 | return success(); |
2195 | } |
2196 | }; |
2197 | |
2198 | void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
2199 | MLIRContext *context) { |
2200 | patterns.insert<CmpFIntToFPConst>(context); |
2201 | } |
2202 | |
2203 | //===----------------------------------------------------------------------===// |
2204 | // SelectOp |
2205 | //===----------------------------------------------------------------------===// |
2206 | |
2207 | // select %arg, %c1, %c0 => extui %arg |
2208 | struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> { |
2209 | using OpRewritePattern<arith::SelectOp>::OpRewritePattern; |
2210 | |
2211 | LogicalResult matchAndRewrite(arith::SelectOp op, |
2212 | PatternRewriter &rewriter) const override { |
2213 | // Cannot extui i1 to i1, or i1 to f32 |
2214 | if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1)) |
2215 | return failure(); |
2216 | |
2217 | // select %x, c1, %c0 => extui %arg |
2218 | if (matchPattern(op.getTrueValue(), m_One()) && |
2219 | matchPattern(op.getFalseValue(), m_Zero())) { |
2220 | rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(), |
2221 | op.getCondition()); |
2222 | return success(); |
2223 | } |
2224 | |
2225 | // select %x, c0, %c1 => extui (xor %arg, true) |
2226 | if (matchPattern(op.getTrueValue(), m_Zero()) && |
2227 | matchPattern(op.getFalseValue(), m_One())) { |
2228 | rewriter.replaceOpWithNewOp<arith::ExtUIOp>( |
2229 | op, op.getType(), |
2230 | rewriter.create<arith::XOrIOp>( |
2231 | op.getLoc(), op.getCondition(), |
2232 | rewriter.create<arith::ConstantIntOp>( |
2233 | op.getLoc(), 1, op.getCondition().getType()))); |
2234 | return success(); |
2235 | } |
2236 | |
2237 | return failure(); |
2238 | } |
2239 | }; |
2240 | |
2241 | void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, |
2242 | MLIRContext *context) { |
2243 | results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond, |
2244 | SelectI1ToNot, SelectToExtUI>(context); |
2245 | } |
2246 | |
2247 | OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) { |
2248 | Value trueVal = getTrueValue(); |
2249 | Value falseVal = getFalseValue(); |
2250 | if (trueVal == falseVal) |
2251 | return trueVal; |
2252 | |
2253 | Value condition = getCondition(); |
2254 | |
2255 | // select true, %0, %1 => %0 |
2256 | if (matchPattern(adaptor.getCondition(), m_One())) |
2257 | return trueVal; |
2258 | |
2259 | // select false, %0, %1 => %1 |
2260 | if (matchPattern(adaptor.getCondition(), m_Zero())) |
2261 | return falseVal; |
2262 | |
2263 | // If either operand is fully poisoned, return the other. |
2264 | if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue())) |
2265 | return falseVal; |
2266 | |
2267 | if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue())) |
2268 | return trueVal; |
2269 | |
2270 | // select %x, true, false => %x |
2271 | if (getType().isInteger(1) && matchPattern(adaptor.getTrueValue(), m_One()) && |
2272 | matchPattern(adaptor.getFalseValue(), m_Zero())) |
2273 | return condition; |
2274 | |
2275 | if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) { |
2276 | auto pred = cmp.getPredicate(); |
2277 | if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { |
2278 | auto cmpLhs = cmp.getLhs(); |
2279 | auto cmpRhs = cmp.getRhs(); |
2280 | |
2281 | // %0 = arith.cmpi eq, %arg0, %arg1 |
2282 | // %1 = arith.select %0, %arg0, %arg1 => %arg1 |
2283 | |
2284 | // %0 = arith.cmpi ne, %arg0, %arg1 |
2285 | // %1 = arith.select %0, %arg0, %arg1 => %arg0 |
2286 | |
2287 | if ((cmpLhs == trueVal && cmpRhs == falseVal) || |
2288 | (cmpRhs == trueVal && cmpLhs == falseVal)) |
2289 | return pred == arith::CmpIPredicate::ne ? trueVal : falseVal; |
2290 | } |
2291 | } |
2292 | |
2293 | // Constant-fold constant operands over non-splat constant condition. |
2294 | // select %cst_vec, %cst0, %cst1 => %cst2 |
2295 | if (auto cond = |
2296 | llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) { |
2297 | if (auto lhs = |
2298 | llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) { |
2299 | if (auto rhs = |
2300 | llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) { |
2301 | SmallVector<Attribute> results; |
2302 | results.reserve(static_cast<size_t>(cond.getNumElements())); |
2303 | auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(), |
2304 | cond.value_end<BoolAttr>()); |
2305 | auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(), |
2306 | lhs.value_end<Attribute>()); |
2307 | auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(), |
2308 | rhs.value_end<Attribute>()); |
2309 | |
2310 | for (auto [condVal, lhsVal, rhsVal] : |
2311 | llvm::zip_equal(condVals, lhsVals, rhsVals)) |
2312 | results.push_back(condVal.getValue() ? lhsVal : rhsVal); |
2313 | |
2314 | return DenseElementsAttr::get(lhs.getType(), results); |
2315 | } |
2316 | } |
2317 | } |
2318 | |
2319 | return nullptr; |
2320 | } |
2321 | |
2322 | ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { |
2323 | Type conditionType, resultType; |
2324 | SmallVector<OpAsmParser::UnresolvedOperand, 3> operands; |
2325 | if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || |
2326 | parser.parseOptionalAttrDict(result.attributes) || |
2327 | parser.parseColonType(resultType)) |
2328 | return failure(); |
2329 | |
2330 | // Check for the explicit condition type if this is a masked tensor or vector. |
2331 | if (succeeded(parser.parseOptionalComma())) { |
2332 | conditionType = resultType; |
2333 | if (parser.parseType(resultType)) |
2334 | return failure(); |
2335 | } else { |
2336 | conditionType = parser.getBuilder().getI1Type(); |
2337 | } |
2338 | |
2339 | result.addTypes(resultType); |
2340 | return parser.resolveOperands(operands, |
2341 | {conditionType, resultType, resultType}, |
2342 | parser.getNameLoc(), result.operands); |
2343 | } |
2344 | |
2345 | void arith::SelectOp::print(OpAsmPrinter &p) { |
2346 | p << " " << getOperands(); |
2347 | p.printOptionalAttrDict((*this)->getAttrs()); |
2348 | p << " : " ; |
2349 | if (ShapedType condType = |
2350 | llvm::dyn_cast<ShapedType>(getCondition().getType())) |
2351 | p << condType << ", " ; |
2352 | p << getType(); |
2353 | } |
2354 | |
2355 | LogicalResult arith::SelectOp::verify() { |
2356 | Type conditionType = getCondition().getType(); |
2357 | if (conditionType.isSignlessInteger(1)) |
2358 | return success(); |
2359 | |
2360 | // If the result type is a vector or tensor, the type can be a mask with the |
2361 | // same elements. |
2362 | Type resultType = getType(); |
2363 | if (!llvm::isa<TensorType, VectorType>(resultType)) |
2364 | return emitOpError() << "expected condition to be a signless i1, but got " |
2365 | << conditionType; |
2366 | Type shapedConditionType = getI1SameShape(resultType); |
2367 | if (conditionType != shapedConditionType) { |
2368 | return emitOpError() << "expected condition type to have the same shape " |
2369 | "as the result type, expected " |
2370 | << shapedConditionType << ", but got " |
2371 | << conditionType; |
2372 | } |
2373 | return success(); |
2374 | } |
2375 | //===----------------------------------------------------------------------===// |
2376 | // ShLIOp |
2377 | //===----------------------------------------------------------------------===// |
2378 | |
2379 | OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) { |
2380 | // shli(x, 0) -> x |
2381 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
2382 | return getLhs(); |
2383 | // Don't fold if shifting more or equal than the bit width. |
2384 | bool bounded = false; |
2385 | auto result = constFoldBinaryOp<IntegerAttr>( |
2386 | adaptor.getOperands(), [&](const APInt &a, const APInt &b) { |
2387 | bounded = b.ult(b.getBitWidth()); |
2388 | return a.shl(b); |
2389 | }); |
2390 | return bounded ? result : Attribute(); |
2391 | } |
2392 | |
2393 | //===----------------------------------------------------------------------===// |
2394 | // ShRUIOp |
2395 | //===----------------------------------------------------------------------===// |
2396 | |
2397 | OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) { |
2398 | // shrui(x, 0) -> x |
2399 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
2400 | return getLhs(); |
2401 | // Don't fold if shifting more or equal than the bit width. |
2402 | bool bounded = false; |
2403 | auto result = constFoldBinaryOp<IntegerAttr>( |
2404 | adaptor.getOperands(), [&](const APInt &a, const APInt &b) { |
2405 | bounded = b.ult(b.getBitWidth()); |
2406 | return a.lshr(b); |
2407 | }); |
2408 | return bounded ? result : Attribute(); |
2409 | } |
2410 | |
2411 | //===----------------------------------------------------------------------===// |
2412 | // ShRSIOp |
2413 | //===----------------------------------------------------------------------===// |
2414 | |
2415 | OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) { |
2416 | // shrsi(x, 0) -> x |
2417 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
2418 | return getLhs(); |
2419 | // Don't fold if shifting more or equal than the bit width. |
2420 | bool bounded = false; |
2421 | auto result = constFoldBinaryOp<IntegerAttr>( |
2422 | adaptor.getOperands(), [&](const APInt &a, const APInt &b) { |
2423 | bounded = b.ult(b.getBitWidth()); |
2424 | return a.ashr(b); |
2425 | }); |
2426 | return bounded ? result : Attribute(); |
2427 | } |
2428 | |
2429 | //===----------------------------------------------------------------------===// |
2430 | // Atomic Enum |
2431 | //===----------------------------------------------------------------------===// |
2432 | |
2433 | /// Returns the identity value attribute associated with an AtomicRMWKind op. |
2434 | TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, |
2435 | OpBuilder &builder, Location loc, |
2436 | bool useOnlyFiniteValue) { |
2437 | switch (kind) { |
2438 | case AtomicRMWKind::maximumf: { |
2439 | const llvm::fltSemantics &semantic = |
2440 | llvm::cast<FloatType>(Val&: resultType).getFloatSemantics(); |
2441 | APFloat identity = useOnlyFiniteValue |
2442 | ? APFloat::getLargest(Sem: semantic, /*Negative=*/true) |
2443 | : APFloat::getInf(Sem: semantic, /*Negative=*/true); |
2444 | return builder.getFloatAttr(resultType, identity); |
2445 | } |
2446 | case AtomicRMWKind::addf: |
2447 | case AtomicRMWKind::addi: |
2448 | case AtomicRMWKind::maxu: |
2449 | case AtomicRMWKind::ori: |
2450 | return builder.getZeroAttr(resultType); |
2451 | case AtomicRMWKind::andi: |
2452 | return builder.getIntegerAttr( |
2453 | resultType, |
2454 | APInt::getAllOnes(numBits: llvm::cast<IntegerType>(resultType).getWidth())); |
2455 | case AtomicRMWKind::maxs: |
2456 | return builder.getIntegerAttr( |
2457 | resultType, APInt::getSignedMinValue( |
2458 | numBits: llvm::cast<IntegerType>(resultType).getWidth())); |
2459 | case AtomicRMWKind::minimumf: { |
2460 | const llvm::fltSemantics &semantic = |
2461 | llvm::cast<FloatType>(Val&: resultType).getFloatSemantics(); |
2462 | APFloat identity = useOnlyFiniteValue |
2463 | ? APFloat::getLargest(Sem: semantic, /*Negative=*/false) |
2464 | : APFloat::getInf(Sem: semantic, /*Negative=*/false); |
2465 | |
2466 | return builder.getFloatAttr(resultType, identity); |
2467 | } |
2468 | case AtomicRMWKind::mins: |
2469 | return builder.getIntegerAttr( |
2470 | resultType, APInt::getSignedMaxValue( |
2471 | numBits: llvm::cast<IntegerType>(resultType).getWidth())); |
2472 | case AtomicRMWKind::minu: |
2473 | return builder.getIntegerAttr( |
2474 | resultType, |
2475 | APInt::getMaxValue(numBits: llvm::cast<IntegerType>(resultType).getWidth())); |
2476 | case AtomicRMWKind::muli: |
2477 | return builder.getIntegerAttr(resultType, 1); |
2478 | case AtomicRMWKind::mulf: |
2479 | return builder.getFloatAttr(resultType, 1); |
2480 | // TODO: Add remaining reduction operations. |
2481 | default: |
2482 | (void)emitOptionalError(loc, args: "Reduction operation type not supported" ); |
2483 | break; |
2484 | } |
2485 | return nullptr; |
2486 | } |
2487 | |
2488 | /// Return the identity numeric value associated to the give op. |
2489 | std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) { |
2490 | std::optional<AtomicRMWKind> maybeKind = |
2491 | llvm::TypeSwitch<Operation *, std::optional<AtomicRMWKind>>(op) |
2492 | // Floating-point operations. |
2493 | .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; }) |
2494 | .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; }) |
2495 | .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; }) |
2496 | .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; }) |
2497 | // Integer operations. |
2498 | .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; }) |
2499 | .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; }) |
2500 | .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; }) |
2501 | .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; }) |
2502 | .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; }) |
2503 | .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; }) |
2504 | .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; }) |
2505 | .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; }) |
2506 | .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; }) |
2507 | .Default([](Operation *op) { return std::nullopt; }); |
2508 | if (!maybeKind) { |
2509 | op->emitError() << "Unknown neutral element for: " << *op; |
2510 | return std::nullopt; |
2511 | } |
2512 | |
2513 | bool useOnlyFiniteValue = false; |
2514 | auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op); |
2515 | if (fmfOpInterface) { |
2516 | arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr(); |
2517 | useOnlyFiniteValue = |
2518 | bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf); |
2519 | } |
2520 | |
2521 | // Builder only used as helper for attribute creation. |
2522 | OpBuilder b(op->getContext()); |
2523 | Type resultType = op->getResult(idx: 0).getType(); |
2524 | |
2525 | return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(), |
2526 | useOnlyFiniteValue); |
2527 | } |
2528 | |
2529 | /// Returns the identity value associated with an AtomicRMWKind op. |
2530 | Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, |
2531 | OpBuilder &builder, Location loc, |
2532 | bool useOnlyFiniteValue) { |
2533 | auto attr = |
2534 | getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue); |
2535 | return builder.create<arith::ConstantOp>(loc, attr); |
2536 | } |
2537 | |
2538 | /// Return the value obtained by applying the reduction operation kind |
2539 | /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. |
2540 | Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, |
2541 | Location loc, Value lhs, Value rhs) { |
2542 | switch (op) { |
2543 | case AtomicRMWKind::addf: |
2544 | return builder.create<arith::AddFOp>(loc, lhs, rhs); |
2545 | case AtomicRMWKind::addi: |
2546 | return builder.create<arith::AddIOp>(loc, lhs, rhs); |
2547 | case AtomicRMWKind::mulf: |
2548 | return builder.create<arith::MulFOp>(loc, lhs, rhs); |
2549 | case AtomicRMWKind::muli: |
2550 | return builder.create<arith::MulIOp>(loc, lhs, rhs); |
2551 | case AtomicRMWKind::maximumf: |
2552 | return builder.create<arith::MaximumFOp>(loc, lhs, rhs); |
2553 | case AtomicRMWKind::minimumf: |
2554 | return builder.create<arith::MinimumFOp>(loc, lhs, rhs); |
2555 | case AtomicRMWKind::maxnumf: |
2556 | return builder.create<arith::MaxNumFOp>(loc, lhs, rhs); |
2557 | case AtomicRMWKind::minnumf: |
2558 | return builder.create<arith::MinNumFOp>(loc, lhs, rhs); |
2559 | case AtomicRMWKind::maxs: |
2560 | return builder.create<arith::MaxSIOp>(loc, lhs, rhs); |
2561 | case AtomicRMWKind::mins: |
2562 | return builder.create<arith::MinSIOp>(loc, lhs, rhs); |
2563 | case AtomicRMWKind::maxu: |
2564 | return builder.create<arith::MaxUIOp>(loc, lhs, rhs); |
2565 | case AtomicRMWKind::minu: |
2566 | return builder.create<arith::MinUIOp>(loc, lhs, rhs); |
2567 | case AtomicRMWKind::ori: |
2568 | return builder.create<arith::OrIOp>(loc, lhs, rhs); |
2569 | case AtomicRMWKind::andi: |
2570 | return builder.create<arith::AndIOp>(loc, lhs, rhs); |
2571 | // TODO: Add remaining reduction operations. |
2572 | default: |
2573 | (void)emitOptionalError(loc, args: "Reduction operation type not supported" ); |
2574 | break; |
2575 | } |
2576 | return nullptr; |
2577 | } |
2578 | |
2579 | //===----------------------------------------------------------------------===// |
2580 | // TableGen'd op method definitions |
2581 | //===----------------------------------------------------------------------===// |
2582 | |
2583 | #define GET_OP_CLASSES |
2584 | #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc" |
2585 | |
2586 | //===----------------------------------------------------------------------===// |
2587 | // TableGen'd enum attribute definitions |
2588 | //===----------------------------------------------------------------------===// |
2589 | |
2590 | #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc" |
2591 | |