1//===- FuncConversions.cpp - Function conversions -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
10#include "mlir/Dialect/Func/IR/FuncOps.h"
11#include "mlir/Transforms/DialectConversion.h"
12
13using namespace mlir;
14using namespace mlir::func;
15
16namespace {
17/// Converts the operand and result types of the CallOp, used together with the
18/// FuncOpSignatureConversion.
19struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
20 using OpConversionPattern<CallOp>::OpConversionPattern;
21
22 /// Hook for derived classes to implement combined matching and rewriting.
23 LogicalResult
24 matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
25 ConversionPatternRewriter &rewriter) const override {
26 // Convert the original function results.
27 SmallVector<Type, 1> convertedResults;
28 if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
29 convertedResults)))
30 return failure();
31
32 // If this isn't a one-to-one type mapping, we don't know how to aggregate
33 // the results.
34 if (callOp->getNumResults() != convertedResults.size())
35 return failure();
36
37 // Substitute with the new result types from the corresponding FuncType
38 // conversion.
39 rewriter.replaceOpWithNewOp<CallOp>(
40 callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
41 return success();
42 }
43};
44} // namespace
45
46void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
47 TypeConverter &converter) {
48 patterns.add<CallOpSignatureConversion>(arg&: converter, args: patterns.getContext());
49}
50
51namespace {
52/// Only needed to support partial conversion of functions where this pattern
53/// ensures that the branch operation arguments matches up with the succesor
54/// block arguments.
55class BranchOpInterfaceTypeConversion
56 : public OpInterfaceConversionPattern<BranchOpInterface> {
57public:
58 using OpInterfaceConversionPattern<
59 BranchOpInterface>::OpInterfaceConversionPattern;
60
61 BranchOpInterfaceTypeConversion(
62 TypeConverter &typeConverter, MLIRContext *ctx,
63 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand)
64 : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1),
65 shouldConvertBranchOperand(shouldConvertBranchOperand) {}
66
67 LogicalResult
68 matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
69 ConversionPatternRewriter &rewriter) const final {
70 // For a branch operation, only some operands go to the target blocks, so
71 // only rewrite those.
72 SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
73 for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
74 succIdx < succEnd; ++succIdx) {
75 OperandRange forwardedOperands =
76 op.getSuccessorOperands(succIdx).getForwardedOperands();
77 if (forwardedOperands.empty())
78 continue;
79
80 for (int idx = forwardedOperands.getBeginOperandIndex(),
81 eidx = idx + forwardedOperands.size();
82 idx < eidx; ++idx) {
83 if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
84 newOperands[idx] = operands[idx];
85 }
86 }
87 rewriter.modifyOpInPlace(
88 op, [newOperands, op]() { op->setOperands(newOperands); });
89 return success();
90 }
91
92private:
93 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand;
94};
95} // namespace
96
97namespace {
98/// Only needed to support partial conversion of functions where this pattern
99/// ensures that the branch operation arguments matches up with the succesor
100/// block arguments.
101class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
102public:
103 using OpConversionPattern<ReturnOp>::OpConversionPattern;
104
105 LogicalResult
106 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
107 ConversionPatternRewriter &rewriter) const final {
108 // For a return, all operands go to the results of the parent, so
109 // rewrite them all.
110 rewriter.modifyOpInPlace(op,
111 [&] { op->setOperands(adaptor.getOperands()); });
112 return success();
113 }
114};
115} // namespace
116
117void mlir::populateBranchOpInterfaceTypeConversionPattern(
118 RewritePatternSet &patterns, TypeConverter &typeConverter,
119 function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) {
120 patterns.add<BranchOpInterfaceTypeConversion>(
121 arg&: typeConverter, args: patterns.getContext(), args&: shouldConvertBranchOperand);
122}
123
124bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
125 Operation *op, TypeConverter &converter) {
126 // All successor operands of branch like operations must be rewritten.
127 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
128 for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
129 auto successorOperands = branchOp.getSuccessorOperands(p);
130 if (!converter.isLegal(
131 successorOperands.getForwardedOperands().getTypes()))
132 return false;
133 }
134 return true;
135 }
136
137 return false;
138}
139
140void mlir::populateReturnOpTypeConversionPattern(RewritePatternSet &patterns,
141 TypeConverter &typeConverter) {
142 patterns.add<ReturnOpTypeConversion>(arg&: typeConverter, args: patterns.getContext());
143}
144
145bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op,
146 TypeConverter &converter,
147 bool returnOpAlwaysLegal) {
148 // If this is a `return` and the user pass wants to convert/transform across
149 // function boundaries, then `converter` is invoked to check whether the
150 // `return` op is legal.
151 if (isa<ReturnOp>(op) && !returnOpAlwaysLegal)
152 return converter.isLegal(op);
153
154 // ReturnLike operations have to be legalized with their parent. For
155 // return this is handled, for other ops they remain as is.
156 return op->hasTrait<OpTrait::ReturnLike>();
157}
158
159bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) {
160 // If it is not a terminator, ignore it.
161 if (!op->mightHaveTrait<OpTrait::IsTerminator>())
162 return true;
163
164 // If it is not the last operation in the block, also ignore it. We do
165 // this to handle unknown operations, as well.
166 Block *block = op->getBlock();
167 if (!block || &block->back() != op)
168 return true;
169
170 // We don't want to handle terminators in nested regions, assume they are
171 // always legal.
172 if (!isa_and_nonnull<FuncOp>(op->getParentOp()))
173 return true;
174
175 return false;
176}
177

source code of mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp