1//===- TestDecomposeCallGraphTypes.cpp - Test CG type decomposition -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Transforms/DialectConversion.h"
16
17using namespace mlir;
18
19namespace {
20/// Creates a sequence of `test.get_tuple_element` ops for all elements of a
21/// given tuple value. If some tuple elements are, in turn, tuples, the elements
22/// of those are extracted recursively such that the returned values have the
23/// same types as `resultTypes.getFlattenedTypes()`.
24static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc,
25 TupleType resultType, Value value,
26 SmallVectorImpl<Value> &values) {
27 for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
28 Type elementType = resultType.getType(i);
29 Value element = builder.create<test::GetTupleElementOp>(
30 loc, elementType, value, builder.getI32IntegerAttr(i));
31 if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
32 // Recurse if the current element is also a tuple.
33 if (failed(buildDecomposeTuple(builder, loc, nestedTupleType, element,
34 values)))
35 return failure();
36 } else {
37 values.push_back(Elt: element);
38 }
39 }
40 return success();
41}
42
43/// Creates a `test.make_tuple` op out of the given inputs building a tuple of
44/// type `resultType`. If that type is nested, each nested tuple is built
45/// recursively with another `test.make_tuple` op.
46static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
47 TupleType resultType,
48 ValueRange inputs, Location loc) {
49 // Build one value for each element at this nesting level.
50 SmallVector<Value> elements;
51 elements.reserve(N: resultType.getTypes().size());
52 ValueRange::iterator inputIt = inputs.begin();
53 for (Type elementType : resultType.getTypes()) {
54 if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
55 // Determine how many input values are needed for the nested elements of
56 // the nested TupleType and advance inputIt by that number.
57 // TODO: We only need the *number* of nested types, not the types itself.
58 // Maybe it's worth adding a more efficient overload?
59 SmallVector<Type> nestedFlattenedTypes;
60 nestedTupleType.getFlattenedTypes(nestedFlattenedTypes);
61 size_t numNestedFlattenedTypes = nestedFlattenedTypes.size();
62 ValueRange nestedFlattenedelements(inputIt,
63 inputIt + numNestedFlattenedTypes);
64 inputIt += numNestedFlattenedTypes;
65
66 // Recurse on the values for the nested TupleType.
67 std::optional<Value> res = buildMakeTupleOp(builder, nestedTupleType,
68 nestedFlattenedelements, loc);
69 if (!res.has_value())
70 return {};
71
72 // The tuple constructed by the conversion is the element value.
73 elements.push_back(res.value());
74 } else {
75 // Base case: take one input as is.
76 elements.push_back(*inputIt++);
77 }
78 }
79
80 // Assemble the tuple from the elements.
81 return builder.create<test::MakeTupleOp>(loc, resultType, elements);
82}
83
84/// A pass for testing call graph type decomposition.
85///
86/// This instantiates the patterns with a TypeConverter and ValueDecomposer
87/// that splits tuple types into their respective element types.
88/// For example, `tuple<T1, T2, T3> --> T1, T2, T3`.
89struct TestDecomposeCallGraphTypes
90 : public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> {
91 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDecomposeCallGraphTypes)
92
93 void getDependentDialects(DialectRegistry &registry) const override {
94 registry.insert<test::TestDialect>();
95 }
96 StringRef getArgument() const final {
97 return "test-decompose-call-graph-types";
98 }
99 StringRef getDescription() const final {
100 return "Decomposes types at call graph boundaries.";
101 }
102 void runOnOperation() override {
103 ModuleOp module = getOperation();
104 auto *context = &getContext();
105 TypeConverter typeConverter;
106 ConversionTarget target(*context);
107 RewritePatternSet patterns(context);
108
109 target.addLegalDialect<test::TestDialect>();
110
111 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
112 return typeConverter.isLegal(op.getOperandTypes());
113 });
114 target.addDynamicallyLegalOp<func::CallOp>(
115 [&](func::CallOp op) { return typeConverter.isLegal(op); });
116 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
117 return typeConverter.isSignatureLegal(op.getFunctionType());
118 });
119
120 typeConverter.addConversion(callback: [](Type type) { return type; });
121 typeConverter.addConversion(
122 callback: [](TupleType tupleType, SmallVectorImpl<Type> &types) {
123 tupleType.getFlattenedTypes(types);
124 return success();
125 });
126 typeConverter.addArgumentMaterialization(callback&: buildMakeTupleOp);
127
128 ValueDecomposer decomposer;
129 decomposer.addDecomposeValueConversion(callback&: buildDecomposeTuple);
130
131 populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer,
132 patterns);
133
134 if (failed(applyPartialConversion(module, target, std::move(patterns))))
135 return signalPassFailure();
136 }
137};
138
139} // namespace
140
141namespace mlir {
142namespace test {
143void registerTestDecomposeCallGraphTypes() {
144 PassRegistration<TestDecomposeCallGraphTypes>();
145}
146} // namespace test
147} // namespace mlir
148

source code of mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp