1//===- MathToEmitC.cpp - Math to EmitC Patterns -----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
10
11#include "mlir/Dialect/EmitC/IR/EmitC.h"
12#include "mlir/Dialect/Math/IR/Math.h"
13#include "mlir/Transforms/DialectConversion.h"
14
15using namespace mlir;
16
17namespace {
18template <typename OpType>
19class LowerToEmitCCallOpaque : public OpRewritePattern<OpType> {
20 std::string calleeStr;
21 emitc::LanguageTarget languageTarget;
22
23public:
24 LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr,
25 emitc::LanguageTarget languageTarget)
26 : OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)),
27 languageTarget(languageTarget) {}
28
29 LogicalResult matchAndRewrite(OpType op,
30 PatternRewriter &rewriter) const override;
31};
32
33template <typename OpType>
34LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
35 OpType op, PatternRewriter &rewriter) const {
36 if (!llvm::all_of(op->getOperandTypes(),
37 llvm::IsaPred<Float32Type, Float64Type>) ||
38 !llvm::all_of(op->getResultTypes(),
39 llvm::IsaPred<Float32Type, Float64Type>))
40 return rewriter.notifyMatchFailure(
41 op.getLoc(),
42 "expected all operands and results to be of type f32 or f64");
43 std::string modifiedCalleeStr = calleeStr;
44 if (languageTarget == emitc::LanguageTarget::cpp11) {
45 modifiedCalleeStr = "std::" + calleeStr;
46 } else if (languageTarget == emitc::LanguageTarget::c99) {
47 auto operandType = op->getOperandTypes()[0];
48 if (operandType.isF32())
49 modifiedCalleeStr = calleeStr + "f";
50 }
51 rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
52 op, op.getType(), modifiedCalleeStr, op->getOperands());
53 return success();
54}
55
56} // namespace
57
58// Populates patterns to replace `math` operations with `emitc.call_opaque`,
59// using function names consistent with those in <math.h>.
60void mlir::populateConvertMathToEmitCPatterns(
61 RewritePatternSet &patterns, emitc::LanguageTarget languageTarget) {
62 auto *context = patterns.getContext();
63 patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor",
64 languageTarget);
65 patterns.insert<LowerToEmitCCallOpaque<math::RoundOp>>(context, "round",
66 languageTarget);
67 patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp",
68 languageTarget);
69 patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos",
70 languageTarget);
71 patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin",
72 languageTarget);
73 patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos",
74 languageTarget);
75 patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin",
76 languageTarget);
77 patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2",
78 languageTarget);
79 patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil",
80 languageTarget);
81 patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs",
82 languageTarget);
83 patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow",
84 languageTarget);
85}
86

source code of mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp