1 | //===- LowerHLFIROrderedAssignments.cpp - Lower HLFIR ordered assignments -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // This file defines a pass to lower HLFIR ordered assignments. |
9 | // Ordered assignments are all the operations with the |
10 | // OrderedAssignmentTreeOpInterface that implements user defined assignments, |
11 | // assignment to vector subscripted entities, and assignments inside forall and |
12 | // where. |
13 | // The pass lowers these operations to regular hlfir.assign, loops and, if |
14 | // needed, introduces temporary storage to fulfill Fortran semantics. |
15 | // |
16 | // For each rewrite, an analysis builds an evaluation schedule, and then the |
17 | // new code is generated by following the evaluation schedule. |
18 | //===----------------------------------------------------------------------===// |
19 | |
20 | #include "ScheduleOrderedAssignments.h" |
21 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
22 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
23 | #include "flang/Optimizer/Builder/TemporaryStorage.h" |
24 | #include "flang/Optimizer/Builder/Todo.h" |
25 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
26 | #include "flang/Optimizer/HLFIR/Passes.h" |
27 | #include "mlir/IR/Dominance.h" |
28 | #include "mlir/IR/IRMapping.h" |
29 | #include "mlir/Transforms/DialectConversion.h" |
30 | #include "llvm/ADT/SmallSet.h" |
31 | #include "llvm/ADT/TypeSwitch.h" |
32 | #include "llvm/Support/Debug.h" |
33 | |
34 | namespace hlfir { |
35 | #define GEN_PASS_DEF_LOWERHLFIRORDEREDASSIGNMENTS |
36 | #include "flang/Optimizer/HLFIR/Passes.h.inc" |
37 | } // namespace hlfir |
38 | |
39 | #define DEBUG_TYPE "flang-ordered-assignment" |
40 | |
41 | // Test option only to test the scheduling part only (operations are erased |
42 | // without codegen). The only goal is to allow printing and testing the debug |
43 | // info. |
44 | static llvm::cl::opt<bool> dbgScheduleOnly( |
45 | "flang-dbg-order-assignment-schedule-only" , |
46 | llvm::cl::desc("Only run ordered assignment scheduling with no codegen" ), |
47 | llvm::cl::init(false)); |
48 | |
49 | namespace { |
50 | |
51 | /// Structure that represents a masked expression being lowered. Masked |
52 | /// expressions are any expressions inside an hlfir.where. As described in |
53 | /// Fortran 2018 section 10.2.3.2, the evaluation of the elemental parts of such |
54 | /// expressions must be masked, while the evaluation of none elemental parts |
55 | /// must not be masked. This structure analyzes the region evaluating the |
56 | /// expression and allows splitting the generation of the none elemental part |
57 | /// from the elemental part. |
58 | struct MaskedArrayExpr { |
59 | MaskedArrayExpr(mlir::Location loc, mlir::Region ®ion); |
60 | |
61 | /// Generate the none elemental part. Must be called outside of the |
62 | /// loops created for the WHERE construct. |
63 | void generateNoneElementalPart(fir::FirOpBuilder &builder, |
64 | mlir::IRMapping &mapper); |
65 | |
66 | /// Methods below can only be called once generateNoneElementalPart has been |
67 | /// called. |
68 | |
69 | /// Return the shape of the expression. |
70 | mlir::Value generateShape(fir::FirOpBuilder &builder, |
71 | mlir::IRMapping &mapper); |
72 | /// Return the value of an element value for this expression given the current |
73 | /// where loop indices. |
74 | mlir::Value generateElementalParts(fir::FirOpBuilder &builder, |
75 | mlir::ValueRange oneBasedIndices, |
76 | mlir::IRMapping &mapper); |
77 | /// Generate the cleanup for the none elemental parts, if any. This must be |
78 | /// called after the loops created for the WHERE construct. |
79 | void generateNoneElementalCleanupIfAny(fir::FirOpBuilder &builder, |
80 | mlir::IRMapping &mapper); |
81 | |
82 | mlir::Location loc; |
83 | mlir::Region ®ion; |
84 | /// Was generateNoneElementalPart called? |
85 | bool noneElementalPartWasGenerated = false; |
86 | /// Set of operations that form the elemental parts of the |
87 | /// expression evaluation. These are the hlfir.elemental and |
88 | /// hlfir.elemental_addr that form the elemental tree producing |
89 | /// the expression value. hlfir.elemental that produce values |
90 | /// used inside transformational operations are not part of this set. |
91 | llvm::SmallSet<mlir::Operation *, 4> elementalParts{}; |
92 | }; |
93 | } // namespace |
94 | |
95 | namespace { |
96 | /// Structure that visits an ordered assignment tree and generates code for |
97 | /// it according to a schedule. |
98 | class OrderedAssignmentRewriter { |
99 | public: |
100 | OrderedAssignmentRewriter(fir::FirOpBuilder &builder, |
101 | hlfir::OrderedAssignmentTreeOpInterface root) |
102 | : builder{builder}, root{root} {} |
103 | |
104 | /// Generate code for the current run of the schedule. |
105 | void lowerRun(hlfir::Run &run) { |
106 | currentRun = &run; |
107 | walk(root); |
108 | currentRun = nullptr; |
109 | assert(constructStack.empty() && "must exit constructs after a run" ); |
110 | mapper.clear(); |
111 | savedInCurrentRunBeforeUse.clear(); |
112 | } |
113 | |
114 | /// After all run have been lowered, clean-up all the temporary |
115 | /// storage that were created (do not call final routines). |
116 | void cleanupSavedEntities() { |
117 | for (auto &temp : savedEntities) |
118 | temp.second.destroy(root.getLoc(), builder); |
119 | } |
120 | |
121 | /// Lowered value for an expression, and the original hlfir.yield if any |
122 | /// clean-up needs to be cloned after usage. |
123 | using ValueAndCleanUp = std::pair<mlir::Value, std::optional<hlfir::YieldOp>>; |
124 | |
125 | private: |
126 | /// Walk the part of an order assignment tree node that needs |
127 | /// to be evaluated in the current run. |
128 | void walk(hlfir::OrderedAssignmentTreeOpInterface node); |
129 | |
130 | /// Generate code when entering a given ordered assignment node. |
131 | void pre(hlfir::ForallOp forallOp); |
132 | void pre(hlfir::ForallIndexOp); |
133 | void pre(hlfir::ForallMaskOp); |
134 | void pre(hlfir::WhereOp whereOp); |
135 | void pre(hlfir::ElseWhereOp elseWhereOp); |
136 | void pre(hlfir::RegionAssignOp); |
137 | |
138 | /// Generate code when leaving a given ordered assignment node. |
139 | void post(hlfir::ForallOp); |
140 | void post(hlfir::ForallMaskOp); |
141 | void post(hlfir::WhereOp); |
142 | void post(hlfir::ElseWhereOp); |
143 | /// Enter (and maybe create) the fir.if else block of an ElseWhereOp, |
144 | /// but do not generate the elswhere mask or the new fir.if. |
145 | void enterElsewhere(hlfir::ElseWhereOp); |
146 | |
147 | /// Are there any leaf region in the node that must be saved in the current |
148 | /// run? |
149 | bool mustSaveRegionIn( |
150 | hlfir::OrderedAssignmentTreeOpInterface node, |
151 | llvm::SmallVectorImpl<hlfir::SaveEntity> &saveEntities) const; |
152 | /// Should this node be evaluated in the current run? Saving a region in a |
153 | /// node does not imply the node needs to be evaluated. |
154 | bool |
155 | isRequiredInCurrentRun(hlfir::OrderedAssignmentTreeOpInterface node) const; |
156 | |
157 | /// Generate a scalar value yielded by an ordered assignment tree region. |
158 | /// If the value was not saved in a previous run, this clone the region |
159 | /// code, except the final yield, at the current execution point. |
160 | /// If the value was saved in a previous run, this fetches the saved value |
161 | /// from the temporary storage and returns the value. |
162 | /// Inside Forall, the value will be hoisted outside of the forall loops if |
163 | /// it does not depend on the forall indices. |
164 | /// An optional type can be provided to get a value from a specific type |
165 | /// (the cast will be hoisted if the computation is hoisted). |
166 | mlir::Value generateYieldedScalarValue( |
167 | mlir::Region ®ion, |
168 | std::optional<mlir::Type> castToType = std::nullopt); |
169 | |
170 | /// Generate an entity yielded by an ordered assignment tree region, and |
171 | /// optionally return the (uncloned) yield if there is any clean-up that |
172 | /// should be done after using the entity. Like, generateYieldedScalarValue, |
173 | /// this will return the saved value if the region was saved in a previous |
174 | /// run. |
175 | ValueAndCleanUp |
176 | generateYieldedEntity(mlir::Region ®ion, |
177 | std::optional<mlir::Type> castToType = std::nullopt); |
178 | |
179 | struct LhsValueAndCleanUp { |
180 | mlir::Value lhs; |
181 | std::optional<hlfir::YieldOp> elementalCleanup; |
182 | mlir::Region *nonElementalCleanup = nullptr; |
183 | std::optional<hlfir::LoopNest> vectorSubscriptLoopNest; |
184 | std::optional<mlir::Value> vectorSubscriptShape; |
185 | }; |
186 | |
187 | /// Generate the left-hand side. If the left-hand side is vector |
188 | /// subscripted (hlfir.elemental_addr), this will create a loop nest |
189 | /// (unless it was already created by a WHERE mask) and return the |
190 | /// element address. |
191 | LhsValueAndCleanUp |
192 | generateYieldedLHS(mlir::Location loc, mlir::Region &lhsRegion, |
193 | std::optional<hlfir::Entity> loweredRhs = std::nullopt); |
194 | |
195 | /// If \p maybeYield is present and has a clean-up, generate the clean-up |
196 | /// at the current insertion point (by cloning). |
197 | void generateCleanupIfAny(std::optional<hlfir::YieldOp> maybeYield); |
198 | void generateCleanupIfAny(mlir::Region *cleanupRegion); |
199 | |
200 | /// Generate a masked entity. This can only be called when whereLoopNest was |
201 | /// set (When an hlfir.where is being visited). |
202 | /// This method returns the scalar element (that may have been previously |
203 | /// saved) for the current indices inside the where loop. |
204 | mlir::Value generateMaskedEntity(mlir::Location loc, mlir::Region ®ion) { |
205 | MaskedArrayExpr maskedExpr(loc, region); |
206 | return generateMaskedEntity(maskedExpr); |
207 | } |
208 | mlir::Value generateMaskedEntity(MaskedArrayExpr &maskedExpr); |
209 | |
210 | /// Create a fir.if at the current position inside the where loop nest |
211 | /// given the element value of a mask. |
212 | void generateMaskIfOp(mlir::Value cdt); |
213 | |
214 | /// Save a value for subsequent runs. |
215 | void generateSaveEntity(hlfir::SaveEntity savedEntity, |
216 | bool willUseSavedEntityInSameRun); |
217 | void saveLeftHandSide(hlfir::SaveEntity savedEntity, |
218 | hlfir::RegionAssignOp regionAssignOp); |
219 | |
220 | /// Get a value if it was saved in this run or a previous run. Returns |
221 | /// nullopt if it has not been saved. |
222 | std::optional<ValueAndCleanUp> getIfSaved(mlir::Region ®ion); |
223 | |
224 | /// Generate code before the loop nest for the current run, if any. |
225 | void doBeforeLoopNest(const std::function<void()> &callback) { |
226 | if (constructStack.empty()) { |
227 | callback(); |
228 | return; |
229 | } |
230 | auto insertionPoint = builder.saveInsertionPoint(); |
231 | builder.setInsertionPoint(constructStack[0]); |
232 | callback(); |
233 | builder.restoreInsertionPoint(insertionPoint); |
234 | } |
235 | |
236 | /// Can the current loop nest iteration number be computed? For simplicity, |
237 | /// this is true if and only if all the bounds and steps of the fir.do_loop |
238 | /// nest dominates the outer loop. The argument is filled with the current |
239 | /// loop nest on success. |
240 | bool currentLoopNestIterationNumberCanBeComputed( |
241 | llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest); |
242 | |
243 | template <typename T> |
244 | fir::factory::TemporaryStorage *insertSavedEntity(mlir::Region ®ion, |
245 | T &&temp) { |
246 | auto inserted = |
247 | savedEntities.insert(std::make_pair(®ion, std::forward<T>(temp))); |
248 | assert(inserted.second && "temp must have been emplaced" ); |
249 | return &inserted.first->second; |
250 | } |
251 | |
252 | fir::FirOpBuilder &builder; |
253 | |
254 | /// Map containing the mapping between the original order assignment tree |
255 | /// operations and the operations that have been cloned in the current run. |
256 | /// It is reset between two runs. |
257 | mlir::IRMapping mapper; |
258 | /// Dominance info is used to determine if inner loop bounds are all computed |
259 | /// before outer loop for the current loop. It does not need to be reset |
260 | /// between runs. |
261 | mlir::DominanceInfo dominanceInfo; |
262 | /// Construct stack in the current run. This allows setting back the insertion |
263 | /// point correctly when leaving a node that requires a fir.do_loop or fir.if |
264 | /// operation. |
265 | llvm::SmallVector<mlir::Operation *> constructStack; |
266 | /// Current where loop nest, if any. |
267 | std::optional<hlfir::LoopNest> whereLoopNest; |
268 | |
269 | /// Map of temporary storage to keep track of saved entity once the run |
270 | /// that saves them has been lowered. It is kept in-between runs. |
271 | /// llvm::MapVector is used to guarantee deterministic order |
272 | /// of iterating through savedEntities (e.g. for generating |
273 | /// destruction code for the temporary storages). |
274 | llvm::MapVector<mlir::Region *, fir::factory::TemporaryStorage> savedEntities; |
275 | /// Map holding the values that were saved in the current run and that also |
276 | /// need to be used (because their construct will be visited). It is reset |
277 | /// after each run. It avoids having to store and fetch in the temporary |
278 | /// during the same run, which would require the temporary to have different |
279 | /// fetching and storing counters. |
280 | llvm::DenseMap<mlir::Region *, ValueAndCleanUp> savedInCurrentRunBeforeUse; |
281 | |
282 | /// Root of the order assignment tree being lowered. |
283 | hlfir::OrderedAssignmentTreeOpInterface root; |
284 | /// Pointer to the current run of the schedule being lowered. |
285 | hlfir::Run *currentRun = nullptr; |
286 | |
287 | /// When allocating temporary storage inlined, indicate if the storage should |
288 | /// be heap or stack allocated. Temporary allocated with the runtime are heap |
289 | /// allocated by the runtime. |
290 | bool allocateOnHeap = true; |
291 | }; |
292 | } // namespace |
293 | |
294 | void OrderedAssignmentRewriter::walk( |
295 | hlfir::OrderedAssignmentTreeOpInterface node) { |
296 | bool mustVisit = |
297 | isRequiredInCurrentRun(node) || mlir::isa<hlfir::ForallIndexOp>(node); |
298 | llvm::SmallVector<hlfir::SaveEntity> saveEntities; |
299 | mlir::Operation *nodeOp = node.getOperation(); |
300 | if (mustSaveRegionIn(node, saveEntities)) { |
301 | mlir::IRRewriter::InsertPoint insertionPoint; |
302 | if (auto elseWhereOp = mlir::dyn_cast<hlfir::ElseWhereOp>(nodeOp)) { |
303 | // ElseWhere mask to save must be evaluated inside the fir.if else |
304 | // for the previous where/elsewehere (its evaluation must be |
305 | // masked by the "pending control mask"). |
306 | insertionPoint = builder.saveInsertionPoint(); |
307 | enterElsewhere(elseWhereOp); |
308 | } |
309 | for (hlfir::SaveEntity saveEntity : saveEntities) |
310 | generateSaveEntity(savedEntity: saveEntity, willUseSavedEntityInSameRun: mustVisit); |
311 | if (insertionPoint.isSet()) |
312 | builder.restoreInsertionPoint(insertionPoint); |
313 | } |
314 | if (mustVisit) { |
315 | llvm::TypeSwitch<mlir::Operation *, void>(nodeOp) |
316 | .Case<hlfir::ForallOp, hlfir::ForallIndexOp, hlfir::ForallMaskOp, |
317 | hlfir::RegionAssignOp, hlfir::WhereOp, hlfir::ElseWhereOp>( |
318 | [&](auto concreteOp) { pre(concreteOp); }) |
319 | .Default([](auto) {}); |
320 | if (auto *body = node.getSubTreeRegion()) { |
321 | for (mlir::Operation &op : body->getOps()) |
322 | if (auto subNode = |
323 | mlir::dyn_cast<hlfir::OrderedAssignmentTreeOpInterface>(op)) |
324 | walk(subNode); |
325 | llvm::TypeSwitch<mlir::Operation *, void>(nodeOp) |
326 | .Case<hlfir::ForallOp, hlfir::ForallMaskOp, hlfir::WhereOp, |
327 | hlfir::ElseWhereOp>([&](auto concreteOp) { post(concreteOp); }) |
328 | .Default([](auto) {}); |
329 | } |
330 | } |
331 | } |
332 | |
333 | void OrderedAssignmentRewriter::pre(hlfir::ForallOp forallOp) { |
334 | /// Create a fir.do_loop given the hlfir.forall control values. |
335 | mlir::Type idxTy = builder.getIndexType(); |
336 | mlir::Location loc = forallOp.getLoc(); |
337 | mlir::Value lb = generateYieldedScalarValue(forallOp.getLbRegion(), idxTy); |
338 | mlir::Value ub = generateYieldedScalarValue(forallOp.getUbRegion(), idxTy); |
339 | mlir::Value step; |
340 | if (forallOp.getStepRegion().empty()) { |
341 | auto insertionPoint = builder.saveInsertionPoint(); |
342 | if (!constructStack.empty()) |
343 | builder.setInsertionPoint(constructStack[0]); |
344 | step = builder.createIntegerConstant(loc, idxTy, 1); |
345 | if (!constructStack.empty()) |
346 | builder.restoreInsertionPoint(insertionPoint); |
347 | } else { |
348 | step = generateYieldedScalarValue(forallOp.getStepRegion(), idxTy); |
349 | } |
350 | auto doLoop = builder.create<fir::DoLoopOp>(loc, lb, ub, step); |
351 | builder.setInsertionPointToStart(doLoop.getBody()); |
352 | mlir::Value oldIndex = forallOp.getForallIndexValue(); |
353 | mlir::Value newIndex = |
354 | builder.createConvert(loc, oldIndex.getType(), doLoop.getInductionVar()); |
355 | mapper.map(oldIndex, newIndex); |
356 | constructStack.push_back(doLoop); |
357 | } |
358 | |
359 | void OrderedAssignmentRewriter::post(hlfir::ForallOp) { |
360 | assert(!constructStack.empty() && "must contain a loop" ); |
361 | builder.setInsertionPointAfter(constructStack.pop_back_val()); |
362 | } |
363 | |
364 | void OrderedAssignmentRewriter::pre(hlfir::ForallIndexOp forallIndexOp) { |
365 | mlir::Location loc = forallIndexOp.getLoc(); |
366 | mlir::Type intTy = fir::unwrapRefType(forallIndexOp.getType()); |
367 | mlir::Value indexVar = |
368 | builder.createTemporary(loc, intTy, forallIndexOp.getName()); |
369 | mlir::Value newVal = mapper.lookupOrDefault(forallIndexOp.getIndex()); |
370 | builder.createStoreWithConvert(loc, newVal, indexVar); |
371 | mapper.map(forallIndexOp, indexVar); |
372 | } |
373 | |
374 | void OrderedAssignmentRewriter::pre(hlfir::ForallMaskOp forallMaskOp) { |
375 | mlir::Location loc = forallMaskOp.getLoc(); |
376 | mlir::Value mask = generateYieldedScalarValue(forallMaskOp.getMaskRegion(), |
377 | builder.getI1Type()); |
378 | auto ifOp = builder.create<fir::IfOp>(loc, std::nullopt, mask, false); |
379 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
380 | constructStack.push_back(ifOp); |
381 | } |
382 | |
383 | void OrderedAssignmentRewriter::post(hlfir::ForallMaskOp forallMaskOp) { |
384 | assert(!constructStack.empty() && "must contain an ifop" ); |
385 | builder.setInsertionPointAfter(constructStack.pop_back_val()); |
386 | } |
387 | |
388 | /// Convert an entity to the type of a given mold. |
389 | /// This is intended to help with cases where hlfir entity is a value while |
390 | /// it must be used as a variable or vice-versa. These mismatches may occur |
391 | /// between the type of user defined assignment block arguments and the actual |
392 | /// argument that was lowered for them. The actual may be an in-memory copy |
393 | /// while the block argument expects an hlfir.expr. |
394 | static hlfir::Entity |
395 | convertToMoldType(mlir::Location loc, fir::FirOpBuilder &builder, |
396 | hlfir::Entity input, hlfir::Entity mold, |
397 | llvm::SmallVectorImpl<hlfir::CleanupFunction> &cleanups) { |
398 | if (input.getType() == mold.getType()) |
399 | return input; |
400 | fir::FirOpBuilder *b = &builder; |
401 | if (input.isVariable() && mold.isValue()) { |
402 | if (fir::isa_trivial(mold.getType())) { |
403 | // fir.ref<T> to T. |
404 | mlir::Value load = builder.create<fir::LoadOp>(loc, input); |
405 | return hlfir::Entity{builder.createConvert(loc, mold.getType(), load)}; |
406 | } |
407 | // fir.ref<T> to hlfir.expr<T>. |
408 | mlir::Value asExpr = builder.create<hlfir::AsExprOp>(loc, input); |
409 | if (asExpr.getType() != mold.getType()) |
410 | TODO(loc, "hlfir.expr conversion" ); |
411 | cleanups.emplace_back([=]() { b->create<hlfir::DestroyOp>(loc, asExpr); }); |
412 | return hlfir::Entity{asExpr}; |
413 | } |
414 | if (input.isValue() && mold.isVariable()) { |
415 | // T to fir.ref<T>, or hlfir.expr<T> to fir.ref<T>. |
416 | hlfir::AssociateOp associate = hlfir::genAssociateExpr( |
417 | loc, builder, input, mold.getFortranElementType(), ".tmp.val2ref" ); |
418 | cleanups.emplace_back( |
419 | [=]() { b->create<hlfir::EndAssociateOp>(loc, associate); }); |
420 | return hlfir::Entity{associate.getBase()}; |
421 | } |
422 | // Variable to Variable mismatch (e.g., fir.heap<T> vs fir.ref<T>), or value |
423 | // to Value mismatch (e.g. i1 vs fir.logical<4>). |
424 | if (mlir::isa<fir::BaseBoxType>(mold.getType()) && |
425 | !mlir::isa<fir::BaseBoxType>(input.getType())) { |
426 | // An entity may have have been saved without descriptor while the original |
427 | // value had a descriptor (e.g., it was not contiguous). |
428 | auto emboxed = hlfir::convertToBox(loc, builder, input, mold.getType()); |
429 | assert(!emboxed.second && "temp should already be in memory" ); |
430 | input = hlfir::Entity{fir::getBase(emboxed.first)}; |
431 | } |
432 | return hlfir::Entity{builder.createConvert(loc, mold.getType(), input)}; |
433 | } |
434 | |
435 | void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) { |
436 | mlir::Location loc = regionAssignOp.getLoc(); |
437 | std::optional<hlfir::LoopNest> elementalLoopNest; |
438 | auto [rhsValue, oldRhsYield] = |
439 | generateYieldedEntity(regionAssignOp.getRhsRegion()); |
440 | hlfir::Entity rhsEntity{rhsValue}; |
441 | LhsValueAndCleanUp loweredLhs = |
442 | generateYieldedLHS(loc, regionAssignOp.getLhsRegion(), rhsEntity); |
443 | hlfir::Entity lhsEntity{loweredLhs.lhs}; |
444 | if (loweredLhs.vectorSubscriptLoopNest) |
445 | rhsEntity = hlfir::getElementAt( |
446 | loc, builder, rhsEntity, |
447 | loweredLhs.vectorSubscriptLoopNest->oneBasedIndices); |
448 | if (!regionAssignOp.getUserDefinedAssignment().empty()) { |
449 | hlfir::Entity userAssignLhs{regionAssignOp.getUserAssignmentLhs()}; |
450 | hlfir::Entity userAssignRhs{regionAssignOp.getUserAssignmentRhs()}; |
451 | std::optional<hlfir::LoopNest> elementalLoopNest; |
452 | if (lhsEntity.isArray() && userAssignLhs.isScalar()) { |
453 | // Elemental assignment with array argument (the RHS cannot be an array |
454 | // if the LHS is not). |
455 | mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity); |
456 | elementalLoopNest = hlfir::genLoopNest(loc, builder, shape); |
457 | builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody()); |
458 | lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity, |
459 | elementalLoopNest->oneBasedIndices); |
460 | rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity, |
461 | elementalLoopNest->oneBasedIndices); |
462 | } |
463 | |
464 | llvm::SmallVector<hlfir::CleanupFunction, 2> argConversionCleanups; |
465 | lhsEntity = convertToMoldType(loc, builder, lhsEntity, userAssignLhs, |
466 | argConversionCleanups); |
467 | rhsEntity = convertToMoldType(loc, builder, rhsEntity, userAssignRhs, |
468 | argConversionCleanups); |
469 | mapper.map(userAssignLhs, lhsEntity); |
470 | mapper.map(userAssignRhs, rhsEntity); |
471 | for (auto &op : |
472 | regionAssignOp.getUserDefinedAssignment().front().without_terminator()) |
473 | (void)builder.clone(op, mapper); |
474 | for (auto &cleanupConversion : argConversionCleanups) |
475 | cleanupConversion(); |
476 | if (elementalLoopNest) |
477 | builder.setInsertionPointAfter(elementalLoopNest->outerLoop); |
478 | } else { |
479 | // TODO: preserve allocatable assignment aspects for forall once |
480 | // they are conveyed in hlfir.region_assign. |
481 | builder.create<hlfir::AssignOp>(loc, rhsEntity, lhsEntity); |
482 | } |
483 | generateCleanupIfAny(loweredLhs.elementalCleanup); |
484 | if (loweredLhs.vectorSubscriptLoopNest) |
485 | builder.setInsertionPointAfter( |
486 | loweredLhs.vectorSubscriptLoopNest->outerLoop); |
487 | generateCleanupIfAny(oldRhsYield); |
488 | generateCleanupIfAny(loweredLhs.nonElementalCleanup); |
489 | } |
490 | |
491 | void OrderedAssignmentRewriter::generateMaskIfOp(mlir::Value cdt) { |
492 | mlir::Location loc = cdt.getLoc(); |
493 | cdt = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{cdt}); |
494 | cdt = builder.createConvert(loc, builder.getI1Type(), cdt); |
495 | auto ifOp = builder.create<fir::IfOp>(cdt.getLoc(), std::nullopt, cdt, |
496 | /*withElseRegion=*/false); |
497 | constructStack.push_back(ifOp.getOperation()); |
498 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
499 | } |
500 | |
501 | void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) { |
502 | mlir::Location loc = whereOp.getLoc(); |
503 | if (!whereLoopNest) { |
504 | // This is the top-level WHERE. Start a loop nest iterating on the shape of |
505 | // the where mask. |
506 | if (auto maybeSaved = getIfSaved(whereOp.getMaskRegion())) { |
507 | // Use the saved value to get the shape and condition element. |
508 | hlfir::Entity savedMask{maybeSaved->first}; |
509 | mlir::Value shape = hlfir::genShape(loc, builder, savedMask); |
510 | whereLoopNest = hlfir::genLoopNest(loc, builder, shape); |
511 | constructStack.push_back(whereLoopNest->outerLoop.getOperation()); |
512 | builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody()); |
513 | mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask, |
514 | whereLoopNest->oneBasedIndices); |
515 | generateMaskIfOp(cdt); |
516 | if (maybeSaved->second) { |
517 | // If this is the same run as the one that saved the value, the clean-up |
518 | // was left-over to be done now. |
519 | auto insertionPoint = builder.saveInsertionPoint(); |
520 | builder.setInsertionPointAfter(whereLoopNest->outerLoop); |
521 | generateCleanupIfAny(maybeSaved->second); |
522 | builder.restoreInsertionPoint(insertionPoint); |
523 | } |
524 | return; |
525 | } |
526 | // The mask was not evaluated yet or can be safely re-evaluated. |
527 | MaskedArrayExpr mask(loc, whereOp.getMaskRegion()); |
528 | mask.generateNoneElementalPart(builder, mapper); |
529 | mlir::Value shape = mask.generateShape(builder, mapper); |
530 | whereLoopNest = hlfir::genLoopNest(loc, builder, shape); |
531 | constructStack.push_back(whereLoopNest->outerLoop.getOperation()); |
532 | builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody()); |
533 | mlir::Value cdt = generateMaskedEntity(mask); |
534 | generateMaskIfOp(cdt); |
535 | return; |
536 | } |
537 | // Where Loops have been already created by a parent WHERE. |
538 | // Generate a fir.if with the value of the current element of the mask |
539 | // inside the loops. The case where the mask was saved is handled in the |
540 | // generateYieldedScalarValue call. |
541 | mlir::Value cdt = generateYieldedScalarValue(whereOp.getMaskRegion()); |
542 | generateMaskIfOp(cdt); |
543 | } |
544 | |
545 | void OrderedAssignmentRewriter::post(hlfir::WhereOp whereOp) { |
546 | assert(!constructStack.empty() && "must contain a fir.if" ); |
547 | builder.setInsertionPointAfter(constructStack.pop_back_val()); |
548 | // If all where/elsewhere fir.if have been popped, this is the outer whereOp, |
549 | // and the where loop must be exited. |
550 | assert(!constructStack.empty() && "must contain a fir.do_loop or fir.if" ); |
551 | if (mlir::isa<fir::DoLoopOp>(constructStack.back())) { |
552 | builder.setInsertionPointAfter(constructStack.pop_back_val()); |
553 | whereLoopNest.reset(); |
554 | } |
555 | } |
556 | |
557 | void OrderedAssignmentRewriter::enterElsewhere(hlfir::ElseWhereOp elseWhereOp) { |
558 | // Create an "else" region for the current where/elsewhere fir.if. |
559 | auto ifOp = mlir::dyn_cast<fir::IfOp>(constructStack.back()); |
560 | assert(ifOp && "must be an if" ); |
561 | if (ifOp.getElseRegion().empty()) { |
562 | mlir::Location loc = elseWhereOp.getLoc(); |
563 | builder.createBlock(&ifOp.getElseRegion()); |
564 | auto end = builder.create<fir::ResultOp>(loc); |
565 | builder.setInsertionPoint(end); |
566 | } else { |
567 | builder.setInsertionPoint(&ifOp.getElseRegion().back().back()); |
568 | } |
569 | } |
570 | |
571 | void OrderedAssignmentRewriter::pre(hlfir::ElseWhereOp elseWhereOp) { |
572 | enterElsewhere(elseWhereOp); |
573 | if (elseWhereOp.getMaskRegion().empty()) |
574 | return; |
575 | // Create new nested fir.if with elsewhere mask if any. |
576 | mlir::Value cdt = generateYieldedScalarValue(elseWhereOp.getMaskRegion()); |
577 | generateMaskIfOp(cdt); |
578 | } |
579 | |
580 | void OrderedAssignmentRewriter::post(hlfir::ElseWhereOp elseWhereOp) { |
581 | // Exit ifOp that was created for the elseWhereOp mask, if any. |
582 | if (elseWhereOp.getMaskRegion().empty()) |
583 | return; |
584 | assert(!constructStack.empty() && "must contain a fir.if" ); |
585 | builder.setInsertionPointAfter(constructStack.pop_back_val()); |
586 | } |
587 | |
588 | /// Is this value a Forall index? |
589 | /// Forall index are block arguments of hlfir.forall body, or the result |
590 | /// of hlfir.forall_index. |
591 | static bool isForallIndex(mlir::Value value) { |
592 | if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(value)) { |
593 | if (mlir::Block *block = blockArg.getOwner()) |
594 | return block->isEntryBlock() && |
595 | mlir::isa_and_nonnull<hlfir::ForallOp>(block->getParentOp()); |
596 | return false; |
597 | } |
598 | return value.getDefiningOp<hlfir::ForallIndexOp>(); |
599 | } |
600 | |
601 | static OrderedAssignmentRewriter::ValueAndCleanUp |
602 | castIfNeeded(mlir::Location loc, fir::FirOpBuilder &builder, |
603 | OrderedAssignmentRewriter::ValueAndCleanUp valueAndCleanUp, |
604 | std::optional<mlir::Type> castToType) { |
605 | if (!castToType.has_value()) |
606 | return valueAndCleanUp; |
607 | mlir::Value cast = |
608 | builder.createConvert(loc, *castToType, valueAndCleanUp.first); |
609 | return {cast, valueAndCleanUp.second}; |
610 | } |
611 | |
612 | std::optional<OrderedAssignmentRewriter::ValueAndCleanUp> |
613 | OrderedAssignmentRewriter::getIfSaved(mlir::Region ®ion) { |
614 | mlir::Location loc = region.getParentOp()->getLoc(); |
615 | // If the region was saved in the same run, use the value that was evaluated |
616 | // instead of fetching the temp, and do clean-up, if any, that were delayed. |
617 | // This is done to avoid requiring the temporary stack to have different |
618 | // fetching and storing counters, and also because it produces slightly better |
619 | // code. |
620 | if (auto savedInSameRun = savedInCurrentRunBeforeUse.find(®ion); |
621 | savedInSameRun != savedInCurrentRunBeforeUse.end()) |
622 | return savedInSameRun->second; |
623 | // If the region was saved in a previous run, fetch the saved value. |
624 | if (auto temp = savedEntities.find(®ion); temp != savedEntities.end()) { |
625 | doBeforeLoopNest(callback: [&]() { temp->second.resetFetchPosition(loc, builder); }); |
626 | return ValueAndCleanUp{temp->second.fetch(loc, builder), std::nullopt}; |
627 | } |
628 | return std::nullopt; |
629 | } |
630 | |
631 | OrderedAssignmentRewriter::ValueAndCleanUp |
632 | OrderedAssignmentRewriter::generateYieldedEntity( |
633 | mlir::Region ®ion, std::optional<mlir::Type> castToType) { |
634 | mlir::Location loc = region.getParentOp()->getLoc(); |
635 | if (auto maybeValueAndCleanUp = getIfSaved(region)) |
636 | return castIfNeeded(loc, builder, *maybeValueAndCleanUp, castToType); |
637 | // Otherwise, evaluate the region now. |
638 | |
639 | // Masked expression must not evaluate the elemental parts that are masked, |
640 | // they have custom code generation. |
641 | if (whereLoopNest.has_value()) { |
642 | mlir::Value maskedValue = generateMaskedEntity(loc, region); |
643 | return castIfNeeded(loc, builder, {maskedValue, std::nullopt}, castToType); |
644 | } |
645 | |
646 | assert(region.hasOneBlock() && "region must contain one block" ); |
647 | auto oldYield = mlir::dyn_cast_or_null<hlfir::YieldOp>( |
648 | region.back().getOperations().back()); |
649 | assert(oldYield && "region computing entities must end with a YieldOp" ); |
650 | mlir::Block::OpListType &ops = region.back().getOperations(); |
651 | |
652 | // Inside Forall, scalars that do not depend on forall indices can be hoisted |
653 | // here because their evaluation is required to only call pure procedures, and |
654 | // if they depend on a variable previously assigned to in a forall assignment, |
655 | // this assignment must have been scheduled in a previous run. Hoisting of |
656 | // scalars is done here to help creating simple temporary storage if needed. |
657 | // Inner forall bounds can often be hoisted, and this allows computing the |
658 | // total number of iterations to create temporary storages. |
659 | bool hoistComputation = false; |
660 | if (fir::isa_trivial(oldYield.getEntity().getType()) && |
661 | !constructStack.empty()) { |
662 | hoistComputation = true; |
663 | for (mlir::Operation &op : ops) |
664 | if (llvm::any_of(op.getOperands(), [](mlir::Value value) { |
665 | return isForallIndex(value); |
666 | })) { |
667 | hoistComputation = false; |
668 | break; |
669 | } |
670 | } |
671 | auto insertionPoint = builder.saveInsertionPoint(); |
672 | if (hoistComputation) |
673 | builder.setInsertionPoint(constructStack[0]); |
674 | |
675 | // Clone all operations except the final hlfir.yield. |
676 | assert(!ops.empty() && "yield block cannot be empty" ); |
677 | auto end = ops.end(); |
678 | for (auto opIt = ops.begin(); std::next(opIt) != end; ++opIt) |
679 | (void)builder.clone(*opIt, mapper); |
680 | // Get the value for the yielded entity, it may be the result of an operation |
681 | // that was cloned, or it may be the same as the previous value if the yield |
682 | // operand was created before the ordered assignment tree. |
683 | mlir::Value newEntity = mapper.lookupOrDefault(oldYield.getEntity()); |
684 | if (castToType.has_value()) |
685 | newEntity = |
686 | builder.createConvert(newEntity.getLoc(), *castToType, newEntity); |
687 | |
688 | if (hoistComputation) { |
689 | // Hoisted trivial scalars clean-up can be done right away, the value is |
690 | // in registers. |
691 | generateCleanupIfAny(oldYield); |
692 | builder.restoreInsertionPoint(insertionPoint); |
693 | return {newEntity, std::nullopt}; |
694 | } |
695 | if (oldYield.getCleanup().empty()) |
696 | return {newEntity, std::nullopt}; |
697 | return {newEntity, oldYield}; |
698 | } |
699 | |
700 | mlir::Value OrderedAssignmentRewriter::generateYieldedScalarValue( |
701 | mlir::Region ®ion, std::optional<mlir::Type> castToType) { |
702 | mlir::Location loc = region.getParentOp()->getLoc(); |
703 | auto [value, maybeYield] = generateYieldedEntity(region, castToType); |
704 | value = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{value}); |
705 | assert(fir::isa_trivial(value.getType()) && "not a trivial scalar value" ); |
706 | generateCleanupIfAny(maybeYield); |
707 | return value; |
708 | } |
709 | |
710 | OrderedAssignmentRewriter::LhsValueAndCleanUp |
711 | OrderedAssignmentRewriter::generateYieldedLHS( |
712 | mlir::Location loc, mlir::Region &lhsRegion, |
713 | std::optional<hlfir::Entity> loweredRhs) { |
714 | LhsValueAndCleanUp loweredLhs; |
715 | hlfir::ElementalAddrOp elementalAddrLhs = |
716 | mlir::dyn_cast<hlfir::ElementalAddrOp>(lhsRegion.back().back()); |
717 | if (auto temp = savedEntities.find(&lhsRegion); temp != savedEntities.end()) { |
718 | // The LHS address was computed and saved in a previous run. Fetch it. |
719 | doBeforeLoopNest(callback: [&]() { temp->second.resetFetchPosition(loc, builder); }); |
720 | if (elementalAddrLhs && !whereLoopNest) { |
721 | // Vector subscripted designator address are saved element by element. |
722 | // If no "elemental" loops have been created yet, the shape of the |
723 | // RHS, if it is an array can be used, or the shape of the vector |
724 | // subscripted designator must be retrieved to generate the "elemental" |
725 | // loop nest. |
726 | if (loweredRhs && loweredRhs->isArray()) { |
727 | // The RHS shape can be used to create the elemental loops and avoid |
728 | // saving the LHS shape. |
729 | loweredLhs.vectorSubscriptShape = |
730 | hlfir::genShape(loc, builder, *loweredRhs); |
731 | } else { |
732 | // If the shape cannot be retrieved from the RHS, it must have been |
733 | // saved. Get it from the temporary. |
734 | auto &vectorTmp = |
735 | temp->second.cast<fir::factory::AnyVectorSubscriptStack>(); |
736 | loweredLhs.vectorSubscriptShape = vectorTmp.fetchShape(loc, builder); |
737 | } |
738 | loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest( |
739 | loc, builder, loweredLhs.vectorSubscriptShape.value()); |
740 | builder.setInsertionPointToStart( |
741 | loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody()); |
742 | } |
743 | loweredLhs.lhs = temp->second.fetch(loc, builder); |
744 | return loweredLhs; |
745 | } |
746 | // The LHS has not yet been evaluated and saved. Evaluate it now. |
747 | if (elementalAddrLhs && !whereLoopNest) { |
748 | // This is a vector subscripted entity. The address of elements must |
749 | // be returned. If no "elemental" loops have been created for a WHERE, |
750 | // create them now based on the vector subscripted designator shape. |
751 | for (auto &op : lhsRegion.front().without_terminator()) |
752 | (void)builder.clone(op, mapper); |
753 | loweredLhs.vectorSubscriptShape = |
754 | mapper.lookupOrDefault(elementalAddrLhs.getShape()); |
755 | loweredLhs.vectorSubscriptLoopNest = |
756 | hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape, |
757 | !elementalAddrLhs.isOrdered()); |
758 | builder.setInsertionPointToStart( |
759 | loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody()); |
760 | mapper.map(elementalAddrLhs.getIndices(), |
761 | loweredLhs.vectorSubscriptLoopNest->oneBasedIndices); |
762 | for (auto &op : elementalAddrLhs.getBody().front().without_terminator()) |
763 | (void)builder.clone(op, mapper); |
764 | loweredLhs.elementalCleanup = elementalAddrLhs.getYieldOp(); |
765 | loweredLhs.lhs = |
766 | mapper.lookupOrDefault(loweredLhs.elementalCleanup->getEntity()); |
767 | } else { |
768 | // This is a designator without vector subscripts. Generate it as |
769 | // it is done for other entities. |
770 | auto [lhs, yield] = generateYieldedEntity(lhsRegion); |
771 | loweredLhs.lhs = lhs; |
772 | if (yield && !yield->getCleanup().empty()) |
773 | loweredLhs.nonElementalCleanup = &yield->getCleanup(); |
774 | } |
775 | return loweredLhs; |
776 | } |
777 | |
778 | mlir::Value |
779 | OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) { |
780 | assert(whereLoopNest.has_value() && "must be inside WHERE loop nest" ); |
781 | auto insertionPoint = builder.saveInsertionPoint(); |
782 | if (!maskedExpr.noneElementalPartWasGenerated) { |
783 | // Generate none elemental part before the where loops (but inside the |
784 | // current forall loops if any). |
785 | builder.setInsertionPoint(whereLoopNest->outerLoop); |
786 | maskedExpr.generateNoneElementalPart(builder, mapper); |
787 | } |
788 | // Generate the none elemental part cleanup after the where loops. |
789 | builder.setInsertionPointAfter(whereLoopNest->outerLoop); |
790 | maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper); |
791 | // Generate the value of the current element for the masked expression |
792 | // at the current insertion point (inside the where loops, and any fir.if |
793 | // generated for previous masks). |
794 | builder.restoreInsertionPoint(insertionPoint); |
795 | return maskedExpr.generateElementalParts( |
796 | builder, whereLoopNest->oneBasedIndices, mapper); |
797 | } |
798 | |
799 | void OrderedAssignmentRewriter::generateCleanupIfAny( |
800 | std::optional<hlfir::YieldOp> maybeYield) { |
801 | if (maybeYield.has_value()) |
802 | generateCleanupIfAny(&maybeYield->getCleanup()); |
803 | } |
804 | void OrderedAssignmentRewriter::generateCleanupIfAny( |
805 | mlir::Region *cleanupRegion) { |
806 | if (cleanupRegion && !cleanupRegion->empty()) { |
807 | assert(cleanupRegion->hasOneBlock() && "region must contain one block" ); |
808 | for (auto &op : cleanupRegion->back().without_terminator()) |
809 | builder.clone(op, mapper); |
810 | } |
811 | } |
812 | |
813 | bool OrderedAssignmentRewriter::mustSaveRegionIn( |
814 | hlfir::OrderedAssignmentTreeOpInterface node, |
815 | llvm::SmallVectorImpl<hlfir::SaveEntity> &saveEntities) const { |
816 | for (auto &action : currentRun->actions) |
817 | if (hlfir::SaveEntity *savedEntity = |
818 | std::get_if<hlfir::SaveEntity>(&action)) |
819 | if (node.getOperation() == savedEntity->yieldRegion->getParentOp()) |
820 | saveEntities.push_back(*savedEntity); |
821 | return !saveEntities.empty(); |
822 | } |
823 | |
824 | bool OrderedAssignmentRewriter::isRequiredInCurrentRun( |
825 | hlfir::OrderedAssignmentTreeOpInterface node) const { |
826 | // hlfir.forall_index do not contain saved regions/assignments, |
827 | // but if their hlfir.forall parent was required, they are |
828 | // required (the forall indices needs to be mapped). |
829 | if (mlir::isa<hlfir::ForallIndexOp>(node)) |
830 | return true; |
831 | for (auto &action : currentRun->actions) |
832 | if (hlfir::SaveEntity *savedEntity = |
833 | std::get_if<hlfir::SaveEntity>(&action)) { |
834 | // A SaveEntity action does not require evaluating the node that contains |
835 | // it, but it requires to evaluate all the parents of the nodes that |
836 | // contains it. For instance, an saving a bound in hlfir.forall B does not |
837 | // require creating the loops for B, but it requires creating the loops |
838 | // for any forall parent A of the forall B. |
839 | if (node->isProperAncestor(savedEntity->yieldRegion->getParentOp())) |
840 | return true; |
841 | } else { |
842 | auto assign = std::get<hlfir::RegionAssignOp>(action); |
843 | if (node->isAncestor(assign.getOperation())) |
844 | return true; |
845 | } |
846 | return false; |
847 | } |
848 | |
849 | /// Is the apply using all the elemental indices in order? |
850 | static bool isInOrderApply(hlfir::ApplyOp apply, |
851 | hlfir::ElementalOpInterface elemental) { |
852 | mlir::Region::BlockArgListType elementalIndices = elemental.getIndices(); |
853 | if (elementalIndices.size() != apply.getIndices().size()) |
854 | return false; |
855 | for (auto [elementalIdx, applyIdx] : |
856 | llvm::zip(elementalIndices, apply.getIndices())) |
857 | if (elementalIdx != applyIdx) |
858 | return false; |
859 | return true; |
860 | } |
861 | |
862 | /// Gather the tree of hlfir::ElementalOpInterface use-def, if any, starting |
863 | /// from \p elemental, which may be a nullptr. |
864 | static void |
865 | gatherElementalTree(hlfir::ElementalOpInterface elemental, |
866 | llvm::SmallPtrSetImpl<mlir::Operation *> &elementalOps, |
867 | bool isOutOfOrder) { |
868 | if (elemental) { |
869 | // Only inline an applied elemental that must be executed in order if the |
870 | // applying indices are in order. An hlfir::Elemental may have been created |
871 | // for a transformational like transpose, and Fortran 2018 standard |
872 | // section 10.2.3.2, point 10 imply that impure elemental sub-expression |
873 | // evaluations should not be masked if they are the arguments of |
874 | // transformational expressions. |
875 | if (isOutOfOrder && elemental.isOrdered()) |
876 | return; |
877 | elementalOps.insert(elemental.getOperation()); |
878 | for (mlir::Operation &op : elemental.getElementalRegion().getOps()) |
879 | if (auto apply = mlir::dyn_cast<hlfir::ApplyOp>(op)) { |
880 | bool isUnorderedApply = |
881 | isOutOfOrder || !isInOrderApply(apply, elemental); |
882 | auto maybeElemental = |
883 | mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>( |
884 | apply.getExpr().getDefiningOp()); |
885 | gatherElementalTree(maybeElemental, elementalOps, isUnorderedApply); |
886 | } |
887 | } |
888 | } |
889 | |
890 | MaskedArrayExpr::MaskedArrayExpr(mlir::Location loc, mlir::Region ®ion) |
891 | : loc{loc}, region{region} { |
892 | mlir::Operation &terminator = region.back().back(); |
893 | if (auto elementalAddr = |
894 | mlir::dyn_cast<hlfir::ElementalOpInterface>(terminator)) { |
895 | // Vector subscripted designator (hlfir.elemental_addr terminator). |
896 | gatherElementalTree(elementalAddr, elementalParts, /*isOutOfOrder=*/false); |
897 | return; |
898 | } |
899 | // Try if elemental expression. |
900 | mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity(); |
901 | auto maybeElemental = mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>( |
902 | entity.getDefiningOp()); |
903 | gatherElementalTree(maybeElemental, elementalParts, /*isOutOfOrder=*/false); |
904 | } |
905 | |
906 | void MaskedArrayExpr::generateNoneElementalPart(fir::FirOpBuilder &builder, |
907 | mlir::IRMapping &mapper) { |
908 | assert(!noneElementalPartWasGenerated && |
909 | "none elemental parts already generated" ); |
910 | // Clone all operations, except the elemental and the final yield. |
911 | mlir::Block::OpListType &ops = region.back().getOperations(); |
912 | assert(!ops.empty() && "yield block cannot be empty" ); |
913 | auto end = ops.end(); |
914 | for (auto opIt = ops.begin(); std::next(opIt) != end; ++opIt) |
915 | if (!elementalParts.contains(&*opIt)) |
916 | (void)builder.clone(*opIt, mapper); |
917 | noneElementalPartWasGenerated = true; |
918 | } |
919 | |
920 | mlir::Value MaskedArrayExpr::generateShape(fir::FirOpBuilder &builder, |
921 | mlir::IRMapping &mapper) { |
922 | assert(noneElementalPartWasGenerated && |
923 | "non elemental part must have been generated" ); |
924 | mlir::Operation &terminator = region.back().back(); |
925 | // If the operation that produced the yielded entity is elemental, it was not |
926 | // cloned, but it holds a shape argument that was cloned. Return the cloned |
927 | // shape. |
928 | if (auto elementalAddrOp = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator)) |
929 | return mapper.lookupOrDefault(elementalAddrOp.getShape()); |
930 | mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity(); |
931 | if (auto elemental = entity.getDefiningOp<hlfir::ElementalOp>()) |
932 | return mapper.lookupOrDefault(elemental.getShape()); |
933 | // Otherwise, the whole entity was cloned, and the shape can be generated |
934 | // from it. |
935 | hlfir::Entity clonedEntity{mapper.lookupOrDefault(entity)}; |
936 | return hlfir::genShape(loc, builder, hlfir::Entity{clonedEntity}); |
937 | } |
938 | |
939 | mlir::Value |
940 | MaskedArrayExpr::generateElementalParts(fir::FirOpBuilder &builder, |
941 | mlir::ValueRange oneBasedIndices, |
942 | mlir::IRMapping &mapper) { |
943 | assert(noneElementalPartWasGenerated && |
944 | "non elemental part must have been generated" ); |
945 | mlir::Operation &terminator = region.back().back(); |
946 | hlfir::ElementalOpInterface elemental = |
947 | mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator); |
948 | if (!elemental) { |
949 | // If the terminator is not an hlfir.elemental_addr, try if the yielded |
950 | // entity was produced by an hlfir.elemental. |
951 | mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity(); |
952 | elemental = entity.getDefiningOp<hlfir::ElementalOp>(); |
953 | if (!elemental) { |
954 | // The yielded entity was not produced by an elemental operation, |
955 | // get its clone in the non elemental part evaluation and address it. |
956 | hlfir::Entity clonedEntity{mapper.lookupOrDefault(entity)}; |
957 | return hlfir::getElementAt(loc, builder, clonedEntity, oneBasedIndices); |
958 | } |
959 | } |
960 | |
961 | auto mustRecursivelyInline = |
962 | [&](hlfir::ElementalOp appliedElemental) -> bool { |
963 | return elementalParts.contains(appliedElemental.getOperation()); |
964 | }; |
965 | return inlineElementalOp(loc, builder, elemental, oneBasedIndices, mapper, |
966 | mustRecursivelyInline); |
967 | } |
968 | |
969 | void MaskedArrayExpr::generateNoneElementalCleanupIfAny( |
970 | fir::FirOpBuilder &builder, mlir::IRMapping &mapper) { |
971 | mlir::Operation &terminator = region.back().back(); |
972 | mlir::Region *cleanupRegion = nullptr; |
973 | if (auto elementalAddr = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator)) { |
974 | cleanupRegion = &elementalAddr.getCleanup(); |
975 | } else { |
976 | auto yieldOp = mlir::cast<hlfir::YieldOp>(terminator); |
977 | cleanupRegion = &yieldOp.getCleanup(); |
978 | } |
979 | if (cleanupRegion->empty()) |
980 | return; |
981 | for (mlir::Operation &op : cleanupRegion->front().without_terminator()) { |
982 | if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(op)) |
983 | if (elementalParts.contains(destroy.getExpr().getDefiningOp())) |
984 | continue; |
985 | (void)builder.clone(op, mapper); |
986 | } |
987 | } |
988 | |
989 | static hlfir::RegionAssignOp |
990 | getAssignIfLeftHandSideRegion(mlir::Region ®ion) { |
991 | auto assign = mlir::dyn_cast<hlfir::RegionAssignOp>(region.getParentOp()); |
992 | if (assign && (&assign.getLhsRegion() == ®ion)) |
993 | return assign; |
994 | return nullptr; |
995 | } |
996 | |
997 | bool OrderedAssignmentRewriter::currentLoopNestIterationNumberCanBeComputed( |
998 | llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest) { |
999 | if (constructStack.empty()) |
1000 | return true; |
1001 | mlir::Operation *outerLoop = constructStack[0]; |
1002 | mlir::Operation *currentConstruct = constructStack.back(); |
1003 | // Loop through the loops until the outer construct is met, and test if the |
1004 | // loop operands dominate the outer construct. |
1005 | while (currentConstruct) { |
1006 | if (auto doLoop = mlir::dyn_cast<fir::DoLoopOp>(currentConstruct)) { |
1007 | if (llvm::any_of(doLoop->getOperands(), [&](mlir::Value value) { |
1008 | return !dominanceInfo.properlyDominates(value, outerLoop); |
1009 | })) { |
1010 | return false; |
1011 | } |
1012 | loopNest.push_back(doLoop); |
1013 | } |
1014 | if (currentConstruct == outerLoop) |
1015 | currentConstruct = nullptr; |
1016 | else |
1017 | currentConstruct = currentConstruct->getParentOp(); |
1018 | } |
1019 | return true; |
1020 | } |
1021 | |
1022 | static mlir::Value |
1023 | computeLoopNestIterationNumber(mlir::Location loc, fir::FirOpBuilder &builder, |
1024 | llvm::ArrayRef<fir::DoLoopOp> loopNest) { |
1025 | mlir::Value loopExtent; |
1026 | for (fir::DoLoopOp doLoop : loopNest) { |
1027 | mlir::Value extent = builder.genExtentFromTriplet( |
1028 | loc, doLoop.getLowerBound(), doLoop.getUpperBound(), doLoop.getStep(), |
1029 | builder.getIndexType()); |
1030 | if (!loopExtent) |
1031 | loopExtent = extent; |
1032 | else |
1033 | loopExtent = builder.create<mlir::arith::MulIOp>(loc, loopExtent, extent); |
1034 | } |
1035 | assert(loopExtent && "loopNest must not be empty" ); |
1036 | return loopExtent; |
1037 | } |
1038 | |
1039 | /// Return a name for temporary storage that indicates in which context |
1040 | /// the temporary storage was created. |
1041 | static llvm::StringRef |
1042 | getTempName(hlfir::OrderedAssignmentTreeOpInterface root) { |
1043 | if (mlir::isa<hlfir::ForallOp>(root.getOperation())) |
1044 | return ".tmp.forall" ; |
1045 | if (mlir::isa<hlfir::WhereOp>(root.getOperation())) |
1046 | return ".tmp.where" ; |
1047 | return ".tmp.assign" ; |
1048 | } |
1049 | |
1050 | void OrderedAssignmentRewriter::generateSaveEntity( |
1051 | hlfir::SaveEntity savedEntity, bool willUseSavedEntityInSameRun) { |
1052 | mlir::Region ®ion = *savedEntity.yieldRegion; |
1053 | |
1054 | if (hlfir::RegionAssignOp regionAssignOp = |
1055 | getAssignIfLeftHandSideRegion(region)) { |
1056 | // Need to save the address, not the values. |
1057 | assert(!willUseSavedEntityInSameRun && |
1058 | "lhs cannot be used in the loop nest where it is saved" ); |
1059 | return saveLeftHandSide(savedEntity, regionAssignOp); |
1060 | } |
1061 | |
1062 | mlir::Location loc = region.getParentOp()->getLoc(); |
1063 | // Evaluate the region inside the loop nest (if any). |
1064 | auto [clonedValue, oldYield] = generateYieldedEntity(region); |
1065 | hlfir::Entity entity{clonedValue}; |
1066 | entity = hlfir::loadTrivialScalar(loc, builder, entity); |
1067 | mlir::Type entityType = entity.getType(); |
1068 | |
1069 | llvm::StringRef tempName = getTempName(root); |
1070 | fir::factory::TemporaryStorage *temp = nullptr; |
1071 | if (constructStack.empty()) { |
1072 | // Value evaluated outside of any loops (this may be the first MASK of a |
1073 | // WHERE construct, or an LHS/RHS temp of hlfir.region_assign outside of |
1074 | // WHERE/FORALL). |
1075 | temp = insertSavedEntity( |
1076 | region, fir::factory::SimpleCopy(loc, builder, entity, tempName)); |
1077 | } else { |
1078 | // Need to create a temporary for values computed inside loops. |
1079 | // Create temporary storage outside of the loop nest given the entity |
1080 | // type (and the loop context). |
1081 | llvm::SmallVector<fir::DoLoopOp> loopNest; |
1082 | bool loopShapeCanBePreComputed = |
1083 | currentLoopNestIterationNumberCanBeComputed(loopNest); |
1084 | doBeforeLoopNest(callback: [&] { |
1085 | /// For simple scalars inside loops whose total iteration number can be |
1086 | /// pre-computed, create a rank-1 array outside of the loops. It will be |
1087 | /// assigned/fetched inside the loops like a normal Fortran array given |
1088 | /// the iteration count. |
1089 | if (loopShapeCanBePreComputed && fir::isa_trivial(entityType)) { |
1090 | mlir::Value loopExtent = |
1091 | computeLoopNestIterationNumber(loc, builder, loopNest); |
1092 | auto sequenceType = |
1093 | builder.getVarLenSeqTy(entityType).cast<fir::SequenceType>(); |
1094 | temp = insertSavedEntity(region, |
1095 | fir::factory::HomogeneousScalarStack{ |
1096 | loc, builder, sequenceType, loopExtent, |
1097 | /*lenParams=*/{}, allocateOnHeap, |
1098 | /*stackThroughLoops=*/true, tempName}); |
1099 | |
1100 | } else { |
1101 | // If the number of iteration is not known, or if the values at each |
1102 | // iterations are values that may have different shape, type parameters |
1103 | // or dynamic type, use the runtime to create and manage a stack-like |
1104 | // temporary. |
1105 | temp = insertSavedEntity( |
1106 | region, fir::factory::AnyValueStack{loc, builder, entityType}); |
1107 | } |
1108 | }); |
1109 | // Inside the loop nest (and any fir.if if there are active masks), copy |
1110 | // the value to the temp and do clean-ups for the value if any. |
1111 | temp->pushValue(loc, builder, entity); |
1112 | } |
1113 | |
1114 | // Delay the clean-up if the entity will be used in the same run (i.e., the |
1115 | // parent construct will be visited and needs to be lowered). When possible, |
1116 | // this is not done for hlfir.expr because this use would prevent the |
1117 | // hlfir.expr storage from being moved when creating the temporary in |
1118 | // bufferization, and that would lead to an extra copy. |
1119 | if (willUseSavedEntityInSameRun && |
1120 | (!temp->canBeFetchedAfterPush() || |
1121 | !mlir::isa<hlfir::ExprType>(entity.getType()))) { |
1122 | auto inserted = |
1123 | savedInCurrentRunBeforeUse.try_emplace(®ion, entity, oldYield); |
1124 | assert(inserted.second && "entity must have been emplaced" ); |
1125 | (void)inserted; |
1126 | } else { |
1127 | if (constructStack.empty() && |
1128 | mlir::isa<hlfir::RegionAssignOp>(region.getParentOp())) { |
1129 | // Here the clean-up code is inserted after the original |
1130 | // RegionAssignOp, so that the assignment code happens |
1131 | // before the cleanup. We do this only for standalone |
1132 | // operations, because the clean-up is handled specially |
1133 | // during lowering of the parent constructs if any |
1134 | // (e.g. see generateNoneElementalCleanupIfAny for |
1135 | // WhereOp). |
1136 | auto insertionPoint = builder.saveInsertionPoint(); |
1137 | builder.setInsertionPointAfter(region.getParentOp()); |
1138 | generateCleanupIfAny(oldYield); |
1139 | builder.restoreInsertionPoint(insertionPoint); |
1140 | } else { |
1141 | generateCleanupIfAny(oldYield); |
1142 | } |
1143 | } |
1144 | } |
1145 | |
1146 | static bool rhsIsArray(hlfir::RegionAssignOp regionAssignOp) { |
1147 | auto yieldOp = mlir::dyn_cast<hlfir::YieldOp>( |
1148 | regionAssignOp.getRhsRegion().back().back()); |
1149 | return yieldOp && hlfir::Entity{yieldOp.getEntity()}.isArray(); |
1150 | } |
1151 | |
1152 | void OrderedAssignmentRewriter::saveLeftHandSide( |
1153 | hlfir::SaveEntity savedEntity, hlfir::RegionAssignOp regionAssignOp) { |
1154 | mlir::Region ®ion = *savedEntity.yieldRegion; |
1155 | mlir::Location loc = region.getParentOp()->getLoc(); |
1156 | LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region); |
1157 | fir::factory::TemporaryStorage *temp = nullptr; |
1158 | if (loweredLhs.vectorSubscriptLoopNest) |
1159 | constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop); |
1160 | if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) { |
1161 | // Vector subscripted entity for which the shape must also be saved on top |
1162 | // of the element addresses (e.g. the shape may change in each forall |
1163 | // iteration and is needed to create the elemental loops). |
1164 | mlir::Value shape = loweredLhs.vectorSubscriptShape.value(); |
1165 | int rank = mlir::cast<fir::ShapeType>(shape.getType()).getRank(); |
1166 | const bool shapeIsInvariant = |
1167 | constructStack.empty() || |
1168 | dominanceInfo.properlyDominates(shape, constructStack[0]); |
1169 | doBeforeLoopNest(callback: [&] { |
1170 | // Outside of any forall/where/elemental loops, create a temporary that |
1171 | // will both be able to save the vector subscripted designator shape(s) |
1172 | // and element addresses. |
1173 | temp = |
1174 | insertSavedEntity(region, fir::factory::AnyVectorSubscriptStack{ |
1175 | loc, builder, loweredLhs.lhs.getType(), |
1176 | shapeIsInvariant, rank}); |
1177 | }); |
1178 | // Save shape before the elemental loop nest created by the vector |
1179 | // subscripted LHS. |
1180 | auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>(); |
1181 | auto insertionPoint = builder.saveInsertionPoint(); |
1182 | builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop); |
1183 | vectorTmp.pushShape(loc, builder, shape); |
1184 | builder.restoreInsertionPoint(insertionPoint); |
1185 | } else { |
1186 | // Otherwise, only save the LHS address. |
1187 | // If the LHS address dominates the constructs, its SSA value can |
1188 | // simply be tracked and there is no need to save the address in memory. |
1189 | // Otherwise, the addresses are stored at each iteration in memory with |
1190 | // a descriptor stack. |
1191 | if (constructStack.empty() || |
1192 | dominanceInfo.properlyDominates(loweredLhs.lhs, constructStack[0])) |
1193 | doBeforeLoopNest(callback: [&] { |
1194 | temp = insertSavedEntity(region, fir::factory::SSARegister{}); |
1195 | }); |
1196 | else |
1197 | doBeforeLoopNest(callback: [&] { |
1198 | temp = insertSavedEntity( |
1199 | region, fir::factory::AnyVariableStack{loc, builder, |
1200 | loweredLhs.lhs.getType()}); |
1201 | }); |
1202 | } |
1203 | temp->pushValue(loc, builder, loweredLhs.lhs); |
1204 | generateCleanupIfAny(loweredLhs.elementalCleanup); |
1205 | if (loweredLhs.vectorSubscriptLoopNest) { |
1206 | constructStack.pop_back(); |
1207 | builder.setInsertionPointAfter( |
1208 | loweredLhs.vectorSubscriptLoopNest->outerLoop); |
1209 | } |
1210 | } |
1211 | |
1212 | /// Lower an ordered assignment tree to fir.do_loop and hlfir.assign given |
1213 | /// a schedule. |
1214 | static void lower(hlfir::OrderedAssignmentTreeOpInterface root, |
1215 | mlir::PatternRewriter &rewriter, hlfir::Schedule &schedule) { |
1216 | auto module = root->getParentOfType<mlir::ModuleOp>(); |
1217 | fir::FirOpBuilder builder(rewriter, module); |
1218 | OrderedAssignmentRewriter assignmentRewriter(builder, root); |
1219 | for (auto &run : schedule) |
1220 | assignmentRewriter.lowerRun(run); |
1221 | assignmentRewriter.cleanupSavedEntities(); |
1222 | } |
1223 | |
1224 | /// Shared rewrite entry point for all the ordered assignment tree root |
1225 | /// operations. It calls the scheduler and then apply the schedule. |
1226 | static mlir::LogicalResult rewrite(hlfir::OrderedAssignmentTreeOpInterface root, |
1227 | bool tryFusingAssignments, |
1228 | mlir::PatternRewriter &rewriter) { |
1229 | hlfir::Schedule schedule = |
1230 | hlfir::buildEvaluationSchedule(root, tryFusingAssignments); |
1231 | |
1232 | LLVM_DEBUG( |
1233 | /// Debug option to print the scheduling debug info without doing |
1234 | /// any code generation. The operations are simply erased to avoid |
1235 | /// failing and calling the rewrite patterns on nested operations. |
1236 | /// The only purpose of this is to help testing scheduling without |
1237 | /// having to test generated code. |
1238 | if (dbgScheduleOnly) { |
1239 | rewriter.eraseOp(root); |
1240 | return mlir::success(); |
1241 | }); |
1242 | lower(root, rewriter, schedule); |
1243 | rewriter.eraseOp(root); |
1244 | return mlir::success(); |
1245 | } |
1246 | |
1247 | namespace { |
1248 | |
1249 | class ForallOpConversion : public mlir::OpRewritePattern<hlfir::ForallOp> { |
1250 | public: |
1251 | explicit ForallOpConversion(mlir::MLIRContext *ctx, bool tryFusingAssignments) |
1252 | : OpRewritePattern{ctx}, tryFusingAssignments{tryFusingAssignments} {} |
1253 | |
1254 | mlir::LogicalResult |
1255 | matchAndRewrite(hlfir::ForallOp forallOp, |
1256 | mlir::PatternRewriter &rewriter) const override { |
1257 | auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>( |
1258 | forallOp.getOperation()); |
1259 | if (mlir::failed(::rewrite(root, tryFusingAssignments, rewriter))) |
1260 | TODO(forallOp.getLoc(), "FORALL construct or statement in HLFIR" ); |
1261 | return mlir::success(); |
1262 | } |
1263 | const bool tryFusingAssignments; |
1264 | }; |
1265 | |
1266 | class WhereOpConversion : public mlir::OpRewritePattern<hlfir::WhereOp> { |
1267 | public: |
1268 | explicit WhereOpConversion(mlir::MLIRContext *ctx, bool tryFusingAssignments) |
1269 | : OpRewritePattern{ctx}, tryFusingAssignments{tryFusingAssignments} {} |
1270 | |
1271 | mlir::LogicalResult |
1272 | matchAndRewrite(hlfir::WhereOp whereOp, |
1273 | mlir::PatternRewriter &rewriter) const override { |
1274 | auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>( |
1275 | whereOp.getOperation()); |
1276 | return ::rewrite(root, tryFusingAssignments, rewriter); |
1277 | } |
1278 | const bool tryFusingAssignments; |
1279 | }; |
1280 | |
1281 | class RegionAssignConversion |
1282 | : public mlir::OpRewritePattern<hlfir::RegionAssignOp> { |
1283 | public: |
1284 | explicit RegionAssignConversion(mlir::MLIRContext *ctx) |
1285 | : OpRewritePattern{ctx} {} |
1286 | |
1287 | mlir::LogicalResult |
1288 | matchAndRewrite(hlfir::RegionAssignOp regionAssignOp, |
1289 | mlir::PatternRewriter &rewriter) const override { |
1290 | auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>( |
1291 | regionAssignOp.getOperation()); |
1292 | return ::rewrite(root, /*tryFusingAssignments=*/false, rewriter); |
1293 | } |
1294 | }; |
1295 | |
1296 | class LowerHLFIROrderedAssignments |
1297 | : public hlfir::impl::LowerHLFIROrderedAssignmentsBase< |
1298 | LowerHLFIROrderedAssignments> { |
1299 | public: |
1300 | void runOnOperation() override { |
1301 | // Running on a ModuleOp because this pass may generate FuncOp declaration |
1302 | // for runtime calls. This could be a FuncOp pass otherwise. |
1303 | auto module = this->getOperation(); |
1304 | auto *context = &getContext(); |
1305 | mlir::RewritePatternSet patterns(context); |
1306 | // Patterns are only defined for the OrderedAssignmentTreeOpInterface |
1307 | // operations that can be the root of ordered assignments. The other |
1308 | // operations will be taken care of while rewriting these trees (they |
1309 | // cannot exist outside of these operations given their verifiers/traits). |
1310 | patterns.insert<ForallOpConversion, WhereOpConversion>( |
1311 | context, this->tryFusingAssignments.getValue()); |
1312 | patterns.insert<RegionAssignConversion>(context); |
1313 | mlir::ConversionTarget target(*context); |
1314 | target.markUnknownOpDynamicallyLegal([](mlir::Operation *op) { |
1315 | return !mlir::isa<hlfir::OrderedAssignmentTreeOpInterface>(op); |
1316 | }); |
1317 | if (mlir::failed(mlir::applyPartialConversion(module, target, |
1318 | std::move(patterns)))) { |
1319 | mlir::emitError(mlir::UnknownLoc::get(context), |
1320 | "failure in HLFIR ordered assignments lowering pass" ); |
1321 | signalPassFailure(); |
1322 | } |
1323 | } |
1324 | }; |
1325 | } // namespace |
1326 | |
1327 | std::unique_ptr<mlir::Pass> hlfir::createLowerHLFIROrderedAssignmentsPass() { |
1328 | return std::make_unique<LowerHLFIROrderedAssignments>(); |
1329 | } |
1330 | |