1//===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "AttrOrTypeFormatGen.h"
10#include "FormatGen.h"
11#include "mlir/Support/LLVM.h"
12#include "mlir/TableGen/AttrOrTypeDef.h"
13#include "mlir/TableGen/Format.h"
14#include "mlir/TableGen/GenInfo.h"
15#include "llvm/ADT/BitVector.h"
16#include "llvm/ADT/SmallVectorExtras.h"
17#include "llvm/ADT/StringExtras.h"
18#include "llvm/ADT/StringSwitch.h"
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/MemoryBuffer.h"
21#include "llvm/Support/SaveAndRestore.h"
22#include "llvm/Support/SourceMgr.h"
23#include "llvm/TableGen/Error.h"
24#include "llvm/TableGen/TableGenBackend.h"
25
26using namespace mlir;
27using namespace mlir::tblgen;
28
29using llvm::formatv;
30
31//===----------------------------------------------------------------------===//
32// Element
33//===----------------------------------------------------------------------===//
34
35namespace {
36/// This class represents an instance of a variable element. A variable refers
37/// to an attribute or type parameter.
38class ParameterElement
39 : public VariableElementBase<VariableElement::Parameter> {
40public:
41 ParameterElement(AttrOrTypeParameter param) : param(param) {}
42
43 /// Get the parameter in the element.
44 const AttrOrTypeParameter &getParam() const { return param; }
45
46 /// Indicate if this variable is printed "qualified" (that is it is
47 /// prefixed with the `#dialect.mnemonic`).
48 bool shouldBeQualified() { return shouldBeQualifiedFlag; }
49 void setShouldBeQualified(bool qualified = true) {
50 shouldBeQualifiedFlag = qualified;
51 }
52
53 /// Returns true if the element contains an optional parameter.
54 bool isOptional() const { return param.isOptional(); }
55
56 /// Returns the name of the parameter.
57 StringRef getName() const { return param.getName(); }
58
59 /// Return the code to check whether the parameter is present.
60 auto genIsPresent(FmtContext &ctx, const Twine &self) const {
61 assert(isOptional() && "cannot guard on a mandatory parameter");
62 std::string valueStr = tgfmt(fmt: *param.getDefaultValue(), ctx: &ctx).str();
63 ctx.addSubst(placeholder: "_lhs", subst: self).addSubst(placeholder: "_rhs", subst: valueStr);
64 return tgfmt(fmt: getParam().getComparator(), ctx: &ctx);
65 }
66
67 /// Generate the code to check whether the parameter should be printed.
68 MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const {
69 assert(isOptional() && "cannot guard on a mandatory parameter");
70 std::string self = param.getAccessorName() + "()";
71 return os << "!(" << genIsPresent(ctx, self) << ")";
72 }
73
74private:
75 bool shouldBeQualifiedFlag = false;
76 AttrOrTypeParameter param;
77};
78
79/// Utility to return the encapsulated parameter element for the provided format
80/// element. This parameter can originate from either a `ParameterElement`,
81/// `CustomDirective` with a single parameter argument or `RefDirective`.
82static ParameterElement *getEncapsulatedParameterElement(FormatElement *el) {
83 return TypeSwitch<FormatElement *, ParameterElement *>(el)
84 .Case<CustomDirective>(caseFn: [&](auto custom) {
85 FailureOr<ParameterElement *> maybeParam =
86 custom->template getFrontAs<ParameterElement>();
87 return *maybeParam;
88 })
89 .Case<ParameterElement>(caseFn: [&](auto param) { return param; })
90 .Case<RefDirective>(
91 caseFn: [&](auto ref) { return cast<ParameterElement>(ref->getArg()); })
92 .Default(defaultFn: [&](auto el) {
93 assert(false && "unexpected struct element type");
94 return nullptr;
95 });
96}
97
98/// Shorthand functions that can be used with ranged-based conditions.
99static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
100static bool formatIsOptional(FormatElement *el) {
101 ParameterElement *param = getEncapsulatedParameterElement(el);
102 return param != nullptr && param->isOptional();
103}
104static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
105static bool formatNotOptional(FormatElement *el) {
106 return !formatIsOptional(el);
107}
108
109/// This class represents a `params` directive that refers to all parameters
110/// of an attribute or type. When used as a top-level directive, it generates
111/// a format of the form:
112///
113/// (param-value (`,` param-value)*)?
114///
115/// When used as an argument to another directive that accepts variables,
116/// `params` can be used in place of manually listing all parameters of an
117/// attribute or type.
118class ParamsDirective
119 : public VectorDirectiveBase<DirectiveElement::Params, ParameterElement *> {
120public:
121 using Base::Base;
122
123 /// Returns true if there are optional parameters present.
124 bool hasOptionalElements() const {
125 return llvm::any_of(Range: getElements(), P: paramIsOptional);
126 }
127};
128
129/// This class represents a `struct` directive that generates a struct format
130/// of the form:
131///
132/// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
133///
134class StructDirective
135 : public VectorDirectiveBase<DirectiveElement::Struct, FormatElement *> {
136public:
137 using Base::Base;
138
139 /// Returns true if there are optional format elements present.
140 bool hasOptionalElements() const {
141 return llvm::any_of(Range: getElements(), P: formatIsOptional);
142 }
143};
144
145} // namespace
146
147//===----------------------------------------------------------------------===//
148// Format Strings
149//===----------------------------------------------------------------------===//
150
151/// Default parser for attribute or type parameters.
152static const char *const defaultParameterParser =
153 "::mlir::FieldParser<$0>::parse($_parser)";
154
155/// Default printer for attribute or type parameters.
156static const char *const defaultParameterPrinter =
157 "$_printer.printStrippedAttrOrType($_self)";
158
159/// Qualified printer for attribute or type parameters: it does not elide
160/// dialect and mnemonic.
161static const char *const qualifiedParameterPrinter = "$_printer << $_self";
162
163/// Print an error when failing to parse an element.
164///
165/// $0: The parameter C++ class name.
166static const char *const parserErrorStr =
167 "$_parser.emitError($_parser.getCurrentLocation(), ";
168
169/// Code format to parse a variable. Separate by lines because variable parsers
170/// may be generated inside other directives, which requires indentation.
171///
172/// {0}: The parameter name.
173/// {1}: The parse code for the parameter.
174/// {2}: Code template for printing an error.
175/// {3}: Name of the attribute or type.
176/// {4}: C++ class of the parameter.
177/// {5}: Optional code to preload the dialect for this variable.
178static const char *const variableParser = R"(
179// Parse variable '{0}'{5}
180_result_{0} = {1};
181if (::mlir::failed(_result_{0})) {{
182 {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
183 return {{};
184}
185)";
186
187//===----------------------------------------------------------------------===//
188// DefFormat
189//===----------------------------------------------------------------------===//
190
191namespace {
192class DefFormat {
193public:
194 DefFormat(const AttrOrTypeDef &def, std::vector<FormatElement *> &&elements)
195 : def(def), elements(std::move(elements)) {}
196
197 /// Generate the attribute or type parser.
198 void genParser(MethodBody &os);
199 /// Generate the attribute or type printer.
200 void genPrinter(MethodBody &os);
201
202private:
203 /// Generate the parser code for a specific format element.
204 void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os);
205 /// Generate the parser code for a literal.
206 void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os,
207 bool isOptional = false);
208 /// Generate the parser code for a variable.
209 void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os);
210 /// Generate the parser code for a `params` directive.
211 void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
212 /// Generate the parser code for a `struct` directive.
213 void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
214 /// Generate the parser code for a `custom` directive.
215 void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os,
216 bool isOptional = false);
217 /// Generate the parser code for an optional group.
218 void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
219 MethodBody &os);
220
221 /// Generate the printer code for a specific format element.
222 void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os);
223 /// Generate the printer code for a literal.
224 void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os);
225 /// Generate the printer code for a variable.
226 void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
227 bool skipGuard = false);
228 /// Generate a printer for comma-separated format elements.
229 void genCommaSeparatedPrinter(ArrayRef<FormatElement *> params,
230 FmtContext &ctx, MethodBody &os,
231 function_ref<void(FormatElement *)> extra);
232 /// Generate the printer code for a `params` directive.
233 void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
234 /// Generate the printer code for a `struct` directive.
235 void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
236 /// Generate the printer code for a `custom` directive.
237 void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os);
238 /// Generate the printer code for an optional group.
239 void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
240 MethodBody &os);
241 /// Generate a printer (or space eraser) for a whitespace element.
242 void genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
243 MethodBody &os);
244
245 /// The ODS definition of the attribute or type whose format is being used to
246 /// generate a parser and printer.
247 const AttrOrTypeDef &def;
248 /// The list of top-level format elements returned by the assembly format
249 /// parser.
250 std::vector<FormatElement *> elements;
251
252 /// Flags for printing spaces.
253 bool shouldEmitSpace = false;
254 bool lastWasPunctuation = false;
255};
256} // namespace
257
258//===----------------------------------------------------------------------===//
259// ParserGen
260//===----------------------------------------------------------------------===//
261
262/// Generate a special-case "parser" for an attribute's self type parameter. The
263/// self type parameter has special handling in the assembly format in that it
264/// is derived from the optional trailing colon type after the attribute.
265static void genAttrSelfTypeParser(MethodBody &os, const FmtContext &ctx,
266 const AttributeSelfTypeParameter &param) {
267 // "Parser" for an attribute self type parameter that checks the
268 // optionally-parsed trailing colon type.
269 //
270 // $0: The C++ storage class of the type parameter.
271 // $1: The self type parameter name.
272 const char *const selfTypeParser = R"(
273if ($_type) {
274 if (auto reqType = ::llvm::dyn_cast<$0>($_type)) {
275 _result_$1 = reqType;
276 } else {
277 $_parser.emitError($_loc, "invalid kind of type specified");
278 return {};
279 }
280})";
281
282 // If the attribute self type parameter is required, emit code that emits an
283 // error if the trailing type was not parsed.
284 const char *const selfTypeRequired = R"( else {
285 $_parser.emitError($_loc, "expected a trailing type");
286 return {};
287})";
288
289 os << tgfmt(fmt: selfTypeParser, ctx: &ctx, vals: param.getCppStorageType(), vals: param.getName());
290 if (!param.isOptional())
291 os << tgfmt(fmt: selfTypeRequired, ctx: &ctx);
292 os << "\n";
293}
294
295void DefFormat::genParser(MethodBody &os) {
296 FmtContext ctx;
297 ctx.addSubst(placeholder: "_parser", subst: "odsParser");
298 ctx.addSubst(placeholder: "_ctxt", subst: "odsParser.getContext()");
299 ctx.withBuilder(subst: "odsBuilder");
300 if (isa<AttrDef>(Val: def))
301 ctx.addSubst(placeholder: "_type", subst: "odsType");
302 os.indent();
303 os << "::mlir::Builder odsBuilder(odsParser.getContext());\n";
304
305 // Store the initial location of the parser.
306 ctx.addSubst(placeholder: "_loc", subst: "odsLoc");
307 os << tgfmt(fmt: "::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
308 "(void) $_loc;\n",
309 ctx: &ctx);
310
311 // Declare variables to store all of the parameters. Allocated parameters
312 // such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
313 // FailureOr<T> to defer type construction for parameters that are parsed in
314 // a loop (parsers return FailureOr anyways).
315 ArrayRef<AttrOrTypeParameter> params = def.getParameters();
316 for (const AttrOrTypeParameter &param : params) {
317 os << formatv(Fmt: "::mlir::FailureOr<{0}> _result_{1};\n",
318 Vals: param.getCppStorageType(), Vals: param.getName());
319 if (auto *selfTypeParam = dyn_cast<AttributeSelfTypeParameter>(Val: &param))
320 genAttrSelfTypeParser(os, ctx, param: *selfTypeParam);
321 }
322
323 // Generate call to each parameter parser.
324 for (FormatElement *el : elements)
325 genElementParser(el, ctx, os);
326
327 // Emit an assert for each mandatory parameter. Triggering an assert means
328 // the generated parser is incorrect (i.e. there is a bug in this code).
329 for (const AttrOrTypeParameter &param : params) {
330 if (param.isOptional())
331 continue;
332 os << formatv(Fmt: "assert(::mlir::succeeded(_result_{0}));\n", Vals: param.getName());
333 }
334
335 // Generate call to the attribute or type builder. Use the checked getter
336 // if one was generated.
337 if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) {
338 os << tgfmt(fmt: "return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
339 ctx: &ctx, vals: def.getCppClassName());
340 } else {
341 os << tgfmt(fmt: "return $0::get($_parser.getContext()", ctx: &ctx,
342 vals: def.getCppClassName());
343 }
344 for (const AttrOrTypeParameter &param : params) {
345 os << ",\n ";
346 std::string paramSelfStr;
347 llvm::raw_string_ostream selfOs(paramSelfStr);
348 if (std::optional<StringRef> defaultValue = param.getDefaultValue()) {
349 selfOs << formatv(Fmt: "(_result_{0}.value_or(", Vals: param.getName())
350 << tgfmt(fmt: *defaultValue, ctx: &ctx) << "))";
351 } else {
352 selfOs << formatv(Fmt: "(*_result_{0})", Vals: param.getName());
353 }
354 ctx.addSubst(placeholder: param.getName(), subst: selfOs.str());
355 os << param.getCppType() << "("
356 << tgfmt(fmt: param.getConvertFromStorage(), ctx: &ctx.withSelf(subst: selfOs.str()))
357 << ")";
358 }
359 os << ");";
360}
361
362void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
363 MethodBody &os) {
364 if (auto *literal = dyn_cast<LiteralElement>(Val: el))
365 return genLiteralParser(value: literal->getSpelling(), ctx, os);
366 if (auto *var = dyn_cast<ParameterElement>(Val: el))
367 return genVariableParser(el: var, ctx, os);
368 if (auto *params = dyn_cast<ParamsDirective>(Val: el))
369 return genParamsParser(el: params, ctx, os);
370 if (auto *strct = dyn_cast<StructDirective>(Val: el))
371 return genStructParser(el: strct, ctx, os);
372 if (auto *custom = dyn_cast<CustomDirective>(Val: el))
373 return genCustomParser(el: custom, ctx, os);
374 if (auto *optional = dyn_cast<OptionalElement>(Val: el))
375 return genOptionalGroupParser(el: optional, ctx, os);
376 if (isa<WhitespaceElement>(Val: el))
377 return;
378
379 llvm_unreachable("unknown format element");
380}
381
382void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx,
383 MethodBody &os, bool isOptional) {
384 os << "// Parse literal '" << value << "'\n";
385 os << tgfmt(fmt: "if ($_parser.parse", ctx: &ctx);
386 if (isOptional)
387 os << "Optional";
388 if (value.front() == '_' || isalpha(value.front())) {
389 os << "Keyword(\"" << value << "\")";
390 } else {
391 os << StringSwitch<StringRef>(value)
392 .Case(S: "->", Value: "Arrow")
393 .Case(S: ":", Value: "Colon")
394 .Case(S: ",", Value: "Comma")
395 .Case(S: "=", Value: "Equal")
396 .Case(S: "<", Value: "Less")
397 .Case(S: ">", Value: "Greater")
398 .Case(S: "{", Value: "LBrace")
399 .Case(S: "}", Value: "RBrace")
400 .Case(S: "(", Value: "LParen")
401 .Case(S: ")", Value: "RParen")
402 .Case(S: "[", Value: "LSquare")
403 .Case(S: "]", Value: "RSquare")
404 .Case(S: "?", Value: "Question")
405 .Case(S: "+", Value: "Plus")
406 .Case(S: "*", Value: "Star")
407 .Case(S: "...", Value: "Ellipsis")
408 << "()";
409 }
410 if (isOptional) {
411 // Leave the `if` unclosed to guard optional groups.
412 return;
413 }
414 // Parser will emit an error
415 os << ") return {};\n";
416}
417
418void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
419 MethodBody &os) {
420 // Check for a custom parser. Use the default attribute parser otherwise.
421 const AttrOrTypeParameter &param = el->getParam();
422 auto customParser = param.getParser();
423 auto parser =
424 customParser ? *customParser : StringRef(defaultParameterParser);
425
426 // If the variable points to a dialect specific entity (type of attribute),
427 // we force load the dialect now before trying to parse it.
428 std::string dialectLoading;
429 if (auto *defInit = dyn_cast<llvm::DefInit>(Val: param.getDef())) {
430 auto *dialectValue = defInit->getDef()->getValue(Name: "dialect");
431 if (dialectValue) {
432 if (auto *dialectInit =
433 dyn_cast<llvm::DefInit>(Val: dialectValue->getValue())) {
434 Dialect dialect(dialectInit->getDef());
435 auto cppNamespace = dialect.getCppNamespace();
436 std::string name = dialect.getCppClassName();
437 if (name != "BuiltinDialect" || cppNamespace != "::mlir") {
438 dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
439 cppNamespace + "::" + name + ">();")
440 .str();
441 }
442 }
443 }
444 }
445 os << formatv(Fmt: variableParser, Vals: param.getName(),
446 Vals: tgfmt(fmt: parser, ctx: &ctx, vals: param.getCppStorageType()),
447 Vals: tgfmt(fmt: parserErrorStr, ctx: &ctx), Vals: def.getName(), Vals: param.getCppType(),
448 Vals&: dialectLoading);
449}
450
451void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
452 MethodBody &os) {
453 os << "// Parse parameter list\n";
454
455 // If there are optional parameters, we need to switch to `parseOptionalComma`
456 // if there are no more required parameters after a certain point.
457 bool hasOptional = el->hasOptionalElements();
458 if (hasOptional) {
459 // Wrap everything in a do-while so that we can `break`.
460 os << "do {\n";
461 os.indent();
462 }
463
464 ArrayRef<ParameterElement *> params = el->getElements();
465 using IteratorT = ParameterElement *const *;
466 IteratorT it = params.begin();
467
468 // Find the last required parameter. Commas become optional aftewards.
469 // Note: IteratorT's copy assignment is deleted.
470 ParameterElement *lastReq = nullptr;
471 for (ParameterElement *param : params)
472 if (!param->isOptional())
473 lastReq = param;
474 IteratorT lastReqIt = lastReq ? llvm::find(Range&: params, Val: lastReq) : params.begin();
475
476 auto eachFn = [&](ParameterElement *el) { genVariableParser(el, ctx, os); };
477 auto betweenFn = [&](IteratorT it) {
478 ParameterElement *el = *std::prev(x: it);
479 // Parse a comma if the last optional parameter had a value.
480 if (el->isOptional()) {
481 os << formatv(Fmt: "if (::mlir::succeeded(_result_{0}) && !({1})) {{\n",
482 Vals: el->getName(),
483 Vals: el->genIsPresent(ctx, self: "(*_result_" + el->getName() + ")"));
484 os.indent();
485 }
486 if (it <= lastReqIt) {
487 genLiteralParser(value: ",", ctx, os);
488 } else {
489 genLiteralParser(value: ",", ctx, os, /*isOptional=*/true);
490 os << ") break;\n";
491 }
492 if (el->isOptional())
493 os.unindent() << "}\n";
494 };
495
496 // llvm::interleave
497 if (it != params.end()) {
498 eachFn(*it++);
499 for (IteratorT e = params.end(); it != e; ++it) {
500 betweenFn(it);
501 eachFn(*it);
502 }
503 }
504
505 if (hasOptional)
506 os.unindent() << "} while(false);\n";
507}
508
509void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
510 MethodBody &os) {
511 // Loop declaration for struct parser with only required parameters.
512 //
513 // $0: Number of expected parameters.
514 const char *const loopHeader = R"(
515 for (unsigned odsStructIndex = 0; odsStructIndex < $0; ++odsStructIndex) {
516)";
517
518 // Loop body start for struct parser.
519 const char *const loopStart = R"(
520 ::llvm::StringRef _paramKey;
521 if ($_parser.parseKeyword(&_paramKey)) {
522 $_parser.emitError($_parser.getCurrentLocation(),
523 "expected a parameter name in struct");
524 return {};
525 }
526 if (!_loop_body(_paramKey)) return {};
527)";
528
529 // Struct parser loop end. Check for duplicate or unknown struct parameters.
530 //
531 // {0}: Code template for printing an error.
532 const char *const loopEnd = R"({{
533 {0}"duplicate or unknown struct parameter name: ") << _paramKey;
534 return {{};
535}
536)";
537
538 // Struct parser loop terminator. Parse a comma except on the last element.
539 //
540 // {0}: Number of elements in the struct.
541 const char *const loopTerminator = R"(
542 if ((odsStructIndex != {0} - 1) && odsParser.parseComma())
543 return {{};
544}
545)";
546
547 // Check that a mandatory parameter was parse.
548 //
549 // {0}: Name of the parameter.
550 const char *const checkParam = R"(
551 if (!_seen_{0}) {
552 {1}"struct is missing required parameter: ") << "{0}";
553 return {{};
554 }
555)";
556
557 // First iteration of the loop parsing an optional struct.
558 const char *const optionalStructFirst = R"(
559 ::llvm::StringRef _paramKey;
560 if (!$_parser.parseOptionalKeyword(&_paramKey)) {
561 if (!_loop_body(_paramKey)) return {};
562 while (!$_parser.parseOptionalComma()) {
563)";
564
565 const char *const checkParamKey = R"(
566 if (!_seen_{0} && _paramKey == "{0}") {
567 _seen_{0} = true;
568)";
569
570 os << "// Parse parameter struct\n";
571
572 // Declare a "seen" variable for each key.
573 for (FormatElement *arg : el->getElements()) {
574 ParameterElement *param = getEncapsulatedParameterElement(el: arg);
575 os << formatv(Fmt: "bool _seen_{0} = false;\n", Vals: param->getName());
576 }
577
578 // Generate the body of the parsing loop inside a lambda.
579 os << "{\n";
580 os.indent()
581 << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
582 genLiteralParser(value: "=", ctx, os&: os.indent());
583 for (FormatElement *arg : el->getElements()) {
584 ParameterElement *param = getEncapsulatedParameterElement(el: arg);
585 os.getStream().printReindented(str: strfmt(fmt: checkParamKey, parameters: param->getName()));
586 if (isa<ParameterElement>(Val: arg))
587 genVariableParser(el: param, ctx, os&: os.indent());
588 else if (auto custom = dyn_cast<CustomDirective>(Val: arg))
589 genCustomParser(el: custom, ctx, os&: os.indent());
590 os.unindent() << "} else ";
591 // Print the check for duplicate or unknown parameter.
592 }
593 os.getStream().printReindented(str: strfmt(fmt: loopEnd, parameters: tgfmt(fmt: parserErrorStr, ctx: &ctx)));
594 os << "return true;\n";
595 os.unindent() << "};\n";
596
597 // Generate the parsing loop. If optional parameters are present, then the
598 // parse loop is guarded by commas.
599 unsigned numOptional = llvm::count_if(Range: el->getElements(), P: formatIsOptional);
600 if (numOptional) {
601 // If the struct itself is optional, pull out the first iteration.
602 if (numOptional == el->getNumElements()) {
603 os.getStream().printReindented(str: tgfmt(fmt: optionalStructFirst, ctx: &ctx).str());
604 os.indent();
605 } else {
606 os << "do {\n";
607 }
608 } else {
609 os.getStream().printReindented(
610 str: tgfmt(fmt: loopHeader, ctx: &ctx, vals: el->getNumElements()).str());
611 }
612 os.indent();
613 os.getStream().printReindented(str: tgfmt(fmt: loopStart, ctx: &ctx).str());
614 os.unindent();
615
616 // Print the loop terminator. For optional parameters, we have to check that
617 // all mandatory parameters have been parsed.
618 // The whole struct is optional if all its parameters are optional.
619 if (numOptional) {
620 if (numOptional == el->getNumElements()) {
621 os << "}\n";
622 os.unindent() << "}\n";
623 } else {
624 os << tgfmt(fmt: "} while(!$_parser.parseOptionalComma());\n", ctx: &ctx);
625 for (FormatElement *arg : el->getElements()) {
626 ParameterElement *param = getEncapsulatedParameterElement(el: arg);
627 if (param->isOptional())
628 continue;
629 os.getStream().printReindented(
630 str: strfmt(fmt: checkParam, parameters: param->getName(), parameters: tgfmt(fmt: parserErrorStr, ctx: &ctx)));
631 }
632 }
633 } else {
634 // Because the loop loops N times and each non-failing iteration sets 1 of
635 // N flags, successfully exiting the loop means that all parameters have
636 // been seen. `parseOptionalComma` would cause issues with any formats that
637 // use "struct(...) `,`" beacuse structs aren't sounded by braces.
638 os.getStream().printReindented(
639 str: strfmt(fmt: loopTerminator, parameters: el->getNumElements()));
640 }
641 os.unindent() << "}\n";
642}
643
644void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
645 MethodBody &os, bool isOptional) {
646 os << "{\n";
647 os.indent();
648
649 // Bound variables are passed directly to the parser as `FailureOr<T> &`.
650 // Referenced variables are passed as `T`. The custom parser fails if it
651 // returns failure or if any of the required parameters failed.
652 os << tgfmt(fmt: "auto odsCustomLoc = $_parser.getCurrentLocation();\n", ctx: &ctx);
653 os << "(void)odsCustomLoc;\n";
654 os << tgfmt(fmt: "auto odsCustomResult = parse$0($_parser", ctx: &ctx, vals: el->getName());
655 os.indent();
656 for (FormatElement *arg : el->getElements()) {
657 os << ",\n";
658 if (auto *param = dyn_cast<ParameterElement>(Val: arg))
659 os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName()
660 << ")";
661 else if (auto *ref = dyn_cast<RefDirective>(Val: arg))
662 os << "*_result_" << cast<ParameterElement>(Val: ref->getArg())->getName();
663 else
664 os << tgfmt(fmt: cast<StringElement>(Val: arg)->getValue(), ctx: &ctx);
665 }
666 os.unindent() << ");\n";
667 if (isOptional) {
668 os << "if (!odsCustomResult.has_value()) return {};\n";
669 os << "if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();\n";
670 } else {
671 os << "if (::mlir::failed(odsCustomResult)) return {};\n";
672 }
673 for (FormatElement *arg : el->getElements()) {
674 if (auto *param = dyn_cast<ParameterElement>(Val: arg)) {
675 if (param->isOptional())
676 continue;
677 os << formatv(Fmt: "if (::mlir::failed(_result_{0})) {{\n", Vals: param->getName());
678 os.indent() << tgfmt(fmt: "$_parser.emitError(odsCustomLoc, ", ctx: &ctx)
679 << "\"custom parser failed to parse parameter '"
680 << param->getName() << "'\");\n";
681 os << "return " << (isOptional ? "::mlir::failure()" : "{}") << ";\n";
682 os.unindent() << "}\n";
683 }
684 }
685
686 os.unindent() << "}\n";
687}
688
689void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
690 MethodBody &os) {
691 ArrayRef<FormatElement *> thenElements =
692 el->getThenElements(/*parseable=*/true);
693
694 FormatElement *first = thenElements.front();
695 const auto guardOn = [&](auto params) {
696 os << "if (!(";
697 llvm::interleave(
698 params, os,
699 [&](ParameterElement *el) {
700 os << formatv(Fmt: "(::mlir::succeeded(_result_{0}) && *_result_{0})",
701 Vals: el->getName());
702 },
703 " || ");
704 os << ")) {\n";
705 };
706 if (auto *literal = dyn_cast<LiteralElement>(Val: first)) {
707 genLiteralParser(value: literal->getSpelling(), ctx, os, /*isOptional=*/true);
708 os << ") {\n";
709 } else if (auto *param = dyn_cast<ParameterElement>(Val: first)) {
710 genVariableParser(el: param, ctx, os);
711 guardOn(llvm::ArrayRef(param));
712 } else if (auto *params = dyn_cast<ParamsDirective>(Val: first)) {
713 genParamsParser(el: params, ctx, os);
714 guardOn(params->getElements());
715 } else if (auto *custom = dyn_cast<CustomDirective>(Val: first)) {
716 os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
717 os.indent();
718 genCustomParser(el: custom, ctx, os, /*isOptional=*/true);
719 os << "return ::mlir::success();\n";
720 os.unindent();
721 os << "}(); result.has_value() && ::mlir::failed(*result)) {\n";
722 os.indent();
723 os << "return {};\n";
724 os.unindent();
725 os << "} else if (result.has_value()) {\n";
726 } else {
727 auto *strct = cast<StructDirective>(Val: first);
728 genStructParser(el: strct, ctx, os);
729 guardOn(params->getElements());
730 }
731 os.indent();
732
733 // Generate the parsers for the rest of the thenElements.
734 for (FormatElement *element : el->getElseElements(/*parseable=*/true))
735 genElementParser(el: element, ctx, os);
736 os.unindent() << "} else {\n";
737 os.indent();
738 for (FormatElement *element : thenElements.drop_front())
739 genElementParser(el: element, ctx, os);
740 os.unindent() << "}\n";
741}
742
743//===----------------------------------------------------------------------===//
744// PrinterGen
745//===----------------------------------------------------------------------===//
746
747void DefFormat::genPrinter(MethodBody &os) {
748 FmtContext ctx;
749 ctx.addSubst(placeholder: "_printer", subst: "odsPrinter");
750 ctx.addSubst(placeholder: "_ctxt", subst: "getContext()");
751 ctx.withBuilder(subst: "odsBuilder");
752 os.indent();
753 os << "::mlir::Builder odsBuilder(getContext());\n";
754
755 // Generate printers.
756 shouldEmitSpace = true;
757 lastWasPunctuation = false;
758 for (FormatElement *el : elements)
759 genElementPrinter(el, ctx, os);
760}
761
762void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
763 MethodBody &os) {
764 if (auto *literal = dyn_cast<LiteralElement>(Val: el))
765 return genLiteralPrinter(value: literal->getSpelling(), ctx, os);
766 if (auto *params = dyn_cast<ParamsDirective>(Val: el))
767 return genParamsPrinter(el: params, ctx, os);
768 if (auto *strct = dyn_cast<StructDirective>(Val: el))
769 return genStructPrinter(el: strct, ctx, os);
770 if (auto *custom = dyn_cast<CustomDirective>(Val: el))
771 return genCustomPrinter(el: custom, ctx, os);
772 if (auto *var = dyn_cast<ParameterElement>(Val: el))
773 return genVariablePrinter(el: var, ctx, os);
774 if (auto *optional = dyn_cast<OptionalElement>(Val: el))
775 return genOptionalGroupPrinter(el: optional, ctx, os);
776 if (auto *whitespace = dyn_cast<WhitespaceElement>(Val: el))
777 return genWhitespacePrinter(el: whitespace, ctx, os);
778
779 llvm::PrintFatalError(Msg: "unsupported format element");
780}
781
782void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
783 MethodBody &os) {
784 // Don't insert a space before certain punctuation.
785 bool needSpace =
786 shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
787 os << tgfmt(fmt: "$_printer$0 << \"$1\";\n", ctx: &ctx, vals: needSpace ? " << ' '" : "",
788 vals&: value);
789
790 // Update the flags.
791 shouldEmitSpace =
792 value.size() != 1 || !StringRef("<({[").contains(C: value.front());
793 lastWasPunctuation = value.front() != '_' && !isalpha(value.front());
794}
795
796void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
797 MethodBody &os, bool skipGuard) {
798 const AttrOrTypeParameter &param = el->getParam();
799 ctx.withSelf(subst: param.getAccessorName() + "()");
800
801 // Guard the printer on the presence of optional parameters and that they
802 // aren't equal to their default values (if they have one).
803 if (el->isOptional() && !skipGuard) {
804 el->genPrintGuard(ctx, os&: os << "if (") << ") {\n";
805 os.indent();
806 }
807
808 // Insert a space before the next parameter, if necessary.
809 if (shouldEmitSpace || !lastWasPunctuation)
810 os << tgfmt(fmt: "$_printer << ' ';\n", ctx: &ctx);
811 shouldEmitSpace = true;
812 lastWasPunctuation = false;
813
814 if (el->shouldBeQualified())
815 os << tgfmt(fmt: qualifiedParameterPrinter, ctx: &ctx) << ";\n";
816 else if (auto printer = param.getPrinter())
817 os << tgfmt(fmt: *printer, ctx: &ctx) << ";\n";
818 else
819 os << tgfmt(fmt: defaultParameterPrinter, ctx: &ctx) << ";\n";
820
821 if (el->isOptional() && !skipGuard)
822 os.unindent() << "}\n";
823}
824
825/// Generate code to guard printing on the presence of any optional parameters.
826template <typename ParameterRange>
827static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &&params,
828 bool inverted = false) {
829 os << "if (";
830 if (inverted)
831 os << "!(";
832 llvm::interleave(
833 params, os,
834 [&](ParameterElement *param) { param->genPrintGuard(ctx, os); }, " || ");
835 if (inverted)
836 os << ")";
837 os << ") {\n";
838 os.indent();
839}
840
841/// Generate code to guard printing on the presence of any optional format
842/// elements.
843template <typename FormatElemRange>
844static void guardOnAnyOptional(FmtContext &ctx, MethodBody &os,
845 FormatElemRange &&args, bool inverted = false) {
846 guardOnAny(ctx, os,
847 llvm::make_filter_range(
848 llvm::map_range(args, getEncapsulatedParameterElement),
849 [](ParameterElement *param) { return param->isOptional(); }),
850 inverted);
851}
852
853void DefFormat::genCommaSeparatedPrinter(
854 ArrayRef<FormatElement *> args, FmtContext &ctx, MethodBody &os,
855 function_ref<void(FormatElement *)> extra) {
856 // Emit a space if necessary, but only if the struct is present.
857 if (shouldEmitSpace || !lastWasPunctuation) {
858 bool allOptional = llvm::all_of(Range&: args, P: formatIsOptional);
859 if (allOptional)
860 guardOnAnyOptional(ctx, os, args);
861 os << tgfmt(fmt: "$_printer << ' ';\n", ctx: &ctx);
862 if (allOptional)
863 os.unindent() << "}\n";
864 }
865
866 // The first printed element does not need to emit a comma.
867 os << "{\n";
868 os.indent() << "bool _firstPrinted = true;\n";
869 for (FormatElement *arg : args) {
870 ParameterElement *param = getEncapsulatedParameterElement(el: arg);
871 if (param->isOptional()) {
872 param->genPrintGuard(ctx, os&: os << "if (") << ") {\n";
873 os.indent();
874 }
875 os << tgfmt(fmt: "if (!_firstPrinted) $_printer << \", \";\n", ctx: &ctx);
876 os << "_firstPrinted = false;\n";
877 extra(arg);
878 shouldEmitSpace = false;
879 lastWasPunctuation = true;
880 if (auto realParam = dyn_cast<ParameterElement>(Val: arg))
881 genVariablePrinter(el: realParam, ctx, os);
882 else if (auto custom = dyn_cast<CustomDirective>(Val: arg))
883 genCustomPrinter(el: custom, ctx, os);
884 if (param->isOptional())
885 os.unindent() << "}\n";
886 }
887 os.unindent() << "}\n";
888}
889
890void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
891 MethodBody &os) {
892 SmallVector<FormatElement *> args = llvm::map_to_vector(
893 C: el->getElements(), F: [](ParameterElement *param) -> FormatElement * {
894 return static_cast<FormatElement *>(param);
895 });
896 genCommaSeparatedPrinter(args, ctx, os, extra: [&](FormatElement *param) {});
897}
898
899void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
900 MethodBody &os) {
901 genCommaSeparatedPrinter(args: el->getElements(), ctx, os, extra: [&](FormatElement *arg) {
902 ParameterElement *param = getEncapsulatedParameterElement(el: arg);
903 os << tgfmt(fmt: "$_printer << \"$0 = \";\n", ctx: &ctx, vals: param->getName());
904 });
905}
906
907void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
908 MethodBody &os) {
909 // Insert a space before the custom directive, if necessary.
910 if (shouldEmitSpace || !lastWasPunctuation)
911 os << tgfmt(fmt: "$_printer << ' ';\n", ctx: &ctx);
912 shouldEmitSpace = true;
913 lastWasPunctuation = false;
914
915 os << tgfmt(fmt: "print$0($_printer", ctx: &ctx, vals: el->getName());
916 os.indent();
917 for (FormatElement *arg : el->getElements()) {
918 os << ",\n";
919 if (auto *param = dyn_cast<ParameterElement>(Val: arg)) {
920 os << param->getParam().getAccessorName() << "()";
921 } else if (auto *ref = dyn_cast<RefDirective>(Val: arg)) {
922 os << cast<ParameterElement>(Val: ref->getArg())->getParam().getAccessorName()
923 << "()";
924 } else {
925 os << tgfmt(fmt: cast<StringElement>(Val: arg)->getValue(), ctx: &ctx);
926 }
927 }
928 os.unindent() << ");\n";
929}
930
931void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
932 MethodBody &os) {
933 FormatElement *anchor = el->getAnchor();
934 if (auto *param = dyn_cast<ParameterElement>(Val: anchor)) {
935 guardOnAny(ctx, os, params: llvm::ArrayRef(param), inverted: el->isInverted());
936 } else if (auto *params = dyn_cast<ParamsDirective>(Val: anchor)) {
937 guardOnAny(ctx, os, params: params->getElements(), inverted: el->isInverted());
938 } else if (auto *strct = dyn_cast<StructDirective>(Val: anchor)) {
939 guardOnAnyOptional(ctx, os, args: strct->getElements(), inverted: el->isInverted());
940 } else {
941 auto *custom = cast<CustomDirective>(Val: anchor);
942 guardOnAnyOptional(ctx, os, args: custom->getElements(), inverted: el->isInverted());
943 }
944 // Generate the printer for the contained elements.
945 {
946 llvm::SaveAndRestore shouldEmitSpaceFlag(shouldEmitSpace);
947 llvm::SaveAndRestore lastWasPunctuationFlag(lastWasPunctuation);
948 for (FormatElement *element : el->getThenElements())
949 genElementPrinter(el: element, ctx, os);
950 }
951 os.unindent() << "} else {\n";
952 os.indent();
953 for (FormatElement *element : el->getElseElements())
954 genElementPrinter(el: element, ctx, os);
955 os.unindent() << "}\n";
956}
957
958void DefFormat::genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
959 MethodBody &os) {
960 if (el->getValue() == "\\n") {
961 // FIXME: The newline should be `printer.printNewLine()`, i.e., handled by
962 // the printer.
963 os << tgfmt(fmt: "$_printer << '\\n';\n", ctx: &ctx);
964 } else if (!el->getValue().empty()) {
965 os << tgfmt(fmt: "$_printer << \"$0\";\n", ctx: &ctx, vals: el->getValue());
966 } else {
967 lastWasPunctuation = true;
968 }
969 shouldEmitSpace = false;
970}
971
972//===----------------------------------------------------------------------===//
973// DefFormatParser
974//===----------------------------------------------------------------------===//
975
976namespace {
977class DefFormatParser : public FormatParser {
978public:
979 DefFormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
980 : FormatParser(mgr, def.getLoc()[0]), def(def),
981 seenParams(def.getNumParameters()) {}
982
983 /// Parse the attribute or type format and create the format elements.
984 FailureOr<DefFormat> parse();
985
986protected:
987 /// Verify the parsed elements.
988 LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
989 /// Verify the elements of a custom directive.
990 LogicalResult
991 verifyCustomDirectiveArguments(SMLoc loc,
992 ArrayRef<FormatElement *> arguments) override;
993 /// Verify the elements of an optional group.
994 LogicalResult verifyOptionalGroupElements(SMLoc loc,
995 ArrayRef<FormatElement *> elements,
996 FormatElement *anchor) override;
997 /// Verify the arguments to a struct directive.
998 LogicalResult verifyStructArguments(SMLoc loc,
999 ArrayRef<FormatElement *> arguments);
1000
1001 LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
1002
1003 /// Parse an attribute or type variable.
1004 FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
1005 Context ctx) override;
1006 /// Parse an attribute or type format directive.
1007 FailureOr<FormatElement *>
1008 parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
1009
1010private:
1011 /// Parse a `params` directive.
1012 FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
1013 /// Parse a `struct` directive.
1014 FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
1015
1016 /// Attribute or type tablegen def.
1017 const AttrOrTypeDef &def;
1018
1019 /// Seen attribute or type parameters.
1020 BitVector seenParams;
1021};
1022} // namespace
1023
1024LogicalResult DefFormatParser::verify(SMLoc loc,
1025 ArrayRef<FormatElement *> elements) {
1026 // Check that all parameters are referenced in the format.
1027 for (auto [index, param] : llvm::enumerate(First: def.getParameters())) {
1028 if (param.isOptional())
1029 continue;
1030 if (!seenParams.test(Idx: index)) {
1031 if (isa<AttributeSelfTypeParameter>(Val: param))
1032 continue;
1033 return emitError(loc, msg: "format is missing reference to parameter: " +
1034 param.getName());
1035 }
1036 if (isa<AttributeSelfTypeParameter>(Val: param)) {
1037 return emitError(loc,
1038 msg: "unexpected self type parameter in assembly format");
1039 }
1040 }
1041 if (elements.empty())
1042 return success();
1043 // A `struct` directive that contains optional parameters cannot be followed
1044 // by a comma literal, which is ambiguous.
1045 for (auto it : llvm::zip(t: elements.drop_back(), u: elements.drop_front())) {
1046 auto *structEl = dyn_cast<StructDirective>(Val: std::get<0>(t&: it));
1047 auto *literalEl = dyn_cast<LiteralElement>(Val: std::get<1>(t&: it));
1048 if (!structEl || !literalEl)
1049 continue;
1050 if (literalEl->getSpelling() == "," && structEl->hasOptionalElements()) {
1051 return emitError(loc, msg: "`struct` directive with optional parameters "
1052 "cannot be followed by a comma literal");
1053 }
1054 }
1055 return success();
1056}
1057
1058LogicalResult DefFormatParser::verifyCustomDirectiveArguments(
1059 SMLoc loc, ArrayRef<FormatElement *> arguments) {
1060 // Arguments are fully verified by the parser context.
1061 return success();
1062}
1063
1064LogicalResult
1065DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
1066 ArrayRef<FormatElement *> elements,
1067 FormatElement *anchor) {
1068 // `params` and `struct` directives are allowed only if all the contained
1069 // parameters are optional.
1070 for (FormatElement *el : elements) {
1071 if (auto *param = dyn_cast<ParameterElement>(Val: el)) {
1072 if (!param->isOptional()) {
1073 return emitError(loc,
1074 msg: "parameters in an optional group must be optional");
1075 }
1076 } else if (auto *params = dyn_cast<ParamsDirective>(Val: el)) {
1077 if (llvm::any_of(Range: params->getElements(), P: paramNotOptional)) {
1078 return emitError(loc, msg: "`params` directive allowed in optional group "
1079 "only if all parameters are optional");
1080 }
1081 } else if (auto *strct = dyn_cast<StructDirective>(Val: el)) {
1082 if (llvm::any_of(Range: strct->getElements(), P: formatNotOptional)) {
1083 return emitError(loc, msg: "`struct` is only allowed in an optional group "
1084 "if all captured parameters are optional");
1085 }
1086 } else if (auto *custom = dyn_cast<CustomDirective>(Val: el)) {
1087 for (FormatElement *el : custom->getElements()) {
1088 // If the custom argument is a variable, then it must be optional.
1089 if (auto *param = dyn_cast<ParameterElement>(Val: el))
1090 if (!param->isOptional())
1091 return emitError(loc,
1092 msg: "`custom` is only allowed in an optional group if "
1093 "all captured parameters are optional");
1094 }
1095 }
1096 }
1097 // The anchor must be a parameter or one of the aforementioned directives.
1098 if (anchor) {
1099 if (!isa<ParameterElement, ParamsDirective, StructDirective,
1100 CustomDirective>(Val: anchor)) {
1101 return emitError(
1102 loc, msg: "optional group anchor must be a parameter or directive");
1103 }
1104 // If the anchor is a custom directive, make sure at least one of its
1105 // arguments is a bound parameter.
1106 if (auto *custom = dyn_cast<CustomDirective>(Val: anchor)) {
1107 const auto *bound =
1108 llvm::find_if(Range: custom->getElements(), P: [](FormatElement *el) {
1109 return isa<ParameterElement>(Val: el);
1110 });
1111 if (bound == custom->getElements().end())
1112 return emitError(loc, msg: "`custom` directive with no bound parameters "
1113 "cannot be used as optional group anchor");
1114 }
1115 }
1116 return success();
1117}
1118
1119LogicalResult
1120DefFormatParser::verifyStructArguments(SMLoc loc,
1121 ArrayRef<FormatElement *> arguments) {
1122 for (FormatElement *el : arguments) {
1123 if (!isa<ParameterElement, CustomDirective, ParamsDirective>(Val: el)) {
1124 return emitError(loc, msg: "expected a parameter, custom directive or params "
1125 "directive in `struct` arguments list");
1126 }
1127 if (auto custom = dyn_cast<CustomDirective>(Val: el)) {
1128 if (custom->getNumElements() != 1) {
1129 return emitError(loc, msg: "`struct` can only contain `custom` directives "
1130 "with a single argument");
1131 }
1132 if (failed(Result: custom->getFrontAs<ParameterElement>())) {
1133 return emitError(loc, msg: "a `custom` directive nested within a `struct` "
1134 "must be passed a parameter");
1135 }
1136 }
1137 }
1138 return success();
1139}
1140
1141LogicalResult DefFormatParser::markQualified(SMLoc loc,
1142 FormatElement *element) {
1143 if (!isa<ParameterElement>(Val: element))
1144 return emitError(loc, msg: "`qualified` argument list expected a variable");
1145 cast<ParameterElement>(Val: element)->setShouldBeQualified();
1146 return success();
1147}
1148
1149FailureOr<DefFormat> DefFormatParser::parse() {
1150 FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
1151 if (failed(Result: elements))
1152 return failure();
1153 return DefFormat(def, std::move(*elements));
1154}
1155
1156FailureOr<FormatElement *>
1157DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
1158 // Lookup the parameter.
1159 ArrayRef<AttrOrTypeParameter> params = def.getParameters();
1160 auto *it = llvm::find_if(
1161 Range&: params, P: [&](auto &param) { return param.getName() == name; });
1162
1163 // Check that the parameter reference is valid.
1164 if (it == params.end()) {
1165 return emitError(loc,
1166 msg: def.getName() + " has no parameter named '" + name + "'");
1167 }
1168 auto idx = std::distance(first: params.begin(), last: it);
1169
1170 if (ctx != RefDirectiveContext) {
1171 // Check that the variable has not already been bound.
1172 if (seenParams.test(Idx: idx))
1173 return emitError(loc, msg: "duplicate parameter '" + name + "'");
1174 seenParams.set(idx);
1175
1176 // Otherwise, to be referenced, a variable must have been bound.
1177 } else if (!seenParams.test(Idx: idx) && !isa<AttributeSelfTypeParameter>(Val: *it)) {
1178 return emitError(loc, msg: "parameter '" + name +
1179 "' must be bound before it is referenced");
1180 }
1181
1182 return create<ParameterElement>(args: *it);
1183}
1184
1185FailureOr<FormatElement *>
1186DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
1187 Context ctx) {
1188
1189 switch (kind) {
1190 case FormatToken::kw_qualified:
1191 return parseQualifiedDirective(loc, ctx);
1192 case FormatToken::kw_params:
1193 return parseParamsDirective(loc, ctx);
1194 case FormatToken::kw_struct:
1195 return parseStructDirective(loc, ctx);
1196 default:
1197 return emitError(loc, msg: "unsupported directive kind");
1198 }
1199}
1200
1201FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
1202 Context ctx) {
1203 // It doesn't make sense to allow references to all parameters in a custom
1204 // directive because parameters are the only things that can be bound.
1205 if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
1206 return emitError(loc, msg: "`params` can only be used at the top-level context "
1207 "or within a `struct` directive");
1208 }
1209
1210 // Collect all of the attribute's or type's parameters and ensure that none of
1211 // the parameters have already been captured.
1212 std::vector<ParameterElement *> vars;
1213 for (const auto &it : llvm::enumerate(First: def.getParameters())) {
1214 if (seenParams.test(Idx: it.index())) {
1215 return emitError(loc, msg: "`params` captures duplicate parameter: " +
1216 it.value().getName());
1217 }
1218 // Self-type parameters are handled separately from the rest of the
1219 // parameters.
1220 if (isa<AttributeSelfTypeParameter>(Val: it.value()))
1221 continue;
1222 seenParams.set(it.index());
1223 vars.push_back(x: create<ParameterElement>(args: it.value()));
1224 }
1225 return create<ParamsDirective>(args: std::move(vars));
1226}
1227
1228FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
1229 Context ctx) {
1230 if (ctx != TopLevelContext)
1231 return emitError(loc, msg: "`struct` can only be used at the top-level context");
1232
1233 if (failed(Result: parseToken(kind: FormatToken::l_paren,
1234 msg: "expected '(' before `struct` argument list"))) {
1235 return failure();
1236 }
1237
1238 // Parse variables captured by `struct`.
1239 std::vector<FormatElement *> vars;
1240
1241 // Parse first captured parameter or a `params` directive.
1242 FailureOr<FormatElement *> var = parseElement(ctx: StructDirectiveContext);
1243 if (failed(Result: var) ||
1244 !isa<ParameterElement, ParamsDirective, CustomDirective>(Val: *var)) {
1245 return emitError(
1246 loc, msg: "`struct` argument list expected a parameter or directive");
1247 }
1248 if (isa<ParameterElement, CustomDirective>(Val: *var)) {
1249 // Parse any other parameters.
1250 vars.push_back(x: *var);
1251 while (peekToken().is(kind: FormatToken::comma)) {
1252 consumeToken();
1253 var = parseElement(ctx: StructDirectiveContext);
1254 if (failed(Result: var) || !isa<ParameterElement, CustomDirective>(Val: *var))
1255 return emitError(loc, msg: "expected a parameter or `custom` directive in "
1256 "`struct` argument list");
1257 vars.push_back(x: *var);
1258 }
1259 } else {
1260 // `struct(params)` captures all parameters in the attribute or type.
1261 ParamsDirective *params = cast<ParamsDirective>(Val: *var);
1262 vars.reserve(n: params->getNumElements());
1263 for (ParameterElement *el : params->takeElements())
1264 vars.push_back(x: cast<FormatElement>(Val: el));
1265 }
1266
1267 if (failed(Result: parseToken(kind: FormatToken::r_paren,
1268 msg: "expected ')' at the end of an argument list"))) {
1269 return failure();
1270 }
1271 if (failed(Result: verifyStructArguments(loc, arguments: vars)))
1272 return failure();
1273 return create<StructDirective>(args: std::move(vars));
1274}
1275
1276//===----------------------------------------------------------------------===//
1277// Interface
1278//===----------------------------------------------------------------------===//
1279
1280void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
1281 MethodBody &parser,
1282 MethodBody &printer) {
1283 llvm::SourceMgr mgr;
1284 mgr.AddNewSourceBuffer(
1285 F: llvm::MemoryBuffer::getMemBuffer(InputData: *def.getAssemblyFormat()), IncludeLoc: SMLoc());
1286
1287 // Parse the custom assembly format>
1288 DefFormatParser fmtParser(mgr, def);
1289 FailureOr<DefFormat> format = fmtParser.parse();
1290 if (failed(Result: format)) {
1291 if (formatErrorIsFatal)
1292 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "failed to parse assembly format");
1293 return;
1294 }
1295
1296 // Generate the parser and printer.
1297 format->genParser(os&: parser);
1298 format->genPrinter(os&: printer);
1299}
1300

source code of mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp