1//===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10
11#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h"
12#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14#include "mlir/IR/BuiltinOps.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "mlir/IR/IRMapping.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/OpImplementation.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/IR/TypeUtilities.h"
21#include "mlir/IR/Value.h"
22#include "mlir/Interfaces/FunctionImplementation.h"
23#include "mlir/Transforms/InliningUtils.h"
24#include "llvm/ADT/APFloat.h"
25#include "llvm/ADT/MapVector.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/Support/FormatVariadic.h"
28#include "llvm/Support/raw_ostream.h"
29#include <numeric>
30
31#include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
32
33using namespace mlir;
34using namespace mlir::func;
35
36//===----------------------------------------------------------------------===//
37// FuncDialect
38//===----------------------------------------------------------------------===//
39
40void FuncDialect::initialize() {
41 addOperations<
42#define GET_OP_LIST
43#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
44 >();
45 declarePromisedInterface<ConvertToEmitCPatternInterface, FuncDialect>();
46 declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
47 declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
48 declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
49 FuncOp, ReturnOp>();
50}
51
52/// Materialize a single constant operation from a given attribute value with
53/// the desired resultant type.
54Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value,
55 Type type, Location loc) {
56 if (ConstantOp::isBuildableWith(value, type))
57 return builder.create<ConstantOp>(loc, type,
58 llvm::cast<FlatSymbolRefAttr>(value));
59 return nullptr;
60}
61
62//===----------------------------------------------------------------------===//
63// CallOp
64//===----------------------------------------------------------------------===//
65
66LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
67 // Check that the callee attribute was specified.
68 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
69 if (!fnAttr)
70 return emitOpError("requires a 'callee' symbol reference attribute");
71 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
72 if (!fn)
73 return emitOpError() << "'" << fnAttr.getValue()
74 << "' does not reference a valid function";
75
76 // Verify that the operand and result types match the callee.
77 auto fnType = fn.getFunctionType();
78 if (fnType.getNumInputs() != getNumOperands())
79 return emitOpError("incorrect number of operands for callee");
80
81 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
82 if (getOperand(i).getType() != fnType.getInput(i))
83 return emitOpError("operand type mismatch: expected operand type ")
84 << fnType.getInput(i) << ", but provided "
85 << getOperand(i).getType() << " for operand number " << i;
86
87 if (fnType.getNumResults() != getNumResults())
88 return emitOpError("incorrect number of results for callee");
89
90 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
91 if (getResult(i).getType() != fnType.getResult(i)) {
92 auto diag = emitOpError("result type mismatch at index ") << i;
93 diag.attachNote() << " op result types: " << getResultTypes();
94 diag.attachNote() << "function result types: " << fnType.getResults();
95 return diag;
96 }
97
98 return success();
99}
100
101FunctionType CallOp::getCalleeType() {
102 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
103}
104
105//===----------------------------------------------------------------------===//
106// CallIndirectOp
107//===----------------------------------------------------------------------===//
108
109/// Fold indirect calls that have a constant function as the callee operand.
110LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
111 PatternRewriter &rewriter) {
112 // Check that the callee is a constant callee.
113 SymbolRefAttr calledFn;
114 if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
115 return failure();
116
117 // Replace with a direct call.
118 rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
119 indirectCall.getResultTypes(),
120 indirectCall.getArgOperands());
121 return success();
122}
123
124//===----------------------------------------------------------------------===//
125// ConstantOp
126//===----------------------------------------------------------------------===//
127
128LogicalResult ConstantOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
129 StringRef fnName = getValue();
130 Type type = getType();
131
132 // Try to find the referenced function.
133 auto fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(
134 this->getOperation(), StringAttr::get(getContext(), fnName));
135 if (!fn)
136 return emitOpError() << "reference to undefined function '" << fnName
137 << "'";
138
139 // Check that the referenced function has the correct type.
140 if (fn.getFunctionType() != type)
141 return emitOpError("reference to function with mismatched type");
142
143 return success();
144}
145
146OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
147 return getValueAttr();
148}
149
150void ConstantOp::getAsmResultNames(
151 function_ref<void(Value, StringRef)> setNameFn) {
152 setNameFn(getResult(), "f");
153}
154
155bool ConstantOp::isBuildableWith(Attribute value, Type type) {
156 return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
157}
158
159//===----------------------------------------------------------------------===//
160// FuncOp
161//===----------------------------------------------------------------------===//
162
163FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
164 ArrayRef<NamedAttribute> attrs) {
165 OpBuilder builder(location->getContext());
166 OperationState state(location, getOperationName());
167 FuncOp::build(builder, state, name, type, attrs);
168 return cast<FuncOp>(Operation::create(state));
169}
170FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
171 Operation::dialect_attr_range attrs) {
172 SmallVector<NamedAttribute, 8> attrRef(attrs);
173 return create(location, name, type, llvm::ArrayRef(attrRef));
174}
175FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
176 ArrayRef<NamedAttribute> attrs,
177 ArrayRef<DictionaryAttr> argAttrs) {
178 FuncOp func = create(location, name, type, attrs);
179 func.setAllArgAttrs(argAttrs);
180 return func;
181}
182
183void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
184 FunctionType type, ArrayRef<NamedAttribute> attrs,
185 ArrayRef<DictionaryAttr> argAttrs) {
186 state.addAttribute(SymbolTable::getSymbolAttrName(),
187 builder.getStringAttr(name));
188 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
189 state.attributes.append(attrs.begin(), attrs.end());
190 state.addRegion();
191
192 if (argAttrs.empty())
193 return;
194 assert(type.getNumInputs() == argAttrs.size());
195 call_interface_impl::addArgAndResultAttrs(
196 builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
197 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
198}
199
200ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
201 auto buildFuncType =
202 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
203 function_interface_impl::VariadicFlag,
204 std::string &) { return builder.getFunctionType(argTypes, results); };
205
206 return function_interface_impl::parseFunctionOp(
207 parser, result, /*allowVariadic=*/false,
208 getFunctionTypeAttrName(result.name), buildFuncType,
209 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
210}
211
212void FuncOp::print(OpAsmPrinter &p) {
213 function_interface_impl::printFunctionOp(
214 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
215 getArgAttrsAttrName(), getResAttrsAttrName());
216}
217
218/// Clone the internal blocks from this function into dest and all attributes
219/// from this function to dest.
220void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
221 // Add the attributes of this function to dest.
222 llvm::MapVector<StringAttr, Attribute> newAttrMap;
223 for (const auto &attr : dest->getAttrs())
224 newAttrMap.insert({attr.getName(), attr.getValue()});
225 for (const auto &attr : (*this)->getAttrs())
226 newAttrMap.insert({attr.getName(), attr.getValue()});
227
228 auto newAttrs = llvm::to_vector(llvm::map_range(
229 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
230 return NamedAttribute(attrPair.first, attrPair.second);
231 }));
232 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
233
234 // Clone the body.
235 getBody().cloneInto(&dest.getBody(), mapper);
236}
237
238/// Create a deep copy of this function and all of its blocks, remapping
239/// any operands that use values outside of the function using the map that is
240/// provided (leaving them alone if no entry is present). Replaces references
241/// to cloned sub-values with the corresponding value that is copied, and adds
242/// those mappings to the mapper.
243FuncOp FuncOp::clone(IRMapping &mapper) {
244 // Create the new function.
245 FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
246
247 // If the function has a body, then the user might be deleting arguments to
248 // the function by specifying them in the mapper. If so, we don't add the
249 // argument to the input type vector.
250 if (!isExternal()) {
251 FunctionType oldType = getFunctionType();
252
253 unsigned oldNumArgs = oldType.getNumInputs();
254 SmallVector<Type, 4> newInputs;
255 newInputs.reserve(oldNumArgs);
256 for (unsigned i = 0; i != oldNumArgs; ++i)
257 if (!mapper.contains(getArgument(i)))
258 newInputs.push_back(oldType.getInput(i));
259
260 /// If any of the arguments were dropped, update the type and drop any
261 /// necessary argument attributes.
262 if (newInputs.size() != oldNumArgs) {
263 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
264 oldType.getResults()));
265
266 if (ArrayAttr argAttrs = getAllArgAttrs()) {
267 SmallVector<Attribute> newArgAttrs;
268 newArgAttrs.reserve(newInputs.size());
269 for (unsigned i = 0; i != oldNumArgs; ++i)
270 if (!mapper.contains(getArgument(i)))
271 newArgAttrs.push_back(argAttrs[i]);
272 newFunc.setAllArgAttrs(newArgAttrs);
273 }
274 }
275 }
276
277 /// Clone the current function into the new one and return it.
278 cloneInto(newFunc, mapper);
279 return newFunc;
280}
281FuncOp FuncOp::clone() {
282 IRMapping mapper;
283 return clone(mapper);
284}
285
286//===----------------------------------------------------------------------===//
287// ReturnOp
288//===----------------------------------------------------------------------===//
289
290LogicalResult ReturnOp::verify() {
291 auto function = cast<FuncOp>((*this)->getParentOp());
292
293 // The operand number and types must match the function signature.
294 const auto &results = function.getFunctionType().getResults();
295 if (getNumOperands() != results.size())
296 return emitOpError("has ")
297 << getNumOperands() << " operands, but enclosing function (@"
298 << function.getName() << ") returns " << results.size();
299
300 for (unsigned i = 0, e = results.size(); i != e; ++i)
301 if (getOperand(i).getType() != results[i])
302 return emitError() << "type of return operand " << i << " ("
303 << getOperand(i).getType()
304 << ") doesn't match function result type ("
305 << results[i] << ")"
306 << " in function @" << function.getName();
307
308 return success();
309}
310
311//===----------------------------------------------------------------------===//
312// TableGen'd op method definitions
313//===----------------------------------------------------------------------===//
314
315#define GET_OP_CLASSES
316#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
317

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Func/IR/FuncOps.cpp