1//===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Transforms/RegionUtils.h"
10#include "mlir/IR/Block.h"
11#include "mlir/IR/IRMapping.h"
12#include "mlir/IR/Operation.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/IR/RegionGraphTraits.h"
15#include "mlir/IR/Value.h"
16#include "mlir/Interfaces/ControlFlowInterfaces.h"
17#include "mlir/Interfaces/SideEffectInterfaces.h"
18#include "mlir/Transforms/TopologicalSortUtils.h"
19
20#include "llvm/ADT/DepthFirstIterator.h"
21#include "llvm/ADT/PostOrderIterator.h"
22#include "llvm/ADT/SmallSet.h"
23
24#include <deque>
25
26using namespace mlir;
27
28void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
29 Region &region) {
30 for (auto &use : llvm::make_early_inc_range(Range: orig.getUses())) {
31 if (region.isAncestor(other: use.getOwner()->getParentRegion()))
32 use.set(replacement);
33 }
34}
35
36void mlir::visitUsedValuesDefinedAbove(
37 Region &region, Region &limit, function_ref<void(OpOperand *)> callback) {
38 assert(limit.isAncestor(&region) &&
39 "expected isolation limit to be an ancestor of the given region");
40
41 // Collect proper ancestors of `limit` upfront to avoid traversing the region
42 // tree for every value.
43 SmallPtrSet<Region *, 4> properAncestors;
44 for (auto *reg = limit.getParentRegion(); reg != nullptr;
45 reg = reg->getParentRegion()) {
46 properAncestors.insert(Ptr: reg);
47 }
48
49 region.walk(callback: [callback, &properAncestors](Operation *op) {
50 for (OpOperand &operand : op->getOpOperands())
51 // Callback on values defined in a proper ancestor of region.
52 if (properAncestors.count(Ptr: operand.get().getParentRegion()))
53 callback(&operand);
54 });
55}
56
57void mlir::visitUsedValuesDefinedAbove(
58 MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
59 for (Region &region : regions)
60 visitUsedValuesDefinedAbove(region, limit&: region, callback);
61}
62
63void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
64 SetVector<Value> &values) {
65 visitUsedValuesDefinedAbove(region, limit, callback: [&](OpOperand *operand) {
66 values.insert(X: operand->get());
67 });
68}
69
70void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
71 SetVector<Value> &values) {
72 for (Region &region : regions)
73 getUsedValuesDefinedAbove(region, limit&: region, values);
74}
75
76//===----------------------------------------------------------------------===//
77// Make block isolated from above.
78//===----------------------------------------------------------------------===//
79
80SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
81 RewriterBase &rewriter, Region &region,
82 llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) {
83
84 // Get initial list of values used within region but defined above.
85 llvm::SetVector<Value> initialCapturedValues;
86 mlir::getUsedValuesDefinedAbove(regions: region, values&: initialCapturedValues);
87
88 std::deque<Value> worklist(initialCapturedValues.begin(),
89 initialCapturedValues.end());
90 llvm::DenseSet<Value> visited;
91 llvm::DenseSet<Operation *> visitedOps;
92
93 llvm::SetVector<Value> finalCapturedValues;
94 SmallVector<Operation *> clonedOperations;
95 while (!worklist.empty()) {
96 Value currValue = worklist.front();
97 worklist.pop_front();
98 if (visited.count(V: currValue))
99 continue;
100 visited.insert(V: currValue);
101
102 Operation *definingOp = currValue.getDefiningOp();
103 if (!definingOp || visitedOps.count(V: definingOp)) {
104 finalCapturedValues.insert(X: currValue);
105 continue;
106 }
107 visitedOps.insert(V: definingOp);
108
109 if (!cloneOperationIntoRegion(definingOp)) {
110 // Defining operation isnt cloned, so add the current value to final
111 // captured values list.
112 finalCapturedValues.insert(X: currValue);
113 continue;
114 }
115
116 // Add all operands of the operation to the worklist and mark the op as to
117 // be cloned.
118 for (Value operand : definingOp->getOperands()) {
119 if (visited.count(V: operand))
120 continue;
121 worklist.push_back(x: operand);
122 }
123 clonedOperations.push_back(Elt: definingOp);
124 }
125
126 // The operations to be cloned need to be ordered in topological order
127 // so that they can be cloned into the region without violating use-def
128 // chains.
129 mlir::computeTopologicalSorting(ops: clonedOperations);
130
131 OpBuilder::InsertionGuard g(rewriter);
132 // Collect types of existing block
133 Block *entryBlock = &region.front();
134 SmallVector<Type> newArgTypes =
135 llvm::to_vector(Range: entryBlock->getArgumentTypes());
136 SmallVector<Location> newArgLocs = llvm::to_vector(Range: llvm::map_range(
137 C: entryBlock->getArguments(), F: [](BlockArgument b) { return b.getLoc(); }));
138
139 // Append the types of the captured values.
140 for (auto value : finalCapturedValues) {
141 newArgTypes.push_back(Elt: value.getType());
142 newArgLocs.push_back(Elt: value.getLoc());
143 }
144
145 // Create a new entry block.
146 Block *newEntryBlock =
147 rewriter.createBlock(parent: &region, insertPt: region.begin(), argTypes: newArgTypes, locs: newArgLocs);
148 auto newEntryBlockArgs = newEntryBlock->getArguments();
149
150 // Create a mapping between the captured values and the new arguments added.
151 IRMapping map;
152 auto replaceIfFn = [&](OpOperand &use) {
153 return use.getOwner()->getBlock()->getParent() == &region;
154 };
155 for (auto [arg, capturedVal] :
156 llvm::zip(t: newEntryBlockArgs.take_back(N: finalCapturedValues.size()),
157 u&: finalCapturedValues)) {
158 map.map(from: capturedVal, to: arg);
159 rewriter.replaceUsesWithIf(from: capturedVal, to: arg, functor: replaceIfFn);
160 }
161 rewriter.setInsertionPointToStart(newEntryBlock);
162 for (auto *clonedOp : clonedOperations) {
163 Operation *newOp = rewriter.clone(op&: *clonedOp, mapper&: map);
164 rewriter.replaceOpUsesWithIf(from: clonedOp, to: newOp->getResults(), functor: replaceIfFn);
165 }
166 rewriter.mergeBlocks(
167 source: entryBlock, dest: newEntryBlock,
168 argValues: newEntryBlock->getArguments().take_front(N: entryBlock->getNumArguments()));
169 return llvm::to_vector(Range&: finalCapturedValues);
170}
171
172//===----------------------------------------------------------------------===//
173// Unreachable Block Elimination
174//===----------------------------------------------------------------------===//
175
176/// Erase the unreachable blocks within the provided regions. Returns success
177/// if any blocks were erased, failure otherwise.
178// TODO: We could likely merge this with the DCE algorithm below.
179LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
180 MutableArrayRef<Region> regions) {
181 // Set of blocks found to be reachable within a given region.
182 llvm::df_iterator_default_set<Block *, 16> reachable;
183 // If any blocks were found to be dead.
184 bool erasedDeadBlocks = false;
185
186 SmallVector<Region *, 1> worklist;
187 worklist.reserve(N: regions.size());
188 for (Region &region : regions)
189 worklist.push_back(Elt: &region);
190 while (!worklist.empty()) {
191 Region *region = worklist.pop_back_val();
192 if (region->empty())
193 continue;
194
195 // If this is a single block region, just collect the nested regions.
196 if (std::next(x: region->begin()) == region->end()) {
197 for (Operation &op : region->front())
198 for (Region &region : op.getRegions())
199 worklist.push_back(Elt: &region);
200 continue;
201 }
202
203 // Mark all reachable blocks.
204 reachable.clear();
205 for (Block *block : depth_first_ext(G: &region->front(), S&: reachable))
206 (void)block /* Mark all reachable blocks */;
207
208 // Collect all of the dead blocks and push the live regions onto the
209 // worklist.
210 for (Block &block : llvm::make_early_inc_range(Range&: *region)) {
211 if (!reachable.count(Ptr: &block)) {
212 block.dropAllDefinedValueUses();
213 rewriter.eraseBlock(block: &block);
214 erasedDeadBlocks = true;
215 continue;
216 }
217
218 // Walk any regions within this block.
219 for (Operation &op : block)
220 for (Region &region : op.getRegions())
221 worklist.push_back(Elt: &region);
222 }
223 }
224
225 return success(isSuccess: erasedDeadBlocks);
226}
227
228//===----------------------------------------------------------------------===//
229// Dead Code Elimination
230//===----------------------------------------------------------------------===//
231
232namespace {
233/// Data structure used to track which values have already been proved live.
234///
235/// Because Operation's can have multiple results, this data structure tracks
236/// liveness for both Value's and Operation's to avoid having to look through
237/// all Operation results when analyzing a use.
238///
239/// This data structure essentially tracks the dataflow lattice.
240/// The set of values/ops proved live increases monotonically to a fixed-point.
241class LiveMap {
242public:
243 /// Value methods.
244 bool wasProvenLive(Value value) {
245 // TODO: For results that are removable, e.g. for region based control flow,
246 // we could allow for these values to be tracked independently.
247 if (OpResult result = dyn_cast<OpResult>(Val&: value))
248 return wasProvenLive(op: result.getOwner());
249 return wasProvenLive(arg: cast<BlockArgument>(Val&: value));
250 }
251 bool wasProvenLive(BlockArgument arg) { return liveValues.count(V: arg); }
252 void setProvedLive(Value value) {
253 // TODO: For results that are removable, e.g. for region based control flow,
254 // we could allow for these values to be tracked independently.
255 if (OpResult result = dyn_cast<OpResult>(Val&: value))
256 return setProvedLive(result.getOwner());
257 setProvedLive(cast<BlockArgument>(Val&: value));
258 }
259 void setProvedLive(BlockArgument arg) {
260 changed |= liveValues.insert(V: arg).second;
261 }
262
263 /// Operation methods.
264 bool wasProvenLive(Operation *op) { return liveOps.count(V: op); }
265 void setProvedLive(Operation *op) { changed |= liveOps.insert(V: op).second; }
266
267 /// Methods for tracking if we have reached a fixed-point.
268 void resetChanged() { changed = false; }
269 bool hasChanged() { return changed; }
270
271private:
272 bool changed = false;
273 DenseSet<Value> liveValues;
274 DenseSet<Operation *> liveOps;
275};
276} // namespace
277
278static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
279 Operation *owner = use.getOwner();
280 unsigned operandIndex = use.getOperandNumber();
281 // This pass generally treats all uses of an op as live if the op itself is
282 // considered live. However, for successor operands to terminators we need a
283 // finer-grained notion where we deduce liveness for operands individually.
284 // The reason for this is easiest to think about in terms of a classical phi
285 // node based SSA IR, where each successor operand is really an operand to a
286 // *separate* phi node, rather than all operands to the branch itself as with
287 // the block argument representation that MLIR uses.
288 //
289 // And similarly, because each successor operand is really an operand to a phi
290 // node, rather than to the terminator op itself, a terminator op can't e.g.
291 // "print" the value of a successor operand.
292 if (owner->hasTrait<OpTrait::IsTerminator>()) {
293 if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
294 if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
295 return !liveMap.wasProvenLive(*arg);
296 return false;
297 }
298 return false;
299}
300
301static void processValue(Value value, LiveMap &liveMap) {
302 bool provedLive = llvm::any_of(Range: value.getUses(), P: [&](OpOperand &use) {
303 if (isUseSpeciallyKnownDead(use, liveMap))
304 return false;
305 return liveMap.wasProvenLive(op: use.getOwner());
306 });
307 if (provedLive)
308 liveMap.setProvedLive(value);
309}
310
311static void propagateLiveness(Region &region, LiveMap &liveMap);
312
313static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
314 // Terminators are always live.
315 liveMap.setProvedLive(op);
316
317 // Check to see if we can reason about the successor operands and mutate them.
318 BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
319 if (!branchInterface) {
320 for (Block *successor : op->getSuccessors())
321 for (BlockArgument arg : successor->getArguments())
322 liveMap.setProvedLive(arg);
323 return;
324 }
325
326 // If we can't reason about the operand to a successor, conservatively mark
327 // it as live.
328 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
329 SuccessorOperands successorOperands =
330 branchInterface.getSuccessorOperands(i);
331 for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
332 opI != opE; ++opI)
333 liveMap.setProvedLive(op->getSuccessor(index: i)->getArgument(i: opI));
334 }
335}
336
337static void propagateLiveness(Operation *op, LiveMap &liveMap) {
338 // Recurse on any regions the op has.
339 for (Region &region : op->getRegions())
340 propagateLiveness(region, liveMap);
341
342 // Process terminator operations.
343 if (op->hasTrait<OpTrait::IsTerminator>())
344 return propagateTerminatorLiveness(op, liveMap);
345
346 // Don't reprocess live operations.
347 if (liveMap.wasProvenLive(op))
348 return;
349
350 // Process the op itself.
351 if (!wouldOpBeTriviallyDead(op))
352 return liveMap.setProvedLive(op);
353
354 // If the op isn't intrinsically alive, check it's results.
355 for (Value value : op->getResults())
356 processValue(value, liveMap);
357}
358
359static void propagateLiveness(Region &region, LiveMap &liveMap) {
360 if (region.empty())
361 return;
362
363 for (Block *block : llvm::post_order(G: &region.front())) {
364 // We process block arguments after the ops in the block, to promote
365 // faster convergence to a fixed point (we try to visit uses before defs).
366 for (Operation &op : llvm::reverse(C&: block->getOperations()))
367 propagateLiveness(op: &op, liveMap);
368
369 // We currently do not remove entry block arguments, so there is no need to
370 // track their liveness.
371 // TODO: We could track these and enable removing dead operands/arguments
372 // from region control flow operations.
373 if (block->isEntryBlock())
374 continue;
375
376 for (Value value : block->getArguments()) {
377 if (!liveMap.wasProvenLive(value))
378 processValue(value, liveMap);
379 }
380 }
381}
382
383static void eraseTerminatorSuccessorOperands(Operation *terminator,
384 LiveMap &liveMap) {
385 BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
386 if (!branchOp)
387 return;
388
389 for (unsigned succI = 0, succE = terminator->getNumSuccessors();
390 succI < succE; succI++) {
391 // Iterating successors in reverse is not strictly needed, since we
392 // aren't erasing any successors. But it is slightly more efficient
393 // since it will promote later operands of the terminator being erased
394 // first, reducing the quadratic-ness.
395 unsigned succ = succE - succI - 1;
396 SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
397 Block *successor = terminator->getSuccessor(index: succ);
398
399 for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
400 // Iterating args in reverse is needed for correctness, to avoid
401 // shifting later args when earlier args are erased.
402 unsigned arg = argE - argI - 1;
403 if (!liveMap.wasProvenLive(arg: successor->getArgument(i: arg)))
404 succOperands.erase(subStart: arg);
405 }
406 }
407}
408
409static LogicalResult deleteDeadness(RewriterBase &rewriter,
410 MutableArrayRef<Region> regions,
411 LiveMap &liveMap) {
412 bool erasedAnything = false;
413 for (Region &region : regions) {
414 if (region.empty())
415 continue;
416 bool hasSingleBlock = llvm::hasSingleElement(C&: region);
417
418 // Delete every operation that is not live. Graph regions may have cycles
419 // in the use-def graph, so we must explicitly dropAllUses() from each
420 // operation as we erase it. Visiting the operations in post-order
421 // guarantees that in SSA CFG regions value uses are removed before defs,
422 // which makes dropAllUses() a no-op.
423 for (Block *block : llvm::post_order(G: &region.front())) {
424 if (!hasSingleBlock)
425 eraseTerminatorSuccessorOperands(terminator: block->getTerminator(), liveMap);
426 for (Operation &childOp :
427 llvm::make_early_inc_range(Range: llvm::reverse(C&: block->getOperations()))) {
428 if (!liveMap.wasProvenLive(op: &childOp)) {
429 erasedAnything = true;
430 childOp.dropAllUses();
431 rewriter.eraseOp(op: &childOp);
432 } else {
433 erasedAnything |= succeeded(
434 result: deleteDeadness(rewriter, regions: childOp.getRegions(), liveMap));
435 }
436 }
437 }
438 // Delete block arguments.
439 // The entry block has an unknown contract with their enclosing block, so
440 // skip it.
441 for (Block &block : llvm::drop_begin(RangeOrContainer&: region.getBlocks(), N: 1)) {
442 block.eraseArguments(
443 shouldEraseFn: [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
444 }
445 }
446 return success(isSuccess: erasedAnything);
447}
448
449// This function performs a simple dead code elimination algorithm over the
450// given regions.
451//
452// The overall goal is to prove that Values are dead, which allows deleting ops
453// and block arguments.
454//
455// This uses an optimistic algorithm that assumes everything is dead until
456// proved otherwise, allowing it to delete recursively dead cycles.
457//
458// This is a simple fixed-point dataflow analysis algorithm on a lattice
459// {Dead,Alive}. Because liveness flows backward, we generally try to
460// iterate everything backward to speed up convergence to the fixed-point. This
461// allows for being able to delete recursively dead cycles of the use-def graph,
462// including block arguments.
463//
464// This function returns success if any operations or arguments were deleted,
465// failure otherwise.
466LogicalResult mlir::runRegionDCE(RewriterBase &rewriter,
467 MutableArrayRef<Region> regions) {
468 LiveMap liveMap;
469 do {
470 liveMap.resetChanged();
471
472 for (Region &region : regions)
473 propagateLiveness(region, liveMap);
474 } while (liveMap.hasChanged());
475
476 return deleteDeadness(rewriter, regions, liveMap);
477}
478
479//===----------------------------------------------------------------------===//
480// Block Merging
481//===----------------------------------------------------------------------===//
482
483//===----------------------------------------------------------------------===//
484// BlockEquivalenceData
485
486namespace {
487/// This class contains the information for comparing the equivalencies of two
488/// blocks. Blocks are considered equivalent if they contain the same operations
489/// in the same order. The only allowed divergence is for operands that come
490/// from sources outside of the parent block, i.e. the uses of values produced
491/// within the block must be equivalent.
492/// e.g.,
493/// Equivalent:
494/// ^bb1(%arg0: i32)
495/// return %arg0, %foo : i32, i32
496/// ^bb2(%arg1: i32)
497/// return %arg1, %bar : i32, i32
498/// Not Equivalent:
499/// ^bb1(%arg0: i32)
500/// return %foo, %arg0 : i32, i32
501/// ^bb2(%arg1: i32)
502/// return %arg1, %bar : i32, i32
503struct BlockEquivalenceData {
504 BlockEquivalenceData(Block *block);
505
506 /// Return the order index for the given value that is within the block of
507 /// this data.
508 unsigned getOrderOf(Value value) const;
509
510 /// The block this data refers to.
511 Block *block;
512 /// A hash value for this block.
513 llvm::hash_code hash;
514 /// A map of result producing operations to their relative orders within this
515 /// block. The order of an operation is the number of defined values that are
516 /// produced within the block before this operation.
517 DenseMap<Operation *, unsigned> opOrderIndex;
518};
519} // namespace
520
521BlockEquivalenceData::BlockEquivalenceData(Block *block)
522 : block(block), hash(0) {
523 unsigned orderIt = block->getNumArguments();
524 for (Operation &op : *block) {
525 if (unsigned numResults = op.getNumResults()) {
526 opOrderIndex.try_emplace(Key: &op, Args&: orderIt);
527 orderIt += numResults;
528 }
529 auto opHash = OperationEquivalence::computeHash(
530 op: &op, hashOperands: OperationEquivalence::ignoreHashValue,
531 hashResults: OperationEquivalence::ignoreHashValue,
532 flags: OperationEquivalence::IgnoreLocations);
533 hash = llvm::hash_combine(args: hash, args: opHash);
534 }
535}
536
537unsigned BlockEquivalenceData::getOrderOf(Value value) const {
538 assert(value.getParentBlock() == block && "expected value of this block");
539
540 // Arguments use the argument number as the order index.
541 if (BlockArgument arg = dyn_cast<BlockArgument>(Val&: value))
542 return arg.getArgNumber();
543
544 // Otherwise, the result order is offset from the parent op's order.
545 OpResult result = cast<OpResult>(Val&: value);
546 auto opOrderIt = opOrderIndex.find(Val: result.getDefiningOp());
547 assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
548 return opOrderIt->second + result.getResultNumber();
549}
550
551//===----------------------------------------------------------------------===//
552// BlockMergeCluster
553
554namespace {
555/// This class represents a cluster of blocks to be merged together.
556class BlockMergeCluster {
557public:
558 BlockMergeCluster(BlockEquivalenceData &&leaderData)
559 : leaderData(std::move(leaderData)) {}
560
561 /// Attempt to add the given block to this cluster. Returns success if the
562 /// block was merged, failure otherwise.
563 LogicalResult addToCluster(BlockEquivalenceData &blockData);
564
565 /// Try to merge all of the blocks within this cluster into the leader block.
566 LogicalResult merge(RewriterBase &rewriter);
567
568private:
569 /// The equivalence data for the leader of the cluster.
570 BlockEquivalenceData leaderData;
571
572 /// The set of blocks that can be merged into the leader.
573 llvm::SmallSetVector<Block *, 1> blocksToMerge;
574
575 /// A set of operand+index pairs that correspond to operands that need to be
576 /// replaced by arguments when the cluster gets merged.
577 std::set<std::pair<int, int>> operandsToMerge;
578};
579} // namespace
580
581LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
582 if (leaderData.hash != blockData.hash)
583 return failure();
584 Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
585 if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes())
586 return failure();
587
588 // A set of operands that mismatch between the leader and the new block.
589 SmallVector<std::pair<int, int>, 8> mismatchedOperands;
590 auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
591 auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
592 for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
593 // Check that the operations are equivalent.
594 if (!OperationEquivalence::isEquivalentTo(
595 lhs: &*lhsIt, rhs: &*rhsIt, checkEquivalent: OperationEquivalence::ignoreValueEquivalence,
596 /*markEquivalent=*/nullptr,
597 flags: OperationEquivalence::Flags::IgnoreLocations))
598 return failure();
599
600 // Compare the operands of the two operations. If the operand is within
601 // the block, it must refer to the same operation.
602 auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
603 for (int operand : llvm::seq<int>(Begin: 0, End: lhsIt->getNumOperands())) {
604 Value lhsOperand = lhsOperands[operand];
605 Value rhsOperand = rhsOperands[operand];
606 if (lhsOperand == rhsOperand)
607 continue;
608 // Check that the types of the operands match.
609 if (lhsOperand.getType() != rhsOperand.getType())
610 return failure();
611
612 // Check that these uses are both external, or both internal.
613 bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock;
614 bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock;
615 if (lhsIsInBlock != rhsIsInBlock)
616 return failure();
617 // Let the operands differ if they are defined in a different block. These
618 // will become new arguments if the blocks get merged.
619 if (!lhsIsInBlock) {
620
621 // Check whether the operands aren't the result of an immediate
622 // predecessors terminator. In that case we are not able to use it as a
623 // successor operand when branching to the merged block as it does not
624 // dominate its producing operation.
625 auto isValidSuccessorArg = [](Block *block, Value operand) {
626 if (operand.getDefiningOp() !=
627 operand.getParentBlock()->getTerminator())
628 return true;
629 return !llvm::is_contained(Range: block->getPredecessors(),
630 Element: operand.getParentBlock());
631 };
632
633 if (!isValidSuccessorArg(leaderBlock, lhsOperand) ||
634 !isValidSuccessorArg(mergeBlock, rhsOperand))
635 return failure();
636
637 mismatchedOperands.emplace_back(Args&: opI, Args&: operand);
638 continue;
639 }
640
641 // Otherwise, these operands must have the same logical order within the
642 // parent block.
643 if (leaderData.getOrderOf(value: lhsOperand) != blockData.getOrderOf(value: rhsOperand))
644 return failure();
645 }
646
647 // If the lhs or rhs has external uses, the blocks cannot be merged as the
648 // merged version of this operation will not be either the lhs or rhs
649 // alone (thus semantically incorrect), but some mix dependending on which
650 // block preceeded this.
651 // TODO allow merging of operations when one block does not dominate the
652 // other
653 if (rhsIt->isUsedOutsideOfBlock(block: mergeBlock) ||
654 lhsIt->isUsedOutsideOfBlock(block: leaderBlock)) {
655 return failure();
656 }
657 }
658 // Make sure that the block sizes are equivalent.
659 if (lhsIt != lhsE || rhsIt != rhsE)
660 return failure();
661
662 // If we get here, the blocks are equivalent and can be merged.
663 operandsToMerge.insert(first: mismatchedOperands.begin(), last: mismatchedOperands.end());
664 blocksToMerge.insert(X: blockData.block);
665 return success();
666}
667
668/// Returns true if the predecessor terminators of the given block can not have
669/// their operands updated.
670static bool ableToUpdatePredOperands(Block *block) {
671 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
672 if (!isa<BranchOpInterface>(Val: (*it)->getTerminator()))
673 return false;
674 }
675 return true;
676}
677
678LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
679 // Don't consider clusters that don't have blocks to merge.
680 if (blocksToMerge.empty())
681 return failure();
682
683 Block *leaderBlock = leaderData.block;
684 if (!operandsToMerge.empty()) {
685 // If the cluster has operands to merge, verify that the predecessor
686 // terminators of each of the blocks can have their successor operands
687 // updated.
688 // TODO: We could try and sub-partition this cluster if only some blocks
689 // cause the mismatch.
690 if (!ableToUpdatePredOperands(block: leaderBlock) ||
691 !llvm::all_of(Range&: blocksToMerge, P: ableToUpdatePredOperands))
692 return failure();
693
694 // Collect the iterators for each of the blocks to merge. We will walk all
695 // of the iterators at once to avoid operand index invalidation.
696 SmallVector<Block::iterator, 2> blockIterators;
697 blockIterators.reserve(N: blocksToMerge.size() + 1);
698 blockIterators.push_back(Elt: leaderBlock->begin());
699 for (Block *mergeBlock : blocksToMerge)
700 blockIterators.push_back(Elt: mergeBlock->begin());
701
702 // Update each of the predecessor terminators with the new arguments.
703 SmallVector<SmallVector<Value, 8>, 2> newArguments(
704 1 + blocksToMerge.size(),
705 SmallVector<Value, 8>(operandsToMerge.size()));
706 unsigned curOpIndex = 0;
707 for (const auto &it : llvm::enumerate(First&: operandsToMerge)) {
708 unsigned nextOpOffset = it.value().first - curOpIndex;
709 curOpIndex = it.value().first;
710
711 // Process the operand for each of the block iterators.
712 for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
713 Block::iterator &blockIter = blockIterators[i];
714 std::advance(i&: blockIter, n: nextOpOffset);
715 auto &operand = blockIter->getOpOperand(idx: it.value().second);
716 newArguments[i][it.index()] = operand.get();
717
718 // Update the operand and insert an argument if this is the leader.
719 if (i == 0) {
720 Value operandVal = operand.get();
721 operand.set(leaderBlock->addArgument(type: operandVal.getType(),
722 loc: operandVal.getLoc()));
723 }
724 }
725 }
726 // Update the predecessors for each of the blocks.
727 auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
728 for (auto predIt = block->pred_begin(), predE = block->pred_end();
729 predIt != predE; ++predIt) {
730 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
731 unsigned succIndex = predIt.getSuccessorIndex();
732 branch.getSuccessorOperands(succIndex).append(
733 newArguments[clusterIndex]);
734 }
735 };
736 updatePredecessors(leaderBlock, /*clusterIndex=*/0);
737 for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
738 updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1);
739 }
740
741 // Replace all uses of the merged blocks with the leader and erase them.
742 for (Block *block : blocksToMerge) {
743 block->replaceAllUsesWith(newValue&: leaderBlock);
744 rewriter.eraseBlock(block);
745 }
746 return success();
747}
748
749/// Identify identical blocks within the given region and merge them, inserting
750/// new block arguments as necessary. Returns success if any blocks were merged,
751/// failure otherwise.
752static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
753 Region &region) {
754 if (region.empty() || llvm::hasSingleElement(C&: region))
755 return failure();
756
757 // Identify sets of blocks, other than the entry block, that branch to the
758 // same successors. We will use these groups to create clusters of equivalent
759 // blocks.
760 DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors;
761 for (Block &block : llvm::drop_begin(RangeOrContainer&: region, N: 1))
762 matchingSuccessors[block.getSuccessors()].push_back(Elt: &block);
763
764 bool mergedAnyBlocks = false;
765 for (ArrayRef<Block *> blocks : llvm::make_second_range(c&: matchingSuccessors)) {
766 if (blocks.size() == 1)
767 continue;
768
769 SmallVector<BlockMergeCluster, 1> clusters;
770 for (Block *block : blocks) {
771 BlockEquivalenceData data(block);
772
773 // Don't allow merging if this block has any regions.
774 // TODO: Add support for regions if necessary.
775 bool hasNonEmptyRegion = llvm::any_of(Range&: *block, P: [](Operation &op) {
776 return llvm::any_of(Range: op.getRegions(),
777 P: [](Region &region) { return !region.empty(); });
778 });
779 if (hasNonEmptyRegion)
780 continue;
781
782 // Try to add this block to an existing cluster.
783 bool addedToCluster = false;
784 for (auto &cluster : clusters)
785 if ((addedToCluster = succeeded(result: cluster.addToCluster(blockData&: data))))
786 break;
787 if (!addedToCluster)
788 clusters.emplace_back(Args: std::move(data));
789 }
790 for (auto &cluster : clusters)
791 mergedAnyBlocks |= succeeded(result: cluster.merge(rewriter));
792 }
793
794 return success(isSuccess: mergedAnyBlocks);
795}
796
797/// Identify identical blocks within the given regions and merge them, inserting
798/// new block arguments as necessary.
799static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
800 MutableArrayRef<Region> regions) {
801 llvm::SmallSetVector<Region *, 1> worklist;
802 for (auto &region : regions)
803 worklist.insert(X: &region);
804 bool anyChanged = false;
805 while (!worklist.empty()) {
806 Region *region = worklist.pop_back_val();
807 if (succeeded(result: mergeIdenticalBlocks(rewriter, region&: *region))) {
808 worklist.insert(X: region);
809 anyChanged = true;
810 }
811
812 // Add any nested regions to the worklist.
813 for (Block &block : *region)
814 for (auto &op : block)
815 for (auto &nestedRegion : op.getRegions())
816 worklist.insert(X: &nestedRegion);
817 }
818
819 return success(isSuccess: anyChanged);
820}
821
822//===----------------------------------------------------------------------===//
823// Region Simplification
824//===----------------------------------------------------------------------===//
825
826/// Run a set of structural simplifications over the given regions. This
827/// includes transformations like unreachable block elimination, dead argument
828/// elimination, as well as some other DCE. This function returns success if any
829/// of the regions were simplified, failure otherwise.
830LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
831 MutableArrayRef<Region> regions) {
832 bool eliminatedBlocks = succeeded(result: eraseUnreachableBlocks(rewriter, regions));
833 bool eliminatedOpsOrArgs = succeeded(result: runRegionDCE(rewriter, regions));
834 bool mergedIdenticalBlocks =
835 succeeded(result: mergeIdenticalBlocks(rewriter, regions));
836 return success(isSuccess: eliminatedBlocks || eliminatedOpsOrArgs ||
837 mergedIdenticalBlocks);
838}
839
840SetVector<Block *> mlir::getTopologicallySortedBlocks(Region &region) {
841 // For each block that has not been visited yet (i.e. that has no
842 // predecessors), add it to the list as well as its successors.
843 SetVector<Block *> blocks;
844 for (Block &b : region) {
845 if (blocks.count(key: &b) == 0) {
846 llvm::ReversePostOrderTraversal<Block *> traversal(&b);
847 blocks.insert(Start: traversal.begin(), End: traversal.end());
848 }
849 }
850 assert(blocks.size() == region.getBlocks().size() &&
851 "some blocks are not sorted");
852
853 return blocks;
854}
855

source code of mlir/lib/Transforms/Utils/RegionUtils.cpp