1//===- CallInterfaces.cpp - ControlFlow Interfaces ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Interfaces/CallInterfaces.h"
10#include "mlir/IR/Builders.h"
11
12using namespace mlir;
13
14//===----------------------------------------------------------------------===//
15// Argument and result attributes utilities
16//===----------------------------------------------------------------------===//
17
18static ParseResult
19parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
20 SmallVectorImpl<DictionaryAttr> &attrs) {
21 // Parse individual function results.
22 return parser.parseCommaSeparatedList(parseElementFn: [&]() -> ParseResult {
23 types.emplace_back();
24 attrs.emplace_back();
25 NamedAttrList attrList;
26 if (parser.parseType(result&: types.back()) ||
27 parser.parseOptionalAttrDict(result&: attrList))
28 return failure();
29 attrs.back() = attrList.getDictionary(parser.getContext());
30 return success();
31 });
32}
33
34ParseResult call_interface_impl::parseFunctionResultList(
35 OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
36 SmallVectorImpl<DictionaryAttr> &resultAttrs) {
37 if (failed(Result: parser.parseOptionalLParen())) {
38 // We already know that there is no `(`, so parse a type.
39 // Because there is no `(`, it cannot be a function type.
40 Type ty;
41 if (parser.parseType(result&: ty))
42 return failure();
43 resultTypes.push_back(Elt: ty);
44 resultAttrs.emplace_back();
45 return success();
46 }
47
48 // Special case for an empty set of parens.
49 if (succeeded(Result: parser.parseOptionalRParen()))
50 return success();
51 if (parseTypeAndAttrList(parser, types&: resultTypes, attrs&: resultAttrs))
52 return failure();
53 return parser.parseRParen();
54}
55
56ParseResult call_interface_impl::parseFunctionSignature(
57 OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
58 SmallVectorImpl<DictionaryAttr> &argAttrs,
59 SmallVectorImpl<Type> &resultTypes,
60 SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
61 // Parse arguments.
62 if (parser.parseLParen())
63 return failure();
64 if (failed(Result: parser.parseOptionalRParen())) {
65 if (parseTypeAndAttrList(parser, types&: argTypes, attrs&: argAttrs))
66 return failure();
67 if (parser.parseRParen())
68 return failure();
69 }
70 // Parse results.
71 if (succeeded(Result: parser.parseOptionalArrow()))
72 return call_interface_impl::parseFunctionResultList(parser, resultTypes,
73 resultAttrs);
74 if (mustParseEmptyResult)
75 return failure();
76 return success();
77}
78
79/// Print a function result list. The provided `attrs` must either be null, or
80/// contain a set of DictionaryAttrs of the same arity as `types`.
81static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
82 ArrayAttr attrs) {
83 assert(!types.empty() && "Should not be called for empty result list.");
84 assert((!attrs || attrs.size() == types.size()) &&
85 "Invalid number of attributes.");
86
87 auto &os = p.getStream();
88 bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(Val: types[0]) ||
89 (attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty());
90 if (needsParens)
91 os << '(';
92 llvm::interleaveComma(c: llvm::seq<size_t>(Begin: 0, End: types.size()), os, each_fn: [&](size_t i) {
93 p.printType(type: types[i]);
94 if (attrs)
95 p.printOptionalAttrDict(attrs: llvm::cast<DictionaryAttr>(attrs[i]).getValue());
96 });
97 if (needsParens)
98 os << ')';
99}
100
101void call_interface_impl::printFunctionSignature(
102 OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
103 TypeRange resultTypes, ArrayAttr resultAttrs, Region *body,
104 bool printEmptyResult) {
105 bool isExternal = !body || body->empty();
106 if (!isExternal && !isVariadic && !argAttrs && !resultAttrs &&
107 printEmptyResult) {
108 p.printFunctionalType(inputs&: argTypes, results&: resultTypes);
109 return;
110 }
111
112 p << '(';
113 for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
114 if (i > 0)
115 p << ", ";
116
117 if (!isExternal) {
118 ArrayRef<NamedAttribute> attrs;
119 if (argAttrs)
120 attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
121 p.printRegionArgument(arg: body->getArgument(i), argAttrs: attrs);
122 } else {
123 p.printType(type: argTypes[i]);
124 if (argAttrs)
125 p.printOptionalAttrDict(
126 attrs: llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
127 }
128 }
129
130 if (isVariadic) {
131 if (!argTypes.empty())
132 p << ", ";
133 p << "...";
134 }
135
136 p << ')';
137
138 if (!resultTypes.empty()) {
139 p << " -> ";
140 printFunctionResultList(p, resultTypes, resultAttrs);
141 } else if (printEmptyResult) {
142 p << " -> ()";
143 }
144}
145
146void call_interface_impl::addArgAndResultAttrs(
147 Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
148 ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
149 StringAttr resAttrsName) {
150 auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
151 return attrs && !attrs.empty();
152 };
153 // Convert the specified array of dictionary attrs (which may have null
154 // entries) to an ArrayAttr of dictionaries.
155 auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
156 SmallVector<Attribute> attrs;
157 for (auto &dict : dictAttrs)
158 attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
159 return builder.getArrayAttr(attrs);
160 };
161
162 // Add the attributes to the operation arguments.
163 if (llvm::any_of(Range&: argAttrs, P: nonEmptyAttrsFn))
164 result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
165
166 // Add the attributes to the operation results.
167 if (llvm::any_of(Range&: resultAttrs, P: nonEmptyAttrsFn))
168 result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
169}
170
171void call_interface_impl::addArgAndResultAttrs(
172 Builder &builder, OperationState &result,
173 ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
174 StringAttr argAttrsName, StringAttr resAttrsName) {
175 SmallVector<DictionaryAttr> argAttrs;
176 for (const auto &arg : args)
177 argAttrs.push_back(arg.attrs);
178 addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
179 resAttrsName);
180}
181
182//===----------------------------------------------------------------------===//
183// CallOpInterface
184//===----------------------------------------------------------------------===//
185
186Operation *
187call_interface_impl::resolveCallable(CallOpInterface call,
188 SymbolTableCollection *symbolTable) {
189 CallInterfaceCallable callable = call.getCallableForCallee();
190 if (auto symbolVal = dyn_cast<Value>(callable))
191 return symbolVal.getDefiningOp();
192
193 // If the callable isn't a value, lookup the symbol reference.
194 auto symbolRef = cast<SymbolRefAttr>(callable);
195 if (symbolTable)
196 return symbolTable->lookupNearestSymbolFrom(call.getOperation(), symbolRef);
197 return SymbolTable::lookupNearestSymbolFrom(call.getOperation(), symbolRef);
198}
199
200//===----------------------------------------------------------------------===//
201// CallInterfaces
202//===----------------------------------------------------------------------===//
203
204#include "mlir/Interfaces/CallInterfaces.cpp.inc"
205

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Interfaces/CallInterfaces.cpp