1 | //===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "AttrOrTypeFormatGen.h" |
10 | #include "FormatGen.h" |
11 | #include "mlir/Support/LLVM.h" |
12 | #include "mlir/Support/LogicalResult.h" |
13 | #include "mlir/TableGen/AttrOrTypeDef.h" |
14 | #include "mlir/TableGen/Format.h" |
15 | #include "mlir/TableGen/GenInfo.h" |
16 | #include "llvm/ADT/BitVector.h" |
17 | #include "llvm/ADT/StringExtras.h" |
18 | #include "llvm/ADT/StringSwitch.h" |
19 | #include "llvm/ADT/TypeSwitch.h" |
20 | #include "llvm/Support/MemoryBuffer.h" |
21 | #include "llvm/Support/SaveAndRestore.h" |
22 | #include "llvm/Support/SourceMgr.h" |
23 | #include "llvm/TableGen/Error.h" |
24 | #include "llvm/TableGen/TableGenBackend.h" |
25 | |
26 | using namespace mlir; |
27 | using namespace mlir::tblgen; |
28 | |
29 | using llvm::formatv; |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // Element |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | namespace { |
36 | /// This class represents an instance of a variable element. A variable refers |
37 | /// to an attribute or type parameter. |
38 | class ParameterElement |
39 | : public VariableElementBase<VariableElement::Parameter> { |
40 | public: |
41 | ParameterElement(AttrOrTypeParameter param) : param(param) {} |
42 | |
43 | /// Get the parameter in the element. |
44 | const AttrOrTypeParameter &getParam() const { return param; } |
45 | |
46 | /// Indicate if this variable is printed "qualified" (that is it is |
47 | /// prefixed with the `#dialect.mnemonic`). |
48 | bool shouldBeQualified() { return shouldBeQualifiedFlag; } |
49 | void setShouldBeQualified(bool qualified = true) { |
50 | shouldBeQualifiedFlag = qualified; |
51 | } |
52 | |
53 | /// Returns true if the element contains an optional parameter. |
54 | bool isOptional() const { return param.isOptional(); } |
55 | |
56 | /// Returns the name of the parameter. |
57 | StringRef getName() const { return param.getName(); } |
58 | |
59 | /// Return the code to check whether the parameter is present. |
60 | auto genIsPresent(FmtContext &ctx, const Twine &self) const { |
61 | assert(isOptional() && "cannot guard on a mandatory parameter"); |
62 | std::string valueStr = tgfmt(fmt: *param.getDefaultValue(), ctx: &ctx).str(); |
63 | ctx.addSubst(placeholder: "_lhs", subst: self).addSubst(placeholder: "_rhs", subst: valueStr); |
64 | return tgfmt(fmt: getParam().getComparator(), ctx: &ctx); |
65 | } |
66 | |
67 | /// Generate the code to check whether the parameter should be printed. |
68 | MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const { |
69 | assert(isOptional() && "cannot guard on a mandatory parameter"); |
70 | std::string self = param.getAccessorName() + "()"; |
71 | return os << "!("<< genIsPresent(ctx, self) << ")"; |
72 | } |
73 | |
74 | private: |
75 | bool shouldBeQualifiedFlag = false; |
76 | AttrOrTypeParameter param; |
77 | }; |
78 | |
79 | /// Shorthand functions that can be used with ranged-based conditions. |
80 | static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); } |
81 | static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); } |
82 | |
83 | /// Base class for a directive that contains references to multiple variables. |
84 | template <DirectiveElement::Kind DirectiveKind> |
85 | class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> { |
86 | public: |
87 | using Base = ParamsDirectiveBase<DirectiveKind>; |
88 | |
89 | ParamsDirectiveBase(std::vector<ParameterElement *> &¶ms) |
90 | : params(std::move(params)) {} |
91 | |
92 | /// Get the parameters contained in this directive. |
93 | ArrayRef<ParameterElement *> getParams() const { return params; } |
94 | |
95 | /// Get the number of parameters. |
96 | unsigned getNumParams() const { return params.size(); } |
97 | |
98 | /// Take all of the parameters from this directive. |
99 | std::vector<ParameterElement *> takeParams() { return std::move(params); } |
100 | |
101 | /// Returns true if there are optional parameters present. |
102 | bool hasOptionalParams() const { |
103 | return llvm::any_of(getParams(), paramIsOptional); |
104 | } |
105 | |
106 | private: |
107 | /// The parameters captured by this directive. |
108 | std::vector<ParameterElement *> params; |
109 | }; |
110 | |
111 | /// This class represents a `params` directive that refers to all parameters |
112 | /// of an attribute or type. When used as a top-level directive, it generates |
113 | /// a format of the form: |
114 | /// |
115 | /// (param-value (`,` param-value)*)? |
116 | /// |
117 | /// When used as an argument to another directive that accepts variables, |
118 | /// `params` can be used in place of manually listing all parameters of an |
119 | /// attribute or type. |
120 | class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> { |
121 | public: |
122 | using Base::Base; |
123 | }; |
124 | |
125 | /// This class represents a `struct` directive that generates a struct format |
126 | /// of the form: |
127 | /// |
128 | /// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}` |
129 | /// |
130 | class StructDirective : public ParamsDirectiveBase<DirectiveElement::Struct> { |
131 | public: |
132 | using Base::Base; |
133 | }; |
134 | |
135 | } // namespace |
136 | |
137 | //===----------------------------------------------------------------------===// |
138 | // Format Strings |
139 | //===----------------------------------------------------------------------===// |
140 | |
141 | /// Default parser for attribute or type parameters. |
142 | static const char *const defaultParameterParser = |
143 | "::mlir::FieldParser<$0>::parse($_parser)"; |
144 | |
145 | /// Default printer for attribute or type parameters. |
146 | static const char *const defaultParameterPrinter = |
147 | "$_printer.printStrippedAttrOrType($_self)"; |
148 | |
149 | /// Qualified printer for attribute or type parameters: it does not elide |
150 | /// dialect and mnemonic. |
151 | static const char *const qualifiedParameterPrinter = "$_printer << $_self"; |
152 | |
153 | /// Print an error when failing to parse an element. |
154 | /// |
155 | /// $0: The parameter C++ class name. |
156 | static const char *const parserErrorStr = |
157 | "$_parser.emitError($_parser.getCurrentLocation(), "; |
158 | |
159 | /// Code format to parse a variable. Separate by lines because variable parsers |
160 | /// may be generated inside other directives, which requires indentation. |
161 | /// |
162 | /// {0}: The parameter name. |
163 | /// {1}: The parse code for the parameter. |
164 | /// {2}: Code template for printing an error. |
165 | /// {3}: Name of the attribute or type. |
166 | /// {4}: C++ class of the parameter. |
167 | static const char *const variableParser = R"( |
168 | // Parse variable '{0}' |
169 | _result_{0} = {1}; |
170 | if (::mlir::failed(_result_{0})) {{ |
171 | {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`"); |
172 | return {{}; |
173 | } |
174 | )"; |
175 | |
176 | //===----------------------------------------------------------------------===// |
177 | // DefFormat |
178 | //===----------------------------------------------------------------------===// |
179 | |
180 | namespace { |
181 | class DefFormat { |
182 | public: |
183 | DefFormat(const AttrOrTypeDef &def, std::vector<FormatElement *> &&elements) |
184 | : def(def), elements(std::move(elements)) {} |
185 | |
186 | /// Generate the attribute or type parser. |
187 | void genParser(MethodBody &os); |
188 | /// Generate the attribute or type printer. |
189 | void genPrinter(MethodBody &os); |
190 | |
191 | private: |
192 | /// Generate the parser code for a specific format element. |
193 | void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os); |
194 | /// Generate the parser code for a literal. |
195 | void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os, |
196 | bool isOptional = false); |
197 | /// Generate the parser code for a variable. |
198 | void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os); |
199 | /// Generate the parser code for a `params` directive. |
200 | void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os); |
201 | /// Generate the parser code for a `struct` directive. |
202 | void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os); |
203 | /// Generate the parser code for a `custom` directive. |
204 | void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os, |
205 | bool isOptional = false); |
206 | /// Generate the parser code for an optional group. |
207 | void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, |
208 | MethodBody &os); |
209 | |
210 | /// Generate the printer code for a specific format element. |
211 | void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os); |
212 | /// Generate the printer code for a literal. |
213 | void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os); |
214 | /// Generate the printer code for a variable. |
215 | void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os, |
216 | bool skipGuard = false); |
217 | /// Generate a printer for comma-separated parameters. |
218 | void genCommaSeparatedPrinter(ArrayRef<ParameterElement *> params, |
219 | FmtContext &ctx, MethodBody &os, |
220 | function_ref<void(ParameterElement *)> extra); |
221 | /// Generate the printer code for a `params` directive. |
222 | void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os); |
223 | /// Generate the printer code for a `struct` directive. |
224 | void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os); |
225 | /// Generate the printer code for a `custom` directive. |
226 | void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os); |
227 | /// Generate the printer code for an optional group. |
228 | void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, |
229 | MethodBody &os); |
230 | /// Generate a printer (or space eraser) for a whitespace element. |
231 | void genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx, |
232 | MethodBody &os); |
233 | |
234 | /// The ODS definition of the attribute or type whose format is being used to |
235 | /// generate a parser and printer. |
236 | const AttrOrTypeDef &def; |
237 | /// The list of top-level format elements returned by the assembly format |
238 | /// parser. |
239 | std::vector<FormatElement *> elements; |
240 | |
241 | /// Flags for printing spaces. |
242 | bool shouldEmitSpace = false; |
243 | bool lastWasPunctuation = false; |
244 | }; |
245 | } // namespace |
246 | |
247 | //===----------------------------------------------------------------------===// |
248 | // ParserGen |
249 | //===----------------------------------------------------------------------===// |
250 | |
251 | /// Generate a special-case "parser" for an attribute's self type parameter. The |
252 | /// self type parameter has special handling in the assembly format in that it |
253 | /// is derived from the optional trailing colon type after the attribute. |
254 | static void genAttrSelfTypeParser(MethodBody &os, const FmtContext &ctx, |
255 | const AttributeSelfTypeParameter ¶m) { |
256 | // "Parser" for an attribute self type parameter that checks the |
257 | // optionally-parsed trailing colon type. |
258 | // |
259 | // $0: The C++ storage class of the type parameter. |
260 | // $1: The self type parameter name. |
261 | const char *const selfTypeParser = R"( |
262 | if ($_type) { |
263 | if (auto reqType = ::llvm::dyn_cast<$0>($_type)) { |
264 | _result_$1 = reqType; |
265 | } else { |
266 | $_parser.emitError($_loc, "invalid kind of type specified"); |
267 | return {}; |
268 | } |
269 | })"; |
270 | |
271 | // If the attribute self type parameter is required, emit code that emits an |
272 | // error if the trailing type was not parsed. |
273 | const char *const selfTypeRequired = R"( else { |
274 | $_parser.emitError($_loc, "expected a trailing type"); |
275 | return {}; |
276 | })"; |
277 | |
278 | os << tgfmt(fmt: selfTypeParser, ctx: &ctx, vals: param.getCppStorageType(), vals: param.getName()); |
279 | if (!param.isOptional()) |
280 | os << tgfmt(fmt: selfTypeRequired, ctx: &ctx); |
281 | os << "\n"; |
282 | } |
283 | |
284 | void DefFormat::genParser(MethodBody &os) { |
285 | FmtContext ctx; |
286 | ctx.addSubst(placeholder: "_parser", subst: "odsParser"); |
287 | ctx.addSubst(placeholder: "_ctxt", subst: "odsParser.getContext()"); |
288 | ctx.withBuilder(subst: "odsBuilder"); |
289 | if (isa<AttrDef>(Val: def)) |
290 | ctx.addSubst(placeholder: "_type", subst: "odsType"); |
291 | os.indent(); |
292 | os << "::mlir::Builder odsBuilder(odsParser.getContext());\n"; |
293 | |
294 | // Store the initial location of the parser. |
295 | ctx.addSubst(placeholder: "_loc", subst: "odsLoc"); |
296 | os << tgfmt(fmt: "::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n" |
297 | "(void) $_loc;\n", |
298 | ctx: &ctx); |
299 | |
300 | // Declare variables to store all of the parameters. Allocated parameters |
301 | // such as `ArrayRef` and `StringRef` must provide a `storageType`. Store |
302 | // FailureOr<T> to defer type construction for parameters that are parsed in |
303 | // a loop (parsers return FailureOr anyways). |
304 | ArrayRef<AttrOrTypeParameter> params = def.getParameters(); |
305 | for (const AttrOrTypeParameter ¶m : params) { |
306 | os << formatv(Fmt: "::mlir::FailureOr<{0}> _result_{1};\n", |
307 | Vals: param.getCppStorageType(), Vals: param.getName()); |
308 | if (auto *selfTypeParam = dyn_cast<AttributeSelfTypeParameter>(Val: ¶m)) |
309 | genAttrSelfTypeParser(os, ctx, param: *selfTypeParam); |
310 | } |
311 | |
312 | // Generate call to each parameter parser. |
313 | for (FormatElement *el : elements) |
314 | genElementParser(el, ctx, os); |
315 | |
316 | // Emit an assert for each mandatory parameter. Triggering an assert means |
317 | // the generated parser is incorrect (i.e. there is a bug in this code). |
318 | for (const AttrOrTypeParameter ¶m : params) { |
319 | if (param.isOptional()) |
320 | continue; |
321 | os << formatv(Fmt: "assert(::mlir::succeeded(_result_{0}));\n", Vals: param.getName()); |
322 | } |
323 | |
324 | // Generate call to the attribute or type builder. Use the checked getter |
325 | // if one was generated. |
326 | if (def.genVerifyDecl()) { |
327 | os << tgfmt(fmt: "return $_parser.getChecked<$0>($_loc, $_parser.getContext()", |
328 | ctx: &ctx, vals: def.getCppClassName()); |
329 | } else { |
330 | os << tgfmt(fmt: "return $0::get($_parser.getContext()", ctx: &ctx, |
331 | vals: def.getCppClassName()); |
332 | } |
333 | for (const AttrOrTypeParameter ¶m : params) { |
334 | os << ",\n "; |
335 | std::string paramSelfStr; |
336 | llvm::raw_string_ostream selfOs(paramSelfStr); |
337 | if (std::optional<StringRef> defaultValue = param.getDefaultValue()) { |
338 | selfOs << formatv(Fmt: "(_result_{0}.value_or(", Vals: param.getName()) |
339 | << tgfmt(fmt: *defaultValue, ctx: &ctx) << "))"; |
340 | } else { |
341 | selfOs << formatv(Fmt: "(*_result_{0})", Vals: param.getName()); |
342 | } |
343 | ctx.addSubst(placeholder: param.getName(), subst: selfOs.str()); |
344 | os << param.getCppType() << "(" |
345 | << tgfmt(fmt: param.getConvertFromStorage(), ctx: &ctx.withSelf(subst: selfOs.str())) |
346 | << ")"; |
347 | } |
348 | os << ");"; |
349 | } |
350 | |
351 | void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx, |
352 | MethodBody &os) { |
353 | if (auto *literal = dyn_cast<LiteralElement>(Val: el)) |
354 | return genLiteralParser(value: literal->getSpelling(), ctx, os); |
355 | if (auto *var = dyn_cast<ParameterElement>(Val: el)) |
356 | return genVariableParser(el: var, ctx, os); |
357 | if (auto *params = dyn_cast<ParamsDirective>(Val: el)) |
358 | return genParamsParser(el: params, ctx, os); |
359 | if (auto *strct = dyn_cast<StructDirective>(Val: el)) |
360 | return genStructParser(el: strct, ctx, os); |
361 | if (auto *custom = dyn_cast<CustomDirective>(Val: el)) |
362 | return genCustomParser(el: custom, ctx, os); |
363 | if (auto *optional = dyn_cast<OptionalElement>(Val: el)) |
364 | return genOptionalGroupParser(el: optional, ctx, os); |
365 | if (isa<WhitespaceElement>(Val: el)) |
366 | return; |
367 | |
368 | llvm_unreachable("unknown format element"); |
369 | } |
370 | |
371 | void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx, |
372 | MethodBody &os, bool isOptional) { |
373 | os << "// Parse literal '"<< value << "'\n"; |
374 | os << tgfmt(fmt: "if ($_parser.parse", ctx: &ctx); |
375 | if (isOptional) |
376 | os << "Optional"; |
377 | if (value.front() == '_' || isalpha(value.front())) { |
378 | os << "Keyword(\""<< value << "\")"; |
379 | } else { |
380 | os << StringSwitch<StringRef>(value) |
381 | .Case(S: "->", Value: "Arrow") |
382 | .Case(S: ":", Value: "Colon") |
383 | .Case(S: ",", Value: "Comma") |
384 | .Case(S: "=", Value: "Equal") |
385 | .Case(S: "<", Value: "Less") |
386 | .Case(S: ">", Value: "Greater") |
387 | .Case(S: "{", Value: "LBrace") |
388 | .Case(S: "}", Value: "RBrace") |
389 | .Case(S: "(", Value: "LParen") |
390 | .Case(S: ")", Value: "RParen") |
391 | .Case(S: "[", Value: "LSquare") |
392 | .Case(S: "]", Value: "RSquare") |
393 | .Case(S: "?", Value: "Question") |
394 | .Case(S: "+", Value: "Plus") |
395 | .Case(S: "*", Value: "Star") |
396 | .Case(S: "...", Value: "Ellipsis") |
397 | << "()"; |
398 | } |
399 | if (isOptional) { |
400 | // Leave the `if` unclosed to guard optional groups. |
401 | return; |
402 | } |
403 | // Parser will emit an error |
404 | os << ") return {};\n"; |
405 | } |
406 | |
407 | void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx, |
408 | MethodBody &os) { |
409 | // Check for a custom parser. Use the default attribute parser otherwise. |
410 | const AttrOrTypeParameter ¶m = el->getParam(); |
411 | auto customParser = param.getParser(); |
412 | auto parser = |
413 | customParser ? *customParser : StringRef(defaultParameterParser); |
414 | os << formatv(Fmt: variableParser, Vals: param.getName(), |
415 | Vals: tgfmt(fmt: parser, ctx: &ctx, vals: param.getCppStorageType()), |
416 | Vals: tgfmt(fmt: parserErrorStr, ctx: &ctx), Vals: def.getName(), Vals: param.getCppType()); |
417 | } |
418 | |
419 | void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx, |
420 | MethodBody &os) { |
421 | os << "// Parse parameter list\n"; |
422 | |
423 | // If there are optional parameters, we need to switch to `parseOptionalComma` |
424 | // if there are no more required parameters after a certain point. |
425 | bool hasOptional = el->hasOptionalParams(); |
426 | if (hasOptional) { |
427 | // Wrap everything in a do-while so that we can `break`. |
428 | os << "do {\n"; |
429 | os.indent(); |
430 | } |
431 | |
432 | ArrayRef<ParameterElement *> params = el->getParams(); |
433 | using IteratorT = ParameterElement *const *; |
434 | IteratorT it = params.begin(); |
435 | |
436 | // Find the last required parameter. Commas become optional aftewards. |
437 | // Note: IteratorT's copy assignment is deleted. |
438 | ParameterElement *lastReq = nullptr; |
439 | for (ParameterElement *param : params) |
440 | if (!param->isOptional()) |
441 | lastReq = param; |
442 | IteratorT lastReqIt = lastReq ? llvm::find(Range&: params, Val: lastReq) : params.begin(); |
443 | |
444 | auto eachFn = [&](ParameterElement *el) { genVariableParser(el, ctx, os); }; |
445 | auto betweenFn = [&](IteratorT it) { |
446 | ParameterElement *el = *std::prev(x: it); |
447 | // Parse a comma if the last optional parameter had a value. |
448 | if (el->isOptional()) { |
449 | os << formatv(Fmt: "if (::mlir::succeeded(_result_{0}) && !({1})) {{\n", |
450 | Vals: el->getName(), |
451 | Vals: el->genIsPresent(ctx, self: "(*_result_"+ el->getName() + ")")); |
452 | os.indent(); |
453 | } |
454 | if (it <= lastReqIt) { |
455 | genLiteralParser(value: ",", ctx, os); |
456 | } else { |
457 | genLiteralParser(value: ",", ctx, os, /*isOptional=*/true); |
458 | os << ") break;\n"; |
459 | } |
460 | if (el->isOptional()) |
461 | os.unindent() << "}\n"; |
462 | }; |
463 | |
464 | // llvm::interleave |
465 | if (it != params.end()) { |
466 | eachFn(*it++); |
467 | for (IteratorT e = params.end(); it != e; ++it) { |
468 | betweenFn(it); |
469 | eachFn(*it); |
470 | } |
471 | } |
472 | |
473 | if (hasOptional) |
474 | os.unindent() << "} while(false);\n"; |
475 | } |
476 | |
477 | void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx, |
478 | MethodBody &os) { |
479 | // Loop declaration for struct parser with only required parameters. |
480 | // |
481 | // $0: Number of expected parameters. |
482 | const char *const loopHeader = R"( |
483 | for (unsigned odsStructIndex = 0; odsStructIndex < $0; ++odsStructIndex) { |
484 | )"; |
485 | |
486 | // Loop body start for struct parser. |
487 | const char *const loopStart = R"( |
488 | ::llvm::StringRef _paramKey; |
489 | if ($_parser.parseKeyword(&_paramKey)) { |
490 | $_parser.emitError($_parser.getCurrentLocation(), |
491 | "expected a parameter name in struct"); |
492 | return {}; |
493 | } |
494 | if (!_loop_body(_paramKey)) return {}; |
495 | )"; |
496 | |
497 | // Struct parser loop end. Check for duplicate or unknown struct parameters. |
498 | // |
499 | // {0}: Code template for printing an error. |
500 | const char *const loopEnd = R"({{ |
501 | {0}"duplicate or unknown struct parameter name: ") << _paramKey; |
502 | return {{}; |
503 | } |
504 | )"; |
505 | |
506 | // Struct parser loop terminator. Parse a comma except on the last element. |
507 | // |
508 | // {0}: Number of elements in the struct. |
509 | const char *const loopTerminator = R"( |
510 | if ((odsStructIndex != {0} - 1) && odsParser.parseComma()) |
511 | return {{}; |
512 | } |
513 | )"; |
514 | |
515 | // Check that a mandatory parameter was parse. |
516 | // |
517 | // {0}: Name of the parameter. |
518 | const char *const checkParam = R"( |
519 | if (!_seen_{0}) { |
520 | {1}"struct is missing required parameter: ") << "{0}"; |
521 | return {{}; |
522 | } |
523 | )"; |
524 | |
525 | // First iteration of the loop parsing an optional struct. |
526 | const char *const optionalStructFirst = R"( |
527 | ::llvm::StringRef _paramKey; |
528 | if (!$_parser.parseOptionalKeyword(&_paramKey)) { |
529 | if (!_loop_body(_paramKey)) return {}; |
530 | while (!$_parser.parseOptionalComma()) { |
531 | )"; |
532 | |
533 | os << "// Parse parameter struct\n"; |
534 | |
535 | // Declare a "seen" variable for each key. |
536 | for (ParameterElement *param : el->getParams()) |
537 | os << formatv(Fmt: "bool _seen_{0} = false;\n", Vals: param->getName()); |
538 | |
539 | // Generate the body of the parsing loop inside a lambda. |
540 | os << "{\n"; |
541 | os.indent() |
542 | << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n"; |
543 | genLiteralParser(value: "=", ctx, os&: os.indent()); |
544 | for (ParameterElement *param : el->getParams()) { |
545 | os << formatv(Fmt: "if (!_seen_{0} && _paramKey == \"{0}\") {\n" |
546 | " _seen_{0} = true;\n", |
547 | Vals: param->getName()); |
548 | genVariableParser(el: param, ctx, os&: os.indent()); |
549 | os.unindent() << "} else "; |
550 | // Print the check for duplicate or unknown parameter. |
551 | } |
552 | os.getStream().printReindented(str: strfmt(fmt: loopEnd, parameters: tgfmt(fmt: parserErrorStr, ctx: &ctx))); |
553 | os << "return true;\n"; |
554 | os.unindent() << "};\n"; |
555 | |
556 | // Generate the parsing loop. If optional parameters are present, then the |
557 | // parse loop is guarded by commas. |
558 | unsigned numOptional = llvm::count_if(Range: el->getParams(), P: paramIsOptional); |
559 | if (numOptional) { |
560 | // If the struct itself is optional, pull out the first iteration. |
561 | if (numOptional == el->getNumParams()) { |
562 | os.getStream().printReindented(str: tgfmt(fmt: optionalStructFirst, ctx: &ctx).str()); |
563 | os.indent(); |
564 | } else { |
565 | os << "do {\n"; |
566 | } |
567 | } else { |
568 | os.getStream().printReindented( |
569 | str: tgfmt(fmt: loopHeader, ctx: &ctx, vals: el->getNumParams()).str()); |
570 | } |
571 | os.indent(); |
572 | os.getStream().printReindented(str: tgfmt(fmt: loopStart, ctx: &ctx).str()); |
573 | os.unindent(); |
574 | |
575 | // Print the loop terminator. For optional parameters, we have to check that |
576 | // all mandatory parameters have been parsed. |
577 | // The whole struct is optional if all its parameters are optional. |
578 | if (numOptional) { |
579 | if (numOptional == el->getNumParams()) { |
580 | os << "}\n"; |
581 | os.unindent() << "}\n"; |
582 | } else { |
583 | os << tgfmt(fmt: "} while(!$_parser.parseOptionalComma());\n", ctx: &ctx); |
584 | for (ParameterElement *param : el->getParams()) { |
585 | if (param->isOptional()) |
586 | continue; |
587 | os.getStream().printReindented( |
588 | str: strfmt(fmt: checkParam, parameters: param->getName(), parameters: tgfmt(fmt: parserErrorStr, ctx: &ctx))); |
589 | } |
590 | } |
591 | } else { |
592 | // Because the loop loops N times and each non-failing iteration sets 1 of |
593 | // N flags, successfully exiting the loop means that all parameters have |
594 | // been seen. `parseOptionalComma` would cause issues with any formats that |
595 | // use "struct(...) `,`" beacuse structs aren't sounded by braces. |
596 | os.getStream().printReindented(str: strfmt(fmt: loopTerminator, parameters: el->getNumParams())); |
597 | } |
598 | os.unindent() << "}\n"; |
599 | } |
600 | |
601 | void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx, |
602 | MethodBody &os, bool isOptional) { |
603 | os << "{\n"; |
604 | os.indent(); |
605 | |
606 | // Bound variables are passed directly to the parser as `FailureOr<T> &`. |
607 | // Referenced variables are passed as `T`. The custom parser fails if it |
608 | // returns failure or if any of the required parameters failed. |
609 | os << tgfmt(fmt: "auto odsCustomLoc = $_parser.getCurrentLocation();\n", ctx: &ctx); |
610 | os << "(void)odsCustomLoc;\n"; |
611 | os << tgfmt(fmt: "auto odsCustomResult = parse$0($_parser", ctx: &ctx, vals: el->getName()); |
612 | os.indent(); |
613 | for (FormatElement *arg : el->getArguments()) { |
614 | os << ",\n"; |
615 | if (auto *param = dyn_cast<ParameterElement>(Val: arg)) |
616 | os << "::mlir::detail::unwrapForCustomParse(_result_"<< param->getName() |
617 | << ")"; |
618 | else if (auto *ref = dyn_cast<RefDirective>(Val: arg)) |
619 | os << "*_result_"<< cast<ParameterElement>(Val: ref->getArg())->getName(); |
620 | else |
621 | os << tgfmt(fmt: cast<StringElement>(Val: arg)->getValue(), ctx: &ctx); |
622 | } |
623 | os.unindent() << ");\n"; |
624 | if (isOptional) { |
625 | os << "if (!odsCustomResult.has_value()) return {};\n"; |
626 | os << "if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();\n"; |
627 | } else { |
628 | os << "if (::mlir::failed(odsCustomResult)) return {};\n"; |
629 | } |
630 | for (FormatElement *arg : el->getArguments()) { |
631 | if (auto *param = dyn_cast<ParameterElement>(Val: arg)) { |
632 | if (param->isOptional()) |
633 | continue; |
634 | os << formatv(Fmt: "if (::mlir::failed(_result_{0})) {{\n", Vals: param->getName()); |
635 | os.indent() << tgfmt(fmt: "$_parser.emitError(odsCustomLoc, ", ctx: &ctx) |
636 | << "\"custom parser failed to parse parameter '" |
637 | << param->getName() << "'\");\n"; |
638 | os << "return "<< (isOptional ? "::mlir::failure()": "{}") << ";\n"; |
639 | os.unindent() << "}\n"; |
640 | } |
641 | } |
642 | |
643 | os.unindent() << "}\n"; |
644 | } |
645 | |
646 | void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, |
647 | MethodBody &os) { |
648 | ArrayRef<FormatElement *> thenElements = |
649 | el->getThenElements(/*parseable=*/true); |
650 | |
651 | FormatElement *first = thenElements.front(); |
652 | const auto guardOn = [&](auto params) { |
653 | os << "if (!("; |
654 | llvm::interleave( |
655 | params, os, |
656 | [&](ParameterElement *el) { |
657 | os << formatv(Fmt: "(::mlir::succeeded(_result_{0}) && *_result_{0})", |
658 | Vals: el->getName()); |
659 | }, |
660 | " || "); |
661 | os << ")) {\n"; |
662 | }; |
663 | if (auto *literal = dyn_cast<LiteralElement>(Val: first)) { |
664 | genLiteralParser(value: literal->getSpelling(), ctx, os, /*isOptional=*/true); |
665 | os << ") {\n"; |
666 | } else if (auto *param = dyn_cast<ParameterElement>(Val: first)) { |
667 | genVariableParser(el: param, ctx, os); |
668 | guardOn(llvm::ArrayRef(param)); |
669 | } else if (auto *params = dyn_cast<ParamsDirective>(Val: first)) { |
670 | genParamsParser(el: params, ctx, os); |
671 | guardOn(params->getParams()); |
672 | } else if (auto *custom = dyn_cast<CustomDirective>(Val: first)) { |
673 | os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n"; |
674 | os.indent(); |
675 | genCustomParser(el: custom, ctx, os, /*isOptional=*/true); |
676 | os << "return ::mlir::success();\n"; |
677 | os.unindent(); |
678 | os << "}(); result.has_value() && ::mlir::failed(*result)) {\n"; |
679 | os.indent(); |
680 | os << "return {};\n"; |
681 | os.unindent(); |
682 | os << "} else if (result.has_value()) {\n"; |
683 | } else { |
684 | auto *strct = cast<StructDirective>(Val: first); |
685 | genStructParser(el: strct, ctx, os); |
686 | guardOn(params->getParams()); |
687 | } |
688 | os.indent(); |
689 | |
690 | // Generate the parsers for the rest of the thenElements. |
691 | for (FormatElement *element : el->getElseElements(/*parseable=*/true)) |
692 | genElementParser(el: element, ctx, os); |
693 | os.unindent() << "} else {\n"; |
694 | os.indent(); |
695 | for (FormatElement *element : thenElements.drop_front()) |
696 | genElementParser(el: element, ctx, os); |
697 | os.unindent() << "}\n"; |
698 | } |
699 | |
700 | //===----------------------------------------------------------------------===// |
701 | // PrinterGen |
702 | //===----------------------------------------------------------------------===// |
703 | |
704 | void DefFormat::genPrinter(MethodBody &os) { |
705 | FmtContext ctx; |
706 | ctx.addSubst(placeholder: "_printer", subst: "odsPrinter"); |
707 | ctx.addSubst(placeholder: "_ctxt", subst: "getContext()"); |
708 | ctx.withBuilder(subst: "odsBuilder"); |
709 | os.indent(); |
710 | os << "::mlir::Builder odsBuilder(getContext());\n"; |
711 | |
712 | // Generate printers. |
713 | shouldEmitSpace = true; |
714 | lastWasPunctuation = false; |
715 | for (FormatElement *el : elements) |
716 | genElementPrinter(el, ctx, os); |
717 | } |
718 | |
719 | void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx, |
720 | MethodBody &os) { |
721 | if (auto *literal = dyn_cast<LiteralElement>(Val: el)) |
722 | return genLiteralPrinter(value: literal->getSpelling(), ctx, os); |
723 | if (auto *params = dyn_cast<ParamsDirective>(Val: el)) |
724 | return genParamsPrinter(el: params, ctx, os); |
725 | if (auto *strct = dyn_cast<StructDirective>(Val: el)) |
726 | return genStructPrinter(el: strct, ctx, os); |
727 | if (auto *custom = dyn_cast<CustomDirective>(Val: el)) |
728 | return genCustomPrinter(el: custom, ctx, os); |
729 | if (auto *var = dyn_cast<ParameterElement>(Val: el)) |
730 | return genVariablePrinter(el: var, ctx, os); |
731 | if (auto *optional = dyn_cast<OptionalElement>(Val: el)) |
732 | return genOptionalGroupPrinter(el: optional, ctx, os); |
733 | if (auto *whitespace = dyn_cast<WhitespaceElement>(Val: el)) |
734 | return genWhitespacePrinter(el: whitespace, ctx, os); |
735 | |
736 | llvm::PrintFatalError(Msg: "unsupported format element"); |
737 | } |
738 | |
739 | void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx, |
740 | MethodBody &os) { |
741 | // Don't insert a space before certain punctuation. |
742 | bool needSpace = |
743 | shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation); |
744 | os << tgfmt(fmt: "$_printer$0 << \"$1\";\n", ctx: &ctx, vals: needSpace ? " << ' '": "", |
745 | vals&: value); |
746 | |
747 | // Update the flags. |
748 | shouldEmitSpace = |
749 | value.size() != 1 || !StringRef("<({[").contains(C: value.front()); |
750 | lastWasPunctuation = value.front() != '_' && !isalpha(value.front()); |
751 | } |
752 | |
753 | void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx, |
754 | MethodBody &os, bool skipGuard) { |
755 | const AttrOrTypeParameter ¶m = el->getParam(); |
756 | ctx.withSelf(subst: param.getAccessorName() + "()"); |
757 | |
758 | // Guard the printer on the presence of optional parameters and that they |
759 | // aren't equal to their default values (if they have one). |
760 | if (el->isOptional() && !skipGuard) { |
761 | el->genPrintGuard(ctx, os&: os << "if (") << ") {\n"; |
762 | os.indent(); |
763 | } |
764 | |
765 | // Insert a space before the next parameter, if necessary. |
766 | if (shouldEmitSpace || !lastWasPunctuation) |
767 | os << tgfmt(fmt: "$_printer << ' ';\n", ctx: &ctx); |
768 | shouldEmitSpace = true; |
769 | lastWasPunctuation = false; |
770 | |
771 | if (el->shouldBeQualified()) |
772 | os << tgfmt(fmt: qualifiedParameterPrinter, ctx: &ctx) << ";\n"; |
773 | else if (auto printer = param.getPrinter()) |
774 | os << tgfmt(fmt: *printer, ctx: &ctx) << ";\n"; |
775 | else |
776 | os << tgfmt(fmt: defaultParameterPrinter, ctx: &ctx) << ";\n"; |
777 | |
778 | if (el->isOptional() && !skipGuard) |
779 | os.unindent() << "}\n"; |
780 | } |
781 | |
782 | /// Generate code to guard printing on the presence of any optional parameters. |
783 | template <typename ParameterRange> |
784 | static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &¶ms, |
785 | bool inverted = false) { |
786 | os << "if ("; |
787 | if (inverted) |
788 | os << "!("; |
789 | llvm::interleave( |
790 | params, os, |
791 | [&](ParameterElement *param) { param->genPrintGuard(ctx, os); }, " || "); |
792 | if (inverted) |
793 | os << ")"; |
794 | os << ") {\n"; |
795 | os.indent(); |
796 | } |
797 | |
798 | void DefFormat::genCommaSeparatedPrinter( |
799 | ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os, |
800 | function_ref<void(ParameterElement *)> extra) { |
801 | // Emit a space if necessary, but only if the struct is present. |
802 | if (shouldEmitSpace || !lastWasPunctuation) { |
803 | bool allOptional = llvm::all_of(Range&: params, P: paramIsOptional); |
804 | if (allOptional) |
805 | guardOnAny(ctx, os, params); |
806 | os << tgfmt(fmt: "$_printer << ' ';\n", ctx: &ctx); |
807 | if (allOptional) |
808 | os.unindent() << "}\n"; |
809 | } |
810 | |
811 | // The first printed element does not need to emit a comma. |
812 | os << "{\n"; |
813 | os.indent() << "bool _firstPrinted = true;\n"; |
814 | for (ParameterElement *param : params) { |
815 | if (param->isOptional()) { |
816 | param->genPrintGuard(ctx, os&: os << "if (") << ") {\n"; |
817 | os.indent(); |
818 | } |
819 | os << tgfmt(fmt: "if (!_firstPrinted) $_printer << \", \";\n", ctx: &ctx); |
820 | os << "_firstPrinted = false;\n"; |
821 | extra(param); |
822 | shouldEmitSpace = false; |
823 | lastWasPunctuation = true; |
824 | genVariablePrinter(el: param, ctx, os); |
825 | if (param->isOptional()) |
826 | os.unindent() << "}\n"; |
827 | } |
828 | os.unindent() << "}\n"; |
829 | } |
830 | |
831 | void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx, |
832 | MethodBody &os) { |
833 | genCommaSeparatedPrinter(params: llvm::to_vector(Range: el->getParams()), ctx, os, |
834 | extra: [&](ParameterElement *param) {}); |
835 | } |
836 | |
837 | void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx, |
838 | MethodBody &os) { |
839 | genCommaSeparatedPrinter( |
840 | params: llvm::to_vector(Range: el->getParams()), ctx, os, extra: [&](ParameterElement *param) { |
841 | os << tgfmt(fmt: "$_printer << \"$0 = \";\n", ctx: &ctx, vals: param->getName()); |
842 | }); |
843 | } |
844 | |
845 | void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx, |
846 | MethodBody &os) { |
847 | // Insert a space before the custom directive, if necessary. |
848 | if (shouldEmitSpace || !lastWasPunctuation) |
849 | os << tgfmt(fmt: "$_printer << ' ';\n", ctx: &ctx); |
850 | shouldEmitSpace = true; |
851 | lastWasPunctuation = false; |
852 | |
853 | os << tgfmt(fmt: "print$0($_printer", ctx: &ctx, vals: el->getName()); |
854 | os.indent(); |
855 | for (FormatElement *arg : el->getArguments()) { |
856 | os << ",\n"; |
857 | if (auto *param = dyn_cast<ParameterElement>(Val: arg)) { |
858 | os << param->getParam().getAccessorName() << "()"; |
859 | } else if (auto *ref = dyn_cast<RefDirective>(Val: arg)) { |
860 | os << cast<ParameterElement>(Val: ref->getArg())->getParam().getAccessorName() |
861 | << "()"; |
862 | } else { |
863 | os << tgfmt(fmt: cast<StringElement>(Val: arg)->getValue(), ctx: &ctx); |
864 | } |
865 | } |
866 | os.unindent() << ");\n"; |
867 | } |
868 | |
869 | void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, |
870 | MethodBody &os) { |
871 | FormatElement *anchor = el->getAnchor(); |
872 | if (auto *param = dyn_cast<ParameterElement>(Val: anchor)) { |
873 | guardOnAny(ctx, os, params: llvm::ArrayRef(param), inverted: el->isInverted()); |
874 | } else if (auto *params = dyn_cast<ParamsDirective>(Val: anchor)) { |
875 | guardOnAny(ctx, os, params: params->getParams(), inverted: el->isInverted()); |
876 | } else if (auto *strct = dyn_cast<StructDirective>(Val: anchor)) { |
877 | guardOnAny(ctx, os, params: strct->getParams(), inverted: el->isInverted()); |
878 | } else { |
879 | auto *custom = cast<CustomDirective>(Val: anchor); |
880 | guardOnAny(ctx, os, |
881 | params: llvm::make_filter_range( |
882 | Range: llvm::map_range(C: custom->getArguments(), |
883 | F: [](FormatElement *el) { |
884 | return dyn_cast<ParameterElement>(Val: el); |
885 | }), |
886 | Pred: [](ParameterElement *param) { return !!param; }), |
887 | inverted: el->isInverted()); |
888 | } |
889 | // Generate the printer for the contained elements. |
890 | { |
891 | llvm::SaveAndRestore shouldEmitSpaceFlag(shouldEmitSpace); |
892 | llvm::SaveAndRestore lastWasPunctuationFlag(lastWasPunctuation); |
893 | for (FormatElement *element : el->getThenElements()) |
894 | genElementPrinter(el: element, ctx, os); |
895 | } |
896 | os.unindent() << "} else {\n"; |
897 | os.indent(); |
898 | for (FormatElement *element : el->getElseElements()) |
899 | genElementPrinter(el: element, ctx, os); |
900 | os.unindent() << "}\n"; |
901 | } |
902 | |
903 | void DefFormat::genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx, |
904 | MethodBody &os) { |
905 | if (el->getValue() == "\\n") { |
906 | // FIXME: The newline should be `printer.printNewLine()`, i.e., handled by |
907 | // the printer. |
908 | os << tgfmt(fmt: "$_printer << '\\n';\n", ctx: &ctx); |
909 | } else if (!el->getValue().empty()) { |
910 | os << tgfmt(fmt: "$_printer << \"$0\";\n", ctx: &ctx, vals: el->getValue()); |
911 | } else { |
912 | lastWasPunctuation = true; |
913 | } |
914 | shouldEmitSpace = false; |
915 | } |
916 | |
917 | //===----------------------------------------------------------------------===// |
918 | // DefFormatParser |
919 | //===----------------------------------------------------------------------===// |
920 | |
921 | namespace { |
922 | class DefFormatParser : public FormatParser { |
923 | public: |
924 | DefFormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def) |
925 | : FormatParser(mgr, def.getLoc()[0]), def(def), |
926 | seenParams(def.getNumParameters()) {} |
927 | |
928 | /// Parse the attribute or type format and create the format elements. |
929 | FailureOr<DefFormat> parse(); |
930 | |
931 | protected: |
932 | /// Verify the parsed elements. |
933 | LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override; |
934 | /// Verify the elements of a custom directive. |
935 | LogicalResult |
936 | verifyCustomDirectiveArguments(SMLoc loc, |
937 | ArrayRef<FormatElement *> arguments) override; |
938 | /// Verify the elements of an optional group. |
939 | LogicalResult verifyOptionalGroupElements(SMLoc loc, |
940 | ArrayRef<FormatElement *> elements, |
941 | FormatElement *anchor) override; |
942 | |
943 | /// Parse an attribute or type variable. |
944 | FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name, |
945 | Context ctx) override; |
946 | /// Parse an attribute or type format directive. |
947 | FailureOr<FormatElement *> |
948 | parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override; |
949 | |
950 | private: |
951 | /// Parse a `params` directive. |
952 | FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx); |
953 | /// Parse a `qualified` directive. |
954 | FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx); |
955 | /// Parse a `struct` directive. |
956 | FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx); |
957 | /// Parse a `ref` directive. |
958 | FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context ctx); |
959 | |
960 | /// Attribute or type tablegen def. |
961 | const AttrOrTypeDef &def; |
962 | |
963 | /// Seen attribute or type parameters. |
964 | BitVector seenParams; |
965 | }; |
966 | } // namespace |
967 | |
968 | LogicalResult DefFormatParser::verify(SMLoc loc, |
969 | ArrayRef<FormatElement *> elements) { |
970 | // Check that all parameters are referenced in the format. |
971 | for (auto [index, param] : llvm::enumerate(First: def.getParameters())) { |
972 | if (param.isOptional()) |
973 | continue; |
974 | if (!seenParams.test(Idx: index)) { |
975 | if (isa<AttributeSelfTypeParameter>(Val: param)) |
976 | continue; |
977 | return emitError(loc, msg: "format is missing reference to parameter: "+ |
978 | param.getName()); |
979 | } |
980 | if (isa<AttributeSelfTypeParameter>(Val: param)) { |
981 | return emitError(loc, |
982 | msg: "unexpected self type parameter in assembly format"); |
983 | } |
984 | } |
985 | if (elements.empty()) |
986 | return success(); |
987 | // A `struct` directive that contains optional parameters cannot be followed |
988 | // by a comma literal, which is ambiguous. |
989 | for (auto it : llvm::zip(t: elements.drop_back(), u: elements.drop_front())) { |
990 | auto *structEl = dyn_cast<StructDirective>(Val: std::get<0>(t&: it)); |
991 | auto *literalEl = dyn_cast<LiteralElement>(Val: std::get<1>(t&: it)); |
992 | if (!structEl || !literalEl) |
993 | continue; |
994 | if (literalEl->getSpelling() == ","&& structEl->hasOptionalParams()) { |
995 | return emitError(loc, msg: "`struct` directive with optional parameters " |
996 | "cannot be followed by a comma literal"); |
997 | } |
998 | } |
999 | return success(); |
1000 | } |
1001 | |
1002 | LogicalResult DefFormatParser::verifyCustomDirectiveArguments( |
1003 | SMLoc loc, ArrayRef<FormatElement *> arguments) { |
1004 | // Arguments are fully verified by the parser context. |
1005 | return success(); |
1006 | } |
1007 | |
1008 | LogicalResult |
1009 | DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc, |
1010 | ArrayRef<FormatElement *> elements, |
1011 | FormatElement *anchor) { |
1012 | // `params` and `struct` directives are allowed only if all the contained |
1013 | // parameters are optional. |
1014 | for (FormatElement *el : elements) { |
1015 | if (auto *param = dyn_cast<ParameterElement>(Val: el)) { |
1016 | if (!param->isOptional()) { |
1017 | return emitError(loc, |
1018 | msg: "parameters in an optional group must be optional"); |
1019 | } |
1020 | } else if (auto *params = dyn_cast<ParamsDirective>(Val: el)) { |
1021 | if (llvm::any_of(Range: params->getParams(), P: paramNotOptional)) { |
1022 | return emitError(loc, msg: "`params` directive allowed in optional group " |
1023 | "only if all parameters are optional"); |
1024 | } |
1025 | } else if (auto *strct = dyn_cast<StructDirective>(Val: el)) { |
1026 | if (llvm::any_of(Range: strct->getParams(), P: paramNotOptional)) { |
1027 | return emitError(loc, msg: "`struct` is only allowed in an optional group " |
1028 | "if all captured parameters are optional"); |
1029 | } |
1030 | } else if (auto *custom = dyn_cast<CustomDirective>(Val: el)) { |
1031 | for (FormatElement *el : custom->getArguments()) { |
1032 | // If the custom argument is a variable, then it must be optional. |
1033 | if (auto *param = dyn_cast<ParameterElement>(Val: el)) |
1034 | if (!param->isOptional()) |
1035 | return emitError(loc, |
1036 | msg: "`custom` is only allowed in an optional group if " |
1037 | "all captured parameters are optional"); |
1038 | } |
1039 | } |
1040 | } |
1041 | // The anchor must be a parameter or one of the aforementioned directives. |
1042 | if (anchor) { |
1043 | if (!isa<ParameterElement, ParamsDirective, StructDirective, |
1044 | CustomDirective>(Val: anchor)) { |
1045 | return emitError( |
1046 | loc, msg: "optional group anchor must be a parameter or directive"); |
1047 | } |
1048 | // If the anchor is a custom directive, make sure at least one of its |
1049 | // arguments is a bound parameter. |
1050 | if (auto *custom = dyn_cast<CustomDirective>(Val: anchor)) { |
1051 | const auto *bound = |
1052 | llvm::find_if(Range: custom->getArguments(), P: [](FormatElement *el) { |
1053 | return isa<ParameterElement>(Val: el); |
1054 | }); |
1055 | if (bound == custom->getArguments().end()) |
1056 | return emitError(loc, msg: "`custom` directive with no bound parameters " |
1057 | "cannot be used as optional group anchor"); |
1058 | } |
1059 | } |
1060 | return success(); |
1061 | } |
1062 | |
1063 | FailureOr<DefFormat> DefFormatParser::parse() { |
1064 | FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse(); |
1065 | if (failed(result: elements)) |
1066 | return failure(); |
1067 | return DefFormat(def, std::move(*elements)); |
1068 | } |
1069 | |
1070 | FailureOr<FormatElement *> |
1071 | DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) { |
1072 | // Lookup the parameter. |
1073 | ArrayRef<AttrOrTypeParameter> params = def.getParameters(); |
1074 | auto *it = llvm::find_if( |
1075 | Range&: params, P: [&](auto ¶m) { return param.getName() == name; }); |
1076 | |
1077 | // Check that the parameter reference is valid. |
1078 | if (it == params.end()) { |
1079 | return emitError(loc, |
1080 | msg: def.getName() + " has no parameter named '"+ name + "'"); |
1081 | } |
1082 | auto idx = std::distance(first: params.begin(), last: it); |
1083 | |
1084 | if (ctx != RefDirectiveContext) { |
1085 | // Check that the variable has not already been bound. |
1086 | if (seenParams.test(Idx: idx)) |
1087 | return emitError(loc, msg: "duplicate parameter '"+ name + "'"); |
1088 | seenParams.set(idx); |
1089 | |
1090 | // Otherwise, to be referenced, a variable must have been bound. |
1091 | } else if (!seenParams.test(Idx: idx) && !isa<AttributeSelfTypeParameter>(Val: *it)) { |
1092 | return emitError(loc, msg: "parameter '"+ name + |
1093 | "' must be bound before it is referenced"); |
1094 | } |
1095 | |
1096 | return create<ParameterElement>(args: *it); |
1097 | } |
1098 | |
1099 | FailureOr<FormatElement *> |
1100 | DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, |
1101 | Context ctx) { |
1102 | |
1103 | switch (kind) { |
1104 | case FormatToken::kw_qualified: |
1105 | return parseQualifiedDirective(loc, ctx); |
1106 | case FormatToken::kw_params: |
1107 | return parseParamsDirective(loc, ctx); |
1108 | case FormatToken::kw_struct: |
1109 | return parseStructDirective(loc, ctx); |
1110 | case FormatToken::kw_ref: |
1111 | return parseRefDirective(loc, ctx); |
1112 | case FormatToken::kw_custom: |
1113 | return parseCustomDirective(loc, ctx); |
1114 | |
1115 | default: |
1116 | return emitError(loc, msg: "unsupported directive kind"); |
1117 | } |
1118 | } |
1119 | |
1120 | FailureOr<FormatElement *> |
1121 | DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) { |
1122 | if (failed(result: parseToken(kind: FormatToken::l_paren, |
1123 | msg: "expected '(' before argument list"))) |
1124 | return failure(); |
1125 | FailureOr<FormatElement *> var = parseElement(ctx); |
1126 | if (failed(result: var)) |
1127 | return var; |
1128 | if (!isa<ParameterElement>(Val: *var)) |
1129 | return emitError(loc, msg: "`qualified` argument list expected a variable"); |
1130 | cast<ParameterElement>(Val: *var)->setShouldBeQualified(); |
1131 | if (failed( |
1132 | result: parseToken(kind: FormatToken::r_paren, msg: "expected ')' after argument list"))) |
1133 | return failure(); |
1134 | return var; |
1135 | } |
1136 | |
1137 | FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc, |
1138 | Context ctx) { |
1139 | // It doesn't make sense to allow references to all parameters in a custom |
1140 | // directive because parameters are the only things that can be bound. |
1141 | if (ctx != TopLevelContext && ctx != StructDirectiveContext) { |
1142 | return emitError(loc, msg: "`params` can only be used at the top-level context " |
1143 | "or within a `struct` directive"); |
1144 | } |
1145 | |
1146 | // Collect all of the attribute's or type's parameters and ensure that none of |
1147 | // the parameters have already been captured. |
1148 | std::vector<ParameterElement *> vars; |
1149 | for (const auto &it : llvm::enumerate(First: def.getParameters())) { |
1150 | if (seenParams.test(Idx: it.index())) { |
1151 | return emitError(loc, msg: "`params` captures duplicate parameter: "+ |
1152 | it.value().getName()); |
1153 | } |
1154 | // Self-type parameters are handled separately from the rest of the |
1155 | // parameters. |
1156 | if (isa<AttributeSelfTypeParameter>(Val: it.value())) |
1157 | continue; |
1158 | seenParams.set(it.index()); |
1159 | vars.push_back(x: create<ParameterElement>(args: it.value())); |
1160 | } |
1161 | return create<ParamsDirective>(args: std::move(vars)); |
1162 | } |
1163 | |
1164 | FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc, |
1165 | Context ctx) { |
1166 | if (ctx != TopLevelContext) |
1167 | return emitError(loc, msg: "`struct` can only be used at the top-level context"); |
1168 | |
1169 | if (failed(result: parseToken(kind: FormatToken::l_paren, |
1170 | msg: "expected '(' before `struct` argument list"))) |
1171 | return failure(); |
1172 | |
1173 | // Parse variables captured by `struct`. |
1174 | std::vector<ParameterElement *> vars; |
1175 | |
1176 | // Parse first captured parameter or a `params` directive. |
1177 | FailureOr<FormatElement *> var = parseElement(ctx: StructDirectiveContext); |
1178 | if (failed(result: var) || !isa<VariableElement, ParamsDirective>(Val: *var)) { |
1179 | return emitError(loc, |
1180 | msg: "`struct` argument list expected a variable or directive"); |
1181 | } |
1182 | if (isa<VariableElement>(Val: *var)) { |
1183 | // Parse any other parameters. |
1184 | vars.push_back(x: cast<ParameterElement>(Val: *var)); |
1185 | while (peekToken().is(kind: FormatToken::comma)) { |
1186 | consumeToken(); |
1187 | var = parseElement(ctx: StructDirectiveContext); |
1188 | if (failed(result: var) || !isa<VariableElement>(Val: *var)) |
1189 | return emitError(loc, msg: "expected a variable in `struct` argument list"); |
1190 | vars.push_back(x: cast<ParameterElement>(Val: *var)); |
1191 | } |
1192 | } else { |
1193 | // `struct(params)` captures all parameters in the attribute or type. |
1194 | vars = cast<ParamsDirective>(Val: *var)->takeParams(); |
1195 | } |
1196 | |
1197 | if (failed(result: parseToken(kind: FormatToken::r_paren, |
1198 | msg: "expected ')' at the end of an argument list"))) |
1199 | return failure(); |
1200 | |
1201 | return create<StructDirective>(args: std::move(vars)); |
1202 | } |
1203 | |
1204 | FailureOr<FormatElement *> DefFormatParser::parseRefDirective(SMLoc loc, |
1205 | Context ctx) { |
1206 | if (ctx != CustomDirectiveContext) |
1207 | return emitError(loc, msg: "`ref` is only allowed inside custom directives"); |
1208 | |
1209 | // Parse the child parameter element. |
1210 | FailureOr<FormatElement *> child; |
1211 | if (failed(result: parseToken(kind: FormatToken::l_paren, msg: "expected '('")) || |
1212 | failed(result: child = parseElement(ctx: RefDirectiveContext)) || |
1213 | failed(result: parseToken(kind: FormatToken::r_paren, msg: "expeced ')'"))) |
1214 | return failure(); |
1215 | |
1216 | // Only parameter elements are allowed to be parsed under a `ref` directive. |
1217 | return create<RefDirective>(args&: *child); |
1218 | } |
1219 | |
1220 | //===----------------------------------------------------------------------===// |
1221 | // Interface |
1222 | //===----------------------------------------------------------------------===// |
1223 | |
1224 | void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def, |
1225 | MethodBody &parser, |
1226 | MethodBody &printer) { |
1227 | llvm::SourceMgr mgr; |
1228 | mgr.AddNewSourceBuffer( |
1229 | F: llvm::MemoryBuffer::getMemBuffer(InputData: *def.getAssemblyFormat()), IncludeLoc: SMLoc()); |
1230 | |
1231 | // Parse the custom assembly format> |
1232 | DefFormatParser fmtParser(mgr, def); |
1233 | FailureOr<DefFormat> format = fmtParser.parse(); |
1234 | if (failed(result: format)) { |
1235 | if (formatErrorIsFatal) |
1236 | PrintFatalError(ErrorLoc: def.getLoc(), Msg: "failed to parse assembly format"); |
1237 | return; |
1238 | } |
1239 | |
1240 | // Generate the parser and printer. |
1241 | format->genParser(os&: parser); |
1242 | format->genPrinter(os&: printer); |
1243 | } |
1244 |
Definitions
- ParameterElement
- ParameterElement
- getParam
- shouldBeQualified
- setShouldBeQualified
- isOptional
- getName
- genIsPresent
- genPrintGuard
- paramIsOptional
- paramNotOptional
- ParamsDirectiveBase
- ParamsDirectiveBase
- getParams
- getNumParams
- takeParams
- hasOptionalParams
- ParamsDirective
- StructDirective
- defaultParameterParser
- defaultParameterPrinter
- qualifiedParameterPrinter
- parserErrorStr
- variableParser
- DefFormat
- DefFormat
- genAttrSelfTypeParser
- genParser
- genElementParser
- genLiteralParser
- genVariableParser
- genParamsParser
- genStructParser
- genCustomParser
- genOptionalGroupParser
- genPrinter
- genElementPrinter
- genLiteralPrinter
- genVariablePrinter
- guardOnAny
- genCommaSeparatedPrinter
- genParamsPrinter
- genStructPrinter
- genCustomPrinter
- genOptionalGroupPrinter
- genWhitespacePrinter
- DefFormatParser
- DefFormatParser
- verify
- verifyCustomDirectiveArguments
- verifyOptionalGroupElements
- parse
- parseVariableImpl
- parseDirectiveImpl
- parseQualifiedDirective
- parseParamsDirective
- parseStructDirective
- parseRefDirective
Improve your Profiling and Debugging skills
Find out more