1 | //===- TestDecomposeCallGraphTypes.cpp - Test CG type decomposition -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestDialect.h" |
10 | #include "TestOps.h" |
11 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
12 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
13 | #include "mlir/IR/Builders.h" |
14 | #include "mlir/Pass/Pass.h" |
15 | #include "mlir/Transforms/DialectConversion.h" |
16 | |
17 | using namespace mlir; |
18 | |
19 | namespace { |
20 | /// Creates a sequence of `test.get_tuple_element` ops for all elements of a |
21 | /// given tuple value. If some tuple elements are, in turn, tuples, the elements |
22 | /// of those are extracted recursively such that the returned values have the |
23 | /// same types as `resultTypes.getFlattenedTypes()`. |
24 | static SmallVector<Value> buildDecomposeTuple(OpBuilder &builder, |
25 | TypeRange resultTypes, |
26 | ValueRange inputs, Location loc) { |
27 | // Skip materialization if the single input value is not a tuple. |
28 | if (inputs.size() != 1) |
29 | return {}; |
30 | Value tuple = inputs.front(); |
31 | auto tupleType = dyn_cast<TupleType>(tuple.getType()); |
32 | if (!tupleType) |
33 | return {}; |
34 | // Skip materialization if the flattened types do not match the requested |
35 | // result types. |
36 | SmallVector<Type> flattenedTypes; |
37 | tupleType.getFlattenedTypes(flattenedTypes); |
38 | if (TypeRange(resultTypes) != TypeRange(flattenedTypes)) |
39 | return {}; |
40 | // Recursively decompose the tuple. |
41 | SmallVector<Value> result; |
42 | std::function<void(Value)> decompose = [&](Value tuple) { |
43 | auto tupleType = dyn_cast<TupleType>(tuple.getType()); |
44 | if (!tupleType) { |
45 | // This is not a tuple. |
46 | result.push_back(Elt: tuple); |
47 | return; |
48 | } |
49 | for (unsigned i = 0, e = tupleType.size(); i < e; ++i) { |
50 | Type elementType = tupleType.getType(i); |
51 | Value element = builder.create<test::GetTupleElementOp>( |
52 | loc, elementType, tuple, builder.getI32IntegerAttr(i)); |
53 | decompose(element); |
54 | } |
55 | }; |
56 | decompose(tuple); |
57 | return result; |
58 | } |
59 | |
60 | /// Creates a `test.make_tuple` op out of the given inputs building a tuple of |
61 | /// type `resultType`. If that type is nested, each nested tuple is built |
62 | /// recursively with another `test.make_tuple` op. |
63 | static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType, |
64 | ValueRange inputs, Location loc) { |
65 | // Build one value for each element at this nesting level. |
66 | SmallVector<Value> elements; |
67 | elements.reserve(N: resultType.getTypes().size()); |
68 | ValueRange::iterator inputIt = inputs.begin(); |
69 | for (Type elementType : resultType.getTypes()) { |
70 | if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) { |
71 | // Determine how many input values are needed for the nested elements of |
72 | // the nested TupleType and advance inputIt by that number. |
73 | // TODO: We only need the *number* of nested types, not the types itself. |
74 | // Maybe it's worth adding a more efficient overload? |
75 | SmallVector<Type> nestedFlattenedTypes; |
76 | nestedTupleType.getFlattenedTypes(nestedFlattenedTypes); |
77 | size_t numNestedFlattenedTypes = nestedFlattenedTypes.size(); |
78 | ValueRange nestedFlattenedelements(inputIt, |
79 | inputIt + numNestedFlattenedTypes); |
80 | inputIt += numNestedFlattenedTypes; |
81 | |
82 | // Recurse on the values for the nested TupleType. |
83 | Value res = buildMakeTupleOp(builder, nestedTupleType, |
84 | nestedFlattenedelements, loc); |
85 | if (!res) |
86 | return Value(); |
87 | |
88 | // The tuple constructed by the conversion is the element value. |
89 | elements.push_back(res); |
90 | } else { |
91 | // Base case: take one input as is. |
92 | elements.push_back(*inputIt++); |
93 | } |
94 | } |
95 | |
96 | // Assemble the tuple from the elements. |
97 | return builder.create<test::MakeTupleOp>(loc, resultType, elements); |
98 | } |
99 | |
100 | /// A pass for testing call graph type decomposition. |
101 | /// |
102 | /// This instantiates the patterns with a TypeConverter that splits tuple types |
103 | /// into their respective element types. |
104 | /// For example, `tuple<T1, T2, T3> --> T1, T2, T3`. |
105 | struct TestDecomposeCallGraphTypes |
106 | : public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> { |
107 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDecomposeCallGraphTypes) |
108 | |
109 | void getDependentDialects(DialectRegistry ®istry) const override { |
110 | registry.insert<test::TestDialect>(); |
111 | } |
112 | StringRef getArgument() const final { |
113 | return "test-decompose-call-graph-types" ; |
114 | } |
115 | StringRef getDescription() const final { |
116 | return "Decomposes types at call graph boundaries." ; |
117 | } |
118 | void runOnOperation() override { |
119 | ModuleOp module = getOperation(); |
120 | auto *context = &getContext(); |
121 | TypeConverter typeConverter; |
122 | ConversionTarget target(*context); |
123 | RewritePatternSet patterns(context); |
124 | |
125 | target.addLegalDialect<test::TestDialect>(); |
126 | |
127 | target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { |
128 | return typeConverter.isLegal(op.getOperandTypes()); |
129 | }); |
130 | target.addDynamicallyLegalOp<func::CallOp>( |
131 | [&](func::CallOp op) { return typeConverter.isLegal(op); }); |
132 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
133 | return typeConverter.isSignatureLegal(op.getFunctionType()); |
134 | }); |
135 | |
136 | typeConverter.addConversion(callback: [](Type type) { return type; }); |
137 | typeConverter.addConversion( |
138 | callback: [](TupleType tupleType, SmallVectorImpl<Type> &types) { |
139 | tupleType.getFlattenedTypes(types); |
140 | return success(); |
141 | }); |
142 | typeConverter.addSourceMaterialization(callback&: buildMakeTupleOp); |
143 | typeConverter.addTargetMaterialization(callback&: buildDecomposeTuple); |
144 | |
145 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( |
146 | patterns, typeConverter); |
147 | populateReturnOpTypeConversionPattern(patterns, converter: typeConverter); |
148 | populateCallOpTypeConversionPattern(patterns, converter: typeConverter); |
149 | |
150 | if (failed(applyPartialConversion(module, target, std::move(patterns)))) |
151 | return signalPassFailure(); |
152 | } |
153 | }; |
154 | |
155 | } // namespace |
156 | |
157 | namespace mlir { |
158 | namespace test { |
159 | void registerTestDecomposeCallGraphTypes() { |
160 | PassRegistration<TestDecomposeCallGraphTypes>(); |
161 | } |
162 | } // namespace test |
163 | } // namespace mlir |
164 | |