1//===- TestControlFlowSink.cpp - Test control-flow sink pass --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass tests the control-flow sink utilities by implementing an example
10// control-flow sink pass.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/IR/Dominance.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Transforms/ControlFlowSinkUtils.h"
18
19using namespace mlir;
20
21namespace {
22/// An example control-flow sink pass to test the control-flow sink utilites.
23/// This pass will sink ops named `test.sink_me` and tag them with an attribute
24/// `was_sunk` into the first region of `test.sink_target` ops.
25struct TestControlFlowSinkPass
26 : public PassWrapper<TestControlFlowSinkPass, OperationPass<func::FuncOp>> {
27 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestControlFlowSinkPass)
28
29 /// Get the command-line argument of the test pass.
30 StringRef getArgument() const final { return "test-control-flow-sink"; }
31 /// Get the description of the test pass.
32 StringRef getDescription() const final {
33 return "Test control-flow sink pass";
34 }
35
36 /// Runs the pass on the function.
37 void runOnOperation() override {
38 auto &domInfo = getAnalysis<DominanceInfo>();
39 auto shouldMoveIntoRegion = [](Operation *op, Region *region) {
40 return region->getRegionNumber() == 0 &&
41 op->getName().getStringRef() == "test.sink_me";
42 };
43 auto moveIntoRegion = [](Operation *op, Region *region) {
44 Block &entry = region->front();
45 op->moveBefore(block: &entry, iterator: entry.begin());
46 op->setAttr("was_sunk",
47 Builder(op).getI32IntegerAttr(region->getRegionNumber()));
48 };
49
50 getOperation()->walk([&](Operation *op) {
51 if (op->getName().getStringRef() != "test.sink_target")
52 return;
53 SmallVector<Region *> regions =
54 llvm::to_vector(Range: RegionRange(op->getRegions()));
55 controlFlowSink(regions, domInfo, shouldMoveIntoRegion, moveIntoRegion);
56 });
57 }
58};
59} // end anonymous namespace
60
61namespace mlir {
62namespace test {
63void registerTestControlFlowSink() {
64 PassRegistration<TestControlFlowSinkPass>();
65}
66} // end namespace test
67} // end namespace mlir
68

source code of mlir/test/lib/Transforms/TestControlFlowSink.cpp