1//===- Mem2Reg.cpp - Promotes memory slots into values ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Transforms/Mem2Reg.h"
10#include "mlir/Analysis/DataLayoutAnalysis.h"
11#include "mlir/Analysis/SliceAnalysis.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/Dominance.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/IR/Value.h"
16#include "mlir/Interfaces/ControlFlowInterfaces.h"
17#include "mlir/Interfaces/MemorySlotInterfaces.h"
18#include "mlir/Transforms/Passes.h"
19#include "mlir/Transforms/RegionUtils.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/Casting.h"
22#include "llvm/Support/GenericIteratedDominanceFrontier.h"
23
24namespace mlir {
25#define GEN_PASS_DEF_MEM2REG
26#include "mlir/Transforms/Passes.h.inc"
27} // namespace mlir
28
29#define DEBUG_TYPE "mem2reg"
30
31using namespace mlir;
32
33/// mem2reg
34///
35/// This pass turns unnecessary uses of automatically allocated memory slots
36/// into direct Value-based operations. For example, it will simplify storing a
37/// constant in a memory slot to immediately load it to a direct use of that
38/// constant. In other words, given a memory slot addressed by a non-aliased
39/// "pointer" Value, mem2reg removes all the uses of that pointer.
40///
41/// Within a block, this is done by following the chain of stores and loads of
42/// the slot and replacing the results of loads with the values previously
43/// stored. If a load happens before any other store, a poison value is used
44/// instead.
45///
46/// Control flow can create situations where a load could be replaced by
47/// multiple possible stores depending on the control flow path taken. As a
48/// result, this pass must introduce new block arguments in some blocks to
49/// accommodate for the multiple possible definitions. Each predecessor will
50/// populate the block argument with the definition reached at its end. With
51/// this, the value stored can be well defined at block boundaries, allowing
52/// the propagation of replacement through blocks.
53///
54/// This pass computes this transformation in four main steps. The two first
55/// steps are performed during an analysis phase that does not mutate IR.
56///
57/// The two steps of the analysis phase are the following:
58/// - A first step computes the list of operations that transitively use the
59/// memory slot we would like to promote. The purpose of this phase is to
60/// identify which uses must be removed to promote the slot, either by rewiring
61/// the user or deleting it. Naturally, direct uses of the slot must be removed.
62/// Sometimes additional uses must also be removed: this is notably the case
63/// when a direct user of the slot cannot rewire its use and must delete itself,
64/// and thus must make its users no longer use it. If any of those uses cannot
65/// be removed by their users in any way, promotion cannot continue: this is
66/// decided at this step.
67/// - A second step computes the list of blocks where a block argument will be
68/// needed ("merge points") without mutating the IR. These blocks are the blocks
69/// leading to a definition clash between two predecessors. Such blocks happen
70/// to be the Iterated Dominance Frontier (IDF) of the set of blocks containing
71/// a store, as they represent the point where a clear defining dominator stops
72/// existing. Computing this information in advance allows making sure the
73/// terminators that will forward values are capable of doing so (inability to
74/// do so aborts promotion at this step).
75///
76/// At this point, promotion is guaranteed to happen, and the mutation phase can
77/// begin with the following steps:
78/// - A third step computes the reaching definition of the memory slot at each
79/// blocking user. This is the core of the mem2reg algorithm, also known as
80/// load-store forwarding. This analyses loads and stores and propagates which
81/// value must be stored in the slot at each blocking user. This is achieved by
82/// doing a depth-first walk of the dominator tree of the function. This is
83/// sufficient because the reaching definition at the beginning of a block is
84/// either its new block argument if it is a merge block, or the definition
85/// reaching the end of its immediate dominator (parent in the dominator tree).
86/// We can therefore propagate this information down the dominator tree to
87/// proceed with renaming within blocks.
88/// - The final fourth step uses the reaching definition to remove blocking uses
89/// in topological order.
90///
91/// For further reading, chapter three of SSA-based Compiler Design [1]
92/// showcases SSA construction, where mem2reg is an adaptation of the same
93/// process.
94///
95/// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022),
96/// Springer.
97
98namespace {
99
100using BlockingUsesMap =
101 llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
102
103/// Information computed during promotion analysis used to perform actual
104/// promotion.
105struct MemorySlotPromotionInfo {
106 /// Blocks for which at least two definitions of the slot values clash.
107 SmallPtrSet<Block *, 8> mergePoints;
108 /// Contains, for each operation, which uses must be eliminated by promotion.
109 /// This is a DAG structure because if an operation must eliminate some of
110 /// its uses, it is because the defining ops of the blocking uses requested
111 /// it. The defining ops therefore must also have blocking uses or be the
112 /// starting point of the blocking uses.
113 BlockingUsesMap userToBlockingUses;
114};
115
116/// Computes information for basic slot promotion. This will check that direct
117/// slot promotion can be performed, and provide the information to execute the
118/// promotion. This does not mutate IR.
119class MemorySlotPromotionAnalyzer {
120public:
121 MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance,
122 const DataLayout &dataLayout)
123 : slot(slot), dominance(dominance), dataLayout(dataLayout) {}
124
125 /// Computes the information for slot promotion if promotion is possible,
126 /// returns nothing otherwise.
127 std::optional<MemorySlotPromotionInfo> computeInfo();
128
129private:
130 /// Computes the transitive uses of the slot that block promotion. This finds
131 /// uses that would block the promotion, checks that the operation has a
132 /// solution to remove the blocking use, and potentially forwards the analysis
133 /// if the operation needs further blocking uses resolved to resolve its own
134 /// uses (typically, removing its users because it will delete itself to
135 /// resolve its own blocking uses). This will fail if one of the transitive
136 /// users cannot remove a requested use, and should prevent promotion.
137 LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
138
139 /// Computes in which blocks the value stored in the slot is actually used,
140 /// meaning blocks leading to a load. This method uses `definingBlocks`, the
141 /// set of blocks containing a store to the slot (defining the value of the
142 /// slot).
143 SmallPtrSet<Block *, 16>
144 computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
145
146 /// Computes the points in which multiple re-definitions of the slot's value
147 /// (stores) may conflict.
148 void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
149
150 /// Ensures predecessors of merge points can properly provide their current
151 /// definition of the value stored in the slot to the merge point. This can
152 /// notably be an issue if the terminator used does not have the ability to
153 /// forward values through block operands.
154 bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
155
156 MemorySlot slot;
157 DominanceInfo &dominance;
158 const DataLayout &dataLayout;
159};
160
161/// The MemorySlotPromoter handles the state of promoting a memory slot. It
162/// wraps a slot and its associated allocator. This will perform the mutation of
163/// IR.
164class MemorySlotPromoter {
165public:
166 MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
167 RewriterBase &rewriter, DominanceInfo &dominance,
168 const DataLayout &dataLayout, MemorySlotPromotionInfo info,
169 const Mem2RegStatistics &statistics);
170
171 /// Actually promotes the slot by mutating IR. Promoting a slot DOES
172 /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
173 /// promotion info should NOT be performed in batches.
174 void promoteSlot();
175
176private:
177 /// Computes the reaching definition for all the operations that require
178 /// promotion. `reachingDef` is the value the slot should contain at the
179 /// beginning of the block. This method returns the reached definition at the
180 /// end of the block. This method must only be called at most once per block.
181 Value computeReachingDefInBlock(Block *block, Value reachingDef);
182
183 /// Computes the reaching definition for all the operations that require
184 /// promotion. `reachingDef` corresponds to the initial value the
185 /// slot will contain before any write, typically a poison value.
186 /// This method must only be called at most once per region.
187 void computeReachingDefInRegion(Region *region, Value reachingDef);
188
189 /// Removes the blocking uses of the slot, in topological order.
190 void removeBlockingUses();
191
192 /// Lazily-constructed default value representing the content of the slot when
193 /// no store has been executed. This function may mutate IR.
194 Value getOrCreateDefaultValue();
195
196 MemorySlot slot;
197 PromotableAllocationOpInterface allocator;
198 RewriterBase &rewriter;
199 /// Potentially non-initialized default value. Use `getOrCreateDefaultValue`
200 /// to initialize it on demand.
201 Value defaultValue;
202 /// Contains the reaching definition at this operation. Reaching definitions
203 /// are only computed for promotable memory operations with blocking uses.
204 DenseMap<PromotableMemOpInterface, Value> reachingDefs;
205 DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
206 DominanceInfo &dominance;
207 const DataLayout &dataLayout;
208 MemorySlotPromotionInfo info;
209 const Mem2RegStatistics &statistics;
210};
211
212} // namespace
213
214MemorySlotPromoter::MemorySlotPromoter(
215 MemorySlot slot, PromotableAllocationOpInterface allocator,
216 RewriterBase &rewriter, DominanceInfo &dominance,
217 const DataLayout &dataLayout, MemorySlotPromotionInfo info,
218 const Mem2RegStatistics &statistics)
219 : slot(slot), allocator(allocator), rewriter(rewriter),
220 dominance(dominance), dataLayout(dataLayout), info(std::move(info)),
221 statistics(statistics) {
222#ifndef NDEBUG
223 auto isResultOrNewBlockArgument = [&]() {
224 if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
225 return arg.getOwner()->getParentOp() == allocator;
226 return slot.ptr.getDefiningOp() == allocator;
227 };
228
229 assert(isResultOrNewBlockArgument() &&
230 "a slot must be a result of the allocator or an argument of the child "
231 "regions of the allocator");
232#endif // NDEBUG
233}
234
235Value MemorySlotPromoter::getOrCreateDefaultValue() {
236 if (defaultValue)
237 return defaultValue;
238
239 RewriterBase::InsertionGuard guard(rewriter);
240 rewriter.setInsertionPointToStart(slot.ptr.getParentBlock());
241 return defaultValue = allocator.getDefaultValue(slot, rewriter);
242}
243
244LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
245 BlockingUsesMap &userToBlockingUses) {
246 // The promotion of an operation may require the promotion of further
247 // operations (typically, removing operations that use an operation that must
248 // delete itself). We thus need to start from the use of the slot pointer and
249 // propagate further requests through the forward slice.
250
251 // First insert that all immediate users of the slot pointer must no longer
252 // use it.
253 for (OpOperand &use : slot.ptr.getUses()) {
254 SmallPtrSet<OpOperand *, 4> &blockingUses =
255 userToBlockingUses[use.getOwner()];
256 blockingUses.insert(Ptr: &use);
257 }
258
259 // Then, propagate the requirements for the removal of uses. The
260 // topologically-sorted forward slice allows for all blocking uses of an
261 // operation to have been computed before it is reached. Operations are
262 // traversed in topological order of their uses, starting from the slot
263 // pointer.
264 SetVector<Operation *> forwardSlice;
265 mlir::getForwardSlice(root: slot.ptr, forwardSlice: &forwardSlice);
266 for (Operation *user : forwardSlice) {
267 // If the next operation has no blocking uses, everything is fine.
268 if (!userToBlockingUses.contains(user))
269 continue;
270
271 SmallPtrSet<OpOperand *, 4> &blockingUses = userToBlockingUses[user];
272
273 SmallVector<OpOperand *> newBlockingUses;
274 // If the operation decides it cannot deal with removing the blocking uses,
275 // promotion must fail.
276 if (auto promotable = dyn_cast<PromotableOpInterface>(user)) {
277 if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
278 dataLayout))
279 return failure();
280 } else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
281 if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
282 dataLayout))
283 return failure();
284 } else {
285 // An operation that has blocking uses must be promoted. If it is not
286 // promotable, promotion must fail.
287 return failure();
288 }
289
290 // Then, register any new blocking uses for coming operations.
291 for (OpOperand *blockingUse : newBlockingUses) {
292 assert(llvm::is_contained(user->getResults(), blockingUse->get()));
293
294 SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
295 userToBlockingUses[blockingUse->getOwner()];
296 newUserBlockingUseSet.insert(Ptr: blockingUse);
297 }
298 }
299
300 // Because this pass currently only supports analysing the parent region of
301 // the slot pointer, if a promotable memory op that needs promotion is outside
302 // of this region, promotion must fail because it will be impossible to
303 // provide a valid `reachingDef` for it.
304 for (auto &[toPromote, _] : userToBlockingUses)
305 if (isa<PromotableMemOpInterface>(toPromote) &&
306 toPromote->getParentRegion() != slot.ptr.getParentRegion())
307 return failure();
308
309 return success();
310}
311
312SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
313 SmallPtrSetImpl<Block *> &definingBlocks) {
314 SmallPtrSet<Block *, 16> liveIn;
315
316 // The worklist contains blocks in which it is known that the slot value is
317 // live-in. The further blocks where this value is live-in will be inferred
318 // from these.
319 SmallVector<Block *> liveInWorkList;
320
321 // Blocks with a load before any other store to the slot are the starting
322 // points of the analysis. The slot value is definitely live-in in those
323 // blocks.
324 SmallPtrSet<Block *, 16> visited;
325 for (Operation *user : slot.ptr.getUsers()) {
326 if (visited.contains(Ptr: user->getBlock()))
327 continue;
328 visited.insert(Ptr: user->getBlock());
329
330 for (Operation &op : user->getBlock()->getOperations()) {
331 if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
332 // If this operation loads the slot, it is loading from it before
333 // ever writing to it, so the value is live-in in this block.
334 if (memOp.loadsFrom(slot)) {
335 liveInWorkList.push_back(Elt: user->getBlock());
336 break;
337 }
338
339 // If we store to the slot, further loads will see that value.
340 // Because we did not meet any load before, the value is not live-in.
341 if (memOp.storesTo(slot))
342 break;
343 }
344 }
345 }
346
347 // The information is then propagated to the predecessors until a def site
348 // (store) is found.
349 while (!liveInWorkList.empty()) {
350 Block *liveInBlock = liveInWorkList.pop_back_val();
351
352 if (!liveIn.insert(Ptr: liveInBlock).second)
353 continue;
354
355 // If a predecessor is a defining block, either:
356 // - It has a load before its first store, in which case it is live-in but
357 // has already been processed in the initialisation step.
358 // - It has a store before any load, in which case it is not live-in.
359 // We can thus at this stage insert to the worklist only predecessors that
360 // are not defining blocks.
361 for (Block *pred : liveInBlock->getPredecessors())
362 if (!definingBlocks.contains(Ptr: pred))
363 liveInWorkList.push_back(Elt: pred);
364 }
365
366 return liveIn;
367}
368
369using IDFCalculator = llvm::IDFCalculatorBase<Block, false>;
370void MemorySlotPromotionAnalyzer::computeMergePoints(
371 SmallPtrSetImpl<Block *> &mergePoints) {
372 if (slot.ptr.getParentRegion()->hasOneBlock())
373 return;
374
375 IDFCalculator idfCalculator(dominance.getDomTree(region: slot.ptr.getParentRegion()));
376
377 SmallPtrSet<Block *, 16> definingBlocks;
378 for (Operation *user : slot.ptr.getUsers())
379 if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
380 if (storeOp.storesTo(slot))
381 definingBlocks.insert(Ptr: user->getBlock());
382
383 idfCalculator.setDefiningBlocks(definingBlocks);
384
385 SmallPtrSet<Block *, 16> liveIn = computeSlotLiveIn(definingBlocks);
386 idfCalculator.setLiveInBlocks(liveIn);
387
388 SmallVector<Block *> mergePointsVec;
389 idfCalculator.calculate(IDFBlocks&: mergePointsVec);
390
391 mergePoints.insert(I: mergePointsVec.begin(), E: mergePointsVec.end());
392}
393
394bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
395 SmallPtrSetImpl<Block *> &mergePoints) {
396 for (Block *mergePoint : mergePoints)
397 for (Block *pred : mergePoint->getPredecessors())
398 if (!isa<BranchOpInterface>(Val: pred->getTerminator()))
399 return false;
400
401 return true;
402}
403
404std::optional<MemorySlotPromotionInfo>
405MemorySlotPromotionAnalyzer::computeInfo() {
406 MemorySlotPromotionInfo info;
407
408 // First, find the set of operations that will need to be changed for the
409 // promotion to happen. These operations need to resolve some of their uses,
410 // either by rewiring them or simply deleting themselves. If any of them
411 // cannot find a way to resolve their blocking uses, we abort the promotion.
412 if (failed(computeBlockingUses(info.userToBlockingUses)))
413 return {};
414
415 // Then, compute blocks in which two or more definitions of the allocated
416 // variable may conflict. These blocks will need a new block argument to
417 // accommodate this.
418 computeMergePoints(mergePoints&: info.mergePoints);
419
420 // The slot can be promoted if the block arguments to be created can
421 // actually be populated with values, which may not be possible depending
422 // on their predecessors.
423 if (!areMergePointsUsable(mergePoints&: info.mergePoints))
424 return {};
425
426 return info;
427}
428
429Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
430 Value reachingDef) {
431 SmallVector<Operation *> blockOps;
432 for (Operation &op : block->getOperations())
433 blockOps.push_back(Elt: &op);
434 for (Operation *op : blockOps) {
435 if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
436 if (info.userToBlockingUses.contains(memOp))
437 reachingDefs.insert({memOp, reachingDef});
438
439 if (memOp.storesTo(slot)) {
440 rewriter.setInsertionPointAfter(memOp);
441 Value stored = memOp.getStored(slot, rewriter, reachingDef, dataLayout);
442 assert(stored && "a memory operation storing to a slot must provide a "
443 "new definition of the slot");
444 reachingDef = stored;
445 replacedValuesMap[memOp] = stored;
446 }
447 }
448 }
449
450 return reachingDef;
451}
452
453void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
454 Value reachingDef) {
455 assert(reachingDef && "expected an initial reaching def to be provided");
456 if (region->hasOneBlock()) {
457 computeReachingDefInBlock(block: &region->front(), reachingDef);
458 return;
459 }
460
461 struct DfsJob {
462 llvm::DomTreeNodeBase<Block> *block;
463 Value reachingDef;
464 };
465
466 SmallVector<DfsJob> dfsStack;
467
468 auto &domTree = dominance.getDomTree(region: slot.ptr.getParentRegion());
469
470 dfsStack.emplace_back<DfsJob>(
471 Args: {.block: domTree.getNode(BB: &region->front()), .reachingDef: reachingDef});
472
473 while (!dfsStack.empty()) {
474 DfsJob job = dfsStack.pop_back_val();
475 Block *block = job.block->getBlock();
476
477 if (info.mergePoints.contains(block)) {
478 // If the block is a merge point, we need to add a block argument to hold
479 // the selected reaching definition. This has to be a bit complicated
480 // because of RewriterBase limitations: we need to create a new block with
481 // the extra block argument, move the content of the block to the new
482 // block, and replace the block with the new block in the merge point set.
483 SmallVector<Type> argTypes;
484 SmallVector<Location> argLocs;
485 for (BlockArgument arg : block->getArguments()) {
486 argTypes.push_back(Elt: arg.getType());
487 argLocs.push_back(Elt: arg.getLoc());
488 }
489 argTypes.push_back(Elt: slot.elemType);
490 argLocs.push_back(Elt: slot.ptr.getLoc());
491 Block *newBlock = rewriter.createBlock(insertBefore: block, argTypes, locs: argLocs);
492
493 info.mergePoints.erase(block);
494 info.mergePoints.insert(newBlock);
495
496 rewriter.replaceAllUsesWith(from: block, to: newBlock);
497 rewriter.mergeBlocks(source: block, dest: newBlock,
498 argValues: newBlock->getArguments().drop_back());
499
500 block = newBlock;
501
502 BlockArgument blockArgument = block->getArguments().back();
503 rewriter.setInsertionPointToStart(block);
504 allocator.handleBlockArgument(slot, blockArgument, rewriter);
505 job.reachingDef = blockArgument;
506
507 if (statistics.newBlockArgumentAmount)
508 (*statistics.newBlockArgumentAmount)++;
509 }
510
511 job.reachingDef = computeReachingDefInBlock(block, reachingDef: job.reachingDef);
512 assert(job.reachingDef);
513
514 if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
515 for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
516 if (info.mergePoints.contains(blockOperand.get())) {
517 rewriter.modifyOpInPlace(terminator, [&]() {
518 terminator.getSuccessorOperands(blockOperand.getOperandNumber())
519 .append(job.reachingDef);
520 });
521 }
522 }
523 }
524
525 for (auto *child : job.block->children())
526 dfsStack.emplace_back<DfsJob>(Args: {.block: child, .reachingDef: job.reachingDef});
527 }
528}
529
530/// Sorts `ops` according to dominance. Relies on the topological order of basic
531/// blocks to get a deterministic ordering.
532static void dominanceSort(SmallVector<Operation *> &ops, Region &region) {
533 // Produce a topological block order and construct a map to lookup the indices
534 // of blocks.
535 DenseMap<Block *, size_t> topoBlockIndices;
536 SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks(region);
537 for (auto [index, block] : llvm::enumerate(First&: topologicalOrder))
538 topoBlockIndices[block] = index;
539
540 // Combining the topological order of the basic blocks together with block
541 // internal operation order guarantees a deterministic, dominance respecting
542 // order.
543 llvm::sort(C&: ops, Comp: [&](Operation *lhs, Operation *rhs) {
544 size_t lhsBlockIndex = topoBlockIndices.at(Val: lhs->getBlock());
545 size_t rhsBlockIndex = topoBlockIndices.at(Val: rhs->getBlock());
546 if (lhsBlockIndex == rhsBlockIndex)
547 return lhs->isBeforeInBlock(other: rhs);
548 return lhsBlockIndex < rhsBlockIndex;
549 });
550}
551
552void MemorySlotPromoter::removeBlockingUses() {
553 llvm::SmallVector<Operation *> usersToRemoveUses(
554 llvm::make_first_range(info.userToBlockingUses));
555
556 // Sort according to dominance.
557 dominanceSort(ops&: usersToRemoveUses, region&: *slot.ptr.getParentBlock()->getParent());
558
559 llvm::SmallVector<Operation *> toErase;
560 // List of all replaced values in the slot.
561 llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList;
562 // Ops to visit with the `visitReplacedValues` method.
563 llvm::SmallVector<PromotableOpInterface> toVisit;
564 for (Operation *toPromote : llvm::reverse(C&: usersToRemoveUses)) {
565 if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
566 Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
567 // If no reaching definition is known, this use is outside the reach of
568 // the slot. The default value should thus be used.
569 if (!reachingDef)
570 reachingDef = getOrCreateDefaultValue();
571
572 rewriter.setInsertionPointAfter(toPromote);
573 if (toPromoteMemOp.removeBlockingUses(
574 slot, info.userToBlockingUses[toPromote], rewriter, reachingDef,
575 dataLayout) == DeletionKind::Delete)
576 toErase.push_back(Elt: toPromote);
577 if (toPromoteMemOp.storesTo(slot))
578 if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
579 replacedValuesList.push_back(Elt: {toPromoteMemOp, replacedValue});
580 continue;
581 }
582
583 auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
584 rewriter.setInsertionPointAfter(toPromote);
585 if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
586 rewriter) == DeletionKind::Delete)
587 toErase.push_back(Elt: toPromote);
588 if (toPromoteBasic.requiresReplacedValues())
589 toVisit.push_back(toPromoteBasic);
590 }
591 for (PromotableOpInterface op : toVisit) {
592 rewriter.setInsertionPointAfter(op);
593 op.visitReplacedValues(replacedValuesList, rewriter);
594 }
595
596 for (Operation *toEraseOp : toErase)
597 rewriter.eraseOp(op: toEraseOp);
598
599 assert(slot.ptr.use_empty() &&
600 "after promotion, the slot pointer should not be used anymore");
601}
602
603void MemorySlotPromoter::promoteSlot() {
604 computeReachingDefInRegion(region: slot.ptr.getParentRegion(),
605 reachingDef: getOrCreateDefaultValue());
606
607 // Now that reaching definitions are known, remove all users.
608 removeBlockingUses();
609
610 // Update terminators in dead branches to forward default if they are
611 // succeeded by a merge points.
612 for (Block *mergePoint : info.mergePoints) {
613 for (BlockOperand &use : mergePoint->getUses()) {
614 auto user = cast<BranchOpInterface>(use.getOwner());
615 SuccessorOperands succOperands =
616 user.getSuccessorOperands(use.getOperandNumber());
617 assert(succOperands.size() == mergePoint->getNumArguments() ||
618 succOperands.size() + 1 == mergePoint->getNumArguments());
619 if (succOperands.size() + 1 == mergePoint->getNumArguments())
620 rewriter.modifyOpInPlace(
621 user, [&]() { succOperands.append(getOrCreateDefaultValue()); });
622 }
623 }
624
625 LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
626 << "\n");
627
628 if (statistics.promotedAmount)
629 (*statistics.promotedAmount)++;
630
631 allocator.handlePromotionComplete(slot, defaultValue, rewriter);
632}
633
634LogicalResult mlir::tryToPromoteMemorySlots(
635 ArrayRef<PromotableAllocationOpInterface> allocators,
636 RewriterBase &rewriter, const DataLayout &dataLayout,
637 Mem2RegStatistics statistics) {
638 bool promotedAny = false;
639
640 for (PromotableAllocationOpInterface allocator : allocators) {
641 for (MemorySlot slot : allocator.getPromotableSlots()) {
642 if (slot.ptr.use_empty())
643 continue;
644
645 DominanceInfo dominance;
646 MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
647 std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
648 if (info) {
649 MemorySlotPromoter(slot, allocator, rewriter, dominance, dataLayout,
650 std::move(*info), statistics)
651 .promoteSlot();
652 promotedAny = true;
653 }
654 }
655 }
656
657 return success(isSuccess: promotedAny);
658}
659
660namespace {
661
662struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
663 using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
664
665 void runOnOperation() override {
666 Operation *scopeOp = getOperation();
667
668 Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount};
669
670 bool changed = false;
671
672 for (Region &region : scopeOp->getRegions()) {
673 if (region.getBlocks().empty())
674 continue;
675
676 OpBuilder builder(&region.front(), region.front().begin());
677 IRRewriter rewriter(builder);
678
679 // Promoting a slot can allow for further promotion of other slots,
680 // promotion is tried until no promotion succeeds.
681 while (true) {
682 SmallVector<PromotableAllocationOpInterface> allocators;
683 // Build a list of allocators to attempt to promote the slots of.
684 region.walk([&](PromotableAllocationOpInterface allocator) {
685 allocators.emplace_back(allocator);
686 });
687
688 auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
689 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
690
691 // Attempt promoting until no promotion succeeds.
692 if (failed(tryToPromoteMemorySlots(allocators, rewriter, dataLayout,
693 statistics)))
694 break;
695
696 changed = true;
697 getAnalysisManager().invalidate({});
698 }
699 }
700 if (!changed)
701 markAllAnalysesPreserved();
702 }
703};
704
705} // namespace
706

source code of mlir/lib/Transforms/Mem2Reg.cpp