1//===- Operator.cpp - Operator class --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Operator wrapper to simplify using TableGen Record defining a MLIR Op.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/TableGen/Operator.h"
14#include "mlir/TableGen/Argument.h"
15#include "mlir/TableGen/Predicate.h"
16#include "mlir/TableGen/Trait.h"
17#include "mlir/TableGen/Type.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/Sequence.h"
20#include "llvm/ADT/SmallPtrSet.h"
21#include "llvm/ADT/StringExtras.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/TableGen/Error.h"
27#include "llvm/TableGen/Record.h"
28
29#define DEBUG_TYPE "mlir-tblgen-operator"
30
31using namespace mlir;
32using namespace mlir::tblgen;
33
34using llvm::DagInit;
35using llvm::DefInit;
36using llvm::Init;
37using llvm::ListInit;
38using llvm::Record;
39using llvm::StringInit;
40
41Operator::Operator(const Record &def)
42 : dialect(def.getValueAsDef(FieldName: "opDialect")), def(def) {
43 // The first `_` in the op's TableGen def name is treated as separating the
44 // dialect prefix and the op class name. The dialect prefix will be ignored if
45 // not empty. Otherwise, if def name starts with a `_`, the `_` is considered
46 // as part of the class name.
47 StringRef prefix;
48 std::tie(args&: prefix, args&: cppClassName) = def.getName().split(Separator: '_');
49 if (prefix.empty()) {
50 // Class name with a leading underscore and without dialect prefix
51 cppClassName = def.getName();
52 } else if (cppClassName.empty()) {
53 // Class name without dialect prefix
54 cppClassName = prefix;
55 }
56
57 cppNamespace = def.getValueAsString(FieldName: "cppNamespace");
58
59 populateOpStructure();
60 assertInvariants();
61}
62
63std::string Operator::getOperationName() const {
64 auto prefix = dialect.getName();
65 auto opName = def.getValueAsString(FieldName: "opName");
66 if (prefix.empty())
67 return std::string(opName);
68 return std::string(llvm::formatv(Fmt: "{0}.{1}", Vals&: prefix, Vals&: opName));
69}
70
71std::string Operator::getAdaptorName() const {
72 return std::string(llvm::formatv(Fmt: "{0}Adaptor", Vals: getCppClassName()));
73}
74
75std::string Operator::getGenericAdaptorName() const {
76 return std::string(llvm::formatv(Fmt: "{0}GenericAdaptor", Vals: getCppClassName()));
77}
78
79/// Assert the invariants of accessors generated for the given name.
80static void assertAccessorInvariants(const Operator &op, StringRef name) {
81 std::string accessorName =
82 convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
83
84 // Functor used to detect when an accessor will cause an overlap with an
85 // operation API.
86 //
87 // There are a little bit more invasive checks possible for cases where not
88 // all ops have the trait that would cause overlap. For many cases here,
89 // renaming would be better (e.g., we can only guard in limited manner
90 // against methods from traits and interfaces here, so avoiding these in op
91 // definition is safer).
92 auto nameOverlapsWithOpAPI = [&](StringRef newName) {
93 if (newName == "AttributeNames" || newName == "Attributes" ||
94 newName == "Operation")
95 return true;
96 if (newName == "Operands")
97 return op.getNumOperands() != 1 || op.getNumVariableLengthOperands() != 1;
98 if (newName == "Regions")
99 return op.getNumRegions() != 1 || op.getNumVariadicRegions() != 1;
100 if (newName == "Type")
101 return op.getNumResults() != 1;
102 return false;
103 };
104 if (nameOverlapsWithOpAPI(accessorName)) {
105 // This error could be avoided in situations where the final function is
106 // identical, but preferably the op definition should avoid using generic
107 // names.
108 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "generated accessor for `" + name +
109 "` overlaps with a default one; please "
110 "rename to avoid overlap");
111 }
112}
113
114void Operator::assertInvariants() const {
115 // Check that the name of arguments/results/regions/successors don't overlap.
116 DenseMap<StringRef, StringRef> existingNames;
117 auto checkName = [&](StringRef name, StringRef entity) {
118 if (name.empty())
119 return;
120 auto insertion = existingNames.insert(KV: {name, entity});
121 if (insertion.second) {
122 // Assert invariants for accessors generated for this name.
123 assertAccessorInvariants(op: *this, name);
124 return;
125 }
126 if (entity == insertion.first->second)
127 PrintFatalError(ErrorLoc: getLoc(), Msg: "op has a conflict with two " + entity +
128 " having the same name '" + name + "'");
129 PrintFatalError(ErrorLoc: getLoc(), Msg: "op has a conflict with " +
130 insertion.first->second + " and " + entity +
131 " both having an entry with the name '" +
132 name + "'");
133 };
134 // Check operands amongst themselves.
135 for (int i : llvm::seq<int>(Begin: 0, End: getNumOperands()))
136 checkName(getOperand(index: i).name, "operands");
137
138 // Check results amongst themselves and against operands.
139 for (int i : llvm::seq<int>(Begin: 0, End: getNumResults()))
140 checkName(getResult(index: i).name, "results");
141
142 // Check regions amongst themselves and against operands and results.
143 for (int i : llvm::seq<int>(Begin: 0, End: getNumRegions()))
144 checkName(getRegion(index: i).name, "regions");
145
146 // Check successors amongst themselves and against operands, results, and
147 // regions.
148 for (int i : llvm::seq<int>(Begin: 0, End: getNumSuccessors()))
149 checkName(getSuccessor(index: i).name, "successors");
150}
151
152StringRef Operator::getDialectName() const { return dialect.getName(); }
153
154StringRef Operator::getCppClassName() const { return cppClassName; }
155
156std::string Operator::getQualCppClassName() const {
157 if (cppNamespace.empty())
158 return std::string(cppClassName);
159 return std::string(llvm::formatv(Fmt: "{0}::{1}", Vals: cppNamespace, Vals: cppClassName));
160}
161
162StringRef Operator::getCppNamespace() const { return cppNamespace; }
163
164int Operator::getNumResults() const {
165 const DagInit *results = def.getValueAsDag(FieldName: "results");
166 return results->getNumArgs();
167}
168
169StringRef Operator::getExtraClassDeclaration() const {
170 constexpr auto attr = "extraClassDeclaration";
171 if (def.isValueUnset(FieldName: attr))
172 return {};
173 return def.getValueAsString(FieldName: attr);
174}
175
176StringRef Operator::getExtraClassDefinition() const {
177 constexpr auto attr = "extraClassDefinition";
178 if (def.isValueUnset(FieldName: attr))
179 return {};
180 return def.getValueAsString(FieldName: attr);
181}
182
183const Record &Operator::getDef() const { return def; }
184
185bool Operator::skipDefaultBuilders() const {
186 return def.getValueAsBit(FieldName: "skipDefaultBuilders");
187}
188
189auto Operator::result_begin() const -> const_value_iterator {
190 return results.begin();
191}
192
193auto Operator::result_end() const -> const_value_iterator {
194 return results.end();
195}
196
197auto Operator::getResults() const -> const_value_range {
198 return {result_begin(), result_end()};
199}
200
201TypeConstraint Operator::getResultTypeConstraint(int index) const {
202 const DagInit *results = def.getValueAsDag(FieldName: "results");
203 return TypeConstraint(cast<DefInit>(Val: results->getArg(Num: index)));
204}
205
206StringRef Operator::getResultName(int index) const {
207 const DagInit *results = def.getValueAsDag(FieldName: "results");
208 return results->getArgNameStr(Num: index);
209}
210
211auto Operator::getResultDecorators(int index) const -> var_decorator_range {
212 const Record *result =
213 cast<DefInit>(Val: def.getValueAsDag(FieldName: "results")->getArg(Num: index))->getDef();
214 if (!result->isSubClassOf(Name: "OpVariable"))
215 return var_decorator_range(nullptr, nullptr);
216 return *result->getValueAsListInit(FieldName: "decorators");
217}
218
219unsigned Operator::getNumVariableLengthResults() const {
220 return llvm::count_if(Range: results, P: [](const NamedTypeConstraint &c) {
221 return c.constraint.isVariableLength();
222 });
223}
224
225unsigned Operator::getNumVariableLengthOperands() const {
226 return llvm::count_if(Range: operands, P: [](const NamedTypeConstraint &c) {
227 return c.constraint.isVariableLength();
228 });
229}
230
231bool Operator::hasSingleVariadicArg() const {
232 return getNumArgs() == 1 && isa<NamedTypeConstraint *>(Val: getArg(index: 0)) &&
233 getOperand(index: 0).isVariadic();
234}
235
236Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); }
237
238Operator::arg_iterator Operator::arg_end() const { return arguments.end(); }
239
240Operator::arg_range Operator::getArgs() const {
241 return {arg_begin(), arg_end()};
242}
243
244StringRef Operator::getArgName(int index) const {
245 const DagInit *argumentValues = def.getValueAsDag(FieldName: "arguments");
246 return argumentValues->getArgNameStr(Num: index);
247}
248
249auto Operator::getArgDecorators(int index) const -> var_decorator_range {
250 const Record *arg =
251 cast<DefInit>(Val: def.getValueAsDag(FieldName: "arguments")->getArg(Num: index))->getDef();
252 if (!arg->isSubClassOf(Name: "OpVariable"))
253 return var_decorator_range(nullptr, nullptr);
254 return *arg->getValueAsListInit(FieldName: "decorators");
255}
256
257const Trait *Operator::getTrait(StringRef trait) const {
258 for (const auto &t : traits) {
259 if (const auto *traitDef = dyn_cast<NativeTrait>(Val: &t)) {
260 if (traitDef->getFullyQualifiedTraitName() == trait)
261 return traitDef;
262 } else if (const auto *traitDef = dyn_cast<InternalTrait>(Val: &t)) {
263 if (traitDef->getFullyQualifiedTraitName() == trait)
264 return traitDef;
265 } else if (const auto *traitDef = dyn_cast<InterfaceTrait>(Val: &t)) {
266 if (traitDef->getFullyQualifiedTraitName() == trait)
267 return traitDef;
268 }
269 }
270 return nullptr;
271}
272
273auto Operator::region_begin() const -> const_region_iterator {
274 return regions.begin();
275}
276auto Operator::region_end() const -> const_region_iterator {
277 return regions.end();
278}
279auto Operator::getRegions() const
280 -> llvm::iterator_range<const_region_iterator> {
281 return {region_begin(), region_end()};
282}
283
284unsigned Operator::getNumRegions() const { return regions.size(); }
285
286const NamedRegion &Operator::getRegion(unsigned index) const {
287 return regions[index];
288}
289
290unsigned Operator::getNumVariadicRegions() const {
291 return llvm::count_if(Range: regions,
292 P: [](const NamedRegion &c) { return c.isVariadic(); });
293}
294
295auto Operator::successor_begin() const -> const_successor_iterator {
296 return successors.begin();
297}
298auto Operator::successor_end() const -> const_successor_iterator {
299 return successors.end();
300}
301auto Operator::getSuccessors() const
302 -> llvm::iterator_range<const_successor_iterator> {
303 return {successor_begin(), successor_end()};
304}
305
306unsigned Operator::getNumSuccessors() const { return successors.size(); }
307
308const NamedSuccessor &Operator::getSuccessor(unsigned index) const {
309 return successors[index];
310}
311
312unsigned Operator::getNumVariadicSuccessors() const {
313 return llvm::count_if(Range: successors,
314 P: [](const NamedSuccessor &c) { return c.isVariadic(); });
315}
316
317auto Operator::trait_begin() const -> const_trait_iterator {
318 return traits.begin();
319}
320auto Operator::trait_end() const -> const_trait_iterator {
321 return traits.end();
322}
323auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> {
324 return {trait_begin(), trait_end()};
325}
326
327auto Operator::attribute_begin() const -> const_attribute_iterator {
328 return attributes.begin();
329}
330auto Operator::attribute_end() const -> const_attribute_iterator {
331 return attributes.end();
332}
333auto Operator::getAttributes() const
334 -> llvm::iterator_range<const_attribute_iterator> {
335 return {attribute_begin(), attribute_end()};
336}
337auto Operator::attribute_begin() -> attribute_iterator {
338 return attributes.begin();
339}
340auto Operator::attribute_end() -> attribute_iterator {
341 return attributes.end();
342}
343auto Operator::getAttributes() -> llvm::iterator_range<attribute_iterator> {
344 return {attribute_begin(), attribute_end()};
345}
346
347auto Operator::operand_begin() const -> const_value_iterator {
348 return operands.begin();
349}
350auto Operator::operand_end() const -> const_value_iterator {
351 return operands.end();
352}
353auto Operator::getOperands() const -> const_value_range {
354 return {operand_begin(), operand_end()};
355}
356
357auto Operator::getArg(int index) const -> Argument { return arguments[index]; }
358
359bool Operator::isVariadic() const {
360 return any_of(Range: llvm::concat<const NamedTypeConstraint>(Ranges: operands, Ranges: results),
361 P: [](const NamedTypeConstraint &op) { return op.isVariadic(); });
362}
363
364void Operator::populateTypeInferenceInfo(
365 const llvm::StringMap<int> &argumentsAndResultsIndex) {
366 // If the type inference op interface is not registered, then do not attempt
367 // to determine if the result types an be inferred.
368 auto &recordKeeper = def.getRecords();
369 auto *inferTrait = recordKeeper.getDef(Name: inferTypeOpInterface);
370 allResultsHaveKnownTypes = false;
371 if (!inferTrait)
372 return;
373
374 // If there are no results, the skip this else the build method generated
375 // overlaps with another autogenerated builder.
376 if (getNumResults() == 0)
377 return;
378
379 // Skip ops with variadic or optional results.
380 if (getNumVariableLengthResults() > 0)
381 return;
382
383 // Skip cases currently being custom generated.
384 // TODO: Remove special cases.
385 if (getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType")) {
386 // Check for a non-variable length operand to use as the type anchor.
387 auto *operandI = llvm::find_if(Range&: arguments, P: [](const Argument &arg) {
388 NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: arg);
389 return operand && !operand->isVariableLength();
390 });
391 if (operandI == arguments.end())
392 return;
393
394 // All result types are inferred from the operand type.
395 int operandIdx = operandI - arguments.begin();
396 for (int i = 0; i < getNumResults(); ++i)
397 resultTypeMapping.emplace_back(Args&: operandIdx, Args: "$_self");
398
399 allResultsHaveKnownTypes = true;
400 traits.push_back(Elt: Trait::create(init: inferTrait->getDefInit()));
401 return;
402 }
403
404 /// This struct represents a node in this operation's result type inferenece
405 /// graph. Each node has a list of incoming type inference edges `sources`.
406 /// Each edge represents a "source" from which the result type can be
407 /// inferred, either an operand (leaf) or another result (node). When a node
408 /// is known to have a fully-inferred type, `inferred` is set to true.
409 struct ResultTypeInference {
410 /// The list of incoming type inference edges.
411 SmallVector<InferredResultType> sources;
412 /// This flag is set to true when the result type is known to be inferrable.
413 bool inferred = false;
414 };
415
416 // This vector represents the type inference graph, with one node for each
417 // operation result. The nth element is the node for the nth result.
418 SmallVector<ResultTypeInference> inference(getNumResults(), {});
419
420 // For all results whose types are buildable, initialize their type inference
421 // nodes with an edge to themselves. Mark those nodes are fully-inferred.
422 for (auto [idx, infer] : llvm::enumerate(First&: inference)) {
423 if (getResult(index: idx).constraint.getBuilderCall()) {
424 infer.sources.emplace_back(Args: InferredResultType::mapResultIndex(i: idx),
425 Args: "$_self");
426 infer.inferred = true;
427 }
428 }
429
430 // Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the
431 // result type inference graph.
432 for (const Trait &trait : traits) {
433 const Record &def = trait.getDef();
434
435 // If the infer type op interface was manually added, then treat it as
436 // intention that the op needs special handling.
437 // TODO: Reconsider whether to always generate, this is more conservative
438 // and keeps existing behavior so starting that way for now.
439 if (def.isSubClassOf(
440 Name: llvm::formatv(Fmt: "{0}::Trait", Vals&: inferTypeOpInterface).str()))
441 return;
442 if (const auto *traitDef = dyn_cast<InterfaceTrait>(Val: &trait))
443 if (&traitDef->getDef() == inferTrait)
444 return;
445
446 // The `TypesMatchWith` trait represents a 1 -> 1 type inference edge with a
447 // type transformer.
448 if (def.isSubClassOf(Name: "TypesMatchWith")) {
449 int target = argumentsAndResultsIndex.lookup(Key: def.getValueAsString(FieldName: "rhs"));
450 // Ignore operand type inference.
451 if (InferredResultType::isArgIndex(i: target))
452 continue;
453 int resultIndex = InferredResultType::unmapResultIndex(i: target);
454 ResultTypeInference &infer = inference[resultIndex];
455 // If the type of the result has already been inferred, do nothing.
456 if (infer.inferred)
457 continue;
458 int sourceIndex =
459 argumentsAndResultsIndex.lookup(Key: def.getValueAsString(FieldName: "lhs"));
460 infer.sources.emplace_back(Args&: sourceIndex,
461 Args: def.getValueAsString(FieldName: "transformer").str());
462 // Locally propagate inferredness.
463 infer.inferred =
464 InferredResultType::isArgIndex(i: sourceIndex) ||
465 inference[InferredResultType::unmapResultIndex(i: sourceIndex)].inferred;
466 continue;
467 }
468
469 // The `ShapedTypeMatchesElementCountAndTypes` trait represents a 1 -> 1
470 // type inference edge where a shaped type matches element count and types
471 // of variadic elements.
472 if (def.isSubClassOf(Name: "ShapedTypeMatchesElementCountAndTypes")) {
473 StringRef shapedArg = def.getValueAsString(FieldName: "shaped");
474 StringRef elementsArg = def.getValueAsString(FieldName: "elements");
475
476 int shapedIndex = argumentsAndResultsIndex.lookup(Key: shapedArg);
477 int elementsIndex = argumentsAndResultsIndex.lookup(Key: elementsArg);
478
479 // Handle result type inference from shaped type to variadic elements.
480 if (InferredResultType::isResultIndex(i: elementsIndex) &&
481 InferredResultType::isArgIndex(i: shapedIndex)) {
482 int resultIndex = InferredResultType::unmapResultIndex(i: elementsIndex);
483 ResultTypeInference &infer = inference[resultIndex];
484 if (!infer.inferred) {
485 infer.sources.emplace_back(
486 Args&: shapedIndex,
487 Args: "::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::"
488 "ShapedType>($_self).getNumElements(), "
489 "::llvm::cast<::mlir::ShapedType>($_self).getElementType())");
490 infer.inferred = true;
491 }
492 }
493
494 // Type inference in the opposite direction is not possible as the actual
495 // shaped type can't be inferred from the variadic elements.
496
497 continue;
498 }
499
500 if (!def.isSubClassOf(Name: "AllTypesMatch"))
501 continue;
502
503 auto values = def.getValueAsListOfStrings(FieldName: "values");
504 // The `AllTypesMatch` trait represents an N <-> N fanin and fanout. That
505 // is, every result type has an edge from every other type. However, if any
506 // one of the values refers to an operand or a result with a fully-inferred
507 // type, we can infer all other types from that value. Try to find a
508 // fully-inferred type in the list.
509 std::optional<int> fullyInferredIndex;
510 SmallVector<int> resultIndices;
511 for (StringRef name : values) {
512 int index = argumentsAndResultsIndex.lookup(Key: name);
513 if (InferredResultType::isResultIndex(i: index))
514 resultIndices.push_back(Elt: InferredResultType::unmapResultIndex(i: index));
515 if (InferredResultType::isArgIndex(i: index) ||
516 inference[InferredResultType::unmapResultIndex(i: index)].inferred)
517 fullyInferredIndex = index;
518 }
519 if (fullyInferredIndex) {
520 // Make the fully-inferred type the only source for all results that
521 // aren't already inferred -- a 1 -> N fanout.
522 for (int resultIndex : resultIndices) {
523 ResultTypeInference &infer = inference[resultIndex];
524 if (!infer.inferred) {
525 infer.sources.assign(NumElts: 1, Elt: {*fullyInferredIndex, "$_self"});
526 infer.inferred = true;
527 }
528 }
529 } else {
530 // Add an edge between every result and every other type; N <-> N.
531 for (int resultIndex : resultIndices) {
532 for (int otherResultIndex : resultIndices) {
533 if (resultIndex == otherResultIndex)
534 continue;
535 inference[resultIndex].sources.emplace_back(
536 Args: InferredResultType::unmapResultIndex(i: otherResultIndex), Args: "$_self");
537 }
538 }
539 }
540 }
541
542 // Propagate inferredness until a fixed point.
543 std::vector<ResultTypeInference *> worklist;
544 for (ResultTypeInference &infer : inference)
545 if (!infer.inferred)
546 worklist.push_back(x: &infer);
547 bool changed;
548 do {
549 changed = false;
550 for (auto cur = worklist.begin(); cur != worklist.end();) {
551 ResultTypeInference &infer = **cur;
552
553 InferredResultType *iter =
554 llvm::find_if(Range&: infer.sources, P: [&](const InferredResultType &source) {
555 assert(InferredResultType::isResultIndex(source.getIndex()));
556 return inference[InferredResultType::unmapResultIndex(
557 i: source.getIndex())]
558 .inferred;
559 });
560 if (iter == infer.sources.end()) {
561 ++cur;
562 continue;
563 }
564
565 changed = true;
566 infer.inferred = true;
567 // Make this the only source for the result. This breaks any cycles.
568 infer.sources.assign(NumElts: 1, Elt: *iter);
569 cur = worklist.erase(position: cur);
570 }
571 } while (changed);
572
573 allResultsHaveKnownTypes = worklist.empty();
574
575 // If the types could be computed, then add type inference trait.
576 if (allResultsHaveKnownTypes) {
577 traits.push_back(Elt: Trait::create(init: inferTrait->getDefInit()));
578 for (const ResultTypeInference &infer : inference)
579 resultTypeMapping.push_back(Elt: infer.sources.front());
580 }
581}
582
583void Operator::populateOpStructure() {
584 auto &recordKeeper = def.getRecords();
585 auto *typeConstraintClass = recordKeeper.getClass(Name: "TypeConstraint");
586 auto *attrClass = recordKeeper.getClass(Name: "Attr");
587 auto *propertyClass = recordKeeper.getClass(Name: "Property");
588 auto *derivedAttrClass = recordKeeper.getClass(Name: "DerivedAttr");
589 auto *opVarClass = recordKeeper.getClass(Name: "OpVariable");
590 numNativeAttributes = 0;
591
592 const DagInit *argumentValues = def.getValueAsDag(FieldName: "arguments");
593 unsigned numArgs = argumentValues->getNumArgs();
594
595 // Mapping from name of to argument or result index. Arguments are indexed
596 // to match getArg index, while the results are negatively indexed.
597 llvm::StringMap<int> argumentsAndResultsIndex;
598
599 // Handle operands and native attributes.
600 for (unsigned i = 0; i != numArgs; ++i) {
601 auto *arg = argumentValues->getArg(Num: i);
602 auto givenName = argumentValues->getArgNameStr(Num: i);
603 auto *argDefInit = dyn_cast<DefInit>(Val: arg);
604 if (!argDefInit)
605 PrintFatalError(ErrorLoc: def.getLoc(),
606 Msg: Twine("undefined type for argument #") + Twine(i));
607 const Record *argDef = argDefInit->getDef();
608 if (argDef->isSubClassOf(R: opVarClass))
609 argDef = argDef->getValueAsDef(FieldName: "constraint");
610
611 if (argDef->isSubClassOf(R: typeConstraintClass)) {
612 operands.push_back(
613 Elt: NamedTypeConstraint{.name: givenName, .constraint: TypeConstraint(argDef)});
614 } else if (argDef->isSubClassOf(R: attrClass)) {
615 if (givenName.empty())
616 PrintFatalError(ErrorLoc: argDef->getLoc(), Msg: "attributes must be named");
617 if (argDef->isSubClassOf(R: derivedAttrClass))
618 PrintFatalError(ErrorLoc: argDef->getLoc(),
619 Msg: "derived attributes not allowed in argument list");
620 attributes.push_back(Elt: {.name: givenName, .attr: Attribute(argDef)});
621 ++numNativeAttributes;
622 } else if (argDef->isSubClassOf(R: propertyClass)) {
623 if (givenName.empty())
624 PrintFatalError(ErrorLoc: argDef->getLoc(), Msg: "properties must be named");
625 properties.push_back(Elt: {.name: givenName, .prop: Property(argDef)});
626 } else {
627 PrintFatalError(ErrorLoc: def.getLoc(),
628 Msg: "unexpected def type; only defs deriving "
629 "from TypeConstraint or Attr or Property are allowed");
630 }
631 if (!givenName.empty())
632 argumentsAndResultsIndex[givenName] = i;
633 }
634
635 // Handle derived attributes.
636 for (const auto &val : def.getValues()) {
637 if (auto *record = dyn_cast<llvm::RecordRecTy>(Val: val.getType())) {
638 if (!record->isSubClassOf(Class: attrClass))
639 continue;
640 if (!record->isSubClassOf(Class: derivedAttrClass))
641 PrintFatalError(ErrorLoc: def.getLoc(),
642 Msg: "unexpected Attr where only DerivedAttr is allowed");
643
644 if (record->getClasses().size() != 1) {
645 PrintFatalError(
646 ErrorLoc: def.getLoc(),
647 Msg: "unsupported attribute modelling, only single class expected");
648 }
649 attributes.push_back(Elt: {.name: cast<StringInit>(Val: val.getNameInit())->getValue(),
650 .attr: Attribute(cast<DefInit>(Val: val.getValue()))});
651 }
652 }
653
654 // Populate `arguments`. This must happen after we've finalized `operands` and
655 // `attributes` because we will put their elements' pointers in `arguments`.
656 // SmallVector may perform re-allocation under the hood when adding new
657 // elements.
658 int operandIndex = 0, attrIndex = 0, propIndex = 0;
659 for (unsigned i = 0; i != numArgs; ++i) {
660 const Record *argDef =
661 dyn_cast<DefInit>(Val: argumentValues->getArg(Num: i))->getDef();
662 if (argDef->isSubClassOf(R: opVarClass))
663 argDef = argDef->getValueAsDef(FieldName: "constraint");
664
665 if (argDef->isSubClassOf(R: typeConstraintClass)) {
666 attrOrOperandMapping.push_back(
667 Elt: {OperandOrAttribute::Kind::Operand, operandIndex});
668 arguments.emplace_back(Args: &operands[operandIndex++]);
669 } else if (argDef->isSubClassOf(R: attrClass)) {
670 attrOrOperandMapping.push_back(
671 Elt: {OperandOrAttribute::Kind::Attribute, attrIndex});
672 arguments.emplace_back(Args: &attributes[attrIndex++]);
673 } else {
674 assert(argDef->isSubClassOf(propertyClass));
675 arguments.emplace_back(Args: &properties[propIndex++]);
676 }
677 }
678
679 auto *resultsDag = def.getValueAsDag(FieldName: "results");
680 auto *outsOp = dyn_cast<DefInit>(Val: resultsDag->getOperator());
681 if (!outsOp || outsOp->getDef()->getName() != "outs") {
682 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "'results' must have 'outs' directive");
683 }
684
685 // Handle results.
686 for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
687 auto name = resultsDag->getArgNameStr(Num: i);
688 auto *resultInit = dyn_cast<DefInit>(Val: resultsDag->getArg(Num: i));
689 if (!resultInit) {
690 PrintFatalError(ErrorLoc: def.getLoc(),
691 Msg: Twine("undefined type for result #") + Twine(i));
692 }
693 auto *resultDef = resultInit->getDef();
694 if (resultDef->isSubClassOf(R: opVarClass))
695 resultDef = resultDef->getValueAsDef(FieldName: "constraint");
696 results.push_back(Elt: {.name: name, .constraint: TypeConstraint(resultDef)});
697 if (!name.empty())
698 argumentsAndResultsIndex[name] = InferredResultType::mapResultIndex(i);
699
700 // We currently only support VariadicOfVariadic operands.
701 if (results.back().constraint.isVariadicOfVariadic()) {
702 PrintFatalError(
703 ErrorLoc: def.getLoc(),
704 Msg: "'VariadicOfVariadic' results are currently not supported");
705 }
706 }
707
708 // Handle successors
709 auto *successorsDag = def.getValueAsDag(FieldName: "successors");
710 auto *successorsOp = dyn_cast<DefInit>(Val: successorsDag->getOperator());
711 if (!successorsOp || successorsOp->getDef()->getName() != "successor") {
712 PrintFatalError(ErrorLoc: def.getLoc(),
713 Msg: "'successors' must have 'successor' directive");
714 }
715
716 for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
717 auto name = successorsDag->getArgNameStr(Num: i);
718 auto *successorInit = dyn_cast<DefInit>(Val: successorsDag->getArg(Num: i));
719 if (!successorInit) {
720 PrintFatalError(ErrorLoc: def.getLoc(),
721 Msg: Twine("undefined kind for successor #") + Twine(i));
722 }
723 Successor successor(successorInit->getDef());
724
725 // Only support variadic successors if it is the last one for now.
726 if (i != e - 1 && successor.isVariadic())
727 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "only the last successor can be variadic");
728 successors.push_back(Elt: {.name: name, .constraint: successor});
729 }
730
731 // Create list of traits, skipping over duplicates: appending to lists in
732 // tablegen is easy, making them unique less so, so dedupe here.
733 if (auto *traitList = def.getValueAsListInit(FieldName: "traits")) {
734 // This is uniquing based on pointers of the trait.
735 SmallPtrSet<const Init *, 32> traitSet;
736 traits.reserve(N: traitSet.size());
737
738 // The declaration order of traits imply the verification order of traits.
739 // Some traits may require other traits to be verified first then they can
740 // do further verification based on those verified facts. If you see this
741 // error, fix the traits declaration order by checking the `dependentTraits`
742 // field.
743 auto verifyTraitValidity = [&](const Record *trait) {
744 auto *dependentTraits = trait->getValueAsListInit(FieldName: "dependentTraits");
745 for (auto *traitInit : *dependentTraits)
746 if (!traitSet.contains(Ptr: traitInit))
747 PrintFatalError(
748 ErrorLoc: def.getLoc(),
749 Msg: trait->getValueAsString(FieldName: "trait") + " requires " +
750 cast<DefInit>(Val: traitInit)->getDef()->getValueAsString(
751 FieldName: "trait") +
752 " to precede it in traits list");
753 };
754
755 std::function<void(const ListInit *)> insert;
756 insert = [&](const ListInit *traitList) {
757 for (auto *traitInit : *traitList) {
758 auto *def = cast<DefInit>(Val: traitInit)->getDef();
759 if (def->isSubClassOf(Name: "TraitList")) {
760 insert(def->getValueAsListInit(FieldName: "traits"));
761 continue;
762 }
763
764 // Ignore duplicates.
765 if (!traitSet.insert(Ptr: traitInit).second)
766 continue;
767
768 // If this is an interface with base classes, add the bases to the
769 // trait list.
770 if (def->isSubClassOf(Name: "Interface"))
771 insert(def->getValueAsListInit(FieldName: "baseInterfaces"));
772
773 // Verify if the trait has all the dependent traits declared before
774 // itself.
775 verifyTraitValidity(def);
776 traits.push_back(Elt: Trait::create(init: traitInit));
777 }
778 };
779 insert(traitList);
780 }
781
782 populateTypeInferenceInfo(argumentsAndResultsIndex);
783
784 // Handle regions
785 auto *regionsDag = def.getValueAsDag(FieldName: "regions");
786 auto *regionsOp = dyn_cast<DefInit>(Val: regionsDag->getOperator());
787 if (!regionsOp || regionsOp->getDef()->getName() != "region") {
788 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "'regions' must have 'region' directive");
789 }
790
791 for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
792 auto name = regionsDag->getArgNameStr(Num: i);
793 auto *regionInit = dyn_cast<DefInit>(Val: regionsDag->getArg(Num: i));
794 if (!regionInit) {
795 PrintFatalError(ErrorLoc: def.getLoc(),
796 Msg: Twine("undefined kind for region #") + Twine(i));
797 }
798 Region region(regionInit->getDef());
799 if (region.isVariadic()) {
800 // Only support variadic regions if it is the last one for now.
801 if (i != e - 1)
802 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "only the last region can be variadic");
803 if (name.empty())
804 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "variadic regions must be named");
805 }
806
807 regions.push_back(Elt: {.name: name, .constraint: region});
808 }
809
810 // Populate the builders.
811 auto *builderList = dyn_cast_or_null<ListInit>(Val: def.getValueInit(FieldName: "builders"));
812 if (builderList && !builderList->empty()) {
813 for (const Init *init : builderList->getElements())
814 builders.emplace_back(Args: cast<DefInit>(Val: init)->getDef(), Args: def.getLoc());
815 } else if (skipDefaultBuilders()) {
816 PrintFatalError(
817 ErrorLoc: def.getLoc(),
818 Msg: "default builders are skipped and no custom builders provided");
819 }
820
821 LLVM_DEBUG(print(llvm::dbgs()));
822}
823
824const InferredResultType &Operator::getInferredResultType(int index) const {
825 assert(allResultTypesKnown());
826 return resultTypeMapping[index];
827}
828
829ArrayRef<SMLoc> Operator::getLoc() const { return def.getLoc(); }
830
831bool Operator::hasDescription() const {
832 return !getDescription().trim().empty();
833}
834
835StringRef Operator::getDescription() const {
836 return def.getValueAsString(FieldName: "description");
837}
838
839bool Operator::hasSummary() const { return !getSummary().trim().empty(); }
840
841StringRef Operator::getSummary() const {
842 return def.getValueAsString(FieldName: "summary");
843}
844
845bool Operator::hasAssemblyFormat() const {
846 auto *valueInit = def.getValueInit(FieldName: "assemblyFormat");
847 return isa<StringInit>(Val: valueInit);
848}
849
850StringRef Operator::getAssemblyFormat() const {
851 return TypeSwitch<const Init *, StringRef>(def.getValueInit(FieldName: "assemblyFormat"))
852 .Case<StringInit>(caseFn: [&](auto *init) { return init->getValue(); });
853}
854
855void Operator::print(llvm::raw_ostream &os) const {
856 os << "op '" << getOperationName() << "'\n";
857 for (Argument arg : arguments) {
858 if (auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: arg))
859 os << "[attribute] " << attr->name << '\n';
860 else
861 os << "[operand] " << cast<NamedTypeConstraint *>(Val&: arg)->name << '\n';
862 }
863}
864
865auto Operator::VariableDecoratorIterator::unwrap(const Init *init)
866 -> VariableDecorator {
867 return VariableDecorator(cast<DefInit>(Val: init)->getDef());
868}
869
870auto Operator::getArgToOperandOrAttribute(int index) const
871 -> OperandOrAttribute {
872 return attrOrOperandMapping[index];
873}
874
875std::string Operator::getGetterName(StringRef name) const {
876 return "get" + convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
877}
878
879std::string Operator::getSetterName(StringRef name) const {
880 return "set" + convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
881}
882
883std::string Operator::getRemoverName(StringRef name) const {
884 return "remove" + convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
885}
886
887bool Operator::hasFolder() const { return def.getValueAsBit(FieldName: "hasFolder"); }
888
889bool Operator::useCustomPropertiesEncoding() const {
890 return def.getValueAsBit(FieldName: "useCustomPropertiesEncoding");
891}
892

source code of mlir/lib/TableGen/Operator.cpp