1//===- FuncConversions.cpp - Function conversions -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
10#include "mlir/Dialect/Func/IR/FuncOps.h"
11#include "mlir/Transforms/DialectConversion.h"
12
13using namespace mlir;
14using namespace mlir::func;
15
16/// Flatten the given value ranges into a single vector of values.
17static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
18 SmallVector<Value> result;
19 for (const auto &vals : values)
20 llvm::append_range(C&: result, R: vals);
21 return result;
22}
23
24namespace {
25/// Converts the operand and result types of the CallOp, used together with the
26/// FuncOpSignatureConversion.
27struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
28 using OpConversionPattern<CallOp>::OpConversionPattern;
29
30 /// Hook for derived classes to implement combined matching and rewriting.
31 LogicalResult
32 matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
33 ConversionPatternRewriter &rewriter) const override {
34 // Convert the original function results. Keep track of how many result
35 // types an original result type is converted into.
36 SmallVector<size_t> numResultsReplacments;
37 SmallVector<Type, 1> convertedResults;
38 size_t numFlattenedResults = 0;
39 for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) {
40 if (failed(typeConverter->convertTypes(type, convertedResults)))
41 return failure();
42 numResultsReplacments.push_back(convertedResults.size() -
43 numFlattenedResults);
44 numFlattenedResults = convertedResults.size();
45 }
46
47 // Substitute with the new result types from the corresponding FuncType
48 // conversion.
49 auto newCallOp = rewriter.create<CallOp>(
50 callOp.getLoc(), callOp.getCallee(), convertedResults,
51 flattenValues(adaptor.getOperands()));
52 SmallVector<ValueRange> replacements;
53 size_t offset = 0;
54 for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {
55 replacements.push_back(
56 Elt: newCallOp->getResults().slice(offset, numResultsReplacments[i]));
57 offset += numResultsReplacments[i];
58 }
59 assert(offset == convertedResults.size() &&
60 "expected that all converted results are used");
61 rewriter.replaceOpWithMultiple(callOp, replacements);
62 return success();
63 }
64};
65} // namespace
66
67void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
68 const TypeConverter &converter) {
69 patterns.add<CallOpSignatureConversion>(arg: converter, args: patterns.getContext());
70}
71
72namespace {
73/// Only needed to support partial conversion of functions where this pattern
74/// ensures that the branch operation arguments matches up with the succesor
75/// block arguments.
76class BranchOpInterfaceTypeConversion
77 : public OpInterfaceConversionPattern<BranchOpInterface> {
78public:
79 using OpInterfaceConversionPattern<
80 BranchOpInterface>::OpInterfaceConversionPattern;
81
82 BranchOpInterfaceTypeConversion(
83 const TypeConverter &typeConverter, MLIRContext *ctx,
84 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand)
85 : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1),
86 shouldConvertBranchOperand(shouldConvertBranchOperand) {}
87
88 LogicalResult
89 matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
90 ConversionPatternRewriter &rewriter) const final {
91 // For a branch operation, only some operands go to the target blocks, so
92 // only rewrite those.
93 SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
94 for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
95 succIdx < succEnd; ++succIdx) {
96 OperandRange forwardedOperands =
97 op.getSuccessorOperands(succIdx).getForwardedOperands();
98 if (forwardedOperands.empty())
99 continue;
100
101 for (int idx = forwardedOperands.getBeginOperandIndex(),
102 eidx = idx + forwardedOperands.size();
103 idx < eidx; ++idx) {
104 if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
105 newOperands[idx] = operands[idx];
106 }
107 }
108 rewriter.modifyOpInPlace(
109 op, [newOperands, op]() { op->setOperands(newOperands); });
110 return success();
111 }
112
113private:
114 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand;
115};
116} // namespace
117
118namespace {
119/// Only needed to support partial conversion of functions where this pattern
120/// ensures that the branch operation arguments matches up with the succesor
121/// block arguments.
122class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
123public:
124 using OpConversionPattern<ReturnOp>::OpConversionPattern;
125
126 LogicalResult
127 matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
128 ConversionPatternRewriter &rewriter) const final {
129 rewriter.replaceOpWithNewOp<ReturnOp>(op,
130 flattenValues(adaptor.getOperands()));
131 return success();
132 }
133};
134} // namespace
135
136void mlir::populateBranchOpInterfaceTypeConversionPattern(
137 RewritePatternSet &patterns, const TypeConverter &typeConverter,
138 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) {
139 patterns.add<BranchOpInterfaceTypeConversion>(
140 arg: typeConverter, args: patterns.getContext(), args&: shouldConvertBranchOperand);
141}
142
143bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
144 Operation *op, const TypeConverter &converter) {
145 // All successor operands of branch like operations must be rewritten.
146 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
147 for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
148 auto successorOperands = branchOp.getSuccessorOperands(p);
149 if (!converter.isLegal(
150 successorOperands.getForwardedOperands().getTypes()))
151 return false;
152 }
153 return true;
154 }
155
156 return false;
157}
158
159void mlir::populateReturnOpTypeConversionPattern(
160 RewritePatternSet &patterns, const TypeConverter &typeConverter) {
161 patterns.add<ReturnOpTypeConversion>(arg: typeConverter, args: patterns.getContext());
162}
163
164bool mlir::isLegalForReturnOpTypeConversionPattern(
165 Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal) {
166 // If this is a `return` and the user pass wants to convert/transform across
167 // function boundaries, then `converter` is invoked to check whether the
168 // `return` op is legal.
169 if (isa<ReturnOp>(op) && !returnOpAlwaysLegal)
170 return converter.isLegal(op);
171
172 // ReturnLike operations have to be legalized with their parent. For
173 // return this is handled, for other ops they remain as is.
174 return op->hasTrait<OpTrait::ReturnLike>();
175}
176
177bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) {
178 // If it is not a terminator, ignore it.
179 if (!op->mightHaveTrait<OpTrait::IsTerminator>())
180 return true;
181
182 // If it is not the last operation in the block, also ignore it. We do
183 // this to handle unknown operations, as well.
184 Block *block = op->getBlock();
185 if (!block || &block->back() != op)
186 return true;
187
188 // We don't want to handle terminators in nested regions, assume they are
189 // always legal.
190 if (!isa_and_nonnull<FuncOp>(op->getParentOp()))
191 return true;
192
193 return false;
194}
195

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp