1 | //===- RegionUtils.cpp - Region-related transformation utilities ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Transforms/RegionUtils.h" |
10 | #include "mlir/IR/Block.h" |
11 | #include "mlir/IR/IRMapping.h" |
12 | #include "mlir/IR/Operation.h" |
13 | #include "mlir/IR/PatternMatch.h" |
14 | #include "mlir/IR/RegionGraphTraits.h" |
15 | #include "mlir/IR/Value.h" |
16 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
17 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
18 | #include "mlir/Transforms/TopologicalSortUtils.h" |
19 | |
20 | #include "llvm/ADT/DepthFirstIterator.h" |
21 | #include "llvm/ADT/PostOrderIterator.h" |
22 | #include "llvm/ADT/SmallSet.h" |
23 | |
24 | #include <deque> |
25 | |
26 | using namespace mlir; |
27 | |
28 | void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, |
29 | Region ®ion) { |
30 | for (auto &use : llvm::make_early_inc_range(Range: orig.getUses())) { |
31 | if (region.isAncestor(other: use.getOwner()->getParentRegion())) |
32 | use.set(replacement); |
33 | } |
34 | } |
35 | |
36 | void mlir::visitUsedValuesDefinedAbove( |
37 | Region ®ion, Region &limit, function_ref<void(OpOperand *)> callback) { |
38 | assert(limit.isAncestor(®ion) && |
39 | "expected isolation limit to be an ancestor of the given region" ); |
40 | |
41 | // Collect proper ancestors of `limit` upfront to avoid traversing the region |
42 | // tree for every value. |
43 | SmallPtrSet<Region *, 4> properAncestors; |
44 | for (auto *reg = limit.getParentRegion(); reg != nullptr; |
45 | reg = reg->getParentRegion()) { |
46 | properAncestors.insert(Ptr: reg); |
47 | } |
48 | |
49 | region.walk(callback: [callback, &properAncestors](Operation *op) { |
50 | for (OpOperand &operand : op->getOpOperands()) |
51 | // Callback on values defined in a proper ancestor of region. |
52 | if (properAncestors.count(Ptr: operand.get().getParentRegion())) |
53 | callback(&operand); |
54 | }); |
55 | } |
56 | |
57 | void mlir::visitUsedValuesDefinedAbove( |
58 | MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) { |
59 | for (Region ®ion : regions) |
60 | visitUsedValuesDefinedAbove(region, limit&: region, callback); |
61 | } |
62 | |
63 | void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, |
64 | SetVector<Value> &values) { |
65 | visitUsedValuesDefinedAbove(region, limit, callback: [&](OpOperand *operand) { |
66 | values.insert(X: operand->get()); |
67 | }); |
68 | } |
69 | |
70 | void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions, |
71 | SetVector<Value> &values) { |
72 | for (Region ®ion : regions) |
73 | getUsedValuesDefinedAbove(region, limit&: region, values); |
74 | } |
75 | |
76 | //===----------------------------------------------------------------------===// |
77 | // Make block isolated from above. |
78 | //===----------------------------------------------------------------------===// |
79 | |
80 | SmallVector<Value> mlir::makeRegionIsolatedFromAbove( |
81 | RewriterBase &rewriter, Region ®ion, |
82 | llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) { |
83 | |
84 | // Get initial list of values used within region but defined above. |
85 | llvm::SetVector<Value> initialCapturedValues; |
86 | mlir::getUsedValuesDefinedAbove(regions: region, values&: initialCapturedValues); |
87 | |
88 | std::deque<Value> worklist(initialCapturedValues.begin(), |
89 | initialCapturedValues.end()); |
90 | llvm::DenseSet<Value> visited; |
91 | llvm::DenseSet<Operation *> visitedOps; |
92 | |
93 | llvm::SetVector<Value> finalCapturedValues; |
94 | SmallVector<Operation *> clonedOperations; |
95 | while (!worklist.empty()) { |
96 | Value currValue = worklist.front(); |
97 | worklist.pop_front(); |
98 | if (visited.count(V: currValue)) |
99 | continue; |
100 | visited.insert(V: currValue); |
101 | |
102 | Operation *definingOp = currValue.getDefiningOp(); |
103 | if (!definingOp || visitedOps.count(V: definingOp)) { |
104 | finalCapturedValues.insert(X: currValue); |
105 | continue; |
106 | } |
107 | visitedOps.insert(V: definingOp); |
108 | |
109 | if (!cloneOperationIntoRegion(definingOp)) { |
110 | // Defining operation isnt cloned, so add the current value to final |
111 | // captured values list. |
112 | finalCapturedValues.insert(X: currValue); |
113 | continue; |
114 | } |
115 | |
116 | // Add all operands of the operation to the worklist and mark the op as to |
117 | // be cloned. |
118 | for (Value operand : definingOp->getOperands()) { |
119 | if (visited.count(V: operand)) |
120 | continue; |
121 | worklist.push_back(x: operand); |
122 | } |
123 | clonedOperations.push_back(Elt: definingOp); |
124 | } |
125 | |
126 | // The operations to be cloned need to be ordered in topological order |
127 | // so that they can be cloned into the region without violating use-def |
128 | // chains. |
129 | mlir::computeTopologicalSorting(ops: clonedOperations); |
130 | |
131 | OpBuilder::InsertionGuard g(rewriter); |
132 | // Collect types of existing block |
133 | Block *entryBlock = ®ion.front(); |
134 | SmallVector<Type> newArgTypes = |
135 | llvm::to_vector(Range: entryBlock->getArgumentTypes()); |
136 | SmallVector<Location> newArgLocs = llvm::to_vector(Range: llvm::map_range( |
137 | C: entryBlock->getArguments(), F: [](BlockArgument b) { return b.getLoc(); })); |
138 | |
139 | // Append the types of the captured values. |
140 | for (auto value : finalCapturedValues) { |
141 | newArgTypes.push_back(Elt: value.getType()); |
142 | newArgLocs.push_back(Elt: value.getLoc()); |
143 | } |
144 | |
145 | // Create a new entry block. |
146 | Block *newEntryBlock = |
147 | rewriter.createBlock(parent: ®ion, insertPt: region.begin(), argTypes: newArgTypes, locs: newArgLocs); |
148 | auto newEntryBlockArgs = newEntryBlock->getArguments(); |
149 | |
150 | // Create a mapping between the captured values and the new arguments added. |
151 | IRMapping map; |
152 | auto replaceIfFn = [&](OpOperand &use) { |
153 | return use.getOwner()->getBlock()->getParent() == ®ion; |
154 | }; |
155 | for (auto [arg, capturedVal] : |
156 | llvm::zip(t: newEntryBlockArgs.take_back(N: finalCapturedValues.size()), |
157 | u&: finalCapturedValues)) { |
158 | map.map(from: capturedVal, to: arg); |
159 | rewriter.replaceUsesWithIf(from: capturedVal, to: arg, functor: replaceIfFn); |
160 | } |
161 | rewriter.setInsertionPointToStart(newEntryBlock); |
162 | for (auto *clonedOp : clonedOperations) { |
163 | Operation *newOp = rewriter.clone(op&: *clonedOp, mapper&: map); |
164 | rewriter.replaceOpUsesWithIf(from: clonedOp, to: newOp->getResults(), functor: replaceIfFn); |
165 | } |
166 | rewriter.mergeBlocks( |
167 | source: entryBlock, dest: newEntryBlock, |
168 | argValues: newEntryBlock->getArguments().take_front(N: entryBlock->getNumArguments())); |
169 | return llvm::to_vector(Range&: finalCapturedValues); |
170 | } |
171 | |
172 | //===----------------------------------------------------------------------===// |
173 | // Unreachable Block Elimination |
174 | //===----------------------------------------------------------------------===// |
175 | |
176 | /// Erase the unreachable blocks within the provided regions. Returns success |
177 | /// if any blocks were erased, failure otherwise. |
178 | // TODO: We could likely merge this with the DCE algorithm below. |
179 | LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, |
180 | MutableArrayRef<Region> regions) { |
181 | // Set of blocks found to be reachable within a given region. |
182 | llvm::df_iterator_default_set<Block *, 16> reachable; |
183 | // If any blocks were found to be dead. |
184 | bool erasedDeadBlocks = false; |
185 | |
186 | SmallVector<Region *, 1> worklist; |
187 | worklist.reserve(N: regions.size()); |
188 | for (Region ®ion : regions) |
189 | worklist.push_back(Elt: ®ion); |
190 | while (!worklist.empty()) { |
191 | Region *region = worklist.pop_back_val(); |
192 | if (region->empty()) |
193 | continue; |
194 | |
195 | // If this is a single block region, just collect the nested regions. |
196 | if (std::next(x: region->begin()) == region->end()) { |
197 | for (Operation &op : region->front()) |
198 | for (Region ®ion : op.getRegions()) |
199 | worklist.push_back(Elt: ®ion); |
200 | continue; |
201 | } |
202 | |
203 | // Mark all reachable blocks. |
204 | reachable.clear(); |
205 | for (Block *block : depth_first_ext(G: ®ion->front(), S&: reachable)) |
206 | (void)block /* Mark all reachable blocks */; |
207 | |
208 | // Collect all of the dead blocks and push the live regions onto the |
209 | // worklist. |
210 | for (Block &block : llvm::make_early_inc_range(Range&: *region)) { |
211 | if (!reachable.count(Ptr: &block)) { |
212 | block.dropAllDefinedValueUses(); |
213 | rewriter.eraseBlock(block: &block); |
214 | erasedDeadBlocks = true; |
215 | continue; |
216 | } |
217 | |
218 | // Walk any regions within this block. |
219 | for (Operation &op : block) |
220 | for (Region ®ion : op.getRegions()) |
221 | worklist.push_back(Elt: ®ion); |
222 | } |
223 | } |
224 | |
225 | return success(isSuccess: erasedDeadBlocks); |
226 | } |
227 | |
228 | //===----------------------------------------------------------------------===// |
229 | // Dead Code Elimination |
230 | //===----------------------------------------------------------------------===// |
231 | |
232 | namespace { |
233 | /// Data structure used to track which values have already been proved live. |
234 | /// |
235 | /// Because Operation's can have multiple results, this data structure tracks |
236 | /// liveness for both Value's and Operation's to avoid having to look through |
237 | /// all Operation results when analyzing a use. |
238 | /// |
239 | /// This data structure essentially tracks the dataflow lattice. |
240 | /// The set of values/ops proved live increases monotonically to a fixed-point. |
241 | class LiveMap { |
242 | public: |
243 | /// Value methods. |
244 | bool wasProvenLive(Value value) { |
245 | // TODO: For results that are removable, e.g. for region based control flow, |
246 | // we could allow for these values to be tracked independently. |
247 | if (OpResult result = dyn_cast<OpResult>(Val&: value)) |
248 | return wasProvenLive(op: result.getOwner()); |
249 | return wasProvenLive(arg: cast<BlockArgument>(Val&: value)); |
250 | } |
251 | bool wasProvenLive(BlockArgument arg) { return liveValues.count(V: arg); } |
252 | void setProvedLive(Value value) { |
253 | // TODO: For results that are removable, e.g. for region based control flow, |
254 | // we could allow for these values to be tracked independently. |
255 | if (OpResult result = dyn_cast<OpResult>(Val&: value)) |
256 | return setProvedLive(result.getOwner()); |
257 | setProvedLive(cast<BlockArgument>(Val&: value)); |
258 | } |
259 | void setProvedLive(BlockArgument arg) { |
260 | changed |= liveValues.insert(V: arg).second; |
261 | } |
262 | |
263 | /// Operation methods. |
264 | bool wasProvenLive(Operation *op) { return liveOps.count(V: op); } |
265 | void setProvedLive(Operation *op) { changed |= liveOps.insert(V: op).second; } |
266 | |
267 | /// Methods for tracking if we have reached a fixed-point. |
268 | void resetChanged() { changed = false; } |
269 | bool hasChanged() { return changed; } |
270 | |
271 | private: |
272 | bool changed = false; |
273 | DenseSet<Value> liveValues; |
274 | DenseSet<Operation *> liveOps; |
275 | }; |
276 | } // namespace |
277 | |
278 | static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) { |
279 | Operation *owner = use.getOwner(); |
280 | unsigned operandIndex = use.getOperandNumber(); |
281 | // This pass generally treats all uses of an op as live if the op itself is |
282 | // considered live. However, for successor operands to terminators we need a |
283 | // finer-grained notion where we deduce liveness for operands individually. |
284 | // The reason for this is easiest to think about in terms of a classical phi |
285 | // node based SSA IR, where each successor operand is really an operand to a |
286 | // *separate* phi node, rather than all operands to the branch itself as with |
287 | // the block argument representation that MLIR uses. |
288 | // |
289 | // And similarly, because each successor operand is really an operand to a phi |
290 | // node, rather than to the terminator op itself, a terminator op can't e.g. |
291 | // "print" the value of a successor operand. |
292 | if (owner->hasTrait<OpTrait::IsTerminator>()) { |
293 | if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner)) |
294 | if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex)) |
295 | return !liveMap.wasProvenLive(*arg); |
296 | return false; |
297 | } |
298 | return false; |
299 | } |
300 | |
301 | static void processValue(Value value, LiveMap &liveMap) { |
302 | bool provedLive = llvm::any_of(Range: value.getUses(), P: [&](OpOperand &use) { |
303 | if (isUseSpeciallyKnownDead(use, liveMap)) |
304 | return false; |
305 | return liveMap.wasProvenLive(op: use.getOwner()); |
306 | }); |
307 | if (provedLive) |
308 | liveMap.setProvedLive(value); |
309 | } |
310 | |
311 | static void propagateLiveness(Region ®ion, LiveMap &liveMap); |
312 | |
313 | static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) { |
314 | // Terminators are always live. |
315 | liveMap.setProvedLive(op); |
316 | |
317 | // Check to see if we can reason about the successor operands and mutate them. |
318 | BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op); |
319 | if (!branchInterface) { |
320 | for (Block *successor : op->getSuccessors()) |
321 | for (BlockArgument arg : successor->getArguments()) |
322 | liveMap.setProvedLive(arg); |
323 | return; |
324 | } |
325 | |
326 | // If we can't reason about the operand to a successor, conservatively mark |
327 | // it as live. |
328 | for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { |
329 | SuccessorOperands successorOperands = |
330 | branchInterface.getSuccessorOperands(i); |
331 | for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount(); |
332 | opI != opE; ++opI) |
333 | liveMap.setProvedLive(op->getSuccessor(index: i)->getArgument(i: opI)); |
334 | } |
335 | } |
336 | |
337 | static void propagateLiveness(Operation *op, LiveMap &liveMap) { |
338 | // Recurse on any regions the op has. |
339 | for (Region ®ion : op->getRegions()) |
340 | propagateLiveness(region, liveMap); |
341 | |
342 | // Process terminator operations. |
343 | if (op->hasTrait<OpTrait::IsTerminator>()) |
344 | return propagateTerminatorLiveness(op, liveMap); |
345 | |
346 | // Don't reprocess live operations. |
347 | if (liveMap.wasProvenLive(op)) |
348 | return; |
349 | |
350 | // Process the op itself. |
351 | if (!wouldOpBeTriviallyDead(op)) |
352 | return liveMap.setProvedLive(op); |
353 | |
354 | // If the op isn't intrinsically alive, check it's results. |
355 | for (Value value : op->getResults()) |
356 | processValue(value, liveMap); |
357 | } |
358 | |
359 | static void propagateLiveness(Region ®ion, LiveMap &liveMap) { |
360 | if (region.empty()) |
361 | return; |
362 | |
363 | for (Block *block : llvm::post_order(G: ®ion.front())) { |
364 | // We process block arguments after the ops in the block, to promote |
365 | // faster convergence to a fixed point (we try to visit uses before defs). |
366 | for (Operation &op : llvm::reverse(C&: block->getOperations())) |
367 | propagateLiveness(op: &op, liveMap); |
368 | |
369 | // We currently do not remove entry block arguments, so there is no need to |
370 | // track their liveness. |
371 | // TODO: We could track these and enable removing dead operands/arguments |
372 | // from region control flow operations. |
373 | if (block->isEntryBlock()) |
374 | continue; |
375 | |
376 | for (Value value : block->getArguments()) { |
377 | if (!liveMap.wasProvenLive(value)) |
378 | processValue(value, liveMap); |
379 | } |
380 | } |
381 | } |
382 | |
383 | static void eraseTerminatorSuccessorOperands(Operation *terminator, |
384 | LiveMap &liveMap) { |
385 | BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator); |
386 | if (!branchOp) |
387 | return; |
388 | |
389 | for (unsigned succI = 0, succE = terminator->getNumSuccessors(); |
390 | succI < succE; succI++) { |
391 | // Iterating successors in reverse is not strictly needed, since we |
392 | // aren't erasing any successors. But it is slightly more efficient |
393 | // since it will promote later operands of the terminator being erased |
394 | // first, reducing the quadratic-ness. |
395 | unsigned succ = succE - succI - 1; |
396 | SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ); |
397 | Block *successor = terminator->getSuccessor(index: succ); |
398 | |
399 | for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) { |
400 | // Iterating args in reverse is needed for correctness, to avoid |
401 | // shifting later args when earlier args are erased. |
402 | unsigned arg = argE - argI - 1; |
403 | if (!liveMap.wasProvenLive(arg: successor->getArgument(i: arg))) |
404 | succOperands.erase(subStart: arg); |
405 | } |
406 | } |
407 | } |
408 | |
409 | static LogicalResult deleteDeadness(RewriterBase &rewriter, |
410 | MutableArrayRef<Region> regions, |
411 | LiveMap &liveMap) { |
412 | bool erasedAnything = false; |
413 | for (Region ®ion : regions) { |
414 | if (region.empty()) |
415 | continue; |
416 | bool hasSingleBlock = llvm::hasSingleElement(C&: region); |
417 | |
418 | // Delete every operation that is not live. Graph regions may have cycles |
419 | // in the use-def graph, so we must explicitly dropAllUses() from each |
420 | // operation as we erase it. Visiting the operations in post-order |
421 | // guarantees that in SSA CFG regions value uses are removed before defs, |
422 | // which makes dropAllUses() a no-op. |
423 | for (Block *block : llvm::post_order(G: ®ion.front())) { |
424 | if (!hasSingleBlock) |
425 | eraseTerminatorSuccessorOperands(terminator: block->getTerminator(), liveMap); |
426 | for (Operation &childOp : |
427 | llvm::make_early_inc_range(Range: llvm::reverse(C&: block->getOperations()))) { |
428 | if (!liveMap.wasProvenLive(op: &childOp)) { |
429 | erasedAnything = true; |
430 | childOp.dropAllUses(); |
431 | rewriter.eraseOp(op: &childOp); |
432 | } else { |
433 | erasedAnything |= succeeded( |
434 | result: deleteDeadness(rewriter, regions: childOp.getRegions(), liveMap)); |
435 | } |
436 | } |
437 | } |
438 | // Delete block arguments. |
439 | // The entry block has an unknown contract with their enclosing block, so |
440 | // skip it. |
441 | for (Block &block : llvm::drop_begin(RangeOrContainer&: region.getBlocks(), N: 1)) { |
442 | block.eraseArguments( |
443 | shouldEraseFn: [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); }); |
444 | } |
445 | } |
446 | return success(isSuccess: erasedAnything); |
447 | } |
448 | |
449 | // This function performs a simple dead code elimination algorithm over the |
450 | // given regions. |
451 | // |
452 | // The overall goal is to prove that Values are dead, which allows deleting ops |
453 | // and block arguments. |
454 | // |
455 | // This uses an optimistic algorithm that assumes everything is dead until |
456 | // proved otherwise, allowing it to delete recursively dead cycles. |
457 | // |
458 | // This is a simple fixed-point dataflow analysis algorithm on a lattice |
459 | // {Dead,Alive}. Because liveness flows backward, we generally try to |
460 | // iterate everything backward to speed up convergence to the fixed-point. This |
461 | // allows for being able to delete recursively dead cycles of the use-def graph, |
462 | // including block arguments. |
463 | // |
464 | // This function returns success if any operations or arguments were deleted, |
465 | // failure otherwise. |
466 | LogicalResult mlir::runRegionDCE(RewriterBase &rewriter, |
467 | MutableArrayRef<Region> regions) { |
468 | LiveMap liveMap; |
469 | do { |
470 | liveMap.resetChanged(); |
471 | |
472 | for (Region ®ion : regions) |
473 | propagateLiveness(region, liveMap); |
474 | } while (liveMap.hasChanged()); |
475 | |
476 | return deleteDeadness(rewriter, regions, liveMap); |
477 | } |
478 | |
479 | //===----------------------------------------------------------------------===// |
480 | // Block Merging |
481 | //===----------------------------------------------------------------------===// |
482 | |
483 | //===----------------------------------------------------------------------===// |
484 | // BlockEquivalenceData |
485 | |
486 | namespace { |
487 | /// This class contains the information for comparing the equivalencies of two |
488 | /// blocks. Blocks are considered equivalent if they contain the same operations |
489 | /// in the same order. The only allowed divergence is for operands that come |
490 | /// from sources outside of the parent block, i.e. the uses of values produced |
491 | /// within the block must be equivalent. |
492 | /// e.g., |
493 | /// Equivalent: |
494 | /// ^bb1(%arg0: i32) |
495 | /// return %arg0, %foo : i32, i32 |
496 | /// ^bb2(%arg1: i32) |
497 | /// return %arg1, %bar : i32, i32 |
498 | /// Not Equivalent: |
499 | /// ^bb1(%arg0: i32) |
500 | /// return %foo, %arg0 : i32, i32 |
501 | /// ^bb2(%arg1: i32) |
502 | /// return %arg1, %bar : i32, i32 |
503 | struct BlockEquivalenceData { |
504 | BlockEquivalenceData(Block *block); |
505 | |
506 | /// Return the order index for the given value that is within the block of |
507 | /// this data. |
508 | unsigned getOrderOf(Value value) const; |
509 | |
510 | /// The block this data refers to. |
511 | Block *block; |
512 | /// A hash value for this block. |
513 | llvm::hash_code hash; |
514 | /// A map of result producing operations to their relative orders within this |
515 | /// block. The order of an operation is the number of defined values that are |
516 | /// produced within the block before this operation. |
517 | DenseMap<Operation *, unsigned> opOrderIndex; |
518 | }; |
519 | } // namespace |
520 | |
521 | BlockEquivalenceData::BlockEquivalenceData(Block *block) |
522 | : block(block), hash(0) { |
523 | unsigned orderIt = block->getNumArguments(); |
524 | for (Operation &op : *block) { |
525 | if (unsigned numResults = op.getNumResults()) { |
526 | opOrderIndex.try_emplace(Key: &op, Args&: orderIt); |
527 | orderIt += numResults; |
528 | } |
529 | auto opHash = OperationEquivalence::computeHash( |
530 | op: &op, hashOperands: OperationEquivalence::ignoreHashValue, |
531 | hashResults: OperationEquivalence::ignoreHashValue, |
532 | flags: OperationEquivalence::IgnoreLocations); |
533 | hash = llvm::hash_combine(args: hash, args: opHash); |
534 | } |
535 | } |
536 | |
537 | unsigned BlockEquivalenceData::getOrderOf(Value value) const { |
538 | assert(value.getParentBlock() == block && "expected value of this block" ); |
539 | |
540 | // Arguments use the argument number as the order index. |
541 | if (BlockArgument arg = dyn_cast<BlockArgument>(Val&: value)) |
542 | return arg.getArgNumber(); |
543 | |
544 | // Otherwise, the result order is offset from the parent op's order. |
545 | OpResult result = cast<OpResult>(Val&: value); |
546 | auto opOrderIt = opOrderIndex.find(Val: result.getDefiningOp()); |
547 | assert(opOrderIt != opOrderIndex.end() && "expected op to have an order" ); |
548 | return opOrderIt->second + result.getResultNumber(); |
549 | } |
550 | |
551 | //===----------------------------------------------------------------------===// |
552 | // BlockMergeCluster |
553 | |
554 | namespace { |
555 | /// This class represents a cluster of blocks to be merged together. |
556 | class BlockMergeCluster { |
557 | public: |
558 | BlockMergeCluster(BlockEquivalenceData &&leaderData) |
559 | : leaderData(std::move(leaderData)) {} |
560 | |
561 | /// Attempt to add the given block to this cluster. Returns success if the |
562 | /// block was merged, failure otherwise. |
563 | LogicalResult addToCluster(BlockEquivalenceData &blockData); |
564 | |
565 | /// Try to merge all of the blocks within this cluster into the leader block. |
566 | LogicalResult merge(RewriterBase &rewriter); |
567 | |
568 | private: |
569 | /// The equivalence data for the leader of the cluster. |
570 | BlockEquivalenceData leaderData; |
571 | |
572 | /// The set of blocks that can be merged into the leader. |
573 | llvm::SmallSetVector<Block *, 1> blocksToMerge; |
574 | |
575 | /// A set of operand+index pairs that correspond to operands that need to be |
576 | /// replaced by arguments when the cluster gets merged. |
577 | std::set<std::pair<int, int>> operandsToMerge; |
578 | }; |
579 | } // namespace |
580 | |
581 | LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) { |
582 | if (leaderData.hash != blockData.hash) |
583 | return failure(); |
584 | Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block; |
585 | if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes()) |
586 | return failure(); |
587 | |
588 | // A set of operands that mismatch between the leader and the new block. |
589 | SmallVector<std::pair<int, int>, 8> mismatchedOperands; |
590 | auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end(); |
591 | auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end(); |
592 | for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) { |
593 | // Check that the operations are equivalent. |
594 | if (!OperationEquivalence::isEquivalentTo( |
595 | lhs: &*lhsIt, rhs: &*rhsIt, checkEquivalent: OperationEquivalence::ignoreValueEquivalence, |
596 | /*markEquivalent=*/nullptr, |
597 | flags: OperationEquivalence::Flags::IgnoreLocations)) |
598 | return failure(); |
599 | |
600 | // Compare the operands of the two operations. If the operand is within |
601 | // the block, it must refer to the same operation. |
602 | auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands(); |
603 | for (int operand : llvm::seq<int>(Begin: 0, End: lhsIt->getNumOperands())) { |
604 | Value lhsOperand = lhsOperands[operand]; |
605 | Value rhsOperand = rhsOperands[operand]; |
606 | if (lhsOperand == rhsOperand) |
607 | continue; |
608 | // Check that the types of the operands match. |
609 | if (lhsOperand.getType() != rhsOperand.getType()) |
610 | return failure(); |
611 | |
612 | // Check that these uses are both external, or both internal. |
613 | bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock; |
614 | bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock; |
615 | if (lhsIsInBlock != rhsIsInBlock) |
616 | return failure(); |
617 | // Let the operands differ if they are defined in a different block. These |
618 | // will become new arguments if the blocks get merged. |
619 | if (!lhsIsInBlock) { |
620 | |
621 | // Check whether the operands aren't the result of an immediate |
622 | // predecessors terminator. In that case we are not able to use it as a |
623 | // successor operand when branching to the merged block as it does not |
624 | // dominate its producing operation. |
625 | auto isValidSuccessorArg = [](Block *block, Value operand) { |
626 | if (operand.getDefiningOp() != |
627 | operand.getParentBlock()->getTerminator()) |
628 | return true; |
629 | return !llvm::is_contained(Range: block->getPredecessors(), |
630 | Element: operand.getParentBlock()); |
631 | }; |
632 | |
633 | if (!isValidSuccessorArg(leaderBlock, lhsOperand) || |
634 | !isValidSuccessorArg(mergeBlock, rhsOperand)) |
635 | return failure(); |
636 | |
637 | mismatchedOperands.emplace_back(Args&: opI, Args&: operand); |
638 | continue; |
639 | } |
640 | |
641 | // Otherwise, these operands must have the same logical order within the |
642 | // parent block. |
643 | if (leaderData.getOrderOf(value: lhsOperand) != blockData.getOrderOf(value: rhsOperand)) |
644 | return failure(); |
645 | } |
646 | |
647 | // If the lhs or rhs has external uses, the blocks cannot be merged as the |
648 | // merged version of this operation will not be either the lhs or rhs |
649 | // alone (thus semantically incorrect), but some mix dependending on which |
650 | // block preceeded this. |
651 | // TODO allow merging of operations when one block does not dominate the |
652 | // other |
653 | if (rhsIt->isUsedOutsideOfBlock(block: mergeBlock) || |
654 | lhsIt->isUsedOutsideOfBlock(block: leaderBlock)) { |
655 | return failure(); |
656 | } |
657 | } |
658 | // Make sure that the block sizes are equivalent. |
659 | if (lhsIt != lhsE || rhsIt != rhsE) |
660 | return failure(); |
661 | |
662 | // If we get here, the blocks are equivalent and can be merged. |
663 | operandsToMerge.insert(first: mismatchedOperands.begin(), last: mismatchedOperands.end()); |
664 | blocksToMerge.insert(X: blockData.block); |
665 | return success(); |
666 | } |
667 | |
668 | /// Returns true if the predecessor terminators of the given block can not have |
669 | /// their operands updated. |
670 | static bool ableToUpdatePredOperands(Block *block) { |
671 | for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { |
672 | if (!isa<BranchOpInterface>(Val: (*it)->getTerminator())) |
673 | return false; |
674 | } |
675 | return true; |
676 | } |
677 | |
678 | LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) { |
679 | // Don't consider clusters that don't have blocks to merge. |
680 | if (blocksToMerge.empty()) |
681 | return failure(); |
682 | |
683 | Block *leaderBlock = leaderData.block; |
684 | if (!operandsToMerge.empty()) { |
685 | // If the cluster has operands to merge, verify that the predecessor |
686 | // terminators of each of the blocks can have their successor operands |
687 | // updated. |
688 | // TODO: We could try and sub-partition this cluster if only some blocks |
689 | // cause the mismatch. |
690 | if (!ableToUpdatePredOperands(block: leaderBlock) || |
691 | !llvm::all_of(Range&: blocksToMerge, P: ableToUpdatePredOperands)) |
692 | return failure(); |
693 | |
694 | // Collect the iterators for each of the blocks to merge. We will walk all |
695 | // of the iterators at once to avoid operand index invalidation. |
696 | SmallVector<Block::iterator, 2> blockIterators; |
697 | blockIterators.reserve(N: blocksToMerge.size() + 1); |
698 | blockIterators.push_back(Elt: leaderBlock->begin()); |
699 | for (Block *mergeBlock : blocksToMerge) |
700 | blockIterators.push_back(Elt: mergeBlock->begin()); |
701 | |
702 | // Update each of the predecessor terminators with the new arguments. |
703 | SmallVector<SmallVector<Value, 8>, 2> newArguments( |
704 | 1 + blocksToMerge.size(), |
705 | SmallVector<Value, 8>(operandsToMerge.size())); |
706 | unsigned curOpIndex = 0; |
707 | for (const auto &it : llvm::enumerate(First&: operandsToMerge)) { |
708 | unsigned nextOpOffset = it.value().first - curOpIndex; |
709 | curOpIndex = it.value().first; |
710 | |
711 | // Process the operand for each of the block iterators. |
712 | for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) { |
713 | Block::iterator &blockIter = blockIterators[i]; |
714 | std::advance(i&: blockIter, n: nextOpOffset); |
715 | auto &operand = blockIter->getOpOperand(idx: it.value().second); |
716 | newArguments[i][it.index()] = operand.get(); |
717 | |
718 | // Update the operand and insert an argument if this is the leader. |
719 | if (i == 0) { |
720 | Value operandVal = operand.get(); |
721 | operand.set(leaderBlock->addArgument(type: operandVal.getType(), |
722 | loc: operandVal.getLoc())); |
723 | } |
724 | } |
725 | } |
726 | // Update the predecessors for each of the blocks. |
727 | auto updatePredecessors = [&](Block *block, unsigned clusterIndex) { |
728 | for (auto predIt = block->pred_begin(), predE = block->pred_end(); |
729 | predIt != predE; ++predIt) { |
730 | auto branch = cast<BranchOpInterface>((*predIt)->getTerminator()); |
731 | unsigned succIndex = predIt.getSuccessorIndex(); |
732 | branch.getSuccessorOperands(succIndex).append( |
733 | newArguments[clusterIndex]); |
734 | } |
735 | }; |
736 | updatePredecessors(leaderBlock, /*clusterIndex=*/0); |
737 | for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i) |
738 | updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1); |
739 | } |
740 | |
741 | // Replace all uses of the merged blocks with the leader and erase them. |
742 | for (Block *block : blocksToMerge) { |
743 | block->replaceAllUsesWith(newValue&: leaderBlock); |
744 | rewriter.eraseBlock(block); |
745 | } |
746 | return success(); |
747 | } |
748 | |
749 | /// Identify identical blocks within the given region and merge them, inserting |
750 | /// new block arguments as necessary. Returns success if any blocks were merged, |
751 | /// failure otherwise. |
752 | static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, |
753 | Region ®ion) { |
754 | if (region.empty() || llvm::hasSingleElement(C&: region)) |
755 | return failure(); |
756 | |
757 | // Identify sets of blocks, other than the entry block, that branch to the |
758 | // same successors. We will use these groups to create clusters of equivalent |
759 | // blocks. |
760 | DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors; |
761 | for (Block &block : llvm::drop_begin(RangeOrContainer&: region, N: 1)) |
762 | matchingSuccessors[block.getSuccessors()].push_back(Elt: &block); |
763 | |
764 | bool mergedAnyBlocks = false; |
765 | for (ArrayRef<Block *> blocks : llvm::make_second_range(c&: matchingSuccessors)) { |
766 | if (blocks.size() == 1) |
767 | continue; |
768 | |
769 | SmallVector<BlockMergeCluster, 1> clusters; |
770 | for (Block *block : blocks) { |
771 | BlockEquivalenceData data(block); |
772 | |
773 | // Don't allow merging if this block has any regions. |
774 | // TODO: Add support for regions if necessary. |
775 | bool hasNonEmptyRegion = llvm::any_of(Range&: *block, P: [](Operation &op) { |
776 | return llvm::any_of(Range: op.getRegions(), |
777 | P: [](Region ®ion) { return !region.empty(); }); |
778 | }); |
779 | if (hasNonEmptyRegion) |
780 | continue; |
781 | |
782 | // Try to add this block to an existing cluster. |
783 | bool addedToCluster = false; |
784 | for (auto &cluster : clusters) |
785 | if ((addedToCluster = succeeded(result: cluster.addToCluster(blockData&: data)))) |
786 | break; |
787 | if (!addedToCluster) |
788 | clusters.emplace_back(Args: std::move(data)); |
789 | } |
790 | for (auto &cluster : clusters) |
791 | mergedAnyBlocks |= succeeded(result: cluster.merge(rewriter)); |
792 | } |
793 | |
794 | return success(isSuccess: mergedAnyBlocks); |
795 | } |
796 | |
797 | /// Identify identical blocks within the given regions and merge them, inserting |
798 | /// new block arguments as necessary. |
799 | static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, |
800 | MutableArrayRef<Region> regions) { |
801 | llvm::SmallSetVector<Region *, 1> worklist; |
802 | for (auto ®ion : regions) |
803 | worklist.insert(X: ®ion); |
804 | bool anyChanged = false; |
805 | while (!worklist.empty()) { |
806 | Region *region = worklist.pop_back_val(); |
807 | if (succeeded(result: mergeIdenticalBlocks(rewriter, region&: *region))) { |
808 | worklist.insert(X: region); |
809 | anyChanged = true; |
810 | } |
811 | |
812 | // Add any nested regions to the worklist. |
813 | for (Block &block : *region) |
814 | for (auto &op : block) |
815 | for (auto &nestedRegion : op.getRegions()) |
816 | worklist.insert(X: &nestedRegion); |
817 | } |
818 | |
819 | return success(isSuccess: anyChanged); |
820 | } |
821 | |
822 | //===----------------------------------------------------------------------===// |
823 | // Region Simplification |
824 | //===----------------------------------------------------------------------===// |
825 | |
826 | /// Run a set of structural simplifications over the given regions. This |
827 | /// includes transformations like unreachable block elimination, dead argument |
828 | /// elimination, as well as some other DCE. This function returns success if any |
829 | /// of the regions were simplified, failure otherwise. |
830 | LogicalResult mlir::simplifyRegions(RewriterBase &rewriter, |
831 | MutableArrayRef<Region> regions) { |
832 | bool eliminatedBlocks = succeeded(result: eraseUnreachableBlocks(rewriter, regions)); |
833 | bool eliminatedOpsOrArgs = succeeded(result: runRegionDCE(rewriter, regions)); |
834 | bool mergedIdenticalBlocks = |
835 | succeeded(result: mergeIdenticalBlocks(rewriter, regions)); |
836 | return success(isSuccess: eliminatedBlocks || eliminatedOpsOrArgs || |
837 | mergedIdenticalBlocks); |
838 | } |
839 | |
840 | SetVector<Block *> mlir::getTopologicallySortedBlocks(Region ®ion) { |
841 | // For each block that has not been visited yet (i.e. that has no |
842 | // predecessors), add it to the list as well as its successors. |
843 | SetVector<Block *> blocks; |
844 | for (Block &b : region) { |
845 | if (blocks.count(key: &b) == 0) { |
846 | llvm::ReversePostOrderTraversal<Block *> traversal(&b); |
847 | blocks.insert(Start: traversal.begin(), End: traversal.end()); |
848 | } |
849 | } |
850 | assert(blocks.size() == region.getBlocks().size() && |
851 | "some blocks are not sorted" ); |
852 | |
853 | return blocks; |
854 | } |
855 | |