1//===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "AttrOrTypeFormatGen.h"
10#include "FormatGen.h"
11#include "mlir/Support/LLVM.h"
12#include "mlir/Support/LogicalResult.h"
13#include "mlir/TableGen/AttrOrTypeDef.h"
14#include "mlir/TableGen/Format.h"
15#include "mlir/TableGen/GenInfo.h"
16#include "llvm/ADT/BitVector.h"
17#include "llvm/ADT/StringExtras.h"
18#include "llvm/ADT/StringSwitch.h"
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/MemoryBuffer.h"
21#include "llvm/Support/SaveAndRestore.h"
22#include "llvm/Support/SourceMgr.h"
23#include "llvm/TableGen/Error.h"
24#include "llvm/TableGen/TableGenBackend.h"
25
26using namespace mlir;
27using namespace mlir::tblgen;
28
29using llvm::formatv;
30
31//===----------------------------------------------------------------------===//
32// Element
33//===----------------------------------------------------------------------===//
34
35namespace {
36/// This class represents an instance of a variable element. A variable refers
37/// to an attribute or type parameter.
38class ParameterElement
39 : public VariableElementBase<VariableElement::Parameter> {
40public:
41 ParameterElement(AttrOrTypeParameter param) : param(param) {}
42
43 /// Get the parameter in the element.
44 const AttrOrTypeParameter &getParam() const { return param; }
45
46 /// Indicate if this variable is printed "qualified" (that is it is
47 /// prefixed with the `#dialect.mnemonic`).
48 bool shouldBeQualified() { return shouldBeQualifiedFlag; }
49 void setShouldBeQualified(bool qualified = true) {
50 shouldBeQualifiedFlag = qualified;
51 }
52
53 /// Returns true if the element contains an optional parameter.
54 bool isOptional() const { return param.isOptional(); }
55
56 /// Returns the name of the parameter.
57 StringRef getName() const { return param.getName(); }
58
59 /// Return the code to check whether the parameter is present.
60 auto genIsPresent(FmtContext &ctx, const Twine &self) const {
61 assert(isOptional() && "cannot guard on a mandatory parameter");
62 std::string valueStr = tgfmt(fmt: *param.getDefaultValue(), ctx: &ctx).str();
63 ctx.addSubst(placeholder: "_lhs", subst: self).addSubst(placeholder: "_rhs", subst: valueStr);
64 return tgfmt(fmt: getParam().getComparator(), ctx: &ctx);
65 }
66
67 /// Generate the code to check whether the parameter should be printed.
68 MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const {
69 assert(isOptional() && "cannot guard on a mandatory parameter");
70 std::string self = param.getAccessorName() + "()";
71 return os << "!(" << genIsPresent(ctx, self) << ")";
72 }
73
74private:
75 bool shouldBeQualifiedFlag = false;
76 AttrOrTypeParameter param;
77};
78
79/// Shorthand functions that can be used with ranged-based conditions.
80static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
81static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
82
83/// Base class for a directive that contains references to multiple variables.
84template <DirectiveElement::Kind DirectiveKind>
85class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
86public:
87 using Base = ParamsDirectiveBase<DirectiveKind>;
88
89 ParamsDirectiveBase(std::vector<ParameterElement *> &&params)
90 : params(std::move(params)) {}
91
92 /// Get the parameters contained in this directive.
93 ArrayRef<ParameterElement *> getParams() const { return params; }
94
95 /// Get the number of parameters.
96 unsigned getNumParams() const { return params.size(); }
97
98 /// Take all of the parameters from this directive.
99 std::vector<ParameterElement *> takeParams() { return std::move(params); }
100
101 /// Returns true if there are optional parameters present.
102 bool hasOptionalParams() const {
103 return llvm::any_of(getParams(), paramIsOptional);
104 }
105
106private:
107 /// The parameters captured by this directive.
108 std::vector<ParameterElement *> params;
109};
110
111/// This class represents a `params` directive that refers to all parameters
112/// of an attribute or type. When used as a top-level directive, it generates
113/// a format of the form:
114///
115/// (param-value (`,` param-value)*)?
116///
117/// When used as an argument to another directive that accepts variables,
118/// `params` can be used in place of manually listing all parameters of an
119/// attribute or type.
120class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> {
121public:
122 using Base::Base;
123};
124
125/// This class represents a `struct` directive that generates a struct format
126/// of the form:
127///
128/// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
129///
130class StructDirective : public ParamsDirectiveBase<DirectiveElement::Struct> {
131public:
132 using Base::Base;
133};
134
135} // namespace
136
137//===----------------------------------------------------------------------===//
138// Format Strings
139//===----------------------------------------------------------------------===//
140
141/// Default parser for attribute or type parameters.
142static const char *const defaultParameterParser =
143 "::mlir::FieldParser<$0>::parse($_parser)";
144
145/// Default printer for attribute or type parameters.
146static const char *const defaultParameterPrinter =
147 "$_printer.printStrippedAttrOrType($_self)";
148
149/// Qualified printer for attribute or type parameters: it does not elide
150/// dialect and mnemonic.
151static const char *const qualifiedParameterPrinter = "$_printer << $_self";
152
153/// Print an error when failing to parse an element.
154///
155/// $0: The parameter C++ class name.
156static const char *const parserErrorStr =
157 "$_parser.emitError($_parser.getCurrentLocation(), ";
158
159/// Code format to parse a variable. Separate by lines because variable parsers
160/// may be generated inside other directives, which requires indentation.
161///
162/// {0}: The parameter name.
163/// {1}: The parse code for the parameter.
164/// {2}: Code template for printing an error.
165/// {3}: Name of the attribute or type.
166/// {4}: C++ class of the parameter.
167static const char *const variableParser = R"(
168// Parse variable '{0}'
169_result_{0} = {1};
170if (::mlir::failed(_result_{0})) {{
171 {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
172 return {{};
173}
174)";
175
176//===----------------------------------------------------------------------===//
177// DefFormat
178//===----------------------------------------------------------------------===//
179
180namespace {
181class DefFormat {
182public:
183 DefFormat(const AttrOrTypeDef &def, std::vector<FormatElement *> &&elements)
184 : def(def), elements(std::move(elements)) {}
185
186 /// Generate the attribute or type parser.
187 void genParser(MethodBody &os);
188 /// Generate the attribute or type printer.
189 void genPrinter(MethodBody &os);
190
191private:
192 /// Generate the parser code for a specific format element.
193 void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os);
194 /// Generate the parser code for a literal.
195 void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os,
196 bool isOptional = false);
197 /// Generate the parser code for a variable.
198 void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os);
199 /// Generate the parser code for a `params` directive.
200 void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
201 /// Generate the parser code for a `struct` directive.
202 void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
203 /// Generate the parser code for a `custom` directive.
204 void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os,
205 bool isOptional = false);
206 /// Generate the parser code for an optional group.
207 void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
208 MethodBody &os);
209
210 /// Generate the printer code for a specific format element.
211 void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os);
212 /// Generate the printer code for a literal.
213 void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os);
214 /// Generate the printer code for a variable.
215 void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
216 bool skipGuard = false);
217 /// Generate a printer for comma-separated parameters.
218 void genCommaSeparatedPrinter(ArrayRef<ParameterElement *> params,
219 FmtContext &ctx, MethodBody &os,
220 function_ref<void(ParameterElement *)> extra);
221 /// Generate the printer code for a `params` directive.
222 void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
223 /// Generate the printer code for a `struct` directive.
224 void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
225 /// Generate the printer code for a `custom` directive.
226 void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os);
227 /// Generate the printer code for an optional group.
228 void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
229 MethodBody &os);
230 /// Generate a printer (or space eraser) for a whitespace element.
231 void genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
232 MethodBody &os);
233
234 /// The ODS definition of the attribute or type whose format is being used to
235 /// generate a parser and printer.
236 const AttrOrTypeDef &def;
237 /// The list of top-level format elements returned by the assembly format
238 /// parser.
239 std::vector<FormatElement *> elements;
240
241 /// Flags for printing spaces.
242 bool shouldEmitSpace = false;
243 bool lastWasPunctuation = false;
244};
245} // namespace
246
247//===----------------------------------------------------------------------===//
248// ParserGen
249//===----------------------------------------------------------------------===//
250
251/// Generate a special-case "parser" for an attribute's self type parameter. The
252/// self type parameter has special handling in the assembly format in that it
253/// is derived from the optional trailing colon type after the attribute.
254static void genAttrSelfTypeParser(MethodBody &os, const FmtContext &ctx,
255 const AttributeSelfTypeParameter &param) {
256 // "Parser" for an attribute self type parameter that checks the
257 // optionally-parsed trailing colon type.
258 //
259 // $0: The C++ storage class of the type parameter.
260 // $1: The self type parameter name.
261 const char *const selfTypeParser = R"(
262if ($_type) {
263 if (auto reqType = ::llvm::dyn_cast<$0>($_type)) {
264 _result_$1 = reqType;
265 } else {
266 $_parser.emitError($_loc, "invalid kind of type specified");
267 return {};
268 }
269})";
270
271 // If the attribute self type parameter is required, emit code that emits an
272 // error if the trailing type was not parsed.
273 const char *const selfTypeRequired = R"( else {
274 $_parser.emitError($_loc, "expected a trailing type");
275 return {};
276})";
277
278 os << tgfmt(fmt: selfTypeParser, ctx: &ctx, vals: param.getCppStorageType(), vals: param.getName());
279 if (!param.isOptional())
280 os << tgfmt(fmt: selfTypeRequired, ctx: &ctx);
281 os << "\n";
282}
283
284void DefFormat::genParser(MethodBody &os) {
285 FmtContext ctx;
286 ctx.addSubst(placeholder: "_parser", subst: "odsParser");
287 ctx.addSubst(placeholder: "_ctxt", subst: "odsParser.getContext()");
288 ctx.withBuilder(subst: "odsBuilder");
289 if (isa<AttrDef>(Val: def))
290 ctx.addSubst(placeholder: "_type", subst: "odsType");
291 os.indent();
292 os << "::mlir::Builder odsBuilder(odsParser.getContext());\n";
293
294 // Store the initial location of the parser.
295 ctx.addSubst(placeholder: "_loc", subst: "odsLoc");
296 os << tgfmt(fmt: "::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
297 "(void) $_loc;\n",
298 ctx: &ctx);
299
300 // Declare variables to store all of the parameters. Allocated parameters
301 // such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
302 // FailureOr<T> to defer type construction for parameters that are parsed in
303 // a loop (parsers return FailureOr anyways).
304 ArrayRef<AttrOrTypeParameter> params = def.getParameters();
305 for (const AttrOrTypeParameter &param : params) {
306 os << formatv(Fmt: "::mlir::FailureOr<{0}> _result_{1};\n",
307 Vals: param.getCppStorageType(), Vals: param.getName());
308 if (auto *selfTypeParam = dyn_cast<AttributeSelfTypeParameter>(Val: &param))
309 genAttrSelfTypeParser(os, ctx, param: *selfTypeParam);
310 }
311
312 // Generate call to each parameter parser.
313 for (FormatElement *el : elements)
314 genElementParser(el, ctx, os);
315
316 // Emit an assert for each mandatory parameter. Triggering an assert means
317 // the generated parser is incorrect (i.e. there is a bug in this code).
318 for (const AttrOrTypeParameter &param : params) {
319 if (param.isOptional())
320 continue;
321 os << formatv(Fmt: "assert(::mlir::succeeded(_result_{0}));\n", Vals: param.getName());
322 }
323
324 // Generate call to the attribute or type builder. Use the checked getter
325 // if one was generated.
326 if (def.genVerifyDecl()) {
327 os << tgfmt(fmt: "return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
328 ctx: &ctx, vals: def.getCppClassName());
329 } else {
330 os << tgfmt(fmt: "return $0::get($_parser.getContext()", ctx: &ctx,
331 vals: def.getCppClassName());
332 }
333 for (const AttrOrTypeParameter &param : params) {
334 os << ",\n ";
335 std::string paramSelfStr;
336 llvm::raw_string_ostream selfOs(paramSelfStr);
337 if (std::optional<StringRef> defaultValue = param.getDefaultValue()) {
338 selfOs << formatv(Fmt: "(_result_{0}.value_or(", Vals: param.getName())
339 << tgfmt(fmt: *defaultValue, ctx: &ctx) << "))";
340 } else {
341 selfOs << formatv(Fmt: "(*_result_{0})", Vals: param.getName());
342 }
343 ctx.addSubst(placeholder: param.getName(), subst: selfOs.str());
344 os << param.getCppType() << "("
345 << tgfmt(fmt: param.getConvertFromStorage(), ctx: &ctx.withSelf(subst: selfOs.str()))
346 << ")";
347 }
348 os << ");";
349}
350
351void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
352 MethodBody &os) {
353 if (auto *literal = dyn_cast<LiteralElement>(Val: el))
354 return genLiteralParser(value: literal->getSpelling(), ctx, os);
355 if (auto *var = dyn_cast<ParameterElement>(Val: el))
356 return genVariableParser(el: var, ctx, os);
357 if (auto *params = dyn_cast<ParamsDirective>(Val: el))
358 return genParamsParser(el: params, ctx, os);
359 if (auto *strct = dyn_cast<StructDirective>(Val: el))
360 return genStructParser(el: strct, ctx, os);
361 if (auto *custom = dyn_cast<CustomDirective>(Val: el))
362 return genCustomParser(el: custom, ctx, os);
363 if (auto *optional = dyn_cast<OptionalElement>(Val: el))
364 return genOptionalGroupParser(el: optional, ctx, os);
365 if (isa<WhitespaceElement>(Val: el))
366 return;
367
368 llvm_unreachable("unknown format element");
369}
370
371void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx,
372 MethodBody &os, bool isOptional) {
373 os << "// Parse literal '" << value << "'\n";
374 os << tgfmt(fmt: "if ($_parser.parse", ctx: &ctx);
375 if (isOptional)
376 os << "Optional";
377 if (value.front() == '_' || isalpha(value.front())) {
378 os << "Keyword(\"" << value << "\")";
379 } else {
380 os << StringSwitch<StringRef>(value)
381 .Case(S: "->", Value: "Arrow")
382 .Case(S: ":", Value: "Colon")
383 .Case(S: ",", Value: "Comma")
384 .Case(S: "=", Value: "Equal")
385 .Case(S: "<", Value: "Less")
386 .Case(S: ">", Value: "Greater")
387 .Case(S: "{", Value: "LBrace")
388 .Case(S: "}", Value: "RBrace")
389 .Case(S: "(", Value: "LParen")
390 .Case(S: ")", Value: "RParen")
391 .Case(S: "[", Value: "LSquare")
392 .Case(S: "]", Value: "RSquare")
393 .Case(S: "?", Value: "Question")
394 .Case(S: "+", Value: "Plus")
395 .Case(S: "*", Value: "Star")
396 .Case(S: "...", Value: "Ellipsis")
397 << "()";
398 }
399 if (isOptional) {
400 // Leave the `if` unclosed to guard optional groups.
401 return;
402 }
403 // Parser will emit an error
404 os << ") return {};\n";
405}
406
407void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
408 MethodBody &os) {
409 // Check for a custom parser. Use the default attribute parser otherwise.
410 const AttrOrTypeParameter &param = el->getParam();
411 auto customParser = param.getParser();
412 auto parser =
413 customParser ? *customParser : StringRef(defaultParameterParser);
414 os << formatv(Fmt: variableParser, Vals: param.getName(),
415 Vals: tgfmt(fmt: parser, ctx: &ctx, vals: param.getCppStorageType()),
416 Vals: tgfmt(fmt: parserErrorStr, ctx: &ctx), Vals: def.getName(), Vals: param.getCppType());
417}
418
419void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
420 MethodBody &os) {
421 os << "// Parse parameter list\n";
422
423 // If there are optional parameters, we need to switch to `parseOptionalComma`
424 // if there are no more required parameters after a certain point.
425 bool hasOptional = el->hasOptionalParams();
426 if (hasOptional) {
427 // Wrap everything in a do-while so that we can `break`.
428 os << "do {\n";
429 os.indent();
430 }
431
432 ArrayRef<ParameterElement *> params = el->getParams();
433 using IteratorT = ParameterElement *const *;
434 IteratorT it = params.begin();
435
436 // Find the last required parameter. Commas become optional aftewards.
437 // Note: IteratorT's copy assignment is deleted.
438 ParameterElement *lastReq = nullptr;
439 for (ParameterElement *param : params)
440 if (!param->isOptional())
441 lastReq = param;
442 IteratorT lastReqIt = lastReq ? llvm::find(Range&: params, Val: lastReq) : params.begin();
443
444 auto eachFn = [&](ParameterElement *el) { genVariableParser(el, ctx, os); };
445 auto betweenFn = [&](IteratorT it) {
446 ParameterElement *el = *std::prev(x: it);
447 // Parse a comma if the last optional parameter had a value.
448 if (el->isOptional()) {
449 os << formatv(Fmt: "if (::mlir::succeeded(_result_{0}) && !({1})) {{\n",
450 Vals: el->getName(),
451 Vals: el->genIsPresent(ctx, self: "(*_result_" + el->getName() + ")"));
452 os.indent();
453 }
454 if (it <= lastReqIt) {
455 genLiteralParser(value: ",", ctx, os);
456 } else {
457 genLiteralParser(value: ",", ctx, os, /*isOptional=*/true);
458 os << ") break;\n";
459 }
460 if (el->isOptional())
461 os.unindent() << "}\n";
462 };
463
464 // llvm::interleave
465 if (it != params.end()) {
466 eachFn(*it++);
467 for (IteratorT e = params.end(); it != e; ++it) {
468 betweenFn(it);
469 eachFn(*it);
470 }
471 }
472
473 if (hasOptional)
474 os.unindent() << "} while(false);\n";
475}
476
477void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
478 MethodBody &os) {
479 // Loop declaration for struct parser with only required parameters.
480 //
481 // $0: Number of expected parameters.
482 const char *const loopHeader = R"(
483 for (unsigned odsStructIndex = 0; odsStructIndex < $0; ++odsStructIndex) {
484)";
485
486 // Loop body start for struct parser.
487 const char *const loopStart = R"(
488 ::llvm::StringRef _paramKey;
489 if ($_parser.parseKeyword(&_paramKey)) {
490 $_parser.emitError($_parser.getCurrentLocation(),
491 "expected a parameter name in struct");
492 return {};
493 }
494 if (!_loop_body(_paramKey)) return {};
495)";
496
497 // Struct parser loop end. Check for duplicate or unknown struct parameters.
498 //
499 // {0}: Code template for printing an error.
500 const char *const loopEnd = R"({{
501 {0}"duplicate or unknown struct parameter name: ") << _paramKey;
502 return {{};
503}
504)";
505
506 // Struct parser loop terminator. Parse a comma except on the last element.
507 //
508 // {0}: Number of elements in the struct.
509 const char *const loopTerminator = R"(
510 if ((odsStructIndex != {0} - 1) && odsParser.parseComma())
511 return {{};
512}
513)";
514
515 // Check that a mandatory parameter was parse.
516 //
517 // {0}: Name of the parameter.
518 const char *const checkParam = R"(
519 if (!_seen_{0}) {
520 {1}"struct is missing required parameter: ") << "{0}";
521 return {{};
522 }
523)";
524
525 // First iteration of the loop parsing an optional struct.
526 const char *const optionalStructFirst = R"(
527 ::llvm::StringRef _paramKey;
528 if (!$_parser.parseOptionalKeyword(&_paramKey)) {
529 if (!_loop_body(_paramKey)) return {};
530 while (!$_parser.parseOptionalComma()) {
531)";
532
533 os << "// Parse parameter struct\n";
534
535 // Declare a "seen" variable for each key.
536 for (ParameterElement *param : el->getParams())
537 os << formatv(Fmt: "bool _seen_{0} = false;\n", Vals: param->getName());
538
539 // Generate the body of the parsing loop inside a lambda.
540 os << "{\n";
541 os.indent()
542 << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
543 genLiteralParser(value: "=", ctx, os&: os.indent());
544 for (ParameterElement *param : el->getParams()) {
545 os << formatv(Fmt: "if (!_seen_{0} && _paramKey == \"{0}\") {\n"
546 " _seen_{0} = true;\n",
547 Vals: param->getName());
548 genVariableParser(el: param, ctx, os&: os.indent());
549 os.unindent() << "} else ";
550 // Print the check for duplicate or unknown parameter.
551 }
552 os.getStream().printReindented(str: strfmt(fmt: loopEnd, parameters: tgfmt(fmt: parserErrorStr, ctx: &ctx)));
553 os << "return true;\n";
554 os.unindent() << "};\n";
555
556 // Generate the parsing loop. If optional parameters are present, then the
557 // parse loop is guarded by commas.
558 unsigned numOptional = llvm::count_if(Range: el->getParams(), P: paramIsOptional);
559 if (numOptional) {
560 // If the struct itself is optional, pull out the first iteration.
561 if (numOptional == el->getNumParams()) {
562 os.getStream().printReindented(str: tgfmt(fmt: optionalStructFirst, ctx: &ctx).str());
563 os.indent();
564 } else {
565 os << "do {\n";
566 }
567 } else {
568 os.getStream().printReindented(
569 str: tgfmt(fmt: loopHeader, ctx: &ctx, vals: el->getNumParams()).str());
570 }
571 os.indent();
572 os.getStream().printReindented(str: tgfmt(fmt: loopStart, ctx: &ctx).str());
573 os.unindent();
574
575 // Print the loop terminator. For optional parameters, we have to check that
576 // all mandatory parameters have been parsed.
577 // The whole struct is optional if all its parameters are optional.
578 if (numOptional) {
579 if (numOptional == el->getNumParams()) {
580 os << "}\n";
581 os.unindent() << "}\n";
582 } else {
583 os << tgfmt(fmt: "} while(!$_parser.parseOptionalComma());\n", ctx: &ctx);
584 for (ParameterElement *param : el->getParams()) {
585 if (param->isOptional())
586 continue;
587 os.getStream().printReindented(
588 str: strfmt(fmt: checkParam, parameters: param->getName(), parameters: tgfmt(fmt: parserErrorStr, ctx: &ctx)));
589 }
590 }
591 } else {
592 // Because the loop loops N times and each non-failing iteration sets 1 of
593 // N flags, successfully exiting the loop means that all parameters have
594 // been seen. `parseOptionalComma` would cause issues with any formats that
595 // use "struct(...) `,`" beacuse structs aren't sounded by braces.
596 os.getStream().printReindented(str: strfmt(fmt: loopTerminator, parameters: el->getNumParams()));
597 }
598 os.unindent() << "}\n";
599}
600
601void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
602 MethodBody &os, bool isOptional) {
603 os << "{\n";
604 os.indent();
605
606 // Bound variables are passed directly to the parser as `FailureOr<T> &`.
607 // Referenced variables are passed as `T`. The custom parser fails if it
608 // returns failure or if any of the required parameters failed.
609 os << tgfmt(fmt: "auto odsCustomLoc = $_parser.getCurrentLocation();\n", ctx: &ctx);
610 os << "(void)odsCustomLoc;\n";
611 os << tgfmt(fmt: "auto odsCustomResult = parse$0($_parser", ctx: &ctx, vals: el->getName());
612 os.indent();
613 for (FormatElement *arg : el->getArguments()) {
614 os << ",\n";
615 if (auto *param = dyn_cast<ParameterElement>(Val: arg))
616 os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName()
617 << ")";
618 else if (auto *ref = dyn_cast<RefDirective>(Val: arg))
619 os << "*_result_" << cast<ParameterElement>(Val: ref->getArg())->getName();
620 else
621 os << tgfmt(fmt: cast<StringElement>(Val: arg)->getValue(), ctx: &ctx);
622 }
623 os.unindent() << ");\n";
624 if (isOptional) {
625 os << "if (!odsCustomResult.has_value()) return {};\n";
626 os << "if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();\n";
627 } else {
628 os << "if (::mlir::failed(odsCustomResult)) return {};\n";
629 }
630 for (FormatElement *arg : el->getArguments()) {
631 if (auto *param = dyn_cast<ParameterElement>(Val: arg)) {
632 if (param->isOptional())
633 continue;
634 os << formatv(Fmt: "if (::mlir::failed(_result_{0})) {{\n", Vals: param->getName());
635 os.indent() << tgfmt(fmt: "$_parser.emitError(odsCustomLoc, ", ctx: &ctx)
636 << "\"custom parser failed to parse parameter '"
637 << param->getName() << "'\");\n";
638 os << "return " << (isOptional ? "::mlir::failure()" : "{}") << ";\n";
639 os.unindent() << "}\n";
640 }
641 }
642
643 os.unindent() << "}\n";
644}
645
646void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
647 MethodBody &os) {
648 ArrayRef<FormatElement *> thenElements =
649 el->getThenElements(/*parseable=*/true);
650
651 FormatElement *first = thenElements.front();
652 const auto guardOn = [&](auto params) {
653 os << "if (!(";
654 llvm::interleave(
655 params, os,
656 [&](ParameterElement *el) {
657 os << formatv(Fmt: "(::mlir::succeeded(_result_{0}) && *_result_{0})",
658 Vals: el->getName());
659 },
660 " || ");
661 os << ")) {\n";
662 };
663 if (auto *literal = dyn_cast<LiteralElement>(Val: first)) {
664 genLiteralParser(value: literal->getSpelling(), ctx, os, /*isOptional=*/true);
665 os << ") {\n";
666 } else if (auto *param = dyn_cast<ParameterElement>(Val: first)) {
667 genVariableParser(el: param, ctx, os);
668 guardOn(llvm::ArrayRef(param));
669 } else if (auto *params = dyn_cast<ParamsDirective>(Val: first)) {
670 genParamsParser(el: params, ctx, os);
671 guardOn(params->getParams());
672 } else if (auto *custom = dyn_cast<CustomDirective>(Val: first)) {
673 os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
674 os.indent();
675 genCustomParser(el: custom, ctx, os, /*isOptional=*/true);
676 os << "return ::mlir::success();\n";
677 os.unindent();
678 os << "}(); result.has_value() && ::mlir::failed(*result)) {\n";
679 os.indent();
680 os << "return {};\n";
681 os.unindent();
682 os << "} else if (result.has_value()) {\n";
683 } else {
684 auto *strct = cast<StructDirective>(Val: first);
685 genStructParser(el: strct, ctx, os);
686 guardOn(params->getParams());
687 }
688 os.indent();
689
690 // Generate the parsers for the rest of the thenElements.
691 for (FormatElement *element : el->getElseElements(/*parseable=*/true))
692 genElementParser(el: element, ctx, os);
693 os.unindent() << "} else {\n";
694 os.indent();
695 for (FormatElement *element : thenElements.drop_front())
696 genElementParser(el: element, ctx, os);
697 os.unindent() << "}\n";
698}
699
700//===----------------------------------------------------------------------===//
701// PrinterGen
702//===----------------------------------------------------------------------===//
703
704void DefFormat::genPrinter(MethodBody &os) {
705 FmtContext ctx;
706 ctx.addSubst(placeholder: "_printer", subst: "odsPrinter");
707 ctx.addSubst(placeholder: "_ctxt", subst: "getContext()");
708 ctx.withBuilder(subst: "odsBuilder");
709 os.indent();
710 os << "::mlir::Builder odsBuilder(getContext());\n";
711
712 // Generate printers.
713 shouldEmitSpace = true;
714 lastWasPunctuation = false;
715 for (FormatElement *el : elements)
716 genElementPrinter(el, ctx, os);
717}
718
719void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
720 MethodBody &os) {
721 if (auto *literal = dyn_cast<LiteralElement>(Val: el))
722 return genLiteralPrinter(value: literal->getSpelling(), ctx, os);
723 if (auto *params = dyn_cast<ParamsDirective>(Val: el))
724 return genParamsPrinter(el: params, ctx, os);
725 if (auto *strct = dyn_cast<StructDirective>(Val: el))
726 return genStructPrinter(el: strct, ctx, os);
727 if (auto *custom = dyn_cast<CustomDirective>(Val: el))
728 return genCustomPrinter(el: custom, ctx, os);
729 if (auto *var = dyn_cast<ParameterElement>(Val: el))
730 return genVariablePrinter(el: var, ctx, os);
731 if (auto *optional = dyn_cast<OptionalElement>(Val: el))
732 return genOptionalGroupPrinter(el: optional, ctx, os);
733 if (auto *whitespace = dyn_cast<WhitespaceElement>(Val: el))
734 return genWhitespacePrinter(el: whitespace, ctx, os);
735
736 llvm::PrintFatalError(Msg: "unsupported format element");
737}
738
739void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
740 MethodBody &os) {
741 // Don't insert a space before certain punctuation.
742 bool needSpace =
743 shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
744 os << tgfmt(fmt: "$_printer$0 << \"$1\";\n", ctx: &ctx, vals: needSpace ? " << ' '" : "",
745 vals&: value);
746
747 // Update the flags.
748 shouldEmitSpace =
749 value.size() != 1 || !StringRef("<({[").contains(C: value.front());
750 lastWasPunctuation = value.front() != '_' && !isalpha(value.front());
751}
752
753void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
754 MethodBody &os, bool skipGuard) {
755 const AttrOrTypeParameter &param = el->getParam();
756 ctx.withSelf(subst: param.getAccessorName() + "()");
757
758 // Guard the printer on the presence of optional parameters and that they
759 // aren't equal to their default values (if they have one).
760 if (el->isOptional() && !skipGuard) {
761 el->genPrintGuard(ctx, os&: os << "if (") << ") {\n";
762 os.indent();
763 }
764
765 // Insert a space before the next parameter, if necessary.
766 if (shouldEmitSpace || !lastWasPunctuation)
767 os << tgfmt(fmt: "$_printer << ' ';\n", ctx: &ctx);
768 shouldEmitSpace = true;
769 lastWasPunctuation = false;
770
771 if (el->shouldBeQualified())
772 os << tgfmt(fmt: qualifiedParameterPrinter, ctx: &ctx) << ";\n";
773 else if (auto printer = param.getPrinter())
774 os << tgfmt(fmt: *printer, ctx: &ctx) << ";\n";
775 else
776 os << tgfmt(fmt: defaultParameterPrinter, ctx: &ctx) << ";\n";
777
778 if (el->isOptional() && !skipGuard)
779 os.unindent() << "}\n";
780}
781
782/// Generate code to guard printing on the presence of any optional parameters.
783template <typename ParameterRange>
784static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &&params,
785 bool inverted = false) {
786 os << "if (";
787 if (inverted)
788 os << "!(";
789 llvm::interleave(
790 params, os,
791 [&](ParameterElement *param) { param->genPrintGuard(ctx, os); }, " || ");
792 if (inverted)
793 os << ")";
794 os << ") {\n";
795 os.indent();
796}
797
798void DefFormat::genCommaSeparatedPrinter(
799 ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os,
800 function_ref<void(ParameterElement *)> extra) {
801 // Emit a space if necessary, but only if the struct is present.
802 if (shouldEmitSpace || !lastWasPunctuation) {
803 bool allOptional = llvm::all_of(Range&: params, P: paramIsOptional);
804 if (allOptional)
805 guardOnAny(ctx, os, params);
806 os << tgfmt(fmt: "$_printer << ' ';\n", ctx: &ctx);
807 if (allOptional)
808 os.unindent() << "}\n";
809 }
810
811 // The first printed element does not need to emit a comma.
812 os << "{\n";
813 os.indent() << "bool _firstPrinted = true;\n";
814 for (ParameterElement *param : params) {
815 if (param->isOptional()) {
816 param->genPrintGuard(ctx, os&: os << "if (") << ") {\n";
817 os.indent();
818 }
819 os << tgfmt(fmt: "if (!_firstPrinted) $_printer << \", \";\n", ctx: &ctx);
820 os << "_firstPrinted = false;\n";
821 extra(param);
822 shouldEmitSpace = false;
823 lastWasPunctuation = true;
824 genVariablePrinter(el: param, ctx, os);
825 if (param->isOptional())
826 os.unindent() << "}\n";
827 }
828 os.unindent() << "}\n";
829}
830
831void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
832 MethodBody &os) {
833 genCommaSeparatedPrinter(params: llvm::to_vector(Range: el->getParams()), ctx, os,
834 extra: [&](ParameterElement *param) {});
835}
836
837void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
838 MethodBody &os) {
839 genCommaSeparatedPrinter(
840 params: llvm::to_vector(Range: el->getParams()), ctx, os, extra: [&](ParameterElement *param) {
841 os << tgfmt(fmt: "$_printer << \"$0 = \";\n", ctx: &ctx, vals: param->getName());
842 });
843}
844
845void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
846 MethodBody &os) {
847 // Insert a space before the custom directive, if necessary.
848 if (shouldEmitSpace || !lastWasPunctuation)
849 os << tgfmt(fmt: "$_printer << ' ';\n", ctx: &ctx);
850 shouldEmitSpace = true;
851 lastWasPunctuation = false;
852
853 os << tgfmt(fmt: "print$0($_printer", ctx: &ctx, vals: el->getName());
854 os.indent();
855 for (FormatElement *arg : el->getArguments()) {
856 os << ",\n";
857 if (auto *param = dyn_cast<ParameterElement>(Val: arg)) {
858 os << param->getParam().getAccessorName() << "()";
859 } else if (auto *ref = dyn_cast<RefDirective>(Val: arg)) {
860 os << cast<ParameterElement>(Val: ref->getArg())->getParam().getAccessorName()
861 << "()";
862 } else {
863 os << tgfmt(fmt: cast<StringElement>(Val: arg)->getValue(), ctx: &ctx);
864 }
865 }
866 os.unindent() << ");\n";
867}
868
869void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
870 MethodBody &os) {
871 FormatElement *anchor = el->getAnchor();
872 if (auto *param = dyn_cast<ParameterElement>(Val: anchor)) {
873 guardOnAny(ctx, os, params: llvm::ArrayRef(param), inverted: el->isInverted());
874 } else if (auto *params = dyn_cast<ParamsDirective>(Val: anchor)) {
875 guardOnAny(ctx, os, params: params->getParams(), inverted: el->isInverted());
876 } else if (auto *strct = dyn_cast<StructDirective>(Val: anchor)) {
877 guardOnAny(ctx, os, params: strct->getParams(), inverted: el->isInverted());
878 } else {
879 auto *custom = cast<CustomDirective>(Val: anchor);
880 guardOnAny(ctx, os,
881 params: llvm::make_filter_range(
882 Range: llvm::map_range(C: custom->getArguments(),
883 F: [](FormatElement *el) {
884 return dyn_cast<ParameterElement>(Val: el);
885 }),
886 Pred: [](ParameterElement *param) { return !!param; }),
887 inverted: el->isInverted());
888 }
889 // Generate the printer for the contained elements.
890 {
891 llvm::SaveAndRestore shouldEmitSpaceFlag(shouldEmitSpace);
892 llvm::SaveAndRestore lastWasPunctuationFlag(lastWasPunctuation);
893 for (FormatElement *element : el->getThenElements())
894 genElementPrinter(el: element, ctx, os);
895 }
896 os.unindent() << "} else {\n";
897 os.indent();
898 for (FormatElement *element : el->getElseElements())
899 genElementPrinter(el: element, ctx, os);
900 os.unindent() << "}\n";
901}
902
903void DefFormat::genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
904 MethodBody &os) {
905 if (el->getValue() == "\\n") {
906 // FIXME: The newline should be `printer.printNewLine()`, i.e., handled by
907 // the printer.
908 os << tgfmt(fmt: "$_printer << '\\n';\n", ctx: &ctx);
909 } else if (!el->getValue().empty()) {
910 os << tgfmt(fmt: "$_printer << \"$0\";\n", ctx: &ctx, vals: el->getValue());
911 } else {
912 lastWasPunctuation = true;
913 }
914 shouldEmitSpace = false;
915}
916
917//===----------------------------------------------------------------------===//
918// DefFormatParser
919//===----------------------------------------------------------------------===//
920
921namespace {
922class DefFormatParser : public FormatParser {
923public:
924 DefFormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
925 : FormatParser(mgr, def.getLoc()[0]), def(def),
926 seenParams(def.getNumParameters()) {}
927
928 /// Parse the attribute or type format and create the format elements.
929 FailureOr<DefFormat> parse();
930
931protected:
932 /// Verify the parsed elements.
933 LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
934 /// Verify the elements of a custom directive.
935 LogicalResult
936 verifyCustomDirectiveArguments(SMLoc loc,
937 ArrayRef<FormatElement *> arguments) override;
938 /// Verify the elements of an optional group.
939 LogicalResult verifyOptionalGroupElements(SMLoc loc,
940 ArrayRef<FormatElement *> elements,
941 FormatElement *anchor) override;
942
943 /// Parse an attribute or type variable.
944 FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
945 Context ctx) override;
946 /// Parse an attribute or type format directive.
947 FailureOr<FormatElement *>
948 parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
949
950private:
951 /// Parse a `params` directive.
952 FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
953 /// Parse a `qualified` directive.
954 FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
955 /// Parse a `struct` directive.
956 FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
957 /// Parse a `ref` directive.
958 FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context ctx);
959
960 /// Attribute or type tablegen def.
961 const AttrOrTypeDef &def;
962
963 /// Seen attribute or type parameters.
964 BitVector seenParams;
965};
966} // namespace
967
968LogicalResult DefFormatParser::verify(SMLoc loc,
969 ArrayRef<FormatElement *> elements) {
970 // Check that all parameters are referenced in the format.
971 for (auto [index, param] : llvm::enumerate(First: def.getParameters())) {
972 if (param.isOptional())
973 continue;
974 if (!seenParams.test(Idx: index)) {
975 if (isa<AttributeSelfTypeParameter>(Val: param))
976 continue;
977 return emitError(loc, msg: "format is missing reference to parameter: " +
978 param.getName());
979 }
980 if (isa<AttributeSelfTypeParameter>(Val: param)) {
981 return emitError(loc,
982 msg: "unexpected self type parameter in assembly format");
983 }
984 }
985 if (elements.empty())
986 return success();
987 // A `struct` directive that contains optional parameters cannot be followed
988 // by a comma literal, which is ambiguous.
989 for (auto it : llvm::zip(t: elements.drop_back(), u: elements.drop_front())) {
990 auto *structEl = dyn_cast<StructDirective>(Val: std::get<0>(t&: it));
991 auto *literalEl = dyn_cast<LiteralElement>(Val: std::get<1>(t&: it));
992 if (!structEl || !literalEl)
993 continue;
994 if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) {
995 return emitError(loc, msg: "`struct` directive with optional parameters "
996 "cannot be followed by a comma literal");
997 }
998 }
999 return success();
1000}
1001
1002LogicalResult DefFormatParser::verifyCustomDirectiveArguments(
1003 SMLoc loc, ArrayRef<FormatElement *> arguments) {
1004 // Arguments are fully verified by the parser context.
1005 return success();
1006}
1007
1008LogicalResult
1009DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
1010 ArrayRef<FormatElement *> elements,
1011 FormatElement *anchor) {
1012 // `params` and `struct` directives are allowed only if all the contained
1013 // parameters are optional.
1014 for (FormatElement *el : elements) {
1015 if (auto *param = dyn_cast<ParameterElement>(Val: el)) {
1016 if (!param->isOptional()) {
1017 return emitError(loc,
1018 msg: "parameters in an optional group must be optional");
1019 }
1020 } else if (auto *params = dyn_cast<ParamsDirective>(Val: el)) {
1021 if (llvm::any_of(Range: params->getParams(), P: paramNotOptional)) {
1022 return emitError(loc, msg: "`params` directive allowed in optional group "
1023 "only if all parameters are optional");
1024 }
1025 } else if (auto *strct = dyn_cast<StructDirective>(Val: el)) {
1026 if (llvm::any_of(Range: strct->getParams(), P: paramNotOptional)) {
1027 return emitError(loc, msg: "`struct` is only allowed in an optional group "
1028 "if all captured parameters are optional");
1029 }
1030 } else if (auto *custom = dyn_cast<CustomDirective>(Val: el)) {
1031 for (FormatElement *el : custom->getArguments()) {
1032 // If the custom argument is a variable, then it must be optional.
1033 if (auto *param = dyn_cast<ParameterElement>(Val: el))
1034 if (!param->isOptional())
1035 return emitError(loc,
1036 msg: "`custom` is only allowed in an optional group if "
1037 "all captured parameters are optional");
1038 }
1039 }
1040 }
1041 // The anchor must be a parameter or one of the aforementioned directives.
1042 if (anchor) {
1043 if (!isa<ParameterElement, ParamsDirective, StructDirective,
1044 CustomDirective>(Val: anchor)) {
1045 return emitError(
1046 loc, msg: "optional group anchor must be a parameter or directive");
1047 }
1048 // If the anchor is a custom directive, make sure at least one of its
1049 // arguments is a bound parameter.
1050 if (auto *custom = dyn_cast<CustomDirective>(Val: anchor)) {
1051 const auto *bound =
1052 llvm::find_if(Range: custom->getArguments(), P: [](FormatElement *el) {
1053 return isa<ParameterElement>(Val: el);
1054 });
1055 if (bound == custom->getArguments().end())
1056 return emitError(loc, msg: "`custom` directive with no bound parameters "
1057 "cannot be used as optional group anchor");
1058 }
1059 }
1060 return success();
1061}
1062
1063FailureOr<DefFormat> DefFormatParser::parse() {
1064 FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
1065 if (failed(result: elements))
1066 return failure();
1067 return DefFormat(def, std::move(*elements));
1068}
1069
1070FailureOr<FormatElement *>
1071DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
1072 // Lookup the parameter.
1073 ArrayRef<AttrOrTypeParameter> params = def.getParameters();
1074 auto *it = llvm::find_if(
1075 Range&: params, P: [&](auto &param) { return param.getName() == name; });
1076
1077 // Check that the parameter reference is valid.
1078 if (it == params.end()) {
1079 return emitError(loc,
1080 msg: def.getName() + " has no parameter named '" + name + "'");
1081 }
1082 auto idx = std::distance(first: params.begin(), last: it);
1083
1084 if (ctx != RefDirectiveContext) {
1085 // Check that the variable has not already been bound.
1086 if (seenParams.test(Idx: idx))
1087 return emitError(loc, msg: "duplicate parameter '" + name + "'");
1088 seenParams.set(idx);
1089
1090 // Otherwise, to be referenced, a variable must have been bound.
1091 } else if (!seenParams.test(Idx: idx) && !isa<AttributeSelfTypeParameter>(Val: *it)) {
1092 return emitError(loc, msg: "parameter '" + name +
1093 "' must be bound before it is referenced");
1094 }
1095
1096 return create<ParameterElement>(args: *it);
1097}
1098
1099FailureOr<FormatElement *>
1100DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
1101 Context ctx) {
1102
1103 switch (kind) {
1104 case FormatToken::kw_qualified:
1105 return parseQualifiedDirective(loc, ctx);
1106 case FormatToken::kw_params:
1107 return parseParamsDirective(loc, ctx);
1108 case FormatToken::kw_struct:
1109 return parseStructDirective(loc, ctx);
1110 case FormatToken::kw_ref:
1111 return parseRefDirective(loc, ctx);
1112 case FormatToken::kw_custom:
1113 return parseCustomDirective(loc, ctx);
1114
1115 default:
1116 return emitError(loc, msg: "unsupported directive kind");
1117 }
1118}
1119
1120FailureOr<FormatElement *>
1121DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
1122 if (failed(result: parseToken(kind: FormatToken::l_paren,
1123 msg: "expected '(' before argument list")))
1124 return failure();
1125 FailureOr<FormatElement *> var = parseElement(ctx);
1126 if (failed(result: var))
1127 return var;
1128 if (!isa<ParameterElement>(Val: *var))
1129 return emitError(loc, msg: "`qualified` argument list expected a variable");
1130 cast<ParameterElement>(Val: *var)->setShouldBeQualified();
1131 if (failed(
1132 result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list")))
1133 return failure();
1134 return var;
1135}
1136
1137FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
1138 Context ctx) {
1139 // It doesn't make sense to allow references to all parameters in a custom
1140 // directive because parameters are the only things that can be bound.
1141 if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
1142 return emitError(loc, msg: "`params` can only be used at the top-level context "
1143 "or within a `struct` directive");
1144 }
1145
1146 // Collect all of the attribute's or type's parameters and ensure that none of
1147 // the parameters have already been captured.
1148 std::vector<ParameterElement *> vars;
1149 for (const auto &it : llvm::enumerate(First: def.getParameters())) {
1150 if (seenParams.test(Idx: it.index())) {
1151 return emitError(loc, msg: "`params` captures duplicate parameter: " +
1152 it.value().getName());
1153 }
1154 // Self-type parameters are handled separately from the rest of the
1155 // parameters.
1156 if (isa<AttributeSelfTypeParameter>(Val: it.value()))
1157 continue;
1158 seenParams.set(it.index());
1159 vars.push_back(x: create<ParameterElement>(args: it.value()));
1160 }
1161 return create<ParamsDirective>(args: std::move(vars));
1162}
1163
1164FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
1165 Context ctx) {
1166 if (ctx != TopLevelContext)
1167 return emitError(loc, msg: "`struct` can only be used at the top-level context");
1168
1169 if (failed(result: parseToken(kind: FormatToken::l_paren,
1170 msg: "expected '(' before `struct` argument list")))
1171 return failure();
1172
1173 // Parse variables captured by `struct`.
1174 std::vector<ParameterElement *> vars;
1175
1176 // Parse first captured parameter or a `params` directive.
1177 FailureOr<FormatElement *> var = parseElement(ctx: StructDirectiveContext);
1178 if (failed(result: var) || !isa<VariableElement, ParamsDirective>(Val: *var)) {
1179 return emitError(loc,
1180 msg: "`struct` argument list expected a variable or directive");
1181 }
1182 if (isa<VariableElement>(Val: *var)) {
1183 // Parse any other parameters.
1184 vars.push_back(x: cast<ParameterElement>(Val: *var));
1185 while (peekToken().is(kind: FormatToken::comma)) {
1186 consumeToken();
1187 var = parseElement(ctx: StructDirectiveContext);
1188 if (failed(result: var) || !isa<VariableElement>(Val: *var))
1189 return emitError(loc, msg: "expected a variable in `struct` argument list");
1190 vars.push_back(x: cast<ParameterElement>(Val: *var));
1191 }
1192 } else {
1193 // `struct(params)` captures all parameters in the attribute or type.
1194 vars = cast<ParamsDirective>(Val: *var)->takeParams();
1195 }
1196
1197 if (failed(result: parseToken(kind: FormatToken::r_paren,
1198 msg: "expected ')' at the end of an argument list")))
1199 return failure();
1200
1201 return create<StructDirective>(args: std::move(vars));
1202}
1203
1204FailureOr<FormatElement *> DefFormatParser::parseRefDirective(SMLoc loc,
1205 Context ctx) {
1206 if (ctx != CustomDirectiveContext)
1207 return emitError(loc, msg: "`ref` is only allowed inside custom directives");
1208
1209 // Parse the child parameter element.
1210 FailureOr<FormatElement *> child;
1211 if (failed(result: parseToken(kind: FormatToken::l_paren, msg: "expected '('")) ||
1212 failed(result: child = parseElement(ctx: RefDirectiveContext)) ||
1213 failed(result: parseToken(kind: FormatToken::r_paren, msg: "expeced ')'")))
1214 return failure();
1215
1216 // Only parameter elements are allowed to be parsed under a `ref` directive.
1217 return create<RefDirective>(args&: *child);
1218}
1219
1220//===----------------------------------------------------------------------===//
1221// Interface
1222//===----------------------------------------------------------------------===//
1223
1224void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
1225 MethodBody &parser,
1226 MethodBody &printer) {
1227 llvm::SourceMgr mgr;
1228 mgr.AddNewSourceBuffer(
1229 F: llvm::MemoryBuffer::getMemBuffer(InputData: *def.getAssemblyFormat()), IncludeLoc: SMLoc());
1230
1231 // Parse the custom assembly format>
1232 DefFormatParser fmtParser(mgr, def);
1233 FailureOr<DefFormat> format = fmtParser.parse();
1234 if (failed(result: format)) {
1235 if (formatErrorIsFatal)
1236 PrintFatalError(ErrorLoc: def.getLoc(), Msg: "failed to parse assembly format");
1237 return;
1238 }
1239
1240 // Generate the parser and printer.
1241 format->genParser(os&: parser);
1242 format->genPrinter(os&: printer);
1243}
1244

source code of mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp