1//===- TestDecomposeCallGraphTypes.cpp - Test CG type decomposition -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Transforms/DialectConversion.h"
16
17using namespace mlir;
18
19namespace {
20/// Creates a sequence of `test.get_tuple_element` ops for all elements of a
21/// given tuple value. If some tuple elements are, in turn, tuples, the elements
22/// of those are extracted recursively such that the returned values have the
23/// same types as `resultTypes.getFlattenedTypes()`.
24static SmallVector<Value> buildDecomposeTuple(OpBuilder &builder,
25 TypeRange resultTypes,
26 ValueRange inputs, Location loc) {
27 // Skip materialization if the single input value is not a tuple.
28 if (inputs.size() != 1)
29 return {};
30 Value tuple = inputs.front();
31 auto tupleType = dyn_cast<TupleType>(tuple.getType());
32 if (!tupleType)
33 return {};
34 // Skip materialization if the flattened types do not match the requested
35 // result types.
36 SmallVector<Type> flattenedTypes;
37 tupleType.getFlattenedTypes(flattenedTypes);
38 if (TypeRange(resultTypes) != TypeRange(flattenedTypes))
39 return {};
40 // Recursively decompose the tuple.
41 SmallVector<Value> result;
42 std::function<void(Value)> decompose = [&](Value tuple) {
43 auto tupleType = dyn_cast<TupleType>(tuple.getType());
44 if (!tupleType) {
45 // This is not a tuple.
46 result.push_back(Elt: tuple);
47 return;
48 }
49 for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
50 Type elementType = tupleType.getType(i);
51 Value element = builder.create<test::GetTupleElementOp>(
52 loc, elementType, tuple, builder.getI32IntegerAttr(i));
53 decompose(element);
54 }
55 };
56 decompose(tuple);
57 return result;
58}
59
60/// Creates a `test.make_tuple` op out of the given inputs building a tuple of
61/// type `resultType`. If that type is nested, each nested tuple is built
62/// recursively with another `test.make_tuple` op.
63static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType,
64 ValueRange inputs, Location loc) {
65 // Build one value for each element at this nesting level.
66 SmallVector<Value> elements;
67 elements.reserve(N: resultType.getTypes().size());
68 ValueRange::iterator inputIt = inputs.begin();
69 for (Type elementType : resultType.getTypes()) {
70 if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
71 // Determine how many input values are needed for the nested elements of
72 // the nested TupleType and advance inputIt by that number.
73 // TODO: We only need the *number* of nested types, not the types itself.
74 // Maybe it's worth adding a more efficient overload?
75 SmallVector<Type> nestedFlattenedTypes;
76 nestedTupleType.getFlattenedTypes(nestedFlattenedTypes);
77 size_t numNestedFlattenedTypes = nestedFlattenedTypes.size();
78 ValueRange nestedFlattenedelements(inputIt,
79 inputIt + numNestedFlattenedTypes);
80 inputIt += numNestedFlattenedTypes;
81
82 // Recurse on the values for the nested TupleType.
83 Value res = buildMakeTupleOp(builder, nestedTupleType,
84 nestedFlattenedelements, loc);
85 if (!res)
86 return Value();
87
88 // The tuple constructed by the conversion is the element value.
89 elements.push_back(res);
90 } else {
91 // Base case: take one input as is.
92 elements.push_back(*inputIt++);
93 }
94 }
95
96 // Assemble the tuple from the elements.
97 return builder.create<test::MakeTupleOp>(loc, resultType, elements);
98}
99
100/// A pass for testing call graph type decomposition.
101///
102/// This instantiates the patterns with a TypeConverter that splits tuple types
103/// into their respective element types.
104/// For example, `tuple<T1, T2, T3> --> T1, T2, T3`.
105struct TestDecomposeCallGraphTypes
106 : public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> {
107 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDecomposeCallGraphTypes)
108
109 void getDependentDialects(DialectRegistry &registry) const override {
110 registry.insert<test::TestDialect>();
111 }
112 StringRef getArgument() const final {
113 return "test-decompose-call-graph-types";
114 }
115 StringRef getDescription() const final {
116 return "Decomposes types at call graph boundaries.";
117 }
118 void runOnOperation() override {
119 ModuleOp module = getOperation();
120 auto *context = &getContext();
121 TypeConverter typeConverter;
122 ConversionTarget target(*context);
123 RewritePatternSet patterns(context);
124
125 target.addLegalDialect<test::TestDialect>();
126
127 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
128 return typeConverter.isLegal(op.getOperandTypes());
129 });
130 target.addDynamicallyLegalOp<func::CallOp>(
131 [&](func::CallOp op) { return typeConverter.isLegal(op); });
132 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
133 return typeConverter.isSignatureLegal(op.getFunctionType());
134 });
135
136 typeConverter.addConversion(callback: [](Type type) { return type; });
137 typeConverter.addConversion(
138 callback: [](TupleType tupleType, SmallVectorImpl<Type> &types) {
139 tupleType.getFlattenedTypes(types);
140 return success();
141 });
142 typeConverter.addSourceMaterialization(callback&: buildMakeTupleOp);
143 typeConverter.addTargetMaterialization(callback&: buildDecomposeTuple);
144
145 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
146 patterns, typeConverter);
147 populateReturnOpTypeConversionPattern(patterns, converter: typeConverter);
148 populateCallOpTypeConversionPattern(patterns, converter: typeConverter);
149
150 if (failed(applyPartialConversion(module, target, std::move(patterns))))
151 return signalPassFailure();
152 }
153};
154
155} // namespace
156
157namespace mlir {
158namespace test {
159void registerTestDecomposeCallGraphTypes() {
160 PassRegistration<TestDecomposeCallGraphTypes>();
161}
162} // namespace test
163} // namespace mlir
164

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp