1//===- Operator.cpp - Operator class --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Operator wrapper to simplify using TableGen Record defining a MLIR Op.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/TableGen/Operator.h"
14#include "mlir/TableGen/Argument.h"
15#include "mlir/TableGen/Predicate.h"
16#include "mlir/TableGen/Trait.h"
17#include "mlir/TableGen/Type.h"
18#include "llvm/ADT/EquivalenceClasses.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/Sequence.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/ErrorHandling.h"
26#include "llvm/Support/FormatVariadic.h"
27#include "llvm/TableGen/Error.h"
28#include "llvm/TableGen/Record.h"
29#include <list>
30
31#define DEBUG_TYPE "mlir-tblgen-operator"
32
33using namespace mlir;
34using namespace mlir::tblgen;
35
36using llvm::DagInit;
37using llvm::DefInit;
38using llvm::Record;
39
40Operator::Operator(const llvm::Record &def)
41 : dialect(def.getValueAsDef(FieldName: "opDialect")), def(def) {
42 // The first `_` in the op's TableGen def name is treated as separating the
43 // dialect prefix and the op class name. The dialect prefix will be ignored if
44 // not empty. Otherwise, if def name starts with a `_`, the `_` is considered
45 // as part of the class name.
46 StringRef prefix;
47 std::tie(args&: prefix, args&: cppClassName) = def.getName().split(Separator: '_');
48 if (prefix.empty()) {
49 // Class name with a leading underscore and without dialect prefix
50 cppClassName = def.getName();
51 } else if (cppClassName.empty()) {
52 // Class name without dialect prefix
53 cppClassName = prefix;
54 }
55
56 cppNamespace = def.getValueAsString(FieldName: "cppNamespace");
57
58 populateOpStructure();
59 assertInvariants();
60}
61
62std::string Operator::getOperationName() const {
63 auto prefix = dialect.getName();
64 auto opName = def.getValueAsString(FieldName: "opName");
65 if (prefix.empty())
66 return std::string(opName);
67 return std::string(llvm::formatv(Fmt: "{0}.{1}", Vals&: prefix, Vals&: opName));
68}
69
70std::string Operator::getAdaptorName() const {
71 return std::string(llvm::formatv(Fmt: "{0}Adaptor", Vals: getCppClassName()));
72}
73
74std::string Operator::getGenericAdaptorName() const {
75 return std::string(llvm::formatv(Fmt: "{0}GenericAdaptor", Vals: getCppClassName()));
76}
77
78/// Assert the invariants of accessors generated for the given name.
79static void assertAccessorInvariants(const Operator &op, StringRef name) {
80 std::string accessorName =
81 convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
82
83 // Functor used to detect when an accessor will cause an overlap with an
84 // operation API.
85 //
86 // There are a little bit more invasive checks possible for cases where not
87 // all ops have the trait that would cause overlap. For many cases here,
88 // renaming would be better (e.g., we can only guard in limited manner
89 // against methods from traits and interfaces here, so avoiding these in op
90 // definition is safer).
91 auto nameOverlapsWithOpAPI = [&](StringRef newName) {
92 if (newName == "AttributeNames" || newName == "Attributes" ||
93 newName == "Operation")
94 return true;
95 if (newName == "Operands")
96 return op.getNumOperands() != 1 || op.getNumVariableLengthOperands() != 1;
97 if (newName == "Regions")
98 return op.getNumRegions() != 1 || op.getNumVariadicRegions() != 1;
99 if (newName == "Type")
100 return op.getNumResults() != 1;
101 return false;
102 };
103 if (nameOverlapsWithOpAPI(accessorName)) {
104 // This error could be avoided in situations where the final function is
105 // identical, but preferably the op definition should avoid using generic
106 // names.
107 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "generated accessor for `" + name +
108 "` overlaps with a default one; please "
109 "rename to avoid overlap");
110 }
111}
112
113void Operator::assertInvariants() const {
114 // Check that the name of arguments/results/regions/successors don't overlap.
115 DenseMap<StringRef, StringRef> existingNames;
116 auto checkName = [&](StringRef name, StringRef entity) {
117 if (name.empty())
118 return;
119 auto insertion = existingNames.insert(KV: {name, entity});
120 if (insertion.second) {
121 // Assert invariants for accessors generated for this name.
122 assertAccessorInvariants(op: *this, name);
123 return;
124 }
125 if (entity == insertion.first->second)
126 PrintFatalError(ErrorLoc: getLoc(), Msg: "op has a conflict with two " + entity +
127 " having the same name '" + name + "'");
128 PrintFatalError(ErrorLoc: getLoc(), Msg: "op has a conflict with " +
129 insertion.first->second + " and " + entity +
130 " both having an entry with the name '" +
131 name + "'");
132 };
133 // Check operands amongst themselves.
134 for (int i : llvm::seq<int>(Begin: 0, End: getNumOperands()))
135 checkName(getOperand(index: i).name, "operands");
136
137 // Check results amongst themselves and against operands.
138 for (int i : llvm::seq<int>(Begin: 0, End: getNumResults()))
139 checkName(getResult(index: i).name, "results");
140
141 // Check regions amongst themselves and against operands and results.
142 for (int i : llvm::seq<int>(Begin: 0, End: getNumRegions()))
143 checkName(getRegion(index: i).name, "regions");
144
145 // Check successors amongst themselves and against operands, results, and
146 // regions.
147 for (int i : llvm::seq<int>(Begin: 0, End: getNumSuccessors()))
148 checkName(getSuccessor(index: i).name, "successors");
149}
150
151StringRef Operator::getDialectName() const { return dialect.getName(); }
152
153StringRef Operator::getCppClassName() const { return cppClassName; }
154
155std::string Operator::getQualCppClassName() const {
156 if (cppNamespace.empty())
157 return std::string(cppClassName);
158 return std::string(llvm::formatv(Fmt: "{0}::{1}", Vals: cppNamespace, Vals: cppClassName));
159}
160
161StringRef Operator::getCppNamespace() const { return cppNamespace; }
162
163int Operator::getNumResults() const {
164 DagInit *results = def.getValueAsDag(FieldName: "results");
165 return results->getNumArgs();
166}
167
168StringRef Operator::getExtraClassDeclaration() const {
169 constexpr auto attr = "extraClassDeclaration";
170 if (def.isValueUnset(FieldName: attr))
171 return {};
172 return def.getValueAsString(FieldName: attr);
173}
174
175StringRef Operator::getExtraClassDefinition() const {
176 constexpr auto attr = "extraClassDefinition";
177 if (def.isValueUnset(FieldName: attr))
178 return {};
179 return def.getValueAsString(FieldName: attr);
180}
181
182const llvm::Record &Operator::getDef() const { return def; }
183
184bool Operator::skipDefaultBuilders() const {
185 return def.getValueAsBit(FieldName: "skipDefaultBuilders");
186}
187
188auto Operator::result_begin() const -> const_value_iterator {
189 return results.begin();
190}
191
192auto Operator::result_end() const -> const_value_iterator {
193 return results.end();
194}
195
196auto Operator::getResults() const -> const_value_range {
197 return {result_begin(), result_end()};
198}
199
200TypeConstraint Operator::getResultTypeConstraint(int index) const {
201 DagInit *results = def.getValueAsDag(FieldName: "results");
202 return TypeConstraint(cast<DefInit>(Val: results->getArg(Num: index)));
203}
204
205StringRef Operator::getResultName(int index) const {
206 DagInit *results = def.getValueAsDag(FieldName: "results");
207 return results->getArgNameStr(Num: index);
208}
209
210auto Operator::getResultDecorators(int index) const -> var_decorator_range {
211 Record *result =
212 cast<DefInit>(Val: def.getValueAsDag(FieldName: "results")->getArg(Num: index))->getDef();
213 if (!result->isSubClassOf(Name: "OpVariable"))
214 return var_decorator_range(nullptr, nullptr);
215 return *result->getValueAsListInit(FieldName: "decorators");
216}
217
218unsigned Operator::getNumVariableLengthResults() const {
219 return llvm::count_if(Range: results, P: [](const NamedTypeConstraint &c) {
220 return c.constraint.isVariableLength();
221 });
222}
223
224unsigned Operator::getNumVariableLengthOperands() const {
225 return llvm::count_if(Range: operands, P: [](const NamedTypeConstraint &c) {
226 return c.constraint.isVariableLength();
227 });
228}
229
230bool Operator::hasSingleVariadicArg() const {
231 return getNumArgs() == 1 && getArg(index: 0).is<NamedTypeConstraint *>() &&
232 getOperand(index: 0).isVariadic();
233}
234
235Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); }
236
237Operator::arg_iterator Operator::arg_end() const { return arguments.end(); }
238
239Operator::arg_range Operator::getArgs() const {
240 return {arg_begin(), arg_end()};
241}
242
243StringRef Operator::getArgName(int index) const {
244 DagInit *argumentValues = def.getValueAsDag(FieldName: "arguments");
245 return argumentValues->getArgNameStr(Num: index);
246}
247
248auto Operator::getArgDecorators(int index) const -> var_decorator_range {
249 Record *arg =
250 cast<DefInit>(Val: def.getValueAsDag(FieldName: "arguments")->getArg(Num: index))->getDef();
251 if (!arg->isSubClassOf(Name: "OpVariable"))
252 return var_decorator_range(nullptr, nullptr);
253 return *arg->getValueAsListInit(FieldName: "decorators");
254}
255
256const Trait *Operator::getTrait(StringRef trait) const {
257 for (const auto &t : traits) {
258 if (const auto *traitDef = dyn_cast<NativeTrait>(Val: &t)) {
259 if (traitDef->getFullyQualifiedTraitName() == trait)
260 return traitDef;
261 } else if (const auto *traitDef = dyn_cast<InternalTrait>(Val: &t)) {
262 if (traitDef->getFullyQualifiedTraitName() == trait)
263 return traitDef;
264 } else if (const auto *traitDef = dyn_cast<InterfaceTrait>(Val: &t)) {
265 if (traitDef->getFullyQualifiedTraitName() == trait)
266 return traitDef;
267 }
268 }
269 return nullptr;
270}
271
272auto Operator::region_begin() const -> const_region_iterator {
273 return regions.begin();
274}
275auto Operator::region_end() const -> const_region_iterator {
276 return regions.end();
277}
278auto Operator::getRegions() const
279 -> llvm::iterator_range<const_region_iterator> {
280 return {region_begin(), region_end()};
281}
282
283unsigned Operator::getNumRegions() const { return regions.size(); }
284
285const NamedRegion &Operator::getRegion(unsigned index) const {
286 return regions[index];
287}
288
289unsigned Operator::getNumVariadicRegions() const {
290 return llvm::count_if(Range: regions,
291 P: [](const NamedRegion &c) { return c.isVariadic(); });
292}
293
294auto Operator::successor_begin() const -> const_successor_iterator {
295 return successors.begin();
296}
297auto Operator::successor_end() const -> const_successor_iterator {
298 return successors.end();
299}
300auto Operator::getSuccessors() const
301 -> llvm::iterator_range<const_successor_iterator> {
302 return {successor_begin(), successor_end()};
303}
304
305unsigned Operator::getNumSuccessors() const { return successors.size(); }
306
307const NamedSuccessor &Operator::getSuccessor(unsigned index) const {
308 return successors[index];
309}
310
311unsigned Operator::getNumVariadicSuccessors() const {
312 return llvm::count_if(Range: successors,
313 P: [](const NamedSuccessor &c) { return c.isVariadic(); });
314}
315
316auto Operator::trait_begin() const -> const_trait_iterator {
317 return traits.begin();
318}
319auto Operator::trait_end() const -> const_trait_iterator {
320 return traits.end();
321}
322auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> {
323 return {trait_begin(), trait_end()};
324}
325
326auto Operator::attribute_begin() const -> const_attribute_iterator {
327 return attributes.begin();
328}
329auto Operator::attribute_end() const -> const_attribute_iterator {
330 return attributes.end();
331}
332auto Operator::getAttributes() const
333 -> llvm::iterator_range<const_attribute_iterator> {
334 return {attribute_begin(), attribute_end()};
335}
336auto Operator::attribute_begin() -> attribute_iterator {
337 return attributes.begin();
338}
339auto Operator::attribute_end() -> attribute_iterator {
340 return attributes.end();
341}
342auto Operator::getAttributes() -> llvm::iterator_range<attribute_iterator> {
343 return {attribute_begin(), attribute_end()};
344}
345
346auto Operator::operand_begin() const -> const_value_iterator {
347 return operands.begin();
348}
349auto Operator::operand_end() const -> const_value_iterator {
350 return operands.end();
351}
352auto Operator::getOperands() const -> const_value_range {
353 return {operand_begin(), operand_end()};
354}
355
356auto Operator::getArg(int index) const -> Argument { return arguments[index]; }
357
358bool Operator::isVariadic() const {
359 return any_of(Range: llvm::concat<const NamedTypeConstraint>(Ranges: operands, Ranges: results),
360 P: [](const NamedTypeConstraint &op) { return op.isVariadic(); });
361}
362
363void Operator::populateTypeInferenceInfo(
364 const llvm::StringMap<int> &argumentsAndResultsIndex) {
365 // If the type inference op interface is not registered, then do not attempt
366 // to determine if the result types an be inferred.
367 auto &recordKeeper = def.getRecords();
368 auto *inferTrait = recordKeeper.getDef(Name: inferTypeOpInterface);
369 allResultsHaveKnownTypes = false;
370 if (!inferTrait)
371 return;
372
373 // If there are no results, the skip this else the build method generated
374 // overlaps with another autogenerated builder.
375 if (getNumResults() == 0)
376 return;
377
378 // Skip ops with variadic or optional results.
379 if (getNumVariableLengthResults() > 0)
380 return;
381
382 // Skip cases currently being custom generated.
383 // TODO: Remove special cases.
384 if (getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType")) {
385 // Check for a non-variable length operand to use as the type anchor.
386 auto *operandI = llvm::find_if(Range&: arguments, P: [](const Argument &arg) {
387 NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: arg);
388 return operand && !operand->isVariableLength();
389 });
390 if (operandI == arguments.end())
391 return;
392
393 // All result types are inferred from the operand type.
394 int operandIdx = operandI - arguments.begin();
395 for (int i = 0; i < getNumResults(); ++i)
396 resultTypeMapping.emplace_back(Args&: operandIdx, Args: "$_self");
397
398 allResultsHaveKnownTypes = true;
399 traits.push_back(Elt: Trait::create(init: inferTrait->getDefInit()));
400 return;
401 }
402
403 /// This struct represents a node in this operation's result type inferenece
404 /// graph. Each node has a list of incoming type inference edges `sources`.
405 /// Each edge represents a "source" from which the result type can be
406 /// inferred, either an operand (leaf) or another result (node). When a node
407 /// is known to have a fully-inferred type, `inferred` is set to true.
408 struct ResultTypeInference {
409 /// The list of incoming type inference edges.
410 SmallVector<InferredResultType> sources;
411 /// This flag is set to true when the result type is known to be inferrable.
412 bool inferred = false;
413 };
414
415 // This vector represents the type inference graph, with one node for each
416 // operation result. The nth element is the node for the nth result.
417 SmallVector<ResultTypeInference> inference(getNumResults(), {});
418
419 // For all results whose types are buildable, initialize their type inference
420 // nodes with an edge to themselves. Mark those nodes are fully-inferred.
421 for (auto [idx, infer] : llvm::enumerate(First&: inference)) {
422 if (getResult(index: idx).constraint.getBuilderCall()) {
423 infer.sources.emplace_back(Args: InferredResultType::mapResultIndex(i: idx),
424 Args: "$_self");
425 infer.inferred = true;
426 }
427 }
428
429 // Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the
430 // result type inference graph.
431 for (const Trait &trait : traits) {
432 const llvm::Record &def = trait.getDef();
433
434 // If the infer type op interface was manually added, then treat it as
435 // intention that the op needs special handling.
436 // TODO: Reconsider whether to always generate, this is more conservative
437 // and keeps existing behavior so starting that way for now.
438 if (def.isSubClassOf(
439 Name: llvm::formatv(Fmt: "{0}::Trait", Vals&: inferTypeOpInterface).str()))
440 return;
441 if (const auto *traitDef = dyn_cast<InterfaceTrait>(Val: &trait))
442 if (&traitDef->getDef() == inferTrait)
443 return;
444
445 // The `TypesMatchWith` trait represents a 1 -> 1 type inference edge with a
446 // type transformer.
447 if (def.isSubClassOf(Name: "TypesMatchWith")) {
448 int target = argumentsAndResultsIndex.lookup(Key: def.getValueAsString(FieldName: "rhs"));
449 // Ignore operand type inference.
450 if (InferredResultType::isArgIndex(i: target))
451 continue;
452 int resultIndex = InferredResultType::unmapResultIndex(i: target);
453 ResultTypeInference &infer = inference[resultIndex];
454 // If the type of the result has already been inferred, do nothing.
455 if (infer.inferred)
456 continue;
457 int sourceIndex =
458 argumentsAndResultsIndex.lookup(Key: def.getValueAsString(FieldName: "lhs"));
459 infer.sources.emplace_back(Args&: sourceIndex,
460 Args: def.getValueAsString(FieldName: "transformer").str());
461 // Locally propagate inferredness.
462 infer.inferred =
463 InferredResultType::isArgIndex(i: sourceIndex) ||
464 inference[InferredResultType::unmapResultIndex(i: sourceIndex)].inferred;
465 continue;
466 }
467
468 if (!def.isSubClassOf(Name: "AllTypesMatch"))
469 continue;
470
471 auto values = def.getValueAsListOfStrings(FieldName: "values");
472 // The `AllTypesMatch` trait represents an N <-> N fanin and fanout. That
473 // is, every result type has an edge from every other type. However, if any
474 // one of the values refers to an operand or a result with a fully-inferred
475 // type, we can infer all other types from that value. Try to find a
476 // fully-inferred type in the list.
477 std::optional<int> fullyInferredIndex;
478 SmallVector<int> resultIndices;
479 for (StringRef name : values) {
480 int index = argumentsAndResultsIndex.lookup(Key: name);
481 if (InferredResultType::isResultIndex(i: index))
482 resultIndices.push_back(Elt: InferredResultType::unmapResultIndex(i: index));
483 if (InferredResultType::isArgIndex(i: index) ||
484 inference[InferredResultType::unmapResultIndex(i: index)].inferred)
485 fullyInferredIndex = index;
486 }
487 if (fullyInferredIndex) {
488 // Make the fully-inferred type the only source for all results that
489 // aren't already inferred -- a 1 -> N fanout.
490 for (int resultIndex : resultIndices) {
491 ResultTypeInference &infer = inference[resultIndex];
492 if (!infer.inferred) {
493 infer.sources.assign(NumElts: 1, Elt: {*fullyInferredIndex, "$_self"});
494 infer.inferred = true;
495 }
496 }
497 } else {
498 // Add an edge between every result and every other type; N <-> N.
499 for (int resultIndex : resultIndices) {
500 for (int otherResultIndex : resultIndices) {
501 if (resultIndex == otherResultIndex)
502 continue;
503 inference[resultIndex].sources.emplace_back(Args&: otherResultIndex,
504 Args: "$_self");
505 }
506 }
507 }
508 }
509
510 // Propagate inferredness until a fixed point.
511 std::vector<ResultTypeInference *> worklist;
512 for (ResultTypeInference &infer : inference)
513 if (!infer.inferred)
514 worklist.push_back(x: &infer);
515 bool changed;
516 do {
517 changed = false;
518 for (auto cur = worklist.begin(); cur != worklist.end();) {
519 ResultTypeInference &infer = **cur;
520
521 InferredResultType *iter =
522 llvm::find_if(Range&: infer.sources, P: [&](const InferredResultType &source) {
523 assert(InferredResultType::isResultIndex(source.getIndex()));
524 return inference[InferredResultType::unmapResultIndex(
525 i: source.getIndex())]
526 .inferred;
527 });
528 if (iter == infer.sources.end()) {
529 ++cur;
530 continue;
531 }
532
533 changed = true;
534 infer.inferred = true;
535 // Make this the only source for the result. This breaks any cycles.
536 infer.sources.assign(NumElts: 1, Elt: *iter);
537 cur = worklist.erase(position: cur);
538 }
539 } while (changed);
540
541 allResultsHaveKnownTypes = worklist.empty();
542
543 // If the types could be computed, then add type inference trait.
544 if (allResultsHaveKnownTypes) {
545 traits.push_back(Elt: Trait::create(init: inferTrait->getDefInit()));
546 for (const ResultTypeInference &infer : inference)
547 resultTypeMapping.push_back(Elt: infer.sources.front());
548 }
549}
550
551void Operator::populateOpStructure() {
552 auto &recordKeeper = def.getRecords();
553 auto *typeConstraintClass = recordKeeper.getClass(Name: "TypeConstraint");
554 auto *attrClass = recordKeeper.getClass(Name: "Attr");
555 auto *propertyClass = recordKeeper.getClass(Name: "Property");
556 auto *derivedAttrClass = recordKeeper.getClass(Name: "DerivedAttr");
557 auto *opVarClass = recordKeeper.getClass(Name: "OpVariable");
558 numNativeAttributes = 0;
559
560 DagInit *argumentValues = def.getValueAsDag(FieldName: "arguments");
561 unsigned numArgs = argumentValues->getNumArgs();
562
563 // Mapping from name of to argument or result index. Arguments are indexed
564 // to match getArg index, while the results are negatively indexed.
565 llvm::StringMap<int> argumentsAndResultsIndex;
566
567 // Handle operands and native attributes.
568 for (unsigned i = 0; i != numArgs; ++i) {
569 auto *arg = argumentValues->getArg(Num: i);
570 auto givenName = argumentValues->getArgNameStr(Num: i);
571 auto *argDefInit = dyn_cast<DefInit>(Val: arg);
572 if (!argDefInit)
573 PrintFatalError(ErrorLoc: def.getLoc(),
574 Msg: Twine("undefined type for argument #") + Twine(i));
575 Record *argDef = argDefInit->getDef();
576 if (argDef->isSubClassOf(R: opVarClass))
577 argDef = argDef->getValueAsDef(FieldName: "constraint");
578
579 if (argDef->isSubClassOf(R: typeConstraintClass)) {
580 operands.push_back(
581 Elt: NamedTypeConstraint{.name: givenName, .constraint: TypeConstraint(argDef)});
582 } else if (argDef->isSubClassOf(R: attrClass)) {
583 if (givenName.empty())
584 PrintFatalError(ErrorLoc: argDef->getLoc(), Msg: "attributes must be named");
585 if (argDef->isSubClassOf(R: derivedAttrClass))
586 PrintFatalError(ErrorLoc: argDef->getLoc(),
587 Msg: "derived attributes not allowed in argument list");
588 attributes.push_back(Elt: {.name: givenName, .attr: Attribute(argDef)});
589 ++numNativeAttributes;
590 } else if (argDef->isSubClassOf(R: propertyClass)) {
591 if (givenName.empty())
592 PrintFatalError(ErrorLoc: argDef->getLoc(), Msg: "properties must be named");
593 properties.push_back(Elt: {.name: givenName, .prop: Property(argDef)});
594 } else {
595 PrintFatalError(ErrorLoc: def.getLoc(),
596 Msg: "unexpected def type; only defs deriving "
597 "from TypeConstraint or Attr or Property are allowed");
598 }
599 if (!givenName.empty())
600 argumentsAndResultsIndex[givenName] = i;
601 }
602
603 // Handle derived attributes.
604 for (const auto &val : def.getValues()) {
605 if (auto *record = dyn_cast<llvm::RecordRecTy>(Val: val.getType())) {
606 if (!record->isSubClassOf(Class: attrClass))
607 continue;
608 if (!record->isSubClassOf(Class: derivedAttrClass))
609 PrintFatalError(ErrorLoc: def.getLoc(),
610 Msg: "unexpected Attr where only DerivedAttr is allowed");
611
612 if (record->getClasses().size() != 1) {
613 PrintFatalError(
614 ErrorLoc: def.getLoc(),
615 Msg: "unsupported attribute modelling, only single class expected");
616 }
617 attributes.push_back(
618 Elt: {.name: cast<llvm::StringInit>(Val: val.getNameInit())->getValue(),
619 .attr: Attribute(cast<DefInit>(Val: val.getValue()))});
620 }
621 }
622
623 // Populate `arguments`. This must happen after we've finalized `operands` and
624 // `attributes` because we will put their elements' pointers in `arguments`.
625 // SmallVector may perform re-allocation under the hood when adding new
626 // elements.
627 int operandIndex = 0, attrIndex = 0, propIndex = 0;
628 for (unsigned i = 0; i != numArgs; ++i) {
629 Record *argDef = dyn_cast<DefInit>(Val: argumentValues->getArg(Num: i))->getDef();
630 if (argDef->isSubClassOf(R: opVarClass))
631 argDef = argDef->getValueAsDef(FieldName: "constraint");
632
633 if (argDef->isSubClassOf(R: typeConstraintClass)) {
634 attrOrOperandMapping.push_back(
635 Elt: {OperandOrAttribute::Kind::Operand, operandIndex});
636 arguments.emplace_back(Args: &operands[operandIndex++]);
637 } else if (argDef->isSubClassOf(R: attrClass)) {
638 attrOrOperandMapping.push_back(
639 Elt: {OperandOrAttribute::Kind::Attribute, attrIndex});
640 arguments.emplace_back(Args: &attributes[attrIndex++]);
641 } else {
642 assert(argDef->isSubClassOf(propertyClass));
643 arguments.emplace_back(Args: &properties[propIndex++]);
644 }
645 }
646
647 auto *resultsDag = def.getValueAsDag(FieldName: "results");
648 auto *outsOp = dyn_cast<DefInit>(Val: resultsDag->getOperator());
649 if (!outsOp || outsOp->getDef()->getName() != "outs") {
650 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "'results' must have 'outs' directive");
651 }
652
653 // Handle results.
654 for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
655 auto name = resultsDag->getArgNameStr(Num: i);
656 auto *resultInit = dyn_cast<DefInit>(Val: resultsDag->getArg(Num: i));
657 if (!resultInit) {
658 PrintFatalError(ErrorLoc: def.getLoc(),
659 Msg: Twine("undefined type for result #") + Twine(i));
660 }
661 auto *resultDef = resultInit->getDef();
662 if (resultDef->isSubClassOf(R: opVarClass))
663 resultDef = resultDef->getValueAsDef(FieldName: "constraint");
664 results.push_back(Elt: {.name: name, .constraint: TypeConstraint(resultDef)});
665 if (!name.empty())
666 argumentsAndResultsIndex[name] = InferredResultType::mapResultIndex(i);
667
668 // We currently only support VariadicOfVariadic operands.
669 if (results.back().constraint.isVariadicOfVariadic()) {
670 PrintFatalError(
671 ErrorLoc: def.getLoc(),
672 Msg: "'VariadicOfVariadic' results are currently not supported");
673 }
674 }
675
676 // Handle successors
677 auto *successorsDag = def.getValueAsDag(FieldName: "successors");
678 auto *successorsOp = dyn_cast<DefInit>(Val: successorsDag->getOperator());
679 if (!successorsOp || successorsOp->getDef()->getName() != "successor") {
680 PrintFatalError(ErrorLoc: def.getLoc(),
681 Msg: "'successors' must have 'successor' directive");
682 }
683
684 for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
685 auto name = successorsDag->getArgNameStr(Num: i);
686 auto *successorInit = dyn_cast<DefInit>(Val: successorsDag->getArg(Num: i));
687 if (!successorInit) {
688 PrintFatalError(ErrorLoc: def.getLoc(),
689 Msg: Twine("undefined kind for successor #") + Twine(i));
690 }
691 Successor successor(successorInit->getDef());
692
693 // Only support variadic successors if it is the last one for now.
694 if (i != e - 1 && successor.isVariadic())
695 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "only the last successor can be variadic");
696 successors.push_back(Elt: {.name: name, .constraint: successor});
697 }
698
699 // Create list of traits, skipping over duplicates: appending to lists in
700 // tablegen is easy, making them unique less so, so dedupe here.
701 if (auto *traitList = def.getValueAsListInit(FieldName: "traits")) {
702 // This is uniquing based on pointers of the trait.
703 SmallPtrSet<const llvm::Init *, 32> traitSet;
704 traits.reserve(N: traitSet.size());
705
706 // The declaration order of traits imply the verification order of traits.
707 // Some traits may require other traits to be verified first then they can
708 // do further verification based on those verified facts. If you see this
709 // error, fix the traits declaration order by checking the `dependentTraits`
710 // field.
711 auto verifyTraitValidity = [&](Record *trait) {
712 auto *dependentTraits = trait->getValueAsListInit(FieldName: "dependentTraits");
713 for (auto *traitInit : *dependentTraits)
714 if (!traitSet.contains(Ptr: traitInit))
715 PrintFatalError(
716 ErrorLoc: def.getLoc(),
717 Msg: trait->getValueAsString(FieldName: "trait") + " requires " +
718 cast<DefInit>(Val: traitInit)->getDef()->getValueAsString(
719 FieldName: "trait") +
720 " to precede it in traits list");
721 };
722
723 std::function<void(llvm::ListInit *)> insert;
724 insert = [&](llvm::ListInit *traitList) {
725 for (auto *traitInit : *traitList) {
726 auto *def = cast<DefInit>(Val: traitInit)->getDef();
727 if (def->isSubClassOf(Name: "TraitList")) {
728 insert(def->getValueAsListInit(FieldName: "traits"));
729 continue;
730 }
731
732 // Ignore duplicates.
733 if (!traitSet.insert(Ptr: traitInit).second)
734 continue;
735
736 // If this is an interface with base classes, add the bases to the
737 // trait list.
738 if (def->isSubClassOf(Name: "Interface"))
739 insert(def->getValueAsListInit(FieldName: "baseInterfaces"));
740
741 // Verify if the trait has all the dependent traits declared before
742 // itself.
743 verifyTraitValidity(def);
744 traits.push_back(Elt: Trait::create(init: traitInit));
745 }
746 };
747 insert(traitList);
748 }
749
750 populateTypeInferenceInfo(argumentsAndResultsIndex);
751
752 // Handle regions
753 auto *regionsDag = def.getValueAsDag(FieldName: "regions");
754 auto *regionsOp = dyn_cast<DefInit>(Val: regionsDag->getOperator());
755 if (!regionsOp || regionsOp->getDef()->getName() != "region") {
756 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "'regions' must have 'region' directive");
757 }
758
759 for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
760 auto name = regionsDag->getArgNameStr(Num: i);
761 auto *regionInit = dyn_cast<DefInit>(Val: regionsDag->getArg(Num: i));
762 if (!regionInit) {
763 PrintFatalError(ErrorLoc: def.getLoc(),
764 Msg: Twine("undefined kind for region #") + Twine(i));
765 }
766 Region region(regionInit->getDef());
767 if (region.isVariadic()) {
768 // Only support variadic regions if it is the last one for now.
769 if (i != e - 1)
770 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "only the last region can be variadic");
771 if (name.empty())
772 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "variadic regions must be named");
773 }
774
775 regions.push_back(Elt: {.name: name, .constraint: region});
776 }
777
778 // Populate the builders.
779 auto *builderList =
780 dyn_cast_or_null<llvm::ListInit>(Val: def.getValueInit(FieldName: "builders"));
781 if (builderList && !builderList->empty()) {
782 for (llvm::Init *init : builderList->getValues())
783 builders.emplace_back(Args: cast<llvm::DefInit>(Val: init)->getDef(), Args: def.getLoc());
784 } else if (skipDefaultBuilders()) {
785 PrintFatalError(
786 ErrorLoc: def.getLoc(),
787 Msg: "default builders are skipped and no custom builders provided");
788 }
789
790 LLVM_DEBUG(print(llvm::dbgs()));
791}
792
793const InferredResultType &Operator::getInferredResultType(int index) const {
794 assert(allResultTypesKnown());
795 return resultTypeMapping[index];
796}
797
798ArrayRef<SMLoc> Operator::getLoc() const { return def.getLoc(); }
799
800bool Operator::hasDescription() const {
801 return def.getValue(Name: "description") != nullptr;
802}
803
804StringRef Operator::getDescription() const {
805 return def.getValueAsString(FieldName: "description");
806}
807
808bool Operator::hasSummary() const { return def.getValue(Name: "summary") != nullptr; }
809
810StringRef Operator::getSummary() const {
811 return def.getValueAsString(FieldName: "summary");
812}
813
814bool Operator::hasAssemblyFormat() const {
815 auto *valueInit = def.getValueInit(FieldName: "assemblyFormat");
816 return isa<llvm::StringInit>(Val: valueInit);
817}
818
819StringRef Operator::getAssemblyFormat() const {
820 return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit(FieldName: "assemblyFormat"))
821 .Case<llvm::StringInit>(caseFn: [&](auto *init) { return init->getValue(); });
822}
823
824void Operator::print(llvm::raw_ostream &os) const {
825 os << "op '" << getOperationName() << "'\n";
826 for (Argument arg : arguments) {
827 if (auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: arg))
828 os << "[attribute] " << attr->name << '\n';
829 else
830 os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
831 }
832}
833
834auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
835 -> VariableDecorator {
836 return VariableDecorator(cast<llvm::DefInit>(Val: init)->getDef());
837}
838
839auto Operator::getArgToOperandOrAttribute(int index) const
840 -> OperandOrAttribute {
841 return attrOrOperandMapping[index];
842}
843
844std::string Operator::getGetterName(StringRef name) const {
845 return "get" + convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
846}
847
848std::string Operator::getSetterName(StringRef name) const {
849 return "set" + convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
850}
851
852std::string Operator::getRemoverName(StringRef name) const {
853 return "remove" + convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
854}
855
856bool Operator::hasFolder() const { return def.getValueAsBit(FieldName: "hasFolder"); }
857
858bool Operator::useCustomPropertiesEncoding() const {
859 return def.getValueAsBit(FieldName: "useCustomPropertiesEncoding");
860}
861

source code of mlir/lib/TableGen/Operator.cpp