1//===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
10
11#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14#include "mlir/Conversion/LLVMCommon/Pattern.h"
15#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
16#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17#include "mlir/Dialect/Math/IR/Math.h"
18#include "mlir/IR/TypeUtilities.h"
19#include "mlir/Pass/Pass.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27
28namespace {
29
30template <typename SourceOp, typename TargetOp>
31using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
32
33template <typename SourceOp, typename TargetOp>
34using ConvertFMFMathToLLVMPattern =
35 VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>;
36
37using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
38using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
39using CopySignOpLowering =
40 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
41using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
42using CtPopFOpLowering =
43 VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
44using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
45using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
46using FloorOpLowering =
47 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
48using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
49using Log10OpLowering =
50 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
51using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
52using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
53using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
54using FPowIOpLowering =
55 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
56using RoundEvenOpLowering =
57 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
58using RoundOpLowering =
59 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
60using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
61using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
62using FTruncOpLowering =
63 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
64
65// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
66template <typename MathOp, typename LLVMOp>
67struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
68 using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
69 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
70
71 LogicalResult
72 matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
73 ConversionPatternRewriter &rewriter) const override {
74 auto operandType = adaptor.getOperand().getType();
75
76 if (!operandType || !LLVM::isCompatibleType(type: operandType))
77 return failure();
78
79 auto loc = op.getLoc();
80 auto resultType = op.getResult().getType();
81
82 if (!isa<LLVM::LLVMArrayType>(operandType)) {
83 rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
84 false);
85 return success();
86 }
87
88 auto vectorType = dyn_cast<VectorType>(resultType);
89 if (!vectorType)
90 return failure();
91
92 return LLVM::detail::handleMultidimensionalVectors(
93 op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *this->getTypeConverter(),
94 createOperand: [&](Type llvm1DVectorTy, ValueRange operands) {
95 return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
96 false);
97 },
98 rewriter);
99 }
100};
101
102using CountLeadingZerosOpLowering =
103 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
104using CountTrailingZerosOpLowering =
105 IntOpWithFlagLowering<math::CountTrailingZerosOp,
106 LLVM::CountTrailingZerosOp>;
107using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
108
109// A `expm1` is converted into `exp - 1`.
110struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
111 using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
112
113 LogicalResult
114 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
115 ConversionPatternRewriter &rewriter) const override {
116 auto operandType = adaptor.getOperand().getType();
117
118 if (!operandType || !LLVM::isCompatibleType(type: operandType))
119 return failure();
120
121 auto loc = op.getLoc();
122 auto resultType = op.getResult().getType();
123 auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
124 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
125 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
126 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
127
128 if (!isa<LLVM::LLVMArrayType>(operandType)) {
129 LLVM::ConstantOp one;
130 if (LLVM::isCompatibleVectorType(type: operandType)) {
131 one = rewriter.create<LLVM::ConstantOp>(
132 loc, operandType,
133 SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
134 } else {
135 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
136 }
137 auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
138 expAttrs.getAttrs());
139 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
140 op, operandType, ValueRange{exp, one}, subAttrs.getAttrs());
141 return success();
142 }
143
144 auto vectorType = dyn_cast<VectorType>(resultType);
145 if (!vectorType)
146 return rewriter.notifyMatchFailure(op, "expected vector result type");
147
148 return LLVM::detail::handleMultidimensionalVectors(
149 op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *getTypeConverter(),
150 createOperand: [&](Type llvm1DVectorTy, ValueRange operands) {
151 auto numElements = LLVM::getVectorNumElements(type: llvm1DVectorTy);
152 auto splatAttr = SplatElementsAttr::get(
153 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
154 {numElements.isScalable()}),
155 floatOne);
156 auto one =
157 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
158 auto exp = rewriter.create<LLVM::ExpOp>(
159 loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
160 return rewriter.create<LLVM::FSubOp>(
161 loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs());
162 },
163 rewriter);
164 }
165};
166
167// A `log1p` is converted into `log(1 + ...)`.
168struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
169 using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
170
171 LogicalResult
172 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
173 ConversionPatternRewriter &rewriter) const override {
174 auto operandType = adaptor.getOperand().getType();
175
176 if (!operandType || !LLVM::isCompatibleType(type: operandType))
177 return rewriter.notifyMatchFailure(op, "unsupported operand type");
178
179 auto loc = op.getLoc();
180 auto resultType = op.getResult().getType();
181 auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
182 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
183 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
184 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
185
186 if (!isa<LLVM::LLVMArrayType>(operandType)) {
187 LLVM::ConstantOp one =
188 LLVM::isCompatibleVectorType(operandType)
189 ? rewriter.create<LLVM::ConstantOp>(
190 loc, operandType,
191 SplatElementsAttr::get(cast<ShapedType>(resultType),
192 floatOne))
193 : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
194
195 auto add = rewriter.create<LLVM::FAddOp>(
196 loc, operandType, ValueRange{one, adaptor.getOperand()},
197 addAttrs.getAttrs());
198 rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add},
199 logAttrs.getAttrs());
200 return success();
201 }
202
203 auto vectorType = dyn_cast<VectorType>(resultType);
204 if (!vectorType)
205 return rewriter.notifyMatchFailure(op, "expected vector result type");
206
207 return LLVM::detail::handleMultidimensionalVectors(
208 op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *getTypeConverter(),
209 createOperand: [&](Type llvm1DVectorTy, ValueRange operands) {
210 auto numElements = LLVM::getVectorNumElements(type: llvm1DVectorTy);
211 auto splatAttr = SplatElementsAttr::get(
212 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
213 {numElements.isScalable()}),
214 floatOne);
215 auto one =
216 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
217 auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy,
218 ValueRange{one, operands[0]},
219 addAttrs.getAttrs());
220 return rewriter.create<LLVM::LogOp>(
221 loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs());
222 },
223 rewriter);
224 }
225};
226
227// A `rsqrt` is converted into `1 / sqrt`.
228struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
229 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
230
231 LogicalResult
232 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
233 ConversionPatternRewriter &rewriter) const override {
234 auto operandType = adaptor.getOperand().getType();
235
236 if (!operandType || !LLVM::isCompatibleType(type: operandType))
237 return failure();
238
239 auto loc = op.getLoc();
240 auto resultType = op.getResult().getType();
241 auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
242 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
243 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
244 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
245
246 if (!isa<LLVM::LLVMArrayType>(operandType)) {
247 LLVM::ConstantOp one;
248 if (LLVM::isCompatibleVectorType(type: operandType)) {
249 one = rewriter.create<LLVM::ConstantOp>(
250 loc, operandType,
251 SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
252 } else {
253 one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
254 }
255 auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
256 sqrtAttrs.getAttrs());
257 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
258 op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
259 return success();
260 }
261
262 auto vectorType = dyn_cast<VectorType>(resultType);
263 if (!vectorType)
264 return failure();
265
266 return LLVM::detail::handleMultidimensionalVectors(
267 op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *getTypeConverter(),
268 createOperand: [&](Type llvm1DVectorTy, ValueRange operands) {
269 auto numElements = LLVM::getVectorNumElements(type: llvm1DVectorTy);
270 auto splatAttr = SplatElementsAttr::get(
271 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
272 {numElements.isScalable()}),
273 floatOne);
274 auto one =
275 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
276 auto sqrt = rewriter.create<LLVM::SqrtOp>(
277 loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
278 return rewriter.create<LLVM::FDivOp>(
279 loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs());
280 },
281 rewriter);
282 }
283};
284
285struct ConvertMathToLLVMPass
286 : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
287 using Base::Base;
288
289 void runOnOperation() override {
290 RewritePatternSet patterns(&getContext());
291 LLVMTypeConverter converter(&getContext());
292 populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
293 LLVMConversionTarget target(getContext());
294 if (failed(applyPartialConversion(getOperation(), target,
295 std::move(patterns))))
296 signalPassFailure();
297 }
298};
299} // namespace
300
301void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
302 RewritePatternSet &patterns,
303 bool approximateLog1p) {
304 if (approximateLog1p)
305 patterns.add<Log1pOpLowering>(arg&: converter);
306 // clang-format off
307 patterns.add<
308 AbsFOpLowering,
309 AbsIOpLowering,
310 CeilOpLowering,
311 CopySignOpLowering,
312 CosOpLowering,
313 CountLeadingZerosOpLowering,
314 CountTrailingZerosOpLowering,
315 CtPopFOpLowering,
316 Exp2OpLowering,
317 ExpM1OpLowering,
318 ExpOpLowering,
319 FPowIOpLowering,
320 FloorOpLowering,
321 FmaOpLowering,
322 Log10OpLowering,
323 Log2OpLowering,
324 LogOpLowering,
325 PowFOpLowering,
326 RoundEvenOpLowering,
327 RoundOpLowering,
328 RsqrtOpLowering,
329 SinOpLowering,
330 SqrtOpLowering,
331 FTruncOpLowering
332 >(converter);
333 // clang-format on
334}
335
336//===----------------------------------------------------------------------===//
337// ConvertToLLVMPatternInterface implementation
338//===----------------------------------------------------------------------===//
339
340namespace {
341/// Implement the interface to convert Math to LLVM.
342struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
343 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
344 void loadDependentDialects(MLIRContext *context) const final {
345 context->loadDialect<LLVM::LLVMDialect>();
346 }
347
348 /// Hook for derived dialect interface to provide conversion patterns
349 /// and mark dialect legal for the conversion target.
350 void populateConvertToLLVMConversionPatterns(
351 ConversionTarget &target, LLVMTypeConverter &typeConverter,
352 RewritePatternSet &patterns) const final {
353 populateMathToLLVMConversionPatterns(converter&: typeConverter, patterns);
354 }
355};
356} // namespace
357
358void mlir::registerConvertMathToLLVMInterface(DialectRegistry &registry) {
359 registry.addExtension(extensionFn: +[](MLIRContext *ctx, math::MathDialect *dialect) {
360 dialect->addInterfaces<MathToLLVMDialectInterface>();
361 });
362}
363

source code of mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp