1 | //===- Mem2Reg.cpp - Promotes memory slots into values ----------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Transforms/Mem2Reg.h" |
10 | #include "mlir/Analysis/DataLayoutAnalysis.h" |
11 | #include "mlir/Analysis/SliceAnalysis.h" |
12 | #include "mlir/IR/Builders.h" |
13 | #include "mlir/IR/Dominance.h" |
14 | #include "mlir/IR/PatternMatch.h" |
15 | #include "mlir/IR/Value.h" |
16 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
17 | #include "mlir/Interfaces/MemorySlotInterfaces.h" |
18 | #include "mlir/Transforms/Passes.h" |
19 | #include "mlir/Transforms/RegionUtils.h" |
20 | #include "llvm/ADT/STLExtras.h" |
21 | #include "llvm/Support/Casting.h" |
22 | #include "llvm/Support/GenericIteratedDominanceFrontier.h" |
23 | |
24 | namespace mlir { |
25 | #define GEN_PASS_DEF_MEM2REG |
26 | #include "mlir/Transforms/Passes.h.inc" |
27 | } // namespace mlir |
28 | |
29 | #define DEBUG_TYPE "mem2reg" |
30 | |
31 | using namespace mlir; |
32 | |
33 | /// mem2reg |
34 | /// |
35 | /// This pass turns unnecessary uses of automatically allocated memory slots |
36 | /// into direct Value-based operations. For example, it will simplify storing a |
37 | /// constant in a memory slot to immediately load it to a direct use of that |
38 | /// constant. In other words, given a memory slot addressed by a non-aliased |
39 | /// "pointer" Value, mem2reg removes all the uses of that pointer. |
40 | /// |
41 | /// Within a block, this is done by following the chain of stores and loads of |
42 | /// the slot and replacing the results of loads with the values previously |
43 | /// stored. If a load happens before any other store, a poison value is used |
44 | /// instead. |
45 | /// |
46 | /// Control flow can create situations where a load could be replaced by |
47 | /// multiple possible stores depending on the control flow path taken. As a |
48 | /// result, this pass must introduce new block arguments in some blocks to |
49 | /// accommodate for the multiple possible definitions. Each predecessor will |
50 | /// populate the block argument with the definition reached at its end. With |
51 | /// this, the value stored can be well defined at block boundaries, allowing |
52 | /// the propagation of replacement through blocks. |
53 | /// |
54 | /// This pass computes this transformation in four main steps. The two first |
55 | /// steps are performed during an analysis phase that does not mutate IR. |
56 | /// |
57 | /// The two steps of the analysis phase are the following: |
58 | /// - A first step computes the list of operations that transitively use the |
59 | /// memory slot we would like to promote. The purpose of this phase is to |
60 | /// identify which uses must be removed to promote the slot, either by rewiring |
61 | /// the user or deleting it. Naturally, direct uses of the slot must be removed. |
62 | /// Sometimes additional uses must also be removed: this is notably the case |
63 | /// when a direct user of the slot cannot rewire its use and must delete itself, |
64 | /// and thus must make its users no longer use it. If any of those uses cannot |
65 | /// be removed by their users in any way, promotion cannot continue: this is |
66 | /// decided at this step. |
67 | /// - A second step computes the list of blocks where a block argument will be |
68 | /// needed ("merge points") without mutating the IR. These blocks are the blocks |
69 | /// leading to a definition clash between two predecessors. Such blocks happen |
70 | /// to be the Iterated Dominance Frontier (IDF) of the set of blocks containing |
71 | /// a store, as they represent the point where a clear defining dominator stops |
72 | /// existing. Computing this information in advance allows making sure the |
73 | /// terminators that will forward values are capable of doing so (inability to |
74 | /// do so aborts promotion at this step). |
75 | /// |
76 | /// At this point, promotion is guaranteed to happen, and the mutation phase can |
77 | /// begin with the following steps: |
78 | /// - A third step computes the reaching definition of the memory slot at each |
79 | /// blocking user. This is the core of the mem2reg algorithm, also known as |
80 | /// load-store forwarding. This analyses loads and stores and propagates which |
81 | /// value must be stored in the slot at each blocking user. This is achieved by |
82 | /// doing a depth-first walk of the dominator tree of the function. This is |
83 | /// sufficient because the reaching definition at the beginning of a block is |
84 | /// either its new block argument if it is a merge block, or the definition |
85 | /// reaching the end of its immediate dominator (parent in the dominator tree). |
86 | /// We can therefore propagate this information down the dominator tree to |
87 | /// proceed with renaming within blocks. |
88 | /// - The final fourth step uses the reaching definition to remove blocking uses |
89 | /// in topological order. |
90 | /// |
91 | /// For further reading, chapter three of SSA-based Compiler Design [1] |
92 | /// showcases SSA construction, where mem2reg is an adaptation of the same |
93 | /// process. |
94 | /// |
95 | /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022), |
96 | /// Springer. |
97 | |
98 | namespace { |
99 | |
100 | using BlockingUsesMap = |
101 | llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>; |
102 | |
103 | /// Information computed during promotion analysis used to perform actual |
104 | /// promotion. |
105 | struct MemorySlotPromotionInfo { |
106 | /// Blocks for which at least two definitions of the slot values clash. |
107 | SmallPtrSet<Block *, 8> mergePoints; |
108 | /// Contains, for each operation, which uses must be eliminated by promotion. |
109 | /// This is a DAG structure because if an operation must eliminate some of |
110 | /// its uses, it is because the defining ops of the blocking uses requested |
111 | /// it. The defining ops therefore must also have blocking uses or be the |
112 | /// starting point of the blocking uses. |
113 | BlockingUsesMap userToBlockingUses; |
114 | }; |
115 | |
116 | /// Computes information for basic slot promotion. This will check that direct |
117 | /// slot promotion can be performed, and provide the information to execute the |
118 | /// promotion. This does not mutate IR. |
119 | class MemorySlotPromotionAnalyzer { |
120 | public: |
121 | MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance, |
122 | const DataLayout &dataLayout) |
123 | : slot(slot), dominance(dominance), dataLayout(dataLayout) {} |
124 | |
125 | /// Computes the information for slot promotion if promotion is possible, |
126 | /// returns nothing otherwise. |
127 | std::optional<MemorySlotPromotionInfo> computeInfo(); |
128 | |
129 | private: |
130 | /// Computes the transitive uses of the slot that block promotion. This finds |
131 | /// uses that would block the promotion, checks that the operation has a |
132 | /// solution to remove the blocking use, and potentially forwards the analysis |
133 | /// if the operation needs further blocking uses resolved to resolve its own |
134 | /// uses (typically, removing its users because it will delete itself to |
135 | /// resolve its own blocking uses). This will fail if one of the transitive |
136 | /// users cannot remove a requested use, and should prevent promotion. |
137 | LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses); |
138 | |
139 | /// Computes in which blocks the value stored in the slot is actually used, |
140 | /// meaning blocks leading to a load. This method uses `definingBlocks`, the |
141 | /// set of blocks containing a store to the slot (defining the value of the |
142 | /// slot). |
143 | SmallPtrSet<Block *, 16> |
144 | computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks); |
145 | |
146 | /// Computes the points in which multiple re-definitions of the slot's value |
147 | /// (stores) may conflict. |
148 | void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints); |
149 | |
150 | /// Ensures predecessors of merge points can properly provide their current |
151 | /// definition of the value stored in the slot to the merge point. This can |
152 | /// notably be an issue if the terminator used does not have the ability to |
153 | /// forward values through block operands. |
154 | bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints); |
155 | |
156 | MemorySlot slot; |
157 | DominanceInfo &dominance; |
158 | const DataLayout &dataLayout; |
159 | }; |
160 | |
161 | /// The MemorySlotPromoter handles the state of promoting a memory slot. It |
162 | /// wraps a slot and its associated allocator. This will perform the mutation of |
163 | /// IR. |
164 | class MemorySlotPromoter { |
165 | public: |
166 | MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, |
167 | RewriterBase &rewriter, DominanceInfo &dominance, |
168 | const DataLayout &dataLayout, MemorySlotPromotionInfo info, |
169 | const Mem2RegStatistics &statistics); |
170 | |
171 | /// Actually promotes the slot by mutating IR. Promoting a slot DOES |
172 | /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of |
173 | /// promotion info should NOT be performed in batches. |
174 | void promoteSlot(); |
175 | |
176 | private: |
177 | /// Computes the reaching definition for all the operations that require |
178 | /// promotion. `reachingDef` is the value the slot should contain at the |
179 | /// beginning of the block. This method returns the reached definition at the |
180 | /// end of the block. This method must only be called at most once per block. |
181 | Value computeReachingDefInBlock(Block *block, Value reachingDef); |
182 | |
183 | /// Computes the reaching definition for all the operations that require |
184 | /// promotion. `reachingDef` corresponds to the initial value the |
185 | /// slot will contain before any write, typically a poison value. |
186 | /// This method must only be called at most once per region. |
187 | void computeReachingDefInRegion(Region *region, Value reachingDef); |
188 | |
189 | /// Removes the blocking uses of the slot, in topological order. |
190 | void removeBlockingUses(); |
191 | |
192 | /// Lazily-constructed default value representing the content of the slot when |
193 | /// no store has been executed. This function may mutate IR. |
194 | Value getOrCreateDefaultValue(); |
195 | |
196 | MemorySlot slot; |
197 | PromotableAllocationOpInterface allocator; |
198 | RewriterBase &rewriter; |
199 | /// Potentially non-initialized default value. Use `getOrCreateDefaultValue` |
200 | /// to initialize it on demand. |
201 | Value defaultValue; |
202 | /// Contains the reaching definition at this operation. Reaching definitions |
203 | /// are only computed for promotable memory operations with blocking uses. |
204 | DenseMap<PromotableMemOpInterface, Value> reachingDefs; |
205 | DenseMap<PromotableMemOpInterface, Value> replacedValuesMap; |
206 | DominanceInfo &dominance; |
207 | const DataLayout &dataLayout; |
208 | MemorySlotPromotionInfo info; |
209 | const Mem2RegStatistics &statistics; |
210 | }; |
211 | |
212 | } // namespace |
213 | |
214 | MemorySlotPromoter::MemorySlotPromoter( |
215 | MemorySlot slot, PromotableAllocationOpInterface allocator, |
216 | RewriterBase &rewriter, DominanceInfo &dominance, |
217 | const DataLayout &dataLayout, MemorySlotPromotionInfo info, |
218 | const Mem2RegStatistics &statistics) |
219 | : slot(slot), allocator(allocator), rewriter(rewriter), |
220 | dominance(dominance), dataLayout(dataLayout), info(std::move(info)), |
221 | statistics(statistics) { |
222 | #ifndef NDEBUG |
223 | auto isResultOrNewBlockArgument = [&]() { |
224 | if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr)) |
225 | return arg.getOwner()->getParentOp() == allocator; |
226 | return slot.ptr.getDefiningOp() == allocator; |
227 | }; |
228 | |
229 | assert(isResultOrNewBlockArgument() && |
230 | "a slot must be a result of the allocator or an argument of the child " |
231 | "regions of the allocator" ); |
232 | #endif // NDEBUG |
233 | } |
234 | |
235 | Value MemorySlotPromoter::getOrCreateDefaultValue() { |
236 | if (defaultValue) |
237 | return defaultValue; |
238 | |
239 | RewriterBase::InsertionGuard guard(rewriter); |
240 | rewriter.setInsertionPointToStart(slot.ptr.getParentBlock()); |
241 | return defaultValue = allocator.getDefaultValue(slot, rewriter); |
242 | } |
243 | |
244 | LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( |
245 | BlockingUsesMap &userToBlockingUses) { |
246 | // The promotion of an operation may require the promotion of further |
247 | // operations (typically, removing operations that use an operation that must |
248 | // delete itself). We thus need to start from the use of the slot pointer and |
249 | // propagate further requests through the forward slice. |
250 | |
251 | // First insert that all immediate users of the slot pointer must no longer |
252 | // use it. |
253 | for (OpOperand &use : slot.ptr.getUses()) { |
254 | SmallPtrSet<OpOperand *, 4> &blockingUses = |
255 | userToBlockingUses[use.getOwner()]; |
256 | blockingUses.insert(Ptr: &use); |
257 | } |
258 | |
259 | // Then, propagate the requirements for the removal of uses. The |
260 | // topologically-sorted forward slice allows for all blocking uses of an |
261 | // operation to have been computed before it is reached. Operations are |
262 | // traversed in topological order of their uses, starting from the slot |
263 | // pointer. |
264 | SetVector<Operation *> forwardSlice; |
265 | mlir::getForwardSlice(root: slot.ptr, forwardSlice: &forwardSlice); |
266 | for (Operation *user : forwardSlice) { |
267 | // If the next operation has no blocking uses, everything is fine. |
268 | if (!userToBlockingUses.contains(user)) |
269 | continue; |
270 | |
271 | SmallPtrSet<OpOperand *, 4> &blockingUses = userToBlockingUses[user]; |
272 | |
273 | SmallVector<OpOperand *> newBlockingUses; |
274 | // If the operation decides it cannot deal with removing the blocking uses, |
275 | // promotion must fail. |
276 | if (auto promotable = dyn_cast<PromotableOpInterface>(user)) { |
277 | if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, |
278 | dataLayout)) |
279 | return failure(); |
280 | } else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) { |
281 | if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses, |
282 | dataLayout)) |
283 | return failure(); |
284 | } else { |
285 | // An operation that has blocking uses must be promoted. If it is not |
286 | // promotable, promotion must fail. |
287 | return failure(); |
288 | } |
289 | |
290 | // Then, register any new blocking uses for coming operations. |
291 | for (OpOperand *blockingUse : newBlockingUses) { |
292 | assert(llvm::is_contained(user->getResults(), blockingUse->get())); |
293 | |
294 | SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet = |
295 | userToBlockingUses[blockingUse->getOwner()]; |
296 | newUserBlockingUseSet.insert(Ptr: blockingUse); |
297 | } |
298 | } |
299 | |
300 | // Because this pass currently only supports analysing the parent region of |
301 | // the slot pointer, if a promotable memory op that needs promotion is outside |
302 | // of this region, promotion must fail because it will be impossible to |
303 | // provide a valid `reachingDef` for it. |
304 | for (auto &[toPromote, _] : userToBlockingUses) |
305 | if (isa<PromotableMemOpInterface>(toPromote) && |
306 | toPromote->getParentRegion() != slot.ptr.getParentRegion()) |
307 | return failure(); |
308 | |
309 | return success(); |
310 | } |
311 | |
312 | SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn( |
313 | SmallPtrSetImpl<Block *> &definingBlocks) { |
314 | SmallPtrSet<Block *, 16> liveIn; |
315 | |
316 | // The worklist contains blocks in which it is known that the slot value is |
317 | // live-in. The further blocks where this value is live-in will be inferred |
318 | // from these. |
319 | SmallVector<Block *> liveInWorkList; |
320 | |
321 | // Blocks with a load before any other store to the slot are the starting |
322 | // points of the analysis. The slot value is definitely live-in in those |
323 | // blocks. |
324 | SmallPtrSet<Block *, 16> visited; |
325 | for (Operation *user : slot.ptr.getUsers()) { |
326 | if (visited.contains(Ptr: user->getBlock())) |
327 | continue; |
328 | visited.insert(Ptr: user->getBlock()); |
329 | |
330 | for (Operation &op : user->getBlock()->getOperations()) { |
331 | if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) { |
332 | // If this operation loads the slot, it is loading from it before |
333 | // ever writing to it, so the value is live-in in this block. |
334 | if (memOp.loadsFrom(slot)) { |
335 | liveInWorkList.push_back(Elt: user->getBlock()); |
336 | break; |
337 | } |
338 | |
339 | // If we store to the slot, further loads will see that value. |
340 | // Because we did not meet any load before, the value is not live-in. |
341 | if (memOp.storesTo(slot)) |
342 | break; |
343 | } |
344 | } |
345 | } |
346 | |
347 | // The information is then propagated to the predecessors until a def site |
348 | // (store) is found. |
349 | while (!liveInWorkList.empty()) { |
350 | Block *liveInBlock = liveInWorkList.pop_back_val(); |
351 | |
352 | if (!liveIn.insert(Ptr: liveInBlock).second) |
353 | continue; |
354 | |
355 | // If a predecessor is a defining block, either: |
356 | // - It has a load before its first store, in which case it is live-in but |
357 | // has already been processed in the initialisation step. |
358 | // - It has a store before any load, in which case it is not live-in. |
359 | // We can thus at this stage insert to the worklist only predecessors that |
360 | // are not defining blocks. |
361 | for (Block *pred : liveInBlock->getPredecessors()) |
362 | if (!definingBlocks.contains(Ptr: pred)) |
363 | liveInWorkList.push_back(Elt: pred); |
364 | } |
365 | |
366 | return liveIn; |
367 | } |
368 | |
369 | using IDFCalculator = llvm::IDFCalculatorBase<Block, false>; |
370 | void MemorySlotPromotionAnalyzer::computeMergePoints( |
371 | SmallPtrSetImpl<Block *> &mergePoints) { |
372 | if (slot.ptr.getParentRegion()->hasOneBlock()) |
373 | return; |
374 | |
375 | IDFCalculator idfCalculator(dominance.getDomTree(region: slot.ptr.getParentRegion())); |
376 | |
377 | SmallPtrSet<Block *, 16> definingBlocks; |
378 | for (Operation *user : slot.ptr.getUsers()) |
379 | if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user)) |
380 | if (storeOp.storesTo(slot)) |
381 | definingBlocks.insert(Ptr: user->getBlock()); |
382 | |
383 | idfCalculator.setDefiningBlocks(definingBlocks); |
384 | |
385 | SmallPtrSet<Block *, 16> liveIn = computeSlotLiveIn(definingBlocks); |
386 | idfCalculator.setLiveInBlocks(liveIn); |
387 | |
388 | SmallVector<Block *> mergePointsVec; |
389 | idfCalculator.calculate(IDFBlocks&: mergePointsVec); |
390 | |
391 | mergePoints.insert(I: mergePointsVec.begin(), E: mergePointsVec.end()); |
392 | } |
393 | |
394 | bool MemorySlotPromotionAnalyzer::areMergePointsUsable( |
395 | SmallPtrSetImpl<Block *> &mergePoints) { |
396 | for (Block *mergePoint : mergePoints) |
397 | for (Block *pred : mergePoint->getPredecessors()) |
398 | if (!isa<BranchOpInterface>(Val: pred->getTerminator())) |
399 | return false; |
400 | |
401 | return true; |
402 | } |
403 | |
404 | std::optional<MemorySlotPromotionInfo> |
405 | MemorySlotPromotionAnalyzer::computeInfo() { |
406 | MemorySlotPromotionInfo info; |
407 | |
408 | // First, find the set of operations that will need to be changed for the |
409 | // promotion to happen. These operations need to resolve some of their uses, |
410 | // either by rewiring them or simply deleting themselves. If any of them |
411 | // cannot find a way to resolve their blocking uses, we abort the promotion. |
412 | if (failed(computeBlockingUses(info.userToBlockingUses))) |
413 | return {}; |
414 | |
415 | // Then, compute blocks in which two or more definitions of the allocated |
416 | // variable may conflict. These blocks will need a new block argument to |
417 | // accommodate this. |
418 | computeMergePoints(mergePoints&: info.mergePoints); |
419 | |
420 | // The slot can be promoted if the block arguments to be created can |
421 | // actually be populated with values, which may not be possible depending |
422 | // on their predecessors. |
423 | if (!areMergePointsUsable(mergePoints&: info.mergePoints)) |
424 | return {}; |
425 | |
426 | return info; |
427 | } |
428 | |
429 | Value MemorySlotPromoter::computeReachingDefInBlock(Block *block, |
430 | Value reachingDef) { |
431 | SmallVector<Operation *> blockOps; |
432 | for (Operation &op : block->getOperations()) |
433 | blockOps.push_back(Elt: &op); |
434 | for (Operation *op : blockOps) { |
435 | if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) { |
436 | if (info.userToBlockingUses.contains(memOp)) |
437 | reachingDefs.insert({memOp, reachingDef}); |
438 | |
439 | if (memOp.storesTo(slot)) { |
440 | rewriter.setInsertionPointAfter(memOp); |
441 | Value stored = memOp.getStored(slot, rewriter, reachingDef, dataLayout); |
442 | assert(stored && "a memory operation storing to a slot must provide a " |
443 | "new definition of the slot" ); |
444 | reachingDef = stored; |
445 | replacedValuesMap[memOp] = stored; |
446 | } |
447 | } |
448 | } |
449 | |
450 | return reachingDef; |
451 | } |
452 | |
453 | void MemorySlotPromoter::computeReachingDefInRegion(Region *region, |
454 | Value reachingDef) { |
455 | assert(reachingDef && "expected an initial reaching def to be provided" ); |
456 | if (region->hasOneBlock()) { |
457 | computeReachingDefInBlock(block: ®ion->front(), reachingDef); |
458 | return; |
459 | } |
460 | |
461 | struct DfsJob { |
462 | llvm::DomTreeNodeBase<Block> *block; |
463 | Value reachingDef; |
464 | }; |
465 | |
466 | SmallVector<DfsJob> dfsStack; |
467 | |
468 | auto &domTree = dominance.getDomTree(region: slot.ptr.getParentRegion()); |
469 | |
470 | dfsStack.emplace_back<DfsJob>( |
471 | Args: {.block: domTree.getNode(BB: ®ion->front()), .reachingDef: reachingDef}); |
472 | |
473 | while (!dfsStack.empty()) { |
474 | DfsJob job = dfsStack.pop_back_val(); |
475 | Block *block = job.block->getBlock(); |
476 | |
477 | if (info.mergePoints.contains(block)) { |
478 | // If the block is a merge point, we need to add a block argument to hold |
479 | // the selected reaching definition. This has to be a bit complicated |
480 | // because of RewriterBase limitations: we need to create a new block with |
481 | // the extra block argument, move the content of the block to the new |
482 | // block, and replace the block with the new block in the merge point set. |
483 | SmallVector<Type> argTypes; |
484 | SmallVector<Location> argLocs; |
485 | for (BlockArgument arg : block->getArguments()) { |
486 | argTypes.push_back(Elt: arg.getType()); |
487 | argLocs.push_back(Elt: arg.getLoc()); |
488 | } |
489 | argTypes.push_back(Elt: slot.elemType); |
490 | argLocs.push_back(Elt: slot.ptr.getLoc()); |
491 | Block *newBlock = rewriter.createBlock(insertBefore: block, argTypes, locs: argLocs); |
492 | |
493 | info.mergePoints.erase(block); |
494 | info.mergePoints.insert(newBlock); |
495 | |
496 | rewriter.replaceAllUsesWith(from: block, to: newBlock); |
497 | rewriter.mergeBlocks(source: block, dest: newBlock, |
498 | argValues: newBlock->getArguments().drop_back()); |
499 | |
500 | block = newBlock; |
501 | |
502 | BlockArgument blockArgument = block->getArguments().back(); |
503 | rewriter.setInsertionPointToStart(block); |
504 | allocator.handleBlockArgument(slot, blockArgument, rewriter); |
505 | job.reachingDef = blockArgument; |
506 | |
507 | if (statistics.newBlockArgumentAmount) |
508 | (*statistics.newBlockArgumentAmount)++; |
509 | } |
510 | |
511 | job.reachingDef = computeReachingDefInBlock(block, reachingDef: job.reachingDef); |
512 | assert(job.reachingDef); |
513 | |
514 | if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) { |
515 | for (BlockOperand &blockOperand : terminator->getBlockOperands()) { |
516 | if (info.mergePoints.contains(blockOperand.get())) { |
517 | rewriter.modifyOpInPlace(terminator, [&]() { |
518 | terminator.getSuccessorOperands(blockOperand.getOperandNumber()) |
519 | .append(job.reachingDef); |
520 | }); |
521 | } |
522 | } |
523 | } |
524 | |
525 | for (auto *child : job.block->children()) |
526 | dfsStack.emplace_back<DfsJob>(Args: {.block: child, .reachingDef: job.reachingDef}); |
527 | } |
528 | } |
529 | |
530 | /// Sorts `ops` according to dominance. Relies on the topological order of basic |
531 | /// blocks to get a deterministic ordering. |
532 | static void dominanceSort(SmallVector<Operation *> &ops, Region ®ion) { |
533 | // Produce a topological block order and construct a map to lookup the indices |
534 | // of blocks. |
535 | DenseMap<Block *, size_t> topoBlockIndices; |
536 | SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks(region); |
537 | for (auto [index, block] : llvm::enumerate(First&: topologicalOrder)) |
538 | topoBlockIndices[block] = index; |
539 | |
540 | // Combining the topological order of the basic blocks together with block |
541 | // internal operation order guarantees a deterministic, dominance respecting |
542 | // order. |
543 | llvm::sort(C&: ops, Comp: [&](Operation *lhs, Operation *rhs) { |
544 | size_t lhsBlockIndex = topoBlockIndices.at(Val: lhs->getBlock()); |
545 | size_t rhsBlockIndex = topoBlockIndices.at(Val: rhs->getBlock()); |
546 | if (lhsBlockIndex == rhsBlockIndex) |
547 | return lhs->isBeforeInBlock(other: rhs); |
548 | return lhsBlockIndex < rhsBlockIndex; |
549 | }); |
550 | } |
551 | |
552 | void MemorySlotPromoter::removeBlockingUses() { |
553 | llvm::SmallVector<Operation *> usersToRemoveUses( |
554 | llvm::make_first_range(info.userToBlockingUses)); |
555 | |
556 | // Sort according to dominance. |
557 | dominanceSort(ops&: usersToRemoveUses, region&: *slot.ptr.getParentBlock()->getParent()); |
558 | |
559 | llvm::SmallVector<Operation *> toErase; |
560 | // List of all replaced values in the slot. |
561 | llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList; |
562 | // Ops to visit with the `visitReplacedValues` method. |
563 | llvm::SmallVector<PromotableOpInterface> toVisit; |
564 | for (Operation *toPromote : llvm::reverse(C&: usersToRemoveUses)) { |
565 | if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) { |
566 | Value reachingDef = reachingDefs.lookup(toPromoteMemOp); |
567 | // If no reaching definition is known, this use is outside the reach of |
568 | // the slot. The default value should thus be used. |
569 | if (!reachingDef) |
570 | reachingDef = getOrCreateDefaultValue(); |
571 | |
572 | rewriter.setInsertionPointAfter(toPromote); |
573 | if (toPromoteMemOp.removeBlockingUses( |
574 | slot, info.userToBlockingUses[toPromote], rewriter, reachingDef, |
575 | dataLayout) == DeletionKind::Delete) |
576 | toErase.push_back(Elt: toPromote); |
577 | if (toPromoteMemOp.storesTo(slot)) |
578 | if (Value replacedValue = replacedValuesMap[toPromoteMemOp]) |
579 | replacedValuesList.push_back(Elt: {toPromoteMemOp, replacedValue}); |
580 | continue; |
581 | } |
582 | |
583 | auto toPromoteBasic = cast<PromotableOpInterface>(toPromote); |
584 | rewriter.setInsertionPointAfter(toPromote); |
585 | if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote], |
586 | rewriter) == DeletionKind::Delete) |
587 | toErase.push_back(Elt: toPromote); |
588 | if (toPromoteBasic.requiresReplacedValues()) |
589 | toVisit.push_back(toPromoteBasic); |
590 | } |
591 | for (PromotableOpInterface op : toVisit) { |
592 | rewriter.setInsertionPointAfter(op); |
593 | op.visitReplacedValues(replacedValuesList, rewriter); |
594 | } |
595 | |
596 | for (Operation *toEraseOp : toErase) |
597 | rewriter.eraseOp(op: toEraseOp); |
598 | |
599 | assert(slot.ptr.use_empty() && |
600 | "after promotion, the slot pointer should not be used anymore" ); |
601 | } |
602 | |
603 | void MemorySlotPromoter::promoteSlot() { |
604 | computeReachingDefInRegion(region: slot.ptr.getParentRegion(), |
605 | reachingDef: getOrCreateDefaultValue()); |
606 | |
607 | // Now that reaching definitions are known, remove all users. |
608 | removeBlockingUses(); |
609 | |
610 | // Update terminators in dead branches to forward default if they are |
611 | // succeeded by a merge points. |
612 | for (Block *mergePoint : info.mergePoints) { |
613 | for (BlockOperand &use : mergePoint->getUses()) { |
614 | auto user = cast<BranchOpInterface>(use.getOwner()); |
615 | SuccessorOperands succOperands = |
616 | user.getSuccessorOperands(use.getOperandNumber()); |
617 | assert(succOperands.size() == mergePoint->getNumArguments() || |
618 | succOperands.size() + 1 == mergePoint->getNumArguments()); |
619 | if (succOperands.size() + 1 == mergePoint->getNumArguments()) |
620 | rewriter.modifyOpInPlace( |
621 | user, [&]() { succOperands.append(getOrCreateDefaultValue()); }); |
622 | } |
623 | } |
624 | |
625 | LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr |
626 | << "\n" ); |
627 | |
628 | if (statistics.promotedAmount) |
629 | (*statistics.promotedAmount)++; |
630 | |
631 | allocator.handlePromotionComplete(slot, defaultValue, rewriter); |
632 | } |
633 | |
634 | LogicalResult mlir::tryToPromoteMemorySlots( |
635 | ArrayRef<PromotableAllocationOpInterface> allocators, |
636 | RewriterBase &rewriter, const DataLayout &dataLayout, |
637 | Mem2RegStatistics statistics) { |
638 | bool promotedAny = false; |
639 | |
640 | for (PromotableAllocationOpInterface allocator : allocators) { |
641 | for (MemorySlot slot : allocator.getPromotableSlots()) { |
642 | if (slot.ptr.use_empty()) |
643 | continue; |
644 | |
645 | DominanceInfo dominance; |
646 | MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout); |
647 | std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo(); |
648 | if (info) { |
649 | MemorySlotPromoter(slot, allocator, rewriter, dominance, dataLayout, |
650 | std::move(*info), statistics) |
651 | .promoteSlot(); |
652 | promotedAny = true; |
653 | } |
654 | } |
655 | } |
656 | |
657 | return success(isSuccess: promotedAny); |
658 | } |
659 | |
660 | namespace { |
661 | |
662 | struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> { |
663 | using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase; |
664 | |
665 | void runOnOperation() override { |
666 | Operation *scopeOp = getOperation(); |
667 | |
668 | Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount}; |
669 | |
670 | bool changed = false; |
671 | |
672 | for (Region ®ion : scopeOp->getRegions()) { |
673 | if (region.getBlocks().empty()) |
674 | continue; |
675 | |
676 | OpBuilder builder(®ion.front(), region.front().begin()); |
677 | IRRewriter rewriter(builder); |
678 | |
679 | // Promoting a slot can allow for further promotion of other slots, |
680 | // promotion is tried until no promotion succeeds. |
681 | while (true) { |
682 | SmallVector<PromotableAllocationOpInterface> allocators; |
683 | // Build a list of allocators to attempt to promote the slots of. |
684 | region.walk([&](PromotableAllocationOpInterface allocator) { |
685 | allocators.emplace_back(allocator); |
686 | }); |
687 | |
688 | auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); |
689 | const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp); |
690 | |
691 | // Attempt promoting until no promotion succeeds. |
692 | if (failed(tryToPromoteMemorySlots(allocators, rewriter, dataLayout, |
693 | statistics))) |
694 | break; |
695 | |
696 | changed = true; |
697 | getAnalysisManager().invalidate({}); |
698 | } |
699 | } |
700 | if (!changed) |
701 | markAllAnalysesPreserved(); |
702 | } |
703 | }; |
704 | |
705 | } // namespace |
706 | |