1 | //===-- HlfirIntrinsics.cpp -----------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Lower/HlfirIntrinsics.h" |
14 | |
15 | #include "flang/Optimizer/Builder/BoxValue.h" |
16 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
17 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
18 | #include "flang/Optimizer/Builder/IntrinsicCall.h" |
19 | #include "flang/Optimizer/Builder/MutableBox.h" |
20 | #include "flang/Optimizer/Builder/Todo.h" |
21 | #include "flang/Optimizer/Dialect/FIRType.h" |
22 | #include "flang/Optimizer/HLFIR/HLFIRDialect.h" |
23 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
24 | #include "mlir/IR/Value.h" |
25 | #include "llvm/ADT/SmallVector.h" |
26 | #include <mlir/IR/ValueRange.h> |
27 | |
28 | namespace { |
29 | |
30 | class HlfirTransformationalIntrinsic { |
31 | public: |
32 | explicit HlfirTransformationalIntrinsic(fir::FirOpBuilder &builder, |
33 | mlir::Location loc) |
34 | : builder(builder), loc(loc) {} |
35 | |
36 | virtual ~HlfirTransformationalIntrinsic() = default; |
37 | |
38 | hlfir::EntityWithAttributes |
39 | lower(const Fortran::lower::PreparedActualArguments &loweredActuals, |
40 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
41 | mlir::Type stmtResultType) { |
42 | mlir::Value res = lowerImpl(loweredActuals, argLowering, stmtResultType); |
43 | for (const hlfir::CleanupFunction &fn : cleanupFns) |
44 | fn(); |
45 | return {hlfir::EntityWithAttributes{res}}; |
46 | } |
47 | |
48 | protected: |
49 | fir::FirOpBuilder &builder; |
50 | mlir::Location loc; |
51 | llvm::SmallVector<hlfir::CleanupFunction, 3> cleanupFns; |
52 | |
53 | virtual mlir::Value |
54 | lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, |
55 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
56 | mlir::Type stmtResultType) = 0; |
57 | |
58 | llvm::SmallVector<mlir::Value> getOperandVector( |
59 | const Fortran::lower::PreparedActualArguments &loweredActuals, |
60 | const fir::IntrinsicArgumentLoweringRules *argLowering); |
61 | |
62 | mlir::Type computeResultType(mlir::Value argArray, mlir::Type stmtResultType); |
63 | |
64 | template <typename OP, typename... BUILD_ARGS> |
65 | inline OP createOp(BUILD_ARGS... args) { |
66 | return builder.create<OP>(loc, args...); |
67 | } |
68 | |
69 | mlir::Value loadBoxAddress( |
70 | const std::optional<Fortran::lower::PreparedActualArgument> &arg); |
71 | |
72 | void addCleanup(std::optional<hlfir::CleanupFunction> cleanup) { |
73 | if (cleanup) |
74 | cleanupFns.emplace_back(std::move(*cleanup)); |
75 | } |
76 | }; |
77 | |
78 | template <typename OP, bool HAS_MASK> |
79 | class HlfirReductionIntrinsic : public HlfirTransformationalIntrinsic { |
80 | public: |
81 | using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; |
82 | |
83 | protected: |
84 | mlir::Value |
85 | lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, |
86 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
87 | mlir::Type stmtResultType) override; |
88 | }; |
89 | using HlfirSumLowering = HlfirReductionIntrinsic<hlfir::SumOp, true>; |
90 | using HlfirProductLowering = HlfirReductionIntrinsic<hlfir::ProductOp, true>; |
91 | using HlfirMaxvalLowering = HlfirReductionIntrinsic<hlfir::MaxvalOp, true>; |
92 | using HlfirMinvalLowering = HlfirReductionIntrinsic<hlfir::MinvalOp, true>; |
93 | using HlfirAnyLowering = HlfirReductionIntrinsic<hlfir::AnyOp, false>; |
94 | using HlfirAllLowering = HlfirReductionIntrinsic<hlfir::AllOp, false>; |
95 | |
96 | template <typename OP> |
97 | class HlfirMinMaxLocIntrinsic : public HlfirTransformationalIntrinsic { |
98 | public: |
99 | using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; |
100 | |
101 | protected: |
102 | mlir::Value |
103 | lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, |
104 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
105 | mlir::Type stmtResultType) override; |
106 | }; |
107 | using HlfirMinlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MinlocOp>; |
108 | using HlfirMaxlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MaxlocOp>; |
109 | |
110 | template <typename OP> |
111 | class HlfirProductIntrinsic : public HlfirTransformationalIntrinsic { |
112 | public: |
113 | using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; |
114 | |
115 | protected: |
116 | mlir::Value |
117 | lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, |
118 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
119 | mlir::Type stmtResultType) override; |
120 | }; |
121 | using HlfirMatmulLowering = HlfirProductIntrinsic<hlfir::MatmulOp>; |
122 | using HlfirDotProductLowering = HlfirProductIntrinsic<hlfir::DotProductOp>; |
123 | |
124 | class HlfirTransposeLowering : public HlfirTransformationalIntrinsic { |
125 | public: |
126 | using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; |
127 | |
128 | protected: |
129 | mlir::Value |
130 | lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, |
131 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
132 | mlir::Type stmtResultType) override; |
133 | }; |
134 | |
135 | class HlfirCountLowering : public HlfirTransformationalIntrinsic { |
136 | public: |
137 | using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; |
138 | |
139 | protected: |
140 | mlir::Value |
141 | lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, |
142 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
143 | mlir::Type stmtResultType) override; |
144 | }; |
145 | |
146 | class HlfirCharExtremumLowering : public HlfirTransformationalIntrinsic { |
147 | public: |
148 | HlfirCharExtremumLowering(fir::FirOpBuilder &builder, mlir::Location loc, |
149 | hlfir::CharExtremumPredicate pred) |
150 | : HlfirTransformationalIntrinsic(builder, loc), pred{pred} {} |
151 | |
152 | protected: |
153 | mlir::Value |
154 | lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, |
155 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
156 | mlir::Type stmtResultType) override; |
157 | |
158 | protected: |
159 | hlfir::CharExtremumPredicate pred; |
160 | }; |
161 | |
162 | } // namespace |
163 | |
164 | mlir::Value HlfirTransformationalIntrinsic::loadBoxAddress( |
165 | const std::optional<Fortran::lower::PreparedActualArgument> &arg) { |
166 | if (!arg) |
167 | return mlir::Value{}; |
168 | |
169 | hlfir::Entity actual = arg->getActual(loc, builder); |
170 | |
171 | if (!arg->handleDynamicOptional()) { |
172 | if (actual.isMutableBox()) { |
173 | // this is a box address type but is not dynamically optional. Just load |
174 | // the box, assuming it is well formed (!fir.ref<!fir.box<...>> -> |
175 | // !fir.box<...>) |
176 | return builder.create<fir::LoadOp>(loc, actual.getBase()); |
177 | } |
178 | return actual; |
179 | } |
180 | |
181 | auto [exv, cleanup] = hlfir::translateToExtendedValue(loc, builder, actual); |
182 | addCleanup(cleanup); |
183 | |
184 | mlir::Value isPresent = arg->getIsPresent(); |
185 | // createBox will not do create any invalid memory dereferences if exv is |
186 | // absent. The created fir.box will not be usable, but the SelectOp below |
187 | // ensures it won't be. |
188 | mlir::Value box = builder.createBox(loc, exv); |
189 | mlir::Type boxType = box.getType(); |
190 | auto absent = builder.create<fir::AbsentOp>(loc, boxType); |
191 | auto boxOrAbsent = builder.create<mlir::arith::SelectOp>( |
192 | loc, boxType, isPresent, box, absent); |
193 | |
194 | return boxOrAbsent; |
195 | } |
196 | |
197 | static mlir::Value loadOptionalValue( |
198 | mlir::Location loc, fir::FirOpBuilder &builder, |
199 | const std::optional<Fortran::lower::PreparedActualArgument> &arg, |
200 | hlfir::Entity actual) { |
201 | if (!arg->handleDynamicOptional()) |
202 | return hlfir::loadTrivialScalar(loc, builder, actual); |
203 | |
204 | mlir::Value isPresent = arg->getIsPresent(); |
205 | mlir::Type eleType = hlfir::getFortranElementType(actual.getType()); |
206 | return builder |
207 | .genIfOp(loc, {eleType}, isPresent, |
208 | /*withElseRegion=*/true) |
209 | .genThen([&]() { |
210 | assert(actual.isScalar() && fir::isa_trivial(eleType) && |
211 | "must be a numerical or logical scalar" ); |
212 | hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, actual); |
213 | builder.create<fir::ResultOp>(loc, val); |
214 | }) |
215 | .genElse([&]() { |
216 | mlir::Value zero = fir::factory::createZeroValue(builder, loc, eleType); |
217 | builder.create<fir::ResultOp>(loc, zero); |
218 | }) |
219 | .getResults()[0]; |
220 | } |
221 | |
222 | llvm::SmallVector<mlir::Value> HlfirTransformationalIntrinsic::getOperandVector( |
223 | const Fortran::lower::PreparedActualArguments &loweredActuals, |
224 | const fir::IntrinsicArgumentLoweringRules *argLowering) { |
225 | llvm::SmallVector<mlir::Value> operands; |
226 | operands.reserve(loweredActuals.size()); |
227 | |
228 | for (size_t i = 0; i < loweredActuals.size(); ++i) { |
229 | std::optional<Fortran::lower::PreparedActualArgument> arg = |
230 | loweredActuals[i]; |
231 | if (!arg) { |
232 | operands.emplace_back(); |
233 | continue; |
234 | } |
235 | hlfir::Entity actual = arg->getActual(loc, builder); |
236 | mlir::Value valArg; |
237 | |
238 | if (!argLowering) { |
239 | valArg = hlfir::loadTrivialScalar(loc, builder, actual); |
240 | } else { |
241 | fir::ArgLoweringRule argRules = |
242 | fir::lowerIntrinsicArgumentAs(*argLowering, i); |
243 | if (argRules.lowerAs == fir::LowerIntrinsicArgAs::Box) |
244 | valArg = loadBoxAddress(arg); |
245 | else if (!argRules.handleDynamicOptional && |
246 | argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired) |
247 | valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual); |
248 | else if (argRules.handleDynamicOptional && |
249 | argRules.lowerAs == fir::LowerIntrinsicArgAs::Value) |
250 | valArg = loadOptionalValue(loc, builder, arg, actual); |
251 | else if (argRules.handleDynamicOptional) |
252 | TODO(loc, "hlfir transformational intrinsic dynamically optional " |
253 | "argument without box lowering" ); |
254 | else |
255 | valArg = actual.getBase(); |
256 | } |
257 | |
258 | operands.emplace_back(valArg); |
259 | } |
260 | return operands; |
261 | } |
262 | |
263 | mlir::Type |
264 | HlfirTransformationalIntrinsic::computeResultType(mlir::Value argArray, |
265 | mlir::Type stmtResultType) { |
266 | mlir::Type normalisedResult = |
267 | hlfir::getFortranElementOrSequenceType(stmtResultType); |
268 | if (auto array = normalisedResult.dyn_cast<fir::SequenceType>()) { |
269 | hlfir::ExprType::Shape resultShape = |
270 | hlfir::ExprType::Shape{array.getShape()}; |
271 | mlir::Type elementType = array.getEleTy(); |
272 | return hlfir::ExprType::get(builder.getContext(), resultShape, elementType, |
273 | /*polymorphic=*/false); |
274 | } else if (auto resCharType = |
275 | mlir::dyn_cast<fir::CharacterType>(stmtResultType)) { |
276 | normalisedResult = hlfir::ExprType::get( |
277 | builder.getContext(), hlfir::ExprType::Shape{}, resCharType, false); |
278 | } |
279 | return normalisedResult; |
280 | } |
281 | |
282 | template <typename OP, bool HAS_MASK> |
283 | mlir::Value HlfirReductionIntrinsic<OP, HAS_MASK>::lowerImpl( |
284 | const Fortran::lower::PreparedActualArguments &loweredActuals, |
285 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
286 | mlir::Type stmtResultType) { |
287 | auto operands = getOperandVector(loweredActuals, argLowering); |
288 | mlir::Value array = operands[0]; |
289 | mlir::Value dim = operands[1]; |
290 | // dim, mask can be NULL if these arguments are not given |
291 | if (dim) |
292 | dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); |
293 | |
294 | mlir::Type resultTy = computeResultType(array, stmtResultType); |
295 | |
296 | OP op; |
297 | if constexpr (HAS_MASK) |
298 | op = createOp<OP>(resultTy, array, dim, |
299 | /*mask=*/operands[2]); |
300 | else |
301 | op = createOp<OP>(resultTy, array, dim); |
302 | return op; |
303 | } |
304 | |
305 | template <typename OP> |
306 | mlir::Value HlfirMinMaxLocIntrinsic<OP>::lowerImpl( |
307 | const Fortran::lower::PreparedActualArguments &loweredActuals, |
308 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
309 | mlir::Type stmtResultType) { |
310 | auto operands = getOperandVector(loweredActuals, argLowering); |
311 | mlir::Value array = operands[0]; |
312 | mlir::Value dim = operands[1]; |
313 | mlir::Value mask = operands[2]; |
314 | mlir::Value back = operands[4]; |
315 | // dim, mask and back can be NULL if these arguments are not given. |
316 | if (dim) |
317 | dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); |
318 | if (back) |
319 | back = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{back}); |
320 | |
321 | mlir::Type resultTy = computeResultType(array, stmtResultType); |
322 | |
323 | return createOp<OP>(resultTy, array, dim, mask, back); |
324 | } |
325 | |
326 | template <typename OP> |
327 | mlir::Value HlfirProductIntrinsic<OP>::lowerImpl( |
328 | const Fortran::lower::PreparedActualArguments &loweredActuals, |
329 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
330 | mlir::Type stmtResultType) { |
331 | auto operands = getOperandVector(loweredActuals, argLowering); |
332 | mlir::Type resultType = computeResultType(operands[0], stmtResultType); |
333 | return createOp<OP>(resultType, operands[0], operands[1]); |
334 | } |
335 | |
336 | mlir::Value HlfirTransposeLowering::lowerImpl( |
337 | const Fortran::lower::PreparedActualArguments &loweredActuals, |
338 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
339 | mlir::Type stmtResultType) { |
340 | auto operands = getOperandVector(loweredActuals, argLowering); |
341 | hlfir::ExprType::Shape resultShape; |
342 | mlir::Type normalisedResult = |
343 | hlfir::getFortranElementOrSequenceType(stmtResultType); |
344 | auto array = normalisedResult.cast<fir::SequenceType>(); |
345 | llvm::ArrayRef<int64_t> arrayShape = array.getShape(); |
346 | assert(arrayShape.size() == 2 && "arguments to transpose have a rank of 2" ); |
347 | mlir::Type elementType = array.getEleTy(); |
348 | resultShape.push_back(arrayShape[0]); |
349 | resultShape.push_back(arrayShape[1]); |
350 | if (auto resCharType = mlir::dyn_cast<fir::CharacterType>(elementType)) |
351 | if (!resCharType.hasConstantLen()) { |
352 | // The FunctionRef expression might have imprecise character |
353 | // type at this point, and we can improve it by propagating |
354 | // the constant length from the argument. |
355 | auto argCharType = mlir::dyn_cast<fir::CharacterType>( |
356 | hlfir::getFortranElementType(operands[0].getType())); |
357 | if (argCharType && argCharType.hasConstantLen()) |
358 | elementType = fir::CharacterType::get( |
359 | builder.getContext(), resCharType.getFKind(), argCharType.getLen()); |
360 | } |
361 | |
362 | mlir::Type resultTy = |
363 | hlfir::ExprType::get(builder.getContext(), resultShape, elementType, |
364 | fir::isPolymorphicType(stmtResultType)); |
365 | return createOp<hlfir::TransposeOp>(resultTy, operands[0]); |
366 | } |
367 | |
368 | mlir::Value HlfirCountLowering::lowerImpl( |
369 | const Fortran::lower::PreparedActualArguments &loweredActuals, |
370 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
371 | mlir::Type stmtResultType) { |
372 | auto operands = getOperandVector(loweredActuals, argLowering); |
373 | mlir::Value array = operands[0]; |
374 | mlir::Value dim = operands[1]; |
375 | if (dim) |
376 | dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); |
377 | mlir::Type resultType = computeResultType(array, stmtResultType); |
378 | return createOp<hlfir::CountOp>(resultType, array, dim); |
379 | } |
380 | |
381 | mlir::Value HlfirCharExtremumLowering::lowerImpl( |
382 | const Fortran::lower::PreparedActualArguments &loweredActuals, |
383 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
384 | mlir::Type stmtResultType) { |
385 | auto operands = getOperandVector(loweredActuals, argLowering); |
386 | assert(operands.size() >= 2); |
387 | return createOp<hlfir::CharExtremumOp>(pred, mlir::ValueRange{operands}); |
388 | } |
389 | |
390 | std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic( |
391 | fir::FirOpBuilder &builder, mlir::Location loc, const std::string &name, |
392 | const Fortran::lower::PreparedActualArguments &loweredActuals, |
393 | const fir::IntrinsicArgumentLoweringRules *argLowering, |
394 | mlir::Type stmtResultType) { |
395 | // If the result is of a derived type that may need finalization, |
396 | // we have to use DestroyOp with 'finalize' attribute for the result |
397 | // of the intrinsic operation. |
398 | if (name == "sum" ) |
399 | return HlfirSumLowering{builder, loc}.lower(loweredActuals, argLowering, |
400 | stmtResultType); |
401 | if (name == "product" ) |
402 | return HlfirProductLowering{builder, loc}.lower(loweredActuals, argLowering, |
403 | stmtResultType); |
404 | if (name == "any" ) |
405 | return HlfirAnyLowering{builder, loc}.lower(loweredActuals, argLowering, |
406 | stmtResultType); |
407 | if (name == "all" ) |
408 | return HlfirAllLowering{builder, loc}.lower(loweredActuals, argLowering, |
409 | stmtResultType); |
410 | if (name == "matmul" ) |
411 | return HlfirMatmulLowering{builder, loc}.lower(loweredActuals, argLowering, |
412 | stmtResultType); |
413 | if (name == "dot_product" ) |
414 | return HlfirDotProductLowering{builder, loc}.lower( |
415 | loweredActuals, argLowering, stmtResultType); |
416 | // FIXME: the result may need finalization. |
417 | if (name == "transpose" ) |
418 | return HlfirTransposeLowering{builder, loc}.lower( |
419 | loweredActuals, argLowering, stmtResultType); |
420 | if (name == "count" ) |
421 | return HlfirCountLowering{builder, loc}.lower(loweredActuals, argLowering, |
422 | stmtResultType); |
423 | if (name == "maxval" ) |
424 | return HlfirMaxvalLowering{builder, loc}.lower(loweredActuals, argLowering, |
425 | stmtResultType); |
426 | if (name == "minval" ) |
427 | return HlfirMinvalLowering{builder, loc}.lower(loweredActuals, argLowering, |
428 | stmtResultType); |
429 | if (name == "minloc" ) |
430 | return HlfirMinlocLowering{builder, loc}.lower(loweredActuals, argLowering, |
431 | stmtResultType); |
432 | if (name == "maxloc" ) |
433 | return HlfirMaxlocLowering{builder, loc}.lower(loweredActuals, argLowering, |
434 | stmtResultType); |
435 | if (mlir::isa<fir::CharacterType>(stmtResultType)) { |
436 | if (name == "min" ) |
437 | return HlfirCharExtremumLowering{builder, loc, |
438 | hlfir::CharExtremumPredicate::min} |
439 | .lower(loweredActuals, argLowering, stmtResultType); |
440 | if (name == "max" ) |
441 | return HlfirCharExtremumLowering{builder, loc, |
442 | hlfir::CharExtremumPredicate::max} |
443 | .lower(loweredActuals, argLowering, stmtResultType); |
444 | } |
445 | return std::nullopt; |
446 | } |
447 | |