1 | //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "OpFormatGen.h" |
10 | #include "FormatGen.h" |
11 | #include "OpClass.h" |
12 | #include "mlir/Support/LLVM.h" |
13 | #include "mlir/TableGen/Class.h" |
14 | #include "mlir/TableGen/Format.h" |
15 | #include "mlir/TableGen/Operator.h" |
16 | #include "mlir/TableGen/Trait.h" |
17 | #include "llvm/ADT/MapVector.h" |
18 | #include "llvm/ADT/Sequence.h" |
19 | #include "llvm/ADT/SetVector.h" |
20 | #include "llvm/ADT/SmallBitVector.h" |
21 | #include "llvm/ADT/StringExtras.h" |
22 | #include "llvm/ADT/TypeSwitch.h" |
23 | #include "llvm/Support/Signals.h" |
24 | #include "llvm/Support/SourceMgr.h" |
25 | #include "llvm/TableGen/Record.h" |
26 | |
27 | #define DEBUG_TYPE "mlir-tblgen-opformatgen" |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::tblgen; |
31 | |
32 | //===----------------------------------------------------------------------===// |
33 | // VariableElement |
34 | |
35 | namespace { |
36 | /// This class represents an instance of an op variable element. A variable |
37 | /// refers to something registered on the operation itself, e.g. an operand, |
38 | /// result, attribute, region, or successor. |
39 | template <typename VarT, VariableElement::Kind VariableKind> |
40 | class OpVariableElement : public VariableElementBase<VariableKind> { |
41 | public: |
42 | using Base = OpVariableElement<VarT, VariableKind>; |
43 | |
44 | /// Create an op variable element with the variable value. |
45 | OpVariableElement(const VarT *var) : var(var) {} |
46 | |
47 | /// Get the variable. |
48 | const VarT *getVar() { return var; } |
49 | |
50 | protected: |
51 | /// The op variable, e.g. a type or attribute constraint. |
52 | const VarT *var; |
53 | }; |
54 | |
55 | /// This class represents a variable that refers to an attribute argument. |
56 | struct AttributeVariable |
57 | : public OpVariableElement<NamedAttribute, VariableElement::Attribute> { |
58 | using Base::Base; |
59 | |
60 | /// Return the constant builder call for the type of this attribute, or |
61 | /// std::nullopt if it doesn't have one. |
62 | std::optional<StringRef> getTypeBuilder() const { |
63 | std::optional<Type> attrType = var->attr.getValueType(); |
64 | return attrType ? attrType->getBuilderCall() : std::nullopt; |
65 | } |
66 | |
67 | /// Return if this attribute refers to a UnitAttr. |
68 | bool isUnitAttr() const { |
69 | return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr" ; |
70 | } |
71 | |
72 | /// Indicate if this attribute is printed "qualified" (that is it is |
73 | /// prefixed with the `#dialect.mnemonic`). |
74 | bool shouldBeQualified() { return shouldBeQualifiedFlag; } |
75 | void setShouldBeQualified(bool qualified = true) { |
76 | shouldBeQualifiedFlag = qualified; |
77 | } |
78 | |
79 | private: |
80 | bool shouldBeQualifiedFlag = false; |
81 | }; |
82 | |
83 | /// This class represents a variable that refers to an operand argument. |
84 | using OperandVariable = |
85 | OpVariableElement<NamedTypeConstraint, VariableElement::Operand>; |
86 | |
87 | /// This class represents a variable that refers to a result. |
88 | using ResultVariable = |
89 | OpVariableElement<NamedTypeConstraint, VariableElement::Result>; |
90 | |
91 | /// This class represents a variable that refers to a region. |
92 | using RegionVariable = OpVariableElement<NamedRegion, VariableElement::Region>; |
93 | |
94 | /// This class represents a variable that refers to a successor. |
95 | using SuccessorVariable = |
96 | OpVariableElement<NamedSuccessor, VariableElement::Successor>; |
97 | |
98 | /// This class represents a variable that refers to a property argument. |
99 | using PropertyVariable = |
100 | OpVariableElement<NamedProperty, VariableElement::Property>; |
101 | } // namespace |
102 | |
103 | //===----------------------------------------------------------------------===// |
104 | // DirectiveElement |
105 | |
106 | namespace { |
107 | /// This class represents the `operands` directive. This directive represents |
108 | /// all of the operands of an operation. |
109 | using OperandsDirective = DirectiveElementBase<DirectiveElement::Operands>; |
110 | |
111 | /// This class represents the `results` directive. This directive represents |
112 | /// all of the results of an operation. |
113 | using ResultsDirective = DirectiveElementBase<DirectiveElement::Results>; |
114 | |
115 | /// This class represents the `regions` directive. This directive represents |
116 | /// all of the regions of an operation. |
117 | using RegionsDirective = DirectiveElementBase<DirectiveElement::Regions>; |
118 | |
119 | /// This class represents the `successors` directive. This directive represents |
120 | /// all of the successors of an operation. |
121 | using SuccessorsDirective = DirectiveElementBase<DirectiveElement::Successors>; |
122 | |
123 | /// This class represents the `attr-dict` directive. This directive represents |
124 | /// the attribute dictionary of the operation. |
125 | class AttrDictDirective |
126 | : public DirectiveElementBase<DirectiveElement::AttrDict> { |
127 | public: |
128 | explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {} |
129 | |
130 | /// Return whether the dictionary should be printed with the 'attributes' |
131 | /// keyword. |
132 | bool isWithKeyword() const { return withKeyword; } |
133 | |
134 | private: |
135 | /// If the dictionary should be printed with the 'attributes' keyword. |
136 | bool withKeyword; |
137 | }; |
138 | |
139 | /// This class represents the `prop-dict` directive. This directive represents |
140 | /// the properties of the operation, expressed as a directionary. |
141 | class PropDictDirective |
142 | : public DirectiveElementBase<DirectiveElement::PropDict> { |
143 | public: |
144 | explicit PropDictDirective() = default; |
145 | }; |
146 | |
147 | /// This class represents the `functional-type` directive. This directive takes |
148 | /// two arguments and formats them, respectively, as the inputs and results of a |
149 | /// FunctionType. |
150 | class FunctionalTypeDirective |
151 | : public DirectiveElementBase<DirectiveElement::FunctionalType> { |
152 | public: |
153 | FunctionalTypeDirective(FormatElement *inputs, FormatElement *results) |
154 | : inputs(inputs), results(results) {} |
155 | |
156 | FormatElement *getInputs() const { return inputs; } |
157 | FormatElement *getResults() const { return results; } |
158 | |
159 | private: |
160 | /// The input and result arguments. |
161 | FormatElement *inputs, *results; |
162 | }; |
163 | |
164 | /// This class represents the `type` directive. |
165 | class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> { |
166 | public: |
167 | TypeDirective(FormatElement *arg) : arg(arg) {} |
168 | |
169 | FormatElement *getArg() const { return arg; } |
170 | |
171 | /// Indicate if this type is printed "qualified" (that is it is |
172 | /// prefixed with the `!dialect.mnemonic`). |
173 | bool shouldBeQualified() { return shouldBeQualifiedFlag; } |
174 | void setShouldBeQualified(bool qualified = true) { |
175 | shouldBeQualifiedFlag = qualified; |
176 | } |
177 | |
178 | private: |
179 | /// The argument that is used to format the directive. |
180 | FormatElement *arg; |
181 | |
182 | bool shouldBeQualifiedFlag = false; |
183 | }; |
184 | |
185 | /// This class represents a group of order-independent optional clauses. Each |
186 | /// clause starts with a literal element and has a coressponding parsing |
187 | /// element. A parsing element is a continous sequence of format elements. |
188 | /// Each clause can appear 0 or 1 time. |
189 | class OIListElement : public DirectiveElementBase<DirectiveElement::OIList> { |
190 | public: |
191 | OIListElement(std::vector<FormatElement *> &&literalElements, |
192 | std::vector<std::vector<FormatElement *>> &&parsingElements) |
193 | : literalElements(std::move(literalElements)), |
194 | parsingElements(std::move(parsingElements)) {} |
195 | |
196 | /// Returns a range to iterate over the LiteralElements. |
197 | auto getLiteralElements() const { |
198 | function_ref<LiteralElement *(FormatElement * el)> |
199 | literalElementCastConverter = |
200 | [](FormatElement *el) { return cast<LiteralElement>(Val: el); }; |
201 | return llvm::map_range(C: literalElements, F: literalElementCastConverter); |
202 | } |
203 | |
204 | /// Returns a range to iterate over the parsing elements corresponding to the |
205 | /// clauses. |
206 | ArrayRef<std::vector<FormatElement *>> getParsingElements() const { |
207 | return parsingElements; |
208 | } |
209 | |
210 | /// Returns a range to iterate over tuples of parsing and literal elements. |
211 | auto getClauses() const { |
212 | return llvm::zip(t: getLiteralElements(), u: getParsingElements()); |
213 | } |
214 | |
215 | /// If the parsing element is a single UnitAttr element, then it returns the |
216 | /// attribute variable. Otherwise, returns nullptr. |
217 | AttributeVariable * |
218 | getUnitAttrParsingElement(ArrayRef<FormatElement *> pelement) { |
219 | if (pelement.size() == 1) { |
220 | auto *attrElem = dyn_cast<AttributeVariable>(Val: pelement[0]); |
221 | if (attrElem && attrElem->isUnitAttr()) |
222 | return attrElem; |
223 | } |
224 | return nullptr; |
225 | } |
226 | |
227 | private: |
228 | /// A vector of `LiteralElement` objects. Each element stores the keyword |
229 | /// for one case of oilist element. For example, an oilist element along with |
230 | /// the `literalElements` vector: |
231 | /// ``` |
232 | /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`] |
233 | /// literalElements = { `keyword`, `otherKeyword` } |
234 | /// ``` |
235 | std::vector<FormatElement *> literalElements; |
236 | |
237 | /// A vector of valid declarative assembly format vectors. Each object in |
238 | /// parsing elements is a vector of elements in assembly format syntax. |
239 | /// For example, an oilist element along with the parsingElements vector: |
240 | /// ``` |
241 | /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`] |
242 | /// parsingElements = { |
243 | /// { `=`, `(`, $arg0, `)` }, |
244 | /// { `<`, $arg1, `>` } |
245 | /// } |
246 | /// ``` |
247 | std::vector<std::vector<FormatElement *>> parsingElements; |
248 | }; |
249 | } // namespace |
250 | |
251 | //===----------------------------------------------------------------------===// |
252 | // OperationFormat |
253 | //===----------------------------------------------------------------------===// |
254 | |
255 | namespace { |
256 | |
257 | using ConstArgument = |
258 | llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>; |
259 | |
260 | struct OperationFormat { |
261 | /// This class represents a specific resolver for an operand or result type. |
262 | class TypeResolution { |
263 | public: |
264 | TypeResolution() = default; |
265 | |
266 | /// Get the index into the buildable types for this type, or std::nullopt. |
267 | std::optional<int> getBuilderIdx() const { return builderIdx; } |
268 | void setBuilderIdx(int idx) { builderIdx = idx; } |
269 | |
270 | /// Get the variable this type is resolved to, or nullptr. |
271 | const NamedTypeConstraint *getVariable() const { |
272 | return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(Val: resolver); |
273 | } |
274 | /// Get the attribute this type is resolved to, or nullptr. |
275 | const NamedAttribute *getAttribute() const { |
276 | return llvm::dyn_cast_if_present<const NamedAttribute *>(Val: resolver); |
277 | } |
278 | /// Get the transformer for the type of the variable, or std::nullopt. |
279 | std::optional<StringRef> getVarTransformer() const { |
280 | return variableTransformer; |
281 | } |
282 | void setResolver(ConstArgument arg, std::optional<StringRef> transformer) { |
283 | resolver = arg; |
284 | variableTransformer = transformer; |
285 | assert(getVariable() || getAttribute()); |
286 | } |
287 | |
288 | private: |
289 | /// If the type is resolved with a buildable type, this is the index into |
290 | /// 'buildableTypes' in the parent format. |
291 | std::optional<int> builderIdx; |
292 | /// If the type is resolved based upon another operand or result, this is |
293 | /// the variable or the attribute that this type is resolved to. |
294 | ConstArgument resolver; |
295 | /// If the type is resolved based upon another operand or result, this is |
296 | /// a transformer to apply to the variable when resolving. |
297 | std::optional<StringRef> variableTransformer; |
298 | }; |
299 | |
300 | /// The context in which an element is generated. |
301 | enum class GenContext { |
302 | /// The element is generated at the top-level or with the same behaviour. |
303 | Normal, |
304 | /// The element is generated inside an optional group. |
305 | Optional |
306 | }; |
307 | |
308 | OperationFormat(const Operator &op) |
309 | : useProperties(op.getDialect().usePropertiesForAttributes() && |
310 | !op.getAttributes().empty()), |
311 | opCppClassName(op.getCppClassName()) { |
312 | operandTypes.resize(new_size: op.getNumOperands(), x: TypeResolution()); |
313 | resultTypes.resize(new_size: op.getNumResults(), x: TypeResolution()); |
314 | |
315 | hasImplicitTermTrait = llvm::any_of(Range: op.getTraits(), P: [](const Trait &trait) { |
316 | return trait.getDef().isSubClassOf(Name: "SingleBlockImplicitTerminatorImpl" ); |
317 | }); |
318 | |
319 | hasSingleBlockTrait = op.getTrait(trait: "::mlir::OpTrait::SingleBlock" ); |
320 | } |
321 | |
322 | /// Generate the operation parser from this format. |
323 | void genParser(Operator &op, OpClass &opClass); |
324 | /// Generate the parser code for a specific format element. |
325 | void genElementParser(FormatElement *element, MethodBody &body, |
326 | FmtContext &attrTypeCtx, |
327 | GenContext genCtx = GenContext::Normal); |
328 | /// Generate the C++ to resolve the types of operands and results during |
329 | /// parsing. |
330 | void genParserTypeResolution(Operator &op, MethodBody &body); |
331 | /// Generate the C++ to resolve the types of the operands during parsing. |
332 | void genParserOperandTypeResolution( |
333 | Operator &op, MethodBody &body, |
334 | function_ref<void(TypeResolution &, StringRef)> emitTypeResolver); |
335 | /// Generate the C++ to resolve regions during parsing. |
336 | void genParserRegionResolution(Operator &op, MethodBody &body); |
337 | /// Generate the C++ to resolve successors during parsing. |
338 | void genParserSuccessorResolution(Operator &op, MethodBody &body); |
339 | /// Generate the C++ to handling variadic segment size traits. |
340 | void genParserVariadicSegmentResolution(Operator &op, MethodBody &body); |
341 | |
342 | /// Generate the operation printer from this format. |
343 | void genPrinter(Operator &op, OpClass &opClass); |
344 | |
345 | /// Generate the printer code for a specific format element. |
346 | void genElementPrinter(FormatElement *element, MethodBody &body, Operator &op, |
347 | bool &shouldEmitSpace, bool &lastWasPunctuation); |
348 | |
349 | /// The various elements in this format. |
350 | std::vector<FormatElement *> elements; |
351 | |
352 | /// A flag indicating if all operand/result types were seen. If the format |
353 | /// contains these, it can not contain individual type resolvers. |
354 | bool allOperands = false, allOperandTypes = false, allResultTypes = false; |
355 | |
356 | /// A flag indicating if this operation infers its result types |
357 | bool infersResultTypes = false; |
358 | |
359 | /// A flag indicating if this operation has the SingleBlockImplicitTerminator |
360 | /// trait. |
361 | bool hasImplicitTermTrait; |
362 | |
363 | /// A flag indicating if this operation has the SingleBlock trait. |
364 | bool hasSingleBlockTrait; |
365 | |
366 | /// Indicate whether attribute are stored in properties. |
367 | bool useProperties; |
368 | |
369 | /// Indicate whether prop-dict is used in the format |
370 | bool hasPropDict; |
371 | |
372 | /// The Operation class name |
373 | StringRef opCppClassName; |
374 | |
375 | /// A map of buildable types to indices. |
376 | llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes; |
377 | |
378 | /// The index of the buildable type, if valid, for every operand and result. |
379 | std::vector<TypeResolution> operandTypes, resultTypes; |
380 | |
381 | /// The set of attributes explicitly used within the format. |
382 | llvm::SmallSetVector<const NamedAttribute *, 8> usedAttributes; |
383 | llvm::StringSet<> inferredAttributes; |
384 | |
385 | /// The set of properties explicitly used within the format. |
386 | llvm::SmallSetVector<const NamedProperty *, 8> usedProperties; |
387 | }; |
388 | } // namespace |
389 | |
390 | //===----------------------------------------------------------------------===// |
391 | // Parser Gen |
392 | |
393 | /// Returns true if we can format the given attribute as an EnumAttr in the |
394 | /// parser format. |
395 | static bool canFormatEnumAttr(const NamedAttribute *attr) { |
396 | Attribute baseAttr = attr->attr.getBaseAttr(); |
397 | const EnumAttr *enumAttr = dyn_cast<EnumAttr>(Val: &baseAttr); |
398 | if (!enumAttr) |
399 | return false; |
400 | |
401 | // The attribute must have a valid underlying type and a constant builder. |
402 | return !enumAttr->getUnderlyingType().empty() && |
403 | !enumAttr->getConstBuilderTemplate().empty(); |
404 | } |
405 | |
406 | /// Returns if we should format the given attribute as an SymbolNameAttr. |
407 | static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) { |
408 | return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr" ; |
409 | } |
410 | |
411 | /// The code snippet used to generate a parser call for an attribute. |
412 | /// |
413 | /// {0}: The name of the attribute. |
414 | /// {1}: The type for the attribute. |
415 | const char *const attrParserCode = R"( |
416 | if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{ |
417 | return ::mlir::failure(); |
418 | } |
419 | )" ; |
420 | |
421 | /// The code snippet used to generate a parser call for an attribute. |
422 | /// |
423 | /// {0}: The name of the attribute. |
424 | /// {1}: The type for the attribute. |
425 | const char *const genericAttrParserCode = R"( |
426 | if (parser.parseAttribute({0}Attr, {1})) |
427 | return ::mlir::failure(); |
428 | )" ; |
429 | |
430 | const char *const optionalAttrParserCode = R"( |
431 | ::mlir::OptionalParseResult parseResult{0}Attr = |
432 | parser.parseOptionalAttribute({0}Attr, {1}); |
433 | if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr)) |
434 | return ::mlir::failure(); |
435 | if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr)) |
436 | )" ; |
437 | |
438 | /// The code snippet used to generate a parser call for a symbol name attribute. |
439 | /// |
440 | /// {0}: The name of the attribute. |
441 | const char *const symbolNameAttrParserCode = R"( |
442 | if (parser.parseSymbolName({0}Attr)) |
443 | return ::mlir::failure(); |
444 | )" ; |
445 | const char *const optionalSymbolNameAttrParserCode = R"( |
446 | // Parsing an optional symbol name doesn't fail, so no need to check the |
447 | // result. |
448 | (void)parser.parseOptionalSymbolName({0}Attr); |
449 | )" ; |
450 | |
451 | /// The code snippet used to generate a parser call for an enum attribute. |
452 | /// |
453 | /// {0}: The name of the attribute. |
454 | /// {1}: The c++ namespace for the enum symbolize functions. |
455 | /// {2}: The function to symbolize a string of the enum. |
456 | /// {3}: The constant builder call to create an attribute of the enum type. |
457 | /// {4}: The set of allowed enum keywords. |
458 | /// {5}: The error message on failure when the enum isn't present. |
459 | /// {6}: The attribute assignment expression |
460 | const char *const enumAttrParserCode = R"( |
461 | { |
462 | ::llvm::StringRef attrStr; |
463 | ::mlir::NamedAttrList attrStorage; |
464 | auto loc = parser.getCurrentLocation(); |
465 | if (parser.parseOptionalKeyword(&attrStr, {4})) { |
466 | ::mlir::StringAttr attrVal; |
467 | ::mlir::OptionalParseResult parseResult = |
468 | parser.parseOptionalAttribute(attrVal, |
469 | parser.getBuilder().getNoneType(), |
470 | "{0}", attrStorage); |
471 | if (parseResult.has_value()) {{ |
472 | if (failed(*parseResult)) |
473 | return ::mlir::failure(); |
474 | attrStr = attrVal.getValue(); |
475 | } else { |
476 | {5} |
477 | } |
478 | } |
479 | if (!attrStr.empty()) { |
480 | auto attrOptional = {1}::{2}(attrStr); |
481 | if (!attrOptional) |
482 | return parser.emitError(loc, "invalid ") |
483 | << "{0} attribute specification: \"" << attrStr << '"';; |
484 | |
485 | {0}Attr = {3}; |
486 | {6} |
487 | } |
488 | } |
489 | )" ; |
490 | |
491 | /// The code snippet used to generate a parser call for an operand. |
492 | /// |
493 | /// {0}: The name of the operand. |
494 | const char *const variadicOperandParserCode = R"( |
495 | {0}OperandsLoc = parser.getCurrentLocation(); |
496 | if (parser.parseOperandList({0}Operands)) |
497 | return ::mlir::failure(); |
498 | )" ; |
499 | const char *const optionalOperandParserCode = R"( |
500 | { |
501 | {0}OperandsLoc = parser.getCurrentLocation(); |
502 | ::mlir::OpAsmParser::UnresolvedOperand operand; |
503 | ::mlir::OptionalParseResult parseResult = |
504 | parser.parseOptionalOperand(operand); |
505 | if (parseResult.has_value()) { |
506 | if (failed(*parseResult)) |
507 | return ::mlir::failure(); |
508 | {0}Operands.push_back(operand); |
509 | } |
510 | } |
511 | )" ; |
512 | const char *const operandParserCode = R"( |
513 | {0}OperandsLoc = parser.getCurrentLocation(); |
514 | if (parser.parseOperand({0}RawOperand)) |
515 | return ::mlir::failure(); |
516 | )" ; |
517 | /// The code snippet used to generate a parser call for a VariadicOfVariadic |
518 | /// operand. |
519 | /// |
520 | /// {0}: The name of the operand. |
521 | /// {1}: The name of segment size attribute. |
522 | const char *const variadicOfVariadicOperandParserCode = R"( |
523 | { |
524 | {0}OperandsLoc = parser.getCurrentLocation(); |
525 | int32_t curSize = 0; |
526 | do { |
527 | if (parser.parseOptionalLParen()) |
528 | break; |
529 | if (parser.parseOperandList({0}Operands) || parser.parseRParen()) |
530 | return ::mlir::failure(); |
531 | {0}OperandGroupSizes.push_back({0}Operands.size() - curSize); |
532 | curSize = {0}Operands.size(); |
533 | } while (succeeded(parser.parseOptionalComma())); |
534 | } |
535 | )" ; |
536 | |
537 | /// The code snippet used to generate a parser call for a type list. |
538 | /// |
539 | /// {0}: The name for the type list. |
540 | const char *const variadicOfVariadicTypeParserCode = R"( |
541 | do { |
542 | if (parser.parseOptionalLParen()) |
543 | break; |
544 | if (parser.parseOptionalRParen() && |
545 | (parser.parseTypeList({0}Types) || parser.parseRParen())) |
546 | return ::mlir::failure(); |
547 | } while (succeeded(parser.parseOptionalComma())); |
548 | )" ; |
549 | const char *const variadicTypeParserCode = R"( |
550 | if (parser.parseTypeList({0}Types)) |
551 | return ::mlir::failure(); |
552 | )" ; |
553 | const char *const optionalTypeParserCode = R"( |
554 | { |
555 | ::mlir::Type optionalType; |
556 | ::mlir::OptionalParseResult parseResult = |
557 | parser.parseOptionalType(optionalType); |
558 | if (parseResult.has_value()) { |
559 | if (failed(*parseResult)) |
560 | return ::mlir::failure(); |
561 | {0}Types.push_back(optionalType); |
562 | } |
563 | } |
564 | )" ; |
565 | const char *const typeParserCode = R"( |
566 | { |
567 | {0} type; |
568 | if (parser.parseCustomTypeWithFallback(type)) |
569 | return ::mlir::failure(); |
570 | {1}RawType = type; |
571 | } |
572 | )" ; |
573 | const char *const qualifiedTypeParserCode = R"( |
574 | if (parser.parseType({1}RawType)) |
575 | return ::mlir::failure(); |
576 | )" ; |
577 | |
578 | /// The code snippet used to generate a parser call for a functional type. |
579 | /// |
580 | /// {0}: The name for the input type list. |
581 | /// {1}: The name for the result type list. |
582 | const char *const functionalTypeParserCode = R"( |
583 | ::mlir::FunctionType {0}__{1}_functionType; |
584 | if (parser.parseType({0}__{1}_functionType)) |
585 | return ::mlir::failure(); |
586 | {0}Types = {0}__{1}_functionType.getInputs(); |
587 | {1}Types = {0}__{1}_functionType.getResults(); |
588 | )" ; |
589 | |
590 | /// The code snippet used to generate a parser call to infer return types. |
591 | /// |
592 | /// {0}: The operation class name |
593 | const char *const inferReturnTypesParserCode = R"( |
594 | ::llvm::SmallVector<::mlir::Type> inferredReturnTypes; |
595 | if (::mlir::failed({0}::inferReturnTypes(parser.getContext(), |
596 | result.location, result.operands, |
597 | result.attributes.getDictionary(parser.getContext()), |
598 | result.getRawProperties(), |
599 | result.regions, inferredReturnTypes))) |
600 | return ::mlir::failure(); |
601 | result.addTypes(inferredReturnTypes); |
602 | )" ; |
603 | |
604 | /// The code snippet used to generate a parser call for a region list. |
605 | /// |
606 | /// {0}: The name for the region list. |
607 | const char *regionListParserCode = R"( |
608 | { |
609 | std::unique_ptr<::mlir::Region> region; |
610 | auto firstRegionResult = parser.parseOptionalRegion(region); |
611 | if (firstRegionResult.has_value()) { |
612 | if (failed(*firstRegionResult)) |
613 | return ::mlir::failure(); |
614 | {0}Regions.emplace_back(std::move(region)); |
615 | |
616 | // Parse any trailing regions. |
617 | while (succeeded(parser.parseOptionalComma())) { |
618 | region = std::make_unique<::mlir::Region>(); |
619 | if (parser.parseRegion(*region)) |
620 | return ::mlir::failure(); |
621 | {0}Regions.emplace_back(std::move(region)); |
622 | } |
623 | } |
624 | } |
625 | )" ; |
626 | |
627 | /// The code snippet used to ensure a list of regions have terminators. |
628 | /// |
629 | /// {0}: The name of the region list. |
630 | const char *regionListEnsureTerminatorParserCode = R"( |
631 | for (auto ®ion : {0}Regions) |
632 | ensureTerminator(*region, parser.getBuilder(), result.location); |
633 | )" ; |
634 | |
635 | /// The code snippet used to ensure a list of regions have a block. |
636 | /// |
637 | /// {0}: The name of the region list. |
638 | const char *regionListEnsureSingleBlockParserCode = R"( |
639 | for (auto ®ion : {0}Regions) |
640 | if (region->empty()) region->emplaceBlock(); |
641 | )" ; |
642 | |
643 | /// The code snippet used to generate a parser call for an optional region. |
644 | /// |
645 | /// {0}: The name of the region. |
646 | const char *optionalRegionParserCode = R"( |
647 | { |
648 | auto parseResult = parser.parseOptionalRegion(*{0}Region); |
649 | if (parseResult.has_value() && failed(*parseResult)) |
650 | return ::mlir::failure(); |
651 | } |
652 | )" ; |
653 | |
654 | /// The code snippet used to generate a parser call for a region. |
655 | /// |
656 | /// {0}: The name of the region. |
657 | const char *regionParserCode = R"( |
658 | if (parser.parseRegion(*{0}Region)) |
659 | return ::mlir::failure(); |
660 | )" ; |
661 | |
662 | /// The code snippet used to ensure a region has a terminator. |
663 | /// |
664 | /// {0}: The name of the region. |
665 | const char *regionEnsureTerminatorParserCode = R"( |
666 | ensureTerminator(*{0}Region, parser.getBuilder(), result.location); |
667 | )" ; |
668 | |
669 | /// The code snippet used to ensure a region has a block. |
670 | /// |
671 | /// {0}: The name of the region. |
672 | const char *regionEnsureSingleBlockParserCode = R"( |
673 | if ({0}Region->empty()) {0}Region->emplaceBlock(); |
674 | )" ; |
675 | |
676 | /// The code snippet used to generate a parser call for a successor list. |
677 | /// |
678 | /// {0}: The name for the successor list. |
679 | const char *successorListParserCode = R"( |
680 | { |
681 | ::mlir::Block *succ; |
682 | auto firstSucc = parser.parseOptionalSuccessor(succ); |
683 | if (firstSucc.has_value()) { |
684 | if (failed(*firstSucc)) |
685 | return ::mlir::failure(); |
686 | {0}Successors.emplace_back(succ); |
687 | |
688 | // Parse any trailing successors. |
689 | while (succeeded(parser.parseOptionalComma())) { |
690 | if (parser.parseSuccessor(succ)) |
691 | return ::mlir::failure(); |
692 | {0}Successors.emplace_back(succ); |
693 | } |
694 | } |
695 | } |
696 | )" ; |
697 | |
698 | /// The code snippet used to generate a parser call for a successor. |
699 | /// |
700 | /// {0}: The name of the successor. |
701 | const char *successorParserCode = R"( |
702 | if (parser.parseSuccessor({0}Successor)) |
703 | return ::mlir::failure(); |
704 | )" ; |
705 | |
706 | /// The code snippet used to generate a parser for OIList |
707 | /// |
708 | /// {0}: literal keyword corresponding to a case for oilist |
709 | const char *oilistParserCode = R"( |
710 | if ({0}Clause) { |
711 | return parser.emitError(parser.getNameLoc()) |
712 | << "`{0}` clause can appear at most once in the expansion of the " |
713 | "oilist directive"; |
714 | } |
715 | {0}Clause = true; |
716 | )" ; |
717 | |
718 | namespace { |
719 | /// The type of length for a given parse argument. |
720 | enum class ArgumentLengthKind { |
721 | /// The argument is a variadic of a variadic, and may contain 0->N range |
722 | /// elements. |
723 | VariadicOfVariadic, |
724 | /// The argument is variadic, and may contain 0->N elements. |
725 | Variadic, |
726 | /// The argument is optional, and may contain 0 or 1 elements. |
727 | Optional, |
728 | /// The argument is a single element, i.e. always represents 1 element. |
729 | Single |
730 | }; |
731 | } // namespace |
732 | |
733 | /// Get the length kind for the given constraint. |
734 | static ArgumentLengthKind |
735 | getArgumentLengthKind(const NamedTypeConstraint *var) { |
736 | if (var->isOptional()) |
737 | return ArgumentLengthKind::Optional; |
738 | if (var->isVariadicOfVariadic()) |
739 | return ArgumentLengthKind::VariadicOfVariadic; |
740 | if (var->isVariadic()) |
741 | return ArgumentLengthKind::Variadic; |
742 | return ArgumentLengthKind::Single; |
743 | } |
744 | |
745 | /// Get the name used for the type list for the given type directive operand. |
746 | /// 'lengthKind' to the corresponding kind for the given argument. |
747 | static StringRef getTypeListName(FormatElement *arg, |
748 | ArgumentLengthKind &lengthKind) { |
749 | if (auto *operand = dyn_cast<OperandVariable>(Val: arg)) { |
750 | lengthKind = getArgumentLengthKind(var: operand->getVar()); |
751 | return operand->getVar()->name; |
752 | } |
753 | if (auto *result = dyn_cast<ResultVariable>(Val: arg)) { |
754 | lengthKind = getArgumentLengthKind(var: result->getVar()); |
755 | return result->getVar()->name; |
756 | } |
757 | lengthKind = ArgumentLengthKind::Variadic; |
758 | if (isa<OperandsDirective>(Val: arg)) |
759 | return "allOperand" ; |
760 | if (isa<ResultsDirective>(Val: arg)) |
761 | return "allResult" ; |
762 | llvm_unreachable("unknown 'type' directive argument" ); |
763 | } |
764 | |
765 | /// Generate the parser for a literal value. |
766 | static void genLiteralParser(StringRef value, MethodBody &body) { |
767 | // Handle the case of a keyword/identifier. |
768 | if (value.front() == '_' || isalpha(value.front())) { |
769 | body << "Keyword(\"" << value << "\")" ; |
770 | return; |
771 | } |
772 | body << (StringRef)StringSwitch<StringRef>(value) |
773 | .Case(S: "->" , Value: "Arrow()" ) |
774 | .Case(S: ":" , Value: "Colon()" ) |
775 | .Case(S: "," , Value: "Comma()" ) |
776 | .Case(S: "=" , Value: "Equal()" ) |
777 | .Case(S: "<" , Value: "Less()" ) |
778 | .Case(S: ">" , Value: "Greater()" ) |
779 | .Case(S: "{" , Value: "LBrace()" ) |
780 | .Case(S: "}" , Value: "RBrace()" ) |
781 | .Case(S: "(" , Value: "LParen()" ) |
782 | .Case(S: ")" , Value: "RParen()" ) |
783 | .Case(S: "[" , Value: "LSquare()" ) |
784 | .Case(S: "]" , Value: "RSquare()" ) |
785 | .Case(S: "?" , Value: "Question()" ) |
786 | .Case(S: "+" , Value: "Plus()" ) |
787 | .Case(S: "*" , Value: "Star()" ) |
788 | .Case(S: "..." , Value: "Ellipsis()" ); |
789 | } |
790 | |
791 | /// Generate the storage code required for parsing the given element. |
792 | static void genElementParserStorage(FormatElement *element, const Operator &op, |
793 | MethodBody &body) { |
794 | if (auto *optional = dyn_cast<OptionalElement>(Val: element)) { |
795 | ArrayRef<FormatElement *> elements = optional->getThenElements(); |
796 | |
797 | // If the anchor is a unit attribute, it won't be parsed directly so elide |
798 | // it. |
799 | auto *anchor = dyn_cast<AttributeVariable>(Val: optional->getAnchor()); |
800 | FormatElement *elidedAnchorElement = nullptr; |
801 | if (anchor && anchor != elements.front() && anchor->isUnitAttr()) |
802 | elidedAnchorElement = anchor; |
803 | for (FormatElement *childElement : elements) |
804 | if (childElement != elidedAnchorElement) |
805 | genElementParserStorage(element: childElement, op, body); |
806 | for (FormatElement *childElement : optional->getElseElements()) |
807 | genElementParserStorage(element: childElement, op, body); |
808 | |
809 | } else if (auto *oilist = dyn_cast<OIListElement>(Val: element)) { |
810 | for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements()) { |
811 | if (!oilist->getUnitAttrParsingElement(pelement)) |
812 | for (FormatElement *element : pelement) |
813 | genElementParserStorage(element, op, body); |
814 | } |
815 | |
816 | } else if (auto *custom = dyn_cast<CustomDirective>(Val: element)) { |
817 | for (FormatElement *paramElement : custom->getArguments()) |
818 | genElementParserStorage(element: paramElement, op, body); |
819 | |
820 | } else if (isa<OperandsDirective>(Val: element)) { |
821 | body << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> " |
822 | "allOperands;\n" ; |
823 | |
824 | } else if (isa<RegionsDirective>(Val: element)) { |
825 | body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> " |
826 | "fullRegions;\n" ; |
827 | |
828 | } else if (isa<SuccessorsDirective>(Val: element)) { |
829 | body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n" ; |
830 | |
831 | } else if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) { |
832 | const NamedAttribute *var = attr->getVar(); |
833 | body << llvm::formatv(Fmt: " {0} {1}Attr;\n" , Vals: var->attr.getStorageType(), |
834 | Vals: var->name); |
835 | |
836 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) { |
837 | StringRef name = operand->getVar()->name; |
838 | if (operand->getVar()->isVariableLength()) { |
839 | body |
840 | << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> " |
841 | << name << "Operands;\n" ; |
842 | if (operand->getVar()->isVariadicOfVariadic()) { |
843 | body << " llvm::SmallVector<int32_t> " << name |
844 | << "OperandGroupSizes;\n" ; |
845 | } |
846 | } else { |
847 | body << " ::mlir::OpAsmParser::UnresolvedOperand " << name |
848 | << "RawOperand{};\n" |
849 | << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> " |
850 | << name << "Operands(&" << name << "RawOperand, 1);" ; |
851 | } |
852 | body << llvm::formatv(Fmt: " ::llvm::SMLoc {0}OperandsLoc;\n" |
853 | " (void){0}OperandsLoc;\n" , |
854 | Vals&: name); |
855 | |
856 | } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) { |
857 | StringRef name = region->getVar()->name; |
858 | if (region->getVar()->isVariadic()) { |
859 | body << llvm::formatv( |
860 | Fmt: " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> " |
861 | "{0}Regions;\n" , |
862 | Vals&: name); |
863 | } else { |
864 | body << llvm::formatv(Fmt: " std::unique_ptr<::mlir::Region> {0}Region = " |
865 | "std::make_unique<::mlir::Region>();\n" , |
866 | Vals&: name); |
867 | } |
868 | |
869 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) { |
870 | StringRef name = successor->getVar()->name; |
871 | if (successor->getVar()->isVariadic()) { |
872 | body << llvm::formatv(Fmt: " ::llvm::SmallVector<::mlir::Block *, 2> " |
873 | "{0}Successors;\n" , |
874 | Vals&: name); |
875 | } else { |
876 | body << llvm::formatv(Fmt: " ::mlir::Block *{0}Successor = nullptr;\n" , Vals&: name); |
877 | } |
878 | |
879 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) { |
880 | ArgumentLengthKind lengthKind; |
881 | StringRef name = getTypeListName(arg: dir->getArg(), lengthKind); |
882 | if (lengthKind != ArgumentLengthKind::Single) |
883 | body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n" ; |
884 | else |
885 | body |
886 | << llvm::formatv(Fmt: " ::mlir::Type {0}RawType{{};\n" , Vals&: name) |
887 | << llvm::formatv( |
888 | Fmt: " ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n" , |
889 | Vals&: name); |
890 | } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(Val: element)) { |
891 | ArgumentLengthKind ignored; |
892 | body << " ::llvm::ArrayRef<::mlir::Type> " |
893 | << getTypeListName(arg: dir->getInputs(), lengthKind&: ignored) << "Types;\n" ; |
894 | body << " ::llvm::ArrayRef<::mlir::Type> " |
895 | << getTypeListName(arg: dir->getResults(), lengthKind&: ignored) << "Types;\n" ; |
896 | } |
897 | } |
898 | |
899 | /// Generate the parser for a parameter to a custom directive. |
900 | static void genCustomParameterParser(FormatElement *param, MethodBody &body) { |
901 | if (auto *attr = dyn_cast<AttributeVariable>(Val: param)) { |
902 | body << attr->getVar()->name << "Attr" ; |
903 | } else if (isa<AttrDictDirective>(Val: param)) { |
904 | body << "result.attributes" ; |
905 | } else if (isa<PropDictDirective>(Val: param)) { |
906 | body << "result" ; |
907 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: param)) { |
908 | StringRef name = operand->getVar()->name; |
909 | ArgumentLengthKind lengthKind = getArgumentLengthKind(var: operand->getVar()); |
910 | if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) |
911 | body << llvm::formatv(Fmt: "{0}OperandGroups" , Vals&: name); |
912 | else if (lengthKind == ArgumentLengthKind::Variadic) |
913 | body << llvm::formatv(Fmt: "{0}Operands" , Vals&: name); |
914 | else if (lengthKind == ArgumentLengthKind::Optional) |
915 | body << llvm::formatv(Fmt: "{0}Operand" , Vals&: name); |
916 | else |
917 | body << formatv(Fmt: "{0}RawOperand" , Vals&: name); |
918 | |
919 | } else if (auto *region = dyn_cast<RegionVariable>(Val: param)) { |
920 | StringRef name = region->getVar()->name; |
921 | if (region->getVar()->isVariadic()) |
922 | body << llvm::formatv(Fmt: "{0}Regions" , Vals&: name); |
923 | else |
924 | body << llvm::formatv(Fmt: "*{0}Region" , Vals&: name); |
925 | |
926 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: param)) { |
927 | StringRef name = successor->getVar()->name; |
928 | if (successor->getVar()->isVariadic()) |
929 | body << llvm::formatv(Fmt: "{0}Successors" , Vals&: name); |
930 | else |
931 | body << llvm::formatv(Fmt: "{0}Successor" , Vals&: name); |
932 | |
933 | } else if (auto *dir = dyn_cast<RefDirective>(Val: param)) { |
934 | genCustomParameterParser(param: dir->getArg(), body); |
935 | |
936 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: param)) { |
937 | ArgumentLengthKind lengthKind; |
938 | StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind); |
939 | if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) |
940 | body << llvm::formatv(Fmt: "{0}TypeGroups" , Vals&: listName); |
941 | else if (lengthKind == ArgumentLengthKind::Variadic) |
942 | body << llvm::formatv(Fmt: "{0}Types" , Vals&: listName); |
943 | else if (lengthKind == ArgumentLengthKind::Optional) |
944 | body << llvm::formatv(Fmt: "{0}Type" , Vals&: listName); |
945 | else |
946 | body << formatv(Fmt: "{0}RawType" , Vals&: listName); |
947 | |
948 | } else if (auto *string = dyn_cast<StringElement>(Val: param)) { |
949 | FmtContext ctx; |
950 | ctx.withBuilder(subst: "parser.getBuilder()" ); |
951 | ctx.addSubst(placeholder: "_ctxt" , subst: "parser.getContext()" ); |
952 | body << tgfmt(fmt: string->getValue(), ctx: &ctx); |
953 | |
954 | } else if (auto *property = dyn_cast<PropertyVariable>(Val: param)) { |
955 | body << llvm::formatv(Fmt: "result.getOrAddProperties<Properties>().{0}" , |
956 | Vals: property->getVar()->name); |
957 | } else { |
958 | llvm_unreachable("unknown custom directive parameter" ); |
959 | } |
960 | } |
961 | |
962 | /// Generate the parser for a custom directive. |
963 | static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, |
964 | bool useProperties, |
965 | StringRef opCppClassName, |
966 | bool isOptional = false) { |
967 | body << " {\n" ; |
968 | |
969 | // Preprocess the directive variables. |
970 | // * Add a local variable for optional operands and types. This provides a |
971 | // better API to the user defined parser methods. |
972 | // * Set the location of operand variables. |
973 | for (FormatElement *param : dir->getArguments()) { |
974 | if (auto *operand = dyn_cast<OperandVariable>(Val: param)) { |
975 | auto *var = operand->getVar(); |
976 | body << " " << var->name |
977 | << "OperandsLoc = parser.getCurrentLocation();\n" ; |
978 | if (var->isOptional()) { |
979 | body << llvm::formatv( |
980 | Fmt: " ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> " |
981 | "{0}Operand;\n" , |
982 | Vals: var->name); |
983 | } else if (var->isVariadicOfVariadic()) { |
984 | body << llvm::formatv(Fmt: " " |
985 | "::llvm::SmallVector<::llvm::SmallVector<::mlir::" |
986 | "OpAsmParser::UnresolvedOperand>> " |
987 | "{0}OperandGroups;\n" , |
988 | Vals: var->name); |
989 | } |
990 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: param)) { |
991 | ArgumentLengthKind lengthKind; |
992 | StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind); |
993 | if (lengthKind == ArgumentLengthKind::Optional) { |
994 | body << llvm::formatv(Fmt: " ::mlir::Type {0}Type;\n" , Vals&: listName); |
995 | } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { |
996 | body << llvm::formatv( |
997 | Fmt: " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> " |
998 | "{0}TypeGroups;\n" , |
999 | Vals&: listName); |
1000 | } |
1001 | } else if (auto *dir = dyn_cast<RefDirective>(Val: param)) { |
1002 | FormatElement *input = dir->getArg(); |
1003 | if (auto *operand = dyn_cast<OperandVariable>(Val: input)) { |
1004 | if (!operand->getVar()->isOptional()) |
1005 | continue; |
1006 | body << llvm::formatv( |
1007 | Fmt: " {0} {1}Operand = {1}Operands.empty() ? {0}() : " |
1008 | "{1}Operands[0];\n" , |
1009 | Vals: "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>" , |
1010 | Vals: operand->getVar()->name); |
1011 | |
1012 | } else if (auto *type = dyn_cast<TypeDirective>(Val: input)) { |
1013 | ArgumentLengthKind lengthKind; |
1014 | StringRef listName = getTypeListName(arg: type->getArg(), lengthKind); |
1015 | if (lengthKind == ArgumentLengthKind::Optional) { |
1016 | body << llvm::formatv(Fmt: " ::mlir::Type {0}Type = {0}Types.empty() ? " |
1017 | "::mlir::Type() : {0}Types[0];\n" , |
1018 | Vals&: listName); |
1019 | } |
1020 | } |
1021 | } |
1022 | } |
1023 | |
1024 | body << " auto odsResult = parse" << dir->getName() << "(parser" ; |
1025 | for (FormatElement *param : dir->getArguments()) { |
1026 | body << ", " ; |
1027 | genCustomParameterParser(param, body); |
1028 | } |
1029 | body << ");\n" ; |
1030 | |
1031 | if (isOptional) { |
1032 | body << " if (!odsResult.has_value()) return {};\n" |
1033 | << " if (::mlir::failed(*odsResult)) return ::mlir::failure();\n" ; |
1034 | } else { |
1035 | body << " if (odsResult) return ::mlir::failure();\n" ; |
1036 | } |
1037 | |
1038 | // After parsing, add handling for any of the optional constructs. |
1039 | for (FormatElement *param : dir->getArguments()) { |
1040 | if (auto *attr = dyn_cast<AttributeVariable>(Val: param)) { |
1041 | const NamedAttribute *var = attr->getVar(); |
1042 | if (var->attr.isOptional() || var->attr.hasDefaultValue()) |
1043 | body << llvm::formatv(Fmt: " if ({0}Attr)\n " , Vals: var->name); |
1044 | if (useProperties) { |
1045 | body << formatv( |
1046 | Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n" , |
1047 | Vals: var->name, Vals&: opCppClassName); |
1048 | } else { |
1049 | body << llvm::formatv(Fmt: " result.addAttribute(\"{0}\", {0}Attr);\n" , |
1050 | Vals: var->name); |
1051 | } |
1052 | |
1053 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: param)) { |
1054 | const NamedTypeConstraint *var = operand->getVar(); |
1055 | if (var->isOptional()) { |
1056 | body << llvm::formatv(Fmt: " if ({0}Operand.has_value())\n" |
1057 | " {0}Operands.push_back(*{0}Operand);\n" , |
1058 | Vals: var->name); |
1059 | } else if (var->isVariadicOfVariadic()) { |
1060 | body << llvm::formatv( |
1061 | Fmt: " for (const auto &subRange : {0}OperandGroups) {{\n" |
1062 | " {0}Operands.append(subRange.begin(), subRange.end());\n" |
1063 | " {0}OperandGroupSizes.push_back(subRange.size());\n" |
1064 | " }\n" , |
1065 | Vals: var->name, Vals: var->constraint.getVariadicOfVariadicSegmentSizeAttr()); |
1066 | } |
1067 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: param)) { |
1068 | ArgumentLengthKind lengthKind; |
1069 | StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind); |
1070 | if (lengthKind == ArgumentLengthKind::Optional) { |
1071 | body << llvm::formatv(Fmt: " if ({0}Type)\n" |
1072 | " {0}Types.push_back({0}Type);\n" , |
1073 | Vals&: listName); |
1074 | } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { |
1075 | body << llvm::formatv( |
1076 | Fmt: " for (const auto &subRange : {0}TypeGroups)\n" |
1077 | " {0}Types.append(subRange.begin(), subRange.end());\n" , |
1078 | Vals&: listName); |
1079 | } |
1080 | } |
1081 | } |
1082 | |
1083 | body << " }\n" ; |
1084 | } |
1085 | |
1086 | /// Generate the parser for a enum attribute. |
1087 | static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body, |
1088 | FmtContext &attrTypeCtx, bool parseAsOptional, |
1089 | bool useProperties, StringRef opCppClassName) { |
1090 | Attribute baseAttr = var->attr.getBaseAttr(); |
1091 | const EnumAttr &enumAttr = cast<EnumAttr>(Val&: baseAttr); |
1092 | std::vector<EnumAttrCase> cases = enumAttr.getAllCases(); |
1093 | |
1094 | // Generate the code for building an attribute for this enum. |
1095 | std::string attrBuilderStr; |
1096 | { |
1097 | llvm::raw_string_ostream os(attrBuilderStr); |
1098 | os << tgfmt(fmt: enumAttr.getConstBuilderTemplate(), ctx: &attrTypeCtx, |
1099 | vals: "*attrOptional" ); |
1100 | } |
1101 | |
1102 | // Build a string containing the cases that can be formatted as a keyword. |
1103 | std::string validCaseKeywordsStr = "{" ; |
1104 | llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr); |
1105 | for (const EnumAttrCase &attrCase : cases) |
1106 | if (canFormatStringAsKeyword(value: attrCase.getStr())) |
1107 | validCaseKeywordsOS << '"' << attrCase.getStr() << "\"," ; |
1108 | validCaseKeywordsOS.str().back() = '}'; |
1109 | |
1110 | // If the attribute is not optional, build an error message for the missing |
1111 | // attribute. |
1112 | std::string errorMessage; |
1113 | if (!parseAsOptional) { |
1114 | llvm::raw_string_ostream errorMessageOS(errorMessage); |
1115 | errorMessageOS |
1116 | << "return parser.emitError(loc, \"expected string or " |
1117 | "keyword containing one of the following enum values for attribute '" |
1118 | << var->name << "' [" ; |
1119 | llvm::interleaveComma(c: cases, os&: errorMessageOS, each_fn: [&](const auto &attrCase) { |
1120 | errorMessageOS << attrCase.getStr(); |
1121 | }); |
1122 | errorMessageOS << "]\");" ; |
1123 | } |
1124 | std::string attrAssignment; |
1125 | if (useProperties) { |
1126 | attrAssignment = |
1127 | formatv(Fmt: " " |
1128 | "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;" , |
1129 | Vals: var->name, Vals&: opCppClassName); |
1130 | } else { |
1131 | attrAssignment = |
1132 | formatv(Fmt: "result.addAttribute(\"{0}\", {0}Attr);" , Vals: var->name); |
1133 | } |
1134 | |
1135 | body << formatv(Fmt: enumAttrParserCode, Vals: var->name, Vals: enumAttr.getCppNamespace(), |
1136 | Vals: enumAttr.getStringToSymbolFnName(), Vals&: attrBuilderStr, |
1137 | Vals&: validCaseKeywordsStr, Vals&: errorMessage, Vals&: attrAssignment); |
1138 | } |
1139 | |
1140 | // Generate the parser for an attribute. |
1141 | static void genAttrParser(AttributeVariable *attr, MethodBody &body, |
1142 | FmtContext &attrTypeCtx, bool parseAsOptional, |
1143 | bool useProperties, StringRef opCppClassName) { |
1144 | const NamedAttribute *var = attr->getVar(); |
1145 | |
1146 | // Check to see if we can parse this as an enum attribute. |
1147 | if (canFormatEnumAttr(attr: var)) |
1148 | return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional, |
1149 | useProperties, opCppClassName); |
1150 | |
1151 | // Check to see if we should parse this as a symbol name attribute. |
1152 | if (shouldFormatSymbolNameAttr(attr: var)) { |
1153 | body << formatv(Fmt: parseAsOptional ? optionalSymbolNameAttrParserCode |
1154 | : symbolNameAttrParserCode, |
1155 | Vals: var->name); |
1156 | } else { |
1157 | |
1158 | // If this attribute has a buildable type, use that when parsing the |
1159 | // attribute. |
1160 | std::string attrTypeStr; |
1161 | if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) { |
1162 | llvm::raw_string_ostream os(attrTypeStr); |
1163 | os << tgfmt(fmt: *typeBuilder, ctx: &attrTypeCtx); |
1164 | } else { |
1165 | attrTypeStr = "::mlir::Type{}" ; |
1166 | } |
1167 | if (parseAsOptional) { |
1168 | body << formatv(Fmt: optionalAttrParserCode, Vals: var->name, Vals&: attrTypeStr); |
1169 | } else { |
1170 | if (attr->shouldBeQualified() || |
1171 | var->attr.getStorageType() == "::mlir::Attribute" ) |
1172 | body << formatv(Fmt: genericAttrParserCode, Vals: var->name, Vals&: attrTypeStr); |
1173 | else |
1174 | body << formatv(Fmt: attrParserCode, Vals: var->name, Vals&: attrTypeStr); |
1175 | } |
1176 | } |
1177 | if (useProperties) { |
1178 | body << formatv( |
1179 | Fmt: " if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = " |
1180 | "{0}Attr;\n" , |
1181 | Vals: var->name, Vals&: opCppClassName); |
1182 | } else { |
1183 | body << formatv( |
1184 | Fmt: " if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n" , |
1185 | Vals: var->name); |
1186 | } |
1187 | } |
1188 | |
1189 | // Generates the 'setPropertiesFromParsedAttr' used to set properties from a |
1190 | // 'prop-dict' dictionary attr. |
1191 | static void genParsedAttrPropertiesSetter(OperationFormat &fmt, Operator &op, |
1192 | OpClass &opClass) { |
1193 | // Not required unless 'prop-dict' is present. |
1194 | if (!fmt.hasPropDict) |
1195 | return; |
1196 | |
1197 | SmallVector<MethodParameter> paramList; |
1198 | paramList.emplace_back(Args: "Properties &" , Args: "prop" ); |
1199 | paramList.emplace_back(Args: "::mlir::Attribute" , Args: "attr" ); |
1200 | paramList.emplace_back(Args: "::llvm::function_ref<::mlir::InFlightDiagnostic()>" , |
1201 | Args: "emitError" ); |
1202 | |
1203 | Method *method = opClass.addStaticMethod(retType: "::mlir::LogicalResult" , |
1204 | name: "setPropertiesFromParsedAttr" , |
1205 | args: std::move(paramList)); |
1206 | MethodBody &body = method->body().indent(); |
1207 | |
1208 | body << R"decl( |
1209 | ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); |
1210 | if (!dict) { |
1211 | emitError() << "expected DictionaryAttr to set properties"; |
1212 | return ::mlir::failure(); |
1213 | } |
1214 | )decl" ; |
1215 | |
1216 | // TODO: properties might be optional as well. |
1217 | const char *propFromAttrFmt = R"decl( |
1218 | auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr, |
1219 | ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{ |
1220 | {0}; |
1221 | }; |
1222 | auto attr = dict.get("{1}"); |
1223 | if (!attr) {{ |
1224 | emitError() << "expected key entry for {1} in DictionaryAttr to set " |
1225 | "Properties."; |
1226 | return ::mlir::failure(); |
1227 | } |
1228 | if (::mlir::failed(setFromAttr(prop.{1}, attr, emitError))) |
1229 | return ::mlir::failure(); |
1230 | )decl" ; |
1231 | |
1232 | // Generate the setter for any property not parsed elsewhere. |
1233 | for (const NamedProperty &namedProperty : op.getProperties()) { |
1234 | if (fmt.usedProperties.contains(key: &namedProperty)) |
1235 | continue; |
1236 | |
1237 | auto scope = body.scope(open: "{\n" , close: "}\n" , /*indent=*/true); |
1238 | |
1239 | StringRef name = namedProperty.name; |
1240 | const Property &prop = namedProperty.prop; |
1241 | FmtContext fctx; |
1242 | body << formatv(Fmt: propFromAttrFmt, |
1243 | Vals: tgfmt(fmt: prop.getConvertFromAttributeCall(), |
1244 | ctx: &fctx.addSubst(placeholder: "_attr" , subst: "propAttr" ) |
1245 | .addSubst(placeholder: "_storage" , subst: "propStorage" ) |
1246 | .addSubst(placeholder: "_diag" , subst: "emitError" )), |
1247 | Vals&: name); |
1248 | } |
1249 | |
1250 | // Generate the setter for any attribute not parsed elsewhere. |
1251 | for (const NamedAttribute &namedAttr : op.getAttributes()) { |
1252 | if (fmt.usedAttributes.contains(key: &namedAttr)) |
1253 | continue; |
1254 | |
1255 | const Attribute &attr = namedAttr.attr; |
1256 | // Derived attributes do not need to be parsed. |
1257 | if (attr.isDerivedAttr()) |
1258 | continue; |
1259 | |
1260 | auto scope = body.scope(open: "{\n" , close: "}\n" , /*indent=*/true); |
1261 | |
1262 | // If the attribute has a default value or is optional, it does not need to |
1263 | // be present in the parsed dictionary attribute. |
1264 | bool isRequired = !attr.isOptional() && !attr.hasDefaultValue(); |
1265 | body << formatv(Fmt: R"decl( |
1266 | auto &propStorage = prop.{0}; |
1267 | auto attr = dict.get("{0}"); |
1268 | if (attr || /*isRequired=*/{1}) {{ |
1269 | if (!attr) {{ |
1270 | emitError() << "expected key entry for {0} in DictionaryAttr to set " |
1271 | "Properties."; |
1272 | return ::mlir::failure(); |
1273 | } |
1274 | auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr); |
1275 | if (convertedAttr) {{ |
1276 | propStorage = convertedAttr; |
1277 | } else {{ |
1278 | emitError() << "Invalid attribute `{0}` in property conversion: " << attr; |
1279 | return ::mlir::failure(); |
1280 | } |
1281 | } |
1282 | )decl" , |
1283 | Vals: namedAttr.name, Vals&: isRequired); |
1284 | } |
1285 | body << "return ::mlir::success();\n" ; |
1286 | } |
1287 | |
1288 | void OperationFormat::genParser(Operator &op, OpClass &opClass) { |
1289 | SmallVector<MethodParameter> paramList; |
1290 | paramList.emplace_back(Args: "::mlir::OpAsmParser &" , Args: "parser" ); |
1291 | paramList.emplace_back(Args: "::mlir::OperationState &" , Args: "result" ); |
1292 | |
1293 | auto *method = opClass.addStaticMethod(retType: "::mlir::ParseResult" , name: "parse" , |
1294 | args: std::move(paramList)); |
1295 | auto &body = method->body(); |
1296 | |
1297 | // Generate variables to store the operands and type within the format. This |
1298 | // allows for referencing these variables in the presence of optional |
1299 | // groupings. |
1300 | for (FormatElement *element : elements) |
1301 | genElementParserStorage(element, op, body); |
1302 | |
1303 | // A format context used when parsing attributes with buildable types. |
1304 | FmtContext attrTypeCtx; |
1305 | attrTypeCtx.withBuilder(subst: "parser.getBuilder()" ); |
1306 | |
1307 | // Generate parsers for each of the elements. |
1308 | for (FormatElement *element : elements) |
1309 | genElementParser(element, body, attrTypeCtx); |
1310 | |
1311 | // Generate the code to resolve the operand/result types and successors now |
1312 | // that they have been parsed. |
1313 | genParserRegionResolution(op, body); |
1314 | genParserSuccessorResolution(op, body); |
1315 | genParserVariadicSegmentResolution(op, body); |
1316 | genParserTypeResolution(op, body); |
1317 | |
1318 | body << " return ::mlir::success();\n" ; |
1319 | |
1320 | genParsedAttrPropertiesSetter(fmt&: *this, op, opClass); |
1321 | } |
1322 | |
1323 | void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, |
1324 | FmtContext &attrTypeCtx, |
1325 | GenContext genCtx) { |
1326 | /// Optional Group. |
1327 | if (auto *optional = dyn_cast<OptionalElement>(Val: element)) { |
1328 | auto genElementParsers = [&](FormatElement *firstElement, |
1329 | ArrayRef<FormatElement *> elements, |
1330 | bool thenGroup) { |
1331 | // If the anchor is a unit attribute, we don't need to print it. When |
1332 | // parsing, we will add this attribute if this group is present. |
1333 | FormatElement *elidedAnchorElement = nullptr; |
1334 | auto *anchorAttr = dyn_cast<AttributeVariable>(Val: optional->getAnchor()); |
1335 | if (anchorAttr && anchorAttr != firstElement && |
1336 | anchorAttr->isUnitAttr()) { |
1337 | elidedAnchorElement = anchorAttr; |
1338 | |
1339 | if (!thenGroup == optional->isInverted()) { |
1340 | // Add the anchor unit attribute to the operation state. |
1341 | if (useProperties) { |
1342 | body << formatv( |
1343 | Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = " |
1344 | "parser.getBuilder().getUnitAttr();" , |
1345 | Vals: anchorAttr->getVar()->name, Vals&: opCppClassName); |
1346 | } else { |
1347 | body << " result.addAttribute(\"" << anchorAttr->getVar()->name |
1348 | << "\", parser.getBuilder().getUnitAttr());\n" ; |
1349 | } |
1350 | } |
1351 | } |
1352 | |
1353 | // Generate the rest of the elements inside an optional group. Elements in |
1354 | // an optional group after the guard are parsed as required. |
1355 | for (FormatElement *childElement : elements) |
1356 | if (childElement != elidedAnchorElement) |
1357 | genElementParser(element: childElement, body, attrTypeCtx, |
1358 | genCtx: GenContext::Optional); |
1359 | }; |
1360 | |
1361 | ArrayRef<FormatElement *> thenElements = |
1362 | optional->getThenElements(/*parseable=*/true); |
1363 | |
1364 | // Generate a special optional parser for the first element to gate the |
1365 | // parsing of the rest of the elements. |
1366 | FormatElement *firstElement = thenElements.front(); |
1367 | if (auto *attrVar = dyn_cast<AttributeVariable>(Val: firstElement)) { |
1368 | genAttrParser(attr: attrVar, body, attrTypeCtx, /*parseAsOptional=*/true, |
1369 | useProperties, opCppClassName); |
1370 | body << " if (" << attrVar->getVar()->name << "Attr) {\n" ; |
1371 | } else if (auto *literal = dyn_cast<LiteralElement>(Val: firstElement)) { |
1372 | body << " if (::mlir::succeeded(parser.parseOptional" ; |
1373 | genLiteralParser(value: literal->getSpelling(), body); |
1374 | body << ")) {\n" ; |
1375 | } else if (auto *opVar = dyn_cast<OperandVariable>(Val: firstElement)) { |
1376 | genElementParser(element: opVar, body, attrTypeCtx); |
1377 | body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n" ; |
1378 | } else if (auto *regionVar = dyn_cast<RegionVariable>(Val: firstElement)) { |
1379 | const NamedRegion *region = regionVar->getVar(); |
1380 | if (region->isVariadic()) { |
1381 | genElementParser(element: regionVar, body, attrTypeCtx); |
1382 | body << " if (!" << region->name << "Regions.empty()) {\n" ; |
1383 | } else { |
1384 | body << llvm::formatv(Fmt: optionalRegionParserCode, Vals: region->name); |
1385 | body << " if (!" << region->name << "Region->empty()) {\n " ; |
1386 | if (hasImplicitTermTrait) |
1387 | body << llvm::formatv(Fmt: regionEnsureTerminatorParserCode, Vals: region->name); |
1388 | else if (hasSingleBlockTrait) |
1389 | body << llvm::formatv(Fmt: regionEnsureSingleBlockParserCode, |
1390 | Vals: region->name); |
1391 | } |
1392 | } else if (auto *custom = dyn_cast<CustomDirective>(Val: firstElement)) { |
1393 | body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n" ; |
1394 | genCustomDirectiveParser(dir: custom, body, useProperties, opCppClassName, |
1395 | /*isOptional=*/true); |
1396 | body << " return ::mlir::success();\n" |
1397 | << " }(); optResult.has_value() && ::mlir::failed(*optResult)) {\n" |
1398 | << " return ::mlir::failure();\n" |
1399 | << " } else if (optResult.has_value()) {\n" ; |
1400 | } |
1401 | |
1402 | genElementParsers(firstElement, thenElements.drop_front(), |
1403 | /*thenGroup=*/true); |
1404 | body << " }" ; |
1405 | |
1406 | // Generate the else elements. |
1407 | auto elseElements = optional->getElseElements(); |
1408 | if (!elseElements.empty()) { |
1409 | body << " else {\n" ; |
1410 | ArrayRef<FormatElement *> elseElements = |
1411 | optional->getElseElements(/*parseable=*/true); |
1412 | genElementParsers(elseElements.front(), elseElements, |
1413 | /*thenGroup=*/false); |
1414 | body << " }" ; |
1415 | } |
1416 | body << "\n" ; |
1417 | |
1418 | /// OIList Directive |
1419 | } else if (OIListElement *oilist = dyn_cast<OIListElement>(Val: element)) { |
1420 | for (LiteralElement *le : oilist->getLiteralElements()) |
1421 | body << " bool " << le->getSpelling() << "Clause = false;\n" ; |
1422 | |
1423 | // Generate the parsing loop |
1424 | body << " while(true) {\n" ; |
1425 | for (auto clause : oilist->getClauses()) { |
1426 | LiteralElement *lelement = std::get<0>(t&: clause); |
1427 | ArrayRef<FormatElement *> pelement = std::get<1>(t&: clause); |
1428 | body << "if (succeeded(parser.parseOptional" ; |
1429 | genLiteralParser(value: lelement->getSpelling(), body); |
1430 | body << ")) {\n" ; |
1431 | StringRef lelementName = lelement->getSpelling(); |
1432 | body << formatv(Fmt: oilistParserCode, Vals&: lelementName); |
1433 | if (AttributeVariable *unitAttrElem = |
1434 | oilist->getUnitAttrParsingElement(pelement)) { |
1435 | if (useProperties) { |
1436 | body << formatv( |
1437 | Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = " |
1438 | "parser.getBuilder().getUnitAttr();" , |
1439 | Vals: unitAttrElem->getVar()->name, Vals&: opCppClassName); |
1440 | } else { |
1441 | body << " result.addAttribute(\"" << unitAttrElem->getVar()->name |
1442 | << "\", UnitAttr::get(parser.getContext()));\n" ; |
1443 | } |
1444 | } else { |
1445 | for (FormatElement *el : pelement) |
1446 | genElementParser(element: el, body, attrTypeCtx); |
1447 | } |
1448 | body << " } else " ; |
1449 | } |
1450 | body << " {\n" ; |
1451 | body << " break;\n" ; |
1452 | body << " }\n" ; |
1453 | body << "}\n" ; |
1454 | |
1455 | /// Literals. |
1456 | } else if (LiteralElement *literal = dyn_cast<LiteralElement>(Val: element)) { |
1457 | body << " if (parser.parse" ; |
1458 | genLiteralParser(value: literal->getSpelling(), body); |
1459 | body << ")\n return ::mlir::failure();\n" ; |
1460 | |
1461 | /// Whitespaces. |
1462 | } else if (isa<WhitespaceElement>(Val: element)) { |
1463 | // Nothing to parse. |
1464 | |
1465 | /// Arguments. |
1466 | } else if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) { |
1467 | bool parseAsOptional = |
1468 | (genCtx == GenContext::Normal && attr->getVar()->attr.isOptional()); |
1469 | genAttrParser(attr, body, attrTypeCtx, parseAsOptional, useProperties, |
1470 | opCppClassName); |
1471 | |
1472 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) { |
1473 | ArgumentLengthKind lengthKind = getArgumentLengthKind(var: operand->getVar()); |
1474 | StringRef name = operand->getVar()->name; |
1475 | if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) |
1476 | body << llvm::formatv( |
1477 | Fmt: variadicOfVariadicOperandParserCode, Vals&: name, |
1478 | Vals: operand->getVar()->constraint.getVariadicOfVariadicSegmentSizeAttr()); |
1479 | else if (lengthKind == ArgumentLengthKind::Variadic) |
1480 | body << llvm::formatv(Fmt: variadicOperandParserCode, Vals&: name); |
1481 | else if (lengthKind == ArgumentLengthKind::Optional) |
1482 | body << llvm::formatv(Fmt: optionalOperandParserCode, Vals&: name); |
1483 | else |
1484 | body << formatv(Fmt: operandParserCode, Vals&: name); |
1485 | |
1486 | } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) { |
1487 | bool isVariadic = region->getVar()->isVariadic(); |
1488 | body << llvm::formatv(Fmt: isVariadic ? regionListParserCode : regionParserCode, |
1489 | Vals: region->getVar()->name); |
1490 | if (hasImplicitTermTrait) |
1491 | body << llvm::formatv(Fmt: isVariadic ? regionListEnsureTerminatorParserCode |
1492 | : regionEnsureTerminatorParserCode, |
1493 | Vals: region->getVar()->name); |
1494 | else if (hasSingleBlockTrait) |
1495 | body << llvm::formatv(Fmt: isVariadic ? regionListEnsureSingleBlockParserCode |
1496 | : regionEnsureSingleBlockParserCode, |
1497 | Vals: region->getVar()->name); |
1498 | |
1499 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) { |
1500 | bool isVariadic = successor->getVar()->isVariadic(); |
1501 | body << formatv(Fmt: isVariadic ? successorListParserCode : successorParserCode, |
1502 | Vals: successor->getVar()->name); |
1503 | |
1504 | /// Directives. |
1505 | } else if (auto *attrDict = dyn_cast<AttrDictDirective>(Val: element)) { |
1506 | body.indent() << "{\n" ; |
1507 | body.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n" |
1508 | << "if (parser.parseOptionalAttrDict" |
1509 | << (attrDict->isWithKeyword() ? "WithKeyword" : "" ) |
1510 | << "(result.attributes))\n" |
1511 | << " return ::mlir::failure();\n" ; |
1512 | if (useProperties) { |
1513 | body << "if (failed(verifyInherentAttrs(result.name, result.attributes, " |
1514 | "[&]() {\n" |
1515 | << " return parser.emitError(loc) << \"'\" << " |
1516 | "result.name.getStringRef() << \"' op \";\n" |
1517 | << " })))\n" |
1518 | << " return ::mlir::failure();\n" ; |
1519 | } |
1520 | body.unindent() << "}\n" ; |
1521 | body.unindent(); |
1522 | } else if (isa<PropDictDirective>(Val: element)) { |
1523 | body << " if (parseProperties(parser, result))\n" |
1524 | << " return ::mlir::failure();\n" ; |
1525 | } else if (auto *customDir = dyn_cast<CustomDirective>(Val: element)) { |
1526 | genCustomDirectiveParser(dir: customDir, body, useProperties, opCppClassName); |
1527 | } else if (isa<OperandsDirective>(Val: element)) { |
1528 | body << " [[maybe_unused]] ::llvm::SMLoc allOperandLoc =" |
1529 | << " parser.getCurrentLocation();\n" |
1530 | << " if (parser.parseOperandList(allOperands))\n" |
1531 | << " return ::mlir::failure();\n" ; |
1532 | |
1533 | } else if (isa<RegionsDirective>(Val: element)) { |
1534 | body << llvm::formatv(Fmt: regionListParserCode, Vals: "full" ); |
1535 | if (hasImplicitTermTrait) |
1536 | body << llvm::formatv(Fmt: regionListEnsureTerminatorParserCode, Vals: "full" ); |
1537 | else if (hasSingleBlockTrait) |
1538 | body << llvm::formatv(Fmt: regionListEnsureSingleBlockParserCode, Vals: "full" ); |
1539 | |
1540 | } else if (isa<SuccessorsDirective>(Val: element)) { |
1541 | body << llvm::formatv(Fmt: successorListParserCode, Vals: "full" ); |
1542 | |
1543 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) { |
1544 | ArgumentLengthKind lengthKind; |
1545 | StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind); |
1546 | if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { |
1547 | body << llvm::formatv(Fmt: variadicOfVariadicTypeParserCode, Vals&: listName); |
1548 | } else if (lengthKind == ArgumentLengthKind::Variadic) { |
1549 | body << llvm::formatv(Fmt: variadicTypeParserCode, Vals&: listName); |
1550 | } else if (lengthKind == ArgumentLengthKind::Optional) { |
1551 | body << llvm::formatv(Fmt: optionalTypeParserCode, Vals&: listName); |
1552 | } else { |
1553 | const char *parserCode = |
1554 | dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode; |
1555 | TypeSwitch<FormatElement *>(dir->getArg()) |
1556 | .Case<OperandVariable, ResultVariable>(caseFn: [&](auto operand) { |
1557 | body << formatv(parserCode, |
1558 | operand->getVar()->constraint.getCPPClassName(), |
1559 | listName); |
1560 | }) |
1561 | .Default(defaultFn: [&](auto operand) { |
1562 | body << formatv(Fmt: parserCode, Vals: "::mlir::Type" , Vals&: listName); |
1563 | }); |
1564 | } |
1565 | } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(Val: element)) { |
1566 | ArgumentLengthKind ignored; |
1567 | body << formatv(Fmt: functionalTypeParserCode, |
1568 | Vals: getTypeListName(arg: dir->getInputs(), lengthKind&: ignored), |
1569 | Vals: getTypeListName(arg: dir->getResults(), lengthKind&: ignored)); |
1570 | } else { |
1571 | llvm_unreachable("unknown format element" ); |
1572 | } |
1573 | } |
1574 | |
1575 | void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) { |
1576 | // If any of type resolutions use transformed variables, make sure that the |
1577 | // types of those variables are resolved. |
1578 | SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables; |
1579 | FmtContext verifierFCtx; |
1580 | for (TypeResolution &resolver : |
1581 | llvm::concat<TypeResolution>(Ranges&: resultTypes, Ranges&: operandTypes)) { |
1582 | std::optional<StringRef> transformer = resolver.getVarTransformer(); |
1583 | if (!transformer) |
1584 | continue; |
1585 | // Ensure that we don't verify the same variables twice. |
1586 | const NamedTypeConstraint *variable = resolver.getVariable(); |
1587 | if (!variable || !verifiedVariables.insert(Ptr: variable).second) |
1588 | continue; |
1589 | |
1590 | auto constraint = variable->constraint; |
1591 | body << " for (::mlir::Type type : " << variable->name << "Types) {\n" |
1592 | << " (void)type;\n" |
1593 | << " if (!(" |
1594 | << tgfmt(fmt: constraint.getConditionTemplate(), |
1595 | ctx: &verifierFCtx.withSelf(subst: "type" )) |
1596 | << ")) {\n" |
1597 | << formatv(Fmt: " return parser.emitError(parser.getNameLoc()) << " |
1598 | "\"'{0}' must be {1}, but got \" << type;\n" , |
1599 | Vals: variable->name, Vals: constraint.getSummary()) |
1600 | << " }\n" |
1601 | << " }\n" ; |
1602 | } |
1603 | |
1604 | // Initialize the set of buildable types. |
1605 | if (!buildableTypes.empty()) { |
1606 | FmtContext typeBuilderCtx; |
1607 | typeBuilderCtx.withBuilder(subst: "parser.getBuilder()" ); |
1608 | for (auto &it : buildableTypes) |
1609 | body << " ::mlir::Type odsBuildableType" << it.second << " = " |
1610 | << tgfmt(fmt: it.first, ctx: &typeBuilderCtx) << ";\n" ; |
1611 | } |
1612 | |
1613 | // Emit the code necessary for a type resolver. |
1614 | auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) { |
1615 | if (std::optional<int> val = resolver.getBuilderIdx()) { |
1616 | body << "odsBuildableType" << *val; |
1617 | } else if (const NamedTypeConstraint *var = resolver.getVariable()) { |
1618 | if (std::optional<StringRef> tform = resolver.getVarTransformer()) { |
1619 | FmtContext fmtContext; |
1620 | fmtContext.addSubst(placeholder: "_ctxt" , subst: "parser.getContext()" ); |
1621 | if (var->isVariadic()) |
1622 | fmtContext.withSelf(subst: var->name + "Types" ); |
1623 | else |
1624 | fmtContext.withSelf(subst: var->name + "Types[0]" ); |
1625 | body << tgfmt(fmt: *tform, ctx: &fmtContext); |
1626 | } else { |
1627 | body << var->name << "Types" ; |
1628 | if (!var->isVariadic()) |
1629 | body << "[0]" ; |
1630 | } |
1631 | } else if (const NamedAttribute *attr = resolver.getAttribute()) { |
1632 | if (std::optional<StringRef> tform = resolver.getVarTransformer()) |
1633 | body << tgfmt(fmt: *tform, |
1634 | ctx: &FmtContext().withSelf(subst: attr->name + "Attr.getType()" )); |
1635 | else |
1636 | body << attr->name << "Attr.getType()" ; |
1637 | } else { |
1638 | body << curVar << "Types" ; |
1639 | } |
1640 | }; |
1641 | |
1642 | // Resolve each of the result types. |
1643 | if (!infersResultTypes) { |
1644 | if (allResultTypes) { |
1645 | body << " result.addTypes(allResultTypes);\n" ; |
1646 | } else { |
1647 | for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { |
1648 | body << " result.addTypes(" ; |
1649 | emitTypeResolver(resultTypes[i], op.getResultName(index: i)); |
1650 | body << ");\n" ; |
1651 | } |
1652 | } |
1653 | } |
1654 | |
1655 | // Emit the operand type resolutions. |
1656 | genParserOperandTypeResolution(op, body, emitTypeResolver); |
1657 | |
1658 | // Handle return type inference once all operands have been resolved |
1659 | if (infersResultTypes) |
1660 | body << formatv(Fmt: inferReturnTypesParserCode, Vals: op.getCppClassName()); |
1661 | } |
1662 | |
1663 | void OperationFormat::genParserOperandTypeResolution( |
1664 | Operator &op, MethodBody &body, |
1665 | function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) { |
1666 | // Early exit if there are no operands. |
1667 | if (op.getNumOperands() == 0) |
1668 | return; |
1669 | |
1670 | // Handle the case where all operand types are grouped together with |
1671 | // "types(operands)". |
1672 | if (allOperandTypes) { |
1673 | // If `operands` was specified, use the full operand list directly. |
1674 | if (allOperands) { |
1675 | body << " if (parser.resolveOperands(allOperands, allOperandTypes, " |
1676 | "allOperandLoc, result.operands))\n" |
1677 | " return ::mlir::failure();\n" ; |
1678 | return; |
1679 | } |
1680 | |
1681 | // Otherwise, use llvm::concat to merge the disjoint operand lists together. |
1682 | // llvm::concat does not allow the case of a single range, so guard it here. |
1683 | body << " if (parser.resolveOperands(" ; |
1684 | if (op.getNumOperands() > 1) { |
1685 | body << "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(" ; |
1686 | llvm::interleaveComma(c: op.getOperands(), os&: body, each_fn: [&](auto &operand) { |
1687 | body << operand.name << "Operands" ; |
1688 | }); |
1689 | body << ")" ; |
1690 | } else { |
1691 | body << op.operand_begin()->name << "Operands" ; |
1692 | } |
1693 | body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n" |
1694 | << " return ::mlir::failure();\n" ; |
1695 | return; |
1696 | } |
1697 | |
1698 | // Handle the case where all operands are grouped together with "operands". |
1699 | if (allOperands) { |
1700 | body << " if (parser.resolveOperands(allOperands, " ; |
1701 | |
1702 | // Group all of the operand types together to perform the resolution all at |
1703 | // once. Use llvm::concat to perform the merge. llvm::concat does not allow |
1704 | // the case of a single range, so guard it here. |
1705 | if (op.getNumOperands() > 1) { |
1706 | body << "::llvm::concat<const ::mlir::Type>(" ; |
1707 | llvm::interleaveComma( |
1708 | c: llvm::seq<int>(Begin: 0, End: op.getNumOperands()), os&: body, each_fn: [&](int i) { |
1709 | body << "::llvm::ArrayRef<::mlir::Type>(" ; |
1710 | emitTypeResolver(operandTypes[i], op.getOperand(index: i).name); |
1711 | body << ")" ; |
1712 | }); |
1713 | body << ")" ; |
1714 | } else { |
1715 | emitTypeResolver(operandTypes.front(), op.getOperand(index: 0).name); |
1716 | } |
1717 | |
1718 | body << ", allOperandLoc, result.operands))\n return " |
1719 | "::mlir::failure();\n" ; |
1720 | return; |
1721 | } |
1722 | |
1723 | // The final case is the one where each of the operands types are resolved |
1724 | // separately. |
1725 | for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { |
1726 | NamedTypeConstraint &operand = op.getOperand(index: i); |
1727 | body << " if (parser.resolveOperands(" << operand.name << "Operands, " ; |
1728 | |
1729 | // Resolve the type of this operand. |
1730 | TypeResolution &operandType = operandTypes[i]; |
1731 | emitTypeResolver(operandType, operand.name); |
1732 | |
1733 | body << ", " << operand.name |
1734 | << "OperandsLoc, result.operands))\n return ::mlir::failure();\n" ; |
1735 | } |
1736 | } |
1737 | |
1738 | void OperationFormat::genParserRegionResolution(Operator &op, |
1739 | MethodBody &body) { |
1740 | // Check for the case where all regions were parsed. |
1741 | bool hasAllRegions = llvm::any_of( |
1742 | Range&: elements, P: [](FormatElement *elt) { return isa<RegionsDirective>(Val: elt); }); |
1743 | if (hasAllRegions) { |
1744 | body << " result.addRegions(fullRegions);\n" ; |
1745 | return; |
1746 | } |
1747 | |
1748 | // Otherwise, handle each region individually. |
1749 | for (const NamedRegion ®ion : op.getRegions()) { |
1750 | if (region.isVariadic()) |
1751 | body << " result.addRegions(" << region.name << "Regions);\n" ; |
1752 | else |
1753 | body << " result.addRegion(std::move(" << region.name << "Region));\n" ; |
1754 | } |
1755 | } |
1756 | |
1757 | void OperationFormat::genParserSuccessorResolution(Operator &op, |
1758 | MethodBody &body) { |
1759 | // Check for the case where all successors were parsed. |
1760 | bool hasAllSuccessors = llvm::any_of(Range&: elements, P: [](FormatElement *elt) { |
1761 | return isa<SuccessorsDirective>(Val: elt); |
1762 | }); |
1763 | if (hasAllSuccessors) { |
1764 | body << " result.addSuccessors(fullSuccessors);\n" ; |
1765 | return; |
1766 | } |
1767 | |
1768 | // Otherwise, handle each successor individually. |
1769 | for (const NamedSuccessor &successor : op.getSuccessors()) { |
1770 | if (successor.isVariadic()) |
1771 | body << " result.addSuccessors(" << successor.name << "Successors);\n" ; |
1772 | else |
1773 | body << " result.addSuccessors(" << successor.name << "Successor);\n" ; |
1774 | } |
1775 | } |
1776 | |
1777 | void OperationFormat::genParserVariadicSegmentResolution(Operator &op, |
1778 | MethodBody &body) { |
1779 | if (!allOperands) { |
1780 | if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments" )) { |
1781 | auto interleaveFn = [&](const NamedTypeConstraint &operand) { |
1782 | // If the operand is variadic emit the parsed size. |
1783 | if (operand.isVariableLength()) |
1784 | body << "static_cast<int32_t>(" << operand.name << "Operands.size())" ; |
1785 | else |
1786 | body << "1" ; |
1787 | }; |
1788 | if (op.getDialect().usePropertiesForAttributes()) { |
1789 | body << "::llvm::copy(::llvm::ArrayRef<int32_t>({" ; |
1790 | llvm::interleaveComma(c: op.getOperands(), os&: body, each_fn: interleaveFn); |
1791 | body << formatv(Fmt: "}), " |
1792 | "result.getOrAddProperties<{0}::Properties>()." |
1793 | "operandSegmentSizes.begin());\n" , |
1794 | Vals: op.getCppClassName()); |
1795 | } else { |
1796 | body << " result.addAttribute(\"operandSegmentSizes\", " |
1797 | << "parser.getBuilder().getDenseI32ArrayAttr({" ; |
1798 | llvm::interleaveComma(c: op.getOperands(), os&: body, each_fn: interleaveFn); |
1799 | body << "}));\n" ; |
1800 | } |
1801 | } |
1802 | for (const NamedTypeConstraint &operand : op.getOperands()) { |
1803 | if (!operand.isVariadicOfVariadic()) |
1804 | continue; |
1805 | if (op.getDialect().usePropertiesForAttributes()) { |
1806 | body << llvm::formatv( |
1807 | Fmt: " result.getOrAddProperties<{0}::Properties>().{1} = " |
1808 | "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n" , |
1809 | Vals: op.getCppClassName(), |
1810 | Vals: operand.constraint.getVariadicOfVariadicSegmentSizeAttr(), |
1811 | Vals: operand.name); |
1812 | } else { |
1813 | body << llvm::formatv( |
1814 | Fmt: " result.addAttribute(\"{0}\", " |
1815 | "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));" |
1816 | "\n" , |
1817 | Vals: operand.constraint.getVariadicOfVariadicSegmentSizeAttr(), |
1818 | Vals: operand.name); |
1819 | } |
1820 | } |
1821 | } |
1822 | |
1823 | if (!allResultTypes && |
1824 | op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments" )) { |
1825 | auto interleaveFn = [&](const NamedTypeConstraint &result) { |
1826 | // If the result is variadic emit the parsed size. |
1827 | if (result.isVariableLength()) |
1828 | body << "static_cast<int32_t>(" << result.name << "Types.size())" ; |
1829 | else |
1830 | body << "1" ; |
1831 | }; |
1832 | if (op.getDialect().usePropertiesForAttributes()) { |
1833 | body << "::llvm::copy(::llvm::ArrayRef<int32_t>({" ; |
1834 | llvm::interleaveComma(c: op.getResults(), os&: body, each_fn: interleaveFn); |
1835 | body << formatv(Fmt: "}), " |
1836 | "result.getOrAddProperties<{0}::Properties>()." |
1837 | "resultSegmentSizes.begin());\n" , |
1838 | Vals: op.getCppClassName()); |
1839 | } else { |
1840 | body << " result.addAttribute(\"resultSegmentSizes\", " |
1841 | << "parser.getBuilder().getDenseI32ArrayAttr({" ; |
1842 | llvm::interleaveComma(c: op.getResults(), os&: body, each_fn: interleaveFn); |
1843 | body << "}));\n" ; |
1844 | } |
1845 | } |
1846 | } |
1847 | |
1848 | //===----------------------------------------------------------------------===// |
1849 | // PrinterGen |
1850 | |
1851 | /// The code snippet used to generate a printer call for a region of an |
1852 | // operation that has the SingleBlockImplicitTerminator trait. |
1853 | /// |
1854 | /// {0}: The name of the region. |
1855 | const char *regionSingleBlockImplicitTerminatorPrinterCode = R"( |
1856 | { |
1857 | bool printTerminator = true; |
1858 | if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{ |
1859 | printTerminator = !term->getAttrDictionary().empty() || |
1860 | term->getNumOperands() != 0 || |
1861 | term->getNumResults() != 0; |
1862 | } |
1863 | _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true, |
1864 | /*printBlockTerminators=*/printTerminator); |
1865 | } |
1866 | )" ; |
1867 | |
1868 | /// The code snippet used to generate a printer call for an enum that has cases |
1869 | /// that can't be represented with a keyword. |
1870 | /// |
1871 | /// {0}: The name of the enum attribute. |
1872 | /// {1}: The name of the enum attributes symbolToString function. |
1873 | const char *enumAttrBeginPrinterCode = R"( |
1874 | { |
1875 | auto caseValue = {0}(); |
1876 | auto caseValueStr = {1}(caseValue); |
1877 | )" ; |
1878 | |
1879 | /// Generate the printer for the 'prop-dict' directive. |
1880 | static void genPropDictPrinter(OperationFormat &fmt, Operator &op, |
1881 | MethodBody &body) { |
1882 | body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedProps;\n" ; |
1883 | for (const NamedProperty *namedProperty : fmt.usedProperties) |
1884 | body << " elidedProps.push_back(\"" << namedProperty->name << "\");\n" ; |
1885 | for (const NamedAttribute *namedAttr : fmt.usedAttributes) |
1886 | body << " elidedProps.push_back(\"" << namedAttr->name << "\");\n" ; |
1887 | |
1888 | // Add code to check attributes for equality with the default value |
1889 | // for attributes with the elidePrintingDefaultValue bit set. |
1890 | for (const NamedAttribute &namedAttr : op.getAttributes()) { |
1891 | const Attribute &attr = namedAttr.attr; |
1892 | if (!attr.isDerivedAttr() && attr.hasDefaultValue()) { |
1893 | const StringRef &name = namedAttr.name; |
1894 | FmtContext fctx; |
1895 | fctx.withBuilder(subst: "odsBuilder" ); |
1896 | std::string defaultValue = std::string( |
1897 | tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx, vals: attr.getDefaultValue())); |
1898 | body << " {\n" ; |
1899 | body << " ::mlir::Builder odsBuilder(getContext());\n" ; |
1900 | body << " ::mlir::Attribute attr = " << op.getGetterName(name) |
1901 | << "Attr();\n" ; |
1902 | body << " if(attr && (attr == " << defaultValue << "))\n" ; |
1903 | body << " elidedProps.push_back(\"" << name << "\");\n" ; |
1904 | body << " }\n" ; |
1905 | } |
1906 | } |
1907 | |
1908 | body << " _odsPrinter << \" \";\n" |
1909 | << " printProperties(this->getContext(), _odsPrinter, " |
1910 | "getProperties(), elidedProps);\n" ; |
1911 | } |
1912 | |
1913 | /// Generate the printer for the 'attr-dict' directive. |
1914 | static void genAttrDictPrinter(OperationFormat &fmt, Operator &op, |
1915 | MethodBody &body, bool withKeyword) { |
1916 | body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;\n" ; |
1917 | // Elide the variadic segment size attributes if necessary. |
1918 | if (!fmt.allOperands && |
1919 | op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments" )) |
1920 | body << " elidedAttrs.push_back(\"operandSegmentSizes\");\n" ; |
1921 | if (!fmt.allResultTypes && |
1922 | op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments" )) |
1923 | body << " elidedAttrs.push_back(\"resultSegmentSizes\");\n" ; |
1924 | for (const StringRef key : fmt.inferredAttributes.keys()) |
1925 | body << " elidedAttrs.push_back(\"" << key << "\");\n" ; |
1926 | for (const NamedAttribute *attr : fmt.usedAttributes) |
1927 | body << " elidedAttrs.push_back(\"" << attr->name << "\");\n" ; |
1928 | // Add code to check attributes for equality with the default value |
1929 | // for attributes with the elidePrintingDefaultValue bit set. |
1930 | for (const NamedAttribute &namedAttr : op.getAttributes()) { |
1931 | const Attribute &attr = namedAttr.attr; |
1932 | if (!attr.isDerivedAttr() && attr.hasDefaultValue()) { |
1933 | const StringRef &name = namedAttr.name; |
1934 | FmtContext fctx; |
1935 | fctx.withBuilder(subst: "odsBuilder" ); |
1936 | std::string defaultValue = std::string( |
1937 | tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx, vals: attr.getDefaultValue())); |
1938 | body << " {\n" ; |
1939 | body << " ::mlir::Builder odsBuilder(getContext());\n" ; |
1940 | body << " ::mlir::Attribute attr = " << op.getGetterName(name) |
1941 | << "Attr();\n" ; |
1942 | body << " if(attr && (attr == " << defaultValue << "))\n" ; |
1943 | body << " elidedAttrs.push_back(\"" << name << "\");\n" ; |
1944 | body << " }\n" ; |
1945 | } |
1946 | } |
1947 | if (fmt.hasPropDict) |
1948 | body << " _odsPrinter.printOptionalAttrDict" |
1949 | << (withKeyword ? "WithKeyword" : "" ) |
1950 | << "(llvm::to_vector((*this)->getDiscardableAttrs()), elidedAttrs);\n" ; |
1951 | else |
1952 | body << " _odsPrinter.printOptionalAttrDict" |
1953 | << (withKeyword ? "WithKeyword" : "" ) |
1954 | << "((*this)->getAttrs(), elidedAttrs);\n" ; |
1955 | } |
1956 | |
1957 | /// Generate the printer for a literal value. `shouldEmitSpace` is true if a |
1958 | /// space should be emitted before this element. `lastWasPunctuation` is true if |
1959 | /// the previous element was a punctuation literal. |
1960 | static void genLiteralPrinter(StringRef value, MethodBody &body, |
1961 | bool &shouldEmitSpace, bool &lastWasPunctuation) { |
1962 | body << " _odsPrinter" ; |
1963 | |
1964 | // Don't insert a space for certain punctuation. |
1965 | if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation)) |
1966 | body << " << ' '" ; |
1967 | body << " << \"" << value << "\";\n" ; |
1968 | |
1969 | // Insert a space after certain literals. |
1970 | shouldEmitSpace = |
1971 | value.size() != 1 || !StringRef("<({[" ).contains(C: value.front()); |
1972 | lastWasPunctuation = value.front() != '_' && !isalpha(value.front()); |
1973 | } |
1974 | |
1975 | /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation` |
1976 | /// are set to false. |
1977 | static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace, |
1978 | bool &lastWasPunctuation) { |
1979 | if (value) { |
1980 | body << " _odsPrinter << ' ';\n" ; |
1981 | lastWasPunctuation = false; |
1982 | } else { |
1983 | lastWasPunctuation = true; |
1984 | } |
1985 | shouldEmitSpace = false; |
1986 | } |
1987 | |
1988 | /// Generate the printer for a custom directive parameter. |
1989 | static void genCustomDirectiveParameterPrinter(FormatElement *element, |
1990 | const Operator &op, |
1991 | MethodBody &body) { |
1992 | if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) { |
1993 | body << op.getGetterName(name: attr->getVar()->name) << "Attr()" ; |
1994 | |
1995 | } else if (isa<AttrDictDirective>(Val: element)) { |
1996 | body << "getOperation()->getAttrDictionary()" ; |
1997 | |
1998 | } else if (isa<PropDictDirective>(Val: element)) { |
1999 | body << "getProperties()" ; |
2000 | |
2001 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) { |
2002 | body << op.getGetterName(name: operand->getVar()->name) << "()" ; |
2003 | |
2004 | } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) { |
2005 | body << op.getGetterName(name: region->getVar()->name) << "()" ; |
2006 | |
2007 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) { |
2008 | body << op.getGetterName(name: successor->getVar()->name) << "()" ; |
2009 | |
2010 | } else if (auto *dir = dyn_cast<RefDirective>(Val: element)) { |
2011 | genCustomDirectiveParameterPrinter(element: dir->getArg(), op, body); |
2012 | |
2013 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) { |
2014 | auto *typeOperand = dir->getArg(); |
2015 | auto *operand = dyn_cast<OperandVariable>(Val: typeOperand); |
2016 | auto *var = operand ? operand->getVar() |
2017 | : cast<ResultVariable>(Val: typeOperand)->getVar(); |
2018 | std::string name = op.getGetterName(name: var->name); |
2019 | if (var->isVariadic()) |
2020 | body << name << "().getTypes()" ; |
2021 | else if (var->isOptional()) |
2022 | body << llvm::formatv(Fmt: "({0}() ? {0}().getType() : ::mlir::Type())" , Vals&: name); |
2023 | else |
2024 | body << name << "().getType()" ; |
2025 | |
2026 | } else if (auto *string = dyn_cast<StringElement>(Val: element)) { |
2027 | FmtContext ctx; |
2028 | ctx.withBuilder(subst: "::mlir::Builder(getContext())" ); |
2029 | ctx.addSubst(placeholder: "_ctxt" , subst: "getContext()" ); |
2030 | body << tgfmt(fmt: string->getValue(), ctx: &ctx); |
2031 | |
2032 | } else if (auto *property = dyn_cast<PropertyVariable>(Val: element)) { |
2033 | FmtContext ctx; |
2034 | ctx.addSubst(placeholder: "_ctxt" , subst: "getContext()" ); |
2035 | const NamedProperty *namedProperty = property->getVar(); |
2036 | ctx.addSubst(placeholder: "_storage" , subst: "getProperties()." + namedProperty->name); |
2037 | body << tgfmt(fmt: namedProperty->prop.getConvertFromStorageCall(), ctx: &ctx); |
2038 | } else { |
2039 | llvm_unreachable("unknown custom directive parameter" ); |
2040 | } |
2041 | } |
2042 | |
2043 | /// Generate the printer for a custom directive. |
2044 | static void genCustomDirectivePrinter(CustomDirective *customDir, |
2045 | const Operator &op, MethodBody &body) { |
2046 | body << " print" << customDir->getName() << "(_odsPrinter, *this" ; |
2047 | for (FormatElement *param : customDir->getArguments()) { |
2048 | body << ", " ; |
2049 | genCustomDirectiveParameterPrinter(element: param, op, body); |
2050 | } |
2051 | body << ");\n" ; |
2052 | } |
2053 | |
2054 | /// Generate the printer for a region with the given variable name. |
2055 | static void genRegionPrinter(const Twine ®ionName, MethodBody &body, |
2056 | bool hasImplicitTermTrait) { |
2057 | if (hasImplicitTermTrait) |
2058 | body << llvm::formatv(Fmt: regionSingleBlockImplicitTerminatorPrinterCode, |
2059 | Vals: regionName); |
2060 | else |
2061 | body << " _odsPrinter.printRegion(" << regionName << ");\n" ; |
2062 | } |
2063 | static void genVariadicRegionPrinter(const Twine ®ionListName, |
2064 | MethodBody &body, |
2065 | bool hasImplicitTermTrait) { |
2066 | body << " llvm::interleaveComma(" << regionListName |
2067 | << ", _odsPrinter, [&](::mlir::Region ®ion) {\n " ; |
2068 | genRegionPrinter(regionName: "region" , body, hasImplicitTermTrait); |
2069 | body << " });\n" ; |
2070 | } |
2071 | |
2072 | /// Generate the C++ for an operand to a (*-)type directive. |
2073 | static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op, |
2074 | MethodBody &body, |
2075 | bool useArrayRef = true) { |
2076 | if (isa<OperandsDirective>(Val: arg)) |
2077 | return body << "getOperation()->getOperandTypes()" ; |
2078 | if (isa<ResultsDirective>(Val: arg)) |
2079 | return body << "getOperation()->getResultTypes()" ; |
2080 | auto *operand = dyn_cast<OperandVariable>(Val: arg); |
2081 | auto *var = operand ? operand->getVar() : cast<ResultVariable>(Val: arg)->getVar(); |
2082 | if (var->isVariadicOfVariadic()) |
2083 | return body << llvm::formatv(Fmt: "{0}().join().getTypes()" , |
2084 | Vals: op.getGetterName(name: var->name)); |
2085 | if (var->isVariadic()) |
2086 | return body << op.getGetterName(name: var->name) << "().getTypes()" ; |
2087 | if (var->isOptional()) |
2088 | return body << llvm::formatv( |
2089 | Fmt: "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : " |
2090 | "::llvm::ArrayRef<::mlir::Type>())" , |
2091 | Vals: op.getGetterName(name: var->name)); |
2092 | if (useArrayRef) |
2093 | return body << "::llvm::ArrayRef<::mlir::Type>(" |
2094 | << op.getGetterName(name: var->name) << "().getType())" ; |
2095 | return body << op.getGetterName(name: var->name) << "().getType()" ; |
2096 | } |
2097 | |
2098 | /// Generate the printer for an enum attribute. |
2099 | static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, |
2100 | MethodBody &body) { |
2101 | Attribute baseAttr = var->attr.getBaseAttr(); |
2102 | const EnumAttr &enumAttr = cast<EnumAttr>(Val&: baseAttr); |
2103 | std::vector<EnumAttrCase> cases = enumAttr.getAllCases(); |
2104 | |
2105 | body << llvm::formatv(Fmt: enumAttrBeginPrinterCode, |
2106 | Vals: (var->attr.isOptional() ? "*" : "" ) + |
2107 | op.getGetterName(name: var->name), |
2108 | Vals: enumAttr.getSymbolToStringFnName()); |
2109 | |
2110 | // Get a string containing all of the cases that can't be represented with a |
2111 | // keyword. |
2112 | BitVector nonKeywordCases(cases.size()); |
2113 | for (auto it : llvm::enumerate(First&: cases)) { |
2114 | if (!canFormatStringAsKeyword(value: it.value().getStr())) |
2115 | nonKeywordCases.set(it.index()); |
2116 | } |
2117 | |
2118 | // Otherwise if this is a bit enum attribute, don't allow cases that may |
2119 | // overlap with other cases. For simplicity sake, only allow cases with a |
2120 | // single bit value. |
2121 | if (enumAttr.isBitEnum()) { |
2122 | for (auto it : llvm::enumerate(First&: cases)) { |
2123 | int64_t value = it.value().getValue(); |
2124 | if (value < 0 || !llvm::isPowerOf2_64(Value: value)) |
2125 | nonKeywordCases.set(it.index()); |
2126 | } |
2127 | } |
2128 | |
2129 | // If there are any cases that can't be used with a keyword, switch on the |
2130 | // case value to determine when to print in the string form. |
2131 | if (nonKeywordCases.any()) { |
2132 | body << " switch (caseValue) {\n" ; |
2133 | StringRef cppNamespace = enumAttr.getCppNamespace(); |
2134 | StringRef enumName = enumAttr.getEnumClassName(); |
2135 | for (auto it : llvm::enumerate(First&: cases)) { |
2136 | if (nonKeywordCases.test(Idx: it.index())) |
2137 | continue; |
2138 | StringRef symbol = it.value().getSymbol(); |
2139 | body << llvm::formatv(Fmt: " case {0}::{1}::{2}:\n" , Vals&: cppNamespace, Vals&: enumName, |
2140 | Vals: llvm::isDigit(C: symbol.front()) ? ("_" + symbol) |
2141 | : symbol); |
2142 | } |
2143 | body << " _odsPrinter << caseValueStr;\n" |
2144 | " break;\n" |
2145 | " default:\n" |
2146 | " _odsPrinter << '\"' << caseValueStr << '\"';\n" |
2147 | " break;\n" |
2148 | " }\n" |
2149 | " }\n" ; |
2150 | return; |
2151 | } |
2152 | |
2153 | body << " _odsPrinter << caseValueStr;\n" |
2154 | " }\n" ; |
2155 | } |
2156 | |
2157 | /// Generate a check that a DefaultValuedAttr has a value that is non-default. |
2158 | static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, |
2159 | AttributeVariable &attrElement) { |
2160 | FmtContext fctx; |
2161 | Attribute attr = attrElement.getVar()->attr; |
2162 | fctx.withBuilder(subst: "::mlir::OpBuilder((*this)->getContext())" ); |
2163 | body << " && " << op.getGetterName(name: attrElement.getVar()->name) << "Attr() != " |
2164 | << tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx, vals: attr.getDefaultValue()); |
2165 | } |
2166 | |
2167 | /// Generate the check for the anchor of an optional group. |
2168 | static void genOptionalGroupPrinterAnchor(FormatElement *anchor, |
2169 | const Operator &op, |
2170 | MethodBody &body) { |
2171 | TypeSwitch<FormatElement *>(anchor) |
2172 | .Case<OperandVariable, ResultVariable>(caseFn: [&](auto *element) { |
2173 | const NamedTypeConstraint *var = element->getVar(); |
2174 | std::string name = op.getGetterName(name: var->name); |
2175 | if (var->isOptional()) |
2176 | body << name << "()" ; |
2177 | else if (var->isVariadic()) |
2178 | body << "!" << name << "().empty()" ; |
2179 | }) |
2180 | .Case(caseFn: [&](RegionVariable *element) { |
2181 | const NamedRegion *var = element->getVar(); |
2182 | std::string name = op.getGetterName(name: var->name); |
2183 | // TODO: Add a check for optional regions here when ODS supports it. |
2184 | body << "!" << name << "().empty()" ; |
2185 | }) |
2186 | .Case(caseFn: [&](TypeDirective *element) { |
2187 | genOptionalGroupPrinterAnchor(anchor: element->getArg(), op, body); |
2188 | }) |
2189 | .Case(caseFn: [&](FunctionalTypeDirective *element) { |
2190 | genOptionalGroupPrinterAnchor(anchor: element->getInputs(), op, body); |
2191 | }) |
2192 | .Case(caseFn: [&](AttributeVariable *element) { |
2193 | Attribute attr = element->getVar()->attr; |
2194 | body << op.getGetterName(name: element->getVar()->name) << "Attr()" ; |
2195 | if (attr.isOptional()) |
2196 | return; // done |
2197 | if (attr.hasDefaultValue()) { |
2198 | // Consider a default-valued attribute as present if it's not the |
2199 | // default value. |
2200 | genNonDefaultValueCheck(body, op, attrElement&: *element); |
2201 | return; |
2202 | } |
2203 | llvm_unreachable("attribute must be optional or default-valued" ); |
2204 | }) |
2205 | .Case(caseFn: [&](CustomDirective *ele) { |
2206 | body << '('; |
2207 | llvm::interleave( |
2208 | c: ele->getArguments(), os&: body, |
2209 | each_fn: [&](FormatElement *child) { |
2210 | body << '('; |
2211 | genOptionalGroupPrinterAnchor(anchor: child, op, body); |
2212 | body << ')'; |
2213 | }, |
2214 | separator: " || " ); |
2215 | body << ')'; |
2216 | }); |
2217 | } |
2218 | |
2219 | void collect(FormatElement *element, |
2220 | SmallVectorImpl<VariableElement *> &variables) { |
2221 | TypeSwitch<FormatElement *>(element) |
2222 | .Case(caseFn: [&](VariableElement *var) { variables.emplace_back(Args&: var); }) |
2223 | .Case(caseFn: [&](CustomDirective *ele) { |
2224 | for (FormatElement *arg : ele->getArguments()) |
2225 | collect(element: arg, variables); |
2226 | }) |
2227 | .Case(caseFn: [&](OptionalElement *ele) { |
2228 | for (FormatElement *arg : ele->getThenElements()) |
2229 | collect(element: arg, variables); |
2230 | for (FormatElement *arg : ele->getElseElements()) |
2231 | collect(element: arg, variables); |
2232 | }) |
2233 | .Case(caseFn: [&](FunctionalTypeDirective *funcType) { |
2234 | collect(element: funcType->getInputs(), variables); |
2235 | collect(element: funcType->getResults(), variables); |
2236 | }) |
2237 | .Case(caseFn: [&](OIListElement *oilist) { |
2238 | for (ArrayRef<FormatElement *> arg : oilist->getParsingElements()) |
2239 | for (FormatElement *arg : arg) |
2240 | collect(element: arg, variables); |
2241 | }); |
2242 | } |
2243 | |
2244 | void OperationFormat::genElementPrinter(FormatElement *element, |
2245 | MethodBody &body, Operator &op, |
2246 | bool &shouldEmitSpace, |
2247 | bool &lastWasPunctuation) { |
2248 | if (LiteralElement *literal = dyn_cast<LiteralElement>(Val: element)) |
2249 | return genLiteralPrinter(value: literal->getSpelling(), body, shouldEmitSpace, |
2250 | lastWasPunctuation); |
2251 | |
2252 | // Emit a whitespace element. |
2253 | if (auto *space = dyn_cast<WhitespaceElement>(Val: element)) { |
2254 | if (space->getValue() == "\\n" ) { |
2255 | body << " _odsPrinter.printNewline();\n" ; |
2256 | } else { |
2257 | genSpacePrinter(value: !space->getValue().empty(), body, shouldEmitSpace, |
2258 | lastWasPunctuation); |
2259 | } |
2260 | return; |
2261 | } |
2262 | |
2263 | // Emit an optional group. |
2264 | if (OptionalElement *optional = dyn_cast<OptionalElement>(Val: element)) { |
2265 | // Emit the check for the presence of the anchor element. |
2266 | FormatElement *anchor = optional->getAnchor(); |
2267 | body << " if (" ; |
2268 | if (optional->isInverted()) |
2269 | body << "!" ; |
2270 | genOptionalGroupPrinterAnchor(anchor, op, body); |
2271 | body << ") {\n" ; |
2272 | body.indent(); |
2273 | |
2274 | // If the anchor is a unit attribute, we don't need to print it. When |
2275 | // parsing, we will add this attribute if this group is present. |
2276 | ArrayRef<FormatElement *> thenElements = optional->getThenElements(); |
2277 | ArrayRef<FormatElement *> elseElements = optional->getElseElements(); |
2278 | FormatElement *elidedAnchorElement = nullptr; |
2279 | auto *anchorAttr = dyn_cast<AttributeVariable>(Val: anchor); |
2280 | if (anchorAttr && anchorAttr != thenElements.front() && |
2281 | (elseElements.empty() || anchorAttr != elseElements.front()) && |
2282 | anchorAttr->isUnitAttr()) { |
2283 | elidedAnchorElement = anchorAttr; |
2284 | } |
2285 | auto genElementPrinters = [&](ArrayRef<FormatElement *> elements) { |
2286 | for (FormatElement *childElement : elements) { |
2287 | if (childElement != elidedAnchorElement) { |
2288 | genElementPrinter(element: childElement, body, op, shouldEmitSpace, |
2289 | lastWasPunctuation); |
2290 | } |
2291 | } |
2292 | }; |
2293 | |
2294 | // Emit each of the elements. |
2295 | genElementPrinters(thenElements); |
2296 | body << "}" ; |
2297 | |
2298 | // Emit each of the else elements. |
2299 | if (!elseElements.empty()) { |
2300 | body << " else {\n" ; |
2301 | genElementPrinters(elseElements); |
2302 | body << "}" ; |
2303 | } |
2304 | |
2305 | body.unindent() << "\n" ; |
2306 | return; |
2307 | } |
2308 | |
2309 | // Emit the OIList |
2310 | if (auto *oilist = dyn_cast<OIListElement>(Val: element)) { |
2311 | for (auto clause : oilist->getClauses()) { |
2312 | LiteralElement *lelement = std::get<0>(t&: clause); |
2313 | ArrayRef<FormatElement *> pelement = std::get<1>(t&: clause); |
2314 | |
2315 | SmallVector<VariableElement *> vars; |
2316 | for (FormatElement *el : pelement) |
2317 | collect(element: el, variables&: vars); |
2318 | body << " if (false" ; |
2319 | for (VariableElement *var : vars) { |
2320 | TypeSwitch<FormatElement *>(var) |
2321 | .Case(caseFn: [&](AttributeVariable *attrEle) { |
2322 | body << " || (" << op.getGetterName(name: attrEle->getVar()->name) |
2323 | << "Attr()" ; |
2324 | Attribute attr = attrEle->getVar()->attr; |
2325 | if (attr.hasDefaultValue()) { |
2326 | // Don't print default-valued attributes. |
2327 | genNonDefaultValueCheck(body, op, attrElement&: *attrEle); |
2328 | } |
2329 | body << ")" ; |
2330 | }) |
2331 | .Case(caseFn: [&](OperandVariable *ele) { |
2332 | if (ele->getVar()->isVariadic()) { |
2333 | body << " || " << op.getGetterName(name: ele->getVar()->name) |
2334 | << "().size()" ; |
2335 | } else { |
2336 | body << " || " << op.getGetterName(name: ele->getVar()->name) << "()" ; |
2337 | } |
2338 | }) |
2339 | .Case(caseFn: [&](ResultVariable *ele) { |
2340 | if (ele->getVar()->isVariadic()) { |
2341 | body << " || " << op.getGetterName(name: ele->getVar()->name) |
2342 | << "().size()" ; |
2343 | } else { |
2344 | body << " || " << op.getGetterName(name: ele->getVar()->name) << "()" ; |
2345 | } |
2346 | }) |
2347 | .Case(caseFn: [&](RegionVariable *reg) { |
2348 | body << " || " << op.getGetterName(name: reg->getVar()->name) << "()" ; |
2349 | }); |
2350 | } |
2351 | |
2352 | body << ") {\n" ; |
2353 | genLiteralPrinter(value: lelement->getSpelling(), body, shouldEmitSpace, |
2354 | lastWasPunctuation); |
2355 | if (oilist->getUnitAttrParsingElement(pelement) == nullptr) { |
2356 | for (FormatElement *element : pelement) |
2357 | genElementPrinter(element, body, op, shouldEmitSpace, |
2358 | lastWasPunctuation); |
2359 | } |
2360 | body << " }\n" ; |
2361 | } |
2362 | return; |
2363 | } |
2364 | |
2365 | // Emit the attribute dictionary. |
2366 | if (auto *attrDict = dyn_cast<AttrDictDirective>(Val: element)) { |
2367 | genAttrDictPrinter(fmt&: *this, op, body, withKeyword: attrDict->isWithKeyword()); |
2368 | lastWasPunctuation = false; |
2369 | return; |
2370 | } |
2371 | |
2372 | // Emit the attribute dictionary. |
2373 | if (isa<PropDictDirective>(Val: element)) { |
2374 | genPropDictPrinter(fmt&: *this, op, body); |
2375 | lastWasPunctuation = false; |
2376 | return; |
2377 | } |
2378 | |
2379 | // Optionally insert a space before the next element. The AttrDict printer |
2380 | // already adds a space as necessary. |
2381 | if (shouldEmitSpace || !lastWasPunctuation) |
2382 | body << " _odsPrinter << ' ';\n" ; |
2383 | lastWasPunctuation = false; |
2384 | shouldEmitSpace = true; |
2385 | |
2386 | if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) { |
2387 | const NamedAttribute *var = attr->getVar(); |
2388 | |
2389 | // If we are formatting as an enum, symbolize the attribute as a string. |
2390 | if (canFormatEnumAttr(attr: var)) |
2391 | return genEnumAttrPrinter(var, op, body); |
2392 | |
2393 | // If we are formatting as a symbol name, handle it as a symbol name. |
2394 | if (shouldFormatSymbolNameAttr(attr: var)) { |
2395 | body << " _odsPrinter.printSymbolName(" << op.getGetterName(name: var->name) |
2396 | << "Attr().getValue());\n" ; |
2397 | return; |
2398 | } |
2399 | |
2400 | // Elide the attribute type if it is buildable. |
2401 | if (attr->getTypeBuilder()) |
2402 | body << " _odsPrinter.printAttributeWithoutType(" |
2403 | << op.getGetterName(name: var->name) << "Attr());\n" ; |
2404 | else if (attr->shouldBeQualified() || |
2405 | var->attr.getStorageType() == "::mlir::Attribute" ) |
2406 | body << " _odsPrinter.printAttribute(" << op.getGetterName(name: var->name) |
2407 | << "Attr());\n" ; |
2408 | else |
2409 | body << "_odsPrinter.printStrippedAttrOrType(" |
2410 | << op.getGetterName(name: var->name) << "Attr());\n" ; |
2411 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) { |
2412 | if (operand->getVar()->isVariadicOfVariadic()) { |
2413 | body << " ::llvm::interleaveComma(" |
2414 | << op.getGetterName(name: operand->getVar()->name) |
2415 | << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << " |
2416 | "\"(\" << operands << " |
2417 | "\")\"; });\n" ; |
2418 | |
2419 | } else if (operand->getVar()->isOptional()) { |
2420 | body << " if (::mlir::Value value = " |
2421 | << op.getGetterName(name: operand->getVar()->name) << "())\n" |
2422 | << " _odsPrinter << value;\n" ; |
2423 | } else { |
2424 | body << " _odsPrinter << " << op.getGetterName(name: operand->getVar()->name) |
2425 | << "();\n" ; |
2426 | } |
2427 | } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) { |
2428 | const NamedRegion *var = region->getVar(); |
2429 | std::string name = op.getGetterName(name: var->name); |
2430 | if (var->isVariadic()) { |
2431 | genVariadicRegionPrinter(regionListName: name + "()" , body, hasImplicitTermTrait); |
2432 | } else { |
2433 | genRegionPrinter(regionName: name + "()" , body, hasImplicitTermTrait); |
2434 | } |
2435 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) { |
2436 | const NamedSuccessor *var = successor->getVar(); |
2437 | std::string name = op.getGetterName(name: var->name); |
2438 | if (var->isVariadic()) |
2439 | body << " ::llvm::interleaveComma(" << name << "(), _odsPrinter);\n" ; |
2440 | else |
2441 | body << " _odsPrinter << " << name << "();\n" ; |
2442 | } else if (auto *dir = dyn_cast<CustomDirective>(Val: element)) { |
2443 | genCustomDirectivePrinter(customDir: dir, op, body); |
2444 | } else if (isa<OperandsDirective>(Val: element)) { |
2445 | body << " _odsPrinter << getOperation()->getOperands();\n" ; |
2446 | } else if (isa<RegionsDirective>(Val: element)) { |
2447 | genVariadicRegionPrinter(regionListName: "getOperation()->getRegions()" , body, |
2448 | hasImplicitTermTrait); |
2449 | } else if (isa<SuccessorsDirective>(Val: element)) { |
2450 | body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), " |
2451 | "_odsPrinter);\n" ; |
2452 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) { |
2453 | if (auto *operand = dyn_cast<OperandVariable>(Val: dir->getArg())) { |
2454 | if (operand->getVar()->isVariadicOfVariadic()) { |
2455 | body << llvm::formatv( |
2456 | Fmt: " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, " |
2457 | "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << " |
2458 | "types << \")\"; });\n" , |
2459 | Vals: op.getGetterName(name: operand->getVar()->name)); |
2460 | return; |
2461 | } |
2462 | } |
2463 | const NamedTypeConstraint *var = nullptr; |
2464 | { |
2465 | if (auto *operand = dyn_cast<OperandVariable>(Val: dir->getArg())) |
2466 | var = operand->getVar(); |
2467 | else if (auto *operand = dyn_cast<ResultVariable>(Val: dir->getArg())) |
2468 | var = operand->getVar(); |
2469 | } |
2470 | if (var && !var->isVariadicOfVariadic() && !var->isVariadic() && |
2471 | !var->isOptional()) { |
2472 | std::string cppClass = var->constraint.getCPPClassName(); |
2473 | if (dir->shouldBeQualified()) { |
2474 | body << " _odsPrinter << " << op.getGetterName(name: var->name) |
2475 | << "().getType();\n" ; |
2476 | return; |
2477 | } |
2478 | body << " {\n" |
2479 | << " auto type = " << op.getGetterName(name: var->name) |
2480 | << "().getType();\n" |
2481 | << " if (auto validType = ::llvm::dyn_cast<" << cppClass |
2482 | << ">(type))\n" |
2483 | << " _odsPrinter.printStrippedAttrOrType(validType);\n" |
2484 | << " else\n" |
2485 | << " _odsPrinter << type;\n" |
2486 | << " }\n" ; |
2487 | return; |
2488 | } |
2489 | body << " _odsPrinter << " ; |
2490 | genTypeOperandPrinter(arg: dir->getArg(), op, body, /*useArrayRef=*/false) |
2491 | << ";\n" ; |
2492 | } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(Val: element)) { |
2493 | body << " _odsPrinter.printFunctionalType(" ; |
2494 | genTypeOperandPrinter(arg: dir->getInputs(), op, body) << ", " ; |
2495 | genTypeOperandPrinter(arg: dir->getResults(), op, body) << ");\n" ; |
2496 | } else { |
2497 | llvm_unreachable("unknown format element" ); |
2498 | } |
2499 | } |
2500 | |
2501 | void OperationFormat::genPrinter(Operator &op, OpClass &opClass) { |
2502 | auto *method = opClass.addMethod( |
2503 | retType: "void" , name: "print" , |
2504 | args: MethodParameter("::mlir::OpAsmPrinter &" , "_odsPrinter" )); |
2505 | auto &body = method->body(); |
2506 | |
2507 | // Flags for if we should emit a space, and if the last element was |
2508 | // punctuation. |
2509 | bool shouldEmitSpace = true, lastWasPunctuation = false; |
2510 | for (FormatElement *element : elements) |
2511 | genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation); |
2512 | } |
2513 | |
2514 | //===----------------------------------------------------------------------===// |
2515 | // OpFormatParser |
2516 | //===----------------------------------------------------------------------===// |
2517 | |
2518 | /// Function to find an element within the given range that has the same name as |
2519 | /// 'name'. |
2520 | template <typename RangeT> |
2521 | static auto findArg(RangeT &&range, StringRef name) { |
2522 | auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; }); |
2523 | return it != range.end() ? &*it : nullptr; |
2524 | } |
2525 | |
2526 | namespace { |
2527 | /// This class implements a parser for an instance of an operation assembly |
2528 | /// format. |
2529 | class OpFormatParser : public FormatParser { |
2530 | public: |
2531 | OpFormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op) |
2532 | : FormatParser(mgr, op.getLoc()[0]), fmt(format), op(op), |
2533 | seenOperandTypes(op.getNumOperands()), |
2534 | seenResultTypes(op.getNumResults()) {} |
2535 | |
2536 | protected: |
2537 | /// Verify the format elements. |
2538 | LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override; |
2539 | /// Verify the arguments to a custom directive. |
2540 | LogicalResult |
2541 | verifyCustomDirectiveArguments(SMLoc loc, |
2542 | ArrayRef<FormatElement *> arguments) override; |
2543 | /// Verify the elements of an optional group. |
2544 | LogicalResult verifyOptionalGroupElements(SMLoc loc, |
2545 | ArrayRef<FormatElement *> elements, |
2546 | FormatElement *anchor) override; |
2547 | LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element, |
2548 | bool isAnchor); |
2549 | |
2550 | /// Parse an operation variable. |
2551 | FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name, |
2552 | Context ctx) override; |
2553 | /// Parse an operation format directive. |
2554 | FailureOr<FormatElement *> |
2555 | parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override; |
2556 | |
2557 | private: |
2558 | /// This struct represents a type resolution instance. It includes a specific |
2559 | /// type as well as an optional transformer to apply to that type in order to |
2560 | /// properly resolve the type of a variable. |
2561 | struct TypeResolutionInstance { |
2562 | ConstArgument resolver; |
2563 | std::optional<StringRef> transformer; |
2564 | }; |
2565 | |
2566 | /// Verify the state of operation attributes within the format. |
2567 | LogicalResult verifyAttributes(SMLoc loc, ArrayRef<FormatElement *> elements); |
2568 | |
2569 | /// Verify that attributes elements aren't followed by colon literals. |
2570 | LogicalResult verifyAttributeColonType(SMLoc loc, |
2571 | ArrayRef<FormatElement *> elements); |
2572 | /// Verify that the attribute dictionary directive isn't followed by a region. |
2573 | LogicalResult verifyAttrDictRegion(SMLoc loc, |
2574 | ArrayRef<FormatElement *> elements); |
2575 | |
2576 | /// Verify the state of operation operands within the format. |
2577 | LogicalResult |
2578 | verifyOperands(SMLoc loc, |
2579 | llvm::StringMap<TypeResolutionInstance> &variableTyResolver); |
2580 | |
2581 | /// Verify the state of operation regions within the format. |
2582 | LogicalResult verifyRegions(SMLoc loc); |
2583 | |
2584 | /// Verify the state of operation results within the format. |
2585 | LogicalResult |
2586 | verifyResults(SMLoc loc, |
2587 | llvm::StringMap<TypeResolutionInstance> &variableTyResolver); |
2588 | |
2589 | /// Verify the state of operation successors within the format. |
2590 | LogicalResult verifySuccessors(SMLoc loc); |
2591 | |
2592 | LogicalResult verifyOIListElements(SMLoc loc, |
2593 | ArrayRef<FormatElement *> elements); |
2594 | |
2595 | /// Given the values of an `AllTypesMatch` trait, check for inferable type |
2596 | /// resolution. |
2597 | void handleAllTypesMatchConstraint( |
2598 | ArrayRef<StringRef> values, |
2599 | llvm::StringMap<TypeResolutionInstance> &variableTyResolver); |
2600 | /// Check for inferable type resolution given all operands, and or results, |
2601 | /// have the same type. If 'includeResults' is true, the results also have the |
2602 | /// same type as all of the operands. |
2603 | void handleSameTypesConstraint( |
2604 | llvm::StringMap<TypeResolutionInstance> &variableTyResolver, |
2605 | bool includeResults); |
2606 | /// Check for inferable type resolution based on another operand, result, or |
2607 | /// attribute. |
2608 | void handleTypesMatchConstraint( |
2609 | llvm::StringMap<TypeResolutionInstance> &variableTyResolver, |
2610 | const llvm::Record &def); |
2611 | |
2612 | /// Returns an argument or attribute with the given name that has been seen |
2613 | /// within the format. |
2614 | ConstArgument findSeenArg(StringRef name); |
2615 | |
2616 | /// Parse the various different directives. |
2617 | FailureOr<FormatElement *> parsePropDictDirective(SMLoc loc, Context context); |
2618 | FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context, |
2619 | bool withKeyword); |
2620 | FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc, |
2621 | Context context); |
2622 | FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context); |
2623 | LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc); |
2624 | FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context); |
2625 | FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, |
2626 | Context context); |
2627 | FailureOr<FormatElement *> parseReferenceDirective(SMLoc loc, |
2628 | Context context); |
2629 | FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context); |
2630 | FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context); |
2631 | FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc, |
2632 | Context context); |
2633 | FailureOr<FormatElement *> parseTypeDirective(SMLoc loc, Context context); |
2634 | FailureOr<FormatElement *> parseTypeDirectiveOperand(SMLoc loc, |
2635 | bool isRefChild = false); |
2636 | |
2637 | //===--------------------------------------------------------------------===// |
2638 | // Fields |
2639 | //===--------------------------------------------------------------------===// |
2640 | |
2641 | OperationFormat &fmt; |
2642 | Operator &op; |
2643 | |
2644 | // The following are various bits of format state used for verification |
2645 | // during parsing. |
2646 | bool hasAttrDict = false; |
2647 | bool hasPropDict = false; |
2648 | bool hasAllRegions = false, hasAllSuccessors = false; |
2649 | bool canInferResultTypes = false; |
2650 | llvm::SmallBitVector seenOperandTypes, seenResultTypes; |
2651 | llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs; |
2652 | llvm::DenseSet<const NamedTypeConstraint *> seenOperands; |
2653 | llvm::DenseSet<const NamedRegion *> seenRegions; |
2654 | llvm::DenseSet<const NamedSuccessor *> seenSuccessors; |
2655 | llvm::SmallSetVector<const NamedProperty *, 8> seenProperties; |
2656 | }; |
2657 | } // namespace |
2658 | |
2659 | LogicalResult OpFormatParser::verify(SMLoc loc, |
2660 | ArrayRef<FormatElement *> elements) { |
2661 | // Check that the attribute dictionary is in the format. |
2662 | if (!hasAttrDict) |
2663 | return emitError(loc, msg: "'attr-dict' directive not found in " |
2664 | "custom assembly format" ); |
2665 | |
2666 | // Check for any type traits that we can use for inferring types. |
2667 | llvm::StringMap<TypeResolutionInstance> variableTyResolver; |
2668 | for (const Trait &trait : op.getTraits()) { |
2669 | const llvm::Record &def = trait.getDef(); |
2670 | if (def.isSubClassOf(Name: "AllTypesMatch" )) { |
2671 | handleAllTypesMatchConstraint(values: def.getValueAsListOfStrings(FieldName: "values" ), |
2672 | variableTyResolver); |
2673 | } else if (def.getName() == "SameTypeOperands" ) { |
2674 | handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false); |
2675 | } else if (def.getName() == "SameOperandsAndResultType" ) { |
2676 | handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); |
2677 | } else if (def.isSubClassOf(Name: "TypesMatchWith" )) { |
2678 | handleTypesMatchConstraint(variableTyResolver, def); |
2679 | } else if (!op.allResultTypesKnown()) { |
2680 | // This doesn't check the name directly to handle |
2681 | // DeclareOpInterfaceMethods<InferTypeOpInterface> |
2682 | // and the like. |
2683 | // TODO: Add hasCppInterface check. |
2684 | if (auto name = def.getValueAsOptionalString(FieldName: "cppInterfaceName" )) { |
2685 | if (*name == "InferTypeOpInterface" && |
2686 | def.getValueAsString(FieldName: "cppNamespace" ) == "::mlir" ) |
2687 | canInferResultTypes = true; |
2688 | } |
2689 | } |
2690 | } |
2691 | |
2692 | // Verify the state of the various operation components. |
2693 | if (failed(result: verifyAttributes(loc, elements)) || |
2694 | failed(result: verifyResults(loc, variableTyResolver)) || |
2695 | failed(result: verifyOperands(loc, variableTyResolver)) || |
2696 | failed(result: verifyRegions(loc)) || failed(result: verifySuccessors(loc)) || |
2697 | failed(result: verifyOIListElements(loc, elements))) |
2698 | return failure(); |
2699 | |
2700 | // Collect the set of used attributes in the format. |
2701 | fmt.usedAttributes = std::move(seenAttrs); |
2702 | fmt.usedProperties = std::move(seenProperties); |
2703 | |
2704 | // Set whether prop-dict is used in the format |
2705 | fmt.hasPropDict = hasPropDict; |
2706 | return success(); |
2707 | } |
2708 | |
2709 | LogicalResult |
2710 | OpFormatParser::verifyAttributes(SMLoc loc, |
2711 | ArrayRef<FormatElement *> elements) { |
2712 | // Check that there are no `:` literals after an attribute without a constant |
2713 | // type. The attribute grammar contains an optional trailing colon type, which |
2714 | // can lead to unexpected and generally unintended behavior. Given that, it is |
2715 | // better to just error out here instead. |
2716 | if (failed(result: verifyAttributeColonType(loc, elements))) |
2717 | return failure(); |
2718 | // Check that there are no region variables following an attribute dicitonary. |
2719 | // Both start with `{` and so the optional attribute dictionary can cause |
2720 | // format ambiguities. |
2721 | if (failed(result: verifyAttrDictRegion(loc, elements))) |
2722 | return failure(); |
2723 | |
2724 | // Check for VariadicOfVariadic variables. The segment attribute of those |
2725 | // variables will be infered. |
2726 | for (const NamedTypeConstraint *var : seenOperands) { |
2727 | if (var->constraint.isVariadicOfVariadic()) { |
2728 | fmt.inferredAttributes.insert( |
2729 | key: var->constraint.getVariadicOfVariadicSegmentSizeAttr()); |
2730 | } |
2731 | } |
2732 | |
2733 | return success(); |
2734 | } |
2735 | |
2736 | /// Returns whether the single format element is optionally parsed. |
2737 | static bool isOptionallyParsed(FormatElement *el) { |
2738 | if (auto *attrVar = dyn_cast<AttributeVariable>(Val: el)) { |
2739 | Attribute attr = attrVar->getVar()->attr; |
2740 | return attr.isOptional() || attr.hasDefaultValue(); |
2741 | } |
2742 | if (auto *operandVar = dyn_cast<OperandVariable>(Val: el)) { |
2743 | const NamedTypeConstraint *operand = operandVar->getVar(); |
2744 | return operand->isOptional() || operand->isVariadic() || |
2745 | operand->isVariadicOfVariadic(); |
2746 | } |
2747 | if (auto *successorVar = dyn_cast<SuccessorVariable>(Val: el)) |
2748 | return successorVar->getVar()->isVariadic(); |
2749 | if (auto *regionVar = dyn_cast<RegionVariable>(Val: el)) |
2750 | return regionVar->getVar()->isVariadic(); |
2751 | return isa<WhitespaceElement, AttrDictDirective>(Val: el); |
2752 | } |
2753 | |
2754 | /// Scan the given range of elements from the start for an invalid format |
2755 | /// element that satisfies `isInvalid`, skipping any optionally-parsed elements. |
2756 | /// If an optional group is encountered, this function recurses into the 'then' |
2757 | /// and 'else' elements to check if they are invalid. Returns `success` if the |
2758 | /// range is known to be valid or `std::nullopt` if scanning reached the end. |
2759 | /// |
2760 | /// Since the guard element of an optional group is required, this function |
2761 | /// accepts an optional element pointer to mark it as required. |
2762 | static std::optional<LogicalResult> checkRangeForElement( |
2763 | FormatElement *base, |
2764 | function_ref<bool(FormatElement *, FormatElement *)> isInvalid, |
2765 | iterator_range<ArrayRef<FormatElement *>::iterator> elementRange, |
2766 | FormatElement *optionalGuard = nullptr) { |
2767 | for (FormatElement *element : elementRange) { |
2768 | // If we encounter an invalid element, return an error. |
2769 | if (isInvalid(base, element)) |
2770 | return failure(); |
2771 | |
2772 | // Recurse on optional groups. |
2773 | if (auto *optional = dyn_cast<OptionalElement>(Val: element)) { |
2774 | if (std::optional<LogicalResult> result = checkRangeForElement( |
2775 | base, isInvalid, elementRange: optional->getThenElements(), |
2776 | // The optional group guard is required for the group. |
2777 | optionalGuard: optional->getThenElements().front())) |
2778 | if (failed(result: *result)) |
2779 | return failure(); |
2780 | if (std::optional<LogicalResult> result = checkRangeForElement( |
2781 | base, isInvalid, elementRange: optional->getElseElements())) |
2782 | if (failed(result: *result)) |
2783 | return failure(); |
2784 | // Skip the optional group. |
2785 | continue; |
2786 | } |
2787 | |
2788 | // Skip optionally parsed elements. |
2789 | if (element != optionalGuard && isOptionallyParsed(el: element)) |
2790 | continue; |
2791 | |
2792 | // We found a closing element that is valid. |
2793 | return success(); |
2794 | } |
2795 | // Return std::nullopt to indicate that we reached the end. |
2796 | return std::nullopt; |
2797 | } |
2798 | |
2799 | /// For the given elements, check whether any attributes are followed by a colon |
2800 | /// literal, resulting in an ambiguous assembly format. Returns a non-null |
2801 | /// attribute if verification of said attribute reached the end of the range. |
2802 | /// Returns null if all attribute elements are verified. |
2803 | static FailureOr<FormatElement *> verifyAdjacentElements( |
2804 | function_ref<bool(FormatElement *)> isBase, |
2805 | function_ref<bool(FormatElement *, FormatElement *)> isInvalid, |
2806 | ArrayRef<FormatElement *> elements) { |
2807 | for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) { |
2808 | // The current attribute being verified. |
2809 | FormatElement *base; |
2810 | |
2811 | if (isBase(*it)) { |
2812 | base = *it; |
2813 | } else if (auto *optional = dyn_cast<OptionalElement>(Val: *it)) { |
2814 | // Recurse on optional groups. |
2815 | FailureOr<FormatElement *> thenResult = verifyAdjacentElements( |
2816 | isBase, isInvalid, elements: optional->getThenElements()); |
2817 | if (failed(result: thenResult)) |
2818 | return failure(); |
2819 | FailureOr<FormatElement *> elseResult = verifyAdjacentElements( |
2820 | isBase, isInvalid, elements: optional->getElseElements()); |
2821 | if (failed(result: elseResult)) |
2822 | return failure(); |
2823 | // If either optional group has an unverified attribute, save it. |
2824 | // Otherwise, move on to the next element. |
2825 | if (!(base = *thenResult) && !(base = *elseResult)) |
2826 | continue; |
2827 | } else { |
2828 | continue; |
2829 | } |
2830 | |
2831 | // Verify subsequent elements for potential ambiguities. |
2832 | if (std::optional<LogicalResult> result = |
2833 | checkRangeForElement(base, isInvalid, elementRange: {std::next(x: it), e})) { |
2834 | if (failed(result: *result)) |
2835 | return failure(); |
2836 | } else { |
2837 | // Since we reached the end, return the attribute as unverified. |
2838 | return base; |
2839 | } |
2840 | } |
2841 | // All attribute elements are known to be verified. |
2842 | return nullptr; |
2843 | } |
2844 | |
2845 | LogicalResult |
2846 | OpFormatParser::verifyAttributeColonType(SMLoc loc, |
2847 | ArrayRef<FormatElement *> elements) { |
2848 | auto isBase = [](FormatElement *el) { |
2849 | auto *attr = dyn_cast<AttributeVariable>(Val: el); |
2850 | if (!attr) |
2851 | return false; |
2852 | // Check only attributes without type builders or that are known to call |
2853 | // the generic attribute parser. |
2854 | return !attr->getTypeBuilder() && |
2855 | (attr->shouldBeQualified() || |
2856 | attr->getVar()->attr.getStorageType() == "::mlir::Attribute" ); |
2857 | }; |
2858 | auto isInvalid = [&](FormatElement *base, FormatElement *el) { |
2859 | auto *literal = dyn_cast<LiteralElement>(Val: el); |
2860 | if (!literal || literal->getSpelling() != ":" ) |
2861 | return false; |
2862 | // If we encounter `:`, the range is known to be invalid. |
2863 | (void)emitError( |
2864 | loc, |
2865 | msg: llvm::formatv(Fmt: "format ambiguity caused by `:` literal found after " |
2866 | "attribute `{0}` which does not have a buildable type" , |
2867 | Vals: cast<AttributeVariable>(Val: base)->getVar()->name)); |
2868 | return true; |
2869 | }; |
2870 | return verifyAdjacentElements(isBase, isInvalid, elements); |
2871 | } |
2872 | |
2873 | LogicalResult |
2874 | OpFormatParser::verifyAttrDictRegion(SMLoc loc, |
2875 | ArrayRef<FormatElement *> elements) { |
2876 | auto isBase = [](FormatElement *el) { |
2877 | if (auto *attrDict = dyn_cast<AttrDictDirective>(Val: el)) |
2878 | return !attrDict->isWithKeyword(); |
2879 | return false; |
2880 | }; |
2881 | auto isInvalid = [&](FormatElement *base, FormatElement *el) { |
2882 | auto *region = dyn_cast<RegionVariable>(Val: el); |
2883 | if (!region) |
2884 | return false; |
2885 | (void)emitErrorAndNote( |
2886 | loc, |
2887 | msg: llvm::formatv(Fmt: "format ambiguity caused by `attr-dict` directive " |
2888 | "followed by region `{0}`" , |
2889 | Vals: region->getVar()->name), |
2890 | note: "try using `attr-dict-with-keyword` instead" ); |
2891 | return true; |
2892 | }; |
2893 | return verifyAdjacentElements(isBase, isInvalid, elements); |
2894 | } |
2895 | |
2896 | LogicalResult OpFormatParser::verifyOperands( |
2897 | SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) { |
2898 | // Check that all of the operands are within the format, and their types can |
2899 | // be inferred. |
2900 | auto &buildableTypes = fmt.buildableTypes; |
2901 | for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { |
2902 | NamedTypeConstraint &operand = op.getOperand(index: i); |
2903 | |
2904 | // Check that the operand itself is in the format. |
2905 | if (!fmt.allOperands && !seenOperands.count(V: &operand)) { |
2906 | return emitErrorAndNote(loc, |
2907 | msg: "operand #" + Twine(i) + ", named '" + |
2908 | operand.name + "', not found" , |
2909 | note: "suggest adding a '$" + operand.name + |
2910 | "' directive to the custom assembly format" ); |
2911 | } |
2912 | |
2913 | // Check that the operand type is in the format, or that it can be inferred. |
2914 | if (fmt.allOperandTypes || seenOperandTypes.test(Idx: i)) |
2915 | continue; |
2916 | |
2917 | // Check to see if we can infer this type from another variable. |
2918 | auto varResolverIt = variableTyResolver.find(Key: op.getOperand(index: i).name); |
2919 | if (varResolverIt != variableTyResolver.end()) { |
2920 | TypeResolutionInstance &resolver = varResolverIt->second; |
2921 | fmt.operandTypes[i].setResolver(arg: resolver.resolver, transformer: resolver.transformer); |
2922 | continue; |
2923 | } |
2924 | |
2925 | // Similarly to results, allow a custom builder for resolving the type if |
2926 | // we aren't using the 'operands' directive. |
2927 | std::optional<StringRef> builder = operand.constraint.getBuilderCall(); |
2928 | if (!builder || (fmt.allOperands && operand.isVariableLength())) { |
2929 | return emitErrorAndNote( |
2930 | loc, |
2931 | msg: "type of operand #" + Twine(i) + ", named '" + operand.name + |
2932 | "', is not buildable and a buildable type cannot be inferred" , |
2933 | note: "suggest adding a type constraint to the operation or adding a " |
2934 | "'type($" + |
2935 | operand.name + ")' directive to the " + "custom assembly format" ); |
2936 | } |
2937 | auto it = buildableTypes.insert(KV: {*builder, buildableTypes.size()}); |
2938 | fmt.operandTypes[i].setBuilderIdx(it.first->second); |
2939 | } |
2940 | return success(); |
2941 | } |
2942 | |
2943 | LogicalResult OpFormatParser::verifyRegions(SMLoc loc) { |
2944 | // Check that all of the regions are within the format. |
2945 | if (hasAllRegions) |
2946 | return success(); |
2947 | |
2948 | for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) { |
2949 | const NamedRegion ®ion = op.getRegion(index: i); |
2950 | if (!seenRegions.count(V: ®ion)) { |
2951 | return emitErrorAndNote(loc, |
2952 | msg: "region #" + Twine(i) + ", named '" + |
2953 | region.name + "', not found" , |
2954 | note: "suggest adding a '$" + region.name + |
2955 | "' directive to the custom assembly format" ); |
2956 | } |
2957 | } |
2958 | return success(); |
2959 | } |
2960 | |
2961 | LogicalResult OpFormatParser::verifyResults( |
2962 | SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) { |
2963 | // If we format all of the types together, there is nothing to check. |
2964 | if (fmt.allResultTypes) |
2965 | return success(); |
2966 | |
2967 | // If no result types are specified and we can infer them, infer all result |
2968 | // types |
2969 | if (op.getNumResults() > 0 && seenResultTypes.count() == 0 && |
2970 | canInferResultTypes) { |
2971 | fmt.infersResultTypes = true; |
2972 | return success(); |
2973 | } |
2974 | |
2975 | // Check that all of the result types can be inferred. |
2976 | auto &buildableTypes = fmt.buildableTypes; |
2977 | for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { |
2978 | if (seenResultTypes.test(Idx: i)) |
2979 | continue; |
2980 | |
2981 | // Check to see if we can infer this type from another variable. |
2982 | auto varResolverIt = variableTyResolver.find(Key: op.getResultName(index: i)); |
2983 | if (varResolverIt != variableTyResolver.end()) { |
2984 | TypeResolutionInstance resolver = varResolverIt->second; |
2985 | fmt.resultTypes[i].setResolver(arg: resolver.resolver, transformer: resolver.transformer); |
2986 | continue; |
2987 | } |
2988 | |
2989 | // If the result is not variable length, allow for the case where the type |
2990 | // has a builder that we can use. |
2991 | NamedTypeConstraint &result = op.getResult(index: i); |
2992 | std::optional<StringRef> builder = result.constraint.getBuilderCall(); |
2993 | if (!builder || result.isVariableLength()) { |
2994 | return emitErrorAndNote( |
2995 | loc, |
2996 | msg: "type of result #" + Twine(i) + ", named '" + result.name + |
2997 | "', is not buildable and a buildable type cannot be inferred" , |
2998 | note: "suggest adding a type constraint to the operation or adding a " |
2999 | "'type($" + |
3000 | result.name + ")' directive to the " + "custom assembly format" ); |
3001 | } |
3002 | // Note in the format that this result uses the custom builder. |
3003 | auto it = buildableTypes.insert(KV: {*builder, buildableTypes.size()}); |
3004 | fmt.resultTypes[i].setBuilderIdx(it.first->second); |
3005 | } |
3006 | return success(); |
3007 | } |
3008 | |
3009 | LogicalResult OpFormatParser::verifySuccessors(SMLoc loc) { |
3010 | // Check that all of the successors are within the format. |
3011 | if (hasAllSuccessors) |
3012 | return success(); |
3013 | |
3014 | for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) { |
3015 | const NamedSuccessor &successor = op.getSuccessor(index: i); |
3016 | if (!seenSuccessors.count(V: &successor)) { |
3017 | return emitErrorAndNote(loc, |
3018 | msg: "successor #" + Twine(i) + ", named '" + |
3019 | successor.name + "', not found" , |
3020 | note: "suggest adding a '$" + successor.name + |
3021 | "' directive to the custom assembly format" ); |
3022 | } |
3023 | } |
3024 | return success(); |
3025 | } |
3026 | |
3027 | LogicalResult |
3028 | OpFormatParser::verifyOIListElements(SMLoc loc, |
3029 | ArrayRef<FormatElement *> elements) { |
3030 | // Check that all of the successors are within the format. |
3031 | SmallVector<StringRef> prohibitedLiterals; |
3032 | for (FormatElement *it : elements) { |
3033 | if (auto *oilist = dyn_cast<OIListElement>(Val: it)) { |
3034 | if (!prohibitedLiterals.empty()) { |
3035 | // We just saw an oilist element in last iteration. Literals should not |
3036 | // match. |
3037 | for (LiteralElement *literal : oilist->getLiteralElements()) { |
3038 | if (find(Range&: prohibitedLiterals, Val: literal->getSpelling()) != |
3039 | prohibitedLiterals.end()) { |
3040 | return emitError( |
3041 | loc, msg: "format ambiguity because " + literal->getSpelling() + |
3042 | " is used in two adjacent oilist elements." ); |
3043 | } |
3044 | } |
3045 | } |
3046 | for (LiteralElement *literal : oilist->getLiteralElements()) |
3047 | prohibitedLiterals.push_back(Elt: literal->getSpelling()); |
3048 | } else if (auto *literal = dyn_cast<LiteralElement>(Val: it)) { |
3049 | if (find(Range&: prohibitedLiterals, Val: literal->getSpelling()) != |
3050 | prohibitedLiterals.end()) { |
3051 | return emitError( |
3052 | loc, |
3053 | msg: "format ambiguity because " + literal->getSpelling() + |
3054 | " is used both in oilist element and the adjacent literal." ); |
3055 | } |
3056 | prohibitedLiterals.clear(); |
3057 | } else { |
3058 | prohibitedLiterals.clear(); |
3059 | } |
3060 | } |
3061 | return success(); |
3062 | } |
3063 | |
3064 | void OpFormatParser::handleAllTypesMatchConstraint( |
3065 | ArrayRef<StringRef> values, |
3066 | llvm::StringMap<TypeResolutionInstance> &variableTyResolver) { |
3067 | for (unsigned i = 0, e = values.size(); i != e; ++i) { |
3068 | // Check to see if this value matches a resolved operand or result type. |
3069 | ConstArgument arg = findSeenArg(name: values[i]); |
3070 | if (!arg) |
3071 | continue; |
3072 | |
3073 | // Mark this value as the type resolver for the other variables. |
3074 | for (unsigned j = 0; j != i; ++j) |
3075 | variableTyResolver[values[j]] = {.resolver: arg, .transformer: std::nullopt}; |
3076 | for (unsigned j = i + 1; j != e; ++j) |
3077 | variableTyResolver[values[j]] = {.resolver: arg, .transformer: std::nullopt}; |
3078 | } |
3079 | } |
3080 | |
3081 | void OpFormatParser::handleSameTypesConstraint( |
3082 | llvm::StringMap<TypeResolutionInstance> &variableTyResolver, |
3083 | bool includeResults) { |
3084 | const NamedTypeConstraint *resolver = nullptr; |
3085 | int resolvedIt = -1; |
3086 | |
3087 | // Check to see if there is an operand or result to use for the resolution. |
3088 | if ((resolvedIt = seenOperandTypes.find_first()) != -1) |
3089 | resolver = &op.getOperand(index: resolvedIt); |
3090 | else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1) |
3091 | resolver = &op.getResult(index: resolvedIt); |
3092 | else |
3093 | return; |
3094 | |
3095 | // Set the resolvers for each operand and result. |
3096 | for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) |
3097 | if (!seenOperandTypes.test(Idx: i)) |
3098 | variableTyResolver[op.getOperand(index: i).name] = {.resolver: resolver, .transformer: std::nullopt}; |
3099 | if (includeResults) { |
3100 | for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) |
3101 | if (!seenResultTypes.test(Idx: i)) |
3102 | variableTyResolver[op.getResultName(index: i)] = {.resolver: resolver, .transformer: std::nullopt}; |
3103 | } |
3104 | } |
3105 | |
3106 | void OpFormatParser::handleTypesMatchConstraint( |
3107 | llvm::StringMap<TypeResolutionInstance> &variableTyResolver, |
3108 | const llvm::Record &def) { |
3109 | StringRef lhsName = def.getValueAsString(FieldName: "lhs" ); |
3110 | StringRef rhsName = def.getValueAsString(FieldName: "rhs" ); |
3111 | StringRef transformer = def.getValueAsString(FieldName: "transformer" ); |
3112 | if (ConstArgument arg = findSeenArg(name: lhsName)) |
3113 | variableTyResolver[rhsName] = {.resolver: arg, .transformer: transformer}; |
3114 | } |
3115 | |
3116 | ConstArgument OpFormatParser::findSeenArg(StringRef name) { |
3117 | if (const NamedTypeConstraint *arg = findArg(range: op.getOperands(), name)) |
3118 | return seenOperandTypes.test(Idx: arg - op.operand_begin()) ? arg : nullptr; |
3119 | if (const NamedTypeConstraint *arg = findArg(range: op.getResults(), name)) |
3120 | return seenResultTypes.test(Idx: arg - op.result_begin()) ? arg : nullptr; |
3121 | if (const NamedAttribute *attr = findArg(range: op.getAttributes(), name)) |
3122 | return seenAttrs.count(key: attr) ? attr : nullptr; |
3123 | return nullptr; |
3124 | } |
3125 | |
3126 | FailureOr<FormatElement *> |
3127 | OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) { |
3128 | // Check that the parsed argument is something actually registered on the op. |
3129 | // Attributes |
3130 | if (const NamedAttribute *attr = findArg(range: op.getAttributes(), name)) { |
3131 | if (ctx == TypeDirectiveContext) |
3132 | return emitError( |
3133 | loc, msg: "attributes cannot be used as children to a `type` directive" ); |
3134 | if (ctx == RefDirectiveContext) { |
3135 | if (!seenAttrs.count(key: attr)) |
3136 | return emitError(loc, msg: "attribute '" + name + |
3137 | "' must be bound before it is referenced" ); |
3138 | } else if (!seenAttrs.insert(X: attr)) { |
3139 | return emitError(loc, msg: "attribute '" + name + "' is already bound" ); |
3140 | } |
3141 | |
3142 | return create<AttributeVariable>(args&: attr); |
3143 | } |
3144 | |
3145 | if (const NamedProperty *property = findArg(range: op.getProperties(), name)) { |
3146 | if (ctx != CustomDirectiveContext && ctx != RefDirectiveContext) |
3147 | return emitError( |
3148 | loc, msg: "properties currently only supported in `custom` directive" ); |
3149 | |
3150 | if (ctx == RefDirectiveContext) { |
3151 | if (!seenProperties.count(key: property)) |
3152 | return emitError(loc, msg: "property '" + name + |
3153 | "' must be bound before it is referenced" ); |
3154 | } else { |
3155 | if (!seenProperties.insert(X: property)) |
3156 | return emitError(loc, msg: "property '" + name + "' is already bound" ); |
3157 | } |
3158 | |
3159 | return create<PropertyVariable>(args&: property); |
3160 | } |
3161 | |
3162 | // Operands |
3163 | if (const NamedTypeConstraint *operand = findArg(range: op.getOperands(), name)) { |
3164 | if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { |
3165 | if (fmt.allOperands || !seenOperands.insert(V: operand).second) |
3166 | return emitError(loc, msg: "operand '" + name + "' is already bound" ); |
3167 | } else if (ctx == RefDirectiveContext && !seenOperands.count(V: operand)) { |
3168 | return emitError(loc, msg: "operand '" + name + |
3169 | "' must be bound before it is referenced" ); |
3170 | } |
3171 | return create<OperandVariable>(args&: operand); |
3172 | } |
3173 | // Regions |
3174 | if (const NamedRegion *region = findArg(range: op.getRegions(), name)) { |
3175 | if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { |
3176 | if (hasAllRegions || !seenRegions.insert(V: region).second) |
3177 | return emitError(loc, msg: "region '" + name + "' is already bound" ); |
3178 | } else if (ctx == RefDirectiveContext && !seenRegions.count(V: region)) { |
3179 | return emitError(loc, msg: "region '" + name + |
3180 | "' must be bound before it is referenced" ); |
3181 | } else { |
3182 | return emitError(loc, msg: "regions can only be used at the top level" ); |
3183 | } |
3184 | return create<RegionVariable>(args&: region); |
3185 | } |
3186 | // Results. |
3187 | if (const auto *result = findArg(range: op.getResults(), name)) { |
3188 | if (ctx != TypeDirectiveContext) |
3189 | return emitError(loc, msg: "result variables can can only be used as a child " |
3190 | "to a 'type' directive" ); |
3191 | return create<ResultVariable>(args&: result); |
3192 | } |
3193 | // Successors. |
3194 | if (const auto *successor = findArg(range: op.getSuccessors(), name)) { |
3195 | if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { |
3196 | if (hasAllSuccessors || !seenSuccessors.insert(V: successor).second) |
3197 | return emitError(loc, msg: "successor '" + name + "' is already bound" ); |
3198 | } else if (ctx == RefDirectiveContext && !seenSuccessors.count(V: successor)) { |
3199 | return emitError(loc, msg: "successor '" + name + |
3200 | "' must be bound before it is referenced" ); |
3201 | } else { |
3202 | return emitError(loc, msg: "successors can only be used at the top level" ); |
3203 | } |
3204 | |
3205 | return create<SuccessorVariable>(args&: successor); |
3206 | } |
3207 | return emitError(loc, msg: "expected variable to refer to an argument, region, " |
3208 | "result, or successor" ); |
3209 | } |
3210 | |
3211 | FailureOr<FormatElement *> |
3212 | OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, |
3213 | Context ctx) { |
3214 | switch (kind) { |
3215 | case FormatToken::kw_prop_dict: |
3216 | return parsePropDictDirective(loc, context: ctx); |
3217 | case FormatToken::kw_attr_dict: |
3218 | return parseAttrDictDirective(loc, context: ctx, |
3219 | /*withKeyword=*/false); |
3220 | case FormatToken::kw_attr_dict_w_keyword: |
3221 | return parseAttrDictDirective(loc, context: ctx, |
3222 | /*withKeyword=*/true); |
3223 | case FormatToken::kw_functional_type: |
3224 | return parseFunctionalTypeDirective(loc, context: ctx); |
3225 | case FormatToken::kw_operands: |
3226 | return parseOperandsDirective(loc, context: ctx); |
3227 | case FormatToken::kw_qualified: |
3228 | return parseQualifiedDirective(loc, context: ctx); |
3229 | case FormatToken::kw_regions: |
3230 | return parseRegionsDirective(loc, context: ctx); |
3231 | case FormatToken::kw_results: |
3232 | return parseResultsDirective(loc, context: ctx); |
3233 | case FormatToken::kw_successors: |
3234 | return parseSuccessorsDirective(loc, context: ctx); |
3235 | case FormatToken::kw_ref: |
3236 | return parseReferenceDirective(loc, context: ctx); |
3237 | case FormatToken::kw_type: |
3238 | return parseTypeDirective(loc, context: ctx); |
3239 | case FormatToken::kw_oilist: |
3240 | return parseOIListDirective(loc, context: ctx); |
3241 | |
3242 | default: |
3243 | return emitError(loc, msg: "unsupported directive kind" ); |
3244 | } |
3245 | } |
3246 | |
3247 | FailureOr<FormatElement *> |
3248 | OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context, |
3249 | bool withKeyword) { |
3250 | if (context == TypeDirectiveContext) |
3251 | return emitError(loc, msg: "'attr-dict' directive can only be used as a " |
3252 | "top-level directive" ); |
3253 | |
3254 | if (context == RefDirectiveContext) { |
3255 | if (!hasAttrDict) |
3256 | return emitError(loc, msg: "'ref' of 'attr-dict' is not bound by a prior " |
3257 | "'attr-dict' directive" ); |
3258 | |
3259 | // Otherwise, this is a top-level context. |
3260 | } else { |
3261 | if (hasAttrDict) |
3262 | return emitError(loc, msg: "'attr-dict' directive has already been seen" ); |
3263 | hasAttrDict = true; |
3264 | } |
3265 | |
3266 | return create<AttrDictDirective>(args&: withKeyword); |
3267 | } |
3268 | |
3269 | FailureOr<FormatElement *> |
3270 | OpFormatParser::parsePropDictDirective(SMLoc loc, Context context) { |
3271 | if (context == TypeDirectiveContext) |
3272 | return emitError(loc, msg: "'prop-dict' directive can only be used as a " |
3273 | "top-level directive" ); |
3274 | |
3275 | if (context == RefDirectiveContext) |
3276 | llvm::report_fatal_error(reason: "'ref' of 'prop-dict' unsupported" ); |
3277 | // Otherwise, this is a top-level context. |
3278 | |
3279 | if (hasPropDict) |
3280 | return emitError(loc, msg: "'prop-dict' directive has already been seen" ); |
3281 | hasPropDict = true; |
3282 | |
3283 | return create<PropDictDirective>(); |
3284 | } |
3285 | |
3286 | LogicalResult OpFormatParser::verifyCustomDirectiveArguments( |
3287 | SMLoc loc, ArrayRef<FormatElement *> arguments) { |
3288 | for (FormatElement *argument : arguments) { |
3289 | if (!isa<AttrDictDirective, PropDictDirective, AttributeVariable, |
3290 | OperandVariable, PropertyVariable, RefDirective, RegionVariable, |
3291 | SuccessorVariable, StringElement, TypeDirective>(Val: argument)) { |
3292 | // TODO: FormatElement should have location info attached. |
3293 | return emitError(loc, msg: "only variables and types may be used as " |
3294 | "parameters to a custom directive" ); |
3295 | } |
3296 | if (auto *type = dyn_cast<TypeDirective>(Val: argument)) { |
3297 | if (!isa<OperandVariable, ResultVariable>(Val: type->getArg())) { |
3298 | return emitError(loc, msg: "type directives within a custom directive may " |
3299 | "only refer to variables" ); |
3300 | } |
3301 | } |
3302 | } |
3303 | return success(); |
3304 | } |
3305 | |
3306 | FailureOr<FormatElement *> |
3307 | OpFormatParser::parseFunctionalTypeDirective(SMLoc loc, Context context) { |
3308 | if (context != TopLevelContext) |
3309 | return emitError( |
3310 | loc, msg: "'functional-type' is only valid as a top-level directive" ); |
3311 | |
3312 | // Parse the main operand. |
3313 | FailureOr<FormatElement *> inputs, results; |
3314 | if (failed(result: parseToken(kind: FormatToken::l_paren, |
3315 | msg: "expected '(' before argument list" )) || |
3316 | failed(result: inputs = parseTypeDirectiveOperand(loc)) || |
3317 | failed(result: parseToken(kind: FormatToken::comma, |
3318 | msg: "expected ',' after inputs argument" )) || |
3319 | failed(result: results = parseTypeDirectiveOperand(loc)) || |
3320 | failed( |
3321 | result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list" ))) |
3322 | return failure(); |
3323 | return create<FunctionalTypeDirective>(args&: *inputs, args&: *results); |
3324 | } |
3325 | |
3326 | FailureOr<FormatElement *> |
3327 | OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) { |
3328 | if (context == RefDirectiveContext) { |
3329 | if (!fmt.allOperands) |
3330 | return emitError(loc, msg: "'ref' of 'operands' is not bound by a prior " |
3331 | "'operands' directive" ); |
3332 | |
3333 | } else if (context == TopLevelContext || context == CustomDirectiveContext) { |
3334 | if (fmt.allOperands || !seenOperands.empty()) |
3335 | return emitError(loc, msg: "'operands' directive creates overlap in format" ); |
3336 | fmt.allOperands = true; |
3337 | } |
3338 | return create<OperandsDirective>(); |
3339 | } |
3340 | |
3341 | FailureOr<FormatElement *> |
3342 | OpFormatParser::parseReferenceDirective(SMLoc loc, Context context) { |
3343 | if (context != CustomDirectiveContext) |
3344 | return emitError(loc, msg: "'ref' is only valid within a `custom` directive" ); |
3345 | |
3346 | FailureOr<FormatElement *> arg; |
3347 | if (failed(result: parseToken(kind: FormatToken::l_paren, |
3348 | msg: "expected '(' before argument list" )) || |
3349 | failed(result: arg = parseElement(ctx: RefDirectiveContext)) || |
3350 | failed( |
3351 | result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list" ))) |
3352 | return failure(); |
3353 | |
3354 | return create<RefDirective>(args&: *arg); |
3355 | } |
3356 | |
3357 | FailureOr<FormatElement *> |
3358 | OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) { |
3359 | if (context == TypeDirectiveContext) |
3360 | return emitError(loc, msg: "'regions' is only valid as a top-level directive" ); |
3361 | if (context == RefDirectiveContext) { |
3362 | if (!hasAllRegions) |
3363 | return emitError(loc, msg: "'ref' of 'regions' is not bound by a prior " |
3364 | "'regions' directive" ); |
3365 | |
3366 | // Otherwise, this is a TopLevel directive. |
3367 | } else { |
3368 | if (hasAllRegions || !seenRegions.empty()) |
3369 | return emitError(loc, msg: "'regions' directive creates overlap in format" ); |
3370 | hasAllRegions = true; |
3371 | } |
3372 | return create<RegionsDirective>(); |
3373 | } |
3374 | |
3375 | FailureOr<FormatElement *> |
3376 | OpFormatParser::parseResultsDirective(SMLoc loc, Context context) { |
3377 | if (context != TypeDirectiveContext) |
3378 | return emitError(loc, msg: "'results' directive can can only be used as a child " |
3379 | "to a 'type' directive" ); |
3380 | return create<ResultsDirective>(); |
3381 | } |
3382 | |
3383 | FailureOr<FormatElement *> |
3384 | OpFormatParser::parseSuccessorsDirective(SMLoc loc, Context context) { |
3385 | if (context == TypeDirectiveContext) |
3386 | return emitError(loc, |
3387 | msg: "'successors' is only valid as a top-level directive" ); |
3388 | if (context == RefDirectiveContext) { |
3389 | if (!hasAllSuccessors) |
3390 | return emitError(loc, msg: "'ref' of 'successors' is not bound by a prior " |
3391 | "'successors' directive" ); |
3392 | |
3393 | // Otherwise, this is a TopLevel directive. |
3394 | } else { |
3395 | if (hasAllSuccessors || !seenSuccessors.empty()) |
3396 | return emitError(loc, msg: "'successors' directive creates overlap in format" ); |
3397 | hasAllSuccessors = true; |
3398 | } |
3399 | return create<SuccessorsDirective>(); |
3400 | } |
3401 | |
3402 | FailureOr<FormatElement *> |
3403 | OpFormatParser::parseOIListDirective(SMLoc loc, Context context) { |
3404 | if (failed(result: parseToken(kind: FormatToken::l_paren, |
3405 | msg: "expected '(' before oilist argument list" ))) |
3406 | return failure(); |
3407 | std::vector<FormatElement *> literalElements; |
3408 | std::vector<std::vector<FormatElement *>> parsingElements; |
3409 | do { |
3410 | FailureOr<FormatElement *> lelement = parseLiteral(ctx: context); |
3411 | if (failed(result: lelement)) |
3412 | return failure(); |
3413 | literalElements.push_back(x: *lelement); |
3414 | parsingElements.emplace_back(); |
3415 | std::vector<FormatElement *> &currParsingElements = parsingElements.back(); |
3416 | while (peekToken().getKind() != FormatToken::pipe && |
3417 | peekToken().getKind() != FormatToken::r_paren) { |
3418 | FailureOr<FormatElement *> pelement = parseElement(ctx: context); |
3419 | if (failed(result: pelement) || |
3420 | failed(result: verifyOIListParsingElement(element: *pelement, loc))) |
3421 | return failure(); |
3422 | currParsingElements.push_back(x: *pelement); |
3423 | } |
3424 | if (peekToken().getKind() == FormatToken::pipe) { |
3425 | consumeToken(); |
3426 | continue; |
3427 | } |
3428 | if (peekToken().getKind() == FormatToken::r_paren) { |
3429 | consumeToken(); |
3430 | break; |
3431 | } |
3432 | } while (true); |
3433 | |
3434 | return create<OIListElement>(args: std::move(literalElements), |
3435 | args: std::move(parsingElements)); |
3436 | } |
3437 | |
3438 | LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element, |
3439 | SMLoc loc) { |
3440 | SmallVector<VariableElement *> vars; |
3441 | collect(element, variables&: vars); |
3442 | for (VariableElement *elem : vars) { |
3443 | LogicalResult res = |
3444 | TypeSwitch<FormatElement *, LogicalResult>(elem) |
3445 | // Only optional attributes can be within an oilist parsing group. |
3446 | .Case(caseFn: [&](AttributeVariable *attrEle) { |
3447 | if (!attrEle->getVar()->attr.isOptional() && |
3448 | !attrEle->getVar()->attr.hasDefaultValue()) |
3449 | return emitError(loc, msg: "only optional attributes can be used in " |
3450 | "an oilist parsing group" ); |
3451 | return success(); |
3452 | }) |
3453 | // Only optional-like(i.e. variadic) operands can be within an |
3454 | // oilist parsing group. |
3455 | .Case(caseFn: [&](OperandVariable *ele) { |
3456 | if (!ele->getVar()->isVariableLength()) |
3457 | return emitError(loc, msg: "only variable length operands can be " |
3458 | "used within an oilist parsing group" ); |
3459 | return success(); |
3460 | }) |
3461 | // Only optional-like(i.e. variadic) results can be within an oilist |
3462 | // parsing group. |
3463 | .Case(caseFn: [&](ResultVariable *ele) { |
3464 | if (!ele->getVar()->isVariableLength()) |
3465 | return emitError(loc, msg: "only variable length results can be " |
3466 | "used within an oilist parsing group" ); |
3467 | return success(); |
3468 | }) |
3469 | .Case(caseFn: [&](RegionVariable *) { return success(); }) |
3470 | .Default(defaultFn: [&](FormatElement *) { |
3471 | return emitError(loc, |
3472 | msg: "only literals, types, and variables can be " |
3473 | "used within an oilist group" ); |
3474 | }); |
3475 | if (failed(result: res)) |
3476 | return failure(); |
3477 | } |
3478 | return success(); |
3479 | } |
3480 | |
3481 | FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc, |
3482 | Context context) { |
3483 | if (context == TypeDirectiveContext) |
3484 | return emitError(loc, msg: "'type' cannot be used as a child of another `type`" ); |
3485 | |
3486 | bool isRefChild = context == RefDirectiveContext; |
3487 | FailureOr<FormatElement *> operand; |
3488 | if (failed(result: parseToken(kind: FormatToken::l_paren, |
3489 | msg: "expected '(' before argument list" )) || |
3490 | failed(result: operand = parseTypeDirectiveOperand(loc, isRefChild)) || |
3491 | failed( |
3492 | result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list" ))) |
3493 | return failure(); |
3494 | |
3495 | return create<TypeDirective>(args&: *operand); |
3496 | } |
3497 | |
3498 | FailureOr<FormatElement *> |
3499 | OpFormatParser::parseQualifiedDirective(SMLoc loc, Context context) { |
3500 | FailureOr<FormatElement *> element; |
3501 | if (failed(result: parseToken(kind: FormatToken::l_paren, |
3502 | msg: "expected '(' before argument list" )) || |
3503 | failed(result: element = parseElement(ctx: context)) || |
3504 | failed( |
3505 | result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list" ))) |
3506 | return failure(); |
3507 | return TypeSwitch<FormatElement *, FailureOr<FormatElement *>>(*element) |
3508 | .Case<AttributeVariable, TypeDirective>(caseFn: [](auto *element) { |
3509 | element->setShouldBeQualified(); |
3510 | return element; |
3511 | }) |
3512 | .Default(defaultFn: [&](auto *element) { |
3513 | return this->emitError( |
3514 | loc, |
3515 | msg: "'qualified' directive expects an attribute or a `type` directive" ); |
3516 | }); |
3517 | } |
3518 | |
3519 | FailureOr<FormatElement *> |
3520 | OpFormatParser::parseTypeDirectiveOperand(SMLoc loc, bool isRefChild) { |
3521 | FailureOr<FormatElement *> result = parseElement(ctx: TypeDirectiveContext); |
3522 | if (failed(result)) |
3523 | return failure(); |
3524 | |
3525 | FormatElement *element = *result; |
3526 | if (isa<LiteralElement>(Val: element)) |
3527 | return emitError( |
3528 | loc, msg: "'type' directive operand expects variable or directive operand" ); |
3529 | |
3530 | if (auto *var = dyn_cast<OperandVariable>(Val: element)) { |
3531 | unsigned opIdx = var->getVar() - op.operand_begin(); |
3532 | if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(Idx: opIdx))) |
3533 | return emitError(loc, msg: "'type' of '" + var->getVar()->name + |
3534 | "' is already bound" ); |
3535 | if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(Idx: opIdx))) |
3536 | return emitError(loc, msg: "'ref' of 'type($" + var->getVar()->name + |
3537 | ")' is not bound by a prior 'type' directive" ); |
3538 | seenOperandTypes.set(opIdx); |
3539 | } else if (auto *var = dyn_cast<ResultVariable>(Val: element)) { |
3540 | unsigned resIdx = var->getVar() - op.result_begin(); |
3541 | if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(Idx: resIdx))) |
3542 | return emitError(loc, msg: "'type' of '" + var->getVar()->name + |
3543 | "' is already bound" ); |
3544 | if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(Idx: resIdx))) |
3545 | return emitError(loc, msg: "'ref' of 'type($" + var->getVar()->name + |
3546 | ")' is not bound by a prior 'type' directive" ); |
3547 | seenResultTypes.set(resIdx); |
3548 | } else if (isa<OperandsDirective>(Val: &*element)) { |
3549 | if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any())) |
3550 | return emitError(loc, msg: "'operands' 'type' is already bound" ); |
3551 | if (isRefChild && !fmt.allOperandTypes) |
3552 | return emitError(loc, msg: "'ref' of 'type(operands)' is not bound by a prior " |
3553 | "'type' directive" ); |
3554 | fmt.allOperandTypes = true; |
3555 | } else if (isa<ResultsDirective>(Val: &*element)) { |
3556 | if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any())) |
3557 | return emitError(loc, msg: "'results' 'type' is already bound" ); |
3558 | if (isRefChild && !fmt.allResultTypes) |
3559 | return emitError(loc, msg: "'ref' of 'type(results)' is not bound by a prior " |
3560 | "'type' directive" ); |
3561 | fmt.allResultTypes = true; |
3562 | } else { |
3563 | return emitError(loc, msg: "invalid argument to 'type' directive" ); |
3564 | } |
3565 | return element; |
3566 | } |
3567 | |
3568 | LogicalResult OpFormatParser::verifyOptionalGroupElements( |
3569 | SMLoc loc, ArrayRef<FormatElement *> elements, FormatElement *anchor) { |
3570 | for (FormatElement *element : elements) { |
3571 | if (failed(result: verifyOptionalGroupElement(loc, element, isAnchor: element == anchor))) |
3572 | return failure(); |
3573 | } |
3574 | return success(); |
3575 | } |
3576 | |
3577 | LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc, |
3578 | FormatElement *element, |
3579 | bool isAnchor) { |
3580 | return TypeSwitch<FormatElement *, LogicalResult>(element) |
3581 | // All attributes can be within the optional group, but only optional |
3582 | // attributes can be the anchor. |
3583 | .Case(caseFn: [&](AttributeVariable *attrEle) { |
3584 | Attribute attr = attrEle->getVar()->attr; |
3585 | if (isAnchor && !(attr.isOptional() || attr.hasDefaultValue())) |
3586 | return emitError(loc, msg: "only optional or default-valued attributes " |
3587 | "can be used to anchor an optional group" ); |
3588 | return success(); |
3589 | }) |
3590 | // Only optional-like(i.e. variadic) operands can be within an optional |
3591 | // group. |
3592 | .Case(caseFn: [&](OperandVariable *ele) { |
3593 | if (!ele->getVar()->isVariableLength()) |
3594 | return emitError(loc, msg: "only variable length operands can be used " |
3595 | "within an optional group" ); |
3596 | return success(); |
3597 | }) |
3598 | // Only optional-like(i.e. variadic) results can be within an optional |
3599 | // group. |
3600 | .Case(caseFn: [&](ResultVariable *ele) { |
3601 | if (!ele->getVar()->isVariableLength()) |
3602 | return emitError(loc, msg: "only variable length results can be used " |
3603 | "within an optional group" ); |
3604 | return success(); |
3605 | }) |
3606 | .Case(caseFn: [&](RegionVariable *) { |
3607 | // TODO: When ODS has proper support for marking "optional" regions, add |
3608 | // a check here. |
3609 | return success(); |
3610 | }) |
3611 | .Case(caseFn: [&](TypeDirective *ele) { |
3612 | return verifyOptionalGroupElement(loc, element: ele->getArg(), |
3613 | /*isAnchor=*/false); |
3614 | }) |
3615 | .Case(caseFn: [&](FunctionalTypeDirective *ele) { |
3616 | if (failed(result: verifyOptionalGroupElement(loc, element: ele->getInputs(), |
3617 | /*isAnchor=*/false))) |
3618 | return failure(); |
3619 | return verifyOptionalGroupElement(loc, element: ele->getResults(), |
3620 | /*isAnchor=*/false); |
3621 | }) |
3622 | .Case(caseFn: [&](CustomDirective *ele) { |
3623 | if (!isAnchor) |
3624 | return success(); |
3625 | // Verify each child as being valid in an optional group. They are all |
3626 | // potential anchors if the custom directive was marked as one. |
3627 | for (FormatElement *child : ele->getArguments()) { |
3628 | if (isa<RefDirective>(Val: child)) |
3629 | continue; |
3630 | if (failed(result: verifyOptionalGroupElement(loc, element: child, /*isAnchor=*/true))) |
3631 | return failure(); |
3632 | } |
3633 | return success(); |
3634 | }) |
3635 | // Literals, whitespace, and custom directives may be used, but they can't |
3636 | // anchor the group. |
3637 | .Case<LiteralElement, WhitespaceElement, OptionalElement>( |
3638 | caseFn: [&](FormatElement *) { |
3639 | if (isAnchor) |
3640 | return emitError(loc, msg: "only variables and types can be used " |
3641 | "to anchor an optional group" ); |
3642 | return success(); |
3643 | }) |
3644 | .Default(defaultFn: [&](FormatElement *) { |
3645 | return emitError(loc, msg: "only literals, types, and variables can be " |
3646 | "used within an optional group" ); |
3647 | }); |
3648 | } |
3649 | |
3650 | //===----------------------------------------------------------------------===// |
3651 | // Interface |
3652 | //===----------------------------------------------------------------------===// |
3653 | |
3654 | void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) { |
3655 | // TODO: Operator doesn't expose all necessary functionality via |
3656 | // the const interface. |
3657 | Operator &op = const_cast<Operator &>(constOp); |
3658 | if (!op.hasAssemblyFormat()) |
3659 | return; |
3660 | |
3661 | // Parse the format description. |
3662 | llvm::SourceMgr mgr; |
3663 | mgr.AddNewSourceBuffer( |
3664 | F: llvm::MemoryBuffer::getMemBuffer(InputData: op.getAssemblyFormat()), IncludeLoc: SMLoc()); |
3665 | OperationFormat format(op); |
3666 | OpFormatParser parser(mgr, format, op); |
3667 | FailureOr<std::vector<FormatElement *>> elements = parser.parse(); |
3668 | if (failed(result: elements)) { |
3669 | // Exit the process if format errors are treated as fatal. |
3670 | if (formatErrorIsFatal) { |
3671 | // Invoke the interrupt handlers to run the file cleanup handlers. |
3672 | llvm::sys::RunInterruptHandlers(); |
3673 | std::exit(status: 1); |
3674 | } |
3675 | return; |
3676 | } |
3677 | format.elements = std::move(*elements); |
3678 | |
3679 | // Generate the printer and parser based on the parsed format. |
3680 | format.genParser(op, opClass); |
3681 | format.genPrinter(op, opClass); |
3682 | } |
3683 | |