| 1 | //===- FunctionImplementation.cpp - Utilities for function-like ops -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Interfaces/FunctionImplementation.h" |
| 10 | #include "mlir/IR/Builders.h" |
| 11 | #include "mlir/IR/SymbolTable.h" |
| 12 | #include "mlir/Interfaces/FunctionInterfaces.h" |
| 13 | |
| 14 | using namespace mlir; |
| 15 | |
| 16 | static ParseResult |
| 17 | parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic, |
| 18 | SmallVectorImpl<OpAsmParser::Argument> &arguments, |
| 19 | bool &isVariadic) { |
| 20 | |
| 21 | // Parse the function arguments. The argument list either has to consistently |
| 22 | // have ssa-id's followed by types, or just be a type list. It isn't ok to |
| 23 | // sometimes have SSA ID's and sometimes not. |
| 24 | isVariadic = false; |
| 25 | |
| 26 | return parser.parseCommaSeparatedList( |
| 27 | delimiter: OpAsmParser::Delimiter::Paren, parseElementFn: [&]() -> ParseResult { |
| 28 | // Ellipsis must be at end of the list. |
| 29 | if (isVariadic) |
| 30 | return parser.emitError( |
| 31 | loc: parser.getCurrentLocation(), |
| 32 | message: "variadic arguments must be in the end of the argument list" ); |
| 33 | |
| 34 | // Handle ellipsis as a special case. |
| 35 | if (allowVariadic && succeeded(Result: parser.parseOptionalEllipsis())) { |
| 36 | // This is a variadic designator. |
| 37 | isVariadic = true; |
| 38 | return success(); // Stop parsing arguments. |
| 39 | } |
| 40 | // Parse argument name if present. |
| 41 | OpAsmParser::Argument argument; |
| 42 | auto argPresent = parser.parseOptionalArgument( |
| 43 | result&: argument, /*allowType=*/true, /*allowAttrs=*/true); |
| 44 | if (argPresent.has_value()) { |
| 45 | if (failed(Result: argPresent.value())) |
| 46 | return failure(); // Present but malformed. |
| 47 | |
| 48 | // Reject this if the preceding argument was missing a name. |
| 49 | if (!arguments.empty() && arguments.back().ssaName.name.empty()) |
| 50 | return parser.emitError(loc: argument.ssaName.location, |
| 51 | message: "expected type instead of SSA identifier" ); |
| 52 | |
| 53 | } else { |
| 54 | argument.ssaName.location = parser.getCurrentLocation(); |
| 55 | // Otherwise we just have a type list without SSA names. Reject |
| 56 | // this if the preceding argument had a name. |
| 57 | if (!arguments.empty() && !arguments.back().ssaName.name.empty()) |
| 58 | return parser.emitError(loc: argument.ssaName.location, |
| 59 | message: "expected SSA identifier" ); |
| 60 | |
| 61 | NamedAttrList attrs; |
| 62 | if (parser.parseType(result&: argument.type) || |
| 63 | parser.parseOptionalAttrDict(result&: attrs) || |
| 64 | parser.parseOptionalLocationSpecifier(result&: argument.sourceLoc)) |
| 65 | return failure(); |
| 66 | argument.attrs = attrs.getDictionary(parser.getContext()); |
| 67 | } |
| 68 | arguments.push_back(Elt: argument); |
| 69 | return success(); |
| 70 | }); |
| 71 | } |
| 72 | |
| 73 | ParseResult function_interface_impl::parseFunctionSignatureWithArguments( |
| 74 | OpAsmParser &parser, bool allowVariadic, |
| 75 | SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic, |
| 76 | SmallVectorImpl<Type> &resultTypes, |
| 77 | SmallVectorImpl<DictionaryAttr> &resultAttrs) { |
| 78 | if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic)) |
| 79 | return failure(); |
| 80 | if (succeeded(Result: parser.parseOptionalArrow())) |
| 81 | return call_interface_impl::parseFunctionResultList(parser, resultTypes, |
| 82 | resultAttrs); |
| 83 | return success(); |
| 84 | } |
| 85 | |
| 86 | ParseResult function_interface_impl::parseFunctionOp( |
| 87 | OpAsmParser &parser, OperationState &result, bool allowVariadic, |
| 88 | StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, |
| 89 | StringAttr argAttrsName, StringAttr resAttrsName) { |
| 90 | SmallVector<OpAsmParser::Argument> entryArgs; |
| 91 | SmallVector<DictionaryAttr> resultAttrs; |
| 92 | SmallVector<Type> resultTypes; |
| 93 | auto &builder = parser.getBuilder(); |
| 94 | |
| 95 | // Parse visibility. |
| 96 | (void)impl::parseOptionalVisibilityKeyword(parser, attrs&: result.attributes); |
| 97 | |
| 98 | // Parse the name as a symbol. |
| 99 | StringAttr nameAttr; |
| 100 | if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), |
| 101 | result.attributes)) |
| 102 | return failure(); |
| 103 | |
| 104 | // Parse the function signature. |
| 105 | SMLoc signatureLocation = parser.getCurrentLocation(); |
| 106 | bool isVariadic = false; |
| 107 | if (parseFunctionSignatureWithArguments(parser, allowVariadic, entryArgs, |
| 108 | isVariadic, resultTypes, resultAttrs)) |
| 109 | return failure(); |
| 110 | |
| 111 | std::string errorMessage; |
| 112 | SmallVector<Type> argTypes; |
| 113 | argTypes.reserve(N: entryArgs.size()); |
| 114 | for (auto &arg : entryArgs) |
| 115 | argTypes.push_back(Elt: arg.type); |
| 116 | Type type = funcTypeBuilder(builder, argTypes, resultTypes, |
| 117 | VariadicFlag(isVariadic), errorMessage); |
| 118 | if (!type) { |
| 119 | return parser.emitError(loc: signatureLocation) |
| 120 | << "failed to construct function type" |
| 121 | << (errorMessage.empty() ? "" : ": " ) << errorMessage; |
| 122 | } |
| 123 | result.addAttribute(typeAttrName, TypeAttr::get(type)); |
| 124 | |
| 125 | // If function attributes are present, parse them. |
| 126 | NamedAttrList parsedAttributes; |
| 127 | SMLoc attributeDictLocation = parser.getCurrentLocation(); |
| 128 | if (parser.parseOptionalAttrDictWithKeyword(result&: parsedAttributes)) |
| 129 | return failure(); |
| 130 | |
| 131 | // Disallow attributes that are inferred from elsewhere in the attribute |
| 132 | // dictionary. |
| 133 | for (StringRef disallowed : |
| 134 | {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), |
| 135 | typeAttrName.getValue()}) { |
| 136 | if (parsedAttributes.get(disallowed)) |
| 137 | return parser.emitError(attributeDictLocation, "'" ) |
| 138 | << disallowed |
| 139 | << "' is an inferred attribute and should not be specified in the " |
| 140 | "explicit attribute dictionary" ; |
| 141 | } |
| 142 | result.attributes.append(newAttributes&: parsedAttributes); |
| 143 | |
| 144 | // Add the attributes to the function arguments. |
| 145 | assert(resultAttrs.size() == resultTypes.size()); |
| 146 | call_interface_impl::addArgAndResultAttrs( |
| 147 | builder, result, entryArgs, resultAttrs, argAttrsName, resAttrsName); |
| 148 | |
| 149 | // Parse the optional function body. The printer will not print the body if |
| 150 | // its empty, so disallow parsing of empty body in the parser. |
| 151 | auto *body = result.addRegion(); |
| 152 | SMLoc loc = parser.getCurrentLocation(); |
| 153 | OptionalParseResult parseResult = |
| 154 | parser.parseOptionalRegion(region&: *body, arguments: entryArgs, |
| 155 | /*enableNameShadowing=*/false); |
| 156 | if (parseResult.has_value()) { |
| 157 | if (failed(Result: *parseResult)) |
| 158 | return failure(); |
| 159 | // Function body was parsed, make sure its not empty. |
| 160 | if (body->empty()) |
| 161 | return parser.emitError(loc, message: "expected non-empty function body" ); |
| 162 | } |
| 163 | return success(); |
| 164 | } |
| 165 | |
| 166 | void function_interface_impl::printFunctionAttributes( |
| 167 | OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) { |
| 168 | // Print out function attributes, if present. |
| 169 | SmallVector<StringRef, 8> ignoredAttrs = {SymbolTable::getSymbolAttrName()}; |
| 170 | ignoredAttrs.append(in_start: elided.begin(), in_end: elided.end()); |
| 171 | |
| 172 | p.printOptionalAttrDictWithKeyword(attrs: op->getAttrs(), elidedAttrs: ignoredAttrs); |
| 173 | } |
| 174 | |
| 175 | void function_interface_impl::printFunctionOp( |
| 176 | OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, |
| 177 | StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) { |
| 178 | // Print the operation and the function name. |
| 179 | auto funcName = |
| 180 | op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()) |
| 181 | .getValue(); |
| 182 | p << ' '; |
| 183 | |
| 184 | StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); |
| 185 | if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName)) |
| 186 | p << visibility.getValue() << ' '; |
| 187 | p.printSymbolName(symbolRef: funcName); |
| 188 | |
| 189 | ArrayRef<Type> argTypes = op.getArgumentTypes(); |
| 190 | ArrayRef<Type> resultTypes = op.getResultTypes(); |
| 191 | printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); |
| 192 | printFunctionAttributes( |
| 193 | p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName}); |
| 194 | // Print the body if this is not an external function. |
| 195 | Region &body = op->getRegion(0); |
| 196 | if (!body.empty()) { |
| 197 | p << ' '; |
| 198 | p.printRegion(blocks&: body, /*printEntryBlockArgs=*/false, |
| 199 | /*printBlockTerminators=*/true); |
| 200 | } |
| 201 | } |
| 202 | |