1 | //===- FunctionSupport.cpp - Utility types for function-like ops ----------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Interfaces/FunctionInterfaces.h" |
10 | |
11 | using namespace mlir; |
12 | |
13 | //===----------------------------------------------------------------------===// |
14 | // Tablegen Interface Definitions |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #include "mlir/Interfaces/FunctionInterfaces.cpp.inc" |
18 | |
19 | //===----------------------------------------------------------------------===// |
20 | // Function Arguments and Results. |
21 | //===----------------------------------------------------------------------===// |
22 | |
23 | static bool isEmptyAttrDict(Attribute attr) { |
24 | return llvm::cast<DictionaryAttr>(attr).empty(); |
25 | } |
26 | |
27 | DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op, |
28 | unsigned index) { |
29 | ArrayAttr attrs = op.getArgAttrsAttr(); |
30 | DictionaryAttr argAttrs = |
31 | attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr(); |
32 | return argAttrs; |
33 | } |
34 | |
35 | DictionaryAttr |
36 | function_interface_impl::getResultAttrDict(FunctionOpInterface op, |
37 | unsigned index) { |
38 | ArrayAttr attrs = op.getResAttrsAttr(); |
39 | DictionaryAttr resAttrs = |
40 | attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr(); |
41 | return resAttrs; |
42 | } |
43 | |
44 | ArrayRef<NamedAttribute> |
45 | function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) { |
46 | auto argDict = getArgAttrDict(op, index); |
47 | return argDict ? argDict.getValue() : std::nullopt; |
48 | } |
49 | |
50 | ArrayRef<NamedAttribute> |
51 | function_interface_impl::getResultAttrs(FunctionOpInterface op, |
52 | unsigned index) { |
53 | auto resultDict = getResultAttrDict(op, index); |
54 | return resultDict ? resultDict.getValue() : std::nullopt; |
55 | } |
56 | |
57 | /// Get either the argument or result attributes array. |
58 | template <bool isArg> |
59 | static ArrayAttr getArgResAttrs(FunctionOpInterface op) { |
60 | if constexpr (isArg) |
61 | return op.getArgAttrsAttr(); |
62 | else |
63 | return op.getResAttrsAttr(); |
64 | } |
65 | |
66 | /// Set either the argument or result attributes array. |
67 | template <bool isArg> |
68 | static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) { |
69 | if constexpr (isArg) |
70 | op.setArgAttrsAttr(attrs); |
71 | else |
72 | op.setResAttrsAttr(attrs); |
73 | } |
74 | |
75 | /// Erase either the argument or result attributes array. |
76 | template <bool isArg> |
77 | static void removeArgResAttrs(FunctionOpInterface op) { |
78 | if constexpr (isArg) |
79 | op.removeArgAttrsAttr(); |
80 | else |
81 | op.removeResAttrsAttr(); |
82 | } |
83 | |
84 | /// Set all of the argument or result attribute dictionaries for a function. |
85 | template <bool isArg> |
86 | static void setAllArgResAttrDicts(FunctionOpInterface op, |
87 | ArrayRef<Attribute> attrs) { |
88 | if (llvm::all_of(Range&: attrs, P: isEmptyAttrDict)) |
89 | removeArgResAttrs<isArg>(op); |
90 | else |
91 | setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs)); |
92 | } |
93 | |
94 | void function_interface_impl::setAllArgAttrDicts( |
95 | FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) { |
96 | setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size())); |
97 | } |
98 | |
99 | void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op, |
100 | ArrayRef<Attribute> attrs) { |
101 | auto wrappedAttrs = llvm::map_range(C&: attrs, F: [op](Attribute attr) -> Attribute { |
102 | return !attr ? DictionaryAttr::get(op->getContext()) : attr; |
103 | }); |
104 | setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(Range&: wrappedAttrs)); |
105 | } |
106 | |
107 | void function_interface_impl::setAllResultAttrDicts( |
108 | FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) { |
109 | setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size())); |
110 | } |
111 | |
112 | void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op, |
113 | ArrayRef<Attribute> attrs) { |
114 | auto wrappedAttrs = llvm::map_range(C&: attrs, F: [op](Attribute attr) -> Attribute { |
115 | return !attr ? DictionaryAttr::get(op->getContext()) : attr; |
116 | }); |
117 | setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(Range&: wrappedAttrs)); |
118 | } |
119 | |
120 | /// Update the given index into an argument or result attribute dictionary. |
121 | template <bool isArg> |
122 | static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices, |
123 | unsigned index, DictionaryAttr attrs) { |
124 | ArrayAttr allAttrs = getArgResAttrs<isArg>(op); |
125 | if (!allAttrs) { |
126 | if (attrs.empty()) |
127 | return; |
128 | |
129 | // If this attribute is not empty, we need to create a new attribute array. |
130 | SmallVector<Attribute, 8> newAttrs(numTotalIndices, |
131 | DictionaryAttr::get(op->getContext())); |
132 | newAttrs[index] = attrs; |
133 | setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs)); |
134 | return; |
135 | } |
136 | // Check to see if the attribute is different from what we already have. |
137 | if (allAttrs[index] == attrs) |
138 | return; |
139 | |
140 | // If it is, check to see if the attribute array would now contain only empty |
141 | // dictionaries. |
142 | ArrayRef<Attribute> rawAttrArray = allAttrs.getValue(); |
143 | if (attrs.empty() && |
144 | llvm::all_of(Range: rawAttrArray.take_front(N: index), P: isEmptyAttrDict) && |
145 | llvm::all_of(Range: rawAttrArray.drop_front(N: index + 1), P: isEmptyAttrDict)) |
146 | return removeArgResAttrs<isArg>(op); |
147 | |
148 | // Otherwise, create a new attribute array with the updated dictionary. |
149 | SmallVector<Attribute, 8> newAttrs(rawAttrArray); |
150 | newAttrs[index] = attrs; |
151 | setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs)); |
152 | } |
153 | |
154 | void function_interface_impl::setArgAttrs(FunctionOpInterface op, |
155 | unsigned index, |
156 | ArrayRef<NamedAttribute> attributes) { |
157 | assert(index < op.getNumArguments() && "invalid argument number"); |
158 | return setArgResAttrDict</*isArg=*/true>( |
159 | op, op.getNumArguments(), index, |
160 | DictionaryAttr::get(op->getContext(), attributes)); |
161 | } |
162 | |
163 | void function_interface_impl::setArgAttrs(FunctionOpInterface op, |
164 | unsigned index, |
165 | DictionaryAttr attributes) { |
166 | return setArgResAttrDict</*isArg=*/true>( |
167 | op, op.getNumArguments(), index, |
168 | attributes ? attributes : DictionaryAttr::get(op->getContext())); |
169 | } |
170 | |
171 | void function_interface_impl::setResultAttrs( |
172 | FunctionOpInterface op, unsigned index, |
173 | ArrayRef<NamedAttribute> attributes) { |
174 | assert(index < op.getNumResults() && "invalid result number"); |
175 | return setArgResAttrDict</*isArg=*/false>( |
176 | op, op.getNumResults(), index, |
177 | DictionaryAttr::get(op->getContext(), attributes)); |
178 | } |
179 | |
180 | void function_interface_impl::setResultAttrs(FunctionOpInterface op, |
181 | unsigned index, |
182 | DictionaryAttr attributes) { |
183 | assert(index < op.getNumResults() && "invalid result number"); |
184 | return setArgResAttrDict</*isArg=*/false>( |
185 | op, op.getNumResults(), index, |
186 | attributes ? attributes : DictionaryAttr::get(op->getContext())); |
187 | } |
188 | |
189 | void function_interface_impl::insertFunctionArguments( |
190 | FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes, |
191 | ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs, |
192 | unsigned originalNumArgs, Type newType) { |
193 | assert(argIndices.size() == argTypes.size()); |
194 | assert(argIndices.size() == argAttrs.size() || argAttrs.empty()); |
195 | assert(argIndices.size() == argLocs.size()); |
196 | if (argIndices.empty()) |
197 | return; |
198 | |
199 | // There are 3 things that need to be updated: |
200 | // - Function type. |
201 | // - Arg attrs. |
202 | // - Block arguments of entry block, if not empty. |
203 | |
204 | // Update the argument attributes of the function. |
205 | ArrayAttr oldArgAttrs = op.getArgAttrsAttr(); |
206 | if (oldArgAttrs || !argAttrs.empty()) { |
207 | SmallVector<DictionaryAttr, 4> newArgAttrs; |
208 | newArgAttrs.reserve(originalNumArgs + argIndices.size()); |
209 | unsigned oldIdx = 0; |
210 | auto migrate = [&](unsigned untilIdx) { |
211 | if (!oldArgAttrs) { |
212 | newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx); |
213 | } else { |
214 | auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>(); |
215 | newArgAttrs.append(oldArgAttrRange.begin() + oldIdx, |
216 | oldArgAttrRange.begin() + untilIdx); |
217 | } |
218 | oldIdx = untilIdx; |
219 | }; |
220 | for (unsigned i = 0, e = argIndices.size(); i < e; ++i) { |
221 | migrate(argIndices[i]); |
222 | newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]); |
223 | } |
224 | migrate(originalNumArgs); |
225 | setAllArgAttrDicts(op, newArgAttrs); |
226 | } |
227 | |
228 | // Update the function type. |
229 | op.setFunctionTypeAttr(TypeAttr::get(newType)); |
230 | |
231 | // Update entry block arguments, if not empty. |
232 | if (!op.isExternal()) { |
233 | Block &entry = op->getRegion(0).front(); |
234 | for (unsigned i = 0, e = argIndices.size(); i < e; ++i) |
235 | entry.insertArgument(index: argIndices[i] + i, type: argTypes[i], loc: argLocs[i]); |
236 | } |
237 | } |
238 | |
239 | void function_interface_impl::insertFunctionResults( |
240 | FunctionOpInterface op, ArrayRef<unsigned> resultIndices, |
241 | TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs, |
242 | unsigned originalNumResults, Type newType) { |
243 | assert(resultIndices.size() == resultTypes.size()); |
244 | assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty()); |
245 | if (resultIndices.empty()) |
246 | return; |
247 | |
248 | // There are 2 things that need to be updated: |
249 | // - Function type. |
250 | // - Result attrs. |
251 | |
252 | // Update the result attributes of the function. |
253 | ArrayAttr oldResultAttrs = op.getResAttrsAttr(); |
254 | if (oldResultAttrs || !resultAttrs.empty()) { |
255 | SmallVector<DictionaryAttr, 4> newResultAttrs; |
256 | newResultAttrs.reserve(originalNumResults + resultIndices.size()); |
257 | unsigned oldIdx = 0; |
258 | auto migrate = [&](unsigned untilIdx) { |
259 | if (!oldResultAttrs) { |
260 | newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx); |
261 | } else { |
262 | auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>(); |
263 | newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx, |
264 | oldResultAttrsRange.begin() + untilIdx); |
265 | } |
266 | oldIdx = untilIdx; |
267 | }; |
268 | for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) { |
269 | migrate(resultIndices[i]); |
270 | newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{} |
271 | : resultAttrs[i]); |
272 | } |
273 | migrate(originalNumResults); |
274 | setAllResultAttrDicts(op, newResultAttrs); |
275 | } |
276 | |
277 | // Update the function type. |
278 | op.setFunctionTypeAttr(TypeAttr::get(newType)); |
279 | } |
280 | |
281 | void function_interface_impl::eraseFunctionArguments( |
282 | FunctionOpInterface op, const BitVector &argIndices, Type newType) { |
283 | // There are 3 things that need to be updated: |
284 | // - Function type. |
285 | // - Arg attrs. |
286 | // - Block arguments of entry block, if not empty. |
287 | |
288 | // Update the argument attributes of the function. |
289 | if (ArrayAttr argAttrs = op.getArgAttrsAttr()) { |
290 | SmallVector<DictionaryAttr, 4> newArgAttrs; |
291 | newArgAttrs.reserve(argAttrs.size()); |
292 | for (unsigned i = 0, e = argIndices.size(); i < e; ++i) |
293 | if (!argIndices[i]) |
294 | newArgAttrs.emplace_back(llvm::cast<DictionaryAttr>(argAttrs[i])); |
295 | setAllArgAttrDicts(op, newArgAttrs); |
296 | } |
297 | |
298 | // Update the function type. |
299 | op.setFunctionTypeAttr(TypeAttr::get(newType)); |
300 | |
301 | // Update entry block arguments, if not empty. |
302 | if (!op.isExternal()) { |
303 | Block &entry = op->getRegion(0).front(); |
304 | entry.eraseArguments(eraseIndices: argIndices); |
305 | } |
306 | } |
307 | |
308 | void function_interface_impl::eraseFunctionResults( |
309 | FunctionOpInterface op, const BitVector &resultIndices, Type newType) { |
310 | // There are 2 things that need to be updated: |
311 | // - Function type. |
312 | // - Result attrs. |
313 | |
314 | // Update the result attributes of the function. |
315 | if (ArrayAttr resAttrs = op.getResAttrsAttr()) { |
316 | SmallVector<DictionaryAttr, 4> newResultAttrs; |
317 | newResultAttrs.reserve(resAttrs.size()); |
318 | for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) |
319 | if (!resultIndices[i]) |
320 | newResultAttrs.emplace_back(llvm::cast<DictionaryAttr>(resAttrs[i])); |
321 | setAllResultAttrDicts(op, newResultAttrs); |
322 | } |
323 | |
324 | // Update the function type. |
325 | op.setFunctionTypeAttr(TypeAttr::get(newType)); |
326 | } |
327 | |
328 | //===----------------------------------------------------------------------===// |
329 | // Function type signature. |
330 | //===----------------------------------------------------------------------===// |
331 | |
332 | void function_interface_impl::setFunctionType(FunctionOpInterface op, |
333 | Type newType) { |
334 | unsigned oldNumArgs = op.getNumArguments(); |
335 | unsigned oldNumResults = op.getNumResults(); |
336 | op.setFunctionTypeAttr(TypeAttr::get(newType)); |
337 | unsigned newNumArgs = op.getNumArguments(); |
338 | unsigned newNumResults = op.getNumResults(); |
339 | |
340 | // Functor used to update the argument and result attributes of the function. |
341 | auto emptyDict = DictionaryAttr::get(op.getContext()); |
342 | auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) { |
343 | constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>; |
344 | |
345 | if (oldCount == newCount) |
346 | return; |
347 | // The new type has no arguments/results, just drop the attribute. |
348 | if (newCount == 0) |
349 | return removeArgResAttrs<isArgVal>(op); |
350 | ArrayAttr attrs = getArgResAttrs<isArgVal>(op); |
351 | if (!attrs) |
352 | return; |
353 | |
354 | // The new type has less arguments/results, take the first N attributes. |
355 | if (newCount < oldCount) |
356 | return setAllArgResAttrDicts<isArgVal>( |
357 | op, attrs.getValue().take_front(newCount)); |
358 | |
359 | // Otherwise, the new type has more arguments/results. Initialize the new |
360 | // arguments/results with empty dictionary attributes. |
361 | SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end()); |
362 | newAttrs.resize(newCount, emptyDict); |
363 | setAllArgResAttrDicts<isArgVal>(op, newAttrs); |
364 | }; |
365 | |
366 | // Update the argument and result attributes. |
367 | updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs); |
368 | updateAttrFn(std::false_type{}, oldNumResults, newNumResults); |
369 | } |
370 |
Definitions
- isEmptyAttrDict
- getArgAttrDict
- getResultAttrDict
- getArgAttrs
- getResultAttrs
- getArgResAttrs
- setArgResAttrs
- removeArgResAttrs
- setAllArgResAttrDicts
- setAllArgAttrDicts
- setAllArgAttrDicts
- setAllResultAttrDicts
- setAllResultAttrDicts
- setArgResAttrDict
- setArgAttrs
- setArgAttrs
- setResultAttrs
- setResultAttrs
- insertFunctionArguments
- insertFunctionResults
- eraseFunctionArguments
- eraseFunctionResults
Improve your Profiling and Debugging skills
Find out more