1 | //===- MathOps.cpp - MLIR operations for math implementation --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Arith/IR/Arith.h" |
10 | #include "mlir/Dialect/CommonFolders.h" |
11 | #include "mlir/Dialect/Math/IR/Math.h" |
12 | #include "mlir/Dialect/UB/IR/UBOps.h" |
13 | #include "mlir/IR/Builders.h" |
14 | #include <optional> |
15 | |
16 | using namespace mlir; |
17 | using namespace mlir::math; |
18 | |
19 | //===----------------------------------------------------------------------===// |
20 | // Common helpers |
21 | //===----------------------------------------------------------------------===// |
22 | |
23 | /// Return the type of the same shape (scalar, vector or tensor) containing i1. |
24 | static Type getI1SameShape(Type type) { |
25 | auto i1Type = IntegerType::get(type.getContext(), 1); |
26 | if (auto shapedType = llvm::dyn_cast<ShapedType>(type)) |
27 | return shapedType.cloneWith(std::nullopt, i1Type); |
28 | if (llvm::isa<UnrankedTensorType>(type)) |
29 | return UnrankedTensorType::get(i1Type); |
30 | return i1Type; |
31 | } |
32 | |
33 | //===----------------------------------------------------------------------===// |
34 | // TableGen'd op method definitions |
35 | //===----------------------------------------------------------------------===// |
36 | |
37 | #define GET_OP_CLASSES |
38 | #include "mlir/Dialect/Math/IR/MathOps.cpp.inc" |
39 | |
40 | //===----------------------------------------------------------------------===// |
41 | // AbsFOp folder |
42 | //===----------------------------------------------------------------------===// |
43 | |
44 | OpFoldResult math::AbsFOp::fold(FoldAdaptor adaptor) { |
45 | return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(), |
46 | [](const APFloat &a) { return abs(a); }); |
47 | } |
48 | |
49 | //===----------------------------------------------------------------------===// |
50 | // AbsIOp folder |
51 | //===----------------------------------------------------------------------===// |
52 | |
53 | OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) { |
54 | return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), |
55 | [](const APInt &a) { return a.abs(); }); |
56 | } |
57 | |
58 | //===----------------------------------------------------------------------===// |
59 | // AcosOp folder |
60 | //===----------------------------------------------------------------------===// |
61 | |
62 | OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) { |
63 | return constFoldUnaryOpConditional<FloatAttr>( |
64 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
65 | switch (a.getSizeInBits(a.getSemantics())) { |
66 | case 64: |
67 | return APFloat(acos(a.convertToDouble())); |
68 | case 32: |
69 | return APFloat(acosf(a.convertToFloat())); |
70 | default: |
71 | return {}; |
72 | } |
73 | }); |
74 | } |
75 | |
76 | //===----------------------------------------------------------------------===// |
77 | // AcoshOp folder |
78 | //===----------------------------------------------------------------------===// |
79 | |
80 | OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) { |
81 | return constFoldUnaryOpConditional<FloatAttr>( |
82 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
83 | switch (a.getSizeInBits(a.getSemantics())) { |
84 | case 64: |
85 | return APFloat(acosh(a.convertToDouble())); |
86 | case 32: |
87 | return APFloat(acoshf(a.convertToFloat())); |
88 | default: |
89 | return {}; |
90 | } |
91 | }); |
92 | } |
93 | |
94 | //===----------------------------------------------------------------------===// |
95 | // AsinOp folder |
96 | //===----------------------------------------------------------------------===// |
97 | |
98 | OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) { |
99 | return constFoldUnaryOpConditional<FloatAttr>( |
100 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
101 | switch (a.getSizeInBits(a.getSemantics())) { |
102 | case 64: |
103 | return APFloat(asin(a.convertToDouble())); |
104 | case 32: |
105 | return APFloat(asinf(a.convertToFloat())); |
106 | default: |
107 | return {}; |
108 | } |
109 | }); |
110 | } |
111 | |
112 | //===----------------------------------------------------------------------===// |
113 | // AsinhOp folder |
114 | //===----------------------------------------------------------------------===// |
115 | |
116 | OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) { |
117 | return constFoldUnaryOpConditional<FloatAttr>( |
118 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
119 | switch (a.getSizeInBits(a.getSemantics())) { |
120 | case 64: |
121 | return APFloat(asinh(a.convertToDouble())); |
122 | case 32: |
123 | return APFloat(asinhf(a.convertToFloat())); |
124 | default: |
125 | return {}; |
126 | } |
127 | }); |
128 | } |
129 | |
130 | //===----------------------------------------------------------------------===// |
131 | // AtanOp folder |
132 | //===----------------------------------------------------------------------===// |
133 | |
134 | OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) { |
135 | return constFoldUnaryOpConditional<FloatAttr>( |
136 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
137 | switch (a.getSizeInBits(a.getSemantics())) { |
138 | case 64: |
139 | return APFloat(atan(a.convertToDouble())); |
140 | case 32: |
141 | return APFloat(atanf(a.convertToFloat())); |
142 | default: |
143 | return {}; |
144 | } |
145 | }); |
146 | } |
147 | |
148 | //===----------------------------------------------------------------------===// |
149 | // AtanhOp folder |
150 | //===----------------------------------------------------------------------===// |
151 | |
152 | OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) { |
153 | return constFoldUnaryOpConditional<FloatAttr>( |
154 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
155 | switch (a.getSizeInBits(a.getSemantics())) { |
156 | case 64: |
157 | return APFloat(atanh(a.convertToDouble())); |
158 | case 32: |
159 | return APFloat(atanhf(a.convertToFloat())); |
160 | default: |
161 | return {}; |
162 | } |
163 | }); |
164 | } |
165 | |
166 | //===----------------------------------------------------------------------===// |
167 | // Atan2Op folder |
168 | //===----------------------------------------------------------------------===// |
169 | |
170 | OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) { |
171 | return constFoldBinaryOpConditional<FloatAttr>( |
172 | adaptor.getOperands(), |
173 | [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> { |
174 | if (a.isZero() && b.isZero()) |
175 | return llvm::APFloat::getNaN(a.getSemantics()); |
176 | |
177 | if (a.getSizeInBits(a.getSemantics()) == 64 && |
178 | b.getSizeInBits(b.getSemantics()) == 64) |
179 | return APFloat(atan2(a.convertToDouble(), b.convertToDouble())); |
180 | |
181 | if (a.getSizeInBits(a.getSemantics()) == 32 && |
182 | b.getSizeInBits(b.getSemantics()) == 32) |
183 | return APFloat(atan2f(a.convertToFloat(), b.convertToFloat())); |
184 | |
185 | return {}; |
186 | }); |
187 | } |
188 | |
189 | //===----------------------------------------------------------------------===// |
190 | // CeilOp folder |
191 | //===----------------------------------------------------------------------===// |
192 | |
193 | OpFoldResult math::CeilOp::fold(FoldAdaptor adaptor) { |
194 | return constFoldUnaryOp<FloatAttr>( |
195 | adaptor.getOperands(), [](const APFloat &a) { |
196 | APFloat result(a); |
197 | result.roundToIntegral(llvm::RoundingMode::TowardPositive); |
198 | return result; |
199 | }); |
200 | } |
201 | |
202 | //===----------------------------------------------------------------------===// |
203 | // CopySignOp folder |
204 | //===----------------------------------------------------------------------===// |
205 | |
206 | OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) { |
207 | return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), |
208 | [](const APFloat &a, const APFloat &b) { |
209 | APFloat result(a); |
210 | result.copySign(b); |
211 | return result; |
212 | }); |
213 | } |
214 | |
215 | //===----------------------------------------------------------------------===// |
216 | // CosOp folder |
217 | //===----------------------------------------------------------------------===// |
218 | |
219 | OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) { |
220 | return constFoldUnaryOpConditional<FloatAttr>( |
221 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
222 | switch (a.getSizeInBits(a.getSemantics())) { |
223 | case 64: |
224 | return APFloat(cos(a.convertToDouble())); |
225 | case 32: |
226 | return APFloat(cosf(a.convertToFloat())); |
227 | default: |
228 | return {}; |
229 | } |
230 | }); |
231 | } |
232 | |
233 | //===----------------------------------------------------------------------===// |
234 | // CoshOp folder |
235 | //===----------------------------------------------------------------------===// |
236 | |
237 | OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) { |
238 | return constFoldUnaryOpConditional<FloatAttr>( |
239 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
240 | switch (a.getSizeInBits(a.getSemantics())) { |
241 | case 64: |
242 | return APFloat(cosh(a.convertToDouble())); |
243 | case 32: |
244 | return APFloat(coshf(a.convertToFloat())); |
245 | default: |
246 | return {}; |
247 | } |
248 | }); |
249 | } |
250 | |
251 | //===----------------------------------------------------------------------===// |
252 | // SinOp folder |
253 | //===----------------------------------------------------------------------===// |
254 | |
255 | OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) { |
256 | return constFoldUnaryOpConditional<FloatAttr>( |
257 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
258 | switch (a.getSizeInBits(a.getSemantics())) { |
259 | case 64: |
260 | return APFloat(sin(a.convertToDouble())); |
261 | case 32: |
262 | return APFloat(sinf(a.convertToFloat())); |
263 | default: |
264 | return {}; |
265 | } |
266 | }); |
267 | } |
268 | |
269 | //===----------------------------------------------------------------------===// |
270 | // SinhOp folder |
271 | //===----------------------------------------------------------------------===// |
272 | |
273 | OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) { |
274 | return constFoldUnaryOpConditional<FloatAttr>( |
275 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
276 | switch (a.getSizeInBits(a.getSemantics())) { |
277 | case 64: |
278 | return APFloat(sinh(a.convertToDouble())); |
279 | case 32: |
280 | return APFloat(sinhf(a.convertToFloat())); |
281 | default: |
282 | return {}; |
283 | } |
284 | }); |
285 | } |
286 | |
287 | //===----------------------------------------------------------------------===// |
288 | // CountLeadingZerosOp folder |
289 | //===----------------------------------------------------------------------===// |
290 | |
291 | OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) { |
292 | return constFoldUnaryOp<IntegerAttr>( |
293 | adaptor.getOperands(), |
294 | [](const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); }); |
295 | } |
296 | |
297 | //===----------------------------------------------------------------------===// |
298 | // CountTrailingZerosOp folder |
299 | //===----------------------------------------------------------------------===// |
300 | |
301 | OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) { |
302 | return constFoldUnaryOp<IntegerAttr>( |
303 | adaptor.getOperands(), |
304 | [](const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); }); |
305 | } |
306 | |
307 | //===----------------------------------------------------------------------===// |
308 | // CtPopOp folder |
309 | //===----------------------------------------------------------------------===// |
310 | |
311 | OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) { |
312 | return constFoldUnaryOp<IntegerAttr>( |
313 | adaptor.getOperands(), |
314 | [](const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); }); |
315 | } |
316 | |
317 | //===----------------------------------------------------------------------===// |
318 | // ErfOp folder |
319 | //===----------------------------------------------------------------------===// |
320 | |
321 | OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) { |
322 | return constFoldUnaryOpConditional<FloatAttr>( |
323 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
324 | switch (a.getSizeInBits(a.getSemantics())) { |
325 | case 64: |
326 | return APFloat(erf(a.convertToDouble())); |
327 | case 32: |
328 | return APFloat(erff(a.convertToFloat())); |
329 | default: |
330 | return {}; |
331 | } |
332 | }); |
333 | } |
334 | |
335 | //===----------------------------------------------------------------------===// |
336 | // ErfcOp folder |
337 | //===----------------------------------------------------------------------===// |
338 | |
339 | OpFoldResult math::ErfcOp::fold(FoldAdaptor adaptor) { |
340 | return constFoldUnaryOpConditional<FloatAttr>( |
341 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
342 | switch (APFloat::SemanticsToEnum(a.getSemantics())) { |
343 | case APFloat::Semantics::S_IEEEdouble: |
344 | return APFloat(erfc(a.convertToDouble())); |
345 | case APFloat::Semantics::S_IEEEsingle: |
346 | return APFloat(erfcf(a.convertToFloat())); |
347 | default: |
348 | return {}; |
349 | } |
350 | }); |
351 | } |
352 | |
353 | //===----------------------------------------------------------------------===// |
354 | // IPowIOp folder |
355 | //===----------------------------------------------------------------------===// |
356 | |
357 | OpFoldResult math::IPowIOp::fold(FoldAdaptor adaptor) { |
358 | return constFoldBinaryOpConditional<IntegerAttr>( |
359 | adaptor.getOperands(), |
360 | [](const APInt &base, const APInt &power) -> std::optional<APInt> { |
361 | unsigned width = base.getBitWidth(); |
362 | auto zeroValue = APInt::getZero(width); |
363 | APInt oneValue{width, 1ULL, /*isSigned=*/true}; |
364 | APInt minusOneValue{width, -1ULL, /*isSigned=*/true}; |
365 | |
366 | if (power.isZero()) |
367 | return oneValue; |
368 | |
369 | if (power.isNegative()) { |
370 | // Leave 0 raised to negative power not folded. |
371 | if (base.isZero()) |
372 | return {}; |
373 | if (base.eq(oneValue)) |
374 | return oneValue; |
375 | // If abs(base) > 1, then the result is zero. |
376 | if (base.ne(minusOneValue)) |
377 | return zeroValue; |
378 | // base == -1: |
379 | // -1: power is odd |
380 | // 1: power is even |
381 | if (power[0] == 1) |
382 | return minusOneValue; |
383 | |
384 | return oneValue; |
385 | } |
386 | |
387 | // power is positive. |
388 | APInt result = oneValue; |
389 | APInt curBase = base; |
390 | APInt curPower = power; |
391 | while (true) { |
392 | if (curPower[0] == 1) |
393 | result *= curBase; |
394 | curPower.lshrInPlace(1); |
395 | if (curPower.isZero()) |
396 | return result; |
397 | curBase *= curBase; |
398 | } |
399 | }); |
400 | |
401 | return Attribute(); |
402 | } |
403 | |
404 | //===----------------------------------------------------------------------===// |
405 | // LogOp folder |
406 | //===----------------------------------------------------------------------===// |
407 | |
408 | OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) { |
409 | return constFoldUnaryOpConditional<FloatAttr>( |
410 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
411 | if (a.isNegative()) |
412 | return {}; |
413 | |
414 | if (a.getSizeInBits(a.getSemantics()) == 64) |
415 | return APFloat(log(a.convertToDouble())); |
416 | |
417 | if (a.getSizeInBits(a.getSemantics()) == 32) |
418 | return APFloat(logf(a.convertToFloat())); |
419 | |
420 | return {}; |
421 | }); |
422 | } |
423 | |
424 | //===----------------------------------------------------------------------===// |
425 | // Log2Op folder |
426 | //===----------------------------------------------------------------------===// |
427 | |
428 | OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) { |
429 | return constFoldUnaryOpConditional<FloatAttr>( |
430 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
431 | if (a.isNegative()) |
432 | return {}; |
433 | |
434 | if (a.getSizeInBits(a.getSemantics()) == 64) |
435 | return APFloat(log2(a.convertToDouble())); |
436 | |
437 | if (a.getSizeInBits(a.getSemantics()) == 32) |
438 | return APFloat(log2f(a.convertToFloat())); |
439 | |
440 | return {}; |
441 | }); |
442 | } |
443 | |
444 | //===----------------------------------------------------------------------===// |
445 | // Log10Op folder |
446 | //===----------------------------------------------------------------------===// |
447 | |
448 | OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) { |
449 | return constFoldUnaryOpConditional<FloatAttr>( |
450 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
451 | if (a.isNegative()) |
452 | return {}; |
453 | |
454 | switch (a.getSizeInBits(a.getSemantics())) { |
455 | case 64: |
456 | return APFloat(log10(a.convertToDouble())); |
457 | case 32: |
458 | return APFloat(log10f(a.convertToFloat())); |
459 | default: |
460 | return {}; |
461 | } |
462 | }); |
463 | } |
464 | |
465 | //===----------------------------------------------------------------------===// |
466 | // Log1pOp folder |
467 | //===----------------------------------------------------------------------===// |
468 | |
469 | OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) { |
470 | return constFoldUnaryOpConditional<FloatAttr>( |
471 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
472 | switch (a.getSizeInBits(a.getSemantics())) { |
473 | case 64: |
474 | if ((a + APFloat(1.0)).isNegative()) |
475 | return {}; |
476 | return APFloat(log1p(a.convertToDouble())); |
477 | case 32: |
478 | if ((a + APFloat(1.0f)).isNegative()) |
479 | return {}; |
480 | return APFloat(log1pf(a.convertToFloat())); |
481 | default: |
482 | return {}; |
483 | } |
484 | }); |
485 | } |
486 | |
487 | //===----------------------------------------------------------------------===// |
488 | // PowFOp folder |
489 | //===----------------------------------------------------------------------===// |
490 | |
491 | OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) { |
492 | return constFoldBinaryOpConditional<FloatAttr>( |
493 | adaptor.getOperands(), |
494 | [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> { |
495 | if (a.getSizeInBits(a.getSemantics()) == 64 && |
496 | b.getSizeInBits(b.getSemantics()) == 64) |
497 | return APFloat(pow(a.convertToDouble(), b.convertToDouble())); |
498 | |
499 | if (a.getSizeInBits(a.getSemantics()) == 32 && |
500 | b.getSizeInBits(b.getSemantics()) == 32) |
501 | return APFloat(powf(a.convertToFloat(), b.convertToFloat())); |
502 | |
503 | return {}; |
504 | }); |
505 | } |
506 | |
507 | //===----------------------------------------------------------------------===// |
508 | // SqrtOp folder |
509 | //===----------------------------------------------------------------------===// |
510 | |
511 | OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) { |
512 | return constFoldUnaryOpConditional<FloatAttr>( |
513 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
514 | if (a.isNegative()) |
515 | return {}; |
516 | |
517 | switch (a.getSizeInBits(a.getSemantics())) { |
518 | case 64: |
519 | return APFloat(sqrt(a.convertToDouble())); |
520 | case 32: |
521 | return APFloat(sqrtf(a.convertToFloat())); |
522 | default: |
523 | return {}; |
524 | } |
525 | }); |
526 | } |
527 | |
528 | //===----------------------------------------------------------------------===// |
529 | // ExpOp folder |
530 | //===----------------------------------------------------------------------===// |
531 | |
532 | OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) { |
533 | return constFoldUnaryOpConditional<FloatAttr>( |
534 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
535 | switch (a.getSizeInBits(a.getSemantics())) { |
536 | case 64: |
537 | return APFloat(exp(a.convertToDouble())); |
538 | case 32: |
539 | return APFloat(expf(a.convertToFloat())); |
540 | default: |
541 | return {}; |
542 | } |
543 | }); |
544 | } |
545 | |
546 | //===----------------------------------------------------------------------===// |
547 | // Exp2Op folder |
548 | //===----------------------------------------------------------------------===// |
549 | |
550 | OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) { |
551 | return constFoldUnaryOpConditional<FloatAttr>( |
552 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
553 | switch (a.getSizeInBits(a.getSemantics())) { |
554 | case 64: |
555 | return APFloat(exp2(a.convertToDouble())); |
556 | case 32: |
557 | return APFloat(exp2f(a.convertToFloat())); |
558 | default: |
559 | return {}; |
560 | } |
561 | }); |
562 | } |
563 | |
564 | //===----------------------------------------------------------------------===// |
565 | // ExpM1Op folder |
566 | //===----------------------------------------------------------------------===// |
567 | |
568 | OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) { |
569 | return constFoldUnaryOpConditional<FloatAttr>( |
570 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
571 | switch (a.getSizeInBits(a.getSemantics())) { |
572 | case 64: |
573 | return APFloat(expm1(a.convertToDouble())); |
574 | case 32: |
575 | return APFloat(expm1f(a.convertToFloat())); |
576 | default: |
577 | return {}; |
578 | } |
579 | }); |
580 | } |
581 | |
582 | //===----------------------------------------------------------------------===// |
583 | // IsFiniteOp folder |
584 | //===----------------------------------------------------------------------===// |
585 | |
586 | OpFoldResult math::IsFiniteOp::fold(FoldAdaptor adaptor) { |
587 | if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) { |
588 | return BoolAttr::get(val.getContext(), val.getValue().isFinite()); |
589 | } |
590 | if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) { |
591 | return DenseElementsAttr::get( |
592 | cast<ShapedType>(getType()), |
593 | APInt(1, splat.getSplatValue<APFloat>().isFinite())); |
594 | } |
595 | return {}; |
596 | } |
597 | |
598 | //===----------------------------------------------------------------------===// |
599 | // IsInfOp folder |
600 | //===----------------------------------------------------------------------===// |
601 | |
602 | OpFoldResult math::IsInfOp::fold(FoldAdaptor adaptor) { |
603 | if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) { |
604 | return BoolAttr::get(val.getContext(), val.getValue().isInfinity()); |
605 | } |
606 | if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) { |
607 | return DenseElementsAttr::get( |
608 | cast<ShapedType>(getType()), |
609 | APInt(1, splat.getSplatValue<APFloat>().isInfinity())); |
610 | } |
611 | return {}; |
612 | } |
613 | |
614 | //===----------------------------------------------------------------------===// |
615 | // IsNaNOp folder |
616 | //===----------------------------------------------------------------------===// |
617 | |
618 | OpFoldResult math::IsNaNOp::fold(FoldAdaptor adaptor) { |
619 | if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) { |
620 | return BoolAttr::get(val.getContext(), val.getValue().isNaN()); |
621 | } |
622 | if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) { |
623 | return DenseElementsAttr::get( |
624 | cast<ShapedType>(getType()), |
625 | APInt(1, splat.getSplatValue<APFloat>().isNaN())); |
626 | } |
627 | return {}; |
628 | } |
629 | |
630 | //===----------------------------------------------------------------------===// |
631 | // IsNormalOp folder |
632 | //===----------------------------------------------------------------------===// |
633 | |
634 | OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) { |
635 | if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) { |
636 | return BoolAttr::get(val.getContext(), val.getValue().isNormal()); |
637 | } |
638 | if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) { |
639 | return DenseElementsAttr::get( |
640 | cast<ShapedType>(getType()), |
641 | APInt(1, splat.getSplatValue<APFloat>().isNormal())); |
642 | } |
643 | return {}; |
644 | } |
645 | |
646 | //===----------------------------------------------------------------------===// |
647 | // TanOp folder |
648 | //===----------------------------------------------------------------------===// |
649 | |
650 | OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) { |
651 | return constFoldUnaryOpConditional<FloatAttr>( |
652 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
653 | switch (a.getSizeInBits(a.getSemantics())) { |
654 | case 64: |
655 | return APFloat(tan(a.convertToDouble())); |
656 | case 32: |
657 | return APFloat(tanf(a.convertToFloat())); |
658 | default: |
659 | return {}; |
660 | } |
661 | }); |
662 | } |
663 | |
664 | //===----------------------------------------------------------------------===// |
665 | // TanhOp folder |
666 | //===----------------------------------------------------------------------===// |
667 | |
668 | OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) { |
669 | return constFoldUnaryOpConditional<FloatAttr>( |
670 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
671 | switch (a.getSizeInBits(a.getSemantics())) { |
672 | case 64: |
673 | return APFloat(tanh(a.convertToDouble())); |
674 | case 32: |
675 | return APFloat(tanhf(a.convertToFloat())); |
676 | default: |
677 | return {}; |
678 | } |
679 | }); |
680 | } |
681 | |
682 | //===----------------------------------------------------------------------===// |
683 | // RoundEvenOp folder |
684 | //===----------------------------------------------------------------------===// |
685 | |
686 | OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) { |
687 | return constFoldUnaryOp<FloatAttr>( |
688 | adaptor.getOperands(), [](const APFloat &a) { |
689 | APFloat result(a); |
690 | result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven); |
691 | return result; |
692 | }); |
693 | } |
694 | |
695 | //===----------------------------------------------------------------------===// |
696 | // FloorOp folder |
697 | //===----------------------------------------------------------------------===// |
698 | |
699 | OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) { |
700 | return constFoldUnaryOp<FloatAttr>( |
701 | adaptor.getOperands(), [](const APFloat &a) { |
702 | APFloat result(a); |
703 | result.roundToIntegral(llvm::RoundingMode::TowardNegative); |
704 | return result; |
705 | }); |
706 | } |
707 | |
708 | //===----------------------------------------------------------------------===// |
709 | // RoundOp folder |
710 | //===----------------------------------------------------------------------===// |
711 | |
712 | OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) { |
713 | return constFoldUnaryOpConditional<FloatAttr>( |
714 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
715 | switch (a.getSizeInBits(a.getSemantics())) { |
716 | case 64: |
717 | return APFloat(round(a.convertToDouble())); |
718 | case 32: |
719 | return APFloat(roundf(a.convertToFloat())); |
720 | default: |
721 | return {}; |
722 | } |
723 | }); |
724 | } |
725 | |
726 | //===----------------------------------------------------------------------===// |
727 | // TruncOp folder |
728 | //===----------------------------------------------------------------------===// |
729 | |
730 | OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) { |
731 | return constFoldUnaryOpConditional<FloatAttr>( |
732 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
733 | switch (a.getSizeInBits(a.getSemantics())) { |
734 | case 64: |
735 | return APFloat(trunc(a.convertToDouble())); |
736 | case 32: |
737 | return APFloat(truncf(a.convertToFloat())); |
738 | default: |
739 | return {}; |
740 | } |
741 | }); |
742 | } |
743 | |
744 | /// Materialize an integer or floating point constant. |
745 | Operation *math::MathDialect::materializeConstant(OpBuilder &builder, |
746 | Attribute value, Type type, |
747 | Location loc) { |
748 | if (auto poison = dyn_cast<ub::PoisonAttr>(value)) |
749 | return builder.create<ub::PoisonOp>(loc, type, poison); |
750 | |
751 | return arith::ConstantOp::materialize(builder, value, type, loc); |
752 | } |
753 | |