1 | //===- MathOps.cpp - MLIR operations for math implementation --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Arith/IR/Arith.h" |
10 | #include "mlir/Dialect/CommonFolders.h" |
11 | #include "mlir/Dialect/Math/IR/Math.h" |
12 | #include "mlir/Dialect/UB/IR/UBOps.h" |
13 | #include "mlir/IR/Builders.h" |
14 | #include <optional> |
15 | |
16 | using namespace mlir; |
17 | using namespace mlir::math; |
18 | |
19 | //===----------------------------------------------------------------------===// |
20 | // TableGen'd op method definitions |
21 | //===----------------------------------------------------------------------===// |
22 | |
23 | #define GET_OP_CLASSES |
24 | #include "mlir/Dialect/Math/IR/MathOps.cpp.inc" |
25 | |
26 | //===----------------------------------------------------------------------===// |
27 | // AbsFOp folder |
28 | //===----------------------------------------------------------------------===// |
29 | |
30 | OpFoldResult math::AbsFOp::fold(FoldAdaptor adaptor) { |
31 | return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(), |
32 | [](const APFloat &a) { return abs(a); }); |
33 | } |
34 | |
35 | //===----------------------------------------------------------------------===// |
36 | // AbsIOp folder |
37 | //===----------------------------------------------------------------------===// |
38 | |
39 | OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) { |
40 | return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), |
41 | [](const APInt &a) { return a.abs(); }); |
42 | } |
43 | |
44 | //===----------------------------------------------------------------------===// |
45 | // AcosOp folder |
46 | //===----------------------------------------------------------------------===// |
47 | |
48 | OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) { |
49 | return constFoldUnaryOpConditional<FloatAttr>( |
50 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
51 | switch (a.getSizeInBits(a.getSemantics())) { |
52 | case 64: |
53 | return APFloat(acos(a.convertToDouble())); |
54 | case 32: |
55 | return APFloat(acosf(a.convertToFloat())); |
56 | default: |
57 | return {}; |
58 | } |
59 | }); |
60 | } |
61 | |
62 | //===----------------------------------------------------------------------===// |
63 | // AcoshOp folder |
64 | //===----------------------------------------------------------------------===// |
65 | |
66 | OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) { |
67 | return constFoldUnaryOpConditional<FloatAttr>( |
68 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
69 | switch (a.getSizeInBits(a.getSemantics())) { |
70 | case 64: |
71 | return APFloat(acosh(a.convertToDouble())); |
72 | case 32: |
73 | return APFloat(acoshf(a.convertToFloat())); |
74 | default: |
75 | return {}; |
76 | } |
77 | }); |
78 | } |
79 | |
80 | //===----------------------------------------------------------------------===// |
81 | // AsinOp folder |
82 | //===----------------------------------------------------------------------===// |
83 | |
84 | OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) { |
85 | return constFoldUnaryOpConditional<FloatAttr>( |
86 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
87 | switch (a.getSizeInBits(a.getSemantics())) { |
88 | case 64: |
89 | return APFloat(asin(a.convertToDouble())); |
90 | case 32: |
91 | return APFloat(asinf(a.convertToFloat())); |
92 | default: |
93 | return {}; |
94 | } |
95 | }); |
96 | } |
97 | |
98 | //===----------------------------------------------------------------------===// |
99 | // AsinhOp folder |
100 | //===----------------------------------------------------------------------===// |
101 | |
102 | OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) { |
103 | return constFoldUnaryOpConditional<FloatAttr>( |
104 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
105 | switch (a.getSizeInBits(a.getSemantics())) { |
106 | case 64: |
107 | return APFloat(asinh(a.convertToDouble())); |
108 | case 32: |
109 | return APFloat(asinhf(a.convertToFloat())); |
110 | default: |
111 | return {}; |
112 | } |
113 | }); |
114 | } |
115 | |
116 | //===----------------------------------------------------------------------===// |
117 | // AtanOp folder |
118 | //===----------------------------------------------------------------------===// |
119 | |
120 | OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) { |
121 | return constFoldUnaryOpConditional<FloatAttr>( |
122 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
123 | switch (a.getSizeInBits(a.getSemantics())) { |
124 | case 64: |
125 | return APFloat(atan(a.convertToDouble())); |
126 | case 32: |
127 | return APFloat(atanf(a.convertToFloat())); |
128 | default: |
129 | return {}; |
130 | } |
131 | }); |
132 | } |
133 | |
134 | //===----------------------------------------------------------------------===// |
135 | // AtanhOp folder |
136 | //===----------------------------------------------------------------------===// |
137 | |
138 | OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) { |
139 | return constFoldUnaryOpConditional<FloatAttr>( |
140 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
141 | switch (a.getSizeInBits(a.getSemantics())) { |
142 | case 64: |
143 | return APFloat(atanh(a.convertToDouble())); |
144 | case 32: |
145 | return APFloat(atanhf(a.convertToFloat())); |
146 | default: |
147 | return {}; |
148 | } |
149 | }); |
150 | } |
151 | |
152 | //===----------------------------------------------------------------------===// |
153 | // Atan2Op folder |
154 | //===----------------------------------------------------------------------===// |
155 | |
156 | OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) { |
157 | return constFoldBinaryOpConditional<FloatAttr>( |
158 | adaptor.getOperands(), |
159 | [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> { |
160 | if (a.isZero() && b.isZero()) |
161 | return llvm::APFloat::getNaN(a.getSemantics()); |
162 | |
163 | if (a.getSizeInBits(a.getSemantics()) == 64 && |
164 | b.getSizeInBits(b.getSemantics()) == 64) |
165 | return APFloat(atan2(a.convertToDouble(), b.convertToDouble())); |
166 | |
167 | if (a.getSizeInBits(a.getSemantics()) == 32 && |
168 | b.getSizeInBits(b.getSemantics()) == 32) |
169 | return APFloat(atan2f(a.convertToFloat(), b.convertToFloat())); |
170 | |
171 | return {}; |
172 | }); |
173 | } |
174 | |
175 | //===----------------------------------------------------------------------===// |
176 | // CeilOp folder |
177 | //===----------------------------------------------------------------------===// |
178 | |
179 | OpFoldResult math::CeilOp::fold(FoldAdaptor adaptor) { |
180 | return constFoldUnaryOp<FloatAttr>( |
181 | adaptor.getOperands(), [](const APFloat &a) { |
182 | APFloat result(a); |
183 | result.roundToIntegral(llvm::RoundingMode::TowardPositive); |
184 | return result; |
185 | }); |
186 | } |
187 | |
188 | //===----------------------------------------------------------------------===// |
189 | // CopySignOp folder |
190 | //===----------------------------------------------------------------------===// |
191 | |
192 | OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) { |
193 | return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), |
194 | [](const APFloat &a, const APFloat &b) { |
195 | APFloat result(a); |
196 | result.copySign(b); |
197 | return result; |
198 | }); |
199 | } |
200 | |
201 | //===----------------------------------------------------------------------===// |
202 | // CosOp folder |
203 | //===----------------------------------------------------------------------===// |
204 | |
205 | OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) { |
206 | return constFoldUnaryOpConditional<FloatAttr>( |
207 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
208 | switch (a.getSizeInBits(a.getSemantics())) { |
209 | case 64: |
210 | return APFloat(cos(a.convertToDouble())); |
211 | case 32: |
212 | return APFloat(cosf(a.convertToFloat())); |
213 | default: |
214 | return {}; |
215 | } |
216 | }); |
217 | } |
218 | |
219 | //===----------------------------------------------------------------------===// |
220 | // CoshOp folder |
221 | //===----------------------------------------------------------------------===// |
222 | |
223 | OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) { |
224 | return constFoldUnaryOpConditional<FloatAttr>( |
225 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
226 | switch (a.getSizeInBits(a.getSemantics())) { |
227 | case 64: |
228 | return APFloat(cosh(a.convertToDouble())); |
229 | case 32: |
230 | return APFloat(coshf(a.convertToFloat())); |
231 | default: |
232 | return {}; |
233 | } |
234 | }); |
235 | } |
236 | |
237 | //===----------------------------------------------------------------------===// |
238 | // SinOp folder |
239 | //===----------------------------------------------------------------------===// |
240 | |
241 | OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) { |
242 | return constFoldUnaryOpConditional<FloatAttr>( |
243 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
244 | switch (a.getSizeInBits(a.getSemantics())) { |
245 | case 64: |
246 | return APFloat(sin(a.convertToDouble())); |
247 | case 32: |
248 | return APFloat(sinf(a.convertToFloat())); |
249 | default: |
250 | return {}; |
251 | } |
252 | }); |
253 | } |
254 | |
255 | //===----------------------------------------------------------------------===// |
256 | // SinhOp folder |
257 | //===----------------------------------------------------------------------===// |
258 | |
259 | OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) { |
260 | return constFoldUnaryOpConditional<FloatAttr>( |
261 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
262 | switch (a.getSizeInBits(a.getSemantics())) { |
263 | case 64: |
264 | return APFloat(sinh(a.convertToDouble())); |
265 | case 32: |
266 | return APFloat(sinhf(a.convertToFloat())); |
267 | default: |
268 | return {}; |
269 | } |
270 | }); |
271 | } |
272 | |
273 | //===----------------------------------------------------------------------===// |
274 | // CountLeadingZerosOp folder |
275 | //===----------------------------------------------------------------------===// |
276 | |
277 | OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) { |
278 | return constFoldUnaryOp<IntegerAttr>( |
279 | adaptor.getOperands(), |
280 | [](const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); }); |
281 | } |
282 | |
283 | //===----------------------------------------------------------------------===// |
284 | // CountTrailingZerosOp folder |
285 | //===----------------------------------------------------------------------===// |
286 | |
287 | OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) { |
288 | return constFoldUnaryOp<IntegerAttr>( |
289 | adaptor.getOperands(), |
290 | [](const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); }); |
291 | } |
292 | |
293 | //===----------------------------------------------------------------------===// |
294 | // CtPopOp folder |
295 | //===----------------------------------------------------------------------===// |
296 | |
297 | OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) { |
298 | return constFoldUnaryOp<IntegerAttr>( |
299 | adaptor.getOperands(), |
300 | [](const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); }); |
301 | } |
302 | |
303 | //===----------------------------------------------------------------------===// |
304 | // ErfOp folder |
305 | //===----------------------------------------------------------------------===// |
306 | |
307 | OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) { |
308 | return constFoldUnaryOpConditional<FloatAttr>( |
309 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
310 | switch (a.getSizeInBits(a.getSemantics())) { |
311 | case 64: |
312 | return APFloat(erf(a.convertToDouble())); |
313 | case 32: |
314 | return APFloat(erff(a.convertToFloat())); |
315 | default: |
316 | return {}; |
317 | } |
318 | }); |
319 | } |
320 | |
321 | //===----------------------------------------------------------------------===// |
322 | // IPowIOp folder |
323 | //===----------------------------------------------------------------------===// |
324 | |
325 | OpFoldResult math::IPowIOp::fold(FoldAdaptor adaptor) { |
326 | return constFoldBinaryOpConditional<IntegerAttr>( |
327 | adaptor.getOperands(), |
328 | [](const APInt &base, const APInt &power) -> std::optional<APInt> { |
329 | unsigned width = base.getBitWidth(); |
330 | auto zeroValue = APInt::getZero(width); |
331 | APInt oneValue{width, 1ULL, /*isSigned=*/true}; |
332 | APInt minusOneValue{width, -1ULL, /*isSigned=*/true}; |
333 | |
334 | if (power.isZero()) |
335 | return oneValue; |
336 | |
337 | if (power.isNegative()) { |
338 | // Leave 0 raised to negative power not folded. |
339 | if (base.isZero()) |
340 | return {}; |
341 | if (base.eq(oneValue)) |
342 | return oneValue; |
343 | // If abs(base) > 1, then the result is zero. |
344 | if (base.ne(minusOneValue)) |
345 | return zeroValue; |
346 | // base == -1: |
347 | // -1: power is odd |
348 | // 1: power is even |
349 | if (power[0] == 1) |
350 | return minusOneValue; |
351 | |
352 | return oneValue; |
353 | } |
354 | |
355 | // power is positive. |
356 | APInt result = oneValue; |
357 | APInt curBase = base; |
358 | APInt curPower = power; |
359 | while (true) { |
360 | if (curPower[0] == 1) |
361 | result *= curBase; |
362 | curPower.lshrInPlace(1); |
363 | if (curPower.isZero()) |
364 | return result; |
365 | curBase *= curBase; |
366 | } |
367 | }); |
368 | |
369 | return Attribute(); |
370 | } |
371 | |
372 | //===----------------------------------------------------------------------===// |
373 | // LogOp folder |
374 | //===----------------------------------------------------------------------===// |
375 | |
376 | OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) { |
377 | return constFoldUnaryOpConditional<FloatAttr>( |
378 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
379 | if (a.isNegative()) |
380 | return {}; |
381 | |
382 | if (a.getSizeInBits(a.getSemantics()) == 64) |
383 | return APFloat(log(a.convertToDouble())); |
384 | |
385 | if (a.getSizeInBits(a.getSemantics()) == 32) |
386 | return APFloat(logf(a.convertToFloat())); |
387 | |
388 | return {}; |
389 | }); |
390 | } |
391 | |
392 | //===----------------------------------------------------------------------===// |
393 | // Log2Op folder |
394 | //===----------------------------------------------------------------------===// |
395 | |
396 | OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) { |
397 | return constFoldUnaryOpConditional<FloatAttr>( |
398 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
399 | if (a.isNegative()) |
400 | return {}; |
401 | |
402 | if (a.getSizeInBits(a.getSemantics()) == 64) |
403 | return APFloat(log2(a.convertToDouble())); |
404 | |
405 | if (a.getSizeInBits(a.getSemantics()) == 32) |
406 | return APFloat(log2f(a.convertToFloat())); |
407 | |
408 | return {}; |
409 | }); |
410 | } |
411 | |
412 | //===----------------------------------------------------------------------===// |
413 | // Log10Op folder |
414 | //===----------------------------------------------------------------------===// |
415 | |
416 | OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) { |
417 | return constFoldUnaryOpConditional<FloatAttr>( |
418 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
419 | if (a.isNegative()) |
420 | return {}; |
421 | |
422 | switch (a.getSizeInBits(a.getSemantics())) { |
423 | case 64: |
424 | return APFloat(log10(a.convertToDouble())); |
425 | case 32: |
426 | return APFloat(log10f(a.convertToFloat())); |
427 | default: |
428 | return {}; |
429 | } |
430 | }); |
431 | } |
432 | |
433 | //===----------------------------------------------------------------------===// |
434 | // Log1pOp folder |
435 | //===----------------------------------------------------------------------===// |
436 | |
437 | OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) { |
438 | return constFoldUnaryOpConditional<FloatAttr>( |
439 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
440 | switch (a.getSizeInBits(a.getSemantics())) { |
441 | case 64: |
442 | if ((a + APFloat(1.0)).isNegative()) |
443 | return {}; |
444 | return APFloat(log1p(a.convertToDouble())); |
445 | case 32: |
446 | if ((a + APFloat(1.0f)).isNegative()) |
447 | return {}; |
448 | return APFloat(log1pf(a.convertToFloat())); |
449 | default: |
450 | return {}; |
451 | } |
452 | }); |
453 | } |
454 | |
455 | //===----------------------------------------------------------------------===// |
456 | // PowFOp folder |
457 | //===----------------------------------------------------------------------===// |
458 | |
459 | OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) { |
460 | return constFoldBinaryOpConditional<FloatAttr>( |
461 | adaptor.getOperands(), |
462 | [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> { |
463 | if (a.getSizeInBits(a.getSemantics()) == 64 && |
464 | b.getSizeInBits(b.getSemantics()) == 64) |
465 | return APFloat(pow(a.convertToDouble(), b.convertToDouble())); |
466 | |
467 | if (a.getSizeInBits(a.getSemantics()) == 32 && |
468 | b.getSizeInBits(b.getSemantics()) == 32) |
469 | return APFloat(powf(a.convertToFloat(), b.convertToFloat())); |
470 | |
471 | return {}; |
472 | }); |
473 | } |
474 | |
475 | //===----------------------------------------------------------------------===// |
476 | // SqrtOp folder |
477 | //===----------------------------------------------------------------------===// |
478 | |
479 | OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) { |
480 | return constFoldUnaryOpConditional<FloatAttr>( |
481 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
482 | if (a.isNegative()) |
483 | return {}; |
484 | |
485 | switch (a.getSizeInBits(a.getSemantics())) { |
486 | case 64: |
487 | return APFloat(sqrt(a.convertToDouble())); |
488 | case 32: |
489 | return APFloat(sqrtf(a.convertToFloat())); |
490 | default: |
491 | return {}; |
492 | } |
493 | }); |
494 | } |
495 | |
496 | //===----------------------------------------------------------------------===// |
497 | // ExpOp folder |
498 | //===----------------------------------------------------------------------===// |
499 | |
500 | OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) { |
501 | return constFoldUnaryOpConditional<FloatAttr>( |
502 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
503 | switch (a.getSizeInBits(a.getSemantics())) { |
504 | case 64: |
505 | return APFloat(exp(a.convertToDouble())); |
506 | case 32: |
507 | return APFloat(expf(a.convertToFloat())); |
508 | default: |
509 | return {}; |
510 | } |
511 | }); |
512 | } |
513 | |
514 | //===----------------------------------------------------------------------===// |
515 | // Exp2Op folder |
516 | //===----------------------------------------------------------------------===// |
517 | |
518 | OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) { |
519 | return constFoldUnaryOpConditional<FloatAttr>( |
520 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
521 | switch (a.getSizeInBits(a.getSemantics())) { |
522 | case 64: |
523 | return APFloat(exp2(a.convertToDouble())); |
524 | case 32: |
525 | return APFloat(exp2f(a.convertToFloat())); |
526 | default: |
527 | return {}; |
528 | } |
529 | }); |
530 | } |
531 | |
532 | //===----------------------------------------------------------------------===// |
533 | // ExpM1Op folder |
534 | //===----------------------------------------------------------------------===// |
535 | |
536 | OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) { |
537 | return constFoldUnaryOpConditional<FloatAttr>( |
538 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
539 | switch (a.getSizeInBits(a.getSemantics())) { |
540 | case 64: |
541 | return APFloat(expm1(a.convertToDouble())); |
542 | case 32: |
543 | return APFloat(expm1f(a.convertToFloat())); |
544 | default: |
545 | return {}; |
546 | } |
547 | }); |
548 | } |
549 | |
550 | //===----------------------------------------------------------------------===// |
551 | // TanOp folder |
552 | //===----------------------------------------------------------------------===// |
553 | |
554 | OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) { |
555 | return constFoldUnaryOpConditional<FloatAttr>( |
556 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
557 | switch (a.getSizeInBits(a.getSemantics())) { |
558 | case 64: |
559 | return APFloat(tan(a.convertToDouble())); |
560 | case 32: |
561 | return APFloat(tanf(a.convertToFloat())); |
562 | default: |
563 | return {}; |
564 | } |
565 | }); |
566 | } |
567 | |
568 | //===----------------------------------------------------------------------===// |
569 | // TanhOp folder |
570 | //===----------------------------------------------------------------------===// |
571 | |
572 | OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) { |
573 | return constFoldUnaryOpConditional<FloatAttr>( |
574 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
575 | switch (a.getSizeInBits(a.getSemantics())) { |
576 | case 64: |
577 | return APFloat(tanh(a.convertToDouble())); |
578 | case 32: |
579 | return APFloat(tanhf(a.convertToFloat())); |
580 | default: |
581 | return {}; |
582 | } |
583 | }); |
584 | } |
585 | |
586 | //===----------------------------------------------------------------------===// |
587 | // RoundEvenOp folder |
588 | //===----------------------------------------------------------------------===// |
589 | |
590 | OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) { |
591 | return constFoldUnaryOp<FloatAttr>( |
592 | adaptor.getOperands(), [](const APFloat &a) { |
593 | APFloat result(a); |
594 | result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven); |
595 | return result; |
596 | }); |
597 | } |
598 | |
599 | //===----------------------------------------------------------------------===// |
600 | // FloorOp folder |
601 | //===----------------------------------------------------------------------===// |
602 | |
603 | OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) { |
604 | return constFoldUnaryOp<FloatAttr>( |
605 | adaptor.getOperands(), [](const APFloat &a) { |
606 | APFloat result(a); |
607 | result.roundToIntegral(llvm::RoundingMode::TowardNegative); |
608 | return result; |
609 | }); |
610 | } |
611 | |
612 | //===----------------------------------------------------------------------===// |
613 | // RoundOp folder |
614 | //===----------------------------------------------------------------------===// |
615 | |
616 | OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) { |
617 | return constFoldUnaryOpConditional<FloatAttr>( |
618 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
619 | switch (a.getSizeInBits(a.getSemantics())) { |
620 | case 64: |
621 | return APFloat(round(a.convertToDouble())); |
622 | case 32: |
623 | return APFloat(roundf(a.convertToFloat())); |
624 | default: |
625 | return {}; |
626 | } |
627 | }); |
628 | } |
629 | |
630 | //===----------------------------------------------------------------------===// |
631 | // TruncOp folder |
632 | //===----------------------------------------------------------------------===// |
633 | |
634 | OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) { |
635 | return constFoldUnaryOpConditional<FloatAttr>( |
636 | adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> { |
637 | switch (a.getSizeInBits(a.getSemantics())) { |
638 | case 64: |
639 | return APFloat(trunc(a.convertToDouble())); |
640 | case 32: |
641 | return APFloat(truncf(a.convertToFloat())); |
642 | default: |
643 | return {}; |
644 | } |
645 | }); |
646 | } |
647 | |
648 | /// Materialize an integer or floating point constant. |
649 | Operation *math::MathDialect::materializeConstant(OpBuilder &builder, |
650 | Attribute value, Type type, |
651 | Location loc) { |
652 | if (auto poison = dyn_cast<ub::PoisonAttr>(value)) |
653 | return builder.create<ub::PoisonOp>(loc, type, poison); |
654 | |
655 | return arith::ConstantOp::materialize(builder, value, type, loc); |
656 | } |
657 | |