| 1 | //===- FuncOps.cpp - Func Dialect Operations ------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 10 | |
| 11 | #include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" |
| 12 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
| 13 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| 14 | #include "mlir/IR/BuiltinOps.h" |
| 15 | #include "mlir/IR/BuiltinTypes.h" |
| 16 | #include "mlir/IR/IRMapping.h" |
| 17 | #include "mlir/IR/Matchers.h" |
| 18 | #include "mlir/IR/OpImplementation.h" |
| 19 | #include "mlir/IR/PatternMatch.h" |
| 20 | #include "mlir/IR/TypeUtilities.h" |
| 21 | #include "mlir/IR/Value.h" |
| 22 | #include "mlir/Interfaces/FunctionImplementation.h" |
| 23 | #include "mlir/Transforms/InliningUtils.h" |
| 24 | #include "llvm/ADT/APFloat.h" |
| 25 | #include "llvm/ADT/MapVector.h" |
| 26 | #include "llvm/ADT/STLExtras.h" |
| 27 | #include "llvm/Support/FormatVariadic.h" |
| 28 | #include "llvm/Support/raw_ostream.h" |
| 29 | #include <numeric> |
| 30 | |
| 31 | #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc" |
| 32 | |
| 33 | using namespace mlir; |
| 34 | using namespace mlir::func; |
| 35 | |
| 36 | //===----------------------------------------------------------------------===// |
| 37 | // FuncDialect |
| 38 | //===----------------------------------------------------------------------===// |
| 39 | |
| 40 | void FuncDialect::initialize() { |
| 41 | addOperations< |
| 42 | #define GET_OP_LIST |
| 43 | #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" |
| 44 | >(); |
| 45 | declarePromisedInterface<ConvertToEmitCPatternInterface, FuncDialect>(); |
| 46 | declarePromisedInterface<DialectInlinerInterface, FuncDialect>(); |
| 47 | declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>(); |
| 48 | declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp, |
| 49 | FuncOp, ReturnOp>(); |
| 50 | } |
| 51 | |
| 52 | /// Materialize a single constant operation from a given attribute value with |
| 53 | /// the desired resultant type. |
| 54 | Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value, |
| 55 | Type type, Location loc) { |
| 56 | if (ConstantOp::isBuildableWith(value, type)) |
| 57 | return builder.create<ConstantOp>(loc, type, |
| 58 | llvm::cast<FlatSymbolRefAttr>(value)); |
| 59 | return nullptr; |
| 60 | } |
| 61 | |
| 62 | //===----------------------------------------------------------------------===// |
| 63 | // CallOp |
| 64 | //===----------------------------------------------------------------------===// |
| 65 | |
| 66 | LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| 67 | // Check that the callee attribute was specified. |
| 68 | auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee" ); |
| 69 | if (!fnAttr) |
| 70 | return emitOpError("requires a 'callee' symbol reference attribute" ); |
| 71 | FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr); |
| 72 | if (!fn) |
| 73 | return emitOpError() << "'" << fnAttr.getValue() |
| 74 | << "' does not reference a valid function" ; |
| 75 | |
| 76 | // Verify that the operand and result types match the callee. |
| 77 | auto fnType = fn.getFunctionType(); |
| 78 | if (fnType.getNumInputs() != getNumOperands()) |
| 79 | return emitOpError("incorrect number of operands for callee" ); |
| 80 | |
| 81 | for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) |
| 82 | if (getOperand(i).getType() != fnType.getInput(i)) |
| 83 | return emitOpError("operand type mismatch: expected operand type " ) |
| 84 | << fnType.getInput(i) << ", but provided " |
| 85 | << getOperand(i).getType() << " for operand number " << i; |
| 86 | |
| 87 | if (fnType.getNumResults() != getNumResults()) |
| 88 | return emitOpError("incorrect number of results for callee" ); |
| 89 | |
| 90 | for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) |
| 91 | if (getResult(i).getType() != fnType.getResult(i)) { |
| 92 | auto diag = emitOpError("result type mismatch at index " ) << i; |
| 93 | diag.attachNote() << " op result types: " << getResultTypes(); |
| 94 | diag.attachNote() << "function result types: " << fnType.getResults(); |
| 95 | return diag; |
| 96 | } |
| 97 | |
| 98 | return success(); |
| 99 | } |
| 100 | |
| 101 | FunctionType CallOp::getCalleeType() { |
| 102 | return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); |
| 103 | } |
| 104 | |
| 105 | //===----------------------------------------------------------------------===// |
| 106 | // CallIndirectOp |
| 107 | //===----------------------------------------------------------------------===// |
| 108 | |
| 109 | /// Fold indirect calls that have a constant function as the callee operand. |
| 110 | LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall, |
| 111 | PatternRewriter &rewriter) { |
| 112 | // Check that the callee is a constant callee. |
| 113 | SymbolRefAttr calledFn; |
| 114 | if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) |
| 115 | return failure(); |
| 116 | |
| 117 | // Replace with a direct call. |
| 118 | rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, |
| 119 | indirectCall.getResultTypes(), |
| 120 | indirectCall.getArgOperands()); |
| 121 | return success(); |
| 122 | } |
| 123 | |
| 124 | //===----------------------------------------------------------------------===// |
| 125 | // ConstantOp |
| 126 | //===----------------------------------------------------------------------===// |
| 127 | |
| 128 | LogicalResult ConstantOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| 129 | StringRef fnName = getValue(); |
| 130 | Type type = getType(); |
| 131 | |
| 132 | // Try to find the referenced function. |
| 133 | auto fn = symbolTable.lookupNearestSymbolFrom<FuncOp>( |
| 134 | this->getOperation(), StringAttr::get(getContext(), fnName)); |
| 135 | if (!fn) |
| 136 | return emitOpError() << "reference to undefined function '" << fnName |
| 137 | << "'" ; |
| 138 | |
| 139 | // Check that the referenced function has the correct type. |
| 140 | if (fn.getFunctionType() != type) |
| 141 | return emitOpError("reference to function with mismatched type" ); |
| 142 | |
| 143 | return success(); |
| 144 | } |
| 145 | |
| 146 | OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { |
| 147 | return getValueAttr(); |
| 148 | } |
| 149 | |
| 150 | void ConstantOp::getAsmResultNames( |
| 151 | function_ref<void(Value, StringRef)> setNameFn) { |
| 152 | setNameFn(getResult(), "f" ); |
| 153 | } |
| 154 | |
| 155 | bool ConstantOp::isBuildableWith(Attribute value, Type type) { |
| 156 | return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type); |
| 157 | } |
| 158 | |
| 159 | //===----------------------------------------------------------------------===// |
| 160 | // FuncOp |
| 161 | //===----------------------------------------------------------------------===// |
| 162 | |
| 163 | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
| 164 | ArrayRef<NamedAttribute> attrs) { |
| 165 | OpBuilder builder(location->getContext()); |
| 166 | OperationState state(location, getOperationName()); |
| 167 | FuncOp::build(builder, state, name, type, attrs); |
| 168 | return cast<FuncOp>(Operation::create(state)); |
| 169 | } |
| 170 | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
| 171 | Operation::dialect_attr_range attrs) { |
| 172 | SmallVector<NamedAttribute, 8> attrRef(attrs); |
| 173 | return create(location, name, type, llvm::ArrayRef(attrRef)); |
| 174 | } |
| 175 | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
| 176 | ArrayRef<NamedAttribute> attrs, |
| 177 | ArrayRef<DictionaryAttr> argAttrs) { |
| 178 | FuncOp func = create(location, name, type, attrs); |
| 179 | func.setAllArgAttrs(argAttrs); |
| 180 | return func; |
| 181 | } |
| 182 | |
| 183 | void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, |
| 184 | FunctionType type, ArrayRef<NamedAttribute> attrs, |
| 185 | ArrayRef<DictionaryAttr> argAttrs) { |
| 186 | state.addAttribute(SymbolTable::getSymbolAttrName(), |
| 187 | builder.getStringAttr(name)); |
| 188 | state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); |
| 189 | state.attributes.append(attrs.begin(), attrs.end()); |
| 190 | state.addRegion(); |
| 191 | |
| 192 | if (argAttrs.empty()) |
| 193 | return; |
| 194 | assert(type.getNumInputs() == argAttrs.size()); |
| 195 | call_interface_impl::addArgAndResultAttrs( |
| 196 | builder, state, argAttrs, /*resultAttrs=*/std::nullopt, |
| 197 | getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); |
| 198 | } |
| 199 | |
| 200 | ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { |
| 201 | auto buildFuncType = |
| 202 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
| 203 | function_interface_impl::VariadicFlag, |
| 204 | std::string &) { return builder.getFunctionType(argTypes, results); }; |
| 205 | |
| 206 | return function_interface_impl::parseFunctionOp( |
| 207 | parser, result, /*allowVariadic=*/false, |
| 208 | getFunctionTypeAttrName(result.name), buildFuncType, |
| 209 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
| 210 | } |
| 211 | |
| 212 | void FuncOp::print(OpAsmPrinter &p) { |
| 213 | function_interface_impl::printFunctionOp( |
| 214 | p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), |
| 215 | getArgAttrsAttrName(), getResAttrsAttrName()); |
| 216 | } |
| 217 | |
| 218 | /// Clone the internal blocks from this function into dest and all attributes |
| 219 | /// from this function to dest. |
| 220 | void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) { |
| 221 | // Add the attributes of this function to dest. |
| 222 | llvm::MapVector<StringAttr, Attribute> newAttrMap; |
| 223 | for (const auto &attr : dest->getAttrs()) |
| 224 | newAttrMap.insert({attr.getName(), attr.getValue()}); |
| 225 | for (const auto &attr : (*this)->getAttrs()) |
| 226 | newAttrMap.insert({attr.getName(), attr.getValue()}); |
| 227 | |
| 228 | auto newAttrs = llvm::to_vector(llvm::map_range( |
| 229 | newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) { |
| 230 | return NamedAttribute(attrPair.first, attrPair.second); |
| 231 | })); |
| 232 | dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs)); |
| 233 | |
| 234 | // Clone the body. |
| 235 | getBody().cloneInto(&dest.getBody(), mapper); |
| 236 | } |
| 237 | |
| 238 | /// Create a deep copy of this function and all of its blocks, remapping |
| 239 | /// any operands that use values outside of the function using the map that is |
| 240 | /// provided (leaving them alone if no entry is present). Replaces references |
| 241 | /// to cloned sub-values with the corresponding value that is copied, and adds |
| 242 | /// those mappings to the mapper. |
| 243 | FuncOp FuncOp::clone(IRMapping &mapper) { |
| 244 | // Create the new function. |
| 245 | FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions()); |
| 246 | |
| 247 | // If the function has a body, then the user might be deleting arguments to |
| 248 | // the function by specifying them in the mapper. If so, we don't add the |
| 249 | // argument to the input type vector. |
| 250 | if (!isExternal()) { |
| 251 | FunctionType oldType = getFunctionType(); |
| 252 | |
| 253 | unsigned oldNumArgs = oldType.getNumInputs(); |
| 254 | SmallVector<Type, 4> newInputs; |
| 255 | newInputs.reserve(oldNumArgs); |
| 256 | for (unsigned i = 0; i != oldNumArgs; ++i) |
| 257 | if (!mapper.contains(getArgument(i))) |
| 258 | newInputs.push_back(oldType.getInput(i)); |
| 259 | |
| 260 | /// If any of the arguments were dropped, update the type and drop any |
| 261 | /// necessary argument attributes. |
| 262 | if (newInputs.size() != oldNumArgs) { |
| 263 | newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, |
| 264 | oldType.getResults())); |
| 265 | |
| 266 | if (ArrayAttr argAttrs = getAllArgAttrs()) { |
| 267 | SmallVector<Attribute> newArgAttrs; |
| 268 | newArgAttrs.reserve(newInputs.size()); |
| 269 | for (unsigned i = 0; i != oldNumArgs; ++i) |
| 270 | if (!mapper.contains(getArgument(i))) |
| 271 | newArgAttrs.push_back(argAttrs[i]); |
| 272 | newFunc.setAllArgAttrs(newArgAttrs); |
| 273 | } |
| 274 | } |
| 275 | } |
| 276 | |
| 277 | /// Clone the current function into the new one and return it. |
| 278 | cloneInto(newFunc, mapper); |
| 279 | return newFunc; |
| 280 | } |
| 281 | FuncOp FuncOp::clone() { |
| 282 | IRMapping mapper; |
| 283 | return clone(mapper); |
| 284 | } |
| 285 | |
| 286 | //===----------------------------------------------------------------------===// |
| 287 | // ReturnOp |
| 288 | //===----------------------------------------------------------------------===// |
| 289 | |
| 290 | LogicalResult ReturnOp::verify() { |
| 291 | auto function = cast<FuncOp>((*this)->getParentOp()); |
| 292 | |
| 293 | // The operand number and types must match the function signature. |
| 294 | const auto &results = function.getFunctionType().getResults(); |
| 295 | if (getNumOperands() != results.size()) |
| 296 | return emitOpError("has " ) |
| 297 | << getNumOperands() << " operands, but enclosing function (@" |
| 298 | << function.getName() << ") returns " << results.size(); |
| 299 | |
| 300 | for (unsigned i = 0, e = results.size(); i != e; ++i) |
| 301 | if (getOperand(i).getType() != results[i]) |
| 302 | return emitError() << "type of return operand " << i << " (" |
| 303 | << getOperand(i).getType() |
| 304 | << ") doesn't match function result type (" |
| 305 | << results[i] << ")" |
| 306 | << " in function @" << function.getName(); |
| 307 | |
| 308 | return success(); |
| 309 | } |
| 310 | |
| 311 | //===----------------------------------------------------------------------===// |
| 312 | // TableGen'd op method definitions |
| 313 | //===----------------------------------------------------------------------===// |
| 314 | |
| 315 | #define GET_OP_CLASSES |
| 316 | #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" |
| 317 | |