1 | //===- FuncOps.cpp - Func Dialect Operations ------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
10 | |
11 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
12 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
13 | #include "mlir/IR/BuiltinOps.h" |
14 | #include "mlir/IR/BuiltinTypes.h" |
15 | #include "mlir/IR/IRMapping.h" |
16 | #include "mlir/IR/Matchers.h" |
17 | #include "mlir/IR/OpImplementation.h" |
18 | #include "mlir/IR/PatternMatch.h" |
19 | #include "mlir/IR/TypeUtilities.h" |
20 | #include "mlir/IR/Value.h" |
21 | #include "mlir/Interfaces/FunctionImplementation.h" |
22 | #include "mlir/Support/MathExtras.h" |
23 | #include "mlir/Transforms/InliningUtils.h" |
24 | #include "llvm/ADT/APFloat.h" |
25 | #include "llvm/ADT/MapVector.h" |
26 | #include "llvm/ADT/STLExtras.h" |
27 | #include "llvm/Support/FormatVariadic.h" |
28 | #include "llvm/Support/raw_ostream.h" |
29 | #include <numeric> |
30 | |
31 | #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc" |
32 | |
33 | using namespace mlir; |
34 | using namespace mlir::func; |
35 | |
36 | //===----------------------------------------------------------------------===// |
37 | // FuncDialect |
38 | //===----------------------------------------------------------------------===// |
39 | |
40 | void FuncDialect::initialize() { |
41 | addOperations< |
42 | #define GET_OP_LIST |
43 | #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" |
44 | >(); |
45 | declarePromisedInterface<DialectInlinerInterface, FuncDialect>(); |
46 | declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>(); |
47 | declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp, |
48 | FuncOp, ReturnOp>(); |
49 | } |
50 | |
51 | /// Materialize a single constant operation from a given attribute value with |
52 | /// the desired resultant type. |
53 | Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value, |
54 | Type type, Location loc) { |
55 | if (ConstantOp::isBuildableWith(value, type)) |
56 | return builder.create<ConstantOp>(loc, type, |
57 | llvm::cast<FlatSymbolRefAttr>(value)); |
58 | return nullptr; |
59 | } |
60 | |
61 | //===----------------------------------------------------------------------===// |
62 | // CallOp |
63 | //===----------------------------------------------------------------------===// |
64 | |
65 | LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
66 | // Check that the callee attribute was specified. |
67 | auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee" ); |
68 | if (!fnAttr) |
69 | return emitOpError("requires a 'callee' symbol reference attribute" ); |
70 | FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr); |
71 | if (!fn) |
72 | return emitOpError() << "'" << fnAttr.getValue() |
73 | << "' does not reference a valid function" ; |
74 | |
75 | // Verify that the operand and result types match the callee. |
76 | auto fnType = fn.getFunctionType(); |
77 | if (fnType.getNumInputs() != getNumOperands()) |
78 | return emitOpError("incorrect number of operands for callee" ); |
79 | |
80 | for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) |
81 | if (getOperand(i).getType() != fnType.getInput(i)) |
82 | return emitOpError("operand type mismatch: expected operand type " ) |
83 | << fnType.getInput(i) << ", but provided " |
84 | << getOperand(i).getType() << " for operand number " << i; |
85 | |
86 | if (fnType.getNumResults() != getNumResults()) |
87 | return emitOpError("incorrect number of results for callee" ); |
88 | |
89 | for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) |
90 | if (getResult(i).getType() != fnType.getResult(i)) { |
91 | auto diag = emitOpError("result type mismatch at index " ) << i; |
92 | diag.attachNote() << " op result types: " << getResultTypes(); |
93 | diag.attachNote() << "function result types: " << fnType.getResults(); |
94 | return diag; |
95 | } |
96 | |
97 | return success(); |
98 | } |
99 | |
100 | FunctionType CallOp::getCalleeType() { |
101 | return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); |
102 | } |
103 | |
104 | //===----------------------------------------------------------------------===// |
105 | // CallIndirectOp |
106 | //===----------------------------------------------------------------------===// |
107 | |
108 | /// Fold indirect calls that have a constant function as the callee operand. |
109 | LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall, |
110 | PatternRewriter &rewriter) { |
111 | // Check that the callee is a constant callee. |
112 | SymbolRefAttr calledFn; |
113 | if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) |
114 | return failure(); |
115 | |
116 | // Replace with a direct call. |
117 | rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, |
118 | indirectCall.getResultTypes(), |
119 | indirectCall.getArgOperands()); |
120 | return success(); |
121 | } |
122 | |
123 | //===----------------------------------------------------------------------===// |
124 | // ConstantOp |
125 | //===----------------------------------------------------------------------===// |
126 | |
127 | LogicalResult ConstantOp::verify() { |
128 | StringRef fnName = getValue(); |
129 | Type type = getType(); |
130 | |
131 | // Try to find the referenced function. |
132 | auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName); |
133 | if (!fn) |
134 | return emitOpError() << "reference to undefined function '" << fnName |
135 | << "'" ; |
136 | |
137 | // Check that the referenced function has the correct type. |
138 | if (fn.getFunctionType() != type) |
139 | return emitOpError("reference to function with mismatched type" ); |
140 | |
141 | return success(); |
142 | } |
143 | |
144 | OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { |
145 | return getValueAttr(); |
146 | } |
147 | |
148 | void ConstantOp::getAsmResultNames( |
149 | function_ref<void(Value, StringRef)> setNameFn) { |
150 | setNameFn(getResult(), "f" ); |
151 | } |
152 | |
153 | bool ConstantOp::isBuildableWith(Attribute value, Type type) { |
154 | return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type); |
155 | } |
156 | |
157 | //===----------------------------------------------------------------------===// |
158 | // FuncOp |
159 | //===----------------------------------------------------------------------===// |
160 | |
161 | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
162 | ArrayRef<NamedAttribute> attrs) { |
163 | OpBuilder builder(location->getContext()); |
164 | OperationState state(location, getOperationName()); |
165 | FuncOp::build(builder, state, name, type, attrs); |
166 | return cast<FuncOp>(Operation::create(state)); |
167 | } |
168 | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
169 | Operation::dialect_attr_range attrs) { |
170 | SmallVector<NamedAttribute, 8> attrRef(attrs); |
171 | return create(location, name, type, llvm::ArrayRef(attrRef)); |
172 | } |
173 | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
174 | ArrayRef<NamedAttribute> attrs, |
175 | ArrayRef<DictionaryAttr> argAttrs) { |
176 | FuncOp func = create(location, name, type, attrs); |
177 | func.setAllArgAttrs(argAttrs); |
178 | return func; |
179 | } |
180 | |
181 | void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, |
182 | FunctionType type, ArrayRef<NamedAttribute> attrs, |
183 | ArrayRef<DictionaryAttr> argAttrs) { |
184 | state.addAttribute(SymbolTable::getSymbolAttrName(), |
185 | builder.getStringAttr(name)); |
186 | state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); |
187 | state.attributes.append(attrs.begin(), attrs.end()); |
188 | state.addRegion(); |
189 | |
190 | if (argAttrs.empty()) |
191 | return; |
192 | assert(type.getNumInputs() == argAttrs.size()); |
193 | function_interface_impl::addArgAndResultAttrs( |
194 | builder, state, argAttrs, /*resultAttrs=*/std::nullopt, |
195 | getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); |
196 | } |
197 | |
198 | ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { |
199 | auto buildFuncType = |
200 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
201 | function_interface_impl::VariadicFlag, |
202 | std::string &) { return builder.getFunctionType(argTypes, results); }; |
203 | |
204 | return function_interface_impl::parseFunctionOp( |
205 | parser, result, /*allowVariadic=*/false, |
206 | getFunctionTypeAttrName(result.name), buildFuncType, |
207 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
208 | } |
209 | |
210 | void FuncOp::print(OpAsmPrinter &p) { |
211 | function_interface_impl::printFunctionOp( |
212 | p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), |
213 | getArgAttrsAttrName(), getResAttrsAttrName()); |
214 | } |
215 | |
216 | /// Clone the internal blocks from this function into dest and all attributes |
217 | /// from this function to dest. |
218 | void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) { |
219 | // Add the attributes of this function to dest. |
220 | llvm::MapVector<StringAttr, Attribute> newAttrMap; |
221 | for (const auto &attr : dest->getAttrs()) |
222 | newAttrMap.insert({attr.getName(), attr.getValue()}); |
223 | for (const auto &attr : (*this)->getAttrs()) |
224 | newAttrMap.insert({attr.getName(), attr.getValue()}); |
225 | |
226 | auto newAttrs = llvm::to_vector(llvm::map_range( |
227 | newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) { |
228 | return NamedAttribute(attrPair.first, attrPair.second); |
229 | })); |
230 | dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs)); |
231 | |
232 | // Clone the body. |
233 | getBody().cloneInto(&dest.getBody(), mapper); |
234 | } |
235 | |
236 | /// Create a deep copy of this function and all of its blocks, remapping |
237 | /// any operands that use values outside of the function using the map that is |
238 | /// provided (leaving them alone if no entry is present). Replaces references |
239 | /// to cloned sub-values with the corresponding value that is copied, and adds |
240 | /// those mappings to the mapper. |
241 | FuncOp FuncOp::clone(IRMapping &mapper) { |
242 | // Create the new function. |
243 | FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions()); |
244 | |
245 | // If the function has a body, then the user might be deleting arguments to |
246 | // the function by specifying them in the mapper. If so, we don't add the |
247 | // argument to the input type vector. |
248 | if (!isExternal()) { |
249 | FunctionType oldType = getFunctionType(); |
250 | |
251 | unsigned oldNumArgs = oldType.getNumInputs(); |
252 | SmallVector<Type, 4> newInputs; |
253 | newInputs.reserve(oldNumArgs); |
254 | for (unsigned i = 0; i != oldNumArgs; ++i) |
255 | if (!mapper.contains(getArgument(i))) |
256 | newInputs.push_back(oldType.getInput(i)); |
257 | |
258 | /// If any of the arguments were dropped, update the type and drop any |
259 | /// necessary argument attributes. |
260 | if (newInputs.size() != oldNumArgs) { |
261 | newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, |
262 | oldType.getResults())); |
263 | |
264 | if (ArrayAttr argAttrs = getAllArgAttrs()) { |
265 | SmallVector<Attribute> newArgAttrs; |
266 | newArgAttrs.reserve(newInputs.size()); |
267 | for (unsigned i = 0; i != oldNumArgs; ++i) |
268 | if (!mapper.contains(getArgument(i))) |
269 | newArgAttrs.push_back(argAttrs[i]); |
270 | newFunc.setAllArgAttrs(newArgAttrs); |
271 | } |
272 | } |
273 | } |
274 | |
275 | /// Clone the current function into the new one and return it. |
276 | cloneInto(newFunc, mapper); |
277 | return newFunc; |
278 | } |
279 | FuncOp FuncOp::clone() { |
280 | IRMapping mapper; |
281 | return clone(mapper); |
282 | } |
283 | |
284 | //===----------------------------------------------------------------------===// |
285 | // ReturnOp |
286 | //===----------------------------------------------------------------------===// |
287 | |
288 | LogicalResult ReturnOp::verify() { |
289 | auto function = cast<FuncOp>((*this)->getParentOp()); |
290 | |
291 | // The operand number and types must match the function signature. |
292 | const auto &results = function.getFunctionType().getResults(); |
293 | if (getNumOperands() != results.size()) |
294 | return emitOpError("has " ) |
295 | << getNumOperands() << " operands, but enclosing function (@" |
296 | << function.getName() << ") returns " << results.size(); |
297 | |
298 | for (unsigned i = 0, e = results.size(); i != e; ++i) |
299 | if (getOperand(i).getType() != results[i]) |
300 | return emitError() << "type of return operand " << i << " (" |
301 | << getOperand(i).getType() |
302 | << ") doesn't match function result type (" |
303 | << results[i] << ")" |
304 | << " in function @" << function.getName(); |
305 | |
306 | return success(); |
307 | } |
308 | |
309 | //===----------------------------------------------------------------------===// |
310 | // TableGen'd op method definitions |
311 | //===----------------------------------------------------------------------===// |
312 | |
313 | #define GET_OP_CLASSES |
314 | #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" |
315 | |