1//===-- HlfirIntrinsics.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Lower/HlfirIntrinsics.h"
14
15#include "flang/Optimizer/Builder/BoxValue.h"
16#include "flang/Optimizer/Builder/FIRBuilder.h"
17#include "flang/Optimizer/Builder/HLFIRTools.h"
18#include "flang/Optimizer/Builder/IntrinsicCall.h"
19#include "flang/Optimizer/Builder/MutableBox.h"
20#include "flang/Optimizer/Builder/Todo.h"
21#include "flang/Optimizer/Dialect/FIRType.h"
22#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
23#include "flang/Optimizer/HLFIR/HLFIROps.h"
24#include "mlir/IR/Value.h"
25#include "llvm/ADT/SmallVector.h"
26#include <mlir/IR/ValueRange.h>
27
28namespace {
29
30class HlfirTransformationalIntrinsic {
31public:
32 explicit HlfirTransformationalIntrinsic(fir::FirOpBuilder &builder,
33 mlir::Location loc)
34 : builder(builder), loc(loc) {}
35
36 virtual ~HlfirTransformationalIntrinsic() = default;
37
38 hlfir::EntityWithAttributes
39 lower(const Fortran::lower::PreparedActualArguments &loweredActuals,
40 const fir::IntrinsicArgumentLoweringRules *argLowering,
41 mlir::Type stmtResultType) {
42 mlir::Value res = lowerImpl(loweredActuals, argLowering, stmtResultType);
43 for (const hlfir::CleanupFunction &fn : cleanupFns)
44 fn();
45 return {hlfir::EntityWithAttributes{res}};
46 }
47
48protected:
49 fir::FirOpBuilder &builder;
50 mlir::Location loc;
51 llvm::SmallVector<hlfir::CleanupFunction, 3> cleanupFns;
52
53 virtual mlir::Value
54 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
55 const fir::IntrinsicArgumentLoweringRules *argLowering,
56 mlir::Type stmtResultType) = 0;
57
58 llvm::SmallVector<mlir::Value> getOperandVector(
59 const Fortran::lower::PreparedActualArguments &loweredActuals,
60 const fir::IntrinsicArgumentLoweringRules *argLowering);
61
62 mlir::Type computeResultType(mlir::Value argArray, mlir::Type stmtResultType);
63
64 template <typename OP, typename... BUILD_ARGS>
65 inline OP createOp(BUILD_ARGS... args) {
66 return builder.create<OP>(loc, args...);
67 }
68
69 mlir::Value loadBoxAddress(
70 const std::optional<Fortran::lower::PreparedActualArgument> &arg);
71
72 void addCleanup(std::optional<hlfir::CleanupFunction> cleanup) {
73 if (cleanup)
74 cleanupFns.emplace_back(std::move(*cleanup));
75 }
76};
77
78template <typename OP, bool HAS_MASK>
79class HlfirReductionIntrinsic : public HlfirTransformationalIntrinsic {
80public:
81 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
82
83protected:
84 mlir::Value
85 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
86 const fir::IntrinsicArgumentLoweringRules *argLowering,
87 mlir::Type stmtResultType) override;
88};
89using HlfirSumLowering = HlfirReductionIntrinsic<hlfir::SumOp, true>;
90using HlfirProductLowering = HlfirReductionIntrinsic<hlfir::ProductOp, true>;
91using HlfirMaxvalLowering = HlfirReductionIntrinsic<hlfir::MaxvalOp, true>;
92using HlfirMinvalLowering = HlfirReductionIntrinsic<hlfir::MinvalOp, true>;
93using HlfirAnyLowering = HlfirReductionIntrinsic<hlfir::AnyOp, false>;
94using HlfirAllLowering = HlfirReductionIntrinsic<hlfir::AllOp, false>;
95
96template <typename OP>
97class HlfirMinMaxLocIntrinsic : public HlfirTransformationalIntrinsic {
98public:
99 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
100
101protected:
102 mlir::Value
103 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
104 const fir::IntrinsicArgumentLoweringRules *argLowering,
105 mlir::Type stmtResultType) override;
106};
107using HlfirMinlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MinlocOp>;
108using HlfirMaxlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MaxlocOp>;
109
110template <typename OP>
111class HlfirProductIntrinsic : public HlfirTransformationalIntrinsic {
112public:
113 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
114
115protected:
116 mlir::Value
117 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
118 const fir::IntrinsicArgumentLoweringRules *argLowering,
119 mlir::Type stmtResultType) override;
120};
121using HlfirMatmulLowering = HlfirProductIntrinsic<hlfir::MatmulOp>;
122using HlfirDotProductLowering = HlfirProductIntrinsic<hlfir::DotProductOp>;
123
124class HlfirTransposeLowering : public HlfirTransformationalIntrinsic {
125public:
126 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
127
128protected:
129 mlir::Value
130 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
131 const fir::IntrinsicArgumentLoweringRules *argLowering,
132 mlir::Type stmtResultType) override;
133};
134
135class HlfirCountLowering : public HlfirTransformationalIntrinsic {
136public:
137 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
138
139protected:
140 mlir::Value
141 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
142 const fir::IntrinsicArgumentLoweringRules *argLowering,
143 mlir::Type stmtResultType) override;
144};
145
146class HlfirCharExtremumLowering : public HlfirTransformationalIntrinsic {
147public:
148 HlfirCharExtremumLowering(fir::FirOpBuilder &builder, mlir::Location loc,
149 hlfir::CharExtremumPredicate pred)
150 : HlfirTransformationalIntrinsic(builder, loc), pred{pred} {}
151
152protected:
153 mlir::Value
154 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
155 const fir::IntrinsicArgumentLoweringRules *argLowering,
156 mlir::Type stmtResultType) override;
157
158protected:
159 hlfir::CharExtremumPredicate pred;
160};
161
162class HlfirCShiftLowering : public HlfirTransformationalIntrinsic {
163public:
164 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
165
166protected:
167 mlir::Value
168 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
169 const fir::IntrinsicArgumentLoweringRules *argLowering,
170 mlir::Type stmtResultType) override;
171};
172
173class HlfirReshapeLowering : public HlfirTransformationalIntrinsic {
174public:
175 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
176
177protected:
178 mlir::Value
179 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
180 const fir::IntrinsicArgumentLoweringRules *argLowering,
181 mlir::Type stmtResultType) override;
182};
183
184} // namespace
185
186mlir::Value HlfirTransformationalIntrinsic::loadBoxAddress(
187 const std::optional<Fortran::lower::PreparedActualArgument> &arg) {
188 if (!arg)
189 return mlir::Value{};
190
191 hlfir::Entity actual = arg->getActual(loc, builder);
192
193 if (!arg->handleDynamicOptional()) {
194 if (actual.isMutableBox()) {
195 // this is a box address type but is not dynamically optional. Just load
196 // the box, assuming it is well formed (!fir.ref<!fir.box<...>> ->
197 // !fir.box<...>)
198 return builder.create<fir::LoadOp>(loc, actual.getBase());
199 }
200 return actual;
201 }
202
203 auto [exv, cleanup] = hlfir::translateToExtendedValue(loc, builder, actual);
204 addCleanup(cleanup);
205
206 mlir::Value isPresent = arg->getIsPresent();
207 // createBox will not do create any invalid memory dereferences if exv is
208 // absent. The created fir.box will not be usable, but the SelectOp below
209 // ensures it won't be.
210 mlir::Value box = builder.createBox(loc, exv);
211 mlir::Type boxType = box.getType();
212 auto absent = builder.create<fir::AbsentOp>(loc, boxType);
213 auto boxOrAbsent = builder.create<mlir::arith::SelectOp>(
214 loc, boxType, isPresent, box, absent);
215
216 return boxOrAbsent;
217}
218
219static mlir::Value loadOptionalValue(
220 mlir::Location loc, fir::FirOpBuilder &builder,
221 const std::optional<Fortran::lower::PreparedActualArgument> &arg,
222 hlfir::Entity actual) {
223 if (!arg->handleDynamicOptional())
224 return hlfir::loadTrivialScalar(loc, builder, actual);
225
226 mlir::Value isPresent = arg->getIsPresent();
227 mlir::Type eleType = hlfir::getFortranElementType(actual.getType());
228 return builder
229 .genIfOp(loc, {eleType}, isPresent,
230 /*withElseRegion=*/true)
231 .genThen([&]() {
232 assert(actual.isScalar() && fir::isa_trivial(eleType) &&
233 "must be a numerical or logical scalar");
234 hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, actual);
235 builder.create<fir::ResultOp>(loc, val);
236 })
237 .genElse([&]() {
238 mlir::Value zero = fir::factory::createZeroValue(builder, loc, eleType);
239 builder.create<fir::ResultOp>(loc, zero);
240 })
241 .getResults()[0];
242}
243
244llvm::SmallVector<mlir::Value> HlfirTransformationalIntrinsic::getOperandVector(
245 const Fortran::lower::PreparedActualArguments &loweredActuals,
246 const fir::IntrinsicArgumentLoweringRules *argLowering) {
247 llvm::SmallVector<mlir::Value> operands;
248 operands.reserve(loweredActuals.size());
249
250 for (size_t i = 0; i < loweredActuals.size(); ++i) {
251 std::optional<Fortran::lower::PreparedActualArgument> arg =
252 loweredActuals[i];
253 if (!arg) {
254 operands.emplace_back();
255 continue;
256 }
257 hlfir::Entity actual = arg->getActual(loc, builder);
258 mlir::Value valArg;
259
260 if (!argLowering) {
261 valArg = hlfir::loadTrivialScalar(loc, builder, actual);
262 } else {
263 fir::ArgLoweringRule argRules =
264 fir::lowerIntrinsicArgumentAs(*argLowering, i);
265 if (argRules.lowerAs == fir::LowerIntrinsicArgAs::Box)
266 valArg = loadBoxAddress(arg);
267 else if (!argRules.handleDynamicOptional &&
268 argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired)
269 valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual);
270 else if (argRules.handleDynamicOptional &&
271 argRules.lowerAs == fir::LowerIntrinsicArgAs::Value)
272 valArg = loadOptionalValue(loc, builder, arg, actual);
273 else if (argRules.handleDynamicOptional)
274 TODO(loc, "hlfir transformational intrinsic dynamically optional "
275 "argument without box lowering");
276 else
277 valArg = actual.getBase();
278 }
279
280 operands.emplace_back(valArg);
281 }
282 return operands;
283}
284
285mlir::Type
286HlfirTransformationalIntrinsic::computeResultType(mlir::Value argArray,
287 mlir::Type stmtResultType) {
288 mlir::Type normalisedResult =
289 hlfir::getFortranElementOrSequenceType(stmtResultType);
290 if (auto array = mlir::dyn_cast<fir::SequenceType>(normalisedResult)) {
291 hlfir::ExprType::Shape resultShape =
292 hlfir::ExprType::Shape{array.getShape()};
293 mlir::Type elementType = array.getEleTy();
294 return hlfir::ExprType::get(builder.getContext(), resultShape, elementType,
295 fir::isPolymorphicType(stmtResultType));
296 } else if (auto resCharType =
297 mlir::dyn_cast<fir::CharacterType>(stmtResultType)) {
298 normalisedResult = hlfir::ExprType::get(
299 builder.getContext(), hlfir::ExprType::Shape{}, resCharType,
300 /*polymorphic=*/false);
301 }
302 return normalisedResult;
303}
304
305template <typename OP, bool HAS_MASK>
306mlir::Value HlfirReductionIntrinsic<OP, HAS_MASK>::lowerImpl(
307 const Fortran::lower::PreparedActualArguments &loweredActuals,
308 const fir::IntrinsicArgumentLoweringRules *argLowering,
309 mlir::Type stmtResultType) {
310 auto operands = getOperandVector(loweredActuals, argLowering);
311 mlir::Value array = operands[0];
312 mlir::Value dim = operands[1];
313 // dim, mask can be NULL if these arguments are not given
314 if (dim)
315 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
316
317 mlir::Type resultTy = computeResultType(array, stmtResultType);
318
319 OP op;
320 if constexpr (HAS_MASK)
321 op = createOp<OP>(resultTy, array, dim,
322 /*mask=*/operands[2]);
323 else
324 op = createOp<OP>(resultTy, array, dim);
325 return op;
326}
327
328template <typename OP>
329mlir::Value HlfirMinMaxLocIntrinsic<OP>::lowerImpl(
330 const Fortran::lower::PreparedActualArguments &loweredActuals,
331 const fir::IntrinsicArgumentLoweringRules *argLowering,
332 mlir::Type stmtResultType) {
333 auto operands = getOperandVector(loweredActuals, argLowering);
334 mlir::Value array = operands[0];
335 mlir::Value dim = operands[1];
336 mlir::Value mask = operands[2];
337 mlir::Value back = operands[4];
338 // dim, mask and back can be NULL if these arguments are not given.
339 if (dim)
340 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
341 if (back)
342 back = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{back});
343
344 mlir::Type resultTy = computeResultType(array, stmtResultType);
345
346 return createOp<OP>(resultTy, array, dim, mask, back);
347}
348
349template <typename OP>
350mlir::Value HlfirProductIntrinsic<OP>::lowerImpl(
351 const Fortran::lower::PreparedActualArguments &loweredActuals,
352 const fir::IntrinsicArgumentLoweringRules *argLowering,
353 mlir::Type stmtResultType) {
354 auto operands = getOperandVector(loweredActuals, argLowering);
355 mlir::Type resultType = computeResultType(operands[0], stmtResultType);
356 return createOp<OP>(resultType, operands[0], operands[1]);
357}
358
359mlir::Value HlfirTransposeLowering::lowerImpl(
360 const Fortran::lower::PreparedActualArguments &loweredActuals,
361 const fir::IntrinsicArgumentLoweringRules *argLowering,
362 mlir::Type stmtResultType) {
363 auto operands = getOperandVector(loweredActuals, argLowering);
364 hlfir::ExprType::Shape resultShape;
365 mlir::Type normalisedResult =
366 hlfir::getFortranElementOrSequenceType(stmtResultType);
367 auto array = mlir::cast<fir::SequenceType>(normalisedResult);
368 llvm::ArrayRef<int64_t> arrayShape = array.getShape();
369 assert(arrayShape.size() == 2 && "arguments to transpose have a rank of 2");
370 mlir::Type elementType = array.getEleTy();
371 resultShape.push_back(arrayShape[0]);
372 resultShape.push_back(arrayShape[1]);
373 if (auto resCharType = mlir::dyn_cast<fir::CharacterType>(elementType))
374 if (!resCharType.hasConstantLen()) {
375 // The FunctionRef expression might have imprecise character
376 // type at this point, and we can improve it by propagating
377 // the constant length from the argument.
378 auto argCharType = mlir::dyn_cast<fir::CharacterType>(
379 hlfir::getFortranElementType(operands[0].getType()));
380 if (argCharType && argCharType.hasConstantLen())
381 elementType = fir::CharacterType::get(
382 builder.getContext(), resCharType.getFKind(), argCharType.getLen());
383 }
384
385 mlir::Type resultTy =
386 hlfir::ExprType::get(builder.getContext(), resultShape, elementType,
387 fir::isPolymorphicType(stmtResultType));
388 return createOp<hlfir::TransposeOp>(resultTy, operands[0]);
389}
390
391mlir::Value HlfirCountLowering::lowerImpl(
392 const Fortran::lower::PreparedActualArguments &loweredActuals,
393 const fir::IntrinsicArgumentLoweringRules *argLowering,
394 mlir::Type stmtResultType) {
395 auto operands = getOperandVector(loweredActuals, argLowering);
396 mlir::Value array = operands[0];
397 mlir::Value dim = operands[1];
398 if (dim)
399 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
400 mlir::Type resultType = computeResultType(array, stmtResultType);
401 return createOp<hlfir::CountOp>(resultType, array, dim);
402}
403
404mlir::Value HlfirCharExtremumLowering::lowerImpl(
405 const Fortran::lower::PreparedActualArguments &loweredActuals,
406 const fir::IntrinsicArgumentLoweringRules *argLowering,
407 mlir::Type stmtResultType) {
408 auto operands = getOperandVector(loweredActuals, argLowering);
409 assert(operands.size() >= 2);
410 return createOp<hlfir::CharExtremumOp>(pred, mlir::ValueRange{operands});
411}
412
413mlir::Value HlfirCShiftLowering::lowerImpl(
414 const Fortran::lower::PreparedActualArguments &loweredActuals,
415 const fir::IntrinsicArgumentLoweringRules *argLowering,
416 mlir::Type stmtResultType) {
417 auto operands = getOperandVector(loweredActuals, argLowering);
418 assert(operands.size() == 3);
419 mlir::Value dim = operands[2];
420 if (!dim) {
421 // If DIM is not present, drop the last element which is a null Value.
422 operands.truncate(2);
423 } else {
424 // If DIM is present, then dereference it if it is a ref.
425 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
426 operands[2] = dim;
427 }
428
429 mlir::Type resultType = computeResultType(operands[0], stmtResultType);
430 return createOp<hlfir::CShiftOp>(resultType, operands);
431}
432
433mlir::Value HlfirReshapeLowering::lowerImpl(
434 const Fortran::lower::PreparedActualArguments &loweredActuals,
435 const fir::IntrinsicArgumentLoweringRules *argLowering,
436 mlir::Type stmtResultType) {
437 auto operands = getOperandVector(loweredActuals, argLowering);
438 assert(operands.size() == 4);
439 mlir::Type resultType = computeResultType(operands[0], stmtResultType);
440 return createOp<hlfir::ReshapeOp>(resultType, operands[0], operands[1],
441 operands[2], operands[3]);
442}
443
444std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
445 fir::FirOpBuilder &builder, mlir::Location loc, const std::string &name,
446 const Fortran::lower::PreparedActualArguments &loweredActuals,
447 const fir::IntrinsicArgumentLoweringRules *argLowering,
448 mlir::Type stmtResultType) {
449 // If the result is of a derived type that may need finalization,
450 // we have to use DestroyOp with 'finalize' attribute for the result
451 // of the intrinsic operation.
452 if (name == "sum")
453 return HlfirSumLowering{builder, loc}.lower(loweredActuals, argLowering,
454 stmtResultType);
455 if (name == "product")
456 return HlfirProductLowering{builder, loc}.lower(loweredActuals, argLowering,
457 stmtResultType);
458 if (name == "any")
459 return HlfirAnyLowering{builder, loc}.lower(loweredActuals, argLowering,
460 stmtResultType);
461 if (name == "all")
462 return HlfirAllLowering{builder, loc}.lower(loweredActuals, argLowering,
463 stmtResultType);
464 if (name == "matmul")
465 return HlfirMatmulLowering{builder, loc}.lower(loweredActuals, argLowering,
466 stmtResultType);
467 if (name == "dot_product")
468 return HlfirDotProductLowering{builder, loc}.lower(
469 loweredActuals, argLowering, stmtResultType);
470 // FIXME: the result may need finalization.
471 if (name == "transpose")
472 return HlfirTransposeLowering{builder, loc}.lower(
473 loweredActuals, argLowering, stmtResultType);
474 if (name == "count")
475 return HlfirCountLowering{builder, loc}.lower(loweredActuals, argLowering,
476 stmtResultType);
477 if (name == "maxval")
478 return HlfirMaxvalLowering{builder, loc}.lower(loweredActuals, argLowering,
479 stmtResultType);
480 if (name == "minval")
481 return HlfirMinvalLowering{builder, loc}.lower(loweredActuals, argLowering,
482 stmtResultType);
483 if (name == "minloc")
484 return HlfirMinlocLowering{builder, loc}.lower(loweredActuals, argLowering,
485 stmtResultType);
486 if (name == "maxloc")
487 return HlfirMaxlocLowering{builder, loc}.lower(loweredActuals, argLowering,
488 stmtResultType);
489 if (name == "cshift")
490 return HlfirCShiftLowering{builder, loc}.lower(loweredActuals, argLowering,
491 stmtResultType);
492 if (name == "reshape")
493 return HlfirReshapeLowering{builder, loc}.lower(loweredActuals, argLowering,
494 stmtResultType);
495 if (mlir::isa<fir::CharacterType>(stmtResultType)) {
496 if (name == "min")
497 return HlfirCharExtremumLowering{builder, loc,
498 hlfir::CharExtremumPredicate::min}
499 .lower(loweredActuals, argLowering, stmtResultType);
500 if (name == "max")
501 return HlfirCharExtremumLowering{builder, loc,
502 hlfir::CharExtremumPredicate::max}
503 .lower(loweredActuals, argLowering, stmtResultType);
504 }
505 return std::nullopt;
506}
507

source code of flang/lib/Lower/HlfirIntrinsics.cpp