1//===- MathOps.cpp - MLIR operations for math implementation --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/CommonFolders.h"
11#include "mlir/Dialect/Math/IR/Math.h"
12#include "mlir/Dialect/UB/IR/UBOps.h"
13#include "mlir/IR/Builders.h"
14#include <optional>
15
16using namespace mlir;
17using namespace mlir::math;
18
19//===----------------------------------------------------------------------===//
20// TableGen'd op method definitions
21//===----------------------------------------------------------------------===//
22
23#define GET_OP_CLASSES
24#include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
25
26//===----------------------------------------------------------------------===//
27// AbsFOp folder
28//===----------------------------------------------------------------------===//
29
30OpFoldResult math::AbsFOp::fold(FoldAdaptor adaptor) {
31 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
32 [](const APFloat &a) { return abs(a); });
33}
34
35//===----------------------------------------------------------------------===//
36// AbsIOp folder
37//===----------------------------------------------------------------------===//
38
39OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
40 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
41 [](const APInt &a) { return a.abs(); });
42}
43
44//===----------------------------------------------------------------------===//
45// AcosOp folder
46//===----------------------------------------------------------------------===//
47
48OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
49 return constFoldUnaryOpConditional<FloatAttr>(
50 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
51 switch (a.getSizeInBits(a.getSemantics())) {
52 case 64:
53 return APFloat(acos(a.convertToDouble()));
54 case 32:
55 return APFloat(acosf(a.convertToFloat()));
56 default:
57 return {};
58 }
59 });
60}
61
62//===----------------------------------------------------------------------===//
63// AcoshOp folder
64//===----------------------------------------------------------------------===//
65
66OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
67 return constFoldUnaryOpConditional<FloatAttr>(
68 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
69 switch (a.getSizeInBits(a.getSemantics())) {
70 case 64:
71 return APFloat(acosh(a.convertToDouble()));
72 case 32:
73 return APFloat(acoshf(a.convertToFloat()));
74 default:
75 return {};
76 }
77 });
78}
79
80//===----------------------------------------------------------------------===//
81// AsinOp folder
82//===----------------------------------------------------------------------===//
83
84OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
85 return constFoldUnaryOpConditional<FloatAttr>(
86 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
87 switch (a.getSizeInBits(a.getSemantics())) {
88 case 64:
89 return APFloat(asin(a.convertToDouble()));
90 case 32:
91 return APFloat(asinf(a.convertToFloat()));
92 default:
93 return {};
94 }
95 });
96}
97
98//===----------------------------------------------------------------------===//
99// AsinhOp folder
100//===----------------------------------------------------------------------===//
101
102OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
103 return constFoldUnaryOpConditional<FloatAttr>(
104 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
105 switch (a.getSizeInBits(a.getSemantics())) {
106 case 64:
107 return APFloat(asinh(a.convertToDouble()));
108 case 32:
109 return APFloat(asinhf(a.convertToFloat()));
110 default:
111 return {};
112 }
113 });
114}
115
116//===----------------------------------------------------------------------===//
117// AtanOp folder
118//===----------------------------------------------------------------------===//
119
120OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
121 return constFoldUnaryOpConditional<FloatAttr>(
122 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
123 switch (a.getSizeInBits(a.getSemantics())) {
124 case 64:
125 return APFloat(atan(a.convertToDouble()));
126 case 32:
127 return APFloat(atanf(a.convertToFloat()));
128 default:
129 return {};
130 }
131 });
132}
133
134//===----------------------------------------------------------------------===//
135// AtanhOp folder
136//===----------------------------------------------------------------------===//
137
138OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) {
139 return constFoldUnaryOpConditional<FloatAttr>(
140 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
141 switch (a.getSizeInBits(a.getSemantics())) {
142 case 64:
143 return APFloat(atanh(a.convertToDouble()));
144 case 32:
145 return APFloat(atanhf(a.convertToFloat()));
146 default:
147 return {};
148 }
149 });
150}
151
152//===----------------------------------------------------------------------===//
153// Atan2Op folder
154//===----------------------------------------------------------------------===//
155
156OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
157 return constFoldBinaryOpConditional<FloatAttr>(
158 adaptor.getOperands(),
159 [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
160 if (a.isZero() && b.isZero())
161 return llvm::APFloat::getNaN(a.getSemantics());
162
163 if (a.getSizeInBits(a.getSemantics()) == 64 &&
164 b.getSizeInBits(b.getSemantics()) == 64)
165 return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
166
167 if (a.getSizeInBits(a.getSemantics()) == 32 &&
168 b.getSizeInBits(b.getSemantics()) == 32)
169 return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
170
171 return {};
172 });
173}
174
175//===----------------------------------------------------------------------===//
176// CeilOp folder
177//===----------------------------------------------------------------------===//
178
179OpFoldResult math::CeilOp::fold(FoldAdaptor adaptor) {
180 return constFoldUnaryOp<FloatAttr>(
181 adaptor.getOperands(), [](const APFloat &a) {
182 APFloat result(a);
183 result.roundToIntegral(llvm::RoundingMode::TowardPositive);
184 return result;
185 });
186}
187
188//===----------------------------------------------------------------------===//
189// CopySignOp folder
190//===----------------------------------------------------------------------===//
191
192OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
193 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
194 [](const APFloat &a, const APFloat &b) {
195 APFloat result(a);
196 result.copySign(b);
197 return result;
198 });
199}
200
201//===----------------------------------------------------------------------===//
202// CosOp folder
203//===----------------------------------------------------------------------===//
204
205OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
206 return constFoldUnaryOpConditional<FloatAttr>(
207 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
208 switch (a.getSizeInBits(a.getSemantics())) {
209 case 64:
210 return APFloat(cos(a.convertToDouble()));
211 case 32:
212 return APFloat(cosf(a.convertToFloat()));
213 default:
214 return {};
215 }
216 });
217}
218
219//===----------------------------------------------------------------------===//
220// CoshOp folder
221//===----------------------------------------------------------------------===//
222
223OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
224 return constFoldUnaryOpConditional<FloatAttr>(
225 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
226 switch (a.getSizeInBits(a.getSemantics())) {
227 case 64:
228 return APFloat(cosh(a.convertToDouble()));
229 case 32:
230 return APFloat(coshf(a.convertToFloat()));
231 default:
232 return {};
233 }
234 });
235}
236
237//===----------------------------------------------------------------------===//
238// SinOp folder
239//===----------------------------------------------------------------------===//
240
241OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
242 return constFoldUnaryOpConditional<FloatAttr>(
243 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
244 switch (a.getSizeInBits(a.getSemantics())) {
245 case 64:
246 return APFloat(sin(a.convertToDouble()));
247 case 32:
248 return APFloat(sinf(a.convertToFloat()));
249 default:
250 return {};
251 }
252 });
253}
254
255//===----------------------------------------------------------------------===//
256// SinhOp folder
257//===----------------------------------------------------------------------===//
258
259OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
260 return constFoldUnaryOpConditional<FloatAttr>(
261 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
262 switch (a.getSizeInBits(a.getSemantics())) {
263 case 64:
264 return APFloat(sinh(a.convertToDouble()));
265 case 32:
266 return APFloat(sinhf(a.convertToFloat()));
267 default:
268 return {};
269 }
270 });
271}
272
273//===----------------------------------------------------------------------===//
274// CountLeadingZerosOp folder
275//===----------------------------------------------------------------------===//
276
277OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
278 return constFoldUnaryOp<IntegerAttr>(
279 adaptor.getOperands(),
280 [](const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); });
281}
282
283//===----------------------------------------------------------------------===//
284// CountTrailingZerosOp folder
285//===----------------------------------------------------------------------===//
286
287OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
288 return constFoldUnaryOp<IntegerAttr>(
289 adaptor.getOperands(),
290 [](const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); });
291}
292
293//===----------------------------------------------------------------------===//
294// CtPopOp folder
295//===----------------------------------------------------------------------===//
296
297OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
298 return constFoldUnaryOp<IntegerAttr>(
299 adaptor.getOperands(),
300 [](const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); });
301}
302
303//===----------------------------------------------------------------------===//
304// ErfOp folder
305//===----------------------------------------------------------------------===//
306
307OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
308 return constFoldUnaryOpConditional<FloatAttr>(
309 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
310 switch (a.getSizeInBits(a.getSemantics())) {
311 case 64:
312 return APFloat(erf(a.convertToDouble()));
313 case 32:
314 return APFloat(erff(a.convertToFloat()));
315 default:
316 return {};
317 }
318 });
319}
320
321//===----------------------------------------------------------------------===//
322// IPowIOp folder
323//===----------------------------------------------------------------------===//
324
325OpFoldResult math::IPowIOp::fold(FoldAdaptor adaptor) {
326 return constFoldBinaryOpConditional<IntegerAttr>(
327 adaptor.getOperands(),
328 [](const APInt &base, const APInt &power) -> std::optional<APInt> {
329 unsigned width = base.getBitWidth();
330 auto zeroValue = APInt::getZero(width);
331 APInt oneValue{width, 1ULL, /*isSigned=*/true};
332 APInt minusOneValue{width, -1ULL, /*isSigned=*/true};
333
334 if (power.isZero())
335 return oneValue;
336
337 if (power.isNegative()) {
338 // Leave 0 raised to negative power not folded.
339 if (base.isZero())
340 return {};
341 if (base.eq(oneValue))
342 return oneValue;
343 // If abs(base) > 1, then the result is zero.
344 if (base.ne(minusOneValue))
345 return zeroValue;
346 // base == -1:
347 // -1: power is odd
348 // 1: power is even
349 if (power[0] == 1)
350 return minusOneValue;
351
352 return oneValue;
353 }
354
355 // power is positive.
356 APInt result = oneValue;
357 APInt curBase = base;
358 APInt curPower = power;
359 while (true) {
360 if (curPower[0] == 1)
361 result *= curBase;
362 curPower.lshrInPlace(1);
363 if (curPower.isZero())
364 return result;
365 curBase *= curBase;
366 }
367 });
368
369 return Attribute();
370}
371
372//===----------------------------------------------------------------------===//
373// LogOp folder
374//===----------------------------------------------------------------------===//
375
376OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
377 return constFoldUnaryOpConditional<FloatAttr>(
378 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
379 if (a.isNegative())
380 return {};
381
382 if (a.getSizeInBits(a.getSemantics()) == 64)
383 return APFloat(log(a.convertToDouble()));
384
385 if (a.getSizeInBits(a.getSemantics()) == 32)
386 return APFloat(logf(a.convertToFloat()));
387
388 return {};
389 });
390}
391
392//===----------------------------------------------------------------------===//
393// Log2Op folder
394//===----------------------------------------------------------------------===//
395
396OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
397 return constFoldUnaryOpConditional<FloatAttr>(
398 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
399 if (a.isNegative())
400 return {};
401
402 if (a.getSizeInBits(a.getSemantics()) == 64)
403 return APFloat(log2(a.convertToDouble()));
404
405 if (a.getSizeInBits(a.getSemantics()) == 32)
406 return APFloat(log2f(a.convertToFloat()));
407
408 return {};
409 });
410}
411
412//===----------------------------------------------------------------------===//
413// Log10Op folder
414//===----------------------------------------------------------------------===//
415
416OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
417 return constFoldUnaryOpConditional<FloatAttr>(
418 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
419 if (a.isNegative())
420 return {};
421
422 switch (a.getSizeInBits(a.getSemantics())) {
423 case 64:
424 return APFloat(log10(a.convertToDouble()));
425 case 32:
426 return APFloat(log10f(a.convertToFloat()));
427 default:
428 return {};
429 }
430 });
431}
432
433//===----------------------------------------------------------------------===//
434// Log1pOp folder
435//===----------------------------------------------------------------------===//
436
437OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
438 return constFoldUnaryOpConditional<FloatAttr>(
439 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
440 switch (a.getSizeInBits(a.getSemantics())) {
441 case 64:
442 if ((a + APFloat(1.0)).isNegative())
443 return {};
444 return APFloat(log1p(a.convertToDouble()));
445 case 32:
446 if ((a + APFloat(1.0f)).isNegative())
447 return {};
448 return APFloat(log1pf(a.convertToFloat()));
449 default:
450 return {};
451 }
452 });
453}
454
455//===----------------------------------------------------------------------===//
456// PowFOp folder
457//===----------------------------------------------------------------------===//
458
459OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
460 return constFoldBinaryOpConditional<FloatAttr>(
461 adaptor.getOperands(),
462 [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
463 if (a.getSizeInBits(a.getSemantics()) == 64 &&
464 b.getSizeInBits(b.getSemantics()) == 64)
465 return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
466
467 if (a.getSizeInBits(a.getSemantics()) == 32 &&
468 b.getSizeInBits(b.getSemantics()) == 32)
469 return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
470
471 return {};
472 });
473}
474
475//===----------------------------------------------------------------------===//
476// SqrtOp folder
477//===----------------------------------------------------------------------===//
478
479OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
480 return constFoldUnaryOpConditional<FloatAttr>(
481 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
482 if (a.isNegative())
483 return {};
484
485 switch (a.getSizeInBits(a.getSemantics())) {
486 case 64:
487 return APFloat(sqrt(a.convertToDouble()));
488 case 32:
489 return APFloat(sqrtf(a.convertToFloat()));
490 default:
491 return {};
492 }
493 });
494}
495
496//===----------------------------------------------------------------------===//
497// ExpOp folder
498//===----------------------------------------------------------------------===//
499
500OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
501 return constFoldUnaryOpConditional<FloatAttr>(
502 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
503 switch (a.getSizeInBits(a.getSemantics())) {
504 case 64:
505 return APFloat(exp(a.convertToDouble()));
506 case 32:
507 return APFloat(expf(a.convertToFloat()));
508 default:
509 return {};
510 }
511 });
512}
513
514//===----------------------------------------------------------------------===//
515// Exp2Op folder
516//===----------------------------------------------------------------------===//
517
518OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
519 return constFoldUnaryOpConditional<FloatAttr>(
520 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
521 switch (a.getSizeInBits(a.getSemantics())) {
522 case 64:
523 return APFloat(exp2(a.convertToDouble()));
524 case 32:
525 return APFloat(exp2f(a.convertToFloat()));
526 default:
527 return {};
528 }
529 });
530}
531
532//===----------------------------------------------------------------------===//
533// ExpM1Op folder
534//===----------------------------------------------------------------------===//
535
536OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
537 return constFoldUnaryOpConditional<FloatAttr>(
538 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
539 switch (a.getSizeInBits(a.getSemantics())) {
540 case 64:
541 return APFloat(expm1(a.convertToDouble()));
542 case 32:
543 return APFloat(expm1f(a.convertToFloat()));
544 default:
545 return {};
546 }
547 });
548}
549
550//===----------------------------------------------------------------------===//
551// TanOp folder
552//===----------------------------------------------------------------------===//
553
554OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
555 return constFoldUnaryOpConditional<FloatAttr>(
556 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
557 switch (a.getSizeInBits(a.getSemantics())) {
558 case 64:
559 return APFloat(tan(a.convertToDouble()));
560 case 32:
561 return APFloat(tanf(a.convertToFloat()));
562 default:
563 return {};
564 }
565 });
566}
567
568//===----------------------------------------------------------------------===//
569// TanhOp folder
570//===----------------------------------------------------------------------===//
571
572OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
573 return constFoldUnaryOpConditional<FloatAttr>(
574 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
575 switch (a.getSizeInBits(a.getSemantics())) {
576 case 64:
577 return APFloat(tanh(a.convertToDouble()));
578 case 32:
579 return APFloat(tanhf(a.convertToFloat()));
580 default:
581 return {};
582 }
583 });
584}
585
586//===----------------------------------------------------------------------===//
587// RoundEvenOp folder
588//===----------------------------------------------------------------------===//
589
590OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
591 return constFoldUnaryOp<FloatAttr>(
592 adaptor.getOperands(), [](const APFloat &a) {
593 APFloat result(a);
594 result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
595 return result;
596 });
597}
598
599//===----------------------------------------------------------------------===//
600// FloorOp folder
601//===----------------------------------------------------------------------===//
602
603OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
604 return constFoldUnaryOp<FloatAttr>(
605 adaptor.getOperands(), [](const APFloat &a) {
606 APFloat result(a);
607 result.roundToIntegral(llvm::RoundingMode::TowardNegative);
608 return result;
609 });
610}
611
612//===----------------------------------------------------------------------===//
613// RoundOp folder
614//===----------------------------------------------------------------------===//
615
616OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
617 return constFoldUnaryOpConditional<FloatAttr>(
618 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
619 switch (a.getSizeInBits(a.getSemantics())) {
620 case 64:
621 return APFloat(round(a.convertToDouble()));
622 case 32:
623 return APFloat(roundf(a.convertToFloat()));
624 default:
625 return {};
626 }
627 });
628}
629
630//===----------------------------------------------------------------------===//
631// TruncOp folder
632//===----------------------------------------------------------------------===//
633
634OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
635 return constFoldUnaryOpConditional<FloatAttr>(
636 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
637 switch (a.getSizeInBits(a.getSemantics())) {
638 case 64:
639 return APFloat(trunc(a.convertToDouble()));
640 case 32:
641 return APFloat(truncf(a.convertToFloat()));
642 default:
643 return {};
644 }
645 });
646}
647
648/// Materialize an integer or floating point constant.
649Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
650 Attribute value, Type type,
651 Location loc) {
652 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
653 return builder.create<ub::PoisonOp>(loc, type, poison);
654
655 return arith::ConstantOp::materialize(builder, value, type, loc);
656}
657

source code of mlir/lib/Dialect/Math/IR/MathOps.cpp