1//===- TestInliningCallback.cpp - Pass to inline calls in the test dialect
2//--------===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9// This file implements a pass to test inlining callbacks including
10// canHandleMultipleBlocks and doClone.
11//===----------------------------------------------------------------------===//
12
13#include "TestDialect.h"
14#include "TestOps.h"
15#include "mlir/Analysis/CallGraph.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/IRMapping.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Transforms/Inliner.h"
22#include "mlir/Transforms/InliningUtils.h"
23#include "llvm/ADT/StringSet.h"
24
25using namespace mlir;
26using namespace test;
27
28namespace {
29struct InlinerCallback
30 : public PassWrapper<InlinerCallback, OperationPass<func::FuncOp>> {
31 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InlinerCallback)
32
33 StringRef getArgument() const final { return "test-inline-callback"; }
34 StringRef getDescription() const final {
35 return "Test inlining region calls with call back functions";
36 }
37
38 void getDependentDialects(DialectRegistry &registry) const override {
39 registry.insert<scf::SCFDialect>();
40 }
41
42 static LogicalResult runPipelineHelper(Pass &pass, OpPassManager &pipeline,
43 Operation *op) {
44 return mlir::cast<InlinerCallback>(Val&: pass).runPipeline(pipeline, op);
45 }
46
47 // Customize the implementation of Inliner::doClone
48 // Wrap the callee into scf.execute_region operation
49 static void testDoClone(OpBuilder &builder, Region *src, Block *inlineBlock,
50 Block *postInsertBlock, IRMapping &mapper,
51 bool shouldCloneInlinedRegion) {
52 // Create a new scf.execute_region operation
53 mlir::Operation &call = inlineBlock->back();
54 builder.setInsertionPointAfter(&call);
55
56 auto executeRegionOp = builder.create<mlir::scf::ExecuteRegionOp>(
57 call.getLoc(), call.getResultTypes());
58 mlir::Region &region = executeRegionOp.getRegion();
59
60 // Move the inlined blocks into the region
61 src->cloneInto(dest: &region, mapper);
62
63 // Split block before scf operation.
64 inlineBlock->splitBlock(executeRegionOp.getOperation());
65
66 // Replace all test.return with scf.yield
67 for (mlir::Block &block : region) {
68
69 for (mlir::Operation &op : llvm::make_early_inc_range(block)) {
70 if (test::TestReturnOp returnOp =
71 llvm::dyn_cast<test::TestReturnOp>(&op)) {
72 mlir::OpBuilder returnBuilder(returnOp);
73 returnBuilder.create<mlir::scf::YieldOp>(returnOp.getLoc(),
74 returnOp.getOperands());
75 returnOp.erase();
76 }
77 }
78 }
79
80 // Add test.return after scf.execute_region
81 builder.setInsertionPointAfter(executeRegionOp);
82 builder.create<test::TestReturnOp>(executeRegionOp.getLoc(),
83 executeRegionOp.getResults());
84 }
85
86 void runOnOperation() override {
87 InlinerConfig config;
88 CallGraph &cg = getAnalysis<CallGraph>();
89
90 func::FuncOp function = getOperation();
91
92 // By default, assume that any inlining is profitable.
93 auto profitabilityCb = [&](const mlir::Inliner::ResolvedCall &call) {
94 return true;
95 };
96
97 // Set the clone callback in the config
98 config.setCloneCallback([](OpBuilder &builder, Region *src,
99 Block *inlineBlock, Block *postInsertBlock,
100 IRMapping &mapper,
101 bool shouldCloneInlinedRegion) {
102 return testDoClone(builder, src, inlineBlock, postInsertBlock, mapper,
103 shouldCloneInlinedRegion);
104 });
105
106 // Set canHandleMultipleBlocks to true in the config
107 config.setCanHandleMultipleBlocks();
108
109 // Get an instance of the inliner.
110 Inliner inliner(function, cg, *this, getAnalysisManager(),
111 runPipelineHelper, config, profitabilityCb);
112
113 // Collect each of the direct function calls within the module.
114 SmallVector<func::CallIndirectOp> callers;
115 function.walk(
116 [&](func::CallIndirectOp caller) { callers.push_back(caller); });
117
118 // Build the inliner interface.
119 InlinerInterface interface(&getContext());
120
121 // Try to inline each of the call operations.
122 for (auto caller : callers) {
123 auto callee = dyn_cast_or_null<FunctionalRegionOp>(
124 caller.getCallee().getDefiningOp());
125 if (!callee)
126 continue;
127
128 // Inline the functional region operation, but only clone the internal
129 // region if there is more than one use.
130 if (failed(inlineRegion(
131 interface, config.getCloneCallback(), &callee.getBody(), caller,
132 caller.getArgOperands(), caller.getResults(), caller.getLoc(),
133 /*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse())))
134 continue;
135
136 // If the inlining was successful then erase the call and callee if
137 // possible.
138 caller.erase();
139 if (callee.use_empty())
140 callee.erase();
141 }
142 }
143};
144} // namespace
145
146namespace mlir {
147namespace test {
148void registerInlinerCallback() { PassRegistration<InlinerCallback>(); }
149} // namespace test
150} // namespace mlir
151

source code of mlir/test/lib/Transforms/TestInliningCallback.cpp