1 | //===- TestOneToNTypeConversionPass.cpp - Test pass 1:N type conv. utils --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestDialect.h" |
10 | #include "TestOps.h" |
11 | #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" |
12 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
13 | #include "mlir/Pass/Pass.h" |
14 | #include "mlir/Transforms/OneToNTypeConversion.h" |
15 | |
16 | using namespace mlir; |
17 | |
18 | namespace { |
19 | /// Test pass that exercises the (poor-man's) 1:N type conversion mechanisms |
20 | /// in `applyPartialOneToNConversion` by converting built-in tuples to the |
21 | /// elements they consist of as well as some dummy ops operating on these |
22 | /// tuples. |
23 | struct TestOneToNTypeConversionPass |
24 | : public PassWrapper<TestOneToNTypeConversionPass, |
25 | OperationPass<ModuleOp>> { |
26 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneToNTypeConversionPass) |
27 | |
28 | TestOneToNTypeConversionPass() = default; |
29 | TestOneToNTypeConversionPass(const TestOneToNTypeConversionPass &pass) |
30 | : PassWrapper(pass) {} |
31 | |
32 | void getDependentDialects(DialectRegistry ®istry) const override { |
33 | registry.insert<test::TestDialect>(); |
34 | } |
35 | |
36 | StringRef getArgument() const final { |
37 | return "test-one-to-n-type-conversion" ; |
38 | } |
39 | |
40 | StringRef getDescription() const final { |
41 | return "Test pass for 1:N type conversion" ; |
42 | } |
43 | |
44 | Option<bool> convertFuncOps{*this, "convert-func-ops" , |
45 | llvm::cl::desc("Enable conversion on func ops" ), |
46 | llvm::cl::init(Val: false)}; |
47 | |
48 | Option<bool> convertSCFOps{*this, "convert-scf-ops" , |
49 | llvm::cl::desc("Enable conversion on scf ops" ), |
50 | llvm::cl::init(Val: false)}; |
51 | |
52 | Option<bool> convertTupleOps{*this, "convert-tuple-ops" , |
53 | llvm::cl::desc("Enable conversion on tuple ops" ), |
54 | llvm::cl::init(Val: false)}; |
55 | |
56 | void runOnOperation() override; |
57 | }; |
58 | |
59 | } // namespace |
60 | |
61 | namespace mlir { |
62 | namespace test { |
63 | void registerTestOneToNTypeConversionPass() { |
64 | PassRegistration<TestOneToNTypeConversionPass>(); |
65 | } |
66 | } // namespace test |
67 | } // namespace mlir |
68 | |
69 | namespace { |
70 | |
71 | /// Test pattern on for the `make_tuple` op from the test dialect that converts |
72 | /// this kind of op into it's "decomposed" form, i.e., the elements of the tuple |
73 | /// that is being produced by `test.make_tuple`, which are really just the |
74 | /// operands of this op. |
75 | class ConvertMakeTupleOp |
76 | : public OneToNOpConversionPattern<::test::MakeTupleOp> { |
77 | public: |
78 | using OneToNOpConversionPattern< |
79 | ::test::MakeTupleOp>::OneToNOpConversionPattern; |
80 | |
81 | LogicalResult |
82 | matchAndRewrite(::test::MakeTupleOp op, OpAdaptor adaptor, |
83 | OneToNPatternRewriter &rewriter) const override { |
84 | // Simply replace the current op with the converted operands. |
85 | rewriter.replaceOp(op, adaptor.getFlatOperands(), |
86 | adaptor.getResultMapping()); |
87 | return success(); |
88 | } |
89 | }; |
90 | |
91 | /// Test pattern on for the `get_tuple_element` op from the test dialect that |
92 | /// converts this kind of op into it's "decomposed" form, i.e., instead of |
93 | /// "physically" extracting one element from the tuple, we forward the one |
94 | /// element of the decomposed form that is being extracted (or the several |
95 | /// elements in case that element is a nested tuple). |
96 | class ConvertGetTupleElementOp |
97 | : public OneToNOpConversionPattern<::test::GetTupleElementOp> { |
98 | public: |
99 | using OneToNOpConversionPattern< |
100 | ::test::GetTupleElementOp>::OneToNOpConversionPattern; |
101 | |
102 | LogicalResult |
103 | matchAndRewrite(::test::GetTupleElementOp op, OpAdaptor adaptor, |
104 | OneToNPatternRewriter &rewriter) const override { |
105 | // Construct mapping for tuple element types. |
106 | auto stateType = cast<TupleType>(op->getOperand(0).getType()); |
107 | TypeRange originalElementTypes = stateType.getTypes(); |
108 | OneToNTypeMapping elementMapping(originalElementTypes); |
109 | if (failed(typeConverter->convertSignatureArgs(originalElementTypes, |
110 | elementMapping))) |
111 | return failure(); |
112 | |
113 | // Compute converted operands corresponding to original input tuple. |
114 | assert(adaptor.getOperands().size() == 1 && |
115 | "expected 'get_tuple_element' to have one operand" ); |
116 | ValueRange convertedTuple = adaptor.getOperands()[0]; |
117 | |
118 | // Got those converted operands that correspond to the index-th element ofq |
119 | // the original input tuple. |
120 | size_t index = op.getIndex(); |
121 | ValueRange = |
122 | elementMapping.getConvertedValues(convertedValues: convertedTuple, originalValueNo: index); |
123 | |
124 | rewriter.replaceOp(op, extractedElement, adaptor.getResultMapping()); |
125 | |
126 | return success(); |
127 | } |
128 | }; |
129 | |
130 | } // namespace |
131 | |
132 | static void populateDecomposeTuplesTestPatterns(TypeConverter &typeConverter, |
133 | RewritePatternSet &patterns) { |
134 | patterns.add< |
135 | // clang-format off |
136 | ConvertMakeTupleOp, |
137 | ConvertGetTupleElementOp |
138 | // clang-format on |
139 | >(arg&: typeConverter, args: patterns.getContext()); |
140 | } |
141 | |
142 | /// Creates a sequence of `test.get_tuple_element` ops for all elements of a |
143 | /// given tuple value. If some tuple elements are, in turn, tuples, the elements |
144 | /// of those are extracted recursively such that the returned values have the |
145 | /// same types as `resultTypes.getFlattenedTypes()`. |
146 | /// |
147 | /// This function has been copied (with small adaptions) from |
148 | /// TestDecomposeCallGraphTypes.cpp. |
149 | static std::optional<SmallVector<Value>> |
150 | buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input, |
151 | Location loc) { |
152 | TupleType inputType = dyn_cast<TupleType>(input.getType()); |
153 | if (!inputType) |
154 | return {}; |
155 | |
156 | SmallVector<Value> values; |
157 | for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) { |
158 | Value element = builder.create<::test::GetTupleElementOp>( |
159 | loc, elementType, input, builder.getI32IntegerAttr(idx)); |
160 | if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) { |
161 | // Recurse if the current element is also a tuple. |
162 | SmallVector<Type> flatRecursiveTypes; |
163 | nestedTupleType.getFlattenedTypes(flatRecursiveTypes); |
164 | std::optional<SmallVector<Value>> resursiveValues = |
165 | buildGetTupleElementOps(builder, flatRecursiveTypes, element, loc); |
166 | if (!resursiveValues.has_value()) |
167 | return {}; |
168 | values.append(resursiveValues.value()); |
169 | } else { |
170 | values.push_back(element); |
171 | } |
172 | } |
173 | return values; |
174 | } |
175 | |
176 | /// Creates a `test.make_tuple` op out of the given inputs building a tuple of |
177 | /// type `resultType`. If that type is nested, each nested tuple is built |
178 | /// recursively with another `test.make_tuple` op. |
179 | /// |
180 | /// This function has been copied (with small adaptions) from |
181 | /// TestDecomposeCallGraphTypes.cpp. |
182 | static std::optional<Value> buildMakeTupleOp(OpBuilder &builder, |
183 | TupleType resultType, |
184 | ValueRange inputs, Location loc) { |
185 | // Build one value for each element at this nesting level. |
186 | SmallVector<Value> elements; |
187 | elements.reserve(N: resultType.getTypes().size()); |
188 | ValueRange::iterator inputIt = inputs.begin(); |
189 | for (Type elementType : resultType.getTypes()) { |
190 | if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) { |
191 | // Determine how many input values are needed for the nested elements of |
192 | // the nested TupleType and advance inputIt by that number. |
193 | // TODO: We only need the *number* of nested types, not the types itself. |
194 | // Maybe it's worth adding a more efficient overload? |
195 | SmallVector<Type> nestedFlattenedTypes; |
196 | nestedTupleType.getFlattenedTypes(nestedFlattenedTypes); |
197 | size_t numNestedFlattenedTypes = nestedFlattenedTypes.size(); |
198 | ValueRange nestedFlattenedelements(inputIt, |
199 | inputIt + numNestedFlattenedTypes); |
200 | inputIt += numNestedFlattenedTypes; |
201 | |
202 | // Recurse on the values for the nested TupleType. |
203 | std::optional<Value> res = buildMakeTupleOp(builder, nestedTupleType, |
204 | nestedFlattenedelements, loc); |
205 | if (!res.has_value()) |
206 | return {}; |
207 | |
208 | // The tuple constructed by the conversion is the element value. |
209 | elements.push_back(res.value()); |
210 | } else { |
211 | // Base case: take one input as is. |
212 | elements.push_back(*inputIt++); |
213 | } |
214 | } |
215 | |
216 | // Assemble the tuple from the elements. |
217 | return builder.create<::test::MakeTupleOp>(loc, resultType, elements); |
218 | } |
219 | |
220 | void TestOneToNTypeConversionPass::runOnOperation() { |
221 | ModuleOp module = getOperation(); |
222 | auto *context = &getContext(); |
223 | |
224 | // Assemble type converter. |
225 | OneToNTypeConverter typeConverter; |
226 | |
227 | typeConverter.addConversion(callback: [](Type type) { return type; }); |
228 | typeConverter.addConversion( |
229 | callback: [](TupleType tupleType, SmallVectorImpl<Type> &types) { |
230 | tupleType.getFlattenedTypes(types); |
231 | return success(); |
232 | }); |
233 | |
234 | typeConverter.addArgumentMaterialization(callback&: buildMakeTupleOp); |
235 | typeConverter.addSourceMaterialization(callback&: buildMakeTupleOp); |
236 | typeConverter.addTargetMaterialization(callback: buildGetTupleElementOps); |
237 | |
238 | // Assemble patterns. |
239 | RewritePatternSet patterns(context); |
240 | if (convertTupleOps) |
241 | populateDecomposeTuplesTestPatterns(typeConverter, patterns); |
242 | if (convertFuncOps) |
243 | populateFuncTypeConversionPatterns(typeConverter, patterns); |
244 | if (convertSCFOps) |
245 | scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns); |
246 | |
247 | // Run conversion. |
248 | if (failed(applyPartialOneToNConversion(module, typeConverter, |
249 | std::move(patterns)))) |
250 | return signalPassFailure(); |
251 | } |
252 | |