1//===- OptimizedBufferization.cpp - special cases for bufferization -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// In some special cases we can bufferize hlfir expressions in a more optimal
9// way so as to avoid creating temporaries. This pass handles these. It should
10// be run before the catch-all bufferization pass.
11//
12// This requires constant subexpression elimination to have already been run.
13//===----------------------------------------------------------------------===//
14
15#include "flang/Optimizer/Analysis/AliasAnalysis.h"
16#include "flang/Optimizer/Builder/FIRBuilder.h"
17#include "flang/Optimizer/Builder/HLFIRTools.h"
18#include "flang/Optimizer/Dialect/FIROps.h"
19#include "flang/Optimizer/Dialect/FIRType.h"
20#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
21#include "flang/Optimizer/HLFIR/HLFIROps.h"
22#include "flang/Optimizer/HLFIR/Passes.h"
23#include "flang/Optimizer/OpenMP/Passes.h"
24#include "flang/Optimizer/Transforms/Utils.h"
25#include "mlir/Dialect/Func/IR/FuncOps.h"
26#include "mlir/IR/Dominance.h"
27#include "mlir/IR/PatternMatch.h"
28#include "mlir/Interfaces/SideEffectInterfaces.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Support/LLVM.h"
31#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include <iterator>
34#include <memory>
35#include <mlir/Analysis/AliasAnalysis.h>
36#include <optional>
37
38namespace hlfir {
39#define GEN_PASS_DEF_OPTIMIZEDBUFFERIZATION
40#include "flang/Optimizer/HLFIR/Passes.h.inc"
41} // namespace hlfir
42
43#define DEBUG_TYPE "opt-bufferization"
44
45namespace {
46
47/// This transformation should match in place modification of arrays.
48/// It should match code of the form
49/// %array = some.operation // array has shape %shape
50/// %expr = hlfir.elemental %shape : [...] {
51/// bb0(%arg0: index)
52/// %0 = hlfir.designate %array(%arg0)
53/// [...] // no other reads or writes to %array
54/// hlfir.yield_element %element
55/// }
56/// hlfir.assign %expr to %array
57/// hlfir.destroy %expr
58///
59/// Or
60///
61/// %read_array = some.operation // shape %shape
62/// %expr = hlfir.elemental %shape : [...] {
63/// bb0(%arg0: index)
64/// %0 = hlfir.designate %read_array(%arg0)
65/// [...]
66/// hlfir.yield_element %element
67/// }
68/// %write_array = some.operation // with shape %shape
69/// [...] // operations which don't effect write_array
70/// hlfir.assign %expr to %write_array
71/// hlfir.destroy %expr
72///
73/// In these cases, it is safe to turn the elemental into a do loop and modify
74/// elements of %array in place without creating an extra temporary for the
75/// elemental. We must check that there are no reads from the array at indexes
76/// which might conflict with the assignment or any writes. For now we will keep
77/// that strict and say that all reads must be at the elemental index (it is
78/// probably safe to read from higher indices if lowering to an ordered loop).
79class ElementalAssignBufferization
80 : public mlir::OpRewritePattern<hlfir::ElementalOp> {
81private:
82 struct MatchInfo {
83 mlir::Value array;
84 hlfir::AssignOp assign;
85 hlfir::DestroyOp destroy;
86 };
87 /// determines if the transformation can be applied to this elemental
88 static std::optional<MatchInfo> findMatch(hlfir::ElementalOp elemental);
89
90 /// Returns the array indices for the given hlfir.designate.
91 /// It recognizes the computations used to transform the one-based indices
92 /// into the array's lb-based indices, and returns the one-based indices
93 /// in these cases.
94 static llvm::SmallVector<mlir::Value>
95 getDesignatorIndices(hlfir::DesignateOp designate);
96
97public:
98 using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;
99
100 llvm::LogicalResult
101 matchAndRewrite(hlfir::ElementalOp elemental,
102 mlir::PatternRewriter &rewriter) const override;
103};
104
105/// recursively collect all effects between start and end (including start, not
106/// including end) start must properly dominate end, start and end must be in
107/// the same block. If any operations with unknown effects are found,
108/// std::nullopt is returned
109static std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>>
110getEffectsBetween(mlir::Operation *start, mlir::Operation *end) {
111 mlir::SmallVector<mlir::MemoryEffects::EffectInstance> ret;
112 if (start == end)
113 return ret;
114 assert(start->getBlock() && end->getBlock() && "TODO: block arguments");
115 assert(start->getBlock() == end->getBlock());
116 assert(mlir::DominanceInfo{}.properlyDominates(start, end));
117
118 mlir::Operation *nextOp = start;
119 while (nextOp && nextOp != end) {
120 std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>>
121 effects = mlir::getEffectsRecursively(nextOp);
122 if (!effects)
123 return std::nullopt;
124 ret.append(*effects);
125 nextOp = nextOp->getNextNode();
126 }
127 return ret;
128}
129
130/// If effect is a read or write on val, return whether it aliases.
131/// Otherwise return mlir::AliasResult::NoAlias
132static mlir::AliasResult
133containsReadOrWriteEffectOn(const mlir::MemoryEffects::EffectInstance &effect,
134 mlir::Value val) {
135 fir::AliasAnalysis aliasAnalysis;
136
137 if (mlir::isa<mlir::MemoryEffects::Read, mlir::MemoryEffects::Write>(
138 effect.getEffect())) {
139 mlir::Value accessedVal = effect.getValue();
140 if (mlir::isa<fir::DebuggingResource>(effect.getResource()))
141 return mlir::AliasResult::NoAlias;
142 if (!accessedVal)
143 return mlir::AliasResult::MayAlias;
144 if (accessedVal == val)
145 return mlir::AliasResult::MustAlias;
146
147 // if the accessed value might alias val
148 mlir::AliasResult res = aliasAnalysis.alias(val, accessedVal);
149 if (!res.isNo())
150 return res;
151
152 // FIXME: alias analysis of fir.load
153 // follow this common pattern:
154 // %ref = hlfir.designate %array(%index)
155 // %val = fir.load $ref
156 if (auto designate = accessedVal.getDefiningOp<hlfir::DesignateOp>()) {
157 if (designate.getMemref() == val)
158 return mlir::AliasResult::MustAlias;
159
160 // if the designate is into an array that might alias val
161 res = aliasAnalysis.alias(val, designate.getMemref());
162 if (!res.isNo())
163 return res;
164 }
165 }
166 return mlir::AliasResult::NoAlias;
167}
168
169// Helper class for analyzing two array slices represented
170// by two hlfir.designate operations.
171class ArraySectionAnalyzer {
172public:
173 // The result of the analyzis is one of the values below.
174 enum class SlicesOverlapKind {
175 // Slices overlap is unknown.
176 Unknown,
177 // Slices are definitely identical.
178 DefinitelyIdentical,
179 // Slices are definitely disjoint.
180 DefinitelyDisjoint,
181 // Slices may be either disjoint or identical,
182 // i.e. there is definitely no partial overlap.
183 EitherIdenticalOrDisjoint
184 };
185
186 // Analyzes two hlfir.designate results and returns the overlap kind.
187 // The callers may use this method when the alias analysis reports
188 // an alias of some kind, so that we can run Fortran specific analysis
189 // on the array slices to see if they are identical or disjoint.
190 // Note that the alias analysis are not able to give such an answer
191 // about the references.
192 static SlicesOverlapKind analyze(mlir::Value ref1, mlir::Value ref2);
193
194private:
195 struct SectionDesc {
196 // An array section is described by <lb, ub, stride> tuple.
197 // If the designator's subscript is not a triple, then
198 // the section descriptor is constructed as <lb, nullptr, nullptr>.
199 mlir::Value lb, ub, stride;
200
201 SectionDesc(mlir::Value lb, mlir::Value ub, mlir::Value stride)
202 : lb(lb), ub(ub), stride(stride) {
203 assert(lb && "lower bound or index must be specified");
204 normalize();
205 }
206
207 // Normalize the section descriptor:
208 // 1. If UB is nullptr, then it is set to LB.
209 // 2. If LB==UB, then stride does not matter,
210 // so it is reset to nullptr.
211 // 3. If STRIDE==1, then it is reset to nullptr.
212 void normalize() {
213 if (!ub)
214 ub = lb;
215 if (lb == ub)
216 stride = nullptr;
217 if (stride)
218 if (auto val = fir::getIntIfConstant(stride))
219 if (*val == 1)
220 stride = nullptr;
221 }
222
223 bool operator==(const SectionDesc &other) const {
224 return lb == other.lb && ub == other.ub && stride == other.stride;
225 }
226 };
227
228 // Given an operand_iterator over the indices operands,
229 // read the subscript values and return them as SectionDesc
230 // updating the iterator. If isTriplet is true,
231 // the subscript is a triplet, and the result is <lb, ub, stride>.
232 // Otherwise, the subscript is a scalar index, and the result
233 // is <index, nullptr, nullptr>.
234 static SectionDesc readSectionDesc(mlir::Operation::operand_iterator &it,
235 bool isTriplet) {
236 if (isTriplet)
237 return {*it++, *it++, *it++};
238 return {*it++, nullptr, nullptr};
239 }
240
241 // Return the ordered lower and upper bounds of the section.
242 // If stride is known to be non-negative, then the ordered
243 // bounds match the <lb, ub> of the descriptor.
244 // If stride is known to be negative, then the ordered
245 // bounds are <ub, lb> of the descriptor.
246 // If stride is unknown, we cannot deduce any order,
247 // so the result is <nullptr, nullptr>
248 static std::pair<mlir::Value, mlir::Value>
249 getOrderedBounds(const SectionDesc &desc) {
250 mlir::Value stride = desc.stride;
251 // Null stride means stride=1.
252 if (!stride)
253 return {desc.lb, desc.ub};
254 // Reverse the bounds, if stride is negative.
255 if (auto val = fir::getIntIfConstant(stride)) {
256 if (*val >= 0)
257 return {desc.lb, desc.ub};
258 else
259 return {desc.ub, desc.lb};
260 }
261
262 return {nullptr, nullptr};
263 }
264
265 // Given two array sections <lb1, ub1, stride1> and
266 // <lb2, ub2, stride2>, return true only if the sections
267 // are known to be disjoint.
268 //
269 // For example, for any positive constant C:
270 // X:Y does not overlap with (Y+C):Z
271 // X:Y does not overlap with Z:(X-C)
272 static bool areDisjointSections(const SectionDesc &desc1,
273 const SectionDesc &desc2) {
274 auto [lb1, ub1] = getOrderedBounds(desc1);
275 auto [lb2, ub2] = getOrderedBounds(desc2);
276 if (!lb1 || !lb2)
277 return false;
278 // Note that this comparison must be made on the ordered bounds,
279 // otherwise 'a(x:y:1) = a(z:x-1:-1) + 1' may be incorrectly treated
280 // as not overlapping (x=2, y=10, z=9).
281 if (isLess(ub1, lb2) || isLess(ub2, lb1))
282 return true;
283 return false;
284 }
285
286 // Given two array sections <lb1, ub1, stride1> and
287 // <lb2, ub2, stride2>, return true only if the sections
288 // are known to be identical.
289 //
290 // For example:
291 // <x, x, stride>
292 // <x, nullptr, nullptr>
293 //
294 // These sections are identical, from the point of which array
295 // elements are being addresses, even though the shape
296 // of the array slices might be different.
297 static bool areIdenticalSections(const SectionDesc &desc1,
298 const SectionDesc &desc2) {
299 if (desc1 == desc2)
300 return true;
301 return false;
302 }
303
304 // Return true, if v1 is known to be less than v2.
305 static bool isLess(mlir::Value v1, mlir::Value v2);
306};
307
308ArraySectionAnalyzer::SlicesOverlapKind
309ArraySectionAnalyzer::analyze(mlir::Value ref1, mlir::Value ref2) {
310 if (ref1 == ref2)
311 return SlicesOverlapKind::DefinitelyIdentical;
312
313 auto des1 = ref1.getDefiningOp<hlfir::DesignateOp>();
314 auto des2 = ref2.getDefiningOp<hlfir::DesignateOp>();
315 // We only support a pair of designators right now.
316 if (!des1 || !des2)
317 return SlicesOverlapKind::Unknown;
318
319 if (des1.getMemref() != des2.getMemref()) {
320 // If the bases are different, then there is unknown overlap.
321 LLVM_DEBUG(llvm::dbgs() << "No identical base for:\n"
322 << des1 << "and:\n"
323 << des2 << "\n");
324 return SlicesOverlapKind::Unknown;
325 }
326
327 // Require all components of the designators to be the same.
328 // It might be too strict, e.g. we may probably allow for
329 // different type parameters.
330 if (des1.getComponent() != des2.getComponent() ||
331 des1.getComponentShape() != des2.getComponentShape() ||
332 des1.getSubstring() != des2.getSubstring() ||
333 des1.getComplexPart() != des2.getComplexPart() ||
334 des1.getTypeparams() != des2.getTypeparams()) {
335 LLVM_DEBUG(llvm::dbgs() << "Different designator specs for:\n"
336 << des1 << "and:\n"
337 << des2 << "\n");
338 return SlicesOverlapKind::Unknown;
339 }
340
341 // Analyze the subscripts.
342 auto des1It = des1.getIndices().begin();
343 auto des2It = des2.getIndices().begin();
344 bool identicalTriplets = true;
345 bool identicalIndices = true;
346 for (auto [isTriplet1, isTriplet2] :
347 llvm::zip(des1.getIsTriplet(), des2.getIsTriplet())) {
348 SectionDesc desc1 = readSectionDesc(des1It, isTriplet1);
349 SectionDesc desc2 = readSectionDesc(des2It, isTriplet2);
350
351 // See if we can prove that any of the sections do not overlap.
352 // This is mostly a Polyhedron/nf performance hack that looks for
353 // particular relations between the lower and upper bounds
354 // of the array sections, e.g. for any positive constant C:
355 // X:Y does not overlap with (Y+C):Z
356 // X:Y does not overlap with Z:(X-C)
357 if (areDisjointSections(desc1, desc2))
358 return SlicesOverlapKind::DefinitelyDisjoint;
359
360 if (!areIdenticalSections(desc1, desc2)) {
361 if (isTriplet1 || isTriplet2) {
362 // For example:
363 // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0)
364 // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1)
365 //
366 // If all the triplets (section speficiers) are the same, then
367 // we do not care if %0 is equal to %1 - the slices are either
368 // identical or completely disjoint.
369 //
370 // Also, treat these as identical sections:
371 // hlfir.designate %6#0 (%c2:%c2:%c1)
372 // hlfir.designate %6#0 (%c2)
373 identicalTriplets = false;
374 LLVM_DEBUG(llvm::dbgs() << "Triplet mismatch for:\n"
375 << des1 << "and:\n"
376 << des2 << "\n");
377 } else {
378 identicalIndices = false;
379 LLVM_DEBUG(llvm::dbgs() << "Indices mismatch for:\n"
380 << des1 << "and:\n"
381 << des2 << "\n");
382 }
383 }
384 }
385
386 if (identicalTriplets) {
387 if (identicalIndices)
388 return SlicesOverlapKind::DefinitelyIdentical;
389 else
390 return SlicesOverlapKind::EitherIdenticalOrDisjoint;
391 }
392
393 LLVM_DEBUG(llvm::dbgs() << "Different sections for:\n"
394 << des1 << "and:\n"
395 << des2 << "\n");
396 return SlicesOverlapKind::Unknown;
397}
398
399bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) {
400 auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
401 auto *op = v.getDefiningOp();
402 while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op))
403 op = conv.getValue().getDefiningOp();
404 return op;
405 };
406
407 auto isPositiveConstant = [](mlir::Value v) -> bool {
408 if (auto val = fir::getIntIfConstant(v))
409 return *val > 0;
410 return false;
411 };
412
413 auto *op1 = removeConvert(v1);
414 auto *op2 = removeConvert(v2);
415 if (!op1 || !op2)
416 return false;
417
418 // Check if they are both constants.
419 if (auto val1 = fir::getIntIfConstant(op1->getResult(0)))
420 if (auto val2 = fir::getIntIfConstant(op2->getResult(0)))
421 return *val1 < *val2;
422
423 // Handle some variable cases (C > 0):
424 // v2 = v1 + C
425 // v2 = C + v1
426 // v1 = v2 - C
427 if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
428 if ((addi.getLhs().getDefiningOp() == op1 &&
429 isPositiveConstant(addi.getRhs())) ||
430 (addi.getRhs().getDefiningOp() == op1 &&
431 isPositiveConstant(addi.getLhs())))
432 return true;
433 if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
434 if (subi.getLhs().getDefiningOp() == op2 &&
435 isPositiveConstant(subi.getRhs()))
436 return true;
437 return false;
438}
439
440llvm::SmallVector<mlir::Value>
441ElementalAssignBufferization::getDesignatorIndices(
442 hlfir::DesignateOp designate) {
443 mlir::Value memref = designate.getMemref();
444
445 // If the object is a box, then the indices may be adjusted
446 // according to the box's lower bound(s). Scan through
447 // the computations to try to find the one-based indices.
448 if (mlir::isa<fir::BaseBoxType>(memref.getType())) {
449 // Look for the following pattern:
450 // %13 = fir.load %12 : !fir.ref<!fir.box<...>
451 // %14:3 = fir.box_dims %13, %c0 : (!fir.box<...>, index) -> ...
452 // %17 = arith.subi %14#0, %c1 : index
453 // %18 = arith.addi %arg2, %17 : index
454 // %19 = hlfir.designate %13 (%18) : (!fir.box<...>, index) -> ...
455 //
456 // %arg2 is a one-based index.
457
458 auto isNormalizedLb = [memref](mlir::Value v, unsigned dim) {
459 // Return true, if v and dim are such that:
460 // %14:3 = fir.box_dims %13, %dim : (!fir.box<...>, index) -> ...
461 // %17 = arith.subi %14#0, %c1 : index
462 // %19 = hlfir.designate %13 (...) : (!fir.box<...>, index) -> ...
463 if (auto subOp =
464 mlir::dyn_cast_or_null<mlir::arith::SubIOp>(v.getDefiningOp())) {
465 auto cst = fir::getIntIfConstant(subOp.getRhs());
466 if (!cst || *cst != 1)
467 return false;
468 if (auto dimsOp = mlir::dyn_cast_or_null<fir::BoxDimsOp>(
469 subOp.getLhs().getDefiningOp())) {
470 if (memref != dimsOp.getVal() ||
471 dimsOp.getResult(0) != subOp.getLhs())
472 return false;
473 auto dimsOpDim = fir::getIntIfConstant(dimsOp.getDim());
474 return dimsOpDim && dimsOpDim == dim;
475 }
476 }
477 return false;
478 };
479
480 llvm::SmallVector<mlir::Value> newIndices;
481 for (auto index : llvm::enumerate(designate.getIndices())) {
482 if (auto addOp = mlir::dyn_cast_or_null<mlir::arith::AddIOp>(
483 index.value().getDefiningOp())) {
484 for (unsigned opNum = 0; opNum < 2; ++opNum)
485 if (isNormalizedLb(addOp->getOperand(opNum), index.index())) {
486 newIndices.push_back(addOp->getOperand((opNum + 1) % 2));
487 break;
488 }
489
490 // If new one-based index was not added, exit early.
491 if (newIndices.size() <= index.index())
492 break;
493 }
494 }
495
496 // If any of the indices is not adjusted to the array's lb,
497 // then return the original designator indices.
498 if (newIndices.size() != designate.getIndices().size())
499 return designate.getIndices();
500
501 return newIndices;
502 }
503
504 return designate.getIndices();
505}
506
507std::optional<ElementalAssignBufferization::MatchInfo>
508ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
509 mlir::Operation::user_range users = elemental->getUsers();
510 // the only uses of the elemental should be the assignment and the destroy
511 if (std::distance(users.begin(), users.end()) != 2) {
512 LLVM_DEBUG(llvm::dbgs() << "Too many uses of the elemental\n");
513 return std::nullopt;
514 }
515
516 // If the ElementalOp must produce a temporary (e.g. for
517 // finalization purposes), then we cannot inline it.
518 if (hlfir::elementalOpMustProduceTemp(elemental)) {
519 LLVM_DEBUG(llvm::dbgs() << "ElementalOp must produce a temp\n");
520 return std::nullopt;
521 }
522
523 MatchInfo match;
524 for (mlir::Operation *user : users)
525 mlir::TypeSwitch<mlir::Operation *, void>(user)
526 .Case([&](hlfir::AssignOp op) { match.assign = op; })
527 .Case([&](hlfir::DestroyOp op) { match.destroy = op; });
528
529 if (!match.assign || !match.destroy) {
530 LLVM_DEBUG(llvm::dbgs() << "Couldn't find assign or destroy\n");
531 return std::nullopt;
532 }
533
534 // the array is what the elemental is assigned into
535 // TODO: this could be extended to also allow hlfir.expr by first bufferizing
536 // the incoming expression
537 match.array = match.assign.getLhs();
538 mlir::Type arrayType = mlir::dyn_cast<fir::SequenceType>(
539 fir::unwrapPassByRefType(match.array.getType()));
540 if (!arrayType) {
541 LLVM_DEBUG(llvm::dbgs() << "AssignOp's result is not an array\n");
542 return std::nullopt;
543 }
544
545 // require that the array elements are trivial
546 // TODO: this is just to make the pass easier to think about. Not an inherent
547 // limitation
548 mlir::Type eleTy = hlfir::getFortranElementType(arrayType);
549 if (!fir::isa_trivial(eleTy)) {
550 LLVM_DEBUG(llvm::dbgs() << "AssignOp's data type is not trivial\n");
551 return std::nullopt;
552 }
553
554 // The array must have the same shape as the elemental.
555 //
556 // f2018 10.2.1.2 (3) requires the lhs and rhs of an assignment to be
557 // conformable unless the lhs is an allocatable array. In HLFIR we can
558 // see this from the presence or absence of the realloc attribute on
559 // hlfir.assign. If it is not a realloc assignment, we can trust that
560 // the shapes do conform.
561 //
562 // TODO: the lhs's shape is dynamic, so it is hard to prove that
563 // there is no reallocation of the lhs due to the assignment.
564 // We can probably try generating multiple versions of the code
565 // with checking for the shape match, length parameters match, etc.
566 if (match.assign.isAllocatableAssignment()) {
567 LLVM_DEBUG(llvm::dbgs() << "AssignOp may involve (re)allocation of LHS\n");
568 return std::nullopt;
569 }
570
571 // the transformation wants to apply the elemental in a do-loop at the
572 // hlfir.assign, check there are no effects which make this unsafe
573
574 // keep track of any values written to in the elemental, as these can't be
575 // read from or written to between the elemental and the assignment
576 mlir::SmallVector<mlir::Value, 1> notToBeAccessedBeforeAssign;
577 // likewise, values read in the elemental cannot be written to between the
578 // elemental and the assign
579 mlir::SmallVector<mlir::Value, 1> notToBeWrittenBeforeAssign;
580
581 // 1) side effects in the elemental body - it isn't sufficient to just look
582 // for ordered elementals because we also cannot support out of order reads
583 std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>>
584 effects = getEffectsBetween(&elemental.getBody()->front(),
585 elemental.getBody()->getTerminator());
586 if (!effects) {
587 LLVM_DEBUG(llvm::dbgs()
588 << "operation with unknown effects inside elemental\n");
589 return std::nullopt;
590 }
591 for (const mlir::MemoryEffects::EffectInstance &effect : *effects) {
592 mlir::AliasResult res = containsReadOrWriteEffectOn(effect, match.array);
593 if (res.isNo()) {
594 if (effect.getValue()) {
595 if (mlir::isa<mlir::MemoryEffects::Write>(effect.getEffect()))
596 notToBeAccessedBeforeAssign.push_back(effect.getValue());
597 else if (mlir::isa<mlir::MemoryEffects::Read>(effect.getEffect()))
598 notToBeWrittenBeforeAssign.push_back(effect.getValue());
599 }
600
601 // this is safe in the elemental
602 continue;
603 }
604
605 // don't allow any aliasing writes in the elemental
606 if (mlir::isa<mlir::MemoryEffects::Write>(effect.getEffect())) {
607 LLVM_DEBUG(llvm::dbgs() << "write inside the elemental body\n");
608 return std::nullopt;
609 }
610
611 if (effect.getValue() == nullptr) {
612 LLVM_DEBUG(llvm::dbgs()
613 << "side-effect with no value, cannot analyze further\n");
614 return std::nullopt;
615 }
616
617 // allow if and only if the reads are from the elemental indices, in order
618 // => each iteration doesn't read values written by other iterations
619 // don't allow reads from a different value which may alias: fir alias
620 // analysis isn't precise enough to tell us if two aliasing arrays overlap
621 // exactly or only partially. If they overlap partially, a designate at the
622 // elemental indices could be accessing different elements: e.g. we could
623 // designate two slices of the same array at different start indexes. These
624 // two MustAlias but index 1 of one array isn't the same element as index 1
625 // of the other array.
626 if (!res.isPartial()) {
627 if (auto designate =
628 effect.getValue().getDefiningOp<hlfir::DesignateOp>()) {
629 ArraySectionAnalyzer::SlicesOverlapKind overlap =
630 ArraySectionAnalyzer::analyze(match.array, designate.getMemref());
631 if (overlap ==
632 ArraySectionAnalyzer::SlicesOverlapKind::DefinitelyDisjoint)
633 continue;
634
635 if (overlap == ArraySectionAnalyzer::SlicesOverlapKind::Unknown) {
636 LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate
637 << " at " << elemental.getLoc() << "\n");
638 return std::nullopt;
639 }
640 auto indices = getDesignatorIndices(designate);
641 auto elementalIndices = elemental.getIndices();
642 if (indices.size() == elementalIndices.size() &&
643 std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
644 elementalIndices.end()))
645 continue;
646
647 LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate
648 << " at " << elemental.getLoc() << "\n");
649 return std::nullopt;
650 }
651 }
652 LLVM_DEBUG(llvm::dbgs() << "disallowed side-effect: " << effect.getValue()
653 << " for " << elemental.getLoc() << "\n");
654 return std::nullopt;
655 }
656
657 // 2) look for conflicting effects between the elemental and the assignment
658 effects = getEffectsBetween(elemental->getNextNode(), match.assign);
659 if (!effects) {
660 LLVM_DEBUG(
661 llvm::dbgs()
662 << "operation with unknown effects between elemental and assign\n");
663 return std::nullopt;
664 }
665 for (const mlir::MemoryEffects::EffectInstance &effect : *effects) {
666 // not safe to access anything written in the elemental as this write
667 // will be moved to the assignment
668 for (mlir::Value val : notToBeAccessedBeforeAssign) {
669 mlir::AliasResult res = containsReadOrWriteEffectOn(effect, val);
670 if (!res.isNo()) {
671 LLVM_DEBUG(llvm::dbgs()
672 << "disallowed side-effect: " << effect.getValue() << " for "
673 << elemental.getLoc() << "\n");
674 return std::nullopt;
675 }
676 }
677 // Anything that is read inside the elemental can only be safely read
678 // between the elemental and the assignment.
679 for (mlir::Value val : notToBeWrittenBeforeAssign) {
680 mlir::AliasResult res = containsReadOrWriteEffectOn(effect, val);
681 if (!res.isNo() &&
682 !mlir::isa<mlir::MemoryEffects::Read>(effect.getEffect())) {
683 LLVM_DEBUG(llvm::dbgs()
684 << "disallowed non-read side-effect: " << effect.getValue()
685 << " for " << elemental.getLoc() << "\n");
686 return std::nullopt;
687 }
688 }
689 }
690
691 return match;
692}
693
694llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
695 hlfir::ElementalOp elemental, mlir::PatternRewriter &rewriter) const {
696 std::optional<MatchInfo> match = findMatch(elemental);
697 if (!match)
698 return rewriter.notifyMatchFailure(
699 elemental, "cannot prove safety of ElementalAssignBufferization");
700
701 mlir::Location loc = elemental->getLoc();
702 fir::FirOpBuilder builder(rewriter, elemental.getOperation());
703 auto rhsExtents = hlfir::getIndexExtents(loc, builder, elemental.getShape());
704
705 // create the loop at the assignment
706 builder.setInsertionPoint(match->assign);
707 hlfir::Entity lhs{match->array};
708 lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
709 mlir::Value lhsShape = hlfir::genShape(loc, builder, lhs);
710 llvm::SmallVector<mlir::Value> lhsExtents =
711 hlfir::getIndexExtents(loc, builder, lhsShape);
712 llvm::SmallVector<mlir::Value> extents =
713 fir::factory::deduceOptimalExtents(rhsExtents, lhsExtents);
714
715 // Generate a loop nest looping around the hlfir.elemental shape and clone
716 // hlfir.elemental region inside the inner loop
717 hlfir::LoopNest loopNest =
718 hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered(),
719 flangomp::shouldUseWorkshareLowering(elemental));
720 builder.setInsertionPointToStart(loopNest.body);
721 auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
722 loopNest.oneBasedIndices);
723 hlfir::Entity elementValue{yield.getElementValue()};
724 rewriter.eraseOp(yield);
725
726 // Assign the element value to the array element for this iteration.
727 auto arrayElement =
728 hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
729 builder.create<hlfir::AssignOp>(
730 loc, elementValue, arrayElement, /*realloc=*/false,
731 /*keep_lhs_length_if_realloc=*/false, match->assign.getTemporaryLhs());
732
733 rewriter.eraseOp(match->assign);
734 rewriter.eraseOp(match->destroy);
735 rewriter.eraseOp(elemental);
736 return mlir::success();
737}
738
739/// Expand hlfir.assign of a scalar RHS to array LHS into a loop nest
740/// of element-by-element assignments:
741/// hlfir.assign %cst to %0 : f32, !fir.ref<!fir.array<6x6xf32>>
742/// into:
743/// fir.do_loop %arg0 = %c1 to %c6 step %c1 unordered {
744/// fir.do_loop %arg1 = %c1 to %c6 step %c1 unordered {
745/// %1 = hlfir.designate %0 (%arg1, %arg0) :
746/// (!fir.ref<!fir.array<6x6xf32>>, index, index) -> !fir.ref<f32>
747/// hlfir.assign %cst to %1 : f32, !fir.ref<f32>
748/// }
749/// }
750class BroadcastAssignBufferization
751 : public mlir::OpRewritePattern<hlfir::AssignOp> {
752private:
753public:
754 using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern;
755
756 llvm::LogicalResult
757 matchAndRewrite(hlfir::AssignOp assign,
758 mlir::PatternRewriter &rewriter) const override;
759};
760
761llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
762 hlfir::AssignOp assign, mlir::PatternRewriter &rewriter) const {
763 // Since RHS is a scalar and LHS is an array, LHS must be allocated
764 // in a conforming Fortran program, and LHS cannot be reallocated
765 // as a result of the assignment. So we can ignore isAllocatableAssignment
766 // and do the transformation always.
767 mlir::Value rhs = assign.getRhs();
768 if (!fir::isa_trivial(rhs.getType()))
769 return rewriter.notifyMatchFailure(
770 assign, "AssignOp's RHS is not a trivial scalar");
771
772 hlfir::Entity lhs{assign.getLhs()};
773 if (!lhs.isArray())
774 return rewriter.notifyMatchFailure(assign,
775 "AssignOp's LHS is not an array");
776
777 mlir::Type eleTy = lhs.getFortranElementType();
778 if (!fir::isa_trivial(eleTy))
779 return rewriter.notifyMatchFailure(
780 assign, "AssignOp's LHS data type is not trivial");
781
782 mlir::Location loc = assign->getLoc();
783 fir::FirOpBuilder builder(rewriter, assign.getOperation());
784 builder.setInsertionPoint(assign);
785 lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
786 mlir::Value shape = hlfir::genShape(loc, builder, lhs);
787 llvm::SmallVector<mlir::Value> extents =
788 hlfir::getIndexExtents(loc, builder, shape);
789 hlfir::LoopNest loopNest =
790 hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
791 flangomp::shouldUseWorkshareLowering(assign));
792 builder.setInsertionPointToStart(loopNest.body);
793 auto arrayElement =
794 hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
795 builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
796 rewriter.eraseOp(assign);
797 return mlir::success();
798}
799
800class EvaluateIntoMemoryAssignBufferization
801 : public mlir::OpRewritePattern<hlfir::EvaluateInMemoryOp> {
802
803public:
804 using mlir::OpRewritePattern<hlfir::EvaluateInMemoryOp>::OpRewritePattern;
805
806 llvm::LogicalResult
807 matchAndRewrite(hlfir::EvaluateInMemoryOp,
808 mlir::PatternRewriter &rewriter) const override;
809};
810
811static llvm::LogicalResult
812tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem,
813 mlir::PatternRewriter &rewriter) {
814 mlir::Location loc = evalInMem.getLoc();
815 hlfir::DestroyOp destroy;
816 hlfir::AssignOp assign;
817 for (auto user : llvm::enumerate(evalInMem->getUsers())) {
818 if (user.index() > 2)
819 return mlir::failure();
820 mlir::TypeSwitch<mlir::Operation *, void>(user.value())
821 .Case([&](hlfir::AssignOp op) { assign = op; })
822 .Case([&](hlfir::DestroyOp op) { destroy = op; });
823 }
824 if (!assign || !destroy || destroy.mustFinalizeExpr() ||
825 assign.isAllocatableAssignment())
826 return mlir::failure();
827
828 hlfir::Entity lhs{assign.getLhs()};
829 // EvaluateInMemoryOp memory is contiguous, so in general, it can only be
830 // replace by the LHS if the LHS is contiguous.
831 if (!lhs.isSimplyContiguous())
832 return mlir::failure();
833 // Character assignment may involves truncation/padding, so the LHS
834 // cannot be used to evaluate RHS in place without proving the LHS and
835 // RHS lengths are the same.
836 if (lhs.isCharacter())
837 return mlir::failure();
838 fir::AliasAnalysis aliasAnalysis;
839 // The region must not read or write the LHS.
840 // Note that getModRef is used instead of mlir::MemoryEffects because
841 // EvaluateInMemoryOp is typically expected to hold fir.calls and that
842 // Fortran calls cannot be modeled in a useful way with mlir::MemoryEffects:
843 // it is hard/impossible to list all the read/written SSA values in a call,
844 // but it is often possible to tell that an SSA value cannot be accessed,
845 // hence getModRef is needed here and below. Also note that getModRef uses
846 // mlir::MemoryEffects for operations that do not have special handling in
847 // getModRef.
848 if (aliasAnalysis.getModRef(evalInMem.getBody(), lhs).isModOrRef())
849 return mlir::failure();
850 // Any variables affected between the hlfir.evalInMem and assignment must not
851 // be read or written inside the region since it will be moved at the
852 // assignment insertion point.
853 auto effects = getEffectsBetween(evalInMem->getNextNode(), assign);
854 if (!effects) {
855 LLVM_DEBUG(
856 llvm::dbgs()
857 << "operation with unknown effects between eval_in_mem and assign\n");
858 return mlir::failure();
859 }
860 for (const mlir::MemoryEffects::EffectInstance &effect : *effects) {
861 mlir::Value affected = effect.getValue();
862 if (!affected ||
863 aliasAnalysis.getModRef(evalInMem.getBody(), affected).isModOrRef())
864 return mlir::failure();
865 }
866
867 rewriter.setInsertionPoint(assign);
868 fir::FirOpBuilder builder(rewriter, evalInMem.getOperation());
869 mlir::Value rawLhs = hlfir::genVariableRawAddress(loc, builder, lhs);
870 hlfir::computeEvaluateOpIn(loc, builder, evalInMem, rawLhs);
871 rewriter.eraseOp(assign);
872 rewriter.eraseOp(destroy);
873 rewriter.eraseOp(evalInMem);
874 return mlir::success();
875}
876
877llvm::LogicalResult EvaluateIntoMemoryAssignBufferization::matchAndRewrite(
878 hlfir::EvaluateInMemoryOp evalInMem,
879 mlir::PatternRewriter &rewriter) const {
880 if (mlir::succeeded(tryUsingAssignLhsDirectly(evalInMem, rewriter)))
881 return mlir::success();
882 // Rewrite to temp + as_expr here so that the assign + as_expr pattern can
883 // kick-in for simple types and at least implement the assignment inline
884 // instead of call Assign runtime.
885 fir::FirOpBuilder builder(rewriter, evalInMem.getOperation());
886 mlir::Location loc = evalInMem.getLoc();
887 auto [temp, isHeapAllocated] = hlfir::computeEvaluateOpInNewTemp(
888 loc, builder, evalInMem, evalInMem.getShape(), evalInMem.getTypeparams());
889 rewriter.replaceOpWithNewOp<hlfir::AsExprOp>(
890 evalInMem, temp, /*mustFree=*/builder.createBool(loc, isHeapAllocated));
891 return mlir::success();
892}
893
894class OptimizedBufferizationPass
895 : public hlfir::impl::OptimizedBufferizationBase<
896 OptimizedBufferizationPass> {
897public:
898 void runOnOperation() override {
899 mlir::MLIRContext *context = &getContext();
900
901 mlir::GreedyRewriteConfig config;
902 // Prevent the pattern driver from merging blocks
903 config.setRegionSimplificationLevel(
904 mlir::GreedySimplifyRegionLevel::Disabled);
905
906 mlir::RewritePatternSet patterns(context);
907 // TODO: right now the patterns are non-conflicting,
908 // but it might be better to run this pass on hlfir.assign
909 // operations and decide which transformation to apply
910 // at one place (e.g. we may use some heuristics and
911 // choose different optimization strategies).
912 // This requires small code reordering in ElementalAssignBufferization.
913 patterns.insert<ElementalAssignBufferization>(context);
914 patterns.insert<BroadcastAssignBufferization>(context);
915 patterns.insert<EvaluateIntoMemoryAssignBufferization>(context);
916
917 if (mlir::failed(mlir::applyPatternsGreedily(
918 getOperation(), std::move(patterns), config))) {
919 mlir::emitError(getOperation()->getLoc(),
920 "failure in HLFIR optimized bufferization");
921 signalPassFailure();
922 }
923 }
924};
925} // namespace
926

source code of flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp