1 | //===- InlinerExtension.cpp - Func Inliner Extension ----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" |
10 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
11 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
12 | #include "mlir/IR/DialectInterface.h" |
13 | #include "mlir/Transforms/InliningUtils.h" |
14 | |
15 | using namespace mlir; |
16 | using namespace mlir::func; |
17 | |
18 | //===----------------------------------------------------------------------===// |
19 | // FuncDialect Interfaces |
20 | //===----------------------------------------------------------------------===// |
21 | namespace { |
22 | /// This class defines the interface for handling inlining with func operations. |
23 | struct FuncInlinerInterface : public DialectInlinerInterface { |
24 | using DialectInlinerInterface::DialectInlinerInterface; |
25 | |
26 | //===--------------------------------------------------------------------===// |
27 | // Analysis Hooks |
28 | //===--------------------------------------------------------------------===// |
29 | |
30 | /// Call operations can be inlined unless specified otherwise by attributes |
31 | /// on either the call or the callbale. |
32 | bool isLegalToInline(Operation *call, Operation *callable, |
33 | bool wouldBeCloned) const final { |
34 | auto callOp = dyn_cast<func::CallOp>(call); |
35 | auto funcOp = dyn_cast<func::FuncOp>(callable); |
36 | return !(callOp && callOp.getNoInline()) && |
37 | !(funcOp && funcOp.getNoInline()); |
38 | } |
39 | |
40 | /// All operations can be inlined. |
41 | bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { |
42 | return true; |
43 | } |
44 | |
45 | /// All function bodies can be inlined. |
46 | bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { |
47 | return true; |
48 | } |
49 | |
50 | //===--------------------------------------------------------------------===// |
51 | // Transformation Hooks |
52 | //===--------------------------------------------------------------------===// |
53 | |
54 | /// Handle the given inlined terminator by replacing it with a new operation |
55 | /// as necessary. |
56 | void handleTerminator(Operation *op, Block *newDest) const final { |
57 | // Only return needs to be handled here. |
58 | auto returnOp = dyn_cast<ReturnOp>(op); |
59 | if (!returnOp) |
60 | return; |
61 | |
62 | // Replace the return with a branch to the dest. |
63 | OpBuilder builder(op); |
64 | builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands()); |
65 | op->erase(); |
66 | } |
67 | |
68 | /// Handle the given inlined terminator by replacing it with a new operation |
69 | /// as necessary. |
70 | void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { |
71 | // Only return needs to be handled here. |
72 | auto returnOp = cast<ReturnOp>(op); |
73 | |
74 | // Replace the values directly with the return operands. |
75 | assert(returnOp.getNumOperands() == valuesToRepl.size()); |
76 | for (const auto &it : llvm::enumerate(returnOp.getOperands())) |
77 | valuesToRepl[it.index()].replaceAllUsesWith(it.value()); |
78 | } |
79 | }; |
80 | } // namespace |
81 | |
82 | //===----------------------------------------------------------------------===// |
83 | // Registration |
84 | //===----------------------------------------------------------------------===// |
85 | |
86 | void mlir::func::registerInlinerExtension(DialectRegistry ®istry) { |
87 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, func::FuncDialect *dialect) { |
88 | dialect->addInterfaces<FuncInlinerInterface>(); |
89 | |
90 | // The inliner extension relies on the ControlFlow dialect. |
91 | ctx->getOrLoadDialect<cf::ControlFlowDialect>(); |
92 | }); |
93 | } |
94 | |