1//===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10
11#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h"
12#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14#include "mlir/IR/BuiltinTypes.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Matchers.h"
17#include "mlir/IR/OpImplementation.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/IR/TypeUtilities.h"
20#include "mlir/IR/Value.h"
21#include "mlir/Interfaces/FunctionImplementation.h"
22#include "mlir/Transforms/InliningUtils.h"
23#include "llvm/ADT/APFloat.h"
24#include "llvm/ADT/MapVector.h"
25#include "llvm/ADT/STLExtras.h"
26
27#include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
28
29using namespace mlir;
30using namespace mlir::func;
31
32//===----------------------------------------------------------------------===//
33// FuncDialect
34//===----------------------------------------------------------------------===//
35
36void FuncDialect::initialize() {
37 addOperations<
38#define GET_OP_LIST
39#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
40 >();
41 declarePromisedInterface<ConvertToEmitCPatternInterface, FuncDialect>();
42 declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
43 declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
44 declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
45 FuncOp, ReturnOp>();
46}
47
48/// Materialize a single constant operation from a given attribute value with
49/// the desired resultant type.
50Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value,
51 Type type, Location loc) {
52 if (ConstantOp::isBuildableWith(value, type))
53 return builder.create<ConstantOp>(location: loc, args&: type,
54 args: llvm::cast<FlatSymbolRefAttr>(Val&: value));
55 return nullptr;
56}
57
58//===----------------------------------------------------------------------===//
59// CallOp
60//===----------------------------------------------------------------------===//
61
62LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
63 // Check that the callee attribute was specified.
64 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>(name: "callee");
65 if (!fnAttr)
66 return emitOpError(message: "requires a 'callee' symbol reference attribute");
67 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(from: *this, symbol: fnAttr);
68 if (!fn)
69 return emitOpError() << "'" << fnAttr.getValue()
70 << "' does not reference a valid function";
71
72 // Verify that the operand and result types match the callee.
73 auto fnType = fn.getFunctionType();
74 if (fnType.getNumInputs() != getNumOperands())
75 return emitOpError(message: "incorrect number of operands for callee");
76
77 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
78 if (getOperand(i).getType() != fnType.getInput(i))
79 return emitOpError(message: "operand type mismatch: expected operand type ")
80 << fnType.getInput(i) << ", but provided "
81 << getOperand(i).getType() << " for operand number " << i;
82
83 if (fnType.getNumResults() != getNumResults())
84 return emitOpError(message: "incorrect number of results for callee");
85
86 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
87 if (getResult(i).getType() != fnType.getResult(i)) {
88 auto diag = emitOpError(message: "result type mismatch at index ") << i;
89 diag.attachNote() << " op result types: " << getResultTypes();
90 diag.attachNote() << "function result types: " << fnType.getResults();
91 return diag;
92 }
93
94 return success();
95}
96
97FunctionType CallOp::getCalleeType() {
98 return FunctionType::get(context: getContext(), inputs: getOperandTypes(), results: getResultTypes());
99}
100
101//===----------------------------------------------------------------------===//
102// CallIndirectOp
103//===----------------------------------------------------------------------===//
104
105/// Fold indirect calls that have a constant function as the callee operand.
106LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
107 PatternRewriter &rewriter) {
108 // Check that the callee is a constant callee.
109 SymbolRefAttr calledFn;
110 if (!matchPattern(value: indirectCall.getCallee(), pattern: m_Constant(bind_value: &calledFn)))
111 return failure();
112
113 // Replace with a direct call.
114 rewriter.replaceOpWithNewOp<CallOp>(op: indirectCall, args&: calledFn,
115 args: indirectCall.getResultTypes(),
116 args: indirectCall.getArgOperands());
117 return success();
118}
119
120//===----------------------------------------------------------------------===//
121// ConstantOp
122//===----------------------------------------------------------------------===//
123
124LogicalResult ConstantOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
125 StringRef fnName = getValue();
126 Type type = getType();
127
128 // Try to find the referenced function.
129 auto fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(
130 from: this->getOperation(), symbol: StringAttr::get(context: getContext(), bytes: fnName));
131 if (!fn)
132 return emitOpError() << "reference to undefined function '" << fnName
133 << "'";
134
135 // Check that the referenced function has the correct type.
136 if (fn.getFunctionType() != type)
137 return emitOpError(message: "reference to function with mismatched type");
138
139 return success();
140}
141
142OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
143 return getValueAttr();
144}
145
146void ConstantOp::getAsmResultNames(
147 function_ref<void(Value, StringRef)> setNameFn) {
148 setNameFn(getResult(), "f");
149}
150
151bool ConstantOp::isBuildableWith(Attribute value, Type type) {
152 return llvm::isa<FlatSymbolRefAttr>(Val: value) && llvm::isa<FunctionType>(Val: type);
153}
154
155//===----------------------------------------------------------------------===//
156// FuncOp
157//===----------------------------------------------------------------------===//
158
159FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
160 ArrayRef<NamedAttribute> attrs) {
161 OpBuilder builder(location->getContext());
162 OperationState state(location, getOperationName());
163 FuncOp::build(odsBuilder&: builder, odsState&: state, name, type, attrs);
164 return cast<FuncOp>(Val: Operation::create(state));
165}
166FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
167 Operation::dialect_attr_range attrs) {
168 SmallVector<NamedAttribute, 8> attrRef(attrs);
169 return create(location, name, type, attrs: llvm::ArrayRef(attrRef));
170}
171FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
172 ArrayRef<NamedAttribute> attrs,
173 ArrayRef<DictionaryAttr> argAttrs) {
174 FuncOp func = create(location, name, type, attrs);
175 func.setAllArgAttrs(argAttrs);
176 return func;
177}
178
179void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
180 FunctionType type, ArrayRef<NamedAttribute> attrs,
181 ArrayRef<DictionaryAttr> argAttrs) {
182 state.addAttribute(name: SymbolTable::getSymbolAttrName(),
183 attr: builder.getStringAttr(bytes: name));
184 state.addAttribute(name: getFunctionTypeAttrName(name: state.name), attr: TypeAttr::get(type));
185 state.attributes.append(inStart: attrs.begin(), inEnd: attrs.end());
186 state.addRegion();
187
188 if (argAttrs.empty())
189 return;
190 assert(type.getNumInputs() == argAttrs.size());
191 call_interface_impl::addArgAndResultAttrs(
192 builder, result&: state, argAttrs, /*resultAttrs=*/{},
193 argAttrsName: getArgAttrsAttrName(name: state.name), resAttrsName: getResAttrsAttrName(name: state.name));
194}
195
196ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
197 auto buildFuncType =
198 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
199 function_interface_impl::VariadicFlag,
200 std::string &) { return builder.getFunctionType(inputs: argTypes, results); };
201
202 return function_interface_impl::parseFunctionOp(
203 parser, result, /*allowVariadic=*/false,
204 typeAttrName: getFunctionTypeAttrName(name: result.name), funcTypeBuilder: buildFuncType,
205 argAttrsName: getArgAttrsAttrName(name: result.name), resAttrsName: getResAttrsAttrName(name: result.name));
206}
207
208void FuncOp::print(OpAsmPrinter &p) {
209 function_interface_impl::printFunctionOp(
210 p, op: *this, /*isVariadic=*/false, typeAttrName: getFunctionTypeAttrName(),
211 argAttrsName: getArgAttrsAttrName(), resAttrsName: getResAttrsAttrName());
212}
213
214/// Clone the internal blocks from this function into dest and all attributes
215/// from this function to dest.
216void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
217 // Add the attributes of this function to dest.
218 llvm::MapVector<StringAttr, Attribute> newAttrMap;
219 for (const auto &attr : dest->getAttrs())
220 newAttrMap.insert(KV: {attr.getName(), attr.getValue()});
221 for (const auto &attr : (*this)->getAttrs())
222 newAttrMap.insert(KV: {attr.getName(), attr.getValue()});
223
224 auto newAttrs = llvm::to_vector(Range: llvm::map_range(
225 C&: newAttrMap, F: [](std::pair<StringAttr, Attribute> attrPair) {
226 return NamedAttribute(attrPair.first, attrPair.second);
227 }));
228 dest->setAttrs(DictionaryAttr::get(context: getContext(), value: newAttrs));
229
230 // Clone the body.
231 getBody().cloneInto(dest: &dest.getBody(), mapper);
232}
233
234/// Create a deep copy of this function and all of its blocks, remapping
235/// any operands that use values outside of the function using the map that is
236/// provided (leaving them alone if no entry is present). Replaces references
237/// to cloned sub-values with the corresponding value that is copied, and adds
238/// those mappings to the mapper.
239FuncOp FuncOp::clone(IRMapping &mapper) {
240 // Create the new function.
241 FuncOp newFunc = cast<FuncOp>(Val: getOperation()->cloneWithoutRegions());
242
243 // If the function has a body, then the user might be deleting arguments to
244 // the function by specifying them in the mapper. If so, we don't add the
245 // argument to the input type vector.
246 if (!isExternal()) {
247 FunctionType oldType = getFunctionType();
248
249 unsigned oldNumArgs = oldType.getNumInputs();
250 SmallVector<Type, 4> newInputs;
251 newInputs.reserve(N: oldNumArgs);
252 for (unsigned i = 0; i != oldNumArgs; ++i)
253 if (!mapper.contains(from: getArgument(idx: i)))
254 newInputs.push_back(Elt: oldType.getInput(i));
255
256 /// If any of the arguments were dropped, update the type and drop any
257 /// necessary argument attributes.
258 if (newInputs.size() != oldNumArgs) {
259 newFunc.setType(FunctionType::get(context: oldType.getContext(), inputs: newInputs,
260 results: oldType.getResults()));
261
262 if (ArrayAttr argAttrs = getAllArgAttrs()) {
263 SmallVector<Attribute> newArgAttrs;
264 newArgAttrs.reserve(N: newInputs.size());
265 for (unsigned i = 0; i != oldNumArgs; ++i)
266 if (!mapper.contains(from: getArgument(idx: i)))
267 newArgAttrs.push_back(Elt: argAttrs[i]);
268 newFunc.setAllArgAttrs(newArgAttrs);
269 }
270 }
271 }
272
273 /// Clone the current function into the new one and return it.
274 cloneInto(dest: newFunc, mapper);
275 return newFunc;
276}
277FuncOp FuncOp::clone() {
278 IRMapping mapper;
279 return clone(mapper);
280}
281
282//===----------------------------------------------------------------------===//
283// ReturnOp
284//===----------------------------------------------------------------------===//
285
286LogicalResult ReturnOp::verify() {
287 auto function = cast<FuncOp>(Val: (*this)->getParentOp());
288
289 // The operand number and types must match the function signature.
290 const auto &results = function.getFunctionType().getResults();
291 if (getNumOperands() != results.size())
292 return emitOpError(message: "has ")
293 << getNumOperands() << " operands, but enclosing function (@"
294 << function.getName() << ") returns " << results.size();
295
296 for (unsigned i = 0, e = results.size(); i != e; ++i)
297 if (getOperand(i).getType() != results[i])
298 return emitError() << "type of return operand " << i << " ("
299 << getOperand(i).getType()
300 << ") doesn't match function result type ("
301 << results[i] << ")"
302 << " in function @" << function.getName();
303
304 return success();
305}
306
307//===----------------------------------------------------------------------===//
308// TableGen'd op method definitions
309//===----------------------------------------------------------------------===//
310
311#define GET_OP_CLASSES
312#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
313

source code of mlir/lib/Dialect/Func/IR/FuncOps.cpp