1//===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "AttrOrTypeFormatGen.h"
10#include "CppGenUtilities.h"
11#include "mlir/TableGen/AttrOrTypeDef.h"
12#include "mlir/TableGen/Class.h"
13#include "mlir/TableGen/CodeGenHelpers.h"
14#include "mlir/TableGen/Format.h"
15#include "mlir/TableGen/GenInfo.h"
16#include "mlir/TableGen/Interfaces.h"
17#include "llvm/ADT/StringSet.h"
18#include "llvm/Support/CommandLine.h"
19#include "llvm/TableGen/Error.h"
20#include "llvm/TableGen/TableGenBackend.h"
21
22#define DEBUG_TYPE "mlir-tblgen-attrortypedefgen"
23
24using namespace mlir;
25using namespace mlir::tblgen;
26using llvm::Record;
27using llvm::RecordKeeper;
28
29//===----------------------------------------------------------------------===//
30// Utility Functions
31//===----------------------------------------------------------------------===//
32
33/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
34/// specified and can only find one dialect's defs, use that.
35static void collectAllDefs(StringRef selectedDialect,
36 ArrayRef<const Record *> records,
37 SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
38 // Nothing to do if no defs were found.
39 if (records.empty())
40 return;
41
42 auto defs = llvm::map_range(
43 C&: records, F: [&](const Record *rec) { return AttrOrTypeDef(rec); });
44 if (selectedDialect.empty()) {
45 // If a dialect was not specified, ensure that all found defs belong to the
46 // same dialect.
47 if (!llvm::all_equal(Range: llvm::map_range(
48 C&: defs, F: [](const auto &def) { return def.getDialect(); }))) {
49 llvm::PrintFatalError(Msg: "defs belonging to more than one dialect. Must "
50 "select one via '--(attr|type)defs-dialect'");
51 }
52 resultDefs.assign(in_start: defs.begin(), in_end: defs.end());
53 } else {
54 // Otherwise, generate the defs that belong to the selected dialect.
55 auto dialectDefs = llvm::make_filter_range(Range&: defs, Pred: [&](const auto &def) {
56 return def.getDialect().getName() == selectedDialect;
57 });
58 resultDefs.assign(in_start: dialectDefs.begin(), in_end: dialectDefs.end());
59 }
60}
61
62//===----------------------------------------------------------------------===//
63// DefGen
64//===----------------------------------------------------------------------===//
65
66namespace {
67class DefGen {
68public:
69 /// Create the attribute or type class.
70 DefGen(const AttrOrTypeDef &def);
71
72 void emitDecl(raw_ostream &os) const {
73 if (storageCls && def.genStorageClass()) {
74 NamespaceEmitter ns(os, def.getStorageNamespace());
75 os << "struct " << def.getStorageClassName() << ";\n";
76 }
77 defCls.writeDeclTo(rawOs&: os);
78 }
79 void emitDef(raw_ostream &os) const {
80 if (storageCls && def.genStorageClass()) {
81 NamespaceEmitter ns(os, def.getStorageNamespace());
82 storageCls->writeDeclTo(rawOs&: os); // everything is inline
83 }
84 defCls.writeDefTo(rawOs&: os);
85 }
86
87private:
88 /// Add traits from the TableGen definition to the class.
89 void createParentWithTraits();
90 /// Emit top-level declarations: using declarations and any extra class
91 /// declarations.
92 void emitTopLevelDeclarations();
93 /// Emit the function that returns the type or attribute name.
94 void emitName();
95 /// Emit the dialect name as a static member variable.
96 void emitDialectName();
97 /// Emit attribute or type builders.
98 void emitBuilders();
99 /// Emit a verifier declaration for custom verification (impl. provided by
100 /// the users).
101 void emitVerifierDecl();
102 /// Emit a verifier that checks type constraints.
103 void emitInvariantsVerifierImpl();
104 /// Emit an entry poiunt for verification that calls the invariants and
105 /// custom verifier.
106 void emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier);
107 /// Emit parsers and printers.
108 void emitParserPrinter();
109 /// Emit parameter accessors, if required.
110 void emitAccessors();
111 /// Emit interface methods.
112 void emitInterfaceMethods();
113
114 //===--------------------------------------------------------------------===//
115 // Builder Emission
116
117 /// Emit the default builder `Attribute::get`
118 void emitDefaultBuilder();
119 /// Emit the checked builder `Attribute::getChecked`
120 void emitCheckedBuilder();
121 /// Emit a custom builder.
122 void emitCustomBuilder(const AttrOrTypeBuilder &builder);
123 /// Emit a checked custom builder.
124 void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder);
125
126 //===--------------------------------------------------------------------===//
127 // Interface Method Emission
128
129 /// Emit methods for a trait.
130 void emitTraitMethods(const InterfaceTrait &trait);
131 /// Emit a trait method.
132 void emitTraitMethod(const InterfaceMethod &method);
133
134 //===--------------------------------------------------------------------===//
135 // OpAsm{Type,Attr}Interface Default Method Emission
136
137 /// Emit 'getAlias' method using mnemonic as alias.
138 void emitMnemonicAliasMethod();
139
140 //===--------------------------------------------------------------------===//
141 // Storage Class Emission
142 void emitStorageClass();
143 /// Generate the storage class constructor.
144 void emitStorageConstructor();
145 /// Emit the key type `KeyTy`.
146 void emitKeyType();
147 /// Emit the equality comparison operator.
148 void emitEquals();
149 /// Emit the key hash function.
150 void emitHashKey();
151 /// Emit the function to construct the storage class.
152 void emitConstruct();
153
154 //===--------------------------------------------------------------------===//
155 // Utility Function Declarations
156
157 /// Get the method parameters for a def builder, where the first several
158 /// parameters may be different.
159 SmallVector<MethodParameter>
160 getBuilderParams(std::initializer_list<MethodParameter> prefix) const;
161
162 //===--------------------------------------------------------------------===//
163 // Class fields
164
165 /// The attribute or type definition.
166 const AttrOrTypeDef &def;
167 /// The list of attribute or type parameters.
168 ArrayRef<AttrOrTypeParameter> params;
169 /// The attribute or type class.
170 Class defCls;
171 /// An optional attribute or type storage class. The storage class will
172 /// exist if and only if the def has more than zero parameters.
173 std::optional<Class> storageCls;
174
175 /// The C++ base value of the def, either "Attribute" or "Type".
176 StringRef valueType;
177 /// The prefix/suffix of the TableGen def name, either "Attr" or "Type".
178 StringRef defType;
179};
180} // namespace
181
182DefGen::DefGen(const AttrOrTypeDef &def)
183 : def(def), params(def.getParameters()), defCls(def.getCppClassName()),
184 valueType(isa<AttrDef>(Val: def) ? "Attribute" : "Type"),
185 defType(isa<AttrDef>(Val: def) ? "Attr" : "Type") {
186 // Check that all parameters have names.
187 for (const AttrOrTypeParameter &param : def.getParameters())
188 if (param.isAnonymous())
189 llvm::PrintFatalError(Msg: "all parameters must have a name");
190
191 // If a storage class is needed, create one.
192 if (def.getNumParameters() > 0)
193 storageCls.emplace(args: def.getStorageClassName(), /*isStruct=*/args: true);
194
195 // Create the parent class with any indicated traits.
196 createParentWithTraits();
197 // Emit top-level declarations.
198 emitTopLevelDeclarations();
199 // Emit builders for defs with parameters
200 if (storageCls)
201 emitBuilders();
202 // Emit the type name.
203 emitName();
204 // Emit the dialect name.
205 emitDialectName();
206 // Emit verification of type constraints.
207 bool genVerifyInvariantsImpl = def.genVerifyInvariantsImpl();
208 if (storageCls && genVerifyInvariantsImpl)
209 emitInvariantsVerifierImpl();
210 // Emit the custom verifier (written by the user).
211 bool genVerifyDecl = def.genVerifyDecl();
212 if (storageCls && genVerifyDecl)
213 emitVerifierDecl();
214 // Emit the "verifyInvariants" function if there is any verification at all.
215 if (storageCls)
216 emitInvariantsVerifier(hasImpl: genVerifyInvariantsImpl, hasCustomVerifier: genVerifyDecl);
217 // Emit the mnemonic, if there is one, and any associated parser and printer.
218 if (def.getMnemonic())
219 emitParserPrinter();
220 // Emit accessors
221 if (def.genAccessors())
222 emitAccessors();
223 // Emit trait interface methods
224 emitInterfaceMethods();
225 // Emit OpAsm{Type,Attr}Interface default methods
226 if (def.genMnemonicAlias())
227 emitMnemonicAliasMethod();
228 defCls.finalize();
229 // Emit a storage class if one is needed
230 if (storageCls && def.genStorageClass())
231 emitStorageClass();
232}
233
234void DefGen::createParentWithTraits() {
235 ParentClass defParent(strfmt(fmt: "::mlir::{0}::{1}Base", parameters&: valueType, parameters&: defType));
236 defParent.addTemplateParam(param: def.getCppClassName());
237 defParent.addTemplateParam(param: def.getCppBaseClassName());
238 defParent.addTemplateParam(param: storageCls
239 ? strfmt(fmt: "{0}::{1}", parameters: def.getStorageNamespace(),
240 parameters: def.getStorageClassName())
241 : strfmt(fmt: "::mlir::{0}Storage", parameters&: valueType));
242 SmallVector<std::string> traitNames =
243 llvm::to_vector(Range: llvm::map_range(C: def.getTraits(), F: [](auto &trait) {
244 return isa<NativeTrait>(&trait)
245 ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
246 : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName();
247 }));
248 llvm::for_each(Range&: traitNames, F: [&](auto &traitName) {
249 defParent.addTemplateParam(traitName);
250 });
251
252 // Add OpAsmInterface::Trait if we automatically generate mnemonic alias
253 // method.
254 std::string opAsmInterfaceTraitName =
255 strfmt(fmt: "::mlir::OpAsm{0}Interface::Trait", parameters&: defType);
256 if (def.genMnemonicAlias() && llvm::none_of(Range&: traitNames, P: [&](auto &traitName) {
257 return traitName == opAsmInterfaceTraitName;
258 })) {
259 defParent.addTemplateParam(param: opAsmInterfaceTraitName);
260 }
261 defCls.addParent(parent: std::move(defParent));
262}
263
264/// Include declarations specified on NativeTrait
265static std::string formatExtraDeclarations(const AttrOrTypeDef &def) {
266 SmallVector<StringRef> extraDeclarations;
267 // Include extra class declarations from NativeTrait
268 for (const auto &trait : def.getTraits()) {
269 if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
270 StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration();
271 if (value.empty())
272 continue;
273 extraDeclarations.push_back(Elt: value);
274 }
275 }
276 if (std::optional<StringRef> extraDecl = def.getExtraDecls()) {
277 extraDeclarations.push_back(Elt: *extraDecl);
278 }
279 return llvm::join(R&: extraDeclarations, Separator: "\n");
280}
281
282/// Extra class definitions have a `$cppClass` substitution that is to be
283/// replaced by the C++ class name.
284static std::string formatExtraDefinitions(const AttrOrTypeDef &def) {
285 SmallVector<StringRef> extraDefinitions;
286 // Include extra class definitions from NativeTrait
287 for (const auto &trait : def.getTraits()) {
288 if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
289 StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition();
290 if (value.empty())
291 continue;
292 extraDefinitions.push_back(Elt: value);
293 }
294 }
295 if (std::optional<StringRef> extraDef = def.getExtraDefs()) {
296 extraDefinitions.push_back(Elt: *extraDef);
297 }
298 FmtContext ctx = FmtContext().addSubst(placeholder: "cppClass", subst: def.getCppClassName());
299 return tgfmt(fmt: llvm::join(R&: extraDefinitions, Separator: "\n"), ctx: &ctx).str();
300}
301
302void DefGen::emitTopLevelDeclarations() {
303 // Inherit constructors from the attribute or type class.
304 defCls.declare<VisibilityDeclaration>(args: Visibility::Public);
305 defCls.declare<UsingDeclaration>(args: "Base::Base");
306
307 // Emit the extra declarations first in case there's a definition in there.
308 std::string extraDecl = formatExtraDeclarations(def);
309 std::string extraDef = formatExtraDefinitions(def);
310 defCls.declare<ExtraClassDeclaration>(args: std::move(extraDecl),
311 args: std::move(extraDef));
312}
313
314void DefGen::emitName() {
315 StringRef name;
316 if (auto *attrDef = dyn_cast<AttrDef>(Val: &def)) {
317 name = attrDef->getAttrName();
318 } else {
319 auto *typeDef = cast<TypeDef>(Val: &def);
320 name = typeDef->getTypeName();
321 }
322 std::string nameDecl =
323 strfmt(fmt: "static constexpr ::llvm::StringLiteral name = \"{0}\";\n", parameters&: name);
324 defCls.declare<ExtraClassDeclaration>(args: std::move(nameDecl));
325}
326
327void DefGen::emitDialectName() {
328 std::string decl =
329 strfmt(fmt: "static constexpr ::llvm::StringLiteral dialectName = \"{0}\";\n",
330 parameters: def.getDialect().getName());
331 defCls.declare<ExtraClassDeclaration>(args: std::move(decl));
332}
333
334void DefGen::emitBuilders() {
335 if (!def.skipDefaultBuilders()) {
336 emitDefaultBuilder();
337 if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
338 emitCheckedBuilder();
339 }
340 for (auto &builder : def.getBuilders()) {
341 emitCustomBuilder(builder);
342 if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
343 emitCheckedCustomBuilder(builder);
344 }
345}
346
347void DefGen::emitVerifierDecl() {
348 defCls.declareStaticMethod(
349 retType: "::llvm::LogicalResult", name: "verify",
350 args: getBuilderParams(prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
351 "emitError"}}));
352}
353
354static const char *const patternParameterVerificationCode = R"(
355if (!({0})) {
356 emitError() << "failed to verify '{1}': {2}";
357 return ::mlir::failure();
358}
359)";
360
361void DefGen::emitInvariantsVerifierImpl() {
362 SmallVector<MethodParameter> builderParams = getBuilderParams(
363 prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
364 Method *verifier =
365 defCls.addMethod(retType: "::llvm::LogicalResult", name: "verifyInvariantsImpl",
366 properties: Method::Static, args&: builderParams);
367 verifier->body().indent();
368
369 // Generate verification for each parameter that is a type constraint.
370 for (auto it : llvm::enumerate(First: def.getParameters())) {
371 const AttrOrTypeParameter &param = it.value();
372 std::optional<Constraint> constraint = param.getConstraint();
373 // No verification needed for parameters that are not type constraints.
374 if (!constraint.has_value())
375 continue;
376 FmtContext ctx;
377 // Note: Skip over the first method parameter (`emitError`).
378 ctx.withSelf(subst: builderParams[it.index() + 1].getName());
379 std::string condition = tgfmt(fmt: constraint->getConditionTemplate(), ctx: &ctx);
380 verifier->body() << formatv(Fmt: patternParameterVerificationCode, Vals&: condition,
381 Vals: param.getName(), Vals: constraint->getSummary())
382 << "\n";
383 }
384 verifier->body() << "return ::mlir::success();";
385}
386
387void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) {
388 if (!hasImpl && !hasCustomVerifier)
389 return;
390 defCls.declare<UsingDeclaration>(args: "Base::getChecked");
391 SmallVector<MethodParameter> builderParams = getBuilderParams(
392 prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
393 Method *verifier =
394 defCls.addMethod(retType: "::llvm::LogicalResult", name: "verifyInvariants",
395 properties: Method::Static, args&: builderParams);
396 verifier->body().indent();
397
398 auto emitVerifierCall = [&](StringRef name) {
399 verifier->body() << strfmt(fmt: "if (::mlir::failed({0}(", parameters&: name);
400 llvm::interleaveComma(
401 c: llvm::map_range(C&: builderParams,
402 F: [](auto &param) { return param.getName(); }),
403 os&: verifier->body());
404 verifier->body() << ")))\n";
405 verifier->body() << " return ::mlir::failure();\n";
406 };
407
408 if (hasImpl) {
409 // Call the verifier that checks the type constraints.
410 emitVerifierCall("verifyInvariantsImpl");
411 }
412 if (hasCustomVerifier) {
413 // Call the custom verifier that is provided by the user.
414 emitVerifierCall("verify");
415 }
416 verifier->body() << "return ::mlir::success();";
417}
418
419void DefGen::emitParserPrinter() {
420 auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
421 retType: "::llvm::StringLiteral", name: "getMnemonic");
422 mnemonic->body().indent() << strfmt(fmt: "return {\"{0}\"};", parameters: *def.getMnemonic());
423
424 // Declare the parser and printer, if needed.
425 bool hasAssemblyFormat = def.getAssemblyFormat().has_value();
426 if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat)
427 return;
428
429 // Declare the parser.
430 SmallVector<MethodParameter> parserParams;
431 parserParams.emplace_back(Args: "::mlir::AsmParser &", Args: "odsParser");
432 if (isa<AttrDef>(Val: &def))
433 parserParams.emplace_back(Args: "::mlir::Type", Args: "odsType");
434 auto *parser = defCls.addMethod(retType: strfmt(fmt: "::mlir::{0}", parameters&: valueType), name: "parse",
435 properties: hasAssemblyFormat ? Method::Static
436 : Method::StaticDeclaration,
437 args: std::move(parserParams));
438 // Declare the printer.
439 auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration;
440 Method *printer =
441 defCls.addMethod(retType: "void", name: "print", properties: props,
442 args: MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
443 // Emit the bodies if we are using the declarative format.
444 if (hasAssemblyFormat)
445 return generateAttrOrTypeFormat(def, parser&: parser->body(), printer&: printer->body());
446}
447
448void DefGen::emitAccessors() {
449 for (auto &param : params) {
450 Method *m = defCls.addMethod(
451 retType: param.getCppAccessorType(), name: param.getAccessorName(),
452 properties: def.genStorageClass() ? Method::Const : Method::ConstDeclaration);
453 // Generate accessor definitions only if we also generate the storage
454 // class. Otherwise, let the user define the exact accessor definition.
455 if (!def.genStorageClass())
456 continue;
457 m->body().indent() << "return getImpl()->" << param.getName() << ";";
458 }
459}
460
461void DefGen::emitInterfaceMethods() {
462 for (auto &traitDef : def.getTraits())
463 if (auto *trait = dyn_cast<InterfaceTrait>(Val: &traitDef))
464 if (trait->shouldDeclareMethods())
465 emitTraitMethods(trait: *trait);
466}
467
468//===----------------------------------------------------------------------===//
469// Builder Emission
470//===----------------------------------------------------------------------===//
471
472SmallVector<MethodParameter>
473DefGen::getBuilderParams(std::initializer_list<MethodParameter> prefix) const {
474 SmallVector<MethodParameter> builderParams;
475 builderParams.append(in_start: prefix.begin(), in_end: prefix.end());
476 for (auto &param : params)
477 builderParams.emplace_back(Args: param.getCppType(), Args: param.getName());
478 return builderParams;
479}
480
481void DefGen::emitDefaultBuilder() {
482 Method *m = defCls.addStaticMethod(
483 retType: def.getCppClassName(), name: "get",
484 args: getBuilderParams(prefix: {{"::mlir::MLIRContext *", "context"}}));
485 MethodBody &body = m->body().indent();
486 auto scope = body.scope(open: "return Base::get(context", close: ");");
487 for (const auto &param : params)
488 body << ", std::move(" << param.getName() << ")";
489}
490
491void DefGen::emitCheckedBuilder() {
492 Method *m = defCls.addStaticMethod(
493 retType: def.getCppClassName(), name: "getChecked",
494 args: getBuilderParams(
495 prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"},
496 {"::mlir::MLIRContext *", "context"}}));
497 MethodBody &body = m->body().indent();
498 auto scope = body.scope(open: "return Base::getChecked(emitError, context", close: ");");
499 for (const auto &param : params)
500 body << ", " << param.getName();
501}
502
503static SmallVector<MethodParameter>
504getCustomBuilderParams(std::initializer_list<MethodParameter> prefix,
505 const AttrOrTypeBuilder &builder) {
506 auto params = builder.getParameters();
507 SmallVector<MethodParameter> builderParams;
508 builderParams.append(in_start: prefix.begin(), in_end: prefix.end());
509 if (!builder.hasInferredContextParameter())
510 builderParams.emplace_back(Args: "::mlir::MLIRContext *", Args: "context");
511 for (auto &param : params) {
512 builderParams.emplace_back(Args: param.getCppType(), Args: *param.getName(),
513 Args: param.getDefaultValue());
514 }
515 return builderParams;
516}
517
518void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
519 // Don't emit a body if there isn't one.
520 auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
521 StringRef returnType = def.getCppClassName();
522 if (std::optional<StringRef> builderReturnType = builder.getReturnType())
523 returnType = *builderReturnType;
524 Method *m = defCls.addMethod(retType&: returnType, name: "get", properties: props,
525 args: getCustomBuilderParams(prefix: {}, builder));
526 if (!builder.getBody())
527 return;
528
529 // Format the body and emit it.
530 FmtContext ctx;
531 ctx.addSubst(placeholder: "_get", subst: "Base::get");
532 if (!builder.hasInferredContextParameter())
533 ctx.addSubst(placeholder: "_ctxt", subst: "context");
534 std::string bodyStr = tgfmt(fmt: *builder.getBody(), ctx: &ctx);
535 m->body().indent().getStream().printReindented(str: bodyStr);
536}
537
538/// Replace all instances of 'from' to 'to' in `str` and return the new string.
539static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
540 size_t pos = 0;
541 while ((pos = str.find(s: from.data(), pos: pos, n: from.size())) != std::string::npos)
542 str.replace(pos: pos, n1: from.size(), s: to.data(), n2: to.size());
543 return str;
544}
545
546void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
547 // Don't emit a body if there isn't one.
548 auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
549 StringRef returnType = def.getCppClassName();
550 if (std::optional<StringRef> builderReturnType = builder.getReturnType())
551 returnType = *builderReturnType;
552 Method *m = defCls.addMethod(
553 retType&: returnType, name: "getChecked", properties: props,
554 args: getCustomBuilderParams(
555 prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}},
556 builder));
557 if (!builder.getBody())
558 return;
559
560 // Format the body and emit it. Replace $_get(...) with
561 // Base::getChecked(emitError, ...)
562 FmtContext ctx;
563 if (!builder.hasInferredContextParameter())
564 ctx.addSubst(placeholder: "_ctxt", subst: "context");
565 std::string bodyStr = replaceInStr(str: builder.getBody()->str(), from: "$_get(",
566 to: "Base::getChecked(emitError, ");
567 bodyStr = tgfmt(fmt: bodyStr, ctx: &ctx);
568 m->body().indent().getStream().printReindented(str: bodyStr);
569}
570
571//===----------------------------------------------------------------------===//
572// Interface Method Emission
573//===----------------------------------------------------------------------===//
574
575void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
576 // Get the set of methods that should always be declared.
577 auto alwaysDeclaredMethods = trait.getAlwaysDeclaredMethods();
578 StringSet<> alwaysDeclared;
579 alwaysDeclared.insert_range(R&: alwaysDeclaredMethods);
580
581 Interface iface = trait.getInterface(); // causes strange bugs if elided
582 for (auto &method : iface.getMethods()) {
583 // Don't declare if the method has a body. Or if the method has a default
584 // implementation and the def didn't request that it always be declared.
585 if (method.getBody() || (method.getDefaultImplementation() &&
586 !alwaysDeclared.count(Key: method.getName())))
587 continue;
588 emitTraitMethod(method);
589 }
590}
591
592void DefGen::emitTraitMethod(const InterfaceMethod &method) {
593 // All interface methods are declaration-only.
594 auto props =
595 method.isStatic() ? Method::StaticDeclaration : Method::ConstDeclaration;
596 SmallVector<MethodParameter> params;
597 for (auto &param : method.getArguments())
598 params.emplace_back(Args: param.type, Args: param.name);
599 defCls.addMethod(retType: method.getReturnType(), name: method.getName(), properties: props,
600 args: std::move(params));
601}
602
603//===----------------------------------------------------------------------===//
604// OpAsm{Type,Attr}Interface Default Method Emission
605
606void DefGen::emitMnemonicAliasMethod() {
607 // If the mnemonic is not set, there is nothing to do.
608 if (!def.getMnemonic())
609 return;
610
611 // Emit the mnemonic alias method.
612 SmallVector<MethodParameter> params{{"::llvm::raw_ostream &", "os"}};
613 Method *m = defCls.addMethod<Method::Const>(retType: "::mlir::OpAsmAliasResult",
614 name: "getAlias", args: std::move(params));
615 m->body().indent() << strfmt(fmt: "os << \"{0}\";\n", parameters: *def.getMnemonic())
616 << "return ::mlir::OpAsmAliasResult::OverridableAlias;\n";
617}
618
619//===----------------------------------------------------------------------===//
620// Storage Class Emission
621//===----------------------------------------------------------------------===//
622
623void DefGen::emitStorageConstructor() {
624 Constructor *ctor =
625 storageCls->addConstructor<Method::Inline>(args: getBuilderParams(prefix: {}));
626 for (auto &param : params) {
627 std::string movedValue = ("std::move(" + param.getName() + ")").str();
628 ctor->addMemberInitializer(name: param.getName(), value&: movedValue);
629 }
630}
631
632void DefGen::emitKeyType() {
633 std::string keyType("std::tuple<");
634 llvm::raw_string_ostream os(keyType);
635 llvm::interleaveComma(c: params, os,
636 each_fn: [&](auto &param) { os << param.getCppType(); });
637 os << '>';
638 storageCls->declare<UsingDeclaration>(args: "KeyTy", args: std::move(os.str()));
639
640 // Add a method to construct the key type from the storage.
641 Method *m = storageCls->addConstMethod<Method::Inline>(retType: "KeyTy", name: "getAsKey");
642 m->body().indent() << "return KeyTy(";
643 llvm::interleaveComma(c: params, os&: m->body().indent(),
644 each_fn: [&](auto &param) { m->body() << param.getName(); });
645 m->body() << ");";
646}
647
648void DefGen::emitEquals() {
649 Method *eq = storageCls->addConstMethod<Method::Inline>(
650 retType: "bool", name: "operator==", args: MethodParameter("const KeyTy &", "tblgenKey"));
651 auto &body = eq->body().indent();
652 auto scope = body.scope(open: "return (", close: ");");
653 const auto eachFn = [&](auto it) {
654 FmtContext ctx({{"_lhs", it.value().getName()},
655 {"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}});
656 body << tgfmt(it.value().getComparator(), &ctx);
657 };
658 llvm::interleave(c: llvm::enumerate(First&: params), os&: body, each_fn: eachFn, separator: ") && (");
659}
660
661void DefGen::emitHashKey() {
662 Method *hash = storageCls->addStaticInlineMethod(
663 retType: "::llvm::hash_code", name: "hashKey",
664 args: MethodParameter("const KeyTy &", "tblgenKey"));
665 auto &body = hash->body().indent();
666 auto scope = body.scope(open: "return ::llvm::hash_combine(", close: ");");
667 llvm::interleaveComma(c: llvm::enumerate(First&: params), os&: body, each_fn: [&](auto it) {
668 body << llvm::formatv("std::get<{0}>(tblgenKey)", it.index());
669 });
670}
671
672void DefGen::emitConstruct() {
673 Method *construct = storageCls->addMethod<Method::Inline>(
674 retType: strfmt(fmt: "{0} *", parameters: def.getStorageClassName()), name: "construct",
675 properties: def.hasStorageCustomConstructor() ? Method::StaticDeclaration
676 : Method::Static,
677 args: MethodParameter(strfmt(fmt: "::mlir::{0}StorageAllocator &", parameters&: valueType),
678 "allocator"),
679 args: MethodParameter("KeyTy &&", "tblgenKey"));
680 if (!def.hasStorageCustomConstructor()) {
681 auto &body = construct->body().indent();
682 for (const auto &it : llvm::enumerate(First&: params)) {
683 body << formatv(Fmt: "auto {0} = std::move(std::get<{1}>(tblgenKey));\n",
684 Vals: it.value().getName(), Vals: it.index());
685 }
686 // Use the parameters' custom allocator code, if provided.
687 FmtContext ctx = FmtContext().addSubst(placeholder: "_allocator", subst: "allocator");
688 for (auto &param : params) {
689 if (std::optional<StringRef> allocCode = param.getAllocator()) {
690 ctx.withSelf(subst: param.getName()).addSubst(placeholder: "_dst", subst: param.getName());
691 body << tgfmt(fmt: *allocCode, ctx: &ctx) << '\n';
692 }
693 }
694 auto scope =
695 body.scope(open: strfmt(fmt: "return new (allocator.allocate<{0}>()) {0}(",
696 parameters: def.getStorageClassName()),
697 close: ");");
698 llvm::interleaveComma(c: params, os&: body, each_fn: [&](auto &param) {
699 body << "std::move(" << param.getName() << ")";
700 });
701 }
702}
703
704void DefGen::emitStorageClass() {
705 // Add the appropriate parent class.
706 storageCls->addParent(parent: strfmt(fmt: "::mlir::{0}Storage", parameters&: valueType));
707 // Add the constructor.
708 emitStorageConstructor();
709 // Declare the key type.
710 emitKeyType();
711 // Add the comparison method.
712 emitEquals();
713 // Emit the key hash method.
714 emitHashKey();
715 // Emit the storage constructor. Just declare it if the user wants to define
716 // it themself.
717 emitConstruct();
718 // Emit the storage class members as public, at the very end of the struct.
719 storageCls->finalize();
720 for (auto &param : params) {
721 if (param.getCppType().contains(Other: "APInt") && !param.hasCustomComparator()) {
722 PrintFatalError(
723 ErrorLoc: def.getLoc(),
724 Msg: "Using a raw APInt parameter without a custom comparator is "
725 "not supported because an assert in the equality operator is "
726 "triggered when the two APInts have different bit widths. This can "
727 "lead to unexpected crashes. Use an `APIntParameter` or "
728 "provide a custom comparator.");
729 }
730 storageCls->declare<Field>(args: param.getCppType(), args: param.getName());
731 }
732}
733
734//===----------------------------------------------------------------------===//
735// DefGenerator
736//===----------------------------------------------------------------------===//
737
738namespace {
739/// This struct is the base generator used when processing tablegen interfaces.
740class DefGenerator {
741public:
742 bool emitDecls(StringRef selectedDialect);
743 bool emitDefs(StringRef selectedDialect);
744
745protected:
746 DefGenerator(ArrayRef<const Record *> defs, raw_ostream &os,
747 StringRef defType, StringRef valueType, bool isAttrGenerator)
748 : defRecords(defs), os(os), defType(defType), valueType(valueType),
749 isAttrGenerator(isAttrGenerator) {
750 // Sort by occurrence in file.
751 llvm::sort(C&: defRecords, Comp: [](const Record *lhs, const Record *rhs) {
752 return lhs->getID() < rhs->getID();
753 });
754 }
755
756 /// Emit the list of def type names.
757 void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
758 /// Emit the code to dispatch between different defs during parsing/printing.
759 void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
760
761 /// The set of def records to emit.
762 std::vector<const Record *> defRecords;
763 /// The attribute or type class to emit.
764 /// The stream to emit to.
765 raw_ostream &os;
766 /// The prefix of the tablegen def name, e.g. Attr or Type.
767 StringRef defType;
768 /// The C++ base value type of the def, e.g. Attribute or Type.
769 StringRef valueType;
770 /// Flag indicating if this generator is for Attributes. False if the
771 /// generator is for types.
772 bool isAttrGenerator;
773};
774
775/// A specialized generator for AttrDefs.
776struct AttrDefGenerator : public DefGenerator {
777 AttrDefGenerator(const RecordKeeper &records, raw_ostream &os)
778 : DefGenerator(records.getAllDerivedDefinitionsIfDefined(ClassName: "AttrDef"), os,
779 "Attr", "Attribute", /*isAttrGenerator=*/true) {}
780};
781/// A specialized generator for TypeDefs.
782struct TypeDefGenerator : public DefGenerator {
783 TypeDefGenerator(const RecordKeeper &records, raw_ostream &os)
784 : DefGenerator(records.getAllDerivedDefinitionsIfDefined(ClassName: "TypeDef"), os,
785 "Type", "Type", /*isAttrGenerator=*/false) {}
786};
787} // namespace
788
789//===----------------------------------------------------------------------===//
790// GEN: Declarations
791//===----------------------------------------------------------------------===//
792
793/// Print this above all the other declarations. Contains type declarations used
794/// later on.
795static const char *const typeDefDeclHeader = R"(
796namespace mlir {
797class AsmParser;
798class AsmPrinter;
799} // namespace mlir
800)";
801
802bool DefGenerator::emitDecls(StringRef selectedDialect) {
803 emitSourceFileHeader(Desc: (defType + "Def Declarations").str(), OS&: os);
804 IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
805
806 // Output the common "header".
807 os << typeDefDeclHeader;
808
809 SmallVector<AttrOrTypeDef, 16> defs;
810 collectAllDefs(selectedDialect, records: defRecords, resultDefs&: defs);
811 if (defs.empty())
812 return false;
813 {
814 NamespaceEmitter nsEmitter(os, defs.front().getDialect());
815
816 // Declare all the def classes first (in case they reference each other).
817 for (const AttrOrTypeDef &def : defs) {
818 std::string comments = tblgen::emitSummaryAndDescComments(
819 summary: def.getSummary(), description: def.getDescription());
820 if (!comments.empty()) {
821 os << comments << "\n";
822 }
823 os << "class " << def.getCppClassName() << ";\n";
824 }
825
826 // Emit the declarations.
827 for (const AttrOrTypeDef &def : defs)
828 DefGen(def).emitDecl(os);
829 }
830 // Emit the TypeID explicit specializations to have a single definition for
831 // each of these.
832 for (const AttrOrTypeDef &def : defs)
833 if (!def.getDialect().getCppNamespace().empty())
834 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID("
835 << def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
836 << ")\n";
837
838 return false;
839}
840
841//===----------------------------------------------------------------------===//
842// GEN: Def List
843//===----------------------------------------------------------------------===//
844
845void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
846 IfDefScope scope("GET_" + defType.upper() + "DEF_LIST", os);
847 auto interleaveFn = [&](const AttrOrTypeDef &def) {
848 os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName();
849 };
850 llvm::interleave(c: defs, os, each_fn: interleaveFn, separator: ",\n");
851 os << "\n";
852}
853
854//===----------------------------------------------------------------------===//
855// GEN: Definitions
856//===----------------------------------------------------------------------===//
857
858/// The code block for default attribute parser/printer dispatch boilerplate.
859/// {0}: the dialect fully qualified class name.
860/// {1}: the optional code for the dynamic attribute parser dispatch.
861/// {2}: the optional code for the dynamic attribute printer dispatch.
862static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
863/// Parse an attribute registered to this dialect.
864::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser,
865 ::mlir::Type type) const {{
866 ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
867 ::llvm::StringRef attrTag;
868 {{
869 ::mlir::Attribute attr;
870 auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
871 if (parseResult.has_value())
872 return attr;
873 }
874 {1}
875 parser.emitError(typeLoc) << "unknown attribute `"
876 << attrTag << "` in dialect `" << getNamespace() << "`";
877 return {{};
878}
879/// Print an attribute registered to this dialect.
880void {0}::printAttribute(::mlir::Attribute attr,
881 ::mlir::DialectAsmPrinter &printer) const {{
882 if (::mlir::succeeded(generatedAttributePrinter(attr, printer)))
883 return;
884 {2}
885}
886)";
887
888/// The code block for dynamic attribute parser dispatch boilerplate.
889static const char *const dialectDynamicAttrParserDispatch = R"(
890 {
891 ::mlir::Attribute genAttr;
892 auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr);
893 if (parseResult.has_value()) {
894 if (::mlir::succeeded(parseResult.value()))
895 return genAttr;
896 return Attribute();
897 }
898 }
899)";
900
901/// The code block for dynamic type printer dispatch boilerplate.
902static const char *const dialectDynamicAttrPrinterDispatch = R"(
903 if (::mlir::succeeded(printIfDynamicAttr(attr, printer)))
904 return;
905)";
906
907/// The code block for default type parser/printer dispatch boilerplate.
908/// {0}: the dialect fully qualified class name.
909/// {1}: the optional code for the dynamic type parser dispatch.
910/// {2}: the optional code for the dynamic type printer dispatch.
911static const char *const dialectDefaultTypePrinterParserDispatch = R"(
912/// Parse a type registered to this dialect.
913::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{
914 ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
915 ::llvm::StringRef mnemonic;
916 ::mlir::Type genType;
917 auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
918 if (parseResult.has_value())
919 return genType;
920 {1}
921 parser.emitError(typeLoc) << "unknown type `"
922 << mnemonic << "` in dialect `" << getNamespace() << "`";
923 return {{};
924}
925/// Print a type registered to this dialect.
926void {0}::printType(::mlir::Type type,
927 ::mlir::DialectAsmPrinter &printer) const {{
928 if (::mlir::succeeded(generatedTypePrinter(type, printer)))
929 return;
930 {2}
931}
932)";
933
934/// The code block for dynamic type parser dispatch boilerplate.
935static const char *const dialectDynamicTypeParserDispatch = R"(
936 {
937 auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
938 if (parseResult.has_value()) {
939 if (::mlir::succeeded(parseResult.value()))
940 return genType;
941 return ::mlir::Type();
942 }
943 }
944)";
945
946/// The code block for dynamic type printer dispatch boilerplate.
947static const char *const dialectDynamicTypePrinterDispatch = R"(
948 if (::mlir::succeeded(printIfDynamicType(type, printer)))
949 return;
950)";
951
952/// Emit the dialect printer/parser dispatcher. User's code should call these
953/// functions from their dialect's print/parse methods.
954void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
955 if (llvm::none_of(Range&: defs, P: [](const AttrOrTypeDef &def) {
956 return def.getMnemonic().has_value();
957 })) {
958 return;
959 }
960 // Declare the parser.
961 SmallVector<MethodParameter> params = {{"::mlir::AsmParser &", "parser"},
962 {"::llvm::StringRef *", "mnemonic"}};
963 if (isAttrGenerator)
964 params.emplace_back(Args: "::mlir::Type", Args: "type");
965 params.emplace_back(Args: strfmt(fmt: "::mlir::{0} &", parameters&: valueType), Args: "value");
966 Method parse("::mlir::OptionalParseResult",
967 strfmt(fmt: "generated{0}Parser", parameters&: valueType), Method::StaticInline,
968 std::move(params));
969 // Declare the printer.
970 Method printer("::llvm::LogicalResult",
971 strfmt(fmt: "generated{0}Printer", parameters&: valueType), Method::StaticInline,
972 {{strfmt(fmt: "::mlir::{0}", parameters&: valueType), "def"},
973 {"::mlir::AsmPrinter &", "printer"}});
974
975 // The parser dispatch uses a KeywordSwitch, matching on the mnemonic and
976 // calling the def's parse function.
977 parse.body() << " return "
978 "::mlir::AsmParser::KeywordSwitch<::mlir::"
979 "OptionalParseResult>(parser)\n";
980 const char *const getValueForMnemonic =
981 R"( .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{
982 value = {0}::{1};
983 return ::mlir::success(!!value);
984 })
985)";
986
987 // The printer dispatch uses llvm::TypeSwitch to find and call the correct
988 // printer.
989 printer.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType
990 << ", ::llvm::LogicalResult>(def)";
991 const char *const printValue = R"( .Case<{0}>([&](auto t) {{
992 printer << {0}::getMnemonic();{1}
993 return ::mlir::success();
994 })
995)";
996 for (auto &def : defs) {
997 if (!def.getMnemonic())
998 continue;
999 bool hasParserPrinterDecl =
1000 def.hasCustomAssemblyFormat() || def.getAssemblyFormat();
1001 std::string defClass = strfmt(
1002 fmt: "{0}::{1}", parameters: def.getDialect().getCppNamespace(), parameters: def.getCppClassName());
1003
1004 // If the def has no parameters or parser code, invoke a normal `get`.
1005 std::string parseOrGet =
1006 hasParserPrinterDecl
1007 ? strfmt(fmt: "parse(parser{0})", parameters: isAttrGenerator ? ", type" : "")
1008 : "get(parser.getContext())";
1009 parse.body() << llvm::formatv(Fmt: getValueForMnemonic, Vals&: defClass, Vals&: parseOrGet);
1010
1011 // If the def has no parameters and no printer, just print the mnemonic.
1012 StringRef printDef = "";
1013 if (hasParserPrinterDecl)
1014 printDef = "\nt.print(printer);";
1015 printer.body() << llvm::formatv(Fmt: printValue, Vals&: defClass, Vals&: printDef);
1016 }
1017 parse.body() << " .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n"
1018 " *mnemonic = keyword;\n"
1019 " return std::nullopt;\n"
1020 " });";
1021 printer.body() << " .Default([](auto) { return ::mlir::failure(); });";
1022
1023 raw_indented_ostream indentedOs(os);
1024 parse.writeDeclTo(os&: indentedOs);
1025 printer.writeDeclTo(os&: indentedOs);
1026}
1027
1028bool DefGenerator::emitDefs(StringRef selectedDialect) {
1029 emitSourceFileHeader(Desc: (defType + "Def Definitions").str(), OS&: os);
1030
1031 SmallVector<AttrOrTypeDef, 16> defs;
1032 collectAllDefs(selectedDialect, records: defRecords, resultDefs&: defs);
1033 if (defs.empty())
1034 return false;
1035 emitTypeDefList(defs);
1036
1037 IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
1038 emitParsePrintDispatch(defs);
1039 for (const AttrOrTypeDef &def : defs) {
1040 {
1041 NamespaceEmitter ns(os, def.getDialect());
1042 DefGen gen(def);
1043 gen.emitDef(os);
1044 }
1045 // Emit the TypeID explicit specializations to have a single symbol def.
1046 if (!def.getDialect().getCppNamespace().empty())
1047 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID("
1048 << def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
1049 << ")\n";
1050 }
1051
1052 Dialect firstDialect = defs.front().getDialect();
1053
1054 // Emit the default parser/printer for Attributes if the dialect asked for it.
1055 if (isAttrGenerator && firstDialect.useDefaultAttributePrinterParser()) {
1056 NamespaceEmitter nsEmitter(os, firstDialect);
1057 if (firstDialect.isExtensible()) {
1058 os << llvm::formatv(Fmt: dialectDefaultAttrPrinterParserDispatch,
1059 Vals: firstDialect.getCppClassName(),
1060 Vals: dialectDynamicAttrParserDispatch,
1061 Vals: dialectDynamicAttrPrinterDispatch);
1062 } else {
1063 os << llvm::formatv(Fmt: dialectDefaultAttrPrinterParserDispatch,
1064 Vals: firstDialect.getCppClassName(), Vals: "", Vals: "");
1065 }
1066 }
1067
1068 // Emit the default parser/printer for Types if the dialect asked for it.
1069 if (!isAttrGenerator && firstDialect.useDefaultTypePrinterParser()) {
1070 NamespaceEmitter nsEmitter(os, firstDialect);
1071 if (firstDialect.isExtensible()) {
1072 os << llvm::formatv(Fmt: dialectDefaultTypePrinterParserDispatch,
1073 Vals: firstDialect.getCppClassName(),
1074 Vals: dialectDynamicTypeParserDispatch,
1075 Vals: dialectDynamicTypePrinterDispatch);
1076 } else {
1077 os << llvm::formatv(Fmt: dialectDefaultTypePrinterParserDispatch,
1078 Vals: firstDialect.getCppClassName(), Vals: "", Vals: "");
1079 }
1080 }
1081
1082 return false;
1083}
1084
1085//===----------------------------------------------------------------------===//
1086// Type Constraints
1087//===----------------------------------------------------------------------===//
1088
1089/// Find all type constraints for which a C++ function should be generated.
1090static std::vector<Constraint>
1091getAllTypeConstraints(const RecordKeeper &records) {
1092 std::vector<Constraint> result;
1093 for (const Record *def :
1094 records.getAllDerivedDefinitionsIfDefined(ClassName: "TypeConstraint")) {
1095 // Ignore constraints defined outside of the top-level file.
1096 if (llvm::SrcMgr.FindBufferContainingLoc(Loc: def->getLoc()[0]) !=
1097 llvm::SrcMgr.getMainFileID())
1098 continue;
1099 Constraint constr(def);
1100 // Generate C++ function only if "cppFunctionName" is set.
1101 if (!constr.getCppFunctionName())
1102 continue;
1103 result.push_back(x: constr);
1104 }
1105 return result;
1106}
1107
1108static void emitTypeConstraintDecls(const RecordKeeper &records,
1109 raw_ostream &os) {
1110 static const char *const typeConstraintDecl = R"(
1111bool {0}(::mlir::Type type);
1112)";
1113
1114 for (Constraint constr : getAllTypeConstraints(records))
1115 os << strfmt(fmt: typeConstraintDecl, parameters: *constr.getCppFunctionName());
1116}
1117
1118static void emitTypeConstraintDefs(const RecordKeeper &records,
1119 raw_ostream &os) {
1120 static const char *const typeConstraintDef = R"(
1121bool {0}(::mlir::Type type) {
1122 return ({1});
1123}
1124)";
1125
1126 for (Constraint constr : getAllTypeConstraints(records)) {
1127 FmtContext ctx;
1128 ctx.withSelf(subst: "type");
1129 std::string condition = tgfmt(fmt: constr.getConditionTemplate(), ctx: &ctx);
1130 os << strfmt(fmt: typeConstraintDef, parameters: *constr.getCppFunctionName(), parameters&: condition);
1131 }
1132}
1133
1134//===----------------------------------------------------------------------===//
1135// GEN: Registration hooks
1136//===----------------------------------------------------------------------===//
1137
1138//===----------------------------------------------------------------------===//
1139// AttrDef
1140//===----------------------------------------------------------------------===//
1141
1142static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*");
1143static llvm::cl::opt<std::string>
1144 attrDialect("attrdefs-dialect",
1145 llvm::cl::desc("Generate attributes for this dialect"),
1146 llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated);
1147
1148static mlir::GenRegistration
1149 genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
1150 [](const RecordKeeper &records, raw_ostream &os) {
1151 AttrDefGenerator generator(records, os);
1152 return generator.emitDefs(selectedDialect: attrDialect);
1153 });
1154static mlir::GenRegistration
1155 genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
1156 [](const RecordKeeper &records, raw_ostream &os) {
1157 AttrDefGenerator generator(records, os);
1158 return generator.emitDecls(selectedDialect: attrDialect);
1159 });
1160
1161//===----------------------------------------------------------------------===//
1162// TypeDef
1163//===----------------------------------------------------------------------===//
1164
1165static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
1166static llvm::cl::opt<std::string>
1167 typeDialect("typedefs-dialect",
1168 llvm::cl::desc("Generate types for this dialect"),
1169 llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
1170
1171static mlir::GenRegistration
1172 genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
1173 [](const RecordKeeper &records, raw_ostream &os) {
1174 TypeDefGenerator generator(records, os);
1175 return generator.emitDefs(selectedDialect: typeDialect);
1176 });
1177static mlir::GenRegistration
1178 genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
1179 [](const RecordKeeper &records, raw_ostream &os) {
1180 TypeDefGenerator generator(records, os);
1181 return generator.emitDecls(selectedDialect: typeDialect);
1182 });
1183
1184static mlir::GenRegistration
1185 genTypeConstrDefs("gen-type-constraint-defs",
1186 "Generate type constraint definitions",
1187 [](const RecordKeeper &records, raw_ostream &os) {
1188 emitTypeConstraintDefs(records, os);
1189 return false;
1190 });
1191static mlir::GenRegistration
1192 genTypeConstrDecls("gen-type-constraint-decls",
1193 "Generate type constraint declarations",
1194 [](const RecordKeeper &records, raw_ostream &os) {
1195 emitTypeConstraintDecls(records, os);
1196 return false;
1197 });
1198

source code of mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp