1//===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "AttrOrTypeFormatGen.h"
10#include "mlir/TableGen/AttrOrTypeDef.h"
11#include "mlir/TableGen/Class.h"
12#include "mlir/TableGen/CodeGenHelpers.h"
13#include "mlir/TableGen/Format.h"
14#include "mlir/TableGen/GenInfo.h"
15#include "mlir/TableGen/Interfaces.h"
16#include "llvm/ADT/StringSet.h"
17#include "llvm/Support/CommandLine.h"
18#include "llvm/TableGen/Error.h"
19#include "llvm/TableGen/TableGenBackend.h"
20
21#define DEBUG_TYPE "mlir-tblgen-attrortypedefgen"
22
23using namespace mlir;
24using namespace mlir::tblgen;
25
26//===----------------------------------------------------------------------===//
27// Utility Functions
28//===----------------------------------------------------------------------===//
29
30/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
31/// specified and can only find one dialect's defs, use that.
32static void collectAllDefs(StringRef selectedDialect,
33 std::vector<llvm::Record *> records,
34 SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
35 // Nothing to do if no defs were found.
36 if (records.empty())
37 return;
38
39 auto defs = llvm::map_range(
40 C&: records, F: [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); });
41 if (selectedDialect.empty()) {
42 // If a dialect was not specified, ensure that all found defs belong to the
43 // same dialect.
44 if (!llvm::all_equal(Range: llvm::map_range(
45 C&: defs, F: [](const auto &def) { return def.getDialect(); }))) {
46 llvm::PrintFatalError(Msg: "defs belonging to more than one dialect. Must "
47 "select one via '--(attr|type)defs-dialect'");
48 }
49 resultDefs.assign(in_start: defs.begin(), in_end: defs.end());
50 } else {
51 // Otherwise, generate the defs that belong to the selected dialect.
52 auto dialectDefs = llvm::make_filter_range(Range&: defs, Pred: [&](const auto &def) {
53 return def.getDialect().getName().equals(selectedDialect);
54 });
55 resultDefs.assign(in_start: dialectDefs.begin(), in_end: dialectDefs.end());
56 }
57}
58
59//===----------------------------------------------------------------------===//
60// DefGen
61//===----------------------------------------------------------------------===//
62
63namespace {
64class DefGen {
65public:
66 /// Create the attribute or type class.
67 DefGen(const AttrOrTypeDef &def);
68
69 void emitDecl(raw_ostream &os) const {
70 if (storageCls && def.genStorageClass()) {
71 NamespaceEmitter ns(os, def.getStorageNamespace());
72 os << "struct " << def.getStorageClassName() << ";\n";
73 }
74 defCls.writeDeclTo(rawOs&: os);
75 }
76 void emitDef(raw_ostream &os) const {
77 if (storageCls && def.genStorageClass()) {
78 NamespaceEmitter ns(os, def.getStorageNamespace());
79 storageCls->writeDeclTo(rawOs&: os); // everything is inline
80 }
81 defCls.writeDefTo(rawOs&: os);
82 }
83
84private:
85 /// Add traits from the TableGen definition to the class.
86 void createParentWithTraits();
87 /// Emit top-level declarations: using declarations and any extra class
88 /// declarations.
89 void emitTopLevelDeclarations();
90 /// Emit the function that returns the type or attribute name.
91 void emitName();
92 /// Emit attribute or type builders.
93 void emitBuilders();
94 /// Emit a verifier for the def.
95 void emitVerifier();
96 /// Emit parsers and printers.
97 void emitParserPrinter();
98 /// Emit parameter accessors, if required.
99 void emitAccessors();
100 /// Emit interface methods.
101 void emitInterfaceMethods();
102
103 //===--------------------------------------------------------------------===//
104 // Builder Emission
105
106 /// Emit the default builder `Attribute::get`
107 void emitDefaultBuilder();
108 /// Emit the checked builder `Attribute::getChecked`
109 void emitCheckedBuilder();
110 /// Emit a custom builder.
111 void emitCustomBuilder(const AttrOrTypeBuilder &builder);
112 /// Emit a checked custom builder.
113 void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder);
114
115 //===--------------------------------------------------------------------===//
116 // Interface Method Emission
117
118 /// Emit methods for a trait.
119 void emitTraitMethods(const InterfaceTrait &trait);
120 /// Emit a trait method.
121 void emitTraitMethod(const InterfaceMethod &method);
122
123 //===--------------------------------------------------------------------===//
124 // Storage Class Emission
125 void emitStorageClass();
126 /// Generate the storage class constructor.
127 void emitStorageConstructor();
128 /// Emit the key type `KeyTy`.
129 void emitKeyType();
130 /// Emit the equality comparison operator.
131 void emitEquals();
132 /// Emit the key hash function.
133 void emitHashKey();
134 /// Emit the function to construct the storage class.
135 void emitConstruct();
136
137 //===--------------------------------------------------------------------===//
138 // Utility Function Declarations
139
140 /// Get the method parameters for a def builder, where the first several
141 /// parameters may be different.
142 SmallVector<MethodParameter>
143 getBuilderParams(std::initializer_list<MethodParameter> prefix) const;
144
145 //===--------------------------------------------------------------------===//
146 // Class fields
147
148 /// The attribute or type definition.
149 const AttrOrTypeDef &def;
150 /// The list of attribute or type parameters.
151 ArrayRef<AttrOrTypeParameter> params;
152 /// The attribute or type class.
153 Class defCls;
154 /// An optional attribute or type storage class. The storage class will
155 /// exist if and only if the def has more than zero parameters.
156 std::optional<Class> storageCls;
157
158 /// The C++ base value of the def, either "Attribute" or "Type".
159 StringRef valueType;
160 /// The prefix/suffix of the TableGen def name, either "Attr" or "Type".
161 StringRef defType;
162};
163} // namespace
164
165DefGen::DefGen(const AttrOrTypeDef &def)
166 : def(def), params(def.getParameters()), defCls(def.getCppClassName()),
167 valueType(isa<AttrDef>(Val: def) ? "Attribute" : "Type"),
168 defType(isa<AttrDef>(Val: def) ? "Attr" : "Type") {
169 // Check that all parameters have names.
170 for (const AttrOrTypeParameter &param : def.getParameters())
171 if (param.isAnonymous())
172 llvm::PrintFatalError(Msg: "all parameters must have a name");
173
174 // If a storage class is needed, create one.
175 if (def.getNumParameters() > 0)
176 storageCls.emplace(args: def.getStorageClassName(), /*isStruct=*/args: true);
177
178 // Create the parent class with any indicated traits.
179 createParentWithTraits();
180 // Emit top-level declarations.
181 emitTopLevelDeclarations();
182 // Emit builders for defs with parameters
183 if (storageCls)
184 emitBuilders();
185 // Emit the type name.
186 emitName();
187 // Emit the verifier.
188 if (storageCls && def.genVerifyDecl())
189 emitVerifier();
190 // Emit the mnemonic, if there is one, and any associated parser and printer.
191 if (def.getMnemonic())
192 emitParserPrinter();
193 // Emit accessors
194 if (def.genAccessors())
195 emitAccessors();
196 // Emit trait interface methods
197 emitInterfaceMethods();
198 defCls.finalize();
199 // Emit a storage class if one is needed
200 if (storageCls && def.genStorageClass())
201 emitStorageClass();
202}
203
204void DefGen::createParentWithTraits() {
205 ParentClass defParent(strfmt(fmt: "::mlir::{0}::{1}Base", parameters&: valueType, parameters&: defType));
206 defParent.addTemplateParam(param: def.getCppClassName());
207 defParent.addTemplateParam(param: def.getCppBaseClassName());
208 defParent.addTemplateParam(param: storageCls
209 ? strfmt(fmt: "{0}::{1}", parameters: def.getStorageNamespace(),
210 parameters: def.getStorageClassName())
211 : strfmt(fmt: "::mlir::{0}Storage", parameters&: valueType));
212 for (auto &trait : def.getTraits()) {
213 defParent.addTemplateParam(
214 param: isa<NativeTrait>(Val: &trait)
215 ? cast<NativeTrait>(Val: &trait)->getFullyQualifiedTraitName()
216 : cast<InterfaceTrait>(Val: &trait)->getFullyQualifiedTraitName());
217 }
218 defCls.addParent(parent: std::move(defParent));
219}
220
221/// Include declarations specified on NativeTrait
222static std::string formatExtraDeclarations(const AttrOrTypeDef &def) {
223 SmallVector<StringRef> extraDeclarations;
224 // Include extra class declarations from NativeTrait
225 for (const auto &trait : def.getTraits()) {
226 if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
227 StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration();
228 if (value.empty())
229 continue;
230 extraDeclarations.push_back(Elt: value);
231 }
232 }
233 if (std::optional<StringRef> extraDecl = def.getExtraDecls()) {
234 extraDeclarations.push_back(Elt: *extraDecl);
235 }
236 return llvm::join(R&: extraDeclarations, Separator: "\n");
237}
238
239/// Extra class definitions have a `$cppClass` substitution that is to be
240/// replaced by the C++ class name.
241static std::string formatExtraDefinitions(const AttrOrTypeDef &def) {
242 SmallVector<StringRef> extraDefinitions;
243 // Include extra class definitions from NativeTrait
244 for (const auto &trait : def.getTraits()) {
245 if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
246 StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition();
247 if (value.empty())
248 continue;
249 extraDefinitions.push_back(Elt: value);
250 }
251 }
252 if (std::optional<StringRef> extraDef = def.getExtraDefs()) {
253 extraDefinitions.push_back(Elt: *extraDef);
254 }
255 FmtContext ctx = FmtContext().addSubst(placeholder: "cppClass", subst: def.getCppClassName());
256 return tgfmt(fmt: llvm::join(R&: extraDefinitions, Separator: "\n"), ctx: &ctx).str();
257}
258
259void DefGen::emitTopLevelDeclarations() {
260 // Inherit constructors from the attribute or type class.
261 defCls.declare<VisibilityDeclaration>(args: Visibility::Public);
262 defCls.declare<UsingDeclaration>(args: "Base::Base");
263
264 // Emit the extra declarations first in case there's a definition in there.
265 std::string extraDecl = formatExtraDeclarations(def);
266 std::string extraDef = formatExtraDefinitions(def);
267 defCls.declare<ExtraClassDeclaration>(args: std::move(extraDecl),
268 args: std::move(extraDef));
269}
270
271void DefGen::emitName() {
272 StringRef name;
273 if (auto *attrDef = dyn_cast<AttrDef>(Val: &def)) {
274 name = attrDef->getAttrName();
275 } else {
276 auto *typeDef = cast<TypeDef>(Val: &def);
277 name = typeDef->getTypeName();
278 }
279 std::string nameDecl =
280 strfmt(fmt: "static constexpr ::llvm::StringLiteral name = \"{0}\";\n", parameters&: name);
281 defCls.declare<ExtraClassDeclaration>(args: std::move(nameDecl));
282}
283
284void DefGen::emitBuilders() {
285 if (!def.skipDefaultBuilders()) {
286 emitDefaultBuilder();
287 if (def.genVerifyDecl())
288 emitCheckedBuilder();
289 }
290 for (auto &builder : def.getBuilders()) {
291 emitCustomBuilder(builder);
292 if (def.genVerifyDecl())
293 emitCheckedCustomBuilder(builder);
294 }
295}
296
297void DefGen::emitVerifier() {
298 defCls.declare<UsingDeclaration>(args: "Base::getChecked");
299 defCls.declareStaticMethod(
300 retType: "::mlir::LogicalResult", name: "verify",
301 args: getBuilderParams(prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
302 "emitError"}}));
303}
304
305void DefGen::emitParserPrinter() {
306 auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
307 retType: "::llvm::StringLiteral", name: "getMnemonic");
308 mnemonic->body().indent() << strfmt(fmt: "return {\"{0}\"};", parameters: *def.getMnemonic());
309
310 // Declare the parser and printer, if needed.
311 bool hasAssemblyFormat = def.getAssemblyFormat().has_value();
312 if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat)
313 return;
314
315 // Declare the parser.
316 SmallVector<MethodParameter> parserParams;
317 parserParams.emplace_back(Args: "::mlir::AsmParser &", Args: "odsParser");
318 if (isa<AttrDef>(Val: &def))
319 parserParams.emplace_back(Args: "::mlir::Type", Args: "odsType");
320 auto *parser = defCls.addMethod(retType: strfmt(fmt: "::mlir::{0}", parameters&: valueType), name: "parse",
321 properties: hasAssemblyFormat ? Method::Static
322 : Method::StaticDeclaration,
323 args: std::move(parserParams));
324 // Declare the printer.
325 auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration;
326 Method *printer =
327 defCls.addMethod(retType: "void", name: "print", properties: props,
328 args: MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
329 // Emit the bodies if we are using the declarative format.
330 if (hasAssemblyFormat)
331 return generateAttrOrTypeFormat(def, parser&: parser->body(), printer&: printer->body());
332}
333
334void DefGen::emitAccessors() {
335 for (auto &param : params) {
336 Method *m = defCls.addMethod(
337 retType: param.getCppAccessorType(), name: param.getAccessorName(),
338 properties: def.genStorageClass() ? Method::Const : Method::ConstDeclaration);
339 // Generate accessor definitions only if we also generate the storage
340 // class. Otherwise, let the user define the exact accessor definition.
341 if (!def.genStorageClass())
342 continue;
343 m->body().indent() << "return getImpl()->" << param.getName() << ";";
344 }
345}
346
347void DefGen::emitInterfaceMethods() {
348 for (auto &traitDef : def.getTraits())
349 if (auto *trait = dyn_cast<InterfaceTrait>(Val: &traitDef))
350 if (trait->shouldDeclareMethods())
351 emitTraitMethods(trait: *trait);
352}
353
354//===----------------------------------------------------------------------===//
355// Builder Emission
356
357SmallVector<MethodParameter>
358DefGen::getBuilderParams(std::initializer_list<MethodParameter> prefix) const {
359 SmallVector<MethodParameter> builderParams;
360 builderParams.append(in_start: prefix.begin(), in_end: prefix.end());
361 for (auto &param : params)
362 builderParams.emplace_back(Args: param.getCppType(), Args: param.getName());
363 return builderParams;
364}
365
366void DefGen::emitDefaultBuilder() {
367 Method *m = defCls.addStaticMethod(
368 retType: def.getCppClassName(), name: "get",
369 args: getBuilderParams(prefix: {{"::mlir::MLIRContext *", "context"}}));
370 MethodBody &body = m->body().indent();
371 auto scope = body.scope(open: "return Base::get(context", close: ");");
372 for (const auto &param : params)
373 body << ", std::move(" << param.getName() << ")";
374}
375
376void DefGen::emitCheckedBuilder() {
377 Method *m = defCls.addStaticMethod(
378 retType: def.getCppClassName(), name: "getChecked",
379 args: getBuilderParams(
380 prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"},
381 {"::mlir::MLIRContext *", "context"}}));
382 MethodBody &body = m->body().indent();
383 auto scope = body.scope(open: "return Base::getChecked(emitError, context", close: ");");
384 for (const auto &param : params)
385 body << ", " << param.getName();
386}
387
388static SmallVector<MethodParameter>
389getCustomBuilderParams(std::initializer_list<MethodParameter> prefix,
390 const AttrOrTypeBuilder &builder) {
391 auto params = builder.getParameters();
392 SmallVector<MethodParameter> builderParams;
393 builderParams.append(in_start: prefix.begin(), in_end: prefix.end());
394 if (!builder.hasInferredContextParameter())
395 builderParams.emplace_back(Args: "::mlir::MLIRContext *", Args: "context");
396 for (auto &param : params) {
397 builderParams.emplace_back(Args: param.getCppType(), Args: *param.getName(),
398 Args: param.getDefaultValue());
399 }
400 return builderParams;
401}
402
403void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
404 // Don't emit a body if there isn't one.
405 auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
406 StringRef returnType = def.getCppClassName();
407 if (std::optional<StringRef> builderReturnType = builder.getReturnType())
408 returnType = *builderReturnType;
409 Method *m = defCls.addMethod(retType&: returnType, name: "get", properties: props,
410 args: getCustomBuilderParams(prefix: {}, builder));
411 if (!builder.getBody())
412 return;
413
414 // Format the body and emit it.
415 FmtContext ctx;
416 ctx.addSubst(placeholder: "_get", subst: "Base::get");
417 if (!builder.hasInferredContextParameter())
418 ctx.addSubst(placeholder: "_ctxt", subst: "context");
419 std::string bodyStr = tgfmt(fmt: *builder.getBody(), ctx: &ctx);
420 m->body().indent().getStream().printReindented(str: bodyStr);
421}
422
423/// Replace all instances of 'from' to 'to' in `str` and return the new string.
424static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
425 size_t pos = 0;
426 while ((pos = str.find(s: from.data(), pos: pos, n: from.size())) != std::string::npos)
427 str.replace(pos: pos, n1: from.size(), s: to.data(), n2: to.size());
428 return str;
429}
430
431void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
432 // Don't emit a body if there isn't one.
433 auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
434 StringRef returnType = def.getCppClassName();
435 if (std::optional<StringRef> builderReturnType = builder.getReturnType())
436 returnType = *builderReturnType;
437 Method *m = defCls.addMethod(
438 retType&: returnType, name: "getChecked", properties: props,
439 args: getCustomBuilderParams(
440 prefix: {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}},
441 builder));
442 if (!builder.getBody())
443 return;
444
445 // Format the body and emit it. Replace $_get(...) with
446 // Base::getChecked(emitError, ...)
447 FmtContext ctx;
448 if (!builder.hasInferredContextParameter())
449 ctx.addSubst(placeholder: "_ctxt", subst: "context");
450 std::string bodyStr = replaceInStr(str: builder.getBody()->str(), from: "$_get(",
451 to: "Base::getChecked(emitError, ");
452 bodyStr = tgfmt(fmt: bodyStr, ctx: &ctx);
453 m->body().indent().getStream().printReindented(str: bodyStr);
454}
455
456//===----------------------------------------------------------------------===//
457// Interface Method Emission
458
459void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
460 // Get the set of methods that should always be declared.
461 auto alwaysDeclaredMethods = trait.getAlwaysDeclaredMethods();
462 StringSet<> alwaysDeclared;
463 alwaysDeclared.insert(begin: alwaysDeclaredMethods.begin(),
464 end: alwaysDeclaredMethods.end());
465
466 Interface iface = trait.getInterface(); // causes strange bugs if elided
467 for (auto &method : iface.getMethods()) {
468 // Don't declare if the method has a body. Or if the method has a default
469 // implementation and the def didn't request that it always be declared.
470 if (method.getBody() || (method.getDefaultImplementation() &&
471 !alwaysDeclared.count(Key: method.getName())))
472 continue;
473 emitTraitMethod(method);
474 }
475}
476
477void DefGen::emitTraitMethod(const InterfaceMethod &method) {
478 // All interface methods are declaration-only.
479 auto props =
480 method.isStatic() ? Method::StaticDeclaration : Method::ConstDeclaration;
481 SmallVector<MethodParameter> params;
482 for (auto &param : method.getArguments())
483 params.emplace_back(Args: param.type, Args: param.name);
484 defCls.addMethod(retType: method.getReturnType(), name: method.getName(), properties: props,
485 args: std::move(params));
486}
487
488//===----------------------------------------------------------------------===//
489// Storage Class Emission
490
491void DefGen::emitStorageConstructor() {
492 Constructor *ctor =
493 storageCls->addConstructor<Method::Inline>(args: getBuilderParams(prefix: {}));
494 for (auto &param : params) {
495 std::string movedValue = ("std::move(" + param.getName() + ")").str();
496 ctor->addMemberInitializer(name: param.getName(), value&: movedValue);
497 }
498}
499
500void DefGen::emitKeyType() {
501 std::string keyType("std::tuple<");
502 llvm::raw_string_ostream os(keyType);
503 llvm::interleaveComma(c: params, os,
504 each_fn: [&](auto &param) { os << param.getCppType(); });
505 os << '>';
506 storageCls->declare<UsingDeclaration>(args: "KeyTy", args: std::move(os.str()));
507
508 // Add a method to construct the key type from the storage.
509 Method *m = storageCls->addConstMethod<Method::Inline>(retType: "KeyTy", name: "getAsKey");
510 m->body().indent() << "return KeyTy(";
511 llvm::interleaveComma(c: params, os&: m->body().indent(),
512 each_fn: [&](auto &param) { m->body() << param.getName(); });
513 m->body() << ");";
514}
515
516void DefGen::emitEquals() {
517 Method *eq = storageCls->addConstMethod<Method::Inline>(
518 retType: "bool", name: "operator==", args: MethodParameter("const KeyTy &", "tblgenKey"));
519 auto &body = eq->body().indent();
520 auto scope = body.scope(open: "return (", close: ");");
521 const auto eachFn = [&](auto it) {
522 FmtContext ctx({{"_lhs", it.value().getName()},
523 {"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}});
524 body << tgfmt(it.value().getComparator(), &ctx);
525 };
526 llvm::interleave(c: llvm::enumerate(First&: params), os&: body, each_fn: eachFn, separator: ") && (");
527}
528
529void DefGen::emitHashKey() {
530 Method *hash = storageCls->addStaticInlineMethod(
531 retType: "::llvm::hash_code", name: "hashKey",
532 args: MethodParameter("const KeyTy &", "tblgenKey"));
533 auto &body = hash->body().indent();
534 auto scope = body.scope(open: "return ::llvm::hash_combine(", close: ");");
535 llvm::interleaveComma(c: llvm::enumerate(First&: params), os&: body, each_fn: [&](auto it) {
536 body << llvm::formatv("std::get<{0}>(tblgenKey)", it.index());
537 });
538}
539
540void DefGen::emitConstruct() {
541 Method *construct = storageCls->addMethod<Method::Inline>(
542 retType: strfmt(fmt: "{0} *", parameters: def.getStorageClassName()), name: "construct",
543 properties: def.hasStorageCustomConstructor() ? Method::StaticDeclaration
544 : Method::Static,
545 args: MethodParameter(strfmt(fmt: "::mlir::{0}StorageAllocator &", parameters&: valueType),
546 "allocator"),
547 args: MethodParameter("KeyTy &&", "tblgenKey"));
548 if (!def.hasStorageCustomConstructor()) {
549 auto &body = construct->body().indent();
550 for (const auto &it : llvm::enumerate(First&: params)) {
551 body << formatv(Fmt: "auto {0} = std::move(std::get<{1}>(tblgenKey));\n",
552 Vals: it.value().getName(), Vals: it.index());
553 }
554 // Use the parameters' custom allocator code, if provided.
555 FmtContext ctx = FmtContext().addSubst(placeholder: "_allocator", subst: "allocator");
556 for (auto &param : params) {
557 if (std::optional<StringRef> allocCode = param.getAllocator()) {
558 ctx.withSelf(subst: param.getName()).addSubst(placeholder: "_dst", subst: param.getName());
559 body << tgfmt(fmt: *allocCode, ctx: &ctx) << '\n';
560 }
561 }
562 auto scope =
563 body.scope(open: strfmt(fmt: "return new (allocator.allocate<{0}>()) {0}(",
564 parameters: def.getStorageClassName()),
565 close: ");");
566 llvm::interleaveComma(c: params, os&: body, each_fn: [&](auto &param) {
567 body << "std::move(" << param.getName() << ")";
568 });
569 }
570}
571
572void DefGen::emitStorageClass() {
573 // Add the appropriate parent class.
574 storageCls->addParent(parent: strfmt(fmt: "::mlir::{0}Storage", parameters&: valueType));
575 // Add the constructor.
576 emitStorageConstructor();
577 // Declare the key type.
578 emitKeyType();
579 // Add the comparison method.
580 emitEquals();
581 // Emit the key hash method.
582 emitHashKey();
583 // Emit the storage constructor. Just declare it if the user wants to define
584 // it themself.
585 emitConstruct();
586 // Emit the storage class members as public, at the very end of the struct.
587 storageCls->finalize();
588 for (auto &param : params)
589 storageCls->declare<Field>(args: param.getCppType(), args: param.getName());
590}
591
592//===----------------------------------------------------------------------===//
593// DefGenerator
594//===----------------------------------------------------------------------===//
595
596namespace {
597/// This struct is the base generator used when processing tablegen interfaces.
598class DefGenerator {
599public:
600 bool emitDecls(StringRef selectedDialect);
601 bool emitDefs(StringRef selectedDialect);
602
603protected:
604 DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os,
605 StringRef defType, StringRef valueType, bool isAttrGenerator)
606 : defRecords(std::move(defs)), os(os), defType(defType),
607 valueType(valueType), isAttrGenerator(isAttrGenerator) {
608 // Sort by occurrence in file.
609 llvm::sort(C&: defRecords, Comp: [](llvm::Record *lhs, llvm::Record *rhs) {
610 return lhs->getID() < rhs->getID();
611 });
612 }
613
614 /// Emit the list of def type names.
615 void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
616 /// Emit the code to dispatch between different defs during parsing/printing.
617 void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
618
619 /// The set of def records to emit.
620 std::vector<llvm::Record *> defRecords;
621 /// The attribute or type class to emit.
622 /// The stream to emit to.
623 raw_ostream &os;
624 /// The prefix of the tablegen def name, e.g. Attr or Type.
625 StringRef defType;
626 /// The C++ base value type of the def, e.g. Attribute or Type.
627 StringRef valueType;
628 /// Flag indicating if this generator is for Attributes. False if the
629 /// generator is for types.
630 bool isAttrGenerator;
631};
632
633/// A specialized generator for AttrDefs.
634struct AttrDefGenerator : public DefGenerator {
635 AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
636 : DefGenerator(records.getAllDerivedDefinitionsIfDefined(ClassName: "AttrDef"), os,
637 "Attr", "Attribute", /*isAttrGenerator=*/true) {}
638};
639/// A specialized generator for TypeDefs.
640struct TypeDefGenerator : public DefGenerator {
641 TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
642 : DefGenerator(records.getAllDerivedDefinitionsIfDefined(ClassName: "TypeDef"), os,
643 "Type", "Type", /*isAttrGenerator=*/false) {}
644};
645} // namespace
646
647//===----------------------------------------------------------------------===//
648// GEN: Declarations
649//===----------------------------------------------------------------------===//
650
651/// Print this above all the other declarations. Contains type declarations used
652/// later on.
653static const char *const typeDefDeclHeader = R"(
654namespace mlir {
655class AsmParser;
656class AsmPrinter;
657} // namespace mlir
658)";
659
660bool DefGenerator::emitDecls(StringRef selectedDialect) {
661 emitSourceFileHeader(Desc: (defType + "Def Declarations").str(), OS&: os);
662 IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
663
664 // Output the common "header".
665 os << typeDefDeclHeader;
666
667 SmallVector<AttrOrTypeDef, 16> defs;
668 collectAllDefs(selectedDialect, records: defRecords, resultDefs&: defs);
669 if (defs.empty())
670 return false;
671 {
672 NamespaceEmitter nsEmitter(os, defs.front().getDialect());
673
674 // Declare all the def classes first (in case they reference each other).
675 for (const AttrOrTypeDef &def : defs)
676 os << "class " << def.getCppClassName() << ";\n";
677
678 // Emit the declarations.
679 for (const AttrOrTypeDef &def : defs)
680 DefGen(def).emitDecl(os);
681 }
682 // Emit the TypeID explicit specializations to have a single definition for
683 // each of these.
684 for (const AttrOrTypeDef &def : defs)
685 if (!def.getDialect().getCppNamespace().empty())
686 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID("
687 << def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
688 << ")\n";
689
690 return false;
691}
692
693//===----------------------------------------------------------------------===//
694// GEN: Def List
695//===----------------------------------------------------------------------===//
696
697void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
698 IfDefScope scope("GET_" + defType.upper() + "DEF_LIST", os);
699 auto interleaveFn = [&](const AttrOrTypeDef &def) {
700 os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName();
701 };
702 llvm::interleave(c: defs, os, each_fn: interleaveFn, separator: ",\n");
703 os << "\n";
704}
705
706//===----------------------------------------------------------------------===//
707// GEN: Definitions
708//===----------------------------------------------------------------------===//
709
710/// The code block for default attribute parser/printer dispatch boilerplate.
711/// {0}: the dialect fully qualified class name.
712/// {1}: the optional code for the dynamic attribute parser dispatch.
713/// {2}: the optional code for the dynamic attribute printer dispatch.
714static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
715/// Parse an attribute registered to this dialect.
716::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser,
717 ::mlir::Type type) const {{
718 ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
719 ::llvm::StringRef attrTag;
720 {{
721 ::mlir::Attribute attr;
722 auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
723 if (parseResult.has_value())
724 return attr;
725 }
726 {1}
727 parser.emitError(typeLoc) << "unknown attribute `"
728 << attrTag << "` in dialect `" << getNamespace() << "`";
729 return {{};
730}
731/// Print an attribute registered to this dialect.
732void {0}::printAttribute(::mlir::Attribute attr,
733 ::mlir::DialectAsmPrinter &printer) const {{
734 if (::mlir::succeeded(generatedAttributePrinter(attr, printer)))
735 return;
736 {2}
737}
738)";
739
740/// The code block for dynamic attribute parser dispatch boilerplate.
741static const char *const dialectDynamicAttrParserDispatch = R"(
742 {
743 ::mlir::Attribute genAttr;
744 auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr);
745 if (parseResult.has_value()) {
746 if (::mlir::succeeded(parseResult.value()))
747 return genAttr;
748 return Attribute();
749 }
750 }
751)";
752
753/// The code block for dynamic type printer dispatch boilerplate.
754static const char *const dialectDynamicAttrPrinterDispatch = R"(
755 if (::mlir::succeeded(printIfDynamicAttr(attr, printer)))
756 return;
757)";
758
759/// The code block for default type parser/printer dispatch boilerplate.
760/// {0}: the dialect fully qualified class name.
761/// {1}: the optional code for the dynamic type parser dispatch.
762/// {2}: the optional code for the dynamic type printer dispatch.
763static const char *const dialectDefaultTypePrinterParserDispatch = R"(
764/// Parse a type registered to this dialect.
765::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{
766 ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
767 ::llvm::StringRef mnemonic;
768 ::mlir::Type genType;
769 auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
770 if (parseResult.has_value())
771 return genType;
772 {1}
773 parser.emitError(typeLoc) << "unknown type `"
774 << mnemonic << "` in dialect `" << getNamespace() << "`";
775 return {{};
776}
777/// Print a type registered to this dialect.
778void {0}::printType(::mlir::Type type,
779 ::mlir::DialectAsmPrinter &printer) const {{
780 if (::mlir::succeeded(generatedTypePrinter(type, printer)))
781 return;
782 {2}
783}
784)";
785
786/// The code block for dynamic type parser dispatch boilerplate.
787static const char *const dialectDynamicTypeParserDispatch = R"(
788 {
789 auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
790 if (parseResult.has_value()) {
791 if (::mlir::succeeded(parseResult.value()))
792 return genType;
793 return ::mlir::Type();
794 }
795 }
796)";
797
798/// The code block for dynamic type printer dispatch boilerplate.
799static const char *const dialectDynamicTypePrinterDispatch = R"(
800 if (::mlir::succeeded(printIfDynamicType(type, printer)))
801 return;
802)";
803
804/// Emit the dialect printer/parser dispatcher. User's code should call these
805/// functions from their dialect's print/parse methods.
806void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
807 if (llvm::none_of(Range&: defs, P: [](const AttrOrTypeDef &def) {
808 return def.getMnemonic().has_value();
809 })) {
810 return;
811 }
812 // Declare the parser.
813 SmallVector<MethodParameter> params = {{"::mlir::AsmParser &", "parser"},
814 {"::llvm::StringRef *", "mnemonic"}};
815 if (isAttrGenerator)
816 params.emplace_back(Args: "::mlir::Type", Args: "type");
817 params.emplace_back(Args: strfmt(fmt: "::mlir::{0} &", parameters&: valueType), Args: "value");
818 Method parse("::mlir::OptionalParseResult",
819 strfmt(fmt: "generated{0}Parser", parameters&: valueType), Method::StaticInline,
820 std::move(params));
821 // Declare the printer.
822 Method printer("::mlir::LogicalResult",
823 strfmt(fmt: "generated{0}Printer", parameters&: valueType), Method::StaticInline,
824 {{strfmt(fmt: "::mlir::{0}", parameters&: valueType), "def"},
825 {"::mlir::AsmPrinter &", "printer"}});
826
827 // The parser dispatch uses a KeywordSwitch, matching on the mnemonic and
828 // calling the def's parse function.
829 parse.body() << " return "
830 "::mlir::AsmParser::KeywordSwitch<::mlir::"
831 "OptionalParseResult>(parser)\n";
832 const char *const getValueForMnemonic =
833 R"( .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{
834 value = {0}::{1};
835 return ::mlir::success(!!value);
836 })
837)";
838
839 // The printer dispatch uses llvm::TypeSwitch to find and call the correct
840 // printer.
841 printer.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType
842 << ", ::mlir::LogicalResult>(def)";
843 const char *const printValue = R"( .Case<{0}>([&](auto t) {{
844 printer << {0}::getMnemonic();{1}
845 return ::mlir::success();
846 })
847)";
848 for (auto &def : defs) {
849 if (!def.getMnemonic())
850 continue;
851 bool hasParserPrinterDecl =
852 def.hasCustomAssemblyFormat() || def.getAssemblyFormat();
853 std::string defClass = strfmt(
854 fmt: "{0}::{1}", parameters: def.getDialect().getCppNamespace(), parameters: def.getCppClassName());
855
856 // If the def has no parameters or parser code, invoke a normal `get`.
857 std::string parseOrGet =
858 hasParserPrinterDecl
859 ? strfmt(fmt: "parse(parser{0})", parameters: isAttrGenerator ? ", type" : "")
860 : "get(parser.getContext())";
861 parse.body() << llvm::formatv(Fmt: getValueForMnemonic, Vals&: defClass, Vals&: parseOrGet);
862
863 // If the def has no parameters and no printer, just print the mnemonic.
864 StringRef printDef = "";
865 if (hasParserPrinterDecl)
866 printDef = "\nt.print(printer);";
867 printer.body() << llvm::formatv(Fmt: printValue, Vals&: defClass, Vals&: printDef);
868 }
869 parse.body() << " .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n"
870 " *mnemonic = keyword;\n"
871 " return std::nullopt;\n"
872 " });";
873 printer.body() << " .Default([](auto) { return ::mlir::failure(); });";
874
875 raw_indented_ostream indentedOs(os);
876 parse.writeDeclTo(os&: indentedOs);
877 printer.writeDeclTo(os&: indentedOs);
878}
879
880bool DefGenerator::emitDefs(StringRef selectedDialect) {
881 emitSourceFileHeader(Desc: (defType + "Def Definitions").str(), OS&: os);
882
883 SmallVector<AttrOrTypeDef, 16> defs;
884 collectAllDefs(selectedDialect, records: defRecords, resultDefs&: defs);
885 if (defs.empty())
886 return false;
887 emitTypeDefList(defs);
888
889 IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
890 emitParsePrintDispatch(defs);
891 for (const AttrOrTypeDef &def : defs) {
892 {
893 NamespaceEmitter ns(os, def.getDialect());
894 DefGen gen(def);
895 gen.emitDef(os);
896 }
897 // Emit the TypeID explicit specializations to have a single symbol def.
898 if (!def.getDialect().getCppNamespace().empty())
899 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID("
900 << def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
901 << ")\n";
902 }
903
904 Dialect firstDialect = defs.front().getDialect();
905
906 // Emit the default parser/printer for Attributes if the dialect asked for it.
907 if (isAttrGenerator && firstDialect.useDefaultAttributePrinterParser()) {
908 NamespaceEmitter nsEmitter(os, firstDialect);
909 if (firstDialect.isExtensible()) {
910 os << llvm::formatv(Fmt: dialectDefaultAttrPrinterParserDispatch,
911 Vals: firstDialect.getCppClassName(),
912 Vals: dialectDynamicAttrParserDispatch,
913 Vals: dialectDynamicAttrPrinterDispatch);
914 } else {
915 os << llvm::formatv(Fmt: dialectDefaultAttrPrinterParserDispatch,
916 Vals: firstDialect.getCppClassName(), Vals: "", Vals: "");
917 }
918 }
919
920 // Emit the default parser/printer for Types if the dialect asked for it.
921 if (!isAttrGenerator && firstDialect.useDefaultTypePrinterParser()) {
922 NamespaceEmitter nsEmitter(os, firstDialect);
923 if (firstDialect.isExtensible()) {
924 os << llvm::formatv(Fmt: dialectDefaultTypePrinterParserDispatch,
925 Vals: firstDialect.getCppClassName(),
926 Vals: dialectDynamicTypeParserDispatch,
927 Vals: dialectDynamicTypePrinterDispatch);
928 } else {
929 os << llvm::formatv(Fmt: dialectDefaultTypePrinterParserDispatch,
930 Vals: firstDialect.getCppClassName(), Vals: "", Vals: "");
931 }
932 }
933
934 return false;
935}
936
937//===----------------------------------------------------------------------===//
938// GEN: Registration hooks
939//===----------------------------------------------------------------------===//
940
941//===----------------------------------------------------------------------===//
942// AttrDef
943
944static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*");
945static llvm::cl::opt<std::string>
946 attrDialect("attrdefs-dialect",
947 llvm::cl::desc("Generate attributes for this dialect"),
948 llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated);
949
950static mlir::GenRegistration
951 genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
952 [](const llvm::RecordKeeper &records, raw_ostream &os) {
953 AttrDefGenerator generator(records, os);
954 return generator.emitDefs(selectedDialect: attrDialect);
955 });
956static mlir::GenRegistration
957 genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
958 [](const llvm::RecordKeeper &records, raw_ostream &os) {
959 AttrDefGenerator generator(records, os);
960 return generator.emitDecls(selectedDialect: attrDialect);
961 });
962
963//===----------------------------------------------------------------------===//
964// TypeDef
965
966static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
967static llvm::cl::opt<std::string>
968 typeDialect("typedefs-dialect",
969 llvm::cl::desc("Generate types for this dialect"),
970 llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
971
972static mlir::GenRegistration
973 genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
974 [](const llvm::RecordKeeper &records, raw_ostream &os) {
975 TypeDefGenerator generator(records, os);
976 return generator.emitDefs(selectedDialect: typeDialect);
977 });
978static mlir::GenRegistration
979 genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
980 [](const llvm::RecordKeeper &records, raw_ostream &os) {
981 TypeDefGenerator generator(records, os);
982 return generator.emitDecls(selectedDialect: typeDialect);
983 });
984

source code of mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp