1 | //===- FuncConversions.cpp - Function conversions -------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
10 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
11 | #include "mlir/Transforms/DialectConversion.h" |
12 | |
13 | using namespace mlir; |
14 | using namespace mlir::func; |
15 | |
16 | namespace { |
17 | /// Converts the operand and result types of the CallOp, used together with the |
18 | /// FuncOpSignatureConversion. |
19 | struct CallOpSignatureConversion : public OpConversionPattern<CallOp> { |
20 | using OpConversionPattern<CallOp>::OpConversionPattern; |
21 | |
22 | /// Hook for derived classes to implement combined matching and rewriting. |
23 | LogicalResult |
24 | matchAndRewrite(CallOp callOp, OpAdaptor adaptor, |
25 | ConversionPatternRewriter &rewriter) const override { |
26 | // Convert the original function results. |
27 | SmallVector<Type, 1> convertedResults; |
28 | if (failed(typeConverter->convertTypes(callOp.getResultTypes(), |
29 | convertedResults))) |
30 | return failure(); |
31 | |
32 | // If this isn't a one-to-one type mapping, we don't know how to aggregate |
33 | // the results. |
34 | if (callOp->getNumResults() != convertedResults.size()) |
35 | return failure(); |
36 | |
37 | // Substitute with the new result types from the corresponding FuncType |
38 | // conversion. |
39 | rewriter.replaceOpWithNewOp<CallOp>( |
40 | callOp, callOp.getCallee(), convertedResults, adaptor.getOperands()); |
41 | return success(); |
42 | } |
43 | }; |
44 | } // namespace |
45 | |
46 | void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns, |
47 | TypeConverter &converter) { |
48 | patterns.add<CallOpSignatureConversion>(arg&: converter, args: patterns.getContext()); |
49 | } |
50 | |
51 | namespace { |
52 | /// Only needed to support partial conversion of functions where this pattern |
53 | /// ensures that the branch operation arguments matches up with the succesor |
54 | /// block arguments. |
55 | class BranchOpInterfaceTypeConversion |
56 | : public OpInterfaceConversionPattern<BranchOpInterface> { |
57 | public: |
58 | using OpInterfaceConversionPattern< |
59 | BranchOpInterface>::OpInterfaceConversionPattern; |
60 | |
61 | BranchOpInterfaceTypeConversion( |
62 | TypeConverter &typeConverter, MLIRContext *ctx, |
63 | function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) |
64 | : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1), |
65 | shouldConvertBranchOperand(shouldConvertBranchOperand) {} |
66 | |
67 | LogicalResult |
68 | matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands, |
69 | ConversionPatternRewriter &rewriter) const final { |
70 | // For a branch operation, only some operands go to the target blocks, so |
71 | // only rewrite those. |
72 | SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end()); |
73 | for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors(); |
74 | succIdx < succEnd; ++succIdx) { |
75 | OperandRange forwardedOperands = |
76 | op.getSuccessorOperands(succIdx).getForwardedOperands(); |
77 | if (forwardedOperands.empty()) |
78 | continue; |
79 | |
80 | for (int idx = forwardedOperands.getBeginOperandIndex(), |
81 | eidx = idx + forwardedOperands.size(); |
82 | idx < eidx; ++idx) { |
83 | if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx)) |
84 | newOperands[idx] = operands[idx]; |
85 | } |
86 | } |
87 | rewriter.modifyOpInPlace( |
88 | op, [newOperands, op]() { op->setOperands(newOperands); }); |
89 | return success(); |
90 | } |
91 | |
92 | private: |
93 | function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand; |
94 | }; |
95 | } // namespace |
96 | |
97 | namespace { |
98 | /// Only needed to support partial conversion of functions where this pattern |
99 | /// ensures that the branch operation arguments matches up with the succesor |
100 | /// block arguments. |
101 | class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> { |
102 | public: |
103 | using OpConversionPattern<ReturnOp>::OpConversionPattern; |
104 | |
105 | LogicalResult |
106 | matchAndRewrite(ReturnOp op, OpAdaptor adaptor, |
107 | ConversionPatternRewriter &rewriter) const final { |
108 | // For a return, all operands go to the results of the parent, so |
109 | // rewrite them all. |
110 | rewriter.modifyOpInPlace(op, |
111 | [&] { op->setOperands(adaptor.getOperands()); }); |
112 | return success(); |
113 | } |
114 | }; |
115 | } // namespace |
116 | |
117 | void mlir::populateBranchOpInterfaceTypeConversionPattern( |
118 | RewritePatternSet &patterns, TypeConverter &typeConverter, |
119 | function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) { |
120 | patterns.add<BranchOpInterfaceTypeConversion>( |
121 | arg&: typeConverter, args: patterns.getContext(), args&: shouldConvertBranchOperand); |
122 | } |
123 | |
124 | bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern( |
125 | Operation *op, TypeConverter &converter) { |
126 | // All successor operands of branch like operations must be rewritten. |
127 | if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { |
128 | for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { |
129 | auto successorOperands = branchOp.getSuccessorOperands(p); |
130 | if (!converter.isLegal( |
131 | successorOperands.getForwardedOperands().getTypes())) |
132 | return false; |
133 | } |
134 | return true; |
135 | } |
136 | |
137 | return false; |
138 | } |
139 | |
140 | void mlir::populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, |
141 | TypeConverter &typeConverter) { |
142 | patterns.add<ReturnOpTypeConversion>(arg&: typeConverter, args: patterns.getContext()); |
143 | } |
144 | |
145 | bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op, |
146 | TypeConverter &converter, |
147 | bool returnOpAlwaysLegal) { |
148 | // If this is a `return` and the user pass wants to convert/transform across |
149 | // function boundaries, then `converter` is invoked to check whether the |
150 | // `return` op is legal. |
151 | if (isa<ReturnOp>(op) && !returnOpAlwaysLegal) |
152 | return converter.isLegal(op); |
153 | |
154 | // ReturnLike operations have to be legalized with their parent. For |
155 | // return this is handled, for other ops they remain as is. |
156 | return op->hasTrait<OpTrait::ReturnLike>(); |
157 | } |
158 | |
159 | bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) { |
160 | // If it is not a terminator, ignore it. |
161 | if (!op->mightHaveTrait<OpTrait::IsTerminator>()) |
162 | return true; |
163 | |
164 | // If it is not the last operation in the block, also ignore it. We do |
165 | // this to handle unknown operations, as well. |
166 | Block *block = op->getBlock(); |
167 | if (!block || &block->back() != op) |
168 | return true; |
169 | |
170 | // We don't want to handle terminators in nested regions, assume they are |
171 | // always legal. |
172 | if (!isa_and_nonnull<FuncOp>(op->getParentOp())) |
173 | return true; |
174 | |
175 | return false; |
176 | } |
177 | |