1//===- MathOps.cpp - MLIR operations for math implementation --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/CommonFolders.h"
11#include "mlir/Dialect/Math/IR/Math.h"
12#include "mlir/Dialect/UB/IR/UBOps.h"
13#include "mlir/IR/Builders.h"
14#include <optional>
15
16using namespace mlir;
17using namespace mlir::math;
18
19//===----------------------------------------------------------------------===//
20// Common helpers
21//===----------------------------------------------------------------------===//
22
23/// Return the type of the same shape (scalar, vector or tensor) containing i1.
24static Type getI1SameShape(Type type) {
25 auto i1Type = IntegerType::get(type.getContext(), 1);
26 if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
27 return shapedType.cloneWith(std::nullopt, i1Type);
28 if (llvm::isa<UnrankedTensorType>(type))
29 return UnrankedTensorType::get(i1Type);
30 return i1Type;
31}
32
33//===----------------------------------------------------------------------===//
34// TableGen'd op method definitions
35//===----------------------------------------------------------------------===//
36
37#define GET_OP_CLASSES
38#include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
39
40//===----------------------------------------------------------------------===//
41// AbsFOp folder
42//===----------------------------------------------------------------------===//
43
44OpFoldResult math::AbsFOp::fold(FoldAdaptor adaptor) {
45 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
46 [](const APFloat &a) { return abs(a); });
47}
48
49//===----------------------------------------------------------------------===//
50// AbsIOp folder
51//===----------------------------------------------------------------------===//
52
53OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
54 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
55 [](const APInt &a) { return a.abs(); });
56}
57
58//===----------------------------------------------------------------------===//
59// AcosOp folder
60//===----------------------------------------------------------------------===//
61
62OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
63 return constFoldUnaryOpConditional<FloatAttr>(
64 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
65 switch (a.getSizeInBits(a.getSemantics())) {
66 case 64:
67 return APFloat(acos(a.convertToDouble()));
68 case 32:
69 return APFloat(acosf(a.convertToFloat()));
70 default:
71 return {};
72 }
73 });
74}
75
76//===----------------------------------------------------------------------===//
77// AcoshOp folder
78//===----------------------------------------------------------------------===//
79
80OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
81 return constFoldUnaryOpConditional<FloatAttr>(
82 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
83 switch (a.getSizeInBits(a.getSemantics())) {
84 case 64:
85 return APFloat(acosh(a.convertToDouble()));
86 case 32:
87 return APFloat(acoshf(a.convertToFloat()));
88 default:
89 return {};
90 }
91 });
92}
93
94//===----------------------------------------------------------------------===//
95// AsinOp folder
96//===----------------------------------------------------------------------===//
97
98OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
99 return constFoldUnaryOpConditional<FloatAttr>(
100 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
101 switch (a.getSizeInBits(a.getSemantics())) {
102 case 64:
103 return APFloat(asin(a.convertToDouble()));
104 case 32:
105 return APFloat(asinf(a.convertToFloat()));
106 default:
107 return {};
108 }
109 });
110}
111
112//===----------------------------------------------------------------------===//
113// AsinhOp folder
114//===----------------------------------------------------------------------===//
115
116OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
117 return constFoldUnaryOpConditional<FloatAttr>(
118 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
119 switch (a.getSizeInBits(a.getSemantics())) {
120 case 64:
121 return APFloat(asinh(a.convertToDouble()));
122 case 32:
123 return APFloat(asinhf(a.convertToFloat()));
124 default:
125 return {};
126 }
127 });
128}
129
130//===----------------------------------------------------------------------===//
131// AtanOp folder
132//===----------------------------------------------------------------------===//
133
134OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
135 return constFoldUnaryOpConditional<FloatAttr>(
136 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
137 switch (a.getSizeInBits(a.getSemantics())) {
138 case 64:
139 return APFloat(atan(a.convertToDouble()));
140 case 32:
141 return APFloat(atanf(a.convertToFloat()));
142 default:
143 return {};
144 }
145 });
146}
147
148//===----------------------------------------------------------------------===//
149// AtanhOp folder
150//===----------------------------------------------------------------------===//
151
152OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) {
153 return constFoldUnaryOpConditional<FloatAttr>(
154 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
155 switch (a.getSizeInBits(a.getSemantics())) {
156 case 64:
157 return APFloat(atanh(a.convertToDouble()));
158 case 32:
159 return APFloat(atanhf(a.convertToFloat()));
160 default:
161 return {};
162 }
163 });
164}
165
166//===----------------------------------------------------------------------===//
167// Atan2Op folder
168//===----------------------------------------------------------------------===//
169
170OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
171 return constFoldBinaryOpConditional<FloatAttr>(
172 adaptor.getOperands(),
173 [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
174 if (a.isZero() && b.isZero())
175 return llvm::APFloat::getNaN(a.getSemantics());
176
177 if (a.getSizeInBits(a.getSemantics()) == 64 &&
178 b.getSizeInBits(b.getSemantics()) == 64)
179 return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
180
181 if (a.getSizeInBits(a.getSemantics()) == 32 &&
182 b.getSizeInBits(b.getSemantics()) == 32)
183 return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
184
185 return {};
186 });
187}
188
189//===----------------------------------------------------------------------===//
190// CeilOp folder
191//===----------------------------------------------------------------------===//
192
193OpFoldResult math::CeilOp::fold(FoldAdaptor adaptor) {
194 return constFoldUnaryOp<FloatAttr>(
195 adaptor.getOperands(), [](const APFloat &a) {
196 APFloat result(a);
197 result.roundToIntegral(llvm::RoundingMode::TowardPositive);
198 return result;
199 });
200}
201
202//===----------------------------------------------------------------------===//
203// CopySignOp folder
204//===----------------------------------------------------------------------===//
205
206OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
207 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
208 [](const APFloat &a, const APFloat &b) {
209 APFloat result(a);
210 result.copySign(b);
211 return result;
212 });
213}
214
215//===----------------------------------------------------------------------===//
216// CosOp folder
217//===----------------------------------------------------------------------===//
218
219OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
220 return constFoldUnaryOpConditional<FloatAttr>(
221 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
222 switch (a.getSizeInBits(a.getSemantics())) {
223 case 64:
224 return APFloat(cos(a.convertToDouble()));
225 case 32:
226 return APFloat(cosf(a.convertToFloat()));
227 default:
228 return {};
229 }
230 });
231}
232
233//===----------------------------------------------------------------------===//
234// CoshOp folder
235//===----------------------------------------------------------------------===//
236
237OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
238 return constFoldUnaryOpConditional<FloatAttr>(
239 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
240 switch (a.getSizeInBits(a.getSemantics())) {
241 case 64:
242 return APFloat(cosh(a.convertToDouble()));
243 case 32:
244 return APFloat(coshf(a.convertToFloat()));
245 default:
246 return {};
247 }
248 });
249}
250
251//===----------------------------------------------------------------------===//
252// SinOp folder
253//===----------------------------------------------------------------------===//
254
255OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
256 return constFoldUnaryOpConditional<FloatAttr>(
257 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
258 switch (a.getSizeInBits(a.getSemantics())) {
259 case 64:
260 return APFloat(sin(a.convertToDouble()));
261 case 32:
262 return APFloat(sinf(a.convertToFloat()));
263 default:
264 return {};
265 }
266 });
267}
268
269//===----------------------------------------------------------------------===//
270// SinhOp folder
271//===----------------------------------------------------------------------===//
272
273OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
274 return constFoldUnaryOpConditional<FloatAttr>(
275 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
276 switch (a.getSizeInBits(a.getSemantics())) {
277 case 64:
278 return APFloat(sinh(a.convertToDouble()));
279 case 32:
280 return APFloat(sinhf(a.convertToFloat()));
281 default:
282 return {};
283 }
284 });
285}
286
287//===----------------------------------------------------------------------===//
288// CountLeadingZerosOp folder
289//===----------------------------------------------------------------------===//
290
291OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
292 return constFoldUnaryOp<IntegerAttr>(
293 adaptor.getOperands(),
294 [](const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); });
295}
296
297//===----------------------------------------------------------------------===//
298// CountTrailingZerosOp folder
299//===----------------------------------------------------------------------===//
300
301OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
302 return constFoldUnaryOp<IntegerAttr>(
303 adaptor.getOperands(),
304 [](const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); });
305}
306
307//===----------------------------------------------------------------------===//
308// CtPopOp folder
309//===----------------------------------------------------------------------===//
310
311OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
312 return constFoldUnaryOp<IntegerAttr>(
313 adaptor.getOperands(),
314 [](const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); });
315}
316
317//===----------------------------------------------------------------------===//
318// ErfOp folder
319//===----------------------------------------------------------------------===//
320
321OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
322 return constFoldUnaryOpConditional<FloatAttr>(
323 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
324 switch (a.getSizeInBits(a.getSemantics())) {
325 case 64:
326 return APFloat(erf(a.convertToDouble()));
327 case 32:
328 return APFloat(erff(a.convertToFloat()));
329 default:
330 return {};
331 }
332 });
333}
334
335//===----------------------------------------------------------------------===//
336// ErfcOp folder
337//===----------------------------------------------------------------------===//
338
339OpFoldResult math::ErfcOp::fold(FoldAdaptor adaptor) {
340 return constFoldUnaryOpConditional<FloatAttr>(
341 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
342 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
343 case APFloat::Semantics::S_IEEEdouble:
344 return APFloat(erfc(a.convertToDouble()));
345 case APFloat::Semantics::S_IEEEsingle:
346 return APFloat(erfcf(a.convertToFloat()));
347 default:
348 return {};
349 }
350 });
351}
352
353//===----------------------------------------------------------------------===//
354// IPowIOp folder
355//===----------------------------------------------------------------------===//
356
357OpFoldResult math::IPowIOp::fold(FoldAdaptor adaptor) {
358 return constFoldBinaryOpConditional<IntegerAttr>(
359 adaptor.getOperands(),
360 [](const APInt &base, const APInt &power) -> std::optional<APInt> {
361 unsigned width = base.getBitWidth();
362 auto zeroValue = APInt::getZero(width);
363 APInt oneValue{width, 1ULL, /*isSigned=*/true};
364 APInt minusOneValue{width, -1ULL, /*isSigned=*/true};
365
366 if (power.isZero())
367 return oneValue;
368
369 if (power.isNegative()) {
370 // Leave 0 raised to negative power not folded.
371 if (base.isZero())
372 return {};
373 if (base.eq(oneValue))
374 return oneValue;
375 // If abs(base) > 1, then the result is zero.
376 if (base.ne(minusOneValue))
377 return zeroValue;
378 // base == -1:
379 // -1: power is odd
380 // 1: power is even
381 if (power[0] == 1)
382 return minusOneValue;
383
384 return oneValue;
385 }
386
387 // power is positive.
388 APInt result = oneValue;
389 APInt curBase = base;
390 APInt curPower = power;
391 while (true) {
392 if (curPower[0] == 1)
393 result *= curBase;
394 curPower.lshrInPlace(1);
395 if (curPower.isZero())
396 return result;
397 curBase *= curBase;
398 }
399 });
400
401 return Attribute();
402}
403
404//===----------------------------------------------------------------------===//
405// LogOp folder
406//===----------------------------------------------------------------------===//
407
408OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
409 return constFoldUnaryOpConditional<FloatAttr>(
410 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
411 if (a.isNegative())
412 return {};
413
414 if (a.getSizeInBits(a.getSemantics()) == 64)
415 return APFloat(log(a.convertToDouble()));
416
417 if (a.getSizeInBits(a.getSemantics()) == 32)
418 return APFloat(logf(a.convertToFloat()));
419
420 return {};
421 });
422}
423
424//===----------------------------------------------------------------------===//
425// Log2Op folder
426//===----------------------------------------------------------------------===//
427
428OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
429 return constFoldUnaryOpConditional<FloatAttr>(
430 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
431 if (a.isNegative())
432 return {};
433
434 if (a.getSizeInBits(a.getSemantics()) == 64)
435 return APFloat(log2(a.convertToDouble()));
436
437 if (a.getSizeInBits(a.getSemantics()) == 32)
438 return APFloat(log2f(a.convertToFloat()));
439
440 return {};
441 });
442}
443
444//===----------------------------------------------------------------------===//
445// Log10Op folder
446//===----------------------------------------------------------------------===//
447
448OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
449 return constFoldUnaryOpConditional<FloatAttr>(
450 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
451 if (a.isNegative())
452 return {};
453
454 switch (a.getSizeInBits(a.getSemantics())) {
455 case 64:
456 return APFloat(log10(a.convertToDouble()));
457 case 32:
458 return APFloat(log10f(a.convertToFloat()));
459 default:
460 return {};
461 }
462 });
463}
464
465//===----------------------------------------------------------------------===//
466// Log1pOp folder
467//===----------------------------------------------------------------------===//
468
469OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
470 return constFoldUnaryOpConditional<FloatAttr>(
471 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
472 switch (a.getSizeInBits(a.getSemantics())) {
473 case 64:
474 if ((a + APFloat(1.0)).isNegative())
475 return {};
476 return APFloat(log1p(a.convertToDouble()));
477 case 32:
478 if ((a + APFloat(1.0f)).isNegative())
479 return {};
480 return APFloat(log1pf(a.convertToFloat()));
481 default:
482 return {};
483 }
484 });
485}
486
487//===----------------------------------------------------------------------===//
488// PowFOp folder
489//===----------------------------------------------------------------------===//
490
491OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
492 return constFoldBinaryOpConditional<FloatAttr>(
493 adaptor.getOperands(),
494 [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
495 if (a.getSizeInBits(a.getSemantics()) == 64 &&
496 b.getSizeInBits(b.getSemantics()) == 64)
497 return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
498
499 if (a.getSizeInBits(a.getSemantics()) == 32 &&
500 b.getSizeInBits(b.getSemantics()) == 32)
501 return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
502
503 return {};
504 });
505}
506
507//===----------------------------------------------------------------------===//
508// SqrtOp folder
509//===----------------------------------------------------------------------===//
510
511OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
512 return constFoldUnaryOpConditional<FloatAttr>(
513 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
514 if (a.isNegative())
515 return {};
516
517 switch (a.getSizeInBits(a.getSemantics())) {
518 case 64:
519 return APFloat(sqrt(a.convertToDouble()));
520 case 32:
521 return APFloat(sqrtf(a.convertToFloat()));
522 default:
523 return {};
524 }
525 });
526}
527
528//===----------------------------------------------------------------------===//
529// ExpOp folder
530//===----------------------------------------------------------------------===//
531
532OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
533 return constFoldUnaryOpConditional<FloatAttr>(
534 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
535 switch (a.getSizeInBits(a.getSemantics())) {
536 case 64:
537 return APFloat(exp(a.convertToDouble()));
538 case 32:
539 return APFloat(expf(a.convertToFloat()));
540 default:
541 return {};
542 }
543 });
544}
545
546//===----------------------------------------------------------------------===//
547// Exp2Op folder
548//===----------------------------------------------------------------------===//
549
550OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
551 return constFoldUnaryOpConditional<FloatAttr>(
552 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
553 switch (a.getSizeInBits(a.getSemantics())) {
554 case 64:
555 return APFloat(exp2(a.convertToDouble()));
556 case 32:
557 return APFloat(exp2f(a.convertToFloat()));
558 default:
559 return {};
560 }
561 });
562}
563
564//===----------------------------------------------------------------------===//
565// ExpM1Op folder
566//===----------------------------------------------------------------------===//
567
568OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
569 return constFoldUnaryOpConditional<FloatAttr>(
570 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
571 switch (a.getSizeInBits(a.getSemantics())) {
572 case 64:
573 return APFloat(expm1(a.convertToDouble()));
574 case 32:
575 return APFloat(expm1f(a.convertToFloat()));
576 default:
577 return {};
578 }
579 });
580}
581
582//===----------------------------------------------------------------------===//
583// IsFiniteOp folder
584//===----------------------------------------------------------------------===//
585
586OpFoldResult math::IsFiniteOp::fold(FoldAdaptor adaptor) {
587 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
588 return BoolAttr::get(val.getContext(), val.getValue().isFinite());
589 }
590 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
591 return DenseElementsAttr::get(
592 cast<ShapedType>(getType()),
593 APInt(1, splat.getSplatValue<APFloat>().isFinite()));
594 }
595 return {};
596}
597
598//===----------------------------------------------------------------------===//
599// IsInfOp folder
600//===----------------------------------------------------------------------===//
601
602OpFoldResult math::IsInfOp::fold(FoldAdaptor adaptor) {
603 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
604 return BoolAttr::get(val.getContext(), val.getValue().isInfinity());
605 }
606 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
607 return DenseElementsAttr::get(
608 cast<ShapedType>(getType()),
609 APInt(1, splat.getSplatValue<APFloat>().isInfinity()));
610 }
611 return {};
612}
613
614//===----------------------------------------------------------------------===//
615// IsNaNOp folder
616//===----------------------------------------------------------------------===//
617
618OpFoldResult math::IsNaNOp::fold(FoldAdaptor adaptor) {
619 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
620 return BoolAttr::get(val.getContext(), val.getValue().isNaN());
621 }
622 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
623 return DenseElementsAttr::get(
624 cast<ShapedType>(getType()),
625 APInt(1, splat.getSplatValue<APFloat>().isNaN()));
626 }
627 return {};
628}
629
630//===----------------------------------------------------------------------===//
631// IsNormalOp folder
632//===----------------------------------------------------------------------===//
633
634OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
635 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
636 return BoolAttr::get(val.getContext(), val.getValue().isNormal());
637 }
638 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
639 return DenseElementsAttr::get(
640 cast<ShapedType>(getType()),
641 APInt(1, splat.getSplatValue<APFloat>().isNormal()));
642 }
643 return {};
644}
645
646//===----------------------------------------------------------------------===//
647// TanOp folder
648//===----------------------------------------------------------------------===//
649
650OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
651 return constFoldUnaryOpConditional<FloatAttr>(
652 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
653 switch (a.getSizeInBits(a.getSemantics())) {
654 case 64:
655 return APFloat(tan(a.convertToDouble()));
656 case 32:
657 return APFloat(tanf(a.convertToFloat()));
658 default:
659 return {};
660 }
661 });
662}
663
664//===----------------------------------------------------------------------===//
665// TanhOp folder
666//===----------------------------------------------------------------------===//
667
668OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
669 return constFoldUnaryOpConditional<FloatAttr>(
670 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
671 switch (a.getSizeInBits(a.getSemantics())) {
672 case 64:
673 return APFloat(tanh(a.convertToDouble()));
674 case 32:
675 return APFloat(tanhf(a.convertToFloat()));
676 default:
677 return {};
678 }
679 });
680}
681
682//===----------------------------------------------------------------------===//
683// RoundEvenOp folder
684//===----------------------------------------------------------------------===//
685
686OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
687 return constFoldUnaryOp<FloatAttr>(
688 adaptor.getOperands(), [](const APFloat &a) {
689 APFloat result(a);
690 result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
691 return result;
692 });
693}
694
695//===----------------------------------------------------------------------===//
696// FloorOp folder
697//===----------------------------------------------------------------------===//
698
699OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
700 return constFoldUnaryOp<FloatAttr>(
701 adaptor.getOperands(), [](const APFloat &a) {
702 APFloat result(a);
703 result.roundToIntegral(llvm::RoundingMode::TowardNegative);
704 return result;
705 });
706}
707
708//===----------------------------------------------------------------------===//
709// RoundOp folder
710//===----------------------------------------------------------------------===//
711
712OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
713 return constFoldUnaryOpConditional<FloatAttr>(
714 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
715 switch (a.getSizeInBits(a.getSemantics())) {
716 case 64:
717 return APFloat(round(a.convertToDouble()));
718 case 32:
719 return APFloat(roundf(a.convertToFloat()));
720 default:
721 return {};
722 }
723 });
724}
725
726//===----------------------------------------------------------------------===//
727// TruncOp folder
728//===----------------------------------------------------------------------===//
729
730OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
731 return constFoldUnaryOpConditional<FloatAttr>(
732 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
733 switch (a.getSizeInBits(a.getSemantics())) {
734 case 64:
735 return APFloat(trunc(a.convertToDouble()));
736 case 32:
737 return APFloat(truncf(a.convertToFloat()));
738 default:
739 return {};
740 }
741 });
742}
743
744/// Materialize an integer or floating point constant.
745Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
746 Attribute value, Type type,
747 Location loc) {
748 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
749 return builder.create<ub::PoisonOp>(loc, type, poison);
750
751 return arith::ConstantOp::materialize(builder, value, type, loc);
752}
753

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Math/IR/MathOps.cpp