1//===- Mem2Reg.cpp - Promotes memory slots into values ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Transforms/Mem2Reg.h"
10#include "mlir/Analysis/DataLayoutAnalysis.h"
11#include "mlir/Analysis/SliceAnalysis.h"
12#include "mlir/Analysis/TopologicalSortUtils.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/IR/RegionKindInterface.h"
17#include "mlir/IR/Value.h"
18#include "mlir/Interfaces/ControlFlowInterfaces.h"
19#include "mlir/Interfaces/MemorySlotInterfaces.h"
20#include "mlir/Transforms/Passes.h"
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/Support/GenericIteratedDominanceFrontier.h"
23
24namespace mlir {
25#define GEN_PASS_DEF_MEM2REG
26#include "mlir/Transforms/Passes.h.inc"
27} // namespace mlir
28
29#define DEBUG_TYPE "mem2reg"
30
31using namespace mlir;
32
33/// mem2reg
34///
35/// This pass turns unnecessary uses of automatically allocated memory slots
36/// into direct Value-based operations. For example, it will simplify storing a
37/// constant in a memory slot to immediately load it to a direct use of that
38/// constant. In other words, given a memory slot addressed by a non-aliased
39/// "pointer" Value, mem2reg removes all the uses of that pointer.
40///
41/// Within a block, this is done by following the chain of stores and loads of
42/// the slot and replacing the results of loads with the values previously
43/// stored. If a load happens before any other store, a poison value is used
44/// instead.
45///
46/// Control flow can create situations where a load could be replaced by
47/// multiple possible stores depending on the control flow path taken. As a
48/// result, this pass must introduce new block arguments in some blocks to
49/// accommodate for the multiple possible definitions. Each predecessor will
50/// populate the block argument with the definition reached at its end. With
51/// this, the value stored can be well defined at block boundaries, allowing
52/// the propagation of replacement through blocks.
53///
54/// This pass computes this transformation in four main steps. The two first
55/// steps are performed during an analysis phase that does not mutate IR.
56///
57/// The two steps of the analysis phase are the following:
58/// - A first step computes the list of operations that transitively use the
59/// memory slot we would like to promote. The purpose of this phase is to
60/// identify which uses must be removed to promote the slot, either by rewiring
61/// the user or deleting it. Naturally, direct uses of the slot must be removed.
62/// Sometimes additional uses must also be removed: this is notably the case
63/// when a direct user of the slot cannot rewire its use and must delete itself,
64/// and thus must make its users no longer use it. If any of those uses cannot
65/// be removed by their users in any way, promotion cannot continue: this is
66/// decided at this step.
67/// - A second step computes the list of blocks where a block argument will be
68/// needed ("merge points") without mutating the IR. These blocks are the blocks
69/// leading to a definition clash between two predecessors. Such blocks happen
70/// to be the Iterated Dominance Frontier (IDF) of the set of blocks containing
71/// a store, as they represent the point where a clear defining dominator stops
72/// existing. Computing this information in advance allows making sure the
73/// terminators that will forward values are capable of doing so (inability to
74/// do so aborts promotion at this step).
75///
76/// At this point, promotion is guaranteed to happen, and the mutation phase can
77/// begin with the following steps:
78/// - A third step computes the reaching definition of the memory slot at each
79/// blocking user. This is the core of the mem2reg algorithm, also known as
80/// load-store forwarding. This analyses loads and stores and propagates which
81/// value must be stored in the slot at each blocking user. This is achieved by
82/// doing a depth-first walk of the dominator tree of the function. This is
83/// sufficient because the reaching definition at the beginning of a block is
84/// either its new block argument if it is a merge block, or the definition
85/// reaching the end of its immediate dominator (parent in the dominator tree).
86/// We can therefore propagate this information down the dominator tree to
87/// proceed with renaming within blocks.
88/// - The final fourth step uses the reaching definition to remove blocking uses
89/// in topological order.
90///
91/// For further reading, chapter three of SSA-based Compiler Design [1]
92/// showcases SSA construction, where mem2reg is an adaptation of the same
93/// process.
94///
95/// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022),
96/// Springer.
97
98namespace {
99
100using BlockingUsesMap =
101 llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
102
103/// Information computed during promotion analysis used to perform actual
104/// promotion.
105struct MemorySlotPromotionInfo {
106 /// Blocks for which at least two definitions of the slot values clash.
107 SmallPtrSet<Block *, 8> mergePoints;
108 /// Contains, for each operation, which uses must be eliminated by promotion.
109 /// This is a DAG structure because if an operation must eliminate some of
110 /// its uses, it is because the defining ops of the blocking uses requested
111 /// it. The defining ops therefore must also have blocking uses or be the
112 /// starting point of the blocking uses.
113 BlockingUsesMap userToBlockingUses;
114};
115
116/// Computes information for basic slot promotion. This will check that direct
117/// slot promotion can be performed, and provide the information to execute the
118/// promotion. This does not mutate IR.
119class MemorySlotPromotionAnalyzer {
120public:
121 MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance,
122 const DataLayout &dataLayout)
123 : slot(slot), dominance(dominance), dataLayout(dataLayout) {}
124
125 /// Computes the information for slot promotion if promotion is possible,
126 /// returns nothing otherwise.
127 std::optional<MemorySlotPromotionInfo> computeInfo();
128
129private:
130 /// Computes the transitive uses of the slot that block promotion. This finds
131 /// uses that would block the promotion, checks that the operation has a
132 /// solution to remove the blocking use, and potentially forwards the analysis
133 /// if the operation needs further blocking uses resolved to resolve its own
134 /// uses (typically, removing its users because it will delete itself to
135 /// resolve its own blocking uses). This will fail if one of the transitive
136 /// users cannot remove a requested use, and should prevent promotion.
137 LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
138
139 /// Computes in which blocks the value stored in the slot is actually used,
140 /// meaning blocks leading to a load. This method uses `definingBlocks`, the
141 /// set of blocks containing a store to the slot (defining the value of the
142 /// slot).
143 SmallPtrSet<Block *, 16>
144 computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
145
146 /// Computes the points in which multiple re-definitions of the slot's value
147 /// (stores) may conflict.
148 void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
149
150 /// Ensures predecessors of merge points can properly provide their current
151 /// definition of the value stored in the slot to the merge point. This can
152 /// notably be an issue if the terminator used does not have the ability to
153 /// forward values through block operands.
154 bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
155
156 MemorySlot slot;
157 DominanceInfo &dominance;
158 const DataLayout &dataLayout;
159};
160
161using BlockIndexCache = DenseMap<Region *, DenseMap<Block *, size_t>>;
162
163/// The MemorySlotPromoter handles the state of promoting a memory slot. It
164/// wraps a slot and its associated allocator. This will perform the mutation of
165/// IR.
166class MemorySlotPromoter {
167public:
168 MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
169 OpBuilder &builder, DominanceInfo &dominance,
170 const DataLayout &dataLayout, MemorySlotPromotionInfo info,
171 const Mem2RegStatistics &statistics,
172 BlockIndexCache &blockIndexCache);
173
174 /// Actually promotes the slot by mutating IR. Promoting a slot DOES
175 /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
176 /// promotion info should NOT be performed in batches.
177 /// Returns a promotable allocation op if a new allocator was created, nullopt
178 /// otherwise.
179 std::optional<PromotableAllocationOpInterface> promoteSlot();
180
181private:
182 /// Computes the reaching definition for all the operations that require
183 /// promotion. `reachingDef` is the value the slot should contain at the
184 /// beginning of the block. This method returns the reached definition at the
185 /// end of the block. This method must only be called at most once per block.
186 Value computeReachingDefInBlock(Block *block, Value reachingDef);
187
188 /// Computes the reaching definition for all the operations that require
189 /// promotion. `reachingDef` corresponds to the initial value the
190 /// slot will contain before any write, typically a poison value.
191 /// This method must only be called at most once per region.
192 void computeReachingDefInRegion(Region *region, Value reachingDef);
193
194 /// Removes the blocking uses of the slot, in topological order.
195 void removeBlockingUses();
196
197 /// Lazily-constructed default value representing the content of the slot when
198 /// no store has been executed. This function may mutate IR.
199 Value getOrCreateDefaultValue();
200
201 MemorySlot slot;
202 PromotableAllocationOpInterface allocator;
203 OpBuilder &builder;
204 /// Potentially non-initialized default value. Use `getOrCreateDefaultValue`
205 /// to initialize it on demand.
206 Value defaultValue;
207 /// Contains the reaching definition at this operation. Reaching definitions
208 /// are only computed for promotable memory operations with blocking uses.
209 DenseMap<PromotableMemOpInterface, Value> reachingDefs;
210 DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
211 DominanceInfo &dominance;
212 const DataLayout &dataLayout;
213 MemorySlotPromotionInfo info;
214 const Mem2RegStatistics &statistics;
215
216 /// Shared cache of block indices of specific regions.
217 BlockIndexCache &blockIndexCache;
218};
219
220} // namespace
221
222MemorySlotPromoter::MemorySlotPromoter(
223 MemorySlot slot, PromotableAllocationOpInterface allocator,
224 OpBuilder &builder, DominanceInfo &dominance, const DataLayout &dataLayout,
225 MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics,
226 BlockIndexCache &blockIndexCache)
227 : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
228 dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
229 blockIndexCache(blockIndexCache) {
230#ifndef NDEBUG
231 auto isResultOrNewBlockArgument = [&]() {
232 if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
233 return arg.getOwner()->getParentOp() == allocator;
234 return slot.ptr.getDefiningOp() == allocator;
235 };
236
237 assert(isResultOrNewBlockArgument() &&
238 "a slot must be a result of the allocator or an argument of the child "
239 "regions of the allocator");
240#endif // NDEBUG
241}
242
243Value MemorySlotPromoter::getOrCreateDefaultValue() {
244 if (defaultValue)
245 return defaultValue;
246
247 OpBuilder::InsertionGuard guard(builder);
248 builder.setInsertionPointToStart(slot.ptr.getParentBlock());
249 return defaultValue = allocator.getDefaultValue(slot, builder);
250}
251
252LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
253 BlockingUsesMap &userToBlockingUses) {
254 // The promotion of an operation may require the promotion of further
255 // operations (typically, removing operations that use an operation that must
256 // delete itself). We thus need to start from the use of the slot pointer and
257 // propagate further requests through the forward slice.
258
259 // Because this pass currently only supports analysing the parent region of
260 // the slot pointer, if a promotable memory op that needs promotion is within
261 // a graph region, the slot may only be used in a graph region and should
262 // therefore be ignored.
263 Region *slotPtrRegion = slot.ptr.getParentRegion();
264 auto slotPtrRegionOp =
265 dyn_cast<RegionKindInterface>(slotPtrRegion->getParentOp());
266 if (slotPtrRegionOp &&
267 slotPtrRegionOp.getRegionKind(slotPtrRegion->getRegionNumber()) ==
268 RegionKind::Graph)
269 return failure();
270
271 // First insert that all immediate users of the slot pointer must no longer
272 // use it.
273 for (OpOperand &use : slot.ptr.getUses()) {
274 SmallPtrSet<OpOperand *, 4> &blockingUses =
275 userToBlockingUses[use.getOwner()];
276 blockingUses.insert(Ptr: &use);
277 }
278
279 // Then, propagate the requirements for the removal of uses. The
280 // topologically-sorted forward slice allows for all blocking uses of an
281 // operation to have been computed before it is reached. Operations are
282 // traversed in topological order of their uses, starting from the slot
283 // pointer.
284 SetVector<Operation *> forwardSlice;
285 mlir::getForwardSlice(root: slot.ptr, forwardSlice: &forwardSlice);
286 for (Operation *user : forwardSlice) {
287 // If the next operation has no blocking uses, everything is fine.
288 auto it = userToBlockingUses.find(user);
289 if (it == userToBlockingUses.end())
290 continue;
291
292 SmallPtrSet<OpOperand *, 4> &blockingUses = it->second;
293
294 SmallVector<OpOperand *> newBlockingUses;
295 // If the operation decides it cannot deal with removing the blocking uses,
296 // promotion must fail.
297 if (auto promotable = dyn_cast<PromotableOpInterface>(user)) {
298 if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
299 dataLayout))
300 return failure();
301 } else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
302 if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
303 dataLayout))
304 return failure();
305 } else {
306 // An operation that has blocking uses must be promoted. If it is not
307 // promotable, promotion must fail.
308 return failure();
309 }
310
311 // Then, register any new blocking uses for coming operations.
312 for (OpOperand *blockingUse : newBlockingUses) {
313 assert(llvm::is_contained(user->getResults(), blockingUse->get()));
314
315 SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
316 userToBlockingUses[blockingUse->getOwner()];
317 newUserBlockingUseSet.insert(Ptr: blockingUse);
318 }
319 }
320
321 // Because this pass currently only supports analysing the parent region of
322 // the slot pointer, if a promotable memory op that needs promotion is outside
323 // of this region, promotion must fail because it will be impossible to
324 // provide a valid `reachingDef` for it.
325 for (auto &[toPromote, _] : userToBlockingUses)
326 if (isa<PromotableMemOpInterface>(toPromote) &&
327 toPromote->getParentRegion() != slot.ptr.getParentRegion())
328 return failure();
329
330 return success();
331}
332
333SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
334 SmallPtrSetImpl<Block *> &definingBlocks) {
335 SmallPtrSet<Block *, 16> liveIn;
336
337 // The worklist contains blocks in which it is known that the slot value is
338 // live-in. The further blocks where this value is live-in will be inferred
339 // from these.
340 SmallVector<Block *> liveInWorkList;
341
342 // Blocks with a load before any other store to the slot are the starting
343 // points of the analysis. The slot value is definitely live-in in those
344 // blocks.
345 SmallPtrSet<Block *, 16> visited;
346 for (Operation *user : slot.ptr.getUsers()) {
347 if (!visited.insert(Ptr: user->getBlock()).second)
348 continue;
349
350 for (Operation &op : user->getBlock()->getOperations()) {
351 if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
352 // If this operation loads the slot, it is loading from it before
353 // ever writing to it, so the value is live-in in this block.
354 if (memOp.loadsFrom(slot)) {
355 liveInWorkList.push_back(Elt: user->getBlock());
356 break;
357 }
358
359 // If we store to the slot, further loads will see that value.
360 // Because we did not meet any load before, the value is not live-in.
361 if (memOp.storesTo(slot))
362 break;
363 }
364 }
365 }
366
367 // The information is then propagated to the predecessors until a def site
368 // (store) is found.
369 while (!liveInWorkList.empty()) {
370 Block *liveInBlock = liveInWorkList.pop_back_val();
371
372 if (!liveIn.insert(Ptr: liveInBlock).second)
373 continue;
374
375 // If a predecessor is a defining block, either:
376 // - It has a load before its first store, in which case it is live-in but
377 // has already been processed in the initialisation step.
378 // - It has a store before any load, in which case it is not live-in.
379 // We can thus at this stage insert to the worklist only predecessors that
380 // are not defining blocks.
381 for (Block *pred : liveInBlock->getPredecessors())
382 if (!definingBlocks.contains(Ptr: pred))
383 liveInWorkList.push_back(Elt: pred);
384 }
385
386 return liveIn;
387}
388
389using IDFCalculator = llvm::IDFCalculatorBase<Block, false>;
390void MemorySlotPromotionAnalyzer::computeMergePoints(
391 SmallPtrSetImpl<Block *> &mergePoints) {
392 if (slot.ptr.getParentRegion()->hasOneBlock())
393 return;
394
395 IDFCalculator idfCalculator(dominance.getDomTree(region: slot.ptr.getParentRegion()));
396
397 SmallPtrSet<Block *, 16> definingBlocks;
398 for (Operation *user : slot.ptr.getUsers())
399 if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
400 if (storeOp.storesTo(slot))
401 definingBlocks.insert(Ptr: user->getBlock());
402
403 idfCalculator.setDefiningBlocks(definingBlocks);
404
405 SmallPtrSet<Block *, 16> liveIn = computeSlotLiveIn(definingBlocks);
406 idfCalculator.setLiveInBlocks(liveIn);
407
408 SmallVector<Block *> mergePointsVec;
409 idfCalculator.calculate(IDFBlocks&: mergePointsVec);
410
411 mergePoints.insert_range(R&: mergePointsVec);
412}
413
414bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
415 SmallPtrSetImpl<Block *> &mergePoints) {
416 for (Block *mergePoint : mergePoints)
417 for (Block *pred : mergePoint->getPredecessors())
418 if (!isa<BranchOpInterface>(Val: pred->getTerminator()))
419 return false;
420
421 return true;
422}
423
424std::optional<MemorySlotPromotionInfo>
425MemorySlotPromotionAnalyzer::computeInfo() {
426 MemorySlotPromotionInfo info;
427
428 // First, find the set of operations that will need to be changed for the
429 // promotion to happen. These operations need to resolve some of their uses,
430 // either by rewiring them or simply deleting themselves. If any of them
431 // cannot find a way to resolve their blocking uses, we abort the promotion.
432 if (failed(computeBlockingUses(info.userToBlockingUses)))
433 return {};
434
435 // Then, compute blocks in which two or more definitions of the allocated
436 // variable may conflict. These blocks will need a new block argument to
437 // accommodate this.
438 computeMergePoints(mergePoints&: info.mergePoints);
439
440 // The slot can be promoted if the block arguments to be created can
441 // actually be populated with values, which may not be possible depending
442 // on their predecessors.
443 if (!areMergePointsUsable(mergePoints&: info.mergePoints))
444 return {};
445
446 return info;
447}
448
449Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
450 Value reachingDef) {
451 SmallVector<Operation *> blockOps;
452 for (Operation &op : block->getOperations())
453 blockOps.push_back(Elt: &op);
454 for (Operation *op : blockOps) {
455 if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
456 if (info.userToBlockingUses.contains(memOp))
457 reachingDefs.insert({memOp, reachingDef});
458
459 if (memOp.storesTo(slot)) {
460 builder.setInsertionPointAfter(memOp);
461 Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout);
462 assert(stored && "a memory operation storing to a slot must provide a "
463 "new definition of the slot");
464 reachingDef = stored;
465 replacedValuesMap[memOp] = stored;
466 }
467 }
468 }
469
470 return reachingDef;
471}
472
473void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
474 Value reachingDef) {
475 assert(reachingDef && "expected an initial reaching def to be provided");
476 if (region->hasOneBlock()) {
477 computeReachingDefInBlock(block: &region->front(), reachingDef);
478 return;
479 }
480
481 struct DfsJob {
482 llvm::DomTreeNodeBase<Block> *block;
483 Value reachingDef;
484 };
485
486 SmallVector<DfsJob> dfsStack;
487
488 auto &domTree = dominance.getDomTree(region: slot.ptr.getParentRegion());
489
490 dfsStack.emplace_back<DfsJob>(
491 Args: {.block: domTree.getNode(BB: &region->front()), .reachingDef: reachingDef});
492
493 while (!dfsStack.empty()) {
494 DfsJob job = dfsStack.pop_back_val();
495 Block *block = job.block->getBlock();
496
497 if (info.mergePoints.contains(block)) {
498 BlockArgument blockArgument =
499 block->addArgument(type: slot.elemType, loc: slot.ptr.getLoc());
500 builder.setInsertionPointToStart(block);
501 allocator.handleBlockArgument(slot, blockArgument, builder);
502 job.reachingDef = blockArgument;
503
504 if (statistics.newBlockArgumentAmount)
505 (*statistics.newBlockArgumentAmount)++;
506 }
507
508 job.reachingDef = computeReachingDefInBlock(block, reachingDef: job.reachingDef);
509 assert(job.reachingDef);
510
511 if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
512 for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
513 if (info.mergePoints.contains(blockOperand.get())) {
514 terminator.getSuccessorOperands(blockOperand.getOperandNumber())
515 .append(job.reachingDef);
516 }
517 }
518 }
519
520 for (auto *child : job.block->children())
521 dfsStack.emplace_back<DfsJob>(Args: {.block: child, .reachingDef: job.reachingDef});
522 }
523}
524
525/// Gets or creates a block index mapping for `region`.
526static const DenseMap<Block *, size_t> &
527getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) {
528 auto [it, inserted] = blockIndexCache.try_emplace(region);
529 if (!inserted)
530 return it->second;
531
532 DenseMap<Block *, size_t> &blockIndices = it->second;
533 SetVector<Block *> topologicalOrder = getBlocksSortedByDominance(region&: *region);
534 for (auto [index, block] : llvm::enumerate(First&: topologicalOrder))
535 blockIndices[block] = index;
536 return blockIndices;
537}
538
539/// Sorts `ops` according to dominance. Relies on the topological order of basic
540/// blocks to get a deterministic ordering. Uses `blockIndexCache` to avoid the
541/// potentially expensive recomputation of a block index map.
542static void dominanceSort(SmallVector<Operation *> &ops, Region &region,
543 BlockIndexCache &blockIndexCache) {
544 // Produce a topological block order and construct a map to lookup the indices
545 // of blocks.
546 const DenseMap<Block *, size_t> &topoBlockIndices =
547 getOrCreateBlockIndices(blockIndexCache, &region);
548
549 // Combining the topological order of the basic blocks together with block
550 // internal operation order guarantees a deterministic, dominance respecting
551 // order.
552 llvm::sort(C&: ops, Comp: [&](Operation *lhs, Operation *rhs) {
553 size_t lhsBlockIndex = topoBlockIndices.at(Val: lhs->getBlock());
554 size_t rhsBlockIndex = topoBlockIndices.at(Val: rhs->getBlock());
555 if (lhsBlockIndex == rhsBlockIndex)
556 return lhs->isBeforeInBlock(other: rhs);
557 return lhsBlockIndex < rhsBlockIndex;
558 });
559}
560
561void MemorySlotPromoter::removeBlockingUses() {
562 llvm::SmallVector<Operation *> usersToRemoveUses(
563 llvm::make_first_range(info.userToBlockingUses));
564
565 // Sort according to dominance.
566 dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent(),
567 blockIndexCache);
568
569 llvm::SmallVector<Operation *> toErase;
570 // List of all replaced values in the slot.
571 llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList;
572 // Ops to visit with the `visitReplacedValues` method.
573 llvm::SmallVector<PromotableOpInterface> toVisit;
574 for (Operation *toPromote : llvm::reverse(C&: usersToRemoveUses)) {
575 if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
576 Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
577 // If no reaching definition is known, this use is outside the reach of
578 // the slot. The default value should thus be used.
579 if (!reachingDef)
580 reachingDef = getOrCreateDefaultValue();
581
582 builder.setInsertionPointAfter(toPromote);
583 if (toPromoteMemOp.removeBlockingUses(
584 slot, info.userToBlockingUses[toPromote], builder, reachingDef,
585 dataLayout) == DeletionKind::Delete)
586 toErase.push_back(Elt: toPromote);
587 if (toPromoteMemOp.storesTo(slot))
588 if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
589 replacedValuesList.push_back(Elt: {toPromoteMemOp, replacedValue});
590 continue;
591 }
592
593 auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
594 builder.setInsertionPointAfter(toPromote);
595 if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
596 builder) == DeletionKind::Delete)
597 toErase.push_back(Elt: toPromote);
598 if (toPromoteBasic.requiresReplacedValues())
599 toVisit.push_back(toPromoteBasic);
600 }
601 for (PromotableOpInterface op : toVisit) {
602 builder.setInsertionPointAfter(op);
603 op.visitReplacedValues(replacedValuesList, builder);
604 }
605
606 for (Operation *toEraseOp : toErase)
607 toEraseOp->erase();
608
609 assert(slot.ptr.use_empty() &&
610 "after promotion, the slot pointer should not be used anymore");
611}
612
613std::optional<PromotableAllocationOpInterface>
614MemorySlotPromoter::promoteSlot() {
615 computeReachingDefInRegion(region: slot.ptr.getParentRegion(),
616 reachingDef: getOrCreateDefaultValue());
617
618 // Now that reaching definitions are known, remove all users.
619 removeBlockingUses();
620
621 // Update terminators in dead branches to forward default if they are
622 // succeeded by a merge points.
623 for (Block *mergePoint : info.mergePoints) {
624 for (BlockOperand &use : mergePoint->getUses()) {
625 auto user = cast<BranchOpInterface>(use.getOwner());
626 SuccessorOperands succOperands =
627 user.getSuccessorOperands(use.getOperandNumber());
628 assert(succOperands.size() == mergePoint->getNumArguments() ||
629 succOperands.size() + 1 == mergePoint->getNumArguments());
630 if (succOperands.size() + 1 == mergePoint->getNumArguments())
631 succOperands.append(getOrCreateDefaultValue());
632 }
633 }
634
635 LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
636 << "\n");
637
638 if (statistics.promotedAmount)
639 (*statistics.promotedAmount)++;
640
641 return allocator.handlePromotionComplete(slot, defaultValue, builder);
642}
643
644LogicalResult mlir::tryToPromoteMemorySlots(
645 ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
646 const DataLayout &dataLayout, DominanceInfo &dominance,
647 Mem2RegStatistics statistics) {
648 bool promotedAny = false;
649
650 // A cache that stores deterministic block indices which are used to determine
651 // a valid operation modification order. The block index maps are computed
652 // lazily and cached to avoid expensive recomputation.
653 BlockIndexCache blockIndexCache;
654
655 SmallVector<PromotableAllocationOpInterface> workList(allocators);
656
657 SmallVector<PromotableAllocationOpInterface> newWorkList;
658 newWorkList.reserve(workList.size());
659 while (true) {
660 bool changesInThisRound = false;
661 for (PromotableAllocationOpInterface allocator : workList) {
662 bool changedAllocator = false;
663 for (MemorySlot slot : allocator.getPromotableSlots()) {
664 if (slot.ptr.use_empty())
665 continue;
666
667 MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
668 std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
669 if (info) {
670 std::optional<PromotableAllocationOpInterface> newAllocator =
671 MemorySlotPromoter(slot, allocator, builder, dominance,
672 dataLayout, std::move(*info), statistics,
673 blockIndexCache)
674 .promoteSlot();
675 changedAllocator = true;
676 // Add newly created allocators to the worklist for further
677 // processing.
678 if (newAllocator)
679 newWorkList.push_back(*newAllocator);
680
681 // A break is required, since promoting a slot may invalidate the
682 // remaining slots of an allocator.
683 break;
684 }
685 }
686 if (!changedAllocator)
687 newWorkList.push_back(allocator);
688 changesInThisRound |= changedAllocator;
689 }
690 if (!changesInThisRound)
691 break;
692 promotedAny = true;
693
694 // Swap the vector's backing memory and clear the entries in newWorkList
695 // afterwards. This ensures that additional heap allocations can be avoided.
696 workList.swap(newWorkList);
697 newWorkList.clear();
698 }
699
700 return success(IsSuccess: promotedAny);
701}
702
703namespace {
704
705struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
706 using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
707
708 void runOnOperation() override {
709 Operation *scopeOp = getOperation();
710
711 Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount};
712
713 bool changed = false;
714
715 auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
716 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
717 auto &dominance = getAnalysis<DominanceInfo>();
718
719 for (Region &region : scopeOp->getRegions()) {
720 if (region.getBlocks().empty())
721 continue;
722
723 OpBuilder builder(&region.front(), region.front().begin());
724
725 SmallVector<PromotableAllocationOpInterface> allocators;
726 // Build a list of allocators to attempt to promote the slots of.
727 region.walk([&](PromotableAllocationOpInterface allocator) {
728 allocators.emplace_back(allocator);
729 });
730
731 // Attempt promoting as many of the slots as possible.
732 if (succeeded(tryToPromoteMemorySlots(allocators, builder, dataLayout,
733 dominance, statistics)))
734 changed = true;
735 }
736 if (!changed)
737 markAllAnalysesPreserved();
738 }
739};
740
741} // namespace
742

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Transforms/Mem2Reg.cpp