1 | //===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" |
10 | |
11 | #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" |
12 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
13 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
14 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
15 | #include "mlir/Conversion/LLVMCommon/VectorPattern.h" |
16 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
17 | #include "mlir/Dialect/Math/IR/Math.h" |
18 | #include "mlir/IR/TypeUtilities.h" |
19 | #include "mlir/Pass/Pass.h" |
20 | |
21 | namespace mlir { |
22 | #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS |
23 | #include "mlir/Conversion/Passes.h.inc" |
24 | } // namespace mlir |
25 | |
26 | using namespace mlir; |
27 | |
28 | namespace { |
29 | |
30 | template <typename SourceOp, typename TargetOp> |
31 | using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>; |
32 | |
33 | template <typename SourceOp, typename TargetOp> |
34 | using ConvertFMFMathToLLVMPattern = |
35 | VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>; |
36 | |
37 | using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>; |
38 | using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>; |
39 | using CopySignOpLowering = |
40 | ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>; |
41 | using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>; |
42 | using CtPopFOpLowering = |
43 | VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>; |
44 | using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>; |
45 | using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>; |
46 | using FloorOpLowering = |
47 | ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>; |
48 | using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>; |
49 | using Log10OpLowering = |
50 | ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>; |
51 | using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>; |
52 | using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>; |
53 | using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>; |
54 | using FPowIOpLowering = |
55 | ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>; |
56 | using RoundEvenOpLowering = |
57 | ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>; |
58 | using RoundOpLowering = |
59 | ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>; |
60 | using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>; |
61 | using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>; |
62 | using FTruncOpLowering = |
63 | ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>; |
64 | |
65 | // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`. |
66 | template <typename MathOp, typename LLVMOp> |
67 | struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> { |
68 | using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern; |
69 | using Super = IntOpWithFlagLowering<MathOp, LLVMOp>; |
70 | |
71 | LogicalResult |
72 | matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, |
73 | ConversionPatternRewriter &rewriter) const override { |
74 | auto operandType = adaptor.getOperand().getType(); |
75 | |
76 | if (!operandType || !LLVM::isCompatibleType(type: operandType)) |
77 | return failure(); |
78 | |
79 | auto loc = op.getLoc(); |
80 | auto resultType = op.getResult().getType(); |
81 | |
82 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
83 | rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(), |
84 | false); |
85 | return success(); |
86 | } |
87 | |
88 | auto vectorType = dyn_cast<VectorType>(resultType); |
89 | if (!vectorType) |
90 | return failure(); |
91 | |
92 | return LLVM::detail::handleMultidimensionalVectors( |
93 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *this->getTypeConverter(), |
94 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) { |
95 | return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0], |
96 | false); |
97 | }, |
98 | rewriter); |
99 | } |
100 | }; |
101 | |
102 | using CountLeadingZerosOpLowering = |
103 | IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>; |
104 | using CountTrailingZerosOpLowering = |
105 | IntOpWithFlagLowering<math::CountTrailingZerosOp, |
106 | LLVM::CountTrailingZerosOp>; |
107 | using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>; |
108 | |
109 | // A `expm1` is converted into `exp - 1`. |
110 | struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { |
111 | using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; |
112 | |
113 | LogicalResult |
114 | matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, |
115 | ConversionPatternRewriter &rewriter) const override { |
116 | auto operandType = adaptor.getOperand().getType(); |
117 | |
118 | if (!operandType || !LLVM::isCompatibleType(type: operandType)) |
119 | return failure(); |
120 | |
121 | auto loc = op.getLoc(); |
122 | auto resultType = op.getResult().getType(); |
123 | auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType)); |
124 | auto floatOne = rewriter.getFloatAttr(floatType, 1.0); |
125 | ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op); |
126 | ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op); |
127 | |
128 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
129 | LLVM::ConstantOp one; |
130 | if (LLVM::isCompatibleVectorType(type: operandType)) { |
131 | one = rewriter.create<LLVM::ConstantOp>( |
132 | loc, operandType, |
133 | SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne)); |
134 | } else { |
135 | one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); |
136 | } |
137 | auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(), |
138 | expAttrs.getAttrs()); |
139 | rewriter.replaceOpWithNewOp<LLVM::FSubOp>( |
140 | op, operandType, ValueRange{exp, one}, subAttrs.getAttrs()); |
141 | return success(); |
142 | } |
143 | |
144 | auto vectorType = dyn_cast<VectorType>(resultType); |
145 | if (!vectorType) |
146 | return rewriter.notifyMatchFailure(op, "expected vector result type" ); |
147 | |
148 | return LLVM::detail::handleMultidimensionalVectors( |
149 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *getTypeConverter(), |
150 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) { |
151 | auto numElements = LLVM::getVectorNumElements(type: llvm1DVectorTy); |
152 | auto splatAttr = SplatElementsAttr::get( |
153 | mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, |
154 | {numElements.isScalable()}), |
155 | floatOne); |
156 | auto one = |
157 | rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); |
158 | auto exp = rewriter.create<LLVM::ExpOp>( |
159 | loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs()); |
160 | return rewriter.create<LLVM::FSubOp>( |
161 | loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs()); |
162 | }, |
163 | rewriter); |
164 | } |
165 | }; |
166 | |
167 | // A `log1p` is converted into `log(1 + ...)`. |
168 | struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { |
169 | using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern; |
170 | |
171 | LogicalResult |
172 | matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, |
173 | ConversionPatternRewriter &rewriter) const override { |
174 | auto operandType = adaptor.getOperand().getType(); |
175 | |
176 | if (!operandType || !LLVM::isCompatibleType(type: operandType)) |
177 | return rewriter.notifyMatchFailure(op, "unsupported operand type" ); |
178 | |
179 | auto loc = op.getLoc(); |
180 | auto resultType = op.getResult().getType(); |
181 | auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType)); |
182 | auto floatOne = rewriter.getFloatAttr(floatType, 1.0); |
183 | ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op); |
184 | ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op); |
185 | |
186 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
187 | LLVM::ConstantOp one = |
188 | LLVM::isCompatibleVectorType(operandType) |
189 | ? rewriter.create<LLVM::ConstantOp>( |
190 | loc, operandType, |
191 | SplatElementsAttr::get(cast<ShapedType>(resultType), |
192 | floatOne)) |
193 | : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); |
194 | |
195 | auto add = rewriter.create<LLVM::FAddOp>( |
196 | loc, operandType, ValueRange{one, adaptor.getOperand()}, |
197 | addAttrs.getAttrs()); |
198 | rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add}, |
199 | logAttrs.getAttrs()); |
200 | return success(); |
201 | } |
202 | |
203 | auto vectorType = dyn_cast<VectorType>(resultType); |
204 | if (!vectorType) |
205 | return rewriter.notifyMatchFailure(op, "expected vector result type" ); |
206 | |
207 | return LLVM::detail::handleMultidimensionalVectors( |
208 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *getTypeConverter(), |
209 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) { |
210 | auto numElements = LLVM::getVectorNumElements(type: llvm1DVectorTy); |
211 | auto splatAttr = SplatElementsAttr::get( |
212 | mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, |
213 | {numElements.isScalable()}), |
214 | floatOne); |
215 | auto one = |
216 | rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); |
217 | auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, |
218 | ValueRange{one, operands[0]}, |
219 | addAttrs.getAttrs()); |
220 | return rewriter.create<LLVM::LogOp>( |
221 | loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs()); |
222 | }, |
223 | rewriter); |
224 | } |
225 | }; |
226 | |
227 | // A `rsqrt` is converted into `1 / sqrt`. |
228 | struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { |
229 | using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern; |
230 | |
231 | LogicalResult |
232 | matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, |
233 | ConversionPatternRewriter &rewriter) const override { |
234 | auto operandType = adaptor.getOperand().getType(); |
235 | |
236 | if (!operandType || !LLVM::isCompatibleType(type: operandType)) |
237 | return failure(); |
238 | |
239 | auto loc = op.getLoc(); |
240 | auto resultType = op.getResult().getType(); |
241 | auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType)); |
242 | auto floatOne = rewriter.getFloatAttr(floatType, 1.0); |
243 | ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op); |
244 | ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op); |
245 | |
246 | if (!isa<LLVM::LLVMArrayType>(operandType)) { |
247 | LLVM::ConstantOp one; |
248 | if (LLVM::isCompatibleVectorType(type: operandType)) { |
249 | one = rewriter.create<LLVM::ConstantOp>( |
250 | loc, operandType, |
251 | SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne)); |
252 | } else { |
253 | one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne); |
254 | } |
255 | auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(), |
256 | sqrtAttrs.getAttrs()); |
257 | rewriter.replaceOpWithNewOp<LLVM::FDivOp>( |
258 | op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs()); |
259 | return success(); |
260 | } |
261 | |
262 | auto vectorType = dyn_cast<VectorType>(resultType); |
263 | if (!vectorType) |
264 | return failure(); |
265 | |
266 | return LLVM::detail::handleMultidimensionalVectors( |
267 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *getTypeConverter(), |
268 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) { |
269 | auto numElements = LLVM::getVectorNumElements(type: llvm1DVectorTy); |
270 | auto splatAttr = SplatElementsAttr::get( |
271 | mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, |
272 | {numElements.isScalable()}), |
273 | floatOne); |
274 | auto one = |
275 | rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); |
276 | auto sqrt = rewriter.create<LLVM::SqrtOp>( |
277 | loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs()); |
278 | return rewriter.create<LLVM::FDivOp>( |
279 | loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs()); |
280 | }, |
281 | rewriter); |
282 | } |
283 | }; |
284 | |
285 | struct ConvertMathToLLVMPass |
286 | : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> { |
287 | using Base::Base; |
288 | |
289 | void runOnOperation() override { |
290 | RewritePatternSet patterns(&getContext()); |
291 | LLVMTypeConverter converter(&getContext()); |
292 | populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p); |
293 | LLVMConversionTarget target(getContext()); |
294 | if (failed(applyPartialConversion(getOperation(), target, |
295 | std::move(patterns)))) |
296 | signalPassFailure(); |
297 | } |
298 | }; |
299 | } // namespace |
300 | |
301 | void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, |
302 | RewritePatternSet &patterns, |
303 | bool approximateLog1p) { |
304 | if (approximateLog1p) |
305 | patterns.add<Log1pOpLowering>(arg&: converter); |
306 | // clang-format off |
307 | patterns.add< |
308 | AbsFOpLowering, |
309 | AbsIOpLowering, |
310 | CeilOpLowering, |
311 | CopySignOpLowering, |
312 | CosOpLowering, |
313 | CountLeadingZerosOpLowering, |
314 | CountTrailingZerosOpLowering, |
315 | CtPopFOpLowering, |
316 | Exp2OpLowering, |
317 | ExpM1OpLowering, |
318 | ExpOpLowering, |
319 | FPowIOpLowering, |
320 | FloorOpLowering, |
321 | FmaOpLowering, |
322 | Log10OpLowering, |
323 | Log2OpLowering, |
324 | LogOpLowering, |
325 | PowFOpLowering, |
326 | RoundEvenOpLowering, |
327 | RoundOpLowering, |
328 | RsqrtOpLowering, |
329 | SinOpLowering, |
330 | SqrtOpLowering, |
331 | FTruncOpLowering |
332 | >(converter); |
333 | // clang-format on |
334 | } |
335 | |
336 | //===----------------------------------------------------------------------===// |
337 | // ConvertToLLVMPatternInterface implementation |
338 | //===----------------------------------------------------------------------===// |
339 | |
340 | namespace { |
341 | /// Implement the interface to convert Math to LLVM. |
342 | struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface { |
343 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
344 | void loadDependentDialects(MLIRContext *context) const final { |
345 | context->loadDialect<LLVM::LLVMDialect>(); |
346 | } |
347 | |
348 | /// Hook for derived dialect interface to provide conversion patterns |
349 | /// and mark dialect legal for the conversion target. |
350 | void populateConvertToLLVMConversionPatterns( |
351 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
352 | RewritePatternSet &patterns) const final { |
353 | populateMathToLLVMConversionPatterns(converter&: typeConverter, patterns); |
354 | } |
355 | }; |
356 | } // namespace |
357 | |
358 | void mlir::registerConvertMathToLLVMInterface(DialectRegistry ®istry) { |
359 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, math::MathDialect *dialect) { |
360 | dialect->addInterfaces<MathToLLVMDialectInterface>(); |
361 | }); |
362 | } |
363 | |