1//===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/MathToLibm/MathToLibm.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14#include "mlir/Dialect/Math/IR/Math.h"
15#include "mlir/Dialect/Utils/IndexingUtils.h"
16#include "mlir/Dialect/Vector/IR/VectorOps.h"
17#include "mlir/IR/BuiltinDialect.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Transforms/DialectConversion.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_CONVERTMATHTOLIBMPASS
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27
28namespace {
29// Pattern to convert vector operations to scalar operations. This is needed as
30// libm calls require scalars.
31template <typename Op>
32struct VecOpToScalarOp : public OpRewritePattern<Op> {
33public:
34 using OpRewritePattern<Op>::OpRewritePattern;
35
36 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
37};
38// Pattern to promote an op of a smaller floating point type to F32.
39template <typename Op>
40struct PromoteOpToF32 : public OpRewritePattern<Op> {
41public:
42 using OpRewritePattern<Op>::OpRewritePattern;
43
44 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
45};
46// Pattern to convert scalar math operations to calls to libm functions.
47// Additionally the libm function signatures are declared.
48template <typename Op>
49struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
50public:
51 using OpRewritePattern<Op>::OpRewritePattern;
52 ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit,
53 StringRef floatFunc, StringRef doubleFunc)
54 : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
55 doubleFunc(doubleFunc) {};
56
57 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
58
59private:
60 std::string floatFunc, doubleFunc;
61};
62
63template <typename OpTy>
64void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit,
65 MLIRContext *ctx, StringRef floatFunc,
66 StringRef doubleFunc) {
67 patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
68 patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
69}
70
71} // namespace
72
73template <typename Op>
74LogicalResult
75VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
76 auto opType = op.getType();
77 auto loc = op.getLoc();
78 auto vecType = dyn_cast<VectorType>(opType);
79
80 if (!vecType)
81 return failure();
82 if (!vecType.hasRank())
83 return failure();
84 auto shape = vecType.getShape();
85 int64_t numElements = vecType.getNumElements();
86
87 Value result = rewriter.create<arith::ConstantOp>(
88 loc, DenseElementsAttr::get(
89 vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
90 SmallVector<int64_t> strides = computeStrides(shape);
91 for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
92 SmallVector<int64_t> positions = delinearize(linearIndex, strides);
93 SmallVector<Value> operands;
94 for (auto input : op->getOperands())
95 operands.push_back(
96 Elt: rewriter.create<vector::ExtractOp>(loc, input, positions));
97 Value scalarOp =
98 rewriter.create<Op>(loc, vecType.getElementType(), operands);
99 result =
100 rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
101 }
102 rewriter.replaceOp(op, {result});
103 return success();
104}
105
106template <typename Op>
107LogicalResult
108PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
109 auto opType = op.getType();
110 if (!isa<Float16Type, BFloat16Type>(opType))
111 return failure();
112
113 auto loc = op.getLoc();
114 auto f32 = rewriter.getF32Type();
115 auto extendedOperands = llvm::to_vector(
116 llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
117 return rewriter.create<arith::ExtFOp>(loc, f32, operand);
118 }));
119 auto newOp = rewriter.create<Op>(loc, f32, extendedOperands);
120 rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp);
121 return success();
122}
123
124template <typename Op>
125LogicalResult
126ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
127 PatternRewriter &rewriter) const {
128 auto module = SymbolTable::getNearestSymbolTable(from: op);
129 auto type = op.getType();
130 if (!isa<Float32Type, Float64Type>(type))
131 return failure();
132
133 auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
134 auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
135 SymbolTable::lookupSymbolIn(module, name));
136 // Forward declare function if it hasn't already been
137 if (!opFunc) {
138 OpBuilder::InsertionGuard guard(rewriter);
139 rewriter.setInsertionPointToStart(&module->getRegion(0).front());
140 auto opFunctionTy = FunctionType::get(
141 context: rewriter.getContext(), inputs: op->getOperandTypes(), results: op->getResultTypes());
142 opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
143 opFunctionTy);
144 opFunc.setPrivate();
145
146 // By definition Math dialect operations imply LLVM's "readnone"
147 // function attribute, so we can set it here to provide more
148 // optimization opportunities (e.g. LICM) for backends targeting LLVM IR.
149 // This will have to be changed, when strict FP behavior is supported
150 // by Math dialect.
151 opFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(),
152 UnitAttr::get(context: rewriter.getContext()));
153 }
154 assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
155
156 rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
157 op->getOperands());
158
159 return success();
160}
161
162void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
163 PatternBenefit benefit) {
164 MLIRContext *ctx = patterns.getContext();
165
166 populatePatternsForOp<math::AbsFOp>(patterns, benefit, ctx, floatFunc: "fabsf", doubleFunc: "fabs");
167 populatePatternsForOp<math::AcosOp>(patterns, benefit, ctx, floatFunc: "acosf", doubleFunc: "acos");
168 populatePatternsForOp<math::AcoshOp>(patterns, benefit, ctx, floatFunc: "acoshf",
169 doubleFunc: "acosh");
170 populatePatternsForOp<math::AsinOp>(patterns, benefit, ctx, floatFunc: "asinf", doubleFunc: "asin");
171 populatePatternsForOp<math::AsinhOp>(patterns, benefit, ctx, floatFunc: "asinhf",
172 doubleFunc: "asinh");
173 populatePatternsForOp<math::Atan2Op>(patterns, benefit, ctx, floatFunc: "atan2f",
174 doubleFunc: "atan2");
175 populatePatternsForOp<math::AtanOp>(patterns, benefit, ctx, floatFunc: "atanf", doubleFunc: "atan");
176 populatePatternsForOp<math::AtanhOp>(patterns, benefit, ctx, floatFunc: "atanhf",
177 doubleFunc: "atanh");
178 populatePatternsForOp<math::CbrtOp>(patterns, benefit, ctx, floatFunc: "cbrtf", doubleFunc: "cbrt");
179 populatePatternsForOp<math::CeilOp>(patterns, benefit, ctx, floatFunc: "ceilf", doubleFunc: "ceil");
180 populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, floatFunc: "cosf", doubleFunc: "cos");
181 populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, floatFunc: "coshf", doubleFunc: "cosh");
182 populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, floatFunc: "erff", doubleFunc: "erf");
183 populatePatternsForOp<math::ErfcOp>(patterns, benefit, ctx, floatFunc: "erfcf", doubleFunc: "erfc");
184 populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, floatFunc: "expf", doubleFunc: "exp");
185 populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, floatFunc: "exp2f", doubleFunc: "exp2");
186 populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, floatFunc: "expm1f",
187 doubleFunc: "expm1");
188 populatePatternsForOp<math::FloorOp>(patterns, benefit, ctx, floatFunc: "floorf",
189 doubleFunc: "floor");
190 populatePatternsForOp<math::FmaOp>(patterns, benefit, ctx, floatFunc: "fmaf", doubleFunc: "fma");
191 populatePatternsForOp<math::LogOp>(patterns, benefit, ctx, floatFunc: "logf", doubleFunc: "log");
192 populatePatternsForOp<math::Log2Op>(patterns, benefit, ctx, floatFunc: "log2f", doubleFunc: "log2");
193 populatePatternsForOp<math::Log10Op>(patterns, benefit, ctx, floatFunc: "log10f",
194 doubleFunc: "log10");
195 populatePatternsForOp<math::Log1pOp>(patterns, benefit, ctx, floatFunc: "log1pf",
196 doubleFunc: "log1p");
197 populatePatternsForOp<math::PowFOp>(patterns, benefit, ctx, floatFunc: "powf", doubleFunc: "pow");
198 populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, floatFunc: "roundevenf",
199 doubleFunc: "roundeven");
200 populatePatternsForOp<math::RoundOp>(patterns, benefit, ctx, floatFunc: "roundf",
201 doubleFunc: "round");
202 populatePatternsForOp<math::SinOp>(patterns, benefit, ctx, floatFunc: "sinf", doubleFunc: "sin");
203 populatePatternsForOp<math::SinhOp>(patterns, benefit, ctx, floatFunc: "sinhf", doubleFunc: "sinh");
204 populatePatternsForOp<math::SqrtOp>(patterns, benefit, ctx, floatFunc: "sqrtf", doubleFunc: "sqrt");
205 populatePatternsForOp<math::RsqrtOp>(patterns, benefit, ctx, floatFunc: "rsqrtf",
206 doubleFunc: "rsqrt");
207 populatePatternsForOp<math::TanOp>(patterns, benefit, ctx, floatFunc: "tanf", doubleFunc: "tan");
208 populatePatternsForOp<math::TanhOp>(patterns, benefit, ctx, floatFunc: "tanhf", doubleFunc: "tanh");
209 populatePatternsForOp<math::TruncOp>(patterns, benefit, ctx, floatFunc: "truncf",
210 doubleFunc: "trunc");
211}
212
213namespace {
214struct ConvertMathToLibmPass
215 : public impl::ConvertMathToLibmPassBase<ConvertMathToLibmPass> {
216 void runOnOperation() override;
217};
218} // namespace
219
220void ConvertMathToLibmPass::runOnOperation() {
221 auto module = getOperation();
222
223 RewritePatternSet patterns(&getContext());
224 populateMathToLibmConversionPatterns(patterns);
225
226 ConversionTarget target(getContext());
227 target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect,
228 vector::VectorDialect>();
229 target.addIllegalDialect<math::MathDialect>();
230 if (failed(Result: applyPartialConversion(op: module, target, patterns: std::move(patterns))))
231 signalPassFailure();
232}
233

source code of mlir/lib/Conversion/MathToLibm/MathToLibm.cpp