1//===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10
11#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/BuiltinTypes.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Matchers.h"
17#include "mlir/IR/OpImplementation.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/IR/TypeUtilities.h"
20#include "mlir/IR/Value.h"
21#include "mlir/Interfaces/FunctionImplementation.h"
22#include "mlir/Support/MathExtras.h"
23#include "mlir/Transforms/InliningUtils.h"
24#include "llvm/ADT/APFloat.h"
25#include "llvm/ADT/MapVector.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/Support/FormatVariadic.h"
28#include "llvm/Support/raw_ostream.h"
29#include <numeric>
30
31#include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
32
33using namespace mlir;
34using namespace mlir::func;
35
36//===----------------------------------------------------------------------===//
37// FuncDialect
38//===----------------------------------------------------------------------===//
39
40void FuncDialect::initialize() {
41 addOperations<
42#define GET_OP_LIST
43#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
44 >();
45 declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
46 declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
47 declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
48 FuncOp, ReturnOp>();
49}
50
51/// Materialize a single constant operation from a given attribute value with
52/// the desired resultant type.
53Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value,
54 Type type, Location loc) {
55 if (ConstantOp::isBuildableWith(value, type))
56 return builder.create<ConstantOp>(loc, type,
57 llvm::cast<FlatSymbolRefAttr>(value));
58 return nullptr;
59}
60
61//===----------------------------------------------------------------------===//
62// CallOp
63//===----------------------------------------------------------------------===//
64
65LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
66 // Check that the callee attribute was specified.
67 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
68 if (!fnAttr)
69 return emitOpError("requires a 'callee' symbol reference attribute");
70 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
71 if (!fn)
72 return emitOpError() << "'" << fnAttr.getValue()
73 << "' does not reference a valid function";
74
75 // Verify that the operand and result types match the callee.
76 auto fnType = fn.getFunctionType();
77 if (fnType.getNumInputs() != getNumOperands())
78 return emitOpError("incorrect number of operands for callee");
79
80 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
81 if (getOperand(i).getType() != fnType.getInput(i))
82 return emitOpError("operand type mismatch: expected operand type ")
83 << fnType.getInput(i) << ", but provided "
84 << getOperand(i).getType() << " for operand number " << i;
85
86 if (fnType.getNumResults() != getNumResults())
87 return emitOpError("incorrect number of results for callee");
88
89 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
90 if (getResult(i).getType() != fnType.getResult(i)) {
91 auto diag = emitOpError("result type mismatch at index ") << i;
92 diag.attachNote() << " op result types: " << getResultTypes();
93 diag.attachNote() << "function result types: " << fnType.getResults();
94 return diag;
95 }
96
97 return success();
98}
99
100FunctionType CallOp::getCalleeType() {
101 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
102}
103
104//===----------------------------------------------------------------------===//
105// CallIndirectOp
106//===----------------------------------------------------------------------===//
107
108/// Fold indirect calls that have a constant function as the callee operand.
109LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
110 PatternRewriter &rewriter) {
111 // Check that the callee is a constant callee.
112 SymbolRefAttr calledFn;
113 if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
114 return failure();
115
116 // Replace with a direct call.
117 rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
118 indirectCall.getResultTypes(),
119 indirectCall.getArgOperands());
120 return success();
121}
122
123//===----------------------------------------------------------------------===//
124// ConstantOp
125//===----------------------------------------------------------------------===//
126
127LogicalResult ConstantOp::verify() {
128 StringRef fnName = getValue();
129 Type type = getType();
130
131 // Try to find the referenced function.
132 auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName);
133 if (!fn)
134 return emitOpError() << "reference to undefined function '" << fnName
135 << "'";
136
137 // Check that the referenced function has the correct type.
138 if (fn.getFunctionType() != type)
139 return emitOpError("reference to function with mismatched type");
140
141 return success();
142}
143
144OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
145 return getValueAttr();
146}
147
148void ConstantOp::getAsmResultNames(
149 function_ref<void(Value, StringRef)> setNameFn) {
150 setNameFn(getResult(), "f");
151}
152
153bool ConstantOp::isBuildableWith(Attribute value, Type type) {
154 return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
155}
156
157//===----------------------------------------------------------------------===//
158// FuncOp
159//===----------------------------------------------------------------------===//
160
161FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
162 ArrayRef<NamedAttribute> attrs) {
163 OpBuilder builder(location->getContext());
164 OperationState state(location, getOperationName());
165 FuncOp::build(builder, state, name, type, attrs);
166 return cast<FuncOp>(Operation::create(state));
167}
168FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
169 Operation::dialect_attr_range attrs) {
170 SmallVector<NamedAttribute, 8> attrRef(attrs);
171 return create(location, name, type, llvm::ArrayRef(attrRef));
172}
173FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
174 ArrayRef<NamedAttribute> attrs,
175 ArrayRef<DictionaryAttr> argAttrs) {
176 FuncOp func = create(location, name, type, attrs);
177 func.setAllArgAttrs(argAttrs);
178 return func;
179}
180
181void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
182 FunctionType type, ArrayRef<NamedAttribute> attrs,
183 ArrayRef<DictionaryAttr> argAttrs) {
184 state.addAttribute(SymbolTable::getSymbolAttrName(),
185 builder.getStringAttr(name));
186 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
187 state.attributes.append(attrs.begin(), attrs.end());
188 state.addRegion();
189
190 if (argAttrs.empty())
191 return;
192 assert(type.getNumInputs() == argAttrs.size());
193 function_interface_impl::addArgAndResultAttrs(
194 builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
195 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
196}
197
198ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
199 auto buildFuncType =
200 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
201 function_interface_impl::VariadicFlag,
202 std::string &) { return builder.getFunctionType(argTypes, results); };
203
204 return function_interface_impl::parseFunctionOp(
205 parser, result, /*allowVariadic=*/false,
206 getFunctionTypeAttrName(result.name), buildFuncType,
207 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
208}
209
210void FuncOp::print(OpAsmPrinter &p) {
211 function_interface_impl::printFunctionOp(
212 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
213 getArgAttrsAttrName(), getResAttrsAttrName());
214}
215
216/// Clone the internal blocks from this function into dest and all attributes
217/// from this function to dest.
218void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
219 // Add the attributes of this function to dest.
220 llvm::MapVector<StringAttr, Attribute> newAttrMap;
221 for (const auto &attr : dest->getAttrs())
222 newAttrMap.insert({attr.getName(), attr.getValue()});
223 for (const auto &attr : (*this)->getAttrs())
224 newAttrMap.insert({attr.getName(), attr.getValue()});
225
226 auto newAttrs = llvm::to_vector(llvm::map_range(
227 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
228 return NamedAttribute(attrPair.first, attrPair.second);
229 }));
230 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
231
232 // Clone the body.
233 getBody().cloneInto(&dest.getBody(), mapper);
234}
235
236/// Create a deep copy of this function and all of its blocks, remapping
237/// any operands that use values outside of the function using the map that is
238/// provided (leaving them alone if no entry is present). Replaces references
239/// to cloned sub-values with the corresponding value that is copied, and adds
240/// those mappings to the mapper.
241FuncOp FuncOp::clone(IRMapping &mapper) {
242 // Create the new function.
243 FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
244
245 // If the function has a body, then the user might be deleting arguments to
246 // the function by specifying them in the mapper. If so, we don't add the
247 // argument to the input type vector.
248 if (!isExternal()) {
249 FunctionType oldType = getFunctionType();
250
251 unsigned oldNumArgs = oldType.getNumInputs();
252 SmallVector<Type, 4> newInputs;
253 newInputs.reserve(oldNumArgs);
254 for (unsigned i = 0; i != oldNumArgs; ++i)
255 if (!mapper.contains(getArgument(i)))
256 newInputs.push_back(oldType.getInput(i));
257
258 /// If any of the arguments were dropped, update the type and drop any
259 /// necessary argument attributes.
260 if (newInputs.size() != oldNumArgs) {
261 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
262 oldType.getResults()));
263
264 if (ArrayAttr argAttrs = getAllArgAttrs()) {
265 SmallVector<Attribute> newArgAttrs;
266 newArgAttrs.reserve(newInputs.size());
267 for (unsigned i = 0; i != oldNumArgs; ++i)
268 if (!mapper.contains(getArgument(i)))
269 newArgAttrs.push_back(argAttrs[i]);
270 newFunc.setAllArgAttrs(newArgAttrs);
271 }
272 }
273 }
274
275 /// Clone the current function into the new one and return it.
276 cloneInto(newFunc, mapper);
277 return newFunc;
278}
279FuncOp FuncOp::clone() {
280 IRMapping mapper;
281 return clone(mapper);
282}
283
284//===----------------------------------------------------------------------===//
285// ReturnOp
286//===----------------------------------------------------------------------===//
287
288LogicalResult ReturnOp::verify() {
289 auto function = cast<FuncOp>((*this)->getParentOp());
290
291 // The operand number and types must match the function signature.
292 const auto &results = function.getFunctionType().getResults();
293 if (getNumOperands() != results.size())
294 return emitOpError("has ")
295 << getNumOperands() << " operands, but enclosing function (@"
296 << function.getName() << ") returns " << results.size();
297
298 for (unsigned i = 0, e = results.size(); i != e; ++i)
299 if (getOperand(i).getType() != results[i])
300 return emitError() << "type of return operand " << i << " ("
301 << getOperand(i).getType()
302 << ") doesn't match function result type ("
303 << results[i] << ")"
304 << " in function @" << function.getName();
305
306 return success();
307}
308
309//===----------------------------------------------------------------------===//
310// TableGen'd op method definitions
311//===----------------------------------------------------------------------===//
312
313#define GET_OP_CLASSES
314#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
315

source code of mlir/lib/Dialect/Func/IR/FuncOps.cpp