1 | //===- RegionUtils.cpp - Region-related transformation utilities ----------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Transforms/RegionUtils.h" |
10 | |
11 | #include "mlir/Analysis/SliceAnalysis.h" |
12 | #include "mlir/Analysis/TopologicalSortUtils.h" |
13 | #include "mlir/IR/Block.h" |
14 | #include "mlir/IR/BuiltinOps.h" |
15 | #include "mlir/IR/Dominance.h" |
16 | #include "mlir/IR/IRMapping.h" |
17 | #include "mlir/IR/Operation.h" |
18 | #include "mlir/IR/PatternMatch.h" |
19 | #include "mlir/IR/RegionGraphTraits.h" |
20 | #include "mlir/IR/Value.h" |
21 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
22 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
23 | #include "mlir/Support/LogicalResult.h" |
24 | |
25 | #include "llvm/ADT/DepthFirstIterator.h" |
26 | #include "llvm/ADT/PostOrderIterator.h" |
27 | #include "llvm/ADT/STLExtras.h" |
28 | #include "llvm/ADT/SmallSet.h" |
29 | |
30 | #include <deque> |
31 | #include <iterator> |
32 | |
33 | using namespace mlir; |
34 | |
35 | void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, |
36 | Region ®ion) { |
37 | for (auto &use : llvm::make_early_inc_range(Range: orig.getUses())) { |
38 | if (region.isAncestor(other: use.getOwner()->getParentRegion())) |
39 | use.set(replacement); |
40 | } |
41 | } |
42 | |
43 | void mlir::visitUsedValuesDefinedAbove( |
44 | Region ®ion, Region &limit, function_ref<void(OpOperand *)> callback) { |
45 | assert(limit.isAncestor(®ion) && |
46 | "expected isolation limit to be an ancestor of the given region"); |
47 | |
48 | // Collect proper ancestors of `limit` upfront to avoid traversing the region |
49 | // tree for every value. |
50 | SmallPtrSet<Region *, 4> properAncestors; |
51 | for (auto *reg = limit.getParentRegion(); reg != nullptr; |
52 | reg = reg->getParentRegion()) { |
53 | properAncestors.insert(Ptr: reg); |
54 | } |
55 | |
56 | region.walk(callback: [callback, &properAncestors](Operation *op) { |
57 | for (OpOperand &operand : op->getOpOperands()) |
58 | // Callback on values defined in a proper ancestor of region. |
59 | if (properAncestors.count(Ptr: operand.get().getParentRegion())) |
60 | callback(&operand); |
61 | }); |
62 | } |
63 | |
64 | void mlir::visitUsedValuesDefinedAbove( |
65 | MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) { |
66 | for (Region ®ion : regions) |
67 | visitUsedValuesDefinedAbove(region, limit&: region, callback); |
68 | } |
69 | |
70 | void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, |
71 | SetVector<Value> &values) { |
72 | visitUsedValuesDefinedAbove(region, limit, callback: [&](OpOperand *operand) { |
73 | values.insert(X: operand->get()); |
74 | }); |
75 | } |
76 | |
77 | void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions, |
78 | SetVector<Value> &values) { |
79 | for (Region ®ion : regions) |
80 | getUsedValuesDefinedAbove(region, limit&: region, values); |
81 | } |
82 | |
83 | //===----------------------------------------------------------------------===// |
84 | // Make block isolated from above. |
85 | //===----------------------------------------------------------------------===// |
86 | |
87 | SmallVector<Value> mlir::makeRegionIsolatedFromAbove( |
88 | RewriterBase &rewriter, Region ®ion, |
89 | llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) { |
90 | |
91 | // Get initial list of values used within region but defined above. |
92 | llvm::SetVector<Value> initialCapturedValues; |
93 | mlir::getUsedValuesDefinedAbove(regions: region, values&: initialCapturedValues); |
94 | |
95 | std::deque<Value> worklist(initialCapturedValues.begin(), |
96 | initialCapturedValues.end()); |
97 | llvm::DenseSet<Value> visited; |
98 | llvm::DenseSet<Operation *> visitedOps; |
99 | |
100 | llvm::SetVector<Value> finalCapturedValues; |
101 | SmallVector<Operation *> clonedOperations; |
102 | while (!worklist.empty()) { |
103 | Value currValue = worklist.front(); |
104 | worklist.pop_front(); |
105 | if (visited.count(V: currValue)) |
106 | continue; |
107 | visited.insert(V: currValue); |
108 | |
109 | Operation *definingOp = currValue.getDefiningOp(); |
110 | if (!definingOp || visitedOps.count(V: definingOp)) { |
111 | finalCapturedValues.insert(X: currValue); |
112 | continue; |
113 | } |
114 | visitedOps.insert(V: definingOp); |
115 | |
116 | if (!cloneOperationIntoRegion(definingOp)) { |
117 | // Defining operation isnt cloned, so add the current value to final |
118 | // captured values list. |
119 | finalCapturedValues.insert(X: currValue); |
120 | continue; |
121 | } |
122 | |
123 | // Add all operands of the operation to the worklist and mark the op as to |
124 | // be cloned. |
125 | for (Value operand : definingOp->getOperands()) { |
126 | if (visited.count(V: operand)) |
127 | continue; |
128 | worklist.push_back(x: operand); |
129 | } |
130 | clonedOperations.push_back(Elt: definingOp); |
131 | } |
132 | |
133 | // The operations to be cloned need to be ordered in topological order |
134 | // so that they can be cloned into the region without violating use-def |
135 | // chains. |
136 | mlir::computeTopologicalSorting(ops: clonedOperations); |
137 | |
138 | OpBuilder::InsertionGuard g(rewriter); |
139 | // Collect types of existing block |
140 | Block *entryBlock = ®ion.front(); |
141 | SmallVector<Type> newArgTypes = |
142 | llvm::to_vector(Range: entryBlock->getArgumentTypes()); |
143 | SmallVector<Location> newArgLocs = llvm::to_vector(Range: llvm::map_range( |
144 | C: entryBlock->getArguments(), F: [](BlockArgument b) { return b.getLoc(); })); |
145 | |
146 | // Append the types of the captured values. |
147 | for (auto value : finalCapturedValues) { |
148 | newArgTypes.push_back(Elt: value.getType()); |
149 | newArgLocs.push_back(Elt: value.getLoc()); |
150 | } |
151 | |
152 | // Create a new entry block. |
153 | Block *newEntryBlock = |
154 | rewriter.createBlock(parent: ®ion, insertPt: region.begin(), argTypes: newArgTypes, locs: newArgLocs); |
155 | auto newEntryBlockArgs = newEntryBlock->getArguments(); |
156 | |
157 | // Create a mapping between the captured values and the new arguments added. |
158 | IRMapping map; |
159 | auto replaceIfFn = [&](OpOperand &use) { |
160 | return use.getOwner()->getBlock()->getParent() == ®ion; |
161 | }; |
162 | for (auto [arg, capturedVal] : |
163 | llvm::zip(t: newEntryBlockArgs.take_back(N: finalCapturedValues.size()), |
164 | u&: finalCapturedValues)) { |
165 | map.map(from: capturedVal, to: arg); |
166 | rewriter.replaceUsesWithIf(from: capturedVal, to: arg, functor: replaceIfFn); |
167 | } |
168 | rewriter.setInsertionPointToStart(newEntryBlock); |
169 | for (auto *clonedOp : clonedOperations) { |
170 | Operation *newOp = rewriter.clone(op&: *clonedOp, mapper&: map); |
171 | rewriter.replaceOpUsesWithIf(from: clonedOp, to: newOp->getResults(), functor: replaceIfFn); |
172 | } |
173 | rewriter.mergeBlocks( |
174 | source: entryBlock, dest: newEntryBlock, |
175 | argValues: newEntryBlock->getArguments().take_front(N: entryBlock->getNumArguments())); |
176 | return llvm::to_vector(Range&: finalCapturedValues); |
177 | } |
178 | |
179 | //===----------------------------------------------------------------------===// |
180 | // Unreachable Block Elimination |
181 | //===----------------------------------------------------------------------===// |
182 | |
183 | /// Erase the unreachable blocks within the provided regions. Returns success |
184 | /// if any blocks were erased, failure otherwise. |
185 | // TODO: We could likely merge this with the DCE algorithm below. |
186 | LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, |
187 | MutableArrayRef<Region> regions) { |
188 | // Set of blocks found to be reachable within a given region. |
189 | llvm::df_iterator_default_set<Block *, 16> reachable; |
190 | // If any blocks were found to be dead. |
191 | bool erasedDeadBlocks = false; |
192 | |
193 | SmallVector<Region *, 1> worklist; |
194 | worklist.reserve(N: regions.size()); |
195 | for (Region ®ion : regions) |
196 | worklist.push_back(Elt: ®ion); |
197 | while (!worklist.empty()) { |
198 | Region *region = worklist.pop_back_val(); |
199 | if (region->empty()) |
200 | continue; |
201 | |
202 | // If this is a single block region, just collect the nested regions. |
203 | if (region->hasOneBlock()) { |
204 | for (Operation &op : region->front()) |
205 | for (Region ®ion : op.getRegions()) |
206 | worklist.push_back(Elt: ®ion); |
207 | continue; |
208 | } |
209 | |
210 | // Mark all reachable blocks. |
211 | reachable.clear(); |
212 | for (Block *block : depth_first_ext(G: ®ion->front(), S&: reachable)) |
213 | (void)block /* Mark all reachable blocks */; |
214 | |
215 | // Collect all of the dead blocks and push the live regions onto the |
216 | // worklist. |
217 | for (Block &block : llvm::make_early_inc_range(Range&: *region)) { |
218 | if (!reachable.count(Ptr: &block)) { |
219 | block.dropAllDefinedValueUses(); |
220 | rewriter.eraseBlock(block: &block); |
221 | erasedDeadBlocks = true; |
222 | continue; |
223 | } |
224 | |
225 | // Walk any regions within this block. |
226 | for (Operation &op : block) |
227 | for (Region ®ion : op.getRegions()) |
228 | worklist.push_back(Elt: ®ion); |
229 | } |
230 | } |
231 | |
232 | return success(IsSuccess: erasedDeadBlocks); |
233 | } |
234 | |
235 | //===----------------------------------------------------------------------===// |
236 | // Dead Code Elimination |
237 | //===----------------------------------------------------------------------===// |
238 | |
239 | namespace { |
240 | /// Data structure used to track which values have already been proved live. |
241 | /// |
242 | /// Because Operation's can have multiple results, this data structure tracks |
243 | /// liveness for both Value's and Operation's to avoid having to look through |
244 | /// all Operation results when analyzing a use. |
245 | /// |
246 | /// This data structure essentially tracks the dataflow lattice. |
247 | /// The set of values/ops proved live increases monotonically to a fixed-point. |
248 | class LiveMap { |
249 | public: |
250 | /// Value methods. |
251 | bool wasProvenLive(Value value) { |
252 | // TODO: For results that are removable, e.g. for region based control flow, |
253 | // we could allow for these values to be tracked independently. |
254 | if (OpResult result = dyn_cast<OpResult>(Val&: value)) |
255 | return wasProvenLive(op: result.getOwner()); |
256 | return wasProvenLive(arg: cast<BlockArgument>(Val&: value)); |
257 | } |
258 | bool wasProvenLive(BlockArgument arg) { return liveValues.count(V: arg); } |
259 | void setProvedLive(Value value) { |
260 | // TODO: For results that are removable, e.g. for region based control flow, |
261 | // we could allow for these values to be tracked independently. |
262 | if (OpResult result = dyn_cast<OpResult>(Val&: value)) |
263 | return setProvedLive(result.getOwner()); |
264 | setProvedLive(cast<BlockArgument>(Val&: value)); |
265 | } |
266 | void setProvedLive(BlockArgument arg) { |
267 | changed |= liveValues.insert(V: arg).second; |
268 | } |
269 | |
270 | /// Operation methods. |
271 | bool wasProvenLive(Operation *op) { return liveOps.count(V: op); } |
272 | void setProvedLive(Operation *op) { changed |= liveOps.insert(V: op).second; } |
273 | |
274 | /// Methods for tracking if we have reached a fixed-point. |
275 | void resetChanged() { changed = false; } |
276 | bool hasChanged() { return changed; } |
277 | |
278 | private: |
279 | bool changed = false; |
280 | DenseSet<Value> liveValues; |
281 | DenseSet<Operation *> liveOps; |
282 | }; |
283 | } // namespace |
284 | |
285 | static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) { |
286 | Operation *owner = use.getOwner(); |
287 | unsigned operandIndex = use.getOperandNumber(); |
288 | // This pass generally treats all uses of an op as live if the op itself is |
289 | // considered live. However, for successor operands to terminators we need a |
290 | // finer-grained notion where we deduce liveness for operands individually. |
291 | // The reason for this is easiest to think about in terms of a classical phi |
292 | // node based SSA IR, where each successor operand is really an operand to a |
293 | // *separate* phi node, rather than all operands to the branch itself as with |
294 | // the block argument representation that MLIR uses. |
295 | // |
296 | // And similarly, because each successor operand is really an operand to a phi |
297 | // node, rather than to the terminator op itself, a terminator op can't e.g. |
298 | // "print" the value of a successor operand. |
299 | if (owner->hasTrait<OpTrait::IsTerminator>()) { |
300 | if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner)) |
301 | if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex)) |
302 | return !liveMap.wasProvenLive(*arg); |
303 | return false; |
304 | } |
305 | return false; |
306 | } |
307 | |
308 | static void processValue(Value value, LiveMap &liveMap) { |
309 | bool provedLive = llvm::any_of(Range: value.getUses(), P: [&](OpOperand &use) { |
310 | if (isUseSpeciallyKnownDead(use, liveMap)) |
311 | return false; |
312 | return liveMap.wasProvenLive(op: use.getOwner()); |
313 | }); |
314 | if (provedLive) |
315 | liveMap.setProvedLive(value); |
316 | } |
317 | |
318 | static void propagateLiveness(Region ®ion, LiveMap &liveMap); |
319 | |
320 | static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) { |
321 | // Terminators are always live. |
322 | liveMap.setProvedLive(op); |
323 | |
324 | // Check to see if we can reason about the successor operands and mutate them. |
325 | BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op); |
326 | if (!branchInterface) { |
327 | for (Block *successor : op->getSuccessors()) |
328 | for (BlockArgument arg : successor->getArguments()) |
329 | liveMap.setProvedLive(arg); |
330 | return; |
331 | } |
332 | |
333 | // If we can't reason about the operand to a successor, conservatively mark |
334 | // it as live. |
335 | for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { |
336 | SuccessorOperands successorOperands = |
337 | branchInterface.getSuccessorOperands(i); |
338 | for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount(); |
339 | opI != opE; ++opI) |
340 | liveMap.setProvedLive(op->getSuccessor(index: i)->getArgument(i: opI)); |
341 | } |
342 | } |
343 | |
344 | static void propagateLiveness(Operation *op, LiveMap &liveMap) { |
345 | // Recurse on any regions the op has. |
346 | for (Region ®ion : op->getRegions()) |
347 | propagateLiveness(region, liveMap); |
348 | |
349 | // Process terminator operations. |
350 | if (op->hasTrait<OpTrait::IsTerminator>()) |
351 | return propagateTerminatorLiveness(op, liveMap); |
352 | |
353 | // Don't reprocess live operations. |
354 | if (liveMap.wasProvenLive(op)) |
355 | return; |
356 | |
357 | // Process the op itself. |
358 | if (!wouldOpBeTriviallyDead(op)) |
359 | return liveMap.setProvedLive(op); |
360 | |
361 | // If the op isn't intrinsically alive, check it's results. |
362 | for (Value value : op->getResults()) |
363 | processValue(value, liveMap); |
364 | } |
365 | |
366 | static void propagateLiveness(Region ®ion, LiveMap &liveMap) { |
367 | if (region.empty()) |
368 | return; |
369 | |
370 | for (Block *block : llvm::post_order(G: ®ion.front())) { |
371 | // We process block arguments after the ops in the block, to promote |
372 | // faster convergence to a fixed point (we try to visit uses before defs). |
373 | for (Operation &op : llvm::reverse(C&: block->getOperations())) |
374 | propagateLiveness(op: &op, liveMap); |
375 | |
376 | // We currently do not remove entry block arguments, so there is no need to |
377 | // track their liveness. |
378 | // TODO: We could track these and enable removing dead operands/arguments |
379 | // from region control flow operations. |
380 | if (block->isEntryBlock()) |
381 | continue; |
382 | |
383 | for (Value value : block->getArguments()) { |
384 | if (!liveMap.wasProvenLive(value)) |
385 | processValue(value, liveMap); |
386 | } |
387 | } |
388 | } |
389 | |
390 | static void eraseTerminatorSuccessorOperands(Operation *terminator, |
391 | LiveMap &liveMap) { |
392 | BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator); |
393 | if (!branchOp) |
394 | return; |
395 | |
396 | for (unsigned succI = 0, succE = terminator->getNumSuccessors(); |
397 | succI < succE; succI++) { |
398 | // Iterating successors in reverse is not strictly needed, since we |
399 | // aren't erasing any successors. But it is slightly more efficient |
400 | // since it will promote later operands of the terminator being erased |
401 | // first, reducing the quadratic-ness. |
402 | unsigned succ = succE - succI - 1; |
403 | SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ); |
404 | Block *successor = terminator->getSuccessor(index: succ); |
405 | |
406 | for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) { |
407 | // Iterating args in reverse is needed for correctness, to avoid |
408 | // shifting later args when earlier args are erased. |
409 | unsigned arg = argE - argI - 1; |
410 | if (!liveMap.wasProvenLive(arg: successor->getArgument(i: arg))) |
411 | succOperands.erase(subStart: arg); |
412 | } |
413 | } |
414 | } |
415 | |
416 | static LogicalResult deleteDeadness(RewriterBase &rewriter, |
417 | MutableArrayRef<Region> regions, |
418 | LiveMap &liveMap) { |
419 | bool erasedAnything = false; |
420 | for (Region ®ion : regions) { |
421 | if (region.empty()) |
422 | continue; |
423 | bool hasSingleBlock = llvm::hasSingleElement(C&: region); |
424 | |
425 | // Delete every operation that is not live. Graph regions may have cycles |
426 | // in the use-def graph, so we must explicitly dropAllUses() from each |
427 | // operation as we erase it. Visiting the operations in post-order |
428 | // guarantees that in SSA CFG regions value uses are removed before defs, |
429 | // which makes dropAllUses() a no-op. |
430 | for (Block *block : llvm::post_order(G: ®ion.front())) { |
431 | if (!hasSingleBlock) |
432 | eraseTerminatorSuccessorOperands(terminator: block->getTerminator(), liveMap); |
433 | for (Operation &childOp : |
434 | llvm::make_early_inc_range(Range: llvm::reverse(C&: block->getOperations()))) { |
435 | if (!liveMap.wasProvenLive(op: &childOp)) { |
436 | erasedAnything = true; |
437 | childOp.dropAllUses(); |
438 | rewriter.eraseOp(op: &childOp); |
439 | } else { |
440 | erasedAnything |= succeeded( |
441 | Result: deleteDeadness(rewriter, regions: childOp.getRegions(), liveMap)); |
442 | } |
443 | } |
444 | } |
445 | // Delete block arguments. |
446 | // The entry block has an unknown contract with their enclosing block, so |
447 | // skip it. |
448 | for (Block &block : llvm::drop_begin(RangeOrContainer&: region.getBlocks(), N: 1)) { |
449 | block.eraseArguments( |
450 | shouldEraseFn: [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); }); |
451 | } |
452 | } |
453 | return success(IsSuccess: erasedAnything); |
454 | } |
455 | |
456 | // This function performs a simple dead code elimination algorithm over the |
457 | // given regions. |
458 | // |
459 | // The overall goal is to prove that Values are dead, which allows deleting ops |
460 | // and block arguments. |
461 | // |
462 | // This uses an optimistic algorithm that assumes everything is dead until |
463 | // proved otherwise, allowing it to delete recursively dead cycles. |
464 | // |
465 | // This is a simple fixed-point dataflow analysis algorithm on a lattice |
466 | // {Dead,Alive}. Because liveness flows backward, we generally try to |
467 | // iterate everything backward to speed up convergence to the fixed-point. This |
468 | // allows for being able to delete recursively dead cycles of the use-def graph, |
469 | // including block arguments. |
470 | // |
471 | // This function returns success if any operations or arguments were deleted, |
472 | // failure otherwise. |
473 | LogicalResult mlir::runRegionDCE(RewriterBase &rewriter, |
474 | MutableArrayRef<Region> regions) { |
475 | LiveMap liveMap; |
476 | do { |
477 | liveMap.resetChanged(); |
478 | |
479 | for (Region ®ion : regions) |
480 | propagateLiveness(region, liveMap); |
481 | } while (liveMap.hasChanged()); |
482 | |
483 | return deleteDeadness(rewriter, regions, liveMap); |
484 | } |
485 | |
486 | //===----------------------------------------------------------------------===// |
487 | // Block Merging |
488 | //===----------------------------------------------------------------------===// |
489 | |
490 | //===----------------------------------------------------------------------===// |
491 | // BlockEquivalenceData |
492 | //===----------------------------------------------------------------------===// |
493 | |
494 | namespace { |
495 | /// This class contains the information for comparing the equivalencies of two |
496 | /// blocks. Blocks are considered equivalent if they contain the same operations |
497 | /// in the same order. The only allowed divergence is for operands that come |
498 | /// from sources outside of the parent block, i.e. the uses of values produced |
499 | /// within the block must be equivalent. |
500 | /// e.g., |
501 | /// Equivalent: |
502 | /// ^bb1(%arg0: i32) |
503 | /// return %arg0, %foo : i32, i32 |
504 | /// ^bb2(%arg1: i32) |
505 | /// return %arg1, %bar : i32, i32 |
506 | /// Not Equivalent: |
507 | /// ^bb1(%arg0: i32) |
508 | /// return %foo, %arg0 : i32, i32 |
509 | /// ^bb2(%arg1: i32) |
510 | /// return %arg1, %bar : i32, i32 |
511 | struct BlockEquivalenceData { |
512 | BlockEquivalenceData(Block *block); |
513 | |
514 | /// Return the order index for the given value that is within the block of |
515 | /// this data. |
516 | unsigned getOrderOf(Value value) const; |
517 | |
518 | /// The block this data refers to. |
519 | Block *block; |
520 | /// A hash value for this block. |
521 | llvm::hash_code hash; |
522 | /// A map of result producing operations to their relative orders within this |
523 | /// block. The order of an operation is the number of defined values that are |
524 | /// produced within the block before this operation. |
525 | DenseMap<Operation *, unsigned> opOrderIndex; |
526 | }; |
527 | } // namespace |
528 | |
529 | BlockEquivalenceData::BlockEquivalenceData(Block *block) |
530 | : block(block), hash(0) { |
531 | unsigned orderIt = block->getNumArguments(); |
532 | for (Operation &op : *block) { |
533 | if (unsigned numResults = op.getNumResults()) { |
534 | opOrderIndex.try_emplace(Key: &op, Args&: orderIt); |
535 | orderIt += numResults; |
536 | } |
537 | auto opHash = OperationEquivalence::computeHash( |
538 | op: &op, hashOperands: OperationEquivalence::ignoreHashValue, |
539 | hashResults: OperationEquivalence::ignoreHashValue, |
540 | flags: OperationEquivalence::IgnoreLocations); |
541 | hash = llvm::hash_combine(args: hash, args: opHash); |
542 | } |
543 | } |
544 | |
545 | unsigned BlockEquivalenceData::getOrderOf(Value value) const { |
546 | assert(value.getParentBlock() == block && "expected value of this block"); |
547 | |
548 | // Arguments use the argument number as the order index. |
549 | if (BlockArgument arg = dyn_cast<BlockArgument>(Val&: value)) |
550 | return arg.getArgNumber(); |
551 | |
552 | // Otherwise, the result order is offset from the parent op's order. |
553 | OpResult result = cast<OpResult>(Val&: value); |
554 | auto opOrderIt = opOrderIndex.find(Val: result.getDefiningOp()); |
555 | assert(opOrderIt != opOrderIndex.end() && "expected op to have an order"); |
556 | return opOrderIt->second + result.getResultNumber(); |
557 | } |
558 | |
559 | //===----------------------------------------------------------------------===// |
560 | // BlockMergeCluster |
561 | //===----------------------------------------------------------------------===// |
562 | |
563 | namespace { |
564 | /// This class represents a cluster of blocks to be merged together. |
565 | class BlockMergeCluster { |
566 | public: |
567 | BlockMergeCluster(BlockEquivalenceData &&leaderData) |
568 | : leaderData(std::move(leaderData)) {} |
569 | |
570 | /// Attempt to add the given block to this cluster. Returns success if the |
571 | /// block was merged, failure otherwise. |
572 | LogicalResult addToCluster(BlockEquivalenceData &blockData); |
573 | |
574 | /// Try to merge all of the blocks within this cluster into the leader block. |
575 | LogicalResult merge(RewriterBase &rewriter); |
576 | |
577 | private: |
578 | /// The equivalence data for the leader of the cluster. |
579 | BlockEquivalenceData leaderData; |
580 | |
581 | /// The set of blocks that can be merged into the leader. |
582 | llvm::SmallSetVector<Block *, 1> blocksToMerge; |
583 | |
584 | /// A set of operand+index pairs that correspond to operands that need to be |
585 | /// replaced by arguments when the cluster gets merged. |
586 | std::set<std::pair<int, int>> operandsToMerge; |
587 | }; |
588 | } // namespace |
589 | |
590 | LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) { |
591 | if (leaderData.hash != blockData.hash) |
592 | return failure(); |
593 | Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block; |
594 | if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes()) |
595 | return failure(); |
596 | |
597 | // A set of operands that mismatch between the leader and the new block. |
598 | SmallVector<std::pair<int, int>, 8> mismatchedOperands; |
599 | auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end(); |
600 | auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end(); |
601 | for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) { |
602 | // Check that the operations are equivalent. |
603 | if (!OperationEquivalence::isEquivalentTo( |
604 | lhs: &*lhsIt, rhs: &*rhsIt, checkEquivalent: OperationEquivalence::ignoreValueEquivalence, |
605 | /*markEquivalent=*/nullptr, |
606 | flags: OperationEquivalence::Flags::IgnoreLocations)) |
607 | return failure(); |
608 | |
609 | // Compare the operands of the two operations. If the operand is within |
610 | // the block, it must refer to the same operation. |
611 | auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands(); |
612 | for (int operand : llvm::seq<int>(Begin: 0, End: lhsIt->getNumOperands())) { |
613 | Value lhsOperand = lhsOperands[operand]; |
614 | Value rhsOperand = rhsOperands[operand]; |
615 | if (lhsOperand == rhsOperand) |
616 | continue; |
617 | // Check that the types of the operands match. |
618 | if (lhsOperand.getType() != rhsOperand.getType()) |
619 | return failure(); |
620 | |
621 | // Check that these uses are both external, or both internal. |
622 | bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock; |
623 | bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock; |
624 | if (lhsIsInBlock != rhsIsInBlock) |
625 | return failure(); |
626 | // Let the operands differ if they are defined in a different block. These |
627 | // will become new arguments if the blocks get merged. |
628 | if (!lhsIsInBlock) { |
629 | |
630 | // Check whether the operands aren't the result of an immediate |
631 | // predecessors terminator. In that case we are not able to use it as a |
632 | // successor operand when branching to the merged block as it does not |
633 | // dominate its producing operation. |
634 | auto isValidSuccessorArg = [](Block *block, Value operand) { |
635 | if (operand.getDefiningOp() != |
636 | operand.getParentBlock()->getTerminator()) |
637 | return true; |
638 | return !llvm::is_contained(Range: block->getPredecessors(), |
639 | Element: operand.getParentBlock()); |
640 | }; |
641 | |
642 | if (!isValidSuccessorArg(leaderBlock, lhsOperand) || |
643 | !isValidSuccessorArg(mergeBlock, rhsOperand)) |
644 | return failure(); |
645 | |
646 | mismatchedOperands.emplace_back(Args&: opI, Args&: operand); |
647 | continue; |
648 | } |
649 | |
650 | // Otherwise, these operands must have the same logical order within the |
651 | // parent block. |
652 | if (leaderData.getOrderOf(value: lhsOperand) != blockData.getOrderOf(value: rhsOperand)) |
653 | return failure(); |
654 | } |
655 | |
656 | // If the lhs or rhs has external uses, the blocks cannot be merged as the |
657 | // merged version of this operation will not be either the lhs or rhs |
658 | // alone (thus semantically incorrect), but some mix dependending on which |
659 | // block preceeded this. |
660 | // TODO allow merging of operations when one block does not dominate the |
661 | // other |
662 | if (rhsIt->isUsedOutsideOfBlock(block: mergeBlock) || |
663 | lhsIt->isUsedOutsideOfBlock(block: leaderBlock)) { |
664 | return failure(); |
665 | } |
666 | } |
667 | // Make sure that the block sizes are equivalent. |
668 | if (lhsIt != lhsE || rhsIt != rhsE) |
669 | return failure(); |
670 | |
671 | // If we get here, the blocks are equivalent and can be merged. |
672 | operandsToMerge.insert(first: mismatchedOperands.begin(), last: mismatchedOperands.end()); |
673 | blocksToMerge.insert(X: blockData.block); |
674 | return success(); |
675 | } |
676 | |
677 | /// Returns true if the predecessor terminators of the given block can not have |
678 | /// their operands updated. |
679 | static bool ableToUpdatePredOperands(Block *block) { |
680 | for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { |
681 | if (!isa<BranchOpInterface>(Val: (*it)->getTerminator())) |
682 | return false; |
683 | } |
684 | return true; |
685 | } |
686 | |
687 | /// Prunes the redundant list of new arguments. E.g., if we are passing an |
688 | /// argument list like [x, y, z, x] this would return [x, y, z] and it would |
689 | /// update the `block` (to whom the argument are passed to) accordingly. The new |
690 | /// arguments are passed as arguments at the back of the block, hence we need to |
691 | /// know how many `numOldArguments` were before, in order to correctly replace |
692 | /// the new arguments in the block |
693 | static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments( |
694 | const SmallVector<SmallVector<Value, 8>, 2> &newArguments, |
695 | RewriterBase &rewriter, unsigned numOldArguments, Block *block) { |
696 | |
697 | SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned( |
698 | newArguments.size(), SmallVector<Value, 8>()); |
699 | |
700 | if (newArguments.empty()) |
701 | return newArguments; |
702 | |
703 | // `newArguments` is a 2D array of size `numLists` x `numArgs` |
704 | unsigned numLists = newArguments.size(); |
705 | unsigned numArgs = newArguments[0].size(); |
706 | |
707 | // Map that for each arg index contains the index that we can use in place of |
708 | // the original index. E.g., if we have newArgs = [x, y, z, x], we will have |
709 | // idxToReplacement[3] = 0 |
710 | llvm::DenseMap<unsigned, unsigned> idxToReplacement; |
711 | |
712 | // This is a useful data structure to track the first appearance of a Value |
713 | // on a given list of arguments |
714 | DenseMap<Value, unsigned> firstValueToIdx; |
715 | for (unsigned j = 0; j < numArgs; ++j) { |
716 | Value newArg = newArguments[0][j]; |
717 | firstValueToIdx.try_emplace(Key: newArg, Args&: j); |
718 | } |
719 | |
720 | // Go through the first list of arguments (list 0). |
721 | for (unsigned j = 0; j < numArgs; ++j) { |
722 | // Look back to see if there are possible redundancies in list 0. Please |
723 | // note that we are using a map to annotate when an argument was seen first |
724 | // to avoid a O(N^2) algorithm. This has the drawback that if we have two |
725 | // lists like: |
726 | // list0: [%a, %a, %a] |
727 | // list1: [%c, %b, %b] |
728 | // We cannot simplify it, because firstValueToIdx[%a] = 0, but we cannot |
729 | // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since |
730 | // the number of arguments can be potentially unbounded we cannot afford a |
731 | // O(N^2) algorithm (to search to all the possible pairs) and we need to |
732 | // accept the trade-off. |
733 | unsigned k = firstValueToIdx[newArguments[0][j]]; |
734 | if (k == j) |
735 | continue; |
736 | |
737 | bool shouldReplaceJ = true; |
738 | unsigned replacement = k; |
739 | // If a possible redundancy is found, then scan the other lists: we |
740 | // can prune the arguments if and only if they are redundant in every |
741 | // list. |
742 | for (unsigned i = 1; i < numLists; ++i) |
743 | shouldReplaceJ = |
744 | shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]); |
745 | // Save the replacement. |
746 | if (shouldReplaceJ) |
747 | idxToReplacement[j] = replacement; |
748 | } |
749 | |
750 | // Populate the pruned argument list. |
751 | for (unsigned i = 0; i < numLists; ++i) |
752 | for (unsigned j = 0; j < numArgs; ++j) |
753 | if (!idxToReplacement.contains(Val: j)) |
754 | newArgumentsPruned[i].push_back(Elt: newArguments[i][j]); |
755 | |
756 | // Replace the block's redundant arguments. |
757 | SmallVector<unsigned> toErase; |
758 | for (auto [idx, arg] : llvm::enumerate(First: block->getArguments())) { |
759 | if (idxToReplacement.contains(Val: idx)) { |
760 | Value oldArg = block->getArgument(i: numOldArguments + idx); |
761 | Value newArg = |
762 | block->getArgument(i: numOldArguments + idxToReplacement[idx]); |
763 | rewriter.replaceAllUsesWith(from: oldArg, to: newArg); |
764 | toErase.push_back(Elt: numOldArguments + idx); |
765 | } |
766 | } |
767 | |
768 | // Erase the block's redundant arguments. |
769 | for (unsigned idxToErase : llvm::reverse(C&: toErase)) |
770 | block->eraseArgument(index: idxToErase); |
771 | return newArgumentsPruned; |
772 | } |
773 | |
774 | LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) { |
775 | // Don't consider clusters that don't have blocks to merge. |
776 | if (blocksToMerge.empty()) |
777 | return failure(); |
778 | |
779 | Block *leaderBlock = leaderData.block; |
780 | if (!operandsToMerge.empty()) { |
781 | // If the cluster has operands to merge, verify that the predecessor |
782 | // terminators of each of the blocks can have their successor operands |
783 | // updated. |
784 | // TODO: We could try and sub-partition this cluster if only some blocks |
785 | // cause the mismatch. |
786 | if (!ableToUpdatePredOperands(block: leaderBlock) || |
787 | !llvm::all_of(Range&: blocksToMerge, P: ableToUpdatePredOperands)) |
788 | return failure(); |
789 | |
790 | // Collect the iterators for each of the blocks to merge. We will walk all |
791 | // of the iterators at once to avoid operand index invalidation. |
792 | SmallVector<Block::iterator, 2> blockIterators; |
793 | blockIterators.reserve(N: blocksToMerge.size() + 1); |
794 | blockIterators.push_back(Elt: leaderBlock->begin()); |
795 | for (Block *mergeBlock : blocksToMerge) |
796 | blockIterators.push_back(Elt: mergeBlock->begin()); |
797 | |
798 | // Update each of the predecessor terminators with the new arguments. |
799 | SmallVector<SmallVector<Value, 8>, 2> newArguments( |
800 | 1 + blocksToMerge.size(), |
801 | SmallVector<Value, 8>(operandsToMerge.size())); |
802 | unsigned curOpIndex = 0; |
803 | unsigned numOldArguments = leaderBlock->getNumArguments(); |
804 | for (const auto &it : llvm::enumerate(First&: operandsToMerge)) { |
805 | unsigned nextOpOffset = it.value().first - curOpIndex; |
806 | curOpIndex = it.value().first; |
807 | |
808 | // Process the operand for each of the block iterators. |
809 | for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) { |
810 | Block::iterator &blockIter = blockIterators[i]; |
811 | std::advance(i&: blockIter, n: nextOpOffset); |
812 | auto &operand = blockIter->getOpOperand(idx: it.value().second); |
813 | newArguments[i][it.index()] = operand.get(); |
814 | |
815 | // Update the operand and insert an argument if this is the leader. |
816 | if (i == 0) { |
817 | Value operandVal = operand.get(); |
818 | operand.set(leaderBlock->addArgument(type: operandVal.getType(), |
819 | loc: operandVal.getLoc())); |
820 | } |
821 | } |
822 | } |
823 | |
824 | // Prune redundant arguments and update the leader block argument list |
825 | newArguments = pruneRedundantArguments(newArguments, rewriter, |
826 | numOldArguments, block: leaderBlock); |
827 | |
828 | // Update the predecessors for each of the blocks. |
829 | auto updatePredecessors = [&](Block *block, unsigned clusterIndex) { |
830 | for (auto predIt = block->pred_begin(), predE = block->pred_end(); |
831 | predIt != predE; ++predIt) { |
832 | auto branch = cast<BranchOpInterface>((*predIt)->getTerminator()); |
833 | unsigned succIndex = predIt.getSuccessorIndex(); |
834 | branch.getSuccessorOperands(succIndex).append( |
835 | newArguments[clusterIndex]); |
836 | } |
837 | }; |
838 | updatePredecessors(leaderBlock, /*clusterIndex=*/0); |
839 | for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i) |
840 | updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1); |
841 | } |
842 | |
843 | // Replace all uses of the merged blocks with the leader and erase them. |
844 | for (Block *block : blocksToMerge) { |
845 | block->replaceAllUsesWith(newValue&: leaderBlock); |
846 | rewriter.eraseBlock(block); |
847 | } |
848 | return success(); |
849 | } |
850 | |
851 | /// Identify identical blocks within the given region and merge them, inserting |
852 | /// new block arguments as necessary. Returns success if any blocks were merged, |
853 | /// failure otherwise. |
854 | static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, |
855 | Region ®ion) { |
856 | if (region.empty() || llvm::hasSingleElement(C&: region)) |
857 | return failure(); |
858 | |
859 | // Identify sets of blocks, other than the entry block, that branch to the |
860 | // same successors. We will use these groups to create clusters of equivalent |
861 | // blocks. |
862 | DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors; |
863 | for (Block &block : llvm::drop_begin(RangeOrContainer&: region, N: 1)) |
864 | matchingSuccessors[block.getSuccessors()].push_back(Elt: &block); |
865 | |
866 | bool mergedAnyBlocks = false; |
867 | for (ArrayRef<Block *> blocks : llvm::make_second_range(c&: matchingSuccessors)) { |
868 | if (blocks.size() == 1) |
869 | continue; |
870 | |
871 | SmallVector<BlockMergeCluster, 1> clusters; |
872 | for (Block *block : blocks) { |
873 | BlockEquivalenceData data(block); |
874 | |
875 | // Don't allow merging if this block has any regions. |
876 | // TODO: Add support for regions if necessary. |
877 | bool hasNonEmptyRegion = llvm::any_of(Range&: *block, P: [](Operation &op) { |
878 | return llvm::any_of(Range: op.getRegions(), |
879 | P: [](Region ®ion) { return !region.empty(); }); |
880 | }); |
881 | if (hasNonEmptyRegion) |
882 | continue; |
883 | |
884 | // Don't allow merging if this block's arguments are used outside of the |
885 | // original block. |
886 | bool argHasExternalUsers = llvm::any_of( |
887 | Range: block->getArguments(), P: [block](mlir::BlockArgument &arg) { |
888 | return arg.isUsedOutsideOfBlock(block); |
889 | }); |
890 | if (argHasExternalUsers) |
891 | continue; |
892 | |
893 | // Try to add this block to an existing cluster. |
894 | bool addedToCluster = false; |
895 | for (auto &cluster : clusters) |
896 | if ((addedToCluster = succeeded(Result: cluster.addToCluster(blockData&: data)))) |
897 | break; |
898 | if (!addedToCluster) |
899 | clusters.emplace_back(Args: std::move(data)); |
900 | } |
901 | for (auto &cluster : clusters) |
902 | mergedAnyBlocks |= succeeded(Result: cluster.merge(rewriter)); |
903 | } |
904 | |
905 | return success(IsSuccess: mergedAnyBlocks); |
906 | } |
907 | |
908 | /// Identify identical blocks within the given regions and merge them, inserting |
909 | /// new block arguments as necessary. |
910 | static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, |
911 | MutableArrayRef<Region> regions) { |
912 | llvm::SmallSetVector<Region *, 1> worklist; |
913 | for (auto ®ion : regions) |
914 | worklist.insert(X: ®ion); |
915 | bool anyChanged = false; |
916 | while (!worklist.empty()) { |
917 | Region *region = worklist.pop_back_val(); |
918 | if (succeeded(Result: mergeIdenticalBlocks(rewriter, region&: *region))) { |
919 | worklist.insert(X: region); |
920 | anyChanged = true; |
921 | } |
922 | |
923 | // Add any nested regions to the worklist. |
924 | for (Block &block : *region) |
925 | for (auto &op : block) |
926 | for (auto &nestedRegion : op.getRegions()) |
927 | worklist.insert(X: &nestedRegion); |
928 | } |
929 | |
930 | return success(IsSuccess: anyChanged); |
931 | } |
932 | |
933 | /// If a block's argument is always the same across different invocations, then |
934 | /// drop the argument and use the value directly inside the block |
935 | static LogicalResult dropRedundantArguments(RewriterBase &rewriter, |
936 | Block &block) { |
937 | SmallVector<size_t> argsToErase; |
938 | |
939 | // Go through the arguments of the block. |
940 | for (auto [argIdx, blockOperand] : llvm::enumerate(First: block.getArguments())) { |
941 | bool sameArg = true; |
942 | Value commonValue; |
943 | |
944 | // Go through the block predecessor and flag if they pass to the block |
945 | // different values for the same argument. |
946 | for (Block::pred_iterator predIt = block.pred_begin(), |
947 | predE = block.pred_end(); |
948 | predIt != predE; ++predIt) { |
949 | auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator()); |
950 | if (!branch) { |
951 | sameArg = false; |
952 | break; |
953 | } |
954 | unsigned succIndex = predIt.getSuccessorIndex(); |
955 | SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex); |
956 | auto branchOperands = succOperands.getForwardedOperands(); |
957 | if (!commonValue) { |
958 | commonValue = branchOperands[argIdx]; |
959 | continue; |
960 | } |
961 | if (branchOperands[argIdx] != commonValue) { |
962 | sameArg = false; |
963 | break; |
964 | } |
965 | } |
966 | |
967 | // If they are passing the same value, drop the argument. |
968 | if (commonValue && sameArg) { |
969 | argsToErase.push_back(Elt: argIdx); |
970 | |
971 | // Remove the argument from the block. |
972 | rewriter.replaceAllUsesWith(from: blockOperand, to: commonValue); |
973 | } |
974 | } |
975 | |
976 | // Remove the arguments. |
977 | for (size_t argIdx : llvm::reverse(C&: argsToErase)) { |
978 | block.eraseArgument(index: argIdx); |
979 | |
980 | // Remove the argument from the branch ops. |
981 | for (auto predIt = block.pred_begin(), predE = block.pred_end(); |
982 | predIt != predE; ++predIt) { |
983 | auto branch = cast<BranchOpInterface>((*predIt)->getTerminator()); |
984 | unsigned succIndex = predIt.getSuccessorIndex(); |
985 | SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex); |
986 | succOperands.erase(subStart: argIdx); |
987 | } |
988 | } |
989 | return success(IsSuccess: !argsToErase.empty()); |
990 | } |
991 | |
992 | /// This optimization drops redundant argument to blocks. I.e., if a given |
993 | /// argument to a block receives the same value from each of the block |
994 | /// predecessors, we can remove the argument from the block and use directly the |
995 | /// original value. This is a simple example: |
996 | /// |
997 | /// %cond = llvm.call @rand() : () -> i1 |
998 | /// %val0 = llvm.mlir.constant(1 : i64) : i64 |
999 | /// %val1 = llvm.mlir.constant(2 : i64) : i64 |
1000 | /// %val2 = llvm.mlir.constant(3 : i64) : i64 |
1001 | /// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2 |
1002 | /// : i64) |
1003 | /// |
1004 | /// ^bb1(%arg0 : i64, %arg1 : i64): |
1005 | /// llvm.call @foo(%arg0, %arg1) |
1006 | /// |
1007 | /// The previous IR can be rewritten as: |
1008 | /// %cond = llvm.call @rand() : () -> i1 |
1009 | /// %val0 = llvm.mlir.constant(1 : i64) : i64 |
1010 | /// %val1 = llvm.mlir.constant(2 : i64) : i64 |
1011 | /// %val2 = llvm.mlir.constant(3 : i64) : i64 |
1012 | /// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64) |
1013 | /// |
1014 | /// ^bb1(%arg0 : i64): |
1015 | /// llvm.call @foo(%val0, %arg0) |
1016 | /// |
1017 | static LogicalResult dropRedundantArguments(RewriterBase &rewriter, |
1018 | MutableArrayRef<Region> regions) { |
1019 | llvm::SmallSetVector<Region *, 1> worklist; |
1020 | for (Region ®ion : regions) |
1021 | worklist.insert(X: ®ion); |
1022 | bool anyChanged = false; |
1023 | while (!worklist.empty()) { |
1024 | Region *region = worklist.pop_back_val(); |
1025 | |
1026 | // Add any nested regions to the worklist. |
1027 | for (Block &block : *region) { |
1028 | anyChanged = |
1029 | succeeded(Result: dropRedundantArguments(rewriter, block)) || anyChanged; |
1030 | |
1031 | for (Operation &op : block) |
1032 | for (Region &nestedRegion : op.getRegions()) |
1033 | worklist.insert(X: &nestedRegion); |
1034 | } |
1035 | } |
1036 | return success(IsSuccess: anyChanged); |
1037 | } |
1038 | |
1039 | //===----------------------------------------------------------------------===// |
1040 | // Region Simplification |
1041 | //===----------------------------------------------------------------------===// |
1042 | |
1043 | /// Run a set of structural simplifications over the given regions. This |
1044 | /// includes transformations like unreachable block elimination, dead argument |
1045 | /// elimination, as well as some other DCE. This function returns success if any |
1046 | /// of the regions were simplified, failure otherwise. |
1047 | LogicalResult mlir::simplifyRegions(RewriterBase &rewriter, |
1048 | MutableArrayRef<Region> regions, |
1049 | bool mergeBlocks) { |
1050 | bool eliminatedBlocks = succeeded(Result: eraseUnreachableBlocks(rewriter, regions)); |
1051 | bool eliminatedOpsOrArgs = succeeded(Result: runRegionDCE(rewriter, regions)); |
1052 | bool mergedIdenticalBlocks = false; |
1053 | bool droppedRedundantArguments = false; |
1054 | if (mergeBlocks) { |
1055 | mergedIdenticalBlocks = succeeded(Result: mergeIdenticalBlocks(rewriter, regions)); |
1056 | droppedRedundantArguments = |
1057 | succeeded(Result: dropRedundantArguments(rewriter, regions)); |
1058 | } |
1059 | return success(IsSuccess: eliminatedBlocks || eliminatedOpsOrArgs || |
1060 | mergedIdenticalBlocks || droppedRedundantArguments); |
1061 | } |
1062 | |
1063 | //===---------------------------------------------------------------------===// |
1064 | // Move operation dependencies |
1065 | //===---------------------------------------------------------------------===// |
1066 | |
1067 | LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, |
1068 | Operation *op, |
1069 | Operation *insertionPoint, |
1070 | DominanceInfo &dominance) { |
1071 | // Currently unsupported case where the op and insertion point are |
1072 | // in different basic blocks. |
1073 | if (op->getBlock() != insertionPoint->getBlock()) { |
1074 | return rewriter.notifyMatchFailure( |
1075 | arg&: op, msg: "unsupported case where operation and insertion point are not in " |
1076 | "the same basic block"); |
1077 | } |
1078 | // If `insertionPoint` does not dominate `op`, do nothing |
1079 | if (!dominance.properlyDominates(a: insertionPoint, b: op)) { |
1080 | return rewriter.notifyMatchFailure(arg&: op, |
1081 | msg: "insertion point does not dominate op"); |
1082 | } |
1083 | |
1084 | // Find the backward slice of operation for each `Value` the operation |
1085 | // depends on. Prune the slice to only include operations not already |
1086 | // dominated by the `insertionPoint` |
1087 | BackwardSliceOptions options; |
1088 | options.inclusive = false; |
1089 | options.omitUsesFromAbove = false; |
1090 | // Since current support is to only move within a same basic block, |
1091 | // the slices dont need to look past block arguments. |
1092 | options.omitBlockArguments = true; |
1093 | options.filter = [&](Operation *sliceBoundaryOp) { |
1094 | return !dominance.properlyDominates(a: sliceBoundaryOp, b: insertionPoint); |
1095 | }; |
1096 | llvm::SetVector<Operation *> slice; |
1097 | LogicalResult result = getBackwardSlice(op, backwardSlice: &slice, options); |
1098 | assert(result.succeeded() && "expected a backward slice"); |
1099 | (void)result; |
1100 | |
1101 | // If the slice contains `insertionPoint` cannot move the dependencies. |
1102 | if (slice.contains(key: insertionPoint)) { |
1103 | return rewriter.notifyMatchFailure( |
1104 | arg&: op, |
1105 | msg: "cannot move dependencies before operation in backward slice of op"); |
1106 | } |
1107 | |
1108 | // We should move the slice in topological order, but `getBackwardSlice` |
1109 | // already does that. So no need to sort again. |
1110 | for (Operation *op : slice) { |
1111 | rewriter.moveOpBefore(op, existingOp: insertionPoint); |
1112 | } |
1113 | return success(); |
1114 | } |
1115 | |
1116 | LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, |
1117 | Operation *op, |
1118 | Operation *insertionPoint) { |
1119 | DominanceInfo dominance(op); |
1120 | return moveOperationDependencies(rewriter, op, insertionPoint, dominance); |
1121 | } |
1122 | |
1123 | LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter, |
1124 | ValueRange values, |
1125 | Operation *insertionPoint, |
1126 | DominanceInfo &dominance) { |
1127 | // Remove the values that already dominate the insertion point. |
1128 | SmallVector<Value> prunedValues; |
1129 | for (auto value : values) { |
1130 | if (dominance.properlyDominates(a: value, b: insertionPoint)) { |
1131 | continue; |
1132 | } |
1133 | // Block arguments are not supported. |
1134 | if (isa<BlockArgument>(Val: value)) { |
1135 | return rewriter.notifyMatchFailure( |
1136 | arg&: insertionPoint, |
1137 | msg: "unsupported case of moving block argument before insertion point"); |
1138 | } |
1139 | // Check for currently unsupported case if the insertion point is in a |
1140 | // different block. |
1141 | if (value.getDefiningOp()->getBlock() != insertionPoint->getBlock()) { |
1142 | return rewriter.notifyMatchFailure( |
1143 | arg&: insertionPoint, |
1144 | msg: "unsupported case of moving definition of value before an insertion " |
1145 | "point in a different basic block"); |
1146 | } |
1147 | prunedValues.push_back(Elt: value); |
1148 | } |
1149 | |
1150 | // Find the backward slice of operation for each `Value` the operation |
1151 | // depends on. Prune the slice to only include operations not already |
1152 | // dominated by the `insertionPoint` |
1153 | BackwardSliceOptions options; |
1154 | options.inclusive = true; |
1155 | options.omitUsesFromAbove = false; |
1156 | // Since current support is to only move within a same basic block, |
1157 | // the slices dont need to look past block arguments. |
1158 | options.omitBlockArguments = true; |
1159 | options.filter = [&](Operation *sliceBoundaryOp) { |
1160 | return !dominance.properlyDominates(a: sliceBoundaryOp, b: insertionPoint); |
1161 | }; |
1162 | llvm::SetVector<Operation *> slice; |
1163 | for (auto value : prunedValues) { |
1164 | LogicalResult result = getBackwardSlice(root: value, backwardSlice: &slice, options); |
1165 | assert(result.succeeded() && "expected a backward slice"); |
1166 | (void)result; |
1167 | } |
1168 | |
1169 | // If the slice contains `insertionPoint` cannot move the dependencies. |
1170 | if (slice.contains(key: insertionPoint)) { |
1171 | return rewriter.notifyMatchFailure( |
1172 | arg&: insertionPoint, |
1173 | msg: "cannot move dependencies before operation in backward slice of op"); |
1174 | } |
1175 | |
1176 | // Sort operations topologically before moving. |
1177 | mlir::topologicalSort(toSort: slice); |
1178 | |
1179 | for (Operation *op : slice) { |
1180 | rewriter.moveOpBefore(op, existingOp: insertionPoint); |
1181 | } |
1182 | return success(); |
1183 | } |
1184 | |
1185 | LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter, |
1186 | ValueRange values, |
1187 | Operation *insertionPoint) { |
1188 | DominanceInfo dominance(insertionPoint); |
1189 | return moveValueDefinitions(rewriter, values, insertionPoint, dominance); |
1190 | } |
1191 |
Definitions
- replaceAllUsesInRegionWith
- visitUsedValuesDefinedAbove
- visitUsedValuesDefinedAbove
- getUsedValuesDefinedAbove
- getUsedValuesDefinedAbove
- makeRegionIsolatedFromAbove
- eraseUnreachableBlocks
- LiveMap
- wasProvenLive
- wasProvenLive
- setProvedLive
- setProvedLive
- wasProvenLive
- setProvedLive
- resetChanged
- hasChanged
- isUseSpeciallyKnownDead
- processValue
- propagateTerminatorLiveness
- propagateLiveness
- propagateLiveness
- eraseTerminatorSuccessorOperands
- deleteDeadness
- runRegionDCE
- BlockEquivalenceData
- BlockEquivalenceData
- getOrderOf
- BlockMergeCluster
- BlockMergeCluster
- addToCluster
- ableToUpdatePredOperands
- pruneRedundantArguments
- merge
- mergeIdenticalBlocks
- mergeIdenticalBlocks
- dropRedundantArguments
- dropRedundantArguments
- simplifyRegions
- moveOperationDependencies
- moveOperationDependencies
- moveValueDefinitions
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more