1 | //===- Mem2Reg.cpp - Promotes memory slots into values ----------*- C++ -*-===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Transforms/Mem2Reg.h" |
10 | #include "mlir/Analysis/DataLayoutAnalysis.h" |
11 | #include "mlir/Analysis/SliceAnalysis.h" |
12 | #include "mlir/Analysis/TopologicalSortUtils.h" |
13 | #include "mlir/IR/Builders.h" |
14 | #include "mlir/IR/Dominance.h" |
15 | #include "mlir/IR/PatternMatch.h" |
16 | #include "mlir/IR/RegionKindInterface.h" |
17 | #include "mlir/IR/Value.h" |
18 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
19 | #include "mlir/Interfaces/MemorySlotInterfaces.h" |
20 | #include "mlir/Transforms/Passes.h" |
21 | #include "llvm/ADT/STLExtras.h" |
22 | #include "llvm/Support/GenericIteratedDominanceFrontier.h" |
23 | |
24 | namespace mlir { |
25 | #define GEN_PASS_DEF_MEM2REG |
26 | #include "mlir/Transforms/Passes.h.inc" |
27 | } // namespace mlir |
28 | |
29 | #define DEBUG_TYPE "mem2reg" |
30 | |
31 | using namespace mlir; |
32 | |
33 | /// mem2reg |
34 | /// |
35 | /// This pass turns unnecessary uses of automatically allocated memory slots |
36 | /// into direct Value-based operations. For example, it will simplify storing a |
37 | /// constant in a memory slot to immediately load it to a direct use of that |
38 | /// constant. In other words, given a memory slot addressed by a non-aliased |
39 | /// "pointer" Value, mem2reg removes all the uses of that pointer. |
40 | /// |
41 | /// Within a block, this is done by following the chain of stores and loads of |
42 | /// the slot and replacing the results of loads with the values previously |
43 | /// stored. If a load happens before any other store, a poison value is used |
44 | /// instead. |
45 | /// |
46 | /// Control flow can create situations where a load could be replaced by |
47 | /// multiple possible stores depending on the control flow path taken. As a |
48 | /// result, this pass must introduce new block arguments in some blocks to |
49 | /// accommodate for the multiple possible definitions. Each predecessor will |
50 | /// populate the block argument with the definition reached at its end. With |
51 | /// this, the value stored can be well defined at block boundaries, allowing |
52 | /// the propagation of replacement through blocks. |
53 | /// |
54 | /// This pass computes this transformation in four main steps. The two first |
55 | /// steps are performed during an analysis phase that does not mutate IR. |
56 | /// |
57 | /// The two steps of the analysis phase are the following: |
58 | /// - A first step computes the list of operations that transitively use the |
59 | /// memory slot we would like to promote. The purpose of this phase is to |
60 | /// identify which uses must be removed to promote the slot, either by rewiring |
61 | /// the user or deleting it. Naturally, direct uses of the slot must be removed. |
62 | /// Sometimes additional uses must also be removed: this is notably the case |
63 | /// when a direct user of the slot cannot rewire its use and must delete itself, |
64 | /// and thus must make its users no longer use it. If any of those uses cannot |
65 | /// be removed by their users in any way, promotion cannot continue: this is |
66 | /// decided at this step. |
67 | /// - A second step computes the list of blocks where a block argument will be |
68 | /// needed ("merge points") without mutating the IR. These blocks are the blocks |
69 | /// leading to a definition clash between two predecessors. Such blocks happen |
70 | /// to be the Iterated Dominance Frontier (IDF) of the set of blocks containing |
71 | /// a store, as they represent the point where a clear defining dominator stops |
72 | /// existing. Computing this information in advance allows making sure the |
73 | /// terminators that will forward values are capable of doing so (inability to |
74 | /// do so aborts promotion at this step). |
75 | /// |
76 | /// At this point, promotion is guaranteed to happen, and the mutation phase can |
77 | /// begin with the following steps: |
78 | /// - A third step computes the reaching definition of the memory slot at each |
79 | /// blocking user. This is the core of the mem2reg algorithm, also known as |
80 | /// load-store forwarding. This analyses loads and stores and propagates which |
81 | /// value must be stored in the slot at each blocking user. This is achieved by |
82 | /// doing a depth-first walk of the dominator tree of the function. This is |
83 | /// sufficient because the reaching definition at the beginning of a block is |
84 | /// either its new block argument if it is a merge block, or the definition |
85 | /// reaching the end of its immediate dominator (parent in the dominator tree). |
86 | /// We can therefore propagate this information down the dominator tree to |
87 | /// proceed with renaming within blocks. |
88 | /// - The final fourth step uses the reaching definition to remove blocking uses |
89 | /// in topological order. |
90 | /// |
91 | /// For further reading, chapter three of SSA-based Compiler Design [1] |
92 | /// showcases SSA construction, where mem2reg is an adaptation of the same |
93 | /// process. |
94 | /// |
95 | /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022), |
96 | /// Springer. |
97 | |
98 | namespace { |
99 | |
100 | using BlockingUsesMap = |
101 | llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>; |
102 | |
103 | /// Information computed during promotion analysis used to perform actual |
104 | /// promotion. |
105 | struct MemorySlotPromotionInfo { |
106 | /// Blocks for which at least two definitions of the slot values clash. |
107 | SmallPtrSet<Block *, 8> mergePoints; |
108 | /// Contains, for each operation, which uses must be eliminated by promotion. |
109 | /// This is a DAG structure because if an operation must eliminate some of |
110 | /// its uses, it is because the defining ops of the blocking uses requested |
111 | /// it. The defining ops therefore must also have blocking uses or be the |
112 | /// starting point of the blocking uses. |
113 | BlockingUsesMap userToBlockingUses; |
114 | }; |
115 | |
116 | /// Computes information for basic slot promotion. This will check that direct |
117 | /// slot promotion can be performed, and provide the information to execute the |
118 | /// promotion. This does not mutate IR. |
119 | class MemorySlotPromotionAnalyzer { |
120 | public: |
121 | MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance, |
122 | const DataLayout &dataLayout) |
123 | : slot(slot), dominance(dominance), dataLayout(dataLayout) {} |
124 | |
125 | /// Computes the information for slot promotion if promotion is possible, |
126 | /// returns nothing otherwise. |
127 | std::optional<MemorySlotPromotionInfo> computeInfo(); |
128 | |
129 | private: |
130 | /// Computes the transitive uses of the slot that block promotion. This finds |
131 | /// uses that would block the promotion, checks that the operation has a |
132 | /// solution to remove the blocking use, and potentially forwards the analysis |
133 | /// if the operation needs further blocking uses resolved to resolve its own |
134 | /// uses (typically, removing its users because it will delete itself to |
135 | /// resolve its own blocking uses). This will fail if one of the transitive |
136 | /// users cannot remove a requested use, and should prevent promotion. |
137 | LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses); |
138 | |
139 | /// Computes in which blocks the value stored in the slot is actually used, |
140 | /// meaning blocks leading to a load. This method uses `definingBlocks`, the |
141 | /// set of blocks containing a store to the slot (defining the value of the |
142 | /// slot). |
143 | SmallPtrSet<Block *, 16> |
144 | computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks); |
145 | |
146 | /// Computes the points in which multiple re-definitions of the slot's value |
147 | /// (stores) may conflict. |
148 | void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints); |
149 | |
150 | /// Ensures predecessors of merge points can properly provide their current |
151 | /// definition of the value stored in the slot to the merge point. This can |
152 | /// notably be an issue if the terminator used does not have the ability to |
153 | /// forward values through block operands. |
154 | bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints); |
155 | |
156 | MemorySlot slot; |
157 | DominanceInfo &dominance; |
158 | const DataLayout &dataLayout; |
159 | }; |
160 | |
161 | using BlockIndexCache = DenseMap<Region *, DenseMap<Block *, size_t>>; |
162 | |
163 | /// The MemorySlotPromoter handles the state of promoting a memory slot. It |
164 | /// wraps a slot and its associated allocator. This will perform the mutation of |
165 | /// IR. |
166 | class MemorySlotPromoter { |
167 | public: |
168 | MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, |
169 | OpBuilder &builder, DominanceInfo &dominance, |
170 | const DataLayout &dataLayout, MemorySlotPromotionInfo info, |
171 | const Mem2RegStatistics &statistics, |
172 | BlockIndexCache &blockIndexCache); |
173 | |
174 | /// Actually promotes the slot by mutating IR. Promoting a slot DOES |
175 | /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of |
176 | /// promotion info should NOT be performed in batches. |
177 | /// Returns a promotable allocation op if a new allocator was created, nullopt |
178 | /// otherwise. |
179 | std::optional<PromotableAllocationOpInterface> promoteSlot(); |
180 | |
181 | private: |
182 | /// Computes the reaching definition for all the operations that require |
183 | /// promotion. `reachingDef` is the value the slot should contain at the |
184 | /// beginning of the block. This method returns the reached definition at the |
185 | /// end of the block. This method must only be called at most once per block. |
186 | Value computeReachingDefInBlock(Block *block, Value reachingDef); |
187 | |
188 | /// Computes the reaching definition for all the operations that require |
189 | /// promotion. `reachingDef` corresponds to the initial value the |
190 | /// slot will contain before any write, typically a poison value. |
191 | /// This method must only be called at most once per region. |
192 | void computeReachingDefInRegion(Region *region, Value reachingDef); |
193 | |
194 | /// Removes the blocking uses of the slot, in topological order. |
195 | void removeBlockingUses(); |
196 | |
197 | /// Lazily-constructed default value representing the content of the slot when |
198 | /// no store has been executed. This function may mutate IR. |
199 | Value getOrCreateDefaultValue(); |
200 | |
201 | MemorySlot slot; |
202 | PromotableAllocationOpInterface allocator; |
203 | OpBuilder &builder; |
204 | /// Potentially non-initialized default value. Use `getOrCreateDefaultValue` |
205 | /// to initialize it on demand. |
206 | Value defaultValue; |
207 | /// Contains the reaching definition at this operation. Reaching definitions |
208 | /// are only computed for promotable memory operations with blocking uses. |
209 | DenseMap<PromotableMemOpInterface, Value> reachingDefs; |
210 | DenseMap<PromotableMemOpInterface, Value> replacedValuesMap; |
211 | DominanceInfo &dominance; |
212 | const DataLayout &dataLayout; |
213 | MemorySlotPromotionInfo info; |
214 | const Mem2RegStatistics &statistics; |
215 | |
216 | /// Shared cache of block indices of specific regions. |
217 | BlockIndexCache &blockIndexCache; |
218 | }; |
219 | |
220 | } // namespace |
221 | |
222 | MemorySlotPromoter::MemorySlotPromoter( |
223 | MemorySlot slot, PromotableAllocationOpInterface allocator, |
224 | OpBuilder &builder, DominanceInfo &dominance, const DataLayout &dataLayout, |
225 | MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics, |
226 | BlockIndexCache &blockIndexCache) |
227 | : slot(slot), allocator(allocator), builder(builder), dominance(dominance), |
228 | dataLayout(dataLayout), info(std::move(info)), statistics(statistics), |
229 | blockIndexCache(blockIndexCache) { |
230 | #ifndef NDEBUG |
231 | auto isResultOrNewBlockArgument = [&]() { |
232 | if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr)) |
233 | return arg.getOwner()->getParentOp() == allocator; |
234 | return slot.ptr.getDefiningOp() == allocator; |
235 | }; |
236 | |
237 | assert(isResultOrNewBlockArgument() && |
238 | "a slot must be a result of the allocator or an argument of the child " |
239 | "regions of the allocator"); |
240 | #endif // NDEBUG |
241 | } |
242 | |
243 | Value MemorySlotPromoter::getOrCreateDefaultValue() { |
244 | if (defaultValue) |
245 | return defaultValue; |
246 | |
247 | OpBuilder::InsertionGuard guard(builder); |
248 | builder.setInsertionPointToStart(slot.ptr.getParentBlock()); |
249 | return defaultValue = allocator.getDefaultValue(slot, builder); |
250 | } |
251 | |
252 | LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( |
253 | BlockingUsesMap &userToBlockingUses) { |
254 | // The promotion of an operation may require the promotion of further |
255 | // operations (typically, removing operations that use an operation that must |
256 | // delete itself). We thus need to start from the use of the slot pointer and |
257 | // propagate further requests through the forward slice. |
258 | |
259 | // Because this pass currently only supports analysing the parent region of |
260 | // the slot pointer, if a promotable memory op that needs promotion is within |
261 | // a graph region, the slot may only be used in a graph region and should |
262 | // therefore be ignored. |
263 | Region *slotPtrRegion = slot.ptr.getParentRegion(); |
264 | auto slotPtrRegionOp = |
265 | dyn_cast<RegionKindInterface>(slotPtrRegion->getParentOp()); |
266 | if (slotPtrRegionOp && |
267 | slotPtrRegionOp.getRegionKind(slotPtrRegion->getRegionNumber()) == |
268 | RegionKind::Graph) |
269 | return failure(); |
270 | |
271 | // First insert that all immediate users of the slot pointer must no longer |
272 | // use it. |
273 | for (OpOperand &use : slot.ptr.getUses()) { |
274 | SmallPtrSet<OpOperand *, 4> &blockingUses = |
275 | userToBlockingUses[use.getOwner()]; |
276 | blockingUses.insert(Ptr: &use); |
277 | } |
278 | |
279 | // Then, propagate the requirements for the removal of uses. The |
280 | // topologically-sorted forward slice allows for all blocking uses of an |
281 | // operation to have been computed before it is reached. Operations are |
282 | // traversed in topological order of their uses, starting from the slot |
283 | // pointer. |
284 | SetVector<Operation *> forwardSlice; |
285 | mlir::getForwardSlice(root: slot.ptr, forwardSlice: &forwardSlice); |
286 | for (Operation *user : forwardSlice) { |
287 | // If the next operation has no blocking uses, everything is fine. |
288 | auto it = userToBlockingUses.find(user); |
289 | if (it == userToBlockingUses.end()) |
290 | continue; |
291 | |
292 | SmallPtrSet<OpOperand *, 4> &blockingUses = it->second; |
293 | |
294 | SmallVector<OpOperand *> newBlockingUses; |
295 | // If the operation decides it cannot deal with removing the blocking uses, |
296 | // promotion must fail. |
297 | if (auto promotable = dyn_cast<PromotableOpInterface>(user)) { |
298 | if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, |
299 | dataLayout)) |
300 | return failure(); |
301 | } else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) { |
302 | if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses, |
303 | dataLayout)) |
304 | return failure(); |
305 | } else { |
306 | // An operation that has blocking uses must be promoted. If it is not |
307 | // promotable, promotion must fail. |
308 | return failure(); |
309 | } |
310 | |
311 | // Then, register any new blocking uses for coming operations. |
312 | for (OpOperand *blockingUse : newBlockingUses) { |
313 | assert(llvm::is_contained(user->getResults(), blockingUse->get())); |
314 | |
315 | SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet = |
316 | userToBlockingUses[blockingUse->getOwner()]; |
317 | newUserBlockingUseSet.insert(Ptr: blockingUse); |
318 | } |
319 | } |
320 | |
321 | // Because this pass currently only supports analysing the parent region of |
322 | // the slot pointer, if a promotable memory op that needs promotion is outside |
323 | // of this region, promotion must fail because it will be impossible to |
324 | // provide a valid `reachingDef` for it. |
325 | for (auto &[toPromote, _] : userToBlockingUses) |
326 | if (isa<PromotableMemOpInterface>(toPromote) && |
327 | toPromote->getParentRegion() != slot.ptr.getParentRegion()) |
328 | return failure(); |
329 | |
330 | return success(); |
331 | } |
332 | |
333 | SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn( |
334 | SmallPtrSetImpl<Block *> &definingBlocks) { |
335 | SmallPtrSet<Block *, 16> liveIn; |
336 | |
337 | // The worklist contains blocks in which it is known that the slot value is |
338 | // live-in. The further blocks where this value is live-in will be inferred |
339 | // from these. |
340 | SmallVector<Block *> liveInWorkList; |
341 | |
342 | // Blocks with a load before any other store to the slot are the starting |
343 | // points of the analysis. The slot value is definitely live-in in those |
344 | // blocks. |
345 | SmallPtrSet<Block *, 16> visited; |
346 | for (Operation *user : slot.ptr.getUsers()) { |
347 | if (!visited.insert(Ptr: user->getBlock()).second) |
348 | continue; |
349 | |
350 | for (Operation &op : user->getBlock()->getOperations()) { |
351 | if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) { |
352 | // If this operation loads the slot, it is loading from it before |
353 | // ever writing to it, so the value is live-in in this block. |
354 | if (memOp.loadsFrom(slot)) { |
355 | liveInWorkList.push_back(Elt: user->getBlock()); |
356 | break; |
357 | } |
358 | |
359 | // If we store to the slot, further loads will see that value. |
360 | // Because we did not meet any load before, the value is not live-in. |
361 | if (memOp.storesTo(slot)) |
362 | break; |
363 | } |
364 | } |
365 | } |
366 | |
367 | // The information is then propagated to the predecessors until a def site |
368 | // (store) is found. |
369 | while (!liveInWorkList.empty()) { |
370 | Block *liveInBlock = liveInWorkList.pop_back_val(); |
371 | |
372 | if (!liveIn.insert(Ptr: liveInBlock).second) |
373 | continue; |
374 | |
375 | // If a predecessor is a defining block, either: |
376 | // - It has a load before its first store, in which case it is live-in but |
377 | // has already been processed in the initialisation step. |
378 | // - It has a store before any load, in which case it is not live-in. |
379 | // We can thus at this stage insert to the worklist only predecessors that |
380 | // are not defining blocks. |
381 | for (Block *pred : liveInBlock->getPredecessors()) |
382 | if (!definingBlocks.contains(Ptr: pred)) |
383 | liveInWorkList.push_back(Elt: pred); |
384 | } |
385 | |
386 | return liveIn; |
387 | } |
388 | |
389 | using IDFCalculator = llvm::IDFCalculatorBase<Block, false>; |
390 | void MemorySlotPromotionAnalyzer::computeMergePoints( |
391 | SmallPtrSetImpl<Block *> &mergePoints) { |
392 | if (slot.ptr.getParentRegion()->hasOneBlock()) |
393 | return; |
394 | |
395 | IDFCalculator idfCalculator(dominance.getDomTree(region: slot.ptr.getParentRegion())); |
396 | |
397 | SmallPtrSet<Block *, 16> definingBlocks; |
398 | for (Operation *user : slot.ptr.getUsers()) |
399 | if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user)) |
400 | if (storeOp.storesTo(slot)) |
401 | definingBlocks.insert(Ptr: user->getBlock()); |
402 | |
403 | idfCalculator.setDefiningBlocks(definingBlocks); |
404 | |
405 | SmallPtrSet<Block *, 16> liveIn = computeSlotLiveIn(definingBlocks); |
406 | idfCalculator.setLiveInBlocks(liveIn); |
407 | |
408 | SmallVector<Block *> mergePointsVec; |
409 | idfCalculator.calculate(IDFBlocks&: mergePointsVec); |
410 | |
411 | mergePoints.insert_range(R&: mergePointsVec); |
412 | } |
413 | |
414 | bool MemorySlotPromotionAnalyzer::areMergePointsUsable( |
415 | SmallPtrSetImpl<Block *> &mergePoints) { |
416 | for (Block *mergePoint : mergePoints) |
417 | for (Block *pred : mergePoint->getPredecessors()) |
418 | if (!isa<BranchOpInterface>(Val: pred->getTerminator())) |
419 | return false; |
420 | |
421 | return true; |
422 | } |
423 | |
424 | std::optional<MemorySlotPromotionInfo> |
425 | MemorySlotPromotionAnalyzer::computeInfo() { |
426 | MemorySlotPromotionInfo info; |
427 | |
428 | // First, find the set of operations that will need to be changed for the |
429 | // promotion to happen. These operations need to resolve some of their uses, |
430 | // either by rewiring them or simply deleting themselves. If any of them |
431 | // cannot find a way to resolve their blocking uses, we abort the promotion. |
432 | if (failed(computeBlockingUses(info.userToBlockingUses))) |
433 | return {}; |
434 | |
435 | // Then, compute blocks in which two or more definitions of the allocated |
436 | // variable may conflict. These blocks will need a new block argument to |
437 | // accommodate this. |
438 | computeMergePoints(mergePoints&: info.mergePoints); |
439 | |
440 | // The slot can be promoted if the block arguments to be created can |
441 | // actually be populated with values, which may not be possible depending |
442 | // on their predecessors. |
443 | if (!areMergePointsUsable(mergePoints&: info.mergePoints)) |
444 | return {}; |
445 | |
446 | return info; |
447 | } |
448 | |
449 | Value MemorySlotPromoter::computeReachingDefInBlock(Block *block, |
450 | Value reachingDef) { |
451 | SmallVector<Operation *> blockOps; |
452 | for (Operation &op : block->getOperations()) |
453 | blockOps.push_back(Elt: &op); |
454 | for (Operation *op : blockOps) { |
455 | if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) { |
456 | if (info.userToBlockingUses.contains(memOp)) |
457 | reachingDefs.insert({memOp, reachingDef}); |
458 | |
459 | if (memOp.storesTo(slot)) { |
460 | builder.setInsertionPointAfter(memOp); |
461 | Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout); |
462 | assert(stored && "a memory operation storing to a slot must provide a " |
463 | "new definition of the slot"); |
464 | reachingDef = stored; |
465 | replacedValuesMap[memOp] = stored; |
466 | } |
467 | } |
468 | } |
469 | |
470 | return reachingDef; |
471 | } |
472 | |
473 | void MemorySlotPromoter::computeReachingDefInRegion(Region *region, |
474 | Value reachingDef) { |
475 | assert(reachingDef && "expected an initial reaching def to be provided"); |
476 | if (region->hasOneBlock()) { |
477 | computeReachingDefInBlock(block: ®ion->front(), reachingDef); |
478 | return; |
479 | } |
480 | |
481 | struct DfsJob { |
482 | llvm::DomTreeNodeBase<Block> *block; |
483 | Value reachingDef; |
484 | }; |
485 | |
486 | SmallVector<DfsJob> dfsStack; |
487 | |
488 | auto &domTree = dominance.getDomTree(region: slot.ptr.getParentRegion()); |
489 | |
490 | dfsStack.emplace_back<DfsJob>( |
491 | Args: {.block: domTree.getNode(BB: ®ion->front()), .reachingDef: reachingDef}); |
492 | |
493 | while (!dfsStack.empty()) { |
494 | DfsJob job = dfsStack.pop_back_val(); |
495 | Block *block = job.block->getBlock(); |
496 | |
497 | if (info.mergePoints.contains(block)) { |
498 | BlockArgument blockArgument = |
499 | block->addArgument(type: slot.elemType, loc: slot.ptr.getLoc()); |
500 | builder.setInsertionPointToStart(block); |
501 | allocator.handleBlockArgument(slot, blockArgument, builder); |
502 | job.reachingDef = blockArgument; |
503 | |
504 | if (statistics.newBlockArgumentAmount) |
505 | (*statistics.newBlockArgumentAmount)++; |
506 | } |
507 | |
508 | job.reachingDef = computeReachingDefInBlock(block, reachingDef: job.reachingDef); |
509 | assert(job.reachingDef); |
510 | |
511 | if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) { |
512 | for (BlockOperand &blockOperand : terminator->getBlockOperands()) { |
513 | if (info.mergePoints.contains(blockOperand.get())) { |
514 | terminator.getSuccessorOperands(blockOperand.getOperandNumber()) |
515 | .append(job.reachingDef); |
516 | } |
517 | } |
518 | } |
519 | |
520 | for (auto *child : job.block->children()) |
521 | dfsStack.emplace_back<DfsJob>(Args: {.block: child, .reachingDef: job.reachingDef}); |
522 | } |
523 | } |
524 | |
525 | /// Gets or creates a block index mapping for `region`. |
526 | static const DenseMap<Block *, size_t> & |
527 | getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) { |
528 | auto [it, inserted] = blockIndexCache.try_emplace(region); |
529 | if (!inserted) |
530 | return it->second; |
531 | |
532 | DenseMap<Block *, size_t> &blockIndices = it->second; |
533 | SetVector<Block *> topologicalOrder = getBlocksSortedByDominance(region&: *region); |
534 | for (auto [index, block] : llvm::enumerate(First&: topologicalOrder)) |
535 | blockIndices[block] = index; |
536 | return blockIndices; |
537 | } |
538 | |
539 | /// Sorts `ops` according to dominance. Relies on the topological order of basic |
540 | /// blocks to get a deterministic ordering. Uses `blockIndexCache` to avoid the |
541 | /// potentially expensive recomputation of a block index map. |
542 | static void dominanceSort(SmallVector<Operation *> &ops, Region ®ion, |
543 | BlockIndexCache &blockIndexCache) { |
544 | // Produce a topological block order and construct a map to lookup the indices |
545 | // of blocks. |
546 | const DenseMap<Block *, size_t> &topoBlockIndices = |
547 | getOrCreateBlockIndices(blockIndexCache, ®ion); |
548 | |
549 | // Combining the topological order of the basic blocks together with block |
550 | // internal operation order guarantees a deterministic, dominance respecting |
551 | // order. |
552 | llvm::sort(C&: ops, Comp: [&](Operation *lhs, Operation *rhs) { |
553 | size_t lhsBlockIndex = topoBlockIndices.at(Val: lhs->getBlock()); |
554 | size_t rhsBlockIndex = topoBlockIndices.at(Val: rhs->getBlock()); |
555 | if (lhsBlockIndex == rhsBlockIndex) |
556 | return lhs->isBeforeInBlock(other: rhs); |
557 | return lhsBlockIndex < rhsBlockIndex; |
558 | }); |
559 | } |
560 | |
561 | void MemorySlotPromoter::removeBlockingUses() { |
562 | llvm::SmallVector<Operation *> usersToRemoveUses( |
563 | llvm::make_first_range(info.userToBlockingUses)); |
564 | |
565 | // Sort according to dominance. |
566 | dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent(), |
567 | blockIndexCache); |
568 | |
569 | llvm::SmallVector<Operation *> toErase; |
570 | // List of all replaced values in the slot. |
571 | llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList; |
572 | // Ops to visit with the `visitReplacedValues` method. |
573 | llvm::SmallVector<PromotableOpInterface> toVisit; |
574 | for (Operation *toPromote : llvm::reverse(C&: usersToRemoveUses)) { |
575 | if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) { |
576 | Value reachingDef = reachingDefs.lookup(toPromoteMemOp); |
577 | // If no reaching definition is known, this use is outside the reach of |
578 | // the slot. The default value should thus be used. |
579 | if (!reachingDef) |
580 | reachingDef = getOrCreateDefaultValue(); |
581 | |
582 | builder.setInsertionPointAfter(toPromote); |
583 | if (toPromoteMemOp.removeBlockingUses( |
584 | slot, info.userToBlockingUses[toPromote], builder, reachingDef, |
585 | dataLayout) == DeletionKind::Delete) |
586 | toErase.push_back(Elt: toPromote); |
587 | if (toPromoteMemOp.storesTo(slot)) |
588 | if (Value replacedValue = replacedValuesMap[toPromoteMemOp]) |
589 | replacedValuesList.push_back(Elt: {toPromoteMemOp, replacedValue}); |
590 | continue; |
591 | } |
592 | |
593 | auto toPromoteBasic = cast<PromotableOpInterface>(toPromote); |
594 | builder.setInsertionPointAfter(toPromote); |
595 | if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote], |
596 | builder) == DeletionKind::Delete) |
597 | toErase.push_back(Elt: toPromote); |
598 | if (toPromoteBasic.requiresReplacedValues()) |
599 | toVisit.push_back(toPromoteBasic); |
600 | } |
601 | for (PromotableOpInterface op : toVisit) { |
602 | builder.setInsertionPointAfter(op); |
603 | op.visitReplacedValues(replacedValuesList, builder); |
604 | } |
605 | |
606 | for (Operation *toEraseOp : toErase) |
607 | toEraseOp->erase(); |
608 | |
609 | assert(slot.ptr.use_empty() && |
610 | "after promotion, the slot pointer should not be used anymore"); |
611 | } |
612 | |
613 | std::optional<PromotableAllocationOpInterface> |
614 | MemorySlotPromoter::promoteSlot() { |
615 | computeReachingDefInRegion(region: slot.ptr.getParentRegion(), |
616 | reachingDef: getOrCreateDefaultValue()); |
617 | |
618 | // Now that reaching definitions are known, remove all users. |
619 | removeBlockingUses(); |
620 | |
621 | // Update terminators in dead branches to forward default if they are |
622 | // succeeded by a merge points. |
623 | for (Block *mergePoint : info.mergePoints) { |
624 | for (BlockOperand &use : mergePoint->getUses()) { |
625 | auto user = cast<BranchOpInterface>(use.getOwner()); |
626 | SuccessorOperands succOperands = |
627 | user.getSuccessorOperands(use.getOperandNumber()); |
628 | assert(succOperands.size() == mergePoint->getNumArguments() || |
629 | succOperands.size() + 1 == mergePoint->getNumArguments()); |
630 | if (succOperands.size() + 1 == mergePoint->getNumArguments()) |
631 | succOperands.append(getOrCreateDefaultValue()); |
632 | } |
633 | } |
634 | |
635 | LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: "<< slot.ptr |
636 | << "\n"); |
637 | |
638 | if (statistics.promotedAmount) |
639 | (*statistics.promotedAmount)++; |
640 | |
641 | return allocator.handlePromotionComplete(slot, defaultValue, builder); |
642 | } |
643 | |
644 | LogicalResult mlir::tryToPromoteMemorySlots( |
645 | ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder, |
646 | const DataLayout &dataLayout, DominanceInfo &dominance, |
647 | Mem2RegStatistics statistics) { |
648 | bool promotedAny = false; |
649 | |
650 | // A cache that stores deterministic block indices which are used to determine |
651 | // a valid operation modification order. The block index maps are computed |
652 | // lazily and cached to avoid expensive recomputation. |
653 | BlockIndexCache blockIndexCache; |
654 | |
655 | SmallVector<PromotableAllocationOpInterface> workList(allocators); |
656 | |
657 | SmallVector<PromotableAllocationOpInterface> newWorkList; |
658 | newWorkList.reserve(workList.size()); |
659 | while (true) { |
660 | bool changesInThisRound = false; |
661 | for (PromotableAllocationOpInterface allocator : workList) { |
662 | bool changedAllocator = false; |
663 | for (MemorySlot slot : allocator.getPromotableSlots()) { |
664 | if (slot.ptr.use_empty()) |
665 | continue; |
666 | |
667 | MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout); |
668 | std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo(); |
669 | if (info) { |
670 | std::optional<PromotableAllocationOpInterface> newAllocator = |
671 | MemorySlotPromoter(slot, allocator, builder, dominance, |
672 | dataLayout, std::move(*info), statistics, |
673 | blockIndexCache) |
674 | .promoteSlot(); |
675 | changedAllocator = true; |
676 | // Add newly created allocators to the worklist for further |
677 | // processing. |
678 | if (newAllocator) |
679 | newWorkList.push_back(*newAllocator); |
680 | |
681 | // A break is required, since promoting a slot may invalidate the |
682 | // remaining slots of an allocator. |
683 | break; |
684 | } |
685 | } |
686 | if (!changedAllocator) |
687 | newWorkList.push_back(allocator); |
688 | changesInThisRound |= changedAllocator; |
689 | } |
690 | if (!changesInThisRound) |
691 | break; |
692 | promotedAny = true; |
693 | |
694 | // Swap the vector's backing memory and clear the entries in newWorkList |
695 | // afterwards. This ensures that additional heap allocations can be avoided. |
696 | workList.swap(newWorkList); |
697 | newWorkList.clear(); |
698 | } |
699 | |
700 | return success(IsSuccess: promotedAny); |
701 | } |
702 | |
703 | namespace { |
704 | |
705 | struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> { |
706 | using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase; |
707 | |
708 | void runOnOperation() override { |
709 | Operation *scopeOp = getOperation(); |
710 | |
711 | Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount}; |
712 | |
713 | bool changed = false; |
714 | |
715 | auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); |
716 | const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp); |
717 | auto &dominance = getAnalysis<DominanceInfo>(); |
718 | |
719 | for (Region ®ion : scopeOp->getRegions()) { |
720 | if (region.getBlocks().empty()) |
721 | continue; |
722 | |
723 | OpBuilder builder(®ion.front(), region.front().begin()); |
724 | |
725 | SmallVector<PromotableAllocationOpInterface> allocators; |
726 | // Build a list of allocators to attempt to promote the slots of. |
727 | region.walk([&](PromotableAllocationOpInterface allocator) { |
728 | allocators.emplace_back(allocator); |
729 | }); |
730 | |
731 | // Attempt promoting as many of the slots as possible. |
732 | if (succeeded(tryToPromoteMemorySlots(allocators, builder, dataLayout, |
733 | dominance, statistics))) |
734 | changed = true; |
735 | } |
736 | if (!changed) |
737 | markAllAnalysesPreserved(); |
738 | } |
739 | }; |
740 | |
741 | } // namespace |
742 |
Definitions
- MemorySlotPromotionInfo
- MemorySlotPromotionAnalyzer
- MemorySlotPromotionAnalyzer
- MemorySlotPromoter
- MemorySlotPromoter
- getOrCreateDefaultValue
- computeBlockingUses
- computeSlotLiveIn
- computeMergePoints
- areMergePointsUsable
- computeInfo
- computeReachingDefInBlock
- computeReachingDefInRegion
- getOrCreateBlockIndices
- dominanceSort
- removeBlockingUses
- promoteSlot
- tryToPromoteMemorySlots
- Mem2Reg
Improve your Profiling and Debugging skills
Find out more