1//===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "OpFormatGen.h"
10#include "FormatGen.h"
11#include "OpClass.h"
12#include "mlir/Support/LLVM.h"
13#include "mlir/TableGen/Class.h"
14#include "mlir/TableGen/Format.h"
15#include "mlir/TableGen/Operator.h"
16#include "mlir/TableGen/Trait.h"
17#include "llvm/ADT/MapVector.h"
18#include "llvm/ADT/Sequence.h"
19#include "llvm/ADT/SetVector.h"
20#include "llvm/ADT/SmallBitVector.h"
21#include "llvm/ADT/StringExtras.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/Signals.h"
24#include "llvm/Support/SourceMgr.h"
25#include "llvm/TableGen/Record.h"
26
27#define DEBUG_TYPE "mlir-tblgen-opformatgen"
28
29using namespace mlir;
30using namespace mlir::tblgen;
31
32//===----------------------------------------------------------------------===//
33// VariableElement
34
35namespace {
36/// This class represents an instance of an op variable element. A variable
37/// refers to something registered on the operation itself, e.g. an operand,
38/// result, attribute, region, or successor.
39template <typename VarT, VariableElement::Kind VariableKind>
40class OpVariableElement : public VariableElementBase<VariableKind> {
41public:
42 using Base = OpVariableElement<VarT, VariableKind>;
43
44 /// Create an op variable element with the variable value.
45 OpVariableElement(const VarT *var) : var(var) {}
46
47 /// Get the variable.
48 const VarT *getVar() { return var; }
49
50protected:
51 /// The op variable, e.g. a type or attribute constraint.
52 const VarT *var;
53};
54
55/// This class represents a variable that refers to an attribute argument.
56struct AttributeVariable
57 : public OpVariableElement<NamedAttribute, VariableElement::Attribute> {
58 using Base::Base;
59
60 /// Return the constant builder call for the type of this attribute, or
61 /// std::nullopt if it doesn't have one.
62 std::optional<StringRef> getTypeBuilder() const {
63 std::optional<Type> attrType = var->attr.getValueType();
64 return attrType ? attrType->getBuilderCall() : std::nullopt;
65 }
66
67 /// Return if this attribute refers to a UnitAttr.
68 bool isUnitAttr() const {
69 return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr";
70 }
71
72 /// Indicate if this attribute is printed "qualified" (that is it is
73 /// prefixed with the `#dialect.mnemonic`).
74 bool shouldBeQualified() { return shouldBeQualifiedFlag; }
75 void setShouldBeQualified(bool qualified = true) {
76 shouldBeQualifiedFlag = qualified;
77 }
78
79private:
80 bool shouldBeQualifiedFlag = false;
81};
82
83/// This class represents a variable that refers to an operand argument.
84using OperandVariable =
85 OpVariableElement<NamedTypeConstraint, VariableElement::Operand>;
86
87/// This class represents a variable that refers to a result.
88using ResultVariable =
89 OpVariableElement<NamedTypeConstraint, VariableElement::Result>;
90
91/// This class represents a variable that refers to a region.
92using RegionVariable = OpVariableElement<NamedRegion, VariableElement::Region>;
93
94/// This class represents a variable that refers to a successor.
95using SuccessorVariable =
96 OpVariableElement<NamedSuccessor, VariableElement::Successor>;
97
98/// This class represents a variable that refers to a property argument.
99using PropertyVariable =
100 OpVariableElement<NamedProperty, VariableElement::Property>;
101} // namespace
102
103//===----------------------------------------------------------------------===//
104// DirectiveElement
105
106namespace {
107/// This class represents the `operands` directive. This directive represents
108/// all of the operands of an operation.
109using OperandsDirective = DirectiveElementBase<DirectiveElement::Operands>;
110
111/// This class represents the `results` directive. This directive represents
112/// all of the results of an operation.
113using ResultsDirective = DirectiveElementBase<DirectiveElement::Results>;
114
115/// This class represents the `regions` directive. This directive represents
116/// all of the regions of an operation.
117using RegionsDirective = DirectiveElementBase<DirectiveElement::Regions>;
118
119/// This class represents the `successors` directive. This directive represents
120/// all of the successors of an operation.
121using SuccessorsDirective = DirectiveElementBase<DirectiveElement::Successors>;
122
123/// This class represents the `attr-dict` directive. This directive represents
124/// the attribute dictionary of the operation.
125class AttrDictDirective
126 : public DirectiveElementBase<DirectiveElement::AttrDict> {
127public:
128 explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {}
129
130 /// Return whether the dictionary should be printed with the 'attributes'
131 /// keyword.
132 bool isWithKeyword() const { return withKeyword; }
133
134private:
135 /// If the dictionary should be printed with the 'attributes' keyword.
136 bool withKeyword;
137};
138
139/// This class represents the `prop-dict` directive. This directive represents
140/// the properties of the operation, expressed as a directionary.
141class PropDictDirective
142 : public DirectiveElementBase<DirectiveElement::PropDict> {
143public:
144 explicit PropDictDirective() = default;
145};
146
147/// This class represents the `functional-type` directive. This directive takes
148/// two arguments and formats them, respectively, as the inputs and results of a
149/// FunctionType.
150class FunctionalTypeDirective
151 : public DirectiveElementBase<DirectiveElement::FunctionalType> {
152public:
153 FunctionalTypeDirective(FormatElement *inputs, FormatElement *results)
154 : inputs(inputs), results(results) {}
155
156 FormatElement *getInputs() const { return inputs; }
157 FormatElement *getResults() const { return results; }
158
159private:
160 /// The input and result arguments.
161 FormatElement *inputs, *results;
162};
163
164/// This class represents the `type` directive.
165class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
166public:
167 TypeDirective(FormatElement *arg) : arg(arg) {}
168
169 FormatElement *getArg() const { return arg; }
170
171 /// Indicate if this type is printed "qualified" (that is it is
172 /// prefixed with the `!dialect.mnemonic`).
173 bool shouldBeQualified() { return shouldBeQualifiedFlag; }
174 void setShouldBeQualified(bool qualified = true) {
175 shouldBeQualifiedFlag = qualified;
176 }
177
178private:
179 /// The argument that is used to format the directive.
180 FormatElement *arg;
181
182 bool shouldBeQualifiedFlag = false;
183};
184
185/// This class represents a group of order-independent optional clauses. Each
186/// clause starts with a literal element and has a coressponding parsing
187/// element. A parsing element is a continous sequence of format elements.
188/// Each clause can appear 0 or 1 time.
189class OIListElement : public DirectiveElementBase<DirectiveElement::OIList> {
190public:
191 OIListElement(std::vector<FormatElement *> &&literalElements,
192 std::vector<std::vector<FormatElement *>> &&parsingElements)
193 : literalElements(std::move(literalElements)),
194 parsingElements(std::move(parsingElements)) {}
195
196 /// Returns a range to iterate over the LiteralElements.
197 auto getLiteralElements() const {
198 function_ref<LiteralElement *(FormatElement * el)>
199 literalElementCastConverter =
200 [](FormatElement *el) { return cast<LiteralElement>(Val: el); };
201 return llvm::map_range(C: literalElements, F: literalElementCastConverter);
202 }
203
204 /// Returns a range to iterate over the parsing elements corresponding to the
205 /// clauses.
206 ArrayRef<std::vector<FormatElement *>> getParsingElements() const {
207 return parsingElements;
208 }
209
210 /// Returns a range to iterate over tuples of parsing and literal elements.
211 auto getClauses() const {
212 return llvm::zip(t: getLiteralElements(), u: getParsingElements());
213 }
214
215 /// If the parsing element is a single UnitAttr element, then it returns the
216 /// attribute variable. Otherwise, returns nullptr.
217 AttributeVariable *
218 getUnitAttrParsingElement(ArrayRef<FormatElement *> pelement) {
219 if (pelement.size() == 1) {
220 auto *attrElem = dyn_cast<AttributeVariable>(Val: pelement[0]);
221 if (attrElem && attrElem->isUnitAttr())
222 return attrElem;
223 }
224 return nullptr;
225 }
226
227private:
228 /// A vector of `LiteralElement` objects. Each element stores the keyword
229 /// for one case of oilist element. For example, an oilist element along with
230 /// the `literalElements` vector:
231 /// ```
232 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
233 /// literalElements = { `keyword`, `otherKeyword` }
234 /// ```
235 std::vector<FormatElement *> literalElements;
236
237 /// A vector of valid declarative assembly format vectors. Each object in
238 /// parsing elements is a vector of elements in assembly format syntax.
239 /// For example, an oilist element along with the parsingElements vector:
240 /// ```
241 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
242 /// parsingElements = {
243 /// { `=`, `(`, $arg0, `)` },
244 /// { `<`, $arg1, `>` }
245 /// }
246 /// ```
247 std::vector<std::vector<FormatElement *>> parsingElements;
248};
249} // namespace
250
251//===----------------------------------------------------------------------===//
252// OperationFormat
253//===----------------------------------------------------------------------===//
254
255namespace {
256
257using ConstArgument =
258 llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
259
260struct OperationFormat {
261 /// This class represents a specific resolver for an operand or result type.
262 class TypeResolution {
263 public:
264 TypeResolution() = default;
265
266 /// Get the index into the buildable types for this type, or std::nullopt.
267 std::optional<int> getBuilderIdx() const { return builderIdx; }
268 void setBuilderIdx(int idx) { builderIdx = idx; }
269
270 /// Get the variable this type is resolved to, or nullptr.
271 const NamedTypeConstraint *getVariable() const {
272 return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(Val: resolver);
273 }
274 /// Get the attribute this type is resolved to, or nullptr.
275 const NamedAttribute *getAttribute() const {
276 return llvm::dyn_cast_if_present<const NamedAttribute *>(Val: resolver);
277 }
278 /// Get the transformer for the type of the variable, or std::nullopt.
279 std::optional<StringRef> getVarTransformer() const {
280 return variableTransformer;
281 }
282 void setResolver(ConstArgument arg, std::optional<StringRef> transformer) {
283 resolver = arg;
284 variableTransformer = transformer;
285 assert(getVariable() || getAttribute());
286 }
287
288 private:
289 /// If the type is resolved with a buildable type, this is the index into
290 /// 'buildableTypes' in the parent format.
291 std::optional<int> builderIdx;
292 /// If the type is resolved based upon another operand or result, this is
293 /// the variable or the attribute that this type is resolved to.
294 ConstArgument resolver;
295 /// If the type is resolved based upon another operand or result, this is
296 /// a transformer to apply to the variable when resolving.
297 std::optional<StringRef> variableTransformer;
298 };
299
300 /// The context in which an element is generated.
301 enum class GenContext {
302 /// The element is generated at the top-level or with the same behaviour.
303 Normal,
304 /// The element is generated inside an optional group.
305 Optional
306 };
307
308 OperationFormat(const Operator &op)
309 : useProperties(op.getDialect().usePropertiesForAttributes() &&
310 !op.getAttributes().empty()),
311 opCppClassName(op.getCppClassName()) {
312 operandTypes.resize(new_size: op.getNumOperands(), x: TypeResolution());
313 resultTypes.resize(new_size: op.getNumResults(), x: TypeResolution());
314
315 hasImplicitTermTrait = llvm::any_of(Range: op.getTraits(), P: [](const Trait &trait) {
316 return trait.getDef().isSubClassOf(Name: "SingleBlockImplicitTerminatorImpl");
317 });
318
319 hasSingleBlockTrait = op.getTrait(trait: "::mlir::OpTrait::SingleBlock");
320 }
321
322 /// Generate the operation parser from this format.
323 void genParser(Operator &op, OpClass &opClass);
324 /// Generate the parser code for a specific format element.
325 void genElementParser(FormatElement *element, MethodBody &body,
326 FmtContext &attrTypeCtx,
327 GenContext genCtx = GenContext::Normal);
328 /// Generate the C++ to resolve the types of operands and results during
329 /// parsing.
330 void genParserTypeResolution(Operator &op, MethodBody &body);
331 /// Generate the C++ to resolve the types of the operands during parsing.
332 void genParserOperandTypeResolution(
333 Operator &op, MethodBody &body,
334 function_ref<void(TypeResolution &, StringRef)> emitTypeResolver);
335 /// Generate the C++ to resolve regions during parsing.
336 void genParserRegionResolution(Operator &op, MethodBody &body);
337 /// Generate the C++ to resolve successors during parsing.
338 void genParserSuccessorResolution(Operator &op, MethodBody &body);
339 /// Generate the C++ to handling variadic segment size traits.
340 void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
341
342 /// Generate the operation printer from this format.
343 void genPrinter(Operator &op, OpClass &opClass);
344
345 /// Generate the printer code for a specific format element.
346 void genElementPrinter(FormatElement *element, MethodBody &body, Operator &op,
347 bool &shouldEmitSpace, bool &lastWasPunctuation);
348
349 /// The various elements in this format.
350 std::vector<FormatElement *> elements;
351
352 /// A flag indicating if all operand/result types were seen. If the format
353 /// contains these, it can not contain individual type resolvers.
354 bool allOperands = false, allOperandTypes = false, allResultTypes = false;
355
356 /// A flag indicating if this operation infers its result types
357 bool infersResultTypes = false;
358
359 /// A flag indicating if this operation has the SingleBlockImplicitTerminator
360 /// trait.
361 bool hasImplicitTermTrait;
362
363 /// A flag indicating if this operation has the SingleBlock trait.
364 bool hasSingleBlockTrait;
365
366 /// Indicate whether attribute are stored in properties.
367 bool useProperties;
368
369 /// Indicate whether prop-dict is used in the format
370 bool hasPropDict;
371
372 /// The Operation class name
373 StringRef opCppClassName;
374
375 /// A map of buildable types to indices.
376 llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
377
378 /// The index of the buildable type, if valid, for every operand and result.
379 std::vector<TypeResolution> operandTypes, resultTypes;
380
381 /// The set of attributes explicitly used within the format.
382 llvm::SmallSetVector<const NamedAttribute *, 8> usedAttributes;
383 llvm::StringSet<> inferredAttributes;
384
385 /// The set of properties explicitly used within the format.
386 llvm::SmallSetVector<const NamedProperty *, 8> usedProperties;
387};
388} // namespace
389
390//===----------------------------------------------------------------------===//
391// Parser Gen
392
393/// Returns true if we can format the given attribute as an EnumAttr in the
394/// parser format.
395static bool canFormatEnumAttr(const NamedAttribute *attr) {
396 Attribute baseAttr = attr->attr.getBaseAttr();
397 const EnumAttr *enumAttr = dyn_cast<EnumAttr>(Val: &baseAttr);
398 if (!enumAttr)
399 return false;
400
401 // The attribute must have a valid underlying type and a constant builder.
402 return !enumAttr->getUnderlyingType().empty() &&
403 !enumAttr->getConstBuilderTemplate().empty();
404}
405
406/// Returns if we should format the given attribute as an SymbolNameAttr.
407static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
408 return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
409}
410
411/// The code snippet used to generate a parser call for an attribute.
412///
413/// {0}: The name of the attribute.
414/// {1}: The type for the attribute.
415const char *const attrParserCode = R"(
416 if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{
417 return ::mlir::failure();
418 }
419)";
420
421/// The code snippet used to generate a parser call for an attribute.
422///
423/// {0}: The name of the attribute.
424/// {1}: The type for the attribute.
425const char *const genericAttrParserCode = R"(
426 if (parser.parseAttribute({0}Attr, {1}))
427 return ::mlir::failure();
428)";
429
430const char *const optionalAttrParserCode = R"(
431 ::mlir::OptionalParseResult parseResult{0}Attr =
432 parser.parseOptionalAttribute({0}Attr, {1});
433 if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr))
434 return ::mlir::failure();
435 if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr))
436)";
437
438/// The code snippet used to generate a parser call for a symbol name attribute.
439///
440/// {0}: The name of the attribute.
441const char *const symbolNameAttrParserCode = R"(
442 if (parser.parseSymbolName({0}Attr))
443 return ::mlir::failure();
444)";
445const char *const optionalSymbolNameAttrParserCode = R"(
446 // Parsing an optional symbol name doesn't fail, so no need to check the
447 // result.
448 (void)parser.parseOptionalSymbolName({0}Attr);
449)";
450
451/// The code snippet used to generate a parser call for an enum attribute.
452///
453/// {0}: The name of the attribute.
454/// {1}: The c++ namespace for the enum symbolize functions.
455/// {2}: The function to symbolize a string of the enum.
456/// {3}: The constant builder call to create an attribute of the enum type.
457/// {4}: The set of allowed enum keywords.
458/// {5}: The error message on failure when the enum isn't present.
459/// {6}: The attribute assignment expression
460const char *const enumAttrParserCode = R"(
461 {
462 ::llvm::StringRef attrStr;
463 ::mlir::NamedAttrList attrStorage;
464 auto loc = parser.getCurrentLocation();
465 if (parser.parseOptionalKeyword(&attrStr, {4})) {
466 ::mlir::StringAttr attrVal;
467 ::mlir::OptionalParseResult parseResult =
468 parser.parseOptionalAttribute(attrVal,
469 parser.getBuilder().getNoneType(),
470 "{0}", attrStorage);
471 if (parseResult.has_value()) {{
472 if (failed(*parseResult))
473 return ::mlir::failure();
474 attrStr = attrVal.getValue();
475 } else {
476 {5}
477 }
478 }
479 if (!attrStr.empty()) {
480 auto attrOptional = {1}::{2}(attrStr);
481 if (!attrOptional)
482 return parser.emitError(loc, "invalid ")
483 << "{0} attribute specification: \"" << attrStr << '"';;
484
485 {0}Attr = {3};
486 {6}
487 }
488 }
489)";
490
491/// The code snippet used to generate a parser call for an operand.
492///
493/// {0}: The name of the operand.
494const char *const variadicOperandParserCode = R"(
495 {0}OperandsLoc = parser.getCurrentLocation();
496 if (parser.parseOperandList({0}Operands))
497 return ::mlir::failure();
498)";
499const char *const optionalOperandParserCode = R"(
500 {
501 {0}OperandsLoc = parser.getCurrentLocation();
502 ::mlir::OpAsmParser::UnresolvedOperand operand;
503 ::mlir::OptionalParseResult parseResult =
504 parser.parseOptionalOperand(operand);
505 if (parseResult.has_value()) {
506 if (failed(*parseResult))
507 return ::mlir::failure();
508 {0}Operands.push_back(operand);
509 }
510 }
511)";
512const char *const operandParserCode = R"(
513 {0}OperandsLoc = parser.getCurrentLocation();
514 if (parser.parseOperand({0}RawOperand))
515 return ::mlir::failure();
516)";
517/// The code snippet used to generate a parser call for a VariadicOfVariadic
518/// operand.
519///
520/// {0}: The name of the operand.
521/// {1}: The name of segment size attribute.
522const char *const variadicOfVariadicOperandParserCode = R"(
523 {
524 {0}OperandsLoc = parser.getCurrentLocation();
525 int32_t curSize = 0;
526 do {
527 if (parser.parseOptionalLParen())
528 break;
529 if (parser.parseOperandList({0}Operands) || parser.parseRParen())
530 return ::mlir::failure();
531 {0}OperandGroupSizes.push_back({0}Operands.size() - curSize);
532 curSize = {0}Operands.size();
533 } while (succeeded(parser.parseOptionalComma()));
534 }
535)";
536
537/// The code snippet used to generate a parser call for a type list.
538///
539/// {0}: The name for the type list.
540const char *const variadicOfVariadicTypeParserCode = R"(
541 do {
542 if (parser.parseOptionalLParen())
543 break;
544 if (parser.parseOptionalRParen() &&
545 (parser.parseTypeList({0}Types) || parser.parseRParen()))
546 return ::mlir::failure();
547 } while (succeeded(parser.parseOptionalComma()));
548)";
549const char *const variadicTypeParserCode = R"(
550 if (parser.parseTypeList({0}Types))
551 return ::mlir::failure();
552)";
553const char *const optionalTypeParserCode = R"(
554 {
555 ::mlir::Type optionalType;
556 ::mlir::OptionalParseResult parseResult =
557 parser.parseOptionalType(optionalType);
558 if (parseResult.has_value()) {
559 if (failed(*parseResult))
560 return ::mlir::failure();
561 {0}Types.push_back(optionalType);
562 }
563 }
564)";
565const char *const typeParserCode = R"(
566 {
567 {0} type;
568 if (parser.parseCustomTypeWithFallback(type))
569 return ::mlir::failure();
570 {1}RawType = type;
571 }
572)";
573const char *const qualifiedTypeParserCode = R"(
574 if (parser.parseType({1}RawType))
575 return ::mlir::failure();
576)";
577
578/// The code snippet used to generate a parser call for a functional type.
579///
580/// {0}: The name for the input type list.
581/// {1}: The name for the result type list.
582const char *const functionalTypeParserCode = R"(
583 ::mlir::FunctionType {0}__{1}_functionType;
584 if (parser.parseType({0}__{1}_functionType))
585 return ::mlir::failure();
586 {0}Types = {0}__{1}_functionType.getInputs();
587 {1}Types = {0}__{1}_functionType.getResults();
588)";
589
590/// The code snippet used to generate a parser call to infer return types.
591///
592/// {0}: The operation class name
593const char *const inferReturnTypesParserCode = R"(
594 ::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
595 if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
596 result.location, result.operands,
597 result.attributes.getDictionary(parser.getContext()),
598 result.getRawProperties(),
599 result.regions, inferredReturnTypes)))
600 return ::mlir::failure();
601 result.addTypes(inferredReturnTypes);
602)";
603
604/// The code snippet used to generate a parser call for a region list.
605///
606/// {0}: The name for the region list.
607const char *regionListParserCode = R"(
608 {
609 std::unique_ptr<::mlir::Region> region;
610 auto firstRegionResult = parser.parseOptionalRegion(region);
611 if (firstRegionResult.has_value()) {
612 if (failed(*firstRegionResult))
613 return ::mlir::failure();
614 {0}Regions.emplace_back(std::move(region));
615
616 // Parse any trailing regions.
617 while (succeeded(parser.parseOptionalComma())) {
618 region = std::make_unique<::mlir::Region>();
619 if (parser.parseRegion(*region))
620 return ::mlir::failure();
621 {0}Regions.emplace_back(std::move(region));
622 }
623 }
624 }
625)";
626
627/// The code snippet used to ensure a list of regions have terminators.
628///
629/// {0}: The name of the region list.
630const char *regionListEnsureTerminatorParserCode = R"(
631 for (auto &region : {0}Regions)
632 ensureTerminator(*region, parser.getBuilder(), result.location);
633)";
634
635/// The code snippet used to ensure a list of regions have a block.
636///
637/// {0}: The name of the region list.
638const char *regionListEnsureSingleBlockParserCode = R"(
639 for (auto &region : {0}Regions)
640 if (region->empty()) region->emplaceBlock();
641)";
642
643/// The code snippet used to generate a parser call for an optional region.
644///
645/// {0}: The name of the region.
646const char *optionalRegionParserCode = R"(
647 {
648 auto parseResult = parser.parseOptionalRegion(*{0}Region);
649 if (parseResult.has_value() && failed(*parseResult))
650 return ::mlir::failure();
651 }
652)";
653
654/// The code snippet used to generate a parser call for a region.
655///
656/// {0}: The name of the region.
657const char *regionParserCode = R"(
658 if (parser.parseRegion(*{0}Region))
659 return ::mlir::failure();
660)";
661
662/// The code snippet used to ensure a region has a terminator.
663///
664/// {0}: The name of the region.
665const char *regionEnsureTerminatorParserCode = R"(
666 ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
667)";
668
669/// The code snippet used to ensure a region has a block.
670///
671/// {0}: The name of the region.
672const char *regionEnsureSingleBlockParserCode = R"(
673 if ({0}Region->empty()) {0}Region->emplaceBlock();
674)";
675
676/// The code snippet used to generate a parser call for a successor list.
677///
678/// {0}: The name for the successor list.
679const char *successorListParserCode = R"(
680 {
681 ::mlir::Block *succ;
682 auto firstSucc = parser.parseOptionalSuccessor(succ);
683 if (firstSucc.has_value()) {
684 if (failed(*firstSucc))
685 return ::mlir::failure();
686 {0}Successors.emplace_back(succ);
687
688 // Parse any trailing successors.
689 while (succeeded(parser.parseOptionalComma())) {
690 if (parser.parseSuccessor(succ))
691 return ::mlir::failure();
692 {0}Successors.emplace_back(succ);
693 }
694 }
695 }
696)";
697
698/// The code snippet used to generate a parser call for a successor.
699///
700/// {0}: The name of the successor.
701const char *successorParserCode = R"(
702 if (parser.parseSuccessor({0}Successor))
703 return ::mlir::failure();
704)";
705
706/// The code snippet used to generate a parser for OIList
707///
708/// {0}: literal keyword corresponding to a case for oilist
709const char *oilistParserCode = R"(
710 if ({0}Clause) {
711 return parser.emitError(parser.getNameLoc())
712 << "`{0}` clause can appear at most once in the expansion of the "
713 "oilist directive";
714 }
715 {0}Clause = true;
716)";
717
718namespace {
719/// The type of length for a given parse argument.
720enum class ArgumentLengthKind {
721 /// The argument is a variadic of a variadic, and may contain 0->N range
722 /// elements.
723 VariadicOfVariadic,
724 /// The argument is variadic, and may contain 0->N elements.
725 Variadic,
726 /// The argument is optional, and may contain 0 or 1 elements.
727 Optional,
728 /// The argument is a single element, i.e. always represents 1 element.
729 Single
730};
731} // namespace
732
733/// Get the length kind for the given constraint.
734static ArgumentLengthKind
735getArgumentLengthKind(const NamedTypeConstraint *var) {
736 if (var->isOptional())
737 return ArgumentLengthKind::Optional;
738 if (var->isVariadicOfVariadic())
739 return ArgumentLengthKind::VariadicOfVariadic;
740 if (var->isVariadic())
741 return ArgumentLengthKind::Variadic;
742 return ArgumentLengthKind::Single;
743}
744
745/// Get the name used for the type list for the given type directive operand.
746/// 'lengthKind' to the corresponding kind for the given argument.
747static StringRef getTypeListName(FormatElement *arg,
748 ArgumentLengthKind &lengthKind) {
749 if (auto *operand = dyn_cast<OperandVariable>(Val: arg)) {
750 lengthKind = getArgumentLengthKind(var: operand->getVar());
751 return operand->getVar()->name;
752 }
753 if (auto *result = dyn_cast<ResultVariable>(Val: arg)) {
754 lengthKind = getArgumentLengthKind(var: result->getVar());
755 return result->getVar()->name;
756 }
757 lengthKind = ArgumentLengthKind::Variadic;
758 if (isa<OperandsDirective>(Val: arg))
759 return "allOperand";
760 if (isa<ResultsDirective>(Val: arg))
761 return "allResult";
762 llvm_unreachable("unknown 'type' directive argument");
763}
764
765/// Generate the parser for a literal value.
766static void genLiteralParser(StringRef value, MethodBody &body) {
767 // Handle the case of a keyword/identifier.
768 if (value.front() == '_' || isalpha(value.front())) {
769 body << "Keyword(\"" << value << "\")";
770 return;
771 }
772 body << (StringRef)StringSwitch<StringRef>(value)
773 .Case(S: "->", Value: "Arrow()")
774 .Case(S: ":", Value: "Colon()")
775 .Case(S: ",", Value: "Comma()")
776 .Case(S: "=", Value: "Equal()")
777 .Case(S: "<", Value: "Less()")
778 .Case(S: ">", Value: "Greater()")
779 .Case(S: "{", Value: "LBrace()")
780 .Case(S: "}", Value: "RBrace()")
781 .Case(S: "(", Value: "LParen()")
782 .Case(S: ")", Value: "RParen()")
783 .Case(S: "[", Value: "LSquare()")
784 .Case(S: "]", Value: "RSquare()")
785 .Case(S: "?", Value: "Question()")
786 .Case(S: "+", Value: "Plus()")
787 .Case(S: "*", Value: "Star()")
788 .Case(S: "...", Value: "Ellipsis()");
789}
790
791/// Generate the storage code required for parsing the given element.
792static void genElementParserStorage(FormatElement *element, const Operator &op,
793 MethodBody &body) {
794 if (auto *optional = dyn_cast<OptionalElement>(Val: element)) {
795 ArrayRef<FormatElement *> elements = optional->getThenElements();
796
797 // If the anchor is a unit attribute, it won't be parsed directly so elide
798 // it.
799 auto *anchor = dyn_cast<AttributeVariable>(Val: optional->getAnchor());
800 FormatElement *elidedAnchorElement = nullptr;
801 if (anchor && anchor != elements.front() && anchor->isUnitAttr())
802 elidedAnchorElement = anchor;
803 for (FormatElement *childElement : elements)
804 if (childElement != elidedAnchorElement)
805 genElementParserStorage(element: childElement, op, body);
806 for (FormatElement *childElement : optional->getElseElements())
807 genElementParserStorage(element: childElement, op, body);
808
809 } else if (auto *oilist = dyn_cast<OIListElement>(Val: element)) {
810 for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements()) {
811 if (!oilist->getUnitAttrParsingElement(pelement))
812 for (FormatElement *element : pelement)
813 genElementParserStorage(element, op, body);
814 }
815
816 } else if (auto *custom = dyn_cast<CustomDirective>(Val: element)) {
817 for (FormatElement *paramElement : custom->getArguments())
818 genElementParserStorage(element: paramElement, op, body);
819
820 } else if (isa<OperandsDirective>(Val: element)) {
821 body << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
822 "allOperands;\n";
823
824 } else if (isa<RegionsDirective>(Val: element)) {
825 body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
826 "fullRegions;\n";
827
828 } else if (isa<SuccessorsDirective>(Val: element)) {
829 body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
830
831 } else if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) {
832 const NamedAttribute *var = attr->getVar();
833 body << llvm::formatv(Fmt: " {0} {1}Attr;\n", Vals: var->attr.getStorageType(),
834 Vals: var->name);
835
836 } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) {
837 StringRef name = operand->getVar()->name;
838 if (operand->getVar()->isVariableLength()) {
839 body
840 << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
841 << name << "Operands;\n";
842 if (operand->getVar()->isVariadicOfVariadic()) {
843 body << " llvm::SmallVector<int32_t> " << name
844 << "OperandGroupSizes;\n";
845 }
846 } else {
847 body << " ::mlir::OpAsmParser::UnresolvedOperand " << name
848 << "RawOperand{};\n"
849 << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
850 << name << "Operands(&" << name << "RawOperand, 1);";
851 }
852 body << llvm::formatv(Fmt: " ::llvm::SMLoc {0}OperandsLoc;\n"
853 " (void){0}OperandsLoc;\n",
854 Vals&: name);
855
856 } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) {
857 StringRef name = region->getVar()->name;
858 if (region->getVar()->isVariadic()) {
859 body << llvm::formatv(
860 Fmt: " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
861 "{0}Regions;\n",
862 Vals&: name);
863 } else {
864 body << llvm::formatv(Fmt: " std::unique_ptr<::mlir::Region> {0}Region = "
865 "std::make_unique<::mlir::Region>();\n",
866 Vals&: name);
867 }
868
869 } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) {
870 StringRef name = successor->getVar()->name;
871 if (successor->getVar()->isVariadic()) {
872 body << llvm::formatv(Fmt: " ::llvm::SmallVector<::mlir::Block *, 2> "
873 "{0}Successors;\n",
874 Vals&: name);
875 } else {
876 body << llvm::formatv(Fmt: " ::mlir::Block *{0}Successor = nullptr;\n", Vals&: name);
877 }
878
879 } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) {
880 ArgumentLengthKind lengthKind;
881 StringRef name = getTypeListName(arg: dir->getArg(), lengthKind);
882 if (lengthKind != ArgumentLengthKind::Single)
883 body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
884 else
885 body
886 << llvm::formatv(Fmt: " ::mlir::Type {0}RawType{{};\n", Vals&: name)
887 << llvm::formatv(
888 Fmt: " ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n",
889 Vals&: name);
890 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(Val: element)) {
891 ArgumentLengthKind ignored;
892 body << " ::llvm::ArrayRef<::mlir::Type> "
893 << getTypeListName(arg: dir->getInputs(), lengthKind&: ignored) << "Types;\n";
894 body << " ::llvm::ArrayRef<::mlir::Type> "
895 << getTypeListName(arg: dir->getResults(), lengthKind&: ignored) << "Types;\n";
896 }
897}
898
899/// Generate the parser for a parameter to a custom directive.
900static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
901 if (auto *attr = dyn_cast<AttributeVariable>(Val: param)) {
902 body << attr->getVar()->name << "Attr";
903 } else if (isa<AttrDictDirective>(Val: param)) {
904 body << "result.attributes";
905 } else if (isa<PropDictDirective>(Val: param)) {
906 body << "result";
907 } else if (auto *operand = dyn_cast<OperandVariable>(Val: param)) {
908 StringRef name = operand->getVar()->name;
909 ArgumentLengthKind lengthKind = getArgumentLengthKind(var: operand->getVar());
910 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
911 body << llvm::formatv(Fmt: "{0}OperandGroups", Vals&: name);
912 else if (lengthKind == ArgumentLengthKind::Variadic)
913 body << llvm::formatv(Fmt: "{0}Operands", Vals&: name);
914 else if (lengthKind == ArgumentLengthKind::Optional)
915 body << llvm::formatv(Fmt: "{0}Operand", Vals&: name);
916 else
917 body << formatv(Fmt: "{0}RawOperand", Vals&: name);
918
919 } else if (auto *region = dyn_cast<RegionVariable>(Val: param)) {
920 StringRef name = region->getVar()->name;
921 if (region->getVar()->isVariadic())
922 body << llvm::formatv(Fmt: "{0}Regions", Vals&: name);
923 else
924 body << llvm::formatv(Fmt: "*{0}Region", Vals&: name);
925
926 } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: param)) {
927 StringRef name = successor->getVar()->name;
928 if (successor->getVar()->isVariadic())
929 body << llvm::formatv(Fmt: "{0}Successors", Vals&: name);
930 else
931 body << llvm::formatv(Fmt: "{0}Successor", Vals&: name);
932
933 } else if (auto *dir = dyn_cast<RefDirective>(Val: param)) {
934 genCustomParameterParser(param: dir->getArg(), body);
935
936 } else if (auto *dir = dyn_cast<TypeDirective>(Val: param)) {
937 ArgumentLengthKind lengthKind;
938 StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind);
939 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
940 body << llvm::formatv(Fmt: "{0}TypeGroups", Vals&: listName);
941 else if (lengthKind == ArgumentLengthKind::Variadic)
942 body << llvm::formatv(Fmt: "{0}Types", Vals&: listName);
943 else if (lengthKind == ArgumentLengthKind::Optional)
944 body << llvm::formatv(Fmt: "{0}Type", Vals&: listName);
945 else
946 body << formatv(Fmt: "{0}RawType", Vals&: listName);
947
948 } else if (auto *string = dyn_cast<StringElement>(Val: param)) {
949 FmtContext ctx;
950 ctx.withBuilder(subst: "parser.getBuilder()");
951 ctx.addSubst(placeholder: "_ctxt", subst: "parser.getContext()");
952 body << tgfmt(fmt: string->getValue(), ctx: &ctx);
953
954 } else if (auto *property = dyn_cast<PropertyVariable>(Val: param)) {
955 body << llvm::formatv(Fmt: "result.getOrAddProperties<Properties>().{0}",
956 Vals: property->getVar()->name);
957 } else {
958 llvm_unreachable("unknown custom directive parameter");
959 }
960}
961
962/// Generate the parser for a custom directive.
963static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
964 bool useProperties,
965 StringRef opCppClassName,
966 bool isOptional = false) {
967 body << " {\n";
968
969 // Preprocess the directive variables.
970 // * Add a local variable for optional operands and types. This provides a
971 // better API to the user defined parser methods.
972 // * Set the location of operand variables.
973 for (FormatElement *param : dir->getArguments()) {
974 if (auto *operand = dyn_cast<OperandVariable>(Val: param)) {
975 auto *var = operand->getVar();
976 body << " " << var->name
977 << "OperandsLoc = parser.getCurrentLocation();\n";
978 if (var->isOptional()) {
979 body << llvm::formatv(
980 Fmt: " ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> "
981 "{0}Operand;\n",
982 Vals: var->name);
983 } else if (var->isVariadicOfVariadic()) {
984 body << llvm::formatv(Fmt: " "
985 "::llvm::SmallVector<::llvm::SmallVector<::mlir::"
986 "OpAsmParser::UnresolvedOperand>> "
987 "{0}OperandGroups;\n",
988 Vals: var->name);
989 }
990 } else if (auto *dir = dyn_cast<TypeDirective>(Val: param)) {
991 ArgumentLengthKind lengthKind;
992 StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind);
993 if (lengthKind == ArgumentLengthKind::Optional) {
994 body << llvm::formatv(Fmt: " ::mlir::Type {0}Type;\n", Vals&: listName);
995 } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
996 body << llvm::formatv(
997 Fmt: " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
998 "{0}TypeGroups;\n",
999 Vals&: listName);
1000 }
1001 } else if (auto *dir = dyn_cast<RefDirective>(Val: param)) {
1002 FormatElement *input = dir->getArg();
1003 if (auto *operand = dyn_cast<OperandVariable>(Val: input)) {
1004 if (!operand->getVar()->isOptional())
1005 continue;
1006 body << llvm::formatv(
1007 Fmt: " {0} {1}Operand = {1}Operands.empty() ? {0}() : "
1008 "{1}Operands[0];\n",
1009 Vals: "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>",
1010 Vals: operand->getVar()->name);
1011
1012 } else if (auto *type = dyn_cast<TypeDirective>(Val: input)) {
1013 ArgumentLengthKind lengthKind;
1014 StringRef listName = getTypeListName(arg: type->getArg(), lengthKind);
1015 if (lengthKind == ArgumentLengthKind::Optional) {
1016 body << llvm::formatv(Fmt: " ::mlir::Type {0}Type = {0}Types.empty() ? "
1017 "::mlir::Type() : {0}Types[0];\n",
1018 Vals&: listName);
1019 }
1020 }
1021 }
1022 }
1023
1024 body << " auto odsResult = parse" << dir->getName() << "(parser";
1025 for (FormatElement *param : dir->getArguments()) {
1026 body << ", ";
1027 genCustomParameterParser(param, body);
1028 }
1029 body << ");\n";
1030
1031 if (isOptional) {
1032 body << " if (!odsResult.has_value()) return {};\n"
1033 << " if (::mlir::failed(*odsResult)) return ::mlir::failure();\n";
1034 } else {
1035 body << " if (odsResult) return ::mlir::failure();\n";
1036 }
1037
1038 // After parsing, add handling for any of the optional constructs.
1039 for (FormatElement *param : dir->getArguments()) {
1040 if (auto *attr = dyn_cast<AttributeVariable>(Val: param)) {
1041 const NamedAttribute *var = attr->getVar();
1042 if (var->attr.isOptional() || var->attr.hasDefaultValue())
1043 body << llvm::formatv(Fmt: " if ({0}Attr)\n ", Vals: var->name);
1044 if (useProperties) {
1045 body << formatv(
1046 Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
1047 Vals: var->name, Vals&: opCppClassName);
1048 } else {
1049 body << llvm::formatv(Fmt: " result.addAttribute(\"{0}\", {0}Attr);\n",
1050 Vals: var->name);
1051 }
1052
1053 } else if (auto *operand = dyn_cast<OperandVariable>(Val: param)) {
1054 const NamedTypeConstraint *var = operand->getVar();
1055 if (var->isOptional()) {
1056 body << llvm::formatv(Fmt: " if ({0}Operand.has_value())\n"
1057 " {0}Operands.push_back(*{0}Operand);\n",
1058 Vals: var->name);
1059 } else if (var->isVariadicOfVariadic()) {
1060 body << llvm::formatv(
1061 Fmt: " for (const auto &subRange : {0}OperandGroups) {{\n"
1062 " {0}Operands.append(subRange.begin(), subRange.end());\n"
1063 " {0}OperandGroupSizes.push_back(subRange.size());\n"
1064 " }\n",
1065 Vals: var->name, Vals: var->constraint.getVariadicOfVariadicSegmentSizeAttr());
1066 }
1067 } else if (auto *dir = dyn_cast<TypeDirective>(Val: param)) {
1068 ArgumentLengthKind lengthKind;
1069 StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind);
1070 if (lengthKind == ArgumentLengthKind::Optional) {
1071 body << llvm::formatv(Fmt: " if ({0}Type)\n"
1072 " {0}Types.push_back({0}Type);\n",
1073 Vals&: listName);
1074 } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1075 body << llvm::formatv(
1076 Fmt: " for (const auto &subRange : {0}TypeGroups)\n"
1077 " {0}Types.append(subRange.begin(), subRange.end());\n",
1078 Vals&: listName);
1079 }
1080 }
1081 }
1082
1083 body << " }\n";
1084}
1085
1086/// Generate the parser for a enum attribute.
1087static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
1088 FmtContext &attrTypeCtx, bool parseAsOptional,
1089 bool useProperties, StringRef opCppClassName) {
1090 Attribute baseAttr = var->attr.getBaseAttr();
1091 const EnumAttr &enumAttr = cast<EnumAttr>(Val&: baseAttr);
1092 std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1093
1094 // Generate the code for building an attribute for this enum.
1095 std::string attrBuilderStr;
1096 {
1097 llvm::raw_string_ostream os(attrBuilderStr);
1098 os << tgfmt(fmt: enumAttr.getConstBuilderTemplate(), ctx: &attrTypeCtx,
1099 vals: "*attrOptional");
1100 }
1101
1102 // Build a string containing the cases that can be formatted as a keyword.
1103 std::string validCaseKeywordsStr = "{";
1104 llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr);
1105 for (const EnumAttrCase &attrCase : cases)
1106 if (canFormatStringAsKeyword(value: attrCase.getStr()))
1107 validCaseKeywordsOS << '"' << attrCase.getStr() << "\",";
1108 validCaseKeywordsOS.str().back() = '}';
1109
1110 // If the attribute is not optional, build an error message for the missing
1111 // attribute.
1112 std::string errorMessage;
1113 if (!parseAsOptional) {
1114 llvm::raw_string_ostream errorMessageOS(errorMessage);
1115 errorMessageOS
1116 << "return parser.emitError(loc, \"expected string or "
1117 "keyword containing one of the following enum values for attribute '"
1118 << var->name << "' [";
1119 llvm::interleaveComma(c: cases, os&: errorMessageOS, each_fn: [&](const auto &attrCase) {
1120 errorMessageOS << attrCase.getStr();
1121 });
1122 errorMessageOS << "]\");";
1123 }
1124 std::string attrAssignment;
1125 if (useProperties) {
1126 attrAssignment =
1127 formatv(Fmt: " "
1128 "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;",
1129 Vals: var->name, Vals&: opCppClassName);
1130 } else {
1131 attrAssignment =
1132 formatv(Fmt: "result.addAttribute(\"{0}\", {0}Attr);", Vals: var->name);
1133 }
1134
1135 body << formatv(Fmt: enumAttrParserCode, Vals: var->name, Vals: enumAttr.getCppNamespace(),
1136 Vals: enumAttr.getStringToSymbolFnName(), Vals&: attrBuilderStr,
1137 Vals&: validCaseKeywordsStr, Vals&: errorMessage, Vals&: attrAssignment);
1138}
1139
1140// Generate the parser for an attribute.
1141static void genAttrParser(AttributeVariable *attr, MethodBody &body,
1142 FmtContext &attrTypeCtx, bool parseAsOptional,
1143 bool useProperties, StringRef opCppClassName) {
1144 const NamedAttribute *var = attr->getVar();
1145
1146 // Check to see if we can parse this as an enum attribute.
1147 if (canFormatEnumAttr(attr: var))
1148 return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional,
1149 useProperties, opCppClassName);
1150
1151 // Check to see if we should parse this as a symbol name attribute.
1152 if (shouldFormatSymbolNameAttr(attr: var)) {
1153 body << formatv(Fmt: parseAsOptional ? optionalSymbolNameAttrParserCode
1154 : symbolNameAttrParserCode,
1155 Vals: var->name);
1156 } else {
1157
1158 // If this attribute has a buildable type, use that when parsing the
1159 // attribute.
1160 std::string attrTypeStr;
1161 if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
1162 llvm::raw_string_ostream os(attrTypeStr);
1163 os << tgfmt(fmt: *typeBuilder, ctx: &attrTypeCtx);
1164 } else {
1165 attrTypeStr = "::mlir::Type{}";
1166 }
1167 if (parseAsOptional) {
1168 body << formatv(Fmt: optionalAttrParserCode, Vals: var->name, Vals&: attrTypeStr);
1169 } else {
1170 if (attr->shouldBeQualified() ||
1171 var->attr.getStorageType() == "::mlir::Attribute")
1172 body << formatv(Fmt: genericAttrParserCode, Vals: var->name, Vals&: attrTypeStr);
1173 else
1174 body << formatv(Fmt: attrParserCode, Vals: var->name, Vals&: attrTypeStr);
1175 }
1176 }
1177 if (useProperties) {
1178 body << formatv(
1179 Fmt: " if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = "
1180 "{0}Attr;\n",
1181 Vals: var->name, Vals&: opCppClassName);
1182 } else {
1183 body << formatv(
1184 Fmt: " if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n",
1185 Vals: var->name);
1186 }
1187}
1188
1189// Generates the 'setPropertiesFromParsedAttr' used to set properties from a
1190// 'prop-dict' dictionary attr.
1191static void genParsedAttrPropertiesSetter(OperationFormat &fmt, Operator &op,
1192 OpClass &opClass) {
1193 // Not required unless 'prop-dict' is present.
1194 if (!fmt.hasPropDict)
1195 return;
1196
1197 SmallVector<MethodParameter> paramList;
1198 paramList.emplace_back(Args: "Properties &", Args: "prop");
1199 paramList.emplace_back(Args: "::mlir::Attribute", Args: "attr");
1200 paramList.emplace_back(Args: "::llvm::function_ref<::mlir::InFlightDiagnostic()>",
1201 Args: "emitError");
1202
1203 Method *method = opClass.addStaticMethod(retType: "::mlir::LogicalResult",
1204 name: "setPropertiesFromParsedAttr",
1205 args: std::move(paramList));
1206 MethodBody &body = method->body().indent();
1207
1208 body << R"decl(
1209::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr);
1210if (!dict) {
1211 emitError() << "expected DictionaryAttr to set properties";
1212 return ::mlir::failure();
1213}
1214)decl";
1215
1216 // TODO: properties might be optional as well.
1217 const char *propFromAttrFmt = R"decl(
1218auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
1219 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{
1220 {0};
1221};
1222auto attr = dict.get("{1}");
1223if (!attr) {{
1224 emitError() << "expected key entry for {1} in DictionaryAttr to set "
1225 "Properties.";
1226 return ::mlir::failure();
1227}
1228if (::mlir::failed(setFromAttr(prop.{1}, attr, emitError)))
1229 return ::mlir::failure();
1230)decl";
1231
1232 // Generate the setter for any property not parsed elsewhere.
1233 for (const NamedProperty &namedProperty : op.getProperties()) {
1234 if (fmt.usedProperties.contains(key: &namedProperty))
1235 continue;
1236
1237 auto scope = body.scope(open: "{\n", close: "}\n", /*indent=*/true);
1238
1239 StringRef name = namedProperty.name;
1240 const Property &prop = namedProperty.prop;
1241 FmtContext fctx;
1242 body << formatv(Fmt: propFromAttrFmt,
1243 Vals: tgfmt(fmt: prop.getConvertFromAttributeCall(),
1244 ctx: &fctx.addSubst(placeholder: "_attr", subst: "propAttr")
1245 .addSubst(placeholder: "_storage", subst: "propStorage")
1246 .addSubst(placeholder: "_diag", subst: "emitError")),
1247 Vals&: name);
1248 }
1249
1250 // Generate the setter for any attribute not parsed elsewhere.
1251 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1252 if (fmt.usedAttributes.contains(key: &namedAttr))
1253 continue;
1254
1255 const Attribute &attr = namedAttr.attr;
1256 // Derived attributes do not need to be parsed.
1257 if (attr.isDerivedAttr())
1258 continue;
1259
1260 auto scope = body.scope(open: "{\n", close: "}\n", /*indent=*/true);
1261
1262 // If the attribute has a default value or is optional, it does not need to
1263 // be present in the parsed dictionary attribute.
1264 bool isRequired = !attr.isOptional() && !attr.hasDefaultValue();
1265 body << formatv(Fmt: R"decl(
1266auto &propStorage = prop.{0};
1267auto attr = dict.get("{0}");
1268if (attr || /*isRequired=*/{1}) {{
1269 if (!attr) {{
1270 emitError() << "expected key entry for {0} in DictionaryAttr to set "
1271 "Properties.";
1272 return ::mlir::failure();
1273 }
1274 auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
1275 if (convertedAttr) {{
1276 propStorage = convertedAttr;
1277 } else {{
1278 emitError() << "Invalid attribute `{0}` in property conversion: " << attr;
1279 return ::mlir::failure();
1280 }
1281}
1282)decl",
1283 Vals: namedAttr.name, Vals&: isRequired);
1284 }
1285 body << "return ::mlir::success();\n";
1286}
1287
1288void OperationFormat::genParser(Operator &op, OpClass &opClass) {
1289 SmallVector<MethodParameter> paramList;
1290 paramList.emplace_back(Args: "::mlir::OpAsmParser &", Args: "parser");
1291 paramList.emplace_back(Args: "::mlir::OperationState &", Args: "result");
1292
1293 auto *method = opClass.addStaticMethod(retType: "::mlir::ParseResult", name: "parse",
1294 args: std::move(paramList));
1295 auto &body = method->body();
1296
1297 // Generate variables to store the operands and type within the format. This
1298 // allows for referencing these variables in the presence of optional
1299 // groupings.
1300 for (FormatElement *element : elements)
1301 genElementParserStorage(element, op, body);
1302
1303 // A format context used when parsing attributes with buildable types.
1304 FmtContext attrTypeCtx;
1305 attrTypeCtx.withBuilder(subst: "parser.getBuilder()");
1306
1307 // Generate parsers for each of the elements.
1308 for (FormatElement *element : elements)
1309 genElementParser(element, body, attrTypeCtx);
1310
1311 // Generate the code to resolve the operand/result types and successors now
1312 // that they have been parsed.
1313 genParserRegionResolution(op, body);
1314 genParserSuccessorResolution(op, body);
1315 genParserVariadicSegmentResolution(op, body);
1316 genParserTypeResolution(op, body);
1317
1318 body << " return ::mlir::success();\n";
1319
1320 genParsedAttrPropertiesSetter(fmt&: *this, op, opClass);
1321}
1322
1323void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
1324 FmtContext &attrTypeCtx,
1325 GenContext genCtx) {
1326 /// Optional Group.
1327 if (auto *optional = dyn_cast<OptionalElement>(Val: element)) {
1328 auto genElementParsers = [&](FormatElement *firstElement,
1329 ArrayRef<FormatElement *> elements,
1330 bool thenGroup) {
1331 // If the anchor is a unit attribute, we don't need to print it. When
1332 // parsing, we will add this attribute if this group is present.
1333 FormatElement *elidedAnchorElement = nullptr;
1334 auto *anchorAttr = dyn_cast<AttributeVariable>(Val: optional->getAnchor());
1335 if (anchorAttr && anchorAttr != firstElement &&
1336 anchorAttr->isUnitAttr()) {
1337 elidedAnchorElement = anchorAttr;
1338
1339 if (!thenGroup == optional->isInverted()) {
1340 // Add the anchor unit attribute to the operation state.
1341 if (useProperties) {
1342 body << formatv(
1343 Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = "
1344 "parser.getBuilder().getUnitAttr();",
1345 Vals: anchorAttr->getVar()->name, Vals&: opCppClassName);
1346 } else {
1347 body << " result.addAttribute(\"" << anchorAttr->getVar()->name
1348 << "\", parser.getBuilder().getUnitAttr());\n";
1349 }
1350 }
1351 }
1352
1353 // Generate the rest of the elements inside an optional group. Elements in
1354 // an optional group after the guard are parsed as required.
1355 for (FormatElement *childElement : elements)
1356 if (childElement != elidedAnchorElement)
1357 genElementParser(element: childElement, body, attrTypeCtx,
1358 genCtx: GenContext::Optional);
1359 };
1360
1361 ArrayRef<FormatElement *> thenElements =
1362 optional->getThenElements(/*parseable=*/true);
1363
1364 // Generate a special optional parser for the first element to gate the
1365 // parsing of the rest of the elements.
1366 FormatElement *firstElement = thenElements.front();
1367 if (auto *attrVar = dyn_cast<AttributeVariable>(Val: firstElement)) {
1368 genAttrParser(attr: attrVar, body, attrTypeCtx, /*parseAsOptional=*/true,
1369 useProperties, opCppClassName);
1370 body << " if (" << attrVar->getVar()->name << "Attr) {\n";
1371 } else if (auto *literal = dyn_cast<LiteralElement>(Val: firstElement)) {
1372 body << " if (::mlir::succeeded(parser.parseOptional";
1373 genLiteralParser(value: literal->getSpelling(), body);
1374 body << ")) {\n";
1375 } else if (auto *opVar = dyn_cast<OperandVariable>(Val: firstElement)) {
1376 genElementParser(element: opVar, body, attrTypeCtx);
1377 body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
1378 } else if (auto *regionVar = dyn_cast<RegionVariable>(Val: firstElement)) {
1379 const NamedRegion *region = regionVar->getVar();
1380 if (region->isVariadic()) {
1381 genElementParser(element: regionVar, body, attrTypeCtx);
1382 body << " if (!" << region->name << "Regions.empty()) {\n";
1383 } else {
1384 body << llvm::formatv(Fmt: optionalRegionParserCode, Vals: region->name);
1385 body << " if (!" << region->name << "Region->empty()) {\n ";
1386 if (hasImplicitTermTrait)
1387 body << llvm::formatv(Fmt: regionEnsureTerminatorParserCode, Vals: region->name);
1388 else if (hasSingleBlockTrait)
1389 body << llvm::formatv(Fmt: regionEnsureSingleBlockParserCode,
1390 Vals: region->name);
1391 }
1392 } else if (auto *custom = dyn_cast<CustomDirective>(Val: firstElement)) {
1393 body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n";
1394 genCustomDirectiveParser(dir: custom, body, useProperties, opCppClassName,
1395 /*isOptional=*/true);
1396 body << " return ::mlir::success();\n"
1397 << " }(); optResult.has_value() && ::mlir::failed(*optResult)) {\n"
1398 << " return ::mlir::failure();\n"
1399 << " } else if (optResult.has_value()) {\n";
1400 }
1401
1402 genElementParsers(firstElement, thenElements.drop_front(),
1403 /*thenGroup=*/true);
1404 body << " }";
1405
1406 // Generate the else elements.
1407 auto elseElements = optional->getElseElements();
1408 if (!elseElements.empty()) {
1409 body << " else {\n";
1410 ArrayRef<FormatElement *> elseElements =
1411 optional->getElseElements(/*parseable=*/true);
1412 genElementParsers(elseElements.front(), elseElements,
1413 /*thenGroup=*/false);
1414 body << " }";
1415 }
1416 body << "\n";
1417
1418 /// OIList Directive
1419 } else if (OIListElement *oilist = dyn_cast<OIListElement>(Val: element)) {
1420 for (LiteralElement *le : oilist->getLiteralElements())
1421 body << " bool " << le->getSpelling() << "Clause = false;\n";
1422
1423 // Generate the parsing loop
1424 body << " while(true) {\n";
1425 for (auto clause : oilist->getClauses()) {
1426 LiteralElement *lelement = std::get<0>(t&: clause);
1427 ArrayRef<FormatElement *> pelement = std::get<1>(t&: clause);
1428 body << "if (succeeded(parser.parseOptional";
1429 genLiteralParser(value: lelement->getSpelling(), body);
1430 body << ")) {\n";
1431 StringRef lelementName = lelement->getSpelling();
1432 body << formatv(Fmt: oilistParserCode, Vals&: lelementName);
1433 if (AttributeVariable *unitAttrElem =
1434 oilist->getUnitAttrParsingElement(pelement)) {
1435 if (useProperties) {
1436 body << formatv(
1437 Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = "
1438 "parser.getBuilder().getUnitAttr();",
1439 Vals: unitAttrElem->getVar()->name, Vals&: opCppClassName);
1440 } else {
1441 body << " result.addAttribute(\"" << unitAttrElem->getVar()->name
1442 << "\", UnitAttr::get(parser.getContext()));\n";
1443 }
1444 } else {
1445 for (FormatElement *el : pelement)
1446 genElementParser(element: el, body, attrTypeCtx);
1447 }
1448 body << " } else ";
1449 }
1450 body << " {\n";
1451 body << " break;\n";
1452 body << " }\n";
1453 body << "}\n";
1454
1455 /// Literals.
1456 } else if (LiteralElement *literal = dyn_cast<LiteralElement>(Val: element)) {
1457 body << " if (parser.parse";
1458 genLiteralParser(value: literal->getSpelling(), body);
1459 body << ")\n return ::mlir::failure();\n";
1460
1461 /// Whitespaces.
1462 } else if (isa<WhitespaceElement>(Val: element)) {
1463 // Nothing to parse.
1464
1465 /// Arguments.
1466 } else if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) {
1467 bool parseAsOptional =
1468 (genCtx == GenContext::Normal && attr->getVar()->attr.isOptional());
1469 genAttrParser(attr, body, attrTypeCtx, parseAsOptional, useProperties,
1470 opCppClassName);
1471
1472 } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) {
1473 ArgumentLengthKind lengthKind = getArgumentLengthKind(var: operand->getVar());
1474 StringRef name = operand->getVar()->name;
1475 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
1476 body << llvm::formatv(
1477 Fmt: variadicOfVariadicOperandParserCode, Vals&: name,
1478 Vals: operand->getVar()->constraint.getVariadicOfVariadicSegmentSizeAttr());
1479 else if (lengthKind == ArgumentLengthKind::Variadic)
1480 body << llvm::formatv(Fmt: variadicOperandParserCode, Vals&: name);
1481 else if (lengthKind == ArgumentLengthKind::Optional)
1482 body << llvm::formatv(Fmt: optionalOperandParserCode, Vals&: name);
1483 else
1484 body << formatv(Fmt: operandParserCode, Vals&: name);
1485
1486 } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) {
1487 bool isVariadic = region->getVar()->isVariadic();
1488 body << llvm::formatv(Fmt: isVariadic ? regionListParserCode : regionParserCode,
1489 Vals: region->getVar()->name);
1490 if (hasImplicitTermTrait)
1491 body << llvm::formatv(Fmt: isVariadic ? regionListEnsureTerminatorParserCode
1492 : regionEnsureTerminatorParserCode,
1493 Vals: region->getVar()->name);
1494 else if (hasSingleBlockTrait)
1495 body << llvm::formatv(Fmt: isVariadic ? regionListEnsureSingleBlockParserCode
1496 : regionEnsureSingleBlockParserCode,
1497 Vals: region->getVar()->name);
1498
1499 } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) {
1500 bool isVariadic = successor->getVar()->isVariadic();
1501 body << formatv(Fmt: isVariadic ? successorListParserCode : successorParserCode,
1502 Vals: successor->getVar()->name);
1503
1504 /// Directives.
1505 } else if (auto *attrDict = dyn_cast<AttrDictDirective>(Val: element)) {
1506 body.indent() << "{\n";
1507 body.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n"
1508 << "if (parser.parseOptionalAttrDict"
1509 << (attrDict->isWithKeyword() ? "WithKeyword" : "")
1510 << "(result.attributes))\n"
1511 << " return ::mlir::failure();\n";
1512 if (useProperties) {
1513 body << "if (failed(verifyInherentAttrs(result.name, result.attributes, "
1514 "[&]() {\n"
1515 << " return parser.emitError(loc) << \"'\" << "
1516 "result.name.getStringRef() << \"' op \";\n"
1517 << " })))\n"
1518 << " return ::mlir::failure();\n";
1519 }
1520 body.unindent() << "}\n";
1521 body.unindent();
1522 } else if (isa<PropDictDirective>(Val: element)) {
1523 body << " if (parseProperties(parser, result))\n"
1524 << " return ::mlir::failure();\n";
1525 } else if (auto *customDir = dyn_cast<CustomDirective>(Val: element)) {
1526 genCustomDirectiveParser(dir: customDir, body, useProperties, opCppClassName);
1527 } else if (isa<OperandsDirective>(Val: element)) {
1528 body << " [[maybe_unused]] ::llvm::SMLoc allOperandLoc ="
1529 << " parser.getCurrentLocation();\n"
1530 << " if (parser.parseOperandList(allOperands))\n"
1531 << " return ::mlir::failure();\n";
1532
1533 } else if (isa<RegionsDirective>(Val: element)) {
1534 body << llvm::formatv(Fmt: regionListParserCode, Vals: "full");
1535 if (hasImplicitTermTrait)
1536 body << llvm::formatv(Fmt: regionListEnsureTerminatorParserCode, Vals: "full");
1537 else if (hasSingleBlockTrait)
1538 body << llvm::formatv(Fmt: regionListEnsureSingleBlockParserCode, Vals: "full");
1539
1540 } else if (isa<SuccessorsDirective>(Val: element)) {
1541 body << llvm::formatv(Fmt: successorListParserCode, Vals: "full");
1542
1543 } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) {
1544 ArgumentLengthKind lengthKind;
1545 StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind);
1546 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1547 body << llvm::formatv(Fmt: variadicOfVariadicTypeParserCode, Vals&: listName);
1548 } else if (lengthKind == ArgumentLengthKind::Variadic) {
1549 body << llvm::formatv(Fmt: variadicTypeParserCode, Vals&: listName);
1550 } else if (lengthKind == ArgumentLengthKind::Optional) {
1551 body << llvm::formatv(Fmt: optionalTypeParserCode, Vals&: listName);
1552 } else {
1553 const char *parserCode =
1554 dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
1555 TypeSwitch<FormatElement *>(dir->getArg())
1556 .Case<OperandVariable, ResultVariable>(caseFn: [&](auto operand) {
1557 body << formatv(parserCode,
1558 operand->getVar()->constraint.getCPPClassName(),
1559 listName);
1560 })
1561 .Default(defaultFn: [&](auto operand) {
1562 body << formatv(Fmt: parserCode, Vals: "::mlir::Type", Vals&: listName);
1563 });
1564 }
1565 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(Val: element)) {
1566 ArgumentLengthKind ignored;
1567 body << formatv(Fmt: functionalTypeParserCode,
1568 Vals: getTypeListName(arg: dir->getInputs(), lengthKind&: ignored),
1569 Vals: getTypeListName(arg: dir->getResults(), lengthKind&: ignored));
1570 } else {
1571 llvm_unreachable("unknown format element");
1572 }
1573}
1574
1575void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
1576 // If any of type resolutions use transformed variables, make sure that the
1577 // types of those variables are resolved.
1578 SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
1579 FmtContext verifierFCtx;
1580 for (TypeResolution &resolver :
1581 llvm::concat<TypeResolution>(Ranges&: resultTypes, Ranges&: operandTypes)) {
1582 std::optional<StringRef> transformer = resolver.getVarTransformer();
1583 if (!transformer)
1584 continue;
1585 // Ensure that we don't verify the same variables twice.
1586 const NamedTypeConstraint *variable = resolver.getVariable();
1587 if (!variable || !verifiedVariables.insert(Ptr: variable).second)
1588 continue;
1589
1590 auto constraint = variable->constraint;
1591 body << " for (::mlir::Type type : " << variable->name << "Types) {\n"
1592 << " (void)type;\n"
1593 << " if (!("
1594 << tgfmt(fmt: constraint.getConditionTemplate(),
1595 ctx: &verifierFCtx.withSelf(subst: "type"))
1596 << ")) {\n"
1597 << formatv(Fmt: " return parser.emitError(parser.getNameLoc()) << "
1598 "\"'{0}' must be {1}, but got \" << type;\n",
1599 Vals: variable->name, Vals: constraint.getSummary())
1600 << " }\n"
1601 << " }\n";
1602 }
1603
1604 // Initialize the set of buildable types.
1605 if (!buildableTypes.empty()) {
1606 FmtContext typeBuilderCtx;
1607 typeBuilderCtx.withBuilder(subst: "parser.getBuilder()");
1608 for (auto &it : buildableTypes)
1609 body << " ::mlir::Type odsBuildableType" << it.second << " = "
1610 << tgfmt(fmt: it.first, ctx: &typeBuilderCtx) << ";\n";
1611 }
1612
1613 // Emit the code necessary for a type resolver.
1614 auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
1615 if (std::optional<int> val = resolver.getBuilderIdx()) {
1616 body << "odsBuildableType" << *val;
1617 } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
1618 if (std::optional<StringRef> tform = resolver.getVarTransformer()) {
1619 FmtContext fmtContext;
1620 fmtContext.addSubst(placeholder: "_ctxt", subst: "parser.getContext()");
1621 if (var->isVariadic())
1622 fmtContext.withSelf(subst: var->name + "Types");
1623 else
1624 fmtContext.withSelf(subst: var->name + "Types[0]");
1625 body << tgfmt(fmt: *tform, ctx: &fmtContext);
1626 } else {
1627 body << var->name << "Types";
1628 if (!var->isVariadic())
1629 body << "[0]";
1630 }
1631 } else if (const NamedAttribute *attr = resolver.getAttribute()) {
1632 if (std::optional<StringRef> tform = resolver.getVarTransformer())
1633 body << tgfmt(fmt: *tform,
1634 ctx: &FmtContext().withSelf(subst: attr->name + "Attr.getType()"));
1635 else
1636 body << attr->name << "Attr.getType()";
1637 } else {
1638 body << curVar << "Types";
1639 }
1640 };
1641
1642 // Resolve each of the result types.
1643 if (!infersResultTypes) {
1644 if (allResultTypes) {
1645 body << " result.addTypes(allResultTypes);\n";
1646 } else {
1647 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
1648 body << " result.addTypes(";
1649 emitTypeResolver(resultTypes[i], op.getResultName(index: i));
1650 body << ");\n";
1651 }
1652 }
1653 }
1654
1655 // Emit the operand type resolutions.
1656 genParserOperandTypeResolution(op, body, emitTypeResolver);
1657
1658 // Handle return type inference once all operands have been resolved
1659 if (infersResultTypes)
1660 body << formatv(Fmt: inferReturnTypesParserCode, Vals: op.getCppClassName());
1661}
1662
1663void OperationFormat::genParserOperandTypeResolution(
1664 Operator &op, MethodBody &body,
1665 function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) {
1666 // Early exit if there are no operands.
1667 if (op.getNumOperands() == 0)
1668 return;
1669
1670 // Handle the case where all operand types are grouped together with
1671 // "types(operands)".
1672 if (allOperandTypes) {
1673 // If `operands` was specified, use the full operand list directly.
1674 if (allOperands) {
1675 body << " if (parser.resolveOperands(allOperands, allOperandTypes, "
1676 "allOperandLoc, result.operands))\n"
1677 " return ::mlir::failure();\n";
1678 return;
1679 }
1680
1681 // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1682 // llvm::concat does not allow the case of a single range, so guard it here.
1683 body << " if (parser.resolveOperands(";
1684 if (op.getNumOperands() > 1) {
1685 body << "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(";
1686 llvm::interleaveComma(c: op.getOperands(), os&: body, each_fn: [&](auto &operand) {
1687 body << operand.name << "Operands";
1688 });
1689 body << ")";
1690 } else {
1691 body << op.operand_begin()->name << "Operands";
1692 }
1693 body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1694 << " return ::mlir::failure();\n";
1695 return;
1696 }
1697
1698 // Handle the case where all operands are grouped together with "operands".
1699 if (allOperands) {
1700 body << " if (parser.resolveOperands(allOperands, ";
1701
1702 // Group all of the operand types together to perform the resolution all at
1703 // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1704 // the case of a single range, so guard it here.
1705 if (op.getNumOperands() > 1) {
1706 body << "::llvm::concat<const ::mlir::Type>(";
1707 llvm::interleaveComma(
1708 c: llvm::seq<int>(Begin: 0, End: op.getNumOperands()), os&: body, each_fn: [&](int i) {
1709 body << "::llvm::ArrayRef<::mlir::Type>(";
1710 emitTypeResolver(operandTypes[i], op.getOperand(index: i).name);
1711 body << ")";
1712 });
1713 body << ")";
1714 } else {
1715 emitTypeResolver(operandTypes.front(), op.getOperand(index: 0).name);
1716 }
1717
1718 body << ", allOperandLoc, result.operands))\n return "
1719 "::mlir::failure();\n";
1720 return;
1721 }
1722
1723 // The final case is the one where each of the operands types are resolved
1724 // separately.
1725 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
1726 NamedTypeConstraint &operand = op.getOperand(index: i);
1727 body << " if (parser.resolveOperands(" << operand.name << "Operands, ";
1728
1729 // Resolve the type of this operand.
1730 TypeResolution &operandType = operandTypes[i];
1731 emitTypeResolver(operandType, operand.name);
1732
1733 body << ", " << operand.name
1734 << "OperandsLoc, result.operands))\n return ::mlir::failure();\n";
1735 }
1736}
1737
1738void OperationFormat::genParserRegionResolution(Operator &op,
1739 MethodBody &body) {
1740 // Check for the case where all regions were parsed.
1741 bool hasAllRegions = llvm::any_of(
1742 Range&: elements, P: [](FormatElement *elt) { return isa<RegionsDirective>(Val: elt); });
1743 if (hasAllRegions) {
1744 body << " result.addRegions(fullRegions);\n";
1745 return;
1746 }
1747
1748 // Otherwise, handle each region individually.
1749 for (const NamedRegion &region : op.getRegions()) {
1750 if (region.isVariadic())
1751 body << " result.addRegions(" << region.name << "Regions);\n";
1752 else
1753 body << " result.addRegion(std::move(" << region.name << "Region));\n";
1754 }
1755}
1756
1757void OperationFormat::genParserSuccessorResolution(Operator &op,
1758 MethodBody &body) {
1759 // Check for the case where all successors were parsed.
1760 bool hasAllSuccessors = llvm::any_of(Range&: elements, P: [](FormatElement *elt) {
1761 return isa<SuccessorsDirective>(Val: elt);
1762 });
1763 if (hasAllSuccessors) {
1764 body << " result.addSuccessors(fullSuccessors);\n";
1765 return;
1766 }
1767
1768 // Otherwise, handle each successor individually.
1769 for (const NamedSuccessor &successor : op.getSuccessors()) {
1770 if (successor.isVariadic())
1771 body << " result.addSuccessors(" << successor.name << "Successors);\n";
1772 else
1773 body << " result.addSuccessors(" << successor.name << "Successor);\n";
1774 }
1775}
1776
1777void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
1778 MethodBody &body) {
1779 if (!allOperands) {
1780 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
1781 auto interleaveFn = [&](const NamedTypeConstraint &operand) {
1782 // If the operand is variadic emit the parsed size.
1783 if (operand.isVariableLength())
1784 body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
1785 else
1786 body << "1";
1787 };
1788 if (op.getDialect().usePropertiesForAttributes()) {
1789 body << "::llvm::copy(::llvm::ArrayRef<int32_t>({";
1790 llvm::interleaveComma(c: op.getOperands(), os&: body, each_fn: interleaveFn);
1791 body << formatv(Fmt: "}), "
1792 "result.getOrAddProperties<{0}::Properties>()."
1793 "operandSegmentSizes.begin());\n",
1794 Vals: op.getCppClassName());
1795 } else {
1796 body << " result.addAttribute(\"operandSegmentSizes\", "
1797 << "parser.getBuilder().getDenseI32ArrayAttr({";
1798 llvm::interleaveComma(c: op.getOperands(), os&: body, each_fn: interleaveFn);
1799 body << "}));\n";
1800 }
1801 }
1802 for (const NamedTypeConstraint &operand : op.getOperands()) {
1803 if (!operand.isVariadicOfVariadic())
1804 continue;
1805 if (op.getDialect().usePropertiesForAttributes()) {
1806 body << llvm::formatv(
1807 Fmt: " result.getOrAddProperties<{0}::Properties>().{1} = "
1808 "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
1809 Vals: op.getCppClassName(),
1810 Vals: operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
1811 Vals: operand.name);
1812 } else {
1813 body << llvm::formatv(
1814 Fmt: " result.addAttribute(\"{0}\", "
1815 "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
1816 "\n",
1817 Vals: operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
1818 Vals: operand.name);
1819 }
1820 }
1821 }
1822
1823 if (!allResultTypes &&
1824 op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments")) {
1825 auto interleaveFn = [&](const NamedTypeConstraint &result) {
1826 // If the result is variadic emit the parsed size.
1827 if (result.isVariableLength())
1828 body << "static_cast<int32_t>(" << result.name << "Types.size())";
1829 else
1830 body << "1";
1831 };
1832 if (op.getDialect().usePropertiesForAttributes()) {
1833 body << "::llvm::copy(::llvm::ArrayRef<int32_t>({";
1834 llvm::interleaveComma(c: op.getResults(), os&: body, each_fn: interleaveFn);
1835 body << formatv(Fmt: "}), "
1836 "result.getOrAddProperties<{0}::Properties>()."
1837 "resultSegmentSizes.begin());\n",
1838 Vals: op.getCppClassName());
1839 } else {
1840 body << " result.addAttribute(\"resultSegmentSizes\", "
1841 << "parser.getBuilder().getDenseI32ArrayAttr({";
1842 llvm::interleaveComma(c: op.getResults(), os&: body, each_fn: interleaveFn);
1843 body << "}));\n";
1844 }
1845 }
1846}
1847
1848//===----------------------------------------------------------------------===//
1849// PrinterGen
1850
1851/// The code snippet used to generate a printer call for a region of an
1852// operation that has the SingleBlockImplicitTerminator trait.
1853///
1854/// {0}: The name of the region.
1855const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
1856 {
1857 bool printTerminator = true;
1858 if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1859 printTerminator = !term->getAttrDictionary().empty() ||
1860 term->getNumOperands() != 0 ||
1861 term->getNumResults() != 0;
1862 }
1863 _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true,
1864 /*printBlockTerminators=*/printTerminator);
1865 }
1866)";
1867
1868/// The code snippet used to generate a printer call for an enum that has cases
1869/// that can't be represented with a keyword.
1870///
1871/// {0}: The name of the enum attribute.
1872/// {1}: The name of the enum attributes symbolToString function.
1873const char *enumAttrBeginPrinterCode = R"(
1874 {
1875 auto caseValue = {0}();
1876 auto caseValueStr = {1}(caseValue);
1877)";
1878
1879/// Generate the printer for the 'prop-dict' directive.
1880static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
1881 MethodBody &body) {
1882 body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedProps;\n";
1883 for (const NamedProperty *namedProperty : fmt.usedProperties)
1884 body << " elidedProps.push_back(\"" << namedProperty->name << "\");\n";
1885 for (const NamedAttribute *namedAttr : fmt.usedAttributes)
1886 body << " elidedProps.push_back(\"" << namedAttr->name << "\");\n";
1887
1888 // Add code to check attributes for equality with the default value
1889 // for attributes with the elidePrintingDefaultValue bit set.
1890 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1891 const Attribute &attr = namedAttr.attr;
1892 if (!attr.isDerivedAttr() && attr.hasDefaultValue()) {
1893 const StringRef &name = namedAttr.name;
1894 FmtContext fctx;
1895 fctx.withBuilder(subst: "odsBuilder");
1896 std::string defaultValue = std::string(
1897 tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx, vals: attr.getDefaultValue()));
1898 body << " {\n";
1899 body << " ::mlir::Builder odsBuilder(getContext());\n";
1900 body << " ::mlir::Attribute attr = " << op.getGetterName(name)
1901 << "Attr();\n";
1902 body << " if(attr && (attr == " << defaultValue << "))\n";
1903 body << " elidedProps.push_back(\"" << name << "\");\n";
1904 body << " }\n";
1905 }
1906 }
1907
1908 body << " _odsPrinter << \" \";\n"
1909 << " printProperties(this->getContext(), _odsPrinter, "
1910 "getProperties(), elidedProps);\n";
1911}
1912
1913/// Generate the printer for the 'attr-dict' directive.
1914static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
1915 MethodBody &body, bool withKeyword) {
1916 body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;\n";
1917 // Elide the variadic segment size attributes if necessary.
1918 if (!fmt.allOperands &&
1919 op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments"))
1920 body << " elidedAttrs.push_back(\"operandSegmentSizes\");\n";
1921 if (!fmt.allResultTypes &&
1922 op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments"))
1923 body << " elidedAttrs.push_back(\"resultSegmentSizes\");\n";
1924 for (const StringRef key : fmt.inferredAttributes.keys())
1925 body << " elidedAttrs.push_back(\"" << key << "\");\n";
1926 for (const NamedAttribute *attr : fmt.usedAttributes)
1927 body << " elidedAttrs.push_back(\"" << attr->name << "\");\n";
1928 // Add code to check attributes for equality with the default value
1929 // for attributes with the elidePrintingDefaultValue bit set.
1930 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1931 const Attribute &attr = namedAttr.attr;
1932 if (!attr.isDerivedAttr() && attr.hasDefaultValue()) {
1933 const StringRef &name = namedAttr.name;
1934 FmtContext fctx;
1935 fctx.withBuilder(subst: "odsBuilder");
1936 std::string defaultValue = std::string(
1937 tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx, vals: attr.getDefaultValue()));
1938 body << " {\n";
1939 body << " ::mlir::Builder odsBuilder(getContext());\n";
1940 body << " ::mlir::Attribute attr = " << op.getGetterName(name)
1941 << "Attr();\n";
1942 body << " if(attr && (attr == " << defaultValue << "))\n";
1943 body << " elidedAttrs.push_back(\"" << name << "\");\n";
1944 body << " }\n";
1945 }
1946 }
1947 if (fmt.hasPropDict)
1948 body << " _odsPrinter.printOptionalAttrDict"
1949 << (withKeyword ? "WithKeyword" : "")
1950 << "(llvm::to_vector((*this)->getDiscardableAttrs()), elidedAttrs);\n";
1951 else
1952 body << " _odsPrinter.printOptionalAttrDict"
1953 << (withKeyword ? "WithKeyword" : "")
1954 << "((*this)->getAttrs(), elidedAttrs);\n";
1955}
1956
1957/// Generate the printer for a literal value. `shouldEmitSpace` is true if a
1958/// space should be emitted before this element. `lastWasPunctuation` is true if
1959/// the previous element was a punctuation literal.
1960static void genLiteralPrinter(StringRef value, MethodBody &body,
1961 bool &shouldEmitSpace, bool &lastWasPunctuation) {
1962 body << " _odsPrinter";
1963
1964 // Don't insert a space for certain punctuation.
1965 if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation))
1966 body << " << ' '";
1967 body << " << \"" << value << "\";\n";
1968
1969 // Insert a space after certain literals.
1970 shouldEmitSpace =
1971 value.size() != 1 || !StringRef("<({[").contains(C: value.front());
1972 lastWasPunctuation = value.front() != '_' && !isalpha(value.front());
1973}
1974
1975/// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
1976/// are set to false.
1977static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace,
1978 bool &lastWasPunctuation) {
1979 if (value) {
1980 body << " _odsPrinter << ' ';\n";
1981 lastWasPunctuation = false;
1982 } else {
1983 lastWasPunctuation = true;
1984 }
1985 shouldEmitSpace = false;
1986}
1987
1988/// Generate the printer for a custom directive parameter.
1989static void genCustomDirectiveParameterPrinter(FormatElement *element,
1990 const Operator &op,
1991 MethodBody &body) {
1992 if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) {
1993 body << op.getGetterName(name: attr->getVar()->name) << "Attr()";
1994
1995 } else if (isa<AttrDictDirective>(Val: element)) {
1996 body << "getOperation()->getAttrDictionary()";
1997
1998 } else if (isa<PropDictDirective>(Val: element)) {
1999 body << "getProperties()";
2000
2001 } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) {
2002 body << op.getGetterName(name: operand->getVar()->name) << "()";
2003
2004 } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) {
2005 body << op.getGetterName(name: region->getVar()->name) << "()";
2006
2007 } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) {
2008 body << op.getGetterName(name: successor->getVar()->name) << "()";
2009
2010 } else if (auto *dir = dyn_cast<RefDirective>(Val: element)) {
2011 genCustomDirectiveParameterPrinter(element: dir->getArg(), op, body);
2012
2013 } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) {
2014 auto *typeOperand = dir->getArg();
2015 auto *operand = dyn_cast<OperandVariable>(Val: typeOperand);
2016 auto *var = operand ? operand->getVar()
2017 : cast<ResultVariable>(Val: typeOperand)->getVar();
2018 std::string name = op.getGetterName(name: var->name);
2019 if (var->isVariadic())
2020 body << name << "().getTypes()";
2021 else if (var->isOptional())
2022 body << llvm::formatv(Fmt: "({0}() ? {0}().getType() : ::mlir::Type())", Vals&: name);
2023 else
2024 body << name << "().getType()";
2025
2026 } else if (auto *string = dyn_cast<StringElement>(Val: element)) {
2027 FmtContext ctx;
2028 ctx.withBuilder(subst: "::mlir::Builder(getContext())");
2029 ctx.addSubst(placeholder: "_ctxt", subst: "getContext()");
2030 body << tgfmt(fmt: string->getValue(), ctx: &ctx);
2031
2032 } else if (auto *property = dyn_cast<PropertyVariable>(Val: element)) {
2033 FmtContext ctx;
2034 ctx.addSubst(placeholder: "_ctxt", subst: "getContext()");
2035 const NamedProperty *namedProperty = property->getVar();
2036 ctx.addSubst(placeholder: "_storage", subst: "getProperties()." + namedProperty->name);
2037 body << tgfmt(fmt: namedProperty->prop.getConvertFromStorageCall(), ctx: &ctx);
2038 } else {
2039 llvm_unreachable("unknown custom directive parameter");
2040 }
2041}
2042
2043/// Generate the printer for a custom directive.
2044static void genCustomDirectivePrinter(CustomDirective *customDir,
2045 const Operator &op, MethodBody &body) {
2046 body << " print" << customDir->getName() << "(_odsPrinter, *this";
2047 for (FormatElement *param : customDir->getArguments()) {
2048 body << ", ";
2049 genCustomDirectiveParameterPrinter(element: param, op, body);
2050 }
2051 body << ");\n";
2052}
2053
2054/// Generate the printer for a region with the given variable name.
2055static void genRegionPrinter(const Twine &regionName, MethodBody &body,
2056 bool hasImplicitTermTrait) {
2057 if (hasImplicitTermTrait)
2058 body << llvm::formatv(Fmt: regionSingleBlockImplicitTerminatorPrinterCode,
2059 Vals: regionName);
2060 else
2061 body << " _odsPrinter.printRegion(" << regionName << ");\n";
2062}
2063static void genVariadicRegionPrinter(const Twine &regionListName,
2064 MethodBody &body,
2065 bool hasImplicitTermTrait) {
2066 body << " llvm::interleaveComma(" << regionListName
2067 << ", _odsPrinter, [&](::mlir::Region &region) {\n ";
2068 genRegionPrinter(regionName: "region", body, hasImplicitTermTrait);
2069 body << " });\n";
2070}
2071
2072/// Generate the C++ for an operand to a (*-)type directive.
2073static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op,
2074 MethodBody &body,
2075 bool useArrayRef = true) {
2076 if (isa<OperandsDirective>(Val: arg))
2077 return body << "getOperation()->getOperandTypes()";
2078 if (isa<ResultsDirective>(Val: arg))
2079 return body << "getOperation()->getResultTypes()";
2080 auto *operand = dyn_cast<OperandVariable>(Val: arg);
2081 auto *var = operand ? operand->getVar() : cast<ResultVariable>(Val: arg)->getVar();
2082 if (var->isVariadicOfVariadic())
2083 return body << llvm::formatv(Fmt: "{0}().join().getTypes()",
2084 Vals: op.getGetterName(name: var->name));
2085 if (var->isVariadic())
2086 return body << op.getGetterName(name: var->name) << "().getTypes()";
2087 if (var->isOptional())
2088 return body << llvm::formatv(
2089 Fmt: "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
2090 "::llvm::ArrayRef<::mlir::Type>())",
2091 Vals: op.getGetterName(name: var->name));
2092 if (useArrayRef)
2093 return body << "::llvm::ArrayRef<::mlir::Type>("
2094 << op.getGetterName(name: var->name) << "().getType())";
2095 return body << op.getGetterName(name: var->name) << "().getType()";
2096}
2097
2098/// Generate the printer for an enum attribute.
2099static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
2100 MethodBody &body) {
2101 Attribute baseAttr = var->attr.getBaseAttr();
2102 const EnumAttr &enumAttr = cast<EnumAttr>(Val&: baseAttr);
2103 std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
2104
2105 body << llvm::formatv(Fmt: enumAttrBeginPrinterCode,
2106 Vals: (var->attr.isOptional() ? "*" : "") +
2107 op.getGetterName(name: var->name),
2108 Vals: enumAttr.getSymbolToStringFnName());
2109
2110 // Get a string containing all of the cases that can't be represented with a
2111 // keyword.
2112 BitVector nonKeywordCases(cases.size());
2113 for (auto it : llvm::enumerate(First&: cases)) {
2114 if (!canFormatStringAsKeyword(value: it.value().getStr()))
2115 nonKeywordCases.set(it.index());
2116 }
2117
2118 // Otherwise if this is a bit enum attribute, don't allow cases that may
2119 // overlap with other cases. For simplicity sake, only allow cases with a
2120 // single bit value.
2121 if (enumAttr.isBitEnum()) {
2122 for (auto it : llvm::enumerate(First&: cases)) {
2123 int64_t value = it.value().getValue();
2124 if (value < 0 || !llvm::isPowerOf2_64(Value: value))
2125 nonKeywordCases.set(it.index());
2126 }
2127 }
2128
2129 // If there are any cases that can't be used with a keyword, switch on the
2130 // case value to determine when to print in the string form.
2131 if (nonKeywordCases.any()) {
2132 body << " switch (caseValue) {\n";
2133 StringRef cppNamespace = enumAttr.getCppNamespace();
2134 StringRef enumName = enumAttr.getEnumClassName();
2135 for (auto it : llvm::enumerate(First&: cases)) {
2136 if (nonKeywordCases.test(Idx: it.index()))
2137 continue;
2138 StringRef symbol = it.value().getSymbol();
2139 body << llvm::formatv(Fmt: " case {0}::{1}::{2}:\n", Vals&: cppNamespace, Vals&: enumName,
2140 Vals: llvm::isDigit(C: symbol.front()) ? ("_" + symbol)
2141 : symbol);
2142 }
2143 body << " _odsPrinter << caseValueStr;\n"
2144 " break;\n"
2145 " default:\n"
2146 " _odsPrinter << '\"' << caseValueStr << '\"';\n"
2147 " break;\n"
2148 " }\n"
2149 " }\n";
2150 return;
2151 }
2152
2153 body << " _odsPrinter << caseValueStr;\n"
2154 " }\n";
2155}
2156
2157/// Generate a check that a DefaultValuedAttr has a value that is non-default.
2158static void genNonDefaultValueCheck(MethodBody &body, const Operator &op,
2159 AttributeVariable &attrElement) {
2160 FmtContext fctx;
2161 Attribute attr = attrElement.getVar()->attr;
2162 fctx.withBuilder(subst: "::mlir::OpBuilder((*this)->getContext())");
2163 body << " && " << op.getGetterName(name: attrElement.getVar()->name) << "Attr() != "
2164 << tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx, vals: attr.getDefaultValue());
2165}
2166
2167/// Generate the check for the anchor of an optional group.
2168static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
2169 const Operator &op,
2170 MethodBody &body) {
2171 TypeSwitch<FormatElement *>(anchor)
2172 .Case<OperandVariable, ResultVariable>(caseFn: [&](auto *element) {
2173 const NamedTypeConstraint *var = element->getVar();
2174 std::string name = op.getGetterName(name: var->name);
2175 if (var->isOptional())
2176 body << name << "()";
2177 else if (var->isVariadic())
2178 body << "!" << name << "().empty()";
2179 })
2180 .Case(caseFn: [&](RegionVariable *element) {
2181 const NamedRegion *var = element->getVar();
2182 std::string name = op.getGetterName(name: var->name);
2183 // TODO: Add a check for optional regions here when ODS supports it.
2184 body << "!" << name << "().empty()";
2185 })
2186 .Case(caseFn: [&](TypeDirective *element) {
2187 genOptionalGroupPrinterAnchor(anchor: element->getArg(), op, body);
2188 })
2189 .Case(caseFn: [&](FunctionalTypeDirective *element) {
2190 genOptionalGroupPrinterAnchor(anchor: element->getInputs(), op, body);
2191 })
2192 .Case(caseFn: [&](AttributeVariable *element) {
2193 Attribute attr = element->getVar()->attr;
2194 body << op.getGetterName(name: element->getVar()->name) << "Attr()";
2195 if (attr.isOptional())
2196 return; // done
2197 if (attr.hasDefaultValue()) {
2198 // Consider a default-valued attribute as present if it's not the
2199 // default value.
2200 genNonDefaultValueCheck(body, op, attrElement&: *element);
2201 return;
2202 }
2203 llvm_unreachable("attribute must be optional or default-valued");
2204 })
2205 .Case(caseFn: [&](CustomDirective *ele) {
2206 body << '(';
2207 llvm::interleave(
2208 c: ele->getArguments(), os&: body,
2209 each_fn: [&](FormatElement *child) {
2210 body << '(';
2211 genOptionalGroupPrinterAnchor(anchor: child, op, body);
2212 body << ')';
2213 },
2214 separator: " || ");
2215 body << ')';
2216 });
2217}
2218
2219void collect(FormatElement *element,
2220 SmallVectorImpl<VariableElement *> &variables) {
2221 TypeSwitch<FormatElement *>(element)
2222 .Case(caseFn: [&](VariableElement *var) { variables.emplace_back(Args&: var); })
2223 .Case(caseFn: [&](CustomDirective *ele) {
2224 for (FormatElement *arg : ele->getArguments())
2225 collect(element: arg, variables);
2226 })
2227 .Case(caseFn: [&](OptionalElement *ele) {
2228 for (FormatElement *arg : ele->getThenElements())
2229 collect(element: arg, variables);
2230 for (FormatElement *arg : ele->getElseElements())
2231 collect(element: arg, variables);
2232 })
2233 .Case(caseFn: [&](FunctionalTypeDirective *funcType) {
2234 collect(element: funcType->getInputs(), variables);
2235 collect(element: funcType->getResults(), variables);
2236 })
2237 .Case(caseFn: [&](OIListElement *oilist) {
2238 for (ArrayRef<FormatElement *> arg : oilist->getParsingElements())
2239 for (FormatElement *arg : arg)
2240 collect(element: arg, variables);
2241 });
2242}
2243
2244void OperationFormat::genElementPrinter(FormatElement *element,
2245 MethodBody &body, Operator &op,
2246 bool &shouldEmitSpace,
2247 bool &lastWasPunctuation) {
2248 if (LiteralElement *literal = dyn_cast<LiteralElement>(Val: element))
2249 return genLiteralPrinter(value: literal->getSpelling(), body, shouldEmitSpace,
2250 lastWasPunctuation);
2251
2252 // Emit a whitespace element.
2253 if (auto *space = dyn_cast<WhitespaceElement>(Val: element)) {
2254 if (space->getValue() == "\\n") {
2255 body << " _odsPrinter.printNewline();\n";
2256 } else {
2257 genSpacePrinter(value: !space->getValue().empty(), body, shouldEmitSpace,
2258 lastWasPunctuation);
2259 }
2260 return;
2261 }
2262
2263 // Emit an optional group.
2264 if (OptionalElement *optional = dyn_cast<OptionalElement>(Val: element)) {
2265 // Emit the check for the presence of the anchor element.
2266 FormatElement *anchor = optional->getAnchor();
2267 body << " if (";
2268 if (optional->isInverted())
2269 body << "!";
2270 genOptionalGroupPrinterAnchor(anchor, op, body);
2271 body << ") {\n";
2272 body.indent();
2273
2274 // If the anchor is a unit attribute, we don't need to print it. When
2275 // parsing, we will add this attribute if this group is present.
2276 ArrayRef<FormatElement *> thenElements = optional->getThenElements();
2277 ArrayRef<FormatElement *> elseElements = optional->getElseElements();
2278 FormatElement *elidedAnchorElement = nullptr;
2279 auto *anchorAttr = dyn_cast<AttributeVariable>(Val: anchor);
2280 if (anchorAttr && anchorAttr != thenElements.front() &&
2281 (elseElements.empty() || anchorAttr != elseElements.front()) &&
2282 anchorAttr->isUnitAttr()) {
2283 elidedAnchorElement = anchorAttr;
2284 }
2285 auto genElementPrinters = [&](ArrayRef<FormatElement *> elements) {
2286 for (FormatElement *childElement : elements) {
2287 if (childElement != elidedAnchorElement) {
2288 genElementPrinter(element: childElement, body, op, shouldEmitSpace,
2289 lastWasPunctuation);
2290 }
2291 }
2292 };
2293
2294 // Emit each of the elements.
2295 genElementPrinters(thenElements);
2296 body << "}";
2297
2298 // Emit each of the else elements.
2299 if (!elseElements.empty()) {
2300 body << " else {\n";
2301 genElementPrinters(elseElements);
2302 body << "}";
2303 }
2304
2305 body.unindent() << "\n";
2306 return;
2307 }
2308
2309 // Emit the OIList
2310 if (auto *oilist = dyn_cast<OIListElement>(Val: element)) {
2311 for (auto clause : oilist->getClauses()) {
2312 LiteralElement *lelement = std::get<0>(t&: clause);
2313 ArrayRef<FormatElement *> pelement = std::get<1>(t&: clause);
2314
2315 SmallVector<VariableElement *> vars;
2316 for (FormatElement *el : pelement)
2317 collect(element: el, variables&: vars);
2318 body << " if (false";
2319 for (VariableElement *var : vars) {
2320 TypeSwitch<FormatElement *>(var)
2321 .Case(caseFn: [&](AttributeVariable *attrEle) {
2322 body << " || (" << op.getGetterName(name: attrEle->getVar()->name)
2323 << "Attr()";
2324 Attribute attr = attrEle->getVar()->attr;
2325 if (attr.hasDefaultValue()) {
2326 // Don't print default-valued attributes.
2327 genNonDefaultValueCheck(body, op, attrElement&: *attrEle);
2328 }
2329 body << ")";
2330 })
2331 .Case(caseFn: [&](OperandVariable *ele) {
2332 if (ele->getVar()->isVariadic()) {
2333 body << " || " << op.getGetterName(name: ele->getVar()->name)
2334 << "().size()";
2335 } else {
2336 body << " || " << op.getGetterName(name: ele->getVar()->name) << "()";
2337 }
2338 })
2339 .Case(caseFn: [&](ResultVariable *ele) {
2340 if (ele->getVar()->isVariadic()) {
2341 body << " || " << op.getGetterName(name: ele->getVar()->name)
2342 << "().size()";
2343 } else {
2344 body << " || " << op.getGetterName(name: ele->getVar()->name) << "()";
2345 }
2346 })
2347 .Case(caseFn: [&](RegionVariable *reg) {
2348 body << " || " << op.getGetterName(name: reg->getVar()->name) << "()";
2349 });
2350 }
2351
2352 body << ") {\n";
2353 genLiteralPrinter(value: lelement->getSpelling(), body, shouldEmitSpace,
2354 lastWasPunctuation);
2355 if (oilist->getUnitAttrParsingElement(pelement) == nullptr) {
2356 for (FormatElement *element : pelement)
2357 genElementPrinter(element, body, op, shouldEmitSpace,
2358 lastWasPunctuation);
2359 }
2360 body << " }\n";
2361 }
2362 return;
2363 }
2364
2365 // Emit the attribute dictionary.
2366 if (auto *attrDict = dyn_cast<AttrDictDirective>(Val: element)) {
2367 genAttrDictPrinter(fmt&: *this, op, body, withKeyword: attrDict->isWithKeyword());
2368 lastWasPunctuation = false;
2369 return;
2370 }
2371
2372 // Emit the attribute dictionary.
2373 if (isa<PropDictDirective>(Val: element)) {
2374 genPropDictPrinter(fmt&: *this, op, body);
2375 lastWasPunctuation = false;
2376 return;
2377 }
2378
2379 // Optionally insert a space before the next element. The AttrDict printer
2380 // already adds a space as necessary.
2381 if (shouldEmitSpace || !lastWasPunctuation)
2382 body << " _odsPrinter << ' ';\n";
2383 lastWasPunctuation = false;
2384 shouldEmitSpace = true;
2385
2386 if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) {
2387 const NamedAttribute *var = attr->getVar();
2388
2389 // If we are formatting as an enum, symbolize the attribute as a string.
2390 if (canFormatEnumAttr(attr: var))
2391 return genEnumAttrPrinter(var, op, body);
2392
2393 // If we are formatting as a symbol name, handle it as a symbol name.
2394 if (shouldFormatSymbolNameAttr(attr: var)) {
2395 body << " _odsPrinter.printSymbolName(" << op.getGetterName(name: var->name)
2396 << "Attr().getValue());\n";
2397 return;
2398 }
2399
2400 // Elide the attribute type if it is buildable.
2401 if (attr->getTypeBuilder())
2402 body << " _odsPrinter.printAttributeWithoutType("
2403 << op.getGetterName(name: var->name) << "Attr());\n";
2404 else if (attr->shouldBeQualified() ||
2405 var->attr.getStorageType() == "::mlir::Attribute")
2406 body << " _odsPrinter.printAttribute(" << op.getGetterName(name: var->name)
2407 << "Attr());\n";
2408 else
2409 body << "_odsPrinter.printStrippedAttrOrType("
2410 << op.getGetterName(name: var->name) << "Attr());\n";
2411 } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) {
2412 if (operand->getVar()->isVariadicOfVariadic()) {
2413 body << " ::llvm::interleaveComma("
2414 << op.getGetterName(name: operand->getVar()->name)
2415 << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << "
2416 "\"(\" << operands << "
2417 "\")\"; });\n";
2418
2419 } else if (operand->getVar()->isOptional()) {
2420 body << " if (::mlir::Value value = "
2421 << op.getGetterName(name: operand->getVar()->name) << "())\n"
2422 << " _odsPrinter << value;\n";
2423 } else {
2424 body << " _odsPrinter << " << op.getGetterName(name: operand->getVar()->name)
2425 << "();\n";
2426 }
2427 } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) {
2428 const NamedRegion *var = region->getVar();
2429 std::string name = op.getGetterName(name: var->name);
2430 if (var->isVariadic()) {
2431 genVariadicRegionPrinter(regionListName: name + "()", body, hasImplicitTermTrait);
2432 } else {
2433 genRegionPrinter(regionName: name + "()", body, hasImplicitTermTrait);
2434 }
2435 } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) {
2436 const NamedSuccessor *var = successor->getVar();
2437 std::string name = op.getGetterName(name: var->name);
2438 if (var->isVariadic())
2439 body << " ::llvm::interleaveComma(" << name << "(), _odsPrinter);\n";
2440 else
2441 body << " _odsPrinter << " << name << "();\n";
2442 } else if (auto *dir = dyn_cast<CustomDirective>(Val: element)) {
2443 genCustomDirectivePrinter(customDir: dir, op, body);
2444 } else if (isa<OperandsDirective>(Val: element)) {
2445 body << " _odsPrinter << getOperation()->getOperands();\n";
2446 } else if (isa<RegionsDirective>(Val: element)) {
2447 genVariadicRegionPrinter(regionListName: "getOperation()->getRegions()", body,
2448 hasImplicitTermTrait);
2449 } else if (isa<SuccessorsDirective>(Val: element)) {
2450 body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), "
2451 "_odsPrinter);\n";
2452 } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) {
2453 if (auto *operand = dyn_cast<OperandVariable>(Val: dir->getArg())) {
2454 if (operand->getVar()->isVariadicOfVariadic()) {
2455 body << llvm::formatv(
2456 Fmt: " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
2457 "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
2458 "types << \")\"; });\n",
2459 Vals: op.getGetterName(name: operand->getVar()->name));
2460 return;
2461 }
2462 }
2463 const NamedTypeConstraint *var = nullptr;
2464 {
2465 if (auto *operand = dyn_cast<OperandVariable>(Val: dir->getArg()))
2466 var = operand->getVar();
2467 else if (auto *operand = dyn_cast<ResultVariable>(Val: dir->getArg()))
2468 var = operand->getVar();
2469 }
2470 if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
2471 !var->isOptional()) {
2472 std::string cppClass = var->constraint.getCPPClassName();
2473 if (dir->shouldBeQualified()) {
2474 body << " _odsPrinter << " << op.getGetterName(name: var->name)
2475 << "().getType();\n";
2476 return;
2477 }
2478 body << " {\n"
2479 << " auto type = " << op.getGetterName(name: var->name)
2480 << "().getType();\n"
2481 << " if (auto validType = ::llvm::dyn_cast<" << cppClass
2482 << ">(type))\n"
2483 << " _odsPrinter.printStrippedAttrOrType(validType);\n"
2484 << " else\n"
2485 << " _odsPrinter << type;\n"
2486 << " }\n";
2487 return;
2488 }
2489 body << " _odsPrinter << ";
2490 genTypeOperandPrinter(arg: dir->getArg(), op, body, /*useArrayRef=*/false)
2491 << ";\n";
2492 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(Val: element)) {
2493 body << " _odsPrinter.printFunctionalType(";
2494 genTypeOperandPrinter(arg: dir->getInputs(), op, body) << ", ";
2495 genTypeOperandPrinter(arg: dir->getResults(), op, body) << ");\n";
2496 } else {
2497 llvm_unreachable("unknown format element");
2498 }
2499}
2500
2501void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
2502 auto *method = opClass.addMethod(
2503 retType: "void", name: "print",
2504 args: MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
2505 auto &body = method->body();
2506
2507 // Flags for if we should emit a space, and if the last element was
2508 // punctuation.
2509 bool shouldEmitSpace = true, lastWasPunctuation = false;
2510 for (FormatElement *element : elements)
2511 genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation);
2512}
2513
2514//===----------------------------------------------------------------------===//
2515// OpFormatParser
2516//===----------------------------------------------------------------------===//
2517
2518/// Function to find an element within the given range that has the same name as
2519/// 'name'.
2520template <typename RangeT>
2521static auto findArg(RangeT &&range, StringRef name) {
2522 auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
2523 return it != range.end() ? &*it : nullptr;
2524}
2525
2526namespace {
2527/// This class implements a parser for an instance of an operation assembly
2528/// format.
2529class OpFormatParser : public FormatParser {
2530public:
2531 OpFormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
2532 : FormatParser(mgr, op.getLoc()[0]), fmt(format), op(op),
2533 seenOperandTypes(op.getNumOperands()),
2534 seenResultTypes(op.getNumResults()) {}
2535
2536protected:
2537 /// Verify the format elements.
2538 LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
2539 /// Verify the arguments to a custom directive.
2540 LogicalResult
2541 verifyCustomDirectiveArguments(SMLoc loc,
2542 ArrayRef<FormatElement *> arguments) override;
2543 /// Verify the elements of an optional group.
2544 LogicalResult verifyOptionalGroupElements(SMLoc loc,
2545 ArrayRef<FormatElement *> elements,
2546 FormatElement *anchor) override;
2547 LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element,
2548 bool isAnchor);
2549
2550 /// Parse an operation variable.
2551 FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
2552 Context ctx) override;
2553 /// Parse an operation format directive.
2554 FailureOr<FormatElement *>
2555 parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
2556
2557private:
2558 /// This struct represents a type resolution instance. It includes a specific
2559 /// type as well as an optional transformer to apply to that type in order to
2560 /// properly resolve the type of a variable.
2561 struct TypeResolutionInstance {
2562 ConstArgument resolver;
2563 std::optional<StringRef> transformer;
2564 };
2565
2566 /// Verify the state of operation attributes within the format.
2567 LogicalResult verifyAttributes(SMLoc loc, ArrayRef<FormatElement *> elements);
2568
2569 /// Verify that attributes elements aren't followed by colon literals.
2570 LogicalResult verifyAttributeColonType(SMLoc loc,
2571 ArrayRef<FormatElement *> elements);
2572 /// Verify that the attribute dictionary directive isn't followed by a region.
2573 LogicalResult verifyAttrDictRegion(SMLoc loc,
2574 ArrayRef<FormatElement *> elements);
2575
2576 /// Verify the state of operation operands within the format.
2577 LogicalResult
2578 verifyOperands(SMLoc loc,
2579 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2580
2581 /// Verify the state of operation regions within the format.
2582 LogicalResult verifyRegions(SMLoc loc);
2583
2584 /// Verify the state of operation results within the format.
2585 LogicalResult
2586 verifyResults(SMLoc loc,
2587 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2588
2589 /// Verify the state of operation successors within the format.
2590 LogicalResult verifySuccessors(SMLoc loc);
2591
2592 LogicalResult verifyOIListElements(SMLoc loc,
2593 ArrayRef<FormatElement *> elements);
2594
2595 /// Given the values of an `AllTypesMatch` trait, check for inferable type
2596 /// resolution.
2597 void handleAllTypesMatchConstraint(
2598 ArrayRef<StringRef> values,
2599 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2600 /// Check for inferable type resolution given all operands, and or results,
2601 /// have the same type. If 'includeResults' is true, the results also have the
2602 /// same type as all of the operands.
2603 void handleSameTypesConstraint(
2604 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2605 bool includeResults);
2606 /// Check for inferable type resolution based on another operand, result, or
2607 /// attribute.
2608 void handleTypesMatchConstraint(
2609 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2610 const llvm::Record &def);
2611
2612 /// Returns an argument or attribute with the given name that has been seen
2613 /// within the format.
2614 ConstArgument findSeenArg(StringRef name);
2615
2616 /// Parse the various different directives.
2617 FailureOr<FormatElement *> parsePropDictDirective(SMLoc loc, Context context);
2618 FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context,
2619 bool withKeyword);
2620 FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc,
2621 Context context);
2622 FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context);
2623 LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc);
2624 FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context);
2625 FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc,
2626 Context context);
2627 FailureOr<FormatElement *> parseReferenceDirective(SMLoc loc,
2628 Context context);
2629 FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context);
2630 FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context);
2631 FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc,
2632 Context context);
2633 FailureOr<FormatElement *> parseTypeDirective(SMLoc loc, Context context);
2634 FailureOr<FormatElement *> parseTypeDirectiveOperand(SMLoc loc,
2635 bool isRefChild = false);
2636
2637 //===--------------------------------------------------------------------===//
2638 // Fields
2639 //===--------------------------------------------------------------------===//
2640
2641 OperationFormat &fmt;
2642 Operator &op;
2643
2644 // The following are various bits of format state used for verification
2645 // during parsing.
2646 bool hasAttrDict = false;
2647 bool hasPropDict = false;
2648 bool hasAllRegions = false, hasAllSuccessors = false;
2649 bool canInferResultTypes = false;
2650 llvm::SmallBitVector seenOperandTypes, seenResultTypes;
2651 llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
2652 llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
2653 llvm::DenseSet<const NamedRegion *> seenRegions;
2654 llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
2655 llvm::SmallSetVector<const NamedProperty *, 8> seenProperties;
2656};
2657} // namespace
2658
2659LogicalResult OpFormatParser::verify(SMLoc loc,
2660 ArrayRef<FormatElement *> elements) {
2661 // Check that the attribute dictionary is in the format.
2662 if (!hasAttrDict)
2663 return emitError(loc, msg: "'attr-dict' directive not found in "
2664 "custom assembly format");
2665
2666 // Check for any type traits that we can use for inferring types.
2667 llvm::StringMap<TypeResolutionInstance> variableTyResolver;
2668 for (const Trait &trait : op.getTraits()) {
2669 const llvm::Record &def = trait.getDef();
2670 if (def.isSubClassOf(Name: "AllTypesMatch")) {
2671 handleAllTypesMatchConstraint(values: def.getValueAsListOfStrings(FieldName: "values"),
2672 variableTyResolver);
2673 } else if (def.getName() == "SameTypeOperands") {
2674 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
2675 } else if (def.getName() == "SameOperandsAndResultType") {
2676 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
2677 } else if (def.isSubClassOf(Name: "TypesMatchWith")) {
2678 handleTypesMatchConstraint(variableTyResolver, def);
2679 } else if (!op.allResultTypesKnown()) {
2680 // This doesn't check the name directly to handle
2681 // DeclareOpInterfaceMethods<InferTypeOpInterface>
2682 // and the like.
2683 // TODO: Add hasCppInterface check.
2684 if (auto name = def.getValueAsOptionalString(FieldName: "cppInterfaceName")) {
2685 if (*name == "InferTypeOpInterface" &&
2686 def.getValueAsString(FieldName: "cppNamespace") == "::mlir")
2687 canInferResultTypes = true;
2688 }
2689 }
2690 }
2691
2692 // Verify the state of the various operation components.
2693 if (failed(result: verifyAttributes(loc, elements)) ||
2694 failed(result: verifyResults(loc, variableTyResolver)) ||
2695 failed(result: verifyOperands(loc, variableTyResolver)) ||
2696 failed(result: verifyRegions(loc)) || failed(result: verifySuccessors(loc)) ||
2697 failed(result: verifyOIListElements(loc, elements)))
2698 return failure();
2699
2700 // Collect the set of used attributes in the format.
2701 fmt.usedAttributes = std::move(seenAttrs);
2702 fmt.usedProperties = std::move(seenProperties);
2703
2704 // Set whether prop-dict is used in the format
2705 fmt.hasPropDict = hasPropDict;
2706 return success();
2707}
2708
2709LogicalResult
2710OpFormatParser::verifyAttributes(SMLoc loc,
2711 ArrayRef<FormatElement *> elements) {
2712 // Check that there are no `:` literals after an attribute without a constant
2713 // type. The attribute grammar contains an optional trailing colon type, which
2714 // can lead to unexpected and generally unintended behavior. Given that, it is
2715 // better to just error out here instead.
2716 if (failed(result: verifyAttributeColonType(loc, elements)))
2717 return failure();
2718 // Check that there are no region variables following an attribute dicitonary.
2719 // Both start with `{` and so the optional attribute dictionary can cause
2720 // format ambiguities.
2721 if (failed(result: verifyAttrDictRegion(loc, elements)))
2722 return failure();
2723
2724 // Check for VariadicOfVariadic variables. The segment attribute of those
2725 // variables will be infered.
2726 for (const NamedTypeConstraint *var : seenOperands) {
2727 if (var->constraint.isVariadicOfVariadic()) {
2728 fmt.inferredAttributes.insert(
2729 key: var->constraint.getVariadicOfVariadicSegmentSizeAttr());
2730 }
2731 }
2732
2733 return success();
2734}
2735
2736/// Returns whether the single format element is optionally parsed.
2737static bool isOptionallyParsed(FormatElement *el) {
2738 if (auto *attrVar = dyn_cast<AttributeVariable>(Val: el)) {
2739 Attribute attr = attrVar->getVar()->attr;
2740 return attr.isOptional() || attr.hasDefaultValue();
2741 }
2742 if (auto *operandVar = dyn_cast<OperandVariable>(Val: el)) {
2743 const NamedTypeConstraint *operand = operandVar->getVar();
2744 return operand->isOptional() || operand->isVariadic() ||
2745 operand->isVariadicOfVariadic();
2746 }
2747 if (auto *successorVar = dyn_cast<SuccessorVariable>(Val: el))
2748 return successorVar->getVar()->isVariadic();
2749 if (auto *regionVar = dyn_cast<RegionVariable>(Val: el))
2750 return regionVar->getVar()->isVariadic();
2751 return isa<WhitespaceElement, AttrDictDirective>(Val: el);
2752}
2753
2754/// Scan the given range of elements from the start for an invalid format
2755/// element that satisfies `isInvalid`, skipping any optionally-parsed elements.
2756/// If an optional group is encountered, this function recurses into the 'then'
2757/// and 'else' elements to check if they are invalid. Returns `success` if the
2758/// range is known to be valid or `std::nullopt` if scanning reached the end.
2759///
2760/// Since the guard element of an optional group is required, this function
2761/// accepts an optional element pointer to mark it as required.
2762static std::optional<LogicalResult> checkRangeForElement(
2763 FormatElement *base,
2764 function_ref<bool(FormatElement *, FormatElement *)> isInvalid,
2765 iterator_range<ArrayRef<FormatElement *>::iterator> elementRange,
2766 FormatElement *optionalGuard = nullptr) {
2767 for (FormatElement *element : elementRange) {
2768 // If we encounter an invalid element, return an error.
2769 if (isInvalid(base, element))
2770 return failure();
2771
2772 // Recurse on optional groups.
2773 if (auto *optional = dyn_cast<OptionalElement>(Val: element)) {
2774 if (std::optional<LogicalResult> result = checkRangeForElement(
2775 base, isInvalid, elementRange: optional->getThenElements(),
2776 // The optional group guard is required for the group.
2777 optionalGuard: optional->getThenElements().front()))
2778 if (failed(result: *result))
2779 return failure();
2780 if (std::optional<LogicalResult> result = checkRangeForElement(
2781 base, isInvalid, elementRange: optional->getElseElements()))
2782 if (failed(result: *result))
2783 return failure();
2784 // Skip the optional group.
2785 continue;
2786 }
2787
2788 // Skip optionally parsed elements.
2789 if (element != optionalGuard && isOptionallyParsed(el: element))
2790 continue;
2791
2792 // We found a closing element that is valid.
2793 return success();
2794 }
2795 // Return std::nullopt to indicate that we reached the end.
2796 return std::nullopt;
2797}
2798
2799/// For the given elements, check whether any attributes are followed by a colon
2800/// literal, resulting in an ambiguous assembly format. Returns a non-null
2801/// attribute if verification of said attribute reached the end of the range.
2802/// Returns null if all attribute elements are verified.
2803static FailureOr<FormatElement *> verifyAdjacentElements(
2804 function_ref<bool(FormatElement *)> isBase,
2805 function_ref<bool(FormatElement *, FormatElement *)> isInvalid,
2806 ArrayRef<FormatElement *> elements) {
2807 for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) {
2808 // The current attribute being verified.
2809 FormatElement *base;
2810
2811 if (isBase(*it)) {
2812 base = *it;
2813 } else if (auto *optional = dyn_cast<OptionalElement>(Val: *it)) {
2814 // Recurse on optional groups.
2815 FailureOr<FormatElement *> thenResult = verifyAdjacentElements(
2816 isBase, isInvalid, elements: optional->getThenElements());
2817 if (failed(result: thenResult))
2818 return failure();
2819 FailureOr<FormatElement *> elseResult = verifyAdjacentElements(
2820 isBase, isInvalid, elements: optional->getElseElements());
2821 if (failed(result: elseResult))
2822 return failure();
2823 // If either optional group has an unverified attribute, save it.
2824 // Otherwise, move on to the next element.
2825 if (!(base = *thenResult) && !(base = *elseResult))
2826 continue;
2827 } else {
2828 continue;
2829 }
2830
2831 // Verify subsequent elements for potential ambiguities.
2832 if (std::optional<LogicalResult> result =
2833 checkRangeForElement(base, isInvalid, elementRange: {std::next(x: it), e})) {
2834 if (failed(result: *result))
2835 return failure();
2836 } else {
2837 // Since we reached the end, return the attribute as unverified.
2838 return base;
2839 }
2840 }
2841 // All attribute elements are known to be verified.
2842 return nullptr;
2843}
2844
2845LogicalResult
2846OpFormatParser::verifyAttributeColonType(SMLoc loc,
2847 ArrayRef<FormatElement *> elements) {
2848 auto isBase = [](FormatElement *el) {
2849 auto *attr = dyn_cast<AttributeVariable>(Val: el);
2850 if (!attr)
2851 return false;
2852 // Check only attributes without type builders or that are known to call
2853 // the generic attribute parser.
2854 return !attr->getTypeBuilder() &&
2855 (attr->shouldBeQualified() ||
2856 attr->getVar()->attr.getStorageType() == "::mlir::Attribute");
2857 };
2858 auto isInvalid = [&](FormatElement *base, FormatElement *el) {
2859 auto *literal = dyn_cast<LiteralElement>(Val: el);
2860 if (!literal || literal->getSpelling() != ":")
2861 return false;
2862 // If we encounter `:`, the range is known to be invalid.
2863 (void)emitError(
2864 loc,
2865 msg: llvm::formatv(Fmt: "format ambiguity caused by `:` literal found after "
2866 "attribute `{0}` which does not have a buildable type",
2867 Vals: cast<AttributeVariable>(Val: base)->getVar()->name));
2868 return true;
2869 };
2870 return verifyAdjacentElements(isBase, isInvalid, elements);
2871}
2872
2873LogicalResult
2874OpFormatParser::verifyAttrDictRegion(SMLoc loc,
2875 ArrayRef<FormatElement *> elements) {
2876 auto isBase = [](FormatElement *el) {
2877 if (auto *attrDict = dyn_cast<AttrDictDirective>(Val: el))
2878 return !attrDict->isWithKeyword();
2879 return false;
2880 };
2881 auto isInvalid = [&](FormatElement *base, FormatElement *el) {
2882 auto *region = dyn_cast<RegionVariable>(Val: el);
2883 if (!region)
2884 return false;
2885 (void)emitErrorAndNote(
2886 loc,
2887 msg: llvm::formatv(Fmt: "format ambiguity caused by `attr-dict` directive "
2888 "followed by region `{0}`",
2889 Vals: region->getVar()->name),
2890 note: "try using `attr-dict-with-keyword` instead");
2891 return true;
2892 };
2893 return verifyAdjacentElements(isBase, isInvalid, elements);
2894}
2895
2896LogicalResult OpFormatParser::verifyOperands(
2897 SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2898 // Check that all of the operands are within the format, and their types can
2899 // be inferred.
2900 auto &buildableTypes = fmt.buildableTypes;
2901 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
2902 NamedTypeConstraint &operand = op.getOperand(index: i);
2903
2904 // Check that the operand itself is in the format.
2905 if (!fmt.allOperands && !seenOperands.count(V: &operand)) {
2906 return emitErrorAndNote(loc,
2907 msg: "operand #" + Twine(i) + ", named '" +
2908 operand.name + "', not found",
2909 note: "suggest adding a '$" + operand.name +
2910 "' directive to the custom assembly format");
2911 }
2912
2913 // Check that the operand type is in the format, or that it can be inferred.
2914 if (fmt.allOperandTypes || seenOperandTypes.test(Idx: i))
2915 continue;
2916
2917 // Check to see if we can infer this type from another variable.
2918 auto varResolverIt = variableTyResolver.find(Key: op.getOperand(index: i).name);
2919 if (varResolverIt != variableTyResolver.end()) {
2920 TypeResolutionInstance &resolver = varResolverIt->second;
2921 fmt.operandTypes[i].setResolver(arg: resolver.resolver, transformer: resolver.transformer);
2922 continue;
2923 }
2924
2925 // Similarly to results, allow a custom builder for resolving the type if
2926 // we aren't using the 'operands' directive.
2927 std::optional<StringRef> builder = operand.constraint.getBuilderCall();
2928 if (!builder || (fmt.allOperands && operand.isVariableLength())) {
2929 return emitErrorAndNote(
2930 loc,
2931 msg: "type of operand #" + Twine(i) + ", named '" + operand.name +
2932 "', is not buildable and a buildable type cannot be inferred",
2933 note: "suggest adding a type constraint to the operation or adding a "
2934 "'type($" +
2935 operand.name + ")' directive to the " + "custom assembly format");
2936 }
2937 auto it = buildableTypes.insert(KV: {*builder, buildableTypes.size()});
2938 fmt.operandTypes[i].setBuilderIdx(it.first->second);
2939 }
2940 return success();
2941}
2942
2943LogicalResult OpFormatParser::verifyRegions(SMLoc loc) {
2944 // Check that all of the regions are within the format.
2945 if (hasAllRegions)
2946 return success();
2947
2948 for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
2949 const NamedRegion &region = op.getRegion(index: i);
2950 if (!seenRegions.count(V: &region)) {
2951 return emitErrorAndNote(loc,
2952 msg: "region #" + Twine(i) + ", named '" +
2953 region.name + "', not found",
2954 note: "suggest adding a '$" + region.name +
2955 "' directive to the custom assembly format");
2956 }
2957 }
2958 return success();
2959}
2960
2961LogicalResult OpFormatParser::verifyResults(
2962 SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2963 // If we format all of the types together, there is nothing to check.
2964 if (fmt.allResultTypes)
2965 return success();
2966
2967 // If no result types are specified and we can infer them, infer all result
2968 // types
2969 if (op.getNumResults() > 0 && seenResultTypes.count() == 0 &&
2970 canInferResultTypes) {
2971 fmt.infersResultTypes = true;
2972 return success();
2973 }
2974
2975 // Check that all of the result types can be inferred.
2976 auto &buildableTypes = fmt.buildableTypes;
2977 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
2978 if (seenResultTypes.test(Idx: i))
2979 continue;
2980
2981 // Check to see if we can infer this type from another variable.
2982 auto varResolverIt = variableTyResolver.find(Key: op.getResultName(index: i));
2983 if (varResolverIt != variableTyResolver.end()) {
2984 TypeResolutionInstance resolver = varResolverIt->second;
2985 fmt.resultTypes[i].setResolver(arg: resolver.resolver, transformer: resolver.transformer);
2986 continue;
2987 }
2988
2989 // If the result is not variable length, allow for the case where the type
2990 // has a builder that we can use.
2991 NamedTypeConstraint &result = op.getResult(index: i);
2992 std::optional<StringRef> builder = result.constraint.getBuilderCall();
2993 if (!builder || result.isVariableLength()) {
2994 return emitErrorAndNote(
2995 loc,
2996 msg: "type of result #" + Twine(i) + ", named '" + result.name +
2997 "', is not buildable and a buildable type cannot be inferred",
2998 note: "suggest adding a type constraint to the operation or adding a "
2999 "'type($" +
3000 result.name + ")' directive to the " + "custom assembly format");
3001 }
3002 // Note in the format that this result uses the custom builder.
3003 auto it = buildableTypes.insert(KV: {*builder, buildableTypes.size()});
3004 fmt.resultTypes[i].setBuilderIdx(it.first->second);
3005 }
3006 return success();
3007}
3008
3009LogicalResult OpFormatParser::verifySuccessors(SMLoc loc) {
3010 // Check that all of the successors are within the format.
3011 if (hasAllSuccessors)
3012 return success();
3013
3014 for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
3015 const NamedSuccessor &successor = op.getSuccessor(index: i);
3016 if (!seenSuccessors.count(V: &successor)) {
3017 return emitErrorAndNote(loc,
3018 msg: "successor #" + Twine(i) + ", named '" +
3019 successor.name + "', not found",
3020 note: "suggest adding a '$" + successor.name +
3021 "' directive to the custom assembly format");
3022 }
3023 }
3024 return success();
3025}
3026
3027LogicalResult
3028OpFormatParser::verifyOIListElements(SMLoc loc,
3029 ArrayRef<FormatElement *> elements) {
3030 // Check that all of the successors are within the format.
3031 SmallVector<StringRef> prohibitedLiterals;
3032 for (FormatElement *it : elements) {
3033 if (auto *oilist = dyn_cast<OIListElement>(Val: it)) {
3034 if (!prohibitedLiterals.empty()) {
3035 // We just saw an oilist element in last iteration. Literals should not
3036 // match.
3037 for (LiteralElement *literal : oilist->getLiteralElements()) {
3038 if (find(Range&: prohibitedLiterals, Val: literal->getSpelling()) !=
3039 prohibitedLiterals.end()) {
3040 return emitError(
3041 loc, msg: "format ambiguity because " + literal->getSpelling() +
3042 " is used in two adjacent oilist elements.");
3043 }
3044 }
3045 }
3046 for (LiteralElement *literal : oilist->getLiteralElements())
3047 prohibitedLiterals.push_back(Elt: literal->getSpelling());
3048 } else if (auto *literal = dyn_cast<LiteralElement>(Val: it)) {
3049 if (find(Range&: prohibitedLiterals, Val: literal->getSpelling()) !=
3050 prohibitedLiterals.end()) {
3051 return emitError(
3052 loc,
3053 msg: "format ambiguity because " + literal->getSpelling() +
3054 " is used both in oilist element and the adjacent literal.");
3055 }
3056 prohibitedLiterals.clear();
3057 } else {
3058 prohibitedLiterals.clear();
3059 }
3060 }
3061 return success();
3062}
3063
3064void OpFormatParser::handleAllTypesMatchConstraint(
3065 ArrayRef<StringRef> values,
3066 llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
3067 for (unsigned i = 0, e = values.size(); i != e; ++i) {
3068 // Check to see if this value matches a resolved operand or result type.
3069 ConstArgument arg = findSeenArg(name: values[i]);
3070 if (!arg)
3071 continue;
3072
3073 // Mark this value as the type resolver for the other variables.
3074 for (unsigned j = 0; j != i; ++j)
3075 variableTyResolver[values[j]] = {.resolver: arg, .transformer: std::nullopt};
3076 for (unsigned j = i + 1; j != e; ++j)
3077 variableTyResolver[values[j]] = {.resolver: arg, .transformer: std::nullopt};
3078 }
3079}
3080
3081void OpFormatParser::handleSameTypesConstraint(
3082 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
3083 bool includeResults) {
3084 const NamedTypeConstraint *resolver = nullptr;
3085 int resolvedIt = -1;
3086
3087 // Check to see if there is an operand or result to use for the resolution.
3088 if ((resolvedIt = seenOperandTypes.find_first()) != -1)
3089 resolver = &op.getOperand(index: resolvedIt);
3090 else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1)
3091 resolver = &op.getResult(index: resolvedIt);
3092 else
3093 return;
3094
3095 // Set the resolvers for each operand and result.
3096 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
3097 if (!seenOperandTypes.test(Idx: i))
3098 variableTyResolver[op.getOperand(index: i).name] = {.resolver: resolver, .transformer: std::nullopt};
3099 if (includeResults) {
3100 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
3101 if (!seenResultTypes.test(Idx: i))
3102 variableTyResolver[op.getResultName(index: i)] = {.resolver: resolver, .transformer: std::nullopt};
3103 }
3104}
3105
3106void OpFormatParser::handleTypesMatchConstraint(
3107 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
3108 const llvm::Record &def) {
3109 StringRef lhsName = def.getValueAsString(FieldName: "lhs");
3110 StringRef rhsName = def.getValueAsString(FieldName: "rhs");
3111 StringRef transformer = def.getValueAsString(FieldName: "transformer");
3112 if (ConstArgument arg = findSeenArg(name: lhsName))
3113 variableTyResolver[rhsName] = {.resolver: arg, .transformer: transformer};
3114}
3115
3116ConstArgument OpFormatParser::findSeenArg(StringRef name) {
3117 if (const NamedTypeConstraint *arg = findArg(range: op.getOperands(), name))
3118 return seenOperandTypes.test(Idx: arg - op.operand_begin()) ? arg : nullptr;
3119 if (const NamedTypeConstraint *arg = findArg(range: op.getResults(), name))
3120 return seenResultTypes.test(Idx: arg - op.result_begin()) ? arg : nullptr;
3121 if (const NamedAttribute *attr = findArg(range: op.getAttributes(), name))
3122 return seenAttrs.count(key: attr) ? attr : nullptr;
3123 return nullptr;
3124}
3125
3126FailureOr<FormatElement *>
3127OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
3128 // Check that the parsed argument is something actually registered on the op.
3129 // Attributes
3130 if (const NamedAttribute *attr = findArg(range: op.getAttributes(), name)) {
3131 if (ctx == TypeDirectiveContext)
3132 return emitError(
3133 loc, msg: "attributes cannot be used as children to a `type` directive");
3134 if (ctx == RefDirectiveContext) {
3135 if (!seenAttrs.count(key: attr))
3136 return emitError(loc, msg: "attribute '" + name +
3137 "' must be bound before it is referenced");
3138 } else if (!seenAttrs.insert(X: attr)) {
3139 return emitError(loc, msg: "attribute '" + name + "' is already bound");
3140 }
3141
3142 return create<AttributeVariable>(args&: attr);
3143 }
3144
3145 if (const NamedProperty *property = findArg(range: op.getProperties(), name)) {
3146 if (ctx != CustomDirectiveContext && ctx != RefDirectiveContext)
3147 return emitError(
3148 loc, msg: "properties currently only supported in `custom` directive");
3149
3150 if (ctx == RefDirectiveContext) {
3151 if (!seenProperties.count(key: property))
3152 return emitError(loc, msg: "property '" + name +
3153 "' must be bound before it is referenced");
3154 } else {
3155 if (!seenProperties.insert(X: property))
3156 return emitError(loc, msg: "property '" + name + "' is already bound");
3157 }
3158
3159 return create<PropertyVariable>(args&: property);
3160 }
3161
3162 // Operands
3163 if (const NamedTypeConstraint *operand = findArg(range: op.getOperands(), name)) {
3164 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
3165 if (fmt.allOperands || !seenOperands.insert(V: operand).second)
3166 return emitError(loc, msg: "operand '" + name + "' is already bound");
3167 } else if (ctx == RefDirectiveContext && !seenOperands.count(V: operand)) {
3168 return emitError(loc, msg: "operand '" + name +
3169 "' must be bound before it is referenced");
3170 }
3171 return create<OperandVariable>(args&: operand);
3172 }
3173 // Regions
3174 if (const NamedRegion *region = findArg(range: op.getRegions(), name)) {
3175 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
3176 if (hasAllRegions || !seenRegions.insert(V: region).second)
3177 return emitError(loc, msg: "region '" + name + "' is already bound");
3178 } else if (ctx == RefDirectiveContext && !seenRegions.count(V: region)) {
3179 return emitError(loc, msg: "region '" + name +
3180 "' must be bound before it is referenced");
3181 } else {
3182 return emitError(loc, msg: "regions can only be used at the top level");
3183 }
3184 return create<RegionVariable>(args&: region);
3185 }
3186 // Results.
3187 if (const auto *result = findArg(range: op.getResults(), name)) {
3188 if (ctx != TypeDirectiveContext)
3189 return emitError(loc, msg: "result variables can can only be used as a child "
3190 "to a 'type' directive");
3191 return create<ResultVariable>(args&: result);
3192 }
3193 // Successors.
3194 if (const auto *successor = findArg(range: op.getSuccessors(), name)) {
3195 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
3196 if (hasAllSuccessors || !seenSuccessors.insert(V: successor).second)
3197 return emitError(loc, msg: "successor '" + name + "' is already bound");
3198 } else if (ctx == RefDirectiveContext && !seenSuccessors.count(V: successor)) {
3199 return emitError(loc, msg: "successor '" + name +
3200 "' must be bound before it is referenced");
3201 } else {
3202 return emitError(loc, msg: "successors can only be used at the top level");
3203 }
3204
3205 return create<SuccessorVariable>(args&: successor);
3206 }
3207 return emitError(loc, msg: "expected variable to refer to an argument, region, "
3208 "result, or successor");
3209}
3210
3211FailureOr<FormatElement *>
3212OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
3213 Context ctx) {
3214 switch (kind) {
3215 case FormatToken::kw_prop_dict:
3216 return parsePropDictDirective(loc, context: ctx);
3217 case FormatToken::kw_attr_dict:
3218 return parseAttrDictDirective(loc, context: ctx,
3219 /*withKeyword=*/false);
3220 case FormatToken::kw_attr_dict_w_keyword:
3221 return parseAttrDictDirective(loc, context: ctx,
3222 /*withKeyword=*/true);
3223 case FormatToken::kw_functional_type:
3224 return parseFunctionalTypeDirective(loc, context: ctx);
3225 case FormatToken::kw_operands:
3226 return parseOperandsDirective(loc, context: ctx);
3227 case FormatToken::kw_qualified:
3228 return parseQualifiedDirective(loc, context: ctx);
3229 case FormatToken::kw_regions:
3230 return parseRegionsDirective(loc, context: ctx);
3231 case FormatToken::kw_results:
3232 return parseResultsDirective(loc, context: ctx);
3233 case FormatToken::kw_successors:
3234 return parseSuccessorsDirective(loc, context: ctx);
3235 case FormatToken::kw_ref:
3236 return parseReferenceDirective(loc, context: ctx);
3237 case FormatToken::kw_type:
3238 return parseTypeDirective(loc, context: ctx);
3239 case FormatToken::kw_oilist:
3240 return parseOIListDirective(loc, context: ctx);
3241
3242 default:
3243 return emitError(loc, msg: "unsupported directive kind");
3244 }
3245}
3246
3247FailureOr<FormatElement *>
3248OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
3249 bool withKeyword) {
3250 if (context == TypeDirectiveContext)
3251 return emitError(loc, msg: "'attr-dict' directive can only be used as a "
3252 "top-level directive");
3253
3254 if (context == RefDirectiveContext) {
3255 if (!hasAttrDict)
3256 return emitError(loc, msg: "'ref' of 'attr-dict' is not bound by a prior "
3257 "'attr-dict' directive");
3258
3259 // Otherwise, this is a top-level context.
3260 } else {
3261 if (hasAttrDict)
3262 return emitError(loc, msg: "'attr-dict' directive has already been seen");
3263 hasAttrDict = true;
3264 }
3265
3266 return create<AttrDictDirective>(args&: withKeyword);
3267}
3268
3269FailureOr<FormatElement *>
3270OpFormatParser::parsePropDictDirective(SMLoc loc, Context context) {
3271 if (context == TypeDirectiveContext)
3272 return emitError(loc, msg: "'prop-dict' directive can only be used as a "
3273 "top-level directive");
3274
3275 if (context == RefDirectiveContext)
3276 llvm::report_fatal_error(reason: "'ref' of 'prop-dict' unsupported");
3277 // Otherwise, this is a top-level context.
3278
3279 if (hasPropDict)
3280 return emitError(loc, msg: "'prop-dict' directive has already been seen");
3281 hasPropDict = true;
3282
3283 return create<PropDictDirective>();
3284}
3285
3286LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
3287 SMLoc loc, ArrayRef<FormatElement *> arguments) {
3288 for (FormatElement *argument : arguments) {
3289 if (!isa<AttrDictDirective, PropDictDirective, AttributeVariable,
3290 OperandVariable, PropertyVariable, RefDirective, RegionVariable,
3291 SuccessorVariable, StringElement, TypeDirective>(Val: argument)) {
3292 // TODO: FormatElement should have location info attached.
3293 return emitError(loc, msg: "only variables and types may be used as "
3294 "parameters to a custom directive");
3295 }
3296 if (auto *type = dyn_cast<TypeDirective>(Val: argument)) {
3297 if (!isa<OperandVariable, ResultVariable>(Val: type->getArg())) {
3298 return emitError(loc, msg: "type directives within a custom directive may "
3299 "only refer to variables");
3300 }
3301 }
3302 }
3303 return success();
3304}
3305
3306FailureOr<FormatElement *>
3307OpFormatParser::parseFunctionalTypeDirective(SMLoc loc, Context context) {
3308 if (context != TopLevelContext)
3309 return emitError(
3310 loc, msg: "'functional-type' is only valid as a top-level directive");
3311
3312 // Parse the main operand.
3313 FailureOr<FormatElement *> inputs, results;
3314 if (failed(result: parseToken(kind: FormatToken::l_paren,
3315 msg: "expected '(' before argument list")) ||
3316 failed(result: inputs = parseTypeDirectiveOperand(loc)) ||
3317 failed(result: parseToken(kind: FormatToken::comma,
3318 msg: "expected ',' after inputs argument")) ||
3319 failed(result: results = parseTypeDirectiveOperand(loc)) ||
3320 failed(
3321 result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list")))
3322 return failure();
3323 return create<FunctionalTypeDirective>(args&: *inputs, args&: *results);
3324}
3325
3326FailureOr<FormatElement *>
3327OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) {
3328 if (context == RefDirectiveContext) {
3329 if (!fmt.allOperands)
3330 return emitError(loc, msg: "'ref' of 'operands' is not bound by a prior "
3331 "'operands' directive");
3332
3333 } else if (context == TopLevelContext || context == CustomDirectiveContext) {
3334 if (fmt.allOperands || !seenOperands.empty())
3335 return emitError(loc, msg: "'operands' directive creates overlap in format");
3336 fmt.allOperands = true;
3337 }
3338 return create<OperandsDirective>();
3339}
3340
3341FailureOr<FormatElement *>
3342OpFormatParser::parseReferenceDirective(SMLoc loc, Context context) {
3343 if (context != CustomDirectiveContext)
3344 return emitError(loc, msg: "'ref' is only valid within a `custom` directive");
3345
3346 FailureOr<FormatElement *> arg;
3347 if (failed(result: parseToken(kind: FormatToken::l_paren,
3348 msg: "expected '(' before argument list")) ||
3349 failed(result: arg = parseElement(ctx: RefDirectiveContext)) ||
3350 failed(
3351 result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list")))
3352 return failure();
3353
3354 return create<RefDirective>(args&: *arg);
3355}
3356
3357FailureOr<FormatElement *>
3358OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) {
3359 if (context == TypeDirectiveContext)
3360 return emitError(loc, msg: "'regions' is only valid as a top-level directive");
3361 if (context == RefDirectiveContext) {
3362 if (!hasAllRegions)
3363 return emitError(loc, msg: "'ref' of 'regions' is not bound by a prior "
3364 "'regions' directive");
3365
3366 // Otherwise, this is a TopLevel directive.
3367 } else {
3368 if (hasAllRegions || !seenRegions.empty())
3369 return emitError(loc, msg: "'regions' directive creates overlap in format");
3370 hasAllRegions = true;
3371 }
3372 return create<RegionsDirective>();
3373}
3374
3375FailureOr<FormatElement *>
3376OpFormatParser::parseResultsDirective(SMLoc loc, Context context) {
3377 if (context != TypeDirectiveContext)
3378 return emitError(loc, msg: "'results' directive can can only be used as a child "
3379 "to a 'type' directive");
3380 return create<ResultsDirective>();
3381}
3382
3383FailureOr<FormatElement *>
3384OpFormatParser::parseSuccessorsDirective(SMLoc loc, Context context) {
3385 if (context == TypeDirectiveContext)
3386 return emitError(loc,
3387 msg: "'successors' is only valid as a top-level directive");
3388 if (context == RefDirectiveContext) {
3389 if (!hasAllSuccessors)
3390 return emitError(loc, msg: "'ref' of 'successors' is not bound by a prior "
3391 "'successors' directive");
3392
3393 // Otherwise, this is a TopLevel directive.
3394 } else {
3395 if (hasAllSuccessors || !seenSuccessors.empty())
3396 return emitError(loc, msg: "'successors' directive creates overlap in format");
3397 hasAllSuccessors = true;
3398 }
3399 return create<SuccessorsDirective>();
3400}
3401
3402FailureOr<FormatElement *>
3403OpFormatParser::parseOIListDirective(SMLoc loc, Context context) {
3404 if (failed(result: parseToken(kind: FormatToken::l_paren,
3405 msg: "expected '(' before oilist argument list")))
3406 return failure();
3407 std::vector<FormatElement *> literalElements;
3408 std::vector<std::vector<FormatElement *>> parsingElements;
3409 do {
3410 FailureOr<FormatElement *> lelement = parseLiteral(ctx: context);
3411 if (failed(result: lelement))
3412 return failure();
3413 literalElements.push_back(x: *lelement);
3414 parsingElements.emplace_back();
3415 std::vector<FormatElement *> &currParsingElements = parsingElements.back();
3416 while (peekToken().getKind() != FormatToken::pipe &&
3417 peekToken().getKind() != FormatToken::r_paren) {
3418 FailureOr<FormatElement *> pelement = parseElement(ctx: context);
3419 if (failed(result: pelement) ||
3420 failed(result: verifyOIListParsingElement(element: *pelement, loc)))
3421 return failure();
3422 currParsingElements.push_back(x: *pelement);
3423 }
3424 if (peekToken().getKind() == FormatToken::pipe) {
3425 consumeToken();
3426 continue;
3427 }
3428 if (peekToken().getKind() == FormatToken::r_paren) {
3429 consumeToken();
3430 break;
3431 }
3432 } while (true);
3433
3434 return create<OIListElement>(args: std::move(literalElements),
3435 args: std::move(parsingElements));
3436}
3437
3438LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element,
3439 SMLoc loc) {
3440 SmallVector<VariableElement *> vars;
3441 collect(element, variables&: vars);
3442 for (VariableElement *elem : vars) {
3443 LogicalResult res =
3444 TypeSwitch<FormatElement *, LogicalResult>(elem)
3445 // Only optional attributes can be within an oilist parsing group.
3446 .Case(caseFn: [&](AttributeVariable *attrEle) {
3447 if (!attrEle->getVar()->attr.isOptional() &&
3448 !attrEle->getVar()->attr.hasDefaultValue())
3449 return emitError(loc, msg: "only optional attributes can be used in "
3450 "an oilist parsing group");
3451 return success();
3452 })
3453 // Only optional-like(i.e. variadic) operands can be within an
3454 // oilist parsing group.
3455 .Case(caseFn: [&](OperandVariable *ele) {
3456 if (!ele->getVar()->isVariableLength())
3457 return emitError(loc, msg: "only variable length operands can be "
3458 "used within an oilist parsing group");
3459 return success();
3460 })
3461 // Only optional-like(i.e. variadic) results can be within an oilist
3462 // parsing group.
3463 .Case(caseFn: [&](ResultVariable *ele) {
3464 if (!ele->getVar()->isVariableLength())
3465 return emitError(loc, msg: "only variable length results can be "
3466 "used within an oilist parsing group");
3467 return success();
3468 })
3469 .Case(caseFn: [&](RegionVariable *) { return success(); })
3470 .Default(defaultFn: [&](FormatElement *) {
3471 return emitError(loc,
3472 msg: "only literals, types, and variables can be "
3473 "used within an oilist group");
3474 });
3475 if (failed(result: res))
3476 return failure();
3477 }
3478 return success();
3479}
3480
3481FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
3482 Context context) {
3483 if (context == TypeDirectiveContext)
3484 return emitError(loc, msg: "'type' cannot be used as a child of another `type`");
3485
3486 bool isRefChild = context == RefDirectiveContext;
3487 FailureOr<FormatElement *> operand;
3488 if (failed(result: parseToken(kind: FormatToken::l_paren,
3489 msg: "expected '(' before argument list")) ||
3490 failed(result: operand = parseTypeDirectiveOperand(loc, isRefChild)) ||
3491 failed(
3492 result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list")))
3493 return failure();
3494
3495 return create<TypeDirective>(args&: *operand);
3496}
3497
3498FailureOr<FormatElement *>
3499OpFormatParser::parseQualifiedDirective(SMLoc loc, Context context) {
3500 FailureOr<FormatElement *> element;
3501 if (failed(result: parseToken(kind: FormatToken::l_paren,
3502 msg: "expected '(' before argument list")) ||
3503 failed(result: element = parseElement(ctx: context)) ||
3504 failed(
3505 result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list")))
3506 return failure();
3507 return TypeSwitch<FormatElement *, FailureOr<FormatElement *>>(*element)
3508 .Case<AttributeVariable, TypeDirective>(caseFn: [](auto *element) {
3509 element->setShouldBeQualified();
3510 return element;
3511 })
3512 .Default(defaultFn: [&](auto *element) {
3513 return this->emitError(
3514 loc,
3515 msg: "'qualified' directive expects an attribute or a `type` directive");
3516 });
3517}
3518
3519FailureOr<FormatElement *>
3520OpFormatParser::parseTypeDirectiveOperand(SMLoc loc, bool isRefChild) {
3521 FailureOr<FormatElement *> result = parseElement(ctx: TypeDirectiveContext);
3522 if (failed(result))
3523 return failure();
3524
3525 FormatElement *element = *result;
3526 if (isa<LiteralElement>(Val: element))
3527 return emitError(
3528 loc, msg: "'type' directive operand expects variable or directive operand");
3529
3530 if (auto *var = dyn_cast<OperandVariable>(Val: element)) {
3531 unsigned opIdx = var->getVar() - op.operand_begin();
3532 if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(Idx: opIdx)))
3533 return emitError(loc, msg: "'type' of '" + var->getVar()->name +
3534 "' is already bound");
3535 if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(Idx: opIdx)))
3536 return emitError(loc, msg: "'ref' of 'type($" + var->getVar()->name +
3537 ")' is not bound by a prior 'type' directive");
3538 seenOperandTypes.set(opIdx);
3539 } else if (auto *var = dyn_cast<ResultVariable>(Val: element)) {
3540 unsigned resIdx = var->getVar() - op.result_begin();
3541 if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(Idx: resIdx)))
3542 return emitError(loc, msg: "'type' of '" + var->getVar()->name +
3543 "' is already bound");
3544 if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(Idx: resIdx)))
3545 return emitError(loc, msg: "'ref' of 'type($" + var->getVar()->name +
3546 ")' is not bound by a prior 'type' directive");
3547 seenResultTypes.set(resIdx);
3548 } else if (isa<OperandsDirective>(Val: &*element)) {
3549 if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any()))
3550 return emitError(loc, msg: "'operands' 'type' is already bound");
3551 if (isRefChild && !fmt.allOperandTypes)
3552 return emitError(loc, msg: "'ref' of 'type(operands)' is not bound by a prior "
3553 "'type' directive");
3554 fmt.allOperandTypes = true;
3555 } else if (isa<ResultsDirective>(Val: &*element)) {
3556 if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any()))
3557 return emitError(loc, msg: "'results' 'type' is already bound");
3558 if (isRefChild && !fmt.allResultTypes)
3559 return emitError(loc, msg: "'ref' of 'type(results)' is not bound by a prior "
3560 "'type' directive");
3561 fmt.allResultTypes = true;
3562 } else {
3563 return emitError(loc, msg: "invalid argument to 'type' directive");
3564 }
3565 return element;
3566}
3567
3568LogicalResult OpFormatParser::verifyOptionalGroupElements(
3569 SMLoc loc, ArrayRef<FormatElement *> elements, FormatElement *anchor) {
3570 for (FormatElement *element : elements) {
3571 if (failed(result: verifyOptionalGroupElement(loc, element, isAnchor: element == anchor)))
3572 return failure();
3573 }
3574 return success();
3575}
3576
3577LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
3578 FormatElement *element,
3579 bool isAnchor) {
3580 return TypeSwitch<FormatElement *, LogicalResult>(element)
3581 // All attributes can be within the optional group, but only optional
3582 // attributes can be the anchor.
3583 .Case(caseFn: [&](AttributeVariable *attrEle) {
3584 Attribute attr = attrEle->getVar()->attr;
3585 if (isAnchor && !(attr.isOptional() || attr.hasDefaultValue()))
3586 return emitError(loc, msg: "only optional or default-valued attributes "
3587 "can be used to anchor an optional group");
3588 return success();
3589 })
3590 // Only optional-like(i.e. variadic) operands can be within an optional
3591 // group.
3592 .Case(caseFn: [&](OperandVariable *ele) {
3593 if (!ele->getVar()->isVariableLength())
3594 return emitError(loc, msg: "only variable length operands can be used "
3595 "within an optional group");
3596 return success();
3597 })
3598 // Only optional-like(i.e. variadic) results can be within an optional
3599 // group.
3600 .Case(caseFn: [&](ResultVariable *ele) {
3601 if (!ele->getVar()->isVariableLength())
3602 return emitError(loc, msg: "only variable length results can be used "
3603 "within an optional group");
3604 return success();
3605 })
3606 .Case(caseFn: [&](RegionVariable *) {
3607 // TODO: When ODS has proper support for marking "optional" regions, add
3608 // a check here.
3609 return success();
3610 })
3611 .Case(caseFn: [&](TypeDirective *ele) {
3612 return verifyOptionalGroupElement(loc, element: ele->getArg(),
3613 /*isAnchor=*/false);
3614 })
3615 .Case(caseFn: [&](FunctionalTypeDirective *ele) {
3616 if (failed(result: verifyOptionalGroupElement(loc, element: ele->getInputs(),
3617 /*isAnchor=*/false)))
3618 return failure();
3619 return verifyOptionalGroupElement(loc, element: ele->getResults(),
3620 /*isAnchor=*/false);
3621 })
3622 .Case(caseFn: [&](CustomDirective *ele) {
3623 if (!isAnchor)
3624 return success();
3625 // Verify each child as being valid in an optional group. They are all
3626 // potential anchors if the custom directive was marked as one.
3627 for (FormatElement *child : ele->getArguments()) {
3628 if (isa<RefDirective>(Val: child))
3629 continue;
3630 if (failed(result: verifyOptionalGroupElement(loc, element: child, /*isAnchor=*/true)))
3631 return failure();
3632 }
3633 return success();
3634 })
3635 // Literals, whitespace, and custom directives may be used, but they can't
3636 // anchor the group.
3637 .Case<LiteralElement, WhitespaceElement, OptionalElement>(
3638 caseFn: [&](FormatElement *) {
3639 if (isAnchor)
3640 return emitError(loc, msg: "only variables and types can be used "
3641 "to anchor an optional group");
3642 return success();
3643 })
3644 .Default(defaultFn: [&](FormatElement *) {
3645 return emitError(loc, msg: "only literals, types, and variables can be "
3646 "used within an optional group");
3647 });
3648}
3649
3650//===----------------------------------------------------------------------===//
3651// Interface
3652//===----------------------------------------------------------------------===//
3653
3654void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) {
3655 // TODO: Operator doesn't expose all necessary functionality via
3656 // the const interface.
3657 Operator &op = const_cast<Operator &>(constOp);
3658 if (!op.hasAssemblyFormat())
3659 return;
3660
3661 // Parse the format description.
3662 llvm::SourceMgr mgr;
3663 mgr.AddNewSourceBuffer(
3664 F: llvm::MemoryBuffer::getMemBuffer(InputData: op.getAssemblyFormat()), IncludeLoc: SMLoc());
3665 OperationFormat format(op);
3666 OpFormatParser parser(mgr, format, op);
3667 FailureOr<std::vector<FormatElement *>> elements = parser.parse();
3668 if (failed(result: elements)) {
3669 // Exit the process if format errors are treated as fatal.
3670 if (formatErrorIsFatal) {
3671 // Invoke the interrupt handlers to run the file cleanup handlers.
3672 llvm::sys::RunInterruptHandlers();
3673 std::exit(status: 1);
3674 }
3675 return;
3676 }
3677 format.elements = std::move(*elements);
3678
3679 // Generate the printer and parser based on the parsed format.
3680 format.genParser(op, opClass);
3681 format.genPrinter(op, opClass);
3682}
3683

source code of mlir/tools/mlir-tblgen/OpFormatGen.cpp