1 | //===- FunctionImplementation.cpp - Utilities for function-like ops -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Interfaces/FunctionImplementation.h" |
10 | #include "mlir/IR/Builders.h" |
11 | #include "mlir/IR/SymbolTable.h" |
12 | #include "mlir/Interfaces/FunctionInterfaces.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | static ParseResult |
17 | parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic, |
18 | SmallVectorImpl<OpAsmParser::Argument> &arguments, |
19 | bool &isVariadic) { |
20 | |
21 | // Parse the function arguments. The argument list either has to consistently |
22 | // have ssa-id's followed by types, or just be a type list. It isn't ok to |
23 | // sometimes have SSA ID's and sometimes not. |
24 | isVariadic = false; |
25 | |
26 | return parser.parseCommaSeparatedList( |
27 | delimiter: OpAsmParser::Delimiter::Paren, parseElementFn: [&]() -> ParseResult { |
28 | // Ellipsis must be at end of the list. |
29 | if (isVariadic) |
30 | return parser.emitError( |
31 | loc: parser.getCurrentLocation(), |
32 | message: "variadic arguments must be in the end of the argument list" ); |
33 | |
34 | // Handle ellipsis as a special case. |
35 | if (allowVariadic && succeeded(result: parser.parseOptionalEllipsis())) { |
36 | // This is a variadic designator. |
37 | isVariadic = true; |
38 | return success(); // Stop parsing arguments. |
39 | } |
40 | // Parse argument name if present. |
41 | OpAsmParser::Argument argument; |
42 | auto argPresent = parser.parseOptionalArgument( |
43 | result&: argument, /*allowType=*/true, /*allowAttrs=*/true); |
44 | if (argPresent.has_value()) { |
45 | if (failed(result: argPresent.value())) |
46 | return failure(); // Present but malformed. |
47 | |
48 | // Reject this if the preceding argument was missing a name. |
49 | if (!arguments.empty() && arguments.back().ssaName.name.empty()) |
50 | return parser.emitError(loc: argument.ssaName.location, |
51 | message: "expected type instead of SSA identifier" ); |
52 | |
53 | } else { |
54 | argument.ssaName.location = parser.getCurrentLocation(); |
55 | // Otherwise we just have a type list without SSA names. Reject |
56 | // this if the preceding argument had a name. |
57 | if (!arguments.empty() && !arguments.back().ssaName.name.empty()) |
58 | return parser.emitError(loc: argument.ssaName.location, |
59 | message: "expected SSA identifier" ); |
60 | |
61 | NamedAttrList attrs; |
62 | if (parser.parseType(result&: argument.type) || |
63 | parser.parseOptionalAttrDict(result&: attrs) || |
64 | parser.parseOptionalLocationSpecifier(result&: argument.sourceLoc)) |
65 | return failure(); |
66 | argument.attrs = attrs.getDictionary(parser.getContext()); |
67 | } |
68 | arguments.push_back(Elt: argument); |
69 | return success(); |
70 | }); |
71 | } |
72 | |
73 | /// Parse a function result list. |
74 | /// |
75 | /// function-result-list ::= function-result-list-parens |
76 | /// | non-function-type |
77 | /// function-result-list-parens ::= `(` `)` |
78 | /// | `(` function-result-list-no-parens `)` |
79 | /// function-result-list-no-parens ::= function-result (`,` function-result)* |
80 | /// function-result ::= type attribute-dict? |
81 | /// |
82 | static ParseResult |
83 | parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes, |
84 | SmallVectorImpl<DictionaryAttr> &resultAttrs) { |
85 | if (failed(result: parser.parseOptionalLParen())) { |
86 | // We already know that there is no `(`, so parse a type. |
87 | // Because there is no `(`, it cannot be a function type. |
88 | Type ty; |
89 | if (parser.parseType(result&: ty)) |
90 | return failure(); |
91 | resultTypes.push_back(Elt: ty); |
92 | resultAttrs.emplace_back(); |
93 | return success(); |
94 | } |
95 | |
96 | // Special case for an empty set of parens. |
97 | if (succeeded(result: parser.parseOptionalRParen())) |
98 | return success(); |
99 | |
100 | // Parse individual function results. |
101 | if (parser.parseCommaSeparatedList(parseElementFn: [&]() -> ParseResult { |
102 | resultTypes.emplace_back(); |
103 | resultAttrs.emplace_back(); |
104 | NamedAttrList attrs; |
105 | if (parser.parseType(result&: resultTypes.back()) || |
106 | parser.parseOptionalAttrDict(result&: attrs)) |
107 | return failure(); |
108 | resultAttrs.back() = attrs.getDictionary(parser.getContext()); |
109 | return success(); |
110 | })) |
111 | return failure(); |
112 | |
113 | return parser.parseRParen(); |
114 | } |
115 | |
116 | ParseResult function_interface_impl::parseFunctionSignature( |
117 | OpAsmParser &parser, bool allowVariadic, |
118 | SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic, |
119 | SmallVectorImpl<Type> &resultTypes, |
120 | SmallVectorImpl<DictionaryAttr> &resultAttrs) { |
121 | if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic)) |
122 | return failure(); |
123 | if (succeeded(result: parser.parseOptionalArrow())) |
124 | return parseFunctionResultList(parser, resultTypes, resultAttrs); |
125 | return success(); |
126 | } |
127 | |
128 | void function_interface_impl::addArgAndResultAttrs( |
129 | Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs, |
130 | ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName, |
131 | StringAttr resAttrsName) { |
132 | auto nonEmptyAttrsFn = [](DictionaryAttr attrs) { |
133 | return attrs && !attrs.empty(); |
134 | }; |
135 | // Convert the specified array of dictionary attrs (which may have null |
136 | // entries) to an ArrayAttr of dictionaries. |
137 | auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) { |
138 | SmallVector<Attribute> attrs; |
139 | for (auto &dict : dictAttrs) |
140 | attrs.push_back(dict ? dict : builder.getDictionaryAttr({})); |
141 | return builder.getArrayAttr(attrs); |
142 | }; |
143 | |
144 | // Add the attributes to the function arguments. |
145 | if (llvm::any_of(Range&: argAttrs, P: nonEmptyAttrsFn)) |
146 | result.addAttribute(argAttrsName, getArrayAttr(argAttrs)); |
147 | |
148 | // Add the attributes to the function results. |
149 | if (llvm::any_of(Range&: resultAttrs, P: nonEmptyAttrsFn)) |
150 | result.addAttribute(resAttrsName, getArrayAttr(resultAttrs)); |
151 | } |
152 | |
153 | void function_interface_impl::addArgAndResultAttrs( |
154 | Builder &builder, OperationState &result, |
155 | ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs, |
156 | StringAttr argAttrsName, StringAttr resAttrsName) { |
157 | SmallVector<DictionaryAttr> argAttrs; |
158 | for (const auto &arg : args) |
159 | argAttrs.push_back(arg.attrs); |
160 | addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName, |
161 | resAttrsName); |
162 | } |
163 | |
164 | ParseResult function_interface_impl::parseFunctionOp( |
165 | OpAsmParser &parser, OperationState &result, bool allowVariadic, |
166 | StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, |
167 | StringAttr argAttrsName, StringAttr resAttrsName) { |
168 | SmallVector<OpAsmParser::Argument> entryArgs; |
169 | SmallVector<DictionaryAttr> resultAttrs; |
170 | SmallVector<Type> resultTypes; |
171 | auto &builder = parser.getBuilder(); |
172 | |
173 | // Parse visibility. |
174 | (void)impl::parseOptionalVisibilityKeyword(parser, attrs&: result.attributes); |
175 | |
176 | // Parse the name as a symbol. |
177 | StringAttr nameAttr; |
178 | if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), |
179 | result.attributes)) |
180 | return failure(); |
181 | |
182 | // Parse the function signature. |
183 | SMLoc signatureLocation = parser.getCurrentLocation(); |
184 | bool isVariadic = false; |
185 | if (parseFunctionSignature(parser, allowVariadic, entryArgs, isVariadic, |
186 | resultTypes, resultAttrs)) |
187 | return failure(); |
188 | |
189 | std::string errorMessage; |
190 | SmallVector<Type> argTypes; |
191 | argTypes.reserve(N: entryArgs.size()); |
192 | for (auto &arg : entryArgs) |
193 | argTypes.push_back(Elt: arg.type); |
194 | Type type = funcTypeBuilder(builder, argTypes, resultTypes, |
195 | VariadicFlag(isVariadic), errorMessage); |
196 | if (!type) { |
197 | return parser.emitError(loc: signatureLocation) |
198 | << "failed to construct function type" |
199 | << (errorMessage.empty() ? "" : ": " ) << errorMessage; |
200 | } |
201 | result.addAttribute(typeAttrName, TypeAttr::get(type)); |
202 | |
203 | // If function attributes are present, parse them. |
204 | NamedAttrList parsedAttributes; |
205 | SMLoc attributeDictLocation = parser.getCurrentLocation(); |
206 | if (parser.parseOptionalAttrDictWithKeyword(result&: parsedAttributes)) |
207 | return failure(); |
208 | |
209 | // Disallow attributes that are inferred from elsewhere in the attribute |
210 | // dictionary. |
211 | for (StringRef disallowed : |
212 | {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), |
213 | typeAttrName.getValue()}) { |
214 | if (parsedAttributes.get(disallowed)) |
215 | return parser.emitError(attributeDictLocation, "'" ) |
216 | << disallowed |
217 | << "' is an inferred attribute and should not be specified in the " |
218 | "explicit attribute dictionary" ; |
219 | } |
220 | result.attributes.append(newAttributes&: parsedAttributes); |
221 | |
222 | // Add the attributes to the function arguments. |
223 | assert(resultAttrs.size() == resultTypes.size()); |
224 | addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName, |
225 | resAttrsName); |
226 | |
227 | // Parse the optional function body. The printer will not print the body if |
228 | // its empty, so disallow parsing of empty body in the parser. |
229 | auto *body = result.addRegion(); |
230 | SMLoc loc = parser.getCurrentLocation(); |
231 | OptionalParseResult parseResult = |
232 | parser.parseOptionalRegion(region&: *body, arguments: entryArgs, |
233 | /*enableNameShadowing=*/false); |
234 | if (parseResult.has_value()) { |
235 | if (failed(result: *parseResult)) |
236 | return failure(); |
237 | // Function body was parsed, make sure its not empty. |
238 | if (body->empty()) |
239 | return parser.emitError(loc, message: "expected non-empty function body" ); |
240 | } |
241 | return success(); |
242 | } |
243 | |
244 | /// Print a function result list. The provided `attrs` must either be null, or |
245 | /// contain a set of DictionaryAttrs of the same arity as `types`. |
246 | static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types, |
247 | ArrayAttr attrs) { |
248 | assert(!types.empty() && "Should not be called for empty result list." ); |
249 | assert((!attrs || attrs.size() == types.size()) && |
250 | "Invalid number of attributes." ); |
251 | |
252 | auto &os = p.getStream(); |
253 | bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(Val: types[0]) || |
254 | (attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty()); |
255 | if (needsParens) |
256 | os << '('; |
257 | llvm::interleaveComma(c: llvm::seq<size_t>(Begin: 0, End: types.size()), os, each_fn: [&](size_t i) { |
258 | p.printType(type: types[i]); |
259 | if (attrs) |
260 | p.printOptionalAttrDict(attrs: llvm::cast<DictionaryAttr>(attrs[i]).getValue()); |
261 | }); |
262 | if (needsParens) |
263 | os << ')'; |
264 | } |
265 | |
266 | void function_interface_impl::printFunctionSignature( |
267 | OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes, |
268 | bool isVariadic, ArrayRef<Type> resultTypes) { |
269 | Region &body = op->getRegion(0); |
270 | bool isExternal = body.empty(); |
271 | |
272 | p << '('; |
273 | ArrayAttr argAttrs = op.getArgAttrsAttr(); |
274 | for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { |
275 | if (i > 0) |
276 | p << ", " ; |
277 | |
278 | if (!isExternal) { |
279 | ArrayRef<NamedAttribute> attrs; |
280 | if (argAttrs) |
281 | attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue(); |
282 | p.printRegionArgument(arg: body.getArgument(i), argAttrs: attrs); |
283 | } else { |
284 | p.printType(type: argTypes[i]); |
285 | if (argAttrs) |
286 | p.printOptionalAttrDict( |
287 | attrs: llvm::cast<DictionaryAttr>(argAttrs[i]).getValue()); |
288 | } |
289 | } |
290 | |
291 | if (isVariadic) { |
292 | if (!argTypes.empty()) |
293 | p << ", " ; |
294 | p << "..." ; |
295 | } |
296 | |
297 | p << ')'; |
298 | |
299 | if (!resultTypes.empty()) { |
300 | p.getStream() << " -> " ; |
301 | auto resultAttrs = op.getResAttrsAttr(); |
302 | printFunctionResultList(p, resultTypes, resultAttrs); |
303 | } |
304 | } |
305 | |
306 | void function_interface_impl::printFunctionAttributes( |
307 | OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) { |
308 | // Print out function attributes, if present. |
309 | SmallVector<StringRef, 8> ignoredAttrs = {SymbolTable::getSymbolAttrName()}; |
310 | ignoredAttrs.append(in_start: elided.begin(), in_end: elided.end()); |
311 | |
312 | p.printOptionalAttrDictWithKeyword(attrs: op->getAttrs(), elidedAttrs: ignoredAttrs); |
313 | } |
314 | |
315 | void function_interface_impl::printFunctionOp( |
316 | OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, |
317 | StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) { |
318 | // Print the operation and the function name. |
319 | auto funcName = |
320 | op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()) |
321 | .getValue(); |
322 | p << ' '; |
323 | |
324 | StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); |
325 | if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName)) |
326 | p << visibility.getValue() << ' '; |
327 | p.printSymbolName(symbolRef: funcName); |
328 | |
329 | ArrayRef<Type> argTypes = op.getArgumentTypes(); |
330 | ArrayRef<Type> resultTypes = op.getResultTypes(); |
331 | printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); |
332 | printFunctionAttributes( |
333 | p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName}); |
334 | // Print the body if this is not an external function. |
335 | Region &body = op->getRegion(0); |
336 | if (!body.empty()) { |
337 | p << ' '; |
338 | p.printRegion(blocks&: body, /*printEntryBlockArgs=*/false, |
339 | /*printBlockTerminators=*/true); |
340 | } |
341 | } |
342 | |