1//===-- HlfirIntrinsics.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Lower/HlfirIntrinsics.h"
14
15#include "flang/Optimizer/Builder/BoxValue.h"
16#include "flang/Optimizer/Builder/FIRBuilder.h"
17#include "flang/Optimizer/Builder/HLFIRTools.h"
18#include "flang/Optimizer/Builder/IntrinsicCall.h"
19#include "flang/Optimizer/Builder/MutableBox.h"
20#include "flang/Optimizer/Builder/Todo.h"
21#include "flang/Optimizer/Dialect/FIRType.h"
22#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
23#include "flang/Optimizer/HLFIR/HLFIROps.h"
24#include "mlir/IR/Value.h"
25#include "llvm/ADT/SmallVector.h"
26#include <mlir/IR/ValueRange.h>
27
28namespace {
29
30class HlfirTransformationalIntrinsic {
31public:
32 explicit HlfirTransformationalIntrinsic(fir::FirOpBuilder &builder,
33 mlir::Location loc)
34 : builder(builder), loc(loc) {}
35
36 virtual ~HlfirTransformationalIntrinsic() = default;
37
38 hlfir::EntityWithAttributes
39 lower(const Fortran::lower::PreparedActualArguments &loweredActuals,
40 const fir::IntrinsicArgumentLoweringRules *argLowering,
41 mlir::Type stmtResultType) {
42 mlir::Value res = lowerImpl(loweredActuals, argLowering, stmtResultType);
43 for (const hlfir::CleanupFunction &fn : cleanupFns)
44 fn();
45 return {hlfir::EntityWithAttributes{res}};
46 }
47
48protected:
49 fir::FirOpBuilder &builder;
50 mlir::Location loc;
51 llvm::SmallVector<hlfir::CleanupFunction, 3> cleanupFns;
52
53 virtual mlir::Value
54 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
55 const fir::IntrinsicArgumentLoweringRules *argLowering,
56 mlir::Type stmtResultType) = 0;
57
58 llvm::SmallVector<mlir::Value> getOperandVector(
59 const Fortran::lower::PreparedActualArguments &loweredActuals,
60 const fir::IntrinsicArgumentLoweringRules *argLowering);
61
62 mlir::Type computeResultType(mlir::Value argArray, mlir::Type stmtResultType);
63
64 template <typename OP, typename... BUILD_ARGS>
65 inline OP createOp(BUILD_ARGS... args) {
66 return builder.create<OP>(loc, args...);
67 }
68
69 mlir::Value loadBoxAddress(
70 const std::optional<Fortran::lower::PreparedActualArgument> &arg);
71
72 void addCleanup(std::optional<hlfir::CleanupFunction> cleanup) {
73 if (cleanup)
74 cleanupFns.emplace_back(std::move(*cleanup));
75 }
76};
77
78template <typename OP, bool HAS_MASK>
79class HlfirReductionIntrinsic : public HlfirTransformationalIntrinsic {
80public:
81 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
82
83protected:
84 mlir::Value
85 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
86 const fir::IntrinsicArgumentLoweringRules *argLowering,
87 mlir::Type stmtResultType) override;
88};
89using HlfirSumLowering = HlfirReductionIntrinsic<hlfir::SumOp, true>;
90using HlfirProductLowering = HlfirReductionIntrinsic<hlfir::ProductOp, true>;
91using HlfirMaxvalLowering = HlfirReductionIntrinsic<hlfir::MaxvalOp, true>;
92using HlfirMinvalLowering = HlfirReductionIntrinsic<hlfir::MinvalOp, true>;
93using HlfirAnyLowering = HlfirReductionIntrinsic<hlfir::AnyOp, false>;
94using HlfirAllLowering = HlfirReductionIntrinsic<hlfir::AllOp, false>;
95
96template <typename OP>
97class HlfirMinMaxLocIntrinsic : public HlfirTransformationalIntrinsic {
98public:
99 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
100
101protected:
102 mlir::Value
103 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
104 const fir::IntrinsicArgumentLoweringRules *argLowering,
105 mlir::Type stmtResultType) override;
106};
107using HlfirMinlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MinlocOp>;
108using HlfirMaxlocLowering = HlfirMinMaxLocIntrinsic<hlfir::MaxlocOp>;
109
110template <typename OP>
111class HlfirProductIntrinsic : public HlfirTransformationalIntrinsic {
112public:
113 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
114
115protected:
116 mlir::Value
117 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
118 const fir::IntrinsicArgumentLoweringRules *argLowering,
119 mlir::Type stmtResultType) override;
120};
121using HlfirMatmulLowering = HlfirProductIntrinsic<hlfir::MatmulOp>;
122using HlfirDotProductLowering = HlfirProductIntrinsic<hlfir::DotProductOp>;
123
124class HlfirTransposeLowering : public HlfirTransformationalIntrinsic {
125public:
126 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
127
128protected:
129 mlir::Value
130 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
131 const fir::IntrinsicArgumentLoweringRules *argLowering,
132 mlir::Type stmtResultType) override;
133};
134
135class HlfirCountLowering : public HlfirTransformationalIntrinsic {
136public:
137 using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic;
138
139protected:
140 mlir::Value
141 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
142 const fir::IntrinsicArgumentLoweringRules *argLowering,
143 mlir::Type stmtResultType) override;
144};
145
146class HlfirCharExtremumLowering : public HlfirTransformationalIntrinsic {
147public:
148 HlfirCharExtremumLowering(fir::FirOpBuilder &builder, mlir::Location loc,
149 hlfir::CharExtremumPredicate pred)
150 : HlfirTransformationalIntrinsic(builder, loc), pred{pred} {}
151
152protected:
153 mlir::Value
154 lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals,
155 const fir::IntrinsicArgumentLoweringRules *argLowering,
156 mlir::Type stmtResultType) override;
157
158protected:
159 hlfir::CharExtremumPredicate pred;
160};
161
162} // namespace
163
164mlir::Value HlfirTransformationalIntrinsic::loadBoxAddress(
165 const std::optional<Fortran::lower::PreparedActualArgument> &arg) {
166 if (!arg)
167 return mlir::Value{};
168
169 hlfir::Entity actual = arg->getActual(loc, builder);
170
171 if (!arg->handleDynamicOptional()) {
172 if (actual.isMutableBox()) {
173 // this is a box address type but is not dynamically optional. Just load
174 // the box, assuming it is well formed (!fir.ref<!fir.box<...>> ->
175 // !fir.box<...>)
176 return builder.create<fir::LoadOp>(loc, actual.getBase());
177 }
178 return actual;
179 }
180
181 auto [exv, cleanup] = hlfir::translateToExtendedValue(loc, builder, actual);
182 addCleanup(cleanup);
183
184 mlir::Value isPresent = arg->getIsPresent();
185 // createBox will not do create any invalid memory dereferences if exv is
186 // absent. The created fir.box will not be usable, but the SelectOp below
187 // ensures it won't be.
188 mlir::Value box = builder.createBox(loc, exv);
189 mlir::Type boxType = box.getType();
190 auto absent = builder.create<fir::AbsentOp>(loc, boxType);
191 auto boxOrAbsent = builder.create<mlir::arith::SelectOp>(
192 loc, boxType, isPresent, box, absent);
193
194 return boxOrAbsent;
195}
196
197static mlir::Value loadOptionalValue(
198 mlir::Location loc, fir::FirOpBuilder &builder,
199 const std::optional<Fortran::lower::PreparedActualArgument> &arg,
200 hlfir::Entity actual) {
201 if (!arg->handleDynamicOptional())
202 return hlfir::loadTrivialScalar(loc, builder, actual);
203
204 mlir::Value isPresent = arg->getIsPresent();
205 mlir::Type eleType = hlfir::getFortranElementType(actual.getType());
206 return builder
207 .genIfOp(loc, {eleType}, isPresent,
208 /*withElseRegion=*/true)
209 .genThen([&]() {
210 assert(actual.isScalar() && fir::isa_trivial(eleType) &&
211 "must be a numerical or logical scalar");
212 hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, actual);
213 builder.create<fir::ResultOp>(loc, val);
214 })
215 .genElse([&]() {
216 mlir::Value zero = fir::factory::createZeroValue(builder, loc, eleType);
217 builder.create<fir::ResultOp>(loc, zero);
218 })
219 .getResults()[0];
220}
221
222llvm::SmallVector<mlir::Value> HlfirTransformationalIntrinsic::getOperandVector(
223 const Fortran::lower::PreparedActualArguments &loweredActuals,
224 const fir::IntrinsicArgumentLoweringRules *argLowering) {
225 llvm::SmallVector<mlir::Value> operands;
226 operands.reserve(loweredActuals.size());
227
228 for (size_t i = 0; i < loweredActuals.size(); ++i) {
229 std::optional<Fortran::lower::PreparedActualArgument> arg =
230 loweredActuals[i];
231 if (!arg) {
232 operands.emplace_back();
233 continue;
234 }
235 hlfir::Entity actual = arg->getActual(loc, builder);
236 mlir::Value valArg;
237
238 if (!argLowering) {
239 valArg = hlfir::loadTrivialScalar(loc, builder, actual);
240 } else {
241 fir::ArgLoweringRule argRules =
242 fir::lowerIntrinsicArgumentAs(*argLowering, i);
243 if (argRules.lowerAs == fir::LowerIntrinsicArgAs::Box)
244 valArg = loadBoxAddress(arg);
245 else if (!argRules.handleDynamicOptional &&
246 argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired)
247 valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual);
248 else if (argRules.handleDynamicOptional &&
249 argRules.lowerAs == fir::LowerIntrinsicArgAs::Value)
250 valArg = loadOptionalValue(loc, builder, arg, actual);
251 else if (argRules.handleDynamicOptional)
252 TODO(loc, "hlfir transformational intrinsic dynamically optional "
253 "argument without box lowering");
254 else
255 valArg = actual.getBase();
256 }
257
258 operands.emplace_back(valArg);
259 }
260 return operands;
261}
262
263mlir::Type
264HlfirTransformationalIntrinsic::computeResultType(mlir::Value argArray,
265 mlir::Type stmtResultType) {
266 mlir::Type normalisedResult =
267 hlfir::getFortranElementOrSequenceType(stmtResultType);
268 if (auto array = normalisedResult.dyn_cast<fir::SequenceType>()) {
269 hlfir::ExprType::Shape resultShape =
270 hlfir::ExprType::Shape{array.getShape()};
271 mlir::Type elementType = array.getEleTy();
272 return hlfir::ExprType::get(builder.getContext(), resultShape, elementType,
273 /*polymorphic=*/false);
274 } else if (auto resCharType =
275 mlir::dyn_cast<fir::CharacterType>(stmtResultType)) {
276 normalisedResult = hlfir::ExprType::get(
277 builder.getContext(), hlfir::ExprType::Shape{}, resCharType, false);
278 }
279 return normalisedResult;
280}
281
282template <typename OP, bool HAS_MASK>
283mlir::Value HlfirReductionIntrinsic<OP, HAS_MASK>::lowerImpl(
284 const Fortran::lower::PreparedActualArguments &loweredActuals,
285 const fir::IntrinsicArgumentLoweringRules *argLowering,
286 mlir::Type stmtResultType) {
287 auto operands = getOperandVector(loweredActuals, argLowering);
288 mlir::Value array = operands[0];
289 mlir::Value dim = operands[1];
290 // dim, mask can be NULL if these arguments are not given
291 if (dim)
292 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
293
294 mlir::Type resultTy = computeResultType(array, stmtResultType);
295
296 OP op;
297 if constexpr (HAS_MASK)
298 op = createOp<OP>(resultTy, array, dim,
299 /*mask=*/operands[2]);
300 else
301 op = createOp<OP>(resultTy, array, dim);
302 return op;
303}
304
305template <typename OP>
306mlir::Value HlfirMinMaxLocIntrinsic<OP>::lowerImpl(
307 const Fortran::lower::PreparedActualArguments &loweredActuals,
308 const fir::IntrinsicArgumentLoweringRules *argLowering,
309 mlir::Type stmtResultType) {
310 auto operands = getOperandVector(loweredActuals, argLowering);
311 mlir::Value array = operands[0];
312 mlir::Value dim = operands[1];
313 mlir::Value mask = operands[2];
314 mlir::Value back = operands[4];
315 // dim, mask and back can be NULL if these arguments are not given.
316 if (dim)
317 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
318 if (back)
319 back = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{back});
320
321 mlir::Type resultTy = computeResultType(array, stmtResultType);
322
323 return createOp<OP>(resultTy, array, dim, mask, back);
324}
325
326template <typename OP>
327mlir::Value HlfirProductIntrinsic<OP>::lowerImpl(
328 const Fortran::lower::PreparedActualArguments &loweredActuals,
329 const fir::IntrinsicArgumentLoweringRules *argLowering,
330 mlir::Type stmtResultType) {
331 auto operands = getOperandVector(loweredActuals, argLowering);
332 mlir::Type resultType = computeResultType(operands[0], stmtResultType);
333 return createOp<OP>(resultType, operands[0], operands[1]);
334}
335
336mlir::Value HlfirTransposeLowering::lowerImpl(
337 const Fortran::lower::PreparedActualArguments &loweredActuals,
338 const fir::IntrinsicArgumentLoweringRules *argLowering,
339 mlir::Type stmtResultType) {
340 auto operands = getOperandVector(loweredActuals, argLowering);
341 hlfir::ExprType::Shape resultShape;
342 mlir::Type normalisedResult =
343 hlfir::getFortranElementOrSequenceType(stmtResultType);
344 auto array = normalisedResult.cast<fir::SequenceType>();
345 llvm::ArrayRef<int64_t> arrayShape = array.getShape();
346 assert(arrayShape.size() == 2 && "arguments to transpose have a rank of 2");
347 mlir::Type elementType = array.getEleTy();
348 resultShape.push_back(arrayShape[0]);
349 resultShape.push_back(arrayShape[1]);
350 if (auto resCharType = mlir::dyn_cast<fir::CharacterType>(elementType))
351 if (!resCharType.hasConstantLen()) {
352 // The FunctionRef expression might have imprecise character
353 // type at this point, and we can improve it by propagating
354 // the constant length from the argument.
355 auto argCharType = mlir::dyn_cast<fir::CharacterType>(
356 hlfir::getFortranElementType(operands[0].getType()));
357 if (argCharType && argCharType.hasConstantLen())
358 elementType = fir::CharacterType::get(
359 builder.getContext(), resCharType.getFKind(), argCharType.getLen());
360 }
361
362 mlir::Type resultTy =
363 hlfir::ExprType::get(builder.getContext(), resultShape, elementType,
364 fir::isPolymorphicType(stmtResultType));
365 return createOp<hlfir::TransposeOp>(resultTy, operands[0]);
366}
367
368mlir::Value HlfirCountLowering::lowerImpl(
369 const Fortran::lower::PreparedActualArguments &loweredActuals,
370 const fir::IntrinsicArgumentLoweringRules *argLowering,
371 mlir::Type stmtResultType) {
372 auto operands = getOperandVector(loweredActuals, argLowering);
373 mlir::Value array = operands[0];
374 mlir::Value dim = operands[1];
375 if (dim)
376 dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
377 mlir::Type resultType = computeResultType(array, stmtResultType);
378 return createOp<hlfir::CountOp>(resultType, array, dim);
379}
380
381mlir::Value HlfirCharExtremumLowering::lowerImpl(
382 const Fortran::lower::PreparedActualArguments &loweredActuals,
383 const fir::IntrinsicArgumentLoweringRules *argLowering,
384 mlir::Type stmtResultType) {
385 auto operands = getOperandVector(loweredActuals, argLowering);
386 assert(operands.size() >= 2);
387 return createOp<hlfir::CharExtremumOp>(pred, mlir::ValueRange{operands});
388}
389
390std::optional<hlfir::EntityWithAttributes> Fortran::lower::lowerHlfirIntrinsic(
391 fir::FirOpBuilder &builder, mlir::Location loc, const std::string &name,
392 const Fortran::lower::PreparedActualArguments &loweredActuals,
393 const fir::IntrinsicArgumentLoweringRules *argLowering,
394 mlir::Type stmtResultType) {
395 // If the result is of a derived type that may need finalization,
396 // we have to use DestroyOp with 'finalize' attribute for the result
397 // of the intrinsic operation.
398 if (name == "sum")
399 return HlfirSumLowering{builder, loc}.lower(loweredActuals, argLowering,
400 stmtResultType);
401 if (name == "product")
402 return HlfirProductLowering{builder, loc}.lower(loweredActuals, argLowering,
403 stmtResultType);
404 if (name == "any")
405 return HlfirAnyLowering{builder, loc}.lower(loweredActuals, argLowering,
406 stmtResultType);
407 if (name == "all")
408 return HlfirAllLowering{builder, loc}.lower(loweredActuals, argLowering,
409 stmtResultType);
410 if (name == "matmul")
411 return HlfirMatmulLowering{builder, loc}.lower(loweredActuals, argLowering,
412 stmtResultType);
413 if (name == "dot_product")
414 return HlfirDotProductLowering{builder, loc}.lower(
415 loweredActuals, argLowering, stmtResultType);
416 // FIXME: the result may need finalization.
417 if (name == "transpose")
418 return HlfirTransposeLowering{builder, loc}.lower(
419 loweredActuals, argLowering, stmtResultType);
420 if (name == "count")
421 return HlfirCountLowering{builder, loc}.lower(loweredActuals, argLowering,
422 stmtResultType);
423 if (name == "maxval")
424 return HlfirMaxvalLowering{builder, loc}.lower(loweredActuals, argLowering,
425 stmtResultType);
426 if (name == "minval")
427 return HlfirMinvalLowering{builder, loc}.lower(loweredActuals, argLowering,
428 stmtResultType);
429 if (name == "minloc")
430 return HlfirMinlocLowering{builder, loc}.lower(loweredActuals, argLowering,
431 stmtResultType);
432 if (name == "maxloc")
433 return HlfirMaxlocLowering{builder, loc}.lower(loweredActuals, argLowering,
434 stmtResultType);
435 if (mlir::isa<fir::CharacterType>(stmtResultType)) {
436 if (name == "min")
437 return HlfirCharExtremumLowering{builder, loc,
438 hlfir::CharExtremumPredicate::min}
439 .lower(loweredActuals, argLowering, stmtResultType);
440 if (name == "max")
441 return HlfirCharExtremumLowering{builder, loc,
442 hlfir::CharExtremumPredicate::max}
443 .lower(loweredActuals, argLowering, stmtResultType);
444 }
445 return std::nullopt;
446}
447

source code of flang/lib/Lower/HlfirIntrinsics.cpp