1 | //===- Operator.cpp - Operator class --------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Operator wrapper to simplify using TableGen Record defining a MLIR Op. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/TableGen/Operator.h" |
14 | #include "mlir/TableGen/Argument.h" |
15 | #include "mlir/TableGen/Predicate.h" |
16 | #include "mlir/TableGen/Trait.h" |
17 | #include "mlir/TableGen/Type.h" |
18 | #include "llvm/ADT/EquivalenceClasses.h" |
19 | #include "llvm/ADT/STLExtras.h" |
20 | #include "llvm/ADT/Sequence.h" |
21 | #include "llvm/ADT/SmallPtrSet.h" |
22 | #include "llvm/ADT/StringExtras.h" |
23 | #include "llvm/ADT/TypeSwitch.h" |
24 | #include "llvm/Support/Debug.h" |
25 | #include "llvm/Support/ErrorHandling.h" |
26 | #include "llvm/Support/FormatVariadic.h" |
27 | #include "llvm/TableGen/Error.h" |
28 | #include "llvm/TableGen/Record.h" |
29 | #include <list> |
30 | |
31 | #define DEBUG_TYPE "mlir-tblgen-operator" |
32 | |
33 | using namespace mlir; |
34 | using namespace mlir::tblgen; |
35 | |
36 | using llvm::DagInit; |
37 | using llvm::DefInit; |
38 | using llvm::Record; |
39 | |
40 | Operator::Operator(const llvm::Record &def) |
41 | : dialect(def.getValueAsDef(FieldName: "opDialect" )), def(def) { |
42 | // The first `_` in the op's TableGen def name is treated as separating the |
43 | // dialect prefix and the op class name. The dialect prefix will be ignored if |
44 | // not empty. Otherwise, if def name starts with a `_`, the `_` is considered |
45 | // as part of the class name. |
46 | StringRef prefix; |
47 | std::tie(args&: prefix, args&: cppClassName) = def.getName().split(Separator: '_'); |
48 | if (prefix.empty()) { |
49 | // Class name with a leading underscore and without dialect prefix |
50 | cppClassName = def.getName(); |
51 | } else if (cppClassName.empty()) { |
52 | // Class name without dialect prefix |
53 | cppClassName = prefix; |
54 | } |
55 | |
56 | cppNamespace = def.getValueAsString(FieldName: "cppNamespace" ); |
57 | |
58 | populateOpStructure(); |
59 | assertInvariants(); |
60 | } |
61 | |
62 | std::string Operator::getOperationName() const { |
63 | auto prefix = dialect.getName(); |
64 | auto opName = def.getValueAsString(FieldName: "opName" ); |
65 | if (prefix.empty()) |
66 | return std::string(opName); |
67 | return std::string(llvm::formatv(Fmt: "{0}.{1}" , Vals&: prefix, Vals&: opName)); |
68 | } |
69 | |
70 | std::string Operator::getAdaptorName() const { |
71 | return std::string(llvm::formatv(Fmt: "{0}Adaptor" , Vals: getCppClassName())); |
72 | } |
73 | |
74 | std::string Operator::getGenericAdaptorName() const { |
75 | return std::string(llvm::formatv(Fmt: "{0}GenericAdaptor" , Vals: getCppClassName())); |
76 | } |
77 | |
78 | /// Assert the invariants of accessors generated for the given name. |
79 | static void assertAccessorInvariants(const Operator &op, StringRef name) { |
80 | std::string accessorName = |
81 | convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true); |
82 | |
83 | // Functor used to detect when an accessor will cause an overlap with an |
84 | // operation API. |
85 | // |
86 | // There are a little bit more invasive checks possible for cases where not |
87 | // all ops have the trait that would cause overlap. For many cases here, |
88 | // renaming would be better (e.g., we can only guard in limited manner |
89 | // against methods from traits and interfaces here, so avoiding these in op |
90 | // definition is safer). |
91 | auto nameOverlapsWithOpAPI = [&](StringRef newName) { |
92 | if (newName == "AttributeNames" || newName == "Attributes" || |
93 | newName == "Operation" ) |
94 | return true; |
95 | if (newName == "Operands" ) |
96 | return op.getNumOperands() != 1 || op.getNumVariableLengthOperands() != 1; |
97 | if (newName == "Regions" ) |
98 | return op.getNumRegions() != 1 || op.getNumVariadicRegions() != 1; |
99 | if (newName == "Type" ) |
100 | return op.getNumResults() != 1; |
101 | return false; |
102 | }; |
103 | if (nameOverlapsWithOpAPI(accessorName)) { |
104 | // This error could be avoided in situations where the final function is |
105 | // identical, but preferably the op definition should avoid using generic |
106 | // names. |
107 | PrintFatalError(ErrorLoc: op.getLoc(), Msg: "generated accessor for `" + name + |
108 | "` overlaps with a default one; please " |
109 | "rename to avoid overlap" ); |
110 | } |
111 | } |
112 | |
113 | void Operator::assertInvariants() const { |
114 | // Check that the name of arguments/results/regions/successors don't overlap. |
115 | DenseMap<StringRef, StringRef> existingNames; |
116 | auto checkName = [&](StringRef name, StringRef entity) { |
117 | if (name.empty()) |
118 | return; |
119 | auto insertion = existingNames.insert(KV: {name, entity}); |
120 | if (insertion.second) { |
121 | // Assert invariants for accessors generated for this name. |
122 | assertAccessorInvariants(op: *this, name); |
123 | return; |
124 | } |
125 | if (entity == insertion.first->second) |
126 | PrintFatalError(ErrorLoc: getLoc(), Msg: "op has a conflict with two " + entity + |
127 | " having the same name '" + name + "'" ); |
128 | PrintFatalError(ErrorLoc: getLoc(), Msg: "op has a conflict with " + |
129 | insertion.first->second + " and " + entity + |
130 | " both having an entry with the name '" + |
131 | name + "'" ); |
132 | }; |
133 | // Check operands amongst themselves. |
134 | for (int i : llvm::seq<int>(Begin: 0, End: getNumOperands())) |
135 | checkName(getOperand(index: i).name, "operands" ); |
136 | |
137 | // Check results amongst themselves and against operands. |
138 | for (int i : llvm::seq<int>(Begin: 0, End: getNumResults())) |
139 | checkName(getResult(index: i).name, "results" ); |
140 | |
141 | // Check regions amongst themselves and against operands and results. |
142 | for (int i : llvm::seq<int>(Begin: 0, End: getNumRegions())) |
143 | checkName(getRegion(index: i).name, "regions" ); |
144 | |
145 | // Check successors amongst themselves and against operands, results, and |
146 | // regions. |
147 | for (int i : llvm::seq<int>(Begin: 0, End: getNumSuccessors())) |
148 | checkName(getSuccessor(index: i).name, "successors" ); |
149 | } |
150 | |
151 | StringRef Operator::getDialectName() const { return dialect.getName(); } |
152 | |
153 | StringRef Operator::getCppClassName() const { return cppClassName; } |
154 | |
155 | std::string Operator::getQualCppClassName() const { |
156 | if (cppNamespace.empty()) |
157 | return std::string(cppClassName); |
158 | return std::string(llvm::formatv(Fmt: "{0}::{1}" , Vals: cppNamespace, Vals: cppClassName)); |
159 | } |
160 | |
161 | StringRef Operator::getCppNamespace() const { return cppNamespace; } |
162 | |
163 | int Operator::getNumResults() const { |
164 | DagInit *results = def.getValueAsDag(FieldName: "results" ); |
165 | return results->getNumArgs(); |
166 | } |
167 | |
168 | StringRef Operator::() const { |
169 | constexpr auto attr = "extraClassDeclaration" ; |
170 | if (def.isValueUnset(FieldName: attr)) |
171 | return {}; |
172 | return def.getValueAsString(FieldName: attr); |
173 | } |
174 | |
175 | StringRef Operator::() const { |
176 | constexpr auto attr = "extraClassDefinition" ; |
177 | if (def.isValueUnset(FieldName: attr)) |
178 | return {}; |
179 | return def.getValueAsString(FieldName: attr); |
180 | } |
181 | |
182 | const llvm::Record &Operator::getDef() const { return def; } |
183 | |
184 | bool Operator::skipDefaultBuilders() const { |
185 | return def.getValueAsBit(FieldName: "skipDefaultBuilders" ); |
186 | } |
187 | |
188 | auto Operator::result_begin() const -> const_value_iterator { |
189 | return results.begin(); |
190 | } |
191 | |
192 | auto Operator::result_end() const -> const_value_iterator { |
193 | return results.end(); |
194 | } |
195 | |
196 | auto Operator::getResults() const -> const_value_range { |
197 | return {result_begin(), result_end()}; |
198 | } |
199 | |
200 | TypeConstraint Operator::getResultTypeConstraint(int index) const { |
201 | DagInit *results = def.getValueAsDag(FieldName: "results" ); |
202 | return TypeConstraint(cast<DefInit>(Val: results->getArg(Num: index))); |
203 | } |
204 | |
205 | StringRef Operator::getResultName(int index) const { |
206 | DagInit *results = def.getValueAsDag(FieldName: "results" ); |
207 | return results->getArgNameStr(Num: index); |
208 | } |
209 | |
210 | auto Operator::getResultDecorators(int index) const -> var_decorator_range { |
211 | Record *result = |
212 | cast<DefInit>(Val: def.getValueAsDag(FieldName: "results" )->getArg(Num: index))->getDef(); |
213 | if (!result->isSubClassOf(Name: "OpVariable" )) |
214 | return var_decorator_range(nullptr, nullptr); |
215 | return *result->getValueAsListInit(FieldName: "decorators" ); |
216 | } |
217 | |
218 | unsigned Operator::getNumVariableLengthResults() const { |
219 | return llvm::count_if(Range: results, P: [](const NamedTypeConstraint &c) { |
220 | return c.constraint.isVariableLength(); |
221 | }); |
222 | } |
223 | |
224 | unsigned Operator::getNumVariableLengthOperands() const { |
225 | return llvm::count_if(Range: operands, P: [](const NamedTypeConstraint &c) { |
226 | return c.constraint.isVariableLength(); |
227 | }); |
228 | } |
229 | |
230 | bool Operator::hasSingleVariadicArg() const { |
231 | return getNumArgs() == 1 && getArg(index: 0).is<NamedTypeConstraint *>() && |
232 | getOperand(index: 0).isVariadic(); |
233 | } |
234 | |
235 | Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); } |
236 | |
237 | Operator::arg_iterator Operator::arg_end() const { return arguments.end(); } |
238 | |
239 | Operator::arg_range Operator::getArgs() const { |
240 | return {arg_begin(), arg_end()}; |
241 | } |
242 | |
243 | StringRef Operator::getArgName(int index) const { |
244 | DagInit *argumentValues = def.getValueAsDag(FieldName: "arguments" ); |
245 | return argumentValues->getArgNameStr(Num: index); |
246 | } |
247 | |
248 | auto Operator::getArgDecorators(int index) const -> var_decorator_range { |
249 | Record *arg = |
250 | cast<DefInit>(Val: def.getValueAsDag(FieldName: "arguments" )->getArg(Num: index))->getDef(); |
251 | if (!arg->isSubClassOf(Name: "OpVariable" )) |
252 | return var_decorator_range(nullptr, nullptr); |
253 | return *arg->getValueAsListInit(FieldName: "decorators" ); |
254 | } |
255 | |
256 | const Trait *Operator::getTrait(StringRef trait) const { |
257 | for (const auto &t : traits) { |
258 | if (const auto *traitDef = dyn_cast<NativeTrait>(Val: &t)) { |
259 | if (traitDef->getFullyQualifiedTraitName() == trait) |
260 | return traitDef; |
261 | } else if (const auto *traitDef = dyn_cast<InternalTrait>(Val: &t)) { |
262 | if (traitDef->getFullyQualifiedTraitName() == trait) |
263 | return traitDef; |
264 | } else if (const auto *traitDef = dyn_cast<InterfaceTrait>(Val: &t)) { |
265 | if (traitDef->getFullyQualifiedTraitName() == trait) |
266 | return traitDef; |
267 | } |
268 | } |
269 | return nullptr; |
270 | } |
271 | |
272 | auto Operator::region_begin() const -> const_region_iterator { |
273 | return regions.begin(); |
274 | } |
275 | auto Operator::region_end() const -> const_region_iterator { |
276 | return regions.end(); |
277 | } |
278 | auto Operator::getRegions() const |
279 | -> llvm::iterator_range<const_region_iterator> { |
280 | return {region_begin(), region_end()}; |
281 | } |
282 | |
283 | unsigned Operator::getNumRegions() const { return regions.size(); } |
284 | |
285 | const NamedRegion &Operator::getRegion(unsigned index) const { |
286 | return regions[index]; |
287 | } |
288 | |
289 | unsigned Operator::getNumVariadicRegions() const { |
290 | return llvm::count_if(Range: regions, |
291 | P: [](const NamedRegion &c) { return c.isVariadic(); }); |
292 | } |
293 | |
294 | auto Operator::successor_begin() const -> const_successor_iterator { |
295 | return successors.begin(); |
296 | } |
297 | auto Operator::successor_end() const -> const_successor_iterator { |
298 | return successors.end(); |
299 | } |
300 | auto Operator::getSuccessors() const |
301 | -> llvm::iterator_range<const_successor_iterator> { |
302 | return {successor_begin(), successor_end()}; |
303 | } |
304 | |
305 | unsigned Operator::getNumSuccessors() const { return successors.size(); } |
306 | |
307 | const NamedSuccessor &Operator::getSuccessor(unsigned index) const { |
308 | return successors[index]; |
309 | } |
310 | |
311 | unsigned Operator::getNumVariadicSuccessors() const { |
312 | return llvm::count_if(Range: successors, |
313 | P: [](const NamedSuccessor &c) { return c.isVariadic(); }); |
314 | } |
315 | |
316 | auto Operator::trait_begin() const -> const_trait_iterator { |
317 | return traits.begin(); |
318 | } |
319 | auto Operator::trait_end() const -> const_trait_iterator { |
320 | return traits.end(); |
321 | } |
322 | auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> { |
323 | return {trait_begin(), trait_end()}; |
324 | } |
325 | |
326 | auto Operator::attribute_begin() const -> const_attribute_iterator { |
327 | return attributes.begin(); |
328 | } |
329 | auto Operator::attribute_end() const -> const_attribute_iterator { |
330 | return attributes.end(); |
331 | } |
332 | auto Operator::getAttributes() const |
333 | -> llvm::iterator_range<const_attribute_iterator> { |
334 | return {attribute_begin(), attribute_end()}; |
335 | } |
336 | auto Operator::attribute_begin() -> attribute_iterator { |
337 | return attributes.begin(); |
338 | } |
339 | auto Operator::attribute_end() -> attribute_iterator { |
340 | return attributes.end(); |
341 | } |
342 | auto Operator::getAttributes() -> llvm::iterator_range<attribute_iterator> { |
343 | return {attribute_begin(), attribute_end()}; |
344 | } |
345 | |
346 | auto Operator::operand_begin() const -> const_value_iterator { |
347 | return operands.begin(); |
348 | } |
349 | auto Operator::operand_end() const -> const_value_iterator { |
350 | return operands.end(); |
351 | } |
352 | auto Operator::getOperands() const -> const_value_range { |
353 | return {operand_begin(), operand_end()}; |
354 | } |
355 | |
356 | auto Operator::getArg(int index) const -> Argument { return arguments[index]; } |
357 | |
358 | bool Operator::isVariadic() const { |
359 | return any_of(Range: llvm::concat<const NamedTypeConstraint>(Ranges: operands, Ranges: results), |
360 | P: [](const NamedTypeConstraint &op) { return op.isVariadic(); }); |
361 | } |
362 | |
363 | void Operator::populateTypeInferenceInfo( |
364 | const llvm::StringMap<int> &argumentsAndResultsIndex) { |
365 | // If the type inference op interface is not registered, then do not attempt |
366 | // to determine if the result types an be inferred. |
367 | auto &recordKeeper = def.getRecords(); |
368 | auto *inferTrait = recordKeeper.getDef(Name: inferTypeOpInterface); |
369 | allResultsHaveKnownTypes = false; |
370 | if (!inferTrait) |
371 | return; |
372 | |
373 | // If there are no results, the skip this else the build method generated |
374 | // overlaps with another autogenerated builder. |
375 | if (getNumResults() == 0) |
376 | return; |
377 | |
378 | // Skip ops with variadic or optional results. |
379 | if (getNumVariableLengthResults() > 0) |
380 | return; |
381 | |
382 | // Skip cases currently being custom generated. |
383 | // TODO: Remove special cases. |
384 | if (getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType" )) { |
385 | // Check for a non-variable length operand to use as the type anchor. |
386 | auto *operandI = llvm::find_if(Range&: arguments, P: [](const Argument &arg) { |
387 | NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val: arg); |
388 | return operand && !operand->isVariableLength(); |
389 | }); |
390 | if (operandI == arguments.end()) |
391 | return; |
392 | |
393 | // All result types are inferred from the operand type. |
394 | int operandIdx = operandI - arguments.begin(); |
395 | for (int i = 0; i < getNumResults(); ++i) |
396 | resultTypeMapping.emplace_back(Args&: operandIdx, Args: "$_self" ); |
397 | |
398 | allResultsHaveKnownTypes = true; |
399 | traits.push_back(Elt: Trait::create(init: inferTrait->getDefInit())); |
400 | return; |
401 | } |
402 | |
403 | /// This struct represents a node in this operation's result type inferenece |
404 | /// graph. Each node has a list of incoming type inference edges `sources`. |
405 | /// Each edge represents a "source" from which the result type can be |
406 | /// inferred, either an operand (leaf) or another result (node). When a node |
407 | /// is known to have a fully-inferred type, `inferred` is set to true. |
408 | struct ResultTypeInference { |
409 | /// The list of incoming type inference edges. |
410 | SmallVector<InferredResultType> sources; |
411 | /// This flag is set to true when the result type is known to be inferrable. |
412 | bool inferred = false; |
413 | }; |
414 | |
415 | // This vector represents the type inference graph, with one node for each |
416 | // operation result. The nth element is the node for the nth result. |
417 | SmallVector<ResultTypeInference> inference(getNumResults(), {}); |
418 | |
419 | // For all results whose types are buildable, initialize their type inference |
420 | // nodes with an edge to themselves. Mark those nodes are fully-inferred. |
421 | for (auto [idx, infer] : llvm::enumerate(First&: inference)) { |
422 | if (getResult(index: idx).constraint.getBuilderCall()) { |
423 | infer.sources.emplace_back(Args: InferredResultType::mapResultIndex(i: idx), |
424 | Args: "$_self" ); |
425 | infer.inferred = true; |
426 | } |
427 | } |
428 | |
429 | // Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the |
430 | // result type inference graph. |
431 | for (const Trait &trait : traits) { |
432 | const llvm::Record &def = trait.getDef(); |
433 | |
434 | // If the infer type op interface was manually added, then treat it as |
435 | // intention that the op needs special handling. |
436 | // TODO: Reconsider whether to always generate, this is more conservative |
437 | // and keeps existing behavior so starting that way for now. |
438 | if (def.isSubClassOf( |
439 | Name: llvm::formatv(Fmt: "{0}::Trait" , Vals&: inferTypeOpInterface).str())) |
440 | return; |
441 | if (const auto *traitDef = dyn_cast<InterfaceTrait>(Val: &trait)) |
442 | if (&traitDef->getDef() == inferTrait) |
443 | return; |
444 | |
445 | // The `TypesMatchWith` trait represents a 1 -> 1 type inference edge with a |
446 | // type transformer. |
447 | if (def.isSubClassOf(Name: "TypesMatchWith" )) { |
448 | int target = argumentsAndResultsIndex.lookup(Key: def.getValueAsString(FieldName: "rhs" )); |
449 | // Ignore operand type inference. |
450 | if (InferredResultType::isArgIndex(i: target)) |
451 | continue; |
452 | int resultIndex = InferredResultType::unmapResultIndex(i: target); |
453 | ResultTypeInference &infer = inference[resultIndex]; |
454 | // If the type of the result has already been inferred, do nothing. |
455 | if (infer.inferred) |
456 | continue; |
457 | int sourceIndex = |
458 | argumentsAndResultsIndex.lookup(Key: def.getValueAsString(FieldName: "lhs" )); |
459 | infer.sources.emplace_back(Args&: sourceIndex, |
460 | Args: def.getValueAsString(FieldName: "transformer" ).str()); |
461 | // Locally propagate inferredness. |
462 | infer.inferred = |
463 | InferredResultType::isArgIndex(i: sourceIndex) || |
464 | inference[InferredResultType::unmapResultIndex(i: sourceIndex)].inferred; |
465 | continue; |
466 | } |
467 | |
468 | if (!def.isSubClassOf(Name: "AllTypesMatch" )) |
469 | continue; |
470 | |
471 | auto values = def.getValueAsListOfStrings(FieldName: "values" ); |
472 | // The `AllTypesMatch` trait represents an N <-> N fanin and fanout. That |
473 | // is, every result type has an edge from every other type. However, if any |
474 | // one of the values refers to an operand or a result with a fully-inferred |
475 | // type, we can infer all other types from that value. Try to find a |
476 | // fully-inferred type in the list. |
477 | std::optional<int> fullyInferredIndex; |
478 | SmallVector<int> resultIndices; |
479 | for (StringRef name : values) { |
480 | int index = argumentsAndResultsIndex.lookup(Key: name); |
481 | if (InferredResultType::isResultIndex(i: index)) |
482 | resultIndices.push_back(Elt: InferredResultType::unmapResultIndex(i: index)); |
483 | if (InferredResultType::isArgIndex(i: index) || |
484 | inference[InferredResultType::unmapResultIndex(i: index)].inferred) |
485 | fullyInferredIndex = index; |
486 | } |
487 | if (fullyInferredIndex) { |
488 | // Make the fully-inferred type the only source for all results that |
489 | // aren't already inferred -- a 1 -> N fanout. |
490 | for (int resultIndex : resultIndices) { |
491 | ResultTypeInference &infer = inference[resultIndex]; |
492 | if (!infer.inferred) { |
493 | infer.sources.assign(NumElts: 1, Elt: {*fullyInferredIndex, "$_self" }); |
494 | infer.inferred = true; |
495 | } |
496 | } |
497 | } else { |
498 | // Add an edge between every result and every other type; N <-> N. |
499 | for (int resultIndex : resultIndices) { |
500 | for (int otherResultIndex : resultIndices) { |
501 | if (resultIndex == otherResultIndex) |
502 | continue; |
503 | inference[resultIndex].sources.emplace_back(Args&: otherResultIndex, |
504 | Args: "$_self" ); |
505 | } |
506 | } |
507 | } |
508 | } |
509 | |
510 | // Propagate inferredness until a fixed point. |
511 | std::vector<ResultTypeInference *> worklist; |
512 | for (ResultTypeInference &infer : inference) |
513 | if (!infer.inferred) |
514 | worklist.push_back(x: &infer); |
515 | bool changed; |
516 | do { |
517 | changed = false; |
518 | for (auto cur = worklist.begin(); cur != worklist.end();) { |
519 | ResultTypeInference &infer = **cur; |
520 | |
521 | InferredResultType *iter = |
522 | llvm::find_if(Range&: infer.sources, P: [&](const InferredResultType &source) { |
523 | assert(InferredResultType::isResultIndex(source.getIndex())); |
524 | return inference[InferredResultType::unmapResultIndex( |
525 | i: source.getIndex())] |
526 | .inferred; |
527 | }); |
528 | if (iter == infer.sources.end()) { |
529 | ++cur; |
530 | continue; |
531 | } |
532 | |
533 | changed = true; |
534 | infer.inferred = true; |
535 | // Make this the only source for the result. This breaks any cycles. |
536 | infer.sources.assign(NumElts: 1, Elt: *iter); |
537 | cur = worklist.erase(position: cur); |
538 | } |
539 | } while (changed); |
540 | |
541 | allResultsHaveKnownTypes = worklist.empty(); |
542 | |
543 | // If the types could be computed, then add type inference trait. |
544 | if (allResultsHaveKnownTypes) { |
545 | traits.push_back(Elt: Trait::create(init: inferTrait->getDefInit())); |
546 | for (const ResultTypeInference &infer : inference) |
547 | resultTypeMapping.push_back(Elt: infer.sources.front()); |
548 | } |
549 | } |
550 | |
551 | void Operator::populateOpStructure() { |
552 | auto &recordKeeper = def.getRecords(); |
553 | auto *typeConstraintClass = recordKeeper.getClass(Name: "TypeConstraint" ); |
554 | auto *attrClass = recordKeeper.getClass(Name: "Attr" ); |
555 | auto *propertyClass = recordKeeper.getClass(Name: "Property" ); |
556 | auto *derivedAttrClass = recordKeeper.getClass(Name: "DerivedAttr" ); |
557 | auto *opVarClass = recordKeeper.getClass(Name: "OpVariable" ); |
558 | numNativeAttributes = 0; |
559 | |
560 | DagInit *argumentValues = def.getValueAsDag(FieldName: "arguments" ); |
561 | unsigned numArgs = argumentValues->getNumArgs(); |
562 | |
563 | // Mapping from name of to argument or result index. Arguments are indexed |
564 | // to match getArg index, while the results are negatively indexed. |
565 | llvm::StringMap<int> argumentsAndResultsIndex; |
566 | |
567 | // Handle operands and native attributes. |
568 | for (unsigned i = 0; i != numArgs; ++i) { |
569 | auto *arg = argumentValues->getArg(Num: i); |
570 | auto givenName = argumentValues->getArgNameStr(Num: i); |
571 | auto *argDefInit = dyn_cast<DefInit>(Val: arg); |
572 | if (!argDefInit) |
573 | PrintFatalError(ErrorLoc: def.getLoc(), |
574 | Msg: Twine("undefined type for argument #" ) + Twine(i)); |
575 | Record *argDef = argDefInit->getDef(); |
576 | if (argDef->isSubClassOf(R: opVarClass)) |
577 | argDef = argDef->getValueAsDef(FieldName: "constraint" ); |
578 | |
579 | if (argDef->isSubClassOf(R: typeConstraintClass)) { |
580 | operands.push_back( |
581 | Elt: NamedTypeConstraint{.name: givenName, .constraint: TypeConstraint(argDef)}); |
582 | } else if (argDef->isSubClassOf(R: attrClass)) { |
583 | if (givenName.empty()) |
584 | PrintFatalError(ErrorLoc: argDef->getLoc(), Msg: "attributes must be named" ); |
585 | if (argDef->isSubClassOf(R: derivedAttrClass)) |
586 | PrintFatalError(ErrorLoc: argDef->getLoc(), |
587 | Msg: "derived attributes not allowed in argument list" ); |
588 | attributes.push_back(Elt: {.name: givenName, .attr: Attribute(argDef)}); |
589 | ++numNativeAttributes; |
590 | } else if (argDef->isSubClassOf(R: propertyClass)) { |
591 | if (givenName.empty()) |
592 | PrintFatalError(ErrorLoc: argDef->getLoc(), Msg: "properties must be named" ); |
593 | properties.push_back(Elt: {.name: givenName, .prop: Property(argDef)}); |
594 | } else { |
595 | PrintFatalError(ErrorLoc: def.getLoc(), |
596 | Msg: "unexpected def type; only defs deriving " |
597 | "from TypeConstraint or Attr or Property are allowed" ); |
598 | } |
599 | if (!givenName.empty()) |
600 | argumentsAndResultsIndex[givenName] = i; |
601 | } |
602 | |
603 | // Handle derived attributes. |
604 | for (const auto &val : def.getValues()) { |
605 | if (auto *record = dyn_cast<llvm::RecordRecTy>(Val: val.getType())) { |
606 | if (!record->isSubClassOf(Class: attrClass)) |
607 | continue; |
608 | if (!record->isSubClassOf(Class: derivedAttrClass)) |
609 | PrintFatalError(ErrorLoc: def.getLoc(), |
610 | Msg: "unexpected Attr where only DerivedAttr is allowed" ); |
611 | |
612 | if (record->getClasses().size() != 1) { |
613 | PrintFatalError( |
614 | ErrorLoc: def.getLoc(), |
615 | Msg: "unsupported attribute modelling, only single class expected" ); |
616 | } |
617 | attributes.push_back( |
618 | Elt: {.name: cast<llvm::StringInit>(Val: val.getNameInit())->getValue(), |
619 | .attr: Attribute(cast<DefInit>(Val: val.getValue()))}); |
620 | } |
621 | } |
622 | |
623 | // Populate `arguments`. This must happen after we've finalized `operands` and |
624 | // `attributes` because we will put their elements' pointers in `arguments`. |
625 | // SmallVector may perform re-allocation under the hood when adding new |
626 | // elements. |
627 | int operandIndex = 0, attrIndex = 0, propIndex = 0; |
628 | for (unsigned i = 0; i != numArgs; ++i) { |
629 | Record *argDef = dyn_cast<DefInit>(Val: argumentValues->getArg(Num: i))->getDef(); |
630 | if (argDef->isSubClassOf(R: opVarClass)) |
631 | argDef = argDef->getValueAsDef(FieldName: "constraint" ); |
632 | |
633 | if (argDef->isSubClassOf(R: typeConstraintClass)) { |
634 | attrOrOperandMapping.push_back( |
635 | Elt: {OperandOrAttribute::Kind::Operand, operandIndex}); |
636 | arguments.emplace_back(Args: &operands[operandIndex++]); |
637 | } else if (argDef->isSubClassOf(R: attrClass)) { |
638 | attrOrOperandMapping.push_back( |
639 | Elt: {OperandOrAttribute::Kind::Attribute, attrIndex}); |
640 | arguments.emplace_back(Args: &attributes[attrIndex++]); |
641 | } else { |
642 | assert(argDef->isSubClassOf(propertyClass)); |
643 | arguments.emplace_back(Args: &properties[propIndex++]); |
644 | } |
645 | } |
646 | |
647 | auto *resultsDag = def.getValueAsDag(FieldName: "results" ); |
648 | auto *outsOp = dyn_cast<DefInit>(Val: resultsDag->getOperator()); |
649 | if (!outsOp || outsOp->getDef()->getName() != "outs" ) { |
650 | PrintFatalError(ErrorLoc: def.getLoc(), Msg: "'results' must have 'outs' directive" ); |
651 | } |
652 | |
653 | // Handle results. |
654 | for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) { |
655 | auto name = resultsDag->getArgNameStr(Num: i); |
656 | auto *resultInit = dyn_cast<DefInit>(Val: resultsDag->getArg(Num: i)); |
657 | if (!resultInit) { |
658 | PrintFatalError(ErrorLoc: def.getLoc(), |
659 | Msg: Twine("undefined type for result #" ) + Twine(i)); |
660 | } |
661 | auto *resultDef = resultInit->getDef(); |
662 | if (resultDef->isSubClassOf(R: opVarClass)) |
663 | resultDef = resultDef->getValueAsDef(FieldName: "constraint" ); |
664 | results.push_back(Elt: {.name: name, .constraint: TypeConstraint(resultDef)}); |
665 | if (!name.empty()) |
666 | argumentsAndResultsIndex[name] = InferredResultType::mapResultIndex(i); |
667 | |
668 | // We currently only support VariadicOfVariadic operands. |
669 | if (results.back().constraint.isVariadicOfVariadic()) { |
670 | PrintFatalError( |
671 | ErrorLoc: def.getLoc(), |
672 | Msg: "'VariadicOfVariadic' results are currently not supported" ); |
673 | } |
674 | } |
675 | |
676 | // Handle successors |
677 | auto *successorsDag = def.getValueAsDag(FieldName: "successors" ); |
678 | auto *successorsOp = dyn_cast<DefInit>(Val: successorsDag->getOperator()); |
679 | if (!successorsOp || successorsOp->getDef()->getName() != "successor" ) { |
680 | PrintFatalError(ErrorLoc: def.getLoc(), |
681 | Msg: "'successors' must have 'successor' directive" ); |
682 | } |
683 | |
684 | for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) { |
685 | auto name = successorsDag->getArgNameStr(Num: i); |
686 | auto *successorInit = dyn_cast<DefInit>(Val: successorsDag->getArg(Num: i)); |
687 | if (!successorInit) { |
688 | PrintFatalError(ErrorLoc: def.getLoc(), |
689 | Msg: Twine("undefined kind for successor #" ) + Twine(i)); |
690 | } |
691 | Successor successor(successorInit->getDef()); |
692 | |
693 | // Only support variadic successors if it is the last one for now. |
694 | if (i != e - 1 && successor.isVariadic()) |
695 | PrintFatalError(ErrorLoc: def.getLoc(), Msg: "only the last successor can be variadic" ); |
696 | successors.push_back(Elt: {.name: name, .constraint: successor}); |
697 | } |
698 | |
699 | // Create list of traits, skipping over duplicates: appending to lists in |
700 | // tablegen is easy, making them unique less so, so dedupe here. |
701 | if (auto *traitList = def.getValueAsListInit(FieldName: "traits" )) { |
702 | // This is uniquing based on pointers of the trait. |
703 | SmallPtrSet<const llvm::Init *, 32> traitSet; |
704 | traits.reserve(N: traitSet.size()); |
705 | |
706 | // The declaration order of traits imply the verification order of traits. |
707 | // Some traits may require other traits to be verified first then they can |
708 | // do further verification based on those verified facts. If you see this |
709 | // error, fix the traits declaration order by checking the `dependentTraits` |
710 | // field. |
711 | auto verifyTraitValidity = [&](Record *trait) { |
712 | auto *dependentTraits = trait->getValueAsListInit(FieldName: "dependentTraits" ); |
713 | for (auto *traitInit : *dependentTraits) |
714 | if (!traitSet.contains(Ptr: traitInit)) |
715 | PrintFatalError( |
716 | ErrorLoc: def.getLoc(), |
717 | Msg: trait->getValueAsString(FieldName: "trait" ) + " requires " + |
718 | cast<DefInit>(Val: traitInit)->getDef()->getValueAsString( |
719 | FieldName: "trait" ) + |
720 | " to precede it in traits list" ); |
721 | }; |
722 | |
723 | std::function<void(llvm::ListInit *)> insert; |
724 | insert = [&](llvm::ListInit *traitList) { |
725 | for (auto *traitInit : *traitList) { |
726 | auto *def = cast<DefInit>(Val: traitInit)->getDef(); |
727 | if (def->isSubClassOf(Name: "TraitList" )) { |
728 | insert(def->getValueAsListInit(FieldName: "traits" )); |
729 | continue; |
730 | } |
731 | |
732 | // Ignore duplicates. |
733 | if (!traitSet.insert(Ptr: traitInit).second) |
734 | continue; |
735 | |
736 | // If this is an interface with base classes, add the bases to the |
737 | // trait list. |
738 | if (def->isSubClassOf(Name: "Interface" )) |
739 | insert(def->getValueAsListInit(FieldName: "baseInterfaces" )); |
740 | |
741 | // Verify if the trait has all the dependent traits declared before |
742 | // itself. |
743 | verifyTraitValidity(def); |
744 | traits.push_back(Elt: Trait::create(init: traitInit)); |
745 | } |
746 | }; |
747 | insert(traitList); |
748 | } |
749 | |
750 | populateTypeInferenceInfo(argumentsAndResultsIndex); |
751 | |
752 | // Handle regions |
753 | auto *regionsDag = def.getValueAsDag(FieldName: "regions" ); |
754 | auto *regionsOp = dyn_cast<DefInit>(Val: regionsDag->getOperator()); |
755 | if (!regionsOp || regionsOp->getDef()->getName() != "region" ) { |
756 | PrintFatalError(ErrorLoc: def.getLoc(), Msg: "'regions' must have 'region' directive" ); |
757 | } |
758 | |
759 | for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) { |
760 | auto name = regionsDag->getArgNameStr(Num: i); |
761 | auto *regionInit = dyn_cast<DefInit>(Val: regionsDag->getArg(Num: i)); |
762 | if (!regionInit) { |
763 | PrintFatalError(ErrorLoc: def.getLoc(), |
764 | Msg: Twine("undefined kind for region #" ) + Twine(i)); |
765 | } |
766 | Region region(regionInit->getDef()); |
767 | if (region.isVariadic()) { |
768 | // Only support variadic regions if it is the last one for now. |
769 | if (i != e - 1) |
770 | PrintFatalError(ErrorLoc: def.getLoc(), Msg: "only the last region can be variadic" ); |
771 | if (name.empty()) |
772 | PrintFatalError(ErrorLoc: def.getLoc(), Msg: "variadic regions must be named" ); |
773 | } |
774 | |
775 | regions.push_back(Elt: {.name: name, .constraint: region}); |
776 | } |
777 | |
778 | // Populate the builders. |
779 | auto *builderList = |
780 | dyn_cast_or_null<llvm::ListInit>(Val: def.getValueInit(FieldName: "builders" )); |
781 | if (builderList && !builderList->empty()) { |
782 | for (llvm::Init *init : builderList->getValues()) |
783 | builders.emplace_back(Args: cast<llvm::DefInit>(Val: init)->getDef(), Args: def.getLoc()); |
784 | } else if (skipDefaultBuilders()) { |
785 | PrintFatalError( |
786 | ErrorLoc: def.getLoc(), |
787 | Msg: "default builders are skipped and no custom builders provided" ); |
788 | } |
789 | |
790 | LLVM_DEBUG(print(llvm::dbgs())); |
791 | } |
792 | |
793 | const InferredResultType &Operator::getInferredResultType(int index) const { |
794 | assert(allResultTypesKnown()); |
795 | return resultTypeMapping[index]; |
796 | } |
797 | |
798 | ArrayRef<SMLoc> Operator::getLoc() const { return def.getLoc(); } |
799 | |
800 | bool Operator::hasDescription() const { |
801 | return def.getValue(Name: "description" ) != nullptr; |
802 | } |
803 | |
804 | StringRef Operator::getDescription() const { |
805 | return def.getValueAsString(FieldName: "description" ); |
806 | } |
807 | |
808 | bool Operator::hasSummary() const { return def.getValue(Name: "summary" ) != nullptr; } |
809 | |
810 | StringRef Operator::getSummary() const { |
811 | return def.getValueAsString(FieldName: "summary" ); |
812 | } |
813 | |
814 | bool Operator::hasAssemblyFormat() const { |
815 | auto *valueInit = def.getValueInit(FieldName: "assemblyFormat" ); |
816 | return isa<llvm::StringInit>(Val: valueInit); |
817 | } |
818 | |
819 | StringRef Operator::getAssemblyFormat() const { |
820 | return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit(FieldName: "assemblyFormat" )) |
821 | .Case<llvm::StringInit>(caseFn: [&](auto *init) { return init->getValue(); }); |
822 | } |
823 | |
824 | void Operator::print(llvm::raw_ostream &os) const { |
825 | os << "op '" << getOperationName() << "'\n" ; |
826 | for (Argument arg : arguments) { |
827 | if (auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(Val&: arg)) |
828 | os << "[attribute] " << attr->name << '\n'; |
829 | else |
830 | os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n'; |
831 | } |
832 | } |
833 | |
834 | auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init) |
835 | -> VariableDecorator { |
836 | return VariableDecorator(cast<llvm::DefInit>(Val: init)->getDef()); |
837 | } |
838 | |
839 | auto Operator::getArgToOperandOrAttribute(int index) const |
840 | -> OperandOrAttribute { |
841 | return attrOrOperandMapping[index]; |
842 | } |
843 | |
844 | std::string Operator::getGetterName(StringRef name) const { |
845 | return "get" + convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true); |
846 | } |
847 | |
848 | std::string Operator::getSetterName(StringRef name) const { |
849 | return "set" + convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true); |
850 | } |
851 | |
852 | std::string Operator::getRemoverName(StringRef name) const { |
853 | return "remove" + convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true); |
854 | } |
855 | |
856 | bool Operator::hasFolder() const { return def.getValueAsBit(FieldName: "hasFolder" ); } |
857 | |
858 | bool Operator::useCustomPropertiesEncoding() const { |
859 | return def.getValueAsBit(FieldName: "useCustomPropertiesEncoding" ); |
860 | } |
861 | |