1//===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Interfaces/FunctionInterfaces.h"
10
11using namespace mlir;
12
13//===----------------------------------------------------------------------===//
14// Tablegen Interface Definitions
15//===----------------------------------------------------------------------===//
16
17#include "mlir/Interfaces/FunctionInterfaces.cpp.inc"
18
19//===----------------------------------------------------------------------===//
20// Function Arguments and Results.
21//===----------------------------------------------------------------------===//
22
23static bool isEmptyAttrDict(Attribute attr) {
24 return llvm::cast<DictionaryAttr>(attr).empty();
25}
26
27DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op,
28 unsigned index) {
29 ArrayAttr attrs = op.getArgAttrsAttr();
30 DictionaryAttr argAttrs =
31 attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
32 return argAttrs;
33}
34
35DictionaryAttr
36function_interface_impl::getResultAttrDict(FunctionOpInterface op,
37 unsigned index) {
38 ArrayAttr attrs = op.getResAttrsAttr();
39 DictionaryAttr resAttrs =
40 attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
41 return resAttrs;
42}
43
44ArrayRef<NamedAttribute>
45function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) {
46 auto argDict = getArgAttrDict(op, index);
47 return argDict ? argDict.getValue() : std::nullopt;
48}
49
50ArrayRef<NamedAttribute>
51function_interface_impl::getResultAttrs(FunctionOpInterface op,
52 unsigned index) {
53 auto resultDict = getResultAttrDict(op, index);
54 return resultDict ? resultDict.getValue() : std::nullopt;
55}
56
57/// Get either the argument or result attributes array.
58template <bool isArg>
59static ArrayAttr getArgResAttrs(FunctionOpInterface op) {
60 if constexpr (isArg)
61 return op.getArgAttrsAttr();
62 else
63 return op.getResAttrsAttr();
64}
65
66/// Set either the argument or result attributes array.
67template <bool isArg>
68static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) {
69 if constexpr (isArg)
70 op.setArgAttrsAttr(attrs);
71 else
72 op.setResAttrsAttr(attrs);
73}
74
75/// Erase either the argument or result attributes array.
76template <bool isArg>
77static void removeArgResAttrs(FunctionOpInterface op) {
78 if constexpr (isArg)
79 op.removeArgAttrsAttr();
80 else
81 op.removeResAttrsAttr();
82}
83
84/// Set all of the argument or result attribute dictionaries for a function.
85template <bool isArg>
86static void setAllArgResAttrDicts(FunctionOpInterface op,
87 ArrayRef<Attribute> attrs) {
88 if (llvm::all_of(Range&: attrs, P: isEmptyAttrDict))
89 removeArgResAttrs<isArg>(op);
90 else
91 setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs));
92}
93
94void function_interface_impl::setAllArgAttrDicts(
95 FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
96 setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
97}
98
99void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op,
100 ArrayRef<Attribute> attrs) {
101 auto wrappedAttrs = llvm::map_range(C&: attrs, F: [op](Attribute attr) -> Attribute {
102 return !attr ? DictionaryAttr::get(op->getContext()) : attr;
103 });
104 setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(Range&: wrappedAttrs));
105}
106
107void function_interface_impl::setAllResultAttrDicts(
108 FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
109 setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
110}
111
112void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op,
113 ArrayRef<Attribute> attrs) {
114 auto wrappedAttrs = llvm::map_range(C&: attrs, F: [op](Attribute attr) -> Attribute {
115 return !attr ? DictionaryAttr::get(op->getContext()) : attr;
116 });
117 setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(Range&: wrappedAttrs));
118}
119
120/// Update the given index into an argument or result attribute dictionary.
121template <bool isArg>
122static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices,
123 unsigned index, DictionaryAttr attrs) {
124 ArrayAttr allAttrs = getArgResAttrs<isArg>(op);
125 if (!allAttrs) {
126 if (attrs.empty())
127 return;
128
129 // If this attribute is not empty, we need to create a new attribute array.
130 SmallVector<Attribute, 8> newAttrs(numTotalIndices,
131 DictionaryAttr::get(op->getContext()));
132 newAttrs[index] = attrs;
133 setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
134 return;
135 }
136 // Check to see if the attribute is different from what we already have.
137 if (allAttrs[index] == attrs)
138 return;
139
140 // If it is, check to see if the attribute array would now contain only empty
141 // dictionaries.
142 ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
143 if (attrs.empty() &&
144 llvm::all_of(Range: rawAttrArray.take_front(N: index), P: isEmptyAttrDict) &&
145 llvm::all_of(Range: rawAttrArray.drop_front(N: index + 1), P: isEmptyAttrDict))
146 return removeArgResAttrs<isArg>(op);
147
148 // Otherwise, create a new attribute array with the updated dictionary.
149 SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
150 newAttrs[index] = attrs;
151 setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
152}
153
154void function_interface_impl::setArgAttrs(FunctionOpInterface op,
155 unsigned index,
156 ArrayRef<NamedAttribute> attributes) {
157 assert(index < op.getNumArguments() && "invalid argument number");
158 return setArgResAttrDict</*isArg=*/true>(
159 op, op.getNumArguments(), index,
160 DictionaryAttr::get(op->getContext(), attributes));
161}
162
163void function_interface_impl::setArgAttrs(FunctionOpInterface op,
164 unsigned index,
165 DictionaryAttr attributes) {
166 return setArgResAttrDict</*isArg=*/true>(
167 op, op.getNumArguments(), index,
168 attributes ? attributes : DictionaryAttr::get(op->getContext()));
169}
170
171void function_interface_impl::setResultAttrs(
172 FunctionOpInterface op, unsigned index,
173 ArrayRef<NamedAttribute> attributes) {
174 assert(index < op.getNumResults() && "invalid result number");
175 return setArgResAttrDict</*isArg=*/false>(
176 op, op.getNumResults(), index,
177 DictionaryAttr::get(op->getContext(), attributes));
178}
179
180void function_interface_impl::setResultAttrs(FunctionOpInterface op,
181 unsigned index,
182 DictionaryAttr attributes) {
183 assert(index < op.getNumResults() && "invalid result number");
184 return setArgResAttrDict</*isArg=*/false>(
185 op, op.getNumResults(), index,
186 attributes ? attributes : DictionaryAttr::get(op->getContext()));
187}
188
189void function_interface_impl::insertFunctionArguments(
190 FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
191 ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
192 unsigned originalNumArgs, Type newType) {
193 assert(argIndices.size() == argTypes.size());
194 assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
195 assert(argIndices.size() == argLocs.size());
196 if (argIndices.empty())
197 return;
198
199 // There are 3 things that need to be updated:
200 // - Function type.
201 // - Arg attrs.
202 // - Block arguments of entry block.
203 Block &entry = op->getRegion(0).front();
204
205 // Update the argument attributes of the function.
206 ArrayAttr oldArgAttrs = op.getArgAttrsAttr();
207 if (oldArgAttrs || !argAttrs.empty()) {
208 SmallVector<DictionaryAttr, 4> newArgAttrs;
209 newArgAttrs.reserve(originalNumArgs + argIndices.size());
210 unsigned oldIdx = 0;
211 auto migrate = [&](unsigned untilIdx) {
212 if (!oldArgAttrs) {
213 newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
214 } else {
215 auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
216 newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
217 oldArgAttrRange.begin() + untilIdx);
218 }
219 oldIdx = untilIdx;
220 };
221 for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
222 migrate(argIndices[i]);
223 newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
224 }
225 migrate(originalNumArgs);
226 setAllArgAttrDicts(op, newArgAttrs);
227 }
228
229 // Update the function type and any entry block arguments.
230 op.setFunctionTypeAttr(TypeAttr::get(newType));
231 for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
232 entry.insertArgument(index: argIndices[i] + i, type: argTypes[i], loc: argLocs[i]);
233}
234
235void function_interface_impl::insertFunctionResults(
236 FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
237 TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
238 unsigned originalNumResults, Type newType) {
239 assert(resultIndices.size() == resultTypes.size());
240 assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
241 if (resultIndices.empty())
242 return;
243
244 // There are 2 things that need to be updated:
245 // - Function type.
246 // - Result attrs.
247
248 // Update the result attributes of the function.
249 ArrayAttr oldResultAttrs = op.getResAttrsAttr();
250 if (oldResultAttrs || !resultAttrs.empty()) {
251 SmallVector<DictionaryAttr, 4> newResultAttrs;
252 newResultAttrs.reserve(originalNumResults + resultIndices.size());
253 unsigned oldIdx = 0;
254 auto migrate = [&](unsigned untilIdx) {
255 if (!oldResultAttrs) {
256 newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
257 } else {
258 auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
259 newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
260 oldResultAttrsRange.begin() + untilIdx);
261 }
262 oldIdx = untilIdx;
263 };
264 for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
265 migrate(resultIndices[i]);
266 newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
267 : resultAttrs[i]);
268 }
269 migrate(originalNumResults);
270 setAllResultAttrDicts(op, newResultAttrs);
271 }
272
273 // Update the function type.
274 op.setFunctionTypeAttr(TypeAttr::get(newType));
275}
276
277void function_interface_impl::eraseFunctionArguments(
278 FunctionOpInterface op, const BitVector &argIndices, Type newType) {
279 // There are 3 things that need to be updated:
280 // - Function type.
281 // - Arg attrs.
282 // - Block arguments of entry block.
283 Block &entry = op->getRegion(0).front();
284
285 // Update the argument attributes of the function.
286 if (ArrayAttr argAttrs = op.getArgAttrsAttr()) {
287 SmallVector<DictionaryAttr, 4> newArgAttrs;
288 newArgAttrs.reserve(argAttrs.size());
289 for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
290 if (!argIndices[i])
291 newArgAttrs.emplace_back(llvm::cast<DictionaryAttr>(argAttrs[i]));
292 setAllArgAttrDicts(op, newArgAttrs);
293 }
294
295 // Update the function type and any entry block arguments.
296 op.setFunctionTypeAttr(TypeAttr::get(newType));
297 entry.eraseArguments(eraseIndices: argIndices);
298}
299
300void function_interface_impl::eraseFunctionResults(
301 FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
302 // There are 2 things that need to be updated:
303 // - Function type.
304 // - Result attrs.
305
306 // Update the result attributes of the function.
307 if (ArrayAttr resAttrs = op.getResAttrsAttr()) {
308 SmallVector<DictionaryAttr, 4> newResultAttrs;
309 newResultAttrs.reserve(resAttrs.size());
310 for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
311 if (!resultIndices[i])
312 newResultAttrs.emplace_back(llvm::cast<DictionaryAttr>(resAttrs[i]));
313 setAllResultAttrDicts(op, newResultAttrs);
314 }
315
316 // Update the function type.
317 op.setFunctionTypeAttr(TypeAttr::get(newType));
318}
319
320//===----------------------------------------------------------------------===//
321// Function type signature.
322//===----------------------------------------------------------------------===//
323
324void function_interface_impl::setFunctionType(FunctionOpInterface op,
325 Type newType) {
326 unsigned oldNumArgs = op.getNumArguments();
327 unsigned oldNumResults = op.getNumResults();
328 op.setFunctionTypeAttr(TypeAttr::get(newType));
329 unsigned newNumArgs = op.getNumArguments();
330 unsigned newNumResults = op.getNumResults();
331
332 // Functor used to update the argument and result attributes of the function.
333 auto emptyDict = DictionaryAttr::get(op.getContext());
334 auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) {
335 constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>;
336
337 if (oldCount == newCount)
338 return;
339 // The new type has no arguments/results, just drop the attribute.
340 if (newCount == 0)
341 return removeArgResAttrs<isArgVal>(op);
342 ArrayAttr attrs = getArgResAttrs<isArgVal>(op);
343 if (!attrs)
344 return;
345
346 // The new type has less arguments/results, take the first N attributes.
347 if (newCount < oldCount)
348 return setAllArgResAttrDicts<isArgVal>(
349 op, attrs.getValue().take_front(newCount));
350
351 // Otherwise, the new type has more arguments/results. Initialize the new
352 // arguments/results with empty dictionary attributes.
353 SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
354 newAttrs.resize(newCount, emptyDict);
355 setAllArgResAttrDicts<isArgVal>(op, newAttrs);
356 };
357
358 // Update the argument and result attributes.
359 updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs);
360 updateAttrFn(std::false_type{}, oldNumResults, newNumResults);
361}
362

source code of mlir/lib/Interfaces/FunctionInterfaces.cpp