1 | //===- Operator.cpp - Operator class --------------------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Operator wrapper to simplify using TableGen Record defining a MLIR Op. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/TableGen/Operator.h" |
14 | #include "mlir/TableGen/Argument.h" |
15 | #include "mlir/TableGen/Predicate.h" |
16 | #include "mlir/TableGen/Trait.h" |
17 | #include "mlir/TableGen/Type.h" |
18 | #include "llvm/ADT/EquivalenceClasses.h" |
19 | #include "llvm/ADT/STLExtras.h" |
20 | #include "llvm/ADT/Sequence.h" |
21 | #include "llvm/ADT/SmallPtrSet.h" |
22 | #include "llvm/ADT/StringExtras.h" |
23 | #include "llvm/ADT/TypeSwitch.h" |
24 | #include "llvm/Support/Debug.h" |
25 | #include "llvm/Support/ErrorHandling.h" |
26 | #include "llvm/Support/FormatVariadic.h" |
27 | #include "llvm/TableGen/Error.h" |
28 | #include "llvm/TableGen/Record.h" |
29 | #include <list> |
30 | |
31 | #define DEBUG_TYPE "mlir-tblgen-operator" |
32 | |
33 | using namespace mlir; |
34 | using namespace mlir::tblgen; |
35 | |
36 | using llvm::DagInit; |
37 | using llvm::DefInit; |
38 | using llvm::Init; |
39 | using llvm::ListInit; |
40 | using llvm::Record; |
41 | using llvm::StringInit; |
42 | |
43 | Operator::Operator(const Record &def) |
44 | : dialect(def.getValueAsDef(FieldName: "opDialect")), def(def) { |
45 | // The first `_` in the op's TableGen def name is treated as separating the |
46 | // dialect prefix and the op class name. The dialect prefix will be ignored if |
47 | // not empty. Otherwise, if def name starts with a `_`, the `_` is considered |
48 | // as part of the class name. |
49 | StringRef prefix; |
50 | std::tie(args&: prefix, args&: cppClassName) = def.getName().split(Separator: '_'); |
51 | if (prefix.empty()) { |
52 | // Class name with a leading underscore and without dialect prefix |
53 | cppClassName = def.getName(); |
54 | } else if (cppClassName.empty()) { |
55 | // Class name without dialect prefix |
56 | cppClassName = prefix; |
57 | } |
58 | |
59 | cppNamespace = def.getValueAsString(FieldName: "cppNamespace"); |
60 | |
61 | populateOpStructure(); |
62 | assertInvariants(); |
63 | } |
64 | |
65 | std::string Operator::getOperationName() const { |
66 | auto prefix = dialect.getName(); |
67 | auto opName = def.getValueAsString(FieldName: "opName"); |
68 | if (prefix.empty()) |
69 | return std::string(opName); |
70 | return std::string(llvm::formatv(Fmt: "{0}.{1}", Vals&: prefix, Vals&: opName)); |
71 | } |
72 | |
73 | std::string Operator::getAdaptorName() const { |
74 | return std::string(llvm::formatv(Fmt: "{0}Adaptor", Vals: getCppClassName())); |
75 | } |
76 | |
77 | std::string Operator::getGenericAdaptorName() const { |
78 | return std::string(llvm::formatv(Fmt: "{0}GenericAdaptor", Vals: getCppClassName())); |
79 | } |
80 | |
81 | /// Assert the invariants of accessors generated for the given name. |
82 | static void assertAccessorInvariants(const Operator &op, StringRef name) { |
83 | std::string accessorName = |
84 | convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true); |
85 | |
86 | // Functor used to detect when an accessor will cause an overlap with an |
87 | // operation API. |
88 | // |
89 | // There are a little bit more invasive checks possible for cases where not |
90 | // all ops have the trait that would cause overlap. For many cases here, |
91 | // renaming would be better (e.g., we can only guard in limited manner |
92 | // against methods from traits and interfaces here, so avoiding these in op |
93 | // definition is safer). |
94 | auto nameOverlapsWithOpAPI = [&](StringRef newName) { |
95 | if (newName == "AttributeNames"|| newName == "Attributes"|| |
96 | newName == "Operation") |
97 | return true; |
98 | if (newName == "Operands") |
99 | return op.getNumOperands() != 1 || op.getNumVariableLengthOperands() != 1; |
100 | if (newName == "Regions") |
101 | return op.getNumRegions() != 1 || op.getNumVariadicRegions() != 1; |
102 | if (newName == "Type") |
103 | return op.getNumResults() != 1; |
104 | return false; |
105 | }; |
106 | if (nameOverlapsWithOpAPI(accessorName)) { |
107 | // This error could be avoided in situations where the final function is |
108 | // identical, but preferably the op definition should avoid using generic |
109 | // names. |
110 | PrintFatalError(ErrorLoc: op.getLoc(), Msg: "generated accessor for `"+ name + |
111 | "` overlaps with a default one; please " |
112 | "rename to avoid overlap"); |
113 | } |
114 | } |
115 | |
116 | void Operator::assertInvariants() const { |
117 | // Check that the name of arguments/results/regions/successors don't overlap. |
118 | DenseMap<StringRef, StringRef> existingNames; |
119 | auto checkName = [&](StringRef name, StringRef entity) { |
120 | if (name.empty()) |
121 | return; |
122 | auto insertion = existingNames.insert(KV: {name, entity}); |
123 | if (insertion.second) { |
124 | // Assert invariants for accessors generated for this name. |
125 | assertAccessorInvariants(op: *this, name); |
126 | return; |
127 | } |
128 | if (entity == insertion.first->second) |
129 | PrintFatalError(ErrorLoc: getLoc(), Msg: "op has a conflict with two "+ entity + |
130 | " having the same name '"+ name + "'"); |
131 | PrintFatalError(ErrorLoc: getLoc(), Msg: "op has a conflict with "+ |
132 | insertion.first->second + " and "+ entity + |
133 | " both having an entry with the name '"+ |
134 | name + "'"); |
135 | }; |
136 | // Check operands amongst themselves. |
137 | for (int i : llvm::seq<int>(Begin: 0, End: getNumOperands())) |
138 | checkName(getOperand(index: i).name, "operands"); |
139 | |
140 | // Check results amongst themselves and against operands. |
141 | for (int i : llvm::seq<int>(Begin: 0, End: getNumResults())) |
142 | checkName(getResult(index: i).name, "results"); |
143 | |
144 | // Check regions amongst themselves and against operands and results. |
145 | for (int i : llvm::seq<int>(Begin: 0, End: getNumRegions())) |
146 | checkName(getRegion(index: i).name, "regions"); |
147 | |
148 | // Check successors amongst themselves and against operands, results, and |
149 | // regions. |
150 | for (int i : llvm::seq<int>(Begin: 0, End: getNumSuccessors())) |
151 | checkName(getSuccessor(index: i).name, "successors"); |
152 | } |
153 | |
154 | StringRef Operator::getDialectName() const { return dialect.getName(); } |
155 | |
156 | StringRef Operator::getCppClassName() const { return cppClassName; } |
157 | |
158 | std::string Operator::getQualCppClassName() const { |
159 | if (cppNamespace.empty()) |
160 | return std::string(cppClassName); |
161 | return std::string(llvm::formatv(Fmt: "{0}::{1}", Vals: cppNamespace, Vals: cppClassName)); |
162 | } |
163 | |
164 | StringRef Operator::getCppNamespace() const { return cppNamespace; } |
165 | |
166 | int Operator::getNumResults() const { |
167 | const DagInit *results = def.getValueAsDag(FieldName: "results"); |
168 | return results->getNumArgs(); |
169 | } |
170 | |
171 | StringRef Operator::getExtraClassDeclaration() const { |
172 | constexpr auto attr = "extraClassDeclaration"; |
173 | if (def.isValueUnset(FieldName: attr)) |
174 | return {}; |
175 | return def.getValueAsString(FieldName: attr); |
176 | } |
177 | |
178 | StringRef Operator::getExtraClassDefinition() const { |
179 | constexpr auto attr = "extraClassDefinition"; |
180 | if (def.isValueUnset(FieldName: attr)) |
181 | return {}; |
182 | return def.getValueAsString(FieldName: attr); |
183 | } |
184 | |
185 | const Record &Operator::getDef() const { return def; } |
186 | |
187 | bool Operator::skipDefaultBuilders() const { |
188 | return def.getValueAsBit(FieldName: "skipDefaultBuilders"); |
189 | } |
190 | |
191 | auto Operator::result_begin() const -> const_value_iterator { |
192 | return results.begin(); |
193 | } |
194 | |
195 | auto Operator::result_end() const -> const_value_iterator { |
196 | return results.end(); |
197 | } |
198 | |
199 | auto Operator::getResults() const -> const_value_range { |
200 | return {result_begin(), result_end()}; |
201 | } |
202 | |
203 | TypeConstraint Operator::getResultTypeConstraint(int index) const { |
204 | const DagInit *results = def.getValueAsDag(FieldName: "results"); |
205 | return TypeConstraint(cast<DefInit>(Val: results->getArg(Num: index))); |
206 | } |
207 | |
208 | StringRef Operator::getResultName(int index) const { |
209 | const DagInit *results = def.getValueAsDag(FieldName: "results"); |
210 | return results->getArgNameStr(Num: index); |
211 | } |
212 | |
213 | auto Operator::getResultDecorators(int index) const -> var_decorator_range { |
214 | const Record *result = |
215 | cast<DefInit>(Val: def.getValueAsDag(FieldName: "results")->getArg(Num: index))->getDef(); |
216 | if (!result->isSubClassOf(Name: "OpVariable")) |
217 | return var_decorator_range(nullptr, nullptr); |
218 | return *result->getValueAsListInit(FieldName: "decorators"); |
219 | } |
220 | |
221 | unsigned Operator::getNumVariableLengthResults() const { |
222 | return llvm::count_if(Range: results, P: [](const NamedTypeConstraint &c) { |
223 | return c.constraint.isVariableLength(); |
224 | }); |
225 | } |
226 | |
227 | unsigned Operator::getNumVariableLengthOperands() const { |
228 | return llvm::count_if(Range: operands, P: [](const NamedTypeConstraint &c) { |
229 | return c.constraint.isVariableLength(); |
230 | }); |
231 | } |
232 | |
233 | bool Operator::hasSingleVariadicArg() const { |
234 | return getNumArgs() == 1 && isa<NamedTypeConstraint *>(Val: getArg(index: 0)) && |
235 | getOperand(index: 0).isVariadic(); |
236 | } |
237 | |
238 | Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); } |
239 | |
240 | Operator::arg_iterator Operator::arg_end() const { return arguments.end(); } |
241 | |
242 | Operator::arg_range Operator::getArgs() const { |
243 | return {arg_begin(), arg_end()}; |
244 | } |
245 | |
246 | StringRef Operator::getArgName(int index) const { |
247 | const DagInit *argumentValues = def.getValueAsDag(FieldName: "arguments"); |
248 | return argumentValues->getArgNameStr(Num: index); |
249 | } |
250 | |
251 | auto Operator::getArgDecorators(int index) const -> var_decorator_range { |
252 | const Record *arg = |
253 | cast<DefInit>(Val: def.getValueAsDag(FieldName: "arguments")->getArg(Num: index))->getDef(); |
254 | if (!arg->isSubClassOf(Name: "OpVariable")) |
255 | return var_decorator_range(nullptr, nullptr); |
256 | return *arg->getValueAsListInit(FieldName: "decorators"); |
257 | } |
258 | |
259 | const Trait *Operator::getTrait(StringRef trait) const { |
260 | for (const auto &t : traits) { |
261 | if (const auto *traitDef = dyn_cast<NativeTrait>(Val: &t)) { |
262 | if (traitDef->getFullyQualifiedTraitName() == trait) |
263 | return traitDef; |
264 | } else if (const auto *traitDef = dyn_cast<InternalTrait>(Val: &t)) { |
265 | if (traitDef->getFullyQualifiedTraitName() == trait) |
266 | return traitDef; |
267 | } else if (const auto *traitDef = dyn_cast<InterfaceTrait>(Val: &t)) { |
268 | if (traitDef->getFullyQualifiedTraitName() == trait) |
269 | return traitDef; |
270 | } |
271 | } |
272 | return nullptr; |
273 | } |
274 | |
275 | auto Operator::region_begin() const -> const_region_iterator { |
276 | return regions.begin(); |
277 | } |
278 | auto Operator::region_end() const -> const_region_iterator { |
279 | return regions.end(); |
280 | } |
281 | auto Operator::getRegions() const |
282 | -> llvm::iterator_range<const_region_iterator> { |
283 | return {region_begin(), region_end()}; |
284 | } |
285 | |
286 | unsigned Operator::getNumRegions() const { return regions.size(); } |
287 | |
288 | const NamedRegion &Operator::getRegion(unsigned index) const { |
289 | return regions[index]; |
290 | } |
291 | |
292 | unsigned Operator::getNumVariadicRegions() const { |
293 | return llvm::count_if(Range: regions, |
294 | P: [](const NamedRegion &c) { return c.isVariadic(); }); |
295 | } |
296 | |
297 | auto Operator::successor_begin() const -> const_successor_iterator { |
298 | return successors.begin(); |
299 | } |
300 | auto Operator::successor_end() const -> const_successor_iterator { |
301 | return successors.end(); |
302 | } |
303 | auto Operator::getSuccessors() const |
304 | -> llvm::iterator_range<const_successor_iterator> { |
305 | return {successor_begin(), successor_end()}; |
306 | } |
307 | |
308 | unsigned Operator::getNumSuccessors() const { return successors.size(); } |
309 | |
310 | const NamedSuccessor &Operator::getSuccessor(unsigned index) const { |
311 | return successors[index]; |
312 | } |
313 | |
314 | unsigned Operator::getNumVariadicSuccessors() const { |
315 | return llvm::count_if(Range: successors, |
316 | P: [](const NamedSuccessor &c) { return c.isVariadic(); }); |
317 | } |
318 | |
319 | auto Operator::trait_begin() const -> const_trait_iterator { |
320 | return traits.begin(); |
321 | } |
322 | auto Operator::trait_end() const -> const_trait_iterator { |
323 | return traits.end(); |
324 | } |
325 | auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> { |
326 | return {trait_begin(), trait_end()}; |
327 | } |
328 | |
329 | auto Operator::attribute_begin() const -> const_attribute_iterator { |
330 | return attributes.begin(); |
331 | } |
332 | auto Operator::attribute_end() const -> const_attribute_iterator { |
333 | return attributes.end(); |
334 | } |
335 | auto Operator::getAttributes() const |
336 | -> llvm::iterator_range<const_attribute_iterator> { |
337 | return {attribute_begin(), attribute_end()}; |
338 | } |
339 | auto Operator::attribute_begin() -> attribute_iterator { |
340 | return attributes.begin(); |
341 | } |
342 | auto Operator::attribute_end() -> attribute_iterator { |
343 | return attributes.end(); |
344 | } |
345 | auto Operator::getAttributes() -> llvm::iterator_range<attribute_iterator> { |
346 | return {attribute_begin(), attribute_end()}; |
347 | } |
348 | |
349 | auto Operator::operand_begin() const -> const_value_iterator { |
350 | return operands.begin(); |
351 | } |
352 | auto Operator::operand_end() const -> const_value_iterator { |
353 | return operands.end(); |
354 | } |
355 | auto Operator::getOperands() const -> const_value_range { |
356 | return {operand_begin(), operand_end()}; |
357 | } |
358 | |
359 | auto Operator::getArg(int index) const -> Argument { return arguments[index]; } |
360 | |
361 | bool Operator::isVariadic() const { |
362 | return any_of(Range: llvm::concat<const NamedTypeConstraint>(Ranges: operands, Ranges: results), |
363 | P: [](const NamedTypeConstraint &op) { return op.isVariadic(); }); |
364 | } |
365 | |
366 | void Operator::populateTypeInferenceInfo( |
367 | const llvm::StringMap<int> &argumentsAndResultsIndex) { |
368 | // If the type inference op interface is not registered, then do not attempt |
369 | // to determine if the result types an be inferred. |
370 | auto &recordKeeper = def.getRecords(); |
371 | auto *inferTrait = recordKeeper.getDef(Name: inferTypeOpInterface); |
372 | allResultsHaveKnownTypes = false; |
373 | if (!inferTrait) |
374 | return; |
375 | |
376 | // If there are no results, the skip this else the build method generated |
377 | // overlaps with another autogenerated builder. |
378 | if (getNumResults() == 0) |
379 | return; |
380 | |
381 | // Skip ops with variadic or optional results. |
382 | if (getNumVariableLengthResults() > 0) |
383 | return; |
384 | |
385 | // Skip cases currently being custom generated. |
386 | // TODO: Remove special cases. |
387 | if (getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType")) { |
388 | // Check for a non-variable length operand to use as the type anchor. |
389 | auto *operandI = llvm::find_if(Range&: arguments, P: [](const Argument &arg) { |
390 | NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: arg); |
391 | return operand && !operand->isVariableLength(); |
392 | }); |
393 | if (operandI == arguments.end()) |
394 | return; |
395 | |
396 | // All result types are inferred from the operand type. |
397 | int operandIdx = operandI - arguments.begin(); |
398 | for (int i = 0; i < getNumResults(); ++i) |
399 | resultTypeMapping.emplace_back(Args&: operandIdx, Args: "$_self"); |
400 | |
401 | allResultsHaveKnownTypes = true; |
402 | traits.push_back(Elt: Trait::create(init: inferTrait->getDefInit())); |
403 | return; |
404 | } |
405 | |
406 | /// This struct represents a node in this operation's result type inferenece |
407 | /// graph. Each node has a list of incoming type inference edges `sources`. |
408 | /// Each edge represents a "source" from which the result type can be |
409 | /// inferred, either an operand (leaf) or another result (node). When a node |
410 | /// is known to have a fully-inferred type, `inferred` is set to true. |
411 | struct ResultTypeInference { |
412 | /// The list of incoming type inference edges. |
413 | SmallVector<InferredResultType> sources; |
414 | /// This flag is set to true when the result type is known to be inferrable. |
415 | bool inferred = false; |
416 | }; |
417 | |
418 | // This vector represents the type inference graph, with one node for each |
419 | // operation result. The nth element is the node for the nth result. |
420 | SmallVector<ResultTypeInference> inference(getNumResults(), {}); |
421 | |
422 | // For all results whose types are buildable, initialize their type inference |
423 | // nodes with an edge to themselves. Mark those nodes are fully-inferred. |
424 | for (auto [idx, infer] : llvm::enumerate(First&: inference)) { |
425 | if (getResult(index: idx).constraint.getBuilderCall()) { |
426 | infer.sources.emplace_back(Args: InferredResultType::mapResultIndex(i: idx), |
427 | Args: "$_self"); |
428 | infer.inferred = true; |
429 | } |
430 | } |
431 | |
432 | // Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the |
433 | // result type inference graph. |
434 | for (const Trait &trait : traits) { |
435 | const Record &def = trait.getDef(); |
436 | |
437 | // If the infer type op interface was manually added, then treat it as |
438 | // intention that the op needs special handling. |
439 | // TODO: Reconsider whether to always generate, this is more conservative |
440 | // and keeps existing behavior so starting that way for now. |
441 | if (def.isSubClassOf( |
442 | Name: llvm::formatv(Fmt: "{0}::Trait", Vals&: inferTypeOpInterface).str())) |
443 | return; |
444 | if (const auto *traitDef = dyn_cast<InterfaceTrait>(Val: &trait)) |
445 | if (&traitDef->getDef() == inferTrait) |
446 | return; |
447 | |
448 | // The `TypesMatchWith` trait represents a 1 -> 1 type inference edge with a |
449 | // type transformer. |
450 | if (def.isSubClassOf(Name: "TypesMatchWith")) { |
451 | int target = argumentsAndResultsIndex.lookup(Key: def.getValueAsString(FieldName: "rhs")); |
452 | // Ignore operand type inference. |
453 | if (InferredResultType::isArgIndex(i: target)) |
454 | continue; |
455 | int resultIndex = InferredResultType::unmapResultIndex(i: target); |
456 | ResultTypeInference &infer = inference[resultIndex]; |
457 | // If the type of the result has already been inferred, do nothing. |
458 | if (infer.inferred) |
459 | continue; |
460 | int sourceIndex = |
461 | argumentsAndResultsIndex.lookup(Key: def.getValueAsString(FieldName: "lhs")); |
462 | infer.sources.emplace_back(Args&: sourceIndex, |
463 | Args: def.getValueAsString(FieldName: "transformer").str()); |
464 | // Locally propagate inferredness. |
465 | infer.inferred = |
466 | InferredResultType::isArgIndex(i: sourceIndex) || |
467 | inference[InferredResultType::unmapResultIndex(i: sourceIndex)].inferred; |
468 | continue; |
469 | } |
470 | |
471 | if (!def.isSubClassOf(Name: "AllTypesMatch")) |
472 | continue; |
473 | |
474 | auto values = def.getValueAsListOfStrings(FieldName: "values"); |
475 | // The `AllTypesMatch` trait represents an N <-> N fanin and fanout. That |
476 | // is, every result type has an edge from every other type. However, if any |
477 | // one of the values refers to an operand or a result with a fully-inferred |
478 | // type, we can infer all other types from that value. Try to find a |
479 | // fully-inferred type in the list. |
480 | std::optional<int> fullyInferredIndex; |
481 | SmallVector<int> resultIndices; |
482 | for (StringRef name : values) { |
483 | int index = argumentsAndResultsIndex.lookup(Key: name); |
484 | if (InferredResultType::isResultIndex(i: index)) |
485 | resultIndices.push_back(Elt: InferredResultType::unmapResultIndex(i: index)); |
486 | if (InferredResultType::isArgIndex(i: index) || |
487 | inference[InferredResultType::unmapResultIndex(i: index)].inferred) |
488 | fullyInferredIndex = index; |
489 | } |
490 | if (fullyInferredIndex) { |
491 | // Make the fully-inferred type the only source for all results that |
492 | // aren't already inferred -- a 1 -> N fanout. |
493 | for (int resultIndex : resultIndices) { |
494 | ResultTypeInference &infer = inference[resultIndex]; |
495 | if (!infer.inferred) { |
496 | infer.sources.assign(NumElts: 1, Elt: {*fullyInferredIndex, "$_self"}); |
497 | infer.inferred = true; |
498 | } |
499 | } |
500 | } else { |
501 | // Add an edge between every result and every other type; N <-> N. |
502 | for (int resultIndex : resultIndices) { |
503 | for (int otherResultIndex : resultIndices) { |
504 | if (resultIndex == otherResultIndex) |
505 | continue; |
506 | inference[resultIndex].sources.emplace_back( |
507 | Args: InferredResultType::unmapResultIndex(i: otherResultIndex), Args: "$_self"); |
508 | } |
509 | } |
510 | } |
511 | } |
512 | |
513 | // Propagate inferredness until a fixed point. |
514 | std::vector<ResultTypeInference *> worklist; |
515 | for (ResultTypeInference &infer : inference) |
516 | if (!infer.inferred) |
517 | worklist.push_back(x: &infer); |
518 | bool changed; |
519 | do { |
520 | changed = false; |
521 | for (auto cur = worklist.begin(); cur != worklist.end();) { |
522 | ResultTypeInference &infer = **cur; |
523 | |
524 | InferredResultType *iter = |
525 | llvm::find_if(Range&: infer.sources, P: [&](const InferredResultType &source) { |
526 | assert(InferredResultType::isResultIndex(source.getIndex())); |
527 | return inference[InferredResultType::unmapResultIndex( |
528 | i: source.getIndex())] |
529 | .inferred; |
530 | }); |
531 | if (iter == infer.sources.end()) { |
532 | ++cur; |
533 | continue; |
534 | } |
535 | |
536 | changed = true; |
537 | infer.inferred = true; |
538 | // Make this the only source for the result. This breaks any cycles. |
539 | infer.sources.assign(NumElts: 1, Elt: *iter); |
540 | cur = worklist.erase(position: cur); |
541 | } |
542 | } while (changed); |
543 | |
544 | allResultsHaveKnownTypes = worklist.empty(); |
545 | |
546 | // If the types could be computed, then add type inference trait. |
547 | if (allResultsHaveKnownTypes) { |
548 | traits.push_back(Elt: Trait::create(init: inferTrait->getDefInit())); |
549 | for (const ResultTypeInference &infer : inference) |
550 | resultTypeMapping.push_back(Elt: infer.sources.front()); |
551 | } |
552 | } |
553 | |
554 | void Operator::populateOpStructure() { |
555 | auto &recordKeeper = def.getRecords(); |
556 | auto *typeConstraintClass = recordKeeper.getClass(Name: "TypeConstraint"); |
557 | auto *attrClass = recordKeeper.getClass(Name: "Attr"); |
558 | auto *propertyClass = recordKeeper.getClass(Name: "Property"); |
559 | auto *derivedAttrClass = recordKeeper.getClass(Name: "DerivedAttr"); |
560 | auto *opVarClass = recordKeeper.getClass(Name: "OpVariable"); |
561 | numNativeAttributes = 0; |
562 | |
563 | const DagInit *argumentValues = def.getValueAsDag(FieldName: "arguments"); |
564 | unsigned numArgs = argumentValues->getNumArgs(); |
565 | |
566 | // Mapping from name of to argument or result index. Arguments are indexed |
567 | // to match getArg index, while the results are negatively indexed. |
568 | llvm::StringMap<int> argumentsAndResultsIndex; |
569 | |
570 | // Handle operands and native attributes. |
571 | for (unsigned i = 0; i != numArgs; ++i) { |
572 | auto *arg = argumentValues->getArg(Num: i); |
573 | auto givenName = argumentValues->getArgNameStr(Num: i); |
574 | auto *argDefInit = dyn_cast<DefInit>(Val: arg); |
575 | if (!argDefInit) |
576 | PrintFatalError(ErrorLoc: def.getLoc(), |
577 | Msg: Twine("undefined type for argument #") + Twine(i)); |
578 | const Record *argDef = argDefInit->getDef(); |
579 | if (argDef->isSubClassOf(R: opVarClass)) |
580 | argDef = argDef->getValueAsDef(FieldName: "constraint"); |
581 | |
582 | if (argDef->isSubClassOf(R: typeConstraintClass)) { |
583 | operands.push_back( |
584 | Elt: NamedTypeConstraint{.name: givenName, .constraint: TypeConstraint(argDef)}); |
585 | } else if (argDef->isSubClassOf(R: attrClass)) { |
586 | if (givenName.empty()) |
587 | PrintFatalError(ErrorLoc: argDef->getLoc(), Msg: "attributes must be named"); |
588 | if (argDef->isSubClassOf(R: derivedAttrClass)) |
589 | PrintFatalError(ErrorLoc: argDef->getLoc(), |
590 | Msg: "derived attributes not allowed in argument list"); |
591 | attributes.push_back(Elt: {.name: givenName, .attr: Attribute(argDef)}); |
592 | ++numNativeAttributes; |
593 | } else if (argDef->isSubClassOf(R: propertyClass)) { |
594 | if (givenName.empty()) |
595 | PrintFatalError(ErrorLoc: argDef->getLoc(), Msg: "properties must be named"); |
596 | properties.push_back(Elt: {.name: givenName, .prop: Property(argDef)}); |
597 | } else { |
598 | PrintFatalError(ErrorLoc: def.getLoc(), |
599 | Msg: "unexpected def type; only defs deriving " |
600 | "from TypeConstraint or Attr or Property are allowed"); |
601 | } |
602 | if (!givenName.empty()) |
603 | argumentsAndResultsIndex[givenName] = i; |
604 | } |
605 | |
606 | // Handle derived attributes. |
607 | for (const auto &val : def.getValues()) { |
608 | if (auto *record = dyn_cast<llvm::RecordRecTy>(Val: val.getType())) { |
609 | if (!record->isSubClassOf(Class: attrClass)) |
610 | continue; |
611 | if (!record->isSubClassOf(Class: derivedAttrClass)) |
612 | PrintFatalError(ErrorLoc: def.getLoc(), |
613 | Msg: "unexpected Attr where only DerivedAttr is allowed"); |
614 | |
615 | if (record->getClasses().size() != 1) { |
616 | PrintFatalError( |
617 | ErrorLoc: def.getLoc(), |
618 | Msg: "unsupported attribute modelling, only single class expected"); |
619 | } |
620 | attributes.push_back(Elt: {.name: cast<StringInit>(Val: val.getNameInit())->getValue(), |
621 | .attr: Attribute(cast<DefInit>(Val: val.getValue()))}); |
622 | } |
623 | } |
624 | |
625 | // Populate `arguments`. This must happen after we've finalized `operands` and |
626 | // `attributes` because we will put their elements' pointers in `arguments`. |
627 | // SmallVector may perform re-allocation under the hood when adding new |
628 | // elements. |
629 | int operandIndex = 0, attrIndex = 0, propIndex = 0; |
630 | for (unsigned i = 0; i != numArgs; ++i) { |
631 | const Record *argDef = |
632 | dyn_cast<DefInit>(Val: argumentValues->getArg(Num: i))->getDef(); |
633 | if (argDef->isSubClassOf(R: opVarClass)) |
634 | argDef = argDef->getValueAsDef(FieldName: "constraint"); |
635 | |
636 | if (argDef->isSubClassOf(R: typeConstraintClass)) { |
637 | attrOrOperandMapping.push_back( |
638 | Elt: {OperandOrAttribute::Kind::Operand, operandIndex}); |
639 | arguments.emplace_back(Args: &operands[operandIndex++]); |
640 | } else if (argDef->isSubClassOf(R: attrClass)) { |
641 | attrOrOperandMapping.push_back( |
642 | Elt: {OperandOrAttribute::Kind::Attribute, attrIndex}); |
643 | arguments.emplace_back(Args: &attributes[attrIndex++]); |
644 | } else { |
645 | assert(argDef->isSubClassOf(propertyClass)); |
646 | arguments.emplace_back(Args: &properties[propIndex++]); |
647 | } |
648 | } |
649 | |
650 | auto *resultsDag = def.getValueAsDag(FieldName: "results"); |
651 | auto *outsOp = dyn_cast<DefInit>(Val: resultsDag->getOperator()); |
652 | if (!outsOp || outsOp->getDef()->getName() != "outs") { |
653 | PrintFatalError(ErrorLoc: def.getLoc(), Msg: "'results' must have 'outs' directive"); |
654 | } |
655 | |
656 | // Handle results. |
657 | for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) { |
658 | auto name = resultsDag->getArgNameStr(Num: i); |
659 | auto *resultInit = dyn_cast<DefInit>(Val: resultsDag->getArg(Num: i)); |
660 | if (!resultInit) { |
661 | PrintFatalError(ErrorLoc: def.getLoc(), |
662 | Msg: Twine("undefined type for result #") + Twine(i)); |
663 | } |
664 | auto *resultDef = resultInit->getDef(); |
665 | if (resultDef->isSubClassOf(R: opVarClass)) |
666 | resultDef = resultDef->getValueAsDef(FieldName: "constraint"); |
667 | results.push_back(Elt: {.name: name, .constraint: TypeConstraint(resultDef)}); |
668 | if (!name.empty()) |
669 | argumentsAndResultsIndex[name] = InferredResultType::mapResultIndex(i); |
670 | |
671 | // We currently only support VariadicOfVariadic operands. |
672 | if (results.back().constraint.isVariadicOfVariadic()) { |
673 | PrintFatalError( |
674 | ErrorLoc: def.getLoc(), |
675 | Msg: "'VariadicOfVariadic' results are currently not supported"); |
676 | } |
677 | } |
678 | |
679 | // Handle successors |
680 | auto *successorsDag = def.getValueAsDag(FieldName: "successors"); |
681 | auto *successorsOp = dyn_cast<DefInit>(Val: successorsDag->getOperator()); |
682 | if (!successorsOp || successorsOp->getDef()->getName() != "successor") { |
683 | PrintFatalError(ErrorLoc: def.getLoc(), |
684 | Msg: "'successors' must have 'successor' directive"); |
685 | } |
686 | |
687 | for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) { |
688 | auto name = successorsDag->getArgNameStr(Num: i); |
689 | auto *successorInit = dyn_cast<DefInit>(Val: successorsDag->getArg(Num: i)); |
690 | if (!successorInit) { |
691 | PrintFatalError(ErrorLoc: def.getLoc(), |
692 | Msg: Twine("undefined kind for successor #") + Twine(i)); |
693 | } |
694 | Successor successor(successorInit->getDef()); |
695 | |
696 | // Only support variadic successors if it is the last one for now. |
697 | if (i != e - 1 && successor.isVariadic()) |
698 | PrintFatalError(ErrorLoc: def.getLoc(), Msg: "only the last successor can be variadic"); |
699 | successors.push_back(Elt: {.name: name, .constraint: successor}); |
700 | } |
701 | |
702 | // Create list of traits, skipping over duplicates: appending to lists in |
703 | // tablegen is easy, making them unique less so, so dedupe here. |
704 | if (auto *traitList = def.getValueAsListInit(FieldName: "traits")) { |
705 | // This is uniquing based on pointers of the trait. |
706 | SmallPtrSet<const Init *, 32> traitSet; |
707 | traits.reserve(N: traitSet.size()); |
708 | |
709 | // The declaration order of traits imply the verification order of traits. |
710 | // Some traits may require other traits to be verified first then they can |
711 | // do further verification based on those verified facts. If you see this |
712 | // error, fix the traits declaration order by checking the `dependentTraits` |
713 | // field. |
714 | auto verifyTraitValidity = [&](const Record *trait) { |
715 | auto *dependentTraits = trait->getValueAsListInit(FieldName: "dependentTraits"); |
716 | for (auto *traitInit : *dependentTraits) |
717 | if (!traitSet.contains(Ptr: traitInit)) |
718 | PrintFatalError( |
719 | ErrorLoc: def.getLoc(), |
720 | Msg: trait->getValueAsString(FieldName: "trait") + " requires "+ |
721 | cast<DefInit>(Val: traitInit)->getDef()->getValueAsString( |
722 | FieldName: "trait") + |
723 | " to precede it in traits list"); |
724 | }; |
725 | |
726 | std::function<void(const ListInit *)> insert; |
727 | insert = [&](const ListInit *traitList) { |
728 | for (auto *traitInit : *traitList) { |
729 | auto *def = cast<DefInit>(Val: traitInit)->getDef(); |
730 | if (def->isSubClassOf(Name: "TraitList")) { |
731 | insert(def->getValueAsListInit(FieldName: "traits")); |
732 | continue; |
733 | } |
734 | |
735 | // Ignore duplicates. |
736 | if (!traitSet.insert(Ptr: traitInit).second) |
737 | continue; |
738 | |
739 | // If this is an interface with base classes, add the bases to the |
740 | // trait list. |
741 | if (def->isSubClassOf(Name: "Interface")) |
742 | insert(def->getValueAsListInit(FieldName: "baseInterfaces")); |
743 | |
744 | // Verify if the trait has all the dependent traits declared before |
745 | // itself. |
746 | verifyTraitValidity(def); |
747 | traits.push_back(Elt: Trait::create(init: traitInit)); |
748 | } |
749 | }; |
750 | insert(traitList); |
751 | } |
752 | |
753 | populateTypeInferenceInfo(argumentsAndResultsIndex); |
754 | |
755 | // Handle regions |
756 | auto *regionsDag = def.getValueAsDag(FieldName: "regions"); |
757 | auto *regionsOp = dyn_cast<DefInit>(Val: regionsDag->getOperator()); |
758 | if (!regionsOp || regionsOp->getDef()->getName() != "region") { |
759 | PrintFatalError(ErrorLoc: def.getLoc(), Msg: "'regions' must have 'region' directive"); |
760 | } |
761 | |
762 | for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) { |
763 | auto name = regionsDag->getArgNameStr(Num: i); |
764 | auto *regionInit = dyn_cast<DefInit>(Val: regionsDag->getArg(Num: i)); |
765 | if (!regionInit) { |
766 | PrintFatalError(ErrorLoc: def.getLoc(), |
767 | Msg: Twine("undefined kind for region #") + Twine(i)); |
768 | } |
769 | Region region(regionInit->getDef()); |
770 | if (region.isVariadic()) { |
771 | // Only support variadic regions if it is the last one for now. |
772 | if (i != e - 1) |
773 | PrintFatalError(ErrorLoc: def.getLoc(), Msg: "only the last region can be variadic"); |
774 | if (name.empty()) |
775 | PrintFatalError(ErrorLoc: def.getLoc(), Msg: "variadic regions must be named"); |
776 | } |
777 | |
778 | regions.push_back(Elt: {.name: name, .constraint: region}); |
779 | } |
780 | |
781 | // Populate the builders. |
782 | auto *builderList = dyn_cast_or_null<ListInit>(Val: def.getValueInit(FieldName: "builders")); |
783 | if (builderList && !builderList->empty()) { |
784 | for (const Init *init : builderList->getElements()) |
785 | builders.emplace_back(Args: cast<DefInit>(Val: init)->getDef(), Args: def.getLoc()); |
786 | } else if (skipDefaultBuilders()) { |
787 | PrintFatalError( |
788 | ErrorLoc: def.getLoc(), |
789 | Msg: "default builders are skipped and no custom builders provided"); |
790 | } |
791 | |
792 | LLVM_DEBUG(print(llvm::dbgs())); |
793 | } |
794 | |
795 | const InferredResultType &Operator::getInferredResultType(int index) const { |
796 | assert(allResultTypesKnown()); |
797 | return resultTypeMapping[index]; |
798 | } |
799 | |
800 | ArrayRef<SMLoc> Operator::getLoc() const { return def.getLoc(); } |
801 | |
802 | bool Operator::hasDescription() const { |
803 | return !getDescription().trim().empty(); |
804 | } |
805 | |
806 | StringRef Operator::getDescription() const { |
807 | return def.getValueAsString(FieldName: "description"); |
808 | } |
809 | |
810 | bool Operator::hasSummary() const { return !getSummary().trim().empty(); } |
811 | |
812 | StringRef Operator::getSummary() const { |
813 | return def.getValueAsString(FieldName: "summary"); |
814 | } |
815 | |
816 | bool Operator::hasAssemblyFormat() const { |
817 | auto *valueInit = def.getValueInit(FieldName: "assemblyFormat"); |
818 | return isa<StringInit>(Val: valueInit); |
819 | } |
820 | |
821 | StringRef Operator::getAssemblyFormat() const { |
822 | return TypeSwitch<const Init *, StringRef>(def.getValueInit(FieldName: "assemblyFormat")) |
823 | .Case<StringInit>(caseFn: [&](auto *init) { return init->getValue(); }); |
824 | } |
825 | |
826 | void Operator::print(llvm::raw_ostream &os) const { |
827 | os << "op '"<< getOperationName() << "'\n"; |
828 | for (Argument arg : arguments) { |
829 | if (auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: arg)) |
830 | os << "[attribute] "<< attr->name << '\n'; |
831 | else |
832 | os << "[operand] "<< cast<NamedTypeConstraint *>(Val&: arg)->name << '\n'; |
833 | } |
834 | } |
835 | |
836 | auto Operator::VariableDecoratorIterator::unwrap(const Init *init) |
837 | -> VariableDecorator { |
838 | return VariableDecorator(cast<DefInit>(Val: init)->getDef()); |
839 | } |
840 | |
841 | auto Operator::getArgToOperandOrAttribute(int index) const |
842 | -> OperandOrAttribute { |
843 | return attrOrOperandMapping[index]; |
844 | } |
845 | |
846 | std::string Operator::getGetterName(StringRef name) const { |
847 | return "get"+ convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true); |
848 | } |
849 | |
850 | std::string Operator::getSetterName(StringRef name) const { |
851 | return "set"+ convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true); |
852 | } |
853 | |
854 | std::string Operator::getRemoverName(StringRef name) const { |
855 | return "remove"+ convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true); |
856 | } |
857 | |
858 | bool Operator::hasFolder() const { return def.getValueAsBit(FieldName: "hasFolder"); } |
859 | |
860 | bool Operator::useCustomPropertiesEncoding() const { |
861 | return def.getValueAsBit(FieldName: "useCustomPropertiesEncoding"); |
862 | } |
863 |
Definitions
- Operator
- getOperationName
- getAdaptorName
- getGenericAdaptorName
- assertAccessorInvariants
- assertInvariants
- getDialectName
- getCppClassName
- getQualCppClassName
- getCppNamespace
- getNumResults
- getExtraClassDeclaration
- getExtraClassDefinition
- getDef
- skipDefaultBuilders
- result_begin
- result_end
- getResults
- getResultTypeConstraint
- getResultName
- getResultDecorators
- getNumVariableLengthResults
- getNumVariableLengthOperands
- hasSingleVariadicArg
- arg_begin
- arg_end
- getArgs
- getArgName
- getArgDecorators
- getTrait
- region_begin
- region_end
- getRegions
- getNumRegions
- getRegion
- getNumVariadicRegions
- successor_begin
- successor_end
- getSuccessors
- getNumSuccessors
- getSuccessor
- getNumVariadicSuccessors
- trait_begin
- trait_end
- getTraits
- attribute_begin
- attribute_end
- getAttributes
- attribute_begin
- attribute_end
- getAttributes
- operand_begin
- operand_end
- getOperands
- getArg
- isVariadic
- populateTypeInferenceInfo
- populateOpStructure
- getInferredResultType
- getLoc
- hasDescription
- getDescription
- hasSummary
- getSummary
- hasAssemblyFormat
- getAssemblyFormat
- unwrap
- getArgToOperandOrAttribute
- getGetterName
- getSetterName
- getRemoverName
- hasFolder
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more