1//===- TestFunc.cpp - Pass to test helpers on function utilities ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinOps.h"
10#include "mlir/Interfaces/FunctionInterfaces.h"
11#include "mlir/Pass/Pass.h"
12
13using namespace mlir;
14
15namespace {
16/// This is a test pass for verifying FunctionOpInterface's insertArgument
17/// method.
18struct TestFuncInsertArg
19 : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> {
20 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncInsertArg)
21
22 StringRef getArgument() const final { return "test-func-insert-arg"; }
23 StringRef getDescription() const final { return "Test inserting func args."; }
24 void runOnOperation() override {
25 auto module = getOperation();
26
27 UnknownLoc unknownLoc = UnknownLoc::get(module.getContext());
28 for (auto func : module.getOps<FunctionOpInterface>()) {
29 auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args");
30 if (!inserts || inserts.empty())
31 continue;
32 SmallVector<unsigned, 4> indicesToInsert;
33 SmallVector<Type, 4> typesToInsert;
34 SmallVector<DictionaryAttr, 4> attrsToInsert;
35 SmallVector<Location, 4> locsToInsert;
36 for (auto insert : inserts.getAsRange<ArrayAttr>()) {
37 indicesToInsert.push_back(
38 cast<IntegerAttr>(insert[0]).getValue().getZExtValue());
39 typesToInsert.push_back(cast<TypeAttr>(insert[1]).getValue());
40 attrsToInsert.push_back(insert.size() > 2
41 ? cast<DictionaryAttr>(insert[2])
42 : DictionaryAttr::get(&getContext()));
43 locsToInsert.push_back(insert.size() > 3
44 ? Location(cast<LocationAttr>(insert[3]))
45 : unknownLoc);
46 }
47 func->removeAttr("test.insert_args");
48 if (succeeded(func.insertArguments(indicesToInsert, typesToInsert,
49 attrsToInsert, locsToInsert)))
50 continue;
51
52 emitError(func->getLoc()) << "failed to insert arguments";
53 return signalPassFailure();
54 }
55 }
56};
57
58/// This is a test pass for verifying FunctionOpInterface's insertResult method.
59struct TestFuncInsertResult
60 : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> {
61 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncInsertResult)
62
63 StringRef getArgument() const final { return "test-func-insert-result"; }
64 StringRef getDescription() const final {
65 return "Test inserting func results.";
66 }
67 void runOnOperation() override {
68 auto module = getOperation();
69
70 for (auto func : module.getOps<FunctionOpInterface>()) {
71 auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results");
72 if (!inserts || inserts.empty())
73 continue;
74 SmallVector<unsigned, 4> indicesToInsert;
75 SmallVector<Type, 4> typesToInsert;
76 SmallVector<DictionaryAttr, 4> attrsToInsert;
77 for (auto insert : inserts.getAsRange<ArrayAttr>()) {
78 indicesToInsert.push_back(
79 cast<IntegerAttr>(insert[0]).getValue().getZExtValue());
80 typesToInsert.push_back(cast<TypeAttr>(insert[1]).getValue());
81 attrsToInsert.push_back(insert.size() > 2
82 ? cast<DictionaryAttr>(insert[2])
83 : DictionaryAttr::get(&getContext()));
84 }
85 func->removeAttr("test.insert_results");
86 if (succeeded(func.insertResults(indicesToInsert, typesToInsert,
87 attrsToInsert)))
88 continue;
89
90 emitError(func->getLoc()) << "failed to insert results";
91 return signalPassFailure();
92 }
93 }
94};
95
96/// This is a test pass for verifying FunctionOpInterface's eraseArgument
97/// method.
98struct TestFuncEraseArg
99 : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
100 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncEraseArg)
101
102 StringRef getArgument() const final { return "test-func-erase-arg"; }
103 StringRef getDescription() const final { return "Test erasing func args."; }
104 void runOnOperation() override {
105 auto module = getOperation();
106
107 for (auto func : module.getOps<FunctionOpInterface>()) {
108 BitVector indicesToErase(func.getNumArguments());
109 for (auto argIndex : llvm::seq<int>(0, func.getNumArguments()))
110 if (func.getArgAttr(argIndex, "test.erase_this_arg"))
111 indicesToErase.set(argIndex);
112 if (succeeded(func.eraseArguments(indicesToErase)))
113 continue;
114 emitError(func->getLoc()) << "failed to erase arguments";
115 return signalPassFailure();
116 }
117 }
118};
119
120/// This is a test pass for verifying FunctionOpInterface's eraseResult method.
121struct TestFuncEraseResult
122 : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> {
123 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncEraseResult)
124
125 StringRef getArgument() const final { return "test-func-erase-result"; }
126 StringRef getDescription() const final {
127 return "Test erasing func results.";
128 }
129 void runOnOperation() override {
130 auto module = getOperation();
131
132 for (auto func : module.getOps<FunctionOpInterface>()) {
133 BitVector indicesToErase(func.getNumResults());
134 for (auto resultIndex : llvm::seq<int>(0, func.getNumResults()))
135 if (func.getResultAttr(resultIndex, "test.erase_this_result"))
136 indicesToErase.set(resultIndex);
137 if (succeeded(func.eraseResults(indicesToErase)))
138 continue;
139 emitError(func->getLoc()) << "failed to erase results";
140 return signalPassFailure();
141 }
142 }
143};
144
145/// This is a test pass for verifying FunctionOpInterface's setType method.
146struct TestFuncSetType
147 : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
148 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFuncSetType)
149
150 StringRef getArgument() const final { return "test-func-set-type"; }
151 StringRef getDescription() const final {
152 return "Test FunctionOpInterface::setType.";
153 }
154 void runOnOperation() override {
155 auto module = getOperation();
156 SymbolTable symbolTable(module);
157
158 for (auto func : module.getOps<FunctionOpInterface>()) {
159 auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from");
160 if (!sym)
161 continue;
162 func.setType(symbolTable.lookup<FunctionOpInterface>(sym.getValue())
163 .getFunctionType());
164 }
165 }
166};
167} // namespace
168
169namespace mlir {
170void registerTestFunc() {
171 PassRegistration<TestFuncInsertArg>();
172
173 PassRegistration<TestFuncInsertResult>();
174
175 PassRegistration<TestFuncEraseArg>();
176
177 PassRegistration<TestFuncEraseResult>();
178
179 PassRegistration<TestFuncSetType>();
180}
181} // namespace mlir
182

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/IR/TestFunc.cpp