1//===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "AttrOrTypeFormatGen.h"
10#include "CppGenUtilities.h"
11#include "mlir/TableGen/AttrOrTypeDef.h"
12#include "mlir/TableGen/Class.h"
13#include "mlir/TableGen/CodeGenHelpers.h"
14#include "mlir/TableGen/Format.h"
15#include "mlir/TableGen/GenInfo.h"
16#include "mlir/TableGen/Interfaces.h"
17#include "llvm/ADT/StringSet.h"
18#include "llvm/Support/CommandLine.h"
19#include "llvm/TableGen/Error.h"
20#include "llvm/TableGen/TableGenBackend.h"
21
22#define DEBUG_TYPE "mlir-tblgen-attrortypedefgen"
23
24using namespace mlir;
25using namespace mlir::tblgen;
26using llvm::Record;
27using llvm::RecordKeeper;
28
29//===----------------------------------------------------------------------===//
30// Utility Functions
31//===----------------------------------------------------------------------===//
32
33/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
34/// specified and can only find one dialect's defs, use that.
35static void collectAllDefs(StringRef selectedDialect,
36 ArrayRef<const Record *> records,
37 SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
38 // Nothing to do if no defs were found.
39 if (records.empty())
40 return;
41
42 auto defs = llvm::map_range(
43 C&: records, F: [&](const Record *rec) { return AttrOrTypeDef(rec); });
44 if (selectedDialect.empty()) {
45 // If a dialect was not specified, ensure that all found defs belong to the
46 // same dialect.
47 if (!llvm::all_equal(Range: llvm::map_range(
48 C&: defs, F: [](const auto &def) { return def.getDialect(); }))) {
49 llvm::PrintFatalError(Msg: "defs belonging to more than one dialect. Must "
50 "select one via '--(attr|type)defs-dialect'");
51 }
52 resultDefs.assign(in_start: defs.begin(), in_end: defs.end());
53 } else {
54 // Otherwise, generate the defs that belong to the selected dialect.
55 auto dialectDefs = llvm::make_filter_range(Range&: defs, Pred: [&](const auto &def) {
56 return def.getDialect().getName() == selectedDialect;
57 });
58 resultDefs.assign(in_start: dialectDefs.begin(), in_end: dialectDefs.end());
59 }
60}
61
62//===----------------------------------------------------------------------===//
63// DefGen
64//===----------------------------------------------------------------------===//
65
66namespace {
67class DefGen {
68public:
69 /// Create the attribute or type class.
70 DefGen(const AttrOrTypeDef &def);
71
72 void emitDecl(raw_ostream &os) const {
73 if (storageCls && def.genStorageClass()) {
74 NamespaceEmitter ns(os, def.getStorageNamespace());
75 os << "struct " << def.getStorageClassName() << ";\n";
76 }
77 defCls.writeDeclTo(rawOs&: os);
78 }
79 void emitDef(raw_ostream &os) const {
80 if (storageCls && def.genStorageClass()) {
81 NamespaceEmitter ns(os, def.getStorageNamespace());
82 storageCls->writeDeclTo(rawOs&: os); // everything is inline
83 }
84 defCls.writeDefTo(rawOs&: os);
85 }
86
87private:
88 /// Add traits from the TableGen definition to the class.
89 void createParentWithTraits();
90 /// Emit top-level declarations: using declarations and any extra class
91 /// declarations.
92 void emitTopLevelDeclarations();
93 /// Emit the function that returns the type or attribute name.
94 void emitName();
95 /// Emit the dialect name as a static member variable.
96 void emitDialectName();
97 /// Emit attribute or type builders.
98 void emitBuilders();
99 /// Emit a verifier declaration for custom verification (impl. provided by
100 /// the users).
101 void emitVerifierDecl();
102 /// Emit a verifier that checks type constraints.
103 void emitInvariantsVerifierImpl();
104 /// Emit an entry poiunt for verification that calls the invariants and
105 /// custom verifier.
106 void emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier);
107 /// Emit parsers and printers.
108 void emitParserPrinter();
109 /// Emit parameter accessors, if required.
110 void emitAccessors();
111 /// Emit interface methods.
112 void emitInterfaceMethods();
113
114 //===--------------------------------------------------------------------===//
115 // Builder Emission
116
117 /// Emit the default builder `Attribute::get`
118 void emitDefaultBuilder();
119 /// Emit the checked builder `Attribute::getChecked`
120 void emitCheckedBuilder();
121 /// Emit a custom builder.
122 void emitCustomBuilder(const AttrOrTypeBuilder &builder);
123 /// Emit a checked custom builder.
124 void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder);
125
126 //===--------------------------------------------------------------------===//
127 // Interface Method Emission
128
129 /// Emit methods for a trait.
130 void emitTraitMethods(const InterfaceTrait &trait);
131 /// Emit a trait method.
132 void emitTraitMethod(const InterfaceMethod &method);
133
134 //===--------------------------------------------------------------------===//
135 // OpAsm{Type,Attr}Interface Default Method Emission
136
137 /// Emit 'getAlias' method using mnemonic as alias.
138 void emitMnemonicAliasMethod();
139
140 //===--------------------------------------------------------------------===//
141 // Storage Class Emission
142 void emitStorageClass();
143 /// Generate the storage class constructor.
144 void emitStorageConstructor();
145 /// Emit the key type `KeyTy`.
146 void emitKeyType();
147 /// Emit the equality comparison operator.
148 void emitEquals();
149 /// Emit the key hash function.
150 void emitHashKey();
151 /// Emit the function to construct the storage class.
152 void emitConstruct();
153
154 //===--------------------------------------------------------------------===//
155 // Utility Function Declarations
156
157 /// Get the method parameters for a def builder, where the first several
158 /// parameters may be different.
159 SmallVector<MethodParameter>
160 getBuilderParams(std::initializer_list<MethodParameter> prefix) const;
161
162 //===--------------------------------------------------------------------===//
163 // Class fields
164
165 /// The attribute or type definition.
166 const AttrOrTypeDef &def;
167 /// The list of attribute or type parameters.
168 ArrayRef<AttrOrTypeParameter> params;
169 /// The attribute or type class.
170 Class defCls;
171 /// An optional attribute or type storage class. The storage class will
172 /// exist if and only if the def has more than zero parameters.
173 std::optional<Class> storageCls;
174
175 /// The C++ base value of the def, either "Attribute" or "Type".
176 StringRef valueType;
177 /// The prefix/suffix of the TableGen def name, either "Attr" or "Type".
178 StringRef defType;
179};
180} // namespace
181
182DefGen::DefGen(const AttrOrTypeDef &def)
183 : def(def), params(def.getParameters()), defCls(def.getCppClassName()),
184 valueType(isa<AttrDef>(Val: def) ? "Attribute" : "Type"),
185 defType(isa<AttrDef>(Val: def) ? "Attr" : "Type") {
186 // Check that all parameters have names.
187 for (const AttrOrTypeParameter &param : def.getParameters())
188 if (param.isAnonymous())
189 llvm::PrintFatalError(Msg: "all parameters must have a name");
190
191 // If a storage class is needed, create one.
192 if (def.getNumParameters() > 0)
193 storageCls.emplace(args: def.getStorageClassName(), /*isStruct=*/args: true);
194
195 // Create the parent class with any indicated traits.
196 createParentWithTraits();
197 // Emit top-level declarations.
198 emitTopLevelDeclarations();
199 // Emit builders for defs with parameters
200 if (storageCls)
201 emitBuilders();
202 // Emit the type name.
203 emitName();
204 // Emit the dialect name.
205 emitDialectName();
206 // Emit verification of type constraints.
207 bool genVerifyInvariantsImpl = def.genVerifyInvariantsImpl();
208 if (storageCls && genVerifyInvariantsImpl)
209 emitInvariantsVerifierImpl();
210 // Emit the custom verifier (written by the user).
211 bool genVerifyDecl = def.genVerifyDecl();
212 if (storageCls && genVerifyDecl)
213 emitVerifierDecl();
214 // Emit the "verifyInvariants" function if there is any verification at all.
215 if (storageCls)
216 emitInvariantsVerifier(hasImpl: genVerifyInvariantsImpl, hasCustomVerifier: genVerifyDecl);
217 // Emit the mnemonic, if there is one, and any associated parser and printer.
218 if (def.getMnemonic())
219 emitParserPrinter();
220 // Emit accessors
221 if (def.genAccessors())
222 emitAccessors();
223 // Emit trait interface methods
224 emitInterfaceMethods();
225 // Emit OpAsm{Type,Attr}Interface default methods
226 if (def.genMnemonicAlias())
227 emitMnemonicAliasMethod();
228 defCls.finalize();
229 // Emit a storage class if one is needed
230 if (storageCls && def.genStorageClass())
231 emitStorageClass();
232}
233
234void DefGen::createParentWithTraits() {
235 ParentClass defParent(strfmt(fmt: "::mlir::{0}::{1}Base", parameters&: valueType, parameters&: defType));
236 defParent.addTemplateParam(param: def.getCppClassName());
237 defParent.addTemplateParam(param: def.getCppBaseClassName());
238 defParent.addTemplateParam(param: storageCls
239 ? strfmt(fmt: "{0}::{1}", parameters: def.getStorageNamespace(),
240 parameters: def.getStorageClassName())
241 : strfmt(fmt: "::mlir::{0}Storage", parameters&: valueType));
242 SmallVector<std::string> traitNames =
243 llvm::to_vector(Range: llvm::map_range(C: def.getTraits(), F: [](auto &trait) {
244 return isa<NativeTrait>(&trait)
245 ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
246 : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName();
247 }));
248 for (auto &traitName : traitNames)
249 defParent.addTemplateParam(param: traitName);
250
251 // Add OpAsmInterface::Trait if we automatically generate mnemonic alias
252 // method.
253 std::string opAsmInterfaceTraitName =
254 strfmt(fmt: "::mlir::OpAsm{0}Interface::Trait", parameters&: defType);
255 if (def.genMnemonicAlias() &&
256 !llvm::is_contained(Range&: traitNames, Element: opAsmInterfaceTraitName)) {
257 defParent.addTemplateParam(param: opAsmInterfaceTraitName);
258 }
259 defCls.addParent(parent: std::move(defParent));
260}
261
262/// Include declarations specified on NativeTrait
263static std::string formatExtraDeclarations(const AttrOrTypeDef &def) {
264 SmallVector<StringRef> extraDeclarations;
265 // Include extra class declarations from NativeTrait
266 for (const auto &trait : def.getTraits()) {
267 if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
268 StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration();
269 if (value.empty())
270 continue;
271 extraDeclarations.push_back(Elt: value);
272 }
273 }
274 if (std::optional<StringRef> extraDecl = def.getExtraDecls()) {
275 extraDeclarations.push_back(Elt: *extraDecl);
276 }
277 return llvm::join(R&: extraDeclarations, Separator: "\n");
278}
279
280/// Extra class definitions have a `$cppClass` substitution that is to be
281/// replaced by the C++ class name.
282static std::string formatExtraDefinitions(const AttrOrTypeDef &def) {
283 SmallVector<StringRef> extraDefinitions;
284 // Include extra class definitions from NativeTrait
285 for (const auto &trait : def.getTraits()) {
286 if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
287 StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition();
288 if (value.empty())
289 continue;
290 extraDefinitions.push_back(Elt: value);
291 }
292 }
293 if (std::optional<StringRef> extraDef = def.getExtraDefs()) {
294 extraDefinitions.push_back(Elt: *extraDef);
295 }
296 FmtContext ctx = FmtContext().addSubst(placeholder: "cppClass", subst: def.getCppClassName());
297 return tgfmt(fmt: llvm::join(R&: extraDefinitions, Separator: "\n"), ctx: &ctx).str();
298}
299
300void DefGen::emitTopLevelDeclarations() {
301 // Inherit constructors from the attribute or type class.
302 defCls.declare<VisibilityDeclaration>(args: Visibility::Public);
303 defCls.declare<UsingDeclaration>(args: "Base::Base");
304
305 // Emit the extra declarations first in case there's a definition in there.
306 std::string extraDecl = formatExtraDeclarations(def);
307 std::string extraDef = formatExtraDefinitions(def);
308 defCls.declare<ExtraClassDeclaration>(args: std::move(extraDecl),
309 args: std::move(extraDef));
310}
311
312void DefGen::emitName() {
313 StringRef name;
314 if (auto *attrDef = dyn_cast<AttrDef>(Val: &def)) {
315 name = attrDef->getAttrName();
316 } else {
317 auto *typeDef = cast<TypeDef>(Val: &def);
318 name = typeDef->getTypeName();
319 }
320 std::string nameDecl =
321 strfmt(fmt: "static constexpr ::llvm::StringLiteral name = \"{0}\";\n", parameters&: name);
322 defCls.declare<ExtraClassDeclaration>(args: std::move(nameDecl));
323}
324
325void DefGen::emitDialectName() {
326 std::string decl =
327 strfmt(fmt: "static constexpr ::llvm::StringLiteral dialectName = \"{0}\";\n",
328 parameters: def.getDialect().getName());
329 defCls.declare<ExtraClassDeclaration>(args: std::move(decl));
330}
331
332void DefGen::emitBuilders() {
333 if (!def.skipDefaultBuilders()) {
334 emitDefaultBuilder();
335 if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
336 emitCheckedBuilder();
337 }
338 for (auto &builder : def.getBuilders()) {
339 emitCustomBuilder(builder);
340 if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
341 emitCheckedCustomBuilder(builder);
342 }
343}
344
345void DefGen::emitVerifierDecl() {
346 defCls.declareStaticMethod(
347 retType: "::llvm::LogicalResult", name: "verify",
348 args: getBuilderParams(prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
349 "emitError"}}));
350}
351
352static const char *const patternParameterVerificationCode = R"(
353if (!({0})) {
354 emitError() << "failed to verify '{1}': {2}";
355 return ::mlir::failure();
356}
357)";
358
359void DefGen::emitInvariantsVerifierImpl() {
360 SmallVector<MethodParameter> builderParams = getBuilderParams(
361 prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
362 Method *verifier =
363 defCls.addMethod(retType: "::llvm::LogicalResult", name: "verifyInvariantsImpl",
364 properties: Method::Static, args&: builderParams);
365 verifier->body().indent();
366
367 // Generate verification for each parameter that is a type constraint.
368 for (auto it : llvm::enumerate(First: def.getParameters())) {
369 const AttrOrTypeParameter &param = it.value();
370 std::optional<Constraint> constraint = param.getConstraint();
371 // No verification needed for parameters that are not type constraints.
372 if (!constraint.has_value())
373 continue;
374 FmtContext ctx;
375 // Note: Skip over the first method parameter (`emitError`).
376 ctx.withSelf(subst: builderParams[it.index() + 1].getName());
377 std::string condition = tgfmt(fmt: constraint->getConditionTemplate(), ctx: &ctx);
378 verifier->body() << formatv(Fmt: patternParameterVerificationCode, Vals&: condition,
379 Vals: param.getName(), Vals: constraint->getSummary())
380 << "\n";
381 }
382 verifier->body() << "return ::mlir::success();";
383}
384
385void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) {
386 if (!hasImpl && !hasCustomVerifier)
387 return;
388 defCls.declare<UsingDeclaration>(args: "Base::getChecked");
389 SmallVector<MethodParameter> builderParams = getBuilderParams(
390 prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
391 Method *verifier =
392 defCls.addMethod(retType: "::llvm::LogicalResult", name: "verifyInvariants",
393 properties: Method::Static, args&: builderParams);
394 verifier->body().indent();
395
396 auto emitVerifierCall = [&](StringRef name) {
397 verifier->body() << strfmt(fmt: "if (::mlir::failed({0}(", parameters&: name);
398 llvm::interleaveComma(
399 c: llvm::map_range(C&: builderParams,
400 F: [](auto &param) { return param.getName(); }),
401 os&: verifier->body());
402 verifier->body() << ")))\n";
403 verifier->body() << " return ::mlir::failure();\n";
404 };
405
406 if (hasImpl) {
407 // Call the verifier that checks the type constraints.
408 emitVerifierCall("verifyInvariantsImpl");
409 }
410 if (hasCustomVerifier) {
411 // Call the custom verifier that is provided by the user.
412 emitVerifierCall("verify");
413 }
414 verifier->body() << "return ::mlir::success();";
415}
416
417void DefGen::emitParserPrinter() {
418 auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
419 retType: "::llvm::StringLiteral", name: "getMnemonic");
420 mnemonic->body().indent() << strfmt(fmt: "return {\"{0}\"};", parameters: *def.getMnemonic());
421
422 // Declare the parser and printer, if needed.
423 bool hasAssemblyFormat = def.getAssemblyFormat().has_value();
424 if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat)
425 return;
426
427 // Declare the parser.
428 SmallVector<MethodParameter> parserParams;
429 parserParams.emplace_back(Args: "::mlir::AsmParser &", Args: "odsParser");
430 if (isa<AttrDef>(Val: &def))
431 parserParams.emplace_back(Args: "::mlir::Type", Args: "odsType");
432 auto *parser = defCls.addMethod(retType: strfmt(fmt: "::mlir::{0}", parameters&: valueType), name: "parse",
433 properties: hasAssemblyFormat ? Method::Static
434 : Method::StaticDeclaration,
435 args: std::move(parserParams));
436 // Declare the printer.
437 auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration;
438 Method *printer =
439 defCls.addMethod(retType: "void", name: "print", properties: props,
440 args: MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
441 // Emit the bodies if we are using the declarative format.
442 if (hasAssemblyFormat)
443 return generateAttrOrTypeFormat(def, parser&: parser->body(), printer&: printer->body());
444}
445
446void DefGen::emitAccessors() {
447 for (auto &param : params) {
448 Method *m = defCls.addMethod(
449 retType: param.getCppAccessorType(), name: param.getAccessorName(),
450 properties: def.genStorageClass() ? Method::Const : Method::ConstDeclaration);
451 // Generate accessor definitions only if we also generate the storage
452 // class. Otherwise, let the user define the exact accessor definition.
453 if (!def.genStorageClass())
454 continue;
455 m->body().indent() << "return getImpl()->" << param.getName() << ";";
456 }
457}
458
459void DefGen::emitInterfaceMethods() {
460 for (auto &traitDef : def.getTraits())
461 if (auto *trait = dyn_cast<InterfaceTrait>(Val: &traitDef))
462 if (trait->shouldDeclareMethods())
463 emitTraitMethods(trait: *trait);
464}
465
466//===----------------------------------------------------------------------===//
467// Builder Emission
468//===----------------------------------------------------------------------===//
469
470SmallVector<MethodParameter>
471DefGen::getBuilderParams(std::initializer_list<MethodParameter> prefix) const {
472 SmallVector<MethodParameter> builderParams;
473 builderParams.append(in_start: prefix.begin(), in_end: prefix.end());
474 for (auto &param : params)
475 builderParams.emplace_back(Args: param.getCppType(), Args: param.getName());
476 return builderParams;
477}
478
479void DefGen::emitDefaultBuilder() {
480 Method *m = defCls.addStaticMethod(
481 retType: def.getCppClassName(), name: "get",
482 args: getBuilderParams(prefix: {{"::mlir::MLIRContext *", "context"}}));
483 MethodBody &body = m->body().indent();
484 auto scope = body.scope(open: "return Base::get(context", close: ");");
485 for (const auto &param : params)
486 body << ", std::move(" << param.getName() << ")";
487}
488
489void DefGen::emitCheckedBuilder() {
490 Method *m = defCls.addStaticMethod(
491 retType: def.getCppClassName(), name: "getChecked",
492 args: getBuilderParams(
493 prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"},
494 {"::mlir::MLIRContext *", "context"}}));
495 MethodBody &body = m->body().indent();
496 auto scope = body.scope(open: "return Base::getChecked(emitError, context", close: ");");
497 for (const auto &param : params)
498 body << ", " << param.getName();
499}
500
501static SmallVector<MethodParameter>
502getCustomBuilderParams(std::initializer_list<MethodParameter> prefix,
503 const AttrOrTypeBuilder &builder) {
504 auto params = builder.getParameters();
505 SmallVector<MethodParameter> builderParams;
506 builderParams.append(in_start: prefix.begin(), in_end: prefix.end());
507 if (!builder.hasInferredContextParameter())
508 builderParams.emplace_back(Args: "::mlir::MLIRContext *", Args: "context");
509 for (auto &param : params) {
510 builderParams.emplace_back(Args: param.getCppType(), Args: *param.getName(),
511 Args: param.getDefaultValue());
512 }
513 return builderParams;
514}
515
516void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
517 // Don't emit a body if there isn't one.
518 auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
519 StringRef returnType = def.getCppClassName();
520 if (std::optional<StringRef> builderReturnType = builder.getReturnType())
521 returnType = *builderReturnType;
522 Method *m = defCls.addMethod(retType&: returnType, name: "get", properties: props,
523 args: getCustomBuilderParams(prefix: {}, builder));
524 if (!builder.getBody())
525 return;
526
527 // Format the body and emit it.
528 FmtContext ctx;
529 ctx.addSubst(placeholder: "_get", subst: "Base::get");
530 if (!builder.hasInferredContextParameter())
531 ctx.addSubst(placeholder: "_ctxt", subst: "context");
532 std::string bodyStr = tgfmt(fmt: *builder.getBody(), ctx: &ctx);
533 m->body().indent().getStream().printReindented(str: bodyStr);
534}
535
536/// Replace all instances of 'from' to 'to' in `str` and return the new string.
537static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
538 size_t pos = 0;
539 while ((pos = str.find(s: from.data(), pos: pos, n: from.size())) != std::string::npos)
540 str.replace(pos: pos, n1: from.size(), s: to.data(), n2: to.size());
541 return str;
542}
543
544void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
545 // Don't emit a body if there isn't one.
546 auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
547 StringRef returnType = def.getCppClassName();
548 if (std::optional<StringRef> builderReturnType = builder.getReturnType())
549 returnType = *builderReturnType;
550 Method *m = defCls.addMethod(
551 retType&: returnType, name: "getChecked", properties: props,
552 args: getCustomBuilderParams(
553 prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}},
554 builder));
555 if (!builder.getBody())
556 return;
557
558 // Format the body and emit it. Replace $_get(...) with
559 // Base::getChecked(emitError, ...)
560 FmtContext ctx;
561 if (!builder.hasInferredContextParameter())
562 ctx.addSubst(placeholder: "_ctxt", subst: "context");
563 std::string bodyStr = replaceInStr(str: builder.getBody()->str(), from: "$_get(",
564 to: "Base::getChecked(emitError, ");
565 bodyStr = tgfmt(fmt: bodyStr, ctx: &ctx);
566 m->body().indent().getStream().printReindented(str: bodyStr);
567}
568
569//===----------------------------------------------------------------------===//
570// Interface Method Emission
571//===----------------------------------------------------------------------===//
572
573void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
574 // Get the set of methods that should always be declared.
575 auto alwaysDeclaredMethods = trait.getAlwaysDeclaredMethods();
576 StringSet<> alwaysDeclared;
577 alwaysDeclared.insert_range(R&: alwaysDeclaredMethods);
578
579 Interface iface = trait.getInterface(); // causes strange bugs if elided
580 for (auto &method : iface.getMethods()) {
581 // Don't declare if the method has a body. Or if the method has a default
582 // implementation and the def didn't request that it always be declared.
583 if (method.getBody() || (method.getDefaultImplementation() &&
584 !alwaysDeclared.count(Key: method.getName())))
585 continue;
586 emitTraitMethod(method);
587 }
588}
589
590void DefGen::emitTraitMethod(const InterfaceMethod &method) {
591 // All interface methods are declaration-only.
592 auto props =
593 method.isStatic() ? Method::StaticDeclaration : Method::ConstDeclaration;
594 SmallVector<MethodParameter> params;
595 for (auto &param : method.getArguments())
596 params.emplace_back(Args: param.type, Args: param.name);
597 defCls.addMethod(retType: method.getReturnType(), name: method.getName(), properties: props,
598 args: std::move(params));
599}
600
601//===----------------------------------------------------------------------===//
602// OpAsm{Type,Attr}Interface Default Method Emission
603
604void DefGen::emitMnemonicAliasMethod() {
605 // If the mnemonic is not set, there is nothing to do.
606 if (!def.getMnemonic())
607 return;
608
609 // Emit the mnemonic alias method.
610 SmallVector<MethodParameter> params{{"::llvm::raw_ostream &", "os"}};
611 Method *m = defCls.addMethod<Method::Const>(retType: "::mlir::OpAsmAliasResult",
612 name: "getAlias", args: std::move(params));
613 m->body().indent() << strfmt(fmt: "os << \"{0}\";\n", parameters: *def.getMnemonic())
614 << "return ::mlir::OpAsmAliasResult::OverridableAlias;\n";
615}
616
617//===----------------------------------------------------------------------===//
618// Storage Class Emission
619//===----------------------------------------------------------------------===//
620
621void DefGen::emitStorageConstructor() {
622 Constructor *ctor =
623 storageCls->addConstructor<Method::Inline>(args: getBuilderParams(prefix: {}));
624 for (auto &param : params) {
625 std::string movedValue = ("std::move(" + param.getName() + ")").str();
626 ctor->addMemberInitializer(name: param.getName(), value&: movedValue);
627 }
628}
629
630void DefGen::emitKeyType() {
631 std::string keyType("std::tuple<");
632 llvm::raw_string_ostream os(keyType);
633 llvm::interleaveComma(c: params, os,
634 each_fn: [&](auto &param) { os << param.getCppType(); });
635 os << '>';
636 storageCls->declare<UsingDeclaration>(args: "KeyTy", args: std::move(os.str()));
637
638 // Add a method to construct the key type from the storage.
639 Method *m = storageCls->addConstMethod<Method::Inline>(retType: "KeyTy", name: "getAsKey");
640 m->body().indent() << "return KeyTy(";
641 llvm::interleaveComma(c: params, os&: m->body().indent(),
642 each_fn: [&](auto &param) { m->body() << param.getName(); });
643 m->body() << ");";
644}
645
646void DefGen::emitEquals() {
647 Method *eq = storageCls->addConstMethod<Method::Inline>(
648 retType: "bool", name: "operator==", args: MethodParameter("const KeyTy &", "tblgenKey"));
649 auto &body = eq->body().indent();
650 auto scope = body.scope(open: "return (", close: ");");
651 const auto eachFn = [&](auto it) {
652 FmtContext ctx({{"_lhs", it.value().getName()},
653 {"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}});
654 body << tgfmt(it.value().getComparator(), &ctx);
655 };
656 llvm::interleave(c: llvm::enumerate(First&: params), os&: body, each_fn: eachFn, separator: ") && (");
657}
658
659void DefGen::emitHashKey() {
660 Method *hash = storageCls->addStaticInlineMethod(
661 retType: "::llvm::hash_code", name: "hashKey",
662 args: MethodParameter("const KeyTy &", "tblgenKey"));
663 auto &body = hash->body().indent();
664 auto scope = body.scope(open: "return ::llvm::hash_combine(", close: ");");
665 llvm::interleaveComma(c: llvm::enumerate(First&: params), os&: body, each_fn: [&](auto it) {
666 body << llvm::formatv("std::get<{0}>(tblgenKey)", it.index());
667 });
668}
669
670void DefGen::emitConstruct() {
671 Method *construct = storageCls->addMethod(
672 retType: strfmt(fmt: "{0} *", parameters: def.getStorageClassName()), name: "construct",
673 properties: def.hasStorageCustomConstructor() ? Method::StaticDeclaration
674 : Method::StaticInline,
675 args: MethodParameter(strfmt(fmt: "::mlir::{0}StorageAllocator &", parameters&: valueType),
676 "allocator"),
677 args: MethodParameter("KeyTy &&", "tblgenKey"));
678 if (!def.hasStorageCustomConstructor()) {
679 auto &body = construct->body().indent();
680 for (const auto &it : llvm::enumerate(First&: params)) {
681 body << formatv(Fmt: "auto {0} = std::move(std::get<{1}>(tblgenKey));\n",
682 Vals: it.value().getName(), Vals: it.index());
683 }
684 // Use the parameters' custom allocator code, if provided.
685 FmtContext ctx = FmtContext().addSubst(placeholder: "_allocator", subst: "allocator");
686 for (auto &param : params) {
687 if (std::optional<StringRef> allocCode = param.getAllocator()) {
688 ctx.withSelf(subst: param.getName()).addSubst(placeholder: "_dst", subst: param.getName());
689 body << tgfmt(fmt: *allocCode, ctx: &ctx) << '\n';
690 }
691 }
692 auto scope =
693 body.scope(open: strfmt(fmt: "return new (allocator.allocate<{0}>()) {0}(",
694 parameters: def.getStorageClassName()),
695 close: ");");
696 llvm::interleaveComma(c: params, os&: body, each_fn: [&](auto &param) {
697 body << "std::move(" << param.getName() << ")";
698 });
699 }
700}
701
702void DefGen::emitStorageClass() {
703 // Add the appropriate parent class.
704 storageCls->addParent(parent: strfmt(fmt: "::mlir::{0}Storage", parameters&: valueType));
705 // Add the constructor.
706 emitStorageConstructor();
707 // Declare the key type.
708 emitKeyType();
709 // Add the comparison method.
710 emitEquals();
711 // Emit the key hash method.
712 emitHashKey();
713 // Emit the storage constructor. Just declare it if the user wants to define
714 // it themself.
715 emitConstruct();
716 // Emit the storage class members as public, at the very end of the struct.
717 storageCls->finalize();
718 for (auto &param : params) {
719 if (param.getCppType().contains(Other: "APInt") && !param.hasCustomComparator()) {
720 PrintFatalError(
721 ErrorLoc: def.getLoc(),
722 Msg: "Using a raw APInt parameter without a custom comparator is "
723 "not supported because an assert in the equality operator is "
724 "triggered when the two APInts have different bit widths. This can "
725 "lead to unexpected crashes. Use an `APIntParameter` or "
726 "provide a custom comparator.");
727 }
728 storageCls->declare<Field>(args: param.getCppType(), args: param.getName());
729 }
730}
731
732//===----------------------------------------------------------------------===//
733// DefGenerator
734//===----------------------------------------------------------------------===//
735
736namespace {
737/// This struct is the base generator used when processing tablegen interfaces.
738class DefGenerator {
739public:
740 bool emitDecls(StringRef selectedDialect);
741 bool emitDefs(StringRef selectedDialect);
742
743protected:
744 DefGenerator(ArrayRef<const Record *> defs, raw_ostream &os,
745 StringRef defType, StringRef valueType, bool isAttrGenerator)
746 : defRecords(defs), os(os), defType(defType), valueType(valueType),
747 isAttrGenerator(isAttrGenerator) {
748 // Sort by occurrence in file.
749 llvm::sort(C&: defRecords, Comp: [](const Record *lhs, const Record *rhs) {
750 return lhs->getID() < rhs->getID();
751 });
752 }
753
754 /// Emit the list of def type names.
755 void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
756 /// Emit the code to dispatch between different defs during parsing/printing.
757 void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
758
759 /// The set of def records to emit.
760 std::vector<const Record *> defRecords;
761 /// The attribute or type class to emit.
762 /// The stream to emit to.
763 raw_ostream &os;
764 /// The prefix of the tablegen def name, e.g. Attr or Type.
765 StringRef defType;
766 /// The C++ base value type of the def, e.g. Attribute or Type.
767 StringRef valueType;
768 /// Flag indicating if this generator is for Attributes. False if the
769 /// generator is for types.
770 bool isAttrGenerator;
771};
772
773/// A specialized generator for AttrDefs.
774struct AttrDefGenerator : public DefGenerator {
775 AttrDefGenerator(const RecordKeeper &records, raw_ostream &os)
776 : DefGenerator(records.getAllDerivedDefinitionsIfDefined(ClassName: "AttrDef"), os,
777 "Attr", "Attribute", /*isAttrGenerator=*/true) {}
778};
779/// A specialized generator for TypeDefs.
780struct TypeDefGenerator : public DefGenerator {
781 TypeDefGenerator(const RecordKeeper &records, raw_ostream &os)
782 : DefGenerator(records.getAllDerivedDefinitionsIfDefined(ClassName: "TypeDef"), os,
783 "Type", "Type", /*isAttrGenerator=*/false) {}
784};
785} // namespace
786
787//===----------------------------------------------------------------------===//
788// GEN: Declarations
789//===----------------------------------------------------------------------===//
790
791/// Print this above all the other declarations. Contains type declarations used
792/// later on.
793static const char *const typeDefDeclHeader = R"(
794namespace mlir {
795class AsmParser;
796class AsmPrinter;
797} // namespace mlir
798)";
799
800bool DefGenerator::emitDecls(StringRef selectedDialect) {
801 emitSourceFileHeader(Desc: (defType + "Def Declarations").str(), OS&: os);
802 IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
803
804 // Output the common "header".
805 os << typeDefDeclHeader;
806
807 SmallVector<AttrOrTypeDef, 16> defs;
808 collectAllDefs(selectedDialect, records: defRecords, resultDefs&: defs);
809 if (defs.empty())
810 return false;
811 {
812 NamespaceEmitter nsEmitter(os, defs.front().getDialect());
813
814 // Declare all the def classes first (in case they reference each other).
815 for (const AttrOrTypeDef &def : defs) {
816 std::string comments = tblgen::emitSummaryAndDescComments(
817 summary: def.getSummary(), description: def.getDescription());
818 if (!comments.empty()) {
819 os << comments << "\n";
820 }
821 os << "class " << def.getCppClassName() << ";\n";
822 }
823
824 // Emit the declarations.
825 for (const AttrOrTypeDef &def : defs)
826 DefGen(def).emitDecl(os);
827 }
828 // Emit the TypeID explicit specializations to have a single definition for
829 // each of these.
830 for (const AttrOrTypeDef &def : defs)
831 if (!def.getDialect().getCppNamespace().empty())
832 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID("
833 << def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
834 << ")\n";
835
836 return false;
837}
838
839//===----------------------------------------------------------------------===//
840// GEN: Def List
841//===----------------------------------------------------------------------===//
842
843void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
844 IfDefScope scope("GET_" + defType.upper() + "DEF_LIST", os);
845 auto interleaveFn = [&](const AttrOrTypeDef &def) {
846 os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName();
847 };
848 llvm::interleave(c: defs, os, each_fn: interleaveFn, separator: ",\n");
849 os << "\n";
850}
851
852//===----------------------------------------------------------------------===//
853// GEN: Definitions
854//===----------------------------------------------------------------------===//
855
856/// The code block for default attribute parser/printer dispatch boilerplate.
857/// {0}: the dialect fully qualified class name.
858/// {1}: the optional code for the dynamic attribute parser dispatch.
859/// {2}: the optional code for the dynamic attribute printer dispatch.
860static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
861/// Parse an attribute registered to this dialect.
862::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser,
863 ::mlir::Type type) const {{
864 ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
865 ::llvm::StringRef attrTag;
866 {{
867 ::mlir::Attribute attr;
868 auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
869 if (parseResult.has_value())
870 return attr;
871 }
872 {1}
873 parser.emitError(typeLoc) << "unknown attribute `"
874 << attrTag << "` in dialect `" << getNamespace() << "`";
875 return {{};
876}
877/// Print an attribute registered to this dialect.
878void {0}::printAttribute(::mlir::Attribute attr,
879 ::mlir::DialectAsmPrinter &printer) const {{
880 if (::mlir::succeeded(generatedAttributePrinter(attr, printer)))
881 return;
882 {2}
883}
884)";
885
886/// The code block for dynamic attribute parser dispatch boilerplate.
887static const char *const dialectDynamicAttrParserDispatch = R"(
888 {
889 ::mlir::Attribute genAttr;
890 auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr);
891 if (parseResult.has_value()) {
892 if (::mlir::succeeded(parseResult.value()))
893 return genAttr;
894 return Attribute();
895 }
896 }
897)";
898
899/// The code block for dynamic type printer dispatch boilerplate.
900static const char *const dialectDynamicAttrPrinterDispatch = R"(
901 if (::mlir::succeeded(printIfDynamicAttr(attr, printer)))
902 return;
903)";
904
905/// The code block for default type parser/printer dispatch boilerplate.
906/// {0}: the dialect fully qualified class name.
907/// {1}: the optional code for the dynamic type parser dispatch.
908/// {2}: the optional code for the dynamic type printer dispatch.
909static const char *const dialectDefaultTypePrinterParserDispatch = R"(
910/// Parse a type registered to this dialect.
911::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{
912 ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
913 ::llvm::StringRef mnemonic;
914 ::mlir::Type genType;
915 auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
916 if (parseResult.has_value())
917 return genType;
918 {1}
919 parser.emitError(typeLoc) << "unknown type `"
920 << mnemonic << "` in dialect `" << getNamespace() << "`";
921 return {{};
922}
923/// Print a type registered to this dialect.
924void {0}::printType(::mlir::Type type,
925 ::mlir::DialectAsmPrinter &printer) const {{
926 if (::mlir::succeeded(generatedTypePrinter(type, printer)))
927 return;
928 {2}
929}
930)";
931
932/// The code block for dynamic type parser dispatch boilerplate.
933static const char *const dialectDynamicTypeParserDispatch = R"(
934 {
935 auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
936 if (parseResult.has_value()) {
937 if (::mlir::succeeded(parseResult.value()))
938 return genType;
939 return ::mlir::Type();
940 }
941 }
942)";
943
944/// The code block for dynamic type printer dispatch boilerplate.
945static const char *const dialectDynamicTypePrinterDispatch = R"(
946 if (::mlir::succeeded(printIfDynamicType(type, printer)))
947 return;
948)";
949
950/// Emit the dialect printer/parser dispatcher. User's code should call these
951/// functions from their dialect's print/parse methods.
952void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
953 if (llvm::none_of(Range&: defs, P: [](const AttrOrTypeDef &def) {
954 return def.getMnemonic().has_value();
955 })) {
956 return;
957 }
958 // Declare the parser.
959 SmallVector<MethodParameter> params = {{"::mlir::AsmParser &", "parser"},
960 {"::llvm::StringRef *", "mnemonic"}};
961 if (isAttrGenerator)
962 params.emplace_back(Args: "::mlir::Type", Args: "type");
963 params.emplace_back(Args: strfmt(fmt: "::mlir::{0} &", parameters&: valueType), Args: "value");
964 Method parse("::mlir::OptionalParseResult",
965 strfmt(fmt: "generated{0}Parser", parameters&: valueType), Method::StaticInline,
966 std::move(params));
967 // Declare the printer.
968 Method printer("::llvm::LogicalResult",
969 strfmt(fmt: "generated{0}Printer", parameters&: valueType), Method::StaticInline,
970 {{strfmt(fmt: "::mlir::{0}", parameters&: valueType), "def"},
971 {"::mlir::AsmPrinter &", "printer"}});
972
973 // The parser dispatch uses a KeywordSwitch, matching on the mnemonic and
974 // calling the def's parse function.
975 parse.body() << " return "
976 "::mlir::AsmParser::KeywordSwitch<::mlir::"
977 "OptionalParseResult>(parser)\n";
978 const char *const getValueForMnemonic =
979 R"( .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{
980 value = {0}::{1};
981 return ::mlir::success(!!value);
982 })
983)";
984
985 // The printer dispatch uses llvm::TypeSwitch to find and call the correct
986 // printer.
987 printer.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType
988 << ", ::llvm::LogicalResult>(def)";
989 const char *const printValue = R"( .Case<{0}>([&](auto t) {{
990 printer << {0}::getMnemonic();{1}
991 return ::mlir::success();
992 })
993)";
994 for (auto &def : defs) {
995 if (!def.getMnemonic())
996 continue;
997 bool hasParserPrinterDecl =
998 def.hasCustomAssemblyFormat() || def.getAssemblyFormat();
999 std::string defClass = strfmt(
1000 fmt: "{0}::{1}", parameters: def.getDialect().getCppNamespace(), parameters: def.getCppClassName());
1001
1002 // If the def has no parameters or parser code, invoke a normal `get`.
1003 std::string parseOrGet =
1004 hasParserPrinterDecl
1005 ? strfmt(fmt: "parse(parser{0})", parameters: isAttrGenerator ? ", type" : "")
1006 : "get(parser.getContext())";
1007 parse.body() << llvm::formatv(Fmt: getValueForMnemonic, Vals&: defClass, Vals&: parseOrGet);
1008
1009 // If the def has no parameters and no printer, just print the mnemonic.
1010 StringRef printDef = "";
1011 if (hasParserPrinterDecl)
1012 printDef = "\nt.print(printer);";
1013 printer.body() << llvm::formatv(Fmt: printValue, Vals&: defClass, Vals&: printDef);
1014 }
1015 parse.body() << " .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n"
1016 " *mnemonic = keyword;\n"
1017 " return std::nullopt;\n"
1018 " });";
1019 printer.body() << " .Default([](auto) { return ::mlir::failure(); });";
1020
1021 raw_indented_ostream indentedOs(os);
1022 parse.writeDeclTo(os&: indentedOs);
1023 printer.writeDeclTo(os&: indentedOs);
1024}
1025
1026bool DefGenerator::emitDefs(StringRef selectedDialect) {
1027 emitSourceFileHeader(Desc: (defType + "Def Definitions").str(), OS&: os);
1028
1029 SmallVector<AttrOrTypeDef, 16> defs;
1030 collectAllDefs(selectedDialect, records: defRecords, resultDefs&: defs);
1031 if (defs.empty())
1032 return false;
1033 emitTypeDefList(defs);
1034
1035 IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
1036 emitParsePrintDispatch(defs);
1037 for (const AttrOrTypeDef &def : defs) {
1038 {
1039 NamespaceEmitter ns(os, def.getDialect());
1040 DefGen gen(def);
1041 gen.emitDef(os);
1042 }
1043 // Emit the TypeID explicit specializations to have a single symbol def.
1044 if (!def.getDialect().getCppNamespace().empty())
1045 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID("
1046 << def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
1047 << ")\n";
1048 }
1049
1050 Dialect firstDialect = defs.front().getDialect();
1051
1052 // Emit the default parser/printer for Attributes if the dialect asked for it.
1053 if (isAttrGenerator && firstDialect.useDefaultAttributePrinterParser()) {
1054 NamespaceEmitter nsEmitter(os, firstDialect);
1055 if (firstDialect.isExtensible()) {
1056 os << llvm::formatv(Fmt: dialectDefaultAttrPrinterParserDispatch,
1057 Vals: firstDialect.getCppClassName(),
1058 Vals: dialectDynamicAttrParserDispatch,
1059 Vals: dialectDynamicAttrPrinterDispatch);
1060 } else {
1061 os << llvm::formatv(Fmt: dialectDefaultAttrPrinterParserDispatch,
1062 Vals: firstDialect.getCppClassName(), Vals: "", Vals: "");
1063 }
1064 }
1065
1066 // Emit the default parser/printer for Types if the dialect asked for it.
1067 if (!isAttrGenerator && firstDialect.useDefaultTypePrinterParser()) {
1068 NamespaceEmitter nsEmitter(os, firstDialect);
1069 if (firstDialect.isExtensible()) {
1070 os << llvm::formatv(Fmt: dialectDefaultTypePrinterParserDispatch,
1071 Vals: firstDialect.getCppClassName(),
1072 Vals: dialectDynamicTypeParserDispatch,
1073 Vals: dialectDynamicTypePrinterDispatch);
1074 } else {
1075 os << llvm::formatv(Fmt: dialectDefaultTypePrinterParserDispatch,
1076 Vals: firstDialect.getCppClassName(), Vals: "", Vals: "");
1077 }
1078 }
1079
1080 return false;
1081}
1082
1083//===----------------------------------------------------------------------===//
1084// Constraints
1085//===----------------------------------------------------------------------===//
1086
1087/// Find all type constraints for which a C++ function should be generated.
1088static std::vector<Constraint> getAllCppConstraints(const RecordKeeper &records,
1089 StringRef constraintKind) {
1090 std::vector<Constraint> result;
1091 for (const Record *def :
1092 records.getAllDerivedDefinitionsIfDefined(ClassName: constraintKind)) {
1093 // Ignore constraints defined outside of the top-level file.
1094 if (llvm::SrcMgr.FindBufferContainingLoc(Loc: def->getLoc()[0]) !=
1095 llvm::SrcMgr.getMainFileID())
1096 continue;
1097 Constraint constr(def);
1098 // Generate C++ function only if "cppFunctionName" is set.
1099 if (!constr.getCppFunctionName())
1100 continue;
1101 result.push_back(x: constr);
1102 }
1103 return result;
1104}
1105
1106static std::vector<Constraint>
1107getAllCppTypeConstraints(const RecordKeeper &records) {
1108 return getAllCppConstraints(records, constraintKind: "TypeConstraint");
1109}
1110
1111static std::vector<Constraint>
1112getAllCppAttrConstraints(const RecordKeeper &records) {
1113 return getAllCppConstraints(records, constraintKind: "AttrConstraint");
1114}
1115
1116/// Emit the declarations for the given constraints, of the form:
1117/// `bool <constraintCppFunctionName>(<parameterTypeName> <parameterName>);`
1118static void emitConstraintDecls(const std::vector<Constraint> &constraints,
1119 raw_ostream &os, StringRef parameterTypeName,
1120 StringRef parameterName) {
1121 static const char *const constraintDecl = "bool {0}({1} {2});\n";
1122 for (Constraint constr : constraints)
1123 os << strfmt(fmt: constraintDecl, parameters: *constr.getCppFunctionName(),
1124 parameters&: parameterTypeName, parameters&: parameterName);
1125}
1126
1127static void emitTypeConstraintDecls(const RecordKeeper &records,
1128 raw_ostream &os) {
1129 emitConstraintDecls(constraints: getAllCppTypeConstraints(records), os, parameterTypeName: "::mlir::Type",
1130 parameterName: "type");
1131}
1132
1133static void emitAttrConstraintDecls(const RecordKeeper &records,
1134 raw_ostream &os) {
1135 emitConstraintDecls(constraints: getAllCppAttrConstraints(records), os,
1136 parameterTypeName: "::mlir::Attribute", parameterName: "attr");
1137}
1138
1139/// Emit the definitions for the given constraints, of the form:
1140/// `bool <constraintCppFunctionName>(<parameterTypeName> <parameterName>) {
1141/// return (<condition>); }`
1142/// where `<condition>` is the condition template with the `self` variable
1143/// replaced with the `selfName` parameter.
1144static void emitConstraintDefs(const std::vector<Constraint> &constraints,
1145 raw_ostream &os, StringRef parameterTypeName,
1146 StringRef selfName) {
1147 static const char *const constraintDef = R"(
1148bool {0}({1} {2}) {
1149return ({3});
1150}
1151)";
1152
1153 for (Constraint constr : constraints) {
1154 FmtContext ctx;
1155 ctx.withSelf(subst: selfName);
1156 std::string condition = tgfmt(fmt: constr.getConditionTemplate(), ctx: &ctx);
1157 os << strfmt(fmt: constraintDef, parameters: *constr.getCppFunctionName(), parameters&: parameterTypeName,
1158 parameters&: selfName, parameters&: condition);
1159 }
1160}
1161
1162static void emitTypeConstraintDefs(const RecordKeeper &records,
1163 raw_ostream &os) {
1164 emitConstraintDefs(constraints: getAllCppTypeConstraints(records), os, parameterTypeName: "::mlir::Type",
1165 selfName: "type");
1166}
1167
1168static void emitAttrConstraintDefs(const RecordKeeper &records,
1169 raw_ostream &os) {
1170 emitConstraintDefs(constraints: getAllCppAttrConstraints(records), os, parameterTypeName: "::mlir::Attribute",
1171 selfName: "attr");
1172}
1173
1174//===----------------------------------------------------------------------===//
1175// GEN: Registration hooks
1176//===----------------------------------------------------------------------===//
1177
1178//===----------------------------------------------------------------------===//
1179// AttrDef
1180//===----------------------------------------------------------------------===//
1181
1182static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*");
1183static llvm::cl::opt<std::string>
1184 attrDialect("attrdefs-dialect",
1185 llvm::cl::desc("Generate attributes for this dialect"),
1186 llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated);
1187
1188static mlir::GenRegistration
1189 genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
1190 [](const RecordKeeper &records, raw_ostream &os) {
1191 AttrDefGenerator generator(records, os);
1192 return generator.emitDefs(selectedDialect: attrDialect);
1193 });
1194static mlir::GenRegistration
1195 genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
1196 [](const RecordKeeper &records, raw_ostream &os) {
1197 AttrDefGenerator generator(records, os);
1198 return generator.emitDecls(selectedDialect: attrDialect);
1199 });
1200
1201static mlir::GenRegistration
1202 genAttrConstrDefs("gen-attr-constraint-defs",
1203 "Generate attribute constraint definitions",
1204 [](const RecordKeeper &records, raw_ostream &os) {
1205 emitAttrConstraintDefs(records, os);
1206 return false;
1207 });
1208static mlir::GenRegistration
1209 genAttrConstrDecls("gen-attr-constraint-decls",
1210 "Generate attribute constraint declarations",
1211 [](const RecordKeeper &records, raw_ostream &os) {
1212 emitAttrConstraintDecls(records, os);
1213 return false;
1214 });
1215
1216//===----------------------------------------------------------------------===//
1217// TypeDef
1218//===----------------------------------------------------------------------===//
1219
1220static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
1221static llvm::cl::opt<std::string>
1222 typeDialect("typedefs-dialect",
1223 llvm::cl::desc("Generate types for this dialect"),
1224 llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
1225
1226static mlir::GenRegistration
1227 genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
1228 [](const RecordKeeper &records, raw_ostream &os) {
1229 TypeDefGenerator generator(records, os);
1230 return generator.emitDefs(selectedDialect: typeDialect);
1231 });
1232static mlir::GenRegistration
1233 genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
1234 [](const RecordKeeper &records, raw_ostream &os) {
1235 TypeDefGenerator generator(records, os);
1236 return generator.emitDecls(selectedDialect: typeDialect);
1237 });
1238
1239static mlir::GenRegistration
1240 genTypeConstrDefs("gen-type-constraint-defs",
1241 "Generate type constraint definitions",
1242 [](const RecordKeeper &records, raw_ostream &os) {
1243 emitTypeConstraintDefs(records, os);
1244 return false;
1245 });
1246static mlir::GenRegistration
1247 genTypeConstrDecls("gen-type-constraint-decls",
1248 "Generate type constraint declarations",
1249 [](const RecordKeeper &records, raw_ostream &os) {
1250 emitTypeConstraintDecls(records, os);
1251 return false;
1252 });
1253

source code of mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp