1 | //===- TestFunc.cpp - Pass to test helpers on function utilities ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/BuiltinOps.h" |
10 | #include "mlir/Interfaces/FunctionInterfaces.h" |
11 | #include "mlir/Pass/Pass.h" |
12 | |
13 | using namespace mlir; |
14 | |
15 | namespace { |
16 | /// This is a test pass for verifying FunctionOpInterface's insertArgument |
17 | /// method. |
18 | struct TestFuncInsertArg |
19 | : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> { |
20 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncInsertArg) |
21 | |
22 | StringRef getArgument() const final { return "test-func-insert-arg" ; } |
23 | StringRef getDescription() const final { return "Test inserting func args." ; } |
24 | void runOnOperation() override { |
25 | auto module = getOperation(); |
26 | |
27 | UnknownLoc unknownLoc = UnknownLoc::get(module.getContext()); |
28 | for (auto func : module.getOps<FunctionOpInterface>()) { |
29 | auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args" ); |
30 | if (!inserts || inserts.empty()) |
31 | continue; |
32 | SmallVector<unsigned, 4> indicesToInsert; |
33 | SmallVector<Type, 4> typesToInsert; |
34 | SmallVector<DictionaryAttr, 4> attrsToInsert; |
35 | SmallVector<Location, 4> locsToInsert; |
36 | for (auto insert : inserts.getAsRange<ArrayAttr>()) { |
37 | indicesToInsert.push_back( |
38 | cast<IntegerAttr>(insert[0]).getValue().getZExtValue()); |
39 | typesToInsert.push_back(cast<TypeAttr>(insert[1]).getValue()); |
40 | attrsToInsert.push_back(insert.size() > 2 |
41 | ? cast<DictionaryAttr>(insert[2]) |
42 | : DictionaryAttr::get(&getContext())); |
43 | locsToInsert.push_back(insert.size() > 3 |
44 | ? Location(cast<LocationAttr>(insert[3])) |
45 | : unknownLoc); |
46 | } |
47 | func->removeAttr("test.insert_args" ); |
48 | func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert, |
49 | locsToInsert); |
50 | } |
51 | } |
52 | }; |
53 | |
54 | /// This is a test pass for verifying FunctionOpInterface's insertResult method. |
55 | struct TestFuncInsertResult |
56 | : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> { |
57 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncInsertResult) |
58 | |
59 | StringRef getArgument() const final { return "test-func-insert-result" ; } |
60 | StringRef getDescription() const final { |
61 | return "Test inserting func results." ; |
62 | } |
63 | void runOnOperation() override { |
64 | auto module = getOperation(); |
65 | |
66 | for (auto func : module.getOps<FunctionOpInterface>()) { |
67 | auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results" ); |
68 | if (!inserts || inserts.empty()) |
69 | continue; |
70 | SmallVector<unsigned, 4> indicesToInsert; |
71 | SmallVector<Type, 4> typesToInsert; |
72 | SmallVector<DictionaryAttr, 4> attrsToInsert; |
73 | for (auto insert : inserts.getAsRange<ArrayAttr>()) { |
74 | indicesToInsert.push_back( |
75 | cast<IntegerAttr>(insert[0]).getValue().getZExtValue()); |
76 | typesToInsert.push_back(cast<TypeAttr>(insert[1]).getValue()); |
77 | attrsToInsert.push_back(insert.size() > 2 |
78 | ? cast<DictionaryAttr>(insert[2]) |
79 | : DictionaryAttr::get(&getContext())); |
80 | } |
81 | func->removeAttr("test.insert_results" ); |
82 | func.insertResults(indicesToInsert, typesToInsert, attrsToInsert); |
83 | } |
84 | } |
85 | }; |
86 | |
87 | /// This is a test pass for verifying FunctionOpInterface's eraseArgument |
88 | /// method. |
89 | struct TestFuncEraseArg |
90 | : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> { |
91 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncEraseArg) |
92 | |
93 | StringRef getArgument() const final { return "test-func-erase-arg" ; } |
94 | StringRef getDescription() const final { return "Test erasing func args." ; } |
95 | void runOnOperation() override { |
96 | auto module = getOperation(); |
97 | |
98 | for (auto func : module.getOps<FunctionOpInterface>()) { |
99 | BitVector indicesToErase(func.getNumArguments()); |
100 | for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) |
101 | if (func.getArgAttr(argIndex, "test.erase_this_arg" )) |
102 | indicesToErase.set(argIndex); |
103 | func.eraseArguments(indicesToErase); |
104 | } |
105 | } |
106 | }; |
107 | |
108 | /// This is a test pass for verifying FunctionOpInterface's eraseResult method. |
109 | struct TestFuncEraseResult |
110 | : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> { |
111 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncEraseResult) |
112 | |
113 | StringRef getArgument() const final { return "test-func-erase-result" ; } |
114 | StringRef getDescription() const final { |
115 | return "Test erasing func results." ; |
116 | } |
117 | void runOnOperation() override { |
118 | auto module = getOperation(); |
119 | |
120 | for (auto func : module.getOps<FunctionOpInterface>()) { |
121 | BitVector indicesToErase(func.getNumResults()); |
122 | for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) |
123 | if (func.getResultAttr(resultIndex, "test.erase_this_result" )) |
124 | indicesToErase.set(resultIndex); |
125 | func.eraseResults(indicesToErase); |
126 | } |
127 | } |
128 | }; |
129 | |
130 | /// This is a test pass for verifying FunctionOpInterface's setType method. |
131 | struct TestFuncSetType |
132 | : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> { |
133 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncSetType) |
134 | |
135 | StringRef getArgument() const final { return "test-func-set-type" ; } |
136 | StringRef getDescription() const final { |
137 | return "Test FunctionOpInterface::setType." ; |
138 | } |
139 | void runOnOperation() override { |
140 | auto module = getOperation(); |
141 | SymbolTable symbolTable(module); |
142 | |
143 | for (auto func : module.getOps<FunctionOpInterface>()) { |
144 | auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from" ); |
145 | if (!sym) |
146 | continue; |
147 | func.setType(symbolTable.lookup<FunctionOpInterface>(sym.getValue()) |
148 | .getFunctionType()); |
149 | } |
150 | } |
151 | }; |
152 | } // namespace |
153 | |
154 | namespace mlir { |
155 | void registerTestFunc() { |
156 | PassRegistration<TestFuncInsertArg>(); |
157 | |
158 | PassRegistration<TestFuncInsertResult>(); |
159 | |
160 | PassRegistration<TestFuncEraseArg>(); |
161 | |
162 | PassRegistration<TestFuncEraseResult>(); |
163 | |
164 | PassRegistration<TestFuncSetType>(); |
165 | } |
166 | } // namespace mlir |
167 | |