1//===- StackArrays.cpp ----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Builder/FIRBuilder.h"
10#include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
11#include "flang/Optimizer/Dialect/FIRAttr.h"
12#include "flang/Optimizer/Dialect/FIRDialect.h"
13#include "flang/Optimizer/Dialect/FIROps.h"
14#include "flang/Optimizer/Dialect/FIRType.h"
15#include "flang/Optimizer/Dialect/Support/FIRContext.h"
16#include "flang/Optimizer/Transforms/Passes.h"
17#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
18#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
19#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
20#include "mlir/Analysis/DataFlowFramework.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/Diagnostics.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Interfaces/LoopLikeInterface.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/LogicalResult.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include "mlir/Transforms/Passes.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/DenseSet.h"
33#include "llvm/ADT/PointerUnion.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/raw_ostream.h"
36#include <optional>
37
38namespace fir {
39#define GEN_PASS_DEF_STACKARRAYS
40#include "flang/Optimizer/Transforms/Passes.h.inc"
41} // namespace fir
42
43#define DEBUG_TYPE "stack-arrays"
44
45static llvm::cl::opt<std::size_t> maxAllocsPerFunc(
46 "stack-arrays-max-allocs",
47 llvm::cl::desc("The maximum number of heap allocations to consider in one "
48 "function before skipping (to save compilation time). Set "
49 "to 0 for no limit."),
50 llvm::cl::init(1000), llvm::cl::Hidden);
51
52namespace {
53
54/// The state of an SSA value at each program point
55enum class AllocationState {
56 /// This means that the allocation state of a variable cannot be determined
57 /// at this program point, e.g. because one route through a conditional freed
58 /// the variable and the other route didn't.
59 /// This asserts a known-unknown: different from the unknown-unknown of having
60 /// no AllocationState stored for a particular SSA value
61 Unknown,
62 /// Means this SSA value was allocated on the heap in this function and has
63 /// now been freed
64 Freed,
65 /// Means this SSA value was allocated on the heap in this function and is a
66 /// candidate for moving to the stack
67 Allocated,
68};
69
70/// Stores where an alloca should be inserted. If the PointerUnion is an
71/// Operation the alloca should be inserted /after/ the operation. If it is a
72/// block, the alloca can be placed anywhere in that block.
73class InsertionPoint {
74 llvm::PointerUnion<mlir::Operation *, mlir::Block *> location;
75 bool saveRestoreStack;
76
77 /// Get contained pointer type or nullptr
78 template <class T>
79 T *tryGetPtr() const {
80 if (location.is<T *>())
81 return location.get<T *>();
82 return nullptr;
83 }
84
85public:
86 template <class T>
87 InsertionPoint(T *ptr, bool saveRestoreStack = false)
88 : location(ptr), saveRestoreStack{saveRestoreStack} {}
89 InsertionPoint(std::nullptr_t null)
90 : location(null), saveRestoreStack{false} {}
91
92 /// Get contained operation, or nullptr
93 mlir::Operation *tryGetOperation() const {
94 return tryGetPtr<mlir::Operation>();
95 }
96
97 /// Get contained block, or nullptr
98 mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); }
99
100 /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave
101 /// intrinsic should be added before the alloca, and an llvm.stackrestore
102 /// intrinsic should be added where the freemem is
103 bool shouldSaveRestoreStack() const { return saveRestoreStack; }
104
105 operator bool() const { return tryGetOperation() || tryGetBlock(); }
106
107 bool operator==(const InsertionPoint &rhs) const {
108 return (location == rhs.location) &&
109 (saveRestoreStack == rhs.saveRestoreStack);
110 }
111
112 bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); }
113};
114
115/// Maps SSA values to their AllocationState at a particular program point.
116/// Also caches the insertion points for the new alloca operations
117class LatticePoint : public mlir::dataflow::AbstractDenseLattice {
118 // Maps all values we are interested in to states
119 llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap;
120
121public:
122 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint)
123 using AbstractDenseLattice::AbstractDenseLattice;
124
125 bool operator==(const LatticePoint &rhs) const {
126 return stateMap == rhs.stateMap;
127 }
128
129 /// Join the lattice accross control-flow edges
130 mlir::ChangeResult join(const AbstractDenseLattice &lattice) override;
131
132 void print(llvm::raw_ostream &os) const override;
133
134 /// Clear all modifications
135 mlir::ChangeResult reset();
136
137 /// Set the state of an SSA value
138 mlir::ChangeResult set(mlir::Value value, AllocationState state);
139
140 /// Get fir.allocmem ops which were allocated in this function and always
141 /// freed before the function returns, plus whre to insert replacement
142 /// fir.alloca ops
143 void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const;
144
145 std::optional<AllocationState> get(mlir::Value val) const;
146};
147
148class AllocationAnalysis
149 : public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> {
150public:
151 using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis;
152
153 void visitOperation(mlir::Operation *op, const LatticePoint &before,
154 LatticePoint *after) override;
155
156 /// At an entry point, the last modifications of all memory resources are
157 /// yet to be determined
158 void setToEntryState(LatticePoint *lattice) override;
159
160protected:
161 /// Visit control flow operations and decide whether to call visitOperation
162 /// to apply the transfer function
163 void processOperation(mlir::Operation *op) override;
164};
165
166/// Drives analysis to find candidate fir.allocmem operations which could be
167/// moved to the stack. Intended to be used with mlir::Pass::getAnalysis
168class StackArraysAnalysisWrapper {
169public:
170 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper)
171
172 // Maps fir.allocmem -> place to insert alloca
173 using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>;
174
175 StackArraysAnalysisWrapper(mlir::Operation *op) {}
176
177 // returns nullptr if analysis failed
178 const AllocMemMap *getCandidateOps(mlir::Operation *func);
179
180private:
181 llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps;
182
183 mlir::LogicalResult analyseFunction(mlir::Operation *func);
184};
185
186/// Converts a fir.allocmem to a fir.alloca
187class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
188public:
189 explicit AllocMemConversion(
190 mlir::MLIRContext *ctx,
191 const StackArraysAnalysisWrapper::AllocMemMap &candidateOps)
192 : OpRewritePattern(ctx), candidateOps{candidateOps} {}
193
194 mlir::LogicalResult
195 matchAndRewrite(fir::AllocMemOp allocmem,
196 mlir::PatternRewriter &rewriter) const override;
197
198 /// Determine where to insert the alloca operation. The returned value should
199 /// be checked to see if it is inside a loop
200 static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc);
201
202private:
203 /// Handle to the DFA (already run)
204 const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;
205
206 /// If we failed to find an insertion point not inside a loop, see if it would
207 /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
208 static InsertionPoint findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc);
209
210 /// Returns the alloca if it was successfully inserted, otherwise {}
211 std::optional<fir::AllocaOp>
212 insertAlloca(fir::AllocMemOp &oldAlloc,
213 mlir::PatternRewriter &rewriter) const;
214
215 /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem
216 void insertStackSaveRestore(fir::AllocMemOp &oldAlloc,
217 mlir::PatternRewriter &rewriter) const;
218};
219
220class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> {
221public:
222 StackArraysPass() = default;
223 StackArraysPass(const StackArraysPass &pass);
224
225 llvm::StringRef getDescription() const override;
226
227 void runOnOperation() override;
228 void runOnFunc(mlir::Operation *func);
229
230private:
231 Statistic runCount{this, "stackArraysRunCount",
232 "Number of heap allocations moved to the stack"};
233};
234
235} // namespace
236
237static void print(llvm::raw_ostream &os, AllocationState state) {
238 switch (state) {
239 case AllocationState::Unknown:
240 os << "Unknown";
241 break;
242 case AllocationState::Freed:
243 os << "Freed";
244 break;
245 case AllocationState::Allocated:
246 os << "Allocated";
247 break;
248 }
249}
250
251/// Join two AllocationStates for the same value coming from different CFG
252/// blocks
253static AllocationState join(AllocationState lhs, AllocationState rhs) {
254 // | Allocated | Freed | Unknown
255 // ========= | ========= | ========= | =========
256 // Allocated | Allocated | Unknown | Unknown
257 // Freed | Unknown | Freed | Unknown
258 // Unknown | Unknown | Unknown | Unknown
259 if (lhs == rhs)
260 return lhs;
261 return AllocationState::Unknown;
262}
263
264mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) {
265 const auto &rhs = static_cast<const LatticePoint &>(lattice);
266 mlir::ChangeResult changed = mlir::ChangeResult::NoChange;
267
268 // add everything from rhs to map, handling cases where values are in both
269 for (const auto &[value, rhsState] : rhs.stateMap) {
270 auto it = stateMap.find(value);
271 if (it != stateMap.end()) {
272 // value is present in both maps
273 AllocationState myState = it->second;
274 AllocationState newState = ::join(myState, rhsState);
275 if (newState != myState) {
276 changed = mlir::ChangeResult::Change;
277 it->getSecond() = newState;
278 }
279 } else {
280 // value not present in current map: add it
281 stateMap.insert({value, rhsState});
282 changed = mlir::ChangeResult::Change;
283 }
284 }
285
286 return changed;
287}
288
289void LatticePoint::print(llvm::raw_ostream &os) const {
290 for (const auto &[value, state] : stateMap) {
291 os << value << ": ";
292 ::print(os, state);
293 }
294}
295
296mlir::ChangeResult LatticePoint::reset() {
297 if (stateMap.empty())
298 return mlir::ChangeResult::NoChange;
299 stateMap.clear();
300 return mlir::ChangeResult::Change;
301}
302
303mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) {
304 if (stateMap.count(value)) {
305 // already in map
306 AllocationState &oldState = stateMap[value];
307 if (oldState != state) {
308 stateMap[value] = state;
309 return mlir::ChangeResult::Change;
310 }
311 return mlir::ChangeResult::NoChange;
312 }
313 stateMap.insert({value, state});
314 return mlir::ChangeResult::Change;
315}
316
317/// Get values which were allocated in this function and always freed before
318/// the function returns
319void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const {
320 for (auto &[value, state] : stateMap) {
321 if (state == AllocationState::Freed)
322 out.insert(value);
323 }
324}
325
326std::optional<AllocationState> LatticePoint::get(mlir::Value val) const {
327 auto it = stateMap.find(val);
328 if (it == stateMap.end())
329 return {};
330 return it->second;
331}
332
333void AllocationAnalysis::visitOperation(mlir::Operation *op,
334 const LatticePoint &before,
335 LatticePoint *after) {
336 LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op
337 << "\n");
338 LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n");
339
340 // propagate before -> after
341 mlir::ChangeResult changed = after->join(before);
342
343 if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) {
344 assert(op->getNumResults() == 1 && "fir.allocmem has one result");
345 auto attr = op->getAttrOfType<fir::MustBeHeapAttr>(
346 fir::MustBeHeapAttr::getAttrName());
347 if (attr && attr.getValue()) {
348 LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n");
349 // skip allocation marked not to be moved
350 return;
351 }
352
353 auto retTy = allocmem.getAllocatedType();
354 if (!retTy.isa<fir::SequenceType>()) {
355 LLVM_DEBUG(llvm::dbgs()
356 << "--Allocation is not for an array: skipping\n");
357 return;
358 }
359
360 mlir::Value result = op->getResult(0);
361 changed |= after->set(result, AllocationState::Allocated);
362 } else if (mlir::isa<fir::FreeMemOp>(op)) {
363 assert(op->getNumOperands() == 1 && "fir.freemem has one operand");
364 mlir::Value operand = op->getOperand(0);
365 std::optional<AllocationState> operandState = before.get(operand);
366 if (operandState && *operandState == AllocationState::Allocated) {
367 // don't tag things not allocated in this function as freed, so that we
368 // don't think they are candidates for moving to the stack
369 changed |= after->set(operand, AllocationState::Freed);
370 }
371 } else if (mlir::isa<fir::ResultOp>(op)) {
372 mlir::Operation *parent = op->getParentOp();
373 LatticePoint *parentLattice = getLattice(parent);
374 assert(parentLattice);
375 mlir::ChangeResult parentChanged = parentLattice->join(*after);
376 propagateIfChanged(parentLattice, parentChanged);
377 }
378
379 // we pass lattices straight through fir.call because called functions should
380 // not deallocate flang-generated array temporaries
381
382 LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n");
383 propagateIfChanged(after, changed);
384}
385
386void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {
387 propagateIfChanged(lattice, lattice->reset());
388}
389
390/// Mostly a copy of AbstractDenseLattice::processOperation - the difference
391/// being that call operations are passed through to the transfer function
392void AllocationAnalysis::processOperation(mlir::Operation *op) {
393 // If the containing block is not executable, bail out.
394 if (!getOrCreateFor<mlir::dataflow::Executable>(op, op->getBlock())->isLive())
395 return;
396
397 // Get the dense lattice to update
398 mlir::dataflow::AbstractDenseLattice *after = getLattice(op);
399
400 // If this op implements region control-flow, then control-flow dictates its
401 // transfer function.
402 if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op))
403 return visitRegionBranchOperation(op, branch, after);
404
405 // pass call operations through to the transfer function
406
407 // Get the dense state before the execution of the op.
408 const mlir::dataflow::AbstractDenseLattice *before;
409 if (mlir::Operation *prev = op->getPrevNode())
410 before = getLatticeFor(op, prev);
411 else
412 before = getLatticeFor(op, op->getBlock());
413
414 /// Invoke the operation transfer function
415 visitOperationImpl(op, *before, after);
416}
417
418mlir::LogicalResult
419StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
420 assert(mlir::isa<mlir::func::FuncOp>(func));
421 size_t nAllocs = 0;
422 func->walk([&nAllocs](fir::AllocMemOp) { nAllocs++; });
423 // don't bother with the analysis if there are no heap allocations
424 if (nAllocs == 0)
425 return mlir::success();
426 if ((maxAllocsPerFunc != 0) && (nAllocs > maxAllocsPerFunc)) {
427 LLVM_DEBUG(llvm::dbgs() << "Skipping stack arrays for function with "
428 << nAllocs << " heap allocations");
429 return mlir::success();
430 }
431
432 mlir::DataFlowSolver solver;
433 // constant propagation is required for dead code analysis, dead code analysis
434 // is required to mark blocks live (required for mlir dense dfa)
435 solver.load<mlir::dataflow::SparseConstantPropagation>();
436 solver.load<mlir::dataflow::DeadCodeAnalysis>();
437
438 auto [it, inserted] = funcMaps.try_emplace(func);
439 AllocMemMap &candidateOps = it->second;
440
441 solver.load<AllocationAnalysis>();
442 if (failed(solver.initializeAndRun(func))) {
443 llvm::errs() << "DataFlowSolver failed!";
444 return mlir::failure();
445 }
446
447 LatticePoint point{func};
448 auto joinOperationLattice = [&](mlir::Operation *op) {
449 const LatticePoint *lattice = solver.lookupState<LatticePoint>(op);
450 // there will be no lattice for an unreachable block
451 if (lattice)
452 (void)point.join(*lattice);
453 };
454 func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); });
455 func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); });
456 llvm::DenseSet<mlir::Value> freedValues;
457 point.appendFreedValues(freedValues);
458
459 // We only replace allocations which are definately freed on all routes
460 // through the function because otherwise the allocation may have an intende
461 // lifetime longer than the current stack frame (e.g. a heap allocation which
462 // is then freed by another function).
463 for (mlir::Value freedValue : freedValues) {
464 fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>();
465 InsertionPoint insertionPoint =
466 AllocMemConversion::findAllocaInsertionPoint(allocmem);
467 if (insertionPoint)
468 candidateOps.insert({allocmem, insertionPoint});
469 }
470
471 LLVM_DEBUG(for (auto [allocMemOp, _]
472 : candidateOps) {
473 llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
474 });
475 return mlir::success();
476}
477
478const StackArraysAnalysisWrapper::AllocMemMap *
479StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) {
480 if (!funcMaps.contains(func))
481 if (mlir::failed(analyseFunction(func)))
482 return nullptr;
483 return &funcMaps[func];
484}
485
486/// Restore the old allocation type exected by existing code
487static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter,
488 const mlir::Location &loc,
489 mlir::Value heap, mlir::Value stack) {
490 mlir::Type heapTy = heap.getType();
491 mlir::Type stackTy = stack.getType();
492
493 if (heapTy == stackTy)
494 return stack;
495
496 fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy);
497 LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy =
498 mlir::cast<fir::ReferenceType>(stackTy);
499 assert(firHeapTy.getElementType() == firRefTy.getElementType() &&
500 "Allocations must have the same type");
501
502 auto insertionPoint = rewriter.saveInsertionPoint();
503 rewriter.setInsertionPointAfter(stack.getDefiningOp());
504 mlir::Value conv =
505 rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult();
506 rewriter.restoreInsertionPoint(insertionPoint);
507 return conv;
508}
509
510mlir::LogicalResult
511AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
512 mlir::PatternRewriter &rewriter) const {
513 auto oldInsertionPt = rewriter.saveInsertionPoint();
514 // add alloca operation
515 std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter);
516 rewriter.restoreInsertionPoint(oldInsertionPt);
517 if (!alloca)
518 return mlir::failure();
519
520 // remove freemem operations
521 llvm::SmallVector<mlir::Operation *> erases;
522 for (mlir::Operation *user : allocmem.getOperation()->getUsers())
523 if (mlir::isa<fir::FreeMemOp>(user))
524 erases.push_back(user);
525 // now we are done iterating the users, it is safe to mutate them
526 for (mlir::Operation *erase : erases)
527 rewriter.eraseOp(erase);
528
529 // replace references to heap allocation with references to stack allocation
530 mlir::Value newValue = convertAllocationType(
531 rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult());
532 rewriter.replaceAllUsesWith(allocmem.getResult(), newValue);
533
534 // remove allocmem operation
535 rewriter.eraseOp(allocmem.getOperation());
536
537 return mlir::success();
538}
539
540static bool isInLoop(mlir::Block *block) {
541 return mlir::LoopLikeOpInterface::blockIsInLoop(block);
542}
543
544static bool isInLoop(mlir::Operation *op) {
545 return isInLoop(op->getBlock()) ||
546 op->getParentOfType<mlir::LoopLikeOpInterface>();
547}
548
549InsertionPoint
550AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) {
551 // Ideally the alloca should be inserted at the end of the function entry
552 // block so that we do not allocate stack space in a loop. However,
553 // the operands to the alloca may not be available that early, so insert it
554 // after the last operand becomes available
555 // If the old allocmem op was in an openmp region then it should not be moved
556 // outside of that
557 LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: "
558 << oldAlloc << "\n");
559
560 // check that an Operation or Block we are about to return is not in a loop
561 auto checkReturn = [&](auto *point) -> InsertionPoint {
562 if (isInLoop(point)) {
563 mlir::Operation *oldAllocOp = oldAlloc.getOperation();
564 if (isInLoop(oldAllocOp)) {
565 // where we want to put it is in a loop, and even the old location is in
566 // a loop. Give up.
567 return findAllocaLoopInsertionPoint(oldAlloc);
568 }
569 return {oldAllocOp};
570 }
571 return {point};
572 };
573
574 auto oldOmpRegion =
575 oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
576
577 // Find when the last operand value becomes available
578 mlir::Block *operandsBlock = nullptr;
579 mlir::Operation *lastOperand = nullptr;
580 for (mlir::Value operand : oldAlloc.getOperands()) {
581 LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n");
582 mlir::Operation *op = operand.getDefiningOp();
583 if (!op)
584 return checkReturn(oldAlloc.getOperation());
585 if (!operandsBlock)
586 operandsBlock = op->getBlock();
587 else if (operandsBlock != op->getBlock()) {
588 LLVM_DEBUG(llvm::dbgs()
589 << "----operand declared in a different block!\n");
590 // Operation::isBeforeInBlock requires the operations to be in the same
591 // block. The best we can do is the location of the allocmem.
592 return checkReturn(oldAlloc.getOperation());
593 }
594 if (!lastOperand || lastOperand->isBeforeInBlock(op))
595 lastOperand = op;
596 }
597
598 if (lastOperand) {
599 // there were value operands to the allocmem so insert after the last one
600 LLVM_DEBUG(llvm::dbgs()
601 << "--Placing after last operand: " << *lastOperand << "\n");
602 // check we aren't moving out of an omp region
603 auto lastOpOmpRegion =
604 lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>();
605 if (lastOpOmpRegion == oldOmpRegion)
606 return checkReturn(lastOperand);
607 // Presumably this happened because the operands became ready before the
608 // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should
609 // imply that oldOmpRegion comes after lastOpOmpRegion.
610 return checkReturn(oldOmpRegion.getAllocaBlock());
611 }
612
613 // There were no value operands to the allocmem so we are safe to insert it
614 // as early as we want
615
616 // handle openmp case
617 if (oldOmpRegion)
618 return checkReturn(oldOmpRegion.getAllocaBlock());
619
620 // fall back to the function entry block
621 mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>();
622 assert(func && "This analysis is run on func.func");
623 mlir::Block &entryBlock = func.getBlocks().front();
624 LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n");
625 return checkReturn(&entryBlock);
626}
627
628InsertionPoint
629AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) {
630 mlir::Operation *oldAllocOp = oldAlloc;
631 // This is only called as a last resort. We should try to insert at the
632 // location of the old allocation, which is inside of a loop, using
633 // llvm.stacksave/llvm.stackrestore
634
635 // find freemem ops
636 llvm::SmallVector<mlir::Operation *, 1> freeOps;
637 for (mlir::Operation *user : oldAllocOp->getUsers())
638 if (mlir::isa<fir::FreeMemOp>(user))
639 freeOps.push_back(user);
640 assert(freeOps.size() && "DFA should only return freed memory");
641
642 // Don't attempt to reason about a stacksave/stackrestore between different
643 // blocks
644 for (mlir::Operation *free : freeOps)
645 if (free->getBlock() != oldAllocOp->getBlock())
646 return {nullptr};
647
648 // Check that there aren't any other stack allocations in between the
649 // stack save and stack restore
650 // note: for flang generated temporaries there should only be one free op
651 for (mlir::Operation *free : freeOps) {
652 for (mlir::Operation *op = oldAlloc; op && op != free;
653 op = op->getNextNode()) {
654 if (mlir::isa<fir::AllocaOp>(op))
655 return {nullptr};
656 }
657 }
658
659 return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true};
660}
661
662std::optional<fir::AllocaOp>
663AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc,
664 mlir::PatternRewriter &rewriter) const {
665 auto it = candidateOps.find(oldAlloc.getOperation());
666 if (it == candidateOps.end())
667 return {};
668 InsertionPoint insertionPoint = it->second;
669 if (!insertionPoint)
670 return {};
671
672 if (insertionPoint.shouldSaveRestoreStack())
673 insertStackSaveRestore(oldAlloc, rewriter);
674
675 mlir::Location loc = oldAlloc.getLoc();
676 mlir::Type varTy = oldAlloc.getInType();
677 if (mlir::Operation *op = insertionPoint.tryGetOperation()) {
678 rewriter.setInsertionPointAfter(op);
679 } else {
680 mlir::Block *block = insertionPoint.tryGetBlock();
681 assert(block && "There must be a valid insertion point");
682 rewriter.setInsertionPointToStart(block);
683 }
684
685 auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef {
686 if (opt)
687 return *opt;
688 return {};
689 };
690
691 llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName());
692 llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName());
693 return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName,
694 oldAlloc.getTypeparams(),
695 oldAlloc.getShape());
696}
697
698void AllocMemConversion::insertStackSaveRestore(
699 fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const {
700 auto oldPoint = rewriter.saveInsertionPoint();
701 auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>();
702 fir::FirOpBuilder builder{rewriter, mod};
703
704 mlir::func::FuncOp stackSaveFn = fir::factory::getLlvmStackSave(builder);
705 mlir::SymbolRefAttr stackSaveSym =
706 builder.getSymbolRefAttr(stackSaveFn.getName());
707
708 builder.setInsertionPoint(oldAlloc);
709 mlir::Value sp =
710 builder
711 .create<fir::CallOp>(oldAlloc.getLoc(),
712 stackSaveFn.getFunctionType().getResults(),
713 stackSaveSym, mlir::ValueRange{})
714 .getResult(0);
715
716 mlir::func::FuncOp stackRestoreFn =
717 fir::factory::getLlvmStackRestore(builder);
718 mlir::SymbolRefAttr stackRestoreSym =
719 builder.getSymbolRefAttr(stackRestoreFn.getName());
720
721 for (mlir::Operation *user : oldAlloc->getUsers()) {
722 if (mlir::isa<fir::FreeMemOp>(user)) {
723 builder.setInsertionPoint(user);
724 builder.create<fir::CallOp>(user->getLoc(),
725 stackRestoreFn.getFunctionType().getResults(),
726 stackRestoreSym, mlir::ValueRange{sp});
727 }
728 }
729
730 rewriter.restoreInsertionPoint(oldPoint);
731}
732
733StackArraysPass::StackArraysPass(const StackArraysPass &pass)
734 : fir::impl::StackArraysBase<StackArraysPass>(pass) {}
735
736llvm::StringRef StackArraysPass::getDescription() const {
737 return "Move heap allocated array temporaries to the stack";
738}
739
740void StackArraysPass::runOnOperation() {
741 mlir::ModuleOp mod = getOperation();
742
743 mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
744}
745
746void StackArraysPass::runOnFunc(mlir::Operation *func) {
747 assert(mlir::isa<mlir::func::FuncOp>(func));
748
749 auto &analysis = getAnalysis<StackArraysAnalysisWrapper>();
750 const StackArraysAnalysisWrapper::AllocMemMap *candidateOps =
751 analysis.getCandidateOps(func);
752 if (!candidateOps) {
753 signalPassFailure();
754 return;
755 }
756
757 if (candidateOps->empty())
758 return;
759 runCount += candidateOps->size();
760
761 llvm::SmallVector<mlir::Operation *> opsToConvert;
762 opsToConvert.reserve(candidateOps->size());
763 for (auto [op, _] : *candidateOps)
764 opsToConvert.push_back(op);
765
766 mlir::MLIRContext &context = getContext();
767 mlir::RewritePatternSet patterns(&context);
768 mlir::GreedyRewriteConfig config;
769 // prevent the pattern driver form merging blocks
770 config.enableRegionSimplification = false;
771
772 patterns.insert<AllocMemConversion>(&context, *candidateOps);
773 if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
774 std::move(patterns), config))) {
775 mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
776 signalPassFailure();
777 }
778}
779
780std::unique_ptr<mlir::Pass> fir::createStackArraysPass() {
781 return std::make_unique<StackArraysPass>();
782}
783

source code of flang/lib/Optimizer/Transforms/StackArrays.cpp