1//===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9//===----------------------------------------------------------------------===//
10/// \file
11/// This pass looks for suitable calls to runtime library for intrinsics that
12/// can be simplified/specialized and replaces with a specialized function.
13///
14/// For example, SUM(arr) can be specialized as a simple function with one loop,
15/// compared to the three arguments (plus file & line info) that the runtime
16/// call has - when the argument is a 1D-array (multiple loops may be needed
17// for higher dimension arrays, of course)
18///
19/// The general idea is that besides making the call simpler, it can also be
20/// inlined by other passes that run after this pass, which further improves
21/// performance, particularly when the work done in the function is trivial
22/// and small in size.
23//===----------------------------------------------------------------------===//
24
25#include "flang/Optimizer/Builder/BoxValue.h"
26#include "flang/Optimizer/Builder/CUFCommon.h"
27#include "flang/Optimizer/Builder/FIRBuilder.h"
28#include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
29#include "flang/Optimizer/Builder/Todo.h"
30#include "flang/Optimizer/Dialect/FIROps.h"
31#include "flang/Optimizer/Dialect/FIRType.h"
32#include "flang/Optimizer/Dialect/Support/FIRContext.h"
33#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
34#include "flang/Optimizer/Transforms/Passes.h"
35#include "flang/Optimizer/Transforms/Utils.h"
36#include "flang/Runtime/entry-names.h"
37#include "flang/Support/Fortran.h"
38#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
39#include "mlir/IR/Matchers.h"
40#include "mlir/IR/Operation.h"
41#include "mlir/Pass/Pass.h"
42#include "mlir/Transforms/DialectConversion.h"
43#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
44#include "mlir/Transforms/RegionUtils.h"
45#include "llvm/Support/Debug.h"
46#include "llvm/Support/raw_ostream.h"
47#include <llvm/Support/ErrorHandling.h>
48#include <mlir/Dialect/Arith/IR/Arith.h>
49#include <mlir/IR/BuiltinTypes.h>
50#include <mlir/IR/Location.h>
51#include <mlir/IR/MLIRContext.h>
52#include <mlir/IR/Value.h>
53#include <mlir/Support/LLVM.h>
54#include <optional>
55
56namespace fir {
57#define GEN_PASS_DEF_SIMPLIFYINTRINSICS
58#include "flang/Optimizer/Transforms/Passes.h.inc"
59} // namespace fir
60
61#define DEBUG_TYPE "flang-simplify-intrinsics"
62
63namespace {
64
65class SimplifyIntrinsicsPass
66 : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
67 using FunctionTypeGeneratorTy =
68 llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
69 using FunctionBodyGeneratorTy =
70 llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
71 using GenReductionBodyTy = llvm::function_ref<void(
72 fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank,
73 mlir::Type elementType)>;
74
75public:
76 using fir::impl::SimplifyIntrinsicsBase<
77 SimplifyIntrinsicsPass>::SimplifyIntrinsicsBase;
78
79 /// Generate a new function implementing a simplified version
80 /// of a Fortran runtime function defined by \p basename name.
81 /// \p typeGenerator is a callback that generates the new function's type.
82 /// \p bodyGenerator is a callback that generates the new function's body.
83 /// The new function is created in the \p builder's Module.
84 mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
85 const mlir::StringRef &basename,
86 FunctionTypeGeneratorTy typeGenerator,
87 FunctionBodyGeneratorTy bodyGenerator);
88 void runOnOperation() override;
89 void getDependentDialects(mlir::DialectRegistry &registry) const override;
90
91private:
92 /// Helper functions to replace a reduction type of call with its
93 /// simplified form. The actual function is generated using a callback
94 /// function.
95 /// \p call is the call to be replaced
96 /// \p kindMap is used to create FIROpBuilder
97 /// \p genBodyFunc is the callback that builds the replacement function
98 void simplifyIntOrFloatReduction(fir::CallOp call,
99 const fir::KindMapping &kindMap,
100 GenReductionBodyTy genBodyFunc);
101 void simplifyLogicalDim0Reduction(fir::CallOp call,
102 const fir::KindMapping &kindMap,
103 GenReductionBodyTy genBodyFunc);
104 void simplifyLogicalDim1Reduction(fir::CallOp call,
105 const fir::KindMapping &kindMap,
106 GenReductionBodyTy genBodyFunc);
107 void simplifyMinMaxlocReduction(fir::CallOp call,
108 const fir::KindMapping &kindMap, bool isMax);
109 void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap,
110 GenReductionBodyTy genBodyFunc,
111 fir::FirOpBuilder &builder,
112 const mlir::StringRef &basename,
113 mlir::Type elementType);
114};
115
116} // namespace
117
118/// Create FirOpBuilder with the provided \p op insertion point
119/// and \p kindMap additionally inheriting FastMathFlags from \p op.
120static fir::FirOpBuilder
121getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) {
122 fir::FirOpBuilder builder{op, kindMap};
123 auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
124 if (!fmi)
125 return builder;
126
127 // Regardless of what default FastMathFlags are used by FirOpBuilder,
128 // override them with FastMathFlags attached to the operation.
129 builder.setFastMathFlags(fmi.getFastMathFlagsAttr().getValue());
130 return builder;
131}
132
133/// Generate function type for the simplified version of RTNAME(Sum) and
134/// similar functions with a fir.box<none> type returning \p elementType.
135static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
136 const mlir::Type &elementType) {
137 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
138 return mlir::FunctionType::get(builder.getContext(), {boxType},
139 {elementType});
140}
141
142template <typename Op>
143Op expectOp(mlir::Value val) {
144 if (Op op = mlir::dyn_cast_or_null<Op>(val.getDefiningOp()))
145 return op;
146 LLVM_DEBUG(llvm::dbgs() << "Didn't find expected " << Op::getOperationName()
147 << '\n');
148 return nullptr;
149}
150
151template <typename Op>
152static mlir::Value findDefSingle(fir::ConvertOp op) {
153 if (auto defOp = expectOp<Op>(op->getOperand(0))) {
154 return defOp.getResult();
155 }
156 return {};
157}
158
159template <typename... Ops>
160static mlir::Value findDef(fir::ConvertOp op) {
161 mlir::Value defOp;
162 // Loop over the operation types given to see if any match, exiting once
163 // a match is found. Cast to void is needed to avoid compiler complaining
164 // that the result of expression is unused
165 (void)((defOp = findDefSingle<Ops>(op), (defOp)) || ...);
166 return defOp;
167}
168
169static bool isOperandAbsent(mlir::Value val) {
170 if (auto op = expectOp<fir::ConvertOp>(val)) {
171 assert(op->getOperands().size() != 0);
172 return mlir::isa_and_nonnull<fir::AbsentOp>(
173 op->getOperand(0).getDefiningOp());
174 }
175 return false;
176}
177
178static bool isTrueOrNotConstant(mlir::Value val) {
179 if (auto op = expectOp<mlir::arith::ConstantOp>(val)) {
180 return !mlir::matchPattern(val, mlir::m_Zero());
181 }
182 return true;
183}
184
185static bool isZero(mlir::Value val) {
186 if (auto op = expectOp<fir::ConvertOp>(val)) {
187 assert(op->getOperands().size() != 0);
188 if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
189 return mlir::matchPattern(defOp, mlir::m_Zero());
190 }
191 return false;
192}
193
194static mlir::Value findBoxDef(mlir::Value val) {
195 if (auto op = expectOp<fir::ConvertOp>(val)) {
196 assert(op->getOperands().size() != 0);
197 return findDef<fir::EmboxOp, fir::ReboxOp>(op);
198 }
199 return {};
200}
201
202static mlir::Value findMaskDef(mlir::Value val) {
203 if (auto op = expectOp<fir::ConvertOp>(val)) {
204 assert(op->getOperands().size() != 0);
205 return findDef<fir::EmboxOp, fir::ReboxOp, fir::AbsentOp>(op);
206 }
207 return {};
208}
209
210static unsigned getDimCount(mlir::Value val) {
211 // In order to find the dimensions count, we look for EmboxOp/ReboxOp
212 // and take the count from its *result* type. Note that in case
213 // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
214 // have different types.
215 // Actually, we can take the box type from the operand of
216 // the first ConvertOp that has non-opaque box type that we meet
217 // going through the ConvertOp chain.
218 if (mlir::Value emboxVal = findBoxDef(val))
219 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(emboxVal.getType()))
220 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy()))
221 return seqTy.getDimension();
222 return 0;
223}
224
225/// Given the call operation's box argument \p val, discover
226/// the element type of the underlying array object.
227/// \returns the element type or std::nullopt if the type cannot
228/// be reliably found.
229/// We expect that the argument is a result of fir.convert
230/// with the destination type of !fir.box<none>.
231static std::optional<mlir::Type> getArgElementType(mlir::Value val) {
232 mlir::Operation *defOp;
233 do {
234 defOp = val.getDefiningOp();
235 // Analyze only sequences of convert operations.
236 if (!mlir::isa<fir::ConvertOp>(defOp))
237 return std::nullopt;
238 val = defOp->getOperand(0);
239 // The convert operation is expected to convert from one
240 // box type to another box type.
241 auto boxType = mlir::cast<fir::BoxType>(val.getType());
242 auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
243 if (!mlir::isa<mlir::NoneType>(elementType))
244 return elementType;
245 } while (true);
246}
247
248using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
249 fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
250 mlir::Value)>;
251using ContinueLoopGenTy = llvm::function_ref<llvm::SmallVector<mlir::Value>(
252 fir::FirOpBuilder &, mlir::Location, mlir::Value)>;
253
254/// Generate the reduction loop into \p funcOp.
255///
256/// \p initVal is a function, called to get the initial value for
257/// the reduction value
258/// \p genBody is called to fill in the actual reduciton operation
259/// for example add for SUM, MAX for MAXVAL, etc.
260/// \p rank is the rank of the input argument.
261/// \p elementType is the type of the elements in the input array,
262/// which may be different to the return type.
263/// \p loopCond is called to generate the condition to continue or
264/// not for IterWhile loops
265/// \p unorderedOrInitalLoopCond contains either a boolean or bool
266/// mlir constant, and controls the inital value for while loops
267/// or if DoLoop is ordered/unordered.
268
269template <typename OP, typename T, int resultIndex>
270static void
271genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
272 fir::InitValGeneratorTy initVal, ContinueLoopGenTy loopCond,
273 T unorderedOrInitialLoopCond, BodyOpGeneratorTy genBody,
274 unsigned rank, mlir::Type elementType, mlir::Location loc) {
275
276 mlir::IndexType idxTy = builder.getIndexType();
277
278 mlir::Block::BlockArgListType args = funcOp.front().getArguments();
279 mlir::Value arg = args[0];
280
281 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
282
283 fir::SequenceType::Shape flatShape(rank,
284 fir::SequenceType::getUnknownExtent());
285 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
286 mlir::Type boxArrTy = fir::BoxType::get(arrTy);
287 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
288 mlir::Type resultType = funcOp.getResultTypes()[0];
289 mlir::Value init = initVal(builder, loc, resultType);
290
291 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
292
293 assert(rank > 0 && "rank cannot be zero");
294 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
295
296 // Compute all the upper bounds before the loop nest.
297 // It is not strictly necessary for performance, since the loop nest
298 // does not have any store operations and any LICM optimization
299 // should be able to optimize the redundancy.
300 for (unsigned i = 0; i < rank; ++i) {
301 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
302 auto dims =
303 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
304 mlir::Value len = dims.getResult(1);
305 // We use C indexing here, so len-1 as loopcount
306 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
307 bounds.push_back(loopCount);
308 }
309 // Create a loop nest consisting of OP operations.
310 // Collect the loops' induction variables into indices array,
311 // which will be used in the innermost loop to load the input
312 // array's element.
313 // The loops are generated such that the innermost loop processes
314 // the 0 dimension.
315 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
316 for (unsigned i = rank; 0 < i; --i) {
317 mlir::Value step = one;
318 mlir::Value loopCount = bounds[i - 1];
319 auto loop = builder.create<OP>(loc, zeroIdx, loopCount, step,
320 unorderedOrInitialLoopCond,
321 /*finalCountValue=*/false, init);
322 init = loop.getRegionIterArgs()[resultIndex];
323 indices.push_back(loop.getInductionVar());
324 // Set insertion point to the loop body so that the next loop
325 // is inserted inside the current one.
326 builder.setInsertionPointToStart(loop.getBody());
327 }
328
329 // Reverse the indices such that they are ordered as:
330 // <dim-0-idx, dim-1-idx, ...>
331 std::reverse(indices.begin(), indices.end());
332 // We are in the innermost loop: generate the reduction body.
333 mlir::Type eleRefTy = builder.getRefType(elementType);
334 mlir::Value addr =
335 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
336 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
337 mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
338 // Generate vector with condition to continue while loop at [0] and result
339 // from current loop at [1] for IterWhileOp loops, just result at [0] for
340 // DoLoopOp loops.
341 llvm::SmallVector<mlir::Value> results = loopCond(builder, loc, reductionVal);
342
343 // Unwind the loop nest and insert ResultOp on each level
344 // to return the updated value of the reduction to the enclosing
345 // loops.
346 for (unsigned i = 0; i < rank; ++i) {
347 auto result = builder.create<fir::ResultOp>(loc, results);
348 // Proceed to the outer loop.
349 auto loop = mlir::cast<OP>(result->getParentOp());
350 results = loop.getResults();
351 // Set insertion point after the loop operation that we have
352 // just processed.
353 builder.setInsertionPointAfter(loop.getOperation());
354 }
355 // End of loop nest. The insertion point is after the outermost loop.
356 // Return the reduction value from the function.
357 builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]);
358}
359
360static llvm::SmallVector<mlir::Value> nopLoopCond(fir::FirOpBuilder &builder,
361 mlir::Location loc,
362 mlir::Value reductionVal) {
363 return {reductionVal};
364}
365
366/// Generate function body of the simplified version of RTNAME(Sum)
367/// with signature provided by \p funcOp. The caller is responsible
368/// for saving/restoring the original insertion point of \p builder.
369/// \p funcOp is expected to be empty on entry to this function.
370/// \p rank specifies the rank of the input argument.
371static void genRuntimeSumBody(fir::FirOpBuilder &builder,
372 mlir::func::FuncOp &funcOp, unsigned rank,
373 mlir::Type elementType) {
374 // function RTNAME(Sum)<T>x<rank>_simplified(arr)
375 // T, dimension(:) :: arr
376 // T sum = 0
377 // integer iter
378 // do iter = 0, extent(arr)
379 // sum = sum + arr[iter]
380 // end do
381 // RTNAME(Sum)<T>x<rank>_simplified = sum
382 // end function RTNAME(Sum)<T>x<rank>_simplified
383 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
384 mlir::Type elementType) {
385 if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
386 const llvm::fltSemantics &sem = ty.getFloatSemantics();
387 return builder.createRealConstant(loc, elementType,
388 llvm::APFloat::getZero(sem));
389 }
390 return builder.createIntegerConstant(loc, elementType, 0);
391 };
392
393 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
394 mlir::Type elementType, mlir::Value elem1,
395 mlir::Value elem2) -> mlir::Value {
396 if (mlir::isa<mlir::FloatType>(elementType))
397 return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2);
398 if (mlir::isa<mlir::IntegerType>(elementType))
399 return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2);
400
401 llvm_unreachable("unsupported type");
402 return {};
403 };
404
405 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
406 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
407
408 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond,
409 false, genBodyOp, rank, elementType,
410 loc);
411}
412
413static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
414 mlir::func::FuncOp &funcOp, unsigned rank,
415 mlir::Type elementType) {
416 auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
417 mlir::Type elementType) {
418 if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
419 const llvm::fltSemantics &sem = ty.getFloatSemantics();
420 return builder.createRealConstant(
421 loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true));
422 }
423 unsigned bits = elementType.getIntOrFloatBitWidth();
424 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
425 return builder.createIntegerConstant(loc, elementType, minInt);
426 };
427
428 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
429 mlir::Type elementType, mlir::Value elem1,
430 mlir::Value elem2) -> mlir::Value {
431 if (mlir::isa<mlir::FloatType>(elementType)) {
432 // arith.maxf later converted to llvm.intr.maxnum does not work
433 // correctly for NaNs and -0.0 (see maxnum/minnum pattern matching
434 // in LLVM's InstCombine pass). Moreover, llvm.intr.maxnum
435 // for F128 operands is lowered into fmaxl call by LLVM.
436 // This libm function may not work properly for F128 arguments
437 // on targets where long double is not F128. It is an LLVM issue,
438 // but we just use normal select here to resolve all the cases.
439 auto compare = builder.create<mlir::arith::CmpFOp>(
440 loc, mlir::arith::CmpFPredicate::OGT, elem1, elem2);
441 return builder.create<mlir::arith::SelectOp>(loc, compare, elem1, elem2);
442 }
443 if (mlir::isa<mlir::IntegerType>(elementType))
444 return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2);
445
446 llvm_unreachable("unsupported type");
447 return {};
448 };
449
450 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
451 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
452
453 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, init, nopLoopCond,
454 false, genBodyOp, rank, elementType,
455 loc);
456}
457
458static void genRuntimeCountBody(fir::FirOpBuilder &builder,
459 mlir::func::FuncOp &funcOp, unsigned rank,
460 mlir::Type elementType) {
461 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
462 mlir::Type elementType) {
463 unsigned bits = elementType.getIntOrFloatBitWidth();
464 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
465 return builder.createIntegerConstant(loc, elementType, zeroInt);
466 };
467
468 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
469 mlir::Type elementType, mlir::Value elem1,
470 mlir::Value elem2) -> mlir::Value {
471 auto zero32 = builder.createIntegerConstant(loc, elementType, 0);
472 auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0);
473 auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1);
474
475 auto compare = builder.create<mlir::arith::CmpIOp>(
476 loc, mlir::arith::CmpIPredicate::eq, elem1, zero32);
477 auto select =
478 builder.create<mlir::arith::SelectOp>(loc, compare, zero64, one64);
479 return builder.create<mlir::arith::AddIOp>(loc, select, elem2);
480 };
481
482 // Count always gets I32 for elementType as it converts logical input to
483 // logical<4> before passing to the function.
484 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
485 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
486
487 genReductionLoop<fir::DoLoopOp, bool, 0>(builder, funcOp, zero, nopLoopCond,
488 false, genBodyOp, rank, elementType,
489 loc);
490}
491
492static void genRuntimeAnyBody(fir::FirOpBuilder &builder,
493 mlir::func::FuncOp &funcOp, unsigned rank,
494 mlir::Type elementType) {
495 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
496 mlir::Type elementType) {
497 return builder.createIntegerConstant(loc, elementType, 0);
498 };
499
500 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
501 mlir::Type elementType, mlir::Value elem1,
502 mlir::Value elem2) -> mlir::Value {
503 auto zero = builder.createIntegerConstant(loc, elementType, 0);
504 return builder.create<mlir::arith::CmpIOp>(
505 loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
506 };
507
508 auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
509 mlir::Value reductionVal) {
510 auto one1 = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
511 auto eor = builder.create<mlir::arith::XOrIOp>(loc, reductionVal, one1);
512 llvm::SmallVector<mlir::Value> results = {eor, reductionVal};
513 return results;
514 };
515
516 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
517 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
518 mlir::Value ok = builder.createBool(loc, true);
519
520 genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
521 builder, funcOp, zero, continueCond, ok, genBodyOp, rank, elementType,
522 loc);
523}
524
525static void genRuntimeAllBody(fir::FirOpBuilder &builder,
526 mlir::func::FuncOp &funcOp, unsigned rank,
527 mlir::Type elementType) {
528 auto one = [](fir::FirOpBuilder builder, mlir::Location loc,
529 mlir::Type elementType) {
530 return builder.createIntegerConstant(loc, elementType, 1);
531 };
532
533 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
534 mlir::Type elementType, mlir::Value elem1,
535 mlir::Value elem2) -> mlir::Value {
536 auto zero = builder.createIntegerConstant(loc, elementType, 0);
537 return builder.create<mlir::arith::CmpIOp>(
538 loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
539 };
540
541 auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
542 mlir::Value reductionVal) {
543 llvm::SmallVector<mlir::Value> results = {reductionVal, reductionVal};
544 return results;
545 };
546
547 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
548 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
549 mlir::Value ok = builder.createBool(loc, true);
550
551 genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
552 builder, funcOp, one, continueCond, ok, genBodyOp, rank, elementType,
553 loc);
554}
555
556static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder,
557 unsigned int rank) {
558 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
559 mlir::Type boxRefType = builder.getRefType(boxType);
560
561 return mlir::FunctionType::get(builder.getContext(),
562 {boxRefType, boxType, boxType}, {});
563}
564
565// Produces a loop nest for a Minloc intrinsic.
566void fir::genMinMaxlocReductionLoop(
567 fir::FirOpBuilder &builder, mlir::Value array,
568 fir::InitValGeneratorTy initVal, fir::MinlocBodyOpGeneratorTy genBody,
569 fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType,
570 mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr,
571 bool maskMayBeLogicalScalar) {
572 mlir::IndexType idxTy = builder.getIndexType();
573
574 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
575
576 fir::SequenceType::Shape flatShape(rank,
577 fir::SequenceType::getUnknownExtent());
578 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
579 mlir::Type boxArrTy = fir::BoxType::get(arrTy);
580 array = builder.create<fir::ConvertOp>(loc, boxArrTy, array);
581
582 mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType());
583 mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
584 mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0);
585 mlir::Value flagRef = builder.createTemporary(loc, resultElemType);
586 builder.create<fir::StoreOp>(loc, zero, flagRef);
587
588 mlir::Value init = initVal(builder, loc, elementType);
589 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
590
591 assert(rank > 0 && "rank cannot be zero");
592 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
593
594 // Compute all the upper bounds before the loop nest.
595 // It is not strictly necessary for performance, since the loop nest
596 // does not have any store operations and any LICM optimization
597 // should be able to optimize the redundancy.
598 for (unsigned i = 0; i < rank; ++i) {
599 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
600 auto dims =
601 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
602 mlir::Value len = dims.getResult(1);
603 // We use C indexing here, so len-1 as loopcount
604 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
605 bounds.push_back(loopCount);
606 }
607 // Create a loop nest consisting of OP operations.
608 // Collect the loops' induction variables into indices array,
609 // which will be used in the innermost loop to load the input
610 // array's element.
611 // The loops are generated such that the innermost loop processes
612 // the 0 dimension.
613 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
614 for (unsigned i = rank; 0 < i; --i) {
615 mlir::Value step = one;
616 mlir::Value loopCount = bounds[i - 1];
617 auto loop =
618 builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false,
619 /*finalCountValue=*/false, init);
620 init = loop.getRegionIterArgs()[0];
621 indices.push_back(loop.getInductionVar());
622 // Set insertion point to the loop body so that the next loop
623 // is inserted inside the current one.
624 builder.setInsertionPointToStart(loop.getBody());
625 }
626
627 // Reverse the indices such that they are ordered as:
628 // <dim-0-idx, dim-1-idx, ...>
629 std::reverse(indices.begin(), indices.end());
630 mlir::Value reductionVal =
631 genBody(builder, loc, elementType, array, flagRef, init, indices);
632
633 // Unwind the loop nest and insert ResultOp on each level
634 // to return the updated value of the reduction to the enclosing
635 // loops.
636 for (unsigned i = 0; i < rank; ++i) {
637 auto result = builder.create<fir::ResultOp>(loc, reductionVal);
638 // Proceed to the outer loop.
639 auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
640 reductionVal = loop.getResult(0);
641 // Set insertion point after the loop operation that we have
642 // just processed.
643 builder.setInsertionPointAfter(loop.getOperation());
644 }
645 // End of loop nest. The insertion point is after the outermost loop.
646 if (maskMayBeLogicalScalar) {
647 if (fir::IfOp ifOp =
648 mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) {
649 builder.create<fir::ResultOp>(loc, reductionVal);
650 builder.setInsertionPointAfter(ifOp);
651 // Redefine flagSet to escape scope of ifOp
652 flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
653 reductionVal = ifOp.getResult(0);
654 }
655 }
656}
657
658static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
659 mlir::func::FuncOp &funcOp, bool isMax,
660 unsigned rank, int maskRank,
661 mlir::Type elementType,
662 mlir::Type maskElemType,
663 mlir::Type resultElemTy, bool isDim) {
664 auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
665 mlir::Type elementType) {
666 if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
667 const llvm::fltSemantics &sem = ty.getFloatSemantics();
668 llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
669 return builder.createRealConstant(loc, elementType, limit);
670 }
671 unsigned bits = elementType.getIntOrFloatBitWidth();
672 int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits)
673 : llvm::APInt::getSignedMaxValue(bits))
674 .getSExtValue();
675 return builder.createIntegerConstant(loc, elementType, initValue);
676 };
677
678 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
679 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
680
681 mlir::Value mask = funcOp.front().getArgument(2);
682
683 // Set up result array in case of early exit / 0 length array
684 mlir::IndexType idxTy = builder.getIndexType();
685 mlir::Type resultTy = fir::SequenceType::get(rank, resultElemTy);
686 mlir::Type resultHeapTy = fir::HeapType::get(resultTy);
687 mlir::Type resultBoxTy = fir::BoxType::get(resultHeapTy);
688
689 mlir::Value returnValue = builder.createIntegerConstant(loc, resultElemTy, 0);
690 mlir::Value resultArrSize = builder.createIntegerConstant(loc, idxTy, rank);
691
692 mlir::Value resultArrInit = builder.create<fir::AllocMemOp>(loc, resultTy);
693 mlir::Value resultArrShape = builder.create<fir::ShapeOp>(loc, resultArrSize);
694 mlir::Value resultArr = builder.create<fir::EmboxOp>(
695 loc, resultBoxTy, resultArrInit, resultArrShape);
696
697 mlir::Type resultRefTy = builder.getRefType(resultElemTy);
698
699 if (maskRank > 0) {
700 fir::SequenceType::Shape flatShape(rank,
701 fir::SequenceType::getUnknownExtent());
702 mlir::Type maskTy = fir::SequenceType::get(flatShape, maskElemType);
703 mlir::Type boxMaskTy = fir::BoxType::get(maskTy);
704 mask = builder.create<fir::ConvertOp>(loc, boxMaskTy, mask);
705 }
706
707 for (unsigned int i = 0; i < rank; ++i) {
708 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
709 mlir::Value resultElemAddr =
710 builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr, index);
711 builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr);
712 }
713
714 auto genBodyOp =
715 [&rank, &resultArr, isMax, &mask, &maskElemType, &maskRank](
716 fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType,
717 mlir::Value array, mlir::Value flagRef, mlir::Value reduction,
718 const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
719 // We are in the innermost loop: generate the reduction body.
720 if (maskRank > 0) {
721 mlir::Type logicalRef = builder.getRefType(maskElemType);
722 mlir::Value maskAddr =
723 builder.create<fir::CoordinateOp>(loc, logicalRef, mask, indices);
724 mlir::Value maskElem = builder.create<fir::LoadOp>(loc, maskAddr);
725
726 // fir::IfOp requires argument to be I1 - won't accept logical or any
727 // other Integer.
728 mlir::Type ifCompatType = builder.getI1Type();
729 mlir::Value ifCompatElem =
730 builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);
731
732 llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
733 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
734 /*withElseRegion=*/true);
735 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
736 }
737
738 // Set flag that mask was true at some point
739 mlir::Value flagSet = builder.createIntegerConstant(
740 loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
741 mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
742 mlir::Type eleRefTy = builder.getRefType(elementType);
743 mlir::Value addr =
744 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
745 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
746
747 mlir::Value cmp;
748 if (mlir::isa<mlir::FloatType>(elementType)) {
749 // For FP reductions we want the first smallest value to be used, that
750 // is not NaN. A OGL/OLT condition will usually work for this unless all
751 // the values are Nan or Inf. This follows the same logic as
752 // NumericCompare for Minloc/Maxlox in extrema.cpp.
753 cmp = builder.create<mlir::arith::CmpFOp>(
754 loc,
755 isMax ? mlir::arith::CmpFPredicate::OGT
756 : mlir::arith::CmpFPredicate::OLT,
757 elem, reduction);
758
759 mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
760 loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
761 mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
762 loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
763 cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
764 cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
765 } else if (mlir::isa<mlir::IntegerType>(elementType)) {
766 cmp = builder.create<mlir::arith::CmpIOp>(
767 loc,
768 isMax ? mlir::arith::CmpIPredicate::sgt
769 : mlir::arith::CmpIPredicate::slt,
770 elem, reduction);
771 } else {
772 llvm_unreachable("unsupported type");
773 }
774
775 // The condition used for the loop is isFirst || <the condition above>.
776 isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
777 isFirst = builder.create<mlir::arith::XOrIOp>(
778 loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
779 cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
780 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
781 /*withElseRegion*/ true);
782
783 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
784 builder.create<fir::StoreOp>(loc, flagSet, flagRef);
785 mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
786 mlir::Type returnRefTy = builder.getRefType(resultElemTy);
787 mlir::IndexType idxTy = builder.getIndexType();
788
789 mlir::Value one = builder.createIntegerConstant(loc, resultElemTy, 1);
790
791 for (unsigned int i = 0; i < rank; ++i) {
792 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
793 mlir::Value resultElemAddr =
794 builder.create<fir::CoordinateOp>(loc, returnRefTy, resultArr, index);
795 mlir::Value convert =
796 builder.create<fir::ConvertOp>(loc, resultElemTy, indices[i]);
797 mlir::Value fortranIndex =
798 builder.create<mlir::arith::AddIOp>(loc, convert, one);
799 builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr);
800 }
801 builder.create<fir::ResultOp>(loc, elem);
802 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
803 builder.create<fir::ResultOp>(loc, reduction);
804 builder.setInsertionPointAfter(ifOp);
805 mlir::Value reductionVal = ifOp.getResult(0);
806
807 // Close the mask if needed
808 if (maskRank > 0) {
809 fir::IfOp ifOp =
810 mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp());
811 builder.create<fir::ResultOp>(loc, reductionVal);
812 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
813 builder.create<fir::ResultOp>(loc, reduction);
814 reductionVal = ifOp.getResult(0);
815 builder.setInsertionPointAfter(ifOp);
816 }
817
818 return reductionVal;
819 };
820
821 // if mask is a logical scalar, we can check its value before the main loop
822 // and either ignore the fact it is there or exit early.
823 if (maskRank == 0) {
824 mlir::Type i1Type = builder.getI1Type();
825 mlir::Type logical = maskElemType;
826 mlir::Type logicalRefTy = builder.getRefType(logical);
827 mlir::Value condAddr =
828 builder.create<fir::BoxAddrOp>(loc, logicalRefTy, mask);
829 mlir::Value cond = builder.create<fir::LoadOp>(loc, condAddr);
830 mlir::Value condI1 = builder.create<fir::ConvertOp>(loc, i1Type, cond);
831
832 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, condI1,
833 /*withElseRegion=*/true);
834
835 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
836 mlir::Value basicValue;
837 if (mlir::isa<mlir::IntegerType>(elementType)) {
838 basicValue = builder.createIntegerConstant(loc, elementType, 0);
839 } else {
840 basicValue = builder.createRealConstant(loc, elementType, 0);
841 }
842 builder.create<fir::ResultOp>(loc, basicValue);
843
844 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
845 }
846 auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
847 const mlir::Type &resultElemType, mlir::Value resultArr,
848 mlir::Value index) {
849 mlir::Type resultRefTy = builder.getRefType(resultElemType);
850 return builder.create<fir::CoordinateOp>(loc, resultRefTy, resultArr,
851 index);
852 };
853
854 genMinMaxlocReductionLoop(builder, funcOp.front().getArgument(1), init,
855 genBodyOp, getAddrFn, rank, elementType, loc,
856 maskElemType, resultArr, maskRank == 0);
857
858 // Store newly created output array to the reference passed in
859 if (isDim) {
860 mlir::Type resultBoxTy =
861 fir::BoxType::get(fir::HeapType::get(resultElemTy));
862 mlir::Value outputArr = builder.create<fir::ConvertOp>(
863 loc, builder.getRefType(resultBoxTy), funcOp.front().getArgument(0));
864 mlir::Value resultArrScalar = builder.create<fir::ConvertOp>(
865 loc, fir::HeapType::get(resultElemTy), resultArrInit);
866 mlir::Value resultBox =
867 builder.create<fir::EmboxOp>(loc, resultBoxTy, resultArrScalar);
868 builder.create<fir::StoreOp>(loc, resultBox, outputArr);
869 } else {
870 fir::SequenceType::Shape resultShape(1, rank);
871 mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy);
872 mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy);
873 mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy);
874 mlir::Type outputRefTy = builder.getRefType(outputBoxTy);
875 mlir::Value outputArr = builder.create<fir::ConvertOp>(
876 loc, outputRefTy, funcOp.front().getArgument(0));
877 builder.create<fir::StoreOp>(loc, resultArr, outputArr);
878 }
879
880 builder.create<mlir::func::ReturnOp>(loc);
881}
882
883/// Generate function type for the simplified version of RTNAME(DotProduct)
884/// operating on the given \p elementType.
885static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder,
886 const mlir::Type &elementType) {
887 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
888 return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
889 {elementType});
890}
891
892/// Generate function body of the simplified version of RTNAME(DotProduct)
893/// with signature provided by \p funcOp. The caller is responsible
894/// for saving/restoring the original insertion point of \p builder.
895/// \p funcOp is expected to be empty on entry to this function.
896/// \p arg1ElementTy and \p arg2ElementTy specify elements types
897/// of the underlying array objects - they are used to generate proper
898/// element accesses.
899static void genRuntimeDotBody(fir::FirOpBuilder &builder,
900 mlir::func::FuncOp &funcOp,
901 mlir::Type arg1ElementTy,
902 mlir::Type arg2ElementTy) {
903 // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
904 // T, dimension(:) :: arr1, arr2
905 // T product = 0
906 // integer iter
907 // do iter = 0, extent(arr1)
908 // product = product + arr1[iter] * arr2[iter]
909 // end do
910 // RTNAME(ADotProduct)<T>_simplified = product
911 // end function RTNAME(DotProduct)<T>_simplified
912 auto loc = mlir::UnknownLoc::get(builder.getContext());
913 mlir::Type resultElementType = funcOp.getResultTypes()[0];
914 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
915
916 mlir::IndexType idxTy = builder.getIndexType();
917
918 mlir::Value zero =
919 mlir::isa<mlir::FloatType>(resultElementType)
920 ? builder.createRealConstant(loc, resultElementType, 0.0)
921 : builder.createIntegerConstant(loc, resultElementType, 0);
922
923 mlir::Block::BlockArgListType args = funcOp.front().getArguments();
924 mlir::Value arg1 = args[0];
925 mlir::Value arg2 = args[1];
926
927 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
928
929 fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
930 mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
931 mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
932 mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
933 mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
934 mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
935 mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
936 // This version takes the loop trip count from the first argument.
937 // If the first argument's box has unknown (at compilation time)
938 // extent, then it may be better to take the extent from the second
939 // argument - so that after inlining the loop may be better optimized, e.g.
940 // fully unrolled. This requires generating two versions of the simplified
941 // function and some analysis at the call site to choose which version
942 // is more profitable to call.
943 // Note that we can assume that both arguments have the same extent.
944 auto dims =
945 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
946 mlir::Value len = dims.getResult(1);
947 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
948 mlir::Value step = one;
949
950 // We use C indexing here, so len-1 as loopcount
951 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
952 auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
953 /*unordered=*/false,
954 /*finalCountValue=*/false, zero);
955 mlir::Value sumVal = loop.getRegionIterArgs()[0];
956
957 // Begin loop code
958 mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
959 builder.setInsertionPointToStart(loop.getBody());
960
961 mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
962 mlir::Value index = loop.getInductionVar();
963 mlir::Value addr1 =
964 builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
965 mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
966 // Convert to the result type.
967 elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
968
969 mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
970 mlir::Value addr2 =
971 builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
972 mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
973 // Convert to the result type.
974 elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
975
976 if (mlir::isa<mlir::FloatType>(resultElementType))
977 sumVal = builder.create<mlir::arith::AddFOp>(
978 loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
979 else if (mlir::isa<mlir::IntegerType>(resultElementType))
980 sumVal = builder.create<mlir::arith::AddIOp>(
981 loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
982 else
983 llvm_unreachable("unsupported type");
984
985 builder.create<fir::ResultOp>(loc, sumVal);
986 // End of loop.
987 builder.restoreInsertionPoint(loopEndPt);
988
989 mlir::Value resultVal = loop.getResult(0);
990 builder.create<mlir::func::ReturnOp>(loc, resultVal);
991}
992
993mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
994 fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
995 FunctionTypeGeneratorTy typeGenerator,
996 FunctionBodyGeneratorTy bodyGenerator) {
997 // WARNING: if the function generated here changes its signature
998 // or behavior (the body code), we should probably embed some
999 // versioning information into its name, otherwise libraries
1000 // statically linked with older versions of Flang may stop
1001 // working with object files created with newer Flang.
1002 // We can also avoid this by using internal linkage, but
1003 // this may increase the size of final executable/shared library.
1004 std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
1005 // If we already have a function, just return it.
1006 mlir::func::FuncOp newFunc = builder.getNamedFunction(replacementName);
1007 mlir::FunctionType fType = typeGenerator(builder);
1008 if (newFunc) {
1009 assert(newFunc.getFunctionType() == fType &&
1010 "type mismatch for simplified function");
1011 return newFunc;
1012 }
1013
1014 // Need to build the function!
1015 auto loc = mlir::UnknownLoc::get(builder.getContext());
1016 newFunc = builder.createFunction(loc, replacementName, fType);
1017 auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
1018 auto linkage =
1019 mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
1020 newFunc->setAttr("llvm.linkage", linkage);
1021
1022 // Save the position of the original call.
1023 mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
1024
1025 bodyGenerator(builder, newFunc);
1026
1027 // Now back to where we were adding code earlier...
1028 builder.restoreInsertionPoint(insertPt);
1029
1030 return newFunc;
1031}
1032
1033void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction(
1034 fir::CallOp call, const fir::KindMapping &kindMap,
1035 GenReductionBodyTy genBodyFunc) {
1036 // args[1] and args[2] are source filename and line number, ignored.
1037 mlir::Operation::operand_range args = call.getArgs();
1038
1039 const mlir::Value &dim = args[3];
1040 const mlir::Value &mask = args[4];
1041 // dim is zero when it is absent, which is an implementation
1042 // detail in the runtime library.
1043
1044 bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
1045 unsigned rank = getDimCount(args[0]);
1046
1047 // Rank is set to 0 for assumed shape arrays, don't simplify
1048 // in these cases
1049 if (!(dimAndMaskAbsent && rank > 0))
1050 return;
1051
1052 mlir::Type resultType = call.getResult(0).getType();
1053
1054 if (!mlir::isa<mlir::FloatType>(resultType) &&
1055 !mlir::isa<mlir::IntegerType>(resultType))
1056 return;
1057
1058 auto argType = getArgElementType(args[0]);
1059 if (!argType)
1060 return;
1061 assert(*argType == resultType &&
1062 "Argument/result types mismatch in reduction");
1063
1064 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1065
1066 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1067 std::string fmfString{builder.getFastMathFlagsString()};
1068 std::string funcName =
1069 (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1070 mlir::Twine{rank} +
1071 // We must mangle the generated function name with FastMathFlags
1072 // value.
1073 (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString}))
1074 .str();
1075
1076 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1077 resultType);
1078}
1079
1080void SimplifyIntrinsicsPass::simplifyLogicalDim0Reduction(
1081 fir::CallOp call, const fir::KindMapping &kindMap,
1082 GenReductionBodyTy genBodyFunc) {
1083
1084 mlir::Operation::operand_range args = call.getArgs();
1085 const mlir::Value &dim = args[3];
1086 unsigned rank = getDimCount(args[0]);
1087
1088 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1089 // these cases.
1090 if (!(isZero(dim) && rank > 0))
1091 return;
1092
1093 mlir::Value inputBox = findBoxDef(args[0]);
1094
1095 mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
1096 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1097
1098 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1099
1100 // Treating logicals as integers makes things a lot easier
1101 fir::LogicalType logicalType = {
1102 mlir::dyn_cast<fir::LogicalType>(elementType)};
1103 fir::KindTy kind = logicalType.getFKind();
1104 mlir::Type intElementType = builder.getIntegerType(kind * 8);
1105
1106 // Mangle kind into function name as it is not done by default
1107 std::string funcName =
1108 (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
1109 mlir::Twine{kind} + "x" + mlir::Twine{rank})
1110 .str();
1111
1112 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1113 intElementType);
1114}
1115
1116void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction(
1117 fir::CallOp call, const fir::KindMapping &kindMap,
1118 GenReductionBodyTy genBodyFunc) {
1119
1120 mlir::Operation::operand_range args = call.getArgs();
1121 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1122 mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1123 unsigned rank = getDimCount(args[0]);
1124
1125 // getDimCount returns a rank of 0 for assumed shape arrays, don't simplify in
1126 // these cases. We check for Dim at the end as some logical functions (Any,
1127 // All) set dim to 1 instead of 0 when the argument is not present.
1128 if (funcNameBase.ends_with("Dim") || !(rank > 0))
1129 return;
1130
1131 mlir::Value inputBox = findBoxDef(args[0]);
1132 mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
1133
1134 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1135
1136 // Treating logicals as integers makes things a lot easier
1137 fir::LogicalType logicalType = {
1138 mlir::dyn_cast<fir::LogicalType>(elementType)};
1139 fir::KindTy kind = logicalType.getFKind();
1140 mlir::Type intElementType = builder.getIntegerType(kind * 8);
1141
1142 // Mangle kind into function name as it is not done by default
1143 std::string funcName =
1144 (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
1145 mlir::Twine{kind} + "x" + mlir::Twine{rank})
1146 .str();
1147
1148 simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName,
1149 intElementType);
1150}
1151
1152void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
1153 fir::CallOp call, const fir::KindMapping &kindMap, bool isMax) {
1154
1155 mlir::Operation::operand_range args = call.getArgs();
1156
1157 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1158 mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1159 bool isDim = funcNameBase.ends_with("Dim");
1160 mlir::Value back = args[isDim ? 7 : 6];
1161 if (isTrueOrNotConstant(back))
1162 return;
1163
1164 mlir::Value mask = args[isDim ? 6 : 5];
1165 mlir::Value maskDef = findMaskDef(mask);
1166
1167 // maskDef is set to NULL when the defining op is not one we accept.
1168 // This tends to be because it is a selectOp, in which case let the
1169 // runtime deal with it.
1170 if (maskDef == NULL)
1171 return;
1172
1173 unsigned rank = getDimCount(args[1]);
1174 if ((isDim && rank != 1) || !(rank > 0))
1175 return;
1176
1177 fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
1178 mlir::Location loc = call.getLoc();
1179 auto inputBox = findBoxDef(args[1]);
1180 mlir::Type inputType = hlfir::getFortranElementType(inputBox.getType());
1181
1182 if (mlir::isa<fir::CharacterType>(inputType))
1183 return;
1184
1185 int maskRank;
1186 fir::KindTy kind = 0;
1187 mlir::Type logicalElemType = builder.getI1Type();
1188 if (isOperandAbsent(mask)) {
1189 maskRank = -1;
1190 } else {
1191 maskRank = getDimCount(mask);
1192 mlir::Type maskElemTy = hlfir::getFortranElementType(maskDef.getType());
1193 fir::LogicalType logicalFirType = {
1194 mlir::dyn_cast<fir::LogicalType>(maskElemTy)};
1195 kind = logicalFirType.getFKind();
1196 // Convert fir::LogicalType to mlir::Type
1197 logicalElemType = logicalFirType;
1198 }
1199
1200 mlir::Operation *outputDef = args[0].getDefiningOp();
1201 mlir::Value outputAlloc = outputDef->getOperand(0);
1202 mlir::Type outType = hlfir::getFortranElementType(outputAlloc.getType());
1203
1204 std::string fmfString{builder.getFastMathFlagsString()};
1205 std::string funcName =
1206 (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1207 mlir::Twine{rank} +
1208 (maskRank >= 0
1209 ? "_Logical" + mlir::Twine{kind} + "x" + mlir::Twine{maskRank}
1210 : "") +
1211 "_")
1212 .str();
1213
1214 llvm::raw_string_ostream nameOS(funcName);
1215 outType.print(nameOS);
1216 if (isDim)
1217 nameOS << '_' << inputType;
1218 nameOS << '_' << fmfString;
1219
1220 auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
1221 return genRuntimeMinlocType(builder, rank);
1222 };
1223 auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType,
1224 isMax, isDim](fir::FirOpBuilder &builder,
1225 mlir::func::FuncOp &funcOp) {
1226 genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType,
1227 logicalElemType, outType, isDim);
1228 };
1229
1230 mlir::func::FuncOp newFunc =
1231 getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
1232 builder.create<fir::CallOp>(loc, newFunc,
1233 mlir::ValueRange{args[0], args[1], mask});
1234 call->dropAllReferences();
1235 call->erase();
1236}
1237
1238void SimplifyIntrinsicsPass::simplifyReductionBody(
1239 fir::CallOp call, const fir::KindMapping &kindMap,
1240 GenReductionBodyTy genBodyFunc, fir::FirOpBuilder &builder,
1241 const mlir::StringRef &funcName, mlir::Type elementType) {
1242
1243 mlir::Operation::operand_range args = call.getArgs();
1244
1245 mlir::Type resultType = call.getResult(0).getType();
1246 unsigned rank = getDimCount(args[0]);
1247
1248 mlir::Location loc = call.getLoc();
1249
1250 auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
1251 return genNoneBoxType(builder, resultType);
1252 };
1253 auto bodyGenerator = [&rank, &genBodyFunc,
1254 &elementType](fir::FirOpBuilder &builder,
1255 mlir::func::FuncOp &funcOp) {
1256 genBodyFunc(builder, funcOp, rank, elementType);
1257 };
1258 // Mangle the function name with the rank value as "x<rank>".
1259 mlir::func::FuncOp newFunc =
1260 getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
1261 auto newCall =
1262 builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
1263 call->replaceAllUsesWith(newCall.getResults());
1264 call->dropAllReferences();
1265 call->erase();
1266}
1267
1268void SimplifyIntrinsicsPass::runOnOperation() {
1269 LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
1270 mlir::ModuleOp module = getOperation();
1271 fir::KindMapping kindMap = fir::getKindMapping(module);
1272 module.walk([&](mlir::Operation *op) {
1273 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
1274 if (cuf::isCUDADeviceContext(op))
1275 return;
1276 if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
1277 mlir::StringRef funcName = callee.getLeafReference().getValue();
1278 // Replace call to runtime function for SUM when it has single
1279 // argument (no dim or mask argument) for 1D arrays with either
1280 // Integer4 or Real8 types. Other forms are ignored.
1281 // The new function is added to the module.
1282 //
1283 // Prototype for runtime call (from sum.cpp):
1284 // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
1285 // int dim, const Descriptor *mask)
1286 //
1287 if (funcName.starts_with(RTNAME_STRING(Sum))) {
1288 simplifyIntOrFloatReduction(call, kindMap, genRuntimeSumBody);
1289 return;
1290 }
1291 if (funcName.starts_with(RTNAME_STRING(DotProduct))) {
1292 LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
1293 LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
1294 llvm::dbgs() << "\n");
1295 mlir::Operation::operand_range args = call.getArgs();
1296 const mlir::Value &v1 = args[0];
1297 const mlir::Value &v2 = args[1];
1298 mlir::Location loc = call.getLoc();
1299 fir::FirOpBuilder builder{getSimplificationBuilder(op, kindMap)};
1300 // Stringize the builder's FastMathFlags flags for mangling
1301 // the generated function name.
1302 std::string fmfString{builder.getFastMathFlagsString()};
1303
1304 mlir::Type type = call.getResult(0).getType();
1305 if (!mlir::isa<mlir::FloatType>(type) &&
1306 !mlir::isa<mlir::IntegerType>(type))
1307 return;
1308
1309 // Try to find the element types of the boxed arguments.
1310 auto arg1Type = getArgElementType(v1);
1311 auto arg2Type = getArgElementType(v2);
1312
1313 if (!arg1Type || !arg2Type)
1314 return;
1315
1316 // Support only floating point and integer arguments
1317 // now (e.g. logical is skipped here).
1318 if (!mlir::isa<mlir::FloatType, mlir::IntegerType>(*arg1Type))
1319 return;
1320 if (!mlir::isa<mlir::FloatType, mlir::IntegerType>(*arg2Type))
1321 return;
1322
1323 auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
1324 return genRuntimeDotType(builder, type);
1325 };
1326 auto bodyGenerator = [&arg1Type,
1327 &arg2Type](fir::FirOpBuilder &builder,
1328 mlir::func::FuncOp &funcOp) {
1329 genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type);
1330 };
1331
1332 // Suffix the function name with the element types
1333 // of the arguments.
1334 std::string typedFuncName(funcName);
1335 llvm::raw_string_ostream nameOS(typedFuncName);
1336 // We must mangle the generated function name with FastMathFlags
1337 // value.
1338 if (!fmfString.empty())
1339 nameOS << '_' << fmfString;
1340 nameOS << '_';
1341 arg1Type->print(nameOS);
1342 nameOS << '_';
1343 arg2Type->print(nameOS);
1344
1345 mlir::func::FuncOp newFunc = getOrCreateFunction(
1346 builder, typedFuncName, typeGenerator, bodyGenerator);
1347 auto newCall = builder.create<fir::CallOp>(loc, newFunc,
1348 mlir::ValueRange{v1, v2});
1349 call->replaceAllUsesWith(newCall.getResults());
1350 call->dropAllReferences();
1351 call->erase();
1352
1353 LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
1354 llvm::dbgs() << "\n");
1355 return;
1356 }
1357 if (funcName.starts_with(RTNAME_STRING(Maxval))) {
1358 simplifyIntOrFloatReduction(call, kindMap, genRuntimeMaxvalBody);
1359 return;
1360 }
1361 if (funcName.starts_with(RTNAME_STRING(Count))) {
1362 simplifyLogicalDim0Reduction(call, kindMap, genRuntimeCountBody);
1363 return;
1364 }
1365 if (funcName.starts_with(RTNAME_STRING(Any))) {
1366 simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAnyBody);
1367 return;
1368 }
1369 if (funcName.ends_with(RTNAME_STRING(All))) {
1370 simplifyLogicalDim1Reduction(call, kindMap, genRuntimeAllBody);
1371 return;
1372 }
1373 if (funcName.starts_with(RTNAME_STRING(Minloc))) {
1374 simplifyMinMaxlocReduction(call, kindMap, false);
1375 return;
1376 }
1377 if (funcName.starts_with(RTNAME_STRING(Maxloc))) {
1378 simplifyMinMaxlocReduction(call, kindMap, true);
1379 return;
1380 }
1381 }
1382 }
1383 });
1384 LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
1385}
1386
1387void SimplifyIntrinsicsPass::getDependentDialects(
1388 mlir::DialectRegistry &registry) const {
1389 // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
1390 registry.insert<mlir::LLVM::LLVMDialect>();
1391}
1392

source code of flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp