1 | //===- TestDecomposeCallGraphTypes.cpp - Test CG type decomposition -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestDialect.h" |
10 | #include "TestOps.h" |
11 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
12 | #include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h" |
13 | #include "mlir/IR/Builders.h" |
14 | #include "mlir/Pass/Pass.h" |
15 | #include "mlir/Transforms/DialectConversion.h" |
16 | |
17 | using namespace mlir; |
18 | |
19 | namespace { |
20 | /// Creates a sequence of `test.get_tuple_element` ops for all elements of a |
21 | /// given tuple value. If some tuple elements are, in turn, tuples, the elements |
22 | /// of those are extracted recursively such that the returned values have the |
23 | /// same types as `resultTypes.getFlattenedTypes()`. |
24 | static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc, |
25 | TupleType resultType, Value value, |
26 | SmallVectorImpl<Value> &values) { |
27 | for (unsigned i = 0, e = resultType.size(); i < e; ++i) { |
28 | Type elementType = resultType.getType(i); |
29 | Value element = builder.create<test::GetTupleElementOp>( |
30 | loc, elementType, value, builder.getI32IntegerAttr(i)); |
31 | if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) { |
32 | // Recurse if the current element is also a tuple. |
33 | if (failed(buildDecomposeTuple(builder, loc, nestedTupleType, element, |
34 | values))) |
35 | return failure(); |
36 | } else { |
37 | values.push_back(Elt: element); |
38 | } |
39 | } |
40 | return success(); |
41 | } |
42 | |
43 | /// Creates a `test.make_tuple` op out of the given inputs building a tuple of |
44 | /// type `resultType`. If that type is nested, each nested tuple is built |
45 | /// recursively with another `test.make_tuple` op. |
46 | static std::optional<Value> buildMakeTupleOp(OpBuilder &builder, |
47 | TupleType resultType, |
48 | ValueRange inputs, Location loc) { |
49 | // Build one value for each element at this nesting level. |
50 | SmallVector<Value> elements; |
51 | elements.reserve(N: resultType.getTypes().size()); |
52 | ValueRange::iterator inputIt = inputs.begin(); |
53 | for (Type elementType : resultType.getTypes()) { |
54 | if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) { |
55 | // Determine how many input values are needed for the nested elements of |
56 | // the nested TupleType and advance inputIt by that number. |
57 | // TODO: We only need the *number* of nested types, not the types itself. |
58 | // Maybe it's worth adding a more efficient overload? |
59 | SmallVector<Type> nestedFlattenedTypes; |
60 | nestedTupleType.getFlattenedTypes(nestedFlattenedTypes); |
61 | size_t numNestedFlattenedTypes = nestedFlattenedTypes.size(); |
62 | ValueRange nestedFlattenedelements(inputIt, |
63 | inputIt + numNestedFlattenedTypes); |
64 | inputIt += numNestedFlattenedTypes; |
65 | |
66 | // Recurse on the values for the nested TupleType. |
67 | std::optional<Value> res = buildMakeTupleOp(builder, nestedTupleType, |
68 | nestedFlattenedelements, loc); |
69 | if (!res.has_value()) |
70 | return {}; |
71 | |
72 | // The tuple constructed by the conversion is the element value. |
73 | elements.push_back(res.value()); |
74 | } else { |
75 | // Base case: take one input as is. |
76 | elements.push_back(*inputIt++); |
77 | } |
78 | } |
79 | |
80 | // Assemble the tuple from the elements. |
81 | return builder.create<test::MakeTupleOp>(loc, resultType, elements); |
82 | } |
83 | |
84 | /// A pass for testing call graph type decomposition. |
85 | /// |
86 | /// This instantiates the patterns with a TypeConverter and ValueDecomposer |
87 | /// that splits tuple types into their respective element types. |
88 | /// For example, `tuple<T1, T2, T3> --> T1, T2, T3`. |
89 | struct TestDecomposeCallGraphTypes |
90 | : public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> { |
91 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDecomposeCallGraphTypes) |
92 | |
93 | void getDependentDialects(DialectRegistry ®istry) const override { |
94 | registry.insert<test::TestDialect>(); |
95 | } |
96 | StringRef getArgument() const final { |
97 | return "test-decompose-call-graph-types" ; |
98 | } |
99 | StringRef getDescription() const final { |
100 | return "Decomposes types at call graph boundaries." ; |
101 | } |
102 | void runOnOperation() override { |
103 | ModuleOp module = getOperation(); |
104 | auto *context = &getContext(); |
105 | TypeConverter typeConverter; |
106 | ConversionTarget target(*context); |
107 | RewritePatternSet patterns(context); |
108 | |
109 | target.addLegalDialect<test::TestDialect>(); |
110 | |
111 | target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { |
112 | return typeConverter.isLegal(op.getOperandTypes()); |
113 | }); |
114 | target.addDynamicallyLegalOp<func::CallOp>( |
115 | [&](func::CallOp op) { return typeConverter.isLegal(op); }); |
116 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
117 | return typeConverter.isSignatureLegal(op.getFunctionType()); |
118 | }); |
119 | |
120 | typeConverter.addConversion(callback: [](Type type) { return type; }); |
121 | typeConverter.addConversion( |
122 | callback: [](TupleType tupleType, SmallVectorImpl<Type> &types) { |
123 | tupleType.getFlattenedTypes(types); |
124 | return success(); |
125 | }); |
126 | typeConverter.addArgumentMaterialization(callback&: buildMakeTupleOp); |
127 | |
128 | ValueDecomposer decomposer; |
129 | decomposer.addDecomposeValueConversion(callback&: buildDecomposeTuple); |
130 | |
131 | populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer, |
132 | patterns); |
133 | |
134 | if (failed(applyPartialConversion(module, target, std::move(patterns)))) |
135 | return signalPassFailure(); |
136 | } |
137 | }; |
138 | |
139 | } // namespace |
140 | |
141 | namespace mlir { |
142 | namespace test { |
143 | void registerTestDecomposeCallGraphTypes() { |
144 | PassRegistration<TestDecomposeCallGraphTypes>(); |
145 | } |
146 | } // namespace test |
147 | } // namespace mlir |
148 | |