1//===- CSE.cpp - Common Sub-expression Elimination ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass performs a simple common sub-expression elimination
10// algorithm on operations within a region.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Transforms/CSE.h"
15
16#include "mlir/IR/Dominance.h"
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/Interfaces/SideEffectInterfaces.h"
19#include "mlir/Pass/Pass.h"
20#include "mlir/Transforms/Passes.h"
21#include "llvm/ADT/DenseMapInfo.h"
22#include "llvm/ADT/Hashing.h"
23#include "llvm/ADT/ScopedHashTable.h"
24#include "llvm/Support/Allocator.h"
25#include "llvm/Support/RecyclingAllocator.h"
26#include <deque>
27
28namespace mlir {
29#define GEN_PASS_DEF_CSE
30#include "mlir/Transforms/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34
35namespace {
36struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
37 static unsigned getHashValue(const Operation *opC) {
38 return OperationEquivalence::computeHash(
39 const_cast<Operation *>(opC),
40 /*hashOperands=*/OperationEquivalence::directHashValue,
41 /*hashResults=*/OperationEquivalence::ignoreHashValue,
42 OperationEquivalence::IgnoreLocations);
43 }
44 static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
45 auto *lhs = const_cast<Operation *>(lhsC);
46 auto *rhs = const_cast<Operation *>(rhsC);
47 if (lhs == rhs)
48 return true;
49 if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
50 rhs == getTombstoneKey() || rhs == getEmptyKey())
51 return false;
52 return OperationEquivalence::isEquivalentTo(
53 lhs: const_cast<Operation *>(lhsC), rhs: const_cast<Operation *>(rhsC),
54 flags: OperationEquivalence::IgnoreLocations);
55 }
56};
57} // namespace
58
59namespace {
60/// Simple common sub-expression elimination.
61class CSEDriver {
62public:
63 CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
64 : rewriter(rewriter), domInfo(domInfo) {}
65
66 /// Simplify all operations within the given op.
67 void simplify(Operation *op, bool *changed = nullptr);
68
69 int64_t getNumCSE() const { return numCSE; }
70 int64_t getNumDCE() const { return numDCE; }
71
72private:
73 /// Shared implementation of operation elimination and scoped map definitions.
74 using AllocatorTy = llvm::RecyclingAllocator<
75 llvm::BumpPtrAllocator,
76 llvm::ScopedHashTableVal<Operation *, Operation *>>;
77 using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
78 SimpleOperationInfo, AllocatorTy>;
79
80 /// Cache holding MemoryEffects information between two operations. The first
81 /// operation is stored has the key. The second operation is stored inside a
82 /// pair in the value. The pair also hold the MemoryEffects between those
83 /// two operations. If the MemoryEffects is nullptr then we assume there is
84 /// no operation with MemoryEffects::Write between the two operations.
85 using MemEffectsCache =
86 DenseMap<Operation *, std::pair<Operation *, MemoryEffects::Effect *>>;
87
88 /// Represents a single entry in the depth first traversal of a CFG.
89 struct CFGStackNode {
90 CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
91 : scope(knownValues), node(node), childIterator(node->begin()) {}
92
93 /// Scope for the known values.
94 ScopedMapTy::ScopeTy scope;
95
96 DominanceInfoNode *node;
97 DominanceInfoNode::const_iterator childIterator;
98
99 /// If this node has been fully processed yet or not.
100 bool processed = false;
101 };
102
103 /// Attempt to eliminate a redundant operation. Returns success if the
104 /// operation was marked for removal, failure otherwise.
105 LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
106 bool hasSSADominance);
107 void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
108 void simplifyRegion(ScopedMapTy &knownValues, Region &region);
109
110 void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
111 Operation *existing, bool hasSSADominance);
112
113 /// Check if there is side-effecting operations other than the given effect
114 /// between the two operations.
115 bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
116
117 /// A rewriter for modifying the IR.
118 RewriterBase &rewriter;
119
120 /// Operations marked as dead and to be erased.
121 std::vector<Operation *> opsToErase;
122 DominanceInfo *domInfo = nullptr;
123 MemEffectsCache memEffectsCache;
124
125 // Various statistics.
126 int64_t numCSE = 0;
127 int64_t numDCE = 0;
128};
129} // namespace
130
131void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
132 Operation *existing,
133 bool hasSSADominance) {
134 // If we find one then replace all uses of the current operation with the
135 // existing one and mark it for deletion. We can only replace an operand in
136 // an operation if it has not been visited yet.
137 if (hasSSADominance) {
138 // If the region has SSA dominance, then we are guaranteed to have not
139 // visited any use of the current operation.
140 if (auto *rewriteListener =
141 dyn_cast_if_present<RewriterBase::Listener>(Val: rewriter.getListener()))
142 rewriteListener->notifyOperationReplaced(op, replacement: existing);
143 // Replace all uses, but do not remote the operation yet. This does not
144 // notify the listener because the original op is not erased.
145 rewriter.replaceAllUsesWith(from: op->getResults(), to: existing->getResults());
146 opsToErase.push_back(x: op);
147 } else {
148 // When the region does not have SSA dominance, we need to check if we
149 // have visited a use before replacing any use.
150 auto wasVisited = [&](OpOperand &operand) {
151 return !knownValues.count(Key: operand.getOwner());
152 };
153 if (auto *rewriteListener =
154 dyn_cast_if_present<RewriterBase::Listener>(Val: rewriter.getListener()))
155 for (Value v : op->getResults())
156 if (all_of(Range: v.getUses(), P: wasVisited))
157 rewriteListener->notifyOperationReplaced(op, replacement: existing);
158
159 // Replace all uses, but do not remote the operation yet. This does not
160 // notify the listener because the original op is not erased.
161 rewriter.replaceUsesWithIf(from: op->getResults(), to: existing->getResults(),
162 functor: wasVisited);
163
164 // There may be some remaining uses of the operation.
165 if (op->use_empty())
166 opsToErase.push_back(x: op);
167 }
168
169 // If the existing operation has an unknown location and the current
170 // operation doesn't, then set the existing op's location to that of the
171 // current op.
172 if (isa<UnknownLoc>(Val: existing->getLoc()) && !isa<UnknownLoc>(Val: op->getLoc()))
173 existing->setLoc(op->getLoc());
174
175 ++numCSE;
176}
177
178bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
179 Operation *toOp) {
180 assert(fromOp->getBlock() == toOp->getBlock());
181 assert(
182 isa<MemoryEffectOpInterface>(fromOp) &&
183 cast<MemoryEffectOpInterface>(fromOp).hasEffect<MemoryEffects::Read>() &&
184 isa<MemoryEffectOpInterface>(toOp) &&
185 cast<MemoryEffectOpInterface>(toOp).hasEffect<MemoryEffects::Read>());
186 Operation *nextOp = fromOp->getNextNode();
187 auto result =
188 memEffectsCache.try_emplace(Key: fromOp, Args: std::make_pair(x&: fromOp, y: nullptr));
189 if (result.second) {
190 auto memEffectsCachePair = result.first->second;
191 if (memEffectsCachePair.second == nullptr) {
192 // No MemoryEffects::Write has been detected until the cached operation.
193 // Continue looking from the cached operation to toOp.
194 nextOp = memEffectsCachePair.first;
195 } else {
196 // MemoryEffects::Write has been detected before so there is no need to
197 // check further.
198 return true;
199 }
200 }
201 while (nextOp && nextOp != toOp) {
202 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
203 getEffectsRecursively(rootOp: nextOp);
204 if (!effects) {
205 // TODO: Do we need to handle other effects generically?
206 // If the operation does not implement the MemoryEffectOpInterface we
207 // conservatively assume it writes.
208 result.first->second =
209 std::make_pair(x&: nextOp, y: MemoryEffects::Write::get());
210 return true;
211 }
212
213 for (const MemoryEffects::EffectInstance &effect : *effects) {
214 if (isa<MemoryEffects::Write>(effect.getEffect())) {
215 result.first->second = {nextOp, MemoryEffects::Write::get()};
216 return true;
217 }
218 }
219 nextOp = nextOp->getNextNode();
220 }
221 result.first->second = std::make_pair(x&: toOp, y: nullptr);
222 return false;
223}
224
225/// Attempt to eliminate a redundant operation.
226LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
227 Operation *op,
228 bool hasSSADominance) {
229 // Don't simplify terminator operations.
230 if (op->hasTrait<OpTrait::IsTerminator>())
231 return failure();
232
233 // If the operation is already trivially dead just add it to the erase list.
234 if (isOpTriviallyDead(op)) {
235 opsToErase.push_back(x: op);
236 ++numDCE;
237 return success();
238 }
239
240 // Don't simplify operations with regions that have multiple blocks.
241 // TODO: We need additional tests to verify that we handle such IR correctly.
242 if (!llvm::all_of(Range: op->getRegions(), P: [](Region &r) {
243 return r.getBlocks().empty() || llvm::hasSingleElement(C&: r.getBlocks());
244 }))
245 return failure();
246
247 // Some simple use case of operation with memory side-effect are dealt with
248 // here. Operations with no side-effect are done after.
249 if (!isMemoryEffectFree(op)) {
250 auto memEffects = dyn_cast<MemoryEffectOpInterface>(op);
251 // TODO: Only basic use case for operations with MemoryEffects::Read can be
252 // eleminated now. More work needs to be done for more complicated patterns
253 // and other side-effects.
254 if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>())
255 return failure();
256
257 // Look for an existing definition for the operation.
258 if (auto *existing = knownValues.lookup(Key: op)) {
259 if (existing->getBlock() == op->getBlock() &&
260 !hasOtherSideEffectingOpInBetween(fromOp: existing, toOp: op)) {
261 // The operation that can be deleted has been reach with no
262 // side-effecting operations in between the existing operation and
263 // this one so we can remove the duplicate.
264 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
265 return success();
266 }
267 }
268 knownValues.insert(Key: op, Val: op);
269 return failure();
270 }
271
272 // Look for an existing definition for the operation.
273 if (auto *existing = knownValues.lookup(Key: op)) {
274 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
275 ++numCSE;
276 return success();
277 }
278
279 // Otherwise, we add this operation to the known values map.
280 knownValues.insert(Key: op, Val: op);
281 return failure();
282}
283
284void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
285 bool hasSSADominance) {
286 for (auto &op : *bb) {
287 // Most operations don't have regions, so fast path that case.
288 if (op.getNumRegions() != 0) {
289 // If this operation is isolated above, we can't process nested regions
290 // with the given 'knownValues' map. This would cause the insertion of
291 // implicit captures in explicit capture only regions.
292 if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
293 ScopedMapTy nestedKnownValues;
294 for (auto &region : op.getRegions())
295 simplifyRegion(knownValues&: nestedKnownValues, region);
296 } else {
297 // Otherwise, process nested regions normally.
298 for (auto &region : op.getRegions())
299 simplifyRegion(knownValues, region);
300 }
301 }
302
303 // If the operation is simplified, we don't process any held regions.
304 if (succeeded(result: simplifyOperation(knownValues, op: &op, hasSSADominance)))
305 continue;
306 }
307 // Clear the MemoryEffects cache since its usage is by block only.
308 memEffectsCache.clear();
309}
310
311void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region &region) {
312 // If the region is empty there is nothing to do.
313 if (region.empty())
314 return;
315
316 bool hasSSADominance = domInfo->hasSSADominance(region: &region);
317
318 // If the region only contains one block, then simplify it directly.
319 if (region.hasOneBlock()) {
320 ScopedMapTy::ScopeTy scope(knownValues);
321 simplifyBlock(knownValues, bb: &region.front(), hasSSADominance);
322 return;
323 }
324
325 // If the region does not have dominanceInfo, then skip it.
326 // TODO: Regions without SSA dominance should define a different
327 // traversal order which is appropriate and can be used here.
328 if (!hasSSADominance)
329 return;
330
331 // Note, deque is being used here because there was significant performance
332 // gains over vector when the container becomes very large due to the
333 // specific access patterns. If/when these performance issues are no
334 // longer a problem we can change this to vector. For more information see
335 // the llvm mailing list discussion on this:
336 // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
337 std::deque<std::unique_ptr<CFGStackNode>> stack;
338
339 // Process the nodes of the dom tree for this region.
340 stack.emplace_back(args: std::make_unique<CFGStackNode>(
341 args&: knownValues, args: domInfo->getRootNode(region: &region)));
342
343 while (!stack.empty()) {
344 auto &currentNode = stack.back();
345
346 // Check to see if we need to process this node.
347 if (!currentNode->processed) {
348 currentNode->processed = true;
349 simplifyBlock(knownValues, bb: currentNode->node->getBlock(),
350 hasSSADominance);
351 }
352
353 // Otherwise, check to see if we need to process a child node.
354 if (currentNode->childIterator != currentNode->node->end()) {
355 auto *childNode = *(currentNode->childIterator++);
356 stack.emplace_back(
357 args: std::make_unique<CFGStackNode>(args&: knownValues, args&: childNode));
358 } else {
359 // Finally, if the node and all of its children have been processed
360 // then we delete the node.
361 stack.pop_back();
362 }
363 }
364}
365
366void CSEDriver::simplify(Operation *op, bool *changed) {
367 /// Simplify all regions.
368 ScopedMapTy knownValues;
369 for (auto &region : op->getRegions())
370 simplifyRegion(knownValues, region);
371
372 /// Erase any operations that were marked as dead during simplification.
373 for (auto *op : opsToErase)
374 rewriter.eraseOp(op);
375 if (changed)
376 *changed = !opsToErase.empty();
377
378 // Note: CSE does currently not remove ops with regions, so DominanceInfo
379 // does not have to be invalidated.
380}
381
382void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter,
383 DominanceInfo &domInfo, Operation *op,
384 bool *changed) {
385 CSEDriver driver(rewriter, &domInfo);
386 driver.simplify(op, changed);
387}
388
389namespace {
390/// CSE pass.
391struct CSE : public impl::CSEBase<CSE> {
392 void runOnOperation() override;
393};
394} // namespace
395
396void CSE::runOnOperation() {
397 // Simplify the IR.
398 IRRewriter rewriter(&getContext());
399 CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
400 bool changed = false;
401 driver.simplify(op: getOperation(), changed: &changed);
402
403 // Set statistics.
404 numCSE = driver.getNumCSE();
405 numDCE = driver.getNumDCE();
406
407 // If there was no change to the IR, we mark all analyses as preserved.
408 if (!changed)
409 return markAllAnalysesPreserved();
410
411 // We currently don't remove region operations, so mark dominance as
412 // preserved.
413 markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
414}
415
416std::unique_ptr<Pass> mlir::createCSEPass() { return std::make_unique<CSE>(); }
417

source code of mlir/lib/Transforms/CSE.cpp