1 | //===- FuncOps.cpp - Func Dialect Operations ------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
10 | |
11 | #include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" |
12 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
13 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
14 | #include "mlir/IR/BuiltinOps.h" |
15 | #include "mlir/IR/BuiltinTypes.h" |
16 | #include "mlir/IR/IRMapping.h" |
17 | #include "mlir/IR/Matchers.h" |
18 | #include "mlir/IR/OpImplementation.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/IR/TypeUtilities.h" |
21 | #include "mlir/IR/Value.h" |
22 | #include "mlir/Interfaces/FunctionImplementation.h" |
23 | #include "mlir/Transforms/InliningUtils.h" |
24 | #include "llvm/ADT/APFloat.h" |
25 | #include "llvm/ADT/MapVector.h" |
26 | #include "llvm/ADT/STLExtras.h" |
27 | #include "llvm/Support/FormatVariadic.h" |
28 | #include "llvm/Support/raw_ostream.h" |
29 | #include <numeric> |
30 | |
31 | #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc" |
32 | |
33 | using namespace mlir; |
34 | using namespace mlir::func; |
35 | |
36 | //===----------------------------------------------------------------------===// |
37 | // FuncDialect |
38 | //===----------------------------------------------------------------------===// |
39 | |
40 | void FuncDialect::initialize() { |
41 | addOperations< |
42 | #define GET_OP_LIST |
43 | #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" |
44 | >(); |
45 | declarePromisedInterface<ConvertToEmitCPatternInterface, FuncDialect>(); |
46 | declarePromisedInterface<DialectInlinerInterface, FuncDialect>(); |
47 | declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>(); |
48 | declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp, |
49 | FuncOp, ReturnOp>(); |
50 | } |
51 | |
52 | /// Materialize a single constant operation from a given attribute value with |
53 | /// the desired resultant type. |
54 | Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value, |
55 | Type type, Location loc) { |
56 | if (ConstantOp::isBuildableWith(value, type)) |
57 | return builder.create<ConstantOp>(loc, type, |
58 | llvm::cast<FlatSymbolRefAttr>(value)); |
59 | return nullptr; |
60 | } |
61 | |
62 | //===----------------------------------------------------------------------===// |
63 | // CallOp |
64 | //===----------------------------------------------------------------------===// |
65 | |
66 | LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
67 | // Check that the callee attribute was specified. |
68 | auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee" ); |
69 | if (!fnAttr) |
70 | return emitOpError("requires a 'callee' symbol reference attribute" ); |
71 | FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr); |
72 | if (!fn) |
73 | return emitOpError() << "'" << fnAttr.getValue() |
74 | << "' does not reference a valid function" ; |
75 | |
76 | // Verify that the operand and result types match the callee. |
77 | auto fnType = fn.getFunctionType(); |
78 | if (fnType.getNumInputs() != getNumOperands()) |
79 | return emitOpError("incorrect number of operands for callee" ); |
80 | |
81 | for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) |
82 | if (getOperand(i).getType() != fnType.getInput(i)) |
83 | return emitOpError("operand type mismatch: expected operand type " ) |
84 | << fnType.getInput(i) << ", but provided " |
85 | << getOperand(i).getType() << " for operand number " << i; |
86 | |
87 | if (fnType.getNumResults() != getNumResults()) |
88 | return emitOpError("incorrect number of results for callee" ); |
89 | |
90 | for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) |
91 | if (getResult(i).getType() != fnType.getResult(i)) { |
92 | auto diag = emitOpError("result type mismatch at index " ) << i; |
93 | diag.attachNote() << " op result types: " << getResultTypes(); |
94 | diag.attachNote() << "function result types: " << fnType.getResults(); |
95 | return diag; |
96 | } |
97 | |
98 | return success(); |
99 | } |
100 | |
101 | FunctionType CallOp::getCalleeType() { |
102 | return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); |
103 | } |
104 | |
105 | //===----------------------------------------------------------------------===// |
106 | // CallIndirectOp |
107 | //===----------------------------------------------------------------------===// |
108 | |
109 | /// Fold indirect calls that have a constant function as the callee operand. |
110 | LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall, |
111 | PatternRewriter &rewriter) { |
112 | // Check that the callee is a constant callee. |
113 | SymbolRefAttr calledFn; |
114 | if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) |
115 | return failure(); |
116 | |
117 | // Replace with a direct call. |
118 | rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, |
119 | indirectCall.getResultTypes(), |
120 | indirectCall.getArgOperands()); |
121 | return success(); |
122 | } |
123 | |
124 | //===----------------------------------------------------------------------===// |
125 | // ConstantOp |
126 | //===----------------------------------------------------------------------===// |
127 | |
128 | LogicalResult ConstantOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
129 | StringRef fnName = getValue(); |
130 | Type type = getType(); |
131 | |
132 | // Try to find the referenced function. |
133 | auto fn = symbolTable.lookupNearestSymbolFrom<FuncOp>( |
134 | this->getOperation(), StringAttr::get(getContext(), fnName)); |
135 | if (!fn) |
136 | return emitOpError() << "reference to undefined function '" << fnName |
137 | << "'" ; |
138 | |
139 | // Check that the referenced function has the correct type. |
140 | if (fn.getFunctionType() != type) |
141 | return emitOpError("reference to function with mismatched type" ); |
142 | |
143 | return success(); |
144 | } |
145 | |
146 | OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { |
147 | return getValueAttr(); |
148 | } |
149 | |
150 | void ConstantOp::getAsmResultNames( |
151 | function_ref<void(Value, StringRef)> setNameFn) { |
152 | setNameFn(getResult(), "f" ); |
153 | } |
154 | |
155 | bool ConstantOp::isBuildableWith(Attribute value, Type type) { |
156 | return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type); |
157 | } |
158 | |
159 | //===----------------------------------------------------------------------===// |
160 | // FuncOp |
161 | //===----------------------------------------------------------------------===// |
162 | |
163 | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
164 | ArrayRef<NamedAttribute> attrs) { |
165 | OpBuilder builder(location->getContext()); |
166 | OperationState state(location, getOperationName()); |
167 | FuncOp::build(builder, state, name, type, attrs); |
168 | return cast<FuncOp>(Operation::create(state)); |
169 | } |
170 | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
171 | Operation::dialect_attr_range attrs) { |
172 | SmallVector<NamedAttribute, 8> attrRef(attrs); |
173 | return create(location, name, type, llvm::ArrayRef(attrRef)); |
174 | } |
175 | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
176 | ArrayRef<NamedAttribute> attrs, |
177 | ArrayRef<DictionaryAttr> argAttrs) { |
178 | FuncOp func = create(location, name, type, attrs); |
179 | func.setAllArgAttrs(argAttrs); |
180 | return func; |
181 | } |
182 | |
183 | void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, |
184 | FunctionType type, ArrayRef<NamedAttribute> attrs, |
185 | ArrayRef<DictionaryAttr> argAttrs) { |
186 | state.addAttribute(SymbolTable::getSymbolAttrName(), |
187 | builder.getStringAttr(name)); |
188 | state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); |
189 | state.attributes.append(attrs.begin(), attrs.end()); |
190 | state.addRegion(); |
191 | |
192 | if (argAttrs.empty()) |
193 | return; |
194 | assert(type.getNumInputs() == argAttrs.size()); |
195 | call_interface_impl::addArgAndResultAttrs( |
196 | builder, state, argAttrs, /*resultAttrs=*/std::nullopt, |
197 | getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); |
198 | } |
199 | |
200 | ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { |
201 | auto buildFuncType = |
202 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
203 | function_interface_impl::VariadicFlag, |
204 | std::string &) { return builder.getFunctionType(argTypes, results); }; |
205 | |
206 | return function_interface_impl::parseFunctionOp( |
207 | parser, result, /*allowVariadic=*/false, |
208 | getFunctionTypeAttrName(result.name), buildFuncType, |
209 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
210 | } |
211 | |
212 | void FuncOp::print(OpAsmPrinter &p) { |
213 | function_interface_impl::printFunctionOp( |
214 | p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), |
215 | getArgAttrsAttrName(), getResAttrsAttrName()); |
216 | } |
217 | |
218 | /// Clone the internal blocks from this function into dest and all attributes |
219 | /// from this function to dest. |
220 | void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) { |
221 | // Add the attributes of this function to dest. |
222 | llvm::MapVector<StringAttr, Attribute> newAttrMap; |
223 | for (const auto &attr : dest->getAttrs()) |
224 | newAttrMap.insert({attr.getName(), attr.getValue()}); |
225 | for (const auto &attr : (*this)->getAttrs()) |
226 | newAttrMap.insert({attr.getName(), attr.getValue()}); |
227 | |
228 | auto newAttrs = llvm::to_vector(llvm::map_range( |
229 | newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) { |
230 | return NamedAttribute(attrPair.first, attrPair.second); |
231 | })); |
232 | dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs)); |
233 | |
234 | // Clone the body. |
235 | getBody().cloneInto(&dest.getBody(), mapper); |
236 | } |
237 | |
238 | /// Create a deep copy of this function and all of its blocks, remapping |
239 | /// any operands that use values outside of the function using the map that is |
240 | /// provided (leaving them alone if no entry is present). Replaces references |
241 | /// to cloned sub-values with the corresponding value that is copied, and adds |
242 | /// those mappings to the mapper. |
243 | FuncOp FuncOp::clone(IRMapping &mapper) { |
244 | // Create the new function. |
245 | FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions()); |
246 | |
247 | // If the function has a body, then the user might be deleting arguments to |
248 | // the function by specifying them in the mapper. If so, we don't add the |
249 | // argument to the input type vector. |
250 | if (!isExternal()) { |
251 | FunctionType oldType = getFunctionType(); |
252 | |
253 | unsigned oldNumArgs = oldType.getNumInputs(); |
254 | SmallVector<Type, 4> newInputs; |
255 | newInputs.reserve(oldNumArgs); |
256 | for (unsigned i = 0; i != oldNumArgs; ++i) |
257 | if (!mapper.contains(getArgument(i))) |
258 | newInputs.push_back(oldType.getInput(i)); |
259 | |
260 | /// If any of the arguments were dropped, update the type and drop any |
261 | /// necessary argument attributes. |
262 | if (newInputs.size() != oldNumArgs) { |
263 | newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, |
264 | oldType.getResults())); |
265 | |
266 | if (ArrayAttr argAttrs = getAllArgAttrs()) { |
267 | SmallVector<Attribute> newArgAttrs; |
268 | newArgAttrs.reserve(newInputs.size()); |
269 | for (unsigned i = 0; i != oldNumArgs; ++i) |
270 | if (!mapper.contains(getArgument(i))) |
271 | newArgAttrs.push_back(argAttrs[i]); |
272 | newFunc.setAllArgAttrs(newArgAttrs); |
273 | } |
274 | } |
275 | } |
276 | |
277 | /// Clone the current function into the new one and return it. |
278 | cloneInto(newFunc, mapper); |
279 | return newFunc; |
280 | } |
281 | FuncOp FuncOp::clone() { |
282 | IRMapping mapper; |
283 | return clone(mapper); |
284 | } |
285 | |
286 | //===----------------------------------------------------------------------===// |
287 | // ReturnOp |
288 | //===----------------------------------------------------------------------===// |
289 | |
290 | LogicalResult ReturnOp::verify() { |
291 | auto function = cast<FuncOp>((*this)->getParentOp()); |
292 | |
293 | // The operand number and types must match the function signature. |
294 | const auto &results = function.getFunctionType().getResults(); |
295 | if (getNumOperands() != results.size()) |
296 | return emitOpError("has " ) |
297 | << getNumOperands() << " operands, but enclosing function (@" |
298 | << function.getName() << ") returns " << results.size(); |
299 | |
300 | for (unsigned i = 0, e = results.size(); i != e; ++i) |
301 | if (getOperand(i).getType() != results[i]) |
302 | return emitError() << "type of return operand " << i << " (" |
303 | << getOperand(i).getType() |
304 | << ") doesn't match function result type (" |
305 | << results[i] << ")" |
306 | << " in function @" << function.getName(); |
307 | |
308 | return success(); |
309 | } |
310 | |
311 | //===----------------------------------------------------------------------===// |
312 | // TableGen'd op method definitions |
313 | //===----------------------------------------------------------------------===// |
314 | |
315 | #define GET_OP_CLASSES |
316 | #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" |
317 | |