1 | //===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <cassert> |
10 | #include <cstdint> |
11 | #include <functional> |
12 | #include <utility> |
13 | |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/CommonFolders.h" |
16 | #include "mlir/Dialect/UB/IR/UBOps.h" |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/BuiltinAttributeInterfaces.h" |
19 | #include "mlir/IR/BuiltinAttributes.h" |
20 | #include "mlir/IR/Matchers.h" |
21 | #include "mlir/IR/OpImplementation.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | #include "mlir/IR/TypeUtilities.h" |
24 | #include "mlir/Support/LogicalResult.h" |
25 | |
26 | #include "llvm/ADT/APFloat.h" |
27 | #include "llvm/ADT/APInt.h" |
28 | #include "llvm/ADT/APSInt.h" |
29 | #include "llvm/ADT/FloatingPointMode.h" |
30 | #include "llvm/ADT/STLExtras.h" |
31 | #include "llvm/ADT/SmallString.h" |
32 | #include "llvm/ADT/SmallVector.h" |
33 | #include "llvm/ADT/TypeSwitch.h" |
34 | |
35 | using namespace mlir; |
36 | using namespace mlir::arith; |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // Pattern helpers |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | static IntegerAttr |
43 | applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, |
44 | Attribute rhs, |
45 | function_ref<APInt(const APInt &, const APInt &)> binFn) { |
46 | APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue(); |
47 | APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue(); |
48 | APInt value = binFn(lhsVal, rhsVal); |
49 | return IntegerAttr::get(res.getType(), value); |
50 | } |
51 | |
52 | static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, |
53 | Attribute lhs, Attribute rhs) { |
54 | return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>()); |
55 | } |
56 | |
57 | static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, |
58 | Attribute lhs, Attribute rhs) { |
59 | return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>()); |
60 | } |
61 | |
62 | static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, |
63 | Attribute lhs, Attribute rhs) { |
64 | return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>()); |
65 | } |
66 | |
67 | // Merge overflow flags from 2 ops, selecting the most conservative combination. |
68 | static IntegerOverflowFlagsAttr |
69 | mergeOverflowFlags(IntegerOverflowFlagsAttr val1, |
70 | IntegerOverflowFlagsAttr val2) { |
71 | return IntegerOverflowFlagsAttr::get(val1.getContext(), |
72 | val1.getValue() & val2.getValue()); |
73 | } |
74 | |
75 | /// Invert an integer comparison predicate. |
76 | arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { |
77 | switch (pred) { |
78 | case arith::CmpIPredicate::eq: |
79 | return arith::CmpIPredicate::ne; |
80 | case arith::CmpIPredicate::ne: |
81 | return arith::CmpIPredicate::eq; |
82 | case arith::CmpIPredicate::slt: |
83 | return arith::CmpIPredicate::sge; |
84 | case arith::CmpIPredicate::sle: |
85 | return arith::CmpIPredicate::sgt; |
86 | case arith::CmpIPredicate::sgt: |
87 | return arith::CmpIPredicate::sle; |
88 | case arith::CmpIPredicate::sge: |
89 | return arith::CmpIPredicate::slt; |
90 | case arith::CmpIPredicate::ult: |
91 | return arith::CmpIPredicate::uge; |
92 | case arith::CmpIPredicate::ule: |
93 | return arith::CmpIPredicate::ugt; |
94 | case arith::CmpIPredicate::ugt: |
95 | return arith::CmpIPredicate::ule; |
96 | case arith::CmpIPredicate::uge: |
97 | return arith::CmpIPredicate::ult; |
98 | } |
99 | llvm_unreachable("unknown cmpi predicate kind"); |
100 | } |
101 | |
102 | /// Equivalent to |
103 | /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)). |
104 | /// |
105 | /// Not possible to implement as chain of calls as this would introduce a |
106 | /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend |
107 | /// on the LLVM dialect and on translation to LLVM. |
108 | static llvm::RoundingMode |
109 | convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) { |
110 | switch (roundingMode) { |
111 | case RoundingMode::downward: |
112 | return llvm::RoundingMode::TowardNegative; |
113 | case RoundingMode::to_nearest_away: |
114 | return llvm::RoundingMode::NearestTiesToAway; |
115 | case RoundingMode::to_nearest_even: |
116 | return llvm::RoundingMode::NearestTiesToEven; |
117 | case RoundingMode::toward_zero: |
118 | return llvm::RoundingMode::TowardZero; |
119 | case RoundingMode::upward: |
120 | return llvm::RoundingMode::TowardPositive; |
121 | } |
122 | llvm_unreachable("Unhandled rounding mode"); |
123 | } |
124 | |
125 | static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { |
126 | return arith::CmpIPredicateAttr::get(pred.getContext(), |
127 | invertPredicate(pred.getValue())); |
128 | } |
129 | |
130 | static int64_t getScalarOrElementWidth(Type type) { |
131 | Type elemTy = getElementTypeOrSelf(type); |
132 | if (elemTy.isIntOrFloat()) |
133 | return elemTy.getIntOrFloatBitWidth(); |
134 | |
135 | return -1; |
136 | } |
137 | |
138 | static int64_t getScalarOrElementWidth(Value value) { |
139 | return getScalarOrElementWidth(type: value.getType()); |
140 | } |
141 | |
142 | static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) { |
143 | APInt value; |
144 | if (matchPattern(attr, m_ConstantInt(&value))) |
145 | return value; |
146 | |
147 | return failure(); |
148 | } |
149 | |
150 | static Attribute getBoolAttribute(Type type, bool value) { |
151 | auto boolAttr = BoolAttr::get(context: type.getContext(), value); |
152 | ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type); |
153 | if (!shapedType) |
154 | return boolAttr; |
155 | return DenseElementsAttr::get(shapedType, boolAttr); |
156 | } |
157 | |
158 | //===----------------------------------------------------------------------===// |
159 | // TableGen'd canonicalization patterns |
160 | //===----------------------------------------------------------------------===// |
161 | |
162 | namespace { |
163 | #include "ArithCanonicalization.inc" |
164 | } // namespace |
165 | |
166 | //===----------------------------------------------------------------------===// |
167 | // Common helpers |
168 | //===----------------------------------------------------------------------===// |
169 | |
170 | /// Return the type of the same shape (scalar, vector or tensor) containing i1. |
171 | static Type getI1SameShape(Type type) { |
172 | auto i1Type = IntegerType::get(type.getContext(), 1); |
173 | if (auto shapedType = llvm::dyn_cast<ShapedType>(type)) |
174 | return shapedType.cloneWith(std::nullopt, i1Type); |
175 | if (llvm::isa<UnrankedTensorType>(type)) |
176 | return UnrankedTensorType::get(i1Type); |
177 | return i1Type; |
178 | } |
179 | |
180 | //===----------------------------------------------------------------------===// |
181 | // ConstantOp |
182 | //===----------------------------------------------------------------------===// |
183 | |
184 | void arith::ConstantOp::getAsmResultNames( |
185 | function_ref<void(Value, StringRef)> setNameFn) { |
186 | auto type = getType(); |
187 | if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) { |
188 | auto intType = llvm::dyn_cast<IntegerType>(type); |
189 | |
190 | // Sugar i1 constants with 'true' and 'false'. |
191 | if (intType && intType.getWidth() == 1) |
192 | return setNameFn(getResult(), (intCst.getInt() ? "true": "false")); |
193 | |
194 | // Otherwise, build a complex name with the value and type. |
195 | SmallString<32> specialNameBuffer; |
196 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
197 | specialName << 'c' << intCst.getValue(); |
198 | if (intType) |
199 | specialName << '_' << type; |
200 | setNameFn(getResult(), specialName.str()); |
201 | } else { |
202 | setNameFn(getResult(), "cst"); |
203 | } |
204 | } |
205 | |
206 | /// TODO: disallow arith.constant to return anything other than signless integer |
207 | /// or float like. |
208 | LogicalResult arith::ConstantOp::verify() { |
209 | auto type = getType(); |
210 | // Integer values must be signless. |
211 | if (llvm::isa<IntegerType>(type) && |
212 | !llvm::cast<IntegerType>(type).isSignless()) |
213 | return emitOpError("integer return type must be signless"); |
214 | // Any float or elements attribute are acceptable. |
215 | if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) { |
216 | return emitOpError( |
217 | "value must be an integer, float, or elements attribute"); |
218 | } |
219 | |
220 | // Note, we could relax this for vectors with 1 scalable dim, e.g.: |
221 | // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32> |
222 | // However, this would most likely require updating the lowerings to LLVM. |
223 | if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue())) |
224 | return emitOpError( |
225 | "intializing scalable vectors with elements attribute is not supported" |
226 | " unless it's a vector splat"); |
227 | return success(); |
228 | } |
229 | |
230 | bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { |
231 | // The value's type must be the same as the provided type. |
232 | auto typedAttr = llvm::dyn_cast<TypedAttr>(value); |
233 | if (!typedAttr || typedAttr.getType() != type) |
234 | return false; |
235 | // Integer values must be signless. |
236 | if (llvm::isa<IntegerType>(type) && |
237 | !llvm::cast<IntegerType>(type).isSignless()) |
238 | return false; |
239 | // Integer, float, and element attributes are buildable. |
240 | return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value); |
241 | } |
242 | |
243 | ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value, |
244 | Type type, Location loc) { |
245 | if (isBuildableWith(value, type)) |
246 | return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value)); |
247 | return nullptr; |
248 | } |
249 | |
250 | OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } |
251 | |
252 | void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, |
253 | int64_t value, unsigned width) { |
254 | auto type = builder.getIntegerType(width); |
255 | arith::ConstantOp::build(builder, result, type, |
256 | builder.getIntegerAttr(type, value)); |
257 | } |
258 | |
259 | void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, |
260 | int64_t value, Type type) { |
261 | assert(type.isSignlessInteger() && |
262 | "ConstantIntOp can only have signless integer type values"); |
263 | arith::ConstantOp::build(builder, result, type, |
264 | builder.getIntegerAttr(type, value)); |
265 | } |
266 | |
267 | bool arith::ConstantIntOp::classof(Operation *op) { |
268 | if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) |
269 | return constOp.getType().isSignlessInteger(); |
270 | return false; |
271 | } |
272 | |
273 | void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, |
274 | const APFloat &value, FloatType type) { |
275 | arith::ConstantOp::build(builder, result, type, |
276 | builder.getFloatAttr(type, value)); |
277 | } |
278 | |
279 | bool arith::ConstantFloatOp::classof(Operation *op) { |
280 | if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) |
281 | return llvm::isa<FloatType>(constOp.getType()); |
282 | return false; |
283 | } |
284 | |
285 | void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, |
286 | int64_t value) { |
287 | arith::ConstantOp::build(builder, result, builder.getIndexType(), |
288 | builder.getIndexAttr(value)); |
289 | } |
290 | |
291 | bool arith::ConstantIndexOp::classof(Operation *op) { |
292 | if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) |
293 | return constOp.getType().isIndex(); |
294 | return false; |
295 | } |
296 | |
297 | //===----------------------------------------------------------------------===// |
298 | // AddIOp |
299 | //===----------------------------------------------------------------------===// |
300 | |
301 | OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) { |
302 | // addi(x, 0) -> x |
303 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
304 | return getLhs(); |
305 | |
306 | // addi(subi(a, b), b) -> a |
307 | if (auto sub = getLhs().getDefiningOp<SubIOp>()) |
308 | if (getRhs() == sub.getRhs()) |
309 | return sub.getLhs(); |
310 | |
311 | // addi(b, subi(a, b)) -> a |
312 | if (auto sub = getRhs().getDefiningOp<SubIOp>()) |
313 | if (getLhs() == sub.getRhs()) |
314 | return sub.getLhs(); |
315 | |
316 | return constFoldBinaryOp<IntegerAttr>( |
317 | adaptor.getOperands(), |
318 | [](APInt a, const APInt &b) { return std::move(a) + b; }); |
319 | } |
320 | |
321 | void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
322 | MLIRContext *context) { |
323 | patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS, |
324 | AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context); |
325 | } |
326 | |
327 | //===----------------------------------------------------------------------===// |
328 | // AddUIExtendedOp |
329 | //===----------------------------------------------------------------------===// |
330 | |
331 | std::optional<SmallVector<int64_t, 4>> |
332 | arith::AddUIExtendedOp::getShapeForUnroll() { |
333 | if (auto vt = llvm::dyn_cast<VectorType>(getType(0))) |
334 | return llvm::to_vector<4>(vt.getShape()); |
335 | return std::nullopt; |
336 | } |
337 | |
338 | // Returns the overflow bit, assuming that `sum` is the result of unsigned |
339 | // addition of `operand` and another number. |
340 | static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) { |
341 | return sum.ult(RHS: operand) ? APInt::getAllOnes(numBits: 1) : APInt::getZero(numBits: 1); |
342 | } |
343 | |
344 | LogicalResult |
345 | arith::AddUIExtendedOp::fold(FoldAdaptor adaptor, |
346 | SmallVectorImpl<OpFoldResult> &results) { |
347 | Type overflowTy = getOverflow().getType(); |
348 | // addui_extended(x, 0) -> x, false |
349 | if (matchPattern(getRhs(), m_Zero())) { |
350 | Builder builder(getContext()); |
351 | auto falseValue = builder.getZeroAttr(overflowTy); |
352 | |
353 | results.push_back(getLhs()); |
354 | results.push_back(falseValue); |
355 | return success(); |
356 | } |
357 | |
358 | // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry |
359 | // Let the `constFoldBinaryOp` utility attempt to fold the sum of both |
360 | // operands. If that succeeds, calculate the overflow bit based on the sum |
361 | // and the first (constant) operand, `lhs`. |
362 | if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>( |
363 | adaptor.getOperands(), |
364 | [](APInt a, const APInt &b) { return std::move(a) + b; })) { |
365 | Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>( |
366 | ArrayRef({sumAttr, adaptor.getLhs()}), |
367 | getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()), |
368 | calculateUnsignedOverflow); |
369 | if (!overflowAttr) |
370 | return failure(); |
371 | |
372 | results.push_back(sumAttr); |
373 | results.push_back(overflowAttr); |
374 | return success(); |
375 | } |
376 | |
377 | return failure(); |
378 | } |
379 | |
380 | void arith::AddUIExtendedOp::getCanonicalizationPatterns( |
381 | RewritePatternSet &patterns, MLIRContext *context) { |
382 | patterns.add<AddUIExtendedToAddI>(context); |
383 | } |
384 | |
385 | //===----------------------------------------------------------------------===// |
386 | // SubIOp |
387 | //===----------------------------------------------------------------------===// |
388 | |
389 | OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) { |
390 | // subi(x,x) -> 0 |
391 | if (getOperand(0) == getOperand(1)) { |
392 | auto shapedType = dyn_cast<ShapedType>(getType()); |
393 | // We can't generate a constant with a dynamic shaped tensor. |
394 | if (!shapedType || shapedType.hasStaticShape()) |
395 | return Builder(getContext()).getZeroAttr(getType()); |
396 | } |
397 | // subi(x,0) -> x |
398 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
399 | return getLhs(); |
400 | |
401 | if (auto add = getLhs().getDefiningOp<AddIOp>()) { |
402 | // subi(addi(a, b), b) -> a |
403 | if (getRhs() == add.getRhs()) |
404 | return add.getLhs(); |
405 | // subi(addi(a, b), a) -> b |
406 | if (getRhs() == add.getLhs()) |
407 | return add.getRhs(); |
408 | } |
409 | |
410 | return constFoldBinaryOp<IntegerAttr>( |
411 | adaptor.getOperands(), |
412 | [](APInt a, const APInt &b) { return std::move(a) - b; }); |
413 | } |
414 | |
415 | void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
416 | MLIRContext *context) { |
417 | patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS, |
418 | SubIRHSSubConstantLHS, SubILHSSubConstantRHS, |
419 | SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context); |
420 | } |
421 | |
422 | //===----------------------------------------------------------------------===// |
423 | // MulIOp |
424 | //===----------------------------------------------------------------------===// |
425 | |
426 | OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) { |
427 | // muli(x, 0) -> 0 |
428 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
429 | return getRhs(); |
430 | // muli(x, 1) -> x |
431 | if (matchPattern(adaptor.getRhs(), m_One())) |
432 | return getLhs(); |
433 | // TODO: Handle the overflow case. |
434 | |
435 | // default folder |
436 | return constFoldBinaryOp<IntegerAttr>( |
437 | adaptor.getOperands(), |
438 | [](const APInt &a, const APInt &b) { return a * b; }); |
439 | } |
440 | |
441 | void arith::MulIOp::getAsmResultNames( |
442 | function_ref<void(Value, StringRef)> setNameFn) { |
443 | if (!isa<IndexType>(getType())) |
444 | return; |
445 | |
446 | // Match vector.vscale by name to avoid depending on the vector dialect (which |
447 | // is a circular dependency). |
448 | auto isVscale = [](Operation *op) { |
449 | return op && op->getName().getStringRef() == "vector.vscale"; |
450 | }; |
451 | |
452 | IntegerAttr baseValue; |
453 | auto isVscaleExpr = [&](Value a, Value b) { |
454 | return matchPattern(a, m_Constant(&baseValue)) && |
455 | isVscale(b.getDefiningOp()); |
456 | }; |
457 | |
458 | if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs())) |
459 | return; |
460 | |
461 | // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`. |
462 | SmallString<32> specialNameBuffer; |
463 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
464 | specialName << 'c' << baseValue.getInt() << "_vscale"; |
465 | setNameFn(getResult(), specialName.str()); |
466 | } |
467 | |
468 | void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
469 | MLIRContext *context) { |
470 | patterns.add<MulIMulIConstant>(context); |
471 | } |
472 | |
473 | //===----------------------------------------------------------------------===// |
474 | // MulSIExtendedOp |
475 | //===----------------------------------------------------------------------===// |
476 | |
477 | std::optional<SmallVector<int64_t, 4>> |
478 | arith::MulSIExtendedOp::getShapeForUnroll() { |
479 | if (auto vt = llvm::dyn_cast<VectorType>(getType(0))) |
480 | return llvm::to_vector<4>(vt.getShape()); |
481 | return std::nullopt; |
482 | } |
483 | |
484 | LogicalResult |
485 | arith::MulSIExtendedOp::fold(FoldAdaptor adaptor, |
486 | SmallVectorImpl<OpFoldResult> &results) { |
487 | // mulsi_extended(x, 0) -> 0, 0 |
488 | if (matchPattern(adaptor.getRhs(), m_Zero())) { |
489 | Attribute zero = adaptor.getRhs(); |
490 | results.push_back(zero); |
491 | results.push_back(zero); |
492 | return success(); |
493 | } |
494 | |
495 | // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high |
496 | if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>( |
497 | adaptor.getOperands(), |
498 | [](const APInt &a, const APInt &b) { return a * b; })) { |
499 | // Invoke the constant fold helper again to calculate the 'high' result. |
500 | Attribute highAttr = constFoldBinaryOp<IntegerAttr>( |
501 | adaptor.getOperands(), [](const APInt &a, const APInt &b) { |
502 | return llvm::APIntOps::mulhs(a, b); |
503 | }); |
504 | assert(highAttr && "Unexpected constant-folding failure"); |
505 | |
506 | results.push_back(lowAttr); |
507 | results.push_back(highAttr); |
508 | return success(); |
509 | } |
510 | |
511 | return failure(); |
512 | } |
513 | |
514 | void arith::MulSIExtendedOp::getCanonicalizationPatterns( |
515 | RewritePatternSet &patterns, MLIRContext *context) { |
516 | patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context); |
517 | } |
518 | |
519 | //===----------------------------------------------------------------------===// |
520 | // MulUIExtendedOp |
521 | //===----------------------------------------------------------------------===// |
522 | |
523 | std::optional<SmallVector<int64_t, 4>> |
524 | arith::MulUIExtendedOp::getShapeForUnroll() { |
525 | if (auto vt = llvm::dyn_cast<VectorType>(getType(0))) |
526 | return llvm::to_vector<4>(vt.getShape()); |
527 | return std::nullopt; |
528 | } |
529 | |
530 | LogicalResult |
531 | arith::MulUIExtendedOp::fold(FoldAdaptor adaptor, |
532 | SmallVectorImpl<OpFoldResult> &results) { |
533 | // mului_extended(x, 0) -> 0, 0 |
534 | if (matchPattern(adaptor.getRhs(), m_Zero())) { |
535 | Attribute zero = adaptor.getRhs(); |
536 | results.push_back(zero); |
537 | results.push_back(zero); |
538 | return success(); |
539 | } |
540 | |
541 | // mului_extended(x, 1) -> x, 0 |
542 | if (matchPattern(adaptor.getRhs(), m_One())) { |
543 | Builder builder(getContext()); |
544 | Attribute zero = builder.getZeroAttr(getLhs().getType()); |
545 | results.push_back(getLhs()); |
546 | results.push_back(zero); |
547 | return success(); |
548 | } |
549 | |
550 | // mului_extended(cst_a, cst_b) -> cst_low, cst_high |
551 | if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>( |
552 | adaptor.getOperands(), |
553 | [](const APInt &a, const APInt &b) { return a * b; })) { |
554 | // Invoke the constant fold helper again to calculate the 'high' result. |
555 | Attribute highAttr = constFoldBinaryOp<IntegerAttr>( |
556 | adaptor.getOperands(), [](const APInt &a, const APInt &b) { |
557 | return llvm::APIntOps::mulhu(a, b); |
558 | }); |
559 | assert(highAttr && "Unexpected constant-folding failure"); |
560 | |
561 | results.push_back(lowAttr); |
562 | results.push_back(highAttr); |
563 | return success(); |
564 | } |
565 | |
566 | return failure(); |
567 | } |
568 | |
569 | void arith::MulUIExtendedOp::getCanonicalizationPatterns( |
570 | RewritePatternSet &patterns, MLIRContext *context) { |
571 | patterns.add<MulUIExtendedToMulI>(context); |
572 | } |
573 | |
574 | //===----------------------------------------------------------------------===// |
575 | // DivUIOp |
576 | //===----------------------------------------------------------------------===// |
577 | |
578 | /// Fold `(a * b) / b -> a` |
579 | static Value foldDivMul(Value lhs, Value rhs, |
580 | arith::IntegerOverflowFlags ovfFlags) { |
581 | auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>(); |
582 | if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags)) |
583 | return {}; |
584 | |
585 | if (mul.getLhs() == rhs) |
586 | return mul.getRhs(); |
587 | |
588 | if (mul.getRhs() == rhs) |
589 | return mul.getLhs(); |
590 | |
591 | return {}; |
592 | } |
593 | |
594 | OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) { |
595 | // divui (x, 1) -> x. |
596 | if (matchPattern(adaptor.getRhs(), m_One())) |
597 | return getLhs(); |
598 | |
599 | // (a * b) / b -> a |
600 | if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw)) |
601 | return val; |
602 | |
603 | // Don't fold if it would require a division by zero. |
604 | bool div0 = false; |
605 | auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), |
606 | [&](APInt a, const APInt &b) { |
607 | if (div0 || !b) { |
608 | div0 = true; |
609 | return a; |
610 | } |
611 | return a.udiv(b); |
612 | }); |
613 | |
614 | return div0 ? Attribute() : result; |
615 | } |
616 | |
617 | /// Returns whether an unsigned division by `divisor` is speculatable. |
618 | static Speculation::Speculatability getDivUISpeculatability(Value divisor) { |
619 | // X / 0 => UB |
620 | if (matchPattern(value: divisor, pattern: m_IntRangeWithoutZeroU())) |
621 | return Speculation::Speculatable; |
622 | |
623 | return Speculation::NotSpeculatable; |
624 | } |
625 | |
626 | Speculation::Speculatability arith::DivUIOp::getSpeculatability() { |
627 | return getDivUISpeculatability(getRhs()); |
628 | } |
629 | |
630 | //===----------------------------------------------------------------------===// |
631 | // DivSIOp |
632 | //===----------------------------------------------------------------------===// |
633 | |
634 | OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) { |
635 | // divsi (x, 1) -> x. |
636 | if (matchPattern(adaptor.getRhs(), m_One())) |
637 | return getLhs(); |
638 | |
639 | // (a * b) / b -> a |
640 | if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw)) |
641 | return val; |
642 | |
643 | // Don't fold if it would overflow or if it requires a division by zero. |
644 | bool overflowOrDiv0 = false; |
645 | auto result = constFoldBinaryOp<IntegerAttr>( |
646 | adaptor.getOperands(), [&](APInt a, const APInt &b) { |
647 | if (overflowOrDiv0 || !b) { |
648 | overflowOrDiv0 = true; |
649 | return a; |
650 | } |
651 | return a.sdiv_ov(b, overflowOrDiv0); |
652 | }); |
653 | |
654 | return overflowOrDiv0 ? Attribute() : result; |
655 | } |
656 | |
657 | /// Returns whether a signed division by `divisor` is speculatable. This |
658 | /// function conservatively assumes that all signed division by -1 are not |
659 | /// speculatable. |
660 | static Speculation::Speculatability getDivSISpeculatability(Value divisor) { |
661 | // X / 0 => UB |
662 | // INT_MIN / -1 => UB |
663 | if (matchPattern(value: divisor, pattern: m_IntRangeWithoutZeroS()) && |
664 | matchPattern(value: divisor, pattern: m_IntRangeWithoutNegOneS())) |
665 | return Speculation::Speculatable; |
666 | |
667 | return Speculation::NotSpeculatable; |
668 | } |
669 | |
670 | Speculation::Speculatability arith::DivSIOp::getSpeculatability() { |
671 | return getDivSISpeculatability(getRhs()); |
672 | } |
673 | |
674 | //===----------------------------------------------------------------------===// |
675 | // Ceil and floor division folding helpers |
676 | //===----------------------------------------------------------------------===// |
677 | |
678 | static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, |
679 | bool &overflow) { |
680 | // Returns (a-1)/b + 1 |
681 | APInt one(a.getBitWidth(), 1, true); // Signed value 1. |
682 | APInt val = a.ssub_ov(RHS: one, Overflow&: overflow).sdiv_ov(RHS: b, Overflow&: overflow); |
683 | return val.sadd_ov(RHS: one, Overflow&: overflow); |
684 | } |
685 | |
686 | //===----------------------------------------------------------------------===// |
687 | // CeilDivUIOp |
688 | //===----------------------------------------------------------------------===// |
689 | |
690 | OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) { |
691 | // ceildivui (x, 1) -> x. |
692 | if (matchPattern(adaptor.getRhs(), m_One())) |
693 | return getLhs(); |
694 | |
695 | bool overflowOrDiv0 = false; |
696 | auto result = constFoldBinaryOp<IntegerAttr>( |
697 | adaptor.getOperands(), [&](APInt a, const APInt &b) { |
698 | if (overflowOrDiv0 || !b) { |
699 | overflowOrDiv0 = true; |
700 | return a; |
701 | } |
702 | APInt quotient = a.udiv(b); |
703 | if (!a.urem(b)) |
704 | return quotient; |
705 | APInt one(a.getBitWidth(), 1, true); |
706 | return quotient.uadd_ov(one, overflowOrDiv0); |
707 | }); |
708 | |
709 | return overflowOrDiv0 ? Attribute() : result; |
710 | } |
711 | |
712 | Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() { |
713 | return getDivUISpeculatability(getRhs()); |
714 | } |
715 | |
716 | //===----------------------------------------------------------------------===// |
717 | // CeilDivSIOp |
718 | //===----------------------------------------------------------------------===// |
719 | |
720 | OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) { |
721 | // ceildivsi (x, 1) -> x. |
722 | if (matchPattern(adaptor.getRhs(), m_One())) |
723 | return getLhs(); |
724 | |
725 | // Don't fold if it would overflow or if it requires a division by zero. |
726 | // TODO: This hook won't fold operations where a = MININT, because |
727 | // negating MININT overflows. This can be improved. |
728 | bool overflowOrDiv0 = false; |
729 | auto result = constFoldBinaryOp<IntegerAttr>( |
730 | adaptor.getOperands(), [&](APInt a, const APInt &b) { |
731 | if (overflowOrDiv0 || !b) { |
732 | overflowOrDiv0 = true; |
733 | return a; |
734 | } |
735 | if (!a) |
736 | return a; |
737 | // After this point we know that neither a or b are zero. |
738 | unsigned bits = a.getBitWidth(); |
739 | APInt zero = APInt::getZero(bits); |
740 | bool aGtZero = a.sgt(zero); |
741 | bool bGtZero = b.sgt(zero); |
742 | if (aGtZero && bGtZero) { |
743 | // Both positive, return ceil(a, b). |
744 | return signedCeilNonnegInputs(a, b, overflowOrDiv0); |
745 | } |
746 | |
747 | // No folding happens if any of the intermediate arithmetic operations |
748 | // overflows. |
749 | bool overflowNegA = false; |
750 | bool overflowNegB = false; |
751 | bool overflowDiv = false; |
752 | bool overflowNegRes = false; |
753 | if (!aGtZero && !bGtZero) { |
754 | // Both negative, return ceil(-a, -b). |
755 | APInt posA = zero.ssub_ov(a, overflowNegA); |
756 | APInt posB = zero.ssub_ov(b, overflowNegB); |
757 | APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv); |
758 | overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv); |
759 | return res; |
760 | } |
761 | if (!aGtZero && bGtZero) { |
762 | // A is negative, b is positive, return - ( -a / b). |
763 | APInt posA = zero.ssub_ov(a, overflowNegA); |
764 | APInt div = posA.sdiv_ov(b, overflowDiv); |
765 | APInt res = zero.ssub_ov(div, overflowNegRes); |
766 | overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes); |
767 | return res; |
768 | } |
769 | // A is positive, b is negative, return - (a / -b). |
770 | APInt posB = zero.ssub_ov(b, overflowNegB); |
771 | APInt div = a.sdiv_ov(posB, overflowDiv); |
772 | APInt res = zero.ssub_ov(div, overflowNegRes); |
773 | |
774 | overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes); |
775 | return res; |
776 | }); |
777 | |
778 | return overflowOrDiv0 ? Attribute() : result; |
779 | } |
780 | |
781 | Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() { |
782 | return getDivSISpeculatability(getRhs()); |
783 | } |
784 | |
785 | //===----------------------------------------------------------------------===// |
786 | // FloorDivSIOp |
787 | //===----------------------------------------------------------------------===// |
788 | |
789 | OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) { |
790 | // floordivsi (x, 1) -> x. |
791 | if (matchPattern(adaptor.getRhs(), m_One())) |
792 | return getLhs(); |
793 | |
794 | // Don't fold if it would overflow or if it requires a division by zero. |
795 | bool overflowOrDiv = false; |
796 | auto result = constFoldBinaryOp<IntegerAttr>( |
797 | adaptor.getOperands(), [&](APInt a, const APInt &b) { |
798 | if (b.isZero()) { |
799 | overflowOrDiv = true; |
800 | return a; |
801 | } |
802 | return a.sfloordiv_ov(b, overflowOrDiv); |
803 | }); |
804 | |
805 | return overflowOrDiv ? Attribute() : result; |
806 | } |
807 | |
808 | //===----------------------------------------------------------------------===// |
809 | // RemUIOp |
810 | //===----------------------------------------------------------------------===// |
811 | |
812 | OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) { |
813 | // remui (x, 1) -> 0. |
814 | if (matchPattern(adaptor.getRhs(), m_One())) |
815 | return Builder(getContext()).getZeroAttr(getType()); |
816 | |
817 | // Don't fold if it would require a division by zero. |
818 | bool div0 = false; |
819 | auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), |
820 | [&](APInt a, const APInt &b) { |
821 | if (div0 || b.isZero()) { |
822 | div0 = true; |
823 | return a; |
824 | } |
825 | return a.urem(b); |
826 | }); |
827 | |
828 | return div0 ? Attribute() : result; |
829 | } |
830 | |
831 | //===----------------------------------------------------------------------===// |
832 | // RemSIOp |
833 | //===----------------------------------------------------------------------===// |
834 | |
835 | OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) { |
836 | // remsi (x, 1) -> 0. |
837 | if (matchPattern(adaptor.getRhs(), m_One())) |
838 | return Builder(getContext()).getZeroAttr(getType()); |
839 | |
840 | // Don't fold if it would require a division by zero. |
841 | bool div0 = false; |
842 | auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), |
843 | [&](APInt a, const APInt &b) { |
844 | if (div0 || b.isZero()) { |
845 | div0 = true; |
846 | return a; |
847 | } |
848 | return a.srem(b); |
849 | }); |
850 | |
851 | return div0 ? Attribute() : result; |
852 | } |
853 | |
854 | //===----------------------------------------------------------------------===// |
855 | // AndIOp |
856 | //===----------------------------------------------------------------------===// |
857 | |
858 | /// Fold `and(a, and(a, b))` to `and(a, b)` |
859 | static Value foldAndIofAndI(arith::AndIOp op) { |
860 | for (bool reversePrev : {false, true}) { |
861 | auto prev = (reversePrev ? op.getRhs() : op.getLhs()) |
862 | .getDefiningOp<arith::AndIOp>(); |
863 | if (!prev) |
864 | continue; |
865 | |
866 | Value other = (reversePrev ? op.getLhs() : op.getRhs()); |
867 | if (other != prev.getLhs() && other != prev.getRhs()) |
868 | continue; |
869 | |
870 | return prev.getResult(); |
871 | } |
872 | return {}; |
873 | } |
874 | |
875 | OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) { |
876 | /// and(x, 0) -> 0 |
877 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
878 | return getRhs(); |
879 | /// and(x, allOnes) -> x |
880 | APInt intValue; |
881 | if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) && |
882 | intValue.isAllOnes()) |
883 | return getLhs(); |
884 | /// and(x, not(x)) -> 0 |
885 | if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()), |
886 | m_ConstantInt(&intValue))) && |
887 | intValue.isAllOnes()) |
888 | return Builder(getContext()).getZeroAttr(getType()); |
889 | /// and(not(x), x) -> 0 |
890 | if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()), |
891 | m_ConstantInt(&intValue))) && |
892 | intValue.isAllOnes()) |
893 | return Builder(getContext()).getZeroAttr(getType()); |
894 | |
895 | /// and(a, and(a, b)) -> and(a, b) |
896 | if (Value result = foldAndIofAndI(*this)) |
897 | return result; |
898 | |
899 | return constFoldBinaryOp<IntegerAttr>( |
900 | adaptor.getOperands(), |
901 | [](APInt a, const APInt &b) { return std::move(a) & b; }); |
902 | } |
903 | |
904 | //===----------------------------------------------------------------------===// |
905 | // OrIOp |
906 | //===----------------------------------------------------------------------===// |
907 | |
908 | OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) { |
909 | if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) { |
910 | /// or(x, 0) -> x |
911 | if (rhsVal.isZero()) |
912 | return getLhs(); |
913 | /// or(x, <all ones>) -> <all ones> |
914 | if (rhsVal.isAllOnes()) |
915 | return adaptor.getRhs(); |
916 | } |
917 | |
918 | APInt intValue; |
919 | /// or(x, xor(x, 1)) -> 1 |
920 | if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()), |
921 | m_ConstantInt(&intValue))) && |
922 | intValue.isAllOnes()) |
923 | return getRhs().getDefiningOp<XOrIOp>().getRhs(); |
924 | /// or(xor(x, 1), x) -> 1 |
925 | if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()), |
926 | m_ConstantInt(&intValue))) && |
927 | intValue.isAllOnes()) |
928 | return getLhs().getDefiningOp<XOrIOp>().getRhs(); |
929 | |
930 | return constFoldBinaryOp<IntegerAttr>( |
931 | adaptor.getOperands(), |
932 | [](APInt a, const APInt &b) { return std::move(a) | b; }); |
933 | } |
934 | |
935 | //===----------------------------------------------------------------------===// |
936 | // XOrIOp |
937 | //===----------------------------------------------------------------------===// |
938 | |
939 | OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) { |
940 | /// xor(x, 0) -> x |
941 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
942 | return getLhs(); |
943 | /// xor(x, x) -> 0 |
944 | if (getLhs() == getRhs()) |
945 | return Builder(getContext()).getZeroAttr(getType()); |
946 | /// xor(xor(x, a), a) -> x |
947 | /// xor(xor(a, x), a) -> x |
948 | if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) { |
949 | if (prev.getRhs() == getRhs()) |
950 | return prev.getLhs(); |
951 | if (prev.getLhs() == getRhs()) |
952 | return prev.getRhs(); |
953 | } |
954 | /// xor(a, xor(x, a)) -> x |
955 | /// xor(a, xor(a, x)) -> x |
956 | if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) { |
957 | if (prev.getRhs() == getLhs()) |
958 | return prev.getLhs(); |
959 | if (prev.getLhs() == getLhs()) |
960 | return prev.getRhs(); |
961 | } |
962 | |
963 | return constFoldBinaryOp<IntegerAttr>( |
964 | adaptor.getOperands(), |
965 | [](APInt a, const APInt &b) { return std::move(a) ^ b; }); |
966 | } |
967 | |
968 | void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
969 | MLIRContext *context) { |
970 | patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context); |
971 | } |
972 | |
973 | //===----------------------------------------------------------------------===// |
974 | // NegFOp |
975 | //===----------------------------------------------------------------------===// |
976 | |
977 | OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) { |
978 | /// negf(negf(x)) -> x |
979 | if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>()) |
980 | return op.getOperand(); |
981 | return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(), |
982 | [](const APFloat &a) { return -a; }); |
983 | } |
984 | |
985 | //===----------------------------------------------------------------------===// |
986 | // AddFOp |
987 | //===----------------------------------------------------------------------===// |
988 | |
989 | OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) { |
990 | // addf(x, -0) -> x |
991 | if (matchPattern(adaptor.getRhs(), m_NegZeroFloat())) |
992 | return getLhs(); |
993 | |
994 | return constFoldBinaryOp<FloatAttr>( |
995 | adaptor.getOperands(), |
996 | [](const APFloat &a, const APFloat &b) { return a + b; }); |
997 | } |
998 | |
999 | //===----------------------------------------------------------------------===// |
1000 | // SubFOp |
1001 | //===----------------------------------------------------------------------===// |
1002 | |
1003 | OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) { |
1004 | // subf(x, +0) -> x |
1005 | if (matchPattern(adaptor.getRhs(), m_PosZeroFloat())) |
1006 | return getLhs(); |
1007 | |
1008 | return constFoldBinaryOp<FloatAttr>( |
1009 | adaptor.getOperands(), |
1010 | [](const APFloat &a, const APFloat &b) { return a - b; }); |
1011 | } |
1012 | |
1013 | //===----------------------------------------------------------------------===// |
1014 | // MaximumFOp |
1015 | //===----------------------------------------------------------------------===// |
1016 | |
1017 | OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) { |
1018 | // maximumf(x,x) -> x |
1019 | if (getLhs() == getRhs()) |
1020 | return getRhs(); |
1021 | |
1022 | // maximumf(x, -inf) -> x |
1023 | if (matchPattern(adaptor.getRhs(), m_NegInfFloat())) |
1024 | return getLhs(); |
1025 | |
1026 | return constFoldBinaryOp<FloatAttr>( |
1027 | adaptor.getOperands(), |
1028 | [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); |
1029 | } |
1030 | |
1031 | //===----------------------------------------------------------------------===// |
1032 | // MaxNumFOp |
1033 | //===----------------------------------------------------------------------===// |
1034 | |
1035 | OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) { |
1036 | // maxnumf(x,x) -> x |
1037 | if (getLhs() == getRhs()) |
1038 | return getRhs(); |
1039 | |
1040 | // maxnumf(x, NaN) -> x |
1041 | if (matchPattern(adaptor.getRhs(), m_NaNFloat())) |
1042 | return getLhs(); |
1043 | |
1044 | return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum); |
1045 | } |
1046 | |
1047 | //===----------------------------------------------------------------------===// |
1048 | // MaxSIOp |
1049 | //===----------------------------------------------------------------------===// |
1050 | |
1051 | OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) { |
1052 | // maxsi(x,x) -> x |
1053 | if (getLhs() == getRhs()) |
1054 | return getRhs(); |
1055 | |
1056 | if (APInt intValue; |
1057 | matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { |
1058 | // maxsi(x,MAX_INT) -> MAX_INT |
1059 | if (intValue.isMaxSignedValue()) |
1060 | return getRhs(); |
1061 | // maxsi(x, MIN_INT) -> x |
1062 | if (intValue.isMinSignedValue()) |
1063 | return getLhs(); |
1064 | } |
1065 | |
1066 | return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), |
1067 | [](const APInt &a, const APInt &b) { |
1068 | return llvm::APIntOps::smax(a, b); |
1069 | }); |
1070 | } |
1071 | |
1072 | //===----------------------------------------------------------------------===// |
1073 | // MaxUIOp |
1074 | //===----------------------------------------------------------------------===// |
1075 | |
1076 | OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) { |
1077 | // maxui(x,x) -> x |
1078 | if (getLhs() == getRhs()) |
1079 | return getRhs(); |
1080 | |
1081 | if (APInt intValue; |
1082 | matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { |
1083 | // maxui(x,MAX_INT) -> MAX_INT |
1084 | if (intValue.isMaxValue()) |
1085 | return getRhs(); |
1086 | // maxui(x, MIN_INT) -> x |
1087 | if (intValue.isMinValue()) |
1088 | return getLhs(); |
1089 | } |
1090 | |
1091 | return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), |
1092 | [](const APInt &a, const APInt &b) { |
1093 | return llvm::APIntOps::umax(a, b); |
1094 | }); |
1095 | } |
1096 | |
1097 | //===----------------------------------------------------------------------===// |
1098 | // MinimumFOp |
1099 | //===----------------------------------------------------------------------===// |
1100 | |
1101 | OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) { |
1102 | // minimumf(x,x) -> x |
1103 | if (getLhs() == getRhs()) |
1104 | return getRhs(); |
1105 | |
1106 | // minimumf(x, +inf) -> x |
1107 | if (matchPattern(adaptor.getRhs(), m_PosInfFloat())) |
1108 | return getLhs(); |
1109 | |
1110 | return constFoldBinaryOp<FloatAttr>( |
1111 | adaptor.getOperands(), |
1112 | [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); |
1113 | } |
1114 | |
1115 | //===----------------------------------------------------------------------===// |
1116 | // MinNumFOp |
1117 | //===----------------------------------------------------------------------===// |
1118 | |
1119 | OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) { |
1120 | // minnumf(x,x) -> x |
1121 | if (getLhs() == getRhs()) |
1122 | return getRhs(); |
1123 | |
1124 | // minnumf(x, NaN) -> x |
1125 | if (matchPattern(adaptor.getRhs(), m_NaNFloat())) |
1126 | return getLhs(); |
1127 | |
1128 | return constFoldBinaryOp<FloatAttr>( |
1129 | adaptor.getOperands(), |
1130 | [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); }); |
1131 | } |
1132 | |
1133 | //===----------------------------------------------------------------------===// |
1134 | // MinSIOp |
1135 | //===----------------------------------------------------------------------===// |
1136 | |
1137 | OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) { |
1138 | // minsi(x,x) -> x |
1139 | if (getLhs() == getRhs()) |
1140 | return getRhs(); |
1141 | |
1142 | if (APInt intValue; |
1143 | matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { |
1144 | // minsi(x,MIN_INT) -> MIN_INT |
1145 | if (intValue.isMinSignedValue()) |
1146 | return getRhs(); |
1147 | // minsi(x, MAX_INT) -> x |
1148 | if (intValue.isMaxSignedValue()) |
1149 | return getLhs(); |
1150 | } |
1151 | |
1152 | return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), |
1153 | [](const APInt &a, const APInt &b) { |
1154 | return llvm::APIntOps::smin(a, b); |
1155 | }); |
1156 | } |
1157 | |
1158 | //===----------------------------------------------------------------------===// |
1159 | // MinUIOp |
1160 | //===----------------------------------------------------------------------===// |
1161 | |
1162 | OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) { |
1163 | // minui(x,x) -> x |
1164 | if (getLhs() == getRhs()) |
1165 | return getRhs(); |
1166 | |
1167 | if (APInt intValue; |
1168 | matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { |
1169 | // minui(x,MIN_INT) -> MIN_INT |
1170 | if (intValue.isMinValue()) |
1171 | return getRhs(); |
1172 | // minui(x, MAX_INT) -> x |
1173 | if (intValue.isMaxValue()) |
1174 | return getLhs(); |
1175 | } |
1176 | |
1177 | return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), |
1178 | [](const APInt &a, const APInt &b) { |
1179 | return llvm::APIntOps::umin(a, b); |
1180 | }); |
1181 | } |
1182 | |
1183 | //===----------------------------------------------------------------------===// |
1184 | // MulFOp |
1185 | //===----------------------------------------------------------------------===// |
1186 | |
1187 | OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) { |
1188 | // mulf(x, 1) -> x |
1189 | if (matchPattern(adaptor.getRhs(), m_OneFloat())) |
1190 | return getLhs(); |
1191 | |
1192 | return constFoldBinaryOp<FloatAttr>( |
1193 | adaptor.getOperands(), |
1194 | [](const APFloat &a, const APFloat &b) { return a * b; }); |
1195 | } |
1196 | |
1197 | void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1198 | MLIRContext *context) { |
1199 | patterns.add<MulFOfNegF>(context); |
1200 | } |
1201 | |
1202 | //===----------------------------------------------------------------------===// |
1203 | // DivFOp |
1204 | //===----------------------------------------------------------------------===// |
1205 | |
1206 | OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) { |
1207 | // divf(x, 1) -> x |
1208 | if (matchPattern(adaptor.getRhs(), m_OneFloat())) |
1209 | return getLhs(); |
1210 | |
1211 | return constFoldBinaryOp<FloatAttr>( |
1212 | adaptor.getOperands(), |
1213 | [](const APFloat &a, const APFloat &b) { return a / b; }); |
1214 | } |
1215 | |
1216 | void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1217 | MLIRContext *context) { |
1218 | patterns.add<DivFOfNegF>(context); |
1219 | } |
1220 | |
1221 | //===----------------------------------------------------------------------===// |
1222 | // RemFOp |
1223 | //===----------------------------------------------------------------------===// |
1224 | |
1225 | OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) { |
1226 | return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), |
1227 | [](const APFloat &a, const APFloat &b) { |
1228 | APFloat result(a); |
1229 | // APFloat::mod() offers the remainder |
1230 | // behavior we want, i.e. the result has |
1231 | // the sign of LHS operand. |
1232 | (void)result.mod(b); |
1233 | return result; |
1234 | }); |
1235 | } |
1236 | |
1237 | //===----------------------------------------------------------------------===// |
1238 | // Utility functions for verifying cast ops |
1239 | //===----------------------------------------------------------------------===// |
1240 | |
1241 | template <typename... Types> |
1242 | using type_list = std::tuple<Types...> *; |
1243 | |
1244 | /// Returns a non-null type only if the provided type is one of the allowed |
1245 | /// types or one of the allowed shaped types of the allowed types. Returns the |
1246 | /// element type if a valid shaped type is provided. |
1247 | template <typename... ShapedTypes, typename... ElementTypes> |
1248 | static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, |
1249 | type_list<ElementTypes...>) { |
1250 | if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type)) |
1251 | return {}; |
1252 | |
1253 | auto underlyingType = getElementTypeOrSelf(type); |
1254 | if (!llvm::isa<ElementTypes...>(underlyingType)) |
1255 | return {}; |
1256 | |
1257 | return underlyingType; |
1258 | } |
1259 | |
1260 | /// Get allowed underlying types for vectors and tensors. |
1261 | template <typename... ElementTypes> |
1262 | static Type getTypeIfLike(Type type) { |
1263 | return getUnderlyingType(type, type_list<VectorType, TensorType>(), |
1264 | type_list<ElementTypes...>()); |
1265 | } |
1266 | |
1267 | /// Get allowed underlying types for vectors, tensors, and memrefs. |
1268 | template <typename... ElementTypes> |
1269 | static Type getTypeIfLikeOrMemRef(Type type) { |
1270 | return getUnderlyingType(type, |
1271 | type_list<VectorType, TensorType, MemRefType>(), |
1272 | type_list<ElementTypes...>()); |
1273 | } |
1274 | |
1275 | /// Return false if both types are ranked tensor with mismatching encoding. |
1276 | static bool hasSameEncoding(Type typeA, Type typeB) { |
1277 | auto rankedTensorA = dyn_cast<RankedTensorType>(typeA); |
1278 | auto rankedTensorB = dyn_cast<RankedTensorType>(typeB); |
1279 | if (!rankedTensorA || !rankedTensorB) |
1280 | return true; |
1281 | return rankedTensorA.getEncoding() == rankedTensorB.getEncoding(); |
1282 | } |
1283 | |
1284 | static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { |
1285 | if (inputs.size() != 1 || outputs.size() != 1) |
1286 | return false; |
1287 | if (!hasSameEncoding(typeA: inputs.front(), typeB: outputs.front())) |
1288 | return false; |
1289 | return succeeded(Result: verifyCompatibleShapes(types1: inputs.front(), types2: outputs.front())); |
1290 | } |
1291 | |
1292 | //===----------------------------------------------------------------------===// |
1293 | // Verifiers for integer and floating point extension/truncation ops |
1294 | //===----------------------------------------------------------------------===// |
1295 | |
1296 | // Extend ops can only extend to a wider type. |
1297 | template <typename ValType, typename Op> |
1298 | static LogicalResult verifyExtOp(Op op) { |
1299 | Type srcType = getElementTypeOrSelf(op.getIn().getType()); |
1300 | Type dstType = getElementTypeOrSelf(op.getType()); |
1301 | |
1302 | if (llvm::cast<ValType>(srcType).getWidth() >= |
1303 | llvm::cast<ValType>(dstType).getWidth()) |
1304 | return op.emitError("result type ") |
1305 | << dstType << " must be wider than operand type "<< srcType; |
1306 | |
1307 | return success(); |
1308 | } |
1309 | |
1310 | // Truncate ops can only truncate to a shorter type. |
1311 | template <typename ValType, typename Op> |
1312 | static LogicalResult verifyTruncateOp(Op op) { |
1313 | Type srcType = getElementTypeOrSelf(op.getIn().getType()); |
1314 | Type dstType = getElementTypeOrSelf(op.getType()); |
1315 | |
1316 | if (llvm::cast<ValType>(srcType).getWidth() <= |
1317 | llvm::cast<ValType>(dstType).getWidth()) |
1318 | return op.emitError("result type ") |
1319 | << dstType << " must be shorter than operand type "<< srcType; |
1320 | |
1321 | return success(); |
1322 | } |
1323 | |
1324 | /// Validate a cast that changes the width of a type. |
1325 | template <template <typename> class WidthComparator, typename... ElementTypes> |
1326 | static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { |
1327 | if (!areValidCastInputsAndOutputs(inputs, outputs)) |
1328 | return false; |
1329 | |
1330 | auto srcType = getTypeIfLike<ElementTypes...>(inputs.front()); |
1331 | auto dstType = getTypeIfLike<ElementTypes...>(outputs.front()); |
1332 | if (!srcType || !dstType) |
1333 | return false; |
1334 | |
1335 | return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(), |
1336 | srcType.getIntOrFloatBitWidth()); |
1337 | } |
1338 | |
1339 | /// Attempts to convert `sourceValue` to an APFloat value with |
1340 | /// `targetSemantics` and `roundingMode`, without any information loss. |
1341 | static FailureOr<APFloat> convertFloatValue( |
1342 | APFloat sourceValue, const llvm::fltSemantics &targetSemantics, |
1343 | llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) { |
1344 | bool losesInfo = false; |
1345 | auto status = sourceValue.convert(ToSemantics: targetSemantics, RM: roundingMode, losesInfo: &losesInfo); |
1346 | if (losesInfo || status != APFloat::opOK) |
1347 | return failure(); |
1348 | |
1349 | return sourceValue; |
1350 | } |
1351 | |
1352 | //===----------------------------------------------------------------------===// |
1353 | // ExtUIOp |
1354 | //===----------------------------------------------------------------------===// |
1355 | |
1356 | OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) { |
1357 | if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { |
1358 | getInMutable().assign(lhs.getIn()); |
1359 | return getResult(); |
1360 | } |
1361 | |
1362 | Type resType = getElementTypeOrSelf(getType()); |
1363 | unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); |
1364 | return constFoldCastOp<IntegerAttr, IntegerAttr>( |
1365 | adaptor.getOperands(), getType(), |
1366 | [bitWidth](const APInt &a, bool &castStatus) { |
1367 | return a.zext(bitWidth); |
1368 | }); |
1369 | } |
1370 | |
1371 | bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1372 | return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); |
1373 | } |
1374 | |
1375 | LogicalResult arith::ExtUIOp::verify() { |
1376 | return verifyExtOp<IntegerType>(*this); |
1377 | } |
1378 | |
1379 | //===----------------------------------------------------------------------===// |
1380 | // ExtSIOp |
1381 | //===----------------------------------------------------------------------===// |
1382 | |
1383 | OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) { |
1384 | if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { |
1385 | getInMutable().assign(lhs.getIn()); |
1386 | return getResult(); |
1387 | } |
1388 | |
1389 | Type resType = getElementTypeOrSelf(getType()); |
1390 | unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); |
1391 | return constFoldCastOp<IntegerAttr, IntegerAttr>( |
1392 | adaptor.getOperands(), getType(), |
1393 | [bitWidth](const APInt &a, bool &castStatus) { |
1394 | return a.sext(bitWidth); |
1395 | }); |
1396 | } |
1397 | |
1398 | bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1399 | return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs); |
1400 | } |
1401 | |
1402 | void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1403 | MLIRContext *context) { |
1404 | patterns.add<ExtSIOfExtUI>(context); |
1405 | } |
1406 | |
1407 | LogicalResult arith::ExtSIOp::verify() { |
1408 | return verifyExtOp<IntegerType>(*this); |
1409 | } |
1410 | |
1411 | //===----------------------------------------------------------------------===// |
1412 | // ExtFOp |
1413 | //===----------------------------------------------------------------------===// |
1414 | |
1415 | /// Fold extension of float constants when there is no information loss due the |
1416 | /// difference in fp semantics. |
1417 | OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) { |
1418 | if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) { |
1419 | if (truncFOp.getOperand().getType() == getType()) { |
1420 | arith::FastMathFlags truncFMF = |
1421 | truncFOp.getFastmath().value_or(arith::FastMathFlags::none); |
1422 | bool isTruncContract = |
1423 | bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract); |
1424 | arith::FastMathFlags extFMF = |
1425 | getFastmath().value_or(arith::FastMathFlags::none); |
1426 | bool isExtContract = |
1427 | bitEnumContainsAll(extFMF, arith::FastMathFlags::contract); |
1428 | if (isTruncContract && isExtContract) { |
1429 | return truncFOp.getOperand(); |
1430 | } |
1431 | } |
1432 | } |
1433 | |
1434 | auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType())); |
1435 | const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics(); |
1436 | return constFoldCastOp<FloatAttr, FloatAttr>( |
1437 | adaptor.getOperands(), getType(), |
1438 | [&targetSemantics](const APFloat &a, bool &castStatus) { |
1439 | FailureOr<APFloat> result = convertFloatValue(a, targetSemantics); |
1440 | if (failed(result)) { |
1441 | castStatus = false; |
1442 | return a; |
1443 | } |
1444 | return *result; |
1445 | }); |
1446 | } |
1447 | |
1448 | bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1449 | return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs); |
1450 | } |
1451 | |
1452 | LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); } |
1453 | |
1454 | //===----------------------------------------------------------------------===// |
1455 | // ScalingExtFOp |
1456 | //===----------------------------------------------------------------------===// |
1457 | |
1458 | bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs, |
1459 | TypeRange outputs) { |
1460 | return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs); |
1461 | } |
1462 | |
1463 | LogicalResult arith::ScalingExtFOp::verify() { |
1464 | return verifyExtOp<FloatType>(*this); |
1465 | } |
1466 | |
1467 | //===----------------------------------------------------------------------===// |
1468 | // TruncIOp |
1469 | //===----------------------------------------------------------------------===// |
1470 | |
1471 | OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) { |
1472 | if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) || |
1473 | matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) { |
1474 | Value src = getOperand().getDefiningOp()->getOperand(0); |
1475 | Type srcType = getElementTypeOrSelf(src.getType()); |
1476 | Type dstType = getElementTypeOrSelf(getType()); |
1477 | // trunci(zexti(a)) -> trunci(a) |
1478 | // trunci(sexti(a)) -> trunci(a) |
1479 | if (llvm::cast<IntegerType>(srcType).getWidth() > |
1480 | llvm::cast<IntegerType>(dstType).getWidth()) { |
1481 | setOperand(src); |
1482 | return getResult(); |
1483 | } |
1484 | |
1485 | // trunci(zexti(a)) -> a |
1486 | // trunci(sexti(a)) -> a |
1487 | if (srcType == dstType) |
1488 | return src; |
1489 | } |
1490 | |
1491 | // trunci(trunci(a)) -> trunci(a)) |
1492 | if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) { |
1493 | setOperand(getOperand().getDefiningOp()->getOperand(0)); |
1494 | return getResult(); |
1495 | } |
1496 | |
1497 | Type resType = getElementTypeOrSelf(getType()); |
1498 | unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); |
1499 | return constFoldCastOp<IntegerAttr, IntegerAttr>( |
1500 | adaptor.getOperands(), getType(), |
1501 | [bitWidth](const APInt &a, bool &castStatus) { |
1502 | return a.trunc(bitWidth); |
1503 | }); |
1504 | } |
1505 | |
1506 | bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1507 | return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs); |
1508 | } |
1509 | |
1510 | void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1511 | MLIRContext *context) { |
1512 | patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI, |
1513 | TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>( |
1514 | context); |
1515 | } |
1516 | |
1517 | LogicalResult arith::TruncIOp::verify() { |
1518 | return verifyTruncateOp<IntegerType>(*this); |
1519 | } |
1520 | |
1521 | //===----------------------------------------------------------------------===// |
1522 | // TruncFOp |
1523 | //===----------------------------------------------------------------------===// |
1524 | |
1525 | /// Perform safe const propagation for truncf, i.e., only propagate if FP value |
1526 | /// can be represented without precision loss. |
1527 | OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) { |
1528 | auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType())); |
1529 | if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) { |
1530 | Value src = extOp.getIn(); |
1531 | auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType())); |
1532 | auto intermediateType = |
1533 | cast<FloatType>(getElementTypeOrSelf(extOp.getType())); |
1534 | // Check if the srcType is representable in the intermediateType. |
1535 | if (llvm::APFloatBase::isRepresentableBy( |
1536 | srcType.getFloatSemantics(), |
1537 | intermediateType.getFloatSemantics())) { |
1538 | // truncf(extf(a)) -> truncf(a) |
1539 | if (srcType.getWidth() > resElemType.getWidth()) { |
1540 | setOperand(src); |
1541 | return getResult(); |
1542 | } |
1543 | |
1544 | // truncf(extf(a)) -> a |
1545 | if (srcType == resElemType) |
1546 | return src; |
1547 | } |
1548 | } |
1549 | |
1550 | const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics(); |
1551 | return constFoldCastOp<FloatAttr, FloatAttr>( |
1552 | adaptor.getOperands(), getType(), |
1553 | [this, &targetSemantics](const APFloat &a, bool &castStatus) { |
1554 | RoundingMode roundingMode = |
1555 | getRoundingmode().value_or(RoundingMode::to_nearest_even); |
1556 | llvm::RoundingMode llvmRoundingMode = |
1557 | convertArithRoundingModeToLLVMIR(roundingMode); |
1558 | FailureOr<APFloat> result = |
1559 | convertFloatValue(a, targetSemantics, llvmRoundingMode); |
1560 | if (failed(result)) { |
1561 | castStatus = false; |
1562 | return a; |
1563 | } |
1564 | return *result; |
1565 | }); |
1566 | } |
1567 | |
1568 | void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1569 | MLIRContext *context) { |
1570 | patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context); |
1571 | } |
1572 | |
1573 | bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1574 | return checkWidthChangeCast<std::less, FloatType>(inputs, outputs); |
1575 | } |
1576 | |
1577 | LogicalResult arith::TruncFOp::verify() { |
1578 | return verifyTruncateOp<FloatType>(*this); |
1579 | } |
1580 | |
1581 | //===----------------------------------------------------------------------===// |
1582 | // ScalingTruncFOp |
1583 | //===----------------------------------------------------------------------===// |
1584 | |
1585 | bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs, |
1586 | TypeRange outputs) { |
1587 | return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs); |
1588 | } |
1589 | |
1590 | LogicalResult arith::ScalingTruncFOp::verify() { |
1591 | return verifyTruncateOp<FloatType>(*this); |
1592 | } |
1593 | |
1594 | //===----------------------------------------------------------------------===// |
1595 | // AndIOp |
1596 | //===----------------------------------------------------------------------===// |
1597 | |
1598 | void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1599 | MLIRContext *context) { |
1600 | patterns.add<AndOfExtUI, AndOfExtSI>(context); |
1601 | } |
1602 | |
1603 | //===----------------------------------------------------------------------===// |
1604 | // OrIOp |
1605 | //===----------------------------------------------------------------------===// |
1606 | |
1607 | void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1608 | MLIRContext *context) { |
1609 | patterns.add<OrOfExtUI, OrOfExtSI>(context); |
1610 | } |
1611 | |
1612 | //===----------------------------------------------------------------------===// |
1613 | // Verifiers for casts between integers and floats. |
1614 | //===----------------------------------------------------------------------===// |
1615 | |
1616 | template <typename From, typename To> |
1617 | static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { |
1618 | if (!areValidCastInputsAndOutputs(inputs, outputs)) |
1619 | return false; |
1620 | |
1621 | auto srcType = getTypeIfLike<From>(inputs.front()); |
1622 | auto dstType = getTypeIfLike<To>(outputs.back()); |
1623 | |
1624 | return srcType && dstType; |
1625 | } |
1626 | |
1627 | //===----------------------------------------------------------------------===// |
1628 | // UIToFPOp |
1629 | //===----------------------------------------------------------------------===// |
1630 | |
1631 | bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1632 | return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); |
1633 | } |
1634 | |
1635 | OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) { |
1636 | Type resEleType = getElementTypeOrSelf(getType()); |
1637 | return constFoldCastOp<IntegerAttr, FloatAttr>( |
1638 | adaptor.getOperands(), getType(), |
1639 | [&resEleType](const APInt &a, bool &castStatus) { |
1640 | FloatType floatTy = llvm::cast<FloatType>(resEleType); |
1641 | APFloat apf(floatTy.getFloatSemantics(), |
1642 | APInt::getZero(floatTy.getWidth())); |
1643 | apf.convertFromAPInt(a, /*IsSigned=*/false, |
1644 | APFloat::rmNearestTiesToEven); |
1645 | return apf; |
1646 | }); |
1647 | } |
1648 | |
1649 | //===----------------------------------------------------------------------===// |
1650 | // SIToFPOp |
1651 | //===----------------------------------------------------------------------===// |
1652 | |
1653 | bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1654 | return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs); |
1655 | } |
1656 | |
1657 | OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) { |
1658 | Type resEleType = getElementTypeOrSelf(getType()); |
1659 | return constFoldCastOp<IntegerAttr, FloatAttr>( |
1660 | adaptor.getOperands(), getType(), |
1661 | [&resEleType](const APInt &a, bool &castStatus) { |
1662 | FloatType floatTy = llvm::cast<FloatType>(resEleType); |
1663 | APFloat apf(floatTy.getFloatSemantics(), |
1664 | APInt::getZero(floatTy.getWidth())); |
1665 | apf.convertFromAPInt(a, /*IsSigned=*/true, |
1666 | APFloat::rmNearestTiesToEven); |
1667 | return apf; |
1668 | }); |
1669 | } |
1670 | |
1671 | //===----------------------------------------------------------------------===// |
1672 | // FPToUIOp |
1673 | //===----------------------------------------------------------------------===// |
1674 | |
1675 | bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1676 | return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); |
1677 | } |
1678 | |
1679 | OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) { |
1680 | Type resType = getElementTypeOrSelf(getType()); |
1681 | unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); |
1682 | return constFoldCastOp<FloatAttr, IntegerAttr>( |
1683 | adaptor.getOperands(), getType(), |
1684 | [&bitWidth](const APFloat &a, bool &castStatus) { |
1685 | bool ignored; |
1686 | APSInt api(bitWidth, /*isUnsigned=*/true); |
1687 | castStatus = APFloat::opInvalidOp != |
1688 | a.convertToInteger(api, APFloat::rmTowardZero, &ignored); |
1689 | return api; |
1690 | }); |
1691 | } |
1692 | |
1693 | //===----------------------------------------------------------------------===// |
1694 | // FPToSIOp |
1695 | //===----------------------------------------------------------------------===// |
1696 | |
1697 | bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1698 | return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs); |
1699 | } |
1700 | |
1701 | OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) { |
1702 | Type resType = getElementTypeOrSelf(getType()); |
1703 | unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth(); |
1704 | return constFoldCastOp<FloatAttr, IntegerAttr>( |
1705 | adaptor.getOperands(), getType(), |
1706 | [&bitWidth](const APFloat &a, bool &castStatus) { |
1707 | bool ignored; |
1708 | APSInt api(bitWidth, /*isUnsigned=*/false); |
1709 | castStatus = APFloat::opInvalidOp != |
1710 | a.convertToInteger(api, APFloat::rmTowardZero, &ignored); |
1711 | return api; |
1712 | }); |
1713 | } |
1714 | |
1715 | //===----------------------------------------------------------------------===// |
1716 | // IndexCastOp |
1717 | //===----------------------------------------------------------------------===// |
1718 | |
1719 | static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) { |
1720 | if (!areValidCastInputsAndOutputs(inputs, outputs)) |
1721 | return false; |
1722 | |
1723 | auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(type: inputs.front()); |
1724 | auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(type: outputs.front()); |
1725 | if (!srcType || !dstType) |
1726 | return false; |
1727 | |
1728 | return (srcType.isIndex() && dstType.isSignlessInteger()) || |
1729 | (srcType.isSignlessInteger() && dstType.isIndex()); |
1730 | } |
1731 | |
1732 | bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, |
1733 | TypeRange outputs) { |
1734 | return areIndexCastCompatible(inputs, outputs); |
1735 | } |
1736 | |
1737 | OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) { |
1738 | // index_cast(constant) -> constant |
1739 | unsigned resultBitwidth = 64; // Default for index integer attributes. |
1740 | if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType()))) |
1741 | resultBitwidth = intTy.getWidth(); |
1742 | |
1743 | return constFoldCastOp<IntegerAttr, IntegerAttr>( |
1744 | adaptor.getOperands(), getType(), |
1745 | [resultBitwidth](const APInt &a, bool & /*castStatus*/) { |
1746 | return a.sextOrTrunc(resultBitwidth); |
1747 | }); |
1748 | } |
1749 | |
1750 | void arith::IndexCastOp::getCanonicalizationPatterns( |
1751 | RewritePatternSet &patterns, MLIRContext *context) { |
1752 | patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context); |
1753 | } |
1754 | |
1755 | //===----------------------------------------------------------------------===// |
1756 | // IndexCastUIOp |
1757 | //===----------------------------------------------------------------------===// |
1758 | |
1759 | bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs, |
1760 | TypeRange outputs) { |
1761 | return areIndexCastCompatible(inputs, outputs); |
1762 | } |
1763 | |
1764 | OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) { |
1765 | // index_castui(constant) -> constant |
1766 | unsigned resultBitwidth = 64; // Default for index integer attributes. |
1767 | if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType()))) |
1768 | resultBitwidth = intTy.getWidth(); |
1769 | |
1770 | return constFoldCastOp<IntegerAttr, IntegerAttr>( |
1771 | adaptor.getOperands(), getType(), |
1772 | [resultBitwidth](const APInt &a, bool & /*castStatus*/) { |
1773 | return a.zextOrTrunc(resultBitwidth); |
1774 | }); |
1775 | } |
1776 | |
1777 | void arith::IndexCastUIOp::getCanonicalizationPatterns( |
1778 | RewritePatternSet &patterns, MLIRContext *context) { |
1779 | patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context); |
1780 | } |
1781 | |
1782 | //===----------------------------------------------------------------------===// |
1783 | // BitcastOp |
1784 | //===----------------------------------------------------------------------===// |
1785 | |
1786 | bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1787 | if (!areValidCastInputsAndOutputs(inputs, outputs)) |
1788 | return false; |
1789 | |
1790 | auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front()); |
1791 | auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front()); |
1792 | if (!srcType || !dstType) |
1793 | return false; |
1794 | |
1795 | return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); |
1796 | } |
1797 | |
1798 | OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) { |
1799 | auto resType = getType(); |
1800 | auto operand = adaptor.getIn(); |
1801 | if (!operand) |
1802 | return {}; |
1803 | |
1804 | /// Bitcast dense elements. |
1805 | if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand)) |
1806 | return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType()); |
1807 | /// Other shaped types unhandled. |
1808 | if (llvm::isa<ShapedType>(resType)) |
1809 | return {}; |
1810 | |
1811 | /// Bitcast poison. |
1812 | if (llvm::isa<ub::PoisonAttr>(operand)) |
1813 | return ub::PoisonAttr::get(getContext()); |
1814 | |
1815 | /// Bitcast integer or float to integer or float. |
1816 | APInt bits = llvm::isa<FloatAttr>(operand) |
1817 | ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt() |
1818 | : llvm::cast<IntegerAttr>(operand).getValue(); |
1819 | assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() && |
1820 | "trying to fold on broken IR: operands have incompatible types"); |
1821 | |
1822 | if (auto resFloatType = llvm::dyn_cast<FloatType>(resType)) |
1823 | return FloatAttr::get(resType, |
1824 | APFloat(resFloatType.getFloatSemantics(), bits)); |
1825 | return IntegerAttr::get(resType, bits); |
1826 | } |
1827 | |
1828 | void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1829 | MLIRContext *context) { |
1830 | patterns.add<BitcastOfBitcast>(context); |
1831 | } |
1832 | |
1833 | //===----------------------------------------------------------------------===// |
1834 | // CmpIOp |
1835 | //===----------------------------------------------------------------------===// |
1836 | |
1837 | /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer |
1838 | /// comparison predicates. |
1839 | bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, |
1840 | const APInt &lhs, const APInt &rhs) { |
1841 | switch (predicate) { |
1842 | case arith::CmpIPredicate::eq: |
1843 | return lhs.eq(RHS: rhs); |
1844 | case arith::CmpIPredicate::ne: |
1845 | return lhs.ne(RHS: rhs); |
1846 | case arith::CmpIPredicate::slt: |
1847 | return lhs.slt(RHS: rhs); |
1848 | case arith::CmpIPredicate::sle: |
1849 | return lhs.sle(RHS: rhs); |
1850 | case arith::CmpIPredicate::sgt: |
1851 | return lhs.sgt(RHS: rhs); |
1852 | case arith::CmpIPredicate::sge: |
1853 | return lhs.sge(RHS: rhs); |
1854 | case arith::CmpIPredicate::ult: |
1855 | return lhs.ult(RHS: rhs); |
1856 | case arith::CmpIPredicate::ule: |
1857 | return lhs.ule(RHS: rhs); |
1858 | case arith::CmpIPredicate::ugt: |
1859 | return lhs.ugt(RHS: rhs); |
1860 | case arith::CmpIPredicate::uge: |
1861 | return lhs.uge(RHS: rhs); |
1862 | } |
1863 | llvm_unreachable("unknown cmpi predicate kind"); |
1864 | } |
1865 | |
1866 | /// Returns true if the predicate is true for two equal operands. |
1867 | static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { |
1868 | switch (predicate) { |
1869 | case arith::CmpIPredicate::eq: |
1870 | case arith::CmpIPredicate::sle: |
1871 | case arith::CmpIPredicate::sge: |
1872 | case arith::CmpIPredicate::ule: |
1873 | case arith::CmpIPredicate::uge: |
1874 | return true; |
1875 | case arith::CmpIPredicate::ne: |
1876 | case arith::CmpIPredicate::slt: |
1877 | case arith::CmpIPredicate::sgt: |
1878 | case arith::CmpIPredicate::ult: |
1879 | case arith::CmpIPredicate::ugt: |
1880 | return false; |
1881 | } |
1882 | llvm_unreachable("unknown cmpi predicate kind"); |
1883 | } |
1884 | |
1885 | static std::optional<int64_t> getIntegerWidth(Type t) { |
1886 | if (auto intType = llvm::dyn_cast<IntegerType>(t)) { |
1887 | return intType.getWidth(); |
1888 | } |
1889 | if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) { |
1890 | return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth(); |
1891 | } |
1892 | return std::nullopt; |
1893 | } |
1894 | |
1895 | OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) { |
1896 | // cmpi(pred, x, x) |
1897 | if (getLhs() == getRhs()) { |
1898 | auto val = applyCmpPredicateToEqualOperands(getPredicate()); |
1899 | return getBoolAttribute(getType(), val); |
1900 | } |
1901 | |
1902 | if (matchPattern(adaptor.getRhs(), m_Zero())) { |
1903 | if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) { |
1904 | // extsi(%x : i1 -> iN) != 0 -> %x |
1905 | std::optional<int64_t> integerWidth = |
1906 | getIntegerWidth(extOp.getOperand().getType()); |
1907 | if (integerWidth && integerWidth.value() == 1 && |
1908 | getPredicate() == arith::CmpIPredicate::ne) |
1909 | return extOp.getOperand(); |
1910 | } |
1911 | if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) { |
1912 | // extui(%x : i1 -> iN) != 0 -> %x |
1913 | std::optional<int64_t> integerWidth = |
1914 | getIntegerWidth(extOp.getOperand().getType()); |
1915 | if (integerWidth && integerWidth.value() == 1 && |
1916 | getPredicate() == arith::CmpIPredicate::ne) |
1917 | return extOp.getOperand(); |
1918 | } |
1919 | |
1920 | // arith.cmpi ne, %val, %zero : i1 -> %val |
1921 | if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) && |
1922 | getPredicate() == arith::CmpIPredicate::ne) |
1923 | return getLhs(); |
1924 | } |
1925 | |
1926 | if (matchPattern(adaptor.getRhs(), m_One())) { |
1927 | // arith.cmpi eq, %val, %one : i1 -> %val |
1928 | if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) && |
1929 | getPredicate() == arith::CmpIPredicate::eq) |
1930 | return getLhs(); |
1931 | } |
1932 | |
1933 | // Move constant to the right side. |
1934 | if (adaptor.getLhs() && !adaptor.getRhs()) { |
1935 | // Do not use invertPredicate, as it will change eq to ne and vice versa. |
1936 | using Pred = CmpIPredicate; |
1937 | const std::pair<Pred, Pred> invPreds[] = { |
1938 | {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge}, |
1939 | {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult}, |
1940 | {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq}, |
1941 | {Pred::ne, Pred::ne}, |
1942 | }; |
1943 | Pred origPred = getPredicate(); |
1944 | for (auto pred : invPreds) { |
1945 | if (origPred == pred.first) { |
1946 | setPredicate(pred.second); |
1947 | Value lhs = getLhs(); |
1948 | Value rhs = getRhs(); |
1949 | getLhsMutable().assign(rhs); |
1950 | getRhsMutable().assign(lhs); |
1951 | return getResult(); |
1952 | } |
1953 | } |
1954 | llvm_unreachable("unknown cmpi predicate kind"); |
1955 | } |
1956 | |
1957 | // We are moving constants to the right side; So if lhs is constant rhs is |
1958 | // guaranteed to be a constant. |
1959 | if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) { |
1960 | return constFoldBinaryOp<IntegerAttr>( |
1961 | adaptor.getOperands(), getI1SameShape(lhs.getType()), |
1962 | [pred = getPredicate()](const APInt &lhs, const APInt &rhs) { |
1963 | return APInt(1, |
1964 | static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs))); |
1965 | }); |
1966 | } |
1967 | |
1968 | return {}; |
1969 | } |
1970 | |
1971 | void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1972 | MLIRContext *context) { |
1973 | patterns.insert<CmpIExtSI, CmpIExtUI>(context); |
1974 | } |
1975 | |
1976 | //===----------------------------------------------------------------------===// |
1977 | // CmpFOp |
1978 | //===----------------------------------------------------------------------===// |
1979 | |
1980 | /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point |
1981 | /// comparison predicates. |
1982 | bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, |
1983 | const APFloat &lhs, const APFloat &rhs) { |
1984 | auto cmpResult = lhs.compare(RHS: rhs); |
1985 | switch (predicate) { |
1986 | case arith::CmpFPredicate::AlwaysFalse: |
1987 | return false; |
1988 | case arith::CmpFPredicate::OEQ: |
1989 | return cmpResult == APFloat::cmpEqual; |
1990 | case arith::CmpFPredicate::OGT: |
1991 | return cmpResult == APFloat::cmpGreaterThan; |
1992 | case arith::CmpFPredicate::OGE: |
1993 | return cmpResult == APFloat::cmpGreaterThan || |
1994 | cmpResult == APFloat::cmpEqual; |
1995 | case arith::CmpFPredicate::OLT: |
1996 | return cmpResult == APFloat::cmpLessThan; |
1997 | case arith::CmpFPredicate::OLE: |
1998 | return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; |
1999 | case arith::CmpFPredicate::ONE: |
2000 | return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; |
2001 | case arith::CmpFPredicate::ORD: |
2002 | return cmpResult != APFloat::cmpUnordered; |
2003 | case arith::CmpFPredicate::UEQ: |
2004 | return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; |
2005 | case arith::CmpFPredicate::UGT: |
2006 | return cmpResult == APFloat::cmpUnordered || |
2007 | cmpResult == APFloat::cmpGreaterThan; |
2008 | case arith::CmpFPredicate::UGE: |
2009 | return cmpResult == APFloat::cmpUnordered || |
2010 | cmpResult == APFloat::cmpGreaterThan || |
2011 | cmpResult == APFloat::cmpEqual; |
2012 | case arith::CmpFPredicate::ULT: |
2013 | return cmpResult == APFloat::cmpUnordered || |
2014 | cmpResult == APFloat::cmpLessThan; |
2015 | case arith::CmpFPredicate::ULE: |
2016 | return cmpResult == APFloat::cmpUnordered || |
2017 | cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; |
2018 | case arith::CmpFPredicate::UNE: |
2019 | return cmpResult != APFloat::cmpEqual; |
2020 | case arith::CmpFPredicate::UNO: |
2021 | return cmpResult == APFloat::cmpUnordered; |
2022 | case arith::CmpFPredicate::AlwaysTrue: |
2023 | return true; |
2024 | } |
2025 | llvm_unreachable("unknown cmpf predicate kind"); |
2026 | } |
2027 | |
2028 | OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) { |
2029 | auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs()); |
2030 | auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs()); |
2031 | |
2032 | // If one operand is NaN, making them both NaN does not change the result. |
2033 | if (lhs && lhs.getValue().isNaN()) |
2034 | rhs = lhs; |
2035 | if (rhs && rhs.getValue().isNaN()) |
2036 | lhs = rhs; |
2037 | |
2038 | if (!lhs || !rhs) |
2039 | return {}; |
2040 | |
2041 | auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); |
2042 | return BoolAttr::get(getContext(), val); |
2043 | } |
2044 | |
2045 | class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> { |
2046 | public: |
2047 | using OpRewritePattern<CmpFOp>::OpRewritePattern; |
2048 | |
2049 | static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, |
2050 | bool isUnsigned) { |
2051 | using namespace arith; |
2052 | switch (pred) { |
2053 | case CmpFPredicate::UEQ: |
2054 | case CmpFPredicate::OEQ: |
2055 | return CmpIPredicate::eq; |
2056 | case CmpFPredicate::UGT: |
2057 | case CmpFPredicate::OGT: |
2058 | return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt; |
2059 | case CmpFPredicate::UGE: |
2060 | case CmpFPredicate::OGE: |
2061 | return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge; |
2062 | case CmpFPredicate::ULT: |
2063 | case CmpFPredicate::OLT: |
2064 | return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt; |
2065 | case CmpFPredicate::ULE: |
2066 | case CmpFPredicate::OLE: |
2067 | return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle; |
2068 | case CmpFPredicate::UNE: |
2069 | case CmpFPredicate::ONE: |
2070 | return CmpIPredicate::ne; |
2071 | default: |
2072 | llvm_unreachable("Unexpected predicate!"); |
2073 | } |
2074 | } |
2075 | |
2076 | LogicalResult matchAndRewrite(CmpFOp op, |
2077 | PatternRewriter &rewriter) const override { |
2078 | FloatAttr flt; |
2079 | if (!matchPattern(op.getRhs(), m_Constant(&flt))) |
2080 | return failure(); |
2081 | |
2082 | const APFloat &rhs = flt.getValue(); |
2083 | |
2084 | // Don't attempt to fold a nan. |
2085 | if (rhs.isNaN()) |
2086 | return failure(); |
2087 | |
2088 | // Get the width of the mantissa. We don't want to hack on conversions that |
2089 | // might lose information from the integer, e.g. "i64 -> float" |
2090 | FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType()); |
2091 | int mantissaWidth = floatTy.getFPMantissaWidth(); |
2092 | if (mantissaWidth <= 0) |
2093 | return failure(); |
2094 | |
2095 | bool isUnsigned; |
2096 | Value intVal; |
2097 | |
2098 | if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) { |
2099 | isUnsigned = false; |
2100 | intVal = si.getIn(); |
2101 | } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) { |
2102 | isUnsigned = true; |
2103 | intVal = ui.getIn(); |
2104 | } else { |
2105 | return failure(); |
2106 | } |
2107 | |
2108 | // Check to see that the input is converted from an integer type that is |
2109 | // small enough that preserves all bits. |
2110 | auto intTy = llvm::cast<IntegerType>(intVal.getType()); |
2111 | auto intWidth = intTy.getWidth(); |
2112 | |
2113 | // Number of bits representing values, as opposed to the sign |
2114 | auto valueBits = isUnsigned ? intWidth : (intWidth - 1); |
2115 | |
2116 | // Following test does NOT adjust intWidth downwards for signed inputs, |
2117 | // because the most negative value still requires all the mantissa bits |
2118 | // to distinguish it from one less than that value. |
2119 | if ((int)intWidth > mantissaWidth) { |
2120 | // Conversion would lose accuracy. Check if loss can impact comparison. |
2121 | int exponent = ilogb(Arg: rhs); |
2122 | if (exponent == APFloat::IEK_Inf) { |
2123 | int maxExponent = ilogb(Arg: APFloat::getLargest(Sem: rhs.getSemantics())); |
2124 | if (maxExponent < (int)valueBits) { |
2125 | // Conversion could create infinity. |
2126 | return failure(); |
2127 | } |
2128 | } else { |
2129 | // Note that if rhs is zero or NaN, then Exp is negative |
2130 | // and first condition is trivially false. |
2131 | if (mantissaWidth <= exponent && exponent <= (int)valueBits) { |
2132 | // Conversion could affect comparison. |
2133 | return failure(); |
2134 | } |
2135 | } |
2136 | } |
2137 | |
2138 | // Convert to equivalent cmpi predicate |
2139 | CmpIPredicate pred; |
2140 | switch (op.getPredicate()) { |
2141 | case CmpFPredicate::ORD: |
2142 | // Int to fp conversion doesn't create a nan (ord checks neither is a nan) |
2143 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2144 | /*width=*/1); |
2145 | return success(); |
2146 | case CmpFPredicate::UNO: |
2147 | // Int to fp conversion doesn't create a nan (uno checks either is a nan) |
2148 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2149 | /*width=*/1); |
2150 | return success(); |
2151 | default: |
2152 | pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned); |
2153 | break; |
2154 | } |
2155 | |
2156 | if (!isUnsigned) { |
2157 | // If the rhs value is > SignedMax, fold the comparison. This handles |
2158 | // +INF and large values. |
2159 | APFloat signedMax(rhs.getSemantics()); |
2160 | signedMax.convertFromAPInt(Input: APInt::getSignedMaxValue(numBits: intWidth), IsSigned: true, |
2161 | RM: APFloat::rmNearestTiesToEven); |
2162 | if (signedMax < rhs) { // smax < 13123.0 |
2163 | if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt || |
2164 | pred == CmpIPredicate::sle) |
2165 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2166 | /*width=*/1); |
2167 | else |
2168 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2169 | /*width=*/1); |
2170 | return success(); |
2171 | } |
2172 | } else { |
2173 | // If the rhs value is > UnsignedMax, fold the comparison. This handles |
2174 | // +INF and large values. |
2175 | APFloat unsignedMax(rhs.getSemantics()); |
2176 | unsignedMax.convertFromAPInt(Input: APInt::getMaxValue(numBits: intWidth), IsSigned: false, |
2177 | RM: APFloat::rmNearestTiesToEven); |
2178 | if (unsignedMax < rhs) { // umax < 13123.0 |
2179 | if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult || |
2180 | pred == CmpIPredicate::ule) |
2181 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2182 | /*width=*/1); |
2183 | else |
2184 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2185 | /*width=*/1); |
2186 | return success(); |
2187 | } |
2188 | } |
2189 | |
2190 | if (!isUnsigned) { |
2191 | // See if the rhs value is < SignedMin. |
2192 | APFloat signedMin(rhs.getSemantics()); |
2193 | signedMin.convertFromAPInt(Input: APInt::getSignedMinValue(numBits: intWidth), IsSigned: true, |
2194 | RM: APFloat::rmNearestTiesToEven); |
2195 | if (signedMin > rhs) { // smin > 12312.0 |
2196 | if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt || |
2197 | pred == CmpIPredicate::sge) |
2198 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2199 | /*width=*/1); |
2200 | else |
2201 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2202 | /*width=*/1); |
2203 | return success(); |
2204 | } |
2205 | } else { |
2206 | // See if the rhs value is < UnsignedMin. |
2207 | APFloat unsignedMin(rhs.getSemantics()); |
2208 | unsignedMin.convertFromAPInt(Input: APInt::getMinValue(numBits: intWidth), IsSigned: false, |
2209 | RM: APFloat::rmNearestTiesToEven); |
2210 | if (unsignedMin > rhs) { // umin > 12312.0 |
2211 | if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt || |
2212 | pred == CmpIPredicate::uge) |
2213 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2214 | /*width=*/1); |
2215 | else |
2216 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2217 | /*width=*/1); |
2218 | return success(); |
2219 | } |
2220 | } |
2221 | |
2222 | // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or |
2223 | // [0, UMAX], but it may still be fractional. See if it is fractional by |
2224 | // casting the FP value to the integer value and back, checking for |
2225 | // equality. Don't do this for zero, because -0.0 is not fractional. |
2226 | bool ignored; |
2227 | APSInt rhsInt(intWidth, isUnsigned); |
2228 | if (APFloat::opInvalidOp == |
2229 | rhs.convertToInteger(Result&: rhsInt, RM: APFloat::rmTowardZero, IsExact: &ignored)) { |
2230 | // Undefined behavior invoked - the destination type can't represent |
2231 | // the input constant. |
2232 | return failure(); |
2233 | } |
2234 | |
2235 | if (!rhs.isZero()) { |
2236 | APFloat apf(floatTy.getFloatSemantics(), |
2237 | APInt::getZero(numBits: floatTy.getWidth())); |
2238 | apf.convertFromAPInt(Input: rhsInt, IsSigned: !isUnsigned, RM: APFloat::rmNearestTiesToEven); |
2239 | |
2240 | bool equal = apf == rhs; |
2241 | if (!equal) { |
2242 | // If we had a comparison against a fractional value, we have to adjust |
2243 | // the compare predicate and sometimes the value. rhsInt is rounded |
2244 | // towards zero at this point. |
2245 | switch (pred) { |
2246 | case CmpIPredicate::ne: // (float)int != 4.4 --> true |
2247 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2248 | /*width=*/1); |
2249 | return success(); |
2250 | case CmpIPredicate::eq: // (float)int == 4.4 --> false |
2251 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2252 | /*width=*/1); |
2253 | return success(); |
2254 | case CmpIPredicate::ule: |
2255 | // (float)int <= 4.4 --> int <= 4 |
2256 | // (float)int <= -4.4 --> false |
2257 | if (rhs.isNegative()) { |
2258 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2259 | /*width=*/1); |
2260 | return success(); |
2261 | } |
2262 | break; |
2263 | case CmpIPredicate::sle: |
2264 | // (float)int <= 4.4 --> int <= 4 |
2265 | // (float)int <= -4.4 --> int < -4 |
2266 | if (rhs.isNegative()) |
2267 | pred = CmpIPredicate::slt; |
2268 | break; |
2269 | case CmpIPredicate::ult: |
2270 | // (float)int < -4.4 --> false |
2271 | // (float)int < 4.4 --> int <= 4 |
2272 | if (rhs.isNegative()) { |
2273 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false, |
2274 | /*width=*/1); |
2275 | return success(); |
2276 | } |
2277 | pred = CmpIPredicate::ule; |
2278 | break; |
2279 | case CmpIPredicate::slt: |
2280 | // (float)int < -4.4 --> int < -4 |
2281 | // (float)int < 4.4 --> int <= 4 |
2282 | if (!rhs.isNegative()) |
2283 | pred = CmpIPredicate::sle; |
2284 | break; |
2285 | case CmpIPredicate::ugt: |
2286 | // (float)int > 4.4 --> int > 4 |
2287 | // (float)int > -4.4 --> true |
2288 | if (rhs.isNegative()) { |
2289 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2290 | /*width=*/1); |
2291 | return success(); |
2292 | } |
2293 | break; |
2294 | case CmpIPredicate::sgt: |
2295 | // (float)int > 4.4 --> int > 4 |
2296 | // (float)int > -4.4 --> int >= -4 |
2297 | if (rhs.isNegative()) |
2298 | pred = CmpIPredicate::sge; |
2299 | break; |
2300 | case CmpIPredicate::uge: |
2301 | // (float)int >= -4.4 --> true |
2302 | // (float)int >= 4.4 --> int > 4 |
2303 | if (rhs.isNegative()) { |
2304 | rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true, |
2305 | /*width=*/1); |
2306 | return success(); |
2307 | } |
2308 | pred = CmpIPredicate::ugt; |
2309 | break; |
2310 | case CmpIPredicate::sge: |
2311 | // (float)int >= -4.4 --> int >= -4 |
2312 | // (float)int >= 4.4 --> int > 4 |
2313 | if (!rhs.isNegative()) |
2314 | pred = CmpIPredicate::sgt; |
2315 | break; |
2316 | } |
2317 | } |
2318 | } |
2319 | |
2320 | // Lower this FP comparison into an appropriate integer version of the |
2321 | // comparison. |
2322 | rewriter.replaceOpWithNewOp<CmpIOp>( |
2323 | op, pred, intVal, |
2324 | rewriter.create<ConstantOp>( |
2325 | op.getLoc(), intVal.getType(), |
2326 | rewriter.getIntegerAttr(intVal.getType(), rhsInt))); |
2327 | return success(); |
2328 | } |
2329 | }; |
2330 | |
2331 | void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
2332 | MLIRContext *context) { |
2333 | patterns.insert<CmpFIntToFPConst>(context); |
2334 | } |
2335 | |
2336 | //===----------------------------------------------------------------------===// |
2337 | // SelectOp |
2338 | //===----------------------------------------------------------------------===// |
2339 | |
2340 | // select %arg, %c1, %c0 => extui %arg |
2341 | struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> { |
2342 | using OpRewritePattern<arith::SelectOp>::OpRewritePattern; |
2343 | |
2344 | LogicalResult matchAndRewrite(arith::SelectOp op, |
2345 | PatternRewriter &rewriter) const override { |
2346 | // Cannot extui i1 to i1, or i1 to f32 |
2347 | if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1)) |
2348 | return failure(); |
2349 | |
2350 | // select %x, c1, %c0 => extui %arg |
2351 | if (matchPattern(op.getTrueValue(), m_One()) && |
2352 | matchPattern(op.getFalseValue(), m_Zero())) { |
2353 | rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(), |
2354 | op.getCondition()); |
2355 | return success(); |
2356 | } |
2357 | |
2358 | // select %x, c0, %c1 => extui (xor %arg, true) |
2359 | if (matchPattern(op.getTrueValue(), m_Zero()) && |
2360 | matchPattern(op.getFalseValue(), m_One())) { |
2361 | rewriter.replaceOpWithNewOp<arith::ExtUIOp>( |
2362 | op, op.getType(), |
2363 | rewriter.create<arith::XOrIOp>( |
2364 | op.getLoc(), op.getCondition(), |
2365 | rewriter.create<arith::ConstantIntOp>( |
2366 | op.getLoc(), 1, op.getCondition().getType()))); |
2367 | return success(); |
2368 | } |
2369 | |
2370 | return failure(); |
2371 | } |
2372 | }; |
2373 | |
2374 | void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, |
2375 | MLIRContext *context) { |
2376 | results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond, |
2377 | SelectI1ToNot, SelectToExtUI>(context); |
2378 | } |
2379 | |
2380 | OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) { |
2381 | Value trueVal = getTrueValue(); |
2382 | Value falseVal = getFalseValue(); |
2383 | if (trueVal == falseVal) |
2384 | return trueVal; |
2385 | |
2386 | Value condition = getCondition(); |
2387 | |
2388 | // select true, %0, %1 => %0 |
2389 | if (matchPattern(adaptor.getCondition(), m_One())) |
2390 | return trueVal; |
2391 | |
2392 | // select false, %0, %1 => %1 |
2393 | if (matchPattern(adaptor.getCondition(), m_Zero())) |
2394 | return falseVal; |
2395 | |
2396 | // If either operand is fully poisoned, return the other. |
2397 | if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue())) |
2398 | return falseVal; |
2399 | |
2400 | if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue())) |
2401 | return trueVal; |
2402 | |
2403 | // select %x, true, false => %x |
2404 | if (getType().isSignlessInteger(1) && |
2405 | matchPattern(adaptor.getTrueValue(), m_One()) && |
2406 | matchPattern(adaptor.getFalseValue(), m_Zero())) |
2407 | return condition; |
2408 | |
2409 | if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) { |
2410 | auto pred = cmp.getPredicate(); |
2411 | if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { |
2412 | auto cmpLhs = cmp.getLhs(); |
2413 | auto cmpRhs = cmp.getRhs(); |
2414 | |
2415 | // %0 = arith.cmpi eq, %arg0, %arg1 |
2416 | // %1 = arith.select %0, %arg0, %arg1 => %arg1 |
2417 | |
2418 | // %0 = arith.cmpi ne, %arg0, %arg1 |
2419 | // %1 = arith.select %0, %arg0, %arg1 => %arg0 |
2420 | |
2421 | if ((cmpLhs == trueVal && cmpRhs == falseVal) || |
2422 | (cmpRhs == trueVal && cmpLhs == falseVal)) |
2423 | return pred == arith::CmpIPredicate::ne ? trueVal : falseVal; |
2424 | } |
2425 | } |
2426 | |
2427 | // Constant-fold constant operands over non-splat constant condition. |
2428 | // select %cst_vec, %cst0, %cst1 => %cst2 |
2429 | if (auto cond = |
2430 | llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) { |
2431 | if (auto lhs = |
2432 | llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) { |
2433 | if (auto rhs = |
2434 | llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) { |
2435 | SmallVector<Attribute> results; |
2436 | results.reserve(static_cast<size_t>(cond.getNumElements())); |
2437 | auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(), |
2438 | cond.value_end<BoolAttr>()); |
2439 | auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(), |
2440 | lhs.value_end<Attribute>()); |
2441 | auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(), |
2442 | rhs.value_end<Attribute>()); |
2443 | |
2444 | for (auto [condVal, lhsVal, rhsVal] : |
2445 | llvm::zip_equal(condVals, lhsVals, rhsVals)) |
2446 | results.push_back(condVal.getValue() ? lhsVal : rhsVal); |
2447 | |
2448 | return DenseElementsAttr::get(lhs.getType(), results); |
2449 | } |
2450 | } |
2451 | } |
2452 | |
2453 | return nullptr; |
2454 | } |
2455 | |
2456 | ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { |
2457 | Type conditionType, resultType; |
2458 | SmallVector<OpAsmParser::UnresolvedOperand, 3> operands; |
2459 | if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || |
2460 | parser.parseOptionalAttrDict(result.attributes) || |
2461 | parser.parseColonType(resultType)) |
2462 | return failure(); |
2463 | |
2464 | // Check for the explicit condition type if this is a masked tensor or vector. |
2465 | if (succeeded(parser.parseOptionalComma())) { |
2466 | conditionType = resultType; |
2467 | if (parser.parseType(resultType)) |
2468 | return failure(); |
2469 | } else { |
2470 | conditionType = parser.getBuilder().getI1Type(); |
2471 | } |
2472 | |
2473 | result.addTypes(resultType); |
2474 | return parser.resolveOperands(operands, |
2475 | {conditionType, resultType, resultType}, |
2476 | parser.getNameLoc(), result.operands); |
2477 | } |
2478 | |
2479 | void arith::SelectOp::print(OpAsmPrinter &p) { |
2480 | p << " "<< getOperands(); |
2481 | p.printOptionalAttrDict((*this)->getAttrs()); |
2482 | p << " : "; |
2483 | if (ShapedType condType = |
2484 | llvm::dyn_cast<ShapedType>(getCondition().getType())) |
2485 | p << condType << ", "; |
2486 | p << getType(); |
2487 | } |
2488 | |
2489 | LogicalResult arith::SelectOp::verify() { |
2490 | Type conditionType = getCondition().getType(); |
2491 | if (conditionType.isSignlessInteger(1)) |
2492 | return success(); |
2493 | |
2494 | // If the result type is a vector or tensor, the type can be a mask with the |
2495 | // same elements. |
2496 | Type resultType = getType(); |
2497 | if (!llvm::isa<TensorType, VectorType>(resultType)) |
2498 | return emitOpError() << "expected condition to be a signless i1, but got " |
2499 | << conditionType; |
2500 | Type shapedConditionType = getI1SameShape(resultType); |
2501 | if (conditionType != shapedConditionType) { |
2502 | return emitOpError() << "expected condition type to have the same shape " |
2503 | "as the result type, expected " |
2504 | << shapedConditionType << ", but got " |
2505 | << conditionType; |
2506 | } |
2507 | return success(); |
2508 | } |
2509 | //===----------------------------------------------------------------------===// |
2510 | // ShLIOp |
2511 | //===----------------------------------------------------------------------===// |
2512 | |
2513 | OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) { |
2514 | // shli(x, 0) -> x |
2515 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
2516 | return getLhs(); |
2517 | // Don't fold if shifting more or equal than the bit width. |
2518 | bool bounded = false; |
2519 | auto result = constFoldBinaryOp<IntegerAttr>( |
2520 | adaptor.getOperands(), [&](const APInt &a, const APInt &b) { |
2521 | bounded = b.ult(b.getBitWidth()); |
2522 | return a.shl(b); |
2523 | }); |
2524 | return bounded ? result : Attribute(); |
2525 | } |
2526 | |
2527 | //===----------------------------------------------------------------------===// |
2528 | // ShRUIOp |
2529 | //===----------------------------------------------------------------------===// |
2530 | |
2531 | OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) { |
2532 | // shrui(x, 0) -> x |
2533 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
2534 | return getLhs(); |
2535 | // Don't fold if shifting more or equal than the bit width. |
2536 | bool bounded = false; |
2537 | auto result = constFoldBinaryOp<IntegerAttr>( |
2538 | adaptor.getOperands(), [&](const APInt &a, const APInt &b) { |
2539 | bounded = b.ult(b.getBitWidth()); |
2540 | return a.lshr(b); |
2541 | }); |
2542 | return bounded ? result : Attribute(); |
2543 | } |
2544 | |
2545 | //===----------------------------------------------------------------------===// |
2546 | // ShRSIOp |
2547 | //===----------------------------------------------------------------------===// |
2548 | |
2549 | OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) { |
2550 | // shrsi(x, 0) -> x |
2551 | if (matchPattern(adaptor.getRhs(), m_Zero())) |
2552 | return getLhs(); |
2553 | // Don't fold if shifting more or equal than the bit width. |
2554 | bool bounded = false; |
2555 | auto result = constFoldBinaryOp<IntegerAttr>( |
2556 | adaptor.getOperands(), [&](const APInt &a, const APInt &b) { |
2557 | bounded = b.ult(b.getBitWidth()); |
2558 | return a.ashr(b); |
2559 | }); |
2560 | return bounded ? result : Attribute(); |
2561 | } |
2562 | |
2563 | //===----------------------------------------------------------------------===// |
2564 | // Atomic Enum |
2565 | //===----------------------------------------------------------------------===// |
2566 | |
2567 | /// Returns the identity value attribute associated with an AtomicRMWKind op. |
2568 | TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, |
2569 | OpBuilder &builder, Location loc, |
2570 | bool useOnlyFiniteValue) { |
2571 | switch (kind) { |
2572 | case AtomicRMWKind::maximumf: { |
2573 | const llvm::fltSemantics &semantic = |
2574 | llvm::cast<FloatType>(resultType).getFloatSemantics(); |
2575 | APFloat identity = useOnlyFiniteValue |
2576 | ? APFloat::getLargest(Sem: semantic, /*Negative=*/true) |
2577 | : APFloat::getInf(Sem: semantic, /*Negative=*/true); |
2578 | return builder.getFloatAttr(resultType, identity); |
2579 | } |
2580 | case AtomicRMWKind::maxnumf: { |
2581 | const llvm::fltSemantics &semantic = |
2582 | llvm::cast<FloatType>(resultType).getFloatSemantics(); |
2583 | APFloat identity = APFloat::getNaN(Sem: semantic, /*Negative=*/true); |
2584 | return builder.getFloatAttr(resultType, identity); |
2585 | } |
2586 | case AtomicRMWKind::addf: |
2587 | case AtomicRMWKind::addi: |
2588 | case AtomicRMWKind::maxu: |
2589 | case AtomicRMWKind::ori: |
2590 | return builder.getZeroAttr(resultType); |
2591 | case AtomicRMWKind::andi: |
2592 | return builder.getIntegerAttr( |
2593 | resultType, |
2594 | APInt::getAllOnes(numBits: llvm::cast<IntegerType>(resultType).getWidth())); |
2595 | case AtomicRMWKind::maxs: |
2596 | return builder.getIntegerAttr( |
2597 | resultType, APInt::getSignedMinValue( |
2598 | numBits: llvm::cast<IntegerType>(resultType).getWidth())); |
2599 | case AtomicRMWKind::minimumf: { |
2600 | const llvm::fltSemantics &semantic = |
2601 | llvm::cast<FloatType>(resultType).getFloatSemantics(); |
2602 | APFloat identity = useOnlyFiniteValue |
2603 | ? APFloat::getLargest(Sem: semantic, /*Negative=*/false) |
2604 | : APFloat::getInf(Sem: semantic, /*Negative=*/false); |
2605 | |
2606 | return builder.getFloatAttr(resultType, identity); |
2607 | } |
2608 | case AtomicRMWKind::minnumf: { |
2609 | const llvm::fltSemantics &semantic = |
2610 | llvm::cast<FloatType>(resultType).getFloatSemantics(); |
2611 | APFloat identity = APFloat::getNaN(Sem: semantic, /*Negative=*/false); |
2612 | return builder.getFloatAttr(resultType, identity); |
2613 | } |
2614 | case AtomicRMWKind::mins: |
2615 | return builder.getIntegerAttr( |
2616 | resultType, APInt::getSignedMaxValue( |
2617 | numBits: llvm::cast<IntegerType>(resultType).getWidth())); |
2618 | case AtomicRMWKind::minu: |
2619 | return builder.getIntegerAttr( |
2620 | resultType, |
2621 | APInt::getMaxValue(numBits: llvm::cast<IntegerType>(resultType).getWidth())); |
2622 | case AtomicRMWKind::muli: |
2623 | return builder.getIntegerAttr(resultType, 1); |
2624 | case AtomicRMWKind::mulf: |
2625 | return builder.getFloatAttr(resultType, 1); |
2626 | // TODO: Add remaining reduction operations. |
2627 | default: |
2628 | (void)emitOptionalError(loc, args: "Reduction operation type not supported"); |
2629 | break; |
2630 | } |
2631 | return nullptr; |
2632 | } |
2633 | |
2634 | /// Return the identity numeric value associated to the give op. |
2635 | std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) { |
2636 | std::optional<AtomicRMWKind> maybeKind = |
2637 | llvm::TypeSwitch<Operation *, std::optional<AtomicRMWKind>>(op) |
2638 | // Floating-point operations. |
2639 | .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; }) |
2640 | .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; }) |
2641 | .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; }) |
2642 | .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; }) |
2643 | .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; }) |
2644 | .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; }) |
2645 | // Integer operations. |
2646 | .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; }) |
2647 | .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; }) |
2648 | .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; }) |
2649 | .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; }) |
2650 | .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; }) |
2651 | .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; }) |
2652 | .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; }) |
2653 | .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; }) |
2654 | .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; }) |
2655 | .Default([](Operation *op) { return std::nullopt; }); |
2656 | if (!maybeKind) { |
2657 | return std::nullopt; |
2658 | } |
2659 | |
2660 | bool useOnlyFiniteValue = false; |
2661 | auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op); |
2662 | if (fmfOpInterface) { |
2663 | arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr(); |
2664 | useOnlyFiniteValue = |
2665 | bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf); |
2666 | } |
2667 | |
2668 | // Builder only used as helper for attribute creation. |
2669 | OpBuilder b(op->getContext()); |
2670 | Type resultType = op->getResult(idx: 0).getType(); |
2671 | |
2672 | return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(), |
2673 | useOnlyFiniteValue); |
2674 | } |
2675 | |
2676 | /// Returns the identity value associated with an AtomicRMWKind op. |
2677 | Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, |
2678 | OpBuilder &builder, Location loc, |
2679 | bool useOnlyFiniteValue) { |
2680 | auto attr = |
2681 | getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue); |
2682 | return builder.create<arith::ConstantOp>(loc, attr); |
2683 | } |
2684 | |
2685 | /// Return the value obtained by applying the reduction operation kind |
2686 | /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. |
2687 | Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, |
2688 | Location loc, Value lhs, Value rhs) { |
2689 | switch (op) { |
2690 | case AtomicRMWKind::addf: |
2691 | return builder.create<arith::AddFOp>(loc, lhs, rhs); |
2692 | case AtomicRMWKind::addi: |
2693 | return builder.create<arith::AddIOp>(loc, lhs, rhs); |
2694 | case AtomicRMWKind::mulf: |
2695 | return builder.create<arith::MulFOp>(loc, lhs, rhs); |
2696 | case AtomicRMWKind::muli: |
2697 | return builder.create<arith::MulIOp>(loc, lhs, rhs); |
2698 | case AtomicRMWKind::maximumf: |
2699 | return builder.create<arith::MaximumFOp>(loc, lhs, rhs); |
2700 | case AtomicRMWKind::minimumf: |
2701 | return builder.create<arith::MinimumFOp>(loc, lhs, rhs); |
2702 | case AtomicRMWKind::maxnumf: |
2703 | return builder.create<arith::MaxNumFOp>(loc, lhs, rhs); |
2704 | case AtomicRMWKind::minnumf: |
2705 | return builder.create<arith::MinNumFOp>(loc, lhs, rhs); |
2706 | case AtomicRMWKind::maxs: |
2707 | return builder.create<arith::MaxSIOp>(loc, lhs, rhs); |
2708 | case AtomicRMWKind::mins: |
2709 | return builder.create<arith::MinSIOp>(loc, lhs, rhs); |
2710 | case AtomicRMWKind::maxu: |
2711 | return builder.create<arith::MaxUIOp>(loc, lhs, rhs); |
2712 | case AtomicRMWKind::minu: |
2713 | return builder.create<arith::MinUIOp>(loc, lhs, rhs); |
2714 | case AtomicRMWKind::ori: |
2715 | return builder.create<arith::OrIOp>(loc, lhs, rhs); |
2716 | case AtomicRMWKind::andi: |
2717 | return builder.create<arith::AndIOp>(loc, lhs, rhs); |
2718 | // TODO: Add remaining reduction operations. |
2719 | default: |
2720 | (void)emitOptionalError(loc, args: "Reduction operation type not supported"); |
2721 | break; |
2722 | } |
2723 | return nullptr; |
2724 | } |
2725 | |
2726 | //===----------------------------------------------------------------------===// |
2727 | // TableGen'd op method definitions |
2728 | //===----------------------------------------------------------------------===// |
2729 | |
2730 | #define GET_OP_CLASSES |
2731 | #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc" |
2732 | |
2733 | //===----------------------------------------------------------------------===// |
2734 | // TableGen'd enum attribute definitions |
2735 | //===----------------------------------------------------------------------===// |
2736 | |
2737 | #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc" |
2738 |
Definitions
- applyToIntegerAttrs
- addIntegerAttrs
- subIntegerAttrs
- mulIntegerAttrs
- mergeOverflowFlags
- invertPredicate
- convertArithRoundingModeToLLVMIR
- invertPredicate
- getScalarOrElementWidth
- getScalarOrElementWidth
- getIntOrSplatIntValue
- getBoolAttribute
- getI1SameShape
- build
- build
- classof
- build
- classof
- build
- classof
- calculateUnsignedOverflow
- foldDivMul
- getDivUISpeculatability
- getDivSISpeculatability
- signedCeilNonnegInputs
- foldAndIofAndI
- getUnderlyingType
- getTypeIfLike
- getTypeIfLikeOrMemRef
- hasSameEncoding
- areValidCastInputsAndOutputs
- verifyExtOp
- verifyTruncateOp
- checkWidthChangeCast
- convertFloatValue
- checkIntFloatCast
- areIndexCastCompatible
- applyCmpPredicate
- applyCmpPredicateToEqualOperands
- getIntegerWidth
- applyCmpPredicate
- CmpFIntToFPConst
- convertToIntegerPredicate
- matchAndRewrite
- SelectToExtUI
- matchAndRewrite
- getIdentityValueAttr
- getNeutralElement
- getIdentityValue
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more