1 | //===- OptimizedBufferization.cpp - special cases for bufferization -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // In some special cases we can bufferize hlfir expressions in a more optimal |
9 | // way so as to avoid creating temporaries. This pass handles these. It should |
10 | // be run before the catch-all bufferization pass. |
11 | // |
12 | // This requires constant subexpression elimination to have already been run. |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "flang/Optimizer/Analysis/AliasAnalysis.h" |
16 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
17 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
18 | #include "flang/Optimizer/Dialect/FIROps.h" |
19 | #include "flang/Optimizer/Dialect/FIRType.h" |
20 | #include "flang/Optimizer/HLFIR/HLFIRDialect.h" |
21 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
22 | #include "flang/Optimizer/HLFIR/Passes.h" |
23 | #include "flang/Optimizer/Transforms/Utils.h" |
24 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
25 | #include "mlir/IR/Dominance.h" |
26 | #include "mlir/IR/PatternMatch.h" |
27 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
28 | #include "mlir/Pass/Pass.h" |
29 | #include "mlir/Support/LLVM.h" |
30 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
31 | #include "llvm/ADT/TypeSwitch.h" |
32 | #include <iterator> |
33 | #include <memory> |
34 | #include <mlir/Analysis/AliasAnalysis.h> |
35 | #include <optional> |
36 | |
37 | namespace hlfir { |
38 | #define GEN_PASS_DEF_OPTIMIZEDBUFFERIZATION |
39 | #include "flang/Optimizer/HLFIR/Passes.h.inc" |
40 | } // namespace hlfir |
41 | |
42 | #define DEBUG_TYPE "opt-bufferization" |
43 | |
44 | namespace { |
45 | |
46 | /// This transformation should match in place modification of arrays. |
47 | /// It should match code of the form |
48 | /// %array = some.operation // array has shape %shape |
49 | /// %expr = hlfir.elemental %shape : [...] { |
50 | /// bb0(%arg0: index) |
51 | /// %0 = hlfir.designate %array(%arg0) |
52 | /// [...] // no other reads or writes to %array |
53 | /// hlfir.yield_element %element |
54 | /// } |
55 | /// hlfir.assign %expr to %array |
56 | /// hlfir.destroy %expr |
57 | /// |
58 | /// Or |
59 | /// |
60 | /// %read_array = some.operation // shape %shape |
61 | /// %expr = hlfir.elemental %shape : [...] { |
62 | /// bb0(%arg0: index) |
63 | /// %0 = hlfir.designate %read_array(%arg0) |
64 | /// [...] |
65 | /// hlfir.yield_element %element |
66 | /// } |
67 | /// %write_array = some.operation // with shape %shape |
68 | /// [...] // operations which don't effect write_array |
69 | /// hlfir.assign %expr to %write_array |
70 | /// hlfir.destroy %expr |
71 | /// |
72 | /// In these cases, it is safe to turn the elemental into a do loop and modify |
73 | /// elements of %array in place without creating an extra temporary for the |
74 | /// elemental. We must check that there are no reads from the array at indexes |
75 | /// which might conflict with the assignment or any writes. For now we will keep |
76 | /// that strict and say that all reads must be at the elemental index (it is |
77 | /// probably safe to read from higher indices if lowering to an ordered loop). |
78 | class ElementalAssignBufferization |
79 | : public mlir::OpRewritePattern<hlfir::ElementalOp> { |
80 | private: |
81 | struct MatchInfo { |
82 | mlir::Value array; |
83 | hlfir::AssignOp assign; |
84 | hlfir::DestroyOp destroy; |
85 | }; |
86 | /// determines if the transformation can be applied to this elemental |
87 | static std::optional<MatchInfo> findMatch(hlfir::ElementalOp elemental); |
88 | |
89 | public: |
90 | using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern; |
91 | |
92 | mlir::LogicalResult |
93 | matchAndRewrite(hlfir::ElementalOp elemental, |
94 | mlir::PatternRewriter &rewriter) const override; |
95 | }; |
96 | |
97 | /// recursively collect all effects between start and end (including start, not |
98 | /// including end) start must properly dominate end, start and end must be in |
99 | /// the same block. If any operations with unknown effects are found, |
100 | /// std::nullopt is returned |
101 | static std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>> |
102 | getEffectsBetween(mlir::Operation *start, mlir::Operation *end) { |
103 | mlir::SmallVector<mlir::MemoryEffects::EffectInstance> ret; |
104 | if (start == end) |
105 | return ret; |
106 | assert(start->getBlock() && end->getBlock() && "TODO: block arguments" ); |
107 | assert(start->getBlock() == end->getBlock()); |
108 | assert(mlir::DominanceInfo{}.properlyDominates(start, end)); |
109 | |
110 | mlir::Operation *nextOp = start; |
111 | while (nextOp && nextOp != end) { |
112 | std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>> |
113 | effects = mlir::getEffectsRecursively(nextOp); |
114 | if (!effects) |
115 | return std::nullopt; |
116 | ret.append(*effects); |
117 | nextOp = nextOp->getNextNode(); |
118 | } |
119 | return ret; |
120 | } |
121 | |
122 | /// If effect is a read or write on val, return whether it aliases. |
123 | /// Otherwise return mlir::AliasResult::NoAlias |
124 | static mlir::AliasResult |
125 | containsReadOrWriteEffectOn(const mlir::MemoryEffects::EffectInstance &effect, |
126 | mlir::Value val) { |
127 | fir::AliasAnalysis aliasAnalysis; |
128 | |
129 | if (mlir::isa<mlir::MemoryEffects::Read, mlir::MemoryEffects::Write>( |
130 | effect.getEffect())) { |
131 | mlir::Value accessedVal = effect.getValue(); |
132 | if (mlir::isa<fir::DebuggingResource>(effect.getResource())) |
133 | return mlir::AliasResult::NoAlias; |
134 | if (!accessedVal) |
135 | return mlir::AliasResult::MayAlias; |
136 | if (accessedVal == val) |
137 | return mlir::AliasResult::MustAlias; |
138 | |
139 | // if the accessed value might alias val |
140 | mlir::AliasResult res = aliasAnalysis.alias(val, accessedVal); |
141 | if (!res.isNo()) |
142 | return res; |
143 | |
144 | // FIXME: alias analysis of fir.load |
145 | // follow this common pattern: |
146 | // %ref = hlfir.designate %array(%index) |
147 | // %val = fir.load $ref |
148 | if (auto designate = accessedVal.getDefiningOp<hlfir::DesignateOp>()) { |
149 | if (designate.getMemref() == val) |
150 | return mlir::AliasResult::MustAlias; |
151 | |
152 | // if the designate is into an array that might alias val |
153 | res = aliasAnalysis.alias(val, designate.getMemref()); |
154 | if (!res.isNo()) |
155 | return res; |
156 | } |
157 | } |
158 | return mlir::AliasResult::NoAlias; |
159 | } |
160 | |
161 | // Returns true if the given array references represent identical |
162 | // or completely disjoint array slices. The callers may use this |
163 | // method when the alias analysis reports an alias of some kind, |
164 | // so that we can run Fortran specific analysis on the array slices |
165 | // to see if they are identical or disjoint. Note that the alias |
166 | // analysis are not able to give such an answer about the references. |
167 | static bool areIdenticalOrDisjointSlices(mlir::Value ref1, mlir::Value ref2) { |
168 | if (ref1 == ref2) |
169 | return true; |
170 | |
171 | auto des1 = ref1.getDefiningOp<hlfir::DesignateOp>(); |
172 | auto des2 = ref2.getDefiningOp<hlfir::DesignateOp>(); |
173 | // We only support a pair of designators right now. |
174 | if (!des1 || !des2) |
175 | return false; |
176 | |
177 | if (des1.getMemref() != des2.getMemref()) { |
178 | // If the bases are different, then there is unknown overlap. |
179 | LLVM_DEBUG(llvm::dbgs() << "No identical base for:\n" |
180 | << des1 << "and:\n" |
181 | << des2 << "\n" ); |
182 | return false; |
183 | } |
184 | |
185 | // Require all components of the designators to be the same. |
186 | // It might be too strict, e.g. we may probably allow for |
187 | // different type parameters. |
188 | if (des1.getComponent() != des2.getComponent() || |
189 | des1.getComponentShape() != des2.getComponentShape() || |
190 | des1.getSubstring() != des2.getSubstring() || |
191 | des1.getComplexPart() != des2.getComplexPart() || |
192 | des1.getTypeparams() != des2.getTypeparams()) { |
193 | LLVM_DEBUG(llvm::dbgs() << "Different designator specs for:\n" |
194 | << des1 << "and:\n" |
195 | << des2 << "\n" ); |
196 | return false; |
197 | } |
198 | |
199 | if (des1.getIsTriplet() != des2.getIsTriplet()) { |
200 | LLVM_DEBUG(llvm::dbgs() << "Different sections for:\n" |
201 | << des1 << "and:\n" |
202 | << des2 << "\n" ); |
203 | return false; |
204 | } |
205 | |
206 | // Analyze the subscripts. |
207 | // For example: |
208 | // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0) shape %9 |
209 | // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1) shape %9 |
210 | // |
211 | // If all the triplets (section speficiers) are the same, then |
212 | // we do not care if %0 is equal to %1 - the slices are either |
213 | // identical or completely disjoint. |
214 | auto des1It = des1.getIndices().begin(); |
215 | auto des2It = des2.getIndices().begin(); |
216 | bool identicalTriplets = true; |
217 | for (bool isTriplet : des1.getIsTriplet()) { |
218 | if (isTriplet) { |
219 | for (int i = 0; i < 3; ++i) |
220 | if (*des1It++ != *des2It++) { |
221 | LLVM_DEBUG(llvm::dbgs() << "Triplet mismatch for:\n" |
222 | << des1 << "and:\n" |
223 | << des2 << "\n" ); |
224 | identicalTriplets = false; |
225 | break; |
226 | } |
227 | } else { |
228 | ++des1It; |
229 | ++des2It; |
230 | } |
231 | } |
232 | if (identicalTriplets) |
233 | return true; |
234 | |
235 | // See if we can prove that any of the triplets do not overlap. |
236 | // This is mostly a Polyhedron/nf performance hack that looks for |
237 | // particular relations between the lower and upper bounds |
238 | // of the array sections, e.g. for any positive constant C: |
239 | // X:Y does not overlap with (Y+C):Z |
240 | // X:Y does not overlap with Z:(X-C) |
241 | auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) { |
242 | auto removeConvert = [](mlir::Value v) -> mlir::Operation * { |
243 | auto *op = v.getDefiningOp(); |
244 | while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op)) |
245 | op = conv.getValue().getDefiningOp(); |
246 | return op; |
247 | }; |
248 | |
249 | auto isPositiveConstant = [](mlir::Value v) -> bool { |
250 | if (auto conOp = |
251 | mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp())) |
252 | if (auto iattr = conOp.getValue().dyn_cast<mlir::IntegerAttr>()) |
253 | return iattr.getInt() > 0; |
254 | return false; |
255 | }; |
256 | |
257 | auto *op1 = removeConvert(v1); |
258 | auto *op2 = removeConvert(v2); |
259 | if (!op1 || !op2) |
260 | return false; |
261 | if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2)) |
262 | if ((addi.getLhs().getDefiningOp() == op1 && |
263 | isPositiveConstant(addi.getRhs())) || |
264 | (addi.getRhs().getDefiningOp() == op1 && |
265 | isPositiveConstant(addi.getLhs()))) |
266 | return true; |
267 | if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1)) |
268 | if (subi.getLhs().getDefiningOp() == op2 && |
269 | isPositiveConstant(subi.getRhs())) |
270 | return true; |
271 | return false; |
272 | }; |
273 | |
274 | des1It = des1.getIndices().begin(); |
275 | des2It = des2.getIndices().begin(); |
276 | for (bool isTriplet : des1.getIsTriplet()) { |
277 | if (isTriplet) { |
278 | mlir::Value des1Lb = *des1It++; |
279 | mlir::Value des1Ub = *des1It++; |
280 | mlir::Value des2Lb = *des2It++; |
281 | mlir::Value des2Ub = *des2It++; |
282 | // Ignore strides. |
283 | ++des1It; |
284 | ++des2It; |
285 | if (displacedByConstant(des1Ub, des2Lb) || |
286 | displacedByConstant(des2Ub, des1Lb)) |
287 | return true; |
288 | } else { |
289 | ++des1It; |
290 | ++des2It; |
291 | } |
292 | } |
293 | |
294 | return false; |
295 | } |
296 | |
297 | std::optional<ElementalAssignBufferization::MatchInfo> |
298 | ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) { |
299 | mlir::Operation::user_range users = elemental->getUsers(); |
300 | // the only uses of the elemental should be the assignment and the destroy |
301 | if (std::distance(users.begin(), users.end()) != 2) { |
302 | LLVM_DEBUG(llvm::dbgs() << "Too many uses of the elemental\n" ); |
303 | return std::nullopt; |
304 | } |
305 | |
306 | // If the ElementalOp must produce a temporary (e.g. for |
307 | // finalization purposes), then we cannot inline it. |
308 | if (hlfir::elementalOpMustProduceTemp(elemental)) { |
309 | LLVM_DEBUG(llvm::dbgs() << "ElementalOp must produce a temp\n" ); |
310 | return std::nullopt; |
311 | } |
312 | |
313 | MatchInfo match; |
314 | for (mlir::Operation *user : users) |
315 | mlir::TypeSwitch<mlir::Operation *, void>(user) |
316 | .Case([&](hlfir::AssignOp op) { match.assign = op; }) |
317 | .Case([&](hlfir::DestroyOp op) { match.destroy = op; }); |
318 | |
319 | if (!match.assign || !match.destroy) { |
320 | LLVM_DEBUG(llvm::dbgs() << "Couldn't find assign or destroy\n" ); |
321 | return std::nullopt; |
322 | } |
323 | |
324 | // the array is what the elemental is assigned into |
325 | // TODO: this could be extended to also allow hlfir.expr by first bufferizing |
326 | // the incoming expression |
327 | match.array = match.assign.getLhs(); |
328 | mlir::Type arrayType = mlir::dyn_cast<fir::SequenceType>( |
329 | fir::unwrapPassByRefType(match.array.getType())); |
330 | if (!arrayType) |
331 | return std::nullopt; |
332 | |
333 | // require that the array elements are trivial |
334 | // TODO: this is just to make the pass easier to think about. Not an inherent |
335 | // limitation |
336 | mlir::Type eleTy = hlfir::getFortranElementType(arrayType); |
337 | if (!fir::isa_trivial(eleTy)) |
338 | return std::nullopt; |
339 | |
340 | // the array must have the same shape as the elemental. CSE should have |
341 | // deduplicated the fir.shape operations where they are provably the same |
342 | // so we just have to check for the same ssa value |
343 | // TODO: add more ways of getting the shape of the array |
344 | mlir::Value arrayShape; |
345 | if (match.array.getDefiningOp()) |
346 | arrayShape = |
347 | mlir::TypeSwitch<mlir::Operation *, mlir::Value>( |
348 | match.array.getDefiningOp()) |
349 | .Case([](hlfir::DesignateOp designate) { |
350 | return designate.getShape(); |
351 | }) |
352 | .Case([](hlfir::DeclareOp declare) { return declare.getShape(); }) |
353 | .Default([](mlir::Operation *) { return mlir::Value{}; }); |
354 | if (!arrayShape) { |
355 | LLVM_DEBUG(llvm::dbgs() << "Can't get shape of " << match.array << " at " |
356 | << elemental->getLoc() << "\n" ); |
357 | return std::nullopt; |
358 | } |
359 | if (arrayShape != elemental.getShape()) { |
360 | // f2018 10.2.1.2 (3) requires the lhs and rhs of an assignment to be |
361 | // conformable unless the lhs is an allocatable array. In HLFIR we can |
362 | // see this from the presence or absence of the realloc attribute on |
363 | // hlfir.assign. If it is not a realloc assignment, we can trust that |
364 | // the shapes do conform |
365 | if (match.assign.getRealloc()) |
366 | return std::nullopt; |
367 | } |
368 | |
369 | // the transformation wants to apply the elemental in a do-loop at the |
370 | // hlfir.assign, check there are no effects which make this unsafe |
371 | |
372 | // keep track of any values written to in the elemental, as these can't be |
373 | // read from between the elemental and the assignment |
374 | // likewise, values read in the elemental cannot be written to between the |
375 | // elemental and the assign |
376 | mlir::SmallVector<mlir::Value, 1> notToBeAccessedBeforeAssign; |
377 | // any accesses to the array between the array and the assignment means it |
378 | // would be unsafe to move the elemental to the assignment |
379 | notToBeAccessedBeforeAssign.push_back(match.array); |
380 | |
381 | // 1) side effects in the elemental body - it isn't sufficient to just look |
382 | // for ordered elementals because we also cannot support out of order reads |
383 | std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>> |
384 | effects = getEffectsBetween(&elemental.getBody()->front(), |
385 | elemental.getBody()->getTerminator()); |
386 | if (!effects) { |
387 | LLVM_DEBUG(llvm::dbgs() |
388 | << "operation with unknown effects inside elemental\n" ); |
389 | return std::nullopt; |
390 | } |
391 | for (const mlir::MemoryEffects::EffectInstance &effect : *effects) { |
392 | mlir::AliasResult res = containsReadOrWriteEffectOn(effect, match.array); |
393 | if (res.isNo()) { |
394 | if (mlir::isa<mlir::MemoryEffects::Write, mlir::MemoryEffects::Read>( |
395 | effect.getEffect())) |
396 | if (effect.getValue()) |
397 | notToBeAccessedBeforeAssign.push_back(effect.getValue()); |
398 | |
399 | // this is safe in the elemental |
400 | continue; |
401 | } |
402 | |
403 | // don't allow any aliasing writes in the elemental |
404 | if (mlir::isa<mlir::MemoryEffects::Write>(effect.getEffect())) { |
405 | LLVM_DEBUG(llvm::dbgs() << "write inside the elemental body\n" ); |
406 | return std::nullopt; |
407 | } |
408 | |
409 | // allow if and only if the reads are from the elemental indices, in order |
410 | // => each iteration doesn't read values written by other iterations |
411 | // don't allow reads from a different value which may alias: fir alias |
412 | // analysis isn't precise enough to tell us if two aliasing arrays overlap |
413 | // exactly or only partially. If they overlap partially, a designate at the |
414 | // elemental indices could be accessing different elements: e.g. we could |
415 | // designate two slices of the same array at different start indexes. These |
416 | // two MustAlias but index 1 of one array isn't the same element as index 1 |
417 | // of the other array. |
418 | if (!res.isPartial()) { |
419 | if (auto designate = |
420 | effect.getValue().getDefiningOp<hlfir::DesignateOp>()) { |
421 | if (!areIdenticalOrDisjointSlices(match.array, designate.getMemref())) { |
422 | LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate |
423 | << " at " << elemental.getLoc() << "\n" ); |
424 | return std::nullopt; |
425 | } |
426 | auto indices = designate.getIndices(); |
427 | auto elementalIndices = elemental.getIndices(); |
428 | if (indices.size() != elementalIndices.size()) { |
429 | LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate |
430 | << " at " << elemental.getLoc() << "\n" ); |
431 | return std::nullopt; |
432 | } |
433 | if (std::equal(indices.begin(), indices.end(), elementalIndices.begin(), |
434 | elementalIndices.end())) |
435 | continue; |
436 | } |
437 | } |
438 | LLVM_DEBUG(llvm::dbgs() << "disallowed side-effect: " << effect.getValue() |
439 | << " for " << elemental.getLoc() << "\n" ); |
440 | return std::nullopt; |
441 | } |
442 | |
443 | // 2) look for conflicting effects between the elemental and the assignment |
444 | effects = getEffectsBetween(elemental->getNextNode(), match.assign); |
445 | if (!effects) { |
446 | LLVM_DEBUG( |
447 | llvm::dbgs() |
448 | << "operation with unknown effects between elemental and assign\n" ); |
449 | return std::nullopt; |
450 | } |
451 | for (const mlir::MemoryEffects::EffectInstance &effect : *effects) { |
452 | // not safe to access anything written in the elemental as this write |
453 | // will be moved to the assignment |
454 | for (mlir::Value val : notToBeAccessedBeforeAssign) { |
455 | mlir::AliasResult res = containsReadOrWriteEffectOn(effect, val); |
456 | if (!res.isNo()) { |
457 | LLVM_DEBUG(llvm::dbgs() |
458 | << "diasllowed side-effect: " << effect.getValue() << " for " |
459 | << elemental.getLoc() << "\n" ); |
460 | return std::nullopt; |
461 | } |
462 | } |
463 | } |
464 | |
465 | return match; |
466 | } |
467 | |
468 | mlir::LogicalResult ElementalAssignBufferization::matchAndRewrite( |
469 | hlfir::ElementalOp elemental, mlir::PatternRewriter &rewriter) const { |
470 | std::optional<MatchInfo> match = findMatch(elemental); |
471 | if (!match) |
472 | return rewriter.notifyMatchFailure( |
473 | elemental, "cannot prove safety of ElementalAssignBufferization" ); |
474 | |
475 | mlir::Location loc = elemental->getLoc(); |
476 | fir::FirOpBuilder builder(rewriter, elemental.getOperation()); |
477 | auto extents = hlfir::getIndexExtents(loc, builder, elemental.getShape()); |
478 | |
479 | // create the loop at the assignment |
480 | builder.setInsertionPoint(match->assign); |
481 | |
482 | // Generate a loop nest looping around the hlfir.elemental shape and clone |
483 | // hlfir.elemental region inside the inner loop |
484 | hlfir::LoopNest loopNest = |
485 | hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered()); |
486 | builder.setInsertionPointToStart(loopNest.innerLoop.getBody()); |
487 | auto yield = hlfir::inlineElementalOp(loc, builder, elemental, |
488 | loopNest.oneBasedIndices); |
489 | hlfir::Entity elementValue{yield.getElementValue()}; |
490 | rewriter.eraseOp(yield); |
491 | |
492 | // Assign the element value to the array element for this iteration. |
493 | auto arrayElement = hlfir::getElementAt( |
494 | loc, builder, hlfir::Entity{match->array}, loopNest.oneBasedIndices); |
495 | builder.create<hlfir::AssignOp>( |
496 | loc, elementValue, arrayElement, /*realloc=*/false, |
497 | /*keep_lhs_length_if_realloc=*/false, match->assign.getTemporaryLhs()); |
498 | |
499 | rewriter.eraseOp(match->assign); |
500 | rewriter.eraseOp(match->destroy); |
501 | rewriter.eraseOp(elemental); |
502 | return mlir::success(); |
503 | } |
504 | |
505 | /// Expand hlfir.assign of a scalar RHS to array LHS into a loop nest |
506 | /// of element-by-element assignments: |
507 | /// hlfir.assign %cst to %0 : f32, !fir.ref<!fir.array<6x6xf32>> |
508 | /// into: |
509 | /// fir.do_loop %arg0 = %c1 to %c6 step %c1 unordered { |
510 | /// fir.do_loop %arg1 = %c1 to %c6 step %c1 unordered { |
511 | /// %1 = hlfir.designate %0 (%arg1, %arg0) : |
512 | /// (!fir.ref<!fir.array<6x6xf32>>, index, index) -> !fir.ref<f32> |
513 | /// hlfir.assign %cst to %1 : f32, !fir.ref<f32> |
514 | /// } |
515 | /// } |
516 | class BroadcastAssignBufferization |
517 | : public mlir::OpRewritePattern<hlfir::AssignOp> { |
518 | private: |
519 | public: |
520 | using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern; |
521 | |
522 | mlir::LogicalResult |
523 | matchAndRewrite(hlfir::AssignOp assign, |
524 | mlir::PatternRewriter &rewriter) const override; |
525 | }; |
526 | |
527 | mlir::LogicalResult BroadcastAssignBufferization::matchAndRewrite( |
528 | hlfir::AssignOp assign, mlir::PatternRewriter &rewriter) const { |
529 | // Since RHS is a scalar and LHS is an array, LHS must be allocated |
530 | // in a conforming Fortran program, and LHS cannot be reallocated |
531 | // as a result of the assignment. So we can ignore isAllocatableAssignment |
532 | // and do the transformation always. |
533 | mlir::Value rhs = assign.getRhs(); |
534 | if (!fir::isa_trivial(rhs.getType())) |
535 | return rewriter.notifyMatchFailure( |
536 | assign, "AssignOp's RHS is not a trivial scalar" ); |
537 | |
538 | hlfir::Entity lhs{assign.getLhs()}; |
539 | if (!lhs.isArray()) |
540 | return rewriter.notifyMatchFailure(assign, |
541 | "AssignOp's LHS is not an array" ); |
542 | |
543 | mlir::Type eleTy = lhs.getFortranElementType(); |
544 | if (!fir::isa_trivial(eleTy)) |
545 | return rewriter.notifyMatchFailure( |
546 | assign, "AssignOp's LHS data type is not trivial" ); |
547 | |
548 | mlir::Location loc = assign->getLoc(); |
549 | fir::FirOpBuilder builder(rewriter, assign.getOperation()); |
550 | builder.setInsertionPoint(assign); |
551 | lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs); |
552 | mlir::Value shape = hlfir::genShape(loc, builder, lhs); |
553 | llvm::SmallVector<mlir::Value> extents = |
554 | hlfir::getIndexExtents(loc, builder, shape); |
555 | hlfir::LoopNest loopNest = |
556 | hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true); |
557 | builder.setInsertionPointToStart(loopNest.innerLoop.getBody()); |
558 | auto arrayElement = |
559 | hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices); |
560 | builder.create<hlfir::AssignOp>(loc, rhs, arrayElement); |
561 | rewriter.eraseOp(assign); |
562 | return mlir::success(); |
563 | } |
564 | |
565 | /// Expand hlfir.assign of array RHS to array LHS into a loop nest |
566 | /// of element-by-element assignments: |
567 | /// hlfir.assign %4 to %5 : !fir.ref<!fir.array<3x3xf32>>, |
568 | /// !fir.ref<!fir.array<3x3xf32>> |
569 | /// into: |
570 | /// fir.do_loop %arg1 = %c1 to %c3 step %c1 unordered { |
571 | /// fir.do_loop %arg2 = %c1 to %c3 step %c1 unordered { |
572 | /// %6 = hlfir.designate %4 (%arg2, %arg1) : |
573 | /// (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32> |
574 | /// %7 = fir.load %6 : !fir.ref<f32> |
575 | /// %8 = hlfir.designate %5 (%arg2, %arg1) : |
576 | /// (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32> |
577 | /// hlfir.assign %7 to %8 : f32, !fir.ref<f32> |
578 | /// } |
579 | /// } |
580 | /// |
581 | /// The transformation is correct only when LHS and RHS do not alias. |
582 | /// This transformation does not support runtime checking for |
583 | /// non-conforming LHS/RHS arrays' shapes currently. |
584 | class VariableAssignBufferization |
585 | : public mlir::OpRewritePattern<hlfir::AssignOp> { |
586 | private: |
587 | public: |
588 | using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern; |
589 | |
590 | mlir::LogicalResult |
591 | matchAndRewrite(hlfir::AssignOp assign, |
592 | mlir::PatternRewriter &rewriter) const override; |
593 | }; |
594 | |
595 | mlir::LogicalResult VariableAssignBufferization::matchAndRewrite( |
596 | hlfir::AssignOp assign, mlir::PatternRewriter &rewriter) const { |
597 | if (assign.isAllocatableAssignment()) |
598 | return rewriter.notifyMatchFailure(assign, "AssignOp may imply allocation" ); |
599 | |
600 | hlfir::Entity rhs{assign.getRhs()}; |
601 | // TODO: ExprType check is here to avoid conflicts with |
602 | // ElementalAssignBufferization pattern. We need to combine |
603 | // these matchers into a single one that applies to AssignOp. |
604 | if (rhs.getType().isa<hlfir::ExprType>()) |
605 | return rewriter.notifyMatchFailure(assign, "RHS is not in memory" ); |
606 | |
607 | if (!rhs.isArray()) |
608 | return rewriter.notifyMatchFailure(assign, |
609 | "AssignOp's RHS is not an array" ); |
610 | |
611 | mlir::Type rhsEleTy = rhs.getFortranElementType(); |
612 | if (!fir::isa_trivial(rhsEleTy)) |
613 | return rewriter.notifyMatchFailure( |
614 | assign, "AssignOp's RHS data type is not trivial" ); |
615 | |
616 | hlfir::Entity lhs{assign.getLhs()}; |
617 | if (!lhs.isArray()) |
618 | return rewriter.notifyMatchFailure(assign, |
619 | "AssignOp's LHS is not an array" ); |
620 | |
621 | mlir::Type lhsEleTy = lhs.getFortranElementType(); |
622 | if (!fir::isa_trivial(lhsEleTy)) |
623 | return rewriter.notifyMatchFailure( |
624 | assign, "AssignOp's LHS data type is not trivial" ); |
625 | |
626 | if (lhsEleTy != rhsEleTy) |
627 | return rewriter.notifyMatchFailure(assign, |
628 | "RHS/LHS element types mismatch" ); |
629 | |
630 | fir::AliasAnalysis aliasAnalysis; |
631 | mlir::AliasResult aliasRes = aliasAnalysis.alias(lhs, rhs); |
632 | // TODO: use areIdenticalOrDisjointSlices() to check if |
633 | // we can still do the expansion. |
634 | if (!aliasRes.isNo()) { |
635 | LLVM_DEBUG(llvm::dbgs() << "VariableAssignBufferization:\n" |
636 | << "\tLHS: " << lhs << "\n" |
637 | << "\tRHS: " << rhs << "\n" |
638 | << "\tALIAS: " << aliasRes << "\n" ); |
639 | return rewriter.notifyMatchFailure(assign, "RHS/LHS may alias" ); |
640 | } |
641 | |
642 | mlir::Location loc = assign->getLoc(); |
643 | fir::FirOpBuilder builder(rewriter, assign.getOperation()); |
644 | builder.setInsertionPoint(assign); |
645 | rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs); |
646 | lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs); |
647 | mlir::Value shape = hlfir::genShape(loc, builder, lhs); |
648 | llvm::SmallVector<mlir::Value> extents = |
649 | hlfir::getIndexExtents(loc, builder, shape); |
650 | hlfir::LoopNest loopNest = |
651 | hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true); |
652 | builder.setInsertionPointToStart(loopNest.innerLoop.getBody()); |
653 | auto rhsArrayElement = |
654 | hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices); |
655 | rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement); |
656 | auto lhsArrayElement = |
657 | hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices); |
658 | builder.create<hlfir::AssignOp>(loc, rhsArrayElement, lhsArrayElement); |
659 | rewriter.eraseOp(assign); |
660 | return mlir::success(); |
661 | } |
662 | |
663 | using GenBodyFn = |
664 | std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value, |
665 | const llvm::SmallVectorImpl<mlir::Value> &)>; |
666 | static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder, |
667 | mlir::Location loc, mlir::Value init, |
668 | mlir::Value shape, GenBodyFn genBody) { |
669 | auto extents = hlfir::getIndexExtents(loc, builder, shape); |
670 | mlir::Value reduction = init; |
671 | mlir::IndexType idxTy = builder.getIndexType(); |
672 | mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1); |
673 | |
674 | // Create a reduction loop nest. We use one-based indices so that they can be |
675 | // passed to the elemental, and reverse the order so that they can be |
676 | // generated in column-major order for better performance. |
677 | llvm::SmallVector<mlir::Value> indices(extents.size(), mlir::Value{}); |
678 | for (unsigned i = 0; i < extents.size(); ++i) { |
679 | auto loop = builder.create<fir::DoLoopOp>( |
680 | loc, oneIdx, extents[extents.size() - i - 1], oneIdx, false, |
681 | /*finalCountValue=*/false, reduction); |
682 | reduction = loop.getRegionIterArgs()[0]; |
683 | indices[extents.size() - i - 1] = loop.getInductionVar(); |
684 | // Set insertion point to the loop body so that the next loop |
685 | // is inserted inside the current one. |
686 | builder.setInsertionPointToStart(loop.getBody()); |
687 | } |
688 | |
689 | // Generate the body |
690 | reduction = genBody(builder, loc, reduction, indices); |
691 | |
692 | // Unwind the loop nest. |
693 | for (unsigned i = 0; i < extents.size(); ++i) { |
694 | auto result = builder.create<fir::ResultOp>(loc, reduction); |
695 | auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp()); |
696 | reduction = loop.getResult(0); |
697 | // Set insertion point after the loop operation that we have |
698 | // just processed. |
699 | builder.setInsertionPointAfter(loop.getOperation()); |
700 | } |
701 | |
702 | return reduction; |
703 | } |
704 | |
705 | /// Given a reduction operation with an elemental mask, attempt to generate a |
706 | /// do-loop to perform the operation inline. |
707 | /// %e = hlfir.elemental %shape unordered |
708 | /// %r = hlfir.count %e |
709 | /// => |
710 | /// %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init) |
711 | /// %i = <inline elemental> |
712 | /// %c = <reduce count> %i |
713 | /// fir.result %c |
714 | template <typename Op> |
715 | class ReductionElementalConversion : public mlir::OpRewritePattern<Op> { |
716 | public: |
717 | using mlir::OpRewritePattern<Op>::OpRewritePattern; |
718 | |
719 | mlir::LogicalResult |
720 | matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { |
721 | mlir::Location loc = op.getLoc(); |
722 | hlfir::ElementalOp elemental = |
723 | op.getMask().template getDefiningOp<hlfir::ElementalOp>(); |
724 | if (!elemental || op.getDim()) |
725 | return rewriter.notifyMatchFailure(op, "Did not find valid elemental" ); |
726 | |
727 | fir::KindMapping kindMap = |
728 | fir::getKindMapping(op->template getParentOfType<mlir::ModuleOp>()); |
729 | fir::FirOpBuilder builder{op, kindMap}; |
730 | |
731 | mlir::Value init; |
732 | GenBodyFn genBodyFn; |
733 | if constexpr (std::is_same_v<Op, hlfir::AnyOp>) { |
734 | init = builder.createIntegerConstant(loc, builder.getI1Type(), 0); |
735 | genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc, |
736 | mlir::Value reduction, |
737 | const llvm::SmallVectorImpl<mlir::Value> &indices) |
738 | -> mlir::Value { |
739 | // Inline the elemental and get the condition from it. |
740 | auto yield = inlineElementalOp(loc, builder, elemental, indices); |
741 | mlir::Value cond = builder.create<fir::ConvertOp>( |
742 | loc, builder.getI1Type(), yield.getElementValue()); |
743 | yield->erase(); |
744 | |
745 | // Conditionally set the reduction variable. |
746 | return builder.create<mlir::arith::OrIOp>(loc, reduction, cond); |
747 | }; |
748 | } else if constexpr (std::is_same_v<Op, hlfir::AllOp>) { |
749 | init = builder.createIntegerConstant(loc, builder.getI1Type(), 1); |
750 | genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc, |
751 | mlir::Value reduction, |
752 | const llvm::SmallVectorImpl<mlir::Value> &indices) |
753 | -> mlir::Value { |
754 | // Inline the elemental and get the condition from it. |
755 | auto yield = inlineElementalOp(loc, builder, elemental, indices); |
756 | mlir::Value cond = builder.create<fir::ConvertOp>( |
757 | loc, builder.getI1Type(), yield.getElementValue()); |
758 | yield->erase(); |
759 | |
760 | // Conditionally set the reduction variable. |
761 | return builder.create<mlir::arith::AndIOp>(loc, reduction, cond); |
762 | }; |
763 | } else if constexpr (std::is_same_v<Op, hlfir::CountOp>) { |
764 | init = builder.createIntegerConstant(loc, op.getType(), 0); |
765 | genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc, |
766 | mlir::Value reduction, |
767 | const llvm::SmallVectorImpl<mlir::Value> &indices) |
768 | -> mlir::Value { |
769 | // Inline the elemental and get the condition from it. |
770 | auto yield = inlineElementalOp(loc, builder, elemental, indices); |
771 | mlir::Value cond = builder.create<fir::ConvertOp>( |
772 | loc, builder.getI1Type(), yield.getElementValue()); |
773 | yield->erase(); |
774 | |
775 | // Conditionally add one to the current value |
776 | mlir::Value one = |
777 | builder.createIntegerConstant(loc, reduction.getType(), 1); |
778 | mlir::Value add1 = |
779 | builder.create<mlir::arith::AddIOp>(loc, reduction, one); |
780 | return builder.create<mlir::arith::SelectOp>(loc, cond, add1, |
781 | reduction); |
782 | }; |
783 | } else { |
784 | return mlir::failure(); |
785 | } |
786 | |
787 | mlir::Value res = generateReductionLoop(builder, loc, init, |
788 | elemental.getOperand(0), genBodyFn); |
789 | if (res.getType() != op.getType()) |
790 | res = builder.create<fir::ConvertOp>(loc, op.getType(), res); |
791 | |
792 | // Check if the op was the only user of the elemental (apart from a |
793 | // destroy), and remove it if so. |
794 | mlir::Operation::user_range elemUsers = elemental->getUsers(); |
795 | hlfir::DestroyOp elemDestroy; |
796 | if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) { |
797 | elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin()); |
798 | if (!elemDestroy) |
799 | elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin()); |
800 | } |
801 | |
802 | rewriter.replaceOp(op, res); |
803 | if (elemDestroy) { |
804 | rewriter.eraseOp(elemDestroy); |
805 | rewriter.eraseOp(elemental); |
806 | } |
807 | return mlir::success(); |
808 | } |
809 | }; |
810 | |
811 | // Look for minloc(mask=elemental) and generate the minloc loop with |
812 | // inlined elemental. |
813 | // %e = hlfir.elemental %shape ({ ... }) |
814 | // %m = hlfir.minloc %array mask %e |
815 | template <typename Op> |
816 | class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> { |
817 | public: |
818 | using mlir::OpRewritePattern<Op>::OpRewritePattern; |
819 | |
820 | mlir::LogicalResult |
821 | matchAndRewrite(Op mloc, mlir::PatternRewriter &rewriter) const override { |
822 | if (!mloc.getMask() || mloc.getDim() || mloc.getBack()) |
823 | return rewriter.notifyMatchFailure(mloc, |
824 | "Did not find valid minloc/maxloc" ); |
825 | |
826 | bool isMax = std::is_same_v<Op, hlfir::MaxlocOp>; |
827 | |
828 | auto elemental = |
829 | mloc.getMask().template getDefiningOp<hlfir::ElementalOp>(); |
830 | if (!elemental || hlfir::elementalOpMustProduceTemp(elemental)) |
831 | return rewriter.notifyMatchFailure(mloc, "Did not find elemental" ); |
832 | |
833 | mlir::Value array = mloc.getArray(); |
834 | |
835 | unsigned rank = mlir::cast<hlfir::ExprType>(mloc.getType()).getShape()[0]; |
836 | mlir::Type arrayType = array.getType(); |
837 | if (!arrayType.isa<fir::BoxType>()) |
838 | return rewriter.notifyMatchFailure( |
839 | mloc, "Currently requires a boxed type input" ); |
840 | mlir::Type elementType = hlfir::getFortranElementType(arrayType); |
841 | if (!fir::isa_trivial(elementType)) |
842 | return rewriter.notifyMatchFailure( |
843 | mloc, "Character arrays are currently not handled" ); |
844 | |
845 | mlir::Location loc = mloc.getLoc(); |
846 | fir::FirOpBuilder builder{rewriter, mloc.getOperation()}; |
847 | mlir::Value resultArr = builder.createTemporary( |
848 | loc, fir::SequenceType::get( |
849 | rank, hlfir::getFortranElementType(mloc.getType()))); |
850 | |
851 | auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc, |
852 | mlir::Type elementType) { |
853 | if (auto ty = elementType.dyn_cast<mlir::FloatType>()) { |
854 | const llvm::fltSemantics &sem = ty.getFloatSemantics(); |
855 | llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax); |
856 | return builder.createRealConstant(loc, elementType, limit); |
857 | } |
858 | unsigned bits = elementType.getIntOrFloatBitWidth(); |
859 | int64_t limitInt = |
860 | isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue() |
861 | : llvm::APInt::getSignedMaxValue(bits).getSExtValue(); |
862 | return builder.createIntegerConstant(loc, elementType, limitInt); |
863 | }; |
864 | |
865 | auto genBodyOp = |
866 | [&rank, &resultArr, &elemental, isMax]( |
867 | fir::FirOpBuilder builder, mlir::Location loc, |
868 | mlir::Type elementType, mlir::Value array, mlir::Value flagRef, |
869 | mlir::Value reduction, |
870 | const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value { |
871 | // We are in the innermost loop: generate the elemental inline |
872 | mlir::Value oneIdx = |
873 | builder.createIntegerConstant(loc, builder.getIndexType(), 1); |
874 | llvm::SmallVector<mlir::Value> oneBasedIndices; |
875 | llvm::transform( |
876 | indices, std::back_inserter(oneBasedIndices), [&](mlir::Value V) { |
877 | return builder.create<mlir::arith::AddIOp>(loc, V, oneIdx); |
878 | }); |
879 | hlfir::YieldElementOp yield = |
880 | hlfir::inlineElementalOp(loc, builder, elemental, oneBasedIndices); |
881 | mlir::Value maskElem = yield.getElementValue(); |
882 | yield->erase(); |
883 | |
884 | mlir::Type ifCompatType = builder.getI1Type(); |
885 | mlir::Value ifCompatElem = |
886 | builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem); |
887 | |
888 | llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType}; |
889 | fir::IfOp maskIfOp = |
890 | builder.create<fir::IfOp>(loc, elementType, ifCompatElem, |
891 | /*withElseRegion=*/true); |
892 | builder.setInsertionPointToStart(&maskIfOp.getThenRegion().front()); |
893 | |
894 | // Set flag that mask was true at some point |
895 | mlir::Value flagSet = builder.createIntegerConstant( |
896 | loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1); |
897 | mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef); |
898 | mlir::Value addr = hlfir::getElementAt(loc, builder, hlfir::Entity{array}, |
899 | oneBasedIndices); |
900 | mlir::Value elem = builder.create<fir::LoadOp>(loc, addr); |
901 | |
902 | // Compare with the max reduction value |
903 | mlir::Value cmp; |
904 | if (elementType.isa<mlir::FloatType>()) { |
905 | // For FP reductions we want the first smallest value to be used, that |
906 | // is not NaN. A OGL/OLT condition will usually work for this unless all |
907 | // the values are Nan or Inf. This follows the same logic as |
908 | // NumericCompare for Minloc/Maxlox in extrema.cpp. |
909 | cmp = builder.create<mlir::arith::CmpFOp>( |
910 | loc, |
911 | isMax ? mlir::arith::CmpFPredicate::OGT |
912 | : mlir::arith::CmpFPredicate::OLT, |
913 | elem, reduction); |
914 | |
915 | mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>( |
916 | loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction); |
917 | mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>( |
918 | loc, mlir::arith::CmpFPredicate::OEQ, elem, elem); |
919 | cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2); |
920 | cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan); |
921 | } else if (elementType.isa<mlir::IntegerType>()) { |
922 | cmp = builder.create<mlir::arith::CmpIOp>( |
923 | loc, |
924 | isMax ? mlir::arith::CmpIPredicate::sgt |
925 | : mlir::arith::CmpIPredicate::slt, |
926 | elem, reduction); |
927 | } else { |
928 | llvm_unreachable("unsupported type" ); |
929 | } |
930 | |
931 | // The condition used for the loop is isFirst || <the condition above>. |
932 | isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst); |
933 | isFirst = builder.create<mlir::arith::XOrIOp>( |
934 | loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1)); |
935 | cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst); |
936 | |
937 | // Set the new coordinate to the result |
938 | fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp, |
939 | /*withElseRegion*/ true); |
940 | |
941 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
942 | builder.create<fir::StoreOp>(loc, flagSet, flagRef); |
943 | mlir::Type resultElemTy = |
944 | hlfir::getFortranElementType(resultArr.getType()); |
945 | mlir::Type returnRefTy = builder.getRefType(resultElemTy); |
946 | mlir::IndexType idxTy = builder.getIndexType(); |
947 | |
948 | for (unsigned int i = 0; i < rank; ++i) { |
949 | mlir::Value index = builder.createIntegerConstant(loc, idxTy, i + 1); |
950 | mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>( |
951 | loc, returnRefTy, resultArr, index); |
952 | mlir::Value fortranIndex = builder.create<fir::ConvertOp>( |
953 | loc, resultElemTy, oneBasedIndices[i]); |
954 | builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr); |
955 | } |
956 | builder.create<fir::ResultOp>(loc, elem); |
957 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
958 | builder.create<fir::ResultOp>(loc, reduction); |
959 | builder.setInsertionPointAfter(ifOp); |
960 | |
961 | // Close the mask if |
962 | builder.create<fir::ResultOp>(loc, ifOp.getResult(0)); |
963 | builder.setInsertionPointToStart(&maskIfOp.getElseRegion().front()); |
964 | builder.create<fir::ResultOp>(loc, reduction); |
965 | builder.setInsertionPointAfter(maskIfOp); |
966 | |
967 | return maskIfOp.getResult(0); |
968 | }; |
969 | auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc, |
970 | const mlir::Type &resultElemType, mlir::Value resultArr, |
971 | mlir::Value index) { |
972 | mlir::Type resultRefTy = builder.getRefType(resultElemType); |
973 | mlir::Value oneIdx = |
974 | builder.createIntegerConstant(loc, builder.getIndexType(), 1); |
975 | index = builder.create<mlir::arith::AddIOp>(loc, index, oneIdx); |
976 | return builder.create<hlfir::DesignateOp>(loc, resultRefTy, resultArr, |
977 | index); |
978 | }; |
979 | |
980 | // Initialize the result |
981 | mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType()); |
982 | mlir::Type resultRefTy = builder.getRefType(resultElemTy); |
983 | mlir::Value returnValue = |
984 | builder.createIntegerConstant(loc, resultElemTy, 0); |
985 | for (unsigned int i = 0; i < rank; ++i) { |
986 | mlir::Value index = |
987 | builder.createIntegerConstant(loc, builder.getIndexType(), i + 1); |
988 | mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>( |
989 | loc, resultRefTy, resultArr, index); |
990 | builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr); |
991 | } |
992 | |
993 | fir::genMinMaxlocReductionLoop(builder, array, init, genBodyOp, getAddrFn, |
994 | rank, elementType, loc, builder.getI1Type(), |
995 | resultArr, false); |
996 | |
997 | mlir::Value asExpr = builder.create<hlfir::AsExprOp>( |
998 | loc, resultArr, builder.createBool(loc, false)); |
999 | |
1000 | // Check all the users - the destroy is no longer required, and any assign |
1001 | // can use resultArr directly so that VariableAssignBufferization in this |
1002 | // pass can optimize the results. Other operations are replaces with an |
1003 | // AsExpr for the temporary resultArr. |
1004 | llvm::SmallVector<hlfir::DestroyOp> destroys; |
1005 | llvm::SmallVector<hlfir::AssignOp> assigns; |
1006 | for (auto user : mloc->getUsers()) { |
1007 | if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(user)) |
1008 | destroys.push_back(destroy); |
1009 | else if (auto assign = mlir::dyn_cast<hlfir::AssignOp>(user)) |
1010 | assigns.push_back(assign); |
1011 | } |
1012 | |
1013 | // Check if the minloc/maxloc was the only user of the elemental (apart from |
1014 | // a destroy), and remove it if so. |
1015 | mlir::Operation::user_range elemUsers = elemental->getUsers(); |
1016 | hlfir::DestroyOp elemDestroy; |
1017 | if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) { |
1018 | elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin()); |
1019 | if (!elemDestroy) |
1020 | elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin()); |
1021 | } |
1022 | |
1023 | for (auto d : destroys) |
1024 | rewriter.eraseOp(d); |
1025 | for (auto a : assigns) |
1026 | a.setOperand(0, resultArr); |
1027 | rewriter.replaceOp(mloc, asExpr); |
1028 | if (elemDestroy) { |
1029 | rewriter.eraseOp(elemDestroy); |
1030 | rewriter.eraseOp(elemental); |
1031 | } |
1032 | return mlir::success(); |
1033 | } |
1034 | }; |
1035 | |
1036 | class OptimizedBufferizationPass |
1037 | : public hlfir::impl::OptimizedBufferizationBase< |
1038 | OptimizedBufferizationPass> { |
1039 | public: |
1040 | void runOnOperation() override { |
1041 | mlir::func::FuncOp func = getOperation(); |
1042 | mlir::MLIRContext *context = &getContext(); |
1043 | |
1044 | mlir::GreedyRewriteConfig config; |
1045 | // Prevent the pattern driver from merging blocks |
1046 | config.enableRegionSimplification = false; |
1047 | |
1048 | mlir::RewritePatternSet patterns(context); |
1049 | // TODO: right now the patterns are non-conflicting, |
1050 | // but it might be better to run this pass on hlfir.assign |
1051 | // operations and decide which transformation to apply |
1052 | // at one place (e.g. we may use some heuristics and |
1053 | // choose different optimization strategies). |
1054 | // This requires small code reordering in ElementalAssignBufferization. |
1055 | patterns.insert<ElementalAssignBufferization>(context); |
1056 | patterns.insert<BroadcastAssignBufferization>(context); |
1057 | patterns.insert<VariableAssignBufferization>(context); |
1058 | patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context); |
1059 | patterns.insert<ReductionElementalConversion<hlfir::AnyOp>>(context); |
1060 | patterns.insert<ReductionElementalConversion<hlfir::AllOp>>(context); |
1061 | patterns.insert<MinMaxlocElementalConversion<hlfir::MinlocOp>>(context); |
1062 | patterns.insert<MinMaxlocElementalConversion<hlfir::MaxlocOp>>(context); |
1063 | |
1064 | if (mlir::failed(mlir::applyPatternsAndFoldGreedily( |
1065 | func, std::move(patterns), config))) { |
1066 | mlir::emitError(func.getLoc(), |
1067 | "failure in HLFIR optimized bufferization" ); |
1068 | signalPassFailure(); |
1069 | } |
1070 | } |
1071 | }; |
1072 | } // namespace |
1073 | |
1074 | std::unique_ptr<mlir::Pass> hlfir::createOptimizedBufferizationPass() { |
1075 | return std::make_unique<OptimizedBufferizationPass>(); |
1076 | } |
1077 | |