1//===- LowerHLFIROrderedAssignments.cpp - Lower HLFIR ordered assignments -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This file defines a pass to lower HLFIR ordered assignments.
9// Ordered assignments are all the operations with the
10// OrderedAssignmentTreeOpInterface that implements user defined assignments,
11// assignment to vector subscripted entities, and assignments inside forall and
12// where.
13// The pass lowers these operations to regular hlfir.assign, loops and, if
14// needed, introduces temporary storage to fulfill Fortran semantics.
15//
16// For each rewrite, an analysis builds an evaluation schedule, and then the
17// new code is generated by following the evaluation schedule.
18//===----------------------------------------------------------------------===//
19
20#include "ScheduleOrderedAssignments.h"
21#include "flang/Optimizer/Builder/FIRBuilder.h"
22#include "flang/Optimizer/Builder/HLFIRTools.h"
23#include "flang/Optimizer/Builder/TemporaryStorage.h"
24#include "flang/Optimizer/Builder/Todo.h"
25#include "flang/Optimizer/Dialect/Support/FIRContext.h"
26#include "flang/Optimizer/HLFIR/Passes.h"
27#include "mlir/Analysis/Liveness.h"
28#include "mlir/IR/Dominance.h"
29#include "mlir/IR/IRMapping.h"
30#include "mlir/Transforms/DialectConversion.h"
31#include "mlir/Transforms/RegionUtils.h"
32#include "llvm/ADT/SmallSet.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
35#include <unordered_set>
36
37namespace hlfir {
38#define GEN_PASS_DEF_LOWERHLFIRORDEREDASSIGNMENTS
39#include "flang/Optimizer/HLFIR/Passes.h.inc"
40} // namespace hlfir
41
42#define DEBUG_TYPE "flang-ordered-assignment"
43
44// Test option only to test the scheduling part only (operations are erased
45// without codegen). The only goal is to allow printing and testing the debug
46// info.
47static llvm::cl::opt<bool> dbgScheduleOnly(
48 "flang-dbg-order-assignment-schedule-only",
49 llvm::cl::desc("Only run ordered assignment scheduling with no codegen"),
50 llvm::cl::init(false));
51
52namespace {
53
54/// Structure that represents a masked expression being lowered. Masked
55/// expressions are any expressions inside an hlfir.where. As described in
56/// Fortran 2018 section 10.2.3.2, the evaluation of the elemental parts of such
57/// expressions must be masked, while the evaluation of none elemental parts
58/// must not be masked. This structure analyzes the region evaluating the
59/// expression and allows splitting the generation of the none elemental part
60/// from the elemental part.
61struct MaskedArrayExpr {
62 MaskedArrayExpr(mlir::Location loc, mlir::Region &region,
63 bool isOuterMaskExpr);
64
65 /// Generate the none elemental part. Must be called outside of the
66 /// loops created for the WHERE construct.
67 void generateNoneElementalPart(fir::FirOpBuilder &builder,
68 mlir::IRMapping &mapper);
69
70 /// Methods below can only be called once generateNoneElementalPart has been
71 /// called.
72
73 /// Return the shape of the expression.
74 mlir::Value generateShape(fir::FirOpBuilder &builder,
75 mlir::IRMapping &mapper);
76 /// Return the value of an element value for this expression given the current
77 /// where loop indices.
78 mlir::Value generateElementalParts(fir::FirOpBuilder &builder,
79 mlir::ValueRange oneBasedIndices,
80 mlir::IRMapping &mapper);
81 /// Generate the cleanup for the none elemental parts, if any. This must be
82 /// called after the loops created for the WHERE construct.
83 void generateNoneElementalCleanupIfAny(fir::FirOpBuilder &builder,
84 mlir::IRMapping &mapper);
85
86 /// Helper to clone the clean-ups of the masked expr region terminator.
87 /// This is called outside of the loops for the initial mask, and inside
88 /// the loops for the other masked expressions.
89 mlir::Operation *generateMaskedExprCleanUps(fir::FirOpBuilder &builder,
90 mlir::IRMapping &mapper);
91
92 mlir::Location loc;
93 mlir::Region &region;
94 /// Set of operations that form the elemental parts of the
95 /// expression evaluation. These are the hlfir.elemental and
96 /// hlfir.elemental_addr that form the elemental tree producing
97 /// the expression value. hlfir.elemental that produce values
98 /// used inside transformational operations are not part of this set.
99 llvm::SmallSet<mlir::Operation *, 4> elementalParts{};
100 /// Was generateNoneElementalPart called?
101 bool noneElementalPartWasGenerated = false;
102 /// Is this expression the mask expression of the outer where statement?
103 /// It is special because its evaluation is not masked by anything yet.
104 bool isOuterMaskExpr = false;
105};
106} // namespace
107
108namespace {
109/// Structure that visits an ordered assignment tree and generates code for
110/// it according to a schedule.
111class OrderedAssignmentRewriter {
112public:
113 OrderedAssignmentRewriter(fir::FirOpBuilder &builder,
114 hlfir::OrderedAssignmentTreeOpInterface root)
115 : builder{builder}, root{root} {}
116
117 /// Generate code for the current run of the schedule.
118 void lowerRun(hlfir::Run &run) {
119 currentRun = &run;
120 walk(root);
121 currentRun = nullptr;
122 assert(constructStack.empty() && "must exit constructs after a run");
123 mapper.clear();
124 savedInCurrentRunBeforeUse.clear();
125 }
126
127 /// After all run have been lowered, clean-up all the temporary
128 /// storage that were created (do not call final routines).
129 void cleanupSavedEntities() {
130 for (auto &temp : savedEntities)
131 temp.second.destroy(root.getLoc(), builder);
132 }
133
134 /// Lowered value for an expression, and the original hlfir.yield if any
135 /// clean-up needs to be cloned after usage.
136 using ValueAndCleanUp = std::pair<mlir::Value, std::optional<hlfir::YieldOp>>;
137
138private:
139 /// Walk the part of an order assignment tree node that needs
140 /// to be evaluated in the current run.
141 void walk(hlfir::OrderedAssignmentTreeOpInterface node);
142
143 /// Generate code when entering a given ordered assignment node.
144 void pre(hlfir::ForallOp forallOp);
145 void pre(hlfir::ForallIndexOp);
146 void pre(hlfir::ForallMaskOp);
147 void pre(hlfir::WhereOp whereOp);
148 void pre(hlfir::ElseWhereOp elseWhereOp);
149 void pre(hlfir::RegionAssignOp);
150
151 /// Generate code when leaving a given ordered assignment node.
152 void post(hlfir::ForallOp);
153 void post(hlfir::ForallMaskOp);
154 void post(hlfir::WhereOp);
155 void post(hlfir::ElseWhereOp);
156 /// Enter (and maybe create) the fir.if else block of an ElseWhereOp,
157 /// but do not generate the elswhere mask or the new fir.if.
158 void enterElsewhere(hlfir::ElseWhereOp);
159
160 /// Are there any leaf region in the node that must be saved in the current
161 /// run?
162 bool mustSaveRegionIn(
163 hlfir::OrderedAssignmentTreeOpInterface node,
164 llvm::SmallVectorImpl<hlfir::SaveEntity> &saveEntities) const;
165 /// Should this node be evaluated in the current run? Saving a region in a
166 /// node does not imply the node needs to be evaluated.
167 bool
168 isRequiredInCurrentRun(hlfir::OrderedAssignmentTreeOpInterface node) const;
169
170 /// Generate a scalar value yielded by an ordered assignment tree region.
171 /// If the value was not saved in a previous run, this clone the region
172 /// code, except the final yield, at the current execution point.
173 /// If the value was saved in a previous run, this fetches the saved value
174 /// from the temporary storage and returns the value.
175 /// Inside Forall, the value will be hoisted outside of the forall loops if
176 /// it does not depend on the forall indices.
177 /// An optional type can be provided to get a value from a specific type
178 /// (the cast will be hoisted if the computation is hoisted).
179 mlir::Value generateYieldedScalarValue(
180 mlir::Region &region,
181 std::optional<mlir::Type> castToType = std::nullopt);
182
183 /// Generate an entity yielded by an ordered assignment tree region, and
184 /// optionally return the (uncloned) yield if there is any clean-up that
185 /// should be done after using the entity. Like, generateYieldedScalarValue,
186 /// this will return the saved value if the region was saved in a previous
187 /// run.
188 ValueAndCleanUp
189 generateYieldedEntity(mlir::Region &region,
190 std::optional<mlir::Type> castToType = std::nullopt);
191
192 struct LhsValueAndCleanUp {
193 mlir::Value lhs;
194 std::optional<hlfir::YieldOp> elementalCleanup;
195 mlir::Region *nonElementalCleanup = nullptr;
196 std::optional<hlfir::LoopNest> vectorSubscriptLoopNest;
197 std::optional<mlir::Value> vectorSubscriptShape;
198 };
199
200 /// Generate the left-hand side. If the left-hand side is vector
201 /// subscripted (hlfir.elemental_addr), this will create a loop nest
202 /// (unless it was already created by a WHERE mask) and return the
203 /// element address.
204 LhsValueAndCleanUp
205 generateYieldedLHS(mlir::Location loc, mlir::Region &lhsRegion,
206 std::optional<hlfir::Entity> loweredRhs = std::nullopt);
207
208 /// If \p maybeYield is present and has a clean-up, generate the clean-up
209 /// at the current insertion point (by cloning).
210 void generateCleanupIfAny(std::optional<hlfir::YieldOp> maybeYield);
211 void generateCleanupIfAny(mlir::Region *cleanupRegion);
212
213 /// Generate a masked entity. This can only be called when whereLoopNest was
214 /// set (When an hlfir.where is being visited).
215 /// This method returns the scalar element (that may have been previously
216 /// saved) for the current indices inside the where loop.
217 mlir::Value generateMaskedEntity(mlir::Location loc, mlir::Region &region) {
218 MaskedArrayExpr maskedExpr(loc, region, /*isOuterMaskExpr=*/!whereLoopNest);
219 return generateMaskedEntity(maskedExpr);
220 }
221 mlir::Value generateMaskedEntity(MaskedArrayExpr &maskedExpr);
222
223 /// Create a fir.if at the current position inside the where loop nest
224 /// given the element value of a mask.
225 void generateMaskIfOp(mlir::Value cdt);
226
227 /// Save a value for subsequent runs.
228 void generateSaveEntity(hlfir::SaveEntity savedEntity,
229 bool willUseSavedEntityInSameRun);
230 /// Save a variable address instead of its value.
231 void saveNonVectorSubscriptedAddress(hlfir::SaveEntity savedEntity);
232 /// Save a LHS variable address instead of its value, handling the cases
233 /// where the LHS is vector subscripted.
234 void saveLeftHandSide(hlfir::SaveEntity savedEntity,
235 hlfir::RegionAssignOp regionAssignOp);
236
237 /// Get a value if it was saved in this run or a previous run. Returns
238 /// nullopt if it has not been saved.
239 std::optional<ValueAndCleanUp> getIfSaved(mlir::Region &region);
240
241 /// Generate code before the loop nest for the current run, if any.
242 void doBeforeLoopNest(const std::function<void()> &callback) {
243 if (constructStack.empty()) {
244 callback();
245 return;
246 }
247 auto insertionPoint = builder.saveInsertionPoint();
248 builder.setInsertionPoint(constructStack[0]);
249 callback();
250 builder.restoreInsertionPoint(insertionPoint);
251 }
252
253 /// Can the current loop nest iteration number be computed? For simplicity,
254 /// this is true if and only if all the bounds and steps of the fir.do_loop
255 /// nest dominates the outer loop. The argument is filled with the current
256 /// loop nest on success.
257 bool currentLoopNestIterationNumberCanBeComputed(
258 llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest);
259
260 template <typename T>
261 fir::factory::TemporaryStorage *insertSavedEntity(mlir::Region &region,
262 T &&temp) {
263 auto inserted =
264 savedEntities.insert(std::make_pair(&region, std::forward<T>(temp)));
265 assert(inserted.second && "temp must have been emplaced");
266 return &inserted.first->second;
267 }
268
269 /// Given a top-level hlfir.where, look for hlfir.exactly_once operations
270 /// inside it and see if any of the values live into hlfir.exactly_once
271 /// do not dominate hlfir.where. This may happen due to CSE reusing
272 /// results of operations from the region parent to hlfir.exactly_once.
273 /// Since we are going to clone the body of hlfir.exactly_once before
274 /// the top-level hlfir.where, such def-use will cause problems.
275 /// There are options how to resolve this in a different way,
276 /// e.g. making hlfir.exactly_once IsolatedFromAbove or making
277 /// it a region of hlfir.where and wiring the result(s) through
278 /// the block arguments. For the time being, this canonicalization
279 /// tries to undo the effects of CSE.
280 void canonicalizeExactlyOnceInsideWhere(hlfir::WhereOp whereOp);
281
282 fir::FirOpBuilder &builder;
283
284 /// Map containing the mapping between the original order assignment tree
285 /// operations and the operations that have been cloned in the current run.
286 /// It is reset between two runs.
287 mlir::IRMapping mapper;
288 /// Dominance info is used to determine if inner loop bounds are all computed
289 /// before outer loop for the current loop. It does not need to be reset
290 /// between runs.
291 mlir::DominanceInfo dominanceInfo;
292 /// Construct stack in the current run. This allows setting back the insertion
293 /// point correctly when leaving a node that requires a fir.do_loop or fir.if
294 /// operation.
295 llvm::SmallVector<mlir::Operation *> constructStack;
296 /// Current where loop nest, if any.
297 std::optional<hlfir::LoopNest> whereLoopNest;
298
299 /// Map of temporary storage to keep track of saved entity once the run
300 /// that saves them has been lowered. It is kept in-between runs.
301 /// llvm::MapVector is used to guarantee deterministic order
302 /// of iterating through savedEntities (e.g. for generating
303 /// destruction code for the temporary storages).
304 llvm::MapVector<mlir::Region *, fir::factory::TemporaryStorage> savedEntities;
305 /// Map holding the values that were saved in the current run and that also
306 /// need to be used (because their construct will be visited). It is reset
307 /// after each run. It avoids having to store and fetch in the temporary
308 /// during the same run, which would require the temporary to have different
309 /// fetching and storing counters.
310 llvm::DenseMap<mlir::Region *, ValueAndCleanUp> savedInCurrentRunBeforeUse;
311
312 /// Root of the order assignment tree being lowered.
313 hlfir::OrderedAssignmentTreeOpInterface root;
314 /// Pointer to the current run of the schedule being lowered.
315 hlfir::Run *currentRun = nullptr;
316
317 /// When allocating temporary storage inlined, indicate if the storage should
318 /// be heap or stack allocated. Temporary allocated with the runtime are heap
319 /// allocated by the runtime.
320 bool allocateOnHeap = true;
321};
322} // namespace
323
324void OrderedAssignmentRewriter::walk(
325 hlfir::OrderedAssignmentTreeOpInterface node) {
326 bool mustVisit =
327 isRequiredInCurrentRun(node) || mlir::isa<hlfir::ForallIndexOp>(node);
328 llvm::SmallVector<hlfir::SaveEntity> saveEntities;
329 mlir::Operation *nodeOp = node.getOperation();
330 if (mustSaveRegionIn(node, saveEntities)) {
331 mlir::IRRewriter::InsertPoint insertionPoint;
332 if (auto elseWhereOp = mlir::dyn_cast<hlfir::ElseWhereOp>(nodeOp)) {
333 // ElseWhere mask to save must be evaluated inside the fir.if else
334 // for the previous where/elsewehere (its evaluation must be
335 // masked by the "pending control mask").
336 insertionPoint = builder.saveInsertionPoint();
337 enterElsewhere(elseWhereOp);
338 }
339 for (hlfir::SaveEntity saveEntity : saveEntities)
340 generateSaveEntity(savedEntity: saveEntity, willUseSavedEntityInSameRun: mustVisit);
341 if (insertionPoint.isSet())
342 builder.restoreInsertionPoint(insertionPoint);
343 }
344 if (mustVisit) {
345 llvm::TypeSwitch<mlir::Operation *, void>(nodeOp)
346 .Case<hlfir::ForallOp, hlfir::ForallIndexOp, hlfir::ForallMaskOp,
347 hlfir::RegionAssignOp, hlfir::WhereOp, hlfir::ElseWhereOp>(
348 [&](auto concreteOp) { pre(concreteOp); })
349 .Default([](auto) {});
350 if (auto *body = node.getSubTreeRegion()) {
351 for (mlir::Operation &op : body->getOps())
352 if (auto subNode =
353 mlir::dyn_cast<hlfir::OrderedAssignmentTreeOpInterface>(op))
354 walk(subNode);
355 llvm::TypeSwitch<mlir::Operation *, void>(nodeOp)
356 .Case<hlfir::ForallOp, hlfir::ForallMaskOp, hlfir::WhereOp,
357 hlfir::ElseWhereOp>([&](auto concreteOp) { post(concreteOp); })
358 .Default([](auto) {});
359 }
360 }
361}
362
363void OrderedAssignmentRewriter::pre(hlfir::ForallOp forallOp) {
364 /// Create a fir.do_loop given the hlfir.forall control values.
365 mlir::Type idxTy = builder.getIndexType();
366 mlir::Location loc = forallOp.getLoc();
367 mlir::Value lb = generateYieldedScalarValue(forallOp.getLbRegion(), idxTy);
368 mlir::Value ub = generateYieldedScalarValue(forallOp.getUbRegion(), idxTy);
369 mlir::Value step;
370 if (forallOp.getStepRegion().empty()) {
371 auto insertionPoint = builder.saveInsertionPoint();
372 if (!constructStack.empty())
373 builder.setInsertionPoint(constructStack[0]);
374 step = builder.createIntegerConstant(loc, idxTy, 1);
375 if (!constructStack.empty())
376 builder.restoreInsertionPoint(insertionPoint);
377 } else {
378 step = generateYieldedScalarValue(forallOp.getStepRegion(), idxTy);
379 }
380 auto doLoop = builder.create<fir::DoLoopOp>(loc, lb, ub, step);
381 builder.setInsertionPointToStart(doLoop.getBody());
382 mlir::Value oldIndex = forallOp.getForallIndexValue();
383 mlir::Value newIndex =
384 builder.createConvert(loc, oldIndex.getType(), doLoop.getInductionVar());
385 mapper.map(oldIndex, newIndex);
386 constructStack.push_back(doLoop);
387}
388
389void OrderedAssignmentRewriter::post(hlfir::ForallOp) {
390 assert(!constructStack.empty() && "must contain a loop");
391 builder.setInsertionPointAfter(constructStack.pop_back_val());
392}
393
394void OrderedAssignmentRewriter::pre(hlfir::ForallIndexOp forallIndexOp) {
395 mlir::Location loc = forallIndexOp.getLoc();
396 mlir::Type intTy = fir::unwrapRefType(forallIndexOp.getType());
397 mlir::Value indexVar =
398 builder.createTemporary(loc, intTy, forallIndexOp.getName());
399 mlir::Value newVal = mapper.lookupOrDefault(forallIndexOp.getIndex());
400 builder.createStoreWithConvert(loc, newVal, indexVar);
401 mapper.map(forallIndexOp, indexVar);
402}
403
404void OrderedAssignmentRewriter::pre(hlfir::ForallMaskOp forallMaskOp) {
405 mlir::Location loc = forallMaskOp.getLoc();
406 mlir::Value mask = generateYieldedScalarValue(forallMaskOp.getMaskRegion(),
407 builder.getI1Type());
408 auto ifOp = builder.create<fir::IfOp>(loc, std::nullopt, mask, false);
409 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
410 constructStack.push_back(ifOp);
411}
412
413void OrderedAssignmentRewriter::post(hlfir::ForallMaskOp forallMaskOp) {
414 assert(!constructStack.empty() && "must contain an ifop");
415 builder.setInsertionPointAfter(constructStack.pop_back_val());
416}
417
418/// Convert an entity to the type of a given mold.
419/// This is intended to help with cases where hlfir entity is a value while
420/// it must be used as a variable or vice-versa. These mismatches may occur
421/// between the type of user defined assignment block arguments and the actual
422/// argument that was lowered for them. The actual may be an in-memory copy
423/// while the block argument expects an hlfir.expr.
424static hlfir::Entity
425convertToMoldType(mlir::Location loc, fir::FirOpBuilder &builder,
426 hlfir::Entity input, hlfir::Entity mold,
427 llvm::SmallVectorImpl<hlfir::CleanupFunction> &cleanups) {
428 if (input.getType() == mold.getType())
429 return input;
430 fir::FirOpBuilder *b = &builder;
431 if (input.isVariable() && mold.isValue()) {
432 if (fir::isa_trivial(mold.getType())) {
433 // fir.ref<T> to T.
434 mlir::Value load = builder.create<fir::LoadOp>(loc, input);
435 return hlfir::Entity{builder.createConvert(loc, mold.getType(), load)};
436 }
437 // fir.ref<T> to hlfir.expr<T>.
438 mlir::Value asExpr = builder.create<hlfir::AsExprOp>(loc, input);
439 if (asExpr.getType() != mold.getType())
440 TODO(loc, "hlfir.expr conversion");
441 cleanups.emplace_back([=]() { b->create<hlfir::DestroyOp>(loc, asExpr); });
442 return hlfir::Entity{asExpr};
443 }
444 if (input.isValue() && mold.isVariable()) {
445 // T to fir.ref<T>, or hlfir.expr<T> to fir.ref<T>.
446 hlfir::AssociateOp associate = hlfir::genAssociateExpr(
447 loc, builder, input, mold.getFortranElementType(), ".tmp.val2ref");
448 cleanups.emplace_back(
449 [=]() { b->create<hlfir::EndAssociateOp>(loc, associate); });
450 return hlfir::Entity{associate.getBase()};
451 }
452 // Variable to Variable mismatch (e.g., fir.heap<T> vs fir.ref<T>), or value
453 // to Value mismatch (e.g. i1 vs fir.logical<4>).
454 if (mlir::isa<fir::BaseBoxType>(mold.getType()) &&
455 !mlir::isa<fir::BaseBoxType>(input.getType())) {
456 // An entity may have have been saved without descriptor while the original
457 // value had a descriptor (e.g., it was not contiguous).
458 auto emboxed = hlfir::convertToBox(loc, builder, input, mold.getType());
459 assert(!emboxed.second && "temp should already be in memory");
460 input = hlfir::Entity{fir::getBase(emboxed.first)};
461 }
462 return hlfir::Entity{builder.createConvert(loc, mold.getType(), input)};
463}
464
465void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
466 mlir::Location loc = regionAssignOp.getLoc();
467 if (regionAssignOp.isPointerAssignment()) {
468 auto [lhsValue, oldLhsYield] =
469 generateYieldedEntity(regionAssignOp.getLhsRegion());
470 auto [rhsValue, oldRhsYield] =
471 generateYieldedEntity(regionAssignOp.getRhsRegion());
472 builder.createStoreWithConvert(loc, rhsValue, lhsValue);
473 generateCleanupIfAny(oldLhsYield);
474 generateCleanupIfAny(oldRhsYield);
475 return;
476 }
477 auto [rhsValue, oldRhsYield] =
478 generateYieldedEntity(regionAssignOp.getRhsRegion());
479 hlfir::Entity rhsEntity{rhsValue};
480 LhsValueAndCleanUp loweredLhs =
481 generateYieldedLHS(loc, regionAssignOp.getLhsRegion(), rhsEntity);
482 hlfir::Entity lhsEntity{loweredLhs.lhs};
483 if (loweredLhs.vectorSubscriptLoopNest)
484 rhsEntity = hlfir::getElementAt(
485 loc, builder, rhsEntity,
486 loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
487 if (!regionAssignOp.getUserDefinedAssignment().empty()) {
488 hlfir::Entity userAssignLhs{regionAssignOp.getUserAssignmentLhs()};
489 hlfir::Entity userAssignRhs{regionAssignOp.getUserAssignmentRhs()};
490 std::optional<hlfir::LoopNest> elementalLoopNest;
491 if (lhsEntity.isArray() && userAssignLhs.isScalar()) {
492 // Elemental assignment with array argument (the RHS cannot be an array
493 // if the LHS is not).
494 mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
495 elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
496 builder.setInsertionPointToStart(elementalLoopNest->body);
497 lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
498 elementalLoopNest->oneBasedIndices);
499 rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
500 elementalLoopNest->oneBasedIndices);
501 }
502
503 llvm::SmallVector<hlfir::CleanupFunction, 2> argConversionCleanups;
504 lhsEntity = convertToMoldType(loc, builder, lhsEntity, userAssignLhs,
505 argConversionCleanups);
506 rhsEntity = convertToMoldType(loc, builder, rhsEntity, userAssignRhs,
507 argConversionCleanups);
508 mapper.map(userAssignLhs, lhsEntity);
509 mapper.map(userAssignRhs, rhsEntity);
510 for (auto &op :
511 regionAssignOp.getUserDefinedAssignment().front().without_terminator())
512 (void)builder.clone(op, mapper);
513 for (auto &cleanupConversion : argConversionCleanups)
514 cleanupConversion();
515 if (elementalLoopNest)
516 builder.setInsertionPointAfter(elementalLoopNest->outerOp);
517 } else {
518 // TODO: preserve allocatable assignment aspects for forall once
519 // they are conveyed in hlfir.region_assign.
520 builder.create<hlfir::AssignOp>(loc, rhsEntity, lhsEntity);
521 }
522 generateCleanupIfAny(loweredLhs.elementalCleanup);
523 if (loweredLhs.vectorSubscriptLoopNest)
524 builder.setInsertionPointAfter(loweredLhs.vectorSubscriptLoopNest->outerOp);
525 generateCleanupIfAny(oldRhsYield);
526 generateCleanupIfAny(loweredLhs.nonElementalCleanup);
527}
528
529void OrderedAssignmentRewriter::generateMaskIfOp(mlir::Value cdt) {
530 mlir::Location loc = cdt.getLoc();
531 cdt = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{cdt});
532 cdt = builder.createConvert(loc, builder.getI1Type(), cdt);
533 auto ifOp = builder.create<fir::IfOp>(cdt.getLoc(), std::nullopt, cdt,
534 /*withElseRegion=*/false);
535 constructStack.push_back(ifOp.getOperation());
536 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
537}
538
539void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
540 mlir::Location loc = whereOp.getLoc();
541 if (!whereLoopNest) {
542 // Make sure liveness information is valid for the inner hlfir.exactly_once
543 // operations, and their bodies can be cloned before the top-level
544 // hlfir.where.
545 canonicalizeExactlyOnceInsideWhere(whereOp);
546 // This is the top-level WHERE. Start a loop nest iterating on the shape of
547 // the where mask.
548 if (auto maybeSaved = getIfSaved(whereOp.getMaskRegion())) {
549 // Use the saved value to get the shape and condition element.
550 hlfir::Entity savedMask{maybeSaved->first};
551 mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
552 whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
553 constructStack.push_back(whereLoopNest->outerOp);
554 builder.setInsertionPointToStart(whereLoopNest->body);
555 mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
556 whereLoopNest->oneBasedIndices);
557 generateMaskIfOp(cdt);
558 if (maybeSaved->second) {
559 // If this is the same run as the one that saved the value, the clean-up
560 // was left-over to be done now.
561 auto insertionPoint = builder.saveInsertionPoint();
562 builder.setInsertionPointAfter(whereLoopNest->outerOp);
563 generateCleanupIfAny(maybeSaved->second);
564 builder.restoreInsertionPoint(insertionPoint);
565 }
566 return;
567 }
568 // The mask was not evaluated yet or can be safely re-evaluated.
569 MaskedArrayExpr mask(loc, whereOp.getMaskRegion(),
570 /*isOuterMaskExpr=*/true);
571 mask.generateNoneElementalPart(builder, mapper);
572 mlir::Value shape = mask.generateShape(builder, mapper);
573 whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
574 constructStack.push_back(whereLoopNest->outerOp);
575 builder.setInsertionPointToStart(whereLoopNest->body);
576 mlir::Value cdt = generateMaskedEntity(mask);
577 generateMaskIfOp(cdt);
578 return;
579 }
580 // Where Loops have been already created by a parent WHERE.
581 // Generate a fir.if with the value of the current element of the mask
582 // inside the loops. The case where the mask was saved is handled in the
583 // generateYieldedScalarValue call.
584 mlir::Value cdt = generateYieldedScalarValue(whereOp.getMaskRegion());
585 generateMaskIfOp(cdt);
586}
587
588void OrderedAssignmentRewriter::post(hlfir::WhereOp whereOp) {
589 assert(!constructStack.empty() && "must contain a fir.if");
590 builder.setInsertionPointAfter(constructStack.pop_back_val());
591 // If all where/elsewhere fir.if have been popped, this is the outer whereOp,
592 // and the where loop must be exited.
593 assert(!constructStack.empty() && "must contain a fir.do_loop or fir.if");
594 if (mlir::isa<fir::DoLoopOp>(constructStack.back())) {
595 builder.setInsertionPointAfter(constructStack.pop_back_val());
596 whereLoopNest.reset();
597 }
598}
599
600void OrderedAssignmentRewriter::enterElsewhere(hlfir::ElseWhereOp elseWhereOp) {
601 // Create an "else" region for the current where/elsewhere fir.if.
602 auto ifOp = mlir::dyn_cast<fir::IfOp>(constructStack.back());
603 assert(ifOp && "must be an if");
604 if (ifOp.getElseRegion().empty()) {
605 mlir::Location loc = elseWhereOp.getLoc();
606 builder.createBlock(&ifOp.getElseRegion());
607 auto end = builder.create<fir::ResultOp>(loc);
608 builder.setInsertionPoint(end);
609 } else {
610 builder.setInsertionPoint(&ifOp.getElseRegion().back().back());
611 }
612}
613
614void OrderedAssignmentRewriter::pre(hlfir::ElseWhereOp elseWhereOp) {
615 enterElsewhere(elseWhereOp);
616 if (elseWhereOp.getMaskRegion().empty())
617 return;
618 // Create new nested fir.if with elsewhere mask if any.
619 mlir::Value cdt = generateYieldedScalarValue(elseWhereOp.getMaskRegion());
620 generateMaskIfOp(cdt);
621}
622
623void OrderedAssignmentRewriter::post(hlfir::ElseWhereOp elseWhereOp) {
624 // Exit ifOp that was created for the elseWhereOp mask, if any.
625 if (elseWhereOp.getMaskRegion().empty())
626 return;
627 assert(!constructStack.empty() && "must contain a fir.if");
628 builder.setInsertionPointAfter(constructStack.pop_back_val());
629}
630
631/// Is this value a Forall index?
632/// Forall index are block arguments of hlfir.forall body, or the result
633/// of hlfir.forall_index.
634static bool isForallIndex(mlir::Value value) {
635 if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(value)) {
636 if (mlir::Block *block = blockArg.getOwner())
637 return block->isEntryBlock() &&
638 mlir::isa_and_nonnull<hlfir::ForallOp>(block->getParentOp());
639 return false;
640 }
641 return value.getDefiningOp<hlfir::ForallIndexOp>();
642}
643
644static OrderedAssignmentRewriter::ValueAndCleanUp
645castIfNeeded(mlir::Location loc, fir::FirOpBuilder &builder,
646 OrderedAssignmentRewriter::ValueAndCleanUp valueAndCleanUp,
647 std::optional<mlir::Type> castToType) {
648 if (!castToType.has_value())
649 return valueAndCleanUp;
650 mlir::Value cast =
651 builder.createConvert(loc, *castToType, valueAndCleanUp.first);
652 return {cast, valueAndCleanUp.second};
653}
654
655std::optional<OrderedAssignmentRewriter::ValueAndCleanUp>
656OrderedAssignmentRewriter::getIfSaved(mlir::Region &region) {
657 mlir::Location loc = region.getParentOp()->getLoc();
658 // If the region was saved in the same run, use the value that was evaluated
659 // instead of fetching the temp, and do clean-up, if any, that were delayed.
660 // This is done to avoid requiring the temporary stack to have different
661 // fetching and storing counters, and also because it produces slightly better
662 // code.
663 if (auto savedInSameRun = savedInCurrentRunBeforeUse.find(&region);
664 savedInSameRun != savedInCurrentRunBeforeUse.end())
665 return savedInSameRun->second;
666 // If the region was saved in a previous run, fetch the saved value.
667 if (auto temp = savedEntities.find(&region); temp != savedEntities.end()) {
668 doBeforeLoopNest(callback: [&]() { temp->second.resetFetchPosition(loc, builder); });
669 return ValueAndCleanUp{temp->second.fetch(loc, builder), std::nullopt};
670 }
671 return std::nullopt;
672}
673
674static hlfir::YieldOp getYield(mlir::Region &region) {
675 auto yield = mlir::dyn_cast_or_null<hlfir::YieldOp>(
676 region.back().getOperations().back());
677 assert(yield && "region computing entities must end with a YieldOp");
678 return yield;
679}
680
681OrderedAssignmentRewriter::ValueAndCleanUp
682OrderedAssignmentRewriter::generateYieldedEntity(
683 mlir::Region &region, std::optional<mlir::Type> castToType) {
684 mlir::Location loc = region.getParentOp()->getLoc();
685 if (auto maybeValueAndCleanUp = getIfSaved(region))
686 return castIfNeeded(loc, builder, *maybeValueAndCleanUp, castToType);
687 // Otherwise, evaluate the region now.
688
689 // Masked expression must not evaluate the elemental parts that are masked,
690 // they have custom code generation.
691 if (whereLoopNest.has_value()) {
692 mlir::Value maskedValue = generateMaskedEntity(loc, region);
693 return castIfNeeded(loc, builder, {maskedValue, std::nullopt}, castToType);
694 }
695
696 auto oldYield = getYield(region);
697 // Inside Forall, scalars that do not depend on forall indices can be hoisted
698 // here because their evaluation is required to only call pure procedures, and
699 // if they depend on a variable previously assigned to in a forall assignment,
700 // this assignment must have been scheduled in a previous run. Hoisting of
701 // scalars is done here to help creating simple temporary storage if needed.
702 // Inner forall bounds can often be hoisted, and this allows computing the
703 // total number of iterations to create temporary storages.
704 bool hoistComputation = false;
705 if (fir::isa_trivial(oldYield.getEntity().getType()) &&
706 !constructStack.empty()) {
707 mlir::WalkResult walkResult =
708 region.walk([&](mlir::Operation *op) -> mlir::WalkResult {
709 if (llvm::any_of(op->getOperands(), [](mlir::Value value) {
710 return isForallIndex(value);
711 }))
712 return mlir::WalkResult::interrupt();
713 return mlir::WalkResult::advance();
714 });
715 hoistComputation = !walkResult.wasInterrupted();
716 }
717 auto insertionPoint = builder.saveInsertionPoint();
718 if (hoistComputation)
719 builder.setInsertionPoint(constructStack[0]);
720
721 // Clone all operations except the final hlfir.yield.
722 assert(region.hasOneBlock() && "region must contain one block");
723 for (auto &op : region.back().without_terminator())
724 (void)builder.clone(op, mapper);
725 // Get the value for the yielded entity, it may be the result of an operation
726 // that was cloned, or it may be the same as the previous value if the yield
727 // operand was created before the ordered assignment tree.
728 mlir::Value newEntity = mapper.lookupOrDefault(oldYield.getEntity());
729 if (castToType.has_value())
730 newEntity =
731 builder.createConvert(newEntity.getLoc(), *castToType, newEntity);
732
733 if (hoistComputation) {
734 // Hoisted trivial scalars clean-up can be done right away, the value is
735 // in registers.
736 generateCleanupIfAny(oldYield);
737 builder.restoreInsertionPoint(insertionPoint);
738 return {newEntity, std::nullopt};
739 }
740 if (oldYield.getCleanup().empty())
741 return {newEntity, std::nullopt};
742 return {newEntity, oldYield};
743}
744
745mlir::Value OrderedAssignmentRewriter::generateYieldedScalarValue(
746 mlir::Region &region, std::optional<mlir::Type> castToType) {
747 mlir::Location loc = region.getParentOp()->getLoc();
748 auto [value, maybeYield] = generateYieldedEntity(region, castToType);
749 value = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{value});
750 assert(fir::isa_trivial(value.getType()) && "not a trivial scalar value");
751 generateCleanupIfAny(maybeYield);
752 return value;
753}
754
755OrderedAssignmentRewriter::LhsValueAndCleanUp
756OrderedAssignmentRewriter::generateYieldedLHS(
757 mlir::Location loc, mlir::Region &lhsRegion,
758 std::optional<hlfir::Entity> loweredRhs) {
759 LhsValueAndCleanUp loweredLhs;
760 hlfir::ElementalAddrOp elementalAddrLhs =
761 mlir::dyn_cast<hlfir::ElementalAddrOp>(lhsRegion.back().back());
762 if (auto temp = savedEntities.find(&lhsRegion); temp != savedEntities.end()) {
763 // The LHS address was computed and saved in a previous run. Fetch it.
764 doBeforeLoopNest(callback: [&]() { temp->second.resetFetchPosition(loc, builder); });
765 if (elementalAddrLhs && !whereLoopNest) {
766 // Vector subscripted designator address are saved element by element.
767 // If no "elemental" loops have been created yet, the shape of the
768 // RHS, if it is an array can be used, or the shape of the vector
769 // subscripted designator must be retrieved to generate the "elemental"
770 // loop nest.
771 if (loweredRhs && loweredRhs->isArray()) {
772 // The RHS shape can be used to create the elemental loops and avoid
773 // saving the LHS shape.
774 loweredLhs.vectorSubscriptShape =
775 hlfir::genShape(loc, builder, *loweredRhs);
776 } else {
777 // If the shape cannot be retrieved from the RHS, it must have been
778 // saved. Get it from the temporary.
779 auto &vectorTmp =
780 temp->second.cast<fir::factory::AnyVectorSubscriptStack>();
781 loweredLhs.vectorSubscriptShape = vectorTmp.fetchShape(loc, builder);
782 }
783 loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
784 loc, builder, loweredLhs.vectorSubscriptShape.value());
785 builder.setInsertionPointToStart(
786 loweredLhs.vectorSubscriptLoopNest->body);
787 }
788 loweredLhs.lhs = temp->second.fetch(loc, builder);
789 return loweredLhs;
790 }
791 // The LHS has not yet been evaluated and saved. Evaluate it now.
792 if (elementalAddrLhs && !whereLoopNest) {
793 // This is a vector subscripted entity. The address of elements must
794 // be returned. If no "elemental" loops have been created for a WHERE,
795 // create them now based on the vector subscripted designator shape.
796 for (auto &op : lhsRegion.front().without_terminator())
797 (void)builder.clone(op, mapper);
798 loweredLhs.vectorSubscriptShape =
799 mapper.lookupOrDefault(elementalAddrLhs.getShape());
800 loweredLhs.vectorSubscriptLoopNest =
801 hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
802 !elementalAddrLhs.isOrdered());
803 builder.setInsertionPointToStart(loweredLhs.vectorSubscriptLoopNest->body);
804 mapper.map(elementalAddrLhs.getIndices(),
805 loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
806 for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
807 (void)builder.clone(op, mapper);
808 loweredLhs.elementalCleanup = elementalAddrLhs.getYieldOp();
809 loweredLhs.lhs =
810 mapper.lookupOrDefault(loweredLhs.elementalCleanup->getEntity());
811 } else {
812 // This is a designator without vector subscripts. Generate it as
813 // it is done for other entities.
814 auto [lhs, yield] = generateYieldedEntity(lhsRegion);
815 loweredLhs.lhs = lhs;
816 if (yield && !yield->getCleanup().empty())
817 loweredLhs.nonElementalCleanup = &yield->getCleanup();
818 }
819 return loweredLhs;
820}
821
822mlir::Value
823OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
824 assert(whereLoopNest.has_value() && "must be inside WHERE loop nest");
825 auto insertionPoint = builder.saveInsertionPoint();
826 if (!maskedExpr.noneElementalPartWasGenerated) {
827 // Generate none elemental part before the where loops (but inside the
828 // current forall loops if any).
829 builder.setInsertionPoint(whereLoopNest->outerOp);
830 maskedExpr.generateNoneElementalPart(builder, mapper);
831 }
832 // Generate the none elemental part cleanup after the where loops.
833 builder.setInsertionPointAfter(whereLoopNest->outerOp);
834 maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
835 // Generate the value of the current element for the masked expression
836 // at the current insertion point (inside the where loops, and any fir.if
837 // generated for previous masks).
838 builder.restoreInsertionPoint(insertionPoint);
839 mlir::Value scalar = maskedExpr.generateElementalParts(
840 builder, whereLoopNest->oneBasedIndices, mapper);
841 /// Generate cleanups for the elemental parts inside the loops (setting the
842 /// location so that the assignment will be generated before the cleanups).
843 if (!maskedExpr.isOuterMaskExpr)
844 if (mlir::Operation *firstCleanup =
845 maskedExpr.generateMaskedExprCleanUps(builder, mapper))
846 builder.setInsertionPoint(firstCleanup);
847 return scalar;
848}
849
850void OrderedAssignmentRewriter::generateCleanupIfAny(
851 std::optional<hlfir::YieldOp> maybeYield) {
852 if (maybeYield.has_value())
853 generateCleanupIfAny(&maybeYield->getCleanup());
854}
855void OrderedAssignmentRewriter::generateCleanupIfAny(
856 mlir::Region *cleanupRegion) {
857 if (cleanupRegion && !cleanupRegion->empty()) {
858 assert(cleanupRegion->hasOneBlock() && "region must contain one block");
859 for (auto &op : cleanupRegion->back().without_terminator())
860 builder.clone(op, mapper);
861 }
862}
863
864bool OrderedAssignmentRewriter::mustSaveRegionIn(
865 hlfir::OrderedAssignmentTreeOpInterface node,
866 llvm::SmallVectorImpl<hlfir::SaveEntity> &saveEntities) const {
867 for (auto &action : currentRun->actions)
868 if (hlfir::SaveEntity *savedEntity =
869 std::get_if<hlfir::SaveEntity>(&action))
870 if (node.getOperation() == savedEntity->yieldRegion->getParentOp())
871 saveEntities.push_back(*savedEntity);
872 return !saveEntities.empty();
873}
874
875bool OrderedAssignmentRewriter::isRequiredInCurrentRun(
876 hlfir::OrderedAssignmentTreeOpInterface node) const {
877 // hlfir.forall_index do not contain saved regions/assignments,
878 // but if their hlfir.forall parent was required, they are
879 // required (the forall indices needs to be mapped).
880 if (mlir::isa<hlfir::ForallIndexOp>(node))
881 return true;
882 for (auto &action : currentRun->actions)
883 if (hlfir::SaveEntity *savedEntity =
884 std::get_if<hlfir::SaveEntity>(&action)) {
885 // A SaveEntity action does not require evaluating the node that contains
886 // it, but it requires to evaluate all the parents of the nodes that
887 // contains it. For instance, an saving a bound in hlfir.forall B does not
888 // require creating the loops for B, but it requires creating the loops
889 // for any forall parent A of the forall B.
890 if (node->isProperAncestor(savedEntity->yieldRegion->getParentOp()))
891 return true;
892 } else {
893 auto assign = std::get<hlfir::RegionAssignOp>(action);
894 if (node->isAncestor(assign.getOperation()))
895 return true;
896 }
897 return false;
898}
899
900/// Is the apply using all the elemental indices in order?
901static bool isInOrderApply(hlfir::ApplyOp apply,
902 hlfir::ElementalOpInterface elemental) {
903 mlir::Region::BlockArgListType elementalIndices = elemental.getIndices();
904 if (elementalIndices.size() != apply.getIndices().size())
905 return false;
906 for (auto [elementalIdx, applyIdx] :
907 llvm::zip(elementalIndices, apply.getIndices()))
908 if (elementalIdx != applyIdx)
909 return false;
910 return true;
911}
912
913/// Gather the tree of hlfir::ElementalOpInterface use-def, if any, starting
914/// from \p elemental, which may be a nullptr.
915static void
916gatherElementalTree(hlfir::ElementalOpInterface elemental,
917 llvm::SmallPtrSetImpl<mlir::Operation *> &elementalOps,
918 bool isOutOfOrder) {
919 if (elemental) {
920 // Only inline an applied elemental that must be executed in order if the
921 // applying indices are in order. An hlfir::Elemental may have been created
922 // for a transformational like transpose, and Fortran 2018 standard
923 // section 10.2.3.2, point 10 imply that impure elemental sub-expression
924 // evaluations should not be masked if they are the arguments of
925 // transformational expressions.
926 if (isOutOfOrder && elemental.isOrdered())
927 return;
928 elementalOps.insert(elemental.getOperation());
929 for (mlir::Operation &op : elemental.getElementalRegion().getOps())
930 if (auto apply = mlir::dyn_cast<hlfir::ApplyOp>(op)) {
931 bool isUnorderedApply =
932 isOutOfOrder || !isInOrderApply(apply, elemental);
933 auto maybeElemental =
934 mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>(
935 apply.getExpr().getDefiningOp());
936 gatherElementalTree(maybeElemental, elementalOps, isUnorderedApply);
937 }
938 }
939}
940
941MaskedArrayExpr::MaskedArrayExpr(mlir::Location loc, mlir::Region &region,
942 bool isOuterMaskExpr)
943 : loc{loc}, region{region}, isOuterMaskExpr{isOuterMaskExpr} {
944 mlir::Operation &terminator = region.back().back();
945 if (auto elementalAddr =
946 mlir::dyn_cast<hlfir::ElementalOpInterface>(terminator)) {
947 // Vector subscripted designator (hlfir.elemental_addr terminator).
948 gatherElementalTree(elementalAddr, elementalParts, /*isOutOfOrder=*/false);
949 return;
950 }
951 // Try if elemental expression.
952 mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
953 auto maybeElemental = mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>(
954 entity.getDefiningOp());
955 gatherElementalTree(maybeElemental, elementalParts, /*isOutOfOrder=*/false);
956}
957
958void MaskedArrayExpr::generateNoneElementalPart(fir::FirOpBuilder &builder,
959 mlir::IRMapping &mapper) {
960 assert(!noneElementalPartWasGenerated &&
961 "none elemental parts already generated");
962 if (isOuterMaskExpr) {
963 // The outer mask expression is actually not masked, it is dealt as
964 // such so that its elemental part, if any, can be inlined in the WHERE
965 // loops. But all of the operations outside of hlfir.elemental/
966 // hlfir.elemental_addr must be emitted now because their value may be
967 // required to deduce the mask shape and the WHERE loop bounds.
968 for (mlir::Operation &op : region.back().without_terminator())
969 if (!elementalParts.contains(&op))
970 (void)builder.clone(op, mapper);
971 } else {
972 // For actual masked expressions, Fortran requires elemental expressions,
973 // even the scalar ones that are not encoded with hlfir.elemental, to be
974 // evaluated only when the mask is true. Blindly hoisting all scalar SSA
975 // tree could be wrong if the scalar computation has side effects and
976 // would never have been evaluated (e.g. division by zero) if the mask
977 // is fully false. See F'2023 10.2.3.2 point 10.
978 // Clone only the bodies of all hlfir.exactly_once operations, which contain
979 // the evaluation of sub-expression tree whose root was a non elemental
980 // function call at the Fortran level (the call itself may have been inlined
981 // since). These must be evaluated only once as per F'2023 10.2.3.2 point 9.
982 for (mlir::Operation &op : region.back().without_terminator())
983 if (auto exactlyOnce = mlir::dyn_cast<hlfir::ExactlyOnceOp>(op)) {
984 for (mlir::Operation &subOp :
985 exactlyOnce.getBody().back().without_terminator())
986 (void)builder.clone(subOp, mapper);
987 mlir::Value oldYield = getYield(exactlyOnce.getBody()).getEntity();
988 auto newYield = mapper.lookupOrDefault(oldYield);
989 mapper.map(exactlyOnce.getResult(), newYield);
990 }
991 }
992 noneElementalPartWasGenerated = true;
993}
994
995mlir::Value MaskedArrayExpr::generateShape(fir::FirOpBuilder &builder,
996 mlir::IRMapping &mapper) {
997 assert(noneElementalPartWasGenerated &&
998 "non elemental part must have been generated");
999 mlir::Operation &terminator = region.back().back();
1000 // If the operation that produced the yielded entity is elemental, it was not
1001 // cloned, but it holds a shape argument that was cloned. Return the cloned
1002 // shape.
1003 if (auto elementalAddrOp = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator))
1004 return mapper.lookupOrDefault(elementalAddrOp.getShape());
1005 mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
1006 if (auto elemental = entity.getDefiningOp<hlfir::ElementalOp>())
1007 return mapper.lookupOrDefault(elemental.getShape());
1008 // Otherwise, the whole entity was cloned, and the shape can be generated
1009 // from it.
1010 hlfir::Entity clonedEntity{mapper.lookupOrDefault(entity)};
1011 return hlfir::genShape(loc, builder, hlfir::Entity{clonedEntity});
1012}
1013
1014mlir::Value
1015MaskedArrayExpr::generateElementalParts(fir::FirOpBuilder &builder,
1016 mlir::ValueRange oneBasedIndices,
1017 mlir::IRMapping &mapper) {
1018 assert(noneElementalPartWasGenerated &&
1019 "non elemental part must have been generated");
1020 if (!isOuterMaskExpr) {
1021 // Clone all operations that are not hlfir.exactly_once and that are not
1022 // hlfir.elemental/hlfir.elemental_addr.
1023 for (mlir::Operation &op : region.back().without_terminator())
1024 if (!mlir::isa<hlfir::ExactlyOnceOp>(op) && !elementalParts.contains(&op))
1025 (void)builder.clone(op, mapper);
1026 // For the outer mask, this was already done outside of the loop.
1027 }
1028 // Clone and "index" bodies of hlfir.elemental/hlfir.elemental_addr.
1029 mlir::Operation &terminator = region.back().back();
1030 hlfir::ElementalOpInterface elemental =
1031 mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator);
1032 if (!elemental) {
1033 // If the terminator is not an hlfir.elemental_addr, try if the yielded
1034 // entity was produced by an hlfir.elemental.
1035 mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
1036 elemental = entity.getDefiningOp<hlfir::ElementalOp>();
1037 if (!elemental) {
1038 // The yielded entity was not produced by an elemental operation,
1039 // get its clone in the non elemental part evaluation and address it.
1040 hlfir::Entity clonedEntity{mapper.lookupOrDefault(entity)};
1041 return hlfir::getElementAt(loc, builder, clonedEntity, oneBasedIndices);
1042 }
1043 }
1044
1045 auto mustRecursivelyInline =
1046 [&](hlfir::ElementalOp appliedElemental) -> bool {
1047 return elementalParts.contains(appliedElemental.getOperation());
1048 };
1049 return inlineElementalOp(loc, builder, elemental, oneBasedIndices, mapper,
1050 mustRecursivelyInline);
1051}
1052
1053mlir::Operation *
1054MaskedArrayExpr::generateMaskedExprCleanUps(fir::FirOpBuilder &builder,
1055 mlir::IRMapping &mapper) {
1056 // Clone the clean-ups from the region itself, except for the destroy
1057 // of the hlfir.elemental that have been inlined.
1058 mlir::Operation &terminator = region.back().back();
1059 mlir::Region *cleanupRegion = nullptr;
1060 if (auto elementalAddr = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator)) {
1061 cleanupRegion = &elementalAddr.getCleanup();
1062 } else {
1063 auto yieldOp = mlir::cast<hlfir::YieldOp>(terminator);
1064 cleanupRegion = &yieldOp.getCleanup();
1065 }
1066 if (cleanupRegion->empty())
1067 return nullptr;
1068 mlir::Operation *firstNewCleanup = nullptr;
1069 for (mlir::Operation &op : cleanupRegion->front().without_terminator()) {
1070 if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(op))
1071 if (elementalParts.contains(destroy.getExpr().getDefiningOp()))
1072 continue;
1073 mlir::Operation *cleanup = builder.clone(op, mapper);
1074 if (!firstNewCleanup)
1075 firstNewCleanup = cleanup;
1076 }
1077 return firstNewCleanup;
1078}
1079
1080void MaskedArrayExpr::generateNoneElementalCleanupIfAny(
1081 fir::FirOpBuilder &builder, mlir::IRMapping &mapper) {
1082 if (!isOuterMaskExpr) {
1083 // Clone clean-ups of hlfir.exactly_once operations (in reverse order
1084 // to properly deal with stack restores).
1085 for (mlir::Operation &op :
1086 llvm::reverse(region.back().without_terminator()))
1087 if (auto exactlyOnce = mlir::dyn_cast<hlfir::ExactlyOnceOp>(op)) {
1088 mlir::Region &cleanupRegion =
1089 getYield(exactlyOnce.getBody()).getCleanup();
1090 if (!cleanupRegion.empty())
1091 for (mlir::Operation &cleanupOp :
1092 cleanupRegion.front().without_terminator())
1093 (void)builder.clone(cleanupOp, mapper);
1094 }
1095 } else {
1096 // For the outer mask, the region clean-ups must be generated
1097 // outside of the loops since the mask non hlfir.elemental part
1098 // is generated before the loops.
1099 generateMaskedExprCleanUps(builder, mapper);
1100 }
1101}
1102
1103static hlfir::RegionAssignOp
1104getAssignIfLeftHandSideRegion(mlir::Region &region) {
1105 auto assign = mlir::dyn_cast<hlfir::RegionAssignOp>(region.getParentOp());
1106 if (assign && (&assign.getLhsRegion() == &region))
1107 return assign;
1108 return nullptr;
1109}
1110
1111static bool isPointerAssignmentRHS(mlir::Region &region) {
1112 auto assign = mlir::dyn_cast<hlfir::RegionAssignOp>(region.getParentOp());
1113 return assign && assign.isPointerAssignment() &&
1114 (&assign.getRhsRegion() == &region);
1115}
1116
1117bool OrderedAssignmentRewriter::currentLoopNestIterationNumberCanBeComputed(
1118 llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest) {
1119 if (constructStack.empty())
1120 return true;
1121 mlir::Operation *outerLoop = constructStack[0];
1122 mlir::Operation *currentConstruct = constructStack.back();
1123 // Loop through the loops until the outer construct is met, and test if the
1124 // loop operands dominate the outer construct.
1125 while (currentConstruct) {
1126 if (auto doLoop = mlir::dyn_cast<fir::DoLoopOp>(currentConstruct)) {
1127 if (llvm::any_of(doLoop->getOperands(), [&](mlir::Value value) {
1128 return !dominanceInfo.properlyDominates(value, outerLoop);
1129 })) {
1130 return false;
1131 }
1132 loopNest.push_back(doLoop);
1133 }
1134 if (currentConstruct == outerLoop)
1135 currentConstruct = nullptr;
1136 else
1137 currentConstruct = currentConstruct->getParentOp();
1138 }
1139 return true;
1140}
1141
1142static mlir::Value
1143computeLoopNestIterationNumber(mlir::Location loc, fir::FirOpBuilder &builder,
1144 llvm::ArrayRef<fir::DoLoopOp> loopNest) {
1145 mlir::Value loopExtent;
1146 for (fir::DoLoopOp doLoop : loopNest) {
1147 mlir::Value extent = builder.genExtentFromTriplet(
1148 loc, doLoop.getLowerBound(), doLoop.getUpperBound(), doLoop.getStep(),
1149 builder.getIndexType());
1150 if (!loopExtent)
1151 loopExtent = extent;
1152 else
1153 loopExtent = builder.create<mlir::arith::MulIOp>(loc, loopExtent, extent);
1154 }
1155 assert(loopExtent && "loopNest must not be empty");
1156 return loopExtent;
1157}
1158
1159/// Return a name for temporary storage that indicates in which context
1160/// the temporary storage was created.
1161static llvm::StringRef
1162getTempName(hlfir::OrderedAssignmentTreeOpInterface root) {
1163 if (mlir::isa<hlfir::ForallOp>(root.getOperation()))
1164 return ".tmp.forall";
1165 if (mlir::isa<hlfir::WhereOp>(root.getOperation()))
1166 return ".tmp.where";
1167 return ".tmp.assign";
1168}
1169
1170void OrderedAssignmentRewriter::generateSaveEntity(
1171 hlfir::SaveEntity savedEntity, bool willUseSavedEntityInSameRun) {
1172 mlir::Region &region = *savedEntity.yieldRegion;
1173
1174 if (hlfir::RegionAssignOp regionAssignOp =
1175 getAssignIfLeftHandSideRegion(region)) {
1176 // Need to save the address, not the values.
1177 assert(!willUseSavedEntityInSameRun &&
1178 "lhs cannot be used in the loop nest where it is saved");
1179 return saveLeftHandSide(savedEntity, regionAssignOp);
1180 }
1181 if (isPointerAssignmentRHS(region)) {
1182 assert(!willUseSavedEntityInSameRun &&
1183 "rhs cannot be used in the loop nest where it is saved");
1184 return saveNonVectorSubscriptedAddress(savedEntity);
1185 }
1186
1187 mlir::Location loc = region.getParentOp()->getLoc();
1188 // Evaluate the region inside the loop nest (if any).
1189 auto [clonedValue, oldYield] = generateYieldedEntity(region);
1190 hlfir::Entity entity{clonedValue};
1191 entity = hlfir::loadTrivialScalar(loc, builder, entity);
1192 mlir::Type entityType = entity.getType();
1193
1194 llvm::StringRef tempName = getTempName(root);
1195 fir::factory::TemporaryStorage *temp = nullptr;
1196 if (constructStack.empty()) {
1197 // Value evaluated outside of any loops (this may be the first MASK of a
1198 // WHERE construct, or an LHS/RHS temp of hlfir.region_assign outside of
1199 // WHERE/FORALL).
1200 temp = insertSavedEntity(
1201 region, fir::factory::SimpleCopy(loc, builder, entity, tempName));
1202 } else {
1203 // Need to create a temporary for values computed inside loops.
1204 // Create temporary storage outside of the loop nest given the entity
1205 // type (and the loop context).
1206 llvm::SmallVector<fir::DoLoopOp> loopNest;
1207 bool loopShapeCanBePreComputed =
1208 currentLoopNestIterationNumberCanBeComputed(loopNest);
1209 doBeforeLoopNest(callback: [&] {
1210 /// For simple scalars inside loops whose total iteration number can be
1211 /// pre-computed, create a rank-1 array outside of the loops. It will be
1212 /// assigned/fetched inside the loops like a normal Fortran array given
1213 /// the iteration count.
1214 if (loopShapeCanBePreComputed && fir::isa_trivial(entityType)) {
1215 mlir::Value loopExtent =
1216 computeLoopNestIterationNumber(loc, builder, loopNest);
1217 auto sequenceType =
1218 mlir::cast<fir::SequenceType>(builder.getVarLenSeqTy(entityType));
1219 temp = insertSavedEntity(region,
1220 fir::factory::HomogeneousScalarStack{
1221 loc, builder, sequenceType, loopExtent,
1222 /*lenParams=*/{}, allocateOnHeap,
1223 /*stackThroughLoops=*/true, tempName});
1224
1225 } else {
1226 // If the number of iteration is not known, or if the values at each
1227 // iterations are values that may have different shape, type parameters
1228 // or dynamic type, use the runtime to create and manage a stack-like
1229 // temporary.
1230 temp = insertSavedEntity(
1231 region, fir::factory::AnyValueStack{loc, builder, entityType});
1232 }
1233 });
1234 // Inside the loop nest (and any fir.if if there are active masks), copy
1235 // the value to the temp and do clean-ups for the value if any.
1236 temp->pushValue(loc, builder, entity);
1237 }
1238
1239 // Delay the clean-up if the entity will be used in the same run (i.e., the
1240 // parent construct will be visited and needs to be lowered). When possible,
1241 // this is not done for hlfir.expr because this use would prevent the
1242 // hlfir.expr storage from being moved when creating the temporary in
1243 // bufferization, and that would lead to an extra copy.
1244 if (willUseSavedEntityInSameRun &&
1245 (!temp->canBeFetchedAfterPush() ||
1246 !mlir::isa<hlfir::ExprType>(entity.getType()))) {
1247 auto inserted =
1248 savedInCurrentRunBeforeUse.try_emplace(&region, entity, oldYield);
1249 assert(inserted.second && "entity must have been emplaced");
1250 (void)inserted;
1251 } else {
1252 if (constructStack.empty() &&
1253 mlir::isa<hlfir::RegionAssignOp>(region.getParentOp())) {
1254 // Here the clean-up code is inserted after the original
1255 // RegionAssignOp, so that the assignment code happens
1256 // before the cleanup. We do this only for standalone
1257 // operations, because the clean-up is handled specially
1258 // during lowering of the parent constructs if any
1259 // (e.g. see generateNoneElementalCleanupIfAny for
1260 // WhereOp).
1261 auto insertionPoint = builder.saveInsertionPoint();
1262 builder.setInsertionPointAfter(region.getParentOp());
1263 generateCleanupIfAny(oldYield);
1264 builder.restoreInsertionPoint(insertionPoint);
1265 } else {
1266 generateCleanupIfAny(oldYield);
1267 }
1268 }
1269}
1270
1271static bool rhsIsArray(hlfir::RegionAssignOp regionAssignOp) {
1272 auto yieldOp = mlir::dyn_cast<hlfir::YieldOp>(
1273 regionAssignOp.getRhsRegion().back().back());
1274 return yieldOp && hlfir::Entity{yieldOp.getEntity()}.isArray();
1275}
1276
1277static bool isVectorSubscripted(mlir::Region &region) {
1278 return llvm::isa<hlfir::ElementalAddrOp>(region.back().back());
1279}
1280
1281void OrderedAssignmentRewriter::saveNonVectorSubscriptedAddress(
1282 hlfir::SaveEntity savedEntity) {
1283 mlir::Region &region = *savedEntity.yieldRegion;
1284 mlir::Location loc = region.getParentOp()->getLoc();
1285 assert(!isVectorSubscripted(region) &&
1286 "expected variable without vector subscripts");
1287 ValueAndCleanUp varAndCleanup = generateYieldedEntity(region);
1288 hlfir::Entity var{varAndCleanup.first};
1289 fir::factory::TemporaryStorage *temp = nullptr;
1290 // If the address dominates the constructs, its SSA value can simply be
1291 // tracked and there is no need to save the address in memory. Otherwise,
1292 // the addresses are stored at each iteration in memory with a descriptor
1293 // stack.
1294 if (constructStack.empty() ||
1295 dominanceInfo.properlyDominates(var, constructStack[0]))
1296 doBeforeLoopNest(
1297 callback: [&] { temp = insertSavedEntity(region, fir::factory::SSARegister{}); });
1298 else
1299 doBeforeLoopNest(callback: [&] {
1300 if (var.isMutableBox() || var.isProcedure() || var.isProcedurePointer())
1301 // Store single C pointer to entity.
1302 temp = insertSavedEntity(
1303 region, fir::factory::AnyAddressStack{loc, builder, var.getType()});
1304 else
1305 // Store the base address and dynamic shape/length/type information
1306 // as descriptor.
1307 temp = insertSavedEntity(region, fir::factory::AnyVariableStack{
1308 loc, builder, var.getType()});
1309 });
1310 temp->pushValue(loc, builder, var);
1311 generateCleanupIfAny(varAndCleanup.second);
1312}
1313
1314void OrderedAssignmentRewriter::saveLeftHandSide(
1315 hlfir::SaveEntity savedEntity, hlfir::RegionAssignOp regionAssignOp) {
1316 mlir::Region &region = *savedEntity.yieldRegion;
1317 if (!isVectorSubscripted(region)) {
1318 saveNonVectorSubscriptedAddress(savedEntity);
1319 return;
1320 }
1321 // Save vector subscripted LHS address.
1322 mlir::Location loc = region.getParentOp()->getLoc();
1323 LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
1324 // loweredLhs.vectorSubscriptLoopNest is empty inside a WHERE because the
1325 // WHERE loops are already indexing the vector subscripted designator.
1326 if (loweredLhs.vectorSubscriptLoopNest)
1327 constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerOp);
1328 fir::factory::TemporaryStorage *temp = nullptr;
1329 if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
1330 // Vector subscripted entity for which the shape must also be saved on top
1331 // of the element addresses (e.g. the shape may change in each forall
1332 // iteration and is needed to create the elemental loops).
1333 mlir::Value shape = loweredLhs.vectorSubscriptShape.value();
1334 int rank = mlir::cast<fir::ShapeType>(shape.getType()).getRank();
1335 const bool shapeIsInvariant =
1336 constructStack.empty() ||
1337 dominanceInfo.properlyDominates(shape, constructStack[0]);
1338 doBeforeLoopNest(callback: [&] {
1339 // Outside of any forall/where/elemental loops, create a temporary that
1340 // will both be able to save the vector subscripted designator shape(s)
1341 // and element addresses.
1342 temp =
1343 insertSavedEntity(region, fir::factory::AnyVectorSubscriptStack{
1344 loc, builder, loweredLhs.lhs.getType(),
1345 shapeIsInvariant, rank});
1346 });
1347 // Save shape before the elemental loop nest created by the vector
1348 // subscripted LHS.
1349 auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
1350 auto insertionPoint = builder.saveInsertionPoint();
1351 builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerOp);
1352 vectorTmp.pushShape(loc, builder, shape);
1353 builder.restoreInsertionPoint(insertionPoint);
1354 } else {
1355 // Only saving the scalar elements addresses. These addresses computation
1356 // depend on the inner loop indices generated for the vector subscripts
1357 // (no need to wast time checking dominance) and can only be save in a
1358 // variable stack so far.
1359 doBeforeLoopNest(callback: [&] {
1360 temp = insertSavedEntity(
1361 region, fir::factory::AnyVariableStack{loc, builder,
1362 loweredLhs.lhs.getType()});
1363 });
1364 }
1365 temp->pushValue(loc, builder, loweredLhs.lhs);
1366 generateCleanupIfAny(loweredLhs.elementalCleanup);
1367 if (loweredLhs.vectorSubscriptLoopNest) {
1368 constructStack.pop_back();
1369 builder.setInsertionPointAfter(loweredLhs.vectorSubscriptLoopNest->outerOp);
1370 }
1371}
1372
1373void OrderedAssignmentRewriter::canonicalizeExactlyOnceInsideWhere(
1374 hlfir::WhereOp whereOp) {
1375 auto getDefinition = [](mlir::Value v) {
1376 mlir::Operation *op = v.getDefiningOp();
1377 bool isValid = true;
1378 if (!op) {
1379 LLVM_DEBUG(
1380 llvm::dbgs()
1381 << "Value live into hlfir.exactly_once has no defining operation: "
1382 << v << "\n");
1383 isValid = false;
1384 }
1385 if (op->getNumRegions() != 0) {
1386 LLVM_DEBUG(
1387 llvm::dbgs()
1388 << "Cannot pull an operation with regions into hlfir.exactly_once"
1389 << *op << "\n");
1390 isValid = false;
1391 }
1392 auto effects = mlir::getEffectsRecursively(op);
1393 if (!effects || !effects->empty()) {
1394 LLVM_DEBUG(llvm::dbgs() << "Side effects on operation with result live "
1395 "into hlfir.exactly_once"
1396 << *op << "\n");
1397 isValid = false;
1398 }
1399 assert(isValid && "invalid live-in");
1400 (void)isValid;
1401 return op;
1402 };
1403 mlir::Liveness liveness(whereOp.getOperation());
1404 whereOp->walk([&](hlfir::ExactlyOnceOp op) {
1405 std::unordered_set<mlir::Operation *> liveInSet;
1406 LLVM_DEBUG(llvm::dbgs() << "Canonicalizing:\n" << op << "\n");
1407 auto &liveIns = liveness.getLiveIn(&op.getBody().front());
1408 if (liveIns.empty())
1409 return;
1410 // Note that the liveIns set is not ordered.
1411 for (mlir::Value liveIn : liveIns) {
1412 if (!dominanceInfo.properlyDominates(liveIn, whereOp)) {
1413 LLVM_DEBUG(llvm::dbgs()
1414 << "Does not dominate top-level where: " << liveIn << "\n");
1415 liveInSet.insert(getDefinition(liveIn));
1416 }
1417 }
1418
1419 // Populate the set of operations that we need to pull into
1420 // hlfir.exactly_once, so that the only live-ins left are the ones
1421 // that dominate whereOp.
1422 std::unordered_set<mlir::Operation *> cloneSet(liveInSet);
1423 llvm::SmallVector<mlir::Operation *> workList(cloneSet.begin(),
1424 cloneSet.end());
1425 while (!workList.empty()) {
1426 mlir::Operation *current = workList.pop_back_val();
1427 for (mlir::Value operand : current->getOperands()) {
1428 if (dominanceInfo.properlyDominates(operand, whereOp))
1429 continue;
1430 mlir::Operation *def = getDefinition(operand);
1431 if (cloneSet.count(def))
1432 continue;
1433 cloneSet.insert(def);
1434 workList.push_back(def);
1435 }
1436 }
1437
1438 // Sort the operations by dominance. This preserves their order
1439 // after the cloning, and also guarantees stable IR generation.
1440 llvm::SmallVector<mlir::Operation *> cloneList(cloneSet.begin(),
1441 cloneSet.end());
1442 llvm::sort(cloneList, [&](mlir::Operation *L, mlir::Operation *R) {
1443 return dominanceInfo.properlyDominates(L, R);
1444 });
1445
1446 // Clone the operations.
1447 mlir::IRMapping mapper;
1448 mlir::Operation::CloneOptions options;
1449 options.cloneOperands();
1450 mlir::OpBuilder::InsertionGuard guard(builder);
1451 builder.setInsertionPointToStart(&op.getBody().front());
1452
1453 for (auto *toClone : cloneList) {
1454 LLVM_DEBUG(llvm::dbgs() << "Cloning: " << *toClone << "\n");
1455 builder.insert(toClone->clone(mapper, options));
1456 }
1457 for (mlir::Operation *oldOps : liveInSet)
1458 for (mlir::Value oldVal : oldOps->getResults()) {
1459 mlir::Value newVal = mapper.lookup(oldVal);
1460 if (!newVal) {
1461 LLVM_DEBUG(llvm::dbgs() << "No clone found for: " << oldVal << "\n");
1462 assert(false && "missing clone");
1463 }
1464 mlir::replaceAllUsesInRegionWith(oldVal, newVal, op.getBody());
1465 }
1466
1467 LLVM_DEBUG(llvm::dbgs() << "Finished canonicalization\n");
1468 if (!liveInSet.empty())
1469 LLVM_DEBUG(llvm::dbgs() << op << "\n");
1470 });
1471}
1472
1473/// Lower an ordered assignment tree to fir.do_loop and hlfir.assign given
1474/// a schedule.
1475static void lower(hlfir::OrderedAssignmentTreeOpInterface root,
1476 mlir::PatternRewriter &rewriter, hlfir::Schedule &schedule) {
1477 auto module = root->getParentOfType<mlir::ModuleOp>();
1478 fir::FirOpBuilder builder(rewriter, module);
1479 OrderedAssignmentRewriter assignmentRewriter(builder, root);
1480 for (auto &run : schedule)
1481 assignmentRewriter.lowerRun(run);
1482 assignmentRewriter.cleanupSavedEntities();
1483}
1484
1485/// Shared rewrite entry point for all the ordered assignment tree root
1486/// operations. It calls the scheduler and then apply the schedule.
1487static llvm::LogicalResult rewrite(hlfir::OrderedAssignmentTreeOpInterface root,
1488 bool tryFusingAssignments,
1489 mlir::PatternRewriter &rewriter) {
1490 hlfir::Schedule schedule =
1491 hlfir::buildEvaluationSchedule(root, tryFusingAssignments);
1492
1493 LLVM_DEBUG(
1494 /// Debug option to print the scheduling debug info without doing
1495 /// any code generation. The operations are simply erased to avoid
1496 /// failing and calling the rewrite patterns on nested operations.
1497 /// The only purpose of this is to help testing scheduling without
1498 /// having to test generated code.
1499 if (dbgScheduleOnly) {
1500 rewriter.eraseOp(root);
1501 return mlir::success();
1502 });
1503 lower(root, rewriter, schedule);
1504 rewriter.eraseOp(root);
1505 return mlir::success();
1506}
1507
1508namespace {
1509
1510class ForallOpConversion : public mlir::OpRewritePattern<hlfir::ForallOp> {
1511public:
1512 explicit ForallOpConversion(mlir::MLIRContext *ctx, bool tryFusingAssignments)
1513 : OpRewritePattern{ctx}, tryFusingAssignments{tryFusingAssignments} {}
1514
1515 llvm::LogicalResult
1516 matchAndRewrite(hlfir::ForallOp forallOp,
1517 mlir::PatternRewriter &rewriter) const override {
1518 auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
1519 forallOp.getOperation());
1520 if (mlir::failed(::rewrite(root, tryFusingAssignments, rewriter)))
1521 TODO(forallOp.getLoc(), "FORALL construct or statement in HLFIR");
1522 return mlir::success();
1523 }
1524 const bool tryFusingAssignments;
1525};
1526
1527class WhereOpConversion : public mlir::OpRewritePattern<hlfir::WhereOp> {
1528public:
1529 explicit WhereOpConversion(mlir::MLIRContext *ctx, bool tryFusingAssignments)
1530 : OpRewritePattern{ctx}, tryFusingAssignments{tryFusingAssignments} {}
1531
1532 llvm::LogicalResult
1533 matchAndRewrite(hlfir::WhereOp whereOp,
1534 mlir::PatternRewriter &rewriter) const override {
1535 auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
1536 whereOp.getOperation());
1537 return ::rewrite(root, tryFusingAssignments, rewriter);
1538 }
1539 const bool tryFusingAssignments;
1540};
1541
1542class RegionAssignConversion
1543 : public mlir::OpRewritePattern<hlfir::RegionAssignOp> {
1544public:
1545 explicit RegionAssignConversion(mlir::MLIRContext *ctx)
1546 : OpRewritePattern{ctx} {}
1547
1548 llvm::LogicalResult
1549 matchAndRewrite(hlfir::RegionAssignOp regionAssignOp,
1550 mlir::PatternRewriter &rewriter) const override {
1551 auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
1552 regionAssignOp.getOperation());
1553 return ::rewrite(root, /*tryFusingAssignments=*/false, rewriter);
1554 }
1555};
1556
1557class LowerHLFIROrderedAssignments
1558 : public hlfir::impl::LowerHLFIROrderedAssignmentsBase<
1559 LowerHLFIROrderedAssignments> {
1560public:
1561 using LowerHLFIROrderedAssignmentsBase<
1562 LowerHLFIROrderedAssignments>::LowerHLFIROrderedAssignmentsBase;
1563
1564 void runOnOperation() override {
1565 // Running on a ModuleOp because this pass may generate FuncOp declaration
1566 // for runtime calls. This could be a FuncOp pass otherwise.
1567 auto module = this->getOperation();
1568 auto *context = &getContext();
1569 mlir::RewritePatternSet patterns(context);
1570 // Patterns are only defined for the OrderedAssignmentTreeOpInterface
1571 // operations that can be the root of ordered assignments. The other
1572 // operations will be taken care of while rewriting these trees (they
1573 // cannot exist outside of these operations given their verifiers/traits).
1574 patterns.insert<ForallOpConversion, WhereOpConversion>(
1575 context, this->tryFusingAssignments.getValue());
1576 patterns.insert<RegionAssignConversion>(context);
1577 mlir::ConversionTarget target(*context);
1578 target.markUnknownOpDynamicallyLegal([](mlir::Operation *op) {
1579 return !mlir::isa<hlfir::OrderedAssignmentTreeOpInterface>(op);
1580 });
1581 if (mlir::failed(mlir::applyPartialConversion(module, target,
1582 std::move(patterns)))) {
1583 mlir::emitError(mlir::UnknownLoc::get(context),
1584 "failure in HLFIR ordered assignments lowering pass");
1585 signalPassFailure();
1586 }
1587 }
1588};
1589} // namespace
1590

source code of flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp