1 | //===- FunctionSupport.h - Utility types for function-like ops --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines support types for Operations that represent function-like |
10 | // constructs to use. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_IR_FUNCTIONINTERFACES_H |
15 | #define MLIR_IR_FUNCTIONINTERFACES_H |
16 | |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/BuiltinTypes.h" |
19 | #include "mlir/IR/OpDefinition.h" |
20 | #include "mlir/IR/SymbolTable.h" |
21 | #include "mlir/IR/TypeUtilities.h" |
22 | #include "mlir/Interfaces/CallInterfaces.h" |
23 | #include "llvm/ADT/BitVector.h" |
24 | #include "llvm/ADT/SmallString.h" |
25 | |
26 | namespace mlir { |
27 | class FunctionOpInterface; |
28 | |
29 | namespace function_interface_impl { |
30 | |
31 | /// Returns the dictionary attribute corresponding to the argument at 'index'. |
32 | /// If there are no argument attributes at 'index', a null attribute is |
33 | /// returned. |
34 | DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index); |
35 | |
36 | /// Returns the dictionary attribute corresponding to the result at 'index'. |
37 | /// If there are no result attributes at 'index', a null attribute is |
38 | /// returned. |
39 | DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index); |
40 | |
41 | /// Return all of the attributes for the argument at 'index'. |
42 | ArrayRef<NamedAttribute> getArgAttrs(FunctionOpInterface op, unsigned index); |
43 | |
44 | /// Return all of the attributes for the result at 'index'. |
45 | ArrayRef<NamedAttribute> getResultAttrs(FunctionOpInterface op, unsigned index); |
46 | |
47 | /// Set all of the argument or result attribute dictionaries for a function. The |
48 | /// size of `attrs` is expected to match the number of arguments/results of the |
49 | /// given `op`. |
50 | void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs); |
51 | void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef<Attribute> attrs); |
52 | void setAllResultAttrDicts(FunctionOpInterface op, |
53 | ArrayRef<DictionaryAttr> attrs); |
54 | void setAllResultAttrDicts(FunctionOpInterface op, ArrayRef<Attribute> attrs); |
55 | |
56 | /// Insert the specified arguments and update the function type attribute. |
57 | void insertFunctionArguments(FunctionOpInterface op, |
58 | ArrayRef<unsigned> argIndices, TypeRange argTypes, |
59 | ArrayRef<DictionaryAttr> argAttrs, |
60 | ArrayRef<Location> argLocs, |
61 | unsigned originalNumArgs, Type newType); |
62 | |
63 | /// Insert the specified results and update the function type attribute. |
64 | void insertFunctionResults(FunctionOpInterface op, |
65 | ArrayRef<unsigned> resultIndices, |
66 | TypeRange resultTypes, |
67 | ArrayRef<DictionaryAttr> resultAttrs, |
68 | unsigned originalNumResults, Type newType); |
69 | |
70 | /// Erase the specified arguments and update the function type attribute. |
71 | void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices, |
72 | Type newType); |
73 | |
74 | /// Erase the specified results and update the function type attribute. |
75 | void eraseFunctionResults(FunctionOpInterface op, |
76 | const BitVector &resultIndices, Type newType); |
77 | |
78 | /// Set a FunctionOpInterface operation's type signature. |
79 | void setFunctionType(FunctionOpInterface op, Type newType); |
80 | |
81 | //===----------------------------------------------------------------------===// |
82 | // Function Argument Attribute. |
83 | //===----------------------------------------------------------------------===// |
84 | |
85 | /// Set the attributes held by the argument at 'index'. |
86 | void setArgAttrs(FunctionOpInterface op, unsigned index, |
87 | ArrayRef<NamedAttribute> attributes); |
88 | void setArgAttrs(FunctionOpInterface op, unsigned index, |
89 | DictionaryAttr attributes); |
90 | |
91 | /// If the an attribute exists with the specified name, change it to the new |
92 | /// value. Otherwise, add a new attribute with the specified name/value. |
93 | template <typename ConcreteType> |
94 | void setArgAttr(ConcreteType op, unsigned index, StringAttr name, |
95 | Attribute value) { |
96 | NamedAttrList attributes(op.getArgAttrDict(index)); |
97 | Attribute oldValue = attributes.set(name, value); |
98 | |
99 | // If the attribute changed, then set the new arg attribute list. |
100 | if (value != oldValue) |
101 | op.setArgAttrs(index, attributes.getDictionary(value.getContext())); |
102 | } |
103 | |
104 | /// Remove the attribute 'name' from the argument at 'index'. Returns the |
105 | /// removed attribute, or nullptr if `name` was not a valid attribute. |
106 | template <typename ConcreteType> |
107 | Attribute removeArgAttr(ConcreteType op, unsigned index, StringAttr name) { |
108 | // Build an attribute list and remove the attribute at 'name'. |
109 | NamedAttrList attributes(op.getArgAttrDict(index)); |
110 | Attribute removedAttr = attributes.erase(name); |
111 | |
112 | // If the attribute was removed, then update the argument dictionary. |
113 | if (removedAttr) |
114 | op.setArgAttrs(index, attributes.getDictionary(removedAttr.getContext())); |
115 | return removedAttr; |
116 | } |
117 | |
118 | //===----------------------------------------------------------------------===// |
119 | // Function Result Attribute. |
120 | //===----------------------------------------------------------------------===// |
121 | |
122 | /// Set the attributes held by the result at 'index'. |
123 | void setResultAttrs(FunctionOpInterface op, unsigned index, |
124 | ArrayRef<NamedAttribute> attributes); |
125 | void setResultAttrs(FunctionOpInterface op, unsigned index, |
126 | DictionaryAttr attributes); |
127 | |
128 | /// If the an attribute exists with the specified name, change it to the new |
129 | /// value. Otherwise, add a new attribute with the specified name/value. |
130 | template <typename ConcreteType> |
131 | void setResultAttr(ConcreteType op, unsigned index, StringAttr name, |
132 | Attribute value) { |
133 | NamedAttrList attributes(op.getResultAttrDict(index)); |
134 | Attribute oldAttr = attributes.set(name, value); |
135 | |
136 | // If the attribute changed, then set the new arg attribute list. |
137 | if (oldAttr != value) |
138 | op.setResultAttrs(index, attributes.getDictionary(value.getContext())); |
139 | } |
140 | |
141 | /// Remove the attribute 'name' from the result at 'index'. |
142 | template <typename ConcreteType> |
143 | Attribute removeResultAttr(ConcreteType op, unsigned index, StringAttr name) { |
144 | // Build an attribute list and remove the attribute at 'name'. |
145 | NamedAttrList attributes(op.getResultAttrDict(index)); |
146 | Attribute removedAttr = attributes.erase(name); |
147 | |
148 | // If the attribute was removed, then update the result dictionary. |
149 | if (removedAttr) |
150 | op.setResultAttrs(index, |
151 | attributes.getDictionary(removedAttr.getContext())); |
152 | return removedAttr; |
153 | } |
154 | |
155 | /// This function defines the internal implementation of the `verifyTrait` |
156 | /// method on FunctionOpInterface::Trait. |
157 | template <typename ConcreteOp> |
158 | LogicalResult verifyTrait(ConcreteOp op) { |
159 | if (failed(op.verifyType())) |
160 | return failure(); |
161 | |
162 | if (ArrayAttr allArgAttrs = op.getAllArgAttrs()) { |
163 | unsigned numArgs = op.getNumArguments(); |
164 | if (allArgAttrs.size() != numArgs) { |
165 | return op.emitOpError() |
166 | << "expects argument attribute array to have the same number of " |
167 | "elements as the number of function arguments, got " |
168 | << allArgAttrs.size() << ", but expected " << numArgs; |
169 | } |
170 | for (unsigned i = 0; i != numArgs; ++i) { |
171 | DictionaryAttr argAttrs = |
172 | llvm::dyn_cast_or_null<DictionaryAttr>(allArgAttrs[i]); |
173 | if (!argAttrs) { |
174 | return op.emitOpError() << "expects argument attribute dictionary " |
175 | "to be a DictionaryAttr, but got `" |
176 | << allArgAttrs[i] << "`" ; |
177 | } |
178 | |
179 | // Verify that all of the argument attributes are dialect attributes, i.e. |
180 | // that they contain a dialect prefix in their name. Call the dialect, if |
181 | // registered, to verify the attributes themselves. |
182 | for (auto attr : argAttrs) { |
183 | if (!attr.getName().strref().contains('.')) |
184 | return op.emitOpError("arguments may only have dialect attributes" ); |
185 | if (Dialect *dialect = attr.getNameDialect()) { |
186 | if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, |
187 | /*argIndex=*/i, attr))) |
188 | return failure(); |
189 | } |
190 | } |
191 | } |
192 | } |
193 | if (ArrayAttr allResultAttrs = op.getAllResultAttrs()) { |
194 | unsigned numResults = op.getNumResults(); |
195 | if (allResultAttrs.size() != numResults) { |
196 | return op.emitOpError() |
197 | << "expects result attribute array to have the same number of " |
198 | "elements as the number of function results, got " |
199 | << allResultAttrs.size() << ", but expected " << numResults; |
200 | } |
201 | for (unsigned i = 0; i != numResults; ++i) { |
202 | DictionaryAttr resultAttrs = |
203 | llvm::dyn_cast_or_null<DictionaryAttr>(allResultAttrs[i]); |
204 | if (!resultAttrs) { |
205 | return op.emitOpError() << "expects result attribute dictionary " |
206 | "to be a DictionaryAttr, but got `" |
207 | << allResultAttrs[i] << "`" ; |
208 | } |
209 | |
210 | // Verify that all of the result attributes are dialect attributes, i.e. |
211 | // that they contain a dialect prefix in their name. Call the dialect, if |
212 | // registered, to verify the attributes themselves. |
213 | for (auto attr : resultAttrs) { |
214 | if (!attr.getName().strref().contains('.')) |
215 | return op.emitOpError("results may only have dialect attributes" ); |
216 | if (Dialect *dialect = attr.getNameDialect()) { |
217 | if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, |
218 | /*resultIndex=*/i, |
219 | attr))) |
220 | return failure(); |
221 | } |
222 | } |
223 | } |
224 | } |
225 | |
226 | // Check that the op has exactly one region for the body. |
227 | if (op->getNumRegions() != 1) |
228 | return op.emitOpError("expects one region" ); |
229 | |
230 | return op.verifyBody(); |
231 | } |
232 | } // namespace function_interface_impl |
233 | } // namespace mlir |
234 | |
235 | //===----------------------------------------------------------------------===// |
236 | // Tablegen Interface Declarations |
237 | //===----------------------------------------------------------------------===// |
238 | |
239 | #include "mlir/Interfaces/FunctionInterfaces.h.inc" |
240 | |
241 | #endif // MLIR_IR_FUNCTIONINTERFACES_H |
242 | |