1//===- TestFunc.cpp - Pass to test helpers on function utilities ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinOps.h"
10#include "mlir/Interfaces/FunctionInterfaces.h"
11#include "mlir/Pass/Pass.h"
12
13using namespace mlir;
14
15namespace {
16/// This is a test pass for verifying FunctionOpInterface's insertArgument
17/// method.
18struct TestFuncInsertArg
19 : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> {
20 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncInsertArg)
21
22 StringRef getArgument() const final { return "test-func-insert-arg"; }
23 StringRef getDescription() const final { return "Test inserting func args."; }
24 void runOnOperation() override {
25 auto module = getOperation();
26
27 UnknownLoc unknownLoc = UnknownLoc::get(module.getContext());
28 for (auto func : module.getOps<FunctionOpInterface>()) {
29 auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args");
30 if (!inserts || inserts.empty())
31 continue;
32 SmallVector<unsigned, 4> indicesToInsert;
33 SmallVector<Type, 4> typesToInsert;
34 SmallVector<DictionaryAttr, 4> attrsToInsert;
35 SmallVector<Location, 4> locsToInsert;
36 for (auto insert : inserts.getAsRange<ArrayAttr>()) {
37 indicesToInsert.push_back(
38 cast<IntegerAttr>(insert[0]).getValue().getZExtValue());
39 typesToInsert.push_back(cast<TypeAttr>(insert[1]).getValue());
40 attrsToInsert.push_back(insert.size() > 2
41 ? cast<DictionaryAttr>(insert[2])
42 : DictionaryAttr::get(&getContext()));
43 locsToInsert.push_back(insert.size() > 3
44 ? Location(cast<LocationAttr>(insert[3]))
45 : unknownLoc);
46 }
47 func->removeAttr("test.insert_args");
48 func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert,
49 locsToInsert);
50 }
51 }
52};
53
54/// This is a test pass for verifying FunctionOpInterface's insertResult method.
55struct TestFuncInsertResult
56 : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> {
57 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncInsertResult)
58
59 StringRef getArgument() const final { return "test-func-insert-result"; }
60 StringRef getDescription() const final {
61 return "Test inserting func results.";
62 }
63 void runOnOperation() override {
64 auto module = getOperation();
65
66 for (auto func : module.getOps<FunctionOpInterface>()) {
67 auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results");
68 if (!inserts || inserts.empty())
69 continue;
70 SmallVector<unsigned, 4> indicesToInsert;
71 SmallVector<Type, 4> typesToInsert;
72 SmallVector<DictionaryAttr, 4> attrsToInsert;
73 for (auto insert : inserts.getAsRange<ArrayAttr>()) {
74 indicesToInsert.push_back(
75 cast<IntegerAttr>(insert[0]).getValue().getZExtValue());
76 typesToInsert.push_back(cast<TypeAttr>(insert[1]).getValue());
77 attrsToInsert.push_back(insert.size() > 2
78 ? cast<DictionaryAttr>(insert[2])
79 : DictionaryAttr::get(&getContext()));
80 }
81 func->removeAttr("test.insert_results");
82 func.insertResults(indicesToInsert, typesToInsert, attrsToInsert);
83 }
84 }
85};
86
87/// This is a test pass for verifying FunctionOpInterface's eraseArgument
88/// method.
89struct TestFuncEraseArg
90 : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
91 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncEraseArg)
92
93 StringRef getArgument() const final { return "test-func-erase-arg"; }
94 StringRef getDescription() const final { return "Test erasing func args."; }
95 void runOnOperation() override {
96 auto module = getOperation();
97
98 for (auto func : module.getOps<FunctionOpInterface>()) {
99 BitVector indicesToErase(func.getNumArguments());
100 for (auto argIndex : llvm::seq<int>(0, func.getNumArguments()))
101 if (func.getArgAttr(argIndex, "test.erase_this_arg"))
102 indicesToErase.set(argIndex);
103 func.eraseArguments(indicesToErase);
104 }
105 }
106};
107
108/// This is a test pass for verifying FunctionOpInterface's eraseResult method.
109struct TestFuncEraseResult
110 : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> {
111 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncEraseResult)
112
113 StringRef getArgument() const final { return "test-func-erase-result"; }
114 StringRef getDescription() const final {
115 return "Test erasing func results.";
116 }
117 void runOnOperation() override {
118 auto module = getOperation();
119
120 for (auto func : module.getOps<FunctionOpInterface>()) {
121 BitVector indicesToErase(func.getNumResults());
122 for (auto resultIndex : llvm::seq<int>(0, func.getNumResults()))
123 if (func.getResultAttr(resultIndex, "test.erase_this_result"))
124 indicesToErase.set(resultIndex);
125 func.eraseResults(indicesToErase);
126 }
127 }
128};
129
130/// This is a test pass for verifying FunctionOpInterface's setType method.
131struct TestFuncSetType
132 : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
133 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncSetType)
134
135 StringRef getArgument() const final { return "test-func-set-type"; }
136 StringRef getDescription() const final {
137 return "Test FunctionOpInterface::setType.";
138 }
139 void runOnOperation() override {
140 auto module = getOperation();
141 SymbolTable symbolTable(module);
142
143 for (auto func : module.getOps<FunctionOpInterface>()) {
144 auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from");
145 if (!sym)
146 continue;
147 func.setType(symbolTable.lookup<FunctionOpInterface>(sym.getValue())
148 .getFunctionType());
149 }
150 }
151};
152} // namespace
153
154namespace mlir {
155void registerTestFunc() {
156 PassRegistration<TestFuncInsertArg>();
157
158 PassRegistration<TestFuncInsertResult>();
159
160 PassRegistration<TestFuncEraseArg>();
161
162 PassRegistration<TestFuncEraseResult>();
163
164 PassRegistration<TestFuncSetType>();
165}
166} // namespace mlir
167

source code of mlir/test/lib/IR/TestFunc.cpp