1 | //===- SimplifyHLFIRIntrinsics.cpp - Simplify HLFIR Intrinsics ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // Normally transformational intrinsics are lowered to calls to runtime |
9 | // functions. However, some cases of the intrinsics are faster when inlined |
10 | // into the calling function. |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Optimizer/Builder/Complex.h" |
14 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
15 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
16 | #include "flang/Optimizer/Builder/IntrinsicCall.h" |
17 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
18 | #include "flang/Optimizer/HLFIR/HLFIRDialect.h" |
19 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
20 | #include "flang/Optimizer/HLFIR/Passes.h" |
21 | #include "mlir/Dialect/Arith/IR/Arith.h" |
22 | #include "mlir/IR/Location.h" |
23 | #include "mlir/Pass/Pass.h" |
24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
25 | |
26 | namespace hlfir { |
27 | #define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS |
28 | #include "flang/Optimizer/HLFIR/Passes.h.inc" |
29 | } // namespace hlfir |
30 | |
31 | #define DEBUG_TYPE "simplify-hlfir-intrinsics" |
32 | |
33 | static llvm::cl::opt<bool> forceMatmulAsElemental( |
34 | "flang-inline-matmul-as-elemental" , |
35 | llvm::cl::desc("Expand hlfir.matmul as elemental operation" ), |
36 | llvm::cl::init(false)); |
37 | |
38 | namespace { |
39 | |
40 | // Helper class to generate operations related to computing |
41 | // product of values. |
42 | class ProductFactory { |
43 | public: |
44 | ProductFactory(mlir::Location loc, fir::FirOpBuilder &builder) |
45 | : loc(loc), builder(builder) {} |
46 | |
47 | // Generate an update of the inner product value: |
48 | // acc += v1 * v2, OR |
49 | // acc += CONJ(v1) * v2, OR |
50 | // acc ||= v1 && v2 |
51 | // |
52 | // CONJ parameter specifies whether the first complex product argument |
53 | // needs to be conjugated. |
54 | template <bool CONJ = false> |
55 | mlir::Value genAccumulateProduct(mlir::Value acc, mlir::Value v1, |
56 | mlir::Value v2) { |
57 | mlir::Type resultType = acc.getType(); |
58 | acc = castToProductType(acc, resultType); |
59 | v1 = castToProductType(v1, resultType); |
60 | v2 = castToProductType(v2, resultType); |
61 | mlir::Value result; |
62 | if (mlir::isa<mlir::FloatType>(resultType)) { |
63 | result = builder.create<mlir::arith::AddFOp>( |
64 | loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2)); |
65 | } else if (mlir::isa<mlir::ComplexType>(resultType)) { |
66 | if constexpr (CONJ) |
67 | result = fir::IntrinsicLibrary{builder, loc}.genConjg(resultType, v1); |
68 | else |
69 | result = v1; |
70 | |
71 | result = builder.create<fir::AddcOp>( |
72 | loc, acc, builder.create<fir::MulcOp>(loc, result, v2)); |
73 | } else if (mlir::isa<mlir::IntegerType>(resultType)) { |
74 | result = builder.create<mlir::arith::AddIOp>( |
75 | loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2)); |
76 | } else if (mlir::isa<fir::LogicalType>(resultType)) { |
77 | result = builder.create<mlir::arith::OrIOp>( |
78 | loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2)); |
79 | } else { |
80 | llvm_unreachable("unsupported type" ); |
81 | } |
82 | |
83 | return builder.createConvert(loc, resultType, result); |
84 | } |
85 | |
86 | private: |
87 | mlir::Location loc; |
88 | fir::FirOpBuilder &builder; |
89 | |
90 | mlir::Value castToProductType(mlir::Value value, mlir::Type type) { |
91 | if (mlir::isa<fir::LogicalType>(type)) |
92 | return builder.createConvert(loc, builder.getIntegerType(1), value); |
93 | |
94 | // TODO: the multiplications/additions by/of zero resulting from |
95 | // complex * real are optimized by LLVM under -fno-signed-zeros |
96 | // -fno-honor-nans. |
97 | // We can make them disappear by default if we: |
98 | // * either expand the complex multiplication into real |
99 | // operations, OR |
100 | // * set nnan nsz fast-math flags to the complex operations. |
101 | if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) { |
102 | mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type); |
103 | fir::factory::Complex helper(builder, loc); |
104 | mlir::Type partType = helper.getComplexPartType(type); |
105 | return helper.insertComplexPart(zeroCmplx, |
106 | castToProductType(value, partType), |
107 | /*isImagPart=*/false); |
108 | } |
109 | return builder.createConvert(loc, type, value); |
110 | } |
111 | }; |
112 | |
113 | class TransposeAsElementalConversion |
114 | : public mlir::OpRewritePattern<hlfir::TransposeOp> { |
115 | public: |
116 | using mlir::OpRewritePattern<hlfir::TransposeOp>::OpRewritePattern; |
117 | |
118 | llvm::LogicalResult |
119 | matchAndRewrite(hlfir::TransposeOp transpose, |
120 | mlir::PatternRewriter &rewriter) const override { |
121 | hlfir::ExprType expr = transpose.getType(); |
122 | // TODO: hlfir.elemental supports polymorphic data types now, |
123 | // so this can be supported. |
124 | if (expr.isPolymorphic()) |
125 | return rewriter.notifyMatchFailure(transpose, |
126 | "TRANSPOSE of polymorphic type" ); |
127 | |
128 | mlir::Location loc = transpose.getLoc(); |
129 | fir::FirOpBuilder builder{rewriter, transpose.getOperation()}; |
130 | mlir::Type elementType = expr.getElementType(); |
131 | hlfir::Entity array = hlfir::Entity{transpose.getArray()}; |
132 | mlir::Value resultShape = genResultShape(loc, builder, array); |
133 | llvm::SmallVector<mlir::Value, 1> typeParams; |
134 | hlfir::genLengthParameters(loc, builder, array, typeParams); |
135 | |
136 | auto genKernel = [&array](mlir::Location loc, fir::FirOpBuilder &builder, |
137 | mlir::ValueRange inputIndices) -> hlfir::Entity { |
138 | assert(inputIndices.size() == 2 && "checked in TransposeOp::validate" ); |
139 | const std::initializer_list<mlir::Value> initList = {inputIndices[1], |
140 | inputIndices[0]}; |
141 | mlir::ValueRange transposedIndices(initList); |
142 | hlfir::Entity element = |
143 | hlfir::getElementAt(loc, builder, array, transposedIndices); |
144 | hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, element); |
145 | return val; |
146 | }; |
147 | hlfir::ElementalOp elementalOp = hlfir::genElementalOp( |
148 | loc, builder, elementType, resultShape, typeParams, genKernel, |
149 | /*isUnordered=*/true, /*polymorphicMold=*/nullptr, |
150 | transpose.getResult().getType()); |
151 | |
152 | // it wouldn't be safe to replace block arguments with a different |
153 | // hlfir.expr type. Types can differ due to differing amounts of shape |
154 | // information |
155 | assert(elementalOp.getResult().getType() == |
156 | transpose.getResult().getType()); |
157 | |
158 | rewriter.replaceOp(transpose, elementalOp); |
159 | return mlir::success(); |
160 | } |
161 | |
162 | private: |
163 | static mlir::Value genResultShape(mlir::Location loc, |
164 | fir::FirOpBuilder &builder, |
165 | hlfir::Entity array) { |
166 | llvm::SmallVector<mlir::Value, 2> inExtents = |
167 | hlfir::genExtentsVector(loc, builder, array); |
168 | |
169 | // transpose indices |
170 | assert(inExtents.size() == 2 && "checked in TransposeOp::validate" ); |
171 | return builder.create<fir::ShapeOp>( |
172 | loc, mlir::ValueRange{inExtents[1], inExtents[0]}); |
173 | } |
174 | }; |
175 | |
176 | /// Base class for converting reduction-like operations into |
177 | /// a reduction loop[-nest] optionally wrapped into hlfir.elemental. |
178 | /// It is used to handle operations produced for ALL, ANY, COUNT, |
179 | /// MAXLOC, MAXVAL, MINLOC, MINVAL, SUM intrinsics. |
180 | /// |
181 | /// All of these operations take an input array, and optional |
182 | /// dim, mask arguments. ALL, ANY, COUNT do not have mask argument. |
183 | class ReductionAsElementalConverter { |
184 | public: |
185 | ReductionAsElementalConverter(mlir::Operation *op, |
186 | mlir::PatternRewriter &rewriter) |
187 | : op{op}, rewriter{rewriter}, loc{op->getLoc()}, builder{rewriter, op} { |
188 | assert(op->getNumResults() == 1); |
189 | } |
190 | virtual ~ReductionAsElementalConverter() {} |
191 | |
192 | /// Do the actual conversion or return mlir::failure(), |
193 | /// if conversion is not possible. |
194 | mlir::LogicalResult convert(); |
195 | |
196 | private: |
197 | // Return fir.shape specifying the shape of the result |
198 | // of a reduction with DIM=dimVal. The second return value |
199 | // is the extent of the DIM dimension. |
200 | std::tuple<mlir::Value, mlir::Value> |
201 | genResultShapeForPartialReduction(hlfir::Entity array, int64_t dimVal); |
202 | |
203 | /// \p mask is a scalar or array logical mask. |
204 | /// If \p isPresentPred is not nullptr, it is a dynamic predicate value |
205 | /// identifying whether the mask's variable is present. |
206 | /// \p indices is a range of one-based indices to access \p mask |
207 | /// when it is an array. |
208 | /// |
209 | /// The method returns the scalar mask value to guard the access |
210 | /// to a single element of the input array. |
211 | mlir::Value genMaskValue(mlir::Value mask, mlir::Value isPresentPred, |
212 | mlir::ValueRange indices); |
213 | |
214 | protected: |
215 | /// Return the input array. |
216 | virtual mlir::Value getSource() const = 0; |
217 | |
218 | /// Return DIM or nullptr, if it is not present. |
219 | virtual mlir::Value getDim() const = 0; |
220 | |
221 | /// Return MASK or nullptr, if it is not present. |
222 | virtual mlir::Value getMask() const { return nullptr; } |
223 | |
224 | /// Return FastMathFlags attached to the operation |
225 | /// or arith::FastMathFlags::none, if the operation |
226 | /// does not support FastMathFlags (e.g. ALL, ANY, COUNT). |
227 | virtual mlir::arith::FastMathFlags getFastMath() const { |
228 | return mlir::arith::FastMathFlags::none; |
229 | } |
230 | |
231 | /// Generates initial values for the reduction values used |
232 | /// by the reduction loop. In general, there is a single |
233 | /// loop-carried reduction value (e.g. for SUM), but, for example, |
234 | /// MAXLOC/MINLOC implementation uses multiple reductions. |
235 | /// \p oneBasedIndices contains any array indices predefined |
236 | /// before the reduction loop, i.e. it is empty for total |
237 | /// reductions, and contains the one-based indices of the wrapping |
238 | /// hlfir.elemental. |
239 | /// \p extents are the pre-computed extents of the input array. |
240 | /// For total reductions, \p extents holds extents of all dimensions. |
241 | /// For partial reductions, \p extents holds a single extent |
242 | /// of the DIM dimension. |
243 | virtual llvm::SmallVector<mlir::Value> |
244 | genReductionInitValues(mlir::ValueRange oneBasedIndices, |
245 | const llvm::SmallVectorImpl<mlir::Value> &extents) = 0; |
246 | |
247 | /// Perform reduction(s) update given a single input array's element |
248 | /// identified by \p array and \p oneBasedIndices coordinates. |
249 | /// \p currentValue specifies the current value(s) of the reduction(s) |
250 | /// inside the reduction loop body. |
251 | virtual llvm::SmallVector<mlir::Value> |
252 | reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> ¤tValue, |
253 | hlfir::Entity array, mlir::ValueRange oneBasedIndices) = 0; |
254 | |
255 | /// Given reduction value(s) in \p reductionResults produced |
256 | /// by the reduction loop, apply any required updates and return |
257 | /// new reduction value(s) to be used after the reduction loop |
258 | /// (e.g. as the result yield of the wrapping hlfir.elemental). |
259 | /// NOTE: if the reduction loop is wrapped in hlfir.elemental, |
260 | /// the insertion point of any generated code is inside hlfir.elemental. |
261 | virtual hlfir::Entity |
262 | genFinalResult(const llvm::SmallVectorImpl<mlir::Value> &reductionResults) { |
263 | assert(reductionResults.size() == 1 && |
264 | "default implementation of genFinalResult expect a single reduction " |
265 | "value" ); |
266 | return hlfir::Entity{reductionResults[0]}; |
267 | } |
268 | |
269 | /// Return mlir::success(), if the operation can be converted. |
270 | /// The default implementation always returns mlir::success(). |
271 | /// The derived type may override the default implementation |
272 | /// with its own definition. |
273 | virtual mlir::LogicalResult isConvertible() const { return mlir::success(); } |
274 | |
275 | // Default implementation of isTotalReduction() just checks |
276 | // if the result of the operation is a scalar. |
277 | // True result indicates that the reduction has to be done |
278 | // across all elements, false result indicates that |
279 | // the result is an array expression produced by an hlfir.elemental |
280 | // operation with a single reduction loop across the DIM dimension. |
281 | // |
282 | // MAXLOC/MINLOC must override this. |
283 | virtual bool isTotalReduction() const { return getResultRank() == 0; } |
284 | |
285 | // Return true, if the reduction loop[-nest] may be unordered. |
286 | // In general, FP reductions may only be unordered when |
287 | // FastMathFlags::reassoc transformations are allowed. |
288 | // |
289 | // Some dervied types may need to override this. |
290 | virtual bool isUnordered() const { |
291 | mlir::Type elemType = getSourceElementType(); |
292 | if (mlir::isa<mlir::IntegerType, fir::LogicalType, fir::CharacterType>( |
293 | elemType)) |
294 | return true; |
295 | return static_cast<bool>(getFastMath() & |
296 | mlir::arith::FastMathFlags::reassoc); |
297 | } |
298 | |
299 | /// Return 0, if DIM is not present or its values does not matter |
300 | /// (for example, a reduction of 1D array does not care about |
301 | /// the DIM value, assuming that it is a valid program). |
302 | /// Return mlir::failure(), if DIM is a constant known |
303 | /// to be invalid for the given array. |
304 | /// Otherwise, return DIM constant value. |
305 | mlir::FailureOr<int64_t> getConstDim() const { |
306 | int64_t dimVal = 0; |
307 | if (!isTotalReduction()) { |
308 | // In case of partial reduction we should ignore the operations |
309 | // with invalid DIM values. They may appear in dead code |
310 | // after constant propagation. |
311 | auto constDim = fir::getIntIfConstant(getDim()); |
312 | if (!constDim) |
313 | return rewriter.notifyMatchFailure(op, "Nonconstant DIM" ); |
314 | dimVal = *constDim; |
315 | |
316 | if ((dimVal <= 0 || dimVal > getSourceRank())) |
317 | return rewriter.notifyMatchFailure(op, |
318 | "Invalid DIM for partial reduction" ); |
319 | } |
320 | return dimVal; |
321 | } |
322 | |
323 | /// Return hlfir::Entity of the result. |
324 | hlfir::Entity getResultEntity() const { |
325 | return hlfir::Entity{op->getResult(0)}; |
326 | } |
327 | |
328 | /// Return type of the result (e.g. !hlfir.expr<?xi32>). |
329 | mlir::Type getResultType() const { return getResultEntity().getType(); } |
330 | |
331 | /// Return the element type of the result (e.g. i32). |
332 | mlir::Type getResultElementType() const { |
333 | return hlfir::getFortranElementType(getResultType()); |
334 | } |
335 | |
336 | /// Return rank of the result. |
337 | unsigned getResultRank() const { return getResultEntity().getRank(); } |
338 | |
339 | /// Return the element type of the source. |
340 | mlir::Type getSourceElementType() const { |
341 | return hlfir::getFortranElementType(getSource().getType()); |
342 | } |
343 | |
344 | /// Return rank of the input array. |
345 | unsigned getSourceRank() const { |
346 | return hlfir::Entity{getSource()}.getRank(); |
347 | } |
348 | |
349 | /// The reduction operation. |
350 | mlir::Operation *op; |
351 | |
352 | mlir::PatternRewriter &rewriter; |
353 | mlir::Location loc; |
354 | fir::FirOpBuilder builder; |
355 | }; |
356 | |
357 | /// Generate initialization value for MIN or MAX reduction |
358 | /// of the given \p type. |
359 | template <bool IS_MAX> |
360 | static mlir::Value genMinMaxInitValue(mlir::Location loc, |
361 | fir::FirOpBuilder &builder, |
362 | mlir::Type type) { |
363 | if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) { |
364 | const llvm::fltSemantics &sem = ty.getFloatSemantics(); |
365 | // We must not use +/-INF here. If the reduction input is empty, |
366 | // the result of reduction must be +/-LARGEST. |
367 | llvm::APFloat limit = llvm::APFloat::getLargest(sem, /*Negative=*/IS_MAX); |
368 | return builder.createRealConstant(loc, type, limit); |
369 | } |
370 | unsigned bits = type.getIntOrFloatBitWidth(); |
371 | int64_t limitInt = IS_MAX |
372 | ? llvm::APInt::getSignedMinValue(bits).getSExtValue() |
373 | : llvm::APInt::getSignedMaxValue(bits).getSExtValue(); |
374 | return builder.createIntegerConstant(loc, type, limitInt); |
375 | } |
376 | |
377 | /// Generate a comparison of an array element value \p elem |
378 | /// and the current reduction value \p reduction for MIN/MAX reduction. |
379 | template <bool IS_MAX> |
380 | static mlir::Value |
381 | genMinMaxComparison(mlir::Location loc, fir::FirOpBuilder &builder, |
382 | mlir::Value elem, mlir::Value reduction) { |
383 | if (mlir::isa<mlir::FloatType>(reduction.getType())) { |
384 | // For FP reductions we want the first smallest value to be used, that |
385 | // is not NaN. A OGL/OLT condition will usually work for this unless all |
386 | // the values are Nan or Inf. This follows the same logic as |
387 | // NumericCompare for Minloc/Maxloc in extrema.cpp. |
388 | mlir::Value cmp = builder.create<mlir::arith::CmpFOp>( |
389 | loc, |
390 | IS_MAX ? mlir::arith::CmpFPredicate::OGT |
391 | : mlir::arith::CmpFPredicate::OLT, |
392 | elem, reduction); |
393 | mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>( |
394 | loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction); |
395 | mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>( |
396 | loc, mlir::arith::CmpFPredicate::OEQ, elem, elem); |
397 | cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2); |
398 | return builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan); |
399 | } else if (mlir::isa<mlir::IntegerType>(reduction.getType())) { |
400 | return builder.create<mlir::arith::CmpIOp>( |
401 | loc, |
402 | IS_MAX ? mlir::arith::CmpIPredicate::sgt |
403 | : mlir::arith::CmpIPredicate::slt, |
404 | elem, reduction); |
405 | } |
406 | llvm_unreachable("unsupported type" ); |
407 | } |
408 | |
409 | // Generate a predicate value indicating that an array with the given |
410 | // extents is not empty. |
411 | static mlir::Value |
412 | genIsNotEmptyArrayExtents(mlir::Location loc, fir::FirOpBuilder &builder, |
413 | const llvm::SmallVectorImpl<mlir::Value> &extents) { |
414 | mlir::Value isNotEmpty = builder.createBool(loc, true); |
415 | for (auto extent : extents) { |
416 | mlir::Value zero = |
417 | fir::factory::createZeroValue(builder, loc, extent.getType()); |
418 | mlir::Value cmp = builder.create<mlir::arith::CmpIOp>( |
419 | loc, mlir::arith::CmpIPredicate::ne, extent, zero); |
420 | isNotEmpty = builder.create<mlir::arith::AndIOp>(loc, isNotEmpty, cmp); |
421 | } |
422 | return isNotEmpty; |
423 | } |
424 | |
425 | // Helper method for MIN/MAX LOC/VAL reductions. |
426 | // It returns a vector of indices such that they address |
427 | // the first element of an array (in case of total reduction) |
428 | // or its section (in case of partial reduction). |
429 | // |
430 | // If case of total reduction oneBasedIndices must be empty, |
431 | // otherwise, they contain the one based indices of the wrapping |
432 | // hlfir.elemental. |
433 | // Basically, the method adds the necessary number of constant-one |
434 | // indices into oneBasedIndices. |
435 | static llvm::SmallVector<mlir::Value> genFirstElementIndicesForReduction( |
436 | mlir::Location loc, fir::FirOpBuilder &builder, bool isTotalReduction, |
437 | mlir::FailureOr<int64_t> dim, unsigned rank, |
438 | mlir::ValueRange oneBasedIndices) { |
439 | llvm::SmallVector<mlir::Value> indices{oneBasedIndices}; |
440 | mlir::Value one = |
441 | builder.createIntegerConstant(loc, builder.getIndexType(), 1); |
442 | if (isTotalReduction) { |
443 | assert(oneBasedIndices.size() == 0 && |
444 | "wrong number of indices for total reduction" ); |
445 | // Set indices to all-ones. |
446 | indices.append(rank, one); |
447 | } else { |
448 | assert(oneBasedIndices.size() == rank - 1 && |
449 | "there must be RANK-1 indices for partial reduction" ); |
450 | assert(mlir::succeeded(dim) && "partial reduction with invalid DIM" ); |
451 | // Insert constant-one index at DIM dimension. |
452 | indices.insert(indices.begin() + *dim - 1, one); |
453 | } |
454 | return indices; |
455 | } |
456 | |
457 | /// Implementation of ReductionAsElementalConverter interface |
458 | /// for MAXLOC/MINLOC. |
459 | template <typename T> |
460 | class MinMaxlocAsElementalConverter : public ReductionAsElementalConverter { |
461 | static_assert(std::is_same_v<T, hlfir::MaxlocOp> || |
462 | std::is_same_v<T, hlfir::MinlocOp>); |
463 | static constexpr unsigned maxRank = Fortran::common::maxRank; |
464 | // We have the following reduction values in the reduction loop: |
465 | // * N integer coordinates, where N is: |
466 | // - RANK(ARRAY) for total reductions. |
467 | // - 1 for partial reductions. |
468 | // * 1 reduction value holding the current MIN/MAX. |
469 | // * 1 boolean indicating whether it is the first time |
470 | // the mask is true. |
471 | // |
472 | // If useIsFirst() returns false, then the boolean loop-carried |
473 | // value is not used. |
474 | static constexpr unsigned maxNumReductions = Fortran::common::maxRank + 2; |
475 | static constexpr bool isMax = std::is_same_v<T, hlfir::MaxlocOp>; |
476 | using Base = ReductionAsElementalConverter; |
477 | |
478 | public: |
479 | MinMaxlocAsElementalConverter(T op, mlir::PatternRewriter &rewriter) |
480 | : Base{op.getOperation(), rewriter} {} |
481 | |
482 | private: |
483 | virtual mlir::Value getSource() const final { return getOp().getArray(); } |
484 | virtual mlir::Value getDim() const final { return getOp().getDim(); } |
485 | virtual mlir::Value getMask() const final { return getOp().getMask(); } |
486 | virtual mlir::arith::FastMathFlags getFastMath() const final { |
487 | return getOp().getFastmath(); |
488 | } |
489 | |
490 | virtual mlir::LogicalResult isConvertible() const final { |
491 | if (getOp().getBack()) |
492 | return rewriter.notifyMatchFailure( |
493 | getOp(), "BACK is not supported for MINLOC/MAXLOC inlining" ); |
494 | if (mlir::isa<fir::CharacterType>(getSourceElementType())) |
495 | return rewriter.notifyMatchFailure( |
496 | getOp(), |
497 | "CHARACTER type is not supported for MINLOC/MAXLOC inlining" ); |
498 | return mlir::success(); |
499 | } |
500 | |
501 | // If the result is scalar, then DIM does not matter, |
502 | // and this is a total reduction. |
503 | // If DIM is not present, this is a total reduction. |
504 | virtual bool isTotalReduction() const final { |
505 | return getResultRank() == 0 || !getDim(); |
506 | } |
507 | |
508 | virtual llvm::SmallVector<mlir::Value> genReductionInitValues( |
509 | mlir::ValueRange oneBasedIndices, |
510 | const llvm::SmallVectorImpl<mlir::Value> &extents) final; |
511 | virtual llvm::SmallVector<mlir::Value> |
512 | reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> ¤tValue, |
513 | hlfir::Entity array, mlir::ValueRange oneBasedIndices) final; |
514 | virtual hlfir::Entity genFinalResult( |
515 | const llvm::SmallVectorImpl<mlir::Value> &reductionResults) final; |
516 | |
517 | private: |
518 | T getOp() const { return mlir::cast<T>(op); } |
519 | |
520 | unsigned getNumCoors() const { |
521 | return isTotalReduction() ? getSourceRank() : 1; |
522 | } |
523 | |
524 | void |
525 | checkReductions(const llvm::SmallVectorImpl<mlir::Value> &reductions) const { |
526 | if (!useIsFirst()) |
527 | assert(reductions.size() == getNumCoors() + 1 && |
528 | "invalid number of reductions for MINLOC/MAXLOC" ); |
529 | else |
530 | assert(reductions.size() == getNumCoors() + 2 && |
531 | "invalid number of reductions for MINLOC/MAXLOC" ); |
532 | } |
533 | |
534 | mlir::Value |
535 | getCurrentMinMax(const llvm::SmallVectorImpl<mlir::Value> &reductions) const { |
536 | checkReductions(reductions); |
537 | return reductions[getNumCoors()]; |
538 | } |
539 | |
540 | mlir::Value |
541 | getIsFirst(const llvm::SmallVectorImpl<mlir::Value> &reductions) const { |
542 | checkReductions(reductions); |
543 | assert(useIsFirst() && "IsFirst predicate must not be used" ); |
544 | return reductions[getNumCoors() + 1]; |
545 | } |
546 | |
547 | // Return true iff the input can contain NaNs, and they should be |
548 | // honored, such that all-NaNs input must produce the location |
549 | // of the first unmasked NaN. |
550 | bool honorNans() const { |
551 | return !static_cast<bool>(getFastMath() & mlir::arith::FastMathFlags::nnan); |
552 | } |
553 | |
554 | // Return true iff we have to use the loop-carried IsFirst predicate. |
555 | // If there is no mask, we can initialize the reductions using |
556 | // the first elements of the input. |
557 | // If NaNs are not honored, we can initialize the starting MIN/MAX |
558 | // value to +/-LARGEST; the coordinates are guaranteed to be updated |
559 | // properly for non-empty input without NaNs. |
560 | bool useIsFirst() const { return getMask() && honorNans(); } |
561 | }; |
562 | |
563 | template <typename T> |
564 | llvm::SmallVector<mlir::Value> |
565 | MinMaxlocAsElementalConverter<T>::genReductionInitValues( |
566 | mlir::ValueRange oneBasedIndices, |
567 | const llvm::SmallVectorImpl<mlir::Value> &extents) { |
568 | fir::IfOp ifOp; |
569 | if (!useIsFirst() && honorNans()) { |
570 | // Check if we can load the value of the first element in the array |
571 | // or its section (for partial reduction). |
572 | assert(!getMask() && "cannot fetch first element when mask is present" ); |
573 | assert(extents.size() == getNumCoors() && |
574 | "wrong number of extents for MINLOC/MAXLOC reduction" ); |
575 | mlir::Value isNotEmpty = genIsNotEmptyArrayExtents(loc, builder, extents); |
576 | |
577 | llvm::SmallVector<mlir::Value> indices = genFirstElementIndicesForReduction( |
578 | loc, builder, isTotalReduction(), getConstDim(), getSourceRank(), |
579 | oneBasedIndices); |
580 | |
581 | llvm::SmallVector<mlir::Type> ifTypes(getNumCoors(), |
582 | getResultElementType()); |
583 | ifTypes.push_back(getSourceElementType()); |
584 | ifOp = builder.create<fir::IfOp>(loc, ifTypes, isNotEmpty, |
585 | /*withElseRegion=*/true); |
586 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
587 | mlir::Value one = |
588 | builder.createIntegerConstant(loc, getResultElementType(), 1); |
589 | llvm::SmallVector<mlir::Value> results(getNumCoors(), one); |
590 | mlir::Value minMaxFirst = |
591 | hlfir::loadElementAt(loc, builder, hlfir::Entity{getSource()}, indices); |
592 | results.push_back(minMaxFirst); |
593 | builder.create<fir::ResultOp>(loc, results); |
594 | |
595 | // In the 'else' block use default init values. |
596 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
597 | } |
598 | |
599 | // Initial value for the coordinate(s) is zero. |
600 | mlir::Value zeroCoor = |
601 | fir::factory::createZeroValue(builder, loc, getResultElementType()); |
602 | llvm::SmallVector<mlir::Value> result(getNumCoors(), zeroCoor); |
603 | |
604 | // Initial value for the MIN/MAX value. |
605 | mlir::Value minMaxInit = |
606 | genMinMaxInitValue<isMax>(loc, builder, getSourceElementType()); |
607 | result.push_back(minMaxInit); |
608 | |
609 | if (ifOp) { |
610 | builder.create<fir::ResultOp>(loc, result); |
611 | builder.setInsertionPointAfter(ifOp); |
612 | result = ifOp.getResults(); |
613 | } else if (useIsFirst()) { |
614 | // Initial value for isFirst predicate. It is switched to false, |
615 | // when the reduction update dynamically happens inside the reduction |
616 | // loop. |
617 | mlir::Value trueVal = builder.createBool(loc, true); |
618 | result.push_back(trueVal); |
619 | } |
620 | |
621 | return result; |
622 | } |
623 | |
624 | template <typename T> |
625 | llvm::SmallVector<mlir::Value> |
626 | MinMaxlocAsElementalConverter<T>::reduceOneElement( |
627 | const llvm::SmallVectorImpl<mlir::Value> ¤tValue, hlfir::Entity array, |
628 | mlir::ValueRange oneBasedIndices) { |
629 | checkReductions(currentValue); |
630 | hlfir::Entity elementValue = |
631 | hlfir::loadElementAt(loc, builder, array, oneBasedIndices); |
632 | mlir::Value cmp = genMinMaxComparison<isMax>(loc, builder, elementValue, |
633 | getCurrentMinMax(currentValue)); |
634 | if (useIsFirst()) { |
635 | // If isFirst is true, then do the reduction update regardless |
636 | // of the FP comparison. |
637 | cmp = |
638 | builder.create<mlir::arith::OrIOp>(loc, cmp, getIsFirst(currentValue)); |
639 | } |
640 | |
641 | llvm::SmallVector<mlir::Value> newIndices; |
642 | int64_t dim = 1; |
643 | if (!isTotalReduction()) { |
644 | auto dimVal = getConstDim(); |
645 | assert(mlir::succeeded(dimVal) && |
646 | "partial MINLOC/MAXLOC reduction with invalid DIM" ); |
647 | dim = *dimVal; |
648 | assert(getNumCoors() == 1 && |
649 | "partial MAXLOC/MINLOC reduction must compute one coordinate" ); |
650 | } |
651 | |
652 | for (unsigned coorIdx = 0; coorIdx < getNumCoors(); ++coorIdx) { |
653 | mlir::Value currentCoor = currentValue[coorIdx]; |
654 | mlir::Value newCoor = builder.createConvert( |
655 | loc, currentCoor.getType(), oneBasedIndices[coorIdx + dim - 1]); |
656 | mlir::Value update = |
657 | builder.create<mlir::arith::SelectOp>(loc, cmp, newCoor, currentCoor); |
658 | newIndices.push_back(update); |
659 | } |
660 | |
661 | mlir::Value newMinMax = builder.create<mlir::arith::SelectOp>( |
662 | loc, cmp, elementValue, getCurrentMinMax(currentValue)); |
663 | newIndices.push_back(newMinMax); |
664 | |
665 | if (useIsFirst()) { |
666 | mlir::Value newIsFirst = builder.createBool(loc, false); |
667 | newIndices.push_back(newIsFirst); |
668 | } |
669 | |
670 | assert(currentValue.size() == newIndices.size() && |
671 | "invalid number of updated reductions" ); |
672 | |
673 | return newIndices; |
674 | } |
675 | |
676 | template <typename T> |
677 | hlfir::Entity MinMaxlocAsElementalConverter<T>::genFinalResult( |
678 | const llvm::SmallVectorImpl<mlir::Value> &reductionResults) { |
679 | // Identification of the final result of MINLOC/MAXLOC: |
680 | // * If DIM is absent, the result is rank-one array. |
681 | // * If DIM is present: |
682 | // - The result is scalar for rank-one input. |
683 | // - The result is an array of rank RANK(ARRAY)-1. |
684 | checkReductions(reductionResults); |
685 | |
686 | // 16.9.137 & 16.9.143: |
687 | // The subscripts returned by MINLOC/MAXLOC are in the range |
688 | // 1 to the extent of the corresponding dimension. |
689 | mlir::Type indexType = builder.getIndexType(); |
690 | |
691 | // For partial reductions, the final result of the reduction |
692 | // loop is just a scalar - the coordinate within DIM dimension. |
693 | if (getResultRank() == 0 || !isTotalReduction()) { |
694 | // The result is a scalar, so just return the scalar. |
695 | assert(getNumCoors() == 1 && |
696 | "unpexpected number of coordinates for scalar result" ); |
697 | return hlfir::Entity{reductionResults[0]}; |
698 | } |
699 | // This is a total reduction, and there is no wrapping hlfir.elemental. |
700 | // We have to pack the reduced coordinates into a rank-one array. |
701 | unsigned rank = getSourceRank(); |
702 | // TODO: in order to avoid introducing new memory effects |
703 | // we should not use a temporary in memory. |
704 | // We can use hlfir.elemental with a switch to pack all the coordinates |
705 | // into an array expression, or we can have a dedicated HLFIR operation |
706 | // for this. |
707 | mlir::Value tempArray = builder.createTemporary( |
708 | loc, fir::SequenceType::get(rank, getResultElementType())); |
709 | for (unsigned i = 0; i < rank; ++i) { |
710 | mlir::Value coor = reductionResults[i]; |
711 | mlir::Value idx = builder.createIntegerConstant(loc, indexType, i + 1); |
712 | mlir::Value resultElement = |
713 | hlfir::getElementAt(loc, builder, hlfir::Entity{tempArray}, {idx}); |
714 | builder.create<hlfir::AssignOp>(loc, coor, resultElement); |
715 | } |
716 | mlir::Value tempExpr = builder.create<hlfir::AsExprOp>( |
717 | loc, tempArray, builder.createBool(loc, false)); |
718 | return hlfir::Entity{tempExpr}; |
719 | } |
720 | |
721 | /// Base class for numeric reductions like MAXVAl, MINVAL, SUM. |
722 | template <typename OpT> |
723 | class NumericReductionAsElementalConverterBase |
724 | : public ReductionAsElementalConverter { |
725 | using Base = ReductionAsElementalConverter; |
726 | |
727 | protected: |
728 | NumericReductionAsElementalConverterBase(OpT op, |
729 | mlir::PatternRewriter &rewriter) |
730 | : Base{op.getOperation(), rewriter} {} |
731 | |
732 | virtual mlir::Value getSource() const final { return getOp().getArray(); } |
733 | virtual mlir::Value getDim() const final { return getOp().getDim(); } |
734 | virtual mlir::Value getMask() const final { return getOp().getMask(); } |
735 | virtual mlir::arith::FastMathFlags getFastMath() const final { |
736 | return getOp().getFastmath(); |
737 | } |
738 | |
739 | OpT getOp() const { return mlir::cast<OpT>(op); } |
740 | |
741 | void checkReductions(const llvm::SmallVectorImpl<mlir::Value> &reductions) { |
742 | assert(reductions.size() == 1 && "reduction must produce single value" ); |
743 | } |
744 | }; |
745 | |
746 | /// Reduction converter for MAXMAL/MINVAL. |
747 | template <typename T> |
748 | class MinMaxvalAsElementalConverter |
749 | : public NumericReductionAsElementalConverterBase<T> { |
750 | static_assert(std::is_same_v<T, hlfir::MaxvalOp> || |
751 | std::is_same_v<T, hlfir::MinvalOp>); |
752 | // We have two reduction values: |
753 | // * The current MIN/MAX value. |
754 | // * 1 boolean indicating whether it is the first time |
755 | // the mask is true. |
756 | // |
757 | // The boolean flag is used to replace the initial value |
758 | // with the first input element even if it is NaN. |
759 | // If useIsFirst() returns false, then the boolean loop-carried |
760 | // value is not used. |
761 | static constexpr bool isMax = std::is_same_v<T, hlfir::MaxvalOp>; |
762 | using Base = NumericReductionAsElementalConverterBase<T>; |
763 | |
764 | public: |
765 | MinMaxvalAsElementalConverter(T op, mlir::PatternRewriter &rewriter) |
766 | : Base{op, rewriter} {} |
767 | |
768 | private: |
769 | virtual mlir::LogicalResult isConvertible() const final { |
770 | if (mlir::isa<fir::CharacterType>(this->getSourceElementType())) |
771 | return this->rewriter.notifyMatchFailure( |
772 | this->getOp(), |
773 | "CHARACTER type is not supported for MINVAL/MAXVAL inlining" ); |
774 | return mlir::success(); |
775 | } |
776 | |
777 | virtual llvm::SmallVector<mlir::Value> genReductionInitValues( |
778 | mlir::ValueRange oneBasedIndices, |
779 | const llvm::SmallVectorImpl<mlir::Value> &extents) final; |
780 | |
781 | virtual llvm::SmallVector<mlir::Value> |
782 | reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> ¤tValue, |
783 | hlfir::Entity array, |
784 | mlir::ValueRange oneBasedIndices) final { |
785 | this->checkReductions(currentValue); |
786 | llvm::SmallVector<mlir::Value> result; |
787 | fir::FirOpBuilder &builder = this->builder; |
788 | mlir::Location loc = this->loc; |
789 | hlfir::Entity elementValue = |
790 | hlfir::loadElementAt(loc, builder, array, oneBasedIndices); |
791 | mlir::Value currentMinMax = getCurrentMinMax(currentValue); |
792 | mlir::Value cmp = |
793 | genMinMaxComparison<isMax>(loc, builder, elementValue, currentMinMax); |
794 | if (useIsFirst()) |
795 | cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, |
796 | getIsFirst(currentValue)); |
797 | mlir::Value newMinMax = builder.create<mlir::arith::SelectOp>( |
798 | loc, cmp, elementValue, currentMinMax); |
799 | result.push_back(newMinMax); |
800 | if (useIsFirst()) |
801 | result.push_back(builder.createBool(loc, false)); |
802 | return result; |
803 | } |
804 | |
805 | virtual hlfir::Entity genFinalResult( |
806 | const llvm::SmallVectorImpl<mlir::Value> &reductionResults) final { |
807 | this->checkReductions(reductionResults); |
808 | return hlfir::Entity{getCurrentMinMax(reductionResults)}; |
809 | } |
810 | |
811 | void |
812 | checkReductions(const llvm::SmallVectorImpl<mlir::Value> &reductions) const { |
813 | assert(reductions.size() == getNumReductions() && |
814 | "invalid number of reductions for MINVAL/MAXVAL" ); |
815 | } |
816 | |
817 | mlir::Value |
818 | getCurrentMinMax(const llvm::SmallVectorImpl<mlir::Value> &reductions) const { |
819 | this->checkReductions(reductions); |
820 | return reductions[0]; |
821 | } |
822 | |
823 | mlir::Value |
824 | getIsFirst(const llvm::SmallVectorImpl<mlir::Value> &reductions) const { |
825 | this->checkReductions(reductions); |
826 | assert(useIsFirst() && "IsFirst predicate must not be used" ); |
827 | return reductions[1]; |
828 | } |
829 | |
830 | // Return true iff the input can contain NaNs, and they should be |
831 | // honored, such that all-NaNs input must produce NaN result. |
832 | bool honorNans() const { |
833 | return !static_cast<bool>(this->getFastMath() & |
834 | mlir::arith::FastMathFlags::nnan); |
835 | } |
836 | |
837 | // Return true iff we have to use the loop-carried IsFirst predicate. |
838 | // If there is no mask, we can initialize the reductions using |
839 | // the first elements of the input. |
840 | // If NaNs are not honored, we can initialize the starting MIN/MAX |
841 | // value to +/-LARGEST. |
842 | bool useIsFirst() const { return this->getMask() && honorNans(); } |
843 | |
844 | std::size_t getNumReductions() const { return useIsFirst() ? 2 : 1; } |
845 | }; |
846 | |
847 | template <typename T> |
848 | llvm::SmallVector<mlir::Value> |
849 | MinMaxvalAsElementalConverter<T>::genReductionInitValues( |
850 | mlir::ValueRange oneBasedIndices, |
851 | const llvm::SmallVectorImpl<mlir::Value> &extents) { |
852 | llvm::SmallVector<mlir::Value> result; |
853 | fir::FirOpBuilder &builder = this->builder; |
854 | mlir::Location loc = this->loc; |
855 | |
856 | fir::IfOp ifOp; |
857 | if (!useIsFirst() && honorNans()) { |
858 | // Check if we can load the value of the first element in the array |
859 | // or its section (for partial reduction). |
860 | assert(!this->getMask() && |
861 | "cannot fetch first element when mask is present" ); |
862 | assert(extents.size() == |
863 | (this->isTotalReduction() ? this->getSourceRank() : 1u) && |
864 | "wrong number of extents for MINVAL/MAXVAL reduction" ); |
865 | mlir::Value isNotEmpty = genIsNotEmptyArrayExtents(loc, builder, extents); |
866 | llvm::SmallVector<mlir::Value> indices = genFirstElementIndicesForReduction( |
867 | loc, builder, this->isTotalReduction(), this->getConstDim(), |
868 | this->getSourceRank(), oneBasedIndices); |
869 | |
870 | ifOp = |
871 | builder.create<fir::IfOp>(loc, this->getResultElementType(), isNotEmpty, |
872 | /*withElseRegion=*/true); |
873 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
874 | mlir::Value minMaxFirst = hlfir::loadElementAt( |
875 | loc, builder, hlfir::Entity{this->getSource()}, indices); |
876 | builder.create<fir::ResultOp>(loc, minMaxFirst); |
877 | |
878 | // In the 'else' block use default init values. |
879 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
880 | } |
881 | |
882 | mlir::Value init = |
883 | genMinMaxInitValue<isMax>(loc, builder, this->getResultElementType()); |
884 | result.push_back(init); |
885 | |
886 | if (ifOp) { |
887 | builder.create<fir::ResultOp>(loc, result); |
888 | builder.setInsertionPointAfter(ifOp); |
889 | result = ifOp.getResults(); |
890 | } else if (useIsFirst()) { |
891 | // Initial value for isFirst predicate. It is switched to false, |
892 | // when the reduction update dynamically happens inside the reduction |
893 | // loop. |
894 | result.push_back(builder.createBool(loc, true)); |
895 | } |
896 | |
897 | return result; |
898 | } |
899 | |
900 | /// Reduction converter for SUM. |
901 | class SumAsElementalConverter |
902 | : public NumericReductionAsElementalConverterBase<hlfir::SumOp> { |
903 | using Base = NumericReductionAsElementalConverterBase; |
904 | |
905 | public: |
906 | SumAsElementalConverter(hlfir::SumOp op, mlir::PatternRewriter &rewriter) |
907 | : Base{op, rewriter} {} |
908 | |
909 | private: |
910 | virtual llvm::SmallVector<mlir::Value> genReductionInitValues( |
911 | [[maybe_unused]] mlir::ValueRange oneBasedIndices, |
912 | [[maybe_unused]] const llvm::SmallVectorImpl<mlir::Value> &extents) |
913 | final { |
914 | return { |
915 | fir::factory::createZeroValue(builder, loc, getResultElementType())}; |
916 | } |
917 | virtual llvm::SmallVector<mlir::Value> |
918 | reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> ¤tValue, |
919 | hlfir::Entity array, |
920 | mlir::ValueRange oneBasedIndices) final { |
921 | checkReductions(currentValue); |
922 | hlfir::Entity elementValue = |
923 | hlfir::loadElementAt(loc, builder, array, oneBasedIndices); |
924 | // NOTE: we can use "Kahan summation" same way as the runtime |
925 | // (e.g. when fast-math is not allowed), but let's start with |
926 | // the simple version. |
927 | return {genScalarAdd(currentValue[0], elementValue)}; |
928 | } |
929 | |
930 | // Generate scalar addition of the two values (of the same data type). |
931 | mlir::Value genScalarAdd(mlir::Value value1, mlir::Value value2); |
932 | }; |
933 | |
934 | /// Base class for logical reductions like ALL, ANY, COUNT. |
935 | /// They do not have MASK and FastMathFlags. |
936 | template <typename OpT> |
937 | class LogicalReductionAsElementalConverterBase |
938 | : public ReductionAsElementalConverter { |
939 | using Base = ReductionAsElementalConverter; |
940 | |
941 | public: |
942 | LogicalReductionAsElementalConverterBase(OpT op, |
943 | mlir::PatternRewriter &rewriter) |
944 | : Base{op.getOperation(), rewriter} {} |
945 | |
946 | protected: |
947 | OpT getOp() const { return mlir::cast<OpT>(op); } |
948 | |
949 | void checkReductions(const llvm::SmallVectorImpl<mlir::Value> &reductions) { |
950 | assert(reductions.size() == 1 && "reduction must produce single value" ); |
951 | } |
952 | |
953 | virtual mlir::Value getSource() const final { return getOp().getMask(); } |
954 | virtual mlir::Value getDim() const final { return getOp().getDim(); } |
955 | |
956 | virtual hlfir::Entity genFinalResult( |
957 | const llvm::SmallVectorImpl<mlir::Value> &reductionResults) override { |
958 | checkReductions(reductionResults); |
959 | return hlfir::Entity{reductionResults[0]}; |
960 | } |
961 | }; |
962 | |
963 | /// Reduction converter for ALL/ANY. |
964 | template <typename T> |
965 | class AllAnyAsElementalConverter |
966 | : public LogicalReductionAsElementalConverterBase<T> { |
967 | static_assert(std::is_same_v<T, hlfir::AllOp> || |
968 | std::is_same_v<T, hlfir::AnyOp>); |
969 | static constexpr bool isAll = std::is_same_v<T, hlfir::AllOp>; |
970 | using Base = LogicalReductionAsElementalConverterBase<T>; |
971 | |
972 | public: |
973 | AllAnyAsElementalConverter(T op, mlir::PatternRewriter &rewriter) |
974 | : Base{op, rewriter} {} |
975 | |
976 | private: |
977 | virtual llvm::SmallVector<mlir::Value> genReductionInitValues( |
978 | [[maybe_unused]] mlir::ValueRange oneBasedIndices, |
979 | [[maybe_unused]] const llvm::SmallVectorImpl<mlir::Value> &extents) |
980 | final { |
981 | return {this->builder.createBool(this->loc, isAll ? true : false)}; |
982 | } |
983 | virtual llvm::SmallVector<mlir::Value> |
984 | reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> ¤tValue, |
985 | hlfir::Entity array, |
986 | mlir::ValueRange oneBasedIndices) final { |
987 | this->checkReductions(currentValue); |
988 | fir::FirOpBuilder &builder = this->builder; |
989 | mlir::Location loc = this->loc; |
990 | hlfir::Entity elementValue = |
991 | hlfir::loadElementAt(loc, builder, array, oneBasedIndices); |
992 | mlir::Value mask = |
993 | builder.createConvert(loc, builder.getI1Type(), elementValue); |
994 | if constexpr (isAll) |
995 | return {builder.create<mlir::arith::AndIOp>(loc, mask, currentValue[0])}; |
996 | else |
997 | return {builder.create<mlir::arith::OrIOp>(loc, mask, currentValue[0])}; |
998 | } |
999 | |
1000 | virtual hlfir::Entity genFinalResult( |
1001 | const llvm::SmallVectorImpl<mlir::Value> &reductionValues) final { |
1002 | this->checkReductions(reductionValues); |
1003 | return hlfir::Entity{this->builder.createConvert( |
1004 | this->loc, this->getResultElementType(), reductionValues[0])}; |
1005 | } |
1006 | }; |
1007 | |
1008 | /// Reduction converter for COUNT. |
1009 | class CountAsElementalConverter |
1010 | : public LogicalReductionAsElementalConverterBase<hlfir::CountOp> { |
1011 | using Base = LogicalReductionAsElementalConverterBase<hlfir::CountOp>; |
1012 | |
1013 | public: |
1014 | CountAsElementalConverter(hlfir::CountOp op, mlir::PatternRewriter &rewriter) |
1015 | : Base{op, rewriter} {} |
1016 | |
1017 | private: |
1018 | virtual llvm::SmallVector<mlir::Value> genReductionInitValues( |
1019 | [[maybe_unused]] mlir::ValueRange oneBasedIndices, |
1020 | [[maybe_unused]] const llvm::SmallVectorImpl<mlir::Value> &extents) |
1021 | final { |
1022 | return { |
1023 | fir::factory::createZeroValue(builder, loc, getResultElementType())}; |
1024 | } |
1025 | virtual llvm::SmallVector<mlir::Value> |
1026 | reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> ¤tValue, |
1027 | hlfir::Entity array, |
1028 | mlir::ValueRange oneBasedIndices) final { |
1029 | checkReductions(currentValue); |
1030 | hlfir::Entity elementValue = |
1031 | hlfir::loadElementAt(loc, builder, array, oneBasedIndices); |
1032 | mlir::Value cond = |
1033 | builder.createConvert(loc, builder.getI1Type(), elementValue); |
1034 | mlir::Value one = |
1035 | builder.createIntegerConstant(loc, getResultElementType(), 1); |
1036 | mlir::Value add1 = |
1037 | builder.create<mlir::arith::AddIOp>(loc, currentValue[0], one); |
1038 | return {builder.create<mlir::arith::SelectOp>(loc, cond, add1, |
1039 | currentValue[0])}; |
1040 | } |
1041 | }; |
1042 | |
1043 | mlir::LogicalResult ReductionAsElementalConverter::convert() { |
1044 | mlir::LogicalResult canConvert(isConvertible()); |
1045 | |
1046 | if (mlir::failed(canConvert)) |
1047 | return canConvert; |
1048 | |
1049 | hlfir::Entity array = hlfir::Entity{getSource()}; |
1050 | bool isTotalReduce = isTotalReduction(); |
1051 | auto dimVal = getConstDim(); |
1052 | if (mlir::failed(dimVal)) |
1053 | return dimVal; |
1054 | mlir::Value mask = getMask(); |
1055 | mlir::Value resultShape, dimExtent; |
1056 | llvm::SmallVector<mlir::Value> arrayExtents; |
1057 | if (isTotalReduce) |
1058 | arrayExtents = hlfir::genExtentsVector(loc, builder, array); |
1059 | else |
1060 | std::tie(resultShape, dimExtent) = |
1061 | genResultShapeForPartialReduction(array, *dimVal); |
1062 | |
1063 | // If the mask is present and is a scalar, then we'd better load its value |
1064 | // outside of the reduction loop making the loop unswitching easier. |
1065 | mlir::Value isPresentPred, maskValue; |
1066 | if (mask) { |
1067 | if (mlir::isa<fir::BaseBoxType>(mask.getType())) { |
1068 | // MASK represented by a box might be dynamically optional, |
1069 | // so we have to check for its presence before accessing it. |
1070 | isPresentPred = |
1071 | builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask); |
1072 | } |
1073 | |
1074 | if (hlfir::Entity{mask}.isScalar()) |
1075 | maskValue = genMaskValue(mask, isPresentPred, {}); |
1076 | } |
1077 | |
1078 | auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
1079 | mlir::ValueRange inputIndices) -> hlfir::Entity { |
1080 | // Loop over all indices in the DIM dimension, and reduce all values. |
1081 | // If DIM is not present, do total reduction. |
1082 | |
1083 | llvm::SmallVector<mlir::Value> extents; |
1084 | if (isTotalReduce) |
1085 | extents = arrayExtents; |
1086 | else |
1087 | extents.push_back( |
1088 | builder.createConvert(loc, builder.getIndexType(), dimExtent)); |
1089 | |
1090 | // Initial value for the reduction. |
1091 | llvm::SmallVector<mlir::Value, 1> reductionInitValues = |
1092 | genReductionInitValues(inputIndices, extents); |
1093 | |
1094 | auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
1095 | mlir::ValueRange oneBasedIndices, |
1096 | mlir::ValueRange reductionArgs) |
1097 | -> llvm::SmallVector<mlir::Value, 1> { |
1098 | // Generate the reduction loop-nest body. |
1099 | // The initial reduction value in the innermost loop |
1100 | // is passed via reductionArgs[0]. |
1101 | llvm::SmallVector<mlir::Value> indices; |
1102 | if (isTotalReduce) { |
1103 | indices = oneBasedIndices; |
1104 | } else { |
1105 | indices = inputIndices; |
1106 | indices.insert(indices.begin() + *dimVal - 1, oneBasedIndices[0]); |
1107 | } |
1108 | |
1109 | llvm::SmallVector<mlir::Value, 1> reductionValues = reductionArgs; |
1110 | llvm::SmallVector<mlir::Type, 1> reductionTypes; |
1111 | llvm::transform(reductionValues, std::back_inserter(reductionTypes), |
1112 | [](mlir::Value v) { return v.getType(); }); |
1113 | fir::IfOp ifOp; |
1114 | if (mask) { |
1115 | // Make the reduction value update conditional on the value |
1116 | // of the mask. |
1117 | if (!maskValue) { |
1118 | // If the mask is an array, use the elemental and the loop indices |
1119 | // to address the proper mask element. |
1120 | maskValue = genMaskValue(mask, isPresentPred, indices); |
1121 | } |
1122 | mlir::Value isUnmasked = |
1123 | builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue); |
1124 | ifOp = builder.create<fir::IfOp>(loc, reductionTypes, isUnmasked, |
1125 | /*withElseRegion=*/true); |
1126 | // In the 'else' block return the current reduction value. |
1127 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
1128 | builder.create<fir::ResultOp>(loc, reductionValues); |
1129 | |
1130 | // In the 'then' block do the actual addition. |
1131 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
1132 | } |
1133 | reductionValues = reduceOneElement(reductionValues, array, indices); |
1134 | if (ifOp) { |
1135 | builder.create<fir::ResultOp>(loc, reductionValues); |
1136 | builder.setInsertionPointAfter(ifOp); |
1137 | reductionValues = ifOp.getResults(); |
1138 | } |
1139 | |
1140 | return reductionValues; |
1141 | }; |
1142 | |
1143 | llvm::SmallVector<mlir::Value, 1> reductionFinalValues = |
1144 | hlfir::genLoopNestWithReductions( |
1145 | loc, builder, extents, reductionInitValues, genBody, isUnordered()); |
1146 | return genFinalResult(reductionFinalValues); |
1147 | }; |
1148 | |
1149 | if (isTotalReduce) { |
1150 | hlfir::Entity result = genKernel(loc, builder, mlir::ValueRange{}); |
1151 | rewriter.replaceOp(op, result); |
1152 | return mlir::success(); |
1153 | } |
1154 | |
1155 | hlfir::ElementalOp elementalOp = hlfir::genElementalOp( |
1156 | loc, builder, getResultElementType(), resultShape, /*typeParams=*/{}, |
1157 | genKernel, |
1158 | /*isUnordered=*/true, /*polymorphicMold=*/nullptr, getResultType()); |
1159 | |
1160 | // it wouldn't be safe to replace block arguments with a different |
1161 | // hlfir.expr type. Types can differ due to differing amounts of shape |
1162 | // information |
1163 | assert(elementalOp.getResult().getType() == op->getResult(0).getType()); |
1164 | |
1165 | rewriter.replaceOp(op, elementalOp); |
1166 | return mlir::success(); |
1167 | } |
1168 | |
1169 | std::tuple<mlir::Value, mlir::Value> |
1170 | ReductionAsElementalConverter::genResultShapeForPartialReduction( |
1171 | hlfir::Entity array, int64_t dimVal) { |
1172 | llvm::SmallVector<mlir::Value> inExtents = |
1173 | hlfir::genExtentsVector(loc, builder, array); |
1174 | assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) && |
1175 | "DIM must be present and a positive constant not exceeding " |
1176 | "the array's rank" ); |
1177 | |
1178 | mlir::Value dimExtent = inExtents[dimVal - 1]; |
1179 | inExtents.erase(inExtents.begin() + dimVal - 1); |
1180 | return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent}; |
1181 | } |
1182 | |
1183 | mlir::Value SumAsElementalConverter::genScalarAdd(mlir::Value value1, |
1184 | mlir::Value value2) { |
1185 | mlir::Type ty = value1.getType(); |
1186 | assert(ty == value2.getType() && "reduction values' types do not match" ); |
1187 | if (mlir::isa<mlir::FloatType>(ty)) |
1188 | return builder.create<mlir::arith::AddFOp>(loc, value1, value2); |
1189 | else if (mlir::isa<mlir::ComplexType>(ty)) |
1190 | return builder.create<fir::AddcOp>(loc, value1, value2); |
1191 | else if (mlir::isa<mlir::IntegerType>(ty)) |
1192 | return builder.create<mlir::arith::AddIOp>(loc, value1, value2); |
1193 | |
1194 | llvm_unreachable("unsupported SUM reduction type" ); |
1195 | } |
1196 | |
1197 | mlir::Value ReductionAsElementalConverter::genMaskValue( |
1198 | mlir::Value mask, mlir::Value isPresentPred, mlir::ValueRange indices) { |
1199 | mlir::OpBuilder::InsertionGuard guard(builder); |
1200 | fir::IfOp ifOp; |
1201 | mlir::Type maskType = |
1202 | hlfir::getFortranElementType(fir::unwrapPassByRefType(mask.getType())); |
1203 | if (isPresentPred) { |
1204 | ifOp = builder.create<fir::IfOp>(loc, maskType, isPresentPred, |
1205 | /*withElseRegion=*/true); |
1206 | |
1207 | // Use 'true', if the mask is not present. |
1208 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
1209 | mlir::Value trueValue = builder.createBool(loc, true); |
1210 | trueValue = builder.createConvert(loc, maskType, trueValue); |
1211 | builder.create<fir::ResultOp>(loc, trueValue); |
1212 | |
1213 | // Load the mask value, if the mask is present. |
1214 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
1215 | } |
1216 | |
1217 | hlfir::Entity maskVar{mask}; |
1218 | if (maskVar.isScalar()) { |
1219 | if (mlir::isa<fir::BaseBoxType>(mask.getType())) { |
1220 | // MASK may be a boxed scalar. |
1221 | mlir::Value addr = hlfir::genVariableRawAddress(loc, builder, maskVar); |
1222 | mask = builder.create<fir::LoadOp>(loc, hlfir::Entity{addr}); |
1223 | } else { |
1224 | mask = hlfir::loadTrivialScalar(loc, builder, maskVar); |
1225 | } |
1226 | } else { |
1227 | // Load from the mask array. |
1228 | assert(!indices.empty() && "no indices for addressing the mask array" ); |
1229 | maskVar = hlfir::getElementAt(loc, builder, maskVar, indices); |
1230 | mask = hlfir::loadTrivialScalar(loc, builder, maskVar); |
1231 | } |
1232 | |
1233 | if (!isPresentPred) |
1234 | return mask; |
1235 | |
1236 | builder.create<fir::ResultOp>(loc, mask); |
1237 | return ifOp.getResult(0); |
1238 | } |
1239 | |
1240 | /// Convert an operation that is a partial or total reduction |
1241 | /// over an array of values into a reduction loop[-nest] |
1242 | /// optionally wrapped into hlfir.elemental. |
1243 | template <typename Op> |
1244 | class ReductionConversion : public mlir::OpRewritePattern<Op> { |
1245 | public: |
1246 | using mlir::OpRewritePattern<Op>::OpRewritePattern; |
1247 | |
1248 | llvm::LogicalResult |
1249 | matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { |
1250 | if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> || |
1251 | std::is_same_v<Op, hlfir::MinlocOp>) { |
1252 | MinMaxlocAsElementalConverter<Op> converter(op, rewriter); |
1253 | return converter.convert(); |
1254 | } else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> || |
1255 | std::is_same_v<Op, hlfir::MinvalOp>) { |
1256 | MinMaxvalAsElementalConverter<Op> converter(op, rewriter); |
1257 | return converter.convert(); |
1258 | } else if constexpr (std::is_same_v<Op, hlfir::CountOp>) { |
1259 | CountAsElementalConverter converter(op, rewriter); |
1260 | return converter.convert(); |
1261 | } else if constexpr (std::is_same_v<Op, hlfir::AllOp> || |
1262 | std::is_same_v<Op, hlfir::AnyOp>) { |
1263 | AllAnyAsElementalConverter<Op> converter(op, rewriter); |
1264 | return converter.convert(); |
1265 | } else if constexpr (std::is_same_v<Op, hlfir::SumOp>) { |
1266 | SumAsElementalConverter converter{op, rewriter}; |
1267 | return converter.convert(); |
1268 | } |
1269 | return rewriter.notifyMatchFailure(op, "unexpected reduction operation" ); |
1270 | } |
1271 | }; |
1272 | |
1273 | class CShiftConversion : public mlir::OpRewritePattern<hlfir::CShiftOp> { |
1274 | public: |
1275 | using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern; |
1276 | |
1277 | llvm::LogicalResult |
1278 | matchAndRewrite(hlfir::CShiftOp cshift, |
1279 | mlir::PatternRewriter &rewriter) const override { |
1280 | |
1281 | hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType()); |
1282 | assert(expr && |
1283 | "expected an expression type for the result of hlfir.cshift" ); |
1284 | unsigned arrayRank = expr.getRank(); |
1285 | // When it is a 1D CSHIFT, we may assume that the DIM argument |
1286 | // (whether it is present or absent) is equal to 1, otherwise, |
1287 | // the program is illegal. |
1288 | int64_t dimVal = 1; |
1289 | if (arrayRank != 1) |
1290 | if (mlir::Value dim = cshift.getDim()) { |
1291 | auto constDim = fir::getIntIfConstant(dim); |
1292 | if (!constDim) |
1293 | return rewriter.notifyMatchFailure(cshift, |
1294 | "Nonconstant DIM for CSHIFT" ); |
1295 | dimVal = *constDim; |
1296 | } |
1297 | |
1298 | if (dimVal <= 0 || dimVal > arrayRank) |
1299 | return rewriter.notifyMatchFailure(cshift, "Invalid DIM for CSHIFT" ); |
1300 | |
1301 | // When DIM==1 and the contiguity of the input array is not statically |
1302 | // known, try to exploit the fact that the leading dimension might be |
1303 | // contiguous. We can do this now using hlfir.eval_in_mem with |
1304 | // a dynamic check for the leading dimension contiguity. |
1305 | // Otherwise, convert hlfir.cshift to hlfir.elemental. |
1306 | // |
1307 | // Note that the hlfir.elemental can be inlined into other hlfir.elemental, |
1308 | // while hlfir.eval_in_mem prevents this, and we will end up creating |
1309 | // a temporary array for the result. We may need to come up with |
1310 | // a more sophisticated logic for picking the most efficient |
1311 | // representation. |
1312 | hlfir::Entity array = hlfir::Entity{cshift.getArray()}; |
1313 | mlir::Type elementType = array.getFortranElementType(); |
1314 | if (dimVal == 1 && fir::isa_trivial(elementType) && |
1315 | // genInMemCShift() only works for variables currently. |
1316 | array.isVariable()) |
1317 | rewriter.replaceOp(cshift, genInMemCShift(rewriter, cshift, dimVal)); |
1318 | else |
1319 | rewriter.replaceOp(cshift, genElementalCShift(rewriter, cshift, dimVal)); |
1320 | return mlir::success(); |
1321 | } |
1322 | |
1323 | private: |
1324 | /// Generate MODULO(\p shiftVal, \p extent). |
1325 | static mlir::Value normalizeShiftValue(mlir::Location loc, |
1326 | fir::FirOpBuilder &builder, |
1327 | mlir::Value shiftVal, |
1328 | mlir::Value extent, |
1329 | mlir::Type calcType) { |
1330 | shiftVal = builder.createConvert(loc, calcType, shiftVal); |
1331 | extent = builder.createConvert(loc, calcType, extent); |
1332 | // Make sure that we do not divide by zero. When the dimension |
1333 | // has zero size, turn the extent into 1. Note that the computed |
1334 | // MODULO value won't be used in this case, so it does not matter |
1335 | // which extent value we use. |
1336 | mlir::Value zero = builder.createIntegerConstant(loc, calcType, 0); |
1337 | mlir::Value one = builder.createIntegerConstant(loc, calcType, 1); |
1338 | mlir::Value isZero = builder.create<mlir::arith::CmpIOp>( |
1339 | loc, mlir::arith::CmpIPredicate::eq, extent, zero); |
1340 | extent = builder.create<mlir::arith::SelectOp>(loc, isZero, one, extent); |
1341 | shiftVal = fir::IntrinsicLibrary{builder, loc}.genModulo( |
1342 | calcType, {shiftVal, extent}); |
1343 | return builder.createConvert(loc, calcType, shiftVal); |
1344 | } |
1345 | |
1346 | /// Convert \p cshift into an hlfir.elemental using |
1347 | /// the pre-computed constant \p dimVal. |
1348 | static mlir::Operation *genElementalCShift(mlir::PatternRewriter &rewriter, |
1349 | hlfir::CShiftOp cshift, |
1350 | int64_t dimVal) { |
1351 | using Fortran::common::maxRank; |
1352 | hlfir::Entity shift = hlfir::Entity{cshift.getShift()}; |
1353 | hlfir::Entity array = hlfir::Entity{cshift.getArray()}; |
1354 | |
1355 | mlir::Location loc = cshift.getLoc(); |
1356 | fir::FirOpBuilder builder{rewriter, cshift.getOperation()}; |
1357 | // The new index computation involves MODULO, which is not implemented |
1358 | // for IndexType, so use I64 instead. |
1359 | mlir::Type calcType = builder.getI64Type(); |
1360 | // All the indices arithmetic used below does not overflow |
1361 | // signed and unsigned I64. |
1362 | builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw | |
1363 | mlir::arith::IntegerOverflowFlags::nuw); |
1364 | |
1365 | mlir::Value arrayShape = hlfir::genShape(loc, builder, array); |
1366 | llvm::SmallVector<mlir::Value, maxRank> arrayExtents = |
1367 | hlfir::getExplicitExtentsFromShape(arrayShape, builder); |
1368 | llvm::SmallVector<mlir::Value, 1> typeParams; |
1369 | hlfir::genLengthParameters(loc, builder, array, typeParams); |
1370 | mlir::Value shiftDimExtent = |
1371 | builder.createConvert(loc, calcType, arrayExtents[dimVal - 1]); |
1372 | mlir::Value shiftVal; |
1373 | if (shift.isScalar()) { |
1374 | shiftVal = hlfir::loadTrivialScalar(loc, builder, shift); |
1375 | shiftVal = |
1376 | normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent, calcType); |
1377 | } |
1378 | |
1379 | auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
1380 | mlir::ValueRange inputIndices) -> hlfir::Entity { |
1381 | llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices}; |
1382 | if (!shiftVal) { |
1383 | // When the array is not a vector, section |
1384 | // (s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n) |
1385 | // of the result has a value equal to: |
1386 | // CSHIFT(ARRAY(s(1), s(2), ..., s(dim-1), :, s(dim+1), ..., s(n)), |
1387 | // SH, 1), |
1388 | // where SH is either SHIFT (if scalar) or |
1389 | // SHIFT(s(1), s(2), ..., s(dim-1), s(dim+1), ..., s(n)). |
1390 | llvm::SmallVector<mlir::Value, maxRank> shiftIndices{indices}; |
1391 | shiftIndices.erase(shiftIndices.begin() + dimVal - 1); |
1392 | hlfir::Entity shiftElement = |
1393 | hlfir::getElementAt(loc, builder, shift, shiftIndices); |
1394 | shiftVal = hlfir::loadTrivialScalar(loc, builder, shiftElement); |
1395 | shiftVal = normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent, |
1396 | calcType); |
1397 | } |
1398 | |
1399 | // Element i of the result (1-based) is element |
1400 | // 'MODULO(i + SH - 1, SIZE(ARRAY,DIM)) + 1' (1-based) of the original |
1401 | // ARRAY (or its section, when ARRAY is not a vector). |
1402 | |
1403 | // Compute the index into the original array using the normalized |
1404 | // shift value, which satisfies (SH >= 0 && SH < SIZE(ARRAY,DIM)): |
1405 | // newIndex = |
1406 | // i + ((i <= SIZE(ARRAY,DIM) - SH) ? SH : SH - SIZE(ARRAY,DIM)) |
1407 | // |
1408 | // Such index computation allows for further loop vectorization |
1409 | // in LLVM. |
1410 | mlir::Value wrapBound = |
1411 | builder.create<mlir::arith::SubIOp>(loc, shiftDimExtent, shiftVal); |
1412 | mlir::Value adjustedShiftVal = |
1413 | builder.create<mlir::arith::SubIOp>(loc, shiftVal, shiftDimExtent); |
1414 | mlir::Value index = |
1415 | builder.createConvert(loc, calcType, inputIndices[dimVal - 1]); |
1416 | mlir::Value wrapCheck = builder.create<mlir::arith::CmpIOp>( |
1417 | loc, mlir::arith::CmpIPredicate::sle, index, wrapBound); |
1418 | mlir::Value actualShift = builder.create<mlir::arith::SelectOp>( |
1419 | loc, wrapCheck, shiftVal, adjustedShiftVal); |
1420 | mlir::Value newIndex = |
1421 | builder.create<mlir::arith::AddIOp>(loc, index, actualShift); |
1422 | newIndex = builder.createConvert(loc, builder.getIndexType(), newIndex); |
1423 | indices[dimVal - 1] = newIndex; |
1424 | hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices); |
1425 | return hlfir::loadTrivialScalar(loc, builder, element); |
1426 | }; |
1427 | |
1428 | mlir::Type elementType = array.getFortranElementType(); |
1429 | hlfir::ElementalOp elementalOp = hlfir::genElementalOp( |
1430 | loc, builder, elementType, arrayShape, typeParams, genKernel, |
1431 | /*isUnordered=*/true, |
1432 | array.isPolymorphic() ? static_cast<mlir::Value>(array) : nullptr, |
1433 | cshift.getResult().getType()); |
1434 | return elementalOp.getOperation(); |
1435 | } |
1436 | |
1437 | /// Convert \p cshift into an hlfir.eval_in_mem using the pre-computed |
1438 | /// constant \p dimVal. |
1439 | /// The converted code looks like this: |
1440 | /// do i=1,SH |
1441 | /// result(i + (SIZE(ARRAY,DIM) - SH)) = array(i) |
1442 | /// end |
1443 | /// do i=1,SIZE(ARRAY,DIM) - SH |
1444 | /// result(i) = array(i + SH) |
1445 | /// end |
1446 | /// |
1447 | /// When \p dimVal is 1, we generate the same code twice |
1448 | /// under a dynamic check for the contiguity of the leading |
1449 | /// dimension. In the code corresponding to the contiguous |
1450 | /// leading dimension, the shift dimension is represented |
1451 | /// as a contiguous slice of the original array. |
1452 | /// This allows recognizing the above two loops as memcpy |
1453 | /// loop idioms in LLVM. |
1454 | static mlir::Operation *genInMemCShift(mlir::PatternRewriter &rewriter, |
1455 | hlfir::CShiftOp cshift, |
1456 | int64_t dimVal) { |
1457 | using Fortran::common::maxRank; |
1458 | hlfir::Entity shift = hlfir::Entity{cshift.getShift()}; |
1459 | hlfir::Entity array = hlfir::Entity{cshift.getArray()}; |
1460 | assert(array.isVariable() && "array must be a variable" ); |
1461 | assert(!array.isPolymorphic() && |
1462 | "genInMemCShift does not support polymorphic types" ); |
1463 | mlir::Location loc = cshift.getLoc(); |
1464 | fir::FirOpBuilder builder{rewriter, cshift.getOperation()}; |
1465 | // The new index computation involves MODULO, which is not implemented |
1466 | // for IndexType, so use I64 instead. |
1467 | mlir::Type calcType = builder.getI64Type(); |
1468 | // All the indices arithmetic used below does not overflow |
1469 | // signed and unsigned I64. |
1470 | builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw | |
1471 | mlir::arith::IntegerOverflowFlags::nuw); |
1472 | |
1473 | mlir::Value arrayShape = hlfir::genShape(loc, builder, array); |
1474 | llvm::SmallVector<mlir::Value, maxRank> arrayExtents = |
1475 | hlfir::getExplicitExtentsFromShape(arrayShape, builder); |
1476 | llvm::SmallVector<mlir::Value, 1> typeParams; |
1477 | hlfir::genLengthParameters(loc, builder, array, typeParams); |
1478 | mlir::Value shiftDimExtent = |
1479 | builder.createConvert(loc, calcType, arrayExtents[dimVal - 1]); |
1480 | mlir::Value shiftVal; |
1481 | if (shift.isScalar()) { |
1482 | shiftVal = hlfir::loadTrivialScalar(loc, builder, shift); |
1483 | shiftVal = |
1484 | normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent, calcType); |
1485 | } |
1486 | |
1487 | hlfir::EvaluateInMemoryOp evalOp = |
1488 | builder.create<hlfir::EvaluateInMemoryOp>( |
1489 | loc, mlir::cast<hlfir::ExprType>(cshift.getType()), arrayShape); |
1490 | builder.setInsertionPointToStart(&evalOp.getBody().front()); |
1491 | |
1492 | mlir::Value resultArray = evalOp.getMemory(); |
1493 | mlir::Type arrayType = fir::dyn_cast_ptrEleTy(resultArray.getType()); |
1494 | resultArray = builder.createBox(loc, fir::BoxType::get(arrayType), |
1495 | resultArray, arrayShape, /*slice=*/nullptr, |
1496 | typeParams, /*tdesc=*/nullptr); |
1497 | |
1498 | // This is a generator of the dimension shift code. |
1499 | // The code is inserted inside a loop nest over the other dimensions |
1500 | // (if any). If exposeContiguity is true, the array's section |
1501 | // array(s(1), ..., s(dim-1), :, s(dim+1), ..., s(n)) is represented |
1502 | // as a contiguous 1D array. |
1503 | // shiftVal is the normalized shift value that satisfies (SH >= 0 && SH < |
1504 | // SIZE(ARRAY,DIM)). |
1505 | // |
1506 | auto genDimensionShift = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
1507 | mlir::Value shiftVal, bool exposeContiguity, |
1508 | mlir::ValueRange oneBasedIndices) |
1509 | -> llvm::SmallVector<mlir::Value, 0> { |
1510 | // Create a vector of indices (s(1), ..., s(dim-1), nullptr, s(dim+1), |
1511 | // ..., s(n)) so that we can update the dimVal index as needed. |
1512 | llvm::SmallVector<mlir::Value, maxRank> srcIndices( |
1513 | oneBasedIndices.begin(), oneBasedIndices.begin() + (dimVal - 1)); |
1514 | srcIndices.push_back(nullptr); |
1515 | srcIndices.append(oneBasedIndices.begin() + (dimVal - 1), |
1516 | oneBasedIndices.end()); |
1517 | llvm::SmallVector<mlir::Value, maxRank> dstIndices(srcIndices); |
1518 | |
1519 | hlfir::Entity srcArray = array; |
1520 | if (exposeContiguity && mlir::isa<fir::BaseBoxType>(srcArray.getType())) { |
1521 | assert(dimVal == 1 && "can expose contiguity only for dim 1" ); |
1522 | llvm::SmallVector<mlir::Value, maxRank> arrayLbounds = |
1523 | hlfir::genLowerbounds(loc, builder, arrayShape, array.getRank()); |
1524 | hlfir::Entity section = |
1525 | hlfir::gen1DSection(loc, builder, srcArray, dimVal, arrayLbounds, |
1526 | arrayExtents, oneBasedIndices, typeParams); |
1527 | mlir::Value addr = hlfir::genVariableRawAddress(loc, builder, section); |
1528 | mlir::Value shape = hlfir::genShape(loc, builder, section); |
1529 | mlir::Type boxType = fir::wrapInClassOrBoxType( |
1530 | hlfir::getFortranElementOrSequenceType(section.getType()), |
1531 | section.isPolymorphic()); |
1532 | srcArray = hlfir::Entity{ |
1533 | builder.createBox(loc, boxType, addr, shape, /*slice=*/nullptr, |
1534 | /*lengths=*/{}, /*tdesc=*/nullptr)}; |
1535 | // When shifting the dimension as a 1D section of the original |
1536 | // array, we only need one index for addressing. |
1537 | srcIndices.resize(1); |
1538 | } |
1539 | |
1540 | // Copy first portion of the array: |
1541 | // do i=1,SH |
1542 | // result(i + (SIZE(ARRAY,DIM) - SH)) = array(i) |
1543 | // end |
1544 | auto genAssign1 = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
1545 | mlir::ValueRange index, |
1546 | mlir::ValueRange reductionArgs) |
1547 | -> llvm::SmallVector<mlir::Value, 0> { |
1548 | assert(index.size() == 1 && "expected single loop" ); |
1549 | mlir::Value srcIndex = builder.createConvert(loc, calcType, index[0]); |
1550 | srcIndices[dimVal - 1] = srcIndex; |
1551 | hlfir::Entity srcElementValue = |
1552 | hlfir::loadElementAt(loc, builder, srcArray, srcIndices); |
1553 | mlir::Value dstIndex = builder.create<mlir::arith::AddIOp>( |
1554 | loc, srcIndex, |
1555 | builder.create<mlir::arith::SubIOp>(loc, shiftDimExtent, shiftVal)); |
1556 | dstIndices[dimVal - 1] = dstIndex; |
1557 | hlfir::Entity dstElement = hlfir::getElementAt( |
1558 | loc, builder, hlfir::Entity{resultArray}, dstIndices); |
1559 | builder.create<hlfir::AssignOp>(loc, srcElementValue, dstElement); |
1560 | return {}; |
1561 | }; |
1562 | |
1563 | // Generate the first loop. |
1564 | hlfir::genLoopNestWithReductions(loc, builder, {shiftVal}, |
1565 | /*reductionInits=*/{}, genAssign1, |
1566 | /*isUnordered=*/true); |
1567 | |
1568 | // Copy second portion of the array: |
1569 | // do i=1,SIZE(ARRAY,DIM)-SH |
1570 | // result(i) = array(i + SH) |
1571 | // end |
1572 | auto genAssign2 = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
1573 | mlir::ValueRange index, |
1574 | mlir::ValueRange reductionArgs) |
1575 | -> llvm::SmallVector<mlir::Value, 0> { |
1576 | assert(index.size() == 1 && "expected single loop" ); |
1577 | mlir::Value dstIndex = builder.createConvert(loc, calcType, index[0]); |
1578 | mlir::Value srcIndex = |
1579 | builder.create<mlir::arith::AddIOp>(loc, dstIndex, shiftVal); |
1580 | srcIndices[dimVal - 1] = srcIndex; |
1581 | hlfir::Entity srcElementValue = |
1582 | hlfir::loadElementAt(loc, builder, srcArray, srcIndices); |
1583 | dstIndices[dimVal - 1] = dstIndex; |
1584 | hlfir::Entity dstElement = hlfir::getElementAt( |
1585 | loc, builder, hlfir::Entity{resultArray}, dstIndices); |
1586 | builder.create<hlfir::AssignOp>(loc, srcElementValue, dstElement); |
1587 | return {}; |
1588 | }; |
1589 | |
1590 | // Generate the second loop. |
1591 | mlir::Value bound = |
1592 | builder.create<mlir::arith::SubIOp>(loc, shiftDimExtent, shiftVal); |
1593 | hlfir::genLoopNestWithReductions(loc, builder, {bound}, |
1594 | /*reductionInits=*/{}, genAssign2, |
1595 | /*isUnordered=*/true); |
1596 | return {}; |
1597 | }; |
1598 | |
1599 | // A wrapper around genDimensionShift that computes the normalized |
1600 | // shift value and manages the insertion of the multiple versions |
1601 | // of the shift based on the dynamic check of the leading dimension's |
1602 | // contiguity (when dimVal == 1). |
1603 | auto genShiftBody = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
1604 | mlir::ValueRange oneBasedIndices, |
1605 | mlir::ValueRange reductionArgs) |
1606 | -> llvm::SmallVector<mlir::Value, 0> { |
1607 | // Copy the dimension with a shift: |
1608 | // SH is either SHIFT (if scalar) or SHIFT(oneBasedIndices). |
1609 | if (!shiftVal) { |
1610 | assert(!oneBasedIndices.empty() && "scalar shift must be precomputed" ); |
1611 | hlfir::Entity shiftElement = |
1612 | hlfir::getElementAt(loc, builder, shift, oneBasedIndices); |
1613 | shiftVal = hlfir::loadTrivialScalar(loc, builder, shiftElement); |
1614 | shiftVal = normalizeShiftValue(loc, builder, shiftVal, shiftDimExtent, |
1615 | calcType); |
1616 | } |
1617 | |
1618 | // If we can fetch the byte stride of the leading dimension, |
1619 | // and the byte size of the element, then we can generate |
1620 | // a dynamic contiguity check and expose the leading dimension's |
1621 | // contiguity in FIR, making memcpy loop idiom recognition |
1622 | // possible. |
1623 | mlir::Value elemSize; |
1624 | mlir::Value stride; |
1625 | if (dimVal == 1 && mlir::isa<fir::BaseBoxType>(array.getType())) { |
1626 | mlir::Type indexType = builder.getIndexType(); |
1627 | elemSize = |
1628 | builder.create<fir::BoxEleSizeOp>(loc, indexType, array.getBase()); |
1629 | mlir::Value dimIdx = |
1630 | builder.createIntegerConstant(loc, indexType, dimVal - 1); |
1631 | auto boxDim = builder.create<fir::BoxDimsOp>( |
1632 | loc, indexType, indexType, indexType, array.getBase(), dimIdx); |
1633 | stride = boxDim.getByteStride(); |
1634 | } |
1635 | |
1636 | if (array.isSimplyContiguous() || !elemSize || !stride) { |
1637 | genDimensionShift(loc, builder, shiftVal, /*exposeContiguity=*/false, |
1638 | oneBasedIndices); |
1639 | return {}; |
1640 | } |
1641 | |
1642 | mlir::Value isContiguous = builder.create<mlir::arith::CmpIOp>( |
1643 | loc, mlir::arith::CmpIPredicate::eq, elemSize, stride); |
1644 | builder.genIfOp(loc, {}, isContiguous, /*withElseRegion=*/true) |
1645 | .genThen([&]() { |
1646 | genDimensionShift(loc, builder, shiftVal, /*exposeContiguity=*/true, |
1647 | oneBasedIndices); |
1648 | }) |
1649 | .genElse([&]() { |
1650 | genDimensionShift(loc, builder, shiftVal, |
1651 | /*exposeContiguity=*/false, oneBasedIndices); |
1652 | }); |
1653 | |
1654 | return {}; |
1655 | }; |
1656 | |
1657 | // For 1D case, generate a single loop. |
1658 | // For ND case, generate a loop nest over the other dimensions |
1659 | // with a single loop inside (generated separately). |
1660 | llvm::SmallVector<mlir::Value, maxRank> newExtents(arrayExtents); |
1661 | newExtents.erase(newExtents.begin() + (dimVal - 1)); |
1662 | if (!newExtents.empty()) |
1663 | hlfir::genLoopNestWithReductions(loc, builder, newExtents, |
1664 | /*reductionInits=*/{}, genShiftBody, |
1665 | /*isUnordered=*/true); |
1666 | else |
1667 | genShiftBody(loc, builder, {}, {}); |
1668 | |
1669 | return evalOp.getOperation(); |
1670 | } |
1671 | }; |
1672 | |
1673 | template <typename Op> |
1674 | class MatmulConversion : public mlir::OpRewritePattern<Op> { |
1675 | public: |
1676 | using mlir::OpRewritePattern<Op>::OpRewritePattern; |
1677 | |
1678 | llvm::LogicalResult |
1679 | matchAndRewrite(Op matmul, mlir::PatternRewriter &rewriter) const override { |
1680 | mlir::Location loc = matmul.getLoc(); |
1681 | fir::FirOpBuilder builder{rewriter, matmul.getOperation()}; |
1682 | hlfir::Entity lhs = hlfir::Entity{matmul.getLhs()}; |
1683 | hlfir::Entity rhs = hlfir::Entity{matmul.getRhs()}; |
1684 | mlir::Value resultShape, innerProductExtent; |
1685 | std::tie(resultShape, innerProductExtent) = |
1686 | genResultShape(loc, builder, lhs, rhs); |
1687 | |
1688 | if (forceMatmulAsElemental || isMatmulTranspose) { |
1689 | // Generate hlfir.elemental that produces the result of |
1690 | // MATMUL/MATMUL(TRANSPOSE). |
1691 | // Note that this implementation is very suboptimal for MATMUL, |
1692 | // but is quite good for MATMUL(TRANSPOSE), e.g.: |
1693 | // R(1:N) = R(1:N) + MATMUL(TRANSPOSE(X(1:N,1:N)), Y(1:N)) |
1694 | // Inlining MATMUL(TRANSPOSE) as hlfir.elemental may result |
1695 | // in merging the inner product computation with the elemental |
1696 | // addition. Note that the inner product computation will |
1697 | // benefit from processing the lowermost dimensions of X and Y, |
1698 | // which may be the best when they are contiguous. |
1699 | // |
1700 | // This is why we always inline MATMUL(TRANSPOSE) as an elemental. |
1701 | // MATMUL is inlined below by default unless forceMatmulAsElemental. |
1702 | hlfir::ExprType resultType = |
1703 | mlir::cast<hlfir::ExprType>(matmul.getType()); |
1704 | hlfir::ElementalOp newOp = genElementalMatmul( |
1705 | loc, builder, resultType, resultShape, lhs, rhs, innerProductExtent); |
1706 | rewriter.replaceOp(matmul, newOp); |
1707 | return mlir::success(); |
1708 | } |
1709 | |
1710 | // Generate hlfir.eval_in_mem to mimic the MATMUL implementation |
1711 | // from Fortran runtime. The implementation needs to operate |
1712 | // with the result array as an in-memory object. |
1713 | hlfir::EvaluateInMemoryOp evalOp = |
1714 | builder.create<hlfir::EvaluateInMemoryOp>( |
1715 | loc, mlir::cast<hlfir::ExprType>(matmul.getType()), resultShape); |
1716 | builder.setInsertionPointToStart(&evalOp.getBody().front()); |
1717 | |
1718 | // Embox the raw array pointer to simplify designating it. |
1719 | // TODO: this currently results in redundant lower bounds |
1720 | // addition for the designator, but this should be fixed in |
1721 | // hlfir::Entity::mayHaveNonDefaultLowerBounds(). |
1722 | mlir::Value resultArray = evalOp.getMemory(); |
1723 | mlir::Type arrayType = fir::dyn_cast_ptrEleTy(resultArray.getType()); |
1724 | resultArray = builder.createBox(loc, fir::BoxType::get(arrayType), |
1725 | resultArray, resultShape, /*slice=*/nullptr, |
1726 | /*lengths=*/{}, /*tdesc=*/nullptr); |
1727 | |
1728 | // The contiguous MATMUL version is best for the cases |
1729 | // where the input arrays and (maybe) the result are contiguous |
1730 | // in their lowermost dimensions. |
1731 | // Especially, when LLVM can recognize the continuity |
1732 | // and vectorize the loops properly. |
1733 | // Note that the contiguous MATMUL inlining is correct |
1734 | // even when the input arrays are not contiguous. |
1735 | // TODO: we can try to recognize the cases when the continuity |
1736 | // is not statically obvious and try to generate an explicitly |
1737 | // continuous version under a dynamic check. This should allow |
1738 | // LLVM to vectorize the loops better. Note that this can |
1739 | // also be postponed up to the LoopVersioning pass. |
1740 | // The fallback implementation may use genElementalMatmul() with |
1741 | // an hlfir.assign into the result of eval_in_mem. |
1742 | mlir::LogicalResult rewriteResult = |
1743 | genContiguousMatmul(loc, builder, hlfir::Entity{resultArray}, |
1744 | resultShape, lhs, rhs, innerProductExtent); |
1745 | |
1746 | if (mlir::failed(rewriteResult)) { |
1747 | // Erase the unclaimed eval_in_mem op. |
1748 | rewriter.eraseOp(evalOp); |
1749 | return rewriter.notifyMatchFailure(matmul, |
1750 | "genContiguousMatmul() failed" ); |
1751 | } |
1752 | |
1753 | rewriter.replaceOp(matmul, evalOp); |
1754 | return mlir::success(); |
1755 | } |
1756 | |
1757 | private: |
1758 | static constexpr bool isMatmulTranspose = |
1759 | std::is_same_v<Op, hlfir::MatmulTransposeOp>; |
1760 | |
1761 | // Return a tuple of: |
1762 | // * A fir.shape operation representing the shape of the result |
1763 | // of a MATMUL/MATMUL(TRANSPOSE). |
1764 | // * An extent of the dimensions of the input array |
1765 | // that are processed during the inner product computation. |
1766 | static std::tuple<mlir::Value, mlir::Value> |
1767 | genResultShape(mlir::Location loc, fir::FirOpBuilder &builder, |
1768 | hlfir::Entity input1, hlfir::Entity input2) { |
1769 | llvm::SmallVector<mlir::Value, 2> input1Extents = |
1770 | hlfir::genExtentsVector(loc, builder, input1); |
1771 | llvm::SmallVector<mlir::Value, 2> input2Extents = |
1772 | hlfir::genExtentsVector(loc, builder, input2); |
1773 | |
1774 | llvm::SmallVector<mlir::Value, 2> newExtents; |
1775 | mlir::Value innerProduct1Extent, innerProduct2Extent; |
1776 | if (input1Extents.size() == 1) { |
1777 | assert(!isMatmulTranspose && |
1778 | "hlfir.matmul_transpose's first operand must be rank-2 array" ); |
1779 | assert(input2Extents.size() == 2 && |
1780 | "hlfir.matmul second argument must be rank-2 array" ); |
1781 | newExtents.push_back(input2Extents[1]); |
1782 | innerProduct1Extent = input1Extents[0]; |
1783 | innerProduct2Extent = input2Extents[0]; |
1784 | } else { |
1785 | if (input2Extents.size() == 1) { |
1786 | assert(input1Extents.size() == 2 && |
1787 | "hlfir.matmul first argument must be rank-2 array" ); |
1788 | if constexpr (isMatmulTranspose) |
1789 | newExtents.push_back(input1Extents[1]); |
1790 | else |
1791 | newExtents.push_back(input1Extents[0]); |
1792 | } else { |
1793 | assert(input1Extents.size() == 2 && input2Extents.size() == 2 && |
1794 | "hlfir.matmul arguments must be rank-2 arrays" ); |
1795 | if constexpr (isMatmulTranspose) |
1796 | newExtents.push_back(input1Extents[1]); |
1797 | else |
1798 | newExtents.push_back(input1Extents[0]); |
1799 | |
1800 | newExtents.push_back(input2Extents[1]); |
1801 | } |
1802 | if constexpr (isMatmulTranspose) |
1803 | innerProduct1Extent = input1Extents[0]; |
1804 | else |
1805 | innerProduct1Extent = input1Extents[1]; |
1806 | |
1807 | innerProduct2Extent = input2Extents[0]; |
1808 | } |
1809 | // The inner product dimensions of the input arrays |
1810 | // must match. Pick the best (e.g. constant) out of them |
1811 | // so that the inner product loop bound can be used in |
1812 | // optimizations. |
1813 | llvm::SmallVector<mlir::Value> innerProductExtent = |
1814 | fir::factory::deduceOptimalExtents({innerProduct1Extent}, |
1815 | {innerProduct2Extent}); |
1816 | return {builder.create<fir::ShapeOp>(loc, newExtents), |
1817 | innerProductExtent[0]}; |
1818 | } |
1819 | |
1820 | static mlir::LogicalResult |
1821 | genContiguousMatmul(mlir::Location loc, fir::FirOpBuilder &builder, |
1822 | hlfir::Entity result, mlir::Value resultShape, |
1823 | hlfir::Entity lhs, hlfir::Entity rhs, |
1824 | mlir::Value innerProductExtent) { |
1825 | // This code does not support MATMUL(TRANSPOSE), and it is supposed |
1826 | // to be inlined as hlfir.elemental. |
1827 | if constexpr (isMatmulTranspose) |
1828 | return mlir::failure(); |
1829 | |
1830 | mlir::OpBuilder::InsertionGuard guard(builder); |
1831 | mlir::Type resultElementType = result.getFortranElementType(); |
1832 | llvm::SmallVector<mlir::Value, 2> resultExtents = |
1833 | mlir::cast<fir::ShapeOp>(resultShape.getDefiningOp()).getExtents(); |
1834 | |
1835 | // The inner product loop may be unordered if FastMathFlags::reassoc |
1836 | // transformations are allowed. The integer/logical inner product is |
1837 | // always unordered. |
1838 | // Note that isUnordered is currently applied to all loops |
1839 | // in the loop nests generated below, while it has to be applied |
1840 | // only to one. |
1841 | bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) || |
1842 | mlir::isa<fir::LogicalType>(resultElementType) || |
1843 | static_cast<bool>(builder.getFastMathFlags() & |
1844 | mlir::arith::FastMathFlags::reassoc); |
1845 | |
1846 | // Insert the initialization loop nest that fills the whole result with |
1847 | // zeroes. |
1848 | mlir::Value initValue = |
1849 | fir::factory::createZeroValue(builder, loc, resultElementType); |
1850 | auto genInitBody = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
1851 | mlir::ValueRange oneBasedIndices, |
1852 | mlir::ValueRange reductionArgs) |
1853 | -> llvm::SmallVector<mlir::Value, 0> { |
1854 | hlfir::Entity resultElement = |
1855 | hlfir::getElementAt(loc, builder, result, oneBasedIndices); |
1856 | builder.create<hlfir::AssignOp>(loc, initValue, resultElement); |
1857 | return {}; |
1858 | }; |
1859 | |
1860 | hlfir::genLoopNestWithReductions(loc, builder, resultExtents, |
1861 | /*reductionInits=*/{}, genInitBody, |
1862 | /*isUnordered=*/true); |
1863 | |
1864 | if (lhs.getRank() == 2 && rhs.getRank() == 2) { |
1865 | // LHS(NROWS,N) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS) |
1866 | // |
1867 | // Insert the computation loop nest: |
1868 | // DO 2 K = 1, N |
1869 | // DO 2 J = 1, NCOLS |
1870 | // DO 2 I = 1, NROWS |
1871 | // 2 RESULT(I,J) = RESULT(I,J) + LHS(I,K)*RHS(K,J) |
1872 | auto genMatrixMatrix = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
1873 | mlir::ValueRange oneBasedIndices, |
1874 | mlir::ValueRange reductionArgs) |
1875 | -> llvm::SmallVector<mlir::Value, 0> { |
1876 | mlir::Value I = oneBasedIndices[0]; |
1877 | mlir::Value J = oneBasedIndices[1]; |
1878 | mlir::Value K = oneBasedIndices[2]; |
1879 | hlfir::Entity resultElement = |
1880 | hlfir::getElementAt(loc, builder, result, {I, J}); |
1881 | hlfir::Entity resultElementValue = |
1882 | hlfir::loadTrivialScalar(loc, builder, resultElement); |
1883 | hlfir::Entity lhsElementValue = |
1884 | hlfir::loadElementAt(loc, builder, lhs, {I, K}); |
1885 | hlfir::Entity rhsElementValue = |
1886 | hlfir::loadElementAt(loc, builder, rhs, {K, J}); |
1887 | mlir::Value productValue = |
1888 | ProductFactory{loc, builder}.genAccumulateProduct( |
1889 | resultElementValue, lhsElementValue, rhsElementValue); |
1890 | builder.create<hlfir::AssignOp>(loc, productValue, resultElement); |
1891 | return {}; |
1892 | }; |
1893 | |
1894 | // Note that the loops are inserted in reverse order, |
1895 | // so innerProductExtent should be passed as the last extent. |
1896 | hlfir::genLoopNestWithReductions( |
1897 | loc, builder, |
1898 | {resultExtents[0], resultExtents[1], innerProductExtent}, |
1899 | /*reductionInits=*/{}, genMatrixMatrix, isUnordered); |
1900 | return mlir::success(); |
1901 | } |
1902 | |
1903 | if (lhs.getRank() == 2 && rhs.getRank() == 1) { |
1904 | // LHS(NROWS,N) * RHS(N) -> RESULT(NROWS) |
1905 | // |
1906 | // Insert the computation loop nest: |
1907 | // DO 2 K = 1, N |
1908 | // DO 2 J = 1, NROWS |
1909 | // 2 RES(J) = RES(J) + LHS(J,K)*RHS(K) |
1910 | auto genMatrixVector = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
1911 | mlir::ValueRange oneBasedIndices, |
1912 | mlir::ValueRange reductionArgs) |
1913 | -> llvm::SmallVector<mlir::Value, 0> { |
1914 | mlir::Value J = oneBasedIndices[0]; |
1915 | mlir::Value K = oneBasedIndices[1]; |
1916 | hlfir::Entity resultElement = |
1917 | hlfir::getElementAt(loc, builder, result, {J}); |
1918 | hlfir::Entity resultElementValue = |
1919 | hlfir::loadTrivialScalar(loc, builder, resultElement); |
1920 | hlfir::Entity lhsElementValue = |
1921 | hlfir::loadElementAt(loc, builder, lhs, {J, K}); |
1922 | hlfir::Entity rhsElementValue = |
1923 | hlfir::loadElementAt(loc, builder, rhs, {K}); |
1924 | mlir::Value productValue = |
1925 | ProductFactory{loc, builder}.genAccumulateProduct( |
1926 | resultElementValue, lhsElementValue, rhsElementValue); |
1927 | builder.create<hlfir::AssignOp>(loc, productValue, resultElement); |
1928 | return {}; |
1929 | }; |
1930 | hlfir::genLoopNestWithReductions( |
1931 | loc, builder, {resultExtents[0], innerProductExtent}, |
1932 | /*reductionInits=*/{}, genMatrixVector, isUnordered); |
1933 | return mlir::success(); |
1934 | } |
1935 | if (lhs.getRank() == 1 && rhs.getRank() == 2) { |
1936 | // LHS(N) * RHS(N,NCOLS) -> RESULT(NCOLS) |
1937 | // |
1938 | // Insert the computation loop nest: |
1939 | // DO 2 K = 1, N |
1940 | // DO 2 J = 1, NCOLS |
1941 | // 2 RES(J) = RES(J) + LHS(K)*RHS(K,J) |
1942 | auto genVectorMatrix = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
1943 | mlir::ValueRange oneBasedIndices, |
1944 | mlir::ValueRange reductionArgs) |
1945 | -> llvm::SmallVector<mlir::Value, 0> { |
1946 | mlir::Value J = oneBasedIndices[0]; |
1947 | mlir::Value K = oneBasedIndices[1]; |
1948 | hlfir::Entity resultElement = |
1949 | hlfir::getElementAt(loc, builder, result, {J}); |
1950 | hlfir::Entity resultElementValue = |
1951 | hlfir::loadTrivialScalar(loc, builder, resultElement); |
1952 | hlfir::Entity lhsElementValue = |
1953 | hlfir::loadElementAt(loc, builder, lhs, {K}); |
1954 | hlfir::Entity rhsElementValue = |
1955 | hlfir::loadElementAt(loc, builder, rhs, {K, J}); |
1956 | mlir::Value productValue = |
1957 | ProductFactory{loc, builder}.genAccumulateProduct( |
1958 | resultElementValue, lhsElementValue, rhsElementValue); |
1959 | builder.create<hlfir::AssignOp>(loc, productValue, resultElement); |
1960 | return {}; |
1961 | }; |
1962 | hlfir::genLoopNestWithReductions( |
1963 | loc, builder, {resultExtents[0], innerProductExtent}, |
1964 | /*reductionInits=*/{}, genVectorMatrix, isUnordered); |
1965 | return mlir::success(); |
1966 | } |
1967 | |
1968 | llvm_unreachable("unsupported MATMUL arguments' ranks" ); |
1969 | } |
1970 | |
1971 | static hlfir::ElementalOp |
1972 | genElementalMatmul(mlir::Location loc, fir::FirOpBuilder &builder, |
1973 | hlfir::ExprType resultType, mlir::Value resultShape, |
1974 | hlfir::Entity lhs, hlfir::Entity rhs, |
1975 | mlir::Value innerProductExtent) { |
1976 | mlir::OpBuilder::InsertionGuard guard(builder); |
1977 | mlir::Type resultElementType = resultType.getElementType(); |
1978 | auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
1979 | mlir::ValueRange resultIndices) -> hlfir::Entity { |
1980 | mlir::Value initValue = |
1981 | fir::factory::createZeroValue(builder, loc, resultElementType); |
1982 | // The inner product loop may be unordered if FastMathFlags::reassoc |
1983 | // transformations are allowed. The integer/logical inner product is |
1984 | // always unordered. |
1985 | bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) || |
1986 | mlir::isa<fir::LogicalType>(resultElementType) || |
1987 | static_cast<bool>(builder.getFastMathFlags() & |
1988 | mlir::arith::FastMathFlags::reassoc); |
1989 | |
1990 | auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
1991 | mlir::ValueRange oneBasedIndices, |
1992 | mlir::ValueRange reductionArgs) |
1993 | -> llvm::SmallVector<mlir::Value, 1> { |
1994 | llvm::SmallVector<mlir::Value, 2> lhsIndices; |
1995 | llvm::SmallVector<mlir::Value, 2> rhsIndices; |
1996 | // MATMUL: |
1997 | // LHS(NROWS,N) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS) |
1998 | // LHS(NROWS,N) * RHS(N) -> RESULT(NROWS) |
1999 | // LHS(N) * RHS(N,NCOLS) -> RESULT(NCOLS) |
2000 | // |
2001 | // MATMUL(TRANSPOSE): |
2002 | // TRANSPOSE(LHS(N,NROWS)) * RHS(N,NCOLS) -> RESULT(NROWS,NCOLS) |
2003 | // TRANSPOSE(LHS(N,NROWS)) * RHS(N) -> RESULT(NROWS) |
2004 | // |
2005 | // The resultIndices iterate over (NROWS[,NCOLS]). |
2006 | // The oneBasedIndices iterate over (N). |
2007 | if (lhs.getRank() > 1) |
2008 | lhsIndices.push_back(resultIndices[0]); |
2009 | lhsIndices.push_back(oneBasedIndices[0]); |
2010 | |
2011 | if constexpr (isMatmulTranspose) { |
2012 | // Swap the LHS indices for TRANSPOSE. |
2013 | std::swap(lhsIndices[0], lhsIndices[1]); |
2014 | } |
2015 | |
2016 | rhsIndices.push_back(oneBasedIndices[0]); |
2017 | if (rhs.getRank() > 1) |
2018 | rhsIndices.push_back(resultIndices.back()); |
2019 | |
2020 | hlfir::Entity lhsElementValue = |
2021 | hlfir::loadElementAt(loc, builder, lhs, lhsIndices); |
2022 | hlfir::Entity rhsElementValue = |
2023 | hlfir::loadElementAt(loc, builder, rhs, rhsIndices); |
2024 | mlir::Value productValue = |
2025 | ProductFactory{loc, builder}.genAccumulateProduct( |
2026 | reductionArgs[0], lhsElementValue, rhsElementValue); |
2027 | return {productValue}; |
2028 | }; |
2029 | llvm::SmallVector<mlir::Value, 1> innerProductValue = |
2030 | hlfir::genLoopNestWithReductions(loc, builder, {innerProductExtent}, |
2031 | {initValue}, genBody, isUnordered); |
2032 | return hlfir::Entity{innerProductValue[0]}; |
2033 | }; |
2034 | hlfir::ElementalOp elementalOp = hlfir::genElementalOp( |
2035 | loc, builder, resultElementType, resultShape, /*typeParams=*/{}, |
2036 | genKernel, |
2037 | /*isUnordered=*/true, /*polymorphicMold=*/nullptr, resultType); |
2038 | |
2039 | return elementalOp; |
2040 | } |
2041 | }; |
2042 | |
2043 | class DotProductConversion |
2044 | : public mlir::OpRewritePattern<hlfir::DotProductOp> { |
2045 | public: |
2046 | using mlir::OpRewritePattern<hlfir::DotProductOp>::OpRewritePattern; |
2047 | |
2048 | llvm::LogicalResult |
2049 | matchAndRewrite(hlfir::DotProductOp product, |
2050 | mlir::PatternRewriter &rewriter) const override { |
2051 | hlfir::Entity op = hlfir::Entity{product}; |
2052 | if (!op.isScalar()) |
2053 | return rewriter.notifyMatchFailure(product, "produces non-scalar result" ); |
2054 | |
2055 | mlir::Location loc = product.getLoc(); |
2056 | fir::FirOpBuilder builder{rewriter, product.getOperation()}; |
2057 | hlfir::Entity lhs = hlfir::Entity{product.getLhs()}; |
2058 | hlfir::Entity rhs = hlfir::Entity{product.getRhs()}; |
2059 | mlir::Type resultElementType = product.getType(); |
2060 | bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) || |
2061 | mlir::isa<fir::LogicalType>(resultElementType) || |
2062 | static_cast<bool>(builder.getFastMathFlags() & |
2063 | mlir::arith::FastMathFlags::reassoc); |
2064 | |
2065 | mlir::Value extent = genProductExtent(loc, builder, lhs, rhs); |
2066 | |
2067 | auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
2068 | mlir::ValueRange oneBasedIndices, |
2069 | mlir::ValueRange reductionArgs) |
2070 | -> llvm::SmallVector<mlir::Value, 1> { |
2071 | hlfir::Entity lhsElementValue = |
2072 | hlfir::loadElementAt(loc, builder, lhs, oneBasedIndices); |
2073 | hlfir::Entity rhsElementValue = |
2074 | hlfir::loadElementAt(loc, builder, rhs, oneBasedIndices); |
2075 | mlir::Value productValue = |
2076 | ProductFactory{loc, builder}.genAccumulateProduct</*CONJ=*/true>( |
2077 | reductionArgs[0], lhsElementValue, rhsElementValue); |
2078 | return {productValue}; |
2079 | }; |
2080 | |
2081 | mlir::Value initValue = |
2082 | fir::factory::createZeroValue(builder, loc, resultElementType); |
2083 | |
2084 | llvm::SmallVector<mlir::Value, 1> result = hlfir::genLoopNestWithReductions( |
2085 | loc, builder, {extent}, |
2086 | /*reductionInits=*/{initValue}, genBody, isUnordered); |
2087 | |
2088 | rewriter.replaceOp(product, result[0]); |
2089 | return mlir::success(); |
2090 | } |
2091 | |
2092 | private: |
2093 | static mlir::Value genProductExtent(mlir::Location loc, |
2094 | fir::FirOpBuilder &builder, |
2095 | hlfir::Entity input1, |
2096 | hlfir::Entity input2) { |
2097 | llvm::SmallVector<mlir::Value, 1> input1Extents = |
2098 | hlfir::genExtentsVector(loc, builder, input1); |
2099 | llvm::SmallVector<mlir::Value, 1> input2Extents = |
2100 | hlfir::genExtentsVector(loc, builder, input2); |
2101 | |
2102 | assert(input1Extents.size() == 1 && input2Extents.size() == 1 && |
2103 | "hlfir.dot_product arguments must be vectors" ); |
2104 | llvm::SmallVector<mlir::Value, 1> extent = |
2105 | fir::factory::deduceOptimalExtents(input1Extents, input2Extents); |
2106 | return extent[0]; |
2107 | } |
2108 | }; |
2109 | |
2110 | class ReshapeAsElementalConversion |
2111 | : public mlir::OpRewritePattern<hlfir::ReshapeOp> { |
2112 | public: |
2113 | using mlir::OpRewritePattern<hlfir::ReshapeOp>::OpRewritePattern; |
2114 | |
2115 | llvm::LogicalResult |
2116 | matchAndRewrite(hlfir::ReshapeOp reshape, |
2117 | mlir::PatternRewriter &rewriter) const override { |
2118 | // Do not inline RESHAPE with ORDER yet. The runtime implementation |
2119 | // may be good enough, unless the temporary creation overhead |
2120 | // is high. |
2121 | // TODO: If ORDER is constant, then we can still easily inline. |
2122 | // TODO: If the result's rank is 1, then we can assume ORDER == (/1/). |
2123 | if (reshape.getOrder()) |
2124 | return rewriter.notifyMatchFailure(reshape, |
2125 | "RESHAPE with ORDER argument" ); |
2126 | |
2127 | // Verify that the element types of ARRAY, PAD and the result |
2128 | // match before doing any transformations. For example, |
2129 | // the character types of different lengths may appear in the dead |
2130 | // code, and it just does not make sense to inline hlfir.reshape |
2131 | // in this case (a runtime call might have less code size footprint). |
2132 | hlfir::Entity result = hlfir::Entity{reshape}; |
2133 | hlfir::Entity array = hlfir::Entity{reshape.getArray()}; |
2134 | mlir::Type elementType = array.getFortranElementType(); |
2135 | if (result.getFortranElementType() != elementType) |
2136 | return rewriter.notifyMatchFailure( |
2137 | reshape, "ARRAY and result have different types" ); |
2138 | mlir::Value pad = reshape.getPad(); |
2139 | if (pad && hlfir::getFortranElementType(pad.getType()) != elementType) |
2140 | return rewriter.notifyMatchFailure(reshape, |
2141 | "ARRAY and PAD have different types" ); |
2142 | |
2143 | // TODO: selecting between ARRAY and PAD of non-trivial element types |
2144 | // requires more work. We have to select between two references |
2145 | // to elements in ARRAY and PAD. This requires conditional |
2146 | // bufferization of the element, if ARRAY/PAD is an expression. |
2147 | if (pad && !fir::isa_trivial(elementType)) |
2148 | return rewriter.notifyMatchFailure(reshape, |
2149 | "PAD present with non-trivial type" ); |
2150 | |
2151 | mlir::Location loc = reshape.getLoc(); |
2152 | fir::FirOpBuilder builder{rewriter, reshape.getOperation()}; |
2153 | // Assume that all the indices arithmetic does not overflow |
2154 | // the IndexType. |
2155 | builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nuw); |
2156 | |
2157 | llvm::SmallVector<mlir::Value, 1> typeParams; |
2158 | hlfir::genLengthParameters(loc, builder, array, typeParams); |
2159 | |
2160 | // Fetch the extents of ARRAY, PAD and result beforehand. |
2161 | llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayExtents = |
2162 | hlfir::genExtentsVector(loc, builder, array); |
2163 | |
2164 | // If PAD is present, we have to use array size to start taking |
2165 | // elements from the PAD array. |
2166 | mlir::Value arraySize = |
2167 | pad ? computeArraySize(loc, builder, arrayExtents) : nullptr; |
2168 | hlfir::Entity shape = hlfir::Entity{reshape.getShape()}; |
2169 | llvm::SmallVector<mlir::Value, Fortran::common::maxRank> resultExtents; |
2170 | mlir::Type indexType = builder.getIndexType(); |
2171 | for (int idx = 0; idx < result.getRank(); ++idx) |
2172 | resultExtents.push_back(hlfir::loadElementAt( |
2173 | loc, builder, shape, |
2174 | builder.createIntegerConstant(loc, indexType, idx + 1))); |
2175 | auto resultShape = builder.create<fir::ShapeOp>(loc, resultExtents); |
2176 | |
2177 | auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder, |
2178 | mlir::ValueRange inputIndices) -> hlfir::Entity { |
2179 | mlir::Value linearIndex = |
2180 | computeLinearIndex(loc, builder, resultExtents, inputIndices); |
2181 | fir::IfOp ifOp; |
2182 | if (pad) { |
2183 | // PAD is present. Check if this element comes from the PAD array. |
2184 | mlir::Value isInsideArray = builder.create<mlir::arith::CmpIOp>( |
2185 | loc, mlir::arith::CmpIPredicate::ult, linearIndex, arraySize); |
2186 | ifOp = builder.create<fir::IfOp>(loc, elementType, isInsideArray, |
2187 | /*withElseRegion=*/true); |
2188 | |
2189 | // In the 'else' block, return an element from the PAD. |
2190 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
2191 | // PAD is dynamically optional, but we can unconditionally access it |
2192 | // in the 'else' block. If we have to start taking elements from it, |
2193 | // then it must be present in a valid program. |
2194 | llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents = |
2195 | hlfir::genExtentsVector(loc, builder, hlfir::Entity{pad}); |
2196 | // Subtract the ARRAY size from the zero-based linear index |
2197 | // to get the zero-based linear index into PAD. |
2198 | mlir::Value padLinearIndex = |
2199 | builder.create<mlir::arith::SubIOp>(loc, linearIndex, arraySize); |
2200 | llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padIndices = |
2201 | delinearizeIndex(loc, builder, padExtents, padLinearIndex, |
2202 | /*wrapAround=*/true); |
2203 | mlir::Value padElement = |
2204 | hlfir::loadElementAt(loc, builder, hlfir::Entity{pad}, padIndices); |
2205 | builder.create<fir::ResultOp>(loc, padElement); |
2206 | |
2207 | // In the 'then' block, return an element from the ARRAY. |
2208 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
2209 | } |
2210 | |
2211 | llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayIndices = |
2212 | delinearizeIndex(loc, builder, arrayExtents, linearIndex, |
2213 | /*wrapAround=*/false); |
2214 | mlir::Value arrayElement = |
2215 | hlfir::loadElementAt(loc, builder, array, arrayIndices); |
2216 | |
2217 | if (ifOp) { |
2218 | builder.create<fir::ResultOp>(loc, arrayElement); |
2219 | builder.setInsertionPointAfter(ifOp); |
2220 | arrayElement = ifOp.getResult(0); |
2221 | } |
2222 | |
2223 | return hlfir::Entity{arrayElement}; |
2224 | }; |
2225 | hlfir::ElementalOp elementalOp = hlfir::genElementalOp( |
2226 | loc, builder, elementType, resultShape, typeParams, genKernel, |
2227 | /*isUnordered=*/true, |
2228 | /*polymorphicMold=*/result.isPolymorphic() ? array : mlir::Value{}, |
2229 | reshape.getResult().getType()); |
2230 | assert(elementalOp.getResult().getType() == reshape.getResult().getType()); |
2231 | rewriter.replaceOp(reshape, elementalOp); |
2232 | return mlir::success(); |
2233 | } |
2234 | |
2235 | private: |
2236 | /// Compute zero-based linear index given an array extents |
2237 | /// and one-based indices: |
2238 | /// \p extents: [e0, e1, ..., en] |
2239 | /// \p indices: [i0, i1, ..., in] |
2240 | /// |
2241 | /// linear-index := |
2242 | /// (...((in-1)*e(n-1)+(i(n-1)-1))*e(n-2)+...)*e0+(i0-1) |
2243 | static mlir::Value computeLinearIndex(mlir::Location loc, |
2244 | fir::FirOpBuilder &builder, |
2245 | mlir::ValueRange extents, |
2246 | mlir::ValueRange indices) { |
2247 | std::size_t rank = extents.size(); |
2248 | assert(rank == indices.size()); |
2249 | mlir::Type indexType = builder.getIndexType(); |
2250 | mlir::Value zero = builder.createIntegerConstant(loc, indexType, 0); |
2251 | mlir::Value one = builder.createIntegerConstant(loc, indexType, 1); |
2252 | mlir::Value linearIndex = zero; |
2253 | std::size_t idx = 0; |
2254 | for (auto index : llvm::reverse(indices)) { |
2255 | mlir::Value tmp = builder.create<mlir::arith::SubIOp>( |
2256 | loc, builder.createConvert(loc, indexType, index), one); |
2257 | tmp = builder.create<mlir::arith::AddIOp>(loc, linearIndex, tmp); |
2258 | if (idx + 1 < rank) |
2259 | tmp = builder.create<mlir::arith::MulIOp>( |
2260 | loc, tmp, |
2261 | builder.createConvert(loc, indexType, extents[rank - idx - 2])); |
2262 | |
2263 | linearIndex = tmp; |
2264 | ++idx; |
2265 | } |
2266 | return linearIndex; |
2267 | } |
2268 | |
2269 | /// Compute one-based array indices from the given zero-based \p linearIndex |
2270 | /// and the array \p extents [e0, e1, ..., en]. |
2271 | /// i0 := linearIndex % e0 + 1 |
2272 | /// linearIndex := linearIndex / e0 |
2273 | /// i1 := linearIndex % e1 + 1 |
2274 | /// linearIndex := linearIndex / e1 |
2275 | /// ... |
2276 | /// i(n-1) := linearIndex % e(n-1) + 1 |
2277 | /// linearIndex := linearIndex / e(n-1) |
2278 | /// if (wrapAround) { |
2279 | /// // If the index is allowed to wrap around, then |
2280 | /// // we need to modulo it by the last dimension's extent. |
2281 | /// in := linearIndex % en + 1 |
2282 | /// } else { |
2283 | /// in := linearIndex + 1 |
2284 | /// } |
2285 | static llvm::SmallVector<mlir::Value, Fortran::common::maxRank> |
2286 | delinearizeIndex(mlir::Location loc, fir::FirOpBuilder &builder, |
2287 | mlir::ValueRange extents, mlir::Value linearIndex, |
2288 | bool wrapAround) { |
2289 | llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices; |
2290 | mlir::Type indexType = builder.getIndexType(); |
2291 | mlir::Value one = builder.createIntegerConstant(loc, indexType, 1); |
2292 | linearIndex = builder.createConvert(loc, indexType, linearIndex); |
2293 | |
2294 | for (std::size_t dim = 0; dim < extents.size(); ++dim) { |
2295 | mlir::Value extent = builder.createConvert(loc, indexType, extents[dim]); |
2296 | // Avoid the modulo for the last index, unless wrap around is allowed. |
2297 | mlir::Value currentIndex = linearIndex; |
2298 | if (dim != extents.size() - 1 || wrapAround) |
2299 | currentIndex = |
2300 | builder.create<mlir::arith::RemUIOp>(loc, linearIndex, extent); |
2301 | // The result of the last division is unused, so it will be DCEd. |
2302 | linearIndex = |
2303 | builder.create<mlir::arith::DivUIOp>(loc, linearIndex, extent); |
2304 | indices.push_back( |
2305 | builder.create<mlir::arith::AddIOp>(loc, currentIndex, one)); |
2306 | } |
2307 | return indices; |
2308 | } |
2309 | |
2310 | /// Return size of an array given its extents. |
2311 | static mlir::Value computeArraySize(mlir::Location loc, |
2312 | fir::FirOpBuilder &builder, |
2313 | mlir::ValueRange extents) { |
2314 | mlir::Type indexType = builder.getIndexType(); |
2315 | mlir::Value size = builder.createIntegerConstant(loc, indexType, 1); |
2316 | for (auto extent : extents) |
2317 | size = builder.create<mlir::arith::MulIOp>( |
2318 | loc, size, builder.createConvert(loc, indexType, extent)); |
2319 | return size; |
2320 | } |
2321 | }; |
2322 | |
2323 | class SimplifyHLFIRIntrinsics |
2324 | : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> { |
2325 | public: |
2326 | using SimplifyHLFIRIntrinsicsBase< |
2327 | SimplifyHLFIRIntrinsics>::SimplifyHLFIRIntrinsicsBase; |
2328 | |
2329 | void runOnOperation() override { |
2330 | mlir::MLIRContext *context = &getContext(); |
2331 | |
2332 | mlir::GreedyRewriteConfig config; |
2333 | // Prevent the pattern driver from merging blocks |
2334 | config.setRegionSimplificationLevel( |
2335 | mlir::GreedySimplifyRegionLevel::Disabled); |
2336 | |
2337 | mlir::RewritePatternSet patterns(context); |
2338 | patterns.insert<TransposeAsElementalConversion>(context); |
2339 | patterns.insert<ReductionConversion<hlfir::SumOp>>(context); |
2340 | patterns.insert<CShiftConversion>(context); |
2341 | patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context); |
2342 | |
2343 | patterns.insert<ReductionConversion<hlfir::CountOp>>(context); |
2344 | patterns.insert<ReductionConversion<hlfir::AnyOp>>(context); |
2345 | patterns.insert<ReductionConversion<hlfir::AllOp>>(context); |
2346 | patterns.insert<ReductionConversion<hlfir::MaxlocOp>>(context); |
2347 | patterns.insert<ReductionConversion<hlfir::MinlocOp>>(context); |
2348 | patterns.insert<ReductionConversion<hlfir::MaxvalOp>>(context); |
2349 | patterns.insert<ReductionConversion<hlfir::MinvalOp>>(context); |
2350 | |
2351 | // If forceMatmulAsElemental is false, then hlfir.matmul inlining |
2352 | // will introduce hlfir.eval_in_mem operation with new memory side |
2353 | // effects. This conflicts with CSE and optimized bufferization, e.g.: |
2354 | // A(1:N,1:N) = A(1:N,1:N) - MATMUL(...) |
2355 | // If we introduce hlfir.eval_in_mem before CSE, then the current |
2356 | // MLIR CSE won't be able to optimize the trivial loads of 'N' value |
2357 | // that happen before and after hlfir.matmul. |
2358 | // If 'N' loads are not optimized, then the optimized bufferization |
2359 | // won't be able to prove that the slices of A are identical |
2360 | // on both sides of the assignment. |
2361 | // This is actually the CSE problem, but we can work it around |
2362 | // for the time being. |
2363 | if (forceMatmulAsElemental || this->allowNewSideEffects) |
2364 | patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context); |
2365 | |
2366 | patterns.insert<DotProductConversion>(context); |
2367 | patterns.insert<ReshapeAsElementalConversion>(context); |
2368 | |
2369 | if (mlir::failed(mlir::applyPatternsGreedily( |
2370 | getOperation(), std::move(patterns), config))) { |
2371 | mlir::emitError(getOperation()->getLoc(), |
2372 | "failure in HLFIR intrinsic simplification" ); |
2373 | signalPassFailure(); |
2374 | } |
2375 | } |
2376 | }; |
2377 | } // namespace |
2378 | |