1//===- SimplifyHLFIRIntrinsics.cpp - Simplify HLFIR Intrinsics ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// Normally transformational intrinsics are lowered to calls to runtime
9// functions. However, some cases of the intrinsics are faster when inlined
10// into the calling function.
11//===----------------------------------------------------------------------===//
12
13#include "flang/Optimizer/Builder/Complex.h"
14#include "flang/Optimizer/Builder/FIRBuilder.h"
15#include "flang/Optimizer/Builder/HLFIRTools.h"
16#include "flang/Optimizer/Builder/IntrinsicCall.h"
17#include "flang/Optimizer/Dialect/FIRDialect.h"
18#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
19#include "flang/Optimizer/HLFIR/HLFIROps.h"
20#include "flang/Optimizer/HLFIR/Passes.h"
21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/IR/Location.h"
23#include "mlir/Pass/Pass.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25
26namespace hlfir {
27#define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS
28#include "flang/Optimizer/HLFIR/Passes.h.inc"
29} // namespace hlfir
30
31#define DEBUG_TYPE "simplify-hlfir-intrinsics"
32
33static llvm::cl::opt<bool> forceMatmulAsElemental(
34 "flang-inline-matmul-as-elemental",
35 llvm::cl::desc("Expand hlfir.matmul as elemental operation"),
36 llvm::cl::init(false));
37
38namespace {
39
40// Helper class to generate operations related to computing
41// product of values.
42class ProductFactory {
43public:
44 ProductFactory(mlir::Location loc, fir::FirOpBuilder &builder)
45 : loc(loc), builder(builder) {}
46
47 // Generate an update of the inner product value:
48 // acc += v1 * v2, OR
49 // acc += CONJ(v1) * v2, OR
50 // acc ||= v1 && v2
51 //
52 // CONJ parameter specifies whether the first complex product argument
53 // needs to be conjugated.
54 template <bool CONJ = false>
55 mlir::Value genAccumulateProduct(mlir::Value acc, mlir::Value v1,
56 mlir::Value v2) {
57 mlir::Type resultType = acc.getType();
58 acc = castToProductType(acc, resultType);
59 v1 = castToProductType(v1, resultType);
60 v2 = castToProductType(v2, resultType);
61 mlir::Value result;
62 if (mlir::isa<mlir::FloatType>(resultType)) {
63 result = builder.create<mlir::arith::AddFOp>(
64 loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2));
65 } else if (mlir::isa<mlir::ComplexType>(resultType)) {
66 if constexpr (CONJ)
67 result = fir::IntrinsicLibrary{builder, loc}.genConjg(resultType, v1);
68 else
69 result = v1;
70
71 result = builder.create<fir::AddcOp>(
72 loc, acc, builder.create<fir::MulcOp>(loc, result, v2));
73 } else if (mlir::isa<mlir::IntegerType>(resultType)) {
74 result = builder.create<mlir::arith::AddIOp>(
75 loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2));
76 } else if (mlir::isa<fir::LogicalType>(resultType)) {
77 result = builder.create<mlir::arith::OrIOp>(
78 loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2));
79 } else {
80 llvm_unreachable("unsupported type");
81 }
82
83 return builder.createConvert(loc, resultType, result);
84 }
85
86private:
87 mlir::Location loc;
88 fir::FirOpBuilder &builder;
89
90 mlir::Value castToProductType(mlir::Value value, mlir::Type type) {
91 if (mlir::isa<fir::LogicalType>(type))
92 return builder.createConvert(loc, builder.getIntegerType(1), value);
93
94 // TODO: the multiplications/additions by/of zero resulting from
95 // complex * real are optimized by LLVM under -fno-signed-zeros
96 // -fno-honor-nans.
97 // We can make them disappear by default if we:
98 // * either expand the complex multiplication into real
99 // operations, OR
100 // * set nnan nsz fast-math flags to the complex operations.
101 if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
102 mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
103 fir::factory::Complex helper(builder, loc);
104 mlir::Type partType = helper.getComplexPartType(type);
105 return helper.insertComplexPart(zeroCmplx,
106 castToProductType(value, partType),
107 /*isImagPart=*/false);
108 }
109 return builder.createConvert(loc, type, value);
110 }
111};
112
113class TransposeAsElementalConversion
114 : public mlir::OpRewritePattern<hlfir::TransposeOp> {
115public:
116 using mlir::OpRewritePattern<hlfir::TransposeOp>::OpRewritePattern;
117
118 llvm::LogicalResult
119 matchAndRewrite(hlfir::TransposeOp transpose,
120 mlir::PatternRewriter &rewriter) const override {
121 hlfir::ExprType expr = transpose.getType();
122 // TODO: hlfir.elemental supports polymorphic data types now,
123 // so this can be supported.
124 if (expr.isPolymorphic())
125 return rewriter.notifyMatchFailure(transpose,
126 "TRANSPOSE of polymorphic type");
127
128 mlir::Location loc = transpose.getLoc();
129 fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
130 mlir::Type elementType = expr.getElementType();
131 hlfir::Entity array = hlfir::Entity{transpose.getArray()};
132 mlir::Value resultShape = genResultShape(loc, builder, array);
133 llvm::SmallVector<mlir::Value, 1> typeParams;
134 hlfir::genLengthParameters(loc, builder, array, typeParams);
135
136 auto genKernel = [&array](mlir::Location loc, fir::FirOpBuilder &builder,
137 mlir::ValueRange inputIndices) -> hlfir::Entity {
138 assert(inputIndices.size() == 2 && "checked in TransposeOp::validate");
139 const std::initializer_list<mlir::Value> initList = {inputIndices[1],
140 inputIndices[0]};
141 mlir::ValueRange transposedIndices(initList);
142 hlfir::Entity element =
143 hlfir::getElementAt(loc, builder, array, transposedIndices);
144 hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, element);
145 return val;
146 };
147 hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
148 loc, builder, elementType, resultShape, typeParams, genKernel,
149 /*isUnordered=*/true, /*polymorphicMold=*/nullptr,
150 transpose.getResult().getType());
151
152 // it wouldn't be safe to replace block arguments with a different
153 // hlfir.expr type. Types can differ due to differing amounts of shape
154 // information
155 assert(elementalOp.getResult().getType() ==
156 transpose.getResult().getType());
157
158 rewriter.replaceOp(transpose, elementalOp);
159 return mlir::success();
160 }
161
162private:
163 static mlir::Value genResultShape(mlir::Location loc,
164 fir::FirOpBuilder &builder,
165 hlfir::Entity array) {
166 llvm::SmallVector<mlir::Value, 2> inExtents =
167 hlfir::genExtentsVector(loc, builder, array);
168
169 // transpose indices
170 assert(inExtents.size() == 2 && "checked in TransposeOp::validate");
171 return builder.create<fir::ShapeOp>(
172 loc, mlir::ValueRange{inExtents[1], inExtents[0]});
173 }
174};
175
176/// Base class for converting reduction-like operations into
177/// a reduction loop[-nest] optionally wrapped into hlfir.elemental.
178/// It is used to handle operations produced for ALL, ANY, COUNT,
179/// MAXLOC, MAXVAL, MINLOC, MINVAL, SUM intrinsics.
180///
181/// All of these operations take an input array, and optional
182/// dim, mask arguments. ALL, ANY, COUNT do not have mask argument.
183class ReductionAsElementalConverter {
184public:
185 ReductionAsElementalConverter(mlir::Operation *op,
186 mlir::PatternRewriter &rewriter)
187 : op{op}, rewriter{rewriter}, loc{op->getLoc()}, builder{rewriter, op} {
188 assert(op->getNumResults() == 1);
189 }
190 virtual ~ReductionAsElementalConverter() {}
191
192 /// Do the actual conversion or return mlir::failure(),
193 /// if conversion is not possible.
194 mlir::LogicalResult convert();
195
196private:
197 // Return fir.shape specifying the shape of the result
198 // of a reduction with DIM=dimVal. The second return value
199 // is the extent of the DIM dimension.
200 std::tuple<mlir::Value, mlir::Value>
201 genResultShapeForPartialReduction(hlfir::Entity array, int64_t dimVal);
202
203 /// \p mask is a scalar or array logical mask.
204 /// If \p isPresentPred is not nullptr, it is a dynamic predicate value
205 /// identifying whether the mask's variable is present.
206 /// \p indices is a range of one-based indices to access \p mask
207 /// when it is an array.
208 ///
209 /// The method returns the scalar mask value to guard the access
210 /// to a single element of the input array.
211 mlir::Value genMaskValue(mlir::Value mask, mlir::Value isPresentPred,
212 mlir::ValueRange indices);
213
214protected:
215 /// Return the input array.
216 virtual mlir::Value getSource() const = 0;
217
218 /// Return DIM or nullptr, if it is not present.
219 virtual mlir::Value getDim() const = 0;
220
221 /// Return MASK or nullptr, if it is not present.
222 virtual mlir::Value getMask() const { return nullptr; }
223
224 /// Return FastMathFlags attached to the operation
225 /// or arith::FastMathFlags::none, if the operation
226 /// does not support FastMathFlags (e.g. ALL, ANY, COUNT).
227 virtual mlir::arith::FastMathFlags getFastMath() const {
228 return mlir::arith::FastMathFlags::none;
229 }
230
231 /// Generates initial values for the reduction values used
232 /// by the reduction loop. In general, there is a single
233 /// loop-carried reduction value (e.g. for SUM), but, for example,
234 /// MAXLOC/MINLOC implementation uses multiple reductions.
235 /// \p oneBasedIndices contains any array indices predefined
236 /// before the reduction loop, i.e. it is empty for total
237 /// reductions, and contains the one-based indices of the wrapping
238 /// hlfir.elemental.
239 /// \p extents are the pre-computed extents of the input array.
240 /// For total reductions, \p extents holds extents of all dimensions.
241 /// For partial reductions, \p extents holds a single extent
242 /// of the DIM dimension.
243 virtual llvm::SmallVector<mlir::Value>
244 genReductionInitValues(mlir::ValueRange oneBasedIndices,
245 const llvm::SmallVectorImpl<mlir::Value> &extents) = 0;
246
247 /// Perform reduction(s) update given a single input array's element
248 /// identified by \p array and \p oneBasedIndices coordinates.
249 /// \p currentValue specifies the current value(s) of the reduction(s)
250 /// inside the reduction loop body.
251 virtual llvm::SmallVector<mlir::Value>
252 reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> &currentValue,
253 hlfir::Entity array, mlir::ValueRange oneBasedIndices) = 0;
254
255 /// Given reduction value(s) in \p reductionResults produced
256 /// by the reduction loop, apply any required updates and return
257 /// new reduction value(s) to be used after the reduction loop
258 /// (e.g. as the result yield of the wrapping hlfir.elemental).
259 /// NOTE: if the reduction loop is wrapped in hlfir.elemental,
260 /// the insertion point of any generated code is inside hlfir.elemental.
261 virtual hlfir::Entity
262 genFinalResult(const llvm::SmallVectorImpl<mlir::Value> &reductionResults) {
263 assert(reductionResults.size() == 1 &&
264 "default implementation of genFinalResult expect a single reduction "
265 "value");
266 return hlfir::Entity{reductionResults[0]};
267 }
268
269 /// Return mlir::success(), if the operation can be converted.
270 /// The default implementation always returns mlir::success().
271 /// The derived type may override the default implementation
272 /// with its own definition.
273 virtual mlir::LogicalResult isConvertible() const { return mlir::success(); }
274
275 // Default implementation of isTotalReduction() just checks
276 // if the result of the operation is a scalar.
277 // True result indicates that the reduction has to be done
278 // across all elements, false result indicates that
279 // the result is an array expression produced by an hlfir.elemental
280 // operation with a single reduction loop across the DIM dimension.
281 //
282 // MAXLOC/MINLOC must override this.
283 virtual bool isTotalReduction() const { return getResultRank() == 0; }
284
285 // Return true, if the reduction loop[-nest] may be unordered.
286 // In general, FP reductions may only be unordered when
287 // FastMathFlags::reassoc transformations are allowed.
288 //
289 // Some dervied types may need to override this.
290 virtual bool isUnordered() const {
291 mlir::Type elemType = getSourceElementType();
292 if (mlir::isa<mlir::IntegerType, fir::LogicalType, fir::CharacterType>(
293 elemType))
294 return true;
295 return static_cast<bool>(getFastMath() &
296 mlir::arith::FastMathFlags::reassoc);
297 }
298
299 /// Return 0, if DIM is not present or its values does not matter
300 /// (for example, a reduction of 1D array does not care about
301 /// the DIM value, assuming that it is a valid program).
302 /// Return mlir::failure(), if DIM is a constant known
303 /// to be invalid for the given array.
304 /// Otherwise, return DIM constant value.
305 mlir::FailureOr<int64_t> getConstDim() const {
306 int64_t dimVal = 0;
307 if (!isTotalReduction()) {
308 // In case of partial reduction we should ignore the operations
309 // with invalid DIM values. They may appear in dead code
310 // after constant propagation.
311 auto constDim = fir::getIntIfConstant(getDim());
312 if (!constDim)
313 return rewriter.notifyMatchFailure(op, "Nonconstant DIM");
314 dimVal = *constDim;
315
316 if ((dimVal <= 0 || dimVal > getSourceRank()))
317 return rewriter.notifyMatchFailure(op,
318 "Invalid DIM for partial reduction");
319 }
320 return dimVal;
321 }
322
323 /// Return hlfir::Entity of the result.
324 hlfir::Entity getResultEntity() const {
325 return hlfir::Entity{op->getResult(0)};
326 }
327
328 /// Return type of the result (e.g. !hlfir.expr<?xi32>).
329 mlir::Type getResultType() const { return getResultEntity().getType(); }
330
331 /// Return the element type of the result (e.g. i32).
332 mlir::Type getResultElementType() const {
333 return hlfir::getFortranElementType(getResultType());
334 }
335
336 /// Return rank of the result.
337 unsigned getResultRank() const { return getResultEntity().getRank(); }
338
339 /// Return the element type of the source.
340 mlir::Type getSourceElementType() const {
341 return hlfir::getFortranElementType(getSource().getType());
342 }
343
344 /// Return rank of the input array.
345 unsigned getSourceRank() const {
346 return hlfir::Entity{getSource()}.getRank();
347 }
348
349 /// The reduction operation.
350 mlir::Operation *op;
351
352 mlir::PatternRewriter &rewriter;
353 mlir::Location loc;
354 fir::FirOpBuilder builder;
355};
356
357/// Generate initialization value for MIN or MAX reduction
358/// of the given \p type.
359template <bool IS_MAX>
360static mlir::Value genMinMaxInitValue(mlir::Location loc,
361 fir::FirOpBuilder &builder,
362 mlir::Type type) {
363 if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
364 const llvm::fltSemantics &sem = ty.getFloatSemantics();
365 // We must not use +/-INF here. If the reduction input is empty,
366 // the result of reduction must be +/-LARGEST.
367 llvm::APFloat limit = llvm::APFloat::getLargest(sem, /*Negative=*/IS_MAX);
368 return builder.createRealConstant(loc, type, limit);
369 }
370 unsigned bits = type.getIntOrFloatBitWidth();
371 int64_t limitInt = IS_MAX
372 ? llvm::APInt::getSignedMinValue(bits).getSExtValue()
373 : llvm::APInt::getSignedMaxValue(bits).getSExtValue();
374 return builder.createIntegerConstant(loc, type, limitInt);
375}
376
377/// Generate a comparison of an array element value \p elem
378/// and the current reduction value \p reduction for MIN/MAX reduction.
379template <bool IS_MAX>
380static mlir::Value
381genMinMaxComparison(mlir::Location loc, fir::FirOpBuilder &builder,
382 mlir::Value elem, mlir::Value reduction) {
383 if (mlir::isa<mlir::FloatType>(reduction.getType())) {
384 // For FP reductions we want the first smallest value to be used, that
385 // is not NaN. A OGL/OLT condition will usually work for this unless all
386 // the values are Nan or Inf. This follows the same logic as
387 // NumericCompare for Minloc/Maxloc in extrema.cpp.
388 mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
389 loc,
390 IS_MAX ? mlir::arith::CmpFPredicate::OGT
391 : mlir::arith::CmpFPredicate::OLT,
392 elem, reduction);
393 mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
394 loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
395 mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
396 loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
397 cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
398 return builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
399 } else if (mlir::isa<mlir::IntegerType>(reduction.getType())) {
400 return builder.create<mlir::arith::CmpIOp>(
401 loc,
402 IS_MAX ? mlir::arith::CmpIPredicate::sgt
403 : mlir::arith::CmpIPredicate::slt,
404 elem, reduction);
405 }
406 llvm_unreachable("unsupported type");
407}
408
409// Generate a predicate value indicating that an array with the given
410// extents is not empty.
411static mlir::Value
412genIsNotEmptyArrayExtents(mlir::Location loc, fir::FirOpBuilder &builder,
413 const llvm::SmallVectorImpl<mlir::Value> &extents) {
414 mlir::Value isNotEmpty = builder.createBool(loc, true);
415 for (auto extent : extents) {
416 mlir::Value zero =
417 fir::factory::createZeroValue(builder, loc, extent.getType());
418 mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
419 loc, mlir::arith::CmpIPredicate::ne, extent, zero);
420 isNotEmpty = builder.create<mlir::arith::AndIOp>(loc, isNotEmpty, cmp);
421 }
422 return isNotEmpty;
423}
424
425// Helper method for MIN/MAX LOC/VAL reductions.
426// It returns a vector of indices such that they address
427// the first element of an array (in case of total reduction)
428// or its section (in case of partial reduction).
429//
430// If case of total reduction oneBasedIndices must be empty,
431// otherwise, they contain the one based indices of the wrapping
432// hlfir.elemental.
433// Basically, the method adds the necessary number of constant-one
434// indices into oneBasedIndices.
435static llvm::SmallVector<mlir::Value> genFirstElementIndicesForReduction(
436 mlir::Location loc, fir::FirOpBuilder &builder, bool isTotalReduction,
437 mlir::FailureOr<int64_t> dim, unsigned rank,
438 mlir::ValueRange oneBasedIndices) {
439 llvm::SmallVector<mlir::Value> indices{oneBasedIndices};
440 mlir::Value one =
441 builder.createIntegerConstant(loc, builder.getIndexType(), 1);
442 if (isTotalReduction) {
443 assert(oneBasedIndices.size() == 0 &&
444 "wrong number of indices for total reduction");
445 // Set indices to all-ones.
446 indices.append(rank, one);
447 } else {
448 assert(oneBasedIndices.size() == rank - 1 &&
449 "there must be RANK-1 indices for partial reduction");
450 assert(mlir::succeeded(dim) && "partial reduction with invalid DIM");
451 // Insert constant-one index at DIM dimension.
452 indices.insert(indices.begin() + *dim - 1, one);
453 }
454 return indices;
455}
456
457/// Implementation of ReductionAsElementalConverter interface
458/// for MAXLOC/MINLOC.
459template <typename T>
460class MinMaxlocAsElementalConverter : public ReductionAsElementalConverter {
461 static_assert(std::is_same_v<T, hlfir::MaxlocOp> ||
462 std::is_same_v<T, hlfir::MinlocOp>);
463 static constexpr unsigned maxRank = Fortran::common::maxRank;
464 // We have the following reduction values in the reduction loop:
465 // * N integer coordinates, where N is:
466 // - RANK(ARRAY) for total reductions.
467 // - 1 for partial reductions.
468 // * 1 reduction value holding the current MIN/MAX.
469 // * 1 boolean indicating whether it is the first time
470 // the mask is true.
471 //
472 // If useIsFirst() returns false, then the boolean loop-carried
473 // value is not used.
474 static constexpr unsigned maxNumReductions = Fortran::common::maxRank + 2;
475 static constexpr bool isMax = std::is_same_v<T, hlfir::MaxlocOp>;
476 using Base = ReductionAsElementalConverter;
477
478public:
479 MinMaxlocAsElementalConverter(T op, mlir::PatternRewriter &rewriter)
480 : Base{op.getOperation(), rewriter} {}
481
482private:
483 virtual mlir::Value getSource() const final { return getOp().getArray(); }
484 virtual mlir::Value getDim() const final { return getOp().getDim(); }
485 virtual mlir::Value getMask() const final { return getOp().getMask(); }
486 virtual mlir::arith::FastMathFlags getFastMath() const final {
487 return getOp().getFastmath();
488 }
489
490 virtual mlir::LogicalResult isConvertible() const final {
491 if (getOp().getBack())
492 return rewriter.notifyMatchFailure(
493 getOp(), "BACK is not supported for MINLOC/MAXLOC inlining");
494 if (mlir::isa<fir::CharacterType>(getSourceElementType()))
495 return rewriter.notifyMatchFailure(
496 getOp(),
497 "CHARACTER type is not supported for MINLOC/MAXLOC inlining");
498 return mlir::success();
499 }
500
501 // If the result is scalar, then DIM does not matter,
502 // and this is a total reduction.
503 // If DIM is not present, this is a total reduction.
504 virtual bool isTotalReduction() const final {
505 return getResultRank() == 0 || !getDim();
506 }
507
508 virtual llvm::SmallVector<mlir::Value> genReductionInitValues(
509 mlir::ValueRange oneBasedIndices,
510 const llvm::SmallVectorImpl<mlir::Value> &extents) final;
511 virtual llvm::SmallVector<mlir::Value>
512 reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> &currentValue,
513 hlfir::Entity array, mlir::ValueRange oneBasedIndices) final;
514 virtual hlfir::Entity genFinalResult(
515 const llvm::SmallVectorImpl<mlir::Value> &reductionResults) final;
516
517private:
518 T getOp() const { return mlir::cast<T>(op); }
519
520 unsigned getNumCoors() const {
521 return isTotalReduction() ? getSourceRank() : 1;
522 }
523
524 void
525 checkReductions(const llvm::SmallVectorImpl<mlir::Value> &reductions) const {
526 if (!useIsFirst())
527 assert(reductions.size() == getNumCoors() + 1 &&
528 "invalid number of reductions for MINLOC/MAXLOC");
529 else
530 assert(reductions.size() == getNumCoors() + 2 &&
531 "invalid number of reductions for MINLOC/MAXLOC");
532 }
533
534 mlir::Value
535 getCurrentMinMax(const llvm::SmallVectorImpl<mlir::Value> &reductions) const {
536 checkReductions(reductions);
537 return reductions[getNumCoors()];
538 }
539
540 mlir::Value
541 getIsFirst(const llvm::SmallVectorImpl<mlir::Value> &reductions) const {
542 checkReductions(reductions);
543 assert(useIsFirst() && "IsFirst predicate must not be used");
544 return reductions[getNumCoors() + 1];
545 }
546
547 // Return true iff the input can contain NaNs, and they should be
548 // honored, such that all-NaNs input must produce the location
549 // of the first unmasked NaN.
550 bool honorNans() const {
551 return !static_cast<bool>(getFastMath() & mlir::arith::FastMathFlags::nnan);
552 }
553
554 // Return true iff we have to use the loop-carried IsFirst predicate.
555 // If there is no mask, we can initialize the reductions using
556 // the first elements of the input.
557 // If NaNs are not honored, we can initialize the starting MIN/MAX
558 // value to +/-LARGEST; the coordinates are guaranteed to be updated
559 // properly for non-empty input without NaNs.
560 bool useIsFirst() const { return getMask() && honorNans(); }
561};
562
563template <typename T>
564llvm::SmallVector<mlir::Value>
565MinMaxlocAsElementalConverter<T>::genReductionInitValues(
566 mlir::ValueRange oneBasedIndices,
567 const llvm::SmallVectorImpl<mlir::Value> &extents) {
568 fir::IfOp ifOp;
569 if (!useIsFirst() && honorNans()) {
570 // Check if we can load the value of the first element in the array
571 // or its section (for partial reduction).
572 assert(!getMask() && "cannot fetch first element when mask is present");
573 assert(extents.size() == getNumCoors() &&
574 "wrong number of extents for MINLOC/MAXLOC reduction");
575 mlir::Value isNotEmpty = genIsNotEmptyArrayExtents(loc, builder, extents);
576
577 llvm::SmallVector<mlir::Value> indices = genFirstElementIndicesForReduction(
578 loc, builder, isTotalReduction(), getConstDim(), getSourceRank(),
579 oneBasedIndices);
580
581 llvm::SmallVector<mlir::Type> ifTypes(getNumCoors(),
582 getResultElementType());
583 ifTypes.push_back(getSourceElementType());
584 ifOp = builder.create<fir::IfOp>(loc, ifTypes, isNotEmpty,
585 /*withElseRegion=*/true);
586 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
587 mlir::Value one =
588 builder.createIntegerConstant(loc, getResultElementType(), 1);
589 llvm::SmallVector<mlir::Value> results(getNumCoors(), one);
590 mlir::Value minMaxFirst =
591 hlfir::loadElementAt(loc, builder, hlfir::Entity{getSource()}, indices);
592 results.push_back(minMaxFirst);
593 builder.create<fir::ResultOp>(loc, results);
594
595 // In the 'else' block use default init values.
596 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
597 }
598
599 // Initial value for the coordinate(s) is zero.
600 mlir::Value zeroCoor =
601 fir::factory::createZeroValue(builder, loc, getResultElementType());
602 llvm::SmallVector<mlir::Value> result(getNumCoors(), zeroCoor);
603
604 // Initial value for the MIN/MAX value.
605 mlir::Value minMaxInit =
606 genMinMaxInitValue<isMax>(loc, builder, getSourceElementType());
607 result.push_back(minMaxInit);
608
609 if (ifOp) {
610 builder.create<fir::ResultOp>(loc, result);
611 builder.setInsertionPointAfter(ifOp);
612 result = ifOp.getResults();
613 } else if (useIsFirst()) {
614 // Initial value for isFirst predicate. It is switched to false,
615 // when the reduction update dynamically happens inside the reduction
616 // loop.
617 mlir::Value trueVal = builder.createBool(loc, true);
618 result.push_back(trueVal);
619 }
620
621 return result;
622}
623
624template <typename T>
625llvm::SmallVector<mlir::Value>
626MinMaxlocAsElementalConverter<T>::reduceOneElement(
627 const llvm::SmallVectorImpl<mlir::Value> &currentValue, hlfir::Entity array,
628 mlir::ValueRange oneBasedIndices) {
629 checkReductions(currentValue);
630 hlfir::Entity elementValue =
631 hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
632 mlir::Value cmp = genMinMaxComparison<isMax>(loc, builder, elementValue,
633 getCurrentMinMax(currentValue));
634 if (useIsFirst()) {
635 // If isFirst is true, then do the reduction update regardless
636 // of the FP comparison.
637 cmp =
638 builder.create<mlir::arith::OrIOp>(loc, cmp, getIsFirst(currentValue));
639 }
640
641 llvm::SmallVector<mlir::Value> newIndices;
642 int64_t dim = 1;
643 if (!isTotalReduction()) {
644 auto dimVal = getConstDim();
645 assert(mlir::succeeded(dimVal) &&
646 "partial MINLOC/MAXLOC reduction with invalid DIM");
647 dim = *dimVal;
648 assert(getNumCoors() == 1 &&
649 "partial MAXLOC/MINLOC reduction must compute one coordinate");
650 }
651
652 for (unsigned coorIdx = 0; coorIdx < getNumCoors(); ++coorIdx) {
653 mlir::Value currentCoor = currentValue[coorIdx];
654 mlir::Value newCoor = builder.createConvert(
655 loc, currentCoor.getType(), oneBasedIndices[coorIdx + dim - 1]);
656 mlir::Value update =
657 builder.create<mlir::arith::SelectOp>(loc, cmp, newCoor, currentCoor);
658 newIndices.push_back(update);
659 }
660
661 mlir::Value newMinMax = builder.create<mlir::arith::SelectOp>(
662 loc, cmp, elementValue, getCurrentMinMax(currentValue));
663 newIndices.push_back(newMinMax);
664
665 if (useIsFirst()) {
666 mlir::Value newIsFirst = builder.createBool(loc, false);
667 newIndices.push_back(newIsFirst);
668 }
669
670 assert(currentValue.size() == newIndices.size() &&
671 "invalid number of updated reductions");
672
673 return newIndices;
674}
675
676template <typename T>
677hlfir::Entity MinMaxlocAsElementalConverter<T>::genFinalResult(
678 const llvm::SmallVectorImpl<mlir::Value> &reductionResults) {
679 // Identification of the final result of MINLOC/MAXLOC:
680 // * If DIM is absent, the result is rank-one array.
681 // * If DIM is present:
682 // - The result is scalar for rank-one input.
683 // - The result is an array of rank RANK(ARRAY)-1.
684 checkReductions(reductionResults);
685
686 // 16.9.137 & 16.9.143:
687 // The subscripts returned by MINLOC/MAXLOC are in the range
688 // 1 to the extent of the corresponding dimension.
689 mlir::Type indexType = builder.getIndexType();
690
691 // For partial reductions, the final result of the reduction
692 // loop is just a scalar - the coordinate within DIM dimension.
693 if (getResultRank() == 0 || !isTotalReduction()) {
694 // The result is a scalar, so just return the scalar.
695 assert(getNumCoors() == 1 &&
696 "unpexpected number of coordinates for scalar result");
697 return hlfir::Entity{reductionResults[0]};
698 }
699 // This is a total reduction, and there is no wrapping hlfir.elemental.
700 // We have to pack the reduced coordinates into a rank-one array.
701 unsigned rank = getSourceRank();
702 // TODO: in order to avoid introducing new memory effects
703 // we should not use a temporary in memory.
704 // We can use hlfir.elemental with a switch to pack all the coordinates
705 // into an array expression, or we can have a dedicated HLFIR operation
706 // for this.
707 mlir::Value tempArray = builder.createTemporary(
708 loc, fir::SequenceType::get(rank, getResultElementType()));
709 for (unsigned i = 0; i < rank; ++i) {
710 mlir::Value coor = reductionResults[i];
711 mlir::Value idx = builder.createIntegerConstant(loc, indexType, i + 1);
712 mlir::Value resultElement =
713 hlfir::getElementAt(loc, builder, hlfir::Entity{tempArray}, {idx});
714 builder.create<hlfir::AssignOp>(loc, coor, resultElement);
715 }
716 mlir::Value tempExpr = builder.create<hlfir::AsExprOp>(
717 loc, tempArray, builder.createBool(loc, false));
718 return hlfir::Entity{tempExpr};
719}
720
721/// Base class for numeric reductions like MAXVAl, MINVAL, SUM.
722template <typename OpT>
723class NumericReductionAsElementalConverterBase
724 : public ReductionAsElementalConverter {
725 using Base = ReductionAsElementalConverter;
726
727protected:
728 NumericReductionAsElementalConverterBase(OpT op,
729 mlir::PatternRewriter &rewriter)
730 : Base{op.getOperation(), rewriter} {}
731
732 virtual mlir::Value getSource() const final { return getOp().getArray(); }
733 virtual mlir::Value getDim() const final { return getOp().getDim(); }
734 virtual mlir::Value getMask() const final { return getOp().getMask(); }
735 virtual mlir::arith::FastMathFlags getFastMath() const final {
736 return getOp().getFastmath();
737 }
738
739 OpT getOp() const { return mlir::cast<OpT>(op); }
740
741 void checkReductions(const llvm::SmallVectorImpl<mlir::Value> &reductions) {
742 assert(reductions.size() == 1 && "reduction must produce single value");
743 }
744};
745
746/// Reduction converter for MAXMAL/MINVAL.
747template <typename T>
748class MinMaxvalAsElementalConverter
749 : public NumericReductionAsElementalConverterBase<T> {
750 static_assert(std::is_same_v<T, hlfir::MaxvalOp> ||
751 std::is_same_v<T, hlfir::MinvalOp>);
752 // We have two reduction values:
753 // * The current MIN/MAX value.
754 // * 1 boolean indicating whether it is the first time
755 // the mask is true.
756 //
757 // The boolean flag is used to replace the initial value
758 // with the first input element even if it is NaN.
759 // If useIsFirst() returns false, then the boolean loop-carried
760 // value is not used.
761 static constexpr bool isMax = std::is_same_v<T, hlfir::MaxvalOp>;
762 using Base = NumericReductionAsElementalConverterBase<T>;
763
764public:
765 MinMaxvalAsElementalConverter(T op, mlir::PatternRewriter &rewriter)
766 : Base{op, rewriter} {}
767
768private:
769 virtual mlir::LogicalResult isConvertible() const final {
770 if (mlir::isa<fir::CharacterType>(this->getSourceElementType()))
771 return this->rewriter.notifyMatchFailure(
772 this->getOp(),
773 "CHARACTER type is not supported for MINVAL/MAXVAL inlining");
774 return mlir::success();
775 }
776
777 virtual llvm::SmallVector<mlir::Value> genReductionInitValues(
778 mlir::ValueRange oneBasedIndices,
779 const llvm::SmallVectorImpl<mlir::Value> &extents) final;
780
781 virtual llvm::SmallVector<mlir::Value>
782 reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> &currentValue,
783 hlfir::Entity array,
784 mlir::ValueRange oneBasedIndices) final {
785 this->checkReductions(currentValue);
786 llvm::SmallVector<mlir::Value> result;
787 fir::FirOpBuilder &builder = this->builder;
788 mlir::Location loc = this->loc;
789 hlfir::Entity elementValue =
790 hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
791 mlir::Value currentMinMax = getCurrentMinMax(currentValue);
792 mlir::Value cmp =
793 genMinMaxComparison<isMax>(loc, builder, elementValue, currentMinMax);
794 if (useIsFirst())
795 cmp = builder.create<mlir::arith::OrIOp>(loc, cmp,
796 getIsFirst(currentValue));
797 mlir::Value newMinMax = builder.create<mlir::arith::SelectOp>(
798 loc, cmp, elementValue, currentMinMax);
799 result.push_back(newMinMax);
800 if (useIsFirst())
801 result.push_back(builder.createBool(loc, false));
802 return result;
803 }
804
805 virtual hlfir::Entity genFinalResult(
806 const llvm::SmallVectorImpl<mlir::Value> &reductionResults) final {
807 this->checkReductions(reductionResults);
808 return hlfir::Entity{getCurrentMinMax(reductionResults)};
809 }
810
811 void
812 checkReductions(const llvm::SmallVectorImpl<mlir::Value> &reductions) const {
813 assert(reductions.size() == getNumReductions() &&
814 "invalid number of reductions for MINVAL/MAXVAL");
815 }
816
817 mlir::Value
818 getCurrentMinMax(const llvm::SmallVectorImpl<mlir::Value> &reductions) const {
819 this->checkReductions(reductions);
820 return reductions[0];
821 }
822
823 mlir::Value
824 getIsFirst(const llvm::SmallVectorImpl<mlir::Value> &reductions) const {
825 this->checkReductions(reductions);
826 assert(useIsFirst() && "IsFirst predicate must not be used");
827 return reductions[1];
828 }
829
830 // Return true iff the input can contain NaNs, and they should be
831 // honored, such that all-NaNs input must produce NaN result.
832 bool honorNans() const {
833 return !static_cast<bool>(this->getFastMath() &
834 mlir::arith::FastMathFlags::nnan);
835 }
836
837 // Return true iff we have to use the loop-carried IsFirst predicate.
838 // If there is no mask, we can initialize the reductions using
839 // the first elements of the input.
840 // If NaNs are not honored, we can initialize the starting MIN/MAX
841 // value to +/-LARGEST.
842 bool useIsFirst() const { return this->getMask() && honorNans(); }
843
844 std::size_t getNumReductions() const { return useIsFirst() ? 2 : 1; }
845};
846
847template <typename T>
848llvm::SmallVector<mlir::Value>
849MinMaxvalAsElementalConverter<T>::genReductionInitValues(
850 mlir::ValueRange oneBasedIndices,
851 const llvm::SmallVectorImpl<mlir::Value> &extents) {
852 llvm::SmallVector<mlir::Value> result;
853 fir::FirOpBuilder &builder = this->builder;
854 mlir::Location loc = this->loc;
855
856 fir::IfOp ifOp;
857 if (!useIsFirst() && honorNans()) {
858 // Check if we can load the value of the first element in the array
859 // or its section (for partial reduction).
860 assert(!this->getMask() &&
861 "cannot fetch first element when mask is present");
862 assert(extents.size() ==
863 (this->isTotalReduction() ? this->getSourceRank() : 1u) &&
864 "wrong number of extents for MINVAL/MAXVAL reduction");
865 mlir::Value isNotEmpty = genIsNotEmptyArrayExtents(loc, builder, extents);
866 llvm::SmallVector<mlir::Value> indices = genFirstElementIndicesForReduction(
867 loc, builder, this->isTotalReduction(), this->getConstDim(),
868 this->getSourceRank(), oneBasedIndices);
869
870 ifOp =
871 builder.create<fir::IfOp>(loc, this->getResultElementType(), isNotEmpty,
872 /*withElseRegion=*/true);
873 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
874 mlir::Value minMaxFirst = hlfir::loadElementAt(
875 loc, builder, hlfir::Entity{this->getSource()}, indices);
876 builder.create<fir::ResultOp>(loc, minMaxFirst);
877
878 // In the 'else' block use default init values.
879 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
880 }
881
882 mlir::Value init =
883 genMinMaxInitValue<isMax>(loc, builder, this->getResultElementType());
884 result.push_back(init);
885
886 if (ifOp) {
887 builder.create<fir::ResultOp>(loc, result);
888 builder.setInsertionPointAfter(ifOp);
889 result = ifOp.getResults();
890 } else if (useIsFirst()) {
891 // Initial value for isFirst predicate. It is switched to false,
892 // when the reduction update dynamically happens inside the reduction
893 // loop.
894 result.push_back(builder.createBool(loc, true));
895 }
896
897 return result;
898}
899
900/// Reduction converter for SUM.
901class SumAsElementalConverter
902 : public NumericReductionAsElementalConverterBase<hlfir::SumOp> {
903 using Base = NumericReductionAsElementalConverterBase;
904
905public:
906 SumAsElementalConverter(hlfir::SumOp op, mlir::PatternRewriter &rewriter)
907 : Base{op, rewriter} {}
908
909private:
910 virtual llvm::SmallVector<mlir::Value> genReductionInitValues(
911 [[maybe_unused]] mlir::ValueRange oneBasedIndices,
912 [[maybe_unused]] const llvm::SmallVectorImpl<mlir::Value> &extents)
913 final {
914 return {
915 fir::factory::createZeroValue(builder, loc, getResultElementType())};
916 }
917 virtual llvm::SmallVector<mlir::Value>
918 reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> &currentValue,
919 hlfir::Entity array,
920 mlir::ValueRange oneBasedIndices) final {
921 checkReductions(currentValue);
922 hlfir::Entity elementValue =
923 hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
924 // NOTE: we can use "Kahan summation" same way as the runtime
925 // (e.g. when fast-math is not allowed), but let's start with
926 // the simple version.
927 return {genScalarAdd(currentValue[0], elementValue)};
928 }
929
930 // Generate scalar addition of the two values (of the same data type).
931 mlir::Value genScalarAdd(mlir::Value value1, mlir::Value value2);
932};
933
934/// Base class for logical reductions like ALL, ANY, COUNT.
935/// They do not have MASK and FastMathFlags.
936template <typename OpT>
937class LogicalReductionAsElementalConverterBase
938 : public ReductionAsElementalConverter {
939 using Base = ReductionAsElementalConverter;
940
941public:
942 LogicalReductionAsElementalConverterBase(OpT op,
943 mlir::PatternRewriter &rewriter)
944 : Base{op.getOperation(), rewriter} {}
945
946protected:
947 OpT getOp() const { return mlir::cast<OpT>(op); }
948
949 void checkReductions(const llvm::SmallVectorImpl<mlir::Value> &reductions) {
950 assert(reductions.size() == 1 && "reduction must produce single value");
951 }
952
953 virtual mlir::Value getSource() const final { return getOp().getMask(); }
954 virtual mlir::Value getDim() const final { return getOp().getDim(); }
955
956 virtual hlfir::Entity genFinalResult(
957 const llvm::SmallVectorImpl<mlir::Value> &reductionResults) override {
958 checkReductions(reductionResults);
959 return hlfir::Entity{reductionResults[0]};
960 }
961};
962
963/// Reduction converter for ALL/ANY.
964template <typename T>
965class AllAnyAsElementalConverter
966 : public LogicalReductionAsElementalConverterBase<T> {
967 static_assert(std::is_same_v<T, hlfir::AllOp> ||
968 std::is_same_v<T, hlfir::AnyOp>);
969 static constexpr bool isAll = std::is_same_v<T, hlfir::AllOp>;
970 using Base = LogicalReductionAsElementalConverterBase<T>;
971
972public:
973 AllAnyAsElementalConverter(T op, mlir::PatternRewriter &rewriter)
974 : Base{op, rewriter} {}
975
976private:
977 virtual llvm::SmallVector<mlir::Value> genReductionInitValues(
978 [[maybe_unused]] mlir::ValueRange oneBasedIndices,
979 [[maybe_unused]] const llvm::SmallVectorImpl<mlir::Value> &extents)
980 final {
981 return {this->builder.createBool(this->loc, isAll ? true : false)};
982 }
983 virtual llvm::SmallVector<mlir::Value>
984 reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> &currentValue,
985 hlfir::Entity array,
986 mlir::ValueRange oneBasedIndices) final {
987 this->checkReductions(currentValue);
988 fir::FirOpBuilder &builder = this->builder;
989 mlir::Location loc = this->loc;
990 hlfir::Entity elementValue =
991 hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
992 mlir::Value mask =
993 builder.createConvert(loc, builder.getI1Type(), elementValue);
994 if constexpr (isAll)
995 return {builder.create<mlir::arith::AndIOp>(loc, mask, currentValue[0])};
996 else
997 return {builder.create<mlir::arith::OrIOp>(loc, mask, currentValue[0])};
998 }
999
1000 virtual hlfir::Entity genFinalResult(
1001 const llvm::SmallVectorImpl<mlir::Value> &reductionValues) final {
1002 this->checkReductions(reductionValues);
1003 return hlfir::Entity{this->builder.createConvert(
1004 this->loc, this->getResultElementType(), reductionValues[0])};
1005 }
1006};
1007
1008/// Reduction converter for COUNT.
1009class CountAsElementalConverter
1010 : public LogicalReductionAsElementalConverterBase<hlfir::CountOp> {
1011 using Base = LogicalReductionAsElementalConverterBase<hlfir::CountOp>;
1012
1013public:
1014 CountAsElementalConverter(hlfir::CountOp op, mlir::PatternRewriter &rewriter)
1015 : Base{op, rewriter} {}
1016
1017private:
1018 virtual llvm::SmallVector<mlir::Value> genReductionInitValues(
1019 [[maybe_unused]] mlir::ValueRange oneBasedIndices,
1020 [[maybe_unused]] const llvm::SmallVectorImpl<mlir::Value> &extents)
1021 final {
1022 return {
1023 fir::factory::createZeroValue(builder, loc, getResultElementType())};
1024 }
1025 virtual llvm::SmallVector<mlir::Value>
1026 reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> &currentValue,
1027 hlfir::Entity array,
1028 mlir::ValueRange oneBasedIndices) final {
1029 checkReductions(currentValue);
1030 hlfir::Entity elementValue =
1031 hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
1032 mlir::Value cond =
1033 builder.createConvert(loc, builder.getI1Type(), elementValue);
1034 mlir::Value one =
1035 builder.createIntegerConstant(loc, getResultElementType(), 1);
1036 mlir::Value add1 =
1037 builder.create<mlir::arith::AddIOp>(loc, currentValue[0], one);
1038 return {builder.create<mlir::arith::SelectOp>(loc, cond, add1,
1039 currentValue[0])};
1040 }
1041};
1042
1043mlir::LogicalResult ReductionAsElementalConverter::convert() {
1044 mlir::LogicalResult canConvert(isConvertible());
1045
1046 if (mlir::failed(canConvert))
1047 return canConvert;
1048
1049 hlfir::Entity array = hlfir::Entity{getSource()};
1050 bool isTotalReduce = isTotalReduction();
1051 auto dimVal = getConstDim();
1052 if (mlir::failed(dimVal))
1053 return dimVal;
1054 mlir::Value mask = getMask();
1055 mlir::Value resultShape, dimExtent;
1056 llvm::SmallVector<mlir::Value> arrayExtents;
1057 if (isTotalReduce)
1058 arrayExtents = hlfir::genExtentsVector(loc, builder, array);
1059 else
1060 std::tie(resultShape, dimExtent) =
1061 genResultShapeForPartialReduction(array, *dimVal);
1062
1063 // If the mask is present and is a scalar, then we'd better load its value
1064 // outside of the reduction loop making the loop unswitching easier.
1065 mlir::Value isPresentPred, maskValue;
1066 if (mask) {
1067 if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
1068 // MASK represented by a box might be dynamically optional,
1069 // so we have to check for its presence before accessing it.
1070 isPresentPred =
1071 builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
1072 }
1073
1074 if (hlfir::Entity{mask}.isScalar())
1075 maskValue = genMaskValue(mask, isPresentPred, {});
1076 }
1077
1078 auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1079 mlir::ValueRange inputIndices) -> hlfir::Entity {
1080 // Loop over all indices in the DIM dimension, and reduce all values.
1081 // If DIM is not present, do total reduction.
1082
1083 llvm::SmallVector<mlir::Value> extents;
1084 if (isTotalReduce)
1085 extents = arrayExtents;
1086 else
1087 extents.push_back(
1088 builder.createConvert(loc, builder.getIndexType(), dimExtent));
1089
1090 // Initial value for the reduction.
1091 llvm::SmallVector<mlir::Value, 1> reductionInitValues =
1092 genReductionInitValues(inputIndices, extents);
1093
1094 auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1095 mlir::ValueRange oneBasedIndices,
1096 mlir::ValueRange reductionArgs)
1097 -> llvm::SmallVector<mlir::Value, 1> {
1098 // Generate the reduction loop-nest body.
1099 // The initial reduction value in the innermost loop
1100 // is passed via reductionArgs[0].
1101 llvm::SmallVector<mlir::Value> indices;
1102 if (isTotalReduce) {
1103 indices = oneBasedIndices;
1104 } else {
1105 indices = inputIndices;
1106 indices.insert(indices.begin() + *dimVal - 1, oneBasedIndices[0]);
1107 }
1108
1109 llvm::SmallVector<mlir::Value, 1> reductionValues = reductionArgs;
1110 llvm::SmallVector<mlir::Type, 1> reductionTypes;
1111 llvm::transform(reductionValues, std::back_inserter(reductionTypes),
1112 [](mlir::Value v) { return v.getType(); });
1113 fir::IfOp ifOp;
1114 if (mask) {
1115 // Make the reduction value update conditional on the value
1116 // of the mask.
1117 if (!maskValue) {
1118 // If the mask is an array, use the elemental and the loop indices
1119 // to address the proper mask element.
1120 maskValue = genMaskValue(mask, isPresentPred, indices);
1121 }
1122 mlir::Value isUnmasked =
1123 builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
1124 ifOp = builder.create<fir::IfOp>(loc, reductionTypes, isUnmasked,
1125 /*withElseRegion=*/true);
1126 // In the 'else' block return the current reduction value.
1127 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1128 builder.create<fir::ResultOp>(loc, reductionValues);
1129
1130 // In the 'then' block do the actual addition.
1131 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1132 }
1133 reductionValues = reduceOneElement(reductionValues, array, indices);
1134 if (ifOp) {
1135 builder.create<fir::ResultOp>(loc, reductionValues);
1136 builder.setInsertionPointAfter(ifOp);
1137 reductionValues = ifOp.getResults();
1138 }
1139
1140 return reductionValues;
1141 };
1142
1143 llvm::SmallVector<mlir::Value, 1> reductionFinalValues =
1144 hlfir::genLoopNestWithReductions(
1145 loc, builder, extents, reductionInitValues, genBody, isUnordered());
1146 return genFinalResult(reductionFinalValues);
1147 };
1148
1149 if (isTotalReduce) {
1150 hlfir::Entity result = genKernel(loc, builder, mlir::ValueRange{});
1151 rewriter.replaceOp(op, result);
1152 return mlir::success();
1153 }
1154
1155 hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
1156 loc, builder, getResultElementType(), resultShape, /*typeParams=*/{},
1157 genKernel,
1158 /*isUnordered=*/true, /*polymorphicMold=*/nullptr, getResultType());
1159
1160 // it wouldn't be safe to replace block arguments with a different
1161 // hlfir.expr type. Types can differ due to differing amounts of shape
1162 // information
1163 assert(elementalOp.getResult().getType() == op->getResult(0).getType());
1164
1165 rewriter.replaceOp(op, elementalOp);
1166 return mlir::success();
1167}
1168
1169std::tuple<mlir::Value, mlir::Value>
1170ReductionAsElementalConverter::genResultShapeForPartialReduction(
1171 hlfir::Entity array, int64_t dimVal) {
1172 llvm::SmallVector<mlir::Value> inExtents =
1173 hlfir::genExtentsVector(loc, builder, array);
1174 assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
1175 "DIM must be present and a positive constant not exceeding "
1176 "the array's rank");
1177
1178 mlir::Value dimExtent = inExtents[dimVal - 1];
1179 inExtents.erase(inExtents.begin() + dimVal - 1);
1180 return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent};
1181}
1182
1183mlir::Value SumAsElementalConverter::genScalarAdd(mlir::Value value1,
1184 mlir::Value value2) {
1185 mlir::Type ty = value1.getType();
1186 assert(ty == value2.getType() && "reduction values' types do not match");
1187 if (mlir::isa<mlir::FloatType>(ty))
1188 return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
1189 else if (mlir::isa<mlir::ComplexType>(ty))
1190 return builder.create<fir::AddcOp>(loc, value1, value2);
1191 else if (mlir::isa<mlir::IntegerType>(ty))
1192 return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
1193
1194 llvm_unreachable("unsupported SUM reduction type");
1195}
1196
1197mlir::Value ReductionAsElementalConverter::genMaskValue(
1198 mlir::Value mask, mlir::Value isPresentPred, mlir::ValueRange indices) {
1199 mlir::OpBuilder::InsertionGuard guard(builder);
1200 fir::IfOp ifOp;
1201 mlir::Type maskType =
1202 hlfir::getFortranElementType(fir::unwrapPassByRefType(mask.getType()));
1203 if (isPresentPred) {
1204 ifOp = builder.create<fir::IfOp>(loc, maskType, isPresentPred,
1205 /*withElseRegion=*/true);
1206
1207 // Use 'true', if the mask is not present.
1208 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1209 mlir::Value trueValue = builder.createBool(loc, true);
1210 trueValue = builder.createConvert(loc, maskType, trueValue);
1211 builder.create<fir::ResultOp>(loc, trueValue);
1212
1213 // Load the mask value, if the mask is present.
1214 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1215 }
1216
1217 hlfir::Entity maskVar{mask};
1218 if (maskVar.isScalar()) {
1219 if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
1220 // MASK may be a boxed scalar.
1221 mlir::Value addr = hlfir::genVariableRawAddress(loc, builder, maskVar);
1222 mask = builder.create<fir::LoadOp>(loc, hlfir::Entity{addr});
1223 } else {
1224 mask = hlfir::loadTrivialScalar(loc, builder, maskVar);
1225 }
1226 } else {
1227 // Load from the mask array.
1228 assert(!indices.empty() && "no indices for addressing the mask array");
1229 maskVar = hlfir::getElementAt(loc, builder, maskVar, indices);
1230 mask = hlfir::loadTrivialScalar(loc, builder, maskVar);
1231 }
1232
1233 if (!isPresentPred)
1234 return mask;
1235
1236 builder.create<fir::ResultOp>(loc, mask);
1237 return ifOp.getResult(0);
1238}
1239
1240/// Convert an operation that is a partial or total reduction
1241/// over an array of values into a reduction loop[-nest]
1242/// optionally wrapped into hlfir.elemental.
1243template <typename Op>
1244class ReductionConversion : public mlir::OpRewritePattern<Op> {
1245public:
1246 using mlir::OpRewritePattern<Op>::OpRewritePattern;
1247
1248 llvm::LogicalResult
1249 matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
1250 if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
1251 std::is_same_v<Op, hlfir::MinlocOp>) {
1252 MinMaxlocAsElementalConverter<Op> converter(op, rewriter);
1253 return converter.convert();
1254 } else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
1255 std::is_same_v<Op, hlfir::MinvalOp>) {
1256 MinMaxvalAsElementalConverter<Op> converter(op, rewriter);
1257 return converter.convert();
1258 } else if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
1259 CountAsElementalConverter converter(op, rewriter);
1260 return converter.convert();
1261 } else if constexpr (std::is_same_v<Op, hlfir::AllOp> ||
1262 std::is_same_v<Op, hlfir::AnyOp>) {
1263 AllAnyAsElementalConverter<Op> converter(op, rewriter);
1264 return converter.convert();
1265 } else if constexpr (std::is_same_v<Op, hlfir::SumOp>) {
1266 SumAsElementalConverter converter{op, rewriter};
1267 return converter.convert();
1268 }
1269 return rewriter.notifyMatchFailure(op, "unexpected reduction operation");
1270 }
1271};
1272
1273class CShiftConversion : public mlir::OpRewritePattern<hlfir::CShiftOp> {
1274public:
1275 using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;
1276
1277 llvm::LogicalResult
1278 matchAndRewrite(hlfir::CShiftOp cshift,
1279 mlir::PatternRewriter &rewriter) const override {
1280
1281 hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType());
1282 assert(expr &&
1283 "expected an expression type for the result of hlfir.cshift");
1284 unsigned arrayRank = expr.getRank();
1285 // When it is a 1D CSHIFT, we may assume that the DIM argument
1286 // (whether it is present or absent) is equal to 1, otherwise,
1287 // the program is illegal.
1288 int64_t dimVal = 1;
1289 if (arrayRank != 1)
1290 if (mlir::Value dim = cshift.getDim()) {
1291 auto constDim = fir::getIntIfConstant(dim);
1292 if (!constDim)
1293 return rewriter.notifyMatchFailure(cshift,
1294 "Nonconstant DIM for CSHIFT");
1295 dimVal = *constDim;
1296 }
1297
1298 if (dimVal <= 0 || dimVal > arrayRank)
1299 return rewriter.notifyMatchFailure(cshift, "Invalid DIM for CSHIFT");
1300
1301 // When DIM==1 and the contiguity of the input array is not statically
1302 // known, try to exploit the fact that the leading dimension might be
1303 // contiguous. We can do this now using hlfir.eval_in_mem with
1304 // a dynamic check for the leading dimension contiguity.
1305 // Otherwise, convert hlfir.cshift to hlfir.elemental.
1306 //
1307 // Note that the hlfir.elemental can be inlined into other hlfir.elemental,
1308 // while hlfir.eval_in_mem prevents this, and we will end up creating
1309 // a temporary array for the result. We may need to come up with
1310 // a more sophisticated logic for picking the most efficient
1311 // representation.
1312 hlfir::Entity array = hlfir::Entity{cshift.getArray()};
1313 mlir::Type elementType = array.getFortranElementType();
1314 if (dimVal == 1 && fir::isa_trivial(elementType) &&
1315 // genInMemCShift() only works for variables currently.
1316 array.isVariable())
1317 rewriter.replaceOp(cshift, genInMemCShift(rewriter, cshift, dimVal));
1318 else
1319 rewriter.replaceOp(cshift, genElementalCShift(rewriter, cshift, dimVal));
1320 return mlir::success();
1321 }
1322
1323private:
1324 /// Generate MODULO(\p shiftVal, \p extent).
1325 static mlir::Value normalizeShiftValue(mlir::Location loc,
1326 fir::FirOpBuilder &builder,
1327 mlir::Value shiftVal,
1328 mlir::Value extent,
1329 mlir::Type calcType) {
1330 shiftVal = builder.createConvert(loc, calcType, shiftVal);
1331 extent = builder.createConvert(loc, calcType, extent);
1332 // Make sure that we do not divide by zero. When the dimension
1333 // has zero size, turn the extent into 1. Note that the computed
1334 // MODULO value won't be used in this case, so it does not matter
1335 // which extent value we use.
1336 mlir::Value zero = builder.createIntegerConstant(loc, calcType, 0);
1337 mlir::Value one = builder.createIntegerConstant(loc, calcType, 1);
1338 mlir::Value isZero = builder.create<mlir::arith::CmpIOp>(
1339 loc, mlir::arith::CmpIPredicate::eq, extent, zero);
1340 extent = builder.create<mlir::arith::SelectOp>(loc, isZero, one, extent);
1341 shiftVal = fir::IntrinsicLibrary{builder, loc}.genModulo(
1342 calcType, {shiftVal, extent});
1343 return builder.createConvert(loc, calcType, shiftVal);
1344 }
1345
1346 /// Convert \p cshift into an hlfir.elemental using
1347 /// the pre-computed constant \p dimVal.
1348 static mlir::Operation *genElementalCShift(mlir::PatternRewriter &rewriter,
1349 hlfir::CShiftOp cshift,
1350 int64_t dimVal) {
1351 using Fortran::common::maxRank;
1352 hlfir::Entity shift = hlfir::Entity{cshift.getShift()};
1353 hlfir::Entity array = hlfir::Entity{cshift.getArray()};
1354
1355 mlir::Location loc = cshift.getLoc();
1356 fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
1357 // The new index computation involves MODULO, which is not implemented
1358 // for IndexType, so use I64 instead.
1359 mlir::Type calcType = builder.getI64Type();
1360 // All the indices arithmetic used below does not overflow
1361 // signed and unsigned I64.
1362 builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw |
1363 mlir::arith::IntegerOverflowFlags::nuw);
1364
1365 mlir::Value arrayShape = hlfir::genShape(loc, builder, array);
1366 llvm::SmallVector<mlir::Value, maxRank> arrayExtents =
1367 hlfir::getExplicitExtentsFromShape(arrayShape, builder);
1368 llvm::SmallVector<mlir::Value, 1> typeParams;
1369 hlfir::genLengthParameters(loc, builder, array, typeParams);
1370 mlir::Value shiftDimExtent =
1371 builder.createConvert(loc, calcType, arrayExtents[dimVal - 1]);
1372 mlir::Value shiftVal;
1373 if (shift.isScalar()) {
1374 shiftVal = hlfir::loadTrivialScalar(loc, builder, shift);
1375 shiftVal =
1376 normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent, calcType);
1377 }
1378
1379 auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1380 mlir::ValueRange inputIndices) -> hlfir::Entity {
1381 llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
1382 if (!shiftVal) {
1383 // When the array is not a vector, section
1384 // (s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)
1385 // of the result has a value equal to:
1386 // CSHIFT(ARRAY(s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)),
1387 // SH, 1),
1388 // where SH is either SHIFT (if scalar) or
1389 // SHIFT(s(1), s(2), ..., s(dim-1), s(dim+1), ..., s(n)).
1390 llvm::SmallVector<mlir::Value, maxRank> shiftIndices{indices};
1391 shiftIndices.erase(shiftIndices.begin() + dimVal - 1);
1392 hlfir::Entity shiftElement =
1393 hlfir::getElementAt(loc, builder, shift, shiftIndices);
1394 shiftVal = hlfir::loadTrivialScalar(loc, builder, shiftElement);
1395 shiftVal = normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent,
1396 calcType);
1397 }
1398
1399 // Element i of the result (1-based) is element
1400 // 'MODULO(i + SH - 1, SIZE(ARRAY,DIM)) + 1' (1-based) of the original
1401 // ARRAY (or its section, when ARRAY is not a vector).
1402
1403 // Compute the index into the original array using the normalized
1404 // shift value, which satisfies (SH >= 0 && SH < SIZE(ARRAY,DIM)):
1405 // newIndex =
1406 // i + ((i <= SIZE(ARRAY,DIM) - SH) ? SH : SH - SIZE(ARRAY,DIM))
1407 //
1408 // Such index computation allows for further loop vectorization
1409 // in LLVM.
1410 mlir::Value wrapBound =
1411 builder.create<mlir::arith::SubIOp>(loc, shiftDimExtent, shiftVal);
1412 mlir::Value adjustedShiftVal =
1413 builder.create<mlir::arith::SubIOp>(loc, shiftVal, shiftDimExtent);
1414 mlir::Value index =
1415 builder.createConvert(loc, calcType, inputIndices[dimVal - 1]);
1416 mlir::Value wrapCheck = builder.create<mlir::arith::CmpIOp>(
1417 loc, mlir::arith::CmpIPredicate::sle, index, wrapBound);
1418 mlir::Value actualShift = builder.create<mlir::arith::SelectOp>(
1419 loc, wrapCheck, shiftVal, adjustedShiftVal);
1420 mlir::Value newIndex =
1421 builder.create<mlir::arith::AddIOp>(loc, index, actualShift);
1422 newIndex = builder.createConvert(loc, builder.getIndexType(), newIndex);
1423 indices[dimVal - 1] = newIndex;
1424 hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
1425 return hlfir::loadTrivialScalar(loc, builder, element);
1426 };
1427
1428 mlir::Type elementType = array.getFortranElementType();
1429 hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
1430 loc, builder, elementType, arrayShape, typeParams, genKernel,
1431 /*isUnordered=*/true,
1432 array.isPolymorphic() ? static_cast<mlir::Value>(array) : nullptr,
1433 cshift.getResult().getType());
1434 return elementalOp.getOperation();
1435 }
1436
1437 /// Convert \p cshift into an hlfir.eval_in_mem using the pre-computed
1438 /// constant \p dimVal.
1439 /// The converted code looks like this:
1440 /// do i=1,SH
1441 /// result(i + (SIZE(ARRAY,DIM) - SH)) = array(i)
1442 /// end
1443 /// do i=1,SIZE(ARRAY,DIM) - SH
1444 /// result(i) = array(i + SH)
1445 /// end
1446 ///
1447 /// When \p dimVal is 1, we generate the same code twice
1448 /// under a dynamic check for the contiguity of the leading
1449 /// dimension. In the code corresponding to the contiguous
1450 /// leading dimension, the shift dimension is represented
1451 /// as a contiguous slice of the original array.
1452 /// This allows recognizing the above two loops as memcpy
1453 /// loop idioms in LLVM.
1454 static mlir::Operation *genInMemCShift(mlir::PatternRewriter &rewriter,
1455 hlfir::CShiftOp cshift,
1456 int64_t dimVal) {
1457 using Fortran::common::maxRank;
1458 hlfir::Entity shift = hlfir::Entity{cshift.getShift()};
1459 hlfir::Entity array = hlfir::Entity{cshift.getArray()};
1460 assert(array.isVariable() && "array must be a variable");
1461 assert(!array.isPolymorphic() &&
1462 "genInMemCShift does not support polymorphic types");
1463 mlir::Location loc = cshift.getLoc();
1464 fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
1465 // The new index computation involves MODULO, which is not implemented
1466 // for IndexType, so use I64 instead.
1467 mlir::Type calcType = builder.getI64Type();
1468 // All the indices arithmetic used below does not overflow
1469 // signed and unsigned I64.
1470 builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw |
1471 mlir::arith::IntegerOverflowFlags::nuw);
1472
1473 mlir::Value arrayShape = hlfir::genShape(loc, builder, array);
1474 llvm::SmallVector<mlir::Value, maxRank> arrayExtents =
1475 hlfir::getExplicitExtentsFromShape(arrayShape, builder);
1476 llvm::SmallVector<mlir::Value, 1> typeParams;
1477 hlfir::genLengthParameters(loc, builder, array, typeParams);
1478 mlir::Value shiftDimExtent =
1479 builder.createConvert(loc, calcType, arrayExtents[dimVal - 1]);
1480 mlir::Value shiftVal;
1481 if (shift.isScalar()) {
1482 shiftVal = hlfir::loadTrivialScalar(loc, builder, shift);
1483 shiftVal =
1484 normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent, calcType);
1485 }
1486
1487 hlfir::EvaluateInMemoryOp evalOp =
1488 builder.create<hlfir::EvaluateInMemoryOp>(
1489 loc, mlir::cast<hlfir::ExprType>(cshift.getType()), arrayShape);
1490 builder.setInsertionPointToStart(&evalOp.getBody().front());
1491
1492 mlir::Value resultArray = evalOp.getMemory();
1493 mlir::Type arrayType = fir::dyn_cast_ptrEleTy(resultArray.getType());
1494 resultArray = builder.createBox(loc, fir::BoxType::get(arrayType),
1495 resultArray, arrayShape, /*slice=*/nullptr,
1496 typeParams, /*tdesc=*/nullptr);
1497
1498 // This is a generator of the dimension shift code.
1499 // The code is inserted inside a loop nest over the other dimensions
1500 // (if any). If exposeContiguity is true, the array's section
1501 // array(s(1), ..., s(dim-1), :, s(dim+1), ..., s(n)) is represented
1502 // as a contiguous 1D array.
1503 // shiftVal is the normalized shift value that satisfies (SH >= 0 && SH <
1504 // SIZE(ARRAY,DIM)).
1505 //
1506 auto genDimensionShift = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1507 mlir::Value shiftVal, bool exposeContiguity,
1508 mlir::ValueRange oneBasedIndices)
1509 -> llvm::SmallVector<mlir::Value, 0> {
1510 // Create a vector of indices (s(1), ..., s(dim-1), nullptr, s(dim+1),
1511 // ..., s(n)) so that we can update the dimVal index as needed.
1512 llvm::SmallVector<mlir::Value, maxRank> srcIndices(
1513 oneBasedIndices.begin(), oneBasedIndices.begin() + (dimVal - 1));
1514 srcIndices.push_back(nullptr);
1515 srcIndices.append(oneBasedIndices.begin() + (dimVal - 1),
1516 oneBasedIndices.end());
1517 llvm::SmallVector<mlir::Value, maxRank> dstIndices(srcIndices);
1518
1519 hlfir::Entity srcArray = array;
1520 if (exposeContiguity && mlir::isa<fir::BaseBoxType>(srcArray.getType())) {
1521 assert(dimVal == 1 && "can expose contiguity only for dim 1");
1522 llvm::SmallVector<mlir::Value, maxRank> arrayLbounds =
1523 hlfir::genLowerbounds(loc, builder, arrayShape, array.getRank());
1524 hlfir::Entity section =
1525 hlfir::gen1DSection(loc, builder, srcArray, dimVal, arrayLbounds,
1526 arrayExtents, oneBasedIndices, typeParams);
1527 mlir::Value addr = hlfir::genVariableRawAddress(loc, builder, section);
1528 mlir::Value shape = hlfir::genShape(loc, builder, section);
1529 mlir::Type boxType = fir::wrapInClassOrBoxType(
1530 hlfir::getFortranElementOrSequenceType(section.getType()),
1531 section.isPolymorphic());
1532 srcArray = hlfir::Entity{
1533 builder.createBox(loc, boxType, addr, shape, /*slice=*/nullptr,
1534 /*lengths=*/{}, /*tdesc=*/nullptr)};
1535 // When shifting the dimension as a 1D section of the original
1536 // array, we only need one index for addressing.
1537 srcIndices.resize(1);
1538 }
1539
1540 // Copy first portion of the array:
1541 // do i=1,SH
1542 // result(i + (SIZE(ARRAY,DIM) - SH)) = array(i)
1543 // end
1544 auto genAssign1 = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1545 mlir::ValueRange index,
1546 mlir::ValueRange reductionArgs)
1547 -> llvm::SmallVector<mlir::Value, 0> {
1548 assert(index.size() == 1 && "expected single loop");
1549 mlir::Value srcIndex = builder.createConvert(loc, calcType, index[0]);
1550 srcIndices[dimVal - 1] = srcIndex;
1551 hlfir::Entity srcElementValue =
1552 hlfir::loadElementAt(loc, builder, srcArray, srcIndices);
1553 mlir::Value dstIndex = builder.create<mlir::arith::AddIOp>(
1554 loc, srcIndex,
1555 builder.create<mlir::arith::SubIOp>(loc, shiftDimExtent, shiftVal));
1556 dstIndices[dimVal - 1] = dstIndex;
1557 hlfir::Entity dstElement = hlfir::getElementAt(
1558 loc, builder, hlfir::Entity{resultArray}, dstIndices);
1559 builder.create<hlfir::AssignOp>(loc, srcElementValue, dstElement);
1560 return {};
1561 };
1562
1563 // Generate the first loop.
1564 hlfir::genLoopNestWithReductions(loc, builder, {shiftVal},
1565 /*reductionInits=*/{}, genAssign1,
1566 /*isUnordered=*/true);
1567
1568 // Copy second portion of the array:
1569 // do i=1,SIZE(ARRAY,DIM)-SH
1570 // result(i) = array(i + SH)
1571 // end
1572 auto genAssign2 = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1573 mlir::ValueRange index,
1574 mlir::ValueRange reductionArgs)
1575 -> llvm::SmallVector<mlir::Value, 0> {
1576 assert(index.size() == 1 && "expected single loop");
1577 mlir::Value dstIndex = builder.createConvert(loc, calcType, index[0]);
1578 mlir::Value srcIndex =
1579 builder.create<mlir::arith::AddIOp>(loc, dstIndex, shiftVal);
1580 srcIndices[dimVal - 1] = srcIndex;
1581 hlfir::Entity srcElementValue =
1582 hlfir::loadElementAt(loc, builder, srcArray, srcIndices);
1583 dstIndices[dimVal - 1] = dstIndex;
1584 hlfir::Entity dstElement = hlfir::getElementAt(
1585 loc, builder, hlfir::Entity{resultArray}, dstIndices);
1586 builder.create<hlfir::AssignOp>(loc, srcElementValue, dstElement);
1587 return {};
1588 };
1589
1590 // Generate the second loop.
1591 mlir::Value bound =
1592 builder.create<mlir::arith::SubIOp>(loc, shiftDimExtent, shiftVal);
1593 hlfir::genLoopNestWithReductions(loc, builder, {bound},
1594 /*reductionInits=*/{}, genAssign2,
1595 /*isUnordered=*/true);
1596 return {};
1597 };
1598
1599 // A wrapper around genDimensionShift that computes the normalized
1600 // shift value and manages the insertion of the multiple versions
1601 // of the shift based on the dynamic check of the leading dimension's
1602 // contiguity (when dimVal == 1).
1603 auto genShiftBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1604 mlir::ValueRange oneBasedIndices,
1605 mlir::ValueRange reductionArgs)
1606 -> llvm::SmallVector<mlir::Value, 0> {
1607 // Copy the dimension with a shift:
1608 // SH is either SHIFT (if scalar) or SHIFT(oneBasedIndices).
1609 if (!shiftVal) {
1610 assert(!oneBasedIndices.empty() && "scalar shift must be precomputed");
1611 hlfir::Entity shiftElement =
1612 hlfir::getElementAt(loc, builder, shift, oneBasedIndices);
1613 shiftVal = hlfir::loadTrivialScalar(loc, builder, shiftElement);
1614 shiftVal = normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent,
1615 calcType);
1616 }
1617
1618 // If we can fetch the byte stride of the leading dimension,
1619 // and the byte size of the element, then we can generate
1620 // a dynamic contiguity check and expose the leading dimension's
1621 // contiguity in FIR, making memcpy loop idiom recognition
1622 // possible.
1623 mlir::Value elemSize;
1624 mlir::Value stride;
1625 if (dimVal == 1 && mlir::isa<fir::BaseBoxType>(array.getType())) {
1626 mlir::Type indexType = builder.getIndexType();
1627 elemSize =
1628 builder.create<fir::BoxEleSizeOp>(loc, indexType, array.getBase());
1629 mlir::Value dimIdx =
1630 builder.createIntegerConstant(loc, indexType, dimVal - 1);
1631 auto boxDim = builder.create<fir::BoxDimsOp>(
1632 loc, indexType, indexType, indexType, array.getBase(), dimIdx);
1633 stride = boxDim.getByteStride();
1634 }
1635
1636 if (array.isSimplyContiguous() || !elemSize || !stride) {
1637 genDimensionShift(loc, builder, shiftVal, /*exposeContiguity=*/false,
1638 oneBasedIndices);
1639 return {};
1640 }
1641
1642 mlir::Value isContiguous = builder.create<mlir::arith::CmpIOp>(
1643 loc, mlir::arith::CmpIPredicate::eq, elemSize, stride);
1644 builder.genIfOp(loc, {}, isContiguous, /*withElseRegion=*/true)
1645 .genThen([&]() {
1646 genDimensionShift(loc, builder, shiftVal, /*exposeContiguity=*/true,
1647 oneBasedIndices);
1648 })
1649 .genElse([&]() {
1650 genDimensionShift(loc, builder, shiftVal,
1651 /*exposeContiguity=*/false, oneBasedIndices);
1652 });
1653
1654 return {};
1655 };
1656
1657 // For 1D case, generate a single loop.
1658 // For ND case, generate a loop nest over the other dimensions
1659 // with a single loop inside (generated separately).
1660 llvm::SmallVector<mlir::Value, maxRank> newExtents(arrayExtents);
1661 newExtents.erase(newExtents.begin() + (dimVal - 1));
1662 if (!newExtents.empty())
1663 hlfir::genLoopNestWithReductions(loc, builder, newExtents,
1664 /*reductionInits=*/{}, genShiftBody,
1665 /*isUnordered=*/true);
1666 else
1667 genShiftBody(loc, builder, {}, {});
1668
1669 return evalOp.getOperation();
1670 }
1671};
1672
1673template <typename Op>
1674class MatmulConversion : public mlir::OpRewritePattern<Op> {
1675public:
1676 using mlir::OpRewritePattern<Op>::OpRewritePattern;
1677
1678 llvm::LogicalResult
1679 matchAndRewrite(Op matmul, mlir::PatternRewriter &rewriter) const override {
1680 mlir::Location loc = matmul.getLoc();
1681 fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
1682 hlfir::Entity lhs = hlfir::Entity{matmul.getLhs()};
1683 hlfir::Entity rhs = hlfir::Entity{matmul.getRhs()};
1684 mlir::Value resultShape, innerProductExtent;
1685 std::tie(resultShape, innerProductExtent) =
1686 genResultShape(loc, builder, lhs, rhs);
1687
1688 if (forceMatmulAsElemental || isMatmulTranspose) {
1689 // Generate hlfir.elemental that produces the result of
1690 // MATMUL/MATMUL(TRANSPOSE).
1691 // Note that this implementation is very suboptimal for MATMUL,
1692 // but is quite good for MATMUL(TRANSPOSE), e.g.:
1693 // R(1:N) = R(1:N) + MATMUL(TRANSPOSE(X(1:N,1:N)), Y(1:N))
1694 // Inlining MATMUL(TRANSPOSE) as hlfir.elemental may result
1695 // in merging the inner product computation with the elemental
1696 // addition. Note that the inner product computation will
1697 // benefit from processing the lowermost dimensions of X and Y,
1698 // which may be the best when they are contiguous.
1699 //
1700 // This is why we always inline MATMUL(TRANSPOSE) as an elemental.
1701 // MATMUL is inlined below by default unless forceMatmulAsElemental.
1702 hlfir::ExprType resultType =
1703 mlir::cast<hlfir::ExprType>(matmul.getType());
1704 hlfir::ElementalOp newOp = genElementalMatmul(
1705 loc, builder, resultType, resultShape, lhs, rhs, innerProductExtent);
1706 rewriter.replaceOp(matmul, newOp);
1707 return mlir::success();
1708 }
1709
1710 // Generate hlfir.eval_in_mem to mimic the MATMUL implementation
1711 // from Fortran runtime. The implementation needs to operate
1712 // with the result array as an in-memory object.
1713 hlfir::EvaluateInMemoryOp evalOp =
1714 builder.create<hlfir::EvaluateInMemoryOp>(
1715 loc, mlir::cast<hlfir::ExprType>(matmul.getType()), resultShape);
1716 builder.setInsertionPointToStart(&evalOp.getBody().front());
1717
1718 // Embox the raw array pointer to simplify designating it.
1719 // TODO: this currently results in redundant lower bounds
1720 // addition for the designator, but this should be fixed in
1721 // hlfir::Entity::mayHaveNonDefaultLowerBounds().
1722 mlir::Value resultArray = evalOp.getMemory();
1723 mlir::Type arrayType = fir::dyn_cast_ptrEleTy(resultArray.getType());
1724 resultArray = builder.createBox(loc, fir::BoxType::get(arrayType),
1725 resultArray, resultShape, /*slice=*/nullptr,
1726 /*lengths=*/{}, /*tdesc=*/nullptr);
1727
1728 // The contiguous MATMUL version is best for the cases
1729 // where the input arrays and (maybe) the result are contiguous
1730 // in their lowermost dimensions.
1731 // Especially, when LLVM can recognize the continuity
1732 // and vectorize the loops properly.
1733 // Note that the contiguous MATMUL inlining is correct
1734 // even when the input arrays are not contiguous.
1735 // TODO: we can try to recognize the cases when the continuity
1736 // is not statically obvious and try to generate an explicitly
1737 // continuous version under a dynamic check. This should allow
1738 // LLVM to vectorize the loops better. Note that this can
1739 // also be postponed up to the LoopVersioning pass.
1740 // The fallback implementation may use genElementalMatmul() with
1741 // an hlfir.assign into the result of eval_in_mem.
1742 mlir::LogicalResult rewriteResult =
1743 genContiguousMatmul(loc, builder, hlfir::Entity{resultArray},
1744 resultShape, lhs, rhs, innerProductExtent);
1745
1746 if (mlir::failed(rewriteResult)) {
1747 // Erase the unclaimed eval_in_mem op.
1748 rewriter.eraseOp(evalOp);
1749 return rewriter.notifyMatchFailure(matmul,
1750 "genContiguousMatmul() failed");
1751 }
1752
1753 rewriter.replaceOp(matmul, evalOp);
1754 return mlir::success();
1755 }
1756
1757private:
1758 static constexpr bool isMatmulTranspose =
1759 std::is_same_v<Op, hlfir::MatmulTransposeOp>;
1760
1761 // Return a tuple of:
1762 // * A fir.shape operation representing the shape of the result
1763 // of a MATMUL/MATMUL(TRANSPOSE).
1764 // * An extent of the dimensions of the input array
1765 // that are processed during the inner product computation.
1766 static std::tuple<mlir::Value, mlir::Value>
1767 genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
1768 hlfir::Entity input1, hlfir::Entity input2) {
1769 llvm::SmallVector<mlir::Value, 2> input1Extents =
1770 hlfir::genExtentsVector(loc, builder, input1);
1771 llvm::SmallVector<mlir::Value, 2> input2Extents =
1772 hlfir::genExtentsVector(loc, builder, input2);
1773
1774 llvm::SmallVector<mlir::Value, 2> newExtents;
1775 mlir::Value innerProduct1Extent, innerProduct2Extent;
1776 if (input1Extents.size() == 1) {
1777 assert(!isMatmulTranspose &&
1778 "hlfir.matmul_transpose's first operand must be rank-2 array");
1779 assert(input2Extents.size() == 2 &&
1780 "hlfir.matmul second argument must be rank-2 array");
1781 newExtents.push_back(input2Extents[1]);
1782 innerProduct1Extent = input1Extents[0];
1783 innerProduct2Extent = input2Extents[0];
1784 } else {
1785 if (input2Extents.size() == 1) {
1786 assert(input1Extents.size() == 2 &&
1787 "hlfir.matmul first argument must be rank-2 array");
1788 if constexpr (isMatmulTranspose)
1789 newExtents.push_back(input1Extents[1]);
1790 else
1791 newExtents.push_back(input1Extents[0]);
1792 } else {
1793 assert(input1Extents.size() == 2 && input2Extents.size() == 2 &&
1794 "hlfir.matmul arguments must be rank-2 arrays");
1795 if constexpr (isMatmulTranspose)
1796 newExtents.push_back(input1Extents[1]);
1797 else
1798 newExtents.push_back(input1Extents[0]);
1799
1800 newExtents.push_back(input2Extents[1]);
1801 }
1802 if constexpr (isMatmulTranspose)
1803 innerProduct1Extent = input1Extents[0];
1804 else
1805 innerProduct1Extent = input1Extents[1];
1806
1807 innerProduct2Extent = input2Extents[0];
1808 }
1809 // The inner product dimensions of the input arrays
1810 // must match. Pick the best (e.g. constant) out of them
1811 // so that the inner product loop bound can be used in
1812 // optimizations.
1813 llvm::SmallVector<mlir::Value> innerProductExtent =
1814 fir::factory::deduceOptimalExtents({innerProduct1Extent},
1815 {innerProduct2Extent});
1816 return {builder.create<fir::ShapeOp>(loc, newExtents),
1817 innerProductExtent[0]};
1818 }
1819
1820 static mlir::LogicalResult
1821 genContiguousMatmul(mlir::Location loc, fir::FirOpBuilder &builder,
1822 hlfir::Entity result, mlir::Value resultShape,
1823 hlfir::Entity lhs, hlfir::Entity rhs,
1824 mlir::Value innerProductExtent) {
1825 // This code does not support MATMUL(TRANSPOSE), and it is supposed
1826 // to be inlined as hlfir.elemental.
1827 if constexpr (isMatmulTranspose)
1828 return mlir::failure();
1829
1830 mlir::OpBuilder::InsertionGuard guard(builder);
1831 mlir::Type resultElementType = result.getFortranElementType();
1832 llvm::SmallVector<mlir::Value, 2> resultExtents =
1833 mlir::cast<fir::ShapeOp>(resultShape.getDefiningOp()).getExtents();
1834
1835 // The inner product loop may be unordered if FastMathFlags::reassoc
1836 // transformations are allowed. The integer/logical inner product is
1837 // always unordered.
1838 // Note that isUnordered is currently applied to all loops
1839 // in the loop nests generated below, while it has to be applied
1840 // only to one.
1841 bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
1842 mlir::isa<fir::LogicalType>(resultElementType) ||
1843 static_cast<bool>(builder.getFastMathFlags() &
1844 mlir::arith::FastMathFlags::reassoc);
1845
1846 // Insert the initialization loop nest that fills the whole result with
1847 // zeroes.
1848 mlir::Value initValue =
1849 fir::factory::createZeroValue(builder, loc, resultElementType);
1850 auto genInitBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1851 mlir::ValueRange oneBasedIndices,
1852 mlir::ValueRange reductionArgs)
1853 -> llvm::SmallVector<mlir::Value, 0> {
1854 hlfir::Entity resultElement =
1855 hlfir::getElementAt(loc, builder, result, oneBasedIndices);
1856 builder.create<hlfir::AssignOp>(loc, initValue, resultElement);
1857 return {};
1858 };
1859
1860 hlfir::genLoopNestWithReductions(loc, builder, resultExtents,
1861 /*reductionInits=*/{}, genInitBody,
1862 /*isUnordered=*/true);
1863
1864 if (lhs.getRank() == 2 && rhs.getRank() == 2) {
1865 // LHS(NROWS,N) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS)
1866 //
1867 // Insert the computation loop nest:
1868 // DO 2 K = 1, N
1869 // DO 2 J = 1, NCOLS
1870 // DO 2 I = 1, NROWS
1871 // 2 RESULT(I,J) = RESULT(I,J) + LHS(I,K)*RHS(K,J)
1872 auto genMatrixMatrix = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1873 mlir::ValueRange oneBasedIndices,
1874 mlir::ValueRange reductionArgs)
1875 -> llvm::SmallVector<mlir::Value, 0> {
1876 mlir::Value I = oneBasedIndices[0];
1877 mlir::Value J = oneBasedIndices[1];
1878 mlir::Value K = oneBasedIndices[2];
1879 hlfir::Entity resultElement =
1880 hlfir::getElementAt(loc, builder, result, {I, J});
1881 hlfir::Entity resultElementValue =
1882 hlfir::loadTrivialScalar(loc, builder, resultElement);
1883 hlfir::Entity lhsElementValue =
1884 hlfir::loadElementAt(loc, builder, lhs, {I, K});
1885 hlfir::Entity rhsElementValue =
1886 hlfir::loadElementAt(loc, builder, rhs, {K, J});
1887 mlir::Value productValue =
1888 ProductFactory{loc, builder}.genAccumulateProduct(
1889 resultElementValue, lhsElementValue, rhsElementValue);
1890 builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
1891 return {};
1892 };
1893
1894 // Note that the loops are inserted in reverse order,
1895 // so innerProductExtent should be passed as the last extent.
1896 hlfir::genLoopNestWithReductions(
1897 loc, builder,
1898 {resultExtents[0], resultExtents[1], innerProductExtent},
1899 /*reductionInits=*/{}, genMatrixMatrix, isUnordered);
1900 return mlir::success();
1901 }
1902
1903 if (lhs.getRank() == 2 && rhs.getRank() == 1) {
1904 // LHS(NROWS,N) * RHS(N) -> RESULT(NROWS)
1905 //
1906 // Insert the computation loop nest:
1907 // DO 2 K = 1, N
1908 // DO 2 J = 1, NROWS
1909 // 2 RES(J) = RES(J) + LHS(J,K)*RHS(K)
1910 auto genMatrixVector = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1911 mlir::ValueRange oneBasedIndices,
1912 mlir::ValueRange reductionArgs)
1913 -> llvm::SmallVector<mlir::Value, 0> {
1914 mlir::Value J = oneBasedIndices[0];
1915 mlir::Value K = oneBasedIndices[1];
1916 hlfir::Entity resultElement =
1917 hlfir::getElementAt(loc, builder, result, {J});
1918 hlfir::Entity resultElementValue =
1919 hlfir::loadTrivialScalar(loc, builder, resultElement);
1920 hlfir::Entity lhsElementValue =
1921 hlfir::loadElementAt(loc, builder, lhs, {J, K});
1922 hlfir::Entity rhsElementValue =
1923 hlfir::loadElementAt(loc, builder, rhs, {K});
1924 mlir::Value productValue =
1925 ProductFactory{loc, builder}.genAccumulateProduct(
1926 resultElementValue, lhsElementValue, rhsElementValue);
1927 builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
1928 return {};
1929 };
1930 hlfir::genLoopNestWithReductions(
1931 loc, builder, {resultExtents[0], innerProductExtent},
1932 /*reductionInits=*/{}, genMatrixVector, isUnordered);
1933 return mlir::success();
1934 }
1935 if (lhs.getRank() == 1 && rhs.getRank() == 2) {
1936 // LHS(N) * RHS(N,NCOLS) -> RESULT(NCOLS)
1937 //
1938 // Insert the computation loop nest:
1939 // DO 2 K = 1, N
1940 // DO 2 J = 1, NCOLS
1941 // 2 RES(J) = RES(J) + LHS(K)*RHS(K,J)
1942 auto genVectorMatrix = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1943 mlir::ValueRange oneBasedIndices,
1944 mlir::ValueRange reductionArgs)
1945 -> llvm::SmallVector<mlir::Value, 0> {
1946 mlir::Value J = oneBasedIndices[0];
1947 mlir::Value K = oneBasedIndices[1];
1948 hlfir::Entity resultElement =
1949 hlfir::getElementAt(loc, builder, result, {J});
1950 hlfir::Entity resultElementValue =
1951 hlfir::loadTrivialScalar(loc, builder, resultElement);
1952 hlfir::Entity lhsElementValue =
1953 hlfir::loadElementAt(loc, builder, lhs, {K});
1954 hlfir::Entity rhsElementValue =
1955 hlfir::loadElementAt(loc, builder, rhs, {K, J});
1956 mlir::Value productValue =
1957 ProductFactory{loc, builder}.genAccumulateProduct(
1958 resultElementValue, lhsElementValue, rhsElementValue);
1959 builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
1960 return {};
1961 };
1962 hlfir::genLoopNestWithReductions(
1963 loc, builder, {resultExtents[0], innerProductExtent},
1964 /*reductionInits=*/{}, genVectorMatrix, isUnordered);
1965 return mlir::success();
1966 }
1967
1968 llvm_unreachable("unsupported MATMUL arguments' ranks");
1969 }
1970
1971 static hlfir::ElementalOp
1972 genElementalMatmul(mlir::Location loc, fir::FirOpBuilder &builder,
1973 hlfir::ExprType resultType, mlir::Value resultShape,
1974 hlfir::Entity lhs, hlfir::Entity rhs,
1975 mlir::Value innerProductExtent) {
1976 mlir::OpBuilder::InsertionGuard guard(builder);
1977 mlir::Type resultElementType = resultType.getElementType();
1978 auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1979 mlir::ValueRange resultIndices) -> hlfir::Entity {
1980 mlir::Value initValue =
1981 fir::factory::createZeroValue(builder, loc, resultElementType);
1982 // The inner product loop may be unordered if FastMathFlags::reassoc
1983 // transformations are allowed. The integer/logical inner product is
1984 // always unordered.
1985 bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
1986 mlir::isa<fir::LogicalType>(resultElementType) ||
1987 static_cast<bool>(builder.getFastMathFlags() &
1988 mlir::arith::FastMathFlags::reassoc);
1989
1990 auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1991 mlir::ValueRange oneBasedIndices,
1992 mlir::ValueRange reductionArgs)
1993 -> llvm::SmallVector<mlir::Value, 1> {
1994 llvm::SmallVector<mlir::Value, 2> lhsIndices;
1995 llvm::SmallVector<mlir::Value, 2> rhsIndices;
1996 // MATMUL:
1997 // LHS(NROWS,N) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS)
1998 // LHS(NROWS,N) * RHS(N) -> RESULT(NROWS)
1999 // LHS(N) * RHS(N,NCOLS) -> RESULT(NCOLS)
2000 //
2001 // MATMUL(TRANSPOSE):
2002 // TRANSPOSE(LHS(N,NROWS)) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS)
2003 // TRANSPOSE(LHS(N,NROWS)) * RHS(N) -> RESULT(NROWS)
2004 //
2005 // The resultIndices iterate over (NROWS[,NCOLS]).
2006 // The oneBasedIndices iterate over (N).
2007 if (lhs.getRank() > 1)
2008 lhsIndices.push_back(resultIndices[0]);
2009 lhsIndices.push_back(oneBasedIndices[0]);
2010
2011 if constexpr (isMatmulTranspose) {
2012 // Swap the LHS indices for TRANSPOSE.
2013 std::swap(lhsIndices[0], lhsIndices[1]);
2014 }
2015
2016 rhsIndices.push_back(oneBasedIndices[0]);
2017 if (rhs.getRank() > 1)
2018 rhsIndices.push_back(resultIndices.back());
2019
2020 hlfir::Entity lhsElementValue =
2021 hlfir::loadElementAt(loc, builder, lhs, lhsIndices);
2022 hlfir::Entity rhsElementValue =
2023 hlfir::loadElementAt(loc, builder, rhs, rhsIndices);
2024 mlir::Value productValue =
2025 ProductFactory{loc, builder}.genAccumulateProduct(
2026 reductionArgs[0], lhsElementValue, rhsElementValue);
2027 return {productValue};
2028 };
2029 llvm::SmallVector<mlir::Value, 1> innerProductValue =
2030 hlfir::genLoopNestWithReductions(loc, builder, {innerProductExtent},
2031 {initValue}, genBody, isUnordered);
2032 return hlfir::Entity{innerProductValue[0]};
2033 };
2034 hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
2035 loc, builder, resultElementType, resultShape, /*typeParams=*/{},
2036 genKernel,
2037 /*isUnordered=*/true, /*polymorphicMold=*/nullptr, resultType);
2038
2039 return elementalOp;
2040 }
2041};
2042
2043class DotProductConversion
2044 : public mlir::OpRewritePattern<hlfir::DotProductOp> {
2045public:
2046 using mlir::OpRewritePattern<hlfir::DotProductOp>::OpRewritePattern;
2047
2048 llvm::LogicalResult
2049 matchAndRewrite(hlfir::DotProductOp product,
2050 mlir::PatternRewriter &rewriter) const override {
2051 hlfir::Entity op = hlfir::Entity{product};
2052 if (!op.isScalar())
2053 return rewriter.notifyMatchFailure(product, "produces non-scalar result");
2054
2055 mlir::Location loc = product.getLoc();
2056 fir::FirOpBuilder builder{rewriter, product.getOperation()};
2057 hlfir::Entity lhs = hlfir::Entity{product.getLhs()};
2058 hlfir::Entity rhs = hlfir::Entity{product.getRhs()};
2059 mlir::Type resultElementType = product.getType();
2060 bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
2061 mlir::isa<fir::LogicalType>(resultElementType) ||
2062 static_cast<bool>(builder.getFastMathFlags() &
2063 mlir::arith::FastMathFlags::reassoc);
2064
2065 mlir::Value extent = genProductExtent(loc, builder, lhs, rhs);
2066
2067 auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
2068 mlir::ValueRange oneBasedIndices,
2069 mlir::ValueRange reductionArgs)
2070 -> llvm::SmallVector<mlir::Value, 1> {
2071 hlfir::Entity lhsElementValue =
2072 hlfir::loadElementAt(loc, builder, lhs, oneBasedIndices);
2073 hlfir::Entity rhsElementValue =
2074 hlfir::loadElementAt(loc, builder, rhs, oneBasedIndices);
2075 mlir::Value productValue =
2076 ProductFactory{loc, builder}.genAccumulateProduct</*CONJ=*/true>(
2077 reductionArgs[0], lhsElementValue, rhsElementValue);
2078 return {productValue};
2079 };
2080
2081 mlir::Value initValue =
2082 fir::factory::createZeroValue(builder, loc, resultElementType);
2083
2084 llvm::SmallVector<mlir::Value, 1> result = hlfir::genLoopNestWithReductions(
2085 loc, builder, {extent},
2086 /*reductionInits=*/{initValue}, genBody, isUnordered);
2087
2088 rewriter.replaceOp(product, result[0]);
2089 return mlir::success();
2090 }
2091
2092private:
2093 static mlir::Value genProductExtent(mlir::Location loc,
2094 fir::FirOpBuilder &builder,
2095 hlfir::Entity input1,
2096 hlfir::Entity input2) {
2097 llvm::SmallVector<mlir::Value, 1> input1Extents =
2098 hlfir::genExtentsVector(loc, builder, input1);
2099 llvm::SmallVector<mlir::Value, 1> input2Extents =
2100 hlfir::genExtentsVector(loc, builder, input2);
2101
2102 assert(input1Extents.size() == 1 && input2Extents.size() == 1 &&
2103 "hlfir.dot_product arguments must be vectors");
2104 llvm::SmallVector<mlir::Value, 1> extent =
2105 fir::factory::deduceOptimalExtents(input1Extents, input2Extents);
2106 return extent[0];
2107 }
2108};
2109
2110class ReshapeAsElementalConversion
2111 : public mlir::OpRewritePattern<hlfir::ReshapeOp> {
2112public:
2113 using mlir::OpRewritePattern<hlfir::ReshapeOp>::OpRewritePattern;
2114
2115 llvm::LogicalResult
2116 matchAndRewrite(hlfir::ReshapeOp reshape,
2117 mlir::PatternRewriter &rewriter) const override {
2118 // Do not inline RESHAPE with ORDER yet. The runtime implementation
2119 // may be good enough, unless the temporary creation overhead
2120 // is high.
2121 // TODO: If ORDER is constant, then we can still easily inline.
2122 // TODO: If the result's rank is 1, then we can assume ORDER == (/1/).
2123 if (reshape.getOrder())
2124 return rewriter.notifyMatchFailure(reshape,
2125 "RESHAPE with ORDER argument");
2126
2127 // Verify that the element types of ARRAY, PAD and the result
2128 // match before doing any transformations. For example,
2129 // the character types of different lengths may appear in the dead
2130 // code, and it just does not make sense to inline hlfir.reshape
2131 // in this case (a runtime call might have less code size footprint).
2132 hlfir::Entity result = hlfir::Entity{reshape};
2133 hlfir::Entity array = hlfir::Entity{reshape.getArray()};
2134 mlir::Type elementType = array.getFortranElementType();
2135 if (result.getFortranElementType() != elementType)
2136 return rewriter.notifyMatchFailure(
2137 reshape, "ARRAY and result have different types");
2138 mlir::Value pad = reshape.getPad();
2139 if (pad && hlfir::getFortranElementType(pad.getType()) != elementType)
2140 return rewriter.notifyMatchFailure(reshape,
2141 "ARRAY and PAD have different types");
2142
2143 // TODO: selecting between ARRAY and PAD of non-trivial element types
2144 // requires more work. We have to select between two references
2145 // to elements in ARRAY and PAD. This requires conditional
2146 // bufferization of the element, if ARRAY/PAD is an expression.
2147 if (pad && !fir::isa_trivial(elementType))
2148 return rewriter.notifyMatchFailure(reshape,
2149 "PAD present with non-trivial type");
2150
2151 mlir::Location loc = reshape.getLoc();
2152 fir::FirOpBuilder builder{rewriter, reshape.getOperation()};
2153 // Assume that all the indices arithmetic does not overflow
2154 // the IndexType.
2155 builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nuw);
2156
2157 llvm::SmallVector<mlir::Value, 1> typeParams;
2158 hlfir::genLengthParameters(loc, builder, array, typeParams);
2159
2160 // Fetch the extents of ARRAY, PAD and result beforehand.
2161 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayExtents =
2162 hlfir::genExtentsVector(loc, builder, array);
2163
2164 // If PAD is present, we have to use array size to start taking
2165 // elements from the PAD array.
2166 mlir::Value arraySize =
2167 pad ? computeArraySize(loc, builder, arrayExtents) : nullptr;
2168 hlfir::Entity shape = hlfir::Entity{reshape.getShape()};
2169 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> resultExtents;
2170 mlir::Type indexType = builder.getIndexType();
2171 for (int idx = 0; idx < result.getRank(); ++idx)
2172 resultExtents.push_back(hlfir::loadElementAt(
2173 loc, builder, shape,
2174 builder.createIntegerConstant(loc, indexType, idx + 1)));
2175 auto resultShape = builder.create<fir::ShapeOp>(loc, resultExtents);
2176
2177 auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
2178 mlir::ValueRange inputIndices) -> hlfir::Entity {
2179 mlir::Value linearIndex =
2180 computeLinearIndex(loc, builder, resultExtents, inputIndices);
2181 fir::IfOp ifOp;
2182 if (pad) {
2183 // PAD is present. Check if this element comes from the PAD array.
2184 mlir::Value isInsideArray = builder.create<mlir::arith::CmpIOp>(
2185 loc, mlir::arith::CmpIPredicate::ult, linearIndex, arraySize);
2186 ifOp = builder.create<fir::IfOp>(loc, elementType, isInsideArray,
2187 /*withElseRegion=*/true);
2188
2189 // In the 'else' block, return an element from the PAD.
2190 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
2191 // PAD is dynamically optional, but we can unconditionally access it
2192 // in the 'else' block. If we have to start taking elements from it,
2193 // then it must be present in a valid program.
2194 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents =
2195 hlfir::genExtentsVector(loc, builder, hlfir::Entity{pad});
2196 // Subtract the ARRAY size from the zero-based linear index
2197 // to get the zero-based linear index into PAD.
2198 mlir::Value padLinearIndex =
2199 builder.create<mlir::arith::SubIOp>(loc, linearIndex, arraySize);
2200 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padIndices =
2201 delinearizeIndex(loc, builder, padExtents, padLinearIndex,
2202 /*wrapAround=*/true);
2203 mlir::Value padElement =
2204 hlfir::loadElementAt(loc, builder, hlfir::Entity{pad}, padIndices);
2205 builder.create<fir::ResultOp>(loc, padElement);
2206
2207 // In the 'then' block, return an element from the ARRAY.
2208 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
2209 }
2210
2211 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayIndices =
2212 delinearizeIndex(loc, builder, arrayExtents, linearIndex,
2213 /*wrapAround=*/false);
2214 mlir::Value arrayElement =
2215 hlfir::loadElementAt(loc, builder, array, arrayIndices);
2216
2217 if (ifOp) {
2218 builder.create<fir::ResultOp>(loc, arrayElement);
2219 builder.setInsertionPointAfter(ifOp);
2220 arrayElement = ifOp.getResult(0);
2221 }
2222
2223 return hlfir::Entity{arrayElement};
2224 };
2225 hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
2226 loc, builder, elementType, resultShape, typeParams, genKernel,
2227 /*isUnordered=*/true,
2228 /*polymorphicMold=*/result.isPolymorphic() ? array : mlir::Value{},
2229 reshape.getResult().getType());
2230 assert(elementalOp.getResult().getType() == reshape.getResult().getType());
2231 rewriter.replaceOp(reshape, elementalOp);
2232 return mlir::success();
2233 }
2234
2235private:
2236 /// Compute zero-based linear index given an array extents
2237 /// and one-based indices:
2238 /// \p extents: [e0, e1, ..., en]
2239 /// \p indices: [i0, i1, ..., in]
2240 ///
2241 /// linear-index :=
2242 /// (...((in-1)*e(n-1)+(i(n-1)-1))*e(n-2)+...)*e0+(i0-1)
2243 static mlir::Value computeLinearIndex(mlir::Location loc,
2244 fir::FirOpBuilder &builder,
2245 mlir::ValueRange extents,
2246 mlir::ValueRange indices) {
2247 std::size_t rank = extents.size();
2248 assert(rank == indices.size());
2249 mlir::Type indexType = builder.getIndexType();
2250 mlir::Value zero = builder.createIntegerConstant(loc, indexType, 0);
2251 mlir::Value one = builder.createIntegerConstant(loc, indexType, 1);
2252 mlir::Value linearIndex = zero;
2253 std::size_t idx = 0;
2254 for (auto index : llvm::reverse(indices)) {
2255 mlir::Value tmp = builder.create<mlir::arith::SubIOp>(
2256 loc, builder.createConvert(loc, indexType, index), one);
2257 tmp = builder.create<mlir::arith::AddIOp>(loc, linearIndex, tmp);
2258 if (idx + 1 < rank)
2259 tmp = builder.create<mlir::arith::MulIOp>(
2260 loc, tmp,
2261 builder.createConvert(loc, indexType, extents[rank - idx - 2]));
2262
2263 linearIndex = tmp;
2264 ++idx;
2265 }
2266 return linearIndex;
2267 }
2268
2269 /// Compute one-based array indices from the given zero-based \p linearIndex
2270 /// and the array \p extents [e0, e1, ..., en].
2271 /// i0 := linearIndex % e0 + 1
2272 /// linearIndex := linearIndex / e0
2273 /// i1 := linearIndex % e1 + 1
2274 /// linearIndex := linearIndex / e1
2275 /// ...
2276 /// i(n-1) := linearIndex % e(n-1) + 1
2277 /// linearIndex := linearIndex / e(n-1)
2278 /// if (wrapAround) {
2279 /// // If the index is allowed to wrap around, then
2280 /// // we need to modulo it by the last dimension's extent.
2281 /// in := linearIndex % en + 1
2282 /// } else {
2283 /// in := linearIndex + 1
2284 /// }
2285 static llvm::SmallVector<mlir::Value, Fortran::common::maxRank>
2286 delinearizeIndex(mlir::Location loc, fir::FirOpBuilder &builder,
2287 mlir::ValueRange extents, mlir::Value linearIndex,
2288 bool wrapAround) {
2289 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
2290 mlir::Type indexType = builder.getIndexType();
2291 mlir::Value one = builder.createIntegerConstant(loc, indexType, 1);
2292 linearIndex = builder.createConvert(loc, indexType, linearIndex);
2293
2294 for (std::size_t dim = 0; dim < extents.size(); ++dim) {
2295 mlir::Value extent = builder.createConvert(loc, indexType, extents[dim]);
2296 // Avoid the modulo for the last index, unless wrap around is allowed.
2297 mlir::Value currentIndex = linearIndex;
2298 if (dim != extents.size() - 1 || wrapAround)
2299 currentIndex =
2300 builder.create<mlir::arith::RemUIOp>(loc, linearIndex, extent);
2301 // The result of the last division is unused, so it will be DCEd.
2302 linearIndex =
2303 builder.create<mlir::arith::DivUIOp>(loc, linearIndex, extent);
2304 indices.push_back(
2305 builder.create<mlir::arith::AddIOp>(loc, currentIndex, one));
2306 }
2307 return indices;
2308 }
2309
2310 /// Return size of an array given its extents.
2311 static mlir::Value computeArraySize(mlir::Location loc,
2312 fir::FirOpBuilder &builder,
2313 mlir::ValueRange extents) {
2314 mlir::Type indexType = builder.getIndexType();
2315 mlir::Value size = builder.createIntegerConstant(loc, indexType, 1);
2316 for (auto extent : extents)
2317 size = builder.create<mlir::arith::MulIOp>(
2318 loc, size, builder.createConvert(loc, indexType, extent));
2319 return size;
2320 }
2321};
2322
2323class SimplifyHLFIRIntrinsics
2324 : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
2325public:
2326 using SimplifyHLFIRIntrinsicsBase<
2327 SimplifyHLFIRIntrinsics>::SimplifyHLFIRIntrinsicsBase;
2328
2329 void runOnOperation() override {
2330 mlir::MLIRContext *context = &getContext();
2331
2332 mlir::GreedyRewriteConfig config;
2333 // Prevent the pattern driver from merging blocks
2334 config.setRegionSimplificationLevel(
2335 mlir::GreedySimplifyRegionLevel::Disabled);
2336
2337 mlir::RewritePatternSet patterns(context);
2338 patterns.insert<TransposeAsElementalConversion>(context);
2339 patterns.insert<ReductionConversion<hlfir::SumOp>>(context);
2340 patterns.insert<CShiftConversion>(context);
2341 patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context);
2342
2343 patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
2344 patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
2345 patterns.insert<ReductionConversion<hlfir::AllOp>>(context);
2346 patterns.insert<ReductionConversion<hlfir::MaxlocOp>>(context);
2347 patterns.insert<ReductionConversion<hlfir::MinlocOp>>(context);
2348 patterns.insert<ReductionConversion<hlfir::MaxvalOp>>(context);
2349 patterns.insert<ReductionConversion<hlfir::MinvalOp>>(context);
2350
2351 // If forceMatmulAsElemental is false, then hlfir.matmul inlining
2352 // will introduce hlfir.eval_in_mem operation with new memory side
2353 // effects. This conflicts with CSE and optimized bufferization, e.g.:
2354 // A(1:N,1:N) = A(1:N,1:N) - MATMUL(...)
2355 // If we introduce hlfir.eval_in_mem before CSE, then the current
2356 // MLIR CSE won't be able to optimize the trivial loads of 'N' value
2357 // that happen before and after hlfir.matmul.
2358 // If 'N' loads are not optimized, then the optimized bufferization
2359 // won't be able to prove that the slices of A are identical
2360 // on both sides of the assignment.
2361 // This is actually the CSE problem, but we can work it around
2362 // for the time being.
2363 if (forceMatmulAsElemental || this->allowNewSideEffects)
2364 patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context);
2365
2366 patterns.insert<DotProductConversion>(context);
2367 patterns.insert<ReshapeAsElementalConversion>(context);
2368
2369 if (mlir::failed(mlir::applyPatternsGreedily(
2370 getOperation(), std::move(patterns), config))) {
2371 mlir::emitError(getOperation()->getLoc(),
2372 "failure in HLFIR intrinsic simplification");
2373 signalPassFailure();
2374 }
2375 }
2376};
2377} // namespace
2378

source code of flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp