1//===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9//===----------------------------------------------------------------------===//
10/// \file
11/// This pass looks for suitable calls to runtime library for intrinsics that
12/// can be simplified/specialized and replaces with a specialized function.
13///
14/// For example, SUM(arr) can be specialized as a simple function with one loop,
15/// compared to the three arguments (plus file & line info) that the runtime
16/// call has - when the argument is a 1D-array (multiple loops may be needed
17// for higher dimension arrays, of course)
18///
19/// The general idea is that besides making the call simpler, it can also be
20/// inlined by other passes that run after this pass, which further improves
21/// performance, particularly when the work done in the function is trivial
22/// and small in size.
23//===----------------------------------------------------------------------===//
24
25#include "flang/Common/Fortran.h"
26#include "flang/Optimizer/Builder/BoxValue.h"
27#include "flang/Optimizer/Builder/FIRBuilder.h"
28#include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
29#include "flang/Optimizer/Builder/Todo.h"
30#include "flang/Optimizer/Dialect/FIROps.h"
31#include "flang/Optimizer/Dialect/FIRType.h"
32#include "flang/Optimizer/Dialect/Support/FIRContext.h"
33#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
34#include "flang/Optimizer/Transforms/Passes.h"
35#include "flang/Optimizer/Transforms/Utils.h"
36#include "flang/Runtime/entry-names.h"
37#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
38#include "mlir/IR/Matchers.h"
39#include "mlir/IR/Operation.h"
40#include "mlir/Pass/Pass.h"
41#include "mlir/Transforms/DialectConversion.h"
42#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
43#include "mlir/Transforms/RegionUtils.h"
44#include "llvm/Support/Debug.h"
45#include "llvm/Support/raw_ostream.h"
46#include <llvm/Support/ErrorHandling.h>
47#include <mlir/Dialect/Arith/IR/Arith.h>
48#include <mlir/IR/BuiltinTypes.h>
49#include <mlir/IR/Location.h>
50#include <mlir/IR/MLIRContext.h>
51#include <mlir/IR/Value.h>
52#include <mlir/Support/LLVM.h>
53#include <optional>
54
55namespace fir {
56#define GEN_PASS_DEF_SIMPLIFYINTRINSICS
57#include "flang/Optimizer/Transforms/Passes.h.inc"
58} // namespace fir
59
60#define DEBUG_TYPE "flang-simplify-intrinsics"
61
62namespace {
63
64class SimplifyIntrinsicsPass
65 : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
66 using FunctionTypeGeneratorTy =
67 llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
68 using FunctionBodyGeneratorTy =
69 llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
70 using GenReductionBodyTy = llvm::function_ref<void(
71 fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank,
72 mlir::Type elementType)>;
73
74public:
75 /// Generate a new function implementing a simplified version
76 /// of a Fortran runtime function defined by \p basename name.
77 /// \p typeGenerator is a callback that generates the new function's type.
78 /// \p bodyGenerator is a callback that generates the new function's body.
79 /// The new function is created in the \p builder's Module.
80 mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
81 const mlir::StringRef &basename,
82 FunctionTypeGeneratorTy typeGenerator,
83 FunctionBodyGeneratorTy bodyGenerator);
84 void runOnOperation() override;
85 void getDependentDialects(mlir::DialectRegistry &registry) const override;
86
87private:
88 /// Helper functions to replace a reduction type of call with its
89 /// simplified form. The actual function is generated using a callback
90 /// function.
91 /// \p call is the call to be replaced
92 /// \p kindMap is used to create FIROpBuilder
93 /// \p genBodyFunc is the callback that builds the replacement function
94 void simplifyIntOrFloatReduction(fir::CallOp call,
95 const fir::KindMapping &kindMap,
96 GenReductionBodyTy genBodyFunc);
97 void simplifyLogicalDim0Reduction(fir::CallOp call,
98 const fir::KindMapping &kindMap,
99 GenReductionBodyTy genBodyFunc);
100 void simplifyLogicalDim1Reduction(fir::CallOp call,
101 const fir::KindMapping &kindMap,
102 GenReductionBodyTy genBodyFunc);
103 void simplifyMinMaxlocReduction(fir::CallOp call,
104 const fir::KindMapping &kindMap, bool isMax);
105 void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap,
106 GenReductionBodyTy genBodyFunc,
107 fir::FirOpBuilder &builder,
108 const mlir::StringRef &basename,
109 mlir::Type elementType);
110};
111
112} // namespace
113
114/// Create FirOpBuilder with the provided \p op insertion point
115/// and \p kindMap additionally inheriting FastMathFlags from \p op.
116static fir::FirOpBuilder
117getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) {
118 fir::FirOpBuilder builder{op, kindMap};
119 auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
120 if (!fmi)
121 return builder;
122
123 // Regardless of what default FastMathFlags are used by FirOpBuilder,
124 // override them with FastMathFlags attached to the operation.
125 builder.setFastMathFlags(fmi.getFastMathFlagsAttr().getValue());
126 return builder;
127}
128
129/// Generate function type for the simplified version of RTNAME(Sum) and
130/// similar functions with a fir.box<none> type returning \p elementType.
131static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
132 const mlir::Type &elementType) {
133 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
134 return mlir::FunctionType::get(builder.getContext(), {boxType},
135 {elementType});
136}
137
138template <typename Op>
139Op expectOp(mlir::Value val) {
140 if (Op op = mlir::dyn_cast_or_null<Op>(val.getDefiningOp()))
141 return op;
142 LLVM_DEBUG(llvm::dbgs() << "Didn't find expected " << Op::getOperationName()
143 << '\n');
144 return nullptr;
145}
146
147template <typename Op>
148static mlir::Value findDefSingle(fir::ConvertOp op) {
149 if (auto defOp = expectOp<Op>(op->getOperand(0))) {
150 return defOp.getResult();
151 }
152 return {};
153}
154
155template <typename... Ops>
156static mlir::Value findDef(fir::ConvertOp op) {
157 mlir::Value defOp;
158 // Loop over the operation types given to see if any match, exiting once
159 // a match is found. Cast to void is needed to avoid compiler complaining
160 // that the result of expression is unused
161 (void)((defOp = findDefSingle<Ops>(op), (defOp)) || ...);
162 return defOp;
163}
164
165static bool isOperandAbsent(mlir::Value val) {
166 if (auto op = expectOp<fir::ConvertOp>(val)) {
167 assert(op->getOperands().size() != 0);
168 return mlir::isa_and_nonnull<fir::AbsentOp>(
169 op->getOperand(0).getDefiningOp());
170 }
171 return false;
172}
173
174static bool isTrueOrNotConstant(mlir::Value val) {
175 if (auto op = expectOp<mlir::arith::ConstantOp>(val)) {
176 return !mlir::matchPattern(val, mlir::m_Zero());
177 }
178 return true;
179}
180
181static bool isZero(mlir::Value val) {
182 if (auto op = expectOp<fir::ConvertOp>(val)) {
183 assert(op->getOperands().size() != 0);
184 if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
185 return mlir::matchPattern(defOp, mlir::m_Zero());
186 }
187 return false;
188}
189
190static mlir::Value findBoxDef(mlir::Value val) {
191 if (auto op = expectOp<fir::ConvertOp>(val)) {
192 assert(op->getOperands().size() != 0);
193 return findDef<fir::EmboxOp, fir::ReboxOp>(op);
194 }
195 return {};
196}
197
198static mlir::Value findMaskDef(mlir::Value val) {
199 if (auto op = expectOp<fir::ConvertOp>(val)) {
200 assert(op->getOperands().size() != 0);
201 return findDef<fir::EmboxOp, fir::ReboxOp, fir::AbsentOp>(op);
202 }
203 return {};
204}
205
206static unsigned getDimCount(mlir::Value val) {
207 // In order to find the dimensions count, we look for EmboxOp/ReboxOp
208 // and take the count from its *result* type. Note that in case
209 // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
210 // have different types.
211 // Actually, we can take the box type from the operand of
212 // the first ConvertOp that has non-opaque box type that we meet
213 // going through the ConvertOp chain.
214 if (mlir::Value emboxVal = findBoxDef(val))
215 if (auto boxTy = emboxVal.getType().dyn_cast<fir::BoxType>())
216 if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
217 return seqTy.getDimension();
218 return 0;
219}
220
221/// Given the call operation's box argument \p val, discover
222/// the element type of the underlying array object.
223/// \returns the element type or std::nullopt if the type cannot
224/// be reliably found.
225/// We expect that the argument is a result of fir.convert
226/// with the destination type of !fir.box<none>.
227static std::optional<mlir::Type> getArgElementType(mlir::Value val) {
228 mlir::Operation *defOp;
229 do {
230 defOp = val.getDefiningOp();
231 // Analyze only sequences of convert operations.
232 if (!mlir::isa<fir::ConvertOp>(defOp))
233 return std::nullopt;
234 val = defOp->getOperand(0);
235 // The convert operation is expected to convert from one
236 // box type to another box type.
237 auto boxType = val.getType().cast<fir::BoxType>();
238 auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
239 if (!elementType.isa<mlir::NoneType>())
240 return elementType;
241 } while (true);
242}
243
244using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
245 fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
246 mlir::Value)>;
247using ContinueLoopGenTy = llvm::function_ref<llvm::SmallVector<mlir::Value>(
248 fir::FirOpBuilder &, mlir::Location, mlir::Value)>;
249
250/// Generate the reduction loop into \p funcOp.
251///
252/// \p initVal is a function, called to get the initial value for
253/// the reduction value
254/// \p genBody is called to fill in the actual reduciton operation
255/// for example add for SUM, MAX for MAXVAL, etc.
256/// \p rank is the rank of the input argument.
257/// \p elementType is the type of the elements in the input array,
258/// which may be different to the return type.
259/// \p loopCond is called to generate the condition to continue or
260/// not for IterWhile loops
261/// \p unorderedOrInitalLoopCond contains either a boolean or bool
262/// mlir constant, and controls the inital value for while loops
263/// or if DoLoop is ordered/unordered.
264
265template <typename OP, typename T, int resultIndex>
266static void
267genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
268 fir::InitValGeneratorTy initVal, ContinueLoopGenTy loopCond,
269 T unorderedOrInitialLoopCond, BodyOpGeneratorTy genBody,
270 unsigned rank, mlir::Type elementType, mlir::Location loc) {
271
272 mlir::IndexType idxTy = builder.getIndexType();
273
274 mlir::Block::BlockArgListType args = funcOp.front().getArguments();
275 mlir::Value arg = args[0];
276
277 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
278
279 fir::SequenceType::Shape flatShape(rank,
280 fir::SequenceType::getUnknownExtent());
281 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
282 mlir::Type boxArrTy = fir::BoxType::get(arrTy);
283 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
284 mlir::Type resultType = funcOp.getResultTypes()[0];
285 mlir::Value init = initVal(builder, loc, resultType);
286
287 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
288
289 assert(rank > 0 && "rank cannot be zero");
290 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
291
292 // Compute all the upper bounds before the loop nest.
293 // It is not strictly necessary for performance, since the loop nest
294 // does not have any store operations and any LICM optimization
295 // should be able to optimize the redundancy.
296 for (unsigned i = 0; i < rank; ++i) {
297 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
298 auto dims =
299 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
300 mlir::Value len = dims.getResult(1);
301 // We use C indexing here, so len-1 as loopcount
302 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
303 bounds.push_back(loopCount);
304 }
305 // Create a loop nest consisting of OP operations.
306 // Collect the loops' induction variables into indices array,
307 // which will be used in the innermost loop to load the input
308 // array's element.
309 // The loops are generated such that the innermost loop processes
310 // the 0 dimension.
311 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
312 for (unsigned i = rank; 0 < i; --i) {
313 mlir::Value step = one;
314 mlir::Value loopCount = bounds[i - 1];
315 auto loop = builder.create<OP>(loc, zeroIdx, loopCount, step,
316 unorderedOrInitialLoopCond,
317 /*finalCountValue=*/false, init);
318 init = loop.getRegionIterArgs()[resultIndex];
319 indices.push_back(loop.getInductionVar());
320 // Set insertion point to the loop body so that the next loop
321 // is inserted inside the current one.
322 builder.setInsertionPointToStart(loop.getBody());
323 }
324
325 // Reverse the indices such that they are ordered as:
326 // <dim-0-idx, dim-1-idx, ...>
327 std::reverse(indices.begin(), indices.end());
328 // We are in the innermost loop: generate the reduction body.
329 mlir::Type eleRefTy = builder.getRefType(elementType);
330 mlir::Value addr =
331 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
332 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
333 mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
334 // Generate vector with condition to continue while loop at [0] and result
335 // from current loop at [1] for IterWhileOp loops, just result at [0] for
336 // DoLoopOp loops.
337 llvm::SmallVector<mlir::Value> results = loopCond(builder, loc, reductionVal);
338
339 // Unwind the loop nest and insert ResultOp on each level
340 // to return the updated value of the reduction to the enclosing
341 // loops.
342 for (unsigned i = 0; i < rank; ++i) {
343 auto result = builder.create<fir::ResultOp>(loc, results);
344 // Proceed to the outer loop.
345 auto loop = mlir::cast<OP>(result->getParentOp());
346 results = loop.getResults();
347 // Set insertion point after the loop operation that we have
348 // just processed.
349 builder.setInsertionPointAfter(loop.getOperation());
350 }
351 // End of loop nest. The insertion point is after the outermost loop.
352 // Return the reduction value from the function.
353 builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]);
354}
355
356static llvm::SmallVector<mlir::Value> nopLoopCond(fir::FirOpBuilder &builder,
357 mlir::Location loc,
358 mlir::Value reductionVal) {
359 return {reductionVal};
360}
361
362/// Generate function body of the simplified version of RTNAME(Sum)
363/// with signature provided by \p funcOp. The caller is responsible
364/// for saving/restoring the original insertion point of \p builder.
365/// \p funcOp is expected to be empty on entry to this function.
366/// \p rank specifies the rank of the input argument.
367static void genRuntimeSumBody(fir::FirOpBuilder &builder,
368 mlir::func::FuncOp &funcOp, unsigned rank,
369 mlir::Type elementType) {
370 // function RTNAME(Sum)<T>x<rank>_simplified(arr)
371 // T, dimension(:) :: arr
372 // T sum = 0
373 // integer iter
374 // do iter = 0, extent(arr)
375 // sum = sum + arr[iter]
376 // end do
377 // RTNAME(Sum)<T>x<rank>_simplified = sum
378 // end function RTNAME(Sum)<T>x<rank>_simplified
379 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
380 mlir::Type elementType) {
381 if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
382 const llvm::fltSemantics &sem = ty.getFloatSemantics();
383 return builder.createRealConstant(loc, elementType,
384 llvm::APFloat::getZero(sem));
385 }
386 return builder.createIntegerConstant(loc, elementType, 0);
387 };
388
389 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
390 mlir::Type elementType, mlir::Value elem1,
391 mlir::Value elem2) -> mlir::Value {
392 if (elementType.isa<mlir::FloatType>())
393 return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2);
394 if (elementType.isa<mlir::IntegerType>())
395 return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2);
396
397 llvm_unreachable("unsupported type");
398 return {};
399 };
400
401 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
402 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
403
404 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond,
405 false, genBodyOp, rank, elementType,
406 loc);
407}
408
409static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
410 mlir::func::FuncOp &funcOp, unsigned rank,
411 mlir::Type elementType) {
412 auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
413 mlir::Type elementType) {
414 if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
415 const llvm::fltSemantics &sem = ty.getFloatSemantics();
416 return builder.createRealConstant(
417 loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true));
418 }
419 unsigned bits = elementType.getIntOrFloatBitWidth();
420 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
421 return builder.createIntegerConstant(loc, elementType, minInt);
422 };
423
424 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
425 mlir::Type elementType, mlir::Value elem1,
426 mlir::Value elem2) -> mlir::Value {
427 if (elementType.isa<mlir::FloatType>()) {
428 // arith.maxf later converted to llvm.intr.maxnum does not work
429 // correctly for NaNs and -0.0 (see maxnum/minnum pattern matching
430 // in LLVM's InstCombine pass). Moreover, llvm.intr.maxnum
431 // for F128 operands is lowered into fmaxl call by LLVM.
432 // This libm function may not work properly for F128 arguments
433 // on targets where long double is not F128. It is an LLVM issue,
434 // but we just use normal select here to resolve all the cases.
435 auto compare = builder.create<mlir::arith::CmpFOp>(
436 loc, mlir::arith::CmpFPredicate::OGT, elem1, elem2);
437 return builder.create<mlir::arith::SelectOp>(loc, compare, elem1, elem2);
438 }
439 if (elementType.isa<mlir::IntegerType>())
440 return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2);
441
442 llvm_unreachable("unsupported type");
443 return {};
444 };
445
446 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
447 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
448
449 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, init, nopLoopCond,
450 false, genBodyOp, rank, elementType,
451 loc);
452}
453
454static void genRuntimeCountBody(fir::FirOpBuilder &builder,
455 mlir::func::FuncOp &funcOp, unsigned rank,
456 mlir::Type elementType) {
457 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
458 mlir::Type elementType) {
459 unsigned bits = elementType.getIntOrFloatBitWidth();
460 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
461 return builder.createIntegerConstant(loc, elementType, zeroInt);
462 };
463
464 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
465 mlir::Type elementType, mlir::Value elem1,
466 mlir::Value elem2) -> mlir::Value {
467 auto zero32 = builder.createIntegerConstant(loc, elementType, 0);
468 auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0);
469 auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1);
470
471 auto compare = builder.create<mlir::arith::CmpIOp>(
472 loc, mlir::arith::CmpIPredicate::eq, elem1, zero32);
473 auto select =
474 builder.create<mlir::arith::SelectOp>(loc, compare, zero64, one64);
475 return builder.create<mlir::arith::AddIOp>(loc, select, elem2);
476 };
477
478 // Count always gets I32 for elementType as it converts logical input to
479 // logical<4> before passing to the function.
480 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
481 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
482
483 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond,
484 false, genBodyOp, rank, elementType,
485 loc);
486}
487
488static void genRuntimeAnyBody(fir::FirOpBuilder &builder,
489 mlir::func::FuncOp &funcOp, unsigned rank,
490 mlir::Type elementType) {
491 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
492 mlir::Type elementType) {
493 return builder.createIntegerConstant(loc, elementType, 0);
494 };
495
496 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
497 mlir::Type elementType, mlir::Value elem1,
498 mlir::Value elem2) -> mlir::Value {
499 auto zero = builder.createIntegerConstant(loc, elementType, 0);
500 return builder.create<mlir::arith::CmpIOp>(
501 loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
502 };
503
504 auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
505 mlir::Value reductionVal) {
506 auto one1 = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
507 auto eor = builder.create<mlir::arith::XOrIOp>(loc, reductionVal, one1);
508 llvm::SmallVector<mlir::Value> results = {eor, reductionVal};
509 return results;
510 };
511
512 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
513 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
514 mlir::Value ok = builder.createBool(loc, true);
515
516 genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
517 builder, funcOp, zero, continueCond, ok, genBodyOp, rank, elementType,
518 loc);
519}
520
521static void genRuntimeAllBody(fir::FirOpBuilder &builder,
522 mlir::func::FuncOp &funcOp, unsigned rank,
523 mlir::Type elementType) {
524 auto one = [](fir::FirOpBuilder builder, mlir::Location loc,
525 mlir::Type elementType) {
526 return builder.createIntegerConstant(loc, elementType, 1);
527 };
528
529 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
530 mlir::Type elementType, mlir::Value elem1,
531 mlir::Value elem2) -> mlir::Value {
532 auto zero = builder.createIntegerConstant(loc, elementType, 0);
533 return builder.create<mlir::arith::CmpIOp>(
534 loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
535 };
536
537 auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
538 mlir::Value reductionVal) {
539 llvm::SmallVector<mlir::Value> results = {reductionVal, reductionVal};
540 return results;
541 };
542
543 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
544 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
545 mlir::Value ok = builder.createBool(loc, true);
546
547 genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
548 builder, funcOp, one, continueCond, ok, genBodyOp, rank, elementType,
549 loc);
550}
551
552static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder,
553 unsigned int rank) {
554 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
555 mlir::Type boxRefType = builder.getRefType(boxType);
556
557 return mlir::FunctionType::get(builder.getContext(),
558 {boxRefType, boxType, boxType}, {});
559}
560
561// Produces a loop nest for a Minloc intrinsic.
562void fir::genMinMaxlocReductionLoop(
563 fir::FirOpBuilder &builder, mlir::Value array,
564 fir::InitValGeneratorTy initVal, fir::MinlocBodyOpGeneratorTy genBody,
565 fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType,
566 mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr,
567 bool maskMayBeLogicalScalar) {
568 mlir::IndexType idxTy = builder.getIndexType();
569
570 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
571
572 fir::SequenceType::Shape flatShape(rank,
573 fir::SequenceType::getUnknownExtent());
574 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
575 mlir::Type boxArrTy = fir::BoxType::get(arrTy);
576 array = builder.create<fir::ConvertOp>(loc, boxArrTy, array);
577
578 mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType());
579 mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
580 mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0);
581 mlir::Value flagRef = builder.createTemporary(loc, resultElemType);
582 builder.create<fir::StoreOp>(loc, zero, flagRef);
583
584 mlir::Value init = initVal(builder, loc, elementType);
585 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
586
587 assert(rank > 0 && "rank cannot be zero");
588 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
589
590 // Compute all the upper bounds before the loop nest.
591 // It is not strictly necessary for performance, since the loop nest
592 // does not have any store operations and any LICM optimization
593 // should be able to optimize the redundancy.
594 for (unsigned i = 0; i < rank; ++i) {
595 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
596 auto dims =
597 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
598 mlir::Value len = dims.getResult(1);
599 // We use C indexing here, so len-1 as loopcount
600 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
601 bounds.push_back(loopCount);
602 }
603 // Create a loop nest consisting of OP operations.
604 // Collect the loops' induction variables into indices array,
605 // which will be used in the innermost loop to load the input
606 // array's element.
607 // The loops are generated such that the innermost loop processes
608 // the 0 dimension.
609 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
610 for (unsigned i = rank; 0 < i; --i) {
611 mlir::Value step = one;
612 mlir::Value loopCount = bounds[i - 1];
613 auto loop =
614 builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false,
615 /*finalCountValue=*/false, init);
616 init = loop.getRegionIterArgs()[0];
617 indices.push_back(loop.getInductionVar());
618 // Set insertion point to the loop body so that the next loop
619 // is inserted inside the current one.
620 builder.setInsertionPointToStart(loop.getBody());
621 }
622
623 // Reverse the indices such that they are ordered as:
624 // <dim-0-idx, dim-1-idx, ...>
625 std::reverse(indices.begin(), indices.end());
626 mlir::Value reductionVal =
627 genBody(builder, loc, elementType, array, flagRef, init, indices);
628
629 // Unwind the loop nest and insert ResultOp on each level
630 // to return the updated value of the reduction to the enclosing
631 // loops.
632 for (unsigned i = 0; i < rank; ++i) {
633 auto result = builder.create<fir::ResultOp>(loc, reductionVal);
634 // Proceed to the outer loop.
635 auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
636 reductionVal = loop.getResult(0);
637 // Set insertion point after the loop operation that we have
638 // just processed.
639 builder.setInsertionPointAfter(loop.getOperation());
640 }
641 // End of loop nest. The insertion point is after the outermost loop.
642 if (maskMayBeLogicalScalar) {
643 if (fir::IfOp ifOp =
644 mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) {
645 builder.create<fir::ResultOp>(loc, reductionVal);
646 builder.setInsertionPointAfter(ifOp);
647 // Redefine flagSet to escape scope of ifOp
648 flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
649 reductionVal = ifOp.getResult(0);
650 }
651 }
652}
653
654static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
655 mlir::func::FuncOp &funcOp, bool isMax,
656 unsigned rank, int maskRank,
657 mlir::Type elementType,
658 mlir::Type maskElemType,
659 mlir::Type resultElemTy, bool isDim) {
660 auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
661 mlir::Type elementType) {
662 if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
663 const llvm::fltSemantics &sem = ty.getFloatSemantics();
664 llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
665 return builder.createRealConstant(loc, elementType, limit);
666 }
667 unsigned bits = elementType.getIntOrFloatBitWidth();
668 int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits)
669 : llvm::APInt::getSignedMaxValue(bits))
670 .getSExtValue();
671 return builder.createIntegerConstant(loc, elementType, initValue);
672 };
673
674 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
675 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
676
677 mlir::Value mask = funcOp.front().getArgument(2);
678
679 // Set up result array in case of early exit / 0 length array
680 mlir::IndexType idxTy = builder.getIndexType();
681 mlir::Type resultTy = fir::SequenceType::get(rank, resultElemTy);
682 mlir::Type resultHeapTy = fir::HeapType::get(resultTy);
683 mlir::Type resultBoxTy = fir::BoxType::get(resultHeapTy);
684
685 mlir::Value returnValue = builder.createIntegerConstant(loc, resultElemTy, 0);
686 mlir::Value resultArrSize = builder.createIntegerConstant(loc, idxTy, rank);
687
688 mlir::Value resultArrInit = builder.create<fir::AllocMemOp>(loc, resultTy);
689 mlir::Value resultArrShape = builder.create<fir::ShapeOp>(loc, resultArrSize);
690 mlir::Value resultArr = builder.create<fir::EmboxOp>(
691 loc, resultBoxTy, resultArrInit, resultArrShape);
692
693 mlir::Type resultRefTy = builder.getRefType(resultElemTy);
694
695 if (maskRank > 0) {
696 fir::SequenceType::Shape flatShape(rank,
697 fir::SequenceType::getUnknownExtent());
698 mlir::Type maskTy = fir::SequenceType::get(flatShape, maskElemType);
699 mlir::Type boxMaskTy = fir::BoxType::get(maskTy);
700 mask = builder.create<fir::ConvertOp>(loc, boxMaskTy, mask);
701 }
702
703 for (unsigned int i = 0; i < rank; ++i) {
704 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
705 mlir::Value resultElemAddr =
706 builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, index);
707 builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr);
708 }
709
710 auto genBodyOp =
711 [&rank, &resultArr, isMax, &mask, &maskElemType, &maskRank](
712 fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType,
713 mlir::Value array, mlir::Value flagRef, mlir::Value reduction,
714 const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
715 // We are in the innermost loop: generate the reduction body.
716 if (maskRank > 0) {
717 mlir::Type logicalRef = builder.getRefType(maskElemType);
718 mlir::Value maskAddr =
719 builder.create<fir::CoordinateOp>(loc, logicalRef, mask, indices);
720 mlir::Value maskElem = builder.create<fir::LoadOp>(loc, maskAddr);
721
722 // fir::IfOp requires argument to be I1 - won't accept logical or any
723 // other Integer.
724 mlir::Type ifCompatType = builder.getI1Type();
725 mlir::Value ifCompatElem =
726 builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);
727
728 llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
729 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
730 /*withElseRegion=*/true);
731 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
732 }
733
734 // Set flag that mask was true at some point
735 mlir::Value flagSet = builder.createIntegerConstant(
736 loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
737 mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
738 mlir::Type eleRefTy = builder.getRefType(elementType);
739 mlir::Value addr =
740 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
741 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
742
743 mlir::Value cmp;
744 if (elementType.isa<mlir::FloatType>()) {
745 // For FP reductions we want the first smallest value to be used, that
746 // is not NaN. A OGL/OLT condition will usually work for this unless all
747 // the values are Nan or Inf. This follows the same logic as
748 // NumericCompare for Minloc/Maxlox in extrema.cpp.
749 cmp = builder.create<mlir::arith::CmpFOp>(
750 loc,
751 isMax ? mlir::arith::CmpFPredicate::OGT
752 : mlir::arith::CmpFPredicate::OLT,
753 elem, reduction);
754
755 mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
756 loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
757 mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
758 loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
759 cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
760 cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
761 } else if (elementType.isa<mlir::IntegerType>()) {
762 cmp = builder.create<mlir::arith::CmpIOp>(
763 loc,
764 isMax ? mlir::arith::CmpIPredicate::sgt
765 : mlir::arith::CmpIPredicate::slt,
766 elem, reduction);
767 } else {
768 llvm_unreachable("unsupported type");
769 }
770
771 // The condition used for the loop is isFirst || <the condition above>.
772 isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
773 isFirst = builder.create<mlir::arith::XOrIOp>(
774 loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
775 cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
776 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
777 /*withElseRegion*/ true);
778
779 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
780 builder.create<fir::StoreOp>(loc, flagSet, flagRef);
781 mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
782 mlir::Type returnRefTy = builder.getRefType(resultElemTy);
783 mlir::IndexType idxTy = builder.getIndexType();
784
785 mlir::Value one = builder.createIntegerConstant(loc, resultElemTy, 1);
786
787 for (unsigned int i = 0; i < rank; ++i) {
788 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
789 mlir::Value resultElemAddr =
790 builder.create<fir::CoordinateOp>(loc, returnRefTy, resultArr, index);
791 mlir::Value convert =
792 builder.create<fir::ConvertOp>(loc, resultElemTy, indices[i]);
793 mlir::Value fortranIndex =
794 builder.create<mlir::arith::AddIOp>(loc, convert, one);
795 builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr);
796 }
797 builder.create<fir::ResultOp>(loc, elem);
798 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
799 builder.create<fir::ResultOp>(loc, reduction);
800 builder.setInsertionPointAfter(ifOp);
801 mlir::Value reductionVal = ifOp.getResult(0);
802
803 // Close the mask if needed
804 if (maskRank > 0) {
805 fir::IfOp ifOp =
806 mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp());
807 builder.create<fir::ResultOp>(loc, reductionVal);
808 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
809 builder.create<fir::ResultOp>(loc, reduction);
810 reductionVal = ifOp.getResult(0);
811 builder.setInsertionPointAfter(ifOp);
812 }
813
814 return reductionVal;
815 };
816
817 // if mask is a logical scalar, we can check its value before the main loop
818 // and either ignore the fact it is there or exit early.
819 if (maskRank == 0) {
820 mlir::Type logical = builder.getI1Type();
821 mlir::IndexType idxTy = builder.getIndexType();
822
823 fir::SequenceType::Shape singleElement(1, 1);
824 mlir::Type arrTy = fir::SequenceType::get(singleElement, logical);
825 mlir::Type boxArrTy = fir::BoxType::get(arrTy);
826 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, mask);
827
828 mlir::Value indx = builder.createIntegerConstant(loc, idxTy, 0);
829 mlir::Type logicalRefTy = builder.getRefType(logical);
830 mlir::Value condAddr =
831 builder.create<fir::CoordinateOp>(loc, logicalRefTy, array, indx);
832 mlir::Value cond = builder.create<fir::LoadOp>(loc, condAddr);
833
834 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cond,
835 /*withElseRegion=*/true);
836
837 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
838 mlir::Value basicValue;
839 if (elementType.isa<mlir::IntegerType>()) {
840 basicValue = builder.createIntegerConstant(loc, elementType, 0);
841 } else {
842 basicValue = builder.createRealConstant(loc, elementType, 0);
843 }
844 builder.create<fir::ResultOp>(loc, basicValue);
845
846 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
847 }
848 auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
849 const mlir::Type &resultElemType, mlir::Value resultArr,
850 mlir::Value index) {
851 mlir::Type resultRefTy = builder.getRefType(resultElemType);
852 return builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr,
853 index);
854 };
855
856 genMinMaxlocReductionLoop(builder, funcOp.front().getArgument(1), init,
857 genBodyOp, getAddrFn, rank, elementType, loc,
858 maskElemType, resultArr, maskRank == 0);
859
860 // Store newly created output array to the reference passed in
861 if (isDim) {
862 mlir::Type resultBoxTy =
863 fir::BoxType::get(fir::HeapType::get(resultElemTy));
864 mlir::Value outputArr = builder.create<fir::ConvertOp>(
865 loc, builder.getRefType(resultBoxTy), funcOp.front().getArgument(0));
866 mlir::Value resultArrScalar = builder.create<fir::ConvertOp>(
867 loc, fir::HeapType::get(resultElemTy), resultArrInit);
868 mlir::Value resultBox =
869 builder.create<fir::EmboxOp>(loc, resultBoxTy, resultArrScalar);
870 builder.create<fir::StoreOp>(loc, resultBox, outputArr);
871 } else {
872 fir::SequenceType::Shape resultShape(1, rank);
873 mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy);
874 mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy);
875 mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy);
876 mlir::Type outputRefTy = builder.getRefType(outputBoxTy);
877 mlir::Value outputArr = builder.create<fir::ConvertOp>(
878 loc, outputRefTy, funcOp.front().getArgument(0));
879 builder.create<fir::StoreOp>(loc, resultArr, outputArr);
880 }
881
882 builder.create<mlir::func::ReturnOp>(loc);
883}
884
885/// Generate function type for the simplified version of RTNAME(DotProduct)
886/// operating on the given \p elementType.
887static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder,
888 const mlir::Type &elementType) {
889 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
890 return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
891 {elementType});
892}
893
894/// Generate function body of the simplified version of RTNAME(DotProduct)
895/// with signature provided by \p funcOp. The caller is responsible
896/// for saving/restoring the original insertion point of \p builder.
897/// \p funcOp is expected to be empty on entry to this function.
898/// \p arg1ElementTy and \p arg2ElementTy specify elements types
899/// of the underlying array objects - they are used to generate proper
900/// element accesses.
901static void genRuntimeDotBody(fir::FirOpBuilder &builder,
902 mlir::func::FuncOp &funcOp,
903 mlir::Type arg1ElementTy,
904 mlir::Type arg2ElementTy) {
905 // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
906 // T, dimension(:) :: arr1, arr2
907 // T product = 0
908 // integer iter
909 // do iter = 0, extent(arr1)
910 // product = product + arr1[iter] * arr2[iter]
911 // end do
912 // RTNAME(ADotProduct)<T>_simplified = product
913 // end function RTNAME(DotProduct)<T>_simplified
914 auto loc = mlir::UnknownLoc::get(builder.getContext());
915 mlir::Type resultElementType = funcOp.getResultTypes()[0];
916 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
917
918 mlir::IndexType idxTy = builder.getIndexType();
919
920 mlir::Value zero =
921 resultElementType.isa<mlir::FloatType>()
922 ? builder.createRealConstant(loc, resultElementType, 0.0)
923 : builder.createIntegerConstant(loc, resultElementType, 0);
924
925 mlir::Block::BlockArgListType args = funcOp.front().getArguments();
926 mlir::Value arg1 = args[0];
927 mlir::Value arg2 = args[1];
928
929 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
930
931 fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
932 mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
933 mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
934 mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
935 mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
936 mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
937 mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
938 // This version takes the loop trip count from the first argument.
939 // If the first argument's box has unknown (at compilation time)
940 // extent, then it may be better to take the extent from the second
941 // argument - so that after inlining the loop may be better optimized, e.g.
942 // fully unrolled. This requires generating two versions of the simplified
943 // function and some analysis at the call site to choose which version
944 // is more profitable to call.
945 // Note that we can assume that both arguments have the same extent.
946 auto dims =
947 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
948 mlir::Value len = dims.getResult(1);
949 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
950 mlir::Value step = one;
951
952 // We use C indexing here, so len-1 as loopcount
953 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
954 auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
955 /*unordered=*/false,
956 /*finalCountValue=*/false, zero);
957 mlir::Value sumVal = loop.getRegionIterArgs()[0];
958
959 // Begin loop code
960 mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
961 builder.setInsertionPointToStart(loop.getBody());
962
963 mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
964 mlir::Value index = loop.getInductionVar();
965 mlir::Value addr1 =
966 builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
967 mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
968 // Convert to the result type.
969 elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
970
971 mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
972 mlir::Value addr2 =
973 builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
974 mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
975 // Convert to the result type.
976 elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
977
978 if (resultElementType.isa<mlir::FloatType>())
979 sumVal = builder.create<mlir::arith::AddFOp>(
980 loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
981 else if (resultElementType.isa<mlir::IntegerType>())
982 sumVal = builder.create<mlir::arith::AddIOp>(
983 loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
984 else
985 llvm_unreachable("unsupported type");
986
987 builder.create<fir::ResultOp>(loc, sumVal);
988 // End of loop.
989 builder.restoreInsertionPoint(loopEndPt);
990
991 mlir::Value resultVal = loop.getResult(0);
992 builder.create<mlir::func::ReturnOp>(loc, resultVal);
993}
994
995mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
996 fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
997 FunctionTypeGeneratorTy typeGenerator,
998 FunctionBodyGeneratorTy bodyGenerator) {
999 // WARNING: if the function generated here changes its signature
1000 // or behavior (the body code), we should probably embed some
1001 // versioning information into its name, otherwise libraries
1002 // statically linked with older versions of Flang may stop
1003 // working with object files created with newer Flang.
1004 // We can also avoid this by using internal linkage, but
1005 // this may increase the size of final executable/shared library.
1006 std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
1007 // If we already have a function, just return it.
1008 mlir::func::FuncOp newFunc = builder.getNamedFunction(replacementName);
1009 mlir::FunctionType fType = typeGenerator(builder);
1010 if (newFunc) {
1011 assert(newFunc.getFunctionType() == fType &&
1012 "type mismatch for simplified function");
1013 return newFunc;
1014 }
1015
1016 // Need to build the function!
1017 auto loc = mlir::UnknownLoc::get(builder.getContext());
1018 newFunc = builder.createFunction(loc, replacementName, fType);
1019 auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
1020 auto linkage =
1021 mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
1022 newFunc->setAttr("llvm.linkage", linkage);
1023
1024 // Save the position of the original call.
1025 mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
1026
1027 bodyGenerator(builder, newFunc);
1028
1029 // Now back to where we were adding code earlier...
1030 builder.restoreInsertionPoint(insertPt);
1031
1032 return newFunc;
1033}
1034
1035void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction(
1036 fir::CallOp call, const fir::KindMapping &kindMap,
1037 GenReductionBodyTy genBodyFunc) {
1038 // args[1] and args[2] are source filename and line number, ignored.
1039 mlir::Operation::operand_range args = call.getArgs();
1040
1041 const mlir::Value &dim = args[3];
1042 const mlir::Value &mask = args[4];
1043 // dim is zero when it is absent, which is an implementation
1044 // detail in the runtime library.
1045
1046 bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
1047 unsigned rank = getDimCount(args[0]);
1048
1049 // Rank is set to 0 for assumed shape arrays, don't simplify
1050 // in these cases
1051 if (!(dimAndMaskAbsent && rank > 0))
1052 return;
1053
1054 mlir::Type resultType = call.getResult(0).getType();
1055
1056 if (!resultType.isa<mlir::FloatType>() &&
1057 !resultType.isa<mlir::IntegerType>())
1058 return;
1059
1060 auto argType = getArgElementType(args[0]);
1061 if (!argType)
1062 return;
1063 assert(*argType == resultType &&
1064 "Argument/result types mismatch in reduction");
1065
1066 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1067
1068 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1069 std::string fmfString{builder.getFastMathFlagsString()};
1070 std::string funcName =
1071 (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1072 mlir::Twine{rank} +
1073 // We must mangle the generated function name with FastMathFlags
1074 // value.
1075 (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString}))
1076 .str();
1077
1078 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1079 resultType);
1080}
1081
1082void SimplifyIntrinsicsPass::simplifyLogicalDim0Reduction(
1083 fir::CallOp call, const fir::KindMapping &kindMap,
1084 GenReductionBodyTy genBodyFunc) {
1085
1086 mlir::Operation::operand_range args = call.getArgs();
1087 const mlir::Value &dim = args[3];
1088 unsigned rank = getDimCount(args[0]);
1089
1090 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1091 // these cases.
1092 if (!(isZero(dim) && rank > 0))
1093 return;
1094
1095 mlir::Value inputBox = findBoxDef(args[0]);
1096
1097 mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
1098 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1099
1100 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1101
1102 // Treating logicals as integers makes things a lot easier
1103 fir::LogicalType logicalType = {elementType.dyn_cast<fir::LogicalType>()};
1104 fir::KindTy kind = logicalType.getFKind();
1105 mlir::Type intElementType = builder.getIntegerType(kind * 8);
1106
1107 // Mangle kind into function name as it is not done by default
1108 std::string funcName =
1109 (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
1110 mlir::Twine{kind} + "x" + mlir::Twine{rank})
1111 .str();
1112
1113 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1114 intElementType);
1115}
1116
1117void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction(
1118 fir::CallOp call, const fir::KindMapping &kindMap,
1119 GenReductionBodyTy genBodyFunc) {
1120
1121 mlir::Operation::operand_range args = call.getArgs();
1122 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1123 mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1124 unsigned rank = getDimCount(args[0]);
1125
1126 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1127 // these cases. We check for Dim at the end as some logical functions (Any,
1128 // All) set dim to 1 instead of 0 when the argument is not present.
1129 if (funcNameBase.ends_with("Dim") || !(rank > 0))
1130 return;
1131
1132 mlir::Value inputBox = findBoxDef(args[0]);
1133 mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
1134
1135 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1136
1137 // Treating logicals as integers makes things a lot easier
1138 fir::LogicalType logicalType = {elementType.dyn_cast<fir::LogicalType>()};
1139 fir::KindTy kind = logicalType.getFKind();
1140 mlir::Type intElementType = builder.getIntegerType(kind * 8);
1141
1142 // Mangle kind into function name as it is not done by default
1143 std::string funcName =
1144 (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
1145 mlir::Twine{kind} + "x" + mlir::Twine{rank})
1146 .str();
1147
1148 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1149 intElementType);
1150}
1151
1152void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
1153 fir::CallOp call, const fir::KindMapping &kindMap, bool isMax) {
1154
1155 mlir::Operation::operand_range args = call.getArgs();
1156
1157 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1158 mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1159 bool isDim = funcNameBase.ends_with("Dim");
1160 mlir::Value back = args[isDim ? 7 : 6];
1161 if (isTrueOrNotConstant(back))
1162 return;
1163
1164 mlir::Value mask = args[isDim ? 6 : 5];
1165 mlir::Value maskDef = findMaskDef(mask);
1166
1167 // maskDef is set to NULL when the defining op is not one we accept.
1168 // This tends to be because it is a selectOp, in which case let the
1169 // runtime deal with it.
1170 if (maskDef == NULL)
1171 return;
1172
1173 unsigned rank = getDimCount(args[1]);
1174 if ((isDim && rank != 1) || !(rank > 0))
1175 return;
1176
1177 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1178 mlir::Location loc = call.getLoc();
1179 auto inputBox = findBoxDef(args[1]);
1180 mlir::Type inputType = hlfir::getFortranElementType(inputBox.getType());
1181
1182 if (inputType.isa<fir::CharacterType>())
1183 return;
1184
1185 int maskRank;
1186 fir::KindTy kind = 0;
1187 mlir::Type logicalElemType = builder.getI1Type();
1188 if (isOperandAbsent(mask)) {
1189 maskRank = -1;
1190 } else {
1191 maskRank = getDimCount(mask);
1192 mlir::Type maskElemTy = hlfir::getFortranElementType(maskDef.getType());
1193 fir::LogicalType logicalFirType = {maskElemTy.dyn_cast<fir::LogicalType>()};
1194 kind = logicalFirType.getFKind();
1195 // Convert fir::LogicalType to mlir::Type
1196 logicalElemType = logicalFirType;
1197 }
1198
1199 mlir::Operation *outputDef = args[0].getDefiningOp();
1200 mlir::Value outputAlloc = outputDef->getOperand(0);
1201 mlir::Type outType = hlfir::getFortranElementType(outputAlloc.getType());
1202
1203 std::string fmfString{builder.getFastMathFlagsString()};
1204 std::string funcName =
1205 (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1206 mlir::Twine{rank} +
1207 (maskRank >= 0
1208 ? "_Logical" + mlir::Twine{kind} + "x" + mlir::Twine{maskRank}
1209 : "") +
1210 "_")
1211 .str();
1212
1213 llvm::raw_string_ostream nameOS(funcName);
1214 outType.print(nameOS);
1215 if (isDim)
1216 nameOS << '_' << inputType;
1217 nameOS << '_' << fmfString;
1218
1219 auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
1220 return genRuntimeMinlocType(builder, rank);
1221 };
1222 auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType,
1223 isMax, isDim](fir::FirOpBuilder &builder,
1224 mlir::func::FuncOp &funcOp) {
1225 genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType,
1226 logicalElemType, outType, isDim);
1227 };
1228
1229 mlir::func::FuncOp newFunc =
1230 getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
1231 builder.create<fir::CallOp>(loc, newFunc,
1232 mlir::ValueRange{args[0], args[1], mask});
1233 call->dropAllReferences();
1234 call->erase();
1235}
1236
1237void SimplifyIntrinsicsPass::simplifyReductionBody(
1238 fir::CallOp call, const fir::KindMapping &kindMap,
1239 GenReductionBodyTy genBodyFunc, fir::FirOpBuilder &builder,
1240 const mlir::StringRef &funcName, mlir::Type elementType) {
1241
1242 mlir::Operation::operand_range args = call.getArgs();
1243
1244 mlir::Type resultType = call.getResult(0).getType();
1245 unsigned rank = getDimCount(args[0]);
1246
1247 mlir::Location loc = call.getLoc();
1248
1249 auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
1250 return genNoneBoxType(builder, resultType);
1251 };
1252 auto bodyGenerator = [&rank, &genBodyFunc,
1253 &elementType](fir::FirOpBuilder &builder,
1254 mlir::func::FuncOp &funcOp) {
1255 genBodyFunc(builder, funcOp, rank, elementType);
1256 };
1257 // Mangle the function name with the rank value as "x<rank>".
1258 mlir::func::FuncOp newFunc =
1259 getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
1260 auto newCall =
1261 builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
1262 call->replaceAllUsesWith(newCall.getResults());
1263 call->dropAllReferences();
1264 call->erase();
1265}
1266
1267void SimplifyIntrinsicsPass::runOnOperation() {
1268 LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
1269 mlir::ModuleOp module = getOperation();
1270 fir::KindMapping kindMap = fir::getKindMapping(module);
1271 module.walk([&](mlir::Operation *op) {
1272 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
1273 if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
1274 mlir::StringRef funcName = callee.getLeafReference().getValue();
1275 // Replace call to runtime function for SUM when it has single
1276 // argument (no dim or mask argument) for 1D arrays with either
1277 // Integer4 or Real8 types. Other forms are ignored.
1278 // The new function is added to the module.
1279 //
1280 // Prototype for runtime call (from sum.cpp):
1281 // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
1282 // int dim, const Descriptor *mask)
1283 //
1284 if (funcName.starts_with(RTNAME_STRING(Sum))) {
1285 simplifyIntOrFloatReduction(call, kindMap, genRuntimeSumBody);
1286 return;
1287 }
1288 if (funcName.starts_with(RTNAME_STRING(DotProduct))) {
1289 LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
1290 LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
1291 llvm::dbgs() << "\n");
1292 mlir::Operation::operand_range args = call.getArgs();
1293 const mlir::Value &v1 = args[0];
1294 const mlir::Value &v2 = args[1];
1295 mlir::Location loc = call.getLoc();
1296 fir::FirOpBuilder builder{getSimplificationBuilder(op, kindMap)};
1297 // Stringize the builder's FastMathFlags flags for mangling
1298 // the generated function name.
1299 std::string fmfString{builder.getFastMathFlagsString()};
1300
1301 mlir::Type type = call.getResult(0).getType();
1302 if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>())
1303 return;
1304
1305 // Try to find the element types of the boxed arguments.
1306 auto arg1Type = getArgElementType(v1);
1307 auto arg2Type = getArgElementType(v2);
1308
1309 if (!arg1Type || !arg2Type)
1310 return;
1311
1312 // Support only floating point and integer arguments
1313 // now (e.g. logical is skipped here).
1314 if (!arg1Type->isa<mlir::FloatType>() &&
1315 !arg1Type->isa<mlir::IntegerType>())
1316 return;
1317 if (!arg2Type->isa<mlir::FloatType>() &&
1318 !arg2Type->isa<mlir::IntegerType>())
1319 return;
1320
1321 auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
1322 return genRuntimeDotType(builder, type);
1323 };
1324 auto bodyGenerator = [&arg1Type,
1325 &arg2Type](fir::FirOpBuilder &builder,
1326 mlir::func::FuncOp &funcOp) {
1327 genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type);
1328 };
1329
1330 // Suffix the function name with the element types
1331 // of the arguments.
1332 std::string typedFuncName(funcName);
1333 llvm::raw_string_ostream nameOS(typedFuncName);
1334 // We must mangle the generated function name with FastMathFlags
1335 // value.
1336 if (!fmfString.empty())
1337 nameOS << '_' << fmfString;
1338 nameOS << '_';
1339 arg1Type->print(nameOS);
1340 nameOS << '_';
1341 arg2Type->print(nameOS);
1342
1343 mlir::func::FuncOp newFunc = getOrCreateFunction(
1344 builder, typedFuncName, typeGenerator, bodyGenerator);
1345 auto newCall = builder.create<fir::CallOp>(loc, newFunc,
1346 mlir::ValueRange{v1, v2});
1347 call->replaceAllUsesWith(newCall.getResults());
1348 call->dropAllReferences();
1349 call->erase();
1350
1351 LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
1352 llvm::dbgs() << "\n");
1353 return;
1354 }
1355 if (funcName.starts_with(RTNAME_STRING(Maxval))) {
1356 simplifyIntOrFloatReduction(call, kindMap, genRuntimeMaxvalBody);
1357 return;
1358 }
1359 if (funcName.starts_with(RTNAME_STRING(Count))) {
1360 simplifyLogicalDim0Reduction(call, kindMap, genRuntimeCountBody);
1361 return;
1362 }
1363 if (funcName.starts_with(RTNAME_STRING(Any))) {
1364 simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAnyBody);
1365 return;
1366 }
1367 if (funcName.ends_with(RTNAME_STRING(All))) {
1368 simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAllBody);
1369 return;
1370 }
1371 if (funcName.starts_with(RTNAME_STRING(Minloc))) {
1372 simplifyMinMaxlocReduction(call, kindMap, false);
1373 return;
1374 }
1375 if (funcName.starts_with(RTNAME_STRING(Maxloc))) {
1376 simplifyMinMaxlocReduction(call, kindMap, true);
1377 return;
1378 }
1379 }
1380 }
1381 });
1382 LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
1383}
1384
1385void SimplifyIntrinsicsPass::getDependentDialects(
1386 mlir::DialectRegistry &registry) const {
1387 // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
1388 registry.insert<mlir::LLVM::LLVMDialect>();
1389}
1390std::unique_ptr<mlir::Pass> fir::createSimplifyIntrinsicsPass() {
1391 return std::make_unique<SimplifyIntrinsicsPass>();
1392}
1393

source code of flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp