1 | //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "OpFormatGen.h" |
10 | #include "FormatGen.h" |
11 | #include "OpClass.h" |
12 | #include "mlir/Support/LLVM.h" |
13 | #include "mlir/TableGen/Class.h" |
14 | #include "mlir/TableGen/EnumInfo.h" |
15 | #include "mlir/TableGen/Format.h" |
16 | #include "mlir/TableGen/Operator.h" |
17 | #include "mlir/TableGen/Trait.h" |
18 | #include "llvm/ADT/MapVector.h" |
19 | #include "llvm/ADT/Sequence.h" |
20 | #include "llvm/ADT/SetVector.h" |
21 | #include "llvm/ADT/SmallBitVector.h" |
22 | #include "llvm/ADT/StringExtras.h" |
23 | #include "llvm/ADT/TypeSwitch.h" |
24 | #include "llvm/Support/Signals.h" |
25 | #include "llvm/Support/SourceMgr.h" |
26 | #include "llvm/TableGen/Record.h" |
27 | |
28 | #define DEBUG_TYPE "mlir-tblgen-opformatgen" |
29 | |
30 | using namespace mlir; |
31 | using namespace mlir::tblgen; |
32 | using llvm::formatv; |
33 | using llvm::Record; |
34 | using llvm::StringMap; |
35 | |
36 | //===----------------------------------------------------------------------===// |
37 | // VariableElement |
38 | //===----------------------------------------------------------------------===// |
39 | |
40 | namespace { |
41 | /// This class represents an instance of an op variable element. A variable |
42 | /// refers to something registered on the operation itself, e.g. an operand, |
43 | /// result, attribute, region, or successor. |
44 | template <typename VarT, VariableElement::Kind VariableKind> |
45 | class OpVariableElement : public VariableElementBase<VariableKind> { |
46 | public: |
47 | using Base = OpVariableElement<VarT, VariableKind>; |
48 | |
49 | /// Create an op variable element with the variable value. |
50 | OpVariableElement(const VarT *var) : var(var) {} |
51 | |
52 | /// Get the variable. |
53 | const VarT *getVar() const { return var; } |
54 | |
55 | protected: |
56 | /// The op variable, e.g. a type or attribute constraint. |
57 | const VarT *var; |
58 | }; |
59 | |
60 | /// This class represents a variable that refers to an attribute argument. |
61 | struct AttributeVariable |
62 | : public OpVariableElement<NamedAttribute, VariableElement::Attribute> { |
63 | using Base::Base; |
64 | |
65 | /// Return the constant builder call for the type of this attribute, or |
66 | /// std::nullopt if it doesn't have one. |
67 | std::optional<StringRef> getTypeBuilder() const { |
68 | std::optional<Type> attrType = var->attr.getValueType(); |
69 | return attrType ? attrType->getBuilderCall() : std::nullopt; |
70 | } |
71 | |
72 | /// Indicate if this attribute is printed "qualified" (that is it is |
73 | /// prefixed with the `#dialect.mnemonic`). |
74 | bool shouldBeQualified() { return shouldBeQualifiedFlag; } |
75 | void setShouldBeQualified(bool qualified = true) { |
76 | shouldBeQualifiedFlag = qualified; |
77 | } |
78 | |
79 | private: |
80 | bool shouldBeQualifiedFlag = false; |
81 | }; |
82 | |
83 | /// This class represents a variable that refers to an operand argument. |
84 | using OperandVariable = |
85 | OpVariableElement<NamedTypeConstraint, VariableElement::Operand>; |
86 | |
87 | /// This class represents a variable that refers to a result. |
88 | using ResultVariable = |
89 | OpVariableElement<NamedTypeConstraint, VariableElement::Result>; |
90 | |
91 | /// This class represents a variable that refers to a region. |
92 | using RegionVariable = OpVariableElement<NamedRegion, VariableElement::Region>; |
93 | |
94 | /// This class represents a variable that refers to a successor. |
95 | using SuccessorVariable = |
96 | OpVariableElement<NamedSuccessor, VariableElement::Successor>; |
97 | |
98 | /// This class represents a variable that refers to a property argument. |
99 | using PropertyVariable = |
100 | OpVariableElement<NamedProperty, VariableElement::Property>; |
101 | |
102 | /// LLVM RTTI helper for attribute-like variables, that is, attributes or |
103 | /// properties. This allows for common handling of attributes and properties in |
104 | /// parts of the code that are oblivious to whether something is stored as an |
105 | /// attribute or a property. |
106 | struct AttributeLikeVariable : public VariableElement { |
107 | enum { AttributeLike = 1 << 0 }; |
108 | |
109 | static bool classof(const VariableElement *ve) { |
110 | return ve->getKind() == VariableElement::Attribute || |
111 | ve->getKind() == VariableElement::Property; |
112 | } |
113 | |
114 | static bool classof(const FormatElement *fe) { |
115 | return isa<VariableElement>(Val: fe) && classof(ve: cast<VariableElement>(Val: fe)); |
116 | } |
117 | |
118 | /// Returns true if the variable is a UnitAttr or a UnitProp. |
119 | bool isUnit() const { |
120 | if (const auto *attr = dyn_cast<AttributeVariable>(Val: this)) |
121 | return attr->getVar()->attr.getBaseAttr().getAttrDefName() == "UnitAttr" ; |
122 | if (const auto *prop = dyn_cast<PropertyVariable>(Val: this)) { |
123 | StringRef baseDefName = |
124 | prop->getVar()->prop.getBaseProperty().getPropertyDefName(); |
125 | // Note: remove the `UnitProperty` case once the deprecation period is |
126 | // over. |
127 | return baseDefName == "UnitProp" || baseDefName == "UnitProperty" ; |
128 | } |
129 | llvm_unreachable("Type that wasn't listed in classof()" ); |
130 | } |
131 | |
132 | StringRef getName() const { |
133 | if (const auto *attr = dyn_cast<AttributeVariable>(Val: this)) |
134 | return attr->getVar()->name; |
135 | if (const auto *prop = dyn_cast<PropertyVariable>(Val: this)) |
136 | return prop->getVar()->name; |
137 | llvm_unreachable("Type that wasn't listed in classof()" ); |
138 | } |
139 | }; |
140 | } // namespace |
141 | |
142 | //===----------------------------------------------------------------------===// |
143 | // DirectiveElement |
144 | //===----------------------------------------------------------------------===// |
145 | |
146 | namespace { |
147 | /// This class represents the `operands` directive. This directive represents |
148 | /// all of the operands of an operation. |
149 | using OperandsDirective = DirectiveElementBase<DirectiveElement::Operands>; |
150 | |
151 | /// This class represents the `results` directive. This directive represents |
152 | /// all of the results of an operation. |
153 | using ResultsDirective = DirectiveElementBase<DirectiveElement::Results>; |
154 | |
155 | /// This class represents the `regions` directive. This directive represents |
156 | /// all of the regions of an operation. |
157 | using RegionsDirective = DirectiveElementBase<DirectiveElement::Regions>; |
158 | |
159 | /// This class represents the `successors` directive. This directive represents |
160 | /// all of the successors of an operation. |
161 | using SuccessorsDirective = DirectiveElementBase<DirectiveElement::Successors>; |
162 | |
163 | /// This class represents the `attr-dict` directive. This directive represents |
164 | /// the attribute dictionary of the operation. |
165 | class AttrDictDirective |
166 | : public DirectiveElementBase<DirectiveElement::AttrDict> { |
167 | public: |
168 | explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {} |
169 | |
170 | /// Return whether the dictionary should be printed with the 'attributes' |
171 | /// keyword. |
172 | bool isWithKeyword() const { return withKeyword; } |
173 | |
174 | private: |
175 | /// If the dictionary should be printed with the 'attributes' keyword. |
176 | bool withKeyword; |
177 | }; |
178 | |
179 | /// This class represents the `prop-dict` directive. This directive represents |
180 | /// the properties of the operation, expressed as a directionary. |
181 | class PropDictDirective |
182 | : public DirectiveElementBase<DirectiveElement::PropDict> { |
183 | public: |
184 | explicit PropDictDirective() = default; |
185 | }; |
186 | |
187 | /// This class represents the `functional-type` directive. This directive takes |
188 | /// two arguments and formats them, respectively, as the inputs and results of a |
189 | /// FunctionType. |
190 | class FunctionalTypeDirective |
191 | : public DirectiveElementBase<DirectiveElement::FunctionalType> { |
192 | public: |
193 | FunctionalTypeDirective(FormatElement *inputs, FormatElement *results) |
194 | : inputs(inputs), results(results) {} |
195 | |
196 | FormatElement *getInputs() const { return inputs; } |
197 | FormatElement *getResults() const { return results; } |
198 | |
199 | private: |
200 | /// The input and result arguments. |
201 | FormatElement *inputs, *results; |
202 | }; |
203 | |
204 | /// This class represents the `type` directive. |
205 | class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> { |
206 | public: |
207 | TypeDirective(FormatElement *arg) : arg(arg) {} |
208 | |
209 | FormatElement *getArg() const { return arg; } |
210 | |
211 | /// Indicate if this type is printed "qualified" (that is it is |
212 | /// prefixed with the `!dialect.mnemonic`). |
213 | bool shouldBeQualified() { return shouldBeQualifiedFlag; } |
214 | void setShouldBeQualified(bool qualified = true) { |
215 | shouldBeQualifiedFlag = qualified; |
216 | } |
217 | |
218 | private: |
219 | /// The argument that is used to format the directive. |
220 | FormatElement *arg; |
221 | |
222 | bool shouldBeQualifiedFlag = false; |
223 | }; |
224 | |
225 | /// This class represents a group of order-independent optional clauses. Each |
226 | /// clause starts with a literal element and has a coressponding parsing |
227 | /// element. A parsing element is a continous sequence of format elements. |
228 | /// Each clause can appear 0 or 1 time. |
229 | class OIListElement : public DirectiveElementBase<DirectiveElement::OIList> { |
230 | public: |
231 | OIListElement(std::vector<FormatElement *> &&literalElements, |
232 | std::vector<std::vector<FormatElement *>> &&parsingElements) |
233 | : literalElements(std::move(literalElements)), |
234 | parsingElements(std::move(parsingElements)) {} |
235 | |
236 | /// Returns a range to iterate over the LiteralElements. |
237 | auto getLiteralElements() const { |
238 | return llvm::map_range(C: literalElements, F: [](FormatElement *el) { |
239 | return cast<LiteralElement>(Val: el); |
240 | }); |
241 | } |
242 | |
243 | /// Returns a range to iterate over the parsing elements corresponding to the |
244 | /// clauses. |
245 | ArrayRef<std::vector<FormatElement *>> getParsingElements() const { |
246 | return parsingElements; |
247 | } |
248 | |
249 | /// Returns a range to iterate over tuples of parsing and literal elements. |
250 | auto getClauses() const { |
251 | return llvm::zip(t: getLiteralElements(), u: getParsingElements()); |
252 | } |
253 | |
254 | /// If the parsing element is a single UnitAttr element, then it returns the |
255 | /// attribute variable. Otherwise, returns nullptr. |
256 | AttributeLikeVariable * |
257 | getUnitVariableParsingElement(ArrayRef<FormatElement *> pelement) { |
258 | if (pelement.size() == 1) { |
259 | auto *attrElem = dyn_cast<AttributeLikeVariable>(Val: pelement[0]); |
260 | if (attrElem && attrElem->isUnit()) |
261 | return attrElem; |
262 | } |
263 | return nullptr; |
264 | } |
265 | |
266 | private: |
267 | /// A vector of `LiteralElement` objects. Each element stores the keyword |
268 | /// for one case of oilist element. For example, an oilist element along with |
269 | /// the `literalElements` vector: |
270 | /// ``` |
271 | /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`] |
272 | /// literalElements = { `keyword`, `otherKeyword` } |
273 | /// ``` |
274 | std::vector<FormatElement *> literalElements; |
275 | |
276 | /// A vector of valid declarative assembly format vectors. Each object in |
277 | /// parsing elements is a vector of elements in assembly format syntax. |
278 | /// For example, an oilist element along with the parsingElements vector: |
279 | /// ``` |
280 | /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`] |
281 | /// parsingElements = { |
282 | /// { `=`, `(`, $arg0, `)` }, |
283 | /// { `<`, $arg1, `>` } |
284 | /// } |
285 | /// ``` |
286 | std::vector<std::vector<FormatElement *>> parsingElements; |
287 | }; |
288 | } // namespace |
289 | |
290 | //===----------------------------------------------------------------------===// |
291 | // OperationFormat |
292 | //===----------------------------------------------------------------------===// |
293 | |
294 | namespace { |
295 | |
296 | using ConstArgument = |
297 | llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>; |
298 | |
299 | struct OperationFormat { |
300 | /// This class represents a specific resolver for an operand or result type. |
301 | class TypeResolution { |
302 | public: |
303 | TypeResolution() = default; |
304 | |
305 | /// Get the index into the buildable types for this type, or std::nullopt. |
306 | std::optional<int> getBuilderIdx() const { return builderIdx; } |
307 | void setBuilderIdx(int idx) { builderIdx = idx; } |
308 | |
309 | /// Get the variable this type is resolved to, or nullptr. |
310 | const NamedTypeConstraint *getVariable() const { |
311 | return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(Val: resolver); |
312 | } |
313 | /// Get the attribute this type is resolved to, or nullptr. |
314 | const NamedAttribute *getAttribute() const { |
315 | return llvm::dyn_cast_if_present<const NamedAttribute *>(Val: resolver); |
316 | } |
317 | /// Get the transformer for the type of the variable, or std::nullopt. |
318 | std::optional<StringRef> getVarTransformer() const { |
319 | return variableTransformer; |
320 | } |
321 | void setResolver(ConstArgument arg, std::optional<StringRef> transformer) { |
322 | resolver = arg; |
323 | variableTransformer = transformer; |
324 | assert(getVariable() || getAttribute()); |
325 | } |
326 | |
327 | private: |
328 | /// If the type is resolved with a buildable type, this is the index into |
329 | /// 'buildableTypes' in the parent format. |
330 | std::optional<int> builderIdx; |
331 | /// If the type is resolved based upon another operand or result, this is |
332 | /// the variable or the attribute that this type is resolved to. |
333 | ConstArgument resolver; |
334 | /// If the type is resolved based upon another operand or result, this is |
335 | /// a transformer to apply to the variable when resolving. |
336 | std::optional<StringRef> variableTransformer; |
337 | }; |
338 | |
339 | /// The context in which an element is generated. |
340 | enum class GenContext { |
341 | /// The element is generated at the top-level or with the same behaviour. |
342 | Normal, |
343 | /// The element is generated inside an optional group. |
344 | Optional |
345 | }; |
346 | |
347 | OperationFormat(const Operator &op, bool hasProperties) |
348 | : useProperties(hasProperties), opCppClassName(op.getCppClassName()) { |
349 | operandTypes.resize(new_size: op.getNumOperands(), x: TypeResolution()); |
350 | resultTypes.resize(new_size: op.getNumResults(), x: TypeResolution()); |
351 | |
352 | hasImplicitTermTrait = llvm::any_of(Range: op.getTraits(), P: [](const Trait &trait) { |
353 | return trait.getDef().isSubClassOf(Name: "SingleBlockImplicitTerminatorImpl" ); |
354 | }); |
355 | |
356 | hasSingleBlockTrait = op.getTrait(trait: "::mlir::OpTrait::SingleBlock" ); |
357 | } |
358 | |
359 | /// Generate the operation parser from this format. |
360 | void genParser(Operator &op, OpClass &opClass); |
361 | /// Generate the parser code for a specific format element. |
362 | void genElementParser(FormatElement *element, MethodBody &body, |
363 | FmtContext &attrTypeCtx, |
364 | GenContext genCtx = GenContext::Normal); |
365 | /// Generate the C++ to resolve the types of operands and results during |
366 | /// parsing. |
367 | void genParserTypeResolution(Operator &op, MethodBody &body); |
368 | /// Generate the C++ to resolve the types of the operands during parsing. |
369 | void genParserOperandTypeResolution( |
370 | Operator &op, MethodBody &body, |
371 | function_ref<void(TypeResolution &, StringRef)> emitTypeResolver); |
372 | /// Generate the C++ to resolve regions during parsing. |
373 | void genParserRegionResolution(Operator &op, MethodBody &body); |
374 | /// Generate the C++ to resolve successors during parsing. |
375 | void genParserSuccessorResolution(Operator &op, MethodBody &body); |
376 | /// Generate the C++ to handling variadic segment size traits. |
377 | void genParserVariadicSegmentResolution(Operator &op, MethodBody &body); |
378 | |
379 | /// Generate the operation printer from this format. |
380 | void genPrinter(Operator &op, OpClass &opClass); |
381 | |
382 | /// Generate the printer code for a specific format element. |
383 | void genElementPrinter(FormatElement *element, MethodBody &body, Operator &op, |
384 | bool &shouldEmitSpace, bool &lastWasPunctuation); |
385 | |
386 | /// The various elements in this format. |
387 | std::vector<FormatElement *> elements; |
388 | |
389 | /// A flag indicating if all operand/result types were seen. If the format |
390 | /// contains these, it can not contain individual type resolvers. |
391 | bool allOperands = false, allOperandTypes = false, allResultTypes = false; |
392 | |
393 | /// A flag indicating if this operation infers its result types |
394 | bool infersResultTypes = false; |
395 | |
396 | /// A flag indicating if this operation has the SingleBlockImplicitTerminator |
397 | /// trait. |
398 | bool hasImplicitTermTrait; |
399 | |
400 | /// A flag indicating if this operation has the SingleBlock trait. |
401 | bool hasSingleBlockTrait; |
402 | |
403 | /// Indicate whether we need to use properties for the current operator. |
404 | bool useProperties; |
405 | |
406 | /// Indicate whether prop-dict is used in the format |
407 | bool hasPropDict; |
408 | |
409 | /// The Operation class name |
410 | StringRef opCppClassName; |
411 | |
412 | /// A map of buildable types to indices. |
413 | llvm::MapVector<StringRef, int, StringMap<int>> buildableTypes; |
414 | |
415 | /// The index of the buildable type, if valid, for every operand and result. |
416 | std::vector<TypeResolution> operandTypes, resultTypes; |
417 | |
418 | /// The set of attributes explicitly used within the format. |
419 | llvm::SmallSetVector<const NamedAttribute *, 8> usedAttributes; |
420 | llvm::StringSet<> inferredAttributes; |
421 | |
422 | /// The set of properties explicitly used within the format. |
423 | llvm::SmallSetVector<const NamedProperty *, 8> usedProperties; |
424 | }; |
425 | } // namespace |
426 | |
427 | //===----------------------------------------------------------------------===// |
428 | // Parser Gen |
429 | //===----------------------------------------------------------------------===// |
430 | |
431 | /// Returns true if we can format the given attribute as an enum in the |
432 | /// parser format. |
433 | static bool canFormatEnumAttr(const NamedAttribute *attr) { |
434 | Attribute baseAttr = attr->attr.getBaseAttr(); |
435 | if (!baseAttr.isEnumAttr()) |
436 | return false; |
437 | EnumInfo enumInfo(&baseAttr.getDef()); |
438 | |
439 | // The attribute must have a valid underlying type and a constant builder. |
440 | return !enumInfo.getUnderlyingType().empty() && |
441 | !baseAttr.getConstBuilderTemplate().empty(); |
442 | } |
443 | |
444 | /// Returns if we should format the given attribute as an SymbolNameAttr. |
445 | static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) { |
446 | return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr" ; |
447 | } |
448 | |
449 | /// The code snippet used to generate a parser call for an attribute. |
450 | /// |
451 | /// {0}: The name of the attribute. |
452 | /// {1}: The type for the attribute. |
453 | const char *const attrParserCode = R"( |
454 | if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{ |
455 | return ::mlir::failure(); |
456 | } |
457 | )" ; |
458 | |
459 | /// The code snippet used to generate a parser call for an attribute. |
460 | /// |
461 | /// {0}: The name of the attribute. |
462 | /// {1}: The type for the attribute. |
463 | const char *const genericAttrParserCode = R"( |
464 | if (parser.parseAttribute({0}Attr, {1})) |
465 | return ::mlir::failure(); |
466 | )" ; |
467 | |
468 | const char *const optionalAttrParserCode = R"( |
469 | ::mlir::OptionalParseResult parseResult{0}Attr = |
470 | parser.parseOptionalAttribute({0}Attr, {1}); |
471 | if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr)) |
472 | return ::mlir::failure(); |
473 | if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr)) |
474 | )" ; |
475 | |
476 | /// The code snippet used to generate a parser call for a symbol name attribute. |
477 | /// |
478 | /// {0}: The name of the attribute. |
479 | const char *const symbolNameAttrParserCode = R"( |
480 | if (parser.parseSymbolName({0}Attr)) |
481 | return ::mlir::failure(); |
482 | )" ; |
483 | const char *const optionalSymbolNameAttrParserCode = R"( |
484 | // Parsing an optional symbol name doesn't fail, so no need to check the |
485 | // result. |
486 | (void)parser.parseOptionalSymbolName({0}Attr); |
487 | )" ; |
488 | |
489 | /// The code snippet used to generate a parser call for an enum attribute. |
490 | /// |
491 | /// {0}: The name of the attribute. |
492 | /// {1}: The c++ namespace for the enum symbolize functions. |
493 | /// {2}: The function to symbolize a string of the enum. |
494 | /// {3}: The constant builder call to create an attribute of the enum type. |
495 | /// {4}: The set of allowed enum keywords. |
496 | /// {5}: The error message on failure when the enum isn't present. |
497 | /// {6}: The attribute assignment expression |
498 | const char *const enumAttrParserCode = R"( |
499 | { |
500 | ::llvm::StringRef attrStr; |
501 | ::mlir::NamedAttrList attrStorage; |
502 | auto loc = parser.getCurrentLocation(); |
503 | if (parser.parseOptionalKeyword(&attrStr, {4})) { |
504 | ::mlir::StringAttr attrVal; |
505 | ::mlir::OptionalParseResult parseResult = |
506 | parser.parseOptionalAttribute(attrVal, |
507 | parser.getBuilder().getNoneType(), |
508 | "{0}", attrStorage); |
509 | if (parseResult.has_value()) {{ |
510 | if (failed(*parseResult)) |
511 | return ::mlir::failure(); |
512 | attrStr = attrVal.getValue(); |
513 | } else { |
514 | {5} |
515 | } |
516 | } |
517 | if (!attrStr.empty()) { |
518 | auto attrOptional = {1}::{2}(attrStr); |
519 | if (!attrOptional) |
520 | return parser.emitError(loc, "invalid ") |
521 | << "{0} attribute specification: \"" << attrStr << '"';; |
522 | |
523 | {0}Attr = {3}; |
524 | {6} |
525 | } |
526 | } |
527 | )" ; |
528 | |
529 | /// The code snippet used to generate a parser call for a property. |
530 | /// {0}: The name of the property |
531 | /// {1}: The C++ class name of the operation |
532 | /// {2}: The property's parser code with appropriate substitutions performed |
533 | /// {3}: The description of the expected property for the error message. |
534 | const char *const propertyParserCode = R"( |
535 | auto {0}PropLoc = parser.getCurrentLocation(); |
536 | auto {0}PropParseResult = [&](auto& propStorage) -> ::mlir::ParseResult {{ |
537 | {2} |
538 | return ::mlir::success(); |
539 | }(result.getOrAddProperties<{1}::Properties>().{0}); |
540 | if (failed({0}PropParseResult)) {{ |
541 | return parser.emitError({0}PropLoc, "invalid value for property {0}, expected {3}"); |
542 | } |
543 | )" ; |
544 | |
545 | /// The code snippet used to generate a parser call for a property. |
546 | /// {0}: The name of the property |
547 | /// {1}: The C++ class name of the operation |
548 | /// {2}: The property's parser code with appropriate substitutions performed |
549 | const char *const optionalPropertyParserCode = R"( |
550 | auto {0}PropParseResult = [&](auto& propStorage) -> ::mlir::OptionalParseResult {{ |
551 | {2} |
552 | return ::mlir::success(); |
553 | }(result.getOrAddProperties<{1}::Properties>().{0}); |
554 | if ({0}PropParseResult.has_value() && failed(*{0}PropParseResult)) {{ |
555 | return ::mlir::failure(); |
556 | } |
557 | )" ; |
558 | |
559 | /// The code snippet used to generate a parser call for an operand. |
560 | /// |
561 | /// {0}: The name of the operand. |
562 | const char *const variadicOperandParserCode = R"( |
563 | {0}OperandsLoc = parser.getCurrentLocation(); |
564 | if (parser.parseOperandList({0}Operands)) |
565 | return ::mlir::failure(); |
566 | )" ; |
567 | const char *const optionalOperandParserCode = R"( |
568 | { |
569 | {0}OperandsLoc = parser.getCurrentLocation(); |
570 | ::mlir::OpAsmParser::UnresolvedOperand operand; |
571 | ::mlir::OptionalParseResult parseResult = |
572 | parser.parseOptionalOperand(operand); |
573 | if (parseResult.has_value()) { |
574 | if (failed(*parseResult)) |
575 | return ::mlir::failure(); |
576 | {0}Operands.push_back(operand); |
577 | } |
578 | } |
579 | )" ; |
580 | const char *const operandParserCode = R"( |
581 | {0}OperandsLoc = parser.getCurrentLocation(); |
582 | if (parser.parseOperand({0}RawOperand)) |
583 | return ::mlir::failure(); |
584 | )" ; |
585 | /// The code snippet used to generate a parser call for a VariadicOfVariadic |
586 | /// operand. |
587 | /// |
588 | /// {0}: The name of the operand. |
589 | /// {1}: The name of segment size attribute. |
590 | const char *const variadicOfVariadicOperandParserCode = R"( |
591 | { |
592 | {0}OperandsLoc = parser.getCurrentLocation(); |
593 | int32_t curSize = 0; |
594 | do { |
595 | if (parser.parseOptionalLParen()) |
596 | break; |
597 | if (parser.parseOperandList({0}Operands) || parser.parseRParen()) |
598 | return ::mlir::failure(); |
599 | {0}OperandGroupSizes.push_back({0}Operands.size() - curSize); |
600 | curSize = {0}Operands.size(); |
601 | } while (succeeded(parser.parseOptionalComma())); |
602 | } |
603 | )" ; |
604 | |
605 | /// The code snippet used to generate a parser call for a type list. |
606 | /// |
607 | /// {0}: The name for the type list. |
608 | const char *const variadicOfVariadicTypeParserCode = R"( |
609 | do { |
610 | if (parser.parseOptionalLParen()) |
611 | break; |
612 | if (parser.parseOptionalRParen() && |
613 | (parser.parseTypeList({0}Types) || parser.parseRParen())) |
614 | return ::mlir::failure(); |
615 | } while (succeeded(parser.parseOptionalComma())); |
616 | )" ; |
617 | const char *const variadicTypeParserCode = R"( |
618 | if (parser.parseTypeList({0}Types)) |
619 | return ::mlir::failure(); |
620 | )" ; |
621 | const char *const optionalTypeParserCode = R"( |
622 | { |
623 | ::mlir::Type optionalType; |
624 | ::mlir::OptionalParseResult parseResult = |
625 | parser.parseOptionalType(optionalType); |
626 | if (parseResult.has_value()) { |
627 | if (failed(*parseResult)) |
628 | return ::mlir::failure(); |
629 | {0}Types.push_back(optionalType); |
630 | } |
631 | } |
632 | )" ; |
633 | const char *const typeParserCode = R"( |
634 | { |
635 | {0} type; |
636 | if (parser.parseCustomTypeWithFallback(type)) |
637 | return ::mlir::failure(); |
638 | {1}RawType = type; |
639 | } |
640 | )" ; |
641 | const char *const qualifiedTypeParserCode = R"( |
642 | if (parser.parseType({1}RawType)) |
643 | return ::mlir::failure(); |
644 | )" ; |
645 | |
646 | /// The code snippet used to generate a parser call for a functional type. |
647 | /// |
648 | /// {0}: The name for the input type list. |
649 | /// {1}: The name for the result type list. |
650 | const char *const functionalTypeParserCode = R"( |
651 | ::mlir::FunctionType {0}__{1}_functionType; |
652 | if (parser.parseType({0}__{1}_functionType)) |
653 | return ::mlir::failure(); |
654 | {0}Types = {0}__{1}_functionType.getInputs(); |
655 | {1}Types = {0}__{1}_functionType.getResults(); |
656 | )" ; |
657 | |
658 | /// The code snippet used to generate a parser call to infer return types. |
659 | /// |
660 | /// {0}: The operation class name |
661 | const char *const inferReturnTypesParserCode = R"( |
662 | ::llvm::SmallVector<::mlir::Type> inferredReturnTypes; |
663 | if (::mlir::failed({0}::inferReturnTypes(parser.getContext(), |
664 | result.location, result.operands, |
665 | result.attributes.getDictionary(parser.getContext()), |
666 | result.getRawProperties(), |
667 | result.regions, inferredReturnTypes))) |
668 | return ::mlir::failure(); |
669 | result.addTypes(inferredReturnTypes); |
670 | )" ; |
671 | |
672 | /// The code snippet used to generate a parser call for a region list. |
673 | /// |
674 | /// {0}: The name for the region list. |
675 | const char *regionListParserCode = R"( |
676 | { |
677 | std::unique_ptr<::mlir::Region> region; |
678 | auto firstRegionResult = parser.parseOptionalRegion(region); |
679 | if (firstRegionResult.has_value()) { |
680 | if (failed(*firstRegionResult)) |
681 | return ::mlir::failure(); |
682 | {0}Regions.emplace_back(std::move(region)); |
683 | |
684 | // Parse any trailing regions. |
685 | while (succeeded(parser.parseOptionalComma())) { |
686 | region = std::make_unique<::mlir::Region>(); |
687 | if (parser.parseRegion(*region)) |
688 | return ::mlir::failure(); |
689 | {0}Regions.emplace_back(std::move(region)); |
690 | } |
691 | } |
692 | } |
693 | )" ; |
694 | |
695 | /// The code snippet used to ensure a list of regions have terminators. |
696 | /// |
697 | /// {0}: The name of the region list. |
698 | const char *regionListEnsureTerminatorParserCode = R"( |
699 | for (auto ®ion : {0}Regions) |
700 | ensureTerminator(*region, parser.getBuilder(), result.location); |
701 | )" ; |
702 | |
703 | /// The code snippet used to ensure a list of regions have a block. |
704 | /// |
705 | /// {0}: The name of the region list. |
706 | const char *regionListEnsureSingleBlockParserCode = R"( |
707 | for (auto ®ion : {0}Regions) |
708 | if (region->empty()) region->emplaceBlock(); |
709 | )" ; |
710 | |
711 | /// The code snippet used to generate a parser call for an optional region. |
712 | /// |
713 | /// {0}: The name of the region. |
714 | const char *optionalRegionParserCode = R"( |
715 | { |
716 | auto parseResult = parser.parseOptionalRegion(*{0}Region); |
717 | if (parseResult.has_value() && failed(*parseResult)) |
718 | return ::mlir::failure(); |
719 | } |
720 | )" ; |
721 | |
722 | /// The code snippet used to generate a parser call for a region. |
723 | /// |
724 | /// {0}: The name of the region. |
725 | const char *regionParserCode = R"( |
726 | if (parser.parseRegion(*{0}Region)) |
727 | return ::mlir::failure(); |
728 | )" ; |
729 | |
730 | /// The code snippet used to ensure a region has a terminator. |
731 | /// |
732 | /// {0}: The name of the region. |
733 | const char *regionEnsureTerminatorParserCode = R"( |
734 | ensureTerminator(*{0}Region, parser.getBuilder(), result.location); |
735 | )" ; |
736 | |
737 | /// The code snippet used to ensure a region has a block. |
738 | /// |
739 | /// {0}: The name of the region. |
740 | const char *regionEnsureSingleBlockParserCode = R"( |
741 | if ({0}Region->empty()) {0}Region->emplaceBlock(); |
742 | )" ; |
743 | |
744 | /// The code snippet used to generate a parser call for a successor list. |
745 | /// |
746 | /// {0}: The name for the successor list. |
747 | const char *successorListParserCode = R"( |
748 | { |
749 | ::mlir::Block *succ; |
750 | auto firstSucc = parser.parseOptionalSuccessor(succ); |
751 | if (firstSucc.has_value()) { |
752 | if (failed(*firstSucc)) |
753 | return ::mlir::failure(); |
754 | {0}Successors.emplace_back(succ); |
755 | |
756 | // Parse any trailing successors. |
757 | while (succeeded(parser.parseOptionalComma())) { |
758 | if (parser.parseSuccessor(succ)) |
759 | return ::mlir::failure(); |
760 | {0}Successors.emplace_back(succ); |
761 | } |
762 | } |
763 | } |
764 | )" ; |
765 | |
766 | /// The code snippet used to generate a parser call for a successor. |
767 | /// |
768 | /// {0}: The name of the successor. |
769 | const char *successorParserCode = R"( |
770 | if (parser.parseSuccessor({0}Successor)) |
771 | return ::mlir::failure(); |
772 | )" ; |
773 | |
774 | /// The code snippet used to generate a parser for OIList |
775 | /// |
776 | /// {0}: literal keyword corresponding to a case for oilist |
777 | const char *oilistParserCode = R"( |
778 | if ({0}Clause) { |
779 | return parser.emitError(parser.getNameLoc()) |
780 | << "`{0}` clause can appear at most once in the expansion of the " |
781 | "oilist directive"; |
782 | } |
783 | {0}Clause = true; |
784 | )" ; |
785 | |
786 | namespace { |
787 | /// The type of length for a given parse argument. |
788 | enum class ArgumentLengthKind { |
789 | /// The argument is a variadic of a variadic, and may contain 0->N range |
790 | /// elements. |
791 | VariadicOfVariadic, |
792 | /// The argument is variadic, and may contain 0->N elements. |
793 | Variadic, |
794 | /// The argument is optional, and may contain 0 or 1 elements. |
795 | Optional, |
796 | /// The argument is a single element, i.e. always represents 1 element. |
797 | Single |
798 | }; |
799 | } // namespace |
800 | |
801 | /// Get the length kind for the given constraint. |
802 | static ArgumentLengthKind |
803 | getArgumentLengthKind(const NamedTypeConstraint *var) { |
804 | if (var->isOptional()) |
805 | return ArgumentLengthKind::Optional; |
806 | if (var->isVariadicOfVariadic()) |
807 | return ArgumentLengthKind::VariadicOfVariadic; |
808 | if (var->isVariadic()) |
809 | return ArgumentLengthKind::Variadic; |
810 | return ArgumentLengthKind::Single; |
811 | } |
812 | |
813 | /// Get the name used for the type list for the given type directive operand. |
814 | /// 'lengthKind' to the corresponding kind for the given argument. |
815 | static StringRef getTypeListName(FormatElement *arg, |
816 | ArgumentLengthKind &lengthKind) { |
817 | if (auto *operand = dyn_cast<OperandVariable>(Val: arg)) { |
818 | lengthKind = getArgumentLengthKind(var: operand->getVar()); |
819 | return operand->getVar()->name; |
820 | } |
821 | if (auto *result = dyn_cast<ResultVariable>(Val: arg)) { |
822 | lengthKind = getArgumentLengthKind(var: result->getVar()); |
823 | return result->getVar()->name; |
824 | } |
825 | lengthKind = ArgumentLengthKind::Variadic; |
826 | if (isa<OperandsDirective>(Val: arg)) |
827 | return "allOperand" ; |
828 | if (isa<ResultsDirective>(Val: arg)) |
829 | return "allResult" ; |
830 | llvm_unreachable("unknown 'type' directive argument" ); |
831 | } |
832 | |
833 | /// Generate the parser for a literal value. |
834 | static void genLiteralParser(StringRef value, MethodBody &body) { |
835 | // Handle the case of a keyword/identifier. |
836 | if (value.front() == '_' || isalpha(value.front())) { |
837 | body << "Keyword(\"" << value << "\")" ; |
838 | return; |
839 | } |
840 | body << (StringRef)StringSwitch<StringRef>(value) |
841 | .Case(S: "->" , Value: "Arrow()" ) |
842 | .Case(S: ":" , Value: "Colon()" ) |
843 | .Case(S: "," , Value: "Comma()" ) |
844 | .Case(S: "=" , Value: "Equal()" ) |
845 | .Case(S: "<" , Value: "Less()" ) |
846 | .Case(S: ">" , Value: "Greater()" ) |
847 | .Case(S: "{" , Value: "LBrace()" ) |
848 | .Case(S: "}" , Value: "RBrace()" ) |
849 | .Case(S: "(" , Value: "LParen()" ) |
850 | .Case(S: ")" , Value: "RParen()" ) |
851 | .Case(S: "[" , Value: "LSquare()" ) |
852 | .Case(S: "]" , Value: "RSquare()" ) |
853 | .Case(S: "?" , Value: "Question()" ) |
854 | .Case(S: "+" , Value: "Plus()" ) |
855 | .Case(S: "*" , Value: "Star()" ) |
856 | .Case(S: "..." , Value: "Ellipsis()" ); |
857 | } |
858 | |
859 | /// Generate the storage code required for parsing the given element. |
860 | static void genElementParserStorage(FormatElement *element, const Operator &op, |
861 | MethodBody &body) { |
862 | if (auto *optional = dyn_cast<OptionalElement>(Val: element)) { |
863 | ArrayRef<FormatElement *> elements = optional->getThenElements(); |
864 | |
865 | // If the anchor is a unit attribute, it won't be parsed directly so elide |
866 | // it. |
867 | auto *anchor = dyn_cast<AttributeLikeVariable>(Val: optional->getAnchor()); |
868 | FormatElement *elidedAnchorElement = nullptr; |
869 | if (anchor && anchor != elements.front() && anchor->isUnit()) |
870 | elidedAnchorElement = anchor; |
871 | for (FormatElement *childElement : elements) |
872 | if (childElement != elidedAnchorElement) |
873 | genElementParserStorage(element: childElement, op, body); |
874 | for (FormatElement *childElement : optional->getElseElements()) |
875 | genElementParserStorage(element: childElement, op, body); |
876 | |
877 | } else if (auto *oilist = dyn_cast<OIListElement>(Val: element)) { |
878 | for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements()) { |
879 | if (!oilist->getUnitVariableParsingElement(pelement)) |
880 | for (FormatElement *element : pelement) |
881 | genElementParserStorage(element, op, body); |
882 | } |
883 | |
884 | } else if (auto *custom = dyn_cast<CustomDirective>(Val: element)) { |
885 | for (FormatElement *paramElement : custom->getElements()) |
886 | genElementParserStorage(element: paramElement, op, body); |
887 | |
888 | } else if (isa<OperandsDirective>(Val: element)) { |
889 | body << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> " |
890 | "allOperands;\n" ; |
891 | |
892 | } else if (isa<RegionsDirective>(Val: element)) { |
893 | body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> " |
894 | "fullRegions;\n" ; |
895 | |
896 | } else if (isa<SuccessorsDirective>(Val: element)) { |
897 | body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n" ; |
898 | |
899 | } else if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) { |
900 | const NamedAttribute *var = attr->getVar(); |
901 | body << formatv(Fmt: " {0} {1}Attr;\n" , Vals: var->attr.getStorageType(), Vals: var->name); |
902 | |
903 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) { |
904 | StringRef name = operand->getVar()->name; |
905 | if (operand->getVar()->isVariableLength()) { |
906 | body |
907 | << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> " |
908 | << name << "Operands;\n" ; |
909 | if (operand->getVar()->isVariadicOfVariadic()) { |
910 | body << " llvm::SmallVector<int32_t> " << name |
911 | << "OperandGroupSizes;\n" ; |
912 | } |
913 | } else { |
914 | body << " ::mlir::OpAsmParser::UnresolvedOperand " << name |
915 | << "RawOperand{};\n" |
916 | << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> " |
917 | << name << "Operands(&" << name << "RawOperand, 1);" ; |
918 | } |
919 | body << formatv(Fmt: " ::llvm::SMLoc {0}OperandsLoc;\n" |
920 | " (void){0}OperandsLoc;\n" , |
921 | Vals&: name); |
922 | |
923 | } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) { |
924 | StringRef name = region->getVar()->name; |
925 | if (region->getVar()->isVariadic()) { |
926 | body << formatv( |
927 | Fmt: " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> " |
928 | "{0}Regions;\n" , |
929 | Vals&: name); |
930 | } else { |
931 | body << formatv(Fmt: " std::unique_ptr<::mlir::Region> {0}Region = " |
932 | "std::make_unique<::mlir::Region>();\n" , |
933 | Vals&: name); |
934 | } |
935 | |
936 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) { |
937 | StringRef name = successor->getVar()->name; |
938 | if (successor->getVar()->isVariadic()) { |
939 | body << formatv(Fmt: " ::llvm::SmallVector<::mlir::Block *, 2> " |
940 | "{0}Successors;\n" , |
941 | Vals&: name); |
942 | } else { |
943 | body << formatv(Fmt: " ::mlir::Block *{0}Successor = nullptr;\n" , Vals&: name); |
944 | } |
945 | |
946 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) { |
947 | ArgumentLengthKind lengthKind; |
948 | StringRef name = getTypeListName(arg: dir->getArg(), lengthKind); |
949 | if (lengthKind != ArgumentLengthKind::Single) |
950 | body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n" ; |
951 | else |
952 | body |
953 | << formatv(Fmt: " ::mlir::Type {0}RawType{{};\n" , Vals&: name) |
954 | << formatv( |
955 | Fmt: " ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n" , |
956 | Vals&: name); |
957 | } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(Val: element)) { |
958 | ArgumentLengthKind ignored; |
959 | body << " ::llvm::ArrayRef<::mlir::Type> " |
960 | << getTypeListName(arg: dir->getInputs(), lengthKind&: ignored) << "Types;\n" ; |
961 | body << " ::llvm::ArrayRef<::mlir::Type> " |
962 | << getTypeListName(arg: dir->getResults(), lengthKind&: ignored) << "Types;\n" ; |
963 | } |
964 | } |
965 | |
966 | /// Generate the parser for a parameter to a custom directive. |
967 | static void genCustomParameterParser(FormatElement *param, MethodBody &body) { |
968 | if (auto *attr = dyn_cast<AttributeVariable>(Val: param)) { |
969 | body << attr->getVar()->name << "Attr" ; |
970 | } else if (isa<AttrDictDirective>(Val: param)) { |
971 | body << "result.attributes" ; |
972 | } else if (isa<PropDictDirective>(Val: param)) { |
973 | body << "result" ; |
974 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: param)) { |
975 | StringRef name = operand->getVar()->name; |
976 | ArgumentLengthKind lengthKind = getArgumentLengthKind(var: operand->getVar()); |
977 | if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) |
978 | body << formatv(Fmt: "{0}OperandGroups" , Vals&: name); |
979 | else if (lengthKind == ArgumentLengthKind::Variadic) |
980 | body << formatv(Fmt: "{0}Operands" , Vals&: name); |
981 | else if (lengthKind == ArgumentLengthKind::Optional) |
982 | body << formatv(Fmt: "{0}Operand" , Vals&: name); |
983 | else |
984 | body << formatv(Fmt: "{0}RawOperand" , Vals&: name); |
985 | |
986 | } else if (auto *region = dyn_cast<RegionVariable>(Val: param)) { |
987 | StringRef name = region->getVar()->name; |
988 | if (region->getVar()->isVariadic()) |
989 | body << formatv(Fmt: "{0}Regions" , Vals&: name); |
990 | else |
991 | body << formatv(Fmt: "*{0}Region" , Vals&: name); |
992 | |
993 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: param)) { |
994 | StringRef name = successor->getVar()->name; |
995 | if (successor->getVar()->isVariadic()) |
996 | body << formatv(Fmt: "{0}Successors" , Vals&: name); |
997 | else |
998 | body << formatv(Fmt: "{0}Successor" , Vals&: name); |
999 | |
1000 | } else if (auto *dir = dyn_cast<RefDirective>(Val: param)) { |
1001 | genCustomParameterParser(param: dir->getArg(), body); |
1002 | |
1003 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: param)) { |
1004 | ArgumentLengthKind lengthKind; |
1005 | StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind); |
1006 | if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) |
1007 | body << formatv(Fmt: "{0}TypeGroups" , Vals&: listName); |
1008 | else if (lengthKind == ArgumentLengthKind::Variadic) |
1009 | body << formatv(Fmt: "{0}Types" , Vals&: listName); |
1010 | else if (lengthKind == ArgumentLengthKind::Optional) |
1011 | body << formatv(Fmt: "{0}Type" , Vals&: listName); |
1012 | else |
1013 | body << formatv(Fmt: "{0}RawType" , Vals&: listName); |
1014 | |
1015 | } else if (auto *string = dyn_cast<StringElement>(Val: param)) { |
1016 | FmtContext ctx; |
1017 | ctx.withBuilder(subst: "parser.getBuilder()" ); |
1018 | ctx.addSubst(placeholder: "_ctxt" , subst: "parser.getContext()" ); |
1019 | body << tgfmt(fmt: string->getValue(), ctx: &ctx); |
1020 | |
1021 | } else if (auto *property = dyn_cast<PropertyVariable>(Val: param)) { |
1022 | body << formatv(Fmt: "result.getOrAddProperties<Properties>().{0}" , |
1023 | Vals: property->getVar()->name); |
1024 | } else { |
1025 | llvm_unreachable("unknown custom directive parameter" ); |
1026 | } |
1027 | } |
1028 | |
1029 | /// Generate the parser for a custom directive. |
1030 | static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, |
1031 | bool useProperties, |
1032 | StringRef opCppClassName, |
1033 | bool isOptional = false) { |
1034 | body << " {\n" ; |
1035 | |
1036 | // Preprocess the directive variables. |
1037 | // * Add a local variable for optional operands and types. This provides a |
1038 | // better API to the user defined parser methods. |
1039 | // * Set the location of operand variables. |
1040 | for (FormatElement *param : dir->getElements()) { |
1041 | if (auto *operand = dyn_cast<OperandVariable>(Val: param)) { |
1042 | auto *var = operand->getVar(); |
1043 | body << " " << var->name |
1044 | << "OperandsLoc = parser.getCurrentLocation();\n" ; |
1045 | if (var->isOptional()) { |
1046 | body << formatv( |
1047 | Fmt: " ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> " |
1048 | "{0}Operand;\n" , |
1049 | Vals: var->name); |
1050 | } else if (var->isVariadicOfVariadic()) { |
1051 | body << formatv(Fmt: " " |
1052 | "::llvm::SmallVector<::llvm::SmallVector<::mlir::" |
1053 | "OpAsmParser::UnresolvedOperand>> " |
1054 | "{0}OperandGroups;\n" , |
1055 | Vals: var->name); |
1056 | } |
1057 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: param)) { |
1058 | ArgumentLengthKind lengthKind; |
1059 | StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind); |
1060 | if (lengthKind == ArgumentLengthKind::Optional) { |
1061 | body << formatv(Fmt: " ::mlir::Type {0}Type;\n" , Vals&: listName); |
1062 | } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { |
1063 | body << formatv( |
1064 | Fmt: " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> " |
1065 | "{0}TypeGroups;\n" , |
1066 | Vals&: listName); |
1067 | } |
1068 | } else if (auto *dir = dyn_cast<RefDirective>(Val: param)) { |
1069 | FormatElement *input = dir->getArg(); |
1070 | if (auto *operand = dyn_cast<OperandVariable>(Val: input)) { |
1071 | if (!operand->getVar()->isOptional()) |
1072 | continue; |
1073 | body << formatv( |
1074 | Fmt: " {0} {1}Operand = {1}Operands.empty() ? {0}() : " |
1075 | "{1}Operands[0];\n" , |
1076 | Vals: "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>" , |
1077 | Vals: operand->getVar()->name); |
1078 | |
1079 | } else if (auto *type = dyn_cast<TypeDirective>(Val: input)) { |
1080 | ArgumentLengthKind lengthKind; |
1081 | StringRef listName = getTypeListName(arg: type->getArg(), lengthKind); |
1082 | if (lengthKind == ArgumentLengthKind::Optional) { |
1083 | body << formatv(Fmt: " ::mlir::Type {0}Type = {0}Types.empty() ? " |
1084 | "::mlir::Type() : {0}Types[0];\n" , |
1085 | Vals&: listName); |
1086 | } |
1087 | } |
1088 | } |
1089 | } |
1090 | |
1091 | body << " auto odsResult = parse" << dir->getName() << "(parser" ; |
1092 | for (FormatElement *param : dir->getElements()) { |
1093 | body << ", " ; |
1094 | genCustomParameterParser(param, body); |
1095 | } |
1096 | body << ");\n" ; |
1097 | |
1098 | if (isOptional) { |
1099 | body << " if (!odsResult.has_value()) return {};\n" |
1100 | << " if (::mlir::failed(*odsResult)) return ::mlir::failure();\n" ; |
1101 | } else { |
1102 | body << " if (odsResult) return ::mlir::failure();\n" ; |
1103 | } |
1104 | |
1105 | // After parsing, add handling for any of the optional constructs. |
1106 | for (FormatElement *param : dir->getElements()) { |
1107 | if (auto *attr = dyn_cast<AttributeVariable>(Val: param)) { |
1108 | const NamedAttribute *var = attr->getVar(); |
1109 | if (var->attr.isOptional() || var->attr.hasDefaultValue()) |
1110 | body << formatv(Fmt: " if ({0}Attr)\n " , Vals: var->name); |
1111 | if (useProperties) { |
1112 | body << formatv( |
1113 | Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n" , |
1114 | Vals: var->name, Vals&: opCppClassName); |
1115 | } else { |
1116 | body << formatv(Fmt: " result.addAttribute(\"{0}\", {0}Attr);\n" , |
1117 | Vals: var->name); |
1118 | } |
1119 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: param)) { |
1120 | const NamedTypeConstraint *var = operand->getVar(); |
1121 | if (var->isOptional()) { |
1122 | body << formatv(Fmt: " if ({0}Operand.has_value())\n" |
1123 | " {0}Operands.push_back(*{0}Operand);\n" , |
1124 | Vals: var->name); |
1125 | } else if (var->isVariadicOfVariadic()) { |
1126 | body << formatv( |
1127 | Fmt: " for (const auto &subRange : {0}OperandGroups) {{\n" |
1128 | " {0}Operands.append(subRange.begin(), subRange.end());\n" |
1129 | " {0}OperandGroupSizes.push_back(subRange.size());\n" |
1130 | " }\n" , |
1131 | Vals: var->name); |
1132 | } |
1133 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: param)) { |
1134 | ArgumentLengthKind lengthKind; |
1135 | StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind); |
1136 | if (lengthKind == ArgumentLengthKind::Optional) { |
1137 | body << formatv(Fmt: " if ({0}Type)\n" |
1138 | " {0}Types.push_back({0}Type);\n" , |
1139 | Vals&: listName); |
1140 | } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { |
1141 | body << formatv( |
1142 | Fmt: " for (const auto &subRange : {0}TypeGroups)\n" |
1143 | " {0}Types.append(subRange.begin(), subRange.end());\n" , |
1144 | Vals&: listName); |
1145 | } |
1146 | } |
1147 | } |
1148 | |
1149 | body << " }\n" ; |
1150 | } |
1151 | |
1152 | /// Generate the parser for a enum attribute. |
1153 | static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body, |
1154 | FmtContext &attrTypeCtx, bool parseAsOptional, |
1155 | bool useProperties, StringRef opCppClassName) { |
1156 | Attribute baseAttr = var->attr.getBaseAttr(); |
1157 | EnumInfo enumInfo(&baseAttr.getDef()); |
1158 | std::vector<EnumCase> cases = enumInfo.getAllCases(); |
1159 | |
1160 | // Generate the code for building an attribute for this enum. |
1161 | std::string attrBuilderStr; |
1162 | { |
1163 | llvm::raw_string_ostream os(attrBuilderStr); |
1164 | os << tgfmt(fmt: baseAttr.getConstBuilderTemplate(), ctx: &attrTypeCtx, |
1165 | vals: "*attrOptional" ); |
1166 | } |
1167 | |
1168 | // Build a string containing the cases that can be formatted as a keyword. |
1169 | std::string validCaseKeywordsStr = "{" ; |
1170 | llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr); |
1171 | for (const EnumCase &attrCase : cases) |
1172 | if (canFormatStringAsKeyword(value: attrCase.getStr())) |
1173 | validCaseKeywordsOS << '"' << attrCase.getStr() << "\"," ; |
1174 | validCaseKeywordsOS.str().back() = '}'; |
1175 | |
1176 | // If the attribute is not optional, build an error message for the missing |
1177 | // attribute. |
1178 | std::string errorMessage; |
1179 | if (!parseAsOptional) { |
1180 | llvm::raw_string_ostream errorMessageOS(errorMessage); |
1181 | errorMessageOS |
1182 | << "return parser.emitError(loc, \"expected string or " |
1183 | "keyword containing one of the following enum values for attribute '" |
1184 | << var->name << "' [" ; |
1185 | llvm::interleaveComma(c: cases, os&: errorMessageOS, each_fn: [&](const auto &attrCase) { |
1186 | errorMessageOS << attrCase.getStr(); |
1187 | }); |
1188 | errorMessageOS << "]\");" ; |
1189 | } |
1190 | std::string attrAssignment; |
1191 | if (useProperties) { |
1192 | attrAssignment = |
1193 | formatv(Fmt: " " |
1194 | "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;" , |
1195 | Vals: var->name, Vals&: opCppClassName); |
1196 | } else { |
1197 | attrAssignment = |
1198 | formatv(Fmt: "result.addAttribute(\"{0}\", {0}Attr);" , Vals: var->name); |
1199 | } |
1200 | |
1201 | body << formatv(Fmt: enumAttrParserCode, Vals: var->name, Vals: enumInfo.getCppNamespace(), |
1202 | Vals: enumInfo.getStringToSymbolFnName(), Vals&: attrBuilderStr, |
1203 | Vals&: validCaseKeywordsStr, Vals&: errorMessage, Vals&: attrAssignment); |
1204 | } |
1205 | |
1206 | // Generate the parser for a property. |
1207 | static void genPropertyParser(PropertyVariable *propVar, MethodBody &body, |
1208 | StringRef opCppClassName, |
1209 | bool requireParse = true) { |
1210 | StringRef name = propVar->getVar()->name; |
1211 | const Property &prop = propVar->getVar()->prop; |
1212 | bool parseOptionally = |
1213 | prop.hasDefaultValue() && !requireParse && prop.hasOptionalParser(); |
1214 | FmtContext fmtContext; |
1215 | fmtContext.addSubst(placeholder: "_parser" , subst: "parser" ); |
1216 | fmtContext.addSubst(placeholder: "_ctxt" , subst: "parser.getContext()" ); |
1217 | fmtContext.addSubst(placeholder: "_storage" , subst: "propStorage" ); |
1218 | |
1219 | if (parseOptionally) { |
1220 | body << formatv(Fmt: optionalPropertyParserCode, Vals&: name, Vals&: opCppClassName, |
1221 | Vals: tgfmt(fmt: prop.getOptionalParserCall(), ctx: &fmtContext)); |
1222 | } else { |
1223 | body << formatv(Fmt: propertyParserCode, Vals&: name, Vals&: opCppClassName, |
1224 | Vals: tgfmt(fmt: prop.getParserCall(), ctx: &fmtContext), |
1225 | Vals: prop.getSummary()); |
1226 | } |
1227 | } |
1228 | |
1229 | // Generate the parser for an attribute. |
1230 | static void genAttrParser(AttributeVariable *attr, MethodBody &body, |
1231 | FmtContext &attrTypeCtx, bool parseAsOptional, |
1232 | bool useProperties, StringRef opCppClassName) { |
1233 | const NamedAttribute *var = attr->getVar(); |
1234 | |
1235 | // Check to see if we can parse this as an enum attribute. |
1236 | if (canFormatEnumAttr(attr: var)) |
1237 | return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional, |
1238 | useProperties, opCppClassName); |
1239 | |
1240 | // Check to see if we should parse this as a symbol name attribute. |
1241 | if (shouldFormatSymbolNameAttr(attr: var)) { |
1242 | body << formatv(Fmt: parseAsOptional ? optionalSymbolNameAttrParserCode |
1243 | : symbolNameAttrParserCode, |
1244 | Vals: var->name); |
1245 | } else { |
1246 | |
1247 | // If this attribute has a buildable type, use that when parsing the |
1248 | // attribute. |
1249 | std::string attrTypeStr; |
1250 | if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) { |
1251 | llvm::raw_string_ostream os(attrTypeStr); |
1252 | os << tgfmt(fmt: *typeBuilder, ctx: &attrTypeCtx); |
1253 | } else { |
1254 | attrTypeStr = "::mlir::Type{}" ; |
1255 | } |
1256 | if (parseAsOptional) { |
1257 | body << formatv(Fmt: optionalAttrParserCode, Vals: var->name, Vals&: attrTypeStr); |
1258 | } else { |
1259 | if (attr->shouldBeQualified() || |
1260 | var->attr.getStorageType() == "::mlir::Attribute" ) |
1261 | body << formatv(Fmt: genericAttrParserCode, Vals: var->name, Vals&: attrTypeStr); |
1262 | else |
1263 | body << formatv(Fmt: attrParserCode, Vals: var->name, Vals&: attrTypeStr); |
1264 | } |
1265 | } |
1266 | if (useProperties) { |
1267 | body << formatv( |
1268 | Fmt: " if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = " |
1269 | "{0}Attr;\n" , |
1270 | Vals: var->name, Vals&: opCppClassName); |
1271 | } else { |
1272 | body << formatv( |
1273 | Fmt: " if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n" , |
1274 | Vals: var->name); |
1275 | } |
1276 | } |
1277 | |
1278 | // Generates the 'setPropertiesFromParsedAttr' used to set properties from a |
1279 | // 'prop-dict' dictionary attr. |
1280 | static void genParsedAttrPropertiesSetter(OperationFormat &fmt, Operator &op, |
1281 | OpClass &opClass) { |
1282 | // Not required unless 'prop-dict' is present or we are not using properties. |
1283 | if (!fmt.hasPropDict || !fmt.useProperties) |
1284 | return; |
1285 | |
1286 | SmallVector<MethodParameter> paramList; |
1287 | paramList.emplace_back(Args: "Properties &" , Args: "prop" ); |
1288 | paramList.emplace_back(Args: "::mlir::Attribute" , Args: "attr" ); |
1289 | paramList.emplace_back(Args: "::llvm::function_ref<::mlir::InFlightDiagnostic()>" , |
1290 | Args: "emitError" ); |
1291 | |
1292 | Method *method = opClass.addStaticMethod(retType: "::llvm::LogicalResult" , |
1293 | name: "setPropertiesFromParsedAttr" , |
1294 | args: std::move(paramList)); |
1295 | MethodBody &body = method->body().indent(); |
1296 | |
1297 | body << R"decl( |
1298 | ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); |
1299 | if (!dict) { |
1300 | emitError() << "expected DictionaryAttr to set properties"; |
1301 | return ::mlir::failure(); |
1302 | } |
1303 | // keep track of used keys in the input dictionary to be able to error out |
1304 | // if there are some unknown ones. |
1305 | DenseSet<StringAttr> usedKeys; |
1306 | MLIRContext *ctx = dict.getContext(); |
1307 | (void)ctx; |
1308 | )decl" ; |
1309 | |
1310 | // {0}: fromAttribute call |
1311 | // {1}: property name |
1312 | // {2}: isRequired |
1313 | const char *propFromAttrFmt = R"decl( |
1314 | auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr, |
1315 | ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) -> ::mlir::LogicalResult {{ |
1316 | {0}; |
1317 | }; |
1318 | auto {1}AttrName = StringAttr::get(ctx, "{1}"); |
1319 | usedKeys.insert({1}AttrName); |
1320 | auto attr = dict.get({1}AttrName); |
1321 | if (!attr && {2}) {{ |
1322 | emitError() << "expected key entry for {1} in DictionaryAttr to set " |
1323 | "Properties."; |
1324 | return ::mlir::failure(); |
1325 | } |
1326 | if (attr && ::mlir::failed(setFromAttr(prop.{1}, attr, emitError))) |
1327 | return ::mlir::failure(); |
1328 | )decl" ; |
1329 | |
1330 | // Generate the setter for any property not parsed elsewhere. |
1331 | for (const NamedProperty &namedProperty : op.getProperties()) { |
1332 | if (fmt.usedProperties.contains(key: &namedProperty)) |
1333 | continue; |
1334 | |
1335 | auto scope = body.scope(open: "{\n" , close: "}\n" , /*indent=*/true); |
1336 | |
1337 | StringRef name = namedProperty.name; |
1338 | const Property &prop = namedProperty.prop; |
1339 | bool isRequired = !prop.hasDefaultValue(); |
1340 | FmtContext fctx; |
1341 | body << formatv(Fmt: propFromAttrFmt, |
1342 | Vals: tgfmt(fmt: prop.getConvertFromAttributeCall(), |
1343 | ctx: &fctx.addSubst(placeholder: "_attr" , subst: "propAttr" ) |
1344 | .addSubst(placeholder: "_storage" , subst: "propStorage" ) |
1345 | .addSubst(placeholder: "_diag" , subst: "emitError" )), |
1346 | Vals&: name, Vals&: isRequired); |
1347 | } |
1348 | |
1349 | // Generate the setter for any attribute not parsed elsewhere. |
1350 | for (const NamedAttribute &namedAttr : op.getAttributes()) { |
1351 | if (fmt.usedAttributes.contains(key: &namedAttr)) |
1352 | continue; |
1353 | |
1354 | const Attribute &attr = namedAttr.attr; |
1355 | // Derived attributes do not need to be parsed. |
1356 | if (attr.isDerivedAttr()) |
1357 | continue; |
1358 | |
1359 | auto scope = body.scope(open: "{\n" , close: "}\n" , /*indent=*/true); |
1360 | |
1361 | // If the attribute has a default value or is optional, it does not need to |
1362 | // be present in the parsed dictionary attribute. |
1363 | bool isRequired = !attr.isOptional() && !attr.hasDefaultValue(); |
1364 | body << formatv(Fmt: R"decl( |
1365 | auto &propStorage = prop.{0}; |
1366 | auto {0}AttrName = StringAttr::get(ctx, "{0}"); |
1367 | auto attr = dict.get({0}AttrName); |
1368 | usedKeys.insert(StringAttr::get(ctx, "{1}")); |
1369 | if (attr || /*isRequired=*/{1}) {{ |
1370 | if (!attr) {{ |
1371 | emitError() << "expected key entry for {0} in DictionaryAttr to set " |
1372 | "Properties."; |
1373 | return ::mlir::failure(); |
1374 | } |
1375 | auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr); |
1376 | if (convertedAttr) {{ |
1377 | propStorage = convertedAttr; |
1378 | } else {{ |
1379 | emitError() << "Invalid attribute `{0}` in property conversion: " << attr; |
1380 | return ::mlir::failure(); |
1381 | } |
1382 | } |
1383 | )decl" , |
1384 | Vals: namedAttr.name, Vals&: isRequired); |
1385 | } |
1386 | body << R"decl( |
1387 | for (NamedAttribute attr : dict) { |
1388 | if (!usedKeys.contains(attr.getName())) |
1389 | return emitError() << "unknown key '" << attr.getName() << |
1390 | "' when parsing properties dictionary"; |
1391 | } |
1392 | return ::mlir::success(); |
1393 | )decl" ; |
1394 | } |
1395 | |
1396 | void OperationFormat::genParser(Operator &op, OpClass &opClass) { |
1397 | SmallVector<MethodParameter> paramList; |
1398 | paramList.emplace_back(Args: "::mlir::OpAsmParser &" , Args: "parser" ); |
1399 | paramList.emplace_back(Args: "::mlir::OperationState &" , Args: "result" ); |
1400 | |
1401 | auto *method = opClass.addStaticMethod(retType: "::mlir::ParseResult" , name: "parse" , |
1402 | args: std::move(paramList)); |
1403 | auto &body = method->body(); |
1404 | |
1405 | // Generate variables to store the operands and type within the format. This |
1406 | // allows for referencing these variables in the presence of optional |
1407 | // groupings. |
1408 | for (FormatElement *element : elements) |
1409 | genElementParserStorage(element, op, body); |
1410 | |
1411 | // A format context used when parsing attributes with buildable types. |
1412 | FmtContext attrTypeCtx; |
1413 | attrTypeCtx.withBuilder(subst: "parser.getBuilder()" ); |
1414 | |
1415 | // Generate parsers for each of the elements. |
1416 | for (FormatElement *element : elements) |
1417 | genElementParser(element, body, attrTypeCtx); |
1418 | |
1419 | // Generate the code to resolve the operand/result types and successors now |
1420 | // that they have been parsed. |
1421 | genParserRegionResolution(op, body); |
1422 | genParserSuccessorResolution(op, body); |
1423 | genParserVariadicSegmentResolution(op, body); |
1424 | genParserTypeResolution(op, body); |
1425 | |
1426 | body << " return ::mlir::success();\n" ; |
1427 | |
1428 | genParsedAttrPropertiesSetter(fmt&: *this, op, opClass); |
1429 | } |
1430 | |
1431 | void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, |
1432 | FmtContext &attrTypeCtx, |
1433 | GenContext genCtx) { |
1434 | /// Optional Group. |
1435 | if (auto *optional = dyn_cast<OptionalElement>(Val: element)) { |
1436 | auto genElementParsers = [&](FormatElement *firstElement, |
1437 | ArrayRef<FormatElement *> elements, |
1438 | bool thenGroup) { |
1439 | // If the anchor is a unit attribute, we don't need to print it. When |
1440 | // parsing, we will add this attribute if this group is present. |
1441 | FormatElement *elidedAnchorElement = nullptr; |
1442 | auto *anchorVar = dyn_cast<AttributeLikeVariable>(Val: optional->getAnchor()); |
1443 | if (anchorVar && anchorVar != firstElement && anchorVar->isUnit()) { |
1444 | elidedAnchorElement = anchorVar; |
1445 | |
1446 | if (!thenGroup == optional->isInverted()) { |
1447 | // Add the anchor unit attribute or property to the operation state |
1448 | // or set the property to true. |
1449 | if (isa<PropertyVariable>(Val: anchorVar)) { |
1450 | body << formatv( |
1451 | Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = true;" , |
1452 | Vals: anchorVar->getName(), Vals&: opCppClassName); |
1453 | } else if (useProperties) { |
1454 | body << formatv( |
1455 | Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = " |
1456 | "parser.getBuilder().getUnitAttr();" , |
1457 | Vals: anchorVar->getName(), Vals&: opCppClassName); |
1458 | } else { |
1459 | body << " result.addAttribute(\"" << anchorVar->getName() |
1460 | << "\", parser.getBuilder().getUnitAttr());\n" ; |
1461 | } |
1462 | } |
1463 | } |
1464 | |
1465 | // Generate the rest of the elements inside an optional group. Elements in |
1466 | // an optional group after the guard are parsed as required. |
1467 | for (FormatElement *childElement : elements) |
1468 | if (childElement != elidedAnchorElement) |
1469 | genElementParser(element: childElement, body, attrTypeCtx, |
1470 | genCtx: GenContext::Optional); |
1471 | }; |
1472 | |
1473 | ArrayRef<FormatElement *> thenElements = |
1474 | optional->getThenElements(/*parseable=*/true); |
1475 | |
1476 | // Generate a special optional parser for the first element to gate the |
1477 | // parsing of the rest of the elements. |
1478 | FormatElement *firstElement = thenElements.front(); |
1479 | if (auto *attrVar = dyn_cast<AttributeVariable>(Val: firstElement)) { |
1480 | genAttrParser(attr: attrVar, body, attrTypeCtx, /*parseAsOptional=*/true, |
1481 | useProperties, opCppClassName); |
1482 | body << " if (" << attrVar->getVar()->name << "Attr) {\n" ; |
1483 | } else if (auto *propVar = dyn_cast<PropertyVariable>(Val: firstElement)) { |
1484 | genPropertyParser(propVar, body, opCppClassName, /*requireParse=*/false); |
1485 | body << formatv(Fmt: "if ({0}PropParseResult.has_value() && " |
1486 | "succeeded(*{0}PropParseResult)) " , |
1487 | Vals: propVar->getVar()->name) |
1488 | << " {\n" ; |
1489 | } else if (auto *literal = dyn_cast<LiteralElement>(Val: firstElement)) { |
1490 | body << " if (::mlir::succeeded(parser.parseOptional" ; |
1491 | genLiteralParser(value: literal->getSpelling(), body); |
1492 | body << ")) {\n" ; |
1493 | } else if (auto *opVar = dyn_cast<OperandVariable>(Val: firstElement)) { |
1494 | genElementParser(element: opVar, body, attrTypeCtx); |
1495 | body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n" ; |
1496 | } else if (auto *regionVar = dyn_cast<RegionVariable>(Val: firstElement)) { |
1497 | const NamedRegion *region = regionVar->getVar(); |
1498 | if (region->isVariadic()) { |
1499 | genElementParser(element: regionVar, body, attrTypeCtx); |
1500 | body << " if (!" << region->name << "Regions.empty()) {\n" ; |
1501 | } else { |
1502 | body << formatv(Fmt: optionalRegionParserCode, Vals: region->name); |
1503 | body << " if (!" << region->name << "Region->empty()) {\n " ; |
1504 | if (hasImplicitTermTrait) |
1505 | body << formatv(Fmt: regionEnsureTerminatorParserCode, Vals: region->name); |
1506 | else if (hasSingleBlockTrait) |
1507 | body << formatv(Fmt: regionEnsureSingleBlockParserCode, Vals: region->name); |
1508 | } |
1509 | } else if (auto *custom = dyn_cast<CustomDirective>(Val: firstElement)) { |
1510 | body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n" ; |
1511 | genCustomDirectiveParser(dir: custom, body, useProperties, opCppClassName, |
1512 | /*isOptional=*/true); |
1513 | body << " return ::mlir::success();\n" |
1514 | << " }(); optResult.has_value() && ::mlir::failed(*optResult)) {\n" |
1515 | << " return ::mlir::failure();\n" |
1516 | << " } else if (optResult.has_value()) {\n" ; |
1517 | } |
1518 | |
1519 | genElementParsers(firstElement, thenElements.drop_front(), |
1520 | /*thenGroup=*/true); |
1521 | body << " }" ; |
1522 | |
1523 | // Generate the else elements. |
1524 | auto elseElements = optional->getElseElements(); |
1525 | if (!elseElements.empty()) { |
1526 | body << " else {\n" ; |
1527 | ArrayRef<FormatElement *> elseElements = |
1528 | optional->getElseElements(/*parseable=*/true); |
1529 | genElementParsers(elseElements.front(), elseElements, |
1530 | /*thenGroup=*/false); |
1531 | body << " }" ; |
1532 | } |
1533 | body << "\n" ; |
1534 | |
1535 | /// OIList Directive |
1536 | } else if (OIListElement *oilist = dyn_cast<OIListElement>(Val: element)) { |
1537 | for (LiteralElement *le : oilist->getLiteralElements()) |
1538 | body << " bool " << le->getSpelling() << "Clause = false;\n" ; |
1539 | |
1540 | // Generate the parsing loop |
1541 | body << " while(true) {\n" ; |
1542 | for (auto clause : oilist->getClauses()) { |
1543 | LiteralElement *lelement = std::get<0>(t&: clause); |
1544 | ArrayRef<FormatElement *> pelement = std::get<1>(t&: clause); |
1545 | body << "if (succeeded(parser.parseOptional" ; |
1546 | genLiteralParser(value: lelement->getSpelling(), body); |
1547 | body << ")) {\n" ; |
1548 | StringRef lelementName = lelement->getSpelling(); |
1549 | body << formatv(Fmt: oilistParserCode, Vals&: lelementName); |
1550 | if (AttributeLikeVariable *unitVarElem = |
1551 | oilist->getUnitVariableParsingElement(pelement)) { |
1552 | if (isa<PropertyVariable>(Val: unitVarElem)) { |
1553 | body << formatv( |
1554 | Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = true;" , |
1555 | Vals: unitVarElem->getName(), Vals&: opCppClassName); |
1556 | } else if (useProperties) { |
1557 | body << formatv( |
1558 | Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = " |
1559 | "parser.getBuilder().getUnitAttr();" , |
1560 | Vals: unitVarElem->getName(), Vals&: opCppClassName); |
1561 | } else { |
1562 | body << " result.addAttribute(\"" << unitVarElem->getName() |
1563 | << "\", UnitAttr::get(parser.getContext()));\n" ; |
1564 | } |
1565 | } else { |
1566 | for (FormatElement *el : pelement) |
1567 | genElementParser(element: el, body, attrTypeCtx); |
1568 | } |
1569 | body << " } else " ; |
1570 | } |
1571 | body << " {\n" ; |
1572 | body << " break;\n" ; |
1573 | body << " }\n" ; |
1574 | body << "}\n" ; |
1575 | |
1576 | /// Literals. |
1577 | } else if (LiteralElement *literal = dyn_cast<LiteralElement>(Val: element)) { |
1578 | body << " if (parser.parse" ; |
1579 | genLiteralParser(value: literal->getSpelling(), body); |
1580 | body << ")\n return ::mlir::failure();\n" ; |
1581 | |
1582 | /// Whitespaces. |
1583 | } else if (isa<WhitespaceElement>(Val: element)) { |
1584 | // Nothing to parse. |
1585 | |
1586 | /// Arguments. |
1587 | } else if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) { |
1588 | bool parseAsOptional = |
1589 | (genCtx == GenContext::Normal && attr->getVar()->attr.isOptional()); |
1590 | genAttrParser(attr, body, attrTypeCtx, parseAsOptional, useProperties, |
1591 | opCppClassName); |
1592 | } else if (auto *prop = dyn_cast<PropertyVariable>(Val: element)) { |
1593 | genPropertyParser(propVar: prop, body, opCppClassName); |
1594 | |
1595 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) { |
1596 | ArgumentLengthKind lengthKind = getArgumentLengthKind(var: operand->getVar()); |
1597 | StringRef name = operand->getVar()->name; |
1598 | if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) |
1599 | body << formatv(Fmt: variadicOfVariadicOperandParserCode, Vals&: name); |
1600 | else if (lengthKind == ArgumentLengthKind::Variadic) |
1601 | body << formatv(Fmt: variadicOperandParserCode, Vals&: name); |
1602 | else if (lengthKind == ArgumentLengthKind::Optional) |
1603 | body << formatv(Fmt: optionalOperandParserCode, Vals&: name); |
1604 | else |
1605 | body << formatv(Fmt: operandParserCode, Vals&: name); |
1606 | |
1607 | } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) { |
1608 | bool isVariadic = region->getVar()->isVariadic(); |
1609 | body << formatv(Fmt: isVariadic ? regionListParserCode : regionParserCode, |
1610 | Vals: region->getVar()->name); |
1611 | if (hasImplicitTermTrait) |
1612 | body << formatv(Fmt: isVariadic ? regionListEnsureTerminatorParserCode |
1613 | : regionEnsureTerminatorParserCode, |
1614 | Vals: region->getVar()->name); |
1615 | else if (hasSingleBlockTrait) |
1616 | body << formatv(Fmt: isVariadic ? regionListEnsureSingleBlockParserCode |
1617 | : regionEnsureSingleBlockParserCode, |
1618 | Vals: region->getVar()->name); |
1619 | |
1620 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) { |
1621 | bool isVariadic = successor->getVar()->isVariadic(); |
1622 | body << formatv(Fmt: isVariadic ? successorListParserCode : successorParserCode, |
1623 | Vals: successor->getVar()->name); |
1624 | |
1625 | /// Directives. |
1626 | } else if (auto *attrDict = dyn_cast<AttrDictDirective>(Val: element)) { |
1627 | body.indent() << "{\n" ; |
1628 | body.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n" |
1629 | << "if (parser.parseOptionalAttrDict" |
1630 | << (attrDict->isWithKeyword() ? "WithKeyword" : "" ) |
1631 | << "(result.attributes))\n" |
1632 | << " return ::mlir::failure();\n" ; |
1633 | if (useProperties) { |
1634 | body << "if (failed(verifyInherentAttrs(result.name, result.attributes, " |
1635 | "[&]() {\n" |
1636 | << " return parser.emitError(loc) << \"'\" << " |
1637 | "result.name.getStringRef() << \"' op \";\n" |
1638 | << " })))\n" |
1639 | << " return ::mlir::failure();\n" ; |
1640 | } |
1641 | body.unindent() << "}\n" ; |
1642 | body.unindent(); |
1643 | } else if (isa<PropDictDirective>(Val: element)) { |
1644 | if (useProperties) { |
1645 | body << " if (parseProperties(parser, result))\n" |
1646 | << " return ::mlir::failure();\n" ; |
1647 | } |
1648 | } else if (auto *customDir = dyn_cast<CustomDirective>(Val: element)) { |
1649 | genCustomDirectiveParser(dir: customDir, body, useProperties, opCppClassName); |
1650 | } else if (isa<OperandsDirective>(Val: element)) { |
1651 | body << " [[maybe_unused]] ::llvm::SMLoc allOperandLoc =" |
1652 | << " parser.getCurrentLocation();\n" |
1653 | << " if (parser.parseOperandList(allOperands))\n" |
1654 | << " return ::mlir::failure();\n" ; |
1655 | |
1656 | } else if (isa<RegionsDirective>(Val: element)) { |
1657 | body << formatv(Fmt: regionListParserCode, Vals: "full" ); |
1658 | if (hasImplicitTermTrait) |
1659 | body << formatv(Fmt: regionListEnsureTerminatorParserCode, Vals: "full" ); |
1660 | else if (hasSingleBlockTrait) |
1661 | body << formatv(Fmt: regionListEnsureSingleBlockParserCode, Vals: "full" ); |
1662 | |
1663 | } else if (isa<SuccessorsDirective>(Val: element)) { |
1664 | body << formatv(Fmt: successorListParserCode, Vals: "full" ); |
1665 | |
1666 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) { |
1667 | ArgumentLengthKind lengthKind; |
1668 | StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind); |
1669 | if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { |
1670 | body << formatv(Fmt: variadicOfVariadicTypeParserCode, Vals&: listName); |
1671 | } else if (lengthKind == ArgumentLengthKind::Variadic) { |
1672 | body << formatv(Fmt: variadicTypeParserCode, Vals&: listName); |
1673 | } else if (lengthKind == ArgumentLengthKind::Optional) { |
1674 | body << formatv(Fmt: optionalTypeParserCode, Vals&: listName); |
1675 | } else { |
1676 | const char *parserCode = |
1677 | dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode; |
1678 | TypeSwitch<FormatElement *>(dir->getArg()) |
1679 | .Case<OperandVariable, ResultVariable>(caseFn: [&](auto operand) { |
1680 | body << formatv(false, parserCode, |
1681 | operand->getVar()->constraint.getCppType(), |
1682 | listName); |
1683 | }) |
1684 | .Default(defaultFn: [&](auto operand) { |
1685 | body << formatv(Validate: false, Fmt: parserCode, Vals: "::mlir::Type" , Vals&: listName); |
1686 | }); |
1687 | } |
1688 | } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(Val: element)) { |
1689 | ArgumentLengthKind ignored; |
1690 | body << formatv(Fmt: functionalTypeParserCode, |
1691 | Vals: getTypeListName(arg: dir->getInputs(), lengthKind&: ignored), |
1692 | Vals: getTypeListName(arg: dir->getResults(), lengthKind&: ignored)); |
1693 | } else { |
1694 | llvm_unreachable("unknown format element" ); |
1695 | } |
1696 | } |
1697 | |
1698 | void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) { |
1699 | // If any of type resolutions use transformed variables, make sure that the |
1700 | // types of those variables are resolved. |
1701 | SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables; |
1702 | FmtContext verifierFCtx; |
1703 | for (TypeResolution &resolver : |
1704 | llvm::concat<TypeResolution>(Ranges&: resultTypes, Ranges&: operandTypes)) { |
1705 | std::optional<StringRef> transformer = resolver.getVarTransformer(); |
1706 | if (!transformer) |
1707 | continue; |
1708 | // Ensure that we don't verify the same variables twice. |
1709 | const NamedTypeConstraint *variable = resolver.getVariable(); |
1710 | if (!variable || !verifiedVariables.insert(Ptr: variable).second) |
1711 | continue; |
1712 | |
1713 | auto constraint = variable->constraint; |
1714 | body << " for (::mlir::Type type : " << variable->name << "Types) {\n" |
1715 | << " (void)type;\n" |
1716 | << " if (!(" |
1717 | << tgfmt(fmt: constraint.getConditionTemplate(), |
1718 | ctx: &verifierFCtx.withSelf(subst: "type" )) |
1719 | << ")) {\n" |
1720 | << formatv(Fmt: " return parser.emitError(parser.getNameLoc()) << " |
1721 | "\"'{0}' must be {1}, but got \" << type;\n" , |
1722 | Vals: variable->name, Vals: constraint.getSummary()) |
1723 | << " }\n" |
1724 | << " }\n" ; |
1725 | } |
1726 | |
1727 | // Initialize the set of buildable types. |
1728 | if (!buildableTypes.empty()) { |
1729 | FmtContext typeBuilderCtx; |
1730 | typeBuilderCtx.withBuilder(subst: "parser.getBuilder()" ); |
1731 | for (auto &it : buildableTypes) |
1732 | body << " ::mlir::Type odsBuildableType" << it.second << " = " |
1733 | << tgfmt(fmt: it.first, ctx: &typeBuilderCtx) << ";\n" ; |
1734 | } |
1735 | |
1736 | // Emit the code necessary for a type resolver. |
1737 | auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) { |
1738 | if (std::optional<int> val = resolver.getBuilderIdx()) { |
1739 | body << "odsBuildableType" << *val; |
1740 | } else if (const NamedTypeConstraint *var = resolver.getVariable()) { |
1741 | if (std::optional<StringRef> tform = resolver.getVarTransformer()) { |
1742 | FmtContext fmtContext; |
1743 | fmtContext.addSubst(placeholder: "_ctxt" , subst: "parser.getContext()" ); |
1744 | if (var->isVariadic()) |
1745 | fmtContext.withSelf(subst: var->name + "Types" ); |
1746 | else |
1747 | fmtContext.withSelf(subst: var->name + "Types[0]" ); |
1748 | body << tgfmt(fmt: *tform, ctx: &fmtContext); |
1749 | } else { |
1750 | body << var->name << "Types" ; |
1751 | if (!var->isVariadic()) |
1752 | body << "[0]" ; |
1753 | } |
1754 | } else if (const NamedAttribute *attr = resolver.getAttribute()) { |
1755 | if (std::optional<StringRef> tform = resolver.getVarTransformer()) |
1756 | body << tgfmt(fmt: *tform, |
1757 | ctx: &FmtContext().withSelf(subst: attr->name + "Attr.getType()" )); |
1758 | else |
1759 | body << attr->name << "Attr.getType()" ; |
1760 | } else { |
1761 | body << curVar << "Types" ; |
1762 | } |
1763 | }; |
1764 | |
1765 | // Resolve each of the result types. |
1766 | if (!infersResultTypes) { |
1767 | if (allResultTypes) { |
1768 | body << " result.addTypes(allResultTypes);\n" ; |
1769 | } else { |
1770 | for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { |
1771 | body << " result.addTypes(" ; |
1772 | emitTypeResolver(resultTypes[i], op.getResultName(index: i)); |
1773 | body << ");\n" ; |
1774 | } |
1775 | } |
1776 | } |
1777 | |
1778 | // Emit the operand type resolutions. |
1779 | genParserOperandTypeResolution(op, body, emitTypeResolver); |
1780 | |
1781 | // Handle return type inference once all operands have been resolved |
1782 | if (infersResultTypes) |
1783 | body << formatv(Fmt: inferReturnTypesParserCode, Vals: op.getCppClassName()); |
1784 | } |
1785 | |
1786 | void OperationFormat::genParserOperandTypeResolution( |
1787 | Operator &op, MethodBody &body, |
1788 | function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) { |
1789 | // Early exit if there are no operands. |
1790 | if (op.getNumOperands() == 0) |
1791 | return; |
1792 | |
1793 | // Handle the case where all operand types are grouped together with |
1794 | // "types(operands)". |
1795 | if (allOperandTypes) { |
1796 | // If `operands` was specified, use the full operand list directly. |
1797 | if (allOperands) { |
1798 | body << " if (parser.resolveOperands(allOperands, allOperandTypes, " |
1799 | "allOperandLoc, result.operands))\n" |
1800 | " return ::mlir::failure();\n" ; |
1801 | return; |
1802 | } |
1803 | |
1804 | // Otherwise, use llvm::concat to merge the disjoint operand lists together. |
1805 | // llvm::concat does not allow the case of a single range, so guard it here. |
1806 | body << " if (parser.resolveOperands(" ; |
1807 | if (op.getNumOperands() > 1) { |
1808 | body << "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(" ; |
1809 | llvm::interleaveComma(c: op.getOperands(), os&: body, each_fn: [&](auto &operand) { |
1810 | body << operand.name << "Operands" ; |
1811 | }); |
1812 | body << ")" ; |
1813 | } else { |
1814 | body << op.operand_begin()->name << "Operands" ; |
1815 | } |
1816 | body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n" |
1817 | << " return ::mlir::failure();\n" ; |
1818 | return; |
1819 | } |
1820 | |
1821 | // Handle the case where all operands are grouped together with "operands". |
1822 | if (allOperands) { |
1823 | body << " if (parser.resolveOperands(allOperands, " ; |
1824 | |
1825 | // Group all of the operand types together to perform the resolution all at |
1826 | // once. Use llvm::concat to perform the merge. llvm::concat does not allow |
1827 | // the case of a single range, so guard it here. |
1828 | if (op.getNumOperands() > 1) { |
1829 | body << "::llvm::concat<const ::mlir::Type>(" ; |
1830 | llvm::interleaveComma( |
1831 | c: llvm::seq<int>(Begin: 0, End: op.getNumOperands()), os&: body, each_fn: [&](int i) { |
1832 | body << "::llvm::ArrayRef<::mlir::Type>(" ; |
1833 | emitTypeResolver(operandTypes[i], op.getOperand(index: i).name); |
1834 | body << ")" ; |
1835 | }); |
1836 | body << ")" ; |
1837 | } else { |
1838 | emitTypeResolver(operandTypes.front(), op.getOperand(index: 0).name); |
1839 | } |
1840 | |
1841 | body << ", allOperandLoc, result.operands))\n return " |
1842 | "::mlir::failure();\n" ; |
1843 | return; |
1844 | } |
1845 | |
1846 | // The final case is the one where each of the operands types are resolved |
1847 | // separately. |
1848 | for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { |
1849 | NamedTypeConstraint &operand = op.getOperand(index: i); |
1850 | body << " if (parser.resolveOperands(" << operand.name << "Operands, " ; |
1851 | |
1852 | // Resolve the type of this operand. |
1853 | TypeResolution &operandType = operandTypes[i]; |
1854 | emitTypeResolver(operandType, operand.name); |
1855 | |
1856 | body << ", " << operand.name |
1857 | << "OperandsLoc, result.operands))\n return ::mlir::failure();\n" ; |
1858 | } |
1859 | } |
1860 | |
1861 | void OperationFormat::genParserRegionResolution(Operator &op, |
1862 | MethodBody &body) { |
1863 | // Check for the case where all regions were parsed. |
1864 | bool hasAllRegions = llvm::any_of( |
1865 | Range&: elements, P: [](FormatElement *elt) { return isa<RegionsDirective>(Val: elt); }); |
1866 | if (hasAllRegions) { |
1867 | body << " result.addRegions(fullRegions);\n" ; |
1868 | return; |
1869 | } |
1870 | |
1871 | // Otherwise, handle each region individually. |
1872 | for (const NamedRegion ®ion : op.getRegions()) { |
1873 | if (region.isVariadic()) |
1874 | body << " result.addRegions(" << region.name << "Regions);\n" ; |
1875 | else |
1876 | body << " result.addRegion(std::move(" << region.name << "Region));\n" ; |
1877 | } |
1878 | } |
1879 | |
1880 | void OperationFormat::genParserSuccessorResolution(Operator &op, |
1881 | MethodBody &body) { |
1882 | // Check for the case where all successors were parsed. |
1883 | bool hasAllSuccessors = llvm::any_of(Range&: elements, P: [](FormatElement *elt) { |
1884 | return isa<SuccessorsDirective>(Val: elt); |
1885 | }); |
1886 | if (hasAllSuccessors) { |
1887 | body << " result.addSuccessors(fullSuccessors);\n" ; |
1888 | return; |
1889 | } |
1890 | |
1891 | // Otherwise, handle each successor individually. |
1892 | for (const NamedSuccessor &successor : op.getSuccessors()) { |
1893 | if (successor.isVariadic()) |
1894 | body << " result.addSuccessors(" << successor.name << "Successors);\n" ; |
1895 | else |
1896 | body << " result.addSuccessors(" << successor.name << "Successor);\n" ; |
1897 | } |
1898 | } |
1899 | |
1900 | void OperationFormat::genParserVariadicSegmentResolution(Operator &op, |
1901 | MethodBody &body) { |
1902 | if (!allOperands) { |
1903 | if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments" )) { |
1904 | auto interleaveFn = [&](const NamedTypeConstraint &operand) { |
1905 | // If the operand is variadic emit the parsed size. |
1906 | if (operand.isVariableLength()) |
1907 | body << "static_cast<int32_t>(" << operand.name << "Operands.size())" ; |
1908 | else |
1909 | body << "1" ; |
1910 | }; |
1911 | if (op.getDialect().usePropertiesForAttributes()) { |
1912 | body << "::llvm::copy(::llvm::ArrayRef<int32_t>({" ; |
1913 | llvm::interleaveComma(c: op.getOperands(), os&: body, each_fn: interleaveFn); |
1914 | body << formatv(Fmt: "}), " |
1915 | "result.getOrAddProperties<{0}::Properties>()." |
1916 | "operandSegmentSizes.begin());\n" , |
1917 | Vals: op.getCppClassName()); |
1918 | } else { |
1919 | body << " result.addAttribute(\"operandSegmentSizes\", " |
1920 | << "parser.getBuilder().getDenseI32ArrayAttr({" ; |
1921 | llvm::interleaveComma(c: op.getOperands(), os&: body, each_fn: interleaveFn); |
1922 | body << "}));\n" ; |
1923 | } |
1924 | } |
1925 | for (const NamedTypeConstraint &operand : op.getOperands()) { |
1926 | if (!operand.isVariadicOfVariadic()) |
1927 | continue; |
1928 | if (op.getDialect().usePropertiesForAttributes()) { |
1929 | body << formatv( |
1930 | Fmt: " result.getOrAddProperties<{0}::Properties>().{1} = " |
1931 | "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n" , |
1932 | Vals: op.getCppClassName(), |
1933 | Vals: operand.constraint.getVariadicOfVariadicSegmentSizeAttr(), |
1934 | Vals: operand.name); |
1935 | } else { |
1936 | body << formatv( |
1937 | Fmt: " result.addAttribute(\"{0}\", " |
1938 | "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));" |
1939 | "\n" , |
1940 | Vals: operand.constraint.getVariadicOfVariadicSegmentSizeAttr(), |
1941 | Vals: operand.name); |
1942 | } |
1943 | } |
1944 | } |
1945 | |
1946 | if (!allResultTypes && |
1947 | op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments" )) { |
1948 | auto interleaveFn = [&](const NamedTypeConstraint &result) { |
1949 | // If the result is variadic emit the parsed size. |
1950 | if (result.isVariableLength()) |
1951 | body << "static_cast<int32_t>(" << result.name << "Types.size())" ; |
1952 | else |
1953 | body << "1" ; |
1954 | }; |
1955 | if (op.getDialect().usePropertiesForAttributes()) { |
1956 | body << "::llvm::copy(::llvm::ArrayRef<int32_t>({" ; |
1957 | llvm::interleaveComma(c: op.getResults(), os&: body, each_fn: interleaveFn); |
1958 | body << formatv(Fmt: "}), " |
1959 | "result.getOrAddProperties<{0}::Properties>()." |
1960 | "resultSegmentSizes.begin());\n" , |
1961 | Vals: op.getCppClassName()); |
1962 | } else { |
1963 | body << " result.addAttribute(\"resultSegmentSizes\", " |
1964 | << "parser.getBuilder().getDenseI32ArrayAttr({" ; |
1965 | llvm::interleaveComma(c: op.getResults(), os&: body, each_fn: interleaveFn); |
1966 | body << "}));\n" ; |
1967 | } |
1968 | } |
1969 | } |
1970 | |
1971 | //===----------------------------------------------------------------------===// |
1972 | // PrinterGen |
1973 | //===----------------------------------------------------------------------===// |
1974 | |
1975 | /// The code snippet used to generate a printer call for a region of an |
1976 | // operation that has the SingleBlockImplicitTerminator trait. |
1977 | /// |
1978 | /// {0}: The name of the region. |
1979 | const char *regionSingleBlockImplicitTerminatorPrinterCode = R"( |
1980 | { |
1981 | bool printTerminator = true; |
1982 | if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{ |
1983 | printTerminator = !term->getAttrDictionary().empty() || |
1984 | term->getNumOperands() != 0 || |
1985 | term->getNumResults() != 0; |
1986 | } |
1987 | _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true, |
1988 | /*printBlockTerminators=*/printTerminator); |
1989 | } |
1990 | )" ; |
1991 | |
1992 | /// The code snippet used to generate a printer call for an enum that has cases |
1993 | /// that can't be represented with a keyword. |
1994 | /// |
1995 | /// {0}: The name of the enum attribute. |
1996 | /// {1}: The name of the enum attributes symbolToString function. |
1997 | const char *enumAttrBeginPrinterCode = R"( |
1998 | { |
1999 | auto caseValue = {0}(); |
2000 | auto caseValueStr = {1}(caseValue); |
2001 | )" ; |
2002 | |
2003 | /// Generate a check that an optional or default-valued attribute or property |
2004 | /// has a non-default value. For these purposes, the default value of an |
2005 | /// optional attribute is its presence, even if the attribute itself has a |
2006 | /// default value. |
2007 | static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, |
2008 | AttributeVariable &attrElement) { |
2009 | Attribute attr = attrElement.getVar()->attr; |
2010 | std::string getter = op.getGetterName(name: attrElement.getVar()->name); |
2011 | bool optionalAndDefault = attr.isOptional() && attr.hasDefaultValue(); |
2012 | if (optionalAndDefault) |
2013 | body << "(" ; |
2014 | if (attr.isOptional()) |
2015 | body << getter << "Attr()" ; |
2016 | if (optionalAndDefault) |
2017 | body << " && " ; |
2018 | if (attr.hasDefaultValue()) { |
2019 | FmtContext fctx; |
2020 | fctx.withBuilder(subst: "::mlir::OpBuilder((*this)->getContext())" ); |
2021 | body << getter << "Attr() != " |
2022 | << tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx, |
2023 | vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx)); |
2024 | } |
2025 | if (optionalAndDefault) |
2026 | body << ")" ; |
2027 | } |
2028 | |
2029 | static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, |
2030 | PropertyVariable &propElement) { |
2031 | FmtContext fctx; |
2032 | fctx.withBuilder(subst: "::mlir::OpBuilder((*this)->getContext())" ); |
2033 | body << op.getGetterName(name: propElement.getVar()->name) << "() != " |
2034 | << tgfmt(fmt: propElement.getVar()->prop.getDefaultValue(), ctx: &fctx); |
2035 | } |
2036 | |
2037 | /// Elide the variadic segment size attributes if necessary. |
2038 | /// This pushes elided attribute names in `elidedStorage`. |
2039 | static void genVariadicSegmentElision(OperationFormat &fmt, Operator &op, |
2040 | MethodBody &body, |
2041 | const char *elidedStorage) { |
2042 | if (!fmt.allOperands && |
2043 | op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments" )) |
2044 | body << " " << elidedStorage << ".push_back(\"operandSegmentSizes\");\n" ; |
2045 | if (!fmt.allResultTypes && |
2046 | op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments" )) |
2047 | body << " " << elidedStorage << ".push_back(\"resultSegmentSizes\");\n" ; |
2048 | } |
2049 | |
2050 | /// Generate the printer for the 'prop-dict' directive. |
2051 | static void genPropDictPrinter(OperationFormat &fmt, Operator &op, |
2052 | MethodBody &body) { |
2053 | body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedProps;\n" ; |
2054 | |
2055 | genVariadicSegmentElision(fmt, op, body, elidedStorage: "elidedProps" ); |
2056 | |
2057 | for (const NamedProperty *namedProperty : fmt.usedProperties) |
2058 | body << " elidedProps.push_back(\"" << namedProperty->name << "\");\n" ; |
2059 | for (const NamedAttribute *namedAttr : fmt.usedAttributes) |
2060 | body << " elidedProps.push_back(\"" << namedAttr->name << "\");\n" ; |
2061 | |
2062 | // Add code to check attributes for equality with their default values. |
2063 | // Default-valued attributes will not be printed when their value matches the |
2064 | // default. |
2065 | for (const NamedAttribute &namedAttr : op.getAttributes()) { |
2066 | const Attribute &attr = namedAttr.attr; |
2067 | if (!attr.isDerivedAttr() && attr.hasDefaultValue()) { |
2068 | const StringRef &name = namedAttr.name; |
2069 | FmtContext fctx; |
2070 | fctx.withBuilder(subst: "odsBuilder" ); |
2071 | std::string defaultValue = |
2072 | std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx, |
2073 | vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx))); |
2074 | body << " {\n" ; |
2075 | body << " ::mlir::Builder odsBuilder(getContext());\n" ; |
2076 | body << " ::mlir::Attribute attr = " << op.getGetterName(name) |
2077 | << "Attr();\n" ; |
2078 | body << " if(attr && (attr == " << defaultValue << "))\n" ; |
2079 | body << " elidedProps.push_back(\"" << name << "\");\n" ; |
2080 | body << " }\n" ; |
2081 | } |
2082 | } |
2083 | // Similarly, elide default-valued properties. |
2084 | for (const NamedProperty &prop : op.getProperties()) { |
2085 | if (prop.prop.hasDefaultValue()) { |
2086 | FmtContext fctx; |
2087 | fctx.withBuilder(subst: "odsBuilder" ); |
2088 | body << " if (" << op.getGetterName(name: prop.name) |
2089 | << "() == " << tgfmt(fmt: prop.prop.getDefaultValue(), ctx: &fctx) << ") {" ; |
2090 | body << " elidedProps.push_back(\"" << prop.name << "\");\n" ; |
2091 | body << " }\n" ; |
2092 | } |
2093 | } |
2094 | |
2095 | if (fmt.useProperties) { |
2096 | body << " _odsPrinter << \" \";\n" |
2097 | << " printProperties(this->getContext(), _odsPrinter, " |
2098 | "getProperties(), elidedProps);\n" ; |
2099 | } |
2100 | } |
2101 | |
2102 | /// Generate the printer for the 'attr-dict' directive. |
2103 | static void genAttrDictPrinter(OperationFormat &fmt, Operator &op, |
2104 | MethodBody &body, bool withKeyword) { |
2105 | body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;\n" ; |
2106 | |
2107 | genVariadicSegmentElision(fmt, op, body, elidedStorage: "elidedAttrs" ); |
2108 | |
2109 | for (const StringRef key : fmt.inferredAttributes.keys()) |
2110 | body << " elidedAttrs.push_back(\"" << key << "\");\n" ; |
2111 | for (const NamedAttribute *attr : fmt.usedAttributes) |
2112 | body << " elidedAttrs.push_back(\"" << attr->name << "\");\n" ; |
2113 | |
2114 | // Add code to check attributes for equality with their default values. |
2115 | // Default-valued attributes will not be printed when their value matches the |
2116 | // default. |
2117 | for (const NamedAttribute &namedAttr : op.getAttributes()) { |
2118 | const Attribute &attr = namedAttr.attr; |
2119 | if (!attr.isDerivedAttr() && attr.hasDefaultValue()) { |
2120 | const StringRef &name = namedAttr.name; |
2121 | FmtContext fctx; |
2122 | fctx.withBuilder(subst: "odsBuilder" ); |
2123 | std::string defaultValue = |
2124 | std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx, |
2125 | vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx))); |
2126 | body << " {\n" ; |
2127 | body << " ::mlir::Builder odsBuilder(getContext());\n" ; |
2128 | body << " ::mlir::Attribute attr = " << op.getGetterName(name) |
2129 | << "Attr();\n" ; |
2130 | body << " if(attr && (attr == " << defaultValue << "))\n" ; |
2131 | body << " elidedAttrs.push_back(\"" << name << "\");\n" ; |
2132 | body << " }\n" ; |
2133 | } |
2134 | } |
2135 | if (fmt.hasPropDict) |
2136 | body << " _odsPrinter.printOptionalAttrDict" |
2137 | << (withKeyword ? "WithKeyword" : "" ) |
2138 | << "(llvm::to_vector((*this)->getDiscardableAttrs()), elidedAttrs);\n" ; |
2139 | else |
2140 | body << " _odsPrinter.printOptionalAttrDict" |
2141 | << (withKeyword ? "WithKeyword" : "" ) |
2142 | << "((*this)->getAttrs(), elidedAttrs);\n" ; |
2143 | } |
2144 | |
2145 | /// Generate the printer for a literal value. `shouldEmitSpace` is true if a |
2146 | /// space should be emitted before this element. `lastWasPunctuation` is true if |
2147 | /// the previous element was a punctuation literal. |
2148 | static void genLiteralPrinter(StringRef value, MethodBody &body, |
2149 | bool &shouldEmitSpace, bool &lastWasPunctuation) { |
2150 | body << " _odsPrinter" ; |
2151 | |
2152 | // Don't insert a space for certain punctuation. |
2153 | if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation)) |
2154 | body << " << ' '" ; |
2155 | body << " << \"" << value << "\";\n" ; |
2156 | |
2157 | // Insert a space after certain literals. |
2158 | shouldEmitSpace = |
2159 | value.size() != 1 || !StringRef("<({[" ).contains(C: value.front()); |
2160 | lastWasPunctuation = value.front() != '_' && !isalpha(value.front()); |
2161 | } |
2162 | |
2163 | /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation` |
2164 | /// are set to false. |
2165 | static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace, |
2166 | bool &lastWasPunctuation) { |
2167 | if (value) { |
2168 | body << " _odsPrinter << ' ';\n" ; |
2169 | lastWasPunctuation = false; |
2170 | } else { |
2171 | lastWasPunctuation = true; |
2172 | } |
2173 | shouldEmitSpace = false; |
2174 | } |
2175 | |
2176 | /// Generate the printer for a custom directive parameter. |
2177 | static void genCustomDirectiveParameterPrinter(FormatElement *element, |
2178 | const Operator &op, |
2179 | MethodBody &body) { |
2180 | if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) { |
2181 | body << op.getGetterName(name: attr->getVar()->name) << "Attr()" ; |
2182 | |
2183 | } else if (isa<AttrDictDirective>(Val: element)) { |
2184 | body << "getOperation()->getAttrDictionary()" ; |
2185 | |
2186 | } else if (isa<PropDictDirective>(Val: element)) { |
2187 | body << "getProperties()" ; |
2188 | |
2189 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) { |
2190 | body << op.getGetterName(name: operand->getVar()->name) << "()" ; |
2191 | |
2192 | } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) { |
2193 | body << op.getGetterName(name: region->getVar()->name) << "()" ; |
2194 | |
2195 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) { |
2196 | body << op.getGetterName(name: successor->getVar()->name) << "()" ; |
2197 | |
2198 | } else if (auto *dir = dyn_cast<RefDirective>(Val: element)) { |
2199 | genCustomDirectiveParameterPrinter(element: dir->getArg(), op, body); |
2200 | |
2201 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) { |
2202 | auto *typeOperand = dir->getArg(); |
2203 | auto *operand = dyn_cast<OperandVariable>(Val: typeOperand); |
2204 | auto *var = operand ? operand->getVar() |
2205 | : cast<ResultVariable>(Val: typeOperand)->getVar(); |
2206 | std::string name = op.getGetterName(name: var->name); |
2207 | if (var->isVariadic()) |
2208 | body << name << "().getTypes()" ; |
2209 | else if (var->isOptional()) |
2210 | body << formatv(Fmt: "({0}() ? {0}().getType() : ::mlir::Type())" , Vals&: name); |
2211 | else |
2212 | body << name << "().getType()" ; |
2213 | |
2214 | } else if (auto *string = dyn_cast<StringElement>(Val: element)) { |
2215 | FmtContext ctx; |
2216 | ctx.withBuilder(subst: "::mlir::Builder(getContext())" ); |
2217 | ctx.addSubst(placeholder: "_ctxt" , subst: "getContext()" ); |
2218 | body << tgfmt(fmt: string->getValue(), ctx: &ctx); |
2219 | |
2220 | } else if (auto *property = dyn_cast<PropertyVariable>(Val: element)) { |
2221 | FmtContext ctx; |
2222 | const NamedProperty *namedProperty = property->getVar(); |
2223 | ctx.addSubst(placeholder: "_storage" , subst: "getProperties()." + namedProperty->name); |
2224 | body << tgfmt(fmt: namedProperty->prop.getConvertFromStorageCall(), ctx: &ctx); |
2225 | } else { |
2226 | llvm_unreachable("unknown custom directive parameter" ); |
2227 | } |
2228 | } |
2229 | |
2230 | /// Generate the printer for a custom directive. |
2231 | static void genCustomDirectivePrinter(CustomDirective *customDir, |
2232 | const Operator &op, MethodBody &body) { |
2233 | body << " print" << customDir->getName() << "(_odsPrinter, *this" ; |
2234 | for (FormatElement *param : customDir->getElements()) { |
2235 | body << ", " ; |
2236 | genCustomDirectiveParameterPrinter(element: param, op, body); |
2237 | } |
2238 | body << ");\n" ; |
2239 | } |
2240 | |
2241 | /// Generate the printer for a region with the given variable name. |
2242 | static void genRegionPrinter(const Twine ®ionName, MethodBody &body, |
2243 | bool hasImplicitTermTrait) { |
2244 | if (hasImplicitTermTrait) |
2245 | body << formatv(Fmt: regionSingleBlockImplicitTerminatorPrinterCode, Vals: regionName); |
2246 | else |
2247 | body << " _odsPrinter.printRegion(" << regionName << ");\n" ; |
2248 | } |
2249 | static void genVariadicRegionPrinter(const Twine ®ionListName, |
2250 | MethodBody &body, |
2251 | bool hasImplicitTermTrait) { |
2252 | body << " llvm::interleaveComma(" << regionListName |
2253 | << ", _odsPrinter, [&](::mlir::Region ®ion) {\n " ; |
2254 | genRegionPrinter(regionName: "region" , body, hasImplicitTermTrait); |
2255 | body << " });\n" ; |
2256 | } |
2257 | |
2258 | /// Generate the C++ for an operand to a (*-)type directive. |
2259 | static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op, |
2260 | MethodBody &body, |
2261 | bool useArrayRef = true) { |
2262 | if (isa<OperandsDirective>(Val: arg)) |
2263 | return body << "getOperation()->getOperandTypes()" ; |
2264 | if (isa<ResultsDirective>(Val: arg)) |
2265 | return body << "getOperation()->getResultTypes()" ; |
2266 | auto *operand = dyn_cast<OperandVariable>(Val: arg); |
2267 | auto *var = operand ? operand->getVar() : cast<ResultVariable>(Val: arg)->getVar(); |
2268 | if (var->isVariadicOfVariadic()) |
2269 | return body << formatv(Fmt: "{0}().join().getTypes()" , |
2270 | Vals: op.getGetterName(name: var->name)); |
2271 | if (var->isVariadic()) |
2272 | return body << op.getGetterName(name: var->name) << "().getTypes()" ; |
2273 | if (var->isOptional()) |
2274 | return body << formatv( |
2275 | Fmt: "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : " |
2276 | "::llvm::ArrayRef<::mlir::Type>())" , |
2277 | Vals: op.getGetterName(name: var->name)); |
2278 | if (useArrayRef) |
2279 | return body << "::llvm::ArrayRef<::mlir::Type>(" |
2280 | << op.getGetterName(name: var->name) << "().getType())" ; |
2281 | return body << op.getGetterName(name: var->name) << "().getType()" ; |
2282 | } |
2283 | |
2284 | /// Generate the printer for an enum attribute. |
2285 | static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, |
2286 | MethodBody &body) { |
2287 | Attribute baseAttr = var->attr.getBaseAttr(); |
2288 | const EnumInfo enumInfo(&baseAttr.getDef()); |
2289 | std::vector<EnumCase> cases = enumInfo.getAllCases(); |
2290 | |
2291 | body << formatv(Fmt: enumAttrBeginPrinterCode, |
2292 | Vals: (var->attr.isOptional() ? "*" : "" ) + |
2293 | op.getGetterName(name: var->name), |
2294 | Vals: enumInfo.getSymbolToStringFnName()); |
2295 | |
2296 | // Get a string containing all of the cases that can't be represented with a |
2297 | // keyword. |
2298 | BitVector nonKeywordCases(cases.size()); |
2299 | for (auto it : llvm::enumerate(First&: cases)) { |
2300 | if (!canFormatStringAsKeyword(value: it.value().getStr())) |
2301 | nonKeywordCases.set(it.index()); |
2302 | } |
2303 | |
2304 | // Otherwise if this is a bit enum attribute, don't allow cases that may |
2305 | // overlap with other cases. For simplicity sake, only allow cases with a |
2306 | // single bit value. |
2307 | if (enumInfo.isBitEnum()) { |
2308 | for (auto it : llvm::enumerate(First&: cases)) { |
2309 | int64_t value = it.value().getValue(); |
2310 | if (value < 0 || !llvm::isPowerOf2_64(Value: value)) |
2311 | nonKeywordCases.set(it.index()); |
2312 | } |
2313 | } |
2314 | |
2315 | // If there are any cases that can't be used with a keyword, switch on the |
2316 | // case value to determine when to print in the string form. |
2317 | if (nonKeywordCases.any()) { |
2318 | body << " switch (caseValue) {\n" ; |
2319 | StringRef cppNamespace = enumInfo.getCppNamespace(); |
2320 | StringRef enumName = enumInfo.getEnumClassName(); |
2321 | for (auto it : llvm::enumerate(First&: cases)) { |
2322 | if (nonKeywordCases.test(Idx: it.index())) |
2323 | continue; |
2324 | StringRef symbol = it.value().getSymbol(); |
2325 | body << formatv(Fmt: " case {0}::{1}::{2}:\n" , Vals&: cppNamespace, Vals&: enumName, |
2326 | Vals: llvm::isDigit(C: symbol.front()) ? ("_" + symbol) : symbol); |
2327 | } |
2328 | body << " _odsPrinter << caseValueStr;\n" |
2329 | " break;\n" |
2330 | " default:\n" |
2331 | " _odsPrinter << '\"' << caseValueStr << '\"';\n" |
2332 | " break;\n" |
2333 | " }\n" |
2334 | " }\n" ; |
2335 | return; |
2336 | } |
2337 | |
2338 | body << " _odsPrinter << caseValueStr;\n" |
2339 | " }\n" ; |
2340 | } |
2341 | |
2342 | /// Generate the check for the anchor of an optional group. |
2343 | static void genOptionalGroupPrinterAnchor(FormatElement *anchor, |
2344 | const Operator &op, |
2345 | MethodBody &body) { |
2346 | TypeSwitch<FormatElement *>(anchor) |
2347 | .Case<OperandVariable, ResultVariable>(caseFn: [&](auto *element) { |
2348 | const NamedTypeConstraint *var = element->getVar(); |
2349 | std::string name = op.getGetterName(name: var->name); |
2350 | if (var->isOptional()) |
2351 | body << name << "()" ; |
2352 | else if (var->isVariadic()) |
2353 | body << "!" << name << "().empty()" ; |
2354 | }) |
2355 | .Case(caseFn: [&](RegionVariable *element) { |
2356 | const NamedRegion *var = element->getVar(); |
2357 | std::string name = op.getGetterName(name: var->name); |
2358 | // TODO: Add a check for optional regions here when ODS supports it. |
2359 | body << "!" << name << "().empty()" ; |
2360 | }) |
2361 | .Case(caseFn: [&](TypeDirective *element) { |
2362 | genOptionalGroupPrinterAnchor(anchor: element->getArg(), op, body); |
2363 | }) |
2364 | .Case(caseFn: [&](FunctionalTypeDirective *element) { |
2365 | genOptionalGroupPrinterAnchor(anchor: element->getInputs(), op, body); |
2366 | }) |
2367 | .Case(caseFn: [&](AttributeVariable *element) { |
2368 | // Consider a default-valued attribute as present if it's not the |
2369 | // default value and an optional one present if it is set. |
2370 | genNonDefaultValueCheck(body, op, attrElement&: *element); |
2371 | }) |
2372 | .Case(caseFn: [&](PropertyVariable *element) { |
2373 | genNonDefaultValueCheck(body, op, propElement&: *element); |
2374 | }) |
2375 | .Case(caseFn: [&](CustomDirective *ele) { |
2376 | body << '('; |
2377 | llvm::interleave( |
2378 | c: ele->getElements(), os&: body, |
2379 | each_fn: [&](FormatElement *child) { |
2380 | body << '('; |
2381 | genOptionalGroupPrinterAnchor(anchor: child, op, body); |
2382 | body << ')'; |
2383 | }, |
2384 | separator: " || " ); |
2385 | body << ')'; |
2386 | }); |
2387 | } |
2388 | |
2389 | void collect(FormatElement *element, |
2390 | SmallVectorImpl<VariableElement *> &variables) { |
2391 | TypeSwitch<FormatElement *>(element) |
2392 | .Case(caseFn: [&](VariableElement *var) { variables.emplace_back(Args&: var); }) |
2393 | .Case(caseFn: [&](CustomDirective *ele) { |
2394 | for (FormatElement *arg : ele->getElements()) |
2395 | collect(element: arg, variables); |
2396 | }) |
2397 | .Case(caseFn: [&](OptionalElement *ele) { |
2398 | for (FormatElement *arg : ele->getThenElements()) |
2399 | collect(element: arg, variables); |
2400 | for (FormatElement *arg : ele->getElseElements()) |
2401 | collect(element: arg, variables); |
2402 | }) |
2403 | .Case(caseFn: [&](FunctionalTypeDirective *funcType) { |
2404 | collect(element: funcType->getInputs(), variables); |
2405 | collect(element: funcType->getResults(), variables); |
2406 | }) |
2407 | .Case(caseFn: [&](OIListElement *oilist) { |
2408 | for (ArrayRef<FormatElement *> arg : oilist->getParsingElements()) |
2409 | for (FormatElement *arg : arg) |
2410 | collect(element: arg, variables); |
2411 | }); |
2412 | } |
2413 | |
2414 | void OperationFormat::genElementPrinter(FormatElement *element, |
2415 | MethodBody &body, Operator &op, |
2416 | bool &shouldEmitSpace, |
2417 | bool &lastWasPunctuation) { |
2418 | if (LiteralElement *literal = dyn_cast<LiteralElement>(Val: element)) |
2419 | return genLiteralPrinter(value: literal->getSpelling(), body, shouldEmitSpace, |
2420 | lastWasPunctuation); |
2421 | |
2422 | // Emit a whitespace element. |
2423 | if (auto *space = dyn_cast<WhitespaceElement>(Val: element)) { |
2424 | if (space->getValue() == "\\n" ) { |
2425 | body << " _odsPrinter.printNewline();\n" ; |
2426 | } else { |
2427 | genSpacePrinter(value: !space->getValue().empty(), body, shouldEmitSpace, |
2428 | lastWasPunctuation); |
2429 | } |
2430 | return; |
2431 | } |
2432 | |
2433 | // Emit an optional group. |
2434 | if (OptionalElement *optional = dyn_cast<OptionalElement>(Val: element)) { |
2435 | // Emit the check for the presence of the anchor element. |
2436 | FormatElement *anchor = optional->getAnchor(); |
2437 | body << " if (" ; |
2438 | if (optional->isInverted()) |
2439 | body << "!" ; |
2440 | genOptionalGroupPrinterAnchor(anchor, op, body); |
2441 | body << ") {\n" ; |
2442 | body.indent(); |
2443 | |
2444 | // If the anchor is a unit attribute, we don't need to print it. When |
2445 | // parsing, we will add this attribute if this group is present. |
2446 | ArrayRef<FormatElement *> thenElements = optional->getThenElements(); |
2447 | ArrayRef<FormatElement *> elseElements = optional->getElseElements(); |
2448 | FormatElement *elidedAnchorElement = nullptr; |
2449 | auto *anchorAttr = dyn_cast<AttributeLikeVariable>(Val: anchor); |
2450 | if (anchorAttr && anchorAttr != thenElements.front() && |
2451 | (elseElements.empty() || anchorAttr != elseElements.front()) && |
2452 | anchorAttr->isUnit()) { |
2453 | elidedAnchorElement = anchorAttr; |
2454 | } |
2455 | auto genElementPrinters = [&](ArrayRef<FormatElement *> elements) { |
2456 | for (FormatElement *childElement : elements) { |
2457 | if (childElement != elidedAnchorElement) { |
2458 | genElementPrinter(element: childElement, body, op, shouldEmitSpace, |
2459 | lastWasPunctuation); |
2460 | } |
2461 | } |
2462 | }; |
2463 | |
2464 | // Emit each of the elements. |
2465 | genElementPrinters(thenElements); |
2466 | body << "}" ; |
2467 | |
2468 | // Emit each of the else elements. |
2469 | if (!elseElements.empty()) { |
2470 | body << " else {\n" ; |
2471 | genElementPrinters(elseElements); |
2472 | body << "}" ; |
2473 | } |
2474 | |
2475 | body.unindent() << "\n" ; |
2476 | return; |
2477 | } |
2478 | |
2479 | // Emit the OIList |
2480 | if (auto *oilist = dyn_cast<OIListElement>(Val: element)) { |
2481 | for (auto clause : oilist->getClauses()) { |
2482 | LiteralElement *lelement = std::get<0>(t&: clause); |
2483 | ArrayRef<FormatElement *> pelement = std::get<1>(t&: clause); |
2484 | |
2485 | SmallVector<VariableElement *> vars; |
2486 | for (FormatElement *el : pelement) |
2487 | collect(element: el, variables&: vars); |
2488 | body << " if (false" ; |
2489 | for (VariableElement *var : vars) { |
2490 | TypeSwitch<FormatElement *>(var) |
2491 | .Case(caseFn: [&](AttributeVariable *attrEle) { |
2492 | body << " || (" ; |
2493 | genNonDefaultValueCheck(body, op, attrElement&: *attrEle); |
2494 | body << ")" ; |
2495 | }) |
2496 | .Case(caseFn: [&](PropertyVariable *propEle) { |
2497 | body << " || (" ; |
2498 | genNonDefaultValueCheck(body, op, propElement&: *propEle); |
2499 | body << ")" ; |
2500 | }) |
2501 | .Case(caseFn: [&](OperandVariable *ele) { |
2502 | if (ele->getVar()->isVariadic()) { |
2503 | body << " || " << op.getGetterName(name: ele->getVar()->name) |
2504 | << "().size()" ; |
2505 | } else { |
2506 | body << " || " << op.getGetterName(name: ele->getVar()->name) << "()" ; |
2507 | } |
2508 | }) |
2509 | .Case(caseFn: [&](ResultVariable *ele) { |
2510 | if (ele->getVar()->isVariadic()) { |
2511 | body << " || " << op.getGetterName(name: ele->getVar()->name) |
2512 | << "().size()" ; |
2513 | } else { |
2514 | body << " || " << op.getGetterName(name: ele->getVar()->name) << "()" ; |
2515 | } |
2516 | }) |
2517 | .Case(caseFn: [&](RegionVariable *reg) { |
2518 | body << " || " << op.getGetterName(name: reg->getVar()->name) << "()" ; |
2519 | }); |
2520 | } |
2521 | |
2522 | body << ") {\n" ; |
2523 | genLiteralPrinter(value: lelement->getSpelling(), body, shouldEmitSpace, |
2524 | lastWasPunctuation); |
2525 | if (oilist->getUnitVariableParsingElement(pelement) == nullptr) { |
2526 | for (FormatElement *element : pelement) |
2527 | genElementPrinter(element, body, op, shouldEmitSpace, |
2528 | lastWasPunctuation); |
2529 | } |
2530 | body << " }\n" ; |
2531 | } |
2532 | return; |
2533 | } |
2534 | |
2535 | // Emit the attribute dictionary. |
2536 | if (auto *attrDict = dyn_cast<AttrDictDirective>(Val: element)) { |
2537 | genAttrDictPrinter(fmt&: *this, op, body, withKeyword: attrDict->isWithKeyword()); |
2538 | lastWasPunctuation = false; |
2539 | return; |
2540 | } |
2541 | |
2542 | // Emit the property dictionary. |
2543 | if (isa<PropDictDirective>(Val: element)) { |
2544 | genPropDictPrinter(fmt&: *this, op, body); |
2545 | lastWasPunctuation = false; |
2546 | return; |
2547 | } |
2548 | |
2549 | // Optionally insert a space before the next element. The AttrDict printer |
2550 | // already adds a space as necessary. |
2551 | if (shouldEmitSpace || !lastWasPunctuation) |
2552 | body << " _odsPrinter << ' ';\n" ; |
2553 | lastWasPunctuation = false; |
2554 | shouldEmitSpace = true; |
2555 | |
2556 | if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) { |
2557 | const NamedAttribute *var = attr->getVar(); |
2558 | |
2559 | // If we are formatting as an enum, symbolize the attribute as a string. |
2560 | if (canFormatEnumAttr(attr: var)) |
2561 | return genEnumAttrPrinter(var, op, body); |
2562 | |
2563 | // If we are formatting as a symbol name, handle it as a symbol name. |
2564 | if (shouldFormatSymbolNameAttr(attr: var)) { |
2565 | body << " _odsPrinter.printSymbolName(" << op.getGetterName(name: var->name) |
2566 | << "Attr().getValue());\n" ; |
2567 | return; |
2568 | } |
2569 | |
2570 | // Elide the attribute type if it is buildable. |
2571 | if (attr->getTypeBuilder()) |
2572 | body << " _odsPrinter.printAttributeWithoutType(" |
2573 | << op.getGetterName(name: var->name) << "Attr());\n" ; |
2574 | else if (attr->shouldBeQualified() || |
2575 | var->attr.getStorageType() == "::mlir::Attribute" ) |
2576 | body << " _odsPrinter.printAttribute(" << op.getGetterName(name: var->name) |
2577 | << "Attr());\n" ; |
2578 | else |
2579 | body << "_odsPrinter.printStrippedAttrOrType(" |
2580 | << op.getGetterName(name: var->name) << "Attr());\n" ; |
2581 | } else if (auto *property = dyn_cast<PropertyVariable>(Val: element)) { |
2582 | const NamedProperty *var = property->getVar(); |
2583 | FmtContext fmtContext; |
2584 | fmtContext.addSubst(placeholder: "_printer" , subst: "_odsPrinter" ); |
2585 | fmtContext.addSubst(placeholder: "_ctxt" , subst: "getContext()" ); |
2586 | fmtContext.addSubst(placeholder: "_storage" , subst: "getProperties()." + var->name); |
2587 | body << tgfmt(fmt: var->prop.getPrinterCall(), ctx: &fmtContext) << ";\n" ; |
2588 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) { |
2589 | if (operand->getVar()->isVariadicOfVariadic()) { |
2590 | body << " ::llvm::interleaveComma(" |
2591 | << op.getGetterName(name: operand->getVar()->name) |
2592 | << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << " |
2593 | "\"(\" << operands << " |
2594 | "\")\"; });\n" ; |
2595 | |
2596 | } else if (operand->getVar()->isOptional()) { |
2597 | body << " if (::mlir::Value value = " |
2598 | << op.getGetterName(name: operand->getVar()->name) << "())\n" |
2599 | << " _odsPrinter << value;\n" ; |
2600 | } else { |
2601 | body << " _odsPrinter << " << op.getGetterName(name: operand->getVar()->name) |
2602 | << "();\n" ; |
2603 | } |
2604 | } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) { |
2605 | const NamedRegion *var = region->getVar(); |
2606 | std::string name = op.getGetterName(name: var->name); |
2607 | if (var->isVariadic()) { |
2608 | genVariadicRegionPrinter(regionListName: name + "()" , body, hasImplicitTermTrait); |
2609 | } else { |
2610 | genRegionPrinter(regionName: name + "()" , body, hasImplicitTermTrait); |
2611 | } |
2612 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) { |
2613 | const NamedSuccessor *var = successor->getVar(); |
2614 | std::string name = op.getGetterName(name: var->name); |
2615 | if (var->isVariadic()) |
2616 | body << " ::llvm::interleaveComma(" << name << "(), _odsPrinter);\n" ; |
2617 | else |
2618 | body << " _odsPrinter << " << name << "();\n" ; |
2619 | } else if (auto *dir = dyn_cast<CustomDirective>(Val: element)) { |
2620 | genCustomDirectivePrinter(customDir: dir, op, body); |
2621 | } else if (isa<OperandsDirective>(Val: element)) { |
2622 | body << " _odsPrinter << getOperation()->getOperands();\n" ; |
2623 | } else if (isa<RegionsDirective>(Val: element)) { |
2624 | genVariadicRegionPrinter(regionListName: "getOperation()->getRegions()" , body, |
2625 | hasImplicitTermTrait); |
2626 | } else if (isa<SuccessorsDirective>(Val: element)) { |
2627 | body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), " |
2628 | "_odsPrinter);\n" ; |
2629 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) { |
2630 | if (auto *operand = dyn_cast<OperandVariable>(Val: dir->getArg())) { |
2631 | if (operand->getVar()->isVariadicOfVariadic()) { |
2632 | body << formatv( |
2633 | Fmt: " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, " |
2634 | "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << " |
2635 | "types << \")\"; });\n" , |
2636 | Vals: op.getGetterName(name: operand->getVar()->name)); |
2637 | return; |
2638 | } |
2639 | } |
2640 | const NamedTypeConstraint *var = nullptr; |
2641 | { |
2642 | if (auto *operand = dyn_cast<OperandVariable>(Val: dir->getArg())) |
2643 | var = operand->getVar(); |
2644 | else if (auto *operand = dyn_cast<ResultVariable>(Val: dir->getArg())) |
2645 | var = operand->getVar(); |
2646 | } |
2647 | if (var && !var->isVariadicOfVariadic() && !var->isVariadic() && |
2648 | !var->isOptional()) { |
2649 | StringRef cppType = var->constraint.getCppType(); |
2650 | if (dir->shouldBeQualified()) { |
2651 | body << " _odsPrinter << " << op.getGetterName(name: var->name) |
2652 | << "().getType();\n" ; |
2653 | return; |
2654 | } |
2655 | body << " {\n" |
2656 | << " auto type = " << op.getGetterName(name: var->name) |
2657 | << "().getType();\n" |
2658 | << " if (auto validType = ::llvm::dyn_cast<" << cppType |
2659 | << ">(type))\n" |
2660 | << " _odsPrinter.printStrippedAttrOrType(validType);\n" |
2661 | << " else\n" |
2662 | << " _odsPrinter << type;\n" |
2663 | << " }\n" ; |
2664 | return; |
2665 | } |
2666 | body << " _odsPrinter << " ; |
2667 | genTypeOperandPrinter(arg: dir->getArg(), op, body, /*useArrayRef=*/false) |
2668 | << ";\n" ; |
2669 | } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(Val: element)) { |
2670 | body << " _odsPrinter.printFunctionalType(" ; |
2671 | genTypeOperandPrinter(arg: dir->getInputs(), op, body) << ", " ; |
2672 | genTypeOperandPrinter(arg: dir->getResults(), op, body) << ");\n" ; |
2673 | } else { |
2674 | llvm_unreachable("unknown format element" ); |
2675 | } |
2676 | } |
2677 | |
2678 | void OperationFormat::genPrinter(Operator &op, OpClass &opClass) { |
2679 | auto *method = opClass.addMethod( |
2680 | retType: "void" , name: "print" , |
2681 | args: MethodParameter("::mlir::OpAsmPrinter &" , "_odsPrinter" )); |
2682 | auto &body = method->body(); |
2683 | |
2684 | // Flags for if we should emit a space, and if the last element was |
2685 | // punctuation. |
2686 | bool shouldEmitSpace = true, lastWasPunctuation = false; |
2687 | for (FormatElement *element : elements) |
2688 | genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation); |
2689 | } |
2690 | |
2691 | //===----------------------------------------------------------------------===// |
2692 | // OpFormatParser |
2693 | //===----------------------------------------------------------------------===// |
2694 | |
2695 | /// Function to find an element within the given range that has the same name as |
2696 | /// 'name'. |
2697 | template <typename RangeT> |
2698 | static auto findArg(RangeT &&range, StringRef name) { |
2699 | auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; }); |
2700 | return it != range.end() ? &*it : nullptr; |
2701 | } |
2702 | |
2703 | namespace { |
2704 | /// This class implements a parser for an instance of an operation assembly |
2705 | /// format. |
2706 | class OpFormatParser : public FormatParser { |
2707 | public: |
2708 | OpFormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op) |
2709 | : FormatParser(mgr, op.getLoc()[0]), fmt(format), op(op), |
2710 | seenOperandTypes(op.getNumOperands()), |
2711 | seenResultTypes(op.getNumResults()) {} |
2712 | |
2713 | protected: |
2714 | /// Verify the format elements. |
2715 | LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override; |
2716 | /// Verify the arguments to a custom directive. |
2717 | LogicalResult |
2718 | verifyCustomDirectiveArguments(SMLoc loc, |
2719 | ArrayRef<FormatElement *> arguments) override; |
2720 | /// Verify the elements of an optional group. |
2721 | LogicalResult verifyOptionalGroupElements(SMLoc loc, |
2722 | ArrayRef<FormatElement *> elements, |
2723 | FormatElement *anchor) override; |
2724 | LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element, |
2725 | bool isAnchor); |
2726 | |
2727 | LogicalResult markQualified(SMLoc loc, FormatElement *element) override; |
2728 | |
2729 | /// Parse an operation variable. |
2730 | FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name, |
2731 | Context ctx) override; |
2732 | /// Parse an operation format directive. |
2733 | FailureOr<FormatElement *> |
2734 | parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override; |
2735 | |
2736 | private: |
2737 | /// This struct represents a type resolution instance. It includes a specific |
2738 | /// type as well as an optional transformer to apply to that type in order to |
2739 | /// properly resolve the type of a variable. |
2740 | struct TypeResolutionInstance { |
2741 | ConstArgument resolver; |
2742 | std::optional<StringRef> transformer; |
2743 | }; |
2744 | |
2745 | /// Verify the state of operation attributes within the format. |
2746 | LogicalResult verifyAttributes(SMLoc loc, ArrayRef<FormatElement *> elements); |
2747 | |
2748 | /// Verify that attributes elements aren't followed by colon literals. |
2749 | LogicalResult verifyAttributeColonType(SMLoc loc, |
2750 | ArrayRef<FormatElement *> elements); |
2751 | /// Verify that the attribute dictionary directive isn't followed by a region. |
2752 | LogicalResult verifyAttrDictRegion(SMLoc loc, |
2753 | ArrayRef<FormatElement *> elements); |
2754 | |
2755 | /// Verify the state of operation operands within the format. |
2756 | LogicalResult |
2757 | verifyOperands(SMLoc loc, |
2758 | StringMap<TypeResolutionInstance> &variableTyResolver); |
2759 | |
2760 | /// Verify the state of operation regions within the format. |
2761 | LogicalResult verifyRegions(SMLoc loc); |
2762 | |
2763 | /// Verify the state of operation results within the format. |
2764 | LogicalResult |
2765 | verifyResults(SMLoc loc, |
2766 | StringMap<TypeResolutionInstance> &variableTyResolver); |
2767 | |
2768 | /// Verify the state of operation successors within the format. |
2769 | LogicalResult verifySuccessors(SMLoc loc); |
2770 | |
2771 | LogicalResult verifyOIListElements(SMLoc loc, |
2772 | ArrayRef<FormatElement *> elements); |
2773 | |
2774 | /// Given the values of an `AllTypesMatch` trait, check for inferable type |
2775 | /// resolution. |
2776 | void handleAllTypesMatchConstraint( |
2777 | ArrayRef<StringRef> values, |
2778 | StringMap<TypeResolutionInstance> &variableTyResolver); |
2779 | /// Check for inferable type resolution given all operands, and or results, |
2780 | /// have the same type. If 'includeResults' is true, the results also have the |
2781 | /// same type as all of the operands. |
2782 | void handleSameTypesConstraint( |
2783 | StringMap<TypeResolutionInstance> &variableTyResolver, |
2784 | bool includeResults); |
2785 | /// Check for inferable type resolution based on another operand, result, or |
2786 | /// attribute. |
2787 | void handleTypesMatchConstraint( |
2788 | StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def); |
2789 | |
2790 | /// Returns an argument or attribute with the given name that has been seen |
2791 | /// within the format. |
2792 | ConstArgument findSeenArg(StringRef name); |
2793 | |
2794 | /// Parse the various different directives. |
2795 | FailureOr<FormatElement *> parsePropDictDirective(SMLoc loc, Context context); |
2796 | FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context, |
2797 | bool withKeyword); |
2798 | FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc, |
2799 | Context context); |
2800 | FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context); |
2801 | LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc); |
2802 | FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context); |
2803 | FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context); |
2804 | FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context); |
2805 | FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc, |
2806 | Context context); |
2807 | FailureOr<FormatElement *> parseTypeDirective(SMLoc loc, Context context); |
2808 | FailureOr<FormatElement *> parseTypeDirectiveOperand(SMLoc loc, |
2809 | bool isRefChild = false); |
2810 | |
2811 | //===--------------------------------------------------------------------===// |
2812 | // Fields |
2813 | //===--------------------------------------------------------------------===// |
2814 | |
2815 | OperationFormat &fmt; |
2816 | Operator &op; |
2817 | |
2818 | // The following are various bits of format state used for verification |
2819 | // during parsing. |
2820 | bool hasAttrDict = false; |
2821 | bool hasPropDict = false; |
2822 | bool hasAllRegions = false, hasAllSuccessors = false; |
2823 | bool canInferResultTypes = false; |
2824 | llvm::SmallBitVector seenOperandTypes, seenResultTypes; |
2825 | llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs; |
2826 | llvm::DenseSet<const NamedTypeConstraint *> seenOperands; |
2827 | llvm::DenseSet<const NamedRegion *> seenRegions; |
2828 | llvm::DenseSet<const NamedSuccessor *> seenSuccessors; |
2829 | llvm::SmallSetVector<const NamedProperty *, 8> seenProperties; |
2830 | }; |
2831 | } // namespace |
2832 | |
2833 | LogicalResult OpFormatParser::verify(SMLoc loc, |
2834 | ArrayRef<FormatElement *> elements) { |
2835 | // Check that the attribute dictionary is in the format. |
2836 | if (!hasAttrDict) |
2837 | return emitError(loc, msg: "'attr-dict' directive not found in " |
2838 | "custom assembly format" ); |
2839 | |
2840 | // Check for any type traits that we can use for inferring types. |
2841 | StringMap<TypeResolutionInstance> variableTyResolver; |
2842 | for (const Trait &trait : op.getTraits()) { |
2843 | const Record &def = trait.getDef(); |
2844 | if (def.isSubClassOf(Name: "AllTypesMatch" )) { |
2845 | handleAllTypesMatchConstraint(values: def.getValueAsListOfStrings(FieldName: "values" ), |
2846 | variableTyResolver); |
2847 | } else if (def.getName() == "SameTypeOperands" ) { |
2848 | handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false); |
2849 | } else if (def.getName() == "SameOperandsAndResultType" ) { |
2850 | handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); |
2851 | } else if (def.isSubClassOf(Name: "TypesMatchWith" )) { |
2852 | handleTypesMatchConstraint(variableTyResolver, def); |
2853 | } else if (!op.allResultTypesKnown()) { |
2854 | // This doesn't check the name directly to handle |
2855 | // DeclareOpInterfaceMethods<InferTypeOpInterface> |
2856 | // and the like. |
2857 | // TODO: Add hasCppInterface check. |
2858 | if (auto name = def.getValueAsOptionalString(FieldName: "cppInterfaceName" )) { |
2859 | if (*name == "InferTypeOpInterface" && |
2860 | def.getValueAsString(FieldName: "cppNamespace" ) == "::mlir" ) |
2861 | canInferResultTypes = true; |
2862 | } |
2863 | } |
2864 | } |
2865 | |
2866 | // Verify the state of the various operation components. |
2867 | if (failed(Result: verifyAttributes(loc, elements)) || |
2868 | failed(Result: verifyResults(loc, variableTyResolver)) || |
2869 | failed(Result: verifyOperands(loc, variableTyResolver)) || |
2870 | failed(Result: verifyRegions(loc)) || failed(Result: verifySuccessors(loc)) || |
2871 | failed(Result: verifyOIListElements(loc, elements))) |
2872 | return failure(); |
2873 | |
2874 | // Collect the set of used attributes in the format. |
2875 | fmt.usedAttributes = std::move(seenAttrs); |
2876 | fmt.usedProperties = std::move(seenProperties); |
2877 | |
2878 | // Set whether prop-dict is used in the format |
2879 | fmt.hasPropDict = hasPropDict; |
2880 | return success(); |
2881 | } |
2882 | |
2883 | LogicalResult |
2884 | OpFormatParser::verifyAttributes(SMLoc loc, |
2885 | ArrayRef<FormatElement *> elements) { |
2886 | // Check that there are no `:` literals after an attribute without a constant |
2887 | // type. The attribute grammar contains an optional trailing colon type, which |
2888 | // can lead to unexpected and generally unintended behavior. Given that, it is |
2889 | // better to just error out here instead. |
2890 | if (failed(Result: verifyAttributeColonType(loc, elements))) |
2891 | return failure(); |
2892 | // Check that there are no region variables following an attribute dicitonary. |
2893 | // Both start with `{` and so the optional attribute dictionary can cause |
2894 | // format ambiguities. |
2895 | if (failed(Result: verifyAttrDictRegion(loc, elements))) |
2896 | return failure(); |
2897 | |
2898 | // Check for VariadicOfVariadic variables. The segment attribute of those |
2899 | // variables will be infered. |
2900 | for (const NamedTypeConstraint *var : seenOperands) { |
2901 | if (var->constraint.isVariadicOfVariadic()) { |
2902 | fmt.inferredAttributes.insert( |
2903 | key: var->constraint.getVariadicOfVariadicSegmentSizeAttr()); |
2904 | } |
2905 | } |
2906 | |
2907 | return success(); |
2908 | } |
2909 | |
2910 | /// Returns whether the single format element is optionally parsed. |
2911 | static bool isOptionallyParsed(FormatElement *el) { |
2912 | if (auto *attrVar = dyn_cast<AttributeVariable>(Val: el)) { |
2913 | Attribute attr = attrVar->getVar()->attr; |
2914 | return attr.isOptional() || attr.hasDefaultValue(); |
2915 | } |
2916 | if (auto *propVar = dyn_cast<PropertyVariable>(Val: el)) { |
2917 | const Property &prop = propVar->getVar()->prop; |
2918 | return prop.hasDefaultValue() && prop.hasOptionalParser(); |
2919 | } |
2920 | if (auto *operandVar = dyn_cast<OperandVariable>(Val: el)) { |
2921 | const NamedTypeConstraint *operand = operandVar->getVar(); |
2922 | return operand->isOptional() || operand->isVariadic() || |
2923 | operand->isVariadicOfVariadic(); |
2924 | } |
2925 | if (auto *successorVar = dyn_cast<SuccessorVariable>(Val: el)) |
2926 | return successorVar->getVar()->isVariadic(); |
2927 | if (auto *regionVar = dyn_cast<RegionVariable>(Val: el)) |
2928 | return regionVar->getVar()->isVariadic(); |
2929 | return isa<WhitespaceElement, AttrDictDirective>(Val: el); |
2930 | } |
2931 | |
2932 | /// Scan the given range of elements from the start for an invalid format |
2933 | /// element that satisfies `isInvalid`, skipping any optionally-parsed elements. |
2934 | /// If an optional group is encountered, this function recurses into the 'then' |
2935 | /// and 'else' elements to check if they are invalid. Returns `success` if the |
2936 | /// range is known to be valid or `std::nullopt` if scanning reached the end. |
2937 | /// |
2938 | /// Since the guard element of an optional group is required, this function |
2939 | /// accepts an optional element pointer to mark it as required. |
2940 | static std::optional<LogicalResult> checkRangeForElement( |
2941 | FormatElement *base, |
2942 | function_ref<bool(FormatElement *, FormatElement *)> isInvalid, |
2943 | iterator_range<ArrayRef<FormatElement *>::iterator> elementRange, |
2944 | FormatElement *optionalGuard = nullptr) { |
2945 | for (FormatElement *element : elementRange) { |
2946 | // If we encounter an invalid element, return an error. |
2947 | if (isInvalid(base, element)) |
2948 | return failure(); |
2949 | |
2950 | // Recurse on optional groups. |
2951 | if (auto *optional = dyn_cast<OptionalElement>(Val: element)) { |
2952 | if (std::optional<LogicalResult> result = checkRangeForElement( |
2953 | base, isInvalid, elementRange: optional->getThenElements(), |
2954 | // The optional group guard is required for the group. |
2955 | optionalGuard: optional->getThenElements().front())) |
2956 | if (failed(Result: *result)) |
2957 | return failure(); |
2958 | if (std::optional<LogicalResult> result = checkRangeForElement( |
2959 | base, isInvalid, elementRange: optional->getElseElements())) |
2960 | if (failed(Result: *result)) |
2961 | return failure(); |
2962 | // Skip the optional group. |
2963 | continue; |
2964 | } |
2965 | |
2966 | // Skip optionally parsed elements. |
2967 | if (element != optionalGuard && isOptionallyParsed(el: element)) |
2968 | continue; |
2969 | |
2970 | // We found a closing element that is valid. |
2971 | return success(); |
2972 | } |
2973 | // Return std::nullopt to indicate that we reached the end. |
2974 | return std::nullopt; |
2975 | } |
2976 | |
2977 | /// For the given elements, check whether any attributes are followed by a colon |
2978 | /// literal, resulting in an ambiguous assembly format. Returns a non-null |
2979 | /// attribute if verification of said attribute reached the end of the range. |
2980 | /// Returns null if all attribute elements are verified. |
2981 | static FailureOr<FormatElement *> verifyAdjacentElements( |
2982 | function_ref<bool(FormatElement *)> isBase, |
2983 | function_ref<bool(FormatElement *, FormatElement *)> isInvalid, |
2984 | ArrayRef<FormatElement *> elements) { |
2985 | for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) { |
2986 | // The current attribute being verified. |
2987 | FormatElement *base; |
2988 | |
2989 | if (isBase(*it)) { |
2990 | base = *it; |
2991 | } else if (auto *optional = dyn_cast<OptionalElement>(Val: *it)) { |
2992 | // Recurse on optional groups. |
2993 | FailureOr<FormatElement *> thenResult = verifyAdjacentElements( |
2994 | isBase, isInvalid, elements: optional->getThenElements()); |
2995 | if (failed(Result: thenResult)) |
2996 | return failure(); |
2997 | FailureOr<FormatElement *> elseResult = verifyAdjacentElements( |
2998 | isBase, isInvalid, elements: optional->getElseElements()); |
2999 | if (failed(Result: elseResult)) |
3000 | return failure(); |
3001 | // If either optional group has an unverified attribute, save it. |
3002 | // Otherwise, move on to the next element. |
3003 | if (!(base = *thenResult) && !(base = *elseResult)) |
3004 | continue; |
3005 | } else { |
3006 | continue; |
3007 | } |
3008 | |
3009 | // Verify subsequent elements for potential ambiguities. |
3010 | if (std::optional<LogicalResult> result = |
3011 | checkRangeForElement(base, isInvalid, elementRange: {std::next(x: it), e})) { |
3012 | if (failed(Result: *result)) |
3013 | return failure(); |
3014 | } else { |
3015 | // Since we reached the end, return the attribute as unverified. |
3016 | return base; |
3017 | } |
3018 | } |
3019 | // All attribute elements are known to be verified. |
3020 | return nullptr; |
3021 | } |
3022 | |
3023 | LogicalResult |
3024 | OpFormatParser::verifyAttributeColonType(SMLoc loc, |
3025 | ArrayRef<FormatElement *> elements) { |
3026 | auto isBase = [](FormatElement *el) { |
3027 | auto *attr = dyn_cast<AttributeVariable>(Val: el); |
3028 | if (!attr) |
3029 | return false; |
3030 | // Check only attributes without type builders or that are known to call |
3031 | // the generic attribute parser. |
3032 | return !attr->getTypeBuilder() && |
3033 | (attr->shouldBeQualified() || |
3034 | attr->getVar()->attr.getStorageType() == "::mlir::Attribute" ); |
3035 | }; |
3036 | auto isInvalid = [&](FormatElement *base, FormatElement *el) { |
3037 | auto *literal = dyn_cast<LiteralElement>(Val: el); |
3038 | if (!literal || literal->getSpelling() != ":" ) |
3039 | return false; |
3040 | // If we encounter `:`, the range is known to be invalid. |
3041 | (void)emitError( |
3042 | loc, msg: formatv(Fmt: "format ambiguity caused by `:` literal found after " |
3043 | "attribute `{0}` which does not have a buildable type" , |
3044 | Vals: cast<AttributeVariable>(Val: base)->getVar()->name)); |
3045 | return true; |
3046 | }; |
3047 | return verifyAdjacentElements(isBase, isInvalid, elements); |
3048 | } |
3049 | |
3050 | LogicalResult |
3051 | OpFormatParser::verifyAttrDictRegion(SMLoc loc, |
3052 | ArrayRef<FormatElement *> elements) { |
3053 | auto isBase = [](FormatElement *el) { |
3054 | if (auto *attrDict = dyn_cast<AttrDictDirective>(Val: el)) |
3055 | return !attrDict->isWithKeyword(); |
3056 | return false; |
3057 | }; |
3058 | auto isInvalid = [&](FormatElement *base, FormatElement *el) { |
3059 | auto *region = dyn_cast<RegionVariable>(Val: el); |
3060 | if (!region) |
3061 | return false; |
3062 | (void)emitErrorAndNote( |
3063 | loc, |
3064 | msg: formatv(Fmt: "format ambiguity caused by `attr-dict` directive " |
3065 | "followed by region `{0}`" , |
3066 | Vals: region->getVar()->name), |
3067 | note: "try using `attr-dict-with-keyword` instead" ); |
3068 | return true; |
3069 | }; |
3070 | return verifyAdjacentElements(isBase, isInvalid, elements); |
3071 | } |
3072 | |
3073 | LogicalResult OpFormatParser::verifyOperands( |
3074 | SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) { |
3075 | // Check that all of the operands are within the format, and their types can |
3076 | // be inferred. |
3077 | auto &buildableTypes = fmt.buildableTypes; |
3078 | for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { |
3079 | NamedTypeConstraint &operand = op.getOperand(index: i); |
3080 | |
3081 | // Check that the operand itself is in the format. |
3082 | if (!fmt.allOperands && !seenOperands.count(V: &operand)) { |
3083 | return emitErrorAndNote(loc, |
3084 | msg: "operand #" + Twine(i) + ", named '" + |
3085 | operand.name + "', not found" , |
3086 | note: "suggest adding a '$" + operand.name + |
3087 | "' directive to the custom assembly format" ); |
3088 | } |
3089 | |
3090 | // Check that the operand type is in the format, or that it can be inferred. |
3091 | if (fmt.allOperandTypes || seenOperandTypes.test(Idx: i)) |
3092 | continue; |
3093 | |
3094 | // Check to see if we can infer this type from another variable. |
3095 | auto varResolverIt = variableTyResolver.find(Key: op.getOperand(index: i).name); |
3096 | if (varResolverIt != variableTyResolver.end()) { |
3097 | TypeResolutionInstance &resolver = varResolverIt->second; |
3098 | fmt.operandTypes[i].setResolver(arg: resolver.resolver, transformer: resolver.transformer); |
3099 | continue; |
3100 | } |
3101 | |
3102 | // Similarly to results, allow a custom builder for resolving the type if |
3103 | // we aren't using the 'operands' directive. |
3104 | std::optional<StringRef> builder = operand.constraint.getBuilderCall(); |
3105 | if (!builder || (fmt.allOperands && operand.isVariableLength())) { |
3106 | return emitErrorAndNote( |
3107 | loc, |
3108 | msg: "type of operand #" + Twine(i) + ", named '" + operand.name + |
3109 | "', is not buildable and a buildable type cannot be inferred" , |
3110 | note: "suggest adding a type constraint to the operation or adding a " |
3111 | "'type($" + |
3112 | operand.name + ")' directive to the " + "custom assembly format" ); |
3113 | } |
3114 | auto it = buildableTypes.insert(KV: {*builder, buildableTypes.size()}); |
3115 | fmt.operandTypes[i].setBuilderIdx(it.first->second); |
3116 | } |
3117 | return success(); |
3118 | } |
3119 | |
3120 | LogicalResult OpFormatParser::verifyRegions(SMLoc loc) { |
3121 | // Check that all of the regions are within the format. |
3122 | if (hasAllRegions) |
3123 | return success(); |
3124 | |
3125 | for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) { |
3126 | const NamedRegion ®ion = op.getRegion(index: i); |
3127 | if (!seenRegions.count(V: ®ion)) { |
3128 | return emitErrorAndNote(loc, |
3129 | msg: "region #" + Twine(i) + ", named '" + |
3130 | region.name + "', not found" , |
3131 | note: "suggest adding a '$" + region.name + |
3132 | "' directive to the custom assembly format" ); |
3133 | } |
3134 | } |
3135 | return success(); |
3136 | } |
3137 | |
3138 | LogicalResult OpFormatParser::verifyResults( |
3139 | SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) { |
3140 | // If we format all of the types together, there is nothing to check. |
3141 | if (fmt.allResultTypes) |
3142 | return success(); |
3143 | |
3144 | // If no result types are specified and we can infer them, infer all result |
3145 | // types |
3146 | if (op.getNumResults() > 0 && seenResultTypes.count() == 0 && |
3147 | canInferResultTypes) { |
3148 | fmt.infersResultTypes = true; |
3149 | return success(); |
3150 | } |
3151 | |
3152 | // Check that all of the result types can be inferred. |
3153 | auto &buildableTypes = fmt.buildableTypes; |
3154 | for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { |
3155 | if (seenResultTypes.test(Idx: i)) |
3156 | continue; |
3157 | |
3158 | // Check to see if we can infer this type from another variable. |
3159 | auto varResolverIt = variableTyResolver.find(Key: op.getResultName(index: i)); |
3160 | if (varResolverIt != variableTyResolver.end()) { |
3161 | TypeResolutionInstance resolver = varResolverIt->second; |
3162 | fmt.resultTypes[i].setResolver(arg: resolver.resolver, transformer: resolver.transformer); |
3163 | continue; |
3164 | } |
3165 | |
3166 | // If the result is not variable length, allow for the case where the type |
3167 | // has a builder that we can use. |
3168 | NamedTypeConstraint &result = op.getResult(index: i); |
3169 | std::optional<StringRef> builder = result.constraint.getBuilderCall(); |
3170 | if (!builder || result.isVariableLength()) { |
3171 | return emitErrorAndNote( |
3172 | loc, |
3173 | msg: "type of result #" + Twine(i) + ", named '" + result.name + |
3174 | "', is not buildable and a buildable type cannot be inferred" , |
3175 | note: "suggest adding a type constraint to the operation or adding a " |
3176 | "'type($" + |
3177 | result.name + ")' directive to the " + "custom assembly format" ); |
3178 | } |
3179 | // Note in the format that this result uses the custom builder. |
3180 | auto it = buildableTypes.insert(KV: {*builder, buildableTypes.size()}); |
3181 | fmt.resultTypes[i].setBuilderIdx(it.first->second); |
3182 | } |
3183 | return success(); |
3184 | } |
3185 | |
3186 | LogicalResult OpFormatParser::verifySuccessors(SMLoc loc) { |
3187 | // Check that all of the successors are within the format. |
3188 | if (hasAllSuccessors) |
3189 | return success(); |
3190 | |
3191 | for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) { |
3192 | const NamedSuccessor &successor = op.getSuccessor(index: i); |
3193 | if (!seenSuccessors.count(V: &successor)) { |
3194 | return emitErrorAndNote(loc, |
3195 | msg: "successor #" + Twine(i) + ", named '" + |
3196 | successor.name + "', not found" , |
3197 | note: "suggest adding a '$" + successor.name + |
3198 | "' directive to the custom assembly format" ); |
3199 | } |
3200 | } |
3201 | return success(); |
3202 | } |
3203 | |
3204 | LogicalResult |
3205 | OpFormatParser::verifyOIListElements(SMLoc loc, |
3206 | ArrayRef<FormatElement *> elements) { |
3207 | // Check that all of the successors are within the format. |
3208 | SmallVector<StringRef> prohibitedLiterals; |
3209 | for (FormatElement *it : elements) { |
3210 | if (auto *oilist = dyn_cast<OIListElement>(Val: it)) { |
3211 | if (!prohibitedLiterals.empty()) { |
3212 | // We just saw an oilist element in last iteration. Literals should not |
3213 | // match. |
3214 | for (LiteralElement *literal : oilist->getLiteralElements()) { |
3215 | if (find(Range&: prohibitedLiterals, Val: literal->getSpelling()) != |
3216 | prohibitedLiterals.end()) { |
3217 | return emitError( |
3218 | loc, msg: "format ambiguity because " + literal->getSpelling() + |
3219 | " is used in two adjacent oilist elements." ); |
3220 | } |
3221 | } |
3222 | } |
3223 | for (LiteralElement *literal : oilist->getLiteralElements()) |
3224 | prohibitedLiterals.push_back(Elt: literal->getSpelling()); |
3225 | } else if (auto *literal = dyn_cast<LiteralElement>(Val: it)) { |
3226 | if (find(Range&: prohibitedLiterals, Val: literal->getSpelling()) != |
3227 | prohibitedLiterals.end()) { |
3228 | return emitError( |
3229 | loc, |
3230 | msg: "format ambiguity because " + literal->getSpelling() + |
3231 | " is used both in oilist element and the adjacent literal." ); |
3232 | } |
3233 | prohibitedLiterals.clear(); |
3234 | } else { |
3235 | prohibitedLiterals.clear(); |
3236 | } |
3237 | } |
3238 | return success(); |
3239 | } |
3240 | |
3241 | void OpFormatParser::handleAllTypesMatchConstraint( |
3242 | ArrayRef<StringRef> values, |
3243 | StringMap<TypeResolutionInstance> &variableTyResolver) { |
3244 | for (unsigned i = 0, e = values.size(); i != e; ++i) { |
3245 | // Check to see if this value matches a resolved operand or result type. |
3246 | ConstArgument arg = findSeenArg(name: values[i]); |
3247 | if (!arg) |
3248 | continue; |
3249 | |
3250 | // Mark this value as the type resolver for the other variables. |
3251 | for (unsigned j = 0; j != i; ++j) |
3252 | variableTyResolver[values[j]] = {.resolver: arg, .transformer: std::nullopt}; |
3253 | for (unsigned j = i + 1; j != e; ++j) |
3254 | variableTyResolver[values[j]] = {.resolver: arg, .transformer: std::nullopt}; |
3255 | } |
3256 | } |
3257 | |
3258 | void OpFormatParser::handleSameTypesConstraint( |
3259 | StringMap<TypeResolutionInstance> &variableTyResolver, |
3260 | bool includeResults) { |
3261 | const NamedTypeConstraint *resolver = nullptr; |
3262 | int resolvedIt = -1; |
3263 | |
3264 | // Check to see if there is an operand or result to use for the resolution. |
3265 | if ((resolvedIt = seenOperandTypes.find_first()) != -1) |
3266 | resolver = &op.getOperand(index: resolvedIt); |
3267 | else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1) |
3268 | resolver = &op.getResult(index: resolvedIt); |
3269 | else |
3270 | return; |
3271 | |
3272 | // Set the resolvers for each operand and result. |
3273 | for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) |
3274 | if (!seenOperandTypes.test(Idx: i)) |
3275 | variableTyResolver[op.getOperand(index: i).name] = {.resolver: resolver, .transformer: std::nullopt}; |
3276 | if (includeResults) { |
3277 | for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) |
3278 | if (!seenResultTypes.test(Idx: i)) |
3279 | variableTyResolver[op.getResultName(index: i)] = {.resolver: resolver, .transformer: std::nullopt}; |
3280 | } |
3281 | } |
3282 | |
3283 | void OpFormatParser::handleTypesMatchConstraint( |
3284 | StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) { |
3285 | StringRef lhsName = def.getValueAsString(FieldName: "lhs" ); |
3286 | StringRef rhsName = def.getValueAsString(FieldName: "rhs" ); |
3287 | StringRef transformer = def.getValueAsString(FieldName: "transformer" ); |
3288 | if (ConstArgument arg = findSeenArg(name: lhsName)) |
3289 | variableTyResolver[rhsName] = {.resolver: arg, .transformer: transformer}; |
3290 | } |
3291 | |
3292 | ConstArgument OpFormatParser::findSeenArg(StringRef name) { |
3293 | if (const NamedTypeConstraint *arg = findArg(range: op.getOperands(), name)) |
3294 | return seenOperandTypes.test(Idx: arg - op.operand_begin()) ? arg : nullptr; |
3295 | if (const NamedTypeConstraint *arg = findArg(range: op.getResults(), name)) |
3296 | return seenResultTypes.test(Idx: arg - op.result_begin()) ? arg : nullptr; |
3297 | if (const NamedAttribute *attr = findArg(range: op.getAttributes(), name)) |
3298 | return seenAttrs.count(key: attr) ? attr : nullptr; |
3299 | return nullptr; |
3300 | } |
3301 | |
3302 | FailureOr<FormatElement *> |
3303 | OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) { |
3304 | // Check that the parsed argument is something actually registered on the op. |
3305 | // Attributes |
3306 | if (const NamedAttribute *attr = findArg(range: op.getAttributes(), name)) { |
3307 | if (ctx == TypeDirectiveContext) |
3308 | return emitError( |
3309 | loc, msg: "attributes cannot be used as children to a `type` directive" ); |
3310 | if (ctx == RefDirectiveContext) { |
3311 | if (!seenAttrs.count(key: attr)) |
3312 | return emitError(loc, msg: "attribute '" + name + |
3313 | "' must be bound before it is referenced" ); |
3314 | } else if (!seenAttrs.insert(X: attr)) { |
3315 | return emitError(loc, msg: "attribute '" + name + "' is already bound" ); |
3316 | } |
3317 | |
3318 | return create<AttributeVariable>(args&: attr); |
3319 | } |
3320 | |
3321 | if (const NamedProperty *property = findArg(range: op.getProperties(), name)) { |
3322 | if (ctx == TypeDirectiveContext) |
3323 | return emitError( |
3324 | loc, msg: "properties cannot be used as children to a `type` directive" ); |
3325 | if (ctx == RefDirectiveContext) { |
3326 | if (!seenProperties.count(key: property)) |
3327 | return emitError(loc, msg: "property '" + name + |
3328 | "' must be bound before it is referenced" ); |
3329 | } else { |
3330 | if (!seenProperties.insert(X: property)) |
3331 | return emitError(loc, msg: "property '" + name + "' is already bound" ); |
3332 | } |
3333 | |
3334 | return create<PropertyVariable>(args&: property); |
3335 | } |
3336 | |
3337 | // Operands |
3338 | if (const NamedTypeConstraint *operand = findArg(range: op.getOperands(), name)) { |
3339 | if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { |
3340 | if (fmt.allOperands || !seenOperands.insert(V: operand).second) |
3341 | return emitError(loc, msg: "operand '" + name + "' is already bound" ); |
3342 | } else if (ctx == RefDirectiveContext && !seenOperands.count(V: operand)) { |
3343 | return emitError(loc, msg: "operand '" + name + |
3344 | "' must be bound before it is referenced" ); |
3345 | } |
3346 | return create<OperandVariable>(args&: operand); |
3347 | } |
3348 | // Regions |
3349 | if (const NamedRegion *region = findArg(range: op.getRegions(), name)) { |
3350 | if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { |
3351 | if (hasAllRegions || !seenRegions.insert(V: region).second) |
3352 | return emitError(loc, msg: "region '" + name + "' is already bound" ); |
3353 | } else if (ctx == RefDirectiveContext && !seenRegions.count(V: region)) { |
3354 | return emitError(loc, msg: "region '" + name + |
3355 | "' must be bound before it is referenced" ); |
3356 | } else { |
3357 | return emitError(loc, msg: "regions can only be used at the top level" ); |
3358 | } |
3359 | return create<RegionVariable>(args&: region); |
3360 | } |
3361 | // Results. |
3362 | if (const auto *result = findArg(range: op.getResults(), name)) { |
3363 | if (ctx != TypeDirectiveContext) |
3364 | return emitError(loc, msg: "result variables can can only be used as a child " |
3365 | "to a 'type' directive" ); |
3366 | return create<ResultVariable>(args&: result); |
3367 | } |
3368 | // Successors. |
3369 | if (const auto *successor = findArg(range: op.getSuccessors(), name)) { |
3370 | if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { |
3371 | if (hasAllSuccessors || !seenSuccessors.insert(V: successor).second) |
3372 | return emitError(loc, msg: "successor '" + name + "' is already bound" ); |
3373 | } else if (ctx == RefDirectiveContext && !seenSuccessors.count(V: successor)) { |
3374 | return emitError(loc, msg: "successor '" + name + |
3375 | "' must be bound before it is referenced" ); |
3376 | } else { |
3377 | return emitError(loc, msg: "successors can only be used at the top level" ); |
3378 | } |
3379 | |
3380 | return create<SuccessorVariable>(args&: successor); |
3381 | } |
3382 | return emitError(loc, msg: "expected variable to refer to an argument, region, " |
3383 | "result, or successor" ); |
3384 | } |
3385 | |
3386 | FailureOr<FormatElement *> |
3387 | OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, |
3388 | Context ctx) { |
3389 | switch (kind) { |
3390 | case FormatToken::kw_prop_dict: |
3391 | return parsePropDictDirective(loc, context: ctx); |
3392 | case FormatToken::kw_attr_dict: |
3393 | return parseAttrDictDirective(loc, context: ctx, |
3394 | /*withKeyword=*/false); |
3395 | case FormatToken::kw_attr_dict_w_keyword: |
3396 | return parseAttrDictDirective(loc, context: ctx, |
3397 | /*withKeyword=*/true); |
3398 | case FormatToken::kw_functional_type: |
3399 | return parseFunctionalTypeDirective(loc, context: ctx); |
3400 | case FormatToken::kw_operands: |
3401 | return parseOperandsDirective(loc, context: ctx); |
3402 | case FormatToken::kw_regions: |
3403 | return parseRegionsDirective(loc, context: ctx); |
3404 | case FormatToken::kw_results: |
3405 | return parseResultsDirective(loc, context: ctx); |
3406 | case FormatToken::kw_successors: |
3407 | return parseSuccessorsDirective(loc, context: ctx); |
3408 | case FormatToken::kw_type: |
3409 | return parseTypeDirective(loc, context: ctx); |
3410 | case FormatToken::kw_oilist: |
3411 | return parseOIListDirective(loc, context: ctx); |
3412 | |
3413 | default: |
3414 | return emitError(loc, msg: "unsupported directive kind" ); |
3415 | } |
3416 | } |
3417 | |
3418 | FailureOr<FormatElement *> |
3419 | OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context, |
3420 | bool withKeyword) { |
3421 | if (context == TypeDirectiveContext) |
3422 | return emitError(loc, msg: "'attr-dict' directive can only be used as a " |
3423 | "top-level directive" ); |
3424 | |
3425 | if (context == RefDirectiveContext) { |
3426 | if (!hasAttrDict) |
3427 | return emitError(loc, msg: "'ref' of 'attr-dict' is not bound by a prior " |
3428 | "'attr-dict' directive" ); |
3429 | |
3430 | // Otherwise, this is a top-level context. |
3431 | } else { |
3432 | if (hasAttrDict) |
3433 | return emitError(loc, msg: "'attr-dict' directive has already been seen" ); |
3434 | hasAttrDict = true; |
3435 | } |
3436 | |
3437 | return create<AttrDictDirective>(args&: withKeyword); |
3438 | } |
3439 | |
3440 | FailureOr<FormatElement *> |
3441 | OpFormatParser::parsePropDictDirective(SMLoc loc, Context context) { |
3442 | if (context == TypeDirectiveContext) |
3443 | return emitError(loc, msg: "'prop-dict' directive can only be used as a " |
3444 | "top-level directive" ); |
3445 | |
3446 | if (context == RefDirectiveContext) |
3447 | llvm::report_fatal_error(reason: "'ref' of 'prop-dict' unsupported" ); |
3448 | // Otherwise, this is a top-level context. |
3449 | |
3450 | if (hasPropDict) |
3451 | return emitError(loc, msg: "'prop-dict' directive has already been seen" ); |
3452 | hasPropDict = true; |
3453 | |
3454 | return create<PropDictDirective>(); |
3455 | } |
3456 | |
3457 | LogicalResult OpFormatParser::verifyCustomDirectiveArguments( |
3458 | SMLoc loc, ArrayRef<FormatElement *> arguments) { |
3459 | for (FormatElement *argument : arguments) { |
3460 | if (!isa<AttrDictDirective, PropDictDirective, AttributeVariable, |
3461 | OperandVariable, PropertyVariable, RefDirective, RegionVariable, |
3462 | SuccessorVariable, StringElement, TypeDirective>(Val: argument)) { |
3463 | // TODO: FormatElement should have location info attached. |
3464 | return emitError(loc, msg: "only variables and types may be used as " |
3465 | "parameters to a custom directive" ); |
3466 | } |
3467 | if (auto *type = dyn_cast<TypeDirective>(Val: argument)) { |
3468 | if (!isa<OperandVariable, ResultVariable>(Val: type->getArg())) { |
3469 | return emitError(loc, msg: "type directives within a custom directive may " |
3470 | "only refer to variables" ); |
3471 | } |
3472 | } |
3473 | } |
3474 | return success(); |
3475 | } |
3476 | |
3477 | FailureOr<FormatElement *> |
3478 | OpFormatParser::parseFunctionalTypeDirective(SMLoc loc, Context context) { |
3479 | if (context != TopLevelContext) |
3480 | return emitError( |
3481 | loc, msg: "'functional-type' is only valid as a top-level directive" ); |
3482 | |
3483 | // Parse the main operand. |
3484 | FailureOr<FormatElement *> inputs, results; |
3485 | if (failed(Result: parseToken(kind: FormatToken::l_paren, |
3486 | msg: "expected '(' before argument list" )) || |
3487 | failed(Result: inputs = parseTypeDirectiveOperand(loc)) || |
3488 | failed(Result: parseToken(kind: FormatToken::comma, |
3489 | msg: "expected ',' after inputs argument" )) || |
3490 | failed(Result: results = parseTypeDirectiveOperand(loc)) || |
3491 | failed( |
3492 | Result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list" ))) |
3493 | return failure(); |
3494 | return create<FunctionalTypeDirective>(args&: *inputs, args&: *results); |
3495 | } |
3496 | |
3497 | FailureOr<FormatElement *> |
3498 | OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) { |
3499 | if (context == RefDirectiveContext) { |
3500 | if (!fmt.allOperands) |
3501 | return emitError(loc, msg: "'ref' of 'operands' is not bound by a prior " |
3502 | "'operands' directive" ); |
3503 | |
3504 | } else if (context == TopLevelContext || context == CustomDirectiveContext) { |
3505 | if (fmt.allOperands || !seenOperands.empty()) |
3506 | return emitError(loc, msg: "'operands' directive creates overlap in format" ); |
3507 | fmt.allOperands = true; |
3508 | } |
3509 | return create<OperandsDirective>(); |
3510 | } |
3511 | |
3512 | FailureOr<FormatElement *> |
3513 | OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) { |
3514 | if (context == TypeDirectiveContext) |
3515 | return emitError(loc, msg: "'regions' is only valid as a top-level directive" ); |
3516 | if (context == RefDirectiveContext) { |
3517 | if (!hasAllRegions) |
3518 | return emitError(loc, msg: "'ref' of 'regions' is not bound by a prior " |
3519 | "'regions' directive" ); |
3520 | |
3521 | // Otherwise, this is a TopLevel directive. |
3522 | } else { |
3523 | if (hasAllRegions || !seenRegions.empty()) |
3524 | return emitError(loc, msg: "'regions' directive creates overlap in format" ); |
3525 | hasAllRegions = true; |
3526 | } |
3527 | return create<RegionsDirective>(); |
3528 | } |
3529 | |
3530 | FailureOr<FormatElement *> |
3531 | OpFormatParser::parseResultsDirective(SMLoc loc, Context context) { |
3532 | if (context != TypeDirectiveContext) |
3533 | return emitError(loc, msg: "'results' directive can can only be used as a child " |
3534 | "to a 'type' directive" ); |
3535 | return create<ResultsDirective>(); |
3536 | } |
3537 | |
3538 | FailureOr<FormatElement *> |
3539 | OpFormatParser::parseSuccessorsDirective(SMLoc loc, Context context) { |
3540 | if (context == TypeDirectiveContext) |
3541 | return emitError(loc, |
3542 | msg: "'successors' is only valid as a top-level directive" ); |
3543 | if (context == RefDirectiveContext) { |
3544 | if (!hasAllSuccessors) |
3545 | return emitError(loc, msg: "'ref' of 'successors' is not bound by a prior " |
3546 | "'successors' directive" ); |
3547 | |
3548 | // Otherwise, this is a TopLevel directive. |
3549 | } else { |
3550 | if (hasAllSuccessors || !seenSuccessors.empty()) |
3551 | return emitError(loc, msg: "'successors' directive creates overlap in format" ); |
3552 | hasAllSuccessors = true; |
3553 | } |
3554 | return create<SuccessorsDirective>(); |
3555 | } |
3556 | |
3557 | FailureOr<FormatElement *> |
3558 | OpFormatParser::parseOIListDirective(SMLoc loc, Context context) { |
3559 | if (failed(Result: parseToken(kind: FormatToken::l_paren, |
3560 | msg: "expected '(' before oilist argument list" ))) |
3561 | return failure(); |
3562 | std::vector<FormatElement *> literalElements; |
3563 | std::vector<std::vector<FormatElement *>> parsingElements; |
3564 | do { |
3565 | FailureOr<FormatElement *> lelement = parseLiteral(ctx: context); |
3566 | if (failed(Result: lelement)) |
3567 | return failure(); |
3568 | literalElements.push_back(x: *lelement); |
3569 | parsingElements.emplace_back(); |
3570 | std::vector<FormatElement *> &currParsingElements = parsingElements.back(); |
3571 | while (peekToken().getKind() != FormatToken::pipe && |
3572 | peekToken().getKind() != FormatToken::r_paren) { |
3573 | FailureOr<FormatElement *> pelement = parseElement(ctx: context); |
3574 | if (failed(Result: pelement) || |
3575 | failed(Result: verifyOIListParsingElement(element: *pelement, loc))) |
3576 | return failure(); |
3577 | currParsingElements.push_back(x: *pelement); |
3578 | } |
3579 | if (peekToken().getKind() == FormatToken::pipe) { |
3580 | consumeToken(); |
3581 | continue; |
3582 | } |
3583 | if (peekToken().getKind() == FormatToken::r_paren) { |
3584 | consumeToken(); |
3585 | break; |
3586 | } |
3587 | } while (true); |
3588 | |
3589 | return create<OIListElement>(args: std::move(literalElements), |
3590 | args: std::move(parsingElements)); |
3591 | } |
3592 | |
3593 | LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element, |
3594 | SMLoc loc) { |
3595 | SmallVector<VariableElement *> vars; |
3596 | collect(element, variables&: vars); |
3597 | for (VariableElement *elem : vars) { |
3598 | LogicalResult res = |
3599 | TypeSwitch<FormatElement *, LogicalResult>(elem) |
3600 | // Only optional attributes can be within an oilist parsing group. |
3601 | .Case(caseFn: [&](AttributeVariable *attrEle) { |
3602 | if (!attrEle->getVar()->attr.isOptional() && |
3603 | !attrEle->getVar()->attr.hasDefaultValue()) |
3604 | return emitError(loc, msg: "only optional attributes can be used in " |
3605 | "an oilist parsing group" ); |
3606 | return success(); |
3607 | }) |
3608 | // Only optional properties can be within an oilist parsing group. |
3609 | .Case(caseFn: [&](PropertyVariable *propEle) { |
3610 | if (!propEle->getVar()->prop.hasDefaultValue()) |
3611 | return emitError( |
3612 | loc, |
3613 | msg: "only default-valued or optional properties can be used in " |
3614 | "an olist parsing group" ); |
3615 | return success(); |
3616 | }) |
3617 | // Only optional-like(i.e. variadic) operands can be within an |
3618 | // oilist parsing group. |
3619 | .Case(caseFn: [&](OperandVariable *ele) { |
3620 | if (!ele->getVar()->isVariableLength()) |
3621 | return emitError(loc, msg: "only variable length operands can be " |
3622 | "used within an oilist parsing group" ); |
3623 | return success(); |
3624 | }) |
3625 | // Only optional-like(i.e. variadic) results can be within an oilist |
3626 | // parsing group. |
3627 | .Case(caseFn: [&](ResultVariable *ele) { |
3628 | if (!ele->getVar()->isVariableLength()) |
3629 | return emitError(loc, msg: "only variable length results can be " |
3630 | "used within an oilist parsing group" ); |
3631 | return success(); |
3632 | }) |
3633 | .Case(caseFn: [&](RegionVariable *) { return success(); }) |
3634 | .Default(defaultFn: [&](FormatElement *) { |
3635 | return emitError(loc, |
3636 | msg: "only literals, types, and variables can be " |
3637 | "used within an oilist group" ); |
3638 | }); |
3639 | if (failed(Result: res)) |
3640 | return failure(); |
3641 | } |
3642 | return success(); |
3643 | } |
3644 | |
3645 | FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc, |
3646 | Context context) { |
3647 | if (context == TypeDirectiveContext) |
3648 | return emitError(loc, msg: "'type' cannot be used as a child of another `type`" ); |
3649 | |
3650 | bool isRefChild = context == RefDirectiveContext; |
3651 | FailureOr<FormatElement *> operand; |
3652 | if (failed(Result: parseToken(kind: FormatToken::l_paren, |
3653 | msg: "expected '(' before argument list" )) || |
3654 | failed(Result: operand = parseTypeDirectiveOperand(loc, isRefChild)) || |
3655 | failed( |
3656 | Result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list" ))) |
3657 | return failure(); |
3658 | |
3659 | return create<TypeDirective>(args&: *operand); |
3660 | } |
3661 | |
3662 | LogicalResult OpFormatParser::markQualified(SMLoc loc, FormatElement *element) { |
3663 | return TypeSwitch<FormatElement *, LogicalResult>(element) |
3664 | .Case<AttributeVariable, TypeDirective>(caseFn: [](auto *element) { |
3665 | element->setShouldBeQualified(); |
3666 | return success(); |
3667 | }) |
3668 | .Default(defaultFn: [&](auto *element) { |
3669 | return this->emitError( |
3670 | loc, |
3671 | msg: "'qualified' directive expects an attribute or a `type` directive" ); |
3672 | }); |
3673 | } |
3674 | |
3675 | FailureOr<FormatElement *> |
3676 | OpFormatParser::parseTypeDirectiveOperand(SMLoc loc, bool isRefChild) { |
3677 | FailureOr<FormatElement *> result = parseElement(ctx: TypeDirectiveContext); |
3678 | if (failed(Result: result)) |
3679 | return failure(); |
3680 | |
3681 | FormatElement *element = *result; |
3682 | if (isa<LiteralElement>(Val: element)) |
3683 | return emitError( |
3684 | loc, msg: "'type' directive operand expects variable or directive operand" ); |
3685 | |
3686 | if (auto *var = dyn_cast<OperandVariable>(Val: element)) { |
3687 | unsigned opIdx = var->getVar() - op.operand_begin(); |
3688 | if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(Idx: opIdx))) |
3689 | return emitError(loc, msg: "'type' of '" + var->getVar()->name + |
3690 | "' is already bound" ); |
3691 | if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(Idx: opIdx))) |
3692 | return emitError(loc, msg: "'ref' of 'type($" + var->getVar()->name + |
3693 | ")' is not bound by a prior 'type' directive" ); |
3694 | seenOperandTypes.set(opIdx); |
3695 | } else if (auto *var = dyn_cast<ResultVariable>(Val: element)) { |
3696 | unsigned resIdx = var->getVar() - op.result_begin(); |
3697 | if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(Idx: resIdx))) |
3698 | return emitError(loc, msg: "'type' of '" + var->getVar()->name + |
3699 | "' is already bound" ); |
3700 | if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(Idx: resIdx))) |
3701 | return emitError(loc, msg: "'ref' of 'type($" + var->getVar()->name + |
3702 | ")' is not bound by a prior 'type' directive" ); |
3703 | seenResultTypes.set(resIdx); |
3704 | } else if (isa<OperandsDirective>(Val: &*element)) { |
3705 | if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any())) |
3706 | return emitError(loc, msg: "'operands' 'type' is already bound" ); |
3707 | if (isRefChild && !fmt.allOperandTypes) |
3708 | return emitError(loc, msg: "'ref' of 'type(operands)' is not bound by a prior " |
3709 | "'type' directive" ); |
3710 | fmt.allOperandTypes = true; |
3711 | } else if (isa<ResultsDirective>(Val: &*element)) { |
3712 | if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any())) |
3713 | return emitError(loc, msg: "'results' 'type' is already bound" ); |
3714 | if (isRefChild && !fmt.allResultTypes) |
3715 | return emitError(loc, msg: "'ref' of 'type(results)' is not bound by a prior " |
3716 | "'type' directive" ); |
3717 | fmt.allResultTypes = true; |
3718 | } else { |
3719 | return emitError(loc, msg: "invalid argument to 'type' directive" ); |
3720 | } |
3721 | return element; |
3722 | } |
3723 | |
3724 | LogicalResult OpFormatParser::verifyOptionalGroupElements( |
3725 | SMLoc loc, ArrayRef<FormatElement *> elements, FormatElement *anchor) { |
3726 | for (FormatElement *element : elements) { |
3727 | if (failed(Result: verifyOptionalGroupElement(loc, element, isAnchor: element == anchor))) |
3728 | return failure(); |
3729 | } |
3730 | return success(); |
3731 | } |
3732 | |
3733 | LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc, |
3734 | FormatElement *element, |
3735 | bool isAnchor) { |
3736 | return TypeSwitch<FormatElement *, LogicalResult>(element) |
3737 | // All attributes can be within the optional group, but only optional |
3738 | // attributes can be the anchor. |
3739 | .Case(caseFn: [&](AttributeVariable *attrEle) { |
3740 | Attribute attr = attrEle->getVar()->attr; |
3741 | if (isAnchor && !(attr.isOptional() || attr.hasDefaultValue())) |
3742 | return emitError(loc, msg: "only optional or default-valued attributes " |
3743 | "can be used to anchor an optional group" ); |
3744 | return success(); |
3745 | }) |
3746 | // All properties can be within the optional group, but only optional |
3747 | // properties can be the anchor. |
3748 | .Case(caseFn: [&](PropertyVariable *propEle) { |
3749 | Property prop = propEle->getVar()->prop; |
3750 | if (isAnchor && !(prop.hasDefaultValue() && prop.hasOptionalParser())) |
3751 | return emitError(loc, msg: "only properties with default values " |
3752 | "that can be optionally parsed " |
3753 | "can be used to anchor an optional group" ); |
3754 | return success(); |
3755 | }) |
3756 | // Only optional-like(i.e. variadic) operands can be within an optional |
3757 | // group. |
3758 | .Case(caseFn: [&](OperandVariable *ele) { |
3759 | if (!ele->getVar()->isVariableLength()) |
3760 | return emitError(loc, msg: "only variable length operands can be used " |
3761 | "within an optional group" ); |
3762 | return success(); |
3763 | }) |
3764 | // Only optional-like(i.e. variadic) results can be within an optional |
3765 | // group. |
3766 | .Case(caseFn: [&](ResultVariable *ele) { |
3767 | if (!ele->getVar()->isVariableLength()) |
3768 | return emitError(loc, msg: "only variable length results can be used " |
3769 | "within an optional group" ); |
3770 | return success(); |
3771 | }) |
3772 | .Case(caseFn: [&](RegionVariable *) { |
3773 | // TODO: When ODS has proper support for marking "optional" regions, add |
3774 | // a check here. |
3775 | return success(); |
3776 | }) |
3777 | .Case(caseFn: [&](TypeDirective *ele) { |
3778 | return verifyOptionalGroupElement(loc, element: ele->getArg(), |
3779 | /*isAnchor=*/false); |
3780 | }) |
3781 | .Case(caseFn: [&](FunctionalTypeDirective *ele) { |
3782 | if (failed(Result: verifyOptionalGroupElement(loc, element: ele->getInputs(), |
3783 | /*isAnchor=*/false))) |
3784 | return failure(); |
3785 | return verifyOptionalGroupElement(loc, element: ele->getResults(), |
3786 | /*isAnchor=*/false); |
3787 | }) |
3788 | .Case(caseFn: [&](CustomDirective *ele) { |
3789 | if (!isAnchor) |
3790 | return success(); |
3791 | // Verify each child as being valid in an optional group. They are all |
3792 | // potential anchors if the custom directive was marked as one. |
3793 | for (FormatElement *child : ele->getElements()) { |
3794 | if (isa<RefDirective>(Val: child)) |
3795 | continue; |
3796 | if (failed(Result: verifyOptionalGroupElement(loc, element: child, /*isAnchor=*/true))) |
3797 | return failure(); |
3798 | } |
3799 | return success(); |
3800 | }) |
3801 | // Literals, whitespace, and custom directives may be used, but they can't |
3802 | // anchor the group. |
3803 | .Case<LiteralElement, WhitespaceElement, OptionalElement>( |
3804 | caseFn: [&](FormatElement *) { |
3805 | if (isAnchor) |
3806 | return emitError(loc, msg: "only variables and types can be used " |
3807 | "to anchor an optional group" ); |
3808 | return success(); |
3809 | }) |
3810 | .Default(defaultFn: [&](FormatElement *) { |
3811 | return emitError(loc, msg: "only literals, types, and variables can be " |
3812 | "used within an optional group" ); |
3813 | }); |
3814 | } |
3815 | |
3816 | //===----------------------------------------------------------------------===// |
3817 | // Interface |
3818 | //===----------------------------------------------------------------------===// |
3819 | |
3820 | void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass, |
3821 | bool hasProperties) { |
3822 | // TODO: Operator doesn't expose all necessary functionality via |
3823 | // the const interface. |
3824 | Operator &op = const_cast<Operator &>(constOp); |
3825 | if (!op.hasAssemblyFormat()) |
3826 | return; |
3827 | |
3828 | // Parse the format description. |
3829 | llvm::SourceMgr mgr; |
3830 | mgr.AddNewSourceBuffer( |
3831 | F: llvm::MemoryBuffer::getMemBuffer(InputData: op.getAssemblyFormat()), IncludeLoc: SMLoc()); |
3832 | OperationFormat format(op, hasProperties); |
3833 | OpFormatParser parser(mgr, format, op); |
3834 | FailureOr<std::vector<FormatElement *>> elements = parser.parse(); |
3835 | if (failed(Result: elements)) { |
3836 | // Exit the process if format errors are treated as fatal. |
3837 | if (formatErrorIsFatal) { |
3838 | // Invoke the interrupt handlers to run the file cleanup handlers. |
3839 | llvm::sys::RunInterruptHandlers(); |
3840 | std::exit(status: 1); |
3841 | } |
3842 | return; |
3843 | } |
3844 | format.elements = std::move(*elements); |
3845 | |
3846 | // Generate the printer and parser based on the parsed format. |
3847 | format.genParser(op, opClass); |
3848 | format.genPrinter(op, opClass); |
3849 | } |
3850 | |