1//===- TestOneToNTypeConversionPass.cpp - Test pass 1:N type conv. utils --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
12#include "mlir/Dialect/SCF/Transforms/Patterns.h"
13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/OneToNTypeConversion.h"
15
16using namespace mlir;
17
18namespace {
19/// Test pass that exercises the (poor-man's) 1:N type conversion mechanisms
20/// in `applyPartialOneToNConversion` by converting built-in tuples to the
21/// elements they consist of as well as some dummy ops operating on these
22/// tuples.
23struct TestOneToNTypeConversionPass
24 : public PassWrapper<TestOneToNTypeConversionPass,
25 OperationPass<ModuleOp>> {
26 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneToNTypeConversionPass)
27
28 TestOneToNTypeConversionPass() = default;
29 TestOneToNTypeConversionPass(const TestOneToNTypeConversionPass &pass)
30 : PassWrapper(pass) {}
31
32 void getDependentDialects(DialectRegistry &registry) const override {
33 registry.insert<test::TestDialect>();
34 }
35
36 StringRef getArgument() const final {
37 return "test-one-to-n-type-conversion";
38 }
39
40 StringRef getDescription() const final {
41 return "Test pass for 1:N type conversion";
42 }
43
44 Option<bool> convertFuncOps{*this, "convert-func-ops",
45 llvm::cl::desc("Enable conversion on func ops"),
46 llvm::cl::init(Val: false)};
47
48 Option<bool> convertSCFOps{*this, "convert-scf-ops",
49 llvm::cl::desc("Enable conversion on scf ops"),
50 llvm::cl::init(Val: false)};
51
52 Option<bool> convertTupleOps{*this, "convert-tuple-ops",
53 llvm::cl::desc("Enable conversion on tuple ops"),
54 llvm::cl::init(Val: false)};
55
56 void runOnOperation() override;
57};
58
59} // namespace
60
61namespace mlir {
62namespace test {
63void registerTestOneToNTypeConversionPass() {
64 PassRegistration<TestOneToNTypeConversionPass>();
65}
66} // namespace test
67} // namespace mlir
68
69namespace {
70
71/// Test pattern on for the `make_tuple` op from the test dialect that converts
72/// this kind of op into it's "decomposed" form, i.e., the elements of the tuple
73/// that is being produced by `test.make_tuple`, which are really just the
74/// operands of this op.
75class ConvertMakeTupleOp
76 : public OneToNOpConversionPattern<::test::MakeTupleOp> {
77public:
78 using OneToNOpConversionPattern<
79 ::test::MakeTupleOp>::OneToNOpConversionPattern;
80
81 LogicalResult
82 matchAndRewrite(::test::MakeTupleOp op, OpAdaptor adaptor,
83 OneToNPatternRewriter &rewriter) const override {
84 // Simply replace the current op with the converted operands.
85 rewriter.replaceOp(op, adaptor.getFlatOperands(),
86 adaptor.getResultMapping());
87 return success();
88 }
89};
90
91/// Test pattern on for the `get_tuple_element` op from the test dialect that
92/// converts this kind of op into it's "decomposed" form, i.e., instead of
93/// "physically" extracting one element from the tuple, we forward the one
94/// element of the decomposed form that is being extracted (or the several
95/// elements in case that element is a nested tuple).
96class ConvertGetTupleElementOp
97 : public OneToNOpConversionPattern<::test::GetTupleElementOp> {
98public:
99 using OneToNOpConversionPattern<
100 ::test::GetTupleElementOp>::OneToNOpConversionPattern;
101
102 LogicalResult
103 matchAndRewrite(::test::GetTupleElementOp op, OpAdaptor adaptor,
104 OneToNPatternRewriter &rewriter) const override {
105 // Construct mapping for tuple element types.
106 auto stateType = cast<TupleType>(op->getOperand(0).getType());
107 TypeRange originalElementTypes = stateType.getTypes();
108 OneToNTypeMapping elementMapping(originalElementTypes);
109 if (failed(typeConverter->convertSignatureArgs(originalElementTypes,
110 elementMapping)))
111 return failure();
112
113 // Compute converted operands corresponding to original input tuple.
114 assert(adaptor.getOperands().size() == 1 &&
115 "expected 'get_tuple_element' to have one operand");
116 ValueRange convertedTuple = adaptor.getOperands()[0];
117
118 // Got those converted operands that correspond to the index-th element ofq
119 // the original input tuple.
120 size_t index = op.getIndex();
121 ValueRange extractedElement =
122 elementMapping.getConvertedValues(convertedValues: convertedTuple, originalValueNo: index);
123
124 rewriter.replaceOp(op, extractedElement, adaptor.getResultMapping());
125
126 return success();
127 }
128};
129
130} // namespace
131
132static void populateDecomposeTuplesTestPatterns(TypeConverter &typeConverter,
133 RewritePatternSet &patterns) {
134 patterns.add<
135 // clang-format off
136 ConvertMakeTupleOp,
137 ConvertGetTupleElementOp
138 // clang-format on
139 >(arg&: typeConverter, args: patterns.getContext());
140}
141
142/// Creates a sequence of `test.get_tuple_element` ops for all elements of a
143/// given tuple value. If some tuple elements are, in turn, tuples, the elements
144/// of those are extracted recursively such that the returned values have the
145/// same types as `resultTypes.getFlattenedTypes()`.
146///
147/// This function has been copied (with small adaptions) from
148/// TestDecomposeCallGraphTypes.cpp.
149static std::optional<SmallVector<Value>>
150buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
151 Location loc) {
152 TupleType inputType = dyn_cast<TupleType>(input.getType());
153 if (!inputType)
154 return {};
155
156 SmallVector<Value> values;
157 for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) {
158 Value element = builder.create<::test::GetTupleElementOp>(
159 loc, elementType, input, builder.getI32IntegerAttr(idx));
160 if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
161 // Recurse if the current element is also a tuple.
162 SmallVector<Type> flatRecursiveTypes;
163 nestedTupleType.getFlattenedTypes(flatRecursiveTypes);
164 std::optional<SmallVector<Value>> resursiveValues =
165 buildGetTupleElementOps(builder, flatRecursiveTypes, element, loc);
166 if (!resursiveValues.has_value())
167 return {};
168 values.append(resursiveValues.value());
169 } else {
170 values.push_back(element);
171 }
172 }
173 return values;
174}
175
176/// Creates a `test.make_tuple` op out of the given inputs building a tuple of
177/// type `resultType`. If that type is nested, each nested tuple is built
178/// recursively with another `test.make_tuple` op.
179///
180/// This function has been copied (with small adaptions) from
181/// TestDecomposeCallGraphTypes.cpp.
182static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
183 TupleType resultType,
184 ValueRange inputs, Location loc) {
185 // Build one value for each element at this nesting level.
186 SmallVector<Value> elements;
187 elements.reserve(N: resultType.getTypes().size());
188 ValueRange::iterator inputIt = inputs.begin();
189 for (Type elementType : resultType.getTypes()) {
190 if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
191 // Determine how many input values are needed for the nested elements of
192 // the nested TupleType and advance inputIt by that number.
193 // TODO: We only need the *number* of nested types, not the types itself.
194 // Maybe it's worth adding a more efficient overload?
195 SmallVector<Type> nestedFlattenedTypes;
196 nestedTupleType.getFlattenedTypes(nestedFlattenedTypes);
197 size_t numNestedFlattenedTypes = nestedFlattenedTypes.size();
198 ValueRange nestedFlattenedelements(inputIt,
199 inputIt + numNestedFlattenedTypes);
200 inputIt += numNestedFlattenedTypes;
201
202 // Recurse on the values for the nested TupleType.
203 std::optional<Value> res = buildMakeTupleOp(builder, nestedTupleType,
204 nestedFlattenedelements, loc);
205 if (!res.has_value())
206 return {};
207
208 // The tuple constructed by the conversion is the element value.
209 elements.push_back(res.value());
210 } else {
211 // Base case: take one input as is.
212 elements.push_back(*inputIt++);
213 }
214 }
215
216 // Assemble the tuple from the elements.
217 return builder.create<::test::MakeTupleOp>(loc, resultType, elements);
218}
219
220void TestOneToNTypeConversionPass::runOnOperation() {
221 ModuleOp module = getOperation();
222 auto *context = &getContext();
223
224 // Assemble type converter.
225 OneToNTypeConverter typeConverter;
226
227 typeConverter.addConversion(callback: [](Type type) { return type; });
228 typeConverter.addConversion(
229 callback: [](TupleType tupleType, SmallVectorImpl<Type> &types) {
230 tupleType.getFlattenedTypes(types);
231 return success();
232 });
233
234 typeConverter.addArgumentMaterialization(callback&: buildMakeTupleOp);
235 typeConverter.addSourceMaterialization(callback&: buildMakeTupleOp);
236 typeConverter.addTargetMaterialization(callback: buildGetTupleElementOps);
237
238 // Assemble patterns.
239 RewritePatternSet patterns(context);
240 if (convertTupleOps)
241 populateDecomposeTuplesTestPatterns(typeConverter, patterns);
242 if (convertFuncOps)
243 populateFuncTypeConversionPatterns(typeConverter, patterns);
244 if (convertSCFOps)
245 scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns);
246
247 // Run conversion.
248 if (failed(applyPartialOneToNConversion(module, typeConverter,
249 std::move(patterns))))
250 return signalPassFailure();
251}
252

source code of mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp