1 | //===- FuncConversions.cpp - Function conversions -------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
10 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
11 | #include "mlir/Transforms/DialectConversion.h" |
12 | |
13 | using namespace mlir; |
14 | using namespace mlir::func; |
15 | |
16 | /// Flatten the given value ranges into a single vector of values. |
17 | static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { |
18 | SmallVector<Value> result; |
19 | for (const auto &vals : values) |
20 | llvm::append_range(C&: result, R: vals); |
21 | return result; |
22 | } |
23 | |
24 | namespace { |
25 | /// Converts the operand and result types of the CallOp, used together with the |
26 | /// FuncOpSignatureConversion. |
27 | struct CallOpSignatureConversion : public OpConversionPattern<CallOp> { |
28 | using OpConversionPattern<CallOp>::OpConversionPattern; |
29 | |
30 | /// Hook for derived classes to implement combined matching and rewriting. |
31 | LogicalResult |
32 | matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor, |
33 | ConversionPatternRewriter &rewriter) const override { |
34 | // Convert the original function results. Keep track of how many result |
35 | // types an original result type is converted into. |
36 | SmallVector<size_t> numResultsReplacments; |
37 | SmallVector<Type, 1> convertedResults; |
38 | size_t numFlattenedResults = 0; |
39 | for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) { |
40 | if (failed(typeConverter->convertTypes(type, convertedResults))) |
41 | return failure(); |
42 | numResultsReplacments.push_back(convertedResults.size() - |
43 | numFlattenedResults); |
44 | numFlattenedResults = convertedResults.size(); |
45 | } |
46 | |
47 | // Substitute with the new result types from the corresponding FuncType |
48 | // conversion. |
49 | auto newCallOp = rewriter.create<CallOp>( |
50 | callOp.getLoc(), callOp.getCallee(), convertedResults, |
51 | flattenValues(adaptor.getOperands())); |
52 | SmallVector<ValueRange> replacements; |
53 | size_t offset = 0; |
54 | for (int i = 0, e = callOp->getNumResults(); i < e; ++i) { |
55 | replacements.push_back( |
56 | Elt: newCallOp->getResults().slice(offset, numResultsReplacments[i])); |
57 | offset += numResultsReplacments[i]; |
58 | } |
59 | assert(offset == convertedResults.size() && |
60 | "expected that all converted results are used"); |
61 | rewriter.replaceOpWithMultiple(callOp, replacements); |
62 | return success(); |
63 | } |
64 | }; |
65 | } // namespace |
66 | |
67 | void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns, |
68 | const TypeConverter &converter) { |
69 | patterns.add<CallOpSignatureConversion>(arg: converter, args: patterns.getContext()); |
70 | } |
71 | |
72 | namespace { |
73 | /// Only needed to support partial conversion of functions where this pattern |
74 | /// ensures that the branch operation arguments matches up with the succesor |
75 | /// block arguments. |
76 | class BranchOpInterfaceTypeConversion |
77 | : public OpInterfaceConversionPattern<BranchOpInterface> { |
78 | public: |
79 | using OpInterfaceConversionPattern< |
80 | BranchOpInterface>::OpInterfaceConversionPattern; |
81 | |
82 | BranchOpInterfaceTypeConversion( |
83 | const TypeConverter &typeConverter, MLIRContext *ctx, |
84 | function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) |
85 | : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1), |
86 | shouldConvertBranchOperand(shouldConvertBranchOperand) {} |
87 | |
88 | LogicalResult |
89 | matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands, |
90 | ConversionPatternRewriter &rewriter) const final { |
91 | // For a branch operation, only some operands go to the target blocks, so |
92 | // only rewrite those. |
93 | SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end()); |
94 | for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors(); |
95 | succIdx < succEnd; ++succIdx) { |
96 | OperandRange forwardedOperands = |
97 | op.getSuccessorOperands(succIdx).getForwardedOperands(); |
98 | if (forwardedOperands.empty()) |
99 | continue; |
100 | |
101 | for (int idx = forwardedOperands.getBeginOperandIndex(), |
102 | eidx = idx + forwardedOperands.size(); |
103 | idx < eidx; ++idx) { |
104 | if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx)) |
105 | newOperands[idx] = operands[idx]; |
106 | } |
107 | } |
108 | rewriter.modifyOpInPlace( |
109 | op, [newOperands, op]() { op->setOperands(newOperands); }); |
110 | return success(); |
111 | } |
112 | |
113 | private: |
114 | function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand; |
115 | }; |
116 | } // namespace |
117 | |
118 | namespace { |
119 | /// Only needed to support partial conversion of functions where this pattern |
120 | /// ensures that the branch operation arguments matches up with the succesor |
121 | /// block arguments. |
122 | class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> { |
123 | public: |
124 | using OpConversionPattern<ReturnOp>::OpConversionPattern; |
125 | |
126 | LogicalResult |
127 | matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor, |
128 | ConversionPatternRewriter &rewriter) const final { |
129 | rewriter.replaceOpWithNewOp<ReturnOp>(op, |
130 | flattenValues(adaptor.getOperands())); |
131 | return success(); |
132 | } |
133 | }; |
134 | } // namespace |
135 | |
136 | void mlir::populateBranchOpInterfaceTypeConversionPattern( |
137 | RewritePatternSet &patterns, const TypeConverter &typeConverter, |
138 | function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) { |
139 | patterns.add<BranchOpInterfaceTypeConversion>( |
140 | arg: typeConverter, args: patterns.getContext(), args&: shouldConvertBranchOperand); |
141 | } |
142 | |
143 | bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern( |
144 | Operation *op, const TypeConverter &converter) { |
145 | // All successor operands of branch like operations must be rewritten. |
146 | if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { |
147 | for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { |
148 | auto successorOperands = branchOp.getSuccessorOperands(p); |
149 | if (!converter.isLegal( |
150 | successorOperands.getForwardedOperands().getTypes())) |
151 | return false; |
152 | } |
153 | return true; |
154 | } |
155 | |
156 | return false; |
157 | } |
158 | |
159 | void mlir::populateReturnOpTypeConversionPattern( |
160 | RewritePatternSet &patterns, const TypeConverter &typeConverter) { |
161 | patterns.add<ReturnOpTypeConversion>(arg: typeConverter, args: patterns.getContext()); |
162 | } |
163 | |
164 | bool mlir::isLegalForReturnOpTypeConversionPattern( |
165 | Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal) { |
166 | // If this is a `return` and the user pass wants to convert/transform across |
167 | // function boundaries, then `converter` is invoked to check whether the |
168 | // `return` op is legal. |
169 | if (isa<ReturnOp>(op) && !returnOpAlwaysLegal) |
170 | return converter.isLegal(op); |
171 | |
172 | // ReturnLike operations have to be legalized with their parent. For |
173 | // return this is handled, for other ops they remain as is. |
174 | return op->hasTrait<OpTrait::ReturnLike>(); |
175 | } |
176 | |
177 | bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) { |
178 | // If it is not a terminator, ignore it. |
179 | if (!op->mightHaveTrait<OpTrait::IsTerminator>()) |
180 | return true; |
181 | |
182 | // If it is not the last operation in the block, also ignore it. We do |
183 | // this to handle unknown operations, as well. |
184 | Block *block = op->getBlock(); |
185 | if (!block || &block->back() != op) |
186 | return true; |
187 | |
188 | // We don't want to handle terminators in nested regions, assume they are |
189 | // always legal. |
190 | if (!isa_and_nonnull<FuncOp>(op->getParentOp())) |
191 | return true; |
192 | |
193 | return false; |
194 | } |
195 |
Definitions
- flattenValues
- CallOpSignatureConversion
- matchAndRewrite
- populateCallOpTypeConversionPattern
- BranchOpInterfaceTypeConversion
- BranchOpInterfaceTypeConversion
- matchAndRewrite
- ReturnOpTypeConversion
- matchAndRewrite
- populateBranchOpInterfaceTypeConversionPattern
- isLegalForBranchOpInterfaceTypeConversionPattern
- populateReturnOpTypeConversionPattern
- isLegalForReturnOpTypeConversionPattern
Learn to use CMake with our Intro Training
Find out more