1//===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
10#include "mlir/Dialect/Func/IR/FuncOps.h"
11#include "mlir/IR/BuiltinOps.h"
12
13using namespace mlir;
14using namespace mlir::func;
15
16//===----------------------------------------------------------------------===//
17// ValueDecomposer
18//===----------------------------------------------------------------------===//
19
20void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc,
21 Type type, Value value,
22 SmallVectorImpl<Value> &results) {
23 for (auto &conversion : decomposeValueConversions)
24 if (conversion(builder, loc, type, value, results))
25 return;
26 results.push_back(Elt: value);
27}
28
29//===----------------------------------------------------------------------===//
30// DecomposeCallGraphTypesOpConversionPattern
31//===----------------------------------------------------------------------===//
32
33namespace {
34/// Base OpConversionPattern class to make a ValueDecomposer available to
35/// inherited patterns.
36template <typename SourceOp>
37class DecomposeCallGraphTypesOpConversionPattern
38 : public OpConversionPattern<SourceOp> {
39public:
40 DecomposeCallGraphTypesOpConversionPattern(TypeConverter &typeConverter,
41 MLIRContext *context,
42 ValueDecomposer &decomposer,
43 PatternBenefit benefit = 1)
44 : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
45 decomposer(decomposer) {}
46
47protected:
48 ValueDecomposer &decomposer;
49};
50} // namespace
51
52//===----------------------------------------------------------------------===//
53// DecomposeCallGraphTypesForFuncArgs
54//===----------------------------------------------------------------------===//
55
56namespace {
57/// Expand function arguments according to the provided TypeConverter and
58/// ValueDecomposer.
59struct DecomposeCallGraphTypesForFuncArgs
60 : public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
61 using DecomposeCallGraphTypesOpConversionPattern::
62 DecomposeCallGraphTypesOpConversionPattern;
63
64 LogicalResult
65 matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
66 ConversionPatternRewriter &rewriter) const final {
67 auto functionType = op.getFunctionType();
68
69 // Convert function arguments using the provided TypeConverter.
70 TypeConverter::SignatureConversion conversion(functionType.getNumInputs());
71 for (const auto &argType : llvm::enumerate(functionType.getInputs())) {
72 SmallVector<Type, 2> decomposedTypes;
73 if (failed(typeConverter->convertType(argType.value(), decomposedTypes)))
74 return failure();
75 if (!decomposedTypes.empty())
76 conversion.addInputs(argType.index(), decomposedTypes);
77 }
78
79 // If the SignatureConversion doesn't apply, bail out.
80 if (failed(rewriter.convertRegionTypes(region: &op.getBody(), converter: *getTypeConverter(),
81 entryConversion: &conversion)))
82 return failure();
83
84 // Update the signature of the function.
85 SmallVector<Type, 2> newResultTypes;
86 if (failed(typeConverter->convertTypes(functionType.getResults(),
87 newResultTypes)))
88 return failure();
89 rewriter.modifyOpInPlace(op, [&] {
90 op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
91 newResultTypes));
92 });
93 return success();
94 }
95};
96} // namespace
97
98//===----------------------------------------------------------------------===//
99// DecomposeCallGraphTypesForReturnOp
100//===----------------------------------------------------------------------===//
101
102namespace {
103/// Expand return operands according to the provided TypeConverter and
104/// ValueDecomposer.
105struct DecomposeCallGraphTypesForReturnOp
106 : public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> {
107 using DecomposeCallGraphTypesOpConversionPattern::
108 DecomposeCallGraphTypesOpConversionPattern;
109 LogicalResult
110 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
111 ConversionPatternRewriter &rewriter) const final {
112 SmallVector<Value, 2> newOperands;
113 for (Value operand : adaptor.getOperands())
114 decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
115 operand, newOperands);
116 rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
117 return success();
118 }
119};
120} // namespace
121
122//===----------------------------------------------------------------------===//
123// DecomposeCallGraphTypesForCallOp
124//===----------------------------------------------------------------------===//
125
126namespace {
127/// Expand call op operands and results according to the provided TypeConverter
128/// and ValueDecomposer.
129struct DecomposeCallGraphTypesForCallOp
130 : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
131 using DecomposeCallGraphTypesOpConversionPattern::
132 DecomposeCallGraphTypesOpConversionPattern;
133
134 LogicalResult
135 matchAndRewrite(CallOp op, OpAdaptor adaptor,
136 ConversionPatternRewriter &rewriter) const final {
137
138 // Create the operands list of the new `CallOp`.
139 SmallVector<Value, 2> newOperands;
140 for (Value operand : adaptor.getOperands())
141 decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
142 operand, newOperands);
143
144 // Create the new result types for the new `CallOp` and track the indices in
145 // the new call op's results that correspond to the old call op's results.
146 //
147 // expandedResultIndices[i] = "list of new result indices that old result i
148 // expanded to".
149 SmallVector<Type, 2> newResultTypes;
150 SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
151 for (Type resultType : op.getResultTypes()) {
152 unsigned oldSize = newResultTypes.size();
153 if (failed(typeConverter->convertType(resultType, newResultTypes)))
154 return failure();
155 auto &resultMapping = expandedResultIndices.emplace_back();
156 for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
157 resultMapping.push_back(i);
158 }
159
160 CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
161 newResultTypes, newOperands);
162
163 // Build a replacement value for each result to replace its uses. If a
164 // result has multiple mapping values, it needs to be materialized as a
165 // single value.
166 SmallVector<Value, 2> replacedValues;
167 replacedValues.reserve(N: op.getNumResults());
168 for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
169 auto decomposedValues = llvm::to_vector<6>(
170 llvm::map_range(expandedResultIndices[i],
171 [&](unsigned i) { return newCallOp.getResult(i); }));
172 if (decomposedValues.empty()) {
173 // No replacement is required.
174 replacedValues.push_back(Elt: nullptr);
175 } else if (decomposedValues.size() == 1) {
176 replacedValues.push_back(Elt: decomposedValues.front());
177 } else {
178 // Materialize a single Value to replace the original Value.
179 Value materialized = getTypeConverter()->materializeArgumentConversion(
180 rewriter, op.getLoc(), op.getType(i), decomposedValues);
181 replacedValues.push_back(Elt: materialized);
182 }
183 }
184 rewriter.replaceOp(op, replacedValues);
185 return success();
186 }
187};
188} // namespace
189
190void mlir::populateDecomposeCallGraphTypesPatterns(
191 MLIRContext *context, TypeConverter &typeConverter,
192 ValueDecomposer &decomposer, RewritePatternSet &patterns) {
193 patterns
194 .add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
195 DecomposeCallGraphTypesForReturnOp>(arg&: typeConverter, args&: context,
196 args&: decomposer);
197}
198

source code of mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp