1 | //===- CSE.cpp - Common Sub-expression Elimination ------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This transformation pass performs a simple common sub-expression elimination |
10 | // algorithm on operations within a region. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Transforms/CSE.h" |
15 | |
16 | #include "mlir/IR/Dominance.h" |
17 | #include "mlir/IR/PatternMatch.h" |
18 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
19 | #include "mlir/Pass/Pass.h" |
20 | #include "mlir/Transforms/Passes.h" |
21 | #include "llvm/ADT/DenseMapInfo.h" |
22 | #include "llvm/ADT/Hashing.h" |
23 | #include "llvm/ADT/ScopedHashTable.h" |
24 | #include "llvm/Support/Allocator.h" |
25 | #include "llvm/Support/RecyclingAllocator.h" |
26 | #include <deque> |
27 | |
28 | namespace mlir { |
29 | #define GEN_PASS_DEF_CSE |
30 | #include "mlir/Transforms/Passes.h.inc" |
31 | } // namespace mlir |
32 | |
33 | using namespace mlir; |
34 | |
35 | namespace { |
36 | struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> { |
37 | static unsigned getHashValue(const Operation *opC) { |
38 | return OperationEquivalence::computeHash( |
39 | const_cast<Operation *>(opC), |
40 | /*hashOperands=*/OperationEquivalence::directHashValue, |
41 | /*hashResults=*/OperationEquivalence::ignoreHashValue, |
42 | OperationEquivalence::IgnoreLocations); |
43 | } |
44 | static bool isEqual(const Operation *lhsC, const Operation *rhsC) { |
45 | auto *lhs = const_cast<Operation *>(lhsC); |
46 | auto *rhs = const_cast<Operation *>(rhsC); |
47 | if (lhs == rhs) |
48 | return true; |
49 | if (lhs == getTombstoneKey() || lhs == getEmptyKey() || |
50 | rhs == getTombstoneKey() || rhs == getEmptyKey()) |
51 | return false; |
52 | return OperationEquivalence::isEquivalentTo( |
53 | lhs: const_cast<Operation *>(lhsC), rhs: const_cast<Operation *>(rhsC), |
54 | flags: OperationEquivalence::IgnoreLocations); |
55 | } |
56 | }; |
57 | } // namespace |
58 | |
59 | namespace { |
60 | /// Simple common sub-expression elimination. |
61 | class CSEDriver { |
62 | public: |
63 | CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo) |
64 | : rewriter(rewriter), domInfo(domInfo) {} |
65 | |
66 | /// Simplify all operations within the given op. |
67 | void simplify(Operation *op, bool *changed = nullptr); |
68 | |
69 | int64_t getNumCSE() const { return numCSE; } |
70 | int64_t getNumDCE() const { return numDCE; } |
71 | |
72 | private: |
73 | /// Shared implementation of operation elimination and scoped map definitions. |
74 | using AllocatorTy = llvm::RecyclingAllocator< |
75 | llvm::BumpPtrAllocator, |
76 | llvm::ScopedHashTableVal<Operation *, Operation *>>; |
77 | using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *, |
78 | SimpleOperationInfo, AllocatorTy>; |
79 | |
80 | /// Cache holding MemoryEffects information between two operations. The first |
81 | /// operation is stored has the key. The second operation is stored inside a |
82 | /// pair in the value. The pair also hold the MemoryEffects between those |
83 | /// two operations. If the MemoryEffects is nullptr then we assume there is |
84 | /// no operation with MemoryEffects::Write between the two operations. |
85 | using MemEffectsCache = |
86 | DenseMap<Operation *, std::pair<Operation *, MemoryEffects::Effect *>>; |
87 | |
88 | /// Represents a single entry in the depth first traversal of a CFG. |
89 | struct CFGStackNode { |
90 | CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node) |
91 | : scope(knownValues), node(node), childIterator(node->begin()) {} |
92 | |
93 | /// Scope for the known values. |
94 | ScopedMapTy::ScopeTy scope; |
95 | |
96 | DominanceInfoNode *node; |
97 | DominanceInfoNode::const_iterator childIterator; |
98 | |
99 | /// If this node has been fully processed yet or not. |
100 | bool processed = false; |
101 | }; |
102 | |
103 | /// Attempt to eliminate a redundant operation. Returns success if the |
104 | /// operation was marked for removal, failure otherwise. |
105 | LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op, |
106 | bool hasSSADominance); |
107 | void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance); |
108 | void simplifyRegion(ScopedMapTy &knownValues, Region ®ion); |
109 | |
110 | void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, |
111 | Operation *existing, bool hasSSADominance); |
112 | |
113 | /// Check if there is side-effecting operations other than the given effect |
114 | /// between the two operations. |
115 | bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp); |
116 | |
117 | /// A rewriter for modifying the IR. |
118 | RewriterBase &rewriter; |
119 | |
120 | /// Operations marked as dead and to be erased. |
121 | std::vector<Operation *> opsToErase; |
122 | DominanceInfo *domInfo = nullptr; |
123 | MemEffectsCache memEffectsCache; |
124 | |
125 | // Various statistics. |
126 | int64_t numCSE = 0; |
127 | int64_t numDCE = 0; |
128 | }; |
129 | } // namespace |
130 | |
131 | void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, |
132 | Operation *existing, |
133 | bool hasSSADominance) { |
134 | // If we find one then replace all uses of the current operation with the |
135 | // existing one and mark it for deletion. We can only replace an operand in |
136 | // an operation if it has not been visited yet. |
137 | if (hasSSADominance) { |
138 | // If the region has SSA dominance, then we are guaranteed to have not |
139 | // visited any use of the current operation. |
140 | if (auto *rewriteListener = |
141 | dyn_cast_if_present<RewriterBase::Listener>(Val: rewriter.getListener())) |
142 | rewriteListener->notifyOperationReplaced(op, replacement: existing); |
143 | // Replace all uses, but do not remote the operation yet. This does not |
144 | // notify the listener because the original op is not erased. |
145 | rewriter.replaceAllUsesWith(from: op->getResults(), to: existing->getResults()); |
146 | opsToErase.push_back(x: op); |
147 | } else { |
148 | // When the region does not have SSA dominance, we need to check if we |
149 | // have visited a use before replacing any use. |
150 | auto wasVisited = [&](OpOperand &operand) { |
151 | return !knownValues.count(Key: operand.getOwner()); |
152 | }; |
153 | if (auto *rewriteListener = |
154 | dyn_cast_if_present<RewriterBase::Listener>(Val: rewriter.getListener())) |
155 | for (Value v : op->getResults()) |
156 | if (all_of(Range: v.getUses(), P: wasVisited)) |
157 | rewriteListener->notifyOperationReplaced(op, replacement: existing); |
158 | |
159 | // Replace all uses, but do not remote the operation yet. This does not |
160 | // notify the listener because the original op is not erased. |
161 | rewriter.replaceUsesWithIf(from: op->getResults(), to: existing->getResults(), |
162 | functor: wasVisited); |
163 | |
164 | // There may be some remaining uses of the operation. |
165 | if (op->use_empty()) |
166 | opsToErase.push_back(x: op); |
167 | } |
168 | |
169 | // If the existing operation has an unknown location and the current |
170 | // operation doesn't, then set the existing op's location to that of the |
171 | // current op. |
172 | if (isa<UnknownLoc>(Val: existing->getLoc()) && !isa<UnknownLoc>(Val: op->getLoc())) |
173 | existing->setLoc(op->getLoc()); |
174 | |
175 | ++numCSE; |
176 | } |
177 | |
178 | bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp, |
179 | Operation *toOp) { |
180 | assert(fromOp->getBlock() == toOp->getBlock()); |
181 | assert( |
182 | isa<MemoryEffectOpInterface>(fromOp) && |
183 | cast<MemoryEffectOpInterface>(fromOp).hasEffect<MemoryEffects::Read>() && |
184 | isa<MemoryEffectOpInterface>(toOp) && |
185 | cast<MemoryEffectOpInterface>(toOp).hasEffect<MemoryEffects::Read>()); |
186 | Operation *nextOp = fromOp->getNextNode(); |
187 | auto result = |
188 | memEffectsCache.try_emplace(Key: fromOp, Args: std::make_pair(x&: fromOp, y: nullptr)); |
189 | if (result.second) { |
190 | auto memEffectsCachePair = result.first->second; |
191 | if (memEffectsCachePair.second == nullptr) { |
192 | // No MemoryEffects::Write has been detected until the cached operation. |
193 | // Continue looking from the cached operation to toOp. |
194 | nextOp = memEffectsCachePair.first; |
195 | } else { |
196 | // MemoryEffects::Write has been detected before so there is no need to |
197 | // check further. |
198 | return true; |
199 | } |
200 | } |
201 | while (nextOp && nextOp != toOp) { |
202 | std::optional<SmallVector<MemoryEffects::EffectInstance>> effects = |
203 | getEffectsRecursively(rootOp: nextOp); |
204 | if (!effects) { |
205 | // TODO: Do we need to handle other effects generically? |
206 | // If the operation does not implement the MemoryEffectOpInterface we |
207 | // conservatively assume it writes. |
208 | result.first->second = |
209 | std::make_pair(x&: nextOp, y: MemoryEffects::Write::get()); |
210 | return true; |
211 | } |
212 | |
213 | for (const MemoryEffects::EffectInstance &effect : *effects) { |
214 | if (isa<MemoryEffects::Write>(effect.getEffect())) { |
215 | result.first->second = {nextOp, MemoryEffects::Write::get()}; |
216 | return true; |
217 | } |
218 | } |
219 | nextOp = nextOp->getNextNode(); |
220 | } |
221 | result.first->second = std::make_pair(x&: toOp, y: nullptr); |
222 | return false; |
223 | } |
224 | |
225 | /// Attempt to eliminate a redundant operation. |
226 | LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues, |
227 | Operation *op, |
228 | bool hasSSADominance) { |
229 | // Don't simplify terminator operations. |
230 | if (op->hasTrait<OpTrait::IsTerminator>()) |
231 | return failure(); |
232 | |
233 | // If the operation is already trivially dead just add it to the erase list. |
234 | if (isOpTriviallyDead(op)) { |
235 | opsToErase.push_back(x: op); |
236 | ++numDCE; |
237 | return success(); |
238 | } |
239 | |
240 | // Don't simplify operations with regions that have multiple blocks. |
241 | // TODO: We need additional tests to verify that we handle such IR correctly. |
242 | if (!llvm::all_of(Range: op->getRegions(), P: [](Region &r) { |
243 | return r.getBlocks().empty() || llvm::hasSingleElement(C&: r.getBlocks()); |
244 | })) |
245 | return failure(); |
246 | |
247 | // Some simple use case of operation with memory side-effect are dealt with |
248 | // here. Operations with no side-effect are done after. |
249 | if (!isMemoryEffectFree(op)) { |
250 | auto memEffects = dyn_cast<MemoryEffectOpInterface>(op); |
251 | // TODO: Only basic use case for operations with MemoryEffects::Read can be |
252 | // eleminated now. More work needs to be done for more complicated patterns |
253 | // and other side-effects. |
254 | if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>()) |
255 | return failure(); |
256 | |
257 | // Look for an existing definition for the operation. |
258 | if (auto *existing = knownValues.lookup(Key: op)) { |
259 | if (existing->getBlock() == op->getBlock() && |
260 | !hasOtherSideEffectingOpInBetween(fromOp: existing, toOp: op)) { |
261 | // The operation that can be deleted has been reach with no |
262 | // side-effecting operations in between the existing operation and |
263 | // this one so we can remove the duplicate. |
264 | replaceUsesAndDelete(knownValues, op, existing, hasSSADominance); |
265 | return success(); |
266 | } |
267 | } |
268 | knownValues.insert(Key: op, Val: op); |
269 | return failure(); |
270 | } |
271 | |
272 | // Look for an existing definition for the operation. |
273 | if (auto *existing = knownValues.lookup(Key: op)) { |
274 | replaceUsesAndDelete(knownValues, op, existing, hasSSADominance); |
275 | ++numCSE; |
276 | return success(); |
277 | } |
278 | |
279 | // Otherwise, we add this operation to the known values map. |
280 | knownValues.insert(Key: op, Val: op); |
281 | return failure(); |
282 | } |
283 | |
284 | void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb, |
285 | bool hasSSADominance) { |
286 | for (auto &op : *bb) { |
287 | // Most operations don't have regions, so fast path that case. |
288 | if (op.getNumRegions() != 0) { |
289 | // If this operation is isolated above, we can't process nested regions |
290 | // with the given 'knownValues' map. This would cause the insertion of |
291 | // implicit captures in explicit capture only regions. |
292 | if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) { |
293 | ScopedMapTy nestedKnownValues; |
294 | for (auto ®ion : op.getRegions()) |
295 | simplifyRegion(knownValues&: nestedKnownValues, region); |
296 | } else { |
297 | // Otherwise, process nested regions normally. |
298 | for (auto ®ion : op.getRegions()) |
299 | simplifyRegion(knownValues, region); |
300 | } |
301 | } |
302 | |
303 | // If the operation is simplified, we don't process any held regions. |
304 | if (succeeded(result: simplifyOperation(knownValues, op: &op, hasSSADominance))) |
305 | continue; |
306 | } |
307 | // Clear the MemoryEffects cache since its usage is by block only. |
308 | memEffectsCache.clear(); |
309 | } |
310 | |
311 | void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) { |
312 | // If the region is empty there is nothing to do. |
313 | if (region.empty()) |
314 | return; |
315 | |
316 | bool hasSSADominance = domInfo->hasSSADominance(region: ®ion); |
317 | |
318 | // If the region only contains one block, then simplify it directly. |
319 | if (region.hasOneBlock()) { |
320 | ScopedMapTy::ScopeTy scope(knownValues); |
321 | simplifyBlock(knownValues, bb: ®ion.front(), hasSSADominance); |
322 | return; |
323 | } |
324 | |
325 | // If the region does not have dominanceInfo, then skip it. |
326 | // TODO: Regions without SSA dominance should define a different |
327 | // traversal order which is appropriate and can be used here. |
328 | if (!hasSSADominance) |
329 | return; |
330 | |
331 | // Note, deque is being used here because there was significant performance |
332 | // gains over vector when the container becomes very large due to the |
333 | // specific access patterns. If/when these performance issues are no |
334 | // longer a problem we can change this to vector. For more information see |
335 | // the llvm mailing list discussion on this: |
336 | // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html |
337 | std::deque<std::unique_ptr<CFGStackNode>> stack; |
338 | |
339 | // Process the nodes of the dom tree for this region. |
340 | stack.emplace_back(args: std::make_unique<CFGStackNode>( |
341 | args&: knownValues, args: domInfo->getRootNode(region: ®ion))); |
342 | |
343 | while (!stack.empty()) { |
344 | auto ¤tNode = stack.back(); |
345 | |
346 | // Check to see if we need to process this node. |
347 | if (!currentNode->processed) { |
348 | currentNode->processed = true; |
349 | simplifyBlock(knownValues, bb: currentNode->node->getBlock(), |
350 | hasSSADominance); |
351 | } |
352 | |
353 | // Otherwise, check to see if we need to process a child node. |
354 | if (currentNode->childIterator != currentNode->node->end()) { |
355 | auto *childNode = *(currentNode->childIterator++); |
356 | stack.emplace_back( |
357 | args: std::make_unique<CFGStackNode>(args&: knownValues, args&: childNode)); |
358 | } else { |
359 | // Finally, if the node and all of its children have been processed |
360 | // then we delete the node. |
361 | stack.pop_back(); |
362 | } |
363 | } |
364 | } |
365 | |
366 | void CSEDriver::simplify(Operation *op, bool *changed) { |
367 | /// Simplify all regions. |
368 | ScopedMapTy knownValues; |
369 | for (auto ®ion : op->getRegions()) |
370 | simplifyRegion(knownValues, region); |
371 | |
372 | /// Erase any operations that were marked as dead during simplification. |
373 | for (auto *op : opsToErase) |
374 | rewriter.eraseOp(op); |
375 | if (changed) |
376 | *changed = !opsToErase.empty(); |
377 | |
378 | // Note: CSE does currently not remove ops with regions, so DominanceInfo |
379 | // does not have to be invalidated. |
380 | } |
381 | |
382 | void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter, |
383 | DominanceInfo &domInfo, Operation *op, |
384 | bool *changed) { |
385 | CSEDriver driver(rewriter, &domInfo); |
386 | driver.simplify(op, changed); |
387 | } |
388 | |
389 | namespace { |
390 | /// CSE pass. |
391 | struct CSE : public impl::CSEBase<CSE> { |
392 | void runOnOperation() override; |
393 | }; |
394 | } // namespace |
395 | |
396 | void CSE::runOnOperation() { |
397 | // Simplify the IR. |
398 | IRRewriter rewriter(&getContext()); |
399 | CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>()); |
400 | bool changed = false; |
401 | driver.simplify(op: getOperation(), changed: &changed); |
402 | |
403 | // Set statistics. |
404 | numCSE = driver.getNumCSE(); |
405 | numDCE = driver.getNumDCE(); |
406 | |
407 | // If there was no change to the IR, we mark all analyses as preserved. |
408 | if (!changed) |
409 | return markAllAnalysesPreserved(); |
410 | |
411 | // We currently don't remove region operations, so mark dominance as |
412 | // preserved. |
413 | markAnalysesPreserved<DominanceInfo, PostDominanceInfo>(); |
414 | } |
415 | |
416 | std::unique_ptr<Pass> mlir::createCSEPass() { return std::make_unique<CSE>(); } |
417 | |