1 | //===- StackArrays.cpp ----------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
10 | #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" |
11 | #include "flang/Optimizer/Dialect/FIRAttr.h" |
12 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
13 | #include "flang/Optimizer/Dialect/FIROps.h" |
14 | #include "flang/Optimizer/Dialect/FIRType.h" |
15 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
16 | #include "flang/Optimizer/Transforms/Passes.h" |
17 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
18 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
19 | #include "mlir/Analysis/DataFlow/DenseAnalysis.h" |
20 | #include "mlir/Analysis/DataFlowFramework.h" |
21 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
22 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
23 | #include "mlir/IR/Builders.h" |
24 | #include "mlir/IR/Diagnostics.h" |
25 | #include "mlir/IR/Value.h" |
26 | #include "mlir/Interfaces/LoopLikeInterface.h" |
27 | #include "mlir/Pass/Pass.h" |
28 | #include "mlir/Support/LogicalResult.h" |
29 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
30 | #include "mlir/Transforms/Passes.h" |
31 | #include "llvm/ADT/DenseMap.h" |
32 | #include "llvm/ADT/DenseSet.h" |
33 | #include "llvm/ADT/PointerUnion.h" |
34 | #include "llvm/Support/Casting.h" |
35 | #include "llvm/Support/raw_ostream.h" |
36 | #include <optional> |
37 | |
38 | namespace fir { |
39 | #define GEN_PASS_DEF_STACKARRAYS |
40 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
41 | } // namespace fir |
42 | |
43 | #define DEBUG_TYPE "stack-arrays" |
44 | |
45 | static llvm::cl::opt<std::size_t> maxAllocsPerFunc( |
46 | "stack-arrays-max-allocs" , |
47 | llvm::cl::desc("The maximum number of heap allocations to consider in one " |
48 | "function before skipping (to save compilation time). Set " |
49 | "to 0 for no limit." ), |
50 | llvm::cl::init(1000), llvm::cl::Hidden); |
51 | |
52 | namespace { |
53 | |
54 | /// The state of an SSA value at each program point |
55 | enum class AllocationState { |
56 | /// This means that the allocation state of a variable cannot be determined |
57 | /// at this program point, e.g. because one route through a conditional freed |
58 | /// the variable and the other route didn't. |
59 | /// This asserts a known-unknown: different from the unknown-unknown of having |
60 | /// no AllocationState stored for a particular SSA value |
61 | Unknown, |
62 | /// Means this SSA value was allocated on the heap in this function and has |
63 | /// now been freed |
64 | Freed, |
65 | /// Means this SSA value was allocated on the heap in this function and is a |
66 | /// candidate for moving to the stack |
67 | Allocated, |
68 | }; |
69 | |
70 | /// Stores where an alloca should be inserted. If the PointerUnion is an |
71 | /// Operation the alloca should be inserted /after/ the operation. If it is a |
72 | /// block, the alloca can be placed anywhere in that block. |
73 | class InsertionPoint { |
74 | llvm::PointerUnion<mlir::Operation *, mlir::Block *> location; |
75 | bool saveRestoreStack; |
76 | |
77 | /// Get contained pointer type or nullptr |
78 | template <class T> |
79 | T *tryGetPtr() const { |
80 | if (location.is<T *>()) |
81 | return location.get<T *>(); |
82 | return nullptr; |
83 | } |
84 | |
85 | public: |
86 | template <class T> |
87 | InsertionPoint(T *ptr, bool saveRestoreStack = false) |
88 | : location(ptr), saveRestoreStack{saveRestoreStack} {} |
89 | InsertionPoint(std::nullptr_t null) |
90 | : location(null), saveRestoreStack{false} {} |
91 | |
92 | /// Get contained operation, or nullptr |
93 | mlir::Operation *tryGetOperation() const { |
94 | return tryGetPtr<mlir::Operation>(); |
95 | } |
96 | |
97 | /// Get contained block, or nullptr |
98 | mlir::Block *tryGetBlock() const { return tryGetPtr<mlir::Block>(); } |
99 | |
100 | /// Get whether the stack should be saved/restored. If yes, an llvm.stacksave |
101 | /// intrinsic should be added before the alloca, and an llvm.stackrestore |
102 | /// intrinsic should be added where the freemem is |
103 | bool shouldSaveRestoreStack() const { return saveRestoreStack; } |
104 | |
105 | operator bool() const { return tryGetOperation() || tryGetBlock(); } |
106 | |
107 | bool operator==(const InsertionPoint &rhs) const { |
108 | return (location == rhs.location) && |
109 | (saveRestoreStack == rhs.saveRestoreStack); |
110 | } |
111 | |
112 | bool operator!=(const InsertionPoint &rhs) const { return !(*this == rhs); } |
113 | }; |
114 | |
115 | /// Maps SSA values to their AllocationState at a particular program point. |
116 | /// Also caches the insertion points for the new alloca operations |
117 | class LatticePoint : public mlir::dataflow::AbstractDenseLattice { |
118 | // Maps all values we are interested in to states |
119 | llvm::SmallDenseMap<mlir::Value, AllocationState, 1> stateMap; |
120 | |
121 | public: |
122 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LatticePoint) |
123 | using AbstractDenseLattice::AbstractDenseLattice; |
124 | |
125 | bool operator==(const LatticePoint &rhs) const { |
126 | return stateMap == rhs.stateMap; |
127 | } |
128 | |
129 | /// Join the lattice accross control-flow edges |
130 | mlir::ChangeResult join(const AbstractDenseLattice &lattice) override; |
131 | |
132 | void print(llvm::raw_ostream &os) const override; |
133 | |
134 | /// Clear all modifications |
135 | mlir::ChangeResult reset(); |
136 | |
137 | /// Set the state of an SSA value |
138 | mlir::ChangeResult set(mlir::Value value, AllocationState state); |
139 | |
140 | /// Get fir.allocmem ops which were allocated in this function and always |
141 | /// freed before the function returns, plus whre to insert replacement |
142 | /// fir.alloca ops |
143 | void appendFreedValues(llvm::DenseSet<mlir::Value> &out) const; |
144 | |
145 | std::optional<AllocationState> get(mlir::Value val) const; |
146 | }; |
147 | |
148 | class AllocationAnalysis |
149 | : public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> { |
150 | public: |
151 | using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis; |
152 | |
153 | void visitOperation(mlir::Operation *op, const LatticePoint &before, |
154 | LatticePoint *after) override; |
155 | |
156 | /// At an entry point, the last modifications of all memory resources are |
157 | /// yet to be determined |
158 | void setToEntryState(LatticePoint *lattice) override; |
159 | |
160 | protected: |
161 | /// Visit control flow operations and decide whether to call visitOperation |
162 | /// to apply the transfer function |
163 | void processOperation(mlir::Operation *op) override; |
164 | }; |
165 | |
166 | /// Drives analysis to find candidate fir.allocmem operations which could be |
167 | /// moved to the stack. Intended to be used with mlir::Pass::getAnalysis |
168 | class StackArraysAnalysisWrapper { |
169 | public: |
170 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StackArraysAnalysisWrapper) |
171 | |
172 | // Maps fir.allocmem -> place to insert alloca |
173 | using AllocMemMap = llvm::DenseMap<mlir::Operation *, InsertionPoint>; |
174 | |
175 | StackArraysAnalysisWrapper(mlir::Operation *op) {} |
176 | |
177 | // returns nullptr if analysis failed |
178 | const AllocMemMap *getCandidateOps(mlir::Operation *func); |
179 | |
180 | private: |
181 | llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps; |
182 | |
183 | mlir::LogicalResult analyseFunction(mlir::Operation *func); |
184 | }; |
185 | |
186 | /// Converts a fir.allocmem to a fir.alloca |
187 | class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> { |
188 | public: |
189 | explicit AllocMemConversion( |
190 | mlir::MLIRContext *ctx, |
191 | const StackArraysAnalysisWrapper::AllocMemMap &candidateOps) |
192 | : OpRewritePattern(ctx), candidateOps{candidateOps} {} |
193 | |
194 | mlir::LogicalResult |
195 | matchAndRewrite(fir::AllocMemOp allocmem, |
196 | mlir::PatternRewriter &rewriter) const override; |
197 | |
198 | /// Determine where to insert the alloca operation. The returned value should |
199 | /// be checked to see if it is inside a loop |
200 | static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc); |
201 | |
202 | private: |
203 | /// Handle to the DFA (already run) |
204 | const StackArraysAnalysisWrapper::AllocMemMap &candidateOps; |
205 | |
206 | /// If we failed to find an insertion point not inside a loop, see if it would |
207 | /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop |
208 | static InsertionPoint findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc); |
209 | |
210 | /// Returns the alloca if it was successfully inserted, otherwise {} |
211 | std::optional<fir::AllocaOp> |
212 | insertAlloca(fir::AllocMemOp &oldAlloc, |
213 | mlir::PatternRewriter &rewriter) const; |
214 | |
215 | /// Inserts a stacksave before oldAlloc and a stackrestore after each freemem |
216 | void insertStackSaveRestore(fir::AllocMemOp &oldAlloc, |
217 | mlir::PatternRewriter &rewriter) const; |
218 | }; |
219 | |
220 | class StackArraysPass : public fir::impl::StackArraysBase<StackArraysPass> { |
221 | public: |
222 | StackArraysPass() = default; |
223 | StackArraysPass(const StackArraysPass &pass); |
224 | |
225 | llvm::StringRef getDescription() const override; |
226 | |
227 | void runOnOperation() override; |
228 | void runOnFunc(mlir::Operation *func); |
229 | |
230 | private: |
231 | Statistic runCount{this, "stackArraysRunCount" , |
232 | "Number of heap allocations moved to the stack" }; |
233 | }; |
234 | |
235 | } // namespace |
236 | |
237 | static void print(llvm::raw_ostream &os, AllocationState state) { |
238 | switch (state) { |
239 | case AllocationState::Unknown: |
240 | os << "Unknown" ; |
241 | break; |
242 | case AllocationState::Freed: |
243 | os << "Freed" ; |
244 | break; |
245 | case AllocationState::Allocated: |
246 | os << "Allocated" ; |
247 | break; |
248 | } |
249 | } |
250 | |
251 | /// Join two AllocationStates for the same value coming from different CFG |
252 | /// blocks |
253 | static AllocationState join(AllocationState lhs, AllocationState rhs) { |
254 | // | Allocated | Freed | Unknown |
255 | // ========= | ========= | ========= | ========= |
256 | // Allocated | Allocated | Unknown | Unknown |
257 | // Freed | Unknown | Freed | Unknown |
258 | // Unknown | Unknown | Unknown | Unknown |
259 | if (lhs == rhs) |
260 | return lhs; |
261 | return AllocationState::Unknown; |
262 | } |
263 | |
264 | mlir::ChangeResult LatticePoint::join(const AbstractDenseLattice &lattice) { |
265 | const auto &rhs = static_cast<const LatticePoint &>(lattice); |
266 | mlir::ChangeResult changed = mlir::ChangeResult::NoChange; |
267 | |
268 | // add everything from rhs to map, handling cases where values are in both |
269 | for (const auto &[value, rhsState] : rhs.stateMap) { |
270 | auto it = stateMap.find(value); |
271 | if (it != stateMap.end()) { |
272 | // value is present in both maps |
273 | AllocationState myState = it->second; |
274 | AllocationState newState = ::join(myState, rhsState); |
275 | if (newState != myState) { |
276 | changed = mlir::ChangeResult::Change; |
277 | it->getSecond() = newState; |
278 | } |
279 | } else { |
280 | // value not present in current map: add it |
281 | stateMap.insert({value, rhsState}); |
282 | changed = mlir::ChangeResult::Change; |
283 | } |
284 | } |
285 | |
286 | return changed; |
287 | } |
288 | |
289 | void LatticePoint::print(llvm::raw_ostream &os) const { |
290 | for (const auto &[value, state] : stateMap) { |
291 | os << value << ": " ; |
292 | ::print(os, state); |
293 | } |
294 | } |
295 | |
296 | mlir::ChangeResult LatticePoint::reset() { |
297 | if (stateMap.empty()) |
298 | return mlir::ChangeResult::NoChange; |
299 | stateMap.clear(); |
300 | return mlir::ChangeResult::Change; |
301 | } |
302 | |
303 | mlir::ChangeResult LatticePoint::set(mlir::Value value, AllocationState state) { |
304 | if (stateMap.count(value)) { |
305 | // already in map |
306 | AllocationState &oldState = stateMap[value]; |
307 | if (oldState != state) { |
308 | stateMap[value] = state; |
309 | return mlir::ChangeResult::Change; |
310 | } |
311 | return mlir::ChangeResult::NoChange; |
312 | } |
313 | stateMap.insert({value, state}); |
314 | return mlir::ChangeResult::Change; |
315 | } |
316 | |
317 | /// Get values which were allocated in this function and always freed before |
318 | /// the function returns |
319 | void LatticePoint::appendFreedValues(llvm::DenseSet<mlir::Value> &out) const { |
320 | for (auto &[value, state] : stateMap) { |
321 | if (state == AllocationState::Freed) |
322 | out.insert(value); |
323 | } |
324 | } |
325 | |
326 | std::optional<AllocationState> LatticePoint::get(mlir::Value val) const { |
327 | auto it = stateMap.find(val); |
328 | if (it == stateMap.end()) |
329 | return {}; |
330 | return it->second; |
331 | } |
332 | |
333 | void AllocationAnalysis::visitOperation(mlir::Operation *op, |
334 | const LatticePoint &before, |
335 | LatticePoint *after) { |
336 | LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op |
337 | << "\n" ); |
338 | LLVM_DEBUG(llvm::dbgs() << "--Lattice in: " << before << "\n" ); |
339 | |
340 | // propagate before -> after |
341 | mlir::ChangeResult changed = after->join(before); |
342 | |
343 | if (auto allocmem = mlir::dyn_cast<fir::AllocMemOp>(op)) { |
344 | assert(op->getNumResults() == 1 && "fir.allocmem has one result" ); |
345 | auto attr = op->getAttrOfType<fir::MustBeHeapAttr>( |
346 | fir::MustBeHeapAttr::getAttrName()); |
347 | if (attr && attr.getValue()) { |
348 | LLVM_DEBUG(llvm::dbgs() << "--Found fir.must_be_heap: skipping\n" ); |
349 | // skip allocation marked not to be moved |
350 | return; |
351 | } |
352 | |
353 | auto retTy = allocmem.getAllocatedType(); |
354 | if (!retTy.isa<fir::SequenceType>()) { |
355 | LLVM_DEBUG(llvm::dbgs() |
356 | << "--Allocation is not for an array: skipping\n" ); |
357 | return; |
358 | } |
359 | |
360 | mlir::Value result = op->getResult(0); |
361 | changed |= after->set(result, AllocationState::Allocated); |
362 | } else if (mlir::isa<fir::FreeMemOp>(op)) { |
363 | assert(op->getNumOperands() == 1 && "fir.freemem has one operand" ); |
364 | mlir::Value operand = op->getOperand(0); |
365 | std::optional<AllocationState> operandState = before.get(operand); |
366 | if (operandState && *operandState == AllocationState::Allocated) { |
367 | // don't tag things not allocated in this function as freed, so that we |
368 | // don't think they are candidates for moving to the stack |
369 | changed |= after->set(operand, AllocationState::Freed); |
370 | } |
371 | } else if (mlir::isa<fir::ResultOp>(op)) { |
372 | mlir::Operation *parent = op->getParentOp(); |
373 | LatticePoint *parentLattice = getLattice(parent); |
374 | assert(parentLattice); |
375 | mlir::ChangeResult parentChanged = parentLattice->join(*after); |
376 | propagateIfChanged(parentLattice, parentChanged); |
377 | } |
378 | |
379 | // we pass lattices straight through fir.call because called functions should |
380 | // not deallocate flang-generated array temporaries |
381 | |
382 | LLVM_DEBUG(llvm::dbgs() << "--Lattice out: " << *after << "\n" ); |
383 | propagateIfChanged(after, changed); |
384 | } |
385 | |
386 | void AllocationAnalysis::setToEntryState(LatticePoint *lattice) { |
387 | propagateIfChanged(lattice, lattice->reset()); |
388 | } |
389 | |
390 | /// Mostly a copy of AbstractDenseLattice::processOperation - the difference |
391 | /// being that call operations are passed through to the transfer function |
392 | void AllocationAnalysis::processOperation(mlir::Operation *op) { |
393 | // If the containing block is not executable, bail out. |
394 | if (!getOrCreateFor<mlir::dataflow::Executable>(op, op->getBlock())->isLive()) |
395 | return; |
396 | |
397 | // Get the dense lattice to update |
398 | mlir::dataflow::AbstractDenseLattice *after = getLattice(op); |
399 | |
400 | // If this op implements region control-flow, then control-flow dictates its |
401 | // transfer function. |
402 | if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) |
403 | return visitRegionBranchOperation(op, branch, after); |
404 | |
405 | // pass call operations through to the transfer function |
406 | |
407 | // Get the dense state before the execution of the op. |
408 | const mlir::dataflow::AbstractDenseLattice *before; |
409 | if (mlir::Operation *prev = op->getPrevNode()) |
410 | before = getLatticeFor(op, prev); |
411 | else |
412 | before = getLatticeFor(op, op->getBlock()); |
413 | |
414 | /// Invoke the operation transfer function |
415 | visitOperationImpl(op, *before, after); |
416 | } |
417 | |
418 | mlir::LogicalResult |
419 | StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) { |
420 | assert(mlir::isa<mlir::func::FuncOp>(func)); |
421 | size_t nAllocs = 0; |
422 | func->walk([&nAllocs](fir::AllocMemOp) { nAllocs++; }); |
423 | // don't bother with the analysis if there are no heap allocations |
424 | if (nAllocs == 0) |
425 | return mlir::success(); |
426 | if ((maxAllocsPerFunc != 0) && (nAllocs > maxAllocsPerFunc)) { |
427 | LLVM_DEBUG(llvm::dbgs() << "Skipping stack arrays for function with " |
428 | << nAllocs << " heap allocations" ); |
429 | return mlir::success(); |
430 | } |
431 | |
432 | mlir::DataFlowSolver solver; |
433 | // constant propagation is required for dead code analysis, dead code analysis |
434 | // is required to mark blocks live (required for mlir dense dfa) |
435 | solver.load<mlir::dataflow::SparseConstantPropagation>(); |
436 | solver.load<mlir::dataflow::DeadCodeAnalysis>(); |
437 | |
438 | auto [it, inserted] = funcMaps.try_emplace(func); |
439 | AllocMemMap &candidateOps = it->second; |
440 | |
441 | solver.load<AllocationAnalysis>(); |
442 | if (failed(solver.initializeAndRun(func))) { |
443 | llvm::errs() << "DataFlowSolver failed!" ; |
444 | return mlir::failure(); |
445 | } |
446 | |
447 | LatticePoint point{func}; |
448 | auto joinOperationLattice = [&](mlir::Operation *op) { |
449 | const LatticePoint *lattice = solver.lookupState<LatticePoint>(op); |
450 | // there will be no lattice for an unreachable block |
451 | if (lattice) |
452 | (void)point.join(*lattice); |
453 | }; |
454 | func->walk([&](mlir::func::ReturnOp child) { joinOperationLattice(child); }); |
455 | func->walk([&](fir::UnreachableOp child) { joinOperationLattice(child); }); |
456 | llvm::DenseSet<mlir::Value> freedValues; |
457 | point.appendFreedValues(freedValues); |
458 | |
459 | // We only replace allocations which are definately freed on all routes |
460 | // through the function because otherwise the allocation may have an intende |
461 | // lifetime longer than the current stack frame (e.g. a heap allocation which |
462 | // is then freed by another function). |
463 | for (mlir::Value freedValue : freedValues) { |
464 | fir::AllocMemOp allocmem = freedValue.getDefiningOp<fir::AllocMemOp>(); |
465 | InsertionPoint insertionPoint = |
466 | AllocMemConversion::findAllocaInsertionPoint(allocmem); |
467 | if (insertionPoint) |
468 | candidateOps.insert({allocmem, insertionPoint}); |
469 | } |
470 | |
471 | LLVM_DEBUG(for (auto [allocMemOp, _] |
472 | : candidateOps) { |
473 | llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n'; |
474 | }); |
475 | return mlir::success(); |
476 | } |
477 | |
478 | const StackArraysAnalysisWrapper::AllocMemMap * |
479 | StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) { |
480 | if (!funcMaps.contains(func)) |
481 | if (mlir::failed(analyseFunction(func))) |
482 | return nullptr; |
483 | return &funcMaps[func]; |
484 | } |
485 | |
486 | /// Restore the old allocation type exected by existing code |
487 | static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter, |
488 | const mlir::Location &loc, |
489 | mlir::Value heap, mlir::Value stack) { |
490 | mlir::Type heapTy = heap.getType(); |
491 | mlir::Type stackTy = stack.getType(); |
492 | |
493 | if (heapTy == stackTy) |
494 | return stack; |
495 | |
496 | fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy); |
497 | LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy = |
498 | mlir::cast<fir::ReferenceType>(stackTy); |
499 | assert(firHeapTy.getElementType() == firRefTy.getElementType() && |
500 | "Allocations must have the same type" ); |
501 | |
502 | auto insertionPoint = rewriter.saveInsertionPoint(); |
503 | rewriter.setInsertionPointAfter(stack.getDefiningOp()); |
504 | mlir::Value conv = |
505 | rewriter.create<fir::ConvertOp>(loc, firHeapTy, stack).getResult(); |
506 | rewriter.restoreInsertionPoint(insertionPoint); |
507 | return conv; |
508 | } |
509 | |
510 | mlir::LogicalResult |
511 | AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem, |
512 | mlir::PatternRewriter &rewriter) const { |
513 | auto oldInsertionPt = rewriter.saveInsertionPoint(); |
514 | // add alloca operation |
515 | std::optional<fir::AllocaOp> alloca = insertAlloca(allocmem, rewriter); |
516 | rewriter.restoreInsertionPoint(oldInsertionPt); |
517 | if (!alloca) |
518 | return mlir::failure(); |
519 | |
520 | // remove freemem operations |
521 | llvm::SmallVector<mlir::Operation *> erases; |
522 | for (mlir::Operation *user : allocmem.getOperation()->getUsers()) |
523 | if (mlir::isa<fir::FreeMemOp>(user)) |
524 | erases.push_back(user); |
525 | // now we are done iterating the users, it is safe to mutate them |
526 | for (mlir::Operation *erase : erases) |
527 | rewriter.eraseOp(erase); |
528 | |
529 | // replace references to heap allocation with references to stack allocation |
530 | mlir::Value newValue = convertAllocationType( |
531 | rewriter, allocmem.getLoc(), allocmem.getResult(), alloca->getResult()); |
532 | rewriter.replaceAllUsesWith(allocmem.getResult(), newValue); |
533 | |
534 | // remove allocmem operation |
535 | rewriter.eraseOp(allocmem.getOperation()); |
536 | |
537 | return mlir::success(); |
538 | } |
539 | |
540 | static bool isInLoop(mlir::Block *block) { |
541 | return mlir::LoopLikeOpInterface::blockIsInLoop(block); |
542 | } |
543 | |
544 | static bool isInLoop(mlir::Operation *op) { |
545 | return isInLoop(op->getBlock()) || |
546 | op->getParentOfType<mlir::LoopLikeOpInterface>(); |
547 | } |
548 | |
549 | InsertionPoint |
550 | AllocMemConversion::findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc) { |
551 | // Ideally the alloca should be inserted at the end of the function entry |
552 | // block so that we do not allocate stack space in a loop. However, |
553 | // the operands to the alloca may not be available that early, so insert it |
554 | // after the last operand becomes available |
555 | // If the old allocmem op was in an openmp region then it should not be moved |
556 | // outside of that |
557 | LLVM_DEBUG(llvm::dbgs() << "StackArrays: findAllocaInsertionPoint: " |
558 | << oldAlloc << "\n" ); |
559 | |
560 | // check that an Operation or Block we are about to return is not in a loop |
561 | auto checkReturn = [&](auto *point) -> InsertionPoint { |
562 | if (isInLoop(point)) { |
563 | mlir::Operation *oldAllocOp = oldAlloc.getOperation(); |
564 | if (isInLoop(oldAllocOp)) { |
565 | // where we want to put it is in a loop, and even the old location is in |
566 | // a loop. Give up. |
567 | return findAllocaLoopInsertionPoint(oldAlloc); |
568 | } |
569 | return {oldAllocOp}; |
570 | } |
571 | return {point}; |
572 | }; |
573 | |
574 | auto oldOmpRegion = |
575 | oldAlloc->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>(); |
576 | |
577 | // Find when the last operand value becomes available |
578 | mlir::Block *operandsBlock = nullptr; |
579 | mlir::Operation *lastOperand = nullptr; |
580 | for (mlir::Value operand : oldAlloc.getOperands()) { |
581 | LLVM_DEBUG(llvm::dbgs() << "--considering operand " << operand << "\n" ); |
582 | mlir::Operation *op = operand.getDefiningOp(); |
583 | if (!op) |
584 | return checkReturn(oldAlloc.getOperation()); |
585 | if (!operandsBlock) |
586 | operandsBlock = op->getBlock(); |
587 | else if (operandsBlock != op->getBlock()) { |
588 | LLVM_DEBUG(llvm::dbgs() |
589 | << "----operand declared in a different block!\n" ); |
590 | // Operation::isBeforeInBlock requires the operations to be in the same |
591 | // block. The best we can do is the location of the allocmem. |
592 | return checkReturn(oldAlloc.getOperation()); |
593 | } |
594 | if (!lastOperand || lastOperand->isBeforeInBlock(op)) |
595 | lastOperand = op; |
596 | } |
597 | |
598 | if (lastOperand) { |
599 | // there were value operands to the allocmem so insert after the last one |
600 | LLVM_DEBUG(llvm::dbgs() |
601 | << "--Placing after last operand: " << *lastOperand << "\n" ); |
602 | // check we aren't moving out of an omp region |
603 | auto lastOpOmpRegion = |
604 | lastOperand->getParentOfType<mlir::omp::OutlineableOpenMPOpInterface>(); |
605 | if (lastOpOmpRegion == oldOmpRegion) |
606 | return checkReturn(lastOperand); |
607 | // Presumably this happened because the operands became ready before the |
608 | // start of this openmp region. (lastOpOmpRegion != oldOmpRegion) should |
609 | // imply that oldOmpRegion comes after lastOpOmpRegion. |
610 | return checkReturn(oldOmpRegion.getAllocaBlock()); |
611 | } |
612 | |
613 | // There were no value operands to the allocmem so we are safe to insert it |
614 | // as early as we want |
615 | |
616 | // handle openmp case |
617 | if (oldOmpRegion) |
618 | return checkReturn(oldOmpRegion.getAllocaBlock()); |
619 | |
620 | // fall back to the function entry block |
621 | mlir::func::FuncOp func = oldAlloc->getParentOfType<mlir::func::FuncOp>(); |
622 | assert(func && "This analysis is run on func.func" ); |
623 | mlir::Block &entryBlock = func.getBlocks().front(); |
624 | LLVM_DEBUG(llvm::dbgs() << "--Placing at the start of func entry block\n" ); |
625 | return checkReturn(&entryBlock); |
626 | } |
627 | |
628 | InsertionPoint |
629 | AllocMemConversion::findAllocaLoopInsertionPoint(fir::AllocMemOp &oldAlloc) { |
630 | mlir::Operation *oldAllocOp = oldAlloc; |
631 | // This is only called as a last resort. We should try to insert at the |
632 | // location of the old allocation, which is inside of a loop, using |
633 | // llvm.stacksave/llvm.stackrestore |
634 | |
635 | // find freemem ops |
636 | llvm::SmallVector<mlir::Operation *, 1> freeOps; |
637 | for (mlir::Operation *user : oldAllocOp->getUsers()) |
638 | if (mlir::isa<fir::FreeMemOp>(user)) |
639 | freeOps.push_back(user); |
640 | assert(freeOps.size() && "DFA should only return freed memory" ); |
641 | |
642 | // Don't attempt to reason about a stacksave/stackrestore between different |
643 | // blocks |
644 | for (mlir::Operation *free : freeOps) |
645 | if (free->getBlock() != oldAllocOp->getBlock()) |
646 | return {nullptr}; |
647 | |
648 | // Check that there aren't any other stack allocations in between the |
649 | // stack save and stack restore |
650 | // note: for flang generated temporaries there should only be one free op |
651 | for (mlir::Operation *free : freeOps) { |
652 | for (mlir::Operation *op = oldAlloc; op && op != free; |
653 | op = op->getNextNode()) { |
654 | if (mlir::isa<fir::AllocaOp>(op)) |
655 | return {nullptr}; |
656 | } |
657 | } |
658 | |
659 | return InsertionPoint{oldAllocOp, /*shouldStackSaveRestore=*/true}; |
660 | } |
661 | |
662 | std::optional<fir::AllocaOp> |
663 | AllocMemConversion::insertAlloca(fir::AllocMemOp &oldAlloc, |
664 | mlir::PatternRewriter &rewriter) const { |
665 | auto it = candidateOps.find(oldAlloc.getOperation()); |
666 | if (it == candidateOps.end()) |
667 | return {}; |
668 | InsertionPoint insertionPoint = it->second; |
669 | if (!insertionPoint) |
670 | return {}; |
671 | |
672 | if (insertionPoint.shouldSaveRestoreStack()) |
673 | insertStackSaveRestore(oldAlloc, rewriter); |
674 | |
675 | mlir::Location loc = oldAlloc.getLoc(); |
676 | mlir::Type varTy = oldAlloc.getInType(); |
677 | if (mlir::Operation *op = insertionPoint.tryGetOperation()) { |
678 | rewriter.setInsertionPointAfter(op); |
679 | } else { |
680 | mlir::Block *block = insertionPoint.tryGetBlock(); |
681 | assert(block && "There must be a valid insertion point" ); |
682 | rewriter.setInsertionPointToStart(block); |
683 | } |
684 | |
685 | auto unpackName = [](std::optional<llvm::StringRef> opt) -> llvm::StringRef { |
686 | if (opt) |
687 | return *opt; |
688 | return {}; |
689 | }; |
690 | |
691 | llvm::StringRef uniqName = unpackName(oldAlloc.getUniqName()); |
692 | llvm::StringRef bindcName = unpackName(oldAlloc.getBindcName()); |
693 | return rewriter.create<fir::AllocaOp>(loc, varTy, uniqName, bindcName, |
694 | oldAlloc.getTypeparams(), |
695 | oldAlloc.getShape()); |
696 | } |
697 | |
698 | void AllocMemConversion::insertStackSaveRestore( |
699 | fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const { |
700 | auto oldPoint = rewriter.saveInsertionPoint(); |
701 | auto mod = oldAlloc->getParentOfType<mlir::ModuleOp>(); |
702 | fir::FirOpBuilder builder{rewriter, mod}; |
703 | |
704 | mlir::func::FuncOp stackSaveFn = fir::factory::getLlvmStackSave(builder); |
705 | mlir::SymbolRefAttr stackSaveSym = |
706 | builder.getSymbolRefAttr(stackSaveFn.getName()); |
707 | |
708 | builder.setInsertionPoint(oldAlloc); |
709 | mlir::Value sp = |
710 | builder |
711 | .create<fir::CallOp>(oldAlloc.getLoc(), |
712 | stackSaveFn.getFunctionType().getResults(), |
713 | stackSaveSym, mlir::ValueRange{}) |
714 | .getResult(0); |
715 | |
716 | mlir::func::FuncOp stackRestoreFn = |
717 | fir::factory::getLlvmStackRestore(builder); |
718 | mlir::SymbolRefAttr stackRestoreSym = |
719 | builder.getSymbolRefAttr(stackRestoreFn.getName()); |
720 | |
721 | for (mlir::Operation *user : oldAlloc->getUsers()) { |
722 | if (mlir::isa<fir::FreeMemOp>(user)) { |
723 | builder.setInsertionPoint(user); |
724 | builder.create<fir::CallOp>(user->getLoc(), |
725 | stackRestoreFn.getFunctionType().getResults(), |
726 | stackRestoreSym, mlir::ValueRange{sp}); |
727 | } |
728 | } |
729 | |
730 | rewriter.restoreInsertionPoint(oldPoint); |
731 | } |
732 | |
733 | StackArraysPass::StackArraysPass(const StackArraysPass &pass) |
734 | : fir::impl::StackArraysBase<StackArraysPass>(pass) {} |
735 | |
736 | llvm::StringRef StackArraysPass::getDescription() const { |
737 | return "Move heap allocated array temporaries to the stack" ; |
738 | } |
739 | |
740 | void StackArraysPass::runOnOperation() { |
741 | mlir::ModuleOp mod = getOperation(); |
742 | |
743 | mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); }); |
744 | } |
745 | |
746 | void StackArraysPass::runOnFunc(mlir::Operation *func) { |
747 | assert(mlir::isa<mlir::func::FuncOp>(func)); |
748 | |
749 | auto &analysis = getAnalysis<StackArraysAnalysisWrapper>(); |
750 | const StackArraysAnalysisWrapper::AllocMemMap *candidateOps = |
751 | analysis.getCandidateOps(func); |
752 | if (!candidateOps) { |
753 | signalPassFailure(); |
754 | return; |
755 | } |
756 | |
757 | if (candidateOps->empty()) |
758 | return; |
759 | runCount += candidateOps->size(); |
760 | |
761 | llvm::SmallVector<mlir::Operation *> opsToConvert; |
762 | opsToConvert.reserve(candidateOps->size()); |
763 | for (auto [op, _] : *candidateOps) |
764 | opsToConvert.push_back(op); |
765 | |
766 | mlir::MLIRContext &context = getContext(); |
767 | mlir::RewritePatternSet patterns(&context); |
768 | mlir::GreedyRewriteConfig config; |
769 | // prevent the pattern driver form merging blocks |
770 | config.enableRegionSimplification = false; |
771 | |
772 | patterns.insert<AllocMemConversion>(&context, *candidateOps); |
773 | if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert, |
774 | std::move(patterns), config))) { |
775 | mlir::emitError(func->getLoc(), "error in stack arrays optimization\n" ); |
776 | signalPassFailure(); |
777 | } |
778 | } |
779 | |
780 | std::unique_ptr<mlir::Pass> fir::createStackArraysPass() { |
781 | return std::make_unique<StackArraysPass>(); |
782 | } |
783 | |