1//===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
10
11#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14#include "mlir/Conversion/LLVMCommon/Pattern.h"
15#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
16#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17#include "mlir/Dialect/Math/IR/Math.h"
18#include "mlir/IR/TypeUtilities.h"
19#include "mlir/Pass/Pass.h"
20
21#include "llvm/ADT/FloatingPointMode.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30namespace {
31
32template <typename SourceOp, typename TargetOp>
33using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
34
35template <typename SourceOp, typename TargetOp>
36using ConvertFMFMathToLLVMPattern =
37 VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>;
38
39using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
40using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
41using CopySignOpLowering =
42 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
43using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
44using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
45using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
46using CtPopFOpLowering =
47 VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
48using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
49using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
50using FloorOpLowering =
51 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
52using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
53using Log10OpLowering =
54 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
55using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
56using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
57using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
58using FPowIOpLowering =
59 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
60using RoundEvenOpLowering =
61 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
62using RoundOpLowering =
63 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
64using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
65using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
66using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
67using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
68using FTruncOpLowering =
69 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
70using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
71using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
72using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
73using ATan2OpLowering =
74 ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
75// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
76// TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
77// may be better to separate the patterns.
78template <typename MathOp, typename LLVMOp>
79struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
80 using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
81 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
82
83 LogicalResult
84 matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
85 ConversionPatternRewriter &rewriter) const override {
86 const auto &typeConverter = *this->getTypeConverter();
87 auto operandType = adaptor.getOperand().getType();
88 auto llvmOperandType = typeConverter.convertType(operandType);
89 if (!llvmOperandType)
90 return failure();
91
92 auto loc = op.getLoc();
93 auto resultType = op.getResult().getType();
94 auto llvmResultType = typeConverter.convertType(resultType);
95 if (!llvmResultType)
96 return failure();
97
98 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
99 rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
100 adaptor.getOperand(), false);
101 return success();
102 }
103
104 if (!isa<VectorType>(llvmResultType))
105 return failure();
106
107 return LLVM::detail::handleMultidimensionalVectors(
108 op: op.getOperation(), operands: adaptor.getOperands(), typeConverter,
109 createOperand: [&](Type llvm1DVectorTy, ValueRange operands) {
110 return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
111 false);
112 },
113 rewriter);
114 }
115};
116
117using CountLeadingZerosOpLowering =
118 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
119using CountTrailingZerosOpLowering =
120 IntOpWithFlagLowering<math::CountTrailingZerosOp,
121 LLVM::CountTrailingZerosOp>;
122using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
123
124// A `expm1` is converted into `exp - 1`.
125struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
126 using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
127
128 LogicalResult
129 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
130 ConversionPatternRewriter &rewriter) const override {
131 const auto &typeConverter = *this->getTypeConverter();
132 auto operandType = adaptor.getOperand().getType();
133 auto llvmOperandType = typeConverter.convertType(operandType);
134 if (!llvmOperandType)
135 return failure();
136
137 auto loc = op.getLoc();
138 auto resultType = op.getResult().getType();
139 auto floatType = cast<FloatType>(
140 typeConverter.convertType(getElementTypeOrSelf(resultType)));
141 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
142 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
143 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
144
145 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
146 LLVM::ConstantOp one;
147 if (LLVM::isCompatibleVectorType(type: llvmOperandType)) {
148 one = rewriter.create<LLVM::ConstantOp>(
149 loc, llvmOperandType,
150 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
151 floatOne));
152 } else {
153 one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
154 }
155 auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
156 expAttrs.getAttrs());
157 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
158 op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
159 return success();
160 }
161
162 if (!isa<VectorType>(resultType))
163 return rewriter.notifyMatchFailure(op, "expected vector result type");
164
165 return LLVM::detail::handleMultidimensionalVectors(
166 op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: typeConverter,
167 createOperand: [&](Type llvm1DVectorTy, ValueRange operands) {
168 auto numElements = LLVM::getVectorNumElements(type: llvm1DVectorTy);
169 auto splatAttr = SplatElementsAttr::get(
170 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
171 {numElements.isScalable()}),
172 floatOne);
173 auto one =
174 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
175 auto exp = rewriter.create<LLVM::ExpOp>(
176 loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
177 return rewriter.create<LLVM::FSubOp>(
178 loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs());
179 },
180 rewriter);
181 }
182};
183
184// A `log1p` is converted into `log(1 + ...)`.
185struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
186 using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
187
188 LogicalResult
189 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
190 ConversionPatternRewriter &rewriter) const override {
191 const auto &typeConverter = *this->getTypeConverter();
192 auto operandType = adaptor.getOperand().getType();
193 auto llvmOperandType = typeConverter.convertType(operandType);
194 if (!llvmOperandType)
195 return rewriter.notifyMatchFailure(op, "unsupported operand type");
196
197 auto loc = op.getLoc();
198 auto resultType = op.getResult().getType();
199 auto floatType = cast<FloatType>(
200 typeConverter.convertType(getElementTypeOrSelf(resultType)));
201 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
202 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
203 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
204
205 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
206 LLVM::ConstantOp one =
207 isa<VectorType>(llvmOperandType)
208 ? rewriter.create<LLVM::ConstantOp>(
209 loc, llvmOperandType,
210 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
211 floatOne))
212 : rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType,
213 floatOne);
214
215 auto add = rewriter.create<LLVM::FAddOp>(
216 loc, llvmOperandType, ValueRange{one, adaptor.getOperand()},
217 addAttrs.getAttrs());
218 rewriter.replaceOpWithNewOp<LLVM::LogOp>(
219 op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
220 return success();
221 }
222
223 if (!isa<VectorType>(resultType))
224 return rewriter.notifyMatchFailure(op, "expected vector result type");
225
226 return LLVM::detail::handleMultidimensionalVectors(
227 op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: typeConverter,
228 createOperand: [&](Type llvm1DVectorTy, ValueRange operands) {
229 auto numElements = LLVM::getVectorNumElements(type: llvm1DVectorTy);
230 auto splatAttr = SplatElementsAttr::get(
231 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
232 {numElements.isScalable()}),
233 floatOne);
234 auto one =
235 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
236 auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy,
237 ValueRange{one, operands[0]},
238 addAttrs.getAttrs());
239 return rewriter.create<LLVM::LogOp>(
240 loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs());
241 },
242 rewriter);
243 }
244};
245
246// A `rsqrt` is converted into `1 / sqrt`.
247struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
248 using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
249
250 LogicalResult
251 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
252 ConversionPatternRewriter &rewriter) const override {
253 const auto &typeConverter = *this->getTypeConverter();
254 auto operandType = adaptor.getOperand().getType();
255 auto llvmOperandType = typeConverter.convertType(operandType);
256 if (!llvmOperandType)
257 return failure();
258
259 auto loc = op.getLoc();
260 auto resultType = op.getResult().getType();
261 auto floatType = cast<FloatType>(
262 typeConverter.convertType(getElementTypeOrSelf(resultType)));
263 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
264 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
265 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
266
267 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
268 LLVM::ConstantOp one;
269 if (isa<VectorType>(llvmOperandType)) {
270 one = rewriter.create<LLVM::ConstantOp>(
271 loc, llvmOperandType,
272 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
273 floatOne));
274 } else {
275 one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
276 }
277 auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
278 sqrtAttrs.getAttrs());
279 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
280 op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
281 return success();
282 }
283
284 if (!isa<VectorType>(resultType))
285 return failure();
286
287 return LLVM::detail::handleMultidimensionalVectors(
288 op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: typeConverter,
289 createOperand: [&](Type llvm1DVectorTy, ValueRange operands) {
290 auto numElements = LLVM::getVectorNumElements(type: llvm1DVectorTy);
291 auto splatAttr = SplatElementsAttr::get(
292 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
293 {numElements.isScalable()}),
294 floatOne);
295 auto one =
296 rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
297 auto sqrt = rewriter.create<LLVM::SqrtOp>(
298 loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
299 return rewriter.create<LLVM::FDivOp>(
300 loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs());
301 },
302 rewriter);
303 }
304};
305
306struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
307 using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
308
309 LogicalResult
310 matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
311 ConversionPatternRewriter &rewriter) const override {
312 const auto &typeConverter = *this->getTypeConverter();
313 auto operandType =
314 typeConverter.convertType(adaptor.getOperand().getType());
315 auto resultType = typeConverter.convertType(op.getResult().getType());
316 if (!operandType || !resultType)
317 return failure();
318
319 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
320 op, resultType, adaptor.getOperand(), llvm::fcNan);
321 return success();
322 }
323};
324
325struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
326 using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
327
328 LogicalResult
329 matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
330 ConversionPatternRewriter &rewriter) const override {
331 const auto &typeConverter = *this->getTypeConverter();
332 auto operandType =
333 typeConverter.convertType(adaptor.getOperand().getType());
334 auto resultType = typeConverter.convertType(op.getResult().getType());
335 if (!operandType || !resultType)
336 return failure();
337
338 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
339 op, resultType, adaptor.getOperand(), llvm::fcFinite);
340 return success();
341 }
342};
343
344struct ConvertMathToLLVMPass
345 : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
346 using Base::Base;
347
348 void runOnOperation() override {
349 RewritePatternSet patterns(&getContext());
350 LLVMTypeConverter converter(&getContext());
351 populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
352 LLVMConversionTarget target(getContext());
353 if (failed(applyPartialConversion(getOperation(), target,
354 std::move(patterns))))
355 signalPassFailure();
356 }
357};
358} // namespace
359
360void mlir::populateMathToLLVMConversionPatterns(
361 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
362 bool approximateLog1p, PatternBenefit benefit) {
363 if (approximateLog1p)
364 patterns.add<Log1pOpLowering>(arg: converter, args&: benefit);
365 // clang-format off
366 patterns.add<
367 IsNaNOpLowering,
368 IsFiniteOpLowering,
369 AbsFOpLowering,
370 AbsIOpLowering,
371 CeilOpLowering,
372 CopySignOpLowering,
373 CosOpLowering,
374 CoshOpLowering,
375 AcosOpLowering,
376 CountLeadingZerosOpLowering,
377 CountTrailingZerosOpLowering,
378 CtPopFOpLowering,
379 Exp2OpLowering,
380 ExpM1OpLowering,
381 ExpOpLowering,
382 FPowIOpLowering,
383 FloorOpLowering,
384 FmaOpLowering,
385 Log10OpLowering,
386 Log2OpLowering,
387 LogOpLowering,
388 PowFOpLowering,
389 RoundEvenOpLowering,
390 RoundOpLowering,
391 RsqrtOpLowering,
392 SinOpLowering,
393 SinhOpLowering,
394 ASinOpLowering,
395 SqrtOpLowering,
396 FTruncOpLowering,
397 TanOpLowering,
398 TanhOpLowering,
399 ATanOpLowering,
400 ATan2OpLowering
401 >(converter, benefit);
402 // clang-format on
403}
404
405//===----------------------------------------------------------------------===//
406// ConvertToLLVMPatternInterface implementation
407//===----------------------------------------------------------------------===//
408
409namespace {
410/// Implement the interface to convert Math to LLVM.
411struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
412 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
413 void loadDependentDialects(MLIRContext *context) const final {
414 context->loadDialect<LLVM::LLVMDialect>();
415 }
416
417 /// Hook for derived dialect interface to provide conversion patterns
418 /// and mark dialect legal for the conversion target.
419 void populateConvertToLLVMConversionPatterns(
420 ConversionTarget &target, LLVMTypeConverter &typeConverter,
421 RewritePatternSet &patterns) const final {
422 populateMathToLLVMConversionPatterns(converter: typeConverter, patterns);
423 }
424};
425} // namespace
426
427void mlir::registerConvertMathToLLVMInterface(DialectRegistry &registry) {
428 registry.addExtension(extensionFn: +[](MLIRContext *ctx, math::MathDialect *dialect) {
429 dialect->addInterfaces<MathToLLVMDialectInterface>();
430 });
431}
432

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp