1//===- LowerHLFIROrderedAssignments.cpp - Lower HLFIR ordered assignments -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This file defines a pass to lower HLFIR ordered assignments.
9// Ordered assignments are all the operations with the
10// OrderedAssignmentTreeOpInterface that implements user defined assignments,
11// assignment to vector subscripted entities, and assignments inside forall and
12// where.
13// The pass lowers these operations to regular hlfir.assign, loops and, if
14// needed, introduces temporary storage to fulfill Fortran semantics.
15//
16// For each rewrite, an analysis builds an evaluation schedule, and then the
17// new code is generated by following the evaluation schedule.
18//===----------------------------------------------------------------------===//
19
20#include "ScheduleOrderedAssignments.h"
21#include "flang/Optimizer/Builder/FIRBuilder.h"
22#include "flang/Optimizer/Builder/HLFIRTools.h"
23#include "flang/Optimizer/Builder/TemporaryStorage.h"
24#include "flang/Optimizer/Builder/Todo.h"
25#include "flang/Optimizer/Dialect/Support/FIRContext.h"
26#include "flang/Optimizer/HLFIR/Passes.h"
27#include "mlir/IR/Dominance.h"
28#include "mlir/IR/IRMapping.h"
29#include "mlir/Transforms/DialectConversion.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/Debug.h"
33
34namespace hlfir {
35#define GEN_PASS_DEF_LOWERHLFIRORDEREDASSIGNMENTS
36#include "flang/Optimizer/HLFIR/Passes.h.inc"
37} // namespace hlfir
38
39#define DEBUG_TYPE "flang-ordered-assignment"
40
41// Test option only to test the scheduling part only (operations are erased
42// without codegen). The only goal is to allow printing and testing the debug
43// info.
44static llvm::cl::opt<bool> dbgScheduleOnly(
45 "flang-dbg-order-assignment-schedule-only",
46 llvm::cl::desc("Only run ordered assignment scheduling with no codegen"),
47 llvm::cl::init(false));
48
49namespace {
50
51/// Structure that represents a masked expression being lowered. Masked
52/// expressions are any expressions inside an hlfir.where. As described in
53/// Fortran 2018 section 10.2.3.2, the evaluation of the elemental parts of such
54/// expressions must be masked, while the evaluation of none elemental parts
55/// must not be masked. This structure analyzes the region evaluating the
56/// expression and allows splitting the generation of the none elemental part
57/// from the elemental part.
58struct MaskedArrayExpr {
59 MaskedArrayExpr(mlir::Location loc, mlir::Region &region);
60
61 /// Generate the none elemental part. Must be called outside of the
62 /// loops created for the WHERE construct.
63 void generateNoneElementalPart(fir::FirOpBuilder &builder,
64 mlir::IRMapping &mapper);
65
66 /// Methods below can only be called once generateNoneElementalPart has been
67 /// called.
68
69 /// Return the shape of the expression.
70 mlir::Value generateShape(fir::FirOpBuilder &builder,
71 mlir::IRMapping &mapper);
72 /// Return the value of an element value for this expression given the current
73 /// where loop indices.
74 mlir::Value generateElementalParts(fir::FirOpBuilder &builder,
75 mlir::ValueRange oneBasedIndices,
76 mlir::IRMapping &mapper);
77 /// Generate the cleanup for the none elemental parts, if any. This must be
78 /// called after the loops created for the WHERE construct.
79 void generateNoneElementalCleanupIfAny(fir::FirOpBuilder &builder,
80 mlir::IRMapping &mapper);
81
82 mlir::Location loc;
83 mlir::Region &region;
84 /// Was generateNoneElementalPart called?
85 bool noneElementalPartWasGenerated = false;
86 /// Set of operations that form the elemental parts of the
87 /// expression evaluation. These are the hlfir.elemental and
88 /// hlfir.elemental_addr that form the elemental tree producing
89 /// the expression value. hlfir.elemental that produce values
90 /// used inside transformational operations are not part of this set.
91 llvm::SmallSet<mlir::Operation *, 4> elementalParts{};
92};
93} // namespace
94
95namespace {
96/// Structure that visits an ordered assignment tree and generates code for
97/// it according to a schedule.
98class OrderedAssignmentRewriter {
99public:
100 OrderedAssignmentRewriter(fir::FirOpBuilder &builder,
101 hlfir::OrderedAssignmentTreeOpInterface root)
102 : builder{builder}, root{root} {}
103
104 /// Generate code for the current run of the schedule.
105 void lowerRun(hlfir::Run &run) {
106 currentRun = &run;
107 walk(root);
108 currentRun = nullptr;
109 assert(constructStack.empty() && "must exit constructs after a run");
110 mapper.clear();
111 savedInCurrentRunBeforeUse.clear();
112 }
113
114 /// After all run have been lowered, clean-up all the temporary
115 /// storage that were created (do not call final routines).
116 void cleanupSavedEntities() {
117 for (auto &temp : savedEntities)
118 temp.second.destroy(root.getLoc(), builder);
119 }
120
121 /// Lowered value for an expression, and the original hlfir.yield if any
122 /// clean-up needs to be cloned after usage.
123 using ValueAndCleanUp = std::pair<mlir::Value, std::optional<hlfir::YieldOp>>;
124
125private:
126 /// Walk the part of an order assignment tree node that needs
127 /// to be evaluated in the current run.
128 void walk(hlfir::OrderedAssignmentTreeOpInterface node);
129
130 /// Generate code when entering a given ordered assignment node.
131 void pre(hlfir::ForallOp forallOp);
132 void pre(hlfir::ForallIndexOp);
133 void pre(hlfir::ForallMaskOp);
134 void pre(hlfir::WhereOp whereOp);
135 void pre(hlfir::ElseWhereOp elseWhereOp);
136 void pre(hlfir::RegionAssignOp);
137
138 /// Generate code when leaving a given ordered assignment node.
139 void post(hlfir::ForallOp);
140 void post(hlfir::ForallMaskOp);
141 void post(hlfir::WhereOp);
142 void post(hlfir::ElseWhereOp);
143 /// Enter (and maybe create) the fir.if else block of an ElseWhereOp,
144 /// but do not generate the elswhere mask or the new fir.if.
145 void enterElsewhere(hlfir::ElseWhereOp);
146
147 /// Are there any leaf region in the node that must be saved in the current
148 /// run?
149 bool mustSaveRegionIn(
150 hlfir::OrderedAssignmentTreeOpInterface node,
151 llvm::SmallVectorImpl<hlfir::SaveEntity> &saveEntities) const;
152 /// Should this node be evaluated in the current run? Saving a region in a
153 /// node does not imply the node needs to be evaluated.
154 bool
155 isRequiredInCurrentRun(hlfir::OrderedAssignmentTreeOpInterface node) const;
156
157 /// Generate a scalar value yielded by an ordered assignment tree region.
158 /// If the value was not saved in a previous run, this clone the region
159 /// code, except the final yield, at the current execution point.
160 /// If the value was saved in a previous run, this fetches the saved value
161 /// from the temporary storage and returns the value.
162 /// Inside Forall, the value will be hoisted outside of the forall loops if
163 /// it does not depend on the forall indices.
164 /// An optional type can be provided to get a value from a specific type
165 /// (the cast will be hoisted if the computation is hoisted).
166 mlir::Value generateYieldedScalarValue(
167 mlir::Region &region,
168 std::optional<mlir::Type> castToType = std::nullopt);
169
170 /// Generate an entity yielded by an ordered assignment tree region, and
171 /// optionally return the (uncloned) yield if there is any clean-up that
172 /// should be done after using the entity. Like, generateYieldedScalarValue,
173 /// this will return the saved value if the region was saved in a previous
174 /// run.
175 ValueAndCleanUp
176 generateYieldedEntity(mlir::Region &region,
177 std::optional<mlir::Type> castToType = std::nullopt);
178
179 struct LhsValueAndCleanUp {
180 mlir::Value lhs;
181 std::optional<hlfir::YieldOp> elementalCleanup;
182 mlir::Region *nonElementalCleanup = nullptr;
183 std::optional<hlfir::LoopNest> vectorSubscriptLoopNest;
184 std::optional<mlir::Value> vectorSubscriptShape;
185 };
186
187 /// Generate the left-hand side. If the left-hand side is vector
188 /// subscripted (hlfir.elemental_addr), this will create a loop nest
189 /// (unless it was already created by a WHERE mask) and return the
190 /// element address.
191 LhsValueAndCleanUp
192 generateYieldedLHS(mlir::Location loc, mlir::Region &lhsRegion,
193 std::optional<hlfir::Entity> loweredRhs = std::nullopt);
194
195 /// If \p maybeYield is present and has a clean-up, generate the clean-up
196 /// at the current insertion point (by cloning).
197 void generateCleanupIfAny(std::optional<hlfir::YieldOp> maybeYield);
198 void generateCleanupIfAny(mlir::Region *cleanupRegion);
199
200 /// Generate a masked entity. This can only be called when whereLoopNest was
201 /// set (When an hlfir.where is being visited).
202 /// This method returns the scalar element (that may have been previously
203 /// saved) for the current indices inside the where loop.
204 mlir::Value generateMaskedEntity(mlir::Location loc, mlir::Region &region) {
205 MaskedArrayExpr maskedExpr(loc, region);
206 return generateMaskedEntity(maskedExpr);
207 }
208 mlir::Value generateMaskedEntity(MaskedArrayExpr &maskedExpr);
209
210 /// Create a fir.if at the current position inside the where loop nest
211 /// given the element value of a mask.
212 void generateMaskIfOp(mlir::Value cdt);
213
214 /// Save a value for subsequent runs.
215 void generateSaveEntity(hlfir::SaveEntity savedEntity,
216 bool willUseSavedEntityInSameRun);
217 void saveLeftHandSide(hlfir::SaveEntity savedEntity,
218 hlfir::RegionAssignOp regionAssignOp);
219
220 /// Get a value if it was saved in this run or a previous run. Returns
221 /// nullopt if it has not been saved.
222 std::optional<ValueAndCleanUp> getIfSaved(mlir::Region &region);
223
224 /// Generate code before the loop nest for the current run, if any.
225 void doBeforeLoopNest(const std::function<void()> &callback) {
226 if (constructStack.empty()) {
227 callback();
228 return;
229 }
230 auto insertionPoint = builder.saveInsertionPoint();
231 builder.setInsertionPoint(constructStack[0]);
232 callback();
233 builder.restoreInsertionPoint(insertionPoint);
234 }
235
236 /// Can the current loop nest iteration number be computed? For simplicity,
237 /// this is true if and only if all the bounds and steps of the fir.do_loop
238 /// nest dominates the outer loop. The argument is filled with the current
239 /// loop nest on success.
240 bool currentLoopNestIterationNumberCanBeComputed(
241 llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest);
242
243 template <typename T>
244 fir::factory::TemporaryStorage *insertSavedEntity(mlir::Region &region,
245 T &&temp) {
246 auto inserted =
247 savedEntities.insert(std::make_pair(&region, std::forward<T>(temp)));
248 assert(inserted.second && "temp must have been emplaced");
249 return &inserted.first->second;
250 }
251
252 fir::FirOpBuilder &builder;
253
254 /// Map containing the mapping between the original order assignment tree
255 /// operations and the operations that have been cloned in the current run.
256 /// It is reset between two runs.
257 mlir::IRMapping mapper;
258 /// Dominance info is used to determine if inner loop bounds are all computed
259 /// before outer loop for the current loop. It does not need to be reset
260 /// between runs.
261 mlir::DominanceInfo dominanceInfo;
262 /// Construct stack in the current run. This allows setting back the insertion
263 /// point correctly when leaving a node that requires a fir.do_loop or fir.if
264 /// operation.
265 llvm::SmallVector<mlir::Operation *> constructStack;
266 /// Current where loop nest, if any.
267 std::optional<hlfir::LoopNest> whereLoopNest;
268
269 /// Map of temporary storage to keep track of saved entity once the run
270 /// that saves them has been lowered. It is kept in-between runs.
271 /// llvm::MapVector is used to guarantee deterministic order
272 /// of iterating through savedEntities (e.g. for generating
273 /// destruction code for the temporary storages).
274 llvm::MapVector<mlir::Region *, fir::factory::TemporaryStorage> savedEntities;
275 /// Map holding the values that were saved in the current run and that also
276 /// need to be used (because their construct will be visited). It is reset
277 /// after each run. It avoids having to store and fetch in the temporary
278 /// during the same run, which would require the temporary to have different
279 /// fetching and storing counters.
280 llvm::DenseMap<mlir::Region *, ValueAndCleanUp> savedInCurrentRunBeforeUse;
281
282 /// Root of the order assignment tree being lowered.
283 hlfir::OrderedAssignmentTreeOpInterface root;
284 /// Pointer to the current run of the schedule being lowered.
285 hlfir::Run *currentRun = nullptr;
286
287 /// When allocating temporary storage inlined, indicate if the storage should
288 /// be heap or stack allocated. Temporary allocated with the runtime are heap
289 /// allocated by the runtime.
290 bool allocateOnHeap = true;
291};
292} // namespace
293
294void OrderedAssignmentRewriter::walk(
295 hlfir::OrderedAssignmentTreeOpInterface node) {
296 bool mustVisit =
297 isRequiredInCurrentRun(node) || mlir::isa<hlfir::ForallIndexOp>(node);
298 llvm::SmallVector<hlfir::SaveEntity> saveEntities;
299 mlir::Operation *nodeOp = node.getOperation();
300 if (mustSaveRegionIn(node, saveEntities)) {
301 mlir::IRRewriter::InsertPoint insertionPoint;
302 if (auto elseWhereOp = mlir::dyn_cast<hlfir::ElseWhereOp>(nodeOp)) {
303 // ElseWhere mask to save must be evaluated inside the fir.if else
304 // for the previous where/elsewehere (its evaluation must be
305 // masked by the "pending control mask").
306 insertionPoint = builder.saveInsertionPoint();
307 enterElsewhere(elseWhereOp);
308 }
309 for (hlfir::SaveEntity saveEntity : saveEntities)
310 generateSaveEntity(savedEntity: saveEntity, willUseSavedEntityInSameRun: mustVisit);
311 if (insertionPoint.isSet())
312 builder.restoreInsertionPoint(insertionPoint);
313 }
314 if (mustVisit) {
315 llvm::TypeSwitch<mlir::Operation *, void>(nodeOp)
316 .Case<hlfir::ForallOp, hlfir::ForallIndexOp, hlfir::ForallMaskOp,
317 hlfir::RegionAssignOp, hlfir::WhereOp, hlfir::ElseWhereOp>(
318 [&](auto concreteOp) { pre(concreteOp); })
319 .Default([](auto) {});
320 if (auto *body = node.getSubTreeRegion()) {
321 for (mlir::Operation &op : body->getOps())
322 if (auto subNode =
323 mlir::dyn_cast<hlfir::OrderedAssignmentTreeOpInterface>(op))
324 walk(subNode);
325 llvm::TypeSwitch<mlir::Operation *, void>(nodeOp)
326 .Case<hlfir::ForallOp, hlfir::ForallMaskOp, hlfir::WhereOp,
327 hlfir::ElseWhereOp>([&](auto concreteOp) { post(concreteOp); })
328 .Default([](auto) {});
329 }
330 }
331}
332
333void OrderedAssignmentRewriter::pre(hlfir::ForallOp forallOp) {
334 /// Create a fir.do_loop given the hlfir.forall control values.
335 mlir::Type idxTy = builder.getIndexType();
336 mlir::Location loc = forallOp.getLoc();
337 mlir::Value lb = generateYieldedScalarValue(forallOp.getLbRegion(), idxTy);
338 mlir::Value ub = generateYieldedScalarValue(forallOp.getUbRegion(), idxTy);
339 mlir::Value step;
340 if (forallOp.getStepRegion().empty()) {
341 auto insertionPoint = builder.saveInsertionPoint();
342 if (!constructStack.empty())
343 builder.setInsertionPoint(constructStack[0]);
344 step = builder.createIntegerConstant(loc, idxTy, 1);
345 if (!constructStack.empty())
346 builder.restoreInsertionPoint(insertionPoint);
347 } else {
348 step = generateYieldedScalarValue(forallOp.getStepRegion(), idxTy);
349 }
350 auto doLoop = builder.create<fir::DoLoopOp>(loc, lb, ub, step);
351 builder.setInsertionPointToStart(doLoop.getBody());
352 mlir::Value oldIndex = forallOp.getForallIndexValue();
353 mlir::Value newIndex =
354 builder.createConvert(loc, oldIndex.getType(), doLoop.getInductionVar());
355 mapper.map(oldIndex, newIndex);
356 constructStack.push_back(doLoop);
357}
358
359void OrderedAssignmentRewriter::post(hlfir::ForallOp) {
360 assert(!constructStack.empty() && "must contain a loop");
361 builder.setInsertionPointAfter(constructStack.pop_back_val());
362}
363
364void OrderedAssignmentRewriter::pre(hlfir::ForallIndexOp forallIndexOp) {
365 mlir::Location loc = forallIndexOp.getLoc();
366 mlir::Type intTy = fir::unwrapRefType(forallIndexOp.getType());
367 mlir::Value indexVar =
368 builder.createTemporary(loc, intTy, forallIndexOp.getName());
369 mlir::Value newVal = mapper.lookupOrDefault(forallIndexOp.getIndex());
370 builder.createStoreWithConvert(loc, newVal, indexVar);
371 mapper.map(forallIndexOp, indexVar);
372}
373
374void OrderedAssignmentRewriter::pre(hlfir::ForallMaskOp forallMaskOp) {
375 mlir::Location loc = forallMaskOp.getLoc();
376 mlir::Value mask = generateYieldedScalarValue(forallMaskOp.getMaskRegion(),
377 builder.getI1Type());
378 auto ifOp = builder.create<fir::IfOp>(loc, std::nullopt, mask, false);
379 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
380 constructStack.push_back(ifOp);
381}
382
383void OrderedAssignmentRewriter::post(hlfir::ForallMaskOp forallMaskOp) {
384 assert(!constructStack.empty() && "must contain an ifop");
385 builder.setInsertionPointAfter(constructStack.pop_back_val());
386}
387
388/// Convert an entity to the type of a given mold.
389/// This is intended to help with cases where hlfir entity is a value while
390/// it must be used as a variable or vice-versa. These mismatches may occur
391/// between the type of user defined assignment block arguments and the actual
392/// argument that was lowered for them. The actual may be an in-memory copy
393/// while the block argument expects an hlfir.expr.
394static hlfir::Entity
395convertToMoldType(mlir::Location loc, fir::FirOpBuilder &builder,
396 hlfir::Entity input, hlfir::Entity mold,
397 llvm::SmallVectorImpl<hlfir::CleanupFunction> &cleanups) {
398 if (input.getType() == mold.getType())
399 return input;
400 fir::FirOpBuilder *b = &builder;
401 if (input.isVariable() && mold.isValue()) {
402 if (fir::isa_trivial(mold.getType())) {
403 // fir.ref<T> to T.
404 mlir::Value load = builder.create<fir::LoadOp>(loc, input);
405 return hlfir::Entity{builder.createConvert(loc, mold.getType(), load)};
406 }
407 // fir.ref<T> to hlfir.expr<T>.
408 mlir::Value asExpr = builder.create<hlfir::AsExprOp>(loc, input);
409 if (asExpr.getType() != mold.getType())
410 TODO(loc, "hlfir.expr conversion");
411 cleanups.emplace_back([=]() { b->create<hlfir::DestroyOp>(loc, asExpr); });
412 return hlfir::Entity{asExpr};
413 }
414 if (input.isValue() && mold.isVariable()) {
415 // T to fir.ref<T>, or hlfir.expr<T> to fir.ref<T>.
416 hlfir::AssociateOp associate = hlfir::genAssociateExpr(
417 loc, builder, input, mold.getFortranElementType(), ".tmp.val2ref");
418 cleanups.emplace_back(
419 [=]() { b->create<hlfir::EndAssociateOp>(loc, associate); });
420 return hlfir::Entity{associate.getBase()};
421 }
422 // Variable to Variable mismatch (e.g., fir.heap<T> vs fir.ref<T>), or value
423 // to Value mismatch (e.g. i1 vs fir.logical<4>).
424 if (mlir::isa<fir::BaseBoxType>(mold.getType()) &&
425 !mlir::isa<fir::BaseBoxType>(input.getType())) {
426 // An entity may have have been saved without descriptor while the original
427 // value had a descriptor (e.g., it was not contiguous).
428 auto emboxed = hlfir::convertToBox(loc, builder, input, mold.getType());
429 assert(!emboxed.second && "temp should already be in memory");
430 input = hlfir::Entity{fir::getBase(emboxed.first)};
431 }
432 return hlfir::Entity{builder.createConvert(loc, mold.getType(), input)};
433}
434
435void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
436 mlir::Location loc = regionAssignOp.getLoc();
437 std::optional<hlfir::LoopNest> elementalLoopNest;
438 auto [rhsValue, oldRhsYield] =
439 generateYieldedEntity(regionAssignOp.getRhsRegion());
440 hlfir::Entity rhsEntity{rhsValue};
441 LhsValueAndCleanUp loweredLhs =
442 generateYieldedLHS(loc, regionAssignOp.getLhsRegion(), rhsEntity);
443 hlfir::Entity lhsEntity{loweredLhs.lhs};
444 if (loweredLhs.vectorSubscriptLoopNest)
445 rhsEntity = hlfir::getElementAt(
446 loc, builder, rhsEntity,
447 loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
448 if (!regionAssignOp.getUserDefinedAssignment().empty()) {
449 hlfir::Entity userAssignLhs{regionAssignOp.getUserAssignmentLhs()};
450 hlfir::Entity userAssignRhs{regionAssignOp.getUserAssignmentRhs()};
451 std::optional<hlfir::LoopNest> elementalLoopNest;
452 if (lhsEntity.isArray() && userAssignLhs.isScalar()) {
453 // Elemental assignment with array argument (the RHS cannot be an array
454 // if the LHS is not).
455 mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
456 elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
457 builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody());
458 lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
459 elementalLoopNest->oneBasedIndices);
460 rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
461 elementalLoopNest->oneBasedIndices);
462 }
463
464 llvm::SmallVector<hlfir::CleanupFunction, 2> argConversionCleanups;
465 lhsEntity = convertToMoldType(loc, builder, lhsEntity, userAssignLhs,
466 argConversionCleanups);
467 rhsEntity = convertToMoldType(loc, builder, rhsEntity, userAssignRhs,
468 argConversionCleanups);
469 mapper.map(userAssignLhs, lhsEntity);
470 mapper.map(userAssignRhs, rhsEntity);
471 for (auto &op :
472 regionAssignOp.getUserDefinedAssignment().front().without_terminator())
473 (void)builder.clone(op, mapper);
474 for (auto &cleanupConversion : argConversionCleanups)
475 cleanupConversion();
476 if (elementalLoopNest)
477 builder.setInsertionPointAfter(elementalLoopNest->outerLoop);
478 } else {
479 // TODO: preserve allocatable assignment aspects for forall once
480 // they are conveyed in hlfir.region_assign.
481 builder.create<hlfir::AssignOp>(loc, rhsEntity, lhsEntity);
482 }
483 generateCleanupIfAny(loweredLhs.elementalCleanup);
484 if (loweredLhs.vectorSubscriptLoopNest)
485 builder.setInsertionPointAfter(
486 loweredLhs.vectorSubscriptLoopNest->outerLoop);
487 generateCleanupIfAny(oldRhsYield);
488 generateCleanupIfAny(loweredLhs.nonElementalCleanup);
489}
490
491void OrderedAssignmentRewriter::generateMaskIfOp(mlir::Value cdt) {
492 mlir::Location loc = cdt.getLoc();
493 cdt = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{cdt});
494 cdt = builder.createConvert(loc, builder.getI1Type(), cdt);
495 auto ifOp = builder.create<fir::IfOp>(cdt.getLoc(), std::nullopt, cdt,
496 /*withElseRegion=*/false);
497 constructStack.push_back(ifOp.getOperation());
498 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
499}
500
501void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
502 mlir::Location loc = whereOp.getLoc();
503 if (!whereLoopNest) {
504 // This is the top-level WHERE. Start a loop nest iterating on the shape of
505 // the where mask.
506 if (auto maybeSaved = getIfSaved(whereOp.getMaskRegion())) {
507 // Use the saved value to get the shape and condition element.
508 hlfir::Entity savedMask{maybeSaved->first};
509 mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
510 whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
511 constructStack.push_back(whereLoopNest->outerLoop.getOperation());
512 builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
513 mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
514 whereLoopNest->oneBasedIndices);
515 generateMaskIfOp(cdt);
516 if (maybeSaved->second) {
517 // If this is the same run as the one that saved the value, the clean-up
518 // was left-over to be done now.
519 auto insertionPoint = builder.saveInsertionPoint();
520 builder.setInsertionPointAfter(whereLoopNest->outerLoop);
521 generateCleanupIfAny(maybeSaved->second);
522 builder.restoreInsertionPoint(insertionPoint);
523 }
524 return;
525 }
526 // The mask was not evaluated yet or can be safely re-evaluated.
527 MaskedArrayExpr mask(loc, whereOp.getMaskRegion());
528 mask.generateNoneElementalPart(builder, mapper);
529 mlir::Value shape = mask.generateShape(builder, mapper);
530 whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
531 constructStack.push_back(whereLoopNest->outerLoop.getOperation());
532 builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
533 mlir::Value cdt = generateMaskedEntity(mask);
534 generateMaskIfOp(cdt);
535 return;
536 }
537 // Where Loops have been already created by a parent WHERE.
538 // Generate a fir.if with the value of the current element of the mask
539 // inside the loops. The case where the mask was saved is handled in the
540 // generateYieldedScalarValue call.
541 mlir::Value cdt = generateYieldedScalarValue(whereOp.getMaskRegion());
542 generateMaskIfOp(cdt);
543}
544
545void OrderedAssignmentRewriter::post(hlfir::WhereOp whereOp) {
546 assert(!constructStack.empty() && "must contain a fir.if");
547 builder.setInsertionPointAfter(constructStack.pop_back_val());
548 // If all where/elsewhere fir.if have been popped, this is the outer whereOp,
549 // and the where loop must be exited.
550 assert(!constructStack.empty() && "must contain a fir.do_loop or fir.if");
551 if (mlir::isa<fir::DoLoopOp>(constructStack.back())) {
552 builder.setInsertionPointAfter(constructStack.pop_back_val());
553 whereLoopNest.reset();
554 }
555}
556
557void OrderedAssignmentRewriter::enterElsewhere(hlfir::ElseWhereOp elseWhereOp) {
558 // Create an "else" region for the current where/elsewhere fir.if.
559 auto ifOp = mlir::dyn_cast<fir::IfOp>(constructStack.back());
560 assert(ifOp && "must be an if");
561 if (ifOp.getElseRegion().empty()) {
562 mlir::Location loc = elseWhereOp.getLoc();
563 builder.createBlock(&ifOp.getElseRegion());
564 auto end = builder.create<fir::ResultOp>(loc);
565 builder.setInsertionPoint(end);
566 } else {
567 builder.setInsertionPoint(&ifOp.getElseRegion().back().back());
568 }
569}
570
571void OrderedAssignmentRewriter::pre(hlfir::ElseWhereOp elseWhereOp) {
572 enterElsewhere(elseWhereOp);
573 if (elseWhereOp.getMaskRegion().empty())
574 return;
575 // Create new nested fir.if with elsewhere mask if any.
576 mlir::Value cdt = generateYieldedScalarValue(elseWhereOp.getMaskRegion());
577 generateMaskIfOp(cdt);
578}
579
580void OrderedAssignmentRewriter::post(hlfir::ElseWhereOp elseWhereOp) {
581 // Exit ifOp that was created for the elseWhereOp mask, if any.
582 if (elseWhereOp.getMaskRegion().empty())
583 return;
584 assert(!constructStack.empty() && "must contain a fir.if");
585 builder.setInsertionPointAfter(constructStack.pop_back_val());
586}
587
588/// Is this value a Forall index?
589/// Forall index are block arguments of hlfir.forall body, or the result
590/// of hlfir.forall_index.
591static bool isForallIndex(mlir::Value value) {
592 if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(value)) {
593 if (mlir::Block *block = blockArg.getOwner())
594 return block->isEntryBlock() &&
595 mlir::isa_and_nonnull<hlfir::ForallOp>(block->getParentOp());
596 return false;
597 }
598 return value.getDefiningOp<hlfir::ForallIndexOp>();
599}
600
601static OrderedAssignmentRewriter::ValueAndCleanUp
602castIfNeeded(mlir::Location loc, fir::FirOpBuilder &builder,
603 OrderedAssignmentRewriter::ValueAndCleanUp valueAndCleanUp,
604 std::optional<mlir::Type> castToType) {
605 if (!castToType.has_value())
606 return valueAndCleanUp;
607 mlir::Value cast =
608 builder.createConvert(loc, *castToType, valueAndCleanUp.first);
609 return {cast, valueAndCleanUp.second};
610}
611
612std::optional<OrderedAssignmentRewriter::ValueAndCleanUp>
613OrderedAssignmentRewriter::getIfSaved(mlir::Region &region) {
614 mlir::Location loc = region.getParentOp()->getLoc();
615 // If the region was saved in the same run, use the value that was evaluated
616 // instead of fetching the temp, and do clean-up, if any, that were delayed.
617 // This is done to avoid requiring the temporary stack to have different
618 // fetching and storing counters, and also because it produces slightly better
619 // code.
620 if (auto savedInSameRun = savedInCurrentRunBeforeUse.find(&region);
621 savedInSameRun != savedInCurrentRunBeforeUse.end())
622 return savedInSameRun->second;
623 // If the region was saved in a previous run, fetch the saved value.
624 if (auto temp = savedEntities.find(&region); temp != savedEntities.end()) {
625 doBeforeLoopNest(callback: [&]() { temp->second.resetFetchPosition(loc, builder); });
626 return ValueAndCleanUp{temp->second.fetch(loc, builder), std::nullopt};
627 }
628 return std::nullopt;
629}
630
631OrderedAssignmentRewriter::ValueAndCleanUp
632OrderedAssignmentRewriter::generateYieldedEntity(
633 mlir::Region &region, std::optional<mlir::Type> castToType) {
634 mlir::Location loc = region.getParentOp()->getLoc();
635 if (auto maybeValueAndCleanUp = getIfSaved(region))
636 return castIfNeeded(loc, builder, *maybeValueAndCleanUp, castToType);
637 // Otherwise, evaluate the region now.
638
639 // Masked expression must not evaluate the elemental parts that are masked,
640 // they have custom code generation.
641 if (whereLoopNest.has_value()) {
642 mlir::Value maskedValue = generateMaskedEntity(loc, region);
643 return castIfNeeded(loc, builder, {maskedValue, std::nullopt}, castToType);
644 }
645
646 assert(region.hasOneBlock() && "region must contain one block");
647 auto oldYield = mlir::dyn_cast_or_null<hlfir::YieldOp>(
648 region.back().getOperations().back());
649 assert(oldYield && "region computing entities must end with a YieldOp");
650 mlir::Block::OpListType &ops = region.back().getOperations();
651
652 // Inside Forall, scalars that do not depend on forall indices can be hoisted
653 // here because their evaluation is required to only call pure procedures, and
654 // if they depend on a variable previously assigned to in a forall assignment,
655 // this assignment must have been scheduled in a previous run. Hoisting of
656 // scalars is done here to help creating simple temporary storage if needed.
657 // Inner forall bounds can often be hoisted, and this allows computing the
658 // total number of iterations to create temporary storages.
659 bool hoistComputation = false;
660 if (fir::isa_trivial(oldYield.getEntity().getType()) &&
661 !constructStack.empty()) {
662 hoistComputation = true;
663 for (mlir::Operation &op : ops)
664 if (llvm::any_of(op.getOperands(), [](mlir::Value value) {
665 return isForallIndex(value);
666 })) {
667 hoistComputation = false;
668 break;
669 }
670 }
671 auto insertionPoint = builder.saveInsertionPoint();
672 if (hoistComputation)
673 builder.setInsertionPoint(constructStack[0]);
674
675 // Clone all operations except the final hlfir.yield.
676 assert(!ops.empty() && "yield block cannot be empty");
677 auto end = ops.end();
678 for (auto opIt = ops.begin(); std::next(opIt) != end; ++opIt)
679 (void)builder.clone(*opIt, mapper);
680 // Get the value for the yielded entity, it may be the result of an operation
681 // that was cloned, or it may be the same as the previous value if the yield
682 // operand was created before the ordered assignment tree.
683 mlir::Value newEntity = mapper.lookupOrDefault(oldYield.getEntity());
684 if (castToType.has_value())
685 newEntity =
686 builder.createConvert(newEntity.getLoc(), *castToType, newEntity);
687
688 if (hoistComputation) {
689 // Hoisted trivial scalars clean-up can be done right away, the value is
690 // in registers.
691 generateCleanupIfAny(oldYield);
692 builder.restoreInsertionPoint(insertionPoint);
693 return {newEntity, std::nullopt};
694 }
695 if (oldYield.getCleanup().empty())
696 return {newEntity, std::nullopt};
697 return {newEntity, oldYield};
698}
699
700mlir::Value OrderedAssignmentRewriter::generateYieldedScalarValue(
701 mlir::Region &region, std::optional<mlir::Type> castToType) {
702 mlir::Location loc = region.getParentOp()->getLoc();
703 auto [value, maybeYield] = generateYieldedEntity(region, castToType);
704 value = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{value});
705 assert(fir::isa_trivial(value.getType()) && "not a trivial scalar value");
706 generateCleanupIfAny(maybeYield);
707 return value;
708}
709
710OrderedAssignmentRewriter::LhsValueAndCleanUp
711OrderedAssignmentRewriter::generateYieldedLHS(
712 mlir::Location loc, mlir::Region &lhsRegion,
713 std::optional<hlfir::Entity> loweredRhs) {
714 LhsValueAndCleanUp loweredLhs;
715 hlfir::ElementalAddrOp elementalAddrLhs =
716 mlir::dyn_cast<hlfir::ElementalAddrOp>(lhsRegion.back().back());
717 if (auto temp = savedEntities.find(&lhsRegion); temp != savedEntities.end()) {
718 // The LHS address was computed and saved in a previous run. Fetch it.
719 doBeforeLoopNest(callback: [&]() { temp->second.resetFetchPosition(loc, builder); });
720 if (elementalAddrLhs && !whereLoopNest) {
721 // Vector subscripted designator address are saved element by element.
722 // If no "elemental" loops have been created yet, the shape of the
723 // RHS, if it is an array can be used, or the shape of the vector
724 // subscripted designator must be retrieved to generate the "elemental"
725 // loop nest.
726 if (loweredRhs && loweredRhs->isArray()) {
727 // The RHS shape can be used to create the elemental loops and avoid
728 // saving the LHS shape.
729 loweredLhs.vectorSubscriptShape =
730 hlfir::genShape(loc, builder, *loweredRhs);
731 } else {
732 // If the shape cannot be retrieved from the RHS, it must have been
733 // saved. Get it from the temporary.
734 auto &vectorTmp =
735 temp->second.cast<fir::factory::AnyVectorSubscriptStack>();
736 loweredLhs.vectorSubscriptShape = vectorTmp.fetchShape(loc, builder);
737 }
738 loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
739 loc, builder, loweredLhs.vectorSubscriptShape.value());
740 builder.setInsertionPointToStart(
741 loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
742 }
743 loweredLhs.lhs = temp->second.fetch(loc, builder);
744 return loweredLhs;
745 }
746 // The LHS has not yet been evaluated and saved. Evaluate it now.
747 if (elementalAddrLhs && !whereLoopNest) {
748 // This is a vector subscripted entity. The address of elements must
749 // be returned. If no "elemental" loops have been created for a WHERE,
750 // create them now based on the vector subscripted designator shape.
751 for (auto &op : lhsRegion.front().without_terminator())
752 (void)builder.clone(op, mapper);
753 loweredLhs.vectorSubscriptShape =
754 mapper.lookupOrDefault(elementalAddrLhs.getShape());
755 loweredLhs.vectorSubscriptLoopNest =
756 hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
757 !elementalAddrLhs.isOrdered());
758 builder.setInsertionPointToStart(
759 loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
760 mapper.map(elementalAddrLhs.getIndices(),
761 loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
762 for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
763 (void)builder.clone(op, mapper);
764 loweredLhs.elementalCleanup = elementalAddrLhs.getYieldOp();
765 loweredLhs.lhs =
766 mapper.lookupOrDefault(loweredLhs.elementalCleanup->getEntity());
767 } else {
768 // This is a designator without vector subscripts. Generate it as
769 // it is done for other entities.
770 auto [lhs, yield] = generateYieldedEntity(lhsRegion);
771 loweredLhs.lhs = lhs;
772 if (yield && !yield->getCleanup().empty())
773 loweredLhs.nonElementalCleanup = &yield->getCleanup();
774 }
775 return loweredLhs;
776}
777
778mlir::Value
779OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
780 assert(whereLoopNest.has_value() && "must be inside WHERE loop nest");
781 auto insertionPoint = builder.saveInsertionPoint();
782 if (!maskedExpr.noneElementalPartWasGenerated) {
783 // Generate none elemental part before the where loops (but inside the
784 // current forall loops if any).
785 builder.setInsertionPoint(whereLoopNest->outerLoop);
786 maskedExpr.generateNoneElementalPart(builder, mapper);
787 }
788 // Generate the none elemental part cleanup after the where loops.
789 builder.setInsertionPointAfter(whereLoopNest->outerLoop);
790 maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
791 // Generate the value of the current element for the masked expression
792 // at the current insertion point (inside the where loops, and any fir.if
793 // generated for previous masks).
794 builder.restoreInsertionPoint(insertionPoint);
795 return maskedExpr.generateElementalParts(
796 builder, whereLoopNest->oneBasedIndices, mapper);
797}
798
799void OrderedAssignmentRewriter::generateCleanupIfAny(
800 std::optional<hlfir::YieldOp> maybeYield) {
801 if (maybeYield.has_value())
802 generateCleanupIfAny(&maybeYield->getCleanup());
803}
804void OrderedAssignmentRewriter::generateCleanupIfAny(
805 mlir::Region *cleanupRegion) {
806 if (cleanupRegion && !cleanupRegion->empty()) {
807 assert(cleanupRegion->hasOneBlock() && "region must contain one block");
808 for (auto &op : cleanupRegion->back().without_terminator())
809 builder.clone(op, mapper);
810 }
811}
812
813bool OrderedAssignmentRewriter::mustSaveRegionIn(
814 hlfir::OrderedAssignmentTreeOpInterface node,
815 llvm::SmallVectorImpl<hlfir::SaveEntity> &saveEntities) const {
816 for (auto &action : currentRun->actions)
817 if (hlfir::SaveEntity *savedEntity =
818 std::get_if<hlfir::SaveEntity>(&action))
819 if (node.getOperation() == savedEntity->yieldRegion->getParentOp())
820 saveEntities.push_back(*savedEntity);
821 return !saveEntities.empty();
822}
823
824bool OrderedAssignmentRewriter::isRequiredInCurrentRun(
825 hlfir::OrderedAssignmentTreeOpInterface node) const {
826 // hlfir.forall_index do not contain saved regions/assignments,
827 // but if their hlfir.forall parent was required, they are
828 // required (the forall indices needs to be mapped).
829 if (mlir::isa<hlfir::ForallIndexOp>(node))
830 return true;
831 for (auto &action : currentRun->actions)
832 if (hlfir::SaveEntity *savedEntity =
833 std::get_if<hlfir::SaveEntity>(&action)) {
834 // A SaveEntity action does not require evaluating the node that contains
835 // it, but it requires to evaluate all the parents of the nodes that
836 // contains it. For instance, an saving a bound in hlfir.forall B does not
837 // require creating the loops for B, but it requires creating the loops
838 // for any forall parent A of the forall B.
839 if (node->isProperAncestor(savedEntity->yieldRegion->getParentOp()))
840 return true;
841 } else {
842 auto assign = std::get<hlfir::RegionAssignOp>(action);
843 if (node->isAncestor(assign.getOperation()))
844 return true;
845 }
846 return false;
847}
848
849/// Is the apply using all the elemental indices in order?
850static bool isInOrderApply(hlfir::ApplyOp apply,
851 hlfir::ElementalOpInterface elemental) {
852 mlir::Region::BlockArgListType elementalIndices = elemental.getIndices();
853 if (elementalIndices.size() != apply.getIndices().size())
854 return false;
855 for (auto [elementalIdx, applyIdx] :
856 llvm::zip(elementalIndices, apply.getIndices()))
857 if (elementalIdx != applyIdx)
858 return false;
859 return true;
860}
861
862/// Gather the tree of hlfir::ElementalOpInterface use-def, if any, starting
863/// from \p elemental, which may be a nullptr.
864static void
865gatherElementalTree(hlfir::ElementalOpInterface elemental,
866 llvm::SmallPtrSetImpl<mlir::Operation *> &elementalOps,
867 bool isOutOfOrder) {
868 if (elemental) {
869 // Only inline an applied elemental that must be executed in order if the
870 // applying indices are in order. An hlfir::Elemental may have been created
871 // for a transformational like transpose, and Fortran 2018 standard
872 // section 10.2.3.2, point 10 imply that impure elemental sub-expression
873 // evaluations should not be masked if they are the arguments of
874 // transformational expressions.
875 if (isOutOfOrder && elemental.isOrdered())
876 return;
877 elementalOps.insert(elemental.getOperation());
878 for (mlir::Operation &op : elemental.getElementalRegion().getOps())
879 if (auto apply = mlir::dyn_cast<hlfir::ApplyOp>(op)) {
880 bool isUnorderedApply =
881 isOutOfOrder || !isInOrderApply(apply, elemental);
882 auto maybeElemental =
883 mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>(
884 apply.getExpr().getDefiningOp());
885 gatherElementalTree(maybeElemental, elementalOps, isUnorderedApply);
886 }
887 }
888}
889
890MaskedArrayExpr::MaskedArrayExpr(mlir::Location loc, mlir::Region &region)
891 : loc{loc}, region{region} {
892 mlir::Operation &terminator = region.back().back();
893 if (auto elementalAddr =
894 mlir::dyn_cast<hlfir::ElementalOpInterface>(terminator)) {
895 // Vector subscripted designator (hlfir.elemental_addr terminator).
896 gatherElementalTree(elementalAddr, elementalParts, /*isOutOfOrder=*/false);
897 return;
898 }
899 // Try if elemental expression.
900 mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
901 auto maybeElemental = mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>(
902 entity.getDefiningOp());
903 gatherElementalTree(maybeElemental, elementalParts, /*isOutOfOrder=*/false);
904}
905
906void MaskedArrayExpr::generateNoneElementalPart(fir::FirOpBuilder &builder,
907 mlir::IRMapping &mapper) {
908 assert(!noneElementalPartWasGenerated &&
909 "none elemental parts already generated");
910 // Clone all operations, except the elemental and the final yield.
911 mlir::Block::OpListType &ops = region.back().getOperations();
912 assert(!ops.empty() && "yield block cannot be empty");
913 auto end = ops.end();
914 for (auto opIt = ops.begin(); std::next(opIt) != end; ++opIt)
915 if (!elementalParts.contains(&*opIt))
916 (void)builder.clone(*opIt, mapper);
917 noneElementalPartWasGenerated = true;
918}
919
920mlir::Value MaskedArrayExpr::generateShape(fir::FirOpBuilder &builder,
921 mlir::IRMapping &mapper) {
922 assert(noneElementalPartWasGenerated &&
923 "non elemental part must have been generated");
924 mlir::Operation &terminator = region.back().back();
925 // If the operation that produced the yielded entity is elemental, it was not
926 // cloned, but it holds a shape argument that was cloned. Return the cloned
927 // shape.
928 if (auto elementalAddrOp = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator))
929 return mapper.lookupOrDefault(elementalAddrOp.getShape());
930 mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
931 if (auto elemental = entity.getDefiningOp<hlfir::ElementalOp>())
932 return mapper.lookupOrDefault(elemental.getShape());
933 // Otherwise, the whole entity was cloned, and the shape can be generated
934 // from it.
935 hlfir::Entity clonedEntity{mapper.lookupOrDefault(entity)};
936 return hlfir::genShape(loc, builder, hlfir::Entity{clonedEntity});
937}
938
939mlir::Value
940MaskedArrayExpr::generateElementalParts(fir::FirOpBuilder &builder,
941 mlir::ValueRange oneBasedIndices,
942 mlir::IRMapping &mapper) {
943 assert(noneElementalPartWasGenerated &&
944 "non elemental part must have been generated");
945 mlir::Operation &terminator = region.back().back();
946 hlfir::ElementalOpInterface elemental =
947 mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator);
948 if (!elemental) {
949 // If the terminator is not an hlfir.elemental_addr, try if the yielded
950 // entity was produced by an hlfir.elemental.
951 mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
952 elemental = entity.getDefiningOp<hlfir::ElementalOp>();
953 if (!elemental) {
954 // The yielded entity was not produced by an elemental operation,
955 // get its clone in the non elemental part evaluation and address it.
956 hlfir::Entity clonedEntity{mapper.lookupOrDefault(entity)};
957 return hlfir::getElementAt(loc, builder, clonedEntity, oneBasedIndices);
958 }
959 }
960
961 auto mustRecursivelyInline =
962 [&](hlfir::ElementalOp appliedElemental) -> bool {
963 return elementalParts.contains(appliedElemental.getOperation());
964 };
965 return inlineElementalOp(loc, builder, elemental, oneBasedIndices, mapper,
966 mustRecursivelyInline);
967}
968
969void MaskedArrayExpr::generateNoneElementalCleanupIfAny(
970 fir::FirOpBuilder &builder, mlir::IRMapping &mapper) {
971 mlir::Operation &terminator = region.back().back();
972 mlir::Region *cleanupRegion = nullptr;
973 if (auto elementalAddr = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator)) {
974 cleanupRegion = &elementalAddr.getCleanup();
975 } else {
976 auto yieldOp = mlir::cast<hlfir::YieldOp>(terminator);
977 cleanupRegion = &yieldOp.getCleanup();
978 }
979 if (cleanupRegion->empty())
980 return;
981 for (mlir::Operation &op : cleanupRegion->front().without_terminator()) {
982 if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(op))
983 if (elementalParts.contains(destroy.getExpr().getDefiningOp()))
984 continue;
985 (void)builder.clone(op, mapper);
986 }
987}
988
989static hlfir::RegionAssignOp
990getAssignIfLeftHandSideRegion(mlir::Region &region) {
991 auto assign = mlir::dyn_cast<hlfir::RegionAssignOp>(region.getParentOp());
992 if (assign && (&assign.getLhsRegion() == &region))
993 return assign;
994 return nullptr;
995}
996
997bool OrderedAssignmentRewriter::currentLoopNestIterationNumberCanBeComputed(
998 llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest) {
999 if (constructStack.empty())
1000 return true;
1001 mlir::Operation *outerLoop = constructStack[0];
1002 mlir::Operation *currentConstruct = constructStack.back();
1003 // Loop through the loops until the outer construct is met, and test if the
1004 // loop operands dominate the outer construct.
1005 while (currentConstruct) {
1006 if (auto doLoop = mlir::dyn_cast<fir::DoLoopOp>(currentConstruct)) {
1007 if (llvm::any_of(doLoop->getOperands(), [&](mlir::Value value) {
1008 return !dominanceInfo.properlyDominates(value, outerLoop);
1009 })) {
1010 return false;
1011 }
1012 loopNest.push_back(doLoop);
1013 }
1014 if (currentConstruct == outerLoop)
1015 currentConstruct = nullptr;
1016 else
1017 currentConstruct = currentConstruct->getParentOp();
1018 }
1019 return true;
1020}
1021
1022static mlir::Value
1023computeLoopNestIterationNumber(mlir::Location loc, fir::FirOpBuilder &builder,
1024 llvm::ArrayRef<fir::DoLoopOp> loopNest) {
1025 mlir::Value loopExtent;
1026 for (fir::DoLoopOp doLoop : loopNest) {
1027 mlir::Value extent = builder.genExtentFromTriplet(
1028 loc, doLoop.getLowerBound(), doLoop.getUpperBound(), doLoop.getStep(),
1029 builder.getIndexType());
1030 if (!loopExtent)
1031 loopExtent = extent;
1032 else
1033 loopExtent = builder.create<mlir::arith::MulIOp>(loc, loopExtent, extent);
1034 }
1035 assert(loopExtent && "loopNest must not be empty");
1036 return loopExtent;
1037}
1038
1039/// Return a name for temporary storage that indicates in which context
1040/// the temporary storage was created.
1041static llvm::StringRef
1042getTempName(hlfir::OrderedAssignmentTreeOpInterface root) {
1043 if (mlir::isa<hlfir::ForallOp>(root.getOperation()))
1044 return ".tmp.forall";
1045 if (mlir::isa<hlfir::WhereOp>(root.getOperation()))
1046 return ".tmp.where";
1047 return ".tmp.assign";
1048}
1049
1050void OrderedAssignmentRewriter::generateSaveEntity(
1051 hlfir::SaveEntity savedEntity, bool willUseSavedEntityInSameRun) {
1052 mlir::Region &region = *savedEntity.yieldRegion;
1053
1054 if (hlfir::RegionAssignOp regionAssignOp =
1055 getAssignIfLeftHandSideRegion(region)) {
1056 // Need to save the address, not the values.
1057 assert(!willUseSavedEntityInSameRun &&
1058 "lhs cannot be used in the loop nest where it is saved");
1059 return saveLeftHandSide(savedEntity, regionAssignOp);
1060 }
1061
1062 mlir::Location loc = region.getParentOp()->getLoc();
1063 // Evaluate the region inside the loop nest (if any).
1064 auto [clonedValue, oldYield] = generateYieldedEntity(region);
1065 hlfir::Entity entity{clonedValue};
1066 entity = hlfir::loadTrivialScalar(loc, builder, entity);
1067 mlir::Type entityType = entity.getType();
1068
1069 llvm::StringRef tempName = getTempName(root);
1070 fir::factory::TemporaryStorage *temp = nullptr;
1071 if (constructStack.empty()) {
1072 // Value evaluated outside of any loops (this may be the first MASK of a
1073 // WHERE construct, or an LHS/RHS temp of hlfir.region_assign outside of
1074 // WHERE/FORALL).
1075 temp = insertSavedEntity(
1076 region, fir::factory::SimpleCopy(loc, builder, entity, tempName));
1077 } else {
1078 // Need to create a temporary for values computed inside loops.
1079 // Create temporary storage outside of the loop nest given the entity
1080 // type (and the loop context).
1081 llvm::SmallVector<fir::DoLoopOp> loopNest;
1082 bool loopShapeCanBePreComputed =
1083 currentLoopNestIterationNumberCanBeComputed(loopNest);
1084 doBeforeLoopNest(callback: [&] {
1085 /// For simple scalars inside loops whose total iteration number can be
1086 /// pre-computed, create a rank-1 array outside of the loops. It will be
1087 /// assigned/fetched inside the loops like a normal Fortran array given
1088 /// the iteration count.
1089 if (loopShapeCanBePreComputed && fir::isa_trivial(entityType)) {
1090 mlir::Value loopExtent =
1091 computeLoopNestIterationNumber(loc, builder, loopNest);
1092 auto sequenceType =
1093 builder.getVarLenSeqTy(entityType).cast<fir::SequenceType>();
1094 temp = insertSavedEntity(region,
1095 fir::factory::HomogeneousScalarStack{
1096 loc, builder, sequenceType, loopExtent,
1097 /*lenParams=*/{}, allocateOnHeap,
1098 /*stackThroughLoops=*/true, tempName});
1099
1100 } else {
1101 // If the number of iteration is not known, or if the values at each
1102 // iterations are values that may have different shape, type parameters
1103 // or dynamic type, use the runtime to create and manage a stack-like
1104 // temporary.
1105 temp = insertSavedEntity(
1106 region, fir::factory::AnyValueStack{loc, builder, entityType});
1107 }
1108 });
1109 // Inside the loop nest (and any fir.if if there are active masks), copy
1110 // the value to the temp and do clean-ups for the value if any.
1111 temp->pushValue(loc, builder, entity);
1112 }
1113
1114 // Delay the clean-up if the entity will be used in the same run (i.e., the
1115 // parent construct will be visited and needs to be lowered). When possible,
1116 // this is not done for hlfir.expr because this use would prevent the
1117 // hlfir.expr storage from being moved when creating the temporary in
1118 // bufferization, and that would lead to an extra copy.
1119 if (willUseSavedEntityInSameRun &&
1120 (!temp->canBeFetchedAfterPush() ||
1121 !mlir::isa<hlfir::ExprType>(entity.getType()))) {
1122 auto inserted =
1123 savedInCurrentRunBeforeUse.try_emplace(&region, entity, oldYield);
1124 assert(inserted.second && "entity must have been emplaced");
1125 (void)inserted;
1126 } else {
1127 if (constructStack.empty() &&
1128 mlir::isa<hlfir::RegionAssignOp>(region.getParentOp())) {
1129 // Here the clean-up code is inserted after the original
1130 // RegionAssignOp, so that the assignment code happens
1131 // before the cleanup. We do this only for standalone
1132 // operations, because the clean-up is handled specially
1133 // during lowering of the parent constructs if any
1134 // (e.g. see generateNoneElementalCleanupIfAny for
1135 // WhereOp).
1136 auto insertionPoint = builder.saveInsertionPoint();
1137 builder.setInsertionPointAfter(region.getParentOp());
1138 generateCleanupIfAny(oldYield);
1139 builder.restoreInsertionPoint(insertionPoint);
1140 } else {
1141 generateCleanupIfAny(oldYield);
1142 }
1143 }
1144}
1145
1146static bool rhsIsArray(hlfir::RegionAssignOp regionAssignOp) {
1147 auto yieldOp = mlir::dyn_cast<hlfir::YieldOp>(
1148 regionAssignOp.getRhsRegion().back().back());
1149 return yieldOp && hlfir::Entity{yieldOp.getEntity()}.isArray();
1150}
1151
1152void OrderedAssignmentRewriter::saveLeftHandSide(
1153 hlfir::SaveEntity savedEntity, hlfir::RegionAssignOp regionAssignOp) {
1154 mlir::Region &region = *savedEntity.yieldRegion;
1155 mlir::Location loc = region.getParentOp()->getLoc();
1156 LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
1157 fir::factory::TemporaryStorage *temp = nullptr;
1158 if (loweredLhs.vectorSubscriptLoopNest)
1159 constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop);
1160 if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
1161 // Vector subscripted entity for which the shape must also be saved on top
1162 // of the element addresses (e.g. the shape may change in each forall
1163 // iteration and is needed to create the elemental loops).
1164 mlir::Value shape = loweredLhs.vectorSubscriptShape.value();
1165 int rank = mlir::cast<fir::ShapeType>(shape.getType()).getRank();
1166 const bool shapeIsInvariant =
1167 constructStack.empty() ||
1168 dominanceInfo.properlyDominates(shape, constructStack[0]);
1169 doBeforeLoopNest(callback: [&] {
1170 // Outside of any forall/where/elemental loops, create a temporary that
1171 // will both be able to save the vector subscripted designator shape(s)
1172 // and element addresses.
1173 temp =
1174 insertSavedEntity(region, fir::factory::AnyVectorSubscriptStack{
1175 loc, builder, loweredLhs.lhs.getType(),
1176 shapeIsInvariant, rank});
1177 });
1178 // Save shape before the elemental loop nest created by the vector
1179 // subscripted LHS.
1180 auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
1181 auto insertionPoint = builder.saveInsertionPoint();
1182 builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop);
1183 vectorTmp.pushShape(loc, builder, shape);
1184 builder.restoreInsertionPoint(insertionPoint);
1185 } else {
1186 // Otherwise, only save the LHS address.
1187 // If the LHS address dominates the constructs, its SSA value can
1188 // simply be tracked and there is no need to save the address in memory.
1189 // Otherwise, the addresses are stored at each iteration in memory with
1190 // a descriptor stack.
1191 if (constructStack.empty() ||
1192 dominanceInfo.properlyDominates(loweredLhs.lhs, constructStack[0]))
1193 doBeforeLoopNest(callback: [&] {
1194 temp = insertSavedEntity(region, fir::factory::SSARegister{});
1195 });
1196 else
1197 doBeforeLoopNest(callback: [&] {
1198 temp = insertSavedEntity(
1199 region, fir::factory::AnyVariableStack{loc, builder,
1200 loweredLhs.lhs.getType()});
1201 });
1202 }
1203 temp->pushValue(loc, builder, loweredLhs.lhs);
1204 generateCleanupIfAny(loweredLhs.elementalCleanup);
1205 if (loweredLhs.vectorSubscriptLoopNest) {
1206 constructStack.pop_back();
1207 builder.setInsertionPointAfter(
1208 loweredLhs.vectorSubscriptLoopNest->outerLoop);
1209 }
1210}
1211
1212/// Lower an ordered assignment tree to fir.do_loop and hlfir.assign given
1213/// a schedule.
1214static void lower(hlfir::OrderedAssignmentTreeOpInterface root,
1215 mlir::PatternRewriter &rewriter, hlfir::Schedule &schedule) {
1216 auto module = root->getParentOfType<mlir::ModuleOp>();
1217 fir::FirOpBuilder builder(rewriter, module);
1218 OrderedAssignmentRewriter assignmentRewriter(builder, root);
1219 for (auto &run : schedule)
1220 assignmentRewriter.lowerRun(run);
1221 assignmentRewriter.cleanupSavedEntities();
1222}
1223
1224/// Shared rewrite entry point for all the ordered assignment tree root
1225/// operations. It calls the scheduler and then apply the schedule.
1226static mlir::LogicalResult rewrite(hlfir::OrderedAssignmentTreeOpInterface root,
1227 bool tryFusingAssignments,
1228 mlir::PatternRewriter &rewriter) {
1229 hlfir::Schedule schedule =
1230 hlfir::buildEvaluationSchedule(root, tryFusingAssignments);
1231
1232 LLVM_DEBUG(
1233 /// Debug option to print the scheduling debug info without doing
1234 /// any code generation. The operations are simply erased to avoid
1235 /// failing and calling the rewrite patterns on nested operations.
1236 /// The only purpose of this is to help testing scheduling without
1237 /// having to test generated code.
1238 if (dbgScheduleOnly) {
1239 rewriter.eraseOp(root);
1240 return mlir::success();
1241 });
1242 lower(root, rewriter, schedule);
1243 rewriter.eraseOp(root);
1244 return mlir::success();
1245}
1246
1247namespace {
1248
1249class ForallOpConversion : public mlir::OpRewritePattern<hlfir::ForallOp> {
1250public:
1251 explicit ForallOpConversion(mlir::MLIRContext *ctx, bool tryFusingAssignments)
1252 : OpRewritePattern{ctx}, tryFusingAssignments{tryFusingAssignments} {}
1253
1254 mlir::LogicalResult
1255 matchAndRewrite(hlfir::ForallOp forallOp,
1256 mlir::PatternRewriter &rewriter) const override {
1257 auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
1258 forallOp.getOperation());
1259 if (mlir::failed(::rewrite(root, tryFusingAssignments, rewriter)))
1260 TODO(forallOp.getLoc(), "FORALL construct or statement in HLFIR");
1261 return mlir::success();
1262 }
1263 const bool tryFusingAssignments;
1264};
1265
1266class WhereOpConversion : public mlir::OpRewritePattern<hlfir::WhereOp> {
1267public:
1268 explicit WhereOpConversion(mlir::MLIRContext *ctx, bool tryFusingAssignments)
1269 : OpRewritePattern{ctx}, tryFusingAssignments{tryFusingAssignments} {}
1270
1271 mlir::LogicalResult
1272 matchAndRewrite(hlfir::WhereOp whereOp,
1273 mlir::PatternRewriter &rewriter) const override {
1274 auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
1275 whereOp.getOperation());
1276 return ::rewrite(root, tryFusingAssignments, rewriter);
1277 }
1278 const bool tryFusingAssignments;
1279};
1280
1281class RegionAssignConversion
1282 : public mlir::OpRewritePattern<hlfir::RegionAssignOp> {
1283public:
1284 explicit RegionAssignConversion(mlir::MLIRContext *ctx)
1285 : OpRewritePattern{ctx} {}
1286
1287 mlir::LogicalResult
1288 matchAndRewrite(hlfir::RegionAssignOp regionAssignOp,
1289 mlir::PatternRewriter &rewriter) const override {
1290 auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
1291 regionAssignOp.getOperation());
1292 return ::rewrite(root, /*tryFusingAssignments=*/false, rewriter);
1293 }
1294};
1295
1296class LowerHLFIROrderedAssignments
1297 : public hlfir::impl::LowerHLFIROrderedAssignmentsBase<
1298 LowerHLFIROrderedAssignments> {
1299public:
1300 void runOnOperation() override {
1301 // Running on a ModuleOp because this pass may generate FuncOp declaration
1302 // for runtime calls. This could be a FuncOp pass otherwise.
1303 auto module = this->getOperation();
1304 auto *context = &getContext();
1305 mlir::RewritePatternSet patterns(context);
1306 // Patterns are only defined for the OrderedAssignmentTreeOpInterface
1307 // operations that can be the root of ordered assignments. The other
1308 // operations will be taken care of while rewriting these trees (they
1309 // cannot exist outside of these operations given their verifiers/traits).
1310 patterns.insert<ForallOpConversion, WhereOpConversion>(
1311 context, this->tryFusingAssignments.getValue());
1312 patterns.insert<RegionAssignConversion>(context);
1313 mlir::ConversionTarget target(*context);
1314 target.markUnknownOpDynamicallyLegal([](mlir::Operation *op) {
1315 return !mlir::isa<hlfir::OrderedAssignmentTreeOpInterface>(op);
1316 });
1317 if (mlir::failed(mlir::applyPartialConversion(module, target,
1318 std::move(patterns)))) {
1319 mlir::emitError(mlir::UnknownLoc::get(context),
1320 "failure in HLFIR ordered assignments lowering pass");
1321 signalPassFailure();
1322 }
1323 }
1324};
1325} // namespace
1326
1327std::unique_ptr<mlir::Pass> hlfir::createLowerHLFIROrderedAssignmentsPass() {
1328 return std::make_unique<LowerHLFIROrderedAssignments>();
1329}
1330

source code of flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp