| 1 | //===- Mem2Reg.cpp - Promotes memory slots into values ----------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Transforms/Mem2Reg.h" |
| 10 | #include "mlir/Analysis/DataLayoutAnalysis.h" |
| 11 | #include "mlir/Analysis/SliceAnalysis.h" |
| 12 | #include "mlir/Analysis/TopologicalSortUtils.h" |
| 13 | #include "mlir/IR/Builders.h" |
| 14 | #include "mlir/IR/Dominance.h" |
| 15 | #include "mlir/IR/PatternMatch.h" |
| 16 | #include "mlir/IR/RegionKindInterface.h" |
| 17 | #include "mlir/IR/Value.h" |
| 18 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
| 19 | #include "mlir/Interfaces/MemorySlotInterfaces.h" |
| 20 | #include "mlir/Transforms/Passes.h" |
| 21 | #include "llvm/ADT/STLExtras.h" |
| 22 | #include "llvm/Support/GenericIteratedDominanceFrontier.h" |
| 23 | |
| 24 | namespace mlir { |
| 25 | #define GEN_PASS_DEF_MEM2REG |
| 26 | #include "mlir/Transforms/Passes.h.inc" |
| 27 | } // namespace mlir |
| 28 | |
| 29 | #define DEBUG_TYPE "mem2reg" |
| 30 | |
| 31 | using namespace mlir; |
| 32 | |
| 33 | /// mem2reg |
| 34 | /// |
| 35 | /// This pass turns unnecessary uses of automatically allocated memory slots |
| 36 | /// into direct Value-based operations. For example, it will simplify storing a |
| 37 | /// constant in a memory slot to immediately load it to a direct use of that |
| 38 | /// constant. In other words, given a memory slot addressed by a non-aliased |
| 39 | /// "pointer" Value, mem2reg removes all the uses of that pointer. |
| 40 | /// |
| 41 | /// Within a block, this is done by following the chain of stores and loads of |
| 42 | /// the slot and replacing the results of loads with the values previously |
| 43 | /// stored. If a load happens before any other store, a poison value is used |
| 44 | /// instead. |
| 45 | /// |
| 46 | /// Control flow can create situations where a load could be replaced by |
| 47 | /// multiple possible stores depending on the control flow path taken. As a |
| 48 | /// result, this pass must introduce new block arguments in some blocks to |
| 49 | /// accommodate for the multiple possible definitions. Each predecessor will |
| 50 | /// populate the block argument with the definition reached at its end. With |
| 51 | /// this, the value stored can be well defined at block boundaries, allowing |
| 52 | /// the propagation of replacement through blocks. |
| 53 | /// |
| 54 | /// This pass computes this transformation in four main steps. The two first |
| 55 | /// steps are performed during an analysis phase that does not mutate IR. |
| 56 | /// |
| 57 | /// The two steps of the analysis phase are the following: |
| 58 | /// - A first step computes the list of operations that transitively use the |
| 59 | /// memory slot we would like to promote. The purpose of this phase is to |
| 60 | /// identify which uses must be removed to promote the slot, either by rewiring |
| 61 | /// the user or deleting it. Naturally, direct uses of the slot must be removed. |
| 62 | /// Sometimes additional uses must also be removed: this is notably the case |
| 63 | /// when a direct user of the slot cannot rewire its use and must delete itself, |
| 64 | /// and thus must make its users no longer use it. If any of those uses cannot |
| 65 | /// be removed by their users in any way, promotion cannot continue: this is |
| 66 | /// decided at this step. |
| 67 | /// - A second step computes the list of blocks where a block argument will be |
| 68 | /// needed ("merge points") without mutating the IR. These blocks are the blocks |
| 69 | /// leading to a definition clash between two predecessors. Such blocks happen |
| 70 | /// to be the Iterated Dominance Frontier (IDF) of the set of blocks containing |
| 71 | /// a store, as they represent the point where a clear defining dominator stops |
| 72 | /// existing. Computing this information in advance allows making sure the |
| 73 | /// terminators that will forward values are capable of doing so (inability to |
| 74 | /// do so aborts promotion at this step). |
| 75 | /// |
| 76 | /// At this point, promotion is guaranteed to happen, and the mutation phase can |
| 77 | /// begin with the following steps: |
| 78 | /// - A third step computes the reaching definition of the memory slot at each |
| 79 | /// blocking user. This is the core of the mem2reg algorithm, also known as |
| 80 | /// load-store forwarding. This analyses loads and stores and propagates which |
| 81 | /// value must be stored in the slot at each blocking user. This is achieved by |
| 82 | /// doing a depth-first walk of the dominator tree of the function. This is |
| 83 | /// sufficient because the reaching definition at the beginning of a block is |
| 84 | /// either its new block argument if it is a merge block, or the definition |
| 85 | /// reaching the end of its immediate dominator (parent in the dominator tree). |
| 86 | /// We can therefore propagate this information down the dominator tree to |
| 87 | /// proceed with renaming within blocks. |
| 88 | /// - The final fourth step uses the reaching definition to remove blocking uses |
| 89 | /// in topological order. |
| 90 | /// |
| 91 | /// For further reading, chapter three of SSA-based Compiler Design [1] |
| 92 | /// showcases SSA construction, where mem2reg is an adaptation of the same |
| 93 | /// process. |
| 94 | /// |
| 95 | /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022), |
| 96 | /// Springer. |
| 97 | |
| 98 | namespace { |
| 99 | |
| 100 | using BlockingUsesMap = |
| 101 | llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>; |
| 102 | |
| 103 | /// Information computed during promotion analysis used to perform actual |
| 104 | /// promotion. |
| 105 | struct MemorySlotPromotionInfo { |
| 106 | /// Blocks for which at least two definitions of the slot values clash. |
| 107 | SmallPtrSet<Block *, 8> mergePoints; |
| 108 | /// Contains, for each operation, which uses must be eliminated by promotion. |
| 109 | /// This is a DAG structure because if an operation must eliminate some of |
| 110 | /// its uses, it is because the defining ops of the blocking uses requested |
| 111 | /// it. The defining ops therefore must also have blocking uses or be the |
| 112 | /// starting point of the blocking uses. |
| 113 | BlockingUsesMap userToBlockingUses; |
| 114 | }; |
| 115 | |
| 116 | /// Computes information for basic slot promotion. This will check that direct |
| 117 | /// slot promotion can be performed, and provide the information to execute the |
| 118 | /// promotion. This does not mutate IR. |
| 119 | class MemorySlotPromotionAnalyzer { |
| 120 | public: |
| 121 | MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance, |
| 122 | const DataLayout &dataLayout) |
| 123 | : slot(slot), dominance(dominance), dataLayout(dataLayout) {} |
| 124 | |
| 125 | /// Computes the information for slot promotion if promotion is possible, |
| 126 | /// returns nothing otherwise. |
| 127 | std::optional<MemorySlotPromotionInfo> computeInfo(); |
| 128 | |
| 129 | private: |
| 130 | /// Computes the transitive uses of the slot that block promotion. This finds |
| 131 | /// uses that would block the promotion, checks that the operation has a |
| 132 | /// solution to remove the blocking use, and potentially forwards the analysis |
| 133 | /// if the operation needs further blocking uses resolved to resolve its own |
| 134 | /// uses (typically, removing its users because it will delete itself to |
| 135 | /// resolve its own blocking uses). This will fail if one of the transitive |
| 136 | /// users cannot remove a requested use, and should prevent promotion. |
| 137 | LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses); |
| 138 | |
| 139 | /// Computes in which blocks the value stored in the slot is actually used, |
| 140 | /// meaning blocks leading to a load. This method uses `definingBlocks`, the |
| 141 | /// set of blocks containing a store to the slot (defining the value of the |
| 142 | /// slot). |
| 143 | SmallPtrSet<Block *, 16> |
| 144 | computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks); |
| 145 | |
| 146 | /// Computes the points in which multiple re-definitions of the slot's value |
| 147 | /// (stores) may conflict. |
| 148 | void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints); |
| 149 | |
| 150 | /// Ensures predecessors of merge points can properly provide their current |
| 151 | /// definition of the value stored in the slot to the merge point. This can |
| 152 | /// notably be an issue if the terminator used does not have the ability to |
| 153 | /// forward values through block operands. |
| 154 | bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints); |
| 155 | |
| 156 | MemorySlot slot; |
| 157 | DominanceInfo &dominance; |
| 158 | const DataLayout &dataLayout; |
| 159 | }; |
| 160 | |
| 161 | using BlockIndexCache = DenseMap<Region *, DenseMap<Block *, size_t>>; |
| 162 | |
| 163 | /// The MemorySlotPromoter handles the state of promoting a memory slot. It |
| 164 | /// wraps a slot and its associated allocator. This will perform the mutation of |
| 165 | /// IR. |
| 166 | class MemorySlotPromoter { |
| 167 | public: |
| 168 | MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, |
| 169 | OpBuilder &builder, DominanceInfo &dominance, |
| 170 | const DataLayout &dataLayout, MemorySlotPromotionInfo info, |
| 171 | const Mem2RegStatistics &statistics, |
| 172 | BlockIndexCache &blockIndexCache); |
| 173 | |
| 174 | /// Actually promotes the slot by mutating IR. Promoting a slot DOES |
| 175 | /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of |
| 176 | /// promotion info should NOT be performed in batches. |
| 177 | /// Returns a promotable allocation op if a new allocator was created, nullopt |
| 178 | /// otherwise. |
| 179 | std::optional<PromotableAllocationOpInterface> promoteSlot(); |
| 180 | |
| 181 | private: |
| 182 | /// Computes the reaching definition for all the operations that require |
| 183 | /// promotion. `reachingDef` is the value the slot should contain at the |
| 184 | /// beginning of the block. This method returns the reached definition at the |
| 185 | /// end of the block. This method must only be called at most once per block. |
| 186 | Value computeReachingDefInBlock(Block *block, Value reachingDef); |
| 187 | |
| 188 | /// Computes the reaching definition for all the operations that require |
| 189 | /// promotion. `reachingDef` corresponds to the initial value the |
| 190 | /// slot will contain before any write, typically a poison value. |
| 191 | /// This method must only be called at most once per region. |
| 192 | void computeReachingDefInRegion(Region *region, Value reachingDef); |
| 193 | |
| 194 | /// Removes the blocking uses of the slot, in topological order. |
| 195 | void removeBlockingUses(); |
| 196 | |
| 197 | /// Lazily-constructed default value representing the content of the slot when |
| 198 | /// no store has been executed. This function may mutate IR. |
| 199 | Value getOrCreateDefaultValue(); |
| 200 | |
| 201 | MemorySlot slot; |
| 202 | PromotableAllocationOpInterface allocator; |
| 203 | OpBuilder &builder; |
| 204 | /// Potentially non-initialized default value. Use `getOrCreateDefaultValue` |
| 205 | /// to initialize it on demand. |
| 206 | Value defaultValue; |
| 207 | /// Contains the reaching definition at this operation. Reaching definitions |
| 208 | /// are only computed for promotable memory operations with blocking uses. |
| 209 | DenseMap<PromotableMemOpInterface, Value> reachingDefs; |
| 210 | DenseMap<PromotableMemOpInterface, Value> replacedValuesMap; |
| 211 | DominanceInfo &dominance; |
| 212 | const DataLayout &dataLayout; |
| 213 | MemorySlotPromotionInfo info; |
| 214 | const Mem2RegStatistics &statistics; |
| 215 | |
| 216 | /// Shared cache of block indices of specific regions. |
| 217 | BlockIndexCache &blockIndexCache; |
| 218 | }; |
| 219 | |
| 220 | } // namespace |
| 221 | |
| 222 | MemorySlotPromoter::MemorySlotPromoter( |
| 223 | MemorySlot slot, PromotableAllocationOpInterface allocator, |
| 224 | OpBuilder &builder, DominanceInfo &dominance, const DataLayout &dataLayout, |
| 225 | MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics, |
| 226 | BlockIndexCache &blockIndexCache) |
| 227 | : slot(slot), allocator(allocator), builder(builder), dominance(dominance), |
| 228 | dataLayout(dataLayout), info(std::move(info)), statistics(statistics), |
| 229 | blockIndexCache(blockIndexCache) { |
| 230 | #ifndef NDEBUG |
| 231 | auto isResultOrNewBlockArgument = [&]() { |
| 232 | if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr)) |
| 233 | return arg.getOwner()->getParentOp() == allocator; |
| 234 | return slot.ptr.getDefiningOp() == allocator; |
| 235 | }; |
| 236 | |
| 237 | assert(isResultOrNewBlockArgument() && |
| 238 | "a slot must be a result of the allocator or an argument of the child " |
| 239 | "regions of the allocator" ); |
| 240 | #endif // NDEBUG |
| 241 | } |
| 242 | |
| 243 | Value MemorySlotPromoter::getOrCreateDefaultValue() { |
| 244 | if (defaultValue) |
| 245 | return defaultValue; |
| 246 | |
| 247 | OpBuilder::InsertionGuard guard(builder); |
| 248 | builder.setInsertionPointToStart(slot.ptr.getParentBlock()); |
| 249 | return defaultValue = allocator.getDefaultValue(slot, builder); |
| 250 | } |
| 251 | |
| 252 | LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( |
| 253 | BlockingUsesMap &userToBlockingUses) { |
| 254 | // The promotion of an operation may require the promotion of further |
| 255 | // operations (typically, removing operations that use an operation that must |
| 256 | // delete itself). We thus need to start from the use of the slot pointer and |
| 257 | // propagate further requests through the forward slice. |
| 258 | |
| 259 | // Because this pass currently only supports analysing the parent region of |
| 260 | // the slot pointer, if a promotable memory op that needs promotion is within |
| 261 | // a graph region, the slot may only be used in a graph region and should |
| 262 | // therefore be ignored. |
| 263 | Region *slotPtrRegion = slot.ptr.getParentRegion(); |
| 264 | auto slotPtrRegionOp = |
| 265 | dyn_cast<RegionKindInterface>(slotPtrRegion->getParentOp()); |
| 266 | if (slotPtrRegionOp && |
| 267 | slotPtrRegionOp.getRegionKind(slotPtrRegion->getRegionNumber()) == |
| 268 | RegionKind::Graph) |
| 269 | return failure(); |
| 270 | |
| 271 | // First insert that all immediate users of the slot pointer must no longer |
| 272 | // use it. |
| 273 | for (OpOperand &use : slot.ptr.getUses()) { |
| 274 | SmallPtrSet<OpOperand *, 4> &blockingUses = |
| 275 | userToBlockingUses[use.getOwner()]; |
| 276 | blockingUses.insert(Ptr: &use); |
| 277 | } |
| 278 | |
| 279 | // Then, propagate the requirements for the removal of uses. The |
| 280 | // topologically-sorted forward slice allows for all blocking uses of an |
| 281 | // operation to have been computed before it is reached. Operations are |
| 282 | // traversed in topological order of their uses, starting from the slot |
| 283 | // pointer. |
| 284 | SetVector<Operation *> forwardSlice; |
| 285 | mlir::getForwardSlice(root: slot.ptr, forwardSlice: &forwardSlice); |
| 286 | for (Operation *user : forwardSlice) { |
| 287 | // If the next operation has no blocking uses, everything is fine. |
| 288 | auto it = userToBlockingUses.find(user); |
| 289 | if (it == userToBlockingUses.end()) |
| 290 | continue; |
| 291 | |
| 292 | SmallPtrSet<OpOperand *, 4> &blockingUses = it->second; |
| 293 | |
| 294 | SmallVector<OpOperand *> newBlockingUses; |
| 295 | // If the operation decides it cannot deal with removing the blocking uses, |
| 296 | // promotion must fail. |
| 297 | if (auto promotable = dyn_cast<PromotableOpInterface>(user)) { |
| 298 | if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, |
| 299 | dataLayout)) |
| 300 | return failure(); |
| 301 | } else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) { |
| 302 | if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses, |
| 303 | dataLayout)) |
| 304 | return failure(); |
| 305 | } else { |
| 306 | // An operation that has blocking uses must be promoted. If it is not |
| 307 | // promotable, promotion must fail. |
| 308 | return failure(); |
| 309 | } |
| 310 | |
| 311 | // Then, register any new blocking uses for coming operations. |
| 312 | for (OpOperand *blockingUse : newBlockingUses) { |
| 313 | assert(llvm::is_contained(user->getResults(), blockingUse->get())); |
| 314 | |
| 315 | SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet = |
| 316 | userToBlockingUses[blockingUse->getOwner()]; |
| 317 | newUserBlockingUseSet.insert(Ptr: blockingUse); |
| 318 | } |
| 319 | } |
| 320 | |
| 321 | // Because this pass currently only supports analysing the parent region of |
| 322 | // the slot pointer, if a promotable memory op that needs promotion is outside |
| 323 | // of this region, promotion must fail because it will be impossible to |
| 324 | // provide a valid `reachingDef` for it. |
| 325 | for (auto &[toPromote, _] : userToBlockingUses) |
| 326 | if (isa<PromotableMemOpInterface>(toPromote) && |
| 327 | toPromote->getParentRegion() != slot.ptr.getParentRegion()) |
| 328 | return failure(); |
| 329 | |
| 330 | return success(); |
| 331 | } |
| 332 | |
| 333 | SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn( |
| 334 | SmallPtrSetImpl<Block *> &definingBlocks) { |
| 335 | SmallPtrSet<Block *, 16> liveIn; |
| 336 | |
| 337 | // The worklist contains blocks in which it is known that the slot value is |
| 338 | // live-in. The further blocks where this value is live-in will be inferred |
| 339 | // from these. |
| 340 | SmallVector<Block *> liveInWorkList; |
| 341 | |
| 342 | // Blocks with a load before any other store to the slot are the starting |
| 343 | // points of the analysis. The slot value is definitely live-in in those |
| 344 | // blocks. |
| 345 | SmallPtrSet<Block *, 16> visited; |
| 346 | for (Operation *user : slot.ptr.getUsers()) { |
| 347 | if (!visited.insert(Ptr: user->getBlock()).second) |
| 348 | continue; |
| 349 | |
| 350 | for (Operation &op : user->getBlock()->getOperations()) { |
| 351 | if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) { |
| 352 | // If this operation loads the slot, it is loading from it before |
| 353 | // ever writing to it, so the value is live-in in this block. |
| 354 | if (memOp.loadsFrom(slot)) { |
| 355 | liveInWorkList.push_back(Elt: user->getBlock()); |
| 356 | break; |
| 357 | } |
| 358 | |
| 359 | // If we store to the slot, further loads will see that value. |
| 360 | // Because we did not meet any load before, the value is not live-in. |
| 361 | if (memOp.storesTo(slot)) |
| 362 | break; |
| 363 | } |
| 364 | } |
| 365 | } |
| 366 | |
| 367 | // The information is then propagated to the predecessors until a def site |
| 368 | // (store) is found. |
| 369 | while (!liveInWorkList.empty()) { |
| 370 | Block *liveInBlock = liveInWorkList.pop_back_val(); |
| 371 | |
| 372 | if (!liveIn.insert(Ptr: liveInBlock).second) |
| 373 | continue; |
| 374 | |
| 375 | // If a predecessor is a defining block, either: |
| 376 | // - It has a load before its first store, in which case it is live-in but |
| 377 | // has already been processed in the initialisation step. |
| 378 | // - It has a store before any load, in which case it is not live-in. |
| 379 | // We can thus at this stage insert to the worklist only predecessors that |
| 380 | // are not defining blocks. |
| 381 | for (Block *pred : liveInBlock->getPredecessors()) |
| 382 | if (!definingBlocks.contains(Ptr: pred)) |
| 383 | liveInWorkList.push_back(Elt: pred); |
| 384 | } |
| 385 | |
| 386 | return liveIn; |
| 387 | } |
| 388 | |
| 389 | using IDFCalculator = llvm::IDFCalculatorBase<Block, false>; |
| 390 | void MemorySlotPromotionAnalyzer::computeMergePoints( |
| 391 | SmallPtrSetImpl<Block *> &mergePoints) { |
| 392 | if (slot.ptr.getParentRegion()->hasOneBlock()) |
| 393 | return; |
| 394 | |
| 395 | IDFCalculator idfCalculator(dominance.getDomTree(region: slot.ptr.getParentRegion())); |
| 396 | |
| 397 | SmallPtrSet<Block *, 16> definingBlocks; |
| 398 | for (Operation *user : slot.ptr.getUsers()) |
| 399 | if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user)) |
| 400 | if (storeOp.storesTo(slot)) |
| 401 | definingBlocks.insert(Ptr: user->getBlock()); |
| 402 | |
| 403 | idfCalculator.setDefiningBlocks(definingBlocks); |
| 404 | |
| 405 | SmallPtrSet<Block *, 16> liveIn = computeSlotLiveIn(definingBlocks); |
| 406 | idfCalculator.setLiveInBlocks(liveIn); |
| 407 | |
| 408 | SmallVector<Block *> mergePointsVec; |
| 409 | idfCalculator.calculate(IDFBlocks&: mergePointsVec); |
| 410 | |
| 411 | mergePoints.insert_range(R&: mergePointsVec); |
| 412 | } |
| 413 | |
| 414 | bool MemorySlotPromotionAnalyzer::areMergePointsUsable( |
| 415 | SmallPtrSetImpl<Block *> &mergePoints) { |
| 416 | for (Block *mergePoint : mergePoints) |
| 417 | for (Block *pred : mergePoint->getPredecessors()) |
| 418 | if (!isa<BranchOpInterface>(Val: pred->getTerminator())) |
| 419 | return false; |
| 420 | |
| 421 | return true; |
| 422 | } |
| 423 | |
| 424 | std::optional<MemorySlotPromotionInfo> |
| 425 | MemorySlotPromotionAnalyzer::computeInfo() { |
| 426 | MemorySlotPromotionInfo info; |
| 427 | |
| 428 | // First, find the set of operations that will need to be changed for the |
| 429 | // promotion to happen. These operations need to resolve some of their uses, |
| 430 | // either by rewiring them or simply deleting themselves. If any of them |
| 431 | // cannot find a way to resolve their blocking uses, we abort the promotion. |
| 432 | if (failed(computeBlockingUses(info.userToBlockingUses))) |
| 433 | return {}; |
| 434 | |
| 435 | // Then, compute blocks in which two or more definitions of the allocated |
| 436 | // variable may conflict. These blocks will need a new block argument to |
| 437 | // accommodate this. |
| 438 | computeMergePoints(mergePoints&: info.mergePoints); |
| 439 | |
| 440 | // The slot can be promoted if the block arguments to be created can |
| 441 | // actually be populated with values, which may not be possible depending |
| 442 | // on their predecessors. |
| 443 | if (!areMergePointsUsable(mergePoints&: info.mergePoints)) |
| 444 | return {}; |
| 445 | |
| 446 | return info; |
| 447 | } |
| 448 | |
| 449 | Value MemorySlotPromoter::computeReachingDefInBlock(Block *block, |
| 450 | Value reachingDef) { |
| 451 | SmallVector<Operation *> blockOps; |
| 452 | for (Operation &op : block->getOperations()) |
| 453 | blockOps.push_back(Elt: &op); |
| 454 | for (Operation *op : blockOps) { |
| 455 | if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) { |
| 456 | if (info.userToBlockingUses.contains(memOp)) |
| 457 | reachingDefs.insert({memOp, reachingDef}); |
| 458 | |
| 459 | if (memOp.storesTo(slot)) { |
| 460 | builder.setInsertionPointAfter(memOp); |
| 461 | Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout); |
| 462 | assert(stored && "a memory operation storing to a slot must provide a " |
| 463 | "new definition of the slot" ); |
| 464 | reachingDef = stored; |
| 465 | replacedValuesMap[memOp] = stored; |
| 466 | } |
| 467 | } |
| 468 | } |
| 469 | |
| 470 | return reachingDef; |
| 471 | } |
| 472 | |
| 473 | void MemorySlotPromoter::computeReachingDefInRegion(Region *region, |
| 474 | Value reachingDef) { |
| 475 | assert(reachingDef && "expected an initial reaching def to be provided" ); |
| 476 | if (region->hasOneBlock()) { |
| 477 | computeReachingDefInBlock(block: ®ion->front(), reachingDef); |
| 478 | return; |
| 479 | } |
| 480 | |
| 481 | struct DfsJob { |
| 482 | llvm::DomTreeNodeBase<Block> *block; |
| 483 | Value reachingDef; |
| 484 | }; |
| 485 | |
| 486 | SmallVector<DfsJob> dfsStack; |
| 487 | |
| 488 | auto &domTree = dominance.getDomTree(region: slot.ptr.getParentRegion()); |
| 489 | |
| 490 | dfsStack.emplace_back<DfsJob>( |
| 491 | Args: {.block: domTree.getNode(BB: ®ion->front()), .reachingDef: reachingDef}); |
| 492 | |
| 493 | while (!dfsStack.empty()) { |
| 494 | DfsJob job = dfsStack.pop_back_val(); |
| 495 | Block *block = job.block->getBlock(); |
| 496 | |
| 497 | if (info.mergePoints.contains(block)) { |
| 498 | BlockArgument blockArgument = |
| 499 | block->addArgument(type: slot.elemType, loc: slot.ptr.getLoc()); |
| 500 | builder.setInsertionPointToStart(block); |
| 501 | allocator.handleBlockArgument(slot, blockArgument, builder); |
| 502 | job.reachingDef = blockArgument; |
| 503 | |
| 504 | if (statistics.newBlockArgumentAmount) |
| 505 | (*statistics.newBlockArgumentAmount)++; |
| 506 | } |
| 507 | |
| 508 | job.reachingDef = computeReachingDefInBlock(block, reachingDef: job.reachingDef); |
| 509 | assert(job.reachingDef); |
| 510 | |
| 511 | if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) { |
| 512 | for (BlockOperand &blockOperand : terminator->getBlockOperands()) { |
| 513 | if (info.mergePoints.contains(blockOperand.get())) { |
| 514 | terminator.getSuccessorOperands(blockOperand.getOperandNumber()) |
| 515 | .append(job.reachingDef); |
| 516 | } |
| 517 | } |
| 518 | } |
| 519 | |
| 520 | for (auto *child : job.block->children()) |
| 521 | dfsStack.emplace_back<DfsJob>(Args: {.block: child, .reachingDef: job.reachingDef}); |
| 522 | } |
| 523 | } |
| 524 | |
| 525 | /// Gets or creates a block index mapping for `region`. |
| 526 | static const DenseMap<Block *, size_t> & |
| 527 | getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) { |
| 528 | auto [it, inserted] = blockIndexCache.try_emplace(region); |
| 529 | if (!inserted) |
| 530 | return it->second; |
| 531 | |
| 532 | DenseMap<Block *, size_t> &blockIndices = it->second; |
| 533 | SetVector<Block *> topologicalOrder = getBlocksSortedByDominance(region&: *region); |
| 534 | for (auto [index, block] : llvm::enumerate(First&: topologicalOrder)) |
| 535 | blockIndices[block] = index; |
| 536 | return blockIndices; |
| 537 | } |
| 538 | |
| 539 | /// Sorts `ops` according to dominance. Relies on the topological order of basic |
| 540 | /// blocks to get a deterministic ordering. Uses `blockIndexCache` to avoid the |
| 541 | /// potentially expensive recomputation of a block index map. |
| 542 | static void dominanceSort(SmallVector<Operation *> &ops, Region ®ion, |
| 543 | BlockIndexCache &blockIndexCache) { |
| 544 | // Produce a topological block order and construct a map to lookup the indices |
| 545 | // of blocks. |
| 546 | const DenseMap<Block *, size_t> &topoBlockIndices = |
| 547 | getOrCreateBlockIndices(blockIndexCache, ®ion); |
| 548 | |
| 549 | // Combining the topological order of the basic blocks together with block |
| 550 | // internal operation order guarantees a deterministic, dominance respecting |
| 551 | // order. |
| 552 | llvm::sort(C&: ops, Comp: [&](Operation *lhs, Operation *rhs) { |
| 553 | size_t lhsBlockIndex = topoBlockIndices.at(Val: lhs->getBlock()); |
| 554 | size_t rhsBlockIndex = topoBlockIndices.at(Val: rhs->getBlock()); |
| 555 | if (lhsBlockIndex == rhsBlockIndex) |
| 556 | return lhs->isBeforeInBlock(other: rhs); |
| 557 | return lhsBlockIndex < rhsBlockIndex; |
| 558 | }); |
| 559 | } |
| 560 | |
| 561 | void MemorySlotPromoter::removeBlockingUses() { |
| 562 | llvm::SmallVector<Operation *> usersToRemoveUses( |
| 563 | llvm::make_first_range(info.userToBlockingUses)); |
| 564 | |
| 565 | // Sort according to dominance. |
| 566 | dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent(), |
| 567 | blockIndexCache); |
| 568 | |
| 569 | llvm::SmallVector<Operation *> toErase; |
| 570 | // List of all replaced values in the slot. |
| 571 | llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList; |
| 572 | // Ops to visit with the `visitReplacedValues` method. |
| 573 | llvm::SmallVector<PromotableOpInterface> toVisit; |
| 574 | for (Operation *toPromote : llvm::reverse(C&: usersToRemoveUses)) { |
| 575 | if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) { |
| 576 | Value reachingDef = reachingDefs.lookup(toPromoteMemOp); |
| 577 | // If no reaching definition is known, this use is outside the reach of |
| 578 | // the slot. The default value should thus be used. |
| 579 | if (!reachingDef) |
| 580 | reachingDef = getOrCreateDefaultValue(); |
| 581 | |
| 582 | builder.setInsertionPointAfter(toPromote); |
| 583 | if (toPromoteMemOp.removeBlockingUses( |
| 584 | slot, info.userToBlockingUses[toPromote], builder, reachingDef, |
| 585 | dataLayout) == DeletionKind::Delete) |
| 586 | toErase.push_back(Elt: toPromote); |
| 587 | if (toPromoteMemOp.storesTo(slot)) |
| 588 | if (Value replacedValue = replacedValuesMap[toPromoteMemOp]) |
| 589 | replacedValuesList.push_back(Elt: {toPromoteMemOp, replacedValue}); |
| 590 | continue; |
| 591 | } |
| 592 | |
| 593 | auto toPromoteBasic = cast<PromotableOpInterface>(toPromote); |
| 594 | builder.setInsertionPointAfter(toPromote); |
| 595 | if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote], |
| 596 | builder) == DeletionKind::Delete) |
| 597 | toErase.push_back(Elt: toPromote); |
| 598 | if (toPromoteBasic.requiresReplacedValues()) |
| 599 | toVisit.push_back(toPromoteBasic); |
| 600 | } |
| 601 | for (PromotableOpInterface op : toVisit) { |
| 602 | builder.setInsertionPointAfter(op); |
| 603 | op.visitReplacedValues(replacedValuesList, builder); |
| 604 | } |
| 605 | |
| 606 | for (Operation *toEraseOp : toErase) |
| 607 | toEraseOp->erase(); |
| 608 | |
| 609 | assert(slot.ptr.use_empty() && |
| 610 | "after promotion, the slot pointer should not be used anymore" ); |
| 611 | } |
| 612 | |
| 613 | std::optional<PromotableAllocationOpInterface> |
| 614 | MemorySlotPromoter::promoteSlot() { |
| 615 | computeReachingDefInRegion(region: slot.ptr.getParentRegion(), |
| 616 | reachingDef: getOrCreateDefaultValue()); |
| 617 | |
| 618 | // Now that reaching definitions are known, remove all users. |
| 619 | removeBlockingUses(); |
| 620 | |
| 621 | // Update terminators in dead branches to forward default if they are |
| 622 | // succeeded by a merge points. |
| 623 | for (Block *mergePoint : info.mergePoints) { |
| 624 | for (BlockOperand &use : mergePoint->getUses()) { |
| 625 | auto user = cast<BranchOpInterface>(use.getOwner()); |
| 626 | SuccessorOperands succOperands = |
| 627 | user.getSuccessorOperands(use.getOperandNumber()); |
| 628 | assert(succOperands.size() == mergePoint->getNumArguments() || |
| 629 | succOperands.size() + 1 == mergePoint->getNumArguments()); |
| 630 | if (succOperands.size() + 1 == mergePoint->getNumArguments()) |
| 631 | succOperands.append(getOrCreateDefaultValue()); |
| 632 | } |
| 633 | } |
| 634 | |
| 635 | LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr |
| 636 | << "\n" ); |
| 637 | |
| 638 | if (statistics.promotedAmount) |
| 639 | (*statistics.promotedAmount)++; |
| 640 | |
| 641 | return allocator.handlePromotionComplete(slot, defaultValue, builder); |
| 642 | } |
| 643 | |
| 644 | LogicalResult mlir::tryToPromoteMemorySlots( |
| 645 | ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder, |
| 646 | const DataLayout &dataLayout, DominanceInfo &dominance, |
| 647 | Mem2RegStatistics statistics) { |
| 648 | bool promotedAny = false; |
| 649 | |
| 650 | // A cache that stores deterministic block indices which are used to determine |
| 651 | // a valid operation modification order. The block index maps are computed |
| 652 | // lazily and cached to avoid expensive recomputation. |
| 653 | BlockIndexCache blockIndexCache; |
| 654 | |
| 655 | SmallVector<PromotableAllocationOpInterface> workList(allocators); |
| 656 | |
| 657 | SmallVector<PromotableAllocationOpInterface> newWorkList; |
| 658 | newWorkList.reserve(workList.size()); |
| 659 | while (true) { |
| 660 | bool changesInThisRound = false; |
| 661 | for (PromotableAllocationOpInterface allocator : workList) { |
| 662 | bool changedAllocator = false; |
| 663 | for (MemorySlot slot : allocator.getPromotableSlots()) { |
| 664 | if (slot.ptr.use_empty()) |
| 665 | continue; |
| 666 | |
| 667 | MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout); |
| 668 | std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo(); |
| 669 | if (info) { |
| 670 | std::optional<PromotableAllocationOpInterface> newAllocator = |
| 671 | MemorySlotPromoter(slot, allocator, builder, dominance, |
| 672 | dataLayout, std::move(*info), statistics, |
| 673 | blockIndexCache) |
| 674 | .promoteSlot(); |
| 675 | changedAllocator = true; |
| 676 | // Add newly created allocators to the worklist for further |
| 677 | // processing. |
| 678 | if (newAllocator) |
| 679 | newWorkList.push_back(*newAllocator); |
| 680 | |
| 681 | // A break is required, since promoting a slot may invalidate the |
| 682 | // remaining slots of an allocator. |
| 683 | break; |
| 684 | } |
| 685 | } |
| 686 | if (!changedAllocator) |
| 687 | newWorkList.push_back(allocator); |
| 688 | changesInThisRound |= changedAllocator; |
| 689 | } |
| 690 | if (!changesInThisRound) |
| 691 | break; |
| 692 | promotedAny = true; |
| 693 | |
| 694 | // Swap the vector's backing memory and clear the entries in newWorkList |
| 695 | // afterwards. This ensures that additional heap allocations can be avoided. |
| 696 | workList.swap(newWorkList); |
| 697 | newWorkList.clear(); |
| 698 | } |
| 699 | |
| 700 | return success(IsSuccess: promotedAny); |
| 701 | } |
| 702 | |
| 703 | namespace { |
| 704 | |
| 705 | struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> { |
| 706 | using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase; |
| 707 | |
| 708 | void runOnOperation() override { |
| 709 | Operation *scopeOp = getOperation(); |
| 710 | |
| 711 | Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount}; |
| 712 | |
| 713 | bool changed = false; |
| 714 | |
| 715 | auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); |
| 716 | const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp); |
| 717 | auto &dominance = getAnalysis<DominanceInfo>(); |
| 718 | |
| 719 | for (Region ®ion : scopeOp->getRegions()) { |
| 720 | if (region.getBlocks().empty()) |
| 721 | continue; |
| 722 | |
| 723 | OpBuilder builder(®ion.front(), region.front().begin()); |
| 724 | |
| 725 | SmallVector<PromotableAllocationOpInterface> allocators; |
| 726 | // Build a list of allocators to attempt to promote the slots of. |
| 727 | region.walk([&](PromotableAllocationOpInterface allocator) { |
| 728 | allocators.emplace_back(allocator); |
| 729 | }); |
| 730 | |
| 731 | // Attempt promoting as many of the slots as possible. |
| 732 | if (succeeded(tryToPromoteMemorySlots(allocators, builder, dataLayout, |
| 733 | dominance, statistics))) |
| 734 | changed = true; |
| 735 | } |
| 736 | if (!changed) |
| 737 | markAllAnalysesPreserved(); |
| 738 | } |
| 739 | }; |
| 740 | |
| 741 | } // namespace |
| 742 | |