1 | //===- MathToEmitC.cpp - Math to EmitC Patterns -----------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/MathToEmitC/MathToEmitC.h" |
10 | |
11 | #include "mlir/Dialect/EmitC/IR/EmitC.h" |
12 | #include "mlir/Dialect/Math/IR/Math.h" |
13 | #include "mlir/Transforms/DialectConversion.h" |
14 | |
15 | using namespace mlir; |
16 | |
17 | namespace { |
18 | template <typename OpType> |
19 | class LowerToEmitCCallOpaque : public OpRewritePattern<OpType> { |
20 | std::string calleeStr; |
21 | emitc::LanguageTarget languageTarget; |
22 | |
23 | public: |
24 | LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr, |
25 | emitc::LanguageTarget languageTarget) |
26 | : OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)), |
27 | languageTarget(languageTarget) {} |
28 | |
29 | LogicalResult matchAndRewrite(OpType op, |
30 | PatternRewriter &rewriter) const override; |
31 | }; |
32 | |
33 | template <typename OpType> |
34 | LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite( |
35 | OpType op, PatternRewriter &rewriter) const { |
36 | if (!llvm::all_of(op->getOperandTypes(), |
37 | llvm::IsaPred<Float32Type, Float64Type>) || |
38 | !llvm::all_of(op->getResultTypes(), |
39 | llvm::IsaPred<Float32Type, Float64Type>)) |
40 | return rewriter.notifyMatchFailure( |
41 | op.getLoc(), |
42 | "expected all operands and results to be of type f32 or f64" ); |
43 | std::string modifiedCalleeStr = calleeStr; |
44 | if (languageTarget == emitc::LanguageTarget::cpp11) { |
45 | modifiedCalleeStr = "std::" + calleeStr; |
46 | } else if (languageTarget == emitc::LanguageTarget::c99) { |
47 | auto operandType = op->getOperandTypes()[0]; |
48 | if (operandType.isF32()) |
49 | modifiedCalleeStr = calleeStr + "f" ; |
50 | } |
51 | rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>( |
52 | op, op.getType(), modifiedCalleeStr, op->getOperands()); |
53 | return success(); |
54 | } |
55 | |
56 | } // namespace |
57 | |
58 | // Populates patterns to replace `math` operations with `emitc.call_opaque`, |
59 | // using function names consistent with those in <math.h>. |
60 | void mlir::populateConvertMathToEmitCPatterns( |
61 | RewritePatternSet &patterns, emitc::LanguageTarget languageTarget) { |
62 | auto *context = patterns.getContext(); |
63 | patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor" , |
64 | languageTarget); |
65 | patterns.insert<LowerToEmitCCallOpaque<math::RoundOp>>(context, "round" , |
66 | languageTarget); |
67 | patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp" , |
68 | languageTarget); |
69 | patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos" , |
70 | languageTarget); |
71 | patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin" , |
72 | languageTarget); |
73 | patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos" , |
74 | languageTarget); |
75 | patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin" , |
76 | languageTarget); |
77 | patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2" , |
78 | languageTarget); |
79 | patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil" , |
80 | languageTarget); |
81 | patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs" , |
82 | languageTarget); |
83 | patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow" , |
84 | languageTarget); |
85 | } |
86 | |