1 | //===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h" |
10 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
11 | #include "mlir/IR/BuiltinOps.h" |
12 | |
13 | using namespace mlir; |
14 | using namespace mlir::func; |
15 | |
16 | //===----------------------------------------------------------------------===// |
17 | // ValueDecomposer |
18 | //===----------------------------------------------------------------------===// |
19 | |
20 | void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc, |
21 | Type type, Value value, |
22 | SmallVectorImpl<Value> &results) { |
23 | for (auto &conversion : decomposeValueConversions) |
24 | if (conversion(builder, loc, type, value, results)) |
25 | return; |
26 | results.push_back(Elt: value); |
27 | } |
28 | |
29 | //===----------------------------------------------------------------------===// |
30 | // DecomposeCallGraphTypesOpConversionPattern |
31 | //===----------------------------------------------------------------------===// |
32 | |
33 | namespace { |
34 | /// Base OpConversionPattern class to make a ValueDecomposer available to |
35 | /// inherited patterns. |
36 | template <typename SourceOp> |
37 | class DecomposeCallGraphTypesOpConversionPattern |
38 | : public OpConversionPattern<SourceOp> { |
39 | public: |
40 | DecomposeCallGraphTypesOpConversionPattern(TypeConverter &typeConverter, |
41 | MLIRContext *context, |
42 | ValueDecomposer &decomposer, |
43 | PatternBenefit benefit = 1) |
44 | : OpConversionPattern<SourceOp>(typeConverter, context, benefit), |
45 | decomposer(decomposer) {} |
46 | |
47 | protected: |
48 | ValueDecomposer &decomposer; |
49 | }; |
50 | } // namespace |
51 | |
52 | //===----------------------------------------------------------------------===// |
53 | // DecomposeCallGraphTypesForFuncArgs |
54 | //===----------------------------------------------------------------------===// |
55 | |
56 | namespace { |
57 | /// Expand function arguments according to the provided TypeConverter and |
58 | /// ValueDecomposer. |
59 | struct DecomposeCallGraphTypesForFuncArgs |
60 | : public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> { |
61 | using DecomposeCallGraphTypesOpConversionPattern:: |
62 | DecomposeCallGraphTypesOpConversionPattern; |
63 | |
64 | LogicalResult |
65 | matchAndRewrite(func::FuncOp op, OpAdaptor adaptor, |
66 | ConversionPatternRewriter &rewriter) const final { |
67 | auto functionType = op.getFunctionType(); |
68 | |
69 | // Convert function arguments using the provided TypeConverter. |
70 | TypeConverter::SignatureConversion conversion(functionType.getNumInputs()); |
71 | for (const auto &argType : llvm::enumerate(functionType.getInputs())) { |
72 | SmallVector<Type, 2> decomposedTypes; |
73 | if (failed(typeConverter->convertType(argType.value(), decomposedTypes))) |
74 | return failure(); |
75 | if (!decomposedTypes.empty()) |
76 | conversion.addInputs(argType.index(), decomposedTypes); |
77 | } |
78 | |
79 | // If the SignatureConversion doesn't apply, bail out. |
80 | if (failed(rewriter.convertRegionTypes(region: &op.getBody(), converter: *getTypeConverter(), |
81 | entryConversion: &conversion))) |
82 | return failure(); |
83 | |
84 | // Update the signature of the function. |
85 | SmallVector<Type, 2> newResultTypes; |
86 | if (failed(typeConverter->convertTypes(functionType.getResults(), |
87 | newResultTypes))) |
88 | return failure(); |
89 | rewriter.modifyOpInPlace(op, [&] { |
90 | op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), |
91 | newResultTypes)); |
92 | }); |
93 | return success(); |
94 | } |
95 | }; |
96 | } // namespace |
97 | |
98 | //===----------------------------------------------------------------------===// |
99 | // DecomposeCallGraphTypesForReturnOp |
100 | //===----------------------------------------------------------------------===// |
101 | |
102 | namespace { |
103 | /// Expand return operands according to the provided TypeConverter and |
104 | /// ValueDecomposer. |
105 | struct DecomposeCallGraphTypesForReturnOp |
106 | : public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> { |
107 | using DecomposeCallGraphTypesOpConversionPattern:: |
108 | DecomposeCallGraphTypesOpConversionPattern; |
109 | LogicalResult |
110 | matchAndRewrite(ReturnOp op, OpAdaptor adaptor, |
111 | ConversionPatternRewriter &rewriter) const final { |
112 | SmallVector<Value, 2> newOperands; |
113 | for (Value operand : adaptor.getOperands()) |
114 | decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(), |
115 | operand, newOperands); |
116 | rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands); |
117 | return success(); |
118 | } |
119 | }; |
120 | } // namespace |
121 | |
122 | //===----------------------------------------------------------------------===// |
123 | // DecomposeCallGraphTypesForCallOp |
124 | //===----------------------------------------------------------------------===// |
125 | |
126 | namespace { |
127 | /// Expand call op operands and results according to the provided TypeConverter |
128 | /// and ValueDecomposer. |
129 | struct DecomposeCallGraphTypesForCallOp |
130 | : public DecomposeCallGraphTypesOpConversionPattern<CallOp> { |
131 | using DecomposeCallGraphTypesOpConversionPattern:: |
132 | DecomposeCallGraphTypesOpConversionPattern; |
133 | |
134 | LogicalResult |
135 | matchAndRewrite(CallOp op, OpAdaptor adaptor, |
136 | ConversionPatternRewriter &rewriter) const final { |
137 | |
138 | // Create the operands list of the new `CallOp`. |
139 | SmallVector<Value, 2> newOperands; |
140 | for (Value operand : adaptor.getOperands()) |
141 | decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(), |
142 | operand, newOperands); |
143 | |
144 | // Create the new result types for the new `CallOp` and track the indices in |
145 | // the new call op's results that correspond to the old call op's results. |
146 | // |
147 | // expandedResultIndices[i] = "list of new result indices that old result i |
148 | // expanded to". |
149 | SmallVector<Type, 2> newResultTypes; |
150 | SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices; |
151 | for (Type resultType : op.getResultTypes()) { |
152 | unsigned oldSize = newResultTypes.size(); |
153 | if (failed(typeConverter->convertType(resultType, newResultTypes))) |
154 | return failure(); |
155 | auto &resultMapping = expandedResultIndices.emplace_back(); |
156 | for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++) |
157 | resultMapping.push_back(i); |
158 | } |
159 | |
160 | CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(), |
161 | newResultTypes, newOperands); |
162 | |
163 | // Build a replacement value for each result to replace its uses. If a |
164 | // result has multiple mapping values, it needs to be materialized as a |
165 | // single value. |
166 | SmallVector<Value, 2> replacedValues; |
167 | replacedValues.reserve(N: op.getNumResults()); |
168 | for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { |
169 | auto decomposedValues = llvm::to_vector<6>( |
170 | llvm::map_range(expandedResultIndices[i], |
171 | [&](unsigned i) { return newCallOp.getResult(i); })); |
172 | if (decomposedValues.empty()) { |
173 | // No replacement is required. |
174 | replacedValues.push_back(Elt: nullptr); |
175 | } else if (decomposedValues.size() == 1) { |
176 | replacedValues.push_back(Elt: decomposedValues.front()); |
177 | } else { |
178 | // Materialize a single Value to replace the original Value. |
179 | Value materialized = getTypeConverter()->materializeArgumentConversion( |
180 | rewriter, op.getLoc(), op.getType(i), decomposedValues); |
181 | replacedValues.push_back(Elt: materialized); |
182 | } |
183 | } |
184 | rewriter.replaceOp(op, replacedValues); |
185 | return success(); |
186 | } |
187 | }; |
188 | } // namespace |
189 | |
190 | void mlir::populateDecomposeCallGraphTypesPatterns( |
191 | MLIRContext *context, TypeConverter &typeConverter, |
192 | ValueDecomposer &decomposer, RewritePatternSet &patterns) { |
193 | patterns |
194 | .add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs, |
195 | DecomposeCallGraphTypesForReturnOp>(arg&: typeConverter, args&: context, |
196 | args&: decomposer); |
197 | } |
198 | |