| 1 | //===- CSE.cpp - Common Sub-expression Elimination ------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This transformation pass performs a simple common sub-expression elimination |
| 10 | // algorithm on operations within a region. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Transforms/CSE.h" |
| 15 | |
| 16 | #include "mlir/IR/Dominance.h" |
| 17 | #include "mlir/IR/PatternMatch.h" |
| 18 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
| 19 | #include "mlir/Pass/Pass.h" |
| 20 | #include "mlir/Transforms/Passes.h" |
| 21 | #include "llvm/ADT/DenseMapInfo.h" |
| 22 | #include "llvm/ADT/Hashing.h" |
| 23 | #include "llvm/ADT/ScopedHashTable.h" |
| 24 | #include "llvm/Support/Allocator.h" |
| 25 | #include "llvm/Support/RecyclingAllocator.h" |
| 26 | #include <deque> |
| 27 | |
| 28 | namespace mlir { |
| 29 | #define GEN_PASS_DEF_CSE |
| 30 | #include "mlir/Transforms/Passes.h.inc" |
| 31 | } // namespace mlir |
| 32 | |
| 33 | using namespace mlir; |
| 34 | |
| 35 | namespace { |
| 36 | struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> { |
| 37 | static unsigned getHashValue(const Operation *opC) { |
| 38 | return OperationEquivalence::computeHash( |
| 39 | const_cast<Operation *>(opC), |
| 40 | /*hashOperands=*/OperationEquivalence::directHashValue, |
| 41 | /*hashResults=*/OperationEquivalence::ignoreHashValue, |
| 42 | OperationEquivalence::IgnoreLocations); |
| 43 | } |
| 44 | static bool isEqual(const Operation *lhsC, const Operation *rhsC) { |
| 45 | auto *lhs = const_cast<Operation *>(lhsC); |
| 46 | auto *rhs = const_cast<Operation *>(rhsC); |
| 47 | if (lhs == rhs) |
| 48 | return true; |
| 49 | if (lhs == getTombstoneKey() || lhs == getEmptyKey() || |
| 50 | rhs == getTombstoneKey() || rhs == getEmptyKey()) |
| 51 | return false; |
| 52 | return OperationEquivalence::isEquivalentTo( |
| 53 | lhs: const_cast<Operation *>(lhsC), rhs: const_cast<Operation *>(rhsC), |
| 54 | flags: OperationEquivalence::IgnoreLocations); |
| 55 | } |
| 56 | }; |
| 57 | } // namespace |
| 58 | |
| 59 | namespace { |
| 60 | /// Simple common sub-expression elimination. |
| 61 | class CSEDriver { |
| 62 | public: |
| 63 | CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo) |
| 64 | : rewriter(rewriter), domInfo(domInfo) {} |
| 65 | |
| 66 | /// Simplify all operations within the given op. |
| 67 | void simplify(Operation *op, bool *changed = nullptr); |
| 68 | |
| 69 | int64_t getNumCSE() const { return numCSE; } |
| 70 | int64_t getNumDCE() const { return numDCE; } |
| 71 | |
| 72 | private: |
| 73 | /// Shared implementation of operation elimination and scoped map definitions. |
| 74 | using AllocatorTy = llvm::RecyclingAllocator< |
| 75 | llvm::BumpPtrAllocator, |
| 76 | llvm::ScopedHashTableVal<Operation *, Operation *>>; |
| 77 | using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *, |
| 78 | SimpleOperationInfo, AllocatorTy>; |
| 79 | |
| 80 | /// Cache holding MemoryEffects information between two operations. The first |
| 81 | /// operation is stored has the key. The second operation is stored inside a |
| 82 | /// pair in the value. The pair also hold the MemoryEffects between those |
| 83 | /// two operations. If the MemoryEffects is nullptr then we assume there is |
| 84 | /// no operation with MemoryEffects::Write between the two operations. |
| 85 | using MemEffectsCache = |
| 86 | DenseMap<Operation *, std::pair<Operation *, MemoryEffects::Effect *>>; |
| 87 | |
| 88 | /// Represents a single entry in the depth first traversal of a CFG. |
| 89 | struct CFGStackNode { |
| 90 | CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node) |
| 91 | : scope(knownValues), node(node), childIterator(node->begin()) {} |
| 92 | |
| 93 | /// Scope for the known values. |
| 94 | ScopedMapTy::ScopeTy scope; |
| 95 | |
| 96 | DominanceInfoNode *node; |
| 97 | DominanceInfoNode::const_iterator childIterator; |
| 98 | |
| 99 | /// If this node has been fully processed yet or not. |
| 100 | bool processed = false; |
| 101 | }; |
| 102 | |
| 103 | /// Attempt to eliminate a redundant operation. Returns success if the |
| 104 | /// operation was marked for removal, failure otherwise. |
| 105 | LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op, |
| 106 | bool hasSSADominance); |
| 107 | void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance); |
| 108 | void simplifyRegion(ScopedMapTy &knownValues, Region ®ion); |
| 109 | |
| 110 | void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, |
| 111 | Operation *existing, bool hasSSADominance); |
| 112 | |
| 113 | /// Check if there is side-effecting operations other than the given effect |
| 114 | /// between the two operations. |
| 115 | bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp); |
| 116 | |
| 117 | /// A rewriter for modifying the IR. |
| 118 | RewriterBase &rewriter; |
| 119 | |
| 120 | /// Operations marked as dead and to be erased. |
| 121 | std::vector<Operation *> opsToErase; |
| 122 | DominanceInfo *domInfo = nullptr; |
| 123 | MemEffectsCache memEffectsCache; |
| 124 | |
| 125 | // Various statistics. |
| 126 | int64_t numCSE = 0; |
| 127 | int64_t numDCE = 0; |
| 128 | }; |
| 129 | } // namespace |
| 130 | |
| 131 | void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, |
| 132 | Operation *existing, |
| 133 | bool hasSSADominance) { |
| 134 | // If we find one then replace all uses of the current operation with the |
| 135 | // existing one and mark it for deletion. We can only replace an operand in |
| 136 | // an operation if it has not been visited yet. |
| 137 | if (hasSSADominance) { |
| 138 | // If the region has SSA dominance, then we are guaranteed to have not |
| 139 | // visited any use of the current operation. |
| 140 | if (auto *rewriteListener = |
| 141 | dyn_cast_if_present<RewriterBase::Listener>(Val: rewriter.getListener())) |
| 142 | rewriteListener->notifyOperationReplaced(op, replacement: existing); |
| 143 | // Replace all uses, but do not remove the operation yet. This does not |
| 144 | // notify the listener because the original op is not erased. |
| 145 | rewriter.replaceAllUsesWith(from: op->getResults(), to: existing->getResults()); |
| 146 | opsToErase.push_back(x: op); |
| 147 | } else { |
| 148 | // When the region does not have SSA dominance, we need to check if we |
| 149 | // have visited a use before replacing any use. |
| 150 | auto wasVisited = [&](OpOperand &operand) { |
| 151 | return !knownValues.count(Key: operand.getOwner()); |
| 152 | }; |
| 153 | if (auto *rewriteListener = |
| 154 | dyn_cast_if_present<RewriterBase::Listener>(Val: rewriter.getListener())) |
| 155 | for (Value v : op->getResults()) |
| 156 | if (all_of(Range: v.getUses(), P: wasVisited)) |
| 157 | rewriteListener->notifyOperationReplaced(op, replacement: existing); |
| 158 | |
| 159 | // Replace all uses, but do not remove the operation yet. This does not |
| 160 | // notify the listener because the original op is not erased. |
| 161 | rewriter.replaceUsesWithIf(from: op->getResults(), to: existing->getResults(), |
| 162 | functor: wasVisited); |
| 163 | |
| 164 | // There may be some remaining uses of the operation. |
| 165 | if (op->use_empty()) |
| 166 | opsToErase.push_back(x: op); |
| 167 | } |
| 168 | |
| 169 | // If the existing operation has an unknown location and the current |
| 170 | // operation doesn't, then set the existing op's location to that of the |
| 171 | // current op. |
| 172 | if (isa<UnknownLoc>(Val: existing->getLoc()) && !isa<UnknownLoc>(Val: op->getLoc())) |
| 173 | existing->setLoc(op->getLoc()); |
| 174 | |
| 175 | ++numCSE; |
| 176 | } |
| 177 | |
| 178 | bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp, |
| 179 | Operation *toOp) { |
| 180 | assert(fromOp->getBlock() == toOp->getBlock()); |
| 181 | assert( |
| 182 | isa<MemoryEffectOpInterface>(fromOp) && |
| 183 | cast<MemoryEffectOpInterface>(fromOp).hasEffect<MemoryEffects::Read>() && |
| 184 | isa<MemoryEffectOpInterface>(toOp) && |
| 185 | cast<MemoryEffectOpInterface>(toOp).hasEffect<MemoryEffects::Read>()); |
| 186 | Operation *nextOp = fromOp->getNextNode(); |
| 187 | auto result = |
| 188 | memEffectsCache.try_emplace(Key: fromOp, Args: std::make_pair(x&: fromOp, y: nullptr)); |
| 189 | if (result.second) { |
| 190 | auto memEffectsCachePair = result.first->second; |
| 191 | if (memEffectsCachePair.second == nullptr) { |
| 192 | // No MemoryEffects::Write has been detected until the cached operation. |
| 193 | // Continue looking from the cached operation to toOp. |
| 194 | nextOp = memEffectsCachePair.first; |
| 195 | } else { |
| 196 | // MemoryEffects::Write has been detected before so there is no need to |
| 197 | // check further. |
| 198 | return true; |
| 199 | } |
| 200 | } |
| 201 | while (nextOp && nextOp != toOp) { |
| 202 | std::optional<SmallVector<MemoryEffects::EffectInstance>> effects = |
| 203 | getEffectsRecursively(rootOp: nextOp); |
| 204 | if (!effects) { |
| 205 | // TODO: Do we need to handle other effects generically? |
| 206 | // If the operation does not implement the MemoryEffectOpInterface we |
| 207 | // conservatively assume it writes. |
| 208 | result.first->second = |
| 209 | std::make_pair(x&: nextOp, y: MemoryEffects::Write::get()); |
| 210 | return true; |
| 211 | } |
| 212 | |
| 213 | for (const MemoryEffects::EffectInstance &effect : *effects) { |
| 214 | if (isa<MemoryEffects::Write>(effect.getEffect())) { |
| 215 | result.first->second = {nextOp, MemoryEffects::Write::get()}; |
| 216 | return true; |
| 217 | } |
| 218 | } |
| 219 | nextOp = nextOp->getNextNode(); |
| 220 | } |
| 221 | result.first->second = std::make_pair(x&: toOp, y: nullptr); |
| 222 | return false; |
| 223 | } |
| 224 | |
| 225 | /// Attempt to eliminate a redundant operation. |
| 226 | LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues, |
| 227 | Operation *op, |
| 228 | bool hasSSADominance) { |
| 229 | // Don't simplify terminator operations. |
| 230 | if (op->hasTrait<OpTrait::IsTerminator>()) |
| 231 | return failure(); |
| 232 | |
| 233 | // If the operation is already trivially dead just add it to the erase list. |
| 234 | if (isOpTriviallyDead(op)) { |
| 235 | opsToErase.push_back(x: op); |
| 236 | ++numDCE; |
| 237 | return success(); |
| 238 | } |
| 239 | |
| 240 | // Don't simplify operations with regions that have multiple blocks. |
| 241 | // TODO: We need additional tests to verify that we handle such IR correctly. |
| 242 | if (!llvm::all_of(Range: op->getRegions(), P: [](Region &r) { |
| 243 | return r.getBlocks().empty() || llvm::hasSingleElement(C&: r.getBlocks()); |
| 244 | })) |
| 245 | return failure(); |
| 246 | |
| 247 | // Some simple use case of operation with memory side-effect are dealt with |
| 248 | // here. Operations with no side-effect are done after. |
| 249 | if (!isMemoryEffectFree(op)) { |
| 250 | auto memEffects = dyn_cast<MemoryEffectOpInterface>(op); |
| 251 | // TODO: Only basic use case for operations with MemoryEffects::Read can be |
| 252 | // eleminated now. More work needs to be done for more complicated patterns |
| 253 | // and other side-effects. |
| 254 | if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>()) |
| 255 | return failure(); |
| 256 | |
| 257 | // Look for an existing definition for the operation. |
| 258 | if (auto *existing = knownValues.lookup(Key: op)) { |
| 259 | if (existing->getBlock() == op->getBlock() && |
| 260 | !hasOtherSideEffectingOpInBetween(fromOp: existing, toOp: op)) { |
| 261 | // The operation that can be deleted has been reach with no |
| 262 | // side-effecting operations in between the existing operation and |
| 263 | // this one so we can remove the duplicate. |
| 264 | replaceUsesAndDelete(knownValues, op, existing, hasSSADominance); |
| 265 | return success(); |
| 266 | } |
| 267 | } |
| 268 | knownValues.insert(Key: op, Val: op); |
| 269 | return failure(); |
| 270 | } |
| 271 | |
| 272 | // Look for an existing definition for the operation. |
| 273 | if (auto *existing = knownValues.lookup(Key: op)) { |
| 274 | replaceUsesAndDelete(knownValues, op, existing, hasSSADominance); |
| 275 | ++numCSE; |
| 276 | return success(); |
| 277 | } |
| 278 | |
| 279 | // Otherwise, we add this operation to the known values map. |
| 280 | knownValues.insert(Key: op, Val: op); |
| 281 | return failure(); |
| 282 | } |
| 283 | |
| 284 | void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb, |
| 285 | bool hasSSADominance) { |
| 286 | for (auto &op : *bb) { |
| 287 | // Most operations don't have regions, so fast path that case. |
| 288 | if (op.getNumRegions() != 0) { |
| 289 | // If this operation is isolated above, we can't process nested regions |
| 290 | // with the given 'knownValues' map. This would cause the insertion of |
| 291 | // implicit captures in explicit capture only regions. |
| 292 | if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) { |
| 293 | ScopedMapTy nestedKnownValues; |
| 294 | for (auto ®ion : op.getRegions()) |
| 295 | simplifyRegion(knownValues&: nestedKnownValues, region); |
| 296 | } else { |
| 297 | // Otherwise, process nested regions normally. |
| 298 | for (auto ®ion : op.getRegions()) |
| 299 | simplifyRegion(knownValues, region); |
| 300 | } |
| 301 | } |
| 302 | |
| 303 | // If the operation is simplified, we don't process any held regions. |
| 304 | if (succeeded(Result: simplifyOperation(knownValues, op: &op, hasSSADominance))) |
| 305 | continue; |
| 306 | } |
| 307 | // Clear the MemoryEffects cache since its usage is by block only. |
| 308 | memEffectsCache.clear(); |
| 309 | } |
| 310 | |
| 311 | void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) { |
| 312 | // If the region is empty there is nothing to do. |
| 313 | if (region.empty()) |
| 314 | return; |
| 315 | |
| 316 | bool hasSSADominance = domInfo->hasSSADominance(region: ®ion); |
| 317 | |
| 318 | // If the region only contains one block, then simplify it directly. |
| 319 | if (region.hasOneBlock()) { |
| 320 | ScopedMapTy::ScopeTy scope(knownValues); |
| 321 | simplifyBlock(knownValues, bb: ®ion.front(), hasSSADominance); |
| 322 | return; |
| 323 | } |
| 324 | |
| 325 | // If the region does not have dominanceInfo, then skip it. |
| 326 | // TODO: Regions without SSA dominance should define a different |
| 327 | // traversal order which is appropriate and can be used here. |
| 328 | if (!hasSSADominance) |
| 329 | return; |
| 330 | |
| 331 | // Note, deque is being used here because there was significant performance |
| 332 | // gains over vector when the container becomes very large due to the |
| 333 | // specific access patterns. If/when these performance issues are no |
| 334 | // longer a problem we can change this to vector. For more information see |
| 335 | // the llvm mailing list discussion on this: |
| 336 | // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html |
| 337 | std::deque<std::unique_ptr<CFGStackNode>> stack; |
| 338 | |
| 339 | // Process the nodes of the dom tree for this region. |
| 340 | stack.emplace_back(args: std::make_unique<CFGStackNode>( |
| 341 | args&: knownValues, args: domInfo->getRootNode(region: ®ion))); |
| 342 | |
| 343 | while (!stack.empty()) { |
| 344 | auto ¤tNode = stack.back(); |
| 345 | |
| 346 | // Check to see if we need to process this node. |
| 347 | if (!currentNode->processed) { |
| 348 | currentNode->processed = true; |
| 349 | simplifyBlock(knownValues, bb: currentNode->node->getBlock(), |
| 350 | hasSSADominance); |
| 351 | } |
| 352 | |
| 353 | // Otherwise, check to see if we need to process a child node. |
| 354 | if (currentNode->childIterator != currentNode->node->end()) { |
| 355 | auto *childNode = *(currentNode->childIterator++); |
| 356 | stack.emplace_back( |
| 357 | args: std::make_unique<CFGStackNode>(args&: knownValues, args&: childNode)); |
| 358 | } else { |
| 359 | // Finally, if the node and all of its children have been processed |
| 360 | // then we delete the node. |
| 361 | stack.pop_back(); |
| 362 | } |
| 363 | } |
| 364 | } |
| 365 | |
| 366 | void CSEDriver::simplify(Operation *op, bool *changed) { |
| 367 | /// Simplify all regions. |
| 368 | ScopedMapTy knownValues; |
| 369 | for (auto ®ion : op->getRegions()) |
| 370 | simplifyRegion(knownValues, region); |
| 371 | |
| 372 | /// Erase any operations that were marked as dead during simplification. |
| 373 | for (auto *op : opsToErase) |
| 374 | rewriter.eraseOp(op); |
| 375 | if (changed) |
| 376 | *changed = !opsToErase.empty(); |
| 377 | |
| 378 | // Note: CSE does currently not remove ops with regions, so DominanceInfo |
| 379 | // does not have to be invalidated. |
| 380 | } |
| 381 | |
| 382 | void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter, |
| 383 | DominanceInfo &domInfo, Operation *op, |
| 384 | bool *changed) { |
| 385 | CSEDriver driver(rewriter, &domInfo); |
| 386 | driver.simplify(op, changed); |
| 387 | } |
| 388 | |
| 389 | namespace { |
| 390 | /// CSE pass. |
| 391 | struct CSE : public impl::CSEBase<CSE> { |
| 392 | void runOnOperation() override; |
| 393 | }; |
| 394 | } // namespace |
| 395 | |
| 396 | void CSE::runOnOperation() { |
| 397 | // Simplify the IR. |
| 398 | IRRewriter rewriter(&getContext()); |
| 399 | CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>()); |
| 400 | bool changed = false; |
| 401 | driver.simplify(op: getOperation(), changed: &changed); |
| 402 | |
| 403 | // Set statistics. |
| 404 | numCSE = driver.getNumCSE(); |
| 405 | numDCE = driver.getNumDCE(); |
| 406 | |
| 407 | // If there was no change to the IR, we mark all analyses as preserved. |
| 408 | if (!changed) |
| 409 | return markAllAnalysesPreserved(); |
| 410 | |
| 411 | // We currently don't remove region operations, so mark dominance as |
| 412 | // preserved. |
| 413 | markAnalysesPreserved<DominanceInfo, PostDominanceInfo>(); |
| 414 | } |
| 415 | |
| 416 | std::unique_ptr<Pass> mlir::createCSEPass() { return std::make_unique<CSE>(); } |
| 417 | |