1 | //===- FunctionImplementation.cpp - Utilities for function-like ops -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Interfaces/FunctionImplementation.h" |
10 | #include "mlir/IR/Builders.h" |
11 | #include "mlir/IR/SymbolTable.h" |
12 | #include "mlir/Interfaces/FunctionInterfaces.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | static ParseResult |
17 | parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic, |
18 | SmallVectorImpl<OpAsmParser::Argument> &arguments, |
19 | bool &isVariadic) { |
20 | |
21 | // Parse the function arguments. The argument list either has to consistently |
22 | // have ssa-id's followed by types, or just be a type list. It isn't ok to |
23 | // sometimes have SSA ID's and sometimes not. |
24 | isVariadic = false; |
25 | |
26 | return parser.parseCommaSeparatedList( |
27 | delimiter: OpAsmParser::Delimiter::Paren, parseElementFn: [&]() -> ParseResult { |
28 | // Ellipsis must be at end of the list. |
29 | if (isVariadic) |
30 | return parser.emitError( |
31 | loc: parser.getCurrentLocation(), |
32 | message: "variadic arguments must be in the end of the argument list" ); |
33 | |
34 | // Handle ellipsis as a special case. |
35 | if (allowVariadic && succeeded(Result: parser.parseOptionalEllipsis())) { |
36 | // This is a variadic designator. |
37 | isVariadic = true; |
38 | return success(); // Stop parsing arguments. |
39 | } |
40 | // Parse argument name if present. |
41 | OpAsmParser::Argument argument; |
42 | auto argPresent = parser.parseOptionalArgument( |
43 | result&: argument, /*allowType=*/true, /*allowAttrs=*/true); |
44 | if (argPresent.has_value()) { |
45 | if (failed(Result: argPresent.value())) |
46 | return failure(); // Present but malformed. |
47 | |
48 | // Reject this if the preceding argument was missing a name. |
49 | if (!arguments.empty() && arguments.back().ssaName.name.empty()) |
50 | return parser.emitError(loc: argument.ssaName.location, |
51 | message: "expected type instead of SSA identifier" ); |
52 | |
53 | } else { |
54 | argument.ssaName.location = parser.getCurrentLocation(); |
55 | // Otherwise we just have a type list without SSA names. Reject |
56 | // this if the preceding argument had a name. |
57 | if (!arguments.empty() && !arguments.back().ssaName.name.empty()) |
58 | return parser.emitError(loc: argument.ssaName.location, |
59 | message: "expected SSA identifier" ); |
60 | |
61 | NamedAttrList attrs; |
62 | if (parser.parseType(result&: argument.type) || |
63 | parser.parseOptionalAttrDict(result&: attrs) || |
64 | parser.parseOptionalLocationSpecifier(result&: argument.sourceLoc)) |
65 | return failure(); |
66 | argument.attrs = attrs.getDictionary(parser.getContext()); |
67 | } |
68 | arguments.push_back(Elt: argument); |
69 | return success(); |
70 | }); |
71 | } |
72 | |
73 | ParseResult function_interface_impl::parseFunctionSignatureWithArguments( |
74 | OpAsmParser &parser, bool allowVariadic, |
75 | SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic, |
76 | SmallVectorImpl<Type> &resultTypes, |
77 | SmallVectorImpl<DictionaryAttr> &resultAttrs) { |
78 | if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic)) |
79 | return failure(); |
80 | if (succeeded(Result: parser.parseOptionalArrow())) |
81 | return call_interface_impl::parseFunctionResultList(parser, resultTypes, |
82 | resultAttrs); |
83 | return success(); |
84 | } |
85 | |
86 | ParseResult function_interface_impl::parseFunctionOp( |
87 | OpAsmParser &parser, OperationState &result, bool allowVariadic, |
88 | StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, |
89 | StringAttr argAttrsName, StringAttr resAttrsName) { |
90 | SmallVector<OpAsmParser::Argument> entryArgs; |
91 | SmallVector<DictionaryAttr> resultAttrs; |
92 | SmallVector<Type> resultTypes; |
93 | auto &builder = parser.getBuilder(); |
94 | |
95 | // Parse visibility. |
96 | (void)impl::parseOptionalVisibilityKeyword(parser, attrs&: result.attributes); |
97 | |
98 | // Parse the name as a symbol. |
99 | StringAttr nameAttr; |
100 | if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), |
101 | result.attributes)) |
102 | return failure(); |
103 | |
104 | // Parse the function signature. |
105 | SMLoc signatureLocation = parser.getCurrentLocation(); |
106 | bool isVariadic = false; |
107 | if (parseFunctionSignatureWithArguments(parser, allowVariadic, entryArgs, |
108 | isVariadic, resultTypes, resultAttrs)) |
109 | return failure(); |
110 | |
111 | std::string errorMessage; |
112 | SmallVector<Type> argTypes; |
113 | argTypes.reserve(N: entryArgs.size()); |
114 | for (auto &arg : entryArgs) |
115 | argTypes.push_back(Elt: arg.type); |
116 | Type type = funcTypeBuilder(builder, argTypes, resultTypes, |
117 | VariadicFlag(isVariadic), errorMessage); |
118 | if (!type) { |
119 | return parser.emitError(loc: signatureLocation) |
120 | << "failed to construct function type" |
121 | << (errorMessage.empty() ? "" : ": " ) << errorMessage; |
122 | } |
123 | result.addAttribute(typeAttrName, TypeAttr::get(type)); |
124 | |
125 | // If function attributes are present, parse them. |
126 | NamedAttrList parsedAttributes; |
127 | SMLoc attributeDictLocation = parser.getCurrentLocation(); |
128 | if (parser.parseOptionalAttrDictWithKeyword(result&: parsedAttributes)) |
129 | return failure(); |
130 | |
131 | // Disallow attributes that are inferred from elsewhere in the attribute |
132 | // dictionary. |
133 | for (StringRef disallowed : |
134 | {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), |
135 | typeAttrName.getValue()}) { |
136 | if (parsedAttributes.get(disallowed)) |
137 | return parser.emitError(attributeDictLocation, "'" ) |
138 | << disallowed |
139 | << "' is an inferred attribute and should not be specified in the " |
140 | "explicit attribute dictionary" ; |
141 | } |
142 | result.attributes.append(newAttributes&: parsedAttributes); |
143 | |
144 | // Add the attributes to the function arguments. |
145 | assert(resultAttrs.size() == resultTypes.size()); |
146 | call_interface_impl::addArgAndResultAttrs( |
147 | builder, result, entryArgs, resultAttrs, argAttrsName, resAttrsName); |
148 | |
149 | // Parse the optional function body. The printer will not print the body if |
150 | // its empty, so disallow parsing of empty body in the parser. |
151 | auto *body = result.addRegion(); |
152 | SMLoc loc = parser.getCurrentLocation(); |
153 | OptionalParseResult parseResult = |
154 | parser.parseOptionalRegion(region&: *body, arguments: entryArgs, |
155 | /*enableNameShadowing=*/false); |
156 | if (parseResult.has_value()) { |
157 | if (failed(Result: *parseResult)) |
158 | return failure(); |
159 | // Function body was parsed, make sure its not empty. |
160 | if (body->empty()) |
161 | return parser.emitError(loc, message: "expected non-empty function body" ); |
162 | } |
163 | return success(); |
164 | } |
165 | |
166 | void function_interface_impl::printFunctionAttributes( |
167 | OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) { |
168 | // Print out function attributes, if present. |
169 | SmallVector<StringRef, 8> ignoredAttrs = {SymbolTable::getSymbolAttrName()}; |
170 | ignoredAttrs.append(in_start: elided.begin(), in_end: elided.end()); |
171 | |
172 | p.printOptionalAttrDictWithKeyword(attrs: op->getAttrs(), elidedAttrs: ignoredAttrs); |
173 | } |
174 | |
175 | void function_interface_impl::printFunctionOp( |
176 | OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, |
177 | StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) { |
178 | // Print the operation and the function name. |
179 | auto funcName = |
180 | op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()) |
181 | .getValue(); |
182 | p << ' '; |
183 | |
184 | StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); |
185 | if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName)) |
186 | p << visibility.getValue() << ' '; |
187 | p.printSymbolName(symbolRef: funcName); |
188 | |
189 | ArrayRef<Type> argTypes = op.getArgumentTypes(); |
190 | ArrayRef<Type> resultTypes = op.getResultTypes(); |
191 | printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); |
192 | printFunctionAttributes( |
193 | p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName}); |
194 | // Print the body if this is not an external function. |
195 | Region &body = op->getRegion(0); |
196 | if (!body.empty()) { |
197 | p << ' '; |
198 | p.printRegion(blocks&: body, /*printEntryBlockArgs=*/false, |
199 | /*printBlockTerminators=*/true); |
200 | } |
201 | } |
202 | |