| 1 | //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "OpFormatGen.h" |
| 10 | #include "FormatGen.h" |
| 11 | #include "OpClass.h" |
| 12 | #include "mlir/Support/LLVM.h" |
| 13 | #include "mlir/TableGen/Class.h" |
| 14 | #include "mlir/TableGen/EnumInfo.h" |
| 15 | #include "mlir/TableGen/Format.h" |
| 16 | #include "mlir/TableGen/Operator.h" |
| 17 | #include "mlir/TableGen/Trait.h" |
| 18 | #include "llvm/ADT/MapVector.h" |
| 19 | #include "llvm/ADT/Sequence.h" |
| 20 | #include "llvm/ADT/SetVector.h" |
| 21 | #include "llvm/ADT/SmallBitVector.h" |
| 22 | #include "llvm/ADT/StringExtras.h" |
| 23 | #include "llvm/ADT/TypeSwitch.h" |
| 24 | #include "llvm/Support/Signals.h" |
| 25 | #include "llvm/Support/SourceMgr.h" |
| 26 | #include "llvm/TableGen/Record.h" |
| 27 | |
| 28 | #define DEBUG_TYPE "mlir-tblgen-opformatgen" |
| 29 | |
| 30 | using namespace mlir; |
| 31 | using namespace mlir::tblgen; |
| 32 | using llvm::formatv; |
| 33 | using llvm::Record; |
| 34 | using llvm::StringMap; |
| 35 | |
| 36 | //===----------------------------------------------------------------------===// |
| 37 | // VariableElement |
| 38 | //===----------------------------------------------------------------------===// |
| 39 | |
| 40 | namespace { |
| 41 | /// This class represents an instance of an op variable element. A variable |
| 42 | /// refers to something registered on the operation itself, e.g. an operand, |
| 43 | /// result, attribute, region, or successor. |
| 44 | template <typename VarT, VariableElement::Kind VariableKind> |
| 45 | class OpVariableElement : public VariableElementBase<VariableKind> { |
| 46 | public: |
| 47 | using Base = OpVariableElement<VarT, VariableKind>; |
| 48 | |
| 49 | /// Create an op variable element with the variable value. |
| 50 | OpVariableElement(const VarT *var) : var(var) {} |
| 51 | |
| 52 | /// Get the variable. |
| 53 | const VarT *getVar() const { return var; } |
| 54 | |
| 55 | protected: |
| 56 | /// The op variable, e.g. a type or attribute constraint. |
| 57 | const VarT *var; |
| 58 | }; |
| 59 | |
| 60 | /// This class represents a variable that refers to an attribute argument. |
| 61 | struct AttributeVariable |
| 62 | : public OpVariableElement<NamedAttribute, VariableElement::Attribute> { |
| 63 | using Base::Base; |
| 64 | |
| 65 | /// Return the constant builder call for the type of this attribute, or |
| 66 | /// std::nullopt if it doesn't have one. |
| 67 | std::optional<StringRef> getTypeBuilder() const { |
| 68 | std::optional<Type> attrType = var->attr.getValueType(); |
| 69 | return attrType ? attrType->getBuilderCall() : std::nullopt; |
| 70 | } |
| 71 | |
| 72 | /// Indicate if this attribute is printed "qualified" (that is it is |
| 73 | /// prefixed with the `#dialect.mnemonic`). |
| 74 | bool shouldBeQualified() { return shouldBeQualifiedFlag; } |
| 75 | void setShouldBeQualified(bool qualified = true) { |
| 76 | shouldBeQualifiedFlag = qualified; |
| 77 | } |
| 78 | |
| 79 | private: |
| 80 | bool shouldBeQualifiedFlag = false; |
| 81 | }; |
| 82 | |
| 83 | /// This class represents a variable that refers to an operand argument. |
| 84 | using OperandVariable = |
| 85 | OpVariableElement<NamedTypeConstraint, VariableElement::Operand>; |
| 86 | |
| 87 | /// This class represents a variable that refers to a result. |
| 88 | using ResultVariable = |
| 89 | OpVariableElement<NamedTypeConstraint, VariableElement::Result>; |
| 90 | |
| 91 | /// This class represents a variable that refers to a region. |
| 92 | using RegionVariable = OpVariableElement<NamedRegion, VariableElement::Region>; |
| 93 | |
| 94 | /// This class represents a variable that refers to a successor. |
| 95 | using SuccessorVariable = |
| 96 | OpVariableElement<NamedSuccessor, VariableElement::Successor>; |
| 97 | |
| 98 | /// This class represents a variable that refers to a property argument. |
| 99 | using PropertyVariable = |
| 100 | OpVariableElement<NamedProperty, VariableElement::Property>; |
| 101 | |
| 102 | /// LLVM RTTI helper for attribute-like variables, that is, attributes or |
| 103 | /// properties. This allows for common handling of attributes and properties in |
| 104 | /// parts of the code that are oblivious to whether something is stored as an |
| 105 | /// attribute or a property. |
| 106 | struct AttributeLikeVariable : public VariableElement { |
| 107 | enum { AttributeLike = 1 << 0 }; |
| 108 | |
| 109 | static bool classof(const VariableElement *ve) { |
| 110 | return ve->getKind() == VariableElement::Attribute || |
| 111 | ve->getKind() == VariableElement::Property; |
| 112 | } |
| 113 | |
| 114 | static bool classof(const FormatElement *fe) { |
| 115 | return isa<VariableElement>(Val: fe) && classof(ve: cast<VariableElement>(Val: fe)); |
| 116 | } |
| 117 | |
| 118 | /// Returns true if the variable is a UnitAttr or a UnitProp. |
| 119 | bool isUnit() const { |
| 120 | if (const auto *attr = dyn_cast<AttributeVariable>(Val: this)) |
| 121 | return attr->getVar()->attr.getBaseAttr().getAttrDefName() == "UnitAttr" ; |
| 122 | if (const auto *prop = dyn_cast<PropertyVariable>(Val: this)) { |
| 123 | StringRef baseDefName = |
| 124 | prop->getVar()->prop.getBaseProperty().getPropertyDefName(); |
| 125 | // Note: remove the `UnitProperty` case once the deprecation period is |
| 126 | // over. |
| 127 | return baseDefName == "UnitProp" || baseDefName == "UnitProperty" ; |
| 128 | } |
| 129 | llvm_unreachable("Type that wasn't listed in classof()" ); |
| 130 | } |
| 131 | |
| 132 | StringRef getName() const { |
| 133 | if (const auto *attr = dyn_cast<AttributeVariable>(Val: this)) |
| 134 | return attr->getVar()->name; |
| 135 | if (const auto *prop = dyn_cast<PropertyVariable>(Val: this)) |
| 136 | return prop->getVar()->name; |
| 137 | llvm_unreachable("Type that wasn't listed in classof()" ); |
| 138 | } |
| 139 | }; |
| 140 | } // namespace |
| 141 | |
| 142 | //===----------------------------------------------------------------------===// |
| 143 | // DirectiveElement |
| 144 | //===----------------------------------------------------------------------===// |
| 145 | |
| 146 | namespace { |
| 147 | /// This class represents the `operands` directive. This directive represents |
| 148 | /// all of the operands of an operation. |
| 149 | using OperandsDirective = DirectiveElementBase<DirectiveElement::Operands>; |
| 150 | |
| 151 | /// This class represents the `results` directive. This directive represents |
| 152 | /// all of the results of an operation. |
| 153 | using ResultsDirective = DirectiveElementBase<DirectiveElement::Results>; |
| 154 | |
| 155 | /// This class represents the `regions` directive. This directive represents |
| 156 | /// all of the regions of an operation. |
| 157 | using RegionsDirective = DirectiveElementBase<DirectiveElement::Regions>; |
| 158 | |
| 159 | /// This class represents the `successors` directive. This directive represents |
| 160 | /// all of the successors of an operation. |
| 161 | using SuccessorsDirective = DirectiveElementBase<DirectiveElement::Successors>; |
| 162 | |
| 163 | /// This class represents the `attr-dict` directive. This directive represents |
| 164 | /// the attribute dictionary of the operation. |
| 165 | class AttrDictDirective |
| 166 | : public DirectiveElementBase<DirectiveElement::AttrDict> { |
| 167 | public: |
| 168 | explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {} |
| 169 | |
| 170 | /// Return whether the dictionary should be printed with the 'attributes' |
| 171 | /// keyword. |
| 172 | bool isWithKeyword() const { return withKeyword; } |
| 173 | |
| 174 | private: |
| 175 | /// If the dictionary should be printed with the 'attributes' keyword. |
| 176 | bool withKeyword; |
| 177 | }; |
| 178 | |
| 179 | /// This class represents the `prop-dict` directive. This directive represents |
| 180 | /// the properties of the operation, expressed as a directionary. |
| 181 | class PropDictDirective |
| 182 | : public DirectiveElementBase<DirectiveElement::PropDict> { |
| 183 | public: |
| 184 | explicit PropDictDirective() = default; |
| 185 | }; |
| 186 | |
| 187 | /// This class represents the `functional-type` directive. This directive takes |
| 188 | /// two arguments and formats them, respectively, as the inputs and results of a |
| 189 | /// FunctionType. |
| 190 | class FunctionalTypeDirective |
| 191 | : public DirectiveElementBase<DirectiveElement::FunctionalType> { |
| 192 | public: |
| 193 | FunctionalTypeDirective(FormatElement *inputs, FormatElement *results) |
| 194 | : inputs(inputs), results(results) {} |
| 195 | |
| 196 | FormatElement *getInputs() const { return inputs; } |
| 197 | FormatElement *getResults() const { return results; } |
| 198 | |
| 199 | private: |
| 200 | /// The input and result arguments. |
| 201 | FormatElement *inputs, *results; |
| 202 | }; |
| 203 | |
| 204 | /// This class represents the `type` directive. |
| 205 | class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> { |
| 206 | public: |
| 207 | TypeDirective(FormatElement *arg) : arg(arg) {} |
| 208 | |
| 209 | FormatElement *getArg() const { return arg; } |
| 210 | |
| 211 | /// Indicate if this type is printed "qualified" (that is it is |
| 212 | /// prefixed with the `!dialect.mnemonic`). |
| 213 | bool shouldBeQualified() { return shouldBeQualifiedFlag; } |
| 214 | void setShouldBeQualified(bool qualified = true) { |
| 215 | shouldBeQualifiedFlag = qualified; |
| 216 | } |
| 217 | |
| 218 | private: |
| 219 | /// The argument that is used to format the directive. |
| 220 | FormatElement *arg; |
| 221 | |
| 222 | bool shouldBeQualifiedFlag = false; |
| 223 | }; |
| 224 | |
| 225 | /// This class represents a group of order-independent optional clauses. Each |
| 226 | /// clause starts with a literal element and has a coressponding parsing |
| 227 | /// element. A parsing element is a continous sequence of format elements. |
| 228 | /// Each clause can appear 0 or 1 time. |
| 229 | class OIListElement : public DirectiveElementBase<DirectiveElement::OIList> { |
| 230 | public: |
| 231 | OIListElement(std::vector<FormatElement *> &&literalElements, |
| 232 | std::vector<std::vector<FormatElement *>> &&parsingElements) |
| 233 | : literalElements(std::move(literalElements)), |
| 234 | parsingElements(std::move(parsingElements)) {} |
| 235 | |
| 236 | /// Returns a range to iterate over the LiteralElements. |
| 237 | auto getLiteralElements() const { |
| 238 | return llvm::map_range(C: literalElements, F: [](FormatElement *el) { |
| 239 | return cast<LiteralElement>(Val: el); |
| 240 | }); |
| 241 | } |
| 242 | |
| 243 | /// Returns a range to iterate over the parsing elements corresponding to the |
| 244 | /// clauses. |
| 245 | ArrayRef<std::vector<FormatElement *>> getParsingElements() const { |
| 246 | return parsingElements; |
| 247 | } |
| 248 | |
| 249 | /// Returns a range to iterate over tuples of parsing and literal elements. |
| 250 | auto getClauses() const { |
| 251 | return llvm::zip(t: getLiteralElements(), u: getParsingElements()); |
| 252 | } |
| 253 | |
| 254 | /// If the parsing element is a single UnitAttr element, then it returns the |
| 255 | /// attribute variable. Otherwise, returns nullptr. |
| 256 | AttributeLikeVariable * |
| 257 | getUnitVariableParsingElement(ArrayRef<FormatElement *> pelement) { |
| 258 | if (pelement.size() == 1) { |
| 259 | auto *attrElem = dyn_cast<AttributeLikeVariable>(Val: pelement[0]); |
| 260 | if (attrElem && attrElem->isUnit()) |
| 261 | return attrElem; |
| 262 | } |
| 263 | return nullptr; |
| 264 | } |
| 265 | |
| 266 | private: |
| 267 | /// A vector of `LiteralElement` objects. Each element stores the keyword |
| 268 | /// for one case of oilist element. For example, an oilist element along with |
| 269 | /// the `literalElements` vector: |
| 270 | /// ``` |
| 271 | /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`] |
| 272 | /// literalElements = { `keyword`, `otherKeyword` } |
| 273 | /// ``` |
| 274 | std::vector<FormatElement *> literalElements; |
| 275 | |
| 276 | /// A vector of valid declarative assembly format vectors. Each object in |
| 277 | /// parsing elements is a vector of elements in assembly format syntax. |
| 278 | /// For example, an oilist element along with the parsingElements vector: |
| 279 | /// ``` |
| 280 | /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`] |
| 281 | /// parsingElements = { |
| 282 | /// { `=`, `(`, $arg0, `)` }, |
| 283 | /// { `<`, $arg1, `>` } |
| 284 | /// } |
| 285 | /// ``` |
| 286 | std::vector<std::vector<FormatElement *>> parsingElements; |
| 287 | }; |
| 288 | } // namespace |
| 289 | |
| 290 | //===----------------------------------------------------------------------===// |
| 291 | // OperationFormat |
| 292 | //===----------------------------------------------------------------------===// |
| 293 | |
| 294 | namespace { |
| 295 | |
| 296 | using ConstArgument = |
| 297 | llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>; |
| 298 | |
| 299 | struct OperationFormat { |
| 300 | /// This class represents a specific resolver for an operand or result type. |
| 301 | class TypeResolution { |
| 302 | public: |
| 303 | TypeResolution() = default; |
| 304 | |
| 305 | /// Get the index into the buildable types for this type, or std::nullopt. |
| 306 | std::optional<int> getBuilderIdx() const { return builderIdx; } |
| 307 | void setBuilderIdx(int idx) { builderIdx = idx; } |
| 308 | |
| 309 | /// Get the variable this type is resolved to, or nullptr. |
| 310 | const NamedTypeConstraint *getVariable() const { |
| 311 | return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(Val: resolver); |
| 312 | } |
| 313 | /// Get the attribute this type is resolved to, or nullptr. |
| 314 | const NamedAttribute *getAttribute() const { |
| 315 | return llvm::dyn_cast_if_present<const NamedAttribute *>(Val: resolver); |
| 316 | } |
| 317 | /// Get the transformer for the type of the variable, or std::nullopt. |
| 318 | std::optional<StringRef> getVarTransformer() const { |
| 319 | return variableTransformer; |
| 320 | } |
| 321 | void setResolver(ConstArgument arg, std::optional<StringRef> transformer) { |
| 322 | resolver = arg; |
| 323 | variableTransformer = transformer; |
| 324 | assert(getVariable() || getAttribute()); |
| 325 | } |
| 326 | |
| 327 | private: |
| 328 | /// If the type is resolved with a buildable type, this is the index into |
| 329 | /// 'buildableTypes' in the parent format. |
| 330 | std::optional<int> builderIdx; |
| 331 | /// If the type is resolved based upon another operand or result, this is |
| 332 | /// the variable or the attribute that this type is resolved to. |
| 333 | ConstArgument resolver; |
| 334 | /// If the type is resolved based upon another operand or result, this is |
| 335 | /// a transformer to apply to the variable when resolving. |
| 336 | std::optional<StringRef> variableTransformer; |
| 337 | }; |
| 338 | |
| 339 | /// The context in which an element is generated. |
| 340 | enum class GenContext { |
| 341 | /// The element is generated at the top-level or with the same behaviour. |
| 342 | Normal, |
| 343 | /// The element is generated inside an optional group. |
| 344 | Optional |
| 345 | }; |
| 346 | |
| 347 | OperationFormat(const Operator &op, bool hasProperties) |
| 348 | : useProperties(hasProperties), opCppClassName(op.getCppClassName()) { |
| 349 | operandTypes.resize(new_size: op.getNumOperands(), x: TypeResolution()); |
| 350 | resultTypes.resize(new_size: op.getNumResults(), x: TypeResolution()); |
| 351 | |
| 352 | hasImplicitTermTrait = llvm::any_of(Range: op.getTraits(), P: [](const Trait &trait) { |
| 353 | return trait.getDef().isSubClassOf(Name: "SingleBlockImplicitTerminatorImpl" ); |
| 354 | }); |
| 355 | |
| 356 | hasSingleBlockTrait = op.getTrait(trait: "::mlir::OpTrait::SingleBlock" ); |
| 357 | } |
| 358 | |
| 359 | /// Generate the operation parser from this format. |
| 360 | void genParser(Operator &op, OpClass &opClass); |
| 361 | /// Generate the parser code for a specific format element. |
| 362 | void genElementParser(FormatElement *element, MethodBody &body, |
| 363 | FmtContext &attrTypeCtx, |
| 364 | GenContext genCtx = GenContext::Normal); |
| 365 | /// Generate the C++ to resolve the types of operands and results during |
| 366 | /// parsing. |
| 367 | void genParserTypeResolution(Operator &op, MethodBody &body); |
| 368 | /// Generate the C++ to resolve the types of the operands during parsing. |
| 369 | void genParserOperandTypeResolution( |
| 370 | Operator &op, MethodBody &body, |
| 371 | function_ref<void(TypeResolution &, StringRef)> emitTypeResolver); |
| 372 | /// Generate the C++ to resolve regions during parsing. |
| 373 | void genParserRegionResolution(Operator &op, MethodBody &body); |
| 374 | /// Generate the C++ to resolve successors during parsing. |
| 375 | void genParserSuccessorResolution(Operator &op, MethodBody &body); |
| 376 | /// Generate the C++ to handling variadic segment size traits. |
| 377 | void genParserVariadicSegmentResolution(Operator &op, MethodBody &body); |
| 378 | |
| 379 | /// Generate the operation printer from this format. |
| 380 | void genPrinter(Operator &op, OpClass &opClass); |
| 381 | |
| 382 | /// Generate the printer code for a specific format element. |
| 383 | void genElementPrinter(FormatElement *element, MethodBody &body, Operator &op, |
| 384 | bool &shouldEmitSpace, bool &lastWasPunctuation); |
| 385 | |
| 386 | /// The various elements in this format. |
| 387 | std::vector<FormatElement *> elements; |
| 388 | |
| 389 | /// A flag indicating if all operand/result types were seen. If the format |
| 390 | /// contains these, it can not contain individual type resolvers. |
| 391 | bool allOperands = false, allOperandTypes = false, allResultTypes = false; |
| 392 | |
| 393 | /// A flag indicating if this operation infers its result types |
| 394 | bool infersResultTypes = false; |
| 395 | |
| 396 | /// A flag indicating if this operation has the SingleBlockImplicitTerminator |
| 397 | /// trait. |
| 398 | bool hasImplicitTermTrait; |
| 399 | |
| 400 | /// A flag indicating if this operation has the SingleBlock trait. |
| 401 | bool hasSingleBlockTrait; |
| 402 | |
| 403 | /// Indicate whether we need to use properties for the current operator. |
| 404 | bool useProperties; |
| 405 | |
| 406 | /// Indicate whether prop-dict is used in the format |
| 407 | bool hasPropDict; |
| 408 | |
| 409 | /// The Operation class name |
| 410 | StringRef opCppClassName; |
| 411 | |
| 412 | /// A map of buildable types to indices. |
| 413 | llvm::MapVector<StringRef, int, StringMap<int>> buildableTypes; |
| 414 | |
| 415 | /// The index of the buildable type, if valid, for every operand and result. |
| 416 | std::vector<TypeResolution> operandTypes, resultTypes; |
| 417 | |
| 418 | /// The set of attributes explicitly used within the format. |
| 419 | llvm::SmallSetVector<const NamedAttribute *, 8> usedAttributes; |
| 420 | llvm::StringSet<> inferredAttributes; |
| 421 | |
| 422 | /// The set of properties explicitly used within the format. |
| 423 | llvm::SmallSetVector<const NamedProperty *, 8> usedProperties; |
| 424 | }; |
| 425 | } // namespace |
| 426 | |
| 427 | //===----------------------------------------------------------------------===// |
| 428 | // Parser Gen |
| 429 | //===----------------------------------------------------------------------===// |
| 430 | |
| 431 | /// Returns true if we can format the given attribute as an enum in the |
| 432 | /// parser format. |
| 433 | static bool canFormatEnumAttr(const NamedAttribute *attr) { |
| 434 | Attribute baseAttr = attr->attr.getBaseAttr(); |
| 435 | if (!baseAttr.isEnumAttr()) |
| 436 | return false; |
| 437 | EnumInfo enumInfo(&baseAttr.getDef()); |
| 438 | |
| 439 | // The attribute must have a valid underlying type and a constant builder. |
| 440 | return !enumInfo.getUnderlyingType().empty() && |
| 441 | !baseAttr.getConstBuilderTemplate().empty(); |
| 442 | } |
| 443 | |
| 444 | /// Returns if we should format the given attribute as an SymbolNameAttr. |
| 445 | static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) { |
| 446 | return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr" ; |
| 447 | } |
| 448 | |
| 449 | /// The code snippet used to generate a parser call for an attribute. |
| 450 | /// |
| 451 | /// {0}: The name of the attribute. |
| 452 | /// {1}: The type for the attribute. |
| 453 | const char *const attrParserCode = R"( |
| 454 | if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{ |
| 455 | return ::mlir::failure(); |
| 456 | } |
| 457 | )" ; |
| 458 | |
| 459 | /// The code snippet used to generate a parser call for an attribute. |
| 460 | /// |
| 461 | /// {0}: The name of the attribute. |
| 462 | /// {1}: The type for the attribute. |
| 463 | const char *const genericAttrParserCode = R"( |
| 464 | if (parser.parseAttribute({0}Attr, {1})) |
| 465 | return ::mlir::failure(); |
| 466 | )" ; |
| 467 | |
| 468 | const char *const optionalAttrParserCode = R"( |
| 469 | ::mlir::OptionalParseResult parseResult{0}Attr = |
| 470 | parser.parseOptionalAttribute({0}Attr, {1}); |
| 471 | if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr)) |
| 472 | return ::mlir::failure(); |
| 473 | if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr)) |
| 474 | )" ; |
| 475 | |
| 476 | /// The code snippet used to generate a parser call for a symbol name attribute. |
| 477 | /// |
| 478 | /// {0}: The name of the attribute. |
| 479 | const char *const symbolNameAttrParserCode = R"( |
| 480 | if (parser.parseSymbolName({0}Attr)) |
| 481 | return ::mlir::failure(); |
| 482 | )" ; |
| 483 | const char *const optionalSymbolNameAttrParserCode = R"( |
| 484 | // Parsing an optional symbol name doesn't fail, so no need to check the |
| 485 | // result. |
| 486 | (void)parser.parseOptionalSymbolName({0}Attr); |
| 487 | )" ; |
| 488 | |
| 489 | /// The code snippet used to generate a parser call for an enum attribute. |
| 490 | /// |
| 491 | /// {0}: The name of the attribute. |
| 492 | /// {1}: The c++ namespace for the enum symbolize functions. |
| 493 | /// {2}: The function to symbolize a string of the enum. |
| 494 | /// {3}: The constant builder call to create an attribute of the enum type. |
| 495 | /// {4}: The set of allowed enum keywords. |
| 496 | /// {5}: The error message on failure when the enum isn't present. |
| 497 | /// {6}: The attribute assignment expression |
| 498 | const char *const enumAttrParserCode = R"( |
| 499 | { |
| 500 | ::llvm::StringRef attrStr; |
| 501 | ::mlir::NamedAttrList attrStorage; |
| 502 | auto loc = parser.getCurrentLocation(); |
| 503 | if (parser.parseOptionalKeyword(&attrStr, {4})) { |
| 504 | ::mlir::StringAttr attrVal; |
| 505 | ::mlir::OptionalParseResult parseResult = |
| 506 | parser.parseOptionalAttribute(attrVal, |
| 507 | parser.getBuilder().getNoneType(), |
| 508 | "{0}", attrStorage); |
| 509 | if (parseResult.has_value()) {{ |
| 510 | if (failed(*parseResult)) |
| 511 | return ::mlir::failure(); |
| 512 | attrStr = attrVal.getValue(); |
| 513 | } else { |
| 514 | {5} |
| 515 | } |
| 516 | } |
| 517 | if (!attrStr.empty()) { |
| 518 | auto attrOptional = {1}::{2}(attrStr); |
| 519 | if (!attrOptional) |
| 520 | return parser.emitError(loc, "invalid ") |
| 521 | << "{0} attribute specification: \"" << attrStr << '"';; |
| 522 | |
| 523 | {0}Attr = {3}; |
| 524 | {6} |
| 525 | } |
| 526 | } |
| 527 | )" ; |
| 528 | |
| 529 | /// The code snippet used to generate a parser call for a property. |
| 530 | /// {0}: The name of the property |
| 531 | /// {1}: The C++ class name of the operation |
| 532 | /// {2}: The property's parser code with appropriate substitutions performed |
| 533 | /// {3}: The description of the expected property for the error message. |
| 534 | const char *const propertyParserCode = R"( |
| 535 | auto {0}PropLoc = parser.getCurrentLocation(); |
| 536 | auto {0}PropParseResult = [&](auto& propStorage) -> ::mlir::ParseResult {{ |
| 537 | {2} |
| 538 | return ::mlir::success(); |
| 539 | }(result.getOrAddProperties<{1}::Properties>().{0}); |
| 540 | if (failed({0}PropParseResult)) {{ |
| 541 | return parser.emitError({0}PropLoc, "invalid value for property {0}, expected {3}"); |
| 542 | } |
| 543 | )" ; |
| 544 | |
| 545 | /// The code snippet used to generate a parser call for a property. |
| 546 | /// {0}: The name of the property |
| 547 | /// {1}: The C++ class name of the operation |
| 548 | /// {2}: The property's parser code with appropriate substitutions performed |
| 549 | const char *const optionalPropertyParserCode = R"( |
| 550 | auto {0}PropParseResult = [&](auto& propStorage) -> ::mlir::OptionalParseResult {{ |
| 551 | {2} |
| 552 | return ::mlir::success(); |
| 553 | }(result.getOrAddProperties<{1}::Properties>().{0}); |
| 554 | if ({0}PropParseResult.has_value() && failed(*{0}PropParseResult)) {{ |
| 555 | return ::mlir::failure(); |
| 556 | } |
| 557 | )" ; |
| 558 | |
| 559 | /// The code snippet used to generate a parser call for an operand. |
| 560 | /// |
| 561 | /// {0}: The name of the operand. |
| 562 | const char *const variadicOperandParserCode = R"( |
| 563 | {0}OperandsLoc = parser.getCurrentLocation(); |
| 564 | if (parser.parseOperandList({0}Operands)) |
| 565 | return ::mlir::failure(); |
| 566 | )" ; |
| 567 | const char *const optionalOperandParserCode = R"( |
| 568 | { |
| 569 | {0}OperandsLoc = parser.getCurrentLocation(); |
| 570 | ::mlir::OpAsmParser::UnresolvedOperand operand; |
| 571 | ::mlir::OptionalParseResult parseResult = |
| 572 | parser.parseOptionalOperand(operand); |
| 573 | if (parseResult.has_value()) { |
| 574 | if (failed(*parseResult)) |
| 575 | return ::mlir::failure(); |
| 576 | {0}Operands.push_back(operand); |
| 577 | } |
| 578 | } |
| 579 | )" ; |
| 580 | const char *const operandParserCode = R"( |
| 581 | {0}OperandsLoc = parser.getCurrentLocation(); |
| 582 | if (parser.parseOperand({0}RawOperand)) |
| 583 | return ::mlir::failure(); |
| 584 | )" ; |
| 585 | /// The code snippet used to generate a parser call for a VariadicOfVariadic |
| 586 | /// operand. |
| 587 | /// |
| 588 | /// {0}: The name of the operand. |
| 589 | /// {1}: The name of segment size attribute. |
| 590 | const char *const variadicOfVariadicOperandParserCode = R"( |
| 591 | { |
| 592 | {0}OperandsLoc = parser.getCurrentLocation(); |
| 593 | int32_t curSize = 0; |
| 594 | do { |
| 595 | if (parser.parseOptionalLParen()) |
| 596 | break; |
| 597 | if (parser.parseOperandList({0}Operands) || parser.parseRParen()) |
| 598 | return ::mlir::failure(); |
| 599 | {0}OperandGroupSizes.push_back({0}Operands.size() - curSize); |
| 600 | curSize = {0}Operands.size(); |
| 601 | } while (succeeded(parser.parseOptionalComma())); |
| 602 | } |
| 603 | )" ; |
| 604 | |
| 605 | /// The code snippet used to generate a parser call for a type list. |
| 606 | /// |
| 607 | /// {0}: The name for the type list. |
| 608 | const char *const variadicOfVariadicTypeParserCode = R"( |
| 609 | do { |
| 610 | if (parser.parseOptionalLParen()) |
| 611 | break; |
| 612 | if (parser.parseOptionalRParen() && |
| 613 | (parser.parseTypeList({0}Types) || parser.parseRParen())) |
| 614 | return ::mlir::failure(); |
| 615 | } while (succeeded(parser.parseOptionalComma())); |
| 616 | )" ; |
| 617 | const char *const variadicTypeParserCode = R"( |
| 618 | if (parser.parseTypeList({0}Types)) |
| 619 | return ::mlir::failure(); |
| 620 | )" ; |
| 621 | const char *const optionalTypeParserCode = R"( |
| 622 | { |
| 623 | ::mlir::Type optionalType; |
| 624 | ::mlir::OptionalParseResult parseResult = |
| 625 | parser.parseOptionalType(optionalType); |
| 626 | if (parseResult.has_value()) { |
| 627 | if (failed(*parseResult)) |
| 628 | return ::mlir::failure(); |
| 629 | {0}Types.push_back(optionalType); |
| 630 | } |
| 631 | } |
| 632 | )" ; |
| 633 | const char *const typeParserCode = R"( |
| 634 | { |
| 635 | {0} type; |
| 636 | if (parser.parseCustomTypeWithFallback(type)) |
| 637 | return ::mlir::failure(); |
| 638 | {1}RawType = type; |
| 639 | } |
| 640 | )" ; |
| 641 | const char *const qualifiedTypeParserCode = R"( |
| 642 | if (parser.parseType({1}RawType)) |
| 643 | return ::mlir::failure(); |
| 644 | )" ; |
| 645 | |
| 646 | /// The code snippet used to generate a parser call for a functional type. |
| 647 | /// |
| 648 | /// {0}: The name for the input type list. |
| 649 | /// {1}: The name for the result type list. |
| 650 | const char *const functionalTypeParserCode = R"( |
| 651 | ::mlir::FunctionType {0}__{1}_functionType; |
| 652 | if (parser.parseType({0}__{1}_functionType)) |
| 653 | return ::mlir::failure(); |
| 654 | {0}Types = {0}__{1}_functionType.getInputs(); |
| 655 | {1}Types = {0}__{1}_functionType.getResults(); |
| 656 | )" ; |
| 657 | |
| 658 | /// The code snippet used to generate a parser call to infer return types. |
| 659 | /// |
| 660 | /// {0}: The operation class name |
| 661 | const char *const inferReturnTypesParserCode = R"( |
| 662 | ::llvm::SmallVector<::mlir::Type> inferredReturnTypes; |
| 663 | if (::mlir::failed({0}::inferReturnTypes(parser.getContext(), |
| 664 | result.location, result.operands, |
| 665 | result.attributes.getDictionary(parser.getContext()), |
| 666 | result.getRawProperties(), |
| 667 | result.regions, inferredReturnTypes))) |
| 668 | return ::mlir::failure(); |
| 669 | result.addTypes(inferredReturnTypes); |
| 670 | )" ; |
| 671 | |
| 672 | /// The code snippet used to generate a parser call for a region list. |
| 673 | /// |
| 674 | /// {0}: The name for the region list. |
| 675 | const char *regionListParserCode = R"( |
| 676 | { |
| 677 | std::unique_ptr<::mlir::Region> region; |
| 678 | auto firstRegionResult = parser.parseOptionalRegion(region); |
| 679 | if (firstRegionResult.has_value()) { |
| 680 | if (failed(*firstRegionResult)) |
| 681 | return ::mlir::failure(); |
| 682 | {0}Regions.emplace_back(std::move(region)); |
| 683 | |
| 684 | // Parse any trailing regions. |
| 685 | while (succeeded(parser.parseOptionalComma())) { |
| 686 | region = std::make_unique<::mlir::Region>(); |
| 687 | if (parser.parseRegion(*region)) |
| 688 | return ::mlir::failure(); |
| 689 | {0}Regions.emplace_back(std::move(region)); |
| 690 | } |
| 691 | } |
| 692 | } |
| 693 | )" ; |
| 694 | |
| 695 | /// The code snippet used to ensure a list of regions have terminators. |
| 696 | /// |
| 697 | /// {0}: The name of the region list. |
| 698 | const char *regionListEnsureTerminatorParserCode = R"( |
| 699 | for (auto ®ion : {0}Regions) |
| 700 | ensureTerminator(*region, parser.getBuilder(), result.location); |
| 701 | )" ; |
| 702 | |
| 703 | /// The code snippet used to ensure a list of regions have a block. |
| 704 | /// |
| 705 | /// {0}: The name of the region list. |
| 706 | const char *regionListEnsureSingleBlockParserCode = R"( |
| 707 | for (auto ®ion : {0}Regions) |
| 708 | if (region->empty()) region->emplaceBlock(); |
| 709 | )" ; |
| 710 | |
| 711 | /// The code snippet used to generate a parser call for an optional region. |
| 712 | /// |
| 713 | /// {0}: The name of the region. |
| 714 | const char *optionalRegionParserCode = R"( |
| 715 | { |
| 716 | auto parseResult = parser.parseOptionalRegion(*{0}Region); |
| 717 | if (parseResult.has_value() && failed(*parseResult)) |
| 718 | return ::mlir::failure(); |
| 719 | } |
| 720 | )" ; |
| 721 | |
| 722 | /// The code snippet used to generate a parser call for a region. |
| 723 | /// |
| 724 | /// {0}: The name of the region. |
| 725 | const char *regionParserCode = R"( |
| 726 | if (parser.parseRegion(*{0}Region)) |
| 727 | return ::mlir::failure(); |
| 728 | )" ; |
| 729 | |
| 730 | /// The code snippet used to ensure a region has a terminator. |
| 731 | /// |
| 732 | /// {0}: The name of the region. |
| 733 | const char *regionEnsureTerminatorParserCode = R"( |
| 734 | ensureTerminator(*{0}Region, parser.getBuilder(), result.location); |
| 735 | )" ; |
| 736 | |
| 737 | /// The code snippet used to ensure a region has a block. |
| 738 | /// |
| 739 | /// {0}: The name of the region. |
| 740 | const char *regionEnsureSingleBlockParserCode = R"( |
| 741 | if ({0}Region->empty()) {0}Region->emplaceBlock(); |
| 742 | )" ; |
| 743 | |
| 744 | /// The code snippet used to generate a parser call for a successor list. |
| 745 | /// |
| 746 | /// {0}: The name for the successor list. |
| 747 | const char *successorListParserCode = R"( |
| 748 | { |
| 749 | ::mlir::Block *succ; |
| 750 | auto firstSucc = parser.parseOptionalSuccessor(succ); |
| 751 | if (firstSucc.has_value()) { |
| 752 | if (failed(*firstSucc)) |
| 753 | return ::mlir::failure(); |
| 754 | {0}Successors.emplace_back(succ); |
| 755 | |
| 756 | // Parse any trailing successors. |
| 757 | while (succeeded(parser.parseOptionalComma())) { |
| 758 | if (parser.parseSuccessor(succ)) |
| 759 | return ::mlir::failure(); |
| 760 | {0}Successors.emplace_back(succ); |
| 761 | } |
| 762 | } |
| 763 | } |
| 764 | )" ; |
| 765 | |
| 766 | /// The code snippet used to generate a parser call for a successor. |
| 767 | /// |
| 768 | /// {0}: The name of the successor. |
| 769 | const char *successorParserCode = R"( |
| 770 | if (parser.parseSuccessor({0}Successor)) |
| 771 | return ::mlir::failure(); |
| 772 | )" ; |
| 773 | |
| 774 | /// The code snippet used to generate a parser for OIList |
| 775 | /// |
| 776 | /// {0}: literal keyword corresponding to a case for oilist |
| 777 | const char *oilistParserCode = R"( |
| 778 | if ({0}Clause) { |
| 779 | return parser.emitError(parser.getNameLoc()) |
| 780 | << "`{0}` clause can appear at most once in the expansion of the " |
| 781 | "oilist directive"; |
| 782 | } |
| 783 | {0}Clause = true; |
| 784 | )" ; |
| 785 | |
| 786 | namespace { |
| 787 | /// The type of length for a given parse argument. |
| 788 | enum class ArgumentLengthKind { |
| 789 | /// The argument is a variadic of a variadic, and may contain 0->N range |
| 790 | /// elements. |
| 791 | VariadicOfVariadic, |
| 792 | /// The argument is variadic, and may contain 0->N elements. |
| 793 | Variadic, |
| 794 | /// The argument is optional, and may contain 0 or 1 elements. |
| 795 | Optional, |
| 796 | /// The argument is a single element, i.e. always represents 1 element. |
| 797 | Single |
| 798 | }; |
| 799 | } // namespace |
| 800 | |
| 801 | /// Get the length kind for the given constraint. |
| 802 | static ArgumentLengthKind |
| 803 | getArgumentLengthKind(const NamedTypeConstraint *var) { |
| 804 | if (var->isOptional()) |
| 805 | return ArgumentLengthKind::Optional; |
| 806 | if (var->isVariadicOfVariadic()) |
| 807 | return ArgumentLengthKind::VariadicOfVariadic; |
| 808 | if (var->isVariadic()) |
| 809 | return ArgumentLengthKind::Variadic; |
| 810 | return ArgumentLengthKind::Single; |
| 811 | } |
| 812 | |
| 813 | /// Get the name used for the type list for the given type directive operand. |
| 814 | /// 'lengthKind' to the corresponding kind for the given argument. |
| 815 | static StringRef getTypeListName(FormatElement *arg, |
| 816 | ArgumentLengthKind &lengthKind) { |
| 817 | if (auto *operand = dyn_cast<OperandVariable>(Val: arg)) { |
| 818 | lengthKind = getArgumentLengthKind(var: operand->getVar()); |
| 819 | return operand->getVar()->name; |
| 820 | } |
| 821 | if (auto *result = dyn_cast<ResultVariable>(Val: arg)) { |
| 822 | lengthKind = getArgumentLengthKind(var: result->getVar()); |
| 823 | return result->getVar()->name; |
| 824 | } |
| 825 | lengthKind = ArgumentLengthKind::Variadic; |
| 826 | if (isa<OperandsDirective>(Val: arg)) |
| 827 | return "allOperand" ; |
| 828 | if (isa<ResultsDirective>(Val: arg)) |
| 829 | return "allResult" ; |
| 830 | llvm_unreachable("unknown 'type' directive argument" ); |
| 831 | } |
| 832 | |
| 833 | /// Generate the parser for a literal value. |
| 834 | static void genLiteralParser(StringRef value, MethodBody &body) { |
| 835 | // Handle the case of a keyword/identifier. |
| 836 | if (value.front() == '_' || isalpha(value.front())) { |
| 837 | body << "Keyword(\"" << value << "\")" ; |
| 838 | return; |
| 839 | } |
| 840 | body << (StringRef)StringSwitch<StringRef>(value) |
| 841 | .Case(S: "->" , Value: "Arrow()" ) |
| 842 | .Case(S: ":" , Value: "Colon()" ) |
| 843 | .Case(S: "," , Value: "Comma()" ) |
| 844 | .Case(S: "=" , Value: "Equal()" ) |
| 845 | .Case(S: "<" , Value: "Less()" ) |
| 846 | .Case(S: ">" , Value: "Greater()" ) |
| 847 | .Case(S: "{" , Value: "LBrace()" ) |
| 848 | .Case(S: "}" , Value: "RBrace()" ) |
| 849 | .Case(S: "(" , Value: "LParen()" ) |
| 850 | .Case(S: ")" , Value: "RParen()" ) |
| 851 | .Case(S: "[" , Value: "LSquare()" ) |
| 852 | .Case(S: "]" , Value: "RSquare()" ) |
| 853 | .Case(S: "?" , Value: "Question()" ) |
| 854 | .Case(S: "+" , Value: "Plus()" ) |
| 855 | .Case(S: "*" , Value: "Star()" ) |
| 856 | .Case(S: "..." , Value: "Ellipsis()" ); |
| 857 | } |
| 858 | |
| 859 | /// Generate the storage code required for parsing the given element. |
| 860 | static void genElementParserStorage(FormatElement *element, const Operator &op, |
| 861 | MethodBody &body) { |
| 862 | if (auto *optional = dyn_cast<OptionalElement>(Val: element)) { |
| 863 | ArrayRef<FormatElement *> elements = optional->getThenElements(); |
| 864 | |
| 865 | // If the anchor is a unit attribute, it won't be parsed directly so elide |
| 866 | // it. |
| 867 | auto *anchor = dyn_cast<AttributeLikeVariable>(Val: optional->getAnchor()); |
| 868 | FormatElement *elidedAnchorElement = nullptr; |
| 869 | if (anchor && anchor != elements.front() && anchor->isUnit()) |
| 870 | elidedAnchorElement = anchor; |
| 871 | for (FormatElement *childElement : elements) |
| 872 | if (childElement != elidedAnchorElement) |
| 873 | genElementParserStorage(element: childElement, op, body); |
| 874 | for (FormatElement *childElement : optional->getElseElements()) |
| 875 | genElementParserStorage(element: childElement, op, body); |
| 876 | |
| 877 | } else if (auto *oilist = dyn_cast<OIListElement>(Val: element)) { |
| 878 | for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements()) { |
| 879 | if (!oilist->getUnitVariableParsingElement(pelement)) |
| 880 | for (FormatElement *element : pelement) |
| 881 | genElementParserStorage(element, op, body); |
| 882 | } |
| 883 | |
| 884 | } else if (auto *custom = dyn_cast<CustomDirective>(Val: element)) { |
| 885 | for (FormatElement *paramElement : custom->getElements()) |
| 886 | genElementParserStorage(element: paramElement, op, body); |
| 887 | |
| 888 | } else if (isa<OperandsDirective>(Val: element)) { |
| 889 | body << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> " |
| 890 | "allOperands;\n" ; |
| 891 | |
| 892 | } else if (isa<RegionsDirective>(Val: element)) { |
| 893 | body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> " |
| 894 | "fullRegions;\n" ; |
| 895 | |
| 896 | } else if (isa<SuccessorsDirective>(Val: element)) { |
| 897 | body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n" ; |
| 898 | |
| 899 | } else if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) { |
| 900 | const NamedAttribute *var = attr->getVar(); |
| 901 | body << formatv(Fmt: " {0} {1}Attr;\n" , Vals: var->attr.getStorageType(), Vals: var->name); |
| 902 | |
| 903 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) { |
| 904 | StringRef name = operand->getVar()->name; |
| 905 | if (operand->getVar()->isVariableLength()) { |
| 906 | body |
| 907 | << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> " |
| 908 | << name << "Operands;\n" ; |
| 909 | if (operand->getVar()->isVariadicOfVariadic()) { |
| 910 | body << " llvm::SmallVector<int32_t> " << name |
| 911 | << "OperandGroupSizes;\n" ; |
| 912 | } |
| 913 | } else { |
| 914 | body << " ::mlir::OpAsmParser::UnresolvedOperand " << name |
| 915 | << "RawOperand{};\n" |
| 916 | << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> " |
| 917 | << name << "Operands(&" << name << "RawOperand, 1);" ; |
| 918 | } |
| 919 | body << formatv(Fmt: " ::llvm::SMLoc {0}OperandsLoc;\n" |
| 920 | " (void){0}OperandsLoc;\n" , |
| 921 | Vals&: name); |
| 922 | |
| 923 | } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) { |
| 924 | StringRef name = region->getVar()->name; |
| 925 | if (region->getVar()->isVariadic()) { |
| 926 | body << formatv( |
| 927 | Fmt: " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> " |
| 928 | "{0}Regions;\n" , |
| 929 | Vals&: name); |
| 930 | } else { |
| 931 | body << formatv(Fmt: " std::unique_ptr<::mlir::Region> {0}Region = " |
| 932 | "std::make_unique<::mlir::Region>();\n" , |
| 933 | Vals&: name); |
| 934 | } |
| 935 | |
| 936 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) { |
| 937 | StringRef name = successor->getVar()->name; |
| 938 | if (successor->getVar()->isVariadic()) { |
| 939 | body << formatv(Fmt: " ::llvm::SmallVector<::mlir::Block *, 2> " |
| 940 | "{0}Successors;\n" , |
| 941 | Vals&: name); |
| 942 | } else { |
| 943 | body << formatv(Fmt: " ::mlir::Block *{0}Successor = nullptr;\n" , Vals&: name); |
| 944 | } |
| 945 | |
| 946 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) { |
| 947 | ArgumentLengthKind lengthKind; |
| 948 | StringRef name = getTypeListName(arg: dir->getArg(), lengthKind); |
| 949 | if (lengthKind != ArgumentLengthKind::Single) |
| 950 | body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n" ; |
| 951 | else |
| 952 | body |
| 953 | << formatv(Fmt: " ::mlir::Type {0}RawType{{};\n" , Vals&: name) |
| 954 | << formatv( |
| 955 | Fmt: " ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n" , |
| 956 | Vals&: name); |
| 957 | } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(Val: element)) { |
| 958 | ArgumentLengthKind ignored; |
| 959 | body << " ::llvm::ArrayRef<::mlir::Type> " |
| 960 | << getTypeListName(arg: dir->getInputs(), lengthKind&: ignored) << "Types;\n" ; |
| 961 | body << " ::llvm::ArrayRef<::mlir::Type> " |
| 962 | << getTypeListName(arg: dir->getResults(), lengthKind&: ignored) << "Types;\n" ; |
| 963 | } |
| 964 | } |
| 965 | |
| 966 | /// Generate the parser for a parameter to a custom directive. |
| 967 | static void genCustomParameterParser(FormatElement *param, MethodBody &body) { |
| 968 | if (auto *attr = dyn_cast<AttributeVariable>(Val: param)) { |
| 969 | body << attr->getVar()->name << "Attr" ; |
| 970 | } else if (isa<AttrDictDirective>(Val: param)) { |
| 971 | body << "result.attributes" ; |
| 972 | } else if (isa<PropDictDirective>(Val: param)) { |
| 973 | body << "result" ; |
| 974 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: param)) { |
| 975 | StringRef name = operand->getVar()->name; |
| 976 | ArgumentLengthKind lengthKind = getArgumentLengthKind(var: operand->getVar()); |
| 977 | if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) |
| 978 | body << formatv(Fmt: "{0}OperandGroups" , Vals&: name); |
| 979 | else if (lengthKind == ArgumentLengthKind::Variadic) |
| 980 | body << formatv(Fmt: "{0}Operands" , Vals&: name); |
| 981 | else if (lengthKind == ArgumentLengthKind::Optional) |
| 982 | body << formatv(Fmt: "{0}Operand" , Vals&: name); |
| 983 | else |
| 984 | body << formatv(Fmt: "{0}RawOperand" , Vals&: name); |
| 985 | |
| 986 | } else if (auto *region = dyn_cast<RegionVariable>(Val: param)) { |
| 987 | StringRef name = region->getVar()->name; |
| 988 | if (region->getVar()->isVariadic()) |
| 989 | body << formatv(Fmt: "{0}Regions" , Vals&: name); |
| 990 | else |
| 991 | body << formatv(Fmt: "*{0}Region" , Vals&: name); |
| 992 | |
| 993 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: param)) { |
| 994 | StringRef name = successor->getVar()->name; |
| 995 | if (successor->getVar()->isVariadic()) |
| 996 | body << formatv(Fmt: "{0}Successors" , Vals&: name); |
| 997 | else |
| 998 | body << formatv(Fmt: "{0}Successor" , Vals&: name); |
| 999 | |
| 1000 | } else if (auto *dir = dyn_cast<RefDirective>(Val: param)) { |
| 1001 | genCustomParameterParser(param: dir->getArg(), body); |
| 1002 | |
| 1003 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: param)) { |
| 1004 | ArgumentLengthKind lengthKind; |
| 1005 | StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind); |
| 1006 | if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) |
| 1007 | body << formatv(Fmt: "{0}TypeGroups" , Vals&: listName); |
| 1008 | else if (lengthKind == ArgumentLengthKind::Variadic) |
| 1009 | body << formatv(Fmt: "{0}Types" , Vals&: listName); |
| 1010 | else if (lengthKind == ArgumentLengthKind::Optional) |
| 1011 | body << formatv(Fmt: "{0}Type" , Vals&: listName); |
| 1012 | else |
| 1013 | body << formatv(Fmt: "{0}RawType" , Vals&: listName); |
| 1014 | |
| 1015 | } else if (auto *string = dyn_cast<StringElement>(Val: param)) { |
| 1016 | FmtContext ctx; |
| 1017 | ctx.withBuilder(subst: "parser.getBuilder()" ); |
| 1018 | ctx.addSubst(placeholder: "_ctxt" , subst: "parser.getContext()" ); |
| 1019 | body << tgfmt(fmt: string->getValue(), ctx: &ctx); |
| 1020 | |
| 1021 | } else if (auto *property = dyn_cast<PropertyVariable>(Val: param)) { |
| 1022 | body << formatv(Fmt: "result.getOrAddProperties<Properties>().{0}" , |
| 1023 | Vals: property->getVar()->name); |
| 1024 | } else { |
| 1025 | llvm_unreachable("unknown custom directive parameter" ); |
| 1026 | } |
| 1027 | } |
| 1028 | |
| 1029 | /// Generate the parser for a custom directive. |
| 1030 | static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, |
| 1031 | bool useProperties, |
| 1032 | StringRef opCppClassName, |
| 1033 | bool isOptional = false) { |
| 1034 | body << " {\n" ; |
| 1035 | |
| 1036 | // Preprocess the directive variables. |
| 1037 | // * Add a local variable for optional operands and types. This provides a |
| 1038 | // better API to the user defined parser methods. |
| 1039 | // * Set the location of operand variables. |
| 1040 | for (FormatElement *param : dir->getElements()) { |
| 1041 | if (auto *operand = dyn_cast<OperandVariable>(Val: param)) { |
| 1042 | auto *var = operand->getVar(); |
| 1043 | body << " " << var->name |
| 1044 | << "OperandsLoc = parser.getCurrentLocation();\n" ; |
| 1045 | if (var->isOptional()) { |
| 1046 | body << formatv( |
| 1047 | Fmt: " ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> " |
| 1048 | "{0}Operand;\n" , |
| 1049 | Vals: var->name); |
| 1050 | } else if (var->isVariadicOfVariadic()) { |
| 1051 | body << formatv(Fmt: " " |
| 1052 | "::llvm::SmallVector<::llvm::SmallVector<::mlir::" |
| 1053 | "OpAsmParser::UnresolvedOperand>> " |
| 1054 | "{0}OperandGroups;\n" , |
| 1055 | Vals: var->name); |
| 1056 | } |
| 1057 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: param)) { |
| 1058 | ArgumentLengthKind lengthKind; |
| 1059 | StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind); |
| 1060 | if (lengthKind == ArgumentLengthKind::Optional) { |
| 1061 | body << formatv(Fmt: " ::mlir::Type {0}Type;\n" , Vals&: listName); |
| 1062 | } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { |
| 1063 | body << formatv( |
| 1064 | Fmt: " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> " |
| 1065 | "{0}TypeGroups;\n" , |
| 1066 | Vals&: listName); |
| 1067 | } |
| 1068 | } else if (auto *dir = dyn_cast<RefDirective>(Val: param)) { |
| 1069 | FormatElement *input = dir->getArg(); |
| 1070 | if (auto *operand = dyn_cast<OperandVariable>(Val: input)) { |
| 1071 | if (!operand->getVar()->isOptional()) |
| 1072 | continue; |
| 1073 | body << formatv( |
| 1074 | Fmt: " {0} {1}Operand = {1}Operands.empty() ? {0}() : " |
| 1075 | "{1}Operands[0];\n" , |
| 1076 | Vals: "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>" , |
| 1077 | Vals: operand->getVar()->name); |
| 1078 | |
| 1079 | } else if (auto *type = dyn_cast<TypeDirective>(Val: input)) { |
| 1080 | ArgumentLengthKind lengthKind; |
| 1081 | StringRef listName = getTypeListName(arg: type->getArg(), lengthKind); |
| 1082 | if (lengthKind == ArgumentLengthKind::Optional) { |
| 1083 | body << formatv(Fmt: " ::mlir::Type {0}Type = {0}Types.empty() ? " |
| 1084 | "::mlir::Type() : {0}Types[0];\n" , |
| 1085 | Vals&: listName); |
| 1086 | } |
| 1087 | } |
| 1088 | } |
| 1089 | } |
| 1090 | |
| 1091 | body << " auto odsResult = parse" << dir->getName() << "(parser" ; |
| 1092 | for (FormatElement *param : dir->getElements()) { |
| 1093 | body << ", " ; |
| 1094 | genCustomParameterParser(param, body); |
| 1095 | } |
| 1096 | body << ");\n" ; |
| 1097 | |
| 1098 | if (isOptional) { |
| 1099 | body << " if (!odsResult.has_value()) return {};\n" |
| 1100 | << " if (::mlir::failed(*odsResult)) return ::mlir::failure();\n" ; |
| 1101 | } else { |
| 1102 | body << " if (odsResult) return ::mlir::failure();\n" ; |
| 1103 | } |
| 1104 | |
| 1105 | // After parsing, add handling for any of the optional constructs. |
| 1106 | for (FormatElement *param : dir->getElements()) { |
| 1107 | if (auto *attr = dyn_cast<AttributeVariable>(Val: param)) { |
| 1108 | const NamedAttribute *var = attr->getVar(); |
| 1109 | if (var->attr.isOptional() || var->attr.hasDefaultValue()) |
| 1110 | body << formatv(Fmt: " if ({0}Attr)\n " , Vals: var->name); |
| 1111 | if (useProperties) { |
| 1112 | body << formatv( |
| 1113 | Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n" , |
| 1114 | Vals: var->name, Vals&: opCppClassName); |
| 1115 | } else { |
| 1116 | body << formatv(Fmt: " result.addAttribute(\"{0}\", {0}Attr);\n" , |
| 1117 | Vals: var->name); |
| 1118 | } |
| 1119 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: param)) { |
| 1120 | const NamedTypeConstraint *var = operand->getVar(); |
| 1121 | if (var->isOptional()) { |
| 1122 | body << formatv(Fmt: " if ({0}Operand.has_value())\n" |
| 1123 | " {0}Operands.push_back(*{0}Operand);\n" , |
| 1124 | Vals: var->name); |
| 1125 | } else if (var->isVariadicOfVariadic()) { |
| 1126 | body << formatv( |
| 1127 | Fmt: " for (const auto &subRange : {0}OperandGroups) {{\n" |
| 1128 | " {0}Operands.append(subRange.begin(), subRange.end());\n" |
| 1129 | " {0}OperandGroupSizes.push_back(subRange.size());\n" |
| 1130 | " }\n" , |
| 1131 | Vals: var->name); |
| 1132 | } |
| 1133 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: param)) { |
| 1134 | ArgumentLengthKind lengthKind; |
| 1135 | StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind); |
| 1136 | if (lengthKind == ArgumentLengthKind::Optional) { |
| 1137 | body << formatv(Fmt: " if ({0}Type)\n" |
| 1138 | " {0}Types.push_back({0}Type);\n" , |
| 1139 | Vals&: listName); |
| 1140 | } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { |
| 1141 | body << formatv( |
| 1142 | Fmt: " for (const auto &subRange : {0}TypeGroups)\n" |
| 1143 | " {0}Types.append(subRange.begin(), subRange.end());\n" , |
| 1144 | Vals&: listName); |
| 1145 | } |
| 1146 | } |
| 1147 | } |
| 1148 | |
| 1149 | body << " }\n" ; |
| 1150 | } |
| 1151 | |
| 1152 | /// Generate the parser for a enum attribute. |
| 1153 | static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body, |
| 1154 | FmtContext &attrTypeCtx, bool parseAsOptional, |
| 1155 | bool useProperties, StringRef opCppClassName) { |
| 1156 | Attribute baseAttr = var->attr.getBaseAttr(); |
| 1157 | EnumInfo enumInfo(&baseAttr.getDef()); |
| 1158 | std::vector<EnumCase> cases = enumInfo.getAllCases(); |
| 1159 | |
| 1160 | // Generate the code for building an attribute for this enum. |
| 1161 | std::string attrBuilderStr; |
| 1162 | { |
| 1163 | llvm::raw_string_ostream os(attrBuilderStr); |
| 1164 | os << tgfmt(fmt: baseAttr.getConstBuilderTemplate(), ctx: &attrTypeCtx, |
| 1165 | vals: "*attrOptional" ); |
| 1166 | } |
| 1167 | |
| 1168 | // Build a string containing the cases that can be formatted as a keyword. |
| 1169 | std::string validCaseKeywordsStr = "{" ; |
| 1170 | llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr); |
| 1171 | for (const EnumCase &attrCase : cases) |
| 1172 | if (canFormatStringAsKeyword(value: attrCase.getStr())) |
| 1173 | validCaseKeywordsOS << '"' << attrCase.getStr() << "\"," ; |
| 1174 | validCaseKeywordsOS.str().back() = '}'; |
| 1175 | |
| 1176 | // If the attribute is not optional, build an error message for the missing |
| 1177 | // attribute. |
| 1178 | std::string errorMessage; |
| 1179 | if (!parseAsOptional) { |
| 1180 | llvm::raw_string_ostream errorMessageOS(errorMessage); |
| 1181 | errorMessageOS |
| 1182 | << "return parser.emitError(loc, \"expected string or " |
| 1183 | "keyword containing one of the following enum values for attribute '" |
| 1184 | << var->name << "' [" ; |
| 1185 | llvm::interleaveComma(c: cases, os&: errorMessageOS, each_fn: [&](const auto &attrCase) { |
| 1186 | errorMessageOS << attrCase.getStr(); |
| 1187 | }); |
| 1188 | errorMessageOS << "]\");" ; |
| 1189 | } |
| 1190 | std::string attrAssignment; |
| 1191 | if (useProperties) { |
| 1192 | attrAssignment = |
| 1193 | formatv(Fmt: " " |
| 1194 | "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;" , |
| 1195 | Vals: var->name, Vals&: opCppClassName); |
| 1196 | } else { |
| 1197 | attrAssignment = |
| 1198 | formatv(Fmt: "result.addAttribute(\"{0}\", {0}Attr);" , Vals: var->name); |
| 1199 | } |
| 1200 | |
| 1201 | body << formatv(Fmt: enumAttrParserCode, Vals: var->name, Vals: enumInfo.getCppNamespace(), |
| 1202 | Vals: enumInfo.getStringToSymbolFnName(), Vals&: attrBuilderStr, |
| 1203 | Vals&: validCaseKeywordsStr, Vals&: errorMessage, Vals&: attrAssignment); |
| 1204 | } |
| 1205 | |
| 1206 | // Generate the parser for a property. |
| 1207 | static void genPropertyParser(PropertyVariable *propVar, MethodBody &body, |
| 1208 | StringRef opCppClassName, |
| 1209 | bool requireParse = true) { |
| 1210 | StringRef name = propVar->getVar()->name; |
| 1211 | const Property &prop = propVar->getVar()->prop; |
| 1212 | bool parseOptionally = |
| 1213 | prop.hasDefaultValue() && !requireParse && prop.hasOptionalParser(); |
| 1214 | FmtContext fmtContext; |
| 1215 | fmtContext.addSubst(placeholder: "_parser" , subst: "parser" ); |
| 1216 | fmtContext.addSubst(placeholder: "_ctxt" , subst: "parser.getContext()" ); |
| 1217 | fmtContext.addSubst(placeholder: "_storage" , subst: "propStorage" ); |
| 1218 | |
| 1219 | if (parseOptionally) { |
| 1220 | body << formatv(Fmt: optionalPropertyParserCode, Vals&: name, Vals&: opCppClassName, |
| 1221 | Vals: tgfmt(fmt: prop.getOptionalParserCall(), ctx: &fmtContext)); |
| 1222 | } else { |
| 1223 | body << formatv(Fmt: propertyParserCode, Vals&: name, Vals&: opCppClassName, |
| 1224 | Vals: tgfmt(fmt: prop.getParserCall(), ctx: &fmtContext), |
| 1225 | Vals: prop.getSummary()); |
| 1226 | } |
| 1227 | } |
| 1228 | |
| 1229 | // Generate the parser for an attribute. |
| 1230 | static void genAttrParser(AttributeVariable *attr, MethodBody &body, |
| 1231 | FmtContext &attrTypeCtx, bool parseAsOptional, |
| 1232 | bool useProperties, StringRef opCppClassName) { |
| 1233 | const NamedAttribute *var = attr->getVar(); |
| 1234 | |
| 1235 | // Check to see if we can parse this as an enum attribute. |
| 1236 | if (canFormatEnumAttr(attr: var)) |
| 1237 | return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional, |
| 1238 | useProperties, opCppClassName); |
| 1239 | |
| 1240 | // Check to see if we should parse this as a symbol name attribute. |
| 1241 | if (shouldFormatSymbolNameAttr(attr: var)) { |
| 1242 | body << formatv(Fmt: parseAsOptional ? optionalSymbolNameAttrParserCode |
| 1243 | : symbolNameAttrParserCode, |
| 1244 | Vals: var->name); |
| 1245 | } else { |
| 1246 | |
| 1247 | // If this attribute has a buildable type, use that when parsing the |
| 1248 | // attribute. |
| 1249 | std::string attrTypeStr; |
| 1250 | if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) { |
| 1251 | llvm::raw_string_ostream os(attrTypeStr); |
| 1252 | os << tgfmt(fmt: *typeBuilder, ctx: &attrTypeCtx); |
| 1253 | } else { |
| 1254 | attrTypeStr = "::mlir::Type{}" ; |
| 1255 | } |
| 1256 | if (parseAsOptional) { |
| 1257 | body << formatv(Fmt: optionalAttrParserCode, Vals: var->name, Vals&: attrTypeStr); |
| 1258 | } else { |
| 1259 | if (attr->shouldBeQualified() || |
| 1260 | var->attr.getStorageType() == "::mlir::Attribute" ) |
| 1261 | body << formatv(Fmt: genericAttrParserCode, Vals: var->name, Vals&: attrTypeStr); |
| 1262 | else |
| 1263 | body << formatv(Fmt: attrParserCode, Vals: var->name, Vals&: attrTypeStr); |
| 1264 | } |
| 1265 | } |
| 1266 | if (useProperties) { |
| 1267 | body << formatv( |
| 1268 | Fmt: " if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = " |
| 1269 | "{0}Attr;\n" , |
| 1270 | Vals: var->name, Vals&: opCppClassName); |
| 1271 | } else { |
| 1272 | body << formatv( |
| 1273 | Fmt: " if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n" , |
| 1274 | Vals: var->name); |
| 1275 | } |
| 1276 | } |
| 1277 | |
| 1278 | // Generates the 'setPropertiesFromParsedAttr' used to set properties from a |
| 1279 | // 'prop-dict' dictionary attr. |
| 1280 | static void genParsedAttrPropertiesSetter(OperationFormat &fmt, Operator &op, |
| 1281 | OpClass &opClass) { |
| 1282 | // Not required unless 'prop-dict' is present or we are not using properties. |
| 1283 | if (!fmt.hasPropDict || !fmt.useProperties) |
| 1284 | return; |
| 1285 | |
| 1286 | SmallVector<MethodParameter> paramList; |
| 1287 | paramList.emplace_back(Args: "Properties &" , Args: "prop" ); |
| 1288 | paramList.emplace_back(Args: "::mlir::Attribute" , Args: "attr" ); |
| 1289 | paramList.emplace_back(Args: "::llvm::function_ref<::mlir::InFlightDiagnostic()>" , |
| 1290 | Args: "emitError" ); |
| 1291 | |
| 1292 | Method *method = opClass.addStaticMethod(retType: "::llvm::LogicalResult" , |
| 1293 | name: "setPropertiesFromParsedAttr" , |
| 1294 | args: std::move(paramList)); |
| 1295 | MethodBody &body = method->body().indent(); |
| 1296 | |
| 1297 | body << R"decl( |
| 1298 | ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); |
| 1299 | if (!dict) { |
| 1300 | emitError() << "expected DictionaryAttr to set properties"; |
| 1301 | return ::mlir::failure(); |
| 1302 | } |
| 1303 | // keep track of used keys in the input dictionary to be able to error out |
| 1304 | // if there are some unknown ones. |
| 1305 | DenseSet<StringAttr> usedKeys; |
| 1306 | MLIRContext *ctx = dict.getContext(); |
| 1307 | (void)ctx; |
| 1308 | )decl" ; |
| 1309 | |
| 1310 | // {0}: fromAttribute call |
| 1311 | // {1}: property name |
| 1312 | // {2}: isRequired |
| 1313 | const char *propFromAttrFmt = R"decl( |
| 1314 | auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr, |
| 1315 | ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) -> ::mlir::LogicalResult {{ |
| 1316 | {0}; |
| 1317 | }; |
| 1318 | auto {1}AttrName = StringAttr::get(ctx, "{1}"); |
| 1319 | usedKeys.insert({1}AttrName); |
| 1320 | auto attr = dict.get({1}AttrName); |
| 1321 | if (!attr && {2}) {{ |
| 1322 | emitError() << "expected key entry for {1} in DictionaryAttr to set " |
| 1323 | "Properties."; |
| 1324 | return ::mlir::failure(); |
| 1325 | } |
| 1326 | if (attr && ::mlir::failed(setFromAttr(prop.{1}, attr, emitError))) |
| 1327 | return ::mlir::failure(); |
| 1328 | )decl" ; |
| 1329 | |
| 1330 | // Generate the setter for any property not parsed elsewhere. |
| 1331 | for (const NamedProperty &namedProperty : op.getProperties()) { |
| 1332 | if (fmt.usedProperties.contains(key: &namedProperty)) |
| 1333 | continue; |
| 1334 | |
| 1335 | auto scope = body.scope(open: "{\n" , close: "}\n" , /*indent=*/true); |
| 1336 | |
| 1337 | StringRef name = namedProperty.name; |
| 1338 | const Property &prop = namedProperty.prop; |
| 1339 | bool isRequired = !prop.hasDefaultValue(); |
| 1340 | FmtContext fctx; |
| 1341 | body << formatv(Fmt: propFromAttrFmt, |
| 1342 | Vals: tgfmt(fmt: prop.getConvertFromAttributeCall(), |
| 1343 | ctx: &fctx.addSubst(placeholder: "_attr" , subst: "propAttr" ) |
| 1344 | .addSubst(placeholder: "_storage" , subst: "propStorage" ) |
| 1345 | .addSubst(placeholder: "_diag" , subst: "emitError" )), |
| 1346 | Vals&: name, Vals&: isRequired); |
| 1347 | } |
| 1348 | |
| 1349 | // Generate the setter for any attribute not parsed elsewhere. |
| 1350 | for (const NamedAttribute &namedAttr : op.getAttributes()) { |
| 1351 | if (fmt.usedAttributes.contains(key: &namedAttr)) |
| 1352 | continue; |
| 1353 | |
| 1354 | const Attribute &attr = namedAttr.attr; |
| 1355 | // Derived attributes do not need to be parsed. |
| 1356 | if (attr.isDerivedAttr()) |
| 1357 | continue; |
| 1358 | |
| 1359 | auto scope = body.scope(open: "{\n" , close: "}\n" , /*indent=*/true); |
| 1360 | |
| 1361 | // If the attribute has a default value or is optional, it does not need to |
| 1362 | // be present in the parsed dictionary attribute. |
| 1363 | bool isRequired = !attr.isOptional() && !attr.hasDefaultValue(); |
| 1364 | body << formatv(Fmt: R"decl( |
| 1365 | auto &propStorage = prop.{0}; |
| 1366 | auto {0}AttrName = StringAttr::get(ctx, "{0}"); |
| 1367 | auto attr = dict.get({0}AttrName); |
| 1368 | usedKeys.insert(StringAttr::get(ctx, "{1}")); |
| 1369 | if (attr || /*isRequired=*/{1}) {{ |
| 1370 | if (!attr) {{ |
| 1371 | emitError() << "expected key entry for {0} in DictionaryAttr to set " |
| 1372 | "Properties."; |
| 1373 | return ::mlir::failure(); |
| 1374 | } |
| 1375 | auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr); |
| 1376 | if (convertedAttr) {{ |
| 1377 | propStorage = convertedAttr; |
| 1378 | } else {{ |
| 1379 | emitError() << "Invalid attribute `{0}` in property conversion: " << attr; |
| 1380 | return ::mlir::failure(); |
| 1381 | } |
| 1382 | } |
| 1383 | )decl" , |
| 1384 | Vals: namedAttr.name, Vals&: isRequired); |
| 1385 | } |
| 1386 | body << R"decl( |
| 1387 | for (NamedAttribute attr : dict) { |
| 1388 | if (!usedKeys.contains(attr.getName())) |
| 1389 | return emitError() << "unknown key '" << attr.getName() << |
| 1390 | "' when parsing properties dictionary"; |
| 1391 | } |
| 1392 | return ::mlir::success(); |
| 1393 | )decl" ; |
| 1394 | } |
| 1395 | |
| 1396 | void OperationFormat::genParser(Operator &op, OpClass &opClass) { |
| 1397 | SmallVector<MethodParameter> paramList; |
| 1398 | paramList.emplace_back(Args: "::mlir::OpAsmParser &" , Args: "parser" ); |
| 1399 | paramList.emplace_back(Args: "::mlir::OperationState &" , Args: "result" ); |
| 1400 | |
| 1401 | auto *method = opClass.addStaticMethod(retType: "::mlir::ParseResult" , name: "parse" , |
| 1402 | args: std::move(paramList)); |
| 1403 | auto &body = method->body(); |
| 1404 | |
| 1405 | // Generate variables to store the operands and type within the format. This |
| 1406 | // allows for referencing these variables in the presence of optional |
| 1407 | // groupings. |
| 1408 | for (FormatElement *element : elements) |
| 1409 | genElementParserStorage(element, op, body); |
| 1410 | |
| 1411 | // A format context used when parsing attributes with buildable types. |
| 1412 | FmtContext attrTypeCtx; |
| 1413 | attrTypeCtx.withBuilder(subst: "parser.getBuilder()" ); |
| 1414 | |
| 1415 | // Generate parsers for each of the elements. |
| 1416 | for (FormatElement *element : elements) |
| 1417 | genElementParser(element, body, attrTypeCtx); |
| 1418 | |
| 1419 | // Generate the code to resolve the operand/result types and successors now |
| 1420 | // that they have been parsed. |
| 1421 | genParserRegionResolution(op, body); |
| 1422 | genParserSuccessorResolution(op, body); |
| 1423 | genParserVariadicSegmentResolution(op, body); |
| 1424 | genParserTypeResolution(op, body); |
| 1425 | |
| 1426 | body << " return ::mlir::success();\n" ; |
| 1427 | |
| 1428 | genParsedAttrPropertiesSetter(fmt&: *this, op, opClass); |
| 1429 | } |
| 1430 | |
| 1431 | void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, |
| 1432 | FmtContext &attrTypeCtx, |
| 1433 | GenContext genCtx) { |
| 1434 | /// Optional Group. |
| 1435 | if (auto *optional = dyn_cast<OptionalElement>(Val: element)) { |
| 1436 | auto genElementParsers = [&](FormatElement *firstElement, |
| 1437 | ArrayRef<FormatElement *> elements, |
| 1438 | bool thenGroup) { |
| 1439 | // If the anchor is a unit attribute, we don't need to print it. When |
| 1440 | // parsing, we will add this attribute if this group is present. |
| 1441 | FormatElement *elidedAnchorElement = nullptr; |
| 1442 | auto *anchorVar = dyn_cast<AttributeLikeVariable>(Val: optional->getAnchor()); |
| 1443 | if (anchorVar && anchorVar != firstElement && anchorVar->isUnit()) { |
| 1444 | elidedAnchorElement = anchorVar; |
| 1445 | |
| 1446 | if (!thenGroup == optional->isInverted()) { |
| 1447 | // Add the anchor unit attribute or property to the operation state |
| 1448 | // or set the property to true. |
| 1449 | if (isa<PropertyVariable>(Val: anchorVar)) { |
| 1450 | body << formatv( |
| 1451 | Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = true;" , |
| 1452 | Vals: anchorVar->getName(), Vals&: opCppClassName); |
| 1453 | } else if (useProperties) { |
| 1454 | body << formatv( |
| 1455 | Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = " |
| 1456 | "parser.getBuilder().getUnitAttr();" , |
| 1457 | Vals: anchorVar->getName(), Vals&: opCppClassName); |
| 1458 | } else { |
| 1459 | body << " result.addAttribute(\"" << anchorVar->getName() |
| 1460 | << "\", parser.getBuilder().getUnitAttr());\n" ; |
| 1461 | } |
| 1462 | } |
| 1463 | } |
| 1464 | |
| 1465 | // Generate the rest of the elements inside an optional group. Elements in |
| 1466 | // an optional group after the guard are parsed as required. |
| 1467 | for (FormatElement *childElement : elements) |
| 1468 | if (childElement != elidedAnchorElement) |
| 1469 | genElementParser(element: childElement, body, attrTypeCtx, |
| 1470 | genCtx: GenContext::Optional); |
| 1471 | }; |
| 1472 | |
| 1473 | ArrayRef<FormatElement *> thenElements = |
| 1474 | optional->getThenElements(/*parseable=*/true); |
| 1475 | |
| 1476 | // Generate a special optional parser for the first element to gate the |
| 1477 | // parsing of the rest of the elements. |
| 1478 | FormatElement *firstElement = thenElements.front(); |
| 1479 | if (auto *attrVar = dyn_cast<AttributeVariable>(Val: firstElement)) { |
| 1480 | genAttrParser(attr: attrVar, body, attrTypeCtx, /*parseAsOptional=*/true, |
| 1481 | useProperties, opCppClassName); |
| 1482 | body << " if (" << attrVar->getVar()->name << "Attr) {\n" ; |
| 1483 | } else if (auto *propVar = dyn_cast<PropertyVariable>(Val: firstElement)) { |
| 1484 | genPropertyParser(propVar, body, opCppClassName, /*requireParse=*/false); |
| 1485 | body << formatv(Fmt: "if ({0}PropParseResult.has_value() && " |
| 1486 | "succeeded(*{0}PropParseResult)) " , |
| 1487 | Vals: propVar->getVar()->name) |
| 1488 | << " {\n" ; |
| 1489 | } else if (auto *literal = dyn_cast<LiteralElement>(Val: firstElement)) { |
| 1490 | body << " if (::mlir::succeeded(parser.parseOptional" ; |
| 1491 | genLiteralParser(value: literal->getSpelling(), body); |
| 1492 | body << ")) {\n" ; |
| 1493 | } else if (auto *opVar = dyn_cast<OperandVariable>(Val: firstElement)) { |
| 1494 | genElementParser(element: opVar, body, attrTypeCtx); |
| 1495 | body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n" ; |
| 1496 | } else if (auto *regionVar = dyn_cast<RegionVariable>(Val: firstElement)) { |
| 1497 | const NamedRegion *region = regionVar->getVar(); |
| 1498 | if (region->isVariadic()) { |
| 1499 | genElementParser(element: regionVar, body, attrTypeCtx); |
| 1500 | body << " if (!" << region->name << "Regions.empty()) {\n" ; |
| 1501 | } else { |
| 1502 | body << formatv(Fmt: optionalRegionParserCode, Vals: region->name); |
| 1503 | body << " if (!" << region->name << "Region->empty()) {\n " ; |
| 1504 | if (hasImplicitTermTrait) |
| 1505 | body << formatv(Fmt: regionEnsureTerminatorParserCode, Vals: region->name); |
| 1506 | else if (hasSingleBlockTrait) |
| 1507 | body << formatv(Fmt: regionEnsureSingleBlockParserCode, Vals: region->name); |
| 1508 | } |
| 1509 | } else if (auto *custom = dyn_cast<CustomDirective>(Val: firstElement)) { |
| 1510 | body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n" ; |
| 1511 | genCustomDirectiveParser(dir: custom, body, useProperties, opCppClassName, |
| 1512 | /*isOptional=*/true); |
| 1513 | body << " return ::mlir::success();\n" |
| 1514 | << " }(); optResult.has_value() && ::mlir::failed(*optResult)) {\n" |
| 1515 | << " return ::mlir::failure();\n" |
| 1516 | << " } else if (optResult.has_value()) {\n" ; |
| 1517 | } |
| 1518 | |
| 1519 | genElementParsers(firstElement, thenElements.drop_front(), |
| 1520 | /*thenGroup=*/true); |
| 1521 | body << " }" ; |
| 1522 | |
| 1523 | // Generate the else elements. |
| 1524 | auto elseElements = optional->getElseElements(); |
| 1525 | if (!elseElements.empty()) { |
| 1526 | body << " else {\n" ; |
| 1527 | ArrayRef<FormatElement *> elseElements = |
| 1528 | optional->getElseElements(/*parseable=*/true); |
| 1529 | genElementParsers(elseElements.front(), elseElements, |
| 1530 | /*thenGroup=*/false); |
| 1531 | body << " }" ; |
| 1532 | } |
| 1533 | body << "\n" ; |
| 1534 | |
| 1535 | /// OIList Directive |
| 1536 | } else if (OIListElement *oilist = dyn_cast<OIListElement>(Val: element)) { |
| 1537 | for (LiteralElement *le : oilist->getLiteralElements()) |
| 1538 | body << " bool " << le->getSpelling() << "Clause = false;\n" ; |
| 1539 | |
| 1540 | // Generate the parsing loop |
| 1541 | body << " while(true) {\n" ; |
| 1542 | for (auto clause : oilist->getClauses()) { |
| 1543 | LiteralElement *lelement = std::get<0>(t&: clause); |
| 1544 | ArrayRef<FormatElement *> pelement = std::get<1>(t&: clause); |
| 1545 | body << "if (succeeded(parser.parseOptional" ; |
| 1546 | genLiteralParser(value: lelement->getSpelling(), body); |
| 1547 | body << ")) {\n" ; |
| 1548 | StringRef lelementName = lelement->getSpelling(); |
| 1549 | body << formatv(Fmt: oilistParserCode, Vals&: lelementName); |
| 1550 | if (AttributeLikeVariable *unitVarElem = |
| 1551 | oilist->getUnitVariableParsingElement(pelement)) { |
| 1552 | if (isa<PropertyVariable>(Val: unitVarElem)) { |
| 1553 | body << formatv( |
| 1554 | Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = true;" , |
| 1555 | Vals: unitVarElem->getName(), Vals&: opCppClassName); |
| 1556 | } else if (useProperties) { |
| 1557 | body << formatv( |
| 1558 | Fmt: " result.getOrAddProperties<{1}::Properties>().{0} = " |
| 1559 | "parser.getBuilder().getUnitAttr();" , |
| 1560 | Vals: unitVarElem->getName(), Vals&: opCppClassName); |
| 1561 | } else { |
| 1562 | body << " result.addAttribute(\"" << unitVarElem->getName() |
| 1563 | << "\", UnitAttr::get(parser.getContext()));\n" ; |
| 1564 | } |
| 1565 | } else { |
| 1566 | for (FormatElement *el : pelement) |
| 1567 | genElementParser(element: el, body, attrTypeCtx); |
| 1568 | } |
| 1569 | body << " } else " ; |
| 1570 | } |
| 1571 | body << " {\n" ; |
| 1572 | body << " break;\n" ; |
| 1573 | body << " }\n" ; |
| 1574 | body << "}\n" ; |
| 1575 | |
| 1576 | /// Literals. |
| 1577 | } else if (LiteralElement *literal = dyn_cast<LiteralElement>(Val: element)) { |
| 1578 | body << " if (parser.parse" ; |
| 1579 | genLiteralParser(value: literal->getSpelling(), body); |
| 1580 | body << ")\n return ::mlir::failure();\n" ; |
| 1581 | |
| 1582 | /// Whitespaces. |
| 1583 | } else if (isa<WhitespaceElement>(Val: element)) { |
| 1584 | // Nothing to parse. |
| 1585 | |
| 1586 | /// Arguments. |
| 1587 | } else if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) { |
| 1588 | bool parseAsOptional = |
| 1589 | (genCtx == GenContext::Normal && attr->getVar()->attr.isOptional()); |
| 1590 | genAttrParser(attr, body, attrTypeCtx, parseAsOptional, useProperties, |
| 1591 | opCppClassName); |
| 1592 | } else if (auto *prop = dyn_cast<PropertyVariable>(Val: element)) { |
| 1593 | genPropertyParser(propVar: prop, body, opCppClassName); |
| 1594 | |
| 1595 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) { |
| 1596 | ArgumentLengthKind lengthKind = getArgumentLengthKind(var: operand->getVar()); |
| 1597 | StringRef name = operand->getVar()->name; |
| 1598 | if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) |
| 1599 | body << formatv(Fmt: variadicOfVariadicOperandParserCode, Vals&: name); |
| 1600 | else if (lengthKind == ArgumentLengthKind::Variadic) |
| 1601 | body << formatv(Fmt: variadicOperandParserCode, Vals&: name); |
| 1602 | else if (lengthKind == ArgumentLengthKind::Optional) |
| 1603 | body << formatv(Fmt: optionalOperandParserCode, Vals&: name); |
| 1604 | else |
| 1605 | body << formatv(Fmt: operandParserCode, Vals&: name); |
| 1606 | |
| 1607 | } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) { |
| 1608 | bool isVariadic = region->getVar()->isVariadic(); |
| 1609 | body << formatv(Fmt: isVariadic ? regionListParserCode : regionParserCode, |
| 1610 | Vals: region->getVar()->name); |
| 1611 | if (hasImplicitTermTrait) |
| 1612 | body << formatv(Fmt: isVariadic ? regionListEnsureTerminatorParserCode |
| 1613 | : regionEnsureTerminatorParserCode, |
| 1614 | Vals: region->getVar()->name); |
| 1615 | else if (hasSingleBlockTrait) |
| 1616 | body << formatv(Fmt: isVariadic ? regionListEnsureSingleBlockParserCode |
| 1617 | : regionEnsureSingleBlockParserCode, |
| 1618 | Vals: region->getVar()->name); |
| 1619 | |
| 1620 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) { |
| 1621 | bool isVariadic = successor->getVar()->isVariadic(); |
| 1622 | body << formatv(Fmt: isVariadic ? successorListParserCode : successorParserCode, |
| 1623 | Vals: successor->getVar()->name); |
| 1624 | |
| 1625 | /// Directives. |
| 1626 | } else if (auto *attrDict = dyn_cast<AttrDictDirective>(Val: element)) { |
| 1627 | body.indent() << "{\n" ; |
| 1628 | body.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n" |
| 1629 | << "if (parser.parseOptionalAttrDict" |
| 1630 | << (attrDict->isWithKeyword() ? "WithKeyword" : "" ) |
| 1631 | << "(result.attributes))\n" |
| 1632 | << " return ::mlir::failure();\n" ; |
| 1633 | if (useProperties) { |
| 1634 | body << "if (failed(verifyInherentAttrs(result.name, result.attributes, " |
| 1635 | "[&]() {\n" |
| 1636 | << " return parser.emitError(loc) << \"'\" << " |
| 1637 | "result.name.getStringRef() << \"' op \";\n" |
| 1638 | << " })))\n" |
| 1639 | << " return ::mlir::failure();\n" ; |
| 1640 | } |
| 1641 | body.unindent() << "}\n" ; |
| 1642 | body.unindent(); |
| 1643 | } else if (isa<PropDictDirective>(Val: element)) { |
| 1644 | if (useProperties) { |
| 1645 | body << " if (parseProperties(parser, result))\n" |
| 1646 | << " return ::mlir::failure();\n" ; |
| 1647 | } |
| 1648 | } else if (auto *customDir = dyn_cast<CustomDirective>(Val: element)) { |
| 1649 | genCustomDirectiveParser(dir: customDir, body, useProperties, opCppClassName); |
| 1650 | } else if (isa<OperandsDirective>(Val: element)) { |
| 1651 | body << " [[maybe_unused]] ::llvm::SMLoc allOperandLoc =" |
| 1652 | << " parser.getCurrentLocation();\n" |
| 1653 | << " if (parser.parseOperandList(allOperands))\n" |
| 1654 | << " return ::mlir::failure();\n" ; |
| 1655 | |
| 1656 | } else if (isa<RegionsDirective>(Val: element)) { |
| 1657 | body << formatv(Fmt: regionListParserCode, Vals: "full" ); |
| 1658 | if (hasImplicitTermTrait) |
| 1659 | body << formatv(Fmt: regionListEnsureTerminatorParserCode, Vals: "full" ); |
| 1660 | else if (hasSingleBlockTrait) |
| 1661 | body << formatv(Fmt: regionListEnsureSingleBlockParserCode, Vals: "full" ); |
| 1662 | |
| 1663 | } else if (isa<SuccessorsDirective>(Val: element)) { |
| 1664 | body << formatv(Fmt: successorListParserCode, Vals: "full" ); |
| 1665 | |
| 1666 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) { |
| 1667 | ArgumentLengthKind lengthKind; |
| 1668 | StringRef listName = getTypeListName(arg: dir->getArg(), lengthKind); |
| 1669 | if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { |
| 1670 | body << formatv(Fmt: variadicOfVariadicTypeParserCode, Vals&: listName); |
| 1671 | } else if (lengthKind == ArgumentLengthKind::Variadic) { |
| 1672 | body << formatv(Fmt: variadicTypeParserCode, Vals&: listName); |
| 1673 | } else if (lengthKind == ArgumentLengthKind::Optional) { |
| 1674 | body << formatv(Fmt: optionalTypeParserCode, Vals&: listName); |
| 1675 | } else { |
| 1676 | const char *parserCode = |
| 1677 | dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode; |
| 1678 | TypeSwitch<FormatElement *>(dir->getArg()) |
| 1679 | .Case<OperandVariable, ResultVariable>(caseFn: [&](auto operand) { |
| 1680 | body << formatv(false, parserCode, |
| 1681 | operand->getVar()->constraint.getCppType(), |
| 1682 | listName); |
| 1683 | }) |
| 1684 | .Default(defaultFn: [&](auto operand) { |
| 1685 | body << formatv(Validate: false, Fmt: parserCode, Vals: "::mlir::Type" , Vals&: listName); |
| 1686 | }); |
| 1687 | } |
| 1688 | } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(Val: element)) { |
| 1689 | ArgumentLengthKind ignored; |
| 1690 | body << formatv(Fmt: functionalTypeParserCode, |
| 1691 | Vals: getTypeListName(arg: dir->getInputs(), lengthKind&: ignored), |
| 1692 | Vals: getTypeListName(arg: dir->getResults(), lengthKind&: ignored)); |
| 1693 | } else { |
| 1694 | llvm_unreachable("unknown format element" ); |
| 1695 | } |
| 1696 | } |
| 1697 | |
| 1698 | void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) { |
| 1699 | // If any of type resolutions use transformed variables, make sure that the |
| 1700 | // types of those variables are resolved. |
| 1701 | SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables; |
| 1702 | FmtContext verifierFCtx; |
| 1703 | for (TypeResolution &resolver : |
| 1704 | llvm::concat<TypeResolution>(Ranges&: resultTypes, Ranges&: operandTypes)) { |
| 1705 | std::optional<StringRef> transformer = resolver.getVarTransformer(); |
| 1706 | if (!transformer) |
| 1707 | continue; |
| 1708 | // Ensure that we don't verify the same variables twice. |
| 1709 | const NamedTypeConstraint *variable = resolver.getVariable(); |
| 1710 | if (!variable || !verifiedVariables.insert(Ptr: variable).second) |
| 1711 | continue; |
| 1712 | |
| 1713 | auto constraint = variable->constraint; |
| 1714 | body << " for (::mlir::Type type : " << variable->name << "Types) {\n" |
| 1715 | << " (void)type;\n" |
| 1716 | << " if (!(" |
| 1717 | << tgfmt(fmt: constraint.getConditionTemplate(), |
| 1718 | ctx: &verifierFCtx.withSelf(subst: "type" )) |
| 1719 | << ")) {\n" |
| 1720 | << formatv(Fmt: " return parser.emitError(parser.getNameLoc()) << " |
| 1721 | "\"'{0}' must be {1}, but got \" << type;\n" , |
| 1722 | Vals: variable->name, Vals: constraint.getSummary()) |
| 1723 | << " }\n" |
| 1724 | << " }\n" ; |
| 1725 | } |
| 1726 | |
| 1727 | // Initialize the set of buildable types. |
| 1728 | if (!buildableTypes.empty()) { |
| 1729 | FmtContext typeBuilderCtx; |
| 1730 | typeBuilderCtx.withBuilder(subst: "parser.getBuilder()" ); |
| 1731 | for (auto &it : buildableTypes) |
| 1732 | body << " ::mlir::Type odsBuildableType" << it.second << " = " |
| 1733 | << tgfmt(fmt: it.first, ctx: &typeBuilderCtx) << ";\n" ; |
| 1734 | } |
| 1735 | |
| 1736 | // Emit the code necessary for a type resolver. |
| 1737 | auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) { |
| 1738 | if (std::optional<int> val = resolver.getBuilderIdx()) { |
| 1739 | body << "odsBuildableType" << *val; |
| 1740 | } else if (const NamedTypeConstraint *var = resolver.getVariable()) { |
| 1741 | if (std::optional<StringRef> tform = resolver.getVarTransformer()) { |
| 1742 | FmtContext fmtContext; |
| 1743 | fmtContext.addSubst(placeholder: "_ctxt" , subst: "parser.getContext()" ); |
| 1744 | if (var->isVariadic()) |
| 1745 | fmtContext.withSelf(subst: var->name + "Types" ); |
| 1746 | else |
| 1747 | fmtContext.withSelf(subst: var->name + "Types[0]" ); |
| 1748 | body << tgfmt(fmt: *tform, ctx: &fmtContext); |
| 1749 | } else { |
| 1750 | body << var->name << "Types" ; |
| 1751 | if (!var->isVariadic()) |
| 1752 | body << "[0]" ; |
| 1753 | } |
| 1754 | } else if (const NamedAttribute *attr = resolver.getAttribute()) { |
| 1755 | if (std::optional<StringRef> tform = resolver.getVarTransformer()) |
| 1756 | body << tgfmt(fmt: *tform, |
| 1757 | ctx: &FmtContext().withSelf(subst: attr->name + "Attr.getType()" )); |
| 1758 | else |
| 1759 | body << attr->name << "Attr.getType()" ; |
| 1760 | } else { |
| 1761 | body << curVar << "Types" ; |
| 1762 | } |
| 1763 | }; |
| 1764 | |
| 1765 | // Resolve each of the result types. |
| 1766 | if (!infersResultTypes) { |
| 1767 | if (allResultTypes) { |
| 1768 | body << " result.addTypes(allResultTypes);\n" ; |
| 1769 | } else { |
| 1770 | for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { |
| 1771 | body << " result.addTypes(" ; |
| 1772 | emitTypeResolver(resultTypes[i], op.getResultName(index: i)); |
| 1773 | body << ");\n" ; |
| 1774 | } |
| 1775 | } |
| 1776 | } |
| 1777 | |
| 1778 | // Emit the operand type resolutions. |
| 1779 | genParserOperandTypeResolution(op, body, emitTypeResolver); |
| 1780 | |
| 1781 | // Handle return type inference once all operands have been resolved |
| 1782 | if (infersResultTypes) |
| 1783 | body << formatv(Fmt: inferReturnTypesParserCode, Vals: op.getCppClassName()); |
| 1784 | } |
| 1785 | |
| 1786 | void OperationFormat::genParserOperandTypeResolution( |
| 1787 | Operator &op, MethodBody &body, |
| 1788 | function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) { |
| 1789 | // Early exit if there are no operands. |
| 1790 | if (op.getNumOperands() == 0) |
| 1791 | return; |
| 1792 | |
| 1793 | // Handle the case where all operand types are grouped together with |
| 1794 | // "types(operands)". |
| 1795 | if (allOperandTypes) { |
| 1796 | // If `operands` was specified, use the full operand list directly. |
| 1797 | if (allOperands) { |
| 1798 | body << " if (parser.resolveOperands(allOperands, allOperandTypes, " |
| 1799 | "allOperandLoc, result.operands))\n" |
| 1800 | " return ::mlir::failure();\n" ; |
| 1801 | return; |
| 1802 | } |
| 1803 | |
| 1804 | // Otherwise, use llvm::concat to merge the disjoint operand lists together. |
| 1805 | // llvm::concat does not allow the case of a single range, so guard it here. |
| 1806 | body << " if (parser.resolveOperands(" ; |
| 1807 | if (op.getNumOperands() > 1) { |
| 1808 | body << "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(" ; |
| 1809 | llvm::interleaveComma(c: op.getOperands(), os&: body, each_fn: [&](auto &operand) { |
| 1810 | body << operand.name << "Operands" ; |
| 1811 | }); |
| 1812 | body << ")" ; |
| 1813 | } else { |
| 1814 | body << op.operand_begin()->name << "Operands" ; |
| 1815 | } |
| 1816 | body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n" |
| 1817 | << " return ::mlir::failure();\n" ; |
| 1818 | return; |
| 1819 | } |
| 1820 | |
| 1821 | // Handle the case where all operands are grouped together with "operands". |
| 1822 | if (allOperands) { |
| 1823 | body << " if (parser.resolveOperands(allOperands, " ; |
| 1824 | |
| 1825 | // Group all of the operand types together to perform the resolution all at |
| 1826 | // once. Use llvm::concat to perform the merge. llvm::concat does not allow |
| 1827 | // the case of a single range, so guard it here. |
| 1828 | if (op.getNumOperands() > 1) { |
| 1829 | body << "::llvm::concat<const ::mlir::Type>(" ; |
| 1830 | llvm::interleaveComma( |
| 1831 | c: llvm::seq<int>(Begin: 0, End: op.getNumOperands()), os&: body, each_fn: [&](int i) { |
| 1832 | body << "::llvm::ArrayRef<::mlir::Type>(" ; |
| 1833 | emitTypeResolver(operandTypes[i], op.getOperand(index: i).name); |
| 1834 | body << ")" ; |
| 1835 | }); |
| 1836 | body << ")" ; |
| 1837 | } else { |
| 1838 | emitTypeResolver(operandTypes.front(), op.getOperand(index: 0).name); |
| 1839 | } |
| 1840 | |
| 1841 | body << ", allOperandLoc, result.operands))\n return " |
| 1842 | "::mlir::failure();\n" ; |
| 1843 | return; |
| 1844 | } |
| 1845 | |
| 1846 | // The final case is the one where each of the operands types are resolved |
| 1847 | // separately. |
| 1848 | for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { |
| 1849 | NamedTypeConstraint &operand = op.getOperand(index: i); |
| 1850 | body << " if (parser.resolveOperands(" << operand.name << "Operands, " ; |
| 1851 | |
| 1852 | // Resolve the type of this operand. |
| 1853 | TypeResolution &operandType = operandTypes[i]; |
| 1854 | emitTypeResolver(operandType, operand.name); |
| 1855 | |
| 1856 | body << ", " << operand.name |
| 1857 | << "OperandsLoc, result.operands))\n return ::mlir::failure();\n" ; |
| 1858 | } |
| 1859 | } |
| 1860 | |
| 1861 | void OperationFormat::genParserRegionResolution(Operator &op, |
| 1862 | MethodBody &body) { |
| 1863 | // Check for the case where all regions were parsed. |
| 1864 | bool hasAllRegions = llvm::any_of( |
| 1865 | Range&: elements, P: [](FormatElement *elt) { return isa<RegionsDirective>(Val: elt); }); |
| 1866 | if (hasAllRegions) { |
| 1867 | body << " result.addRegions(fullRegions);\n" ; |
| 1868 | return; |
| 1869 | } |
| 1870 | |
| 1871 | // Otherwise, handle each region individually. |
| 1872 | for (const NamedRegion ®ion : op.getRegions()) { |
| 1873 | if (region.isVariadic()) |
| 1874 | body << " result.addRegions(" << region.name << "Regions);\n" ; |
| 1875 | else |
| 1876 | body << " result.addRegion(std::move(" << region.name << "Region));\n" ; |
| 1877 | } |
| 1878 | } |
| 1879 | |
| 1880 | void OperationFormat::genParserSuccessorResolution(Operator &op, |
| 1881 | MethodBody &body) { |
| 1882 | // Check for the case where all successors were parsed. |
| 1883 | bool hasAllSuccessors = llvm::any_of(Range&: elements, P: [](FormatElement *elt) { |
| 1884 | return isa<SuccessorsDirective>(Val: elt); |
| 1885 | }); |
| 1886 | if (hasAllSuccessors) { |
| 1887 | body << " result.addSuccessors(fullSuccessors);\n" ; |
| 1888 | return; |
| 1889 | } |
| 1890 | |
| 1891 | // Otherwise, handle each successor individually. |
| 1892 | for (const NamedSuccessor &successor : op.getSuccessors()) { |
| 1893 | if (successor.isVariadic()) |
| 1894 | body << " result.addSuccessors(" << successor.name << "Successors);\n" ; |
| 1895 | else |
| 1896 | body << " result.addSuccessors(" << successor.name << "Successor);\n" ; |
| 1897 | } |
| 1898 | } |
| 1899 | |
| 1900 | void OperationFormat::genParserVariadicSegmentResolution(Operator &op, |
| 1901 | MethodBody &body) { |
| 1902 | if (!allOperands) { |
| 1903 | if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments" )) { |
| 1904 | auto interleaveFn = [&](const NamedTypeConstraint &operand) { |
| 1905 | // If the operand is variadic emit the parsed size. |
| 1906 | if (operand.isVariableLength()) |
| 1907 | body << "static_cast<int32_t>(" << operand.name << "Operands.size())" ; |
| 1908 | else |
| 1909 | body << "1" ; |
| 1910 | }; |
| 1911 | if (op.getDialect().usePropertiesForAttributes()) { |
| 1912 | body << "::llvm::copy(::llvm::ArrayRef<int32_t>({" ; |
| 1913 | llvm::interleaveComma(c: op.getOperands(), os&: body, each_fn: interleaveFn); |
| 1914 | body << formatv(Fmt: "}), " |
| 1915 | "result.getOrAddProperties<{0}::Properties>()." |
| 1916 | "operandSegmentSizes.begin());\n" , |
| 1917 | Vals: op.getCppClassName()); |
| 1918 | } else { |
| 1919 | body << " result.addAttribute(\"operandSegmentSizes\", " |
| 1920 | << "parser.getBuilder().getDenseI32ArrayAttr({" ; |
| 1921 | llvm::interleaveComma(c: op.getOperands(), os&: body, each_fn: interleaveFn); |
| 1922 | body << "}));\n" ; |
| 1923 | } |
| 1924 | } |
| 1925 | for (const NamedTypeConstraint &operand : op.getOperands()) { |
| 1926 | if (!operand.isVariadicOfVariadic()) |
| 1927 | continue; |
| 1928 | if (op.getDialect().usePropertiesForAttributes()) { |
| 1929 | body << formatv( |
| 1930 | Fmt: " result.getOrAddProperties<{0}::Properties>().{1} = " |
| 1931 | "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n" , |
| 1932 | Vals: op.getCppClassName(), |
| 1933 | Vals: operand.constraint.getVariadicOfVariadicSegmentSizeAttr(), |
| 1934 | Vals: operand.name); |
| 1935 | } else { |
| 1936 | body << formatv( |
| 1937 | Fmt: " result.addAttribute(\"{0}\", " |
| 1938 | "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));" |
| 1939 | "\n" , |
| 1940 | Vals: operand.constraint.getVariadicOfVariadicSegmentSizeAttr(), |
| 1941 | Vals: operand.name); |
| 1942 | } |
| 1943 | } |
| 1944 | } |
| 1945 | |
| 1946 | if (!allResultTypes && |
| 1947 | op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments" )) { |
| 1948 | auto interleaveFn = [&](const NamedTypeConstraint &result) { |
| 1949 | // If the result is variadic emit the parsed size. |
| 1950 | if (result.isVariableLength()) |
| 1951 | body << "static_cast<int32_t>(" << result.name << "Types.size())" ; |
| 1952 | else |
| 1953 | body << "1" ; |
| 1954 | }; |
| 1955 | if (op.getDialect().usePropertiesForAttributes()) { |
| 1956 | body << "::llvm::copy(::llvm::ArrayRef<int32_t>({" ; |
| 1957 | llvm::interleaveComma(c: op.getResults(), os&: body, each_fn: interleaveFn); |
| 1958 | body << formatv(Fmt: "}), " |
| 1959 | "result.getOrAddProperties<{0}::Properties>()." |
| 1960 | "resultSegmentSizes.begin());\n" , |
| 1961 | Vals: op.getCppClassName()); |
| 1962 | } else { |
| 1963 | body << " result.addAttribute(\"resultSegmentSizes\", " |
| 1964 | << "parser.getBuilder().getDenseI32ArrayAttr({" ; |
| 1965 | llvm::interleaveComma(c: op.getResults(), os&: body, each_fn: interleaveFn); |
| 1966 | body << "}));\n" ; |
| 1967 | } |
| 1968 | } |
| 1969 | } |
| 1970 | |
| 1971 | //===----------------------------------------------------------------------===// |
| 1972 | // PrinterGen |
| 1973 | //===----------------------------------------------------------------------===// |
| 1974 | |
| 1975 | /// The code snippet used to generate a printer call for a region of an |
| 1976 | // operation that has the SingleBlockImplicitTerminator trait. |
| 1977 | /// |
| 1978 | /// {0}: The name of the region. |
| 1979 | const char *regionSingleBlockImplicitTerminatorPrinterCode = R"( |
| 1980 | { |
| 1981 | bool printTerminator = true; |
| 1982 | if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{ |
| 1983 | printTerminator = !term->getAttrDictionary().empty() || |
| 1984 | term->getNumOperands() != 0 || |
| 1985 | term->getNumResults() != 0; |
| 1986 | } |
| 1987 | _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true, |
| 1988 | /*printBlockTerminators=*/printTerminator); |
| 1989 | } |
| 1990 | )" ; |
| 1991 | |
| 1992 | /// The code snippet used to generate a printer call for an enum that has cases |
| 1993 | /// that can't be represented with a keyword. |
| 1994 | /// |
| 1995 | /// {0}: The name of the enum attribute. |
| 1996 | /// {1}: The name of the enum attributes symbolToString function. |
| 1997 | const char *enumAttrBeginPrinterCode = R"( |
| 1998 | { |
| 1999 | auto caseValue = {0}(); |
| 2000 | auto caseValueStr = {1}(caseValue); |
| 2001 | )" ; |
| 2002 | |
| 2003 | /// Generate a check that an optional or default-valued attribute or property |
| 2004 | /// has a non-default value. For these purposes, the default value of an |
| 2005 | /// optional attribute is its presence, even if the attribute itself has a |
| 2006 | /// default value. |
| 2007 | static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, |
| 2008 | AttributeVariable &attrElement) { |
| 2009 | Attribute attr = attrElement.getVar()->attr; |
| 2010 | std::string getter = op.getGetterName(name: attrElement.getVar()->name); |
| 2011 | bool optionalAndDefault = attr.isOptional() && attr.hasDefaultValue(); |
| 2012 | if (optionalAndDefault) |
| 2013 | body << "(" ; |
| 2014 | if (attr.isOptional()) |
| 2015 | body << getter << "Attr()" ; |
| 2016 | if (optionalAndDefault) |
| 2017 | body << " && " ; |
| 2018 | if (attr.hasDefaultValue()) { |
| 2019 | FmtContext fctx; |
| 2020 | fctx.withBuilder(subst: "::mlir::OpBuilder((*this)->getContext())" ); |
| 2021 | body << getter << "Attr() != " |
| 2022 | << tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx, |
| 2023 | vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx)); |
| 2024 | } |
| 2025 | if (optionalAndDefault) |
| 2026 | body << ")" ; |
| 2027 | } |
| 2028 | |
| 2029 | static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, |
| 2030 | PropertyVariable &propElement) { |
| 2031 | FmtContext fctx; |
| 2032 | fctx.withBuilder(subst: "::mlir::OpBuilder((*this)->getContext())" ); |
| 2033 | body << op.getGetterName(name: propElement.getVar()->name) << "() != " |
| 2034 | << tgfmt(fmt: propElement.getVar()->prop.getDefaultValue(), ctx: &fctx); |
| 2035 | } |
| 2036 | |
| 2037 | /// Elide the variadic segment size attributes if necessary. |
| 2038 | /// This pushes elided attribute names in `elidedStorage`. |
| 2039 | static void genVariadicSegmentElision(OperationFormat &fmt, Operator &op, |
| 2040 | MethodBody &body, |
| 2041 | const char *elidedStorage) { |
| 2042 | if (!fmt.allOperands && |
| 2043 | op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments" )) |
| 2044 | body << " " << elidedStorage << ".push_back(\"operandSegmentSizes\");\n" ; |
| 2045 | if (!fmt.allResultTypes && |
| 2046 | op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments" )) |
| 2047 | body << " " << elidedStorage << ".push_back(\"resultSegmentSizes\");\n" ; |
| 2048 | } |
| 2049 | |
| 2050 | /// Generate the printer for the 'prop-dict' directive. |
| 2051 | static void genPropDictPrinter(OperationFormat &fmt, Operator &op, |
| 2052 | MethodBody &body) { |
| 2053 | body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedProps;\n" ; |
| 2054 | |
| 2055 | genVariadicSegmentElision(fmt, op, body, elidedStorage: "elidedProps" ); |
| 2056 | |
| 2057 | for (const NamedProperty *namedProperty : fmt.usedProperties) |
| 2058 | body << " elidedProps.push_back(\"" << namedProperty->name << "\");\n" ; |
| 2059 | for (const NamedAttribute *namedAttr : fmt.usedAttributes) |
| 2060 | body << " elidedProps.push_back(\"" << namedAttr->name << "\");\n" ; |
| 2061 | |
| 2062 | // Add code to check attributes for equality with their default values. |
| 2063 | // Default-valued attributes will not be printed when their value matches the |
| 2064 | // default. |
| 2065 | for (const NamedAttribute &namedAttr : op.getAttributes()) { |
| 2066 | const Attribute &attr = namedAttr.attr; |
| 2067 | if (!attr.isDerivedAttr() && attr.hasDefaultValue()) { |
| 2068 | const StringRef &name = namedAttr.name; |
| 2069 | FmtContext fctx; |
| 2070 | fctx.withBuilder(subst: "odsBuilder" ); |
| 2071 | std::string defaultValue = |
| 2072 | std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx, |
| 2073 | vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx))); |
| 2074 | body << " {\n" ; |
| 2075 | body << " ::mlir::Builder odsBuilder(getContext());\n" ; |
| 2076 | body << " ::mlir::Attribute attr = " << op.getGetterName(name) |
| 2077 | << "Attr();\n" ; |
| 2078 | body << " if(attr && (attr == " << defaultValue << "))\n" ; |
| 2079 | body << " elidedProps.push_back(\"" << name << "\");\n" ; |
| 2080 | body << " }\n" ; |
| 2081 | } |
| 2082 | } |
| 2083 | // Similarly, elide default-valued properties. |
| 2084 | for (const NamedProperty &prop : op.getProperties()) { |
| 2085 | if (prop.prop.hasDefaultValue()) { |
| 2086 | FmtContext fctx; |
| 2087 | fctx.withBuilder(subst: "odsBuilder" ); |
| 2088 | body << " if (" << op.getGetterName(name: prop.name) |
| 2089 | << "() == " << tgfmt(fmt: prop.prop.getDefaultValue(), ctx: &fctx) << ") {" ; |
| 2090 | body << " elidedProps.push_back(\"" << prop.name << "\");\n" ; |
| 2091 | body << " }\n" ; |
| 2092 | } |
| 2093 | } |
| 2094 | |
| 2095 | if (fmt.useProperties) { |
| 2096 | body << " _odsPrinter << \" \";\n" |
| 2097 | << " printProperties(this->getContext(), _odsPrinter, " |
| 2098 | "getProperties(), elidedProps);\n" ; |
| 2099 | } |
| 2100 | } |
| 2101 | |
| 2102 | /// Generate the printer for the 'attr-dict' directive. |
| 2103 | static void genAttrDictPrinter(OperationFormat &fmt, Operator &op, |
| 2104 | MethodBody &body, bool withKeyword) { |
| 2105 | body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;\n" ; |
| 2106 | |
| 2107 | genVariadicSegmentElision(fmt, op, body, elidedStorage: "elidedAttrs" ); |
| 2108 | |
| 2109 | for (const StringRef key : fmt.inferredAttributes.keys()) |
| 2110 | body << " elidedAttrs.push_back(\"" << key << "\");\n" ; |
| 2111 | for (const NamedAttribute *attr : fmt.usedAttributes) |
| 2112 | body << " elidedAttrs.push_back(\"" << attr->name << "\");\n" ; |
| 2113 | |
| 2114 | // Add code to check attributes for equality with their default values. |
| 2115 | // Default-valued attributes will not be printed when their value matches the |
| 2116 | // default. |
| 2117 | for (const NamedAttribute &namedAttr : op.getAttributes()) { |
| 2118 | const Attribute &attr = namedAttr.attr; |
| 2119 | if (!attr.isDerivedAttr() && attr.hasDefaultValue()) { |
| 2120 | const StringRef &name = namedAttr.name; |
| 2121 | FmtContext fctx; |
| 2122 | fctx.withBuilder(subst: "odsBuilder" ); |
| 2123 | std::string defaultValue = |
| 2124 | std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx, |
| 2125 | vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx))); |
| 2126 | body << " {\n" ; |
| 2127 | body << " ::mlir::Builder odsBuilder(getContext());\n" ; |
| 2128 | body << " ::mlir::Attribute attr = " << op.getGetterName(name) |
| 2129 | << "Attr();\n" ; |
| 2130 | body << " if(attr && (attr == " << defaultValue << "))\n" ; |
| 2131 | body << " elidedAttrs.push_back(\"" << name << "\");\n" ; |
| 2132 | body << " }\n" ; |
| 2133 | } |
| 2134 | } |
| 2135 | if (fmt.hasPropDict) |
| 2136 | body << " _odsPrinter.printOptionalAttrDict" |
| 2137 | << (withKeyword ? "WithKeyword" : "" ) |
| 2138 | << "(llvm::to_vector((*this)->getDiscardableAttrs()), elidedAttrs);\n" ; |
| 2139 | else |
| 2140 | body << " _odsPrinter.printOptionalAttrDict" |
| 2141 | << (withKeyword ? "WithKeyword" : "" ) |
| 2142 | << "((*this)->getAttrs(), elidedAttrs);\n" ; |
| 2143 | } |
| 2144 | |
| 2145 | /// Generate the printer for a literal value. `shouldEmitSpace` is true if a |
| 2146 | /// space should be emitted before this element. `lastWasPunctuation` is true if |
| 2147 | /// the previous element was a punctuation literal. |
| 2148 | static void genLiteralPrinter(StringRef value, MethodBody &body, |
| 2149 | bool &shouldEmitSpace, bool &lastWasPunctuation) { |
| 2150 | body << " _odsPrinter" ; |
| 2151 | |
| 2152 | // Don't insert a space for certain punctuation. |
| 2153 | if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation)) |
| 2154 | body << " << ' '" ; |
| 2155 | body << " << \"" << value << "\";\n" ; |
| 2156 | |
| 2157 | // Insert a space after certain literals. |
| 2158 | shouldEmitSpace = |
| 2159 | value.size() != 1 || !StringRef("<({[" ).contains(C: value.front()); |
| 2160 | lastWasPunctuation = value.front() != '_' && !isalpha(value.front()); |
| 2161 | } |
| 2162 | |
| 2163 | /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation` |
| 2164 | /// are set to false. |
| 2165 | static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace, |
| 2166 | bool &lastWasPunctuation) { |
| 2167 | if (value) { |
| 2168 | body << " _odsPrinter << ' ';\n" ; |
| 2169 | lastWasPunctuation = false; |
| 2170 | } else { |
| 2171 | lastWasPunctuation = true; |
| 2172 | } |
| 2173 | shouldEmitSpace = false; |
| 2174 | } |
| 2175 | |
| 2176 | /// Generate the printer for a custom directive parameter. |
| 2177 | static void genCustomDirectiveParameterPrinter(FormatElement *element, |
| 2178 | const Operator &op, |
| 2179 | MethodBody &body) { |
| 2180 | if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) { |
| 2181 | body << op.getGetterName(name: attr->getVar()->name) << "Attr()" ; |
| 2182 | |
| 2183 | } else if (isa<AttrDictDirective>(Val: element)) { |
| 2184 | body << "getOperation()->getAttrDictionary()" ; |
| 2185 | |
| 2186 | } else if (isa<PropDictDirective>(Val: element)) { |
| 2187 | body << "getProperties()" ; |
| 2188 | |
| 2189 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) { |
| 2190 | body << op.getGetterName(name: operand->getVar()->name) << "()" ; |
| 2191 | |
| 2192 | } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) { |
| 2193 | body << op.getGetterName(name: region->getVar()->name) << "()" ; |
| 2194 | |
| 2195 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) { |
| 2196 | body << op.getGetterName(name: successor->getVar()->name) << "()" ; |
| 2197 | |
| 2198 | } else if (auto *dir = dyn_cast<RefDirective>(Val: element)) { |
| 2199 | genCustomDirectiveParameterPrinter(element: dir->getArg(), op, body); |
| 2200 | |
| 2201 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) { |
| 2202 | auto *typeOperand = dir->getArg(); |
| 2203 | auto *operand = dyn_cast<OperandVariable>(Val: typeOperand); |
| 2204 | auto *var = operand ? operand->getVar() |
| 2205 | : cast<ResultVariable>(Val: typeOperand)->getVar(); |
| 2206 | std::string name = op.getGetterName(name: var->name); |
| 2207 | if (var->isVariadic()) |
| 2208 | body << name << "().getTypes()" ; |
| 2209 | else if (var->isOptional()) |
| 2210 | body << formatv(Fmt: "({0}() ? {0}().getType() : ::mlir::Type())" , Vals&: name); |
| 2211 | else |
| 2212 | body << name << "().getType()" ; |
| 2213 | |
| 2214 | } else if (auto *string = dyn_cast<StringElement>(Val: element)) { |
| 2215 | FmtContext ctx; |
| 2216 | ctx.withBuilder(subst: "::mlir::Builder(getContext())" ); |
| 2217 | ctx.addSubst(placeholder: "_ctxt" , subst: "getContext()" ); |
| 2218 | body << tgfmt(fmt: string->getValue(), ctx: &ctx); |
| 2219 | |
| 2220 | } else if (auto *property = dyn_cast<PropertyVariable>(Val: element)) { |
| 2221 | FmtContext ctx; |
| 2222 | const NamedProperty *namedProperty = property->getVar(); |
| 2223 | ctx.addSubst(placeholder: "_storage" , subst: "getProperties()." + namedProperty->name); |
| 2224 | body << tgfmt(fmt: namedProperty->prop.getConvertFromStorageCall(), ctx: &ctx); |
| 2225 | } else { |
| 2226 | llvm_unreachable("unknown custom directive parameter" ); |
| 2227 | } |
| 2228 | } |
| 2229 | |
| 2230 | /// Generate the printer for a custom directive. |
| 2231 | static void genCustomDirectivePrinter(CustomDirective *customDir, |
| 2232 | const Operator &op, MethodBody &body) { |
| 2233 | body << " print" << customDir->getName() << "(_odsPrinter, *this" ; |
| 2234 | for (FormatElement *param : customDir->getElements()) { |
| 2235 | body << ", " ; |
| 2236 | genCustomDirectiveParameterPrinter(element: param, op, body); |
| 2237 | } |
| 2238 | body << ");\n" ; |
| 2239 | } |
| 2240 | |
| 2241 | /// Generate the printer for a region with the given variable name. |
| 2242 | static void genRegionPrinter(const Twine ®ionName, MethodBody &body, |
| 2243 | bool hasImplicitTermTrait) { |
| 2244 | if (hasImplicitTermTrait) |
| 2245 | body << formatv(Fmt: regionSingleBlockImplicitTerminatorPrinterCode, Vals: regionName); |
| 2246 | else |
| 2247 | body << " _odsPrinter.printRegion(" << regionName << ");\n" ; |
| 2248 | } |
| 2249 | static void genVariadicRegionPrinter(const Twine ®ionListName, |
| 2250 | MethodBody &body, |
| 2251 | bool hasImplicitTermTrait) { |
| 2252 | body << " llvm::interleaveComma(" << regionListName |
| 2253 | << ", _odsPrinter, [&](::mlir::Region ®ion) {\n " ; |
| 2254 | genRegionPrinter(regionName: "region" , body, hasImplicitTermTrait); |
| 2255 | body << " });\n" ; |
| 2256 | } |
| 2257 | |
| 2258 | /// Generate the C++ for an operand to a (*-)type directive. |
| 2259 | static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op, |
| 2260 | MethodBody &body, |
| 2261 | bool useArrayRef = true) { |
| 2262 | if (isa<OperandsDirective>(Val: arg)) |
| 2263 | return body << "getOperation()->getOperandTypes()" ; |
| 2264 | if (isa<ResultsDirective>(Val: arg)) |
| 2265 | return body << "getOperation()->getResultTypes()" ; |
| 2266 | auto *operand = dyn_cast<OperandVariable>(Val: arg); |
| 2267 | auto *var = operand ? operand->getVar() : cast<ResultVariable>(Val: arg)->getVar(); |
| 2268 | if (var->isVariadicOfVariadic()) |
| 2269 | return body << formatv(Fmt: "{0}().join().getTypes()" , |
| 2270 | Vals: op.getGetterName(name: var->name)); |
| 2271 | if (var->isVariadic()) |
| 2272 | return body << op.getGetterName(name: var->name) << "().getTypes()" ; |
| 2273 | if (var->isOptional()) |
| 2274 | return body << formatv( |
| 2275 | Fmt: "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : " |
| 2276 | "::llvm::ArrayRef<::mlir::Type>())" , |
| 2277 | Vals: op.getGetterName(name: var->name)); |
| 2278 | if (useArrayRef) |
| 2279 | return body << "::llvm::ArrayRef<::mlir::Type>(" |
| 2280 | << op.getGetterName(name: var->name) << "().getType())" ; |
| 2281 | return body << op.getGetterName(name: var->name) << "().getType()" ; |
| 2282 | } |
| 2283 | |
| 2284 | /// Generate the printer for an enum attribute. |
| 2285 | static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, |
| 2286 | MethodBody &body) { |
| 2287 | Attribute baseAttr = var->attr.getBaseAttr(); |
| 2288 | const EnumInfo enumInfo(&baseAttr.getDef()); |
| 2289 | std::vector<EnumCase> cases = enumInfo.getAllCases(); |
| 2290 | |
| 2291 | body << formatv(Fmt: enumAttrBeginPrinterCode, |
| 2292 | Vals: (var->attr.isOptional() ? "*" : "" ) + |
| 2293 | op.getGetterName(name: var->name), |
| 2294 | Vals: enumInfo.getSymbolToStringFnName()); |
| 2295 | |
| 2296 | // Get a string containing all of the cases that can't be represented with a |
| 2297 | // keyword. |
| 2298 | BitVector nonKeywordCases(cases.size()); |
| 2299 | for (auto it : llvm::enumerate(First&: cases)) { |
| 2300 | if (!canFormatStringAsKeyword(value: it.value().getStr())) |
| 2301 | nonKeywordCases.set(it.index()); |
| 2302 | } |
| 2303 | |
| 2304 | // Otherwise if this is a bit enum attribute, don't allow cases that may |
| 2305 | // overlap with other cases. For simplicity sake, only allow cases with a |
| 2306 | // single bit value. |
| 2307 | if (enumInfo.isBitEnum()) { |
| 2308 | for (auto it : llvm::enumerate(First&: cases)) { |
| 2309 | int64_t value = it.value().getValue(); |
| 2310 | if (value < 0 || !llvm::isPowerOf2_64(Value: value)) |
| 2311 | nonKeywordCases.set(it.index()); |
| 2312 | } |
| 2313 | } |
| 2314 | |
| 2315 | // If there are any cases that can't be used with a keyword, switch on the |
| 2316 | // case value to determine when to print in the string form. |
| 2317 | if (nonKeywordCases.any()) { |
| 2318 | body << " switch (caseValue) {\n" ; |
| 2319 | StringRef cppNamespace = enumInfo.getCppNamespace(); |
| 2320 | StringRef enumName = enumInfo.getEnumClassName(); |
| 2321 | for (auto it : llvm::enumerate(First&: cases)) { |
| 2322 | if (nonKeywordCases.test(Idx: it.index())) |
| 2323 | continue; |
| 2324 | StringRef symbol = it.value().getSymbol(); |
| 2325 | body << formatv(Fmt: " case {0}::{1}::{2}:\n" , Vals&: cppNamespace, Vals&: enumName, |
| 2326 | Vals: llvm::isDigit(C: symbol.front()) ? ("_" + symbol) : symbol); |
| 2327 | } |
| 2328 | body << " _odsPrinter << caseValueStr;\n" |
| 2329 | " break;\n" |
| 2330 | " default:\n" |
| 2331 | " _odsPrinter << '\"' << caseValueStr << '\"';\n" |
| 2332 | " break;\n" |
| 2333 | " }\n" |
| 2334 | " }\n" ; |
| 2335 | return; |
| 2336 | } |
| 2337 | |
| 2338 | body << " _odsPrinter << caseValueStr;\n" |
| 2339 | " }\n" ; |
| 2340 | } |
| 2341 | |
| 2342 | /// Generate the check for the anchor of an optional group. |
| 2343 | static void genOptionalGroupPrinterAnchor(FormatElement *anchor, |
| 2344 | const Operator &op, |
| 2345 | MethodBody &body) { |
| 2346 | TypeSwitch<FormatElement *>(anchor) |
| 2347 | .Case<OperandVariable, ResultVariable>(caseFn: [&](auto *element) { |
| 2348 | const NamedTypeConstraint *var = element->getVar(); |
| 2349 | std::string name = op.getGetterName(name: var->name); |
| 2350 | if (var->isOptional()) |
| 2351 | body << name << "()" ; |
| 2352 | else if (var->isVariadic()) |
| 2353 | body << "!" << name << "().empty()" ; |
| 2354 | }) |
| 2355 | .Case(caseFn: [&](RegionVariable *element) { |
| 2356 | const NamedRegion *var = element->getVar(); |
| 2357 | std::string name = op.getGetterName(name: var->name); |
| 2358 | // TODO: Add a check for optional regions here when ODS supports it. |
| 2359 | body << "!" << name << "().empty()" ; |
| 2360 | }) |
| 2361 | .Case(caseFn: [&](TypeDirective *element) { |
| 2362 | genOptionalGroupPrinterAnchor(anchor: element->getArg(), op, body); |
| 2363 | }) |
| 2364 | .Case(caseFn: [&](FunctionalTypeDirective *element) { |
| 2365 | genOptionalGroupPrinterAnchor(anchor: element->getInputs(), op, body); |
| 2366 | }) |
| 2367 | .Case(caseFn: [&](AttributeVariable *element) { |
| 2368 | // Consider a default-valued attribute as present if it's not the |
| 2369 | // default value and an optional one present if it is set. |
| 2370 | genNonDefaultValueCheck(body, op, attrElement&: *element); |
| 2371 | }) |
| 2372 | .Case(caseFn: [&](PropertyVariable *element) { |
| 2373 | genNonDefaultValueCheck(body, op, propElement&: *element); |
| 2374 | }) |
| 2375 | .Case(caseFn: [&](CustomDirective *ele) { |
| 2376 | body << '('; |
| 2377 | llvm::interleave( |
| 2378 | c: ele->getElements(), os&: body, |
| 2379 | each_fn: [&](FormatElement *child) { |
| 2380 | body << '('; |
| 2381 | genOptionalGroupPrinterAnchor(anchor: child, op, body); |
| 2382 | body << ')'; |
| 2383 | }, |
| 2384 | separator: " || " ); |
| 2385 | body << ')'; |
| 2386 | }); |
| 2387 | } |
| 2388 | |
| 2389 | void collect(FormatElement *element, |
| 2390 | SmallVectorImpl<VariableElement *> &variables) { |
| 2391 | TypeSwitch<FormatElement *>(element) |
| 2392 | .Case(caseFn: [&](VariableElement *var) { variables.emplace_back(Args&: var); }) |
| 2393 | .Case(caseFn: [&](CustomDirective *ele) { |
| 2394 | for (FormatElement *arg : ele->getElements()) |
| 2395 | collect(element: arg, variables); |
| 2396 | }) |
| 2397 | .Case(caseFn: [&](OptionalElement *ele) { |
| 2398 | for (FormatElement *arg : ele->getThenElements()) |
| 2399 | collect(element: arg, variables); |
| 2400 | for (FormatElement *arg : ele->getElseElements()) |
| 2401 | collect(element: arg, variables); |
| 2402 | }) |
| 2403 | .Case(caseFn: [&](FunctionalTypeDirective *funcType) { |
| 2404 | collect(element: funcType->getInputs(), variables); |
| 2405 | collect(element: funcType->getResults(), variables); |
| 2406 | }) |
| 2407 | .Case(caseFn: [&](OIListElement *oilist) { |
| 2408 | for (ArrayRef<FormatElement *> arg : oilist->getParsingElements()) |
| 2409 | for (FormatElement *arg : arg) |
| 2410 | collect(element: arg, variables); |
| 2411 | }); |
| 2412 | } |
| 2413 | |
| 2414 | void OperationFormat::genElementPrinter(FormatElement *element, |
| 2415 | MethodBody &body, Operator &op, |
| 2416 | bool &shouldEmitSpace, |
| 2417 | bool &lastWasPunctuation) { |
| 2418 | if (LiteralElement *literal = dyn_cast<LiteralElement>(Val: element)) |
| 2419 | return genLiteralPrinter(value: literal->getSpelling(), body, shouldEmitSpace, |
| 2420 | lastWasPunctuation); |
| 2421 | |
| 2422 | // Emit a whitespace element. |
| 2423 | if (auto *space = dyn_cast<WhitespaceElement>(Val: element)) { |
| 2424 | if (space->getValue() == "\\n" ) { |
| 2425 | body << " _odsPrinter.printNewline();\n" ; |
| 2426 | } else { |
| 2427 | genSpacePrinter(value: !space->getValue().empty(), body, shouldEmitSpace, |
| 2428 | lastWasPunctuation); |
| 2429 | } |
| 2430 | return; |
| 2431 | } |
| 2432 | |
| 2433 | // Emit an optional group. |
| 2434 | if (OptionalElement *optional = dyn_cast<OptionalElement>(Val: element)) { |
| 2435 | // Emit the check for the presence of the anchor element. |
| 2436 | FormatElement *anchor = optional->getAnchor(); |
| 2437 | body << " if (" ; |
| 2438 | if (optional->isInverted()) |
| 2439 | body << "!" ; |
| 2440 | genOptionalGroupPrinterAnchor(anchor, op, body); |
| 2441 | body << ") {\n" ; |
| 2442 | body.indent(); |
| 2443 | |
| 2444 | // If the anchor is a unit attribute, we don't need to print it. When |
| 2445 | // parsing, we will add this attribute if this group is present. |
| 2446 | ArrayRef<FormatElement *> thenElements = optional->getThenElements(); |
| 2447 | ArrayRef<FormatElement *> elseElements = optional->getElseElements(); |
| 2448 | FormatElement *elidedAnchorElement = nullptr; |
| 2449 | auto *anchorAttr = dyn_cast<AttributeLikeVariable>(Val: anchor); |
| 2450 | if (anchorAttr && anchorAttr != thenElements.front() && |
| 2451 | (elseElements.empty() || anchorAttr != elseElements.front()) && |
| 2452 | anchorAttr->isUnit()) { |
| 2453 | elidedAnchorElement = anchorAttr; |
| 2454 | } |
| 2455 | auto genElementPrinters = [&](ArrayRef<FormatElement *> elements) { |
| 2456 | for (FormatElement *childElement : elements) { |
| 2457 | if (childElement != elidedAnchorElement) { |
| 2458 | genElementPrinter(element: childElement, body, op, shouldEmitSpace, |
| 2459 | lastWasPunctuation); |
| 2460 | } |
| 2461 | } |
| 2462 | }; |
| 2463 | |
| 2464 | // Emit each of the elements. |
| 2465 | genElementPrinters(thenElements); |
| 2466 | body << "}" ; |
| 2467 | |
| 2468 | // Emit each of the else elements. |
| 2469 | if (!elseElements.empty()) { |
| 2470 | body << " else {\n" ; |
| 2471 | genElementPrinters(elseElements); |
| 2472 | body << "}" ; |
| 2473 | } |
| 2474 | |
| 2475 | body.unindent() << "\n" ; |
| 2476 | return; |
| 2477 | } |
| 2478 | |
| 2479 | // Emit the OIList |
| 2480 | if (auto *oilist = dyn_cast<OIListElement>(Val: element)) { |
| 2481 | for (auto clause : oilist->getClauses()) { |
| 2482 | LiteralElement *lelement = std::get<0>(t&: clause); |
| 2483 | ArrayRef<FormatElement *> pelement = std::get<1>(t&: clause); |
| 2484 | |
| 2485 | SmallVector<VariableElement *> vars; |
| 2486 | for (FormatElement *el : pelement) |
| 2487 | collect(element: el, variables&: vars); |
| 2488 | body << " if (false" ; |
| 2489 | for (VariableElement *var : vars) { |
| 2490 | TypeSwitch<FormatElement *>(var) |
| 2491 | .Case(caseFn: [&](AttributeVariable *attrEle) { |
| 2492 | body << " || (" ; |
| 2493 | genNonDefaultValueCheck(body, op, attrElement&: *attrEle); |
| 2494 | body << ")" ; |
| 2495 | }) |
| 2496 | .Case(caseFn: [&](PropertyVariable *propEle) { |
| 2497 | body << " || (" ; |
| 2498 | genNonDefaultValueCheck(body, op, propElement&: *propEle); |
| 2499 | body << ")" ; |
| 2500 | }) |
| 2501 | .Case(caseFn: [&](OperandVariable *ele) { |
| 2502 | if (ele->getVar()->isVariadic()) { |
| 2503 | body << " || " << op.getGetterName(name: ele->getVar()->name) |
| 2504 | << "().size()" ; |
| 2505 | } else { |
| 2506 | body << " || " << op.getGetterName(name: ele->getVar()->name) << "()" ; |
| 2507 | } |
| 2508 | }) |
| 2509 | .Case(caseFn: [&](ResultVariable *ele) { |
| 2510 | if (ele->getVar()->isVariadic()) { |
| 2511 | body << " || " << op.getGetterName(name: ele->getVar()->name) |
| 2512 | << "().size()" ; |
| 2513 | } else { |
| 2514 | body << " || " << op.getGetterName(name: ele->getVar()->name) << "()" ; |
| 2515 | } |
| 2516 | }) |
| 2517 | .Case(caseFn: [&](RegionVariable *reg) { |
| 2518 | body << " || " << op.getGetterName(name: reg->getVar()->name) << "()" ; |
| 2519 | }); |
| 2520 | } |
| 2521 | |
| 2522 | body << ") {\n" ; |
| 2523 | genLiteralPrinter(value: lelement->getSpelling(), body, shouldEmitSpace, |
| 2524 | lastWasPunctuation); |
| 2525 | if (oilist->getUnitVariableParsingElement(pelement) == nullptr) { |
| 2526 | for (FormatElement *element : pelement) |
| 2527 | genElementPrinter(element, body, op, shouldEmitSpace, |
| 2528 | lastWasPunctuation); |
| 2529 | } |
| 2530 | body << " }\n" ; |
| 2531 | } |
| 2532 | return; |
| 2533 | } |
| 2534 | |
| 2535 | // Emit the attribute dictionary. |
| 2536 | if (auto *attrDict = dyn_cast<AttrDictDirective>(Val: element)) { |
| 2537 | genAttrDictPrinter(fmt&: *this, op, body, withKeyword: attrDict->isWithKeyword()); |
| 2538 | lastWasPunctuation = false; |
| 2539 | return; |
| 2540 | } |
| 2541 | |
| 2542 | // Emit the property dictionary. |
| 2543 | if (isa<PropDictDirective>(Val: element)) { |
| 2544 | genPropDictPrinter(fmt&: *this, op, body); |
| 2545 | lastWasPunctuation = false; |
| 2546 | return; |
| 2547 | } |
| 2548 | |
| 2549 | // Optionally insert a space before the next element. The AttrDict printer |
| 2550 | // already adds a space as necessary. |
| 2551 | if (shouldEmitSpace || !lastWasPunctuation) |
| 2552 | body << " _odsPrinter << ' ';\n" ; |
| 2553 | lastWasPunctuation = false; |
| 2554 | shouldEmitSpace = true; |
| 2555 | |
| 2556 | if (auto *attr = dyn_cast<AttributeVariable>(Val: element)) { |
| 2557 | const NamedAttribute *var = attr->getVar(); |
| 2558 | |
| 2559 | // If we are formatting as an enum, symbolize the attribute as a string. |
| 2560 | if (canFormatEnumAttr(attr: var)) |
| 2561 | return genEnumAttrPrinter(var, op, body); |
| 2562 | |
| 2563 | // If we are formatting as a symbol name, handle it as a symbol name. |
| 2564 | if (shouldFormatSymbolNameAttr(attr: var)) { |
| 2565 | body << " _odsPrinter.printSymbolName(" << op.getGetterName(name: var->name) |
| 2566 | << "Attr().getValue());\n" ; |
| 2567 | return; |
| 2568 | } |
| 2569 | |
| 2570 | // Elide the attribute type if it is buildable. |
| 2571 | if (attr->getTypeBuilder()) |
| 2572 | body << " _odsPrinter.printAttributeWithoutType(" |
| 2573 | << op.getGetterName(name: var->name) << "Attr());\n" ; |
| 2574 | else if (attr->shouldBeQualified() || |
| 2575 | var->attr.getStorageType() == "::mlir::Attribute" ) |
| 2576 | body << " _odsPrinter.printAttribute(" << op.getGetterName(name: var->name) |
| 2577 | << "Attr());\n" ; |
| 2578 | else |
| 2579 | body << "_odsPrinter.printStrippedAttrOrType(" |
| 2580 | << op.getGetterName(name: var->name) << "Attr());\n" ; |
| 2581 | } else if (auto *property = dyn_cast<PropertyVariable>(Val: element)) { |
| 2582 | const NamedProperty *var = property->getVar(); |
| 2583 | FmtContext fmtContext; |
| 2584 | fmtContext.addSubst(placeholder: "_printer" , subst: "_odsPrinter" ); |
| 2585 | fmtContext.addSubst(placeholder: "_ctxt" , subst: "getContext()" ); |
| 2586 | fmtContext.addSubst(placeholder: "_storage" , subst: "getProperties()." + var->name); |
| 2587 | body << tgfmt(fmt: var->prop.getPrinterCall(), ctx: &fmtContext) << ";\n" ; |
| 2588 | } else if (auto *operand = dyn_cast<OperandVariable>(Val: element)) { |
| 2589 | if (operand->getVar()->isVariadicOfVariadic()) { |
| 2590 | body << " ::llvm::interleaveComma(" |
| 2591 | << op.getGetterName(name: operand->getVar()->name) |
| 2592 | << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << " |
| 2593 | "\"(\" << operands << " |
| 2594 | "\")\"; });\n" ; |
| 2595 | |
| 2596 | } else if (operand->getVar()->isOptional()) { |
| 2597 | body << " if (::mlir::Value value = " |
| 2598 | << op.getGetterName(name: operand->getVar()->name) << "())\n" |
| 2599 | << " _odsPrinter << value;\n" ; |
| 2600 | } else { |
| 2601 | body << " _odsPrinter << " << op.getGetterName(name: operand->getVar()->name) |
| 2602 | << "();\n" ; |
| 2603 | } |
| 2604 | } else if (auto *region = dyn_cast<RegionVariable>(Val: element)) { |
| 2605 | const NamedRegion *var = region->getVar(); |
| 2606 | std::string name = op.getGetterName(name: var->name); |
| 2607 | if (var->isVariadic()) { |
| 2608 | genVariadicRegionPrinter(regionListName: name + "()" , body, hasImplicitTermTrait); |
| 2609 | } else { |
| 2610 | genRegionPrinter(regionName: name + "()" , body, hasImplicitTermTrait); |
| 2611 | } |
| 2612 | } else if (auto *successor = dyn_cast<SuccessorVariable>(Val: element)) { |
| 2613 | const NamedSuccessor *var = successor->getVar(); |
| 2614 | std::string name = op.getGetterName(name: var->name); |
| 2615 | if (var->isVariadic()) |
| 2616 | body << " ::llvm::interleaveComma(" << name << "(), _odsPrinter);\n" ; |
| 2617 | else |
| 2618 | body << " _odsPrinter << " << name << "();\n" ; |
| 2619 | } else if (auto *dir = dyn_cast<CustomDirective>(Val: element)) { |
| 2620 | genCustomDirectivePrinter(customDir: dir, op, body); |
| 2621 | } else if (isa<OperandsDirective>(Val: element)) { |
| 2622 | body << " _odsPrinter << getOperation()->getOperands();\n" ; |
| 2623 | } else if (isa<RegionsDirective>(Val: element)) { |
| 2624 | genVariadicRegionPrinter(regionListName: "getOperation()->getRegions()" , body, |
| 2625 | hasImplicitTermTrait); |
| 2626 | } else if (isa<SuccessorsDirective>(Val: element)) { |
| 2627 | body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), " |
| 2628 | "_odsPrinter);\n" ; |
| 2629 | } else if (auto *dir = dyn_cast<TypeDirective>(Val: element)) { |
| 2630 | if (auto *operand = dyn_cast<OperandVariable>(Val: dir->getArg())) { |
| 2631 | if (operand->getVar()->isVariadicOfVariadic()) { |
| 2632 | body << formatv( |
| 2633 | Fmt: " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, " |
| 2634 | "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << " |
| 2635 | "types << \")\"; });\n" , |
| 2636 | Vals: op.getGetterName(name: operand->getVar()->name)); |
| 2637 | return; |
| 2638 | } |
| 2639 | } |
| 2640 | const NamedTypeConstraint *var = nullptr; |
| 2641 | { |
| 2642 | if (auto *operand = dyn_cast<OperandVariable>(Val: dir->getArg())) |
| 2643 | var = operand->getVar(); |
| 2644 | else if (auto *operand = dyn_cast<ResultVariable>(Val: dir->getArg())) |
| 2645 | var = operand->getVar(); |
| 2646 | } |
| 2647 | if (var && !var->isVariadicOfVariadic() && !var->isVariadic() && |
| 2648 | !var->isOptional()) { |
| 2649 | StringRef cppType = var->constraint.getCppType(); |
| 2650 | if (dir->shouldBeQualified()) { |
| 2651 | body << " _odsPrinter << " << op.getGetterName(name: var->name) |
| 2652 | << "().getType();\n" ; |
| 2653 | return; |
| 2654 | } |
| 2655 | body << " {\n" |
| 2656 | << " auto type = " << op.getGetterName(name: var->name) |
| 2657 | << "().getType();\n" |
| 2658 | << " if (auto validType = ::llvm::dyn_cast<" << cppType |
| 2659 | << ">(type))\n" |
| 2660 | << " _odsPrinter.printStrippedAttrOrType(validType);\n" |
| 2661 | << " else\n" |
| 2662 | << " _odsPrinter << type;\n" |
| 2663 | << " }\n" ; |
| 2664 | return; |
| 2665 | } |
| 2666 | body << " _odsPrinter << " ; |
| 2667 | genTypeOperandPrinter(arg: dir->getArg(), op, body, /*useArrayRef=*/false) |
| 2668 | << ";\n" ; |
| 2669 | } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(Val: element)) { |
| 2670 | body << " _odsPrinter.printFunctionalType(" ; |
| 2671 | genTypeOperandPrinter(arg: dir->getInputs(), op, body) << ", " ; |
| 2672 | genTypeOperandPrinter(arg: dir->getResults(), op, body) << ");\n" ; |
| 2673 | } else { |
| 2674 | llvm_unreachable("unknown format element" ); |
| 2675 | } |
| 2676 | } |
| 2677 | |
| 2678 | void OperationFormat::genPrinter(Operator &op, OpClass &opClass) { |
| 2679 | auto *method = opClass.addMethod( |
| 2680 | retType: "void" , name: "print" , |
| 2681 | args: MethodParameter("::mlir::OpAsmPrinter &" , "_odsPrinter" )); |
| 2682 | auto &body = method->body(); |
| 2683 | |
| 2684 | // Flags for if we should emit a space, and if the last element was |
| 2685 | // punctuation. |
| 2686 | bool shouldEmitSpace = true, lastWasPunctuation = false; |
| 2687 | for (FormatElement *element : elements) |
| 2688 | genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation); |
| 2689 | } |
| 2690 | |
| 2691 | //===----------------------------------------------------------------------===// |
| 2692 | // OpFormatParser |
| 2693 | //===----------------------------------------------------------------------===// |
| 2694 | |
| 2695 | /// Function to find an element within the given range that has the same name as |
| 2696 | /// 'name'. |
| 2697 | template <typename RangeT> |
| 2698 | static auto findArg(RangeT &&range, StringRef name) { |
| 2699 | auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; }); |
| 2700 | return it != range.end() ? &*it : nullptr; |
| 2701 | } |
| 2702 | |
| 2703 | namespace { |
| 2704 | /// This class implements a parser for an instance of an operation assembly |
| 2705 | /// format. |
| 2706 | class OpFormatParser : public FormatParser { |
| 2707 | public: |
| 2708 | OpFormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op) |
| 2709 | : FormatParser(mgr, op.getLoc()[0]), fmt(format), op(op), |
| 2710 | seenOperandTypes(op.getNumOperands()), |
| 2711 | seenResultTypes(op.getNumResults()) {} |
| 2712 | |
| 2713 | protected: |
| 2714 | /// Verify the format elements. |
| 2715 | LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override; |
| 2716 | /// Verify the arguments to a custom directive. |
| 2717 | LogicalResult |
| 2718 | verifyCustomDirectiveArguments(SMLoc loc, |
| 2719 | ArrayRef<FormatElement *> arguments) override; |
| 2720 | /// Verify the elements of an optional group. |
| 2721 | LogicalResult verifyOptionalGroupElements(SMLoc loc, |
| 2722 | ArrayRef<FormatElement *> elements, |
| 2723 | FormatElement *anchor) override; |
| 2724 | LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element, |
| 2725 | bool isAnchor); |
| 2726 | |
| 2727 | LogicalResult markQualified(SMLoc loc, FormatElement *element) override; |
| 2728 | |
| 2729 | /// Parse an operation variable. |
| 2730 | FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name, |
| 2731 | Context ctx) override; |
| 2732 | /// Parse an operation format directive. |
| 2733 | FailureOr<FormatElement *> |
| 2734 | parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override; |
| 2735 | |
| 2736 | private: |
| 2737 | /// This struct represents a type resolution instance. It includes a specific |
| 2738 | /// type as well as an optional transformer to apply to that type in order to |
| 2739 | /// properly resolve the type of a variable. |
| 2740 | struct TypeResolutionInstance { |
| 2741 | ConstArgument resolver; |
| 2742 | std::optional<StringRef> transformer; |
| 2743 | }; |
| 2744 | |
| 2745 | /// Verify the state of operation attributes within the format. |
| 2746 | LogicalResult verifyAttributes(SMLoc loc, ArrayRef<FormatElement *> elements); |
| 2747 | |
| 2748 | /// Verify that attributes elements aren't followed by colon literals. |
| 2749 | LogicalResult verifyAttributeColonType(SMLoc loc, |
| 2750 | ArrayRef<FormatElement *> elements); |
| 2751 | /// Verify that the attribute dictionary directive isn't followed by a region. |
| 2752 | LogicalResult verifyAttrDictRegion(SMLoc loc, |
| 2753 | ArrayRef<FormatElement *> elements); |
| 2754 | |
| 2755 | /// Verify the state of operation operands within the format. |
| 2756 | LogicalResult |
| 2757 | verifyOperands(SMLoc loc, |
| 2758 | StringMap<TypeResolutionInstance> &variableTyResolver); |
| 2759 | |
| 2760 | /// Verify the state of operation regions within the format. |
| 2761 | LogicalResult verifyRegions(SMLoc loc); |
| 2762 | |
| 2763 | /// Verify the state of operation results within the format. |
| 2764 | LogicalResult |
| 2765 | verifyResults(SMLoc loc, |
| 2766 | StringMap<TypeResolutionInstance> &variableTyResolver); |
| 2767 | |
| 2768 | /// Verify the state of operation successors within the format. |
| 2769 | LogicalResult verifySuccessors(SMLoc loc); |
| 2770 | |
| 2771 | LogicalResult verifyOIListElements(SMLoc loc, |
| 2772 | ArrayRef<FormatElement *> elements); |
| 2773 | |
| 2774 | /// Given the values of an `AllTypesMatch` trait, check for inferable type |
| 2775 | /// resolution. |
| 2776 | void handleAllTypesMatchConstraint( |
| 2777 | ArrayRef<StringRef> values, |
| 2778 | StringMap<TypeResolutionInstance> &variableTyResolver); |
| 2779 | /// Check for inferable type resolution given all operands, and or results, |
| 2780 | /// have the same type. If 'includeResults' is true, the results also have the |
| 2781 | /// same type as all of the operands. |
| 2782 | void handleSameTypesConstraint( |
| 2783 | StringMap<TypeResolutionInstance> &variableTyResolver, |
| 2784 | bool includeResults); |
| 2785 | /// Check for inferable type resolution based on another operand, result, or |
| 2786 | /// attribute. |
| 2787 | void handleTypesMatchConstraint( |
| 2788 | StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def); |
| 2789 | |
| 2790 | /// Returns an argument or attribute with the given name that has been seen |
| 2791 | /// within the format. |
| 2792 | ConstArgument findSeenArg(StringRef name); |
| 2793 | |
| 2794 | /// Parse the various different directives. |
| 2795 | FailureOr<FormatElement *> parsePropDictDirective(SMLoc loc, Context context); |
| 2796 | FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context, |
| 2797 | bool withKeyword); |
| 2798 | FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc, |
| 2799 | Context context); |
| 2800 | FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context); |
| 2801 | LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc); |
| 2802 | FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context); |
| 2803 | FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context); |
| 2804 | FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context); |
| 2805 | FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc, |
| 2806 | Context context); |
| 2807 | FailureOr<FormatElement *> parseTypeDirective(SMLoc loc, Context context); |
| 2808 | FailureOr<FormatElement *> parseTypeDirectiveOperand(SMLoc loc, |
| 2809 | bool isRefChild = false); |
| 2810 | |
| 2811 | //===--------------------------------------------------------------------===// |
| 2812 | // Fields |
| 2813 | //===--------------------------------------------------------------------===// |
| 2814 | |
| 2815 | OperationFormat &fmt; |
| 2816 | Operator &op; |
| 2817 | |
| 2818 | // The following are various bits of format state used for verification |
| 2819 | // during parsing. |
| 2820 | bool hasAttrDict = false; |
| 2821 | bool hasPropDict = false; |
| 2822 | bool hasAllRegions = false, hasAllSuccessors = false; |
| 2823 | bool canInferResultTypes = false; |
| 2824 | llvm::SmallBitVector seenOperandTypes, seenResultTypes; |
| 2825 | llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs; |
| 2826 | llvm::DenseSet<const NamedTypeConstraint *> seenOperands; |
| 2827 | llvm::DenseSet<const NamedRegion *> seenRegions; |
| 2828 | llvm::DenseSet<const NamedSuccessor *> seenSuccessors; |
| 2829 | llvm::SmallSetVector<const NamedProperty *, 8> seenProperties; |
| 2830 | }; |
| 2831 | } // namespace |
| 2832 | |
| 2833 | LogicalResult OpFormatParser::verify(SMLoc loc, |
| 2834 | ArrayRef<FormatElement *> elements) { |
| 2835 | // Check that the attribute dictionary is in the format. |
| 2836 | if (!hasAttrDict) |
| 2837 | return emitError(loc, msg: "'attr-dict' directive not found in " |
| 2838 | "custom assembly format" ); |
| 2839 | |
| 2840 | // Check for any type traits that we can use for inferring types. |
| 2841 | StringMap<TypeResolutionInstance> variableTyResolver; |
| 2842 | for (const Trait &trait : op.getTraits()) { |
| 2843 | const Record &def = trait.getDef(); |
| 2844 | if (def.isSubClassOf(Name: "AllTypesMatch" )) { |
| 2845 | handleAllTypesMatchConstraint(values: def.getValueAsListOfStrings(FieldName: "values" ), |
| 2846 | variableTyResolver); |
| 2847 | } else if (def.getName() == "SameTypeOperands" ) { |
| 2848 | handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false); |
| 2849 | } else if (def.getName() == "SameOperandsAndResultType" ) { |
| 2850 | handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); |
| 2851 | } else if (def.isSubClassOf(Name: "TypesMatchWith" )) { |
| 2852 | handleTypesMatchConstraint(variableTyResolver, def); |
| 2853 | } else if (!op.allResultTypesKnown()) { |
| 2854 | // This doesn't check the name directly to handle |
| 2855 | // DeclareOpInterfaceMethods<InferTypeOpInterface> |
| 2856 | // and the like. |
| 2857 | // TODO: Add hasCppInterface check. |
| 2858 | if (auto name = def.getValueAsOptionalString(FieldName: "cppInterfaceName" )) { |
| 2859 | if (*name == "InferTypeOpInterface" && |
| 2860 | def.getValueAsString(FieldName: "cppNamespace" ) == "::mlir" ) |
| 2861 | canInferResultTypes = true; |
| 2862 | } |
| 2863 | } |
| 2864 | } |
| 2865 | |
| 2866 | // Verify the state of the various operation components. |
| 2867 | if (failed(Result: verifyAttributes(loc, elements)) || |
| 2868 | failed(Result: verifyResults(loc, variableTyResolver)) || |
| 2869 | failed(Result: verifyOperands(loc, variableTyResolver)) || |
| 2870 | failed(Result: verifyRegions(loc)) || failed(Result: verifySuccessors(loc)) || |
| 2871 | failed(Result: verifyOIListElements(loc, elements))) |
| 2872 | return failure(); |
| 2873 | |
| 2874 | // Collect the set of used attributes in the format. |
| 2875 | fmt.usedAttributes = std::move(seenAttrs); |
| 2876 | fmt.usedProperties = std::move(seenProperties); |
| 2877 | |
| 2878 | // Set whether prop-dict is used in the format |
| 2879 | fmt.hasPropDict = hasPropDict; |
| 2880 | return success(); |
| 2881 | } |
| 2882 | |
| 2883 | LogicalResult |
| 2884 | OpFormatParser::verifyAttributes(SMLoc loc, |
| 2885 | ArrayRef<FormatElement *> elements) { |
| 2886 | // Check that there are no `:` literals after an attribute without a constant |
| 2887 | // type. The attribute grammar contains an optional trailing colon type, which |
| 2888 | // can lead to unexpected and generally unintended behavior. Given that, it is |
| 2889 | // better to just error out here instead. |
| 2890 | if (failed(Result: verifyAttributeColonType(loc, elements))) |
| 2891 | return failure(); |
| 2892 | // Check that there are no region variables following an attribute dicitonary. |
| 2893 | // Both start with `{` and so the optional attribute dictionary can cause |
| 2894 | // format ambiguities. |
| 2895 | if (failed(Result: verifyAttrDictRegion(loc, elements))) |
| 2896 | return failure(); |
| 2897 | |
| 2898 | // Check for VariadicOfVariadic variables. The segment attribute of those |
| 2899 | // variables will be infered. |
| 2900 | for (const NamedTypeConstraint *var : seenOperands) { |
| 2901 | if (var->constraint.isVariadicOfVariadic()) { |
| 2902 | fmt.inferredAttributes.insert( |
| 2903 | key: var->constraint.getVariadicOfVariadicSegmentSizeAttr()); |
| 2904 | } |
| 2905 | } |
| 2906 | |
| 2907 | return success(); |
| 2908 | } |
| 2909 | |
| 2910 | /// Returns whether the single format element is optionally parsed. |
| 2911 | static bool isOptionallyParsed(FormatElement *el) { |
| 2912 | if (auto *attrVar = dyn_cast<AttributeVariable>(Val: el)) { |
| 2913 | Attribute attr = attrVar->getVar()->attr; |
| 2914 | return attr.isOptional() || attr.hasDefaultValue(); |
| 2915 | } |
| 2916 | if (auto *propVar = dyn_cast<PropertyVariable>(Val: el)) { |
| 2917 | const Property &prop = propVar->getVar()->prop; |
| 2918 | return prop.hasDefaultValue() && prop.hasOptionalParser(); |
| 2919 | } |
| 2920 | if (auto *operandVar = dyn_cast<OperandVariable>(Val: el)) { |
| 2921 | const NamedTypeConstraint *operand = operandVar->getVar(); |
| 2922 | return operand->isOptional() || operand->isVariadic() || |
| 2923 | operand->isVariadicOfVariadic(); |
| 2924 | } |
| 2925 | if (auto *successorVar = dyn_cast<SuccessorVariable>(Val: el)) |
| 2926 | return successorVar->getVar()->isVariadic(); |
| 2927 | if (auto *regionVar = dyn_cast<RegionVariable>(Val: el)) |
| 2928 | return regionVar->getVar()->isVariadic(); |
| 2929 | return isa<WhitespaceElement, AttrDictDirective>(Val: el); |
| 2930 | } |
| 2931 | |
| 2932 | /// Scan the given range of elements from the start for an invalid format |
| 2933 | /// element that satisfies `isInvalid`, skipping any optionally-parsed elements. |
| 2934 | /// If an optional group is encountered, this function recurses into the 'then' |
| 2935 | /// and 'else' elements to check if they are invalid. Returns `success` if the |
| 2936 | /// range is known to be valid or `std::nullopt` if scanning reached the end. |
| 2937 | /// |
| 2938 | /// Since the guard element of an optional group is required, this function |
| 2939 | /// accepts an optional element pointer to mark it as required. |
| 2940 | static std::optional<LogicalResult> checkRangeForElement( |
| 2941 | FormatElement *base, |
| 2942 | function_ref<bool(FormatElement *, FormatElement *)> isInvalid, |
| 2943 | iterator_range<ArrayRef<FormatElement *>::iterator> elementRange, |
| 2944 | FormatElement *optionalGuard = nullptr) { |
| 2945 | for (FormatElement *element : elementRange) { |
| 2946 | // If we encounter an invalid element, return an error. |
| 2947 | if (isInvalid(base, element)) |
| 2948 | return failure(); |
| 2949 | |
| 2950 | // Recurse on optional groups. |
| 2951 | if (auto *optional = dyn_cast<OptionalElement>(Val: element)) { |
| 2952 | if (std::optional<LogicalResult> result = checkRangeForElement( |
| 2953 | base, isInvalid, elementRange: optional->getThenElements(), |
| 2954 | // The optional group guard is required for the group. |
| 2955 | optionalGuard: optional->getThenElements().front())) |
| 2956 | if (failed(Result: *result)) |
| 2957 | return failure(); |
| 2958 | if (std::optional<LogicalResult> result = checkRangeForElement( |
| 2959 | base, isInvalid, elementRange: optional->getElseElements())) |
| 2960 | if (failed(Result: *result)) |
| 2961 | return failure(); |
| 2962 | // Skip the optional group. |
| 2963 | continue; |
| 2964 | } |
| 2965 | |
| 2966 | // Skip optionally parsed elements. |
| 2967 | if (element != optionalGuard && isOptionallyParsed(el: element)) |
| 2968 | continue; |
| 2969 | |
| 2970 | // We found a closing element that is valid. |
| 2971 | return success(); |
| 2972 | } |
| 2973 | // Return std::nullopt to indicate that we reached the end. |
| 2974 | return std::nullopt; |
| 2975 | } |
| 2976 | |
| 2977 | /// For the given elements, check whether any attributes are followed by a colon |
| 2978 | /// literal, resulting in an ambiguous assembly format. Returns a non-null |
| 2979 | /// attribute if verification of said attribute reached the end of the range. |
| 2980 | /// Returns null if all attribute elements are verified. |
| 2981 | static FailureOr<FormatElement *> verifyAdjacentElements( |
| 2982 | function_ref<bool(FormatElement *)> isBase, |
| 2983 | function_ref<bool(FormatElement *, FormatElement *)> isInvalid, |
| 2984 | ArrayRef<FormatElement *> elements) { |
| 2985 | for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) { |
| 2986 | // The current attribute being verified. |
| 2987 | FormatElement *base; |
| 2988 | |
| 2989 | if (isBase(*it)) { |
| 2990 | base = *it; |
| 2991 | } else if (auto *optional = dyn_cast<OptionalElement>(Val: *it)) { |
| 2992 | // Recurse on optional groups. |
| 2993 | FailureOr<FormatElement *> thenResult = verifyAdjacentElements( |
| 2994 | isBase, isInvalid, elements: optional->getThenElements()); |
| 2995 | if (failed(Result: thenResult)) |
| 2996 | return failure(); |
| 2997 | FailureOr<FormatElement *> elseResult = verifyAdjacentElements( |
| 2998 | isBase, isInvalid, elements: optional->getElseElements()); |
| 2999 | if (failed(Result: elseResult)) |
| 3000 | return failure(); |
| 3001 | // If either optional group has an unverified attribute, save it. |
| 3002 | // Otherwise, move on to the next element. |
| 3003 | if (!(base = *thenResult) && !(base = *elseResult)) |
| 3004 | continue; |
| 3005 | } else { |
| 3006 | continue; |
| 3007 | } |
| 3008 | |
| 3009 | // Verify subsequent elements for potential ambiguities. |
| 3010 | if (std::optional<LogicalResult> result = |
| 3011 | checkRangeForElement(base, isInvalid, elementRange: {std::next(x: it), e})) { |
| 3012 | if (failed(Result: *result)) |
| 3013 | return failure(); |
| 3014 | } else { |
| 3015 | // Since we reached the end, return the attribute as unverified. |
| 3016 | return base; |
| 3017 | } |
| 3018 | } |
| 3019 | // All attribute elements are known to be verified. |
| 3020 | return nullptr; |
| 3021 | } |
| 3022 | |
| 3023 | LogicalResult |
| 3024 | OpFormatParser::verifyAttributeColonType(SMLoc loc, |
| 3025 | ArrayRef<FormatElement *> elements) { |
| 3026 | auto isBase = [](FormatElement *el) { |
| 3027 | auto *attr = dyn_cast<AttributeVariable>(Val: el); |
| 3028 | if (!attr) |
| 3029 | return false; |
| 3030 | // Check only attributes without type builders or that are known to call |
| 3031 | // the generic attribute parser. |
| 3032 | return !attr->getTypeBuilder() && |
| 3033 | (attr->shouldBeQualified() || |
| 3034 | attr->getVar()->attr.getStorageType() == "::mlir::Attribute" ); |
| 3035 | }; |
| 3036 | auto isInvalid = [&](FormatElement *base, FormatElement *el) { |
| 3037 | auto *literal = dyn_cast<LiteralElement>(Val: el); |
| 3038 | if (!literal || literal->getSpelling() != ":" ) |
| 3039 | return false; |
| 3040 | // If we encounter `:`, the range is known to be invalid. |
| 3041 | (void)emitError( |
| 3042 | loc, msg: formatv(Fmt: "format ambiguity caused by `:` literal found after " |
| 3043 | "attribute `{0}` which does not have a buildable type" , |
| 3044 | Vals: cast<AttributeVariable>(Val: base)->getVar()->name)); |
| 3045 | return true; |
| 3046 | }; |
| 3047 | return verifyAdjacentElements(isBase, isInvalid, elements); |
| 3048 | } |
| 3049 | |
| 3050 | LogicalResult |
| 3051 | OpFormatParser::verifyAttrDictRegion(SMLoc loc, |
| 3052 | ArrayRef<FormatElement *> elements) { |
| 3053 | auto isBase = [](FormatElement *el) { |
| 3054 | if (auto *attrDict = dyn_cast<AttrDictDirective>(Val: el)) |
| 3055 | return !attrDict->isWithKeyword(); |
| 3056 | return false; |
| 3057 | }; |
| 3058 | auto isInvalid = [&](FormatElement *base, FormatElement *el) { |
| 3059 | auto *region = dyn_cast<RegionVariable>(Val: el); |
| 3060 | if (!region) |
| 3061 | return false; |
| 3062 | (void)emitErrorAndNote( |
| 3063 | loc, |
| 3064 | msg: formatv(Fmt: "format ambiguity caused by `attr-dict` directive " |
| 3065 | "followed by region `{0}`" , |
| 3066 | Vals: region->getVar()->name), |
| 3067 | note: "try using `attr-dict-with-keyword` instead" ); |
| 3068 | return true; |
| 3069 | }; |
| 3070 | return verifyAdjacentElements(isBase, isInvalid, elements); |
| 3071 | } |
| 3072 | |
| 3073 | LogicalResult OpFormatParser::verifyOperands( |
| 3074 | SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) { |
| 3075 | // Check that all of the operands are within the format, and their types can |
| 3076 | // be inferred. |
| 3077 | auto &buildableTypes = fmt.buildableTypes; |
| 3078 | for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { |
| 3079 | NamedTypeConstraint &operand = op.getOperand(index: i); |
| 3080 | |
| 3081 | // Check that the operand itself is in the format. |
| 3082 | if (!fmt.allOperands && !seenOperands.count(V: &operand)) { |
| 3083 | return emitErrorAndNote(loc, |
| 3084 | msg: "operand #" + Twine(i) + ", named '" + |
| 3085 | operand.name + "', not found" , |
| 3086 | note: "suggest adding a '$" + operand.name + |
| 3087 | "' directive to the custom assembly format" ); |
| 3088 | } |
| 3089 | |
| 3090 | // Check that the operand type is in the format, or that it can be inferred. |
| 3091 | if (fmt.allOperandTypes || seenOperandTypes.test(Idx: i)) |
| 3092 | continue; |
| 3093 | |
| 3094 | // Check to see if we can infer this type from another variable. |
| 3095 | auto varResolverIt = variableTyResolver.find(Key: op.getOperand(index: i).name); |
| 3096 | if (varResolverIt != variableTyResolver.end()) { |
| 3097 | TypeResolutionInstance &resolver = varResolverIt->second; |
| 3098 | fmt.operandTypes[i].setResolver(arg: resolver.resolver, transformer: resolver.transformer); |
| 3099 | continue; |
| 3100 | } |
| 3101 | |
| 3102 | // Similarly to results, allow a custom builder for resolving the type if |
| 3103 | // we aren't using the 'operands' directive. |
| 3104 | std::optional<StringRef> builder = operand.constraint.getBuilderCall(); |
| 3105 | if (!builder || (fmt.allOperands && operand.isVariableLength())) { |
| 3106 | return emitErrorAndNote( |
| 3107 | loc, |
| 3108 | msg: "type of operand #" + Twine(i) + ", named '" + operand.name + |
| 3109 | "', is not buildable and a buildable type cannot be inferred" , |
| 3110 | note: "suggest adding a type constraint to the operation or adding a " |
| 3111 | "'type($" + |
| 3112 | operand.name + ")' directive to the " + "custom assembly format" ); |
| 3113 | } |
| 3114 | auto it = buildableTypes.insert(KV: {*builder, buildableTypes.size()}); |
| 3115 | fmt.operandTypes[i].setBuilderIdx(it.first->second); |
| 3116 | } |
| 3117 | return success(); |
| 3118 | } |
| 3119 | |
| 3120 | LogicalResult OpFormatParser::verifyRegions(SMLoc loc) { |
| 3121 | // Check that all of the regions are within the format. |
| 3122 | if (hasAllRegions) |
| 3123 | return success(); |
| 3124 | |
| 3125 | for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) { |
| 3126 | const NamedRegion ®ion = op.getRegion(index: i); |
| 3127 | if (!seenRegions.count(V: ®ion)) { |
| 3128 | return emitErrorAndNote(loc, |
| 3129 | msg: "region #" + Twine(i) + ", named '" + |
| 3130 | region.name + "', not found" , |
| 3131 | note: "suggest adding a '$" + region.name + |
| 3132 | "' directive to the custom assembly format" ); |
| 3133 | } |
| 3134 | } |
| 3135 | return success(); |
| 3136 | } |
| 3137 | |
| 3138 | LogicalResult OpFormatParser::verifyResults( |
| 3139 | SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) { |
| 3140 | // If we format all of the types together, there is nothing to check. |
| 3141 | if (fmt.allResultTypes) |
| 3142 | return success(); |
| 3143 | |
| 3144 | // If no result types are specified and we can infer them, infer all result |
| 3145 | // types |
| 3146 | if (op.getNumResults() > 0 && seenResultTypes.count() == 0 && |
| 3147 | canInferResultTypes) { |
| 3148 | fmt.infersResultTypes = true; |
| 3149 | return success(); |
| 3150 | } |
| 3151 | |
| 3152 | // Check that all of the result types can be inferred. |
| 3153 | auto &buildableTypes = fmt.buildableTypes; |
| 3154 | for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { |
| 3155 | if (seenResultTypes.test(Idx: i)) |
| 3156 | continue; |
| 3157 | |
| 3158 | // Check to see if we can infer this type from another variable. |
| 3159 | auto varResolverIt = variableTyResolver.find(Key: op.getResultName(index: i)); |
| 3160 | if (varResolverIt != variableTyResolver.end()) { |
| 3161 | TypeResolutionInstance resolver = varResolverIt->second; |
| 3162 | fmt.resultTypes[i].setResolver(arg: resolver.resolver, transformer: resolver.transformer); |
| 3163 | continue; |
| 3164 | } |
| 3165 | |
| 3166 | // If the result is not variable length, allow for the case where the type |
| 3167 | // has a builder that we can use. |
| 3168 | NamedTypeConstraint &result = op.getResult(index: i); |
| 3169 | std::optional<StringRef> builder = result.constraint.getBuilderCall(); |
| 3170 | if (!builder || result.isVariableLength()) { |
| 3171 | return emitErrorAndNote( |
| 3172 | loc, |
| 3173 | msg: "type of result #" + Twine(i) + ", named '" + result.name + |
| 3174 | "', is not buildable and a buildable type cannot be inferred" , |
| 3175 | note: "suggest adding a type constraint to the operation or adding a " |
| 3176 | "'type($" + |
| 3177 | result.name + ")' directive to the " + "custom assembly format" ); |
| 3178 | } |
| 3179 | // Note in the format that this result uses the custom builder. |
| 3180 | auto it = buildableTypes.insert(KV: {*builder, buildableTypes.size()}); |
| 3181 | fmt.resultTypes[i].setBuilderIdx(it.first->second); |
| 3182 | } |
| 3183 | return success(); |
| 3184 | } |
| 3185 | |
| 3186 | LogicalResult OpFormatParser::verifySuccessors(SMLoc loc) { |
| 3187 | // Check that all of the successors are within the format. |
| 3188 | if (hasAllSuccessors) |
| 3189 | return success(); |
| 3190 | |
| 3191 | for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) { |
| 3192 | const NamedSuccessor &successor = op.getSuccessor(index: i); |
| 3193 | if (!seenSuccessors.count(V: &successor)) { |
| 3194 | return emitErrorAndNote(loc, |
| 3195 | msg: "successor #" + Twine(i) + ", named '" + |
| 3196 | successor.name + "', not found" , |
| 3197 | note: "suggest adding a '$" + successor.name + |
| 3198 | "' directive to the custom assembly format" ); |
| 3199 | } |
| 3200 | } |
| 3201 | return success(); |
| 3202 | } |
| 3203 | |
| 3204 | LogicalResult |
| 3205 | OpFormatParser::verifyOIListElements(SMLoc loc, |
| 3206 | ArrayRef<FormatElement *> elements) { |
| 3207 | // Check that all of the successors are within the format. |
| 3208 | SmallVector<StringRef> prohibitedLiterals; |
| 3209 | for (FormatElement *it : elements) { |
| 3210 | if (auto *oilist = dyn_cast<OIListElement>(Val: it)) { |
| 3211 | if (!prohibitedLiterals.empty()) { |
| 3212 | // We just saw an oilist element in last iteration. Literals should not |
| 3213 | // match. |
| 3214 | for (LiteralElement *literal : oilist->getLiteralElements()) { |
| 3215 | if (find(Range&: prohibitedLiterals, Val: literal->getSpelling()) != |
| 3216 | prohibitedLiterals.end()) { |
| 3217 | return emitError( |
| 3218 | loc, msg: "format ambiguity because " + literal->getSpelling() + |
| 3219 | " is used in two adjacent oilist elements." ); |
| 3220 | } |
| 3221 | } |
| 3222 | } |
| 3223 | for (LiteralElement *literal : oilist->getLiteralElements()) |
| 3224 | prohibitedLiterals.push_back(Elt: literal->getSpelling()); |
| 3225 | } else if (auto *literal = dyn_cast<LiteralElement>(Val: it)) { |
| 3226 | if (find(Range&: prohibitedLiterals, Val: literal->getSpelling()) != |
| 3227 | prohibitedLiterals.end()) { |
| 3228 | return emitError( |
| 3229 | loc, |
| 3230 | msg: "format ambiguity because " + literal->getSpelling() + |
| 3231 | " is used both in oilist element and the adjacent literal." ); |
| 3232 | } |
| 3233 | prohibitedLiterals.clear(); |
| 3234 | } else { |
| 3235 | prohibitedLiterals.clear(); |
| 3236 | } |
| 3237 | } |
| 3238 | return success(); |
| 3239 | } |
| 3240 | |
| 3241 | void OpFormatParser::handleAllTypesMatchConstraint( |
| 3242 | ArrayRef<StringRef> values, |
| 3243 | StringMap<TypeResolutionInstance> &variableTyResolver) { |
| 3244 | for (unsigned i = 0, e = values.size(); i != e; ++i) { |
| 3245 | // Check to see if this value matches a resolved operand or result type. |
| 3246 | ConstArgument arg = findSeenArg(name: values[i]); |
| 3247 | if (!arg) |
| 3248 | continue; |
| 3249 | |
| 3250 | // Mark this value as the type resolver for the other variables. |
| 3251 | for (unsigned j = 0; j != i; ++j) |
| 3252 | variableTyResolver[values[j]] = {.resolver: arg, .transformer: std::nullopt}; |
| 3253 | for (unsigned j = i + 1; j != e; ++j) |
| 3254 | variableTyResolver[values[j]] = {.resolver: arg, .transformer: std::nullopt}; |
| 3255 | } |
| 3256 | } |
| 3257 | |
| 3258 | void OpFormatParser::handleSameTypesConstraint( |
| 3259 | StringMap<TypeResolutionInstance> &variableTyResolver, |
| 3260 | bool includeResults) { |
| 3261 | const NamedTypeConstraint *resolver = nullptr; |
| 3262 | int resolvedIt = -1; |
| 3263 | |
| 3264 | // Check to see if there is an operand or result to use for the resolution. |
| 3265 | if ((resolvedIt = seenOperandTypes.find_first()) != -1) |
| 3266 | resolver = &op.getOperand(index: resolvedIt); |
| 3267 | else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1) |
| 3268 | resolver = &op.getResult(index: resolvedIt); |
| 3269 | else |
| 3270 | return; |
| 3271 | |
| 3272 | // Set the resolvers for each operand and result. |
| 3273 | for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) |
| 3274 | if (!seenOperandTypes.test(Idx: i)) |
| 3275 | variableTyResolver[op.getOperand(index: i).name] = {.resolver: resolver, .transformer: std::nullopt}; |
| 3276 | if (includeResults) { |
| 3277 | for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) |
| 3278 | if (!seenResultTypes.test(Idx: i)) |
| 3279 | variableTyResolver[op.getResultName(index: i)] = {.resolver: resolver, .transformer: std::nullopt}; |
| 3280 | } |
| 3281 | } |
| 3282 | |
| 3283 | void OpFormatParser::handleTypesMatchConstraint( |
| 3284 | StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) { |
| 3285 | StringRef lhsName = def.getValueAsString(FieldName: "lhs" ); |
| 3286 | StringRef rhsName = def.getValueAsString(FieldName: "rhs" ); |
| 3287 | StringRef transformer = def.getValueAsString(FieldName: "transformer" ); |
| 3288 | if (ConstArgument arg = findSeenArg(name: lhsName)) |
| 3289 | variableTyResolver[rhsName] = {.resolver: arg, .transformer: transformer}; |
| 3290 | } |
| 3291 | |
| 3292 | ConstArgument OpFormatParser::findSeenArg(StringRef name) { |
| 3293 | if (const NamedTypeConstraint *arg = findArg(range: op.getOperands(), name)) |
| 3294 | return seenOperandTypes.test(Idx: arg - op.operand_begin()) ? arg : nullptr; |
| 3295 | if (const NamedTypeConstraint *arg = findArg(range: op.getResults(), name)) |
| 3296 | return seenResultTypes.test(Idx: arg - op.result_begin()) ? arg : nullptr; |
| 3297 | if (const NamedAttribute *attr = findArg(range: op.getAttributes(), name)) |
| 3298 | return seenAttrs.count(key: attr) ? attr : nullptr; |
| 3299 | return nullptr; |
| 3300 | } |
| 3301 | |
| 3302 | FailureOr<FormatElement *> |
| 3303 | OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) { |
| 3304 | // Check that the parsed argument is something actually registered on the op. |
| 3305 | // Attributes |
| 3306 | if (const NamedAttribute *attr = findArg(range: op.getAttributes(), name)) { |
| 3307 | if (ctx == TypeDirectiveContext) |
| 3308 | return emitError( |
| 3309 | loc, msg: "attributes cannot be used as children to a `type` directive" ); |
| 3310 | if (ctx == RefDirectiveContext) { |
| 3311 | if (!seenAttrs.count(key: attr)) |
| 3312 | return emitError(loc, msg: "attribute '" + name + |
| 3313 | "' must be bound before it is referenced" ); |
| 3314 | } else if (!seenAttrs.insert(X: attr)) { |
| 3315 | return emitError(loc, msg: "attribute '" + name + "' is already bound" ); |
| 3316 | } |
| 3317 | |
| 3318 | return create<AttributeVariable>(args&: attr); |
| 3319 | } |
| 3320 | |
| 3321 | if (const NamedProperty *property = findArg(range: op.getProperties(), name)) { |
| 3322 | if (ctx == TypeDirectiveContext) |
| 3323 | return emitError( |
| 3324 | loc, msg: "properties cannot be used as children to a `type` directive" ); |
| 3325 | if (ctx == RefDirectiveContext) { |
| 3326 | if (!seenProperties.count(key: property)) |
| 3327 | return emitError(loc, msg: "property '" + name + |
| 3328 | "' must be bound before it is referenced" ); |
| 3329 | } else { |
| 3330 | if (!seenProperties.insert(X: property)) |
| 3331 | return emitError(loc, msg: "property '" + name + "' is already bound" ); |
| 3332 | } |
| 3333 | |
| 3334 | return create<PropertyVariable>(args&: property); |
| 3335 | } |
| 3336 | |
| 3337 | // Operands |
| 3338 | if (const NamedTypeConstraint *operand = findArg(range: op.getOperands(), name)) { |
| 3339 | if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { |
| 3340 | if (fmt.allOperands || !seenOperands.insert(V: operand).second) |
| 3341 | return emitError(loc, msg: "operand '" + name + "' is already bound" ); |
| 3342 | } else if (ctx == RefDirectiveContext && !seenOperands.count(V: operand)) { |
| 3343 | return emitError(loc, msg: "operand '" + name + |
| 3344 | "' must be bound before it is referenced" ); |
| 3345 | } |
| 3346 | return create<OperandVariable>(args&: operand); |
| 3347 | } |
| 3348 | // Regions |
| 3349 | if (const NamedRegion *region = findArg(range: op.getRegions(), name)) { |
| 3350 | if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { |
| 3351 | if (hasAllRegions || !seenRegions.insert(V: region).second) |
| 3352 | return emitError(loc, msg: "region '" + name + "' is already bound" ); |
| 3353 | } else if (ctx == RefDirectiveContext && !seenRegions.count(V: region)) { |
| 3354 | return emitError(loc, msg: "region '" + name + |
| 3355 | "' must be bound before it is referenced" ); |
| 3356 | } else { |
| 3357 | return emitError(loc, msg: "regions can only be used at the top level" ); |
| 3358 | } |
| 3359 | return create<RegionVariable>(args&: region); |
| 3360 | } |
| 3361 | // Results. |
| 3362 | if (const auto *result = findArg(range: op.getResults(), name)) { |
| 3363 | if (ctx != TypeDirectiveContext) |
| 3364 | return emitError(loc, msg: "result variables can can only be used as a child " |
| 3365 | "to a 'type' directive" ); |
| 3366 | return create<ResultVariable>(args&: result); |
| 3367 | } |
| 3368 | // Successors. |
| 3369 | if (const auto *successor = findArg(range: op.getSuccessors(), name)) { |
| 3370 | if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { |
| 3371 | if (hasAllSuccessors || !seenSuccessors.insert(V: successor).second) |
| 3372 | return emitError(loc, msg: "successor '" + name + "' is already bound" ); |
| 3373 | } else if (ctx == RefDirectiveContext && !seenSuccessors.count(V: successor)) { |
| 3374 | return emitError(loc, msg: "successor '" + name + |
| 3375 | "' must be bound before it is referenced" ); |
| 3376 | } else { |
| 3377 | return emitError(loc, msg: "successors can only be used at the top level" ); |
| 3378 | } |
| 3379 | |
| 3380 | return create<SuccessorVariable>(args&: successor); |
| 3381 | } |
| 3382 | return emitError(loc, msg: "expected variable to refer to an argument, region, " |
| 3383 | "result, or successor" ); |
| 3384 | } |
| 3385 | |
| 3386 | FailureOr<FormatElement *> |
| 3387 | OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, |
| 3388 | Context ctx) { |
| 3389 | switch (kind) { |
| 3390 | case FormatToken::kw_prop_dict: |
| 3391 | return parsePropDictDirective(loc, context: ctx); |
| 3392 | case FormatToken::kw_attr_dict: |
| 3393 | return parseAttrDictDirective(loc, context: ctx, |
| 3394 | /*withKeyword=*/false); |
| 3395 | case FormatToken::kw_attr_dict_w_keyword: |
| 3396 | return parseAttrDictDirective(loc, context: ctx, |
| 3397 | /*withKeyword=*/true); |
| 3398 | case FormatToken::kw_functional_type: |
| 3399 | return parseFunctionalTypeDirective(loc, context: ctx); |
| 3400 | case FormatToken::kw_operands: |
| 3401 | return parseOperandsDirective(loc, context: ctx); |
| 3402 | case FormatToken::kw_regions: |
| 3403 | return parseRegionsDirective(loc, context: ctx); |
| 3404 | case FormatToken::kw_results: |
| 3405 | return parseResultsDirective(loc, context: ctx); |
| 3406 | case FormatToken::kw_successors: |
| 3407 | return parseSuccessorsDirective(loc, context: ctx); |
| 3408 | case FormatToken::kw_type: |
| 3409 | return parseTypeDirective(loc, context: ctx); |
| 3410 | case FormatToken::kw_oilist: |
| 3411 | return parseOIListDirective(loc, context: ctx); |
| 3412 | |
| 3413 | default: |
| 3414 | return emitError(loc, msg: "unsupported directive kind" ); |
| 3415 | } |
| 3416 | } |
| 3417 | |
| 3418 | FailureOr<FormatElement *> |
| 3419 | OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context, |
| 3420 | bool withKeyword) { |
| 3421 | if (context == TypeDirectiveContext) |
| 3422 | return emitError(loc, msg: "'attr-dict' directive can only be used as a " |
| 3423 | "top-level directive" ); |
| 3424 | |
| 3425 | if (context == RefDirectiveContext) { |
| 3426 | if (!hasAttrDict) |
| 3427 | return emitError(loc, msg: "'ref' of 'attr-dict' is not bound by a prior " |
| 3428 | "'attr-dict' directive" ); |
| 3429 | |
| 3430 | // Otherwise, this is a top-level context. |
| 3431 | } else { |
| 3432 | if (hasAttrDict) |
| 3433 | return emitError(loc, msg: "'attr-dict' directive has already been seen" ); |
| 3434 | hasAttrDict = true; |
| 3435 | } |
| 3436 | |
| 3437 | return create<AttrDictDirective>(args&: withKeyword); |
| 3438 | } |
| 3439 | |
| 3440 | FailureOr<FormatElement *> |
| 3441 | OpFormatParser::parsePropDictDirective(SMLoc loc, Context context) { |
| 3442 | if (context == TypeDirectiveContext) |
| 3443 | return emitError(loc, msg: "'prop-dict' directive can only be used as a " |
| 3444 | "top-level directive" ); |
| 3445 | |
| 3446 | if (context == RefDirectiveContext) |
| 3447 | llvm::report_fatal_error(reason: "'ref' of 'prop-dict' unsupported" ); |
| 3448 | // Otherwise, this is a top-level context. |
| 3449 | |
| 3450 | if (hasPropDict) |
| 3451 | return emitError(loc, msg: "'prop-dict' directive has already been seen" ); |
| 3452 | hasPropDict = true; |
| 3453 | |
| 3454 | return create<PropDictDirective>(); |
| 3455 | } |
| 3456 | |
| 3457 | LogicalResult OpFormatParser::verifyCustomDirectiveArguments( |
| 3458 | SMLoc loc, ArrayRef<FormatElement *> arguments) { |
| 3459 | for (FormatElement *argument : arguments) { |
| 3460 | if (!isa<AttrDictDirective, PropDictDirective, AttributeVariable, |
| 3461 | OperandVariable, PropertyVariable, RefDirective, RegionVariable, |
| 3462 | SuccessorVariable, StringElement, TypeDirective>(Val: argument)) { |
| 3463 | // TODO: FormatElement should have location info attached. |
| 3464 | return emitError(loc, msg: "only variables and types may be used as " |
| 3465 | "parameters to a custom directive" ); |
| 3466 | } |
| 3467 | if (auto *type = dyn_cast<TypeDirective>(Val: argument)) { |
| 3468 | if (!isa<OperandVariable, ResultVariable>(Val: type->getArg())) { |
| 3469 | return emitError(loc, msg: "type directives within a custom directive may " |
| 3470 | "only refer to variables" ); |
| 3471 | } |
| 3472 | } |
| 3473 | } |
| 3474 | return success(); |
| 3475 | } |
| 3476 | |
| 3477 | FailureOr<FormatElement *> |
| 3478 | OpFormatParser::parseFunctionalTypeDirective(SMLoc loc, Context context) { |
| 3479 | if (context != TopLevelContext) |
| 3480 | return emitError( |
| 3481 | loc, msg: "'functional-type' is only valid as a top-level directive" ); |
| 3482 | |
| 3483 | // Parse the main operand. |
| 3484 | FailureOr<FormatElement *> inputs, results; |
| 3485 | if (failed(Result: parseToken(kind: FormatToken::l_paren, |
| 3486 | msg: "expected '(' before argument list" )) || |
| 3487 | failed(Result: inputs = parseTypeDirectiveOperand(loc)) || |
| 3488 | failed(Result: parseToken(kind: FormatToken::comma, |
| 3489 | msg: "expected ',' after inputs argument" )) || |
| 3490 | failed(Result: results = parseTypeDirectiveOperand(loc)) || |
| 3491 | failed( |
| 3492 | Result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list" ))) |
| 3493 | return failure(); |
| 3494 | return create<FunctionalTypeDirective>(args&: *inputs, args&: *results); |
| 3495 | } |
| 3496 | |
| 3497 | FailureOr<FormatElement *> |
| 3498 | OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) { |
| 3499 | if (context == RefDirectiveContext) { |
| 3500 | if (!fmt.allOperands) |
| 3501 | return emitError(loc, msg: "'ref' of 'operands' is not bound by a prior " |
| 3502 | "'operands' directive" ); |
| 3503 | |
| 3504 | } else if (context == TopLevelContext || context == CustomDirectiveContext) { |
| 3505 | if (fmt.allOperands || !seenOperands.empty()) |
| 3506 | return emitError(loc, msg: "'operands' directive creates overlap in format" ); |
| 3507 | fmt.allOperands = true; |
| 3508 | } |
| 3509 | return create<OperandsDirective>(); |
| 3510 | } |
| 3511 | |
| 3512 | FailureOr<FormatElement *> |
| 3513 | OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) { |
| 3514 | if (context == TypeDirectiveContext) |
| 3515 | return emitError(loc, msg: "'regions' is only valid as a top-level directive" ); |
| 3516 | if (context == RefDirectiveContext) { |
| 3517 | if (!hasAllRegions) |
| 3518 | return emitError(loc, msg: "'ref' of 'regions' is not bound by a prior " |
| 3519 | "'regions' directive" ); |
| 3520 | |
| 3521 | // Otherwise, this is a TopLevel directive. |
| 3522 | } else { |
| 3523 | if (hasAllRegions || !seenRegions.empty()) |
| 3524 | return emitError(loc, msg: "'regions' directive creates overlap in format" ); |
| 3525 | hasAllRegions = true; |
| 3526 | } |
| 3527 | return create<RegionsDirective>(); |
| 3528 | } |
| 3529 | |
| 3530 | FailureOr<FormatElement *> |
| 3531 | OpFormatParser::parseResultsDirective(SMLoc loc, Context context) { |
| 3532 | if (context != TypeDirectiveContext) |
| 3533 | return emitError(loc, msg: "'results' directive can can only be used as a child " |
| 3534 | "to a 'type' directive" ); |
| 3535 | return create<ResultsDirective>(); |
| 3536 | } |
| 3537 | |
| 3538 | FailureOr<FormatElement *> |
| 3539 | OpFormatParser::parseSuccessorsDirective(SMLoc loc, Context context) { |
| 3540 | if (context == TypeDirectiveContext) |
| 3541 | return emitError(loc, |
| 3542 | msg: "'successors' is only valid as a top-level directive" ); |
| 3543 | if (context == RefDirectiveContext) { |
| 3544 | if (!hasAllSuccessors) |
| 3545 | return emitError(loc, msg: "'ref' of 'successors' is not bound by a prior " |
| 3546 | "'successors' directive" ); |
| 3547 | |
| 3548 | // Otherwise, this is a TopLevel directive. |
| 3549 | } else { |
| 3550 | if (hasAllSuccessors || !seenSuccessors.empty()) |
| 3551 | return emitError(loc, msg: "'successors' directive creates overlap in format" ); |
| 3552 | hasAllSuccessors = true; |
| 3553 | } |
| 3554 | return create<SuccessorsDirective>(); |
| 3555 | } |
| 3556 | |
| 3557 | FailureOr<FormatElement *> |
| 3558 | OpFormatParser::parseOIListDirective(SMLoc loc, Context context) { |
| 3559 | if (failed(Result: parseToken(kind: FormatToken::l_paren, |
| 3560 | msg: "expected '(' before oilist argument list" ))) |
| 3561 | return failure(); |
| 3562 | std::vector<FormatElement *> literalElements; |
| 3563 | std::vector<std::vector<FormatElement *>> parsingElements; |
| 3564 | do { |
| 3565 | FailureOr<FormatElement *> lelement = parseLiteral(ctx: context); |
| 3566 | if (failed(Result: lelement)) |
| 3567 | return failure(); |
| 3568 | literalElements.push_back(x: *lelement); |
| 3569 | parsingElements.emplace_back(); |
| 3570 | std::vector<FormatElement *> &currParsingElements = parsingElements.back(); |
| 3571 | while (peekToken().getKind() != FormatToken::pipe && |
| 3572 | peekToken().getKind() != FormatToken::r_paren) { |
| 3573 | FailureOr<FormatElement *> pelement = parseElement(ctx: context); |
| 3574 | if (failed(Result: pelement) || |
| 3575 | failed(Result: verifyOIListParsingElement(element: *pelement, loc))) |
| 3576 | return failure(); |
| 3577 | currParsingElements.push_back(x: *pelement); |
| 3578 | } |
| 3579 | if (peekToken().getKind() == FormatToken::pipe) { |
| 3580 | consumeToken(); |
| 3581 | continue; |
| 3582 | } |
| 3583 | if (peekToken().getKind() == FormatToken::r_paren) { |
| 3584 | consumeToken(); |
| 3585 | break; |
| 3586 | } |
| 3587 | } while (true); |
| 3588 | |
| 3589 | return create<OIListElement>(args: std::move(literalElements), |
| 3590 | args: std::move(parsingElements)); |
| 3591 | } |
| 3592 | |
| 3593 | LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element, |
| 3594 | SMLoc loc) { |
| 3595 | SmallVector<VariableElement *> vars; |
| 3596 | collect(element, variables&: vars); |
| 3597 | for (VariableElement *elem : vars) { |
| 3598 | LogicalResult res = |
| 3599 | TypeSwitch<FormatElement *, LogicalResult>(elem) |
| 3600 | // Only optional attributes can be within an oilist parsing group. |
| 3601 | .Case(caseFn: [&](AttributeVariable *attrEle) { |
| 3602 | if (!attrEle->getVar()->attr.isOptional() && |
| 3603 | !attrEle->getVar()->attr.hasDefaultValue()) |
| 3604 | return emitError(loc, msg: "only optional attributes can be used in " |
| 3605 | "an oilist parsing group" ); |
| 3606 | return success(); |
| 3607 | }) |
| 3608 | // Only optional properties can be within an oilist parsing group. |
| 3609 | .Case(caseFn: [&](PropertyVariable *propEle) { |
| 3610 | if (!propEle->getVar()->prop.hasDefaultValue()) |
| 3611 | return emitError( |
| 3612 | loc, |
| 3613 | msg: "only default-valued or optional properties can be used in " |
| 3614 | "an olist parsing group" ); |
| 3615 | return success(); |
| 3616 | }) |
| 3617 | // Only optional-like(i.e. variadic) operands can be within an |
| 3618 | // oilist parsing group. |
| 3619 | .Case(caseFn: [&](OperandVariable *ele) { |
| 3620 | if (!ele->getVar()->isVariableLength()) |
| 3621 | return emitError(loc, msg: "only variable length operands can be " |
| 3622 | "used within an oilist parsing group" ); |
| 3623 | return success(); |
| 3624 | }) |
| 3625 | // Only optional-like(i.e. variadic) results can be within an oilist |
| 3626 | // parsing group. |
| 3627 | .Case(caseFn: [&](ResultVariable *ele) { |
| 3628 | if (!ele->getVar()->isVariableLength()) |
| 3629 | return emitError(loc, msg: "only variable length results can be " |
| 3630 | "used within an oilist parsing group" ); |
| 3631 | return success(); |
| 3632 | }) |
| 3633 | .Case(caseFn: [&](RegionVariable *) { return success(); }) |
| 3634 | .Default(defaultFn: [&](FormatElement *) { |
| 3635 | return emitError(loc, |
| 3636 | msg: "only literals, types, and variables can be " |
| 3637 | "used within an oilist group" ); |
| 3638 | }); |
| 3639 | if (failed(Result: res)) |
| 3640 | return failure(); |
| 3641 | } |
| 3642 | return success(); |
| 3643 | } |
| 3644 | |
| 3645 | FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc, |
| 3646 | Context context) { |
| 3647 | if (context == TypeDirectiveContext) |
| 3648 | return emitError(loc, msg: "'type' cannot be used as a child of another `type`" ); |
| 3649 | |
| 3650 | bool isRefChild = context == RefDirectiveContext; |
| 3651 | FailureOr<FormatElement *> operand; |
| 3652 | if (failed(Result: parseToken(kind: FormatToken::l_paren, |
| 3653 | msg: "expected '(' before argument list" )) || |
| 3654 | failed(Result: operand = parseTypeDirectiveOperand(loc, isRefChild)) || |
| 3655 | failed( |
| 3656 | Result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list" ))) |
| 3657 | return failure(); |
| 3658 | |
| 3659 | return create<TypeDirective>(args&: *operand); |
| 3660 | } |
| 3661 | |
| 3662 | LogicalResult OpFormatParser::markQualified(SMLoc loc, FormatElement *element) { |
| 3663 | return TypeSwitch<FormatElement *, LogicalResult>(element) |
| 3664 | .Case<AttributeVariable, TypeDirective>(caseFn: [](auto *element) { |
| 3665 | element->setShouldBeQualified(); |
| 3666 | return success(); |
| 3667 | }) |
| 3668 | .Default(defaultFn: [&](auto *element) { |
| 3669 | return this->emitError( |
| 3670 | loc, |
| 3671 | msg: "'qualified' directive expects an attribute or a `type` directive" ); |
| 3672 | }); |
| 3673 | } |
| 3674 | |
| 3675 | FailureOr<FormatElement *> |
| 3676 | OpFormatParser::parseTypeDirectiveOperand(SMLoc loc, bool isRefChild) { |
| 3677 | FailureOr<FormatElement *> result = parseElement(ctx: TypeDirectiveContext); |
| 3678 | if (failed(Result: result)) |
| 3679 | return failure(); |
| 3680 | |
| 3681 | FormatElement *element = *result; |
| 3682 | if (isa<LiteralElement>(Val: element)) |
| 3683 | return emitError( |
| 3684 | loc, msg: "'type' directive operand expects variable or directive operand" ); |
| 3685 | |
| 3686 | if (auto *var = dyn_cast<OperandVariable>(Val: element)) { |
| 3687 | unsigned opIdx = var->getVar() - op.operand_begin(); |
| 3688 | if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(Idx: opIdx))) |
| 3689 | return emitError(loc, msg: "'type' of '" + var->getVar()->name + |
| 3690 | "' is already bound" ); |
| 3691 | if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(Idx: opIdx))) |
| 3692 | return emitError(loc, msg: "'ref' of 'type($" + var->getVar()->name + |
| 3693 | ")' is not bound by a prior 'type' directive" ); |
| 3694 | seenOperandTypes.set(opIdx); |
| 3695 | } else if (auto *var = dyn_cast<ResultVariable>(Val: element)) { |
| 3696 | unsigned resIdx = var->getVar() - op.result_begin(); |
| 3697 | if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(Idx: resIdx))) |
| 3698 | return emitError(loc, msg: "'type' of '" + var->getVar()->name + |
| 3699 | "' is already bound" ); |
| 3700 | if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(Idx: resIdx))) |
| 3701 | return emitError(loc, msg: "'ref' of 'type($" + var->getVar()->name + |
| 3702 | ")' is not bound by a prior 'type' directive" ); |
| 3703 | seenResultTypes.set(resIdx); |
| 3704 | } else if (isa<OperandsDirective>(Val: &*element)) { |
| 3705 | if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any())) |
| 3706 | return emitError(loc, msg: "'operands' 'type' is already bound" ); |
| 3707 | if (isRefChild && !fmt.allOperandTypes) |
| 3708 | return emitError(loc, msg: "'ref' of 'type(operands)' is not bound by a prior " |
| 3709 | "'type' directive" ); |
| 3710 | fmt.allOperandTypes = true; |
| 3711 | } else if (isa<ResultsDirective>(Val: &*element)) { |
| 3712 | if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any())) |
| 3713 | return emitError(loc, msg: "'results' 'type' is already bound" ); |
| 3714 | if (isRefChild && !fmt.allResultTypes) |
| 3715 | return emitError(loc, msg: "'ref' of 'type(results)' is not bound by a prior " |
| 3716 | "'type' directive" ); |
| 3717 | fmt.allResultTypes = true; |
| 3718 | } else { |
| 3719 | return emitError(loc, msg: "invalid argument to 'type' directive" ); |
| 3720 | } |
| 3721 | return element; |
| 3722 | } |
| 3723 | |
| 3724 | LogicalResult OpFormatParser::verifyOptionalGroupElements( |
| 3725 | SMLoc loc, ArrayRef<FormatElement *> elements, FormatElement *anchor) { |
| 3726 | for (FormatElement *element : elements) { |
| 3727 | if (failed(Result: verifyOptionalGroupElement(loc, element, isAnchor: element == anchor))) |
| 3728 | return failure(); |
| 3729 | } |
| 3730 | return success(); |
| 3731 | } |
| 3732 | |
| 3733 | LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc, |
| 3734 | FormatElement *element, |
| 3735 | bool isAnchor) { |
| 3736 | return TypeSwitch<FormatElement *, LogicalResult>(element) |
| 3737 | // All attributes can be within the optional group, but only optional |
| 3738 | // attributes can be the anchor. |
| 3739 | .Case(caseFn: [&](AttributeVariable *attrEle) { |
| 3740 | Attribute attr = attrEle->getVar()->attr; |
| 3741 | if (isAnchor && !(attr.isOptional() || attr.hasDefaultValue())) |
| 3742 | return emitError(loc, msg: "only optional or default-valued attributes " |
| 3743 | "can be used to anchor an optional group" ); |
| 3744 | return success(); |
| 3745 | }) |
| 3746 | // All properties can be within the optional group, but only optional |
| 3747 | // properties can be the anchor. |
| 3748 | .Case(caseFn: [&](PropertyVariable *propEle) { |
| 3749 | Property prop = propEle->getVar()->prop; |
| 3750 | if (isAnchor && !(prop.hasDefaultValue() && prop.hasOptionalParser())) |
| 3751 | return emitError(loc, msg: "only properties with default values " |
| 3752 | "that can be optionally parsed " |
| 3753 | "can be used to anchor an optional group" ); |
| 3754 | return success(); |
| 3755 | }) |
| 3756 | // Only optional-like(i.e. variadic) operands can be within an optional |
| 3757 | // group. |
| 3758 | .Case(caseFn: [&](OperandVariable *ele) { |
| 3759 | if (!ele->getVar()->isVariableLength()) |
| 3760 | return emitError(loc, msg: "only variable length operands can be used " |
| 3761 | "within an optional group" ); |
| 3762 | return success(); |
| 3763 | }) |
| 3764 | // Only optional-like(i.e. variadic) results can be within an optional |
| 3765 | // group. |
| 3766 | .Case(caseFn: [&](ResultVariable *ele) { |
| 3767 | if (!ele->getVar()->isVariableLength()) |
| 3768 | return emitError(loc, msg: "only variable length results can be used " |
| 3769 | "within an optional group" ); |
| 3770 | return success(); |
| 3771 | }) |
| 3772 | .Case(caseFn: [&](RegionVariable *) { |
| 3773 | // TODO: When ODS has proper support for marking "optional" regions, add |
| 3774 | // a check here. |
| 3775 | return success(); |
| 3776 | }) |
| 3777 | .Case(caseFn: [&](TypeDirective *ele) { |
| 3778 | return verifyOptionalGroupElement(loc, element: ele->getArg(), |
| 3779 | /*isAnchor=*/false); |
| 3780 | }) |
| 3781 | .Case(caseFn: [&](FunctionalTypeDirective *ele) { |
| 3782 | if (failed(Result: verifyOptionalGroupElement(loc, element: ele->getInputs(), |
| 3783 | /*isAnchor=*/false))) |
| 3784 | return failure(); |
| 3785 | return verifyOptionalGroupElement(loc, element: ele->getResults(), |
| 3786 | /*isAnchor=*/false); |
| 3787 | }) |
| 3788 | .Case(caseFn: [&](CustomDirective *ele) { |
| 3789 | if (!isAnchor) |
| 3790 | return success(); |
| 3791 | // Verify each child as being valid in an optional group. They are all |
| 3792 | // potential anchors if the custom directive was marked as one. |
| 3793 | for (FormatElement *child : ele->getElements()) { |
| 3794 | if (isa<RefDirective>(Val: child)) |
| 3795 | continue; |
| 3796 | if (failed(Result: verifyOptionalGroupElement(loc, element: child, /*isAnchor=*/true))) |
| 3797 | return failure(); |
| 3798 | } |
| 3799 | return success(); |
| 3800 | }) |
| 3801 | // Literals, whitespace, and custom directives may be used, but they can't |
| 3802 | // anchor the group. |
| 3803 | .Case<LiteralElement, WhitespaceElement, OptionalElement>( |
| 3804 | caseFn: [&](FormatElement *) { |
| 3805 | if (isAnchor) |
| 3806 | return emitError(loc, msg: "only variables and types can be used " |
| 3807 | "to anchor an optional group" ); |
| 3808 | return success(); |
| 3809 | }) |
| 3810 | .Default(defaultFn: [&](FormatElement *) { |
| 3811 | return emitError(loc, msg: "only literals, types, and variables can be " |
| 3812 | "used within an optional group" ); |
| 3813 | }); |
| 3814 | } |
| 3815 | |
| 3816 | //===----------------------------------------------------------------------===// |
| 3817 | // Interface |
| 3818 | //===----------------------------------------------------------------------===// |
| 3819 | |
| 3820 | void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass, |
| 3821 | bool hasProperties) { |
| 3822 | // TODO: Operator doesn't expose all necessary functionality via |
| 3823 | // the const interface. |
| 3824 | Operator &op = const_cast<Operator &>(constOp); |
| 3825 | if (!op.hasAssemblyFormat()) |
| 3826 | return; |
| 3827 | |
| 3828 | // Parse the format description. |
| 3829 | llvm::SourceMgr mgr; |
| 3830 | mgr.AddNewSourceBuffer( |
| 3831 | F: llvm::MemoryBuffer::getMemBuffer(InputData: op.getAssemblyFormat()), IncludeLoc: SMLoc()); |
| 3832 | OperationFormat format(op, hasProperties); |
| 3833 | OpFormatParser parser(mgr, format, op); |
| 3834 | FailureOr<std::vector<FormatElement *>> elements = parser.parse(); |
| 3835 | if (failed(Result: elements)) { |
| 3836 | // Exit the process if format errors are treated as fatal. |
| 3837 | if (formatErrorIsFatal) { |
| 3838 | // Invoke the interrupt handlers to run the file cleanup handlers. |
| 3839 | llvm::sys::RunInterruptHandlers(); |
| 3840 | std::exit(status: 1); |
| 3841 | } |
| 3842 | return; |
| 3843 | } |
| 3844 | format.elements = std::move(*elements); |
| 3845 | |
| 3846 | // Generate the printer and parser based on the parsed format. |
| 3847 | format.genParser(op, opClass); |
| 3848 | format.genPrinter(op, opClass); |
| 3849 | } |
| 3850 | |