1//===- Operator.cpp - Operator class --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Operator wrapper to simplify using TableGen Record defining a MLIR Op.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/TableGen/Operator.h"
14#include "mlir/TableGen/Argument.h"
15#include "mlir/TableGen/Predicate.h"
16#include "mlir/TableGen/Trait.h"
17#include "mlir/TableGen/Type.h"
18#include "llvm/ADT/EquivalenceClasses.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/Sequence.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/ErrorHandling.h"
26#include "llvm/Support/FormatVariadic.h"
27#include "llvm/TableGen/Error.h"
28#include "llvm/TableGen/Record.h"
29#include <list>
30
31#define DEBUG_TYPE "mlir-tblgen-operator"
32
33using namespace mlir;
34using namespace mlir::tblgen;
35
36using llvm::DagInit;
37using llvm::DefInit;
38using llvm::Init;
39using llvm::ListInit;
40using llvm::Record;
41using llvm::StringInit;
42
43Operator::Operator(const Record &def)
44 : dialect(def.getValueAsDef(FieldName: "opDialect")), def(def) {
45 // The first `_` in the op's TableGen def name is treated as separating the
46 // dialect prefix and the op class name. The dialect prefix will be ignored if
47 // not empty. Otherwise, if def name starts with a `_`, the `_` is considered
48 // as part of the class name.
49 StringRef prefix;
50 std::tie(args&: prefix, args&: cppClassName) = def.getName().split(Separator: '_');
51 if (prefix.empty()) {
52 // Class name with a leading underscore and without dialect prefix
53 cppClassName = def.getName();
54 } else if (cppClassName.empty()) {
55 // Class name without dialect prefix
56 cppClassName = prefix;
57 }
58
59 cppNamespace = def.getValueAsString(FieldName: "cppNamespace");
60
61 populateOpStructure();
62 assertInvariants();
63}
64
65std::string Operator::getOperationName() const {
66 auto prefix = dialect.getName();
67 auto opName = def.getValueAsString(FieldName: "opName");
68 if (prefix.empty())
69 return std::string(opName);
70 return std::string(llvm::formatv(Fmt: "{0}.{1}", Vals&: prefix, Vals&: opName));
71}
72
73std::string Operator::getAdaptorName() const {
74 return std::string(llvm::formatv(Fmt: "{0}Adaptor", Vals: getCppClassName()));
75}
76
77std::string Operator::getGenericAdaptorName() const {
78 return std::string(llvm::formatv(Fmt: "{0}GenericAdaptor", Vals: getCppClassName()));
79}
80
81/// Assert the invariants of accessors generated for the given name.
82static void assertAccessorInvariants(const Operator &op, StringRef name) {
83 std::string accessorName =
84 convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
85
86 // Functor used to detect when an accessor will cause an overlap with an
87 // operation API.
88 //
89 // There are a little bit more invasive checks possible for cases where not
90 // all ops have the trait that would cause overlap. For many cases here,
91 // renaming would be better (e.g., we can only guard in limited manner
92 // against methods from traits and interfaces here, so avoiding these in op
93 // definition is safer).
94 auto nameOverlapsWithOpAPI = [&](StringRef newName) {
95 if (newName == "AttributeNames" || newName == "Attributes" ||
96 newName == "Operation")
97 return true;
98 if (newName == "Operands")
99 return op.getNumOperands() != 1 || op.getNumVariableLengthOperands() != 1;
100 if (newName == "Regions")
101 return op.getNumRegions() != 1 || op.getNumVariadicRegions() != 1;
102 if (newName == "Type")
103 return op.getNumResults() != 1;
104 return false;
105 };
106 if (nameOverlapsWithOpAPI(accessorName)) {
107 // This error could be avoided in situations where the final function is
108 // identical, but preferably the op definition should avoid using generic
109 // names.
110 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "generated accessor for `" + name +
111 "` overlaps with a default one; please "
112 "rename to avoid overlap");
113 }
114}
115
116void Operator::assertInvariants() const {
117 // Check that the name of arguments/results/regions/successors don't overlap.
118 DenseMap<StringRef, StringRef> existingNames;
119 auto checkName = [&](StringRef name, StringRef entity) {
120 if (name.empty())
121 return;
122 auto insertion = existingNames.insert(KV: {name, entity});
123 if (insertion.second) {
124 // Assert invariants for accessors generated for this name.
125 assertAccessorInvariants(op: *this, name);
126 return;
127 }
128 if (entity == insertion.first->second)
129 PrintFatalError(ErrorLoc: getLoc(), Msg: "op has a conflict with two " + entity +
130 " having the same name '" + name + "'");
131 PrintFatalError(ErrorLoc: getLoc(), Msg: "op has a conflict with " +
132 insertion.first->second + " and " + entity +
133 " both having an entry with the name '" +
134 name + "'");
135 };
136 // Check operands amongst themselves.
137 for (int i : llvm::seq<int>(Begin: 0, End: getNumOperands()))
138 checkName(getOperand(index: i).name, "operands");
139
140 // Check results amongst themselves and against operands.
141 for (int i : llvm::seq<int>(Begin: 0, End: getNumResults()))
142 checkName(getResult(index: i).name, "results");
143
144 // Check regions amongst themselves and against operands and results.
145 for (int i : llvm::seq<int>(Begin: 0, End: getNumRegions()))
146 checkName(getRegion(index: i).name, "regions");
147
148 // Check successors amongst themselves and against operands, results, and
149 // regions.
150 for (int i : llvm::seq<int>(Begin: 0, End: getNumSuccessors()))
151 checkName(getSuccessor(index: i).name, "successors");
152}
153
154StringRef Operator::getDialectName() const { return dialect.getName(); }
155
156StringRef Operator::getCppClassName() const { return cppClassName; }
157
158std::string Operator::getQualCppClassName() const {
159 if (cppNamespace.empty())
160 return std::string(cppClassName);
161 return std::string(llvm::formatv(Fmt: "{0}::{1}", Vals: cppNamespace, Vals: cppClassName));
162}
163
164StringRef Operator::getCppNamespace() const { return cppNamespace; }
165
166int Operator::getNumResults() const {
167 const DagInit *results = def.getValueAsDag(FieldName: "results");
168 return results->getNumArgs();
169}
170
171StringRef Operator::getExtraClassDeclaration() const {
172 constexpr auto attr = "extraClassDeclaration";
173 if (def.isValueUnset(FieldName: attr))
174 return {};
175 return def.getValueAsString(FieldName: attr);
176}
177
178StringRef Operator::getExtraClassDefinition() const {
179 constexpr auto attr = "extraClassDefinition";
180 if (def.isValueUnset(FieldName: attr))
181 return {};
182 return def.getValueAsString(FieldName: attr);
183}
184
185const Record &Operator::getDef() const { return def; }
186
187bool Operator::skipDefaultBuilders() const {
188 return def.getValueAsBit(FieldName: "skipDefaultBuilders");
189}
190
191auto Operator::result_begin() const -> const_value_iterator {
192 return results.begin();
193}
194
195auto Operator::result_end() const -> const_value_iterator {
196 return results.end();
197}
198
199auto Operator::getResults() const -> const_value_range {
200 return {result_begin(), result_end()};
201}
202
203TypeConstraint Operator::getResultTypeConstraint(int index) const {
204 const DagInit *results = def.getValueAsDag(FieldName: "results");
205 return TypeConstraint(cast<DefInit>(Val: results->getArg(Num: index)));
206}
207
208StringRef Operator::getResultName(int index) const {
209 const DagInit *results = def.getValueAsDag(FieldName: "results");
210 return results->getArgNameStr(Num: index);
211}
212
213auto Operator::getResultDecorators(int index) const -> var_decorator_range {
214 const Record *result =
215 cast<DefInit>(Val: def.getValueAsDag(FieldName: "results")->getArg(Num: index))->getDef();
216 if (!result->isSubClassOf(Name: "OpVariable"))
217 return var_decorator_range(nullptr, nullptr);
218 return *result->getValueAsListInit(FieldName: "decorators");
219}
220
221unsigned Operator::getNumVariableLengthResults() const {
222 return llvm::count_if(Range: results, P: [](const NamedTypeConstraint &c) {
223 return c.constraint.isVariableLength();
224 });
225}
226
227unsigned Operator::getNumVariableLengthOperands() const {
228 return llvm::count_if(Range: operands, P: [](const NamedTypeConstraint &c) {
229 return c.constraint.isVariableLength();
230 });
231}
232
233bool Operator::hasSingleVariadicArg() const {
234 return getNumArgs() == 1 && isa<NamedTypeConstraint *>(Val: getArg(index: 0)) &&
235 getOperand(index: 0).isVariadic();
236}
237
238Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); }
239
240Operator::arg_iterator Operator::arg_end() const { return arguments.end(); }
241
242Operator::arg_range Operator::getArgs() const {
243 return {arg_begin(), arg_end()};
244}
245
246StringRef Operator::getArgName(int index) const {
247 const DagInit *argumentValues = def.getValueAsDag(FieldName: "arguments");
248 return argumentValues->getArgNameStr(Num: index);
249}
250
251auto Operator::getArgDecorators(int index) const -> var_decorator_range {
252 const Record *arg =
253 cast<DefInit>(Val: def.getValueAsDag(FieldName: "arguments")->getArg(Num: index))->getDef();
254 if (!arg->isSubClassOf(Name: "OpVariable"))
255 return var_decorator_range(nullptr, nullptr);
256 return *arg->getValueAsListInit(FieldName: "decorators");
257}
258
259const Trait *Operator::getTrait(StringRef trait) const {
260 for (const auto &t : traits) {
261 if (const auto *traitDef = dyn_cast<NativeTrait>(Val: &t)) {
262 if (traitDef->getFullyQualifiedTraitName() == trait)
263 return traitDef;
264 } else if (const auto *traitDef = dyn_cast<InternalTrait>(Val: &t)) {
265 if (traitDef->getFullyQualifiedTraitName() == trait)
266 return traitDef;
267 } else if (const auto *traitDef = dyn_cast<InterfaceTrait>(Val: &t)) {
268 if (traitDef->getFullyQualifiedTraitName() == trait)
269 return traitDef;
270 }
271 }
272 return nullptr;
273}
274
275auto Operator::region_begin() const -> const_region_iterator {
276 return regions.begin();
277}
278auto Operator::region_end() const -> const_region_iterator {
279 return regions.end();
280}
281auto Operator::getRegions() const
282 -> llvm::iterator_range<const_region_iterator> {
283 return {region_begin(), region_end()};
284}
285
286unsigned Operator::getNumRegions() const { return regions.size(); }
287
288const NamedRegion &Operator::getRegion(unsigned index) const {
289 return regions[index];
290}
291
292unsigned Operator::getNumVariadicRegions() const {
293 return llvm::count_if(Range: regions,
294 P: [](const NamedRegion &c) { return c.isVariadic(); });
295}
296
297auto Operator::successor_begin() const -> const_successor_iterator {
298 return successors.begin();
299}
300auto Operator::successor_end() const -> const_successor_iterator {
301 return successors.end();
302}
303auto Operator::getSuccessors() const
304 -> llvm::iterator_range<const_successor_iterator> {
305 return {successor_begin(), successor_end()};
306}
307
308unsigned Operator::getNumSuccessors() const { return successors.size(); }
309
310const NamedSuccessor &Operator::getSuccessor(unsigned index) const {
311 return successors[index];
312}
313
314unsigned Operator::getNumVariadicSuccessors() const {
315 return llvm::count_if(Range: successors,
316 P: [](const NamedSuccessor &c) { return c.isVariadic(); });
317}
318
319auto Operator::trait_begin() const -> const_trait_iterator {
320 return traits.begin();
321}
322auto Operator::trait_end() const -> const_trait_iterator {
323 return traits.end();
324}
325auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> {
326 return {trait_begin(), trait_end()};
327}
328
329auto Operator::attribute_begin() const -> const_attribute_iterator {
330 return attributes.begin();
331}
332auto Operator::attribute_end() const -> const_attribute_iterator {
333 return attributes.end();
334}
335auto Operator::getAttributes() const
336 -> llvm::iterator_range<const_attribute_iterator> {
337 return {attribute_begin(), attribute_end()};
338}
339auto Operator::attribute_begin() -> attribute_iterator {
340 return attributes.begin();
341}
342auto Operator::attribute_end() -> attribute_iterator {
343 return attributes.end();
344}
345auto Operator::getAttributes() -> llvm::iterator_range<attribute_iterator> {
346 return {attribute_begin(), attribute_end()};
347}
348
349auto Operator::operand_begin() const -> const_value_iterator {
350 return operands.begin();
351}
352auto Operator::operand_end() const -> const_value_iterator {
353 return operands.end();
354}
355auto Operator::getOperands() const -> const_value_range {
356 return {operand_begin(), operand_end()};
357}
358
359auto Operator::getArg(int index) const -> Argument { return arguments[index]; }
360
361bool Operator::isVariadic() const {
362 return any_of(Range: llvm::concat<const NamedTypeConstraint>(Ranges: operands, Ranges: results),
363 P: [](const NamedTypeConstraint &op) { return op.isVariadic(); });
364}
365
366void Operator::populateTypeInferenceInfo(
367 const llvm::StringMap<int> &argumentsAndResultsIndex) {
368 // If the type inference op interface is not registered, then do not attempt
369 // to determine if the result types an be inferred.
370 auto &recordKeeper = def.getRecords();
371 auto *inferTrait = recordKeeper.getDef(Name: inferTypeOpInterface);
372 allResultsHaveKnownTypes = false;
373 if (!inferTrait)
374 return;
375
376 // If there are no results, the skip this else the build method generated
377 // overlaps with another autogenerated builder.
378 if (getNumResults() == 0)
379 return;
380
381 // Skip ops with variadic or optional results.
382 if (getNumVariableLengthResults() > 0)
383 return;
384
385 // Skip cases currently being custom generated.
386 // TODO: Remove special cases.
387 if (getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType")) {
388 // Check for a non-variable length operand to use as the type anchor.
389 auto *operandI = llvm::find_if(Range&: arguments, P: [](const Argument &arg) {
390 NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: arg);
391 return operand && !operand->isVariableLength();
392 });
393 if (operandI == arguments.end())
394 return;
395
396 // All result types are inferred from the operand type.
397 int operandIdx = operandI - arguments.begin();
398 for (int i = 0; i < getNumResults(); ++i)
399 resultTypeMapping.emplace_back(Args&: operandIdx, Args: "$_self");
400
401 allResultsHaveKnownTypes = true;
402 traits.push_back(Elt: Trait::create(init: inferTrait->getDefInit()));
403 return;
404 }
405
406 /// This struct represents a node in this operation's result type inferenece
407 /// graph. Each node has a list of incoming type inference edges `sources`.
408 /// Each edge represents a "source" from which the result type can be
409 /// inferred, either an operand (leaf) or another result (node). When a node
410 /// is known to have a fully-inferred type, `inferred` is set to true.
411 struct ResultTypeInference {
412 /// The list of incoming type inference edges.
413 SmallVector<InferredResultType> sources;
414 /// This flag is set to true when the result type is known to be inferrable.
415 bool inferred = false;
416 };
417
418 // This vector represents the type inference graph, with one node for each
419 // operation result. The nth element is the node for the nth result.
420 SmallVector<ResultTypeInference> inference(getNumResults(), {});
421
422 // For all results whose types are buildable, initialize their type inference
423 // nodes with an edge to themselves. Mark those nodes are fully-inferred.
424 for (auto [idx, infer] : llvm::enumerate(First&: inference)) {
425 if (getResult(index: idx).constraint.getBuilderCall()) {
426 infer.sources.emplace_back(Args: InferredResultType::mapResultIndex(i: idx),
427 Args: "$_self");
428 infer.inferred = true;
429 }
430 }
431
432 // Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the
433 // result type inference graph.
434 for (const Trait &trait : traits) {
435 const Record &def = trait.getDef();
436
437 // If the infer type op interface was manually added, then treat it as
438 // intention that the op needs special handling.
439 // TODO: Reconsider whether to always generate, this is more conservative
440 // and keeps existing behavior so starting that way for now.
441 if (def.isSubClassOf(
442 Name: llvm::formatv(Fmt: "{0}::Trait", Vals&: inferTypeOpInterface).str()))
443 return;
444 if (const auto *traitDef = dyn_cast<InterfaceTrait>(Val: &trait))
445 if (&traitDef->getDef() == inferTrait)
446 return;
447
448 // The `TypesMatchWith` trait represents a 1 -> 1 type inference edge with a
449 // type transformer.
450 if (def.isSubClassOf(Name: "TypesMatchWith")) {
451 int target = argumentsAndResultsIndex.lookup(Key: def.getValueAsString(FieldName: "rhs"));
452 // Ignore operand type inference.
453 if (InferredResultType::isArgIndex(i: target))
454 continue;
455 int resultIndex = InferredResultType::unmapResultIndex(i: target);
456 ResultTypeInference &infer = inference[resultIndex];
457 // If the type of the result has already been inferred, do nothing.
458 if (infer.inferred)
459 continue;
460 int sourceIndex =
461 argumentsAndResultsIndex.lookup(Key: def.getValueAsString(FieldName: "lhs"));
462 infer.sources.emplace_back(Args&: sourceIndex,
463 Args: def.getValueAsString(FieldName: "transformer").str());
464 // Locally propagate inferredness.
465 infer.inferred =
466 InferredResultType::isArgIndex(i: sourceIndex) ||
467 inference[InferredResultType::unmapResultIndex(i: sourceIndex)].inferred;
468 continue;
469 }
470
471 if (!def.isSubClassOf(Name: "AllTypesMatch"))
472 continue;
473
474 auto values = def.getValueAsListOfStrings(FieldName: "values");
475 // The `AllTypesMatch` trait represents an N <-> N fanin and fanout. That
476 // is, every result type has an edge from every other type. However, if any
477 // one of the values refers to an operand or a result with a fully-inferred
478 // type, we can infer all other types from that value. Try to find a
479 // fully-inferred type in the list.
480 std::optional<int> fullyInferredIndex;
481 SmallVector<int> resultIndices;
482 for (StringRef name : values) {
483 int index = argumentsAndResultsIndex.lookup(Key: name);
484 if (InferredResultType::isResultIndex(i: index))
485 resultIndices.push_back(Elt: InferredResultType::unmapResultIndex(i: index));
486 if (InferredResultType::isArgIndex(i: index) ||
487 inference[InferredResultType::unmapResultIndex(i: index)].inferred)
488 fullyInferredIndex = index;
489 }
490 if (fullyInferredIndex) {
491 // Make the fully-inferred type the only source for all results that
492 // aren't already inferred -- a 1 -> N fanout.
493 for (int resultIndex : resultIndices) {
494 ResultTypeInference &infer = inference[resultIndex];
495 if (!infer.inferred) {
496 infer.sources.assign(NumElts: 1, Elt: {*fullyInferredIndex, "$_self"});
497 infer.inferred = true;
498 }
499 }
500 } else {
501 // Add an edge between every result and every other type; N <-> N.
502 for (int resultIndex : resultIndices) {
503 for (int otherResultIndex : resultIndices) {
504 if (resultIndex == otherResultIndex)
505 continue;
506 inference[resultIndex].sources.emplace_back(
507 Args: InferredResultType::unmapResultIndex(i: otherResultIndex), Args: "$_self");
508 }
509 }
510 }
511 }
512
513 // Propagate inferredness until a fixed point.
514 std::vector<ResultTypeInference *> worklist;
515 for (ResultTypeInference &infer : inference)
516 if (!infer.inferred)
517 worklist.push_back(x: &infer);
518 bool changed;
519 do {
520 changed = false;
521 for (auto cur = worklist.begin(); cur != worklist.end();) {
522 ResultTypeInference &infer = **cur;
523
524 InferredResultType *iter =
525 llvm::find_if(Range&: infer.sources, P: [&](const InferredResultType &source) {
526 assert(InferredResultType::isResultIndex(source.getIndex()));
527 return inference[InferredResultType::unmapResultIndex(
528 i: source.getIndex())]
529 .inferred;
530 });
531 if (iter == infer.sources.end()) {
532 ++cur;
533 continue;
534 }
535
536 changed = true;
537 infer.inferred = true;
538 // Make this the only source for the result. This breaks any cycles.
539 infer.sources.assign(NumElts: 1, Elt: *iter);
540 cur = worklist.erase(position: cur);
541 }
542 } while (changed);
543
544 allResultsHaveKnownTypes = worklist.empty();
545
546 // If the types could be computed, then add type inference trait.
547 if (allResultsHaveKnownTypes) {
548 traits.push_back(Elt: Trait::create(init: inferTrait->getDefInit()));
549 for (const ResultTypeInference &infer : inference)
550 resultTypeMapping.push_back(Elt: infer.sources.front());
551 }
552}
553
554void Operator::populateOpStructure() {
555 auto &recordKeeper = def.getRecords();
556 auto *typeConstraintClass = recordKeeper.getClass(Name: "TypeConstraint");
557 auto *attrClass = recordKeeper.getClass(Name: "Attr");
558 auto *propertyClass = recordKeeper.getClass(Name: "Property");
559 auto *derivedAttrClass = recordKeeper.getClass(Name: "DerivedAttr");
560 auto *opVarClass = recordKeeper.getClass(Name: "OpVariable");
561 numNativeAttributes = 0;
562
563 const DagInit *argumentValues = def.getValueAsDag(FieldName: "arguments");
564 unsigned numArgs = argumentValues->getNumArgs();
565
566 // Mapping from name of to argument or result index. Arguments are indexed
567 // to match getArg index, while the results are negatively indexed.
568 llvm::StringMap<int> argumentsAndResultsIndex;
569
570 // Handle operands and native attributes.
571 for (unsigned i = 0; i != numArgs; ++i) {
572 auto *arg = argumentValues->getArg(Num: i);
573 auto givenName = argumentValues->getArgNameStr(Num: i);
574 auto *argDefInit = dyn_cast<DefInit>(Val: arg);
575 if (!argDefInit)
576 PrintFatalError(ErrorLoc: def.getLoc(),
577 Msg: Twine("undefined type for argument #") + Twine(i));
578 const Record *argDef = argDefInit->getDef();
579 if (argDef->isSubClassOf(R: opVarClass))
580 argDef = argDef->getValueAsDef(FieldName: "constraint");
581
582 if (argDef->isSubClassOf(R: typeConstraintClass)) {
583 operands.push_back(
584 Elt: NamedTypeConstraint{.name: givenName, .constraint: TypeConstraint(argDef)});
585 } else if (argDef->isSubClassOf(R: attrClass)) {
586 if (givenName.empty())
587 PrintFatalError(ErrorLoc: argDef->getLoc(), Msg: "attributes must be named");
588 if (argDef->isSubClassOf(R: derivedAttrClass))
589 PrintFatalError(ErrorLoc: argDef->getLoc(),
590 Msg: "derived attributes not allowed in argument list");
591 attributes.push_back(Elt: {.name: givenName, .attr: Attribute(argDef)});
592 ++numNativeAttributes;
593 } else if (argDef->isSubClassOf(R: propertyClass)) {
594 if (givenName.empty())
595 PrintFatalError(ErrorLoc: argDef->getLoc(), Msg: "properties must be named");
596 properties.push_back(Elt: {.name: givenName, .prop: Property(argDef)});
597 } else {
598 PrintFatalError(ErrorLoc: def.getLoc(),
599 Msg: "unexpected def type; only defs deriving "
600 "from TypeConstraint or Attr or Property are allowed");
601 }
602 if (!givenName.empty())
603 argumentsAndResultsIndex[givenName] = i;
604 }
605
606 // Handle derived attributes.
607 for (const auto &val : def.getValues()) {
608 if (auto *record = dyn_cast<llvm::RecordRecTy>(Val: val.getType())) {
609 if (!record->isSubClassOf(Class: attrClass))
610 continue;
611 if (!record->isSubClassOf(Class: derivedAttrClass))
612 PrintFatalError(ErrorLoc: def.getLoc(),
613 Msg: "unexpected Attr where only DerivedAttr is allowed");
614
615 if (record->getClasses().size() != 1) {
616 PrintFatalError(
617 ErrorLoc: def.getLoc(),
618 Msg: "unsupported attribute modelling, only single class expected");
619 }
620 attributes.push_back(Elt: {.name: cast<StringInit>(Val: val.getNameInit())->getValue(),
621 .attr: Attribute(cast<DefInit>(Val: val.getValue()))});
622 }
623 }
624
625 // Populate `arguments`. This must happen after we've finalized `operands` and
626 // `attributes` because we will put their elements' pointers in `arguments`.
627 // SmallVector may perform re-allocation under the hood when adding new
628 // elements.
629 int operandIndex = 0, attrIndex = 0, propIndex = 0;
630 for (unsigned i = 0; i != numArgs; ++i) {
631 const Record *argDef =
632 dyn_cast<DefInit>(Val: argumentValues->getArg(Num: i))->getDef();
633 if (argDef->isSubClassOf(R: opVarClass))
634 argDef = argDef->getValueAsDef(FieldName: "constraint");
635
636 if (argDef->isSubClassOf(R: typeConstraintClass)) {
637 attrOrOperandMapping.push_back(
638 Elt: {OperandOrAttribute::Kind::Operand, operandIndex});
639 arguments.emplace_back(Args: &operands[operandIndex++]);
640 } else if (argDef->isSubClassOf(R: attrClass)) {
641 attrOrOperandMapping.push_back(
642 Elt: {OperandOrAttribute::Kind::Attribute, attrIndex});
643 arguments.emplace_back(Args: &attributes[attrIndex++]);
644 } else {
645 assert(argDef->isSubClassOf(propertyClass));
646 arguments.emplace_back(Args: &properties[propIndex++]);
647 }
648 }
649
650 auto *resultsDag = def.getValueAsDag(FieldName: "results");
651 auto *outsOp = dyn_cast<DefInit>(Val: resultsDag->getOperator());
652 if (!outsOp || outsOp->getDef()->getName() != "outs") {
653 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "'results' must have 'outs' directive");
654 }
655
656 // Handle results.
657 for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
658 auto name = resultsDag->getArgNameStr(Num: i);
659 auto *resultInit = dyn_cast<DefInit>(Val: resultsDag->getArg(Num: i));
660 if (!resultInit) {
661 PrintFatalError(ErrorLoc: def.getLoc(),
662 Msg: Twine("undefined type for result #") + Twine(i));
663 }
664 auto *resultDef = resultInit->getDef();
665 if (resultDef->isSubClassOf(R: opVarClass))
666 resultDef = resultDef->getValueAsDef(FieldName: "constraint");
667 results.push_back(Elt: {.name: name, .constraint: TypeConstraint(resultDef)});
668 if (!name.empty())
669 argumentsAndResultsIndex[name] = InferredResultType::mapResultIndex(i);
670
671 // We currently only support VariadicOfVariadic operands.
672 if (results.back().constraint.isVariadicOfVariadic()) {
673 PrintFatalError(
674 ErrorLoc: def.getLoc(),
675 Msg: "'VariadicOfVariadic' results are currently not supported");
676 }
677 }
678
679 // Handle successors
680 auto *successorsDag = def.getValueAsDag(FieldName: "successors");
681 auto *successorsOp = dyn_cast<DefInit>(Val: successorsDag->getOperator());
682 if (!successorsOp || successorsOp->getDef()->getName() != "successor") {
683 PrintFatalError(ErrorLoc: def.getLoc(),
684 Msg: "'successors' must have 'successor' directive");
685 }
686
687 for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
688 auto name = successorsDag->getArgNameStr(Num: i);
689 auto *successorInit = dyn_cast<DefInit>(Val: successorsDag->getArg(Num: i));
690 if (!successorInit) {
691 PrintFatalError(ErrorLoc: def.getLoc(),
692 Msg: Twine("undefined kind for successor #") + Twine(i));
693 }
694 Successor successor(successorInit->getDef());
695
696 // Only support variadic successors if it is the last one for now.
697 if (i != e - 1 && successor.isVariadic())
698 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "only the last successor can be variadic");
699 successors.push_back(Elt: {.name: name, .constraint: successor});
700 }
701
702 // Create list of traits, skipping over duplicates: appending to lists in
703 // tablegen is easy, making them unique less so, so dedupe here.
704 if (auto *traitList = def.getValueAsListInit(FieldName: "traits")) {
705 // This is uniquing based on pointers of the trait.
706 SmallPtrSet<const Init *, 32> traitSet;
707 traits.reserve(N: traitSet.size());
708
709 // The declaration order of traits imply the verification order of traits.
710 // Some traits may require other traits to be verified first then they can
711 // do further verification based on those verified facts. If you see this
712 // error, fix the traits declaration order by checking the `dependentTraits`
713 // field.
714 auto verifyTraitValidity = [&](const Record *trait) {
715 auto *dependentTraits = trait->getValueAsListInit(FieldName: "dependentTraits");
716 for (auto *traitInit : *dependentTraits)
717 if (!traitSet.contains(Ptr: traitInit))
718 PrintFatalError(
719 ErrorLoc: def.getLoc(),
720 Msg: trait->getValueAsString(FieldName: "trait") + " requires " +
721 cast<DefInit>(Val: traitInit)->getDef()->getValueAsString(
722 FieldName: "trait") +
723 " to precede it in traits list");
724 };
725
726 std::function<void(const ListInit *)> insert;
727 insert = [&](const ListInit *traitList) {
728 for (auto *traitInit : *traitList) {
729 auto *def = cast<DefInit>(Val: traitInit)->getDef();
730 if (def->isSubClassOf(Name: "TraitList")) {
731 insert(def->getValueAsListInit(FieldName: "traits"));
732 continue;
733 }
734
735 // Ignore duplicates.
736 if (!traitSet.insert(Ptr: traitInit).second)
737 continue;
738
739 // If this is an interface with base classes, add the bases to the
740 // trait list.
741 if (def->isSubClassOf(Name: "Interface"))
742 insert(def->getValueAsListInit(FieldName: "baseInterfaces"));
743
744 // Verify if the trait has all the dependent traits declared before
745 // itself.
746 verifyTraitValidity(def);
747 traits.push_back(Elt: Trait::create(init: traitInit));
748 }
749 };
750 insert(traitList);
751 }
752
753 populateTypeInferenceInfo(argumentsAndResultsIndex);
754
755 // Handle regions
756 auto *regionsDag = def.getValueAsDag(FieldName: "regions");
757 auto *regionsOp = dyn_cast<DefInit>(Val: regionsDag->getOperator());
758 if (!regionsOp || regionsOp->getDef()->getName() != "region") {
759 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "'regions' must have 'region' directive");
760 }
761
762 for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
763 auto name = regionsDag->getArgNameStr(Num: i);
764 auto *regionInit = dyn_cast<DefInit>(Val: regionsDag->getArg(Num: i));
765 if (!regionInit) {
766 PrintFatalError(ErrorLoc: def.getLoc(),
767 Msg: Twine("undefined kind for region #") + Twine(i));
768 }
769 Region region(regionInit->getDef());
770 if (region.isVariadic()) {
771 // Only support variadic regions if it is the last one for now.
772 if (i != e - 1)
773 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "only the last region can be variadic");
774 if (name.empty())
775 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "variadic regions must be named");
776 }
777
778 regions.push_back(Elt: {.name: name, .constraint: region});
779 }
780
781 // Populate the builders.
782 auto *builderList = dyn_cast_or_null<ListInit>(Val: def.getValueInit(FieldName: "builders"));
783 if (builderList && !builderList->empty()) {
784 for (const Init *init : builderList->getElements())
785 builders.emplace_back(Args: cast<DefInit>(Val: init)->getDef(), Args: def.getLoc());
786 } else if (skipDefaultBuilders()) {
787 PrintFatalError(
788 ErrorLoc: def.getLoc(),
789 Msg: "default builders are skipped and no custom builders provided");
790 }
791
792 LLVM_DEBUG(print(llvm::dbgs()));
793}
794
795const InferredResultType &Operator::getInferredResultType(int index) const {
796 assert(allResultTypesKnown());
797 return resultTypeMapping[index];
798}
799
800ArrayRef<SMLoc> Operator::getLoc() const { return def.getLoc(); }
801
802bool Operator::hasDescription() const {
803 return !getDescription().trim().empty();
804}
805
806StringRef Operator::getDescription() const {
807 return def.getValueAsString(FieldName: "description");
808}
809
810bool Operator::hasSummary() const { return !getSummary().trim().empty(); }
811
812StringRef Operator::getSummary() const {
813 return def.getValueAsString(FieldName: "summary");
814}
815
816bool Operator::hasAssemblyFormat() const {
817 auto *valueInit = def.getValueInit(FieldName: "assemblyFormat");
818 return isa<StringInit>(Val: valueInit);
819}
820
821StringRef Operator::getAssemblyFormat() const {
822 return TypeSwitch<const Init *, StringRef>(def.getValueInit(FieldName: "assemblyFormat"))
823 .Case<StringInit>(caseFn: [&](auto *init) { return init->getValue(); });
824}
825
826void Operator::print(llvm::raw_ostream &os) const {
827 os << "op '" << getOperationName() << "'\n";
828 for (Argument arg : arguments) {
829 if (auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: arg))
830 os << "[attribute] " << attr->name << '\n';
831 else
832 os << "[operand] " << cast<NamedTypeConstraint *>(Val&: arg)->name << '\n';
833 }
834}
835
836auto Operator::VariableDecoratorIterator::unwrap(const Init *init)
837 -> VariableDecorator {
838 return VariableDecorator(cast<DefInit>(Val: init)->getDef());
839}
840
841auto Operator::getArgToOperandOrAttribute(int index) const
842 -> OperandOrAttribute {
843 return attrOrOperandMapping[index];
844}
845
846std::string Operator::getGetterName(StringRef name) const {
847 return "get" + convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
848}
849
850std::string Operator::getSetterName(StringRef name) const {
851 return "set" + convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
852}
853
854std::string Operator::getRemoverName(StringRef name) const {
855 return "remove" + convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
856}
857
858bool Operator::hasFolder() const { return def.getValueAsBit(FieldName: "hasFolder"); }
859
860bool Operator::useCustomPropertiesEncoding() const {
861 return def.getValueAsBit(FieldName: "useCustomPropertiesEncoding");
862}
863

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/TableGen/Operator.cpp