1//===- BufferOptimizations.cpp - pre-pass optimizations for bufferization -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logic for three optimization passes. The first two
10// passes try to move alloc nodes out of blocks to reduce the number of
11// allocations and copies during buffer deallocation. The third pass tries to
12// convert heap-based allocations to stack-based allocations, if possible.
13
14#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
15
16#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
17#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
18#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/IR/Operation.h"
22#include "mlir/Interfaces/LoopLikeInterface.h"
23#include "mlir/Pass/Pass.h"
24
25namespace mlir {
26namespace bufferization {
27#define GEN_PASS_DEF_BUFFERHOISTING
28#define GEN_PASS_DEF_BUFFERLOOPHOISTING
29#define GEN_PASS_DEF_PROMOTEBUFFERSTOSTACK
30#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
31} // namespace bufferization
32} // namespace mlir
33
34using namespace mlir;
35using namespace mlir::bufferization;
36
37/// Returns true if the given operation implements a known high-level region-
38/// based control-flow interface.
39static bool isKnownControlFlowInterface(Operation *op) {
40 return isa<LoopLikeOpInterface, RegionBranchOpInterface>(op);
41}
42
43/// Returns true if the given operation represents a loop by testing whether it
44/// implements the `LoopLikeOpInterface` or the `RegionBranchOpInterface`. In
45/// the case of a `RegionBranchOpInterface`, it checks all region-based control-
46/// flow edges for cycles.
47static bool isLoop(Operation *op) {
48 // If the operation implements the `LoopLikeOpInterface` it can be considered
49 // a loop.
50 if (isa<LoopLikeOpInterface>(op))
51 return true;
52
53 // If the operation does not implement the `RegionBranchOpInterface`, it is
54 // (currently) not possible to detect a loop.
55 auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
56 if (!regionInterface)
57 return false;
58
59 return regionInterface.hasLoop();
60}
61
62/// Returns true if the given operation implements the AllocationOpInterface
63/// and it supports the dominate block hoisting.
64static bool allowAllocDominateBlockHoisting(Operation *op) {
65 auto allocOp = dyn_cast<AllocationOpInterface>(op);
66 return allocOp &&
67 static_cast<uint8_t>(allocOp.getHoistingKind() & HoistingKind::Block);
68}
69
70/// Returns true if the given operation implements the AllocationOpInterface
71/// and it supports the loop hoisting.
72static bool allowAllocLoopHoisting(Operation *op) {
73 auto allocOp = dyn_cast<AllocationOpInterface>(op);
74 return allocOp &&
75 static_cast<uint8_t>(allocOp.getHoistingKind() & HoistingKind::Loop);
76}
77
78/// Check if the size of the allocation is less than the given size. The
79/// transformation is only applied to small buffers since large buffers could
80/// exceed the stack space.
81static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes,
82 unsigned maxRankOfAllocatedMemRef) {
83 auto type = dyn_cast<ShapedType>(alloc.getType());
84 if (!type || !alloc.getDefiningOp<memref::AllocOp>())
85 return false;
86 if (!type.hasStaticShape()) {
87 // Check if the dynamic shape dimension of the alloc is produced by
88 // `memref.rank`. If this is the case, it is likely to be small.
89 // Furthermore, the dimension is limited to the maximum rank of the
90 // allocated memref to avoid large values by multiplying several small
91 // values.
92 if (type.getRank() <= maxRankOfAllocatedMemRef) {
93 return llvm::all_of(Range: alloc.getDefiningOp()->getOperands(),
94 P: [&](Value operand) {
95 return operand.getDefiningOp<memref::RankOp>();
96 });
97 }
98 return false;
99 }
100 unsigned bitwidth = mlir::DataLayout::closest(op: alloc.getDefiningOp())
101 .getTypeSizeInBits(t: type.getElementType());
102 return type.getNumElements() * bitwidth <= maximumSizeInBytes * 8;
103}
104
105/// Checks whether the given aliases leave the allocation scope.
106static bool
107leavesAllocationScope(Region *parentRegion,
108 const BufferViewFlowAnalysis::ValueSetT &aliases) {
109 for (Value alias : aliases) {
110 for (auto *use : alias.getUsers()) {
111 // If there is at least one alias that leaves the parent region, we know
112 // that this alias escapes the whole region and hence the associated
113 // allocation leaves allocation scope.
114 if (isa<RegionBranchTerminatorOpInterface>(use) &&
115 use->getParentRegion() == parentRegion)
116 return true;
117 }
118 }
119 return false;
120}
121
122/// Checks, if an automated allocation scope for a given alloc value exists.
123static bool hasAllocationScope(Value alloc,
124 const BufferViewFlowAnalysis &aliasAnalysis) {
125 Region *region = alloc.getParentRegion();
126 do {
127 if (Operation *parentOp = region->getParentOp()) {
128 // Check if the operation is an automatic allocation scope and whether an
129 // alias leaves the scope. This means, an allocation yields out of
130 // this scope and can not be transformed in a stack-based allocation.
131 if (parentOp->hasTrait<OpTrait::AutomaticAllocationScope>() &&
132 !leavesAllocationScope(parentRegion: region, aliases: aliasAnalysis.resolve(value: alloc)))
133 return true;
134 // Check if the operation is a known control flow interface and break the
135 // loop to avoid transformation in loops. Furthermore skip transformation
136 // if the operation does not implement a RegionBeanchOpInterface.
137 if (isLoop(op: parentOp) || !isKnownControlFlowInterface(op: parentOp))
138 break;
139 }
140 } while ((region = region->getParentRegion()));
141 return false;
142}
143
144namespace {
145
146//===----------------------------------------------------------------------===//
147// BufferAllocationHoisting
148//===----------------------------------------------------------------------===//
149
150/// A base implementation compatible with the `BufferAllocationHoisting` class.
151struct BufferAllocationHoistingStateBase {
152 /// A pointer to the current dominance info.
153 DominanceInfo *dominators;
154
155 /// The current allocation value.
156 Value allocValue;
157
158 /// The current placement block (if any).
159 Block *placementBlock;
160
161 /// Initializes the state base.
162 BufferAllocationHoistingStateBase(DominanceInfo *dominators, Value allocValue,
163 Block *placementBlock)
164 : dominators(dominators), allocValue(allocValue),
165 placementBlock(placementBlock) {}
166};
167
168/// Implements the actual hoisting logic for allocation nodes.
169template <typename StateT>
170class BufferAllocationHoisting : public BufferPlacementTransformationBase {
171public:
172 BufferAllocationHoisting(Operation *op)
173 : BufferPlacementTransformationBase(op), dominators(op),
174 postDominators(op), scopeOp(op) {}
175
176 /// Moves allocations upwards.
177 void hoist() {
178 SmallVector<Value> allocsAndAllocas;
179 for (BufferPlacementAllocs::AllocEntry &entry : allocs)
180 allocsAndAllocas.push_back(Elt: std::get<0>(t&: entry));
181 scopeOp->walk([&](memref::AllocaOp op) {
182 allocsAndAllocas.push_back(Elt: op.getMemref());
183 });
184
185 for (auto allocValue : allocsAndAllocas) {
186 if (!StateT::shouldHoistOpType(allocValue.getDefiningOp()))
187 continue;
188 Operation *definingOp = allocValue.getDefiningOp();
189 assert(definingOp && "No defining op");
190 auto operands = definingOp->getOperands();
191 auto resultAliases = aliases.resolve(value: allocValue);
192 // Determine the common dominator block of all aliases.
193 Block *dominatorBlock =
194 findCommonDominator(allocValue, resultAliases, dominators);
195 // Init the initial hoisting state.
196 StateT state(&dominators, allocValue, allocValue.getParentBlock());
197 // Check for additional allocation dependencies to compute an upper bound
198 // for hoisting.
199 Block *dependencyBlock = nullptr;
200 // If this node has dependencies, check all dependent nodes. This ensures
201 // that all dependency values have been computed before allocating the
202 // buffer.
203 for (Value depValue : operands) {
204 Block *depBlock = depValue.getParentBlock();
205 if (!dependencyBlock || dominators.dominates(a: dependencyBlock, b: depBlock))
206 dependencyBlock = depBlock;
207 }
208
209 // Find the actual placement block and determine the start operation using
210 // an upper placement-block boundary. The idea is that placement block
211 // cannot be moved any further upwards than the given upper bound.
212 Block *placementBlock = findPlacementBlock(
213 state, upperBound: state.computeUpperBound(dominatorBlock, dependencyBlock));
214 Operation *startOperation = BufferPlacementAllocs::getStartOperation(
215 allocValue, placementBlock, liveness);
216
217 // Move the alloc in front of the start operation.
218 Operation *allocOperation = allocValue.getDefiningOp();
219 allocOperation->moveBefore(existingOp: startOperation);
220 }
221 }
222
223private:
224 /// Finds a valid placement block by walking upwards in the CFG until we
225 /// either cannot continue our walk due to constraints (given by the StateT
226 /// implementation) or we have reached the upper-most dominator block.
227 Block *findPlacementBlock(StateT &state, Block *upperBound) {
228 Block *currentBlock = state.placementBlock;
229 // Walk from the innermost regions/loops to the outermost regions/loops and
230 // find an appropriate placement block that satisfies the constraint of the
231 // current StateT implementation. Walk until we reach the upperBound block
232 // (if any).
233
234 // If we are not able to find a valid parent operation or an associated
235 // parent block, break the walk loop.
236 Operation *parentOp;
237 Block *parentBlock;
238 while ((parentOp = currentBlock->getParentOp()) &&
239 (parentBlock = parentOp->getBlock()) &&
240 (!upperBound ||
241 dominators.properlyDominates(a: upperBound, b: currentBlock))) {
242 // Try to find an immediate dominator and check whether the parent block
243 // is above the immediate dominator (if any).
244 DominanceInfoNode *idom = nullptr;
245
246 // DominanceInfo doesn't support getNode queries for single-block regions.
247 if (!currentBlock->isEntryBlock())
248 idom = dominators.getNode(a: currentBlock)->getIDom();
249
250 if (idom && dominators.properlyDominates(a: parentBlock, b: idom->getBlock())) {
251 // If the current immediate dominator is below the placement block, move
252 // to the immediate dominator block.
253 currentBlock = idom->getBlock();
254 state.recordMoveToDominator(currentBlock);
255 } else {
256 // We have to move to our parent block since an immediate dominator does
257 // either not exist or is above our parent block. If we cannot move to
258 // our parent operation due to constraints given by the StateT
259 // implementation, break the walk loop. Furthermore, we should not move
260 // allocations out of unknown region-based control-flow operations.
261 if (!isKnownControlFlowInterface(op: parentOp) ||
262 !state.isLegalPlacement(parentOp))
263 break;
264 // Move to our parent block by notifying the current StateT
265 // implementation.
266 currentBlock = parentBlock;
267 state.recordMoveToParent(currentBlock);
268 }
269 }
270 // Return the finally determined placement block.
271 return state.placementBlock;
272 }
273
274 /// The dominator info to find the appropriate start operation to move the
275 /// allocs.
276 DominanceInfo dominators;
277
278 /// The post dominator info to move the dependent allocs in the right
279 /// position.
280 PostDominanceInfo postDominators;
281
282 /// The map storing the final placement blocks of a given alloc value.
283 llvm::DenseMap<Value, Block *> placementBlocks;
284
285 /// The operation that this transformation is working on. It is used to also
286 /// gather allocas.
287 Operation *scopeOp;
288};
289
290/// A state implementation compatible with the `BufferAllocationHoisting` class
291/// that hoists allocations into dominator blocks while keeping them inside of
292/// loops.
293struct BufferAllocationHoistingState : BufferAllocationHoistingStateBase {
294 using BufferAllocationHoistingStateBase::BufferAllocationHoistingStateBase;
295
296 /// Computes the upper bound for the placement block search.
297 Block *computeUpperBound(Block *dominatorBlock, Block *dependencyBlock) {
298 // If we do not have a dependency block, the upper bound is given by the
299 // dominator block.
300 if (!dependencyBlock)
301 return dominatorBlock;
302
303 // Find the "lower" block of the dominator and the dependency block to
304 // ensure that we do not move allocations above this block.
305 return dominators->properlyDominates(a: dominatorBlock, b: dependencyBlock)
306 ? dependencyBlock
307 : dominatorBlock;
308 }
309
310 /// Returns true if the given operation does not represent a loop.
311 bool isLegalPlacement(Operation *op) { return !isLoop(op); }
312
313 /// Returns true if the given operation should be considered for hoisting.
314 static bool shouldHoistOpType(Operation *op) {
315 return allowAllocDominateBlockHoisting(op);
316 }
317
318 /// Sets the current placement block to the given block.
319 void recordMoveToDominator(Block *block) { placementBlock = block; }
320
321 /// Sets the current placement block to the given block.
322 void recordMoveToParent(Block *block) { recordMoveToDominator(block); }
323};
324
325/// A state implementation compatible with the `BufferAllocationHoisting` class
326/// that hoists allocations out of loops.
327struct BufferAllocationLoopHoistingState : BufferAllocationHoistingStateBase {
328 using BufferAllocationHoistingStateBase::BufferAllocationHoistingStateBase;
329
330 /// Remembers the dominator block of all aliases.
331 Block *aliasDominatorBlock = nullptr;
332
333 /// Computes the upper bound for the placement block search.
334 Block *computeUpperBound(Block *dominatorBlock, Block *dependencyBlock) {
335 aliasDominatorBlock = dominatorBlock;
336 // If there is a dependency block, we have to use this block as an upper
337 // bound to satisfy all allocation value dependencies.
338 return dependencyBlock ? dependencyBlock : nullptr;
339 }
340
341 /// Returns true if the given operation represents a loop and one of the
342 /// aliases caused the `aliasDominatorBlock` to be "above" the block of the
343 /// given loop operation. If this is the case, it indicates that the
344 /// allocation is passed via a back edge.
345 bool isLegalPlacement(Operation *op) {
346 return isLoop(op) &&
347 !dominators->dominates(a: aliasDominatorBlock, b: op->getBlock());
348 }
349
350 /// Returns true if the given operation should be considered for hoisting.
351 static bool shouldHoistOpType(Operation *op) {
352 return allowAllocLoopHoisting(op);
353 }
354
355 /// Does not change the internal placement block, as we want to move
356 /// operations out of loops only.
357 void recordMoveToDominator(Block *block) {}
358
359 /// Sets the current placement block to the given block.
360 void recordMoveToParent(Block *block) { placementBlock = block; }
361};
362
363//===----------------------------------------------------------------------===//
364// BufferPlacementPromotion
365//===----------------------------------------------------------------------===//
366
367/// Promotes heap-based allocations to stack-based allocations (if possible).
368class BufferPlacementPromotion : BufferPlacementTransformationBase {
369public:
370 BufferPlacementPromotion(Operation *op)
371 : BufferPlacementTransformationBase(op) {}
372
373 /// Promote buffers to stack-based allocations.
374 void promote(function_ref<bool(Value)> isSmallAlloc) {
375 for (BufferPlacementAllocs::AllocEntry &entry : allocs) {
376 Value alloc = std::get<0>(t&: entry);
377 Operation *dealloc = std::get<1>(t&: entry);
378 // Checking several requirements to transform an AllocOp into an AllocaOp.
379 // The transformation is done if the allocation is limited to a given
380 // size. Furthermore, a deallocation must not be defined for this
381 // allocation entry and a parent allocation scope must exist.
382 if (!isSmallAlloc(alloc) || dealloc ||
383 !hasAllocationScope(alloc, aliasAnalysis: aliases))
384 continue;
385
386 Operation *startOperation = BufferPlacementAllocs::getStartOperation(
387 allocValue: alloc, placementBlock: alloc.getParentBlock(), liveness);
388 // Build a new alloca that is associated with its parent
389 // `AutomaticAllocationScope` determined during the initialization phase.
390 OpBuilder builder(startOperation);
391 Operation *allocOp = alloc.getDefiningOp();
392 if (auto allocInterface = dyn_cast<AllocationOpInterface>(allocOp)) {
393 Operation *alloca =
394 allocInterface.buildPromotedAlloc(builder, alloc).value();
395 if (!alloca)
396 continue;
397 // Replace the original alloc by a newly created alloca.
398 allocOp->replaceAllUsesWith(values&: alloca);
399 allocOp->erase();
400 }
401 }
402 }
403};
404
405//===----------------------------------------------------------------------===//
406// BufferOptimizationPasses
407//===----------------------------------------------------------------------===//
408
409/// The buffer hoisting pass that hoists allocation nodes into dominating
410/// blocks.
411struct BufferHoistingPass
412 : public bufferization::impl::BufferHoistingBase<BufferHoistingPass> {
413
414 void runOnOperation() override {
415 // Hoist all allocations into dominator blocks.
416 BufferAllocationHoisting<BufferAllocationHoistingState> optimizer(
417 getOperation());
418 optimizer.hoist();
419 }
420};
421
422/// The buffer loop hoisting pass that hoists allocation nodes out of loops.
423struct BufferLoopHoistingPass
424 : public bufferization::impl::BufferLoopHoistingBase<
425 BufferLoopHoistingPass> {
426
427 void runOnOperation() override {
428 // Hoist all allocations out of loops.
429 hoistBuffersFromLoops(getOperation());
430 }
431};
432
433/// The promote buffer to stack pass that tries to convert alloc nodes into
434/// alloca nodes.
435class PromoteBuffersToStackPass
436 : public bufferization::impl::PromoteBuffersToStackBase<
437 PromoteBuffersToStackPass> {
438public:
439 PromoteBuffersToStackPass(unsigned maxAllocSizeInBytes,
440 unsigned maxRankOfAllocatedMemRef) {
441 this->maxAllocSizeInBytes = maxAllocSizeInBytes;
442 this->maxRankOfAllocatedMemRef = maxRankOfAllocatedMemRef;
443 }
444
445 explicit PromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc)
446 : isSmallAlloc(std::move(isSmallAlloc)) {}
447
448 LogicalResult initialize(MLIRContext *context) override {
449 if (isSmallAlloc == nullptr) {
450 isSmallAlloc = [=](Value alloc) {
451 return defaultIsSmallAlloc(alloc, maxAllocSizeInBytes,
452 maxRankOfAllocatedMemRef);
453 };
454 }
455 return success();
456 }
457
458 void runOnOperation() override {
459 // Move all allocation nodes and convert candidates into allocas.
460 BufferPlacementPromotion optimizer(getOperation());
461 optimizer.promote(isSmallAlloc);
462 }
463
464private:
465 std::function<bool(Value)> isSmallAlloc;
466};
467
468} // namespace
469
470void mlir::bufferization::hoistBuffersFromLoops(Operation *op) {
471 BufferAllocationHoisting<BufferAllocationLoopHoistingState> optimizer(op);
472 optimizer.hoist();
473}
474
475std::unique_ptr<Pass> mlir::bufferization::createBufferHoistingPass() {
476 return std::make_unique<BufferHoistingPass>();
477}
478
479std::unique_ptr<Pass> mlir::bufferization::createBufferLoopHoistingPass() {
480 return std::make_unique<BufferLoopHoistingPass>();
481}
482
483std::unique_ptr<Pass> mlir::bufferization::createPromoteBuffersToStackPass(
484 unsigned maxAllocSizeInBytes, unsigned maxRankOfAllocatedMemRef) {
485 return std::make_unique<PromoteBuffersToStackPass>(args&: maxAllocSizeInBytes,
486 args&: maxRankOfAllocatedMemRef);
487}
488
489std::unique_ptr<Pass> mlir::bufferization::createPromoteBuffersToStackPass(
490 std::function<bool(Value)> isSmallAlloc) {
491 return std::make_unique<PromoteBuffersToStackPass>(args: std::move(isSmallAlloc));
492}
493

source code of mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp